From cf016b5f28f761536dc87ce3cc9bf88f662caa16 Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Tue, 10 Oct 2023 16:02:17 +0200 Subject: [PATCH] Improve 'mnist_quickrun' example documentation and add torch variant. --- docs/quickstart.md | 5 ++- examples/mnist_quickrun/config.toml | 10 +++--- examples/mnist_quickrun/mnist.ipynb | 11 +++++-- examples/mnist_quickrun/model.py | 18 +++++++++++ examples/mnist_quickrun/model_torch.py | 42 ++++++++++++++++++++++++++ 5 files changed, 78 insertions(+), 8 deletions(-) create mode 100644 examples/mnist_quickrun/model_torch.py diff --git a/docs/quickstart.md b/docs/quickstart.md index f2c4d8f6..71981ce3 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -288,7 +288,10 @@ where to report it. An example: ```python [experiment] -metrics=[["multi-classif",{labels = [0,1,2,3,4,5,6,7,8,9]}]] # Accuracy metric +metrics = [ + # Multi-label Accuracy, Precision, Recall and F1-Score. + ["multi-classif", {labels = [0,1,2,3,4,5,6,7,8,9]}] +] checkpoint = "./result_custom" # Custom location for results ``` diff --git a/examples/mnist_quickrun/config.toml b/examples/mnist_quickrun/config.toml index b01422a7..ed48c7d2 100644 --- a/examples/mnist_quickrun/config.toml +++ b/examples/mnist_quickrun/config.toml @@ -1,6 +1,6 @@ # This is a minimal TOML file for the MNIST example # It contains the bare minimum to make the experiment run. -# See quickstart for more details. +# See quickstart for more details. # The TOML is parsed by python as dictionnary with each `[header]` # as a key. Note the "=" sign and the absence of quotes around keys. @@ -12,7 +12,7 @@ port = 8765 # Port used, works as-is on most set ups [data] # Where to find your data - data_folder = "examples/mnist_quickrun/data_iid" + data_folder = "examples/mnist_quickrun/data_iid" [optim] # Optimization options for both client and server aggregator = "averaging" # Server aggregation strategy @@ -37,5 +37,7 @@ batch_size = 128 # Evaluation batch size [experiment] # What to report during the experiment and where to report it - metrics=[["multi-classif",{labels = [0,1,2,3,4,5,6,7,8,9]}]] # Accuracy metric - + metrics = [ + # Multi-label Accuracy, Precision, Recall and F-Score. + ["multi-classif", {labels = [0,1,2,3,4,5,6,7,8,9]}] + ] diff --git a/examples/mnist_quickrun/mnist.ipynb b/examples/mnist_quickrun/mnist.ipynb index d4dc2e62..b44a6eb7 100644 --- a/examples/mnist_quickrun/mnist.ipynb +++ b/examples/mnist_quickrun/mnist.ipynb @@ -23,7 +23,9 @@ "id": "Clzf4NTja121" }, "source": [ - "We first clone the repo, to have both the package itself and the `examples` folder we will use in this tutorial, then naviguate to the package directory, and finally install the required dependencies" + "We first clone the repo, to have both the package itself and the `examples` folder we will use in this tutorial, then naviguate to the package directory, and finally install the required dependencies.\n", + "\n", + "**If you have already cloned the repository and/or installed declearn, you may skip the following commands.** Simply make sure to set your current working directory to the folder under which the `examples/mnist_quickrun` subfolder may be found (as cloned or downloaded from the repo)." ] }, { @@ -100,7 +102,9 @@ "source": [ "## The model\n", "\n", - "To do this, we will use a simple CNN, defined in `examples/mnist_quickrun/model.py`" + "To do this, we will use a simple CNN, defined in `examples/mnist_quickrun/model.py`.\n", + "\n", + "Here, the model is implemented in TensorFlow, which is merely an implementation detail. If you update the `config.toml` file to use the `examples/mnist_quickrun/model_torch.py`, you will train a model with the same architecture, but implemented with Torch." ] }, { @@ -147,7 +151,8 @@ ], "source": [ "from examples.mnist_quickrun.model import network\n", - "network.summary()" + "\n", + "network.summary() # network is a `tensorflow.keras.Model` instance" ] }, { diff --git a/examples/mnist_quickrun/model.py b/examples/mnist_quickrun/model.py index 01dd16a0..6e0e7616 100644 --- a/examples/mnist_quickrun/model.py +++ b/examples/mnist_quickrun/model.py @@ -1,9 +1,27 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Simple TensorFlow-backed CNN model for the MNIST quickrun example.""" import tensorflow as tf from declearn.model.tensorflow import TensorflowModel + stack = [ tf.keras.layers.InputLayer(input_shape=(28, 28, 1)), tf.keras.layers.Conv2D(8, 3, 1, activation="relu"), diff --git a/examples/mnist_quickrun/model_torch.py b/examples/mnist_quickrun/model_torch.py new file mode 100644 index 00000000..59ed94c7 --- /dev/null +++ b/examples/mnist_quickrun/model_torch.py @@ -0,0 +1,42 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Simple Torch-backed CNN model for the MNIST quickrun example.""" + +import torch + +from declearn.model.torch import TorchModel + + +stack = [ + torch.nn.Unflatten(dim=0, unflattened_size=(-1, 1)), + torch.nn.Conv2d(1, 8, 3, 1), + torch.nn.ReLU(), + torch.nn.MaxPool2d(2), + torch.nn.Dropout(0.25), + torch.nn.Flatten(), + torch.nn.Linear(1352, 64), + torch.nn.ReLU(), + torch.nn.Dropout(0.5), + torch.nn.Linear(64, 10), + torch.nn.Softmax(dim=-1), +] +network = torch.nn.Sequential(*stack) + +# This needs to be called "model"; otherwise, a different name must be +# specified via the experiment's TOML configuration file. +model = TorchModel(network, loss=torch.nn.CrossEntropyLoss()) -- GitLab