From a87e891ea5805b65a87c730f7437bf4670558ac0 Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Mon, 15 Jul 2024 15:51:33 +0200 Subject: [PATCH] Improve coverage of 'fairness' unit tests. --- .../algorithms/test_fairfed_aggregator.py | 13 +++++ .../algorithms/test_fairfed_computer.py | 16 ++++++ .../controllers/test_fairbatch_controllers.py | 57 ++++++++++++++++++- .../controllers/test_fairfed_controllers.py | 30 ++++++++++ .../controllers/test_fairgrad_controllers.py | 54 +++++++++++++++++- 5 files changed, 165 insertions(+), 5 deletions(-) diff --git a/test/fairness/algorithms/test_fairfed_aggregator.py b/test/fairness/algorithms/test_fairfed_aggregator.py index b7ad02f..2238ad6 100644 --- a/test/fairness/algorithms/test_fairfed_aggregator.py +++ b/test/fairness/algorithms/test_fairfed_aggregator.py @@ -72,3 +72,16 @@ class TestFairfedAggregator: updates.__mul__.assert_called_once_with(expectw) assert model_updates.updates is updates.__mul__.return_value assert model_updates.weights == expectw + + def test_finalize_updates(self) -> None: + """Test that 'finalize_updates' works as expected.""" + # Set up a FairFed aggregator and initialize it. + n_samples = 100 + aggregator = FairfedAggregator(beta=0.1) + aggregator.initialize_local_weight(n_samples=n_samples) + # Prepare, then finalize updates. + updates = mock.create_autospec(Vector, instance=True) + output = aggregator.finalize_updates( + aggregator.prepare_for_sharing(updates, n_steps=mock.MagicMock()) + ) + assert output == (updates * n_samples / n_samples) diff --git a/test/fairness/algorithms/test_fairfed_computer.py b/test/fairness/algorithms/test_fairfed_computer.py index 959d6f4..b38e883 100644 --- a/test/fairness/algorithms/test_fairfed_computer.py +++ b/test/fairness/algorithms/test_fairfed_computer.py @@ -74,6 +74,22 @@ class TestFairfedValueComputer: warnings.simplefilter("ignore", RuntimeWarning) computer.identify_key_groups(GROUPS_EXTEND.copy()) + @pytest.mark.parametrize("f_type", F_TYPES) + def test_identify_key_groups_hybrid_exception( + self, + f_type: str, + ) -> None: + """Test 'identify_key_groups' exception raising with 'hybrid' groups. + + 'Hybrid' groups are groups that seemingly arise from a categorical + target that does not cross all sensitive attribute modalities. + """ + computer = FairfedValueComputer(f_type, strict=True, target=1) + with pytest.raises(KeyError): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RuntimeWarning) + computer.identify_key_groups([(0, 0), (0, 1), (1, 0), (2, 1)]) + @pytest.mark.parametrize("binary", [True, False], ids=["binary", "extend"]) @pytest.mark.parametrize("strict", [True, False], ids=["strict", "free"]) @pytest.mark.parametrize("f_type", F_TYPES[1:]) # avoid warning on AccPar diff --git a/test/fairness/controllers/test_fairbatch_controllers.py b/test/fairness/controllers/test_fairbatch_controllers.py index 5f55948..f6e5a67 100644 --- a/test/fairness/controllers/test_fairbatch_controllers.py +++ b/test/fairness/controllers/test_fairbatch_controllers.py @@ -17,6 +17,7 @@ """Unit tests for Fed-FairBatch controllers.""" +import asyncio import os from typing import List from unittest import mock @@ -24,6 +25,7 @@ from unittest import mock import pytest from declearn.aggregator import Aggregator, SumAggregator +from declearn.communication.utils import ErrorMessageException from declearn.fairness.api import ( FairnessControllerClient, FairnessControllerServer, @@ -34,7 +36,7 @@ from declearn.fairness.fairbatch import ( FairbatchDataset, FairbatchSamplingController, ) -from declearn.test_utils import make_importable +from declearn.test_utils import make_importable, setup_mock_network_endpoints with make_importable(os.path.dirname(os.path.abspath(__file__))): from fairness_controllers_testing import ( @@ -125,18 +127,20 @@ class TestFairbatchControllers(FairnessControllerTestSuite): with mock.patch( "declearn.fairness.fairbatch._server.setup_fairbatch_controller" ) as patch_setup_fairbatch: - FairbatchControllerServer( + controller = FairbatchControllerServer( f_type="demographic_parity", fedfb=False, ) + assert not controller.fedfb patch_setup_fairbatch.assert_called_once() with mock.patch( "declearn.fairness.fairbatch._server.setup_fedfb_controller" ) as patch_setup_fedfb: - FairbatchControllerServer( + controller = FairbatchControllerServer( f_type="demographic_parity", fedfb=True, ) + assert controller.fedfb patch_setup_fedfb.assert_called_once() def test_init_alpha_param(self) -> None: @@ -146,3 +150,50 @@ class TestFairbatchControllers(FairnessControllerTestSuite): f_type="demographic_parity", alpha=alpha ) assert server.sampling_controller.alpha is alpha + + @pytest.mark.asyncio + async def test_finalize_fairness_setup_error( + self, + ) -> None: + """Test that FairBatch probas update error-catching works properly.""" + n_peers = len(CLIENT_COUNTS) + # Instantiate the fairness controllers. + server = self.setup_server_controller() + clients = [ + self.setup_client_controller_from_server(server, idx) + for idx in range(n_peers) + ] + # Assign expected group definitions and counts. + server.groups = sorted(list(TOTAL_COUNTS)) + for client in clients: + client.groups = server.groups.copy() + counts = [TOTAL_COUNTS[group] for group in server.groups] + # Run setup coroutines, using mock network endpoints. + aggregator = mock.create_autospec(SumAggregator, instance=True) + async with setup_mock_network_endpoints(n_peers) as network: + coro_server = server.finalize_fairness_setup( + netwk=network[0], + secagg=None, + counts=counts, + aggregator=aggregator, + ) + coro_clients = [ + client.finalize_fairness_setup( + netwk=network[1][idx], + secagg=None, + ) + for idx, client in enumerate(clients) + ] + # Have the sampling probabilities' assignment fail. + with mock.patch.object( + FairbatchDataset, + "set_sampling_probabilities", + side_effect=Exception, + ) as patch_set_sampling_probabilities: + exc_server, *exc_clients = await asyncio.gather( + coro_server, *coro_clients, return_exceptions=True + ) + # Assert that expected exceptions were raised. + assert isinstance(exc_server, ErrorMessageException) + assert all(isinstance(exc, RuntimeError) for exc in exc_clients) + assert patch_set_sampling_probabilities.call_count == n_peers diff --git a/test/fairness/controllers/test_fairfed_controllers.py b/test/fairness/controllers/test_fairfed_controllers.py index c6e1513..157d879 100644 --- a/test/fairness/controllers/test_fairfed_controllers.py +++ b/test/fairness/controllers/test_fairfed_controllers.py @@ -156,3 +156,33 @@ class TestFairfedControllers(FairnessControllerTestSuite): assert server["fairfed_deltavg"] == ( sum(client["fairfed_delta"] for client in clients) / len(clients) ) + + @pytest.mark.parametrize( + "strict", [True, False], ids=["strict", "extended"] + ) + def test_init_params( + self, + strict: bool, + ) -> None: + """Test that instantiation parameters are properly passed.""" + rng = np.random.default_rng() + beta = abs(rng.normal()) + target = int(rng.choice(2)) + controller = FairfedControllerServer( + f_type="demographic_parity", + beta=beta, + strict=strict, + target=target, + ) + assert controller.beta == beta + assert controller.fairfed_computer.f_type == "demographic_parity" + assert controller.strict is strict + assert controller.fairfed_computer.strict is strict + assert controller.fairfed_computer.target is target + # Verify that parameters are transmitted to clients. + client = self.setup_client_controller_from_server(controller, idx=0) + assert isinstance(client, FairfedControllerClient) + assert client.beta == controller.beta + assert client.fairfed_computer.f_type == "demographic_parity" + assert client.strict is strict + assert client.fairfed_computer.strict is strict diff --git a/test/fairness/controllers/test_fairgrad_controllers.py b/test/fairness/controllers/test_fairgrad_controllers.py index 385122c..bd4af8c 100644 --- a/test/fairness/controllers/test_fairgrad_controllers.py +++ b/test/fairness/controllers/test_fairgrad_controllers.py @@ -17,6 +17,7 @@ """Unit tests for Fed-FairGrad controllers.""" +import asyncio import os from typing import List from unittest import mock @@ -24,6 +25,7 @@ from unittest import mock import pytest from declearn.aggregator import Aggregator, SumAggregator +from declearn.communication.utils import ErrorMessageException from declearn.fairness.api import ( FairnessDataset, FairnessControllerClient, @@ -34,10 +36,14 @@ from declearn.fairness.fairgrad import ( FairgradControllerServer, FairgradWeightsController, ) -from declearn.test_utils import make_importable +from declearn.test_utils import make_importable, setup_mock_network_endpoints with make_importable(os.path.dirname(os.path.abspath(__file__))): - from fairness_controllers_testing import FairnessControllerTestSuite + from fairness_controllers_testing import ( + CLIENT_COUNTS, + TOTAL_COUNTS, + FairnessControllerTestSuite, + ) class TestFairgradControllers(FairnessControllerTestSuite): @@ -103,3 +109,47 @@ class TestFairgradControllers(FairnessControllerTestSuite): self.verify_fairness_round_metrics(metrics) patch_update_weights.assert_called_once() self.verify_fairgrad_weights_coherence(server, clients) + + @pytest.mark.asyncio + async def test_finalize_fairness_setup_error( + self, + ) -> None: + """Test that FairGrad weights setup error-catching works properly.""" + n_peers = len(CLIENT_COUNTS) + # Instantiate the fairness controllers. + server = self.setup_server_controller() + clients = [ + self.setup_client_controller_from_server(server, idx) + for idx in range(n_peers) + ] + # Assign expected group definitions and counts. + # Have client datasets fail upon receiving sensitive group weights. + server.groups = sorted(list(TOTAL_COUNTS)) + for client in clients: + client.groups = server.groups.copy() + mock_dst = client.manager.train_data + assert isinstance(mock_dst, mock.NonCallableMagicMock) + mock_dst.set_sensitive_group_weights.side_effect = Exception + counts = [TOTAL_COUNTS[group] for group in server.groups] + # Run setup coroutines, using mock network endpoints. + aggregator = mock.create_autospec(SumAggregator, instance=True) + async with setup_mock_network_endpoints(n_peers) as network: + coro_server = server.finalize_fairness_setup( + netwk=network[0], + secagg=None, + counts=counts, + aggregator=aggregator, + ) + coro_clients = [ + client.finalize_fairness_setup( + netwk=network[1][idx], + secagg=None, + ) + for idx, client in enumerate(clients) + ] + exc_server, *exc_clients = await asyncio.gather( + coro_server, *coro_clients, return_exceptions=True + ) + # Assert that expected exceptions were raised. + assert isinstance(exc_server, ErrorMessageException) + assert all(isinstance(exc, RuntimeError) for exc in exc_clients) -- GitLab