diff --git a/declearn/model/torch/_samplewise/functorch.py b/declearn/model/torch/_samplewise/functorch.py index d39a3e16602918cd36e755550e1440622efb2866..37ea5ecbe07a716727bada27e05abac153e1bbb6 100644 --- a/declearn/model/torch/_samplewise/functorch.py +++ b/declearn/model/torch/_samplewise/functorch.py @@ -19,16 +19,8 @@ from typing import List, Tuple -# fmt: off import functorch # type: ignore -try: - import functorch.compile # type: ignore - COMPILE_AVAILABLE = True -except ModuleNotFoundError: - # pragma: no cover - COMPILE_AVAILABLE = False import torch -# fmt: on from declearn.model.torch._samplewise.shared import ( GetGradientsFunction, @@ -73,12 +65,7 @@ def build_samplewise_grads_fn_backend( # Wrap the former function to compute and clip sample-wise gradients. in_dims = ([0] * inputs, 0 if y_true else None, 0 if s_wght else None) - grads_fn = functorch.vmap(grads_fn, in_dims, randomness="same") - # Compile the resulting function to decrease runtime costs. - if not COMPILE_AVAILABLE: - # pragma: no cover - return grads_fn - return functorch.compile.aot_function(grads_fn, functorch.compile.nop) + return functorch.vmap(grads_fn, in_dims, randomness="same") def get_params(