diff --git a/declearn/fairness/fairbatch/_client.py b/declearn/fairness/fairbatch/_client.py
index 378aafd925a34d3c574014a0ec55a2e7e3419695..07b29f8a3d9b5d3b1003445924484c6729fd34cb 100644
--- a/declearn/fairness/fairbatch/_client.py
+++ b/declearn/fairness/fairbatch/_client.py
@@ -15,12 +15,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-"""Client-side Fed-FairBatch controller."""
+"""Client-side Fed-FairBatch/FedFB controller."""
 
 from typing import Any, Dict, List, Optional, Tuple, Union
 
 import numpy as np
 
+from declearn.aggregator import SumAggregator
 from declearn.communication.api import NetworkClient
 from declearn.communication.utils import verify_server_message_validity
 from declearn.fairness.api import FairnessControllerClient
@@ -77,7 +78,9 @@ class FairbatchControllerClient(FairnessControllerClient):
         netwk: NetworkClient,
         secagg: Optional[Encrypter],
     ) -> None:
-        pass  # no action required beyond sharing group definitions and counts
+        # Force the use of a SumAggregator.
+        if not isinstance(self.manager.aggrg, SumAggregator):
+            self.manager.aggrg = SumAggregator()
 
     async def _update_fairbatch_sampling_probas(
         self,
diff --git a/declearn/fairness/fairgrad/_client.py b/declearn/fairness/fairgrad/_client.py
index 11308d07635e3702ed9ca6fb5b26719debc99342..6d9eab6169f015a476f47943ae877e34e3f2ffc0 100644
--- a/declearn/fairness/fairgrad/_client.py
+++ b/declearn/fairness/fairgrad/_client.py
@@ -21,6 +21,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
 
 import numpy as np
 
+from declearn.aggregator import SumAggregator
 from declearn.communication.api import NetworkClient
 from declearn.communication.utils import verify_server_message_validity
 from declearn.fairness.api import FairnessControllerClient
@@ -71,6 +72,9 @@ class FairgradControllerClient(FairnessControllerClient):
         netwk: NetworkClient,
         secagg: Optional[Encrypter],
     ) -> None:
+        # Force the use of a SumAggregator.
+        if not isinstance(self.manager.aggrg, SumAggregator):
+            self.manager.aggrg = SumAggregator()
         # Await initial loss weights from the server.
         await self._update_fairgrad_weights(netwk)