From 4de359e9734cc4cc372a2351d8e0ff2d29a0fa3d Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Wed, 6 Sep 2023 10:49:35 +0200 Subject: [PATCH] Drop 'functorch.compile' use. - The use of 'functorch.compile' over a function that takes variable-size batch inputs proves impossible, as the tracing on first call creates a computation graph with fixed dimensions. - As a result, the tentative compilation of per-sample clipped gradients computation prevents the proper use of DP-SGD with the functorch backend. - An alternative attempt was to compile the sample-wise function and vmap it afterwards, but this is currently unsupported (and unlikely to be as part of functorch, as development efforts have moved to 'torch.func'). - This commit therefore drops the use of 'functorch.compile'. --- declearn/model/torch/_samplewise/functorch.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/declearn/model/torch/_samplewise/functorch.py b/declearn/model/torch/_samplewise/functorch.py index d39a3e16..37ea5ecb 100644 --- a/declearn/model/torch/_samplewise/functorch.py +++ b/declearn/model/torch/_samplewise/functorch.py @@ -19,16 +19,8 @@ from typing import List, Tuple -# fmt: off import functorch # type: ignore -try: - import functorch.compile # type: ignore - COMPILE_AVAILABLE = True -except ModuleNotFoundError: - # pragma: no cover - COMPILE_AVAILABLE = False import torch -# fmt: on from declearn.model.torch._samplewise.shared import ( GetGradientsFunction, @@ -73,12 +65,7 @@ def build_samplewise_grads_fn_backend( # Wrap the former function to compute and clip sample-wise gradients. in_dims = ([0] * inputs, 0 if y_true else None, 0 if s_wght else None) - grads_fn = functorch.vmap(grads_fn, in_dims, randomness="same") - # Compile the resulting function to decrease runtime costs. - if not COMPILE_AVAILABLE: - # pragma: no cover - return grads_fn - return functorch.compile.aot_function(grads_fn, functorch.compile.nop) + return functorch.vmap(grads_fn, in_dims, randomness="same") def get_params( -- GitLab