From aff0ff3185113a97b6603092b79ba9d30659541c Mon Sep 17 00:00:00 2001
From: Paul Andrey <paul.andrey@inria.fr>
Date: Fri, 15 Mar 2024 11:29:16 +0100
Subject: [PATCH] Add support for TensorFlow 2.16 / Keras 3.

- Apply some minor backward changes (and, when required, branches) to
  support breaking changes in Keras 3 (the default for TensorFlow 2.16
  and onwards), while preserving support for older versions.
- On the side, split a test about 'Model' serialization in two tests,
  one for config serializability, the other for instantiation from a
  non-serialized config dict.
- The current code has been tested (by running all unit tests) with
  both TensorFlow 2.11 and 2.16, in a Python 3.10 environment.
---
 declearn/model/tensorflow/_model.py      | 24 ++++++++++-----
 declearn/model/tensorflow/_optim.py      | 25 +++++++++++-----
 declearn/model/tensorflow/utils/_loss.py |  2 +-
 docs/release-notes/v2.4.0.md             | 11 ++++---
 pyproject.toml                           |  4 +--
 test/model/model_testing.py              | 20 +++++++++----
 test/model/test_haiku_model.py           | 11 +++++--
 test/model/test_sksgd_model.py           |  9 ------
 test/model/test_tflow_model.py           | 30 ++++++++++++++-----
 test/model/test_torch_model.py           | 38 +++++++++++++++++-------
 test/optimizer/test_tflow_optim.py       |  8 ++++-
 11 files changed, 125 insertions(+), 57 deletions(-)

diff --git a/declearn/model/tensorflow/_model.py b/declearn/model/tensorflow/_model.py
index 6d9dfe3c..212b3189 100644
--- a/declearn/model/tensorflow/_model.py
+++ b/declearn/model/tensorflow/_model.py
@@ -183,10 +183,20 @@ class TensorflowModel(Model):
         self,
         trainable: bool = False,
     ) -> TensorflowVector:
+        variables = self._get_weight_variables(trainable)
+        return TensorflowVector({var.name: var.value() for var in variables})
+
+    def _get_weight_variables(
+        self,
+        trainable: bool,
+    ) -> Iterable[tf.Variable]:
+        """Access TensorFlow Variables wrapping model weight tensors."""
         variables = (
             self._model.trainable_weights if trainable else self._model.weights
         )
-        return TensorflowVector({var.name: var.value() for var in variables})
+        if hasattr(tf_keras, "version") and tf_keras.version().startswith("3"):
+            variables = (var.value for var in variables)
+        return variables
 
     def set_weights(
         self,
@@ -198,7 +208,9 @@ class TensorflowModel(Model):
                 "TensorflowModel requires TensorflowVector weights."
             )
         self._verify_weights_compatibility(weights, trainable=trainable)
-        variables = {var.name: var for var in self._model.weights}
+        variables = {
+            var.name: var for var in self._get_weight_variables(trainable)
+        }
         with tf.device(self._device):
             for name, value in weights.coefs.items():
                 variables[name].assign(value, read_value=False)
@@ -225,9 +237,7 @@ class TensorflowModel(Model):
             In case some expected keys are missing, or additional keys
             are present. Be verbose about the identified mismatch(es).
         """
-        variables = (
-            self._model.trainable_weights if trainable else self._model.weights
-        )
+        variables = self._get_weight_variables(trainable)
         raise_on_stringsets_mismatch(
             received=set(vector.coefs),
             expected={var.name for var in variables},
@@ -247,7 +257,7 @@ class TensorflowModel(Model):
                 norm = tf.constant(max_norm)
                 grads, loss = self._compute_clipped_gradients(*data, norm)
         self._loss_history.append(float(loss.numpy()))
-        grads_and_vars = zip(grads, self._model.trainable_weights)
+        grads_and_vars = zip(grads, self._get_weight_variables(trainable=True))
         return TensorflowVector(
             {var.name: grad for grad, var in grads_and_vars}
         )
@@ -331,7 +341,7 @@ class TensorflowModel(Model):
     ) -> None:
         self._verify_weights_compatibility(updates, trainable=True)
         with tf.device(self._device):
-            for var in self._model.trainable_weights:
+            for var in self._get_weight_variables(trainable=True):
                 updt = updates.coefs[var.name]
                 if isinstance(updt, tf.IndexedSlices):
                     var.scatter_add(updt)
diff --git a/declearn/model/tensorflow/_optim.py b/declearn/model/tensorflow/_optim.py
index 9513f75f..6538b687 100644
--- a/declearn/model/tensorflow/_optim.py
+++ b/declearn/model/tensorflow/_optim.py
@@ -17,7 +17,7 @@
 
 """Hacky OptiModule subclass enabling the use of a keras Optimizer."""
 
-from typing import Any, Dict, Union
+from typing import Any, Dict, List, Union
 
 # fmt: off
 # pylint: disable=import-error,no-name-in-module
@@ -189,7 +189,7 @@ class TensorflowOptiModule(OptiModule):
                 key: tf.Variable(tf.zeros_like(grad), name=key)
                 for key, grad in gradients.coefs.items()
             }
-            self.optim.build(self._vars.values())
+            self.optim.build(list(self._vars.values()))
 
     def reset(self) -> None:
         """Reset this module to its uninitialized state.
