diff --git a/test/optimizer/test_modules.py b/test/optimizer/test_modules.py index ffd1e3132ffb64691b04ebdd03cdad4bd5d4e355..d6a1db0e28f89cfe78fe4bf792b12cabb24c1d7a 100644 --- a/test/optimizer/test_modules.py +++ b/test/optimizer/test_modules.py @@ -119,6 +119,25 @@ class TestOptiModule(PluginTestBase): module.set_state(states) assert module.get_state() == states + def test_set_state_results( + self, cls: Type[OptiModule], framework: FrameworkType + ) -> None: + """Test that an OptiModule's set_state yields deterministic results.""" + if issubclass(cls, NoiseModule): + pytest.skip("This test is mal-defined for RNG-based modules.") + # Run the module a first time. + module = cls() + test_case = GradientsTestCase(framework) + gradients = test_case.mock_gradient + module.run(gradients) + # Record states and results from a second run. + states = module.get_state() + result = module.run(gradients) + # Set up a new module from the state and verify its outputs. + module = cls() + module.set_state(states) + assert module.run(gradients) == result + def test_set_state_failure(self, cls: Type[OptiModule]) -> None: """Test that an OptiModule's set_state raises an excepted error.""" module = cls()