Drop 'functorch.compile' use.
- The use of 'functorch.compile' over a function that takes variable-size batch inputs proves impossible, as the tracing on first call creates a computation graph with fixed dimensions. - As a result, the tentative compilation of per-sample clipped gradients computation prevents the proper use of DP-SGD with the functorch backend. - An alternative attempt was to compile the sample-wise function and vmap it afterwards, but this is currently unsupported (and unlikely to be as part of functorch, as development efforts have moved to 'torch.func'). - This commit therefore drops the use of 'functorch.compile'.
Please register or sign in to comment