Add support for Torch 2.0
Torch 2.0 has been out for a couple of months now, and brings the promise of faster model training thanks to their new compilation engine - in addition to being the new basis to which future features will be added. Overall, Torch 2.0 preserves compatibility with the previous versions, meaning that most of the declearn codebase is natively compatible with 2.0 simply by being compatible with 1.10-1.13 versions. This is, however, not the case of the differential-privacy-oriented code, which has been relying on functorch
, which is now deprecated in favor of torch.func
, with the latter building on the previous logic and preserving most of the function names and signatures, but still introducing some breaking changes.
Hence, this issue is about adding support for Torch 2.0, while retaining support for versions 1.10-1.13 (or is it really 1.11-1.13, due to functorch?), since 2.0 is very young, and these versions are quite recent as well (1.10 was released in October 2021, 1.11 in March 2022).
Current advancement:
- Write an alternative 'TorchModel' for Torch-2.0.
- Refactor backend code to shrink and isolate version-dependent parts.
- Enable dynamically (and quietly) loading whichever version should be.
-
Specify two alternative dependency specifications (torch 1 & functorch | torch 2) in
pyproject.toml
. - Handle torch-version-based linter issues, with as little brutal silencing as possible.
- Enable testing both cases in the CI/CD, for proper maintenance.
-
Optimize models using
torch.compile
in Torch-2.0.