diff --git a/declearn/model/api/_model.py b/declearn/model/api/_model.py index a3e6dfee994f0a15ae9c289d049b73690aacfd1f..e74102bbdfdc1c174f633c08eaf6ea23b62e949d 100644 --- a/declearn/model/api/_model.py +++ b/declearn/model/api/_model.py @@ -63,6 +63,21 @@ class Model(Generic[VectorT], metaclass=ABCMeta): """Instantiate a Model interface wrapping a 'model' object.""" self._model = model + def get_wrapped_model(self) -> Any: + """Getter to access the wrapped framework-specific model object. + + This getter should be used sparingly, so as to avoid undesirable + side effects. In particular, it should not be used in declearn + backend code (but may be in examples or tests), as it is merely + a way for end-users to access the wrapped model after training. + + Returns + ------- + model: + Wrapped model, of (framework/Model-subclass)-specific type. + """ + return self._model + @property @abstractmethod def device_policy( diff --git a/test/model/test_tflow.py b/test/model/test_tflow.py index a6da92e767b7db9760357a00777c46116d13e577..0387075f66dab9bad3635d033e76775c56666624 100644 --- a/test/model/test_tflow.py +++ b/test/model/test_tflow.py @@ -183,7 +183,7 @@ class TestTensorflowModel(ModelTestSuite): ) -> None: """Check that `get_weights` behaves properly with frozen weights.""" model = test_case.model - tfmod = getattr(model, "_model") # type: tf.keras.Sequential + tfmod = model.get_wrapped_model() tfmod.layers[0].trainable = False # freeze the first layer's weights w_all = model.get_weights() w_trn = model.get_weights(trainable=True) @@ -198,7 +198,7 @@ class TestTensorflowModel(ModelTestSuite): """Check that `set_weights` behaves properly with frozen weights.""" # Setup a model with some frozen weights, and gather trainable ones. model = test_case.model - tfmod = getattr(model, "_model") # type: tf.keras.Sequential + tfmod = model.get_wrapped_model() tfmod.layers[0].trainable = False # freeze the first layer's weights w_trn = model.get_weights(trainable=True) # Test that `set_weights` works if and only if properly parametrized. @@ -217,7 +217,7 @@ class TestTensorflowModel(ModelTestSuite): policy = model.device_policy assert policy.gpu == (test_case.device == "GPU") assert policy.idx == 0 - tfmod = getattr(model, "_model") + tfmod = model.get_wrapped_model() device = f"{test_case.device}:0" for var in tfmod.weights: assert var.device.endswith(device) diff --git a/test/model/test_torch.py b/test/model/test_torch.py index 67bdc5cce547fcd98cf7b64a18fed9d81ddfb8d5..dafe8c8cb66b1fbfe12a7201d4bb35a0c4eb23b5 100644 --- a/test/model/test_torch.py +++ b/test/model/test_torch.py @@ -236,8 +236,8 @@ class TestTorchModel(ModelTestSuite): # Verify that both models have the same device policy. assert model.device_policy == other.device_policy # Verify that both models have a similar structure of modules. - mod_a = list(getattr(model, "_model").modules()) - mod_b = list(getattr(other, "_model").modules()) + mod_a = list(model.get_wrapped_model().modules()) + mod_b = list(other.get_wrapped_model().modules()) assert len(mod_a) == len(mod_b) assert all(isinstance(a, type(b)) for a, b in zip(mod_a, mod_b)) assert all(repr(a) == repr(b) for a, b in zip(mod_a, mod_b)) @@ -262,7 +262,7 @@ class TestTorchModel(ModelTestSuite): ) -> None: """Check that `get_weights` behaves properly with frozen weights.""" model = test_case.model - ptmod = getattr(model, "_model") # type: torch.nn.Module + ptmod = model.get_wrapped_model() next(ptmod.parameters()).requires_grad = False # freeze some weights w_all = model.get_weights() w_trn = model.get_weights(trainable=True) @@ -280,7 +280,7 @@ class TestTorchModel(ModelTestSuite): """Check that `set_weights` behaves properly with frozen weights.""" # Setup a model with some frozen weights, and gather trainable ones. model = test_case.model - ptmod = getattr(model, "_model") # type: torch.nn.Module + ptmod = model.get_wrapped_model() next(ptmod.parameters()).requires_grad = False # freeze some weights w_trn = model.get_weights(trainable=True) # Test that `set_weights` works if and only if properly parametrized. @@ -299,7 +299,7 @@ class TestTorchModel(ModelTestSuite): policy = model.device_policy assert policy.gpu == (test_case.device == "GPU") assert (policy.idx == 0) if policy.gpu else (policy.idx is None) - ptmod = getattr(model, "_model").module + ptmod = model.get_wrapped_model().module device_type = "cpu" if test_case.device == "CPU" else "cuda" for param in ptmod.parameters(): assert param.device.type == device_type