diff --git a/declearn/fairness/__init__.py b/declearn/fairness/__init__.py index 15eb1aaca2e39fc0b807e224bbf725045ffee318..ede33109cf7d2aec898c0c5f309bccb51e5a365d 100644 --- a/declearn/fairness/__init__.py +++ b/declearn/fairness/__init__.py @@ -23,10 +23,13 @@ This module implements the following submodules: API to set up and run fairness-aware federated learning algorithms. * [core][declearn.fairness.core]: Core components and utils for fairness-aware (federated) machine learning. +* [fairbatch][declearn.fairness.fairbatch]: + Fed-FairBatch / FedB algorithm controllers and utils. * [fairgrad][declearn.fairness.fairgrad]: Fed-FairGrad algorithm controllers and utils. """ from . import core from . import api +from . import fairbatch from . import fairgrad diff --git a/declearn/fairness/fairbatch/__init__.py b/declearn/fairness/fairbatch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0dbd6f1759f284052741af57347548c339027df8 --- /dev/null +++ b/declearn/fairness/fairbatch/__init__.py @@ -0,0 +1,85 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Fed-FairBatch / FedFB algorithm controllers and utils. + +Introduction +------------ +This module provides with a double-fold implementation of an adaptation +of the FairBatch [1] algorithm for federated learning. On the one hand, +the FedFB [2] algorithm is implemented, that both adapts FairBatch in a +straightforward manner and introduces changes in formulas compared with +the initial paper. On the other hand, the a custom algorithm deemed as +Fed-FairBatch is implemented, that is similar in intent to FedFB but +sticks to the raw FairBatch formulas. + +FairBatch is a group-fairness-enforcing algorithm that relies on a +specific form of loss reweighting mediated by a specific batching +of samples for SGD steps. Namely, in FairBatch, batches are drawn +by concatenating group-wise sub-batches, the size of which is the +byproduct of the desired total batch size and group-wise sampling +probabilities, with the latter being updated throughout training +based on the current model's fairness. + +Initially, FairBatch is designed for binary classification tasks +on data that have a single binary sensitive attribute. Both our +implementations currently stick to that setting, in spite of the +FedFB authors using a formalism that arguably extend formulas to +more generic categorical sensitive attribute(s) - which is not +tested in the paper. + +Controllers +----------- +* [FairbatchControllerClient] +[declearn.fairness.fairbatch.FairgradControllerClient]: + Client-side controller to implement Fed-FairBatch or FedFB. +* [FairbatchControllerServer] +[declearn.fairness.fairbatch.FairgradControllerServer]: + Server-side controller to implement Fed-FairBatch or FedFB. + +Backend +------- +* [FairbatchSamplingController] +[declearn.fairness.fairbatch.FairbatchSamplingController]: + ABC to compute and update Fairbatch sampling probabilities. +* [setup_fairbatch_controller] +[declearn.fairness.fairbatch.setup_fairbatch_controller]: + Instantiate a FairBatch sampling probabilities controller. +* [setup_fedfb_controller] +[declearn.fairness.fairbatch.setup_fedfb_controller]: + Instantiate a FedFB sampling probabilities controller. + +Messages +-------- +* [FairbatchOkay][declearn.fairness.fairbatch.FairbatchOkay]: + Message for client signal that Fed-FairBatch/FedFB update went fine. +* [FairbatchSamplingProbas] +[declearn.fairness.fairbatch.FairbatchSamplingProbas]: + Message for server-emitted Fed-FairBatch/Fed-FB sampling probabilities. +""" + +from ._messages import ( + FairbatchOkay, + FairbatchSamplingProbas, +) +from ._sampling import ( + FairbatchSamplingController, + setup_fairbatch_controller, +) +from ._fedfb import setup_fedfb_controller +from ._client import FairbatchControllerClient +from ._server import FairbatchControllerServer diff --git a/declearn/fairness/fairbatch/_client.py b/declearn/fairness/fairbatch/_client.py new file mode 100644 index 0000000000000000000000000000000000000000..41f4460f6d007cf5556d49730cd85c0e9bae15f7 --- /dev/null +++ b/declearn/fairness/fairbatch/_client.py @@ -0,0 +1,180 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Client-side Fed-FairBatch controller.""" + +from typing import Any, Dict, List, Optional, Union + +import numpy as np + +from declearn.communication.api import NetworkClient +from declearn.communication.utils import verify_server_message_validity +from declearn.fairness.api import FairnessControllerClient +from declearn.fairness.core import ( + FairnessDataset, + instantiate_fairness_function, +) +from declearn.fairness.fairbatch._dataset import FairbatchDataset +from declearn.fairness.fairbatch._messages import ( + FairbatchSamplingProbas, + FairbatchOkay, +) +from declearn.messaging import Error +from declearn.secagg.api import Encrypter +from declearn.training import TrainingManager + +__all__ = [ + "FairbatchControllerClient", +] + + +class FairbatchControllerClient(FairnessControllerClient): + """Client-side controller to implement Fed-FairBatch or FedFB.""" + + algorithm = "fedfairbatch" + + def __init__( + self, + manager: TrainingManager, + f_type: str, + f_args: Dict[str, Any], + ) -> None: + """Instantiate the client-side fairness controller. + + Parameters + ---------- + manager: + `TrainingManager` instance wrapping the model being trained + and its training dataset (that must be a `FairnessDataset`). + f_type: + Name of the type of group-fairness function being optimized. + f_args: + Keyword arguments to the group-fairness function. + """ + super().__init__(manager) + assert isinstance(self.manager.train_data, FairnessDataset) + self.manager.train_data = FairbatchDataset(self.manager.train_data) + self.fairness_function = instantiate_fairness_function( + f_type=f_type, counts=self.computer.counts, **f_args + ) + + async def finalize_fairness_setup( + self, + netwk: NetworkClient, + secagg: Optional[Encrypter], + ) -> None: + pass # no action required beyond sharing group definitions and counts + + async def _update_fairbatch_sampling_probas( + self, + netwk: NetworkClient, + ) -> None: + """Run a FairBatch-specific routine to update sampling probabilities. + + Expect a message from the orchestrating server containing the new + sensitive group sampling probabilities, and apply them to the + training dataset. + + Raises + ------ + RuntimeError: + If the expected message is not received. + If the sampling pobabilities' update fails. + """ + # Receive aggregated sensitive weights. + received = await netwk.check_message() + message = await verify_server_message_validity( + netwk, received, expected=FairbatchSamplingProbas + ) + probas = dict(zip(self.groups, message.probas)) + # Set the received weights, handling and propagating exceptions if any. + try: + assert isinstance(self.manager.train_data, FairbatchDataset) + self.manager.train_data.set_sampling_probabilities( + group_probas=probas + ) + except Exception as exc: + self.manager.logger.error( + "Exception encountered when setting FairBatch sampling" + "probabilities: %s", + repr(exc), + ) + await netwk.send_message(Error(repr(exc))) + raise RuntimeError( + "FairBatch sampling probabilities update failed." + ) from exc + # If things went well, ping the server back to indicate so. + self.manager.logger.info("Updated FairBatch sampling probabilities.") + await netwk.send_message(FairbatchOkay()) + + def compute_fairness_measures( + self, + batch_size: int, + n_batch: Optional[int] = None, + thresh: Optional[float] = None, + ) -> List[float]: + # Compute group-wise accuracy scores and loss values. + accuracy, loss = self.computer.compute_groupwise_accuracy_and_loss( + model=self.manager.model, + batch_size=batch_size, + n_batch=n_batch, + thresh=thresh, + ) + # Multiply these values by sample counts. + accuracy = { + key: val * self.computer.counts[key] + for key, val in accuracy.items() + } + loss = { + key: val * self.computer.counts[key] for key, val in loss.items() + } + # Return shareable group-wise values, ordered and filled out. + return [accuracy.get(group, 0.0) for group in self.groups] + [ + loss.get(group, 0.0) for group in self.groups + ] + + async def finalize_fairness_round( + self, + netwk: NetworkClient, + values: List[float], + secagg: Optional[Encrypter], + ) -> Dict[str, Union[float, np.ndarray]]: + # Await updated loss weights from the server. + await self._update_fairbatch_sampling_probas(netwk) + # Recover raw accuracy and loss values for groups with local samples. + accuracy = { + key: val / self.computer.counts[key] + for key, val in zip(self.groups, values[: len(self.groups)]) + if key in self.computer.counts + } + loss = { + key: val / self.computer.counts[key] + for key, val in zip(self.groups, values[len(self.groups) :]) + if key in self.computer.counts + } + # Compute local fairness measures. + fairness = self.fairness_function.compute_from_group_accuracy(accuracy) + f_type = self.fairness_function.f_type + # Package and return accuracy and fairness metrics. + metrics = { + f"accuracy_{key}": val for key, val in accuracy.items() + } # type: Dict[str, Union[float, np.ndarray]] + metrics.update({f"loss_{key}": val for key, val in loss.items()}) + metrics.update( + {f"{f_type}_{key}": val for key, val in fairness.items()} + ) + return metrics diff --git a/declearn/fairness/fairbatch/_dataset.py b/declearn/fairness/fairbatch/_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..8fb1411afca94b5862e2588aa587eb52bbbe7703 --- /dev/null +++ b/declearn/fairness/fairbatch/_dataset.py @@ -0,0 +1,292 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""FairBatch-specific Dataset wrapper and subclass.""" + +from typing import Any, Dict, Iterator, List, Sequence, Tuple + +import numpy as np + +from declearn.dataset import Dataset, DataSpecs +from declearn.fairness.core import FairnessDataset +from declearn.typing import Batch + +__all__ = [ + "FairbatchDataset", +] + + +class FairbatchDataset(FairnessDataset): + """FairBatch-specific FairnessDataset subclass and wrapper.""" + + def __init__( + self, + base: FairnessDataset, + ) -> None: + """Instantiate a FairbatchDataset wrapping a FairnessDataset. + + Parameters + ---------- + base: + Base `FairnessDataset` instance to wrap so as to apply + group-wise subsampling as per the FairBatch algorithm. + """ + self.base = base + # Assign a dictionary with sampling probability for each group. + self.groups = self.get_sensitive_group_definitions() + self._counts = self.base.get_sensitive_group_counts() + self._sampling_probas = { + group: 1.0 / len(self.groups) for group in self.groups + } + + # Methods provided by the wrapped dataset (merely interfaced). + + def get_data_specs( + self, + ) -> DataSpecs: + return self.base.get_data_specs() + + def get_sensitive_group_definitions( + self, + ) -> List[Tuple[Any, ...]]: + return self.groups + + def get_sensitive_group_counts( + self, + ) -> Dict[Tuple[Any, ...], int]: + return self._counts.copy() + + def get_sensitive_group_subset( + self, + group: Tuple[Any, ...], + ) -> Dataset: + return self.base.get_sensitive_group_subset(group) + + def set_sensitive_group_weights( + self, + weights: Dict[Tuple[Any, ...], float], + adjust_by_counts: bool = False, + ) -> None: + self.base.set_sensitive_group_weights(weights, adjust_by_counts) + + # FairBatch-specific methods. + + def get_sampling_probabilities( + self, + ) -> Dict[Tuple[Any, ...], float]: + """Access current group-wise sampling probabilities.""" + return self._sampling_probas.copy() + + def set_sampling_probabilities( + self, + group_probas: Dict[Tuple[Any, ...], float], + ) -> None: + """Assign new group-wise sampling probabilities. + + If some groups are not present in the wrapped dataset, + scale the probabilities associated with all represented + groups so that they sum to 1. + + Parameters + ---------- + group_probas: + Dict of group-wise sampling probabilities, with + `{(s_attr_1, ..., s_attr_k): sampling_proba}` format. + + Raises + ------ + ValueError + If the input probabilities are not positive values + or if they do not cover (a superset of) all sensitive + groups present in the wrapped dataset. + """ + # Verify that input match expectations. + if not all(x >= 0 for x in group_probas.values()): + raise ValueError( + f"'{self.__class__.__name__}.update_sampling_probabilities' " + "cannot have a negative probability value as parameter." + ) + if not set(self.groups).issubset(group_probas): + raise ValueError( + "'FairbatchDataset.update_sampling_probabilities' requires " + "input values to cover (a superset of) local sensitive groups." + ) + # Restrict and adjust probabilities to groups with samples. + probas = {group: group_probas[group] for group in self.groups} + total = sum(probas.values()) + self._sampling_probas = { + key: val / total for key, val in probas.items() + } + + def generate_batches( + self, + batch_size: int, + shuffle: bool = False, + drop_remainder: bool = True, + replacement: bool = False, + poisson: bool = False, + ) -> Iterator[Batch]: + # inherited signature; pylint: disable=too-many-arguments + # NOTE: we could add support for those, but let's start simple. + if not drop_remainder: + raise ValueError( + f"'{self.__class__.__name__}.generate_batches' does not " + "support argument value 'drop_remainder=False'." + ) + # Compute the number of batches to yield. + nb_batches = sum(self._counts.values()) // batch_size + # Compute the group-wise number of samples per batch. + # NOTE: this number may be reduced if there are too few samples. + group_batch_size = { + group: round(proba * batch_size) + for group, proba in self._sampling_probas.items() + } + # Yield batches made of a fixed number of samples from each group. + generators = [ + self._generate_sensitive_group_batches( + group, nb_batches, g_batch_size, shuffle, replacement, poisson + ) + for group, g_batch_size in group_batch_size.items() + if g_batch_size > 0 + ] + for batches in zip(*generators): + yield self._concatenate_batches(batches) + + @staticmethod + def _concatenate_batches( + batches: Sequence[Batch], + ) -> Batch: + """Concatenate batches of numpy array data.""" + x_dat = np.concatenate([batch[0] for batch in batches], axis=0) + y_dat = ( + None + if batches[0][1] is None + else np.concatenate([batch[1] for batch in batches], axis=0) + ) + w_dat = ( + None + if batches[0][2] is None + else np.concatenate([batch[2] for batch in batches], axis=0) + ) + return x_dat, y_dat, w_dat + + def _generate_sensitive_group_batches( + self, + group: Tuple[Any, ...], + nb_batches: int, + batch_size: int, + shuffle: bool, + replacement: bool, + poisson: bool, + ) -> Iterator[Batch]: + """Generate a fixed number of batches for a given sensitive group. + + Parameters + ---------- + group: + Sensitive group, the dataset from which to draw from. + nb_batches: + Number of batches to yield. The dataset will be iterated + over if needed to achieve this number. + batch_size: + Number of samples per batch (will be exact). + shuffle: + Whether to shuffle the dataset prior to drawing batches. + replacement: + Whether to draw with replacement between batches. + poisson: + Whether to use poisson sampling rather than batching. + """ + # backend method; pylint: disable=too-many-arguments + # Fetch the target sub-dataset and its samples count. + dataset = self.get_sensitive_group_subset(group) + n_samples = self._counts[group] + # Adjust batch size when needed and set up a batches generator. + n_repeats, batch_size = divmod(batch_size, n_samples) + generator = self._generate_batches( + # fmt: off + dataset, group, nb_batches, batch_size, + shuffle, replacement, poisson, + ) + # When the batch size is larger than the number of data points, + # make up a base batch will all points (duplicated if needed), + # that will be combined with further batches of data. + if n_repeats: + full = self._get_full_dataset(dataset, n_samples, group) + full = self._concatenate_batches([full] * n_repeats) + for batch in generator: + yield self._concatenate_batches((full, batch)) + # Otherwise, merely yield from the generator. + else: + yield from generator + + def _generate_batches( + self, + dataset: Dataset, + group: Tuple[Any, ...], + nb_batches: int, + batch_size: int, + shuffle: bool, + replacement: bool, + poisson: bool, + ) -> Iterator[Batch]: + """Backend to yield a fixed number of batches from a dataset.""" + # backend method; pylint: disable=too-many-arguments + # Iterate multiple times over the sub-dataset if needed. + counter = 0 + while counter < nb_batches: + # Yield batches from the sub-dataset. + generator = dataset.generate_batches( + batch_size=batch_size, + shuffle=shuffle, + drop_remainder=True, + replacement=replacement, + poisson=poisson, + ) + for batch in generator: + yield batch + counter += 1 + if counter == nb_batches: + break + # Prevent infinite loops and raise an informative error. + if not counter: # pragma: no cover + raise RuntimeError( + f"'{self.__class__.__name__}.generate_batches' triggered " + "an infinite loop; this happened when trying to extract " + f"{batch_size}-samples batches for group {group}." + ) + + @staticmethod + def _get_full_dataset( + dataset: Dataset, + n_samples: int, + group: Tuple[Any, ...], + ) -> Batch: + """Return a batch containing an entire dataset's samples.""" + try: + generator = dataset.generate_batches( + batch_size=n_samples, + shuffle=False, + drop_remainder=False, + replacement=False, + poisson=False, + ) + return next(generator) + except StopIteration as exc: # pragma: no cover + raise RuntimeError( + f"Failed to fetch the full subdataset for group '{group}'." + ) from exc diff --git a/declearn/fairness/fairbatch/_fedfb.py b/declearn/fairness/fairbatch/_fedfb.py new file mode 100644 index 0000000000000000000000000000000000000000..8f3b1a1cf421d30ad3a7b54646eac9d343f2ff38 --- /dev/null +++ b/declearn/fairness/fairbatch/_fedfb.py @@ -0,0 +1,263 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""FedFB sampling probability controllers.""" + +from typing import Any, Dict, Tuple + +import numpy as np + +from declearn.fairness.fairbatch._sampling import ( + FairbatchDemographicParity, + FairbatchEqualizedOdds, + FairbatchEqualityOpportunity, + FairbatchSamplingController, + assign_sensitive_group_labels, +) + + +__all__ = [ + "setup_fedfb_controller", +] + + +class FedFBEqualityOpportunity(FairbatchEqualityOpportunity): + """FedFB variant of Equality of Opportunity controller. + + This variant introduces two changes as compared with our FedFairBatch: + - The lambda parameter and difference of losses are written with a + different group ordering, albeit resulting in identical results. + - When comparing loss values over sensitive groups, the notations from + the FedFB paper indicate that the sums of losses over samples in the + groups are compared, rather than the averages of group-wise losses; + this implementation sticks to the FedFB paper. + """ + + f_type = "equality_of_opportunity" + + def get_sampling_probas( + self, + ) -> Dict[Tuple[Any, ...], float]: + # Revert the sense of lambda (invert (1, 0) and (0, 1) groups) + # to stick with notations from the FedFB paper. + probas = super().get_sampling_probas() + label_10 = self.groups["1_0"] + label_11 = self.groups["1_1"] + probas[label_10], probas[label_11] = probas[label_11], probas[label_10] + return probas + + def update_from_losses( + self, + losses: Dict[Tuple[Any, ...], float], + ) -> None: + # Recover sum-aggregated losses for the two groups of interest. + # Do not scale: obtain sums of sample losses for each group. + # This differs from parent class (and centralized FairBatch) + # but sticks with the FedFB paper's notations and algorithm. + loss_10 = losses[self.groups["1_0"]] * self.counts[self.groups["1_0"]] + loss_11 = losses[self.groups["1_1"]] * self.counts[self.groups["1_1"]] + # Update lambda based on these and the alpha hyper-parameter. + # Note: this is the same as in parent class, inverting sense of + # groups (0, 0) and (1, 0), to stick with the FedFB paper. + if loss_11 > loss_10: + self.states["lambda"] = min( + self.states["lambda"] + self.alpha, self.states["p_tgt_1"] + ) + elif loss_11 < loss_10: + self.states["lambda"] = max(self.states["lambda"] - self.alpha, 0) + + +class FedFBEqualizedOdds(FairbatchEqualizedOdds): + """FedFB variant of Equalized Odds controller. + + This variant introduces three changes as compared with our FedFairBatch: + - The lambda parameters and difference of losses are written with a + different group ordering, albeit resulting in identical results. + - When comparing loss values over sensitive groups, the notations from + the FedFB paper indicate that the sums of losses over samples in the + groups are compared, rather than the averages of group-wise losses; + this implementation sticks to the FedFB paper. + - The update rule for lambda parameters has a distinct formula, with the + alpha hyper-parameter being here scaled by the difference in losses + and normalized by the L2 norm of differences in losses, and both groups' + lambda being updated at each step. + """ + + f_type = "equalized_odds" + + def compute_initial_states( + self, + ) -> Dict[str, float]: + # Switch lambdas: apply to groups (-, 1) rather than (-, 0). + states = super().compute_initial_states() + states["lambda_1"] = states["p_trgt_0"] - states["lambda_1"] + states["lambda_2"] = states["p_trgt_1"] - states["lambda_2"] + return states + + def get_sampling_probas( + self, + ) -> Dict[Tuple[Any, ...], float]: + # Rewrite the rules entirely, effectively swapping (0,1)/(0,0) + # and (1,1)/(1,0) groups compared with parent implementation. + states = self.states + return { + self.groups["0_0"]: states["p_trgt_0"] - states["lambda_1"], + self.groups["0_1"]: states["lambda_1"], + self.groups["1_0"]: states["p_trgt_1"] - states["lambda_2"], + self.groups["1_1"]: states["lambda_2"], + } + + def update_from_losses( + self, + losses: Dict[Tuple[Any, ...], float], + ) -> None: + # Recover sum-aggregated losses for each sensitive group. + # Do not scale: obtain sums of sample losses for each group. + # This differs from parent class (and centralized FairBatch) + # but sticks with the FedFB paper's notations and algorithm. + labeled_losses = { + label: losses[group] * self.counts[group] + for label, group in self.groups.items() + } + # Compute aggregated-loss differences for each target label. + diff_loss_tgt_0 = labeled_losses["0_1"] - labeled_losses["0_0"] + diff_loss_tgt_1 = labeled_losses["1_1"] - labeled_losses["1_0"] + # Compute the euclidean norm of these values. + den = float(np.linalg.norm([diff_loss_tgt_0, diff_loss_tgt_1], ord=2)) + # Update lambda_1 (affecting groups with y=0). + update = self.alpha * diff_loss_tgt_0 / den + self.states["lambda_1"] = min( + self.states["lambda_1"] + update, self.states["p_trgt_0"] + ) + self.states["lambda_1"] = max(self.states["lambda_1"], 0) + # Update lambda_1 (affecting groups with y=1). + update = self.alpha * diff_loss_tgt_1 / den + self.states["lambda_2"] = min( + self.states["lambda_2"] + update, self.states["p_trgt_1"] + ) + self.states["lambda_2"] = max(self.states["lambda_2"], 0) + + +class FedFBDemographicParity(FairbatchDemographicParity): + """FairbatchSamplingController subclass for 'demographic_parity'.""" + + f_type = "demographic_parity" + + def update_from_losses( + self, + losses: Dict[Tuple[Any, ...], float], + ) -> None: + # NOTE: losses' aggregation does not defer from parent class. + # Recover sum-aggregated losses for each sensitive group. + # Obtain {k: n_k * Sum(loss for all samples in group k)}. + labeled_losses = { + label: losses[group] * self.counts[group] + for label, group in self.groups.items() + } + # Normalize losses based on sensitive attribute counts. + # Obtain {k: sum(loss for samples in k) / n_samples_with_attr}. + labeled_losses["0_0"] /= self.states["n_attr_0"] + labeled_losses["0_1"] /= self.states["n_attr_1"] + labeled_losses["1_0"] /= self.states["n_attr_0"] + labeled_losses["1_1"] /= self.states["n_attr_1"] + # NOTE: this is where things differ from parent class. + # Compute an overall fairness value based on all losses. + f_val = ( + -labeled_losses["0_0"] + + labeled_losses["0_1"] + + labeled_losses["1_0"] + - labeled_losses["1_1"] + + self.counts[self.groups["0_0"]] / self.states["n_attr_0"] + - self.counts[self.groups["0_1"]] / self.states["n_attr_1"] + ) + # Update both lambdas based on this overall value. + # Note: in the binary attribute case, $mu_a / ||mu||_2$ + # is equal to $sign(mu_1) / sqrt(2)$. + update = float(np.sign(f_val) * self.alpha / np.sqrt(2)) + self.states["lambda_1"] = min( + self.states["lambda_1"] - update, self.states["p_attr_0"] + ) + self.states["lambda_1"] = max(self.states["lambda_1"], 0) + self.states["lambda_2"] = min( + self.states["lambda_2"] - update, self.states["p_attr_1"] + ) + self.states["lambda_2"] = max(self.states["lambda_2"], 0) + + +def setup_fedfb_controller( + f_type: str, + counts: Dict[Tuple[Any, ...], int], + target: int = 1, + alpha: float = 0.005, +) -> FairbatchSamplingController: + """Instantiate a FedFB sampling probabilities controller. + + This is a drop-in replacement for `setup_fedfairbatch_controller` + that implemented update rules matching the Fed-FB algorithm(s) as + introduced in [1]. + + Parameters + ---------- + f_type: + Type of group fairness to optimize for. + counts: + Dict mapping sensitive group definitions to their total + sample counts (across clients). These groups must arise + from the crossing of a binary target label and a binary + sensitive attribute. + target: + Target label to treat as positive. + alpha: + Alpha hyper-parameter, scaling the magnitude of sampling + probabilities' updates by the returned controller. + + Returns + ------- + controller: + FairBatch sampling probabilities controller matching inputs. + + Raises + ------ + KeyError + If `f_type` does not match any known or supported fairness type. + ValueError + If `counts` keys cannot be matched to canonical group labels. + + References + ---------- + [1] Zeng et al. (2022). + Improving Fairness via Federated Learning. + https://arxiv.org/abs/2110.15545 + """ + controller_types = { + "demographic_parity": FedFBDemographicParity, + "equality_of_opportunity": FedFBEqualityOpportunity, + "equalized_odds": FedFBEqualizedOdds, + } + controller_cls = controller_types.get(f_type, None) + if controller_cls is None: + raise KeyError( + "Unknown or unsupported fairness type parameter for FairBatch " + f"controller initialization: '{f_type}'. Supported values are " + f"{list(controller_types)}." + ) + # Match groups to canonical labels and instantiate the controller. + groups = assign_sensitive_group_labels(groups=list(counts), target=target) + kwargs = {"target": target} if f_type == "equality_of_opportunity" else {} + return controller_cls( # type: ignore[abstract] + groups=groups, counts=counts, alpha=alpha, **kwargs + ) diff --git a/declearn/fairness/fairbatch/_messages.py b/declearn/fairness/fairbatch/_messages.py new file mode 100644 index 0000000000000000000000000000000000000000..c1b9ec19db977cb64ba736f194a459a3b5c23b0d --- /dev/null +++ b/declearn/fairness/fairbatch/_messages.py @@ -0,0 +1,53 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Fed-FairBatch/Fed-FB specific messages.""" + +import dataclasses +from typing import List + + +from declearn.messaging import Message + + +__all__ = [ + "FairbatchOkay", + "FairbatchSamplingProbas", +] + + +@dataclasses.dataclass +class FairbatchOkay(Message): + """Message for client signal that Fed-FairBatch/FedFB update went fine.""" + + typekey = "fairbatch-okay" + + +@dataclasses.dataclass +class FairbatchSamplingProbas(Message): + """Message for server-emitted Fed-FairBatch/Fed-FB sampling probabilities. + + Fields + ------ + probas: + List of group-wise sampling probabilities, ordered based on + an agreed-upon sorted list of sensitive groups. + """ + + probas: List[float] + + typekey = "fairbatch-probas" diff --git a/declearn/fairness/fairbatch/_sampling.py b/declearn/fairness/fairbatch/_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..dc2e1f9cc0d09978cffcfbd664ab78eddc4944ef --- /dev/null +++ b/declearn/fairness/fairbatch/_sampling.py @@ -0,0 +1,435 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""FairBatch sampling probability controllers.""" + +import abc +from typing import Any, ClassVar, Dict, List, Literal, Tuple + + +from declearn.fairness.core import instantiate_fairness_function + + +__all__ = [ + "FairbatchSamplingController", + "setup_fairbatch_controller", +] + + +GroupLabel = Literal["0_0", "0_1", "1_0", "1_1"] + + +class FairbatchSamplingController(metaclass=abc.ABCMeta): + """ABC to compute and update Fairbatch sampling probabilities.""" + + f_type: ClassVar[str] + + def __init__( + self, + groups: Dict[GroupLabel, Tuple[Any, ...]], + counts: Dict[Tuple[Any, ...], int], + alpha: float = 0.005, + **kwargs: Any, + ) -> None: + """Instantiate the Fairbatch sampling probabilities controller. + + Parameters + ---------- + groups: + Dict mapping canonical labels to sensitive group definitions. + counts: + Dict mapping sensitive group definitions to sample counts. + alpha: + Hyper-parameter controlling the update rule for internal + states and thereof sampling probabilities. + **kwargs: + Keyword arguments specific to the fairness definition in use. + """ + # Assign input parameters as attributes. + self.groups = groups + self.counts = counts + self.total = sum(counts.values()) + self.alpha = alpha + # Initialize internal states and sampling probabilities. + self.states = self.compute_initial_states() + # Initialize a fairness function. + self.f_func = instantiate_fairness_function( + f_type=self.f_type, counts=counts, **kwargs + ) + + @abc.abstractmethod + def compute_initial_states( + self, + ) -> Dict[str, float]: + """Return a dict containing initial internal states. + + Returns + ------- + states: + Dict associating float values to arbitrary names that + depend on the type of group-fairness being optimized. + """ + + @abc.abstractmethod + def get_sampling_probas( + self, + ) -> Dict[Tuple[Any, ...], float]: + """Return group-wise sampling probabilities. + + Returns + ------- + sampling_probas: + Dict mapping sensitive group definitions to their sampling + probabilities, as establised via the FairBatch algorithm. + """ + + @abc.abstractmethod + def update_from_losses( + self, + losses: Dict[Tuple[Any, ...], float], + ) -> None: + """Update internal states based on group-wise losses. + + Parameters + ---------- + losses: + Group-wise model loss values, as a `{group_k: loss_k}` dict. + """ + + def update_from_federated_losses( + self, + losses: Dict[Tuple[Any, ...], float], + ) -> None: + """Update internal states based on federated group-wise losses. + + Parameters + ---------- + losses: + Group-wise sum-aggregated local-group-count-weighted model + loss values, computed over an ensemble of local datasets. + I.e. `{group_k: sum_i(n_ik * loss_ik)}` dict. + + Raises + ------ + KeyError + If any defined sensitive group does not have a loss value. + """ + losses = {key: val / self.counts[key] for key, val in losses.items()} + self.update_from_losses(losses) + + +class FairbatchEqualityOpportunity(FairbatchSamplingController): + """FairbatchSamplingController subclass for 'equality_of_opportunity'.""" + + f_type = "equality_of_opportunity" + + def compute_initial_states( + self, + ) -> Dict[str, float]: + # Gather sample counts and share with positive target label. + nsmp_10 = self.counts[self.groups["1_0"]] + nsmp_11 = self.counts[self.groups["1_1"]] + p_tgt_1 = (nsmp_10 + nsmp_11) / self.total + # Assign the initial lambda and fixed quantities to re-use. + return { + "lambda": nsmp_10 / self.total, + "p_tgt_1": p_tgt_1, + "p_g_0_0": self.counts[self.groups["0_0"]] / self.total, + "p_g_0_1": self.counts[self.groups["0_1"]] / self.total, + } + + def get_sampling_probas( + self, + ) -> Dict[Tuple[Any, ...], float]: + return { + self.groups["0_0"]: self.states["p_g_0_0"], + self.groups["0_1"]: self.states["p_g_0_1"], + self.groups["1_0"]: self.states["lambda"], + self.groups["1_1"]: self.states["p_tgt_1"] - self.states["lambda"], + } + + def update_from_losses( + self, + losses: Dict[Tuple[Any, ...], float], + ) -> None: + # Gather mean-aggregated losses for the two groups of interest. + loss_10 = losses[self.groups["1_0"]] + loss_11 = losses[self.groups["1_1"]] + # Update lambda based on these and the alpha hyper-parameter. + if loss_10 > loss_11: + self.states["lambda"] = min( + self.states["lambda"] + self.alpha, self.states["p_tgt_1"] + ) + elif loss_10 < loss_11: + self.states["lambda"] = max(self.states["lambda"] - self.alpha, 0) + + +class FairbatchEqualizedOdds(FairbatchSamplingController): + """FairbatchSamplingController subclass for 'equalized_odds'.""" + + f_type = "equalized_odds" + + def compute_initial_states( + self, + ) -> Dict[str, float]: + # Gather sample counts. + nsmp_00 = self.counts[self.groups["0_0"]] + nsmp_01 = self.counts[self.groups["0_1"]] + nsmp_10 = self.counts[self.groups["1_0"]] + nsmp_11 = self.counts[self.groups["1_1"]] + # Compute initial lambas, and attribute-wise sample counts. + return { + "lambda_1": nsmp_00 / self.total, + "lambda_2": nsmp_10 / self.total, + "p_trgt_0": (nsmp_00 + nsmp_01) / self.total, + "p_trgt_1": (nsmp_10 + nsmp_11) / self.total, + } + + def get_sampling_probas( + self, + ) -> Dict[Tuple[Any, ...], float]: + states = self.states + return { + self.groups["0_0"]: states["lambda_1"], + self.groups["0_1"]: states["p_trgt_0"] - states["lambda_1"], + self.groups["1_0"]: states["lambda_2"], + self.groups["1_1"]: states["p_trgt_1"] - states["lambda_2"], + } + + def update_from_losses( + self, + losses: Dict[Tuple[Any, ...], float], + ) -> None: + # Compute loss differences for each target label. + diff_loss_tgt_0 = ( + losses[self.groups["0_0"]] - losses[self.groups["0_1"]] + ) + diff_loss_tgt_1 = ( + losses[self.groups["1_0"]] - losses[self.groups["1_1"]] + ) + # Update a lambda based on these and the alpha hyper-parameter. + if abs(diff_loss_tgt_0) > abs(diff_loss_tgt_1): + if diff_loss_tgt_0 > 0: + self.states["lambda_1"] = min( + self.states["lambda_1"] + self.alpha, + self.states["p_trgt_0"], + ) + elif diff_loss_tgt_0 < 0: + self.states["lambda_1"] = max( + self.states["lambda_1"] - self.alpha, 0 + ) + else: + if diff_loss_tgt_1 > 0: + self.states["lambda_2"] = min( + self.states["lambda_2"] + self.alpha, + self.states["p_trgt_1"], + ) + elif diff_loss_tgt_1 < 0: + self.states["lambda_2"] = max( + self.states["lambda_2"] - self.alpha, 0 + ) + + +class FairbatchDemographicParity(FairbatchSamplingController): + """FairbatchSamplingController subclass for 'demographic_parity'.""" + + f_type = "demographic_parity" + + def compute_initial_states( + self, + ) -> Dict[str, float]: + # Gather sample counts. + nsmp_00 = self.counts[self.groups["0_0"]] + nsmp_01 = self.counts[self.groups["0_1"]] + nsmp_10 = self.counts[self.groups["1_0"]] + nsmp_11 = self.counts[self.groups["1_1"]] + # Compute initial lambas, and target-label-wise sample counts. + return { + "lambda_1": nsmp_00 / self.total, + "lambda_2": nsmp_01 / self.total, + "p_attr_0": (nsmp_00 + nsmp_10) / self.total, + "p_attr_1": (nsmp_01 + nsmp_11) / self.total, + "n_attr_0": nsmp_00 + nsmp_10, + "n_attr_1": nsmp_01 + nsmp_11, + } + + def get_sampling_probas( + self, + ) -> Dict[Tuple[Any, ...], float]: + states = self.states + return { + self.groups["0_0"]: states["lambda_1"], + self.groups["1_0"]: states["p_attr_0"] - states["lambda_1"], + self.groups["0_1"]: states["lambda_2"], + self.groups["1_1"]: states["p_attr_1"] - states["lambda_2"], + } + + def update_from_losses( + self, + losses: Dict[Tuple[Any, ...], float], + ) -> None: + # Recover sum-aggregated losses for each sensitive group. + # Obtain {k: n_k * Sum(loss for all samples in group k)}. + labeled_losses = { + label: losses[group] * self.counts[group] + for label, group in self.groups.items() + } + # Normalize losses based on sensitive attribute counts. + # Obtain {k: sum(loss for samples in k) / n_samples_with_attr}. + labeled_losses["0_0"] /= self.states["n_attr_0"] + labeled_losses["0_1"] /= self.states["n_attr_1"] + labeled_losses["1_0"] /= self.states["n_attr_0"] + labeled_losses["1_1"] /= self.states["n_attr_1"] + # Compute aggregated-loss differences for each target label. + diff_loss_tgt_0 = labeled_losses["0_0"] - labeled_losses["0_1"] + diff_loss_tgt_1 = labeled_losses["1_0"] - labeled_losses["1_1"] + # Update a lambda based on these and the alpha hyper-parameter. + if abs(diff_loss_tgt_0) > abs(diff_loss_tgt_1): + if diff_loss_tgt_0 > 0: + self.states["lambda_1"] = max( + self.states["lambda_1"] - self.alpha, 0 + ) + elif diff_loss_tgt_0 < 0: + self.states["lambda_1"] = min( + self.states["lambda_1"] + self.alpha, + self.states["p_attr_0"], + ) + else: + if diff_loss_tgt_1 > 0: + self.states["lambda_2"] = min( + self.states["lambda_2"] + self.alpha, + self.states["p_attr_1"], + ) + elif diff_loss_tgt_1 < 0: + self.states["lambda_2"] = max( + self.states["lambda_2"] - self.alpha, 0 + ) + + +def assign_sensitive_group_labels( + groups: List[Tuple[Any, ...]], + target: int, +) -> Dict[GroupLabel, Tuple[Any, ...]]: + """Parse sensitive group definitions to match canonical labels. + + Parameters + ---------- + groups: + List of sensitive group definitions, as a list of tuples. + These should be four tuples arising from the intersection + of binary labels (with any actual type). + target: + Value of the target label to treat as positive. + + Returns + ------- + labeled_groups: + Dict mapping canonical labels `"0_0", "0_1", "1_0", "1_1"` + to the input sensitive group definitions. + + Raises + ------ + ValueError + If 'groups' has unproper length, values that do not appear + to be binary, or that do not match the specified 'target'. + """ + # Verify that groups can be identified as crossing two binary labels. + if len(groups) != 4: + raise ValueError( + "FairBatch requires counts over exactly 4 sensitive groups, " + "arising from a binary target label and a binary sensitive " + "attribute." + ) + target_values = list({group[0] for group in groups}) + s_attr_values = sorted(list({group[1] for group in groups})) + if not len(target_values) == len(s_attr_values) == 2: + raise ValueError( + "FairBatch requires sensitive groups to arise from a binary " + "target label and a binary sensitive attribute." + ) + # Identify the positive and negative label values. + if target_values[0] == target: + postgt, negtgt = target_values + elif target_values[1] == target: + negtgt, postgt = target_values + else: + raise ValueError( + f"Received a target value of '{target}' that does not match any " + f"value in the sensitive group definitions: {target_values}." + ) + # Match group definitions with canonical string labels. + return { + "0_0": (negtgt, s_attr_values[0]), + "0_1": (negtgt, s_attr_values[1]), + "1_0": (postgt, s_attr_values[0]), + "1_1": (postgt, s_attr_values[1]), + } + + +def setup_fairbatch_controller( + f_type: str, + counts: Dict[Tuple[Any, ...], int], + target: int = 1, + alpha: float = 0.005, +) -> FairbatchSamplingController: + """Instantiate a FairBatch sampling probabilities controller. + + Parameters + ---------- + f_type: + Type of group fairness to optimize for. + counts: + Dict mapping sensitive group definitions to their total + sample counts (across clients). These groups must arise + from the crossing of a binary target label and a binary + sensitive attribute. + target: + Target label to treat as positive. + alpha: + Alpha hyper-parameter, scaling the magnitude of sampling + probabilities' updates by the returned controller. + + Returns + ------- + controller: + FairBatch sampling probabilities controller matching inputs. + + Raises + ------ + KeyError + If `f_type` does not match any known or supported fairness type. + ValueError + If `counts` keys cannot be matched to canonical group labels. + """ + controller_types = { + "demographic_parity": FairbatchDemographicParity, + "equality_of_opportunity": FairbatchEqualityOpportunity, + "equalized_odds": FairbatchEqualizedOdds, + } + controller_cls = controller_types.get(f_type, None) + if controller_cls is None: + raise KeyError( + "Unknown or unsupported fairness type parameter for FairBatch " + f"controller initialization: '{f_type}'. Supported values are " + f"{list(controller_types)}." + ) + # Match groups to canonical labels and instantiate the controller. + groups = assign_sensitive_group_labels(groups=list(counts), target=target) + kwargs = {"target": target} if f_type == "equality_of_opportunity" else {} + return controller_cls( # type: ignore[abstract] + groups=groups, counts=counts, alpha=alpha, **kwargs + ) diff --git a/declearn/fairness/fairbatch/_server.py b/declearn/fairness/fairbatch/_server.py new file mode 100644 index 0000000000000000000000000000000000000000..4e1c99af7144f23a029945e7118ca5168fdf1b0d --- /dev/null +++ b/declearn/fairness/fairbatch/_server.py @@ -0,0 +1,182 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Server-side Fed-FairBatch/FedFB controller.""" + +import warnings +from typing import Any, Dict, List, Optional, Union + +import numpy as np + +from declearn.aggregator import Aggregator, SumAggregator +from declearn.communication.api import NetworkServer +from declearn.communication.utils import verify_client_messages_validity +from declearn.fairness.api import FairnessControllerServer +from declearn.fairness.fairbatch._fedfb import setup_fedfb_controller +from declearn.fairness.fairbatch._messages import ( + FairbatchOkay, + FairbatchSamplingProbas, +) +from declearn.fairness.fairbatch._sampling import setup_fairbatch_controller +from declearn.messaging import FairnessSetupQuery +from declearn.secagg.api import Decrypter + + +__all__ = [ + "FairbatchControllerServer", +] + + +class FairbatchControllerServer(FairnessControllerServer): + """Server-side controller to implement Fed-FairBatch or FedFB. + + References + ---------- + - [1] + Roh et al. (2020). + FairBatch: Batch Selection for Model Fairness. + https://arxiv.org/abs/2012.01696 + - [2] + Zeng et al. (2022). + Improving Fairness via Federated Learning. + https://arxiv.org/abs/2110.15545 + """ + + algorithm = "fed-fairbatch" + + def __init__( + self, + f_type: str, + f_args: Optional[Dict[str, Any]] = None, + alpha: float = 0.005, + fedfb: bool = False, + ) -> None: + """Instantiate the server-side Fed-FairGrad controller. + + Parameters + ---------- + f_type: + Name of the fairness function to evaluate and optimize. + f_args: + Optional dict of keyword arguments to the fairness function. + alpha: + Hyper-parameter controlling the update rule for internal + states and thereof sampling probabilities. + fedfb: + Whether to use FedFB formulas rather than to stick + to those from the original FairBatch paper. + """ + super().__init__(f_type=f_type, f_args=f_args) + # Choose whether to use FedFB or FairBatch update rules. + self._setup_function = ( + setup_fedfb_controller if fedfb else setup_fairbatch_controller + ) + # Set up a temporary controller that will be replaced at setup time. + self.sampling_controller = self._setup_function( + f_type=self.f_type, + counts={(0, 0): 1, (0, 1): 1, (1, 0): 1, (1, 1): 1}, + target=self.f_args.get("target", 1), + alpha=alpha, + ) + + @property + def fedfb(self) -> bool: + """Whether this controller implements FedFB rather than Fed-FairBatch. + + FedFB is a published adaptation of FairBatch to the federated + setting, that introduces changes to some FairBatch formulas. + + Fed-FairBatch is a DecLearn-introduced variant of FedFB that + restores the original FairBatch formulas. + """ + return self._setup_function is setup_fedfb_controller + + def prepare_fairness_setup_query( + self, + ) -> FairnessSetupQuery: + query = super().prepare_fairness_setup_query() + query.params.update({"f_type": self.f_type, "f_args": self.f_args}) + return query + + async def finalize_fairness_setup( + self, + netwk: NetworkServer, + counts: List[int], + aggregator: Aggregator, + ) -> Aggregator: + # Set up the FairbatchWeightsController. + self.sampling_controller = self._setup_function( + f_type=self.f_type, + counts=dict(zip(self.groups, counts)), + target=self.f_args.get("target", 1), + alpha=self.sampling_controller.alpha, + ) + # Send initial loss weights to the clients. + await self._send_fairbatch_probas(netwk) + # Force the use of a SumAggregator. + if not isinstance(aggregator, SumAggregator): + warnings.warn( + "Overriding Aggregator choice to a 'SumAggregator', " + "due to the use of Fed-FairBatch.", + category=RuntimeWarning, + ) + aggregator = SumAggregator() + return aggregator + + async def _send_fairbatch_probas( + self, + netwk: NetworkServer, + ) -> None: + """Send FairBatch sensitive group sampling probabilities to clients. + + Await for clients to ping back that things went fine on their side. + """ + netwk.logger.info( + "Sending FairBatch sampling probabilities to clients." + ) + probas = self.sampling_controller.get_sampling_probas() + p_list = [probas[group] for group in self.groups] + await netwk.broadcast_message(FairbatchSamplingProbas(p_list)) + received = await netwk.wait_for_messages() + await verify_client_messages_validity( + netwk, received, expected=FairbatchOkay + ) + + async def finalize_fairness_round( + self, + round_i: int, + values: List[float], + netwk: NetworkServer, + secagg: Optional[Decrypter], + ) -> Dict[str, Union[float, np.ndarray]]: + # Unpack group-wise accuracy and loss values. + accuracy = dict(zip(self.groups, values[: len(self.groups)])) + loss = dict(zip(self.groups, values[len(self.groups) :])) + # Update sampling probabilities and send them to clients. + self.sampling_controller.update_from_federated_losses(loss) + await self._send_fairbatch_probas(netwk) + # Package and return accuracy, loss and fairness metrics. + metrics = { + f"accuracy_{key}": val for key, val in accuracy.items() + } # type: Dict[str, Union[float, np.ndarray]] + metrics.update({f"loss_{key}": val for key, val in loss.items()}) + f_func = self.sampling_controller.f_func + fairness = f_func.compute_from_federated_group_accuracy(accuracy) + metrics.update( + {f"{self.f_type}_{key}": val for key, val in fairness.items()} + ) + return metrics