diff --git a/declearn/fairness/fairfed/__init__.py b/declearn/fairness/fairfed/__init__.py index e5eefbc71c9383608a0657f0027f75d8effcba28..3979d8bb435de39007c692b81aa269082e114917 100644 --- a/declearn/fairness/fairfed/__init__.py +++ b/declearn/fairness/fairfed/__init__.py @@ -29,8 +29,9 @@ This algorithm was originally designed for settings where a binary classifier is trained over data with a single binary sensitive attribute, with the authors showcasing their generic formulas over a limited set of group fairness definitions. DecLearn expands it to -a broader case, enabling the use of arbitraty fairness definitions +a broader case, enabling the use of arbitrary fairness definitions over data that may have non-binary and/or many sensitive attributes. +A 'strict' mode is made available to stick to the original paper. Additionally, the algorithm's authors suggest combining it with other mechanisms that aim at enforcing model fairness during local training @@ -51,6 +52,8 @@ Backend ------- * [FairfedAggregator][declearn.fairness.fairfed.FairfedAggregator]: Fairfed-specific Aggregator using arbitrary averaging weights. +* [FairfedFairnessFunction][declearn.fairness.fairfed.FairfedFairnessFunction]: + FairFed-specific fairness function wrapper. Messages -------- @@ -69,5 +72,6 @@ from ._messages import ( SecaggFairfedDelta, ) from ._aggregator import FairfedAggregator +from ._function import FairfedFairnessFunction from ._client import FairfedControllerClient from ._server import FairfedControllerServer diff --git a/declearn/fairness/fairfed/_client.py b/declearn/fairness/fairfed/_client.py index 1a6ae1720981e66a918143c2211033df0c96e53c..9596ebb052faa267548a4baef988c88e140c838d 100644 --- a/declearn/fairness/fairfed/_client.py +++ b/declearn/fairness/fairfed/_client.py @@ -26,6 +26,7 @@ from declearn.communication.utils import verify_server_message_validity from declearn.fairness.api import FairnessControllerClient from declearn.fairness.core import instantiate_fairness_function from declearn.fairness.fairfed._aggregator import FairfedAggregator +from declearn.fairness.fairfed._function import FairfedFairnessFunction from declearn.fairness.fairfed._messages import ( FairfedDelta, FairfedDeltavg, @@ -52,6 +53,7 @@ class FairfedControllerClient(FairnessControllerClient): f_type: str, f_args: Dict[str, Any], beta: float, + strict: bool = True, ) -> None: """Instantiate the client-side fairness controller. @@ -67,12 +69,30 @@ class FairfedControllerClient(FairnessControllerClient): beta: Hyper-parameter controlling the magnitude of averaging weights' updates across rounds. + strict: + Whether to stick strictly to the FairFed paper's setting + and explicit formulas, or to use a broader adaptation of + FairFed to more diverse settings. """ + # arguments serve modularity; pylint: disable=too-many-arguments super().__init__(manager) self.beta = beta - self.fairness_function = instantiate_fairness_function( + self._key_groups = ( + ((0, 0), (0, 1)) if strict else None + ) # type: Optional[Tuple[Tuple[Any, ...], Tuple[Any, ...]]] + fairness_function = instantiate_fairness_function( f_type=f_type, counts=self.computer.counts, **f_args ) + self.fairfed_func = FairfedFairnessFunction( + fairness_function, strict=strict + ) + + @property + def strict( + self, + ) -> bool: + """Whether this controller strictly sticks to the FairFed paper.""" + return self.fairfed_func.strict async def finalize_fairness_setup( self, @@ -98,7 +118,9 @@ class FairfedControllerClient(FairnessControllerClient): n_batch=n_batch, thresh=thresh, ) - fairness = self.fairness_function.compute_from_group_accuracy(accuracy) + fairness = self.fairfed_func.compute_group_fairness_from_accuracy( + accuracy, federated=False + ) # Flatten local values for post-processing and checkpointing. local_values = list(accuracy.values()) + list(fairness.values()) # Scale accuracy values by sample counts for their aggregation. @@ -124,7 +146,7 @@ class FairfedControllerClient(FairnessControllerClient): netwk, received, expected=FairfedFairness ) # Compute the absolute difference between local and global fairness. - fair_avg = sum(abs(x) for x in fairness.values()) / len(groups) + fair_avg = self.fairfed_func.compute_synthetic_fairness_value(fairness) my_delta = FairfedDelta(abs(fair_avg - fair_glb.fairness)) # Share it with the server for its (secure-)aggregation across clients. if secagg is None: @@ -150,7 +172,7 @@ class FairfedControllerClient(FairnessControllerClient): metrics = { f"accuracy_{key}": val for key, val in accuracy.items() } # type: Dict[str, Union[float, np.ndarray]] - f_type = self.fairness_function.f_type + f_type = self.fairfed_func.f_type metrics.update( {f"{f_type}_{key}": val for key, val in fairness.items()} ) diff --git a/declearn/fairness/fairfed/_function.py b/declearn/fairness/fairfed/_function.py new file mode 100644 index 0000000000000000000000000000000000000000..c5700229c4a5a68f25d7be031b1bb36c83733d7a --- /dev/null +++ b/declearn/fairness/fairfed/_function.py @@ -0,0 +1,183 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""FairFed-specific fairness function wrapper.""" + +import warnings +from typing import Any, Dict, Optional, Tuple + + +from declearn.fairness.core import FairnessFunction + + +class FairfedFairnessFunction: + """FairFed-specific fairness function wrapper.""" + + def __init__( + self, + wrapped: FairnessFunction, + strict: bool = True, + target: Optional[int] = None, + ) -> None: + """Instantiate the FairFed-specific fairness function wrapper. + + Parameters + ---------- + wrapped: + Initial `FairnessFunction` instance to wrap up for FairFed. + strict: + Whether to stick strictly to the FairFed paper's setting + and explicit formulas, or to use a broader adaptation of + FairFed to more diverse settings. + See details below. + target: + Optional choice of target label to focus on in `strict` mode. + Only used when `strict=True`. If `None`, use `wrapped.target` + when it exists, or else a default value of 1. + + Strict mode + ----------- + This FairFed implementation comes in two flavors. + + - The "strict" mode sticks to the original FairFed paper: + - It applies only to binary classification tasks with + a single binary sensitive attributes. + - Clients must hold examples to each and every group. + - If `wrapped.f_type` is not explicitly cited in the + original paper, a `RuntimeWarning` is warned. + - The synthetic fairness value is computed based on + fairness values for two groups: (y=`target`,s=1) + and (y=`target`,s=0). + + - The "non-strict" mode extends to broader settings: + - It applies to any number of sensitive groups. + - Clients may not hold examples of all groups. + - It applies to any type of group-fairness. + - The synthetic fairness value is computed as + the average of all absolute fairness values. + - The local fairness is only computed over groups + that have a least one sample in the local data. + """ + self.wrapped = wrapped + self._key_groups = ( + None + ) # type: Optional[Tuple[Tuple[Any, ...], Tuple[Any, ...]]] + if strict: + target = int( + getattr(wrapped, "target", 1) if target is None else target + ) + self._key_groups = self._identify_key_groups(target) + + @property + def f_type( + self, + ) -> str: + """Type of group-fairness being measured.""" + return self.wrapped.f_type + + @property + def strict( + self, + ) -> bool: + """Whether this function strictly sticks to the FairFed paper.""" + return self._key_groups is not None + + def _identify_key_groups( + self, + target: int, + ) -> Tuple[Tuple[Any, ...], Tuple[Any, ...]]: + """Parse sensitive groups' definitions to identify 'key' ones.""" + if self.f_type not in ( + "demographic_parity", + "equality_of_opportunity", + "equalized_odds", + ): + warnings.warn( + f"Using fairness type '{self.f_type}' with FairFed in 'strict'" + " mode. This is supported, but beyond the original paper.", + RuntimeWarning, + ) + if len(self.wrapped.groups) != 4: + raise RuntimeError( + "FairFed in 'strict' mode requires exactly 4 sensitive groups," + " arising from a binary target label and a binary attribute." + ) + groups = tuple( + sorted([grp for grp in self.wrapped.groups if grp[0] == target]) + ) + if len(groups) != 2: + raise KeyError( + f"Failed to identify the (target,attr_0);(target,attr_1) " + "pair of sensitive groups for FairFed in 'strict' mode " + f"with 'target' value {target}." + ) + return groups + + def compute_group_fairness_from_accuracy( + self, + accuracy: Dict[Tuple[Any, ...], float], + federated: bool, + ) -> Dict[Tuple[Any, ...], float]: + """Compute group-wise fairness values from group-wise accuracy metrics. + + Parameters + ---------- + accuracy: + Group-wise accuracy values of the model being evaluated on a + dataset. I.e. `{group_k: P(y_pred == y_true | group_k)}`. + federated: + Whether `accuracy` holds values computes federatively, that + is sum-aggregated local-group-count-weighted accuracies + `{group_k: sum_i(n_ik * accuracy_ik)}`. + + Returns + ------- + fairness: + Group-wise fairness metrics, as a `{group_k: score_k}` dict. + """ + if federated: + return self.wrapped.compute_from_federated_group_accuracy(accuracy) + return self.wrapped.compute_from_group_accuracy(accuracy) + + def compute_synthetic_fairness_value( + self, + fairness: Dict[Tuple[Any, ...], float], + ) -> float: + """Compute a synthetic fairness value from group-wise ones. + + If `self.strict`, compute the difference between the fairness + values associated with two key sensitive groups, as per the + original FairFed paper for the two definitions exposed by the + authors. + + Otherwise, compute the average of absolute group-wise fairness + values, that applies to more generic fairness formulations than + in the original paper, and may encompass broader information. + + Parameters + ---------- + fairness: + Group-wise fairness metrics, as a `{group_k: score_k}` dict. + + Returns + ------- + value: + Scalar value summarizing the computed fairness. + """ + if self._key_groups is None: + return sum(abs(x) for x in fairness.values()) / len(fairness) + return fairness[self._key_groups[0]] - fairness[self._key_groups[1]] diff --git a/declearn/fairness/fairfed/_server.py b/declearn/fairness/fairfed/_server.py index d6bda0bc868cc342068d5c90a0135453560b60b1..2483c5f45f48476a36c4849554a06b2a3227305f 100644 --- a/declearn/fairness/fairfed/_server.py +++ b/declearn/fairness/fairfed/_server.py @@ -28,6 +28,7 @@ from declearn.communication.utils import verify_client_messages_validity from declearn.fairness.api import FairnessControllerServer from declearn.fairness.core import instantiate_fairness_function from declearn.fairness.fairfed._aggregator import FairfedAggregator +from declearn.fairness.fairfed._function import FairfedFairnessFunction from declearn.fairness.fairfed._messages import ( FairfedDelta, FairfedDeltavg, @@ -35,6 +36,7 @@ from declearn.fairness.fairfed._messages import ( FairfedOkay, SecaggFairfedDelta, ) +from declearn.messaging import FairnessSetupQuery from declearn.secagg.api import Decrypter from declearn.secagg.messaging import aggregate_secagg_messages @@ -54,6 +56,7 @@ class FairfedControllerServer(FairnessControllerServer): f_type: str, f_args: Optional[Dict[str, Any]] = None, beta: float = 1.0, + strict: bool = True, ) -> None: """Instantiate the server-side Fed-FairGrad controller. @@ -66,13 +69,35 @@ class FairfedControllerServer(FairnessControllerServer): beta: Hyper-parameter controlling the magnitude of updates to clients' averaging weights updates. + strict: + Whether to stick strictly to the FairFed paper's setting + and explicit formulas, or to use a broader adaptation of + FairFed to more diverse settings. """ super().__init__(f_type=f_type, f_args=f_args) self.beta = beta # Set up a temporary fairness function, replaced at setup time. - self.fairness_func = instantiate_fairness_function( + fairfed_func = instantiate_fairness_function( "accuracy_parity", counts={} ) + self.fairfed_func = FairfedFairnessFunction( + wrapped=fairfed_func, strict=strict + ) + + @property + def strict( + self, + ) -> bool: + """Whether this controller strictly sticks to the FairFed paper.""" + return self.fairfed_func.strict + + def prepare_fairness_setup_query( + self, + ) -> FairnessSetupQuery: + query = super().prepare_fairness_setup_query() + query.params["beta"] = self.beta + query.params["strict"] = self.strict + return query async def finalize_fairness_setup( self, @@ -81,9 +106,12 @@ class FairfedControllerServer(FairnessControllerServer): aggregator: Aggregator, ) -> Aggregator: # Set up a fairness function. - self.fairness_func = instantiate_fairness_function( + fairfed_func = instantiate_fairness_function( self.f_type, counts=dict(zip(self.groups, counts)), **self.f_args ) + self.fairfed_func = FairfedFairnessFunction( + wrapped=fairfed_func, strict=self.fairfed_func.strict + ) # Force the use of a FairFed-specific averaging aggregator. warnings.warn( "Overriding Aggregator choice due to the use of FairFed.", @@ -100,11 +128,11 @@ class FairfedControllerServer(FairnessControllerServer): ) -> Dict[str, Union[float, np.ndarray]]: # Unpack group-wise accuracy values and compute fairness. accuracy = dict(zip(self.groups, values)) - fairness = self.fairness_func.compute_from_federated_group_accuracy( - accuracy + fairness = self.fairfed_func.compute_group_fairness_from_accuracy( + accuracy, federated=True ) # Share the absolute mean fairness with clients. - fair_avg = sum(abs(x) for x in fairness.values()) / len(fairness) + fair_avg = self.fairfed_func.compute_synthetic_fairness_value(fairness) await netwk.broadcast_message(FairfedFairness(fairness=fair_avg)) # Await and (secure-)aggregate clients' absolute fairness difference. received = await netwk.wait_for_messages()