diff --git a/declearn/model/tensorflow/utils/_loss.py b/declearn/model/tensorflow/utils/_loss.py
index d70e26a0fe489e7ef95b37b22f8cdcd0e6ffde30..5a3ffdaae06ee144f51777747f1d2bc797dd3f04 100644
--- a/declearn/model/tensorflow/utils/_loss.py
+++ b/declearn/model/tensorflow/utils/_loss.py
@@ -85,22 +85,9 @@ def build_keras_loss(
     # Case when 'loss' is already a Loss object.
     if isinstance(loss, tf.keras.losses.Loss):
         loss.reduction = reduction
-    # Case when 'loss' is a string.
+    # Case when 'loss' is a string: deserialize and/or wrap into a Loss object.
     elif isinstance(loss, str):
-        cls = tf.keras.losses.deserialize(loss)
-        # Case when the string was deserialized into a function.
-        if inspect.isfunction(cls):
-            # Try altering the string to gather its object counterpart.
-            loss = "".join(word.capitalize() for word in loss.split("_"))
-            try:
-                loss = tf.keras.losses.deserialize(loss)
-                loss.reduction = reduction
-            # If this failed, try wrapping the function using LossFunction.
-            except (AttributeError, ValueError):
-                loss = LossFunction(cls)
-        # Case when the string was deserialized into a class.
-        else:
-            loss = cls(reduction=reduction)
+        loss = get_keras_loss_from_string(name=loss, reduction=reduction)
     # Case when 'loss' is a function: wrap it up using LossFunction.
     elif inspect.isfunction(loss):
         loss = LossFunction(loss, reduction=reduction)
@@ -111,3 +98,32 @@ def build_keras_loss(
         )
     # Otherwise, properly configure the reduction scheme and return.
     return loss
+
+
+def get_keras_loss_from_string(
+    name: str,
+    reduction: str,
+) -> tf.keras.losses.Loss:
+    """Instantiate a keras Loss object from a registered string identifier.
+
+    - If `name` matches a Loss registration name, return an instance.
+    - If it matches a loss function registration name, return either
+      an instance from its name-matching Loss subclass, or a custom
+      Loss subclass instance wrapping the function.
+    - If it does not match anything, raise a ValueError.
+    """
+    loss = tf.keras.losses.deserialize(name)
+    if isinstance(loss, tf.keras.losses.Loss):
+        loss.reduction = reduction
+        return loss
+    if inspect.isfunction(loss):
+        try:
+            name = "".join(word.capitalize() for word in name.split("_"))
+            return get_keras_loss_from_string(name, reduction)
+        except ValueError:
+            return LossFunction(
+                loss, reduction=reduction, name=getattr(loss, "__name__", None)
+            )
+    raise ValueError(
+        f"Name '{loss}' cannot be deserialized into a keras loss."
+    )
diff --git a/test/model/test_tflow.py b/test/model/test_tflow.py
index 0387075f66dab9bad3635d033e76775c56666624..168d1c0b9b2e1f3dc3c9d9fdd241282848215108 100644
--- a/test/model/test_tflow.py
+++ b/test/model/test_tflow.py
@@ -32,6 +32,7 @@ except ModuleNotFoundError:
     pytest.skip("TensorFlow is unavailable", allow_module_level=True)
 
 from declearn.model.tensorflow import TensorflowModel, TensorflowVector
+from declearn.model.tensorflow.utils import build_keras_loss
 from declearn.typing import Batch
 from declearn.utils import set_device_policy
 
@@ -221,3 +222,66 @@ class TestTensorflowModel(ModelTestSuite):
         device = f"{test_case.device}:0"
         for var in tfmod.weights:
             assert var.device.endswith(device)
+
+
+class TestBuildKerasLoss:
+    """Unit tests for `build_keras_loss` util function."""
+
+    def test_build_keras_loss_from_string_class_name(self) -> None:
+        """Test `build_keras_loss` with a valid class name string input."""
+        loss = build_keras_loss(
+            "BinaryCrossentropy", tf.keras.losses.Reduction.SUM
+        )
+        assert isinstance(loss, tf.keras.losses.BinaryCrossentropy)
+        assert loss.reduction == tf.keras.losses.Reduction.SUM
+
+    def test_build_keras_loss_from_string_function_name(self) -> None:
+        """Test `build_keras_loss` with a valid function name string input."""
+        loss = build_keras_loss(
+            "binary_crossentropy", tf.keras.losses.Reduction.SUM
+        )
+        assert isinstance(loss, tf.keras.losses.BinaryCrossentropy)
+        assert loss.reduction == tf.keras.losses.Reduction.SUM
+
+    def test_build_keras_loss_from_string_noclass_function_name(self) -> None:
+        """Test `build_keras_loss` with a valid function name string input."""
+        loss = build_keras_loss("mse", tf.keras.losses.Reduction.SUM)
+        assert isinstance(loss, tf.keras.losses.Loss)
+        assert hasattr(loss, "loss_fn")
+        assert loss.loss_fn is tf.keras.losses.mse
+        assert loss.reduction == tf.keras.losses.Reduction.SUM
+
+    def test_build_keras_loss_from_loss_instance(self) -> None:
+        """Test `build_keras_loss` with a valid keras Loss input."""
+        # Set up a BinaryCrossentropy loss instance.
+        loss = tf.keras.losses.BinaryCrossentropy(
+            reduction=tf.keras.losses.Reduction.SUM
+        )
+        assert loss.reduction == tf.keras.losses.Reduction.SUM
+        # Pass it through the util and verify that reduction changes.
+        loss = build_keras_loss(loss, tf.keras.losses.Reduction.NONE)
+        assert isinstance(loss, tf.keras.losses.BinaryCrossentropy)
+        assert loss.reduction == tf.keras.losses.Reduction.NONE
+
+    def test_build_keras_loss_from_loss_function(self) -> None:
+        """Test `build_keras_loss` with a valid keras loss function input."""
+        loss = build_keras_loss(
+            tf.keras.losses.binary_crossentropy, tf.keras.losses.Reduction.SUM
+        )
+        assert isinstance(loss, tf.keras.losses.Loss)
+        assert hasattr(loss, "loss_fn")
+        assert loss.loss_fn is tf.keras.losses.binary_crossentropy
+        assert loss.reduction == tf.keras.losses.Reduction.SUM
+
+    def test_build_keras_loss_from_custom_function(self) -> None:
+        """Test `build_keras_loss` with a valid custom loss function input."""
+
+        def loss_fn(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
+            """Custom loss function."""
+            return tf.reduce_sum(tf.cast(tf.equal(y_true, y_pred), tf.float32))
+
+        loss = build_keras_loss(loss_fn, tf.keras.losses.Reduction.SUM)
+        assert isinstance(loss, tf.keras.losses.Loss)
+        assert hasattr(loss, "loss_fn")
+        assert loss.loss_fn is loss_fn
+        assert loss.reduction == tf.keras.losses.Reduction.SUM