Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit accb07d1 authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Revise TorchModel to support Torch 2.0.

- Torch 2.0 was released in march 2023, introducing a number of
  new features to the torch ecosystem, while being non-breaking
  wrt past versions on most points.
- One of the salient changes is the introduction of `torch.func`,
  that integrates features previously introduced as part of the
  `functorch` package, which we have been relying upon in order
  to efficiently compute and clip sample-wise gradients.
- This commit therefore introduces a new backend code branch so
  as to make use of `torch.func` when `torch~=2.0` is used, and
  retain the existing `functorch`-based branch when older torch
  versions are used.
- As part of this effort, some of the `TorchModel` backend code
  was refactored, notably deferring the functional transform of
  the input `torch.nn.Module` to the first call for sample-wise
  gradients computation. This has the nice side effect of fixing
  cases when users want to train a non-functorch-compatible model
  without DP-SGD.
parent 9f0035d7
No related branches found
No related tags found
1 merge request!49Add support for Torch 2.0
......@@ -17,25 +17,21 @@
"""Model subclass to wrap PyTorch models."""
import functools
import io
import functools
import warnings
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
import functorch # type: ignore
from typing import Any, Dict, List, Optional, Set, Tuple
try:
import functorch.compile # type: ignore
except ModuleNotFoundError:
COMPILE_AVAILABLE = False
else:
COMPILE_AVAILABLE = True
import numpy as np
import torch
from typing_extensions import Self # future: import from typing (py >=3.11)
from declearn.model.api import Model
from declearn.model.torch.utils import AutoDeviceModule, select_device
from declearn.model.torch._samplewise import (
GetGradientsFunction,
build_samplewise_grads_fn,
)
from declearn.model.torch._vector import TorchVector
from declearn.model._utils import raise_on_stringsets_mismatch
from declearn.typing import Batch
......@@ -108,8 +104,6 @@ class TorchModel(Model):
raise TypeError("'loss' should be a torch.nn.Module instance.")
loss.reduction = "none" # type: ignore
self._loss_fn = AutoDeviceModule(loss, device=device)
# Compute and assign a functional version of the model.
self._func_model, _ = functorch.make_functional(self._model)
@property
def device_policy(
......@@ -281,8 +275,8 @@ class TorchModel(Model):
max_norm: float,
) -> TorchVector:
"""Compute and return batch-averaged sample-wise-clipped gradients."""
# Compute sample-wise clipped gradients, using functorch.
grads = self._compute_samplewise_gradients(batch, max_norm)
# Compute sample-wise clipped gradients, using functional torch.
grads = self._compute_samplewise_gradients(batch, clip=max_norm)
# Batch-average the resulting sample-wise gradients.
return TorchVector(
{name: tensor.mean(dim=0) for name, tensor in grads.coefs.items()}
......@@ -291,92 +285,48 @@ class TorchModel(Model):
def _compute_samplewise_gradients(
self,
batch: Batch,
max_norm: Optional[float],
clip: Optional[float],
) -> TorchVector:
"""Compute and return stacked sample-wise gradients over a batch."""
# Unpack the inputs, gather parameters and list gradients to compute.
inputs, y_true, s_wght = self._unpack_batch(batch)
params = [] # type: List[torch.nn.Parameter]
idxgrd = [] # type: List[int]
pnames = [] # type: List[str]
for index, (name, param) in enumerate(self._model.named_parameters()):
params.append(param)
if param.requires_grad:
idxgrd.append(index + 3)
pnames.append(name)
# Gather or build the sample-wise clipped gradients computing function.
grads_fn = self._build_samplewise_grads_fn(
idxgrd=tuple(idxgrd),
inputs=len(inputs),
y_true=(y_true is not None),
s_wght=(s_wght is not None),
)
# Call it on the current inputs, with optional clipping.
with torch.no_grad():
grads = grads_fn(inputs, y_true, s_wght, *params, clip=max_norm)
# Wrap the results into a TorchVector and return it.
return TorchVector(dict(zip(pnames, grads)))
grads = grads_fn(inputs, y_true, s_wght, clip=clip) # type: ignore
return TorchVector(grads)
@functools.lru_cache
def _build_samplewise_grads_fn(
self,
idxgrd: Tuple[int, ...],
inputs: int,
y_true: bool,
s_wght: bool,
) -> Callable[..., List[torch.Tensor]]:
"""Build a functorch-based sample-wise gradients-computation function.
) -> GetGradientsFunction:
"""Build an optimizer sample-wise gradients-computation function.
This function is cached, i.e. repeated calls with the same parameters
will return the same object - enabling to reduce runtime costs due to
building and (when available) compiling the output function.
Parameters
----------
idxgrd: tuple of int
Pre-incremented indices of the parameters that require gradients.
inputs: int
Number of input tensors.
y_true: bool
Whether a true labels tensor is provided.
s_wght: bool
Whether a sample weights tensor is provided.
Returns
-------
grads_fn: callable[inputs, y_true, s_wght, *params, /, clip]
Functorch-optimized function to efficiently compute sample-
wise gradients based on batched inputs, and optionally clip
them based on a maximum l2-norm value `clip`.
grads_fn: callable[[inputs, y_true, s_wght, clip], grads]
Function to efficiently compute and return sample-wise gradients
wrt trainable model parameters based on a batch of inputs, with
opt. clipping based on a maximum l2-norm value `clip`.
Note
----
The underlying backend code depends on your Torch version, so as to
enable optimizing operations using either `functorch` for torch 1.1X
or `torch.func` for torch 2.X.
"""
def forward(inputs, y_true, s_wght, *params):
"""Conduct the forward pass in a functional way."""
y_pred = self._func_model(params, *inputs)
return self._compute_loss(y_pred, y_true, s_wght)
def grads_fn(inputs, y_true, s_wght, *params, clip=None):
"""Compute gradients and optionally clip them."""
gfunc = functorch.grad(forward, argnums=idxgrd)
grads = gfunc(inputs, y_true, None, *params)
if clip:
for grad in grads:
# future: use torch.linalg.norm when supported by functorch
norm = torch.norm(grad, p=2, keepdim=True)
# false-positive; pylint: disable=no-member
grad.mul_(torch.clamp(clip / norm, max=1))
if s_wght is not None:
grad.mul_(s_wght.to(grad.device))
return grads
# Wrap the former function to compute and clip sample-wise gradients.
in_axes = [[0] * inputs, 0 if y_true else None, 0 if s_wght else None]
in_axes.extend([None] * sum(1 for _ in self._model.parameters()))
grads_fn = functorch.vmap(grads_fn, tuple(in_axes))
# Compile the resulting function to decrease runtime costs.
if not COMPILE_AVAILABLE:
return grads_fn
return functorch.compile.aot_function(grads_fn, functorch.compile.nop)
return build_samplewise_grads_fn(
self._model, self._loss_fn, inputs, y_true, s_wght
)
def apply_updates(
self,
......
# coding: utf-8
# Copyright 2023 Inria (Institut National de Recherche en Informatique
# et Automatique)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Torch-version-dependent code to compute sample-wise gradients."""
from typing import Callable, Dict, List, Optional
import torch
from .shared import GetGradientsFunction
if torch.__version__.startswith("2."):
from .torchfunc import build_samplewise_grads_fn_backend
elif torch.__version__.startswith("1.1"):
from .functorch import build_samplewise_grads_fn_backend
else:
# pragma: no cover
raise ImportError(f"Unsupported Torch version: {torch.__version__}")
__all__ = [
"GetGradientsFunction",
"build_samplewise_grads_fn",
]
def build_samplewise_grads_fn(
model: torch.nn.Module,
loss_fn: torch.nn.Module,
inputs: int,
y_true: bool,
s_wght: bool,
) -> GetGradientsFunction:
"""Build a torch-specific sample-wise gradients-computation function.
Parameters
----------
model: torch.nn.Module
Model that is to be trained.
loss_fn: torch.nn.Module
Loss-computing module, returning sample-wise loss values.
inputs: int
Number of input tensors.
y_true: bool
Whether a true labels tensor is provided.
s_wght: bool
Whether a sample weights tensor is provided.
Returns
-------
grads_fn: callable[[inputs, y_true, s_wght, clip], grads]
Function that efficiently computes and returns sample-wise gradients
wrt trainable model parameters based on a batch of inputs, with opt.
clipping based on a maximum l2-norm value `clip`.
Note
----
The underlying backend code depends on your Torch version, so as to
enable optimizing operations using either `functorch` for torch 1.1X
or `torch.func` for torch 2.X.
"""
return build_samplewise_grads_fn_backend(
model, loss_fn, inputs, y_true, s_wght
)
# coding: utf-8
# Copyright 2023 Inria (Institut National de Recherche en Informatique
# et Automatique)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of `build_samplewise_grads_fn` for Torch 2.0."""
from typing import List, Tuple
# fmt: off
import functorch # type: ignore
try:
import functorch.compile # type: ignore
COMPILE_AVAILABLE = True
except ModuleNotFoundError:
# pragma: no cover
COMPILE_AVAILABLE = False
import torch
# fmt: on
from declearn.model.torch._samplewise.shared import (
GetGradientsFunction,
clip_and_scale_grads_inplace,
)
__all__ = [
"build_samplewise_grads_fn_backend",
]
def build_samplewise_grads_fn_backend(
model: torch.nn.Module,
loss_fn: torch.nn.Module,
inputs: int,
y_true: bool,
s_wght: bool,
) -> GetGradientsFunction:
"""Implementation of `build_samplewise_grads_fn` for Torch 1.1X."""
func_model, _ = functorch.make_functional(model)
def run_forward(inputs, y_true, s_wght, *params):
"""Run the forward pass in a functional way."""
y_pred = func_model(params, *inputs)
s_loss = loss_fn(y_pred, y_true)
if s_wght is not None:
s_loss.mul_(s_wght.to(s_loss.device))
return s_loss.mean()
def grads_fn(inputs, y_true, s_wght, clip=None):
"""Compute gradients and optionally clip them."""
params, idxgrd, pnames = get_params(model)
gfunc = functorch.grad(run_forward, argnums=tuple(idxgrd))
grads = gfunc(inputs, y_true, (None if clip else s_wght), *params)
if clip:
clip_and_scale_grads_inplace(grads, clip, s_wght)
return dict(zip(pnames, grads))
# Wrap the former function to compute and clip sample-wise gradients.
in_dims = ([0] * inputs, 0 if y_true else None, 0 if s_wght else None)
grads_fn = functorch.vmap(grads_fn, in_dims)
# Compile the resulting function to decrease runtime costs.
if not COMPILE_AVAILABLE:
# pragma: no cover
return grads_fn
return functorch.compile.aot_function(grads_fn, functorch.compile.nop)
def get_params(
model: torch.nn.Module,
) -> Tuple[List[torch.nn.Parameter], List[int], List[str]]:
"""Return a model's parameters and the index and name of trainable ones."""
params = [] # type: List[torch.nn.Parameter]
idxgrd = [] # type: List[int]
pnames = [] # type: List[str]
for idx, (name, param) in enumerate(model.named_parameters()):
params.append(param)
if param.requires_grad:
idxgrd.append(idx + 3)
pnames.append(name)
return params, idxgrd, pnames
# coding: utf-8
# Copyright 2023 Inria (Institut National de Recherche en Informatique
# et Automatique)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared code for torch-version-dependent backend code."""
from typing import Callable, Dict, Iterable, List, Optional
import torch
__all__ = [
"GetGradientsFunction",
"clip_and_scale_grads_inplace",
]
GetGradientsFunction = Callable[
[
List[torch.Tensor],
Optional[torch.Tensor],
Optional[torch.Tensor],
Optional[float],
],
Dict[str, torch.Tensor],
]
"""Signature for sample-wise gradients computation functions."""
def clip_and_scale_grads_inplace(
grads: Iterable[torch.Tensor],
clip: float,
wght: Optional[torch.Tensor] = None,
) -> None:
"""Clip a collection of tensors in-place, based on their euclidean norm.
Also apply an optional weight tensor to scale the clipped gradients.
"""
for grad in grads:
norm = torch.norm(grad, p=2, keepdim=True)
# false-positive; pylint: disable=no-member
grad.mul_(torch.clamp(clip / norm, max=1))
if wght is not None:
grad.mul_(wght.to(grad.device))
# coding: utf-8
# Copyright 2023 Inria (Institut National de Recherche en Informatique
# et Automatique)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of `build_samplewise_grads_fn` for Torch 2.0."""
from typing import Dict, Tuple
import torch
from declearn.model.torch._samplewise.shared import (
GetGradientsFunction,
clip_and_scale_grads_inplace,
)
__all__ = [
"build_samplewise_grads_fn_backend",
]
def build_samplewise_grads_fn_backend(
model: torch.nn.Module,
loss_fn: torch.nn.Module,
inputs: int,
y_true: bool,
s_wght: bool,
) -> GetGradientsFunction:
"""Implementation of `build_samplewise_grads_fn` for Torch 2.0."""
def run_forward(params, frozen, inputs, y_true, s_wght):
"""Run the forward pass in a functional way."""
y_pred = torch.func.functional_call(model, [params, frozen], *inputs)
s_loss = loss_fn(y_pred, y_true)
if s_wght is not None:
s_loss.mul_(s_wght.to(s_loss.device))
return s_loss.mean()
get_grads = torch.func.grad(run_forward, argnums=0)
def get_clipped_grads(inputs, y_true, s_wght, clip=None):
"""Compute gradients and optionally clip them."""
params, frozen = get_params(model)
grads = get_grads(
params, frozen, inputs, y_true, None if clip else s_wght
)
if clip:
clip_and_scale_grads_inplace(grads.values(), clip, s_wght)
return grads
# Wrap the former function to compute and clip sample-wise gradients.
in_dims = ([0] * inputs, 0 if y_true else None, 0 if s_wght else None)
return torch.func.vmap(get_clipped_grads, in_dims)
def get_params(
model: torch.nn.Module,
) -> Tuple[Dict[str, torch.nn.Parameter], Dict[str, torch.nn.Parameter]]:
"""Return a model's parameters, split between trainable and frozen ones."""
params = {} # type: Dict[str, torch.nn.Parameter]
frozen = {} # type: Dict[str, torch.nn.Parameter]
for name, param in model.named_parameters():
(params if param.requires_grad else frozen)[name] = param
return params, frozen
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment