From 1e388de1aeae7a0c730858e946b911f9d36cffbf Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Fri, 10 Feb 2023 19:01:14 +0100 Subject: [PATCH] Warn about limited support for frozen neural network weights. --- declearn/model/tensorflow/_model.py | 13 +++++++++++++ declearn/model/torch/_model.py | 11 ++++++++++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/declearn/model/tensorflow/_model.py b/declearn/model/tensorflow/_model.py index 8407f5ae..fc1a9eb8 100644 --- a/declearn/model/tensorflow/_model.py +++ b/declearn/model/tensorflow/_model.py @@ -17,6 +17,7 @@ """Model subclass to wrap TensorFlow models.""" +import warnings from copy import deepcopy from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union @@ -98,6 +99,18 @@ class TensorflowModel(Model): if not self._model.built: data_info = aggregate_data_info([data_info], {"input_shape"}) self._model.build(data_info["input_shape"]) + # Warn about frozen weights. + # similar to TorchModel warning; pylint: disable=duplicate-code + if len(self._model.trainable_weights) < len(self._model.weights): + warnings.warn( + "'TensorflowModel' wraps a model with frozen weights.\n" + "This is not fully compatible with declearn v2.0.x: the " + "use of weight decay and/or of a loss-regularization " + "plug-in in an Optimizer will fail to produce updates " + "for this model.\n" + "This issue will be fixed in declearn v2.1.0." + ) + # pylint: enable=duplicate-code def get_config( self, diff --git a/declearn/model/torch/_model.py b/declearn/model/torch/_model.py index c6e021dd..488d7782 100644 --- a/declearn/model/torch/_model.py +++ b/declearn/model/torch/_model.py @@ -84,7 +84,16 @@ class TorchModel(Model): self, data_info: Dict[str, Any], ) -> None: - return None + # Warn about frozen weights. + if not all(p.requires_grad for p in self._model.parameters()): + warnings.warn( + "'TorchModel' wraps a model with frozen weights.\n" + "This is not fully compatible with declearn v2.0.x: the " + "use of weight decay and/or of a loss-regularization " + "plug-in in an Optimizer will fail to produce updates " + "for this model.\n" + "This issue will be fixed in declearn v2.1.0." + ) def get_config( self, -- GitLab