diff --git a/declearn/main/_client.py b/declearn/main/_client.py index 14ae68094582d6fa63d9bbd077dbef3c3f9cdc29..c1619e7cda5400db402f76abfda59f5e1d544642 100644 --- a/declearn/main/_client.py +++ b/declearn/main/_client.py @@ -454,15 +454,21 @@ class FederatedClient: and should never be called in another context. """ assert self.trainmanager is not None - # When SecAgg is to be used, setup controllers first. - if self.secagg is not None: + # Optionally setup SecAgg; await a FairnessSetupQuery. + try: + # When SecAgg is to be used, setup controllers first. + if self.secagg is not None: + received = await self.netwk.recv_message() + await self.setup_secagg(received) + # Await and deserialize a FairnessSetupQuery. received = await self.netwk.recv_message() - await self.setup_secagg(received) - # Await and deserialize a FairnessSetupQuery. - received = await self.netwk.recv_message() - query = await verify_server_message_validity( - self.netwk, received, expected=messaging.FairnessSetupQuery - ) + query = await verify_server_message_validity( + self.netwk, received, expected=messaging.FairnessSetupQuery + ) + except Exception as exc: + error = f"Fairness initialization failed: {repr(exc)}." + self.logger.critical(error) + raise RuntimeError(error) from exc # Instantiate a FairnessControllerClient and run its setup routine. try: self.fairness = FairnessControllerClient.from_setup_query( @@ -643,6 +649,16 @@ class FederatedClient: self.logger.critical(error) await self.netwk.send_message(messaging.Error(error)) raise RuntimeError(error) + # When SecAgg is to be used, verify that it was set up. + if self.secagg is not None and self._encrypter is None: + error = ( + "Refusing to participate in fairness-related round " + f"{query.round_i} as SecAgg is configured to be used " + "but was not set up." + ) + self.logger.error(error) + await self.netwk.send_message(messaging.Error(error)) + return # Otherwise, run the controller's routine. metrics = await self.fairness.fairness_round( netwk=self.netwk, query=query, secagg=self._encrypter diff --git a/test/main/test_main_client.py b/test/main/test_main_client.py index 4288131c274eb13c711fe57150f1933e6caf2a25..391daec98e2fe143e3ae1639f1270a98fa3eb45f 100644 --- a/test/main/test_main_client.py +++ b/test/main/test_main_client.py @@ -19,7 +19,7 @@ import contextlib import logging -from typing import Any, Iterator, Optional, Type +from typing import Any, Iterator, Optional, Tuple, Type from unittest import mock import pytest # type: ignore @@ -28,6 +28,7 @@ from declearn import messaging from declearn.dataset import Dataset, DataSpecs from declearn.communication import NetworkClientConfig from declearn.communication.api import NetworkClient +from declearn.fairness.api import FairnessControllerClient from declearn.main import FederatedClient from declearn.main.utils import Checkpointer from declearn.metrics import MetricState @@ -45,6 +46,9 @@ else: DP_AVAILABLE = True +# numerous but organized tests; pylint: disable=too-many-lines + + MOCK_NETWK = mock.create_autospec(NetworkClient, instance=True) MOCK_NETWK.name = "client" MOCK_DATASET = mock.create_autospec(Dataset, instance=True) @@ -356,6 +360,7 @@ class TestFederatedClientInitialize: def _setup_mock_init_request( secagg: Optional[str] = None, dpsgd: bool = False, + fairness: bool = False, ) -> messaging.SerializedMessage[messaging.InitRequest]: """Return a mock serialized InitRequest.""" init_req = messaging.InitRequest( @@ -364,6 +369,7 @@ class TestFederatedClientInitialize: aggrg=mock.MagicMock(), secagg=secagg, dpsgd=dpsgd, + fairness=fairness, ) msg_init = mock.create_autospec( messaging.SerializedMessage, instance=True @@ -518,6 +524,23 @@ class TestFederatedClientInitialize: assert isinstance(reply, messaging.Error) patched.assert_not_called() + def _setup_dpsgd_setup_query( + self, + ) -> Tuple[ + messaging.SerializedMessage[messaging.PrivacyRequest], + messaging.PrivacyRequest, + ]: + """Setup a mock PrivacyRequest and a wrapping SerializedMessage.""" + dp_query = mock.create_autospec( + messaging.PrivacyRequest, instance=True + ) + msg_priv = mock.create_autospec( + messaging.SerializedMessage, instance=True + ) + msg_priv.message_cls = messaging.PrivacyRequest + msg_priv.deserialize.return_value = dp_query + return msg_priv, dp_query + @pytest.mark.asyncio async def test_initialize_with_dpsgd(self) -> None: """Test that initialization with DP-SGD works properly.""" @@ -527,13 +550,7 @@ class TestFederatedClientInitialize: netwk = mock.create_autospec(NetworkClient, instance=True) netwk.name = "client" msg_init = self._setup_mock_init_request(secagg=None, dpsgd=True) - msg_priv = mock.create_autospec( - messaging.SerializedMessage, instance=True - ) - msg_priv.message_cls = messaging.PrivacyRequest - msg_priv.deserialize.return_value = dpconfig = mock.create_autospec( - messaging.PrivacyRequest, instance=True - ) + msg_priv, dp_query = self._setup_dpsgd_setup_query() netwk.recv_message.side_effect = [msg_init, msg_priv] # Set up a client wrapping the former network endpoint. client = FederatedClient(netwk=netwk, train_data=MOCK_DATASET) @@ -542,7 +559,10 @@ class TestFederatedClientInitialize: with patch_class_constructor(DPTrainingManager) as patch_dp: with patch_class_constructor(TrainingManager) as patch_tm: await client.initialize() - # Assert that a single InitReply was then sent to the server. + # Assert that a PrivacyReply and InitReply were sent to the server. + assert netwk.send_message.call_count == 2 + reply = netwk.send_message.call_args_list[0].args[0] + assert isinstance(reply, messaging.PrivacyReply) reply = netwk.send_message.call_args_list[1].args[0] assert isinstance(reply, messaging.InitReply) # Assert that a DPTrainingManager was set up. @@ -558,7 +578,7 @@ class TestFederatedClientInitialize: logger=patch_tm.return_value.logger, verbose=patch_tm.return_value.verbose, ) - patch_dp.return_value.make_private.assert_called_once_with(dpconfig) + patch_dp.return_value.make_private.assert_called_once_with(dp_query) assert client.trainmanager is patch_dp.return_value @pytest.mark.asyncio @@ -597,13 +617,7 @@ class TestFederatedClientInitialize: netwk = mock.create_autospec(NetworkClient, instance=True) netwk.name = "client" msg_init = self._setup_mock_init_request(secagg=None, dpsgd=True) - msg_priv = mock.create_autospec( - messaging.SerializedMessage, instance=True - ) - msg_priv.message_cls = messaging.PrivacyRequest - msg_priv.deserialize.return_value = mock.create_autospec( - messaging.PrivacyRequest, instance=True - ) + msg_priv, _ = self._setup_dpsgd_setup_query() netwk.recv_message.side_effect = [msg_init, msg_priv] # Set up a client wrapping the former network endpoint. client = FederatedClient(netwk=netwk, train_data=MOCK_DATASET) @@ -623,6 +637,157 @@ class TestFederatedClientInitialize: reply = netwk.send_message.call_args.args[0] assert isinstance(reply, messaging.Error) + def _setup_fairness_setup_query( + self, + ) -> Tuple[ + messaging.SerializedMessage[messaging.FairnessSetupQuery], + messaging.FairnessSetupQuery, + ]: + """Setup a mock FairnessSetupQuery and a wrapping SerializedMessage.""" + fs_query = mock.create_autospec( + messaging.FairnessSetupQuery, instance=True + ) + msg_fair = mock.create_autospec( + messaging.SerializedMessage, instance=True + ) + msg_fair.message_cls = messaging.FairnessSetupQuery + msg_fair.deserialize.return_value = fs_query + return msg_fair, fs_query + + @pytest.mark.asyncio + async def test_initialize_with_fairness(self) -> None: + """Test that initialization with fairness works properly.""" + # Set up a mock network receiving an InitRequest, + # then a FairnessSetupQuery. + netwk = mock.create_autospec(NetworkClient, instance=True) + netwk.name = "client" + msg_init = self._setup_mock_init_request(secagg=None, fairness=True) + msg_fair, fs_query = self._setup_fairness_setup_query() + netwk.recv_message.side_effect = [msg_init, msg_fair] + # Set up a client wrapping the former network endpoint. + client = FederatedClient(netwk=netwk, train_data=MOCK_DATASET) + # Attempt running initialization, patching fairness controller setup. + with mock.patch.object( + FairnessControllerClient, "from_setup_query" + ) as patch_fcc: + mock_controller = patch_fcc.return_value + mock_controller.setup_fairness = mock.AsyncMock() + await client.initialize() + # Assert that a controller was instantiated and set up. + patch_fcc.assert_called_once_with( + query=fs_query, manager=client.trainmanager + ) + mock_controller.setup_fairness.assert_awaited_once_with( + netwk=client.netwk, secagg=None + ) + assert client.fairness is mock_controller + # Assert that a single InitReply was then sent to the server. + netwk.send_message.assert_called_once() + reply = netwk.send_message.call_args[0][0] + assert isinstance(reply, messaging.InitReply) + + @pytest.mark.asyncio + async def test_initialize_with_fairness_and_secagg(self) -> None: + """Test that initialization with fairness and secagg works properly.""" + # Set up a mock network receiving an InitRequest with SecAgg, + # then a SecaggSetupQuery and finally a FairnessSetupQuery. + netwk = mock.create_autospec(NetworkClient, instance=True) + netwk.name = "client" + msg_init = self._setup_mock_init_request( + secagg="mock-secagg", fairness=True + ) + msg_sqry = mock.create_autospec( + messaging.SerializedMessage, instance=True + ) + msg_sqry.message_cls = SecaggSetupQuery + msg_sqry.deserialize.return_value = mock.create_autospec( + SecaggSetupQuery, instance=True + ) + msg_fair, fs_query = self._setup_fairness_setup_query() + netwk.recv_message.side_effect = [msg_init, msg_sqry, msg_fair] + # Set up a client with that endpoint and a matching mock secagg. + secagg = mock.create_autospec(SecaggConfigClient, instance=True) + secagg.secagg_type = "mock-secagg" + client = FederatedClient( + netwk=netwk, train_data=MOCK_DATASET, secagg=secagg + ) + # Attempt running initialization, patching fairness controller setup. + with mock.patch.object( + FairnessControllerClient, "from_setup_query" + ) as patch_fcc: + mock_controller = patch_fcc.return_value + mock_controller.setup_fairness = mock.AsyncMock() + await client.initialize() + # Assert that all three messages were fetched. + assert netwk.recv_message.call_count == 3 + # Assert that a secagg controller was set up. + secagg.setup_encrypter.assert_awaited_once_with(netwk, msg_sqry) + # Assert that a fairness controller was instantiated + # and then set up using the secagg controller. + patch_fcc.assert_called_once_with( + query=fs_query, manager=client.trainmanager + ) + mock_controller.setup_fairness.assert_awaited_once_with( + netwk=client.netwk, secagg=secagg.setup_encrypter.return_value + ) + assert client.fairness is mock_controller + # Assert that a single InitReply was then sent to the server. + netwk.send_message.assert_called_once() + reply = netwk.send_message.call_args[0][0] + assert isinstance(reply, messaging.InitReply) + + @pytest.mark.asyncio + async def test_initialize_with_fairness_error_wrong_message(self) -> None: + """Test error catching for fairness setup with wrong second message.""" + # Set up a mock network receiving an InitRequest but wrong follow-up. + netwk = mock.create_autospec(NetworkClient, instance=True) + netwk.name = "client" + msg_init = self._setup_mock_init_request(secagg=None, fairness=True) + netwk.recv_message.return_value = msg_init + # Set up a client wrapping the former network endpoint. + client = FederatedClient(netwk=netwk, train_data=MOCK_DATASET) + # Attempt running initialization, monitoring fairness controller setup. + with mock.patch.object( + FairnessControllerClient, "from_setup_query" + ) as patch_fcc: + with pytest.raises(RuntimeError): + await client.initialize() + # Assert that two messages were fetched, and an error was sent. + assert netwk.recv_message.call_count == 2 + netwk.send_message.assert_called_once() + reply = netwk.send_message.call_args.args[0] + assert isinstance(reply, messaging.Error) + # Assert that no fairness controller was set. + patch_fcc.assert_not_called() + + @pytest.mark.asyncio + async def test_initialize_with_fairness_error_setup(self) -> None: + """Test error catching for fairness setup with client-side failure.""" + # Set up a mock network receiving an InitRequest, + # then a FairnessSetupQuery. + netwk = mock.create_autospec(NetworkClient, instance=True) + netwk.name = "client" + msg_init = self._setup_mock_init_request(secagg=None, fairness=True) + msg_fair, _ = self._setup_fairness_setup_query() + netwk.recv_message.side_effect = [msg_init, msg_fair] + # Set up a client wrapping the former network endpoint. + client = FederatedClient(netwk=netwk, train_data=MOCK_DATASET) + # Attempt running initialization, monitoring and forcing setup failure. + with mock.patch.object( + FairnessControllerClient, "from_setup_query" + ) as patch_fcc: + patch_fcc.side_effect = TypeError + with pytest.raises(RuntimeError): + await client.initialize() + # Assert that setup was called (hence causing the exception). + patch_fcc.assert_called_once() + assert client.fairness is None + # Assert that both messages were fetched, and an error was sent. + assert netwk.recv_message.call_count == 2 + netwk.send_message.assert_called_once() + reply = netwk.send_message.call_args.args[0] + assert isinstance(reply, messaging.Error) + class TestFederatedClientSetupSecagg: """Unit tests for 'FederatedClient.setup_secagg'.""" @@ -875,6 +1040,118 @@ class TestFederatedClientEvaluationRound: train_manager.evaluation_round.assert_not_called() +class TestFederatedClientFairnessRound: + """Unit tests for 'FederatedClient.fairness_round'.""" + + @pytest.mark.parametrize("ckpt", [True, False], ids=["ckpt", "nockpt"]) + @pytest.mark.asyncio + async def test_fairness_round( + self, + ckpt: bool, + ) -> None: + """Test 'fairness_round' with fairness and without SecAgg.""" + # Set up a client with a mock NetworkClient, TrainingManager, + # FairnessClientController and optional Checkpointer. + netwk = mock.create_autospec(NetworkClient, instance=True) + netwk.name = "client" + client = FederatedClient(netwk, train_data=MOCK_DATASET) + client.trainmanager = mock.create_autospec( + TrainingManager, instance=True + ) + fairness = mock.create_autospec( + FairnessControllerClient, instance=True + ) + client.fairness = fairness + if ckpt: + client.ckptr = mock.create_autospec(Checkpointer, instance=True) + # Call the 'fairness_round' routine and verify expected actions. + request = messaging.FairnessQuery(round_i=1) + await client.fairness_round(request) + fairness.fairness_round.assert_awaited_once_with( + netwk=netwk, query=request, secagg=None + ) + # Verify that when a checkpointer is set, it is used. + if ckpt: + client.ckptr.save_metrics.assert_called_once_with( # type: ignore + metrics=fairness.fairness_round.return_value, + prefix="fairness_metrics", + append=True, + timestamp="round_1", + ) + + @pytest.mark.asyncio + async def test_fairness_round_secagg(self) -> None: + """Test 'fairness_round' with fairness and with SecAgg.""" + # Set up a client with a mock NetworkClient, TrainingManager, + # FairnessClientController and SecaggConfigClient. + netwk = mock.create_autospec(NetworkClient, instance=True) + netwk.name = "client" + secagg = mock.create_autospec(SecaggConfigClient, instance=True) + client = FederatedClient(netwk, train_data=MOCK_DATASET, secagg=secagg) + client.trainmanager = mock.create_autospec( + TrainingManager, instance=True + ) + fairness = mock.create_autospec( + FairnessControllerClient, instance=True + ) + client.fairness = fairness + # Call the SecAgg setup routine. + await client.setup_secagg( + mock.create_autospec(messaging.SerializedMessage) + ) + # Call the 'fairness_round' routine and verify expected actions. + request = messaging.FairnessQuery(round_i=1) + await client.fairness_round(request) + fairness.fairness_round.assert_awaited_once_with( + netwk=netwk, + query=request, + secagg=secagg.setup_encrypter.return_value, + ) + + @pytest.mark.asyncio + async def test_fairness_round_fairness_not_setup(self) -> None: + """Test 'fairness_round' without a fairness controller.""" + # Set up a client with a mock NetworkClient and TrainingManager, + # but no fairness controller. + netwk = mock.create_autospec(NetworkClient, instance=True) + netwk.name = "client" + client = FederatedClient(netwk, train_data=MOCK_DATASET) + client.trainmanager = mock.create_autospec( + TrainingManager, instance=True + ) + # Verify that running the routine raises a RuntimeError. + with pytest.raises(RuntimeError): + await client.fairness_round(messaging.FairnessQuery(round_i=1)) + # Verify that an Error message was sent. + netwk.send_message.assert_called_once() + reply = netwk.send_message.call_args.args[0] + assert isinstance(reply, messaging.Error) + + @pytest.mark.asyncio + async def test_fairness_round_secagg_not_setup(self) -> None: + """Test 'fairness_round' error with configured, not-setup SecAgg.""" + # Set up a client with a mock NetworkClient, TrainingManager, + # FairnessClientController and SecaggConfigClient. + netwk = mock.create_autospec(NetworkClient, instance=True) + netwk.name = "client" + secagg = mock.create_autospec(SecaggConfigClient, instance=True) + client = FederatedClient(netwk, train_data=MOCK_DATASET, secagg=secagg) + client.trainmanager = mock.create_autospec( + TrainingManager, instance=True + ) + fairness = mock.create_autospec( + FairnessControllerClient, instance=True + ) + client.fairness = fairness + # Run the routine and verify that an Error message was sent. + request = messaging.FairnessQuery(round_i=1) + await client.fairness_round(request) + netwk.send_message.assert_called_once() + reply = netwk.send_message.call_args.args[0] + assert isinstance(reply, messaging.Error) + fairness.fairness_round.assert_not_called() + + class TestFederatedClientMisc: """Unit tests for miscellaneous 'FederatedClient' methods."""