diff --git a/declearn/model/torch/_model.py b/declearn/model/torch/_model.py
index 59334cc9098ea9128a50a213473e27724ff67a4b..7491c3f53b444f8216143a39ce0d15249c810ebc 100644
--- a/declearn/model/torch/_model.py
+++ b/declearn/model/torch/_model.py
@@ -17,11 +17,19 @@
 
 """Model subclass to wrap PyTorch models."""
 
+import functools
 import io
 import warnings
 from typing import Any, Callable, Dict, List, Optional, Set, Tuple
 
 import functorch  # type: ignore
+
+try:
+    import functorch.compile  # type: ignore
+except ModuleNotFoundError:
+    COMPILE_AVAILABLE = False
+else:
+    COMPILE_AVAILABLE = True
 import numpy as np
 import torch
 from typing_extensions import Self  # future: import from typing (py >=3.11)
@@ -83,7 +91,9 @@ class TorchModel(Model):
         loss: torch.nn.Module
             Torch Module instance that defines the model's loss, that
             is to be minimized through training. Note that it will be
-            altered when wrapped.
+            altered when wrapped. It must expect `y_pred` and `y_true`
+            as input arguments (in that order) and will be used to get
+            sample-wise loss values (by removing any reduction scheme).
         """
         # Type-check the input model.
         if not isinstance(model, torch.nn.Module):
@@ -99,7 +109,7 @@ class TorchModel(Model):
         loss.reduction = "none"  # type: ignore
         self._loss_fn = AutoDeviceModule(loss, device=device)
         # Compute and assign a functional version of the model.
-        self._func_model = functorch.make_functional(self._model)[0]
+        self._func_model, _ = functorch.make_functional(self._model)
 
     @property
     def device_policy(
@@ -265,117 +275,108 @@ class TorchModel(Model):
             loss.mul_(s_wght.to(loss.device))
         return loss.mean()
 
-    def _compute_samplewise_gradients(
+    def _compute_clipped_gradients(
         self,
         batch: Batch,
+        max_norm: float,
     ) -> TorchVector:
-        """Compute and return stacked sample-wise gradients from a batch."""
-        # Delegate preparation of the gradients-computing function.
-        # fmt: off
-        grads_fn, data, params, pnames, in_axes = (
-            self._prepare_samplewise_gradients_computations(batch)
+        """Compute and return batch-averaged sample-wise-clipped gradients."""
+        # Compute sample-wise clipped gradients, using functorch.
+        grads = self._compute_samplewise_gradients(batch, max_norm)
+        # Batch-average the resulting sample-wise gradients.
+        return TorchVector(
+            {name: tensor.mean(dim=0) for name, tensor in grads.coefs.items()}
         )
-        # Vectorize the function to compute sample-wise gradients.
-        with torch.no_grad():
-            grads = functorch.vmap(grads_fn, in_axes)(*data, *params)
-        # Wrap the results into a TorchVector and return it.
-        return TorchVector(dict(zip(pnames, grads)))
 
-    def _compute_clipped_gradients(
+    def _compute_samplewise_gradients(
         self,
         batch: Batch,
-        max_norm: float,
+        max_norm: Optional[float],
     ) -> TorchVector:
-        """Compute and return batch-averaged sample-wise-clipped gradients."""
-        # Delegate preparation of the gradients-computing function.
-        # fmt: off
-        grads_fn, data, params, pnames, in_axes = (
-            self._prepare_samplewise_gradients_computations(batch)
+        """Compute and return stacked sample-wise gradients over a batch."""
+        # Unpack the inputs, gather parameters and list gradients to compute.
+        inputs, y_true, s_wght = self._unpack_batch(batch)
+        params = []  # type: List[torch.nn.Parameter]
+        idxgrd = []  # type: List[int]
+        pnames = []  # type: List[str]
+        for index, (name, param) in enumerate(self._model.named_parameters()):
+            params.append(param)
+            if param.requires_grad:
+                idxgrd.append(index + 3)
+                pnames.append(name)
+        # Gather or build the sample-wise clipped gradients computing function.
+        grads_fn = self._build_samplewise_grads_fn(
+            idxgrd=tuple(idxgrd),
+            inputs=len(inputs),
+            y_true=(y_true is not None),
+            s_wght=(s_wght is not None),
         )
-        # Compose it to clip output gradients on the way.
-        def clipped_grads_fn(inputs, y_true, s_wght, *params):
-            grads = grads_fn(inputs, y_true, None, *params)
-            for grad in grads:
-                # future: use torch.linalg.norm when supported by functorch
-                norm = torch.norm(grad, p=2, keepdim=True)
-                # false-positive; pylint: disable=no-member
-                grad.mul_(torch.clamp(max_norm / norm, max=1))
-                if s_wght is not None:
-                    grad.mul_(s_wght.to(grad.device))
-            return grads
-        # Vectorize the function to compute sample-wise clipped gradients.
+        # Call it on the current inputs, with optional clipping.
         with torch.no_grad():
-            grads = functorch.vmap(clipped_grads_fn, in_axes)(*data, *params)
-        # Wrap batch-averaged results into a TorchVector and return it.
-        return TorchVector(
-            {name: grad.mean(dim=0) for name, grad in zip(pnames, grads)}
-        )
+            grads = grads_fn(inputs, y_true, s_wght, *params, clip=max_norm)
+        # Wrap the results into a TorchVector and return it.
+        return TorchVector(dict(zip(pnames, grads)))
 
-    def _prepare_samplewise_gradients_computations(
+    @functools.lru_cache
+    def _build_samplewise_grads_fn(
         self,
-        batch: Batch,
-    ) -> Tuple[
-        Callable[..., List[torch.Tensor]],
-        TensorBatch,
-        List[torch.nn.Parameter],
-        List[str],
-        Tuple[Any, ...],
-    ]:
-        """Prepare a function an parameters to compute sample-wise gradients.
-
-        Note: this method is merely implemented as a way to avoid code
-        redundancies between the `_compute_samplewise_gradients` method
-        and the `_compute_clipped_gradients` ones.
+        idxgrd: Tuple[int, ...],
+        inputs: int,
+        y_true: bool,
+        s_wght: bool,
+    ) -> Callable[..., List[torch.Tensor]]:
+        """Build a functorch-based sample-wise gradients-computation function.
+
+        This function is cached, i.e. repeated calls with the same parameters
+        will return the same object - enabling to reduce runtime costs due to
+        building and (when available) compiling the output function.
 
         Parameters
         ----------
-        batch: declearn.typing.Batch
-            Batch structure wrapping the input data, target labels and
-            optional sample weights based on which to compute gradients.
+        idxgrd: tuple of int
+            Pre-incremented indices of the parameters that require gradients.
+        inputs: int
+            Number of input tensors.
+        y_true: bool
+            Whether a true labels tensor is provided.
+        s_wght: bool
+            Whether a sample weights tensor is provided.
 
         Returns
         -------
-        grads_fn: function(*data, *params) -> List[torch.Tensor]
-            Functorch-issued gradients computation function.
-        data: tuple([torch.Tensor], torch.Tensor, torch.Tensor or None)
-            Tensor-converted data unpacked from `batch`.
-        params: list[torch.nn.Parameter]
-            Input parameters of the model, some of which require grads.
-        pnames: list[str]
-            Names of the parameters that require gradients.
-        in_axes: tuple(...)
-            Prepared `in_axes` parameter to `functorch.vmap`, suitable
-            to distribute `grads_fn` (or any compose that shares its
-            input signature) over a batch so as to compute sample-wise
-            gradients in a computationally-efficient manner.
+        grads_fn: callable[inputs, y_true, s_wght, *params, /, clip]
+            Functorch-optimized function to efficiently compute sample-
+            wise gradients based on batched inputs, and optionally clip
+            them based on a maximum l2-norm value `clip`.
         """
-        # fmt: off
-        # Unpack and validate inputs.
-        data = (inputs, y_true, s_wght) = self._unpack_batch(batch)
-        # Gather parameters and list those that require gradients.
-        idxgrd = []  # type: List[int]
-        pnames = []  # type: List[str]
-        params = []  # type: List[torch.nn.Parameter]
-        for idx, (name, param) in enumerate(self._model.named_parameters()):
-            params.append(param)
-            if param.requires_grad:
-                pnames.append(name)
-                idxgrd.append(idx)
-        # Define a differentiable function wrapping the forward pass.
+
         def forward(inputs, y_true, s_wght, *params):
+            """Conduct the forward pass in a functional way."""
             y_pred = self._func_model(params, *inputs)
             return self._compute_loss(y_pred, y_true, s_wght)
-        # Transform it into a sample-wise-gradients-computing function.
-        grads_fn = functorch.grad(forward, argnums=tuple(i+3 for i in idxgrd))
-        # Prepare `functools.vmap` parameter to slice through data and params.
-        in_axes = [
-            [0] * len(inputs),
-            None if y_true is None else 0,
-            None if s_wght is None else 0,
-        ]
-        in_axes.extend([None] * len(params))
-        # Return all this prepared material.
-        return grads_fn, data, params, pnames, tuple(in_axes)
+
+        def grads_fn(inputs, y_true, s_wght, *params, clip=None):
+            """Compute gradients and optionally clip them."""
+            gfunc = functorch.grad(forward, argnums=idxgrd)
+            grads = gfunc(inputs, y_true, None, *params)
+            if clip:
+                for grad in grads:
+                    # future: use torch.linalg.norm when supported by functorch
+                    norm = torch.norm(grad, p=2, keepdim=True)
+                    # false-positive; pylint: disable=no-member
+                    grad.mul_(torch.clamp(clip / norm, max=1))
+                    if s_wght is not None:
+                        grad.mul_(s_wght.to(grad.device))
+            return grads
+
+        # Wrap the former function to compute and clip sample-wise gradients.
+        in_axes = [[0] * inputs, 0 if y_true else None, 0 if s_wght else None]
+        in_axes.extend([None] * sum(1 for _ in self._model.parameters()))
+        grads_fn = functorch.vmap(grads_fn, tuple(in_axes))
+        # Compile the resulting function to decrease runtime costs.
+        if not COMPILE_AVAILABLE:
+            return grads_fn
+        return functorch.compile.aot_function(grads_fn, functorch.compile.nop)
 
     def apply_updates(  # type: ignore  # Vector subtype specification
         self,