diff --git a/declearn/fairness/api/_client.py b/declearn/fairness/api/_client.py index 7dd98ac5e50372f1f9d603c0185224701f5dc1bb..7b60c902f9737adebe953e8c625b27a9cce61973 100644 --- a/declearn/fairness/api/_client.py +++ b/declearn/fairness/api/_client.py @@ -254,19 +254,20 @@ class FairnessControllerClient(metaclass=abc.ABCMeta): # Optionally update the wrapped model's weights. if query.weights is not None: self.manager.model.set_weights(query.weights, trainable=True) - # Compute, opt. encrypt and share fairness-related metrics. - values = self.compute_fairness_measures( + # Compute some fairness-related values, split between two sets. + share_values, local_values = self.compute_fairness_measures( query.batch_size, query.n_batch, query.thresh ) - reply = FairnessReply(values=values) + # Share the first set of values for their (secure-)aggregation. + reply = FairnessReply(values=share_values) if secagg is None: await netwk.send_message(reply) else: await netwk.send_message( SecaggFairnessReply.from_cleartext_message(reply, secagg) ) - # Return computed values. - return values + # Return the second set of values. + return local_values @abc.abstractmethod def compute_fairness_measures( @@ -274,7 +275,7 @@ class FairnessControllerClient(metaclass=abc.ABCMeta): batch_size: int, n_batch: Optional[int] = None, thresh: Optional[float] = None, - ) -> List[float]: + ) -> Tuple[List[float], List[float]]: """Compute fairness measures based on a received query. By default, compute and return group-wise accuracy metrics, @@ -297,9 +298,13 @@ class FairnessControllerClient(metaclass=abc.ABCMeta): Returns ------- - values: - Computed values, as a deterministic-length ordered list - of float values. + share_values: + Values that are to be shared with the orchestrating server, + as a deterministic-length list of float values. + local_values: + Values that are to be used in local post-processing steps. + This may be a reference to `share_values`, but is typically + designed to contain unscaled measures to checkpoint. """ @abc.abstractmethod @@ -320,8 +325,10 @@ class FairnessControllerClient(metaclass=abc.ABCMeta): netwk: NetworkClient endpoint instance, connected to a server. values: - List of locally-computed evaluation metrics, already shared - with the server for their (secure-)aggregation. + List of locally-computed evaluation metrics. + This is the second set of `compute_fairness_measures` return + values; when this method is called, the first has already + been shared with the server for (secure-)aggregation. secagg: Optional SecAgg encryption controller. diff --git a/declearn/fairness/core/_accuracy.py b/declearn/fairness/core/_accuracy.py index 0f2c32c32238d7f09cc5c1e2f37d1388f6420639..e842ef37b4d743b35338bbb6fc1846e67346d89a 100644 --- a/declearn/fairness/core/_accuracy.py +++ b/declearn/fairness/core/_accuracy.py @@ -310,3 +310,21 @@ class FairnessAccuracyComputer: g_losses[group] = float(results[ModelLoss.name]) # Return the pair of dicts storing results. return accuracy, g_losses + + def scale_metrics_by_sample_counts( + self, + metrics: Dict[Tuple[Any, ...], float], + ) -> Dict[Tuple[Any, ...], float]: + """Scale a dict of computed group-wise metrics by sample counts. + + Parameters + ---------- + metrics: + Pre-computed raw metrics, as a `{group_k: score_k}` dict. + + Returns + ------- + metrics: + Scaled matrics, as a `{group_k: n_k * score_k}` dict. + """ + return {key: val * self.counts[key] for key, val in metrics.items()} diff --git a/declearn/fairness/core/_fair_func.py b/declearn/fairness/core/_fair_func.py index 6fab534317380ef2c0063a75dfaa54dc9f74b8b4..0a042fc3956590c090f81ed30c392e563fedc14d 100644 --- a/declearn/fairness/core/_fair_func.py +++ b/declearn/fairness/core/_fair_func.py @@ -220,6 +220,7 @@ class FairnessFunction(metaclass=abc.ABCMeta): Values' interpretation depend on the implemented group-fairness definition, but overall the fairer the accuracy towards a group, the closer the metric is to zero. + Raises ------ KeyError diff --git a/declearn/fairness/fairbatch/_client.py b/declearn/fairness/fairbatch/_client.py index 41f4460f6d007cf5556d49730cd85c0e9bae15f7..378aafd925a34d3c574014a0ec55a2e7e3419695 100644 --- a/declearn/fairness/fairbatch/_client.py +++ b/declearn/fairness/fairbatch/_client.py @@ -17,7 +17,7 @@ """Client-side Fed-FairBatch controller.""" -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np @@ -126,7 +126,7 @@ class FairbatchControllerClient(FairnessControllerClient): batch_size: int, n_batch: Optional[int] = None, thresh: Optional[float] = None, - ) -> List[float]: + ) -> Tuple[List[float], List[float]]: # Compute group-wise accuracy scores and loss values. accuracy, loss = self.computer.compute_groupwise_accuracy_and_loss( model=self.manager.model, @@ -134,18 +134,18 @@ class FairbatchControllerClient(FairnessControllerClient): n_batch=n_batch, thresh=thresh, ) - # Multiply these values by sample counts. - accuracy = { - key: val * self.computer.counts[key] - for key, val in accuracy.items() - } - loss = { - key: val * self.computer.counts[key] for key, val in loss.items() - } - # Return shareable group-wise values, ordered and filled out. - return [accuracy.get(group, 0.0) for group in self.groups] + [ - loss.get(group, 0.0) for group in self.groups + # Flatten local values for post-processing and checkpointing. + local_values = list(accuracy.values()) + list(loss.values()) + # Scale local values by sample counts for their aggregation. + accuracy = self.computer.scale_metrics_by_sample_counts(accuracy) + loss = self.computer.scale_metrics_by_sample_counts(loss) + # Flatten shareable values, ordered and filled-out. + share_values = [ + *[accuracy.get(group, 0.0) for group in self.groups], + *[loss.get(group, 0.0) for group in self.groups], ] + # Return both sets of values. + return share_values, local_values async def finalize_fairness_round( self, @@ -156,16 +156,9 @@ class FairbatchControllerClient(FairnessControllerClient): # Await updated loss weights from the server. await self._update_fairbatch_sampling_probas(netwk) # Recover raw accuracy and loss values for groups with local samples. - accuracy = { - key: val / self.computer.counts[key] - for key, val in zip(self.groups, values[: len(self.groups)]) - if key in self.computer.counts - } - loss = { - key: val / self.computer.counts[key] - for key, val in zip(self.groups, values[len(self.groups) :]) - if key in self.computer.counts - } + groups = list(self.computer.g_data) + accuracy = dict(zip(groups, values[: len(groups)])) + loss = dict(zip(groups, values[len(groups) :])) # Compute local fairness measures. fairness = self.fairness_function.compute_from_group_accuracy(accuracy) f_type = self.fairness_function.f_type diff --git a/declearn/fairness/fairgrad/_client.py b/declearn/fairness/fairgrad/_client.py index 3e21b272f2cdb9b540a71e6314df3fa66833af0c..11308d07635e3702ed9ca6fb5b26719debc99342 100644 --- a/declearn/fairness/fairgrad/_client.py +++ b/declearn/fairness/fairgrad/_client.py @@ -17,7 +17,7 @@ """Client-side Fed-FairGrad controller.""" -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np @@ -116,7 +116,7 @@ class FairgradControllerClient(FairnessControllerClient): batch_size: int, n_batch: Optional[int] = None, thresh: Optional[float] = None, - ) -> List[float]: + ) -> Tuple[List[float], List[float]]: # Compute group-wise accuracy scores. accuracy = self.computer.compute_groupwise_accuracy( model=self.manager.model, @@ -124,13 +124,14 @@ class FairgradControllerClient(FairnessControllerClient): n_batch=n_batch, thresh=thresh, ) - # Multiply these scores by sample counts. - accuracy = { - key: val * self.computer.counts[key] - for key, val in accuracy.items() - } - # Return shareable group-wise values, ordered and filled out. - return [accuracy.get(group, 0.0) for group in self.groups] + # Flatten local values for post-processing and checkpointing. + local_values = list(accuracy.values()) + # Scale local values by sample counts for their aggregation. + accuracy = self.computer.scale_metrics_by_sample_counts(accuracy) + # Flatten shareable values, ordered and filled-out. + share_values = [accuracy.get(group, 0.0) for group in self.groups] + # Return both sets of values. + return share_values, local_values async def finalize_fairness_round( self, @@ -141,11 +142,7 @@ class FairgradControllerClient(FairnessControllerClient): # Await updated loss weights from the server. await self._update_fairgrad_weights(netwk) # Recover raw accuracy scores for groups with local samples. - accuracy = { - key: val / self.computer.counts[key] - for key, val in zip(self.groups, values) - if key in self.computer.counts - } + accuracy = dict(zip(self.computer.g_data, values)) # Compute local fairness measures. fairness = self.fairness_function.compute_from_group_accuracy(accuracy) f_type = self.fairness_function.f_type