Mentions légales du service

Skip to content
Snippets Groups Projects

Fix `TensorflowModel` gradients' labeling

Merged ANDREY Paul requested to merge hotfix-tf-gradient-names into main
1 file
+ 43
8
Compare changes
  • Side-by-side
  • Inline
@@ -136,17 +136,48 @@ class TensorflowModel(Model):
@@ -136,17 +136,48 @@ class TensorflowModel(Model):
raise TypeError(
raise TypeError(
"TensorflowModel requires TensorflowVector weights."
"TensorflowModel requires TensorflowVector weights."
)
)
 
self._verify_weights_compatibility(weights, trainable_only=False)
variables = {var.name: var for var in self._model.weights}
variables = {var.name: var for var in self._model.weights}
if set(weights.coefs).symmetric_difference(variables):
for name, value in weights.coefs.items():
missing = set(variables).difference(weights.coefs)
variables[name].assign(value, read_value=False)
unexpct = set(weights.coefs).difference(variables)
 
def _verify_weights_compatibility(
 
self,
 
vector: TensorflowVector,
 
trainable_only: bool = False,
 
) -> None:
 
"""Verify that a vector has the same names as the model's weights.
 
 
Parameters
 
----------
 
vector: TensorflowVector
 
Vector wrapping weight-related coefficients (e.g. weight
 
values or gradient-based updates).
 
trainable_only: bool, default=False
 
Whether to restrict the comparision to the model's trainable
 
weights rather than to all of its weights.
 
 
Raises
 
------
 
KeyError:
 
In case some expected keys are missing, or additional keys
 
are present. Be verbose about the identified mismatch(es).
 
"""
 
# Gather the variables to compare to the input vector.
 
if trainable_only:
 
weights = self._model.trainable_weights
 
else:
 
weights = self._model.weights
 
variables = {var.name: var for var in weights}
 
# Raise a verbose KeyError in case inputs do not match weights.
 
if set(vector.coefs).symmetric_difference(variables):
 
missing = set(variables).difference(vector.coefs)
 
unexpct = set(vector.coefs).difference(variables)
raise KeyError(
raise KeyError(
"Mismatch between input and model weights' names:\n"
"Mismatch between input and model weights' names:\n"
+ f"Missing key(s) in inputs: {missing}\n" * bool(missing)
+ f"Missing key(s) in inputs: {missing}\n" * bool(missing)
+ f"Unexpected key(s) in inputs: {unexpct}\n" * bool(unexpct)
+ f"Unexpected key(s) in inputs: {unexpct}\n" * bool(unexpct)
)
)
for name, value in weights.coefs.items():
variables[name].assign(value, read_value=False)
def compute_batch_gradients(
def compute_batch_gradients(
self,
self,
@@ -155,11 +186,14 @@ class TensorflowModel(Model):
@@ -155,11 +186,14 @@ class TensorflowModel(Model):
) -> TensorflowVector:
) -> TensorflowVector:
data = self._unpack_batch(batch)
data = self._unpack_batch(batch)
if max_norm is None:
if max_norm is None:
grad = self._compute_batch_gradients(*data)
grads = self._compute_batch_gradients(*data)
else:
else:
norm = tf.constant(max_norm)
norm = tf.constant(max_norm)
grad = self._compute_clipped_gradients(*data, norm)
grads = self._compute_clipped_gradients(*data, norm)
return TensorflowVector({str(i): tns for i, tns in enumerate(grad)})
grads_and_vars = zip(grads, self._model.trainable_weights)
 
return TensorflowVector(
 
{var.name: grad for grad, var in grads_and_vars}
 
)
def _unpack_batch(
def _unpack_batch(
self,
self,
@@ -237,6 +271,7 @@ class TensorflowModel(Model):
@@ -237,6 +271,7 @@ class TensorflowModel(Model):
self,
self,
updates: TensorflowVector,
updates: TensorflowVector,
) -> None:
) -> None:
 
self._verify_weights_compatibility(updates, trainable_only=True)
# Delegate updates' application to a tensorflow Optimizer.
# Delegate updates' application to a tensorflow Optimizer.
values = (-1 * updates).coefs.values()
values = (-1 * updates).coefs.values()
zipped = zip(values, self._model.trainable_weights)
zipped = zip(values, self._model.trainable_weights)
Loading