From 23726d573eeb503f1881f0eafccea627684ccd4a Mon Sep 17 00:00:00 2001
From: Paul Andrey <paul.andrey@inria.fr>
Date: Thu, 21 Sep 2023 16:49:33 +0200
Subject: [PATCH] Add 'TrainingManager.(train|evaluate)_under_constraints' to
 the API.

- Until now, the public methods for training and evaluation using
  the `declearn.main.utils.TrainingManager` class (and its DP-SGD
  counterpart) required message-wrapped inputs and emitted similar
  outputs.
- With this commit, the previously-private `train_under_constraints`
  and `evaluate_under_constraints` routines are made public and thus
  part of the declearn API, with some minor refactoring to make them
  more user-friendly.
- The rationale of this change is to enable using `TrainingManager`
  outside of our `FederatedClient`/`FederatedServer` orchestration,
  notably when simulating FL training or testing client-side code.
  It may also be helpful to end-users that would like to build on
  declearn but implement their own orchestration tools or algorithm
  loops.
- In the future (declearn >=3.0), we may go one step further and
  take all the messaging-related instructions out of the current
  class. The class may also be moved to a different namespace, e.g.
  a new 'declearn.train' module; but this is entirely out of scope
  for now.
---
 declearn/main/utils/_training.py | 94 +++++++++++++++++++++-----------
 1 file changed, 63 insertions(+), 31 deletions(-)

diff --git a/declearn/main/utils/_training.py b/declearn/main/utils/_training.py
index 901d9538..a82e7114 100644
--- a/declearn/main/utils/_training.py
+++ b/declearn/main/utils/_training.py
@@ -18,7 +18,7 @@
 """Wrapper to run local training and evaluation rounds in a FL process."""
 
 import logging
-from typing import Any, ClassVar, Dict, List, Optional, Union
+from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union
 
 import numpy as np
 
@@ -169,7 +169,7 @@ class TrainingManager:
             "Training local model for %s epochs | %s steps | %s seconds.",
             *params,
         )
-        effort = self._train_under_constraints(message.batches, *params)
+        effort = self.train_under_constraints(message.batches, *params)
         # Compute model updates and collect auxiliary variables.
         self.logger.info("Packing local updates to be sent to the server.")
         return messaging.TrainReply(
@@ -180,21 +180,25 @@ class TrainingManager:
             t_spent=round(effort["t_spent"], 3),
         )
 
-    def _train_under_constraints(
+    def train_under_constraints(
         self,
         batch_cfg: Dict[str, Any],
-        n_epoch: Optional[int],
-        n_steps: Optional[int],
-        timeout: Optional[int],
+        n_epoch: Optional[int] = 1,
+        n_steps: Optional[int] = None,
+        timeout: Optional[int] = None,
     ) -> Dict[str, float]:
-        """Backend code to run local SGD steps under effort constraints.
+        """Run local SGD steps under effort constraints.
+
+        This is the core backend to the `training_round` method,
+        which further handles message parsing and passing, as well
+        as exception catching.
 
         Parameters
         ----------
         batch_cfg: Dict[str, Any]
             Keyword arguments for `self.train_data.generate_batches`
             i.e. specifications of batches used in local SGD steps.
-        n_epoch: int or None, default=None
+        n_epoch: int or None, default=1
             Maximum number of local training epochs to perform.
             May be overridden by `n_steps` or `timeout`.
         n_steps: int or None, default=None
@@ -286,12 +290,7 @@ class TrainingManager:
         )
         # Try running the evaluation round.
         try:
-            # Update the model's weights and evaluate on the local dataset.
-            # Revise: make the weights' update optional.
-            self.model.set_weights(message.weights, trainable=True)
-            return self._evaluate_under_constraints(
-                message.batches, message.n_steps, message.timeout
-            )
+            return self._evaluation_round(message)
         # In case of failure, wrap the exception as an Error message.
         except Exception as exception:  # pylint: disable=broad-except
             self.logger.error(
@@ -299,13 +298,41 @@ class TrainingManager:
             )
             return messaging.Error(repr(exception))
 
-    def _evaluate_under_constraints(
+    def _evaluation_round(
+        self,
+        message: messaging.EvaluationRequest,
+    ) -> messaging.EvaluationReply:
+        """Backend to `evaluation_round`, without exception capture hooks."""
+        # Update the model's weights and evaluate on the local dataset.
+        # Revise: make the weights' update optional.
+        self.model.set_weights(message.weights, trainable=True)
+        metrics, states, effort = self.evaluate_under_constraints(
+            message.batches, message.n_steps, message.timeout
+        )
+        # Pack the resulting information into a message.
+        self.logger.info("Packing local results to be sent to the server.")
+        return messaging.EvaluationReply(
+            loss=float(metrics["loss"]),
+            metrics=states,
+            n_steps=int(effort["n_steps"]),
+            t_spent=round(effort["t_spent"], 3),
+        )
+
+    def evaluate_under_constraints(
         self,
         batch_cfg: Dict[str, Any],
         n_steps: Optional[int] = None,
         timeout: Optional[int] = None,
-    ) -> messaging.EvaluationReply:
-        """Backend code to run local loss computation under effort constraints.
+    ) -> Tuple[
+        Dict[str, Union[float, np.ndarray]],
+        Dict[str, Dict[str, Union[float, np.ndarray]]],
+        Dict[str, float],
+    ]:
+        """Run local loss computation under effort constraints.
+
+        This is the core backend to the `evaluation_round` method,
+        which further handles message parsing and passing, as well
+        as exception catching.
 
         Parameters
         ----------
@@ -320,10 +347,21 @@ class TrainingManager:
 
         Returns
         -------
-        reply: messaging.EvaluationReply
-            EvaluationReply message wrapping the computed loss on the
-            local validation (or, if absent, training) dataset as well
-            as the number of steps and the time taken to obtain it.
+        metrics:
+            Computed metrics, as a dict with float or array values.
+        states:
+            Computed metrics, as partial values that may be shared
+            with other agents to federatively compute final values
+            with the same specs as `metrics`.
+        effort:
+            Dictionary storing information on the computational
+            effort effectively performed:
+            * n_epoch: int
+                Number of evaluation epochs completed.
+            * n_steps: int
+                Number of evaluation steps completed.
+            * t_spent: float
+                Time spent running training steps (in seconds).
         """
         # Set up effort constraints under which to operate.
         constraints = ConstraintSet(
@@ -342,18 +380,12 @@ class TrainingManager:
                 break
         # Gather the computed metrics and computational effort information.
         effort = constraints.get_values()
-        result = self.metrics.get_result()
+        values = self.metrics.get_result()
         states = self.metrics.get_states()
         self.logger.log(
             LOGGING_LEVEL_MAJOR,
             "Local scalar evaluation metrics: %s",
-            {k: v for k, v in result.items() if isinstance(v, float)},
-        )
-        # Pack the result and computational effort information into a message.
-        self.logger.info("Packing local results to be sent to the server.")
-        return messaging.EvaluationReply(
-            loss=float(result["loss"]),
-            metrics=states,
-            n_steps=int(effort["n_steps"]),
-            t_spent=round(effort["t_spent"], 3),
+            {k: v for k, v in values.items() if isinstance(v, float)},
         )
+        # Return the metrics' values, their states and the effort information.
+        return values, states, effort
-- 
GitLab