From 19bfaf8e1ea88a940f03286406e30f1f10015dd8 Mon Sep 17 00:00:00 2001
From: Paul Andrey <paul.andrey@inria.fr>
Date: Thu, 3 Aug 2023 17:46:41 +0200
Subject: [PATCH] Fix 'L2Clipping' algorithm.

---
 declearn/optimizer/modules/_clipping.py | 13 +++++++------
 1 file changed, 7 insertions(+), 6 deletions(-)

diff --git a/declearn/optimizer/modules/_clipping.py b/declearn/optimizer/modules/_clipping.py
index 5343950d..3d6617c8 100644
--- a/declearn/optimizer/modules/_clipping.py
+++ b/declearn/optimizer/modules/_clipping.py
@@ -30,11 +30,12 @@ class L2Clipping(OptiModule):
 
     This module implements the following algorithm:
 
-        Init(max_norm):
-        Step(max_norm):
-            norm = euclidean_norm(grads)
-            clip = min(norm, max_norm)
-            grads *= clip / max_norm
+        Init(max_norm, per_grad):
+            assign hyper-parameters
+        Step(grads):
+            norm = euclidean_norm(grads)  # parameter-wise
+            clip = max(max_norm / norm, 1.0)
+            grads *= clip
 
     In other words, (batch-averaged) gradients are clipped
     based on their L2 (euclidean) norm, based on a single,
@@ -67,7 +68,7 @@ class L2Clipping(OptiModule):
         gradients: Vector,
     ) -> Vector:
         l2_norm = (gradients**2).sum() ** 0.5
-        c_scale = (l2_norm / self.max_norm).minimum(1.0)
+        c_scale = (self.max_norm / l2_norm).minimum(1.0)
         return gradients * c_scale
 
     def get_config(
-- 
GitLab