From 248f73b7c2802ff230d2845789c77b8fb5bd537e Mon Sep 17 00:00:00 2001
From: Paul Andrey <paul.andrey@inria.fr>
Date: Thu, 20 Jun 2024 16:23:31 +0200
Subject: [PATCH] Add unit tests for 'FederatedServer' covering fairness
 features.

---
 test/main/test_main_server.py | 10 ++++++++++
 1 file changed, 10 insertions(+)

diff --git a/test/main/test_main_server.py b/test/main/test_main_server.py
index e1c37a4..1ae08bb 100644
--- a/test/main/test_main_server.py
+++ b/test/main/test_main_server.py
@@ -26,6 +26,7 @@ import pytest  # type: ignore
 from declearn.aggregator import Aggregator
 from declearn.communication import NetworkServerConfig
 from declearn.communication.api import NetworkServer
+from declearn.fairness.api import FairnessControllerServer
 from declearn.main import FederatedServer
 from declearn.main.config import FLOptimConfig
 from declearn.main.utils import Checkpointer
@@ -144,6 +145,9 @@ class TestFederatedServerInit:  # pylint: disable=too-many-public-methods
             client_opt=mock.create_autospec(Optimizer, instance=True),
             server_opt=mock.create_autospec(Optimizer, instance=True),
             aggregator=mock.create_autospec(Aggregator, instance=True),
+            fairness=mock.create_autospec(
+                FairnessControllerServer, instance=True
+            ),
         )
         server = FederatedServer(
             model=MOCK_MODEL, netwk=MOCK_NETWK, optim=optim
@@ -151,6 +155,7 @@ class TestFederatedServerInit:  # pylint: disable=too-many-public-methods
         assert server.c_opt is optim.client_opt
         assert server.optim is optim.server_opt
         assert server.aggrg is optim.aggregator
+        assert server.fairness is optim.fairness
 
     def test_optim_dict(self) -> None:
         """Test specifying 'optim' as a config dict."""
@@ -158,6 +163,9 @@ class TestFederatedServerInit:  # pylint: disable=too-many-public-methods
             "client_opt": mock.create_autospec(Optimizer, instance=True),
             "server_opt": mock.create_autospec(Optimizer, instance=True),
             "aggregator": mock.create_autospec(Aggregator, instance=True),
+            "fairness": mock.create_autospec(
+                FairnessControllerServer, instance=True
+            ),
         }
         server = FederatedServer(
             model=MOCK_MODEL, netwk=MOCK_NETWK, optim=optim
@@ -165,6 +173,7 @@ class TestFederatedServerInit:  # pylint: disable=too-many-public-methods
         assert server.c_opt is optim["client_opt"]
         assert server.optim is optim["server_opt"]
         assert server.aggrg is optim["aggregator"]
+        assert server.fairness is optim["fairness"]
 
     def test_optim_toml(self, tmp_path: str) -> None:
         """Test specifying 'optim' as a TOML file path."""
@@ -192,6 +201,7 @@ class TestFederatedServerInit:  # pylint: disable=too-many-public-methods
         assert server.c_opt.get_config() == config.client_opt.get_config()
         assert server.optim.get_config() == config.server_opt.get_config()
         assert server.aggrg.get_config() == config.aggregator.get_config()
+        assert server.fairness is None
 
     def test_optim_invalid(self) -> None:
         """Test specifying 'optim' with an invalid type."""
-- 
GitLab