From ef8b5662e43533cd59f439f467b1ac08e7a419e3 Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Tue, 4 Jul 2023 13:43:23 +0200 Subject: [PATCH] Add support for torch model buffers in clipped gradients computation. --- declearn/model/torch/_samplewise/functorch.py | 15 +++++++++------ declearn/model/torch/_samplewise/torchfunc.py | 12 ++++++++---- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/declearn/model/torch/_samplewise/functorch.py b/declearn/model/torch/_samplewise/functorch.py index fc8e613b..d39a3e16 100644 --- a/declearn/model/torch/_samplewise/functorch.py +++ b/declearn/model/torch/_samplewise/functorch.py @@ -49,11 +49,11 @@ def build_samplewise_grads_fn_backend( ) -> GetGradientsFunction: """Implementation of `build_samplewise_grads_fn` for Torch 1.1X.""" - func_model, _ = functorch.make_functional(model) + func_model, *_ = functorch.make_functional_with_buffers(model) - def run_forward(inputs, y_true, s_wght, *params): + def run_forward(inputs, y_true, s_wght, buffers, *params): """Run the forward pass in a functional way.""" - y_pred = func_model(params, *inputs) + y_pred = func_model(params, buffers, *inputs) s_loss = loss_fn(y_pred, y_true) if s_wght is not None: s_loss.mul_(s_wght.to(s_loss.device)) @@ -62,15 +62,18 @@ def build_samplewise_grads_fn_backend( def grads_fn(inputs, y_true, s_wght, clip=None): """Compute gradients and optionally clip them.""" params, idxgrd, pnames = get_params(model) + buffers = list(model.buffers()) gfunc = functorch.grad(run_forward, argnums=tuple(idxgrd)) - grads = gfunc(inputs, y_true, (None if clip else s_wght), *params) + grads = gfunc( + inputs, y_true, (None if clip else s_wght), buffers, *params + ) if clip: clip_and_scale_grads_inplace(grads, clip, s_wght) return dict(zip(pnames, grads)) # Wrap the former function to compute and clip sample-wise gradients. in_dims = ([0] * inputs, 0 if y_true else None, 0 if s_wght else None) - grads_fn = functorch.vmap(grads_fn, in_dims) + grads_fn = functorch.vmap(grads_fn, in_dims, randomness="same") # Compile the resulting function to decrease runtime costs. if not COMPILE_AVAILABLE: # pragma: no cover @@ -88,6 +91,6 @@ def get_params( for idx, (name, param) in enumerate(model.named_parameters()): params.append(param) if param.requires_grad: - idxgrd.append(idx + 3) + idxgrd.append(idx + 4) pnames.append(name) return params, idxgrd, pnames diff --git a/declearn/model/torch/_samplewise/torchfunc.py b/declearn/model/torch/_samplewise/torchfunc.py index d330d4f9..88aa5b7d 100644 --- a/declearn/model/torch/_samplewise/torchfunc.py +++ b/declearn/model/torch/_samplewise/torchfunc.py @@ -40,9 +40,12 @@ def build_samplewise_grads_fn_backend( ) -> GetGradientsFunction: """Implementation of `build_samplewise_grads_fn` for Torch 2.0.""" - def run_forward(params, frozen, inputs, y_true, s_wght): + def run_forward(params, frozen, buffers, inputs, y_true, s_wght): """Run the forward pass in a functional way.""" - y_pred = torch.func.functional_call(model, [params, frozen], *inputs) + # backend closure function; pylint: disable=too-many-arguments + y_pred = torch.func.functional_call( + model, [params, frozen, buffers], *inputs + ) s_loss = loss_fn(y_pred, y_true) if s_wght is not None: s_loss.mul_(s_wght.to(s_loss.device)) @@ -53,8 +56,9 @@ def build_samplewise_grads_fn_backend( def get_clipped_grads(inputs, y_true, s_wght, clip=None): """Compute gradients and optionally clip them.""" params, frozen = get_params(model) + buffers = dict(model.named_buffers()) grads = get_grads( - params, frozen, inputs, y_true, None if clip else s_wght + params, frozen, buffers, inputs, y_true, None if clip else s_wght ) if clip: clip_and_scale_grads_inplace(grads.values(), clip, s_wght) @@ -62,7 +66,7 @@ def build_samplewise_grads_fn_backend( # Wrap the former function to compute and clip sample-wise gradients. in_dims = ([0] * inputs, 0 if y_true else None, 0 if s_wght else None) - return torch.func.vmap(get_clipped_grads, in_dims) + return torch.func.vmap(get_clipped_grads, in_dims, randomness="same") def get_params( -- GitLab