From 604620ee07f1b2e6f0624c12bf902afa2eb44a3f Mon Sep 17 00:00:00 2001
From: Paul Andrey <paul.andrey@inria.fr>
Date: Thu, 30 May 2024 16:38:30 +0200
Subject: [PATCH] Improve Fairness integration into the main orchestrating
 classes.

* Integrate Fairness configuration on the server side, split
  between 'FLOptimConfig' and 'FLRunConfig'.
* Implement generic fairness measures computation on the server
  side. Implement default computations on the client side. Leave
  additional actions up to subclasses.
  This will potentially be refactored to be coherent as to what
  goes in the fairness controllers and what goes in the existing
  main classes, and as to what is part of the shared API and what
  is left to algorithm-specific subclasses.
* Move fairness round query and reply messages to 'messaging'.
---
 declearn/fairness/api/__init__.py     |   2 -
 declearn/fairness/api/_controllers.py | 109 +++++++++++++++++++++++---
 declearn/fairness/api/_messages.py    |  47 +----------
 declearn/fairness/fairgrad/_client.py |   6 +-
 declearn/main/_client.py              |  18 ++---
 declearn/main/_server.py              |  93 +++++++++++++++++-----
 declearn/main/config/__init__.py      |   3 +
 declearn/main/config/_dataclasses.py  |  28 +++++++
 declearn/main/config/_run_config.py   |  19 ++++-
 declearn/main/config/_strategy.py     |   7 +-
 declearn/messaging/__init__.py        |   4 +
 declearn/messaging/_base.py           |  40 ++++++++++
 declearn/secagg/messaging.py          |  51 +++++++++++-
 13 files changed, 334 insertions(+), 93 deletions(-)

diff --git a/declearn/fairness/api/__init__.py b/declearn/fairness/api/__init__.py
index e60366cd..832b9e11 100644
--- a/declearn/fairness/api/__init__.py
+++ b/declearn/fairness/api/__init__.py
@@ -21,8 +21,6 @@ from ._messages import (
     FairnessAccuracy,
     FairnessCounts,
     FairnessGroups,
-    FairnessRoundQuery,
-    FairnessRoundReply,
     SecaggFairnessAccuracy,
     SecaggFairnessCounts,
 )
diff --git a/declearn/fairness/api/_controllers.py b/declearn/fairness/api/_controllers.py
index 06b25786..52277c0a 100644
--- a/declearn/fairness/api/_controllers.py
+++ b/declearn/fairness/api/_controllers.py
@@ -32,13 +32,15 @@ from declearn.communication.utils import (
 from declearn.fairness.api._messages import (
     FairnessCounts,
     FairnessGroups,
-    FairnessRoundQuery,
     SecaggFairnessCounts,
 )
-from declearn.fairness.core import FairnessDataset
-from declearn.messaging import Error, Message, SerializedMessage
+from declearn.fairness.core import FairnessAccuracyComputer, FairnessDataset
+from declearn.messaging import Error, FairnessQuery, FairnessReply, Message
 from declearn.secagg.api import Decrypter, Encrypter
-from declearn.secagg.messaging import aggregate_secagg_messages
+from declearn.secagg.messaging import (
+    aggregate_secagg_messages,
+    SecaggFairnessReply,
+)
 from declearn.training import TrainingManager
 
 __all__ = [
@@ -170,12 +172,11 @@ class FairnessControllerClient(metaclass=abc.ABCMeta):
             or may not have been altered compared with the input one.
         """
 
-    @abc.abstractmethod
     async def fairness_round(
         self,
         netwk: NetworkClient,
+        query: FairnessQuery,
         manager: TrainingManager,
-        received: SerializedMessage[FairnessRoundQuery],
         secagg: Optional[Encrypter],
     ) -> None:
         """Participate in a round of actions to enforce fairness.
@@ -184,11 +185,93 @@ class FairnessControllerClient(metaclass=abc.ABCMeta):
         ----------
         netwk:
             NetworkClient endpoint instance, connected to a server.
+        query:
+            `FairnessQuery` message to participate in a fairness round.
+        manager:
+            TrainingManager instance holding the local model, optimizer, etc.
+            This method may (and usually does) have side effects on this.
+        secagg:
+            Optional SecAgg encryption controller.
+        """
+        values = self.compute_fairness_measures(query, manager)
+        reply = FairnessReply(values=values)
+        if secagg is None:
+            await netwk.send_message(reply)
+        else:
+            await netwk.send_message(
+                SecaggFairnessReply.from_cleartext_message(reply, secagg)
+            )
+        await self.finalize_fairness_round(netwk, values, manager, secagg)
+
+    def compute_fairness_measures(
+        self,
+        query: FairnessQuery,
+        manager: TrainingManager,
+    ) -> List[float]:
+        """Compute fairness measures based on a received query.
+
+        By default, compute and return group-wise accuracy metrics,
+        weighted by group-wise sample counts. This may be modified
+        by algorithm-specific subclasses depending on algorithms'
+        needs.
+
+        Parameters
+        ----------
+        query:
+            `FairnessQuery` message with computational effort constraints,
+            and optionally model weights to assign before evaluation.
+        manager:
+            TrainingManager instance holding the model to evaluate and the
+            training dataset on which to do so.
+
+        Returns
+        -------
+        values:
+            Computed values, as a deterministic-length ordered list
+            of float values.
+        """
+        assert isinstance(manager.train_data, FairnessDataset)
+        if query.weights is not None:
+            manager.model.set_weights(query.weights, trainable=True)
+        # Compute group-wise accuracy metrics.
+        computer = FairnessAccuracyComputer(manager.train_data)
+        accuracy = computer.compute_groupwise_accuracy(
+            model=manager.model,
+            batch_size=query.batch_size,
+            n_batch=query.n_batch,
+            thresh=query.thresh,
+        )
+        # Scale computed accuracy metrics by sample counts.
+        accuracy = {
+            key: val * computer.counts[key] for key, val in accuracy.items()
+        }
+        # Gather ordered values (filling-in groups without samples).
+        return [accuracy.get(group, 0.0) for group in self.groups]
+
+    @abc.abstractmethod
+    async def finalize_fairness_round(
+        self,
+        netwk: NetworkClient,
+        values: List[float],
+        manager: TrainingManager,
+        secagg: Optional[Encrypter],
+    ) -> None:
+        """Take actions to enforce fairness.
+
+        This method is designed to be called after an initial query
+        has been received and responded to, resulting in computing
+        and sharing fairness(-related) metrics.
+
+        Parameters
+        ----------
+        netwk:
+            NetworkClient endpoint instance, connected to a server.
+        values:
+            List of locally-computed evaluation metrics, already shared
+            with the server for their (secure-)aggregation.
         manager:
             TrainingManager instance holding the local model, optimizer, etc.
             This method may (and usually does) have side effects on this.
-        received:
-            Serialized query message to participated in a fairness round.
         secagg:
             Optional SecAgg encryption controller.
         """
@@ -367,19 +450,27 @@ class FairnessControllerServer(metaclass=abc.ABCMeta):
         """
 
     @abc.abstractmethod
-    async def fairness_round(
+    async def finalize_fairness_round(
         self,
         round_i: int,
+        values: List[float],
         netwk: NetworkServer,
         secagg: Optional[Decrypter],
     ) -> None:
         """Orchestrate a round of actions to enforce fairness.
 
+        This method is designed to be called after an initial query
+        has been sent and responded to by clients, resulting in the
+        federated computation of fairness(-related) metrics.
+
         Parameters
         ----------
         round_i:
             Index of the current round (reflecting that of an upcoming
             training round).
+        values:
+            Aggregated metrics resulting from the fairness evaluation
+            run by clients at this round.
         netwk:
             NetworkServer endpoint instance, to which clients are registered.
         secagg:
diff --git a/declearn/fairness/api/_messages.py b/declearn/fairness/api/_messages.py
index eae28a45..7beb0785 100644
--- a/declearn/fairness/api/_messages.py
+++ b/declearn/fairness/api/_messages.py
@@ -18,7 +18,7 @@
 """API messages for fairness-aware federated learning setup and rounds."""
 
 import dataclasses
-from typing import Any, List, Optional, Tuple
+from typing import Any, List, Tuple
 
 from typing_extensions import Self  # future: import from typing (py >=3.11)
 
@@ -30,8 +30,6 @@ __all__ = [
     "FairnessAccuracy",
     "FairnessCounts",
     "FairnessGroups",
-    "FairnessRoundQuery",
-    "FairnessRoundReply",
     "SecaggFairnessAccuracy",
     "SecaggFairnessCounts",
 ]
@@ -168,46 +166,3 @@ class FairnessGroups(Message):
     ) -> Self:
         kwargs["groups"] = [tuple(group) for group in kwargs["groups"]]
         return super().from_kwargs(**kwargs)
-
-
-@dataclasses.dataclass
-class FairnessRoundQuery(Message):
-    """Base Message for server-emitted fairness-computation queries.
-
-    The base `FairnessRoundQuery` defines information that is used
-    when evaluating a model's accuracy and/or loss over group-wise
-    training samples.
-
-    Subclasses may be defined to add algorithm-specific information.
-
-    Fields
-    ------
-    batch_size:
-        Number of samples per batch when computing metrics.
-    n_batch:
-        Optional maximum number of batches to draw per group.
-        If None, use the entire wrapped dataset.
-    thresh:
-        Optional binarization threshold for binary classification
-        models' output scores. If None, use 0.5 by default, or 0.0
-        for `SklearnSGDModel` instances.
-        Unused for multinomial classifiers (argmax over scores).
-    """
-
-    batch_size: int = 32
-    n_batch: Optional[int] = None
-    thresh: Optional[float] = None
-
-    typekey = "fairness-round-query"
-
-
-@dataclasses.dataclass
-class FairnessRoundReply(Message):
-    """Base Message for client-emitted fairness-round end signal.
-
-    By default this message is empty, merely noticing that things
-    went well. Subclasses may be used to convey algorithm-specific
-    results or information.
-    """
-
-    typekey = "fairness-round-reply"
diff --git a/declearn/fairness/fairgrad/_client.py b/declearn/fairness/fairgrad/_client.py
index 1af21bad..1fa63ac4 100644
--- a/declearn/fairness/fairgrad/_client.py
+++ b/declearn/fairness/fairgrad/_client.py
@@ -17,7 +17,7 @@
 
 """Client-side Fed-FairGrad controller."""
 
-from typing import List, Optional
+from typing import Any, Dict, List, Optional
 
 
 from declearn.communication.api import NetworkClient
@@ -27,7 +27,6 @@ from declearn.fairness.api import (
     FairnessRoundQuery,
     FairnessRoundReply,
     FairnessControllerClient,
-    FairnessSetupQuery,
     SecaggFairnessAccuracy,
 )
 from declearn.fairness.core import FairnessAccuracyComputer, FairnessDataset
@@ -60,8 +59,9 @@ class FairgradControllerClient(FairnessControllerClient):
     async def finalize_fairness_setup(
         self,
         netwk: NetworkClient,
-        query: FairnessSetupQuery,
         manager: TrainingManager,
+        secagg: Optional[Encrypter],
+        params: Dict[str, Any],
     ) -> TrainingManager:
         assert isinstance(manager.train_data, FairnessDataset)
         # Set up a controller to compute group-wise model accuracy.
diff --git a/declearn/main/_client.py b/declearn/main/_client.py
index 75ac07ec..7770c4ff 100644
--- a/declearn/main/_client.py
+++ b/declearn/main/_client.py
@@ -36,7 +36,6 @@ from declearn.dataset import Dataset, load_dataset_from_json
 from declearn.fairness.api import (
     FairnessControllerClient,
     FairnessSetupQuery,
-    FairnessRoundQuery,
 )
 from declearn.main.utils import Checkpointer
 from declearn.messaging import Message, SerializedMessage
@@ -254,8 +253,8 @@ class FederatedClient:
             await self.training_round(message.deserialize())
         elif issubclass(message.message_cls, messaging.EvaluationRequest):
             await self.evaluation_round(message.deserialize())
-        elif issubclass(message.message_cls, FairnessRoundQuery):
-            await self.fairness_round(message)  # note: keep serialized
+        elif issubclass(message.message_cls, messaging.FairnessQuery):
+            await self.fairness_round(message.deserialize())
         elif issubclass(message.message_cls, SecaggSetupQuery):
             await self.setup_secagg(message)  # note: keep serialized
         elif issubclass(message.message_cls, messaging.StopTraining):
@@ -622,7 +621,7 @@ class FederatedClient:
 
     async def fairness_round(
         self,
-        received: SerializedMessage[FairnessRoundQuery],
+        query: messaging.FairnessQuery,
     ) -> None:
         """Handle a server request to run a fairness-related round.
 
@@ -633,8 +632,8 @@ class FederatedClient:
 
         Parameters
         ----------
-        received:
-            Serialized `FairnessRoundQuery` message from the server.
+        query:
+            `FairnessQuery` message from the server.
 
         Raises
         ------
@@ -645,9 +644,8 @@ class FederatedClient:
         # If no fairness controller was set up, raise a RuntimeError.
         if self.fairness is None:
             error = (
-                "Received a query to participate in a fairness round "
-                f"('{received.message_cls.__name__}'), but no fairness "
-                "controller was set up."
+                "Received a query to participate in a fairness round, "
+                "but no fairness controller was set up."
             )
             self.logger.critical(error)
             await self.netwk.send_message(messaging.Error(error))
@@ -655,8 +653,8 @@ class FederatedClient:
         # Otherwise, run the controller's routine.
         await self.fairness.fairness_round(
             netwk=self.netwk,
+            query=query,
             manager=self.trainmanager,
-            received=received,
             secagg=self._encrypter,
         )
 
diff --git a/declearn/main/_server.py b/declearn/main/_server.py
index be099444..7b68744b 100644
--- a/declearn/main/_server.py
+++ b/declearn/main/_server.py
@@ -31,9 +31,9 @@ import numpy as np
 from declearn import messaging
 from declearn.communication import NetworkServerConfig
 from declearn.communication.api import NetworkServer
-from declearn.fairness.api import FairnessControllerServer
 from declearn.main.config import (
     EvaluateConfig,
+    FairnessConfig,
     FLOptimConfig,
     FLRunConfig,
     TrainingConfig,
@@ -48,14 +48,9 @@ from declearn.metrics import MetricInputType, MetricSet
 from declearn.metrics._mean import MeanState
 from declearn.model.api import Model, Vector
 from declearn.optimizer.modules import AuxVar
+from declearn.secagg import messaging as secagg_messaging
 from declearn.secagg import parse_secagg_config_server
 from declearn.secagg.api import Decrypter, SecaggConfigServer
-from declearn.secagg.messaging import (
-    SecaggEvaluationReply,
-    SecaggMessage,
-    SecaggTrainReply,
-    aggregate_secagg_messages,
-)
 from declearn.utils import deserialize_object, get_logger
 
 
@@ -79,7 +74,6 @@ class FederatedServer:
         optim: Union[FLOptimConfig, str, Dict[str, Any]],
         metrics: Union[MetricSet, List[MetricInputType], None] = None,
         secagg: Union[SecaggConfigServer, Dict[str, Any], None] = None,
-        fairness: Union[FairnessControllerServer, None] = None,
         checkpoint: Union[Checkpointer, Dict[str, Any], str, None] = None,
         logger: Union[logging.Logger, str, None] = None,
     ) -> None:
@@ -108,8 +102,6 @@ class FederatedServer:
         secagg: SecaggConfigServer or dict or None, default=None
             Optional SecAgg config and setup controller
             or dict of kwargs to set one up.
-        fairness: FairnessControllerServer of None, default=None
-            Optional Fairness-aware Federated Learning controller.
         checkpoint: Checkpointer or dict or str or None, default=None
             Optional Checkpointer instance or instantiation dict to be
             used so as to save round-wise model, optimizer and metrics.
@@ -133,6 +125,7 @@ class FederatedServer:
         self.aggrg = optim.aggregator
         self.optim = optim.server_opt
         self.c_opt = optim.client_opt
+        self.fairness = optim.fairness  # note: optional
         # Assign the wrapped MetricSet.
         self.metrics = MetricSet.from_specs(metrics)
         # Assign an optional checkpointer.
@@ -143,8 +136,6 @@ class FederatedServer:
         self.secagg = self._parse_secagg(secagg)
         self._decrypter = None  # type: Optional[Decrypter]
         self._secagg_peers = set()  # type: Set[str]
-        # Assign the optional FairnessControllerServer.
-        self.fairness = fairness  # TODO: add proper parser and alternatives
         # Set up private attributes to record the loss values and best weights.
         self._loss = {}  # type: Dict[int, float]
         self._best = None  # type: Optional[Vector]
@@ -277,7 +268,8 @@ class FederatedServer:
             specify the federated learning process, including clients
             registration, training and validation rounds' setup, plus
             optional elements: local differential-privacy parameters,
-            and/or an early-stopping criterion.
+            fairness evaluation parameters, and/or an early-stopping
+            criterion.
         """
         # Instantiate the early-stopping criterion, if any.
         early_stop = None  # type: Optional[EarlyStopping]
@@ -293,7 +285,8 @@ class FederatedServer:
             round_i = 0
             while True:
                 round_i += 1
-                # TODO: await self.fairness_round(round_i, config.fairness)
+                if self.fairness is not None:
+                    await self.fairness_round(round_i, config.fairness)
                 await self.training_round(round_i, config.training)
                 await self.evaluation_round(round_i, config.evaluate)
                 if not self._keep_training(round_i, config.rounds, early_stop):
@@ -529,11 +522,69 @@ class FederatedServer:
 
     def _aggregate_secagg_replies(
         self,
-        replies: Mapping[str, SecaggMessage[MessageT]],
+        replies: Mapping[str, secagg_messaging.SecaggMessage[MessageT]],
     ) -> MessageT:
         """Secure-Aggregate (and decrypt) client-issued encrypted messages."""
         assert self._decrypter is not None
-        return aggregate_secagg_messages(replies, decrypter=self._decrypter)
+        return secagg_messaging.aggregate_secagg_messages(
+            replies, decrypter=self._decrypter
+        )
+
+    async def fairness_round(
+        self,
+        round_i: int,
+        fairness_cfg: FairnessConfig,
+    ) -> None:
+        """Orchestrate a fairness round.
+
+        Parameters
+        ----------
+        round_i:
+            Index of the training round.
+        fairness_cfg:
+            FairnessConfig dataclass instance wrapping data-batching
+            and computational effort constraints hyper-parameters for
+            fairness evaluation.
+        """
+        assert self.fairness is not None
+        # Run SecAgg setup when needed.
+        self.logger.info("Initiating fairness-enforcing round %s", round_i)
+        clients = self.netwk.client_names  # FUTURE: enable sampling(?)
+        if self.secagg is not None and clients.difference(self._secagg_peers):
+            await self.setup_secagg(clients)
+        # Send a query to clients, including model weights when required.
+        query = messaging.FairnessQuery(
+            round_i=round_i,
+            batch_size=fairness_cfg.batch_size,
+            n_batch=fairness_cfg.n_batch,
+            thresh=fairness_cfg.thresh,
+            weights=None,
+        )
+        await self._send_request_with_optional_weights(query, clients)
+        # Await and (secure-)aggregate) results.
+        self.logger.info("Awaiting clients' fairness measures.")
+        if self._decrypter is None:
+            replies = await self._collect_results(
+                clients, messaging.FairnessReply, "fairness round"
+            )
+            if len(set(len(r.values) for r in replies.values())) != 1:
+                error = "Clients sent fairness values of different lengths."
+                self.logger.error(error)
+                await self.netwk.broadcast_message(messaging.Error(error))
+                raise RuntimeError(error)
+            values = [sum(c_values) for c_values in zip(*replies.values())]
+        else:
+            secagg_replies = await self._collect_results(
+                clients, secagg_messaging.SecaggFairnessReply, "fairness round"
+            )
+            values = self._aggregate_secagg_replies(secagg_replies).values
+        # Have the fairness controller process results.
+        await self.fairness.finalize_fairness_round(
+            round_i=round_i,
+            values=values,
+            netwk=self.netwk,
+            secagg=self._decrypter,
+        )
 
     async def training_round(
         self,
@@ -564,7 +615,7 @@ class FederatedServer:
             )
         else:
             secagg_results = await self._collect_results(
-                clients, SecaggTrainReply, "training"
+                clients, secagg_messaging.SecaggTrainReply, "training"
             )
             results = {
                 "aggregated": self._aggregate_secagg_replies(secagg_results)
@@ -609,7 +660,11 @@ class FederatedServer:
 
     async def _send_request_with_optional_weights(
         self,
-        msg_light: Union[messaging.TrainRequest, messaging.EvaluationRequest],
+        msg_light: Union[
+            messaging.TrainRequest,
+            messaging.EvaluationRequest,
+            messaging.FairnessQuery,
+        ],
         clients: Set[str],
     ) -> None:
         """Send a request to clients, sparingly adding model weights to it.
@@ -693,7 +748,7 @@ class FederatedServer:
             )
         else:
             secagg_results = await self._collect_results(
-                clients, SecaggEvaluationReply, "evaluation"
+                clients, secagg_messaging.SecaggEvaluationReply, "evaluation"
             )
             results = {
                 "aggregated": self._aggregate_secagg_replies(secagg_results)
diff --git a/declearn/main/config/__init__.py b/declearn/main/config/__init__.py
index c26bb7ee..42e7c0e5 100644
--- a/declearn/main/config/__init__.py
+++ b/declearn/main/config/__init__.py
@@ -33,6 +33,8 @@ The following dataclasses are articulated by `FLRunConfig`:
 
 * [EvaluateConfig][declearn.main.config.EvaluateConfig]:
     Hyper-parameters for an evaluation round.
+* [FairnessConfig][declearn.main.config.FairnessConfig]:
+    Dataclass wrapping parameters for fairness evaluation rounds.
 * [RegisterConfig][declearn.main.config.RegisterConfig]:
     Hyper-parameters for clients registration.
 * [TrainingConfig][declearn.main.config.TrainingConfig]:
@@ -41,6 +43,7 @@ The following dataclasses are articulated by `FLRunConfig`:
 
 from ._dataclasses import (
     EvaluateConfig,
+    FairnessConfig,
     PrivacyConfig,
     RegisterConfig,
     TrainingConfig,
diff --git a/declearn/main/config/_dataclasses.py b/declearn/main/config/_dataclasses.py
index 0c4f5614..867343e6 100644
--- a/declearn/main/config/_dataclasses.py
+++ b/declearn/main/config/_dataclasses.py
@@ -22,6 +22,7 @@ from typing import Any, Dict, Optional, Tuple
 
 __all__ = [
     "EvaluateConfig",
+    "FairnessConfig",
     "PrivacyConfig",
     "RegisterConfig",
     "TrainingConfig",
@@ -223,3 +224,30 @@ class PrivacyConfig:
         accountants = ("rdp", "gdp", "prv")
         if self.accountant not in accountants:
             raise TypeError(f"'accountant' should be one of {accountants}")
+
+
+@dataclasses.dataclass
+class FairnessConfig:
+    """Dataclass wrapping parameters for fairness evaluation rounds.
+
+    The parameters wrapped by this class are those of
+    `declearn.fairness.core.FairnessAccuracyComputer`
+    metrics-computation methods.
+
+    Attributes
+    ----------
+    batch_size: int
+        Number of samples per processed data batch.
+    n_batch: int or None, default=None
+        Optional maximum number of batches to draw.
+        If None, use the entire training dataset.
+    thresh: float or None, default=None
+        Optional binarization threshold for binary classification
+        models' output scores. If None, use 0.5 by default, or 0.0
+        for `SklearnSGDModel` instances.
+        Unused for multinomial classifiers (argmax over scores).
+    """
+
+    batch_size: int = 32
+    n_batch: Optional[int] = None
+    thresh: Optional[float] = None
diff --git a/declearn/main/config/_run_config.py b/declearn/main/config/_run_config.py
index ec68dbbd..25b93597 100644
--- a/declearn/main/config/_run_config.py
+++ b/declearn/main/config/_run_config.py
@@ -25,6 +25,7 @@ from typing_extensions import Self  # future: import from typing (py >=3.11)
 from declearn.main.utils import EarlyStopConfig
 from declearn.main.config._dataclasses import (
     EvaluateConfig,
+    FairnessConfig,
     PrivacyConfig,
     RegisterConfig,
     TrainingConfig,
@@ -66,6 +67,10 @@ class FLRunConfig(TomlConfig):
         and data-batching instructions.
     - evaluate: EvaluateConfig
         Parameters for validation rounds, similar to training ones.
+    - fairness: FairnessConfig or None
+        Parameters for fairness evaluation rounds.
+        Only used when an algorithm to enforce fairness is set up,
+        as part of the process's federated optimization configuration.
     - privacy: PrivacyConfig or None
         Optional parameters to set up local differential privacy,
         by having clients use the DP-SGD algorithm for training.
@@ -90,12 +95,15 @@ class FLRunConfig(TomlConfig):
       batch size will be used for evaluation as well.
     - If `privacy` is provided and the 'poisson' parameter is unspecified
       for `training`, it will be set to True by default rather than False.
+    - If `fairness` is not provided or lacks a 'batch_size' parameter,
+      that of evaluation (or, by extension, training) will be used.
     """
 
     rounds: int
     register: RegisterConfig
     training: TrainingConfig
     evaluate: EvaluateConfig
+    fairness: FairnessConfig
     privacy: Optional[PrivacyConfig] = None
     early_stop: Optional[EarlyStopConfig] = None  # type: ignore  # is a type
 
@@ -128,7 +136,7 @@ class FLRunConfig(TomlConfig):
         # If evaluation batch size is not set, use the same as training.
         # Note: if inputs have invalid formats, let the parent method fail.
         evaluate = kwargs.setdefault("evaluate", {})
-        if isinstance(evaluate, dict):
+        if isinstance(evaluate, dict) and ("batch_size" not in evaluate):
             training = kwargs.get("training")
             if isinstance(training, dict):
                 evaluate.setdefault("batch_size", training.get("batch_size"))
@@ -141,5 +149,14 @@ class FLRunConfig(TomlConfig):
             training = kwargs.get("training")
             if isinstance(training, dict):
                 training.setdefault("poisson", True)
+        # If fairness batch size is not set, use the same as evaluation.
+        # Note: if inputs have invalid formats, let the parent method fail.
+        fairness = kwargs.setdefault("fairness", {})
+        if isinstance(fairness, dict) and ("batch_size" not in fairness):
+            evaluate = kwargs.get("evaluate")
+            if isinstance(evaluate, dict):
+                fairness.setdefault("batch_size", evaluate.get("batch_size"))
+            elif isinstance(evaluate, EvaluateConfig):
+                fairness.setdefault("batch_size", evaluate.batch_size)
         # Delegate the rest of the work to the parent method.
         return super().from_params(**kwargs)
diff --git a/declearn/main/config/_strategy.py b/declearn/main/config/_strategy.py
index 333072a7..707b93a2 100644
--- a/declearn/main/config/_strategy.py
+++ b/declearn/main/config/_strategy.py
@@ -19,10 +19,11 @@
 
 import dataclasses
 import functools
-from typing import Any, Dict, Union
+from typing import Any, Dict, Optional, Union
 
 
 from declearn.aggregator import Aggregator, AveragingAggregator
+from declearn.fairness.api import FairnessControllerServer
 from declearn.optimizer import Optimizer
 from declearn.utils import TomlConfig, access_registered, deserialize_object
 
@@ -59,6 +60,9 @@ class FLOptimConfig(TomlConfig):
     - aggregator: Aggregator, default=AverageAggregator()
         Client weights aggregator to be used by the server so as
         to conduct the round-wise aggregation of client udpates.
+    - fairness: Fairness or None, default=None
+        Optional `FairnessControllerServer` instance specifying
+        an algorithm to enforce fairness of the trained model.
 
     Notes
     -----
@@ -98,6 +102,7 @@ class FLOptimConfig(TomlConfig):
     aggregator: Aggregator = dataclasses.field(
         default_factory=AveragingAggregator
     )
+    fairness: Optional[FairnessControllerServer] = None
 
     @classmethod
     def parse_client_opt(
diff --git a/declearn/messaging/__init__.py b/declearn/messaging/__init__.py
index 17f5e5e4..515a1c24 100644
--- a/declearn/messaging/__init__.py
+++ b/declearn/messaging/__init__.py
@@ -33,6 +33,8 @@ Base messages
 * [Error][declearn.messaging.Error]
 * [EvaluationReply][declearn.messaging.EvaluationReply]
 * [EvaluationRequest][declearn.messaging.EvaluationRequest]
+* [FairnessQuery][declearn.messaging.FairnessQuery]
+* [FairnessReply][declearn.messaging.FairnessReply]
 * [GenericMessage][declearn.messaging.GenericMessage]
 * [InitRequest][declearn.messaging.InitRequest]
 * [InitReply][declearn.messaging.InitReply]
@@ -55,6 +57,8 @@ from ._base import (
     Error,
     EvaluationReply,
     EvaluationRequest,
+    FairnessQuery,
+    FairnessReply,
     GenericMessage,
     InitRequest,
     InitReply,
diff --git a/declearn/messaging/_base.py b/declearn/messaging/_base.py
index 6cec34d2..fe029e23 100644
--- a/declearn/messaging/_base.py
+++ b/declearn/messaging/_base.py
@@ -36,6 +36,8 @@ __all__ = [
     "Error",
     "EvaluationReply",
     "EvaluationRequest",
+    "FairnessQuery",
+    "FairnessReply",
     "GenericMessage",
     "InitRequest",
     "InitReply",
@@ -100,6 +102,44 @@ class EvaluationReply(Message):
         return kwargs
 
 
+@dataclasses.dataclass
+class FairnessQuery(Message):
+    """Base Message for server-emitted fairness-computation queries.
+
+    This message conveys hyper-parameters used when evaluating a model's
+    accuracy and/or loss over group-wise samples (from which fairness is
+    derived). Model weights may be attached.
+
+    Algorithm-specific information should be conveyed using ad-hoc
+    messages exchanged as part of fairness-enforcement routines.
+    """
+
+    typekey = "fairness-request"
+
+    round_i: int
+    batch_size: int = 32
+    n_batch: Optional[int] = None
+    thresh: Optional[float] = None
+    weights: Optional[Vector] = None
+
+
+@dataclasses.dataclass
+class FairnessReply(Message):
+    """Base Message for client-emitted fairness-computation results.
+
+    This message conveys results from the evaluation of a model's accuracy
+    and/or loss over group-wise samples (from which fairness is derived).
+
+    This information is generically stored as a list of `values`, the
+    mearning and structure of which is left up to algorithm-specific
+    controllers.
+    """
+
+    typekey = "fairness-reply"
+
+    values: List[float] = dataclasses.field(default_factory=list)
+
+
 @dataclasses.dataclass
 class GenericMessage(Message):
     """Generic message format, with action/params pair."""
diff --git a/declearn/secagg/messaging.py b/declearn/secagg/messaging.py
index 43a67167..9a906705 100644
--- a/declearn/secagg/messaging.py
+++ b/declearn/secagg/messaging.py
@@ -19,18 +19,24 @@
 
 import abc
 import dataclasses
-from typing import Dict, Generic, Mapping, TypeVar
+from typing import Dict, Generic, List, Mapping, TypeVar
 
 from typing_extensions import Self  # future: import from typing (py >=3.11)
 
 from declearn.aggregator import ModelUpdates
-from declearn.messaging import EvaluationReply, Message, TrainReply
+from declearn.messaging import (
+    EvaluationReply,
+    FairnessReply,
+    Message,
+    TrainReply,
+)
 from declearn.metrics import MetricState
 from declearn.optimizer.modules import AuxVar
 from declearn.secagg.api import Decrypter, Encrypter, SecureAggregate
 
 __all__ = [
     "SecaggEvaluationReply",
+    "SecaggFairnessReply",
     "SecaggMessage",
     "SecaggTrainReply",
     "aggregate_secagg_messages",
@@ -256,3 +262,44 @@ class SecaggEvaluationReply(SecaggMessage[EvaluationReply]):
         return self.__class__(
             loss=loss, n_steps=n_steps, t_spent=t_spent, metrics=metrics
         )
+
+
+@dataclasses.dataclass
+class SecaggFairnessReply(SecaggMessage[FairnessReply]):
+    """SecAgg-wrapped 'FairnessReply' message."""
+
+    typekey = "secagg_fairness_reply"
+
+    values: List[int]
+
+    @classmethod
+    def from_cleartext_message(
+        cls,
+        cleartext: FairnessReply,
+        encrypter: Encrypter,
+    ) -> Self:
+        values = [encrypter.encrypt_float(value) for value in cleartext.values]
+        return cls(values=values)
+
+    def decrypt_wrapped_message(
+        self,
+        decrypter: Decrypter,
+    ) -> FairnessReply:
+        values = [decrypter.decrypt_float(value) for value in self.values]
+        return FairnessReply(values=values)
+
+    def aggregate(
+        self,
+        other: Self,
+        decrypter: Decrypter,
+    ) -> Self:
+        if len(self.values) != len(other.values):
+            raise ValueError(
+                "Cannot aggregate SecAgg-protected fairness values with "
+                "distinct shapes."
+            )
+        values = [
+            decrypter.sum_encrypted([v_a, v_b])
+            for v_a, v_b in zip(self.values, other.values)
+        ]
+        return self.__class__(values=values)
-- 
GitLab