From 7468f705fb2dd4d020daf0586f85f09d3443582c Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Fri, 26 May 2023 10:44:55 +0200 Subject: [PATCH] Add and document 'torch.compile' support. - Torch 2.0 introduced `torch.compile`, a novel utilty to optimize computations via their JIT-compiling into optimized kernels. At the moment, compiled modules cannot be saved (but their states, which are shared with the underlying original module, can), and are not compatible with `torch.func` functional execution either. - As declearn vows to support Torch 2.0, it seems crucial that end- users may use `torch.compile`. However, as long as Torch 1.10-13 versions (and previous versions of DecLearn 2.X) are supported, this handling should not break backward compatibility. - An initial approach was to enable compiling the handled module as part of `TorchModel.__init__`. However, this proves impractical, as it takes away the assumption that end-users should be able to use their customly-prepared module as-is - including pre-compiled ones. This is all the more true as their are many options to the `torch.compile` function, that DecLearn has no purpose handling. - Therefore, the approach implemented here is to detect and handle models that were compiled prior to being input into `TorchModel`. - Here are a few notes and caveats regarding the current implementation: - Some impractical steps were taken to ensure weights and gradients have the same name regardless of whether the module was compiled or not, _and regardless of the specific 2.x version of declearn_. When we move to declearn 3, it will be worth revising. - A positive consequence of the previous note is that the compilation of the module should not impair cases where some clients are using torch 1 and/or an older 2.x version of declearn. - The will to retain user-defined compilation option is mostly lost due to the current lack of recording of these info by torch. This is however expected to evolve in the future, which should enable sharing instructions with clients. See issue 101107 of pytorch: https://github.com/pytorch/pytorch/issues/101107 - A clumsy bugfix was introduced to avoid an issue where the wrapped compiled model would not take weights updates into account when running in evaluation mode. The status of this hack should be kept under close look as the issue I opened to report the bug is treated: https://github.com/pytorch/pytorch/issues/104984 --- declearn/model/torch/_model.py | 79 +++++++++++++++++++++++++++------- test/model/test_torch.py | 21 +++++++-- 2 files changed, 81 insertions(+), 19 deletions(-) diff --git a/declearn/model/torch/_model.py b/declearn/model/torch/_model.py index 8eefac9c..21ce772a 100644 --- a/declearn/model/torch/_model.py +++ b/declearn/model/torch/_model.py @@ -71,6 +71,25 @@ class TorchModel(Model): `update_device_policy` method. - You may consult the device policy currently enforced by a TorchModel instance by accessing its `device_policy` property. + + Notes regarding `torch.compile` support (torch >=2.0): + + - If you want the wrapped model to be optimized via `torch.compile`, it + should be so _prior_ to being wrapped using `TorchModel`. + - The compilation will not be used when computing sample-wise-clipped + gradients, as `torch.func` and `torch.compile` do not play along yet. + - The information that the module was compiled will be saved as part of + the `TorchModel` config, so that using `TorchModel.from_config` will + trigger it again when possible; this is however limited to calling + `torch.compile`, meaning that any other argument will be lost. + - Note that the former point notably affects the way clients will run + a server-emitted `TorchModel` as part of a FL process: client that + run Torch 1.X will be able to use the un-optimized module, while + clients running Torch 2.0 will use compilation, but in a rather crude + flavor, that may not be suitable for some specific/advanced cases. + - Enhanced support for `torch.compile` is on the roadmap. If you run + into issues and/or have requests or advice on that topic, feel free + to let us know by contacting us via mail or GitLab. """ def __init__( @@ -97,13 +116,19 @@ class TorchModel(Model): # Select the device where to place computations, and wrap the model. policy = get_device_policy() device = select_device(gpu=policy.gpu, idx=policy.idx) - model = AutoDeviceModule(model, device=device) - super().__init__(model) + super().__init__(AutoDeviceModule(model, device=device)) # Assign loss module and set it not to reduce sample-wise values. if not isinstance(loss, torch.nn.Module): raise TypeError("'loss' should be a torch.nn.Module instance.") loss.reduction = "none" # type: ignore self._loss_fn = AutoDeviceModule(loss, device=device) + # Detect torch-compiled models and extract underlying module. + self._raw_model = self._model + if hasattr(torch, "compile") and hasattr(model, "_orig_mod"): + self._raw_model = AutoDeviceModule( + module=getattr(model, "_orig_mod"), + device=self._model.device, + ) @property def device_policy( @@ -131,12 +156,16 @@ class TorchModel(Model): "PyTorch JSON serialization relies on pickle, which may be unsafe." ) with io.BytesIO() as buffer: - torch.save(self._model.module, buffer) + torch.save(self._raw_model.module, buffer) model = buffer.getbuffer().hex() with io.BytesIO() as buffer: torch.save(self._loss_fn.module, buffer) loss = buffer.getbuffer().hex() - return {"model": model, "loss": loss} + return { + "model": model, + "loss": loss, + "compile": self._raw_model is not self._model, + } @classmethod def from_config( @@ -148,13 +177,15 @@ class TorchModel(Model): model = torch.load(buffer) with io.BytesIO(bytes.fromhex(config["loss"])) as buffer: loss = torch.load(buffer) + if config.get("compile", False) and hasattr(torch, "compile"): + model = torch.compile(model) return cls(model=model, loss=loss) def get_weights( self, trainable: bool = False, ) -> TorchVector: - params = self._model.named_parameters() + params = self._raw_model.named_parameters() if trainable: weights = {k: p.data for k, p in params if p.requires_grad} else: @@ -171,12 +202,12 @@ class TorchModel(Model): raise TypeError("TorchModel requires TorchVector weights.") self._verify_weights_compatibility(weights, trainable=trainable) if trainable: - state_dict = self._model.state_dict() + state_dict = self._raw_model.state_dict() state_dict.update(weights.coefs) else: state_dict = weights.coefs # NOTE: this preserves the device placement of current states - self._model.load_state_dict(state_dict) + self._raw_model.load_state_dict(state_dict) def _verify_weights_compatibility( self, @@ -200,12 +231,9 @@ class TorchModel(Model): In case some expected keys are missing, or additional keys are present. Be verbose about the identified mismatch(es). """ + params = self._raw_model.named_parameters() received = set(vector.coefs) - expected = { - name - for name, param in self._model.named_parameters() - if (not trainable) or param.requires_grad - } + expected = {n for n, p in params if (not trainable) or p.requires_grad} raise_on_stringsets_mismatch( received, expected, context="model weights" ) @@ -235,7 +263,7 @@ class TorchModel(Model): # Collect weights' gradients and return them in a Vector container. grads = { k: p.grad.detach().clone() - for k, p in self._model.named_parameters() + for k, p in self._raw_model.named_parameters() if p.requires_grad } return TorchVector(grads) @@ -324,8 +352,9 @@ class TorchModel(Model): enable optimizing operations using either `functorch` for torch 1.1X or `torch.func` for torch 2.X. """ + # NOTE: torch.func is not compatible with torch.compile yet return build_samplewise_grads_fn( - self._model, self._loss_fn, inputs, y_true, s_wght + self._raw_model, self._loss_fn, inputs, y_true, s_wght ) def apply_updates( @@ -337,7 +366,7 @@ class TorchModel(Model): self._verify_weights_compatibility(updates, trainable=True) with torch.no_grad(): for key, upd in updates.coefs.items(): - tns = self._model.get_parameter(key) + tns = self._raw_model.get_parameter(key) tns.add_(upd.to(tns.device)) def compute_batch_predictions( @@ -353,12 +382,32 @@ class TorchModel(Model): "creating labels from the base inputs." ) self._model.eval() + self._handle_torch_compile_eval_issue(inputs) with torch.no_grad(): y_pred = self._model(*inputs).cpu().numpy() y_true = y_true.cpu().numpy() s_wght = None if s_wght is None else s_wght.cpu().numpy() return y_true, y_pred, s_wght # type: ignore + def _handle_torch_compile_eval_issue( + self, + inputs: List[torch.Tensor], + ) -> None: + """Clumsily handle issues with `torch.compile` and `torch.no_grad`. + + As of Torch 2.0.1, running a compiled model's first forward pass + within a `torch.no_grad` context results in the model's future + weights updates not being properly taken into account. + + Therefore, when wrapping a compiled model, this method runs a lost + forward pass outside of a no-grad context on its first call (later + it does nothing). + """ + if (self._raw_model is self._model) or hasattr(self, "__eval_called"): + return + self._model(*inputs) + setattr(self, "__eval_called", True) + def loss_function( self, y_true: np.ndarray, diff --git a/test/model/test_torch.py b/test/model/test_torch.py index dafe8c8c..498b87a7 100644 --- a/test/model/test_torch.py +++ b/test/model/test_torch.py @@ -19,6 +19,7 @@ import json import sys +import typing from typing import Any, List, Literal, Tuple import numpy as np @@ -69,6 +70,9 @@ class FlattenCNNOutput(torch.nn.Module): return inputs.view(*shape) +Kind = Literal["MLP", "MLP-tune", "MLP-compile", "RNN", "CNN"] + + class TorchTestCase(ModelTestCase): """PyTorch test-case-provider fixture. @@ -88,6 +92,11 @@ class TorchTestCase(ModelTestCase): - stack: 32 7x7 conv. filters, then 8x8 max pooling 16 5x5 conv. filters, then 8x8 avg pooling 1 output neuron with sigmoid activation + + + Additional scenarios include "MLP-tune", where the MLP's first hidden + layer is frozen, and "MLP-compile", where the MLP is optimized using + `torch.compile` (available in torch >=2.0 only). """ vector_cls = TorchVector @@ -95,11 +104,11 @@ class TorchTestCase(ModelTestCase): def __init__( self, - kind: Literal["MLP", "MLP-tune", "RNN", "CNN"], + kind: Kind, device: Literal["CPU", "GPU"], ) -> None: """Specify the desired model architecture.""" - if kind not in ("MLP", "MLP-tune", "RNN", "CNN"): + if kind not in typing.get_args(Kind): raise ValueError(f"Invalid torch test architecture: '{kind}'.") self.kind = kind self.device = device @@ -169,6 +178,8 @@ class TorchTestCase(ModelTestCase): torch.nn.Sigmoid(), ] nnmod = torch.nn.Sequential(*stack) + if self.kind == "MLP-compile": + nnmod = torch.compile(nnmod) # type: ignore return TorchModel(nnmod, loss=torch.nn.BCELoss()) def assert_correct_device( @@ -184,13 +195,15 @@ class TorchTestCase(ModelTestCase): @pytest.fixture(name="test_case") def fixture_test_case( - kind: Literal["MLP", "MLP-tune", "RNN", "CNN"], + kind: Kind, device: Literal["CPU", "GPU"], cpu_only: bool, ) -> TorchTestCase: """Fixture to access a TorchTestCase.""" if cpu_only and device == "GPU": pytest.skip(reason="--cpu-only mode") + if kind == "MLP-compile" and not hasattr(torch, "compile"): + pytest.skip(reason="'torch.compile' is unavailable") return TorchTestCase(kind, device) @@ -200,7 +213,7 @@ if torch.cuda.device_count(): @pytest.mark.parametrize("device", DEVICES) -@pytest.mark.parametrize("kind", ["MLP", "MLP-tune", "RNN", "CNN"]) +@pytest.mark.parametrize("kind", typing.get_args(Kind)) class TestTorchModel(ModelTestSuite): """Unit tests for declearn.model.torch.TorchModel.""" -- GitLab