Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit 16211c95 authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Fix or silence linter warnings on test code files.

parent e0b9b0cb
No related branches found
No related tags found
1 merge request!69Enable Fairness-Aware Federated Learning
Pipeline #1010732 failed
...@@ -26,6 +26,7 @@ from declearn.fairness.api import FairnessFunction ...@@ -26,6 +26,7 @@ from declearn.fairness.api import FairnessFunction
from declearn.fairness.fairgrad import FairgradWeightsController from declearn.fairness.fairgrad import FairgradWeightsController
# pylint: disable=duplicate-code
COUNTS = {(0, 0): 30, (0, 1): 15, (1, 0): 35, (1, 1): 20} COUNTS = {(0, 0): 30, (0, 1): 15, (1, 0): 35, (1, 1): 20}
F_TYPES = [ F_TYPES = [
"accuracy_parity", "accuracy_parity",
...@@ -33,6 +34,7 @@ F_TYPES = [ ...@@ -33,6 +34,7 @@ F_TYPES = [
"equality_of_opportunity", "equality_of_opportunity",
"equalized_odds", "equalized_odds",
] ]
# pylint: enable=duplicate-code
class TestFairgradWeightsController: class TestFairgradWeightsController:
......
...@@ -49,6 +49,8 @@ with make_importable(os.path.dirname(os.path.abspath(__file__))): ...@@ -49,6 +49,8 @@ with make_importable(os.path.dirname(os.path.abspath(__file__))):
class TestFairbatchControllers(FairnessControllerTestSuite): class TestFairbatchControllers(FairnessControllerTestSuite):
"""Unit tests for Fed-FairBatch / FedFB controllers.""" """Unit tests for Fed-FairBatch / FedFB controllers."""
# similar code to FairGrad and parent code; pylint: disable=duplicate-code
server_cls = FairbatchControllerServer server_cls = FairbatchControllerServer
client_cls = FairbatchControllerClient client_cls = FairbatchControllerClient
...@@ -169,13 +171,12 @@ class TestFairbatchControllers(FairnessControllerTestSuite): ...@@ -169,13 +171,12 @@ class TestFairbatchControllers(FairnessControllerTestSuite):
client.groups = server.groups.copy() client.groups = server.groups.copy()
counts = [TOTAL_COUNTS[group] for group in server.groups] counts = [TOTAL_COUNTS[group] for group in server.groups]
# Run setup coroutines, using mock network endpoints. # Run setup coroutines, using mock network endpoints.
aggregator = mock.create_autospec(SumAggregator, instance=True)
async with setup_mock_network_endpoints(n_peers) as network: async with setup_mock_network_endpoints(n_peers) as network:
coro_server = server.finalize_fairness_setup( coro_server = server.finalize_fairness_setup(
netwk=network[0], netwk=network[0],
secagg=None, secagg=None,
counts=counts, counts=counts,
aggregator=aggregator, aggregator=mock.create_autospec(SumAggregator, instance=True),
) )
coro_clients = [ coro_clients = [
client.finalize_fairness_setup( client.finalize_fairness_setup(
......
...@@ -132,13 +132,12 @@ class TestFairgradControllers(FairnessControllerTestSuite): ...@@ -132,13 +132,12 @@ class TestFairgradControllers(FairnessControllerTestSuite):
mock_dst.set_sensitive_group_weights.side_effect = Exception mock_dst.set_sensitive_group_weights.side_effect = Exception
counts = [TOTAL_COUNTS[group] for group in server.groups] counts = [TOTAL_COUNTS[group] for group in server.groups]
# Run setup coroutines, using mock network endpoints. # Run setup coroutines, using mock network endpoints.
aggregator = mock.create_autospec(SumAggregator, instance=True)
async with setup_mock_network_endpoints(n_peers) as network: async with setup_mock_network_endpoints(n_peers) as network:
coro_server = server.finalize_fairness_setup( coro_server = server.finalize_fairness_setup(
netwk=network[0], netwk=network[0],
secagg=None, secagg=None,
counts=counts, counts=counts,
aggregator=aggregator, aggregator=mock.create_autospec(SumAggregator, instance=True),
) )
coro_clients = [ coro_clients = [
client.finalize_fairness_setup( client.finalize_fairness_setup(
......
...@@ -72,9 +72,8 @@ class TestFairnessInMemoryDatasetInit: ...@@ -72,9 +72,8 @@ class TestFairnessInMemoryDatasetInit:
dst = FairnessInMemoryDataset( dst = FairnessInMemoryDataset(
dataset, s_attr=s_attr, target="col_y", sensitive_target=True dataset, s_attr=s_attr, target="col_y", sensitive_target=True
) )
expected = pd.concat( expected = pd.DataFrame(
[pd.DataFrame(dataset["col_y"].rename("target")), s_attr], {"target": dataset["col_y"], "col_s": s_attr["col_s"]}
axis=1,
).apply(tuple, axis=1) ).apply(tuple, axis=1)
assert isinstance(dst.sensitive, pd.Series) assert isinstance(dst.sensitive, pd.Series)
assert (dst.sensitive == expected).all() assert (dst.sensitive == expected).all()
......
...@@ -117,6 +117,7 @@ async def server_routine( ...@@ -117,6 +117,7 @@ async def server_routine(
n_clients: int = 3, n_clients: int = 3,
) -> None: ) -> None:
"""Run the FL routine of the server.""" """Run the FL routine of the server."""
# similar to SecAgg functional test; pylint: disable=duplicate-code
model = SklearnSGDModel.from_parameters( model = SklearnSGDModel.from_parameters(
kind="classifier", kind="classifier",
loss="log_loss", loss="log_loss",
...@@ -221,9 +222,7 @@ async def test_toy_classif_fairness( ...@@ -221,9 +222,7 @@ async def test_toy_classif_fairness(
coro_server, *coro_clients, return_exceptions=True coro_server, *coro_clients, return_exceptions=True
) )
# Assert that no exceptions occurred during the process. # Assert that no exceptions occurred during the process.
errors = "\n".join( errors = "\n".join(repr(e) for e in outputs if isinstance(e, Exception))
repr(exc) for exc in outputs if isinstance(exc, Exception)
)
assert not errors, f"The FL process failed:\n{errors}" assert not errors, f"The FL process failed:\n{errors}"
# Load and parse utility and fairness metrics at the final round. # Load and parse utility and fairness metrics at the final round.
u_metrics = pd.read_csv(os.path.join(tmp_path, "metrics.csv")) u_metrics = pd.read_csv(os.path.join(tmp_path, "metrics.csv"))
......
...@@ -221,13 +221,11 @@ async def run_declearn_experiment( ...@@ -221,13 +221,11 @@ async def run_declearn_experiment(
for i, (train, valid) in enumerate(datasets) for i, (train, valid) in enumerate(datasets)
] ]
# Run the coroutines concurrently using asyncio. # Run the coroutines concurrently using asyncio.
outputs = await asyncio.gather( output = await asyncio.gather(
coro_server, *coro_clients, return_exceptions=True coro_server, *coro_clients, return_exceptions=True
) )
# Assert that no exceptions occurred during the process. # Assert that no exceptions occurred during the process.
errors = "\n".join( errors = "\n".join(repr(e) for e in output if isinstance(e, Exception))
repr(exc) for exc in outputs if isinstance(exc, Exception)
)
assert not errors, f"The FL process failed:\n{errors}" assert not errors, f"The FL process failed:\n{errors}"
# Assert that the experiment ran properly. # Assert that the experiment ran properly.
with open( with open(
......
...@@ -149,7 +149,7 @@ class OptiModuleTestSuite(PluginTestBase): ...@@ -149,7 +149,7 @@ class OptiModuleTestSuite(PluginTestBase):
) -> None: ) -> None:
# For Noise-addition mechanisms, seed the (unsafe) RNG. # For Noise-addition mechanisms, seed the (unsafe) RNG.
if issubclass(cls, NoiseModule): if issubclass(cls, NoiseModule):
cls = functools.partial( cls = functools.partial( # type: ignore[misc]
cls, safe_mode=False, seed=0 cls, safe_mode=False, seed=0
) # type: ignore # partial wraps the __init__ method ) # type: ignore # partial wraps the __init__ method
# Run the unit test. # Run the unit test.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment