From 13b5c75792879db557e45e23900624c5ce6e17eb Mon Sep 17 00:00:00 2001
From: Paul Andrey <paul.andrey@inria.fr>
Date: Wed, 26 Apr 2023 16:45:56 +0200
Subject: [PATCH] Add 'Model.get_wrapped_model' to the API.

---
 declearn/model/api/_model.py | 15 +++++++++++++++
 test/model/test_tflow.py     |  6 +++---
 test/model/test_torch.py     | 10 +++++-----
 3 files changed, 23 insertions(+), 8 deletions(-)

diff --git a/declearn/model/api/_model.py b/declearn/model/api/_model.py
index a3e6dfee..e74102bb 100644
--- a/declearn/model/api/_model.py
+++ b/declearn/model/api/_model.py
@@ -63,6 +63,21 @@ class Model(Generic[VectorT], metaclass=ABCMeta):
         """Instantiate a Model interface wrapping a 'model' object."""
         self._model = model
 
+    def get_wrapped_model(self) -> Any:
+        """Getter to access the wrapped framework-specific model object.
+
+        This getter should be used sparingly, so as to avoid undesirable
+        side effects. In particular, it should not be used in declearn
+        backend code (but may be in examples or tests), as it is merely
+        a way for end-users to access the wrapped model after training.
+
+        Returns
+        -------
+        model:
+            Wrapped model, of (framework/Model-subclass)-specific type.
+        """
+        return self._model
+
     @property
     @abstractmethod
     def device_policy(
diff --git a/test/model/test_tflow.py b/test/model/test_tflow.py
index a6da92e7..0387075f 100644
--- a/test/model/test_tflow.py
+++ b/test/model/test_tflow.py
@@ -183,7 +183,7 @@ class TestTensorflowModel(ModelTestSuite):
     ) -> None:
         """Check that `get_weights` behaves properly with frozen weights."""
         model = test_case.model
-        tfmod = getattr(model, "_model")  # type: tf.keras.Sequential
+        tfmod = model.get_wrapped_model()
         tfmod.layers[0].trainable = False  # freeze the first layer's weights
         w_all = model.get_weights()
         w_trn = model.get_weights(trainable=True)
@@ -198,7 +198,7 @@ class TestTensorflowModel(ModelTestSuite):
         """Check that `set_weights` behaves properly with frozen weights."""
         # Setup a model with some frozen weights, and gather trainable ones.
         model = test_case.model
-        tfmod = getattr(model, "_model")  # type: tf.keras.Sequential
+        tfmod = model.get_wrapped_model()
         tfmod.layers[0].trainable = False  # freeze the first layer's weights
         w_trn = model.get_weights(trainable=True)
         # Test that `set_weights` works if and only if properly parametrized.
@@ -217,7 +217,7 @@ class TestTensorflowModel(ModelTestSuite):
         policy = model.device_policy
         assert policy.gpu == (test_case.device == "GPU")
         assert policy.idx == 0
-        tfmod = getattr(model, "_model")
+        tfmod = model.get_wrapped_model()
         device = f"{test_case.device}:0"
         for var in tfmod.weights:
             assert var.device.endswith(device)
diff --git a/test/model/test_torch.py b/test/model/test_torch.py
index 67bdc5cc..dafe8c8c 100644
--- a/test/model/test_torch.py
+++ b/test/model/test_torch.py
@@ -236,8 +236,8 @@ class TestTorchModel(ModelTestSuite):
         # Verify that both models have the same device policy.
         assert model.device_policy == other.device_policy
         # Verify that both models have a similar structure of modules.
-        mod_a = list(getattr(model, "_model").modules())
-        mod_b = list(getattr(other, "_model").modules())
+        mod_a = list(model.get_wrapped_model().modules())
+        mod_b = list(other.get_wrapped_model().modules())
         assert len(mod_a) == len(mod_b)
         assert all(isinstance(a, type(b)) for a, b in zip(mod_a, mod_b))
         assert all(repr(a) == repr(b) for a, b in zip(mod_a, mod_b))
@@ -262,7 +262,7 @@ class TestTorchModel(ModelTestSuite):
     ) -> None:
         """Check that `get_weights` behaves properly with frozen weights."""
         model = test_case.model
-        ptmod = getattr(model, "_model")  # type: torch.nn.Module
+        ptmod = model.get_wrapped_model()
         next(ptmod.parameters()).requires_grad = False  # freeze some weights
         w_all = model.get_weights()
         w_trn = model.get_weights(trainable=True)
@@ -280,7 +280,7 @@ class TestTorchModel(ModelTestSuite):
         """Check that `set_weights` behaves properly with frozen weights."""
         # Setup a model with some frozen weights, and gather trainable ones.
         model = test_case.model
-        ptmod = getattr(model, "_model")  # type: torch.nn.Module
+        ptmod = model.get_wrapped_model()
         next(ptmod.parameters()).requires_grad = False  # freeze some weights
         w_trn = model.get_weights(trainable=True)
         # Test that `set_weights` works if and only if properly parametrized.
@@ -299,7 +299,7 @@ class TestTorchModel(ModelTestSuite):
         policy = model.device_policy
         assert policy.gpu == (test_case.device == "GPU")
         assert (policy.idx == 0) if policy.gpu else (policy.idx is None)
-        ptmod = getattr(model, "_model").module
+        ptmod = model.get_wrapped_model().module
         device_type = "cpu" if test_case.device == "CPU" else "cuda"
         for param in ptmod.parameters():
             assert param.device.type == device_type
-- 
GitLab