Mentions légales du service

Skip to content
Snippets Groups Projects
Commit e8afbe1b authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Merge branch 'hotfix-tf-gradient-names' into 'main'

Fix `TensorflowModel` gradients' labeling.

Closes #14

See merge request !26
parents 71554197 261ab379
No related branches found
No related tags found
1 merge request!26Fix `TensorflowModel` gradients' labeling
Pipeline #753973 waiting for manual action
...@@ -136,17 +136,48 @@ class TensorflowModel(Model): ...@@ -136,17 +136,48 @@ class TensorflowModel(Model):
raise TypeError( raise TypeError(
"TensorflowModel requires TensorflowVector weights." "TensorflowModel requires TensorflowVector weights."
) )
self._verify_weights_compatibility(weights, trainable_only=False)
variables = {var.name: var for var in self._model.weights} variables = {var.name: var for var in self._model.weights}
if set(weights.coefs).symmetric_difference(variables): for name, value in weights.coefs.items():
missing = set(variables).difference(weights.coefs) variables[name].assign(value, read_value=False)
unexpct = set(weights.coefs).difference(variables)
def _verify_weights_compatibility(
self,
vector: TensorflowVector,
trainable_only: bool = False,
) -> None:
"""Verify that a vector has the same names as the model's weights.
Parameters
----------
vector: TensorflowVector
Vector wrapping weight-related coefficients (e.g. weight
values or gradient-based updates).
trainable_only: bool, default=False
Whether to restrict the comparision to the model's trainable
weights rather than to all of its weights.
Raises
------
KeyError:
In case some expected keys are missing, or additional keys
are present. Be verbose about the identified mismatch(es).
"""
# Gather the variables to compare to the input vector.
if trainable_only:
weights = self._model.trainable_weights
else:
weights = self._model.weights
variables = {var.name: var for var in weights}
# Raise a verbose KeyError in case inputs do not match weights.
if set(vector.coefs).symmetric_difference(variables):
missing = set(variables).difference(vector.coefs)
unexpct = set(vector.coefs).difference(variables)
raise KeyError( raise KeyError(
"Mismatch between input and model weights' names:\n" "Mismatch between input and model weights' names:\n"
+ f"Missing key(s) in inputs: {missing}\n" * bool(missing) + f"Missing key(s) in inputs: {missing}\n" * bool(missing)
+ f"Unexpected key(s) in inputs: {unexpct}\n" * bool(unexpct) + f"Unexpected key(s) in inputs: {unexpct}\n" * bool(unexpct)
) )
for name, value in weights.coefs.items():
variables[name].assign(value, read_value=False)
def compute_batch_gradients( def compute_batch_gradients(
self, self,
...@@ -155,11 +186,14 @@ class TensorflowModel(Model): ...@@ -155,11 +186,14 @@ class TensorflowModel(Model):
) -> TensorflowVector: ) -> TensorflowVector:
data = self._unpack_batch(batch) data = self._unpack_batch(batch)
if max_norm is None: if max_norm is None:
grad = self._compute_batch_gradients(*data) grads = self._compute_batch_gradients(*data)
else: else:
norm = tf.constant(max_norm) norm = tf.constant(max_norm)
grad = self._compute_clipped_gradients(*data, norm) grads = self._compute_clipped_gradients(*data, norm)
return TensorflowVector({str(i): tns for i, tns in enumerate(grad)}) grads_and_vars = zip(grads, self._model.trainable_weights)
return TensorflowVector(
{var.name: grad for grad, var in grads_and_vars}
)
def _unpack_batch( def _unpack_batch(
self, self,
...@@ -237,6 +271,7 @@ class TensorflowModel(Model): ...@@ -237,6 +271,7 @@ class TensorflowModel(Model):
self, self,
updates: TensorflowVector, updates: TensorflowVector,
) -> None: ) -> None:
self._verify_weights_compatibility(updates, trainable_only=True)
# Delegate updates' application to a tensorflow Optimizer. # Delegate updates' application to a tensorflow Optimizer.
values = (-1 * updates).coefs.values() values = (-1 * updates).coefs.values()
zipped = zip(values, self._model.trainable_weights) zipped = zip(values, self._model.trainable_weights)
......
...@@ -91,6 +91,9 @@ class ModelTestSuite: ...@@ -91,6 +91,9 @@ class ModelTestSuite:
w_end = model.get_weights() w_end = model.get_weights()
assert w_srt == w_end assert w_srt == w_end
assert isinstance(grads, test_case.vector_cls) assert isinstance(grads, test_case.vector_cls)
# Check that gradients and weights share the same labeling.
# Note: some weights may not have gradients (if they are frozen).
assert set(grads.coefs).issubset(w_end.coefs)
def test_compute_batch_gradients_np( def test_compute_batch_gradients_np(
self, self,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment