Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit accb07d1 authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Revise TorchModel to support Torch 2.0.

- Torch 2.0 was released in march 2023, introducing a number of
  new features to the torch ecosystem, while being non-breaking
  wrt past versions on most points.
- One of the salient changes is the introduction of `torch.func`,
  that integrates features previously introduced as part of the
  `functorch` package, which we have been relying upon in order
  to efficiently compute and clip sample-wise gradients.
- This commit therefore introduces a new backend code branch so
  as to make use of `torch.func` when `torch~=2.0` is used, and
  retain the existing `functorch`-based branch when older torch
  versions are used.
- As part of this effort, some of the `TorchModel` backend code
  was refactored, notably deferring the functional transform of
  the input `torch.nn.Module` to the first call for sample-wise
  gradients computation. This has the nice side effect of fixing
  cases when users want to train a non-functorch-compatible model
  without DP-SGD.
parent 9f0035d7
No related branches found
No related tags found
1 merge request!49Add support for Torch 2.0
...@@ -17,25 +17,21 @@ ...@@ -17,25 +17,21 @@
"""Model subclass to wrap PyTorch models.""" """Model subclass to wrap PyTorch models."""
import functools
import io import io
import functools
import warnings import warnings
from typing import Any, Callable, Dict, List, Optional, Set, Tuple from typing import Any, Dict, List, Optional, Set, Tuple
import functorch # type: ignore
try:
import functorch.compile # type: ignore
except ModuleNotFoundError:
COMPILE_AVAILABLE = False
else:
COMPILE_AVAILABLE = True
import numpy as np import numpy as np
import torch import torch
from typing_extensions import Self # future: import from typing (py >=3.11) from typing_extensions import Self # future: import from typing (py >=3.11)
from declearn.model.api import Model from declearn.model.api import Model
from declearn.model.torch.utils import AutoDeviceModule, select_device from declearn.model.torch.utils import AutoDeviceModule, select_device
from declearn.model.torch._samplewise import (
GetGradientsFunction,
build_samplewise_grads_fn,
)
from declearn.model.torch._vector import TorchVector from declearn.model.torch._vector import TorchVector
from declearn.model._utils import raise_on_stringsets_mismatch from declearn.model._utils import raise_on_stringsets_mismatch
from declearn.typing import Batch from declearn.typing import Batch
...@@ -108,8 +104,6 @@ class TorchModel(Model): ...@@ -108,8 +104,6 @@ class TorchModel(Model):
raise TypeError("'loss' should be a torch.nn.Module instance.") raise TypeError("'loss' should be a torch.nn.Module instance.")
loss.reduction = "none" # type: ignore loss.reduction = "none" # type: ignore
self._loss_fn = AutoDeviceModule(loss, device=device) self._loss_fn = AutoDeviceModule(loss, device=device)
# Compute and assign a functional version of the model.
self._func_model, _ = functorch.make_functional(self._model)
@property @property
def device_policy( def device_policy(
...@@ -281,8 +275,8 @@ class TorchModel(Model): ...@@ -281,8 +275,8 @@ class TorchModel(Model):
max_norm: float, max_norm: float,
) -> TorchVector: ) -> TorchVector:
"""Compute and return batch-averaged sample-wise-clipped gradients.""" """Compute and return batch-averaged sample-wise-clipped gradients."""
# Compute sample-wise clipped gradients, using functorch. # Compute sample-wise clipped gradients, using functional torch.
grads = self._compute_samplewise_gradients(batch, max_norm) grads = self._compute_samplewise_gradients(batch, clip=max_norm)
# Batch-average the resulting sample-wise gradients. # Batch-average the resulting sample-wise gradients.
return TorchVector( return TorchVector(
{name: tensor.mean(dim=0) for name, tensor in grads.coefs.items()} {name: tensor.mean(dim=0) for name, tensor in grads.coefs.items()}
...@@ -291,92 +285,48 @@ class TorchModel(Model): ...@@ -291,92 +285,48 @@ class TorchModel(Model):
def _compute_samplewise_gradients( def _compute_samplewise_gradients(
self, self,
batch: Batch, batch: Batch,
max_norm: Optional[float], clip: Optional[float],
) -> TorchVector: ) -> TorchVector:
"""Compute and return stacked sample-wise gradients over a batch.""" """Compute and return stacked sample-wise gradients over a batch."""
# Unpack the inputs, gather parameters and list gradients to compute.
inputs, y_true, s_wght = self._unpack_batch(batch) inputs, y_true, s_wght = self._unpack_batch(batch)
params = [] # type: List[torch.nn.Parameter]
idxgrd = [] # type: List[int]
pnames = [] # type: List[str]
for index, (name, param) in enumerate(self._model.named_parameters()):
params.append(param)
if param.requires_grad:
idxgrd.append(index + 3)
pnames.append(name)
# Gather or build the sample-wise clipped gradients computing function.
grads_fn = self._build_samplewise_grads_fn( grads_fn = self._build_samplewise_grads_fn(
idxgrd=tuple(idxgrd),
inputs=len(inputs), inputs=len(inputs),
y_true=(y_true is not None), y_true=(y_true is not None),
s_wght=(s_wght is not None), s_wght=(s_wght is not None),
) )
# Call it on the current inputs, with optional clipping.
with torch.no_grad(): with torch.no_grad():
grads = grads_fn(inputs, y_true, s_wght, *params, clip=max_norm) grads = grads_fn(inputs, y_true, s_wght, clip=clip) # type: ignore
# Wrap the results into a TorchVector and return it. return TorchVector(grads)
return TorchVector(dict(zip(pnames, grads)))
@functools.lru_cache @functools.lru_cache
def _build_samplewise_grads_fn( def _build_samplewise_grads_fn(
self, self,
idxgrd: Tuple[int, ...],
inputs: int, inputs: int,
y_true: bool, y_true: bool,
s_wght: bool, s_wght: bool,
) -> Callable[..., List[torch.Tensor]]: ) -> GetGradientsFunction:
"""Build a functorch-based sample-wise gradients-computation function. """Build an optimizer sample-wise gradients-computation function.
This function is cached, i.e. repeated calls with the same parameters This function is cached, i.e. repeated calls with the same parameters
will return the same object - enabling to reduce runtime costs due to will return the same object - enabling to reduce runtime costs due to
building and (when available) compiling the output function. building and (when available) compiling the output function.
Parameters
----------
idxgrd: tuple of int
Pre-incremented indices of the parameters that require gradients.
inputs: int
Number of input tensors.
y_true: bool
Whether a true labels tensor is provided.
s_wght: bool
Whether a sample weights tensor is provided.
Returns Returns
------- -------
grads_fn: callable[inputs, y_true, s_wght, *params, /, clip] grads_fn: callable[[inputs, y_true, s_wght, clip], grads]
Functorch-optimized function to efficiently compute sample- Function to efficiently compute and return sample-wise gradients
wise gradients based on batched inputs, and optionally clip wrt trainable model parameters based on a batch of inputs, with
them based on a maximum l2-norm value `clip`. opt. clipping based on a maximum l2-norm value `clip`.
Note
----
The underlying backend code depends on your Torch version, so as to
enable optimizing operations using either `functorch` for torch 1.1X
or `torch.func` for torch 2.X.
""" """
return build_samplewise_grads_fn(
def forward(inputs, y_true, s_wght, *params): self._model, self._loss_fn, inputs, y_true, s_wght
"""Conduct the forward pass in a functional way.""" )
y_pred = self._func_model(params, *inputs)
return self._compute_loss(y_pred, y_true, s_wght)
def grads_fn(inputs, y_true, s_wght, *params, clip=None):
"""Compute gradients and optionally clip them."""
gfunc = functorch.grad(forward, argnums=idxgrd)
grads = gfunc(inputs, y_true, None, *params)
if clip:
for grad in grads:
# future: use torch.linalg.norm when supported by functorch
norm = torch.norm(grad, p=2, keepdim=True)
# false-positive; pylint: disable=no-member
grad.mul_(torch.clamp(clip / norm, max=1))
if s_wght is not None:
grad.mul_(s_wght.to(grad.device))
return grads
# Wrap the former function to compute and clip sample-wise gradients.
in_axes = [[0] * inputs, 0 if y_true else None, 0 if s_wght else None]
in_axes.extend([None] * sum(1 for _ in self._model.parameters()))
grads_fn = functorch.vmap(grads_fn, tuple(in_axes))
# Compile the resulting function to decrease runtime costs.
if not COMPILE_AVAILABLE:
return grads_fn
return functorch.compile.aot_function(grads_fn, functorch.compile.nop)
def apply_updates( def apply_updates(
self, self,
......
# coding: utf-8
# Copyright 2023 Inria (Institut National de Recherche en Informatique
# et Automatique)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Torch-version-dependent code to compute sample-wise gradients."""
from typing import Callable, Dict, List, Optional
import torch
from .shared import GetGradientsFunction
if torch.__version__.startswith("2."):
from .torchfunc import build_samplewise_grads_fn_backend
elif torch.__version__.startswith("1.1"):
from .functorch import build_samplewise_grads_fn_backend
else:
# pragma: no cover
raise ImportError(f"Unsupported Torch version: {torch.__version__}")
__all__ = [
"GetGradientsFunction",
"build_samplewise_grads_fn",
]
def build_samplewise_grads_fn(
model: torch.nn.Module,
loss_fn: torch.nn.Module,
inputs: int,
y_true: bool,
s_wght: bool,
) -> GetGradientsFunction:
"""Build a torch-specific sample-wise gradients-computation function.
Parameters
----------
model: torch.nn.Module
Model that is to be trained.
loss_fn: torch.nn.Module
Loss-computing module, returning sample-wise loss values.
inputs: int
Number of input tensors.
y_true: bool
Whether a true labels tensor is provided.
s_wght: bool
Whether a sample weights tensor is provided.
Returns
-------
grads_fn: callable[[inputs, y_true, s_wght, clip], grads]
Function that efficiently computes and returns sample-wise gradients
wrt trainable model parameters based on a batch of inputs, with opt.
clipping based on a maximum l2-norm value `clip`.
Note
----
The underlying backend code depends on your Torch version, so as to
enable optimizing operations using either `functorch` for torch 1.1X
or `torch.func` for torch 2.X.
"""
return build_samplewise_grads_fn_backend(
model, loss_fn, inputs, y_true, s_wght
)
# coding: utf-8
# Copyright 2023 Inria (Institut National de Recherche en Informatique
# et Automatique)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of `build_samplewise_grads_fn` for Torch 2.0."""
from typing import List, Tuple
# fmt: off
import functorch # type: ignore
try:
import functorch.compile # type: ignore
COMPILE_AVAILABLE = True
except ModuleNotFoundError:
# pragma: no cover
COMPILE_AVAILABLE = False
import torch
# fmt: on
from declearn.model.torch._samplewise.shared import (
GetGradientsFunction,
clip_and_scale_grads_inplace,
)
__all__ = [
"build_samplewise_grads_fn_backend",
]
def build_samplewise_grads_fn_backend(
model: torch.nn.Module,
loss_fn: torch.nn.Module,
inputs: int,
y_true: bool,
s_wght: bool,
) -> GetGradientsFunction:
"""Implementation of `build_samplewise_grads_fn` for Torch 1.1X."""
func_model, _ = functorch.make_functional(model)
def run_forward(inputs, y_true, s_wght, *params):
"""Run the forward pass in a functional way."""
y_pred = func_model(params, *inputs)
s_loss = loss_fn(y_pred, y_true)
if s_wght is not None:
s_loss.mul_(s_wght.to(s_loss.device))
return s_loss.mean()
def grads_fn(inputs, y_true, s_wght, clip=None):
"""Compute gradients and optionally clip them."""
params, idxgrd, pnames = get_params(model)
gfunc = functorch.grad(run_forward, argnums=tuple(idxgrd))
grads = gfunc(inputs, y_true, (None if clip else s_wght), *params)
if clip:
clip_and_scale_grads_inplace(grads, clip, s_wght)
return dict(zip(pnames, grads))
# Wrap the former function to compute and clip sample-wise gradients.
in_dims = ([0] * inputs, 0 if y_true else None, 0 if s_wght else None)
grads_fn = functorch.vmap(grads_fn, in_dims)
# Compile the resulting function to decrease runtime costs.
if not COMPILE_AVAILABLE:
# pragma: no cover
return grads_fn
return functorch.compile.aot_function(grads_fn, functorch.compile.nop)
def get_params(
model: torch.nn.Module,
) -> Tuple[List[torch.nn.Parameter], List[int], List[str]]:
"""Return a model's parameters and the index and name of trainable ones."""
params = [] # type: List[torch.nn.Parameter]
idxgrd = [] # type: List[int]
pnames = [] # type: List[str]
for idx, (name, param) in enumerate(model.named_parameters()):
params.append(param)
if param.requires_grad:
idxgrd.append(idx + 3)
pnames.append(name)
return params, idxgrd, pnames
# coding: utf-8
# Copyright 2023 Inria (Institut National de Recherche en Informatique
# et Automatique)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared code for torch-version-dependent backend code."""
from typing import Callable, Dict, Iterable, List, Optional
import torch
__all__ = [
"GetGradientsFunction",
"clip_and_scale_grads_inplace",
]
GetGradientsFunction = Callable[
[
List[torch.Tensor],
Optional[torch.Tensor],
Optional[torch.Tensor],
Optional[float],
],
Dict[str, torch.Tensor],
]
"""Signature for sample-wise gradients computation functions."""
def clip_and_scale_grads_inplace(
grads: Iterable[torch.Tensor],
clip: float,
wght: Optional[torch.Tensor] = None,
) -> None:
"""Clip a collection of tensors in-place, based on their euclidean norm.
Also apply an optional weight tensor to scale the clipped gradients.
"""
for grad in grads:
norm = torch.norm(grad, p=2, keepdim=True)
# false-positive; pylint: disable=no-member
grad.mul_(torch.clamp(clip / norm, max=1))
if wght is not None:
grad.mul_(wght.to(grad.device))
# coding: utf-8
# Copyright 2023 Inria (Institut National de Recherche en Informatique
# et Automatique)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of `build_samplewise_grads_fn` for Torch 2.0."""
from typing import Dict, Tuple
import torch
from declearn.model.torch._samplewise.shared import (
GetGradientsFunction,
clip_and_scale_grads_inplace,
)
__all__ = [
"build_samplewise_grads_fn_backend",
]
def build_samplewise_grads_fn_backend(
model: torch.nn.Module,
loss_fn: torch.nn.Module,
inputs: int,
y_true: bool,
s_wght: bool,
) -> GetGradientsFunction:
"""Implementation of `build_samplewise_grads_fn` for Torch 2.0."""
def run_forward(params, frozen, inputs, y_true, s_wght):
"""Run the forward pass in a functional way."""
y_pred = torch.func.functional_call(model, [params, frozen], *inputs)
s_loss = loss_fn(y_pred, y_true)
if s_wght is not None:
s_loss.mul_(s_wght.to(s_loss.device))
return s_loss.mean()
get_grads = torch.func.grad(run_forward, argnums=0)
def get_clipped_grads(inputs, y_true, s_wght, clip=None):
"""Compute gradients and optionally clip them."""
params, frozen = get_params(model)
grads = get_grads(
params, frozen, inputs, y_true, None if clip else s_wght
)
if clip:
clip_and_scale_grads_inplace(grads.values(), clip, s_wght)
return grads
# Wrap the former function to compute and clip sample-wise gradients.
in_dims = ([0] * inputs, 0 if y_true else None, 0 if s_wght else None)
return torch.func.vmap(get_clipped_grads, in_dims)
def get_params(
model: torch.nn.Module,
) -> Tuple[Dict[str, torch.nn.Parameter], Dict[str, torch.nn.Parameter]]:
"""Return a model's parameters, split between trainable and frozen ones."""
params = {} # type: Dict[str, torch.nn.Parameter]
frozen = {} # type: Dict[str, torch.nn.Parameter]
for name, param in model.named_parameters():
(params if param.requires_grad else frozen)[name] = param
return params, frozen
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment