diff --git a/declearn/fairness/fairfed/__init__.py b/declearn/fairness/fairfed/__init__.py index 3979d8bb435de39007c692b81aa269082e114917..d0fdc4e5d221c9af6512b7af8590b10599c3285f 100644 --- a/declearn/fairness/fairfed/__init__.py +++ b/declearn/fairness/fairfed/__init__.py @@ -52,8 +52,8 @@ Backend ------- * [FairfedAggregator][declearn.fairness.fairfed.FairfedAggregator]: Fairfed-specific Aggregator using arbitrary averaging weights. -* [FairfedFairnessFunction][declearn.fairness.fairfed.FairfedFairnessFunction]: - FairFed-specific fairness function wrapper. +* [FairfedValueComputer][declearn.fairness.fairfed.FairfedValueComputer]: + Fairfed-specific synthetic fairness value computer. Messages -------- @@ -72,6 +72,6 @@ from ._messages import ( SecaggFairfedDelta, ) from ._aggregator import FairfedAggregator -from ._function import FairfedFairnessFunction +from ._fairfed import FairfedValueComputer from ._client import FairfedControllerClient from ._server import FairfedControllerServer diff --git a/declearn/fairness/fairfed/_client.py b/declearn/fairness/fairfed/_client.py index fe02ac505ecda72888cc2c64bc58133c929cdd38..24efbefcfcb0dfd2000b352a1d57b74335dc5c34 100644 --- a/declearn/fairness/fairfed/_client.py +++ b/declearn/fairness/fairfed/_client.py @@ -25,7 +25,7 @@ from declearn.communication.api import NetworkClient from declearn.communication.utils import verify_server_message_validity from declearn.fairness.api import FairnessControllerClient from declearn.fairness.fairfed._aggregator import FairfedAggregator -from declearn.fairness.fairfed._function import FairfedFairnessFunction +from declearn.fairness.fairfed._fairfed import FairfedValueComputer from declearn.fairness.fairfed._messages import ( FairfedDelta, FairfedDeltavg, @@ -53,6 +53,7 @@ class FairfedControllerClient(FairnessControllerClient): f_args: Dict[str, Any], beta: float, strict: bool = True, + target: int = 1, ) -> None: """Instantiate the client-side fairness controller. @@ -72,23 +73,24 @@ class FairfedControllerClient(FairnessControllerClient): Whether to stick strictly to the FairFed paper's setting and explicit formulas, or to use a broader adaptation of FairFed to more diverse settings. + target: + Choice of target label to focus on in `strict` mode. + Unused when `strict=False`. """ # arguments serve modularity; pylint: disable=too-many-arguments super().__init__(manager=manager, f_type=f_type, f_args=f_args) self.beta = beta - self._key_groups = ( - ((0, 0), (0, 1)) if strict else None - ) # type: Optional[Tuple[Tuple[Any, ...], Tuple[Any, ...]]] - self.fairfed_func = FairfedFairnessFunction( - self.fairness_function, strict=strict + self._fairfed = FairfedValueComputer( + f_type=self.fairness_function.f_type, strict=strict, target=target ) + self._fairfed.initialize(groups=self.fairness_function.groups) @property def strict( self, ) -> bool: - """Whether this controller strictly sticks to the FairFed paper.""" - return self.fairfed_func.strict + """Whether this function strictly sticks to the FairFed paper.""" + return self._fairfed.strict async def finalize_fairness_setup( self, @@ -113,7 +115,9 @@ class FairfedControllerClient(FairnessControllerClient): netwk, received, expected=FairfedFairness ) # Compute the absolute difference between local and global fairness. - fair_avg = self.fairfed_func.compute_synthetic_fairness_value(fairness) + fair_avg = self._fairfed.compute_synthetic_fairness_value( + values["fairness"] + ) my_delta = FairfedDelta(abs(fair_avg - fair_glb.fairness)) # Share it with the server for its (secure-)aggregation across clients. if secagg is None: diff --git a/declearn/fairness/fairfed/_fairfed.py b/declearn/fairness/fairfed/_fairfed.py new file mode 100644 index 0000000000000000000000000000000000000000..9a829ff4125d79246c1b5e86517b97b729ebe5a5 --- /dev/null +++ b/declearn/fairness/fairfed/_fairfed.py @@ -0,0 +1,154 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Fairfed-specific synthetic fairness value computer.""" + +import warnings +from typing import Any, Dict, List, Optional, Tuple + + +__all__ = [ + "FairfedValueComputer", +] + + +class FairfedValueComputer: + """Fairfed-specific synthetic fairness value computer. + + Strict mode + ----------- + This FairFed implementation comes in two flavors. + + - The "strict" mode sticks to the original FairFed paper: + - It applies only to binary classification tasks with + a single binary sensitive attributes. + - Clients must hold examples to each and every group. + - If `f_type` is not explicitly cited in the original + paper, a `RuntimeWarning` is warned. + - The synthetic fairness value is computed based on + fairness values for two groups: (y=`target`,s=1) + and (y=`target`,s=0). + + - The "non-strict" mode extends to broader settings: + - It applies to any number of sensitive groups. + - Clients may not hold examples of all groups. + - It applies to any type of group-fairness. + - The synthetic fairness value is computed as + the average of all absolute fairness values. + - The local fairness is only computed over groups + that have a least one sample in the local data. + """ + + def __init__( + self, + f_type: str, + strict: bool = True, + target: int = 1, + ) -> None: + """Instantiate the FairFed-specific fairness function wrapper. + + Parameters + ---------- + f_type: + Name of the fairness definition being optimized. + strict: + Whether to stick strictly to the FairFed paper's setting + and explicit formulas, or to use a broader adaptation of + FairFed to more diverse settings. + See class docstring for details. + target: + Choice of target label to focus on in `strict` mode. + Unused when `strict=False`. + """ + self.f_type = f_type + self.strict = strict + self.target = target + self._key_groups = ( + None + ) # type: Optional[Tuple[Tuple[Any, ...], Tuple[Any, ...]]] + + def initialize( + self, + groups: List[Tuple[Any, ...]], + ) -> None: + """Initialize the Fairfed synthetic value computer from group counts. + + Parameters + ---------- + groups: + List of sensitive group definitions. + """ + if self.strict: + self._key_groups = self.identify_key_groups(groups) + + def identify_key_groups( + self, + groups: List[Tuple[Any, ...]], + ) -> Tuple[Tuple[Any, ...], Tuple[Any, ...]]: + """Parse sensitive groups' definitions to identify 'key' ones.""" + if self.f_type not in ( + "demographic_parity", + "equality_of_opportunity", + "equalized_odds", + ): + warnings.warn( + f"Using fairness type '{self.f_type}' with FairFed in 'strict'" + " mode. This is supported, but beyond the original paper.", + RuntimeWarning, + ) + if len(groups) != 4: + raise RuntimeError( + "FairFed in 'strict' mode requires exactly 4 sensitive groups," + " arising from a binary target label and a binary attribute." + ) + key_groups = tuple(sorted([g for g in groups if g[0] == self.target])) + if len(key_groups) != 2: + raise KeyError( + f"Failed to identify the (target,attr_0);(target,attr_1) " + "pair of sensitive groups for FairFed in 'strict' mode " + f"with 'target' value {self.target}." + ) + return key_groups + + def compute_synthetic_fairness_value( + self, + fairness: Dict[Tuple[Any, ...], float], + ) -> float: + """Compute a synthetic fairness value from group-wise ones. + + If `self.strict`, compute the difference between the fairness + values associated with two key sensitive groups, as per the + original FairFed paper for the two definitions exposed by the + authors. + + Otherwise, compute the average of absolute group-wise fairness + values, that applies to more generic fairness formulations than + in the original paper, and may encompass broader information. + + Parameters + ---------- + fairness: + Group-wise fairness metrics, as a `{group_k: score_k}` dict. + + Returns + ------- + value: + Scalar value summarizing the computed fairness. + """ + if self._key_groups is None: + return sum(abs(x) for x in fairness.values()) / len(fairness) + return fairness[self._key_groups[0]] - fairness[self._key_groups[1]] diff --git a/declearn/fairness/fairfed/_function.py b/declearn/fairness/fairfed/_function.py deleted file mode 100644 index 5233611808d5ccc3e8fb9da0b26c010572981282..0000000000000000000000000000000000000000 --- a/declearn/fairness/fairfed/_function.py +++ /dev/null @@ -1,183 +0,0 @@ -# coding: utf-8 - -# Copyright 2023 Inria (Institut National de Recherche en Informatique -# et Automatique) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""FairFed-specific fairness function wrapper.""" - -import warnings -from typing import Any, Dict, Optional, Tuple - - -from declearn.fairness.api import FairnessFunction - - -class FairfedFairnessFunction: - """FairFed-specific fairness function wrapper.""" - - def __init__( - self, - wrapped: FairnessFunction, - strict: bool = True, - target: Optional[int] = None, - ) -> None: - """Instantiate the FairFed-specific fairness function wrapper. - - Parameters - ---------- - wrapped: - Initial `FairnessFunction` instance to wrap up for FairFed. - strict: - Whether to stick strictly to the FairFed paper's setting - and explicit formulas, or to use a broader adaptation of - FairFed to more diverse settings. - See details below. - target: - Optional choice of target label to focus on in `strict` mode. - Only used when `strict=True`. If `None`, use `wrapped.target` - when it exists, or else a default value of 1. - - Strict mode - ----------- - This FairFed implementation comes in two flavors. - - - The "strict" mode sticks to the original FairFed paper: - - It applies only to binary classification tasks with - a single binary sensitive attributes. - - Clients must hold examples to each and every group. - - If `wrapped.f_type` is not explicitly cited in the - original paper, a `RuntimeWarning` is warned. - - The synthetic fairness value is computed based on - fairness values for two groups: (y=`target`,s=1) - and (y=`target`,s=0). - - - The "non-strict" mode extends to broader settings: - - It applies to any number of sensitive groups. - - Clients may not hold examples of all groups. - - It applies to any type of group-fairness. - - The synthetic fairness value is computed as - the average of all absolute fairness values. - - The local fairness is only computed over groups - that have a least one sample in the local data. - """ - self.wrapped = wrapped - self._key_groups = ( - None - ) # type: Optional[Tuple[Tuple[Any, ...], Tuple[Any, ...]]] - if strict: - target = int( - getattr(wrapped, "target", 1) if target is None else target - ) - self._key_groups = self._identify_key_groups(target) - - @property - def f_type( - self, - ) -> str: - """Type of group-fairness being measured.""" - return self.wrapped.f_type - - @property - def strict( - self, - ) -> bool: - """Whether this function strictly sticks to the FairFed paper.""" - return self._key_groups is not None - - def _identify_key_groups( - self, - target: int, - ) -> Tuple[Tuple[Any, ...], Tuple[Any, ...]]: - """Parse sensitive groups' definitions to identify 'key' ones.""" - if self.f_type not in ( - "demographic_parity", - "equality_of_opportunity", - "equalized_odds", - ): - warnings.warn( - f"Using fairness type '{self.f_type}' with FairFed in 'strict'" - " mode. This is supported, but beyond the original paper.", - RuntimeWarning, - ) - if len(self.wrapped.groups) != 4: - raise RuntimeError( - "FairFed in 'strict' mode requires exactly 4 sensitive groups," - " arising from a binary target label and a binary attribute." - ) - groups = tuple( - sorted([grp for grp in self.wrapped.groups if grp[0] == target]) - ) - if len(groups) != 2: - raise KeyError( - f"Failed to identify the (target,attr_0);(target,attr_1) " - "pair of sensitive groups for FairFed in 'strict' mode " - f"with 'target' value {target}." - ) - return groups - - def compute_group_fairness_from_accuracy( - self, - accuracy: Dict[Tuple[Any, ...], float], - federated: bool, - ) -> Dict[Tuple[Any, ...], float]: - """Compute group-wise fairness values from group-wise accuracy metrics. - - Parameters - ---------- - accuracy: - Group-wise accuracy values of the model being evaluated on a - dataset. I.e. `{group_k: P(y_pred == y_true | group_k)}`. - federated: - Whether `accuracy` holds values computes federatively, that - is sum-aggregated local-group-count-weighted accuracies - `{group_k: sum_i(n_ik * accuracy_ik)}`. - - Returns - ------- - fairness: - Group-wise fairness metrics, as a `{group_k: score_k}` dict. - """ - if federated: - return self.wrapped.compute_from_federated_group_accuracy(accuracy) - return self.wrapped.compute_from_group_accuracy(accuracy) - - def compute_synthetic_fairness_value( - self, - fairness: Dict[Tuple[Any, ...], float], - ) -> float: - """Compute a synthetic fairness value from group-wise ones. - - If `self.strict`, compute the difference between the fairness - values associated with two key sensitive groups, as per the - original FairFed paper for the two definitions exposed by the - authors. - - Otherwise, compute the average of absolute group-wise fairness - values, that applies to more generic fairness formulations than - in the original paper, and may encompass broader information. - - Parameters - ---------- - fairness: - Group-wise fairness metrics, as a `{group_k: score_k}` dict. - - Returns - ------- - value: - Scalar value summarizing the computed fairness. - """ - if self._key_groups is None: - return sum(abs(x) for x in fairness.values()) / len(fairness) - return fairness[self._key_groups[0]] - fairness[self._key_groups[1]] diff --git a/declearn/fairness/fairfed/_server.py b/declearn/fairness/fairfed/_server.py index 3f1f89425662de47be21ce2a0e9beec9d058a613..24b53489124a3100e0d6f76a56b305fe65f84589 100644 --- a/declearn/fairness/fairfed/_server.py +++ b/declearn/fairness/fairfed/_server.py @@ -30,7 +30,7 @@ from declearn.fairness.api import ( instantiate_fairness_function, ) from declearn.fairness.fairfed._aggregator import FairfedAggregator -from declearn.fairness.fairfed._function import FairfedFairnessFunction +from declearn.fairness.fairfed._fairfed import FairfedValueComputer from declearn.fairness.fairfed._messages import ( FairfedDelta, FairfedDeltavg, @@ -59,6 +59,7 @@ class FairfedControllerServer(FairnessControllerServer): f_args: Optional[Dict[str, Any]] = None, beta: float = 1.0, strict: bool = True, + target: Optional[int] = None, ) -> None: """Instantiate the server-side Fed-FairGrad controller. @@ -75,15 +76,22 @@ class FairfedControllerServer(FairnessControllerServer): Whether to stick strictly to the FairFed paper's setting and explicit formulas, or to use a broader adaptation of FairFed to more diverse settings. + target: + If `strict=True`, target value of interest, on which to focus. + If None, try fetching from `f_args` or use default value `1`. """ + # arguments serve modularity; pylint: disable=too-many-arguments super().__init__(f_type=f_type, f_args=f_args) self.beta = beta # Set up a temporary fairness function, replaced at setup time. - fairfed_func = instantiate_fairness_function( + self._fairness = instantiate_fairness_function( "accuracy_parity", counts={} ) - self.fairfed_func = FairfedFairnessFunction( - wrapped=fairfed_func, strict=strict + # Set up an uninitialized FairFed value computer. + if target is None: + target = int(self.f_args.get("target", 1)) + self._fairfed = FairfedValueComputer( + f_type=self.f_type, strict=strict, target=target ) @property @@ -91,7 +99,7 @@ class FairfedControllerServer(FairnessControllerServer): self, ) -> bool: """Whether this controller strictly sticks to the FairFed paper.""" - return self.fairfed_func.strict + return self._fairfed.strict def prepare_fairness_setup_query( self, @@ -99,6 +107,7 @@ class FairfedControllerServer(FairnessControllerServer): query = super().prepare_fairness_setup_query() query.params["beta"] = self.beta query.params["strict"] = self.strict + query.params["target"] = self._fairfed.target return query async def finalize_fairness_setup( @@ -107,13 +116,11 @@ class FairfedControllerServer(FairnessControllerServer): counts: List[int], aggregator: Aggregator, ) -> Aggregator: - # Set up a fairness function. - fairfed_func = instantiate_fairness_function( + # Set up a fairness function and initialized the FairFed computer. + self._fairness = instantiate_fairness_function( self.f_type, counts=dict(zip(self.groups, counts)), **self.f_args ) - self.fairfed_func = FairfedFairnessFunction( - wrapped=fairfed_func, strict=self.fairfed_func.strict - ) + self._fairfed.initialize(groups=self.groups) # Force the use of a FairFed-specific averaging aggregator. warnings.warn( "Overriding Aggregator choice due to the use of FairFed.", @@ -128,13 +135,13 @@ class FairfedControllerServer(FairnessControllerServer): netwk: NetworkServer, secagg: Optional[Decrypter], ) -> Dict[str, Union[float, np.ndarray]]: - # Unpack group-wise accuracy values and compute fairness. + # Unpack group-wise accuracy values and compute fairness ones. accuracy = dict(zip(self.groups, values)) - fairness = self.fairfed_func.compute_group_fairness_from_accuracy( - accuracy, federated=True + fairness = self._fairness.compute_from_federated_group_accuracy( + accuracy ) # Share the absolute mean fairness with clients. - fair_avg = self.fairfed_func.compute_synthetic_fairness_value(fairness) + fair_avg = self._fairfed.compute_synthetic_fairness_value(fairness) await netwk.broadcast_message(FairfedFairness(fairness=fair_avg)) # Await and (secure-)aggregate clients' absolute fairness difference. received = await netwk.wait_for_messages() @@ -161,6 +168,6 @@ class FairfedControllerServer(FairnessControllerServer): metrics.update( {f"{self.f_type}_{key}": val for key, val in fairness.items()} ) - metrics[f"{self.f_type}_mean_abs"] = fair_avg + metrics["fairfed_value"] = fair_avg metrics["fairfed_deltavg"] = deltavg return metrics