From 8ac8411a47e97e72c9a639747a8af5e6bec62c1c Mon Sep 17 00:00:00 2001 From: BIGAUD Nathan <nathan.bigaud@inria.fr> Date: Fri, 14 Apr 2023 11:38:06 +0200 Subject: [PATCH] Local MNIST example --- examples/mnist/client.py | 136 +++++++++++++++++++++++++++ examples/mnist/gen_ssl.py | 27 ++++++ examples/mnist/readme.md | 137 +++++++++++++++++++++++++++ examples/mnist/run.py | 72 ++++++++++++++ examples/mnist/server.py | 191 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 563 insertions(+) create mode 100644 examples/mnist/client.py create mode 100644 examples/mnist/gen_ssl.py create mode 100644 examples/mnist/readme.md create mode 100644 examples/mnist/run.py create mode 100644 examples/mnist/server.py diff --git a/examples/mnist/client.py b/examples/mnist/client.py new file mode 100644 index 00000000..a727a4a3 --- /dev/null +++ b/examples/mnist/client.py @@ -0,0 +1,136 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Script to run a federated client on the heart-disease example.""" + +import datetime +import logging +import os + +import declearn +import declearn.model.tensorflow +import fire + +FILEDIR = os.path.dirname(__file__) + + +def run_client( + client_name: str, + ca_cert: str, + data_folder: str, + protocol: str = "websockets", + serv_uri: str = "wss://localhost:8765", + verbose: bool = True, +) -> None: + + """Instantiate and run a given client. + + Parameters + --------- + client_name: str + Name of the client (i.e. center data from which to use). + ca_cert: str + Path to the certificate authority file that was used to + sign the server's SSL certificate. + data_folder: str + The parent folder of this client's data + protocol: str, default="websockets" + Name of the communication protocol to use. + serv_uri: str, default="wss://localhost:8765" + URI of the server to which to connect. + verbose: + Whether to log everything to the console, or filter out most non-error + information. + """ + + ### Optional: some convenience settings + + # Set CPU as device + declearn.utils.set_device_policy(gpu=False) + + # Set up logger and checkpointer + stamp = datetime.datetime.now() + stamp = stamp - datetime.timedelta( + minutes=stamp.minute % 5, + seconds=stamp.second, + microseconds=stamp.microsecond, + ) + stamp = stamp.strftime("%y-%m-%d_%H-%M") + checkpoint = os.path.join(FILEDIR, f"result_{stamp}", client_name) + logger = declearn.utils.get_logger( + name=client_name, + fpath=os.path.join(checkpoint, "logs.txt"), + ) + + # Reduce logger verbosity + if not verbose: + for handler in logger.handlers: + if isinstance(handler, logging.StreamHandler): + handler.setLevel(declearn.utils.LOGGING_LEVEL_MAJOR) + + ### (1-2) Interface training and optional validation data. + + # Target the proper dataset (specific to our MNIST setup). + data_folder = os.path.join(FILEDIR, data_folder, client_name) + + # Interface the data through the generic `InMemoryDataset` class. + train = declearn.dataset.InMemoryDataset( + os.path.join(data_folder, "train_data.npy"), + os.path.join(data_folder, "train_target.npy"), + ) + valid = declearn.dataset.InMemoryDataset( + os.path.join(data_folder, "valid_data.npy"), + os.path.join(data_folder, "valid_target.npy"), + ) + + ### (3) Define network communication parameters. + + # Here, use websockets protocol on localhost:8765, + # with SSL encryption. + network = declearn.communication.build_client( + protocol=protocol, + server_uri=serv_uri, + name=client_name, + certificate=ca_cert, + ) + + ### (4) Run any necessary import statement. + # We imported `import declearn.model.tensorflow` + + ### (5) Instantiate a FederatedClient and run it. + + client = declearn.main.FederatedClient( + netwk=network, + train_data=train, + valid_data=valid, + checkpoint=checkpoint, + logger=logger, + ) + client.run() + + +# This part should not be altered: it provides with an argument parser +# for `python client.py`). + + +def main(): + "fire-wrapped split data" + fire.Fire(run_client) + + +if __name__ == "__main__": + main() diff --git a/examples/mnist/gen_ssl.py b/examples/mnist/gen_ssl.py new file mode 100644 index 00000000..94f81e98 --- /dev/null +++ b/examples/mnist/gen_ssl.py @@ -0,0 +1,27 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Script to generate self-signed SSL certificates for the demo.""" + +import os + +from declearn.test_utils import generate_ssl_certificates + + +if __name__ == "__main__": + FILEDIR = os.path.dirname(os.path.abspath(__file__)) + generate_ssl_certificates(FILEDIR) diff --git a/examples/mnist/readme.md b/examples/mnist/readme.md new file mode 100644 index 00000000..5198ecd0 --- /dev/null +++ b/examples/mnist/readme.md @@ -0,0 +1,137 @@ +# Demo training task : MNIST + +## Overview + +**We are going to train a common model between three simulated clients on the +classic [MNIST dataset](http://yann.lecun.com/exdb/mnist/)**. The input of the +model is a set of images of handwritten digits, and the model needs to +determine which number between $0$ and $9$ each image corresponds to. + +## Setup + +To be able to experiment with this tutorial: + +* Clone the declearn repo, on the experimental branch: + +```bash +git clone -b experimental git@gitlab.inria.fr:magnet/declearn/declearn2.git declearn +``` + +* Create a dedicated virtual environment. +* Install declearn in it from the local repo: + +```bash +cd declearn && pip install .[websockets,tensorflow] && cd .. +``` + +In an FL experiment, we consider your data as a given. So before running +the experiment below, split the MNIST data using : + +```bash +declearn-split --folder "examples/mnist" --n_shards 3 +``` + +## Contents + +This script runs a FL experiment using MNIST. The folder is structured +the following way: + +``` +mnist/ +│ client.py - set up and launch a federated-learning client +│ gen_ssl.py - generate self-signed ssl certificates +│ run.py - launch both the server and clients in a single session +│ server.py - set up and launch a federated-learning server +└─── data - data split by client, created with the `split_data` util +└─── results - saved results from training procedure +``` + +## Run training routine + +The simplest way to run the demo is to run it locally, using multiprocessing. +For something closer to real life implementation, we also show a way to run +the demo from different terminals or machines. + +### Locally, for testing and experimentation + +**To simply run the demo**, use the bash command below. You can follow along +the code in the `hands-on` section of the package documentation. For more +details on what running the federated learning processes imply, see the last +section. + +```bash +cd +python run.py +``` + +Use : + +```bash +python run.py # note: python examples/heart-uci/run.py works as well +``` + +The `run.py` scripts collects the server and client routines defined under +the `server.py` and `client.py` scripts, and runs them concurrently under +a single python session using multiprocessing. + +This is the easiest way to launch the demo, e.g. to see the effects of +tweaking some learning parameters. + +### On separate terminals or machines + +**To run the examples from different terminals or machines**, +we first ensure data is appropriately distributed between machines, +and the machines can communicate over network using SSL-encrypted +communications. We give the code to simulate this on a single machine. +We then sequentially run the server then the clients on separate terminals. + +1. **Set up SSL certificates**:<br/> + Start by creating a signed SSL certificate for the server and sharing the + CA file with each and every clients. The CA may be self-signed. + + When testing locally, execute the `gen_ssl.py` script, to create a + self-signed root CA and an SSL certificate for "localhost": + + ```bash + python gen_ssl.py + ``` + + Note that in real-life applications, one would most likely use certificates + certificates signed by a trusted certificate authority instead. + Alternatively, `declearn.test_utils.gen_ssl_certificates` may be used to + generate a self-signed CA and a signed certificate for a given domain name + or IP address. + +2. **Run the server**:<br/> + Open a terminal and launch the server script for 1 to 4 clients, + specifying the path to the SSL certificate and private key files, + and network parameters. By default, things will run on the local + host, looking for `gen_ssl.py`-created PEM files. + + E.g., to use 2 clients: + + ```bash + python server.py 2 # use --help for details on network and SSL options + ``` + +3. **Run each client**:<br/> + Open a new terminal and launch the client script, specifying one of the + dataset-provider names, and optionally the path the CA file and network + parameters. By default, things will run on the local host, looking for + a `gen_ssl.py`-created CA PEM file. + + E.g., to launch a client using the "cleveland" dataset: + + ```bash + python client.py cleveland # use --help for details on other options + ``` + +Note that the server should be launched before the clients, otherwise the +latter might fail to connect which would cause the script to terminate. A +few seconds' delay is tolerable as clients will make multiple connection +attempts prior to failing. + +**To run the example in a real-life setting**, follow the instructions from +this section, after having generated and shared the appropriate PEM files to +set up SSL-encryption, and using additional script parameters to specify the +network host and port to use. diff --git a/examples/mnist/run.py b/examples/mnist/run.py new file mode 100644 index 00000000..80f24ef8 --- /dev/null +++ b/examples/mnist/run.py @@ -0,0 +1,72 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Demonstration script using the UCI Heart Disease Dataset.""" + +import glob +import os +import tempfile + +import fire + +from declearn.test_utils import generate_ssl_certificates, make_importable +from declearn.utils import run_as_processes + +# Perform local imports. +# pylint: disable=wrong-import-position, wrong-import-order +with make_importable(os.path.dirname(__file__)): + from client import run_client + from server import run_server +# pylint: enable=wrong-import-position, wrong-import-order + +FILEDIR = os.path.join(os.path.dirname(__file__)) +DATADIR = glob.glob(f"{FILEDIR}/data*")[0] + + +def run_demo(nb_clients: int = 3, data_folder: str = DATADIR) -> None: + """Run a server and its clients using multiprocessing. + + Parameters + ------ + + n_clients: int + number of clients to run. + data_folder: str + Relative path to the folder holding client's data + + """ + # Use a temporary directory for single-use self-signed SSL files. + with tempfile.TemporaryDirectory() as folder: + # Generate self-signed SSL certificates and gather their paths. + ca_cert, sv_cert, sv_pkey = generate_ssl_certificates(folder) + # Specify the server and client routines that need executing. + server = (run_server, (nb_clients, sv_cert, sv_pkey)) + clients = [ + (run_client, (f"client_{idx}", ca_cert, data_folder)) + for idx in range(nb_clients) + ] + # Run routines in isolated processes. Raise if any failed. + success, outp = run_as_processes(server, *clients) + if not success: + raise RuntimeError( + "Something went wrong during the demo. Exceptions caught:\n" + "\n".join(str(e) for e in outp if isinstance(e, RuntimeError)) + ) + + +if __name__ == "__main__": + fire.Fire(run_demo) diff --git a/examples/mnist/server.py b/examples/mnist/server.py new file mode 100644 index 00000000..552a9bb5 --- /dev/null +++ b/examples/mnist/server.py @@ -0,0 +1,191 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Script to run a federated server on the heart-disease example.""" + +import datetime +import os + +import fire +import tensorflow as tf # type: ignore + +import declearn + +FILEDIR = os.path.dirname(os.path.abspath(__file__)) +DEFAULT_CERT = os.path.join(FILEDIR, "server-cert.pem") +DEFAULT_PKEY = os.path.join(FILEDIR, "server-pkey.pem") + + +def run_server( + nb_clients: int, + certificate: str = DEFAULT_CERT, + private_key: str = DEFAULT_PKEY, + protocol: str = "websockets", + host: str = "localhost", + port: int = 8765, +) -> None: + """Instantiate and run the orchestrating server. + + Arguments + --------- + nb_clients: int + Exact number of clients used in this example. + certificate: str + Path to the (self-signed) SSL certificate to use. + private_key: str + Path to the associated private-key to use. + protocol: str, default="websockets" + Name of the communication protocol to use. + host: str, default="localhost" + Hostname or IP address on which to serve. + port: int, default=8765 + Communication port on which to serve. + """ + + ### Optional: some convenience settings + + # Set CPU as device + declearn.utils.set_device_policy(gpu=False) + + # Set up metrics suitable for MNIST. + metrics = declearn.metrics.MetricSet( + [ + declearn.metrics.MulticlassAccuracyPrecisionRecall( + labels=range(10) + ), + ] + ) + + # Set up checkpointing and logging. + stamp = datetime.datetime.now() + stamp = stamp - datetime.timedelta( + minutes=stamp.minute % 5, + seconds=stamp.second, + microseconds=stamp.microsecond, + ) + stamp = stamp.strftime("%y-%m-%d_%H-%M") + checkpoint = os.path.join(FILEDIR, f"result_{stamp}", "server") + # Set up a logger, records from which will go to a file. + logger = declearn.utils.get_logger( + name="Server", + fpath=os.path.join(checkpoint, "logs.txt"), + ) + + ### (1) Define a model + + # Here we use a scikit-learn SGD classifier and parametrize it + # into a L2-penalized binary logistic regression. + stack = [ + tf.keras.layers.InputLayer(input_shape=(28, 28, 1)), + tf.keras.layers.Conv2D(32, 3, 1, activation="relu"), + tf.keras.layers.MaxPool2D(2), + tf.keras.layers.Dropout(0.25), + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(10, activation="softmax"), + ] + model = declearn.model.tensorflow.TensorflowModel( + model=tf.keras.Sequential(stack), + loss="sparse_categorical_crossentropy", + ) + + ### (2) Define an optimization strategy + + # Set up the cient updates' aggregator. By default: FedAvg. + aggregator = declearn.aggregator.AveragingAggregator() + + # Set up the server-side optimizer (to refine aggregated updates). + # By default: no refinement (lrate=1.0, no plug-ins). + server_opt = declearn.optimizer.Optimizer( + lrate=1.0, + w_decay=0.0, + modules=None, + ) + + # Set up the client-side optimizer (for local SGD steps). + # By default: vanilla SGD, with a selected learning rate. + client_opt = declearn.optimizer.Optimizer( + lrate=0.001, + w_decay=0.0, + regularizers=None, + modules=None, + ) + + # Wrap all this into a FLOptimConfig. + optim = declearn.main.config.FLOptimConfig.from_params( + aggregator=aggregator, + server_opt=server_opt, + client_opt=client_opt, + ) + + ### (3) Define network communication parameters. + + # Here, use websockets protocol on localhost:8765, with SSL encryption. + network = declearn.communication.build_server( + protocol=protocol, + host=host, + port=port, + certificate=certificate, + private_key=private_key, + ) + + ### (4) Instantiate and run a FederatedServer. + + # Instanciate + server = declearn.main.FederatedServer( + model=model, + netwk=network, + optim=optim, + metrics=metrics, + checkpoint=checkpoint, + logger=logger, + ) + + # Set up the experiment's hyper-parameters. + # Registration rules: wait for 10 seconds at registration. + register = declearn.main.config.RegisterConfig(timeout=10) + # Training rounds hyper-parameters. By default, 1 epoch / round. + training = declearn.main.config.TrainingConfig( + batch_size=32, + n_epoch=1, + ) + # Evaluation rounds. by default, 1 epoch with train's batch size. + evaluate = declearn.main.config.EvaluateConfig( + batch_size=128, + ) + # Wrap all this into a FLRunConfig. + run_config = declearn.main.config.FLRunConfig.from_params( + rounds=5, # you may change the number of training rounds + register=register, + training=training, + evaluate=evaluate, + privacy=None, # you may set up local DP (DP-SGD) here + early_stop=None, # you may add an early-stopping cirterion here + ) + server.run(run_config) + + +# This part should not be altered: it provides with an argument parser. +# for `python server.py`. + + +def main(): + "fire-wrapped split data" + fire.Fire(run_server) + + +if __name__ == "__main__": + main() -- GitLab