Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit 16211c95 authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Fix or silence linter warnings on test code files.

parent e0b9b0cb
No related branches found
No related tags found
1 merge request!69Enable Fairness-Aware Federated Learning
Pipeline #1010732 failed
......@@ -26,6 +26,7 @@ from declearn.fairness.api import FairnessFunction
from declearn.fairness.fairgrad import FairgradWeightsController
# pylint: disable=duplicate-code
COUNTS = {(0, 0): 30, (0, 1): 15, (1, 0): 35, (1, 1): 20}
F_TYPES = [
"accuracy_parity",
......@@ -33,6 +34,7 @@ F_TYPES = [
"equality_of_opportunity",
"equalized_odds",
]
# pylint: enable=duplicate-code
class TestFairgradWeightsController:
......
......@@ -49,6 +49,8 @@ with make_importable(os.path.dirname(os.path.abspath(__file__))):
class TestFairbatchControllers(FairnessControllerTestSuite):
"""Unit tests for Fed-FairBatch / FedFB controllers."""
# similar code to FairGrad and parent code; pylint: disable=duplicate-code
server_cls = FairbatchControllerServer
client_cls = FairbatchControllerClient
......@@ -169,13 +171,12 @@ class TestFairbatchControllers(FairnessControllerTestSuite):
client.groups = server.groups.copy()
counts = [TOTAL_COUNTS[group] for group in server.groups]
# Run setup coroutines, using mock network endpoints.
aggregator = mock.create_autospec(SumAggregator, instance=True)
async with setup_mock_network_endpoints(n_peers) as network:
coro_server = server.finalize_fairness_setup(
netwk=network[0],
secagg=None,
counts=counts,
aggregator=aggregator,
aggregator=mock.create_autospec(SumAggregator, instance=True),
)
coro_clients = [
client.finalize_fairness_setup(
......
......@@ -132,13 +132,12 @@ class TestFairgradControllers(FairnessControllerTestSuite):
mock_dst.set_sensitive_group_weights.side_effect = Exception
counts = [TOTAL_COUNTS[group] for group in server.groups]
# Run setup coroutines, using mock network endpoints.
aggregator = mock.create_autospec(SumAggregator, instance=True)
async with setup_mock_network_endpoints(n_peers) as network:
coro_server = server.finalize_fairness_setup(
netwk=network[0],
secagg=None,
counts=counts,
aggregator=aggregator,
aggregator=mock.create_autospec(SumAggregator, instance=True),
)
coro_clients = [
client.finalize_fairness_setup(
......
......@@ -72,9 +72,8 @@ class TestFairnessInMemoryDatasetInit:
dst = FairnessInMemoryDataset(
dataset, s_attr=s_attr, target="col_y", sensitive_target=True
)
expected = pd.concat(
[pd.DataFrame(dataset["col_y"].rename("target")), s_attr],
axis=1,
expected = pd.DataFrame(
{"target": dataset["col_y"], "col_s": s_attr["col_s"]}
).apply(tuple, axis=1)
assert isinstance(dst.sensitive, pd.Series)
assert (dst.sensitive == expected).all()
......
......@@ -117,6 +117,7 @@ async def server_routine(
n_clients: int = 3,
) -> None:
"""Run the FL routine of the server."""
# similar to SecAgg functional test; pylint: disable=duplicate-code
model = SklearnSGDModel.from_parameters(
kind="classifier",
loss="log_loss",
......@@ -221,9 +222,7 @@ async def test_toy_classif_fairness(
coro_server, *coro_clients, return_exceptions=True
)
# Assert that no exceptions occurred during the process.
errors = "\n".join(
repr(exc) for exc in outputs if isinstance(exc, Exception)
)
errors = "\n".join(repr(e) for e in outputs if isinstance(e, Exception))
assert not errors, f"The FL process failed:\n{errors}"
# Load and parse utility and fairness metrics at the final round.
u_metrics = pd.read_csv(os.path.join(tmp_path, "metrics.csv"))
......
......@@ -221,13 +221,11 @@ async def run_declearn_experiment(
for i, (train, valid) in enumerate(datasets)
]
# Run the coroutines concurrently using asyncio.
outputs = await asyncio.gather(
output = await asyncio.gather(
coro_server, *coro_clients, return_exceptions=True
)
# Assert that no exceptions occurred during the process.
errors = "\n".join(
repr(exc) for exc in outputs if isinstance(exc, Exception)
)
errors = "\n".join(repr(e) for e in output if isinstance(e, Exception))
assert not errors, f"The FL process failed:\n{errors}"
# Assert that the experiment ran properly.
with open(
......
......@@ -149,7 +149,7 @@ class OptiModuleTestSuite(PluginTestBase):
) -> None:
# For Noise-addition mechanisms, seed the (unsafe) RNG.
if issubclass(cls, NoiseModule):
cls = functools.partial(
cls = functools.partial( # type: ignore[misc]
cls, safe_mode=False, seed=0
) # type: ignore # partial wraps the __init__ method
# Run the unit test.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment