Mentions légales du service

Skip to content
Snippets Groups Projects
Commit b8222c84 authored by BIGAUD Nathan's avatar BIGAUD Nathan
Browse files

Merge branch 'functorch' into 'develop'

Refactor 'TorchModel' backend code to compute sample-wise gradients.

See merge request !42
parents 594dfacb be449f94
No related branches found
No related tags found
1 merge request!42Refactor 'TorchModel' backend code to compute sample-wise gradients.
Pipeline #789706 failed
...@@ -17,11 +17,19 @@ ...@@ -17,11 +17,19 @@
"""Model subclass to wrap PyTorch models.""" """Model subclass to wrap PyTorch models."""
import functools
import io import io
import warnings import warnings
from typing import Any, Callable, Dict, List, Optional, Set, Tuple from typing import Any, Callable, Dict, List, Optional, Set, Tuple
import functorch # type: ignore import functorch # type: ignore
try:
import functorch.compile # type: ignore
except ModuleNotFoundError:
COMPILE_AVAILABLE = False
else:
COMPILE_AVAILABLE = True
import numpy as np import numpy as np
import torch import torch
from typing_extensions import Self # future: import from typing (py >=3.11) from typing_extensions import Self # future: import from typing (py >=3.11)
...@@ -83,7 +91,9 @@ class TorchModel(Model): ...@@ -83,7 +91,9 @@ class TorchModel(Model):
loss: torch.nn.Module loss: torch.nn.Module
Torch Module instance that defines the model's loss, that Torch Module instance that defines the model's loss, that
is to be minimized through training. Note that it will be is to be minimized through training. Note that it will be
altered when wrapped. altered when wrapped. It must expect `y_pred` and `y_true`
as input arguments (in that order) and will be used to get
sample-wise loss values (by removing any reduction scheme).
""" """
# Type-check the input model. # Type-check the input model.
if not isinstance(model, torch.nn.Module): if not isinstance(model, torch.nn.Module):
...@@ -99,7 +109,7 @@ class TorchModel(Model): ...@@ -99,7 +109,7 @@ class TorchModel(Model):
loss.reduction = "none" # type: ignore loss.reduction = "none" # type: ignore
self._loss_fn = AutoDeviceModule(loss, device=device) self._loss_fn = AutoDeviceModule(loss, device=device)
# Compute and assign a functional version of the model. # Compute and assign a functional version of the model.
self._func_model = functorch.make_functional(self._model)[0] self._func_model, _ = functorch.make_functional(self._model)
@property @property
def device_policy( def device_policy(
...@@ -265,117 +275,108 @@ class TorchModel(Model): ...@@ -265,117 +275,108 @@ class TorchModel(Model):
loss.mul_(s_wght.to(loss.device)) loss.mul_(s_wght.to(loss.device))
return loss.mean() return loss.mean()
def _compute_samplewise_gradients( def _compute_clipped_gradients(
self, self,
batch: Batch, batch: Batch,
max_norm: float,
) -> TorchVector: ) -> TorchVector:
"""Compute and return stacked sample-wise gradients from a batch.""" """Compute and return batch-averaged sample-wise-clipped gradients."""
# Delegate preparation of the gradients-computing function. # Compute sample-wise clipped gradients, using functorch.
# fmt: off grads = self._compute_samplewise_gradients(batch, max_norm)
grads_fn, data, params, pnames, in_axes = ( # Batch-average the resulting sample-wise gradients.
self._prepare_samplewise_gradients_computations(batch) return TorchVector(
{name: tensor.mean(dim=0) for name, tensor in grads.coefs.items()}
) )
# Vectorize the function to compute sample-wise gradients.
with torch.no_grad():
grads = functorch.vmap(grads_fn, in_axes)(*data, *params)
# Wrap the results into a TorchVector and return it.
return TorchVector(dict(zip(pnames, grads)))
def _compute_clipped_gradients( def _compute_samplewise_gradients(
self, self,
batch: Batch, batch: Batch,
max_norm: float, max_norm: Optional[float],
) -> TorchVector: ) -> TorchVector:
"""Compute and return batch-averaged sample-wise-clipped gradients.""" """Compute and return stacked sample-wise gradients over a batch."""
# Delegate preparation of the gradients-computing function. # Unpack the inputs, gather parameters and list gradients to compute.
# fmt: off inputs, y_true, s_wght = self._unpack_batch(batch)
grads_fn, data, params, pnames, in_axes = ( params = [] # type: List[torch.nn.Parameter]
self._prepare_samplewise_gradients_computations(batch) idxgrd = [] # type: List[int]
pnames = [] # type: List[str]
for index, (name, param) in enumerate(self._model.named_parameters()):
params.append(param)
if param.requires_grad:
idxgrd.append(index + 3)
pnames.append(name)
# Gather or build the sample-wise clipped gradients computing function.
grads_fn = self._build_samplewise_grads_fn(
idxgrd=tuple(idxgrd),
inputs=len(inputs),
y_true=(y_true is not None),
s_wght=(s_wght is not None),
) )
# Compose it to clip output gradients on the way. # Call it on the current inputs, with optional clipping.
def clipped_grads_fn(inputs, y_true, s_wght, *params):
grads = grads_fn(inputs, y_true, None, *params)
for grad in grads:
# future: use torch.linalg.norm when supported by functorch
norm = torch.norm(grad, p=2, keepdim=True)
# false-positive; pylint: disable=no-member
grad.mul_(torch.clamp(max_norm / norm, max=1))
if s_wght is not None:
grad.mul_(s_wght.to(grad.device))
return grads
# Vectorize the function to compute sample-wise clipped gradients.
with torch.no_grad(): with torch.no_grad():
grads = functorch.vmap(clipped_grads_fn, in_axes)(*data, *params) grads = grads_fn(inputs, y_true, s_wght, *params, clip=max_norm)
# Wrap batch-averaged results into a TorchVector and return it. # Wrap the results into a TorchVector and return it.
return TorchVector( return TorchVector(dict(zip(pnames, grads)))
{name: grad.mean(dim=0) for name, grad in zip(pnames, grads)}
)
def _prepare_samplewise_gradients_computations( @functools.lru_cache
def _build_samplewise_grads_fn(
self, self,
batch: Batch, idxgrd: Tuple[int, ...],
) -> Tuple[ inputs: int,
Callable[..., List[torch.Tensor]], y_true: bool,
TensorBatch, s_wght: bool,
List[torch.nn.Parameter], ) -> Callable[..., List[torch.Tensor]]:
List[str], """Build a functorch-based sample-wise gradients-computation function.
Tuple[Any, ...],
]: This function is cached, i.e. repeated calls with the same parameters
"""Prepare a function an parameters to compute sample-wise gradients. will return the same object - enabling to reduce runtime costs due to
building and (when available) compiling the output function.
Note: this method is merely implemented as a way to avoid code
redundancies between the `_compute_samplewise_gradients` method
and the `_compute_clipped_gradients` ones.
Parameters Parameters
---------- ----------
batch: declearn.typing.Batch idxgrd: tuple of int
Batch structure wrapping the input data, target labels and Pre-incremented indices of the parameters that require gradients.
optional sample weights based on which to compute gradients. inputs: int
Number of input tensors.
y_true: bool
Whether a true labels tensor is provided.
s_wght: bool
Whether a sample weights tensor is provided.
Returns Returns
------- -------
grads_fn: function(*data, *params) -> List[torch.Tensor] grads_fn: callable[inputs, y_true, s_wght, *params, /, clip]
Functorch-issued gradients computation function. Functorch-optimized function to efficiently compute sample-
data: tuple([torch.Tensor], torch.Tensor, torch.Tensor or None) wise gradients based on batched inputs, and optionally clip
Tensor-converted data unpacked from `batch`. them based on a maximum l2-norm value `clip`.
params: list[torch.nn.Parameter]
Input parameters of the model, some of which require grads.
pnames: list[str]
Names of the parameters that require gradients.
in_axes: tuple(...)
Prepared `in_axes` parameter to `functorch.vmap`, suitable
to distribute `grads_fn` (or any compose that shares its
input signature) over a batch so as to compute sample-wise
gradients in a computationally-efficient manner.
""" """
# fmt: off
# Unpack and validate inputs.
data = (inputs, y_true, s_wght) = self._unpack_batch(batch)
# Gather parameters and list those that require gradients.
idxgrd = [] # type: List[int]
pnames = [] # type: List[str]
params = [] # type: List[torch.nn.Parameter]
for idx, (name, param) in enumerate(self._model.named_parameters()):
params.append(param)
if param.requires_grad:
pnames.append(name)
idxgrd.append(idx)
# Define a differentiable function wrapping the forward pass.
def forward(inputs, y_true, s_wght, *params): def forward(inputs, y_true, s_wght, *params):
"""Conduct the forward pass in a functional way."""
y_pred = self._func_model(params, *inputs) y_pred = self._func_model(params, *inputs)
return self._compute_loss(y_pred, y_true, s_wght) return self._compute_loss(y_pred, y_true, s_wght)
# Transform it into a sample-wise-gradients-computing function.
grads_fn = functorch.grad(forward, argnums=tuple(i+3 for i in idxgrd)) def grads_fn(inputs, y_true, s_wght, *params, clip=None):
# Prepare `functools.vmap` parameter to slice through data and params. """Compute gradients and optionally clip them."""
in_axes = [ gfunc = functorch.grad(forward, argnums=idxgrd)
[0] * len(inputs), grads = gfunc(inputs, y_true, None, *params)
None if y_true is None else 0, if clip:
None if s_wght is None else 0, for grad in grads:
] # future: use torch.linalg.norm when supported by functorch
in_axes.extend([None] * len(params)) norm = torch.norm(grad, p=2, keepdim=True)
# Return all this prepared material. # false-positive; pylint: disable=no-member
return grads_fn, data, params, pnames, tuple(in_axes) grad.mul_(torch.clamp(clip / norm, max=1))
if s_wght is not None:
grad.mul_(s_wght.to(grad.device))
return grads
# Wrap the former function to compute and clip sample-wise gradients.
in_axes = [[0] * inputs, 0 if y_true else None, 0 if s_wght else None]
in_axes.extend([None] * sum(1 for _ in self._model.parameters()))
grads_fn = functorch.vmap(grads_fn, tuple(in_axes))
# Compile the resulting function to decrease runtime costs.
if not COMPILE_AVAILABLE:
return grads_fn
return functorch.compile.aot_function(grads_fn, functorch.compile.nop)
def apply_updates( # type: ignore # Vector subtype specification def apply_updates( # type: ignore # Vector subtype specification
self, self,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment