From 23726d573eeb503f1881f0eafccea627684ccd4a Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Thu, 21 Sep 2023 16:49:33 +0200 Subject: [PATCH] Add 'TrainingManager.(train|evaluate)_under_constraints' to the API. - Until now, the public methods for training and evaluation using the `declearn.main.utils.TrainingManager` class (and its DP-SGD counterpart) required message-wrapped inputs and emitted similar outputs. - With this commit, the previously-private `train_under_constraints` and `evaluate_under_constraints` routines are made public and thus part of the declearn API, with some minor refactoring to make them more user-friendly. - The rationale of this change is to enable using `TrainingManager` outside of our `FederatedClient`/`FederatedServer` orchestration, notably when simulating FL training or testing client-side code. It may also be helpful to end-users that would like to build on declearn but implement their own orchestration tools or algorithm loops. - In the future (declearn >=3.0), we may go one step further and take all the messaging-related instructions out of the current class. The class may also be moved to a different namespace, e.g. a new 'declearn.train' module; but this is entirely out of scope for now. --- declearn/main/utils/_training.py | 94 +++++++++++++++++++++----------- 1 file changed, 63 insertions(+), 31 deletions(-) diff --git a/declearn/main/utils/_training.py b/declearn/main/utils/_training.py index 901d9538..a82e7114 100644 --- a/declearn/main/utils/_training.py +++ b/declearn/main/utils/_training.py @@ -18,7 +18,7 @@ """Wrapper to run local training and evaluation rounds in a FL process.""" import logging -from typing import Any, ClassVar, Dict, List, Optional, Union +from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union import numpy as np @@ -169,7 +169,7 @@ class TrainingManager: "Training local model for %s epochs | %s steps | %s seconds.", *params, ) - effort = self._train_under_constraints(message.batches, *params) + effort = self.train_under_constraints(message.batches, *params) # Compute model updates and collect auxiliary variables. self.logger.info("Packing local updates to be sent to the server.") return messaging.TrainReply( @@ -180,21 +180,25 @@ class TrainingManager: t_spent=round(effort["t_spent"], 3), ) - def _train_under_constraints( + def train_under_constraints( self, batch_cfg: Dict[str, Any], - n_epoch: Optional[int], - n_steps: Optional[int], - timeout: Optional[int], + n_epoch: Optional[int] = 1, + n_steps: Optional[int] = None, + timeout: Optional[int] = None, ) -> Dict[str, float]: - """Backend code to run local SGD steps under effort constraints. + """Run local SGD steps under effort constraints. + + This is the core backend to the `training_round` method, + which further handles message parsing and passing, as well + as exception catching. Parameters ---------- batch_cfg: Dict[str, Any] Keyword arguments for `self.train_data.generate_batches` i.e. specifications of batches used in local SGD steps. - n_epoch: int or None, default=None + n_epoch: int or None, default=1 Maximum number of local training epochs to perform. May be overridden by `n_steps` or `timeout`. n_steps: int or None, default=None @@ -286,12 +290,7 @@ class TrainingManager: ) # Try running the evaluation round. try: - # Update the model's weights and evaluate on the local dataset. - # Revise: make the weights' update optional. - self.model.set_weights(message.weights, trainable=True) - return self._evaluate_under_constraints( - message.batches, message.n_steps, message.timeout - ) + return self._evaluation_round(message) # In case of failure, wrap the exception as an Error message. except Exception as exception: # pylint: disable=broad-except self.logger.error( @@ -299,13 +298,41 @@ class TrainingManager: ) return messaging.Error(repr(exception)) - def _evaluate_under_constraints( + def _evaluation_round( + self, + message: messaging.EvaluationRequest, + ) -> messaging.EvaluationReply: + """Backend to `evaluation_round`, without exception capture hooks.""" + # Update the model's weights and evaluate on the local dataset. + # Revise: make the weights' update optional. + self.model.set_weights(message.weights, trainable=True) + metrics, states, effort = self.evaluate_under_constraints( + message.batches, message.n_steps, message.timeout + ) + # Pack the resulting information into a message. + self.logger.info("Packing local results to be sent to the server.") + return messaging.EvaluationReply( + loss=float(metrics["loss"]), + metrics=states, + n_steps=int(effort["n_steps"]), + t_spent=round(effort["t_spent"], 3), + ) + + def evaluate_under_constraints( self, batch_cfg: Dict[str, Any], n_steps: Optional[int] = None, timeout: Optional[int] = None, - ) -> messaging.EvaluationReply: - """Backend code to run local loss computation under effort constraints. + ) -> Tuple[ + Dict[str, Union[float, np.ndarray]], + Dict[str, Dict[str, Union[float, np.ndarray]]], + Dict[str, float], + ]: + """Run local loss computation under effort constraints. + + This is the core backend to the `evaluation_round` method, + which further handles message parsing and passing, as well + as exception catching. Parameters ---------- @@ -320,10 +347,21 @@ class TrainingManager: Returns ------- - reply: messaging.EvaluationReply - EvaluationReply message wrapping the computed loss on the - local validation (or, if absent, training) dataset as well - as the number of steps and the time taken to obtain it. + metrics: + Computed metrics, as a dict with float or array values. + states: + Computed metrics, as partial values that may be shared + with other agents to federatively compute final values + with the same specs as `metrics`. + effort: + Dictionary storing information on the computational + effort effectively performed: + * n_epoch: int + Number of evaluation epochs completed. + * n_steps: int + Number of evaluation steps completed. + * t_spent: float + Time spent running training steps (in seconds). """ # Set up effort constraints under which to operate. constraints = ConstraintSet( @@ -342,18 +380,12 @@ class TrainingManager: break # Gather the computed metrics and computational effort information. effort = constraints.get_values() - result = self.metrics.get_result() + values = self.metrics.get_result() states = self.metrics.get_states() self.logger.log( LOGGING_LEVEL_MAJOR, "Local scalar evaluation metrics: %s", - {k: v for k, v in result.items() if isinstance(v, float)}, - ) - # Pack the result and computational effort information into a message. - self.logger.info("Packing local results to be sent to the server.") - return messaging.EvaluationReply( - loss=float(result["loss"]), - metrics=states, - n_steps=int(effort["n_steps"]), - t_spent=round(effort["t_spent"], 3), + {k: v for k, v in values.items() if isinstance(v, float)}, ) + # Return the metrics' values, their states and the effort information. + return values, states, effort -- GitLab