Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 2bca6bdb authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Update the RMSProp-implementation example.

parent 5df95e75
Branches
Tags
1 merge request!19Revise `Optimizer` and add state-access methods.
Pipeline #745949 failed
...@@ -56,7 +56,7 @@ The RMSProp optimizer is part of the adaptative optimizer family, and was thus ...@@ -56,7 +56,7 @@ The RMSProp optimizer is part of the adaptative optimizer family, and was thus
added to `_adaptative.py` file. added to `_adaptative.py` file.
```python ```python
from declearn.optimizer.modules import OptiModule from declearn.optimizer.modules import OptiModule, EWMAModule
class RMSPropModule(OptiModule): class RMSPropModule(OptiModule):
...@@ -82,37 +82,38 @@ class RMSPropModule(OptiModule): ...@@ -82,37 +82,38 @@ class RMSPropModule(OptiModule):
to the (divisor) adapative scaling term. to the (divisor) adapative scaling term.
""" """
# Reuse the existing momemtum module, see below # Reuse the existing EWMA module, see below
self.mom = MomentumModule(beta=beta) self.mom = EWMAModule(beta=beta)
self.eps = eps self.eps = eps
# Allow access to the module's parameters # Allow access to the module's parameters
def get_config(self,) -> Dict[str, Any]: def get_config(self,) -> Dict[str, Any]:
"""Return a JSON-serializable dict with this module's parameters.""" return {"beta": self.ewma.beta, "eps": self.eps}
return {"beta": self.mom.beta, "eps": self.eps}
# Define the actual transformations of the gradient # Define the actual transformations of the gradient
def run(self, gradients: Vector) -> Vector: def run(self, gradients: Vector) -> Vector:
"""Apply RMSProp adaptation to input (pseudo-)gradients.""" v_t = self.ewma.run(gradients**2)
v_t = self.mom.run(gradients**2)
scaling = (v_t**0.5) + self.eps scaling = (v_t**0.5) + self.eps
return gradients / scaling return gradients / scaling
# Define the state-access methods; here states are handled by the EWMA
def get_state(self) -> Dict[str, Any]:
return self.ewma.get_state()
def set_state(self, state: Dict[str, Any],) -> None:
self.ewma.set_state(state)
``` ```
We here reuse the Momemtum module, defined in `modules/_base.py`. As a We here reuse the EWMA module, defined in the `modules/_momentum.py` file. As a
module, it takes in a `Vector` and outputs a `Vector`. It has one parameter, module, it takes in a `Vector` and outputs a `Vector`. It has one parameter,
$`\beta`$, and its `run` method looks like this: $`\beta`$, manages a state vector `state` and its `run` method looks like this:
```python ```python
def run(self, gradients: Vector) -> Vector: def run(self, gradients: Vector) -> Vector:
"""Apply Momentum acceleration to input (pseudo-)gradients."""
# Iteratively update the state class attribute with input gradients
self.state = (self.beta * self.state) + ((1 - self.beta) * gradients) self.state = (self.beta * self.state) + ((1 - self.beta) * gradients)
return self.state return self.state
``` ```
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment