Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit c8bf5b93 authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Add integration tests for fairness-aware federated learning.

parent d9a9a276
No related branches found
No related tags found
1 merge request!69Enable Fairness-Aware Federated Learning
Pipeline #1010682 failed
# coding: utf-8
# Copyright 2023 Inria (Institut National de Recherche en Informatique
# et Automatique)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Integration test using fairness algorithms (and opt. SecAgg) on toy data.
* Set up a toy classification dataset with a sensitive attribute, and
some client heterogeneity.
* Run a federated learning experiment...
"""
import asyncio
import os
from typing import List, Optional, Tuple
import numpy as np
import pandas as pd
import pytest
from declearn.dataset.utils import split_multi_classif_dataset
from declearn.fairness.api import FairnessControllerServer
from declearn.fairness.core import FairnessInMemoryDataset
from declearn.fairness.fairbatch import FairbatchControllerServer
from declearn.fairness.fairfed import FairfedControllerServer
from declearn.fairness.fairgrad import FairgradControllerServer
from declearn.fairness.monitor import FairnessMonitorServer
from declearn.main import FederatedClient, FederatedServer
from declearn.main.config import FLRunConfig
from declearn.model.sklearn import SklearnSGDModel
from declearn.secagg.utils import IdentityKeys
from declearn.test_utils import (
MockNetworkClient,
MockNetworkServer,
make_importable,
)
with make_importable(os.path.dirname(__file__)):
from test_toy_clf_secagg import setup_masking_idkeys
SEED = 0
def generate_toy_dataset(
n_train: int = 300,
n_valid: int = 150,
n_clients: int = 3,
) -> List[Tuple[FairnessInMemoryDataset, FairnessInMemoryDataset]]:
"""Generate datasets to a toy fairness-aware classification problem."""
# Generate a toy classification dataset with a sensitive attribute.
n_samples = n_train + n_valid
inputs, s_attr, target = _generate_toy_data(n_samples)
# Split samples uniformly across clients, with 80%/20% train/valid splits.
shards = split_multi_classif_dataset(
dataset=(np.concatenate([inputs, s_attr], axis=1), target.ravel()),
n_shards=n_clients,
scheme="iid",
p_valid=0.2,
seed=SEED,
)
# Wrap the resulting data as fairness in-memory datasets and return them.
return [
(
FairnessInMemoryDataset(
# fmt: off
data=x_train[:, :-1], s_attr=x_train[:, -1:], target=y_train,
expose_classes=True,
),
FairnessInMemoryDataset(
data=x_valid[:, :-1], s_attr=x_valid[:, -1:], target=y_valid
),
)
for (x_train, y_train), (x_valid, y_valid) in shards
]
def _generate_toy_data(
n_samples: int = 100,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Build a toy classification dataset with a binary sensitive attribute.
- Draw random normal features X, random coefficients B and random noise N.
- Compute L = XB + N, min-max normalize it into [0, 1] probabilities P.
- Draw random binary sensitive attribute values S.
- Define Y = 1{P >= 0.8}*1{S == 1} + 1{P >= 0.5}*1{S == 0}.
Return X, S and Y matrices, as numpy arrays.
"""
rng = np.random.default_rng(SEED)
x_dat = rng.normal(size=(n_samples, 10), scale=10.0)
s_dat = rng.choice(2, size=(n_samples, 1))
theta = rng.normal(size=(10, 1), scale=5.0)
noise = rng.normal(size=(n_samples, 1), scale=5.0)
logit = np.matmul(x_dat, theta) + noise
y_dat = (logit - logit.min()) / (logit.max() - logit.min())
y_dat = (y_dat >= np.where(s_dat == 1, 0.8, 0.5)).astype("float32")
return x_dat.astype("float32"), s_dat.astype("float32"), y_dat
async def server_routine(
fairness: FairnessControllerServer,
secagg: bool,
folder: str,
n_clients: int = 3,
) -> None:
"""Run the FL routine of the server."""
model = SklearnSGDModel.from_parameters(
kind="classifier",
loss="log_loss",
penalty="none",
dtype="float32",
)
netwk = MockNetworkServer(
host="localhost",
port=8765,
heartbeat=0.1,
)
optim = {
"client_opt": 0.05,
"server_opt": 1.0,
"fairness": fairness,
}
server = FederatedServer(
model,
netwk=netwk,
optim=optim,
metrics=["binary-classif"],
secagg={"secagg_type": "masking"} if secagg else None,
checkpoint={"folder": folder, "max_history": 1},
)
config = FLRunConfig.from_params(
rounds=5,
register={"min_clients": n_clients, "timeout": 2},
training={"n_epoch": 1, "batch_size": 10},
fairness={"batch_size": 50},
)
await server.async_run(config)
async def client_routine(
train_dst: FairnessInMemoryDataset,
valid_dst: FairnessInMemoryDataset,
id_keys: Optional[IdentityKeys],
) -> None:
"""Run the FL routine of a given client."""
netwk = MockNetworkClient(
server_uri="mock://localhost:8765",
name="client",
)
secagg = (
{"secagg_type": "masking", "id_keys": id_keys} if id_keys else None
)
client = FederatedClient(
netwk=netwk,
train_data=train_dst,
valid_data=valid_dst,
verbose=False,
secagg=secagg,
)
await client.async_run()
@pytest.fixture(name="fairness")
def fairness_fixture(
algorithm: str,
f_type: str,
) -> FairnessControllerServer:
"""Server-side fairness controller providing fixture."""
if algorithm == "fairbatch":
return FairbatchControllerServer(f_type, alpha=0.005, fedfb=False)
if algorithm == "fairfed":
return FairfedControllerServer(f_type, beta=1.0, strict=True)
if algorithm == "fairgrad":
return FairgradControllerServer(f_type, eta=0.5, eps=1e-6)
return FairnessMonitorServer(f_type)
@pytest.mark.parametrize("secagg", [False, True], ids=["clrtxt", "secagg"])
@pytest.mark.parametrize("f_type", ["demographic_parity", "equalized_odds"])
@pytest.mark.parametrize(
"algorithm", ["fairbatch", "fairfed", "fairgrad", "monitor"]
)
@pytest.mark.asyncio
async def test_toy_classif_fairness(
fairness: FairnessControllerServer,
secagg: bool,
tmp_path: str,
) -> None:
"""Test a given fairness-aware federated learning algorithm on toy data.
Set up a toy dataset for fairness-aware federated learning.
Use a given algorithm, with a given group-fairness definition.
Optionally use SecAgg.
Verify that after training for 10 rounds, the learned model achieves
some accuracy and has become fairer that after 5 rounds.
"""
# Set up the toy dataset and optional identity keys for SecAgg.
datasets = generate_toy_dataset(n_clients=3)
clients_id_keys = setup_masking_idkeys(secagg, n_clients=3)
# Set up and run the fairness-aware federated learning experiment.
coro_server = server_routine(fairness, secagg, folder=tmp_path)
coro_clients = [
client_routine(train_dst, valid_dst, id_keys)
for (train_dst, valid_dst), id_keys in zip(datasets, clients_id_keys)
]
outputs = await asyncio.gather(
coro_server, *coro_clients, return_exceptions=True
)
# Assert that no exceptions occurred during the process.
errors = "\n".join(
repr(exc) for exc in outputs if isinstance(exc, Exception)
)
assert not errors, f"The FL process failed:\n{errors}"
# Load and parse utility and fairness metrics at the final round.
u_metrics = pd.read_csv(os.path.join(tmp_path, "metrics.csv"))
f_metrics = pd.read_csv(os.path.join(tmp_path, "fairness_metrics.csv"))
accuracy = u_metrics.iloc[-1]["accuracy"]
fairness_cols = [f"{fairness.f_type}_{group}" for group in fairness.groups]
fairness_mean_abs = f_metrics.iloc[-1][fairness_cols].abs().mean()
# Verify that the FedAvg baseline matches expected accuracy and fairness,
# or that other algorithms achieve lower accuracy and better fairness.
# Note that FairFed is bound to match the FedAvg baseline due to the
# split across clients being uniform.
expected_fairness = {
"demographic_parity": 0.02,
"equalized_odds": 0.11,
}
if fairness.algorithm == "monitor":
assert accuracy >= 0.76
assert fairness_mean_abs > expected_fairness[fairness.f_type]
elif fairness.algorithm == "fairfed":
assert accuracy >= 0.72
assert fairness_mean_abs > expected_fairness[fairness.f_type]
else:
assert 0.76 > accuracy > 0.54
assert fairness_mean_abs < expected_fairness[fairness.f_type]
......@@ -27,6 +27,7 @@ import asyncio
import json
import os
import tempfile
import warnings
from typing import List, Optional, Tuple, Union
import pytest
......@@ -147,7 +148,9 @@ async def async_run_server(
register={"min_clients": n_clients, "timeout": 2},
training={"n_epoch": 1, "batch_size": 1, "drop_remainder": False},
)
await server.async_run(config)
with warnings.catch_warnings():
warnings.simplefilter("ignore", RuntimeWarning)
await server.async_run(config)
async def async_run_client(
......@@ -170,6 +173,18 @@ async def async_run_client(
await client.async_run()
def setup_masking_idkeys(
secagg: bool,
n_clients: int,
) -> Union[List[IdentityKeys], List[None]]:
"""Setup identity keys for SecAgg, or a list of None values."""
if not secagg:
return [None for _ in range(n_clients)]
prv_keys = [Ed25519PrivateKey.generate() for _ in range(n_clients)]
pub_keys = [key.public_key() for key in prv_keys]
return [IdentityKeys(key, trusted=pub_keys) for key in prv_keys]
async def run_declearn_experiment(
scaffold: bool,
secagg: bool,
......@@ -197,14 +212,7 @@ async def run_declearn_experiment(
"""
# Set up the toy dataset(s) and optional identity keys (for SecAgg).
n_clients = len(datasets)
if secagg:
prv_keys = [Ed25519PrivateKey.generate() for _ in range(n_clients)]
pub_keys = [key.public_key() for key in prv_keys]
id_keys = [
IdentityKeys(key, trusted=pub_keys) for key in prv_keys
] # type: Union[List[IdentityKeys], List[None]]
else:
id_keys = [None for _ in range(n_clients)]
id_keys = setup_masking_idkeys(secagg=secagg, n_clients=n_clients)
with tempfile.TemporaryDirectory() as folder:
# Set up the server and client coroutines.
coro_server = async_run_server(folder, scaffold, secagg, n_clients)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment