Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit aff0ff31 authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Add support for TensorFlow 2.16 / Keras 3.

- Apply some minor backward changes (and, when required, branches) to
  support breaking changes in Keras 3 (the default for TensorFlow 2.16
  and onwards), while preserving support for older versions.
- On the side, split a test about 'Model' serialization in two tests,
  one for config serializability, the other for instantiation from a
  non-serialized config dict.
- The current code has been tested (by running all unit tests) with
  both TensorFlow 2.11 and 2.16, in a Python 3.10 environment.
parent dce82854
No related branches found
No related tags found
1 merge request!63Finalize DecLearn v2.4.0.
Pipeline #943445 passed
...@@ -183,10 +183,20 @@ class TensorflowModel(Model): ...@@ -183,10 +183,20 @@ class TensorflowModel(Model):
self, self,
trainable: bool = False, trainable: bool = False,
) -> TensorflowVector: ) -> TensorflowVector:
variables = self._get_weight_variables(trainable)
return TensorflowVector({var.name: var.value() for var in variables})
def _get_weight_variables(
self,
trainable: bool,
) -> Iterable[tf.Variable]:
"""Access TensorFlow Variables wrapping model weight tensors."""
variables = ( variables = (
self._model.trainable_weights if trainable else self._model.weights self._model.trainable_weights if trainable else self._model.weights
) )
return TensorflowVector({var.name: var.value() for var in variables}) if hasattr(tf_keras, "version") and tf_keras.version().startswith("3"):
variables = (var.value for var in variables)
return variables
def set_weights( def set_weights(
self, self,
...@@ -198,7 +208,9 @@ class TensorflowModel(Model): ...@@ -198,7 +208,9 @@ class TensorflowModel(Model):
"TensorflowModel requires TensorflowVector weights." "TensorflowModel requires TensorflowVector weights."
) )
self._verify_weights_compatibility(weights, trainable=trainable) self._verify_weights_compatibility(weights, trainable=trainable)
variables = {var.name: var for var in self._model.weights} variables = {
var.name: var for var in self._get_weight_variables(trainable)
}
with tf.device(self._device): with tf.device(self._device):
for name, value in weights.coefs.items(): for name, value in weights.coefs.items():
variables[name].assign(value, read_value=False) variables[name].assign(value, read_value=False)
...@@ -225,9 +237,7 @@ class TensorflowModel(Model): ...@@ -225,9 +237,7 @@ class TensorflowModel(Model):
In case some expected keys are missing, or additional keys In case some expected keys are missing, or additional keys
are present. Be verbose about the identified mismatch(es). are present. Be verbose about the identified mismatch(es).
""" """
variables = ( variables = self._get_weight_variables(trainable)
self._model.trainable_weights if trainable else self._model.weights
)
raise_on_stringsets_mismatch( raise_on_stringsets_mismatch(
received=set(vector.coefs), received=set(vector.coefs),
expected={var.name for var in variables}, expected={var.name for var in variables},
...@@ -247,7 +257,7 @@ class TensorflowModel(Model): ...@@ -247,7 +257,7 @@ class TensorflowModel(Model):
norm = tf.constant(max_norm) norm = tf.constant(max_norm)
grads, loss = self._compute_clipped_gradients(*data, norm) grads, loss = self._compute_clipped_gradients(*data, norm)
self._loss_history.append(float(loss.numpy())) self._loss_history.append(float(loss.numpy()))
grads_and_vars = zip(grads, self._model.trainable_weights) grads_and_vars = zip(grads, self._get_weight_variables(trainable=True))
return TensorflowVector( return TensorflowVector(
{var.name: grad for grad, var in grads_and_vars} {var.name: grad for grad, var in grads_and_vars}
) )
...@@ -331,7 +341,7 @@ class TensorflowModel(Model): ...@@ -331,7 +341,7 @@ class TensorflowModel(Model):
) -> None: ) -> None:
self._verify_weights_compatibility(updates, trainable=True) self._verify_weights_compatibility(updates, trainable=True)
with tf.device(self._device): with tf.device(self._device):
for var in self._model.trainable_weights: for var in self._get_weight_variables(trainable=True):
updt = updates.coefs[var.name] updt = updates.coefs[var.name]
if isinstance(updt, tf.IndexedSlices): if isinstance(updt, tf.IndexedSlices):
var.scatter_add(updt) var.scatter_add(updt)
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
"""Hacky OptiModule subclass enabling the use of a keras Optimizer.""" """Hacky OptiModule subclass enabling the use of a keras Optimizer."""
from typing import Any, Dict, Union from typing import Any, Dict, List, Union
# fmt: off # fmt: off
# pylint: disable=import-error,no-name-in-module # pylint: disable=import-error,no-name-in-module
...@@ -189,7 +189,7 @@ class TensorflowOptiModule(OptiModule): ...@@ -189,7 +189,7 @@ class TensorflowOptiModule(OptiModule):
key: tf.Variable(tf.zeros_like(grad), name=key) key: tf.Variable(tf.zeros_like(grad), name=key)
for key, grad in gradients.coefs.items() for key, grad in gradients.coefs.items()
} }
self.optim.build(self._vars.values()) self.optim.build(list(self._vars.values()))
def reset(self) -> None: def reset(self) -> None:
"""Reset this module to its uninitialized state. """Reset this module to its uninitialized state.
...@@ -206,7 +206,7 @@ class TensorflowOptiModule(OptiModule): ...@@ -206,7 +206,7 @@ class TensorflowOptiModule(OptiModule):
policy = get_device_policy() policy = get_device_policy()
self._device = select_device(gpu=policy.gpu, idx=policy.idx) self._device = select_device(gpu=policy.gpu, idx=policy.idx)
with tf.device(self._device): with tf.device(self._device):
self._vars = {} self._vars.clear()
self.optim = self.optim.from_config(self.optim.get_config()) self.optim = self.optim.from_config(self.optim.get_config())
def get_config( def get_config(
...@@ -222,11 +222,20 @@ class TensorflowOptiModule(OptiModule): ...@@ -222,11 +222,20 @@ class TensorflowOptiModule(OptiModule):
key: (val.shape.as_list(), val.dtype.name) key: (val.shape.as_list(), val.dtype.name)
for key, val in self._vars.items() for key, val in self._vars.items()
} }
variables = self._get_optimizer_variables()
state = TensorflowVector( state = TensorflowVector(
{var.name: var.value() for var in self.optim.variables()} {str(i): v.value() for i, v in enumerate(variables)}
) )
return {"specs": specs, "state": state} return {"specs": specs, "state": state}
def _get_optimizer_variables(
self,
) -> List[tf.Variable]:
"""Access wrapped optimizer's variables as 'tf.Variable' instances."""
if hasattr(tf_keras, "version") and tf_keras.version().startswith("3"):
return [var.value for var in self.optim.variables]
return self.optim.variables()
def set_state( def set_state(
self, self,
state: Dict[str, Any], state: Dict[str, Any],
...@@ -244,9 +253,9 @@ class TensorflowOptiModule(OptiModule): ...@@ -244,9 +253,9 @@ class TensorflowOptiModule(OptiModule):
key: tf.Variable(tf.zeros(shape, dtype), name=key) key: tf.Variable(tf.zeros(shape, dtype), name=key)
for key, (shape, dtype) in state["specs"].items() for key, (shape, dtype) in state["specs"].items()
} }
self.optim.build(self._vars.values()) self.optim.build(list(self._vars.values()))
# Restore optimizer variables' values from the input state dict. # Restore optimizer variables' values from the input state dict.
opt_vars = {var.name: var for var in self.optim.variables()} opt_vars = self._get_optimizer_variables()
with tf.device(self._device): with tf.device(self._device):
for key, val in state["state"].coefs.items(): for var, val in zip(opt_vars, state["state"].coefs.values()):
opt_vars[key].assign(val, read_value=False) var.assign(val, read_value=False)
...@@ -48,7 +48,7 @@ class LossFunction(tf_keras.losses.Loss): ...@@ -48,7 +48,7 @@ class LossFunction(tf_keras.losses.Loss):
reduction: str = tf_keras.losses.Reduction.NONE, reduction: str = tf_keras.losses.Reduction.NONE,
name: Optional[str] = None, name: Optional[str] = None,
) -> None: ) -> None:
super().__init__(reduction, name) super().__init__(reduction=reduction, name=name)
self.loss_fn = tf_keras.losses.deserialize(loss_fn) self.loss_fn = tf_keras.losses.deserialize(loss_fn)
def call( def call(
......
...@@ -274,10 +274,13 @@ Older TensorFlow versions (v2.5 to 2.10 included) were improperly marked as ...@@ -274,10 +274,13 @@ Older TensorFlow versions (v2.5 to 2.10 included) were improperly marked as
supported in spite of `TensorflowOptiModule` requiring at least version 2.11 supported in spite of `TensorflowOptiModule` requiring at least version 2.11
to work (due to changes of the Keras Optimizer API). This has been corrected. to work (due to changes of the Keras Optimizer API). This has been corrected.
Upcoming TensorFlow versions (v2.16 and onwards) introduce backward-breaking The latest TensorFlow version (v2.16) introduces backward-breaking changes, due to the backend swap from Keras 2 to Keras 3. Our backend code was updated to
changes, which are probably due to the backend swap from Keras 2 to Keras 3. both add support for this newer Keras backend, and preserve existing support.
To keep safe, these versions are currently marked as unsupported, awaiting
further investigation once version 2.16 is finalized. Note that at the moment, the CI does not support TensorFlow above 2.13, due to
newer versions not being compatible with Python 3.8. As such, our code will be
tested to remain backward-compatible. Forward compatibility has been (and will
keep being) tested locally with a newer Python version.
### Deprecate `declearn.dataset.load_from_json` ### Deprecate `declearn.dataset.load_from_json`
......
...@@ -53,7 +53,7 @@ all = [ # all non-tests extra dependencies ...@@ -53,7 +53,7 @@ all = [ # all non-tests extra dependencies
"jax[cpu] ~= 0.4.1", "jax[cpu] ~= 0.4.1",
"opacus ~= 1.4", "opacus ~= 1.4",
"protobuf >= 3.19", "protobuf >= 3.19",
"tensorflow ~= 2.11, < 2.16", "tensorflow ~= 2.11",
"torch >= 1.13, < 3.0", "torch >= 1.13, < 3.0",
"websockets >= 10.1, < 13.0", "websockets >= 10.1, < 13.0",
] ]
...@@ -71,7 +71,7 @@ haiku = [ ...@@ -71,7 +71,7 @@ haiku = [
"jax[cpu] ~= 0.4.1", # NOTE: GPU support must be manually installed "jax[cpu] ~= 0.4.1", # NOTE: GPU support must be manually installed
] ]
tensorflow = [ tensorflow = [
"tensorflow ~= 2.11, < 2.16", # starting with 2.16 TF upgrades to Keras 3 "tensorflow ~= 2.11",
] ]
torch = [ # generic requirements for Torch torch = [ # generic requirements for Torch
"torch >= 1.13, < 3.0", "torch >= 1.13, < 3.0",
......
...@@ -17,13 +17,14 @@ ...@@ -17,13 +17,14 @@
"""Shared testing code for TensorFlow and Torch models' unit tests.""" """Shared testing code for TensorFlow and Torch models' unit tests."""
import copy
import json import json
from typing import Any, Generic, List, Protocol, Tuple, Type, TypeVar, Union from typing import Any, Generic, List, Protocol, Tuple, Type, TypeVar, Union
import numpy as np import numpy as np
from declearn.model.api import Model, Vector from declearn.model.api import Model, Vector
from declearn.test_utils import to_numpy from declearn.test_utils import assert_json_serializable_dict, to_numpy
from declearn.typing import Batch from declearn.typing import Batch
from declearn.utils import json_pack, json_unpack from declearn.utils import json_pack, json_unpack
...@@ -59,14 +60,23 @@ class ModelTestCase(Protocol, Generic[VectorT]): ...@@ -59,14 +60,23 @@ class ModelTestCase(Protocol, Generic[VectorT]):
class ModelTestSuite: class ModelTestSuite:
"""Unit tests for a declearn Model.""" """Unit tests for a declearn Model."""
def test_serialization( def test_get_config(
self, self,
test_case: ModelTestCase, test_case: ModelTestCase,
) -> None: ) -> None:
"""Check that the model can be JSON-(de)serialized properly.""" """Check that the model's config is JSON-serializable."""
model = test_case.model model = test_case.model
config = json.dumps(model.get_config()) config = model.get_config()
other = model.from_config(json.loads(config)) assert_json_serializable_dict(config)
def test_from_config(
self,
test_case: ModelTestCase,
) -> None:
"""Check that the model can be instantiated from its config."""
model = test_case.model
config = model.get_config()
other = model.from_config(copy.deepcopy(config))
assert model.get_config() == other.get_config() assert model.get_config() == other.get_config()
assert model.device_policy == other.device_policy assert model.device_policy == other.device_policy
......
...@@ -234,11 +234,18 @@ class TestHaikuModel(ModelTestSuite): ...@@ -234,11 +234,18 @@ class TestHaikuModel(ModelTestSuite):
"""Unit tests for declearn.model.tensorflow.TensorflowModel.""" """Unit tests for declearn.model.tensorflow.TensorflowModel."""
@pytest.mark.filterwarnings("ignore: Our custom Haiku serialization") @pytest.mark.filterwarnings("ignore: Our custom Haiku serialization")
def test_serialization( def test_get_config(
self, self,
test_case: ModelTestCase, test_case: ModelTestCase,
) -> None: ) -> None:
super().test_serialization(test_case) super().test_get_config(test_case)
@pytest.mark.filterwarnings("ignore: Our custom Haiku serialization")
def test_from_config(
self,
test_case: ModelTestCase,
) -> None:
super().test_from_config(test_case)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"criterion_type", ["names", "pytree", "predicate"] "criterion_type", ["names", "pytree", "predicate"]
......
...@@ -120,15 +120,6 @@ def fixture_test_case( ...@@ -120,15 +120,6 @@ def fixture_test_case(
class TestSklearnSGDModel(ModelTestSuite): class TestSklearnSGDModel(ModelTestSuite):
"""Unit tests for declearn.model.sklearn.SklearnSGDModel.""" """Unit tests for declearn.model.sklearn.SklearnSGDModel."""
def test_serialization( # type: ignore # Liskov does not matter here
self,
test_case: SklearnSGDTestCase,
) -> None:
# Avoid re-running tests that are unaltered by data parameters.
if test_case.s_weights or test_case.as_sparse:
return None
return super().test_serialization(test_case)
def test_initialization( def test_initialization(
self, self,
test_case: SklearnSGDTestCase, test_case: SklearnSGDTestCase,
......
...@@ -30,7 +30,6 @@ try: ...@@ -30,7 +30,6 @@ try:
except ModuleNotFoundError: except ModuleNotFoundError:
pytest.skip("TensorFlow is unavailable", allow_module_level=True) pytest.skip("TensorFlow is unavailable", allow_module_level=True)
else: else:
# pylint: disable=import-error,no-name-in-module
import tensorflow.keras as tf_keras # type: ignore import tensorflow.keras as tf_keras # type: ignore
from declearn.model.tensorflow import TensorflowModel, TensorflowVector from declearn.model.tensorflow import TensorflowModel, TensorflowVector
...@@ -176,14 +175,23 @@ class TestTensorflowModel(ModelTestSuite): ...@@ -176,14 +175,23 @@ class TestTensorflowModel(ModelTestSuite):
test_case: ModelTestCase, test_case: ModelTestCase,
) -> None: ) -> None:
"""Check that `get_weights` behaves properly with frozen weights.""" """Check that `get_weights` behaves properly with frozen weights."""
# Set up a model with a frozen layer.
model = test_case.model model = test_case.model
tfmod = model.get_wrapped_model() tfmod = model.get_wrapped_model()
tfmod.layers[0].trainable = False # freeze the first layer's weights tfmod.layers[0].trainable = False # freeze the first layer's weights
w_all = model.get_weights() # Access names of the model's variables (via TensorFlow/Keras API).
w_trn = model.get_weights(trainable=True) if hasattr(tf_keras, "version") and tf_keras.version().startswith("3"):
assert set(w_trn.coefs).issubset(w_all.coefs) # check on keys names_all_wghts = {v.value.name for v in tfmod.weights}
assert w_trn.coefs.keys() == {v.name for v in tfmod.trainable_weights} names_trainable = {v.value.name for v in tfmod.trainable_weights}
assert w_all.coefs.keys() == {v.name for v in tfmod.weights} else:
names_all_wghts = {v.name for v in tfmod.weights}
names_trainable = {v.name for v in tfmod.trainable_weights}
# Verify that DecLearn-accessed weights' names match.
weights_all = model.get_weights()
weights_trn = model.get_weights(trainable=True)
assert set(weights_trn.coefs).issubset(weights_all.coefs)
assert weights_all.coefs.keys() == names_all_wghts
assert weights_trn.coefs.keys() == names_trainable
def test_set_frozen_weights( def test_set_frozen_weights(
self, self,
...@@ -193,7 +201,7 @@ class TestTensorflowModel(ModelTestSuite): ...@@ -193,7 +201,7 @@ class TestTensorflowModel(ModelTestSuite):
# Setup a model with some frozen weights, and gather trainable ones. # Setup a model with some frozen weights, and gather trainable ones.
model = test_case.model model = test_case.model
tfmod = model.get_wrapped_model() tfmod = model.get_wrapped_model()
tfmod.layers[0].trainable = False # freeze the first layer's weights tfmod.layers[-1].trainable = False # freeze the last layer's weights
w_trn = model.get_weights(trainable=True) w_trn = model.get_weights(trainable=True)
# Test that `set_weights` works if and only if properly parametrized. # Test that `set_weights` works if and only if properly parametrized.
with pytest.raises(KeyError): with pytest.raises(KeyError):
...@@ -213,7 +221,11 @@ class TestTensorflowModel(ModelTestSuite): ...@@ -213,7 +221,11 @@ class TestTensorflowModel(ModelTestSuite):
assert policy.idx == 0 assert policy.idx == 0
tfmod = model.get_wrapped_model() tfmod = model.get_wrapped_model()
device = f"{test_case.device}:0" device = f"{test_case.device}:0"
for var in tfmod.weights: if hasattr(tf_keras, "version") and tf_keras.version().startswith("3"):
variables = [var.value for var in tfmod.weights]
else:
variables = tfmod.weights
for var in variables:
assert var.device.endswith(device) assert var.device.endswith(device)
...@@ -238,6 +250,8 @@ class TestBuildKerasLoss: ...@@ -238,6 +250,8 @@ class TestBuildKerasLoss:
def test_build_keras_loss_from_string_noclass_function_name(self) -> None: def test_build_keras_loss_from_string_noclass_function_name(self) -> None:
"""Test `build_keras_loss` with a valid function name string input.""" """Test `build_keras_loss` with a valid function name string input."""
if hasattr(tf_keras, "version") and tf_keras.version().startswith("3"):
pytest.skip("Skipping test that no longer works with Keras 3.")
loss = build_keras_loss("mse", tf_keras.losses.Reduction.SUM) loss = build_keras_loss("mse", tf_keras.losses.Reduction.SUM)
assert isinstance(loss, tf_keras.losses.Loss) assert isinstance(loss, tf_keras.losses.Loss)
assert hasattr(loss, "loss_fn") assert hasattr(loss, "loss_fn")
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
"""Unit tests for TorchModel.""" """Unit tests for TorchModel."""
import json import copy
import os import os
import typing import typing
from typing import List, Literal, Tuple from typing import List, Literal, Tuple
...@@ -219,7 +219,7 @@ class TestTorchModel(ModelTestSuite): ...@@ -219,7 +219,7 @@ class TestTorchModel(ModelTestSuite):
"""Unit tests for declearn.model.torch.TorchModel.""" """Unit tests for declearn.model.torch.TorchModel."""
@pytest.mark.filterwarnings("ignore: PyTorch JSON serialization") @pytest.mark.filterwarnings("ignore: PyTorch JSON serialization")
def test_serialization( def test_get_config(
self, self,
test_case: ModelTestCase, test_case: ModelTestCase,
) -> None: ) -> None:
...@@ -228,26 +228,44 @@ class TestTorchModel(ModelTestSuite): ...@@ -228,26 +228,44 @@ class TestTorchModel(ModelTestSuite):
# due to the (de)serialization of a custom nn.Module # due to the (de)serialization of a custom nn.Module
# the expected model behaviour is, however, correct # the expected model behaviour is, however, correct
try: try:
self._test_serialization(test_case) super().test_get_config(test_case)
except AssertionError: except AssertionError:
pytest.skip( pytest.skip(
"skipping failed test due to custom nn.Module pickling" "skipping failed test due to custom nn.Module pickling"
) )
self._test_serialization(test_case) super().test_get_config(test_case)
def _test_serialization( @pytest.mark.filterwarnings("ignore: PyTorch JSON serialization")
def test_from_config(
self,
test_case: ModelTestCase,
) -> None:
if getattr(test_case, "kind", "") == "RNN":
# NOTE: this test fails on python 3.8 but succeeds in 3.10
# due to the (de)serialization of a custom nn.Module
# the expected model behaviour is, however, correct
try:
self._test_from_config(test_case)
except AssertionError:
pytest.skip(
"skipping failed test due to custom nn.Module pickling"
)
self._test_from_config(test_case)
def _test_from_config(
self, self,
test_case: ModelTestCase, test_case: ModelTestCase,
) -> None: ) -> None:
"""Check that the model can be JSON-(de)serialized properly. """Check that the model can be instantiated from its config.
This method replaces the parent `test_serialization` one. This method replaces the parent `test_from_config` one.
""" """
# Same setup as in parent test: a model and a config-based other. # Same setup as in parent test: a model and a config-based other.
model = test_case.model model = test_case.model
config = json.dumps(model.get_config()) config = model.get_config()
other = model.from_config(json.loads(config)) other = model.from_config(copy.deepcopy(config))
# Verify that both models have the same device policy. # Verify that both models have the same config and device policy.
assert other.get_config() == config
assert model.device_policy == other.device_policy assert model.device_policy == other.device_policy
# Verify that both models have a similar structure of modules. # Verify that both models have a similar structure of modules.
mod_a = list(model.get_wrapped_model().modules()) mod_a = list(model.get_wrapped_model().modules())
......
...@@ -31,6 +31,8 @@ try: ...@@ -31,6 +31,8 @@ try:
import tensorflow as tf # type: ignore import tensorflow as tf # type: ignore
except ModuleNotFoundError: except ModuleNotFoundError:
pytest.skip("TensorFlow is unavailable", allow_module_level=True) pytest.skip("TensorFlow is unavailable", allow_module_level=True)
else:
import tensorflow.keras as tf_keras # type: ignore
# pylint: enable=duplicate-code # pylint: enable=duplicate-code
from declearn.model.tensorflow import TensorflowOptiModule, TensorflowVector from declearn.model.tensorflow import TensorflowOptiModule, TensorflowVector
...@@ -201,7 +203,11 @@ class TestTensorflowOptiModule(OptiModuleTestSuite): ...@@ -201,7 +203,11 @@ class TestTensorflowOptiModule(OptiModuleTestSuite):
grads = GradientsTestCase("tensorflow").mock_gradient grads = GradientsTestCase("tensorflow").mock_gradient
updts = module.run(grads) updts = module.run(grads)
# Assert that the outputs and internal states are properly placed. # Assert that the outputs and internal states are properly placed.
if hasattr(tf_keras, "version") and tf_keras.version().startswith("3"):
optimizer_variables = [var.value for var in module.optim.variables]
else:
optimizer_variables = module.optim.variables()
assert all(device in t.device for t in updts.coefs.values()) assert all(device in t.device for t in updts.coefs.values())
assert all(device in t.device for t in module.optim.variables()) assert all(device in t.device for t in optimizer_variables)
# Reset device policy to run other tests on CPU as expected. # Reset device policy to run other tests on CPU as expected.
set_device_policy(gpu=False) set_device_policy(gpu=False)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment