diff --git a/README.md b/README.md index 0d4863d830bb7a934b8591e16a4076c6a5cabef9..c8e5c40aa3ada052b248558c33a482565cb842a3 100644 --- a/README.md +++ b/README.md @@ -388,6 +388,13 @@ that will run together to perform a federated learning process. Generic remarks from the [Quickstart](#quickstart) section hold here as well, the former section being an overly simple exemplification of the present one. +You can follow along on a concrete example that uses the UCI heart disease +dataset, that is stored in the `examples/uci-heart` folder. You may refer +to the `server.py` and `client.py` example scripts, that comprise comments +indicating how the code relates to the steps described below. For further +details on this example and on how to run it, please refer to its own +`readme.md` file. + #### Server setup instructions 1. Define a Model: diff --git a/examples/heart-uci/client.py b/examples/heart-uci/client.py new file mode 100644 index 0000000000000000000000000000000000000000..c547bca2c98f1cf43f1f0a4354762bf63b4f97d8 --- /dev/null +++ b/examples/heart-uci/client.py @@ -0,0 +1,95 @@ +"""Script to run a federated client on the heart-disease example.""" + +import argparse +import os +import sys + +import numpy as np +import pandas as pd # type: ignore +from declearn.communication import NetworkClientConfig +from declearn.dataset import InMemoryDataset +from declearn.main import FederatedClient + +FILEDIR = os.path.dirname(os.path.abspath(__file__)) +# Perform local imports. +sys.path.append(FILEDIR) +from data import get_data # pylint: disable=wrong-import-order + + +def run_client( + name: str, + ca_cert: str, +) -> None: + """Instantiate and run a given client. + + Arguments + --------- + name: str + Name of the client (i.e. center data from which to use). + ca_cert: str + Path to the certificate authority file that was used to + sign the server's SSL certificate. + """ + + # (1-2) Interface training and optional validation data. + + # Load and randomly split the dataset. + path = os.path.join(FILEDIR, f"data/{name}.csv") + if not os.path.isfile(path): + get_data(os.path.join(FILEDIR, "data"), [name]) + data = pd.read_csv(path) + data = data.loc[np.random.permutation(data.index)] + n_tr = round(len(data) * 0.8) # 80% train, 20% valid + + # Wrap train and validation data as Dataset objects. + train = InMemoryDataset( + data=data.iloc[:n_tr], + target="num", + expose_classes=True, # share unique target labels with server + ) + valid = InMemoryDataset( + data=data.iloc[n_tr:], + target="num", + ) + + # (3) Define network communication parameters. + + # Here, use websockets protocol on localhost:8765, with SSL encryption. + network = NetworkClientConfig( + protocol="websockets", + server_uri="wss://localhost:8765", + name=name, + certificate=ca_cert, + ) + + # (4) Run any necessary import statement. + # => None are required in this example. + + # (5) Instantiate a FederatedClient and run it. + + client = FederatedClient( + network, train, valid, folder=f"{FILEDIR}/results/{name}" + ) + client.run() + + +# Called when the script is called directly (using `python client.py`). +if __name__ == "__main__": + # Parse command-line arguments. + parser = argparse.ArgumentParser() + parser.add_argument( + "name", + type=str, + help="name of your client", + choices=["cleveland", "hungarian", "switzerland", "va"], + ) + parser.add_argument( + "--cert_path", + dest="cert_path", + type=str, + help="path to the client-side ssl certification", + default=os.path.join(FILEDIR, "ca-cert.pem"), + ) + args = parser.parse_args() + # Run the client routine. + run_client(args.name, args.cert_path) diff --git a/examples/heart-uci/data.py b/examples/heart-uci/data.py new file mode 100644 index 0000000000000000000000000000000000000000..74a137c688bff217caddb6dfc31e909079fd9bc8 --- /dev/null +++ b/examples/heart-uci/data.py @@ -0,0 +1,87 @@ +"""Script to download and pre-process the UCI Heart Disease Dataset.""" + +import argparse +import os +from typing import List + +import pandas as pd + +NAMES = ("cleveland", "hungarian", "switzerland", "va") + +COLNAMES = [ + "age", + "sex", + "cp", + "trestbps", + "chol", + "fbs", + "restecg", + "thalach", + "exang", + "oldpeak", + "slope", + "ca", + "thal", + "num", +] + +DATADIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data") + + +def get_data( + dir: str = DATADIR, + names: List[str] = NAMES, +) -> None: + """Download and process the UCI heart disease dataset. + + Arguments + --------- + dir: str + Path to the folder where to write output csv files. + names: list[str] + Names of centers, the dataset from which to download, + pre-process and export as csv files. + """ + for name in names: + print(f"Downloading data from center {name}:") + url = ( + "https://archive.ics.uci.edu/ml/machine-learning-databases/" + f"heart-disease/processed.{name}.data" + ) + print(url) + # Download the dataset. + df = pd.read_csv(url, header=None, na_values="?") + df.columns = COLNAMES + # Drop unused columns and rows with missing values. + df.drop(columns=["ca", "chol", "fbs", "slope", "thal"], inplace=True) + df.dropna(inplace=True) + # Normalize quantitative variables. + for col in ("age", "trestbps", "thalach", "oldpeak"): + df[col] = (df[col] - df[col].mean()) / df[col].std() + # Binarize the target variable. + df["num"] = (df["num"] > 0).astype(int) + # Export the resulting dataset to a csv file. + os.makedirs(dir, exist_ok=True) + df.to_csv(f"{dir}/{name}.csv", index=False) + + +# Code executed when the script is called directly. +if __name__ == "__main__": + # Parse commandline parameters. + parser = argparse.ArgumentParser() + parser.add_argument( + "--dir", + type=str, + default=DATADIR, + help="folder where to write output csv files", + ) + parser.add_argument( + "names", + action="append", + nargs="+", + help="name(s) of client center(s), data from which to prepare", + choices=["cleveland", "hungarian", "switzerland", "va"], + ) + args = parser.parse_args() + # Download and pre-process the selected dataset(s). + get_data(dir=args.dir, names=args.names) diff --git a/examples/heart-uci/gen_ssl.py b/examples/heart-uci/gen_ssl.py new file mode 100644 index 0000000000000000000000000000000000000000..c233c610a8bdc49b4b81e346f865e61e9e91bf81 --- /dev/null +++ b/examples/heart-uci/gen_ssl.py @@ -0,0 +1,10 @@ +"""Script to generate self-signed SSL certificates for the demo.""" + +import os + +from declearn.test_utils import generate_ssl_certificates + + +if __name__ == "__main__": + FILEDIR = os.path.dirname(os.path.abspath(__file__)) + generate_ssl_certificates(FILEDIR) diff --git a/examples/heart-uci/readme.md b/examples/heart-uci/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..a0869b066b99bc791b212603c9ccecb6fd47f5a3 --- /dev/null +++ b/examples/heart-uci/readme.md @@ -0,0 +1,87 @@ +# Demo training task : heart disease prediction + +## Overview + +**We use data from the UCI ML repository** - Heart disease dataset, available +[here](https://archive.ics.uci.edu/ml/datasets/Heart+Disease). The goal is to +predict a binary variable, indicating heart disease, from a set of health +indicators. + +**To simply run the demo**, use the bash command below. You can follow along +the code in the `hands-on` section of the package documentation. For more +details on what running the federated learning processes imply, see the last +section. + +```bash +python run.py +``` + +## Folder structure + +``` +heart-uci/ +│ client.py - set up and launch a federated-learning client +│ data.py - download and preapte the dataset +│ gen_ssl.py - generate self-signed ssl certificates +│ run.py - launch both the server and clients in a single session +│ server.py - set up and launch a federated-learning server +| setup.sh - bash script to prepare client-wise and server isolated folders +└─── data - saved datasets as csv files +└─── results - saved results from training procedure +``` + +## Run training routine + +The simplest way to run the demo is to run it locally, using multiprocessing. +For something closer to real life implementation, we also show a way to run +the demo from different terminals or machines. + +### Locally, for testing and experimentation + +Use : + +```bash +python run.py # note: python examples/heart-uci/run.py works as well +``` + +The `run.py` scripts collects the server and client routines defined under +the `server.py` and `client.py` scripts, and runs them concurrently under +a single python session using multiprocessing. + +This is the easiest way to launch the demo, e.g. to see the effects of +tweaking some learning parameters. + +### On separate terminals or machines + +**To run the examples from different terminals or machines**, +We first ensure data is appropriately distributed between machines, +and the machines can communicate over network using SSL-encrypted +communications. We give the code to simulate this on a single machine. + +We then sequentially run the server then the clients on separate terminals. + +1. **Set up self-signed SSL certificates**:<br/> + Start by running executing the `gen_ssl.py` script. + This creates self-signed SSL certificates: + ```bash + python gen_ssl.py + ``` + Note that in real-life applications, one would most likely use certificates + signed by a trusted certificate authority instead. + +2. **Run the server**:<br/> + Open a terminal and launch the server script for 1 to 4 clients, + using the generated SSL certificates: + ```bash + python server.py 2 # use --help for details on SSL files options + ``` + +3. **Run each client**:<br/> + Open a new terminal and launch the client script, using one of the + dataset's location name and the generated SSL certificate, e.g.: + ```bash + python client.py cleveland # use --help for details on SSL files options + ``` + +Note that the server should be launched before the clients, otherwise the +latter would fail to connect which might cause the script to terminate. diff --git a/examples/heart-uci/run.py b/examples/heart-uci/run.py new file mode 100644 index 0000000000000000000000000000000000000000..10bd93cf68edd8670ed76b82e8c032cdff834927 --- /dev/null +++ b/examples/heart-uci/run.py @@ -0,0 +1,41 @@ +"""Demonstration script using the UCI Heart Disease Dataset.""" + +import os +import sys +import tempfile + +# Perform local imports. +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +from client import run_client # pylint: disable=wrong-import-position +from server import run_server # pylint: disable=wrong-import-position + +from declearn.test_utils import ( + generate_ssl_certificates, + run_as_processes, +) + + +NAMES = ["cleveland", "hungarian", "switzerland", "va"] + + +def run_demo( + nb_clients: int = 4, +) -> None: + """Run a server and its clients using multiprocessing.""" + # Use a temporary directory for single-use self-signed SSL files. + with tempfile.TemporaryDirectory() as folder: + # Generate self-signed SSL certificates and gather their paths. + ca_cert, sv_cert, sv_pkey = generate_ssl_certificates(folder) + # Specify the server and client routines that need executing. + server = (run_server, (nb_clients, sv_cert, sv_pkey)) + clients = [ + (run_client, (name, ca_cert)) for name in NAMES[:nb_clients] + ] + # Run routines in isolated processes. Raise if any failed. + exitcodes = run_as_processes(server, *clients) + if any(code != 0 for code in exitcodes): + raise RuntimeError("Something went wrong during the demo.") + + +if __name__ == "__main__": + run_demo() diff --git a/examples/heart-uci/server.py b/examples/heart-uci/server.py new file mode 100644 index 0000000000000000000000000000000000000000..4f7ef5ff9dd432b925a9423f69ccc14bf3071307 --- /dev/null +++ b/examples/heart-uci/server.py @@ -0,0 +1,125 @@ +"""Script to run a federated server on the heart-disease example.""" + +import argparse +import os + +from declearn.communication import NetworkServerConfig +from declearn.main import FederatedServer +from declearn.model.sklearn import SklearnSGDModel +from declearn.optimizer.modules import MomentumModule, RMSPropModule +from declearn.strategy import strategy_from_config + + +FILEDIR = os.path.dirname(os.path.abspath(__file__)) + + +def run_server( + nb_clients: int, + sv_cert: str, + sv_priv: str, +) -> None: + """Instantiate and run the orchestrating server. + + Arguments + --------- + nb_clients: int + Exact number of clients used in this example. + sv_cert: str + Path to the (self-signed) SSL certificate to use. + sv_priv: str + Path to the associated private-key to use. + """ + + # (1) Define a model + + # Here we use a scikit-learn SGD classifier and parametrize it + # into a L2-penalized binary logistic regression. + model = SklearnSGDModel.from_parameters( + kind="classifier", loss="log_loss", penalty="l2", alpha=0.005 + ) + + # (2) Define a strategy + + # Configure the aggregator to use. + # Here, averaging weighted by the effective number + # of local gradient descent steps taken. + aggregator = { + "name": "Average", + "config": {"steps_weighted": True}, + } + + # Configure the client-side optimizer to use. + # Here, RMSProp optimizer with 0.02 learning rate. + client_opt = { + "lrate": 0.02, + "modules": [RMSPropModule()], + } + + # Configure the server-side optimizer to use. + # Here, apply momentum to the updates and apply them (as lr=1.0). + server_opt = { + "lrate": 1.0, + "modules": [MomentumModule()], + } + + # Wrap this up into a Strategy object$ + config = { + "aggregator": aggregator, + "client_opt": client_opt, + "server_opt": server_opt, + } + strategy = strategy_from_config(config) + + # (3) Define network communication parameters. + + # Here, use websockets protocol on localhost:8765, with SSL encryption. + network = NetworkServerConfig( + protocol="websockets", + host="localhost", + port=8765, + certificate=sv_cert, + private_key=sv_priv, + ) + + # (4) Instantiate and run a FederatedServer. + + server = FederatedServer(model, network, strategy) + # Here, we setup 20 rounds of training, with 30 samples per batch + # during training and 50 during validation; plus an early-stopping + # criterion if the global validation loss stops decreasing for 5 rounds. + server.run( + rounds=20, + regst_cfg={"min_clients": nb_clients}, + train_cfg={"batch_size": 30, "drop_remainder": False}, + valid_cfg={"batch_size": 50, "drop_remainder": False}, + early_stop={"tolerance": 0.0, "patience": 5, "relative": False}, + ) + + +# Called when the script is called directly (using `python server.py`). +if __name__ == "__main__": + # Parse command-line arguments. + parser = argparse.ArgumentParser() + parser.add_argument( + "nb_clients", + type=int, + help="number of clients", + choices=[1, 2, 3, 4], + ) + parser.add_argument( + "--cert_path", + dest="cert_path", + type=str, + help="path to the server-side ssl certification", + default=os.path.join(FILEDIR, "server-cert.pem"), + ) + parser.add_argument( + "--key_path", + dest="key_path", + type=str, + help="path to the server-side ssl private key", + default=os.path.join(FILEDIR, "server-key.pem"), + ) + args = parser.parse_args() + # Run the server routine. + run_server(args.nb_clients, args.cert_path, args.key_path)