From 5e7a49723cd0f0af1082b76651eb2997aed5c970 Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Fri, 19 Jul 2024 14:18:04 +0200 Subject: [PATCH] Fix (and change) the initialization phase messaging sequence. - Always have clients send back an 'InitReply' after the first 'TrainingManager' setup step happened without error. - Then, rely on step-specific query-reply exchanges for further steps (SecAgg, DP-SGD and/or Fairness setup). --- declearn/main/_client.py | 9 ++++---- test/main/test_main_client.py | 42 ++++++++++++++++++++++------------- 2 files changed, 32 insertions(+), 19 deletions(-) diff --git a/declearn/main/_client.py b/declearn/main/_client.py index 1b3f45e..26204b9 100644 --- a/declearn/main/_client.py +++ b/declearn/main/_client.py @@ -346,15 +346,15 @@ class FederatedClient: except Exception as exc: await self.netwk.send_message(messaging.Error(repr(exc))) raise RuntimeError("Initialization failed.") from exc + # Send back an empty message to indicate that things went fine. + self.logger.info("Notifying the server that initialization went fine.") + await self.netwk.send_message(messaging.InitReply()) # If instructed to do so, run additional steps to set up DP-SGD. if message.dpsgd: await self._initialize_dpsgd() # If instructed to do so, run additional steps to enforce fairness. if message.fairness: await self._initialize_fairness() - # Send back an empty message to indicate that all went fine. - self.logger.info("Notifying the server that initialization went fine.") - await self.netwk.send_message(messaging.InitReply()) # Optionally checkpoint the received model and optimizer. if self.ckptr: self.ckptr.checkpoint( @@ -396,7 +396,7 @@ class FederatedClient: ) except Exception as exc: raise RuntimeError("DP-SGD initialization failed.") from exc - self.logger.info("Received a request to set up DP-SGD.") + self.logger.info("Received DP-SGD setup instructions.") try: self.make_private(message) except Exception as exc: # pylint: disable=broad-except @@ -469,6 +469,7 @@ class FederatedClient: error = f"Fairness initialization failed: {repr(exc)}." self.logger.critical(error) raise RuntimeError(error) from exc + self.logger.info("Received fairness setup instructions.") # Instantiate a FairnessControllerClient and run its setup routine. try: self.fairness = FairnessControllerClient.from_setup_query( diff --git a/test/main/test_main_client.py b/test/main/test_main_client.py index 7a591c7..d3d3c2e 100644 --- a/test/main/test_main_client.py +++ b/test/main/test_main_client.py @@ -559,12 +559,12 @@ class TestFederatedClientInitialize: with patch_class_constructor(DPTrainingManager) as patch_dp: with patch_class_constructor(TrainingManager) as patch_tm: await client.initialize() - # Assert that a PrivacyReply and InitReply were sent to the server. + # Assert that an InitReply and a PrivacyReply were sent to the server. assert netwk.send_message.call_count == 2 reply = netwk.send_message.call_args_list[0].args[0] - assert isinstance(reply, messaging.PrivacyReply) - reply = netwk.send_message.call_args_list[1].args[0] assert isinstance(reply, messaging.InitReply) + reply = netwk.send_message.call_args_list[1].args[0] + assert isinstance(reply, messaging.PrivacyReply) # Assert that a DPTrainingManager was set up. patch_tm.assert_called_once() patch_dp.assert_called_once_with( @@ -599,10 +599,13 @@ class TestFederatedClientInitialize: with patch_class_constructor(TrainingManager) as patch_tm: with pytest.raises(RuntimeError): await client.initialize() - # Assert that two messages were fetched, and an error was sent. + # Assert that two messages were fetched, that first step went well + # (resulting in an InitReply) and then an Error was sent. assert netwk.recv_message.call_count == 2 - netwk.send_message.assert_called_once() - reply = netwk.send_message.call_args.args[0] + assert netwk.send_message.call_count == 2 + reply = netwk.send_message.call_args_list[0].args[0] + assert isinstance(reply, messaging.InitReply) + reply = netwk.send_message.call_args_list[1].args[0] assert isinstance(reply, messaging.Error) # Assert that the initial TrainingManager was set, but not the DP one. patch_tm.assert_called_once() @@ -631,10 +634,13 @@ class TestFederatedClientInitialize: # Assert that TrainingManager was instantiated and DP one was called. patch_tm.assert_called_once() patch_dp.assert_called_once() - # Assert that both messages were fetched, and an error was sent. + # Assert that both messages were fetched, and an error was sent + # after the DP-SGD setup failed. assert netwk.recv_message.call_count == 2 - netwk.send_message.assert_called_once() - reply = netwk.send_message.call_args.args[0] + assert netwk.send_message.call_count == 2 + reply = netwk.send_message.call_args_list[0].args[0] + assert isinstance(reply, messaging.InitReply) + reply = netwk.send_message.call_args_list[1].args[0] assert isinstance(reply, messaging.Error) def _setup_fairness_setup_query( @@ -752,10 +758,13 @@ class TestFederatedClientInitialize: ) as patch_fcc: with pytest.raises(RuntimeError): await client.initialize() - # Assert that two messages were fetched, and an error was sent. + # Assert that two messages were fetched, the first one answere with + # an InitReply, the second with an Error. assert netwk.recv_message.call_count == 2 - netwk.send_message.assert_called_once() - reply = netwk.send_message.call_args.args[0] + assert netwk.send_message.call_count == 2 + reply = netwk.send_message.call_args_list[0].args[0] + assert isinstance(reply, messaging.InitReply) + reply = netwk.send_message.call_args_list[1].args[0] assert isinstance(reply, messaging.Error) # Assert that no fairness controller was set. patch_fcc.assert_not_called() @@ -782,10 +791,13 @@ class TestFederatedClientInitialize: # Assert that setup was called (hence causing the exception). patch_fcc.assert_called_once() assert client.fairness is None - # Assert that both messages were fetched, and an error was sent. + # Assert that both messages were fetched, and an error was sent + # after the fairness setup failed. assert netwk.recv_message.call_count == 2 - netwk.send_message.assert_called_once() - reply = netwk.send_message.call_args.args[0] + assert netwk.send_message.call_count == 2 + reply = netwk.send_message.call_args_list[0].args[0] + assert isinstance(reply, messaging.InitReply) + reply = netwk.send_message.call_args_list[1].args[0] assert isinstance(reply, messaging.Error) -- GitLab