diff --git a/examples/mnist/gen_ssl.py b/examples/mnist/generate_ssl.py similarity index 100% rename from examples/mnist/gen_ssl.py rename to examples/mnist/generate_ssl.py diff --git a/examples/mnist/prepare_data.py b/examples/mnist/prepare_data.py new file mode 100644 index 0000000000000000000000000000000000000000..769109752eb3153ea014b22d319789c55c03d493 --- /dev/null +++ b/examples/mnist/prepare_data.py @@ -0,0 +1,93 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data-preparation script for the MNIST dataset.""" + +import os +from typing import Literal, Optional + +import fire # type: ignore + +from declearn.dataset.examples import load_mnist +from declearn.dataset.utils import ( + save_data_array, + split_multi_classif_dataset, +) + + +DATADIR = os.path.join(os.path.dirname(__file__), "data") + + +def prepare_mnist( + nb_clients: int, + scheme: Literal["iid", "labels", "biased"] = "iid", + folder: str = DATADIR, + seed: Optional[int] = None, +) -> str: + """Fetch and split the MNIST dataset to use it federatively. + + Parameters + ---------- + nb_clients: + Number of shards between which to split the raw MNIST data. + scheme: + Splitting scheme to use. In all cases, shards contain mutually- + exclusive samples and cover the full dataset. See details below. + folder: + Path to the root folder where to export the raw and split data, + using adequately-named subfolders. + seed: + Optional seed to the RNG used for all sampling operations. + + Data-splitting schemes + ---------------------- + + - If "iid", split the dataset through iid random sampling. + - If "labels", split into shards that hold all samples associated + with mutually-exclusive target classes. + - If "biased", split the dataset through random sampling according + to a shard-specific random labels distribution. + """ + # Download (or reload) the raw MNIST data. + datadir_raw = os.path.join(folder, "mnist_raw") + dataset_raw = load_mnist(train=True, folder=datadir_raw) + # Split it based on the input arguments. + print(f"Splitting MNIST into {nb_clients} shards using '{scheme}' scheme.") + split_data = split_multi_classif_dataset( + dataset_raw, n_shards=nb_clients, scheme=scheme, seed=seed + ) + # Export shard data into expected folder structure. + folder = os.path.join(folder, f"mnist_{scheme}") + for idx, ((x_t, y_t), (x_v, y_v)) in enumerate(split_data): + save_data_array( + os.path.join(folder, f"client_{idx}", "train_data"), x_t + ) + save_data_array( + os.path.join(folder, f"client_{idx}", "train_target"), y_t + ) + save_data_array( + os.path.join(folder, f"client_{idx}", "valid_data"), x_v + ) + save_data_array( + os.path.join(folder, f"client_{idx}", "valid_target"), y_v + ) + # Return the path to the split data folder. + return folder + + +if __name__ == "__main__": + fire.Fire(prepare_mnist) diff --git a/examples/mnist/readme.md b/examples/mnist/readme.md index bdff64b17812e116aac6836a5220b1b7e55c8218..250524b83091c2b805bece09b1b7d956f1f1032d 100644 --- a/examples/mnist/readme.md +++ b/examples/mnist/readme.md @@ -24,15 +24,6 @@ git clone git@gitlab.inria.fr:magnet/declearn/declearn2.git declearn cd declearn && pip install .[websockets,tensorflow] && cd .. ``` -In an FL experiment, we consider your data as a given. So before running -the experiment below, download and split the MNIST data using: - -```bash -declearn-split --folder "examples/mnist" --n_shards 3 -``` - -You may add `--seed <some_number>` if you want to ensure reproducibility. - ## Contents This script runs a FL experiment using MNIST. The folder is structured @@ -40,12 +31,13 @@ the following way: ``` mnist/ -│ client.py - set up and launch a federated-learning client -│ gen_ssl.py - generate self-signed ssl certificates -│ run.py - launch both the server and clients in a single session -│ server.py - set up and launch a federated-learning server -└─── data - data split by client, created with the `split_data` util -└─── results - saved results from training procedure +│ generate_ssl.py - generate self-signed ssl certificates +| prepare_data.py - fetch and split the MNIST dataset for FL use +| run_client.py - set up and launch a federated-learning client +│ run_demo.py - simulate the entire FL process in a single session +│ run_server.py - set up and launch a federated-learning server +└─── data - data folder, containing raw and split MNIST data +└─── results_<time> - saved results from training procedures ``` ## Run training routine @@ -63,15 +55,18 @@ section. ```bash cd declearn/examples/mnist/ -python run.py # note: python declearn/examples/mnist/run.py works as well +python run_demo.py # note: python declearn/examples/mnist/run.py works as well ``` -The `run.py` scripts collects the server and client routines defined under -the `server.py` and `client.py` scripts, and runs them concurrently under -a single python session using multiprocessing. +The `run_demo.py` scripts collects the server and client routines defined under +the `run_server.py` and `run_client.py` scripts, and runs them concurrently +under a single python session using multiprocessing. It also prepares the data +by calling the `prepare_data.py` script, passing along input arguments, which +users are encouraged to play with - notably to vary the number of clients and +the data splitting scheme. This is the easiest way to launch the demo, e.g. to see the effects of -tweaking some learning parameters. +tweaking some learning parameters (by editing the `run_server.py` script). ### On separate terminals or machines @@ -81,45 +76,75 @@ and the machines can communicate over network using SSL-encrypted communications. We give the code to simulate this on a single machine. We then sequentially run the server then the clients on separate terminals. -1. **Set up SSL certificates**:<br/> - Start by creating a signed SSL certificate for the server and sharing the - CA file with each and every clients. The CA may be self-signed. +1. **Prepare the data**:<br/> + First, clients' data should be generated, by fetching and splitting the + MNIST dataset. This can be done in any way you want, but a practical and + easy one is to use the `prepare_data.py` script. Data may either be + prepared at a single location and then shared across clients (in the case + when distinct computers are used), or prepared redundantly at each place + using the same random seed and agreeing on clients' ordering. + + To use the `prepare_data.py` script, simply run: + ```bashand `SEED` may be + any int + python prepare_data.py <NB_CLIENTS> [--scheme=SCHEME] [--seed=SEED] + ``` + where `SCHEME` must be in `{"iid", "labels", "biased"}`. + + Alternatively, you may use the `declearn-split` command-line utility, with + similar arguments: + ```bash + declearn-split --n_shards=<NB_CLIENTS> [--scheme=SCHEME] [--seed=SEED] + ``` + +2. **Set up SSL certificates**:<br/> + Create a signed SSL certificate for the server and share the CA file that + signed it with each and every clients. That CA may be self-signed. - When testing locally, execute the `gen_ssl.py` script, to create a + When testing locally, execute the `generate_ssl.py` script, to create a self-signed root CA and an SSL certificate for "localhost": ```bash - python gen_ssl.py + python generate_ssl.py ``` Note that in real-life applications, one would most likely use certificates certificates signed by a trusted certificate authority instead. + Alternatively, `declearn.test_utils.gen_ssl_certificates` may be used to generate a self-signed CA and a signed certificate for a given domain name or IP address. -2. **Run the server**:<br/> - Open a terminal and launch the server script for 1 to 4 clients, - specifying the path to the SSL certificate and private key files, +3. **Run the server**:<br/> + Open a terminal and launch the server script for the desired number of + clients, specifying the path to the SSL certificate and private key files, and network parameters. By default, things will run on the local - host, looking for `gen_ssl.py`-created PEM files. + host, looking for `generate_ssl.py`-created PEM files. E.g., to use 2 clients: ```bash - python server.py 2 # use --help for details on network and SSL options + python run_server.py 2 # use --help for details on network and SSL options ``` + Note that you may edit that script to change the model learned, the FL + and optimization algorithms used, and/or the training hyper-parameters, + including the introduction of sample-level differential privacy. + 3. **Run each client**:<br/> - Open a new terminal and launch the client script, specifying one of the - dataset-provider names, and optionally the path the CA file and network - parameters. By default, things will run on the local host, looking for - a `gen_ssl.py`-created CA PEM file. + Open a new terminal and launch the client script, specifying the path to + the main data folder (e.g. "data/mnist_iid") and the client's name (e.g. + "client_0"), which are both used to determine where to get the prepared + data. Additional network parameters may also be passed; by default, things + will run on the localhost, looking for a `generate_ssl.py`-created CA PEM + file. - E.g., to launch a client using the "cleveland" dataset: + E.g., to launch the first client after preparing iid-split data with the + `prepara_data.py` script, call: ```bash - python client.py cleveland # use --help for details on other options + python run_client.py client_0 data/mnist_iid + # use --help for details on options ``` Note that the server should be launched before the clients, otherwise the diff --git a/examples/mnist/client.py b/examples/mnist/run_client.py similarity index 90% rename from examples/mnist/client.py rename to examples/mnist/run_client.py index 7dc57bc8f4d3feec529b8dd5bb0dd827872275ff..b8d84b942570c9f9d9d6ff09ac7305bcccfd596d 100644 --- a/examples/mnist/client.py +++ b/examples/mnist/run_client.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Script to run a federated client on the heart-disease example.""" +"""Script to run a federated client on the MNIST example.""" import datetime import logging @@ -27,13 +27,14 @@ import declearn import declearn.model.tensorflow -FILEDIR = os.path.dirname(__file__) +FILEDIR = os.path.dirname(os.path.abspath(__file__)) +DEFAULT_CERT = os.path.join(FILEDIR, "ca-cert.pem") def run_client( client_name: str, - ca_cert: str, data_folder: str, + ca_cert: str = DEFAULT_CERT, protocol: str = "websockets", serv_uri: str = "wss://localhost:8765", verbose: bool = True, @@ -44,16 +45,16 @@ def run_client( --------- client_name: str Name of the client (i.e. center data from which to use). - ca_cert: str - Path to the certificate authority file that was used to - sign the server's SSL certificate. data_folder: str The parent folder of this client's data + ca_cert: str, default="./ca-cert.pem" + Path to the certificate authority file that was used to + sign the server's SSL certificate. protocol: str, default="websockets" Name of the communication protocol to use. serv_uri: str, default="wss://localhost:8765" URI of the server to which to connect. - verbose: + verbose: bool, default=True Whether to log everything to the console, or filter out most non-error information. """ @@ -94,7 +95,7 @@ def run_client( ### (3) Define network communication parameters. - # Here, use websockets protocol on localhost:8765, + # Here, by default, use websockets protocol on localhost:8765, # with SSL encryption. network = declearn.communication.build_client( protocol=protocol, @@ -104,7 +105,7 @@ def run_client( ) ### (4) Run any necessary import statement. - # We imported `import declearn.model.tensorflow` + # We imported `import declearn.model.tensorflow`. ### (5) Instantiate a FederatedClient and run it. diff --git a/examples/mnist/run.py b/examples/mnist/run_demo.py similarity index 76% rename from examples/mnist/run.py rename to examples/mnist/run_demo.py index 35edefc291850c4938566a0dc2dbd40d194e0a27..ebe5c0c3fa4ccafe567e86c1fd1963eb8b02033f 100644 --- a/examples/mnist/run.py +++ b/examples/mnist/run_demo.py @@ -15,11 +15,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Demonstration script using the UCI Heart Disease Dataset.""" +"""Demonstration script using the MNIST dataset.""" -import glob import os import tempfile +from typing import Literal, Optional import fire # type: ignore @@ -29,15 +29,17 @@ from declearn.utils import run_as_processes # Perform local imports. # pylint: disable=wrong-import-position, wrong-import-order with make_importable(os.path.dirname(__file__)): - from client import run_client - from server import run_server + from prepare_data import prepare_mnist + from run_client import run_client + from run_server import run_server # pylint: enable=wrong-import-position, wrong-import-order -FILEDIR = os.path.join(os.path.dirname(__file__)) -DATADIR = glob.glob(f"{FILEDIR}/data*")[0] - -def run_demo(nb_clients: int = 3, data_folder: str = DATADIR) -> None: +def run_demo( + nb_clients: int = 3, + scheme: Literal["iid", "labels", "biased"] = "iid", + seed: Optional[int] = None, +) -> None: """Run a server and its clients using multiprocessing. Parameters @@ -49,14 +51,19 @@ def run_demo(nb_clients: int = 3, data_folder: str = DATADIR) -> None: Relative path to the folder holding client's data """ + # Generate the MNIST split data for this demo. + data_folder = prepare_mnist(nb_clients, scheme, seed=seed) # Use a temporary directory for single-use self-signed SSL files. with tempfile.TemporaryDirectory() as folder: # Generate self-signed SSL certificates and gather their paths. ca_cert, sv_cert, sv_pkey = generate_ssl_certificates(folder) # Specify the server and client routines that need executing. server = (run_server, (nb_clients, sv_cert, sv_pkey)) + client_args = tuple( + [data_folder, ca_cert, "websockets", "wss://localhost:8765", False] + ) clients = [ - (run_client, (f"client_{idx}", ca_cert, data_folder)) + (run_client, (f"client_{idx}", *client_args)) for idx in range(nb_clients) ] # Run routines in isolated processes. Raise if any failed. diff --git a/examples/mnist/server.py b/examples/mnist/run_server.py similarity index 94% rename from examples/mnist/server.py rename to examples/mnist/run_server.py index dda597e6d95ee93adac78495966f353573586378..3e62efabc7d41057d90fc2ebe655c5e1349447b6 100644 --- a/examples/mnist/server.py +++ b/examples/mnist/run_server.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Script to run a federated server on the heart-disease example.""" +"""Script to run a federated server on the MNIST example.""" import datetime import os @@ -24,6 +24,7 @@ import fire # type: ignore import tensorflow as tf # type: ignore import declearn +import declearn.model.tensorflow FILEDIR = os.path.dirname(os.path.abspath(__file__)) @@ -82,8 +83,7 @@ def run_server( ### (1) Define a model - # Here we use a scikit-learn SGD classifier and parametrize it - # into a L2-penalized binary logistic regression. + # Here we use a tensorflow-implemented small Convolutional Neural Network. stack = [ tf.keras.layers.InputLayer(input_shape=(28, 28, 1)), tf.keras.layers.Conv2D(32, 3, 1, activation="relu"), @@ -128,7 +128,7 @@ def run_server( ### (3) Define network communication parameters. - # Here, use websockets protocol on localhost:8765, with SSL encryption. + # Use user-provided parameters (or default WSS on localhost:8765). network = declearn.communication.build_server( protocol=protocol, host=host, @@ -161,7 +161,7 @@ def run_server( batch_size=32, n_epoch=1, ) - # Evaluation rounds. by default, 1 epoch with train's batch size. + # Evaluation rounds. By default, 1 epoch with train's batch size. evaluate = declearn.main.config.EvaluateConfig( batch_size=128, )