Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 71debe68 authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Merge branch 'scaffold-update' into 'develop'

Change 'ScaffoldClient.collect_aux_var' behavior on unused module.

See merge request !31
parents 10d8a40d 622d2416
No related branches found
No related tags found
1 merge request!31Change 'ScaffoldClient.collect_aux_var' behavior on unused module.
Pipeline #763612 waiting for manual action
......@@ -31,6 +31,7 @@ References:
https://arxiv.org/abs/1910.06378
"""
import warnings
from typing import Any, ClassVar, Dict, List, Optional, Union
from declearn.model.api import Vector
......@@ -139,7 +140,6 @@ class ScaffoldClientModule(OptiModule):
where x are the shared model's weights, y_i are the local
model's weights after K optimization steps with eta_l lr,
c is the shared global state and c_i is the local state.
We add the notation: .
Noting that (x - y_i) is in fact the difference between the
local model's weights before and after running K training
......@@ -154,7 +154,12 @@ class ScaffoldClientModule(OptiModule):
gradients input to this module along the training steps.
"""
if not self._steps:
raise RuntimeError("Cannot collect on a module that was not run.")
warnings.warn(
"Collecting auxiliary variables from a scaffold module "
"that was not run. Returned zero-valued scalar updates.",
category=RuntimeWarning,
)
return 0.0
return self._grads / self._steps
def process_aux_var(
......@@ -304,6 +309,9 @@ class ScaffoldServerModule(OptiModule):
f"received by ScaffoldServerModule from client '{client}'."
)
state = c_dict["state"]
if isinstance(state, float) and state == 0.0:
# Drop info from clients that have not processed gradients.
continue
if isinstance(state, (Vector, float)):
s_new[client] = state
else:
......
......@@ -38,9 +38,10 @@ def test_scaffold_client(mock_gradients: Vector) -> None:
"""Conduct a series of co-dependent unit tests on ScaffoldClientModule."""
module = ScaffoldClientModule()
assert module.delta == 0.0
# Test that initial aux_var collection fails.
with pytest.raises(RuntimeError):
module.collect_aux_var()
# Test that initial aux_var collection raises a warning and returns 0.
with pytest.warns(RuntimeWarning):
aux_var = module.collect_aux_var()
assert aux_var == {"state": 0.0}
# Test run correctness (no correction at state 0).
output = module.run(mock_gradients)
assert output == mock_gradients
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment