Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit 580ce463 authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Extend unit tests for 'FederatedServer'.

parent 25c7b773
No related branches found
No related tags found
1 merge request!70Finalize version 2.6.0
...@@ -18,9 +18,10 @@ ...@@ -18,9 +18,10 @@
"""Unit tests for 'FederatedServer'.""" """Unit tests for 'FederatedServer'."""
import logging import logging
import math
import os import os
from unittest import mock from unittest import mock
from typing import Optional, Type from typing import Dict, List, Optional, Type
import pytest # type: ignore import pytest # type: ignore
...@@ -31,8 +32,10 @@ from declearn.fairness.api import FairnessControllerServer ...@@ -31,8 +32,10 @@ from declearn.fairness.api import FairnessControllerServer
from declearn.main import FederatedServer from declearn.main import FederatedServer
from declearn.main.config import ( from declearn.main.config import (
FLOptimConfig, FLOptimConfig,
FLRunConfig,
EvaluateConfig, EvaluateConfig,
FairnessConfig, FairnessConfig,
RegisterConfig,
TrainingConfig, TrainingConfig,
) )
from declearn.main.utils import Checkpointer from declearn.main.utils import Checkpointer
...@@ -41,8 +44,15 @@ from declearn.messaging import ( ...@@ -41,8 +44,15 @@ from declearn.messaging import (
EvaluationReply, EvaluationReply,
EvaluationRequest, EvaluationRequest,
FairnessQuery, FairnessQuery,
InitReply,
InitRequest,
Message, Message,
MetadataQuery,
MetadataReply,
PrivacyReply,
PrivacyRequest,
SerializedMessage, SerializedMessage,
StopTraining,
TrainRequest, TrainRequest,
TrainReply, TrainReply,
) )
...@@ -372,8 +382,8 @@ class TestFederatedServerInit: # pylint: disable=too-many-public-methods ...@@ -372,8 +382,8 @@ class TestFederatedServerInit: # pylint: disable=too-many-public-methods
class TestFederatedServerRoutines: class TestFederatedServerRoutines:
"""Unit tests for 'FederatedServer' main unitary routines.""" """Unit tests for 'FederatedServer' main unitary routines."""
async def setup_server( @staticmethod
self, async def setup_test_server(
use_secagg: bool = False, use_secagg: bool = False,
use_fairness: bool = False, use_fairness: bool = False,
) -> FederatedServer: ) -> FederatedServer:
...@@ -391,11 +401,10 @@ class TestFederatedServerRoutines: ...@@ -391,11 +401,10 @@ class TestFederatedServerRoutines:
else None else None
), ),
) )
secagg = ( secagg = None # type: Optional[SecaggConfigServer]
mock.create_autospec(SecaggConfigServer, instance=True) if use_secagg:
if use_secagg secagg = mock.create_autospec(SecaggConfigServer, instance=True)
else None secagg.secagg_type = "mock_secagg" # type: ignore
)
return FederatedServer( return FederatedServer(
model=mock.create_autospec(Model, instance=True), model=mock.create_autospec(Model, instance=True),
netwk=netwk, netwk=netwk,
...@@ -409,7 +418,7 @@ class TestFederatedServerRoutines: ...@@ -409,7 +418,7 @@ class TestFederatedServerRoutines:
def setup_mock_serialized_message( def setup_mock_serialized_message(
msg_cls: Type[Message], msg_cls: Type[Message],
wrapped: Optional[Message] = None, wrapped: Optional[Message] = None,
) -> mock.Base: ) -> mock.NonCallableMagicMock:
"""Set up a mock SerializedMessage with given wrapped message type.""" """Set up a mock SerializedMessage with given wrapped message type."""
message = mock.create_autospec(SerializedMessage, instance=True) message = mock.create_autospec(SerializedMessage, instance=True)
message.message_cls = msg_cls message.message_cls = msg_cls
...@@ -418,6 +427,120 @@ class TestFederatedServerRoutines: ...@@ -418,6 +427,120 @@ class TestFederatedServerRoutines:
message.deserialize.return_value = wrapped message.deserialize.return_value = wrapped
return message return message
@pytest.mark.parametrize(
"metadata", [False, True], ids=["nometa", "metadata"]
)
@pytest.mark.parametrize("privacy", [False, True], ids=["nodp", "dpsgd"])
@pytest.mark.parametrize(
"fairness", [False, True], ids=["unfair", "fairness"]
)
@pytest.mark.parametrize("secagg", [False, True], ids=["clrtxt", "secagg"])
@pytest.mark.asyncio
async def test_initialization(
self,
secagg: bool,
fairness: bool,
privacy: bool,
metadata: bool,
) -> None:
"""Test that the 'initialization' routine triggers expected calls."""
# Set up a server with mocked attributes.
server = await self.setup_test_server(
use_secagg=secagg, use_fairness=fairness
)
assert isinstance(server.netwk, mock.NonCallableMagicMock)
assert isinstance(server.model, mock.NonCallableMagicMock)
server.model.required_data_info = {"n_samples"} if metadata else {}
aggrg = server.aggrg
# Run the initialization routine.
config = FLRunConfig.from_params(
rounds=10,
register=RegisterConfig(0, 2, 120),
training={"batch_size": 8},
privacy=(
{"budget": (1e-3, 0.0), "sclip_norm": 1.0} if privacy else None
),
)
server.netwk.wait_for_messages.side_effect = self._setup_init_replies(
metadata, privacy
)
await server.initialization(config)
# Verify that the clients-registration routine was called.
server.netwk.wait_for_clients.assert_awaited_once_with(0, 2, 120)
# Verify that the expected number of message exchanges occured.
assert server.netwk.broadcast_message.await_count == (
1 + metadata + privacy
)
queries = server.netwk.broadcast_message.await_args_list.copy()
# When configured, verify that metadata were queried and used.
if metadata:
query = queries.pop(0)[0][0]
assert isinstance(query, MetadataQuery)
assert query.fields == ["n_samples"]
server.model.initialize.assert_called_once_with({"n_samples": 200})
# Verify that an InitRequest was sent with expected parameters.
query = queries.pop(0)[0][0]
assert isinstance(query, InitRequest)
assert query.dpsgd is privacy
if secagg:
assert query.secagg is not None
else:
assert query.secagg is None
assert query.fairness is fairness
# Verify that DP-SGD setup occurred when expected.
if privacy:
query = queries.pop(0)[0][0]
assert isinstance(query, PrivacyRequest)
assert query.budget == (1e-3, 0.0)
assert query.sclip_norm == 1.0
assert query.rounds == 10
# Verify that SecAgg setup occurred when expected.
decrypter = None # type: Optional[Decrypter]
if secagg:
assert isinstance(server.secagg, mock.NonCallableMagicMock)
if fairness:
server.secagg.setup_decrypter.assert_awaited_once()
decrypter = server.secagg.setup_decrypter.return_value
else:
server.secagg.setup_decrypter.assert_not_called()
# Verify that fairness setup occurred when expected.
if fairness:
assert isinstance(server.fairness, mock.NonCallableMagicMock)
server.fairness.setup_fairness.assert_awaited_once_with(
netwk=server.netwk, aggregator=aggrg, secagg=decrypter
)
assert server.aggrg is server.fairness.setup_fairness.return_value
def _setup_init_replies(
self,
metadata: bool,
privacy: bool,
) -> List[Dict[str, mock.NonCallableMagicMock]]:
clients = ("client_a", "client_b")
messages = [] # type: List[Dict[str, mock.NonCallableMagicMock]]
if metadata:
msg = MetadataReply({"n_samples": 100})
messages.append(
{
key: self.setup_mock_serialized_message(MetadataReply, msg)
for key in clients
}
)
messages.append(
{
key: self.setup_mock_serialized_message(InitReply)
for key in clients
}
)
if privacy:
messages.append(
{
key: self.setup_mock_serialized_message(PrivacyReply)
for key in clients
}
)
return messages
@pytest.mark.parametrize("secagg", [False, True], ids=["clrtxt", "secagg"]) @pytest.mark.parametrize("secagg", [False, True], ids=["clrtxt", "secagg"])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_training_round( async def test_training_round(
...@@ -426,7 +549,7 @@ class TestFederatedServerRoutines: ...@@ -426,7 +549,7 @@ class TestFederatedServerRoutines:
) -> None: ) -> None:
"""Test that the 'training_round' routine triggers expected calls.""" """Test that the 'training_round' routine triggers expected calls."""
# Set up a server with mocked attributes. # Set up a server with mocked attributes.
server = await self.setup_server(use_secagg=secagg) server = await self.setup_test_server(use_secagg=secagg)
assert isinstance(server.netwk, mock.NonCallableMagicMock) assert isinstance(server.netwk, mock.NonCallableMagicMock)
assert isinstance(server.model, mock.NonCallableMagicMock) assert isinstance(server.model, mock.NonCallableMagicMock)
assert isinstance(server.optim, mock.NonCallableMagicMock) assert isinstance(server.optim, mock.NonCallableMagicMock)
...@@ -484,7 +607,7 @@ class TestFederatedServerRoutines: ...@@ -484,7 +607,7 @@ class TestFederatedServerRoutines:
) -> None: ) -> None:
"""Test that the 'evaluation_round' routine triggers expected calls.""" """Test that the 'evaluation_round' routine triggers expected calls."""
# Set up a server with mocked attributes. # Set up a server with mocked attributes.
server = await self.setup_server(use_secagg=secagg) server = await self.setup_test_server(use_secagg=secagg)
assert isinstance(server.netwk, mock.NonCallableMagicMock) assert isinstance(server.netwk, mock.NonCallableMagicMock)
assert isinstance(server.model, mock.NonCallableMagicMock) assert isinstance(server.model, mock.NonCallableMagicMock)
assert isinstance(server.metrics, mock.NonCallableMagicMock) assert isinstance(server.metrics, mock.NonCallableMagicMock)
...@@ -546,7 +669,7 @@ class TestFederatedServerRoutines: ...@@ -546,7 +669,7 @@ class TestFederatedServerRoutines:
) -> None: ) -> None:
"""Test that 'evaluation_round' skips rounds when configured.""" """Test that 'evaluation_round' skips rounds when configured."""
# Set up a server with mocked attributes. # Set up a server with mocked attributes.
server = await self.setup_server() server = await self.setup_test_server()
assert isinstance(server.netwk, mock.NonCallableMagicMock) assert isinstance(server.netwk, mock.NonCallableMagicMock)
# Mock a call that should result in skipping the round. # Mock a call that should result in skipping the round.
await server.evaluation_round( await server.evaluation_round(
...@@ -566,7 +689,9 @@ class TestFederatedServerRoutines: ...@@ -566,7 +689,9 @@ class TestFederatedServerRoutines:
) -> None: ) -> None:
"""Test that the 'fairness_round' routine triggers expected calls.""" """Test that the 'fairness_round' routine triggers expected calls."""
# Set up a server with mocked attributes. # Set up a server with mocked attributes.
server = await self.setup_server(use_secagg=secagg, use_fairness=True) server = await self.setup_test_server(
use_secagg=secagg, use_fairness=True
)
assert isinstance(server.netwk, mock.NonCallableMagicMock) assert isinstance(server.netwk, mock.NonCallableMagicMock)
assert isinstance(server.model, mock.NonCallableMagicMock) assert isinstance(server.model, mock.NonCallableMagicMock)
assert isinstance(server.fairness, mock.NonCallableMagicMock) assert isinstance(server.fairness, mock.NonCallableMagicMock)
...@@ -609,7 +734,7 @@ class TestFederatedServerRoutines: ...@@ -609,7 +734,7 @@ class TestFederatedServerRoutines:
) -> None: ) -> None:
"""Test that 'fairness_round' early-exits when fairness is not set.""" """Test that 'fairness_round' early-exits when fairness is not set."""
# Set up a server with mocked attributes and no fairness controller. # Set up a server with mocked attributes and no fairness controller.
server = await self.setup_server(use_fairness=False) server = await self.setup_test_server(use_fairness=False)
assert isinstance(server.netwk, mock.NonCallableMagicMock) assert isinstance(server.netwk, mock.NonCallableMagicMock)
assert server.fairness is None assert server.fairness is None
# Call the fairness round routine.1 # Call the fairness round routine.1
...@@ -628,7 +753,7 @@ class TestFederatedServerRoutines: ...@@ -628,7 +753,7 @@ class TestFederatedServerRoutines:
) -> None: ) -> None:
"""Test that 'fairness_round' skips rounds when configured.""" """Test that 'fairness_round' skips rounds when configured."""
# Set up a server with a mocked fairness controller. # Set up a server with a mocked fairness controller.
server = await self.setup_server(use_fairness=True) server = await self.setup_test_server(use_fairness=True)
assert isinstance(server.fairness, mock.NonCallableMagicMock) assert isinstance(server.fairness, mock.NonCallableMagicMock)
# Mock a call that should result in skipping the round. # Mock a call that should result in skipping the round.
await server.fairness_round( await server.fairness_round(
...@@ -637,3 +762,145 @@ class TestFederatedServerRoutines: ...@@ -637,3 +762,145 @@ class TestFederatedServerRoutines:
) )
# Assert that the round was skipped. # Assert that the round was skipped.
server.fairness.run_fairness_round.assert_not_called() server.fairness.run_fairness_round.assert_not_called()
@pytest.mark.asyncio
async def test_stop_training(
self,
) -> None:
"""Test that 'stop_training' triggers expected actions."""
# Set up a server with mocked attributes.
server = await self.setup_test_server()
assert isinstance(server.netwk, mock.NonCallableMagicMock)
assert isinstance(server.model, mock.NonCallableMagicMock)
assert isinstance(server.ckptr, mock.NonCallableMagicMock)
server.ckptr.folder = "mock_folder"
# Call the 'stop_training' routine.
await server.stop_training(rounds=5)
# Verify that the expected message was broadcasted.
server.netwk.broadcast_message.assert_awaited_once()
message = server.netwk.broadcast_message.await_args[0][0]
assert isinstance(message, StopTraining)
assert message.weights is server.model.get_weights.return_value
assert math.isnan(message.loss)
assert message.rounds == 5
# Verify that the expected checkpointing occured.
server.ckptr.save_model.assert_called_once_with(
server.model, timestamp="best"
)
class TestFederatedServerRun:
"""Unit tests for 'FederatedServer.run' and 'async_run' routines."""
# Unit tests for FLRunConfig parsing via synchronous 'run' method.
def test_run_from_dict(
self,
) -> None:
"""Test that 'run' properly parses input dict config.
Mock the actual underlying routine.
"""
server = FederatedServer(
model=MOCK_MODEL, netwk=MOCK_NETWK, optim=MOCK_OPTIM
)
config = mock.create_autospec(dict, instance=True)
with mock.patch.object(
FLRunConfig,
"from_params",
return_value=mock.create_autospec(FLRunConfig, instance=True),
) as patch_flrunconfig_from_params:
with mock.patch.object(server, "async_run") as patch_async_run:
server.run(config)
patch_flrunconfig_from_params.assert_called_once_with(**config)
patch_async_run.assert_called_once_with(
patch_flrunconfig_from_params.return_value
)
def test_run_from_toml(
self,
) -> None:
"""Test that 'run' properly parses input TOML file.
Mock the actual underlying routine.
"""
server = FederatedServer(
model=MOCK_MODEL, netwk=MOCK_NETWK, optim=MOCK_OPTIM
)
config = "mock_path.toml"
with mock.patch.object(
FLRunConfig,
"from_toml",
return_value=mock.create_autospec(FLRunConfig, instance=True),
) as patch_flrunconfig_from_toml:
with mock.patch.object(server, "async_run") as patch_async_run:
server.run(config)
patch_flrunconfig_from_toml.assert_called_once_with(config)
patch_async_run.assert_called_once_with(
patch_flrunconfig_from_toml.return_value
)
def test_run_from_config(
self,
) -> None:
"""Test that 'run' properly uses input FLRunConfig.
Mock the actual underlying routine.
"""
server = FederatedServer(
model=MOCK_MODEL, netwk=MOCK_NETWK, optim=MOCK_OPTIM
)
config = mock.create_autospec(FLRunConfig, instance=True)
with mock.patch.object(server, "async_run") as patch_async_run:
server.run(config)
patch_async_run.assert_called_once_with(config)
# Unit tests for overall actions sequence in 'async_run'.
@pytest.mark.asyncio
async def test_async_run_actions_sequence(self) -> None:
"""Test that 'async_run' triggers expected routines."""
# Setup a server and a run config with mock attributes.
server = FederatedServer(
model=MOCK_MODEL,
netwk=MOCK_NETWK,
optim=MOCK_OPTIM,
checkpoint=mock.create_autospec(Checkpointer, instance=True),
)
config = FLRunConfig(
rounds=10,
register=mock.create_autospec(RegisterConfig, instance=True),
training=mock.create_autospec(TrainingConfig, instance=True),
evaluate=mock.create_autospec(EvaluateConfig, instance=True),
fairness=mock.create_autospec(FairnessConfig, instance=True),
privacy=None,
early_stop=None,
)
# Call 'async_run', mocking all underlying routines.
with mock.patch.object(
server, "initialization"
) as patch_initialization:
with mock.patch.object(server, "training_round") as patch_training:
with mock.patch.object(
server, "evaluation_round"
) as patch_evaluation:
with mock.patch.object(
server, "fairness_round"
) as patch_fairness:
with mock.patch.object(
server, "stop_training"
) as patch_stop_training:
await server.async_run(config)
# Verify that expected calls occured.
patch_initialization.assert_called_once_with(config)
patch_training.assert_has_calls(
[mock.call(idx, config.training) for idx in range(1, 11)]
)
patch_evaluation.assert_has_calls(
[mock.call(idx, config.evaluate) for idx in range(1, 11)]
)
patch_fairness.assert_has_calls(
[mock.call(idx, config.fairness) for idx in range(0, 10)]
+ [mock.call(10, config.fairness, force_run=True)]
)
patch_stop_training.assert_called_once_with(10)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment