diff --git a/declearn/fairness/fairgrad/__init__.py b/declearn/fairness/fairgrad/__init__.py index 8eb00d8d2bc233bfa326eb5507ec2e14c621944f..532d3b6a1849866a1b43b79b8e17a9075a8e4a92 100644 --- a/declearn/fairness/fairgrad/__init__.py +++ b/declearn/fairness/fairgrad/__init__.py @@ -50,17 +50,23 @@ Controllers [declearn.fairness.fairgrad.FairgradControllerServer]: Server-side controller to implement Fed-FairGrad. +Backend +------- +* [FairgradWeightsController] +[declearn.fairness.fairgrad.FairgradWeightsController]: + Controller to implement Faigrad optimization constraints. + Messages -------- -* [FairgradSetupQuery][declearn.fairness.fairgrad.FairgradSetupQuery]: - Message for server-emitted Fed-FairGrad setup queries. +* [FairgradOkay][declearn.fairness.fairgrad.FairgradOkay]: + Message for client-emitted signal that Fed-FairGrad update went fine. * [FairgradWeights][declearn.fairness.fairgrad.FairgradWeights]: Message for server-emitted (Fed-)FairGrad loss weights sharing. """ from ._messages import ( - FairgradSetupQuery, + FairgradOkay, FairgradWeights, ) from ._client import FairgradControllerClient -from ._server import FairgradControllerServer +from ._server import FairgradControllerServer, FairgradWeightsController diff --git a/declearn/fairness/fairgrad/_client.py b/declearn/fairness/fairgrad/_client.py index 1fa63ac4417a90152f0c918ced1b12437ccade9d..3e21b272f2cdb9b540a71e6314df3fa66833af0c 100644 --- a/declearn/fairness/fairgrad/_client.py +++ b/declearn/fairness/fairgrad/_client.py @@ -17,24 +17,19 @@ """Client-side Fed-FairGrad controller.""" -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union +import numpy as np from declearn.communication.api import NetworkClient from declearn.communication.utils import verify_server_message_validity -from declearn.fairness.api import ( - FairnessAccuracy, - FairnessRoundQuery, - FairnessRoundReply, - FairnessControllerClient, - SecaggFairnessAccuracy, +from declearn.fairness.api import FairnessControllerClient +from declearn.fairness.core import ( + FairnessDataset, + instantiate_fairness_function, ) -from declearn.fairness.core import FairnessAccuracyComputer, FairnessDataset -from declearn.fairness.fairgrad._messages import ( - FairgradSetupQuery, - FairgradWeights, -) -from declearn.messaging import Error, SerializedMessage +from declearn.fairness.fairgrad._messages import FairgradOkay, FairgradWeights +from declearn.messaging import Error from declearn.secagg.api import Encrypter from declearn.training import TrainingManager @@ -46,98 +41,42 @@ __all__ = [ class FairgradControllerClient(FairnessControllerClient): """Client-side controller to implement Fed-FairGrad.""" - setup_query_cls = FairgradSetupQuery + algorithm = "fedfairgrad" def __init__( self, - ) -> None: - super().__init__() - self._accuracy_computer = ( - None - ) # type: Optional[FairnessAccuracyComputer] - - async def finalize_fairness_setup( - self, - netwk: NetworkClient, - manager: TrainingManager, - secagg: Optional[Encrypter], - params: Dict[str, Any], - ) -> TrainingManager: - assert isinstance(manager.train_data, FairnessDataset) - # Set up a controller to compute group-wise model accuracy. - self._accuracy_computer = FairnessAccuracyComputer(manager.train_data) - # Await initial loss weights from the server. - await self._update_fairgrad_weights(netwk, manager) - # Return the input TrainingManager. - return manager - - async def fairness_round( - self, - netwk: NetworkClient, manager: TrainingManager, - received: SerializedMessage[FairnessRoundQuery], - secagg: Optional[Encrypter], + f_type: str, + f_args: Dict[str, Any], ) -> None: - query = await verify_server_message_validity( - netwk, received, expected=FairnessRoundQuery - ) - await self._compute_and_send_groupwise_accuracy( - netwk, manager, query, secagg + """Instantiate the client-side fairness controller. + + Parameters + ---------- + manager: + `TrainingManager` instance wrapping the model being trained + and its training dataset (that must be a `FairnessDataset`). + f_type: + Name of the type of group-fairness function being optimized. + f_args: + Keyword arguments to the group-fairness function. + """ + super().__init__(manager) + self.fairness_function = instantiate_fairness_function( + f_type=f_type, counts=self.computer.counts, **f_args ) - await self._update_fairgrad_weights(netwk, manager) - async def _compute_and_send_groupwise_accuracy( + async def finalize_fairness_setup( self, netwk: NetworkClient, - manager: TrainingManager, - query: FairnessRoundQuery, secagg: Optional[Encrypter], ) -> None: - # Compute the count-weighted group-wise accuracy, handling exceptions. - try: - accuracy = self._compute_groupwise_accuracy(manager, query) - except Exception as exc: # pylint: disable=broad-except - manager.logger.error( - "Exception raised when computing group-wise accuracy: %s", exc - ) - await netwk.send_message(Error(repr(exc))) - raise RuntimeError("Group accuracy computation failed.") from exc - # Send the computed metrics to the server, optionally encrypted. - manager.logger.info("Sending group-wise accuracy to the server.") - reply = FairnessAccuracy(accuracy) - if secagg is None: - await netwk.send_message(reply) - else: - await netwk.send_message( - SecaggFairnessAccuracy.from_cleartext_message(reply, secagg) - ) - - def _compute_groupwise_accuracy( - self, - manager: TrainingManager, - query: FairnessRoundQuery, - ) -> List[float]: - """Compute (counts-weighted) accuracy over sensitive groups.""" - assert self._accuracy_computer is not None - # Compute group-wise accuracy scores. - accuracy = self._accuracy_computer.compute_groupwise_accuracy( - model=manager.model, - batch_size=query.batch_size, - n_batch=query.n_batch, - thresh=query.thresh, - ) - # Multiply these scores by sample counts. - accuracy = { - key: val * self._accuracy_computer.counts[key] - for key, val in accuracy.items() - } - # Return shareable group-wise values, ordered and filled out. - return [accuracy.get(group, 0.0) for group in self.groups] + # Await initial loss weights from the server. + await self._update_fairgrad_weights(netwk) async def _update_fairgrad_weights( self, netwk: NetworkClient, - manager: TrainingManager, ) -> None: """Run a FairGrad-specific routine to update sensitive group weights. @@ -158,17 +97,63 @@ class FairgradControllerClient(FairnessControllerClient): weights = dict(zip(self.groups, message.weights)) # Set the received weights, handling and propagating exceptions if any. try: - assert isinstance(manager.train_data, FairnessDataset) - manager.train_data.set_sensitive_group_weights( - weights, - adjust_by_counts=True, + assert isinstance(self.manager.train_data, FairnessDataset) + self.manager.train_data.set_sensitive_group_weights( + weights, adjust_by_counts=True ) - except (AssertionError, KeyError, TypeError) as exc: - manager.logger.error( + except Exception as exc: + self.manager.logger.error( "Exception encountered when setting FairGrad weights: %s", exc ) await netwk.send_message(Error(repr(exc))) raise RuntimeError("FairGrad weights update failed.") from exc # If things went well, ping the server back to indicate so. - manager.logger.info("Updated FairGrad weights.") - await netwk.send_message(FairnessRoundReply()) + self.manager.logger.info("Updated FairGrad weights.") + await netwk.send_message(FairgradOkay()) + + def compute_fairness_measures( + self, + batch_size: int, + n_batch: Optional[int] = None, + thresh: Optional[float] = None, + ) -> List[float]: + # Compute group-wise accuracy scores. + accuracy = self.computer.compute_groupwise_accuracy( + model=self.manager.model, + batch_size=batch_size, + n_batch=n_batch, + thresh=thresh, + ) + # Multiply these scores by sample counts. + accuracy = { + key: val * self.computer.counts[key] + for key, val in accuracy.items() + } + # Return shareable group-wise values, ordered and filled out. + return [accuracy.get(group, 0.0) for group in self.groups] + + async def finalize_fairness_round( + self, + netwk: NetworkClient, + values: List[float], + secagg: Optional[Encrypter], + ) -> Dict[str, Union[float, np.ndarray]]: + # Await updated loss weights from the server. + await self._update_fairgrad_weights(netwk) + # Recover raw accuracy scores for groups with local samples. + accuracy = { + key: val / self.computer.counts[key] + for key, val in zip(self.groups, values) + if key in self.computer.counts + } + # Compute local fairness measures. + fairness = self.fairness_function.compute_from_group_accuracy(accuracy) + f_type = self.fairness_function.f_type + # Package and return accuracy and fairness metrics. + metrics = { + f"accuracy_{key}": val for key, val in accuracy.items() + } # type: Dict[str, Union[float, np.ndarray]] + metrics.update( + {f"{f_type}_{key}": val for key, val in fairness.items()} + ) + return metrics diff --git a/declearn/fairness/fairgrad/_messages.py b/declearn/fairness/fairgrad/_messages.py index e374b5a180079a6ef97fffdea9fd8fdd0529168d..fd650cbce3b01f8aa7c95943b3b19de83f26c3e1 100644 --- a/declearn/fairness/fairgrad/_messages.py +++ b/declearn/fairness/fairgrad/_messages.py @@ -21,25 +21,20 @@ import dataclasses from typing import List -from declearn.fairness.api import FairnessSetupQuery from declearn.messaging import Message __all__ = [ - "FairgradSetupQuery", + "FairgradOkay", "FairgradWeights", ] @dataclasses.dataclass -class FairgradSetupQuery(FairnessSetupQuery): - """Message for server-emitted Fed-FairGrad setup queries. +class FairgradOkay(Message): + """Message for client-emitted signal that Fed-FairGrad update went fine.""" - This message is empty and merely signifies that Fed-FairGrad - should be set up by the client. - """ - - typekey = "fairgrad-setup" + typekey = "fairgrad-okay" @dataclasses.dataclass diff --git a/declearn/fairness/fairgrad/_server.py b/declearn/fairness/fairgrad/_server.py index f13f1c3cbac9139fcaae48ef251b7515450ac10f..b04fd3b3d06d72555a290619b78fd71045c0aae8 100644 --- a/declearn/fairness/fairgrad/_server.py +++ b/declearn/fairness/fairgrad/_server.py @@ -18,32 +18,28 @@ """Server-side Fed-FairGrad controller.""" import warnings -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np from declearn.aggregator import Aggregator, SumAggregator from declearn.communication.api import NetworkServer from declearn.communication.utils import verify_client_messages_validity -from declearn.fairness.api import ( - FairnessAccuracy, - FairnessRoundQuery, - FairnessRoundReply, - FairnessControllerServer, - FairnessSetupQuery, - SecaggFairnessAccuracy, -) +from declearn.fairness.api import FairnessControllerServer from declearn.fairness.core import instantiate_fairness_function -from declearn.fairness.fairgrad._messages import ( - FairgradSetupQuery, - FairgradWeights, -) +from declearn.fairness.fairgrad._messages import FairgradOkay, FairgradWeights +from declearn.messaging import FairnessSetupQuery from declearn.secagg.api import Decrypter -from declearn.secagg.messaging import aggregate_secagg_messages + + +__all__ = [ + "FairgradControllerServer", + "FairgradWeightsController", +] class FairgradWeightsController: - """Fairness controller to implement Faigrad optimization constraints.""" + """Controller to implement Faigrad optimization constraints.""" # attrs serve readability; pylint: disable=too-many-instance-attributes @@ -157,6 +153,8 @@ class FairgradWeightsController: class FairgradControllerServer(FairnessControllerServer): """Server-side controller to implement Fed-FairGrad.""" + algorithm = "fedfairgrad" + def __init__( self, f_type: str, @@ -182,16 +180,17 @@ class FairgradControllerServer(FairnessControllerServer): This may be set to 0.0 to try and enforce absolute fairness. """ super().__init__(f_type=f_type, f_args=f_args) - self.weights_controller = ( - None - ) # type: Optional[FairgradWeightsController] - self._eta = eta - self._eps = eps + # Set up a temporary controller that will be replaced at setup time. + self.weights_controller = FairgradWeightsController( + counts={}, f_type="accuracy_parity", eta=eta, eps=eps + ) def prepare_fairness_setup_query( self, ) -> FairnessSetupQuery: - return FairgradSetupQuery() + query = super().prepare_fairness_setup_query() + query.params.update({"f_type": self.f_type, "f_args": self.f_args}) + return query async def finalize_fairness_setup( self, @@ -203,8 +202,8 @@ class FairgradControllerServer(FairnessControllerServer): self.weights_controller = FairgradWeightsController( counts=dict(zip(self.groups, counts)), f_type=self.f_type, - eta=self._eta, - eps=self._eps, + eta=self.weights_controller.eta, + eps=self.weights_controller.eps, **self.f_args, ) # Send initial loss weights to the clients. @@ -228,50 +227,31 @@ class FairgradControllerServer(FairnessControllerServer): Await for clients to ping back that things went fine on their side. """ netwk.logger.info("Sending FairGrad weights to clients.") - assert self.weights_controller is not None weights = self.weights_controller.get_current_weights(norm_nk=True) await netwk.broadcast_message(FairgradWeights(weights=weights)) received = await netwk.wait_for_messages() await verify_client_messages_validity( - netwk, received, expected=FairnessRoundReply + netwk, received, expected=FairgradOkay ) - async def fairness_round( + async def finalize_fairness_round( self, + round_i: int, + values: List[float], netwk: NetworkServer, secagg: Optional[Decrypter], - ) -> None: - assert self.weights_controller is not None - # Send a query to clients and await group-wise accuracy metrics. - await netwk.broadcast_message( - FairnessRoundQuery() # TODO: receive a config and use it - ) - received = await netwk.wait_for_messages() - # When SecAgg is not set, expect and aggregate cleartext values. - if secagg is None: - replies = await verify_client_messages_validity( - netwk, received, expected=FairnessAccuracy - ) - accuracy = self._aggregate_cleartext_accuracy(replies) - # When SecAgg is set, expect and secure-aggregate encrypted values. - else: - sec_rep = await verify_client_messages_validity( - netwk, received, expected=SecaggFairnessAccuracy - ) - accuracy = aggregate_secagg_messages(sec_rep, secagg).values - # Compute global fairness and update FairGrad loss weights. - self.weights_controller.update_weights_based_on_accuracy( - accuracy=dict(zip(self.groups, accuracy)) - ) - # Send back the updated weights to the clients. + ) -> Dict[str, Union[float, np.ndarray]]: + # Unpack group-wise accuracy metrics and update loss weights. + accuracy = dict(zip(self.groups, values)) + self.weights_controller.update_weights_based_on_accuracy(accuracy) + # Send the updated weights to clients. await self._send_fairgrad_weights(netwk) - - def _aggregate_cleartext_accuracy( - self, - messages: Dict[str, FairnessAccuracy], - ) -> List[float]: - """Sum group-wise accuracy metrics received from clients.""" - accuracy = np.zeros(len(self.groups), dtype="float64") - for message in messages.values(): - accuracy += np.asarray(message.values, dtype="float64") - return accuracy.tolist() + # Package and return accuracy and fairness metrics. + metrics = { + f"accuracy_{key}": val for key, val in accuracy.items() + } # type: Dict[str, Union[float, np.ndarray]] + fairness = self.weights_controller.get_current_fairness() + metrics.update( + {f"{self.f_type}_{key}": val for key, val in fairness.items()} + ) + return metrics