From 2bca6bdbe71d8c123eb0ab940075479a94e76b28 Mon Sep 17 00:00:00 2001
From: Paul Andrey <paul.andrey@inria.fr>
Date: Fri, 27 Jan 2023 12:40:36 +0100
Subject: [PATCH] Update the RMSProp-implementation example.

---
 examples/adding_rmsprop/readme.md | 29 +++++++++++++++--------------
 1 file changed, 15 insertions(+), 14 deletions(-)

diff --git a/examples/adding_rmsprop/readme.md b/examples/adding_rmsprop/readme.md
index 3a4e37a5..7d47d6e9 100644
--- a/examples/adding_rmsprop/readme.md
+++ b/examples/adding_rmsprop/readme.md
@@ -56,7 +56,7 @@ The RMSProp optimizer is part of the adaptative optimizer family, and was thus
 added to `_adaptative.py` file.
 
 ```python
-from declearn.optimizer.modules import OptiModule
+from declearn.optimizer.modules import OptiModule, EWMAModule
 
 
 class RMSPropModule(OptiModule):
@@ -82,37 +82,38 @@ class RMSPropModule(OptiModule):
             to the (divisor) adapative scaling term.
         """
 
-        # Reuse the existing momemtum module, see below
+        # Reuse the existing EWMA module, see below
 
-        self.mom = MomentumModule(beta=beta)
+        self.mom = EWMAModule(beta=beta)
         self.eps = eps
 
     # Allow access to the module's parameters
 
     def get_config(self,) -> Dict[str, Any]:
-        """Return a JSON-serializable dict with this module's parameters."""
-        return {"beta": self.mom.beta, "eps": self.eps}
+        return {"beta": self.ewma.beta, "eps": self.eps}
 
     # Define the actual transformations of the gradient
 
     def run(self, gradients: Vector) -> Vector:
-        """Apply RMSProp adaptation to input (pseudo-)gradients."""
-        v_t = self.mom.run(gradients**2)
+        v_t = self.ewma.run(gradients**2)
         scaling = (v_t**0.5) + self.eps
         return gradients / scaling
+
+    # Define the state-access methods; here states are handled by the EWMA
+
+    def get_state(self) -> Dict[str, Any]:
+        return self.ewma.get_state()
+
+    def set_state(self, state: Dict[str, Any],) -> None:
+        self.ewma.set_state(state)
 ```
 
-We here reuse the Momemtum module, defined in `modules/_base.py`. As a
+We here reuse the EWMA module, defined in the `modules/_momentum.py` file. As a
 module, it takes in a `Vector` and outputs a `Vector`. It has one parameter,
-$`\beta`$, and its `run` method looks like this:
+$`\beta`$, manages a state vector `state` and its `run` method looks like this:
 
 ```python
     def run(self, gradients: Vector) -> Vector:
-        """Apply Momentum acceleration to input (pseudo-)gradients."""
-
-        # Iteratively update the state class attribute with input gradients
-
         self.state = (self.beta * self.state) + ((1 - self.beta) * gradients)
         return self.state
-
 ```
-- 
GitLab