@@ -206,7 +206,7 @@ class TensorflowOptiModule(OptiModule):
         policy = get_device_policy()
         self._device = select_device(gpu=policy.gpu, idx=policy.idx)
         with tf.device(self._device):
-            self._vars = {}
+            self._vars.clear()
             self.optim = self.optim.from_config(self.optim.get_config())
 
     def get_config(
@@ -222,11 +222,20 @@ class TensorflowOptiModule(OptiModule):
             key: (val.shape.as_list(), val.dtype.name)
             for key, val in self._vars.items()
         }
+        variables = self._get_optimizer_variables()
         state = TensorflowVector(
-            {var.name: var.value() for var in self.optim.variables()}
+            {str(i): v.value() for i, v in enumerate(variables)}
         )
         return {"specs": specs, "state": state}
 
+    def _get_optimizer_variables(
+        self,
+    ) -> List[tf.Variable]:
+        """Access wrapped optimizer's variables as 'tf.Variable' instances."""
+        if hasattr(tf_keras, "version") and tf_keras.version().startswith("3"):
+            return [var.value for var in self.optim.variables]
+        return self.optim.variables()
+
     def set_state(
         self,
         state: Dict[str, Any],
@@ -244,9 +253,9 @@ class TensorflowOptiModule(OptiModule):
                 key: tf.Variable(tf.zeros(shape, dtype), name=key)
                 for key, (shape, dtype) in state["specs"].items()
             }
-            self.optim.build(self._vars.values())
+            self.optim.build(list(self._vars.values()))
         # Restore optimizer variables' values from the input state dict.
-        opt_vars = {var.name: var for var in self.optim.variables()}
+        opt_vars = self._get_optimizer_variables()
         with tf.device(self._device):
-            for key, val in state["state"].coefs.items():
-                opt_vars[key].assign(val, read_value=False)
+            for var, val in zip(opt_vars, state["state"].coefs.values()):
+                var.assign(val, read_value=False)
diff --git a/declearn/model/tensorflow/utils/_loss.py b/declearn/model/tensorflow/utils/_loss.py
index 31069ee2..da49339d 100644
--- a/declearn/model/tensorflow/utils/_loss.py
+++ b/declearn/model/tensorflow/utils/_loss.py
@@ -48,7 +48,7 @@ class LossFunction(tf_keras.losses.Loss):
         reduction: str = tf_keras.losses.Reduction.NONE,
         name: Optional[str] = None,
     ) -> None:
-        super().__init__(reduction, name)
+        super().__init__(reduction=reduction, name=name)
         self.loss_fn = tf_keras.losses.deserialize(loss_fn)
 
     def call(
diff --git a/docs/release-notes/v2.4.0.md b/docs/release-notes/v2.4.0.md
index e2af281a..cfaa2de8 100644
--- a/docs/release-notes/v2.4.0.md
+++ b/docs/release-notes/v2.4.0.md
@@ -274,10 +274,13 @@ Older TensorFlow versions (v2.5 to 2.10 included) were improperly marked as
 supported in spite of `TensorflowOptiModule` requiring at least version 2.11
 to work (due to changes of the Keras Optimizer API). This has been corrected.
 
-Upcoming TensorFlow versions (v2.16 and onwards) introduce backward-breaking
-changes, which are probably due to the backend swap from Keras 2 to Keras 3.
-To keep safe, these versions are currently marked as unsupported, awaiting
-further investigation once version 2.16 is finalized.
+The latest TensorFlow version (v2.16) introduces backward-breaking changes, due to the backend swap from Keras 2 to Keras 3. Our backend code was updated to
+both add support for this newer Keras backend, and preserve existing support.
+
+Note that at the moment, the CI does not support TensorFlow above 2.13, due to
+newer versions not being compatible with Python 3.8. As such, our code will be
+tested to remain backward-compatible. Forward compatibility has been (and will
+keep being) tested locally with a newer Python version.
 
 ### Deprecate `declearn.dataset.load_from_json`
 
diff --git a/pyproject.toml b/pyproject.toml
index 2223f3d6..e9224145 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -53,7 +53,7 @@ all = [  # all non-tests extra dependencies
     "jax[cpu] ~= 0.4.1",
     "opacus ~= 1.4",
     "protobuf >= 3.19",
-    "tensorflow ~= 2.11, < 2.16",
+    "tensorflow ~= 2.11",
     "torch >= 1.13, < 3.0",
     "websockets >= 10.1, < 13.0",
 ]
@@ -71,7 +71,7 @@ haiku = [
     "jax[cpu] ~= 0.4.1",  # NOTE: GPU support must be manually installed
 ]
 tensorflow = [
-    "tensorflow ~= 2.11, < 2.16",  # starting with 2.16 TF upgrades to Keras 3
+    "tensorflow ~= 2.11",
 ]
 torch = [  # generic requirements for Torch
     "torch >= 1.13, < 3.0",
diff --git a/test/model/model_testing.py b/test/model/model_testing.py
index 4806e0df..76cf4583 100644
--- a/test/model/model_testing.py
+++ b/test/model/model_testing.py
@@ -17,13 +17,14 @@
 
 """Shared testing code for TensorFlow and Torch models' unit tests."""
 
+import copy
 import json
 from typing import Any, Generic, List, Protocol, Tuple, Type, TypeVar, Union
 
 import numpy as np
 
 from declearn.model.api import Model, Vector
-from declearn.test_utils import to_numpy
+from declearn.test_utils import assert_json_serializable_dict, to_numpy
 from declearn.typing import Batch
 from declearn.utils import json_pack, json_unpack
 
@@ -59,14 +60,23 @@ class ModelTestCase(Protocol, Generic[VectorT]):
 class ModelTestSuite:
     """Unit tests for a declearn Model."""
 
-    def test_serialization(
+    def test_get_config(
         self,
         test_case: ModelTestCase,
     ) -> None:
-        """Check that the model can be JSON-(de)serialized properly."""
+        """Check that the model's config is JSON-serializable."""
         model = test_case.model
-        config = json.dumps(model.get_config())
-        other = model.from_config(json.loads(config))
+        config = model.get_config()
+        assert_json_serializable_dict(config)
+
+    def test_from_config(
+        self,
+        test_case: ModelTestCase,
+    ) -> None:
+        """Check that the model can be instantiated from its config."""
+        model = test_case.model
+        config = model.get_config()
+        other = model.from_config(copy.deepcopy(config))
         assert model.get_config() == other.get_config()
         assert model.device_policy == other.device_policy
 
diff --git a/test/model/test_haiku_model.py b/test/model/test_haiku_model.py
index 7bb9ca7b..7e31f048 100644
--- a/test/model/test_haiku_model.py
+++ b/test/model/test_haiku_model.py
@@ -234,11 +234,18 @@ class TestHaikuModel(ModelTestSuite):
     """Unit tests for declearn.model.tensorflow.TensorflowModel."""
 
     @pytest.mark.filterwarnings("ignore: Our custom Haiku serialization")
-    def test_serialization(
+    def test_get_config(
         self,
         test_case: ModelTestCase,
     ) -> None:
-        super().test_serialization(test_case)
+        super().test_get_config(test_case)
+
+    @pytest.mark.filterwarnings("ignore: Our custom Haiku serialization")
+    def test_from_config(
+        self,
+        test_case: ModelTestCase,
+    ) -> None:
+        super().test_from_config(test_case)
 
     @pytest.mark.parametrize(
         "criterion_type", ["names", "pytree", "predicate"]
diff --git a/test/model/test_sksgd_model.py b/test/model/test_sksgd_model.py
index 35f1b238..f88de8a7 100644
--- a/test/model/test_sksgd_model.py
+++ b/test/model/test_sksgd_model.py
@@ -120,15 +120,6 @@ def fixture_test_case(
 class TestSklearnSGDModel(ModelTestSuite):
     """Unit tests for declearn.model.sklearn.SklearnSGDModel."""
 
-    def test_serialization(  # type: ignore  # Liskov does not matter here
-        self,
-        test_case: SklearnSGDTestCase,
-    ) -> None:
-        # Avoid re-running tests that are unaltered by data parameters.
-        if test_case.s_weights or test_case.as_sparse:
-            return None
-        return super().test_serialization(test_case)
-
     def test_initialization(
         self,
         test_case: SklearnSGDTestCase,
diff --git a/test/model/test_tflow_model.py b/test/model/test_tflow_model.py
index 9b1feb43..65805b5f 100644
--- a/test/model/test_tflow_model.py
+++ b/test/model/test_tflow_model.py
@@ -30,7 +30,6 @@ try:
 except ModuleNotFoundError:
     pytest.skip("TensorFlow is unavailable", allow_module_level=True)
 else:
-    # pylint: disable=import-error,no-name-in-module
     import tensorflow.keras as tf_keras  # type: ignore
 
 from declearn.model.tensorflow import TensorflowModel, TensorflowVector
@@ -176,14 +175,23 @@ class TestTensorflowModel(ModelTestSuite):
         test_case: ModelTestCase,
     ) -> None:
         """Check that `get_weights` behaves properly with frozen weights."""
+        # Set up a model with a frozen layer.
         model = test_case.model
         tfmod = model.get_wrapped_model()
         tfmod.layers[0].trainable = False  # freeze the first layer's weights
-        w_all = model.get_weights()
-        w_trn = model.get_weights(trainable=True)
-        assert set(w_trn.coefs).issubset(w_all.coefs)  # check on keys
-        assert w_trn.coefs.keys() == {v.name for v in tfmod.trainable_weights}
-        assert w_all.coefs.keys() == {v.name for v in tfmod.weights}
+        # Access names of the model's variables (via TensorFlow/Keras API).
+        if hasattr(tf_keras, "version") and tf_keras.version().startswith("3"):
+            names_all_wghts = {v.value.name for v in tfmod.weights}
+            names_trainable = {v.value.name for v in tfmod.trainable_weights}
+        else:
+            names_all_wghts = {v.name for v in tfmod.weights}
+            names_trainable = {v.name for v in tfmod.trainable_weights}
+        # Verify that DecLearn-accessed weights' names match.
+        weights_all = model.get_weights()
+        weights_trn = model.get_weights(trainable=True)
+        assert set(weights_trn.coefs).issubset(weights_all.coefs)
+        assert weights_all.coefs.keys() == names_all_wghts
+        assert weights_trn.coefs.keys() == names_trainable
 
     def test_set_frozen_weights(
         self,
@@ -193,7 +201,7 @@ class TestTensorflowModel(ModelTestSuite):
         # Setup a model with some frozen weights, and gather trainable ones.
         model = test_case.model
         tfmod = model.get_wrapped_model()
-        tfmod.layers[0].trainable = False  # freeze the first layer's weights
+        tfmod.layers[-1].trainable = False  # freeze the last layer's weights
         w_trn = model.get_weights(trainable=True)
         # Test that `set_weights` works if and only if properly parametrized.
         with pytest.raises(KeyError):
@@ -213,7 +221,11 @@ class TestTensorflowModel(ModelTestSuite):
         assert policy.idx == 0
         tfmod = model.get_wrapped_model()
         device = f"{test_case.device}:0"
-        for var in tfmod.weights:
+        if hasattr(tf_keras, "version") and tf_keras.version().startswith("3"):
+            variables = [var.value for var in tfmod.weights]
+        else:
+            variables = tfmod.weights
+        for var in variables:
             assert var.device.endswith(device)
 
 
@@ -238,6 +250,8 @@ class TestBuildKerasLoss:
 
     def test_build_keras_loss_from_string_noclass_function_name(self) -> None:
         """Test `build_keras_loss` with a valid function name string input."""
+        if hasattr(tf_keras, "version") and tf_keras.version().startswith("3"):
+            pytest.skip("Skipping test that no longer works with Keras 3.")
         loss = build_keras_loss("mse", tf_keras.losses.Reduction.SUM)
         assert isinstance(loss, tf_keras.losses.Loss)
         assert hasattr(loss, "loss_fn")
diff --git a/test/model/test_torch_model.py b/test/model/test_torch_model.py
index 74419484..725e81c8 100644
--- a/test/model/test_torch_model.py
+++ b/test/model/test_torch_model.py
@@ -17,7 +17,7 @@
 
 """Unit tests for TorchModel."""
 
-import json
+import copy
 import os
 import typing
 from typing import List, Literal, Tuple
@@ -219,7 +219,7 @@ class TestTorchModel(ModelTestSuite):
     """Unit tests for declearn.model.torch.TorchModel."""
 
     @pytest.mark.filterwarnings("ignore: PyTorch JSON serialization")
-    def test_serialization(
+    def test_get_config(
         self,
         test_case: ModelTestCase,
     ) -> None:
@@ -228,26 +228,44 @@ class TestTorchModel(ModelTestSuite):
             #       due to the (de)serialization of a custom nn.Module
             #       the expected model behaviour is, however, correct
             try:
-                self._test_serialization(test_case)
+                super().test_get_config(test_case)
             except AssertionError:
                 pytest.skip(
                     "skipping failed test due to custom nn.Module pickling"
                 )
-        self._test_serialization(test_case)
+        super().test_get_config(test_case)
 
-    def _test_serialization(
+    @pytest.mark.filterwarnings("ignore: PyTorch JSON serialization")
+    def test_from_config(
+        self,
+        test_case: ModelTestCase,
+    ) -> None:
+        if getattr(test_case, "kind", "") == "RNN":
+            # NOTE: this test fails on python 3.8 but succeeds in 3.10
+            #       due to the (de)serialization of a custom nn.Module
+            #       the expected model behaviour is, however, correct
+            try:
+                self._test_from_config(test_case)
+            except AssertionError:
+                pytest.skip(
+                    "skipping failed test due to custom nn.Module pickling"
+                )
+        self._test_from_config(test_case)
+
+    def _test_from_config(
         self,
         test_case: ModelTestCase,
     ) -> None:
-        """Check that the model can be JSON-(de)serialized properly.
+        """Check that the model can be instantiated from its config.
 
-        This method replaces the parent `test_serialization` one.
+        This method replaces the parent `test_from_config` one.
         """
         # Same setup as in parent test: a model and a config-based other.
         model = test_case.model
-        config = json.dumps(model.get_config())
-        other = model.from_config(json.loads(config))
-        # Verify that both models have the same device policy.
+        config = model.get_config()
+        other = model.from_config(copy.deepcopy(config))
+        # Verify that both models have the same config and device policy.
+        assert other.get_config() == config
         assert model.device_policy == other.device_policy
         # Verify that both models have a similar structure of modules.
         mod_a = list(model.get_wrapped_model().modules())
diff --git a/test/optimizer/test_tflow_optim.py b/test/optimizer/test_tflow_optim.py
index 81d3346a..7658fed0 100644
--- a/test/optimizer/test_tflow_optim.py
+++ b/test/optimizer/test_tflow_optim.py
@@ -31,6 +31,8 @@ try:
         import tensorflow as tf  # type: ignore
 except ModuleNotFoundError:
     pytest.skip("TensorFlow is unavailable", allow_module_level=True)
+else:
+    import tensorflow.keras as tf_keras  # type: ignore
 # pylint: enable=duplicate-code
 
 from declearn.model.tensorflow import TensorflowOptiModule, TensorflowVector
@@ -201,7 +203,11 @@ class TestTensorflowOptiModule(OptiModuleTestSuite):
             grads = GradientsTestCase("tensorflow").mock_gradient
         updts = module.run(grads)
         # Assert that the outputs and internal states are properly placed.
+        if hasattr(tf_keras, "version") and tf_keras.version().startswith("3"):
+            optimizer_variables = [var.value for var in module.optim.variables]
+        else:
+            optimizer_variables = module.optim.variables()
         assert all(device in t.device for t in updts.coefs.values())
-        assert all(device in t.device for t in module.optim.variables())
+        assert all(device in t.device for t in optimizer_variables)
         # Reset device policy to run other tests on CPU as expected.
         set_device_policy(gpu=False)
-- 
GitLab