From cf016b5f28f761536dc87ce3cc9bf88f662caa16 Mon Sep 17 00:00:00 2001
From: Paul Andrey <paul.andrey@inria.fr>
Date: Tue, 10 Oct 2023 16:02:17 +0200
Subject: [PATCH] Improve 'mnist_quickrun' example documentation and add torch
 variant.

---
 docs/quickstart.md                     |  5 ++-
 examples/mnist_quickrun/config.toml    | 10 +++---
 examples/mnist_quickrun/mnist.ipynb    | 11 +++++--
 examples/mnist_quickrun/model.py       | 18 +++++++++++
 examples/mnist_quickrun/model_torch.py | 42 ++++++++++++++++++++++++++
 5 files changed, 78 insertions(+), 8 deletions(-)
 create mode 100644 examples/mnist_quickrun/model_torch.py

diff --git a/docs/quickstart.md b/docs/quickstart.md
index f2c4d8f6..71981ce3 100644
--- a/docs/quickstart.md
+++ b/docs/quickstart.md
@@ -288,7 +288,10 @@ where to report it. An example:
 
 ```python
 [experiment]
-metrics=[["multi-classif",{labels = [0,1,2,3,4,5,6,7,8,9]}]] # Accuracy metric
+metrics = [
+    # Multi-label Accuracy, Precision, Recall and F1-Score.
+    ["multi-classif", {labels = [0,1,2,3,4,5,6,7,8,9]}]
+]
 checkpoint = "./result_custom" # Custom location for results
 ```
 
diff --git a/examples/mnist_quickrun/config.toml b/examples/mnist_quickrun/config.toml
index b01422a7..ed48c7d2 100644
--- a/examples/mnist_quickrun/config.toml
+++ b/examples/mnist_quickrun/config.toml
@@ -1,6 +1,6 @@
 # This is a minimal TOML file for the MNIST example
 # It contains the bare minimum to make the experiment run.
-# See quickstart for more details. 
+# See quickstart for more details.
 
 # The TOML is parsed by python as dictionnary with each `[header]`
 # as a key. Note the "=" sign and the absence of quotes around keys.
@@ -12,7 +12,7 @@
     port = 8765 # Port used, works as-is on most set ups
 
 [data] # Where to find your data
-    data_folder = "examples/mnist_quickrun/data_iid" 
+    data_folder = "examples/mnist_quickrun/data_iid"
 
 [optim] # Optimization options for both client and server
     aggregator = "averaging" # Server aggregation strategy
@@ -37,5 +37,7 @@
     batch_size = 128 # Evaluation batch size
 
 [experiment] # What to report during the experiment and where to report it
-    metrics=[["multi-classif",{labels = [0,1,2,3,4,5,6,7,8,9]}]] # Accuracy metric
-
+    metrics = [
+        # Multi-label Accuracy, Precision, Recall and F-Score.
+        ["multi-classif", {labels = [0,1,2,3,4,5,6,7,8,9]}]
+    ]
diff --git a/examples/mnist_quickrun/mnist.ipynb b/examples/mnist_quickrun/mnist.ipynb
index d4dc2e62..b44a6eb7 100644
--- a/examples/mnist_quickrun/mnist.ipynb
+++ b/examples/mnist_quickrun/mnist.ipynb
@@ -23,7 +23,9 @@
     "id": "Clzf4NTja121"
    },
    "source": [
-    "We first clone the repo, to have both the package itself and the `examples` folder we will use in this tutorial, then naviguate to the package directory, and finally install the required dependencies"
+    "We first clone the repo, to have both the package itself and the `examples` folder we will use in this tutorial, then naviguate to the package directory, and finally install the required dependencies.\n",
+    "\n",
+    "**If you have already cloned the repository and/or installed declearn, you may skip the following commands.** Simply make sure to set your current working directory to the folder under which the `examples/mnist_quickrun` subfolder may be found (as cloned or downloaded from the repo)."
    ]
   },
   {
@@ -100,7 +102,9 @@
    "source": [
     "## The model\n",
     "\n",
-    "To do this, we will use a simple CNN, defined in `examples/mnist_quickrun/model.py`"
+    "To do this, we will use a simple CNN, defined in `examples/mnist_quickrun/model.py`.\n",
+    "\n",
+    "Here, the model is implemented in TensorFlow, which is merely an implementation detail. If you update the `config.toml` file to use the `examples/mnist_quickrun/model_torch.py`, you will train a model with the same architecture, but implemented with Torch."
    ]
   },
   {
@@ -147,7 +151,8 @@
    ],
    "source": [
     "from examples.mnist_quickrun.model import network\n",
-    "network.summary()"
+    "\n",
+    "network.summary()  # network is a `tensorflow.keras.Model` instance"
    ]
   },
   {
diff --git a/examples/mnist_quickrun/model.py b/examples/mnist_quickrun/model.py
index 01dd16a0..6e0e7616 100644
--- a/examples/mnist_quickrun/model.py
+++ b/examples/mnist_quickrun/model.py
@@ -1,9 +1,27 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 """Simple TensorFlow-backed CNN model for the MNIST quickrun example."""
 
 import tensorflow as tf
 
 from declearn.model.tensorflow import TensorflowModel
 
+
 stack = [
     tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),
     tf.keras.layers.Conv2D(8, 3, 1, activation="relu"),
diff --git a/examples/mnist_quickrun/model_torch.py b/examples/mnist_quickrun/model_torch.py
new file mode 100644
index 00000000..59ed94c7
--- /dev/null
+++ b/examples/mnist_quickrun/model_torch.py
@@ -0,0 +1,42 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Simple Torch-backed CNN model for the MNIST quickrun example."""
+
+import torch
+
+from declearn.model.torch import TorchModel
+
+
+stack = [
+    torch.nn.Unflatten(dim=0, unflattened_size=(-1, 1)),
+    torch.nn.Conv2d(1, 8, 3, 1),
+    torch.nn.ReLU(),
+    torch.nn.MaxPool2d(2),
+    torch.nn.Dropout(0.25),
+    torch.nn.Flatten(),
+    torch.nn.Linear(1352, 64),
+    torch.nn.ReLU(),
+    torch.nn.Dropout(0.5),
+    torch.nn.Linear(64, 10),
+    torch.nn.Softmax(dim=-1),
+]
+network = torch.nn.Sequential(*stack)
+
+# This needs to be called "model"; otherwise, a different name must be
+# specified via the experiment's TOML configuration file.
+model = TorchModel(network, loss=torch.nn.CrossEntropyLoss())
-- 
GitLab