diff --git a/declearn/model/tensorflow/_model.py b/declearn/model/tensorflow/_model.py index 8407f5aeba083522ab5120c123012e8e14e26bb5..fc1a9eb88dc4d42f07314276b266d20749661318 100644 --- a/declearn/model/tensorflow/_model.py +++ b/declearn/model/tensorflow/_model.py @@ -17,6 +17,7 @@ """Model subclass to wrap TensorFlow models.""" +import warnings from copy import deepcopy from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union @@ -98,6 +99,18 @@ class TensorflowModel(Model): if not self._model.built: data_info = aggregate_data_info([data_info], {"input_shape"}) self._model.build(data_info["input_shape"]) + # Warn about frozen weights. + # similar to TorchModel warning; pylint: disable=duplicate-code + if len(self._model.trainable_weights) < len(self._model.weights): + warnings.warn( + "'TensorflowModel' wraps a model with frozen weights.\n" + "This is not fully compatible with declearn v2.0.x: the " + "use of weight decay and/or of a loss-regularization " + "plug-in in an Optimizer will fail to produce updates " + "for this model.\n" + "This issue will be fixed in declearn v2.1.0." + ) + # pylint: enable=duplicate-code def get_config( self, diff --git a/declearn/model/torch/_model.py b/declearn/model/torch/_model.py index c6e021dd68c96f5c70a53d805d411eadb271fc14..488d7782be84ffc9ef38ec9cee4d38bf31ff6d37 100644 --- a/declearn/model/torch/_model.py +++ b/declearn/model/torch/_model.py @@ -84,7 +84,16 @@ class TorchModel(Model): self, data_info: Dict[str, Any], ) -> None: - return None + # Warn about frozen weights. + if not all(p.requires_grad for p in self._model.parameters()): + warnings.warn( + "'TorchModel' wraps a model with frozen weights.\n" + "This is not fully compatible with declearn v2.0.x: the " + "use of weight decay and/or of a loss-regularization " + "plug-in in an Optimizer will fail to produce updates " + "for this model.\n" + "This issue will be fixed in declearn v2.1.0." + ) def get_config( self,