diff --git a/declearn/main/_client.py b/declearn/main/_client.py index b712d72150794aad9c0daa6e61a83f037a8a4f21..017eeaddef7d8674821bf6c9df0bf1f27390a12a 100644 --- a/declearn/main/_client.py +++ b/declearn/main/_client.py @@ -20,8 +20,11 @@ import asyncio import dataclasses import logging +import os from typing import Any, Dict, Optional, Union +import numpy as np + from declearn.communication import NetworkClientConfig, messaging from declearn.communication.api import NetworkClient from declearn.dataset import Dataset, load_dataset_from_json @@ -356,6 +359,16 @@ class FederatedClient: assert self.trainmanager is not None # Run the training round. reply = self.trainmanager.training_round(message) + # Collect and optionally record batch-wise training losses. + # Note: collection enables purging them from memory. + losses = self.trainmanager.model.collect_training_losses() + if self.ckptr is not None: + self.ckptr.save_metrics( + metrics={"training_losses": np.array(losses)}, + prefix="training_losses", + append=True, + timestamp=f"round_{message.round_i}", + ) # Send training results (or error message) to the server. await self.netwk.send_message(reply) @@ -412,7 +425,7 @@ class FederatedClient: message.loss, ) if self.ckptr: - path = f"{self.ckptr.folder}/model_state_best.json" + path = os.path.join(self.ckptr.folder, "model_state_best.json") self.logger.info("Checkpointing final weights under %s.", path) assert self.trainmanager is not None # for mypy self.trainmanager.model.set_weights(message.weights) diff --git a/declearn/main/utils/_checkpoint.py b/declearn/main/utils/_checkpoint.py index e6677accf0ea31cc28e06e45a53ea9b39d2b36c2..bdd4c462475b2d05659fc54820e98b388bf8e1fd 100644 --- a/declearn/main/utils/_checkpoint.py +++ b/declearn/main/utils/_checkpoint.py @@ -23,7 +23,7 @@ from datetime import datetime from typing import Any, Dict, List, Optional, Union import numpy as np -import pandas as pd # type: ignore +import pandas as pd from typing_extensions import Self # future: import from typing (py >=3.11) from declearn.model.api import Model @@ -319,16 +319,17 @@ class Checkpointer: } # Filter out scalar metrics and write them to a csv file. scalars = {k: v for k, v in scores.items() if isinstance(v, float)} - fpath = os.path.join(self.folder, f"{prefix}.csv") - pd.DataFrame(scalars, index=[timestamp]).to_csv( - fpath, - sep=",", - mode=("a" if append else "w"), - header=not (append and os.path.isfile(fpath)), - index=True, - index_label="timestamp", - encoding="utf-8", - ) + if scalars: + fpath = os.path.join(self.folder, f"{prefix}.csv") + pd.DataFrame(scalars, index=[timestamp]).to_csv( + fpath, + sep=",", + mode=("a" if append else "w"), + header=not (append and os.path.isfile(fpath)), + index=True, + index_label="timestamp", + encoding="utf-8", + ) # Write the full set of metrics to a JSON file. jdump = json.dumps({timestamp: scores})[1:-1] # bracket-less dict fpath = os.path.join(self.folder, f"{prefix}.json") diff --git a/declearn/main/utils/_training.py b/declearn/main/utils/_training.py index 901d953868e64b48a84d30a945c59d680862edb8..a82e7114f1371fa29f8f6d4053e6487c578888a8 100644 --- a/declearn/main/utils/_training.py +++ b/declearn/main/utils/_training.py @@ -18,7 +18,7 @@ """Wrapper to run local training and evaluation rounds in a FL process.""" import logging -from typing import Any, ClassVar, Dict, List, Optional, Union +from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union import numpy as np @@ -169,7 +169,7 @@ class TrainingManager: "Training local model for %s epochs | %s steps | %s seconds.", *params, ) - effort = self._train_under_constraints(message.batches, *params) + effort = self.train_under_constraints(message.batches, *params) # Compute model updates and collect auxiliary variables. self.logger.info("Packing local updates to be sent to the server.") return messaging.TrainReply( @@ -180,21 +180,25 @@ class TrainingManager: t_spent=round(effort["t_spent"], 3), ) - def _train_under_constraints( + def train_under_constraints( self, batch_cfg: Dict[str, Any], - n_epoch: Optional[int], - n_steps: Optional[int], - timeout: Optional[int], + n_epoch: Optional[int] = 1, + n_steps: Optional[int] = None, + timeout: Optional[int] = None, ) -> Dict[str, float]: - """Backend code to run local SGD steps under effort constraints. + """Run local SGD steps under effort constraints. + + This is the core backend to the `training_round` method, + which further handles message parsing and passing, as well + as exception catching. Parameters ---------- batch_cfg: Dict[str, Any] Keyword arguments for `self.train_data.generate_batches` i.e. specifications of batches used in local SGD steps. - n_epoch: int or None, default=None + n_epoch: int or None, default=1 Maximum number of local training epochs to perform. May be overridden by `n_steps` or `timeout`. n_steps: int or None, default=None @@ -286,12 +290,7 @@ class TrainingManager: ) # Try running the evaluation round. try: - # Update the model's weights and evaluate on the local dataset. - # Revise: make the weights' update optional. - self.model.set_weights(message.weights, trainable=True) - return self._evaluate_under_constraints( - message.batches, message.n_steps, message.timeout - ) + return self._evaluation_round(message) # In case of failure, wrap the exception as an Error message. except Exception as exception: # pylint: disable=broad-except self.logger.error( @@ -299,13 +298,41 @@ class TrainingManager: ) return messaging.Error(repr(exception)) - def _evaluate_under_constraints( + def _evaluation_round( + self, + message: messaging.EvaluationRequest, + ) -> messaging.EvaluationReply: + """Backend to `evaluation_round`, without exception capture hooks.""" + # Update the model's weights and evaluate on the local dataset. + # Revise: make the weights' update optional. + self.model.set_weights(message.weights, trainable=True) + metrics, states, effort = self.evaluate_under_constraints( + message.batches, message.n_steps, message.timeout + ) + # Pack the resulting information into a message. + self.logger.info("Packing local results to be sent to the server.") + return messaging.EvaluationReply( + loss=float(metrics["loss"]), + metrics=states, + n_steps=int(effort["n_steps"]), + t_spent=round(effort["t_spent"], 3), + ) + + def evaluate_under_constraints( self, batch_cfg: Dict[str, Any], n_steps: Optional[int] = None, timeout: Optional[int] = None, - ) -> messaging.EvaluationReply: - """Backend code to run local loss computation under effort constraints. + ) -> Tuple[ + Dict[str, Union[float, np.ndarray]], + Dict[str, Dict[str, Union[float, np.ndarray]]], + Dict[str, float], + ]: + """Run local loss computation under effort constraints. + + This is the core backend to the `evaluation_round` method, + which further handles message parsing and passing, as well + as exception catching. Parameters ---------- @@ -320,10 +347,21 @@ class TrainingManager: Returns ------- - reply: messaging.EvaluationReply - EvaluationReply message wrapping the computed loss on the - local validation (or, if absent, training) dataset as well - as the number of steps and the time taken to obtain it. + metrics: + Computed metrics, as a dict with float or array values. + states: + Computed metrics, as partial values that may be shared + with other agents to federatively compute final values + with the same specs as `metrics`. + effort: + Dictionary storing information on the computational + effort effectively performed: + * n_epoch: int + Number of evaluation epochs completed. + * n_steps: int + Number of evaluation steps completed. + * t_spent: float + Time spent running training steps (in seconds). """ # Set up effort constraints under which to operate. constraints = ConstraintSet( @@ -342,18 +380,12 @@ class TrainingManager: break # Gather the computed metrics and computational effort information. effort = constraints.get_values() - result = self.metrics.get_result() + values = self.metrics.get_result() states = self.metrics.get_states() self.logger.log( LOGGING_LEVEL_MAJOR, "Local scalar evaluation metrics: %s", - {k: v for k, v in result.items() if isinstance(v, float)}, - ) - # Pack the result and computational effort information into a message. - self.logger.info("Packing local results to be sent to the server.") - return messaging.EvaluationReply( - loss=float(result["loss"]), - metrics=states, - n_steps=int(effort["n_steps"]), - t_spent=round(effort["t_spent"], 3), + {k: v for k, v in values.items() if isinstance(v, float)}, ) + # Return the metrics' values, their states and the effort information. + return values, states, effort diff --git a/declearn/model/api/_model.py b/declearn/model/api/_model.py index e74102bbdfdc1c174f633c08eaf6ea23b62e949d..f012dfee9c9c44979a029c073cf34b0d157ba6d4 100644 --- a/declearn/model/api/_model.py +++ b/declearn/model/api/_model.py @@ -18,7 +18,7 @@ """Model abstraction API.""" from abc import ABCMeta, abstractmethod -from typing import Any, Dict, Generic, Optional, Set, Tuple, TypeVar +from typing import Any, Dict, Generic, List, Optional, Set, Tuple, TypeVar import numpy as np from typing_extensions import Self # future: import from typing (py >=3.11) @@ -62,6 +62,8 @@ class Model(Generic[VectorT], metaclass=ABCMeta): ) -> None: """Instantiate a Model interface wrapping a 'model' object.""" self._model = model + # Declare a private list where to record batch-wise training losses. + self._loss_history = [] # type: List[float] def get_wrapped_model(self) -> Any: """Getter to access the wrapped framework-specific model object. @@ -202,6 +204,10 @@ class Model(Generic[VectorT], metaclass=ABCMeta): to its trainable parameters for the given data batch. Optionally clip sample-wise gradients before batch-averaging. + Record the loss value over the batch, which may be collected + (and thereof purged from the internal memory) by calling the + `collect_training_losses` method. + Parameters ---------- batch: declearn.typing.Batch @@ -227,6 +233,25 @@ class Model(Generic[VectorT], metaclass=ABCMeta): ) -> None: """Apply updates to the model's weights.""" + def collect_training_losses( + self, + ) -> List[float]: + """Collect batch-wise training losses accumulated over time. + + Return all recorded batch-averaged loss values computed a + part of `compute_batch_gradients` calls, and clear them + from memory, so that next time this method is called, only + new values are returned. + + Returns + ------- + losses: + List of bath-averaged loss values computed over inputs + to the `compute_batch_gradients` method. + """ + losses, self._loss_history = self._loss_history, [] + return losses + @abstractmethod def compute_batch_predictions( self, diff --git a/declearn/model/haiku/_model.py b/declearn/model/haiku/_model.py index 1f81dd8f475f722846177637b04f3a1246515f89..960b9f6c5ea207b244a56481bb91349e436b15ef 100644 --- a/declearn/model/haiku/_model.py +++ b/declearn/model/haiku/_model.py @@ -385,23 +385,29 @@ class HaikuModel(Model): rng = next(self._rng_gen) # Compute batch-averaged gradients, opt. clipped on a per-sample basis. if max_norm: - grads = self._clipped_grad_fn( + grads, loss = self._clipped_grads_and_loss_fn( train_params, fixed_params, rng, inputs, max_norm ) grads = [value.mean(0) for value in grads] else: - grads = jax.tree_util.tree_leaves( - self._grad_fn(train_params, fixed_params, rng, inputs) + loss, grads_tree = self._loss_and_grads_fn( + train_params, fixed_params, rng, inputs ) + grads = jax.tree_util.tree_leaves(grads_tree) + # Record the batch-averaged loss value. + self._loss_history.append(float(np.array(loss).mean())) # Return the gradients, flattened into a JaxNumpyVector container. return JaxNumpyVector(dict(zip(self._trainable, grads))) @functools.cached_property - def _grad_fn( + def _loss_and_grads_fn( self, - ) -> Callable[[hk.Params, hk.Params, jax.Array, JaxBatch], hk.Params]: + ) -> Callable[ + [hk.Params, hk.Params, jax.Array, JaxBatch], + Tuple[jax.Array, hk.Params], + ]: """Lazy-built jax function to compute batch-averaged gradients.""" - return jax.jit(jax.grad(self._forward)) + return jax.jit(jax.value_and_grad(self._forward)) def _forward( self, @@ -436,10 +442,11 @@ class HaikuModel(Model): return jnp.mean(s_loss) @functools.cached_property - def _clipped_grad_fn( + def _clipped_grads_and_loss_fn( self, ) -> Callable[ - [hk.Params, hk.Params, jax.Array, JaxBatch, float], List[jax.Array] + [hk.Params, hk.Params, jax.Array, JaxBatch, float], + Tuple[List[jax.Array], jax.Array], ]: """Lazy-built jax function to compute clipped sample-wise gradients. @@ -447,17 +454,17 @@ class HaikuModel(Model): applying optional parameters to pytrees. """ - def clipped_grad_fn( + def clipped_grads_and_loss_fn( train_params: hk.Params, fixed_params: hk.Params, rng: jax.Array, batch: JaxBatch, max_norm: float, - ) -> List[jax.Array]: + ) -> Tuple[List[jax.Array], jax.Array]: """Compute and clip gradients wrt parameters for a sample.""" inputs, y_true, s_wght = batch batch = (inputs, y_true, None) - grads = jax.grad(self._forward)( + loss, grads = jax.value_and_grad(self._forward)( train_params, fixed_params, rng, batch ) grads_flat = [ @@ -466,10 +473,10 @@ class HaikuModel(Model): ] if s_wght is not None: grads_flat = [g * s_wght for g in grads_flat] - return grads_flat + return grads_flat, loss in_axes = [None, None, None, 0, None] # map on inputs' first dimension - return jax.jit(jax.vmap(clipped_grad_fn, in_axes)) + return jax.jit(jax.vmap(clipped_grads_and_loss_fn, in_axes)) def _unpack_batch( self, diff --git a/declearn/model/sklearn/_sgd.py b/declearn/model/sklearn/_sgd.py index 3862d0af6dc21b9c0210270e9004f6c18795a0ce..160a6d3970295c4b2e6c6fd66653b8b8f945b07d 100644 --- a/declearn/model/sklearn/_sgd.py +++ b/declearn/model/sklearn/_sgd.py @@ -397,6 +397,11 @@ class SklearnSGDModel(Model): # Optionally re-weight gradients based on sample weights. if s_wght is not None: grad = [g * w for g, w in zip(grad, s_wght)] + # Compute and record the loss value on the entire batch. + loss = self.loss_function( + y_data, self._predict(x_data) # type: ignore + ) + self._loss_history.append(float(loss.mean())) # Batch-average the gradients and return them. return sum(grad) / len(grad) # type: ignore diff --git a/declearn/model/tensorflow/_model.py b/declearn/model/tensorflow/_model.py index 59260eaededbf9182178d82ec282c68964867e27..812fad761cae780e58386253bd464697de6c69dc 100644 --- a/declearn/model/tensorflow/_model.py +++ b/declearn/model/tensorflow/_model.py @@ -237,10 +237,11 @@ class TensorflowModel(Model): with tf.device(self._device): data = self._unpack_batch(batch) if max_norm is None: - grads = self._compute_batch_gradients(*data) + grads, loss = self._compute_batch_gradients(*data) else: norm = tf.constant(max_norm) - grads = self._compute_clipped_gradients(*data, norm) + grads, loss = self._compute_clipped_gradients(*data, norm) + self._loss_history.append(float(loss.numpy())) grads_and_vars = zip(grads, self._model.trainable_weights) return TensorflowVector( {var.name: grad for grad, var in grads_and_vars} @@ -267,14 +268,14 @@ class TensorflowModel(Model): inputs: tf.Tensor, y_true: Optional[tf.Tensor], s_wght: Optional[tf.Tensor], - ) -> List[tf.Tensor]: + ) -> Tuple[List[tf.Tensor], tf.Tensor]: """Compute and return batch-averaged gradients of trainable weights.""" with tf.GradientTape() as tape: y_pred = self._model(inputs, training=True) loss = self._model.compute_loss(inputs, y_true, y_pred, s_wght) loss = tf.reduce_mean(loss) grad = tape.gradient(loss, self._model.trainable_weights) - return grad + return grad, loss @tf.function # optimize tensorflow runtime def _compute_clipped_gradients( @@ -283,26 +284,26 @@ class TensorflowModel(Model): y_true: Optional[tf.Tensor], s_wght: Optional[tf.Tensor], max_norm: Union[tf.Tensor, float], - ) -> List[tf.Tensor]: + ) -> Tuple[List[tf.Tensor], tf.Tensor]: """Compute and return sample-wise-clipped batch-averaged gradients.""" - grad = self._compute_samplewise_gradients(inputs, y_true) + grad, loss = self._compute_samplewise_gradients(inputs, y_true) if s_wght is None: s_wght = tf.cast(1, grad[0].dtype) grad = self._clip_and_average_gradients(grad, max_norm, s_wght) - return grad + return grad, loss @tf.function # optimize tensorflow runtime def _compute_samplewise_gradients( self, inputs: tf.Tensor, y_true: Optional[tf.Tensor], - ) -> List[tf.Tensor]: + ) -> Tuple[List[tf.Tensor], tf.Tensor]: """Compute and return sample-wise gradients for a given batch.""" with tf.GradientTape() as tape: y_pred = self._model(inputs, training=True) loss = self._model.compute_loss(inputs, y_true, y_pred) grad = tape.jacobian(loss, self._model.trainable_weights) - return grad + return grad, tf.reduce_mean(loss) @staticmethod @tf.function # optimize tensorflow runtime diff --git a/declearn/model/torch/_model.py b/declearn/model/torch/_model.py index 9451846d2bb9474cde12aaf7636ce89a145578e3..3f79113d72d36fec04c40383bb9d3de0c10b92e3 100644 --- a/declearn/model/torch/_model.py +++ b/declearn/model/torch/_model.py @@ -253,6 +253,7 @@ class TorchModel(Model): y_pred = self._model(*inputs) loss = self._compute_loss(y_pred, y_true, s_wght) loss.backward() + self._loss_history.append(float(loss.detach().cpu().numpy().mean())) # Collect weights' gradients and return them in a Vector container. grads = { k: p.grad.detach().clone() @@ -320,7 +321,10 @@ class TorchModel(Model): s_wght=(s_wght is not None), ) with torch.no_grad(): - grads = grads_fn(inputs, y_true, s_wght, clip=clip) # type: ignore + grads, loss = grads_fn( + inputs, y_true, s_wght, clip=clip + ) # type: ignore + self._loss_history.append(float(loss.cpu().numpy().mean())) return TorchVector(grads) @functools.lru_cache diff --git a/declearn/model/torch/_samplewise/__init__.py b/declearn/model/torch/_samplewise/__init__.py index 1448a18bb210279d99343f8637cc8d7587cf12a5..3060e34c173ae073afe54d8d92dff58b740dd267 100644 --- a/declearn/model/torch/_samplewise/__init__.py +++ b/declearn/model/torch/_samplewise/__init__.py @@ -62,10 +62,12 @@ def build_samplewise_grads_fn( Returns ------- - grads_fn: callable[[inputs, y_true, s_wght, clip], grads] + grads_fn: callable[[inputs, y_true, s_wght, clip], (grads, loss)] Function that efficiently computes and returns sample-wise gradients wrt trainable model parameters based on a batch of inputs, with opt. clipping based on a maximum l2-norm value `clip`. + It returns the sample-wise gradients as a dict of tensors with their + parameter name as key, plus the sample-wise loss values as a tensor. Note ---- diff --git a/declearn/model/torch/_samplewise/functorch.py b/declearn/model/torch/_samplewise/functorch.py index 37ea5ecbe07a716727bada27e05abac153e1bbb6..67c50619aa1a8a57016ae1bba445e57e3937533c 100644 --- a/declearn/model/torch/_samplewise/functorch.py +++ b/declearn/model/torch/_samplewise/functorch.py @@ -55,13 +55,13 @@ def build_samplewise_grads_fn_backend( """Compute gradients and optionally clip them.""" params, idxgrd, pnames = get_params(model) buffers = list(model.buffers()) - gfunc = functorch.grad(run_forward, argnums=tuple(idxgrd)) - grads = gfunc( + gfunc = functorch.grad_and_value(run_forward, argnums=tuple(idxgrd)) + grads, loss = gfunc( inputs, y_true, (None if clip else s_wght), buffers, *params ) if clip: clip_and_scale_grads_inplace(grads, clip, s_wght) - return dict(zip(pnames, grads)) + return dict(zip(pnames, grads)), loss.detach() # Wrap the former function to compute and clip sample-wise gradients. in_dims = ([0] * inputs, 0 if y_true else None, 0 if s_wght else None) diff --git a/declearn/model/torch/_samplewise/shared.py b/declearn/model/torch/_samplewise/shared.py index 451ae7c11998c3d337bffe988df2f6e131f5e539..009fa040e0189e8b1c39354ef33af1224aa0e6ef 100644 --- a/declearn/model/torch/_samplewise/shared.py +++ b/declearn/model/torch/_samplewise/shared.py @@ -17,7 +17,7 @@ """Shared code for torch-version-dependent backend code.""" -from typing import Callable, Dict, Iterable, List, Optional +from typing import Callable, Dict, Iterable, List, Optional, Tuple import torch @@ -34,7 +34,7 @@ GetGradientsFunction = Callable[ Optional[torch.Tensor], Optional[float], ], - Dict[str, torch.Tensor], + Tuple[Dict[str, torch.Tensor], torch.Tensor], ] """Signature for sample-wise gradients computation functions.""" diff --git a/declearn/model/torch/_samplewise/torchfunc.py b/declearn/model/torch/_samplewise/torchfunc.py index 88aa5b7dca01dcdd77adda76b29d6b080c7487d2..14e9989f415c90465067513f02a16e975218aefc 100644 --- a/declearn/model/torch/_samplewise/torchfunc.py +++ b/declearn/model/torch/_samplewise/torchfunc.py @@ -51,22 +51,24 @@ def build_samplewise_grads_fn_backend( s_loss.mul_(s_wght.to(s_loss.device)) return s_loss.mean() - get_grads = torch.func.grad(run_forward, argnums=0) + get_grads_and_loss = torch.func.grad_and_value(run_forward, argnums=0) - def get_clipped_grads(inputs, y_true, s_wght, clip=None): + def get_clipped_grads_and_loss(inputs, y_true, s_wght, clip=None): """Compute gradients and optionally clip them.""" params, frozen = get_params(model) buffers = dict(model.named_buffers()) - grads = get_grads( + grads, loss = get_grads_and_loss( params, frozen, buffers, inputs, y_true, None if clip else s_wght ) if clip: clip_and_scale_grads_inplace(grads.values(), clip, s_wght) - return grads + return grads, loss.detach() # Wrap the former function to compute and clip sample-wise gradients. in_dims = ([0] * inputs, 0 if y_true else None, 0 if s_wght else None) - return torch.func.vmap(get_clipped_grads, in_dims, randomness="same") + return torch.func.vmap( + get_clipped_grads_and_loss, in_dims, randomness="same" + ) def get_params(