diff --git a/declearn/fairness/fairbatch/__init__.py b/declearn/fairness/fairbatch/__init__.py index 63ba46d24dd6a1eb1287bf242055dfaf5bb34cda..56dbad116edfe73753e6e843a86cb52a261cef12 100644 --- a/declearn/fairness/fairbatch/__init__.py +++ b/declearn/fairness/fairbatch/__init__.py @@ -53,6 +53,8 @@ Controllers Backend ------- +* [FairbatchDataset][declearn.fairness.fairbatch.FairbatchDataset]: + FairBatch-specific FairnessDataset subclass and wrapper. * [FairbatchSamplingController] [declearn.fairness.fairbatch.FairbatchSamplingController]: ABC to compute and update Fairbatch sampling probabilities. @@ -79,5 +81,6 @@ from ._sampling import ( setup_fairbatch_controller, ) from ._fedfb import setup_fedfb_controller +from ._dataset import FairbatchDataset from ._client import FairbatchControllerClient from ._server import FairbatchControllerServer diff --git a/declearn/fairness/fairbatch/_client.py b/declearn/fairness/fairbatch/_client.py index 264e639bba0c1fa5fd67df75f2adbc45fbe01435..dd91840b5f2bca28f16fa4346aa6e4fa2795d569 100644 --- a/declearn/fairness/fairbatch/_client.py +++ b/declearn/fairness/fairbatch/_client.py @@ -66,6 +66,8 @@ class FairbatchControllerClient(FairnessControllerClient): # Force the use of a SumAggregator. if not isinstance(self.manager.aggrg, SumAggregator): self.manager.aggrg = SumAggregator() + # Receive and assign initial sampling probabilities. + await self._update_fairbatch_sampling_probas(netwk) async def _update_fairbatch_sampling_probas( self, diff --git a/declearn/fairness/fairbatch/_dataset.py b/declearn/fairness/fairbatch/_dataset.py index 83ffd70d6881c2c0771e1dc69ff9d44b85f807d3..2949273c0b60394f132faf672e0bb60813e024b1 100644 --- a/declearn/fairness/fairbatch/_dataset.py +++ b/declearn/fairness/fairbatch/_dataset.py @@ -47,7 +47,7 @@ class FairbatchDataset(FairnessDataset): """ self.base = base # Assign a dictionary with sampling probability for each group. - self.groups = self.get_sensitive_group_definitions() + self.groups = self.base.get_sensitive_group_definitions() self._counts = self.base.get_sensitive_group_counts() self._sampling_probas = { group: 1.0 / len(self.groups) for group in self.groups diff --git a/declearn/fairness/fairbatch/_server.py b/declearn/fairness/fairbatch/_server.py index 89719fa8f70456f3e9064a2cd0ef0829762a73d8..8f0592af47583f26234edd3a79b94d6e6d2a8bec 100644 --- a/declearn/fairness/fairbatch/_server.py +++ b/declearn/fairness/fairbatch/_server.py @@ -55,7 +55,7 @@ class FairbatchControllerServer(FairnessControllerServer): https://arxiv.org/abs/2110.15545 """ - algorithm = "fed-fairbatch" + algorithm = "fedfairbatch" def __init__( self,