diff --git a/declearn/communication/__init__.py b/declearn/communication/__init__.py
index 819f28e1ac409a4bbed79d2e21ca0b05404dc163..e7f685f248bfc22bbc35300a0ef3f39d66660dba 100644
--- a/declearn/communication/__init__.py
+++ b/declearn/communication/__init__.py
@@ -70,8 +70,10 @@ from ._build import (
 try:
     from . import grpc
 except ImportError:
+    # pragma: no cover
     _INSTALLABLE_BACKENDS["grpc"] = ("grpcio", "protobuf")
 try:
     from . import websockets
 except ImportError:
+    # pragma: no cover
     _INSTALLABLE_BACKENDS["websockets"] = ("websockets",)
diff --git a/declearn/model/torch/_model.py b/declearn/model/torch/_model.py
index f8a5dd6f1396d6d6909edbf070ee616e6cc3a6ec..21ce772adf4d606c9da584fcb9376fa3e353af6d 100644
--- a/declearn/model/torch/_model.py
+++ b/declearn/model/torch/_model.py
@@ -17,25 +17,21 @@
 
 """Model subclass to wrap PyTorch models."""
 
-import functools
 import io
+import functools
 import warnings
-from typing import Any, Callable, Dict, List, Optional, Set, Tuple
-
-import functorch  # type: ignore
+from typing import Any, Dict, List, Optional, Set, Tuple
 
-try:
-    import functorch.compile  # type: ignore
-except ModuleNotFoundError:
-    COMPILE_AVAILABLE = False
-else:
-    COMPILE_AVAILABLE = True
 import numpy as np
 import torch
 from typing_extensions import Self  # future: import from typing (py >=3.11)
 
 from declearn.model.api import Model
 from declearn.model.torch.utils import AutoDeviceModule, select_device
+from declearn.model.torch._samplewise import (
+    GetGradientsFunction,
+    build_samplewise_grads_fn,
+)
 from declearn.model.torch._vector import TorchVector
 from declearn.model._utils import raise_on_stringsets_mismatch
 from declearn.typing import Batch
@@ -75,6 +71,25 @@ class TorchModel(Model):
       `update_device_policy` method.
     - You may consult the device policy currently enforced by a TorchModel
       instance by accessing its `device_policy` property.
+
+    Notes regarding `torch.compile` support (torch >=2.0):
+
+    - If you want the wrapped model to be optimized via `torch.compile`, it
+      should be so _prior_ to being wrapped using `TorchModel`.
+    - The compilation will not be used when computing sample-wise-clipped
+      gradients, as `torch.func` and `torch.compile` do not play along yet.
+    - The information that the module was compiled will be saved as part of
+      the `TorchModel` config, so that using `TorchModel.from_config` will
+      trigger it again when possible; this is however limited to calling
+      `torch.compile`, meaning that any other argument will be lost.
+    - Note that the former point notably affects the way clients will run
+      a server-emitted `TorchModel` as part of a FL process: client that
+      run Torch 1.X will be able to use the un-optimized module, while
+      clients running Torch 2.0 will use compilation, but in a rather crude
+      flavor, that may not be suitable for some specific/advanced cases.
+    - Enhanced support for `torch.compile` is on the roadmap. If you run
+      into issues and/or have requests or advice on that topic, feel free
+      to let us know by contacting us via mail or GitLab.
     """
 
     def __init__(
@@ -101,15 +116,19 @@ class TorchModel(Model):
         # Select the device where to place computations, and wrap the model.
         policy = get_device_policy()
         device = select_device(gpu=policy.gpu, idx=policy.idx)
-        model = AutoDeviceModule(model, device=device)
-        super().__init__(model)
+        super().__init__(AutoDeviceModule(model, device=device))
         # Assign loss module and set it not to reduce sample-wise values.
         if not isinstance(loss, torch.nn.Module):
             raise TypeError("'loss' should be a torch.nn.Module instance.")
         loss.reduction = "none"  # type: ignore
         self._loss_fn = AutoDeviceModule(loss, device=device)
-        # Compute and assign a functional version of the model.
-        self._func_model, _ = functorch.make_functional(self._model)
+        # Detect torch-compiled models and extract underlying module.
+        self._raw_model = self._model
+        if hasattr(torch, "compile") and hasattr(model, "_orig_mod"):
+            self._raw_model = AutoDeviceModule(
+                module=getattr(model, "_orig_mod"),
+                device=self._model.device,
+            )
 
     @property
     def device_policy(
@@ -137,12 +156,16 @@ class TorchModel(Model):
             "PyTorch JSON serialization relies on pickle, which may be unsafe."
         )
         with io.BytesIO() as buffer:
-            torch.save(self._model.module, buffer)
+            torch.save(self._raw_model.module, buffer)
             model = buffer.getbuffer().hex()
         with io.BytesIO() as buffer:
             torch.save(self._loss_fn.module, buffer)
             loss = buffer.getbuffer().hex()
-        return {"model": model, "loss": loss}
+        return {
+            "model": model,
+            "loss": loss,
+            "compile": self._raw_model is not self._model,
+        }
 
     @classmethod
     def from_config(
@@ -154,13 +177,15 @@ class TorchModel(Model):
             model = torch.load(buffer)
         with io.BytesIO(bytes.fromhex(config["loss"])) as buffer:
             loss = torch.load(buffer)
+        if config.get("compile", False) and hasattr(torch, "compile"):
+            model = torch.compile(model)
         return cls(model=model, loss=loss)
 
     def get_weights(
         self,
         trainable: bool = False,
     ) -> TorchVector:
-        params = self._model.named_parameters()
+        params = self._raw_model.named_parameters()
         if trainable:
             weights = {k: p.data for k, p in params if p.requires_grad}
         else:
@@ -177,12 +202,12 @@ class TorchModel(Model):
             raise TypeError("TorchModel requires TorchVector weights.")
         self._verify_weights_compatibility(weights, trainable=trainable)
         if trainable:
-            state_dict = self._model.state_dict()
+            state_dict = self._raw_model.state_dict()
             state_dict.update(weights.coefs)
         else:
             state_dict = weights.coefs
         # NOTE: this preserves the device placement of current states
-        self._model.load_state_dict(state_dict)
+        self._raw_model.load_state_dict(state_dict)
 
     def _verify_weights_compatibility(
         self,
@@ -206,12 +231,9 @@ class TorchModel(Model):
             In case some expected keys are missing, or additional keys
             are present. Be verbose about the identified mismatch(es).
         """
+        params = self._raw_model.named_parameters()
         received = set(vector.coefs)
-        expected = {
-            name
-            for name, param in self._model.named_parameters()
-            if (not trainable) or param.requires_grad
-        }
+        expected = {n for n, p in params if (not trainable) or p.requires_grad}
         raise_on_stringsets_mismatch(
             received, expected, context="model weights"
         )
@@ -241,7 +263,7 @@ class TorchModel(Model):
         # Collect weights' gradients and return them in a Vector container.
         grads = {
             k: p.grad.detach().clone()
-            for k, p in self._model.named_parameters()
+            for k, p in self._raw_model.named_parameters()
             if p.requires_grad
         }
         return TorchVector(grads)
@@ -281,8 +303,8 @@ class TorchModel(Model):
         max_norm: float,
     ) -> TorchVector:
         """Compute and return batch-averaged sample-wise-clipped gradients."""
