From 248f73b7c2802ff230d2845789c77b8fb5bd537e Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Thu, 20 Jun 2024 16:23:31 +0200 Subject: [PATCH] Add unit tests for 'FederatedServer' covering fairness features. --- test/main/test_main_server.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test/main/test_main_server.py b/test/main/test_main_server.py index e1c37a4..1ae08bb 100644 --- a/test/main/test_main_server.py +++ b/test/main/test_main_server.py @@ -26,6 +26,7 @@ import pytest # type: ignore from declearn.aggregator import Aggregator from declearn.communication import NetworkServerConfig from declearn.communication.api import NetworkServer +from declearn.fairness.api import FairnessControllerServer from declearn.main import FederatedServer from declearn.main.config import FLOptimConfig from declearn.main.utils import Checkpointer @@ -144,6 +145,9 @@ class TestFederatedServerInit: # pylint: disable=too-many-public-methods client_opt=mock.create_autospec(Optimizer, instance=True), server_opt=mock.create_autospec(Optimizer, instance=True), aggregator=mock.create_autospec(Aggregator, instance=True), + fairness=mock.create_autospec( + FairnessControllerServer, instance=True + ), ) server = FederatedServer( model=MOCK_MODEL, netwk=MOCK_NETWK, optim=optim @@ -151,6 +155,7 @@ class TestFederatedServerInit: # pylint: disable=too-many-public-methods assert server.c_opt is optim.client_opt assert server.optim is optim.server_opt assert server.aggrg is optim.aggregator + assert server.fairness is optim.fairness def test_optim_dict(self) -> None: """Test specifying 'optim' as a config dict.""" @@ -158,6 +163,9 @@ class TestFederatedServerInit: # pylint: disable=too-many-public-methods "client_opt": mock.create_autospec(Optimizer, instance=True), "server_opt": mock.create_autospec(Optimizer, instance=True), "aggregator": mock.create_autospec(Aggregator, instance=True), + "fairness": mock.create_autospec( + FairnessControllerServer, instance=True + ), } server = FederatedServer( model=MOCK_MODEL, netwk=MOCK_NETWK, optim=optim @@ -165,6 +173,7 @@ class TestFederatedServerInit: # pylint: disable=too-many-public-methods assert server.c_opt is optim["client_opt"] assert server.optim is optim["server_opt"] assert server.aggrg is optim["aggregator"] + assert server.fairness is optim["fairness"] def test_optim_toml(self, tmp_path: str) -> None: """Test specifying 'optim' as a TOML file path.""" @@ -192,6 +201,7 @@ class TestFederatedServerInit: # pylint: disable=too-many-public-methods assert server.c_opt.get_config() == config.client_opt.get_config() assert server.optim.get_config() == config.server_opt.get_config() assert server.aggrg.get_config() == config.aggregator.get_config() + assert server.fairness is None def test_optim_invalid(self) -> None: """Test specifying 'optim' with an invalid type.""" -- GitLab