diff --git a/declearn/main/_client.py b/declearn/main/_client.py
index b712d72150794aad9c0daa6e61a83f037a8a4f21..017eeaddef7d8674821bf6c9df0bf1f27390a12a 100644
--- a/declearn/main/_client.py
+++ b/declearn/main/_client.py
@@ -20,8 +20,11 @@
 import asyncio
 import dataclasses
 import logging
+import os
 from typing import Any, Dict, Optional, Union
 
+import numpy as np
+
 from declearn.communication import NetworkClientConfig, messaging
 from declearn.communication.api import NetworkClient
 from declearn.dataset import Dataset, load_dataset_from_json
@@ -356,6 +359,16 @@ class FederatedClient:
         assert self.trainmanager is not None
         # Run the training round.
         reply = self.trainmanager.training_round(message)
+        # Collect and optionally record batch-wise training losses.
+        # Note: collection enables purging them from memory.
+        losses = self.trainmanager.model.collect_training_losses()
+        if self.ckptr is not None:
+            self.ckptr.save_metrics(
+                metrics={"training_losses": np.array(losses)},
+                prefix="training_losses",
+                append=True,
+                timestamp=f"round_{message.round_i}",
+            )
         # Send training results (or error message) to the server.
         await self.netwk.send_message(reply)
 
@@ -412,7 +425,7 @@ class FederatedClient:
             message.loss,
         )
         if self.ckptr:
-            path = f"{self.ckptr.folder}/model_state_best.json"
+            path = os.path.join(self.ckptr.folder, "model_state_best.json")
             self.logger.info("Checkpointing final weights under %s.", path)
             assert self.trainmanager is not None  # for mypy
             self.trainmanager.model.set_weights(message.weights)
diff --git a/declearn/main/utils/_checkpoint.py b/declearn/main/utils/_checkpoint.py
index e6677accf0ea31cc28e06e45a53ea9b39d2b36c2..bdd4c462475b2d05659fc54820e98b388bf8e1fd 100644
--- a/declearn/main/utils/_checkpoint.py
+++ b/declearn/main/utils/_checkpoint.py
@@ -23,7 +23,7 @@ from datetime import datetime
 from typing import Any, Dict, List, Optional, Union
 
 import numpy as np
-import pandas as pd  # type: ignore
+import pandas as pd
 from typing_extensions import Self  # future: import from typing (py >=3.11)
 
 from declearn.model.api import Model
@@ -319,16 +319,17 @@ class Checkpointer:
         }
         # Filter out scalar metrics and write them to a csv file.
         scalars = {k: v for k, v in scores.items() if isinstance(v, float)}
-        fpath = os.path.join(self.folder, f"{prefix}.csv")
-        pd.DataFrame(scalars, index=[timestamp]).to_csv(
-            fpath,
-            sep=",",
-            mode=("a" if append else "w"),
-            header=not (append and os.path.isfile(fpath)),
-            index=True,
-            index_label="timestamp",
-            encoding="utf-8",
-        )
+        if scalars:
+            fpath = os.path.join(self.folder, f"{prefix}.csv")
+            pd.DataFrame(scalars, index=[timestamp]).to_csv(
+                fpath,
+                sep=",",
+                mode=("a" if append else "w"),
+                header=not (append and os.path.isfile(fpath)),
+                index=True,
+                index_label="timestamp",
+                encoding="utf-8",
+            )
         # Write the full set of metrics to a JSON file.
         jdump = json.dumps({timestamp: scores})[1:-1]  # bracket-less dict
         fpath = os.path.join(self.folder, f"{prefix}.json")
diff --git a/declearn/main/utils/_training.py b/declearn/main/utils/_training.py
index 901d953868e64b48a84d30a945c59d680862edb8..a82e7114f1371fa29f8f6d4053e6487c578888a8 100644
--- a/declearn/main/utils/_training.py
+++ b/declearn/main/utils/_training.py
@@ -18,7 +18,7 @@
 """Wrapper to run local training and evaluation rounds in a FL process."""
 
 import logging
-from typing import Any, ClassVar, Dict, List, Optional, Union
+from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union
 
 import numpy as np
 
@@ -169,7 +169,7 @@ class TrainingManager:
             "Training local model for %s epochs | %s steps | %s seconds.",
             *params,
         )
