Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit c47dfc84 authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Add 'FairnessControllerServer.from_specs' generic constructor.

This generic constructor enables parsing any subclass instance from
serializable specifications. This is notably useful to ensure that
`FLOptimConfig` remains compatible with TOML parsing even when a
fairness algorithm is used.

Dedicated tests and further documentation are yet to be added.
parent f9ec48be
No related branches found
No related tags found
1 merge request!69Enable Fairness-Aware Federated Learning
......@@ -39,7 +39,11 @@ from declearn.secagg.messaging import (
SecaggFairnessCounts,
SecaggFairnessReply,
)
from declearn.utils import create_types_registry, register_type
from declearn.utils import (
access_registered,
create_types_registry,
register_type,
)
__all__ = [
"FairnessControllerServer",
......@@ -430,3 +434,46 @@ class FairnessControllerServer(metaclass=abc.ABCMeta):
Fairness(-related) metrics computed as part of this routine,
as a dict mapping scalar or numpy array values with their name.
"""
@staticmethod
def from_specs(
algorithm: str,
f_type: str,
f_args: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> "FairnessControllerServer":
"""Instantiate a 'FairnessControllerServer' from its specifications.
Parameters
----------
algorithm:
Name of the algorithm associated with the target controller class.
f_type:
Name of the fairness function to evaluate and optimize.
f_args:
Optional dict of keyword arguments to the fairness function.
**kwargs:
Any additional algorithm-specific instantiation keyword argument.
Returns
-------
controller:
`FairnessControllerServer` instance matching input specifications.
Raises
------
KeyError
If `algorithm` does not match any registered
`FairnessControllerServer` type.
"""
try:
cls = access_registered(
name=algorithm, group="FairnessControllerServer"
)
except Exception as exc:
raise KeyError(
"Failed to retrieve fairness controller with algorithm name "
f"'{algorithm}'."
) from exc
assert issubclass(cls, FairnessControllerServer)
return cls(f_type=f_type, f_args=f_args, **kwargs)
......@@ -110,7 +110,14 @@ class FLOptimConfig(TomlConfig):
field: dataclasses.Field, # future: dataclasses.Field[Optimizer]
inputs: Union[float, Dict[str, Any], Optimizer],
) -> Optimizer:
"""Field-specific parser to instantiate the client-side Optimizer."""
"""Field-specific parser to instantiate the client-side Optimizer.
This method supports specifying `client_opt`:
- as a float, parsed as the learning rate to a basic SGD optimzier
- as a dict, parsed a serialized Optimizer configuration
- as an `Optimizer` instance (requiring no parsing)
"""
return cls._parse_optimizer(field, inputs)
@classmethod
......@@ -119,7 +126,15 @@ class FLOptimConfig(TomlConfig):
field: dataclasses.Field, # future: dataclasses.Field[Optimizer]
inputs: Union[float, Dict[str, Any], Optimizer, None],
) -> Optimizer:
"""Field-specific parser to instantiate the server-side Optimizer."""
"""Field-specific parser to instantiate the server-side Optimizer.
This method supports specifying `server_opt`:
- as None (or missing kwarg), resulting in a basic `Optimizer(1.0)`
- as a float, parsed as the learning rate to a basic SGD optimzier
- as a dict, parsed a serialized Optimizer configuration
- as an `Optimizer` instance (requiring no parsing)
"""
return cls._parse_optimizer(field, inputs)
@classmethod
......@@ -155,6 +170,7 @@ class FLOptimConfig(TomlConfig):
- (opt.) config: dict specifying kwargs for the constructor
- any other field will be added to the `config` kwargs dict
- as None (or missing kwarg), using default AveragingAggregator()
- as an `Aggregator` instance (requiring no parsing)
"""
# Case when using the default value: delegate to the default parser.
if inputs is None:
......@@ -193,3 +209,35 @@ class FLOptimConfig(TomlConfig):
return obj
# Otherwise, raise a TypeError as inputs are unsupported.
raise TypeError("Unsupported inputs type for field 'aggregator'.")
@classmethod
def parse_fairness(
cls,
field: dataclasses.Field, # future: dataclasses.Field[<type>]
inputs: Union[Dict[str, Any], FairnessControllerServer, None],
) -> FairnessControllerServer:
"""Field-specific parser to instantiate a FairnessControllerServer.
This method supports specifying `fairness`:
- as None (or missing kwarg), using no fairness controller
- as a dict, parsed a FairnessControllerServer specifications:
- algorithm: str used to retrieve a registered type
- f_type: str used to define a group fairness function
- (opt.) f_args: dict to parametrize the fairness function
- any other field will be added to the `config` kwargs dict
- as a `FairnessControllerServer` instance (requiring no parsing)
"""
if inputs is None:
return cls.default_parser(field, inputs)
if isinstance(inputs, FairnessControllerServer):
return inputs
if isinstance(inputs, dict):
for key in ("algorithm", "f_type"):
if key not in inputs:
raise TypeError(
"Wrong format for FairnessControllerServer "
f"configuration: missing '{key}' field."
)
return FairnessControllerServer.from_specs(**inputs)
raise TypeError("Unsupported inputs type for field 'fairness.")
......@@ -99,6 +99,16 @@ class FairnessControllerTestSuite:
manager.train_data = build_mock_dataset(idx)
return manager
def test_setup_server_from_specs(
self,
) -> None:
"""Test instantiating a server-side controller 'from_specs'."""
server = self.server_cls.from_specs(
algorithm=self.server_cls.algorithm,
f_type="demographic_parity",
)
assert isinstance(server, self.server_cls)
def test_setup_client_from_setup_query(
self,
) -> None:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment