From 622d24169613f9af9a98f9e3b15412e0e0b95368 Mon Sep 17 00:00:00 2001
From: Paul Andrey <paul.andrey@inria.fr>
Date: Fri, 17 Feb 2023 12:24:04 +0100
Subject: [PATCH] Change 'ScaffoldClient.collect_aux_var' behavior on unused
 module.

* Until now, calling the `collect_aux_var` method of a `ScaffoldClient`
  module that had not processed any gradients raised a RuntimeError.
* This commit changes that behavior for something less stringent:
  warn a RuntimeWarning that such a call is unexpected, but return
  a conventional, zero-valued scalar state.
* Conversely, upon receiving such states, `ScaffoldServer` ignores the
  client (i.e. it is not part of the global state's update calculation).
* The rationale behind this change is that it may happen that a client
  does not perform any optimization step, e.g. due to their DP budget
  having been saturated. In that case, we do not want the FL process
  to crash, but rather the client to be ignored as part of the update
  process.
---
 declearn/optimizer/modules/_scaffold.py | 12 ++++++++++--
 test/optimizer/test_scaffold.py         |  7 ++++---
 2 files changed, 14 insertions(+), 5 deletions(-)

diff --git a/declearn/optimizer/modules/_scaffold.py b/declearn/optimizer/modules/_scaffold.py
index 46cae6b4..185edf69 100644
--- a/declearn/optimizer/modules/_scaffold.py
+++ b/declearn/optimizer/modules/_scaffold.py
@@ -31,6 +31,7 @@ References:
     https://arxiv.org/abs/1910.06378
 """
 
+import warnings
 from typing import Any, ClassVar, Dict, List, Optional, Union
 
 from declearn.model.api import Vector
@@ -139,7 +140,6 @@ class ScaffoldClientModule(OptiModule):
         where x are the shared model's weights, y_i are the local
         model's weights after K optimization steps with eta_l lr,
         c is the shared global state and c_i is the local state.
-        We add the notation: .
 
         Noting that (x - y_i) is in fact the difference between the
         local model's weights before and after running K training
@@ -154,7 +154,12 @@ class ScaffoldClientModule(OptiModule):
         gradients input to this module along the training steps.
         """
         if not self._steps:
-            raise RuntimeError("Cannot collect on a module that was not run.")
+            warnings.warn(
+                "Collecting auxiliary variables from a scaffold module "
+                "that was not run. Returned zero-valued scalar updates.",
+                category=RuntimeWarning,
+            )
+            return 0.0
         return self._grads / self._steps
 
     def process_aux_var(
@@ -304,6 +309,9 @@ class ScaffoldServerModule(OptiModule):
                     f"received by ScaffoldServerModule from client '{client}'."
                 )
             state = c_dict["state"]
+            if isinstance(state, float) and state == 0.0:
+                # Drop info from clients that have not processed gradients.
+                continue
             if isinstance(state, (Vector, float)):
                 s_new[client] = state
             else:
diff --git a/test/optimizer/test_scaffold.py b/test/optimizer/test_scaffold.py
index a7ae284b..1f20596f 100644
--- a/test/optimizer/test_scaffold.py
+++ b/test/optimizer/test_scaffold.py
@@ -38,9 +38,10 @@ def test_scaffold_client(mock_gradients: Vector) -> None:
     """Conduct a series of co-dependent unit tests on ScaffoldClientModule."""
     module = ScaffoldClientModule()
     assert module.delta == 0.0
-    # Test that initial aux_var collection fails.
-    with pytest.raises(RuntimeError):
-        module.collect_aux_var()
+    # Test that initial aux_var collection raises a warning and returns 0.
+    with pytest.warns(RuntimeWarning):
+        aux_var = module.collect_aux_var()
+    assert aux_var == {"state": 0.0}
     # Test run correctness (no correction at state 0).
     output = module.run(mock_gradients)
     assert output == mock_gradients
-- 
GitLab