Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit a87e891e authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Improve coverage of 'fairness' unit tests.

parent 743998fb
No related branches found
No related tags found
1 merge request!69Enable Fairness-Aware Federated Learning
...@@ -72,3 +72,16 @@ class TestFairfedAggregator: ...@@ -72,3 +72,16 @@ class TestFairfedAggregator:
updates.__mul__.assert_called_once_with(expectw) updates.__mul__.assert_called_once_with(expectw)
assert model_updates.updates is updates.__mul__.return_value assert model_updates.updates is updates.__mul__.return_value
assert model_updates.weights == expectw assert model_updates.weights == expectw
def test_finalize_updates(self) -> None:
"""Test that 'finalize_updates' works as expected."""
# Set up a FairFed aggregator and initialize it.
n_samples = 100
aggregator = FairfedAggregator(beta=0.1)
aggregator.initialize_local_weight(n_samples=n_samples)
# Prepare, then finalize updates.
updates = mock.create_autospec(Vector, instance=True)
output = aggregator.finalize_updates(
aggregator.prepare_for_sharing(updates, n_steps=mock.MagicMock())
)
assert output == (updates * n_samples / n_samples)
...@@ -74,6 +74,22 @@ class TestFairfedValueComputer: ...@@ -74,6 +74,22 @@ class TestFairfedValueComputer:
warnings.simplefilter("ignore", RuntimeWarning) warnings.simplefilter("ignore", RuntimeWarning)
computer.identify_key_groups(GROUPS_EXTEND.copy()) computer.identify_key_groups(GROUPS_EXTEND.copy())
@pytest.mark.parametrize("f_type", F_TYPES)
def test_identify_key_groups_hybrid_exception(
self,
f_type: str,
) -> None:
"""Test 'identify_key_groups' exception raising with 'hybrid' groups.
'Hybrid' groups are groups that seemingly arise from a categorical
target that does not cross all sensitive attribute modalities.
"""
computer = FairfedValueComputer(f_type, strict=True, target=1)
with pytest.raises(KeyError):
with warnings.catch_warnings():
warnings.simplefilter("ignore", RuntimeWarning)
computer.identify_key_groups([(0, 0), (0, 1), (1, 0), (2, 1)])
@pytest.mark.parametrize("binary", [True, False], ids=["binary", "extend"]) @pytest.mark.parametrize("binary", [True, False], ids=["binary", "extend"])
@pytest.mark.parametrize("strict", [True, False], ids=["strict", "free"]) @pytest.mark.parametrize("strict", [True, False], ids=["strict", "free"])
@pytest.mark.parametrize("f_type", F_TYPES[1:]) # avoid warning on AccPar @pytest.mark.parametrize("f_type", F_TYPES[1:]) # avoid warning on AccPar
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
"""Unit tests for Fed-FairBatch controllers.""" """Unit tests for Fed-FairBatch controllers."""
import asyncio
import os import os
from typing import List from typing import List
from unittest import mock from unittest import mock
...@@ -24,6 +25,7 @@ from unittest import mock ...@@ -24,6 +25,7 @@ from unittest import mock
import pytest import pytest
from declearn.aggregator import Aggregator, SumAggregator from declearn.aggregator import Aggregator, SumAggregator
from declearn.communication.utils import ErrorMessageException
from declearn.fairness.api import ( from declearn.fairness.api import (
FairnessControllerClient, FairnessControllerClient,
FairnessControllerServer, FairnessControllerServer,
...@@ -34,7 +36,7 @@ from declearn.fairness.fairbatch import ( ...@@ -34,7 +36,7 @@ from declearn.fairness.fairbatch import (
FairbatchDataset, FairbatchDataset,
FairbatchSamplingController, FairbatchSamplingController,
) )
from declearn.test_utils import make_importable from declearn.test_utils import make_importable, setup_mock_network_endpoints
with make_importable(os.path.dirname(os.path.abspath(__file__))): with make_importable(os.path.dirname(os.path.abspath(__file__))):
from fairness_controllers_testing import ( from fairness_controllers_testing import (
...@@ -125,18 +127,20 @@ class TestFairbatchControllers(FairnessControllerTestSuite): ...@@ -125,18 +127,20 @@ class TestFairbatchControllers(FairnessControllerTestSuite):
with mock.patch( with mock.patch(
"declearn.fairness.fairbatch._server.setup_fairbatch_controller" "declearn.fairness.fairbatch._server.setup_fairbatch_controller"
) as patch_setup_fairbatch: ) as patch_setup_fairbatch:
FairbatchControllerServer( controller = FairbatchControllerServer(
f_type="demographic_parity", f_type="demographic_parity",
fedfb=False, fedfb=False,
) )
assert not controller.fedfb
patch_setup_fairbatch.assert_called_once() patch_setup_fairbatch.assert_called_once()
with mock.patch( with mock.patch(
"declearn.fairness.fairbatch._server.setup_fedfb_controller" "declearn.fairness.fairbatch._server.setup_fedfb_controller"
) as patch_setup_fedfb: ) as patch_setup_fedfb:
FairbatchControllerServer( controller = FairbatchControllerServer(
f_type="demographic_parity", f_type="demographic_parity",
fedfb=True, fedfb=True,
) )
assert controller.fedfb
patch_setup_fedfb.assert_called_once() patch_setup_fedfb.assert_called_once()
def test_init_alpha_param(self) -> None: def test_init_alpha_param(self) -> None:
...@@ -146,3 +150,50 @@ class TestFairbatchControllers(FairnessControllerTestSuite): ...@@ -146,3 +150,50 @@ class TestFairbatchControllers(FairnessControllerTestSuite):
f_type="demographic_parity", alpha=alpha f_type="demographic_parity", alpha=alpha
) )
assert server.sampling_controller.alpha is alpha assert server.sampling_controller.alpha is alpha
@pytest.mark.asyncio
async def test_finalize_fairness_setup_error(
self,
) -> None:
"""Test that FairBatch probas update error-catching works properly."""
n_peers = len(CLIENT_COUNTS)
# Instantiate the fairness controllers.
server = self.setup_server_controller()
clients = [
self.setup_client_controller_from_server(server, idx)
for idx in range(n_peers)
]
# Assign expected group definitions and counts.
server.groups = sorted(list(TOTAL_COUNTS))
for client in clients:
client.groups = server.groups.copy()
counts = [TOTAL_COUNTS[group] for group in server.groups]
# Run setup coroutines, using mock network endpoints.
aggregator = mock.create_autospec(SumAggregator, instance=True)
async with setup_mock_network_endpoints(n_peers) as network:
coro_server = server.finalize_fairness_setup(
netwk=network[0],
secagg=None,
counts=counts,
aggregator=aggregator,
)
coro_clients = [
client.finalize_fairness_setup(
netwk=network[1][idx],
secagg=None,
)
for idx, client in enumerate(clients)
]
# Have the sampling probabilities' assignment fail.
with mock.patch.object(
FairbatchDataset,
"set_sampling_probabilities",
side_effect=Exception,
) as patch_set_sampling_probabilities:
exc_server, *exc_clients = await asyncio.gather(
coro_server, *coro_clients, return_exceptions=True
)
# Assert that expected exceptions were raised.
assert isinstance(exc_server, ErrorMessageException)
assert all(isinstance(exc, RuntimeError) for exc in exc_clients)
assert patch_set_sampling_probabilities.call_count == n_peers
...@@ -156,3 +156,33 @@ class TestFairfedControllers(FairnessControllerTestSuite): ...@@ -156,3 +156,33 @@ class TestFairfedControllers(FairnessControllerTestSuite):
assert server["fairfed_deltavg"] == ( assert server["fairfed_deltavg"] == (
sum(client["fairfed_delta"] for client in clients) / len(clients) sum(client["fairfed_delta"] for client in clients) / len(clients)
) )
@pytest.mark.parametrize(
"strict", [True, False], ids=["strict", "extended"]
)
def test_init_params(
self,
strict: bool,
) -> None:
"""Test that instantiation parameters are properly passed."""
rng = np.random.default_rng()
beta = abs(rng.normal())
target = int(rng.choice(2))
controller = FairfedControllerServer(
f_type="demographic_parity",
beta=beta,
strict=strict,
target=target,
)
assert controller.beta == beta
assert controller.fairfed_computer.f_type == "demographic_parity"
assert controller.strict is strict
assert controller.fairfed_computer.strict is strict
assert controller.fairfed_computer.target is target
# Verify that parameters are transmitted to clients.
client = self.setup_client_controller_from_server(controller, idx=0)
assert isinstance(client, FairfedControllerClient)
assert client.beta == controller.beta
assert client.fairfed_computer.f_type == "demographic_parity"
assert client.strict is strict
assert client.fairfed_computer.strict is strict
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
"""Unit tests for Fed-FairGrad controllers.""" """Unit tests for Fed-FairGrad controllers."""
import asyncio
import os import os
from typing import List from typing import List
from unittest import mock from unittest import mock
...@@ -24,6 +25,7 @@ from unittest import mock ...@@ -24,6 +25,7 @@ from unittest import mock
import pytest import pytest
from declearn.aggregator import Aggregator, SumAggregator from declearn.aggregator import Aggregator, SumAggregator
from declearn.communication.utils import ErrorMessageException
from declearn.fairness.api import ( from declearn.fairness.api import (
FairnessDataset, FairnessDataset,
FairnessControllerClient, FairnessControllerClient,
...@@ -34,10 +36,14 @@ from declearn.fairness.fairgrad import ( ...@@ -34,10 +36,14 @@ from declearn.fairness.fairgrad import (
FairgradControllerServer, FairgradControllerServer,
FairgradWeightsController, FairgradWeightsController,
) )
from declearn.test_utils import make_importable from declearn.test_utils import make_importable, setup_mock_network_endpoints
with make_importable(os.path.dirname(os.path.abspath(__file__))): with make_importable(os.path.dirname(os.path.abspath(__file__))):
from fairness_controllers_testing import FairnessControllerTestSuite from fairness_controllers_testing import (
CLIENT_COUNTS,
TOTAL_COUNTS,
FairnessControllerTestSuite,
)
class TestFairgradControllers(FairnessControllerTestSuite): class TestFairgradControllers(FairnessControllerTestSuite):
...@@ -103,3 +109,47 @@ class TestFairgradControllers(FairnessControllerTestSuite): ...@@ -103,3 +109,47 @@ class TestFairgradControllers(FairnessControllerTestSuite):
self.verify_fairness_round_metrics(metrics) self.verify_fairness_round_metrics(metrics)
patch_update_weights.assert_called_once() patch_update_weights.assert_called_once()
self.verify_fairgrad_weights_coherence(server, clients) self.verify_fairgrad_weights_coherence(server, clients)
@pytest.mark.asyncio
async def test_finalize_fairness_setup_error(
self,
) -> None:
"""Test that FairGrad weights setup error-catching works properly."""
n_peers = len(CLIENT_COUNTS)
# Instantiate the fairness controllers.
server = self.setup_server_controller()
clients = [
self.setup_client_controller_from_server(server, idx)
for idx in range(n_peers)
]
# Assign expected group definitions and counts.
# Have client datasets fail upon receiving sensitive group weights.
server.groups = sorted(list(TOTAL_COUNTS))
for client in clients:
client.groups = server.groups.copy()
mock_dst = client.manager.train_data
assert isinstance(mock_dst, mock.NonCallableMagicMock)
mock_dst.set_sensitive_group_weights.side_effect = Exception
counts = [TOTAL_COUNTS[group] for group in server.groups]
# Run setup coroutines, using mock network endpoints.
aggregator = mock.create_autospec(SumAggregator, instance=True)
async with setup_mock_network_endpoints(n_peers) as network:
coro_server = server.finalize_fairness_setup(
netwk=network[0],
secagg=None,
counts=counts,
aggregator=aggregator,
)
coro_clients = [
client.finalize_fairness_setup(
netwk=network[1][idx],
secagg=None,
)
for idx, client in enumerate(clients)
]
exc_server, *exc_clients = await asyncio.gather(
coro_server, *coro_clients, return_exceptions=True
)
# Assert that expected exceptions were raised.
assert isinstance(exc_server, ErrorMessageException)
assert all(isinstance(exc, RuntimeError) for exc in exc_clients)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment