From 1fbecfacfed035d64d6fb316a9c8a62fc9553094 Mon Sep 17 00:00:00 2001
From: Paul Andrey <paul.andrey@inria.fr>
Date: Thu, 23 Feb 2023 15:23:49 +0100
Subject: [PATCH] Add a generic functional test for `OptiModule.set_state`.

* Add `test_set_state_results` to verify that resetting a modules' state
  enables running the same computation twice. This should enable detecting
  cases when information is missing from the returned state dict.

* Note that at the moment noise-addition modules are ignored. We could look
  into a way to access and restore RNG states (when CSPRNG is not used).
---
 test/optimizer/test_modules.py | 19 +++++++++++++++++++
 1 file changed, 19 insertions(+)

diff --git a/test/optimizer/test_modules.py b/test/optimizer/test_modules.py
index ffd1e313..d6a1db0e 100644
--- a/test/optimizer/test_modules.py
+++ b/test/optimizer/test_modules.py
@@ -119,6 +119,25 @@ class TestOptiModule(PluginTestBase):
             module.set_state(states)
             assert module.get_state() == states
 
+    def test_set_state_results(
+        self, cls: Type[OptiModule], framework: FrameworkType
+    ) -> None:
+        """Test that an OptiModule's set_state yields deterministic results."""
+        if issubclass(cls, NoiseModule):
+            pytest.skip("This test is mal-defined for RNG-based modules.")
+        # Run the module a first time.
+        module = cls()
+        test_case = GradientsTestCase(framework)
+        gradients = test_case.mock_gradient
+        module.run(gradients)
+        # Record states and results from a second run.
+        states = module.get_state()
+        result = module.run(gradients)
+        # Set up a new module from the state and verify its outputs.
+        module = cls()
+        module.set_state(states)
+        assert module.run(gradients) == result
+
     def test_set_state_failure(self, cls: Type[OptiModule]) -> None:
         """Test that an OptiModule's set_state raises an excepted error."""
         module = cls()
-- 
GitLab