diff --git a/declearn/fairness/fairbatch/__init__.py b/declearn/fairness/fairbatch/__init__.py
index 63ba46d24dd6a1eb1287bf242055dfaf5bb34cda..56dbad116edfe73753e6e843a86cb52a261cef12 100644
--- a/declearn/fairness/fairbatch/__init__.py
+++ b/declearn/fairness/fairbatch/__init__.py
@@ -53,6 +53,8 @@ Controllers
 
 Backend
 -------
+* [FairbatchDataset][declearn.fairness.fairbatch.FairbatchDataset]:
+    FairBatch-specific FairnessDataset subclass and wrapper.
 * [FairbatchSamplingController]
 [declearn.fairness.fairbatch.FairbatchSamplingController]:
     ABC to compute and update Fairbatch sampling probabilities.
@@ -79,5 +81,6 @@ from ._sampling import (
     setup_fairbatch_controller,
 )
 from ._fedfb import setup_fedfb_controller
+from ._dataset import FairbatchDataset
 from ._client import FairbatchControllerClient
 from ._server import FairbatchControllerServer
diff --git a/declearn/fairness/fairbatch/_client.py b/declearn/fairness/fairbatch/_client.py
index 264e639bba0c1fa5fd67df75f2adbc45fbe01435..dd91840b5f2bca28f16fa4346aa6e4fa2795d569 100644
--- a/declearn/fairness/fairbatch/_client.py
+++ b/declearn/fairness/fairbatch/_client.py
@@ -66,6 +66,8 @@ class FairbatchControllerClient(FairnessControllerClient):
         # Force the use of a SumAggregator.
         if not isinstance(self.manager.aggrg, SumAggregator):
             self.manager.aggrg = SumAggregator()
+        # Receive and assign initial sampling probabilities.
+        await self._update_fairbatch_sampling_probas(netwk)
 
     async def _update_fairbatch_sampling_probas(
         self,
diff --git a/declearn/fairness/fairbatch/_dataset.py b/declearn/fairness/fairbatch/_dataset.py
index 83ffd70d6881c2c0771e1dc69ff9d44b85f807d3..2949273c0b60394f132faf672e0bb60813e024b1 100644
--- a/declearn/fairness/fairbatch/_dataset.py
+++ b/declearn/fairness/fairbatch/_dataset.py
@@ -47,7 +47,7 @@ class FairbatchDataset(FairnessDataset):
         """
         self.base = base
         # Assign a dictionary with sampling probability for each group.
-        self.groups = self.get_sensitive_group_definitions()
+        self.groups = self.base.get_sensitive_group_definitions()
         self._counts = self.base.get_sensitive_group_counts()
         self._sampling_probas = {
             group: 1.0 / len(self.groups) for group in self.groups
diff --git a/declearn/fairness/fairbatch/_server.py b/declearn/fairness/fairbatch/_server.py
index 89719fa8f70456f3e9064a2cd0ef0829762a73d8..8f0592af47583f26234edd3a79b94d6e6d2a8bec 100644
--- a/declearn/fairness/fairbatch/_server.py
+++ b/declearn/fairness/fairbatch/_server.py
@@ -55,7 +55,7 @@ class FairbatchControllerServer(FairnessControllerServer):
         https://arxiv.org/abs/2110.15545
     """
 
-    algorithm = "fed-fairbatch"
+    algorithm = "fedfairbatch"
 
     def __init__(
         self,