From 4cd7d94d737bd9cb470a8fb7196f62d3a97ffccc Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Fri, 5 Jul 2024 15:18:49 +0200 Subject: [PATCH] Refactor some fairness controllers code. - Expose some subroutines under setup and fairness round, for the mere sake of making tests easier to perform, as well as to enable variants over current algorithms in the future / in experiments. - Rename some methods and re-order some arguments. - Refactor server-side aggregation of metrics, making it part of 'FairnessControllerServer' rather than part of 'FederatedServer' backend code. --- declearn/fairness/api/_client.py | 38 +++++-- declearn/fairness/api/_server.py | 149 ++++++++++++++++++++++--- declearn/fairness/fairbatch/_client.py | 6 +- declearn/fairness/fairbatch/_server.py | 4 +- declearn/fairness/fairfed/_client.py | 2 +- declearn/fairness/fairfed/_server.py | 4 +- declearn/fairness/fairgrad/_client.py | 4 +- declearn/fairness/fairgrad/_server.py | 4 +- declearn/fairness/monitor/_client.py | 2 +- declearn/fairness/monitor/_server.py | 6 +- declearn/main/_client.py | 2 +- declearn/main/_server.py | 23 +--- test/main/test_main_client.py | 8 +- 13 files changed, 186 insertions(+), 66 deletions(-) diff --git a/declearn/fairness/api/_client.py b/declearn/fairness/api/_client.py index 6f187cc..3d08709 100644 --- a/declearn/fairness/api/_client.py +++ b/declearn/fairness/api/_client.py @@ -161,12 +161,36 @@ class FairnessControllerClient(metaclass=abc.ABCMeta): secagg: Optional SecAgg encryption controller. """ + # Agree on a list of sensitive groups and share local sample counts. + await self.exchange_sensitive_groups_list_and_counts(netwk, secagg) + # Run additional algorithm-specific setup steps. + await self.finalize_fairness_setup(netwk, secagg) + + async def exchange_sensitive_groups_list_and_counts( + self, + netwk: NetworkClient, + secagg: Optional[Encrypter], + ) -> None: + """Agree on a list of sensitive groups and share local sample counts. + + This method performs the following routine: + + - Send the list of local sensitive group definitions to the server. + - Await a unified list of sensitive groups in return. + - Assign the received list as `groups` attribute. + - Send (optionally-encrypted) group-wise sample counts to the server. + + Parameters + ---------- + netwk: + `NetworkClient` endpoint, connected to a server. + secagg: + Optional SecAgg encryption controller. + """ # Share sensitive groups definitions and received an ordered list. self.groups = await self._exchange_sensitive_groups_list(netwk) # Send group-wise sample counts for the server to (secure-)aggregate. await self._send_sensitive_groups_counts(netwk, secagg) - # Run additional algorithm-specific setup steps. - await self.finalize_fairness_setup(netwk, secagg) async def _exchange_sensitive_groups_list( self, @@ -220,7 +244,7 @@ class FairnessControllerClient(metaclass=abc.ABCMeta): Optional SecAgg encryption controller. """ - async def fairness_round( + async def run_fairness_round( self, netwk: NetworkClient, query: FairnessQuery, @@ -253,7 +277,7 @@ class FairnessControllerClient(metaclass=abc.ABCMeta): await netwk.send_message(Error(error)) raise RuntimeError(error) from exc # Run additional algorithm-specific steps. - return await self.finalize_fairness_round(netwk, values, secagg) + return await self.finalize_fairness_round(netwk, secagg, values) async def _compute_and_share_fairness_measures( self, @@ -374,8 +398,8 @@ class FairnessControllerClient(metaclass=abc.ABCMeta): async def finalize_fairness_round( self, netwk: NetworkClient, - values: Dict[str, Dict[Tuple[Any, ...], float]], secagg: Optional[Encrypter], + values: Dict[str, Dict[Tuple[Any, ...], float]], ) -> Dict[str, Union[float, np.ndarray]]: """Take actions to enforce fairness. @@ -387,13 +411,13 @@ class FairnessControllerClient(metaclass=abc.ABCMeta): ---------- netwk: NetworkClient endpoint instance, connected to a server. + secagg: + Optional SecAgg encryption controller. values: Nested dict of locally-computed group-wise metrics. This is the second set of `compute_fairness_measures` return values; when this method is called, the first has already been shared with the server for (secure-)aggregation. - secagg: - Optional SecAgg encryption controller. Returns ------- diff --git a/declearn/fairness/api/_server.py b/declearn/fairness/api/_server.py index 33deb01..8ef8a54 100644 --- a/declearn/fairness/api/_server.py +++ b/declearn/fairness/api/_server.py @@ -26,8 +26,10 @@ from declearn.aggregator import Aggregator from declearn.communication.api import NetworkServer from declearn.communication.utils import verify_client_messages_validity from declearn.messaging import ( + Error, FairnessCounts, FairnessGroups, + FairnessReply, FairnessSetupQuery, SerializedMessage, ) @@ -35,6 +37,7 @@ from declearn.secagg.api import Decrypter from declearn.secagg.messaging import ( aggregate_secagg_messages, SecaggFairnessCounts, + SecaggFairnessReply, ) from declearn.utils import create_types_registry, register_type @@ -67,7 +70,7 @@ class FairnessControllerServer(metaclass=abc.ABCMeta): def __init__( self, f_type: str, - f_args: Optional[Dict[str, Any]], + f_args: Optional[Dict[str, Any]] = None, ) -> None: """Instantiate the server-side fairness controller. @@ -82,6 +85,8 @@ class FairnessControllerServer(metaclass=abc.ABCMeta): self.f_args = f_args or {} self.groups = [] # type: List[Tuple[Any, ...]] + # Fairness Setup methods. + async def setup_fairness( self, netwk: NetworkServer, @@ -126,12 +131,14 @@ class FairnessControllerServer(metaclass=abc.ABCMeta): # Send a setup query to all clients. query = self.prepare_fairness_setup_query() await netwk.broadcast_message(query) - # Receive, aggregate, assign and send back sensitive group definitions. - self.groups = await self._exchange_sensitive_groups_list(netwk) - # Receive, (secure-)aggregate and return group-wise sample counts. - counts = await self._aggregate_sensitive_groups_counts(netwk, secagg) + # Agree on a list of sensitive groups and aggregate sample counts. + counts = await self.exchange_sensitive_groups_list_and_counts( + netwk, secagg + ) # Run additional algorithm-specific setup steps. - return await self.finalize_fairness_setup(netwk, counts, aggregator) + return await self.finalize_fairness_setup( + netwk, secagg, counts, aggregator + ) def prepare_fairness_setup_query( self, @@ -149,6 +156,40 @@ class FairnessControllerServer(metaclass=abc.ABCMeta): params={"f_type": self.f_type, "f_args": self.f_args}, ) + async def exchange_sensitive_groups_list_and_counts( + self, + netwk: NetworkServer, + secagg: Optional[Decrypter], + ) -> List[int]: + """Agree on a list of sensitive groups and aggregate sample counts. + + This method performs the following routine: + + - Await `FairnessGroups` messages from clients with group definitions. + - Assign a sorted list of sensitive groups as `groups` attribute. + - Share that list with clients. + - Await possibly-encrypted group-wise sample counts from clients. + - (Secure-)Aggregate these sample counts and return them. + + Parameters + ---------- + netwk: + `NetworkServer` endpoint, through which a fairness setup query + was previously sent to all clients. + secagg: + Optional SecAgg decryption controller. + + Returns + ------- + counts: + List of group-wise total sample count across clients, + sorted based on the newly-assigned `self.groups`. + """ + # Receive, aggregate, assign and send back sensitive group definitions. + self.groups = await self._exchange_sensitive_groups_list(netwk) + # Receive, (secure-)aggregate and return group-wise sample counts. + return await self._aggregate_sensitive_groups_counts(netwk, secagg) + @staticmethod async def _exchange_sensitive_groups_list( netwk: NetworkServer, @@ -213,6 +254,7 @@ class FairnessControllerServer(metaclass=abc.ABCMeta): async def finalize_fairness_setup( self, netwk: NetworkServer, + secagg: Optional[Decrypter], counts: List[int], aggregator: Aggregator, ) -> Aggregator: @@ -238,13 +280,90 @@ class FairnessControllerServer(metaclass=abc.ABCMeta): or may not have been altered compared with the input one. """ + # Fairness Round methods. + + async def run_fairness_round( + self, + netwk: NetworkServer, + secagg: Optional[Decrypter], + ) -> Dict[str, Union[float, np.ndarray]]: + """Secure-aggregate and post-process fairness measures. + + This method is to be run **after** having sent a `FairnessQuery` + to clients. It consists in receiving, (secure-)aggregating and + post-processing measures that clients produce as a reply to that + query. This may involve further algorithm-specific communications. + + Parameters + ---------- + netwk: + NetworkServer endpoint instance, to which clients are registered. + secagg: + Optional SecAgg decryption controller. + + Returns + ------- + metrics: + Fairness(-related) metrics computed as part of this routine, + as a dict mapping scalar or numpy array values with their name. + """ + values = await self.receive_and_aggregate_fairness_measures( + netwk, secagg + ) + return await self.finalize_fairness_round(netwk, secagg, values) + + async def receive_and_aggregate_fairness_measures( + self, + netwk: NetworkServer, + secagg: Optional[Decrypter], + ) -> List[float]: + """Await and (secure-)aggregate client-wise fairness-related metrics. + + This method is designed to be called after sending a `FairnessQuery` + to clients, and returns values that are yet to be parsed and used by + the algorithm-dependent `finalize_fairness_round` method. + + Parameters + ---------- + netwk: + NetworkServer endpoint instance, to which clients are registered. + secagg: + Optional SecAgg decryption controller. + + Returns + ------- + metrics: + List of sum-aggregated fairness-related metrics (as floats). + By default, these are group-wise accuracy values; this may + however be changed or expanded by algorithm-specific classes. + """ + received = await netwk.wait_for_messages() + # Case when expecting cleartext values. + if secagg is None: + replies = await verify_client_messages_validity( + netwk, received, expected=FairnessReply + ) + if len(set(len(r.values) for r in replies.values())) != 1: + error = "Clients sent fairness values of different lengths." + await netwk.broadcast_message(Error(error)) + raise RuntimeError(error) + return [ + sum(rval) + for rval in zip(*[reply.values for reply in replies.values()]) + ] + # Case when expecting encrypted values. + secagg_replies = await verify_client_messages_validity( + netwk, received, expected=SecaggFairnessReply + ) + agg_reply = aggregate_secagg_messages(secagg_replies, decrypter=secagg) + return agg_reply.values + @abc.abstractmethod async def finalize_fairness_round( self, - round_i: int, - values: List[float], netwk: NetworkServer, secagg: Optional[Decrypter], + values: List[float], ) -> Dict[str, Union[float, np.ndarray]]: """Orchestrate a round of actions to enforce fairness. @@ -254,21 +373,17 @@ class FairnessControllerServer(metaclass=abc.ABCMeta): Parameters ---------- - round_i: - Index of the current round (reflecting that of an upcoming - training round). - values: - Aggregated metrics resulting from the fairness evaluation - run by clients at this round. netwk: NetworkServer endpoint instance, to which clients are registered. secagg: Optional SecAgg decryption controller. + values: + Aggregated metrics resulting from the fairness evaluation + run by clients at this round. Returns ------- metrics: - Computed local fairness(-related) metrics computed as part - of this routine, as a dict mapping scalar or numpy array - values with their name. + Fairness(-related) metrics computed as part of this routine, + as a dict mapping scalar or numpy array values with their name. """ diff --git a/declearn/fairness/fairbatch/_client.py b/declearn/fairness/fairbatch/_client.py index ace949f..264e639 100644 --- a/declearn/fairness/fairbatch/_client.py +++ b/declearn/fairness/fairbatch/_client.py @@ -84,7 +84,7 @@ class FairbatchControllerClient(FairnessControllerClient): If the sampling pobabilities' update fails. """ # Receive aggregated sensitive weights. - received = await netwk.check_message() + received = await netwk.recv_message() message = await verify_server_message_validity( netwk, received, expected=FairbatchSamplingProbas ) @@ -114,15 +114,15 @@ class FairbatchControllerClient(FairnessControllerClient): thresh: Optional[float] = None, ) -> List[MeanMetric]: loss = self.computer.setup_loss_metric(model=self.manager.model) - metrics = super().setup_fairness_metrics() + metrics = super().setup_fairness_metrics(thresh=thresh) metrics.append(loss) return metrics async def finalize_fairness_round( self, netwk: NetworkClient, - values: Dict[str, Dict[Tuple[Any, ...], float]], secagg: Optional[Encrypter], + values: Dict[str, Dict[Tuple[Any, ...], float]], ) -> Dict[str, Union[float, np.ndarray]]: # Await updated loss weights from the server. await self._update_fairbatch_sampling_probas(netwk) diff --git a/declearn/fairness/fairbatch/_server.py b/declearn/fairness/fairbatch/_server.py index ecf5971..89719fa 100644 --- a/declearn/fairness/fairbatch/_server.py +++ b/declearn/fairness/fairbatch/_server.py @@ -107,6 +107,7 @@ class FairbatchControllerServer(FairnessControllerServer): async def finalize_fairness_setup( self, netwk: NetworkServer, + secagg: Optional[Decrypter], counts: List[int], aggregator: Aggregator, ) -> Aggregator: @@ -150,10 +151,9 @@ class FairbatchControllerServer(FairnessControllerServer): async def finalize_fairness_round( self, - round_i: int, - values: List[float], netwk: NetworkServer, secagg: Optional[Decrypter], + values: List[float], ) -> Dict[str, Union[float, np.ndarray]]: # Unpack group-wise accuracy and loss values. accuracy = dict(zip(self.groups, values[: len(self.groups)])) diff --git a/declearn/fairness/fairfed/_client.py b/declearn/fairness/fairfed/_client.py index 24efbef..c8a03b2 100644 --- a/declearn/fairness/fairfed/_client.py +++ b/declearn/fairness/fairfed/_client.py @@ -106,8 +106,8 @@ class FairfedControllerClient(FairnessControllerClient): async def finalize_fairness_round( self, netwk: NetworkClient, - values: Dict[str, Dict[Tuple[Any, ...], float]], secagg: Optional[Encrypter], + values: Dict[str, Dict[Tuple[Any, ...], float]], ) -> Dict[str, Union[float, np.ndarray]]: # Await absolute mean fairness across all clients. received = await netwk.recv_message() diff --git a/declearn/fairness/fairfed/_server.py b/declearn/fairness/fairfed/_server.py index 24b5348..99f274b 100644 --- a/declearn/fairness/fairfed/_server.py +++ b/declearn/fairness/fairfed/_server.py @@ -113,6 +113,7 @@ class FairfedControllerServer(FairnessControllerServer): async def finalize_fairness_setup( self, netwk: NetworkServer, + secagg: Optional[Decrypter], counts: List[int], aggregator: Aggregator, ) -> Aggregator: @@ -130,10 +131,9 @@ class FairfedControllerServer(FairnessControllerServer): async def finalize_fairness_round( self, - round_i: int, - values: List[float], netwk: NetworkServer, secagg: Optional[Decrypter], + values: List[float], ) -> Dict[str, Union[float, np.ndarray]]: # Unpack group-wise accuracy values and compute fairness ones. accuracy = dict(zip(self.groups, values)) diff --git a/declearn/fairness/fairgrad/_client.py b/declearn/fairness/fairgrad/_client.py index 4a75c06..aa127af 100644 --- a/declearn/fairness/fairgrad/_client.py +++ b/declearn/fairness/fairgrad/_client.py @@ -69,7 +69,7 @@ class FairgradControllerClient(FairnessControllerClient): If the weights' update fails. """ # Receive aggregated sensitive weights. - received = await netwk.check_message() + received = await netwk.recv_message() message = await verify_server_message_validity( netwk, received, expected=FairgradWeights ) @@ -93,8 +93,8 @@ class FairgradControllerClient(FairnessControllerClient): async def finalize_fairness_round( self, netwk: NetworkClient, - values: Dict[str, Dict[Tuple[Any, ...], float]], secagg: Optional[Encrypter], + values: Dict[str, Dict[Tuple[Any, ...], float]], ) -> Dict[str, Union[float, np.ndarray]]: # Await updated loss weights from the server. await self._update_fairgrad_weights(netwk) diff --git a/declearn/fairness/fairgrad/_server.py b/declearn/fairness/fairgrad/_server.py index 53ef1b7..822b2bc 100644 --- a/declearn/fairness/fairgrad/_server.py +++ b/declearn/fairness/fairgrad/_server.py @@ -189,6 +189,7 @@ class FairgradControllerServer(FairnessControllerServer): async def finalize_fairness_setup( self, netwk: NetworkServer, + secagg: Optional[Decrypter], counts: List[int], aggregator: Aggregator, ) -> Aggregator: @@ -230,10 +231,9 @@ class FairgradControllerServer(FairnessControllerServer): async def finalize_fairness_round( self, - round_i: int, - values: List[float], netwk: NetworkServer, secagg: Optional[Decrypter], + values: List[float], ) -> Dict[str, Union[float, np.ndarray]]: # Unpack group-wise accuracy metrics and update loss weights. accuracy = dict(zip(self.groups, values)) diff --git a/declearn/fairness/monitor/_client.py b/declearn/fairness/monitor/_client.py index 5b41158..01aa352 100644 --- a/declearn/fairness/monitor/_client.py +++ b/declearn/fairness/monitor/_client.py @@ -45,8 +45,8 @@ class FairnessMonitorClient(FairnessControllerClient): async def finalize_fairness_round( self, netwk: NetworkClient, - values: Dict[str, Dict[Tuple[Any, ...], float]], secagg: Optional[Encrypter], + values: Dict[str, Dict[Tuple[Any, ...], float]], ) -> Dict[str, Union[float, np.ndarray]]: return { f"{metric}_{group}": value diff --git a/declearn/fairness/monitor/_server.py b/declearn/fairness/monitor/_server.py index d319a9c..12f2c20 100644 --- a/declearn/fairness/monitor/_server.py +++ b/declearn/fairness/monitor/_server.py @@ -42,7 +42,7 @@ class FairnessMonitorServer(FairnessControllerServer): def __init__( self, f_type: str, - f_args: Optional[Dict[str, Any]], + f_args: Optional[Dict[str, Any]] = None, ) -> None: super().__init__(f_type, f_args) # Assign a temporary fairness functions, replaced at setup time. @@ -53,6 +53,7 @@ class FairnessMonitorServer(FairnessControllerServer): async def finalize_fairness_setup( self, netwk: NetworkServer, + secagg: Optional[Decrypter], counts: List[int], aggregator: Aggregator, ) -> Aggregator: @@ -65,10 +66,9 @@ class FairnessMonitorServer(FairnessControllerServer): async def finalize_fairness_round( self, - round_i: int, - values: List[float], netwk: NetworkServer, secagg: Optional[Decrypter], + values: List[float], ) -> Dict[str, Union[float, np.ndarray]]: # Unpack group-wise accuracy metrics and compute fairness ones. accuracy = dict(zip(self.groups, values)) diff --git a/declearn/main/_client.py b/declearn/main/_client.py index c1619e7..1b3f45e 100644 --- a/declearn/main/_client.py +++ b/declearn/main/_client.py @@ -660,7 +660,7 @@ class FederatedClient: await self.netwk.send_message(messaging.Error(error)) return # Otherwise, run the controller's routine. - metrics = await self.fairness.fairness_round( + metrics = await self.fairness.run_fairness_round( netwk=self.netwk, query=query, secagg=self._encrypter ) # Optionally save computed fairness metrics. diff --git a/declearn/main/_server.py b/declearn/main/_server.py index 11f157a..d0eecfc 100644 --- a/declearn/main/_server.py +++ b/declearn/main/_server.py @@ -565,27 +565,8 @@ class FederatedServer: weights=None, ) await self._send_request_with_optional_weights(query, clients) - # Await and (secure-)aggregate) results. - self.logger.info("Awaiting clients' fairness measures.") - if self._decrypter is None: - replies = await self._collect_results( - clients, messaging.FairnessReply, "fairness round" - ) - if len(set(len(r.values) for r in replies.values())) != 1: - error = "Clients sent fairness values of different lengths." - self.logger.error(error) - await self.netwk.broadcast_message(messaging.Error(error)) - raise RuntimeError(error) - values = [sum(c_values) for c_values in zip(*replies.values())] - else: - secagg_replies = await self._collect_results( - clients, secagg_messaging.SecaggFairnessReply, "fairness round" - ) - values = self._aggregate_secagg_replies(secagg_replies).values - # Have the fairness controller process results. - metrics = await self.fairness.finalize_fairness_round( - round_i=round_i, - values=values, + # Await, (secure-)aggregate and process fairness measures. + metrics = await self.fairness.run_fairness_round( netwk=self.netwk, secagg=self._decrypter, ) diff --git a/test/main/test_main_client.py b/test/main/test_main_client.py index 391daec..7a591c7 100644 --- a/test/main/test_main_client.py +++ b/test/main/test_main_client.py @@ -1067,13 +1067,13 @@ class TestFederatedClientFairnessRound: # Call the 'fairness_round' routine and verify expected actions. request = messaging.FairnessQuery(round_i=1) await client.fairness_round(request) - fairness.fairness_round.assert_awaited_once_with( + fairness.run_fairness_round.assert_awaited_once_with( netwk=netwk, query=request, secagg=None ) # Verify that when a checkpointer is set, it is used. if ckpt: client.ckptr.save_metrics.assert_called_once_with( # type: ignore - metrics=fairness.fairness_round.return_value, + metrics=fairness.run_fairness_round.return_value, prefix="fairness_metrics", append=True, timestamp="round_1", @@ -1102,7 +1102,7 @@ class TestFederatedClientFairnessRound: # Call the 'fairness_round' routine and verify expected actions. request = messaging.FairnessQuery(round_i=1) await client.fairness_round(request) - fairness.fairness_round.assert_awaited_once_with( + fairness.run_fairness_round.assert_awaited_once_with( netwk=netwk, query=request, secagg=secagg.setup_encrypter.return_value, @@ -1149,7 +1149,7 @@ class TestFederatedClientFairnessRound: netwk.send_message.assert_called_once() reply = netwk.send_message.call_args.args[0] assert isinstance(reply, messaging.Error) - fairness.fairness_round.assert_not_called() + fairness.run_fairness_round.assert_not_called() class TestFederatedClientMisc: -- GitLab