From 2bca6bdbe71d8c123eb0ab940075479a94e76b28 Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Fri, 27 Jan 2023 12:40:36 +0100 Subject: [PATCH] Update the RMSProp-implementation example. --- examples/adding_rmsprop/readme.md | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/examples/adding_rmsprop/readme.md b/examples/adding_rmsprop/readme.md index 3a4e37a5..7d47d6e9 100644 --- a/examples/adding_rmsprop/readme.md +++ b/examples/adding_rmsprop/readme.md @@ -56,7 +56,7 @@ The RMSProp optimizer is part of the adaptative optimizer family, and was thus added to `_adaptative.py` file. ```python -from declearn.optimizer.modules import OptiModule +from declearn.optimizer.modules import OptiModule, EWMAModule class RMSPropModule(OptiModule): @@ -82,37 +82,38 @@ class RMSPropModule(OptiModule): to the (divisor) adapative scaling term. """ - # Reuse the existing momemtum module, see below + # Reuse the existing EWMA module, see below - self.mom = MomentumModule(beta=beta) + self.mom = EWMAModule(beta=beta) self.eps = eps # Allow access to the module's parameters def get_config(self,) -> Dict[str, Any]: - """Return a JSON-serializable dict with this module's parameters.""" - return {"beta": self.mom.beta, "eps": self.eps} + return {"beta": self.ewma.beta, "eps": self.eps} # Define the actual transformations of the gradient def run(self, gradients: Vector) -> Vector: - """Apply RMSProp adaptation to input (pseudo-)gradients.""" - v_t = self.mom.run(gradients**2) + v_t = self.ewma.run(gradients**2) scaling = (v_t**0.5) + self.eps return gradients / scaling + + # Define the state-access methods; here states are handled by the EWMA + + def get_state(self) -> Dict[str, Any]: + return self.ewma.get_state() + + def set_state(self, state: Dict[str, Any],) -> None: + self.ewma.set_state(state) ``` -We here reuse the Momemtum module, defined in `modules/_base.py`. As a +We here reuse the EWMA module, defined in the `modules/_momentum.py` file. As a module, it takes in a `Vector` and outputs a `Vector`. It has one parameter, -$`\beta`$, and its `run` method looks like this: +$`\beta`$, manages a state vector `state` and its `run` method looks like this: ```python def run(self, gradients: Vector) -> Vector: - """Apply Momentum acceleration to input (pseudo-)gradients.""" - - # Iteratively update the state class attribute with input gradients - self.state = (self.beta * self.state) + ((1 - self.beta) * gradients) return self.state - ``` -- GitLab