Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit 63218f59 authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Lift the MNIST example files.

parent a9dd17de
No related branches found
No related tags found
1 merge request!41Quickrun mode
......@@ -21,9 +21,11 @@ import datetime
import logging
import os
import fire # type: ignore
import declearn
import declearn.model.tensorflow
import fire
FILEDIR = os.path.dirname(__file__)
......@@ -36,7 +38,6 @@ def run_client(
serv_uri: str = "wss://localhost:8765",
verbose: bool = True,
) -> None:
"""Instantiate and run a given client.
Parameters
......@@ -63,13 +64,7 @@ def run_client(
declearn.utils.set_device_policy(gpu=False)
# Set up logger and checkpointer
stamp = datetime.datetime.now()
stamp = stamp - datetime.timedelta(
minutes=stamp.minute % 5,
seconds=stamp.second,
microseconds=stamp.microsecond,
)
stamp = stamp.strftime("%y-%m-%d_%H-%M")
stamp = datetime.datetime.now().strftime("%y-%m-%d_%H-%M")
checkpoint = os.path.join(FILEDIR, f"result_{stamp}", client_name)
logger = declearn.utils.get_logger(
name=client_name,
......@@ -124,11 +119,11 @@ def run_client(
# This part should not be altered: it provides with an argument parser
# for `python client.py`).
# for `python client.py`.
def main():
"fire-wrapped split data"
"Fire-wrapped `run_client`."
fire.Fire(run_client)
......
......@@ -5,16 +5,16 @@
**We are going to train a common model between three simulated clients on the
classic [MNIST dataset](http://yann.lecun.com/exdb/mnist/)**. The input of the
model is a set of images of handwritten digits, and the model needs to
determine which number between $0$ and $9$ each image corresponds to.
determine to which digit between $0$ and $9$ each image corresponds.
## Setup
To be able to experiment with this tutorial:
* Clone the declearn repo, on the experimental branch:
* Clone the declearn repo (you may specify a given release branch or tag):
```bash
git clone -b experimental git@gitlab.inria.fr:magnet/declearn/declearn2.git declearn
git clone git@gitlab.inria.fr:magnet/declearn/declearn2.git declearn
```
* Create a dedicated virtual environment.
......@@ -25,12 +25,14 @@ cd declearn && pip install .[websockets,tensorflow] && cd ..
```
In an FL experiment, we consider your data as a given. So before running
the experiment below, split the MNIST data using :
the experiment below, download and split the MNIST data using:
```bash
declearn-split --folder "examples/mnist" --n_shards 3
declearn-split --folder "examples/mnist" --n_shards 3
```
You may add `--seed <some_number>` if you want to ensure reproducibility.
## Contents
This script runs a FL experiment using MNIST. The folder is structured
......@@ -60,14 +62,8 @@ details on what running the federated learning processes imply, see the last
section.
```bash
cd
python run.py
```
Use :
```bash
python run.py # note: python examples/heart-uci/run.py works as well
cd declearn/examples/mnist/
python run.py # note: python declearn/examples/mnist/run.py works as well
```
The `run.py` scripts collects the server and client routines defined under
......
......@@ -21,7 +21,7 @@ import glob
import os
import tempfile
import fire
import fire # type: ignore
from declearn.test_utils import generate_ssl_certificates, make_importable
from declearn.utils import run_as_processes
......
......@@ -20,11 +20,12 @@
import datetime
import os
import fire
import fire # type: ignore
import tensorflow as tf # type: ignore
import declearn
FILEDIR = os.path.dirname(os.path.abspath(__file__))
DEFAULT_CERT = os.path.join(FILEDIR, "server-cert.pem")
DEFAULT_PKEY = os.path.join(FILEDIR, "server-pkey.pem")
......@@ -71,13 +72,7 @@ def run_server(
)
# Set up checkpointing and logging.
stamp = datetime.datetime.now()
stamp = stamp - datetime.timedelta(
minutes=stamp.minute % 5,
seconds=stamp.second,
microseconds=stamp.microsecond,
)
stamp = stamp.strftime("%y-%m-%d_%H-%M")
stamp = datetime.datetime.now().strftime("%y-%m-%d_%H-%M")
checkpoint = os.path.join(FILEDIR, f"result_{stamp}", "server")
# Set up a logger, records from which will go to a file.
logger = declearn.utils.get_logger(
......@@ -155,8 +150,12 @@ def run_server(
)
# Set up the experiment's hyper-parameters.
# Registration rules: wait for 10 seconds at registration.
register = declearn.main.config.RegisterConfig(timeout=10)
# Registration rules: wait for exactly `nb_clients`, at most 5 minutes.
register = declearn.main.config.RegisterConfig(
min_clients=nb_clients,
max_clients=nb_clients,
timeout=300,
)
# Training rounds hyper-parameters. By default, 1 epoch / round.
training = declearn.main.config.TrainingConfig(
batch_size=32,
......@@ -183,7 +182,7 @@ def run_server(
def main():
"fire-wrapped split data"
"Fire-wrapped `run_server`."
fire.Fire(run_server)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment