From 1fbecfacfed035d64d6fb316a9c8a62fc9553094 Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Thu, 23 Feb 2023 15:23:49 +0100 Subject: [PATCH] Add a generic functional test for `OptiModule.set_state`. * Add `test_set_state_results` to verify that resetting a modules' state enables running the same computation twice. This should enable detecting cases when information is missing from the returned state dict. * Note that at the moment noise-addition modules are ignored. We could look into a way to access and restore RNG states (when CSPRNG is not used). --- test/optimizer/test_modules.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/test/optimizer/test_modules.py b/test/optimizer/test_modules.py index ffd1e313..d6a1db0e 100644 --- a/test/optimizer/test_modules.py +++ b/test/optimizer/test_modules.py @@ -119,6 +119,25 @@ class TestOptiModule(PluginTestBase): module.set_state(states) assert module.get_state() == states + def test_set_state_results( + self, cls: Type[OptiModule], framework: FrameworkType + ) -> None: + """Test that an OptiModule's set_state yields deterministic results.""" + if issubclass(cls, NoiseModule): + pytest.skip("This test is mal-defined for RNG-based modules.") + # Run the module a first time. + module = cls() + test_case = GradientsTestCase(framework) + gradients = test_case.mock_gradient + module.run(gradients) + # Record states and results from a second run. + states = module.get_state() + result = module.run(gradients) + # Set up a new module from the state and verify its outputs. + module = cls() + module.set_state(states) + assert module.run(gradients) == result + def test_set_state_failure(self, cls: Type[OptiModule]) -> None: """Test that an OptiModule's set_state raises an excepted error.""" module = cls() -- GitLab