-        effort = self._train_under_constraints(message.batches, *params)
+        effort = self.train_under_constraints(message.batches, *params)
         # Compute model updates and collect auxiliary variables.
         self.logger.info("Packing local updates to be sent to the server.")
         return messaging.TrainReply(
@@ -180,21 +180,25 @@ class TrainingManager:
             t_spent=round(effort["t_spent"], 3),
         )
 
-    def _train_under_constraints(
+    def train_under_constraints(
         self,
         batch_cfg: Dict[str, Any],
-        n_epoch: Optional[int],
-        n_steps: Optional[int],
-        timeout: Optional[int],
+        n_epoch: Optional[int] = 1,
+        n_steps: Optional[int] = None,
+        timeout: Optional[int] = None,
     ) -> Dict[str, float]:
-        """Backend code to run local SGD steps under effort constraints.
+        """Run local SGD steps under effort constraints.
+
+        This is the core backend to the `training_round` method,
+        which further handles message parsing and passing, as well
+        as exception catching.
 
         Parameters
         ----------
         batch_cfg: Dict[str, Any]
             Keyword arguments for `self.train_data.generate_batches`
             i.e. specifications of batches used in local SGD steps.
-        n_epoch: int or None, default=None
+        n_epoch: int or None, default=1
             Maximum number of local training epochs to perform.
             May be overridden by `n_steps` or `timeout`.
         n_steps: int or None, default=None
@@ -286,12 +290,7 @@ class TrainingManager:
         )
         # Try running the evaluation round.
         try:
-            # Update the model's weights and evaluate on the local dataset.
-            # Revise: make the weights' update optional.
-            self.model.set_weights(message.weights, trainable=True)
-            return self._evaluate_under_constraints(
-                message.batches, message.n_steps, message.timeout
-            )
+            return self._evaluation_round(message)
         # In case of failure, wrap the exception as an Error message.
         except Exception as exception:  # pylint: disable=broad-except
             self.logger.error(
@@ -299,13 +298,41 @@ class TrainingManager:
             )
             return messaging.Error(repr(exception))
 
-    def _evaluate_under_constraints(
+    def _evaluation_round(
+        self,
+        message: messaging.EvaluationRequest,
+    ) -> messaging.EvaluationReply:
+        """Backend to `evaluation_round`, without exception capture hooks."""
+        # Update the model's weights and evaluate on the local dataset.
+        # Revise: make the weights' update optional.
+        self.model.set_weights(message.weights, trainable=True)
+        metrics, states, effort = self.evaluate_under_constraints(
+            message.batches, message.n_steps, message.timeout
+        )
+        # Pack the resulting information into a message.
+        self.logger.info("Packing local results to be sent to the server.")
+        return messaging.EvaluationReply(
+            loss=float(metrics["loss"]),
+            metrics=states,
+            n_steps=int(effort["n_steps"]),
+            t_spent=round(effort["t_spent"], 3),
+        )
+
+    def evaluate_under_constraints(
         self,
         batch_cfg: Dict[str, Any],
         n_steps: Optional[int] = None,
         timeout: Optional[int] = None,
-    ) -> messaging.EvaluationReply:
-        """Backend code to run local loss computation under effort constraints.
+    ) -> Tuple[
+        Dict[str, Union[float, np.ndarray]],
+        Dict[str, Dict[str, Union[float, np.ndarray]]],
+        Dict[str, float],
+    ]:
+        """Run local loss computation under effort constraints.
+
+        This is the core backend to the `evaluation_round` method,
+        which further handles message parsing and passing, as well
+        as exception catching.
 
         Parameters
         ----------
@@ -320,10 +347,21 @@ class TrainingManager:
 
         Returns
         -------
-        reply: messaging.EvaluationReply
-            EvaluationReply message wrapping the computed loss on the
-            local validation (or, if absent, training) dataset as well
-            as the number of steps and the time taken to obtain it.
+        metrics:
+            Computed metrics, as a dict with float or array values.
+        states:
+            Computed metrics, as partial values that may be shared
+            with other agents to federatively compute final values
+            with the same specs as `metrics`.
+        effort:
+            Dictionary storing information on the computational
+            effort effectively performed:
+            * n_epoch: int
+                Number of evaluation epochs completed.
+            * n_steps: int
+                Number of evaluation steps completed.
+            * t_spent: float
+                Time spent running training steps (in seconds).
         """
         # Set up effort constraints under which to operate.
         constraints = ConstraintSet(
@@ -342,18 +380,12 @@ class TrainingManager:
                 break
         # Gather the computed metrics and computational effort information.
         effort = constraints.get_values()
-        result = self.metrics.get_result()
+        values = self.metrics.get_result()
         states = self.metrics.get_states()
         self.logger.log(
             LOGGING_LEVEL_MAJOR,
             "Local scalar evaluation metrics: %s",
-            {k: v for k, v in result.items() if isinstance(v, float)},
-        )
-        # Pack the result and computational effort information into a message.
-        self.logger.info("Packing local results to be sent to the server.")
-        return messaging.EvaluationReply(
-            loss=float(result["loss"]),
-            metrics=states,
-            n_steps=int(effort["n_steps"]),
-            t_spent=round(effort["t_spent"], 3),
+            {k: v for k, v in values.items() if isinstance(v, float)},
         )
+        # Return the metrics' values, their states and the effort information.
+        return values, states, effort
diff --git a/declearn/model/api/_model.py b/declearn/model/api/_model.py
index e74102bbdfdc1c174f633c08eaf6ea23b62e949d..f012dfee9c9c44979a029c073cf34b0d157ba6d4 100644
--- a/declearn/model/api/_model.py
+++ b/declearn/model/api/_model.py
@@ -18,7 +18,7 @@
 """Model abstraction API."""
 
 from abc import ABCMeta, abstractmethod
-from typing import Any, Dict, Generic, Optional, Set, Tuple, TypeVar
+from typing import Any, Dict, Generic, List, Optional, Set, Tuple, TypeVar
 
 import numpy as np
 from typing_extensions import Self  # future: import from typing (py >=3.11)
@@ -62,6 +62,8 @@ class Model(Generic[VectorT], metaclass=ABCMeta):
     ) -> None:
         """Instantiate a Model interface wrapping a 'model' object."""
         self._model = model
+        # Declare a private list where to record batch-wise training losses.
+        self._loss_history = []  # type: List[float]
 
     def get_wrapped_model(self) -> Any:
         """Getter to access the wrapped framework-specific model object.
@@ -202,6 +204,10 @@ class Model(Generic[VectorT], metaclass=ABCMeta):
         to its trainable parameters for the given data batch.
         Optionally clip sample-wise gradients before batch-averaging.
 
+        Record the loss value over the batch, which may be collected
+        (and thereof purged from the internal memory) by calling the
+        `collect_training_losses` method.
+
         Parameters
         ----------
         batch: declearn.typing.Batch
@@ -227,6 +233,25 @@ class Model(Generic[VectorT], metaclass=ABCMeta):
     ) -> None:
         """Apply updates to the model's weights."""
 
+    def collect_training_losses(
+        self,
+    ) -> List[float]:
+        """Collect batch-wise training losses accumulated over time.
+
+        Return all recorded batch-averaged loss values computed a
+        part of `compute_batch_gradients` calls, and clear them
+        from memory, so that next time this method is called, only
+        new values are returned.
+
+        Returns
+        -------
+        losses:
+            List of bath-averaged loss values computed over inputs
+            to the `compute_batch_gradients` method.
+        """
+        losses, self._loss_history = self._loss_history, []
+        return losses
+
     @abstractmethod
     def compute_batch_predictions(
         self,
diff --git a/declearn/model/haiku/_model.py b/declearn/model/haiku/_model.py
index 1f81dd8f475f722846177637b04f3a1246515f89..960b9f6c5ea207b244a56481bb91349e436b15ef 100644
--- a/declearn/model/haiku/_model.py
+++ b/declearn/model/haiku/_model.py
@@ -385,23 +385,29 @@ class HaikuModel(Model):
         rng = next(self._rng_gen)
         # Compute batch-averaged gradients, opt. clipped on a per-sample basis.
         if max_norm:
-            grads = self._clipped_grad_fn(
+            grads, loss = self._clipped_grads_and_loss_fn(
                 train_params, fixed_params, rng, inputs, max_norm
             )
             grads = [value.mean(0) for value in grads]
         else:
-            grads = jax.tree_util.tree_leaves(
-                self._grad_fn(train_params, fixed_params, rng, inputs)
+            loss, grads_tree = self._loss_and_grads_fn(
+                train_params, fixed_params, rng, inputs
             )
+            grads = jax.tree_util.tree_leaves(grads_tree)
+        # Record the batch-averaged loss value.
+        self._loss_history.append(float(np.array(loss).mean()))
         # Return the gradients, flattened into a JaxNumpyVector container.
         return JaxNumpyVector(dict(zip(self._trainable, grads)))
 
     @functools.cached_property
-    def _grad_fn(
+    def _loss_and_grads_fn(
         self,
-    ) -> Callable[[hk.Params, hk.Params, jax.Array, JaxBatch], hk.Params]:
+    ) -> Callable[
+        [hk.Params, hk.Params, jax.Array, JaxBatch],
+        Tuple[jax.Array, hk.Params],
+    ]:
         """Lazy-built jax function to compute batch-averaged gradients."""
-        return jax.jit(jax.grad(self._forward))
+        return jax.jit(jax.value_and_grad(self._forward))
 
     def _forward(
         self,
@@ -436,10 +442,11 @@ class HaikuModel(Model):
         return jnp.mean(s_loss)
 
     @functools.cached_property
-    def _clipped_grad_fn(
+    def _clipped_grads_and_loss_fn(
         self,
     ) -> Callable[
-        [hk.Params, hk.Params, jax.Array, JaxBatch, float], List[jax.Array]
+        [hk.Params, hk.Params, jax.Array, JaxBatch, float],
+        Tuple[List[jax.Array], jax.Array],
     ]:
         """Lazy-built jax function to compute clipped sample-wise gradients.
 
@@ -447,17 +454,17 @@ class HaikuModel(Model):
         applying optional parameters to pytrees.
         """
 
-        def clipped_grad_fn(
+        def clipped_grads_and_loss_fn(
             train_params: hk.Params,
             fixed_params: hk.Params,
             rng: jax.Array,
             batch: JaxBatch,
             max_norm: float,
-        ) -> List[jax.Array]:
+        ) -> Tuple[List[jax.Array], jax.Array]:
             """Compute and clip gradients wrt parameters for a sample."""
             inputs, y_true, s_wght = batch
             batch = (inputs, y_true, None)
-            grads = jax.grad(self._forward)(
+            loss, grads = jax.value_and_grad(self._forward)(
                 train_params, fixed_params, rng, batch
             )
             grads_flat = [
@@ -466,10 +473,10 @@ class HaikuModel(Model):
             ]
             if s_wght is not None:
                 grads_flat = [g * s_wght for g in grads_flat]
-            return grads_flat
+            return grads_flat, loss
 
         in_axes = [None, None, None, 0, None]  # map on inputs' first dimension
-        return jax.jit(jax.vmap(clipped_grad_fn, in_axes))
+        return jax.jit(jax.vmap(clipped_grads_and_loss_fn, in_axes))
 
     def _unpack_batch(
         self,
diff --git a/declearn/model/sklearn/_sgd.py b/declearn/model/sklearn/_sgd.py
index 3862d0af6dc21b9c0210270e9004f6c18795a0ce..160a6d3970295c4b2e6c6fd66653b8b8f945b07d 100644
--- a/declearn/model/sklearn/_sgd.py
+++ b/declearn/model/sklearn/_sgd.py
@@ -397,6 +397,11 @@ class SklearnSGDModel(Model):
         # Optionally re-weight gradients based on sample weights.
         if s_wght is not None:
             grad = [g * w for g, w in zip(grad, s_wght)]
+        # Compute and record the loss value on the entire batch.
+        loss = self.loss_function(
+            y_data, self._predict(x_data)  # type: ignore
+        )
+        self._loss_history.append(float(loss.mean()))
         # Batch-average the gradients and return them.
         return sum(grad) / len(grad)  # type: ignore
 
diff --git a/declearn/model/tensorflow/_model.py b/declearn/model/tensorflow/_model.py
index 59260eaededbf9182178d82ec282c68964867e27..812fad761cae780e58386253bd464697de6c69dc 100644
--- a/declearn/model/tensorflow/_model.py
+++ b/declearn/model/tensorflow/_model.py
@@ -237,10 +237,11 @@ class TensorflowModel(Model):
         with tf.device(self._device):
             data = self._unpack_batch(batch)
             if max_norm is None:
-                grads = self._compute_batch_gradients(*data)
+                grads, loss = self._compute_batch_gradients(*data)
             else:
                 norm = tf.constant(max_norm)
-                grads = self._compute_clipped_gradients(*data, norm)
+                grads, loss = self._compute_clipped_gradients(*data, norm)
+        self._loss_history.append(float(loss.numpy()))
         grads_and_vars = zip(grads, self._model.trainable_weights)
         return TensorflowVector(
             {var.name: grad for grad, var in grads_and_vars}
@@ -267,14 +268,14 @@ class TensorflowModel(Model):
         inputs: tf.Tensor,
         y_true: Optional[tf.Tensor],
         s_wght: Optional[tf.Tensor],
-    ) -> List[tf.Tensor]:
+    ) -> Tuple[List[tf.Tensor], tf.Tensor]:
         """Compute and return batch-averaged gradients of trainable weights."""
         with tf.GradientTape() as tape:
             y_pred = self._model(inputs, training=True)
             loss = self._model.compute_loss(inputs, y_true, y_pred, s_wght)
             loss = tf.reduce_mean(loss)
             grad = tape.gradient(loss, self._model.trainable_weights)
-        return grad
+        return grad, loss
 
     @tf.function  # optimize tensorflow runtime
     def _compute_clipped_gradients(
@@ -283,26 +284,26 @@ class TensorflowModel(Model):
         y_true: Optional[tf.Tensor],
         s_wght: Optional[tf.Tensor],
         max_norm: Union[tf.Tensor, float],
-    ) -> List[tf.Tensor]:
+    ) -> Tuple[List[tf.Tensor], tf.Tensor]:
         """Compute and return sample-wise-clipped batch-averaged gradients."""
-        grad = self._compute_samplewise_gradients(inputs, y_true)
+        grad, loss = self._compute_samplewise_gradients(inputs, y_true)
         if s_wght is None:
             s_wght = tf.cast(1, grad[0].dtype)
         grad = self._clip_and_average_gradients(grad, max_norm, s_wght)
-        return grad
+        return grad, loss
 
     @tf.function  # optimize tensorflow runtime
     def _compute_samplewise_gradients(
         self,
         inputs: tf.Tensor,
         y_true: Optional[tf.Tensor],
-    ) -> List[tf.Tensor]:
+    ) -> Tuple[List[tf.Tensor], tf.Tensor]:
         """Compute and return sample-wise gradients for a given batch."""
         with tf.GradientTape() as tape:
             y_pred = self._model(inputs, training=True)
             loss = self._model.compute_loss(inputs, y_true, y_pred)
             grad = tape.jacobian(loss, self._model.trainable_weights)
-        return grad
+        return grad, tf.reduce_mean(loss)
 
     @staticmethod
     @tf.function  # optimize tensorflow runtime
diff --git a/declearn/model/torch/_model.py b/declearn/model/torch/_model.py
index 9451846d2bb9474cde12aaf7636ce89a145578e3..3f79113d72d36fec04c40383bb9d3de0c10b92e3 100644
--- a/declearn/model/torch/_model.py
+++ b/declearn/model/torch/_model.py
@@ -253,6 +253,7 @@ class TorchModel(Model):
         y_pred = self._model(*inputs)
         loss = self._compute_loss(y_pred, y_true, s_wght)
         loss.backward()
+        self._loss_history.append(float(loss.detach().cpu().numpy().mean()))
         # Collect weights' gradients and return them in a Vector container.
         grads = {
             k: p.grad.detach().clone()
@@ -320,7 +321,10 @@ class TorchModel(Model):
             s_wght=(s_wght is not None),
         )
         with torch.no_grad():
-            grads = grads_fn(inputs, y_true, s_wght, clip=clip)  # type: ignore
+            grads, loss = grads_fn(
+                inputs, y_true, s_wght, clip=clip
+            )  # type: ignore
+            self._loss_history.append(float(loss.cpu().numpy().mean()))
         return TorchVector(grads)
 
     @functools.lru_cache
diff --git a/declearn/model/torch/_samplewise/__init__.py b/declearn/model/torch/_samplewise/__init__.py
index 1448a18bb210279d99343f8637cc8d7587cf12a5..3060e34c173ae073afe54d8d92dff58b740dd267 100644
--- a/declearn/model/torch/_samplewise/__init__.py
+++ b/declearn/model/torch/_samplewise/__init__.py
@@ -62,10 +62,12 @@ def build_samplewise_grads_fn(
 
     Returns
     -------
-    grads_fn: callable[[inputs, y_true, s_wght, clip], grads]
+    grads_fn: callable[[inputs, y_true, s_wght, clip], (grads, loss)]
         Function that efficiently computes and returns sample-wise gradients
         wrt trainable model parameters based on a batch of inputs, with opt.
         clipping based on a maximum l2-norm value `clip`.
+        It returns the sample-wise gradients as a dict of tensors with their
+        parameter name as key, plus the sample-wise loss values as a tensor.
 
     Note
     ----
diff --git a/declearn/model/torch/_samplewise/functorch.py b/declearn/model/torch/_samplewise/functorch.py
index 37ea5ecbe07a716727bada27e05abac153e1bbb6..67c50619aa1a8a57016ae1bba445e57e3937533c 100644
--- a/declearn/model/torch/_samplewise/functorch.py
+++ b/declearn/model/torch/_samplewise/functorch.py
@@ -55,13 +55,13 @@ def build_samplewise_grads_fn_backend(
         """Compute gradients and optionally clip them."""
         params, idxgrd, pnames = get_params(model)
         buffers = list(model.buffers())
-        gfunc = functorch.grad(run_forward, argnums=tuple(idxgrd))
-        grads = gfunc(
+        gfunc = functorch.grad_and_value(run_forward, argnums=tuple(idxgrd))
+        grads, loss = gfunc(
             inputs, y_true, (None if clip else s_wght), buffers, *params
         )
         if clip:
             clip_and_scale_grads_inplace(grads, clip, s_wght)
-        return dict(zip(pnames, grads))
+        return dict(zip(pnames, grads)), loss.detach()
 
     # Wrap the former function to compute and clip sample-wise gradients.
     in_dims = ([0] * inputs, 0 if y_true else None, 0 if s_wght else None)
diff --git a/declearn/model/torch/_samplewise/shared.py b/declearn/model/torch/_samplewise/shared.py
index 451ae7c11998c3d337bffe988df2f6e131f5e539..009fa040e0189e8b1c39354ef33af1224aa0e6ef 100644
--- a/declearn/model/torch/_samplewise/shared.py
+++ b/declearn/model/torch/_samplewise/shared.py
@@ -17,7 +17,7 @@
 
 """Shared code for torch-version-dependent backend code."""
 
-from typing import Callable, Dict, Iterable, List, Optional
+from typing import Callable, Dict, Iterable, List, Optional, Tuple
 
 import torch
 
@@ -34,7 +34,7 @@ GetGradientsFunction = Callable[
         Optional[torch.Tensor],
         Optional[float],
     ],
-    Dict[str, torch.Tensor],
+    Tuple[Dict[str, torch.Tensor], torch.Tensor],
 ]
 """Signature for sample-wise gradients computation functions."""
 
diff --git a/declearn/model/torch/_samplewise/torchfunc.py b/declearn/model/torch/_samplewise/torchfunc.py
index 88aa5b7dca01dcdd77adda76b29d6b080c7487d2..14e9989f415c90465067513f02a16e975218aefc 100644
--- a/declearn/model/torch/_samplewise/torchfunc.py
+++ b/declearn/model/torch/_samplewise/torchfunc.py
@@ -51,22 +51,24 @@ def build_samplewise_grads_fn_backend(
             s_loss.mul_(s_wght.to(s_loss.device))
         return s_loss.mean()
 
-    get_grads = torch.func.grad(run_forward, argnums=0)
+    get_grads_and_loss = torch.func.grad_and_value(run_forward, argnums=0)
 
-    def get_clipped_grads(inputs, y_true, s_wght, clip=None):
+    def get_clipped_grads_and_loss(inputs, y_true, s_wght, clip=None):
         """Compute gradients and optionally clip them."""
         params, frozen = get_params(model)
         buffers = dict(model.named_buffers())
-        grads = get_grads(
+        grads, loss = get_grads_and_loss(
             params, frozen, buffers, inputs, y_true, None if clip else s_wght
         )
         if clip:
             clip_and_scale_grads_inplace(grads.values(), clip, s_wght)
-        return grads
+        return grads, loss.detach()
 
     # Wrap the former function to compute and clip sample-wise gradients.
     in_dims = ([0] * inputs, 0 if y_true else None, 0 if s_wght else None)
-    return torch.func.vmap(get_clipped_grads, in_dims, randomness="same")
+    return torch.func.vmap(
+        get_clipped_grads_and_loss, in_dims, randomness="same"
+    )
 
 
 def get_params(