Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit 23726d57 authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Add 'TrainingManager.(train|evaluate)_under_constraints' to the API.

- Until now, the public methods for training and evaluation using
  the `declearn.main.utils.TrainingManager` class (and its DP-SGD
  counterpart) required message-wrapped inputs and emitted similar
  outputs.
- With this commit, the previously-private `train_under_constraints`
  and `evaluate_under_constraints` routines are made public and thus
  part of the declearn API, with some minor refactoring to make them
  more user-friendly.
- The rationale of this change is to enable using `TrainingManager`
  outside of our `FederatedClient`/`FederatedServer` orchestration,
  notably when simulating FL training or testing client-side code.
  It may also be helpful to end-users that would like to build on
  declearn but implement their own orchestration tools or algorithm
  loops.
- In the future (declearn >=3.0), we may go one step further and
  take all the messaging-related instructions out of the current
  class. The class may also be moved to a different namespace, e.g.
  a new 'declearn.train' module; but this is entirely out of scope
  for now.
parent e820470c
No related branches found
No related tags found
1 merge request!58Enable recording training loss values
......@@ -18,7 +18,7 @@
"""Wrapper to run local training and evaluation rounds in a FL process."""
import logging
from typing import Any, ClassVar, Dict, List, Optional, Union
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union
import numpy as np
......@@ -169,7 +169,7 @@ class TrainingManager:
"Training local model for %s epochs | %s steps | %s seconds.",
*params,
)
effort = self._train_under_constraints(message.batches, *params)
effort = self.train_under_constraints(message.batches, *params)
# Compute model updates and collect auxiliary variables.
self.logger.info("Packing local updates to be sent to the server.")
return messaging.TrainReply(
......@@ -180,21 +180,25 @@ class TrainingManager:
t_spent=round(effort["t_spent"], 3),
)
def _train_under_constraints(
def train_under_constraints(
self,
batch_cfg: Dict[str, Any],
n_epoch: Optional[int],
n_steps: Optional[int],
timeout: Optional[int],
n_epoch: Optional[int] = 1,
n_steps: Optional[int] = None,
timeout: Optional[int] = None,
) -> Dict[str, float]:
"""Backend code to run local SGD steps under effort constraints.
"""Run local SGD steps under effort constraints.
This is the core backend to the `training_round` method,
which further handles message parsing and passing, as well
as exception catching.
Parameters
----------
batch_cfg: Dict[str, Any]
Keyword arguments for `self.train_data.generate_batches`
i.e. specifications of batches used in local SGD steps.
n_epoch: int or None, default=None
n_epoch: int or None, default=1
Maximum number of local training epochs to perform.
May be overridden by `n_steps` or `timeout`.
n_steps: int or None, default=None
......@@ -286,12 +290,7 @@ class TrainingManager:
)
# Try running the evaluation round.
try:
# Update the model's weights and evaluate on the local dataset.
# Revise: make the weights' update optional.
self.model.set_weights(message.weights, trainable=True)
return self._evaluate_under_constraints(
message.batches, message.n_steps, message.timeout
)
return self._evaluation_round(message)
# In case of failure, wrap the exception as an Error message.
except Exception as exception: # pylint: disable=broad-except
self.logger.error(
......@@ -299,13 +298,41 @@ class TrainingManager:
)
return messaging.Error(repr(exception))
def _evaluate_under_constraints(
def _evaluation_round(
self,
message: messaging.EvaluationRequest,
) -> messaging.EvaluationReply:
"""Backend to `evaluation_round`, without exception capture hooks."""
# Update the model's weights and evaluate on the local dataset.
# Revise: make the weights' update optional.
self.model.set_weights(message.weights, trainable=True)
metrics, states, effort = self.evaluate_under_constraints(
message.batches, message.n_steps, message.timeout
)
# Pack the resulting information into a message.
self.logger.info("Packing local results to be sent to the server.")
return messaging.EvaluationReply(
loss=float(metrics["loss"]),
metrics=states,
n_steps=int(effort["n_steps"]),
t_spent=round(effort["t_spent"], 3),
)
def evaluate_under_constraints(
self,
batch_cfg: Dict[str, Any],
n_steps: Optional[int] = None,
timeout: Optional[int] = None,
) -> messaging.EvaluationReply:
"""Backend code to run local loss computation under effort constraints.
) -> Tuple[
Dict[str, Union[float, np.ndarray]],
Dict[str, Dict[str, Union[float, np.ndarray]]],
Dict[str, float],
]:
"""Run local loss computation under effort constraints.
This is the core backend to the `evaluation_round` method,
which further handles message parsing and passing, as well
as exception catching.
Parameters
----------
......@@ -320,10 +347,21 @@ class TrainingManager:
Returns
-------
reply: messaging.EvaluationReply
EvaluationReply message wrapping the computed loss on the
local validation (or, if absent, training) dataset as well
as the number of steps and the time taken to obtain it.
metrics:
Computed metrics, as a dict with float or array values.
states:
Computed metrics, as partial values that may be shared
with other agents to federatively compute final values
with the same specs as `metrics`.
effort:
Dictionary storing information on the computational
effort effectively performed:
* n_epoch: int
Number of evaluation epochs completed.
* n_steps: int
Number of evaluation steps completed.
* t_spent: float
Time spent running training steps (in seconds).
"""
# Set up effort constraints under which to operate.
constraints = ConstraintSet(
......@@ -342,18 +380,12 @@ class TrainingManager:
break
# Gather the computed metrics and computational effort information.
effort = constraints.get_values()
result = self.metrics.get_result()
values = self.metrics.get_result()
states = self.metrics.get_states()
self.logger.log(
LOGGING_LEVEL_MAJOR,
"Local scalar evaluation metrics: %s",
{k: v for k, v in result.items() if isinstance(v, float)},
)
# Pack the result and computational effort information into a message.
self.logger.info("Packing local results to be sent to the server.")
return messaging.EvaluationReply(
loss=float(result["loss"]),
metrics=states,
n_steps=int(effort["n_steps"]),
t_spent=round(effort["t_spent"], 3),
{k: v for k, v in values.items() if isinstance(v, float)},
)
# Return the metrics' values, their states and the effort information.
return values, states, effort
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment