diff --git a/declearn/fairness/fairfed/__init__.py b/declearn/fairness/fairfed/__init__.py
index e5eefbc71c9383608a0657f0027f75d8effcba28..3979d8bb435de39007c692b81aa269082e114917 100644
--- a/declearn/fairness/fairfed/__init__.py
+++ b/declearn/fairness/fairfed/__init__.py
@@ -29,8 +29,9 @@ This algorithm was originally designed for settings where a binary
 classifier is trained over data with a single binary sensitive
 attribute, with the authors showcasing their generic formulas over
 a limited set of group fairness definitions. DecLearn expands it to
-a broader case, enabling the use of arbitraty fairness definitions
+a broader case, enabling the use of arbitrary fairness definitions
 over data that may have non-binary and/or many sensitive attributes.
+A 'strict' mode is made available to stick to the original paper.
 
 Additionally, the algorithm's authors suggest combining it with other
 mechanisms that aim at enforcing model fairness during local training
@@ -51,6 +52,8 @@ Backend
 -------
 * [FairfedAggregator][declearn.fairness.fairfed.FairfedAggregator]:
     Fairfed-specific Aggregator using arbitrary averaging weights.
+* [FairfedFairnessFunction][declearn.fairness.fairfed.FairfedFairnessFunction]:
+    FairFed-specific fairness function wrapper.
 
 Messages
 --------
@@ -69,5 +72,6 @@ from ._messages import (
     SecaggFairfedDelta,
 )
 from ._aggregator import FairfedAggregator
+from ._function import FairfedFairnessFunction
 from ._client import FairfedControllerClient
 from ._server import FairfedControllerServer
diff --git a/declearn/fairness/fairfed/_client.py b/declearn/fairness/fairfed/_client.py
index 1a6ae1720981e66a918143c2211033df0c96e53c..9596ebb052faa267548a4baef988c88e140c838d 100644
--- a/declearn/fairness/fairfed/_client.py
+++ b/declearn/fairness/fairfed/_client.py
@@ -26,6 +26,7 @@ from declearn.communication.utils import verify_server_message_validity
 from declearn.fairness.api import FairnessControllerClient
 from declearn.fairness.core import instantiate_fairness_function
 from declearn.fairness.fairfed._aggregator import FairfedAggregator
+from declearn.fairness.fairfed._function import FairfedFairnessFunction
 from declearn.fairness.fairfed._messages import (
     FairfedDelta,
     FairfedDeltavg,
@@ -52,6 +53,7 @@ class FairfedControllerClient(FairnessControllerClient):
         f_type: str,
         f_args: Dict[str, Any],
         beta: float,
+        strict: bool = True,
     ) -> None:
         """Instantiate the client-side fairness controller.
 
@@ -67,12 +69,30 @@ class FairfedControllerClient(FairnessControllerClient):
         beta:
             Hyper-parameter controlling the magnitude of averaging weights'
             updates across rounds.
+        strict:
+            Whether to stick strictly to the FairFed paper's setting
+            and explicit formulas, or to use a broader adaptation of
+            FairFed to more diverse settings.
         """
+        # arguments serve modularity; pylint: disable=too-many-arguments
         super().__init__(manager)
         self.beta = beta
-        self.fairness_function = instantiate_fairness_function(
+        self._key_groups = (
+            ((0, 0), (0, 1)) if strict else None
+        )  # type: Optional[Tuple[Tuple[Any, ...], Tuple[Any, ...]]]
+        fairness_function = instantiate_fairness_function(
             f_type=f_type, counts=self.computer.counts, **f_args
         )
+        self.fairfed_func = FairfedFairnessFunction(
+            fairness_function, strict=strict
+        )
+
+    @property
+    def strict(
+        self,
+    ) -> bool:
+        """Whether this controller strictly sticks to the FairFed paper."""
+        return self.fairfed_func.strict
 
     async def finalize_fairness_setup(
         self,
@@ -98,7 +118,9 @@ class FairfedControllerClient(FairnessControllerClient):
             n_batch=n_batch,
             thresh=thresh,
         )
-        fairness = self.fairness_function.compute_from_group_accuracy(accuracy)
+        fairness = self.fairfed_func.compute_group_fairness_from_accuracy(
+            accuracy, federated=False
+        )
         # Flatten local values for post-processing and checkpointing.
         local_values = list(accuracy.values()) + list(fairness.values())
         # Scale accuracy values by sample counts for their aggregation.
@@ -124,7 +146,7 @@ class FairfedControllerClient(FairnessControllerClient):
             netwk, received, expected=FairfedFairness
         )
         # Compute the absolute difference between local and global fairness.
