From c47dfc84b7da8dd65acbf87776502854b7af82ae Mon Sep 17 00:00:00 2001
From: Paul Andrey <paul.andrey@inria.fr>
Date: Mon, 22 Jul 2024 16:44:13 +0200
Subject: [PATCH] Add 'FairnessControllerServer.from_specs' generic
 constructor.

This generic constructor enables parsing any subclass instance from
serializable specifications. This is notably useful to ensure that
`FLOptimConfig` remains compatible with TOML parsing even when a
fairness algorithm is used.

Dedicated tests and further documentation are yet to be added.
---
 declearn/fairness/api/_server.py              | 49 ++++++++++++++++-
 declearn/main/config/_strategy.py             | 52 ++++++++++++++++++-
 .../fairness_controllers_testing.py           | 10 ++++
 3 files changed, 108 insertions(+), 3 deletions(-)

diff --git a/declearn/fairness/api/_server.py b/declearn/fairness/api/_server.py
index c193b6b..561261d 100644
--- a/declearn/fairness/api/_server.py
+++ b/declearn/fairness/api/_server.py
@@ -39,7 +39,11 @@ from declearn.secagg.messaging import (
     SecaggFairnessCounts,
     SecaggFairnessReply,
 )
-from declearn.utils import create_types_registry, register_type
+from declearn.utils import (
+    access_registered,
+    create_types_registry,
+    register_type,
+)
 
 __all__ = [
     "FairnessControllerServer",
@@ -430,3 +434,46 @@ class FairnessControllerServer(metaclass=abc.ABCMeta):
             Fairness(-related) metrics computed as part of this routine,
             as a dict mapping scalar or numpy array values with their name.
         """
+
+    @staticmethod
+    def from_specs(
+        algorithm: str,
+        f_type: str,
+        f_args: Optional[Dict[str, Any]] = None,
+        **kwargs: Any,
+    ) -> "FairnessControllerServer":
+        """Instantiate a 'FairnessControllerServer' from its specifications.
+
+        Parameters
+        ----------
+        algorithm:
+            Name of the algorithm associated with the target controller class.
+        f_type:
+            Name of the fairness function to evaluate and optimize.
+        f_args:
+            Optional dict of keyword arguments to the fairness function.
+        **kwargs:
+            Any additional algorithm-specific instantiation keyword argument.
+
+        Returns
+        -------
+        controller:
+            `FairnessControllerServer` instance matching input specifications.
+
+        Raises
+        ------
+        KeyError
+            If `algorithm` does not match any registered
+            `FairnessControllerServer` type.
+        """
+        try:
+            cls = access_registered(
+                name=algorithm, group="FairnessControllerServer"
+            )
+        except Exception as exc:
+            raise KeyError(
+                "Failed to retrieve fairness controller with algorithm name "
+                f"'{algorithm}'."
+            ) from exc
+        assert issubclass(cls, FairnessControllerServer)
+        return cls(f_type=f_type, f_args=f_args, **kwargs)
diff --git a/declearn/main/config/_strategy.py b/declearn/main/config/_strategy.py
index 707b93a..85e6814 100644
--- a/declearn/main/config/_strategy.py
+++ b/declearn/main/config/_strategy.py
@@ -110,7 +110,14 @@ class FLOptimConfig(TomlConfig):
         field: dataclasses.Field,  # future: dataclasses.Field[Optimizer]
         inputs: Union[float, Dict[str, Any], Optimizer],
     ) -> Optimizer:
-        """Field-specific parser to instantiate the client-side Optimizer."""
+        """Field-specific parser to instantiate the client-side Optimizer.
+
+        This method supports specifying `client_opt`:
+
+        - as a float, parsed as the learning rate to a basic SGD optimzier
+        - as a dict, parsed a serialized Optimizer configuration
+        - as an `Optimizer` instance (requiring no parsing)
+        """
         return cls._parse_optimizer(field, inputs)
 
     @classmethod
@@ -119,7 +126,15 @@ class FLOptimConfig(TomlConfig):
         field: dataclasses.Field,  # future: dataclasses.Field[Optimizer]
         inputs: Union[float, Dict[str, Any], Optimizer, None],
     ) -> Optimizer:
-        """Field-specific parser to instantiate the server-side Optimizer."""
+        """Field-specific parser to instantiate the server-side Optimizer.
+
+        This method supports specifying `server_opt`:
+
+        - as None (or missing kwarg), resulting in a basic `Optimizer(1.0)`
+        - as a float, parsed as the learning rate to a basic SGD optimzier
+        - as a dict, parsed a serialized Optimizer configuration
+        - as an `Optimizer` instance (requiring no parsing)
+        """
         return cls._parse_optimizer(field, inputs)
 
     @classmethod
@@ -155,6 +170,7 @@ class FLOptimConfig(TomlConfig):
             - (opt.) config: dict specifying kwargs for the constructor
             - any other field will be added to the `config` kwargs dict
         - as None (or missing kwarg), using default AveragingAggregator()
+        - as an `Aggregator` instance (requiring no parsing)
         """
         # Case when using the default value: delegate to the default parser.
         if inputs is None:
@@ -193,3 +209,35 @@ class FLOptimConfig(TomlConfig):
             return obj
         # Otherwise, raise a TypeError as inputs are unsupported.
         raise TypeError("Unsupported inputs type for field 'aggregator'.")
+
+    @classmethod
+    def parse_fairness(
+        cls,
+        field: dataclasses.Field,  # future: dataclasses.Field[<type>]
+        inputs: Union[Dict[str, Any], FairnessControllerServer, None],
+    ) -> FairnessControllerServer:
+        """Field-specific parser to instantiate a FairnessControllerServer.
+
+        This method supports specifying `fairness`:
+
+        - as None (or missing kwarg), using no fairness controller
+        - as a dict, parsed a FairnessControllerServer specifications:
+            - algorithm: str used to retrieve a registered type
+            - f_type: str used to define a group fairness function
+            - (opt.) f_args: dict to parametrize the fairness function
+            - any other field will be added to the `config` kwargs dict
+        - as a `FairnessControllerServer` instance (requiring no parsing)
+        """
+        if inputs is None:
+            return cls.default_parser(field, inputs)
+        if isinstance(inputs, FairnessControllerServer):
+            return inputs
+        if isinstance(inputs, dict):
+            for key in ("algorithm", "f_type"):
+                if key not in inputs:
+                    raise TypeError(
+                        "Wrong format for FairnessControllerServer "
+                        f"configuration: missing '{key}' field."
+                    )
+            return FairnessControllerServer.from_specs(**inputs)
+        raise TypeError("Unsupported inputs type for field 'fairness.")
diff --git a/test/fairness/controllers/fairness_controllers_testing.py b/test/fairness/controllers/fairness_controllers_testing.py
index 41b774b..8641ef3 100644
--- a/test/fairness/controllers/fairness_controllers_testing.py
+++ b/test/fairness/controllers/fairness_controllers_testing.py
@@ -99,6 +99,16 @@ class FairnessControllerTestSuite:
         manager.train_data = build_mock_dataset(idx)
         return manager
 
+    def test_setup_server_from_specs(
+        self,
+    ) -> None:
+        """Test instantiating a server-side controller 'from_specs'."""
+        server = self.server_cls.from_specs(
+            algorithm=self.server_cls.algorithm,
+            f_type="demographic_parity",
+        )
+        assert isinstance(server, self.server_cls)
+
     def test_setup_client_from_setup_query(
         self,
     ) -> None:
-- 
GitLab