diff --git a/declearn/model/torch/_samplewise/functorch.py b/declearn/model/torch/_samplewise/functorch.py
index d39a3e16602918cd36e755550e1440622efb2866..37ea5ecbe07a716727bada27e05abac153e1bbb6 100644
--- a/declearn/model/torch/_samplewise/functorch.py
+++ b/declearn/model/torch/_samplewise/functorch.py
@@ -19,16 +19,8 @@
 
 from typing import List, Tuple
 
-# fmt: off
 import functorch  # type: ignore
-try:
-    import functorch.compile  # type: ignore
-    COMPILE_AVAILABLE = True
-except ModuleNotFoundError:
-    # pragma: no cover
-    COMPILE_AVAILABLE = False
 import torch
-# fmt: on
 
 from declearn.model.torch._samplewise.shared import (
     GetGradientsFunction,
@@ -73,12 +65,7 @@ def build_samplewise_grads_fn_backend(
 
     # Wrap the former function to compute and clip sample-wise gradients.
     in_dims = ([0] * inputs, 0 if y_true else None, 0 if s_wght else None)
-    grads_fn = functorch.vmap(grads_fn, in_dims, randomness="same")
-    # Compile the resulting function to decrease runtime costs.
-    if not COMPILE_AVAILABLE:
-        # pragma: no cover
-        return grads_fn
-    return functorch.compile.aot_function(grads_fn, functorch.compile.nop)
+    return functorch.vmap(grads_fn, in_dims, randomness="same")
 
 
 def get_params(