diff --git a/declearn/model/tensorflow/utils/_loss.py b/declearn/model/tensorflow/utils/_loss.py index d70e26a0fe489e7ef95b37b22f8cdcd0e6ffde30..5a3ffdaae06ee144f51777747f1d2bc797dd3f04 100644 --- a/declearn/model/tensorflow/utils/_loss.py +++ b/declearn/model/tensorflow/utils/_loss.py @@ -85,22 +85,9 @@ def build_keras_loss( # Case when 'loss' is already a Loss object. if isinstance(loss, tf.keras.losses.Loss): loss.reduction = reduction - # Case when 'loss' is a string. + # Case when 'loss' is a string: deserialize and/or wrap into a Loss object. elif isinstance(loss, str): - cls = tf.keras.losses.deserialize(loss) - # Case when the string was deserialized into a function. - if inspect.isfunction(cls): - # Try altering the string to gather its object counterpart. - loss = "".join(word.capitalize() for word in loss.split("_")) - try: - loss = tf.keras.losses.deserialize(loss) - loss.reduction = reduction - # If this failed, try wrapping the function using LossFunction. - except (AttributeError, ValueError): - loss = LossFunction(cls) - # Case when the string was deserialized into a class. - else: - loss = cls(reduction=reduction) + loss = get_keras_loss_from_string(name=loss, reduction=reduction) # Case when 'loss' is a function: wrap it up using LossFunction. elif inspect.isfunction(loss): loss = LossFunction(loss, reduction=reduction) @@ -111,3 +98,32 @@ def build_keras_loss( ) # Otherwise, properly configure the reduction scheme and return. return loss + + +def get_keras_loss_from_string( + name: str, + reduction: str, +) -> tf.keras.losses.Loss: + """Instantiate a keras Loss object from a registered string identifier. + + - If `name` matches a Loss registration name, return an instance. + - If it matches a loss function registration name, return either + an instance from its name-matching Loss subclass, or a custom + Loss subclass instance wrapping the function. + - If it does not match anything, raise a ValueError. + """ + loss = tf.keras.losses.deserialize(name) + if isinstance(loss, tf.keras.losses.Loss): + loss.reduction = reduction + return loss + if inspect.isfunction(loss): + try: + name = "".join(word.capitalize() for word in name.split("_")) + return get_keras_loss_from_string(name, reduction) + except ValueError: + return LossFunction( + loss, reduction=reduction, name=getattr(loss, "__name__", None) + ) + raise ValueError( + f"Name '{loss}' cannot be deserialized into a keras loss." + ) diff --git a/test/model/test_tflow.py b/test/model/test_tflow.py index 0387075f66dab9bad3635d033e76775c56666624..168d1c0b9b2e1f3dc3c9d9fdd241282848215108 100644 --- a/test/model/test_tflow.py +++ b/test/model/test_tflow.py @@ -32,6 +32,7 @@ except ModuleNotFoundError: pytest.skip("TensorFlow is unavailable", allow_module_level=True) from declearn.model.tensorflow import TensorflowModel, TensorflowVector +from declearn.model.tensorflow.utils import build_keras_loss from declearn.typing import Batch from declearn.utils import set_device_policy @@ -221,3 +222,66 @@ class TestTensorflowModel(ModelTestSuite): device = f"{test_case.device}:0" for var in tfmod.weights: assert var.device.endswith(device) + + +class TestBuildKerasLoss: + """Unit tests for `build_keras_loss` util function.""" + + def test_build_keras_loss_from_string_class_name(self) -> None: + """Test `build_keras_loss` with a valid class name string input.""" + loss = build_keras_loss( + "BinaryCrossentropy", tf.keras.losses.Reduction.SUM + ) + assert isinstance(loss, tf.keras.losses.BinaryCrossentropy) + assert loss.reduction == tf.keras.losses.Reduction.SUM + + def test_build_keras_loss_from_string_function_name(self) -> None: + """Test `build_keras_loss` with a valid function name string input.""" + loss = build_keras_loss( + "binary_crossentropy", tf.keras.losses.Reduction.SUM + ) + assert isinstance(loss, tf.keras.losses.BinaryCrossentropy) + assert loss.reduction == tf.keras.losses.Reduction.SUM + + def test_build_keras_loss_from_string_noclass_function_name(self) -> None: + """Test `build_keras_loss` with a valid function name string input.""" + loss = build_keras_loss("mse", tf.keras.losses.Reduction.SUM) + assert isinstance(loss, tf.keras.losses.Loss) + assert hasattr(loss, "loss_fn") + assert loss.loss_fn is tf.keras.losses.mse + assert loss.reduction == tf.keras.losses.Reduction.SUM + + def test_build_keras_loss_from_loss_instance(self) -> None: + """Test `build_keras_loss` with a valid keras Loss input.""" + # Set up a BinaryCrossentropy loss instance. + loss = tf.keras.losses.BinaryCrossentropy( + reduction=tf.keras.losses.Reduction.SUM + ) + assert loss.reduction == tf.keras.losses.Reduction.SUM + # Pass it through the util and verify that reduction changes. + loss = build_keras_loss(loss, tf.keras.losses.Reduction.NONE) + assert isinstance(loss, tf.keras.losses.BinaryCrossentropy) + assert loss.reduction == tf.keras.losses.Reduction.NONE + + def test_build_keras_loss_from_loss_function(self) -> None: + """Test `build_keras_loss` with a valid keras loss function input.""" + loss = build_keras_loss( + tf.keras.losses.binary_crossentropy, tf.keras.losses.Reduction.SUM + ) + assert isinstance(loss, tf.keras.losses.Loss) + assert hasattr(loss, "loss_fn") + assert loss.loss_fn is tf.keras.losses.binary_crossentropy + assert loss.reduction == tf.keras.losses.Reduction.SUM + + def test_build_keras_loss_from_custom_function(self) -> None: + """Test `build_keras_loss` with a valid custom loss function input.""" + + def loss_fn(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor: + """Custom loss function.""" + return tf.reduce_sum(tf.cast(tf.equal(y_true, y_pred), tf.float32)) + + loss = build_keras_loss(loss_fn, tf.keras.losses.Reduction.SUM) + assert isinstance(loss, tf.keras.losses.Loss) + assert hasattr(loss, "loss_fn") + assert loss.loss_fn is loss_fn + assert loss.reduction == tf.keras.losses.Reduction.SUM