diff --git a/docs/quickstart.md b/docs/quickstart.md index f2c4d8f681d06434a4046c1bc630858e3920a076..71981ce36c2d75b478dee1aebeffed5a82f84d81 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -288,7 +288,10 @@ where to report it. An example: ```python [experiment] -metrics=[["multi-classif",{labels = [0,1,2,3,4,5,6,7,8,9]}]] # Accuracy metric +metrics = [ + # Multi-label Accuracy, Precision, Recall and F1-Score. + ["multi-classif", {labels = [0,1,2,3,4,5,6,7,8,9]}] +] checkpoint = "./result_custom" # Custom location for results ``` diff --git a/examples/mnist_quickrun/config.toml b/examples/mnist_quickrun/config.toml index b01422a743b1631857f799686e13be08f60df34b..ed48c7d20e18d43933b5fb679fbd9c66e460ceb1 100644 --- a/examples/mnist_quickrun/config.toml +++ b/examples/mnist_quickrun/config.toml @@ -1,6 +1,6 @@ # This is a minimal TOML file for the MNIST example # It contains the bare minimum to make the experiment run. -# See quickstart for more details. +# See quickstart for more details. # The TOML is parsed by python as dictionnary with each `[header]` # as a key. Note the "=" sign and the absence of quotes around keys. @@ -12,7 +12,7 @@ port = 8765 # Port used, works as-is on most set ups [data] # Where to find your data - data_folder = "examples/mnist_quickrun/data_iid" + data_folder = "examples/mnist_quickrun/data_iid" [optim] # Optimization options for both client and server aggregator = "averaging" # Server aggregation strategy @@ -37,5 +37,7 @@ batch_size = 128 # Evaluation batch size [experiment] # What to report during the experiment and where to report it - metrics=[["multi-classif",{labels = [0,1,2,3,4,5,6,7,8,9]}]] # Accuracy metric - + metrics = [ + # Multi-label Accuracy, Precision, Recall and F-Score. + ["multi-classif", {labels = [0,1,2,3,4,5,6,7,8,9]}] + ] diff --git a/examples/mnist_quickrun/mnist.ipynb b/examples/mnist_quickrun/mnist.ipynb index d4dc2e6249b9dc33ac19ea9388b2f20d442ec799..b44a6eb714b6826d69352945a4879df6d36782a2 100644 --- a/examples/mnist_quickrun/mnist.ipynb +++ b/examples/mnist_quickrun/mnist.ipynb @@ -23,7 +23,9 @@ "id": "Clzf4NTja121" }, "source": [ - "We first clone the repo, to have both the package itself and the `examples` folder we will use in this tutorial, then naviguate to the package directory, and finally install the required dependencies" + "We first clone the repo, to have both the package itself and the `examples` folder we will use in this tutorial, then naviguate to the package directory, and finally install the required dependencies.\n", + "\n", + "**If you have already cloned the repository and/or installed declearn, you may skip the following commands.** Simply make sure to set your current working directory to the folder under which the `examples/mnist_quickrun` subfolder may be found (as cloned or downloaded from the repo)." ] }, { @@ -100,7 +102,9 @@ "source": [ "## The model\n", "\n", - "To do this, we will use a simple CNN, defined in `examples/mnist_quickrun/model.py`" + "To do this, we will use a simple CNN, defined in `examples/mnist_quickrun/model.py`.\n", + "\n", + "Here, the model is implemented in TensorFlow, which is merely an implementation detail. If you update the `config.toml` file to use the `examples/mnist_quickrun/model_torch.py`, you will train a model with the same architecture, but implemented with Torch." ] }, { @@ -147,7 +151,8 @@ ], "source": [ "from examples.mnist_quickrun.model import network\n", - "network.summary()" + "\n", + "network.summary() # network is a `tensorflow.keras.Model` instance" ] }, { diff --git a/examples/mnist_quickrun/model.py b/examples/mnist_quickrun/model.py index 01dd16a094fa00fab9f6c5336f2c7c2825c67cb2..6e0e76162bf022900d026883d61b63e24eb58c9a 100644 --- a/examples/mnist_quickrun/model.py +++ b/examples/mnist_quickrun/model.py @@ -1,9 +1,27 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Simple TensorFlow-backed CNN model for the MNIST quickrun example.""" import tensorflow as tf from declearn.model.tensorflow import TensorflowModel + stack = [ tf.keras.layers.InputLayer(input_shape=(28, 28, 1)), tf.keras.layers.Conv2D(8, 3, 1, activation="relu"), diff --git a/examples/mnist_quickrun/model_torch.py b/examples/mnist_quickrun/model_torch.py new file mode 100644 index 0000000000000000000000000000000000000000..59ed94c75ed28a14191ed55733a8006739757798 --- /dev/null +++ b/examples/mnist_quickrun/model_torch.py @@ -0,0 +1,42 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Simple Torch-backed CNN model for the MNIST quickrun example.""" + +import torch + +from declearn.model.torch import TorchModel + + +stack = [ + torch.nn.Unflatten(dim=0, unflattened_size=(-1, 1)), + torch.nn.Conv2d(1, 8, 3, 1), + torch.nn.ReLU(), + torch.nn.MaxPool2d(2), + torch.nn.Dropout(0.25), + torch.nn.Flatten(), + torch.nn.Linear(1352, 64), + torch.nn.ReLU(), + torch.nn.Dropout(0.5), + torch.nn.Linear(64, 10), + torch.nn.Softmax(dim=-1), +] +network = torch.nn.Sequential(*stack) + +# This needs to be called "model"; otherwise, a different name must be +# specified via the experiment's TOML configuration file. +model = TorchModel(network, loss=torch.nn.CrossEntropyLoss())