From 71e66792f05a1f4035b8298767e98b172f93c549 Mon Sep 17 00:00:00 2001
From: Paul Andrey <paul.andrey@inria.fr>
Date: Fri, 12 Jul 2024 10:32:32 +0200
Subject: [PATCH] Add unit tests for FairFed backend tools.

---
 .../algorithms/test_fairfed_aggregator.py     |  74 +++++++++
 .../algorithms/test_fairfed_computer.py       | 141 ++++++++++++++++++
 2 files changed, 215 insertions(+)
 create mode 100644 test/fairness/algorithms/test_fairfed_aggregator.py
 create mode 100644 test/fairness/algorithms/test_fairfed_computer.py

diff --git a/test/fairness/algorithms/test_fairfed_aggregator.py b/test/fairness/algorithms/test_fairfed_aggregator.py
new file mode 100644
index 0000000..b7ad02f
--- /dev/null
+++ b/test/fairness/algorithms/test_fairfed_aggregator.py
@@ -0,0 +1,74 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for FairFed-specific Aggregator subclass."""
+
+from unittest import mock
+
+
+from declearn.fairness.fairfed import FairfedAggregator
+from declearn.model.api import Vector
+
+
+class TestFairfedAggregator:
+    """Unit tests for 'declearn.fairness.fairfed.FairfedAggregator'."""
+
+    def test_init_beta(self) -> None:
+        """Test that the 'beta' parameter is properly assigned."""
+        beta = mock.create_autospec(float, instance=True)
+        aggregator = FairfedAggregator(beta=beta)
+        assert aggregator.beta is beta
+
+    def test_prepare_for_sharing_initial(self) -> None:
+        """Test that 'prepare_for_sharing' has expected outputs at first."""
+        # Set up an uninitialized aggregator and prepare mock updates.
+        aggregator = FairfedAggregator(beta=1.0)
+        updates = mock.create_autospec(Vector, instance=True)
+        model_updates = aggregator.prepare_for_sharing(updates, n_steps=10)
+        # Verify that outputs match expectations.
+        updates.__mul__.assert_called_once_with(1.0)
+        assert model_updates.updates is updates.__mul__.return_value
+        assert model_updates.weights == 1.0
+
+    def test_initialize_local_weight(self) -> None:
+        """Test that 'initialize_local_weight' works properly."""
+        # Set up an aggregator, initialize it and prepare mock updates.
+        n_samples = 100
+        aggregator = FairfedAggregator(beta=1.0)
+        aggregator.initialize_local_weight(n_samples=n_samples)
+        updates = mock.create_autospec(Vector, instance=True)
+        model_updates = aggregator.prepare_for_sharing(updates, n_steps=10)
+        # Verify that outputs match expectations.
+        updates.__mul__.assert_called_once_with(n_samples)
+        assert model_updates.updates is updates.__mul__.return_value
+        assert model_updates.weights == n_samples
+
+    def test_update_local_weight(self) -> None:
+        """Test that 'update_local_weight' works properly."""
+        # Set up a FairFed aggregator and initialize it.
+        n_samples = 100
+        aggregator = FairfedAggregator(beta=0.1)
+        aggregator.initialize_local_weight(n_samples=n_samples)
+        # Perform a local wiehgt update with arbitrary values.
+        aggregator.update_local_weight(delta_loc=2.0, delta_avg=5.0)
+        # Verify that updates have expected weight.
+        updates = mock.create_autospec(Vector, instance=True)
+        expectw = n_samples - 0.1 * (2.0 - 5.0)  # w_0 - beta * diff_delta
+        model_updates = aggregator.prepare_for_sharing(updates, n_steps=10)
+        updates.__mul__.assert_called_once_with(expectw)
+        assert model_updates.updates is updates.__mul__.return_value
+        assert model_updates.weights == expectw
diff --git a/test/fairness/algorithms/test_fairfed_computer.py b/test/fairness/algorithms/test_fairfed_computer.py
new file mode 100644
index 0000000..959d6f4
--- /dev/null
+++ b/test/fairness/algorithms/test_fairfed_computer.py
@@ -0,0 +1,141 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for FairFed-specific fairness value computer."""
+
+import warnings
+from typing import Any, List, Tuple
+
+import pytest
+
+from declearn.fairness.fairfed import FairfedValueComputer
+
+
+GROUPS_BINARY = [
+    (target, s_attr) for target in (0, 1) for s_attr in (0, 1)
+]  # type: List[Tuple[Any, ...]]
+GROUPS_EXTEND = [
+    (tgt, s_a, s_b) for tgt in (0, 1, 2) for s_a in (0, 1) for s_b in (1, 2)
+]  # type: List[Tuple[Any, ...]]
+F_TYPES = [
+    "accuracy_parity",
+    "demographic_parity",
+    "equality_of_opportunity",
+    "equalized_odds",
+]
+
+
+class TestFairfedValueComputer:
+    """Unit tests for 'declearn.fairness.fairfed.FairfedValueComputer'."""
+
+    @pytest.mark.parametrize("target", [1, 0], ids=["target1", "target0"])
+    @pytest.mark.parametrize("f_type", F_TYPES)
+    def test_identify_key_groups_binary(
+        self,
+        f_type: str,
+        target: int,
+    ) -> None:
+        """Test 'identify_key_groups' with binary target and attribute."""
+        computer = FairfedValueComputer(f_type, strict=True, target=target)
+        if f_type == "accuracy_parity":
+            with pytest.warns(RuntimeWarning):
+                key_groups = computer.identify_key_groups(GROUPS_BINARY.copy())
+        else:
+            key_groups = computer.identify_key_groups(GROUPS_BINARY.copy())
+        assert key_groups == ((target, 0), (target, 1))
+
+    @pytest.mark.parametrize("f_type", F_TYPES)
+    def test_identify_key_groups_extended_exception(
+        self,
+        f_type: str,
+    ) -> None:
+        """Test 'identify_key_groups' exception raising with extended groups.
+
+        'Extended' groups arise from a non-binary label intersected with
+        two distinct binary sensitive groups.
+        """
+        computer = FairfedValueComputer(f_type, strict=True, target=1)
+        with pytest.raises(RuntimeError):
+            with warnings.catch_warnings():
+                warnings.simplefilter("ignore", RuntimeWarning)
+                computer.identify_key_groups(GROUPS_EXTEND.copy())
+
+    @pytest.mark.parametrize("binary", [True, False], ids=["binary", "extend"])
+    @pytest.mark.parametrize("strict", [True, False], ids=["strict", "free"])
+    @pytest.mark.parametrize("f_type", F_TYPES[1:])  # avoid warning on AccPar
+    def test_initialize(
+        self,
+        f_type: str,
+        strict: bool,
+        binary: bool,
+    ) -> None:
+        """Test that 'initialize' raises an exception in expected cases."""
+        computer = FairfedValueComputer(f_type, strict=strict, target=1)
+        groups = (GROUPS_BINARY if binary else GROUPS_EXTEND).copy()
+        if strict and not binary:
+            with pytest.raises(RuntimeError):
+                computer.initialize(groups)
+        else:
+            computer.initialize(groups)
+
+    @pytest.mark.parametrize("strict", [True, False], ids=["strict", "free"])
+    def test_compute_synthetic_fairness_value_binary(
+        self,
+        strict: bool,
+    ) -> None:
+        """Test 'compute_synthetic_fairness_value' with 4 groups.
+
+        This test only applies to both strict and non-strict modes.
+        """
+        # Compute a synthetic value using arbitrary inputs.
+        fairness = {
+            group: float(idx) for idx, group in enumerate(GROUPS_BINARY)
+        }
+        computer = FairfedValueComputer(
+            f_type="demographic_parity",
+            strict=strict,
+            target=1,
+        )
+        computer.initialize(list(fairness))
+        value = computer.compute_synthetic_fairness_value(fairness)
+        # Verify that the ouput value matches expectations.
+        if strict:
+            expected = fairness[(1, 0)] - fairness[(1, 1)]
+        else:
+            expected = sum(fairness.values()) / len(fairness)
+        assert value == expected
+
+    def test_compute_synthetic_fairness_value_extended(
+        self,
+    ) -> None:
+        """Test 'compute_synthetic_fairness_value' with many groups.
+
+        This test only applies to the non-strict mode.
+        """
+        # Compute a synthetic value using arbitrary inputs.
+        fairness = {
+            group: float(idx) for idx, group in enumerate(GROUPS_EXTEND)
+        }
+        computer = FairfedValueComputer(
+            f_type="demographic_parity",
+            strict=False,
+        )
+        computer.initialize(list(fairness))
+        value = computer.compute_synthetic_fairness_value(fairness)
+        # Verify that the ouput value matches expectations.
+        expected = sum(fairness.values()) / len(fairness)
+        assert value == expected
-- 
GitLab