Mentions légales du service

Skip to content
Snippets Groups Projects

Fix `TensorflowModel` gradients' labeling

Merged ANDREY Paul requested to merge hotfix-tf-gradient-names into main
2 files
+ 46
8
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -136,17 +136,48 @@ class TensorflowModel(Model):
raise TypeError(
"TensorflowModel requires TensorflowVector weights."
)
self._verify_weights_compatibility(weights, trainable_only=False)
variables = {var.name: var for var in self._model.weights}
if set(weights.coefs).symmetric_difference(variables):
missing = set(variables).difference(weights.coefs)
unexpct = set(weights.coefs).difference(variables)
for name, value in weights.coefs.items():
variables[name].assign(value, read_value=False)
def _verify_weights_compatibility(
self,
vector: TensorflowVector,
trainable_only: bool = False,
) -> None:
"""Verify that a vector has the same names as the model's weights.
Parameters
----------
vector: TensorflowVector
Vector wrapping weight-related coefficients (e.g. weight
values or gradient-based updates).
trainable_only: bool, default=False
Whether to restrict the comparision to the model's trainable
weights rather than to all of its weights.
Raises
------
KeyError:
In case some expected keys are missing, or additional keys
are present. Be verbose about the identified mismatch(es).
"""
# Gather the variables to compare to the input vector.
if trainable_only:
weights = self._model.trainable_weights
else:
weights = self._model.weights
variables = {var.name: var for var in weights}
# Raise a verbose KeyError in case inputs do not match weights.
if set(vector.coefs).symmetric_difference(variables):
missing = set(variables).difference(vector.coefs)
unexpct = set(vector.coefs).difference(variables)
raise KeyError(
"Mismatch between input and model weights' names:\n"
+ f"Missing key(s) in inputs: {missing}\n" * bool(missing)
+ f"Unexpected key(s) in inputs: {unexpct}\n" * bool(unexpct)
)
for name, value in weights.coefs.items():
variables[name].assign(value, read_value=False)
def compute_batch_gradients(
self,
@@ -155,11 +186,14 @@ class TensorflowModel(Model):
) -> TensorflowVector:
data = self._unpack_batch(batch)
if max_norm is None:
grad = self._compute_batch_gradients(*data)
grads = self._compute_batch_gradients(*data)
else:
norm = tf.constant(max_norm)
grad = self._compute_clipped_gradients(*data, norm)
return TensorflowVector({str(i): tns for i, tns in enumerate(grad)})
grads = self._compute_clipped_gradients(*data, norm)
grads_and_vars = zip(grads, self._model.trainable_weights)
return TensorflowVector(
{var.name: grad for grad, var in grads_and_vars}
)
def _unpack_batch(
self,
@@ -237,6 +271,7 @@ class TensorflowModel(Model):
self,
updates: TensorflowVector,
) -> None:
self._verify_weights_compatibility(updates, trainable_only=True)
# Delegate updates' application to a tensorflow Optimizer.
values = (-1 * updates).coefs.values()
zipped = zip(values, self._model.trainable_weights)
Loading