Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit 5e7a4972 authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Fix (and change) the initialization phase messaging sequence.

- Always have clients send back an 'InitReply' after the first
  'TrainingManager' setup step happened without error.
- Then, rely on step-specific query-reply exchanges for further
  steps (SecAgg, DP-SGD and/or Fairness setup).
parent 2d355073
No related branches found
No related tags found
1 merge request!69Enable Fairness-Aware Federated Learning
......@@ -346,15 +346,15 @@ class FederatedClient:
except Exception as exc:
await self.netwk.send_message(messaging.Error(repr(exc)))
raise RuntimeError("Initialization failed.") from exc
# Send back an empty message to indicate that things went fine.
self.logger.info("Notifying the server that initialization went fine.")
await self.netwk.send_message(messaging.InitReply())
# If instructed to do so, run additional steps to set up DP-SGD.
if message.dpsgd:
await self._initialize_dpsgd()
# If instructed to do so, run additional steps to enforce fairness.
if message.fairness:
await self._initialize_fairness()
# Send back an empty message to indicate that all went fine.
self.logger.info("Notifying the server that initialization went fine.")
await self.netwk.send_message(messaging.InitReply())
# Optionally checkpoint the received model and optimizer.
if self.ckptr:
self.ckptr.checkpoint(
......@@ -396,7 +396,7 @@ class FederatedClient:
)
except Exception as exc:
raise RuntimeError("DP-SGD initialization failed.") from exc
self.logger.info("Received a request to set up DP-SGD.")
self.logger.info("Received DP-SGD setup instructions.")
try:
self.make_private(message)
except Exception as exc: # pylint: disable=broad-except
......@@ -469,6 +469,7 @@ class FederatedClient:
error = f"Fairness initialization failed: {repr(exc)}."
self.logger.critical(error)
raise RuntimeError(error) from exc
self.logger.info("Received fairness setup instructions.")
# Instantiate a FairnessControllerClient and run its setup routine.
try:
self.fairness = FairnessControllerClient.from_setup_query(
......
......@@ -559,12 +559,12 @@ class TestFederatedClientInitialize:
with patch_class_constructor(DPTrainingManager) as patch_dp:
with patch_class_constructor(TrainingManager) as patch_tm:
await client.initialize()
# Assert that a PrivacyReply and InitReply were sent to the server.
# Assert that an InitReply and a PrivacyReply were sent to the server.
assert netwk.send_message.call_count == 2
reply = netwk.send_message.call_args_list[0].args[0]
assert isinstance(reply, messaging.PrivacyReply)
reply = netwk.send_message.call_args_list[1].args[0]
assert isinstance(reply, messaging.InitReply)
reply = netwk.send_message.call_args_list[1].args[0]
assert isinstance(reply, messaging.PrivacyReply)
# Assert that a DPTrainingManager was set up.
patch_tm.assert_called_once()
patch_dp.assert_called_once_with(
......@@ -599,10 +599,13 @@ class TestFederatedClientInitialize:
with patch_class_constructor(TrainingManager) as patch_tm:
with pytest.raises(RuntimeError):
await client.initialize()
# Assert that two messages were fetched, and an error was sent.
# Assert that two messages were fetched, that first step went well
# (resulting in an InitReply) and then an Error was sent.
assert netwk.recv_message.call_count == 2
netwk.send_message.assert_called_once()
reply = netwk.send_message.call_args.args[0]
assert netwk.send_message.call_count == 2
reply = netwk.send_message.call_args_list[0].args[0]
assert isinstance(reply, messaging.InitReply)
reply = netwk.send_message.call_args_list[1].args[0]
assert isinstance(reply, messaging.Error)
# Assert that the initial TrainingManager was set, but not the DP one.
patch_tm.assert_called_once()
......@@ -631,10 +634,13 @@ class TestFederatedClientInitialize:
# Assert that TrainingManager was instantiated and DP one was called.
patch_tm.assert_called_once()
patch_dp.assert_called_once()
# Assert that both messages were fetched, and an error was sent.
# Assert that both messages were fetched, and an error was sent
# after the DP-SGD setup failed.
assert netwk.recv_message.call_count == 2
netwk.send_message.assert_called_once()
reply = netwk.send_message.call_args.args[0]
assert netwk.send_message.call_count == 2
reply = netwk.send_message.call_args_list[0].args[0]
assert isinstance(reply, messaging.InitReply)
reply = netwk.send_message.call_args_list[1].args[0]
assert isinstance(reply, messaging.Error)
def _setup_fairness_setup_query(
......@@ -752,10 +758,13 @@ class TestFederatedClientInitialize:
) as patch_fcc:
with pytest.raises(RuntimeError):
await client.initialize()
# Assert that two messages were fetched, and an error was sent.
# Assert that two messages were fetched, the first one answere with
# an InitReply, the second with an Error.
assert netwk.recv_message.call_count == 2
netwk.send_message.assert_called_once()
reply = netwk.send_message.call_args.args[0]
assert netwk.send_message.call_count == 2
reply = netwk.send_message.call_args_list[0].args[0]
assert isinstance(reply, messaging.InitReply)
reply = netwk.send_message.call_args_list[1].args[0]
assert isinstance(reply, messaging.Error)
# Assert that no fairness controller was set.
patch_fcc.assert_not_called()
......@@ -782,10 +791,13 @@ class TestFederatedClientInitialize:
# Assert that setup was called (hence causing the exception).
patch_fcc.assert_called_once()
assert client.fairness is None
# Assert that both messages were fetched, and an error was sent.
# Assert that both messages were fetched, and an error was sent
# after the fairness setup failed.
assert netwk.recv_message.call_count == 2
netwk.send_message.assert_called_once()
reply = netwk.send_message.call_args.args[0]
assert netwk.send_message.call_count == 2
reply = netwk.send_message.call_args_list[0].args[0]
assert isinstance(reply, messaging.InitReply)
reply = netwk.send_message.call_args_list[1].args[0]
assert isinstance(reply, messaging.Error)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment