diff --git a/declearn/model/torch/_model.py b/declearn/model/torch/_model.py index 8eefac9cdb0b111f488c29de373bf3f9b7b911dc..21ce772adf4d606c9da584fcb9376fa3e353af6d 100644 --- a/declearn/model/torch/_model.py +++ b/declearn/model/torch/_model.py @@ -71,6 +71,25 @@ class TorchModel(Model): `update_device_policy` method. - You may consult the device policy currently enforced by a TorchModel instance by accessing its `device_policy` property. + + Notes regarding `torch.compile` support (torch >=2.0): + + - If you want the wrapped model to be optimized via `torch.compile`, it + should be so _prior_ to being wrapped using `TorchModel`. + - The compilation will not be used when computing sample-wise-clipped + gradients, as `torch.func` and `torch.compile` do not play along yet. + - The information that the module was compiled will be saved as part of + the `TorchModel` config, so that using `TorchModel.from_config` will + trigger it again when possible; this is however limited to calling + `torch.compile`, meaning that any other argument will be lost. + - Note that the former point notably affects the way clients will run + a server-emitted `TorchModel` as part of a FL process: client that + run Torch 1.X will be able to use the un-optimized module, while + clients running Torch 2.0 will use compilation, but in a rather crude + flavor, that may not be suitable for some specific/advanced cases. + - Enhanced support for `torch.compile` is on the roadmap. If you run + into issues and/or have requests or advice on that topic, feel free + to let us know by contacting us via mail or GitLab. """ def __init__( @@ -97,13 +116,19 @@ class TorchModel(Model): # Select the device where to place computations, and wrap the model. policy = get_device_policy() device = select_device(gpu=policy.gpu, idx=policy.idx) - model = AutoDeviceModule(model, device=device) - super().__init__(model) + super().__init__(AutoDeviceModule(model, device=device)) # Assign loss module and set it not to reduce sample-wise values. if not isinstance(loss, torch.nn.Module): raise TypeError("'loss' should be a torch.nn.Module instance.") loss.reduction = "none" # type: ignore self._loss_fn = AutoDeviceModule(loss, device=device) + # Detect torch-compiled models and extract underlying module. + self._raw_model = self._model + if hasattr(torch, "compile") and hasattr(model, "_orig_mod"): + self._raw_model = AutoDeviceModule( + module=getattr(model, "_orig_mod"), + device=self._model.device, + ) @property def device_policy( @@ -131,12 +156,16 @@ class TorchModel(Model): "PyTorch JSON serialization relies on pickle, which may be unsafe." ) with io.BytesIO() as buffer: - torch.save(self._model.module, buffer) + torch.save(self._raw_model.module, buffer) model = buffer.getbuffer().hex() with io.BytesIO() as buffer: torch.save(self._loss_fn.module, buffer) loss = buffer.getbuffer().hex() - return {"model": model, "loss": loss} + return { + "model": model, + "loss": loss, + "compile": self._raw_model is not self._model, + } @classmethod def from_config( @@ -148,13 +177,15 @@ class TorchModel(Model): model = torch.load(buffer) with io.BytesIO(bytes.fromhex(config["loss"])) as buffer: loss = torch.load(buffer) + if config.get("compile", False) and hasattr(torch, "compile"): + model = torch.compile(model) return cls(model=model, loss=loss) def get_weights( self, trainable: bool = False, ) -> TorchVector: - params = self._model.named_parameters() + params = self._raw_model.named_parameters() if trainable: weights = {k: p.data for k, p in params if p.requires_grad} else: @@ -171,12 +202,12 @@ class TorchModel(Model): raise TypeError("TorchModel requires TorchVector weights.") self._verify_weights_compatibility(weights, trainable=trainable) if trainable: - state_dict = self._model.state_dict() + state_dict = self._raw_model.state_dict() state_dict.update(weights.coefs) else: state_dict = weights.coefs # NOTE: this preserves the device placement of current states - self._model.load_state_dict(state_dict) + self._raw_model.load_state_dict(state_dict) def _verify_weights_compatibility( self, @@ -200,12 +231,9 @@ class TorchModel(Model): In case some expected keys are missing, or additional keys are present. Be verbose about the identified mismatch(es). """ + params = self._raw_model.named_parameters() received = set(vector.coefs) - expected = { - name - for name, param in self._model.named_parameters() - if (not trainable) or param.requires_grad - } + expected = {n for n, p in params if (not trainable) or p.requires_grad} raise_on_stringsets_mismatch( received, expected, context="model weights" ) @@ -235,7 +263,7 @@ class TorchModel(Model): # Collect weights' gradients and return them in a Vector container. grads = { k: p.grad.detach().clone() - for k, p in self._model.named_parameters() + for k, p in self._raw_model.named_parameters() if p.requires_grad } return TorchVector(grads) @@ -324,8 +352,9 @@ class TorchModel(Model): enable optimizing operations using either `functorch` for torch 1.1X or `torch.func` for torch 2.X. """ + # NOTE: torch.func is not compatible with torch.compile yet return build_samplewise_grads_fn( - self._model, self._loss_fn, inputs, y_true, s_wght + self._raw_model, self._loss_fn, inputs, y_true, s_wght ) def apply_updates( @@ -337,7 +366,7 @@ class TorchModel(Model): self._verify_weights_compatibility(updates, trainable=True) with torch.no_grad(): for key, upd in updates.coefs.items(): - tns = self._model.get_parameter(key) + tns = self._raw_model.get_parameter(key) tns.add_(upd.to(tns.device)) def compute_batch_predictions( @@ -353,12 +382,32 @@ class TorchModel(Model): "creating labels from the base inputs." ) self._model.eval() + self._handle_torch_compile_eval_issue(inputs) with torch.no_grad(): y_pred = self._model(*inputs).cpu().numpy() y_true = y_true.cpu().numpy() s_wght = None if s_wght is None else s_wght.cpu().numpy() return y_true, y_pred, s_wght # type: ignore + def _handle_torch_compile_eval_issue( + self, + inputs: List[torch.Tensor], + ) -> None: + """Clumsily handle issues with `torch.compile` and `torch.no_grad`. + + As of Torch 2.0.1, running a compiled model's first forward pass + within a `torch.no_grad` context results in the model's future + weights updates not being properly taken into account. + + Therefore, when wrapping a compiled model, this method runs a lost + forward pass outside of a no-grad context on its first call (later + it does nothing). + """ + if (self._raw_model is self._model) or hasattr(self, "__eval_called"): + return + self._model(*inputs) + setattr(self, "__eval_called", True) + def loss_function( self, y_true: np.ndarray, diff --git a/test/model/test_torch.py b/test/model/test_torch.py index dafe8c8cb66b1fbfe12a7201d4bb35a0c4eb23b5..498b87a7549b49a104766df625b03b0370c5dac8 100644 --- a/test/model/test_torch.py +++ b/test/model/test_torch.py @@ -19,6 +19,7 @@ import json import sys +import typing from typing import Any, List, Literal, Tuple import numpy as np @@ -69,6 +70,9 @@ class FlattenCNNOutput(torch.nn.Module): return inputs.view(*shape) +Kind = Literal["MLP", "MLP-tune", "MLP-compile", "RNN", "CNN"] + + class TorchTestCase(ModelTestCase): """PyTorch test-case-provider fixture. @@ -88,6 +92,11 @@ class TorchTestCase(ModelTestCase): - stack: 32 7x7 conv. filters, then 8x8 max pooling 16 5x5 conv. filters, then 8x8 avg pooling 1 output neuron with sigmoid activation + + + Additional scenarios include "MLP-tune", where the MLP's first hidden + layer is frozen, and "MLP-compile", where the MLP is optimized using + `torch.compile` (available in torch >=2.0 only). """ vector_cls = TorchVector @@ -95,11 +104,11 @@ class TorchTestCase(ModelTestCase): def __init__( self, - kind: Literal["MLP", "MLP-tune", "RNN", "CNN"], + kind: Kind, device: Literal["CPU", "GPU"], ) -> None: """Specify the desired model architecture.""" - if kind not in ("MLP", "MLP-tune", "RNN", "CNN"): + if kind not in typing.get_args(Kind): raise ValueError(f"Invalid torch test architecture: '{kind}'.") self.kind = kind self.device = device @@ -169,6 +178,8 @@ class TorchTestCase(ModelTestCase): torch.nn.Sigmoid(), ] nnmod = torch.nn.Sequential(*stack) + if self.kind == "MLP-compile": + nnmod = torch.compile(nnmod) # type: ignore return TorchModel(nnmod, loss=torch.nn.BCELoss()) def assert_correct_device( @@ -184,13 +195,15 @@ class TorchTestCase(ModelTestCase): @pytest.fixture(name="test_case") def fixture_test_case( - kind: Literal["MLP", "MLP-tune", "RNN", "CNN"], + kind: Kind, device: Literal["CPU", "GPU"], cpu_only: bool, ) -> TorchTestCase: """Fixture to access a TorchTestCase.""" if cpu_only and device == "GPU": pytest.skip(reason="--cpu-only mode") + if kind == "MLP-compile" and not hasattr(torch, "compile"): + pytest.skip(reason="'torch.compile' is unavailable") return TorchTestCase(kind, device) @@ -200,7 +213,7 @@ if torch.cuda.device_count(): @pytest.mark.parametrize("device", DEVICES) -@pytest.mark.parametrize("kind", ["MLP", "MLP-tune", "RNN", "CNN"]) +@pytest.mark.parametrize("kind", typing.get_args(Kind)) class TestTorchModel(ModelTestSuite): """Unit tests for declearn.model.torch.TorchModel."""