diff --git a/declearn/fairness/fairbatch/_client.py b/declearn/fairness/fairbatch/_client.py index 378aafd925a34d3c574014a0ec55a2e7e3419695..07b29f8a3d9b5d3b1003445924484c6729fd34cb 100644 --- a/declearn/fairness/fairbatch/_client.py +++ b/declearn/fairness/fairbatch/_client.py @@ -15,12 +15,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Client-side Fed-FairBatch controller.""" +"""Client-side Fed-FairBatch/FedFB controller.""" from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np +from declearn.aggregator import SumAggregator from declearn.communication.api import NetworkClient from declearn.communication.utils import verify_server_message_validity from declearn.fairness.api import FairnessControllerClient @@ -77,7 +78,9 @@ class FairbatchControllerClient(FairnessControllerClient): netwk: NetworkClient, secagg: Optional[Encrypter], ) -> None: - pass # no action required beyond sharing group definitions and counts + # Force the use of a SumAggregator. + if not isinstance(self.manager.aggrg, SumAggregator): + self.manager.aggrg = SumAggregator() async def _update_fairbatch_sampling_probas( self, diff --git a/declearn/fairness/fairgrad/_client.py b/declearn/fairness/fairgrad/_client.py index 11308d07635e3702ed9ca6fb5b26719debc99342..6d9eab6169f015a476f47943ae877e34e3f2ffc0 100644 --- a/declearn/fairness/fairgrad/_client.py +++ b/declearn/fairness/fairgrad/_client.py @@ -21,6 +21,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np +from declearn.aggregator import SumAggregator from declearn.communication.api import NetworkClient from declearn.communication.utils import verify_server_message_validity from declearn.fairness.api import FairnessControllerClient @@ -71,6 +72,9 @@ class FairgradControllerClient(FairnessControllerClient): netwk: NetworkClient, secagg: Optional[Encrypter], ) -> None: + # Force the use of a SumAggregator. + if not isinstance(self.manager.aggrg, SumAggregator): + self.manager.aggrg = SumAggregator() # Await initial loss weights from the server. await self._update_fairgrad_weights(netwk)