diff --git a/declearn/model/torch/_model.py b/declearn/model/torch/_model.py
index 8eefac9cdb0b111f488c29de373bf3f9b7b911dc..21ce772adf4d606c9da584fcb9376fa3e353af6d 100644
--- a/declearn/model/torch/_model.py
+++ b/declearn/model/torch/_model.py
@@ -71,6 +71,25 @@ class TorchModel(Model):
       `update_device_policy` method.
     - You may consult the device policy currently enforced by a TorchModel
       instance by accessing its `device_policy` property.
+
+    Notes regarding `torch.compile` support (torch >=2.0):
+
+    - If you want the wrapped model to be optimized via `torch.compile`, it
+      should be so _prior_ to being wrapped using `TorchModel`.
+    - The compilation will not be used when computing sample-wise-clipped
+      gradients, as `torch.func` and `torch.compile` do not play along yet.
+    - The information that the module was compiled will be saved as part of
+      the `TorchModel` config, so that using `TorchModel.from_config` will
+      trigger it again when possible; this is however limited to calling
+      `torch.compile`, meaning that any other argument will be lost.
+    - Note that the former point notably affects the way clients will run
+      a server-emitted `TorchModel` as part of a FL process: client that
+      run Torch 1.X will be able to use the un-optimized module, while
+      clients running Torch 2.0 will use compilation, but in a rather crude
+      flavor, that may not be suitable for some specific/advanced cases.
+    - Enhanced support for `torch.compile` is on the roadmap. If you run
+      into issues and/or have requests or advice on that topic, feel free
+      to let us know by contacting us via mail or GitLab.
     """
 
     def __init__(
@@ -97,13 +116,19 @@ class TorchModel(Model):
         # Select the device where to place computations, and wrap the model.
         policy = get_device_policy()
         device = select_device(gpu=policy.gpu, idx=policy.idx)
-        model = AutoDeviceModule(model, device=device)
-        super().__init__(model)
+        super().__init__(AutoDeviceModule(model, device=device))
         # Assign loss module and set it not to reduce sample-wise values.
         if not isinstance(loss, torch.nn.Module):
             raise TypeError("'loss' should be a torch.nn.Module instance.")
         loss.reduction = "none"  # type: ignore
         self._loss_fn = AutoDeviceModule(loss, device=device)
+        # Detect torch-compiled models and extract underlying module.
+        self._raw_model = self._model
+        if hasattr(torch, "compile") and hasattr(model, "_orig_mod"):
+            self._raw_model = AutoDeviceModule(
+                module=getattr(model, "_orig_mod"),
+                device=self._model.device,
+            )
 
     @property
     def device_policy(
@@ -131,12 +156,16 @@ class TorchModel(Model):
             "PyTorch JSON serialization relies on pickle, which may be unsafe."
         )
         with io.BytesIO() as buffer:
-            torch.save(self._model.module, buffer)
+            torch.save(self._raw_model.module, buffer)
             model = buffer.getbuffer().hex()
         with io.BytesIO() as buffer:
             torch.save(self._loss_fn.module, buffer)
             loss = buffer.getbuffer().hex()
-        return {"model": model, "loss": loss}
+        return {
+            "model": model,
+            "loss": loss,
+            "compile": self._raw_model is not self._model,
+        }
 
     @classmethod
     def from_config(
@@ -148,13 +177,15 @@ class TorchModel(Model):
             model = torch.load(buffer)
         with io.BytesIO(bytes.fromhex(config["loss"])) as buffer:
             loss = torch.load(buffer)
+        if config.get("compile", False) and hasattr(torch, "compile"):
+            model = torch.compile(model)
         return cls(model=model, loss=loss)
 
     def get_weights(
         self,
         trainable: bool = False,
     ) -> TorchVector:
-        params = self._model.named_parameters()
+        params = self._raw_model.named_parameters()
         if trainable:
             weights = {k: p.data for k, p in params if p.requires_grad}
         else:
@@ -171,12 +202,12 @@ class TorchModel(Model):
             raise TypeError("TorchModel requires TorchVector weights.")
         self._verify_weights_compatibility(weights, trainable=trainable)
         if trainable:
-            state_dict = self._model.state_dict()
+            state_dict = self._raw_model.state_dict()
             state_dict.update(weights.coefs)
         else:
             state_dict = weights.coefs
         # NOTE: this preserves the device placement of current states
-        self._model.load_state_dict(state_dict)
+        self._raw_model.load_state_dict(state_dict)
 
     def _verify_weights_compatibility(
         self,
@@ -200,12 +231,9 @@ class TorchModel(Model):
             In case some expected keys are missing, or additional keys
             are present. Be verbose about the identified mismatch(es).
         """
+        params = self._raw_model.named_parameters()
         received = set(vector.coefs)
-        expected = {
-            name
-            for name, param in self._model.named_parameters()
-            if (not trainable) or param.requires_grad
-        }
+        expected = {n for n, p in params if (not trainable) or p.requires_grad}
         raise_on_stringsets_mismatch(
             received, expected, context="model weights"
         )
@@ -235,7 +263,7 @@ class TorchModel(Model):
         # Collect weights' gradients and return them in a Vector container.
         grads = {
             k: p.grad.detach().clone()
-            for k, p in self._model.named_parameters()
+            for k, p in self._raw_model.named_parameters()
             if p.requires_grad
         }
         return TorchVector(grads)
@@ -324,8 +352,9 @@ class TorchModel(Model):
         enable optimizing operations using either `functorch` for torch 1.1X
         or `torch.func` for torch 2.X.
         """
+        # NOTE: torch.func is not compatible with torch.compile yet
         return build_samplewise_grads_fn(
-            self._model, self._loss_fn, inputs, y_true, s_wght
+            self._raw_model, self._loss_fn, inputs, y_true, s_wght
         )
 
     def apply_updates(
@@ -337,7 +366,7 @@ class TorchModel(Model):
         self._verify_weights_compatibility(updates, trainable=True)
         with torch.no_grad():
             for key, upd in updates.coefs.items():
-                tns = self._model.get_parameter(key)
+                tns = self._raw_model.get_parameter(key)
                 tns.add_(upd.to(tns.device))
 
     def compute_batch_predictions(
@@ -353,12 +382,32 @@ class TorchModel(Model):
                 "creating labels from the base inputs."
             )
         self._model.eval()
+        self._handle_torch_compile_eval_issue(inputs)
         with torch.no_grad():
             y_pred = self._model(*inputs).cpu().numpy()
         y_true = y_true.cpu().numpy()
         s_wght = None if s_wght is None else s_wght.cpu().numpy()
         return y_true, y_pred, s_wght  # type: ignore
 
+    def _handle_torch_compile_eval_issue(
+        self,
+        inputs: List[torch.Tensor],
+    ) -> None:
+        """Clumsily handle issues with `torch.compile` and `torch.no_grad`.
+
+        As of Torch 2.0.1, running a compiled model's first forward pass
+        within a `torch.no_grad` context results in the model's future
+        weights updates not being properly taken into account.
+
+        Therefore, when wrapping a compiled model, this method runs a lost
+        forward pass outside of a no-grad context on its first call (later
+        it does nothing).
+        """
+        if (self._raw_model is self._model) or hasattr(self, "__eval_called"):
+            return
+        self._model(*inputs)
+        setattr(self, "__eval_called", True)
+
     def loss_function(
         self,
         y_true: np.ndarray,
diff --git a/test/model/test_torch.py b/test/model/test_torch.py
index dafe8c8cb66b1fbfe12a7201d4bb35a0c4eb23b5..498b87a7549b49a104766df625b03b0370c5dac8 100644
--- a/test/model/test_torch.py
+++ b/test/model/test_torch.py
@@ -19,6 +19,7 @@
 
 import json
 import sys
+import typing
 from typing import Any, List, Literal, Tuple
 
 import numpy as np
@@ -69,6 +70,9 @@ class FlattenCNNOutput(torch.nn.Module):
         return inputs.view(*shape)
 
 
+Kind = Literal["MLP", "MLP-tune", "MLP-compile", "RNN", "CNN"]
+
+
 class TorchTestCase(ModelTestCase):
     """PyTorch test-case-provider fixture.
 
@@ -88,6 +92,11 @@ class TorchTestCase(ModelTestCase):
         - stack: 32 7x7 conv. filters, then 8x8 max pooling
                  16 5x5 conv. filters, then 8x8 avg pooling
                  1 output neuron with sigmoid activation
+
+
+    Additional scenarios include "MLP-tune", where the MLP's first hidden
+    layer is frozen, and "MLP-compile", where the MLP is optimized using
+    `torch.compile` (available in torch >=2.0 only).
     """
 
     vector_cls = TorchVector
@@ -95,11 +104,11 @@ class TorchTestCase(ModelTestCase):
 
     def __init__(
         self,
-        kind: Literal["MLP", "MLP-tune", "RNN", "CNN"],
+        kind: Kind,
         device: Literal["CPU", "GPU"],
     ) -> None:
         """Specify the desired model architecture."""
-        if kind not in ("MLP", "MLP-tune", "RNN", "CNN"):
+        if kind not in typing.get_args(Kind):
             raise ValueError(f"Invalid torch test architecture: '{kind}'.")
         self.kind = kind
         self.device = device
@@ -169,6 +178,8 @@ class TorchTestCase(ModelTestCase):
                 torch.nn.Sigmoid(),
             ]
         nnmod = torch.nn.Sequential(*stack)
+        if self.kind == "MLP-compile":
+            nnmod = torch.compile(nnmod)  # type: ignore
         return TorchModel(nnmod, loss=torch.nn.BCELoss())
 
     def assert_correct_device(
@@ -184,13 +195,15 @@ class TorchTestCase(ModelTestCase):
 
 @pytest.fixture(name="test_case")
 def fixture_test_case(
-    kind: Literal["MLP", "MLP-tune", "RNN", "CNN"],
+    kind: Kind,
     device: Literal["CPU", "GPU"],
     cpu_only: bool,
 ) -> TorchTestCase:
     """Fixture to access a TorchTestCase."""
     if cpu_only and device == "GPU":
         pytest.skip(reason="--cpu-only mode")
+    if kind == "MLP-compile" and not hasattr(torch, "compile"):
+        pytest.skip(reason="'torch.compile' is unavailable")
     return TorchTestCase(kind, device)
 
 
@@ -200,7 +213,7 @@ if torch.cuda.device_count():
 
 
 @pytest.mark.parametrize("device", DEVICES)
-@pytest.mark.parametrize("kind", ["MLP", "MLP-tune", "RNN", "CNN"])
+@pytest.mark.parametrize("kind", typing.get_args(Kind))
 class TestTorchModel(ModelTestSuite):
     """Unit tests for declearn.model.torch.TorchModel."""