-
- Downloads
Add and document 'torch.compile' support.
- Torch 2.0 introduced `torch.compile`, a novel utilty to optimize computations via their JIT-compiling into optimized kernels. At the moment, compiled modules cannot be saved (but their states, which are shared with the underlying original module, can), and are not compatible with `torch.func` functional execution either. - As declearn vows to support Torch 2.0, it seems crucial that end- users may use `torch.compile`. However, as long as Torch 1.10-13 versions (and previous versions of DecLearn 2.X) are supported, this handling should not break backward compatibility. - An initial approach was to enable compiling the handled module as part of `TorchModel.__init__`. However, this proves impractical, as it takes away the assumption that end-users should be able to use their customly-prepared module as-is - including pre-compiled ones. This is all the more true as their are many options to the `torch.compile` function, that DecLearn has no purpose handling. - Therefore, the approach implemented here is to detect and handle models that were compiled prior to being input into `TorchModel`. - Here are a few notes and caveats regarding the current implementation: - Some impractical steps were taken to ensure weights and gradients have the same name regardless of whether the module was compiled or not, _and regardless of the specific 2.x version of declearn_. When we move to declearn 3, it will be worth revising. - A positive consequence of the previous note is that the compilation of the module should not impair cases where some clients are using torch 1 and/or an older 2.x version of declearn. - The will to retain user-defined compilation option is mostly lost due to the current lack of recording of these info by torch. This is however expected to evolve in the future, which should enable sharing instructions with clients. See issue 101107 of pytorch: https://github.com/pytorch/pytorch/issues/101107 - A clumsy bugfix was introduced to avoid an issue where the wrapped compiled model would not take weights updates into account when running in evaluation mode. The status of this hack should be kept under close look as the issue I opened to report the bug is treated: https://github.com/pytorch/pytorch/issues/104984
Loading
Please register or sign in to comment