From b641e27a016e9c3b2b5b2db09a220aa35d6560af Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Fri, 5 Jul 2024 16:17:54 +0200 Subject: [PATCH] Write shared unit tests for fairness controllers. --- .../fairness_controllers_testing.py | 407 ++++++++++++++++++ .../controllers/test_fairgrad_controllers.py | 72 ++++ 2 files changed, 479 insertions(+) create mode 100644 test/fairness/controllers/fairness_controllers_testing.py create mode 100644 test/fairness/controllers/test_fairgrad_controllers.py diff --git a/test/fairness/controllers/fairness_controllers_testing.py b/test/fairness/controllers/fairness_controllers_testing.py new file mode 100644 index 0000000..9255612 --- /dev/null +++ b/test/fairness/controllers/fairness_controllers_testing.py @@ -0,0 +1,407 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared unit tests for Fairness controllers.""" + +import asyncio +import logging +from unittest import mock +from typing import List, Optional, Tuple, Type, Union + +import numpy as np +import pytest + +from declearn.aggregator import Aggregator +from declearn.communication.api import NetworkServer +from declearn.communication.utils import verify_server_message_validity +from declearn.fairness.api import ( + FairnessControllerClient, + FairnessControllerServer, + FairnessDataset, +) +from declearn.messaging import ( + FairnessQuery, + FairnessReply, + FairnessSetupQuery, + SerializedMessage, +) +from declearn.metrics import MeanMetric +from declearn.model.api import Model +from declearn.secagg.api import Decrypter, Encrypter +from declearn.secagg.messaging import SecaggFairnessReply +from declearn.test_utils import ( + assert_dict_equal, + build_secagg_controllers, + setup_mock_network_endpoints, +) +from declearn.training import TrainingManager + + +# Define arbitrary group definitions and sample counts. +CLIENT_COUNTS = [ + {(0, 0): 10, (0, 1): 10, (1, 0): 10, (1, 1): 10}, + {(0, 0): 10, (1, 0): 15, (1, 1): 10}, + {(0, 0): 10, (0, 1): 5, (1, 0): 10}, +] +TOTAL_COUNTS = {(0, 0): 30, (0, 1): 15, (1, 0): 35, (1, 1): 20} + + +def build_mock_dataset(idx: int) -> mock.Mock: + """Return a mock FairnessDataset with deterministic group counts.""" + counts = CLIENT_COUNTS[idx] + dataset = mock.create_autospec(FairnessDataset, instance=True) + dataset.get_sensitive_group_definitions.return_value = list(counts) + dataset.get_sensitive_group_counts.return_value = counts + return dataset + + +class FairnessControllerTestSuite: + """Shared test suite for Fairness controllers.""" + + # Types of controllers associated with a given test suite subclass. + server_cls: Type[FairnessControllerServer] + client_cls: Type[FairnessControllerClient] + + # Default expected local computed metrics. May be overloaded by subclasses. + mock_client_metrics = [ + {"accuracy": {group: 1.0 for group in CLIENT_COUNTS[idx]}} + for idx in range(len(CLIENT_COUNTS)) + ] + + def setup_server_controller(self) -> FairnessControllerServer: + """Instantiate and return a server-side fairness controller.""" + return self.server_cls(f_type="accuracy_parity") + + def test_setup_client_from_setup_query( + self, + ) -> None: + """Test that the server's setup query results in a proper client.""" + server = self.setup_server_controller() + query = server.prepare_fairness_setup_query() + assert isinstance(query, FairnessSetupQuery) + manager = mock.create_autospec(TrainingManager, instance=True) + manager.train_data = ( + mock.create_autospec(FairnessDataset, instance=True) + ) + client = FairnessControllerClient.from_setup_query(query, manager) + assert isinstance(client, self.client_cls) + assert client.manager is manager + assert client.fairness_function.f_type == server.f_type + + def setup_client_controller_from_server( + self, + server: FairnessControllerServer, + idx: int, + ) -> FairnessControllerClient: + """Instantiate and return a client-side fairness controller.""" + manager = self.setup_mock_training_manager(idx) + query = server.prepare_fairness_setup_query() + return FairnessControllerClient.from_setup_query(query, manager) + + def setup_mock_training_manager( + self, + idx: int, + ) -> mock.MagicMock: + """Setup and return a mock TrainingManager for a given client.""" + manager = mock.create_autospec(TrainingManager, instance=True) + manager.aggrg = mock.create_autospec(Aggregator, instance=True) + manager.logger = mock.create_autospec(logging.Logger, instance=True) + manager.model = mock.create_autospec(Model, instance=True) + manager.train_data = build_mock_dataset(idx) + return manager + + def setup_fairness_controllers_and_secagg( + self, + n_peers: int, + use_secagg: bool, + ) -> Tuple[ + FairnessControllerServer, + List[FairnessControllerClient], + Optional[Decrypter], + Union[List[Encrypter], List[None]], + ]: + """Instantiate fairness and (optional) secagg controllers.""" + # Instantiate the server and client controllers. + server = self.setup_server_controller() + clients = [ + self.setup_client_controller_from_server(server, idx) + for idx in range(n_peers) + ] + # Optionally set up SecAgg controllers, then return. + if use_secagg: + decrypter, encrypters = build_secagg_controllers(n_peers) + return server, clients, decrypter, encrypters # type: ignore + return server, clients, None, [None for _ in range(n_peers)] + + @pytest.mark.parametrize( + "use_secagg", [False, True], ids=["clrtxt", "secagg"] + ) + @pytest.mark.asyncio + async def test_exchange_sensitive_groups_list_and_counts( + self, + use_secagg: bool, + ) -> None: + """Test that sensitive groups' definitions and counts works.""" + n_peers = len(CLIENT_COUNTS) + # Instantiate the fairness and optional secagg controllers. + server, clients, decrypter, encrypters = ( + self.setup_fairness_controllers_and_secagg(n_peers, use_secagg) + ) + # Run setup coroutines, using mock network endpoints. + async with setup_mock_network_endpoints(n_peers) as network: + coro_server = server.exchange_sensitive_groups_list_and_counts( + netwk=network[0], secagg=decrypter + ) + coro_clients = [ + client.exchange_sensitive_groups_list_and_counts( + netwk=network[1][idx], secagg=encrypters[idx] + ) + for idx, client in enumerate(clients) + ] + counts, *_ = await asyncio.gather(coro_server, *coro_clients) + # Verify that expected attributes were assigned with expected values. + assert isinstance(counts, list) and len(counts) == len(TOTAL_COUNTS) + assert dict(zip(server.groups, counts)) == TOTAL_COUNTS + assert all(client.groups == server.groups for client in clients) + + @pytest.mark.parametrize( + "use_secagg", [False, True], ids=["clrtxt", "secagg"] + ) + @pytest.mark.asyncio + async def test_finalize_fairness_setup( + self, + use_secagg: bool, + ) -> None: + """Test that 'finalize_fairness_setup' works properly. + + This test should be overridden by subclasses to perform + algorithm-specific verification (and warnings-catching). + """ + aggregator = mock.create_autospec(Aggregator, instance=True) + agg_final, *_ = await self.run_finalize_fairness_setup( + aggregator, use_secagg + ) + # Verify that the server returns an Aggregator. + assert isinstance(agg_final, Aggregator) + + async def run_finalize_fairness_setup( + self, + aggregator: Aggregator, + use_secagg: bool, + ) -> Tuple[ + Aggregator, FairnessControllerServer, List[FairnessControllerClient] + ]: + """Run 'finalize_fairness_setup' and return controllers.""" + n_peers = len(CLIENT_COUNTS) + # Instantiate the fairness and optional secagg controllers. + server, clients, decrypter, encrypters = ( + self.setup_fairness_controllers_and_secagg(n_peers, use_secagg) + ) + # Assign expected group definitions and counts. + server.groups = sorted(list(TOTAL_COUNTS)) + for client in clients: + client.groups = server.groups.copy() + counts = [TOTAL_COUNTS[group] for group in server.groups] + # Run setup coroutines, using mock network endpoints. + async with setup_mock_network_endpoints(n_peers) as network: + coro_server = server.finalize_fairness_setup( + netwk=network[0], + secagg=decrypter, + counts=counts, + aggregator=aggregator, + ) + coro_clients = [ + client.finalize_fairness_setup( + netwk=network[1][idx], + secagg=encrypters[idx], + ) + for idx, client in enumerate(clients) + ] + agg_final, *_ = await asyncio.gather(coro_server, *coro_clients) + # Return the resulting aggregator and controllers. + return agg_final, server, clients + + def test_setup_fairness_metrics( + self, + ) -> None: + """Test that 'setup_fairness_metrics' has proper output type.""" + server = self.setup_server_controller() + client = self.setup_client_controller_from_server(server, idx=0) + metrics = client.setup_fairness_metrics() + assert isinstance(metrics, list) + assert all(isinstance(metric, MeanMetric) for metric in metrics) + + @pytest.mark.parametrize("idx", list(range(len(CLIENT_COUNTS)))) + def test_compute_fairness_metrics( + self, + idx: int, + ) -> None: + """Test that metrics computation works for a given client.""" + server = self.setup_server_controller() + client = self.setup_client_controller_from_server(server, idx) + client.groups = list(TOTAL_COUNTS) + # Run mock computations. + with mock.patch.object( + client.computer, "compute_groupwise_metrics" + ) as patch_compute: + patch_compute.return_value = self.mock_client_metrics[idx].copy() + share_values, local_values = client.compute_fairness_measures(32) + # Verify that expected shareable values were output. + patch_compute.assert_called_once() + assert isinstance(share_values, list) + expected_share = [ + group_values.get(group, 0.0) * CLIENT_COUNTS[idx].get(group, 0.0) + for group_values in self.mock_client_metrics[idx].values() + for group in client.groups + ] + assert share_values == expected_share + # Verify that expected local values were output. + assert isinstance(local_values, dict) + expected_local = self.mock_client_metrics[idx].copy() + if "accuracy" in expected_local: + expected_local[client.fairness_function.f_type] = ( + client.fairness_function.compute_from_group_accuracy( + expected_local["accuracy"] + ) + ) + assert_dict_equal(local_values, expected_local) + + @pytest.mark.parametrize( + "use_secagg", [False, True], ids=["clrtxt", "secagg"] + ) + @pytest.mark.asyncio + async def test_receive_and_aggregate_fairness_metrics( + self, + use_secagg: bool, + ) -> None: + """Test that server-side aggregation of metrics works properly.""" + # Setup a server controller and optionally some secagg controllers. + n_peers = len(CLIENT_COUNTS) + server = self.setup_server_controller() + server.groups = list(TOTAL_COUNTS) + decrypter, encrypters = ( + build_secagg_controllers(n_peers) if use_secagg else (None, None) + ) + # Setup a mock network endpoint receiving local metrics. + netwk = mock.create_autospec(NetworkServer, instance=True) + replies = { + f"client_{idx}": FairnessReply( + [ + group_values.get(group, 0.0) * CLIENT_COUNTS[idx].get(group, 0.0) + for group_values in self.mock_client_metrics[idx].values() + for group in list(TOTAL_COUNTS) + ] + ) + for idx in range(len(self.mock_client_metrics)) + } + if encrypters: + secagg_replies = { + key: SecaggFairnessReply.from_cleartext_message( + cleartext=val, encrypter=encrypters[idx] + ) + for idx, (key, val) in enumerate(replies.items()) + } + netwk.wait_for_messages.return_value = { + key: SerializedMessage.from_message_string(val.to_string()) + for key, val in secagg_replies.items() + } + else: + netwk.wait_for_messages.return_value = { + key: SerializedMessage.from_message_string(val.to_string()) + for key, val in replies.items() + } + # Run the reception and (secure-)aggregation of these replies. + aggregated = await server.receive_and_aggregate_fairness_measures( + netwk=netwk, secagg=decrypter + ) + # Verify that outputs match expectations. + assert isinstance(aggregated, list) + expected = [ + sum(rv) for rv in zip(*[rep.values for rep in replies.values()]) + ] + assert np.allclose(np.array(aggregated), np.array(expected)) + + @pytest.mark.parametrize( + "use_secagg", [False, True], ids=["clrtxt", "secagg"] + ) + @pytest.mark.asyncio + async def test_fairness_end2end( + self, + use_secagg: bool, + ) -> None: + """Test that running both fairness setup and round routines works. + + This end-to-end test is about verifying that running all unit-tested + components together does not raise exceptions. Details about unitary + operations are left up to unit tests. + """ + # Instantiate the fairness and optional secagg controllers. + n_peers = len(CLIENT_COUNTS) + decrypter = None # type: Optional[Decrypter] + encrypters = [None] * n_peers # type: List[Optional[Encrypter]] + if use_secagg: + decrypter, encrypters = ( + build_secagg_controllers(n_peers) # type: ignore + ) + # Run end-to-end routines using mock communication endpoints. + async with setup_mock_network_endpoints(n_peers=n_peers) as netwk: + + async def server_routine() -> None: + """Server-side fairness setup and round routine.""" + nonlocal decrypter, netwk + server = self.setup_server_controller() + with pytest.warns() as warnings_record: + await server.setup_fairness( + netwk=netwk[0], + aggregator=mock.create_autospec( + Aggregator, instance=True + ), + secagg=decrypter, + ) + assert len(warnings_record) <= 1 + await netwk[0].broadcast_message(FairnessQuery(round_i=0)) + await server.run_fairness_round( + netwk=netwk[0], + secagg=decrypter, + ) + + async def client_routine(idx: int) -> None: + """Client-side fairness setup and round routine.""" + nonlocal encrypters, netwk + # Instantiate the client-side controller. + received = await netwk[1][idx].recv_message() + setup_query = await verify_server_message_validity( + netwk[1][idx], received, FairnessSetupQuery + ) + client = FairnessControllerClient.from_setup_query( + setup_query, manager=self.setup_mock_training_manager(idx) + ) + # Run the fairness setup routine. + await client.setup_fairness(netwk[1][idx], encrypters[idx]) + # Run the fairness round routine. + received = await netwk[1][idx].recv_message() + round_query = await verify_server_message_validity( + netwk[1][idx], received, FairnessQuery + ) + await client.run_fairness_round( + netwk[1][idx], round_query, encrypters[idx] + ) + + await asyncio.gather( + server_routine(), + *[client_routine(idx) for idx in range(n_peers)], + ) diff --git a/test/fairness/controllers/test_fairgrad_controllers.py b/test/fairness/controllers/test_fairgrad_controllers.py new file mode 100644 index 0000000..417b7d5 --- /dev/null +++ b/test/fairness/controllers/test_fairgrad_controllers.py @@ -0,0 +1,72 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for Fed-FairGrad controllers.""" + +import os +from unittest import mock + +import pytest + +from declearn.aggregator import Aggregator, SumAggregator +from declearn.fairness.api import FairnessDataset +from declearn.fairness.fairgrad import ( + FairgradControllerClient, + FairgradControllerServer, +) +from declearn.test_utils import make_importable + +with make_importable(os.path.dirname(os.path.abspath(__file__))): + from fairness_controllers_testing import FairnessControllerTestSuite + + +class TestFairgradControllers(FairnessControllerTestSuite): + """Unit tests for Fed-FairGrad controllers.""" + + server_cls = FairgradControllerServer + client_cls = FairgradControllerClient + + @pytest.mark.parametrize( + "use_secagg", [False, True], ids=["clrtxt", "secagg"] + ) + @pytest.mark.asyncio + async def test_finalize_fairness_setup( + self, + use_secagg: bool, + ) -> None: + aggregator = mock.create_autospec(Aggregator, instance=True) + with pytest.warns(RuntimeWarning, match="SumAggregator"): + agg_final, server, clients = ( + await self.run_finalize_fairness_setup(aggregator, use_secagg) + ) + # Verify that aggregators were replaced with a SumAggregator. + assert isinstance(agg_final, SumAggregator) + assert all( + isinstance(client.manager.aggrg, SumAggregator) + for client in clients + ) + # Verify that FairGrad weights were shared and applied. + assert isinstance(server, FairgradControllerServer) + weights = server.weights_controller.get_current_weights(norm_nk=True) + expectw = dict(zip(server.groups, weights)) + for client in clients: + mock_dst = client.manager.train_data + assert isinstance(mock_dst, FairnessDataset) + assert isinstance(mock_dst, mock.NonCallableMagicMock) + mock_dst.set_sensitive_group_weights.assert_called_once_with( + weights=expectw, adjust_by_counts=True + ) -- GitLab