diff --git a/declearn/optimizer/__init__.py b/declearn/optimizer/__init__.py index 520eeafa235b9b7ebce84c7327c96455deb09de0..44cc4396bfb6bfec5e264a18ce15b371948550d4 100644 --- a/declearn/optimizer/__init__.py +++ b/declearn/optimizer/__init__.py @@ -32,8 +32,16 @@ Submodules providing with plug-in algorithms: Gradients-alteration algorithms, implemented as plug-in modules. * [regularizers][declearn.optimizer.regularizers]: Loss-regularization algorithms, implemented as plug-in modules. + +Utils to list available plug-ins: + +* [list_optim_modules][declearn.optimizer.list_optim_modules]: + Return a mapping of registered OptiModule subclasses. +* [list_optim_regularizers][declearn.optimizer.list_optim_regularizers]: + Return a mapping of registered Regularizer subclasses. """ from . import modules, regularizers from ._base import Optimizer +from ._utils import list_optim_modules, list_optim_regularizers diff --git a/declearn/optimizer/_base.py b/declearn/optimizer/_base.py index 7480d8f8b077d0cf79ca369ce12c65e516182675..01214cd26bb1db1cc3fb496a5805f4c148f59068 100644 --- a/declearn/optimizer/_base.py +++ b/declearn/optimizer/_base.py @@ -109,6 +109,13 @@ class Optimizer: [1] Loshchilov & Hutter, 2019. Decoupled Weight Decay Regularization. https://arxiv.org/abs/1711.05101 + + See also + -------- + - [declearn.optimizer.list_optim_modules][]: + Return a mapping of registered OptiModule subclasses. + - [declearn.optimizer.list_optim_regularizers][]: + Return a mapping of registered Regularizer subclasses. """ def __init__( diff --git a/declearn/optimizer/_utils.py b/declearn/optimizer/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..acf56ace356d8917843db4a4a5daa18ee28bb127 --- /dev/null +++ b/declearn/optimizer/_utils.py @@ -0,0 +1,86 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utils to list available optimizer plug-ins (OptiModule and Regularizer).""" + +from typing import Dict, Type + +from declearn.optimizer.modules import OptiModule +from declearn.optimizer.regularizers import Regularizer +from declearn.utils import access_types_mapping + + +__all__ = [ + "list_optim_modules", + "list_optim_regularizers", +] + + +def list_optim_modules() -> Dict[str, Type[OptiModule]]: + """Return a mapping of registered OptiModule subclasses. + + This function aims at making it easy for end-users to list and access + all available OptiModule optimizer plug-ins at any given time. The + returned dict uses unique identifier keys, which may be used to add + the associated plug-in to a [declearn.optimizer.Optimizer][] without + going through the fuss of importing and instantiating it manually. + + Note that the mapping will include all declearn-provided plug-ins, + but also registered plug-ins provided by user or third-party code. + + See also + -------- + * [declearn.optimizer.modules.OptiModule][]: + API-defining abstract base class for the OptiModule plug-ins. + * [declearn.optimizer.list_optim_regularizers][]: + Counterpart function for Regularizer plug-ins. + + Returns + ------- + mapping: + Dictionary mapping unique str identifiers to OptiModule + class constructors. + """ + return access_types_mapping("OptiModule") + + +def list_optim_regularizers() -> Dict[str, Type[Regularizer]]: + """Return a mapping of registered Regularizer subclasses. + + This function aims at making it easy for end-users to list and access + all available Regularizer optimizer plug-ins at any given time. The + returned dict uses unique identifier keys, which may be used to add + the associated plug-in to a [declearn.optimizer.Optimizer][] without + going through the fuss of importing and instantiating it manually. + + Note that the mapping will include all declearn-provided plug-ins, + but also registered plug-ins provided by user or third-party code. + + See also + -------- + * [declearn.optimizer.regularizers.Regularizer][]: + API-defining abstract base class for the Regularizer plug-ins. + * [declearn.optimizer.list_optim_modules][]: + Counterpart function for OptiModule plug-ins. + + Returns + ------- + mapping: + Dictionary mapping unique str identifiers to Regularizer + class constructors. + """ + return access_types_mapping("Regularizer")