From a87e891ea5805b65a87c730f7437bf4670558ac0 Mon Sep 17 00:00:00 2001
From: Paul Andrey <paul.andrey@inria.fr>
Date: Mon, 15 Jul 2024 15:51:33 +0200
Subject: [PATCH] Improve coverage of 'fairness' unit tests.

---
 .../algorithms/test_fairfed_aggregator.py     | 13 +++++
 .../algorithms/test_fairfed_computer.py       | 16 ++++++
 .../controllers/test_fairbatch_controllers.py | 57 ++++++++++++++++++-
 .../controllers/test_fairfed_controllers.py   | 30 ++++++++++
 .../controllers/test_fairgrad_controllers.py  | 54 +++++++++++++++++-
 5 files changed, 165 insertions(+), 5 deletions(-)

diff --git a/test/fairness/algorithms/test_fairfed_aggregator.py b/test/fairness/algorithms/test_fairfed_aggregator.py
index b7ad02f..2238ad6 100644
--- a/test/fairness/algorithms/test_fairfed_aggregator.py
+++ b/test/fairness/algorithms/test_fairfed_aggregator.py
@@ -72,3 +72,16 @@ class TestFairfedAggregator:
         updates.__mul__.assert_called_once_with(expectw)
         assert model_updates.updates is updates.__mul__.return_value
         assert model_updates.weights == expectw
+
+    def test_finalize_updates(self) -> None:
+        """Test that 'finalize_updates' works as expected."""
+        # Set up a FairFed aggregator and initialize it.
+        n_samples = 100
+        aggregator = FairfedAggregator(beta=0.1)
+        aggregator.initialize_local_weight(n_samples=n_samples)
+        # Prepare, then finalize updates.
+        updates = mock.create_autospec(Vector, instance=True)
+        output = aggregator.finalize_updates(
+            aggregator.prepare_for_sharing(updates, n_steps=mock.MagicMock())
+        )
+        assert output == (updates * n_samples / n_samples)
diff --git a/test/fairness/algorithms/test_fairfed_computer.py b/test/fairness/algorithms/test_fairfed_computer.py
index 959d6f4..b38e883 100644
--- a/test/fairness/algorithms/test_fairfed_computer.py
+++ b/test/fairness/algorithms/test_fairfed_computer.py
@@ -74,6 +74,22 @@ class TestFairfedValueComputer:
                 warnings.simplefilter("ignore", RuntimeWarning)
                 computer.identify_key_groups(GROUPS_EXTEND.copy())
 
+    @pytest.mark.parametrize("f_type", F_TYPES)
+    def test_identify_key_groups_hybrid_exception(
+        self,
+        f_type: str,
+    ) -> None:
+        """Test 'identify_key_groups' exception raising with 'hybrid' groups.
+
+        'Hybrid' groups are groups that seemingly arise from a categorical
+        target that does not cross all sensitive attribute modalities.
+        """
+        computer = FairfedValueComputer(f_type, strict=True, target=1)
+        with pytest.raises(KeyError):
+            with warnings.catch_warnings():
+                warnings.simplefilter("ignore", RuntimeWarning)
+                computer.identify_key_groups([(0, 0), (0, 1), (1, 0), (2, 1)])
+
     @pytest.mark.parametrize("binary", [True, False], ids=["binary", "extend"])
     @pytest.mark.parametrize("strict", [True, False], ids=["strict", "free"])
     @pytest.mark.parametrize("f_type", F_TYPES[1:])  # avoid warning on AccPar
diff --git a/test/fairness/controllers/test_fairbatch_controllers.py b/test/fairness/controllers/test_fairbatch_controllers.py
index 5f55948..f6e5a67 100644
--- a/test/fairness/controllers/test_fairbatch_controllers.py
+++ b/test/fairness/controllers/test_fairbatch_controllers.py
@@ -17,6 +17,7 @@
 
 """Unit tests for Fed-FairBatch controllers."""
 
+import asyncio
 import os
 from typing import List
 from unittest import mock
@@ -24,6 +25,7 @@ from unittest import mock
 import pytest
 
 from declearn.aggregator import Aggregator, SumAggregator
+from declearn.communication.utils import ErrorMessageException
 from declearn.fairness.api import (
     FairnessControllerClient,
     FairnessControllerServer,
@@ -34,7 +36,7 @@ from declearn.fairness.fairbatch import (
     FairbatchDataset,
     FairbatchSamplingController,
 )
-from declearn.test_utils import make_importable
+from declearn.test_utils import make_importable, setup_mock_network_endpoints
 
 with make_importable(os.path.dirname(os.path.abspath(__file__))):
     from fairness_controllers_testing import (
@@ -125,18 +127,20 @@ class TestFairbatchControllers(FairnessControllerTestSuite):
         with mock.patch(
             "declearn.fairness.fairbatch._server.setup_fairbatch_controller"
         ) as patch_setup_fairbatch:
-            FairbatchControllerServer(
+            controller = FairbatchControllerServer(
                 f_type="demographic_parity",
                 fedfb=False,
             )
+            assert not controller.fedfb
             patch_setup_fairbatch.assert_called_once()
         with mock.patch(
             "declearn.fairness.fairbatch._server.setup_fedfb_controller"
         ) as patch_setup_fedfb:
-            FairbatchControllerServer(
+            controller = FairbatchControllerServer(
                 f_type="demographic_parity",
                 fedfb=True,
             )
+            assert controller.fedfb
             patch_setup_fedfb.assert_called_once()
 
     def test_init_alpha_param(self) -> None:
@@ -146,3 +150,50 @@ class TestFairbatchControllers(FairnessControllerTestSuite):
             f_type="demographic_parity", alpha=alpha
         )
         assert server.sampling_controller.alpha is alpha
+
+    @pytest.mark.asyncio
+    async def test_finalize_fairness_setup_error(
+        self,
+    ) -> None:
+        """Test that FairBatch probas update error-catching works properly."""
+        n_peers = len(CLIENT_COUNTS)
+        # Instantiate the fairness controllers.
+        server = self.setup_server_controller()
+        clients = [
+            self.setup_client_controller_from_server(server, idx)
+            for idx in range(n_peers)
+        ]
+        # Assign expected group definitions and counts.
+        server.groups = sorted(list(TOTAL_COUNTS))
+        for client in clients:
+            client.groups = server.groups.copy()
+        counts = [TOTAL_COUNTS[group] for group in server.groups]
+        # Run setup coroutines, using mock network endpoints.
+        aggregator = mock.create_autospec(SumAggregator, instance=True)
+        async with setup_mock_network_endpoints(n_peers) as network:
+            coro_server = server.finalize_fairness_setup(
+                netwk=network[0],
+                secagg=None,
+                counts=counts,
+                aggregator=aggregator,
+            )
+            coro_clients = [
+                client.finalize_fairness_setup(
+                    netwk=network[1][idx],
+                    secagg=None,
+                )
+                for idx, client in enumerate(clients)
+            ]
+            # Have the sampling probabilities' assignment fail.
+            with mock.patch.object(
+                FairbatchDataset,
+                "set_sampling_probabilities",
+                side_effect=Exception,
+            ) as patch_set_sampling_probabilities:
+                exc_server, *exc_clients = await asyncio.gather(
+                    coro_server, *coro_clients, return_exceptions=True
+                )
+        # Assert that expected exceptions were raised.
+        assert isinstance(exc_server, ErrorMessageException)
+        assert all(isinstance(exc, RuntimeError) for exc in exc_clients)
+        assert patch_set_sampling_probabilities.call_count == n_peers
diff --git a/test/fairness/controllers/test_fairfed_controllers.py b/test/fairness/controllers/test_fairfed_controllers.py
index c6e1513..157d879 100644
--- a/test/fairness/controllers/test_fairfed_controllers.py
+++ b/test/fairness/controllers/test_fairfed_controllers.py
@@ -156,3 +156,33 @@ class TestFairfedControllers(FairnessControllerTestSuite):
         assert server["fairfed_deltavg"] == (
             sum(client["fairfed_delta"] for client in clients) / len(clients)
         )
+
+    @pytest.mark.parametrize(
+        "strict", [True, False], ids=["strict", "extended"]
+    )
+    def test_init_params(
+        self,
+        strict: bool,
+    ) -> None:
+        """Test that instantiation parameters are properly passed."""
+        rng = np.random.default_rng()
+        beta = abs(rng.normal())
+        target = int(rng.choice(2))
+        controller = FairfedControllerServer(
+            f_type="demographic_parity",
+            beta=beta,
+            strict=strict,
+            target=target,
+        )
+        assert controller.beta == beta
+        assert controller.fairfed_computer.f_type == "demographic_parity"
+        assert controller.strict is strict
+        assert controller.fairfed_computer.strict is strict
+        assert controller.fairfed_computer.target is target
+        # Verify that parameters are transmitted to clients.
+        client = self.setup_client_controller_from_server(controller, idx=0)
+        assert isinstance(client, FairfedControllerClient)
+        assert client.beta == controller.beta
+        assert client.fairfed_computer.f_type == "demographic_parity"
+        assert client.strict is strict
+        assert client.fairfed_computer.strict is strict
diff --git a/test/fairness/controllers/test_fairgrad_controllers.py b/test/fairness/controllers/test_fairgrad_controllers.py
index 385122c..bd4af8c 100644
--- a/test/fairness/controllers/test_fairgrad_controllers.py
+++ b/test/fairness/controllers/test_fairgrad_controllers.py
@@ -17,6 +17,7 @@
 
 """Unit tests for Fed-FairGrad controllers."""
 
+import asyncio
 import os
 from typing import List
 from unittest import mock
@@ -24,6 +25,7 @@ from unittest import mock
 import pytest
 
 from declearn.aggregator import Aggregator, SumAggregator
+from declearn.communication.utils import ErrorMessageException
 from declearn.fairness.api import (
     FairnessDataset,
     FairnessControllerClient,
@@ -34,10 +36,14 @@ from declearn.fairness.fairgrad import (
     FairgradControllerServer,
     FairgradWeightsController,
 )
-from declearn.test_utils import make_importable
+from declearn.test_utils import make_importable, setup_mock_network_endpoints
 
 with make_importable(os.path.dirname(os.path.abspath(__file__))):
-    from fairness_controllers_testing import FairnessControllerTestSuite
+    from fairness_controllers_testing import (
+        CLIENT_COUNTS,
+        TOTAL_COUNTS,
+        FairnessControllerTestSuite,
+    )
 
 
 class TestFairgradControllers(FairnessControllerTestSuite):
@@ -103,3 +109,47 @@ class TestFairgradControllers(FairnessControllerTestSuite):
         self.verify_fairness_round_metrics(metrics)
         patch_update_weights.assert_called_once()
         self.verify_fairgrad_weights_coherence(server, clients)
+
+    @pytest.mark.asyncio
+    async def test_finalize_fairness_setup_error(
+        self,
+    ) -> None:
+        """Test that FairGrad weights setup error-catching works properly."""
+        n_peers = len(CLIENT_COUNTS)
+        # Instantiate the fairness controllers.
+        server = self.setup_server_controller()
+        clients = [
+            self.setup_client_controller_from_server(server, idx)
+            for idx in range(n_peers)
+        ]
+        # Assign expected group definitions and counts.
+        # Have client datasets fail upon receiving sensitive group weights.
+        server.groups = sorted(list(TOTAL_COUNTS))
+        for client in clients:
+            client.groups = server.groups.copy()
+            mock_dst = client.manager.train_data
+            assert isinstance(mock_dst, mock.NonCallableMagicMock)
+            mock_dst.set_sensitive_group_weights.side_effect = Exception
+        counts = [TOTAL_COUNTS[group] for group in server.groups]
+        # Run setup coroutines, using mock network endpoints.
+        aggregator = mock.create_autospec(SumAggregator, instance=True)
+        async with setup_mock_network_endpoints(n_peers) as network:
+            coro_server = server.finalize_fairness_setup(
+                netwk=network[0],
+                secagg=None,
+                counts=counts,
+                aggregator=aggregator,
+            )
+            coro_clients = [
+                client.finalize_fairness_setup(
+                    netwk=network[1][idx],
+                    secagg=None,
+                )
+                for idx, client in enumerate(clients)
+            ]
+            exc_server, *exc_clients = await asyncio.gather(
+                coro_server, *coro_clients, return_exceptions=True
+            )
+        # Assert that expected exceptions were raised.
+        assert isinstance(exc_server, ErrorMessageException)
+        assert all(isinstance(exc, RuntimeError) for exc in exc_clients)
-- 
GitLab