Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit 29c6fb3a authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Marginally improve some 'Aggregator' code and docs.

parent f9ee9b9a
No related branches found
No related tags found
1 merge request!44Minor gardening around the package
......@@ -18,7 +18,7 @@
"""Model updates aggregation API."""
from abc import ABCMeta, abstractmethod
from typing import Any, ClassVar, Dict
from typing import Any, ClassVar, Dict, TypeVar
from typing_extensions import Self # future: import from typing (py >=3.11)
......@@ -30,6 +30,9 @@ __all__ = [
]
T = TypeVar("T")
@create_types_registry
class Aggregator(metaclass=ABCMeta):
"""Abstract class defining an API for Vector aggregation.
......@@ -90,9 +93,9 @@ class Aggregator(metaclass=ABCMeta):
@abstractmethod
def aggregate(
self,
updates: Dict[str, Vector],
updates: Dict[str, Vector[T]],
n_steps: Dict[str, int], # revise: abstract~generalize kwargs use
) -> Vector:
) -> Vector[T]:
"""Aggregate input vectors into a single one.
Parameters
......@@ -109,6 +112,11 @@ class Aggregator(metaclass=ABCMeta):
gradients: Vector
Aggregated updates, as a Vector - treated as gradients by
the server-side optimizer.
Raises
------
TypeError
If the input `updates` are an empty dict.
"""
def get_config(
......
......@@ -19,7 +19,6 @@
from typing import Any, ClassVar, Dict, Optional
from typing_extensions import Self # future: import from typing (py >=3.11)
from declearn.aggregator._api import Aggregator
from declearn.model.api import Vector
......@@ -76,13 +75,6 @@ class AveragingAggregator(Aggregator):
"client_weights": self.client_weights,
}
@classmethod
def from_config(
cls,
config: Dict[str, Any],
) -> Self:
return cls(**config)
def aggregate(
self,
updates: Dict[str, Vector],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment