diff --git a/test/main/test_main_server.py b/test/main/test_main_server.py index 1ae08bb5bab6bda583c47b5d87a1eb81827db905..0e6e07264316a1e26bb56db3415d78078f6b4c47 100644 --- a/test/main/test_main_server.py +++ b/test/main/test_main_server.py @@ -20,21 +20,40 @@ import logging import os from unittest import mock +from typing import Optional, Type import pytest # type: ignore -from declearn.aggregator import Aggregator +from declearn.aggregator import Aggregator, ModelUpdates from declearn.communication import NetworkServerConfig from declearn.communication.api import NetworkServer from declearn.fairness.api import FairnessControllerServer from declearn.main import FederatedServer -from declearn.main.config import FLOptimConfig +from declearn.main.config import ( + FLOptimConfig, + EvaluateConfig, + FairnessConfig, + TrainingConfig, +) from declearn.main.utils import Checkpointer from declearn.metrics import MetricSet +from declearn.messaging import ( + EvaluationReply, + EvaluationRequest, + FairnessQuery, + Message, + SerializedMessage, + TrainRequest, + TrainReply, +) from declearn.model.api import Model from declearn.model.sklearn import SklearnSGDModel from declearn.optimizer import Optimizer -from declearn.secagg.api import SecaggConfigServer +from declearn.secagg.api import Decrypter, SecaggConfigServer +from declearn.secagg.messaging import ( + SecaggEvaluationReply, + SecaggTrainReply, +) from declearn.utils import serialize_object @@ -348,3 +367,273 @@ class TestFederatedServerInit: # pylint: disable=too-many-public-methods FederatedServer( MOCK_MODEL, MOCK_NETWK, MOCK_OPTIM, logger=mock.MagicMock() ) + + +class TestFederatedServerRoutines: + """Unit tests for 'FederatedServer' main unitary routines.""" + + async def setup_server( + self, + use_secagg: bool = False, + use_fairness: bool = False, + ) -> FederatedServer: + """Set up a FederatedServer wrapping mock controllers.""" + netwk = mock.create_autospec(NetworkServer, instance=True) + netwk.name = "server" + netwk.client_names = {"client_a", "client_b"} + optim = FLOptimConfig( + client_opt=mock.create_autospec(Optimizer, instance=True), + server_opt=mock.create_autospec(Optimizer, instance=True), + aggregator=mock.create_autospec(Aggregator, instance=True), + fairness=( + mock.create_autospec(FairnessControllerServer, instance=True) + if use_fairness + else None + ), + ) + secagg = ( + mock.create_autospec(SecaggConfigServer, instance=True) + if use_secagg + else None + ) + return FederatedServer( + model=mock.create_autospec(Model, instance=True), + netwk=netwk, + optim=optim, + metrics=mock.create_autospec(MetricSet, instance=True), + secagg=secagg, + checkpoint=mock.create_autospec(Checkpointer, instance=True), + ) + + @staticmethod + def setup_mock_serialized_message( + msg_cls: Type[Message], + wrapped: Optional[Message] = None, + ) -> mock.Base: + """Set up a mock SerializedMessage with given wrapped message type.""" + message = mock.create_autospec(SerializedMessage, instance=True) + message.message_cls = msg_cls + if wrapped is None: + wrapped = mock.create_autospec(msg_cls, instance=True) + message.deserialize.return_value = wrapped + return message + + @pytest.mark.parametrize("secagg", [False, True], ids=["clrtxt", "secagg"]) + @pytest.mark.asyncio + async def test_training_round( + self, + secagg: bool, + ) -> None: + """Test that the 'training_round' routine triggers expected calls.""" + # Set up a server with mocked attributes. + server = await self.setup_server(use_secagg=secagg) + assert isinstance(server.netwk, mock.NonCallableMagicMock) + assert isinstance(server.model, mock.NonCallableMagicMock) + assert isinstance(server.optim, mock.NonCallableMagicMock) + assert isinstance(server.aggrg, mock.NonCallableMagicMock) + # Mock-run a training routine. + reply_cls = ( + SecaggTrainReply if secagg else TrainReply # type: ignore + ) # type: Type[Message] + updates = mock.create_autospec(ModelUpdates, instance=True) + reply_msg = TrainReply( + n_epoch=1, n_steps=10, t_spent=0.0, updates=updates, aux_var={} + ) + wrapped = None if secagg else reply_msg + server.netwk.wait_for_messages.return_value = { + "client_a": self.setup_mock_serialized_message(reply_cls, wrapped), + "client_b": self.setup_mock_serialized_message(reply_cls, wrapped), + } + with mock.patch( + "declearn.secagg.messaging.aggregate_secagg_messages", + return_value=reply_msg, + ) as patch_aggregate_secagg_messages: + await server.training_round( + round_i=1, train_cfg=TrainingConfig(batch_size=8) + ) + # Verify that expected actions occured. + # (a) optional secagg setup + if secagg: + assert isinstance(server.secagg, mock.NonCallableMagicMock) + server.secagg.setup_decrypter.assert_awaited_once() + # (b) training request emission, including model weights + server.netwk.send_messages.assert_awaited_once() + queries = server.netwk.send_messages.await_args[0][0] + assert isinstance(queries, dict) + assert queries.keys() == server.netwk.client_names + for query in queries.values(): + assert isinstance(query, TrainRequest) + assert query.weights is server.model.get_weights.return_value + assert query.aux_var is server.optim.collect_aux_var.return_value + # (c) training reply reception + server.netwk.wait_for_messages.assert_awaited_once() + if secagg: + patch_aggregate_secagg_messages.assert_called_once() + else: + patch_aggregate_secagg_messages.assert_not_called() + # (d) updates aggregation and global model weights update + server.optim.process_aux_var.assert_called_once() + server.aggrg.finalize_updates.assert_called_once() + server.optim.apply_gradients.assert_called_once() + + @pytest.mark.parametrize("secagg", [False, True], ids=["clrtxt", "secagg"]) + @pytest.mark.asyncio + async def test_evaluation_round( + self, + secagg: bool, + ) -> None: + """Test that the 'evaluation_round' routine triggers expected calls.""" + # Set up a server with mocked attributes. + server = await self.setup_server(use_secagg=secagg) + assert isinstance(server.netwk, mock.NonCallableMagicMock) + assert isinstance(server.model, mock.NonCallableMagicMock) + assert isinstance(server.metrics, mock.NonCallableMagicMock) + assert isinstance(server.ckptr, mock.NonCallableMagicMock) + # Mock-run an evaluation routine. + reply_cls = ( + SecaggEvaluationReply # type: ignore + if secagg + else EvaluationReply + ) # type: Type[Message] + reply_msg = EvaluationReply( + loss=0.42, n_steps=10, t_spent=0.0, metrics={} + ) + wrapped = None if secagg else reply_msg + server.netwk.wait_for_messages.return_value = { + "client_a": self.setup_mock_serialized_message(reply_cls, wrapped), + "client_b": self.setup_mock_serialized_message(reply_cls, wrapped), + } + with mock.patch( + "declearn.secagg.messaging.aggregate_secagg_messages", + return_value=reply_msg, + ) as patch_aggregate_secagg_messages: + await server.evaluation_round( + round_i=1, valid_cfg=EvaluateConfig(batch_size=8) + ) + # Verify that expected actions occured. + # (a) optional secagg setup + if secagg: + assert isinstance(server.secagg, mock.NonCallableMagicMock) + server.secagg.setup_decrypter.assert_awaited_once() + # (b) evaluation request emission, including model weights + server.netwk.send_messages.assert_awaited_once() + queries = server.netwk.send_messages.await_args[0][0] + assert isinstance(queries, dict) + assert queries.keys() == server.netwk.client_names + for query in queries.values(): + assert isinstance(query, EvaluationRequest) + assert query.weights is server.model.get_weights.return_value + # (c) evaluation reply reception + server.netwk.wait_for_messages.assert_awaited_once() + if secagg: + patch_aggregate_secagg_messages.assert_called_once() + else: + patch_aggregate_secagg_messages.assert_not_called() + # (d) metrics aggregation + server.metrics.reset.assert_called_once() + server.metrics.set_states.assert_called_once() + server.metrics.get_result.assert_called_once() + # (e) checkpointing + server.ckptr.checkpoint.assert_called_once_with( + model=server.model, + optimizer=server.optim, + metrics=server.metrics.get_result.return_value, + ) + + @pytest.mark.asyncio + async def test_evaluation_round_skip( + self, + ) -> None: + """Test that 'evaluation_round' skips rounds when configured.""" + # Set up a server with mocked attributes. + server = await self.setup_server() + assert isinstance(server.netwk, mock.NonCallableMagicMock) + # Mock a call that should result in skipping the round. + await server.evaluation_round( + round_i=1, + valid_cfg=EvaluateConfig(batch_size=8, frequency=2), + ) + # Assert that no message was sent (routine was skipped). + server.netwk.broadcast_message.assert_not_called() + server.netwk.send_messages.assert_not_called() + server.netwk.send_message.assert_not_called() + + @pytest.mark.parametrize("secagg", [False, True], ids=["clrtxt", "secagg"]) + @pytest.mark.asyncio + async def test_fairness_round( + self, + secagg: bool, + ) -> None: + """Test that the 'fairness_round' routine triggers expected calls.""" + # Set up a server with mocked attributes. + server = await self.setup_server(use_secagg=secagg, use_fairness=True) + assert isinstance(server.netwk, mock.NonCallableMagicMock) + assert isinstance(server.model, mock.NonCallableMagicMock) + assert isinstance(server.fairness, mock.NonCallableMagicMock) + assert isinstance(server.ckptr, mock.NonCallableMagicMock) + # Mock-run a fairness routine. + await server.fairness_round( + round_i=0, + fairness_cfg=FairnessConfig(), + ) + # Verify that expected actions occured. + # (a) optional secagg setup + decrypter = None # type: Optional[Decrypter] + if secagg: + assert isinstance(server.secagg, mock.NonCallableMagicMock) + server.secagg.setup_decrypter.assert_awaited_once() + decrypter = server.secagg.setup_decrypter.return_value + # (b) fairness query emission, including model weights + server.netwk.send_messages.assert_awaited_once() + queries = server.netwk.send_messages.await_args[0][0] + assert isinstance(queries, dict) + assert queries.keys() == server.netwk.client_names + for query in queries.values(): + assert isinstance(query, FairnessQuery) + assert query.weights is server.model.get_weights.return_value + # (c) fairness controller round routine + server.fairness.run_fairness_round.assert_awaited_once_with( + netwk=server.netwk, secagg=decrypter + ) + # (d) checkpointing + server.ckptr.save_metrics.assert_called_once_with( + metrics=server.fairness.run_fairness_round.return_value, + prefix="fairness_metrics", + append=False, + timestamp="round_0", + ) + + @pytest.mark.asyncio + async def test_fairness_round_undefined( + self, + ) -> None: + """Test that 'fairness_round' early-exits when fairness is not set.""" + # Set up a server with mocked attributes and no fairness controller. + server = await self.setup_server(use_fairness=False) + assert isinstance(server.netwk, mock.NonCallableMagicMock) + assert server.fairness is None + # Call the fairness round routine.1 + await server.fairness_round( + round_i=0, + fairness_cfg=FairnessConfig(), + ) + # Assert that no message was sent (routine was skipped). + server.netwk.broadcast_message.assert_not_called() + server.netwk.send_messages.assert_not_called() + server.netwk.send_message.assert_not_called() + + @pytest.mark.asyncio + async def test_fairness_round_skip( + self, + ) -> None: + """Test that 'fairness_round' skips rounds when configured.""" + # Set up a server with a mocked fairness controller. + server = await self.setup_server(use_fairness=True) + assert isinstance(server.fairness, mock.NonCallableMagicMock) + # Mock a call that should result in skipping the round. + await server.fairness_round( + round_i=1, + fairness_cfg=FairnessConfig(frequency=2), + ) + # Assert that the round was skipped. + server.fairness.run_fairness_round.assert_not_called()