-        fair_avg = sum(abs(x) for x in fairness.values()) / len(groups)
+        fair_avg = self.fairfed_func.compute_synthetic_fairness_value(fairness)
         my_delta = FairfedDelta(abs(fair_avg - fair_glb.fairness))
         # Share it with the server for its (secure-)aggregation across clients.
         if secagg is None:
@@ -150,7 +172,7 @@ class FairfedControllerClient(FairnessControllerClient):
         metrics = {
             f"accuracy_{key}": val for key, val in accuracy.items()
         }  # type: Dict[str, Union[float, np.ndarray]]
-        f_type = self.fairness_function.f_type
+        f_type = self.fairfed_func.f_type
         metrics.update(
             {f"{f_type}_{key}": val for key, val in fairness.items()}
         )
diff --git a/declearn/fairness/fairfed/_function.py b/declearn/fairness/fairfed/_function.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5700229c4a5a68f25d7be031b1bb36c83733d7a
--- /dev/null
+++ b/declearn/fairness/fairfed/_function.py
@@ -0,0 +1,183 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""FairFed-specific fairness function wrapper."""
+
+import warnings
+from typing import Any, Dict, Optional, Tuple
+
+
+from declearn.fairness.core import FairnessFunction
+
+
+class FairfedFairnessFunction:
+    """FairFed-specific fairness function wrapper."""
+
+    def __init__(
+        self,
+        wrapped: FairnessFunction,
+        strict: bool = True,
+        target: Optional[int] = None,
+    ) -> None:
+        """Instantiate the FairFed-specific fairness function wrapper.
+
+        Parameters
+        ----------
+        wrapped:
+            Initial `FairnessFunction` instance to wrap up for FairFed.
+        strict:
+            Whether to stick strictly to the FairFed paper's setting
+            and explicit formulas, or to use a broader adaptation of
+            FairFed to more diverse settings.
+            See details below.
+        target:
+            Optional choice of target label to focus on in `strict` mode.
+            Only used when `strict=True`. If `None`, use `wrapped.target`
+            when it exists, or else a default value of 1.
+
+        Strict mode
+        -----------
+        This FairFed implementation comes in two flavors.
+
+        - The "strict" mode sticks to the original FairFed paper:
+            - It applies only to binary classification tasks with
+              a single binary sensitive attributes.
+            - Clients must hold examples to each and every group.
+            - If `wrapped.f_type` is not explicitly cited in the
+              original paper, a `RuntimeWarning` is warned.
+            - The synthetic fairness value is computed based on
+              fairness values for two groups: (y=`target`,s=1)
+              and (y=`target`,s=0).
+
+        - The "non-strict" mode extends to broader settings:
+            - It applies to any number of sensitive groups.
+            - Clients may not hold examples of all groups.
+            - It applies to any type of group-fairness.
+            - The synthetic fairness value is computed as
+              the average of all absolute fairness values.
+            - The local fairness is only computed over groups
+              that have a least one sample in the local data.
+        """
+        self.wrapped = wrapped
+        self._key_groups = (
+            None
+        )  # type: Optional[Tuple[Tuple[Any, ...], Tuple[Any, ...]]]
+        if strict:
+            target = int(
+                getattr(wrapped, "target", 1) if target is None else target
+            )
+            self._key_groups = self._identify_key_groups(target)
+
+    @property
+    def f_type(
+        self,
+    ) -> str:
+        """Type of group-fairness being measured."""
+        return self.wrapped.f_type
+
+    @property
+    def strict(
+        self,
+    ) -> bool:
+        """Whether this function strictly sticks to the FairFed paper."""
+        return self._key_groups is not None
+
+    def _identify_key_groups(
+        self,
+        target: int,
+    ) -> Tuple[Tuple[Any, ...], Tuple[Any, ...]]:
+        """Parse sensitive groups' definitions to identify 'key' ones."""
+        if self.f_type not in (
+            "demographic_parity",
+            "equality_of_opportunity",
+            "equalized_odds",
+        ):
+            warnings.warn(
+                f"Using fairness type '{self.f_type}' with FairFed in 'strict'"
+                " mode. This is supported, but beyond the original paper.",
+                RuntimeWarning,
+            )
+        if len(self.wrapped.groups) != 4:
+            raise RuntimeError(
+                "FairFed in 'strict' mode requires exactly 4 sensitive groups,"
+                " arising from a binary target label and a binary attribute."
+            )
+        groups = tuple(
+            sorted([grp for grp in self.wrapped.groups if grp[0] == target])
+        )
+        if len(groups) != 2:
+            raise KeyError(
+                f"Failed to identify the (target,attr_0);(target,attr_1) "
+                "pair of sensitive groups for FairFed in 'strict' mode "
+                f"with 'target' value {target}."
+            )
+        return groups
+
+    def compute_group_fairness_from_accuracy(
+        self,
+        accuracy: Dict[Tuple[Any, ...], float],
+        federated: bool,
+    ) -> Dict[Tuple[Any, ...], float]:
+        """Compute group-wise fairness values from group-wise accuracy metrics.
+
+        Parameters
+        ----------
+        accuracy:
+            Group-wise accuracy values of the model being evaluated on a
+            dataset. I.e. `{group_k: P(y_pred == y_true | group_k)}`.
+        federated:
+            Whether `accuracy` holds values computes federatively, that
+            is sum-aggregated local-group-count-weighted accuracies
+            `{group_k: sum_i(n_ik * accuracy_ik)}`.
+
+        Returns
+        -------
+        fairness:
+            Group-wise fairness metrics, as a `{group_k: score_k}` dict.
+        """
+        if federated:
+            return self.wrapped.compute_from_federated_group_accuracy(accuracy)
+        return self.wrapped.compute_from_group_accuracy(accuracy)
+
+    def compute_synthetic_fairness_value(
+        self,
+        fairness: Dict[Tuple[Any, ...], float],
+    ) -> float:
+        """Compute a synthetic fairness value from group-wise ones.
+
+        If `self.strict`, compute the difference between the fairness
+        values associated with two key sensitive groups, as per the
+        original FairFed paper for the two definitions exposed by the
+        authors.
+
+        Otherwise, compute the average of absolute group-wise fairness
+        values, that applies to more generic fairness formulations than
+        in the original paper, and may encompass broader information.
+
+        Parameters
+        ----------
+        fairness:
+            Group-wise fairness metrics, as a `{group_k: score_k}` dict.
+
+        Returns
+        -------
+        value:
+            Scalar value summarizing the computed fairness.
+        """
+        if self._key_groups is None:
+            return sum(abs(x) for x in fairness.values()) / len(fairness)
+        return fairness[self._key_groups[0]] - fairness[self._key_groups[1]]
diff --git a/declearn/fairness/fairfed/_server.py b/declearn/fairness/fairfed/_server.py
index d6bda0bc868cc342068d5c90a0135453560b60b1..2483c5f45f48476a36c4849554a06b2a3227305f 100644
--- a/declearn/fairness/fairfed/_server.py
+++ b/declearn/fairness/fairfed/_server.py
@@ -28,6 +28,7 @@ from declearn.communication.utils import verify_client_messages_validity
 from declearn.fairness.api import FairnessControllerServer
 from declearn.fairness.core import instantiate_fairness_function
 from declearn.fairness.fairfed._aggregator import FairfedAggregator
+from declearn.fairness.fairfed._function import FairfedFairnessFunction
 from declearn.fairness.fairfed._messages import (
     FairfedDelta,
     FairfedDeltavg,
@@ -35,6 +36,7 @@ from declearn.fairness.fairfed._messages import (
     FairfedOkay,
     SecaggFairfedDelta,
 )
+from declearn.messaging import FairnessSetupQuery
 from declearn.secagg.api import Decrypter
 from declearn.secagg.messaging import aggregate_secagg_messages
 
@@ -54,6 +56,7 @@ class FairfedControllerServer(FairnessControllerServer):
         f_type: str,
         f_args: Optional[Dict[str, Any]] = None,
         beta: float = 1.0,
+        strict: bool = True,
     ) -> None:
         """Instantiate the server-side Fed-FairGrad controller.
 
@@ -66,13 +69,35 @@ class FairfedControllerServer(FairnessControllerServer):
         beta:
             Hyper-parameter controlling the magnitude of updates
             to clients' averaging weights updates.
+        strict:
+            Whether to stick strictly to the FairFed paper's setting
+            and explicit formulas, or to use a broader adaptation of
+            FairFed to more diverse settings.
         """
         super().__init__(f_type=f_type, f_args=f_args)
         self.beta = beta
         # Set up a temporary fairness function, replaced at setup time.
-        self.fairness_func = instantiate_fairness_function(
+        fairfed_func = instantiate_fairness_function(
             "accuracy_parity", counts={}
         )
+        self.fairfed_func = FairfedFairnessFunction(
+            wrapped=fairfed_func, strict=strict
+        )
+
+    @property
+    def strict(
+        self,
+    ) -> bool:
+        """Whether this controller strictly sticks to the FairFed paper."""
+        return self.fairfed_func.strict
+
+    def prepare_fairness_setup_query(
+        self,
+    ) -> FairnessSetupQuery:
+        query = super().prepare_fairness_setup_query()
+        query.params["beta"] = self.beta
+        query.params["strict"] = self.strict
+        return query
 
     async def finalize_fairness_setup(
         self,
@@ -81,9 +106,12 @@ class FairfedControllerServer(FairnessControllerServer):
         aggregator: Aggregator,
     ) -> Aggregator:
         # Set up a fairness function.
-        self.fairness_func = instantiate_fairness_function(
+        fairfed_func = instantiate_fairness_function(
             self.f_type, counts=dict(zip(self.groups, counts)), **self.f_args
         )
+        self.fairfed_func = FairfedFairnessFunction(
+            wrapped=fairfed_func, strict=self.fairfed_func.strict
+        )
         # Force the use of a FairFed-specific averaging aggregator.
         warnings.warn(
             "Overriding Aggregator choice due to the use of FairFed.",
@@ -100,11 +128,11 @@ class FairfedControllerServer(FairnessControllerServer):
     ) -> Dict[str, Union[float, np.ndarray]]:
         # Unpack group-wise accuracy values and compute fairness.
         accuracy = dict(zip(self.groups, values))
-        fairness = self.fairness_func.compute_from_federated_group_accuracy(
-            accuracy
+        fairness = self.fairfed_func.compute_group_fairness_from_accuracy(
+            accuracy, federated=True
         )
         # Share the absolute mean fairness with clients.
-        fair_avg = sum(abs(x) for x in fairness.values()) / len(fairness)
+        fair_avg = self.fairfed_func.compute_synthetic_fairness_value(fairness)
         await netwk.broadcast_message(FairfedFairness(fairness=fair_avg))
         # Await and (secure-)aggregate clients' absolute fairness difference.
         received = await netwk.wait_for_messages()