diff --git a/declearn/communication/__init__.py b/declearn/communication/__init__.py index 819f28e1ac409a4bbed79d2e21ca0b05404dc163..e7f685f248bfc22bbc35300a0ef3f39d66660dba 100644 --- a/declearn/communication/__init__.py +++ b/declearn/communication/__init__.py @@ -70,8 +70,10 @@ from ._build import ( try: from . import grpc except ImportError: + # pragma: no cover _INSTALLABLE_BACKENDS["grpc"] = ("grpcio", "protobuf") try: from . import websockets except ImportError: + # pragma: no cover _INSTALLABLE_BACKENDS["websockets"] = ("websockets",) diff --git a/declearn/model/torch/_model.py b/declearn/model/torch/_model.py index f8a5dd6f1396d6d6909edbf070ee616e6cc3a6ec..21ce772adf4d606c9da584fcb9376fa3e353af6d 100644 --- a/declearn/model/torch/_model.py +++ b/declearn/model/torch/_model.py @@ -17,25 +17,21 @@ """Model subclass to wrap PyTorch models.""" -import functools import io +import functools import warnings -from typing import Any, Callable, Dict, List, Optional, Set, Tuple - -import functorch # type: ignore +from typing import Any, Dict, List, Optional, Set, Tuple -try: - import functorch.compile # type: ignore -except ModuleNotFoundError: - COMPILE_AVAILABLE = False -else: - COMPILE_AVAILABLE = True import numpy as np import torch from typing_extensions import Self # future: import from typing (py >=3.11) from declearn.model.api import Model from declearn.model.torch.utils import AutoDeviceModule, select_device +from declearn.model.torch._samplewise import ( + GetGradientsFunction, + build_samplewise_grads_fn, +) from declearn.model.torch._vector import TorchVector from declearn.model._utils import raise_on_stringsets_mismatch from declearn.typing import Batch @@ -75,6 +71,25 @@ class TorchModel(Model): `update_device_policy` method. - You may consult the device policy currently enforced by a TorchModel instance by accessing its `device_policy` property. + + Notes regarding `torch.compile` support (torch >=2.0): + + - If you want the wrapped model to be optimized via `torch.compile`, it + should be so _prior_ to being wrapped using `TorchModel`. + - The compilation will not be used when computing sample-wise-clipped + gradients, as `torch.func` and `torch.compile` do not play along yet. + - The information that the module was compiled will be saved as part of + the `TorchModel` config, so that using `TorchModel.from_config` will + trigger it again when possible; this is however limited to calling + `torch.compile`, meaning that any other argument will be lost. + - Note that the former point notably affects the way clients will run + a server-emitted `TorchModel` as part of a FL process: client that + run Torch 1.X will be able to use the un-optimized module, while + clients running Torch 2.0 will use compilation, but in a rather crude + flavor, that may not be suitable for some specific/advanced cases. + - Enhanced support for `torch.compile` is on the roadmap. If you run + into issues and/or have requests or advice on that topic, feel free + to let us know by contacting us via mail or GitLab. """ def __init__( @@ -101,15 +116,19 @@ class TorchModel(Model): # Select the device where to place computations, and wrap the model. policy = get_device_policy() device = select_device(gpu=policy.gpu, idx=policy.idx) - model = AutoDeviceModule(model, device=device) - super().__init__(model) + super().__init__(AutoDeviceModule(model, device=device)) # Assign loss module and set it not to reduce sample-wise values. if not isinstance(loss, torch.nn.Module): raise TypeError("'loss' should be a torch.nn.Module instance.") loss.reduction = "none" # type: ignore self._loss_fn = AutoDeviceModule(loss, device=device) - # Compute and assign a functional version of the model. - self._func_model, _ = functorch.make_functional(self._model) + # Detect torch-compiled models and extract underlying module. + self._raw_model = self._model + if hasattr(torch, "compile") and hasattr(model, "_orig_mod"): + self._raw_model = AutoDeviceModule( + module=getattr(model, "_orig_mod"), + device=self._model.device, + ) @property def device_policy( @@ -137,12 +156,16 @@ class TorchModel(Model): "PyTorch JSON serialization relies on pickle, which may be unsafe." ) with io.BytesIO() as buffer: - torch.save(self._model.module, buffer) + torch.save(self._raw_model.module, buffer) model = buffer.getbuffer().hex() with io.BytesIO() as buffer: torch.save(self._loss_fn.module, buffer) loss = buffer.getbuffer().hex() - return {"model": model, "loss": loss} + return { + "model": model, + "loss": loss, + "compile": self._raw_model is not self._model, + } @classmethod def from_config( @@ -154,13 +177,15 @@ class TorchModel(Model): model = torch.load(buffer) with io.BytesIO(bytes.fromhex(config["loss"])) as buffer: loss = torch.load(buffer) + if config.get("compile", False) and hasattr(torch, "compile"): + model = torch.compile(model) return cls(model=model, loss=loss) def get_weights( self, trainable: bool = False, ) -> TorchVector: - params = self._model.named_parameters() + params = self._raw_model.named_parameters() if trainable: weights = {k: p.data for k, p in params if p.requires_grad} else: @@ -177,12 +202,12 @@ class TorchModel(Model): raise TypeError("TorchModel requires TorchVector weights.") self._verify_weights_compatibility(weights, trainable=trainable) if trainable: - state_dict = self._model.state_dict() + state_dict = self._raw_model.state_dict() state_dict.update(weights.coefs) else: state_dict = weights.coefs # NOTE: this preserves the device placement of current states - self._model.load_state_dict(state_dict) + self._raw_model.load_state_dict(state_dict) def _verify_weights_compatibility( self, @@ -206,12 +231,9 @@ class TorchModel(Model): In case some expected keys are missing, or additional keys are present. Be verbose about the identified mismatch(es). """ + params = self._raw_model.named_parameters() received = set(vector.coefs) - expected = { - name - for name, param in self._model.named_parameters() - if (not trainable) or param.requires_grad - } + expected = {n for n, p in params if (not trainable) or p.requires_grad} raise_on_stringsets_mismatch( received, expected, context="model weights" ) @@ -241,7 +263,7 @@ class TorchModel(Model): # Collect weights' gradients and return them in a Vector container. grads = { k: p.grad.detach().clone() - for k, p in self._model.named_parameters() + for k, p in self._raw_model.named_parameters() if p.requires_grad } return TorchVector(grads) @@ -281,8 +303,8 @@ class TorchModel(Model): max_norm: float, ) -> TorchVector: """Compute and return batch-averaged sample-wise-clipped gradients.""" - # Compute sample-wise clipped gradients, using functorch. - grads = self._compute_samplewise_gradients(batch, max_norm) + # Compute sample-wise clipped gradients, using functional torch. + grads = self._compute_samplewise_gradients(batch, clip=max_norm) # Batch-average the resulting sample-wise gradients. return TorchVector( {name: tensor.mean(dim=0) for name, tensor in grads.coefs.items()} @@ -291,92 +313,49 @@ class TorchModel(Model): def _compute_samplewise_gradients( self, batch: Batch, - max_norm: Optional[float], + clip: Optional[float], ) -> TorchVector: """Compute and return stacked sample-wise gradients over a batch.""" - # Unpack the inputs, gather parameters and list gradients to compute. inputs, y_true, s_wght = self._unpack_batch(batch) - params = [] # type: List[torch.nn.Parameter] - idxgrd = [] # type: List[int] - pnames = [] # type: List[str] - for index, (name, param) in enumerate(self._model.named_parameters()): - params.append(param) - if param.requires_grad: - idxgrd.append(index + 3) - pnames.append(name) - # Gather or build the sample-wise clipped gradients computing function. grads_fn = self._build_samplewise_grads_fn( - idxgrd=tuple(idxgrd), inputs=len(inputs), y_true=(y_true is not None), s_wght=(s_wght is not None), ) - # Call it on the current inputs, with optional clipping. with torch.no_grad(): - grads = grads_fn(inputs, y_true, s_wght, *params, clip=max_norm) - # Wrap the results into a TorchVector and return it. - return TorchVector(dict(zip(pnames, grads))) + grads = grads_fn(inputs, y_true, s_wght, clip=clip) # type: ignore + return TorchVector(grads) @functools.lru_cache def _build_samplewise_grads_fn( self, - idxgrd: Tuple[int, ...], inputs: int, y_true: bool, s_wght: bool, - ) -> Callable[..., List[torch.Tensor]]: - """Build a functorch-based sample-wise gradients-computation function. + ) -> GetGradientsFunction: + """Build an optimizer sample-wise gradients-computation function. This function is cached, i.e. repeated calls with the same parameters will return the same object - enabling to reduce runtime costs due to building and (when available) compiling the output function. - Parameters - ---------- - idxgrd: tuple of int - Pre-incremented indices of the parameters that require gradients. - inputs: int - Number of input tensors. - y_true: bool - Whether a true labels tensor is provided. - s_wght: bool - Whether a sample weights tensor is provided. - Returns ------- - grads_fn: callable[inputs, y_true, s_wght, *params, /, clip] - Functorch-optimized function to efficiently compute sample- - wise gradients based on batched inputs, and optionally clip - them based on a maximum l2-norm value `clip`. + grads_fn: callable[[inputs, y_true, s_wght, clip], grads] + Function to efficiently compute and return sample-wise gradients + wrt trainable model parameters based on a batch of inputs, with + opt. clipping based on a maximum l2-norm value `clip`. + + Note + ---- + The underlying backend code depends on your Torch version, so as to + enable optimizing operations using either `functorch` for torch 1.1X + or `torch.func` for torch 2.X. """ - - def forward(inputs, y_true, s_wght, *params): - """Conduct the forward pass in a functional way.""" - y_pred = self._func_model(params, *inputs) - return self._compute_loss(y_pred, y_true, s_wght) - - def grads_fn(inputs, y_true, s_wght, *params, clip=None): - """Compute gradients and optionally clip them.""" - gfunc = functorch.grad(forward, argnums=idxgrd) - grads = gfunc(inputs, y_true, None, *params) - if clip: - for grad in grads: - # future: use torch.linalg.norm when supported by functorch - norm = torch.norm(grad, p=2, keepdim=True) - # false-positive; pylint: disable=no-member - grad.mul_(torch.clamp(clip / norm, max=1)) - if s_wght is not None: - grad.mul_(s_wght.to(grad.device)) - return grads - - # Wrap the former function to compute and clip sample-wise gradients. - in_axes = [[0] * inputs, 0 if y_true else None, 0 if s_wght else None] - in_axes.extend([None] * sum(1 for _ in self._model.parameters())) - grads_fn = functorch.vmap(grads_fn, tuple(in_axes)) - # Compile the resulting function to decrease runtime costs. - if not COMPILE_AVAILABLE: - return grads_fn - return functorch.compile.aot_function(grads_fn, functorch.compile.nop) + # NOTE: torch.func is not compatible with torch.compile yet + return build_samplewise_grads_fn( + self._raw_model, self._loss_fn, inputs, y_true, s_wght + ) def apply_updates( self, @@ -387,7 +366,7 @@ class TorchModel(Model): self._verify_weights_compatibility(updates, trainable=True) with torch.no_grad(): for key, upd in updates.coefs.items(): - tns = self._model.get_parameter(key) + tns = self._raw_model.get_parameter(key) tns.add_(upd.to(tns.device)) def compute_batch_predictions( @@ -403,12 +382,32 @@ class TorchModel(Model): "creating labels from the base inputs." ) self._model.eval() + self._handle_torch_compile_eval_issue(inputs) with torch.no_grad(): y_pred = self._model(*inputs).cpu().numpy() y_true = y_true.cpu().numpy() s_wght = None if s_wght is None else s_wght.cpu().numpy() return y_true, y_pred, s_wght # type: ignore + def _handle_torch_compile_eval_issue( + self, + inputs: List[torch.Tensor], + ) -> None: + """Clumsily handle issues with `torch.compile` and `torch.no_grad`. + + As of Torch 2.0.1, running a compiled model's first forward pass + within a `torch.no_grad` context results in the model's future + weights updates not being properly taken into account. + + Therefore, when wrapping a compiled model, this method runs a lost + forward pass outside of a no-grad context on its first call (later + it does nothing). + """ + if (self._raw_model is self._model) or hasattr(self, "__eval_called"): + return + self._model(*inputs) + setattr(self, "__eval_called", True) + def loss_function( self, y_true: np.ndarray, diff --git a/declearn/model/torch/_samplewise/__init__.py b/declearn/model/torch/_samplewise/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1448a18bb210279d99343f8637cc8d7587cf12a5 --- /dev/null +++ b/declearn/model/torch/_samplewise/__init__.py @@ -0,0 +1,78 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Torch-version-dependent code to compute sample-wise gradients.""" + +from typing import Callable, Dict, List, Optional + +import torch + +from .shared import GetGradientsFunction + +if torch.__version__.startswith("2."): + from .torchfunc import build_samplewise_grads_fn_backend +elif torch.__version__.startswith("1.1"): + from .functorch import build_samplewise_grads_fn_backend +else: + # pragma: no cover + raise ImportError(f"Unsupported Torch version: {torch.__version__}") + + +__all__ = [ + "GetGradientsFunction", + "build_samplewise_grads_fn", +] + + +def build_samplewise_grads_fn( + model: torch.nn.Module, + loss_fn: torch.nn.Module, + inputs: int, + y_true: bool, + s_wght: bool, +) -> GetGradientsFunction: + """Build a torch-specific sample-wise gradients-computation function. + + Parameters + ---------- + model: torch.nn.Module + Model that is to be trained. + loss_fn: torch.nn.Module + Loss-computing module, returning sample-wise loss values. + inputs: int + Number of input tensors. + y_true: bool + Whether a true labels tensor is provided. + s_wght: bool + Whether a sample weights tensor is provided. + + Returns + ------- + grads_fn: callable[[inputs, y_true, s_wght, clip], grads] + Function that efficiently computes and returns sample-wise gradients + wrt trainable model parameters based on a batch of inputs, with opt. + clipping based on a maximum l2-norm value `clip`. + + Note + ---- + The underlying backend code depends on your Torch version, so as to + enable optimizing operations using either `functorch` for torch 1.1X + or `torch.func` for torch 2.X. + """ + return build_samplewise_grads_fn_backend( + model, loss_fn, inputs, y_true, s_wght + ) diff --git a/declearn/model/torch/_samplewise/functorch.py b/declearn/model/torch/_samplewise/functorch.py new file mode 100644 index 0000000000000000000000000000000000000000..d39a3e16602918cd36e755550e1440622efb2866 --- /dev/null +++ b/declearn/model/torch/_samplewise/functorch.py @@ -0,0 +1,96 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implementation of `build_samplewise_grads_fn` for Torch 2.0.""" + +from typing import List, Tuple + +# fmt: off +import functorch # type: ignore +try: + import functorch.compile # type: ignore + COMPILE_AVAILABLE = True +except ModuleNotFoundError: + # pragma: no cover + COMPILE_AVAILABLE = False +import torch +# fmt: on + +from declearn.model.torch._samplewise.shared import ( + GetGradientsFunction, + clip_and_scale_grads_inplace, +) + +__all__ = [ + "build_samplewise_grads_fn_backend", +] + + +def build_samplewise_grads_fn_backend( + model: torch.nn.Module, + loss_fn: torch.nn.Module, + inputs: int, + y_true: bool, + s_wght: bool, +) -> GetGradientsFunction: + """Implementation of `build_samplewise_grads_fn` for Torch 1.1X.""" + + func_model, *_ = functorch.make_functional_with_buffers(model) + + def run_forward(inputs, y_true, s_wght, buffers, *params): + """Run the forward pass in a functional way.""" + y_pred = func_model(params, buffers, *inputs) + s_loss = loss_fn(y_pred, y_true) + if s_wght is not None: + s_loss.mul_(s_wght.to(s_loss.device)) + return s_loss.mean() + + def grads_fn(inputs, y_true, s_wght, clip=None): + """Compute gradients and optionally clip them.""" + params, idxgrd, pnames = get_params(model) + buffers = list(model.buffers()) + gfunc = functorch.grad(run_forward, argnums=tuple(idxgrd)) + grads = gfunc( + inputs, y_true, (None if clip else s_wght), buffers, *params + ) + if clip: + clip_and_scale_grads_inplace(grads, clip, s_wght) + return dict(zip(pnames, grads)) + + # Wrap the former function to compute and clip sample-wise gradients. + in_dims = ([0] * inputs, 0 if y_true else None, 0 if s_wght else None) + grads_fn = functorch.vmap(grads_fn, in_dims, randomness="same") + # Compile the resulting function to decrease runtime costs. + if not COMPILE_AVAILABLE: + # pragma: no cover + return grads_fn + return functorch.compile.aot_function(grads_fn, functorch.compile.nop) + + +def get_params( + model: torch.nn.Module, +) -> Tuple[List[torch.nn.Parameter], List[int], List[str]]: + """Return a model's parameters and the index and name of trainable ones.""" + params = [] # type: List[torch.nn.Parameter] + idxgrd = [] # type: List[int] + pnames = [] # type: List[str] + for idx, (name, param) in enumerate(model.named_parameters()): + params.append(param) + if param.requires_grad: + idxgrd.append(idx + 4) + pnames.append(name) + return params, idxgrd, pnames diff --git a/declearn/model/torch/_samplewise/shared.py b/declearn/model/torch/_samplewise/shared.py new file mode 100644 index 0000000000000000000000000000000000000000..451ae7c11998c3d337bffe988df2f6e131f5e539 --- /dev/null +++ b/declearn/model/torch/_samplewise/shared.py @@ -0,0 +1,56 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared code for torch-version-dependent backend code.""" + +from typing import Callable, Dict, Iterable, List, Optional + +import torch + +__all__ = [ + "GetGradientsFunction", + "clip_and_scale_grads_inplace", +] + + +GetGradientsFunction = Callable[ + [ + List[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[float], + ], + Dict[str, torch.Tensor], +] +"""Signature for sample-wise gradients computation functions.""" + + +def clip_and_scale_grads_inplace( + grads: Iterable[torch.Tensor], + clip: float, + wght: Optional[torch.Tensor] = None, +) -> None: + """Clip a collection of tensors in-place, based on their euclidean norm. + + Also apply an optional weight tensor to scale the clipped gradients. + """ + for grad in grads: + norm = torch.norm(grad, p=2, keepdim=True) + # false-positive; pylint: disable=no-member + grad.mul_(torch.clamp(clip / norm, max=1)) + if wght is not None: + grad.mul_(wght.to(grad.device)) diff --git a/declearn/model/torch/_samplewise/torchfunc.py b/declearn/model/torch/_samplewise/torchfunc.py new file mode 100644 index 0000000000000000000000000000000000000000..88aa5b7dca01dcdd77adda76b29d6b080c7487d2 --- /dev/null +++ b/declearn/model/torch/_samplewise/torchfunc.py @@ -0,0 +1,80 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implementation of `build_samplewise_grads_fn` for Torch 2.0.""" + +from typing import Dict, Tuple + +import torch + +from declearn.model.torch._samplewise.shared import ( + GetGradientsFunction, + clip_and_scale_grads_inplace, +) + +__all__ = [ + "build_samplewise_grads_fn_backend", +] + + +def build_samplewise_grads_fn_backend( + model: torch.nn.Module, + loss_fn: torch.nn.Module, + inputs: int, + y_true: bool, + s_wght: bool, +) -> GetGradientsFunction: + """Implementation of `build_samplewise_grads_fn` for Torch 2.0.""" + + def run_forward(params, frozen, buffers, inputs, y_true, s_wght): + """Run the forward pass in a functional way.""" + # backend closure function; pylint: disable=too-many-arguments + y_pred = torch.func.functional_call( + model, [params, frozen, buffers], *inputs + ) + s_loss = loss_fn(y_pred, y_true) + if s_wght is not None: + s_loss.mul_(s_wght.to(s_loss.device)) + return s_loss.mean() + + get_grads = torch.func.grad(run_forward, argnums=0) + + def get_clipped_grads(inputs, y_true, s_wght, clip=None): + """Compute gradients and optionally clip them.""" + params, frozen = get_params(model) + buffers = dict(model.named_buffers()) + grads = get_grads( + params, frozen, buffers, inputs, y_true, None if clip else s_wght + ) + if clip: + clip_and_scale_grads_inplace(grads.values(), clip, s_wght) + return grads + + # Wrap the former function to compute and clip sample-wise gradients. + in_dims = ([0] * inputs, 0 if y_true else None, 0 if s_wght else None) + return torch.func.vmap(get_clipped_grads, in_dims, randomness="same") + + +def get_params( + model: torch.nn.Module, +) -> Tuple[Dict[str, torch.nn.Parameter], Dict[str, torch.nn.Parameter]]: + """Return a model's parameters, split between trainable and frozen ones.""" + params = {} # type: Dict[str, torch.nn.Parameter] + frozen = {} # type: Dict[str, torch.nn.Parameter] + for name, param in model.named_parameters(): + (params if param.requires_grad else frozen)[name] = param + return params, frozen diff --git a/docs/setup.md b/docs/setup.md index 979ef90c9dfc77757bb0e2c99b51cca58f716b82..7b00c0aafd38d5390a402fb70dd21747b817de97 100644 --- a/docs/setup.md +++ b/docs/setup.md @@ -110,3 +110,10 @@ pip install declearn[all,tests] # install all extra and testing dependencies - On some systems, the square brackets used our pip install are not properly parsed. Try replacing `[` by `\[` and `]` by `\]`, or putting the instruction between quotes (`pip install "declearn[...]"`). +- Regarding Torch: declearn currently supports both late 1.10-1.13 versions and + 2.X ones. You may use either one freely, but may run into issues regarding + co-dependent versions of `torch`, `functorch` (in 1.10-1.13) and `opacus` (if + you want to use [differential privacy](./user-guide/local_dp.md) features). + You may use the `torch1` or `torch2` extra dependency specifier to explicitly + target either the 1.13 or latest 2.X torch version and install the proper + versions of the other packages (including opacus). diff --git a/pyproject.toml b/pyproject.toml index 11f8094e2f7a2c26d93e775db9f5daa590b9a129..2161244d8d6de8a5b4e8408ccd0a494f5fe575e3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,13 +48,13 @@ dependencies = [ [project.optional-dependencies] all = [ # all non-tests extra dependencies "dm-haiku == 0.0.9", - "functorch", + "functorch >= 1.10, < 3.0", "grpcio >= 1.45", "jax[cpu] >= 0.4, < 0.5", "opacus ~= 1.1", "protobuf >= 3.19", "tensorflow ~= 2.5", - "torch ~= 1.10", + "torch >= 1.10, < 3.0", "websockets ~= 10.1", ] # thematically grouped dependencies (part of "all") @@ -72,9 +72,19 @@ haiku = [ tensorflow = [ "tensorflow ~= 2.5", ] -torch = [ - "functorch", # note: functorch is internalized by torch 2.0 - "torch ~= 1.10", +torch = [ # generic requirements for Torch + "functorch >= 1.10, < 3.0", # note: unused with torch >=2.0 + "torch >= 1.10, < 3.0", +] +torch1 = [ # Torch 1.13 (latest pre-2.0 version) + Opacus + "functorch ~= 1.13.0", + "opacus ~= 1.3.0", + "torch ~= 1.13.0", +] +torch2 = [ # Torch 2.X version + Opacus + "functorch ~= 2.0", # unused, but useful to linters and pip + "opacus ~= 1.4", + "torch ~= 2.0", ] websockets = [ "websockets ~= 10.1", diff --git a/scripts/run_tests.sh b/scripts/run_tests.sh index 189986b730d1ee3e75ccd723bc49697598371879..c962e8b6a6645feb0b564a4779db95e107bcfb28 100644 --- a/scripts/run_tests.sh +++ b/scripts/run_tests.sh @@ -114,6 +114,7 @@ run_declearn_tests() { commands=( "run_unit_tests $@" "run_integration_tests $@" + "run_torch13_tests $@" ) run_commands "declearn test suite" "${commands[@]}" status=$? @@ -152,6 +153,33 @@ run_integration_tests() { } +run_torch13_tests() { + : ' + Verbosely run Torch 1.13-specific unit tests. + + Install Torch 1.13 at the start of this function, and re-install + torch >=2.0 at the end of it, together with its co-dependencies. + ' + echo "Re-installing torch 1.13 and its co-dependencies." + pip install .[torch1] + if [[ $? -eq 0 ]]; then + echo "Running unit tests for torch 1.13." + command="pytest $@ + --cov --cov-append --cov-report= + test/model/test_torch.py + " + run_command $command + status=$? + else + echo "\e[31mSkipping tests as installation failed.\e[0m" + status=1 + fi + echo "Re-installing torch 2.X and its co-dependencies." + pip install .[torch2] + return $status +} + + main() { if [[ $# -eq 0 ]]; then echo "Missing required positional argument." diff --git a/test/model/test_torch.py b/test/model/test_torch.py index dafe8c8cb66b1fbfe12a7201d4bb35a0c4eb23b5..498b87a7549b49a104766df625b03b0370c5dac8 100644 --- a/test/model/test_torch.py +++ b/test/model/test_torch.py @@ -19,6 +19,7 @@ import json import sys +import typing from typing import Any, List, Literal, Tuple import numpy as np @@ -69,6 +70,9 @@ class FlattenCNNOutput(torch.nn.Module): return inputs.view(*shape) +Kind = Literal["MLP", "MLP-tune", "MLP-compile", "RNN", "CNN"] + + class TorchTestCase(ModelTestCase): """PyTorch test-case-provider fixture. @@ -88,6 +92,11 @@ class TorchTestCase(ModelTestCase): - stack: 32 7x7 conv. filters, then 8x8 max pooling 16 5x5 conv. filters, then 8x8 avg pooling 1 output neuron with sigmoid activation + + + Additional scenarios include "MLP-tune", where the MLP's first hidden + layer is frozen, and "MLP-compile", where the MLP is optimized using + `torch.compile` (available in torch >=2.0 only). """ vector_cls = TorchVector @@ -95,11 +104,11 @@ class TorchTestCase(ModelTestCase): def __init__( self, - kind: Literal["MLP", "MLP-tune", "RNN", "CNN"], + kind: Kind, device: Literal["CPU", "GPU"], ) -> None: """Specify the desired model architecture.""" - if kind not in ("MLP", "MLP-tune", "RNN", "CNN"): + if kind not in typing.get_args(Kind): raise ValueError(f"Invalid torch test architecture: '{kind}'.") self.kind = kind self.device = device @@ -169,6 +178,8 @@ class TorchTestCase(ModelTestCase): torch.nn.Sigmoid(), ] nnmod = torch.nn.Sequential(*stack) + if self.kind == "MLP-compile": + nnmod = torch.compile(nnmod) # type: ignore return TorchModel(nnmod, loss=torch.nn.BCELoss()) def assert_correct_device( @@ -184,13 +195,15 @@ class TorchTestCase(ModelTestCase): @pytest.fixture(name="test_case") def fixture_test_case( - kind: Literal["MLP", "MLP-tune", "RNN", "CNN"], + kind: Kind, device: Literal["CPU", "GPU"], cpu_only: bool, ) -> TorchTestCase: """Fixture to access a TorchTestCase.""" if cpu_only and device == "GPU": pytest.skip(reason="--cpu-only mode") + if kind == "MLP-compile" and not hasattr(torch, "compile"): + pytest.skip(reason="'torch.compile' is unavailable") return TorchTestCase(kind, device) @@ -200,7 +213,7 @@ if torch.cuda.device_count(): @pytest.mark.parametrize("device", DEVICES) -@pytest.mark.parametrize("kind", ["MLP", "MLP-tune", "RNN", "CNN"]) +@pytest.mark.parametrize("kind", typing.get_args(Kind)) class TestTorchModel(ModelTestSuite): """Unit tests for declearn.model.torch.TorchModel."""