Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit 4de359e9 authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Drop 'functorch.compile' use.

- The use of 'functorch.compile' over a function that takes variable-size
  batch inputs proves impossible, as the tracing on first call creates a
  computation graph with fixed dimensions.
- As a result, the tentative compilation of per-sample clipped gradients
  computation prevents the proper use of DP-SGD with the functorch backend.
- An alternative attempt was to compile the sample-wise function and vmap
  it afterwards, but this is currently unsupported (and unlikely to be as
  part of functorch, as development efforts have moved to 'torch.func').
- This commit therefore drops the use of 'functorch.compile'.
parent c65d0b4c
No related branches found
No related tags found
No related merge requests found
Pipeline #852811 passed
...@@ -19,16 +19,8 @@ ...@@ -19,16 +19,8 @@
from typing import List, Tuple from typing import List, Tuple
# fmt: off
import functorch # type: ignore import functorch # type: ignore
try:
import functorch.compile # type: ignore
COMPILE_AVAILABLE = True
except ModuleNotFoundError:
# pragma: no cover
COMPILE_AVAILABLE = False
import torch import torch
# fmt: on
from declearn.model.torch._samplewise.shared import ( from declearn.model.torch._samplewise.shared import (
GetGradientsFunction, GetGradientsFunction,
...@@ -73,12 +65,7 @@ def build_samplewise_grads_fn_backend( ...@@ -73,12 +65,7 @@ def build_samplewise_grads_fn_backend(
# Wrap the former function to compute and clip sample-wise gradients. # Wrap the former function to compute and clip sample-wise gradients.
in_dims = ([0] * inputs, 0 if y_true else None, 0 if s_wght else None) in_dims = ([0] * inputs, 0 if y_true else None, 0 if s_wght else None)
grads_fn = functorch.vmap(grads_fn, in_dims, randomness="same") return functorch.vmap(grads_fn, in_dims, randomness="same")
# Compile the resulting function to decrease runtime costs.
if not COMPILE_AVAILABLE:
# pragma: no cover
return grads_fn
return functorch.compile.aot_function(grads_fn, functorch.compile.nop)
def get_params( def get_params(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment