From 31bda9c22e8d97d728cc8594fd650ebffd0a8c54 Mon Sep 17 00:00:00 2001
From: Paul Andrey <paul.andrey@inria.fr>
Date: Tue, 23 Jul 2024 16:47:49 +0200
Subject: [PATCH] Fix fairness rounds indexing and enable last model fairness
 evaluation.

- Fairness rounds occur prior to training ones. Consequently, this
  commit modifies their indexing to be aligned with that of the
  _previous_ training round. As a result, utility and fairness
  metrics are properly aligned in checkpoint files.
- In addition, this commit makes it so that a final fairness
  evaluation round is run prior to ending the FL process when
  checkpointing is setup at the server level.
---
 declearn/main/_client.py                 |  2 +-
 declearn/main/_server.py                 | 15 ++++++++++-----
 test/functional/test_toy_clf_fairness.py | 12 +++++++-----
 test/main/test_main_client.py            |  4 ++--
 4 files changed, 20 insertions(+), 13 deletions(-)

diff --git a/declearn/main/_client.py b/declearn/main/_client.py
index 1c09dc0..506046e 100644
--- a/declearn/main/_client.py
+++ b/declearn/main/_client.py
@@ -670,7 +670,7 @@ class FederatedClient:
             self.ckptr.save_metrics(
                 metrics=metrics,
                 prefix="fairness_metrics",
-                append=(query.round_i > 1),
+                append=bool(query.round_i),
                 timestamp=f"round_{query.round_i}",
             )
 
diff --git a/declearn/main/_server.py b/declearn/main/_server.py
index 9fc16c6..cb4d9e8 100644
--- a/declearn/main/_server.py
+++ b/declearn/main/_server.py
@@ -284,13 +284,17 @@ class FederatedServer:
             # Iteratively run training and evaluation rounds.
             round_i = 0
             while True:
+                # Run (opt.) fairness; training; evaluation.
+                await self.fairness_round(round_i, config.fairness)
                 round_i += 1
-                if self.fairness is not None:
-                    await self.fairness_round(round_i, config.fairness)
                 await self.training_round(round_i, config.training)
                 await self.evaluation_round(round_i, config.evaluate)
+                # Decide whether to keep training for at least one round.
                 if not self._keep_training(round_i, config.rounds, early_stop):
                     break
+            # When checkpointing, evaluate the last model's fairness.
+            if self.ckptr is not None:
+                await self.fairness_round(round_i, config.fairness)
             # Interrupt training when time comes.
             self.logger.info("Stopping training.")
             await self.stop_training(round_i)
@@ -544,13 +548,14 @@ class FederatedServer:
         Parameters
         ----------
         round_i:
-            Index of the training round.
+            Index of the latest training round (start at 0).
         fairness_cfg:
             FairnessConfig dataclass instance wrapping data-batching
             and computational effort constraints hyper-parameters for
             fairness evaluation.
         """
-        assert self.fairness is not None
+        if self.fairness is None:
+            return
         # Run SecAgg setup when needed.
         self.logger.info("Initiating fairness-enforcing round %s", round_i)
         clients = self.netwk.client_names  # FUTURE: enable sampling(?)
@@ -575,7 +580,7 @@ class FederatedServer:
             self.ckptr.save_metrics(
                 metrics=metrics,
                 prefix="fairness_metrics",
-                append=(query.round_i > 1),
+                append=bool(query.round_i),
                 timestamp=f"round_{query.round_i}",
             )
 
diff --git a/test/functional/test_toy_clf_fairness.py b/test/functional/test_toy_clf_fairness.py
index 60f4540..41df97a 100644
--- a/test/functional/test_toy_clf_fairness.py
+++ b/test/functional/test_toy_clf_fairness.py
@@ -207,10 +207,12 @@ async def test_toy_classif_fairness(
 
     Set up a toy dataset for fairness-aware federated learning.
     Use a given algorithm, with a given group-fairness definition.
-    Optionally use SecAgg.
+    Run training for 5 rounds. Optionally use SecAgg.
 
-    Verify that after training for 10 rounds, the learned model achieves
-    some accuracy and has become fairer that after 5 rounds.
+    When using mere monitoring, verify that hardcoded accuracy
+    and (un)fairness levels, taken as a baseline, are achieved.
+    When using another algorithm, verify that is achieves some
+    degraded accuracy, and better fairness than the baseline.
     """
     # Set up the toy dataset and optional identity keys for SecAgg.
     datasets = generate_toy_dataset(n_clients=3)
@@ -238,8 +240,8 @@ async def test_toy_classif_fairness(
     # Note that FairFed is bound to match the FedAvg baseline due to the
     # split across clients being uniform.
     expected_fairness = {
-        "demographic_parity": 0.02,
-        "equalized_odds": 0.11,
+        "demographic_parity": 0.025,
+        "equalized_odds": 0.142,
     }
     if fairness.algorithm == "monitor":
         assert accuracy >= 0.76
diff --git a/test/main/test_main_client.py b/test/main/test_main_client.py
index 47f5539..e002b9d 100644
--- a/test/main/test_main_client.py
+++ b/test/main/test_main_client.py
@@ -1077,7 +1077,7 @@ class TestFederatedClientFairnessRound:
         if ckpt:
             client.ckptr = mock.create_autospec(Checkpointer, instance=True)
         # Call the 'fairness_round' routine and verify expected actions.
-        request = messaging.FairnessQuery(round_i=1)
+        request = messaging.FairnessQuery(round_i=0)
         await client.fairness_round(request)
         fairness.run_fairness_round.assert_awaited_once_with(
             netwk=netwk, query=request, secagg=None
@@ -1088,7 +1088,7 @@ class TestFederatedClientFairnessRound:
                 metrics=fairness.run_fairness_round.return_value,
                 prefix="fairness_metrics",
                 append=False,  # first round, hence file creation or overwrite
-                timestamp="round_1",
+                timestamp="round_0",
             )
 
     @pytest.mark.asyncio
-- 
GitLab