Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit 0450ee40 authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Fix client-side aggregator choice for FairBatch and FairGrad.

parent a48ba2dd
No related branches found
No related tags found
1 merge request!69Enable Fairness-Aware Federated Learning
...@@ -15,12 +15,13 @@ ...@@ -15,12 +15,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Client-side Fed-FairBatch controller.""" """Client-side Fed-FairBatch/FedFB controller."""
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
from declearn.aggregator import SumAggregator
from declearn.communication.api import NetworkClient from declearn.communication.api import NetworkClient
from declearn.communication.utils import verify_server_message_validity from declearn.communication.utils import verify_server_message_validity
from declearn.fairness.api import FairnessControllerClient from declearn.fairness.api import FairnessControllerClient
...@@ -77,7 +78,9 @@ class FairbatchControllerClient(FairnessControllerClient): ...@@ -77,7 +78,9 @@ class FairbatchControllerClient(FairnessControllerClient):
netwk: NetworkClient, netwk: NetworkClient,
secagg: Optional[Encrypter], secagg: Optional[Encrypter],
) -> None: ) -> None:
pass # no action required beyond sharing group definitions and counts # Force the use of a SumAggregator.
if not isinstance(self.manager.aggrg, SumAggregator):
self.manager.aggrg = SumAggregator()
async def _update_fairbatch_sampling_probas( async def _update_fairbatch_sampling_probas(
self, self,
......
...@@ -21,6 +21,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union ...@@ -21,6 +21,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
from declearn.aggregator import SumAggregator
from declearn.communication.api import NetworkClient from declearn.communication.api import NetworkClient
from declearn.communication.utils import verify_server_message_validity from declearn.communication.utils import verify_server_message_validity
from declearn.fairness.api import FairnessControllerClient from declearn.fairness.api import FairnessControllerClient
...@@ -71,6 +72,9 @@ class FairgradControllerClient(FairnessControllerClient): ...@@ -71,6 +72,9 @@ class FairgradControllerClient(FairnessControllerClient):
netwk: NetworkClient, netwk: NetworkClient,
secagg: Optional[Encrypter], secagg: Optional[Encrypter],
) -> None: ) -> None:
# Force the use of a SumAggregator.
if not isinstance(self.manager.aggrg, SumAggregator):
self.manager.aggrg = SumAggregator()
# Await initial loss weights from the server. # Await initial loss weights from the server.
await self._update_fairgrad_weights(netwk) await self._update_fairgrad_weights(netwk)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment