From fe08f5424217508b3f5b8644af0b01e6d2ccd49e Mon Sep 17 00:00:00 2001
From: Paul Andrey <paul.andrey@inria.fr>
Date: Wed, 24 Jul 2024 11:13:05 +0200
Subject: [PATCH] Add 'frequency' options to skip some evaluation or fairness
 rounds.

---
 declearn/main/_server.py             | 46 ++++++++++++++++++++--------
 declearn/main/config/_dataclasses.py | 18 +++++++++--
 2 files changed, 49 insertions(+), 15 deletions(-)

diff --git a/declearn/main/_server.py b/declearn/main/_server.py
index cb4d9e8..cb53805 100644
--- a/declearn/main/_server.py
+++ b/declearn/main/_server.py
@@ -137,7 +137,7 @@ class FederatedServer:
         self._decrypter = None  # type: Optional[Decrypter]
         self._secagg_peers = set()  # type: Set[str]
         # Set up private attributes to record the loss values and best weights.
-        self._loss = {}  # type: Dict[int, float]
+        self._losses = []  # type: List[float]
         self._best = None  # type: Optional[Vector]
         # Set up a private attribute to prevent redundant weights sharing.
         self._clients_holding_latest_model = set()  # type: Set[str]
@@ -284,7 +284,6 @@ class FederatedServer:
             # Iteratively run training and evaluation rounds.
             round_i = 0
             while True:
-                # Run (opt.) fairness; training; evaluation.
                 await self.fairness_round(round_i, config.fairness)
                 round_i += 1
                 await self.training_round(round_i, config.training)
@@ -292,9 +291,15 @@ class FederatedServer:
                 # Decide whether to keep training for at least one round.
                 if not self._keep_training(round_i, config.rounds, early_stop):
                     break
-            # When checkpointing, evaluate the last model's fairness.
+            # When checkpointing, force evaluating the last model.
             if self.ckptr is not None:
-                await self.fairness_round(round_i, config.fairness)
+                if round_i % config.evaluate.frequency:
+                    await self.evaluation_round(
+                        round_i, config.evaluate, force_run=True
+                    )
+                await self.fairness_round(
+                    round_i, config.fairness, force_run=True
+                )
             # Interrupt training when time comes.
             self.logger.info("Stopping training.")
             await self.stop_training(round_i)
@@ -542,8 +547,12 @@ class FederatedServer:
         self,
         round_i: int,
         fairness_cfg: FairnessConfig,
+        force_run: bool = False,
     ) -> None:
-        """Orchestrate a fairness round.
+        """Orchestrate a fairness round, when configured to do so.
+
+        If fairness is not set, or if `round_i` is to be skipped based
+        on `fairness_cfg.frequency`, do nothing.
 
         Parameters
         ----------
@@ -553,9 +562,15 @@ class FederatedServer:
             FairnessConfig dataclass instance wrapping data-batching
             and computational effort constraints hyper-parameters for
             fairness evaluation.
+        force_run:
+            Whether to disregard `fairness_cfg.frequency` and run the
+            round (provided a fairness controller is setup).
         """
+        # Early exit when fairness is not set or the round is to be skipped.
         if self.fairness is None:
             return
+        if (round_i % fairness_cfg.frequency) and not force_run:
+            return
         # Run SecAgg setup when needed.
         self.logger.info("Initiating fairness-enforcing round %s", round_i)
         clients = self.netwk.client_names  # FUTURE: enable sampling(?)
@@ -721,17 +736,24 @@ class FederatedServer:
         self,
         round_i: int,
         valid_cfg: EvaluateConfig,
+        force_run: bool = False,
     ) -> None:
-        """Orchestrate an evaluation round.
+        """Orchestrate an evaluation round, when configured to do so.
+
+        If `round_i` is to be skipped based on `fairness_cfg.frequency`,
+        do nothing.
 
         Parameters
         ----------
         round_i: int
-            Index of the evaluation round.
+            Index of the latest training round.
         valid_cfg: EvaluateConfig
             EvaluateConfig dataclass instance wrapping data-batching
             and computational effort constraints hyper-parameters.
         """
+        # Early exit when the evaluation round is to be skipped.
+        if (round_i % valid_cfg.frequency) and not force_run:
+            return
         # Select participating clients. Run SecAgg setup when needed.
         self.logger.info("Initiating evaluation round %s", round_i)
         clients = self._select_evaluation_round_participants()
@@ -766,8 +788,8 @@ class FederatedServer:
                 metrics, results if len(results) > 1 else {}
             )
         # Record the global loss, and update the kept "best" weights.
-        self._loss[round_i] = loss
-        if loss == min(self._loss.values()):
+        self._losses.append(loss)
+        if loss == min(self._losses):
             self._best = self.model.get_weights()
 
     def _select_evaluation_round_participants(
@@ -891,7 +913,7 @@ class FederatedServer:
             self.ckptr.save_metrics(
                 metrics=metrics,
                 prefix=f"metrics_{client}",
-                append=bool(self._loss),
+                append=bool(self._losses),
                 timestamp=timestamp,
             )
 
@@ -917,7 +939,7 @@ class FederatedServer:
             self.logger.info("Maximum number of training rounds reached.")
             return False
         if early_stop is not None:
-            early_stop.update(self._loss[round_i])
+            early_stop.update(self._losses[-1])
             if not early_stop.keep_training:
                 self.logger.info("Early stopping criterion reached.")
                 return False
@@ -937,7 +959,7 @@ class FederatedServer:
         self.logger.info("Recovering weights that yielded the lowest loss.")
         message = messaging.StopTraining(
             weights=self._best or self.model.get_weights(),
-            loss=min(self._loss.values()) if self._loss else float("nan"),
+            loss=min(self._losses, default=float("nan")),
             rounds=rounds,
         )
         self.logger.info("Notifying clients that training is over.")
diff --git a/declearn/main/config/_dataclasses.py b/declearn/main/config/_dataclasses.py
index 867343e..91f0410 100644
--- a/declearn/main/config/_dataclasses.py
+++ b/declearn/main/config/_dataclasses.py
@@ -127,12 +127,20 @@ class TrainingConfig:
 class EvaluateConfig(TrainingConfig):
     """Dataclass wrapping parameters for an evaluation round.
 
+    Exclusive attributes
+    --------------------
+    frequency: int
+        Number of training rounds to run between evaluation ones.
+        By default, run an evaluation round after each training one.
+
     Please refer to the parent class `TrainingConfig` for details
-    on the wrapped parameters / attribute. Note that `n_epoch` is
-    dropped when this config is turned into an EvaluationRequest
-    message.
+    on the other wrapped parameters / attributes.
+
+    Note that `n_epoch` is dropped when this config is turned into
+    an EvaluationRequest message.
     """
 
+    frequency: int = 1
     drop_remainder: bool = False
 
     @property
@@ -238,6 +246,9 @@ class FairnessConfig:
     ----------
     batch_size: int
         Number of samples per processed data batch.
+    frequency: int
+        Number of training rounds to run between fairness ones.
+        By default, run a fairness round before each training one.
     n_batch: int or None, default=None
         Optional maximum number of batches to draw.
         If None, use the entire training dataset.
@@ -249,5 +260,6 @@ class FairnessConfig:
     """
 
     batch_size: int = 32
+    frequency: int = 1
     n_batch: Optional[int] = None
     thresh: Optional[float] = None
-- 
GitLab