From 1e88aa456ec945fbe268ceba80a2234260a0777f Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Fri, 5 Jul 2024 16:18:34 +0200 Subject: [PATCH] Fix some bugs in FairBatch controllers. --- declearn/fairness/fairbatch/__init__.py | 3 +++ declearn/fairness/fairbatch/_client.py | 2 ++ declearn/fairness/fairbatch/_dataset.py | 2 +- declearn/fairness/fairbatch/_server.py | 2 +- 4 files changed, 7 insertions(+), 2 deletions(-) diff --git a/declearn/fairness/fairbatch/__init__.py b/declearn/fairness/fairbatch/__init__.py index 63ba46d..56dbad1 100644 --- a/declearn/fairness/fairbatch/__init__.py +++ b/declearn/fairness/fairbatch/__init__.py @@ -53,6 +53,8 @@ Controllers Backend ------- +* [FairbatchDataset][declearn.fairness.fairbatch.FairbatchDataset]: + FairBatch-specific FairnessDataset subclass and wrapper. * [FairbatchSamplingController] [declearn.fairness.fairbatch.FairbatchSamplingController]: ABC to compute and update Fairbatch sampling probabilities. @@ -79,5 +81,6 @@ from ._sampling import ( setup_fairbatch_controller, ) from ._fedfb import setup_fedfb_controller +from ._dataset import FairbatchDataset from ._client import FairbatchControllerClient from ._server import FairbatchControllerServer diff --git a/declearn/fairness/fairbatch/_client.py b/declearn/fairness/fairbatch/_client.py index 264e639..dd91840 100644 --- a/declearn/fairness/fairbatch/_client.py +++ b/declearn/fairness/fairbatch/_client.py @@ -66,6 +66,8 @@ class FairbatchControllerClient(FairnessControllerClient): # Force the use of a SumAggregator. if not isinstance(self.manager.aggrg, SumAggregator): self.manager.aggrg = SumAggregator() + # Receive and assign initial sampling probabilities. + await self._update_fairbatch_sampling_probas(netwk) async def _update_fairbatch_sampling_probas( self, diff --git a/declearn/fairness/fairbatch/_dataset.py b/declearn/fairness/fairbatch/_dataset.py index 83ffd70..2949273 100644 --- a/declearn/fairness/fairbatch/_dataset.py +++ b/declearn/fairness/fairbatch/_dataset.py @@ -47,7 +47,7 @@ class FairbatchDataset(FairnessDataset): """ self.base = base # Assign a dictionary with sampling probability for each group. - self.groups = self.get_sensitive_group_definitions() + self.groups = self.base.get_sensitive_group_definitions() self._counts = self.base.get_sensitive_group_counts() self._sampling_probas = { group: 1.0 / len(self.groups) for group in self.groups diff --git a/declearn/fairness/fairbatch/_server.py b/declearn/fairness/fairbatch/_server.py index 89719fa..8f0592a 100644 --- a/declearn/fairness/fairbatch/_server.py +++ b/declearn/fairness/fairbatch/_server.py @@ -55,7 +55,7 @@ class FairbatchControllerServer(FairnessControllerServer): https://arxiv.org/abs/2110.15545 """ - algorithm = "fed-fairbatch" + algorithm = "fedfairbatch" def __init__( self, -- GitLab