diff --git a/declearn/model/torch/_samplewise/functorch.py b/declearn/model/torch/_samplewise/functorch.py
index fc8e613b77a3e925676001ac05df5bc347e0e1fd..d39a3e16602918cd36e755550e1440622efb2866 100644
--- a/declearn/model/torch/_samplewise/functorch.py
+++ b/declearn/model/torch/_samplewise/functorch.py
@@ -49,11 +49,11 @@ def build_samplewise_grads_fn_backend(
 ) -> GetGradientsFunction:
     """Implementation of `build_samplewise_grads_fn` for Torch 1.1X."""
 
-    func_model, _ = functorch.make_functional(model)
+    func_model, *_ = functorch.make_functional_with_buffers(model)
 
-    def run_forward(inputs, y_true, s_wght, *params):
+    def run_forward(inputs, y_true, s_wght, buffers, *params):
         """Run the forward pass in a functional way."""
-        y_pred = func_model(params, *inputs)
+        y_pred = func_model(params, buffers, *inputs)
         s_loss = loss_fn(y_pred, y_true)
         if s_wght is not None:
             s_loss.mul_(s_wght.to(s_loss.device))
@@ -62,15 +62,18 @@ def build_samplewise_grads_fn_backend(
     def grads_fn(inputs, y_true, s_wght, clip=None):
         """Compute gradients and optionally clip them."""
         params, idxgrd, pnames = get_params(model)
+        buffers = list(model.buffers())
         gfunc = functorch.grad(run_forward, argnums=tuple(idxgrd))
-        grads = gfunc(inputs, y_true, (None if clip else s_wght), *params)
+        grads = gfunc(
+            inputs, y_true, (None if clip else s_wght), buffers, *params
+        )
         if clip:
             clip_and_scale_grads_inplace(grads, clip, s_wght)
         return dict(zip(pnames, grads))
 
     # Wrap the former function to compute and clip sample-wise gradients.
     in_dims = ([0] * inputs, 0 if y_true else None, 0 if s_wght else None)
-    grads_fn = functorch.vmap(grads_fn, in_dims)
+    grads_fn = functorch.vmap(grads_fn, in_dims, randomness="same")
     # Compile the resulting function to decrease runtime costs.
     if not COMPILE_AVAILABLE:
         # pragma: no cover
@@ -88,6 +91,6 @@ def get_params(
     for idx, (name, param) in enumerate(model.named_parameters()):
         params.append(param)
         if param.requires_grad:
-            idxgrd.append(idx + 3)
+            idxgrd.append(idx + 4)
             pnames.append(name)
     return params, idxgrd, pnames
diff --git a/declearn/model/torch/_samplewise/torchfunc.py b/declearn/model/torch/_samplewise/torchfunc.py
index d330d4f9f59647a48198b80b231ac5612ce5a6f0..88aa5b7dca01dcdd77adda76b29d6b080c7487d2 100644
--- a/declearn/model/torch/_samplewise/torchfunc.py
+++ b/declearn/model/torch/_samplewise/torchfunc.py
@@ -40,9 +40,12 @@ def build_samplewise_grads_fn_backend(
 ) -> GetGradientsFunction:
     """Implementation of `build_samplewise_grads_fn` for Torch 2.0."""
 
-    def run_forward(params, frozen, inputs, y_true, s_wght):
+    def run_forward(params, frozen, buffers, inputs, y_true, s_wght):
         """Run the forward pass in a functional way."""
-        y_pred = torch.func.functional_call(model, [params, frozen], *inputs)
+        # backend closure function; pylint: disable=too-many-arguments
+        y_pred = torch.func.functional_call(
+            model, [params, frozen, buffers], *inputs
+        )
         s_loss = loss_fn(y_pred, y_true)
         if s_wght is not None:
             s_loss.mul_(s_wght.to(s_loss.device))
@@ -53,8 +56,9 @@ def build_samplewise_grads_fn_backend(
     def get_clipped_grads(inputs, y_true, s_wght, clip=None):
         """Compute gradients and optionally clip them."""
         params, frozen = get_params(model)
+        buffers = dict(model.named_buffers())
         grads = get_grads(
-            params, frozen, inputs, y_true, None if clip else s_wght
+            params, frozen, buffers, inputs, y_true, None if clip else s_wght
         )
         if clip:
             clip_and_scale_grads_inplace(grads.values(), clip, s_wght)
@@ -62,7 +66,7 @@ def build_samplewise_grads_fn_backend(
 
     # Wrap the former function to compute and clip sample-wise gradients.
     in_dims = ([0] * inputs, 0 if y_true else None, 0 if s_wght else None)
-    return torch.func.vmap(get_clipped_grads, in_dims)
+    return torch.func.vmap(get_clipped_grads, in_dims, randomness="same")
 
 
 def get_params(