Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit 747399db authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Implement monitoring-only fairness controllers.

parent 23feb7d3
No related branches found
No related tags found
1 merge request!69Enable Fairness-Aware Federated Learning
...@@ -34,6 +34,8 @@ Algorithms submodules ...@@ -34,6 +34,8 @@ Algorithms submodules
FairFed algorithm controllers and utils. FairFed algorithm controllers and utils.
* [fairgrad][declearn.fairness.fairgrad]: * [fairgrad][declearn.fairness.fairgrad]:
Fed-FairGrad algorithm controllers and utils. Fed-FairGrad algorithm controllers and utils.
* [monitor][declearn.fairness.monitor]:
Fairness-monitoring controllers, that leave training unaltered.
""" """
from . import api from . import api
...@@ -41,3 +43,4 @@ from . import core ...@@ -41,3 +43,4 @@ from . import core
from . import fairbatch from . import fairbatch
from . import fairfed from . import fairfed
from . import fairgrad from . import fairgrad
from . import monitor
# coding: utf-8
# Copyright 2023 Inria (Institut National de Recherche en Informatique
# et Automatique)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fairness-monitoring controllers, that leave training unaltered.
Introduction
------------
This submodule implements dummy fairness-aware learning controllers, that
implement fairness metrics' computation, hence enabling their monitoring
throughout training, without altering the model's training process itself.
These controllers may therefore be used to monitor fairness metrics of any
baseline federated learning algorithm, notably for comparison purposes with
fairness-aware algorithms implemented using other controllers (FairBatch,
Fed-FairGrad, ...).
Controllers
-----------
* [FairnessMonitorClient][declearn.fairness.monitor.FairnessMonitorClient]:
Client-side controller to monitor fairness without altering training.
* [FairnessMonitorServer][declearn.fairness.monitor.FairnessMonitorServer]:
Server-side controller to monitor fairness without altering training.
"""
from ._client import FairnessMonitorClient
from ._server import FairnessMonitorServer
# coding: utf-8
# Copyright 2023 Inria (Institut National de Recherche en Informatique
# et Automatique)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Client-side controller to monitor fairness without altering training."""
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
from declearn.secagg.api import Encrypter
from declearn.communication.api import NetworkClient
from declearn.fairness.api import (
FairnessControllerClient,
instantiate_fairness_function,
)
from declearn.training import TrainingManager
__all__ = [
"FairnessMonitorClient",
]
class FairnessMonitorClient(FairnessControllerClient):
"""Client-side controller to monitor fairness without altering training."""
algorithm = "monitor"
def __init__(
self,
manager: TrainingManager,
f_type: str,
f_args: Dict[str, Any],
) -> None:
"""Instantiate the client-side fairness controller.
Parameters
----------
manager:
`TrainingManager` instance wrapping the model being trained
and its training dataset (that must be a `FairnessDataset`).
f_type:
Name of the type of group-fairness function being monitored.
f_args:
Keyword arguments to the group-fairness function.
"""
super().__init__(manager)
self.fairness_function = instantiate_fairness_function(
f_type=f_type, counts=self.computer.counts, **f_args
)
async def finalize_fairness_setup(
self,
netwk: NetworkClient,
secagg: Optional[Encrypter],
) -> None:
pass
def compute_fairness_measures(
self,
batch_size: int,
n_batch: Optional[int] = None,
thresh: Optional[float] = None,
) -> Tuple[List[float], List[float]]:
# Compute group-wise accuracy scores.
accuracy = self.computer.compute_groupwise_accuracy(
model=self.manager.model,
batch_size=batch_size,
n_batch=n_batch,
thresh=thresh,
)
# Flatten local values for post-processing and checkpointing.
local_values = list(accuracy.values())
# Scale local values by sample counts for their aggregation.
accuracy = self.computer.scale_metrics_by_sample_counts(accuracy)
# Flatten shareable values, ordered and filled-out.
share_values = [accuracy.get(group, 0.0) for group in self.groups]
# Return both sets of values.
return share_values, local_values
async def finalize_fairness_round(
self,
netwk: NetworkClient,
values: List[float],
secagg: Optional[Encrypter],
) -> Dict[str, Union[float, np.ndarray]]:
# Recover raw accuracy scores for groups with local samples.
accuracy = dict(zip(self.computer.g_data, values))
# Compute local fairness measures.
fairness = self.fairness_function.compute_from_group_accuracy(accuracy)
f_type = self.fairness_function.f_type
# Package and return accuracy and fairness metrics.
metrics = {
f"accuracy_{key}": val for key, val in accuracy.items()
} # type: Dict[str, Union[float, np.ndarray]]
metrics.update(
{f"{f_type}_{key}": val for key, val in fairness.items()}
)
return metrics
# coding: utf-8
# Copyright 2023 Inria (Institut National de Recherche en Informatique
# et Automatique)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Server-side controller to monitor fairness without altering training."""
from typing import Any, Dict, List, Optional, Union
import numpy as np
from declearn.aggregator import Aggregator
from declearn.secagg.api import Decrypter
from declearn.communication.api import NetworkServer
from declearn.fairness.api import (
FairnessControllerServer,
instantiate_fairness_function,
)
__all__ = [
"FairnessMonitorServer",
]
class FairnessMonitorServer(FairnessControllerServer):
"""Server-side controller to monitor fairness without altering training."""
algorithm = "monitor"
def __init__(
self,
f_type: str,
f_args: Optional[Dict[str, Any]],
) -> None:
super().__init__(f_type, f_args)
# Assign a temporary fairness functions, replaced at setup time.
self.function = instantiate_fairness_function(
f_type="accuracy_parity", counts={}
)
async def finalize_fairness_setup(
self,
netwk: NetworkServer,
counts: List[int],
aggregator: Aggregator,
) -> Aggregator:
self.function = instantiate_fairness_function(
f_type=self.f_type,
counts=dict(zip(self.groups, counts)),
**self.f_args,
)
return aggregator
async def finalize_fairness_round(
self,
round_i: int,
values: List[float],
netwk: NetworkServer,
secagg: Optional[Decrypter],
) -> Dict[str, Union[float, np.ndarray]]:
# Unpack group-wise accuracy metrics and compute fairness ones.
accuracy = dict(zip(self.groups, values))
fairness = self.function.compute_from_federated_group_accuracy(
accuracy
)
# Package and return these metrics.
metrics = {
f"accuracy_{key}": val for key, val in accuracy.items()
} # type: Dict[str, Union[float, np.ndarray]]
metrics.update(
{f"{self.f_type}_{key}": val for key, val in fairness.items()}
)
return metrics
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment