Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit c64886ea authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Fix 'L2Clipping' algorithm.

parent d4f0e84d
No related branches found
No related tags found
No related merge requests found
......@@ -30,11 +30,12 @@ class L2Clipping(OptiModule):
This module implements the following algorithm:
Init(max_norm):
Step(max_norm):
norm = euclidean_norm(grads)
clip = min(norm, max_norm)
grads *= clip / max_norm
Init(max_norm, per_grad):
assign hyper-parameters
Step(grads):
norm = euclidean_norm(grads) # parameter-wise
clip = max(max_norm / norm, 1.0)
grads *= clip
In other words, (batch-averaged) gradients are clipped
based on their L2 (euclidean) norm, based on a single,
......@@ -67,7 +68,7 @@ class L2Clipping(OptiModule):
gradients: Vector,
) -> Vector:
l2_norm = (gradients**2).sum() ** 0.5
c_scale = (l2_norm / self.max_norm).minimum(1.0)
c_scale = (self.max_norm / l2_norm).minimum(1.0)
return gradients * c_scale
def get_config(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment