diff --git a/declearn/aggregator/__init__.py b/declearn/aggregator/__init__.py index 76c285654fb199870664407309608b7e8a3e9574..1ac99d21dd7d75d4198cdb10de71cd61cce79f6e 100644 --- a/declearn/aggregator/__init__.py +++ b/declearn/aggregator/__init__.py @@ -17,22 +17,32 @@ """Model updates aggregating API and implementations. -An Aggregator is typically meant to be used on a round-wise basis by +An `Aggregator` is typically meant to be used on a round-wise basis by the orchestrating server of a centralized federated learning process, -to aggregate the client-wise model updated into a Vector that may then -be used as "gradients" by the server's Optimizer to update the global +to aggregate the client-wise model updated into a `Vector` that may then +be used as "gradients" by the server's `Optimizer` to update the global model. This declearn submodule provides with: +API tools +--------- + * [Aggregator][declearn.aggregator.Aggregator]: - abstract class defining an API for Vector aggregation + Abstract base class defining an API for Vector aggregation. +* [list_aggregators][declearn.aggregator.list_aggregators]: + Return a mapping of registered Aggregator subclasses. + + +Concrete classes +---------------- + * [AveragingAggregator][declearn.aggregator.AveragingAggregator]: - average-based-aggregation Aggregator subclass + Average-based-aggregation Aggregator subclass. * [GradientMaskedAveraging][declearn.aggregator.GradientMaskedAveraging]: - gradient Masked Averaging Aggregator subclass + Gradient Masked Averaging Aggregator subclass. """ -from ._api import Aggregator +from ._api import Aggregator, list_aggregators from ._base import AveragingAggregator from ._gma import GradientMaskedAveraging diff --git a/declearn/aggregator/_api.py b/declearn/aggregator/_api.py index 341e0ac169f74af6b942f7dce2f8fe452fc70a9d..778b1d3e0c92a7e0c0baedea85e5b35ae6ace0a7 100644 --- a/declearn/aggregator/_api.py +++ b/declearn/aggregator/_api.py @@ -18,12 +18,16 @@ """Model updates aggregation API.""" from abc import ABCMeta, abstractmethod -from typing import Any, ClassVar, Dict, TypeVar +from typing import Any, ClassVar, Dict, Type, TypeVar from typing_extensions import Self # future: import from typing (py >=3.11) from declearn.model.api import Vector -from declearn.utils import create_types_registry, register_type +from declearn.utils import ( + access_types_mapping, + create_types_registry, + register_type, +) __all__ = [ "Aggregator", @@ -132,3 +136,29 @@ class Aggregator(metaclass=ABCMeta): ) -> Self: """Instantiate an Aggregator from its configuration dict.""" return cls(**config) + + +def list_aggregators() -> Dict[str, Type[Aggregator]]: + """Return a mapping of registered Aggregator subclasses. + + This function aims at making it easy for end-users to list and access + all available Aggregator classes at any given time. The returned dict + uses unique identifier keys, which may be used to specify the desired + algorithm as part of a federated learning process without going through + the fuss of importing and instantiating it manually. + + Note that the mapping will include all declearn-provided plug-ins, + but also registered plug-ins provided by user or third-party code. + + See also + -------- + * [declearn.aggregator.Aggregator][]: + API-defining abstract base class for the aggregation algorithms. + + Returns + ------- + mapping: + Dictionary mapping unique str identifiers to `Aggregator` class + constructors. + """ + return access_types_mapping("Aggregator") diff --git a/test/aggregator/test_aggregator.py b/test/aggregator/test_aggregator.py index a12966619dde4456c7a4d5f64cf40e7d4927fb9d..f1b02def27aee7900813a73bef35d4dfaef8d24b 100644 --- a/test/aggregator/test_aggregator.py +++ b/test/aggregator/test_aggregator.py @@ -22,7 +22,7 @@ from typing import Dict, Type import pytest -from declearn.aggregator import Aggregator +from declearn.aggregator import Aggregator, list_aggregators from declearn.model.api import Vector from declearn.test_utils import ( FrameworkType, @@ -30,10 +30,9 @@ from declearn.test_utils import ( assert_dict_equal, assert_json_serializable_dict, ) -from declearn.utils import access_types_mapping -AGGREGATOR_CLASSES = access_types_mapping("Aggregator") +AGGREGATOR_CLASSES = list_aggregators() VECTOR_FRAMEWORKS = typing.get_args(FrameworkType)