Refactor 'TorchModel' backend code to compute sample-wise gradients.
This MR introduces some backend changes to TorchModel
:
- Cache the functorch-built vmapped functions that compute (and opt. clip) sample-wise gradients.
- When available, use
functorch.compile
to try and optimize these functions (resulting in lower execution runtimes on repeated calls thanks to their caching). - Note: with the release of torch-2.0, the
functorch.compile
API, which is experimental as of torch-1.13, is already deprecated. Hence this patch is going to be short-lived - but it may in turn pave the way to the future backend changes to add support for and take advantage of new torch-2.0 features.
Note: transitioning to torch 2.0 (optionally keeping support for torch 1.X at first) should be considered in the (near) future, but will require refactoring all of the functorch code. For the time being, this refactoring provides with some code optimization for (func)torch 1.13.