Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit fe08f542 authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Add 'frequency' options to skip some evaluation or fairness rounds.

parent aab0f428
No related branches found
No related tags found
1 merge request!70Finalize version 2.6.0
......@@ -137,7 +137,7 @@ class FederatedServer:
self._decrypter = None # type: Optional[Decrypter]
self._secagg_peers = set() # type: Set[str]
# Set up private attributes to record the loss values and best weights.
self._loss = {} # type: Dict[int, float]
self._losses = [] # type: List[float]
self._best = None # type: Optional[Vector]
# Set up a private attribute to prevent redundant weights sharing.
self._clients_holding_latest_model = set() # type: Set[str]
......@@ -284,7 +284,6 @@ class FederatedServer:
# Iteratively run training and evaluation rounds.
round_i = 0
while True:
# Run (opt.) fairness; training; evaluation.
await self.fairness_round(round_i, config.fairness)
round_i += 1
await self.training_round(round_i, config.training)
......@@ -292,9 +291,15 @@ class FederatedServer:
# Decide whether to keep training for at least one round.
if not self._keep_training(round_i, config.rounds, early_stop):
break
# When checkpointing, evaluate the last model's fairness.
# When checkpointing, force evaluating the last model.
if self.ckptr is not None:
await self.fairness_round(round_i, config.fairness)
if round_i % config.evaluate.frequency:
await self.evaluation_round(
round_i, config.evaluate, force_run=True
)
await self.fairness_round(
round_i, config.fairness, force_run=True
)
# Interrupt training when time comes.
self.logger.info("Stopping training.")
await self.stop_training(round_i)
......@@ -542,8 +547,12 @@ class FederatedServer:
self,
round_i: int,
fairness_cfg: FairnessConfig,
force_run: bool = False,
) -> None:
"""Orchestrate a fairness round.
"""Orchestrate a fairness round, when configured to do so.
If fairness is not set, or if `round_i` is to be skipped based
on `fairness_cfg.frequency`, do nothing.
Parameters
----------
......@@ -553,9 +562,15 @@ class FederatedServer:
FairnessConfig dataclass instance wrapping data-batching
and computational effort constraints hyper-parameters for
fairness evaluation.
force_run:
Whether to disregard `fairness_cfg.frequency` and run the
round (provided a fairness controller is setup).
"""
# Early exit when fairness is not set or the round is to be skipped.
if self.fairness is None:
return
if (round_i % fairness_cfg.frequency) and not force_run:
return
# Run SecAgg setup when needed.
self.logger.info("Initiating fairness-enforcing round %s", round_i)
clients = self.netwk.client_names # FUTURE: enable sampling(?)
......@@ -721,17 +736,24 @@ class FederatedServer:
self,
round_i: int,
valid_cfg: EvaluateConfig,
force_run: bool = False,
) -> None:
"""Orchestrate an evaluation round.
"""Orchestrate an evaluation round, when configured to do so.
If `round_i` is to be skipped based on `fairness_cfg.frequency`,
do nothing.
Parameters
----------
round_i: int
Index of the evaluation round.
Index of the latest training round.
valid_cfg: EvaluateConfig
EvaluateConfig dataclass instance wrapping data-batching
and computational effort constraints hyper-parameters.
"""
# Early exit when the evaluation round is to be skipped.
if (round_i % valid_cfg.frequency) and not force_run:
return
# Select participating clients. Run SecAgg setup when needed.
self.logger.info("Initiating evaluation round %s", round_i)
clients = self._select_evaluation_round_participants()
......@@ -766,8 +788,8 @@ class FederatedServer:
metrics, results if len(results) > 1 else {}
)
# Record the global loss, and update the kept "best" weights.
self._loss[round_i] = loss
if loss == min(self._loss.values()):
self._losses.append(loss)
if loss == min(self._losses):
self._best = self.model.get_weights()
def _select_evaluation_round_participants(
......@@ -891,7 +913,7 @@ class FederatedServer:
self.ckptr.save_metrics(
metrics=metrics,
prefix=f"metrics_{client}",
append=bool(self._loss),
append=bool(self._losses),
timestamp=timestamp,
)
......@@ -917,7 +939,7 @@ class FederatedServer:
self.logger.info("Maximum number of training rounds reached.")
return False
if early_stop is not None:
early_stop.update(self._loss[round_i])
early_stop.update(self._losses[-1])
if not early_stop.keep_training:
self.logger.info("Early stopping criterion reached.")
return False
......@@ -937,7 +959,7 @@ class FederatedServer:
self.logger.info("Recovering weights that yielded the lowest loss.")
message = messaging.StopTraining(
weights=self._best or self.model.get_weights(),
loss=min(self._loss.values()) if self._loss else float("nan"),
loss=min(self._losses, default=float("nan")),
rounds=rounds,
)
self.logger.info("Notifying clients that training is over.")
......
......@@ -127,12 +127,20 @@ class TrainingConfig:
class EvaluateConfig(TrainingConfig):
"""Dataclass wrapping parameters for an evaluation round.
Exclusive attributes
--------------------
frequency: int
Number of training rounds to run between evaluation ones.
By default, run an evaluation round after each training one.
Please refer to the parent class `TrainingConfig` for details
on the wrapped parameters / attribute. Note that `n_epoch` is
dropped when this config is turned into an EvaluationRequest
message.
on the other wrapped parameters / attributes.
Note that `n_epoch` is dropped when this config is turned into
an EvaluationRequest message.
"""
frequency: int = 1
drop_remainder: bool = False
@property
......@@ -238,6 +246,9 @@ class FairnessConfig:
----------
batch_size: int
Number of samples per processed data batch.
frequency: int
Number of training rounds to run between fairness ones.
By default, run a fairness round before each training one.
n_batch: int or None, default=None
Optional maximum number of batches to draw.
If None, use the entire training dataset.
......@@ -249,5 +260,6 @@ class FairnessConfig:
"""
batch_size: int = 32
frequency: int = 1
n_batch: Optional[int] = None
thresh: Optional[float] = None
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment