Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit aff0ff31 authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Add support for TensorFlow 2.16 / Keras 3.

- Apply some minor backward changes (and, when required, branches) to
  support breaking changes in Keras 3 (the default for TensorFlow 2.16
  and onwards), while preserving support for older versions.
- On the side, split a test about 'Model' serialization in two tests,
  one for config serializability, the other for instantiation from a
  non-serialized config dict.
- The current code has been tested (by running all unit tests) with
  both TensorFlow 2.11 and 2.16, in a Python 3.10 environment.
parent dce82854
No related branches found
No related tags found
1 merge request!63Finalize DecLearn v2.4.0.
Pipeline #943445 passed
......@@ -183,10 +183,20 @@ class TensorflowModel(Model):
self,
trainable: bool = False,
) -> TensorflowVector:
variables = self._get_weight_variables(trainable)
return TensorflowVector({var.name: var.value() for var in variables})
def _get_weight_variables(
self,
trainable: bool,
) -> Iterable[tf.Variable]:
"""Access TensorFlow Variables wrapping model weight tensors."""
variables = (
self._model.trainable_weights if trainable else self._model.weights
)
return TensorflowVector({var.name: var.value() for var in variables})
if hasattr(tf_keras, "version") and tf_keras.version().startswith("3"):
variables = (var.value for var in variables)
return variables
def set_weights(
self,
......@@ -198,7 +208,9 @@ class TensorflowModel(Model):
"TensorflowModel requires TensorflowVector weights."
)
self._verify_weights_compatibility(weights, trainable=trainable)
variables = {var.name: var for var in self._model.weights}
variables = {
var.name: var for var in self._get_weight_variables(trainable)
}
with tf.device(self._device):
for name, value in weights.coefs.items():
variables[name].assign(value, read_value=False)
......@@ -225,9 +237,7 @@ class TensorflowModel(Model):
In case some expected keys are missing, or additional keys
are present. Be verbose about the identified mismatch(es).
"""
variables = (
self._model.trainable_weights if trainable else self._model.weights
)
variables = self._get_weight_variables(trainable)
raise_on_stringsets_mismatch(
received=set(vector.coefs),
expected={var.name for var in variables},
......@@ -247,7 +257,7 @@ class TensorflowModel(Model):
norm = tf.constant(max_norm)
grads, loss = self._compute_clipped_gradients(*data, norm)
self._loss_history.append(float(loss.numpy()))
grads_and_vars = zip(grads, self._model.trainable_weights)
grads_and_vars = zip(grads, self._get_weight_variables(trainable=True))
return TensorflowVector(
{var.name: grad for grad, var in grads_and_vars}
)
......@@ -331,7 +341,7 @@ class TensorflowModel(Model):
) -> None:
self._verify_weights_compatibility(updates, trainable=True)
with tf.device(self._device):
for var in self._model.trainable_weights:
for var in self._get_weight_variables(trainable=True):
updt = updates.coefs[var.name]
if isinstance(updt, tf.IndexedSlices):
var.scatter_add(updt)
......
......@@ -17,7 +17,7 @@
"""Hacky OptiModule subclass enabling the use of a keras Optimizer."""
from typing import Any, Dict, Union
from typing import Any, Dict, List, Union
# fmt: off
# pylint: disable=import-error,no-name-in-module
......@@ -189,7 +189,7 @@ class TensorflowOptiModule(OptiModule):
key: tf.Variable(tf.zeros_like(grad), name=key)
for key, grad in gradients.coefs.items()
}
self.optim.build(self._vars.values())
self.optim.build(list(self._vars.values()))
def reset(self) -> None:
"""Reset this module to its uninitialized state.
......@@ -206,7 +206,7 @@ class TensorflowOptiModule(OptiModule):
policy = get_device_policy()
self._device = select_device(gpu=policy.gpu, idx=policy.idx)
with tf.device(self._device):
self._vars = {}
self._vars.clear()
self.optim = self.optim.from_config(self.optim.get_config())
def get_config(
......@@ -222,11 +222,20 @@ class TensorflowOptiModule(OptiModule):
key: (val.shape.as_list(), val.dtype.name)
for key, val in self._vars.items()
}
variables = self._get_optimizer_variables()
state = TensorflowVector(
{var.name: var.value() for var in self.optim.variables()}
{str(i): v.value() for i, v in enumerate(variables)}
)
return {"specs": specs, "state": state}
def _get_optimizer_variables(
self,
) -> List[tf.Variable]:
"""Access wrapped optimizer's variables as 'tf.Variable' instances."""
if hasattr(tf_keras, "version") and tf_keras.version().startswith("3"):
return [var.value for var in self.optim.variables]
return self.optim.variables()
def set_state(
self,
state: Dict[str, Any],
......@@ -244,9 +253,9 @@ class TensorflowOptiModule(OptiModule):
key: tf.Variable(tf.zeros(shape, dtype), name=key)
for key, (shape, dtype) in state["specs"].items()
}
self.optim.build(self._vars.values())
self.optim.build(list(self._vars.values()))
# Restore optimizer variables' values from the input state dict.
opt_vars = {var.name: var for var in self.optim.variables()}
opt_vars = self._get_optimizer_variables()
with tf.device(self._device):
for key, val in state["state"].coefs.items():
opt_vars[key].assign(val, read_value=False)
for var, val in zip(opt_vars, state["state"].coefs.values()):
var.assign(val, read_value=False)
......@@ -48,7 +48,7 @@ class LossFunction(tf_keras.losses.Loss):
reduction: str = tf_keras.losses.Reduction.NONE,
name: Optional[str] = None,
) -> None:
super().__init__(reduction, name)
super().__init__(reduction=reduction, name=name)
self.loss_fn = tf_keras.losses.deserialize(loss_fn)
def call(
......
......@@ -274,10 +274,13 @@ Older TensorFlow versions (v2.5 to 2.10 included) were improperly marked as
supported in spite of `TensorflowOptiModule` requiring at least version 2.11
to work (due to changes of the Keras Optimizer API). This has been corrected.
Upcoming TensorFlow versions (v2.16 and onwards) introduce backward-breaking
changes, which are probably due to the backend swap from Keras 2 to Keras 3.
To keep safe, these versions are currently marked as unsupported, awaiting
further investigation once version 2.16 is finalized.
The latest TensorFlow version (v2.16) introduces backward-breaking changes, due to the backend swap from Keras 2 to Keras 3. Our backend code was updated to
both add support for this newer Keras backend, and preserve existing support.
Note that at the moment, the CI does not support TensorFlow above 2.13, due to
newer versions not being compatible with Python 3.8. As such, our code will be
tested to remain backward-compatible. Forward compatibility has been (and will
keep being) tested locally with a newer Python version.
### Deprecate `declearn.dataset.load_from_json`
......
......@@ -53,7 +53,7 @@ all = [ # all non-tests extra dependencies
"jax[cpu] ~= 0.4.1",
"opacus ~= 1.4",
"protobuf >= 3.19",
"tensorflow ~= 2.11, < 2.16",
"tensorflow ~= 2.11",
"torch >= 1.13, < 3.0",
"websockets >= 10.1, < 13.0",
]
......@@ -71,7 +71,7 @@ haiku = [
"jax[cpu] ~= 0.4.1", # NOTE: GPU support must be manually installed
]
tensorflow = [
"tensorflow ~= 2.11, < 2.16", # starting with 2.16 TF upgrades to Keras 3
"tensorflow ~= 2.11",
]
torch = [ # generic requirements for Torch
"torch >= 1.13, < 3.0",
......
......@@ -17,13 +17,14 @@
"""Shared testing code for TensorFlow and Torch models' unit tests."""
import copy
import json
from typing import Any, Generic, List, Protocol, Tuple, Type, TypeVar, Union
import numpy as np
from declearn.model.api import Model, Vector
from declearn.test_utils import to_numpy
from declearn.test_utils import assert_json_serializable_dict, to_numpy
from declearn.typing import Batch
from declearn.utils import json_pack, json_unpack
......@@ -59,14 +60,23 @@ class ModelTestCase(Protocol, Generic[VectorT]):
class ModelTestSuite:
"""Unit tests for a declearn Model."""
def test_serialization(
def test_get_config(
self,
test_case: ModelTestCase,
) -> None:
"""Check that the model can be JSON-(de)serialized properly."""
"""Check that the model's config is JSON-serializable."""
model = test_case.model
config = json.dumps(model.get_config())
other = model.from_config(json.loads(config))
config = model.get_config()
assert_json_serializable_dict(config)
def test_from_config(
self,
test_case: ModelTestCase,
) -> None:
"""Check that the model can be instantiated from its config."""
model = test_case.model
config = model.get_config()
other = model.from_config(copy.deepcopy(config))
assert model.get_config() == other.get_config()
assert model.device_policy == other.device_policy
......
......@@ -234,11 +234,18 @@ class TestHaikuModel(ModelTestSuite):
"""Unit tests for declearn.model.tensorflow.TensorflowModel."""
@pytest.mark.filterwarnings("ignore: Our custom Haiku serialization")
def test_serialization(
def test_get_config(
self,
test_case: ModelTestCase,
) -> None:
super().test_serialization(test_case)
super().test_get_config(test_case)
@pytest.mark.filterwarnings("ignore: Our custom Haiku serialization")
def test_from_config(
self,
test_case: ModelTestCase,
) -> None:
super().test_from_config(test_case)
@pytest.mark.parametrize(
"criterion_type", ["names", "pytree", "predicate"]
......
......@@ -120,15 +120,6 @@ def fixture_test_case(
class TestSklearnSGDModel(ModelTestSuite):
"""Unit tests for declearn.model.sklearn.SklearnSGDModel."""
def test_serialization( # type: ignore # Liskov does not matter here
self,
test_case: SklearnSGDTestCase,
) -> None:
# Avoid re-running tests that are unaltered by data parameters.
if test_case.s_weights or test_case.as_sparse:
return None
return super().test_serialization(test_case)
def test_initialization(
self,
test_case: SklearnSGDTestCase,
......
......@@ -30,7 +30,6 @@ try:
except ModuleNotFoundError:
pytest.skip("TensorFlow is unavailable", allow_module_level=True)
else:
# pylint: disable=import-error,no-name-in-module
import tensorflow.keras as tf_keras # type: ignore
from declearn.model.tensorflow import TensorflowModel, TensorflowVector
......@@ -176,14 +175,23 @@ class TestTensorflowModel(ModelTestSuite):
test_case: ModelTestCase,
) -> None:
"""Check that `get_weights` behaves properly with frozen weights."""
# Set up a model with a frozen layer.
model = test_case.model
tfmod = model.get_wrapped_model()
tfmod.layers[0].trainable = False # freeze the first layer's weights
w_all = model.get_weights()
w_trn = model.get_weights(trainable=True)
assert set(w_trn.coefs).issubset(w_all.coefs) # check on keys
assert w_trn.coefs.keys() == {v.name for v in tfmod.trainable_weights}
assert w_all.coefs.keys() == {v.name for v in tfmod.weights}
# Access names of the model's variables (via TensorFlow/Keras API).
if hasattr(tf_keras, "version") and tf_keras.version().startswith("3"):
names_all_wghts = {v.value.name for v in tfmod.weights}
names_trainable = {v.value.name for v in tfmod.trainable_weights}
else:
names_all_wghts = {v.name for v in tfmod.weights}
names_trainable = {v.name for v in tfmod.trainable_weights}
# Verify that DecLearn-accessed weights' names match.
weights_all = model.get_weights()
weights_trn = model.get_weights(trainable=True)
assert set(weights_trn.coefs).issubset(weights_all.coefs)
assert weights_all.coefs.keys() == names_all_wghts
assert weights_trn.coefs.keys() == names_trainable
def test_set_frozen_weights(
self,
......@@ -193,7 +201,7 @@ class TestTensorflowModel(ModelTestSuite):
# Setup a model with some frozen weights, and gather trainable ones.
model = test_case.model
tfmod = model.get_wrapped_model()
tfmod.layers[0].trainable = False # freeze the first layer's weights
tfmod.layers[-1].trainable = False # freeze the last layer's weights
w_trn = model.get_weights(trainable=True)
# Test that `set_weights` works if and only if properly parametrized.
with pytest.raises(KeyError):
......@@ -213,7 +221,11 @@ class TestTensorflowModel(ModelTestSuite):
assert policy.idx == 0
tfmod = model.get_wrapped_model()
device = f"{test_case.device}:0"
for var in tfmod.weights:
if hasattr(tf_keras, "version") and tf_keras.version().startswith("3"):
variables = [var.value for var in tfmod.weights]
else:
variables = tfmod.weights
for var in variables:
assert var.device.endswith(device)
......@@ -238,6 +250,8 @@ class TestBuildKerasLoss:
def test_build_keras_loss_from_string_noclass_function_name(self) -> None:
"""Test `build_keras_loss` with a valid function name string input."""
if hasattr(tf_keras, "version") and tf_keras.version().startswith("3"):
pytest.skip("Skipping test that no longer works with Keras 3.")
loss = build_keras_loss("mse", tf_keras.losses.Reduction.SUM)
assert isinstance(loss, tf_keras.losses.Loss)
assert hasattr(loss, "loss_fn")
......
......@@ -17,7 +17,7 @@
"""Unit tests for TorchModel."""
import json
import copy
import os
import typing
from typing import List, Literal, Tuple
......@@ -219,7 +219,7 @@ class TestTorchModel(ModelTestSuite):
"""Unit tests for declearn.model.torch.TorchModel."""
@pytest.mark.filterwarnings("ignore: PyTorch JSON serialization")
def test_serialization(
def test_get_config(
self,
test_case: ModelTestCase,
) -> None:
......@@ -228,26 +228,44 @@ class TestTorchModel(ModelTestSuite):
# due to the (de)serialization of a custom nn.Module
# the expected model behaviour is, however, correct
try:
self._test_serialization(test_case)
super().test_get_config(test_case)
except AssertionError:
pytest.skip(
"skipping failed test due to custom nn.Module pickling"
)
self._test_serialization(test_case)
super().test_get_config(test_case)
def _test_serialization(
@pytest.mark.filterwarnings("ignore: PyTorch JSON serialization")
def test_from_config(
self,
test_case: ModelTestCase,
) -> None:
if getattr(test_case, "kind", "") == "RNN":
# NOTE: this test fails on python 3.8 but succeeds in 3.10
# due to the (de)serialization of a custom nn.Module
# the expected model behaviour is, however, correct
try:
self._test_from_config(test_case)
except AssertionError:
pytest.skip(
"skipping failed test due to custom nn.Module pickling"
)
self._test_from_config(test_case)
def _test_from_config(
self,
test_case: ModelTestCase,
) -> None:
"""Check that the model can be JSON-(de)serialized properly.
"""Check that the model can be instantiated from its config.
This method replaces the parent `test_serialization` one.
This method replaces the parent `test_from_config` one.
"""
# Same setup as in parent test: a model and a config-based other.
model = test_case.model
config = json.dumps(model.get_config())
other = model.from_config(json.loads(config))
# Verify that both models have the same device policy.
config = model.get_config()
other = model.from_config(copy.deepcopy(config))
# Verify that both models have the same config and device policy.
assert other.get_config() == config
assert model.device_policy == other.device_policy
# Verify that both models have a similar structure of modules.
mod_a = list(model.get_wrapped_model().modules())
......
......@@ -31,6 +31,8 @@ try:
import tensorflow as tf # type: ignore
except ModuleNotFoundError:
pytest.skip("TensorFlow is unavailable", allow_module_level=True)
else:
import tensorflow.keras as tf_keras # type: ignore
# pylint: enable=duplicate-code
from declearn.model.tensorflow import TensorflowOptiModule, TensorflowVector
......@@ -201,7 +203,11 @@ class TestTensorflowOptiModule(OptiModuleTestSuite):
grads = GradientsTestCase("tensorflow").mock_gradient
updts = module.run(grads)
# Assert that the outputs and internal states are properly placed.
if hasattr(tf_keras, "version") and tf_keras.version().startswith("3"):
optimizer_variables = [var.value for var in module.optim.variables]
else:
optimizer_variables = module.optim.variables()
assert all(device in t.device for t in updts.coefs.values())
assert all(device in t.device for t in module.optim.variables())
assert all(device in t.device for t in optimizer_variables)
# Reset device policy to run other tests on CPU as expected.
set_device_policy(gpu=False)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment