From 622d24169613f9af9a98f9e3b15412e0e0b95368 Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Fri, 17 Feb 2023 12:24:04 +0100 Subject: [PATCH] Change 'ScaffoldClient.collect_aux_var' behavior on unused module. * Until now, calling the `collect_aux_var` method of a `ScaffoldClient` module that had not processed any gradients raised a RuntimeError. * This commit changes that behavior for something less stringent: warn a RuntimeWarning that such a call is unexpected, but return a conventional, zero-valued scalar state. * Conversely, upon receiving such states, `ScaffoldServer` ignores the client (i.e. it is not part of the global state's update calculation). * The rationale behind this change is that it may happen that a client does not perform any optimization step, e.g. due to their DP budget having been saturated. In that case, we do not want the FL process to crash, but rather the client to be ignored as part of the update process. --- declearn/optimizer/modules/_scaffold.py | 12 ++++++++++-- test/optimizer/test_scaffold.py | 7 ++++--- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/declearn/optimizer/modules/_scaffold.py b/declearn/optimizer/modules/_scaffold.py index 46cae6b4..185edf69 100644 --- a/declearn/optimizer/modules/_scaffold.py +++ b/declearn/optimizer/modules/_scaffold.py @@ -31,6 +31,7 @@ References: https://arxiv.org/abs/1910.06378 """ +import warnings from typing import Any, ClassVar, Dict, List, Optional, Union from declearn.model.api import Vector @@ -139,7 +140,6 @@ class ScaffoldClientModule(OptiModule): where x are the shared model's weights, y_i are the local model's weights after K optimization steps with eta_l lr, c is the shared global state and c_i is the local state. - We add the notation: . Noting that (x - y_i) is in fact the difference between the local model's weights before and after running K training @@ -154,7 +154,12 @@ class ScaffoldClientModule(OptiModule): gradients input to this module along the training steps. """ if not self._steps: - raise RuntimeError("Cannot collect on a module that was not run.") + warnings.warn( + "Collecting auxiliary variables from a scaffold module " + "that was not run. Returned zero-valued scalar updates.", + category=RuntimeWarning, + ) + return 0.0 return self._grads / self._steps def process_aux_var( @@ -304,6 +309,9 @@ class ScaffoldServerModule(OptiModule): f"received by ScaffoldServerModule from client '{client}'." ) state = c_dict["state"] + if isinstance(state, float) and state == 0.0: + # Drop info from clients that have not processed gradients. + continue if isinstance(state, (Vector, float)): s_new[client] = state else: diff --git a/test/optimizer/test_scaffold.py b/test/optimizer/test_scaffold.py index a7ae284b..1f20596f 100644 --- a/test/optimizer/test_scaffold.py +++ b/test/optimizer/test_scaffold.py @@ -38,9 +38,10 @@ def test_scaffold_client(mock_gradients: Vector) -> None: """Conduct a series of co-dependent unit tests on ScaffoldClientModule.""" module = ScaffoldClientModule() assert module.delta == 0.0 - # Test that initial aux_var collection fails. - with pytest.raises(RuntimeError): - module.collect_aux_var() + # Test that initial aux_var collection raises a warning and returns 0. + with pytest.warns(RuntimeWarning): + aux_var = module.collect_aux_var() + assert aux_var == {"state": 0.0} # Test run correctness (no correction at state 0). output = module.run(mock_gradients) assert output == mock_gradients -- GitLab