Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 9e9efb35 authored by BIGAUD Nathan's avatar BIGAUD Nathan Committed by ANDREY Paul
Browse files

Implement UCI heart disease dataset example.


Co-authored-by: default avatarPaul Andrey <paul.andrey@inria.fr>
parent 06250995
No related branches found
No related tags found
1 merge request!5Add examples folder
Pipeline #649056 passed
......@@ -388,6 +388,13 @@ that will run together to perform a federated learning process. Generic
remarks from the [Quickstart](#quickstart) section hold here as well, the
former section being an overly simple exemplification of the present one.
You can follow along on a concrete example that uses the UCI heart disease
dataset, that is stored in the `examples/uci-heart` folder. You may refer
to the `server.py` and `client.py` example scripts, that comprise comments
indicating how the code relates to the steps described below. For further
details on this example and on how to run it, please refer to its own
`readme.md` file.
#### Server setup instructions
1. Define a Model:
......
"""Script to run a federated client on the heart-disease example."""
import argparse
import os
import sys
import numpy as np
import pandas as pd # type: ignore
from declearn.communication import NetworkClientConfig
from declearn.dataset import InMemoryDataset
from declearn.main import FederatedClient
FILEDIR = os.path.dirname(os.path.abspath(__file__))
# Perform local imports.
sys.path.append(FILEDIR)
from data import get_data # pylint: disable=wrong-import-order
def run_client(
name: str,
ca_cert: str,
) -> None:
"""Instantiate and run a given client.
Arguments
---------
name: str
Name of the client (i.e. center data from which to use).
ca_cert: str
Path to the certificate authority file that was used to
sign the server's SSL certificate.
"""
# (1-2) Interface training and optional validation data.
# Load and randomly split the dataset.
path = os.path.join(FILEDIR, f"data/{name}.csv")
if not os.path.isfile(path):
get_data(os.path.join(FILEDIR, "data"), [name])
data = pd.read_csv(path)
data = data.loc[np.random.permutation(data.index)]
n_tr = round(len(data) * 0.8) # 80% train, 20% valid
# Wrap train and validation data as Dataset objects.
train = InMemoryDataset(
data=data.iloc[:n_tr],
target="num",
expose_classes=True, # share unique target labels with server
)
valid = InMemoryDataset(
data=data.iloc[n_tr:],
target="num",
)
# (3) Define network communication parameters.
# Here, use websockets protocol on localhost:8765, with SSL encryption.
network = NetworkClientConfig(
protocol="websockets",
server_uri="wss://localhost:8765",
name=name,
certificate=ca_cert,
)
# (4) Run any necessary import statement.
# => None are required in this example.
# (5) Instantiate a FederatedClient and run it.
client = FederatedClient(
network, train, valid, folder=f"{FILEDIR}/results/{name}"
)
client.run()
# Called when the script is called directly (using `python client.py`).
if __name__ == "__main__":
# Parse command-line arguments.
parser = argparse.ArgumentParser()
parser.add_argument(
"name",
type=str,
help="name of your client",
choices=["cleveland", "hungarian", "switzerland", "va"],
)
parser.add_argument(
"--cert_path",
dest="cert_path",
type=str,
help="path to the client-side ssl certification",
default=os.path.join(FILEDIR, "ca-cert.pem"),
)
args = parser.parse_args()
# Run the client routine.
run_client(args.name, args.cert_path)
"""Script to download and pre-process the UCI Heart Disease Dataset."""
import argparse
import os
from typing import List
import pandas as pd
NAMES = ("cleveland", "hungarian", "switzerland", "va")
COLNAMES = [
"age",
"sex",
"cp",
"trestbps",
"chol",
"fbs",
"restecg",
"thalach",
"exang",
"oldpeak",
"slope",
"ca",
"thal",
"num",
]
DATADIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")
def get_data(
dir: str = DATADIR,
names: List[str] = NAMES,
) -> None:
"""Download and process the UCI heart disease dataset.
Arguments
---------
dir: str
Path to the folder where to write output csv files.
names: list[str]
Names of centers, the dataset from which to download,
pre-process and export as csv files.
"""
for name in names:
print(f"Downloading data from center {name}:")
url = (
"https://archive.ics.uci.edu/ml/machine-learning-databases/"
f"heart-disease/processed.{name}.data"
)
print(url)
# Download the dataset.
df = pd.read_csv(url, header=None, na_values="?")
df.columns = COLNAMES
# Drop unused columns and rows with missing values.
df.drop(columns=["ca", "chol", "fbs", "slope", "thal"], inplace=True)
df.dropna(inplace=True)
# Normalize quantitative variables.
for col in ("age", "trestbps", "thalach", "oldpeak"):
df[col] = (df[col] - df[col].mean()) / df[col].std()
# Binarize the target variable.
df["num"] = (df["num"] > 0).astype(int)
# Export the resulting dataset to a csv file.
os.makedirs(dir, exist_ok=True)
df.to_csv(f"{dir}/{name}.csv", index=False)
# Code executed when the script is called directly.
if __name__ == "__main__":
# Parse commandline parameters.
parser = argparse.ArgumentParser()
parser.add_argument(
"--dir",
type=str,
default=DATADIR,
help="folder where to write output csv files",
)
parser.add_argument(
"names",
action="append",
nargs="+",
help="name(s) of client center(s), data from which to prepare",
choices=["cleveland", "hungarian", "switzerland", "va"],
)
args = parser.parse_args()
# Download and pre-process the selected dataset(s).
get_data(dir=args.dir, names=args.names)
"""Script to generate self-signed SSL certificates for the demo."""
import os
from declearn.test_utils import generate_ssl_certificates
if __name__ == "__main__":
FILEDIR = os.path.dirname(os.path.abspath(__file__))
generate_ssl_certificates(FILEDIR)
# Demo training task : heart disease prediction
## Overview
**We use data from the UCI ML repository** - Heart disease dataset, available
[here](https://archive.ics.uci.edu/ml/datasets/Heart+Disease). The goal is to
predict a binary variable, indicating heart disease, from a set of health
indicators.
**To simply run the demo**, use the bash command below. You can follow along
the code in the `hands-on` section of the package documentation. For more
details on what running the federated learning processes imply, see the last
section.
```bash
python run.py
```
## Folder structure
```
heart-uci/
│ client.py - set up and launch a federated-learning client
│ data.py - download and preapte the dataset
│ gen_ssl.py - generate self-signed ssl certificates
│ run.py - launch both the server and clients in a single session
│ server.py - set up and launch a federated-learning server
| setup.sh - bash script to prepare client-wise and server isolated folders
└─── data - saved datasets as csv files
└─── results - saved results from training procedure
```
## Run training routine
The simplest way to run the demo is to run it locally, using multiprocessing.
For something closer to real life implementation, we also show a way to run
the demo from different terminals or machines.
### Locally, for testing and experimentation
Use :
```bash
python run.py # note: python examples/heart-uci/run.py works as well
```
The `run.py` scripts collects the server and client routines defined under
the `server.py` and `client.py` scripts, and runs them concurrently under
a single python session using multiprocessing.
This is the easiest way to launch the demo, e.g. to see the effects of
tweaking some learning parameters.
### On separate terminals or machines
**To run the examples from different terminals or machines**,
We first ensure data is appropriately distributed between machines,
and the machines can communicate over network using SSL-encrypted
communications. We give the code to simulate this on a single machine.
We then sequentially run the server then the clients on separate terminals.
1. **Set up self-signed SSL certificates**:<br/>
Start by running executing the `gen_ssl.py` script.
This creates self-signed SSL certificates:
```bash
python gen_ssl.py
```
Note that in real-life applications, one would most likely use certificates
signed by a trusted certificate authority instead.
2. **Run the server**:<br/>
Open a terminal and launch the server script for 1 to 4 clients,
using the generated SSL certificates:
```bash
python server.py 2 # use --help for details on SSL files options
```
3. **Run each client**:<br/>
Open a new terminal and launch the client script, using one of the
dataset's location name and the generated SSL certificate, e.g.:
```bash
python client.py cleveland # use --help for details on SSL files options
```
Note that the server should be launched before the clients, otherwise the
latter would fail to connect which might cause the script to terminate.
"""Demonstration script using the UCI Heart Disease Dataset."""
import os
import sys
import tempfile
# Perform local imports.
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from client import run_client # pylint: disable=wrong-import-position
from server import run_server # pylint: disable=wrong-import-position
from declearn.test_utils import (
generate_ssl_certificates,
run_as_processes,
)
NAMES = ["cleveland", "hungarian", "switzerland", "va"]
def run_demo(
nb_clients: int = 4,
) -> None:
"""Run a server and its clients using multiprocessing."""
# Use a temporary directory for single-use self-signed SSL files.
with tempfile.TemporaryDirectory() as folder:
# Generate self-signed SSL certificates and gather their paths.
ca_cert, sv_cert, sv_pkey = generate_ssl_certificates(folder)
# Specify the server and client routines that need executing.
server = (run_server, (nb_clients, sv_cert, sv_pkey))
clients = [
(run_client, (name, ca_cert)) for name in NAMES[:nb_clients]
]
# Run routines in isolated processes. Raise if any failed.
exitcodes = run_as_processes(server, *clients)
if any(code != 0 for code in exitcodes):
raise RuntimeError("Something went wrong during the demo.")
if __name__ == "__main__":
run_demo()
"""Script to run a federated server on the heart-disease example."""
import argparse
import os
from declearn.communication import NetworkServerConfig
from declearn.main import FederatedServer
from declearn.model.sklearn import SklearnSGDModel
from declearn.optimizer.modules import MomentumModule, RMSPropModule
from declearn.strategy import strategy_from_config
FILEDIR = os.path.dirname(os.path.abspath(__file__))
def run_server(
nb_clients: int,
sv_cert: str,
sv_priv: str,
) -> None:
"""Instantiate and run the orchestrating server.
Arguments
---------
nb_clients: int
Exact number of clients used in this example.
sv_cert: str
Path to the (self-signed) SSL certificate to use.
sv_priv: str
Path to the associated private-key to use.
"""
# (1) Define a model
# Here we use a scikit-learn SGD classifier and parametrize it
# into a L2-penalized binary logistic regression.
model = SklearnSGDModel.from_parameters(
kind="classifier", loss="log_loss", penalty="l2", alpha=0.005
)
# (2) Define a strategy
# Configure the aggregator to use.
# Here, averaging weighted by the effective number
# of local gradient descent steps taken.
aggregator = {
"name": "Average",
"config": {"steps_weighted": True},
}
# Configure the client-side optimizer to use.
# Here, RMSProp optimizer with 0.02 learning rate.
client_opt = {
"lrate": 0.02,
"modules": [RMSPropModule()],
}
# Configure the server-side optimizer to use.
# Here, apply momentum to the updates and apply them (as lr=1.0).
server_opt = {
"lrate": 1.0,
"modules": [MomentumModule()],
}
# Wrap this up into a Strategy object$
config = {
"aggregator": aggregator,
"client_opt": client_opt,
"server_opt": server_opt,
}
strategy = strategy_from_config(config)
# (3) Define network communication parameters.
# Here, use websockets protocol on localhost:8765, with SSL encryption.
network = NetworkServerConfig(
protocol="websockets",
host="localhost",
port=8765,
certificate=sv_cert,
private_key=sv_priv,
)
# (4) Instantiate and run a FederatedServer.
server = FederatedServer(model, network, strategy)
# Here, we setup 20 rounds of training, with 30 samples per batch
# during training and 50 during validation; plus an early-stopping
# criterion if the global validation loss stops decreasing for 5 rounds.
server.run(
rounds=20,
regst_cfg={"min_clients": nb_clients},
train_cfg={"batch_size": 30, "drop_remainder": False},
valid_cfg={"batch_size": 50, "drop_remainder": False},
early_stop={"tolerance": 0.0, "patience": 5, "relative": False},
)
# Called when the script is called directly (using `python server.py`).
if __name__ == "__main__":
# Parse command-line arguments.
parser = argparse.ArgumentParser()
parser.add_argument(
"nb_clients",
type=int,
help="number of clients",
choices=[1, 2, 3, 4],
)
parser.add_argument(
"--cert_path",
dest="cert_path",
type=str,
help="path to the server-side ssl certification",
default=os.path.join(FILEDIR, "server-cert.pem"),
)
parser.add_argument(
"--key_path",
dest="key_path",
type=str,
help="path to the server-side ssl private key",
default=os.path.join(FILEDIR, "server-key.pem"),
)
args = parser.parse_args()
# Run the server routine.
run_server(args.nb_clients, args.cert_path, args.key_path)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment