From 1e88aa456ec945fbe268ceba80a2234260a0777f Mon Sep 17 00:00:00 2001
From: Paul Andrey <paul.andrey@inria.fr>
Date: Fri, 5 Jul 2024 16:18:34 +0200
Subject: [PATCH] Fix some bugs in FairBatch controllers.

---
 declearn/fairness/fairbatch/__init__.py | 3 +++
 declearn/fairness/fairbatch/_client.py  | 2 ++
 declearn/fairness/fairbatch/_dataset.py | 2 +-
 declearn/fairness/fairbatch/_server.py  | 2 +-
 4 files changed, 7 insertions(+), 2 deletions(-)

diff --git a/declearn/fairness/fairbatch/__init__.py b/declearn/fairness/fairbatch/__init__.py
index 63ba46d..56dbad1 100644
--- a/declearn/fairness/fairbatch/__init__.py
+++ b/declearn/fairness/fairbatch/__init__.py
@@ -53,6 +53,8 @@ Controllers
 
 Backend
 -------
+* [FairbatchDataset][declearn.fairness.fairbatch.FairbatchDataset]:
+    FairBatch-specific FairnessDataset subclass and wrapper.
 * [FairbatchSamplingController]
 [declearn.fairness.fairbatch.FairbatchSamplingController]:
     ABC to compute and update Fairbatch sampling probabilities.
@@ -79,5 +81,6 @@ from ._sampling import (
     setup_fairbatch_controller,
 )
 from ._fedfb import setup_fedfb_controller
+from ._dataset import FairbatchDataset
 from ._client import FairbatchControllerClient
 from ._server import FairbatchControllerServer
diff --git a/declearn/fairness/fairbatch/_client.py b/declearn/fairness/fairbatch/_client.py
index 264e639..dd91840 100644
--- a/declearn/fairness/fairbatch/_client.py
+++ b/declearn/fairness/fairbatch/_client.py
@@ -66,6 +66,8 @@ class FairbatchControllerClient(FairnessControllerClient):
         # Force the use of a SumAggregator.
         if not isinstance(self.manager.aggrg, SumAggregator):
             self.manager.aggrg = SumAggregator()
+        # Receive and assign initial sampling probabilities.
+        await self._update_fairbatch_sampling_probas(netwk)
 
     async def _update_fairbatch_sampling_probas(
         self,
diff --git a/declearn/fairness/fairbatch/_dataset.py b/declearn/fairness/fairbatch/_dataset.py
index 83ffd70..2949273 100644
--- a/declearn/fairness/fairbatch/_dataset.py
+++ b/declearn/fairness/fairbatch/_dataset.py
@@ -47,7 +47,7 @@ class FairbatchDataset(FairnessDataset):
         """
         self.base = base
         # Assign a dictionary with sampling probability for each group.
-        self.groups = self.get_sensitive_group_definitions()
+        self.groups = self.base.get_sensitive_group_definitions()
         self._counts = self.base.get_sensitive_group_counts()
         self._sampling_probas = {
             group: 1.0 / len(self.groups) for group in self.groups
diff --git a/declearn/fairness/fairbatch/_server.py b/declearn/fairness/fairbatch/_server.py
index 89719fa..8f0592a 100644
--- a/declearn/fairness/fairbatch/_server.py
+++ b/declearn/fairness/fairbatch/_server.py
@@ -55,7 +55,7 @@ class FairbatchControllerServer(FairnessControllerServer):
         https://arxiv.org/abs/2110.15545
     """
 
-    algorithm = "fed-fairbatch"
+    algorithm = "fedfairbatch"
 
     def __init__(
         self,
-- 
GitLab