Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit 6b7a3d61 authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Revise the MNIST example, cleaning up the code and documentation.

parent c9f52458
No related branches found
No related tags found
1 merge request!51Revise the MNIST example, cleaning up the code and documentation.
Pipeline #833374 passed
File moved
# coding: utf-8
# Copyright 2023 Inria (Institut National de Recherche en Informatique
# et Automatique)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data-preparation script for the MNIST dataset."""
import os
from typing import Literal, Optional
import fire # type: ignore
from declearn.dataset.examples import load_mnist
from declearn.dataset.utils import (
save_data_array,
split_multi_classif_dataset,
)
DATADIR = os.path.join(os.path.dirname(__file__), "data")
def prepare_mnist(
nb_clients: int,
scheme: Literal["iid", "labels", "biased"] = "iid",
folder: str = DATADIR,
seed: Optional[int] = None,
) -> str:
"""Fetch and split the MNIST dataset to use it federatively.
Parameters
----------
nb_clients:
Number of shards between which to split the raw MNIST data.
scheme:
Splitting scheme to use. In all cases, shards contain mutually-
exclusive samples and cover the full dataset. See details below.
folder:
Path to the root folder where to export the raw and split data,
using adequately-named subfolders.
seed:
Optional seed to the RNG used for all sampling operations.
Data-splitting schemes
----------------------
- If "iid", split the dataset through iid random sampling.
- If "labels", split into shards that hold all samples associated
with mutually-exclusive target classes.
- If "biased", split the dataset through random sampling according
to a shard-specific random labels distribution.
"""
# Download (or reload) the raw MNIST data.
datadir_raw = os.path.join(folder, "mnist_raw")
dataset_raw = load_mnist(train=True, folder=datadir_raw)
# Split it based on the input arguments.
print(f"Splitting MNIST into {nb_clients} shards using '{scheme}' scheme.")
split_data = split_multi_classif_dataset(
dataset_raw, n_shards=nb_clients, scheme=scheme, seed=seed
)
# Export shard data into expected folder structure.
folder = os.path.join(folder, f"mnist_{scheme}")
for idx, ((x_t, y_t), (x_v, y_v)) in enumerate(split_data):
save_data_array(
os.path.join(folder, f"client_{idx}", "train_data"), x_t
)
save_data_array(
os.path.join(folder, f"client_{idx}", "train_target"), y_t
)
save_data_array(
os.path.join(folder, f"client_{idx}", "valid_data"), x_v
)
save_data_array(
os.path.join(folder, f"client_{idx}", "valid_target"), y_v
)
# Return the path to the split data folder.
return folder
if __name__ == "__main__":
fire.Fire(prepare_mnist)
...@@ -24,15 +24,6 @@ git clone git@gitlab.inria.fr:magnet/declearn/declearn2.git declearn ...@@ -24,15 +24,6 @@ git clone git@gitlab.inria.fr:magnet/declearn/declearn2.git declearn
cd declearn && pip install .[websockets,tensorflow] && cd .. cd declearn && pip install .[websockets,tensorflow] && cd ..
``` ```
In an FL experiment, we consider your data as a given. So before running
the experiment below, download and split the MNIST data using:
```bash
declearn-split --folder "examples/mnist" --n_shards 3
```
You may add `--seed <some_number>` if you want to ensure reproducibility.
## Contents ## Contents
This script runs a FL experiment using MNIST. The folder is structured This script runs a FL experiment using MNIST. The folder is structured
...@@ -40,12 +31,13 @@ the following way: ...@@ -40,12 +31,13 @@ the following way:
``` ```
mnist/ mnist/
│ client.py - set up and launch a federated-learning client │ generate_ssl.py - generate self-signed ssl certificates
│ gen_ssl.py - generate self-signed ssl certificates | prepare_data.py - fetch and split the MNIST dataset for FL use
│ run.py - launch both the server and clients in a single session | run_client.py - set up and launch a federated-learning client
│ server.py - set up and launch a federated-learning server │ run_demo.py - simulate the entire FL process in a single session
└─── data - data split by client, created with the `split_data` util │ run_server.py - set up and launch a federated-learning server
└─── results - saved results from training procedure └─── data - data folder, containing raw and split MNIST data
└─── results_<time> - saved results from training procedures
``` ```
## Run training routine ## Run training routine
...@@ -63,15 +55,18 @@ section. ...@@ -63,15 +55,18 @@ section.
```bash ```bash
cd declearn/examples/mnist/ cd declearn/examples/mnist/
python run.py # note: python declearn/examples/mnist/run.py works as well python run_demo.py # note: python declearn/examples/mnist/run.py works as well
``` ```
The `run.py` scripts collects the server and client routines defined under The `run_demo.py` scripts collects the server and client routines defined under
the `server.py` and `client.py` scripts, and runs them concurrently under the `run_server.py` and `run_client.py` scripts, and runs them concurrently
a single python session using multiprocessing. under a single python session using multiprocessing. It also prepares the data
by calling the `prepare_data.py` script, passing along input arguments, which
users are encouraged to play with - notably to vary the number of clients and
the data splitting scheme.
This is the easiest way to launch the demo, e.g. to see the effects of This is the easiest way to launch the demo, e.g. to see the effects of
tweaking some learning parameters. tweaking some learning parameters (by editing the `run_server.py` script).
### On separate terminals or machines ### On separate terminals or machines
...@@ -81,45 +76,75 @@ and the machines can communicate over network using SSL-encrypted ...@@ -81,45 +76,75 @@ and the machines can communicate over network using SSL-encrypted
communications. We give the code to simulate this on a single machine. communications. We give the code to simulate this on a single machine.
We then sequentially run the server then the clients on separate terminals. We then sequentially run the server then the clients on separate terminals.
1. **Set up SSL certificates**:<br/> 1. **Prepare the data**:<br/>
Start by creating a signed SSL certificate for the server and sharing the First, clients' data should be generated, by fetching and splitting the
CA file with each and every clients. The CA may be self-signed. MNIST dataset. This can be done in any way you want, but a practical and
easy one is to use the `prepare_data.py` script. Data may either be
prepared at a single location and then shared across clients (in the case
when distinct computers are used), or prepared redundantly at each place
using the same random seed and agreeing on clients' ordering.
To use the `prepare_data.py` script, simply run:
```bashand `SEED` may be
any int
python prepare_data.py <NB_CLIENTS> [--scheme=SCHEME] [--seed=SEED]
```
where `SCHEME` must be in `{"iid", "labels", "biased"}`.
Alternatively, you may use the `declearn-split` command-line utility, with
similar arguments:
```bash
declearn-split --n_shards=<NB_CLIENTS> [--scheme=SCHEME] [--seed=SEED]
```
2. **Set up SSL certificates**:<br/>
Create a signed SSL certificate for the server and share the CA file that
signed it with each and every clients. That CA may be self-signed.
When testing locally, execute the `gen_ssl.py` script, to create a When testing locally, execute the `generate_ssl.py` script, to create a
self-signed root CA and an SSL certificate for "localhost": self-signed root CA and an SSL certificate for "localhost":
```bash ```bash
python gen_ssl.py python generate_ssl.py
``` ```
Note that in real-life applications, one would most likely use certificates Note that in real-life applications, one would most likely use certificates
certificates signed by a trusted certificate authority instead. certificates signed by a trusted certificate authority instead.
Alternatively, `declearn.test_utils.gen_ssl_certificates` may be used to Alternatively, `declearn.test_utils.gen_ssl_certificates` may be used to
generate a self-signed CA and a signed certificate for a given domain name generate a self-signed CA and a signed certificate for a given domain name
or IP address. or IP address.
2. **Run the server**:<br/> 3. **Run the server**:<br/>
Open a terminal and launch the server script for 1 to 4 clients, Open a terminal and launch the server script for the desired number of
specifying the path to the SSL certificate and private key files, clients, specifying the path to the SSL certificate and private key files,
and network parameters. By default, things will run on the local and network parameters. By default, things will run on the local
host, looking for `gen_ssl.py`-created PEM files. host, looking for `generate_ssl.py`-created PEM files.
E.g., to use 2 clients: E.g., to use 2 clients:
```bash ```bash
python server.py 2 # use --help for details on network and SSL options python run_server.py 2 # use --help for details on network and SSL options
``` ```
Note that you may edit that script to change the model learned, the FL
and optimization algorithms used, and/or the training hyper-parameters,
including the introduction of sample-level differential privacy.
3. **Run each client**:<br/> 3. **Run each client**:<br/>
Open a new terminal and launch the client script, specifying one of the Open a new terminal and launch the client script, specifying the path to
dataset-provider names, and optionally the path the CA file and network the main data folder (e.g. "data/mnist_iid") and the client's name (e.g.
parameters. By default, things will run on the local host, looking for "client_0"), which are both used to determine where to get the prepared
a `gen_ssl.py`-created CA PEM file. data. Additional network parameters may also be passed; by default, things
will run on the localhost, looking for a `generate_ssl.py`-created CA PEM
file.
E.g., to launch a client using the "cleveland" dataset: E.g., to launch the first client after preparing iid-split data with the
`prepara_data.py` script, call:
```bash ```bash
python client.py cleveland # use --help for details on other options python run_client.py client_0 data/mnist_iid
# use --help for details on options
``` ```
Note that the server should be launched before the clients, otherwise the Note that the server should be launched before the clients, otherwise the
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Script to run a federated client on the heart-disease example.""" """Script to run a federated client on the MNIST example."""
import datetime import datetime
import logging import logging
...@@ -27,13 +27,14 @@ import declearn ...@@ -27,13 +27,14 @@ import declearn
import declearn.model.tensorflow import declearn.model.tensorflow
FILEDIR = os.path.dirname(__file__) FILEDIR = os.path.dirname(os.path.abspath(__file__))
DEFAULT_CERT = os.path.join(FILEDIR, "ca-cert.pem")
def run_client( def run_client(
client_name: str, client_name: str,
ca_cert: str,
data_folder: str, data_folder: str,
ca_cert: str = DEFAULT_CERT,
protocol: str = "websockets", protocol: str = "websockets",
serv_uri: str = "wss://localhost:8765", serv_uri: str = "wss://localhost:8765",
verbose: bool = True, verbose: bool = True,
...@@ -44,16 +45,16 @@ def run_client( ...@@ -44,16 +45,16 @@ def run_client(
--------- ---------
client_name: str client_name: str
Name of the client (i.e. center data from which to use). Name of the client (i.e. center data from which to use).
ca_cert: str
Path to the certificate authority file that was used to
sign the server's SSL certificate.
data_folder: str data_folder: str
The parent folder of this client's data The parent folder of this client's data
ca_cert: str, default="./ca-cert.pem"
Path to the certificate authority file that was used to
sign the server's SSL certificate.
protocol: str, default="websockets" protocol: str, default="websockets"
Name of the communication protocol to use. Name of the communication protocol to use.
serv_uri: str, default="wss://localhost:8765" serv_uri: str, default="wss://localhost:8765"
URI of the server to which to connect. URI of the server to which to connect.
verbose: verbose: bool, default=True
Whether to log everything to the console, or filter out most non-error Whether to log everything to the console, or filter out most non-error
information. information.
""" """
...@@ -94,7 +95,7 @@ def run_client( ...@@ -94,7 +95,7 @@ def run_client(
### (3) Define network communication parameters. ### (3) Define network communication parameters.
# Here, use websockets protocol on localhost:8765, # Here, by default, use websockets protocol on localhost:8765,
# with SSL encryption. # with SSL encryption.
network = declearn.communication.build_client( network = declearn.communication.build_client(
protocol=protocol, protocol=protocol,
...@@ -104,7 +105,7 @@ def run_client( ...@@ -104,7 +105,7 @@ def run_client(
) )
### (4) Run any necessary import statement. ### (4) Run any necessary import statement.
# We imported `import declearn.model.tensorflow` # We imported `import declearn.model.tensorflow`.
### (5) Instantiate a FederatedClient and run it. ### (5) Instantiate a FederatedClient and run it.
......
...@@ -15,11 +15,11 @@ ...@@ -15,11 +15,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Demonstration script using the UCI Heart Disease Dataset.""" """Demonstration script using the MNIST dataset."""
import glob
import os import os
import tempfile import tempfile
from typing import Literal, Optional
import fire # type: ignore import fire # type: ignore
...@@ -29,15 +29,17 @@ from declearn.utils import run_as_processes ...@@ -29,15 +29,17 @@ from declearn.utils import run_as_processes
# Perform local imports. # Perform local imports.
# pylint: disable=wrong-import-position, wrong-import-order # pylint: disable=wrong-import-position, wrong-import-order
with make_importable(os.path.dirname(__file__)): with make_importable(os.path.dirname(__file__)):
from client import run_client from prepare_data import prepare_mnist
from server import run_server from run_client import run_client
from run_server import run_server
# pylint: enable=wrong-import-position, wrong-import-order # pylint: enable=wrong-import-position, wrong-import-order
FILEDIR = os.path.join(os.path.dirname(__file__))
DATADIR = glob.glob(f"{FILEDIR}/data*")[0]
def run_demo(
def run_demo(nb_clients: int = 3, data_folder: str = DATADIR) -> None: nb_clients: int = 3,
scheme: Literal["iid", "labels", "biased"] = "iid",
seed: Optional[int] = None,
) -> None:
"""Run a server and its clients using multiprocessing. """Run a server and its clients using multiprocessing.
Parameters Parameters
...@@ -49,14 +51,19 @@ def run_demo(nb_clients: int = 3, data_folder: str = DATADIR) -> None: ...@@ -49,14 +51,19 @@ def run_demo(nb_clients: int = 3, data_folder: str = DATADIR) -> None:
Relative path to the folder holding client's data Relative path to the folder holding client's data
""" """
# Generate the MNIST split data for this demo.
data_folder = prepare_mnist(nb_clients, scheme, seed=seed)
# Use a temporary directory for single-use self-signed SSL files. # Use a temporary directory for single-use self-signed SSL files.
with tempfile.TemporaryDirectory() as folder: with tempfile.TemporaryDirectory() as folder:
# Generate self-signed SSL certificates and gather their paths. # Generate self-signed SSL certificates and gather their paths.
ca_cert, sv_cert, sv_pkey = generate_ssl_certificates(folder) ca_cert, sv_cert, sv_pkey = generate_ssl_certificates(folder)
# Specify the server and client routines that need executing. # Specify the server and client routines that need executing.
server = (run_server, (nb_clients, sv_cert, sv_pkey)) server = (run_server, (nb_clients, sv_cert, sv_pkey))
client_args = tuple(
[data_folder, ca_cert, "websockets", "wss://localhost:8765", False]
)
clients = [ clients = [
(run_client, (f"client_{idx}", ca_cert, data_folder)) (run_client, (f"client_{idx}", *client_args))
for idx in range(nb_clients) for idx in range(nb_clients)
] ]
# Run routines in isolated processes. Raise if any failed. # Run routines in isolated processes. Raise if any failed.
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Script to run a federated server on the heart-disease example.""" """Script to run a federated server on the MNIST example."""
import datetime import datetime
import os import os
...@@ -24,6 +24,7 @@ import fire # type: ignore ...@@ -24,6 +24,7 @@ import fire # type: ignore
import tensorflow as tf # type: ignore import tensorflow as tf # type: ignore
import declearn import declearn
import declearn.model.tensorflow
FILEDIR = os.path.dirname(os.path.abspath(__file__)) FILEDIR = os.path.dirname(os.path.abspath(__file__))
...@@ -82,8 +83,7 @@ def run_server( ...@@ -82,8 +83,7 @@ def run_server(
### (1) Define a model ### (1) Define a model
# Here we use a scikit-learn SGD classifier and parametrize it # Here we use a tensorflow-implemented small Convolutional Neural Network.
# into a L2-penalized binary logistic regression.
stack = [ stack = [
tf.keras.layers.InputLayer(input_shape=(28, 28, 1)), tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),
tf.keras.layers.Conv2D(32, 3, 1, activation="relu"), tf.keras.layers.Conv2D(32, 3, 1, activation="relu"),
...@@ -128,7 +128,7 @@ def run_server( ...@@ -128,7 +128,7 @@ def run_server(
### (3) Define network communication parameters. ### (3) Define network communication parameters.
# Here, use websockets protocol on localhost:8765, with SSL encryption. # Use user-provided parameters (or default WSS on localhost:8765).
network = declearn.communication.build_server( network = declearn.communication.build_server(
protocol=protocol, protocol=protocol,
host=host, host=host,
...@@ -161,7 +161,7 @@ def run_server( ...@@ -161,7 +161,7 @@ def run_server(
batch_size=32, batch_size=32,
n_epoch=1, n_epoch=1,
) )
# Evaluation rounds. by default, 1 epoch with train's batch size. # Evaluation rounds. By default, 1 epoch with train's batch size.
evaluate = declearn.main.config.EvaluateConfig( evaluate = declearn.main.config.EvaluateConfig(
batch_size=128, batch_size=128,
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment