Mentions légales du service

Skip to content

Refactor 'TorchModel' backend code to compute sample-wise gradients.

ANDREY Paul requested to merge functorch into develop

This MR introduces some backend changes to TorchModel:

  • Cache the functorch-built vmapped functions that compute (and opt. clip) sample-wise gradients.
  • When available, use functorch.compile to try and optimize these functions (resulting in lower execution runtimes on repeated calls thanks to their caching).
  • Note: with the release of torch-2.0, the functorch.compile API, which is experimental as of torch-1.13, is already deprecated. Hence this patch is going to be short-lived - but it may in turn pave the way to the future backend changes to add support for and take advantage of new torch-2.0 features.

Note: transitioning to torch 2.0 (optionally keeping support for torch 1.X at first) should be considered in the (near) future, but will require refactoring all of the functorch code. For the time being, this refactoring provides with some code optimization for (func)torch 1.13.

Merge request reports