diff --git a/declearn/model/torch/_model.py b/declearn/model/torch/_model.py index 59334cc9098ea9128a50a213473e27724ff67a4b..7491c3f53b444f8216143a39ce0d15249c810ebc 100644 --- a/declearn/model/torch/_model.py +++ b/declearn/model/torch/_model.py @@ -17,11 +17,19 @@ """Model subclass to wrap PyTorch models.""" +import functools import io import warnings from typing import Any, Callable, Dict, List, Optional, Set, Tuple import functorch # type: ignore + +try: + import functorch.compile # type: ignore +except ModuleNotFoundError: + COMPILE_AVAILABLE = False +else: + COMPILE_AVAILABLE = True import numpy as np import torch from typing_extensions import Self # future: import from typing (py >=3.11) @@ -83,7 +91,9 @@ class TorchModel(Model): loss: torch.nn.Module Torch Module instance that defines the model's loss, that is to be minimized through training. Note that it will be - altered when wrapped. + altered when wrapped. It must expect `y_pred` and `y_true` + as input arguments (in that order) and will be used to get + sample-wise loss values (by removing any reduction scheme). """ # Type-check the input model. if not isinstance(model, torch.nn.Module): @@ -99,7 +109,7 @@ class TorchModel(Model): loss.reduction = "none" # type: ignore self._loss_fn = AutoDeviceModule(loss, device=device) # Compute and assign a functional version of the model. - self._func_model = functorch.make_functional(self._model)[0] + self._func_model, _ = functorch.make_functional(self._model) @property def device_policy( @@ -265,117 +275,108 @@ class TorchModel(Model): loss.mul_(s_wght.to(loss.device)) return loss.mean() - def _compute_samplewise_gradients( + def _compute_clipped_gradients( self, batch: Batch, + max_norm: float, ) -> TorchVector: - """Compute and return stacked sample-wise gradients from a batch.""" - # Delegate preparation of the gradients-computing function. - # fmt: off - grads_fn, data, params, pnames, in_axes = ( - self._prepare_samplewise_gradients_computations(batch) + """Compute and return batch-averaged sample-wise-clipped gradients.""" + # Compute sample-wise clipped gradients, using functorch. + grads = self._compute_samplewise_gradients(batch, max_norm) + # Batch-average the resulting sample-wise gradients. + return TorchVector( + {name: tensor.mean(dim=0) for name, tensor in grads.coefs.items()} ) - # Vectorize the function to compute sample-wise gradients. - with torch.no_grad(): - grads = functorch.vmap(grads_fn, in_axes)(*data, *params) - # Wrap the results into a TorchVector and return it. - return TorchVector(dict(zip(pnames, grads))) - def _compute_clipped_gradients( + def _compute_samplewise_gradients( self, batch: Batch, - max_norm: float, + max_norm: Optional[float], ) -> TorchVector: - """Compute and return batch-averaged sample-wise-clipped gradients.""" - # Delegate preparation of the gradients-computing function. - # fmt: off - grads_fn, data, params, pnames, in_axes = ( - self._prepare_samplewise_gradients_computations(batch) + """Compute and return stacked sample-wise gradients over a batch.""" + # Unpack the inputs, gather parameters and list gradients to compute. + inputs, y_true, s_wght = self._unpack_batch(batch) + params = [] # type: List[torch.nn.Parameter] + idxgrd = [] # type: List[int] + pnames = [] # type: List[str] + for index, (name, param) in enumerate(self._model.named_parameters()): + params.append(param) + if param.requires_grad: + idxgrd.append(index + 3) + pnames.append(name) + # Gather or build the sample-wise clipped gradients computing function. + grads_fn = self._build_samplewise_grads_fn( + idxgrd=tuple(idxgrd), + inputs=len(inputs), + y_true=(y_true is not None), + s_wght=(s_wght is not None), ) - # Compose it to clip output gradients on the way. - def clipped_grads_fn(inputs, y_true, s_wght, *params): - grads = grads_fn(inputs, y_true, None, *params) - for grad in grads: - # future: use torch.linalg.norm when supported by functorch - norm = torch.norm(grad, p=2, keepdim=True) - # false-positive; pylint: disable=no-member - grad.mul_(torch.clamp(max_norm / norm, max=1)) - if s_wght is not None: - grad.mul_(s_wght.to(grad.device)) - return grads - # Vectorize the function to compute sample-wise clipped gradients. + # Call it on the current inputs, with optional clipping. with torch.no_grad(): - grads = functorch.vmap(clipped_grads_fn, in_axes)(*data, *params) - # Wrap batch-averaged results into a TorchVector and return it. - return TorchVector( - {name: grad.mean(dim=0) for name, grad in zip(pnames, grads)} - ) + grads = grads_fn(inputs, y_true, s_wght, *params, clip=max_norm) + # Wrap the results into a TorchVector and return it. + return TorchVector(dict(zip(pnames, grads))) - def _prepare_samplewise_gradients_computations( + @functools.lru_cache + def _build_samplewise_grads_fn( self, - batch: Batch, - ) -> Tuple[ - Callable[..., List[torch.Tensor]], - TensorBatch, - List[torch.nn.Parameter], - List[str], - Tuple[Any, ...], - ]: - """Prepare a function an parameters to compute sample-wise gradients. - - Note: this method is merely implemented as a way to avoid code - redundancies between the `_compute_samplewise_gradients` method - and the `_compute_clipped_gradients` ones. + idxgrd: Tuple[int, ...], + inputs: int, + y_true: bool, + s_wght: bool, + ) -> Callable[..., List[torch.Tensor]]: + """Build a functorch-based sample-wise gradients-computation function. + + This function is cached, i.e. repeated calls with the same parameters + will return the same object - enabling to reduce runtime costs due to + building and (when available) compiling the output function. Parameters ---------- - batch: declearn.typing.Batch - Batch structure wrapping the input data, target labels and - optional sample weights based on which to compute gradients. + idxgrd: tuple of int + Pre-incremented indices of the parameters that require gradients. + inputs: int + Number of input tensors. + y_true: bool + Whether a true labels tensor is provided. + s_wght: bool + Whether a sample weights tensor is provided. Returns ------- - grads_fn: function(*data, *params) -> List[torch.Tensor] - Functorch-issued gradients computation function. - data: tuple([torch.Tensor], torch.Tensor, torch.Tensor or None) - Tensor-converted data unpacked from `batch`. - params: list[torch.nn.Parameter] - Input parameters of the model, some of which require grads. - pnames: list[str] - Names of the parameters that require gradients. - in_axes: tuple(...) - Prepared `in_axes` parameter to `functorch.vmap`, suitable - to distribute `grads_fn` (or any compose that shares its - input signature) over a batch so as to compute sample-wise - gradients in a computationally-efficient manner. + grads_fn: callable[inputs, y_true, s_wght, *params, /, clip] + Functorch-optimized function to efficiently compute sample- + wise gradients based on batched inputs, and optionally clip + them based on a maximum l2-norm value `clip`. """ - # fmt: off - # Unpack and validate inputs. - data = (inputs, y_true, s_wght) = self._unpack_batch(batch) - # Gather parameters and list those that require gradients. - idxgrd = [] # type: List[int] - pnames = [] # type: List[str] - params = [] # type: List[torch.nn.Parameter] - for idx, (name, param) in enumerate(self._model.named_parameters()): - params.append(param) - if param.requires_grad: - pnames.append(name) - idxgrd.append(idx) - # Define a differentiable function wrapping the forward pass. + def forward(inputs, y_true, s_wght, *params): + """Conduct the forward pass in a functional way.""" y_pred = self._func_model(params, *inputs) return self._compute_loss(y_pred, y_true, s_wght) - # Transform it into a sample-wise-gradients-computing function. - grads_fn = functorch.grad(forward, argnums=tuple(i+3 for i in idxgrd)) - # Prepare `functools.vmap` parameter to slice through data and params. - in_axes = [ - [0] * len(inputs), - None if y_true is None else 0, - None if s_wght is None else 0, - ] - in_axes.extend([None] * len(params)) - # Return all this prepared material. - return grads_fn, data, params, pnames, tuple(in_axes) + + def grads_fn(inputs, y_true, s_wght, *params, clip=None): + """Compute gradients and optionally clip them.""" + gfunc = functorch.grad(forward, argnums=idxgrd) + grads = gfunc(inputs, y_true, None, *params) + if clip: + for grad in grads: + # future: use torch.linalg.norm when supported by functorch + norm = torch.norm(grad, p=2, keepdim=True) + # false-positive; pylint: disable=no-member + grad.mul_(torch.clamp(clip / norm, max=1)) + if s_wght is not None: + grad.mul_(s_wght.to(grad.device)) + return grads + + # Wrap the former function to compute and clip sample-wise gradients. + in_axes = [[0] * inputs, 0 if y_true else None, 0 if s_wght else None] + in_axes.extend([None] * sum(1 for _ in self._model.parameters())) + grads_fn = functorch.vmap(grads_fn, tuple(in_axes)) + # Compile the resulting function to decrease runtime costs. + if not COMPILE_AVAILABLE: + return grads_fn + return functorch.compile.aot_function(grads_fn, functorch.compile.nop) def apply_updates( # type: ignore # Vector subtype specification self,