diff --git a/declearn/fairness/__init__.py b/declearn/fairness/__init__.py index 327ba3bac3d40975154ed1ef36e3bb287d372c84..6bfd887023ed43c71ec14b01ab3b61fc8c635217 100644 --- a/declearn/fairness/__init__.py +++ b/declearn/fairness/__init__.py @@ -34,6 +34,8 @@ Algorithms submodules FairFed algorithm controllers and utils. * [fairgrad][declearn.fairness.fairgrad]: Fed-FairGrad algorithm controllers and utils. +* [monitor][declearn.fairness.monitor]: + Fairness-monitoring controllers, that leave training unaltered. """ from . import api @@ -41,3 +43,4 @@ from . import core from . import fairbatch from . import fairfed from . import fairgrad +from . import monitor diff --git a/declearn/fairness/monitor/__init__.py b/declearn/fairness/monitor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6aa1939ff84dbf1537cae3ea30a573eb6f5a5289 --- /dev/null +++ b/declearn/fairness/monitor/__init__.py @@ -0,0 +1,40 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Fairness-monitoring controllers, that leave training unaltered. + +Introduction +------------ +This submodule implements dummy fairness-aware learning controllers, that +implement fairness metrics' computation, hence enabling their monitoring +throughout training, without altering the model's training process itself. + +These controllers may therefore be used to monitor fairness metrics of any +baseline federated learning algorithm, notably for comparison purposes with +fairness-aware algorithms implemented using other controllers (FairBatch, +Fed-FairGrad, ...). + +Controllers +----------- +* [FairnessMonitorClient][declearn.fairness.monitor.FairnessMonitorClient]: + Client-side controller to monitor fairness without altering training. +* [FairnessMonitorServer][declearn.fairness.monitor.FairnessMonitorServer]: + Server-side controller to monitor fairness without altering training. +""" + +from ._client import FairnessMonitorClient +from ._server import FairnessMonitorServer diff --git a/declearn/fairness/monitor/_client.py b/declearn/fairness/monitor/_client.py new file mode 100644 index 0000000000000000000000000000000000000000..e47231bd90dec34d351f0d3365e7cfd5e3efde37 --- /dev/null +++ b/declearn/fairness/monitor/_client.py @@ -0,0 +1,112 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Client-side controller to monitor fairness without altering training.""" + +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np + +from declearn.secagg.api import Encrypter +from declearn.communication.api import NetworkClient +from declearn.fairness.api import ( + FairnessControllerClient, + instantiate_fairness_function, +) +from declearn.training import TrainingManager + +__all__ = [ + "FairnessMonitorClient", +] + + +class FairnessMonitorClient(FairnessControllerClient): + """Client-side controller to monitor fairness without altering training.""" + + algorithm = "monitor" + + def __init__( + self, + manager: TrainingManager, + f_type: str, + f_args: Dict[str, Any], + ) -> None: + """Instantiate the client-side fairness controller. + + Parameters + ---------- + manager: + `TrainingManager` instance wrapping the model being trained + and its training dataset (that must be a `FairnessDataset`). + f_type: + Name of the type of group-fairness function being monitored. + f_args: + Keyword arguments to the group-fairness function. + """ + super().__init__(manager) + self.fairness_function = instantiate_fairness_function( + f_type=f_type, counts=self.computer.counts, **f_args + ) + + async def finalize_fairness_setup( + self, + netwk: NetworkClient, + secagg: Optional[Encrypter], + ) -> None: + pass + + def compute_fairness_measures( + self, + batch_size: int, + n_batch: Optional[int] = None, + thresh: Optional[float] = None, + ) -> Tuple[List[float], List[float]]: + # Compute group-wise accuracy scores. + accuracy = self.computer.compute_groupwise_accuracy( + model=self.manager.model, + batch_size=batch_size, + n_batch=n_batch, + thresh=thresh, + ) + # Flatten local values for post-processing and checkpointing. + local_values = list(accuracy.values()) + # Scale local values by sample counts for their aggregation. + accuracy = self.computer.scale_metrics_by_sample_counts(accuracy) + # Flatten shareable values, ordered and filled-out. + share_values = [accuracy.get(group, 0.0) for group in self.groups] + # Return both sets of values. + return share_values, local_values + + async def finalize_fairness_round( + self, + netwk: NetworkClient, + values: List[float], + secagg: Optional[Encrypter], + ) -> Dict[str, Union[float, np.ndarray]]: + # Recover raw accuracy scores for groups with local samples. + accuracy = dict(zip(self.computer.g_data, values)) + # Compute local fairness measures. + fairness = self.fairness_function.compute_from_group_accuracy(accuracy) + f_type = self.fairness_function.f_type + # Package and return accuracy and fairness metrics. + metrics = { + f"accuracy_{key}": val for key, val in accuracy.items() + } # type: Dict[str, Union[float, np.ndarray]] + metrics.update( + {f"{f_type}_{key}": val for key, val in fairness.items()} + ) + return metrics diff --git a/declearn/fairness/monitor/_server.py b/declearn/fairness/monitor/_server.py new file mode 100644 index 0000000000000000000000000000000000000000..d319a9c703c5d0a572864b2e9bc146b591779733 --- /dev/null +++ b/declearn/fairness/monitor/_server.py @@ -0,0 +1,85 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Server-side controller to monitor fairness without altering training.""" + +from typing import Any, Dict, List, Optional, Union + +import numpy as np + +from declearn.aggregator import Aggregator +from declearn.secagg.api import Decrypter +from declearn.communication.api import NetworkServer +from declearn.fairness.api import ( + FairnessControllerServer, + instantiate_fairness_function, +) + +__all__ = [ + "FairnessMonitorServer", +] + + +class FairnessMonitorServer(FairnessControllerServer): + """Server-side controller to monitor fairness without altering training.""" + + algorithm = "monitor" + + def __init__( + self, + f_type: str, + f_args: Optional[Dict[str, Any]], + ) -> None: + super().__init__(f_type, f_args) + # Assign a temporary fairness functions, replaced at setup time. + self.function = instantiate_fairness_function( + f_type="accuracy_parity", counts={} + ) + + async def finalize_fairness_setup( + self, + netwk: NetworkServer, + counts: List[int], + aggregator: Aggregator, + ) -> Aggregator: + self.function = instantiate_fairness_function( + f_type=self.f_type, + counts=dict(zip(self.groups, counts)), + **self.f_args, + ) + return aggregator + + async def finalize_fairness_round( + self, + round_i: int, + values: List[float], + netwk: NetworkServer, + secagg: Optional[Decrypter], + ) -> Dict[str, Union[float, np.ndarray]]: + # Unpack group-wise accuracy metrics and compute fairness ones. + accuracy = dict(zip(self.groups, values)) + fairness = self.function.compute_from_federated_group_accuracy( + accuracy + ) + # Package and return these metrics. + metrics = { + f"accuracy_{key}": val for key, val in accuracy.items() + } # type: Dict[str, Union[float, np.ndarray]] + metrics.update( + {f"{self.f_type}_{key}": val for key, val in fairness.items()} + ) + return metrics