Mentions légales du service

Skip to content
Snippets Groups Projects

Refactor 'TorchModel' backend code to compute sample-wise gradients.

Merged ANDREY Paul requested to merge functorch into develop
All threads resolved!

This MR introduces some backend changes to TorchModel:

  • Cache the functorch-built vmapped functions that compute (and opt. clip) sample-wise gradients.
  • When available, use functorch.compile to try and optimize these functions (resulting in lower execution runtimes on repeated calls thanks to their caching).
  • Note: with the release of torch-2.0, the functorch.compile API, which is experimental as of torch-1.13, is already deprecated. Hence this patch is going to be short-lived - but it may in turn pave the way to the future backend changes to add support for and take advantage of new torch-2.0 features.

Note: transitioning to torch 2.0 (optionally keeping support for torch 1.X at first) should be considered in the (near) future, but will require refactoring all of the functorch code. For the time being, this refactoring provides with some code optimization for (func)torch 1.13.

Merge request reports

Loading
Loading

Activity

Filter activity
  • Approvals
  • Assignees & reviewers
  • Comments (from bots)
  • Comments (from users)
  • Commits & branches
  • Edits
  • Labels
  • Lock status
  • Mentions
  • Merge request status
  • Tracking
Please register or sign in to reply
Loading