Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit 8ac8411a authored by BIGAUD Nathan's avatar BIGAUD Nathan Committed by ANDREY Paul
Browse files

Local MNIST example

parent 75c1ec35
No related branches found
No related tags found
1 merge request!41Quickrun mode
# coding: utf-8
# Copyright 2023 Inria (Institut National de Recherche en Informatique
# et Automatique)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Script to run a federated client on the heart-disease example."""
import datetime
import logging
import os
import declearn
import declearn.model.tensorflow
import fire
FILEDIR = os.path.dirname(__file__)
def run_client(
client_name: str,
ca_cert: str,
data_folder: str,
protocol: str = "websockets",
serv_uri: str = "wss://localhost:8765",
verbose: bool = True,
) -> None:
"""Instantiate and run a given client.
Parameters
---------
client_name: str
Name of the client (i.e. center data from which to use).
ca_cert: str
Path to the certificate authority file that was used to
sign the server's SSL certificate.
data_folder: str
The parent folder of this client's data
protocol: str, default="websockets"
Name of the communication protocol to use.
serv_uri: str, default="wss://localhost:8765"
URI of the server to which to connect.
verbose:
Whether to log everything to the console, or filter out most non-error
information.
"""
### Optional: some convenience settings
# Set CPU as device
declearn.utils.set_device_policy(gpu=False)
# Set up logger and checkpointer
stamp = datetime.datetime.now()
stamp = stamp - datetime.timedelta(
minutes=stamp.minute % 5,
seconds=stamp.second,
microseconds=stamp.microsecond,
)
stamp = stamp.strftime("%y-%m-%d_%H-%M")
checkpoint = os.path.join(FILEDIR, f"result_{stamp}", client_name)
logger = declearn.utils.get_logger(
name=client_name,
fpath=os.path.join(checkpoint, "logs.txt"),
)
# Reduce logger verbosity
if not verbose:
for handler in logger.handlers:
if isinstance(handler, logging.StreamHandler):
handler.setLevel(declearn.utils.LOGGING_LEVEL_MAJOR)
### (1-2) Interface training and optional validation data.
# Target the proper dataset (specific to our MNIST setup).
data_folder = os.path.join(FILEDIR, data_folder, client_name)
# Interface the data through the generic `InMemoryDataset` class.
train = declearn.dataset.InMemoryDataset(
os.path.join(data_folder, "train_data.npy"),
os.path.join(data_folder, "train_target.npy"),
)
valid = declearn.dataset.InMemoryDataset(
os.path.join(data_folder, "valid_data.npy"),
os.path.join(data_folder, "valid_target.npy"),
)
### (3) Define network communication parameters.
# Here, use websockets protocol on localhost:8765,
# with SSL encryption.
network = declearn.communication.build_client(
protocol=protocol,
server_uri=serv_uri,
name=client_name,
certificate=ca_cert,
)
### (4) Run any necessary import statement.
# We imported `import declearn.model.tensorflow`
### (5) Instantiate a FederatedClient and run it.
client = declearn.main.FederatedClient(
netwk=network,
train_data=train,
valid_data=valid,
checkpoint=checkpoint,
logger=logger,
)
client.run()
# This part should not be altered: it provides with an argument parser
# for `python client.py`).
def main():
"fire-wrapped split data"
fire.Fire(run_client)
if __name__ == "__main__":
main()
# coding: utf-8
# Copyright 2023 Inria (Institut National de Recherche en Informatique
# et Automatique)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Script to generate self-signed SSL certificates for the demo."""
import os
from declearn.test_utils import generate_ssl_certificates
if __name__ == "__main__":
FILEDIR = os.path.dirname(os.path.abspath(__file__))
generate_ssl_certificates(FILEDIR)
# Demo training task : MNIST
## Overview
**We are going to train a common model between three simulated clients on the
classic [MNIST dataset](http://yann.lecun.com/exdb/mnist/)**. The input of the
model is a set of images of handwritten digits, and the model needs to
determine which number between $0$ and $9$ each image corresponds to.
## Setup
To be able to experiment with this tutorial:
* Clone the declearn repo, on the experimental branch:
```bash
git clone -b experimental git@gitlab.inria.fr:magnet/declearn/declearn2.git declearn
```
* Create a dedicated virtual environment.
* Install declearn in it from the local repo:
```bash
cd declearn && pip install .[websockets,tensorflow] && cd ..
```
In an FL experiment, we consider your data as a given. So before running
the experiment below, split the MNIST data using :
```bash
declearn-split --folder "examples/mnist" --n_shards 3
```
## Contents
This script runs a FL experiment using MNIST. The folder is structured
the following way:
```
mnist/
│ client.py - set up and launch a federated-learning client
│ gen_ssl.py - generate self-signed ssl certificates
│ run.py - launch both the server and clients in a single session
│ server.py - set up and launch a federated-learning server
└─── data - data split by client, created with the `split_data` util
└─── results - saved results from training procedure
```
## Run training routine
The simplest way to run the demo is to run it locally, using multiprocessing.
For something closer to real life implementation, we also show a way to run
the demo from different terminals or machines.
### Locally, for testing and experimentation
**To simply run the demo**, use the bash command below. You can follow along
the code in the `hands-on` section of the package documentation. For more
details on what running the federated learning processes imply, see the last
section.
```bash
cd
python run.py
```
Use :
```bash
python run.py # note: python examples/heart-uci/run.py works as well
```
The `run.py` scripts collects the server and client routines defined under
the `server.py` and `client.py` scripts, and runs them concurrently under
a single python session using multiprocessing.
This is the easiest way to launch the demo, e.g. to see the effects of
tweaking some learning parameters.
### On separate terminals or machines
**To run the examples from different terminals or machines**,
we first ensure data is appropriately distributed between machines,
and the machines can communicate over network using SSL-encrypted
communications. We give the code to simulate this on a single machine.
We then sequentially run the server then the clients on separate terminals.
1. **Set up SSL certificates**:<br/>
Start by creating a signed SSL certificate for the server and sharing the
CA file with each and every clients. The CA may be self-signed.
When testing locally, execute the `gen_ssl.py` script, to create a
self-signed root CA and an SSL certificate for "localhost":
```bash
python gen_ssl.py
```
Note that in real-life applications, one would most likely use certificates
certificates signed by a trusted certificate authority instead.
Alternatively, `declearn.test_utils.gen_ssl_certificates` may be used to
generate a self-signed CA and a signed certificate for a given domain name
or IP address.
2. **Run the server**:<br/>
Open a terminal and launch the server script for 1 to 4 clients,
specifying the path to the SSL certificate and private key files,
and network parameters. By default, things will run on the local
host, looking for `gen_ssl.py`-created PEM files.
E.g., to use 2 clients:
```bash
python server.py 2 # use --help for details on network and SSL options
```
3. **Run each client**:<br/>
Open a new terminal and launch the client script, specifying one of the
dataset-provider names, and optionally the path the CA file and network
parameters. By default, things will run on the local host, looking for
a `gen_ssl.py`-created CA PEM file.
E.g., to launch a client using the "cleveland" dataset:
```bash
python client.py cleveland # use --help for details on other options
```
Note that the server should be launched before the clients, otherwise the
latter might fail to connect which would cause the script to terminate. A
few seconds' delay is tolerable as clients will make multiple connection
attempts prior to failing.
**To run the example in a real-life setting**, follow the instructions from
this section, after having generated and shared the appropriate PEM files to
set up SSL-encryption, and using additional script parameters to specify the
network host and port to use.
# coding: utf-8
# Copyright 2023 Inria (Institut National de Recherche en Informatique
# et Automatique)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Demonstration script using the UCI Heart Disease Dataset."""
import glob
import os
import tempfile
import fire
from declearn.test_utils import generate_ssl_certificates, make_importable
from declearn.utils import run_as_processes
# Perform local imports.
# pylint: disable=wrong-import-position, wrong-import-order
with make_importable(os.path.dirname(__file__)):
from client import run_client
from server import run_server
# pylint: enable=wrong-import-position, wrong-import-order
FILEDIR = os.path.join(os.path.dirname(__file__))
DATADIR = glob.glob(f"{FILEDIR}/data*")[0]
def run_demo(nb_clients: int = 3, data_folder: str = DATADIR) -> None:
"""Run a server and its clients using multiprocessing.
Parameters
------
n_clients: int
number of clients to run.
data_folder: str
Relative path to the folder holding client's data
"""
# Use a temporary directory for single-use self-signed SSL files.
with tempfile.TemporaryDirectory() as folder:
# Generate self-signed SSL certificates and gather their paths.
ca_cert, sv_cert, sv_pkey = generate_ssl_certificates(folder)
# Specify the server and client routines that need executing.
server = (run_server, (nb_clients, sv_cert, sv_pkey))
clients = [
(run_client, (f"client_{idx}", ca_cert, data_folder))
for idx in range(nb_clients)
]
# Run routines in isolated processes. Raise if any failed.
success, outp = run_as_processes(server, *clients)
if not success:
raise RuntimeError(
"Something went wrong during the demo. Exceptions caught:\n"
"\n".join(str(e) for e in outp if isinstance(e, RuntimeError))
)
if __name__ == "__main__":
fire.Fire(run_demo)
# coding: utf-8
# Copyright 2023 Inria (Institut National de Recherche en Informatique
# et Automatique)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Script to run a federated server on the heart-disease example."""
import datetime
import os
import fire
import tensorflow as tf # type: ignore
import declearn
FILEDIR = os.path.dirname(os.path.abspath(__file__))
DEFAULT_CERT = os.path.join(FILEDIR, "server-cert.pem")
DEFAULT_PKEY = os.path.join(FILEDIR, "server-pkey.pem")
def run_server(
nb_clients: int,
certificate: str = DEFAULT_CERT,
private_key: str = DEFAULT_PKEY,
protocol: str = "websockets",
host: str = "localhost",
port: int = 8765,
) -> None:
"""Instantiate and run the orchestrating server.
Arguments
---------
nb_clients: int
Exact number of clients used in this example.
certificate: str
Path to the (self-signed) SSL certificate to use.
private_key: str
Path to the associated private-key to use.
protocol: str, default="websockets"
Name of the communication protocol to use.
host: str, default="localhost"
Hostname or IP address on which to serve.
port: int, default=8765
Communication port on which to serve.
"""
### Optional: some convenience settings
# Set CPU as device
declearn.utils.set_device_policy(gpu=False)
# Set up metrics suitable for MNIST.
metrics = declearn.metrics.MetricSet(
[
declearn.metrics.MulticlassAccuracyPrecisionRecall(
labels=range(10)
),
]
)
# Set up checkpointing and logging.
stamp = datetime.datetime.now()
stamp = stamp - datetime.timedelta(
minutes=stamp.minute % 5,
seconds=stamp.second,
microseconds=stamp.microsecond,
)
stamp = stamp.strftime("%y-%m-%d_%H-%M")
checkpoint = os.path.join(FILEDIR, f"result_{stamp}", "server")
# Set up a logger, records from which will go to a file.
logger = declearn.utils.get_logger(
name="Server",
fpath=os.path.join(checkpoint, "logs.txt"),
)
### (1) Define a model
# Here we use a scikit-learn SGD classifier and parametrize it
# into a L2-penalized binary logistic regression.
stack = [
tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),
tf.keras.layers.Conv2D(32, 3, 1, activation="relu"),
tf.keras.layers.MaxPool2D(2),
tf.keras.layers.Dropout(0.25),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(10, activation="softmax"),
]
model = declearn.model.tensorflow.TensorflowModel(
model=tf.keras.Sequential(stack),
loss="sparse_categorical_crossentropy",
)
### (2) Define an optimization strategy
# Set up the cient updates' aggregator. By default: FedAvg.
aggregator = declearn.aggregator.AveragingAggregator()
# Set up the server-side optimizer (to refine aggregated updates).
# By default: no refinement (lrate=1.0, no plug-ins).
server_opt = declearn.optimizer.Optimizer(
lrate=1.0,
w_decay=0.0,
modules=None,
)
# Set up the client-side optimizer (for local SGD steps).
# By default: vanilla SGD, with a selected learning rate.
client_opt = declearn.optimizer.Optimizer(
lrate=0.001,
w_decay=0.0,
regularizers=None,
modules=None,
)
# Wrap all this into a FLOptimConfig.
optim = declearn.main.config.FLOptimConfig.from_params(
aggregator=aggregator,
server_opt=server_opt,
client_opt=client_opt,
)
### (3) Define network communication parameters.
# Here, use websockets protocol on localhost:8765, with SSL encryption.
network = declearn.communication.build_server(
protocol=protocol,
host=host,
port=port,
certificate=certificate,
private_key=private_key,
)
### (4) Instantiate and run a FederatedServer.
# Instanciate
server = declearn.main.FederatedServer(
model=model,
netwk=network,
optim=optim,
metrics=metrics,
checkpoint=checkpoint,
logger=logger,
)
# Set up the experiment's hyper-parameters.
# Registration rules: wait for 10 seconds at registration.
register = declearn.main.config.RegisterConfig(timeout=10)
# Training rounds hyper-parameters. By default, 1 epoch / round.
training = declearn.main.config.TrainingConfig(
batch_size=32,
n_epoch=1,
)
# Evaluation rounds. by default, 1 epoch with train's batch size.
evaluate = declearn.main.config.EvaluateConfig(
batch_size=128,
)
# Wrap all this into a FLRunConfig.
run_config = declearn.main.config.FLRunConfig.from_params(
rounds=5, # you may change the number of training rounds
register=register,
training=training,
evaluate=evaluate,
privacy=None, # you may set up local DP (DP-SGD) here
early_stop=None, # you may add an early-stopping cirterion here
)
server.run(run_config)
# This part should not be altered: it provides with an argument parser.
# for `python server.py`.
def main():
"fire-wrapped split data"
fire.Fire(run_server)
if __name__ == "__main__":
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment