diff --git a/declearn/main/_client.py b/declearn/main/_client.py index 1c09dc01777bafcb4e3be69ecc0d5b00ffc594e2..506046e634945a75160649fdd7d9e433cb59a5f5 100644 --- a/declearn/main/_client.py +++ b/declearn/main/_client.py @@ -670,7 +670,7 @@ class FederatedClient: self.ckptr.save_metrics( metrics=metrics, prefix="fairness_metrics", - append=(query.round_i > 1), + append=bool(query.round_i), timestamp=f"round_{query.round_i}", ) diff --git a/declearn/main/_server.py b/declearn/main/_server.py index 9fc16c6863b1324b3de5a056f00ec63487abfd1e..cb4d9e86f46dabb7409ce87b03b08ce5c733a4af 100644 --- a/declearn/main/_server.py +++ b/declearn/main/_server.py @@ -284,13 +284,17 @@ class FederatedServer: # Iteratively run training and evaluation rounds. round_i = 0 while True: + # Run (opt.) fairness; training; evaluation. + await self.fairness_round(round_i, config.fairness) round_i += 1 - if self.fairness is not None: - await self.fairness_round(round_i, config.fairness) await self.training_round(round_i, config.training) await self.evaluation_round(round_i, config.evaluate) + # Decide whether to keep training for at least one round. if not self._keep_training(round_i, config.rounds, early_stop): break + # When checkpointing, evaluate the last model's fairness. + if self.ckptr is not None: + await self.fairness_round(round_i, config.fairness) # Interrupt training when time comes. self.logger.info("Stopping training.") await self.stop_training(round_i) @@ -544,13 +548,14 @@ class FederatedServer: Parameters ---------- round_i: - Index of the training round. + Index of the latest training round (start at 0). fairness_cfg: FairnessConfig dataclass instance wrapping data-batching and computational effort constraints hyper-parameters for fairness evaluation. """ - assert self.fairness is not None + if self.fairness is None: + return # Run SecAgg setup when needed. self.logger.info("Initiating fairness-enforcing round %s", round_i) clients = self.netwk.client_names # FUTURE: enable sampling(?) @@ -575,7 +580,7 @@ class FederatedServer: self.ckptr.save_metrics( metrics=metrics, prefix="fairness_metrics", - append=(query.round_i > 1), + append=bool(query.round_i), timestamp=f"round_{query.round_i}", ) diff --git a/test/functional/test_toy_clf_fairness.py b/test/functional/test_toy_clf_fairness.py index 60f454087dc051d1336f3f9ed4ad00b686bd050c..41df97a8f2275f090099720de905c7e8b74cbe9d 100644 --- a/test/functional/test_toy_clf_fairness.py +++ b/test/functional/test_toy_clf_fairness.py @@ -207,10 +207,12 @@ async def test_toy_classif_fairness( Set up a toy dataset for fairness-aware federated learning. Use a given algorithm, with a given group-fairness definition. - Optionally use SecAgg. + Run training for 5 rounds. Optionally use SecAgg. - Verify that after training for 10 rounds, the learned model achieves - some accuracy and has become fairer that after 5 rounds. + When using mere monitoring, verify that hardcoded accuracy + and (un)fairness levels, taken as a baseline, are achieved. + When using another algorithm, verify that is achieves some + degraded accuracy, and better fairness than the baseline. """ # Set up the toy dataset and optional identity keys for SecAgg. datasets = generate_toy_dataset(n_clients=3) @@ -238,8 +240,8 @@ async def test_toy_classif_fairness( # Note that FairFed is bound to match the FedAvg baseline due to the # split across clients being uniform. expected_fairness = { - "demographic_parity": 0.02, - "equalized_odds": 0.11, + "demographic_parity": 0.025, + "equalized_odds": 0.142, } if fairness.algorithm == "monitor": assert accuracy >= 0.76 diff --git a/test/main/test_main_client.py b/test/main/test_main_client.py index 47f55391f88278f9c762a760ebf80bf2ca7ad23b..e002b9dc9455dff31c52f3b91adfd912785da39e 100644 --- a/test/main/test_main_client.py +++ b/test/main/test_main_client.py @@ -1077,7 +1077,7 @@ class TestFederatedClientFairnessRound: if ckpt: client.ckptr = mock.create_autospec(Checkpointer, instance=True) # Call the 'fairness_round' routine and verify expected actions. - request = messaging.FairnessQuery(round_i=1) + request = messaging.FairnessQuery(round_i=0) await client.fairness_round(request) fairness.run_fairness_round.assert_awaited_once_with( netwk=netwk, query=request, secagg=None @@ -1088,7 +1088,7 @@ class TestFederatedClientFairnessRound: metrics=fairness.run_fairness_round.return_value, prefix="fairness_metrics", append=False, # first round, hence file creation or overwrite - timestamp="round_1", + timestamp="round_0", ) @pytest.mark.asyncio