From 71e66792f05a1f4035b8298767e98b172f93c549 Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Fri, 12 Jul 2024 10:32:32 +0200 Subject: [PATCH] Add unit tests for FairFed backend tools. --- .../algorithms/test_fairfed_aggregator.py | 74 +++++++++ .../algorithms/test_fairfed_computer.py | 141 ++++++++++++++++++ 2 files changed, 215 insertions(+) create mode 100644 test/fairness/algorithms/test_fairfed_aggregator.py create mode 100644 test/fairness/algorithms/test_fairfed_computer.py diff --git a/test/fairness/algorithms/test_fairfed_aggregator.py b/test/fairness/algorithms/test_fairfed_aggregator.py new file mode 100644 index 0000000..b7ad02f --- /dev/null +++ b/test/fairness/algorithms/test_fairfed_aggregator.py @@ -0,0 +1,74 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for FairFed-specific Aggregator subclass.""" + +from unittest import mock + + +from declearn.fairness.fairfed import FairfedAggregator +from declearn.model.api import Vector + + +class TestFairfedAggregator: + """Unit tests for 'declearn.fairness.fairfed.FairfedAggregator'.""" + + def test_init_beta(self) -> None: + """Test that the 'beta' parameter is properly assigned.""" + beta = mock.create_autospec(float, instance=True) + aggregator = FairfedAggregator(beta=beta) + assert aggregator.beta is beta + + def test_prepare_for_sharing_initial(self) -> None: + """Test that 'prepare_for_sharing' has expected outputs at first.""" + # Set up an uninitialized aggregator and prepare mock updates. + aggregator = FairfedAggregator(beta=1.0) + updates = mock.create_autospec(Vector, instance=True) + model_updates = aggregator.prepare_for_sharing(updates, n_steps=10) + # Verify that outputs match expectations. + updates.__mul__.assert_called_once_with(1.0) + assert model_updates.updates is updates.__mul__.return_value + assert model_updates.weights == 1.0 + + def test_initialize_local_weight(self) -> None: + """Test that 'initialize_local_weight' works properly.""" + # Set up an aggregator, initialize it and prepare mock updates. + n_samples = 100 + aggregator = FairfedAggregator(beta=1.0) + aggregator.initialize_local_weight(n_samples=n_samples) + updates = mock.create_autospec(Vector, instance=True) + model_updates = aggregator.prepare_for_sharing(updates, n_steps=10) + # Verify that outputs match expectations. + updates.__mul__.assert_called_once_with(n_samples) + assert model_updates.updates is updates.__mul__.return_value + assert model_updates.weights == n_samples + + def test_update_local_weight(self) -> None: + """Test that 'update_local_weight' works properly.""" + # Set up a FairFed aggregator and initialize it. + n_samples = 100 + aggregator = FairfedAggregator(beta=0.1) + aggregator.initialize_local_weight(n_samples=n_samples) + # Perform a local wiehgt update with arbitrary values. + aggregator.update_local_weight(delta_loc=2.0, delta_avg=5.0) + # Verify that updates have expected weight. + updates = mock.create_autospec(Vector, instance=True) + expectw = n_samples - 0.1 * (2.0 - 5.0) # w_0 - beta * diff_delta + model_updates = aggregator.prepare_for_sharing(updates, n_steps=10) + updates.__mul__.assert_called_once_with(expectw) + assert model_updates.updates is updates.__mul__.return_value + assert model_updates.weights == expectw diff --git a/test/fairness/algorithms/test_fairfed_computer.py b/test/fairness/algorithms/test_fairfed_computer.py new file mode 100644 index 0000000..959d6f4 --- /dev/null +++ b/test/fairness/algorithms/test_fairfed_computer.py @@ -0,0 +1,141 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for FairFed-specific fairness value computer.""" + +import warnings +from typing import Any, List, Tuple + +import pytest + +from declearn.fairness.fairfed import FairfedValueComputer + + +GROUPS_BINARY = [ + (target, s_attr) for target in (0, 1) for s_attr in (0, 1) +] # type: List[Tuple[Any, ...]] +GROUPS_EXTEND = [ + (tgt, s_a, s_b) for tgt in (0, 1, 2) for s_a in (0, 1) for s_b in (1, 2) +] # type: List[Tuple[Any, ...]] +F_TYPES = [ + "accuracy_parity", + "demographic_parity", + "equality_of_opportunity", + "equalized_odds", +] + + +class TestFairfedValueComputer: + """Unit tests for 'declearn.fairness.fairfed.FairfedValueComputer'.""" + + @pytest.mark.parametrize("target", [1, 0], ids=["target1", "target0"]) + @pytest.mark.parametrize("f_type", F_TYPES) + def test_identify_key_groups_binary( + self, + f_type: str, + target: int, + ) -> None: + """Test 'identify_key_groups' with binary target and attribute.""" + computer = FairfedValueComputer(f_type, strict=True, target=target) + if f_type == "accuracy_parity": + with pytest.warns(RuntimeWarning): + key_groups = computer.identify_key_groups(GROUPS_BINARY.copy()) + else: + key_groups = computer.identify_key_groups(GROUPS_BINARY.copy()) + assert key_groups == ((target, 0), (target, 1)) + + @pytest.mark.parametrize("f_type", F_TYPES) + def test_identify_key_groups_extended_exception( + self, + f_type: str, + ) -> None: + """Test 'identify_key_groups' exception raising with extended groups. + + 'Extended' groups arise from a non-binary label intersected with + two distinct binary sensitive groups. + """ + computer = FairfedValueComputer(f_type, strict=True, target=1) + with pytest.raises(RuntimeError): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RuntimeWarning) + computer.identify_key_groups(GROUPS_EXTEND.copy()) + + @pytest.mark.parametrize("binary", [True, False], ids=["binary", "extend"]) + @pytest.mark.parametrize("strict", [True, False], ids=["strict", "free"]) + @pytest.mark.parametrize("f_type", F_TYPES[1:]) # avoid warning on AccPar + def test_initialize( + self, + f_type: str, + strict: bool, + binary: bool, + ) -> None: + """Test that 'initialize' raises an exception in expected cases.""" + computer = FairfedValueComputer(f_type, strict=strict, target=1) + groups = (GROUPS_BINARY if binary else GROUPS_EXTEND).copy() + if strict and not binary: + with pytest.raises(RuntimeError): + computer.initialize(groups) + else: + computer.initialize(groups) + + @pytest.mark.parametrize("strict", [True, False], ids=["strict", "free"]) + def test_compute_synthetic_fairness_value_binary( + self, + strict: bool, + ) -> None: + """Test 'compute_synthetic_fairness_value' with 4 groups. + + This test only applies to both strict and non-strict modes. + """ + # Compute a synthetic value using arbitrary inputs. + fairness = { + group: float(idx) for idx, group in enumerate(GROUPS_BINARY) + } + computer = FairfedValueComputer( + f_type="demographic_parity", + strict=strict, + target=1, + ) + computer.initialize(list(fairness)) + value = computer.compute_synthetic_fairness_value(fairness) + # Verify that the ouput value matches expectations. + if strict: + expected = fairness[(1, 0)] - fairness[(1, 1)] + else: + expected = sum(fairness.values()) / len(fairness) + assert value == expected + + def test_compute_synthetic_fairness_value_extended( + self, + ) -> None: + """Test 'compute_synthetic_fairness_value' with many groups. + + This test only applies to the non-strict mode. + """ + # Compute a synthetic value using arbitrary inputs. + fairness = { + group: float(idx) for idx, group in enumerate(GROUPS_EXTEND) + } + computer = FairfedValueComputer( + f_type="demographic_parity", + strict=False, + ) + computer.initialize(list(fairness)) + value = computer.compute_synthetic_fairness_value(fairness) + # Verify that the ouput value matches expectations. + expected = sum(fairness.values()) / len(fairness) + assert value == expected -- GitLab