Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit f95450f4 authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Update Fed-FairGrad implementation to API changes.

parent e599b23a
No related branches found
No related tags found
1 merge request!69Enable Fairness-Aware Federated Learning
......@@ -50,17 +50,23 @@ Controllers
[declearn.fairness.fairgrad.FairgradControllerServer]:
Server-side controller to implement Fed-FairGrad.
Backend
-------
* [FairgradWeightsController]
[declearn.fairness.fairgrad.FairgradWeightsController]:
Controller to implement Faigrad optimization constraints.
Messages
--------
* [FairgradSetupQuery][declearn.fairness.fairgrad.FairgradSetupQuery]:
Message for server-emitted Fed-FairGrad setup queries.
* [FairgradOkay][declearn.fairness.fairgrad.FairgradOkay]:
Message for client-emitted signal that Fed-FairGrad update went fine.
* [FairgradWeights][declearn.fairness.fairgrad.FairgradWeights]:
Message for server-emitted (Fed-)FairGrad loss weights sharing.
"""
from ._messages import (
FairgradSetupQuery,
FairgradOkay,
FairgradWeights,
)
from ._client import FairgradControllerClient
from ._server import FairgradControllerServer
from ._server import FairgradControllerServer, FairgradWeightsController
......@@ -17,24 +17,19 @@
"""Client-side Fed-FairGrad controller."""
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union
import numpy as np
from declearn.communication.api import NetworkClient
from declearn.communication.utils import verify_server_message_validity
from declearn.fairness.api import (
FairnessAccuracy,
FairnessRoundQuery,
FairnessRoundReply,
FairnessControllerClient,
SecaggFairnessAccuracy,
from declearn.fairness.api import FairnessControllerClient
from declearn.fairness.core import (
FairnessDataset,
instantiate_fairness_function,
)
from declearn.fairness.core import FairnessAccuracyComputer, FairnessDataset
from declearn.fairness.fairgrad._messages import (
FairgradSetupQuery,
FairgradWeights,
)
from declearn.messaging import Error, SerializedMessage
from declearn.fairness.fairgrad._messages import FairgradOkay, FairgradWeights
from declearn.messaging import Error
from declearn.secagg.api import Encrypter
from declearn.training import TrainingManager
......@@ -46,98 +41,42 @@ __all__ = [
class FairgradControllerClient(FairnessControllerClient):
"""Client-side controller to implement Fed-FairGrad."""
setup_query_cls = FairgradSetupQuery
algorithm = "fedfairgrad"
def __init__(
self,
) -> None:
super().__init__()
self._accuracy_computer = (
None
) # type: Optional[FairnessAccuracyComputer]
async def finalize_fairness_setup(
self,
netwk: NetworkClient,
manager: TrainingManager,
secagg: Optional[Encrypter],
params: Dict[str, Any],
) -> TrainingManager:
assert isinstance(manager.train_data, FairnessDataset)
# Set up a controller to compute group-wise model accuracy.
self._accuracy_computer = FairnessAccuracyComputer(manager.train_data)
# Await initial loss weights from the server.
await self._update_fairgrad_weights(netwk, manager)
# Return the input TrainingManager.
return manager
async def fairness_round(
self,
netwk: NetworkClient,
manager: TrainingManager,
received: SerializedMessage[FairnessRoundQuery],
secagg: Optional[Encrypter],
f_type: str,
f_args: Dict[str, Any],
) -> None:
query = await verify_server_message_validity(
netwk, received, expected=FairnessRoundQuery
)
await self._compute_and_send_groupwise_accuracy(
netwk, manager, query, secagg
"""Instantiate the client-side fairness controller.
Parameters
----------
manager:
`TrainingManager` instance wrapping the model being trained
and its training dataset (that must be a `FairnessDataset`).
f_type:
Name of the type of group-fairness function being optimized.
f_args:
Keyword arguments to the group-fairness function.
"""
super().__init__(manager)
self.fairness_function = instantiate_fairness_function(
f_type=f_type, counts=self.computer.counts, **f_args
)
await self._update_fairgrad_weights(netwk, manager)
async def _compute_and_send_groupwise_accuracy(
async def finalize_fairness_setup(
self,
netwk: NetworkClient,
manager: TrainingManager,
query: FairnessRoundQuery,
secagg: Optional[Encrypter],
) -> None:
# Compute the count-weighted group-wise accuracy, handling exceptions.
try:
accuracy = self._compute_groupwise_accuracy(manager, query)
except Exception as exc: # pylint: disable=broad-except
manager.logger.error(
"Exception raised when computing group-wise accuracy: %s", exc
)
await netwk.send_message(Error(repr(exc)))
raise RuntimeError("Group accuracy computation failed.") from exc
# Send the computed metrics to the server, optionally encrypted.
manager.logger.info("Sending group-wise accuracy to the server.")
reply = FairnessAccuracy(accuracy)
if secagg is None:
await netwk.send_message(reply)
else:
await netwk.send_message(
SecaggFairnessAccuracy.from_cleartext_message(reply, secagg)
)
def _compute_groupwise_accuracy(
self,
manager: TrainingManager,
query: FairnessRoundQuery,
) -> List[float]:
"""Compute (counts-weighted) accuracy over sensitive groups."""
assert self._accuracy_computer is not None
# Compute group-wise accuracy scores.
accuracy = self._accuracy_computer.compute_groupwise_accuracy(
model=manager.model,
batch_size=query.batch_size,
n_batch=query.n_batch,
thresh=query.thresh,
)
# Multiply these scores by sample counts.
accuracy = {
key: val * self._accuracy_computer.counts[key]
for key, val in accuracy.items()
}
# Return shareable group-wise values, ordered and filled out.
return [accuracy.get(group, 0.0) for group in self.groups]
# Await initial loss weights from the server.
await self._update_fairgrad_weights(netwk)
async def _update_fairgrad_weights(
self,
netwk: NetworkClient,
manager: TrainingManager,
) -> None:
"""Run a FairGrad-specific routine to update sensitive group weights.
......@@ -158,17 +97,63 @@ class FairgradControllerClient(FairnessControllerClient):
weights = dict(zip(self.groups, message.weights))
# Set the received weights, handling and propagating exceptions if any.
try:
assert isinstance(manager.train_data, FairnessDataset)
manager.train_data.set_sensitive_group_weights(
weights,
adjust_by_counts=True,
assert isinstance(self.manager.train_data, FairnessDataset)
self.manager.train_data.set_sensitive_group_weights(
weights, adjust_by_counts=True
)
except (AssertionError, KeyError, TypeError) as exc:
manager.logger.error(
except Exception as exc:
self.manager.logger.error(
"Exception encountered when setting FairGrad weights: %s", exc
)
await netwk.send_message(Error(repr(exc)))
raise RuntimeError("FairGrad weights update failed.") from exc
# If things went well, ping the server back to indicate so.
manager.logger.info("Updated FairGrad weights.")
await netwk.send_message(FairnessRoundReply())
self.manager.logger.info("Updated FairGrad weights.")
await netwk.send_message(FairgradOkay())
def compute_fairness_measures(
self,
batch_size: int,
n_batch: Optional[int] = None,
thresh: Optional[float] = None,
) -> List[float]:
# Compute group-wise accuracy scores.
accuracy = self.computer.compute_groupwise_accuracy(
model=self.manager.model,
batch_size=batch_size,
n_batch=n_batch,
thresh=thresh,
)
# Multiply these scores by sample counts.
accuracy = {
key: val * self.computer.counts[key]
for key, val in accuracy.items()
}
# Return shareable group-wise values, ordered and filled out.
return [accuracy.get(group, 0.0) for group in self.groups]
async def finalize_fairness_round(
self,
netwk: NetworkClient,
values: List[float],
secagg: Optional[Encrypter],
) -> Dict[str, Union[float, np.ndarray]]:
# Await updated loss weights from the server.
await self._update_fairgrad_weights(netwk)
# Recover raw accuracy scores for groups with local samples.
accuracy = {
key: val / self.computer.counts[key]
for key, val in zip(self.groups, values)
if key in self.computer.counts
}
# Compute local fairness measures.
fairness = self.fairness_function.compute_from_group_accuracy(accuracy)
f_type = self.fairness_function.f_type
# Package and return accuracy and fairness metrics.
metrics = {
f"accuracy_{key}": val for key, val in accuracy.items()
} # type: Dict[str, Union[float, np.ndarray]]
metrics.update(
{f"{f_type}_{key}": val for key, val in fairness.items()}
)
return metrics
......@@ -21,25 +21,20 @@ import dataclasses
from typing import List
from declearn.fairness.api import FairnessSetupQuery
from declearn.messaging import Message
__all__ = [
"FairgradSetupQuery",
"FairgradOkay",
"FairgradWeights",
]
@dataclasses.dataclass
class FairgradSetupQuery(FairnessSetupQuery):
"""Message for server-emitted Fed-FairGrad setup queries.
class FairgradOkay(Message):
"""Message for client-emitted signal that Fed-FairGrad update went fine."""
This message is empty and merely signifies that Fed-FairGrad
should be set up by the client.
"""
typekey = "fairgrad-setup"
typekey = "fairgrad-okay"
@dataclasses.dataclass
......
......@@ -18,32 +18,28 @@
"""Server-side Fed-FairGrad controller."""
import warnings
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
from declearn.aggregator import Aggregator, SumAggregator
from declearn.communication.api import NetworkServer
from declearn.communication.utils import verify_client_messages_validity
from declearn.fairness.api import (
FairnessAccuracy,
FairnessRoundQuery,
FairnessRoundReply,
FairnessControllerServer,
FairnessSetupQuery,
SecaggFairnessAccuracy,
)
from declearn.fairness.api import FairnessControllerServer
from declearn.fairness.core import instantiate_fairness_function
from declearn.fairness.fairgrad._messages import (
FairgradSetupQuery,
FairgradWeights,
)
from declearn.fairness.fairgrad._messages import FairgradOkay, FairgradWeights
from declearn.messaging import FairnessSetupQuery
from declearn.secagg.api import Decrypter
from declearn.secagg.messaging import aggregate_secagg_messages
__all__ = [
"FairgradControllerServer",
"FairgradWeightsController",
]
class FairgradWeightsController:
"""Fairness controller to implement Faigrad optimization constraints."""
"""Controller to implement Faigrad optimization constraints."""
# attrs serve readability; pylint: disable=too-many-instance-attributes
......@@ -157,6 +153,8 @@ class FairgradWeightsController:
class FairgradControllerServer(FairnessControllerServer):
"""Server-side controller to implement Fed-FairGrad."""
algorithm = "fedfairgrad"
def __init__(
self,
f_type: str,
......@@ -182,16 +180,17 @@ class FairgradControllerServer(FairnessControllerServer):
This may be set to 0.0 to try and enforce absolute fairness.
"""
super().__init__(f_type=f_type, f_args=f_args)
self.weights_controller = (
None
) # type: Optional[FairgradWeightsController]
self._eta = eta
self._eps = eps
# Set up a temporary controller that will be replaced at setup time.
self.weights_controller = FairgradWeightsController(
counts={}, f_type="accuracy_parity", eta=eta, eps=eps
)
def prepare_fairness_setup_query(
self,
) -> FairnessSetupQuery:
return FairgradSetupQuery()
query = super().prepare_fairness_setup_query()
query.params.update({"f_type": self.f_type, "f_args": self.f_args})
return query
async def finalize_fairness_setup(
self,
......@@ -203,8 +202,8 @@ class FairgradControllerServer(FairnessControllerServer):
self.weights_controller = FairgradWeightsController(
counts=dict(zip(self.groups, counts)),
f_type=self.f_type,
eta=self._eta,
eps=self._eps,
eta=self.weights_controller.eta,
eps=self.weights_controller.eps,
**self.f_args,
)
# Send initial loss weights to the clients.
......@@ -228,50 +227,31 @@ class FairgradControllerServer(FairnessControllerServer):
Await for clients to ping back that things went fine on their side.
"""
netwk.logger.info("Sending FairGrad weights to clients.")
assert self.weights_controller is not None
weights = self.weights_controller.get_current_weights(norm_nk=True)
await netwk.broadcast_message(FairgradWeights(weights=weights))
received = await netwk.wait_for_messages()
await verify_client_messages_validity(
netwk, received, expected=FairnessRoundReply
netwk, received, expected=FairgradOkay
)
async def fairness_round(
async def finalize_fairness_round(
self,
round_i: int,
values: List[float],
netwk: NetworkServer,
secagg: Optional[Decrypter],
) -> None:
assert self.weights_controller is not None
# Send a query to clients and await group-wise accuracy metrics.
await netwk.broadcast_message(
FairnessRoundQuery() # TODO: receive a config and use it
)
received = await netwk.wait_for_messages()
# When SecAgg is not set, expect and aggregate cleartext values.
if secagg is None:
replies = await verify_client_messages_validity(
netwk, received, expected=FairnessAccuracy
)
accuracy = self._aggregate_cleartext_accuracy(replies)
# When SecAgg is set, expect and secure-aggregate encrypted values.
else:
sec_rep = await verify_client_messages_validity(
netwk, received, expected=SecaggFairnessAccuracy
)
accuracy = aggregate_secagg_messages(sec_rep, secagg).values
# Compute global fairness and update FairGrad loss weights.
self.weights_controller.update_weights_based_on_accuracy(
accuracy=dict(zip(self.groups, accuracy))
)
# Send back the updated weights to the clients.
) -> Dict[str, Union[float, np.ndarray]]:
# Unpack group-wise accuracy metrics and update loss weights.
accuracy = dict(zip(self.groups, values))
self.weights_controller.update_weights_based_on_accuracy(accuracy)
# Send the updated weights to clients.
await self._send_fairgrad_weights(netwk)
def _aggregate_cleartext_accuracy(
self,
messages: Dict[str, FairnessAccuracy],
) -> List[float]:
"""Sum group-wise accuracy metrics received from clients."""
accuracy = np.zeros(len(self.groups), dtype="float64")
for message in messages.values():
accuracy += np.asarray(message.values, dtype="float64")
return accuracy.tolist()
# Package and return accuracy and fairness metrics.
metrics = {
f"accuracy_{key}": val for key, val in accuracy.items()
} # type: Dict[str, Union[float, np.ndarray]]
fairness = self.weights_controller.get_current_fairness()
metrics.update(
{f"{self.f_type}_{key}": val for key, val in fairness.items()}
)
return metrics
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment