From 3c1bfc1223c2b7b3ac6ea9ab0b57f7a2fd63d53f Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Mon, 3 Jun 2024 15:45:01 +0200 Subject: [PATCH] Update Fairness API to limit redundant computations. - Have clients' measures-computing routine return both shareable and local-use values. - This way, unscaled metrics that are to be post-processed and/or checkpointed can be kept as-is rather than scaled-then-descaled. --- declearn/fairness/api/_client.py | 29 +++++++++++-------- declearn/fairness/core/_accuracy.py | 18 ++++++++++++ declearn/fairness/core/_fair_func.py | 1 + declearn/fairness/fairbatch/_client.py | 39 +++++++++++--------------- declearn/fairness/fairgrad/_client.py | 25 ++++++++--------- 5 files changed, 64 insertions(+), 48 deletions(-) diff --git a/declearn/fairness/api/_client.py b/declearn/fairness/api/_client.py index 7dd98ac..7b60c90 100644 --- a/declearn/fairness/api/_client.py +++ b/declearn/fairness/api/_client.py @@ -254,19 +254,20 @@ class FairnessControllerClient(metaclass=abc.ABCMeta): # Optionally update the wrapped model's weights. if query.weights is not None: self.manager.model.set_weights(query.weights, trainable=True) - # Compute, opt. encrypt and share fairness-related metrics. - values = self.compute_fairness_measures( + # Compute some fairness-related values, split between two sets. + share_values, local_values = self.compute_fairness_measures( query.batch_size, query.n_batch, query.thresh ) - reply = FairnessReply(values=values) + # Share the first set of values for their (secure-)aggregation. + reply = FairnessReply(values=share_values) if secagg is None: await netwk.send_message(reply) else: await netwk.send_message( SecaggFairnessReply.from_cleartext_message(reply, secagg) ) - # Return computed values. - return values + # Return the second set of values. + return local_values @abc.abstractmethod def compute_fairness_measures( @@ -274,7 +275,7 @@ class FairnessControllerClient(metaclass=abc.ABCMeta): batch_size: int, n_batch: Optional[int] = None, thresh: Optional[float] = None, - ) -> List[float]: + ) -> Tuple[List[float], List[float]]: """Compute fairness measures based on a received query. By default, compute and return group-wise accuracy metrics, @@ -297,9 +298,13 @@ class FairnessControllerClient(metaclass=abc.ABCMeta): Returns ------- - values: - Computed values, as a deterministic-length ordered list - of float values. + share_values: + Values that are to be shared with the orchestrating server, + as a deterministic-length list of float values. + local_values: + Values that are to be used in local post-processing steps. + This may be a reference to `share_values`, but is typically + designed to contain unscaled measures to checkpoint. """ @abc.abstractmethod @@ -320,8 +325,10 @@ class FairnessControllerClient(metaclass=abc.ABCMeta): netwk: NetworkClient endpoint instance, connected to a server. values: - List of locally-computed evaluation metrics, already shared - with the server for their (secure-)aggregation. + List of locally-computed evaluation metrics. + This is the second set of `compute_fairness_measures` return + values; when this method is called, the first has already + been shared with the server for (secure-)aggregation. secagg: Optional SecAgg encryption controller. diff --git a/declearn/fairness/core/_accuracy.py b/declearn/fairness/core/_accuracy.py index 0f2c32c..e842ef3 100644 --- a/declearn/fairness/core/_accuracy.py +++ b/declearn/fairness/core/_accuracy.py @@ -310,3 +310,21 @@ class FairnessAccuracyComputer: g_losses[group] = float(results[ModelLoss.name]) # Return the pair of dicts storing results. return accuracy, g_losses + + def scale_metrics_by_sample_counts( + self, + metrics: Dict[Tuple[Any, ...], float], + ) -> Dict[Tuple[Any, ...], float]: + """Scale a dict of computed group-wise metrics by sample counts. + + Parameters + ---------- + metrics: + Pre-computed raw metrics, as a `{group_k: score_k}` dict. + + Returns + ------- + metrics: + Scaled matrics, as a `{group_k: n_k * score_k}` dict. + """ + return {key: val * self.counts[key] for key, val in metrics.items()} diff --git a/declearn/fairness/core/_fair_func.py b/declearn/fairness/core/_fair_func.py index 6fab534..0a042fc 100644 --- a/declearn/fairness/core/_fair_func.py +++ b/declearn/fairness/core/_fair_func.py @@ -220,6 +220,7 @@ class FairnessFunction(metaclass=abc.ABCMeta): Values' interpretation depend on the implemented group-fairness definition, but overall the fairer the accuracy towards a group, the closer the metric is to zero. + Raises ------ KeyError diff --git a/declearn/fairness/fairbatch/_client.py b/declearn/fairness/fairbatch/_client.py index 41f4460..378aafd 100644 --- a/declearn/fairness/fairbatch/_client.py +++ b/declearn/fairness/fairbatch/_client.py @@ -17,7 +17,7 @@ """Client-side Fed-FairBatch controller.""" -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np @@ -126,7 +126,7 @@ class FairbatchControllerClient(FairnessControllerClient): batch_size: int, n_batch: Optional[int] = None, thresh: Optional[float] = None, - ) -> List[float]: + ) -> Tuple[List[float], List[float]]: # Compute group-wise accuracy scores and loss values. accuracy, loss = self.computer.compute_groupwise_accuracy_and_loss( model=self.manager.model, @@ -134,18 +134,18 @@ class FairbatchControllerClient(FairnessControllerClient): n_batch=n_batch, thresh=thresh, ) - # Multiply these values by sample counts. - accuracy = { - key: val * self.computer.counts[key] - for key, val in accuracy.items() - } - loss = { - key: val * self.computer.counts[key] for key, val in loss.items() - } - # Return shareable group-wise values, ordered and filled out. - return [accuracy.get(group, 0.0) for group in self.groups] + [ - loss.get(group, 0.0) for group in self.groups + # Flatten local values for post-processing and checkpointing. + local_values = list(accuracy.values()) + list(loss.values()) + # Scale local values by sample counts for their aggregation. + accuracy = self.computer.scale_metrics_by_sample_counts(accuracy) + loss = self.computer.scale_metrics_by_sample_counts(loss) + # Flatten shareable values, ordered and filled-out. + share_values = [ + *[accuracy.get(group, 0.0) for group in self.groups], + *[loss.get(group, 0.0) for group in self.groups], ] + # Return both sets of values. + return share_values, local_values async def finalize_fairness_round( self, @@ -156,16 +156,9 @@ class FairbatchControllerClient(FairnessControllerClient): # Await updated loss weights from the server. await self._update_fairbatch_sampling_probas(netwk) # Recover raw accuracy and loss values for groups with local samples. - accuracy = { - key: val / self.computer.counts[key] - for key, val in zip(self.groups, values[: len(self.groups)]) - if key in self.computer.counts - } - loss = { - key: val / self.computer.counts[key] - for key, val in zip(self.groups, values[len(self.groups) :]) - if key in self.computer.counts - } + groups = list(self.computer.g_data) + accuracy = dict(zip(groups, values[: len(groups)])) + loss = dict(zip(groups, values[len(groups) :])) # Compute local fairness measures. fairness = self.fairness_function.compute_from_group_accuracy(accuracy) f_type = self.fairness_function.f_type diff --git a/declearn/fairness/fairgrad/_client.py b/declearn/fairness/fairgrad/_client.py index 3e21b27..11308d0 100644 --- a/declearn/fairness/fairgrad/_client.py +++ b/declearn/fairness/fairgrad/_client.py @@ -17,7 +17,7 @@ """Client-side Fed-FairGrad controller.""" -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np @@ -116,7 +116,7 @@ class FairgradControllerClient(FairnessControllerClient): batch_size: int, n_batch: Optional[int] = None, thresh: Optional[float] = None, - ) -> List[float]: + ) -> Tuple[List[float], List[float]]: # Compute group-wise accuracy scores. accuracy = self.computer.compute_groupwise_accuracy( model=self.manager.model, @@ -124,13 +124,14 @@ class FairgradControllerClient(FairnessControllerClient): n_batch=n_batch, thresh=thresh, ) - # Multiply these scores by sample counts. - accuracy = { - key: val * self.computer.counts[key] - for key, val in accuracy.items() - } - # Return shareable group-wise values, ordered and filled out. - return [accuracy.get(group, 0.0) for group in self.groups] + # Flatten local values for post-processing and checkpointing. + local_values = list(accuracy.values()) + # Scale local values by sample counts for their aggregation. + accuracy = self.computer.scale_metrics_by_sample_counts(accuracy) + # Flatten shareable values, ordered and filled-out. + share_values = [accuracy.get(group, 0.0) for group in self.groups] + # Return both sets of values. + return share_values, local_values async def finalize_fairness_round( self, @@ -141,11 +142,7 @@ class FairgradControllerClient(FairnessControllerClient): # Await updated loss weights from the server. await self._update_fairgrad_weights(netwk) # Recover raw accuracy scores for groups with local samples. - accuracy = { - key: val / self.computer.counts[key] - for key, val in zip(self.groups, values) - if key in self.computer.counts - } + accuracy = dict(zip(self.computer.g_data, values)) # Compute local fairness measures. fairness = self.fairness_function.compute_from_group_accuracy(accuracy) f_type = self.fairness_function.f_type -- GitLab