Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit 248f73b7 authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Add unit tests for 'FederatedServer' covering fairness features.

parent f4189e37
No related branches found
No related tags found
1 merge request!69Enable Fairness-Aware Federated Learning
...@@ -26,6 +26,7 @@ import pytest # type: ignore ...@@ -26,6 +26,7 @@ import pytest # type: ignore
from declearn.aggregator import Aggregator from declearn.aggregator import Aggregator
from declearn.communication import NetworkServerConfig from declearn.communication import NetworkServerConfig
from declearn.communication.api import NetworkServer from declearn.communication.api import NetworkServer
from declearn.fairness.api import FairnessControllerServer
from declearn.main import FederatedServer from declearn.main import FederatedServer
from declearn.main.config import FLOptimConfig from declearn.main.config import FLOptimConfig
from declearn.main.utils import Checkpointer from declearn.main.utils import Checkpointer
...@@ -144,6 +145,9 @@ class TestFederatedServerInit: # pylint: disable=too-many-public-methods ...@@ -144,6 +145,9 @@ class TestFederatedServerInit: # pylint: disable=too-many-public-methods
client_opt=mock.create_autospec(Optimizer, instance=True), client_opt=mock.create_autospec(Optimizer, instance=True),
server_opt=mock.create_autospec(Optimizer, instance=True), server_opt=mock.create_autospec(Optimizer, instance=True),
aggregator=mock.create_autospec(Aggregator, instance=True), aggregator=mock.create_autospec(Aggregator, instance=True),
fairness=mock.create_autospec(
FairnessControllerServer, instance=True
),
) )
server = FederatedServer( server = FederatedServer(
model=MOCK_MODEL, netwk=MOCK_NETWK, optim=optim model=MOCK_MODEL, netwk=MOCK_NETWK, optim=optim
...@@ -151,6 +155,7 @@ class TestFederatedServerInit: # pylint: disable=too-many-public-methods ...@@ -151,6 +155,7 @@ class TestFederatedServerInit: # pylint: disable=too-many-public-methods
assert server.c_opt is optim.client_opt assert server.c_opt is optim.client_opt
assert server.optim is optim.server_opt assert server.optim is optim.server_opt
assert server.aggrg is optim.aggregator assert server.aggrg is optim.aggregator
assert server.fairness is optim.fairness
def test_optim_dict(self) -> None: def test_optim_dict(self) -> None:
"""Test specifying 'optim' as a config dict.""" """Test specifying 'optim' as a config dict."""
...@@ -158,6 +163,9 @@ class TestFederatedServerInit: # pylint: disable=too-many-public-methods ...@@ -158,6 +163,9 @@ class TestFederatedServerInit: # pylint: disable=too-many-public-methods
"client_opt": mock.create_autospec(Optimizer, instance=True), "client_opt": mock.create_autospec(Optimizer, instance=True),
"server_opt": mock.create_autospec(Optimizer, instance=True), "server_opt": mock.create_autospec(Optimizer, instance=True),
"aggregator": mock.create_autospec(Aggregator, instance=True), "aggregator": mock.create_autospec(Aggregator, instance=True),
"fairness": mock.create_autospec(
FairnessControllerServer, instance=True
),
} }
server = FederatedServer( server = FederatedServer(
model=MOCK_MODEL, netwk=MOCK_NETWK, optim=optim model=MOCK_MODEL, netwk=MOCK_NETWK, optim=optim
...@@ -165,6 +173,7 @@ class TestFederatedServerInit: # pylint: disable=too-many-public-methods ...@@ -165,6 +173,7 @@ class TestFederatedServerInit: # pylint: disable=too-many-public-methods
assert server.c_opt is optim["client_opt"] assert server.c_opt is optim["client_opt"]
assert server.optim is optim["server_opt"] assert server.optim is optim["server_opt"]
assert server.aggrg is optim["aggregator"] assert server.aggrg is optim["aggregator"]
assert server.fairness is optim["fairness"]
def test_optim_toml(self, tmp_path: str) -> None: def test_optim_toml(self, tmp_path: str) -> None:
"""Test specifying 'optim' as a TOML file path.""" """Test specifying 'optim' as a TOML file path."""
...@@ -192,6 +201,7 @@ class TestFederatedServerInit: # pylint: disable=too-many-public-methods ...@@ -192,6 +201,7 @@ class TestFederatedServerInit: # pylint: disable=too-many-public-methods
assert server.c_opt.get_config() == config.client_opt.get_config() assert server.c_opt.get_config() == config.client_opt.get_config()
assert server.optim.get_config() == config.server_opt.get_config() assert server.optim.get_config() == config.server_opt.get_config()
assert server.aggrg.get_config() == config.aggregator.get_config() assert server.aggrg.get_config() == config.aggregator.get_config()
assert server.fairness is None
def test_optim_invalid(self) -> None: def test_optim_invalid(self) -> None:
"""Test specifying 'optim' with an invalid type.""" """Test specifying 'optim' with an invalid type."""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment