From 8ac8411a47e97e72c9a639747a8af5e6bec62c1c Mon Sep 17 00:00:00 2001
From: BIGAUD Nathan <nathan.bigaud@inria.fr>
Date: Fri, 14 Apr 2023 11:38:06 +0200
Subject: [PATCH] Local MNIST example

---
 examples/mnist/client.py  | 136 +++++++++++++++++++++++++++
 examples/mnist/gen_ssl.py |  27 ++++++
 examples/mnist/readme.md  | 137 +++++++++++++++++++++++++++
 examples/mnist/run.py     |  72 ++++++++++++++
 examples/mnist/server.py  | 191 ++++++++++++++++++++++++++++++++++++++
 5 files changed, 563 insertions(+)
 create mode 100644 examples/mnist/client.py
 create mode 100644 examples/mnist/gen_ssl.py
 create mode 100644 examples/mnist/readme.md
 create mode 100644 examples/mnist/run.py
 create mode 100644 examples/mnist/server.py

diff --git a/examples/mnist/client.py b/examples/mnist/client.py
new file mode 100644
index 00000000..a727a4a3
--- /dev/null
+++ b/examples/mnist/client.py
@@ -0,0 +1,136 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Script to run a federated client on the heart-disease example."""
+
+import datetime
+import logging
+import os
+
+import declearn
+import declearn.model.tensorflow
+import fire
+
+FILEDIR = os.path.dirname(__file__)
+
+
+def run_client(
+    client_name: str,
+    ca_cert: str,
+    data_folder: str,
+    protocol: str = "websockets",
+    serv_uri: str = "wss://localhost:8765",
+    verbose: bool = True,
+) -> None:
+
+    """Instantiate and run a given client.
+
+    Parameters
+    ---------
+    client_name: str
+        Name of the client (i.e. center data from which to use).
+    ca_cert: str
+        Path to the certificate authority file that was used to
+        sign the server's SSL certificate.
+    data_folder: str
+        The parent folder of this client's data
+    protocol: str, default="websockets"
+        Name of the communication protocol to use.
+    serv_uri: str, default="wss://localhost:8765"
+        URI of the server to which to connect.
+    verbose:
+        Whether to log everything to the console, or filter out most non-error
+        information.
+    """
+
+    ### Optional: some convenience settings
+
+    # Set CPU as device
+    declearn.utils.set_device_policy(gpu=False)
+
+    # Set up logger and checkpointer
+    stamp = datetime.datetime.now()
+    stamp = stamp - datetime.timedelta(
+        minutes=stamp.minute % 5,
+        seconds=stamp.second,
+        microseconds=stamp.microsecond,
+    )
+    stamp = stamp.strftime("%y-%m-%d_%H-%M")
+    checkpoint = os.path.join(FILEDIR, f"result_{stamp}", client_name)
+    logger = declearn.utils.get_logger(
+        name=client_name,
+        fpath=os.path.join(checkpoint, "logs.txt"),
+    )
+
+    # Reduce logger verbosity
+    if not verbose:
+        for handler in logger.handlers:
+            if isinstance(handler, logging.StreamHandler):
+                handler.setLevel(declearn.utils.LOGGING_LEVEL_MAJOR)
+
+    ### (1-2) Interface training and optional validation data.
+
+    # Target the proper dataset (specific to our MNIST setup).
+    data_folder = os.path.join(FILEDIR, data_folder, client_name)
+
+    # Interface the data through the generic `InMemoryDataset` class.
+    train = declearn.dataset.InMemoryDataset(
+        os.path.join(data_folder, "train_data.npy"),
+        os.path.join(data_folder, "train_target.npy"),
+    )
+    valid = declearn.dataset.InMemoryDataset(
+        os.path.join(data_folder, "valid_data.npy"),
+        os.path.join(data_folder, "valid_target.npy"),
+    )
+
+    ### (3) Define network communication parameters.
+
+    # Here, use websockets protocol on localhost:8765,
+    # with SSL encryption.
+    network = declearn.communication.build_client(
+        protocol=protocol,
+        server_uri=serv_uri,
+        name=client_name,
+        certificate=ca_cert,
+    )
+
+    ### (4) Run any necessary import statement.
+    # We imported `import declearn.model.tensorflow`
+
+    ### (5) Instantiate a FederatedClient and run it.
+
+    client = declearn.main.FederatedClient(
+        netwk=network,
+        train_data=train,
+        valid_data=valid,
+        checkpoint=checkpoint,
+        logger=logger,
+    )
+    client.run()
+
+
+# This part should not be altered: it provides with an argument parser
+# for `python client.py`).
+
+
+def main():
+    "fire-wrapped split data"
+    fire.Fire(run_client)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/examples/mnist/gen_ssl.py b/examples/mnist/gen_ssl.py
new file mode 100644
index 00000000..94f81e98
--- /dev/null
+++ b/examples/mnist/gen_ssl.py
@@ -0,0 +1,27 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Script to generate self-signed SSL certificates for the demo."""
+
+import os
+
+from declearn.test_utils import generate_ssl_certificates
+
+
+if __name__ == "__main__":
+    FILEDIR = os.path.dirname(os.path.abspath(__file__))
+    generate_ssl_certificates(FILEDIR)
diff --git a/examples/mnist/readme.md b/examples/mnist/readme.md
new file mode 100644
index 00000000..5198ecd0
--- /dev/null
+++ b/examples/mnist/readme.md
@@ -0,0 +1,137 @@
+# Demo training task : MNIST
+
+## Overview
+
+**We are going to train a common model between three simulated clients on the
+classic [MNIST dataset](http://yann.lecun.com/exdb/mnist/)**. The input of the
+model is a set of images of handwritten digits, and the model needs to
+determine which number between $0$ and $9$ each image corresponds to.
+
+## Setup
+
+To be able to experiment with this tutorial:
+
+* Clone the declearn repo, on the experimental branch:
+
+```bash
+git clone -b experimental git@gitlab.inria.fr:magnet/declearn/declearn2.git declearn
+```
+
+* Create a dedicated virtual environment.
+* Install declearn in it from the local repo:
+
+```bash
+cd declearn && pip install .[websockets,tensorflow] && cd ..
+```
+
+In an FL experiment, we consider your data as a given. So before running
+the experiment below, split the MNIST data using :
+
+```bash
+declearn-split --folder "examples/mnist" --n_shards 3 
+```
+
+## Contents
+
+This script runs a FL experiment using MNIST. The folder is structured
+the following way:
+
+```
+mnist/
+│   client.py  - set up and launch a federated-learning client
+│   gen_ssl.py - generate self-signed ssl certificates
+│   run.py     - launch both the server and clients in a single session
+│   server.py  - set up and launch a federated-learning server
+└─── data      - data split by client, created with the `split_data` util
+└─── results   - saved results from training procedure
+```
+
+## Run training routine
+
+The simplest way to run the demo is to run it locally, using multiprocessing.
+For something closer to real life implementation, we also show a way to run
+the demo from different terminals or machines.
+
+### Locally, for testing and experimentation
+
+**To simply run the demo**, use the bash command below. You can follow along
+the code in the `hands-on` section of the package documentation. For more
+details on what running the federated learning processes imply, see the last
+section.
+
+```bash
+cd
+python run.py
+```
+
+Use :
+
+```bash
+python run.py  # note: python examples/heart-uci/run.py works as well
+```
+
+The `run.py` scripts collects the server and client routines defined under
+the `server.py` and `client.py` scripts, and runs them concurrently under
+a single python session using multiprocessing.
+
+This is the easiest way to launch the demo, e.g. to see the effects of
+tweaking some learning parameters.
+
+### On separate terminals or machines
+
+**To run the examples from different terminals or machines**,
+we first ensure data is appropriately distributed between machines,
+and the machines can communicate over network using SSL-encrypted
+communications. We give the code to simulate this on a single machine.
+We then sequentially run the server then the clients on separate terminals.
+
+1. **Set up SSL certificates**:<br/>
+   Start by creating a signed SSL certificate for the server and sharing the
+   CA file with each and every clients. The CA may be self-signed.
+
+   When testing locally, execute the `gen_ssl.py` script, to create a
+   self-signed root CA and an SSL certificate for "localhost":
+
+   ```bash
+   python gen_ssl.py
+   ```
+
+   Note that in real-life applications, one would most likely use certificates
+   certificates signed by a trusted certificate authority instead.
+   Alternatively, `declearn.test_utils.gen_ssl_certificates` may be used to
+   generate a self-signed CA and a signed certificate for a given domain name
+   or IP address.
+
+2. **Run the server**:<br/>
+   Open a terminal and launch the server script for 1 to 4 clients,
+   specifying the path to the SSL certificate and private key files,
+   and network parameters. By default, things will run on the local
+   host, looking for `gen_ssl.py`-created PEM files.
+
+   E.g., to use 2 clients:
+
+    ```bash
+    python server.py 2  # use --help for details on network and SSL options
+    ```
+
+3. **Run each client**:<br/>
+   Open a new terminal and launch the client script, specifying one of the
+   dataset-provider names, and optionally the path the CA file and network
+   parameters. By default, things will run on the local host, looking for
+   a `gen_ssl.py`-created CA PEM file.
+
+   E.g., to launch a client using the "cleveland" dataset:
+
+    ```bash
+    python client.py cleveland   # use --help for details on other options
+    ```
+
+Note that the server should be launched before the clients, otherwise the
+latter might fail to connect which would cause the script to terminate. A
+few seconds' delay is tolerable as clients will make multiple connection
+attempts prior to failing.
+
+**To run the example in a real-life setting**, follow the instructions from
+this section, after having generated and shared the appropriate PEM files to
+set up SSL-encryption, and using additional script parameters to specify the
+network host and port to use.
diff --git a/examples/mnist/run.py b/examples/mnist/run.py
new file mode 100644
index 00000000..80f24ef8
--- /dev/null
+++ b/examples/mnist/run.py
@@ -0,0 +1,72 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Demonstration script using the UCI Heart Disease Dataset."""
+
+import glob
+import os
+import tempfile
+
+import fire
+
+from declearn.test_utils import generate_ssl_certificates, make_importable
+from declearn.utils import run_as_processes
+
+# Perform local imports.
+# pylint: disable=wrong-import-position, wrong-import-order
+with make_importable(os.path.dirname(__file__)):
+    from client import run_client
+    from server import run_server
+# pylint: enable=wrong-import-position, wrong-import-order
+
+FILEDIR = os.path.join(os.path.dirname(__file__))
+DATADIR = glob.glob(f"{FILEDIR}/data*")[0]
+
+
+def run_demo(nb_clients: int = 3, data_folder: str = DATADIR) -> None:
+    """Run a server and its clients using multiprocessing.
+
+    Parameters
+    ------
+
+    n_clients: int
+        number of clients to run.
+    data_folder: str
+        Relative path to the folder holding client's data
+
+    """
+    # Use a temporary directory for single-use self-signed SSL files.
+    with tempfile.TemporaryDirectory() as folder:
+        # Generate self-signed SSL certificates and gather their paths.
+        ca_cert, sv_cert, sv_pkey = generate_ssl_certificates(folder)
+        # Specify the server and client routines that need executing.
+        server = (run_server, (nb_clients, sv_cert, sv_pkey))
+        clients = [
+            (run_client, (f"client_{idx}", ca_cert, data_folder))
+            for idx in range(nb_clients)
+        ]
+        # Run routines in isolated processes. Raise if any failed.
+        success, outp = run_as_processes(server, *clients)
+        if not success:
+            raise RuntimeError(
+                "Something went wrong during the demo. Exceptions caught:\n"
+                "\n".join(str(e) for e in outp if isinstance(e, RuntimeError))
+            )
+
+
+if __name__ == "__main__":
+    fire.Fire(run_demo)
diff --git a/examples/mnist/server.py b/examples/mnist/server.py
new file mode 100644
index 00000000..552a9bb5
--- /dev/null
+++ b/examples/mnist/server.py
@@ -0,0 +1,191 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Script to run a federated server on the heart-disease example."""
+
+import datetime
+import os
+
+import fire
+import tensorflow as tf  # type: ignore
+
+import declearn
+
+FILEDIR = os.path.dirname(os.path.abspath(__file__))
+DEFAULT_CERT = os.path.join(FILEDIR, "server-cert.pem")
+DEFAULT_PKEY = os.path.join(FILEDIR, "server-pkey.pem")
+
+
+def run_server(
+    nb_clients: int,
+    certificate: str = DEFAULT_CERT,
+    private_key: str = DEFAULT_PKEY,
+    protocol: str = "websockets",
+    host: str = "localhost",
+    port: int = 8765,
+) -> None:
+    """Instantiate and run the orchestrating server.
+
+    Arguments
+    ---------
+    nb_clients: int
+        Exact number of clients used in this example.
+    certificate: str
+        Path to the (self-signed) SSL certificate to use.
+    private_key: str
+        Path to the associated private-key to use.
+    protocol: str, default="websockets"
+        Name of the communication protocol to use.
+    host: str, default="localhost"
+        Hostname or IP address on which to serve.
+    port: int, default=8765
+        Communication port on which to serve.
+    """
+
+    ### Optional: some convenience settings
+
+    # Set CPU as device
+    declearn.utils.set_device_policy(gpu=False)
+
+    # Set up metrics suitable for MNIST.
+    metrics = declearn.metrics.MetricSet(
+        [
+            declearn.metrics.MulticlassAccuracyPrecisionRecall(
+                labels=range(10)
+            ),
+        ]
+    )
+
+    # Set up checkpointing and logging.
+    stamp = datetime.datetime.now()
+    stamp = stamp - datetime.timedelta(
+        minutes=stamp.minute % 5,
+        seconds=stamp.second,
+        microseconds=stamp.microsecond,
+    )
+    stamp = stamp.strftime("%y-%m-%d_%H-%M")
+    checkpoint = os.path.join(FILEDIR, f"result_{stamp}", "server")
+    # Set up a logger, records from which will go to a file.
+    logger = declearn.utils.get_logger(
+        name="Server",
+        fpath=os.path.join(checkpoint, "logs.txt"),
+    )
+
+    ### (1) Define a model
+
+    # Here we use a scikit-learn SGD classifier and parametrize it
+    # into a L2-penalized binary logistic regression.
+    stack = [
+        tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),
+        tf.keras.layers.Conv2D(32, 3, 1, activation="relu"),
+        tf.keras.layers.MaxPool2D(2),
+        tf.keras.layers.Dropout(0.25),
+        tf.keras.layers.Flatten(),
+        tf.keras.layers.Dense(10, activation="softmax"),
+    ]
+    model = declearn.model.tensorflow.TensorflowModel(
+        model=tf.keras.Sequential(stack),
+        loss="sparse_categorical_crossentropy",
+    )
+
+    ### (2) Define an optimization strategy
+
+    # Set up the cient updates' aggregator. By default: FedAvg.
+    aggregator = declearn.aggregator.AveragingAggregator()
+
+    # Set up the server-side optimizer (to refine aggregated updates).
+    # By default: no refinement (lrate=1.0, no plug-ins).
+    server_opt = declearn.optimizer.Optimizer(
+        lrate=1.0,
+        w_decay=0.0,
+        modules=None,
+    )
+
+    # Set up the client-side optimizer (for local SGD steps).
+    # By default: vanilla SGD, with a selected learning rate.
+    client_opt = declearn.optimizer.Optimizer(
+        lrate=0.001,
+        w_decay=0.0,
+        regularizers=None,
+        modules=None,
+    )
+
+    # Wrap all this into a FLOptimConfig.
+    optim = declearn.main.config.FLOptimConfig.from_params(
+        aggregator=aggregator,
+        server_opt=server_opt,
+        client_opt=client_opt,
+    )
+
+    ### (3) Define network communication parameters.
+
+    # Here, use websockets protocol on localhost:8765, with SSL encryption.
+    network = declearn.communication.build_server(
+        protocol=protocol,
+        host=host,
+        port=port,
+        certificate=certificate,
+        private_key=private_key,
+    )
+
+    ### (4) Instantiate and run a FederatedServer.
+
+    # Instanciate
+    server = declearn.main.FederatedServer(
+        model=model,
+        netwk=network,
+        optim=optim,
+        metrics=metrics,
+        checkpoint=checkpoint,
+        logger=logger,
+    )
+
+    # Set up the experiment's hyper-parameters.
+    # Registration rules: wait for 10 seconds at registration.
+    register = declearn.main.config.RegisterConfig(timeout=10)
+    # Training rounds hyper-parameters. By default, 1 epoch / round.
+    training = declearn.main.config.TrainingConfig(
+        batch_size=32,
+        n_epoch=1,
+    )
+    # Evaluation rounds. by default,  1 epoch with train's batch size.
+    evaluate = declearn.main.config.EvaluateConfig(
+        batch_size=128,
+    )
+    # Wrap all this into a FLRunConfig.
+    run_config = declearn.main.config.FLRunConfig.from_params(
+        rounds=5,  # you may change the number of training rounds
+        register=register,
+        training=training,
+        evaluate=evaluate,
+        privacy=None,  # you may set up local DP (DP-SGD) here
+        early_stop=None,  # you may add an early-stopping cirterion here
+    )
+    server.run(run_config)
+
+
+# This part should not be altered: it provides with an argument parser.
+# for `python server.py`.
+
+
+def main():
+    "fire-wrapped split data"
+    fire.Fire(run_server)
+
+
+if __name__ == "__main__":
+    main()
-- 
GitLab