diff --git a/declearn/fairness/fairgrad/__init__.py b/declearn/fairness/fairgrad/__init__.py
index 8eb00d8d2bc233bfa326eb5507ec2e14c621944f..532d3b6a1849866a1b43b79b8e17a9075a8e4a92 100644
--- a/declearn/fairness/fairgrad/__init__.py
+++ b/declearn/fairness/fairgrad/__init__.py
@@ -50,17 +50,23 @@ Controllers
 [declearn.fairness.fairgrad.FairgradControllerServer]:
     Server-side controller to implement Fed-FairGrad.
 
+Backend
+-------
+* [FairgradWeightsController]
+[declearn.fairness.fairgrad.FairgradWeightsController]:
+    Controller to implement Faigrad optimization constraints.
+
 Messages
 --------
-* [FairgradSetupQuery][declearn.fairness.fairgrad.FairgradSetupQuery]:
-    Message for server-emitted Fed-FairGrad setup queries.
+* [FairgradOkay][declearn.fairness.fairgrad.FairgradOkay]:
+    Message for client-emitted signal that Fed-FairGrad update went fine.
 * [FairgradWeights][declearn.fairness.fairgrad.FairgradWeights]:
     Message for server-emitted (Fed-)FairGrad loss weights sharing.
 """
 
 from ._messages import (
-    FairgradSetupQuery,
+    FairgradOkay,
     FairgradWeights,
 )
 from ._client import FairgradControllerClient
-from ._server import FairgradControllerServer
+from ._server import FairgradControllerServer, FairgradWeightsController
diff --git a/declearn/fairness/fairgrad/_client.py b/declearn/fairness/fairgrad/_client.py
index 1fa63ac4417a90152f0c918ced1b12437ccade9d..3e21b272f2cdb9b540a71e6314df3fa66833af0c 100644
--- a/declearn/fairness/fairgrad/_client.py
+++ b/declearn/fairness/fairgrad/_client.py
@@ -17,24 +17,19 @@
 
 """Client-side Fed-FairGrad controller."""
 
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Union
 
+import numpy as np
 
 from declearn.communication.api import NetworkClient
 from declearn.communication.utils import verify_server_message_validity
-from declearn.fairness.api import (
-    FairnessAccuracy,
-    FairnessRoundQuery,
-    FairnessRoundReply,
-    FairnessControllerClient,
-    SecaggFairnessAccuracy,
+from declearn.fairness.api import FairnessControllerClient
+from declearn.fairness.core import (
+    FairnessDataset,
+    instantiate_fairness_function,
 )
-from declearn.fairness.core import FairnessAccuracyComputer, FairnessDataset
-from declearn.fairness.fairgrad._messages import (
-    FairgradSetupQuery,
-    FairgradWeights,
-)
-from declearn.messaging import Error, SerializedMessage
+from declearn.fairness.fairgrad._messages import FairgradOkay, FairgradWeights
+from declearn.messaging import Error
 from declearn.secagg.api import Encrypter
 from declearn.training import TrainingManager
 
@@ -46,98 +41,42 @@ __all__ = [
 class FairgradControllerClient(FairnessControllerClient):
     """Client-side controller to implement Fed-FairGrad."""
 
-    setup_query_cls = FairgradSetupQuery
+    algorithm = "fedfairgrad"
 
     def __init__(
         self,
-    ) -> None:
-        super().__init__()
-        self._accuracy_computer = (
-            None
-        )  # type: Optional[FairnessAccuracyComputer]
-
-    async def finalize_fairness_setup(
-        self,
-        netwk: NetworkClient,
-        manager: TrainingManager,
-        secagg: Optional[Encrypter],
-        params: Dict[str, Any],
-    ) -> TrainingManager:
-        assert isinstance(manager.train_data, FairnessDataset)
-        # Set up a controller to compute group-wise model accuracy.
-        self._accuracy_computer = FairnessAccuracyComputer(manager.train_data)
-        # Await initial loss weights from the server.
-        await self._update_fairgrad_weights(netwk, manager)
-        # Return the input TrainingManager.
-        return manager
-
-    async def fairness_round(
-        self,
-        netwk: NetworkClient,
         manager: TrainingManager,
-        received: SerializedMessage[FairnessRoundQuery],
-        secagg: Optional[Encrypter],
+        f_type: str,
+        f_args: Dict[str, Any],
     ) -> None:
-        query = await verify_server_message_validity(
-            netwk, received, expected=FairnessRoundQuery
-        )
-        await self._compute_and_send_groupwise_accuracy(
-            netwk, manager, query, secagg
+        """Instantiate the client-side fairness controller.
+
+        Parameters
+        ----------
+        manager:
+            `TrainingManager` instance wrapping the model being trained
+            and its training dataset (that must be a `FairnessDataset`).
+        f_type:
+            Name of the type of group-fairness function being optimized.
+        f_args:
+            Keyword arguments to the group-fairness function.
+        """
+        super().__init__(manager)
+        self.fairness_function = instantiate_fairness_function(
+            f_type=f_type, counts=self.computer.counts, **f_args
         )
-        await self._update_fairgrad_weights(netwk, manager)
 
-    async def _compute_and_send_groupwise_accuracy(
+    async def finalize_fairness_setup(
         self,
         netwk: NetworkClient,
-        manager: TrainingManager,
-        query: FairnessRoundQuery,
         secagg: Optional[Encrypter],
     ) -> None:
-        # Compute the count-weighted group-wise accuracy, handling exceptions.
-        try:
-            accuracy = self._compute_groupwise_accuracy(manager, query)
-        except Exception as exc:  # pylint: disable=broad-except
-            manager.logger.error(
-                "Exception raised when computing group-wise accuracy: %s", exc
-            )
-            await netwk.send_message(Error(repr(exc)))
-            raise RuntimeError("Group accuracy computation failed.") from exc
-        # Send the computed metrics to the server, optionally encrypted.
-        manager.logger.info("Sending group-wise accuracy to the server.")
-        reply = FairnessAccuracy(accuracy)
-        if secagg is None:
-            await netwk.send_message(reply)
-        else:
-            await netwk.send_message(
-                SecaggFairnessAccuracy.from_cleartext_message(reply, secagg)
-            )
-
-    def _compute_groupwise_accuracy(
-        self,
-        manager: TrainingManager,
-        query: FairnessRoundQuery,
-    ) -> List[float]:
-        """Compute (counts-weighted) accuracy over sensitive groups."""
-        assert self._accuracy_computer is not None
-        # Compute group-wise accuracy scores.
-        accuracy = self._accuracy_computer.compute_groupwise_accuracy(
-            model=manager.model,
-            batch_size=query.batch_size,
-            n_batch=query.n_batch,
-            thresh=query.thresh,
-        )
-        # Multiply these scores by sample counts.
-        accuracy = {
-            key: val * self._accuracy_computer.counts[key]
-            for key, val in accuracy.items()
-        }
-        # Return shareable group-wise values, ordered and filled out.
-        return [accuracy.get(group, 0.0) for group in self.groups]
+        # Await initial loss weights from the server.
+        await self._update_fairgrad_weights(netwk)
 
     async def _update_fairgrad_weights(
         self,
         netwk: NetworkClient,
-        manager: TrainingManager,
     ) -> None:
         """Run a FairGrad-specific routine to update sensitive group weights.
 
@@ -158,17 +97,63 @@ class FairgradControllerClient(FairnessControllerClient):
         weights = dict(zip(self.groups, message.weights))
         # Set the received weights, handling and propagating exceptions if any.
         try:
-            assert isinstance(manager.train_data, FairnessDataset)
-            manager.train_data.set_sensitive_group_weights(
-                weights,
-                adjust_by_counts=True,
+            assert isinstance(self.manager.train_data, FairnessDataset)
+            self.manager.train_data.set_sensitive_group_weights(
+                weights, adjust_by_counts=True
             )
-        except (AssertionError, KeyError, TypeError) as exc:
-            manager.logger.error(
+        except Exception as exc:
+            self.manager.logger.error(
                 "Exception encountered when setting FairGrad weights: %s", exc
             )
             await netwk.send_message(Error(repr(exc)))
             raise RuntimeError("FairGrad weights update failed.") from exc
         # If things went well, ping the server back to indicate so.
-        manager.logger.info("Updated FairGrad weights.")
-        await netwk.send_message(FairnessRoundReply())
+        self.manager.logger.info("Updated FairGrad weights.")
+        await netwk.send_message(FairgradOkay())
+
+    def compute_fairness_measures(
+        self,
+        batch_size: int,
+        n_batch: Optional[int] = None,
+        thresh: Optional[float] = None,
+    ) -> List[float]:
+        # Compute group-wise accuracy scores.
+        accuracy = self.computer.compute_groupwise_accuracy(
+            model=self.manager.model,
+            batch_size=batch_size,
+            n_batch=n_batch,
+            thresh=thresh,
+        )
+        # Multiply these scores by sample counts.
+        accuracy = {
+            key: val * self.computer.counts[key]
+            for key, val in accuracy.items()
+        }
+        # Return shareable group-wise values, ordered and filled out.
+        return [accuracy.get(group, 0.0) for group in self.groups]
+
+    async def finalize_fairness_round(
+        self,
+        netwk: NetworkClient,
+        values: List[float],
+        secagg: Optional[Encrypter],
+    ) -> Dict[str, Union[float, np.ndarray]]:
+        # Await updated loss weights from the server.
+        await self._update_fairgrad_weights(netwk)
+        # Recover raw accuracy scores for groups with local samples.
+        accuracy = {
+            key: val / self.computer.counts[key]
+            for key, val in zip(self.groups, values)
+            if key in self.computer.counts
+        }
+        # Compute local fairness measures.
+        fairness = self.fairness_function.compute_from_group_accuracy(accuracy)
+        f_type = self.fairness_function.f_type
+        # Package and return accuracy and fairness metrics.
+        metrics = {
+            f"accuracy_{key}": val for key, val in accuracy.items()
+        }  # type: Dict[str, Union[float, np.ndarray]]
+        metrics.update(
+            {f"{f_type}_{key}": val for key, val in fairness.items()}
+        )
+        return metrics
diff --git a/declearn/fairness/fairgrad/_messages.py b/declearn/fairness/fairgrad/_messages.py
index e374b5a180079a6ef97fffdea9fd8fdd0529168d..fd650cbce3b01f8aa7c95943b3b19de83f26c3e1 100644
--- a/declearn/fairness/fairgrad/_messages.py
+++ b/declearn/fairness/fairgrad/_messages.py
@@ -21,25 +21,20 @@ import dataclasses
 from typing import List
 
 
-from declearn.fairness.api import FairnessSetupQuery
 from declearn.messaging import Message
 
 
 __all__ = [
-    "FairgradSetupQuery",
+    "FairgradOkay",
     "FairgradWeights",
 ]
 
 
 @dataclasses.dataclass
-class FairgradSetupQuery(FairnessSetupQuery):
-    """Message for server-emitted Fed-FairGrad setup queries.
+class FairgradOkay(Message):
+    """Message for client-emitted signal that Fed-FairGrad update went fine."""
 
-    This message is empty and merely signifies that Fed-FairGrad
-    should be set up by the client.
-    """
-
-    typekey = "fairgrad-setup"
+    typekey = "fairgrad-okay"
 
 
 @dataclasses.dataclass
diff --git a/declearn/fairness/fairgrad/_server.py b/declearn/fairness/fairgrad/_server.py
index f13f1c3cbac9139fcaae48ef251b7515450ac10f..b04fd3b3d06d72555a290619b78fd71045c0aae8 100644
--- a/declearn/fairness/fairgrad/_server.py
+++ b/declearn/fairness/fairgrad/_server.py
@@ -18,32 +18,28 @@
 """Server-side Fed-FairGrad controller."""
 
 import warnings
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple, Union
 
 import numpy as np
 
 from declearn.aggregator import Aggregator, SumAggregator
 from declearn.communication.api import NetworkServer
 from declearn.communication.utils import verify_client_messages_validity
-from declearn.fairness.api import (
-    FairnessAccuracy,
-    FairnessRoundQuery,
-    FairnessRoundReply,
-    FairnessControllerServer,
-    FairnessSetupQuery,
-    SecaggFairnessAccuracy,
-)
+from declearn.fairness.api import FairnessControllerServer
 from declearn.fairness.core import instantiate_fairness_function
-from declearn.fairness.fairgrad._messages import (
-    FairgradSetupQuery,
-    FairgradWeights,
-)
+from declearn.fairness.fairgrad._messages import FairgradOkay, FairgradWeights
+from declearn.messaging import FairnessSetupQuery
 from declearn.secagg.api import Decrypter
-from declearn.secagg.messaging import aggregate_secagg_messages
+
+
+__all__ = [
+    "FairgradControllerServer",
+    "FairgradWeightsController",
+]
 
 
 class FairgradWeightsController:
-    """Fairness controller to implement Faigrad optimization constraints."""
+    """Controller to implement Faigrad optimization constraints."""
 
     # attrs serve readability; pylint: disable=too-many-instance-attributes
 
@@ -157,6 +153,8 @@ class FairgradWeightsController:
 class FairgradControllerServer(FairnessControllerServer):
     """Server-side controller to implement Fed-FairGrad."""
 
+    algorithm = "fedfairgrad"
+
     def __init__(
         self,
         f_type: str,
@@ -182,16 +180,17 @@ class FairgradControllerServer(FairnessControllerServer):
             This may be set to 0.0 to try and enforce absolute fairness.
         """
         super().__init__(f_type=f_type, f_args=f_args)
-        self.weights_controller = (
-            None
-        )  # type: Optional[FairgradWeightsController]
-        self._eta = eta
-        self._eps = eps
+        # Set up a temporary controller that will be replaced at setup time.
+        self.weights_controller = FairgradWeightsController(
+            counts={}, f_type="accuracy_parity", eta=eta, eps=eps
+        )
 
     def prepare_fairness_setup_query(
         self,
     ) -> FairnessSetupQuery:
-        return FairgradSetupQuery()
+        query = super().prepare_fairness_setup_query()
+        query.params.update({"f_type": self.f_type, "f_args": self.f_args})
+        return query
 
     async def finalize_fairness_setup(
         self,
@@ -203,8 +202,8 @@ class FairgradControllerServer(FairnessControllerServer):
         self.weights_controller = FairgradWeightsController(
             counts=dict(zip(self.groups, counts)),
             f_type=self.f_type,
-            eta=self._eta,
-            eps=self._eps,
+            eta=self.weights_controller.eta,
+            eps=self.weights_controller.eps,
             **self.f_args,
         )
         # Send initial loss weights to the clients.
@@ -228,50 +227,31 @@ class FairgradControllerServer(FairnessControllerServer):
         Await for clients to ping back that things went fine on their side.
         """
         netwk.logger.info("Sending FairGrad weights to clients.")
-        assert self.weights_controller is not None
         weights = self.weights_controller.get_current_weights(norm_nk=True)
         await netwk.broadcast_message(FairgradWeights(weights=weights))
         received = await netwk.wait_for_messages()
         await verify_client_messages_validity(
-            netwk, received, expected=FairnessRoundReply
+            netwk, received, expected=FairgradOkay
         )
 
-    async def fairness_round(
+    async def finalize_fairness_round(
         self,
+        round_i: int,
+        values: List[float],
         netwk: NetworkServer,
         secagg: Optional[Decrypter],
-    ) -> None:
-        assert self.weights_controller is not None
-        # Send a query to clients and await group-wise accuracy metrics.
-        await netwk.broadcast_message(
-            FairnessRoundQuery()  # TODO: receive a config and use it
-        )
-        received = await netwk.wait_for_messages()
-        # When SecAgg is not set, expect and aggregate cleartext values.
-        if secagg is None:
-            replies = await verify_client_messages_validity(
-                netwk, received, expected=FairnessAccuracy
-            )
-            accuracy = self._aggregate_cleartext_accuracy(replies)
-        # When SecAgg is set, expect and secure-aggregate encrypted values.
-        else:
-            sec_rep = await verify_client_messages_validity(
-                netwk, received, expected=SecaggFairnessAccuracy
-            )
-            accuracy = aggregate_secagg_messages(sec_rep, secagg).values
-        # Compute global fairness and update FairGrad loss weights.
-        self.weights_controller.update_weights_based_on_accuracy(
-            accuracy=dict(zip(self.groups, accuracy))
-        )
-        # Send back the updated weights to the clients.
+    ) -> Dict[str, Union[float, np.ndarray]]:
+        # Unpack group-wise accuracy metrics and update loss weights.
+        accuracy = dict(zip(self.groups, values))
+        self.weights_controller.update_weights_based_on_accuracy(accuracy)
+        # Send the updated weights to clients.
         await self._send_fairgrad_weights(netwk)
-
-    def _aggregate_cleartext_accuracy(
-        self,
-        messages: Dict[str, FairnessAccuracy],
-    ) -> List[float]:
-        """Sum group-wise accuracy metrics received from clients."""
-        accuracy = np.zeros(len(self.groups), dtype="float64")
-        for message in messages.values():
-            accuracy += np.asarray(message.values, dtype="float64")
-        return accuracy.tolist()
+        # Package and return accuracy and fairness metrics.
+        metrics = {
+            f"accuracy_{key}": val for key, val in accuracy.items()
+        }  # type: Dict[str, Union[float, np.ndarray]]
+        fairness = self.weights_controller.get_current_fairness()
+        metrics.update(
+            {f"{self.f_type}_{key}": val for key, val in fairness.items()}
+        )
+        return metrics