Revise TorchModel to support Torch 2.0.
- Torch 2.0 was released in march 2023, introducing a number of new features to the torch ecosystem, while being non-breaking wrt past versions on most points. - One of the salient changes is the introduction of `torch.func`, that integrates features previously introduced as part of the `functorch` package, which we have been relying upon in order to efficiently compute and clip sample-wise gradients. - This commit therefore introduces a new backend code branch so as to make use of `torch.func` when `torch~=2.0` is used, and retain the existing `functorch`-based branch when older torch versions are used. - As part of this effort, some of the `TorchModel` backend code was refactored, notably deferring the functional transform of the input `torch.nn.Module` to the first call for sample-wise gradients computation. This has the nice side effect of fixing cases when users want to train a non-functorch-compatible model without DP-SGD.
parent
9f0035d7
No related branches found
No related tags found
Showing
- declearn/model/torch/_model.py 26 additions, 76 deletionsdeclearn/model/torch/_model.py
- declearn/model/torch/_samplewise/__init__.py 78 additions, 0 deletionsdeclearn/model/torch/_samplewise/__init__.py
- declearn/model/torch/_samplewise/functorch.py 93 additions, 0 deletionsdeclearn/model/torch/_samplewise/functorch.py
- declearn/model/torch/_samplewise/shared.py 56 additions, 0 deletionsdeclearn/model/torch/_samplewise/shared.py
- declearn/model/torch/_samplewise/torchfunc.py 76 additions, 0 deletionsdeclearn/model/torch/_samplewise/torchfunc.py
declearn/model/torch/_samplewise/__init__.py
0 → 100644
declearn/model/torch/_samplewise/shared.py
0 → 100644
Please register or sign in to comment