From 4de359e9734cc4cc372a2351d8e0ff2d29a0fa3d Mon Sep 17 00:00:00 2001
From: Paul Andrey <paul.andrey@inria.fr>
Date: Wed, 6 Sep 2023 10:49:35 +0200
Subject: [PATCH] Drop 'functorch.compile' use.

- The use of 'functorch.compile' over a function that takes variable-size
  batch inputs proves impossible, as the tracing on first call creates a
  computation graph with fixed dimensions.
- As a result, the tentative compilation of per-sample clipped gradients
  computation prevents the proper use of DP-SGD with the functorch backend.
- An alternative attempt was to compile the sample-wise function and vmap
  it afterwards, but this is currently unsupported (and unlikely to be as
  part of functorch, as development efforts have moved to 'torch.func').
- This commit therefore drops the use of 'functorch.compile'.
---
 declearn/model/torch/_samplewise/functorch.py | 15 +--------------
 1 file changed, 1 insertion(+), 14 deletions(-)

diff --git a/declearn/model/torch/_samplewise/functorch.py b/declearn/model/torch/_samplewise/functorch.py
index d39a3e16..37ea5ecb 100644
--- a/declearn/model/torch/_samplewise/functorch.py
+++ b/declearn/model/torch/_samplewise/functorch.py
@@ -19,16 +19,8 @@
 
 from typing import List, Tuple
 
-# fmt: off
 import functorch  # type: ignore
-try:
-    import functorch.compile  # type: ignore
-    COMPILE_AVAILABLE = True
-except ModuleNotFoundError:
-    # pragma: no cover
-    COMPILE_AVAILABLE = False
 import torch
-# fmt: on
 
 from declearn.model.torch._samplewise.shared import (
     GetGradientsFunction,
@@ -73,12 +65,7 @@ def build_samplewise_grads_fn_backend(
 
     # Wrap the former function to compute and clip sample-wise gradients.
     in_dims = ([0] * inputs, 0 if y_true else None, 0 if s_wght else None)
-    grads_fn = functorch.vmap(grads_fn, in_dims, randomness="same")
-    # Compile the resulting function to decrease runtime costs.
-    if not COMPILE_AVAILABLE:
-        # pragma: no cover
-        return grads_fn
-    return functorch.compile.aot_function(grads_fn, functorch.compile.nop)
+    return functorch.vmap(grads_fn, in_dims, randomness="same")
 
 
 def get_params(
-- 
GitLab