From fe08f5424217508b3f5b8644af0b01e6d2ccd49e Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Wed, 24 Jul 2024 11:13:05 +0200 Subject: [PATCH] Add 'frequency' options to skip some evaluation or fairness rounds. --- declearn/main/_server.py | 46 ++++++++++++++++++++-------- declearn/main/config/_dataclasses.py | 18 +++++++++-- 2 files changed, 49 insertions(+), 15 deletions(-) diff --git a/declearn/main/_server.py b/declearn/main/_server.py index cb4d9e8..cb53805 100644 --- a/declearn/main/_server.py +++ b/declearn/main/_server.py @@ -137,7 +137,7 @@ class FederatedServer: self._decrypter = None # type: Optional[Decrypter] self._secagg_peers = set() # type: Set[str] # Set up private attributes to record the loss values and best weights. - self._loss = {} # type: Dict[int, float] + self._losses = [] # type: List[float] self._best = None # type: Optional[Vector] # Set up a private attribute to prevent redundant weights sharing. self._clients_holding_latest_model = set() # type: Set[str] @@ -284,7 +284,6 @@ class FederatedServer: # Iteratively run training and evaluation rounds. round_i = 0 while True: - # Run (opt.) fairness; training; evaluation. await self.fairness_round(round_i, config.fairness) round_i += 1 await self.training_round(round_i, config.training) @@ -292,9 +291,15 @@ class FederatedServer: # Decide whether to keep training for at least one round. if not self._keep_training(round_i, config.rounds, early_stop): break - # When checkpointing, evaluate the last model's fairness. + # When checkpointing, force evaluating the last model. if self.ckptr is not None: - await self.fairness_round(round_i, config.fairness) + if round_i % config.evaluate.frequency: + await self.evaluation_round( + round_i, config.evaluate, force_run=True + ) + await self.fairness_round( + round_i, config.fairness, force_run=True + ) # Interrupt training when time comes. self.logger.info("Stopping training.") await self.stop_training(round_i) @@ -542,8 +547,12 @@ class FederatedServer: self, round_i: int, fairness_cfg: FairnessConfig, + force_run: bool = False, ) -> None: - """Orchestrate a fairness round. + """Orchestrate a fairness round, when configured to do so. + + If fairness is not set, or if `round_i` is to be skipped based + on `fairness_cfg.frequency`, do nothing. Parameters ---------- @@ -553,9 +562,15 @@ class FederatedServer: FairnessConfig dataclass instance wrapping data-batching and computational effort constraints hyper-parameters for fairness evaluation. + force_run: + Whether to disregard `fairness_cfg.frequency` and run the + round (provided a fairness controller is setup). """ + # Early exit when fairness is not set or the round is to be skipped. if self.fairness is None: return + if (round_i % fairness_cfg.frequency) and not force_run: + return # Run SecAgg setup when needed. self.logger.info("Initiating fairness-enforcing round %s", round_i) clients = self.netwk.client_names # FUTURE: enable sampling(?) @@ -721,17 +736,24 @@ class FederatedServer: self, round_i: int, valid_cfg: EvaluateConfig, + force_run: bool = False, ) -> None: - """Orchestrate an evaluation round. + """Orchestrate an evaluation round, when configured to do so. + + If `round_i` is to be skipped based on `fairness_cfg.frequency`, + do nothing. Parameters ---------- round_i: int - Index of the evaluation round. + Index of the latest training round. valid_cfg: EvaluateConfig EvaluateConfig dataclass instance wrapping data-batching and computational effort constraints hyper-parameters. """ + # Early exit when the evaluation round is to be skipped. + if (round_i % valid_cfg.frequency) and not force_run: + return # Select participating clients. Run SecAgg setup when needed. self.logger.info("Initiating evaluation round %s", round_i) clients = self._select_evaluation_round_participants() @@ -766,8 +788,8 @@ class FederatedServer: metrics, results if len(results) > 1 else {} ) # Record the global loss, and update the kept "best" weights. - self._loss[round_i] = loss - if loss == min(self._loss.values()): + self._losses.append(loss) + if loss == min(self._losses): self._best = self.model.get_weights() def _select_evaluation_round_participants( @@ -891,7 +913,7 @@ class FederatedServer: self.ckptr.save_metrics( metrics=metrics, prefix=f"metrics_{client}", - append=bool(self._loss), + append=bool(self._losses), timestamp=timestamp, ) @@ -917,7 +939,7 @@ class FederatedServer: self.logger.info("Maximum number of training rounds reached.") return False if early_stop is not None: - early_stop.update(self._loss[round_i]) + early_stop.update(self._losses[-1]) if not early_stop.keep_training: self.logger.info("Early stopping criterion reached.") return False @@ -937,7 +959,7 @@ class FederatedServer: self.logger.info("Recovering weights that yielded the lowest loss.") message = messaging.StopTraining( weights=self._best or self.model.get_weights(), - loss=min(self._loss.values()) if self._loss else float("nan"), + loss=min(self._losses, default=float("nan")), rounds=rounds, ) self.logger.info("Notifying clients that training is over.") diff --git a/declearn/main/config/_dataclasses.py b/declearn/main/config/_dataclasses.py index 867343e..91f0410 100644 --- a/declearn/main/config/_dataclasses.py +++ b/declearn/main/config/_dataclasses.py @@ -127,12 +127,20 @@ class TrainingConfig: class EvaluateConfig(TrainingConfig): """Dataclass wrapping parameters for an evaluation round. + Exclusive attributes + -------------------- + frequency: int + Number of training rounds to run between evaluation ones. + By default, run an evaluation round after each training one. + Please refer to the parent class `TrainingConfig` for details - on the wrapped parameters / attribute. Note that `n_epoch` is - dropped when this config is turned into an EvaluationRequest - message. + on the other wrapped parameters / attributes. + + Note that `n_epoch` is dropped when this config is turned into + an EvaluationRequest message. """ + frequency: int = 1 drop_remainder: bool = False @property @@ -238,6 +246,9 @@ class FairnessConfig: ---------- batch_size: int Number of samples per processed data batch. + frequency: int + Number of training rounds to run between fairness ones. + By default, run a fairness round before each training one. n_batch: int or None, default=None Optional maximum number of batches to draw. If None, use the entire training dataset. @@ -249,5 +260,6 @@ class FairnessConfig: """ batch_size: int = 32 + frequency: int = 1 n_batch: Optional[int] = None thresh: Optional[float] = None -- GitLab