diff --git a/test/main/test_main_server.py b/test/main/test_main_server.py index 0e6e07264316a1e26bb56db3415d78078f6b4c47..e8e416b20e204d31869098a51724fc2b6d0efa2b 100644 --- a/test/main/test_main_server.py +++ b/test/main/test_main_server.py @@ -18,9 +18,10 @@ """Unit tests for 'FederatedServer'.""" import logging +import math import os from unittest import mock -from typing import Optional, Type +from typing import Dict, List, Optional, Type import pytest # type: ignore @@ -31,8 +32,10 @@ from declearn.fairness.api import FairnessControllerServer from declearn.main import FederatedServer from declearn.main.config import ( FLOptimConfig, + FLRunConfig, EvaluateConfig, FairnessConfig, + RegisterConfig, TrainingConfig, ) from declearn.main.utils import Checkpointer @@ -41,8 +44,15 @@ from declearn.messaging import ( EvaluationReply, EvaluationRequest, FairnessQuery, + InitReply, + InitRequest, Message, + MetadataQuery, + MetadataReply, + PrivacyReply, + PrivacyRequest, SerializedMessage, + StopTraining, TrainRequest, TrainReply, ) @@ -372,8 +382,8 @@ class TestFederatedServerInit: # pylint: disable=too-many-public-methods class TestFederatedServerRoutines: """Unit tests for 'FederatedServer' main unitary routines.""" - async def setup_server( - self, + @staticmethod + async def setup_test_server( use_secagg: bool = False, use_fairness: bool = False, ) -> FederatedServer: @@ -391,11 +401,10 @@ class TestFederatedServerRoutines: else None ), ) - secagg = ( - mock.create_autospec(SecaggConfigServer, instance=True) - if use_secagg - else None - ) + secagg = None # type: Optional[SecaggConfigServer] + if use_secagg: + secagg = mock.create_autospec(SecaggConfigServer, instance=True) + secagg.secagg_type = "mock_secagg" # type: ignore return FederatedServer( model=mock.create_autospec(Model, instance=True), netwk=netwk, @@ -409,7 +418,7 @@ class TestFederatedServerRoutines: def setup_mock_serialized_message( msg_cls: Type[Message], wrapped: Optional[Message] = None, - ) -> mock.Base: + ) -> mock.NonCallableMagicMock: """Set up a mock SerializedMessage with given wrapped message type.""" message = mock.create_autospec(SerializedMessage, instance=True) message.message_cls = msg_cls @@ -418,6 +427,120 @@ class TestFederatedServerRoutines: message.deserialize.return_value = wrapped return message + @pytest.mark.parametrize( + "metadata", [False, True], ids=["nometa", "metadata"] + ) + @pytest.mark.parametrize("privacy", [False, True], ids=["nodp", "dpsgd"]) + @pytest.mark.parametrize( + "fairness", [False, True], ids=["unfair", "fairness"] + ) + @pytest.mark.parametrize("secagg", [False, True], ids=["clrtxt", "secagg"]) + @pytest.mark.asyncio + async def test_initialization( + self, + secagg: bool, + fairness: bool, + privacy: bool, + metadata: bool, + ) -> None: + """Test that the 'initialization' routine triggers expected calls.""" + # Set up a server with mocked attributes. + server = await self.setup_test_server( + use_secagg=secagg, use_fairness=fairness + ) + assert isinstance(server.netwk, mock.NonCallableMagicMock) + assert isinstance(server.model, mock.NonCallableMagicMock) + server.model.required_data_info = {"n_samples"} if metadata else {} + aggrg = server.aggrg + # Run the initialization routine. + config = FLRunConfig.from_params( + rounds=10, + register=RegisterConfig(0, 2, 120), + training={"batch_size": 8}, + privacy=( + {"budget": (1e-3, 0.0), "sclip_norm": 1.0} if privacy else None + ), + ) + server.netwk.wait_for_messages.side_effect = self._setup_init_replies( + metadata, privacy + ) + await server.initialization(config) + # Verify that the clients-registration routine was called. + server.netwk.wait_for_clients.assert_awaited_once_with(0, 2, 120) + # Verify that the expected number of message exchanges occured. + assert server.netwk.broadcast_message.await_count == ( + 1 + metadata + privacy + ) + queries = server.netwk.broadcast_message.await_args_list.copy() + # When configured, verify that metadata were queried and used. + if metadata: + query = queries.pop(0)[0][0] + assert isinstance(query, MetadataQuery) + assert query.fields == ["n_samples"] + server.model.initialize.assert_called_once_with({"n_samples": 200}) + # Verify that an InitRequest was sent with expected parameters. + query = queries.pop(0)[0][0] + assert isinstance(query, InitRequest) + assert query.dpsgd is privacy + if secagg: + assert query.secagg is not None + else: + assert query.secagg is None + assert query.fairness is fairness + # Verify that DP-SGD setup occurred when expected. + if privacy: + query = queries.pop(0)[0][0] + assert isinstance(query, PrivacyRequest) + assert query.budget == (1e-3, 0.0) + assert query.sclip_norm == 1.0 + assert query.rounds == 10 + # Verify that SecAgg setup occurred when expected. + decrypter = None # type: Optional[Decrypter] + if secagg: + assert isinstance(server.secagg, mock.NonCallableMagicMock) + if fairness: + server.secagg.setup_decrypter.assert_awaited_once() + decrypter = server.secagg.setup_decrypter.return_value + else: + server.secagg.setup_decrypter.assert_not_called() + # Verify that fairness setup occurred when expected. + if fairness: + assert isinstance(server.fairness, mock.NonCallableMagicMock) + server.fairness.setup_fairness.assert_awaited_once_with( + netwk=server.netwk, aggregator=aggrg, secagg=decrypter + ) + assert server.aggrg is server.fairness.setup_fairness.return_value + + def _setup_init_replies( + self, + metadata: bool, + privacy: bool, + ) -> List[Dict[str, mock.NonCallableMagicMock]]: + clients = ("client_a", "client_b") + messages = [] # type: List[Dict[str, mock.NonCallableMagicMock]] + if metadata: + msg = MetadataReply({"n_samples": 100}) + messages.append( + { + key: self.setup_mock_serialized_message(MetadataReply, msg) + for key in clients + } + ) + messages.append( + { + key: self.setup_mock_serialized_message(InitReply) + for key in clients + } + ) + if privacy: + messages.append( + { + key: self.setup_mock_serialized_message(PrivacyReply) + for key in clients + } + ) + return messages + @pytest.mark.parametrize("secagg", [False, True], ids=["clrtxt", "secagg"]) @pytest.mark.asyncio async def test_training_round( @@ -426,7 +549,7 @@ class TestFederatedServerRoutines: ) -> None: """Test that the 'training_round' routine triggers expected calls.""" # Set up a server with mocked attributes. - server = await self.setup_server(use_secagg=secagg) + server = await self.setup_test_server(use_secagg=secagg) assert isinstance(server.netwk, mock.NonCallableMagicMock) assert isinstance(server.model, mock.NonCallableMagicMock) assert isinstance(server.optim, mock.NonCallableMagicMock) @@ -484,7 +607,7 @@ class TestFederatedServerRoutines: ) -> None: """Test that the 'evaluation_round' routine triggers expected calls.""" # Set up a server with mocked attributes. - server = await self.setup_server(use_secagg=secagg) + server = await self.setup_test_server(use_secagg=secagg) assert isinstance(server.netwk, mock.NonCallableMagicMock) assert isinstance(server.model, mock.NonCallableMagicMock) assert isinstance(server.metrics, mock.NonCallableMagicMock) @@ -546,7 +669,7 @@ class TestFederatedServerRoutines: ) -> None: """Test that 'evaluation_round' skips rounds when configured.""" # Set up a server with mocked attributes. - server = await self.setup_server() + server = await self.setup_test_server() assert isinstance(server.netwk, mock.NonCallableMagicMock) # Mock a call that should result in skipping the round. await server.evaluation_round( @@ -566,7 +689,9 @@ class TestFederatedServerRoutines: ) -> None: """Test that the 'fairness_round' routine triggers expected calls.""" # Set up a server with mocked attributes. - server = await self.setup_server(use_secagg=secagg, use_fairness=True) + server = await self.setup_test_server( + use_secagg=secagg, use_fairness=True + ) assert isinstance(server.netwk, mock.NonCallableMagicMock) assert isinstance(server.model, mock.NonCallableMagicMock) assert isinstance(server.fairness, mock.NonCallableMagicMock) @@ -609,7 +734,7 @@ class TestFederatedServerRoutines: ) -> None: """Test that 'fairness_round' early-exits when fairness is not set.""" # Set up a server with mocked attributes and no fairness controller. - server = await self.setup_server(use_fairness=False) + server = await self.setup_test_server(use_fairness=False) assert isinstance(server.netwk, mock.NonCallableMagicMock) assert server.fairness is None # Call the fairness round routine.1 @@ -628,7 +753,7 @@ class TestFederatedServerRoutines: ) -> None: """Test that 'fairness_round' skips rounds when configured.""" # Set up a server with a mocked fairness controller. - server = await self.setup_server(use_fairness=True) + server = await self.setup_test_server(use_fairness=True) assert isinstance(server.fairness, mock.NonCallableMagicMock) # Mock a call that should result in skipping the round. await server.fairness_round( @@ -637,3 +762,145 @@ class TestFederatedServerRoutines: ) # Assert that the round was skipped. server.fairness.run_fairness_round.assert_not_called() + + @pytest.mark.asyncio + async def test_stop_training( + self, + ) -> None: + """Test that 'stop_training' triggers expected actions.""" + # Set up a server with mocked attributes. + server = await self.setup_test_server() + assert isinstance(server.netwk, mock.NonCallableMagicMock) + assert isinstance(server.model, mock.NonCallableMagicMock) + assert isinstance(server.ckptr, mock.NonCallableMagicMock) + server.ckptr.folder = "mock_folder" + # Call the 'stop_training' routine. + await server.stop_training(rounds=5) + # Verify that the expected message was broadcasted. + server.netwk.broadcast_message.assert_awaited_once() + message = server.netwk.broadcast_message.await_args[0][0] + assert isinstance(message, StopTraining) + assert message.weights is server.model.get_weights.return_value + assert math.isnan(message.loss) + assert message.rounds == 5 + # Verify that the expected checkpointing occured. + server.ckptr.save_model.assert_called_once_with( + server.model, timestamp="best" + ) + + +class TestFederatedServerRun: + """Unit tests for 'FederatedServer.run' and 'async_run' routines.""" + + # Unit tests for FLRunConfig parsing via synchronous 'run' method. + + def test_run_from_dict( + self, + ) -> None: + """Test that 'run' properly parses input dict config. + + Mock the actual underlying routine. + """ + server = FederatedServer( + model=MOCK_MODEL, netwk=MOCK_NETWK, optim=MOCK_OPTIM + ) + config = mock.create_autospec(dict, instance=True) + with mock.patch.object( + FLRunConfig, + "from_params", + return_value=mock.create_autospec(FLRunConfig, instance=True), + ) as patch_flrunconfig_from_params: + with mock.patch.object(server, "async_run") as patch_async_run: + server.run(config) + patch_flrunconfig_from_params.assert_called_once_with(**config) + patch_async_run.assert_called_once_with( + patch_flrunconfig_from_params.return_value + ) + + def test_run_from_toml( + self, + ) -> None: + """Test that 'run' properly parses input TOML file. + + Mock the actual underlying routine. + """ + server = FederatedServer( + model=MOCK_MODEL, netwk=MOCK_NETWK, optim=MOCK_OPTIM + ) + config = "mock_path.toml" + with mock.patch.object( + FLRunConfig, + "from_toml", + return_value=mock.create_autospec(FLRunConfig, instance=True), + ) as patch_flrunconfig_from_toml: + with mock.patch.object(server, "async_run") as patch_async_run: + server.run(config) + patch_flrunconfig_from_toml.assert_called_once_with(config) + patch_async_run.assert_called_once_with( + patch_flrunconfig_from_toml.return_value + ) + + def test_run_from_config( + self, + ) -> None: + """Test that 'run' properly uses input FLRunConfig. + + Mock the actual underlying routine. + """ + server = FederatedServer( + model=MOCK_MODEL, netwk=MOCK_NETWK, optim=MOCK_OPTIM + ) + config = mock.create_autospec(FLRunConfig, instance=True) + with mock.patch.object(server, "async_run") as patch_async_run: + server.run(config) + patch_async_run.assert_called_once_with(config) + + # Unit tests for overall actions sequence in 'async_run'. + + @pytest.mark.asyncio + async def test_async_run_actions_sequence(self) -> None: + """Test that 'async_run' triggers expected routines.""" + # Setup a server and a run config with mock attributes. + server = FederatedServer( + model=MOCK_MODEL, + netwk=MOCK_NETWK, + optim=MOCK_OPTIM, + checkpoint=mock.create_autospec(Checkpointer, instance=True), + ) + config = FLRunConfig( + rounds=10, + register=mock.create_autospec(RegisterConfig, instance=True), + training=mock.create_autospec(TrainingConfig, instance=True), + evaluate=mock.create_autospec(EvaluateConfig, instance=True), + fairness=mock.create_autospec(FairnessConfig, instance=True), + privacy=None, + early_stop=None, + ) + # Call 'async_run', mocking all underlying routines. + with mock.patch.object( + server, "initialization" + ) as patch_initialization: + with mock.patch.object(server, "training_round") as patch_training: + with mock.patch.object( + server, "evaluation_round" + ) as patch_evaluation: + with mock.patch.object( + server, "fairness_round" + ) as patch_fairness: + with mock.patch.object( + server, "stop_training" + ) as patch_stop_training: + await server.async_run(config) + # Verify that expected calls occured. + patch_initialization.assert_called_once_with(config) + patch_training.assert_has_calls( + [mock.call(idx, config.training) for idx in range(1, 11)] + ) + patch_evaluation.assert_has_calls( + [mock.call(idx, config.evaluate) for idx in range(1, 11)] + ) + patch_fairness.assert_has_calls( + [mock.call(idx, config.fairness) for idx in range(0, 10)] + + [mock.call(10, config.fairness, force_run=True)] + ) + patch_stop_training.assert_called_once_with(10)