Mentions légales du service

Skip to content

Add support for Torch 2.0

ANDREY Paul requested to merge 27-add-support-for-torch-2-0 into develop

This MR tackles issue #27 (closed), adding support for Torch 2.0 in declearn, while also keeping support for Torch 1.10-1.13.

  • Most TorchModel operations remain exactly the same.
  • The only part that requires different backend code branches is the computation of sample-wise-clipped gradients: for torch 1.10-1.13 we keep on relying on functorch, while for torch 2.X the new torch.func submodule (in favor of which functorch is being deprecated) is used.
  • Agents using heterogeneous versions of Torch should be able to co-operate smoothly, with no additional information exchange to do so.
  • The optional use of torch.compile to (hopefully) optimize runtime is enabled via a new torch_compile: bool = False argument to TorchModel.__init__.
  • By default, pip install declearn[all] and pip install declearn[torch] will result in installing the latest 2.X version (and latest functorch and, optionally, opacus versions). By contrast, explicit [torch1] and [torch2] dependency specifiers are now provided to hopefully solve some common co-dependency issues between torch, functorch and opacus. This is documented as part of the package installation guide.
  • A consequence of the former point is that the CI/CD pipeline will by default run with torch 2.X, which makes sense as it is bound to become increasingly used by end-users, and is also mandatory for linters to analyze the code properly.
  • Tests were added to cover the 1.1X backend branch as part of CI/CD.
  • The former was achieved through an effort to refactor the tox and CI/CD configuration files, which goes beyond the initial focus of this MR.
  • Crude support for torch.compile was implemented. The following PyTorch issues should be watched:

Closes #27 (closed)

Edited by ANDREY Paul

Merge request reports