From aff0ff3185113a97b6603092b79ba9d30659541c Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Fri, 15 Mar 2024 11:29:16 +0100 Subject: [PATCH] Add support for TensorFlow 2.16 / Keras 3. - Apply some minor backward changes (and, when required, branches) to support breaking changes in Keras 3 (the default for TensorFlow 2.16 and onwards), while preserving support for older versions. - On the side, split a test about 'Model' serialization in two tests, one for config serializability, the other for instantiation from a non-serialized config dict. - The current code has been tested (by running all unit tests) with both TensorFlow 2.11 and 2.16, in a Python 3.10 environment. --- declearn/model/tensorflow/_model.py | 24 ++++++++++----- declearn/model/tensorflow/_optim.py | 25 +++++++++++----- declearn/model/tensorflow/utils/_loss.py | 2 +- docs/release-notes/v2.4.0.md | 11 ++++--- pyproject.toml | 4 +-- test/model/model_testing.py | 20 +++++++++---- test/model/test_haiku_model.py | 11 +++++-- test/model/test_sksgd_model.py | 9 ------ test/model/test_tflow_model.py | 30 ++++++++++++++----- test/model/test_torch_model.py | 38 +++++++++++++++++------- test/optimizer/test_tflow_optim.py | 8 ++++- 11 files changed, 125 insertions(+), 57 deletions(-) diff --git a/declearn/model/tensorflow/_model.py b/declearn/model/tensorflow/_model.py index 6d9dfe3c..212b3189 100644 --- a/declearn/model/tensorflow/_model.py +++ b/declearn/model/tensorflow/_model.py @@ -183,10 +183,20 @@ class TensorflowModel(Model): self, trainable: bool = False, ) -> TensorflowVector: + variables = self._get_weight_variables(trainable) + return TensorflowVector({var.name: var.value() for var in variables}) + + def _get_weight_variables( + self, + trainable: bool, + ) -> Iterable[tf.Variable]: + """Access TensorFlow Variables wrapping model weight tensors.""" variables = ( self._model.trainable_weights if trainable else self._model.weights ) - return TensorflowVector({var.name: var.value() for var in variables}) + if hasattr(tf_keras, "version") and tf_keras.version().startswith("3"): + variables = (var.value for var in variables) + return variables def set_weights( self, @@ -198,7 +208,9 @@ class TensorflowModel(Model): "TensorflowModel requires TensorflowVector weights." ) self._verify_weights_compatibility(weights, trainable=trainable) - variables = {var.name: var for var in self._model.weights} + variables = { + var.name: var for var in self._get_weight_variables(trainable) + } with tf.device(self._device): for name, value in weights.coefs.items(): variables[name].assign(value, read_value=False) @@ -225,9 +237,7 @@ class TensorflowModel(Model): In case some expected keys are missing, or additional keys are present. Be verbose about the identified mismatch(es). """ - variables = ( - self._model.trainable_weights if trainable else self._model.weights - ) + variables = self._get_weight_variables(trainable) raise_on_stringsets_mismatch( received=set(vector.coefs), expected={var.name for var in variables}, @@ -247,7 +257,7 @@ class TensorflowModel(Model): norm = tf.constant(max_norm) grads, loss = self._compute_clipped_gradients(*data, norm) self._loss_history.append(float(loss.numpy())) - grads_and_vars = zip(grads, self._model.trainable_weights) + grads_and_vars = zip(grads, self._get_weight_variables(trainable=True)) return TensorflowVector( {var.name: grad for grad, var in grads_and_vars} ) @@ -331,7 +341,7 @@ class TensorflowModel(Model): ) -> None: self._verify_weights_compatibility(updates, trainable=True) with tf.device(self._device): - for var in self._model.trainable_weights: + for var in self._get_weight_variables(trainable=True): updt = updates.coefs[var.name] if isinstance(updt, tf.IndexedSlices): var.scatter_add(updt) diff --git a/declearn/model/tensorflow/_optim.py b/declearn/model/tensorflow/_optim.py index 9513f75f..6538b687 100644 --- a/declearn/model/tensorflow/_optim.py +++ b/declearn/model/tensorflow/_optim.py @@ -17,7 +17,7 @@ """Hacky OptiModule subclass enabling the use of a keras Optimizer.""" -from typing import Any, Dict, Union +from typing import Any, Dict, List, Union # fmt: off # pylint: disable=import-error,no-name-in-module @@ -189,7 +189,7 @@ class TensorflowOptiModule(OptiModule): key: tf.Variable(tf.zeros_like(grad), name=key) for key, grad in gradients.coefs.items() } - self.optim.build(self._vars.values()) + self.optim.build(list(self._vars.values())) def reset(self) -> None: """Reset this module to its uninitialized state. @@ -206,7 +206,7 @@ class TensorflowOptiModule(OptiModule): policy = get_device_policy() self._device = select_device(gpu=policy.gpu, idx=policy.idx) with tf.device(self._device): - self._vars = {} + self._vars.clear() self.optim = self.optim.from_config(self.optim.get_config()) def get_config( @@ -222,11 +222,20 @@ class TensorflowOptiModule(OptiModule): key: (val.shape.as_list(), val.dtype.name) for key, val in self._vars.items() } + variables = self._get_optimizer_variables() state = TensorflowVector( - {var.name: var.value() for var in self.optim.variables()} + {str(i): v.value() for i, v in enumerate(variables)} ) return {"specs": specs, "state": state} + def _get_optimizer_variables( + self, + ) -> List[tf.Variable]: + """Access wrapped optimizer's variables as 'tf.Variable' instances.""" + if hasattr(tf_keras, "version") and tf_keras.version().startswith("3"): + return [var.value for var in self.optim.variables] + return self.optim.variables() + def set_state( self, state: Dict[str, Any], @@ -244,9 +253,9 @@ class TensorflowOptiModule(OptiModule): key: tf.Variable(tf.zeros(shape, dtype), name=key) for key, (shape, dtype) in state["specs"].items() } - self.optim.build(self._vars.values()) + self.optim.build(list(self._vars.values())) # Restore optimizer variables' values from the input state dict. - opt_vars = {var.name: var for var in self.optim.variables()} + opt_vars = self._get_optimizer_variables() with tf.device(self._device): - for key, val in state["state"].coefs.items(): - opt_vars[key].assign(val, read_value=False) + for var, val in zip(opt_vars, state["state"].coefs.values()): + var.assign(val, read_value=False) diff --git a/declearn/model/tensorflow/utils/_loss.py b/declearn/model/tensorflow/utils/_loss.py index 31069ee2..da49339d 100644 --- a/declearn/model/tensorflow/utils/_loss.py +++ b/declearn/model/tensorflow/utils/_loss.py @@ -48,7 +48,7 @@ class LossFunction(tf_keras.losses.Loss): reduction: str = tf_keras.losses.Reduction.NONE, name: Optional[str] = None, ) -> None: - super().__init__(reduction, name) + super().__init__(reduction=reduction, name=name) self.loss_fn = tf_keras.losses.deserialize(loss_fn) def call( diff --git a/docs/release-notes/v2.4.0.md b/docs/release-notes/v2.4.0.md index e2af281a..cfaa2de8 100644 --- a/docs/release-notes/v2.4.0.md +++ b/docs/release-notes/v2.4.0.md @@ -274,10 +274,13 @@ Older TensorFlow versions (v2.5 to 2.10 included) were improperly marked as supported in spite of `TensorflowOptiModule` requiring at least version 2.11 to work (due to changes of the Keras Optimizer API). This has been corrected. -Upcoming TensorFlow versions (v2.16 and onwards) introduce backward-breaking -changes, which are probably due to the backend swap from Keras 2 to Keras 3. -To keep safe, these versions are currently marked as unsupported, awaiting -further investigation once version 2.16 is finalized. +The latest TensorFlow version (v2.16) introduces backward-breaking changes, due to the backend swap from Keras 2 to Keras 3. Our backend code was updated to +both add support for this newer Keras backend, and preserve existing support. + +Note that at the moment, the CI does not support TensorFlow above 2.13, due to +newer versions not being compatible with Python 3.8. As such, our code will be +tested to remain backward-compatible. Forward compatibility has been (and will +keep being) tested locally with a newer Python version. ### Deprecate `declearn.dataset.load_from_json` diff --git a/pyproject.toml b/pyproject.toml index 2223f3d6..e9224145 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,7 @@ all = [ # all non-tests extra dependencies "jax[cpu] ~= 0.4.1", "opacus ~= 1.4", "protobuf >= 3.19", - "tensorflow ~= 2.11, < 2.16", + "tensorflow ~= 2.11", "torch >= 1.13, < 3.0", "websockets >= 10.1, < 13.0", ] @@ -71,7 +71,7 @@ haiku = [ "jax[cpu] ~= 0.4.1", # NOTE: GPU support must be manually installed ] tensorflow = [ - "tensorflow ~= 2.11, < 2.16", # starting with 2.16 TF upgrades to Keras 3 + "tensorflow ~= 2.11", ] torch = [ # generic requirements for Torch "torch >= 1.13, < 3.0", diff --git a/test/model/model_testing.py b/test/model/model_testing.py index 4806e0df..76cf4583 100644 --- a/test/model/model_testing.py +++ b/test/model/model_testing.py @@ -17,13 +17,14 @@ """Shared testing code for TensorFlow and Torch models' unit tests.""" +import copy import json from typing import Any, Generic, List, Protocol, Tuple, Type, TypeVar, Union import numpy as np from declearn.model.api import Model, Vector -from declearn.test_utils import to_numpy +from declearn.test_utils import assert_json_serializable_dict, to_numpy from declearn.typing import Batch from declearn.utils import json_pack, json_unpack @@ -59,14 +60,23 @@ class ModelTestCase(Protocol, Generic[VectorT]): class ModelTestSuite: """Unit tests for a declearn Model.""" - def test_serialization( + def test_get_config( self, test_case: ModelTestCase, ) -> None: - """Check that the model can be JSON-(de)serialized properly.""" + """Check that the model's config is JSON-serializable.""" model = test_case.model - config = json.dumps(model.get_config()) - other = model.from_config(json.loads(config)) + config = model.get_config() + assert_json_serializable_dict(config) + + def test_from_config( + self, + test_case: ModelTestCase, + ) -> None: + """Check that the model can be instantiated from its config.""" + model = test_case.model + config = model.get_config() + other = model.from_config(copy.deepcopy(config)) assert model.get_config() == other.get_config() assert model.device_policy == other.device_policy diff --git a/test/model/test_haiku_model.py b/test/model/test_haiku_model.py index 7bb9ca7b..7e31f048 100644 --- a/test/model/test_haiku_model.py +++ b/test/model/test_haiku_model.py @@ -234,11 +234,18 @@ class TestHaikuModel(ModelTestSuite): """Unit tests for declearn.model.tensorflow.TensorflowModel.""" @pytest.mark.filterwarnings("ignore: Our custom Haiku serialization") - def test_serialization( + def test_get_config( self, test_case: ModelTestCase, ) -> None: - super().test_serialization(test_case) + super().test_get_config(test_case) + + @pytest.mark.filterwarnings("ignore: Our custom Haiku serialization") + def test_from_config( + self, + test_case: ModelTestCase, + ) -> None: + super().test_from_config(test_case) @pytest.mark.parametrize( "criterion_type", ["names", "pytree", "predicate"] diff --git a/test/model/test_sksgd_model.py b/test/model/test_sksgd_model.py index 35f1b238..f88de8a7 100644 --- a/test/model/test_sksgd_model.py +++ b/test/model/test_sksgd_model.py @@ -120,15 +120,6 @@ def fixture_test_case( class TestSklearnSGDModel(ModelTestSuite): """Unit tests for declearn.model.sklearn.SklearnSGDModel.""" - def test_serialization( # type: ignore # Liskov does not matter here - self, - test_case: SklearnSGDTestCase, - ) -> None: - # Avoid re-running tests that are unaltered by data parameters. - if test_case.s_weights or test_case.as_sparse: - return None - return super().test_serialization(test_case) - def test_initialization( self, test_case: SklearnSGDTestCase, diff --git a/test/model/test_tflow_model.py b/test/model/test_tflow_model.py index 9b1feb43..65805b5f 100644 --- a/test/model/test_tflow_model.py +++ b/test/model/test_tflow_model.py @@ -30,7 +30,6 @@ try: except ModuleNotFoundError: pytest.skip("TensorFlow is unavailable", allow_module_level=True) else: - # pylint: disable=import-error,no-name-in-module import tensorflow.keras as tf_keras # type: ignore from declearn.model.tensorflow import TensorflowModel, TensorflowVector @@ -176,14 +175,23 @@ class TestTensorflowModel(ModelTestSuite): test_case: ModelTestCase, ) -> None: """Check that `get_weights` behaves properly with frozen weights.""" + # Set up a model with a frozen layer. model = test_case.model tfmod = model.get_wrapped_model() tfmod.layers[0].trainable = False # freeze the first layer's weights - w_all = model.get_weights() - w_trn = model.get_weights(trainable=True) - assert set(w_trn.coefs).issubset(w_all.coefs) # check on keys - assert w_trn.coefs.keys() == {v.name for v in tfmod.trainable_weights} - assert w_all.coefs.keys() == {v.name for v in tfmod.weights} + # Access names of the model's variables (via TensorFlow/Keras API). + if hasattr(tf_keras, "version") and tf_keras.version().startswith("3"): + names_all_wghts = {v.value.name for v in tfmod.weights} + names_trainable = {v.value.name for v in tfmod.trainable_weights} + else: + names_all_wghts = {v.name for v in tfmod.weights} + names_trainable = {v.name for v in tfmod.trainable_weights} + # Verify that DecLearn-accessed weights' names match. + weights_all = model.get_weights() + weights_trn = model.get_weights(trainable=True) + assert set(weights_trn.coefs).issubset(weights_all.coefs) + assert weights_all.coefs.keys() == names_all_wghts + assert weights_trn.coefs.keys() == names_trainable def test_set_frozen_weights( self, @@ -193,7 +201,7 @@ class TestTensorflowModel(ModelTestSuite): # Setup a model with some frozen weights, and gather trainable ones. model = test_case.model tfmod = model.get_wrapped_model() - tfmod.layers[0].trainable = False # freeze the first layer's weights + tfmod.layers[-1].trainable = False # freeze the last layer's weights w_trn = model.get_weights(trainable=True) # Test that `set_weights` works if and only if properly parametrized. with pytest.raises(KeyError): @@ -213,7 +221,11 @@ class TestTensorflowModel(ModelTestSuite): assert policy.idx == 0 tfmod = model.get_wrapped_model() device = f"{test_case.device}:0" - for var in tfmod.weights: + if hasattr(tf_keras, "version") and tf_keras.version().startswith("3"): + variables = [var.value for var in tfmod.weights] + else: + variables = tfmod.weights + for var in variables: assert var.device.endswith(device) @@ -238,6 +250,8 @@ class TestBuildKerasLoss: def test_build_keras_loss_from_string_noclass_function_name(self) -> None: """Test `build_keras_loss` with a valid function name string input.""" + if hasattr(tf_keras, "version") and tf_keras.version().startswith("3"): + pytest.skip("Skipping test that no longer works with Keras 3.") loss = build_keras_loss("mse", tf_keras.losses.Reduction.SUM) assert isinstance(loss, tf_keras.losses.Loss) assert hasattr(loss, "loss_fn") diff --git a/test/model/test_torch_model.py b/test/model/test_torch_model.py index 74419484..725e81c8 100644 --- a/test/model/test_torch_model.py +++ b/test/model/test_torch_model.py @@ -17,7 +17,7 @@ """Unit tests for TorchModel.""" -import json +import copy import os import typing from typing import List, Literal, Tuple @@ -219,7 +219,7 @@ class TestTorchModel(ModelTestSuite): """Unit tests for declearn.model.torch.TorchModel.""" @pytest.mark.filterwarnings("ignore: PyTorch JSON serialization") - def test_serialization( + def test_get_config( self, test_case: ModelTestCase, ) -> None: @@ -228,26 +228,44 @@ class TestTorchModel(ModelTestSuite): # due to the (de)serialization of a custom nn.Module # the expected model behaviour is, however, correct try: - self._test_serialization(test_case) + super().test_get_config(test_case) except AssertionError: pytest.skip( "skipping failed test due to custom nn.Module pickling" ) - self._test_serialization(test_case) + super().test_get_config(test_case) - def _test_serialization( + @pytest.mark.filterwarnings("ignore: PyTorch JSON serialization") + def test_from_config( + self, + test_case: ModelTestCase, + ) -> None: + if getattr(test_case, "kind", "") == "RNN": + # NOTE: this test fails on python 3.8 but succeeds in 3.10 + # due to the (de)serialization of a custom nn.Module + # the expected model behaviour is, however, correct + try: + self._test_from_config(test_case) + except AssertionError: + pytest.skip( + "skipping failed test due to custom nn.Module pickling" + ) + self._test_from_config(test_case) + + def _test_from_config( self, test_case: ModelTestCase, ) -> None: - """Check that the model can be JSON-(de)serialized properly. + """Check that the model can be instantiated from its config. - This method replaces the parent `test_serialization` one. + This method replaces the parent `test_from_config` one. """ # Same setup as in parent test: a model and a config-based other. model = test_case.model - config = json.dumps(model.get_config()) - other = model.from_config(json.loads(config)) - # Verify that both models have the same device policy. + config = model.get_config() + other = model.from_config(copy.deepcopy(config)) + # Verify that both models have the same config and device policy. + assert other.get_config() == config assert model.device_policy == other.device_policy # Verify that both models have a similar structure of modules. mod_a = list(model.get_wrapped_model().modules()) diff --git a/test/optimizer/test_tflow_optim.py b/test/optimizer/test_tflow_optim.py index 81d3346a..7658fed0 100644 --- a/test/optimizer/test_tflow_optim.py +++ b/test/optimizer/test_tflow_optim.py @@ -31,6 +31,8 @@ try: import tensorflow as tf # type: ignore except ModuleNotFoundError: pytest.skip("TensorFlow is unavailable", allow_module_level=True) +else: + import tensorflow.keras as tf_keras # type: ignore # pylint: enable=duplicate-code from declearn.model.tensorflow import TensorflowOptiModule, TensorflowVector @@ -201,7 +203,11 @@ class TestTensorflowOptiModule(OptiModuleTestSuite): grads = GradientsTestCase("tensorflow").mock_gradient updts = module.run(grads) # Assert that the outputs and internal states are properly placed. + if hasattr(tf_keras, "version") and tf_keras.version().startswith("3"): + optimizer_variables = [var.value for var in module.optim.variables] + else: + optimizer_variables = module.optim.variables() assert all(device in t.device for t in updts.coefs.values()) - assert all(device in t.device for t in module.optim.variables()) + assert all(device in t.device for t in optimizer_variables) # Reset device policy to run other tests on CPU as expected. set_device_policy(gpu=False) -- GitLab