Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit fd8abbd3 authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Add the 'list_aggregators' util function.

parent c3a61861
No related branches found
No related tags found
1 merge request!44Minor gardening around the package
......@@ -17,22 +17,32 @@
"""Model updates aggregating API and implementations.
An Aggregator is typically meant to be used on a round-wise basis by
An `Aggregator` is typically meant to be used on a round-wise basis by
the orchestrating server of a centralized federated learning process,
to aggregate the client-wise model updated into a Vector that may then
be used as "gradients" by the server's Optimizer to update the global
to aggregate the client-wise model updated into a `Vector` that may then
be used as "gradients" by the server's `Optimizer` to update the global
model.
This declearn submodule provides with:
API tools
---------
* [Aggregator][declearn.aggregator.Aggregator]:
abstract class defining an API for Vector aggregation
Abstract base class defining an API for Vector aggregation.
* [list_aggregators][declearn.aggregator.list_aggregators]:
Return a mapping of registered Aggregator subclasses.
Concrete classes
----------------
* [AveragingAggregator][declearn.aggregator.AveragingAggregator]:
average-based-aggregation Aggregator subclass
Average-based-aggregation Aggregator subclass.
* [GradientMaskedAveraging][declearn.aggregator.GradientMaskedAveraging]:
gradient Masked Averaging Aggregator subclass
Gradient Masked Averaging Aggregator subclass.
"""
from ._api import Aggregator
from ._api import Aggregator, list_aggregators
from ._base import AveragingAggregator
from ._gma import GradientMaskedAveraging
......@@ -18,12 +18,16 @@
"""Model updates aggregation API."""
from abc import ABCMeta, abstractmethod
from typing import Any, ClassVar, Dict, TypeVar
from typing import Any, ClassVar, Dict, Type, TypeVar
from typing_extensions import Self # future: import from typing (py >=3.11)
from declearn.model.api import Vector
from declearn.utils import create_types_registry, register_type
from declearn.utils import (
access_types_mapping,
create_types_registry,
register_type,
)
__all__ = [
"Aggregator",
......@@ -132,3 +136,29 @@ class Aggregator(metaclass=ABCMeta):
) -> Self:
"""Instantiate an Aggregator from its configuration dict."""
return cls(**config)
def list_aggregators() -> Dict[str, Type[Aggregator]]:
"""Return a mapping of registered Aggregator subclasses.
This function aims at making it easy for end-users to list and access
all available Aggregator classes at any given time. The returned dict
uses unique identifier keys, which may be used to specify the desired
algorithm as part of a federated learning process without going through
the fuss of importing and instantiating it manually.
Note that the mapping will include all declearn-provided plug-ins,
but also registered plug-ins provided by user or third-party code.
See also
--------
* [declearn.aggregator.Aggregator][]:
API-defining abstract base class for the aggregation algorithms.
Returns
-------
mapping:
Dictionary mapping unique str identifiers to `Aggregator` class
constructors.
"""
return access_types_mapping("Aggregator")
......@@ -22,7 +22,7 @@ from typing import Dict, Type
import pytest
from declearn.aggregator import Aggregator
from declearn.aggregator import Aggregator, list_aggregators
from declearn.model.api import Vector
from declearn.test_utils import (
FrameworkType,
......@@ -30,10 +30,9 @@ from declearn.test_utils import (
assert_dict_equal,
assert_json_serializable_dict,
)
from declearn.utils import access_types_mapping
AGGREGATOR_CLASSES = access_types_mapping("Aggregator")
AGGREGATOR_CLASSES = list_aggregators()
VECTOR_FRAMEWORKS = typing.get_args(FrameworkType)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment