From 13b5c75792879db557e45e23900624c5ce6e17eb Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Wed, 26 Apr 2023 16:45:56 +0200 Subject: [PATCH] Add 'Model.get_wrapped_model' to the API. --- declearn/model/api/_model.py | 15 +++++++++++++++ test/model/test_tflow.py | 6 +++--- test/model/test_torch.py | 10 +++++----- 3 files changed, 23 insertions(+), 8 deletions(-) diff --git a/declearn/model/api/_model.py b/declearn/model/api/_model.py index a3e6dfee..e74102bb 100644 --- a/declearn/model/api/_model.py +++ b/declearn/model/api/_model.py @@ -63,6 +63,21 @@ class Model(Generic[VectorT], metaclass=ABCMeta): """Instantiate a Model interface wrapping a 'model' object.""" self._model = model + def get_wrapped_model(self) -> Any: + """Getter to access the wrapped framework-specific model object. + + This getter should be used sparingly, so as to avoid undesirable + side effects. In particular, it should not be used in declearn + backend code (but may be in examples or tests), as it is merely + a way for end-users to access the wrapped model after training. + + Returns + ------- + model: + Wrapped model, of (framework/Model-subclass)-specific type. + """ + return self._model + @property @abstractmethod def device_policy( diff --git a/test/model/test_tflow.py b/test/model/test_tflow.py index a6da92e7..0387075f 100644 --- a/test/model/test_tflow.py +++ b/test/model/test_tflow.py @@ -183,7 +183,7 @@ class TestTensorflowModel(ModelTestSuite): ) -> None: """Check that `get_weights` behaves properly with frozen weights.""" model = test_case.model - tfmod = getattr(model, "_model") # type: tf.keras.Sequential + tfmod = model.get_wrapped_model() tfmod.layers[0].trainable = False # freeze the first layer's weights w_all = model.get_weights() w_trn = model.get_weights(trainable=True) @@ -198,7 +198,7 @@ class TestTensorflowModel(ModelTestSuite): """Check that `set_weights` behaves properly with frozen weights.""" # Setup a model with some frozen weights, and gather trainable ones. model = test_case.model - tfmod = getattr(model, "_model") # type: tf.keras.Sequential + tfmod = model.get_wrapped_model() tfmod.layers[0].trainable = False # freeze the first layer's weights w_trn = model.get_weights(trainable=True) # Test that `set_weights` works if and only if properly parametrized. @@ -217,7 +217,7 @@ class TestTensorflowModel(ModelTestSuite): policy = model.device_policy assert policy.gpu == (test_case.device == "GPU") assert policy.idx == 0 - tfmod = getattr(model, "_model") + tfmod = model.get_wrapped_model() device = f"{test_case.device}:0" for var in tfmod.weights: assert var.device.endswith(device) diff --git a/test/model/test_torch.py b/test/model/test_torch.py index 67bdc5cc..dafe8c8c 100644 --- a/test/model/test_torch.py +++ b/test/model/test_torch.py @@ -236,8 +236,8 @@ class TestTorchModel(ModelTestSuite): # Verify that both models have the same device policy. assert model.device_policy == other.device_policy # Verify that both models have a similar structure of modules. - mod_a = list(getattr(model, "_model").modules()) - mod_b = list(getattr(other, "_model").modules()) + mod_a = list(model.get_wrapped_model().modules()) + mod_b = list(other.get_wrapped_model().modules()) assert len(mod_a) == len(mod_b) assert all(isinstance(a, type(b)) for a, b in zip(mod_a, mod_b)) assert all(repr(a) == repr(b) for a, b in zip(mod_a, mod_b)) @@ -262,7 +262,7 @@ class TestTorchModel(ModelTestSuite): ) -> None: """Check that `get_weights` behaves properly with frozen weights.""" model = test_case.model - ptmod = getattr(model, "_model") # type: torch.nn.Module + ptmod = model.get_wrapped_model() next(ptmod.parameters()).requires_grad = False # freeze some weights w_all = model.get_weights() w_trn = model.get_weights(trainable=True) @@ -280,7 +280,7 @@ class TestTorchModel(ModelTestSuite): """Check that `set_weights` behaves properly with frozen weights.""" # Setup a model with some frozen weights, and gather trainable ones. model = test_case.model - ptmod = getattr(model, "_model") # type: torch.nn.Module + ptmod = model.get_wrapped_model() next(ptmod.parameters()).requires_grad = False # freeze some weights w_trn = model.get_weights(trainable=True) # Test that `set_weights` works if and only if properly parametrized. @@ -299,7 +299,7 @@ class TestTorchModel(ModelTestSuite): policy = model.device_policy assert policy.gpu == (test_case.device == "GPU") assert (policy.idx == 0) if policy.gpu else (policy.idx is None) - ptmod = getattr(model, "_model").module + ptmod = model.get_wrapped_model().module device_type = "cpu" if test_case.device == "CPU" else "cuda" for param in ptmod.parameters(): assert param.device.type == device_type -- GitLab