diff --git a/declearn/model/tensorflow/_model.py b/declearn/model/tensorflow/_model.py index 05e313b8cf306fca54fef740c0ef42b549c97375..8407f5aeba083522ab5120c123012e8e14e26bb5 100644 --- a/declearn/model/tensorflow/_model.py +++ b/declearn/model/tensorflow/_model.py @@ -136,17 +136,48 @@ class TensorflowModel(Model): raise TypeError( "TensorflowModel requires TensorflowVector weights." ) + self._verify_weights_compatibility(weights, trainable_only=False) variables = {var.name: var for var in self._model.weights} - if set(weights.coefs).symmetric_difference(variables): - missing = set(variables).difference(weights.coefs) - unexpct = set(weights.coefs).difference(variables) + for name, value in weights.coefs.items(): + variables[name].assign(value, read_value=False) + + def _verify_weights_compatibility( + self, + vector: TensorflowVector, + trainable_only: bool = False, + ) -> None: + """Verify that a vector has the same names as the model's weights. + + Parameters + ---------- + vector: TensorflowVector + Vector wrapping weight-related coefficients (e.g. weight + values or gradient-based updates). + trainable_only: bool, default=False + Whether to restrict the comparision to the model's trainable + weights rather than to all of its weights. + + Raises + ------ + KeyError: + In case some expected keys are missing, or additional keys + are present. Be verbose about the identified mismatch(es). + """ + # Gather the variables to compare to the input vector. + if trainable_only: + weights = self._model.trainable_weights + else: + weights = self._model.weights + variables = {var.name: var for var in weights} + # Raise a verbose KeyError in case inputs do not match weights. + if set(vector.coefs).symmetric_difference(variables): + missing = set(variables).difference(vector.coefs) + unexpct = set(vector.coefs).difference(variables) raise KeyError( "Mismatch between input and model weights' names:\n" + f"Missing key(s) in inputs: {missing}\n" * bool(missing) + f"Unexpected key(s) in inputs: {unexpct}\n" * bool(unexpct) ) - for name, value in weights.coefs.items(): - variables[name].assign(value, read_value=False) def compute_batch_gradients( self, @@ -155,11 +186,14 @@ class TensorflowModel(Model): ) -> TensorflowVector: data = self._unpack_batch(batch) if max_norm is None: - grad = self._compute_batch_gradients(*data) + grads = self._compute_batch_gradients(*data) else: norm = tf.constant(max_norm) - grad = self._compute_clipped_gradients(*data, norm) - return TensorflowVector({str(i): tns for i, tns in enumerate(grad)}) + grads = self._compute_clipped_gradients(*data, norm) + grads_and_vars = zip(grads, self._model.trainable_weights) + return TensorflowVector( + {var.name: grad for grad, var in grads_and_vars} + ) def _unpack_batch( self, @@ -237,6 +271,7 @@ class TensorflowModel(Model): self, updates: TensorflowVector, ) -> None: + self._verify_weights_compatibility(updates, trainable_only=True) # Delegate updates' application to a tensorflow Optimizer. values = (-1 * updates).coefs.values() zipped = zip(values, self._model.trainable_weights) diff --git a/test/model/model_testing.py b/test/model/model_testing.py index 4751a18fc3ae9c7ecfd809d541286c147027030d..3d632814f678f506a81a25dc7f384141fcbf55df 100644 --- a/test/model/model_testing.py +++ b/test/model/model_testing.py @@ -91,6 +91,9 @@ class ModelTestSuite: w_end = model.get_weights() assert w_srt == w_end assert isinstance(grads, test_case.vector_cls) + # Check that gradients and weights share the same labeling. + # Note: some weights may not have gradients (if they are frozen). + assert set(grads.coefs).issubset(w_end.coefs) def test_compute_batch_gradients_np( self,