diff --git a/declearn/__init__.py b/declearn/__init__.py index e1c51560184af1eed71b79fa8c02de4c1c34de7d..0e52c31b2cab3e2cb4cf6e978c54270d9a90095a 100644 --- a/declearn/__init__.py +++ b/declearn/__init__.py @@ -38,6 +38,8 @@ The package is organized into the following submodules: Tools to write and extend shareable metadata fields specifications. * [dataset][declearn.dataset]: Data interfacing API and implementations. +* [fairness][declearn.fairness]: + Processes and components for fairness-aware federated learning. * [main][declearn.main]: Main classes implementing a Federated Learning process. * [messaging][declearn.messaging]: @@ -50,6 +52,8 @@ The package is organized into the following submodules: Framework-agnostic optimizer and algorithmic plug-ins API and tools. * [secagg][declearn.secagg]: Secure Aggregation API, methods and utils. +* [training][declearn.training]: + Model training and evaluation orchestration tools. * [typing][declearn.typing]: Type hinting utils, defined and exposed for code readability purposes. * [utils][declearn.utils]: @@ -63,12 +67,14 @@ from . import ( communication, data_info, dataset, + fairness, main, metrics, messaging, model, optimizer, secagg, + training, typing, utils, version, diff --git a/declearn/aggregator/__init__.py b/declearn/aggregator/__init__.py index a0fdfaab69535e46f28e1fc326a773d3c8eb6cc3..7b600bd8fbd2b4358b181d9839989c23701faf16 100644 --- a/declearn/aggregator/__init__.py +++ b/declearn/aggregator/__init__.py @@ -43,8 +43,11 @@ Concrete classes Average-based-aggregation Aggregator subclass. * [GradientMaskedAveraging][declearn.aggregator.GradientMaskedAveraging]: Gradient Masked Averaging Aggregator subclass. +* [SumAggregator][declearn.aggregator.SumAggregator]: + Sum-aggregation Aggregator subclass. """ from ._api import Aggregator, ModelUpdates, list_aggregators from ._avg import AveragingAggregator from ._gma import GradientMaskedAveraging +from ._sum import SumAggregator diff --git a/declearn/aggregator/_sum.py b/declearn/aggregator/_sum.py new file mode 100644 index 0000000000000000000000000000000000000000..396dba01bacb5ab6363185d2996d99d288c418ae --- /dev/null +++ b/declearn/aggregator/_sum.py @@ -0,0 +1,53 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sum-aggregation Aggregator subclass.""" + + +from declearn.aggregator._api import Aggregator, ModelUpdates +from declearn.model.api import Vector + + +__all__ = [ + "SumAggregator", +] + + +class SumAggregator(Aggregator[ModelUpdates]): + """Sum-aggregation Aggregator subclass. + + This class implements the mere summation of client-wise model + updates. It is therefore targetted at algorithms that perform + some processing on model updates (e.g. via sample weights) so + that mere summation is the proper way to recover gradients of + the global model. + """ + + name = "sum" + + def prepare_for_sharing( + self, + updates: Vector, + n_steps: int, + ) -> ModelUpdates: + return ModelUpdates(updates=updates, weights=1) + + def finalize_updates( + self, + updates: ModelUpdates, + ) -> Vector: + return updates.updates diff --git a/declearn/fairness/__init__.py b/declearn/fairness/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dc283d7d4669c8178c69440c7fdf5117db5d99eb --- /dev/null +++ b/declearn/fairness/__init__.py @@ -0,0 +1,104 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Processes and components for fairness-aware federated learning. + +Introduction +------------ + +This modules provides with a general API and some specific algorithms +to measure and enforce group fairness as part of a federated learning +process in DecLearn. + +Group fairness refers to a setting where a classifier is trained over +data that can be split between various subsets based on one or more +categorical sensitive attributes, usually comprising the target label. +In such a setting, the model's fairness is defined and evaluated by +comparing its accuracy over the various subgroups, using one of the +various definitions proposed in the litterature. + +The algorithms and shared API implemented in this module consider that +the fairness being measured (and optimized) is to be computed over the +union of all training datasets held by clients. The API is designed to +be compatible with any number of sensitive groups, with regimes where +individual clients do not necessarily hold samples to each and every +group, and with all group fairness definitions that can be expressed +in a form that was introduced in paper [1]. However, some restrictions +may be enforced by concrete algorithms, in alignment with those set by +their original authors. + +Currently, concrete algorithms include: + +- Fed-FairGrad, adapted from [1] +- Fed-FairBatch, adapted from [2], and the FedFB variant based on [3] +- FairFed, based on [4] + +In addition, a "monitor-only" algorithm is provided, that merely uses +the shared API to measure client-wise and global fairness throughout +training without altering the training algorithm. + + +API-defining and core submodules +-------------------------------- + +* [api][declearn.fairness.api]: + Abstract and base components for fairness-aware federated learning. +* [core][declearn.fairness.core]: + Built-in concrete components for fairness-aware federated learning. + +Algorithms submodules +--------------------- + +* [fairbatch][declearn.fairness.fairbatch]: + Fed-FairBatch / FedB algorithm controllers and utils. +* [fairfed][declearn.fairness.fairfed]: + FairFed algorithm controllers and utils. +* [fairgrad][declearn.fairness.fairgrad]: + Fed-FairGrad algorithm controllers and utils. +* [monitor][declearn.fairness.monitor]: + Fairness-monitoring controllers, that leave training unaltered. + +Note that the controllers implemented under these submodules +are type-registered under the submodule's name. + +References +---------- + +- [1] + Maheshwari & Perrot (2023). + FairGrad: Fairness Aware Gradient Descent. + https://openreview.net/forum?id=0f8tU3QwWD +- [2] + Roh et al. (2020). + FairBatch: Batch Selection for Model Fairness. + https://arxiv.org/abs/2012.01696 +- [3] + Zeng et al. (2022). + Improving Fairness via Federated Learning. + https://arxiv.org/abs/2110.15545 +- [4] + Ezzeldin et al. (2021). + FairFed: Enabling Group Fairness in Federated Learning + https://arxiv.org/abs/2110.00857 +""" + +from . import api +from . import core +from . import fairbatch +from . import fairfed +from . import fairgrad +from . import monitor diff --git a/declearn/fairness/api/__init__.py b/declearn/fairness/api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f3ff82f4563ba193030fbd24c218cd7c76cab281 --- /dev/null +++ b/declearn/fairness/api/__init__.py @@ -0,0 +1,57 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Abstract and base components for fairness-aware federated learning. + +Endpoint Controller ABCs +------------------------ + +* [FairnessControllerClient][declearn.fairness.api.FairnessControllerClient]: + Abstract base class for client-side fairness controllers. +* [FairnessControllerServer][declearn.fairness.api.FairnessControllerServer]: + Abstract base class for server-side fairness controllers. + +Group-fairness functions +------------------------ +API-defining ABC and generic constructor: + +* [FairnessFunction][declearn.fairness.api.FairnessFunction]: + Abstract base class for group-fairness functions. +* [instantiate_fairness_function]\ +[declearn.fairness.api.instantiate_fairness_function]: + Instantiate a FairnessFunction from its specifications. + +Built-in concrete implementations may be found in [declearn.fairness.core][]. + +Dataset subclass +---------------- + +* [FairnessDataset][declearn.fairness.api.FairnessDataset]: + Abstract base class for Fairness-aware `Dataset` interfaces. + +Backend +------- + +* [FairnessMetricsComputer][declearn.fairness.api.FairnessMetricsComputer]: + Utility dataset-handler to compute group-wise evaluation metrics. +""" + +from ._dataset import FairnessDataset +from ._fair_func import FairnessFunction, instantiate_fairness_function +from ._metrics import FairnessMetricsComputer +from ._client import FairnessControllerClient +from ._server import FairnessControllerServer diff --git a/declearn/fairness/api/_client.py b/declearn/fairness/api/_client.py new file mode 100644 index 0000000000000000000000000000000000000000..f9ce30bb5dd9dbe8c614ede974909165ade280a0 --- /dev/null +++ b/declearn/fairness/api/_client.py @@ -0,0 +1,479 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Client-side ABC for fairness-aware federated learning controllers.""" + +import abc +from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union + +import numpy as np + +from declearn.communication.api import NetworkClient +from declearn.communication.utils import verify_server_message_validity +from declearn.fairness.api._dataset import FairnessDataset +from declearn.fairness.api._fair_func import instantiate_fairness_function +from declearn.messaging import ( + Error, + FairnessCounts, + FairnessGroups, + FairnessQuery, + FairnessReply, + FairnessSetupQuery, +) +from declearn.fairness.api._metrics import FairnessMetricsComputer +from declearn.metrics import MeanMetric +from declearn.secagg.api import Encrypter +from declearn.secagg.messaging import ( + SecaggFairnessCounts, + SecaggFairnessReply, +) +from declearn.training import TrainingManager +from declearn.utils import ( + access_registered, + create_types_registry, + register_type, +) + +__all__ = [ + "FairnessControllerClient", +] + + +@create_types_registry(name="FairnessControllerClient") +class FairnessControllerClient(metaclass=abc.ABCMeta): + """Abstract base class for client-side fairness controllers. + + Usage + ----- + A `FairnessControllerClient` (subclass) instance has two main + routines that are to be called as part of a federated learning + process, in addition to a static method from the base API class: + + - `from_setup_query`: + This is a static method that can be called generically from + the base `FairnessControllerClient` type to instantiate a + controller from a server-emitted `FairnessSetupQuery`. + - `setup_fairness`: + This routine is to be called only once, after instantiating + from a `FairnessSetupQuery`. It triggers the following process: + - Run a basic routine to exchange sensitive group definitions + and associated (encrypted) sample counts. + - Perform any additional algorithm-specific setup actions. + - `run_fairness_round`: + This routine is to be called once per round, before the next + training round occurs, upon receiving a `FairnessQuery` from + the server. It triggers the following process: + - Run a basic routine to compute fairness-related metrics + and send (some of) their (encrypted) values to the server. + - Perform any additonal algorithm-specific round actions. + + Inheritance + ----------- + Algorithm-specific subclasses should define the following abstract + attribute and methods: + + - `algorithm`: + Abstract string class attribute. Name under which this controller + and its server-side counterpart classes are registered. + - `finalize_fairness_setup`: + Method implementing any algorithm-specific setup actions. + - `finalize_fairness_round`: + Method implementing any algorithm-specific round actions. + + Additionally, they may overload or override the following method: + + - `setup_fairness_metrics`: + Method that defines metrics being computed as part of fairness + rounds. By default, group-wise accuracy values are computed and + shared with the server, and the local fairness is computed from + them (but not sent to the server). + + By default, subclasses are type-registered under their `algorithm` + name and "FairnessControllerClient" group upon declaration. This can + be prevented by passing `register=False` to the inheritance parameters + (e.g. `class Cls(FairnessControllerClient, register=False)`). + See `declearn.utils.register_type` for details on types registration. + """ + + algorithm: ClassVar[str] + """Name of the fairness-enforcing algorithm. + + This name should be unique across 'FairnessControllerClient' classes, + and shared with a unique paired 'FairnessControllerServer'. It is used + for type-registration and to enable instantiating a client controller + based on server-emitted instructions in a federated setting. + """ + + def __init_subclass__( + cls, + register: bool = True, + ) -> None: + """Automatically type-register subclasses.""" + if register: + register_type(cls, cls.algorithm, group="FairnessControllerClient") + + def __init__( + self, + manager: TrainingManager, + f_type: str, + f_args: Dict[str, Any], + ) -> None: + """Instantiate the client-side fairness controller. + + Parameters + ---------- + manager: + `TrainingManager` instance wrapping the model being trained + and its training dataset (that must be a `FairnessDataset`). + f_type: + Name of the type of group-fairness function being optimized. + f_args: + Keyword arguments to the group-fairness function. + """ + if not isinstance(manager.train_data, FairnessDataset): + raise TypeError( + "Cannot set up fairness without a 'FairnessDataset' " + "as training dataset." + ) + self.manager = manager + self.computer = FairnessMetricsComputer(manager.train_data) + self.fairness_function = instantiate_fairness_function( + f_type=f_type, counts=self.computer.counts, **f_args + ) + self.groups = [] # type: List[Tuple[Any, ...]] + + @staticmethod + def from_setup_query( + query: FairnessSetupQuery, + manager: TrainingManager, + ) -> "FairnessControllerClient": + """Instantiate a controller from a server-emitted query. + + Parameters + ---------- + query: + `FairnessSetupQuery` received from the server. + manager: + `TrainingManager` wrapping the model to train. + + Returns + ------- + controller: + `FairnessControllerClient` instance, the type and parameters + of which depend on the input `query`, that wraps `manager`. + """ + try: + cls = access_registered( + name=query.algorithm, group="FairnessControllerClient" + ) + assert issubclass(cls, FairnessControllerClient) + except Exception as exc: + raise ValueError( + "Failed to retrieve a 'FairnessControllerClient' class " + "matching the input 'FairnessSetupQuery' message." + ) from exc + return cls(manager=manager, **query.params) + + async def setup_fairness( + self, + netwk: NetworkClient, + secagg: Optional[Encrypter], + ) -> None: + """Participate in a routine to initialize fairness-aware learning. + + This routine has the following structure: + + - Exchange with the server to agree on an ordered list of sensitive + groups defined by the interesection of 1+ sensitive attributes + and (opt.) a classification target label. + - Send (encrypted) group-wise training sample counts, that the server + is to (secure-)aggregate. + - Perform any additional actions specific to the algorithm in use. + - On the client side, optionally alter the `TrainingManager` used. + - On the server side, optionally alter the `Aggregator` used. + + Parameters + ---------- + netwk: + NetworkClient endpoint, registered to a server. + secagg: + Optional SecAgg encryption controller. + """ + # Agree on a list of sensitive groups and share local sample counts. + await self.exchange_sensitive_groups_list_and_counts(netwk, secagg) + # Run additional algorithm-specific setup steps. + await self.finalize_fairness_setup(netwk, secagg) + + async def exchange_sensitive_groups_list_and_counts( + self, + netwk: NetworkClient, + secagg: Optional[Encrypter], + ) -> None: + """Agree on a list of sensitive groups and share local sample counts. + + This method performs the following routine: + + - Send the list of local sensitive group definitions to the server. + - Await a unified list of sensitive groups in return. + - Assign the received list as `groups` attribute. + - Send (optionally-encrypted) group-wise sample counts to the server. + + Parameters + ---------- + netwk: + `NetworkClient` endpoint, connected to a server. + secagg: + Optional SecAgg encryption controller. + """ + # Share sensitive groups definitions and received an ordered list. + self.groups = await self._exchange_sensitive_groups_list(netwk) + # Send group-wise sample counts for the server to (secure-)aggregate. + await self._send_sensitive_groups_counts(netwk, secagg) + + async def _exchange_sensitive_groups_list( + self, + netwk: NetworkClient, + ) -> List[Tuple[Any, ...]]: + """Exhange sensitive groups definitions and return a unified list.""" + # Gather local sensitive groups and their sample counts. + counts = self.computer.counts + groups = list(counts) + # Share them and receive a unified, ordered list of groups. + await netwk.send_message(FairnessGroups(groups=groups)) + received = await netwk.recv_message() + message = await verify_server_message_validity( + netwk, received, expected=FairnessGroups + ) + return message.groups + + async def _send_sensitive_groups_counts( + self, + netwk: NetworkClient, + secagg: Optional[Encrypter], + ) -> None: + """Send (opt. encrypted) group-wise sample counts to the server.""" + counts = self.computer.counts + reply = FairnessCounts([counts.get(group, 0) for group in self.groups]) + if secagg is None: + await netwk.send_message(reply) + else: + await netwk.send_message( + SecaggFairnessCounts.from_cleartext_message(reply, secagg) + ) + + @abc.abstractmethod + async def finalize_fairness_setup( + self, + netwk: NetworkClient, + secagg: Optional[Encrypter], + ) -> None: + """Finalize the fairness setup routine. + + This method is called as part of `setup_fairness`, and should + be defined by concrete subclasses to implement setup behavior + once the initial echange of sensitive group definitions and + sample counts has been performed. + + Parameters + ---------- + netwk: + NetworkClient endpoint, registered to a server. + secagg: + Optional SecAgg encryption controller. + """ + + async def run_fairness_round( + self, + netwk: NetworkClient, + query: FairnessQuery, + secagg: Optional[Encrypter], + ) -> Dict[str, Union[float, np.ndarray]]: + """Participate in a round of actions to enforce fairness. + + Parameters + ---------- + netwk: + NetworkClient endpoint instance, connected to a server. + query: + `FairnessQuery` message to participate in a fairness round. + secagg: + Optional SecAgg encryption controller. + + Returns + ------- + metrics: + Fairness(-related) metrics computed as part of this routine, + as a `{name: value}` dict with scalar or numpy array values. + """ + try: + values = await self._compute_and_share_fairness_measures( + netwk, query, secagg + ) + except Exception as exc: + error = f"Error encountered in fairness round: {repr(exc)}" + self.manager.logger.error(error) + await netwk.send_message(Error(error)) + raise RuntimeError(error) from exc + # Run additional algorithm-specific steps. + return await self.finalize_fairness_round(netwk, secagg, values) + + async def _compute_and_share_fairness_measures( + self, + netwk: NetworkClient, + query: FairnessQuery, + secagg: Optional[Encrypter], + ) -> Dict[str, Dict[Tuple[Any, ...], float]]: + """Compute, share (encrypted) and return fairness measures.""" + # Optionally update the wrapped model's weights. + if query.weights is not None: + self.manager.model.set_weights(query.weights, trainable=True) + # Compute some fairness-related values, split between two sets. + share_values, local_values = self.compute_fairness_measures( + query.batch_size, query.n_batch, query.thresh + ) + # Share the first set of values for their (secure-)aggregation. + reply = FairnessReply(values=share_values) + if secagg is None: + await netwk.send_message(reply) + else: + await netwk.send_message( + SecaggFairnessReply.from_cleartext_message(reply, secagg) + ) + # Return the second set of values. + return local_values + + def compute_fairness_measures( + self, + batch_size: int, + n_batch: Optional[int] = None, + thresh: Optional[float] = None, + ) -> Tuple[List[float], Dict[str, Dict[Tuple[Any, ...], float]]]: + """Compute fairness measures based on a received query. + + By default, compute and return group-wise accuracy metrics, + weighted by group-wise sample counts. This may be modified + by algorithm-specific subclasses depending on algorithms' + needs. + + Parameters + ---------- + batch_size: + Number of samples per batch when computing predictions. + n_batch: + Optional maximum number of batches to draw per category. + If None, use the entire wrapped dataset. + thresh: + Optional binarization threshold for binary classification + models' output scores. If None, use 0.5 by default, or 0.0 + for `SklearnSGDModel` instances. + Unused for multinomial classifiers (argmax over scores). + + Returns + ------- + share_values: + Values that are to be shared with the orchestrating server, + as a deterministic-length list of float values. + local_values: + Values that are to be used in local post-processing steps, + as a nested dictionary of group-wise metrics. + """ + # Compute group-wise metrics. + metrics = self.setup_fairness_metrics(thresh=thresh) + local_values = self.computer.compute_groupwise_metrics( + metrics=metrics, + model=self.manager.model, + batch_size=batch_size, + n_batch=n_batch, + ) + # Gather sample-counts-scaled values to share with the server. + scaled_values = { + key: self.computer.scale_metrics_by_sample_counts(val) + for key, val in local_values.items() + } + share_values = [ + scaled_values[key].get(group, 0.0) + for key in sorted(scaled_values) + for group in self.groups + ] + # Compute group-wise local fairness measures. + if "accuracy" in local_values: + fairness = self.fairness_function.compute_from_group_accuracy( + local_values["accuracy"] + ) + local_values[self.fairness_function.f_type] = fairness + # Return both shareable and local values. + return share_values, local_values + + def setup_fairness_metrics( + self, + thresh: Optional[float] = None, + ) -> List[MeanMetric]: + """Setup metrics to compute group-wise and share with the server. + + By default, this method returns an accuracy-computation method. + It may be overloaded to compute additional metrics depending on + the needs of the fairness-enforcing algorithm being implemented. + + Parameters + ---------- + thresh: + Optional binarization threshold for binary classification + models' output scores. Used to setup accuracy computations. + + Returns + ------- + metrics: + List of `MeanMetric` instances, that each compute a unique + scalar float metric (per sensitive group) and have distinct + names. + """ + accuracy = self.computer.setup_accuracy_metric( + self.manager.model, thresh=thresh + ) + return [accuracy] + + @abc.abstractmethod + async def finalize_fairness_round( + self, + netwk: NetworkClient, + secagg: Optional[Encrypter], + values: Dict[str, Dict[Tuple[Any, ...], float]], + ) -> Dict[str, Union[float, np.ndarray]]: + """Take actions to enforce fairness. + + This method is designed to be called after an initial query + has been received and responded to, resulting in computing + and sharing fairness(-related) metrics. + + Parameters + ---------- + netwk: + NetworkClient endpoint instance, connected to a server. + secagg: + Optional SecAgg encryption controller. + values: + Nested dict of locally-computed group-wise metrics. + This is the second set of `compute_fairness_measures` return + values; when this method is called, the first has already + been shared with the server for (secure-)aggregation. + + Returns + ------- + metrics: + Computed fairness(-related) metrics to checkpoint, as a + `{name: value}` dict with scalar or numpy array values. + """ diff --git a/declearn/fairness/api/_dataset.py b/declearn/fairness/api/_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..4e4af5386cbc050a9afec507f6b30d886fe04573 --- /dev/null +++ b/declearn/fairness/api/_dataset.py @@ -0,0 +1,124 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Fairness-aware Dataset abstract base subclass.""" + +from abc import ABCMeta, abstractmethod +from typing import Any, Dict, List, Tuple + + +from declearn.dataset import Dataset + + +__all__ = [ + "FairnessDataset", +] + + +class FairnessDataset(Dataset, metaclass=ABCMeta): + """Abstract base class for Fairness-aware Dataset interfaces. + + This `declearn.dataset.Dataset` abstract subclass adds API methods + related to group fairness to the base dataset API. These revolve + around accessing sensitive group definitions, sample counts and + dataset subset. They further add the possibility to modify samples' + weights based on the sensitive group to which they belong. + """ + + @abstractmethod + def get_sensitive_group_definitions( + self, + ) -> List[Tuple[Any, ...]]: + """Return a list of exhaustive sensitive groups for this dataset. + + Returns + ------- + groups: + List of tuples of values that define a sensitive group, as + the intersection of one or more sensitive attributes, and + the model's target when defined. + """ + + @abstractmethod + def get_sensitive_group_counts( + self, + ) -> Dict[Tuple[Any, ...], int]: + """Return sensitive attributes' combinations and counts. + + Returns + ------- + values: + Dict holding the number of samples for each and every sensitive + group. Its keys are tuples holding the values of the attributes + that define the sensitive groups (based on their intersection). + """ + + @abstractmethod + def get_sensitive_group_subset( + self, + group: Tuple[Any, ...], + ) -> Dataset: + """Return samples that belong to a given sensitive group. + + Parameters + ---------- + group: + Tuple of values that define a sensitive group (as samples that + have these values as sensitive attributes and/or target label). + + Returns + ------- + dataset: + `Dataset` instance, holding the subset of samples that belong + to the specified sensitive group. + + Raises + ------ + KeyError + If the specified group does not match any sensitive group + defined for this instance. + """ + + @abstractmethod + def set_sensitive_group_weights( + self, + weights: Dict[Tuple[Any, ...], float], + adjust_by_counts: bool = False, + ) -> None: + """Assign weights associated with samples' sensitive group membership. + + This method updates the sample weights yielded by this dataset's + `generate_batches` method, to become the product of the raw sample + weights and the values associated with the sensitive attributes. + + Parameters + ---------- + weights: + Dict associating weights with tuples of values caracterizing the + sensitive groups they are to apply to. + adjust_by_counts: + Whether to multiply input group weights `w_k` by the number of + samples for their group. This is notably useful in federated + contexts, where `weights` may in fact be input as `w_k^t / n_k` + and thereof adjusted to `w_k^t * n_{i,k} / n_k`. + + Raises + ------ + KeyError + If not all local sensitive groups have weights defined as part + of the inputs. + """ diff --git a/declearn/fairness/api/_fair_func.py b/declearn/fairness/api/_fair_func.py new file mode 100644 index 0000000000000000000000000000000000000000..0a042fc3956590c090f81ed30c392e563fedc14d --- /dev/null +++ b/declearn/fairness/api/_fair_func.py @@ -0,0 +1,271 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ABC and generic constructor for group-fairness functions.""" + +import abc +import functools +from typing import Any, ClassVar, Dict, List, Tuple + +import numpy as np + +from declearn.utils import ( + access_registered, + create_types_registry, + register_type, +) + +__all__ = [ + "FairnessFunction", + "instantiate_fairness_function", +] + + +@create_types_registry(name="FairnessFunction") +class FairnessFunction(metaclass=abc.ABCMeta): + """Abstract base class for group-fairness functions. + + Group Fairness + -------------- + + This abstract base class defines a common API and shared backend code + to compute group-wise fairness metrics for group-fairness definitions + that can be written using the following canonical form, introduced by + Maheshwari & Perrot (2023) [1]: + + $$ F_k(h, T) = C_k^0 + \\sum_{k'} C_k^{k'} P(h(x) \\neq y | T_{k'}) $$ + + Where $F_k$ is the fairness metric associated with sensitive group $k$, + $h$ is the evaluated classifier, $T$ denotes a data distribution (that + is approximated based on an empirical dataset) and $T_k$ denotes the + distribution of samples belonging to group $k$. + + Scope + ----- + + This class implements the formula above based on empirical sample counts + and group-wise accuracy estimates. It also implements its adaptation to + the federated learning setting, where $P(h(x) \\neq y | T_{k'})$ can be + computed from client-wise accuracy values for local samples belonging to + group $k'$, as their weighted average based on client-wise group-wise + sample counts. Here, clients merely need to send values scaled by their + local counts, while the server only needs to access the total group-wise + counts, which makes these computations compatible with SecAgg. + + This class does not implement the computation of group-wise counts, + nor the evaluation of the group-wise accuracy of a given model, but + merely instruments these quantities together with the definition of + a group-fairness notion in order to compute the latter's values. + + Inheritance + ----------- + + Subclasses of `FairnessFunction` are required to: + + - specify a `f_type` string class attribute, which is meant to be + unique across subclasses; + - implement the `compute_fairness_constants` abstract method, which + is called upon instantiation to define the $C_k^{k'}$ constants + from the canonical form of the group-fairness function. + + By default, subclasses are type-registered under the `f_type` name, + enabling instantiation with the `instantiate_fairness_function` generic + constructor. This can be disabled by passing the `register=False` kwarg + at inheritance (e.g. `class MyFunc(FairnessFunction, register=False):`). + + References + ---------- + [1] Maheshwari & Perrot (2023). + FairGrad: Fairness Aware Gradient Descent. + https://openreview.net/forum?id=0f8tU3QwWD + """ + + f_type: ClassVar[str] + + def __init__( + self, + counts: Dict[Tuple[Any, ...], int], + ) -> None: + """Instantiate a group-fairness function from its defining parameters. + + Parameters + ---------- + counts: + Group-wise counts for all target label and sensitive attribute(s) + intersected values, with format `{(label, *attrs): count}`. + """ + self._counts = counts.copy() + c_k0, c_kk = self.compute_fairness_constants() + self._ck0 = c_k0 + self._ckk = c_kk + + def __init_subclass__( + cls, + register: bool = True, + ) -> None: + """Automatically type-register subclasses.""" + if register: + register_type(cls, name=cls.f_type, group="FairnessFunction") + + @functools.cached_property + def groups(self) -> List[Tuple[Any, ...]]: + """Sorted list of defined sensitive groups for this function.""" + return sorted(self._counts) + + @functools.cached_property + def constants(self) -> Tuple[np.ndarray, np.ndarray]: + """Constants defining the fairness function.""" + return self._ck0.copy(), self._ckk.copy() + + @abc.abstractmethod + def compute_fairness_constants( + self, + ) -> Tuple[np.ndarray, np.ndarray]: + """Compute fairness constants associated with this function. + + This method computes and returns fairness constants that are based + on the specific group-fairness definition, on the sensitive groups' + sample counts and on any other class-specific hyper-parameter set + at instantiation. + + This method is notably called upon instantiation, to produce values + that are instrumental in all subsequent fairness computations using + the defined setting. + + Cached values may be accessed using the `constants` property getter. + + Returns + ------- + c_k0: + 1-d array containing $C_k^0$ constants for each and every group. + May be a single-value array, notably when $C_k^0 = 0 \\forall k$. + c_kk: + 2-d array containing $C_k^{k'}$ constants for each and every pair + of sensitive groups. + + Raises + ------ + ValueError + If the fairness constants' computation fails. + """ + + def compute_from_group_accuracy( + self, + accuracy: Dict[Tuple[Any, ...], float], + ) -> Dict[Tuple[Any, ...], float]: + """Compute the fairness function from group-wise accuracy metrics. + + Parameters + ---------- + accuracy: + Group-wise accuracy values of the model being evaluated on a + dataset. I.e. `{group_k: P(y_pred == y_true | group_k)}`. + + Returns + ------- + fairness: + Group-wise fairness metrics, as a `{group_k: score_k}` dict. + Values' interpretation depend on the implemented group-fairness + definition, but overall the fairer the accuracy towards a group, + the closer the metric is to zero. + + Raises + ------ + KeyError + If any defined sensitive group does not have an accuracy metric. + """ + if diff := set(self.groups).difference(accuracy): + raise KeyError( + f"Group accuracies from groups {diff} are missing from inputs" + f" to '{self.__class__.__name__}.compute_from_group_accuracy'." + ) + # Compute F_k = C_k^0 + Sum_{k'}(C_k^k' * (1 - acc_k')) + cerr = 1 - np.array([accuracy[group] for group in self.groups]) + c_k0, c_kk = self.constants + f_k = c_k0 + np.dot(c_kk, cerr) + # Wrap up results as a {group: score} dict, for readability purposes. + return dict(zip(self.groups, f_k.tolist())) + + def compute_from_federated_group_accuracy( + self, + accuracy: Dict[Tuple[Any, ...], float], + ) -> Dict[Tuple[Any, ...], float]: + """Compute the fairness function from federated group-wise accuracy. + + Parameters + ---------- + accuracy: + Group-wise sum-aggregated local-group-count-weighted accuracies + of a given model over an ensemble of local datasets. + I.e. `{group_k: sum_i(n_ik * accuracy_ik)}`. + + Returns + ------- + fairness: + Group-wise fairness metrics, as a `{group_k: score_k}` dict. + Values' interpretation depend on the implemented group-fairness + definition, but overall the fairer the accuracy towards a group, + the closer the metric is to zero. + + Raises + ------ + KeyError + If any defined sensitive group does not have an accuracy metric. + """ + accuracy = { + key: val / cnt + for key, val in accuracy.items() + if (cnt := self._counts.get(key)) is not None + } + return self.compute_from_group_accuracy(accuracy) + + def get_specs( + self, + ) -> Dict[str, Any]: + """Return specifications of this fairness function. + + Returns + ------- + specs: + Dict of keyword arguments that may be passed to the + `declearn.fairness.core.instantiate_fairness_function` + generic constructor to recover a copy of this instance. + """ + return {"f_type": self.f_type, "counts": self._counts.copy()} + + +def instantiate_fairness_function( + f_type: str, + counts: Dict[Tuple[Any, ...], int], + **kwargs: Any, +) -> FairnessFunction: + """Instantiate a FairnessFunction from its specifications. + + Parameters + ---------- + f_type: + Name of the type of group-fairness function to instantiate. + counts: + Group-wise counts for all target label and sensitive attribute(s) + intersected values, with format `{(label, *attrs): count}`. + **kwargs: + Any keyword argument for the instantiation of the target function + may be passed. + """ + cls = access_registered(name=f_type, group="FairnessFunction") + assert issubclass(cls, FairnessFunction) + return cls(counts=counts, **kwargs) diff --git a/declearn/fairness/api/_metrics.py b/declearn/fairness/api/_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..5565e7c6b0410ea3e2c0477ed8f142e69ea800a1 --- /dev/null +++ b/declearn/fairness/api/_metrics.py @@ -0,0 +1,251 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility dataset-handler to compute group-wise model evaluation metrics.""" + +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np + +from declearn.fairness.api._dataset import FairnessDataset +from declearn.metrics import Accuracy, MeanMetric, MetricSet +from declearn.model.api import Model +from declearn.model.sklearn import SklearnSGDModel + + +__all__ = [ + "FairnessMetricsComputer", +] + + +class ModelLoss(MeanMetric, register=False): + """Metric container to compute a model's loss iteratively.""" + + name = "loss" + + def __init__( + self, + model: Model, + ) -> None: + super().__init__() + self.model = model + + def metric_func( + self, + y_true: np.ndarray, + y_pred: np.ndarray, + ) -> np.ndarray: + return self.model.loss_function(y_true, y_pred) + + +class FairnessMetricsComputer: + """Utility dataset-handler to compute group-wise evaluation metrics. + + This class aims at making fairness evaluation of models readable, + by internalizing the computation of group-wise accuracy (or loss) + metrics, that may then be passed to a `FairnessFunction` instance + so as to compute the associate fairness values. + + In federated contexts, clients' group-wise accuracy scores should + be weighted by their group-wise counts, sum-aggregated and passed + to `FairnessFunction.compute_from_federated_group_accuracy`. + + Attributes + ---------- + counts: Dict[Tuple[Any, ...], int] + Category-wise number of data samples. + g_data: Dict[Tuple[Any, ...], FairnessDataset] + Category-wise sub-datasets, over which the accuracy of a model + may be computed via the `compute_groupwise_accuracy` method. + """ + + def __init__( + self, + dataset: FairnessDataset, + ) -> None: + """Wrap up a `FairnessDataset` to facilitate metrics computation. + + Parameters + ---------- + dataset: + `FairnessDataset` instance, that wraps samples over which + to estimate models' evaluation metrics, and defines the + partition of that data into sensitive groups. + """ + self.counts = dataset.get_sensitive_group_counts() + self.g_data = { + group: dataset.get_sensitive_group_subset(group) + for group in dataset.get_sensitive_group_definitions() + } + + def compute_groupwise_metrics( + self, + metrics: List[MeanMetric], + model: Model, + batch_size: int = 32, + n_batch: Optional[int] = None, + ) -> Dict[str, Dict[Tuple[Any, ...], float]]: + """Compute an ensemble of mean metrics over group-wise sub-datasets. + + Parameters + ---------- + metrics: + List of `MeanMetric` instances defining metrics to compute, + that are required to be scalar float values. + model: + Model that is to be evaluated. + batch_size: int, default=32 + Number of samples per batch when computing predictions. + n_batch: int or None, default=None + Optional maximum number of batches to draw per group. + If None, use the entire wrapped dataset. + + Returns + ------- + metrics: + Computed group-wise metrics, as a nested dictionary + with `{metric.name: {group: value}}` structure. + """ + metricset = MetricSet(metrics) + output = { + metric.name: {} for metric in metrics + } # type: Dict[str, Dict[Tuple[Any, ...], float]] + for group in self.g_data: + values = self.compute_metrics_over_sensitive_group( + group, metricset, model, batch_size, n_batch + ) + for metric in metrics: + output[metric.name][group] = float(values[metric.name]) + return output + + def compute_metrics_over_sensitive_group( + self, + group: Tuple[Any, ...], + metrics: MetricSet, + model: Model, + batch_size: int = 32, + n_batch: Optional[int] = None, + ) -> Dict[str, Union[float, np.ndarray]]: + """Compute some metrics for a given model and sensitive group. + + Parameters + ---------- + group: tuple + Tuple of sensitive attribute values defining the group, + the accuracy of the model over which to compute. + metrics: MetricSet + Ensemble of metrics that need to be computed. + model: Model + Model that is to be evaluated. + batch_size: int, default=32 + Number of samples per batch when computing predictions. + n_batch: int or None, default=None + Optional maximum number of batches to draw. + If None, use the entire wrapped dataset. + + Returns + ------- + metrics: + Dict storing resulting metrics. + + Raises + ------ + KeyError: + If `category` is an invalid key to the existing combinations + of sensitive attribute values. + """ + # arguments serve modularity; pylint: disable=too-many-arguments + # Prepare to iterate over batches from the target group. + if group not in self.g_data: + raise KeyError(f"Invalid sensitive group: '{group}'.") + gen_batches = self.g_data[group].generate_batches( + batch_size, shuffle=(n_batch is not None), drop_remainder=False + ) + # Iteratively evaluate the model. + metrics.reset() + for idx, batch in enumerate(gen_batches): + if n_batch and (idx == n_batch): + break + # Run the model in inference, and round up output scores. + batch_predictions = model.compute_batch_predictions(batch) + metrics.update(*batch_predictions) + # Return the computed metrics. + return metrics.get_result() + + def setup_accuracy_metric( + self, + model: Model, + thresh: Optional[float] = None, + ) -> MeanMetric: + """Return a Metric object to compute a model's accuracy. + + Parameters + ---------- + model: Model + Model that needs to be evaluated. + thresh: int or None, default=None + Optional binarization threshold for binary classification + models' output scores. If None, use 0.5 by default, or 0.0 + for `SklearnSGDModel` instances. + Unused for multinomial classifiers (argmax over scores). + + Returns + ------- + metric: + `MeanMetric` subclass that computes the average accuracy + from pre-computed model predictions. + """ + if thresh is None: + thresh = 0.0 if isinstance(model, SklearnSGDModel) else 0.5 + return Accuracy(thresh=thresh) + + def setup_loss_metric( + self, + model: Model, + ) -> MeanMetric: + """Compute a model's accuracy and loss over each sensitive group. + + Parameters + ---------- + model: Model + Model that needs to be evaluated. + + Returns + ------- + metric: + `MeanMetric` subclass that computes the average loss of + the input `model` based on pre-computed predictions. + """ + return ModelLoss(model) + + def scale_metrics_by_sample_counts( + self, + metrics: Dict[Tuple[Any, ...], float], + ) -> Dict[Tuple[Any, ...], float]: + """Scale a dict of computed group-wise metrics by sample counts. + + Parameters + ---------- + metrics: + Pre-computed raw metrics, as a `{group_k: score_k}` dict. + + Returns + ------- + metrics: + Scaled matrics, as a `{group_k: n_k * score_k}` dict. + """ + return {key: val * self.counts[key] for key, val in metrics.items()} diff --git a/declearn/fairness/api/_server.py b/declearn/fairness/api/_server.py new file mode 100644 index 0000000000000000000000000000000000000000..561261dca4a5116a08aecba088f464f70865320a --- /dev/null +++ b/declearn/fairness/api/_server.py @@ -0,0 +1,479 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Server-side ABC for fairness-aware federated learning controllers.""" + +import abc +from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union + +import numpy as np + +from declearn.aggregator import Aggregator +from declearn.communication.api import NetworkServer +from declearn.communication.utils import verify_client_messages_validity +from declearn.messaging import ( + Error, + FairnessCounts, + FairnessGroups, + FairnessReply, + FairnessSetupQuery, + SerializedMessage, +) +from declearn.secagg.api import Decrypter +from declearn.secagg.messaging import ( + aggregate_secagg_messages, + SecaggFairnessCounts, + SecaggFairnessReply, +) +from declearn.utils import ( + access_registered, + create_types_registry, + register_type, +) + +__all__ = [ + "FairnessControllerServer", +] + + +@create_types_registry(name="FairnessControllerServer") +class FairnessControllerServer(metaclass=abc.ABCMeta): + """Abstract base class for server-side fairness controllers. + + Usage + ----- + A `FairnessControllerServer` (subclass) instance has two main + routines that are to be called as part of a federated learning + process: + + - `setup_fairness`: + This routine is to be called only once, during the setup of the + overall federated learning task. It triggers the following process: + - Send a `FairnessSetupQuery` to clients so that they + instantiate a counterpart `FairnessControllerClient`. + - Run a basic routine to exchange sensitive group definitions + and (secure-)aggregate associated sample counts. + - Perform any additional algorithm-specific setup actions. + + - `run_fairness_round`: + This routine is to be called once per round, before the next + training round occurs. A `FairnessQuery` should be sent to + clients prior to calling it. It triggers the following process: + - Run a basic routine to receive and (secure-)aggregate + metrics computed by clients that relate to fairness. + - Perform any additonal algorithm-specific round actions. + + Inheritance + ----------- + Algorithm-specific subclasses should define the following abstract + attribute and methods: + + - `algorithm`: + Abstract string class attribute. Name under which this controller + and its client-side counterpart classes are registered. + - `finalize_fairness_setup`: + Method implementing any algorithm-specific setup actions. + - `finalize_fairness_round`: + Method implementing any algorithm-specific round actions. + + By default, subclasses are type-registered under their `algorithm` + name and "FairnessControllerServer" group upon declaration. This can + be prevented by passing `register=False` to the inheritance parameters + (e.g. `class Cls(FairnessControllerServer, register=False)`). + See `declearn.utils.register_type` for details on types registration. + """ + + algorithm: ClassVar[str] + """Name of the fairness-enforcing algorithm. + + This name should be unique across 'FairnessControllerServer' classes, + and shared with a unique paired 'FairnessControllerClient'. It is used + for type-registration and to enable instructing clients to instantiate + a controller matching that chosen by the server in a federated setting. + """ + + def __init_subclass__( + cls, + register: bool = True, + ) -> None: + """Automatically type-register subclasses.""" + if register: + register_type(cls, cls.algorithm, group="FairnessControllerServer") + + def __init__( + self, + f_type: str, + f_args: Optional[Dict[str, Any]] = None, + ) -> None: + """Instantiate the server-side fairness controller. + + Parameters + ---------- + f_type: + Name of the fairness function to evaluate and optimize. + f_args: + Optional dict of keyword arguments to the fairness function. + """ + self.f_type = f_type + self.f_args = f_args or {} + self.groups = [] # type: List[Tuple[Any, ...]] + + # Fairness Setup methods. + + async def setup_fairness( + self, + netwk: NetworkServer, + aggregator: Aggregator, + secagg: Optional[Decrypter], + ) -> Aggregator: + """Orchestrate a routine to initialize fairness-aware learning. + + This routine has the following structure: + + - Send a setup query to clients, resulting in the instantiation + of client-side controllers matching this one. + - Exchange with clients to agree on an ordered list of sensitive + groups defined by the interesection of 1+ sensitive attributes + and (opt.) a classification target label. + - Receive and (secure-)aggregate group-wise sample counts across + clients' training dataset. + - Perform any additional actions specific to the algorithm in use. + - On the server side, optionally alter the `Aggregator` used. + - On the client side, optionally alter the `TrainingManager` used. + + Parameters + ---------- + netwk: + NetworkServer endpoint, to which clients are registered. + aggregator: + Aggregator instance that was set up notwithstanding fairness. + secagg: + Optional SecAgg decryption controller. + + Warns + ----- + RuntimeWarning + If the returned aggregator differs from the input one. + + Returns + ------- + aggregator: + `Aggregator` instance to use in the FL process, that may + or may not have been altered compared with the input one. + """ + # Send a setup query to all clients. + query = self.prepare_fairness_setup_query() + await netwk.broadcast_message(query) + # Agree on a list of sensitive groups and aggregate sample counts. + counts = await self.exchange_sensitive_groups_list_and_counts( + netwk, secagg + ) + # Run additional algorithm-specific setup steps. + return await self.finalize_fairness_setup( + netwk, secagg, counts, aggregator + ) + + def prepare_fairness_setup_query( + self, + ) -> FairnessSetupQuery: + """Return a request to setup fairness, broadcastable to clients. + + Returns + ------- + message: + `FairnessSetupQuery` instance to be sent to clients in order + to trigger the Fairness setup protocol. + """ + return FairnessSetupQuery( + algorithm=self.algorithm, + params={"f_type": self.f_type, "f_args": self.f_args}, + ) + + async def exchange_sensitive_groups_list_and_counts( + self, + netwk: NetworkServer, + secagg: Optional[Decrypter], + ) -> List[int]: + """Agree on a list of sensitive groups and aggregate sample counts. + + This method performs the following routine: + + - Await `FairnessGroups` messages from clients with group definitions. + - Assign a sorted list of sensitive groups as `groups` attribute. + - Share that list with clients. + - Await possibly-encrypted group-wise sample counts from clients. + - (Secure-)Aggregate these sample counts and return them. + + Parameters + ---------- + netwk: + `NetworkServer` endpoint, through which a fairness setup query + was previously sent to all clients. + secagg: + Optional SecAgg decryption controller. + + Returns + ------- + counts: + List of group-wise total sample count across clients, + sorted based on the newly-assigned `self.groups`. + """ + # Receive, aggregate, assign and send back sensitive group definitions. + self.groups = await self._exchange_sensitive_groups_list(netwk) + # Receive, (secure-)aggregate and return group-wise sample counts. + return await self._aggregate_sensitive_groups_counts(netwk, secagg) + + @staticmethod + async def _exchange_sensitive_groups_list( + netwk: NetworkServer, + ) -> List[Tuple[Any, ...]]: + """Receive, aggregate, share and return sensitive group definitions.""" + received = await netwk.wait_for_messages() + # Verify and deserialize client-wise sensitive group definitions. + messages = await verify_client_messages_validity( + netwk, received, expected=FairnessGroups + ) + # Gather the sorted union of all existing definitions. + unique = {group for msg in messages.values() for group in msg.groups} + groups = sorted(list(unique)) + # Send it to clients, and expect their reply (encrypted counts). + await netwk.broadcast_message(FairnessGroups(groups=groups)) + return groups + + async def _aggregate_sensitive_groups_counts( + self, + netwk: NetworkServer, + secagg: Optional[Decrypter], + ) -> List[int]: + """Receive, (secure-)aggregate and return group-wise sample counts.""" + received = await netwk.wait_for_messages() + if secagg is None: + return await self._aggregate_sensitive_groups_counts_cleartext( + netwk=netwk, received=received, n_groups=len(self.groups) + ) + return await self._aggregate_sensitive_groups_counts_encrypted( + netwk=netwk, received=received, decrypter=secagg + ) + + @staticmethod + async def _aggregate_sensitive_groups_counts_cleartext( + netwk: NetworkServer, + received: Dict[str, SerializedMessage], + n_groups: int, + ) -> List[int]: + """Deserialize and aggregate cleartext group-wise counts.""" + replies = await verify_client_messages_validity( + netwk, received, expected=FairnessCounts + ) + counts = np.zeros(n_groups, dtype="uint64") + for message in replies.values(): + counts = counts + np.asarray(message.counts, dtype="uint64") + return counts.tolist() + + @staticmethod + async def _aggregate_sensitive_groups_counts_encrypted( + netwk: NetworkServer, + received: Dict[str, SerializedMessage], + decrypter: Decrypter, + ) -> List[int]: + """Deserialize and secure-aggregate encrypted group-wise counts.""" + replies = await verify_client_messages_validity( + netwk, received, expected=SecaggFairnessCounts + ) + aggregated = aggregate_secagg_messages(replies, decrypter) + return aggregated.counts + + @abc.abstractmethod + async def finalize_fairness_setup( + self, + netwk: NetworkServer, + secagg: Optional[Decrypter], + counts: List[int], + aggregator: Aggregator, + ) -> Aggregator: + """Finalize the fairness setup routine and return an Aggregator. + + This method is called as part of `setup_fairness`, and should + be defined by concrete subclasses to implement setup behavior + once the initial query/reply messages have been exchanged. + + The returned `Aggregator` may either be the input `aggregator` + or a new or modified version of it, depending on the needs of + the fairness-aware federated learning process being implemented. + + Warns + ----- + RuntimeWarning + If the returned aggregator differs from the input one. + + Returns + ------- + aggregator: + `Aggregator` instance to use in the FL process, that may + or may not have been altered compared with the input one. + """ + + # Fairness Round methods. + + async def run_fairness_round( + self, + netwk: NetworkServer, + secagg: Optional[Decrypter], + ) -> Dict[str, Union[float, np.ndarray]]: + """Secure-aggregate and post-process fairness measures. + + This method is to be run **after** having sent a `FairnessQuery` + to clients. It consists in receiving, (secure-)aggregating and + post-processing measures that clients produce as a reply to that + query. This may involve further algorithm-specific communications. + + Parameters + ---------- + netwk: + NetworkServer endpoint instance, to which clients are registered. + secagg: + Optional SecAgg decryption controller. + + Returns + ------- + metrics: + Fairness(-related) metrics computed as part of this routine, + as a dict mapping scalar or numpy array values with their name. + """ + values = await self.receive_and_aggregate_fairness_measures( + netwk, secagg + ) + return await self.finalize_fairness_round(netwk, secagg, values) + + async def receive_and_aggregate_fairness_measures( + self, + netwk: NetworkServer, + secagg: Optional[Decrypter], + ) -> List[float]: + """Await and (secure-)aggregate client-wise fairness-related metrics. + + This method is designed to be called after sending a `FairnessQuery` + to clients, and returns values that are yet to be parsed and used by + the algorithm-dependent `finalize_fairness_round` method. + + Parameters + ---------- + netwk: + NetworkServer endpoint instance, to which clients are registered. + secagg: + Optional SecAgg decryption controller. + + Returns + ------- + metrics: + List of sum-aggregated fairness-related metrics (as floats). + By default, these are group-wise accuracy values; this may + however be changed or expanded by algorithm-specific classes. + """ + received = await netwk.wait_for_messages() + # Case when expecting cleartext values. + if secagg is None: + replies = await verify_client_messages_validity( + netwk, received, expected=FairnessReply + ) + if len(set(len(r.values) for r in replies.values())) != 1: + error = "Clients sent fairness values of different lengths." + await netwk.broadcast_message(Error(error)) + raise RuntimeError(error) + return [ + sum(rval) + for rval in zip(*[reply.values for reply in replies.values()]) + ] + # Case when expecting encrypted values. + secagg_replies = await verify_client_messages_validity( + netwk, received, expected=SecaggFairnessReply + ) + agg_reply = aggregate_secagg_messages(secagg_replies, decrypter=secagg) + return agg_reply.values + + @abc.abstractmethod + async def finalize_fairness_round( + self, + netwk: NetworkServer, + secagg: Optional[Decrypter], + values: List[float], + ) -> Dict[str, Union[float, np.ndarray]]: + """Orchestrate a round of actions to enforce fairness. + + This method is designed to be called after an initial query + has been sent and responded to by clients, resulting in the + federated computation of fairness(-related) metrics. + + Parameters + ---------- + netwk: + NetworkServer endpoint instance, to which clients are registered. + secagg: + Optional SecAgg decryption controller. + values: + Aggregated metrics resulting from the fairness evaluation + run by clients at this round. + + Returns + ------- + metrics: + Fairness(-related) metrics computed as part of this routine, + as a dict mapping scalar or numpy array values with their name. + """ + + @staticmethod + def from_specs( + algorithm: str, + f_type: str, + f_args: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> "FairnessControllerServer": + """Instantiate a 'FairnessControllerServer' from its specifications. + + Parameters + ---------- + algorithm: + Name of the algorithm associated with the target controller class. + f_type: + Name of the fairness function to evaluate and optimize. + f_args: + Optional dict of keyword arguments to the fairness function. + **kwargs: + Any additional algorithm-specific instantiation keyword argument. + + Returns + ------- + controller: + `FairnessControllerServer` instance matching input specifications. + + Raises + ------ + KeyError + If `algorithm` does not match any registered + `FairnessControllerServer` type. + """ + try: + cls = access_registered( + name=algorithm, group="FairnessControllerServer" + ) + except Exception as exc: + raise KeyError( + "Failed to retrieve fairness controller with algorithm name " + f"'{algorithm}'." + ) from exc + assert issubclass(cls, FairnessControllerServer) + return cls(f_type=f_type, f_args=f_args, **kwargs) diff --git a/declearn/fairness/core/__init__.py b/declearn/fairness/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7ee5702c240d64820970c234438fd77e850ed676 --- /dev/null +++ b/declearn/fairness/core/__init__.py @@ -0,0 +1,57 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Built-in concrete components for fairness-aware federated learning. + +Dataset subclass +---------------- + +* [FairnessInMemoryDataset][declearn.fairness.core.FairnessInMemoryDataset]: + Fairness-aware `InMemoryDataset` subclass. + + +Group-fairness functions +------------------------ +Concrete implementations of various fairness functions: + +* [AccuracyParityFunction][declearn.fairness.core.AccuracyParityFunction]: + Accuracy Parity group-fairness function. +* [DemographicParityFunction]\ +[declearn.fairness.core.DemographicParityFunction]: + Demographic Parity group-fairness function for binary classifiers.. +* [EqualityOfOpportunityFunction]\ +[declearn.fairness.core.EqualityOfOpportunityFunction]: + Equality of Opportunity group-fairness function. +* [EqualizedOddsFunction][declearn.fairness.core.EqualizedOddsFunction]: + Equalized Odds group-fairness function. + +Abstraction and generic constructor may be found in [declearn.fairness.api][]. +An additional util may be used to list available functions, either declared +here or by third-party and end-user code: + +* [list_fairness_functions][declearn.fairness.core.list_fairness_functions]: + Return a mapping of registered FairnessFunction subclasses. +""" + +from ._functions import ( + AccuracyParityFunction, + DemographicParityFunction, + EqualityOfOpportunityFunction, + EqualizedOddsFunction, + list_fairness_functions, +) +from ._inmemory import FairnessInMemoryDataset diff --git a/declearn/fairness/core/_functions.py b/declearn/fairness/core/_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..6d4dbc735892c311a98a6c85a2088bd97254618f --- /dev/null +++ b/declearn/fairness/core/_functions.py @@ -0,0 +1,384 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Concrete implementations of various group-fairness functions.""" + +from typing import Any, Dict, List, Tuple, Type, Union + +import numpy as np + +from declearn.fairness.api import FairnessFunction +from declearn.utils import access_types_mapping + +__all__ = ( + "AccuracyParityFunction", + "DemographicParityFunction", + "EqualityOfOpportunityFunction", + "EqualizedOddsFunction", + "list_fairness_functions", +) + + +def list_fairness_functions() -> Dict[str, Type[FairnessFunction]]: + """Return a mapping of registered FairnessFunction subclasses. + + This function aims at making it easy for end-users to list and access + all available FairnessFunction classes at any given time. The returned + dict uses unique identifier keys, which may be used to use the associated + function within a [declearn.fairness.api.FairnessControllerServer][]. + + Note that the mapping will include all declearn-provided functions, + but also registered functions provided by user or third-party code. + + See also + -------- + * [declearn.fairness.api.FairnessFunction][]: + API-defining abstract base class for the FairnessFunction classes. + + Returns + ------- + mapping: + Dictionary mapping unique str identifiers to `FairnessFunction` + class constructors. + """ + return access_types_mapping("FairnessFunction") + + +class AccuracyParityFunction(FairnessFunction): + """Accuracy Parity group-fairness function. + + Definition + ---------- + + Accuracy Parity is achieved when + + $$ \\forall r \\in S , P(h(x) = y | s = r) == P(h(x) = y) $$ + + where $S$ denotes possible values of a (set of intersected) sensitive + attribute(s), $y$ is the true target classification label and $h$ is + the evaluated classifier. + + In other words, Accuracy Parity is achieved when the model's accuracy + is independent from the sensitive attribute(s) (but not necessarily + balanced across specific target classes). + + Formula + ------- + + For any sensitive group $k = (l, r)$ defined by the intersection of a + given true label and a value of the sensitive attribute(s), Accuracy + Parity can be expressed in the canonical form from [1]: + + $$ F_k(h, T) = C_k^0 + \\sum_{k'} C_k^{k'} P(h(x) \\neq y | T_{k'}) $$ + + using the following constants: + + - $ C_k^{k'} = (n_{k'} / n) - 1{k_s = k'_s} * (n_{k'} / n_s) $ + - $ C_k^0 = 0 $ + + where $n$ denotes a number of samples in the empirical dataset used, + and its subscripted counterparts are number of samples that belong to + a given sensitive group and/or have given sensitive attribute(s) value. + + This results in all scores associated with a given target label to all + be equal, as the partition in sensitive groups could be done regardless + of the target label. + + References + ---------- + [1] Maheshwari & Perrot (2023). + FairGrad: Fairness Aware Gradient Descent. + https://openreview.net/forum?id=0f8tU3QwWD + """ + + f_type = "accuracy_parity" + + def compute_fairness_constants( + self, + ) -> Tuple[np.ndarray, np.ndarray]: + counts = self._counts + # Compute the sensitive-attributes-wise and total number of samples. + s_tot = {} # type: Dict[Any, int] # attributes-wise number of samples + for categ, count in counts.items(): + s_tot[categ[1:]] = s_tot.get(categ[1:], 0) + count + total = sum(s_tot.values()) + # Compute fairness constraint constants C_k^k'. + c_kk = np.zeros((len(counts), len(counts))) + for idx, cat_i in enumerate(counts): + for jdx, cat_j in enumerate(counts): + coef = counts[cat_j] / total + if cat_i[1:] == cat_j[1:]: # same sensitive attributes + coef -= counts[cat_j] / s_tot[cat_i[1:]] + c_kk[idx, jdx] = coef + # Return the computed constants. + c_k0 = np.array([0.0]) + return c_k0, c_kk + + +class DemographicParityFunction(FairnessFunction): + """Demographic Parity group-fairness function for binary classifiers. + + Note that this implementation is restricted to binary classification. + + Definition + ---------- + + Demographic Parity is achieved when + + $$ \\forall l \\in Y, \\forall r \\in S, + P(h(x) = l | s = r) == P(h(x) = l) $$ + + where $Y$ denotes possible target labels, $S$ denotes possible values + of a (set of intersected) sensitive attribute(s), and $h$ is the + evaluated classifier. + + In other words, Demographic Parity is achieved when the probability to + predict any given label is independent from the sensitive attribute(s) + (regardless of whether that label is accurate or not). + + Formula + ------- + + When considering a binary classification task, for any sensitive group + $k = (l, r)$ defined by the intersection of a given true label and a + value of the sensitive attribute(s), Demographic Parity can be expressed + in the canonical form from [1]: + + $$ F_k(h, T) = C_k^0 + \\sum_{k'} C_k^{k'} P(h(x) \\neq y | T_{k'}) $$ + + using the following constants: + + - $ C_{l,r}^0 = (n_k / n_r) - (n_l / n) $ + - $ C_{l,r}^{l,r} = (n_k / n) - (n_k / n_r) $ + - $ C_{l,r}^{l',r} = (n_{k'} / n_r) - (n_{k'} / n) $ + - $ C_{l,r}^{l,r'} = n_{k'} / n $ + - $ C_{l,r}^{l',r'} = - n_{k'} / n $ + + where $n$ denotes a number of samples in the empirical dataset used, + and its subscripted counterparts are number of samples that belong to + a given sensitive group and/or have given sensitive attribute(s) or + true label values. + + References + ---------- + [1] Maheshwari & Perrot (2023). + FairGrad: Fairness Aware Gradient Descent. + https://openreview.net/forum?id=0f8tU3QwWD + """ + + f_type = "demographic_parity" + + def compute_fairness_constants( + self, + ) -> Tuple[np.ndarray, np.ndarray]: + counts = self._counts + # Check that the labels are binary. + if len(set(key[0] for key in counts)) != 2: + raise ValueError("Demographic Parity requires binary labels.") + # Compute label-wise, attribute-wise and total number of samples. + l_tot = {} # type: Dict[Any, int] # label-wise number of samples + s_tot = {} # type: Dict[Tuple[Any, ...], int] # attribute-wise + for categ, count in counts.items(): + l_tot[categ[0]] = l_tot.get(categ[0], 0) + count + s_tot[categ[1:]] = s_tot.get(categ[1:], 0) + count + total = sum(l_tot.values()) + # Compute fairness constraint constants C_k^k'. + c_k0 = np.zeros(len(counts)) + c_kk = np.zeros((len(counts), len(counts))) + for idx, cat_i in enumerate(counts): + # Compute the C_k^0 constant. + c_k0[idx] = counts[cat_i] / s_tot[cat_i[1:]] + c_k0[idx] -= l_tot[cat_i[0]] / total + # Compute all other C_k^k' constants. + for jdx, cat_j in enumerate(counts): + value = counts[cat_j] / total # n_k' / n + if cat_i[1:] == cat_j[1:]: # same sensitive attributes + value -= counts[cat_j] / s_tot[cat_i[1:]] # n_k' / n_s + if cat_i[0] != cat_j[0]: # distinct label + value *= -1 + c_kk[idx, jdx] = value + # Return the computed constants. + return c_k0, c_kk + + +class EqualizedOddsFunction(FairnessFunction): + """Equalized Odds group-fairness function. + + Definition + ---------- + + Equalized Odds is achieved when + + $$ \\forall l \\in Y, \\forall r \\in S, + P(h(x) = y | y = l) == P(h(x) = y | y = l, s = r) $$ + + where $Y$ denotes possible target labels, $S$ denotes possible values + of a (set of intersected) sensitive attribute(s), and $h$ is the + evaluated classifier. + + In other words, Equalized Odds is achieved when the probability that + the model predicts the correct label is independent from the sensitive + attribute(s). + + Formula + ------- + + For any sensitive group $k = (l, r)$ defined by the intersection of a + given true label and a value of the sensitive attribute(s), Equalized + Odds can be expressed in the canonical form from [1]: + + $$ F_k(h, T) = C_k^0 + \\sum_{k'} C_k^{k'} P(h(x) \\neq y | T_{k'}) $$ + + using the following constants: + + - $ C_k^k = (n_k / n_l) - 1 $ + - $ C_k^{k'} = (n_{k'} / n_l) * 1{k_l = k'_l} $ + - $ C_k^0 = 0$ + + where $n$ denotes a number of samples in the empirical dataset used, + and its subscripted counterparts are number of samples that belong to + a given sensitive group and/or have a given true label. + + References + ---------- + [1] Maheshwari & Perrot (2023). + FairGrad: Fairness Aware Gradient Descent. + https://openreview.net/forum?id=0f8tU3QwWD + """ + + f_type = "equalized_odds" + + def compute_fairness_constants( + self, + ) -> Tuple[np.ndarray, np.ndarray]: + counts = self._counts + # Compute the label-wise number of samples. + l_tot = {} # type: Dict[Any, int] # label-wise number of samples + for categ, count in counts.items(): + l_tot[categ[0]] = l_tot.get(categ[0], 0) + count + # Compute fairness constraint constants C_k^k'. + c_kk = np.zeros((len(counts), len(counts))) + for idx, cat_i in enumerate(counts): + for jdx, cat_j in enumerate(counts): + if cat_i[0] != cat_j[0]: # distinct target label + continue + c_kk[idx, jdx] = counts[cat_j] / l_tot[cat_i[0]] - (idx == jdx) + # Return the computed constants. + c_k0 = np.array([0.0], dtype=c_kk.dtype) + return c_k0, c_kk + + +class EqualityOfOpportunityFunction(EqualizedOddsFunction): + """Equality of Opportunity group-fairness function. + + Definition + ---------- + + Equality of Opportunity is achieved when + + $$ \\forall l \\in Y' \\subset Y, \\forall r \\in S, + P(h(x) = y | y = l) == P(h(x) = y | y = l, s = r) $$ + + where $Y$ denotes possible target labels, $S$ denotes possible values + of a (set of intersected) sensitive attribute(s), and $h$ is the + evaluated classifier. + + In other words, Equality of Opportunity is equivalent to Equalized Odds + but restricted to a subset of possible target labels. It is therefore + achieved when the probability that the model predicts the correct label + is independent from the sensitive attribute(s), for a subset of correct + labels. + + Formula + ------- + + For any sensitive group $k = (l, r)$ defined by the intersection of a + given true label and a value of the sensitive attribute(s), Equality + of Opportunity can be expressed in the canonical form from [1]: + + $$ F_k(h, T) = C_k^0 + \\sum_{k'} C_k^{k'} P(h(x) \\neq y | T_{k'}) $$ + + using the following constants: + + - For $k$ so that $k_l \\in Y'$: + - $ C_k^k = (n_k / n_l) - 1 $ + - $ C_k^{k'} = (n_{k'} / n_l) * 1{k_l = k'_l} $ + - All other constants are null. + + where $n$ denotes a number of samples in the empirical dataset used, + and its subscripted counterparts are number of samples that belong to + a given sensitive group and/or have a given true label. + + References + ---------- + [1] Maheshwari & Perrot (2023). + FairGrad: Fairness Aware Gradient Descent. + https://openreview.net/forum?id=0f8tU3QwWD + """ + + f_type = "equality_of_opportunity" + + def __init__( + self, + counts: Dict[Tuple[Any, ...], int], + target: Union[int, List[int]] = 1, + ) -> None: + """Instantiate an Equality of Opportunity group-fairness function. + + Parameters + ---------- + counts: + Group-wise counts for all target label and sensitive attribute(s) + intersected values, with format `{(label, *attrs): count}`. + target: + Label(s) that fairness constraints are to be restricted to. + """ + # Parse 'target' inputs. + if isinstance(target, int): + self._target = {target} + elif isinstance(target, (list, tuple, set)): + self._target = set(target) + else: + raise TypeError("'target' should be an int or list of ints.") + # Verify that 'target' is a subset of target labels. + targets = set(int(group[0]) for group in counts) + if not self._target.issubset(targets): + raise ValueError( + "'target' should be a subset of target label values present " + "in sensitive groups' definitions." + ) + # Delegate remainder of instantiation to the parent class. + super().__init__(counts=counts) + + def compute_fairness_constants( + self, + ) -> Tuple[np.ndarray, np.ndarray]: + # Compute equalized odds constants (as if all targets were desirable). + c_k0, c_kk = super().compute_fairness_constants() + # Zero-out constants associated with undesired targets. + for idx, cat_i in enumerate(self._counts): + if int(cat_i[0]) not in self._target: + c_kk[idx] = 0.0 + # Return the computed constants. + return c_k0, c_kk + + def get_specs( + self, + ) -> Dict[str, Any]: + specs = super().get_specs() + specs["target"] = list(self._target) + return specs diff --git a/declearn/fairness/core/_inmemory.py b/declearn/fairness/core/_inmemory.py new file mode 100644 index 0000000000000000000000000000000000000000..d4855e676d4559453a0542bd639a924e92847e79 --- /dev/null +++ b/declearn/fairness/core/_inmemory.py @@ -0,0 +1,269 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Fairness-aware InMemoryDataset subclass.""" + +import warnings +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import pandas as pd +import scipy.sparse # type: ignore + +from declearn.dataset import InMemoryDataset +from declearn.dataset.utils import load_data_array +from declearn.fairness.api import FairnessDataset +from declearn.typing import DataArray + + +__all__ = [ + "FairnessInMemoryDataset", +] + + +class FairnessInMemoryDataset(FairnessDataset, InMemoryDataset): + """Fairness-aware InMemoryDataset subclass. + + This class extends `declearn.dataset.InMemoryDataset` to + enable its use in fairness-aware federated learning. New + parameters are added to its `__init__`: `s_attr` as well + as `sensitive_target`, that are used to define sensitive + groups among the held dataset. Additionally, API methods + from `declearn.fairness.api.FairnessDataset` are defined, + enabling to access sensitive groups' metadata and samples + as well as to change sample weights based on the group to + which samples belong. + """ + + def __init__( + self, + data: Union[DataArray, str], + *, + s_attr: Union[DataArray, str, List[int], List[str]], + target: Optional[Union[DataArray, str]] = None, + s_wght: Optional[Union[DataArray, str]] = None, + f_cols: Optional[Union[List[int], List[str]]] = None, + sensitive_target: bool = True, + expose_classes: bool = False, + expose_data_type: bool = False, + seed: Optional[int] = None, + ) -> None: + """Instantiate the memory-fitting-data `FairnessDataset` interface. + + Please refer to `declearn.dataset.InMemoryDataset`, which this class + extends, for generalities about supported input formats. + + Parameters + ---------- + data: + Main data array which contains input features (and possibly + more), or path to a dump file from which it is to be loaded. + s_attr: + Sensitive attributes, that define group-fairness constraints. + May be a data array, the path to dump file, a list of `data` + column names or a list of `data` column indices. + target: + Optional target labels, as a data array, or as a path to a + dump file, or as the name of a `data` column. + s_wght: + Optional sample weights, as a data array, or as a path to a + dump file, or as the name of a `data` column. + f_cols: + Optional list of columns in `data` to use as input features. + These may be specified as column names or indices. If None, + use all non-target, non-sample-weights columns of `data`. + + Other parameters + ---------------- + sensitive_target: + Whether to define sensitive groups based on the intersection + of `s_attr` sensitive attributes and `target` target labels, + or merely on `s_attr`. + expose_classes: + Whether to expose unique target values as part of data specs. + This should only be used for classification datasets. + expose_data_type: + Whether to expose features' dtype, which will be verified to + be unique, as part of data specs. + seed: + Optional seed for the random number generator used for all + randomness-based operations required to generate batches + (e.g. to shuffle the data or sample from it). + """ + # inherited signature; pylint: disable=too-many-arguments + super().__init__( + data=data, + target=target, + s_wght=s_wght, + f_cols=f_cols, + expose_classes=expose_classes, + expose_data_type=expose_data_type, + seed=seed, + ) + # Pre-emptively declare attributes to deal with fairness balancing. + self.sensitive = pd.Series() # type: pd.Series[Any] + self._smp_wght = self.weights # type: DataArray + # Actually set up sensitive groups based on specific parameters. + self._set_sensitive_data(sensitive=s_attr, use_label=sensitive_target) + + def _set_sensitive_data( + self, + sensitive: Union[DataArray, str, List[int], List[str]], + use_label: bool = True, + ) -> None: + """Define sensitive attributes based on which to filter samples. + + This method updates in-place the `sensitive` attribute of this + dataset instance. + + Parameters + ---------- + sensitive: + Sensitive attributes, either as a pandas DataFrame storing data + that is aligned with that already interfaced by this Dataset, + or as a list of columns that are part of `self.data` (only when + the latter is a pandas DataFrame). + use_label: + Whether to use the target labels (when defined) as an additional + sensitive attribute, placed first in the list. Default: True. + + Raises + ------ + TypeError + If the inputs are of unproper type. + ValueError + If 'sensitive' is parsed into an unproper-length data array. + """ + # Gather (and/or validate) sensitive data as a data array. + s_data = self._parse_sensitive_data(sensitive) + if len(s_data) != len(self.data): + raise ValueError( + "The passed 'sensitive' data was parsed into a DataFrame with" + " a number of records that does not match the base data." + ) + # Optionally add target labels as a first sensitive category. + if use_label: + if self.target is None: + warnings.warn( + f"'{self.__class__.__name__}.set_sensitive_data' was" + " called with 'use_label=True', but there are no labels" + " defined for this instance.", + RuntimeWarning, + ) + else: + target = ( + self.target.rename("target") + if isinstance(self.target, pd.Series) + else pd.Series(self.target, name="target") + ) + s_data = pd.concat([target, s_data], axis=1) + # Wrap sensitive data as a Series of tuples of values. + self.sensitive = pd.Series(zip(*[s_data[c] for c in s_data.columns])) + + def _parse_sensitive_data( + self, + sensitive: Union[DataArray, str, List[int], List[str]], + ) -> pd.DataFrame: + """Process inputs to `set_sensitive_data` into a data array.""" + # Handle cases when 'sensitive' is a file path of columns list. + if isinstance(sensitive, str): + sensitive = load_data_array(sensitive) + elif isinstance(sensitive, list): + if isinstance(self.data, pd.DataFrame) and all( + col in self.data.columns for col in sensitive + ): + sensitive = self.data[sensitive] + elif all( + isinstance(col, int) and (col <= self.data.shape[1]) + for col in sensitive + ): + sensitive = ( + self.data.iloc[:, sensitive] # type: ignore[index] + if isinstance(self.data, pd.DataFrame) + else self.data[:, sensitive] # type: ignore[index] + ) + else: + raise TypeError( + "'sensitive' was passed as a list, but matches neither" + " data column names nor indices." + ) + # Type-check and optionally convert sensitive attributes to pandas. + if isinstance(sensitive, pd.DataFrame): + return sensitive + if isinstance(sensitive, np.ndarray): + return pd.DataFrame(sensitive) + if isinstance(sensitive, scipy.sparse.spmatrix): + return pd.DataFrame(sensitive.toarray()) + raise TypeError( + "'sensitive' should be a numpy array, scipy matrix, pandas" + " DataFrame, path to such a structure's file dump, or list" + " of 'data' column names or indices to slice off." + ) + + def get_sensitive_group_definitions( + self, + ) -> List[Tuple[Any, ...]]: + return sorted(self.sensitive.unique().tolist()) + + def get_sensitive_group_counts( + self, + ) -> Dict[Tuple[Any, ...], int]: + return self.sensitive.value_counts().sort_index().to_dict() + + def get_sensitive_group_subset( + self, + group: Tuple[Any, ...], + ) -> InMemoryDataset: + mask = self.sensitive == group + inputs = self.feats[mask] + target = None if self.target is None else self.target[mask] + s_wght = ( + None + if self._smp_wght is None + else self._smp_wght[mask] # type: ignore + ) + return InMemoryDataset( + data=inputs, + target=target, + s_wght=s_wght, + expose_classes=self.expose_classes, + expose_data_type=self.expose_data_type, + seed=self.seed, + ) + + def set_sensitive_group_weights( + self, + weights: Dict[Tuple[Any, ...], float], + adjust_by_counts: bool = False, + ) -> None: + # Optionally adjust input weights based on local group-wise counts. + if adjust_by_counts: + counts = self.get_sensitive_group_counts() + weights = { + key: val * counts.get(key, 0) for key, val in weights.items() + } + # Define or adjust sample weights based on sensitive attributes. + sample_weights = self.sensitive.apply(weights.get) + if sample_weights.isnull().any(): + raise KeyError( + f"'{self.__class__.__name__}.set_sensitive_group_weights'" + " received input weights that seemingly do not cover all" + " existing sensitive groups." + ) + if self._smp_wght is not None: + sample_weights *= self._smp_wght # type: ignore[assignment] + self.weights = sample_weights diff --git a/declearn/fairness/fairbatch/__init__.py b/declearn/fairness/fairbatch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4cc03a78804a9cb7f9e9f2606db5547f48d4752f --- /dev/null +++ b/declearn/fairness/fairbatch/__init__.py @@ -0,0 +1,104 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Fed-FairBatch / FedFB algorithm controllers and utils. + +Introduction +------------ +FairBatch [1] is a group-fairness-enforcing algorithm that relies +on a specific form of loss reweighting mediated via the batching +of samples for SGD steps. Namely, in FairBatch, batches are drawn +by concatenating group-wise sub-batches, the size of which is the +byproduct of the desired total batch size and group-wise sampling +probabilities, with the latter being updated throughout training +based on the measured fairness of the current model. + +This module provides with a double-fold adaptation of FairBatch to +the federated learning setting. On the one hand, a straightforward +adaptation using the law of total probability is proposed, that is +not based on any published paper. On the other hand, the FedFB [2] +algorithm is implemented, which adapts FairBatch in a similar way +but further introduces changes in formulas compared with the base +paper. Both variants are available via a unique pair of classes, +with a boolean switch enabling to choose between them. + +Originally, FairBatch was designed for binary classification tasks +on data that have a single binary sensitive attribute. Both our +implementations currently stick to that setting, in spite of the +FedFB authors using a formalism that arguably extends formulas to +more generic categorical sensitive attribute(s) - which is not +tested in the paper. + +Finally, it is worth noting that the translation of the sampling +probabilities into the data batching process is done in accordance +with the reference implementation by the original FairBatch authors. +More details may be found in the documentation of `FairbatchDataset` +(a backend tool that end-users do not need to use directly). + +Controllers +----------- +* [FairbatchControllerClient] +[declearn.fairness.fairbatch.FairbatchControllerClient]: + Client-side controller to implement Fed-FairBatch or FedFB. +* [FairbatchControllerServer] +[declearn.fairness.fairbatch.FairbatchControllerServer]: + Server-side controller to implement Fed-FairBatch or FedFB. + +Backend +------- +* [FairbatchDataset][declearn.fairness.fairbatch.FairbatchDataset]: + FairBatch-specific FairnessDataset subclass and wrapper. +* [FairbatchSamplingController] +[declearn.fairness.fairbatch.FairbatchSamplingController]: + ABC to compute and update Fairbatch sampling probabilities. +* [setup_fairbatch_controller] +[declearn.fairness.fairbatch.setup_fairbatch_controller]: + Instantiate a FairBatch sampling probabilities controller. +* [setup_fedfb_controller] +[declearn.fairness.fairbatch.setup_fedfb_controller]: + Instantiate a FedFB sampling probabilities controller. + +Messages +-------- +* [FairbatchOkay][declearn.fairness.fairbatch.FairbatchOkay] +* [FairbatchSamplingProbas[ +[declearn.fairness.fairbatch.FairbatchSamplingProbas] + +References +---------- +- [1] + Roh et al. (2020). + FairBatch: Batch Selection for Model Fairness. + https://arxiv.org/abs/2012.01696 +- [2] + Zeng et al. (2022). + Improving Fairness via Federated Learning. + https://arxiv.org/abs/2110.15545 +""" + +from ._messages import ( + FairbatchOkay, + FairbatchSamplingProbas, +) +from ._sampling import ( + FairbatchSamplingController, + setup_fairbatch_controller, +) +from ._fedfb import setup_fedfb_controller +from ._dataset import FairbatchDataset +from ._client import FairbatchControllerClient +from ._server import FairbatchControllerServer diff --git a/declearn/fairness/fairbatch/_client.py b/declearn/fairness/fairbatch/_client.py new file mode 100644 index 0000000000000000000000000000000000000000..9f7891615bda4cc5dac14fea3d078e864a807f46 --- /dev/null +++ b/declearn/fairness/fairbatch/_client.py @@ -0,0 +1,136 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Client-side Fed-FairBatch/FedFB controller.""" + +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np + +from declearn.aggregator import SumAggregator +from declearn.communication.api import NetworkClient +from declearn.communication.utils import verify_server_message_validity +from declearn.fairness.api import ( + FairnessControllerClient, + FairnessDataset, +) +from declearn.fairness.fairbatch._dataset import FairbatchDataset +from declearn.fairness.fairbatch._messages import ( + FairbatchSamplingProbas, + FairbatchOkay, +) +from declearn.messaging import Error +from declearn.metrics import MeanMetric +from declearn.secagg.api import Encrypter +from declearn.training import TrainingManager + +__all__ = [ + "FairbatchControllerClient", +] + + +class FairbatchControllerClient(FairnessControllerClient): + """Client-side controller to implement Fed-FairBatch or FedFB.""" + + algorithm = "fairbatch" + + def __init__( + self, + manager: TrainingManager, + f_type: str, + f_args: Dict[str, Any], + ) -> None: + super().__init__(manager=manager, f_type=f_type, f_args=f_args) + assert isinstance(self.manager.train_data, FairnessDataset) + self.manager.train_data = FairbatchDataset(self.manager.train_data) + + async def finalize_fairness_setup( + self, + netwk: NetworkClient, + secagg: Optional[Encrypter], + ) -> None: + # Force the use of a SumAggregator. + if not isinstance(self.manager.aggrg, SumAggregator): + self.manager.aggrg = SumAggregator() + # Receive and assign initial sampling probabilities. + await self._update_fairbatch_sampling_probas(netwk) + + async def _update_fairbatch_sampling_probas( + self, + netwk: NetworkClient, + ) -> None: + """Run a FairBatch-specific routine to update sampling probabilities. + + Expect a message from the orchestrating server containing the new + sensitive group sampling probabilities, and apply them to the + training dataset. + + Raises + ------ + RuntimeError: + If the expected message is not received. + If the sampling pobabilities' update fails. + """ + # Receive aggregated sensitive weights. + received = await netwk.recv_message() + message = await verify_server_message_validity( + netwk, received, expected=FairbatchSamplingProbas + ) + probas = dict(zip(self.groups, message.probas)) + # Set the received weights, handling and propagating exceptions if any. + try: + assert isinstance(self.manager.train_data, FairbatchDataset) + self.manager.train_data.set_sampling_probabilities( + group_probas=probas + ) + except Exception as exc: + self.manager.logger.error( + "Exception encountered when setting FairBatch sampling" + "probabilities: %s", + repr(exc), + ) + await netwk.send_message(Error(repr(exc))) + raise RuntimeError( + "FairBatch sampling probabilities update failed." + ) from exc + # If things went well, ping the server back to indicate so. + self.manager.logger.info("Updated FairBatch sampling probabilities.") + await netwk.send_message(FairbatchOkay()) + + def setup_fairness_metrics( + self, + thresh: Optional[float] = None, + ) -> List[MeanMetric]: + loss = self.computer.setup_loss_metric(model=self.manager.model) + metrics = super().setup_fairness_metrics(thresh=thresh) + metrics.append(loss) + return metrics + + async def finalize_fairness_round( + self, + netwk: NetworkClient, + secagg: Optional[Encrypter], + values: Dict[str, Dict[Tuple[Any, ...], float]], + ) -> Dict[str, Union[float, np.ndarray]]: + # Await updated loss weights from the server. + await self._update_fairbatch_sampling_probas(netwk) + # Return group-wise local accuracy, model loss and fairness scores. + return { + f"{metric}_{group}": value + for metric, m_dict in values.items() + for group, value in m_dict.items() + } diff --git a/declearn/fairness/fairbatch/_dataset.py b/declearn/fairness/fairbatch/_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..b620e34370574db2d0d539458acd4a79d1cbc077 --- /dev/null +++ b/declearn/fairness/fairbatch/_dataset.py @@ -0,0 +1,339 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""FairBatch-specific Dataset wrapper and subclass.""" + +from typing import Any, Dict, Iterator, List, Sequence, Tuple + +import numpy as np + +from declearn.dataset import Dataset, DataSpecs +from declearn.fairness.api import FairnessDataset +from declearn.typing import Batch + +__all__ = [ + "FairbatchDataset", +] + + +class FairbatchDataset(FairnessDataset): + """FairBatch-specific FairnessDataset subclass and wrapper. + + FairBatch is an algorithm to enforce group fairness when learning + a classifier, initially designed for the centralized setting, and + extendable to the federated one. It mostly relies on changing the + way training data batches are drawn: instead of drawing uniformly + from the full dataset, FairBatch introduces sampling probabilities + attached to the sensitive groups, that are updated throughout time + to reflect the model's current fairness levels. + + This class is both a subclass to `FairnessDataset` and a wrapper + that is designed to hold such a dataset. It implements a couple + of algorithm-specific methods to set or get group-wise sampling + probabilities, and transparently introduces the FairBatch logic + into the API-defined `generate_batches` method. + + This implementation is based both on the original FairBatch paper + and on the reference implementation by the paper's authors. Hence, + instead of effectively assigning drawing probabilities to samples + based on their sensitive group, batches are in fact drawn as the + concatenation of fixed-size sub-batches, drawn from data subsets + defined by samples' sensitive group. + + As in the reference implementation: + + - ouptut batches always have the same number of samples - this is + true even when using `drop_remainder=False`, that merely adds a + batch to the sequence of generated batches; + - the number of batches to yield is computed based on the total + number of samples and full abtch size; + - when a subset is exhausted, it is drawn from anew; hence, samples + may be seen multiple time in a single "epoch" depending on the + groups' number of samples and sampling probabilities; + - in the extreme case when a subset is smaller than the number of + samples that should be drawn from it for any batch, samples may + even be included multiple times in the same batch. + + In the federated setting, clients may not hold samples to each and + every sensitive group. In this implementation, when a client has no + samples for a given group, it adjusts the sampling probabilities of + all groups for which they have samples. In other words, sampling + probabilities are adjusted so that the total batch size is the same + across clients, in spite of some clients possibly not having samples + for some groups. + """ + + def __init__( + self, + base: FairnessDataset, + ) -> None: + """Instantiate a FairbatchDataset wrapping a FairnessDataset. + + Parameters + ---------- + base: + Base `FairnessDataset` instance to wrap so as to apply + group-wise subsampling as per the FairBatch algorithm. + """ + self.base = base + # Assign a dictionary with sampling probability for each group. + self.groups = self.base.get_sensitive_group_definitions() + self._counts = self.base.get_sensitive_group_counts() + self._sampling_probas = { + group: 1.0 / len(self.groups) for group in self.groups + } + + # Methods provided by the wrapped dataset (merely interfaced). + + def get_data_specs( + self, + ) -> DataSpecs: + return self.base.get_data_specs() + + def get_sensitive_group_definitions( + self, + ) -> List[Tuple[Any, ...]]: + return self.groups + + def get_sensitive_group_counts( + self, + ) -> Dict[Tuple[Any, ...], int]: + return self._counts.copy() + + def get_sensitive_group_subset( + self, + group: Tuple[Any, ...], + ) -> Dataset: + return self.base.get_sensitive_group_subset(group) + + def set_sensitive_group_weights( + self, + weights: Dict[Tuple[Any, ...], float], + adjust_by_counts: bool = False, + ) -> None: + self.base.set_sensitive_group_weights(weights, adjust_by_counts) + + # FairBatch-specific methods. + + def get_sampling_probabilities( + self, + ) -> Dict[Tuple[Any, ...], float]: + """Access current group-wise sampling probabilities.""" + return self._sampling_probas.copy() + + def set_sampling_probabilities( + self, + group_probas: Dict[Tuple[Any, ...], float], + ) -> None: + """Assign new group-wise sampling probabilities. + + If some groups are not present in the wrapped dataset, + scale the probabilities associated with all represented + groups so that they sum to 1. + + Parameters + ---------- + group_probas: + Dict of group-wise sampling probabilities, with + `{(s_attr_1, ..., s_attr_k): sampling_proba}` format. + + Raises + ------ + ValueError + If the input probabilities are not positive values + or if they do not cover (a superset of) all sensitive + groups present in the wrapped dataset. + """ + # Verify that input match expectations. + if not all(x >= 0 for x in group_probas.values()): + raise ValueError( + f"'{self.__class__.__name__}.update_sampling_probabilities' " + "cannot have a negative probability value as parameter." + ) + if not set(self.groups).issubset(group_probas): + raise ValueError( + "'FairbatchDataset.update_sampling_probabilities' requires " + "input values to cover (a superset of) local sensitive groups." + ) + # Restrict and adjust probabilities to groups with samples. + probas = {group: group_probas[group] for group in self.groups} + total = sum(probas.values()) + self._sampling_probas = { + key: val / total for key, val in probas.items() + } + + def generate_batches( + self, + batch_size: int, + shuffle: bool = False, + drop_remainder: bool = True, + replacement: bool = False, + poisson: bool = False, + ) -> Iterator[Batch]: + # inherited signature; pylint: disable=too-many-arguments + # Compute the number of batches to yield. + nb_samples = sum(self._counts.values()) + nb_batches = nb_samples // batch_size + if (not drop_remainder) and (nb_samples % batch_size): + nb_batches += 1 + # Compute the group-wise number of samples per batch. + # NOTE: this number may be reduced if there are too few samples. + group_batch_size = { + group: round(proba * batch_size) + for group, proba in self._sampling_probas.items() + } + # Yield batches made of a fixed number of samples from each group. + generators = [ + self._generate_sensitive_group_batches( + group, nb_batches, g_batch_size, shuffle, replacement, poisson + ) + for group, g_batch_size in group_batch_size.items() + if g_batch_size > 0 + ] + for batches in zip(*generators): + yield self._concatenate_batches(batches) + + @staticmethod + def _concatenate_batches( + batches: Sequence[Batch], + ) -> Batch: + """Concatenate batches of numpy array data.""" + x_dat = np.concatenate([batch[0] for batch in batches], axis=0) + y_dat = ( + None + if batches[0][1] is None + else np.concatenate([batch[1] for batch in batches], axis=0) + ) + w_dat = ( + None + if batches[0][2] is None + else np.concatenate([batch[2] for batch in batches], axis=0) + ) + return x_dat, y_dat, w_dat + + def _generate_sensitive_group_batches( + self, + group: Tuple[Any, ...], + nb_batches: int, + batch_size: int, + shuffle: bool, + replacement: bool, + poisson: bool, + ) -> Iterator[Batch]: + """Generate a fixed number of batches for a given sensitive group. + + Parameters + ---------- + group: + Sensitive group, the dataset from which to draw from. + nb_batches: + Number of batches to yield. The dataset will be iterated + over if needed to achieve this number. + batch_size: + Number of samples per batch (will be exact). + shuffle: + Whether to shuffle the dataset prior to drawing batches. + replacement: + Whether to draw with replacement between batches. + poisson: + Whether to use poisson sampling rather than batching. + """ + # backend method; pylint: disable=too-many-arguments + args = (shuffle, replacement, poisson) + # Fetch the target sub-dataset and its samples count. + dataset = self.get_sensitive_group_subset(group) + n_samples = self._counts[group] + # When the dataset is large enough, merely yield batches. + if batch_size <= n_samples: + yield from self._generate_batches( + dataset, group, nb_batches, batch_size, *args + ) + # When the batch size is larger than the number of data points, + # make up a base batch will all points (duplicated if needed), + # that will be combined with further batches of data. + else: + n_repeats, batch_size = divmod(batch_size, n_samples) + # Gather the full subset, optionally duplicated. + full = self._get_full_dataset(dataset, n_samples, group) + if n_repeats > 1: + full = self._concatenate_batches([full] * n_repeats) + # Add up further (batch-varying) samples (when needed). + if batch_size: + for batch in self._generate_batches( + dataset, group, nb_batches, batch_size, *args + ): + yield self._concatenate_batches([full, batch]) + else: # edge case: require exactly N times the full dataset + for _ in range(nb_batches): + yield full + + def _generate_batches( + self, + dataset: Dataset, + group: Tuple[Any, ...], + nb_batches: int, + batch_size: int, + shuffle: bool, + replacement: bool, + poisson: bool, + ) -> Iterator[Batch]: + """Backend to yield a fixed number of batches from a dataset.""" + # backend method; pylint: disable=too-many-arguments + # Iterate multiple times over the sub-dataset if needed. + counter = 0 + while counter < nb_batches: + # Yield batches from the sub-dataset. + generator = dataset.generate_batches( + batch_size=batch_size, + shuffle=shuffle, + drop_remainder=True, + replacement=replacement, + poisson=poisson, + ) + for batch in generator: + yield batch + counter += 1 + if counter == nb_batches: + break + # Prevent infinite loops and raise an informative error. + if not counter: # pragma: no cover + raise RuntimeError( + f"'{self.__class__.__name__}.generate_batches' triggered " + "an infinite loop; this happened when trying to extract " + f"{batch_size}-samples batches for group {group}." + ) + + @staticmethod + def _get_full_dataset( + dataset: Dataset, + n_samples: int, + group: Tuple[Any, ...], + ) -> Batch: + """Return a batch containing an entire dataset's samples.""" + try: + generator = dataset.generate_batches( + batch_size=n_samples, + shuffle=False, + drop_remainder=False, + replacement=False, + poisson=False, + ) + return next(generator) + except StopIteration as exc: # pragma: no cover + raise RuntimeError( + f"Failed to fetch the full subdataset for group '{group}'." + ) from exc diff --git a/declearn/fairness/fairbatch/_fedfb.py b/declearn/fairness/fairbatch/_fedfb.py new file mode 100644 index 0000000000000000000000000000000000000000..2d940f51b3def72e5724cbae1230217ad4f77955 --- /dev/null +++ b/declearn/fairness/fairbatch/_fedfb.py @@ -0,0 +1,266 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""FedFB sampling probability controllers.""" + +from typing import Any, Dict, Tuple + +import numpy as np + +from declearn.fairness.fairbatch._sampling import ( + FairbatchDemographicParity, + FairbatchEqualizedOdds, + FairbatchEqualityOpportunity, + FairbatchSamplingController, + assign_sensitive_group_labels, +) + + +__all__ = [ + "setup_fedfb_controller", +] + + +class FedFBEqualityOpportunity(FairbatchEqualityOpportunity): + """FedFB variant of Equality of Opportunity controller. + + This variant introduces two changes as compared with our FedFairBatch: + - The lambda parameter and difference of losses are written with a + different group ordering, albeit resulting in identical results. + - When comparing loss values over sensitive groups, the notations from + the FedFB paper indicate that the sums of losses over samples in the + groups are compared, rather than the averages of group-wise losses; + this implementation sticks to the FedFB paper. + """ + + f_type = "equality_of_opportunity" + + def get_sampling_probas( + self, + ) -> Dict[Tuple[Any, ...], float]: + # Revert the sense of lambda (invert (1, 0) and (0, 1) groups) + # to stick with notations from the FedFB paper. + probas = super().get_sampling_probas() + label_10 = self.groups["1_0"] + label_11 = self.groups["1_1"] + probas[label_10], probas[label_11] = probas[label_11], probas[label_10] + return probas + + def update_from_losses( + self, + losses: Dict[Tuple[Any, ...], float], + ) -> None: + # Recover sum-aggregated losses for the two groups of interest. + # Do not scale: obtain sums of sample losses for each group. + # This differs from parent class (and centralized FairBatch) + # but sticks with the FedFB paper's notations and algorithm. + loss_10 = losses[self.groups["1_0"]] * self.counts[self.groups["1_0"]] + loss_11 = losses[self.groups["1_1"]] * self.counts[self.groups["1_1"]] + # Update lambda based on these and the alpha hyper-parameter. + # Note: this is the same as in parent class, inverting sense of + # groups (0, 0) and (1, 0), to stick with the FedFB paper. + if loss_11 > loss_10: + self.states["lambda"] = min( + self.states["lambda"] + self.alpha, self.states["p_tgt_1"] + ) + elif loss_11 < loss_10: + self.states["lambda"] = max(self.states["lambda"] - self.alpha, 0) + + +class FedFBEqualizedOdds(FairbatchEqualizedOdds): + """FedFB variant of Equalized Odds controller. + + This variant introduces three changes as compared with our FedFairBatch: + - The lambda parameters and difference of losses are written with a + different group ordering, albeit resulting in identical results. + - When comparing loss values over sensitive groups, the notations from + the FedFB paper indicate that the sums of losses over samples in the + groups are compared, rather than the averages of group-wise losses; + this implementation sticks to the FedFB paper. + - The update rule for lambda parameters has a distinct formula, with the + alpha hyper-parameter being here scaled by the difference in losses + and normalized by the L2 norm of differences in losses, and both groups' + lambda being updated at each step. + """ + + f_type = "equalized_odds" + + def compute_initial_states( + self, + ) -> Dict[str, float]: + # Switch lambdas: apply to groups (-, 1) rather than (-, 0). + states = super().compute_initial_states() + states["lambda_1"] = states["p_trgt_0"] - states["lambda_1"] + states["lambda_2"] = states["p_trgt_1"] - states["lambda_2"] + return states + + def get_sampling_probas( + self, + ) -> Dict[Tuple[Any, ...], float]: + # Rewrite the rules entirely, effectively swapping (0,1)/(0,0) + # and (1,1)/(1,0) groups compared with parent implementation. + states = self.states + return { + self.groups["0_0"]: states["p_trgt_0"] - states["lambda_1"], + self.groups["0_1"]: states["lambda_1"], + self.groups["1_0"]: states["p_trgt_1"] - states["lambda_2"], + self.groups["1_1"]: states["lambda_2"], + } + + def update_from_losses( + self, + losses: Dict[Tuple[Any, ...], float], + ) -> None: + # Recover sum-aggregated losses for each sensitive group. + # Do not scale: obtain sums of sample losses for each group. + # This differs from parent class (and centralized FairBatch) + # but sticks with the FedFB paper's notations and algorithm. + labeled_losses = { + label: losses[group] * self.counts[group] + for label, group in self.groups.items() + } + # Compute aggregated-loss differences for each target label. + diff_loss_tgt_0 = labeled_losses["0_1"] - labeled_losses["0_0"] + diff_loss_tgt_1 = labeled_losses["1_1"] - labeled_losses["1_0"] + # Compute the euclidean norm of these values. + den = float(np.linalg.norm([diff_loss_tgt_0, diff_loss_tgt_1], ord=2)) + # Update lambda_1 (affecting groups with y=0). + update = self.alpha * diff_loss_tgt_0 / den + self.states["lambda_1"] = min( + self.states["lambda_1"] + update, self.states["p_trgt_0"] + ) + self.states["lambda_1"] = max(self.states["lambda_1"], 0) + # Update lambda_1 (affecting groups with y=1). + update = self.alpha * diff_loss_tgt_1 / den + self.states["lambda_2"] = min( + self.states["lambda_2"] + update, self.states["p_trgt_1"] + ) + self.states["lambda_2"] = max(self.states["lambda_2"], 0) + + +class FedFBDemographicParity(FairbatchDemographicParity): + """FairbatchSamplingController subclass for 'demographic_parity'.""" + + f_type = "demographic_parity" + + def update_from_losses( + self, + losses: Dict[Tuple[Any, ...], float], + ) -> None: + # NOTE: losses' aggregation does not defer from parent class. + # pylint: disable=duplicate-code + # Recover sum-aggregated losses for each sensitive group. + # Obtain {k: n_k * Sum(loss for all samples in group k)}. + labeled_losses = { + label: losses[group] * self.counts[group] + for label, group in self.groups.items() + } + # Normalize losses based on sensitive attribute counts. + # Obtain {k: sum(loss for samples in k) / n_samples_with_attr}. + labeled_losses["0_0"] /= self.states["n_attr_0"] + labeled_losses["0_1"] /= self.states["n_attr_1"] + labeled_losses["1_0"] /= self.states["n_attr_0"] + labeled_losses["1_1"] /= self.states["n_attr_1"] + # NOTE: this is where things differ from parent class. + # pylint: enable=duplicate-code + # Compute an overall fairness value based on all losses. + f_val = ( + -labeled_losses["0_0"] + + labeled_losses["0_1"] + + labeled_losses["1_0"] + - labeled_losses["1_1"] + + self.counts[self.groups["0_0"]] / self.states["n_attr_0"] + - self.counts[self.groups["0_1"]] / self.states["n_attr_1"] + ) + # Update both lambdas based on this overall value. + # Note: in the binary attribute case, $mu_a / ||mu||_2$ + # is equal to $sign(mu_1) / sqrt(2)$. + update = float(np.sign(f_val) * self.alpha / np.sqrt(2)) + self.states["lambda_1"] = min( + self.states["lambda_1"] - update, self.states["p_attr_0"] + ) + self.states["lambda_1"] = max(self.states["lambda_1"], 0) + self.states["lambda_2"] = min( + self.states["lambda_2"] - update, self.states["p_attr_1"] + ) + self.states["lambda_2"] = max(self.states["lambda_2"], 0) + + +def setup_fedfb_controller( + f_type: str, + counts: Dict[Tuple[Any, ...], int], + target: int = 1, + alpha: float = 0.005, +) -> FairbatchSamplingController: + """Instantiate a FedFB sampling probabilities controller. + + This is a drop-in replacement for `setup_fairbatch_controller` + that implemented update rules matching the Fed-FB algorithm(s) + as introduced in [1]. + + Parameters + ---------- + f_type: + Type of group fairness to optimize for. + counts: + Dict mapping sensitive group definitions to their total + sample counts (across clients). These groups must arise + from the crossing of a binary target label and a binary + sensitive attribute. + target: + Target label to treat as positive. + alpha: + Alpha hyper-parameter, scaling the magnitude of sampling + probabilities' updates by the returned controller. + + Returns + ------- + controller: + FairBatch sampling probabilities controller matching inputs. + + Raises + ------ + KeyError + If `f_type` does not match any known or supported fairness type. + ValueError + If `counts` keys cannot be matched to canonical group labels. + + References + ---------- + [1] Zeng et al. (2022). + Improving Fairness via Federated Learning. + https://arxiv.org/abs/2110.15545 + """ + # known duplicate of fairbatch setup; pylint: disable=duplicate-code + controller_types = { + "demographic_parity": FedFBDemographicParity, + "equality_of_opportunity": FedFBEqualityOpportunity, + "equalized_odds": FedFBEqualizedOdds, + } + controller_cls = controller_types.get(f_type, None) + if controller_cls is None: + raise KeyError( + "Unknown or unsupported fairness type parameter for FairBatch " + f"controller initialization: '{f_type}'. Supported values are " + f"{list(controller_types)}." + ) + # Match groups to canonical labels and instantiate the controller. + groups = assign_sensitive_group_labels(groups=list(counts), target=target) + kwargs = {"target": target} if f_type == "equality_of_opportunity" else {} + return controller_cls( # type: ignore[abstract] + groups=groups, counts=counts, alpha=alpha, **kwargs + ) diff --git a/declearn/fairness/fairbatch/_messages.py b/declearn/fairness/fairbatch/_messages.py new file mode 100644 index 0000000000000000000000000000000000000000..c1b9ec19db977cb64ba736f194a459a3b5c23b0d --- /dev/null +++ b/declearn/fairness/fairbatch/_messages.py @@ -0,0 +1,53 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Fed-FairBatch/Fed-FB specific messages.""" + +import dataclasses +from typing import List + + +from declearn.messaging import Message + + +__all__ = [ + "FairbatchOkay", + "FairbatchSamplingProbas", +] + + +@dataclasses.dataclass +class FairbatchOkay(Message): + """Message for client signal that Fed-FairBatch/FedFB update went fine.""" + + typekey = "fairbatch-okay" + + +@dataclasses.dataclass +class FairbatchSamplingProbas(Message): + """Message for server-emitted Fed-FairBatch/Fed-FB sampling probabilities. + + Fields + ------ + probas: + List of group-wise sampling probabilities, ordered based on + an agreed-upon sorted list of sensitive groups. + """ + + probas: List[float] + + typekey = "fairbatch-probas" diff --git a/declearn/fairness/fairbatch/_sampling.py b/declearn/fairness/fairbatch/_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..a7ffb9f0ddd7478c16ee0378befdb3a8afd15301 --- /dev/null +++ b/declearn/fairness/fairbatch/_sampling.py @@ -0,0 +1,435 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""FairBatch sampling probability controllers.""" + +import abc +from typing import Any, ClassVar, Dict, List, Literal, Tuple + + +from declearn.fairness.api import instantiate_fairness_function + + +__all__ = [ + "FairbatchSamplingController", + "setup_fairbatch_controller", +] + + +GroupLabel = Literal["0_0", "0_1", "1_0", "1_1"] + + +class FairbatchSamplingController(metaclass=abc.ABCMeta): + """ABC to compute and update Fairbatch sampling probabilities.""" + + f_type: ClassVar[str] + + def __init__( + self, + groups: Dict[GroupLabel, Tuple[Any, ...]], + counts: Dict[Tuple[Any, ...], int], + alpha: float = 0.005, + **kwargs: Any, + ) -> None: + """Instantiate the Fairbatch sampling probabilities controller. + + Parameters + ---------- + groups: + Dict mapping canonical labels to sensitive group definitions. + counts: + Dict mapping sensitive group definitions to sample counts. + alpha: + Hyper-parameter controlling the update rule for internal + states and thereof sampling probabilities. + **kwargs: + Keyword arguments specific to the fairness definition in use. + """ + # Assign input parameters as attributes. + self.groups = groups + self.counts = counts + self.total = sum(counts.values()) + self.alpha = alpha + # Initialize internal states and sampling probabilities. + self.states = self.compute_initial_states() + # Initialize a fairness function. + self.f_func = instantiate_fairness_function( + f_type=self.f_type, counts=counts, **kwargs + ) + + @abc.abstractmethod + def compute_initial_states( + self, + ) -> Dict[str, float]: + """Return a dict containing initial internal states. + + Returns + ------- + states: + Dict associating float values to arbitrary names that + depend on the type of group-fairness being optimized. + """ + + @abc.abstractmethod + def get_sampling_probas( + self, + ) -> Dict[Tuple[Any, ...], float]: + """Return group-wise sampling probabilities. + + Returns + ------- + sampling_probas: + Dict mapping sensitive group definitions to their sampling + probabilities, as establised via the FairBatch algorithm. + """ + + @abc.abstractmethod + def update_from_losses( + self, + losses: Dict[Tuple[Any, ...], float], + ) -> None: + """Update internal states based on group-wise losses. + + Parameters + ---------- + losses: + Group-wise model loss values, as a `{group_k: loss_k}` dict. + """ + + def update_from_federated_losses( + self, + losses: Dict[Tuple[Any, ...], float], + ) -> None: + """Update internal states based on federated group-wise losses. + + Parameters + ---------- + losses: + Group-wise sum-aggregated local-group-count-weighted model + loss values, computed over an ensemble of local datasets. + I.e. `{group_k: sum_i(n_ik * loss_ik)}` dict. + + Raises + ------ + KeyError + If any defined sensitive group does not have a loss value. + """ + losses = {key: val / self.counts[key] for key, val in losses.items()} + self.update_from_losses(losses) + + +class FairbatchEqualityOpportunity(FairbatchSamplingController): + """FairbatchSamplingController subclass for 'equality_of_opportunity'.""" + + f_type = "equality_of_opportunity" + + def compute_initial_states( + self, + ) -> Dict[str, float]: + # Gather sample counts and share with positive target label. + nsmp_10 = self.counts[self.groups["1_0"]] + nsmp_11 = self.counts[self.groups["1_1"]] + p_tgt_1 = (nsmp_10 + nsmp_11) / self.total + # Assign the initial lambda and fixed quantities to re-use. + return { + "lambda": nsmp_10 / self.total, + "p_tgt_1": p_tgt_1, + "p_g_0_0": self.counts[self.groups["0_0"]] / self.total, + "p_g_0_1": self.counts[self.groups["0_1"]] / self.total, + } + + def get_sampling_probas( + self, + ) -> Dict[Tuple[Any, ...], float]: + return { + self.groups["0_0"]: self.states["p_g_0_0"], + self.groups["0_1"]: self.states["p_g_0_1"], + self.groups["1_0"]: self.states["lambda"], + self.groups["1_1"]: self.states["p_tgt_1"] - self.states["lambda"], + } + + def update_from_losses( + self, + losses: Dict[Tuple[Any, ...], float], + ) -> None: + # Gather mean-aggregated losses for the two groups of interest. + loss_10 = losses[self.groups["1_0"]] + loss_11 = losses[self.groups["1_1"]] + # Update lambda based on these and the alpha hyper-parameter. + if loss_10 > loss_11: + self.states["lambda"] = min( + self.states["lambda"] + self.alpha, self.states["p_tgt_1"] + ) + elif loss_10 < loss_11: + self.states["lambda"] = max(self.states["lambda"] - self.alpha, 0) + + +class FairbatchEqualizedOdds(FairbatchSamplingController): + """FairbatchSamplingController subclass for 'equalized_odds'.""" + + f_type = "equalized_odds" + + def compute_initial_states( + self, + ) -> Dict[str, float]: + # Gather sample counts. + nsmp_00 = self.counts[self.groups["0_0"]] + nsmp_01 = self.counts[self.groups["0_1"]] + nsmp_10 = self.counts[self.groups["1_0"]] + nsmp_11 = self.counts[self.groups["1_1"]] + # Compute initial lambas, and attribute-wise sample counts. + return { + "lambda_1": nsmp_00 / self.total, + "lambda_2": nsmp_10 / self.total, + "p_trgt_0": (nsmp_00 + nsmp_01) / self.total, + "p_trgt_1": (nsmp_10 + nsmp_11) / self.total, + } + + def get_sampling_probas( + self, + ) -> Dict[Tuple[Any, ...], float]: + states = self.states + return { + self.groups["0_0"]: states["lambda_1"], + self.groups["0_1"]: states["p_trgt_0"] - states["lambda_1"], + self.groups["1_0"]: states["lambda_2"], + self.groups["1_1"]: states["p_trgt_1"] - states["lambda_2"], + } + + def update_from_losses( + self, + losses: Dict[Tuple[Any, ...], float], + ) -> None: + # Compute loss differences for each target label. + diff_loss_tgt_0 = ( + losses[self.groups["0_0"]] - losses[self.groups["0_1"]] + ) + diff_loss_tgt_1 = ( + losses[self.groups["1_0"]] - losses[self.groups["1_1"]] + ) + # Update a lambda based on these and the alpha hyper-parameter. + if abs(diff_loss_tgt_0) > abs(diff_loss_tgt_1): + if diff_loss_tgt_0 > 0: + self.states["lambda_1"] = min( + self.states["lambda_1"] + self.alpha, + self.states["p_trgt_0"], + ) + elif diff_loss_tgt_0 < 0: + self.states["lambda_1"] = max( + self.states["lambda_1"] - self.alpha, 0 + ) + else: + if diff_loss_tgt_1 > 0: + self.states["lambda_2"] = min( + self.states["lambda_2"] + self.alpha, + self.states["p_trgt_1"], + ) + elif diff_loss_tgt_1 < 0: + self.states["lambda_2"] = max( + self.states["lambda_2"] - self.alpha, 0 + ) + + +class FairbatchDemographicParity(FairbatchSamplingController): + """FairbatchSamplingController subclass for 'demographic_parity'.""" + + f_type = "demographic_parity" + + def compute_initial_states( + self, + ) -> Dict[str, float]: + # Gather sample counts. + nsmp_00 = self.counts[self.groups["0_0"]] + nsmp_01 = self.counts[self.groups["0_1"]] + nsmp_10 = self.counts[self.groups["1_0"]] + nsmp_11 = self.counts[self.groups["1_1"]] + # Compute initial lambas, and target-label-wise sample counts. + return { + "lambda_1": nsmp_00 / self.total, + "lambda_2": nsmp_01 / self.total, + "p_attr_0": (nsmp_00 + nsmp_10) / self.total, + "p_attr_1": (nsmp_01 + nsmp_11) / self.total, + "n_attr_0": nsmp_00 + nsmp_10, + "n_attr_1": nsmp_01 + nsmp_11, + } + + def get_sampling_probas( + self, + ) -> Dict[Tuple[Any, ...], float]: + states = self.states + return { + self.groups["0_0"]: states["lambda_1"], + self.groups["1_0"]: states["p_attr_0"] - states["lambda_1"], + self.groups["0_1"]: states["lambda_2"], + self.groups["1_1"]: states["p_attr_1"] - states["lambda_2"], + } + + def update_from_losses( + self, + losses: Dict[Tuple[Any, ...], float], + ) -> None: + # Recover sum-aggregated losses for each sensitive group. + # Obtain {k: n_k * Sum(loss for all samples in group k)}. + labeled_losses = { + label: losses[group] * self.counts[group] + for label, group in self.groups.items() + } + # Normalize losses based on sensitive attribute counts. + # Obtain {k: sum(loss for samples in k) / n_samples_with_attr}. + labeled_losses["0_0"] /= self.states["n_attr_0"] + labeled_losses["0_1"] /= self.states["n_attr_1"] + labeled_losses["1_0"] /= self.states["n_attr_0"] + labeled_losses["1_1"] /= self.states["n_attr_1"] + # Compute aggregated-loss differences for each target label. + diff_loss_tgt_0 = labeled_losses["0_0"] - labeled_losses["0_1"] + diff_loss_tgt_1 = labeled_losses["1_0"] - labeled_losses["1_1"] + # Update a lambda based on these and the alpha hyper-parameter. + if abs(diff_loss_tgt_0) > abs(diff_loss_tgt_1): + if diff_loss_tgt_0 > 0: + self.states["lambda_1"] = max( + self.states["lambda_1"] - self.alpha, 0 + ) + elif diff_loss_tgt_0 < 0: + self.states["lambda_1"] = min( + self.states["lambda_1"] + self.alpha, + self.states["p_attr_0"], + ) + else: + if diff_loss_tgt_1 > 0: + self.states["lambda_2"] = min( + self.states["lambda_2"] + self.alpha, + self.states["p_attr_1"], + ) + elif diff_loss_tgt_1 < 0: + self.states["lambda_2"] = max( + self.states["lambda_2"] - self.alpha, 0 + ) + + +def assign_sensitive_group_labels( + groups: List[Tuple[Any, ...]], + target: int, +) -> Dict[GroupLabel, Tuple[Any, ...]]: + """Parse sensitive group definitions to match canonical labels. + + Parameters + ---------- + groups: + List of sensitive group definitions, as a list of tuples. + These should be four tuples arising from the intersection + of binary labels (with any actual type). + target: + Value of the target label to treat as positive. + + Returns + ------- + labeled_groups: + Dict mapping canonical labels `"0_0", "0_1", "1_0", "1_1"` + to the input sensitive group definitions. + + Raises + ------ + ValueError + If 'groups' has unproper length, values that do not appear + to be binary, or that do not match the specified 'target'. + """ + # Verify that groups can be identified as crossing two binary labels. + if len(groups) != 4: + raise ValueError( + "FairBatch requires counts over exactly 4 sensitive groups, " + "arising from a binary target label and a binary sensitive " + "attribute." + ) + target_values = list({group[0] for group in groups}) + s_attr_values = sorted(list({group[1] for group in groups})) + if not len(target_values) == len(s_attr_values) == 2: + raise ValueError( + "FairBatch requires sensitive groups to arise from a binary " + "target label and a binary sensitive attribute." + ) + # Identify the positive and negative label values. + if target_values[0] == target: + postgt, negtgt = target_values + elif target_values[1] == target: + negtgt, postgt = target_values + else: + raise ValueError( + f"Received a target value of '{target}' that does not match any " + f"value in the sensitive group definitions: {target_values}." + ) + # Match group definitions with canonical string labels. + return { + "0_0": (negtgt, s_attr_values[0]), + "0_1": (negtgt, s_attr_values[1]), + "1_0": (postgt, s_attr_values[0]), + "1_1": (postgt, s_attr_values[1]), + } + + +def setup_fairbatch_controller( + f_type: str, + counts: Dict[Tuple[Any, ...], int], + target: int = 1, + alpha: float = 0.005, +) -> FairbatchSamplingController: + """Instantiate a FairBatch sampling probabilities controller. + + Parameters + ---------- + f_type: + Type of group fairness to optimize for. + counts: + Dict mapping sensitive group definitions to their total + sample counts (across clients). These groups must arise + from the crossing of a binary target label and a binary + sensitive attribute. + target: + Target label to treat as positive. + alpha: + Alpha hyper-parameter, scaling the magnitude of sampling + probabilities' updates by the returned controller. + + Returns + ------- + controller: + FairBatch sampling probabilities controller matching inputs. + + Raises + ------ + KeyError + If `f_type` does not match any known or supported fairness type. + ValueError + If `counts` keys cannot be matched to canonical group labels. + """ + controller_types = { + "demographic_parity": FairbatchDemographicParity, + "equality_of_opportunity": FairbatchEqualityOpportunity, + "equalized_odds": FairbatchEqualizedOdds, + } + controller_cls = controller_types.get(f_type, None) + if controller_cls is None: + raise KeyError( + "Unknown or unsupported fairness type parameter for FairBatch " + f"controller initialization: '{f_type}'. Supported values are " + f"{list(controller_types)}." + ) + # Match groups to canonical labels and instantiate the controller. + groups = assign_sensitive_group_labels(groups=list(counts), target=target) + kwargs = {"target": target} if f_type == "equality_of_opportunity" else {} + return controller_cls( # type: ignore[abstract] + groups=groups, counts=counts, alpha=alpha, **kwargs + ) diff --git a/declearn/fairness/fairbatch/_server.py b/declearn/fairness/fairbatch/_server.py new file mode 100644 index 0000000000000000000000000000000000000000..80f88aa9a186ac6ab004a7a744d8100a5b690761 --- /dev/null +++ b/declearn/fairness/fairbatch/_server.py @@ -0,0 +1,192 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Server-side Fed-FairBatch/FedFB controller.""" + +import warnings +from typing import Any, Dict, List, Optional, Union + +import numpy as np + +from declearn.aggregator import Aggregator, SumAggregator +from declearn.communication.api import NetworkServer +from declearn.communication.utils import verify_client_messages_validity +from declearn.fairness.api import FairnessControllerServer +from declearn.fairness.fairbatch._fedfb import setup_fedfb_controller +from declearn.fairness.fairbatch._messages import ( + FairbatchOkay, + FairbatchSamplingProbas, +) +from declearn.fairness.fairbatch._sampling import setup_fairbatch_controller +from declearn.secagg.api import Decrypter + + +__all__ = [ + "FairbatchControllerServer", +] + + +class FairbatchControllerServer(FairnessControllerServer): + """Server-side controller to implement Fed-FairBatch or FedFB. + + FairBatch [1] is a group-fairness-enforcing algorithm that relies + on a specific form of loss reweighting mediated via the batching + of samples for SGD steps. Namely, in FairBatch, batches are drawn + by concatenating group-wise sub-batches, the size of which is the + byproduct of the desired total batch size and group-wise sampling + probabilities, with the latter being updated throughout training + based on the measured fairness of the current model. + + This controller implements an adaptation of FairBatch for federated + learning, that is limited to the setting of the original paper, i.e. + a binary classification task on data that have a single and binary + sensitive attribute. + + The `fedfb` instantiation parameter controls whether formulas from + the original paper should be used for computing and updating group + sampling probabilities (the default), or be replaced with variants + introduced in the FedFB algorithm from paper [2]. + + References + ---------- + - [1] + Roh et al. (2020). + FairBatch: Batch Selection for Model Fairness. + https://arxiv.org/abs/2012.01696 + - [2] + Zeng et al. (2022). + Improving Fairness via Federated Learning. + https://arxiv.org/abs/2110.15545 + """ + + algorithm = "fairbatch" + + def __init__( + self, + f_type: str, + f_args: Optional[Dict[str, Any]] = None, + alpha: float = 0.005, + fedfb: bool = False, + ) -> None: + """Instantiate the server-side Fed-FairGrad controller. + + Parameters + ---------- + f_type: + Name of the fairness function to evaluate and optimize. + f_args: + Optional dict of keyword arguments to the fairness function. + alpha: + Hyper-parameter controlling the update rule for internal + states and thereof sampling probabilities. + fedfb: + Whether to use FedFB formulas rather than to stick + to those from the original FairBatch paper. + """ + super().__init__(f_type=f_type, f_args=f_args) + # Choose whether to use FedFB or FairBatch update rules. + self._setup_function = ( + setup_fedfb_controller if fedfb else setup_fairbatch_controller + ) + # Set up a temporary controller that will be replaced at setup time. + self.sampling_controller = self._setup_function( + f_type=self.f_type, + counts={(0, 0): 1, (0, 1): 1, (1, 0): 1, (1, 1): 1}, + target=self.f_args.get("target", 1), + alpha=alpha, + ) + + @property + def fedfb(self) -> bool: + """Whether this controller implements FedFB rather than Fed-FairBatch. + + FedFB is a published adaptation of FairBatch to the federated + setting, that introduces changes to some FairBatch formulas. + + Fed-FairBatch is a DecLearn-introduced variant of FedFB that + restores the original FairBatch formulas. + """ + return self._setup_function is setup_fedfb_controller + + async def finalize_fairness_setup( + self, + netwk: NetworkServer, + secagg: Optional[Decrypter], + counts: List[int], + aggregator: Aggregator, + ) -> Aggregator: + # Set up the FairbatchWeightsController. + self.sampling_controller = self._setup_function( + f_type=self.f_type, + counts=dict(zip(self.groups, counts)), + target=self.f_args.get("target", 1), + alpha=self.sampling_controller.alpha, + ) + # Send initial loss weights to the clients. + await self._send_fairbatch_probas(netwk) + # Force the use of a SumAggregator. + if not isinstance(aggregator, SumAggregator): + warnings.warn( + "Overriding Aggregator choice to a 'SumAggregator', " + "due to the use of Fed-FairBatch.", + category=RuntimeWarning, + ) + aggregator = SumAggregator() + return aggregator + + async def _send_fairbatch_probas( + self, + netwk: NetworkServer, + ) -> None: + """Send FairBatch sensitive group sampling probabilities to clients. + + Await for clients to ping back that things went fine on their side. + """ + netwk.logger.info( + "Sending FairBatch sampling probabilities to clients." + ) + probas = self.sampling_controller.get_sampling_probas() + p_list = [probas[group] for group in self.groups] + await netwk.broadcast_message(FairbatchSamplingProbas(p_list)) + received = await netwk.wait_for_messages() + await verify_client_messages_validity( + netwk, received, expected=FairbatchOkay + ) + + async def finalize_fairness_round( + self, + netwk: NetworkServer, + secagg: Optional[Decrypter], + values: List[float], + ) -> Dict[str, Union[float, np.ndarray]]: + # Unpack group-wise accuracy and loss values. + accuracy = dict(zip(self.groups, values[: len(self.groups)])) + loss = dict(zip(self.groups, values[len(self.groups) :])) + # Update sampling probabilities and send them to clients. + self.sampling_controller.update_from_federated_losses(loss) + await self._send_fairbatch_probas(netwk) + # Package and return accuracy, loss and fairness metrics. + metrics = { + f"accuracy_{key}": val for key, val in accuracy.items() + } # type: Dict[str, Union[float, np.ndarray]] + metrics.update({f"loss_{key}": val for key, val in loss.items()}) + f_func = self.sampling_controller.f_func + fairness = f_func.compute_from_federated_group_accuracy(accuracy) + metrics.update( + {f"{self.f_type}_{key}": val for key, val in fairness.items()} + ) + return metrics diff --git a/declearn/fairness/fairfed/__init__.py b/declearn/fairness/fairfed/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..077601e26c03707c6367d2395beced438be697b6 --- /dev/null +++ b/declearn/fairness/fairfed/__init__.py @@ -0,0 +1,85 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""FairFed algorithm controllers and utils. + +Introduction +------------ +This module provides with an implementation of FairFed [1], an +algorithm that aims at enforcing fairness in a federated learning +setting by weighting client-wise model updates' averaging based on +differences between the global and local fairness of the (prior +version of the) shared model, using somewhat ad hoc discrepancy +metrics to summarize fairness as scalar values. + +This algorithm was originally designed for settings where a binary +classifier is trained over data with a single binary sensitive +attribute, with the authors showcasing their generic formulas over +a limited set of group fairness definitions. DecLearn expands it to +a broader case, enabling the use of arbitrary fairness definitions +over data that may have non-binary and/or many sensitive attributes. +A 'strict' mode is made available to stick to the original paper. + +Additionally, the algorithm's authors suggest combining it with other +mechanisms that aim at enforcing model fairness during local training +steps. At the moment, our implementation does not support this, but +future refactoring of routines that make up for a federated learning +process may enable doing so. + +Controllers +----------- +* [FairfedControllerClient] +[declearn.fairness.fairfed.FairfedControllerClient]: + Client-side controller to implement FairFed. +* [FairfedControllerServer] +[declearn.fairness.fairfed.FairfedControllerServer]: + Server-side controller to implement FairFed. + +Backend +------- +* [FairfedAggregator][declearn.fairness.fairfed.FairfedAggregator]: + Fairfed-specific Aggregator using arbitrary averaging weights. +* [FairfedValueComputer][declearn.fairness.fairfed.FairfedValueComputer]: + Fairfed-specific synthetic fairness value computer. + +Messages +-------- +* [FairfedDelta][declearn.fairness.fairfed.FairfedDelta] +* [FairfedDeltavg][declearn.fairness.fairfed.FairfedDeltavg] +* [FairfedFairness][declearn.fairness.fairfed.FairfedFairness] +* [FairfedOkay][declearn.fairness.fairfed.FairfedOkay] +* [SecaggFairfedDelta][declearn.fairness.fairfed.SecaggFairfedDelta] + +References +---------- +- [1] + Ezzeldin et al. (2021). + FairFed: Enabling Group Fairness in Federated Learning + https://arxiv.org/abs/2110.00857 +""" + +from ._messages import ( + FairfedDelta, + FairfedDeltavg, + FairfedFairness, + FairfedOkay, + SecaggFairfedDelta, +) +from ._aggregator import FairfedAggregator +from ._fairfed import FairfedValueComputer +from ._client import FairfedControllerClient +from ._server import FairfedControllerServer diff --git a/declearn/fairness/fairfed/_aggregator.py b/declearn/fairness/fairfed/_aggregator.py new file mode 100644 index 0000000000000000000000000000000000000000..92c87bb422793806d652aa3e2c73ed0ceef3e523 --- /dev/null +++ b/declearn/fairness/fairfed/_aggregator.py @@ -0,0 +1,85 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""FairFed-specific Aggregator subclass.""" + + +from declearn.aggregator import Aggregator, ModelUpdates +from declearn.model.api import Vector + +__all__ = [ + "FairfedAggregator", +] + + +class FairfedAggregator(Aggregator, register=False): + """Fairfed-specific Aggregator using arbitrary averaging weights.""" + + name = "fairfed" + + def __init__( + self, + beta: float = 1.0, + ) -> None: + """Instantiate the Fairfed-specific weight averaging aggregator. + + Parameters + ---------- + beta: + Hyper-parameter controlling the magnitude of averaging weights' + updates across rounds. + """ + self.beta = beta + self._weight = 1.0 + + def initialize_local_weight( + self, + n_samples: int, + ) -> None: + """Initialize the local averaging weight based on dataset size.""" + self._weight = n_samples + + def update_local_weight( + self, + delta_loc: float, + delta_avg: float, + ) -> None: + """Update the local averaging weight based on fairness measures. + + Parameters + ---------- + delta_loc: + Absolute difference between the local and global fairness values. + delta_avg: + Average of `delta_loc` values across all clients. + """ + update = self.beta * (delta_loc - delta_avg) + self._weight -= update + + def prepare_for_sharing( + self, + updates: Vector, + n_steps: int, + ) -> ModelUpdates: + updates = updates * self._weight + return ModelUpdates(updates=updates, weights=self._weight) + + def finalize_updates( + self, + updates: ModelUpdates, + ) -> Vector: + return updates.updates / updates.weights diff --git a/declearn/fairness/fairfed/_client.py b/declearn/fairness/fairfed/_client.py new file mode 100644 index 0000000000000000000000000000000000000000..3e27b3348e49d4b19999297284bd2c9f47cb2a91 --- /dev/null +++ b/declearn/fairness/fairfed/_client.py @@ -0,0 +1,152 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Client-side FairFed controller.""" + +from typing import Any, Dict, Optional, Tuple, Union + +import numpy as np + +from declearn.communication.api import NetworkClient +from declearn.communication.utils import verify_server_message_validity +from declearn.fairness.api import FairnessControllerClient +from declearn.fairness.fairfed._aggregator import FairfedAggregator +from declearn.fairness.fairfed._fairfed import FairfedValueComputer +from declearn.fairness.fairfed._messages import ( + FairfedDelta, + FairfedDeltavg, + FairfedFairness, + FairfedOkay, + SecaggFairfedDelta, +) +from declearn.secagg.api import Encrypter +from declearn.training import TrainingManager + +__all__ = [ + "FairfedControllerClient", +] + + +class FairfedControllerClient(FairnessControllerClient): + """Client-side controller to implement FairFed.""" + + algorithm = "fairfed" + + def __init__( + self, + manager: TrainingManager, + f_type: str, + f_args: Dict[str, Any], + beta: float, + strict: bool = True, + target: int = 1, + ) -> None: + """Instantiate the client-side fairness controller. + + Parameters + ---------- + manager: + `TrainingManager` instance wrapping the model being trained + and its training dataset (that must be a `FairnessDataset`). + f_type: + Name of the type of group-fairness function being optimized. + f_args: + Keyword arguments to the group-fairness function. + beta: + Hyper-parameter controlling the magnitude of averaging weights' + updates across rounds. + strict: + Whether to stick strictly to the FairFed paper's setting + and explicit formulas, or to use a broader adaptation of + FairFed to more diverse settings. + target: + Choice of target label to focus on in `strict` mode. + Unused when `strict=False`. + """ + # arguments serve modularity; pylint: disable=too-many-arguments + super().__init__(manager=manager, f_type=f_type, f_args=f_args) + self.beta = beta + self.fairfed_computer = FairfedValueComputer( + f_type=self.fairness_function.f_type, strict=strict, target=target + ) + self.fairfed_computer.initialize(groups=self.fairness_function.groups) + + @property + def strict( + self, + ) -> bool: + """Whether this function strictly sticks to the FairFed paper.""" + return self.fairfed_computer.strict + + async def finalize_fairness_setup( + self, + netwk: NetworkClient, + secagg: Optional[Encrypter], + ) -> None: + # Force the use of a FairFed-specific aggregator. + self.manager.aggrg = FairfedAggregator(beta=self.beta) + self.manager.aggrg.initialize_local_weight( + n_samples=sum(self.computer.counts.values()) + ) + + async def finalize_fairness_round( + self, + netwk: NetworkClient, + secagg: Optional[Encrypter], + values: Dict[str, Dict[Tuple[Any, ...], float]], + ) -> Dict[str, Union[float, np.ndarray]]: + # Await absolute mean fairness across all clients. + received = await netwk.recv_message() + fair_glb = await verify_server_message_validity( + netwk, received, expected=FairfedFairness + ) + # Compute the absolute difference between local and global fairness. + fair_avg = self.fairfed_computer.compute_synthetic_fairness_value( + values[self.fairness_function.f_type] + ) + my_delta = FairfedDelta(abs(fair_avg - fair_glb.fairness)) + # Share it with the server for its (secure-)aggregation across clients. + if secagg is None: + await netwk.send_message(my_delta) + else: + await netwk.send_message( + SecaggFairfedDelta.from_cleartext_message(my_delta, secagg) + ) + # Await mean absolute fairness difference across clients. + received = await netwk.recv_message() + deltavg = await verify_server_message_validity( + netwk, received, expected=FairfedDeltavg + ) + # Update the aggregation weight of this client. + assert isinstance(self.manager.aggrg, FairfedAggregator) + self.manager.aggrg.update_local_weight( + delta_loc=my_delta.delta, + delta_avg=deltavg.deltavg, + ) + # Signal the server that things went well. + await netwk.send_message(FairfedOkay()) + # Flatten group-wise local accuracy and fairness scores. + metrics = { + f"{metric}_{group}": value + for metric, m_dict in values.items() + for group, value in m_dict.items() + } # type: Dict[str, Union[float, np.ndarray]] + # Add FairFed-specific metrics, then return. + metrics["fairfed_value"] = fair_avg + metrics["fairfed_delta"] = my_delta.delta + metrics["fairfed_deltavg"] = deltavg.deltavg + return metrics diff --git a/declearn/fairness/fairfed/_fairfed.py b/declearn/fairness/fairfed/_fairfed.py new file mode 100644 index 0000000000000000000000000000000000000000..9a829ff4125d79246c1b5e86517b97b729ebe5a5 --- /dev/null +++ b/declearn/fairness/fairfed/_fairfed.py @@ -0,0 +1,154 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Fairfed-specific synthetic fairness value computer.""" + +import warnings +from typing import Any, Dict, List, Optional, Tuple + + +__all__ = [ + "FairfedValueComputer", +] + + +class FairfedValueComputer: + """Fairfed-specific synthetic fairness value computer. + + Strict mode + ----------- + This FairFed implementation comes in two flavors. + + - The "strict" mode sticks to the original FairFed paper: + - It applies only to binary classification tasks with + a single binary sensitive attributes. + - Clients must hold examples to each and every group. + - If `f_type` is not explicitly cited in the original + paper, a `RuntimeWarning` is warned. + - The synthetic fairness value is computed based on + fairness values for two groups: (y=`target`,s=1) + and (y=`target`,s=0). + + - The "non-strict" mode extends to broader settings: + - It applies to any number of sensitive groups. + - Clients may not hold examples of all groups. + - It applies to any type of group-fairness. + - The synthetic fairness value is computed as + the average of all absolute fairness values. + - The local fairness is only computed over groups + that have a least one sample in the local data. + """ + + def __init__( + self, + f_type: str, + strict: bool = True, + target: int = 1, + ) -> None: + """Instantiate the FairFed-specific fairness function wrapper. + + Parameters + ---------- + f_type: + Name of the fairness definition being optimized. + strict: + Whether to stick strictly to the FairFed paper's setting + and explicit formulas, or to use a broader adaptation of + FairFed to more diverse settings. + See class docstring for details. + target: + Choice of target label to focus on in `strict` mode. + Unused when `strict=False`. + """ + self.f_type = f_type + self.strict = strict + self.target = target + self._key_groups = ( + None + ) # type: Optional[Tuple[Tuple[Any, ...], Tuple[Any, ...]]] + + def initialize( + self, + groups: List[Tuple[Any, ...]], + ) -> None: + """Initialize the Fairfed synthetic value computer from group counts. + + Parameters + ---------- + groups: + List of sensitive group definitions. + """ + if self.strict: + self._key_groups = self.identify_key_groups(groups) + + def identify_key_groups( + self, + groups: List[Tuple[Any, ...]], + ) -> Tuple[Tuple[Any, ...], Tuple[Any, ...]]: + """Parse sensitive groups' definitions to identify 'key' ones.""" + if self.f_type not in ( + "demographic_parity", + "equality_of_opportunity", + "equalized_odds", + ): + warnings.warn( + f"Using fairness type '{self.f_type}' with FairFed in 'strict'" + " mode. This is supported, but beyond the original paper.", + RuntimeWarning, + ) + if len(groups) != 4: + raise RuntimeError( + "FairFed in 'strict' mode requires exactly 4 sensitive groups," + " arising from a binary target label and a binary attribute." + ) + key_groups = tuple(sorted([g for g in groups if g[0] == self.target])) + if len(key_groups) != 2: + raise KeyError( + f"Failed to identify the (target,attr_0);(target,attr_1) " + "pair of sensitive groups for FairFed in 'strict' mode " + f"with 'target' value {self.target}." + ) + return key_groups + + def compute_synthetic_fairness_value( + self, + fairness: Dict[Tuple[Any, ...], float], + ) -> float: + """Compute a synthetic fairness value from group-wise ones. + + If `self.strict`, compute the difference between the fairness + values associated with two key sensitive groups, as per the + original FairFed paper for the two definitions exposed by the + authors. + + Otherwise, compute the average of absolute group-wise fairness + values, that applies to more generic fairness formulations than + in the original paper, and may encompass broader information. + + Parameters + ---------- + fairness: + Group-wise fairness metrics, as a `{group_k: score_k}` dict. + + Returns + ------- + value: + Scalar value summarizing the computed fairness. + """ + if self._key_groups is None: + return sum(abs(x) for x in fairness.values()) / len(fairness) + return fairness[self._key_groups[0]] - fairness[self._key_groups[1]] diff --git a/declearn/fairness/fairfed/_messages.py b/declearn/fairness/fairfed/_messages.py new file mode 100644 index 0000000000000000000000000000000000000000..d60f233cbe40e5a81f8e2e4563e53acdc0756d6a --- /dev/null +++ b/declearn/fairness/fairfed/_messages.py @@ -0,0 +1,120 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""FairFed specific messages.""" + +import dataclasses + +from typing_extensions import Self # future: import from typing (py>=3.11) + +from declearn.messaging import Message +from declearn.secagg.api import Decrypter, Encrypter +from declearn.secagg.messaging import SecaggMessage + + +__all__ = [ + "FairfedDelta", + "FairfedDeltavg", + "FairfedFairness", + "FairfedOkay", + "SecaggFairfedDelta", +] + + +@dataclasses.dataclass +class FairfedOkay(Message): + """Message for client-emitted signal that Fairfed update went fine.""" + + typekey = "fairfed-okay" + + +@dataclasses.dataclass +class FairfedFairness(Message): + """Message for server-emitted Fairfed global fairness value sharing. + + Fields + ------ + fairness: + Global fairness (or accuracy) value. + """ + + fairness: float + + typekey = "fairfed-fairness" + + +@dataclasses.dataclass +class FairfedDelta(Message): + """Message for client-emitted Fairfed absolute fairness difference. + + Fields + ------ + delta: + Local absolute difference in fairness (or accuracy). + """ + + delta: float + + typekey = "fairfed-delta" + + +@dataclasses.dataclass +class SecaggFairfedDelta(SecaggMessage[FairfedDelta]): + """SecAgg-wrapped 'FairfedDelta' message.""" + + typekey = "secagg-fairfed-delta" + + delta: int + + @classmethod + def from_cleartext_message( + cls, + cleartext: FairfedDelta, + encrypter: Encrypter, + ) -> Self: + delta = encrypter.encrypt_float(cleartext.delta) + return cls(delta=delta) + + def decrypt_wrapped_message( + self, + decrypter: Decrypter, + ) -> FairfedDelta: + delta = decrypter.decrypt_float(self.delta) + return FairfedDelta(delta=delta) + + def aggregate( + self, + other: Self, + decrypter: Decrypter, + ) -> Self: + delta = decrypter.sum_encrypted([self.delta, other.delta]) + return self.__class__(delta=delta) + + +@dataclasses.dataclass +class FairfedDeltavg(Message): + """Message for server-emitted Fairfed average absolute fairness difference. + + Fields + ------ + deltavg: + Average absolute difference in fairness (or accuracy). + """ + + deltavg: float + + typekey = "fairfed-deltavg" diff --git a/declearn/fairness/fairfed/_server.py b/declearn/fairness/fairfed/_server.py new file mode 100644 index 0000000000000000000000000000000000000000..8bb5625b6b15b2916b6566f04bf9af5a2755a755 --- /dev/null +++ b/declearn/fairness/fairfed/_server.py @@ -0,0 +1,203 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Server-side FairFed controller.""" + +import warnings +from typing import Any, Dict, List, Optional, Union + +import numpy as np + +from declearn.aggregator import Aggregator +from declearn.communication.api import NetworkServer +from declearn.communication.utils import verify_client_messages_validity +from declearn.fairness.api import ( + FairnessControllerServer, + instantiate_fairness_function, +) +from declearn.fairness.fairfed._aggregator import FairfedAggregator +from declearn.fairness.fairfed._fairfed import FairfedValueComputer +from declearn.fairness.fairfed._messages import ( + FairfedDelta, + FairfedDeltavg, + FairfedFairness, + FairfedOkay, + SecaggFairfedDelta, +) +from declearn.messaging import FairnessSetupQuery +from declearn.secagg.api import Decrypter +from declearn.secagg.messaging import aggregate_secagg_messages + + +__all__ = [ + "FairfedControllerServer", +] + + +class FairfedControllerServer(FairnessControllerServer): + """Server-side controller to implement FairFed. + + FairFed [1] is an algorithm that aims at enforcing fairness in + a federated learning setting by altering the aggregation rule + for client-wise model updates. It conducts a weighted averaging + of these updates that is based on discrepancy metrics between + global and client-wise fairness measures. + + This algorithm was originally designed for settings where a binary + classifier is trained over data with a single binary sensitive + attribute, with the authors showcasing their generic formulas over + a limited set of group fairness definitions. DecLearn expands it to + a broader case, enabling the use of arbitrary fairness definitions + over data that may have non-binary and/or many sensitive attributes. + A 'strict' mode is made available to stick to the original paper, + that is turned on by default and can be disabled at instantiation. + + It is worth noting that the authors of FairFed suggest combining it + with other mechanisms that aim at enforcing local model fairness; at + the moment, this is not implemented in DecLearn, unless a custom and + specific `Model` subclass is implemented by end-users to do so. + + References + ---------- + - [1] + Ezzeldin et al. (2021). + FairFed: Enabling Group Fairness in Federated Learning + https://arxiv.org/abs/2110.00857 + """ + + algorithm = "fairfed" + + def __init__( + self, + f_type: str, + f_args: Optional[Dict[str, Any]] = None, + beta: float = 1.0, + strict: bool = True, + target: Optional[int] = None, + ) -> None: + """Instantiate the server-side Fed-FairGrad controller. + + Parameters + ---------- + f_type: + Name of the fairness function to evaluate and optimize. + f_args: + Optional dict of keyword arguments to the fairness function. + beta: + Hyper-parameter controlling the magnitude of updates + to clients' averaging weights updates. + strict: + Whether to stick strictly to the FairFed paper's setting + and explicit formulas, or to use a broader adaptation of + FairFed to more diverse settings. + target: + If `strict=True`, target value of interest, on which to focus. + If None, try fetching from `f_args` or use default value `1`. + """ + # arguments serve modularity; pylint: disable=too-many-arguments + super().__init__(f_type=f_type, f_args=f_args) + self.beta = beta + # Set up a temporary fairness function, replaced at setup time. + self._fairness = instantiate_fairness_function( + "accuracy_parity", counts={} + ) + # Set up an uninitialized FairFed value computer. + if target is None: + target = int(self.f_args.get("target", 1)) + self.fairfed_computer = FairfedValueComputer( + f_type=self.f_type, strict=strict, target=target + ) + + @property + def strict( + self, + ) -> bool: + """Whether this controller strictly sticks to the FairFed paper.""" + return self.fairfed_computer.strict + + def prepare_fairness_setup_query( + self, + ) -> FairnessSetupQuery: + query = super().prepare_fairness_setup_query() + query.params["beta"] = self.beta + query.params["strict"] = self.strict + query.params["target"] = self.fairfed_computer.target + return query + + async def finalize_fairness_setup( + self, + netwk: NetworkServer, + secagg: Optional[Decrypter], + counts: List[int], + aggregator: Aggregator, + ) -> Aggregator: + # Set up a fairness function and initialized the FairFed computer. + self._fairness = instantiate_fairness_function( + self.f_type, counts=dict(zip(self.groups, counts)), **self.f_args + ) + self.fairfed_computer.initialize(groups=self.groups) + # Force the use of a FairFed-specific averaging aggregator. + warnings.warn( + "Overriding Aggregator choice due to the use of FairFed.", + category=RuntimeWarning, + ) + return FairfedAggregator(beta=self.beta) + + async def finalize_fairness_round( + self, + netwk: NetworkServer, + secagg: Optional[Decrypter], + values: List[float], + ) -> Dict[str, Union[float, np.ndarray]]: + # Unpack group-wise accuracy values and compute fairness ones. + accuracy = dict(zip(self.groups, values)) + fairness = self._fairness.compute_from_federated_group_accuracy( + accuracy + ) + # Share the absolute mean fairness with clients. + fair_avg = self.fairfed_computer.compute_synthetic_fairness_value( + fairness + ) + await netwk.broadcast_message(FairfedFairness(fairness=fair_avg)) + # Await and (secure-)aggregate clients' absolute fairness difference. + received = await netwk.wait_for_messages() + if secagg is None: + replies = await verify_client_messages_validity( + netwk, received, expected=FairfedDelta + ) + deltavg = sum(r.delta for r in replies.values()) / len(replies) + else: + sec_rep = await verify_client_messages_validity( + netwk, received, expected=SecaggFairfedDelta + ) + deltavg = aggregate_secagg_messages(sec_rep, secagg).delta + # Share the computed value with clients and await their okay signal. + await netwk.broadcast_message(FairfedDeltavg(deltavg=deltavg)) + received = await netwk.wait_for_messages() + await verify_client_messages_validity( + netwk, received, expected=FairfedOkay + ) + # Package and return accuracy, fairness and computed average metrics. + metrics = { + f"accuracy_{key}": val for key, val in accuracy.items() + } # type: Dict[str, Union[float, np.ndarray]] + metrics.update( + {f"{self.f_type}_{key}": val for key, val in fairness.items()} + ) + metrics["fairfed_value"] = fair_avg + metrics["fairfed_deltavg"] = deltavg + return metrics diff --git a/declearn/fairness/fairgrad/__init__.py b/declearn/fairness/fairgrad/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7967c1ad17e8fd6e993c03ab87cfe64fd17de63e --- /dev/null +++ b/declearn/fairness/fairgrad/__init__.py @@ -0,0 +1,78 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Fed-FairGrad algorithm controllers and utils. + +Introduction +------------ +This module provides with an implementation of Fed-FairGrad, +a yet-to-be-published algorithm that adapts the FairGrad [1] +algorithm to the federated learning setting. + +FairGrad aims at minimizing the training loss of a model under +group-fairness constraints, with an optional epsilon tolerance. +It relies on reweighting the loss using weights that are based +on sensitive groups, and are updated throughout training based +on estimates of the current fairness of the trained model. + +Fed-FairGrad formulates the same problem, and adjusts client-wise +weights based on the repartition of group-wise data across clients. +In its current version, the algorithm has fixed weights across local +training steps that are taken between model aggregation steps, while +the weights are updated based on robust estimates of the aggregated +model's fairness on the federated training data. + +This algorithm is designed for settings where a classifier is trained +over data with any number of categorical sensitive attributes. It may +evolve as more theoretical and/or empirical results are obtained as to +its performance (both in terms of utility and fairness). + +Controllers +----------- +* [FairgradControllerClient] +[declearn.fairness.fairgrad.FairgradControllerClient]: + Client-side controller to implement Fed-FairGrad. +* [FairgradControllerServer] +[declearn.fairness.fairgrad.FairgradControllerServer]: + Server-side controller to implement Fed-FairGrad. + +Backend +------- +* [FairgradWeightsController] +[declearn.fairness.fairgrad.FairgradWeightsController]: + Controller to implement Faigrad optimization constraints. + +Messages +-------- +* [FairgradOkay][declearn.fairness.fairgrad.FairgradOkay] +* [FairgradWeights][declearn.fairness.fairgrad.FairgradWeights] + + +References +---------- +- [1] + Maheshwari & Perrot (2023). + FairGrad: Fairness Aware Gradient Descent. + https://openreview.net/forum?id=0f8tU3QwWD +""" + +from ._messages import ( + FairgradOkay, + FairgradWeights, +) +from ._client import FairgradControllerClient +from ._server import FairgradControllerServer, FairgradWeightsController diff --git a/declearn/fairness/fairgrad/_client.py b/declearn/fairness/fairgrad/_client.py new file mode 100644 index 0000000000000000000000000000000000000000..9cfd150721c023843b43cfb69148173389863fc2 --- /dev/null +++ b/declearn/fairness/fairgrad/_client.py @@ -0,0 +1,106 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Client-side Fed-FairGrad controller.""" + +from typing import Any, Dict, Optional, Tuple, Union + +import numpy as np + +from declearn.aggregator import SumAggregator +from declearn.communication.api import NetworkClient +from declearn.communication.utils import verify_server_message_validity +from declearn.fairness.api import ( + FairnessControllerClient, + FairnessDataset, +) +from declearn.fairness.fairgrad._messages import FairgradOkay, FairgradWeights +from declearn.messaging import Error +from declearn.secagg.api import Encrypter + +__all__ = [ + "FairgradControllerClient", +] + + +class FairgradControllerClient(FairnessControllerClient): + """Client-side controller to implement Fed-FairGrad.""" + + algorithm = "fairgrad" + + async def finalize_fairness_setup( + self, + netwk: NetworkClient, + secagg: Optional[Encrypter], + ) -> None: + # Force the use of a SumAggregator. + if not isinstance(self.manager.aggrg, SumAggregator): + self.manager.aggrg = SumAggregator() + # Await initial loss weights from the server. + await self._update_fairgrad_weights(netwk) + + async def _update_fairgrad_weights( + self, + netwk: NetworkClient, + ) -> None: + """Run a FairGrad-specific routine to update sensitive group weights. + + Expect a message from the orchestrating server containing the new + sensitive group weights, and apply them to the training dataset. + + Raises + ------ + RuntimeError: + If the expected message is not received. + If the weights' update fails. + """ + # Receive aggregated sensitive weights. + received = await netwk.recv_message() + message = await verify_server_message_validity( + netwk, received, expected=FairgradWeights + ) + weights = dict(zip(self.groups, message.weights)) + # Set the received weights, handling and propagating exceptions if any. + try: + assert isinstance(self.manager.train_data, FairnessDataset) + self.manager.train_data.set_sensitive_group_weights( + weights, adjust_by_counts=True + ) + except Exception as exc: + self.manager.logger.error( + "Exception encountered when setting FairGrad weights: %s", exc + ) + await netwk.send_message(Error(repr(exc))) + raise RuntimeError("FairGrad weights update failed.") from exc + # If things went well, ping the server back to indicate so. + self.manager.logger.info("Updated FairGrad weights.") + await netwk.send_message(FairgradOkay()) + + async def finalize_fairness_round( + self, + netwk: NetworkClient, + secagg: Optional[Encrypter], + values: Dict[str, Dict[Tuple[Any, ...], float]], + ) -> Dict[str, Union[float, np.ndarray]]: + # Await updated loss weights from the server. + await self._update_fairgrad_weights(netwk) + # Return group-wise local accuracy and fairness scores. + return { + f"{metric}_{group}": value + for metric, m_dict in values.items() + for group, value in m_dict.items() + } diff --git a/declearn/fairness/fairgrad/_messages.py b/declearn/fairness/fairgrad/_messages.py new file mode 100644 index 0000000000000000000000000000000000000000..fd650cbce3b01f8aa7c95943b3b19de83f26c3e1 --- /dev/null +++ b/declearn/fairness/fairgrad/_messages.py @@ -0,0 +1,53 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Fed-FairGrad specific messages.""" + +import dataclasses +from typing import List + + +from declearn.messaging import Message + + +__all__ = [ + "FairgradOkay", + "FairgradWeights", +] + + +@dataclasses.dataclass +class FairgradOkay(Message): + """Message for client-emitted signal that Fed-FairGrad update went fine.""" + + typekey = "fairgrad-okay" + + +@dataclasses.dataclass +class FairgradWeights(Message): + """Message for server-emitted (Fed-)FairGrad loss weights sharing. + + Fields + ------ + weights: + List of group-wise loss weights, ordered based on + an agreed-upon sorted list of sensitive groups. + """ + + weights: List[float] + + typekey = "fairgrad-weights" diff --git a/declearn/fairness/fairgrad/_server.py b/declearn/fairness/fairgrad/_server.py new file mode 100644 index 0000000000000000000000000000000000000000..b4e4efcdc8a7a65df46847dea170cc3c8b5d18d3 --- /dev/null +++ b/declearn/fairness/fairgrad/_server.py @@ -0,0 +1,273 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Server-side Fed-FairGrad controller.""" + +import warnings +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np + +from declearn.aggregator import Aggregator, SumAggregator +from declearn.communication.api import NetworkServer +from declearn.communication.utils import verify_client_messages_validity +from declearn.fairness.api import ( + FairnessControllerServer, + instantiate_fairness_function, +) +from declearn.fairness.fairgrad._messages import FairgradOkay, FairgradWeights +from declearn.secagg.api import Decrypter + + +__all__ = [ + "FairgradControllerServer", + "FairgradWeightsController", +] + + +class FairgradWeightsController: + """Controller to implement Faigrad optimization constraints.""" + + # attrs serve readability; pylint: disable=too-many-instance-attributes + + def __init__( + self, + counts: Dict[Tuple[Any, ...], int], + f_type: str = "accuracy_parity", + eta: float = 1e-2, + eps: float = 1e-6, + **kwargs: Any, + ) -> None: + """Instantiate the FairGrad controller. + + Parameters + ---------- + counts: + Group-wise counts for all target-label and sensitive-attributes + combinations, with format `{(label, *attrs): count}`. + f_type: + Name of the type of fairness based on which to constraint the + optimization problem. By default, "accuracy_parity". + eta: + Learning rate of the controller, impacting the update rule for + fairness constraints and associate weights. + As a rule of thumb, it should be between 1/5 and 1/10 of the + model weights optimizer's learning rate. + eps: + Epsilon value introducing some small tolerance to unfairness. + **kwargs: + Optional keyword arguments to the constants-computing function. + Supported arguments: + - `target: int|list[int]` for "equality_of_opportunity". + """ + # arguments serve modularity; pylint: disable=too-many-arguments + # Store some input parameters. + self.counts = np.array(list(counts.values())) + self.eta = eta + self.eps = eps + self.total = sum(self.counts) # n_samples + # Compute the fairness constraint constants. + self.function = instantiate_fairness_function( + f_type=f_type, counts=counts, **kwargs + ) + # Initialize the gradient-weighting constraint parameters. + n_groups = len(counts) + self.f_k = np.zeros(n_groups) + self._upper = np.zeros(n_groups) # lambda_k^t + self._lower = np.zeros(n_groups) # delta_k^t + + def update_weights_based_on_accuracy( + self, + accuracy: Dict[Tuple[Any, ...], float], + ) -> None: + """Update the held fairness constraint and loss weight parameters. + + Parameters + ---------- + accuracy: + Dict containing group-wise accuracy metrics, formatted + as `{group_k: sum_i(n_ik * accuracy_ik)}`. + """ + f_k = self.function.compute_from_federated_group_accuracy(accuracy) + self.f_k = np.array(list(f_k.values())) + self._upper = np.maximum( + 0, self._upper + self.eta * (self.f_k - self.eps) + ) + self._lower = np.maximum( + 0, self._lower - self.eta * (self.f_k + self.eps) + ) + + def get_current_weights( + self, + norm_nk: bool = True, + ) -> List[float]: + """Return current loss weights for each sensitive group. + + Parameters + ---------- + norm_nk: + Whether to divide output weights by `n_k`. + This is useful in Fed-FairGrad to turn the + base weights into client-wise ones. + + Returns + ------- + weights: + List of group-wise loss weights. + Group definitions may be accessed as `groups` attribute. + """ + # Compute P_k := P(sample \in group_k). + p_tk = self.counts / self.total + # Compute group weights as P_k + Sum_k'(c_k'^k (lambda_k' - delta_k')). + ld_k = self._upper - self._lower + c_kk = self.function.constants[1] + weights = p_tk + np.dot(ld_k, c_kk) + # Optionally normalize weights by group-wise total sample counts. + if norm_nk: + weights /= self.counts + # Output the ordered list of group-wise loss weights. + return weights.tolist() + + def get_current_fairness( + self, + ) -> Dict[Tuple[Any, ...], float]: + """Return the group-wise current fairness level.""" + return { + key: float(val) for key, val in zip(self.function.groups, self.f_k) + } + + +class FairgradControllerServer(FairnessControllerServer): + """Server-side controller to implement Fed-FairGrad. + + FairGrad [1] is an algorithm to learn a model under group-fairness + constraints, that relies on reweighting its training loss based on + the current group-wise fairness levels of the model. + + This controller, together with its client-side counterpart, implements + a straightforward adaptation of FairGrad to the federated learning + setting, where the fairness level of the model is computed robustly + and federatively at the start of each training round, and kept as-is + for all local training steps within that round. + + This algorithm may be applied using any group-fairness definition, + with any number of sensitive attributes and, thereof, groups that + is compatible with the chosen definition. + + References + ---------- + - [1] + Maheshwari & Perrot (2023). + FairGrad: Fairness Aware Gradient Descent. + https://openreview.net/forum?id=0f8tU3QwWD + """ + + algorithm = "fairgrad" + + def __init__( + self, + f_type: str, + f_args: Optional[Dict[str, Any]] = None, + eta: float = 1e-2, + eps: float = 1e-6, + ) -> None: + """Instantiate the server-side Fed-FairGrad controller. + + Parameters + ---------- + f_type: + Name of the fairness function to evaluate and optimize. + f_args: + Optional dict of keyword arguments to the fairness function. + eta: + Learning rate of the controller, impacting the update rule for + fairness constraints and associate weights. + As a rule of thumb, it should be between 1/5 and 1/10 of the + model weights optimizer's learning rate. + eps: + Epsilon value introducing some small tolerance to unfairness. + This may be set to 0.0 to try and enforce absolute fairness. + """ + super().__init__(f_type=f_type, f_args=f_args) + # Set up a temporary controller that will be replaced at setup time. + self.weights_controller = FairgradWeightsController( + counts={}, f_type="accuracy_parity", eta=eta, eps=eps + ) + + async def finalize_fairness_setup( + self, + netwk: NetworkServer, + secagg: Optional[Decrypter], + counts: List[int], + aggregator: Aggregator, + ) -> Aggregator: + # Set up the FairgradWeightsController. + self.weights_controller = FairgradWeightsController( + counts=dict(zip(self.groups, counts)), + f_type=self.f_type, + eta=self.weights_controller.eta, + eps=self.weights_controller.eps, + **self.f_args, + ) + # Send initial loss weights to the clients. + await self._send_fairgrad_weights(netwk) + # Force the use of a SumAggregator. + if not isinstance(aggregator, SumAggregator): + warnings.warn( + "Overriding Aggregator choice to a 'SumAggregator', " + "due to the use of Fed-FairGrad.", + category=RuntimeWarning, + ) + aggregator = SumAggregator() + return aggregator + + async def _send_fairgrad_weights( + self, + netwk: NetworkServer, + ) -> None: + """Send FairGrad sensitive group loss weights to clients. + + Await for clients to ping back that things went fine on their side. + """ + netwk.logger.info("Sending FairGrad weights to clients.") + weights = self.weights_controller.get_current_weights(norm_nk=True) + await netwk.broadcast_message(FairgradWeights(weights=weights)) + received = await netwk.wait_for_messages() + await verify_client_messages_validity( + netwk, received, expected=FairgradOkay + ) + + async def finalize_fairness_round( + self, + netwk: NetworkServer, + secagg: Optional[Decrypter], + values: List[float], + ) -> Dict[str, Union[float, np.ndarray]]: + # Unpack group-wise accuracy metrics and update loss weights. + accuracy = dict(zip(self.groups, values)) + self.weights_controller.update_weights_based_on_accuracy(accuracy) + # Send the updated weights to clients. + await self._send_fairgrad_weights(netwk) + # Package and return accuracy and fairness metrics. + metrics = { + f"accuracy_{key}": val for key, val in accuracy.items() + } # type: Dict[str, Union[float, np.ndarray]] + fairness = self.weights_controller.get_current_fairness() + metrics.update( + {f"{self.f_type}_{key}": val for key, val in fairness.items()} + ) + return metrics diff --git a/declearn/fairness/monitor/__init__.py b/declearn/fairness/monitor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6aa1939ff84dbf1537cae3ea30a573eb6f5a5289 --- /dev/null +++ b/declearn/fairness/monitor/__init__.py @@ -0,0 +1,40 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Fairness-monitoring controllers, that leave training unaltered. + +Introduction +------------ +This submodule implements dummy fairness-aware learning controllers, that +implement fairness metrics' computation, hence enabling their monitoring +throughout training, without altering the model's training process itself. + +These controllers may therefore be used to monitor fairness metrics of any +baseline federated learning algorithm, notably for comparison purposes with +fairness-aware algorithms implemented using other controllers (FairBatch, +Fed-FairGrad, ...). + +Controllers +----------- +* [FairnessMonitorClient][declearn.fairness.monitor.FairnessMonitorClient]: + Client-side controller to monitor fairness without altering training. +* [FairnessMonitorServer][declearn.fairness.monitor.FairnessMonitorServer]: + Server-side controller to monitor fairness without altering training. +""" + +from ._client import FairnessMonitorClient +from ._server import FairnessMonitorServer diff --git a/declearn/fairness/monitor/_client.py b/declearn/fairness/monitor/_client.py new file mode 100644 index 0000000000000000000000000000000000000000..01aa352cc9f6f5dd887fb99d59cac01c8a14e0fc --- /dev/null +++ b/declearn/fairness/monitor/_client.py @@ -0,0 +1,55 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Client-side controller to monitor fairness without altering training.""" + +from typing import Any, Dict, Optional, Tuple, Union + +import numpy as np + +from declearn.secagg.api import Encrypter +from declearn.communication.api import NetworkClient +from declearn.fairness.api import FairnessControllerClient + +__all__ = [ + "FairnessMonitorClient", +] + + +class FairnessMonitorClient(FairnessControllerClient): + """Client-side controller to monitor fairness without altering training.""" + + algorithm = "monitor" + + async def finalize_fairness_setup( + self, + netwk: NetworkClient, + secagg: Optional[Encrypter], + ) -> None: + pass + + async def finalize_fairness_round( + self, + netwk: NetworkClient, + secagg: Optional[Encrypter], + values: Dict[str, Dict[Tuple[Any, ...], float]], + ) -> Dict[str, Union[float, np.ndarray]]: + return { + f"{metric}_{group}": value + for metric, m_dict in values.items() + for group, value in m_dict.items() + } diff --git a/declearn/fairness/monitor/_server.py b/declearn/fairness/monitor/_server.py new file mode 100644 index 0000000000000000000000000000000000000000..dbe6c6c56d2e0e61139751e319886f98cf182bd2 --- /dev/null +++ b/declearn/fairness/monitor/_server.py @@ -0,0 +1,95 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Server-side controller to monitor fairness without altering training.""" + +from typing import Any, Dict, List, Optional, Union + +import numpy as np + +from declearn.aggregator import Aggregator +from declearn.secagg.api import Decrypter +from declearn.communication.api import NetworkServer +from declearn.fairness.api import ( + FairnessControllerServer, + instantiate_fairness_function, +) + +__all__ = [ + "FairnessMonitorServer", +] + + +class FairnessMonitorServer(FairnessControllerServer): + """Server-side controller to monitor fairness without altering training. + + This controller, together with its client-side counterpart, + does not alter the training procedure of the model, but adds + computation and communication steps to measure its fairness + level at the start of each and every training round. + + It is compatible with any group-fairness definition implemented + in DecLearn, and any number of sensitive groups compatible with + the chosen definition. + """ + + algorithm = "monitor" + + def __init__( + self, + f_type: str, + f_args: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__(f_type, f_args) + # Assign a temporary fairness functions, replaced at setup time. + self.function = instantiate_fairness_function( + f_type="accuracy_parity", counts={} + ) + + async def finalize_fairness_setup( + self, + netwk: NetworkServer, + secagg: Optional[Decrypter], + counts: List[int], + aggregator: Aggregator, + ) -> Aggregator: + self.function = instantiate_fairness_function( + f_type=self.f_type, + counts=dict(zip(self.groups, counts)), + **self.f_args, + ) + return aggregator + + async def finalize_fairness_round( + self, + netwk: NetworkServer, + secagg: Optional[Decrypter], + values: List[float], + ) -> Dict[str, Union[float, np.ndarray]]: + # Unpack group-wise accuracy metrics and compute fairness ones. + accuracy = dict(zip(self.groups, values)) + fairness = self.function.compute_from_federated_group_accuracy( + accuracy + ) + # Package and return these metrics. + metrics = { + f"accuracy_{key}": val for key, val in accuracy.items() + } # type: Dict[str, Union[float, np.ndarray]] + metrics.update( + {f"{self.f_type}_{key}": val for key, val in fairness.items()} + ) + return metrics diff --git a/declearn/main/__init__.py b/declearn/main/__init__.py index 49164f430f19ad0b92757fb122b139b7acd8d1dc..1b63808121d281580e94b948def022f8bfdba490 100644 --- a/declearn/main/__init__.py +++ b/declearn/main/__init__.py @@ -30,18 +30,17 @@ This module also implements the following submodules, used by the former: Server-side dataclasses that specify a FL process's parameter. The main classes implemented here are `FLRunConfig` and `FLOptimConfig`, that implement parameters' parsing from python objets or from TOML files. -* [privacy][declearn.main.privacy]: - Differentially-Private training routine utils. - The main class implemented here is `DPTrainingManager` that implements - client-side DP-SGD training. This module is to be manually imported or - lazy-imported by `FederatedClient`, and may trigger warnings or errors - in the absence of the 'opacus' third-party dependency. * [utils][declearn.main.utils]: Various utils to the FL process. The main class of interest for end-users is `TrainingManager`, that implements client-side training and evaluation routines, and may therefore be leveraged in a non-FL setting or to implement other FL process routines than the centralized one defined here. + +Finally, the [privacy][declearn.main.privacy] submodule is DEPRECATED as of +DecLearn 2.6 and will be removed in DecLearn 2.8. It was moved and renamed +to [declearn.training.dp][]. It can still be manually imported under its +deprecated name (containing re-exports of moved contents). """ from . import utils diff --git a/declearn/main/_client.py b/declearn/main/_client.py index 27a9c915ffeb1da6d0f9ddfc5bf4174f5c43420d..506046e634945a75160649fdd7d9e433cb59a5f5 100644 --- a/declearn/main/_client.py +++ b/declearn/main/_client.py @@ -33,11 +33,13 @@ from declearn.communication.utils import ( verify_server_message_validity, ) from declearn.dataset import Dataset, load_dataset_from_json -from declearn.main.utils import Checkpointer, TrainingManager +from declearn.fairness.api import FairnessControllerClient +from declearn.main.utils import Checkpointer from declearn.messaging import Message, SerializedMessage +from declearn.training import TrainingManager from declearn.secagg import parse_secagg_config_client from declearn.secagg.api import Encrypter, SecaggConfigClient, SecaggSetupQuery -from declearn.secagg.messaging import SecaggEvaluationReply, SecaggTrainReply +from declearn.secagg import messaging as secagg_messaging from declearn.utils import LOGGING_LEVEL_MAJOR, get_logger @@ -142,8 +144,9 @@ class FederatedClient: self.logger.warning(msg) warnings.warn(msg, UserWarning, stacklevel=-1) self.verbose = bool(verbose) - # Create a TrainingManager slot, populated at initialization phase. + # Create slots that are (opt.) populated during initialization. self.trainmanager = None # type: Optional[TrainingManager] + self.fairness = None # type: Optional[FairnessControllerClient] @staticmethod def _parse_netwk(netwk) -> Tuple[NetworkClient, bool]: @@ -247,6 +250,8 @@ class FederatedClient: await self.training_round(message.deserialize()) elif issubclass(message.message_cls, messaging.EvaluationRequest): await self.evaluation_round(message.deserialize()) + elif issubclass(message.message_cls, messaging.FairnessQuery): + await self.fairness_round(message.deserialize()) elif issubclass(message.message_cls, SecaggSetupQuery): await self.setup_secagg(message) # note: keep serialized elif issubclass(message.message_cls, messaging.StopTraining): @@ -341,12 +346,15 @@ class FederatedClient: except Exception as exc: await self.netwk.send_message(messaging.Error(repr(exc))) raise RuntimeError("Initialization failed.") from exc + # Send back an empty message to indicate that things went fine. + self.logger.info("Notifying the server that initialization went fine.") + await self.netwk.send_message(messaging.InitReply()) # If instructed to do so, run additional steps to set up DP-SGD. if message.dpsgd: await self._initialize_dpsgd() - # Send back an empty message to indicate that all went fine. - self.logger.info("Notifying the server that initialization went fine.") - await self.netwk.send_message(messaging.InitReply()) + # If instructed to do so, run additional steps to enforce fairness. + if message.fairness: + await self._initialize_fairness() # Optionally checkpoint the received model and optimizer. if self.ckptr: self.ckptr.checkpoint( @@ -388,7 +396,7 @@ class FederatedClient: ) except Exception as exc: raise RuntimeError("DP-SGD initialization failed.") from exc - self.logger.info("Received a request to set up DP-SGD.") + self.logger.info("Received DP-SGD setup instructions.") try: self.make_private(message) except Exception as exc: # pylint: disable=broad-except @@ -423,8 +431,7 @@ class FederatedClient: # fmt: off # lazy-import the DPTrainingManager, that involves some optional, # heavy-loadtime dependencies; pylint: disable=import-outside-toplevel - from declearn.main.privacy import DPTrainingManager - + from declearn.training.dp import DPTrainingManager # pylint: enable=import-outside-toplevel self.trainmanager = DPTrainingManager( model=self.trainmanager.model, @@ -438,6 +445,47 @@ class FederatedClient: ) self.trainmanager.make_private(message) + async def _initialize_fairness( + self, + ) -> None: + """Set up a fairness-enforcing algorithm as part of initialization. + + This method is optionally called in the context of `initialize` + and should never be called in another context. + """ + assert self.trainmanager is not None + # Optionally setup SecAgg; await a FairnessSetupQuery. + try: + # When SecAgg is to be used, setup controllers first. + if self.secagg is not None: + received = await self.netwk.recv_message() + await self.setup_secagg(received) + # Await and deserialize a FairnessSetupQuery. + received = await self.netwk.recv_message() + query = await verify_server_message_validity( + self.netwk, received, expected=messaging.FairnessSetupQuery + ) + except Exception as exc: + error = f"Fairness initialization failed: {repr(exc)}." + self.logger.critical(error) + raise RuntimeError(error) from exc + self.logger.info("Received fairness setup instructions.") + # Instantiate a FairnessControllerClient and run its setup routine. + try: + self.fairness = FairnessControllerClient.from_setup_query( + query=query, manager=self.trainmanager + ) + await self.fairness.setup_fairness( + netwk=self.netwk, secagg=self._encrypter + ) + except Exception as exc: + error = ( + f"Fairness-aware federated learning setup failed: {repr(exc)}." + ) + self.logger.critical(error) + await self.netwk.send_message(messaging.Error(error)) + raise RuntimeError(error) from exc + async def setup_secagg( self, received: SerializedMessage[SecaggSetupQuery], @@ -513,7 +561,7 @@ class FederatedClient: if self._encrypter is not None and isinstance( reply, messaging.TrainReply ): - reply = SecaggTrainReply.from_cleartext_message( + reply = secagg_messaging.SecaggTrainReply.from_cleartext_message( cleartext=reply, encrypter=self._encrypter ) # Send training results (or error message) to the server. @@ -564,12 +612,68 @@ class FederatedClient: reply.metrics.clear() # Optionally SecAgg-encrypt results. if self._encrypter is not None: - reply = SecaggEvaluationReply.from_cleartext_message( + msg_cls = secagg_messaging.SecaggEvaluationReply + reply = msg_cls.from_cleartext_message( cleartext=reply, encrypter=self._encrypter ) # Send evaluation results (or error message) to the server. await self.netwk.send_message(reply) + async def fairness_round( + self, + query: messaging.FairnessQuery, + ) -> None: + """Handle a server request to run a fairness-related round. + + The nature of the round depends on the fairness-aware learning + algorithm that was optionally set up during the initialization + phase. In case no such algorithm was set up, this method will + raise a process-crashing exception. + + Parameters + ---------- + query: + `FairnessQuery` message from the server. + + Raises + ------ + RuntimeError + If no fairness controller was set up for this instance. + """ + assert self.trainmanager is not None + # If no fairness controller was set up, raise a RuntimeError. + if self.fairness is None: + error = ( + "Received a query to participate in a fairness round, " + "but no fairness controller was set up." + ) + self.logger.critical(error) + await self.netwk.send_message(messaging.Error(error)) + raise RuntimeError(error) + # When SecAgg is to be used, verify that it was set up. + if self.secagg is not None and self._encrypter is None: + error = ( + "Refusing to participate in fairness-related round " + f"{query.round_i} as SecAgg is configured to be used " + "but was not set up." + ) + self.logger.error(error) + await self.netwk.send_message(messaging.Error(error)) + return + # Otherwise, run the controller's routine. + metrics = await self.fairness.run_fairness_round( + netwk=self.netwk, query=query, secagg=self._encrypter + ) + # Optionally save computed fairness metrics. + # similar to server code; pylint: disable=duplicate-code + if self.ckptr is not None: + self.ckptr.save_metrics( + metrics=metrics, + prefix="fairness_metrics", + append=bool(query.round_i), + timestamp=f"round_{query.round_i}", + ) + async def stop_training( self, message: messaging.StopTraining, diff --git a/declearn/main/_server.py b/declearn/main/_server.py index bf23803ba08373542cd01d2c1c9830f52b735188..cb4d9e86f46dabb7409ce87b03b08ce5c733a4af 100644 --- a/declearn/main/_server.py +++ b/declearn/main/_server.py @@ -33,6 +33,7 @@ from declearn.communication import NetworkServerConfig from declearn.communication.api import NetworkServer from declearn.main.config import ( EvaluateConfig, + FairnessConfig, FLOptimConfig, FLRunConfig, TrainingConfig, @@ -47,13 +48,9 @@ from declearn.metrics import MetricInputType, MetricSet from declearn.metrics._mean import MeanState from declearn.model.api import Model, Vector from declearn.optimizer.modules import AuxVar +from declearn.secagg import messaging as secagg_messaging from declearn.secagg import parse_secagg_config_server from declearn.secagg.api import Decrypter, SecaggConfigServer -from declearn.secagg.messaging import ( - SecaggEvaluationReply, - SecaggMessage, - SecaggTrainReply, -) from declearn.utils import deserialize_object, get_logger @@ -128,6 +125,7 @@ class FederatedServer: self.aggrg = optim.aggregator self.optim = optim.server_opt self.c_opt = optim.client_opt + self.fairness = optim.fairness # note: optional # Assign the wrapped MetricSet. self.metrics = MetricSet.from_specs(metrics) # Assign an optional checkpointer. @@ -270,7 +268,8 @@ class FederatedServer: specify the federated learning process, including clients registration, training and validation rounds' setup, plus optional elements: local differential-privacy parameters, - and/or an early-stopping criterion. + fairness evaluation parameters, and/or an early-stopping + criterion. """ # Instantiate the early-stopping criterion, if any. early_stop = None # type: Optional[EarlyStopping] @@ -285,11 +284,17 @@ class FederatedServer: # Iteratively run training and evaluation rounds. round_i = 0 while True: + # Run (opt.) fairness; training; evaluation. + await self.fairness_round(round_i, config.fairness) round_i += 1 await self.training_round(round_i, config.training) await self.evaluation_round(round_i, config.evaluate) + # Decide whether to keep training for at least one round. if not self._keep_training(round_i, config.rounds, early_stop): break + # When checkpointing, evaluate the last model's fairness. + if self.ckptr is not None: + await self.fairness_round(round_i, config.fairness) # Interrupt training when time comes. self.logger.info("Stopping training.") await self.stop_training(round_i) @@ -337,6 +342,7 @@ class FederatedServer: metrics=self.metrics.get_config()["metrics"], dpsgd=config.privacy is not None, secagg=None if self.secagg is None else self.secagg.secagg_type, + fairness=self.fairness is not None, ) self.logger.info("Sending initialization requests to clients.") await self.netwk.broadcast_message(message) @@ -351,6 +357,15 @@ class FederatedServer: # If local differential privacy is configured, set it up. if config.privacy is not None: await self._initialize_dpsgd(config) + # If fairness-aware federated learning is configured, set it up. + if self.fairness is not None: + # When SecAgg is to be used, setup controllers first. + if self.secagg is not None: + await self.setup_secagg() + # Call the setup routine of the held fairness controller. + self.aggrg = await self.fairness.setup_fairness( + netwk=self.netwk, aggregator=self.aggrg, secagg=self._decrypter + ) self.logger.info("Initialization was successful.") async def _require_and_process_data_info( @@ -515,15 +530,59 @@ class FederatedServer: def _aggregate_secagg_replies( self, - replies: Mapping[str, SecaggMessage[MessageT]], + replies: Mapping[str, secagg_messaging.SecaggMessage[MessageT]], ) -> MessageT: """Secure-Aggregate (and decrypt) client-issued encrypted messages.""" assert self._decrypter is not None - encrypted = list(replies.values()) - aggregate = encrypted[0] - for message in encrypted[1:]: - aggregate = aggregate.aggregate(message, decrypter=self._decrypter) - return aggregate.decrypt_wrapped_message(decrypter=self._decrypter) + return secagg_messaging.aggregate_secagg_messages( + replies, decrypter=self._decrypter + ) + + async def fairness_round( + self, + round_i: int, + fairness_cfg: FairnessConfig, + ) -> None: + """Orchestrate a fairness round. + + Parameters + ---------- + round_i: + Index of the latest training round (start at 0). + fairness_cfg: + FairnessConfig dataclass instance wrapping data-batching + and computational effort constraints hyper-parameters for + fairness evaluation. + """ + if self.fairness is None: + return + # Run SecAgg setup when needed. + self.logger.info("Initiating fairness-enforcing round %s", round_i) + clients = self.netwk.client_names # FUTURE: enable sampling(?) + if self.secagg is not None and clients.difference(self._secagg_peers): + await self.setup_secagg(clients) + # Send a query to clients, including model weights when required. + query = messaging.FairnessQuery( + round_i=round_i, + batch_size=fairness_cfg.batch_size, + n_batch=fairness_cfg.n_batch, + thresh=fairness_cfg.thresh, + weights=None, + ) + await self._send_request_with_optional_weights(query, clients) + # Await, (secure-)aggregate and process fairness measures. + metrics = await self.fairness.run_fairness_round( + netwk=self.netwk, + secagg=self._decrypter, + ) + # Optionally save computed fairness metrics. + if self.ckptr is not None: + self.ckptr.save_metrics( + metrics=metrics, + prefix="fairness_metrics", + append=bool(query.round_i), + timestamp=f"round_{query.round_i}", + ) async def training_round( self, @@ -554,7 +613,7 @@ class FederatedServer: ) else: secagg_results = await self._collect_results( - clients, SecaggTrainReply, "training" + clients, secagg_messaging.SecaggTrainReply, "training" ) results = { "aggregated": self._aggregate_secagg_replies(secagg_results) @@ -599,7 +658,11 @@ class FederatedServer: async def _send_request_with_optional_weights( self, - msg_light: Union[messaging.TrainRequest, messaging.EvaluationRequest], + msg_light: Union[ + messaging.TrainRequest, + messaging.EvaluationRequest, + messaging.FairnessQuery, + ], clients: Set[str], ) -> None: """Send a request to clients, sparingly adding model weights to it. @@ -683,7 +746,7 @@ class FederatedServer: ) else: secagg_results = await self._collect_results( - clients, SecaggEvaluationReply, "evaluation" + clients, secagg_messaging.SecaggEvaluationReply, "evaluation" ) results = { "aggregated": self._aggregate_secagg_replies(secagg_results) diff --git a/declearn/main/config/__init__.py b/declearn/main/config/__init__.py index c26bb7eef8287fac3cdfbad24678055809e9456c..42e7c0e5cc1c5d869d9efbe35068eb09e1452a00 100644 --- a/declearn/main/config/__init__.py +++ b/declearn/main/config/__init__.py @@ -33,6 +33,8 @@ The following dataclasses are articulated by `FLRunConfig`: * [EvaluateConfig][declearn.main.config.EvaluateConfig]: Hyper-parameters for an evaluation round. +* [FairnessConfig][declearn.main.config.FairnessConfig]: + Dataclass wrapping parameters for fairness evaluation rounds. * [RegisterConfig][declearn.main.config.RegisterConfig]: Hyper-parameters for clients registration. * [TrainingConfig][declearn.main.config.TrainingConfig]: @@ -41,6 +43,7 @@ The following dataclasses are articulated by `FLRunConfig`: from ._dataclasses import ( EvaluateConfig, + FairnessConfig, PrivacyConfig, RegisterConfig, TrainingConfig, diff --git a/declearn/main/config/_dataclasses.py b/declearn/main/config/_dataclasses.py index 0c4f56148de97402d3aed1355b164d3486870d55..867343e689aa097273bd238bee54dbd32739e82f 100644 --- a/declearn/main/config/_dataclasses.py +++ b/declearn/main/config/_dataclasses.py @@ -22,6 +22,7 @@ from typing import Any, Dict, Optional, Tuple __all__ = [ "EvaluateConfig", + "FairnessConfig", "PrivacyConfig", "RegisterConfig", "TrainingConfig", @@ -223,3 +224,30 @@ class PrivacyConfig: accountants = ("rdp", "gdp", "prv") if self.accountant not in accountants: raise TypeError(f"'accountant' should be one of {accountants}") + + +@dataclasses.dataclass +class FairnessConfig: + """Dataclass wrapping parameters for fairness evaluation rounds. + + The parameters wrapped by this class are those of + `declearn.fairness.core.FairnessAccuracyComputer` + metrics-computation methods. + + Attributes + ---------- + batch_size: int + Number of samples per processed data batch. + n_batch: int or None, default=None + Optional maximum number of batches to draw. + If None, use the entire training dataset. + thresh: float or None, default=None + Optional binarization threshold for binary classification + models' output scores. If None, use 0.5 by default, or 0.0 + for `SklearnSGDModel` instances. + Unused for multinomial classifiers (argmax over scores). + """ + + batch_size: int = 32 + n_batch: Optional[int] = None + thresh: Optional[float] = None diff --git a/declearn/main/config/_run_config.py b/declearn/main/config/_run_config.py index ec68dbbd5353d59d13955547656e48bf3f9d42ac..570761133fbb8aec1b3fca42daa4bdaefd033311 100644 --- a/declearn/main/config/_run_config.py +++ b/declearn/main/config/_run_config.py @@ -25,6 +25,7 @@ from typing_extensions import Self # future: import from typing (py >=3.11) from declearn.main.utils import EarlyStopConfig from declearn.main.config._dataclasses import ( EvaluateConfig, + FairnessConfig, PrivacyConfig, RegisterConfig, TrainingConfig, @@ -66,6 +67,10 @@ class FLRunConfig(TomlConfig): and data-batching instructions. - evaluate: EvaluateConfig Parameters for validation rounds, similar to training ones. + - fairness: FairnessConfig or None + Parameters for fairness evaluation rounds. + Only used when an algorithm to enforce fairness is set up, + as part of the process's federated optimization configuration. - privacy: PrivacyConfig or None Optional parameters to set up local differential privacy, by having clients use the DP-SGD algorithm for training. @@ -90,15 +95,20 @@ class FLRunConfig(TomlConfig): batch size will be used for evaluation as well. - If `privacy` is provided and the 'poisson' parameter is unspecified for `training`, it will be set to True by default rather than False. + - If `fairness` is not provided or lacks a 'batch_size' parameter, + that of evaluation (or, by extension, training) will be used. """ rounds: int register: RegisterConfig training: TrainingConfig evaluate: EvaluateConfig + fairness: FairnessConfig privacy: Optional[PrivacyConfig] = None early_stop: Optional[EarlyStopConfig] = None # type: ignore # is a type + autofill_fields = {"evaluate", "fairness"} + @classmethod def parse_register( cls, @@ -128,7 +138,7 @@ class FLRunConfig(TomlConfig): # If evaluation batch size is not set, use the same as training. # Note: if inputs have invalid formats, let the parent method fail. evaluate = kwargs.setdefault("evaluate", {}) - if isinstance(evaluate, dict): + if isinstance(evaluate, dict) and ("batch_size" not in evaluate): training = kwargs.get("training") if isinstance(training, dict): evaluate.setdefault("batch_size", training.get("batch_size")) @@ -141,5 +151,14 @@ class FLRunConfig(TomlConfig): training = kwargs.get("training") if isinstance(training, dict): training.setdefault("poisson", True) + # If fairness batch size is not set, use the same as evaluation. + # Note: if inputs have invalid formats, let the parent method fail. + fairness = kwargs.setdefault("fairness", {}) + if isinstance(fairness, dict) and ("batch_size" not in fairness): + evaluate = kwargs.get("evaluate") + if isinstance(evaluate, dict): + fairness.setdefault("batch_size", evaluate.get("batch_size")) + elif isinstance(evaluate, EvaluateConfig): + fairness.setdefault("batch_size", evaluate.batch_size) # Delegate the rest of the work to the parent method. return super().from_params(**kwargs) diff --git a/declearn/main/config/_strategy.py b/declearn/main/config/_strategy.py index 333072a74f9d122aa4d76df2fae8347c4ebad5cd..85e681466dc07041062e9fa89307447f13c12e47 100644 --- a/declearn/main/config/_strategy.py +++ b/declearn/main/config/_strategy.py @@ -19,10 +19,11 @@ import dataclasses import functools -from typing import Any, Dict, Union +from typing import Any, Dict, Optional, Union from declearn.aggregator import Aggregator, AveragingAggregator +from declearn.fairness.api import FairnessControllerServer from declearn.optimizer import Optimizer from declearn.utils import TomlConfig, access_registered, deserialize_object @@ -59,6 +60,9 @@ class FLOptimConfig(TomlConfig): - aggregator: Aggregator, default=AverageAggregator() Client weights aggregator to be used by the server so as to conduct the round-wise aggregation of client udpates. + - fairness: Fairness or None, default=None + Optional `FairnessControllerServer` instance specifying + an algorithm to enforce fairness of the trained model. Notes ----- @@ -98,6 +102,7 @@ class FLOptimConfig(TomlConfig): aggregator: Aggregator = dataclasses.field( default_factory=AveragingAggregator ) + fairness: Optional[FairnessControllerServer] = None @classmethod def parse_client_opt( @@ -105,7 +110,14 @@ class FLOptimConfig(TomlConfig): field: dataclasses.Field, # future: dataclasses.Field[Optimizer] inputs: Union[float, Dict[str, Any], Optimizer], ) -> Optimizer: - """Field-specific parser to instantiate the client-side Optimizer.""" + """Field-specific parser to instantiate the client-side Optimizer. + + This method supports specifying `client_opt`: + + - as a float, parsed as the learning rate to a basic SGD optimzier + - as a dict, parsed a serialized Optimizer configuration + - as an `Optimizer` instance (requiring no parsing) + """ return cls._parse_optimizer(field, inputs) @classmethod @@ -114,7 +126,15 @@ class FLOptimConfig(TomlConfig): field: dataclasses.Field, # future: dataclasses.Field[Optimizer] inputs: Union[float, Dict[str, Any], Optimizer, None], ) -> Optimizer: - """Field-specific parser to instantiate the server-side Optimizer.""" + """Field-specific parser to instantiate the server-side Optimizer. + + This method supports specifying `server_opt`: + + - as None (or missing kwarg), resulting in a basic `Optimizer(1.0)` + - as a float, parsed as the learning rate to a basic SGD optimzier + - as a dict, parsed a serialized Optimizer configuration + - as an `Optimizer` instance (requiring no parsing) + """ return cls._parse_optimizer(field, inputs) @classmethod @@ -150,6 +170,7 @@ class FLOptimConfig(TomlConfig): - (opt.) config: dict specifying kwargs for the constructor - any other field will be added to the `config` kwargs dict - as None (or missing kwarg), using default AveragingAggregator() + - as an `Aggregator` instance (requiring no parsing) """ # Case when using the default value: delegate to the default parser. if inputs is None: @@ -188,3 +209,35 @@ class FLOptimConfig(TomlConfig): return obj # Otherwise, raise a TypeError as inputs are unsupported. raise TypeError("Unsupported inputs type for field 'aggregator'.") + + @classmethod + def parse_fairness( + cls, + field: dataclasses.Field, # future: dataclasses.Field[<type>] + inputs: Union[Dict[str, Any], FairnessControllerServer, None], + ) -> FairnessControllerServer: + """Field-specific parser to instantiate a FairnessControllerServer. + + This method supports specifying `fairness`: + + - as None (or missing kwarg), using no fairness controller + - as a dict, parsed a FairnessControllerServer specifications: + - algorithm: str used to retrieve a registered type + - f_type: str used to define a group fairness function + - (opt.) f_args: dict to parametrize the fairness function + - any other field will be added to the `config` kwargs dict + - as a `FairnessControllerServer` instance (requiring no parsing) + """ + if inputs is None: + return cls.default_parser(field, inputs) + if isinstance(inputs, FairnessControllerServer): + return inputs + if isinstance(inputs, dict): + for key in ("algorithm", "f_type"): + if key not in inputs: + raise TypeError( + "Wrong format for FairnessControllerServer " + f"configuration: missing '{key}' field." + ) + return FairnessControllerServer.from_specs(**inputs) + raise TypeError("Unsupported inputs type for field 'fairness.") diff --git a/declearn/main/privacy/__init__.py b/declearn/main/privacy/__init__.py index aa256c7ffb56e80e0d915b068d2c3eed4a92d77c..b81a0b985bf94fdac768772812e4c784d5159d33 100644 --- a/declearn/main/privacy/__init__.py +++ b/declearn/main/privacy/__init__.py @@ -15,10 +15,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Submodule implementing Differential-Privacy-oriented tools. +"""DEPRECATED Submodule implementing Differential-Privacy-oriented tools. -* [DPTrainingManager][declearn.main.privacy.DPTrainingManager]: +This module was moved to `declearn.training.dp` as of DecLearn 2.6, and is +only re-exported for retro-compatibility. It will be removed in DecLearn 2.8. + +* [DPTrainingManager][declearn.training.dp.DPTrainingManager]: TrainingManager subclass implementing Differential Privacy mechanisms. """ -from ._dp_trainer import DPTrainingManager +# pragma: no cover + +import warnings + +from declearn.training.dp import DPTrainingManager + +warnings.warn( + "'declearn.main.privacy' was moved to `declearn.training.dp` and is only " + "re-exported for retro-compatibility. It will be removed in DecLearn 2.8.", + DeprecationWarning, +) + +__all__ = ["DPTrainingManager"] diff --git a/declearn/main/utils/__init__.py b/declearn/main/utils/__init__.py index 5b903e43cd7b249c401b293a477e103ee86a9330..604a96d81d4623742ef3f4529c228be4ea9a98e1 100644 --- a/declearn/main/utils/__init__.py +++ b/declearn/main/utils/__init__.py @@ -17,13 +17,6 @@ """Utils for the main federated learning traning and evaluation processes. -TrainingManager ---------------- -The main class implemented here is `TrainingManager`, that is used by clients -and may also be used to perform centralized machine learning using declearn: - -* [TrainingManager][declearn.main.utils.TrainingManager]: - End-user utils -------------- Utils that may be composed into the main orchestration classes: @@ -45,8 +38,20 @@ Backend utils to aggregate clients' dataset information: * [AggregationError][declearn.main.utils.AggregationError]: Custom exception that may be raised by `aggregate_clients_data_info`. -Backend: effort constraints ---------------------------- + +DEPRECATED TrainingManager +-------------------------- +This class has been moved to `declearn.training.TrainingManager` as of +DecLearn 2.6. It is re-exported merely for retro-compatibility purposes, +but this import path will be removed in DecLearn 2.8. + +* [TrainingManager][declearn.training.TrainingManager]: + Class wrapping the logic for local training and evaluation rounds. + + +DEPRECATED Backend: effort constraints +-------------------------------------- + Backend utils that are used to specify and articulate effort constraints for training and evaluation rounds: @@ -56,10 +61,21 @@ for training and evaluation rounds: Utility class to wrap sets of Constraint instances. * [TimeoutConstraint][declearn.main.utils.TimeoutConstraint]: Class implementing a simple time-based constraint. + +The following components have been moved elsewhere and made private as of +DecLearn 2.6. They are re-exported from this module for retro-compatibility +but will be removed in DecLearn 2.8. +**If you are using them, let us know so that we may amend this decision.** """ +# Deprecated re-exports. FUTURE: remove these (DecLearn >=2.8) +from declearn.training import TrainingManager +from declearn.training._constraints import ( + Constraint, + ConstraintSet, + TimeoutConstraint, +) + from ._checkpoint import Checkpointer -from ._constraints import Constraint, ConstraintSet, TimeoutConstraint from ._data_info import AggregationError, aggregate_clients_data_info from ._early_stop import EarlyStopping, EarlyStopConfig -from ._training import TrainingManager diff --git a/declearn/messaging/__init__.py b/declearn/messaging/__init__.py index 17f5e5e42fd6be0fbe6566dbcd68817937447f67..4717448a32d34efe57b3242627fc5eaae4316c6e 100644 --- a/declearn/messaging/__init__.py +++ b/declearn/messaging/__init__.py @@ -44,6 +44,14 @@ Base messages * [TrainReply][declearn.messaging.TrainReply] * [TrainRequest][declearn.messaging.TrainRequest] +Fairness algorithms messages +---------------------------- + +* [FairnessCounts][declearn.messaging.FairnessCounts] +* [FairnessGroups][declearn.messaging.FairnessGroups] +* [FairnessQuery][declearn.messaging.FairnessQuery] +* [FairnessReply][declearn.messaging.FairnessReply] +* [FairnessSetupQuery][declearn.messaging.FairnessSetupQuery] """ from ._api import ( @@ -66,3 +74,10 @@ from ._base import ( TrainReply, TrainRequest, ) +from ._fairness import ( + FairnessCounts, + FairnessGroups, + FairnessQuery, + FairnessReply, + FairnessSetupQuery, +) diff --git a/declearn/messaging/_base.py b/declearn/messaging/_base.py index 52bb60109f8d76dfbf331e06a99ad539459472ae..6cec34d26be39428434367ecdff0330d4a77c96b 100644 --- a/declearn/messaging/_base.py +++ b/declearn/messaging/_base.py @@ -122,6 +122,7 @@ class InitRequest(Message): metrics: List[MetricInputType] = dataclasses.field(default_factory=list) dpsgd: bool = False secagg: Optional[str] = None + fairness: bool = False def to_kwargs(self) -> Dict[str, Any]: data = {} # type: Dict[str, Any] @@ -131,6 +132,7 @@ class InitRequest(Message): data["metrics"] = self.metrics data["dpsgd"] = self.dpsgd data["secagg"] = self.secagg + data["fairness"] = self.fairness return data @classmethod diff --git a/declearn/messaging/_fairness.py b/declearn/messaging/_fairness.py new file mode 100644 index 0000000000000000000000000000000000000000..b850221d0df18d35a44f0423bdb5818c6e88982a --- /dev/null +++ b/declearn/messaging/_fairness.py @@ -0,0 +1,136 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Messages for fairness-aware federated learning setup and rounds.""" + +import dataclasses +from typing import Any, Dict, List, Optional, Tuple + +from typing_extensions import Self # future: import from typing (py >=3.11) + +from declearn.messaging._api import Message +from declearn.model.api import Vector + +__all__ = [ + "FairnessCounts", + "FairnessGroups", + "FairnessQuery", + "FairnessReply", + "FairnessSetupQuery", +] + + +@dataclasses.dataclass +class FairnessCounts(Message): + """Message for client-emitted sample counts across sensitive groups. + + Fields + ------ + counts: + List of group-wise sample counts, ordered based on + an agreed-upon sorted list of sensitive groups. + """ + + counts: List[int] + + typekey = "fairness-counts" + + +@dataclasses.dataclass +class FairnessGroups(Message): + """Message to exchange a list of unique sensitive group definitions. + + This message may be exchanged both ways, with clients sharing the + list of groups for which they have samples and the server sharing + back a unified, sorted list of all sensitive groups across clients. + + Fields + ------ + groups: + List of sensitive group definitions, defined by tuples of values + corresponding to those of one or more sensitive attributes and + (optionally) a target label. + """ + + groups: List[Tuple[Any, ...]] + + typekey = "fairness-groups" + + @classmethod + def from_kwargs( + cls, + **kwargs: Any, + ) -> Self: + kwargs["groups"] = [tuple(group) for group in kwargs["groups"]] + return super().from_kwargs(**kwargs) + + +@dataclasses.dataclass +class FairnessQuery(Message): + """Base Message for server-emitted fairness-computation queries. + + This message conveys hyper-parameters used when evaluating a model's + accuracy and/or loss over group-wise samples (from which fairness is + derived). Model weights may be attached. + + Algorithm-specific information should be conveyed using ad-hoc + messages exchanged as part of fairness-enforcement routines. + """ + + typekey = "fairness-request" + + round_i: int + batch_size: int = 32 + n_batch: Optional[int] = None + thresh: Optional[float] = None + weights: Optional[Vector] = None + + +@dataclasses.dataclass +class FairnessReply(Message): + """Base Message for client-emitted fairness-computation results. + + This message conveys results from the evaluation of a model's accuracy + and/or loss over group-wise samples (from which fairness is derived). + + This information is generically stored as a list of `values`, the + mearning and structure of which is left up to algorithm-specific + controllers. + """ + + typekey = "fairness-reply" + + values: List[float] = dataclasses.field(default_factory=list) + + +@dataclasses.dataclass +class FairnessSetupQuery(Message): + """Message to instruct clients to instantiate a fairness controller. + + Fields + ------ + algorithm: + Name of the algorithm, under which the target controller type + is expected to be registered. + params: + Dict of instantiation keyword arguments to the controller. + """ + + typekey = "fairness-setup-query" + + algorithm: str + params: Dict[str, Any] = dataclasses.field(default_factory=dict) diff --git a/declearn/metrics/__init__.py b/declearn/metrics/__init__.py index 445dbfe0be659711708be35b67a6f2c4e7d5071a..66898a2cea4ec0bc0db84bba7d2dc2072adddb63 100644 --- a/declearn/metrics/__init__.py +++ b/declearn/metrics/__init__.py @@ -40,6 +40,12 @@ Utils Classification metrics ---------------------- +* [Accuracy][declearn.metrics.Accuracy]: + Single-label classification accuracy. + Identified name: "accuracy". +* [BinaryRocAUC][declearn.metrics.BinaryRocAUC]: + Receiver Operator Curve and its Area Under the Curve for binary classif. + Identifier name: "binary-roc". * [BinaryAccuracyPrecisionRecall]\ [declearn.metrics.BinaryAccuracyPrecisionRecall]: Accuracy, precision, recall and confusion matrix for binary classif. @@ -48,9 +54,6 @@ Classification metrics [declearn.metrics.MulticlassAccuracyPrecisionRecall]: Accuracy, precision, recall and confusion matrix for multiclass classif. Identifier name: "multi-classif". -* [BinaryRocAUC][declearn.metrics.BinaryRocAUC]: - Receiver Operator Curve and its Area Under the Curve for binary classif. - Identifier name: "binary-roc". Regression metrics ------------------ @@ -71,7 +74,12 @@ from ._classif import ( BinaryAccuracyPrecisionRecall, MulticlassAccuracyPrecisionRecall, ) -from ._mean import MeanMetric, MeanAbsoluteError, MeanSquaredError +from ._mean import ( + Accuracy, + MeanMetric, + MeanAbsoluteError, + MeanSquaredError, +) from ._roc_auc import BinaryRocAUC from ._rsquared import RSquared from ._wrapper import MetricInputType, MetricSet diff --git a/declearn/metrics/_mean.py b/declearn/metrics/_mean.py index 29aa924d93369b94c1396d04f64a7f0b1ce88d9b..7ed7b1293bf7f19d2588c51e46667ed463498f0a 100644 --- a/declearn/metrics/_mean.py +++ b/declearn/metrics/_mean.py @@ -27,6 +27,7 @@ from declearn.metrics._api import Metric, MetricState from declearn.metrics._utils import squeeze_into_identical_shapes __all__ = [ + "Accuracy", "MeanMetric", "MeanAbsoluteError", "MeanSquaredError", @@ -176,3 +177,40 @@ class MeanSquaredError(MeanMetric): while errors.ndim > 1: errors = errors.sum(axis=-1) return errors + + +class Accuracy(MeanMetric, register=False): + """Metric container to compute classification accuracy iteratively. + + This metric applies to a single-label classification model, + and computes the (opt. weighted) mean sample-wise accuracy. + It requires true labels to be formatted as sample-wise scalar + values (as opposed to one-hot encoded values). + + Computed metric is the following: + + * accuracy: float + Mean accuracy, averaged across samples. + """ + + name = "accuracy" + + def __init__( + self, + thresh: float, + ) -> None: + super().__init__() + self.thresh = thresh + + def metric_func( + self, + y_true: np.ndarray, + y_pred: np.ndarray, + ) -> np.ndarray: + y_pred = ( + y_pred > self.thresh + if (y_pred.ndim == 1) or (y_pred.shape[1] == 1) + else y_pred.max(axis=1) + ) + y_true, y_pred = squeeze_into_identical_shapes(y_true, y_pred) + return y_pred == y_true diff --git a/declearn/metrics/_wrapper.py b/declearn/metrics/_wrapper.py index e56746466479a2c76e2c00dbf845c5ac9fbc5bf4..cbb5cee2afc71f39ccafca789b818b10d79f3bc8 100644 --- a/declearn/metrics/_wrapper.py +++ b/declearn/metrics/_wrapper.py @@ -18,7 +18,7 @@ """Wrapper for an ensemble of Metric objects.""" import warnings -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import numpy as np from typing_extensions import Self # future: import from typing (py >=3.11) @@ -52,7 +52,7 @@ class MetricSet: def __init__( self, - metrics: List[MetricInputType], + metrics: Sequence[MetricInputType], ) -> None: """Instantiate the grouped ensemble of Metric instances. diff --git a/declearn/optimizer/schedulers/__init__.py b/declearn/optimizer/schedulers/__init__.py index 624d08caabb8c001a1a80102263c8a6b9c42d289..141782db36e0bc4982a2561d1144c33b73d85f9a 100644 --- a/declearn/optimizer/schedulers/__init__.py +++ b/declearn/optimizer/schedulers/__init__.py @@ -61,9 +61,9 @@ Cyclic rate rules Warmup schedulers ----------------- -* [Warmup][declearn.optimizers.schedulers.Warmup]: +* [Warmup][declearn.optimizer.schedulers.Warmup]: Scheduler (wrapper) setting up a linear warmup over steps. -* [WarmupRounds][declearn.optimizers.schedulers.WarmupRounds]: +* [WarmupRounds][declearn.optimizer.schedulers.WarmupRounds]: Scheduler (wrapper) setting up a linear warmup over rounds. """ diff --git a/declearn/optimizer/schedulers/_api.py b/declearn/optimizer/schedulers/_api.py index 5a3a7e5bcf9530f8bb74b035aabc85ec7ac0dc99..d14cc27dabf45dde6cd3575c43470d52c3c7f4d2 100644 --- a/declearn/optimizer/schedulers/_api.py +++ b/declearn/optimizer/schedulers/_api.py @@ -99,10 +99,10 @@ class Scheduler(metaclass=abc.ABCMeta): Inheritance ----------- - When a subclass inheriting from `OptiModule` is declared, it is - automatically registered under the "OptiModule" group using its + When a subclass inheriting from `Scheduler` is declared, it is + automatically registered under the "Scheduler" group using its class-attribute `name`. This can be prevented by adding `register=False` - to the inheritance specs (e.g. `class MyCls(OptiModule, register=False)`). + to the inheritance specs (e.g. `class MyCls(Scheduler, register=False)`). See `declearn.utils.register_type` for details on types registration. """ diff --git a/declearn/secagg/messaging.py b/declearn/secagg/messaging.py index 8a9d55fe17b43caee82136cb11e4c1aec7c61abe..a2949c418851c9f4be6fc108da765a5c33a5f8fb 100644 --- a/declearn/secagg/messaging.py +++ b/declearn/secagg/messaging.py @@ -19,20 +19,29 @@ import abc import dataclasses -from typing import Dict, Generic, TypeVar +from typing import Dict, Generic, List, Mapping, TypeVar from typing_extensions import Self # future: import from typing (py >=3.11) from declearn.aggregator import ModelUpdates -from declearn.messaging import EvaluationReply, Message, TrainReply +from declearn.messaging import ( + EvaluationReply, + FairnessCounts, + FairnessReply, + Message, + TrainReply, +) from declearn.metrics import MetricState from declearn.optimizer.modules import AuxVar from declearn.secagg.api import Decrypter, Encrypter, SecureAggregate __all__ = [ "SecaggEvaluationReply", + "SecaggFairnessCounts", + "SecaggFairnessReply", "SecaggMessage", "SecaggTrainReply", + "aggregate_secagg_messages", ] @@ -66,7 +75,7 @@ class SecaggMessage( Parameters ---------- - cleartext: + cleartext:1 Message that needs encryption prior to sharing. encrypter: Controller to be used for message contents' encryption. @@ -98,6 +107,33 @@ class SecaggMessage( """Aggregate two clients' SecaggMessage instances into one.""" +def aggregate_secagg_messages( + messages: Mapping[str, SecaggMessage[MessageT]], + decrypter: Decrypter, +) -> MessageT: + """Secure-Aggregate (and decrypt) client-issued encrypted messages. + + Parameters + ---------- + messages: + Mapping of client-wise `SecaggMessage` instances, wrapping + similar messages that need secure aggregation. + decrypter: + Decryption controller to use when aggregating inputs. + + Returns + ------- + message: + Cleartext message resulting from the secure aggregation + of input `messages`. + """ + encrypted = list(messages.values()) + aggregate = encrypted[0] + for message in encrypted[1:]: + aggregate = aggregate.aggregate(message, decrypter=decrypter) + return aggregate.decrypt_wrapped_message(decrypter=decrypter) + + @dataclasses.dataclass class SecaggTrainReply(SecaggMessage[TrainReply]): """SecAgg-wrapped 'TrainReply' message.""" @@ -228,3 +264,80 @@ class SecaggEvaluationReply(SecaggMessage[EvaluationReply]): return self.__class__( loss=loss, n_steps=n_steps, t_spent=t_spent, metrics=metrics ) + + +@dataclasses.dataclass +class SecaggFairnessCounts(SecaggMessage[FairnessCounts]): + """SecAgg counterpart of the 'FairnessCounts' message class.""" + + counts: List[int] + + typekey = "secagg-fairness-counts" + + @classmethod + def from_cleartext_message( + cls, + cleartext: FairnessCounts, + encrypter: Encrypter, + ) -> Self: + counts = [encrypter.encrypt_uint(val) for val in cleartext.counts] + return cls(counts=counts) + + def decrypt_wrapped_message( + self, + decrypter: Decrypter, + ) -> FairnessCounts: + counts = [decrypter.decrypt_uint(val) for val in self.counts] + return FairnessCounts(counts=counts) + + def aggregate( + self, + other: Self, + decrypter: Decrypter, + ) -> Self: + counts = [ + decrypter.sum_encrypted([v_a, v_b]) + for v_a, v_b in zip(self.counts, other.counts) + ] + return self.__class__(counts=counts) + + +@dataclasses.dataclass +class SecaggFairnessReply(SecaggMessage[FairnessReply]): + """SecAgg-wrapped 'FairnessReply' message.""" + + typekey = "secagg_fairness_reply" + + values: List[int] + + @classmethod + def from_cleartext_message( + cls, + cleartext: FairnessReply, + encrypter: Encrypter, + ) -> Self: + values = [encrypter.encrypt_float(value) for value in cleartext.values] + return cls(values=values) + + def decrypt_wrapped_message( + self, + decrypter: Decrypter, + ) -> FairnessReply: + values = [decrypter.decrypt_float(value) for value in self.values] + return FairnessReply(values=values) + + def aggregate( + self, + other: Self, + decrypter: Decrypter, + ) -> Self: + if len(self.values) != len(other.values): + raise ValueError( + "Cannot aggregate SecAgg-protected fairness values with " + "distinct shapes." + ) + values = [ + decrypter.sum_encrypted([v_a, v_b]) + for v_a, v_b in zip(self.values, other.values) + ] + return self.__class__(values=values) diff --git a/declearn/test_utils/__init__.py b/declearn/test_utils/__init__.py index 37915bf7b7f2697688d5a44f4d4594114b93b515..f05bbb7bd6dccb657ffccc26b45cbd03b28fe52f 100644 --- a/declearn/test_utils/__init__.py +++ b/declearn/test_utils/__init__.py @@ -38,7 +38,12 @@ from ._assertions import ( from ._convert import to_numpy from ._gen_ssl import generate_ssl_certificates from ._imports import make_importable -from ._network import MockNetworkClient, MockNetworkServer +from ._network import ( + MockNetworkClient, + MockNetworkServer, + setup_mock_network_endpoints, +) +from ._secagg import build_secagg_controllers from ._vectors import ( FrameworkType, GradientsTestCase, diff --git a/declearn/test_utils/_network.py b/declearn/test_utils/_network.py index db0515329fe155c2b5516c047fcf872cfe5f31fc..e0ad213ce779de4cc3764785b0566e9282c10540 100644 --- a/declearn/test_utils/_network.py +++ b/declearn/test_utils/_network.py @@ -18,9 +18,13 @@ """Fake network communication endpoints relying on shared memory objects.""" import asyncio +import contextlib import logging import uuid -from typing import Dict, Mapping, Optional, Set, TypeVar, Union +from typing import ( + # fmt: off + AsyncIterator, Dict, List, Mapping, Optional, Set, Tuple, TypeVar, Union +) from declearn.communication.api import NetworkClient, NetworkServer @@ -31,6 +35,7 @@ from declearn.messaging import Message, SerializedMessage __all__ = [ "MockNetworkClient", "MockNetworkServer", + "setup_mock_network_endpoints", ] @@ -179,3 +184,47 @@ class MockNetworkClient(NetworkClient, register=False): ) -> SerializedMessage: # Force the use of a timeout, to prevent tests from being stuck. return await super().recv_message(timeout=timeout or 5) + + +@contextlib.asynccontextmanager +async def setup_mock_network_endpoints( + n_peers: int, + port: int = 8765, +) -> AsyncIterator[Tuple[MockNetworkServer, List[MockNetworkClient]]]: + """Instantiate, start and register mock network communication endpoints. + + This is an async context manager, that returns network endpoints, + and ensures they are all properly closed upon leaving the context. + + Parameters + ---------- + n_peers: + Number of client endpoints to instantiate. + port: + Mock port number to use. + + Returns + ------- + server: + `MockNetworkServer` instance to which clients are registered. + clients: + List of `MockNetworkClient` instances, registered to the server. + """ + # Instantiate the endpoints. + server = MockNetworkServer(port=port) + clients = [ + MockNetworkClient(f"mock://localhost:{port}", name=f"client_{i}") + for i in range(n_peers) + ] + async with contextlib.AsyncExitStack() as stack: + # Start the endpoints and ensure they will be properly closed. + await stack.enter_async_context(server) # type: ignore + for client in clients: + await stack.enter_async_context(client) # type: ignore + # Register the clients with the server. + await asyncio.gather( + server.wait_for_clients(n_peers), + *[client.register() for client in clients], + ) + # Yield the started, registered endpoints. + yield server, clients diff --git a/declearn/test_utils/_secagg.py b/declearn/test_utils/_secagg.py new file mode 100644 index 0000000000000000000000000000000000000000..73c6620f5cb4b7ed5eb1ef321955f5ba7d067afc --- /dev/null +++ b/declearn/test_utils/_secagg.py @@ -0,0 +1,61 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Routine to set up some SecAgg controllers.""" + +import secrets +from typing import List, Tuple + + +from declearn.secagg.masking import MaskingDecrypter, MaskingEncrypter + +__all__ = [ + "build_secagg_controllers", +] + + +def build_secagg_controllers( + n_peers: int, +) -> Tuple[MaskingDecrypter, List[MaskingEncrypter]]: + """Setup aligned masking-based encrypters and decrypter. + + Parameters + ---------- + n_peers: + Number of clients for which to set up an encrypter. + + Returns + ------- + decrypter: + `MaskingDecrypter` instance. + encrypters: + List of `MaskingEncrypter` instances with compatible seeds. + """ + n_pairs = int(n_peers * (n_peers - 1) / 2) + s_keys = [secrets.randbits(32) for _ in range(n_pairs)] + clients = [] # type: List[MaskingEncrypter] + starts = [n_peers - i - 1 for i in range(n_peers)] + starts = [sum(starts[:i]) for i in range(n_peers)] + for idx in range(n_peers): + pos = s_keys[starts[idx] : starts[idx] + n_peers - idx - 1] + neg = [s_keys[starts[j] + idx - j - 1] for j in range(idx)] + clients.append( + MaskingEncrypter(pos_masks_seeds=pos, neg_masks_seeds=neg) + ) + server = MaskingDecrypter(n_peers=n_peers) + return server, clients diff --git a/declearn/training/__init__.py b/declearn/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f1fd445de312e222e8692619891474354ae7171f --- /dev/null +++ b/declearn/training/__init__.py @@ -0,0 +1,39 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Model training and evaluation orchestration tools. + +Classes +------- +The main class implemented here is `TrainingManager`, that is used by clients +and may also be used to perform centralized machine learning using declearn: + +* [TrainingManager][declearn.training.TrainingManager]: + Class wrapping the logic for local training and evaluation rounds. + +Submodules +---------- + +* [dp][declearn.training.dp]: + Differentially-Private training routine utils. + The main class implemented here is `DPTrainingManager` that implements + client-side DP-SGD training. This module is to be manually imported or + lazy-imported (e.g. by `declearn.main.FederatedClient`), and may trigger + warnings or errors in the absence of the 'opacus' third-party dependency. +""" + +from ._manager import TrainingManager diff --git a/declearn/main/utils/_constraints.py b/declearn/training/_constraints.py similarity index 100% rename from declearn/main/utils/_constraints.py rename to declearn/training/_constraints.py diff --git a/declearn/main/utils/_training.py b/declearn/training/_manager.py similarity index 98% rename from declearn/main/utils/_training.py rename to declearn/training/_manager.py index 09182c7bb22324bd2e8e621046b14abf737738cf..a3d158f8dcf3d0ccc4dae7ca8888cd6a1075111f 100644 --- a/declearn/main/utils/_training.py +++ b/declearn/training/_manager.py @@ -26,17 +26,20 @@ import tqdm from declearn.aggregator import Aggregator from declearn.communication import messaging from declearn.dataset import Dataset -from declearn.main.utils._constraints import ( - Constraint, - ConstraintSet, - TimeoutConstraint, -) from declearn.metrics import ( - # fmt: off - MeanMetric, Metric, MetricInputType, MetricSet, MetricState + MeanMetric, + Metric, + MetricInputType, + MetricSet, + MetricState, ) from declearn.model.api import Model from declearn.optimizer import Optimizer +from declearn.training._constraints import ( + Constraint, + ConstraintSet, + TimeoutConstraint, +) from declearn.typing import Batch from declearn.utils import LOGGING_LEVEL_MAJOR, get_logger @@ -260,7 +263,7 @@ class TrainingManager: while not (stop_training or epochs.saturated): for batch in self.train_data.generate_batches(**batch_cfg): try: - self._run_train_step(batch) + self.run_train_step(batch) except StopIteration as exc: self.logger.warning("Interrupting training round: %s", exc) stop_training = True @@ -277,7 +280,7 @@ class TrainingManager: effort.update(constraints.get_values()) return effort - def _run_train_step( + def run_train_step( self, batch: Batch, ) -> None: diff --git a/declearn/training/dp/__init__.py b/declearn/training/dp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0076b9694d5c90149594e9b7ee978482cc2608db --- /dev/null +++ b/declearn/training/dp/__init__.py @@ -0,0 +1,24 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Submodule implementing Differential-Privacy-oriented training tools. + +* [DPTrainingManager][declearn.training.dp.DPTrainingManager]: + TrainingManager subclass implementing Differential Privacy mechanisms. +""" + +from ._manager import DPTrainingManager diff --git a/declearn/main/privacy/_dp_trainer.py b/declearn/training/dp/_manager.py similarity index 99% rename from declearn/main/privacy/_dp_trainer.py rename to declearn/training/dp/_manager.py index 379ef2d58e8862e62355fbaa48eacbe6f0453c88..26aa2245a8e7f71b221a5da64ce00bec7ecd113a 100644 --- a/declearn/main/privacy/_dp_trainer.py +++ b/declearn/training/dp/_manager.py @@ -26,11 +26,11 @@ from opacus.accountants.utils import get_noise_multiplier # type: ignore from declearn.aggregator import Aggregator from declearn.communication import messaging from declearn.dataset import Dataset -from declearn.main.utils import TrainingManager from declearn.metrics import MetricInputType, MetricSet from declearn.model.api import Model from declearn.optimizer import Optimizer from declearn.optimizer.modules import GaussianNoiseModule +from declearn.training import TrainingManager from declearn.typing import Batch @@ -205,7 +205,7 @@ class DPTrainingManager(TrainingManager): epsilon = self.accountant.get_epsilon(delta=delta) return epsilon, delta - def _run_train_step( + def run_train_step( self, batch: Batch, ) -> None: diff --git a/declearn/utils/_toml_config.py b/declearn/utils/_toml_config.py index adc3c2b87cd6fbf0825b37ff35e4905f032e8b98..a3c886021231c37d03934467d5797008e898d242 100644 --- a/declearn/utils/_toml_config.py +++ b/declearn/utils/_toml_config.py @@ -27,7 +27,7 @@ try: except ModuleNotFoundError: import tomli as tomllib -from typing import Any, Dict, Optional, Type, TypeVar, Union +from typing import Any, ClassVar, Dict, Optional, Set, Type, TypeVar, Union from typing_extensions import Self # future: import from typing (py >=3.11) @@ -178,6 +178,14 @@ class TomlConfig: Instantiate by parsing inputs dicts (or objects). """ + autofill_fields: ClassVar[Set[str]] = set() + """Class attribute listing names of auto-fill fields. + + The listed fields do not have a formal default value, but one is + dynamically created upon parsing other fields. As a consequence, + they may safely been ignored in TOML files or input dict params. + """ + @classmethod def from_params( cls, @@ -334,8 +342,8 @@ class TomlConfig: hyper-parameters making up for the FL "run" configuration. warn_user: bool, default=True Boolean indicating whether to raise a warning when some - fields are unused. Useful for cases where unused fields are - expected, e.g. in declearn-quickrun mode. + fields are unused. Useful for cases where unused fields + are expected, e.g. in declearn-quickrun mode. use_section: optional(str), default=None If not None, points to a specific section of the TOML that should be used, rather than the whole file. Useful to parse @@ -381,10 +389,11 @@ class TomlConfig: elif ( field.default is dataclasses.MISSING and field.default_factory is dataclasses.MISSING + and field.name not in cls.autofill_fields ): raise RuntimeError( - "Missing required section in the TOML configuration " - f"file: '{field.name}'." + "Missing section in the TOML configuration file: " + f"'{field.name}'.", ) # Warn about remaining (unused) config sections. if warn_user: diff --git a/declearn/version.py b/declearn/version.py index 559d814484c76615d5e1765f6924d96f8b3ed9ed..70a1a14e9b2da930342dc069a7d713b7a1742893 100644 --- a/declearn/version.py +++ b/declearn/version.py @@ -17,5 +17,5 @@ """DecLearn version information, as hard-coded constants.""" -VERSION = "2.6.0.dev1" +VERSION = "2.6.0.dev2" """Version information of the installed DecLearn package.""" diff --git a/docs/user-guide/SUMMARY.md b/docs/user-guide/SUMMARY.md index 865118d0c1c485e2fcbed0dcf86f34eafd5a0916..f6ee456404523e87decd1bfe1fb1422ad3f856e3 100644 --- a/docs/user-guide/SUMMARY.md +++ b/docs/user-guide/SUMMARY.md @@ -5,3 +5,4 @@ - [Guide to the Optimizer API](./optimizer.md) - [Local Differential Privacy capabilities](./local_dp.md) - [Secure Aggregation capabilities](./secagg.md) +- [Fairness capabilities](./fairness.md) diff --git a/docs/user-guide/fairness.md b/docs/user-guide/fairness.md new file mode 100644 index 0000000000000000000000000000000000000000..0d03d19237d2e6c281c5e0b090fdb5c3873acf22 --- /dev/null +++ b/docs/user-guide/fairness.md @@ -0,0 +1,459 @@ +# Fairness + +DecLearn comes up with algorithms and tools to measure and (attempt to) +enforce fairness constraints when learning a machine learning model +federatively. This guide introduces what (group) fairness is, which +settings and algorithms are available in DecLearn, how to use them, +and how to implement custom algorithms that fit into the API. + +## Overview + +### What is Fairness? + +Fairness in machine learning is a wide area of research and algorithms that +aim at formalizing, measuring and correcting various algorithmic biases that +are deemed undesirable, notably when they result in models under-performing +for some individuals or groups in a way that is correlated with attributes +such as gender, ethnicity or other socio-demographic characteristics that +can also be the source of unfair discrimination in real life. + +Defining what fairness is and how to formalize it is a research topic _per se_, +that has been open and active for the past few years. So is the understanding +of both the causes and consequences of unfairness in machine learning. + +### What is Group Fairness? + +Group Fairness is one of the main families of approaches to defining fairness +in machine learning. It applies to classification problems, and to data that +can be divided into non-overlapping subsets, designated as sensitive groups, +defined by the intersected values of one or more categorical attributes +(designated as sensitive attributes) and (usually, but not always) the target +label. + +For instance, when learning a classifier over a human population, sensitive +attributes may include gender, ethnicity, age groups, etc. Defining relevant +attributes and assigning samples to them can be a sensible issue, which may +motivate a recourse to other families of fairness approaches. + +Formally, we can note $\mathcal{Y}$ the set of values for the target label, +$\mathcal{S}$ the set of (intersected) sensitive attribute values, $S$ the +random variable over $\mathcal{S}$ denoting a sample's sensitive attribute +values, and $Y$ the random variable over $\mathcal{Y}$ denoting its true +target label and $\hat{Y}$ its predicted target label by the evaluated +classifier. + +Various group fairness definitions exist, that can overall be summarized as +achieving a balance between the group-wise accuracy scores of the evaluated +model. The detail of that balance varies with the definitions; some common +choices include: + +- Demographic Parity (also known as Statistical Parity): + $$ + \forall a, b \in \mathcal{S}, \forall y \in \mathcal{Y}, + \mathbb{P}(\hat{Y} = y | S = a) = \mathbb{P}(\hat{Y} = y | S = b) + $$ +- Accuracy Parity: + $$ + \forall a \in \mathcal{S}, \forall y \in \mathcal{Y}, + \mathbb{P}(\hat{Y} = y | S = a) = \mathbb{P}(\hat{Y} = y) + $$ +- Equalized Odds: + $$ + \forall a \in \mathcal{S}, \forall y \in \mathcal{Y}, + \mathbb{P}(\hat{Y} = y | Y = y) = \mathbb{P}(\hat{Y} = y | Y = y, S = a) + $$ + +### Fairness in DecLearn + +Starting with version 2.6, and following a year-long collaboration with +researcher colleagues to develop and evaluate fairness-enforcing federated +learning algorithms, DecLearn is providing with an API and algorithms that +(attempt to) enforce group fairness by altering the model training itself +(as opposed to pre-processing and post-processing methods, that can be +applied outside of DecLearn). + +The dedicated API, shared tools and provided algorithms are implemented +under the `declearn.fairness` submodule, and integrated into the main +`declearn.main.FederatedServer` and `declearn.main.FederatedClient` classes, +enabling to plug a group-fairness algorithm into any federated learning +process. + +Currently, the implemented features enable end-users to choose among a +variety of group fairness definitions, to measure it throughout model +training, and optionally to use one of various algorithms that aim at +enforcing fairness while training a model federatively. The provided +algorithms are either taken from the litterature, original algorithms +awaiting publication, or adaptations of algorithms from the centralized +to the federated learning setting. + +It is worth noting that the fairness API and DecLearn-provided algorithms +are fully-compatible with the use of secure-aggregation (protecting +fairness-related values such as group-wise sample counts, accuracy and +fairness scores), and of any advanced federated optimization strategy (apart +from some algorithms forcing (with a warning) the choice of aggregation rule). + +Local differential privacy can also be used, but the accounting might not +be correct for all algorithms (namely, Fed-FairBatch/FedFB), hence we would +advise careful use after informed analysis of the selected algorithm. + +## Details and caveats + +### Overall setting + +Currently, the DecLearn fairness API is designed so that the fairness being +measured and optimized is computed over the union of all training datasets +held by clients. + +The API is designed to be compatible with any number of sensitive groups, +with regimes where individual clients do not necessarily hold samples to +each and every group, and with all group fairness definitions that can be +expressed in a form that was introduced in the FairGrad paper (Maheshwari +& Perrot, 2023). However, some additional restrictions may be enforced by +concrete definitions and/or algorithms. + +### Available group-fairness definitions + +As of version 2.6.0, DecLearn provides with the following group-fairness +definitions: + +- **Accuracy Parity**, achieved when the model's accuracy is independent + from the sensitive attribute(s) - but not necessarily balanced across + target labels. +- **Demographic Parity**, achieved when the probability to predict a given + label is independent from the sensitive attribute(s) - regardless of + whether that label is accurate or not. _In DecLearn, it is restricted to + binary classification tasks._ +- **Equalized Odds**, achieved when the probability to predict the correct + label is independent from the sensitive attribute(s). +- **Equality of Opportunity**, which is similar to Equalized Odds but is + restricted to an arbitrary subset of target labels. + +### Available algorithms + +As of version 2.6.0, DecLearn provides with the following algorithms, that +can each impose restrictions as to the supported group-fairness definition +and/or number of sensitive groups: + +- [**Fed-FairGrad**][declearn.fairness.fairgrad], an adaptation of FairGrad + (Maheshwari & Perrot, 2023) to the federated learning setting.<br/> + This algorithm reweighs the training loss based on the current fairness + levels of the model, so that advantaged groups contribute less than + disadvantaged ones, and may even contribute negatively (effectively trading + accuracy off in favor of fairness). +- [**Fed-FairBatch**][declearn.fairness.fairbatch], a custom adaptation of + FairBatch (Roh et al., 2020) to the federated learning setting.<br/> + This algorithm alters the way training data batches are drawn, enforcing + sampling probabilities that are based on the current fairness levels of the + model, so that advantaged groups are under-represented and disadvantaged + groups are over-represented relatively to raw group-wise sample counts. +- **FedFB** (Zeng et al., 2022), an arXiv-published alternative adaptation + of FairBatch that is similar to Fed-FairBatch but introduces further + formula changes with respect to the original FairBatch. +- [**FairFed**][declearn.fairness.fairfed] (Ezzeldin et al., 2021), an + algorithm designed for federated learning, with the caveat that authors + designed it to be combined with local fairness-enforcing algorithms, + something that is not yet effortlessly-available in DecLearn.<br/> + This algorithm modifies the aggregation rule based on the discrepancy + between client-wise fairness levels, so that clients for which the model + is more unfair weigh more in the overall model updates than clients for + which the model is fairer. + +### Shared algorithm structure + +The current API sets up a shared structure for all implemented algorithms, +that is divided between two phases. Each of these comprises a basic part +that is shared across algorithms, and an algorithm-specific part that has +varying computation and communication costs depending on the algorithm. + +- The **setup** phase, that occurs as an additional step of the overall + federated learning initialization phase. During that phase: + - Each client sends the list of sensitive group definitions for which they + have samples to the server. + - The server sends back the ordered list of sensitive group definitions + across the union of client datasets. + - Each client communicates the (optionally encrypted) sample counts + associated with these definitions. + - The server (secure-)aggregates these sample counts to initialize the + fairness function on the union of client datasets. + - Any algorithm-specific additional steps occur. For this, the controllers + have access to the network communication endpoints and optional secure + aggregation controllers. On the server side, the `Aggregator` may be + changed (with a warning). On the client side, side effects may occur + on the `TrainingManager` (hence altering future training rounds). + +- The **fairness round**, that is designed to occur prior to training rounds + (and implemented as such as part of `FederatedServer`). During that phase: + - Each client evaluates the fairness of the current (global) model on their + training set. By default, group-wise accuracy scores are computed and + sent to the server, while the local fairness scores are computed and + kept locally. Specific algorithms may change the metrics computed and + shared. Shared metrics are encrypted when secure aggregation is set up. + - The server (secure-)aggregates client-emitted metrics. It will usually + compute the global fairness scores of the model based on group-wise + accuracy scores. + - Any algorithm-specific additional steps occur. For this, the controllers + have access to the network communication endpoints and optional secure + aggregation controllers. On the client side, side effects may occur on + the `TrainingManager` (hence altering future training rounds). + - On both sides, computed metrics are returned, so that they can be + checkpointed as part of the overall federated learning process. + +### A word to fellow researchers in machine learning fairness + +If you are an end-user with a need for fairness, we hope that the current +state of things offers a suitable set of ready-for-use state-of-the-art +algorithms. If you are a researcher however, it is likely that you will +run into limitations of the current API at some point, whether because +you need to change some assumptions or use an algorithm that has less +in common with the structure of the currently-available one than we have +anticipated. Maybe you even want to work on something different from the +group fairness family of approaches. + +As with the rest of DecLearn features, if you run into issues, have trouble +figuring out how to implement something or have suggestions as to how and/or +why the API should evolve, you are welcome to contact us by e-mail or by +opening an issue on our GitLab or GitHub repository, so that we can figure +out solutions and plan evolutions in future DecLearn versions. + +On our side, we plan to keep collaborating with colleagues to design and +evaluate fairness-aware federated learning methods. As in the past, this +will likely happen first in parallel versions of DecLearn prior to being +robustified and integrated in the stable branch - but at any rate, we are +aware and want to make clear that the current API and its limitations are +but the product of a first iteration to come up with solid implementations +of existing algorithms and lay the ground for more research-driven iterations. + +### Caveats and future work + +The current API was abstracted from an ensemble of concrete algorithms. +As such, it may be limiting for end-users that would like to implement +alternative algorithms to tackle fairness-aware federated learning. + +The current API assumes that fairness is to be estimated over the training +dataset of clients, and without client sampling. More generally, available +algorithms are mostly designed assuming that all clients participate to +each fairness evaluation and model training round. The current API offers +some space to use and test algorithms with client sampling, and with limited +computational effort put into the fairness evaluation, but it is not as +modular as one might want, in part because we believe further theoretical +analysis of the implications at hand is required to inform implementation +choices. + +The current API requires fairness definitions to be defined using the form +introduced in the FairGrad paper. While we do believe this form to be a clever +choice, and use it as leverage to be consistent in the way computations are +performed and limit the room left for formula implementation errors, we are +open to revising it in case we are signalled that definitions that do not fit +into that form are desired by someone using DecLearn. + +## How to set up fairness-aware federated learning + +### Description + +Making a federated learning process fairness-aware only requires a couple +of changes to the code executed by the server and clients. + +First, clients need to interface their training data using a subclass of +`declearn.fairness.api.FairnessDataset`, which extends the base +`declearn.dataset.Dataset` API to define sensitive attributes and access +their definitions, sample counts and group-wise sub-dataset. As with the +base `Dataset`, DecLearn provides with a `FairnessInMemoryDataset` that +is suited for tabular data that fits in memory. For other uses, end-users +need to write their own subclass, implementing the various abstract methods +regarding metadata fetching and data batching. + +Second, the server needs to select and configure a fairness algorithm, and +make it part of the federated optimization configuration (wrapped or parsed +using `declearn.main.config.FLOptimConfig`). Some hyper-parameters enabling +to control computational efforts put into evaluating the model's fairness +throughout training may also be specified as part of the run configuration +(`declearn.main.config.FLRunConfig`), which is otherwise auto-filled to use +the same batch size as in evaluation rounds, and the entire training dataset +to compute as robust fairness estimates as possible. + +### Hands-on example + +For this example, we are going to use the +[UCI Heart Disease](https://archive.ics.uci.edu/dataset/45/heart+disease) +dataset, for which we already provide a [base example](https://gitlab.inria.fr/magnet/declearn/declearn2/-/tree/develop/examples/heart-uci/) +implemented via Python scripts. + +This is a binary classification task, for which we are going to define a single +binary sensitive attribute: patients' biological sex. + +**1. Interface training datasets to define sensitive groups** + +On the client side, we simply need to wrap the training dataset as an +`InMemoryFairnessDataset` rather than a base `InMemoryDataset`. This +results in a simple edit in steps (1-2) of the initial +[client script](https://gitlab.inria.fr/magnet/declearn/declearn2/-/tree/develop/examples/heart-uci/client.py): + +```python +from declearn.fairness.core import InMemoryFairnessDataset + +train = InMemoryFairnessDataset( + data=data.iloc[:n_tr], # unchanged + target=target, # unchanged + s_attr=["sex"], # define sensitive attribute(s) + sensitive_target=True, # define fairness relative to Y x S (default) +) +``` + +instead of the initial + +```python +train = InMemoryDataset( + data=data.iloc[:n_tr], + target=target, +) +``` + +Note that: + +- The validation dataset does not need to be wrapped as a `FairnessDataset`, + as fairness is only evaluated on the training dataset during the process. +- By default, sensitive attribute columns are _not_ excluded from feature + columns. The `f_cols` parameter may be used to exclude it. Alternatively, + one could pop the sensitive attribute column(s) from `data` and pass the + resulting DataFrame or numpy array directly as `s_attr`. +- It is important that clients order the sensitive attributes in the same + order. This is also true of feature columns in general in DecLearn. + +**2. Configure a fairness algorithm on the server side** + +On the server side, a `declearn.fairness.api.FairnessControllerServer` subclass +must be selected, instantiated and plugged into the `FLOptimConfig` object (or +dict input as `optim` instantiation parameter to `FederatedServer`). + +For instance, to merely measure the model's fairness wihtout altering the +training process (typically to assess the fairness of a baseline approach), +one may edit step 2 of the initial +[server script](https://gitlab.inria.fr/magnet/declearn/declearn2/-/tree/develop/examples/heart-uci/server.py) +as follows: + +```python +from declearn.fairness.monitor import FairnessMonitorServer + +fairness = FairnessMonitorServer( + f_type="demographic_parity", # choose any fairness definition +) + +optim = FLOptimConfig.from_params( + aggregator=aggregator, # unchanged + client_opt=client_opt, # unchanged + server_opt=server_opt, # unchanged + fairness=fairness, # set up the fairness monitoring algorithm +) +``` + +To use a fairness-enforcing algorithm, simply instantiate another type of +controller. For instance, to use Fed-FairGrad: + +```python +from declearn.fairness.fairgrad import FairgradControllerServer + +fairness = FairgradControllerServer( + f_type="demographic_parity", # choose any fairness definition + eta=0.1, # adjust this based on the SGD learning rate and empirical tuning + eps=0.0, # change this to configure epsilon-fairness +) +``` + +Equivalently, the choice of fairness controller class and parameters may be +specified using a configuration dict (that may be parsed from a TOML file): + +```python +fairness = { + # mandatory parmaeters: + "algorithm": "fairgrad", # name of the algorithm + "f_type": "demographic_parity", # name of the group-fairness definition + # optional, algorithm-dependent hyper-parameters: + "eta": 0.1, + "eps": 0.0, +} +``` + +Notes: + +- `declearn.fairness.core.list_fairness_functions` may be used to review all + available fairness definitions and their registration name. +- `FairnessControllerServer` all expose an optional `f_args` parameter that + can be used to parameterize the fairness definition, _e.g._ to define + target labels on which to focus when using "equality_of_opportunity". +- `FLRunConfig` exposes an optional `fairness` field that can be used to + reduce the computational effort put in evaluating fairness. Otherwise, + it is automatically filled with a mix of default values and values parsed + from the training and/or evaluation round configuration. + +## How to implement custom fairness definitions and algorithms + +As with most DecLearn components, group fairness definitions and algorithms +are designed to be extendable by third-party developers and end-users. On +the one hand, new definitions of group fairness may be implemented by +subclassing `declearn.fairness.api.FairnessFunction`. On the other hand, +new algorithms for fairness-aware federated learning may be implemented by +subclassing both `declearn.fairness.api.FairnessControllerServer` and +`declearn.fairness.api.FairnessControllerClient`. + +### Implement a new group fairness definition + +Group fairness definitions are implemented by subclassing the API-defining +abstract base class [declearn.fairness.api.FairnessFunction][]. Extensive +details can be found in the API docs of that class. Overall, the process +is the following: + +- Declare a `FairnessFunction` subclass. +- Define its `f_type` string class attribute, that must be unique across + subclasses, and will be used to type-register this class and make it + available in controllers. +- Define its `compute_fairness_contants` method, which must return $C_k^{k'}$ + constants defining the group-wise fairness level computations based on + group-wise sample counts. + +The latter method, and overall API, echo the generic formulation for fairness +functions introduced in the FairGrad paper (Maheshwari & Perrot, 2023). If +this is limiting for your application, please let us know. If you are using +definitions that are specific to your custom algorithm, you may "simply" tweak +around the API when implementing controllers (see the following section). + +### Implement a new fairness-enforcing algorithm + +Fairness enforcing algorithms are implemented by subclassing API-defining +abstract base classes [declearn.fairness.api.FairnessControllerServer][] +and [declearn.fairness.api.FairnessControllerClient][]. Extensive details +can be found in the API docs of these classes. Overall, the process is +the following: + +- Declare the paired subclasses. +- Define their `algorithm` string class attribute, which must be the same + for paired classes, and distinct from that of other subclasses. It is + used to register the types and make serializable generic instantation + instructions. +- Define their respective `finalize_fairness_setup` methods, to take any + algorithm-specific steps once sensitive group definitions and sample + counts have been exchanged. +- Define their respected `finalize_fairness_round` methods, to take any + algorithm-specific steps once fairness-related metrics have been computed + by clients and (secure-)aggregated by the server. +- Optionally, overload or override the client's `setup_fairness_metrics` + method, that defines the fairness-related metrics being computed and + shared as part of fairness rounds. + +When overloading the `__init__` method of subclasses, you may add additional +restrictions as to the supported fairness definitions and/or number and +nature of sensitive groups. + +## References + +- Maheshwari & Perrot (2023). + FairGrad: Fairness Aware Gradient Descent. + [https://openreview.net/forum?id=0f8tU3QwWD]() +- Roh et al. (2020). + FairBatch: Batch Selection for Model Fairness. + [https://arxiv.org/abs/2012.01696]() +- Zeng et al. (2022). + Improving Fairness via Federated Learning. + [https://arxiv.org/abs/2110.15545]() +- Ezzeldin et al. (2021). + FairFed: Enabling Group Fairness in Federated Learning + [https://arxiv.org/abs/2110.00857]() diff --git a/docs/user-guide/index.md b/docs/user-guide/index.md index 362050cb74125820c35e2c0302ee6458b9aec886..86b4099233724e8dbf0bf4b71ef32e720618c5da 100644 --- a/docs/user-guide/index.md +++ b/docs/user-guide/index.md @@ -18,3 +18,5 @@ This guide is structured this way: Description of the local-DP features of declearn. - [Secure Aggregation capabilities](./secagg.md):<br/> Description of the SecAgg features of declearn. +- [Fairness capabilities](./fairness.md):<br/> + Description of the fairness-aware federated learning features of declearn. diff --git a/pyproject.toml b/pyproject.toml index 846c6d4192bb4de1233d44eea747718723f511b2..13163b577144f87ee587e488a82e66d958fef5c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ requires = [ [project] name = "declearn" -version = "2.6.0.dev1" +version = "2.6.0.dev2" description = "Declearn - a python package for private decentralized learning." readme = "README.md" requires-python = ">=3.8" diff --git a/scripts/gen_docs.py b/scripts/gen_docs.py index 645ca0ff357916c297c8304b2116f31ffedfeb07..742e64ec999d10a0da590252db8812f2c92ad298 100644 --- a/scripts/gen_docs.py +++ b/scripts/gen_docs.py @@ -165,7 +165,11 @@ def _generate_private_submodules_content_doc( ) -> Dict[str, str]: """Create files for public contents from a module's private submodules.""" pub_obj = {} - for key, obj in module.members.items(): + if module.exports is None: + members = module.members + else: + members = {str(k): module.members[str(k)] for k in module.exports} + for key, obj in members.items(): if obj.is_module or obj.module.name in pub_mod or key.startswith("_"): continue if not (obj.docstring or obj.is_class or obj.is_function): diff --git a/test/fairness/algorithms/test_fairbatch_dataset.py b/test/fairness/algorithms/test_fairbatch_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a749d27000904065165919fcd62422c60bbd0ce2 --- /dev/null +++ b/test/fairness/algorithms/test_fairbatch_dataset.py @@ -0,0 +1,204 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for FairBatch dataset wrapper.""" + +from unittest import mock + +import numpy as np +import pandas as pd +import pytest + +from declearn.fairness.api import FairnessDataset +from declearn.fairness.core import FairnessInMemoryDataset +from declearn.fairness.fairbatch import FairbatchDataset + + +COUNTS = {(0, 0): 30, (0, 1): 15, (1, 0): 35, (1, 1): 20} + + +class TestFairbatchDataset: + """Unit tests for 'declearn.fairness.fairbatch.FairbatchDataset'.""" + + def setup_mock_base_dataset(self) -> mock.Mock: + """Return a mock FairnessDataset with arbitrary groupwise counts.""" + base = mock.create_autospec(FairnessDataset, instance=True) + base.get_sensitive_group_definitions.return_value = list(COUNTS) + base.get_sensitive_group_counts.return_value = COUNTS + return base + + def test_wrapped_methods(self) -> None: + """Test that API-defined methods are properly wrapped.""" + # Instantiate a FairbatchDataset wrapping a mock FairnessDataset. + base = mock.create_autospec(FairnessDataset, instance=True) + data = FairbatchDataset(base) + # Test API-defined getters. + assert data.get_data_specs() is base.get_data_specs.return_value + assert data.get_sensitive_group_definitions() is ( + base.get_sensitive_group_definitions.return_value + ) + assert data.get_sensitive_group_counts() is ( + base.get_sensitive_group_counts.return_value.copy() + ) + group = mock.create_autospec(tuple, instance=True) + assert data.get_sensitive_group_subset(group) is ( + base.get_sensitive_group_subset.return_value + ) + base.get_sensitive_group_subset.assert_called_once_with(group) + # Test API-defined setter. + weights = mock.create_autospec(dict, instance=True) + adjust_by_counts = mock.create_autospec(bool, instance=True) + data.set_sensitive_group_weights(weights, adjust_by_counts) + base.set_sensitive_group_weights.assert_called_once_with( + weights, adjust_by_counts + ) + + def test_get_sampling_probabilities_initial(self) -> None: + """Test 'get_sampling_probabilities' upon initialization.""" + # Instantiate a FairbatchDataset wrapping a mock FairnessDataset. + base = self.setup_mock_base_dataset() + data = FairbatchDataset(base) + # Access initial sampling probabilities and verify their value. + probas = data.get_sampling_probabilities() + assert isinstance(probas, dict) + assert probas.keys() == COUNTS.keys() + assert all(isinstance(val, float) for val in probas.values()) + expected = {key: 1 / len(COUNTS) for key in COUNTS} + assert probas == expected + + def test_set_sampling_probabilities_simple(self) -> None: + """Test 'set_sampling_probabilities' with matching groups.""" + data = FairbatchDataset(base=self.setup_mock_base_dataset()) + # Assign arbitrary probabilities that match local groups. + probas = {group: idx / 10 for idx, group in enumerate(COUNTS, 1)} + data.set_sampling_probabilities(group_probas=probas) + # Test that inputs were assigned. + assert data.get_sampling_probabilities() == probas + + def test_set_sampling_probabilities_unnormalized(self) -> None: + """Test 'set_sampling_probabilities' with un-normalized values.""" + data = FairbatchDataset(base=self.setup_mock_base_dataset()) + # Assign arbitrary probabilities that do not sum to 1. + probas = {group: float(idx) for idx, group in enumerate(COUNTS, 1)} + expect = {key: val / 10 for key, val in probas.items()} + data.set_sampling_probabilities(group_probas=probas) + # Test that inputs were cprrected, then assigned. + assert data.get_sampling_probabilities() == expect + + def test_set_sampling_probabilities_superset(self) -> None: + """Test 'set_sampling_probabilities' with unrepresented groups.""" + data = FairbatchDataset(base=self.setup_mock_base_dataset()) + # Assign arbitrary probabilities that cover a superset of local groups. + probas = {group: idx / 10 for idx, group in enumerate(COUNTS, 1)} + expect = probas.copy() + probas[(2, 0)] = probas[(2, 1)] = 0.2 + data.set_sampling_probabilities(group_probas=probas) + # Test that inputs were corrected, then assigned. + assert data.get_sampling_probabilities() == expect + + def test_set_sampling_probabilities_invalid_values(self) -> None: + """Test 'set_sampling_probabilities' with negative values.""" + probas = {group: float(idx) for idx, group in enumerate(COUNTS, -2)} + data = FairbatchDataset(base=self.setup_mock_base_dataset()) + with pytest.raises(ValueError): + data.set_sampling_probabilities(group_probas=probas) + + def test_set_sampling_probabilities_invalid_groups(self) -> None: + """Test 'set_sampling_probabilities' with missing groups.""" + probas = { + group: idx / 6 for idx, group in enumerate(list(COUNTS)[1:], 1) + } + data = FairbatchDataset(base=self.setup_mock_base_dataset()) + with pytest.raises(ValueError): + data.set_sampling_probabilities(group_probas=probas) + + def setup_simple_dataset(self) -> FairbatchDataset: + """Set up a simple FairbatchDataset with arbitrary data. + + Samples have a single feature, reflecting the sensitive + group to which they belong. + """ + samples = [ + sample + for idx, (group, n_samples) in enumerate(COUNTS.items()) + for sample in [(group[0], group[1], idx)] * n_samples + ] + base = FairnessInMemoryDataset( + data=pd.DataFrame(samples, columns=["target", "s_attr", "value"]), + f_cols=["value"], + target="target", + s_attr=["s_attr"], + sensitive_target=True, + ) + # Wrap it up as a FairbatchDataset and assign arbitrary probabilities. + return FairbatchDataset(base) + + def test_generate_batches_simple(self) -> None: + """Test that 'generate_batches' has expected behavior.""" + # Setup a simple dataset and assign arbitrary sampling probabilities. + data = self.setup_simple_dataset() + data.set_sampling_probabilities( + {group: idx / 10 for idx, group in enumerate(COUNTS, start=1)} + ) + # Generate batches with a low batch size. + # Verify that outputs match expectations. + batches = list(data.generate_batches(batch_size=10)) + assert len(batches) == 10 + expect_x = np.array( + [[idx] for idx in range(len(COUNTS)) for _ in range(idx + 1)] + ) + expect_y = np.array( + [lab for idx, (lab, _) in enumerate(COUNTS, 1) for _ in range(idx)] + ) + for batch in batches: + assert isinstance(batch, tuple) and (len(batch) == 3) + assert isinstance(batch[0], np.ndarray) + assert (batch[0] == expect_x).all() + assert isinstance(batch[1], np.ndarray) + assert (batch[1] == expect_y).all() + assert batch[2] is None + + def test_generate_batches_large(self) -> None: + """Test that 'generate_batches' has expected behavior.""" + # Setup a simple dataset and assign arbitrary sampling probabilities. + data = self.setup_simple_dataset() + data.set_sampling_probabilities( + {group: idx / 10 for idx, group in enumerate(COUNTS, start=1)} + ) + # Generate batches with a high batch size. + # Verify that outputs match expectations. + batches = list(data.generate_batches(batch_size=100)) + assert len(batches) == 1 + assert isinstance(batches[0][0], np.ndarray) + assert isinstance(batches[0][1], np.ndarray) + assert batches[0][2] is None + expect_x = np.array( + [ + [idx] + for idx in range(len(COUNTS)) + for _ in range(10 * (idx + 1)) + ] + ) + expect_y = np.array( + [ + lab + for idx, (lab, _) in enumerate(COUNTS, 1) + for _ in range(idx * 10) + ] + ) + assert (batches[0][0] == expect_x).all() + assert (batches[0][1] == expect_y).all() diff --git a/test/fairness/algorithms/test_fairbatch_sampling.py b/test/fairness/algorithms/test_fairbatch_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..5584f05092ea1923c7533cbce1fb3e8bdcdf38dc --- /dev/null +++ b/test/fairness/algorithms/test_fairbatch_sampling.py @@ -0,0 +1,185 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for FairBatch sampling probability controllers.""" + + +import pytest + +from declearn.fairness.api import FairnessFunction +from declearn.fairness.fairbatch import ( + FairbatchSamplingController, + setup_fairbatch_controller, + setup_fedfb_controller, +) + + +ALPHA = 0.05 +COUNTS = {(0, 0): 30, (0, 1): 15, (1, 0): 35, (1, 1): 20} +F_TYPES = [ + "demographic_parity", + "equality_of_opportunity", + "equalized_odds", +] + + +@pytest.fixture(name="controller") +def controller_fixture( + f_type: str, + fedfb: bool, +) -> FairbatchSamplingController: + """Fixture providing with a given 'FairbatchSamplingController'.""" + if fedfb: + return setup_fedfb_controller(f_type, counts=COUNTS, alpha=ALPHA) + return setup_fairbatch_controller(f_type, counts=COUNTS, alpha=ALPHA) + + +@pytest.mark.parametrize("fedfb", [False, True], ids=["FedFairBatch", "FedFB"]) +@pytest.mark.parametrize("f_type", F_TYPES) +class TestFairbatchSamplingController: + """Shared unit tests for all 'FairbatchSamplingController' subclasses.""" + + def test_init( + self, + controller: FairbatchSamplingController, + ) -> None: + """Test that the instantiated controller is coherent.""" + # Assert that groups and counts attributes have expected specs. + assert isinstance(controller.groups, dict) + assert isinstance(controller.counts, dict) + assert all(isinstance(key, str) for key in controller.groups) + assert set(controller.groups.values()) == set(controller.counts) + assert controller.counts == COUNTS + # Verify that other hyper-parameters were properly passed. + assert controller.alpha == ALPHA + assert isinstance(controller.f_func, FairnessFunction) + assert controller.f_func.f_type == controller.f_type + + def test_compute_initial_states( + self, + controller: FairbatchSamplingController, + ) -> None: + """Test that 'compute_initial_states' has proper output types.""" + states = controller.compute_initial_states() + assert all( + isinstance(key, str) and isinstance(val, (int, float)) + for key, val in states.items() + ) + assert controller.states == states # initial states + + def test_get_sampling_probas( + self, + controller: FairbatchSamplingController, + ) -> None: + """Test that 'get_sampling_probas' outputs coherent values.""" + probas = controller.get_sampling_probas() + # Assert that probabilities are a dict with expected keys. + assert isinstance(probas, dict) + assert set(probas.keys()) == set(controller.counts) + # Assert that values are floats and sum to one (up to a small epsilon). + assert all(isinstance(val, float) for val in probas.values()) + assert abs(1 - sum(probas.values())) < 0.001 + + def test_update_from_losses( + self, + controller: FairbatchSamplingController, + ) -> None: + """Test that 'update_from_losses' alters states and output probas. + + This test does not verify that the maths match initial papers. + """ + # Record initial states and sampling probabilities. + initial_states = controller.states.copy() + initial_probas = controller.get_sampling_probas() + # Perform an update with arbitrary loss values. + losses = {group: float(idx) for idx, group in enumerate(COUNTS)} + controller.update_from_losses(losses) + # Verify that states and probas have changed. + assert controller.states != initial_states + probas = controller.get_sampling_probas() + assert probas.keys() == initial_probas.keys() + assert probas != initial_probas + # Verify that output probabilities sum to one (up to a small epsilon). + assert abs(1 - sum(probas.values())) < 0.001 + + def test_update_from_federated_losses( + self, + controller: FairbatchSamplingController, + ) -> None: + """Test that 'update_from_federated_losses' has expected outputs.""" + # Update from arbitrary countes-scaled losses and gather states. + losses = {group: float(idx) for idx, group in enumerate(COUNTS)} + controller.update_from_federated_losses( + {key: val * controller.counts[key] for key, val in losses.items()} + ) + states = controller.states.copy() + # Reset states and use unscaled values via basic update method. + controller.states = controller.compute_initial_states() + controller.update_from_losses(losses) + # Assert that resulting states are the same. + assert controller.states == states + + +@pytest.mark.parametrize("fedfb", [False, True], ids=["FedFairBatch", "FedFB"]) +@pytest.mark.parametrize("f_type", F_TYPES) +def test_setup_controller_parameters( + f_type: str, + fedfb: bool, +) -> None: + """Test that controller setup properly passes input parameters.""" + function = setup_fedfb_controller if fedfb else setup_fairbatch_controller + controller = function(f_type=f_type, counts=COUNTS.copy(), target=0) + assert controller.f_type == f_type + if f_type == "equality_of_opportunity": + assert controller.f_func.get_specs()["target"] == [0] + + +@pytest.mark.parametrize("fedfb", [False, True], ids=["FedFairBatch", "FedFB"]) +def test_setup_controller_invalid_ftype( + fedfb: bool, +) -> None: + """Test that controller setup raises a KeyError on invalid 'f_type'.""" + function = setup_fedfb_controller if fedfb else setup_fairbatch_controller + with pytest.raises(KeyError): + function(f_type="invalid_f_type", counts=COUNTS.copy()) + + +@pytest.mark.parametrize("fedfb", [False, True], ids=["FedFairBatch", "FedFB"]) +def test_setup_controller_invalid_groups( + fedfb: bool, +) -> None: + """Test that controller setup raises a ValueError on invalid groups.""" + function = setup_fedfb_controller if fedfb else setup_fairbatch_controller + # Case when there are more than 4 groups. + counts = COUNTS.copy() + counts[(2, 0)] = counts[(2, 1)] = 5 # add a third target label value + with pytest.raises(ValueError): + function(f_type="demographic_parity", counts=counts) + # Case when there are 4 ill-defined groups. + counts = {(0, 0): 10, (1, 0): 10, (2, 0): 10, (2, 1): 10} + with pytest.raises(ValueError): + function(f_type="demographic_parity", counts=counts) + + +@pytest.mark.parametrize("fedfb", [False, True], ids=["FedFairBatch", "FedFB"]) +def test_setup_controller_invalid_target( + fedfb: bool, +) -> None: + """Test that controller setup raises a ValueError on invalid target.""" + function = setup_fedfb_controller if fedfb else setup_fairbatch_controller + with pytest.raises(ValueError): + function(f_type="demographic_parity", counts=COUNTS.copy(), target=2) diff --git a/test/fairness/algorithms/test_fairfed_aggregator.py b/test/fairness/algorithms/test_fairfed_aggregator.py new file mode 100644 index 0000000000000000000000000000000000000000..2238ad6e07283bf6b438d1437e909340702b0323 --- /dev/null +++ b/test/fairness/algorithms/test_fairfed_aggregator.py @@ -0,0 +1,87 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for FairFed-specific Aggregator subclass.""" + +from unittest import mock + + +from declearn.fairness.fairfed import FairfedAggregator +from declearn.model.api import Vector + + +class TestFairfedAggregator: + """Unit tests for 'declearn.fairness.fairfed.FairfedAggregator'.""" + + def test_init_beta(self) -> None: + """Test that the 'beta' parameter is properly assigned.""" + beta = mock.create_autospec(float, instance=True) + aggregator = FairfedAggregator(beta=beta) + assert aggregator.beta is beta + + def test_prepare_for_sharing_initial(self) -> None: + """Test that 'prepare_for_sharing' has expected outputs at first.""" + # Set up an uninitialized aggregator and prepare mock updates. + aggregator = FairfedAggregator(beta=1.0) + updates = mock.create_autospec(Vector, instance=True) + model_updates = aggregator.prepare_for_sharing(updates, n_steps=10) + # Verify that outputs match expectations. + updates.__mul__.assert_called_once_with(1.0) + assert model_updates.updates is updates.__mul__.return_value + assert model_updates.weights == 1.0 + + def test_initialize_local_weight(self) -> None: + """Test that 'initialize_local_weight' works properly.""" + # Set up an aggregator, initialize it and prepare mock updates. + n_samples = 100 + aggregator = FairfedAggregator(beta=1.0) + aggregator.initialize_local_weight(n_samples=n_samples) + updates = mock.create_autospec(Vector, instance=True) + model_updates = aggregator.prepare_for_sharing(updates, n_steps=10) + # Verify that outputs match expectations. + updates.__mul__.assert_called_once_with(n_samples) + assert model_updates.updates is updates.__mul__.return_value + assert model_updates.weights == n_samples + + def test_update_local_weight(self) -> None: + """Test that 'update_local_weight' works properly.""" + # Set up a FairFed aggregator and initialize it. + n_samples = 100 + aggregator = FairfedAggregator(beta=0.1) + aggregator.initialize_local_weight(n_samples=n_samples) + # Perform a local wiehgt update with arbitrary values. + aggregator.update_local_weight(delta_loc=2.0, delta_avg=5.0) + # Verify that updates have expected weight. + updates = mock.create_autospec(Vector, instance=True) + expectw = n_samples - 0.1 * (2.0 - 5.0) # w_0 - beta * diff_delta + model_updates = aggregator.prepare_for_sharing(updates, n_steps=10) + updates.__mul__.assert_called_once_with(expectw) + assert model_updates.updates is updates.__mul__.return_value + assert model_updates.weights == expectw + + def test_finalize_updates(self) -> None: + """Test that 'finalize_updates' works as expected.""" + # Set up a FairFed aggregator and initialize it. + n_samples = 100 + aggregator = FairfedAggregator(beta=0.1) + aggregator.initialize_local_weight(n_samples=n_samples) + # Prepare, then finalize updates. + updates = mock.create_autospec(Vector, instance=True) + output = aggregator.finalize_updates( + aggregator.prepare_for_sharing(updates, n_steps=mock.MagicMock()) + ) + assert output == (updates * n_samples / n_samples) diff --git a/test/fairness/algorithms/test_fairfed_computer.py b/test/fairness/algorithms/test_fairfed_computer.py new file mode 100644 index 0000000000000000000000000000000000000000..b38e883812061110c6fd12651b0d6e78be687d87 --- /dev/null +++ b/test/fairness/algorithms/test_fairfed_computer.py @@ -0,0 +1,157 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for FairFed-specific fairness value computer.""" + +import warnings +from typing import Any, List, Tuple + +import pytest + +from declearn.fairness.fairfed import FairfedValueComputer + + +GROUPS_BINARY = [ + (target, s_attr) for target in (0, 1) for s_attr in (0, 1) +] # type: List[Tuple[Any, ...]] +GROUPS_EXTEND = [ + (tgt, s_a, s_b) for tgt in (0, 1, 2) for s_a in (0, 1) for s_b in (1, 2) +] # type: List[Tuple[Any, ...]] +F_TYPES = [ + "accuracy_parity", + "demographic_parity", + "equality_of_opportunity", + "equalized_odds", +] + + +class TestFairfedValueComputer: + """Unit tests for 'declearn.fairness.fairfed.FairfedValueComputer'.""" + + @pytest.mark.parametrize("target", [1, 0], ids=["target1", "target0"]) + @pytest.mark.parametrize("f_type", F_TYPES) + def test_identify_key_groups_binary( + self, + f_type: str, + target: int, + ) -> None: + """Test 'identify_key_groups' with binary target and attribute.""" + computer = FairfedValueComputer(f_type, strict=True, target=target) + if f_type == "accuracy_parity": + with pytest.warns(RuntimeWarning): + key_groups = computer.identify_key_groups(GROUPS_BINARY.copy()) + else: + key_groups = computer.identify_key_groups(GROUPS_BINARY.copy()) + assert key_groups == ((target, 0), (target, 1)) + + @pytest.mark.parametrize("f_type", F_TYPES) + def test_identify_key_groups_extended_exception( + self, + f_type: str, + ) -> None: + """Test 'identify_key_groups' exception raising with extended groups. + + 'Extended' groups arise from a non-binary label intersected with + two distinct binary sensitive groups. + """ + computer = FairfedValueComputer(f_type, strict=True, target=1) + with pytest.raises(RuntimeError): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RuntimeWarning) + computer.identify_key_groups(GROUPS_EXTEND.copy()) + + @pytest.mark.parametrize("f_type", F_TYPES) + def test_identify_key_groups_hybrid_exception( + self, + f_type: str, + ) -> None: + """Test 'identify_key_groups' exception raising with 'hybrid' groups. + + 'Hybrid' groups are groups that seemingly arise from a categorical + target that does not cross all sensitive attribute modalities. + """ + computer = FairfedValueComputer(f_type, strict=True, target=1) + with pytest.raises(KeyError): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RuntimeWarning) + computer.identify_key_groups([(0, 0), (0, 1), (1, 0), (2, 1)]) + + @pytest.mark.parametrize("binary", [True, False], ids=["binary", "extend"]) + @pytest.mark.parametrize("strict", [True, False], ids=["strict", "free"]) + @pytest.mark.parametrize("f_type", F_TYPES[1:]) # avoid warning on AccPar + def test_initialize( + self, + f_type: str, + strict: bool, + binary: bool, + ) -> None: + """Test that 'initialize' raises an exception in expected cases.""" + computer = FairfedValueComputer(f_type, strict=strict, target=1) + groups = (GROUPS_BINARY if binary else GROUPS_EXTEND).copy() + if strict and not binary: + with pytest.raises(RuntimeError): + computer.initialize(groups) + else: + computer.initialize(groups) + + @pytest.mark.parametrize("strict", [True, False], ids=["strict", "free"]) + def test_compute_synthetic_fairness_value_binary( + self, + strict: bool, + ) -> None: + """Test 'compute_synthetic_fairness_value' with 4 groups. + + This test only applies to both strict and non-strict modes. + """ + # Compute a synthetic value using arbitrary inputs. + fairness = { + group: float(idx) for idx, group in enumerate(GROUPS_BINARY) + } + computer = FairfedValueComputer( + f_type="demographic_parity", + strict=strict, + target=1, + ) + computer.initialize(list(fairness)) + value = computer.compute_synthetic_fairness_value(fairness) + # Verify that the ouput value matches expectations. + if strict: + expected = fairness[(1, 0)] - fairness[(1, 1)] + else: + expected = sum(fairness.values()) / len(fairness) + assert value == expected + + def test_compute_synthetic_fairness_value_extended( + self, + ) -> None: + """Test 'compute_synthetic_fairness_value' with many groups. + + This test only applies to the non-strict mode. + """ + # Compute a synthetic value using arbitrary inputs. + fairness = { + group: float(idx) for idx, group in enumerate(GROUPS_EXTEND) + } + computer = FairfedValueComputer( + f_type="demographic_parity", + strict=False, + ) + computer.initialize(list(fairness)) + value = computer.compute_synthetic_fairness_value(fairness) + # Verify that the ouput value matches expectations. + expected = sum(fairness.values()) / len(fairness) + assert value == expected diff --git a/test/fairness/algorithms/test_fairgrad_weights_controller.py b/test/fairness/algorithms/test_fairgrad_weights_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..4259ee4e7de455d29f170050aee2ec7265a6b008 --- /dev/null +++ b/test/fairness/algorithms/test_fairgrad_weights_controller.py @@ -0,0 +1,134 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for FairGrad weights computation controller.""" + +from unittest import mock + +import numpy as np +import pytest + +from declearn.fairness.api import FairnessFunction +from declearn.fairness.fairgrad import FairgradWeightsController + + +# pylint: disable=duplicate-code +COUNTS = {(0, 0): 30, (0, 1): 15, (1, 0): 35, (1, 1): 20} +F_TYPES = [ + "accuracy_parity", + "demographic_parity", + "equality_of_opportunity", + "equalized_odds", +] +# pylint: enable=duplicate-code + + +class TestFairgradWeightsController: + """Unit tests for 'FairgradWeightsController'. + + These tests cover both the formal behavior of methods + and the correctness of the wrapped math operations. + """ + + @pytest.mark.parametrize("f_type", F_TYPES) + def test_init( + self, + f_type: str, + ) -> None: + """Test that instantiation hyper-parameters are properly passed.""" + eta = mock.create_autospec(float, instance=True) + eps = mock.create_autospec(float, instance=True) + controller = FairgradWeightsController( + counts=COUNTS.copy(), + f_type=f_type, + eta=eta, + eps=eps, + ) + assert controller.eta is eta + assert controller.eps is eps + assert controller.total == sum(COUNTS.values()) + assert isinstance(controller.function, FairnessFunction) + assert controller.function.f_type == f_type + assert (controller.f_k == 0).all() + + def test_get_current_weights_initial(self) -> None: + """Test that initial weights are properly computed / accessed.""" + controller = FairgradWeightsController( + counts=COUNTS.copy(), f_type="accuracy_parity" + ) + # Verify that initial weights match expectations (i.e. P(T_k)). + weights = controller.get_current_weights(norm_nk=False) + expectw = [val / controller.total for val in COUNTS.values()] + assert weights == expectw + # Verify that 'norm_nk' parameter has proper effect. + weights = controller.get_current_weights(norm_nk=True) + expectw = [1 / controller.total] * len(COUNTS) + assert weights == expectw + + @pytest.mark.parametrize("exact", [True, False], ids=["exact", "epsilon"]) + @pytest.mark.parametrize("f_type", F_TYPES) + def test_update_weights_based_on_accuracy( + self, + f_type: str, + exact: bool, + ) -> None: + """Test that weights update works properly.""" + # Setup a controller and update its weights using arbitrary values. + controller = FairgradWeightsController( + counts=COUNTS.copy(), f_type=f_type, eps=0.0 if exact else 0.01 + ) + accuracy = {group: 0.2 * idx for idx, group in enumerate(COUNTS)} + controller.update_weights_based_on_accuracy(accuracy) + # Verify that proper values were assigned as current fairness. + f_k = controller.function.compute_from_federated_group_accuracy( + accuracy + ) + assert (controller.f_k == np.array(list(f_k.values()))).all() + # Verify that expected weights are returned. + c_kk = controller.function.constants[1] + p_tk = controller.counts / controller.total + if exact: + w_tk = controller.eta * controller.f_k + else: + w_tk = np.abs(controller.f_k) + w_tk = controller.eta * np.where( + w_tk > controller.eps, w_tk - controller.eps, 0.0 + ) + expectw = p_tk + np.dot(w_tk, c_kk) + weights = controller.get_current_weights(norm_nk=False) + assert np.allclose(expectw, np.array(weights), atol=0.001) + # Same check with 'norm_nk = True'. + expectw /= controller.counts + weights = controller.get_current_weights(norm_nk=True) + assert np.allclose(expectw, np.array(weights), atol=0.001) + + @pytest.mark.parametrize("f_type", F_TYPES) + def test_get_current_fairness( + self, + f_type: str, + ) -> None: + """Test that access to current fairness values works properly.""" + controller = FairgradWeightsController( + counts=COUNTS.copy(), f_type=f_type + ) + accuracy = {group: 0.2 * idx for idx, group in enumerate(COUNTS)} + controller.update_weights_based_on_accuracy(accuracy) + fairness = controller.get_current_fairness() + assert isinstance(fairness, dict) + assert fairness == dict( + zip(controller.function.groups, controller.f_k) + ) diff --git a/test/fairness/controllers/fairness_controllers_testing.py b/test/fairness/controllers/fairness_controllers_testing.py new file mode 100644 index 0000000000000000000000000000000000000000..8641ef38bab6166db9f6c0f47444a7030203fb59 --- /dev/null +++ b/test/fairness/controllers/fairness_controllers_testing.py @@ -0,0 +1,509 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared unit tests for Fairness controllers.""" + +import asyncio +import logging +import warnings +from unittest import mock +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import numpy as np +import pytest + +from declearn.aggregator import Aggregator +from declearn.communication.api import NetworkServer +from declearn.communication.utils import verify_server_message_validity +from declearn.fairness.api import ( + FairnessControllerClient, + FairnessControllerServer, + FairnessDataset, +) +from declearn.messaging import ( + FairnessQuery, + FairnessReply, + FairnessSetupQuery, + SerializedMessage, +) +from declearn.metrics import MeanMetric +from declearn.model.api import Model +from declearn.secagg.api import Decrypter, Encrypter +from declearn.secagg.messaging import SecaggFairnessReply +from declearn.test_utils import ( + assert_dict_equal, + build_secagg_controllers, + setup_mock_network_endpoints, +) +from declearn.training import TrainingManager + + +# Define arbitrary group definitions and sample counts. +CLIENT_COUNTS = [ + {(0, 0): 10, (0, 1): 10, (1, 0): 10, (1, 1): 10}, + {(0, 0): 10, (1, 0): 15, (1, 1): 10}, + {(0, 0): 10, (0, 1): 5, (1, 0): 10}, +] +TOTAL_COUNTS = {(0, 0): 30, (0, 1): 15, (1, 0): 35, (1, 1): 20} + + +def build_mock_dataset(idx: int) -> mock.Mock: + """Return a mock FairnessDataset with deterministic group counts.""" + counts = CLIENT_COUNTS[idx] + dataset = mock.create_autospec(FairnessDataset, instance=True) + dataset.get_sensitive_group_definitions.return_value = list(counts) + dataset.get_sensitive_group_counts.return_value = counts + return dataset + + +class FairnessControllerTestSuite: + """Shared test suite for Fairness controllers.""" + + # Types of controllers associated with a given test suite subclass. + server_cls: Type[FairnessControllerServer] + client_cls: Type[FairnessControllerClient] + + # Default expected local computed metrics. May be overloaded by subclasses. + mock_client_metrics = [ + {"accuracy": {group: 1.0 for group in CLIENT_COUNTS[idx]}} + for idx in range(len(CLIENT_COUNTS)) + ] + + def setup_server_controller(self) -> FairnessControllerServer: + """Instantiate and return a server-side fairness controller.""" + return self.server_cls(f_type="accuracy_parity") + + def setup_mock_training_manager( + self, + idx: int, + ) -> mock.MagicMock: + """Setup and return a mock TrainingManager for a given client.""" + manager = mock.create_autospec(TrainingManager, instance=True) + manager.aggrg = mock.create_autospec(Aggregator, instance=True) + manager.logger = mock.create_autospec(logging.Logger, instance=True) + manager.model = mock.create_autospec(Model, instance=True) + manager.train_data = build_mock_dataset(idx) + return manager + + def test_setup_server_from_specs( + self, + ) -> None: + """Test instantiating a server-side controller 'from_specs'.""" + server = self.server_cls.from_specs( + algorithm=self.server_cls.algorithm, + f_type="demographic_parity", + ) + assert isinstance(server, self.server_cls) + + def test_setup_client_from_setup_query( + self, + ) -> None: + """Test that the server's setup query results in a proper client.""" + server = self.setup_server_controller() + query = server.prepare_fairness_setup_query() + assert isinstance(query, FairnessSetupQuery) + manager = self.setup_mock_training_manager(idx=0) + client = FairnessControllerClient.from_setup_query(query, manager) + assert isinstance(client, self.client_cls) + assert client.manager is manager + assert client.fairness_function.f_type == server.f_type + + def setup_client_controller_from_server( + self, + server: FairnessControllerServer, + idx: int, + ) -> FairnessControllerClient: + """Instantiate and return a client-side fairness controller.""" + manager = self.setup_mock_training_manager(idx) + query = server.prepare_fairness_setup_query() + return FairnessControllerClient.from_setup_query(query, manager) + + def setup_fairness_controllers_and_secagg( + self, + n_peers: int, + use_secagg: bool, + ) -> Tuple[ + FairnessControllerServer, + List[FairnessControllerClient], + Optional[Decrypter], + Union[List[Encrypter], List[None]], + ]: + """Instantiate fairness and (optional) secagg controllers.""" + # Instantiate the server and client controllers. + server = self.setup_server_controller() + clients = [ + self.setup_client_controller_from_server(server, idx) + for idx in range(n_peers) + ] + # Optionally set up SecAgg controllers, then return. + if use_secagg: + decrypter, encrypters = build_secagg_controllers(n_peers) + return server, clients, decrypter, encrypters # type: ignore + return server, clients, None, [None for _ in range(n_peers)] + + @pytest.mark.parametrize( + "use_secagg", [False, True], ids=["clrtxt", "secagg"] + ) + @pytest.mark.asyncio + async def test_exchange_sensitive_groups_list_and_counts( + self, + use_secagg: bool, + ) -> None: + """Test that sensitive groups' definitions and counts works.""" + n_peers = len(CLIENT_COUNTS) + # Instantiate the fairness and optional secagg controllers. + server, clients, decrypter, encrypters = ( + self.setup_fairness_controllers_and_secagg(n_peers, use_secagg) + ) + # Run setup coroutines, using mock network endpoints. + async with setup_mock_network_endpoints(n_peers) as network: + coro_server = server.exchange_sensitive_groups_list_and_counts( + netwk=network[0], secagg=decrypter + ) + coro_clients = [ + client.exchange_sensitive_groups_list_and_counts( + netwk=network[1][idx], secagg=encrypters[idx] + ) + for idx, client in enumerate(clients) + ] + counts, *_ = await asyncio.gather(coro_server, *coro_clients) + # Verify that expected attributes were assigned with expected values. + assert isinstance(counts, list) and len(counts) == len(TOTAL_COUNTS) + assert dict(zip(server.groups, counts)) == TOTAL_COUNTS + assert all(client.groups == server.groups for client in clients) + + @pytest.mark.parametrize( + "use_secagg", [False, True], ids=["clrtxt", "secagg"] + ) + @pytest.mark.asyncio + async def test_finalize_fairness_setup( + self, + use_secagg: bool, + ) -> None: + """Test that 'finalize_fairness_setup' works properly. + + This test should be overridden by subclasses to perform + algorithm-specific verification (and warnings-catching). + """ + aggregator = mock.create_autospec(Aggregator, instance=True) + agg_final, *_ = await self.run_finalize_fairness_setup( + aggregator, use_secagg + ) + # Verify that the server returns an Aggregator. + assert isinstance(agg_final, Aggregator) + + async def run_finalize_fairness_setup( + self, + aggregator: Aggregator, + use_secagg: bool, + ) -> Tuple[ + Aggregator, FairnessControllerServer, List[FairnessControllerClient] + ]: + """Run 'finalize_fairness_setup' and return controllers.""" + n_peers = len(CLIENT_COUNTS) + # Instantiate the fairness and optional secagg controllers. + server, clients, decrypter, encrypters = ( + self.setup_fairness_controllers_and_secagg(n_peers, use_secagg) + ) + # Assign expected group definitions and counts. + server.groups = sorted(list(TOTAL_COUNTS)) + for client in clients: + client.groups = server.groups.copy() + counts = [TOTAL_COUNTS[group] for group in server.groups] + # Run setup coroutines, using mock network endpoints. + async with setup_mock_network_endpoints(n_peers) as network: + coro_server = server.finalize_fairness_setup( + netwk=network[0], + secagg=decrypter, + counts=counts, + aggregator=aggregator, + ) + coro_clients = [ + client.finalize_fairness_setup( + netwk=network[1][idx], + secagg=encrypters[idx], + ) + for idx, client in enumerate(clients) + ] + agg_final, *_ = await asyncio.gather(coro_server, *coro_clients) + # Return the resulting aggregator and controllers. + return agg_final, server, clients + + def test_setup_fairness_metrics( + self, + ) -> None: + """Test that 'setup_fairness_metrics' has proper output type.""" + server = self.setup_server_controller() + client = self.setup_client_controller_from_server(server, idx=0) + metrics = client.setup_fairness_metrics() + assert isinstance(metrics, list) + assert all(isinstance(metric, MeanMetric) for metric in metrics) + + @pytest.mark.parametrize("idx", list(range(len(CLIENT_COUNTS)))) + def test_compute_fairness_metrics( + self, + idx: int, + ) -> None: + """Test that metrics computation works for a given client.""" + server = self.setup_server_controller() + client = self.setup_client_controller_from_server(server, idx) + client.groups = list(TOTAL_COUNTS) + # Run mock computations. + with mock.patch.object( + client.computer, "compute_groupwise_metrics" + ) as patch_compute: + patch_compute.return_value = self.mock_client_metrics[idx].copy() + share_values, local_values = client.compute_fairness_measures(32) + # Verify that expected shareable values were output. + patch_compute.assert_called_once() + assert isinstance(share_values, list) + expected_share = [ + group_values.get(group, 0.0) * CLIENT_COUNTS[idx].get(group, 0.0) + for group_values in self.mock_client_metrics[idx].values() + for group in client.groups + ] + assert share_values == expected_share + # Verify that expected local values were output. + assert isinstance(local_values, dict) + expected_local = self.mock_client_metrics[idx].copy() + if "accuracy" in expected_local: + expected_local[client.fairness_function.f_type] = ( + client.fairness_function.compute_from_group_accuracy( + expected_local["accuracy"] + ) + ) + assert_dict_equal(local_values, expected_local) + + @pytest.mark.parametrize( + "use_secagg", [False, True], ids=["clrtxt", "secagg"] + ) + @pytest.mark.asyncio + async def test_receive_and_aggregate_fairness_metrics( + self, + use_secagg: bool, + ) -> None: + """Test that server-side aggregation of metrics works properly.""" + # Setup a server controller and optionally some secagg controllers. + n_peers = len(CLIENT_COUNTS) + server = self.setup_server_controller() + server.groups = list(TOTAL_COUNTS) + decrypter, encrypters = ( + build_secagg_controllers(n_peers) if use_secagg else (None, None) + ) + # Setup a mock network endpoint receiving local metrics. + netwk = mock.create_autospec(NetworkServer, instance=True) + replies = { + f"client_{idx}": FairnessReply( + [ + g_val.get(group, 0.0) * CLIENT_COUNTS[idx].get(group, 0.0) + for g_val in self.mock_client_metrics[idx].values() + for group in list(TOTAL_COUNTS) + ] + ) + for idx in range(len(self.mock_client_metrics)) + } + if encrypters: + secagg_replies = { + key: SecaggFairnessReply.from_cleartext_message( + cleartext=val, encrypter=encrypters[idx] + ) + for idx, (key, val) in enumerate(replies.items()) + } + netwk.wait_for_messages.return_value = { + key: SerializedMessage.from_message_string(val.to_string()) + for key, val in secagg_replies.items() + } + else: + netwk.wait_for_messages.return_value = { + key: SerializedMessage.from_message_string(val.to_string()) + for key, val in replies.items() + } + # Run the reception and (secure-)aggregation of these replies. + aggregated = await server.receive_and_aggregate_fairness_measures( + netwk=netwk, secagg=decrypter + ) + # Verify that outputs match expectations. + assert isinstance(aggregated, list) + expected = [ + sum(rv) for rv in zip(*[rep.values for rep in replies.values()]) + ] + assert np.allclose(np.array(aggregated), np.array(expected)) + + @pytest.mark.parametrize( + "use_secagg", [False, True], ids=["clrtxt", "secagg"] + ) + @pytest.mark.asyncio + async def test_finalize_fairness_round( + self, + use_secagg: bool, + ) -> None: + """Test that 'finalize_fairness_round' works properly. + + This test should be overridden by subclasses to perform + algorithm-specific verification. + """ + _, _, metrics = await self.run_finalize_fairness_round(use_secagg) + self.verify_fairness_round_metrics(metrics) + + async def run_finalize_fairness_round( + self, + use_secagg: bool, + ) -> Tuple[ + FairnessControllerServer, + List[FairnessControllerClient], + List[Dict[str, Union[float, np.ndarray]]], + ]: + """Run 'finalize_fairness_round' after mocking previous steps. + + Return the server and client controllers, as well as the list + of output metrics dictionary returned by the executed routines. + """ + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RuntimeWarning) + _, server, clients = await self.run_finalize_fairness_setup( + mock.MagicMock(), + use_secagg, + ) + # Run mock client computations and compute expected aggregate. + share_vals = [] # type: List[List[float]] + local_vals = [] # type: List[Dict[str, Dict[Tuple[Any, ...], float]]] + for idx, client in enumerate(clients): + with mock.patch.object( + client.computer, + "compute_groupwise_metrics", + return_value=self.mock_client_metrics[idx].copy(), + ): + client_values = client.compute_fairness_measures(32) + share_vals.append(client_values[0]) + local_vals.append(client_values[1]) + server_values = [float(sum(values)) for values in zip(*share_vals)] + # Setup optional SecAgg and mock network communication endpoints. + # Run the tested method. + n_peers = len(clients) + decrypter, encrypters = ( + build_secagg_controllers(n_peers) + if use_secagg + else (None, [None] * n_peers) # type: ignore + ) + async with setup_mock_network_endpoints(n_peers) as netwk: + metrics = await asyncio.gather( + server.finalize_fairness_round( + netwk=netwk[0], + secagg=decrypter, + values=server_values, + ), + *[ + client.finalize_fairness_round( + netwk=netwk[1][idx], + secagg=encrypters[idx], + values=local_vals[idx], + ) + for idx, client in enumerate(clients) + ], + ) + return server, clients, metrics + + def verify_fairness_round_metrics( + self, + metrics: List[Dict[str, Union[float, np.ndarray]]], + ) -> None: + """Verify that metrics output by fairness rounds match expectations. + + Input `metrics` contain the server-side metrics followed by each and + every client-side ones, all formatted as dictionaries. + """ + # Verify that all output metrics are dict with proper inner types. + for m_dict in metrics: + assert isinstance(m_dict, dict) + assert all(isinstance(key, str) for key in m_dict) + assert all( + isinstance(val, (float, np.ndarray)) for val in m_dict.values() + ) + # Verify that client dictionaries have the same keys. + keys = list(metrics[1].keys()) + assert all(set(m_dict).issubset(keys) for m_dict in metrics[2:]) + + @pytest.mark.parametrize( + "use_secagg", [False, True], ids=["clrtxt", "secagg"] + ) + @pytest.mark.asyncio + async def test_fairness_end2end( + self, + use_secagg: bool, + ) -> None: + """Test that running both fairness setup and round routines works. + + This end-to-end test is about verifying that running all unit-tested + components together does not raise exceptions. Details about unitary + operations are left up to unit tests. + """ + # Instantiate the fairness and optional secagg controllers. + n_peers = len(CLIENT_COUNTS) + decrypter = None # type: Optional[Decrypter] + encrypters = [None] * n_peers # type: List[Optional[Encrypter]] + if use_secagg: + decrypter, encrypters = build_secagg_controllers( # type: ignore + n_peers + ) + # Run end-to-end routines using mock communication endpoints. + async with setup_mock_network_endpoints(n_peers=n_peers) as netwk: + + async def server_routine() -> None: + """Server-side fairness setup and round routine.""" + nonlocal decrypter, netwk + server = self.setup_server_controller() + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RuntimeWarning) + await server.setup_fairness( + netwk=netwk[0], + aggregator=mock.create_autospec( + Aggregator, instance=True + ), + secagg=decrypter, + ) + await netwk[0].broadcast_message(FairnessQuery(round_i=0)) + await server.run_fairness_round( + netwk=netwk[0], + secagg=decrypter, + ) + + async def client_routine(idx: int) -> None: + """Client-side fairness setup and round routine.""" + nonlocal encrypters, netwk + # Instantiate the client-side controller. + received = await netwk[1][idx].recv_message() + setup_query = await verify_server_message_validity( + netwk[1][idx], received, FairnessSetupQuery + ) + client = FairnessControllerClient.from_setup_query( + setup_query, manager=self.setup_mock_training_manager(idx) + ) + # Run the fairness setup routine. + await client.setup_fairness(netwk[1][idx], encrypters[idx]) + # Run the fairness round routine. + received = await netwk[1][idx].recv_message() + round_query = await verify_server_message_validity( + netwk[1][idx], received, FairnessQuery + ) + await client.run_fairness_round( + netwk[1][idx], round_query, encrypters[idx] + ) + + await asyncio.gather( + server_routine(), + *[client_routine(idx) for idx in range(n_peers)], + ) diff --git a/test/fairness/controllers/test_fairbatch_controllers.py b/test/fairness/controllers/test_fairbatch_controllers.py new file mode 100644 index 0000000000000000000000000000000000000000..e815623598d6e37aa748505e7551dbd2042de9b0 --- /dev/null +++ b/test/fairness/controllers/test_fairbatch_controllers.py @@ -0,0 +1,200 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for Fed-FairBatch controllers.""" + +import asyncio +import os +from typing import List +from unittest import mock + +import pytest + +from declearn.aggregator import Aggregator, SumAggregator +from declearn.communication.utils import ErrorMessageException +from declearn.fairness.api import ( + FairnessControllerClient, + FairnessControllerServer, +) +from declearn.fairness.fairbatch import ( + FairbatchControllerClient, + FairbatchControllerServer, + FairbatchDataset, + FairbatchSamplingController, +) +from declearn.test_utils import make_importable, setup_mock_network_endpoints + +with make_importable(os.path.dirname(os.path.abspath(__file__))): + from fairness_controllers_testing import ( + FairnessControllerTestSuite, + CLIENT_COUNTS, + TOTAL_COUNTS, + ) + + +class TestFairbatchControllers(FairnessControllerTestSuite): + """Unit tests for Fed-FairBatch / FedFB controllers.""" + + # similar code to FairGrad and parent code; pylint: disable=duplicate-code + + server_cls = FairbatchControllerServer + client_cls = FairbatchControllerClient + + mock_client_metrics = [ + { + "accuracy": {group: 1.0 for group in CLIENT_COUNTS[idx]}, + "loss": {group: 0.05 for group in CLIENT_COUNTS[idx]}, + } + for idx in range(len(CLIENT_COUNTS)) + ] + + def setup_server_controller(self) -> FairbatchControllerServer: + return self.server_cls(f_type="equalized_odds") + + @pytest.mark.parametrize( + "use_secagg", [False, True], ids=["clrtxt", "secagg"] + ) + @pytest.mark.asyncio + async def test_finalize_fairness_setup( + self, + use_secagg: bool, + ) -> None: + aggregator = mock.create_autospec(Aggregator, instance=True) + with pytest.warns(RuntimeWarning, match="SumAggregator"): + agg_final, server, clients = ( + await self.run_finalize_fairness_setup(aggregator, use_secagg) + ) + # Verify that aggregators were replaced with a SumAggregator. + assert isinstance(agg_final, SumAggregator) + assert all( + isinstance(client.manager.aggrg, SumAggregator) + for client in clients + ) + # Verify that the sampling controller was properly instantiated. + assert isinstance(server, FairbatchControllerServer) + assert server.sampling_controller.counts == TOTAL_COUNTS + # Verify that FairBatch sampling probas were shared and applied. + self.verify_fairbatch_sampling_probas_coherence(server, clients) + + def verify_fairbatch_sampling_probas_coherence( + self, + server: FairnessControllerServer, + clients: List[FairnessControllerClient], + ) -> None: + """Verify that FairBatch sampling probas were shared and applied.""" + assert isinstance(server, FairbatchControllerServer) + probas = server.sampling_controller.get_sampling_probas() + for client in clients: + dst = client.manager.train_data + assert isinstance(dst, FairbatchDataset) + total = sum(probas[group] for group in dst.groups) + expected = {group: probas[group] / total for group in dst.groups} + assert dst.get_sampling_probabilities() == expected + + @pytest.mark.parametrize( + "use_secagg", [False, True], ids=["clrtxt", "secagg"] + ) + @pytest.mark.asyncio + async def test_finalize_fairness_round( + self, + use_secagg: bool, + ) -> None: + with mock.patch.object( + FairbatchSamplingController, + "update_from_federated_losses", + ) as patch_update_sampling_probas: + server, clients, metrics = await self.run_finalize_fairness_round( + use_secagg + ) + self.verify_fairness_round_metrics(metrics) + patch_update_sampling_probas.assert_called_once() + self.verify_fairbatch_sampling_probas_coherence(server, clients) + + def test_init_fedfb_param(self) -> None: + """Test that server-side 'fedfb' parameter is enforced.""" + with mock.patch( + "declearn.fairness.fairbatch._server.setup_fairbatch_controller" + ) as patch_setup_fairbatch: + controller = FairbatchControllerServer( + f_type="demographic_parity", + fedfb=False, + ) + assert not controller.fedfb + patch_setup_fairbatch.assert_called_once() + with mock.patch( + "declearn.fairness.fairbatch._server.setup_fedfb_controller" + ) as patch_setup_fedfb: + controller = FairbatchControllerServer( + f_type="demographic_parity", + fedfb=True, + ) + assert controller.fedfb + patch_setup_fedfb.assert_called_once() + + def test_init_alpha_param(self) -> None: + """Test that server-side 'fedfb' parameter is enforced.""" + alpha = mock.MagicMock() + server = FairbatchControllerServer( + f_type="demographic_parity", alpha=alpha + ) + assert server.sampling_controller.alpha is alpha + + @pytest.mark.asyncio + async def test_finalize_fairness_setup_error( + self, + ) -> None: + """Test that FairBatch probas update error-catching works properly.""" + n_peers = len(CLIENT_COUNTS) + # Instantiate the fairness controllers. + server = self.setup_server_controller() + clients = [ + self.setup_client_controller_from_server(server, idx) + for idx in range(n_peers) + ] + # Assign expected group definitions and counts. + server.groups = sorted(list(TOTAL_COUNTS)) + for client in clients: + client.groups = server.groups.copy() + counts = [TOTAL_COUNTS[group] for group in server.groups] + # Run setup coroutines, using mock network endpoints. + async with setup_mock_network_endpoints(n_peers) as network: + coro_server = server.finalize_fairness_setup( + netwk=network[0], + secagg=None, + counts=counts, + aggregator=mock.create_autospec(SumAggregator, instance=True), + ) + coro_clients = [ + client.finalize_fairness_setup( + netwk=network[1][idx], + secagg=None, + ) + for idx, client in enumerate(clients) + ] + # Have the sampling probabilities' assignment fail. + with mock.patch.object( + FairbatchDataset, + "set_sampling_probabilities", + side_effect=Exception, + ) as patch_set_sampling_probabilities: + exc_server, *exc_clients = await asyncio.gather( + coro_server, *coro_clients, return_exceptions=True + ) + # Assert that expected exceptions were raised. + assert isinstance(exc_server, ErrorMessageException) + assert all(isinstance(exc, RuntimeError) for exc in exc_clients) + assert patch_set_sampling_probabilities.call_count == n_peers diff --git a/test/fairness/controllers/test_fairfed_controllers.py b/test/fairness/controllers/test_fairfed_controllers.py new file mode 100644 index 0000000000000000000000000000000000000000..157d879c2f35da776889c6a505ed9402d6614bfe --- /dev/null +++ b/test/fairness/controllers/test_fairfed_controllers.py @@ -0,0 +1,188 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for FairFed controllers.""" + +import os +from typing import Dict, List, Union +from unittest import mock + +import numpy as np +import pytest + +from declearn.aggregator import Aggregator +from declearn.fairness.fairfed import ( + FairfedAggregator, + FairfedControllerClient, + FairfedControllerServer, + FairfedValueComputer, +) +from declearn.test_utils import make_importable + +with make_importable(os.path.dirname(os.path.abspath(__file__))): + from fairness_controllers_testing import ( + CLIENT_COUNTS, + FairnessControllerTestSuite, + ) + + +class TestFairfedControllers(FairnessControllerTestSuite): + """Unit tests for FairFed controllers.""" + + server_cls = FairfedControllerServer + client_cls = FairfedControllerClient + + def setup_server_controller(self) -> FairfedControllerServer: + return self.server_cls(f_type="equality_of_opportunity", strict=False) + + @pytest.mark.parametrize( + "use_secagg", [False, True], ids=["clrtxt", "secagg"] + ) + @pytest.mark.asyncio + async def test_finalize_fairness_setup( + self, + use_secagg: bool, + ) -> None: + aggregator = mock.create_autospec(Aggregator, instance=True) + with mock.patch.object( + FairfedValueComputer, "initialize" + ) as patch_initialize_computer: + with mock.patch.object( + FairfedAggregator, "initialize_local_weight" + ) as patch_initialize_aggregator: + with pytest.warns(RuntimeWarning, match="Aggregator"): + agg_final, server, clients = ( + await self.run_finalize_fairness_setup( + aggregator, use_secagg + ) + ) + # Verify that aggregators were replaced with a FairfedAggregator. + assert isinstance(agg_final, FairfedAggregator) + assert all( + isinstance(client.manager.aggrg, FairfedAggregator) + for client in clients + ) + # Verify that all FairFed computers were initialized. + calls = [ + mock.call(groups=client.fairness_function.groups) + for client in clients + ] + calls.append(mock.call(groups=server.groups)) + patch_initialize_computer.assert_has_calls(calls, any_order=True) + # Verify that all FairFed aggregators were initialized. + calls = [ + mock.call(n_samples=sum(client_counts.values())) + for client_counts in CLIENT_COUNTS + ] + patch_initialize_aggregator.assert_has_calls(calls, any_order=True) + + @pytest.mark.parametrize( + "use_secagg", [False, True], ids=["clrtxt", "secagg"] + ) + @pytest.mark.asyncio + async def test_finalize_fairness_round( + self, + use_secagg: bool, + ) -> None: + # Run the routine. + with mock.patch.object( + FairfedAggregator, "update_local_weight" + ) as patch_update_fairfed_local_weight: + server, clients, metrics = await self.run_finalize_fairness_round( + use_secagg + ) + # Verify output metrics, including coherence of fairfed-specific ones. + self.verify_fairness_round_metrics(metrics) + # Verify that expected client-wise weights update occurred. + calls = [ + mock.call( + delta_loc=client_metrics["fairfed_delta"], + delta_avg=client_metrics["fairfed_deltavg"], + ) + for client_metrics in metrics[1:] + ] + patch_update_fairfed_local_weight.assert_has_calls( + calls, any_order=True + ) + # Verify that fairfed synthetic values were properly computed. + assert isinstance(server, FairfedControllerServer) + fairness = { + group: float(metrics[0][f"{server.f_type}_{group}"]) + for group in server.groups + } + assert metrics[0]["fairfed_value"] == ( + server.fairfed_computer.compute_synthetic_fairness_value(fairness) + ) + for client, client_metrics in zip(clients, metrics[1:]): + assert isinstance(client, FairfedControllerClient) + fairness = { + group: float(client_metrics[f"{server.f_type}_{group}"]) + for group in client.fairness_function.groups + } + assert client_metrics["fairfed_value"] == ( + client.fairfed_computer.compute_synthetic_fairness_value( + fairness + ) + ) + + def verify_fairness_round_metrics( + self, + metrics: List[Dict[str, Union[float, np.ndarray]]], + ) -> None: + # Perform basic verifications. + super().verify_fairness_round_metrics(metrics) + # Verify that computed fairfed delta values are coherent. + server = metrics[0] + clients = metrics[1:] + for client in clients: + assert client["fairfed_delta"] == abs( + client["fairfed_value"] - server["fairfed_value"] + ) + assert client["fairfed_deltavg"] == server["fairfed_deltavg"] + assert server["fairfed_deltavg"] == ( + sum(client["fairfed_delta"] for client in clients) / len(clients) + ) + + @pytest.mark.parametrize( + "strict", [True, False], ids=["strict", "extended"] + ) + def test_init_params( + self, + strict: bool, + ) -> None: + """Test that instantiation parameters are properly passed.""" + rng = np.random.default_rng() + beta = abs(rng.normal()) + target = int(rng.choice(2)) + controller = FairfedControllerServer( + f_type="demographic_parity", + beta=beta, + strict=strict, + target=target, + ) + assert controller.beta == beta + assert controller.fairfed_computer.f_type == "demographic_parity" + assert controller.strict is strict + assert controller.fairfed_computer.strict is strict + assert controller.fairfed_computer.target is target + # Verify that parameters are transmitted to clients. + client = self.setup_client_controller_from_server(controller, idx=0) + assert isinstance(client, FairfedControllerClient) + assert client.beta == controller.beta + assert client.fairfed_computer.f_type == "demographic_parity" + assert client.strict is strict + assert client.fairfed_computer.strict is strict diff --git a/test/fairness/controllers/test_fairgrad_controllers.py b/test/fairness/controllers/test_fairgrad_controllers.py new file mode 100644 index 0000000000000000000000000000000000000000..cf6612f973c1ff88010faae5d1d534124292f8e0 --- /dev/null +++ b/test/fairness/controllers/test_fairgrad_controllers.py @@ -0,0 +1,154 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for Fed-FairGrad controllers.""" + +import asyncio +import os +from typing import List +from unittest import mock + +import pytest + +from declearn.aggregator import Aggregator, SumAggregator +from declearn.communication.utils import ErrorMessageException +from declearn.fairness.api import ( + FairnessDataset, + FairnessControllerClient, + FairnessControllerServer, +) +from declearn.fairness.fairgrad import ( + FairgradControllerClient, + FairgradControllerServer, + FairgradWeightsController, +) +from declearn.test_utils import make_importable, setup_mock_network_endpoints + +with make_importable(os.path.dirname(os.path.abspath(__file__))): + from fairness_controllers_testing import ( + CLIENT_COUNTS, + TOTAL_COUNTS, + FairnessControllerTestSuite, + ) + + +class TestFairgradControllers(FairnessControllerTestSuite): + """Unit tests for Fed-FairGrad controllers.""" + + server_cls = FairgradControllerServer + client_cls = FairgradControllerClient + + @pytest.mark.parametrize( + "use_secagg", [False, True], ids=["clrtxt", "secagg"] + ) + @pytest.mark.asyncio + async def test_finalize_fairness_setup( + self, + use_secagg: bool, + ) -> None: + aggregator = mock.create_autospec(Aggregator, instance=True) + with pytest.warns(RuntimeWarning, match="SumAggregator"): + agg_final, server, clients = ( + await self.run_finalize_fairness_setup(aggregator, use_secagg) + ) + # Verify that aggregators were replaced with a SumAggregator. + assert isinstance(agg_final, SumAggregator) + assert all( + isinstance(client.manager.aggrg, SumAggregator) + for client in clients + ) + # Verify that FairgradWeights were shared and applied. + self.verify_fairgrad_weights_coherence(server, clients) + + def verify_fairgrad_weights_coherence( + self, + server: FairnessControllerServer, + clients: List[FairnessControllerClient], + ) -> None: + """Verify that FairGrad weights were shared to clients and applied.""" + assert isinstance(server, FairgradControllerServer) + weights = server.weights_controller.get_current_weights(norm_nk=True) + expectw = dict(zip(server.groups, weights)) + for client in clients: + mock_dst = client.manager.train_data + assert isinstance(mock_dst, FairnessDataset) + assert isinstance(mock_dst, mock.NonCallableMagicMock) + mock_dst.set_sensitive_group_weights.assert_called_with( + weights=expectw, adjust_by_counts=True + ) + + @pytest.mark.parametrize( + "use_secagg", [False, True], ids=["clrtxt", "secagg"] + ) + @pytest.mark.asyncio + async def test_finalize_fairness_round( + self, + use_secagg: bool, + ) -> None: + with mock.patch.object( + FairgradWeightsController, + "update_weights_based_on_accuracy", + ) as patch_update_weights: + server, clients, metrics = await self.run_finalize_fairness_round( + use_secagg + ) + self.verify_fairness_round_metrics(metrics) + patch_update_weights.assert_called_once() + self.verify_fairgrad_weights_coherence(server, clients) + + @pytest.mark.asyncio + async def test_finalize_fairness_setup_error( + self, + ) -> None: + """Test that FairGrad weights setup error-catching works properly.""" + n_peers = len(CLIENT_COUNTS) + # Instantiate the fairness controllers. + server = self.setup_server_controller() + clients = [ + self.setup_client_controller_from_server(server, idx) + for idx in range(n_peers) + ] + # Assign expected group definitions and counts. + # Have client datasets fail upon receiving sensitive group weights. + server.groups = sorted(list(TOTAL_COUNTS)) + for client in clients: + client.groups = server.groups.copy() + mock_dst = client.manager.train_data + assert isinstance(mock_dst, mock.NonCallableMagicMock) + mock_dst.set_sensitive_group_weights.side_effect = Exception + counts = [TOTAL_COUNTS[group] for group in server.groups] + # Run setup coroutines, using mock network endpoints. + async with setup_mock_network_endpoints(n_peers) as network: + coro_server = server.finalize_fairness_setup( + netwk=network[0], + secagg=None, + counts=counts, + aggregator=mock.create_autospec(SumAggregator, instance=True), + ) + coro_clients = [ + client.finalize_fairness_setup( + netwk=network[1][idx], + secagg=None, + ) + for idx, client in enumerate(clients) + ] + exc_server, *exc_clients = await asyncio.gather( + coro_server, *coro_clients, return_exceptions=True + ) + # Assert that expected exceptions were raised. + assert isinstance(exc_server, ErrorMessageException) + assert all(isinstance(exc, RuntimeError) for exc in exc_clients) diff --git a/test/fairness/controllers/test_monitor_controllers.py b/test/fairness/controllers/test_monitor_controllers.py new file mode 100644 index 0000000000000000000000000000000000000000..ea8b86ea6076db49015171e5e3d48be03d9d3159 --- /dev/null +++ b/test/fairness/controllers/test_monitor_controllers.py @@ -0,0 +1,54 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for mere-monitoring fairness controllers.""" + +import os +from unittest import mock + +import pytest + +from declearn.aggregator import Aggregator +from declearn.fairness.monitor import ( + FairnessMonitorClient, + FairnessMonitorServer, +) +from declearn.test_utils import make_importable + +with make_importable(os.path.dirname(os.path.abspath(__file__))): + from fairness_controllers_testing import FairnessControllerTestSuite + + +class TestFairnessMonitorControllers(FairnessControllerTestSuite): + """Unit tests for mere-monitoring fairness controllers.""" + + server_cls = FairnessMonitorServer + client_cls = FairnessMonitorClient + + @pytest.mark.parametrize( + "use_secagg", [False, True], ids=["clrtxt", "secagg"] + ) + @pytest.mark.asyncio + async def test_finalize_fairness_setup( + self, + use_secagg: bool, + ) -> None: + aggregator = mock.create_autospec(Aggregator, instance=True) + agg_final, *_ = await self.run_finalize_fairness_setup( + aggregator, use_secagg + ) + assert agg_final is aggregator diff --git a/test/fairness/test_accuracy_computer.py b/test/fairness/test_accuracy_computer.py new file mode 100644 index 0000000000000000000000000000000000000000..a46964b6fd0c9e3b3fb61dca2079b43270d326d0 --- /dev/null +++ b/test/fairness/test_accuracy_computer.py @@ -0,0 +1,181 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for 'declearn.fairness.api.FairnessMetricsComputer'.""" + +from typing import Optional +from unittest import mock + +import numpy as np +import pytest + +from declearn.dataset import Dataset +from declearn.fairness.api import FairnessMetricsComputer, FairnessDataset +from declearn.metrics import MeanMetric, MetricSet +from declearn.model.api import Model + + +N_BATCHES = 8 +GROUPS = [(0, 0), (0, 1), (1, 0), (1, 1)] + + +@pytest.fixture(name="dataset") +def dataset_fixture() -> FairnessDataset: + """Mock FairnessDataset providing fixture.""" + # Set up a mock FairnessDataset. + dataset = mock.create_autospec(FairnessDataset, instance=True) + dataset.get_sensitive_group_definitions.return_value = GROUPS.copy() + # Set up a mock Dataset. + subdataset = mock.create_autospec(Dataset, instance=True) + batches = [mock.MagicMock() for _ in range(N_BATCHES)] + subdataset.generate_batches.return_value = iter(batches) + # Have the FairnessDataset return the Dataset for any group. + dataset.get_sensitive_group_subset.return_value = subdataset + return dataset + + +class TestFairnessMetricsComputer: + """Unit tests for 'declearn.fairness.api.FairnessMetricsComputer'.""" + + @pytest.mark.parametrize("n_batch", [None, 4, 12]) + def test_compute_metrics_over_sensitive_groups( + self, + dataset: FairnessDataset, + n_batch: Optional[int], + ) -> None: + """Test the 'compute_metrics_over_sensitive_groups' method.""" + # Set up mock objects and run (mocked) computations. + computer = FairnessMetricsComputer(dataset) + metrics = mock.create_autospec(MetricSet, instance=True) + model = mock.create_autospec(Model, instance=True) + mock_pred = (mock.MagicMock(), mock.MagicMock(), None) + model.compute_batch_predictions.return_value = mock_pred + results = computer.compute_metrics_over_sensitive_group( + group=GROUPS[0], + metrics=metrics, + model=model, + batch_size=8, + n_batch=n_batch, + ) + # Verify that expected (mocked) computations happened. + expected_nbatches = min(n_batch or N_BATCHES, N_BATCHES) + assert results is metrics.get_result.return_value + metrics.reset.assert_called_once() + assert metrics.update.call_count == expected_nbatches + assert model.compute_batch_predictions.call_count == expected_nbatches + subset = computer.g_data[GROUPS[0]] + subset.generate_batches.assert_called_once_with( # type: ignore + batch_size=8, shuffle=n_batch is not None, drop_remainder=False + ) + + def test_setup_accuracy_metric( + self, + dataset: FairnessDataset, + ) -> None: + """Verify that 'setup_accuracy_metric' works properly.""" + # Set up an accuracy metric with an arbitrary threshold. + computer = FairnessMetricsComputer(dataset) + model = mock.create_autospec(Model, instance=True) + metric = computer.setup_accuracy_metric(model, thresh=0.65) + # Verify that the metric performs expected comptuations. + assert isinstance(metric, MeanMetric) + metric.update(y_true=np.ones(4), y_pred=np.ones(4) * 0.7) + assert metric.get_result()[metric.name] == 1.0 + metric.reset() + metric.update(y_true=np.ones(4), y_pred=np.ones(4) * 0.6) + assert metric.get_result()[metric.name] == 0.0 + + def test_setup_loss_metric( + self, + dataset: FairnessDataset, + ) -> None: + """Verify that 'setup_loss_metric' works properly.""" + # Set up an accuracy metric with an arbitrary threshold. + computer = FairnessMetricsComputer(dataset) + model = mock.create_autospec(Model, instance=True) + + def mock_loss_function( + y_true: np.ndarray, + y_pred: np.ndarray, + s_wght: Optional[np.ndarray] = None, + ) -> np.ndarray: + """Mock model loss function.""" + # API-defined signature; pylint: disable=unused-argument + return np.ones_like(y_pred) * 0.05 + + model.loss_function.side_effect = mock_loss_function + metric = computer.setup_loss_metric(model) + # Verify that the metric performs expected comptuations. + assert isinstance(metric, MeanMetric) + metric.update(y_true=np.ones(4), y_pred=np.ones(4)) + assert metric.get_result()[metric.name] == 0.05 + model.loss_function.assert_called_once() + + def test_compute_groupwise_metrics( + self, + dataset: FairnessDataset, + ) -> None: + """Test the 'compute_groupwise_metrics' method.""" + # Set up mock objects and run (mocked) computations. + computer = FairnessMetricsComputer(dataset) + model = mock.create_autospec(Model, instance=True) + metrics = [ + computer.setup_accuracy_metric(model), + computer.setup_loss_metric(model), + ] + with mock.patch.object( + computer, "compute_metrics_over_sensitive_group" + ) as patch_compute_metrics_over_sensitive_group: + results = computer.compute_groupwise_metrics( + metrics=metrics, + model=model, + batch_size=16, + n_batch=32, + ) + # Verify that outputs have expected types and dict keys. + assert isinstance(results, dict) + assert set(results) == {metric.name for metric in metrics} + for m_dict in results.values(): + assert isinstance(m_dict, dict) + assert set(m_dict) == set(GROUPS) + assert all(isinstance(value, float) for value in m_dict.values()) + # Verify that expected calls occured. + patch_compute_metrics_over_sensitive_group.assert_has_calls( + [mock.call(group, mock.ANY, model, 16, 32) for group in GROUPS], + any_order=True, + ) + + def test_scale_metrics_by_sample_counts( + self, + ) -> None: + """Test that 'scale_metrics_by_sample_counts' works properly.""" + # Set up a mock FairnessDataset and wrap it up with a metrics computer. + dataset = mock.create_autospec(FairnessDataset, instance=True) + dataset.get_sensitive_group_definitions.return_value = GROUPS + dataset.get_sensitive_group_counts.return_value = { + group: idx for idx, group in enumerate(GROUPS, start=1) + } + computer = FairnessMetricsComputer(dataset) + # Test the 'scale_metrics_by_sample_counts' method. + metrics = { + group: float(idx) for idx, group in enumerate(GROUPS, start=1) + } + metrics = computer.scale_metrics_by_sample_counts(metrics) + expected = { + group: float(idx**2) for idx, group in enumerate(GROUPS, start=1) + } + assert metrics == expected diff --git a/test/fairness/test_fairness_functions.py b/test/fairness/test_fairness_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..a40746d33b63559886aa7867bec20ef3b95f1cf9 --- /dev/null +++ b/test/fairness/test_fairness_functions.py @@ -0,0 +1,317 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for built-in group-fairness functions.""" + +import abc +from typing import Any, ClassVar, Dict, List, Tuple, Type, Union + +import numpy as np +import pytest + +from declearn.fairness.api import ( + FairnessFunction, + instantiate_fairness_function, +) +from declearn.fairness.core import ( + AccuracyParityFunction, + DemographicParityFunction, + EqualityOfOpportunityFunction, + EqualizedOddsFunction, + list_fairness_functions, +) +from declearn.test_utils import assert_dict_equal + + +class FairnessFunctionTestSuite(metaclass=abc.ABCMeta): + """Shared unit tests suite for 'FairnessFunction' subclasses.""" + + cls_func: ClassVar[Type[FairnessFunction]] + + @property + @abc.abstractmethod + def expected_constants(self) -> Tuple[np.ndarray, np.ndarray]: + """Expected values for fairness constants with `counts`.""" + + @property + @abc.abstractmethod + def expected_fairness(self) -> Dict[Tuple[Any, ...], float]: + """Expected fairness values with `counts` and `accuracy`.""" + + @property + def counts(self) -> Dict[Tuple[Any, ...], int]: + """Deterministic sensitive group counts for unit tests.""" + return {(0, 0): 40, (0, 1): 20, (1, 0): 30, (1, 1): 10} + + @property + def accuracy(self) -> Dict[Tuple[Any, ...], float]: + """Deterministic group-wise accuracy values for unit tests.""" + return {(0, 0): 0.7, (0, 1): 0.4, (1, 0): 0.8, (1, 1): 0.6} + + def setup_function( + self, + ) -> FairnessFunction: + """Return a FairnessFunction used in shared unit tests.""" + return self.cls_func(counts=self.counts.copy()) + + def test_properties( + self, + ) -> None: + """Test that properties are properly defined.""" + func = self.setup_function() + assert func.groups == list(self.counts) + expected = func.compute_fairness_constants() + assert (func.constants[0] == expected[0]).all() + assert (func.constants[1] == expected[1]).all() + + def test_compute_fairness_constants( + self, + ) -> None: + """Test that 'compute_fairness_constants' returns expected values.""" + func = self.setup_function() + c_k0, c_kk = func.compute_fairness_constants() + # Assert that types and shapes match generic expectations. + n_groups = len(self.counts) + assert isinstance(c_k0, np.ndarray) + assert c_k0.shape in ((1,), (n_groups,)) + assert isinstance(c_kk, np.ndarray) + assert c_kk.shape == (n_groups, n_groups) + # Assert that values match expectations. + expected = self.expected_constants + assert np.allclose(c_k0, expected[0]) + assert np.allclose(c_kk, expected[1]) + + def test_compute_from_group_accuracy( + self, + ) -> None: + """Test that fairness computations work properly.""" + func = self.setup_function() + fairness = func.compute_from_group_accuracy(self.accuracy) + # Assert that types and keys match generic expectations. + assert isinstance(fairness, dict) + assert list(fairness) == func.groups + assert all(isinstance(x, float) for x in fairness.values()) + # Assert that values match expectations. + expected = self.expected_fairness + deltas = [abs(fairness[k] - expected[k]) for k in func.groups] + assert all(x < 1e-10 for x in deltas), deltas + + def test_compute_from_group_accuracy_error( + self, + ) -> None: + """Test that fairness computations fail on improper inputs.""" + func = self.setup_function() + accuracy = self.accuracy + red_accr = {key: accuracy[key] for key in list(accuracy)[:2]} + with pytest.raises(KeyError): + func.compute_from_group_accuracy(red_accr) + + def test_compute_from_federated_group_accuracy( + self, + ) -> None: + """Test that pseudo-federated fairness computations work properly.""" + func = self.setup_function() + counts = self.counts + accuracy = { + key: val * counts[key] for key, val in self.accuracy.items() + } + fairness = func.compute_from_federated_group_accuracy(accuracy) + expected = self.expected_fairness + deltas = [abs(fairness[k] - expected[k]) for k in func.groups] + assert all(x < 1e-10 for x in deltas) + + def test_get_specs( + self, + ) -> None: + """Test that 'get_specs' works properly.""" + func = self.setup_function() + specs = func.get_specs() + assert isinstance(specs, dict) + assert specs["f_type"] == func.f_type + assert specs["counts"] == self.counts + + def test_instantiation_from_specs( + self, + ) -> None: + """Test that instantiation from specs works properly.""" + func = self.setup_function() + fbis = instantiate_fairness_function(**func.get_specs()) + assert isinstance(fbis, func.__class__) + assert_dict_equal(fbis.get_specs(), func.get_specs()) + + +class TestAccuracyParityFunction(FairnessFunctionTestSuite): + """Unit tests for 'AccuracyParityFunction'.""" + + cls_func = AccuracyParityFunction + + @property + def expected_constants(self) -> Tuple[np.ndarray, np.ndarray]: + c_k0 = np.array(0.0) + c_kk = [ # (n_k' / n) - 1{s == s'}*(n_k' / n_s) + # fmt: off + [0.4 - 4/7, 0.2 - 0.0, 0.3 - 3/7, 0.1 - 0.0], + [0.4 - 0.0, 0.2 - 2/3, 0.3 - 0.0, 0.1 - 1/3], + [0.4 - 4/7, 0.2 - 0.0, 0.3 - 3/7, 0.1 - 0.0], + [0.4 - 0.0, 0.2 - 2/3, 0.3 - 0.0, 0.1 - 1/3], + ] + return c_k0, np.array(c_kk) + + @property + def expected_fairness(self) -> Dict[Tuple[Any, ...], float]: + c_kk = self.expected_constants[1] + accuracy = self.accuracy + acc = [accuracy[k] for k in ((0, 0), (0, 1), (1, 0), (1, 1))] + f_s0 = -sum(c * a for c, a in zip(c_kk[0], acc)) + f_s1 = -sum(c * a for c, a in zip(c_kk[1], acc)) + return {(0, 0): f_s0, (0, 1): f_s1, (1, 0): f_s0, (1, 1): f_s1} + + +class TestDemographicParityFunction(FairnessFunctionTestSuite): + """Unit tests for 'DemographicParityFunction'.""" + + cls_func = DemographicParityFunction + + @property + def expected_constants(self) -> Tuple[np.ndarray, np.ndarray]: + # (n_k / n_s) - (n_y / n) + c_k0 = [ + # fmt: off + 4/7 - 0.6, 2/3 - 0.6, 3/7 - 0.4, 1/3 - 0.4 + ] + # diagonal: (n_k / n) - (n_k / n_s) + # reverse-diagonal: -n_k' / n + # c_(y,s)^(y,s'): n_k' / n + # c_(y,s)^(y',s): (n_k' / n_s) - (n_k' / n) + c_kk = [ + # fmt: off + [0.4 - 4/7, 0.2 - 0.0, 3/7 - 0.3, 0.0 - 0.1], + [0.4 - 0.0, 0.2 - 2/3, 0.0 - 0.3, 1/3 - 0.1], + [4/7 - 0.4, 0.0 - 0.2, 0.3 - 3/7, 0.1 - 0.0], + [0.0 - 0.4, 2/3 - 0.2, 0.3 - 0.0, 0.1 - 1/3], + ] + return np.array(c_k0), np.array(c_kk) + + @property + def expected_fairness(self) -> Dict[Tuple[Any, ...], float]: + c_k0, c_kk = self.expected_constants + accuracy = self.accuracy + acc = [accuracy[k] for k in ((0, 0), (0, 1), (1, 0), (1, 1))] + f_00 = ( + c_k0[0] + c_kk[0].sum() - sum(c * a for c, a in zip(c_kk[0], acc)) + ) + f_01 = ( + c_k0[1] + c_kk[1].sum() - sum(c * a for c, a in zip(c_kk[1], acc)) + ) + return {(0, 0): f_00, (0, 1): f_01, (1, 0): -f_00, (1, 1): -f_01} + + def test_error_nonbinary_attribute(self) -> None: + """Test that a ValueError is raised on non-binary target labels.""" + with pytest.raises(ValueError): + DemographicParityFunction( + counts={(y, s): 1 for y in (0, 1, 2) for s in (0, 1)} + ) + + +class TestEqualizedOddsFunction(FairnessFunctionTestSuite): + """Unit tests for 'EqualizedOddsFunction'.""" + + cls_func = EqualizedOddsFunction + + @property + def expected_constants(self) -> Tuple[np.ndarray, np.ndarray]: + c_k0 = np.array(0.0) + # diagonal: (n_k / n_y) - 1 + # otherwise: 1{y == y'} * (n_k' / n_y) + c_kk = [ + # fmt: off + [4/6 - 1.0, 2/6 - 0.0, 0.0 - 0.0, 0.0 - 0.0], + [4/6 - 0.0, 2/6 - 1.0, 0.0 - 0.0, 0.0 - 0.0], + [0.0 - 0.0, 0.0 - 0.0, 3/4 - 1.0, 1/4 - 0.0], + [0.0 - 0.0, 0.0 - 0.0, 3/4 - 0.0, 1/4 - 1.0], + ] + return c_k0, np.array(c_kk) + + @property + def expected_fairness(self) -> Dict[Tuple[Any, ...], float]: + c_kk = self.expected_constants[1] + accuracy = self.accuracy + groups = ((0, 0), (0, 1), (1, 0), (1, 1)) + acc = [accuracy[k] for k in groups] + return { + group: -sum(c * a for c, a in zip(c_kk[i], acc)) + for i, group in enumerate(groups) + } + + +@pytest.mark.parametrize("target", [0, 1, [0, 1]]) +class TestEqualityOfOpportunity(TestEqualizedOddsFunction): + """ABC for 'EqualityOfOpportunityFunction' unit tests.""" + + cls_func = EqualityOfOpportunityFunction + + target: Union[int, List[int]] # set via fixture + + @pytest.fixture(autouse=True) + def init_attrs(self, target: Union[int, List[int]]) -> None: + """Set up the desired 'target' parametrizing the fairness function.""" + self.target = target + + def setup_function(self) -> EqualityOfOpportunityFunction: + return EqualityOfOpportunityFunction( + counts=self.counts.copy(), + target=self.target, + ) + + @property + def expected_constants(self) -> Tuple[np.ndarray, np.ndarray]: + c_k0, c_kk = super().expected_constants + if self.target == 0: + c_kk[:, 2:] = 0.0 + elif self.target == 1: + c_kk[:, :2] = 0.0 + return c_k0, c_kk + + def test_error_wrong_target_value(self) -> None: + """Test that a ValueError is raised if 'target' is misspecified.""" + with pytest.raises(ValueError): + EqualityOfOpportunityFunction(counts=self.counts, target=2) + + def test_error_wrong_target_type(self) -> None: + """Test that a TypeError is raised if 'target' has unproper type.""" + with pytest.raises(TypeError): + EqualityOfOpportunityFunction( + counts=self.counts, + target="wrong-type", # type: ignore + ) + + +def test_list_fairness_functions() -> None: + """Test 'declearn.fairness.core.list_fairness_functions'.""" + mapping = list_fairness_functions() + assert isinstance(mapping, dict) + assert all( + isinstance(key, str) and issubclass(val, FairnessFunction) + for key, val in mapping.items() + ) + for cls in ( + AccuracyParityFunction, + DemographicParityFunction, + EqualityOfOpportunityFunction, + EqualizedOddsFunction, + ): + assert mapping.get(cls.f_type) is cls # type: ignore diff --git a/test/fairness/test_fairness_inmemory_dataset.py b/test/fairness/test_fairness_inmemory_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..3551e34d677c4c857575a3ac6c88b335a864509d --- /dev/null +++ b/test/fairness/test_fairness_inmemory_dataset.py @@ -0,0 +1,276 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for 'declearn.fairness.core.FairnessInMemoryDataset'""" + +import os +from unittest import mock + +import numpy as np +import pandas as pd +import pytest +from scipy.sparse import coo_matrix # type: ignore + +from declearn.dataset import InMemoryDataset +from declearn.fairness.core import FairnessInMemoryDataset + + +SEED = 0 + + +@pytest.fixture(name="dataset") +def dataset_fixture() -> pd.DataFrame: + """Fixture providing with a small toy dataset.""" + rng = np.random.default_rng(seed=SEED) + wgt = rng.normal(size=100).astype("float32") + data = { + "col_a": np.arange(100, dtype="float32"), + "col_b": rng.normal(size=100).astype("float32"), + "col_y": rng.choice(2, size=100, replace=True), + "col_w": wgt / sum(wgt), + "col_s": rng.choice(2, size=100, replace=True), + } + return pd.DataFrame(data) + + +class TestFairnessInMemoryDatasetInit: + """Unit tests for 'declearn.fairness.core.FairnessInMemoryDataset' init.""" + + def test_init_sattr_dataframe_target_none( + self, + dataset: pd.DataFrame, + ) -> None: + """Test instantiating with a sensitive attribute and no target.""" + s_attr = pd.DataFrame(dataset.pop("col_s")) + with pytest.warns(RuntimeWarning): # due to sensitive_target=True + dst = FairnessInMemoryDataset( + dataset, s_attr=s_attr, sensitive_target=True + ) + assert isinstance(dst.sensitive, pd.Series) + assert (dst.sensitive == s_attr.apply(tuple, axis=1)).all() + + def test_init_sattr_dataframe_target( + self, + dataset: pd.DataFrame, + ) -> None: + """Test instantiating with both sensitive attribute and target.""" + s_attr = pd.DataFrame(dataset.pop("col_s")) + dst = FairnessInMemoryDataset( + dataset, s_attr=s_attr, target="col_y", sensitive_target=True + ) + expected = pd.DataFrame( + {"target": dataset["col_y"], "col_s": s_attr["col_s"]} + ).apply(tuple, axis=1) + assert isinstance(dst.sensitive, pd.Series) + assert (dst.sensitive == expected).all() + + def test_init_sattr_dataframe_no_target( + self, + dataset: pd.DataFrame, + ) -> None: + """Test instantiating with a sensitive attribute, ignoring target.""" + s_attr = pd.DataFrame(dataset.pop("col_s")) + dst = FairnessInMemoryDataset( + dataset, s_attr=s_attr, target="col_y", sensitive_target=False + ) + assert isinstance(dst.sensitive, pd.Series) + assert (dst.sensitive == s_attr.apply(tuple, axis=1)).all() + + def test_init_sattr_one_column( + self, + dataset: pd.DataFrame, + ) -> None: + """Test instantiating with sensitive attributes as a column name.""" + dst = FairnessInMemoryDataset( + dataset, s_attr=["col_s"], sensitive_target=False + ) + expected = dataset[["col_s"]].apply(tuple, axis=1) + assert (dst.sensitive == expected).all() + + def test_init_sattr_multiple_columns( + self, + dataset: pd.DataFrame, + ) -> None: + """Test instantiating with sensitive attributes as column names.""" + dst = FairnessInMemoryDataset( + dataset, s_attr=["col_s", "col_y"], sensitive_target=False + ) + expected = dataset[["col_s", "col_y"]].apply(tuple, axis=1) + assert (dst.sensitive == expected).all() + + def test_init_sattr_column_indices( + self, + dataset: pd.DataFrame, + ) -> None: + """Test instantiating with sensitive attributes as column indices.""" + dst = FairnessInMemoryDataset( + dataset, s_attr=[2, 4], sensitive_target=False + ) + expected = dataset.iloc[:, [2, 4]].apply(tuple, axis=1) + assert (dst.sensitive == expected).all() + + def test_init_sattr_path( + self, + dataset: pd.DataFrame, + tmp_path: str, + ) -> None: + """Test instantiating with a sensitive attribute as a file dump.""" + path = os.path.join(tmp_path, "s_attr.csv") + s_attr = dataset.pop("col_s") + s_attr.to_csv(path, index=False) + dst = FairnessInMemoryDataset( + dataset, s_attr=path, target="col_y", sensitive_target=True + ) + expected = pd.DataFrame( + {"target": dataset["col_y"], "col_s": s_attr} + ).apply(tuple, axis=1) + assert (dst.sensitive == expected).all() + + def test_init_sattr_array( + self, + dataset: pd.DataFrame, + ) -> None: + """Test instantiating with a sensitive attribute array.""" + s_attr = dataset.pop("col_s").values + dst = FairnessInMemoryDataset( + dataset, s_attr=s_attr, sensitive_target=False + ) + assert isinstance(dst.sensitive, pd.Series) + expected = pd.DataFrame(s_attr).apply(tuple, axis=1) + assert (dst.sensitive == expected).all() + + def test_init_sattr_spmatrix( + self, + dataset: pd.DataFrame, + ) -> None: + """Test instantiating with a sensitive attribute spmatrix.""" + s_attr = coo_matrix(dataset.pop("col_s").values).T + dst = FairnessInMemoryDataset( + dataset, s_attr=s_attr, sensitive_target=False + ) + assert isinstance(dst.sensitive, pd.Series) + expected = pd.DataFrame(s_attr.toarray()).apply(tuple, axis=1) + assert (dst.sensitive == expected).all() + + def test_init_error_wrong_sattr_cols( + self, + dataset: pd.DataFrame, + ) -> None: + """Test that an error is raised if 's_attr' is a wrongful list.""" + with pytest.raises(TypeError): + FairnessInMemoryDataset( + dataset, s_attr=["col_wrong"], sensitive_target=False + ) + + def test_init_error_wrong_sattr_type( + self, + dataset: pd.DataFrame, + ) -> None: + """Test that an error is raised if 's_attr' has unsupported type.""" + with pytest.raises(TypeError): + FairnessInMemoryDataset( + dataset, s_attr=mock.MagicMock(), sensitive_target=False + ) + + def test_init_error_wrong_sattr_shape( + self, + dataset: pd.DataFrame, + ) -> None: + """Test that an error is raised if 's_attr' has unsupported type.""" + s_attr = dataset.pop("col_s").values[:20] + with pytest.raises(ValueError): + FairnessInMemoryDataset( + dataset, s_attr=s_attr, sensitive_target=False + ) + + +@pytest.fixture(name="fdst") +def fdst_fixture( + dataset: pd.DataFrame, +) -> FairnessInMemoryDataset: + """Fixture providing with a wrapped small toy dataset.""" + return FairnessInMemoryDataset( + data=pd.DataFrame(dataset), + target="col_y", + s_wght="col_w", + s_attr=["col_s"], + sensitive_target=True, + ) + + +class TestFairnessInMemoryDataset: + """Unit tests for 'declearn.fairness.core.FairnessInMemoryDataset'.""" + + def test_get_sensitive_group_definitions( + self, + fdst: FairnessInMemoryDataset, + ) -> None: + """Test that sensitive groups definitions match expectations.""" + groups = fdst.get_sensitive_group_definitions() + assert groups == [(0, 0), (0, 1), (1, 0), (1, 1)] + + def test_get_sensitive_group_counts( + self, + fdst: FairnessInMemoryDataset, + ) -> None: + """Test that sensitive group counts match wrapped data.""" + counts = fdst.get_sensitive_group_counts() + assert set(counts) == set(fdst.get_sensitive_group_definitions()) + assert isinstance(fdst.data, pd.DataFrame) # by construction here + assert counts == fdst.data[["col_y", "col_s"]].value_counts().to_dict() + + def test_get_sensitive_group_subset( + self, + fdst: FairnessInMemoryDataset, + ) -> None: + """Test that sensitive group subset access works properly.""" + subset = fdst.get_sensitive_group_subset(group=(0, 0)) + assert isinstance(subset, InMemoryDataset) + assert (subset.target == 0).all() + assert (subset.feats["col_s"] == 0).all() + assert len(subset.data) == fdst.get_sensitive_group_counts()[(0, 0)] + + def test_set_sensitive_group_weights( + self, + fdst: FairnessInMemoryDataset, + ) -> None: + """Test that sensitive group weighting works properly.""" + # Assert that initial sample weights are based on specified column. + assert isinstance(fdst.data, pd.DataFrame) + expected = fdst.data["col_w"] + assert (fdst.weights == expected).all() + # Set sensitive group weights. + weights = {(0, 0): 0.1, (0, 1): 0.2, (1, 0): 0.3, (1, 1): 0.4} + fdst.set_sensitive_group_weights(weights, adjust_by_counts=False) + # Assert that resuling sample weights match expectations. + sgroup = fdst.data[["col_y", "col_s"]].apply(tuple, axis=1) + expected = fdst.data["col_w"] * sgroup.apply(weights.get) + assert (fdst.weights == expected).all() + # Do it again with counts-adjustment. + counts = fdst.get_sensitive_group_counts() + weights = {key: val / counts[key] for key, val in weights.items()} + fdst.set_sensitive_group_weights(weights, adjust_by_counts=True) + assert np.allclose(fdst.weights, expected) + + def test_set_sensitive_group_weights_keyerror( + self, + fdst: FairnessInMemoryDataset, + ) -> None: + """Test setting sensitive group weights with missing groups.""" + weights = {(0, 0): 0.1, (0, 1): 0.2} + with pytest.raises(KeyError): + fdst.set_sensitive_group_weights(weights) diff --git a/test/functional/test_toy_clf_fairness.py b/test/functional/test_toy_clf_fairness.py new file mode 100644 index 0000000000000000000000000000000000000000..41df97a8f2275f090099720de905c7e8b74cbe9d --- /dev/null +++ b/test/functional/test_toy_clf_fairness.py @@ -0,0 +1,254 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration test using fairness algorithms (and opt. SecAgg) on toy data. + +* Set up a toy classification dataset with a sensitive attribute, and + some client heterogeneity. +* Run a federated learning experiment... +""" + +import asyncio +import os +import warnings +from typing import List, Optional, Tuple + +import numpy as np +import pandas as pd +import pytest + +from declearn.dataset.utils import split_multi_classif_dataset +from declearn.fairness.api import FairnessControllerServer +from declearn.fairness.core import FairnessInMemoryDataset +from declearn.fairness.fairbatch import FairbatchControllerServer +from declearn.fairness.fairfed import FairfedControllerServer +from declearn.fairness.fairgrad import FairgradControllerServer +from declearn.fairness.monitor import FairnessMonitorServer +from declearn.main import FederatedClient, FederatedServer +from declearn.main.config import FLRunConfig +from declearn.model.sklearn import SklearnSGDModel +from declearn.secagg.utils import IdentityKeys +from declearn.test_utils import ( + MockNetworkClient, + MockNetworkServer, + make_importable, +) + +with make_importable(os.path.dirname(__file__)): + from test_toy_clf_secagg import setup_masking_idkeys + + +SEED = 0 + + +def generate_toy_dataset( + n_train: int = 300, + n_valid: int = 150, + n_clients: int = 3, +) -> List[Tuple[FairnessInMemoryDataset, FairnessInMemoryDataset]]: + """Generate datasets to a toy fairness-aware classification problem.""" + # Generate a toy classification dataset with a sensitive attribute. + n_samples = n_train + n_valid + inputs, s_attr, target = _generate_toy_data(n_samples) + # Split samples uniformly across clients, with 80%/20% train/valid splits. + shards = split_multi_classif_dataset( + dataset=(np.concatenate([inputs, s_attr], axis=1), target.ravel()), + n_shards=n_clients, + scheme="iid", + p_valid=0.2, + seed=SEED, + ) + # Wrap the resulting data as fairness in-memory datasets and return them. + return [ + ( + FairnessInMemoryDataset( + # fmt: off + data=x_train[:, :-1], s_attr=x_train[:, -1:], target=y_train, + expose_classes=True, + ), + FairnessInMemoryDataset( + data=x_valid[:, :-1], s_attr=x_valid[:, -1:], target=y_valid + ), + ) + for (x_train, y_train), (x_valid, y_valid) in shards + ] + + +def _generate_toy_data( + n_samples: int = 100, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Build a toy classification dataset with a binary sensitive attribute. + + - Draw random normal features X, random coefficients B and random noise N. + - Compute L = XB + N, min-max normalize it into [0, 1] probabilities P. + - Draw random binary sensitive attribute values S. + - Define Y = 1{P >= 0.8}*1{S == 1} + 1{P >= 0.5}*1{S == 0}. + + Return X, S and Y matrices, as numpy arrays. + """ + rng = np.random.default_rng(SEED) + x_dat = rng.normal(size=(n_samples, 10), scale=10.0) + s_dat = rng.choice(2, size=(n_samples, 1)) + theta = rng.normal(size=(10, 1), scale=5.0) + noise = rng.normal(size=(n_samples, 1), scale=5.0) + logit = np.matmul(x_dat, theta) + noise + y_dat = (logit - logit.min()) / (logit.max() - logit.min()) + y_dat = (y_dat >= np.where(s_dat == 1, 0.8, 0.5)).astype("float32") + return x_dat.astype("float32"), s_dat.astype("float32"), y_dat + + +async def server_routine( + fairness: FairnessControllerServer, + secagg: bool, + folder: str, + n_clients: int = 3, +) -> None: + """Run the FL routine of the server.""" + # similar to SecAgg functional test; pylint: disable=duplicate-code + model = SklearnSGDModel.from_parameters( + kind="classifier", + loss="log_loss", + penalty="none", + dtype="float32", + ) + netwk = MockNetworkServer( + host="localhost", + port=8765, + heartbeat=0.1, + ) + optim = { + "client_opt": 0.05, + "server_opt": 1.0, + "fairness": fairness, + } + server = FederatedServer( + model, + netwk=netwk, + optim=optim, + metrics=["binary-classif"], + secagg={"secagg_type": "masking"} if secagg else None, + checkpoint={"folder": folder, "max_history": 1}, + ) + config = FLRunConfig.from_params( + rounds=5, + register={"min_clients": n_clients, "timeout": 2}, + training={"n_epoch": 1, "batch_size": 10}, + fairness={"batch_size": 50}, + ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RuntimeWarning) + await server.async_run(config) + + +async def client_routine( + train_dst: FairnessInMemoryDataset, + valid_dst: FairnessInMemoryDataset, + id_keys: Optional[IdentityKeys], +) -> None: + """Run the FL routine of a given client.""" + netwk = MockNetworkClient( + server_uri="mock://localhost:8765", + name="client", + ) + secagg = ( + {"secagg_type": "masking", "id_keys": id_keys} if id_keys else None + ) + client = FederatedClient( + netwk=netwk, + train_data=train_dst, + valid_data=valid_dst, + verbose=False, + secagg=secagg, + ) + await client.async_run() + + +@pytest.fixture(name="fairness") +def fairness_fixture( + algorithm: str, + f_type: str, +) -> FairnessControllerServer: + """Server-side fairness controller providing fixture.""" + if algorithm == "fairbatch": + return FairbatchControllerServer(f_type, alpha=0.005, fedfb=False) + if algorithm == "fairfed": + return FairfedControllerServer(f_type, beta=1.0, strict=True) + if algorithm == "fairgrad": + return FairgradControllerServer(f_type, eta=0.5, eps=1e-6) + return FairnessMonitorServer(f_type) + + +@pytest.mark.parametrize("secagg", [False, True], ids=["clrtxt", "secagg"]) +@pytest.mark.parametrize("f_type", ["demographic_parity", "equalized_odds"]) +@pytest.mark.parametrize( + "algorithm", ["fairbatch", "fairfed", "fairgrad", "monitor"] +) +@pytest.mark.asyncio +async def test_toy_classif_fairness( + fairness: FairnessControllerServer, + secagg: bool, + tmp_path: str, +) -> None: + """Test a given fairness-aware federated learning algorithm on toy data. + + Set up a toy dataset for fairness-aware federated learning. + Use a given algorithm, with a given group-fairness definition. + Run training for 5 rounds. Optionally use SecAgg. + + When using mere monitoring, verify that hardcoded accuracy + and (un)fairness levels, taken as a baseline, are achieved. + When using another algorithm, verify that is achieves some + degraded accuracy, and better fairness than the baseline. + """ + # Set up the toy dataset and optional identity keys for SecAgg. + datasets = generate_toy_dataset(n_clients=3) + clients_id_keys = setup_masking_idkeys(secagg, n_clients=3) + # Set up and run the fairness-aware federated learning experiment. + coro_server = server_routine(fairness, secagg, folder=tmp_path) + coro_clients = [ + client_routine(train_dst, valid_dst, id_keys) + for (train_dst, valid_dst), id_keys in zip(datasets, clients_id_keys) + ] + outputs = await asyncio.gather( + coro_server, *coro_clients, return_exceptions=True + ) + # Assert that no exceptions occurred during the process. + errors = "\n".join(repr(e) for e in outputs if isinstance(e, Exception)) + assert not errors, f"The FL process failed:\n{errors}" + # Load and parse utility and fairness metrics at the final round. + u_metrics = pd.read_csv(os.path.join(tmp_path, "metrics.csv")) + f_metrics = pd.read_csv(os.path.join(tmp_path, "fairness_metrics.csv")) + accuracy = u_metrics.iloc[-1]["accuracy"] + fairness_cols = [f"{fairness.f_type}_{group}" for group in fairness.groups] + fairness_mean_abs = f_metrics.iloc[-1][fairness_cols].abs().mean() + # Verify that the FedAvg baseline matches expected accuracy and fairness, + # or that other algorithms achieve lower accuracy and better fairness. + # Note that FairFed is bound to match the FedAvg baseline due to the + # split across clients being uniform. + expected_fairness = { + "demographic_parity": 0.025, + "equalized_odds": 0.142, + } + if fairness.algorithm == "monitor": + assert accuracy >= 0.76 + assert fairness_mean_abs > expected_fairness[fairness.f_type] + elif fairness.algorithm == "fairfed": + assert accuracy >= 0.72 + assert fairness_mean_abs > expected_fairness[fairness.f_type] + else: + assert 0.76 > accuracy > 0.54 + assert fairness_mean_abs < expected_fairness[fairness.f_type] diff --git a/test/functional/test_toy_clf_secagg.py b/test/functional/test_toy_clf_secagg.py index cc5723e41b5d8c8d71c28da465a581637b587f4b..3599d22f0c677c577c1bb7ad876141d7dee79cc1 100644 --- a/test/functional/test_toy_clf_secagg.py +++ b/test/functional/test_toy_clf_secagg.py @@ -170,6 +170,18 @@ async def async_run_client( await client.async_run() +def setup_masking_idkeys( + secagg: bool, + n_clients: int, +) -> Union[List[IdentityKeys], List[None]]: + """Setup identity keys for SecAgg, or a list of None values.""" + if not secagg: + return [None for _ in range(n_clients)] + prv_keys = [Ed25519PrivateKey.generate() for _ in range(n_clients)] + pub_keys = [key.public_key() for key in prv_keys] + return [IdentityKeys(key, trusted=pub_keys) for key in prv_keys] + + async def run_declearn_experiment( scaffold: bool, secagg: bool, @@ -197,14 +209,7 @@ async def run_declearn_experiment( """ # Set up the toy dataset(s) and optional identity keys (for SecAgg). n_clients = len(datasets) - if secagg: - prv_keys = [Ed25519PrivateKey.generate() for _ in range(n_clients)] - pub_keys = [key.public_key() for key in prv_keys] - id_keys = [ - IdentityKeys(key, trusted=pub_keys) for key in prv_keys - ] # type: Union[List[IdentityKeys], List[None]] - else: - id_keys = [None for _ in range(n_clients)] + id_keys = setup_masking_idkeys(secagg=secagg, n_clients=n_clients) with tempfile.TemporaryDirectory() as folder: # Set up the server and client coroutines. coro_server = async_run_server(folder, scaffold, secagg, n_clients) @@ -213,13 +218,11 @@ async def run_declearn_experiment( for i, (train, valid) in enumerate(datasets) ] # Run the coroutines concurrently using asyncio. - outputs = await asyncio.gather( + output = await asyncio.gather( coro_server, *coro_clients, return_exceptions=True ) # Assert that no exceptions occurred during the process. - errors = "\n".join( - repr(exc) for exc in outputs if isinstance(exc, Exception) - ) + errors = "\n".join(repr(e) for e in output if isinstance(e, Exception)) assert not errors, f"The FL process failed:\n{errors}" # Assert that the experiment ran properly. with open( diff --git a/test/functional/test_toy_reg.py b/test/functional/test_toy_reg.py index c37f941fc2042871f43ad27dfb0735a6bbcde45b..515e182e5899cea3fdbab1da8c9ae9cb1bdaae7a 100644 --- a/test/functional/test_toy_reg.py +++ b/test/functional/test_toy_reg.py @@ -74,6 +74,7 @@ try: except ModuleNotFoundError: pass else: + import tensorflow.keras as tf_keras # type: ignore from declearn.dataset.tensorflow import TensorflowDataset from declearn.model.tensorflow import TensorflowModel, TensorflowVector # torch imports @@ -136,9 +137,7 @@ def _get_model_numpy() -> Model: def _get_model_tflow() -> Model: """Return a linear model with MSE loss in TensorFlow, with zero weights.""" tf.random.set_seed(SEED) # set seed - tfmod = tf.keras.Sequential( # pylint: disable=no-member - tf.keras.layers.Dense(units=1) # pylint: disable=no-member - ) + tfmod = tf_keras.Sequential([tf_keras.layers.Dense(units=1)]) tfmod.build([None, 100]) model = TensorflowModel(tfmod, loss="mean_squared_error") with tf.device("CPU"): diff --git a/test/main/test_config_optim.py b/test/main/test_config_optim.py new file mode 100644 index 0000000000000000000000000000000000000000..0fb4128af361f0bfd2dc24a209e0e3c170561fd2 --- /dev/null +++ b/test/main/test_config_optim.py @@ -0,0 +1,256 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for 'declearn.main.config.FLOptimConfig'.""" + +import dataclasses +import os +from unittest import mock + +import pytest + +from declearn.aggregator import Aggregator, AveragingAggregator, SumAggregator +from declearn.fairness.api import FairnessControllerServer +from declearn.fairness.fairgrad import FairgradControllerServer +from declearn.main.config import FLOptimConfig +from declearn.optimizer import Optimizer +from declearn.optimizer.modules import AdamModule + + +FIELDS = {field.name: field for field in dataclasses.fields(FLOptimConfig)} + + +class TestFLOptimConfig: + """Unit tests for 'declearn.main.config.FLOptimConfig'.""" + + # unit tests; pylint: disable=too-many-public-methods + + # Client-side optimizer. + + def test_parse_client_opt_float(self) -> None: + """Test parsing 'client_opt' from a float input.""" + field = FIELDS["client_opt"] + optim = FLOptimConfig.parse_client_opt(field, 0.1) + assert isinstance(optim, Optimizer) + assert optim.lrate == 0.1 + assert optim.w_decay == 0.0 + assert not optim.modules + assert not optim.regularizers + + def test_parse_client_opt_dict(self) -> None: + """Test parsing 'client_opt' from a dict input.""" + field = FIELDS["client_opt"] + config = {"lrate": 0.1, "modules": ["adam"]} + optim = FLOptimConfig.parse_client_opt(field, config) + assert isinstance(optim, Optimizer) + assert optim.lrate == 0.1 + assert optim.w_decay == 0.0 + assert len(optim.modules) == 1 + assert isinstance(optim.modules[0], AdamModule) + assert not optim.regularizers + + def test_parse_client_opt_dict_error(self) -> None: + """Test parsing 'client_opt' from an invalid dict input.""" + field = FIELDS["client_opt"] + config = {"modules": ["adam"]} # missing 'lrate' + with pytest.raises(TypeError): + FLOptimConfig.parse_client_opt(field, config) + + def test_parse_client_opt_optimizer(self) -> None: + """Test parsing 'client_opt' from an Optimizer input.""" + field = FIELDS["client_opt"] + optim = mock.create_autospec(Optimizer, instance=True) + assert FLOptimConfig.parse_client_opt(field, optim) is optim + + def test_parse_client_opt_error(self) -> None: + """Test parsing 'client_opt' from an invalid-type input.""" + field = FIELDS["client_opt"] + with pytest.raises(TypeError): + FLOptimConfig.parse_client_opt(field, mock.MagicMock()) + + # Server-side optimizer. + # pylint: disable=duplicate-code + + def test_parse_server_opt_none(self) -> None: + """Test parsing 'server_opt' from None.""" + field = FIELDS["server_opt"] + optim = FLOptimConfig.parse_server_opt(field, None) + assert isinstance(optim, Optimizer) + assert optim.lrate == 1.0 + assert optim.w_decay == 0.0 + assert not optim.modules + assert not optim.regularizers + + def test_parse_server_opt_float(self) -> None: + """Test parsing 'server_opt' from a float input.""" + field = FIELDS["server_opt"] + optim = FLOptimConfig.parse_server_opt(field, 0.1) + assert isinstance(optim, Optimizer) + assert optim.lrate == 0.1 + assert optim.w_decay == 0.0 + assert not optim.modules + assert not optim.regularizers + + def test_parse_server_opt_dict(self) -> None: + """Test parsing 'server_opt' from a dict input.""" + field = FIELDS["server_opt"] + config = {"lrate": 0.1, "modules": ["adam"]} + optim = FLOptimConfig.parse_server_opt(field, config) + assert isinstance(optim, Optimizer) + assert optim.lrate == 0.1 + assert optim.w_decay == 0.0 + assert len(optim.modules) == 1 + assert isinstance(optim.modules[0], AdamModule) + assert not optim.regularizers + + def test_parse_server_opt_dict_error(self) -> None: + """Test parsing 'server_opt' from an invalid dict input.""" + field = FIELDS["server_opt"] + config = {"modules": ["adam"]} # missing 'lrate' + with pytest.raises(TypeError): + FLOptimConfig.parse_server_opt(field, config) + + def test_parse_server_opt_optimizer(self) -> None: + """Test parsing 'server_opt' from an Optimizer input.""" + field = FIELDS["server_opt"] + optim = mock.create_autospec(Optimizer, instance=True) + assert FLOptimConfig.parse_server_opt(field, optim) is optim + + def test_parse_server_opt_error(self) -> None: + """Test parsing 'server_opt' from an invalid-type input.""" + field = FIELDS["server_opt"] + with pytest.raises(TypeError): + FLOptimConfig.parse_server_opt(field, mock.MagicMock()) + + # pylint: enable=duplicate-code + # Aggregator. + + def test_parse_aggregator_none(self) -> None: + """Test parsing 'aggregator' from None.""" + field = FIELDS["aggregator"] + aggregator = FLOptimConfig.parse_aggregator(field, None) + assert isinstance(aggregator, AveragingAggregator) + + def test_parse_aggregator_str(self) -> None: + """Test parsing 'aggregator' from a string.""" + field = FIELDS["aggregator"] + aggregator = FLOptimConfig.parse_aggregator(field, "sum") + assert isinstance(aggregator, SumAggregator) + + def test_parse_aggregator_dict(self) -> None: + """Test parsing 'aggregator' from a dict.""" + field = FIELDS["aggregator"] + config = {"name": "averaging", "config": {"steps_weighted": False}} + aggregator = FLOptimConfig.parse_aggregator(field, config) + assert isinstance(aggregator, AveragingAggregator) + assert not aggregator.steps_weighted + + def test_parse_aggregator_dict_error(self) -> None: + """Test parsing 'aggregator' from an invalid dict.""" + field = FIELDS["aggregator"] + config = {"name": "adam", "group": "OptiModule"} # wrong target type + with pytest.raises(TypeError): + FLOptimConfig.parse_aggregator(field, config) + + def test_parse_aggregator_aggregator(self) -> None: + """Test parsing 'aggregator' from an Aggregator.""" + field = FIELDS["aggregator"] + aggregator = mock.create_autospec(Aggregator, instance=True) + assert FLOptimConfig.parse_aggregator(field, aggregator) is aggregator + + def test_parse_aggregator_error(self) -> None: + """Test parsing 'aggregator' from an invalid-type input.""" + field = FIELDS["aggregator"] + with pytest.raises(TypeError): + FLOptimConfig.parse_aggregator(field, mock.MagicMock()) + + # Fairness. + + def test_parse_fairness_none(self) -> None: + """Test parsing 'fairness' from None.""" + field = FIELDS["fairness"] + fairness = FLOptimConfig.parse_fairness(field, None) + assert fairness is None + + def test_parse_fairness_dict(self) -> None: + """Test parsing 'fairness' from a dict.""" + field = FIELDS["fairness"] + config = { + "algorithm": "fairgrad", + "f_type": "demographic_parity", + "eta": 0.1, + "eps": 0.0, + } + fairness = FLOptimConfig.parse_fairness(field, config) + assert isinstance(fairness, FairgradControllerServer) + assert fairness.f_type == "demographic_parity" + assert fairness.weights_controller.eta == 0.1 + assert fairness.weights_controller.eps == 0.0 + + def test_parse_fairness_dict_error(self) -> None: + """Test parsing 'fairness' from an invalid dict.""" + field = FIELDS["fairness"] + config = {"algorithm": "fairgrad"} # missing f_type choice + with pytest.raises(TypeError): + FLOptimConfig.parse_fairness(field, config) + + def test_parse_fairness_controller(self) -> None: + """Test parsing 'fairness' from a FairnessControllerServer.""" + field = FIELDS["fairness"] + fairness = mock.create_autospec( + FairnessControllerServer, instance=True + ) + assert FLOptimConfig.parse_fairness(field, fairness) is fairness + + # Functional test. + + def test_from_toml(self, tmp_path: str) -> None: + """Test parsing an arbitrary, complex TOML file.""" + # Set up an arbitrary TOML file parseabld into an FLOptimConfig. + path = os.path.join(tmp_path, "config.toml") + toml_config = """ + [optim] + aggregator = "sum" + client_opt = 0.001 + [optim.server_opt] + lrate = 1.0 + modules = [["adam", {beta_1=0.8, beta_2=0.9}]] + [optim.fairness] + algorithm = "fairgrad" + f_type = "equalized_odds" + eta = 0.1 + eps = 0.0 + """ + with open(path, "w", encoding="utf-8") as file: + file.write(toml_config) + # Parse the TOML file and verify that outputs match expectations. + optim = FLOptimConfig.from_toml(path, use_section="optim") + assert isinstance(optim, FLOptimConfig) + assert isinstance(optim.aggregator, SumAggregator) + assert isinstance(optim.client_opt, Optimizer) + assert optim.client_opt.lrate == 0.001 + assert not optim.client_opt.modules + assert isinstance(optim.server_opt, Optimizer) + assert optim.server_opt.lrate == 1.0 + assert len(optim.server_opt.modules) == 1 + assert isinstance(optim.server_opt.modules[0], AdamModule) + assert optim.server_opt.modules[0].ewma_1.beta == 0.8 + assert optim.server_opt.modules[0].ewma_2.beta == 0.9 + assert isinstance(optim.fairness, FairgradControllerServer) + assert optim.fairness.f_type == "equalized_odds" + assert optim.fairness.weights_controller.eta == 0.1 + assert optim.fairness.weights_controller.eps == 0.0 diff --git a/test/main/test_main_client.py b/test/main/test_main_client.py index e6ed7c38f3b780646b7b1f5ec25e50433d679dbd..e002b9dc9455dff31c52f3b91adfd912785da39e 100644 --- a/test/main/test_main_client.py +++ b/test/main/test_main_client.py @@ -19,7 +19,7 @@ import contextlib import logging -from typing import Any, Iterator, Optional, Type +from typing import Any, Iterator, Optional, Tuple, Type from unittest import mock import pytest # type: ignore @@ -28,22 +28,27 @@ from declearn import messaging from declearn.dataset import Dataset, DataSpecs from declearn.communication import NetworkClientConfig from declearn.communication.api import NetworkClient +from declearn.fairness.api import FairnessControllerClient from declearn.main import FederatedClient -from declearn.main.utils import Checkpointer, TrainingManager +from declearn.main.utils import Checkpointer from declearn.metrics import MetricState from declearn.model.api import Model from declearn.secagg import messaging as secagg_messaging from declearn.secagg.api import SecaggConfigClient, SecaggSetupQuery +from declearn.training import TrainingManager from declearn.utils import LOGGING_LEVEL_MAJOR try: - from declearn.main.privacy import DPTrainingManager + from declearn.training.dp import DPTrainingManager except ModuleNotFoundError: DP_AVAILABLE = False else: DP_AVAILABLE = True +# numerous but organized tests; pylint: disable=too-many-lines + + MOCK_NETWK = mock.create_autospec(NetworkClient, instance=True) MOCK_NETWK.name = "client" MOCK_DATASET = mock.create_autospec(Dataset, instance=True) @@ -355,6 +360,7 @@ class TestFederatedClientInitialize: def _setup_mock_init_request( secagg: Optional[str] = None, dpsgd: bool = False, + fairness: bool = False, ) -> messaging.SerializedMessage[messaging.InitRequest]: """Return a mock serialized InitRequest.""" init_req = messaging.InitRequest( @@ -363,6 +369,7 @@ class TestFederatedClientInitialize: aggrg=mock.MagicMock(), secagg=secagg, dpsgd=dpsgd, + fairness=fairness, ) msg_init = mock.create_autospec( messaging.SerializedMessage, instance=True @@ -517,6 +524,23 @@ class TestFederatedClientInitialize: assert isinstance(reply, messaging.Error) patched.assert_not_called() + def _setup_dpsgd_setup_query( + self, + ) -> Tuple[ + messaging.SerializedMessage[messaging.PrivacyRequest], + messaging.PrivacyRequest, + ]: + """Setup a mock PrivacyRequest and a wrapping SerializedMessage.""" + dp_query = mock.create_autospec( + messaging.PrivacyRequest, instance=True + ) + msg_priv = mock.create_autospec( + messaging.SerializedMessage, instance=True + ) + msg_priv.message_cls = messaging.PrivacyRequest + msg_priv.deserialize.return_value = dp_query + return msg_priv, dp_query + @pytest.mark.asyncio async def test_initialize_with_dpsgd(self) -> None: """Test that initialization with DP-SGD works properly.""" @@ -526,13 +550,7 @@ class TestFederatedClientInitialize: netwk = mock.create_autospec(NetworkClient, instance=True) netwk.name = "client" msg_init = self._setup_mock_init_request(secagg=None, dpsgd=True) - msg_priv = mock.create_autospec( - messaging.SerializedMessage, instance=True - ) - msg_priv.message_cls = messaging.PrivacyRequest - msg_priv.deserialize.return_value = dpconfig = mock.create_autospec( - messaging.PrivacyRequest, instance=True - ) + msg_priv, dp_query = self._setup_dpsgd_setup_query() netwk.recv_message.side_effect = [msg_init, msg_priv] # Set up a client wrapping the former network endpoint. client = FederatedClient(netwk=netwk, train_data=MOCK_DATASET) @@ -541,9 +559,12 @@ class TestFederatedClientInitialize: with patch_class_constructor(DPTrainingManager) as patch_dp: with patch_class_constructor(TrainingManager) as patch_tm: await client.initialize() - # Assert that a single InitReply was then sent to the server. - reply = netwk.send_message.call_args_list[1].args[0] + # Assert that an InitReply and a PrivacyReply were sent to the server. + assert netwk.send_message.call_count == 2 + reply = netwk.send_message.call_args_list[0].args[0] assert isinstance(reply, messaging.InitReply) + reply = netwk.send_message.call_args_list[1].args[0] + assert isinstance(reply, messaging.PrivacyReply) # Assert that a DPTrainingManager was set up. patch_tm.assert_called_once() patch_dp.assert_called_once_with( @@ -557,7 +578,7 @@ class TestFederatedClientInitialize: logger=patch_tm.return_value.logger, verbose=patch_tm.return_value.verbose, ) - patch_dp.return_value.make_private.assert_called_once_with(dpconfig) + patch_dp.return_value.make_private.assert_called_once_with(dp_query) assert client.trainmanager is patch_dp.return_value @pytest.mark.asyncio @@ -578,10 +599,13 @@ class TestFederatedClientInitialize: with patch_class_constructor(TrainingManager) as patch_tm: with pytest.raises(RuntimeError): await client.initialize() - # Assert that two messages were fetched, and an error was sent. + # Assert that two messages were fetched, that first step went well + # (resulting in an InitReply) and then an Error was sent. assert netwk.recv_message.call_count == 2 - netwk.send_message.assert_called_once() - reply = netwk.send_message.call_args.args[0] + assert netwk.send_message.call_count == 2 + reply = netwk.send_message.call_args_list[0].args[0] + assert isinstance(reply, messaging.InitReply) + reply = netwk.send_message.call_args_list[1].args[0] assert isinstance(reply, messaging.Error) # Assert that the initial TrainingManager was set, but not the DP one. patch_tm.assert_called_once() @@ -596,13 +620,7 @@ class TestFederatedClientInitialize: netwk = mock.create_autospec(NetworkClient, instance=True) netwk.name = "client" msg_init = self._setup_mock_init_request(secagg=None, dpsgd=True) - msg_priv = mock.create_autospec( - messaging.SerializedMessage, instance=True - ) - msg_priv.message_cls = messaging.PrivacyRequest - msg_priv.deserialize.return_value = mock.create_autospec( - messaging.PrivacyRequest, instance=True - ) + msg_priv, _ = self._setup_dpsgd_setup_query() netwk.recv_message.side_effect = [msg_init, msg_priv] # Set up a client wrapping the former network endpoint. client = FederatedClient(netwk=netwk, train_data=MOCK_DATASET) @@ -616,10 +634,170 @@ class TestFederatedClientInitialize: # Assert that TrainingManager was instantiated and DP one was called. patch_tm.assert_called_once() patch_dp.assert_called_once() - # Assert that both messages were fetched, and an error was sent. + # Assert that both messages were fetched, and an error was sent + # after the DP-SGD setup failed. assert netwk.recv_message.call_count == 2 + assert netwk.send_message.call_count == 2 + reply = netwk.send_message.call_args_list[0].args[0] + assert isinstance(reply, messaging.InitReply) + reply = netwk.send_message.call_args_list[1].args[0] + assert isinstance(reply, messaging.Error) + + def _setup_fairness_setup_query( + self, + ) -> Tuple[ + messaging.SerializedMessage[messaging.FairnessSetupQuery], + messaging.FairnessSetupQuery, + ]: + """Setup a mock FairnessSetupQuery and a wrapping SerializedMessage.""" + fs_query = mock.create_autospec( + messaging.FairnessSetupQuery, instance=True + ) + msg_fair = mock.create_autospec( + messaging.SerializedMessage, instance=True + ) + msg_fair.message_cls = messaging.FairnessSetupQuery + msg_fair.deserialize.return_value = fs_query + return msg_fair, fs_query + + @pytest.mark.asyncio + async def test_initialize_with_fairness(self) -> None: + """Test that initialization with fairness works properly.""" + # Set up a mock network receiving an InitRequest, + # then a FairnessSetupQuery. + netwk = mock.create_autospec(NetworkClient, instance=True) + netwk.name = "client" + msg_init = self._setup_mock_init_request(secagg=None, fairness=True) + msg_fair, fs_query = self._setup_fairness_setup_query() + netwk.recv_message.side_effect = [msg_init, msg_fair] + # Set up a client wrapping the former network endpoint. + client = FederatedClient(netwk=netwk, train_data=MOCK_DATASET) + # Attempt running initialization, patching fairness controller setup. + with mock.patch.object( + FairnessControllerClient, "from_setup_query" + ) as patch_fcc: + mock_controller = patch_fcc.return_value + mock_controller.setup_fairness = mock.AsyncMock() + await client.initialize() + # Assert that a controller was instantiated and set up. + patch_fcc.assert_called_once_with( + query=fs_query, manager=client.trainmanager + ) + mock_controller.setup_fairness.assert_awaited_once_with( + netwk=client.netwk, secagg=None + ) + assert client.fairness is mock_controller + # Assert that a single InitReply was then sent to the server. netwk.send_message.assert_called_once() - reply = netwk.send_message.call_args.args[0] + reply = netwk.send_message.call_args[0][0] + assert isinstance(reply, messaging.InitReply) + + @pytest.mark.asyncio + async def test_initialize_with_fairness_and_secagg(self) -> None: + """Test that initialization with fairness and secagg works properly.""" + # Set up a mock network receiving an InitRequest with SecAgg, + # then a SecaggSetupQuery and finally a FairnessSetupQuery. + netwk = mock.create_autospec(NetworkClient, instance=True) + netwk.name = "client" + msg_init = self._setup_mock_init_request( + secagg="mock-secagg", fairness=True + ) + msg_sqry = mock.create_autospec( + messaging.SerializedMessage, instance=True + ) + msg_sqry.message_cls = SecaggSetupQuery + msg_sqry.deserialize.return_value = mock.create_autospec( + SecaggSetupQuery, instance=True + ) + msg_fair, fs_query = self._setup_fairness_setup_query() + netwk.recv_message.side_effect = [msg_init, msg_sqry, msg_fair] + # Set up a client with that endpoint and a matching mock secagg. + secagg = mock.create_autospec(SecaggConfigClient, instance=True) + secagg.secagg_type = "mock-secagg" + client = FederatedClient( + netwk=netwk, train_data=MOCK_DATASET, secagg=secagg + ) + # Attempt running initialization, patching fairness controller setup. + with mock.patch.object( + FairnessControllerClient, "from_setup_query" + ) as patch_fcc: + mock_controller = patch_fcc.return_value + mock_controller.setup_fairness = mock.AsyncMock() + await client.initialize() + # Assert that all three messages were fetched. + assert netwk.recv_message.call_count == 3 + # Assert that a secagg controller was set up. + secagg.setup_encrypter.assert_awaited_once_with(netwk, msg_sqry) + # Assert that a fairness controller was instantiated + # and then set up using the secagg controller. + patch_fcc.assert_called_once_with( + query=fs_query, manager=client.trainmanager + ) + mock_controller.setup_fairness.assert_awaited_once_with( + netwk=client.netwk, secagg=secagg.setup_encrypter.return_value + ) + assert client.fairness is mock_controller + # Assert that a single InitReply was then sent to the server. + netwk.send_message.assert_called_once() + reply = netwk.send_message.call_args[0][0] + assert isinstance(reply, messaging.InitReply) + + @pytest.mark.asyncio + async def test_initialize_with_fairness_error_wrong_message(self) -> None: + """Test error catching for fairness setup with wrong second message.""" + # Set up a mock network receiving an InitRequest but wrong follow-up. + netwk = mock.create_autospec(NetworkClient, instance=True) + netwk.name = "client" + msg_init = self._setup_mock_init_request(secagg=None, fairness=True) + netwk.recv_message.return_value = msg_init + # Set up a client wrapping the former network endpoint. + client = FederatedClient(netwk=netwk, train_data=MOCK_DATASET) + # Attempt running initialization, monitoring fairness controller setup. + with mock.patch.object( + FairnessControllerClient, "from_setup_query" + ) as patch_fcc: + with pytest.raises(RuntimeError): + await client.initialize() + # Assert that two messages were fetched, the first one answere with + # an InitReply, the second with an Error. + assert netwk.recv_message.call_count == 2 + assert netwk.send_message.call_count == 2 + reply = netwk.send_message.call_args_list[0].args[0] + assert isinstance(reply, messaging.InitReply) + reply = netwk.send_message.call_args_list[1].args[0] + assert isinstance(reply, messaging.Error) + # Assert that no fairness controller was set. + patch_fcc.assert_not_called() + + @pytest.mark.asyncio + async def test_initialize_with_fairness_error_setup(self) -> None: + """Test error catching for fairness setup with client-side failure.""" + # Set up a mock network receiving an InitRequest, + # then a FairnessSetupQuery. + netwk = mock.create_autospec(NetworkClient, instance=True) + netwk.name = "client" + msg_init = self._setup_mock_init_request(secagg=None, fairness=True) + msg_fair, _ = self._setup_fairness_setup_query() + netwk.recv_message.side_effect = [msg_init, msg_fair] + # Set up a client wrapping the former network endpoint. + client = FederatedClient(netwk=netwk, train_data=MOCK_DATASET) + # Attempt running initialization, monitoring and forcing setup failure. + with mock.patch.object( + FairnessControllerClient, "from_setup_query" + ) as patch_fcc: + patch_fcc.side_effect = TypeError + with pytest.raises(RuntimeError): + await client.initialize() + # Assert that setup was called (hence causing the exception). + patch_fcc.assert_called_once() + assert client.fairness is None + # Assert that both messages were fetched, and an error was sent + # after the fairness setup failed. + assert netwk.recv_message.call_count == 2 + assert netwk.send_message.call_count == 2 + reply = netwk.send_message.call_args_list[0].args[0] + assert isinstance(reply, messaging.InitReply) + reply = netwk.send_message.call_args_list[1].args[0] assert isinstance(reply, messaging.Error) @@ -874,6 +1052,118 @@ class TestFederatedClientEvaluationRound: train_manager.evaluation_round.assert_not_called() +class TestFederatedClientFairnessRound: + """Unit tests for 'FederatedClient.fairness_round'.""" + + @pytest.mark.parametrize("ckpt", [True, False], ids=["ckpt", "nockpt"]) + @pytest.mark.asyncio + async def test_fairness_round( + self, + ckpt: bool, + ) -> None: + """Test 'fairness_round' with fairness and without SecAgg.""" + # Set up a client with a mock NetworkClient, TrainingManager, + # FairnessClientController and optional Checkpointer. + netwk = mock.create_autospec(NetworkClient, instance=True) + netwk.name = "client" + client = FederatedClient(netwk, train_data=MOCK_DATASET) + client.trainmanager = mock.create_autospec( + TrainingManager, instance=True + ) + fairness = mock.create_autospec( + FairnessControllerClient, instance=True + ) + client.fairness = fairness + if ckpt: + client.ckptr = mock.create_autospec(Checkpointer, instance=True) + # Call the 'fairness_round' routine and verify expected actions. + request = messaging.FairnessQuery(round_i=0) + await client.fairness_round(request) + fairness.run_fairness_round.assert_awaited_once_with( + netwk=netwk, query=request, secagg=None + ) + # Verify that when a checkpointer is set, it is used. + if ckpt: + client.ckptr.save_metrics.assert_called_once_with( # type: ignore + metrics=fairness.run_fairness_round.return_value, + prefix="fairness_metrics", + append=False, # first round, hence file creation or overwrite + timestamp="round_0", + ) + + @pytest.mark.asyncio + async def test_fairness_round_secagg(self) -> None: + """Test 'fairness_round' with fairness and with SecAgg.""" + # Set up a client with a mock NetworkClient, TrainingManager, + # FairnessClientController and SecaggConfigClient. + netwk = mock.create_autospec(NetworkClient, instance=True) + netwk.name = "client" + secagg = mock.create_autospec(SecaggConfigClient, instance=True) + client = FederatedClient(netwk, train_data=MOCK_DATASET, secagg=secagg) + client.trainmanager = mock.create_autospec( + TrainingManager, instance=True + ) + fairness = mock.create_autospec( + FairnessControllerClient, instance=True + ) + client.fairness = fairness + # Call the SecAgg setup routine. + await client.setup_secagg( + mock.create_autospec(messaging.SerializedMessage) + ) + # Call the 'fairness_round' routine and verify expected actions. + request = messaging.FairnessQuery(round_i=1) + await client.fairness_round(request) + fairness.run_fairness_round.assert_awaited_once_with( + netwk=netwk, + query=request, + secagg=secagg.setup_encrypter.return_value, + ) + + @pytest.mark.asyncio + async def test_fairness_round_fairness_not_setup(self) -> None: + """Test 'fairness_round' without a fairness controller.""" + # Set up a client with a mock NetworkClient and TrainingManager, + # but no fairness controller. + netwk = mock.create_autospec(NetworkClient, instance=True) + netwk.name = "client" + client = FederatedClient(netwk, train_data=MOCK_DATASET) + client.trainmanager = mock.create_autospec( + TrainingManager, instance=True + ) + # Verify that running the routine raises a RuntimeError. + with pytest.raises(RuntimeError): + await client.fairness_round(messaging.FairnessQuery(round_i=1)) + # Verify that an Error message was sent. + netwk.send_message.assert_called_once() + reply = netwk.send_message.call_args.args[0] + assert isinstance(reply, messaging.Error) + + @pytest.mark.asyncio + async def test_fairness_round_secagg_not_setup(self) -> None: + """Test 'fairness_round' error with configured, not-setup SecAgg.""" + # Set up a client with a mock NetworkClient, TrainingManager, + # FairnessClientController and SecaggConfigClient. + netwk = mock.create_autospec(NetworkClient, instance=True) + netwk.name = "client" + secagg = mock.create_autospec(SecaggConfigClient, instance=True) + client = FederatedClient(netwk, train_data=MOCK_DATASET, secagg=secagg) + client.trainmanager = mock.create_autospec( + TrainingManager, instance=True + ) + fairness = mock.create_autospec( + FairnessControllerClient, instance=True + ) + client.fairness = fairness + # Run the routine and verify that an Error message was sent. + request = messaging.FairnessQuery(round_i=1) + await client.fairness_round(request) + netwk.send_message.assert_called_once() + reply = netwk.send_message.call_args.args[0] + assert isinstance(reply, messaging.Error) + fairness.run_fairness_round.assert_not_called() + + class TestFederatedClientMisc: """Unit tests for miscellaneous 'FederatedClient' methods.""" diff --git a/test/main/test_main_server.py b/test/main/test_main_server.py index e1c37a4a6cc243f3b97af9b1f2655fb1b671babd..1ae08bb5bab6bda583c47b5d87a1eb81827db905 100644 --- a/test/main/test_main_server.py +++ b/test/main/test_main_server.py @@ -26,6 +26,7 @@ import pytest # type: ignore from declearn.aggregator import Aggregator from declearn.communication import NetworkServerConfig from declearn.communication.api import NetworkServer +from declearn.fairness.api import FairnessControllerServer from declearn.main import FederatedServer from declearn.main.config import FLOptimConfig from declearn.main.utils import Checkpointer @@ -144,6 +145,9 @@ class TestFederatedServerInit: # pylint: disable=too-many-public-methods client_opt=mock.create_autospec(Optimizer, instance=True), server_opt=mock.create_autospec(Optimizer, instance=True), aggregator=mock.create_autospec(Aggregator, instance=True), + fairness=mock.create_autospec( + FairnessControllerServer, instance=True + ), ) server = FederatedServer( model=MOCK_MODEL, netwk=MOCK_NETWK, optim=optim @@ -151,6 +155,7 @@ class TestFederatedServerInit: # pylint: disable=too-many-public-methods assert server.c_opt is optim.client_opt assert server.optim is optim.server_opt assert server.aggrg is optim.aggregator + assert server.fairness is optim.fairness def test_optim_dict(self) -> None: """Test specifying 'optim' as a config dict.""" @@ -158,6 +163,9 @@ class TestFederatedServerInit: # pylint: disable=too-many-public-methods "client_opt": mock.create_autospec(Optimizer, instance=True), "server_opt": mock.create_autospec(Optimizer, instance=True), "aggregator": mock.create_autospec(Aggregator, instance=True), + "fairness": mock.create_autospec( + FairnessControllerServer, instance=True + ), } server = FederatedServer( model=MOCK_MODEL, netwk=MOCK_NETWK, optim=optim @@ -165,6 +173,7 @@ class TestFederatedServerInit: # pylint: disable=too-many-public-methods assert server.c_opt is optim["client_opt"] assert server.optim is optim["server_opt"] assert server.aggrg is optim["aggregator"] + assert server.fairness is optim["fairness"] def test_optim_toml(self, tmp_path: str) -> None: """Test specifying 'optim' as a TOML file path.""" @@ -192,6 +201,7 @@ class TestFederatedServerInit: # pylint: disable=too-many-public-methods assert server.c_opt.get_config() == config.client_opt.get_config() assert server.optim.get_config() == config.server_opt.get_config() assert server.aggrg.get_config() == config.aggregator.get_config() + assert server.fairness is None def test_optim_invalid(self) -> None: """Test specifying 'optim' with an invalid type.""" diff --git a/test/optimizer/test_modules.py b/test/optimizer/test_modules.py index 123edf20db0101dfbac27aed62379a560cc7b03a..d3a91a808cf4f24e0f0c0109d3c269421413c50a 100644 --- a/test/optimizer/test_modules.py +++ b/test/optimizer/test_modules.py @@ -149,7 +149,7 @@ class OptiModuleTestSuite(PluginTestBase): ) -> None: # For Noise-addition mechanisms, seed the (unsafe) RNG. if issubclass(cls, NoiseModule): - cls = functools.partial( + cls = functools.partial( # type: ignore[misc] cls, safe_mode=False, seed=0 ) # type: ignore # partial wraps the __init__ method # Run the unit test. diff --git a/test/main/test_train_manager.py b/test/training/test_train_manager.py similarity index 99% rename from test/main/test_train_manager.py rename to test/training/test_train_manager.py index 980b9f83209192cf51738a83c3e40c26d076f737..271c2a1f52d1d8a88596f66315f2101c2f1403b2 100644 --- a/test/main/test_train_manager.py +++ b/test/training/test_train_manager.py @@ -25,10 +25,10 @@ import numpy from declearn.aggregator import Aggregator from declearn.communication import messaging from declearn.dataset import Dataset -from declearn.main.utils import TrainingManager from declearn.metrics import Metric, MetricSet from declearn.model.api import Model, Vector from declearn.optimizer import Optimizer +from declearn.training import TrainingManager MockArray = mock.create_autospec(numpy.ndarray) diff --git a/test/main/test_train_manager_dp.py b/test/training/test_train_manager_dp.py similarity index 99% rename from test/main/test_train_manager_dp.py rename to test/training/test_train_manager_dp.py index 3735f290b8e9da87be72d4dce2ce2024ed95307a..01c62d7ff9c09f32d190055b6f147c94886d17d1 100644 --- a/test/main/test_train_manager_dp.py +++ b/test/training/test_train_manager_dp.py @@ -30,8 +30,8 @@ except ModuleNotFoundError: from declearn.communication import messaging from declearn.dataset import DataSpecs -from declearn.main.privacy import DPTrainingManager from declearn.optimizer.modules import GaussianNoiseModule +from declearn.training.dp import DPTrainingManager from declearn.test_utils import make_importable with make_importable(os.path.dirname(__file__)): diff --git a/test/utils/test_toml.py b/test/utils/test_toml.py index ed8a04d981648591e1c9a30536daae04cd2af2c5..1b0567d11691468c8933d23c175b98c4f94b0ef7 100644 --- a/test/utils/test_toml.py +++ b/test/utils/test_toml.py @@ -381,3 +381,56 @@ class TestTomlConfigNested: }["demo_a"] with pytest.raises(TypeError): ComplexTomlConfig.default_parser(field, path_bad) + + +@dataclasses.dataclass +class AutofillTomlConfig(TomlConfig): + """Demonstration TomlConfig subclass with an autofill field.""" + + base: int + auto: int + + autofill_fields = {"auto"} + + @classmethod + def from_params( + cls, + **kwargs: Any, + ) -> Self: + if "base" in kwargs: + kwargs.setdefault("auto", kwargs["base"]) + return super().from_params(**kwargs) + + +class TestTomlAutofill: + """Unit tests for a 'TomlConfig' subclass with an auto-fill field.""" + + def test_from_params_exhaustive(self) -> None: + """Test parsing kwargs with exhaustive values.""" + config = AutofillTomlConfig.from_params(base=0, auto=1) + assert config.base == 0 # false-positive; pylint: disable=no-member + assert config.auto == 1 # false-positive; pylint: disable=no-member + + def test_from_params_autofill(self) -> None: + """Test parsing kwargs without the auto-filled value.""" + config = AutofillTomlConfig.from_params(base=0) + assert config.base == 0 # false-positive; pylint: disable=no-member + assert config.auto == 0 # false-positive; pylint: disable=no-member + + def test_from_toml_exhaustive(self, tmp_path: str) -> None: + """Test parsing a TOML file with exhaustive values.""" + path = os.path.join(tmp_path, "config.toml") + with open(path, "w", encoding="utf-8") as file: + file.write("base = 0\nauto = 1") + config = AutofillTomlConfig.from_toml(path) + assert config.base == 0 # false-positive; pylint: disable=no-member + assert config.auto == 1 # false-positive; pylint: disable=no-member + + def test_from_toml_autofill(self, tmp_path: str) -> None: + """Test parsing a TOML file without the auto-filled value.""" + path = os.path.join(tmp_path, "config.toml") + with open(path, "w", encoding="utf-8") as file: + file.write("base = 0") + config = AutofillTomlConfig.from_toml(path) + assert config.base == 0 # false-positive; pylint: disable=no-member + assert config.auto == 0 # false-positive; pylint: disable=no-member