Add sample-wise gradients clipping to the Model API
This Merge Request adds the max_norm
optional argument to the Model.compute_batch_gradients
method, effectively making sample-wise gradients' clipping part of the declearn API.
It also implements efficient sample-wise gradients' computation, clipping and eventual batch-averaging through framework-specific backend methods, notably making use of functorch for TorchModel
(which is listed as an additional dependency in the pyproject file but is packaged together with torch
itself in the newer releases).
The first commits added by this MR were cherry-picked out of the Merge Request !11 (merged) that aims at implementing Differential Privacy. This was done as part of an effort to clean out the repository's commit history by integrating contents of the latter MR incrementally.
In addition, the following changes were added exclusively as part of this MR:
- Add a
sclip
parameter toOptimizer.run_train_step
, that effectively exposesmax_norm
fromModel.compute_batch_gradients
.-
The rationale behind making this a method argument rather than an attribute is that it may only be used as part of
run_train_step
but has no effect inapply_gradients
. - Another reason behind this choice is that it may enable writing complex clipping schedulers ~ adaptive algorithms outside of the Optimizer API.
-
The rationale behind making this a method argument rather than an attribute is that it may only be used as part of
- Add a
compute_updates_from_gradients
method toOptimizer
.- This method merely splits out most of the code from the
apply_gradients
method. - The rationale behind this is to enable modular uses of the Optimizer API, notably in derived applications or in hacky edge cases, and/or to implement future
Optimizer
unit tests.
- This method merely splits out most of the code from the
- Add a
L2Clipping
optimizer plug-in module, implementing batch-level clipping.- This module anticipates on possible requirements to implement central DP, as opposed to sample-level-protecting local DP.