-        # Compute sample-wise clipped gradients, using functorch.
-        grads = self._compute_samplewise_gradients(batch, max_norm)
+        # Compute sample-wise clipped gradients, using functional torch.
+        grads = self._compute_samplewise_gradients(batch, clip=max_norm)
         # Batch-average the resulting sample-wise gradients.
         return TorchVector(
             {name: tensor.mean(dim=0) for name, tensor in grads.coefs.items()}
@@ -291,92 +313,49 @@ class TorchModel(Model):
     def _compute_samplewise_gradients(
         self,
         batch: Batch,
-        max_norm: Optional[float],
+        clip: Optional[float],
     ) -> TorchVector:
         """Compute and return stacked sample-wise gradients over a batch."""
-        # Unpack the inputs, gather parameters and list gradients to compute.
         inputs, y_true, s_wght = self._unpack_batch(batch)
-        params = []  # type: List[torch.nn.Parameter]
-        idxgrd = []  # type: List[int]
-        pnames = []  # type: List[str]
-        for index, (name, param) in enumerate(self._model.named_parameters()):
-            params.append(param)
-            if param.requires_grad:
-                idxgrd.append(index + 3)
-                pnames.append(name)
-        # Gather or build the sample-wise clipped gradients computing function.
         grads_fn = self._build_samplewise_grads_fn(
-            idxgrd=tuple(idxgrd),
             inputs=len(inputs),
             y_true=(y_true is not None),
             s_wght=(s_wght is not None),
         )
-        # Call it on the current inputs, with optional clipping.
         with torch.no_grad():
-            grads = grads_fn(inputs, y_true, s_wght, *params, clip=max_norm)
-        # Wrap the results into a TorchVector and return it.
-        return TorchVector(dict(zip(pnames, grads)))
+            grads = grads_fn(inputs, y_true, s_wght, clip=clip)  # type: ignore
+        return TorchVector(grads)
 
     @functools.lru_cache
     def _build_samplewise_grads_fn(
         self,
-        idxgrd: Tuple[int, ...],
         inputs: int,
         y_true: bool,
         s_wght: bool,
-    ) -> Callable[..., List[torch.Tensor]]:
-        """Build a functorch-based sample-wise gradients-computation function.
+    ) -> GetGradientsFunction:
+        """Build an optimizer sample-wise gradients-computation function.
 
         This function is cached, i.e. repeated calls with the same parameters
         will return the same object - enabling to reduce runtime costs due to
         building and (when available) compiling the output function.
 
-        Parameters
-        ----------
-        idxgrd: tuple of int
-            Pre-incremented indices of the parameters that require gradients.
-        inputs: int
-            Number of input tensors.
-        y_true: bool
-            Whether a true labels tensor is provided.
-        s_wght: bool
-            Whether a sample weights tensor is provided.
-
         Returns
         -------
-        grads_fn: callable[inputs, y_true, s_wght, *params, /, clip]
-            Functorch-optimized function to efficiently compute sample-
-            wise gradients based on batched inputs, and optionally clip
-            them based on a maximum l2-norm value `clip`.
+        grads_fn: callable[[inputs, y_true, s_wght, clip], grads]
+            Function to efficiently compute and return sample-wise gradients
+            wrt trainable model parameters based on a batch of inputs, with
+            opt. clipping based on a maximum l2-norm value `clip`.
+
+        Note
+        ----
+        The underlying backend code depends on your Torch version, so as to
+        enable optimizing operations using either `functorch` for torch 1.1X
+        or `torch.func` for torch 2.X.
         """
-
-        def forward(inputs, y_true, s_wght, *params):
-            """Conduct the forward pass in a functional way."""
-            y_pred = self._func_model(params, *inputs)
-            return self._compute_loss(y_pred, y_true, s_wght)
-
-        def grads_fn(inputs, y_true, s_wght, *params, clip=None):
-            """Compute gradients and optionally clip them."""
-            gfunc = functorch.grad(forward, argnums=idxgrd)
-            grads = gfunc(inputs, y_true, None, *params)
-            if clip:
-                for grad in grads:
-                    # future: use torch.linalg.norm when supported by functorch
-                    norm = torch.norm(grad, p=2, keepdim=True)
-                    # false-positive; pylint: disable=no-member
-                    grad.mul_(torch.clamp(clip / norm, max=1))
-                    if s_wght is not None:
-                        grad.mul_(s_wght.to(grad.device))
-            return grads
-
-        # Wrap the former function to compute and clip sample-wise gradients.
-        in_axes = [[0] * inputs, 0 if y_true else None, 0 if s_wght else None]
-        in_axes.extend([None] * sum(1 for _ in self._model.parameters()))
-        grads_fn = functorch.vmap(grads_fn, tuple(in_axes))
-        # Compile the resulting function to decrease runtime costs.
-        if not COMPILE_AVAILABLE:
-            return grads_fn
-        return functorch.compile.aot_function(grads_fn, functorch.compile.nop)
+        # NOTE: torch.func is not compatible with torch.compile yet
+        return build_samplewise_grads_fn(
+            self._raw_model, self._loss_fn, inputs, y_true, s_wght
+        )
 
     def apply_updates(
         self,
@@ -387,7 +366,7 @@ class TorchModel(Model):
         self._verify_weights_compatibility(updates, trainable=True)
         with torch.no_grad():
             for key, upd in updates.coefs.items():
-                tns = self._model.get_parameter(key)
+                tns = self._raw_model.get_parameter(key)
                 tns.add_(upd.to(tns.device))
 
     def compute_batch_predictions(
@@ -403,12 +382,32 @@ class TorchModel(Model):
                 "creating labels from the base inputs."
             )
         self._model.eval()
+        self._handle_torch_compile_eval_issue(inputs)
         with torch.no_grad():
             y_pred = self._model(*inputs).cpu().numpy()
         y_true = y_true.cpu().numpy()
         s_wght = None if s_wght is None else s_wght.cpu().numpy()
         return y_true, y_pred, s_wght  # type: ignore
 
+    def _handle_torch_compile_eval_issue(
+        self,
+        inputs: List[torch.Tensor],
+    ) -> None:
+        """Clumsily handle issues with `torch.compile` and `torch.no_grad`.
+
+        As of Torch 2.0.1, running a compiled model's first forward pass
+        within a `torch.no_grad` context results in the model's future
+        weights updates not being properly taken into account.
+
+        Therefore, when wrapping a compiled model, this method runs a lost
+        forward pass outside of a no-grad context on its first call (later
+        it does nothing).
+        """
+        if (self._raw_model is self._model) or hasattr(self, "__eval_called"):
+            return
+        self._model(*inputs)
+        setattr(self, "__eval_called", True)
+
     def loss_function(
         self,
         y_true: np.ndarray,
diff --git a/declearn/model/torch/_samplewise/__init__.py b/declearn/model/torch/_samplewise/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1448a18bb210279d99343f8637cc8d7587cf12a5
--- /dev/null
+++ b/declearn/model/torch/_samplewise/__init__.py
@@ -0,0 +1,78 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Torch-version-dependent code to compute sample-wise gradients."""
+
+from typing import Callable, Dict, List, Optional
+
+import torch
+
+from .shared import GetGradientsFunction
+
+if torch.__version__.startswith("2."):
+    from .torchfunc import build_samplewise_grads_fn_backend
+elif torch.__version__.startswith("1.1"):
+    from .functorch import build_samplewise_grads_fn_backend
+else:
+    # pragma: no cover
+    raise ImportError(f"Unsupported Torch version: {torch.__version__}")
+
+
+__all__ = [
+    "GetGradientsFunction",
+    "build_samplewise_grads_fn",
+]
+
+
+def build_samplewise_grads_fn(
+    model: torch.nn.Module,
+    loss_fn: torch.nn.Module,
+    inputs: int,
+    y_true: bool,
+    s_wght: bool,
+) -> GetGradientsFunction:
+    """Build a torch-specific sample-wise gradients-computation function.
+
+    Parameters
+    ----------
+    model: torch.nn.Module
+        Model that is to be trained.
+    loss_fn: torch.nn.Module
+        Loss-computing module, returning sample-wise loss values.
+    inputs: int
+        Number of input tensors.
+    y_true: bool
+        Whether a true labels tensor is provided.
+    s_wght: bool
+        Whether a sample weights tensor is provided.
+
+    Returns
+    -------
+    grads_fn: callable[[inputs, y_true, s_wght, clip], grads]
+        Function that efficiently computes and returns sample-wise gradients
+        wrt trainable model parameters based on a batch of inputs, with opt.
+        clipping based on a maximum l2-norm value `clip`.
+
+    Note
+    ----
+    The underlying backend code depends on your Torch version, so as to
+    enable optimizing operations using either `functorch` for torch 1.1X
+    or `torch.func` for torch 2.X.
+    """
+    return build_samplewise_grads_fn_backend(
+        model, loss_fn, inputs, y_true, s_wght
+    )
diff --git a/declearn/model/torch/_samplewise/functorch.py b/declearn/model/torch/_samplewise/functorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..d39a3e16602918cd36e755550e1440622efb2866
--- /dev/null
+++ b/declearn/model/torch/_samplewise/functorch.py
@@ -0,0 +1,96 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Implementation of `build_samplewise_grads_fn` for Torch 2.0."""
+
+from typing import List, Tuple
+
+# fmt: off
+import functorch  # type: ignore
+try:
+    import functorch.compile  # type: ignore
+    COMPILE_AVAILABLE = True
+except ModuleNotFoundError:
+    # pragma: no cover
+    COMPILE_AVAILABLE = False
+import torch
+# fmt: on
+
+from declearn.model.torch._samplewise.shared import (
+    GetGradientsFunction,
+    clip_and_scale_grads_inplace,
+)
+
+__all__ = [
+    "build_samplewise_grads_fn_backend",
+]
+
+
+def build_samplewise_grads_fn_backend(
+    model: torch.nn.Module,
+    loss_fn: torch.nn.Module,
+    inputs: int,
+    y_true: bool,
+    s_wght: bool,
+) -> GetGradientsFunction:
+    """Implementation of `build_samplewise_grads_fn` for Torch 1.1X."""
+
+    func_model, *_ = functorch.make_functional_with_buffers(model)
+
+    def run_forward(inputs, y_true, s_wght, buffers, *params):
+        """Run the forward pass in a functional way."""
+        y_pred = func_model(params, buffers, *inputs)
+        s_loss = loss_fn(y_pred, y_true)
+        if s_wght is not None:
+            s_loss.mul_(s_wght.to(s_loss.device))
+        return s_loss.mean()
+
+    def grads_fn(inputs, y_true, s_wght, clip=None):
+        """Compute gradients and optionally clip them."""
+        params, idxgrd, pnames = get_params(model)
+        buffers = list(model.buffers())
+        gfunc = functorch.grad(run_forward, argnums=tuple(idxgrd))
+        grads = gfunc(
+            inputs, y_true, (None if clip else s_wght), buffers, *params
+        )
+        if clip:
+            clip_and_scale_grads_inplace(grads, clip, s_wght)
+        return dict(zip(pnames, grads))
+
+    # Wrap the former function to compute and clip sample-wise gradients.
+    in_dims = ([0] * inputs, 0 if y_true else None, 0 if s_wght else None)
+    grads_fn = functorch.vmap(grads_fn, in_dims, randomness="same")
+    # Compile the resulting function to decrease runtime costs.
+    if not COMPILE_AVAILABLE:
+        # pragma: no cover
+        return grads_fn
+    return functorch.compile.aot_function(grads_fn, functorch.compile.nop)
+
+
+def get_params(
+    model: torch.nn.Module,
+) -> Tuple[List[torch.nn.Parameter], List[int], List[str]]:
+    """Return a model's parameters and the index and name of trainable ones."""
+    params = []  # type: List[torch.nn.Parameter]
+    idxgrd = []  # type: List[int]
+    pnames = []  # type: List[str]
+    for idx, (name, param) in enumerate(model.named_parameters()):
+        params.append(param)
+        if param.requires_grad:
+            idxgrd.append(idx + 4)
+            pnames.append(name)
+    return params, idxgrd, pnames
diff --git a/declearn/model/torch/_samplewise/shared.py b/declearn/model/torch/_samplewise/shared.py
new file mode 100644
index 0000000000000000000000000000000000000000..451ae7c11998c3d337bffe988df2f6e131f5e539
--- /dev/null
+++ b/declearn/model/torch/_samplewise/shared.py
@@ -0,0 +1,56 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Shared code for torch-version-dependent backend code."""
+
+from typing import Callable, Dict, Iterable, List, Optional
+
+import torch
+
+__all__ = [
+    "GetGradientsFunction",
+    "clip_and_scale_grads_inplace",
+]
+
+
+GetGradientsFunction = Callable[
+    [
+        List[torch.Tensor],
+        Optional[torch.Tensor],
+        Optional[torch.Tensor],
+        Optional[float],
+    ],
+    Dict[str, torch.Tensor],
+]
+"""Signature for sample-wise gradients computation functions."""
+
+
+def clip_and_scale_grads_inplace(
+    grads: Iterable[torch.Tensor],
+    clip: float,
+    wght: Optional[torch.Tensor] = None,
+) -> None:
+    """Clip a collection of tensors in-place, based on their euclidean norm.
+
+    Also apply an optional weight tensor to scale the clipped gradients.
+    """
+    for grad in grads:
+        norm = torch.norm(grad, p=2, keepdim=True)
+        # false-positive; pylint: disable=no-member
+        grad.mul_(torch.clamp(clip / norm, max=1))
+        if wght is not None:
+            grad.mul_(wght.to(grad.device))
diff --git a/declearn/model/torch/_samplewise/torchfunc.py b/declearn/model/torch/_samplewise/torchfunc.py
new file mode 100644
index 0000000000000000000000000000000000000000..88aa5b7dca01dcdd77adda76b29d6b080c7487d2
--- /dev/null
+++ b/declearn/model/torch/_samplewise/torchfunc.py
@@ -0,0 +1,80 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Implementation of `build_samplewise_grads_fn` for Torch 2.0."""
+
+from typing import Dict, Tuple
+
+import torch
+
+from declearn.model.torch._samplewise.shared import (
+    GetGradientsFunction,
+    clip_and_scale_grads_inplace,
+)
+
+__all__ = [
+    "build_samplewise_grads_fn_backend",
+]
+
+
+def build_samplewise_grads_fn_backend(
+    model: torch.nn.Module,
+    loss_fn: torch.nn.Module,
+    inputs: int,
+    y_true: bool,
+    s_wght: bool,
+) -> GetGradientsFunction:
+    """Implementation of `build_samplewise_grads_fn` for Torch 2.0."""
+
+    def run_forward(params, frozen, buffers, inputs, y_true, s_wght):
+        """Run the forward pass in a functional way."""
+        # backend closure function; pylint: disable=too-many-arguments
+        y_pred = torch.func.functional_call(
+            model, [params, frozen, buffers], *inputs
+        )
+        s_loss = loss_fn(y_pred, y_true)
+        if s_wght is not None:
+            s_loss.mul_(s_wght.to(s_loss.device))
+        return s_loss.mean()
+
+    get_grads = torch.func.grad(run_forward, argnums=0)
+
+    def get_clipped_grads(inputs, y_true, s_wght, clip=None):
+        """Compute gradients and optionally clip them."""
+        params, frozen = get_params(model)
+        buffers = dict(model.named_buffers())
+        grads = get_grads(
+            params, frozen, buffers, inputs, y_true, None if clip else s_wght
+        )
+        if clip:
+            clip_and_scale_grads_inplace(grads.values(), clip, s_wght)
+        return grads
+
+    # Wrap the former function to compute and clip sample-wise gradients.
+    in_dims = ([0] * inputs, 0 if y_true else None, 0 if s_wght else None)
+    return torch.func.vmap(get_clipped_grads, in_dims, randomness="same")
+
+
+def get_params(
+    model: torch.nn.Module,
+) -> Tuple[Dict[str, torch.nn.Parameter], Dict[str, torch.nn.Parameter]]:
+    """Return a model's parameters, split between trainable and frozen ones."""
+    params = {}  # type: Dict[str, torch.nn.Parameter]
+    frozen = {}  # type: Dict[str, torch.nn.Parameter]
+    for name, param in model.named_parameters():
+        (params if param.requires_grad else frozen)[name] = param
+    return params, frozen
diff --git a/docs/setup.md b/docs/setup.md
index 979ef90c9dfc77757bb0e2c99b51cca58f716b82..7b00c0aafd38d5390a402fb70dd21747b817de97 100644
--- a/docs/setup.md
+++ b/docs/setup.md
@@ -110,3 +110,10 @@ pip install declearn[all,tests]  # install all extra and testing dependencies
 - On some systems, the square brackets used our pip install are not properly
   parsed. Try replacing `[` by `\[` and `]` by `\]`, or putting the instruction
   between quotes (`pip install "declearn[...]"`).
+- Regarding Torch: declearn currently supports both late 1.10-1.13 versions and
+  2.X ones. You may use either one freely, but may run into issues regarding
+  co-dependent versions of `torch`, `functorch` (in 1.10-1.13) and `opacus` (if
+  you want to use [differential privacy](./user-guide/local_dp.md) features).
+  You may use the `torch1` or `torch2` extra dependency specifier to explicitly
+  target either the 1.13 or latest 2.X torch version and install the proper
+  versions of the other packages (including opacus).
diff --git a/pyproject.toml b/pyproject.toml
index 11f8094e2f7a2c26d93e775db9f5daa590b9a129..2161244d8d6de8a5b4e8408ccd0a494f5fe575e3 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -48,13 +48,13 @@ dependencies = [
 [project.optional-dependencies]
 all = [  # all non-tests extra dependencies
     "dm-haiku == 0.0.9",
-    "functorch",
+    "functorch >= 1.10, < 3.0",
     "grpcio >= 1.45",
     "jax[cpu] >= 0.4, < 0.5",
     "opacus ~= 1.1",
     "protobuf >= 3.19",
     "tensorflow ~= 2.5",
-    "torch ~= 1.10",
+    "torch >= 1.10, < 3.0",
     "websockets ~= 10.1",
 ]
 # thematically grouped dependencies (part of "all")
@@ -72,9 +72,19 @@ haiku = [
 tensorflow = [
     "tensorflow ~= 2.5",
 ]
-torch = [
-    "functorch",  # note: functorch is internalized by torch 2.0
-    "torch ~= 1.10",
+torch = [  # generic requirements for Torch
+    "functorch >= 1.10, < 3.0",  # note: unused with torch >=2.0
+    "torch >= 1.10, < 3.0",
+]
+torch1 = [  # Torch 1.13 (latest pre-2.0 version) + Opacus
+    "functorch ~= 1.13.0",
+    "opacus ~= 1.3.0",
+    "torch ~= 1.13.0",
+]
+torch2 = [  # Torch 2.X version + Opacus
+    "functorch ~= 2.0",  # unused, but useful to linters and pip
+    "opacus ~= 1.4",
+    "torch ~= 2.0",
 ]
 websockets = [
     "websockets ~= 10.1",
diff --git a/scripts/run_tests.sh b/scripts/run_tests.sh
index 189986b730d1ee3e75ccd723bc49697598371879..c962e8b6a6645feb0b564a4779db95e107bcfb28 100644
--- a/scripts/run_tests.sh
+++ b/scripts/run_tests.sh
@@ -114,6 +114,7 @@ run_declearn_tests() {
     commands=(
         "run_unit_tests $@"
         "run_integration_tests $@"
+        "run_torch13_tests $@"
     )
     run_commands "declearn test suite" "${commands[@]}"
     status=$?
@@ -152,6 +153,33 @@ run_integration_tests() {
 }
 
 
+run_torch13_tests() {
+    : '
+    Verbosely run Torch 1.13-specific unit tests.
+
+    Install Torch 1.13 at the start of this function, and re-install
+    torch >=2.0 at the end of it, together with its co-dependencies.
+    '
+    echo "Re-installing torch 1.13 and its co-dependencies."
+    pip install .[torch1]
+    if [[ $? -eq 0 ]]; then
+        echo "Running unit tests for torch 1.13."
+        command="pytest $@
+            --cov --cov-append --cov-report=
+            test/model/test_torch.py
+        "
+        run_command $command
+        status=$?
+    else
+        echo "\e[31mSkipping tests as installation failed.\e[0m"
+        status=1
+    fi
+    echo "Re-installing torch 2.X and its co-dependencies."
+    pip install .[torch2]
+    return $status
+}
+
+
 main() {
     if [[ $# -eq 0 ]]; then
         echo "Missing required positional argument."
diff --git a/test/model/test_torch.py b/test/model/test_torch.py
index dafe8c8cb66b1fbfe12a7201d4bb35a0c4eb23b5..498b87a7549b49a104766df625b03b0370c5dac8 100644
--- a/test/model/test_torch.py
+++ b/test/model/test_torch.py
@@ -19,6 +19,7 @@
 
 import json
 import sys
+import typing
 from typing import Any, List, Literal, Tuple
 
 import numpy as np
@@ -69,6 +70,9 @@ class FlattenCNNOutput(torch.nn.Module):
         return inputs.view(*shape)
 
 
+Kind = Literal["MLP", "MLP-tune", "MLP-compile", "RNN", "CNN"]
+
+
 class TorchTestCase(ModelTestCase):
     """PyTorch test-case-provider fixture.
 
@@ -88,6 +92,11 @@ class TorchTestCase(ModelTestCase):
         - stack: 32 7x7 conv. filters, then 8x8 max pooling
                  16 5x5 conv. filters, then 8x8 avg pooling
                  1 output neuron with sigmoid activation
+
+
+    Additional scenarios include "MLP-tune", where the MLP's first hidden
+    layer is frozen, and "MLP-compile", where the MLP is optimized using
+    `torch.compile` (available in torch >=2.0 only).
     """
 
     vector_cls = TorchVector
@@ -95,11 +104,11 @@ class TorchTestCase(ModelTestCase):
 
     def __init__(
         self,
-        kind: Literal["MLP", "MLP-tune", "RNN", "CNN"],
+        kind: Kind,
         device: Literal["CPU", "GPU"],
     ) -> None:
         """Specify the desired model architecture."""
-        if kind not in ("MLP", "MLP-tune", "RNN", "CNN"):
+        if kind not in typing.get_args(Kind):
             raise ValueError(f"Invalid torch test architecture: '{kind}'.")
         self.kind = kind
         self.device = device
@@ -169,6 +178,8 @@ class TorchTestCase(ModelTestCase):
                 torch.nn.Sigmoid(),
             ]
         nnmod = torch.nn.Sequential(*stack)
+        if self.kind == "MLP-compile":
+            nnmod = torch.compile(nnmod)  # type: ignore
         return TorchModel(nnmod, loss=torch.nn.BCELoss())
 
     def assert_correct_device(
@@ -184,13 +195,15 @@ class TorchTestCase(ModelTestCase):
 
 @pytest.fixture(name="test_case")
 def fixture_test_case(
-    kind: Literal["MLP", "MLP-tune", "RNN", "CNN"],
+    kind: Kind,
     device: Literal["CPU", "GPU"],
     cpu_only: bool,
 ) -> TorchTestCase:
     """Fixture to access a TorchTestCase."""
     if cpu_only and device == "GPU":
         pytest.skip(reason="--cpu-only mode")
+    if kind == "MLP-compile" and not hasattr(torch, "compile"):
+        pytest.skip(reason="'torch.compile' is unavailable")
     return TorchTestCase(kind, device)
 
 
@@ -200,7 +213,7 @@ if torch.cuda.device_count():
 
 
 @pytest.mark.parametrize("device", DEVICES)
-@pytest.mark.parametrize("kind", ["MLP", "MLP-tune", "RNN", "CNN"])
+@pytest.mark.parametrize("kind", typing.get_args(Kind))
 class TestTorchModel(ModelTestSuite):
     """Unit tests for declearn.model.torch.TorchModel."""