From c47dfc84b7da8dd65acbf87776502854b7af82ae Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Mon, 22 Jul 2024 16:44:13 +0200 Subject: [PATCH] Add 'FairnessControllerServer.from_specs' generic constructor. This generic constructor enables parsing any subclass instance from serializable specifications. This is notably useful to ensure that `FLOptimConfig` remains compatible with TOML parsing even when a fairness algorithm is used. Dedicated tests and further documentation are yet to be added. --- declearn/fairness/api/_server.py | 49 ++++++++++++++++- declearn/main/config/_strategy.py | 52 ++++++++++++++++++- .../fairness_controllers_testing.py | 10 ++++ 3 files changed, 108 insertions(+), 3 deletions(-) diff --git a/declearn/fairness/api/_server.py b/declearn/fairness/api/_server.py index c193b6b..561261d 100644 --- a/declearn/fairness/api/_server.py +++ b/declearn/fairness/api/_server.py @@ -39,7 +39,11 @@ from declearn.secagg.messaging import ( SecaggFairnessCounts, SecaggFairnessReply, ) -from declearn.utils import create_types_registry, register_type +from declearn.utils import ( + access_registered, + create_types_registry, + register_type, +) __all__ = [ "FairnessControllerServer", @@ -430,3 +434,46 @@ class FairnessControllerServer(metaclass=abc.ABCMeta): Fairness(-related) metrics computed as part of this routine, as a dict mapping scalar or numpy array values with their name. """ + + @staticmethod + def from_specs( + algorithm: str, + f_type: str, + f_args: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> "FairnessControllerServer": + """Instantiate a 'FairnessControllerServer' from its specifications. + + Parameters + ---------- + algorithm: + Name of the algorithm associated with the target controller class. + f_type: + Name of the fairness function to evaluate and optimize. + f_args: + Optional dict of keyword arguments to the fairness function. + **kwargs: + Any additional algorithm-specific instantiation keyword argument. + + Returns + ------- + controller: + `FairnessControllerServer` instance matching input specifications. + + Raises + ------ + KeyError + If `algorithm` does not match any registered + `FairnessControllerServer` type. + """ + try: + cls = access_registered( + name=algorithm, group="FairnessControllerServer" + ) + except Exception as exc: + raise KeyError( + "Failed to retrieve fairness controller with algorithm name " + f"'{algorithm}'." + ) from exc + assert issubclass(cls, FairnessControllerServer) + return cls(f_type=f_type, f_args=f_args, **kwargs) diff --git a/declearn/main/config/_strategy.py b/declearn/main/config/_strategy.py index 707b93a..85e6814 100644 --- a/declearn/main/config/_strategy.py +++ b/declearn/main/config/_strategy.py @@ -110,7 +110,14 @@ class FLOptimConfig(TomlConfig): field: dataclasses.Field, # future: dataclasses.Field[Optimizer] inputs: Union[float, Dict[str, Any], Optimizer], ) -> Optimizer: - """Field-specific parser to instantiate the client-side Optimizer.""" + """Field-specific parser to instantiate the client-side Optimizer. + + This method supports specifying `client_opt`: + + - as a float, parsed as the learning rate to a basic SGD optimzier + - as a dict, parsed a serialized Optimizer configuration + - as an `Optimizer` instance (requiring no parsing) + """ return cls._parse_optimizer(field, inputs) @classmethod @@ -119,7 +126,15 @@ class FLOptimConfig(TomlConfig): field: dataclasses.Field, # future: dataclasses.Field[Optimizer] inputs: Union[float, Dict[str, Any], Optimizer, None], ) -> Optimizer: - """Field-specific parser to instantiate the server-side Optimizer.""" + """Field-specific parser to instantiate the server-side Optimizer. + + This method supports specifying `server_opt`: + + - as None (or missing kwarg), resulting in a basic `Optimizer(1.0)` + - as a float, parsed as the learning rate to a basic SGD optimzier + - as a dict, parsed a serialized Optimizer configuration + - as an `Optimizer` instance (requiring no parsing) + """ return cls._parse_optimizer(field, inputs) @classmethod @@ -155,6 +170,7 @@ class FLOptimConfig(TomlConfig): - (opt.) config: dict specifying kwargs for the constructor - any other field will be added to the `config` kwargs dict - as None (or missing kwarg), using default AveragingAggregator() + - as an `Aggregator` instance (requiring no parsing) """ # Case when using the default value: delegate to the default parser. if inputs is None: @@ -193,3 +209,35 @@ class FLOptimConfig(TomlConfig): return obj # Otherwise, raise a TypeError as inputs are unsupported. raise TypeError("Unsupported inputs type for field 'aggregator'.") + + @classmethod + def parse_fairness( + cls, + field: dataclasses.Field, # future: dataclasses.Field[<type>] + inputs: Union[Dict[str, Any], FairnessControllerServer, None], + ) -> FairnessControllerServer: + """Field-specific parser to instantiate a FairnessControllerServer. + + This method supports specifying `fairness`: + + - as None (or missing kwarg), using no fairness controller + - as a dict, parsed a FairnessControllerServer specifications: + - algorithm: str used to retrieve a registered type + - f_type: str used to define a group fairness function + - (opt.) f_args: dict to parametrize the fairness function + - any other field will be added to the `config` kwargs dict + - as a `FairnessControllerServer` instance (requiring no parsing) + """ + if inputs is None: + return cls.default_parser(field, inputs) + if isinstance(inputs, FairnessControllerServer): + return inputs + if isinstance(inputs, dict): + for key in ("algorithm", "f_type"): + if key not in inputs: + raise TypeError( + "Wrong format for FairnessControllerServer " + f"configuration: missing '{key}' field." + ) + return FairnessControllerServer.from_specs(**inputs) + raise TypeError("Unsupported inputs type for field 'fairness.") diff --git a/test/fairness/controllers/fairness_controllers_testing.py b/test/fairness/controllers/fairness_controllers_testing.py index 41b774b..8641ef3 100644 --- a/test/fairness/controllers/fairness_controllers_testing.py +++ b/test/fairness/controllers/fairness_controllers_testing.py @@ -99,6 +99,16 @@ class FairnessControllerTestSuite: manager.train_data = build_mock_dataset(idx) return manager + def test_setup_server_from_specs( + self, + ) -> None: + """Test instantiating a server-side controller 'from_specs'.""" + server = self.server_cls.from_specs( + algorithm=self.server_cls.algorithm, + f_type="demographic_parity", + ) + assert isinstance(server, self.server_cls) + def test_setup_client_from_setup_query( self, ) -> None: -- GitLab