diff --git a/declearn/optimizer/modules/_clipping.py b/declearn/optimizer/modules/_clipping.py index 5343950d43c505af8938ccf4b151e21ca42c8f17..3d6617c88b2c247637a28de7f5b3584117efac92 100644 --- a/declearn/optimizer/modules/_clipping.py +++ b/declearn/optimizer/modules/_clipping.py @@ -30,11 +30,12 @@ class L2Clipping(OptiModule): This module implements the following algorithm: - Init(max_norm): - Step(max_norm): - norm = euclidean_norm(grads) - clip = min(norm, max_norm) - grads *= clip / max_norm + Init(max_norm, per_grad): + assign hyper-parameters + Step(grads): + norm = euclidean_norm(grads) # parameter-wise + clip = max(max_norm / norm, 1.0) + grads *= clip In other words, (batch-averaged) gradients are clipped based on their L2 (euclidean) norm, based on a single, @@ -67,7 +68,7 @@ class L2Clipping(OptiModule): gradients: Vector, ) -> Vector: l2_norm = (gradients**2).sum() ** 0.5 - c_scale = (l2_norm / self.max_norm).minimum(1.0) + c_scale = (self.max_norm / l2_norm).minimum(1.0) return gradients * c_scale def get_config(