diff --git a/declearn/optimizer/modules/_scaffold.py b/declearn/optimizer/modules/_scaffold.py index 46cae6b4c6affccdd690b02da81c3d982f6445c6..185edf69fef59e0266efdd0a050c6064828c1224 100644 --- a/declearn/optimizer/modules/_scaffold.py +++ b/declearn/optimizer/modules/_scaffold.py @@ -31,6 +31,7 @@ References: https://arxiv.org/abs/1910.06378 """ +import warnings from typing import Any, ClassVar, Dict, List, Optional, Union from declearn.model.api import Vector @@ -139,7 +140,6 @@ class ScaffoldClientModule(OptiModule): where x are the shared model's weights, y_i are the local model's weights after K optimization steps with eta_l lr, c is the shared global state and c_i is the local state. - We add the notation: . Noting that (x - y_i) is in fact the difference between the local model's weights before and after running K training @@ -154,7 +154,12 @@ class ScaffoldClientModule(OptiModule): gradients input to this module along the training steps. """ if not self._steps: - raise RuntimeError("Cannot collect on a module that was not run.") + warnings.warn( + "Collecting auxiliary variables from a scaffold module " + "that was not run. Returned zero-valued scalar updates.", + category=RuntimeWarning, + ) + return 0.0 return self._grads / self._steps def process_aux_var( @@ -304,6 +309,9 @@ class ScaffoldServerModule(OptiModule): f"received by ScaffoldServerModule from client '{client}'." ) state = c_dict["state"] + if isinstance(state, float) and state == 0.0: + # Drop info from clients that have not processed gradients. + continue if isinstance(state, (Vector, float)): s_new[client] = state else: diff --git a/test/optimizer/test_scaffold.py b/test/optimizer/test_scaffold.py index a7ae284b1c2f2be675a81d88d1895628a78be883..1f20596f1c8bcf97a387752d2b5b6988186de346 100644 --- a/test/optimizer/test_scaffold.py +++ b/test/optimizer/test_scaffold.py @@ -38,9 +38,10 @@ def test_scaffold_client(mock_gradients: Vector) -> None: """Conduct a series of co-dependent unit tests on ScaffoldClientModule.""" module = ScaffoldClientModule() assert module.delta == 0.0 - # Test that initial aux_var collection fails. - with pytest.raises(RuntimeError): - module.collect_aux_var() + # Test that initial aux_var collection raises a warning and returns 0. + with pytest.warns(RuntimeWarning): + aux_var = module.collect_aux_var() + assert aux_var == {"state": 0.0} # Test run correctness (no correction at state 0). output = module.run(mock_gradients) assert output == mock_gradients