diff --git a/declearn/main/_client.py b/declearn/main/_client.py index cbe48c27ca517cd6393e7f2b2176c03ffc94ddca..14ae68094582d6fa63d9bbd077dbef3c3f9cdc29 100644 --- a/declearn/main/_client.py +++ b/declearn/main/_client.py @@ -454,6 +454,10 @@ class FederatedClient: and should never be called in another context. """ assert self.trainmanager is not None + # When SecAgg is to be used, setup controllers first. + if self.secagg is not None: + received = await self.netwk.recv_message() + await self.setup_secagg(received) # Await and deserialize a FairnessSetupQuery. received = await self.netwk.recv_message() query = await verify_server_message_validity( diff --git a/declearn/main/_server.py b/declearn/main/_server.py index 1286103148fa8686bfed57dc1d95abbcfadf40f2..11f157a427b29a2586fea26b96e5773455b26138 100644 --- a/declearn/main/_server.py +++ b/declearn/main/_server.py @@ -355,6 +355,10 @@ class FederatedServer: await self._initialize_dpsgd(config) # If fairness-aware federated learning is configured, set it up. if self.fairness is not None: + # When SecAgg is to be used, setup controllers first. + if self.secagg is not None: + await self.setup_secagg() + # Call the setup routine of the held fairness controller. self.aggrg = await self.fairness.setup_fairness( netwk=self.netwk, aggregator=self.aggrg, secagg=self._decrypter )