diff --git a/declearn/fairness/__init__.py b/declearn/fairness/__init__.py
index 15eb1aaca2e39fc0b807e224bbf725045ffee318..ede33109cf7d2aec898c0c5f309bccb51e5a365d 100644
--- a/declearn/fairness/__init__.py
+++ b/declearn/fairness/__init__.py
@@ -23,10 +23,13 @@ This module implements the following submodules:
     API to set up and run fairness-aware federated learning algorithms.
 * [core][declearn.fairness.core]:
     Core components and utils for fairness-aware (federated) machine learning.
+* [fairbatch][declearn.fairness.fairbatch]:
+    Fed-FairBatch / FedB algorithm controllers and utils.
 * [fairgrad][declearn.fairness.fairgrad]:
     Fed-FairGrad algorithm controllers and utils.
 """
 
 from . import core
 from . import api
+from . import fairbatch
 from . import fairgrad
diff --git a/declearn/fairness/fairbatch/__init__.py b/declearn/fairness/fairbatch/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0dbd6f1759f284052741af57347548c339027df8
--- /dev/null
+++ b/declearn/fairness/fairbatch/__init__.py
@@ -0,0 +1,85 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Fed-FairBatch / FedFB algorithm controllers and utils.
+
+Introduction
+------------
+This module provides with a double-fold implementation of an adaptation
+of the FairBatch [1] algorithm for federated learning. On the one hand,
+the FedFB [2] algorithm is implemented, that both adapts FairBatch in a
+straightforward manner and introduces changes in formulas compared with
+the initial paper. On the other hand, the a custom algorithm deemed as
+Fed-FairBatch is implemented, that is similar in intent to FedFB but
+sticks to the raw FairBatch formulas.
+
+FairBatch is a group-fairness-enforcing algorithm that relies on a
+specific form of loss reweighting mediated by a specific batching
+of samples for SGD steps. Namely, in FairBatch, batches are drawn
+by concatenating group-wise sub-batches, the size of which is the
+byproduct of the desired total batch size and group-wise sampling
+probabilities, with the latter being updated throughout training
+based on the current model's fairness.
+
+Initially, FairBatch is designed for binary classification tasks
+on data that have a single binary sensitive attribute. Both our
+implementations currently stick to that setting, in spite of the
+FedFB authors using a formalism that arguably extend formulas to
+more generic categorical sensitive attribute(s) - which is not
+tested in the paper.
+
+Controllers
+-----------
+* [FairbatchControllerClient]
+[declearn.fairness.fairbatch.FairgradControllerClient]:
+    Client-side controller to implement Fed-FairBatch or FedFB.
+* [FairbatchControllerServer]
+[declearn.fairness.fairbatch.FairgradControllerServer]:
+    Server-side controller to implement Fed-FairBatch or FedFB.
+
+Backend
+-------
+* [FairbatchSamplingController]
+[declearn.fairness.fairbatch.FairbatchSamplingController]:
+    ABC to compute and update Fairbatch sampling probabilities.
+* [setup_fairbatch_controller]
+[declearn.fairness.fairbatch.setup_fairbatch_controller]:
+    Instantiate a FairBatch sampling probabilities controller.
+* [setup_fedfb_controller]
+[declearn.fairness.fairbatch.setup_fedfb_controller]:
+    Instantiate a FedFB sampling probabilities controller.
+
+Messages
+--------
+* [FairbatchOkay][declearn.fairness.fairbatch.FairbatchOkay]:
+    Message for client signal that Fed-FairBatch/FedFB update went fine.
+* [FairbatchSamplingProbas]
+[declearn.fairness.fairbatch.FairbatchSamplingProbas]:
+    Message for server-emitted Fed-FairBatch/Fed-FB sampling probabilities.
+"""
+
+from ._messages import (
+    FairbatchOkay,
+    FairbatchSamplingProbas,
+)
+from ._sampling import (
+    FairbatchSamplingController,
+    setup_fairbatch_controller,
+)
+from ._fedfb import setup_fedfb_controller
+from ._client import FairbatchControllerClient
+from ._server import FairbatchControllerServer
diff --git a/declearn/fairness/fairbatch/_client.py b/declearn/fairness/fairbatch/_client.py
new file mode 100644
index 0000000000000000000000000000000000000000..41f4460f6d007cf5556d49730cd85c0e9bae15f7
--- /dev/null
+++ b/declearn/fairness/fairbatch/_client.py
@@ -0,0 +1,180 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Client-side Fed-FairBatch controller."""
+
+from typing import Any, Dict, List, Optional, Union
+
+import numpy as np
+
+from declearn.communication.api import NetworkClient
+from declearn.communication.utils import verify_server_message_validity
+from declearn.fairness.api import FairnessControllerClient
+from declearn.fairness.core import (
+    FairnessDataset,
+    instantiate_fairness_function,
+)
+from declearn.fairness.fairbatch._dataset import FairbatchDataset
+from declearn.fairness.fairbatch._messages import (
+    FairbatchSamplingProbas,
+    FairbatchOkay,
+)
+from declearn.messaging import Error
+from declearn.secagg.api import Encrypter
+from declearn.training import TrainingManager
+
+__all__ = [
+    "FairbatchControllerClient",
+]
+
+
+class FairbatchControllerClient(FairnessControllerClient):
+    """Client-side controller to implement Fed-FairBatch or FedFB."""
+
+    algorithm = "fedfairbatch"
+
+    def __init__(
+        self,
+        manager: TrainingManager,
+        f_type: str,
+        f_args: Dict[str, Any],
+    ) -> None:
+        """Instantiate the client-side fairness controller.
+
+        Parameters
+        ----------
+        manager:
+            `TrainingManager` instance wrapping the model being trained
+            and its training dataset (that must be a `FairnessDataset`).
+        f_type:
+            Name of the type of group-fairness function being optimized.
+        f_args:
+            Keyword arguments to the group-fairness function.
+        """
+        super().__init__(manager)
+        assert isinstance(self.manager.train_data, FairnessDataset)
+        self.manager.train_data = FairbatchDataset(self.manager.train_data)
+        self.fairness_function = instantiate_fairness_function(
+            f_type=f_type, counts=self.computer.counts, **f_args
+        )
+
+    async def finalize_fairness_setup(
+        self,
+        netwk: NetworkClient,
+        secagg: Optional[Encrypter],
+    ) -> None:
+        pass  # no action required beyond sharing group definitions and counts
+
+    async def _update_fairbatch_sampling_probas(
+        self,
+        netwk: NetworkClient,
+    ) -> None:
+        """Run a FairBatch-specific routine to update sampling probabilities.
+
+        Expect a message from the orchestrating server containing the new
+        sensitive group sampling probabilities, and apply them to the
+        training dataset.
+
+        Raises
+        ------
+        RuntimeError:
+            If the expected message is not received.
+            If the sampling pobabilities' update fails.
+        """
+        # Receive aggregated sensitive weights.
+        received = await netwk.check_message()
+        message = await verify_server_message_validity(
+            netwk, received, expected=FairbatchSamplingProbas
+        )
+        probas = dict(zip(self.groups, message.probas))
+        # Set the received weights, handling and propagating exceptions if any.
+        try:
+            assert isinstance(self.manager.train_data, FairbatchDataset)
+            self.manager.train_data.set_sampling_probabilities(
+                group_probas=probas
+            )
+        except Exception as exc:
+            self.manager.logger.error(
+                "Exception encountered when setting FairBatch sampling"
+                "probabilities: %s",
+                repr(exc),
+            )
+            await netwk.send_message(Error(repr(exc)))
+            raise RuntimeError(
+                "FairBatch sampling probabilities update failed."
+            ) from exc
+        # If things went well, ping the server back to indicate so.
+        self.manager.logger.info("Updated FairBatch sampling probabilities.")
+        await netwk.send_message(FairbatchOkay())
+
+    def compute_fairness_measures(
+        self,
+        batch_size: int,
+        n_batch: Optional[int] = None,
+        thresh: Optional[float] = None,
+    ) -> List[float]:
+        # Compute group-wise accuracy scores and loss values.
+        accuracy, loss = self.computer.compute_groupwise_accuracy_and_loss(
+            model=self.manager.model,
+            batch_size=batch_size,
+            n_batch=n_batch,
+            thresh=thresh,
+        )
+        # Multiply these values by sample counts.
+        accuracy = {
+            key: val * self.computer.counts[key]
+            for key, val in accuracy.items()
+        }
+        loss = {
+            key: val * self.computer.counts[key] for key, val in loss.items()
+        }
+        # Return shareable group-wise values, ordered and filled out.
+        return [accuracy.get(group, 0.0) for group in self.groups] + [
+            loss.get(group, 0.0) for group in self.groups
+        ]
+
+    async def finalize_fairness_round(
+        self,
+        netwk: NetworkClient,
+        values: List[float],
+        secagg: Optional[Encrypter],
+    ) -> Dict[str, Union[float, np.ndarray]]:
+        # Await updated loss weights from the server.
+        await self._update_fairbatch_sampling_probas(netwk)
+        # Recover raw accuracy and loss values for groups with local samples.
+        accuracy = {
+            key: val / self.computer.counts[key]
+            for key, val in zip(self.groups, values[: len(self.groups)])
+            if key in self.computer.counts
+        }
+        loss = {
+            key: val / self.computer.counts[key]
+            for key, val in zip(self.groups, values[len(self.groups) :])
+            if key in self.computer.counts
+        }
+        # Compute local fairness measures.
+        fairness = self.fairness_function.compute_from_group_accuracy(accuracy)
+        f_type = self.fairness_function.f_type
+        # Package and return accuracy and fairness metrics.
+        metrics = {
+            f"accuracy_{key}": val for key, val in accuracy.items()
+        }  # type: Dict[str, Union[float, np.ndarray]]
+        metrics.update({f"loss_{key}": val for key, val in loss.items()})
+        metrics.update(
+            {f"{f_type}_{key}": val for key, val in fairness.items()}
+        )
+        return metrics
diff --git a/declearn/fairness/fairbatch/_dataset.py b/declearn/fairness/fairbatch/_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..8fb1411afca94b5862e2588aa587eb52bbbe7703
--- /dev/null
+++ b/declearn/fairness/fairbatch/_dataset.py
@@ -0,0 +1,292 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""FairBatch-specific Dataset wrapper and subclass."""
+
+from typing import Any, Dict, Iterator, List, Sequence, Tuple
+
+import numpy as np
+
+from declearn.dataset import Dataset, DataSpecs
+from declearn.fairness.core import FairnessDataset
+from declearn.typing import Batch
+
+__all__ = [
+    "FairbatchDataset",
+]
+
+
+class FairbatchDataset(FairnessDataset):
+    """FairBatch-specific FairnessDataset subclass and wrapper."""
+
+    def __init__(
+        self,
+        base: FairnessDataset,
+    ) -> None:
+        """Instantiate a FairbatchDataset wrapping a FairnessDataset.
+
+        Parameters
+        ----------
+        base:
+            Base `FairnessDataset` instance to wrap so as to apply
+            group-wise subsampling as per the FairBatch algorithm.
+        """
+        self.base = base
+        # Assign a dictionary with sampling probability for each group.
+        self.groups = self.get_sensitive_group_definitions()
+        self._counts = self.base.get_sensitive_group_counts()
+        self._sampling_probas = {
+            group: 1.0 / len(self.groups) for group in self.groups
+        }
+
+    # Methods provided by the wrapped dataset (merely interfaced).
+
+    def get_data_specs(
+        self,
+    ) -> DataSpecs:
+        return self.base.get_data_specs()
+
+    def get_sensitive_group_definitions(
+        self,
+    ) -> List[Tuple[Any, ...]]:
+        return self.groups
+
+    def get_sensitive_group_counts(
+        self,
+    ) -> Dict[Tuple[Any, ...], int]:
+        return self._counts.copy()
+
+    def get_sensitive_group_subset(
+        self,
+        group: Tuple[Any, ...],
+    ) -> Dataset:
+        return self.base.get_sensitive_group_subset(group)
+
+    def set_sensitive_group_weights(
+        self,
+        weights: Dict[Tuple[Any, ...], float],
+        adjust_by_counts: bool = False,
+    ) -> None:
+        self.base.set_sensitive_group_weights(weights, adjust_by_counts)
+
+    # FairBatch-specific methods.
+
+    def get_sampling_probabilities(
+        self,
+    ) -> Dict[Tuple[Any, ...], float]:
+        """Access current group-wise sampling probabilities."""
+        return self._sampling_probas.copy()
+
+    def set_sampling_probabilities(
+        self,
+        group_probas: Dict[Tuple[Any, ...], float],
+    ) -> None:
+        """Assign new group-wise sampling probabilities.
+
+        If some groups are not present in the wrapped dataset,
+        scale the probabilities associated with all represented
+        groups so that they sum to 1.
+
+        Parameters
+        ----------
+        group_probas:
+            Dict of group-wise sampling probabilities, with
+            `{(s_attr_1, ..., s_attr_k): sampling_proba}` format.
+
+        Raises
+        ------
+        ValueError
+            If the input probabilities are not positive values
+            or if they do not cover (a superset of) all sensitive
+            groups present in the wrapped dataset.
+        """
+        # Verify that input match expectations.
+        if not all(x >= 0 for x in group_probas.values()):
+            raise ValueError(
+                f"'{self.__class__.__name__}.update_sampling_probabilities' "
+                "cannot have a negative probability value as parameter."
+            )
+        if not set(self.groups).issubset(group_probas):
+            raise ValueError(
+                "'FairbatchDataset.update_sampling_probabilities' requires "
+                "input values to cover (a superset of) local sensitive groups."
+            )
+        # Restrict and adjust probabilities to groups with samples.
+        probas = {group: group_probas[group] for group in self.groups}
+        total = sum(probas.values())
+        self._sampling_probas = {
+            key: val / total for key, val in probas.items()
+        }
+
+    def generate_batches(
+        self,
+        batch_size: int,
+        shuffle: bool = False,
+        drop_remainder: bool = True,
+        replacement: bool = False,
+        poisson: bool = False,
+    ) -> Iterator[Batch]:
+        # inherited signature; pylint: disable=too-many-arguments
+        # NOTE: we could add support for those, but let's start simple.
+        if not drop_remainder:
+            raise ValueError(
+                f"'{self.__class__.__name__}.generate_batches' does not "
+                "support argument value 'drop_remainder=False'."
+            )
+        # Compute the number of batches to yield.
+        nb_batches = sum(self._counts.values()) // batch_size
+        # Compute the group-wise number of samples per batch.
+        # NOTE: this number may be reduced if there are too few samples.
+        group_batch_size = {
+            group: round(proba * batch_size)
+            for group, proba in self._sampling_probas.items()
+        }
+        # Yield batches made of a fixed number of samples from each group.
+        generators = [
+            self._generate_sensitive_group_batches(
+                group, nb_batches, g_batch_size, shuffle, replacement, poisson
+            )
+            for group, g_batch_size in group_batch_size.items()
+            if g_batch_size > 0
+        ]
+        for batches in zip(*generators):
+            yield self._concatenate_batches(batches)
+
+    @staticmethod
+    def _concatenate_batches(
+        batches: Sequence[Batch],
+    ) -> Batch:
+        """Concatenate batches of numpy array data."""
+        x_dat = np.concatenate([batch[0] for batch in batches], axis=0)
+        y_dat = (
+            None
+            if batches[0][1] is None
+            else np.concatenate([batch[1] for batch in batches], axis=0)
+        )
+        w_dat = (
+            None
+            if batches[0][2] is None
+            else np.concatenate([batch[2] for batch in batches], axis=0)
+        )
+        return x_dat, y_dat, w_dat
+
+    def _generate_sensitive_group_batches(
+        self,
+        group: Tuple[Any, ...],
+        nb_batches: int,
+        batch_size: int,
+        shuffle: bool,
+        replacement: bool,
+        poisson: bool,
+    ) -> Iterator[Batch]:
+        """Generate a fixed number of batches for a given sensitive group.
+
+        Parameters
+        ----------
+        group:
+            Sensitive group, the dataset from which to draw from.
+        nb_batches:
+            Number of batches to yield. The dataset will be iterated
+            over if needed to achieve this number.
+        batch_size:
+            Number of samples per batch (will be exact).
+        shuffle:
+            Whether to shuffle the dataset prior to drawing batches.
+        replacement:
+            Whether to draw with replacement between batches.
+        poisson:
+            Whether to use poisson sampling rather than batching.
+        """
+        # backend method; pylint: disable=too-many-arguments
+        # Fetch the target sub-dataset and its samples count.
+        dataset = self.get_sensitive_group_subset(group)
+        n_samples = self._counts[group]
+        # Adjust batch size when needed and set up a batches generator.
+        n_repeats, batch_size = divmod(batch_size, n_samples)
+        generator = self._generate_batches(
+            # fmt: off
+            dataset, group, nb_batches, batch_size,
+            shuffle, replacement, poisson,
+        )
+        # When the batch size is larger than the number of data points,
+        # make up a base batch will all points (duplicated if needed),
+        # that will be combined with further batches of data.
+        if n_repeats:
+            full = self._get_full_dataset(dataset, n_samples, group)
+            full = self._concatenate_batches([full] * n_repeats)
+            for batch in generator:
+                yield self._concatenate_batches((full, batch))
+        # Otherwise, merely yield from the generator.
+        else:
+            yield from generator
+
+    def _generate_batches(
+        self,
+        dataset: Dataset,
+        group: Tuple[Any, ...],
+        nb_batches: int,
+        batch_size: int,
+        shuffle: bool,
+        replacement: bool,
+        poisson: bool,
+    ) -> Iterator[Batch]:
+        """Backend to yield a fixed number of batches from a dataset."""
+        # backend method; pylint: disable=too-many-arguments
+        # Iterate multiple times over the sub-dataset if needed.
+        counter = 0
+        while counter < nb_batches:
+            # Yield batches from the sub-dataset.
+            generator = dataset.generate_batches(
+                batch_size=batch_size,
+                shuffle=shuffle,
+                drop_remainder=True,
+                replacement=replacement,
+                poisson=poisson,
+            )
+            for batch in generator:
+                yield batch
+                counter += 1
+                if counter == nb_batches:
+                    break
+            # Prevent infinite loops and raise an informative error.
+            if not counter:  # pragma: no cover
+                raise RuntimeError(
+                    f"'{self.__class__.__name__}.generate_batches' triggered "
+                    "an infinite loop; this happened when trying to extract "
+                    f"{batch_size}-samples batches for group {group}."
+                )
+
+    @staticmethod
+    def _get_full_dataset(
+        dataset: Dataset,
+        n_samples: int,
+        group: Tuple[Any, ...],
+    ) -> Batch:
+        """Return a batch containing an entire dataset's samples."""
+        try:
+            generator = dataset.generate_batches(
+                batch_size=n_samples,
+                shuffle=False,
+                drop_remainder=False,
+                replacement=False,
+                poisson=False,
+            )
+            return next(generator)
+        except StopIteration as exc:  # pragma: no cover
+            raise RuntimeError(
+                f"Failed to fetch the full subdataset for group '{group}'."
+            ) from exc
diff --git a/declearn/fairness/fairbatch/_fedfb.py b/declearn/fairness/fairbatch/_fedfb.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f3b1a1cf421d30ad3a7b54646eac9d343f2ff38
--- /dev/null
+++ b/declearn/fairness/fairbatch/_fedfb.py
@@ -0,0 +1,263 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""FedFB sampling probability controllers."""
+
+from typing import Any, Dict, Tuple
+
+import numpy as np
+
+from declearn.fairness.fairbatch._sampling import (
+    FairbatchDemographicParity,
+    FairbatchEqualizedOdds,
+    FairbatchEqualityOpportunity,
+    FairbatchSamplingController,
+    assign_sensitive_group_labels,
+)
+
+
+__all__ = [
+    "setup_fedfb_controller",
+]
+
+
+class FedFBEqualityOpportunity(FairbatchEqualityOpportunity):
+    """FedFB variant of Equality of Opportunity controller.
+
+    This variant introduces two changes as compared with our FedFairBatch:
+    - The lambda parameter and difference of losses are written with a
+      different group ordering, albeit resulting in identical results.
+    - When comparing loss values over sensitive groups, the notations from
+      the FedFB paper indicate that the sums of losses over samples in the
+      groups are compared, rather than the averages of group-wise losses;
+      this implementation sticks to the FedFB paper.
+    """
+
+    f_type = "equality_of_opportunity"
+
+    def get_sampling_probas(
+        self,
+    ) -> Dict[Tuple[Any, ...], float]:
+        # Revert the sense of lambda (invert (1, 0) and (0, 1) groups)
+        # to stick with notations from the FedFB paper.
+        probas = super().get_sampling_probas()
+        label_10 = self.groups["1_0"]
+        label_11 = self.groups["1_1"]
+        probas[label_10], probas[label_11] = probas[label_11], probas[label_10]
+        return probas
+
+    def update_from_losses(
+        self,
+        losses: Dict[Tuple[Any, ...], float],
+    ) -> None:
+        # Recover sum-aggregated losses for the two groups of interest.
+        # Do not scale: obtain sums of sample losses for each group.
+        # This differs from parent class (and centralized FairBatch)
+        # but sticks with the FedFB paper's notations and algorithm.
+        loss_10 = losses[self.groups["1_0"]] * self.counts[self.groups["1_0"]]
+        loss_11 = losses[self.groups["1_1"]] * self.counts[self.groups["1_1"]]
+        # Update lambda based on these and the alpha hyper-parameter.
+        # Note: this is the same as in parent class, inverting sense of
+        # groups (0, 0) and (1, 0), to stick with the FedFB paper.
+        if loss_11 > loss_10:
+            self.states["lambda"] = min(
+                self.states["lambda"] + self.alpha, self.states["p_tgt_1"]
+            )
+        elif loss_11 < loss_10:
+            self.states["lambda"] = max(self.states["lambda"] - self.alpha, 0)
+
+
+class FedFBEqualizedOdds(FairbatchEqualizedOdds):
+    """FedFB variant of Equalized Odds controller.
+
+    This variant introduces three changes as compared with our FedFairBatch:
+    - The lambda parameters and difference of losses are written with a
+      different group ordering, albeit resulting in identical results.
+    - When comparing loss values over sensitive groups, the notations from
+      the FedFB paper indicate that the sums of losses over samples in the
+      groups are compared, rather than the averages of group-wise losses;
+      this implementation sticks to the FedFB paper.
+    - The update rule for lambda parameters has a distinct formula, with the
+      alpha hyper-parameter being here scaled by the difference in losses
+      and normalized by the L2 norm of differences in losses, and both groups'
+      lambda being updated at each step.
+    """
+
+    f_type = "equalized_odds"
+
+    def compute_initial_states(
+        self,
+    ) -> Dict[str, float]:
+        # Switch lambdas: apply to groups (-, 1) rather than (-, 0).
+        states = super().compute_initial_states()
+        states["lambda_1"] = states["p_trgt_0"] - states["lambda_1"]
+        states["lambda_2"] = states["p_trgt_1"] - states["lambda_2"]
+        return states
+
+    def get_sampling_probas(
+        self,
+    ) -> Dict[Tuple[Any, ...], float]:
+        # Rewrite the rules entirely, effectively swapping (0,1)/(0,0)
+        # and (1,1)/(1,0) groups compared with parent implementation.
+        states = self.states
+        return {
+            self.groups["0_0"]: states["p_trgt_0"] - states["lambda_1"],
+            self.groups["0_1"]: states["lambda_1"],
+            self.groups["1_0"]: states["p_trgt_1"] - states["lambda_2"],
+            self.groups["1_1"]: states["lambda_2"],
+        }
+
+    def update_from_losses(
+        self,
+        losses: Dict[Tuple[Any, ...], float],
+    ) -> None:
+        # Recover sum-aggregated losses for each sensitive group.
+        # Do not scale: obtain sums of sample losses for each group.
+        # This differs from parent class (and centralized FairBatch)
+        # but sticks with the FedFB paper's notations and algorithm.
+        labeled_losses = {
+            label: losses[group] * self.counts[group]
+            for label, group in self.groups.items()
+        }
+        # Compute aggregated-loss differences for each target label.
+        diff_loss_tgt_0 = labeled_losses["0_1"] - labeled_losses["0_0"]
+        diff_loss_tgt_1 = labeled_losses["1_1"] - labeled_losses["1_0"]
+        # Compute the euclidean norm of these values.
+        den = float(np.linalg.norm([diff_loss_tgt_0, diff_loss_tgt_1], ord=2))
+        # Update lambda_1 (affecting groups with y=0).
+        update = self.alpha * diff_loss_tgt_0 / den
+        self.states["lambda_1"] = min(
+            self.states["lambda_1"] + update, self.states["p_trgt_0"]
+        )
+        self.states["lambda_1"] = max(self.states["lambda_1"], 0)
+        # Update lambda_1 (affecting groups with y=1).
+        update = self.alpha * diff_loss_tgt_1 / den
+        self.states["lambda_2"] = min(
+            self.states["lambda_2"] + update, self.states["p_trgt_1"]
+        )
+        self.states["lambda_2"] = max(self.states["lambda_2"], 0)
+
+
+class FedFBDemographicParity(FairbatchDemographicParity):
+    """FairbatchSamplingController subclass for 'demographic_parity'."""
+
+    f_type = "demographic_parity"
+
+    def update_from_losses(
+        self,
+        losses: Dict[Tuple[Any, ...], float],
+    ) -> None:
+        # NOTE: losses' aggregation does not defer from parent class.
+        # Recover sum-aggregated losses for each sensitive group.
+        # Obtain {k: n_k * Sum(loss for all samples in group k)}.
+        labeled_losses = {
+            label: losses[group] * self.counts[group]
+            for label, group in self.groups.items()
+        }
+        # Normalize losses based on sensitive attribute counts.
+        # Obtain {k: sum(loss for samples in k) / n_samples_with_attr}.
+        labeled_losses["0_0"] /= self.states["n_attr_0"]
+        labeled_losses["0_1"] /= self.states["n_attr_1"]
+        labeled_losses["1_0"] /= self.states["n_attr_0"]
+        labeled_losses["1_1"] /= self.states["n_attr_1"]
+        # NOTE: this is where things differ from parent class.
+        # Compute an overall fairness value based on all losses.
+        f_val = (
+            -labeled_losses["0_0"]
+            + labeled_losses["0_1"]
+            + labeled_losses["1_0"]
+            - labeled_losses["1_1"]
+            + self.counts[self.groups["0_0"]] / self.states["n_attr_0"]
+            - self.counts[self.groups["0_1"]] / self.states["n_attr_1"]
+        )
+        # Update both lambdas based on this overall value.
+        # Note: in the binary attribute case, $mu_a / ||mu||_2$
+        # is equal to $sign(mu_1) / sqrt(2)$.
+        update = float(np.sign(f_val) * self.alpha / np.sqrt(2))
+        self.states["lambda_1"] = min(
+            self.states["lambda_1"] - update, self.states["p_attr_0"]
+        )
+        self.states["lambda_1"] = max(self.states["lambda_1"], 0)
+        self.states["lambda_2"] = min(
+            self.states["lambda_2"] - update, self.states["p_attr_1"]
+        )
+        self.states["lambda_2"] = max(self.states["lambda_2"], 0)
+
+
+def setup_fedfb_controller(
+    f_type: str,
+    counts: Dict[Tuple[Any, ...], int],
+    target: int = 1,
+    alpha: float = 0.005,
+) -> FairbatchSamplingController:
+    """Instantiate a FedFB sampling probabilities controller.
+
+    This is a drop-in replacement for `setup_fedfairbatch_controller`
+    that implemented update rules matching the Fed-FB algorithm(s) as
+    introduced in [1].
+
+    Parameters
+    ----------
+    f_type:
+        Type of group fairness to optimize for.
+    counts:
+        Dict mapping sensitive group definitions to their total
+        sample counts (across clients). These groups must arise
+        from the crossing of a binary target label and a binary
+        sensitive attribute.
+    target:
+        Target label to treat as positive.
+    alpha:
+        Alpha hyper-parameter, scaling the magnitude of sampling
+        probabilities' updates by the returned controller.
+
+    Returns
+    -------
+    controller:
+        FairBatch sampling probabilities controller matching inputs.
+
+    Raises
+    ------
+    KeyError
+        If `f_type` does not match any known or supported fairness type.
+    ValueError
+        If `counts` keys cannot be matched to canonical group labels.
+
+    References
+    ----------
+    [1] Zeng et al. (2022).
+        Improving Fairness via Federated Learning.
+        https://arxiv.org/abs/2110.15545
+    """
+    controller_types = {
+        "demographic_parity": FedFBDemographicParity,
+        "equality_of_opportunity": FedFBEqualityOpportunity,
+        "equalized_odds": FedFBEqualizedOdds,
+    }
+    controller_cls = controller_types.get(f_type, None)
+    if controller_cls is None:
+        raise KeyError(
+            "Unknown or unsupported fairness type parameter for FairBatch "
+            f"controller initialization: '{f_type}'. Supported values are "
+            f"{list(controller_types)}."
+        )
+    # Match groups to canonical labels and instantiate the controller.
+    groups = assign_sensitive_group_labels(groups=list(counts), target=target)
+    kwargs = {"target": target} if f_type == "equality_of_opportunity" else {}
+    return controller_cls(  # type: ignore[abstract]
+        groups=groups, counts=counts, alpha=alpha, **kwargs
+    )
diff --git a/declearn/fairness/fairbatch/_messages.py b/declearn/fairness/fairbatch/_messages.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1b9ec19db977cb64ba736f194a459a3b5c23b0d
--- /dev/null
+++ b/declearn/fairness/fairbatch/_messages.py
@@ -0,0 +1,53 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Fed-FairBatch/Fed-FB specific messages."""
+
+import dataclasses
+from typing import List
+
+
+from declearn.messaging import Message
+
+
+__all__ = [
+    "FairbatchOkay",
+    "FairbatchSamplingProbas",
+]
+
+
+@dataclasses.dataclass
+class FairbatchOkay(Message):
+    """Message for client signal that Fed-FairBatch/FedFB update went fine."""
+
+    typekey = "fairbatch-okay"
+
+
+@dataclasses.dataclass
+class FairbatchSamplingProbas(Message):
+    """Message for server-emitted Fed-FairBatch/Fed-FB sampling probabilities.
+
+    Fields
+    ------
+    probas:
+        List of group-wise sampling probabilities, ordered based on
+        an agreed-upon sorted list of sensitive groups.
+    """
+
+    probas: List[float]
+
+    typekey = "fairbatch-probas"
diff --git a/declearn/fairness/fairbatch/_sampling.py b/declearn/fairness/fairbatch/_sampling.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc2e1f9cc0d09978cffcfbd664ab78eddc4944ef
--- /dev/null
+++ b/declearn/fairness/fairbatch/_sampling.py
@@ -0,0 +1,435 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""FairBatch sampling probability controllers."""
+
+import abc
+from typing import Any, ClassVar, Dict, List, Literal, Tuple
+
+
+from declearn.fairness.core import instantiate_fairness_function
+
+
+__all__ = [
+    "FairbatchSamplingController",
+    "setup_fairbatch_controller",
+]
+
+
+GroupLabel = Literal["0_0", "0_1", "1_0", "1_1"]
+
+
+class FairbatchSamplingController(metaclass=abc.ABCMeta):
+    """ABC to compute and update Fairbatch sampling probabilities."""
+
+    f_type: ClassVar[str]
+
+    def __init__(
+        self,
+        groups: Dict[GroupLabel, Tuple[Any, ...]],
+        counts: Dict[Tuple[Any, ...], int],
+        alpha: float = 0.005,
+        **kwargs: Any,
+    ) -> None:
+        """Instantiate the Fairbatch sampling probabilities controller.
+
+        Parameters
+        ----------
+        groups:
+            Dict mapping canonical labels to sensitive group definitions.
+        counts:
+            Dict mapping sensitive group definitions to sample counts.
+        alpha:
+            Hyper-parameter controlling the update rule for internal
+            states and thereof sampling probabilities.
+        **kwargs:
+            Keyword arguments specific to the fairness definition in use.
+        """
+        # Assign input parameters as attributes.
+        self.groups = groups
+        self.counts = counts
+        self.total = sum(counts.values())
+        self.alpha = alpha
+        # Initialize internal states and sampling probabilities.
+        self.states = self.compute_initial_states()
+        # Initialize a fairness function.
+        self.f_func = instantiate_fairness_function(
+            f_type=self.f_type, counts=counts, **kwargs
+        )
+
+    @abc.abstractmethod
+    def compute_initial_states(
+        self,
+    ) -> Dict[str, float]:
+        """Return a dict containing initial internal states.
+
+        Returns
+        -------
+        states:
+            Dict associating float values to arbitrary names that
+            depend on the type of group-fairness being optimized.
+        """
+
+    @abc.abstractmethod
+    def get_sampling_probas(
+        self,
+    ) -> Dict[Tuple[Any, ...], float]:
+        """Return group-wise sampling probabilities.
+
+        Returns
+        -------
+        sampling_probas:
+            Dict mapping sensitive group definitions to their sampling
+            probabilities, as establised via the FairBatch algorithm.
+        """
+
+    @abc.abstractmethod
+    def update_from_losses(
+        self,
+        losses: Dict[Tuple[Any, ...], float],
+    ) -> None:
+        """Update internal states based on group-wise losses.
+
+        Parameters
+        ----------
+        losses:
+            Group-wise model loss values, as a `{group_k: loss_k}` dict.
+        """
+
+    def update_from_federated_losses(
+        self,
+        losses: Dict[Tuple[Any, ...], float],
+    ) -> None:
+        """Update internal states based on federated group-wise losses.
+
+        Parameters
+        ----------
+        losses:
+            Group-wise sum-aggregated local-group-count-weighted model
+            loss values, computed over an ensemble of local datasets.
+            I.e. `{group_k: sum_i(n_ik * loss_ik)}` dict.
+
+        Raises
+        ------
+        KeyError
+            If any defined sensitive group does not have a loss value.
+        """
+        losses = {key: val / self.counts[key] for key, val in losses.items()}
+        self.update_from_losses(losses)
+
+
+class FairbatchEqualityOpportunity(FairbatchSamplingController):
+    """FairbatchSamplingController subclass for 'equality_of_opportunity'."""
+
+    f_type = "equality_of_opportunity"
+
+    def compute_initial_states(
+        self,
+    ) -> Dict[str, float]:
+        # Gather sample counts and share with positive target label.
+        nsmp_10 = self.counts[self.groups["1_0"]]
+        nsmp_11 = self.counts[self.groups["1_1"]]
+        p_tgt_1 = (nsmp_10 + nsmp_11) / self.total
+        # Assign the initial lambda and fixed quantities to re-use.
+        return {
+            "lambda": nsmp_10 / self.total,
+            "p_tgt_1": p_tgt_1,
+            "p_g_0_0": self.counts[self.groups["0_0"]] / self.total,
+            "p_g_0_1": self.counts[self.groups["0_1"]] / self.total,
+        }
+
+    def get_sampling_probas(
+        self,
+    ) -> Dict[Tuple[Any, ...], float]:
+        return {
+            self.groups["0_0"]: self.states["p_g_0_0"],
+            self.groups["0_1"]: self.states["p_g_0_1"],
+            self.groups["1_0"]: self.states["lambda"],
+            self.groups["1_1"]: self.states["p_tgt_1"] - self.states["lambda"],
+        }
+
+    def update_from_losses(
+        self,
+        losses: Dict[Tuple[Any, ...], float],
+    ) -> None:
+        # Gather mean-aggregated losses for the two groups of interest.
+        loss_10 = losses[self.groups["1_0"]]
+        loss_11 = losses[self.groups["1_1"]]
+        # Update lambda based on these and the alpha hyper-parameter.
+        if loss_10 > loss_11:
+            self.states["lambda"] = min(
+                self.states["lambda"] + self.alpha, self.states["p_tgt_1"]
+            )
+        elif loss_10 < loss_11:
+            self.states["lambda"] = max(self.states["lambda"] - self.alpha, 0)
+
+
+class FairbatchEqualizedOdds(FairbatchSamplingController):
+    """FairbatchSamplingController subclass for 'equalized_odds'."""
+
+    f_type = "equalized_odds"
+
+    def compute_initial_states(
+        self,
+    ) -> Dict[str, float]:
+        # Gather sample counts.
+        nsmp_00 = self.counts[self.groups["0_0"]]
+        nsmp_01 = self.counts[self.groups["0_1"]]
+        nsmp_10 = self.counts[self.groups["1_0"]]
+        nsmp_11 = self.counts[self.groups["1_1"]]
+        # Compute initial lambas, and attribute-wise sample counts.
+        return {
+            "lambda_1": nsmp_00 / self.total,
+            "lambda_2": nsmp_10 / self.total,
+            "p_trgt_0": (nsmp_00 + nsmp_01) / self.total,
+            "p_trgt_1": (nsmp_10 + nsmp_11) / self.total,
+        }
+
+    def get_sampling_probas(
+        self,
+    ) -> Dict[Tuple[Any, ...], float]:
+        states = self.states
+        return {
+            self.groups["0_0"]: states["lambda_1"],
+            self.groups["0_1"]: states["p_trgt_0"] - states["lambda_1"],
+            self.groups["1_0"]: states["lambda_2"],
+            self.groups["1_1"]: states["p_trgt_1"] - states["lambda_2"],
+        }
+
+    def update_from_losses(
+        self,
+        losses: Dict[Tuple[Any, ...], float],
+    ) -> None:
+        # Compute loss differences for each target label.
+        diff_loss_tgt_0 = (
+            losses[self.groups["0_0"]] - losses[self.groups["0_1"]]
+        )
+        diff_loss_tgt_1 = (
+            losses[self.groups["1_0"]] - losses[self.groups["1_1"]]
+        )
+        # Update a lambda based on these and the alpha hyper-parameter.
+        if abs(diff_loss_tgt_0) > abs(diff_loss_tgt_1):
+            if diff_loss_tgt_0 > 0:
+                self.states["lambda_1"] = min(
+                    self.states["lambda_1"] + self.alpha,
+                    self.states["p_trgt_0"],
+                )
+            elif diff_loss_tgt_0 < 0:
+                self.states["lambda_1"] = max(
+                    self.states["lambda_1"] - self.alpha, 0
+                )
+        else:
+            if diff_loss_tgt_1 > 0:
+                self.states["lambda_2"] = min(
+                    self.states["lambda_2"] + self.alpha,
+                    self.states["p_trgt_1"],
+                )
+            elif diff_loss_tgt_1 < 0:
+                self.states["lambda_2"] = max(
+                    self.states["lambda_2"] - self.alpha, 0
+                )
+
+
+class FairbatchDemographicParity(FairbatchSamplingController):
+    """FairbatchSamplingController subclass for 'demographic_parity'."""
+
+    f_type = "demographic_parity"
+
+    def compute_initial_states(
+        self,
+    ) -> Dict[str, float]:
+        # Gather sample counts.
+        nsmp_00 = self.counts[self.groups["0_0"]]
+        nsmp_01 = self.counts[self.groups["0_1"]]
+        nsmp_10 = self.counts[self.groups["1_0"]]
+        nsmp_11 = self.counts[self.groups["1_1"]]
+        # Compute initial lambas, and target-label-wise sample counts.
+        return {
+            "lambda_1": nsmp_00 / self.total,
+            "lambda_2": nsmp_01 / self.total,
+            "p_attr_0": (nsmp_00 + nsmp_10) / self.total,
+            "p_attr_1": (nsmp_01 + nsmp_11) / self.total,
+            "n_attr_0": nsmp_00 + nsmp_10,
+            "n_attr_1": nsmp_01 + nsmp_11,
+        }
+
+    def get_sampling_probas(
+        self,
+    ) -> Dict[Tuple[Any, ...], float]:
+        states = self.states
+        return {
+            self.groups["0_0"]: states["lambda_1"],
+            self.groups["1_0"]: states["p_attr_0"] - states["lambda_1"],
+            self.groups["0_1"]: states["lambda_2"],
+            self.groups["1_1"]: states["p_attr_1"] - states["lambda_2"],
+        }
+
+    def update_from_losses(
+        self,
+        losses: Dict[Tuple[Any, ...], float],
+    ) -> None:
+        # Recover sum-aggregated losses for each sensitive group.
+        # Obtain {k: n_k * Sum(loss for all samples in group k)}.
+        labeled_losses = {
+            label: losses[group] * self.counts[group]
+            for label, group in self.groups.items()
+        }
+        # Normalize losses based on sensitive attribute counts.
+        # Obtain {k: sum(loss for samples in k) / n_samples_with_attr}.
+        labeled_losses["0_0"] /= self.states["n_attr_0"]
+        labeled_losses["0_1"] /= self.states["n_attr_1"]
+        labeled_losses["1_0"] /= self.states["n_attr_0"]
+        labeled_losses["1_1"] /= self.states["n_attr_1"]
+        # Compute aggregated-loss differences for each target label.
+        diff_loss_tgt_0 = labeled_losses["0_0"] - labeled_losses["0_1"]
+        diff_loss_tgt_1 = labeled_losses["1_0"] - labeled_losses["1_1"]
+        # Update a lambda based on these and the alpha hyper-parameter.
+        if abs(diff_loss_tgt_0) > abs(diff_loss_tgt_1):
+            if diff_loss_tgt_0 > 0:
+                self.states["lambda_1"] = max(
+                    self.states["lambda_1"] - self.alpha, 0
+                )
+            elif diff_loss_tgt_0 < 0:
+                self.states["lambda_1"] = min(
+                    self.states["lambda_1"] + self.alpha,
+                    self.states["p_attr_0"],
+                )
+        else:
+            if diff_loss_tgt_1 > 0:
+                self.states["lambda_2"] = min(
+                    self.states["lambda_2"] + self.alpha,
+                    self.states["p_attr_1"],
+                )
+            elif diff_loss_tgt_1 < 0:
+                self.states["lambda_2"] = max(
+                    self.states["lambda_2"] - self.alpha, 0
+                )
+
+
+def assign_sensitive_group_labels(
+    groups: List[Tuple[Any, ...]],
+    target: int,
+) -> Dict[GroupLabel, Tuple[Any, ...]]:
+    """Parse sensitive group definitions to match canonical labels.
+
+    Parameters
+    ----------
+    groups:
+        List of sensitive group definitions, as a list of tuples.
+        These should be four tuples arising from the intersection
+        of binary labels (with any actual type).
+    target:
+        Value of the target label to treat as positive.
+
+    Returns
+    -------
+    labeled_groups:
+        Dict mapping canonical labels `"0_0", "0_1", "1_0", "1_1"`
+        to the input sensitive group definitions.
+
+    Raises
+    ------
+    ValueError
+        If 'groups' has unproper length, values that do not appear
+        to be binary, or that do not match the specified 'target'.
+    """
+    # Verify that groups can be identified as crossing two binary labels.
+    if len(groups) != 4:
+        raise ValueError(
+            "FairBatch requires counts over exactly 4 sensitive groups, "
+            "arising from a binary target label and a binary sensitive "
+            "attribute."
+        )
+    target_values = list({group[0] for group in groups})
+    s_attr_values = sorted(list({group[1] for group in groups}))
+    if not len(target_values) == len(s_attr_values) == 2:
+        raise ValueError(
+            "FairBatch requires sensitive groups to arise from a binary "
+            "target label and a binary sensitive attribute."
+        )
+    # Identify the positive and negative label values.
+    if target_values[0] == target:
+        postgt, negtgt = target_values
+    elif target_values[1] == target:
+        negtgt, postgt = target_values
+    else:
+        raise ValueError(
+            f"Received a target value of '{target}' that does not match any "
+            f"value in the sensitive group definitions: {target_values}."
+        )
+    # Match group definitions with canonical string labels.
+    return {
+        "0_0": (negtgt, s_attr_values[0]),
+        "0_1": (negtgt, s_attr_values[1]),
+        "1_0": (postgt, s_attr_values[0]),
+        "1_1": (postgt, s_attr_values[1]),
+    }
+
+
+def setup_fairbatch_controller(
+    f_type: str,
+    counts: Dict[Tuple[Any, ...], int],
+    target: int = 1,
+    alpha: float = 0.005,
+) -> FairbatchSamplingController:
+    """Instantiate a FairBatch sampling probabilities controller.
+
+    Parameters
+    ----------
+    f_type:
+        Type of group fairness to optimize for.
+    counts:
+        Dict mapping sensitive group definitions to their total
+        sample counts (across clients). These groups must arise
+        from the crossing of a binary target label and a binary
+        sensitive attribute.
+    target:
+        Target label to treat as positive.
+    alpha:
+        Alpha hyper-parameter, scaling the magnitude of sampling
+        probabilities' updates by the returned controller.
+
+    Returns
+    -------
+    controller:
+        FairBatch sampling probabilities controller matching inputs.
+
+    Raises
+    ------
+    KeyError
+        If `f_type` does not match any known or supported fairness type.
+    ValueError
+        If `counts` keys cannot be matched to canonical group labels.
+    """
+    controller_types = {
+        "demographic_parity": FairbatchDemographicParity,
+        "equality_of_opportunity": FairbatchEqualityOpportunity,
+        "equalized_odds": FairbatchEqualizedOdds,
+    }
+    controller_cls = controller_types.get(f_type, None)
+    if controller_cls is None:
+        raise KeyError(
+            "Unknown or unsupported fairness type parameter for FairBatch "
+            f"controller initialization: '{f_type}'. Supported values are "
+            f"{list(controller_types)}."
+        )
+    # Match groups to canonical labels and instantiate the controller.
+    groups = assign_sensitive_group_labels(groups=list(counts), target=target)
+    kwargs = {"target": target} if f_type == "equality_of_opportunity" else {}
+    return controller_cls(  # type: ignore[abstract]
+        groups=groups, counts=counts, alpha=alpha, **kwargs
+    )
diff --git a/declearn/fairness/fairbatch/_server.py b/declearn/fairness/fairbatch/_server.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e1c99af7144f23a029945e7118ca5168fdf1b0d
--- /dev/null
+++ b/declearn/fairness/fairbatch/_server.py
@@ -0,0 +1,182 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Server-side Fed-FairBatch/FedFB controller."""
+
+import warnings
+from typing import Any, Dict, List, Optional, Union
+
+import numpy as np
+
+from declearn.aggregator import Aggregator, SumAggregator
+from declearn.communication.api import NetworkServer
+from declearn.communication.utils import verify_client_messages_validity
+from declearn.fairness.api import FairnessControllerServer
+from declearn.fairness.fairbatch._fedfb import setup_fedfb_controller
+from declearn.fairness.fairbatch._messages import (
+    FairbatchOkay,
+    FairbatchSamplingProbas,
+)
+from declearn.fairness.fairbatch._sampling import setup_fairbatch_controller
+from declearn.messaging import FairnessSetupQuery
+from declearn.secagg.api import Decrypter
+
+
+__all__ = [
+    "FairbatchControllerServer",
+]
+
+
+class FairbatchControllerServer(FairnessControllerServer):
+    """Server-side controller to implement Fed-FairBatch or FedFB.
+
+    References
+    ----------
+    - [1]
+        Roh et al. (2020).
+        FairBatch: Batch Selection for Model Fairness.
+        https://arxiv.org/abs/2012.01696
+    - [2]
+        Zeng et al. (2022).
+        Improving Fairness via Federated Learning.
+        https://arxiv.org/abs/2110.15545
+    """
+
+    algorithm = "fed-fairbatch"
+
+    def __init__(
+        self,
+        f_type: str,
+        f_args: Optional[Dict[str, Any]] = None,
+        alpha: float = 0.005,
+        fedfb: bool = False,
+    ) -> None:
+        """Instantiate the server-side Fed-FairGrad controller.
+
+        Parameters
+        ----------
+        f_type:
+            Name of the fairness function to evaluate and optimize.
+        f_args:
+            Optional dict of keyword arguments to the fairness function.
+        alpha:
+            Hyper-parameter controlling the update rule for internal
+            states and thereof sampling probabilities.
+        fedfb:
+            Whether to use FedFB formulas rather than to stick
+            to those from the original FairBatch paper.
+        """
+        super().__init__(f_type=f_type, f_args=f_args)
+        # Choose whether to use FedFB or FairBatch update rules.
+        self._setup_function = (
+            setup_fedfb_controller if fedfb else setup_fairbatch_controller
+        )
+        # Set up a temporary controller that will be replaced at setup time.
+        self.sampling_controller = self._setup_function(
+            f_type=self.f_type,
+            counts={(0, 0): 1, (0, 1): 1, (1, 0): 1, (1, 1): 1},
+            target=self.f_args.get("target", 1),
+            alpha=alpha,
+        )
+
+    @property
+    def fedfb(self) -> bool:
+        """Whether this controller implements FedFB rather than Fed-FairBatch.
+
+        FedFB is a published adaptation of FairBatch to the federated
+        setting, that introduces changes to some FairBatch formulas.
+
+        Fed-FairBatch is a DecLearn-introduced variant of FedFB that
+        restores the original FairBatch formulas.
+        """
+        return self._setup_function is setup_fedfb_controller
+
+    def prepare_fairness_setup_query(
+        self,
+    ) -> FairnessSetupQuery:
+        query = super().prepare_fairness_setup_query()
+        query.params.update({"f_type": self.f_type, "f_args": self.f_args})
+        return query
+
+    async def finalize_fairness_setup(
+        self,
+        netwk: NetworkServer,
+        counts: List[int],
+        aggregator: Aggregator,
+    ) -> Aggregator:
+        # Set up the FairbatchWeightsController.
+        self.sampling_controller = self._setup_function(
+            f_type=self.f_type,
+            counts=dict(zip(self.groups, counts)),
+            target=self.f_args.get("target", 1),
+            alpha=self.sampling_controller.alpha,
+        )
+        # Send initial loss weights to the clients.
+        await self._send_fairbatch_probas(netwk)
+        # Force the use of a SumAggregator.
+        if not isinstance(aggregator, SumAggregator):
+            warnings.warn(
+                "Overriding Aggregator choice to a 'SumAggregator', "
+                "due to the use of Fed-FairBatch.",
+                category=RuntimeWarning,
+            )
+            aggregator = SumAggregator()
+        return aggregator
+
+    async def _send_fairbatch_probas(
+        self,
+        netwk: NetworkServer,
+    ) -> None:
+        """Send FairBatch sensitive group sampling probabilities to clients.
+
+        Await for clients to ping back that things went fine on their side.
+        """
+        netwk.logger.info(
+            "Sending FairBatch sampling probabilities to clients."
+        )
+        probas = self.sampling_controller.get_sampling_probas()
+        p_list = [probas[group] for group in self.groups]
+        await netwk.broadcast_message(FairbatchSamplingProbas(p_list))
+        received = await netwk.wait_for_messages()
+        await verify_client_messages_validity(
+            netwk, received, expected=FairbatchOkay
+        )
+
+    async def finalize_fairness_round(
+        self,
+        round_i: int,
+        values: List[float],
+        netwk: NetworkServer,
+        secagg: Optional[Decrypter],
+    ) -> Dict[str, Union[float, np.ndarray]]:
+        # Unpack group-wise accuracy and loss values.
+        accuracy = dict(zip(self.groups, values[: len(self.groups)]))
+        loss = dict(zip(self.groups, values[len(self.groups) :]))
+        # Update sampling probabilities and send them to clients.
+        self.sampling_controller.update_from_federated_losses(loss)
+        await self._send_fairbatch_probas(netwk)
+        # Package and return accuracy, loss and fairness metrics.
+        metrics = {
+            f"accuracy_{key}": val for key, val in accuracy.items()
+        }  # type: Dict[str, Union[float, np.ndarray]]
+        metrics.update({f"loss_{key}": val for key, val in loss.items()})
+        f_func = self.sampling_controller.f_func
+        fairness = f_func.compute_from_federated_group_accuracy(accuracy)
+        metrics.update(
+            {f"{self.f_type}_{key}": val for key, val in fairness.items()}
+        )
+        return metrics