-
ANDREY Paul authored
- Torch 2.0 was released in march 2023, introducing a number of new features to the torch ecosystem, while being non-breaking wrt past versions on most points. - One of the salient changes is the introduction of `torch.func`, that integrates features previously introduced as part of the `functorch` package, which we have been relying upon in order to efficiently compute and clip sample-wise gradients. - This commit therefore introduces a new backend code branch so as to make use of `torch.func` when `torch~=2.0` is used, and retain the existing `functorch`-based branch when older torch versions are used. - As part of this effort, some of the `TorchModel` backend code was refactored, notably deferring the functional transform of the input `torch.nn.Module` to the first call for sample-wise gradients computation. This has the nice side effect of fixing cases when users want to train a non-functorch-compatible model without DP-SGD.
ANDREY Paul authored- Torch 2.0 was released in march 2023, introducing a number of new features to the torch ecosystem, while being non-breaking wrt past versions on most points. - One of the salient changes is the introduction of `torch.func`, that integrates features previously introduced as part of the `functorch` package, which we have been relying upon in order to efficiently compute and clip sample-wise gradients. - This commit therefore introduces a new backend code branch so as to make use of `torch.func` when `torch~=2.0` is used, and retain the existing `functorch`-based branch when older torch versions are used. - As part of this effort, some of the `TorchModel` backend code was refactored, notably deferring the functional transform of the input `torch.nn.Module` to the first call for sample-wise gradients computation. This has the nice side effect of fixing cases when users want to train a non-functorch-compatible model without DP-SGD.
functorch.py 3.07 KiB