From 82174e1eed2766f3530ba9b89df8f1393094e48e Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Wed, 26 Apr 2023 09:35:52 +0200 Subject: [PATCH] Update mnist-quickrun example and add a readme file. --- examples/mnist_quickrun/mnist.ipynb | 1067 +++++++++++++-------------- examples/mnist_quickrun/readme.md | 31 + 2 files changed, 539 insertions(+), 559 deletions(-) create mode 100644 examples/mnist_quickrun/readme.md diff --git a/examples/mnist_quickrun/mnist.ipynb b/examples/mnist_quickrun/mnist.ipynb index 7753314d..a55dc891 100644 --- a/examples/mnist_quickrun/mnist.ipynb +++ b/examples/mnist_quickrun/mnist.ipynb @@ -1,576 +1,525 @@ { - "cells": [ - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "This notebook is meant to be run in google colab. You can find import your local copy of the file in the the [colab welcome page](https://colab.research.google.com/)." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "s9bpLdH5ThpJ" - }, - "source": [ - "# Setting up your declearn " - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Clzf4NTja121" - }, - "source": [ - "We first clone the repo, to have both the package itself and the `examples` folder we will use in this tutorial, then naviguate to the package directory, and finally install the required dependencies" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "u2QDwb0_QQ_f", - "outputId": "cac0761c-b229-49b0-d71d-c7b5cef919b3" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Cloning into 'declearn2'...\n", - "warning: redirecting to https://gitlab.inria.fr/magnet/declearn/declearn2.git/\n", - "remote: Enumerating objects: 4997, done.\u001b[K\n", - "remote: Counting objects: 100% (79/79), done.\u001b[K\n", - "remote: Compressing objects: 100% (79/79), done.\u001b[K\n", - "remote: Total 4997 (delta 39), reused 0 (delta 0), pack-reused 4918\u001b[K\n", - "Receiving objects: 100% (4997/4997), 1.15 MiB | 777.00 KiB/s, done.\n", - "Resolving deltas: 100% (3248/3248), done.\n" - ] - } - ], - "source": [ - "!git clone -b experimental https://gitlab.inria.fr/magnet/declearn/declearn2" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "9kDHh_AfPG2l", - "outputId": "74e2f85f-7f93-40ae-a218-f4403470d72c" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/content/declearn2\n" - ] - } - ], - "source": [ - "cd declearn2" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Un212t1GluHB", - "outputId": "0ea67577-da6e-4f80-a412-7b7a79803aa1" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", - "Processing /content/declearn2\n", - " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", - " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", - " Installing backend dependencies ... \u001b[?25l\u001b[?25hdone\n", - " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "Requirement already satisfied: cryptography>=35.0 in /usr/local/lib/python3.9/dist-packages (from declearn==2.1.0) (40.0.1)\n", - "Requirement already satisfied: scikit-learn>=1.0 in /usr/local/lib/python3.9/dist-packages (from declearn==2.1.0) (1.2.2)\n", - "Requirement already satisfied: requests~=2.18 in /usr/local/lib/python3.9/dist-packages (from declearn==2.1.0) (2.27.1)\n", - "Requirement already satisfied: pandas>=1.2 in /usr/local/lib/python3.9/dist-packages (from declearn==2.1.0) (1.5.3)\n", - "Requirement already satisfied: tomli>=2.0 in /usr/local/lib/python3.9/dist-packages (from declearn==2.1.0) (2.0.1)\n", - "Collecting fire>=0.4\n", - " Downloading fire-0.5.0.tar.gz (88 kB)\n", - "\u001b[2K \u001b[90mâ”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”\u001b[0m \u001b[32m88.3/88.3 kB\u001b[0m \u001b[31m1.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - "Requirement already satisfied: typing-extensions>=4.0 in /usr/local/lib/python3.9/dist-packages (from declearn==2.1.0) (4.5.0)\n", - "Collecting websockets~=10.1\n", - " Downloading websockets-10.4-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (106 kB)\n", - "\u001b[2K \u001b[90mâ”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”â”\u001b[0m \u001b[32m106.5/106.5 kB\u001b[0m \u001b[31m3.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: cffi>=1.12 in /usr/local/lib/python3.9/dist-packages (from cryptography>=35.0->declearn==2.1.0) (1.15.1)\n", - "Requirement already satisfied: six in /usr/local/lib/python3.9/dist-packages (from fire>=0.4->declearn==2.1.0) (1.16.0)\n", - "Requirement already satisfied: termcolor in /usr/local/lib/python3.9/dist-packages (from fire>=0.4->declearn==2.1.0) (2.2.0)\n", - "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.9/dist-packages (from pandas>=1.2->declearn==2.1.0) (2022.7.1)\n", - "Requirement already satisfied: numpy>=1.20.3 in /usr/local/lib/python3.9/dist-packages (from pandas>=1.2->declearn==2.1.0) (1.22.4)\n", - "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.9/dist-packages (from pandas>=1.2->declearn==2.1.0) (2.8.2)\n", - "Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.9/dist-packages (from requests~=2.18->declearn==2.1.0) (2.0.12)\n", - "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.9/dist-packages (from requests~=2.18->declearn==2.1.0) (1.26.15)\n", - "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.9/dist-packages (from requests~=2.18->declearn==2.1.0) (2022.12.7)\n", - "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.9/dist-packages (from requests~=2.18->declearn==2.1.0) (3.4)\n", - "Requirement already satisfied: scipy>=1.3.2 in /usr/local/lib/python3.9/dist-packages (from scikit-learn>=1.0->declearn==2.1.0) (1.10.1)\n", - "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.9/dist-packages (from scikit-learn>=1.0->declearn==2.1.0) (3.1.0)\n", - "Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.9/dist-packages (from scikit-learn>=1.0->declearn==2.1.0) (1.2.0)\n", - "Requirement already satisfied: pycparser in /usr/local/lib/python3.9/dist-packages (from cffi>=1.12->cryptography>=35.0->declearn==2.1.0) (2.21)\n", - "Building wheels for collected packages: fire, declearn\n", - " Building wheel for fire (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Created wheel for fire: filename=fire-0.5.0-py2.py3-none-any.whl size=116952 sha256=ab01943c400d3267450974ec56a6572193bed40710845edd44623e56c7757799\n", - " Stored in directory: /root/.cache/pip/wheels/f7/f1/89/b9ea2bf8f80ec027a88fef1d354b3816b4d3d29530988972f6\n", - " Building wheel for declearn (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - " Created wheel for declearn: filename=declearn-2.1.0-py3-none-any.whl size=276123 sha256=4969a91ded8b704c8c9497bcda8f514f847c49098715d659cc8e96a947ec887f\n", - " Stored in directory: /tmp/pip-ephem-wheel-cache-fgkx9jiw/wheels/cc/79/79/6586306a117d40a1f8b251a22e50583b8abb2d7e855a62ecf7\n", - "Successfully built fire declearn\n", - "Installing collected packages: websockets, fire, declearn\n", - "Successfully installed declearn-2.1.0 fire-0.5.0 websockets-10.4\n" - ] - } - ], - "source": [ - "!pip install .[websockets]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hC8Fty8YTy9P" - }, - "source": [ - "# Running your first experiment" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "rcWcZJdob1IG" - }, - "source": [ - "We are going to train a common model between three simulated clients on the classic [MNIST dataset](http://yann.lecun.com/exdb/mnist/). The input of the model is a set of images of handwritten digits, and the model needs to determine which number between 0 and 9 each image corresponds to." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "KlY_vVtFHv2P" - }, - "source": [ - "## The model\n", - "\n", - "To do this, we will use a simple CNN, defined in `examples/mnist_quickrun/model.py`" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "C7D52a8_dEr7", - "outputId": "a25223f8-c8eb-4998-d7fd-4b8bfde92486" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Model: \"sequential\"\n", - "_________________________________________________________________\n", - " Layer (type) Output Shape Param # \n", - "=================================================================\n", - " conv2d (Conv2D) (None, 26, 26, 8) 80 \n", - " \n", - " max_pooling2d (MaxPooling2D (None, 13, 13, 8) 0 \n", - " ) \n", - " \n", - " dropout (Dropout) (None, 13, 13, 8) 0 \n", - " \n", - " flatten (Flatten) (None, 1352) 0 \n", - " \n", - " dense (Dense) (None, 64) 86592 \n", - " \n", - " dropout_1 (Dropout) (None, 64) 0 \n", - " \n", - " dense_1 (Dense) (None, 10) 650 \n", - " \n", - "=================================================================\n", - "Total params: 87,322\n", - "Trainable params: 87,322\n", - "Non-trainable params: 0\n", - "_________________________________________________________________\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/content/declearn2/declearn/model/tensorflow/utils/_gpu.py:66: UserWarning: Cannot use a GPU device: either CUDA is unavailable or no GPU is visible to tensorflow.\n", - " warnings.warn(\n" - ] - } - ], - "source": [ - "from examples.mnist_quickrun.model import model\n", - "model.summary()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "HoBcOs9hH2QA" - }, - "source": [ - "## The data\n", - "\n", - "We start by splitting the MNIST dataset between 3 clients and storing the output in the `examples/mnist_quickrun` folder. For this we use an experimental utility provided by `declearn`. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "quduXkpIWFjL", - "outputId": "ddf7d45d-acf0-44ee-ce77-357c0987a2a1" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Downloading MNIST source file train-images-idx3-ubyte.gz.\n", - "Downloading MNIST source file train-labels-idx1-ubyte.gz.\n", - "Splitting data into 3 shards using the 'iid' scheme.\n" - ] - } - ], - "source": [ - "from declearn.dataset import split_data\n", - "split_data(folder=\"examples/mnist_quickrun\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "3-2hKmz-2RF4" - }, - "source": [ - "Here is what the first image of the first client looks like:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 430 - }, - "id": "MLVI9GOZ1TGd", - "outputId": "f34a6a93-cb5f-4a45-bc24-4146ea119d1a" - }, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAbo0lEQVR4nO3df2zV9fXH8dct0itoe7ta29tKy1r8gYrUDKR2Kv6goXQJESQLikvAGJxYjMicpkZBtiXdMPPrNAz/cTATUcQJRDNJsNgStxZDlRCmVsq6UQYtwsK9pUhh7fv7B+HqhfLjc7m3597yfCQ3offe03v8eO3Ty7188DnnnAAAGGBp1gsAAC5OBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJi4xHqBU/X19Wnv3r3KyMiQz+ezXgcA4JFzTl1dXSooKFBa2plf5yRdgPbu3avCwkLrNQAAF6i9vV0jRow44+1JF6CMjAxJJxbPzMw03gYA4FU4HFZhYWHk5/mZJCxAy5Yt04svvqiOjg6Vlpbq1Vdf1YQJE845d/K33TIzMwkQAKSwc72NkpAPIaxevVoLFy7U4sWL9dlnn6m0tFSVlZXav39/Ih4OAJCCEhKgl156SXPnztVDDz2kG264Qa+99pqGDx+uP/3pT4l4OABACop7gI4dO6bm5mZVVFR89yBpaaqoqFBjY+Np9+/p6VE4HI66AAAGv7gH6MCBA+rt7VVeXl7U9Xl5eero6Djt/rW1tQoEApELn4ADgIuD+R9ErampUSgUilza29utVwIADIC4fwouJydHQ4YMUWdnZ9T1nZ2dCgaDp93f7/fL7/fHew0AQJKL+yug9PR0jRs3TnV1dZHr+vr6VFdXp/Ly8ng/HAAgRSXkzwEtXLhQs2fP1vjx4zVhwgS9/PLL6u7u1kMPPZSIhwMApKCEBGjmzJn65ptvtGjRInV0dOjmm2/Whg0bTvtgAgDg4uVzzjnrJb4vHA4rEAgoFApxJgQASEHn+3Pc/FNwAICLEwECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGDiEusFAJyfnp4ezzMHDhyI6bFef/31mOa8+vjjjz3P1NfXe57x+XyeZ2IVyz/TnXfemYBNkh+vgAAAJggQAMBE3AP0wgsvyOfzRV1Gjx4d74cBAKS4hLwHdOONN+qjjz767kEu4a0mAEC0hJThkksuUTAYTMS3BgAMEgl5D2jnzp0qKChQSUmJHnzwQe3evfuM9+3p6VE4HI66AAAGv7gHqKysTCtXrtSGDRu0fPlytbW16Y477lBXV1e/96+trVUgEIhcCgsL470SACAJxT1AVVVV+ulPf6qxY8eqsrJSf/3rX3Xo0CG98847/d6/pqZGoVAocmlvb4/3SgCAJJTwTwdkZWXp2muvVWtra7+3+/1++f3+RK8BAEgyCf9zQIcPH9auXbuUn5+f6IcCAKSQuAfoqaeeUkNDg/71r3/p73//u6ZPn64hQ4bogQceiPdDAQBSWNx/C27Pnj164IEHdPDgQV155ZW6/fbb1dTUpCuvvDLeDwUASGFxD9Dbb78d728JJLXe3l7PMxs3bvQ8s2TJEs8zn376qeeZZJeWltxnEAuFQtYrpIzk/jcJABi0CBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATCf8L6ZD8nHMxzfX19XmeGagTScaymyR98803nmeqqqo8z2zfvt3zDAbekCFDPM8UFRUlYJPBiVdAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMMHZsKEdO3bENHfzzTd7nvnHP/7heebbb7/1PDN+/HjPM0gNGRkZnmdGjx4d02P95S9/8Txz1VVXxfRYFyNeAQEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJjgZ6SDzz3/+0/PMrFmzErBJ//bv3+95pqamJgGbXBwuvfTSmOYefPDBOG/SvyeeeMLzTFZWlucZThCanHgFBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCY4GSkg8yBAwc8z3zxxRcxPZbP5/M809TU5Hlmy5YtnmfS0mL7f6thw4bFNOfVzJkzPc/cf//9nmfKyso8z0jS5ZdfHtMc4AWvgAAAJggQAMCE5wBt3rxZU6dOVUFBgXw+n9atWxd1u3NOixYtUn5+voYNG6aKigrt3LkzXvsCAAYJzwHq7u5WaWmpli1b1u/tS5cu1SuvvKLXXntNW7Zs0WWXXabKykodPXr0gpcFAAwenj+EUFVVpaqqqn5vc87p5Zdf1nPPPad7771XkvTGG28oLy9P69ati+lNVADA4BTX94Da2trU0dGhioqKyHWBQEBlZWVqbGzsd6anp0fhcDjqAgAY/OIaoI6ODklSXl5e1PV5eXmR205VW1urQCAQuRQWFsZzJQBAkjL/FFxNTY1CoVDk0t7ebr0SAGAAxDVAwWBQktTZ2Rl1fWdnZ+S2U/n9fmVmZkZdAACDX1wDVFxcrGAwqLq6ush14XBYW7ZsUXl5eTwfCgCQ4jx/Cu7w4cNqbW2NfN3W1qZt27YpOztbRUVFWrBggX7zm9/ommuuUXFxsZ5//nkVFBRo2rRp8dwbAJDiPAdo69atuvvuuyNfL1y4UJI0e/ZsrVy5Uk8//bS6u7v1yCOP6NChQ7r99tu1YcMGXXrppfHbGgCQ8nzOOWe9xPeFw2EFAgGFQiHeD4rBrFmzPM+sXr06psfKzs72PBPLiU+//vprzzN+v9/zjCSNHz8+pjkA3znfn+Pmn4IDAFycCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYMLzX8cAnPTf//7X88zUqVM9z6xatcrzTElJiecZAAOLV0AAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAmfc85ZL/F94XBYgUBAoVBImZmZ1uuknMbGRs8zt99+ewI2iZ/hw4d7npkzZ05Mj7VkyRLPM7HsN3ToUM8zQ4YM8TwDWDjfn+O8AgIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATHAy0kEmHA57nlm+fHlMj/Xss8/GNJfMrr76as8ze/bs8Txzzz33eJ6JZbdYTZ8+3fPMrbfe6nkmPT3d8wySHycjBQAkNQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABCcjhfr6+mKae+qppzzPvPvuu55n/vOf/3iewcD7+c9/7nlm0aJFnmeCwaDnGQwsTkYKAEhqBAgAYMJzgDZv3qypU6eqoKBAPp9P69ati7p9zpw58vl8UZcpU6bEa18AwCDhOUDd3d0qLS3VsmXLznifKVOmaN++fZHLW2+9dUFLAgAGn0u8DlRVVamqquqs9/H7/bxRCAA4q4S8B1RfX6/c3Fxdd911mjdvng4ePHjG+/b09CgcDkddAACDX9wDNGXKFL3xxhuqq6vT7373OzU0NKiqqkq9vb393r+2tlaBQCByKSwsjPdKAIAk5Pm34M7l/vvvj/z6pptu0tixYzVq1CjV19dr0qRJp92/pqZGCxcujHwdDoeJEABcBBL+MeySkhLl5OSotbW139v9fr8yMzOjLgCAwS/hAdqzZ48OHjyo/Pz8RD8UACCFeP4tuMOHD0e9mmlra9O2bduUnZ2t7OxsLVmyRDNmzFAwGNSuXbv09NNP6+qrr1ZlZWVcFwcApDbPAdq6davuvvvuyNcn37+ZPXu2li9fru3bt+vPf/6zDh06pIKCAk2ePFm//vWv5ff747c1ACDlcTJSDKhjx455nvnf//7neebNN9/0PCNJX331leeZP/zhD55nkuw/OzOvvvqq55nHHnssAZsgnjgZKQAgqREgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEZ8MGLtDWrVs9z7z77rueZ7788kvPMx9++KHnGUnq7e2Nac6r4cOHe57ZsWOH55mRI0d6nkHsOBs2ACCpESAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmLrFeAEh148ePH5CZWHz99dcxzV1//fVx3qR/R44c8TzT2dnpeYaTkSYnXgEBAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACY4GSkwiF111VXWK5xVVlaW55mSkpL4LwITvAICAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAExwMlIMqD179nieGT58uOeZ7OxszzMDqaenx/NMW1ub55nf//73nmcG0oQJEzzP5OTkJGATWOAVEADABAECAJjwFKDa2lrdcsstysjIUG5urqZNm6aWlpao+xw9elTV1dW64oordPnll2vGjBnq7OyM69IAgNTnKUANDQ2qrq5WU1OTNm7cqOPHj2vy5Mnq7u6O3OfJJ5/U+++/rzVr1qihoUF79+7VfffdF/fFAQCpzdOHEDZs2BD19cqVK5Wbm6vm5mZNnDhRoVBIr7/+ulatWqV77rlHkrRixQpdf/31ampq0q233hq/zQEAKe2C3gMKhUKSvvvEUXNzs44fP66KiorIfUaPHq2ioiI1Njb2+z16enoUDoejLgCAwS/mAPX19WnBggW67bbbNGbMGElSR0eH0tPTT/t73vPy8tTR0dHv96mtrVUgEIhcCgsLY10JAJBCYg5QdXW1duzYobfffvuCFqipqVEoFIpc2tvbL+j7AQBSQ0x/EHX+/Pn64IMPtHnzZo0YMSJyfTAY1LFjx3To0KGoV0GdnZ0KBoP9fi+/3y+/3x/LGgCAFObpFZBzTvPnz9fatWu1adMmFRcXR90+btw4DR06VHV1dZHrWlpatHv3bpWXl8dnYwDAoODpFVB1dbVWrVql9evXKyMjI/K+TiAQ0LBhwxQIBPTwww9r4cKFys7OVmZmph5//HGVl5fzCTgAQBRPAVq+fLkk6a677oq6fsWKFZozZ44k6f/+7/+UlpamGTNmqKenR5WVlfrjH/8Yl2UBAIOHzznnrJf4vnA4rEAgoFAopMzMTOt1cBaffvqp55nvf0T/fF122WWeZ4qKijzPDKSuri7PM6eedWQwaG5u9jxz8803x38RxNX5/hznXHAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwEdPfiIrB5ciRIzHN/fjHP/Y8E8vJ17u7uz3P7N+/3/MMLkxbW5vnmcLCwgRsglTBKyAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQnI4XWr18f01wsJxbFwLr11ltjmlu2bJnnmdzcXM8zPp/P8wwGD14BAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmOBkpVFpaar3CRWfmzJkDMjN58mTPM5I0bNiwmOYAL3gFBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCY4GSk0A033BDTXG9vb5w3AXAx4RUQAMAEAQIAmPAUoNraWt1yyy3KyMhQbm6upk2bppaWlqj73HXXXfL5fFGXRx99NK5LAwBSn6cANTQ0qLq6Wk1NTdq4caOOHz+uyZMnq7u7O+p+c+fO1b59+yKXpUuXxnVpAEDq8/QhhA0bNkR9vXLlSuXm5qq5uVkTJ06MXD98+HAFg8H4bAgAGJQu6D2gUCgkScrOzo66/s0331ROTo7GjBmjmpoaHTly5Izfo6enR+FwOOoCABj8Yv4Ydl9fnxYsWKDbbrtNY8aMiVw/a9YsjRw5UgUFBdq+fbueeeYZtbS06L333uv3+9TW1mrJkiWxrgEASFE+55yLZXDevHn68MMP9cknn2jEiBFnvN+mTZs0adIktba2atSoUafd3tPTo56ensjX4XBYhYWFCoVCyszMjGU1AIChcDisQCBwzp/jMb0Cmj9/vj744ANt3rz5rPGRpLKyMkk6Y4D8fr/8fn8sawAAUpinADnn9Pjjj2vt2rWqr69XcXHxOWe2bdsmScrPz49pQQDA4OQpQNXV1Vq1apXWr1+vjIwMdXR0SJICgYCGDRumXbt2adWqVfrJT36iK664Qtu3b9eTTz6piRMnauzYsQn5BwAApCZP7wH5fL5+r1+xYoXmzJmj9vZ2/exnP9OOHTvU3d2twsJCTZ8+Xc8999x5v59zvr93CABITgl5D+hcrSosLFRDQ4OXbwkAuEhxLjgAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgIlLrBc4lXNOkhQOh403AQDE4uTP75M/z88k6QLU1dUlSSosLDTeBABwIbq6uhQIBM54u8+dK1EDrK+vT3v37lVGRoZ8Pl/UbeFwWIWFhWpvb1dmZqbRhvY4DidwHE7gOJzAcTghGY6Dc05dXV0qKChQWtqZ3+lJuldAaWlpGjFixFnvk5mZeVE/wU7iOJzAcTiB43ACx+EE6+Nwtlc+J/EhBACACQIEADCRUgHy+/1avHix/H6/9SqmOA4ncBxO4DicwHE4IZWOQ9J9CAEAcHFIqVdAAIDBgwABAEwQIACACQIEADCRMgFatmyZfvjDH+rSSy9VWVmZPv30U+uVBtwLL7wgn88XdRk9erT1Wgm3efNmTZ06VQUFBfL5fFq3bl3U7c45LVq0SPn5+Ro2bJgqKiq0c+dOm2UT6FzHYc6cOac9P6ZMmWKzbILU1tbqlltuUUZGhnJzczVt2jS1tLRE3efo0aOqrq7WFVdcocsvv1wzZsxQZ2en0caJcT7H4a677jrt+fDoo48abdy/lAjQ6tWrtXDhQi1evFifffaZSktLVVlZqf3791uvNuBuvPFG7du3L3L55JNPrFdKuO7ubpWWlmrZsmX93r506VK98soreu2117RlyxZddtllqqys1NGjRwd408Q613GQpClTpkQ9P956660B3DDxGhoaVF1draamJm3cuFHHjx/X5MmT1d3dHbnPk08+qffff19r1qxRQ0OD9u7dq/vuu89w6/g7n+MgSXPnzo16PixdutRo4zNwKWDChAmuuro68nVvb68rKChwtbW1hlsNvMWLF7vS0lLrNUxJcmvXro183dfX54LBoHvxxRcj1x06dMj5/X731ltvGWw4ME49Ds45N3v2bHfvvfea7GNl//79TpJraGhwzp34dz906FC3Zs2ayH2+/PJLJ8k1NjZarZlwpx4H55y788473RNPPGG31HlI+ldAx44dU3NzsyoqKiLXpaWlqaKiQo2NjYab2di5c6cKCgpUUlKiBx98ULt377ZeyVRbW5s6Ojqinh+BQEBlZWUX5fOjvr5eubm5uu666zRv3jwdPHjQeqWECoVCkqTs7GxJUnNzs44fPx71fBg9erSKiooG9fPh1ONw0ptvvqmcnByNGTNGNTU1OnLkiMV6Z5R0JyM91YEDB9Tb26u8vLyo6/Py8vTVV18ZbWWjrKxMK1eu1HXXXad9+/ZpyZIluuOOO7Rjxw5lZGRYr2eio6NDkvp9fpy87WIxZcoU3XfffSouLtauXbv07LPPqqqqSo2NjRoyZIj1enHX19enBQsW6LbbbtOYMWMknXg+pKenKysrK+q+g/n50N9xkKRZs2Zp5MiRKigo0Pbt2/XMM8+opaVF7733nuG20ZI+QPhOVVVV5Ndjx45VWVmZRo4cqXfeeUcPP/yw4WZIBvfff3/k1zfddJPGjh2rUaNGqb6+XpMmTTLcLDGqq6u1Y8eOi+J90LM503F45JFHIr++6aablJ+fr0mTJmnXrl0aNWrUQK/Zr6T/LbicnBwNGTLktE+xdHZ2KhgMGm2VHLKysnTttdeqtbXVehUzJ58DPD9OV1JSopycnEH5/Jg/f74++OADffzxx1F/fUswGNSxY8d06NChqPsP1ufDmY5Df8rKyiQpqZ4PSR+g9PR0jRs3TnV1dZHr+vr6VFdXp/LycsPN7B0+fFi7du1Sfn6+9SpmiouLFQwGo54f4XBYW7ZsueifH3v27NHBgwcH1fPDOaf58+dr7dq12rRpk4qLi6NuHzdunIYOHRr1fGhpadHu3bsH1fPhXMehP9u2bZOk5Ho+WH8K4ny8/fbbzu/3u5UrV7ovvvjCPfLIIy4rK8t1dHRYrzagfvGLX7j6+nrX1tbm/va3v7mKigqXk5Pj9u/fb71aQnV1dbnPP//cff75506Se+mll9znn3/u/v3vfzvnnPvtb3/rsrKy3Pr169327dvdvffe64qLi923335rvHl8ne04dHV1uaeeeso1Nja6trY299FHH7kf/ehH7pprrnFHjx61Xj1u5s2b5wKBgKuvr3f79u2LXI4cORK5z6OPPuqKiorcpk2b3NatW115ebkrLy833Dr+znUcWltb3a9+9Su3detW19bW5tavX+9KSkrcxIkTjTePlhIBcs65V1991RUVFbn09HQ3YcIE19TUZL3SgJs5c6bLz8936enp7qqrrnIzZ850ra2t1msl3Mcff+wknXaZPXu2c+7ER7Gff/55l5eX5/x+v5s0aZJraWmxXToBznYcjhw54iZPnuyuvPJKN3ToUDdy5Eg3d+7cQfc/af3980tyK1asiNzn22+/dY899pj7wQ9+4IYPH+6mT5/u9u3bZ7d0ApzrOOzevdtNnDjRZWdnO7/f766++mr3y1/+0oVCIdvFT8FfxwAAMJH07wEBAAYnAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMDE/wPgnA/bT9IQRgAAAABJRU5ErkJggg==", - "text/plain": [ - "<Figure size 640x480 with 1 Axes>" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "\n", - "images = np.load(\"examples/mnist_quickrun/data_iid/client_0/train_data.npy\")\n", - "sample_img = images[0]\n", - "sample_fig = plt.imshow(sample_img,cmap='Greys')\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1vNWNGjefSfH" - }, - "source": [ - "For more information on how the `split_data` function works, you can look at the documentation. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "-wORmq5DYfRF", - "outputId": "4d79da63-ccad-4622-e600-ac36fae1ff3f" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Randomly split a dataset into shards.\n", - "\n", - " The resulting folder structure is :\n", - " folder/\n", - " └─── data*/\n", - " └─── client*/\n", - " │ train_data.* - training data\n", - " │ train_target.* - training labels\n", - " │ valid_data.* - validation data\n", - " │ valid_target.* - validation labels\n", - " └─── client*/\n", - " │ ...\n", - "\n", - " Parameters\n", - " ----------\n", - " folder: str, default = \".\"\n", - " Path to the folder where to add a data folder\n", - " holding output shard-wise files\n", - " data_file: str or None, default=None\n", - " Optional path to a folder where to find the data.\n", - " If None, default to the MNIST example.\n", - " target_file: str or int or None, default=None\n", - " If str, path to the labels file to import. If int, column of\n", - " the data file to be used as labels. Required if data is not None,\n", - " ignored if data is None.\n", - " n_shards: int\n", - " Number of shards between which to split the data.\n", - " scheme: {\"iid\", \"labels\", \"biased\"}, default=\"iid\"\n", - " Splitting scheme(s) to use. In all cases, shards contain mutually-\n", - " exclusive samples and cover the full raw training data.\n", - " - If \"iid\", split the dataset through iid random sampling.\n", - " - If \"labels\", split into shards that hold all samples associated\n", - " with mutually-exclusive target classes.\n", - " - If \"biased\", split the dataset through random sampling according\n", - " to a shard-specific random labels distribution.\n", - " perc_train: float, default= 0.8\n", - " Train/validation split in each client dataset, must be in the\n", - " ]0,1] range.\n", - " seed: int or None, default=None\n", - " Optional seed to the RNG used for all sampling operations.\n", - " \n" - ] - } - ], - "source": [ - "print(split_data.__doc__)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "kZtbxlwUftKd" - }, - "source": [ - "## Quickrun\n", - "\n", - "We can now run our experiment. As explained in the section 2.1 of the [quickstart documentation](https://magnet.gitlabpages.inria.fr/declearn/docs/2.1/quickstart), using the mode `declearn-quickrun` requires a configuration file, some data, and a model:\n", - "\n", - "* A TOML file, to store your experiment configurations. Here: \n", - "`examples/mnist_quickrun/config.toml`.\n", - "* A folder with your data, split by client. Here: `examples/mnist_quickrun/data_iid`\n", - "* A model file, to store your model wrapped in a `declearn` object. Here: `examples/mnist_quickrun/model.py`.\n", - "\n", - "We then only have to run the `quickrun` util with the path to the TOML file:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "1n_mvTIIWpRf" - }, - "outputs": [], - "source": [ - "from declearn.quickrun import quickrun\n", - "quickrun(config=\"examples/mnist_quickrun/config.toml\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "O0kuw7UxJqKk" - }, - "source": [ - "The output obtained is the combination of the CLI output of our server and our clients, going through: \n", - "\n", - "* `INFO:Server:Starting clients registration process.` : a first registration step, where clients register with the server\n", - "* `INFO:Server:Sending initialization requests to clients.`: the initilization of the object needed for training on both the server and clients side.\n", - "* `Server:INFO: Initiating training round 1`: the training starts, where each client makes its local update(s) and send the result to the server which aggregates them\n", - "* `INFO: Initiating evaluation round 1`: the model is evaluated at each round\n", - "* `Server:INFO: Stopping training`: the training is finalized " - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "wo6NDugiOH6V" - }, - "source": [ - "## Results \n", - "\n", - "You can have a look at the results in the `examples/mnist_quickrun/result_*` folder, including the metrics evolution during training. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "zlm5El13SvnG" - }, - "outputs": [], - "source": [ - "import pandas as pd\n", - "import glob\n", - "import os \n", - "\n", - "res_file = glob.glob('examples/mnist_quickrun/result*') \n", - "res = pd.read_csv(os.path.join(res_file[0],'server/metrics.csv'))\n", - "res_fig = res.plot()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Kd_MBQt9OJ40" - }, - "source": [ - "# Experiment further\n", - "\n", - "\n", - "You can change the TOML config file to experiment with different strategies." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "E3OOeAYJRGqU" - }, - "source": [ - "For instance, try splitting the data in a very heterogenous way, by distributing digits in mutually exclusive way between clients. " - ] + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This notebook is meant to be run in google colab. You can find import your local copy of the file in the the [colab welcome page](https://colab.research.google.com/)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "s9bpLdH5ThpJ" + }, + "source": [ + "# Setting up your declearn " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Clzf4NTja121" + }, + "source": [ + "We first clone the repo, to have both the package itself and the `examples` folder we will use in this tutorial, then naviguate to the package directory, and finally install the required dependencies" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "BNPLnpQuQ8Au" - }, - "outputs": [], - "source": [ - "split_data(folder=\"examples/mnist_quickrun\",scheme='labels')" - ] + "id": "u2QDwb0_QQ_f", + "outputId": "cac0761c-b229-49b0-d71d-c7b5cef919b3" + }, + "outputs": [], + "source": [ + "# you may want to specify a release branch or tag\n", + "!git clone https://gitlab.inria.fr/magnet/declearn/declearn2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "markdown", - "metadata": { - "id": "Xfs-3wH-3Eio" - }, - "source": [ - "And change the `examples/mnist_quickrun/config.toml` file with:\n", - "\n", - "```\n", - "[data] \n", - " data_folder = \"examples/mnist_quickrun/data_labels\" \n", - "```" - ] + "id": "9kDHh_AfPG2l", + "outputId": "74e2f85f-7f93-40ae-a218-f4403470d72c" + }, + "outputs": [], + "source": [ + "cd declearn2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "markdown", - "metadata": { - "id": "ZZVFNO07O1ry" - }, - "source": [ - "If you run the model as is, you should see a drop of performance\n", - "\n" - ] + "id": "Un212t1GluHB", + "outputId": "0ea67577-da6e-4f80-a412-7b7a79803aa1" + }, + "outputs": [], + "source": [ + "# Install the package, with TensorFlow and Websockets extra dependencies.\n", + "# You may want to work in a dedicated virtual environment.\n", + "!pip install .[tensorflow,websockets]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hC8Fty8YTy9P" + }, + "source": [ + "# Running your first experiment" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rcWcZJdob1IG" + }, + "source": [ + "We are going to train a common model between three simulated clients on the classic [MNIST dataset](http://yann.lecun.com/exdb/mnist/). The input of the model is a set of images of handwritten digits, and the model needs to determine which number between 0 and 9 each image corresponds to." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KlY_vVtFHv2P" + }, + "source": [ + "## The model\n", + "\n", + "To do this, we will use a simple CNN, defined in `examples/mnist_quickrun/model.py`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "C7D52a8_dEr7", + "outputId": "a25223f8-c8eb-4998-d7fd-4b8bfde92486" + }, + "outputs": [ { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "7kFa0EbINJXq" - }, - "outputs": [], - "source": [ - "quickrun(config=\"examples/mnist_quickrun/config.toml\")" - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: \"sequential\"\n", + "_________________________________________________________________\n", + " Layer (type) Output Shape Param # \n", + "=================================================================\n", + " conv2d (Conv2D) (None, 26, 26, 8) 80 \n", + " \n", + " max_pooling2d (MaxPooling2D (None, 13, 13, 8) 0 \n", + " ) \n", + " \n", + " dropout (Dropout) (None, 13, 13, 8) 0 \n", + " \n", + " flatten (Flatten) (None, 1352) 0 \n", + " \n", + " dense (Dense) (None, 64) 86592 \n", + " \n", + " dropout_1 (Dropout) (None, 64) 0 \n", + " \n", + " dense_1 (Dense) (None, 10) 650 \n", + " \n", + "=================================================================\n", + "Total params: 87,322\n", + "Trainable params: 87,322\n", + "Non-trainable params: 0\n", + "_________________________________________________________________\n" + ] + } + ], + "source": [ + "from examples.mnist_quickrun.model import network\n", + "network.summary()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HoBcOs9hH2QA" + }, + "source": [ + "## The data\n", + "\n", + "We start by splitting the MNIST dataset between 3 clients and storing the output in the `examples/mnist_quickrun` folder. For this we use an experimental utility provided by `declearn`. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "quduXkpIWFjL", + "outputId": "ddf7d45d-acf0-44ee-ce77-357c0987a2a1" + }, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "XV6JfaRzR3ee" - }, - "source": [ - "Now try modifying the `examples/mnist_quickrun/config.toml` file like this, to implement the [scaffold algorithm](https://arxiv.org/abs/1910.06378) and running the experiment again. \n", - "\n", - "```\n", - " [optim]\n", - "\n", - " [optim.client_opt]\n", - " lrate = 0.005 \n", - " modules = [\"scaffold-client\"] \n", - "\n", - " [optim.server_opt]\n", - " lrate = 1.0 \n", - " modules = [\"scaffold-client\"]\n", - "```" - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading MNIST source file train-images-idx3-ubyte.gz.\n", + "Downloading MNIST source file train-labels-idx1-ubyte.gz.\n", + "Splitting data into 3 shards using the 'iid' scheme.\n" + ] + } + ], + "source": [ + "from declearn.dataset import split_data\n", + "\n", + "split_data(folder=\"examples/mnist_quickrun\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The python code above is equivalent to running `declearn-split examples/mnist_quickrun/` in a shell command-line." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3-2hKmz-2RF4" + }, + "source": [ + "Here is what the first image of the first client looks like:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 430 }, + "id": "MLVI9GOZ1TGd", + "outputId": "f34a6a93-cb5f-4a45-bc24-4146ea119d1a" + }, + "outputs": [ { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "FK6c9HDjSdGZ" - }, - "outputs": [], - "source": [ - "quickrun(config=\"examples/mnist_quickrun/config.toml\")" + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAbo0lEQVR4nO3df2zV9fXH8dct0itoe7ta29tKy1r8gYrUDKR2Kv6goXQJESQLikvAGJxYjMicpkZBtiXdMPPrNAz/cTATUcQJRDNJsNgStxZDlRCmVsq6UQYtwsK9pUhh7fv7B+HqhfLjc7m3597yfCQ3offe03v8eO3Ty7188DnnnAAAGGBp1gsAAC5OBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJi4xHqBU/X19Wnv3r3KyMiQz+ezXgcA4JFzTl1dXSooKFBa2plf5yRdgPbu3avCwkLrNQAAF6i9vV0jRow44+1JF6CMjAxJJxbPzMw03gYA4FU4HFZhYWHk5/mZJCxAy5Yt04svvqiOjg6Vlpbq1Vdf1YQJE845d/K33TIzMwkQAKSwc72NkpAPIaxevVoLFy7U4sWL9dlnn6m0tFSVlZXav39/Ih4OAJCCEhKgl156SXPnztVDDz2kG264Qa+99pqGDx+uP/3pT4l4OABACop7gI4dO6bm5mZVVFR89yBpaaqoqFBjY+Np9+/p6VE4HI66AAAGv7gH6MCBA+rt7VVeXl7U9Xl5eero6Djt/rW1tQoEApELn4ADgIuD+R9ErampUSgUilza29utVwIADIC4fwouJydHQ4YMUWdnZ9T1nZ2dCgaDp93f7/fL7/fHew0AQJKL+yug9PR0jRs3TnV1dZHr+vr6VFdXp/Ly8ng/HAAgRSXkzwEtXLhQs2fP1vjx4zVhwgS9/PLL6u7u1kMPPZSIhwMApKCEBGjmzJn65ptvtGjRInV0dOjmm2/Whg0bTvtgAgDg4uVzzjnrJb4vHA4rEAgoFApxJgQASEHn+3Pc/FNwAICLEwECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGDiEusFAJyfnp4ezzMHDhyI6bFef/31mOa8+vjjjz3P1NfXe57x+XyeZ2IVyz/TnXfemYBNkh+vgAAAJggQAMBE3AP0wgsvyOfzRV1Gjx4d74cBAKS4hLwHdOONN+qjjz767kEu4a0mAEC0hJThkksuUTAYTMS3BgAMEgl5D2jnzp0qKChQSUmJHnzwQe3evfuM9+3p6VE4HI66AAAGv7gHqKysTCtXrtSGDRu0fPlytbW16Y477lBXV1e/96+trVUgEIhcCgsL470SACAJxT1AVVVV+ulPf6qxY8eqsrJSf/3rX3Xo0CG98847/d6/pqZGoVAocmlvb4/3SgCAJJTwTwdkZWXp2muvVWtra7+3+/1++f3+RK8BAEgyCf9zQIcPH9auXbuUn5+f6IcCAKSQuAfoqaeeUkNDg/71r3/p73//u6ZPn64hQ4bogQceiPdDAQBSWNx/C27Pnj164IEHdPDgQV155ZW6/fbb1dTUpCuvvDLeDwUASGFxD9Dbb78d728JJLXe3l7PMxs3bvQ8s2TJEs8zn376qeeZZJeWltxnEAuFQtYrpIzk/jcJABi0CBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATCf8L6ZD8nHMxzfX19XmeGagTScaymyR98803nmeqqqo8z2zfvt3zDAbekCFDPM8UFRUlYJPBiVdAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMMHZsKEdO3bENHfzzTd7nvnHP/7heebbb7/1PDN+/HjPM0gNGRkZnmdGjx4d02P95S9/8Txz1VVXxfRYFyNeAQEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJjgZ6SDzz3/+0/PMrFmzErBJ//bv3+95pqamJgGbXBwuvfTSmOYefPDBOG/SvyeeeMLzTFZWlucZThCanHgFBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCY4GSkg8yBAwc8z3zxxRcxPZbP5/M809TU5Hlmy5YtnmfS0mL7f6thw4bFNOfVzJkzPc/cf//9nmfKyso8z0jS5ZdfHtMc4AWvgAAAJggQAMCE5wBt3rxZU6dOVUFBgXw+n9atWxd1u3NOixYtUn5+voYNG6aKigrt3LkzXvsCAAYJzwHq7u5WaWmpli1b1u/tS5cu1SuvvKLXXntNW7Zs0WWXXabKykodPXr0gpcFAAwenj+EUFVVpaqqqn5vc87p5Zdf1nPPPad7771XkvTGG28oLy9P69ati+lNVADA4BTX94Da2trU0dGhioqKyHWBQEBlZWVqbGzsd6anp0fhcDjqAgAY/OIaoI6ODklSXl5e1PV5eXmR205VW1urQCAQuRQWFsZzJQBAkjL/FFxNTY1CoVDk0t7ebr0SAGAAxDVAwWBQktTZ2Rl1fWdnZ+S2U/n9fmVmZkZdAACDX1wDVFxcrGAwqLq6ush14XBYW7ZsUXl5eTwfCgCQ4jx/Cu7w4cNqbW2NfN3W1qZt27YpOztbRUVFWrBggX7zm9/ommuuUXFxsZ5//nkVFBRo2rRp8dwbAJDiPAdo69atuvvuuyNfL1y4UJI0e/ZsrVy5Uk8//bS6u7v1yCOP6NChQ7r99tu1YcMGXXrppfHbGgCQ8nzOOWe9xPeFw2EFAgGFQiHeD4rBrFmzPM+sXr06psfKzs72PBPLiU+//vprzzN+v9/zjCSNHz8+pjkA3znfn+Pmn4IDAFycCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYMLzX8cAnPTf//7X88zUqVM9z6xatcrzTElJiecZAAOLV0AAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAmfc85ZL/F94XBYgUBAoVBImZmZ1uuknMbGRs8zt99+ewI2iZ/hw4d7npkzZ05Mj7VkyRLPM7HsN3ToUM8zQ4YM8TwDWDjfn+O8AgIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATHAy0kEmHA57nlm+fHlMj/Xss8/GNJfMrr76as8ze/bs8Txzzz33eJ6JZbdYTZ8+3fPMrbfe6nkmPT3d8wySHycjBQAkNQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABCcjhfr6+mKae+qppzzPvPvuu55n/vOf/3iewcD7+c9/7nlm0aJFnmeCwaDnGQwsTkYKAEhqBAgAYMJzgDZv3qypU6eqoKBAPp9P69ati7p9zpw58vl8UZcpU6bEa18AwCDhOUDd3d0qLS3VsmXLznifKVOmaN++fZHLW2+9dUFLAgAGn0u8DlRVVamqquqs9/H7/bxRCAA4q4S8B1RfX6/c3Fxdd911mjdvng4ePHjG+/b09CgcDkddAACDX9wDNGXKFL3xxhuqq6vT7373OzU0NKiqqkq9vb393r+2tlaBQCByKSwsjPdKAIAk5Pm34M7l/vvvj/z6pptu0tixYzVq1CjV19dr0qRJp92/pqZGCxcujHwdDoeJEABcBBL+MeySkhLl5OSotbW139v9fr8yMzOjLgCAwS/hAdqzZ48OHjyo/Pz8RD8UACCFeP4tuMOHD0e9mmlra9O2bduUnZ2t7OxsLVmyRDNmzFAwGNSuXbv09NNP6+qrr1ZlZWVcFwcApDbPAdq6davuvvvuyNcn37+ZPXu2li9fru3bt+vPf/6zDh06pIKCAk2ePFm//vWv5ff747c1ACDlcTJSDKhjx455nvnf//7neebNN9/0PCNJX331leeZP/zhD55nkuw/OzOvvvqq55nHHnssAZsgnjgZKQAgqREgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEZ8MGLtDWrVs9z7z77rueZ7788kvPMx9++KHnGUnq7e2Nac6r4cOHe57ZsWOH55mRI0d6nkHsOBs2ACCpESAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmLrFeAEh148ePH5CZWHz99dcxzV1//fVx3qR/R44c8TzT2dnpeYaTkSYnXgEBAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACY4GSkwiF111VXWK5xVVlaW55mSkpL4LwITvAICAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAExwMlIMqD179nieGT58uOeZ7OxszzMDqaenx/NMW1ub55nf//73nmcG0oQJEzzP5OTkJGATWOAVEADABAECAJjwFKDa2lrdcsstysjIUG5urqZNm6aWlpao+xw9elTV1dW64oordPnll2vGjBnq7OyM69IAgNTnKUANDQ2qrq5WU1OTNm7cqOPHj2vy5Mnq7u6O3OfJJ5/U+++/rzVr1qihoUF79+7VfffdF/fFAQCpzdOHEDZs2BD19cqVK5Wbm6vm5mZNnDhRoVBIr7/+ulatWqV77rlHkrRixQpdf/31ampq0q233hq/zQEAKe2C3gMKhUKSvvvEUXNzs44fP66KiorIfUaPHq2ioiI1Njb2+z16enoUDoejLgCAwS/mAPX19WnBggW67bbbNGbMGElSR0eH0tPTT/t73vPy8tTR0dHv96mtrVUgEIhcCgsLY10JAJBCYg5QdXW1duzYobfffvuCFqipqVEoFIpc2tvbL+j7AQBSQ0x/EHX+/Pn64IMPtHnzZo0YMSJyfTAY1LFjx3To0KGoV0GdnZ0KBoP9fi+/3y+/3x/LGgCAFObpFZBzTvPnz9fatWu1adMmFRcXR90+btw4DR06VHV1dZHrWlpatHv3bpWXl8dnYwDAoODpFVB1dbVWrVql9evXKyMjI/K+TiAQ0LBhwxQIBPTwww9r4cKFys7OVmZmph5//HGVl5fzCTgAQBRPAVq+fLkk6a677oq6fsWKFZozZ44k6f/+7/+UlpamGTNmqKenR5WVlfrjH/8Yl2UBAIOHzznnrJf4vnA4rEAgoFAopMzMTOt1cBaffvqp55nvf0T/fF122WWeZ4qKijzPDKSuri7PM6eedWQwaG5u9jxz8803x38RxNX5/hznXHAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwEdPfiIrB5ciRIzHN/fjHP/Y8E8vJ17u7uz3P7N+/3/MMLkxbW5vnmcLCwgRsglTBKyAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQnI4XWr18f01wsJxbFwLr11ltjmlu2bJnnmdzcXM8zPp/P8wwGD14BAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmOBkpVFpaar3CRWfmzJkDMjN58mTPM5I0bNiwmOYAL3gFBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCY4GSk0A033BDTXG9vb5w3AXAx4RUQAMAEAQIAmPAUoNraWt1yyy3KyMhQbm6upk2bppaWlqj73HXXXfL5fFGXRx99NK5LAwBSn6cANTQ0qLq6Wk1NTdq4caOOHz+uyZMnq7u7O+p+c+fO1b59+yKXpUuXxnVpAEDq8/QhhA0bNkR9vXLlSuXm5qq5uVkTJ06MXD98+HAFg8H4bAgAGJQu6D2gUCgkScrOzo66/s0331ROTo7GjBmjmpoaHTly5Izfo6enR+FwOOoCABj8Yv4Ydl9fnxYsWKDbbrtNY8aMiVw/a9YsjRw5UgUFBdq+fbueeeYZtbS06L333uv3+9TW1mrJkiWxrgEASFE+55yLZXDevHn68MMP9cknn2jEiBFnvN+mTZs0adIktba2atSoUafd3tPTo56ensjX4XBYhYWFCoVCyszMjGU1AIChcDisQCBwzp/jMb0Cmj9/vj744ANt3rz5rPGRpLKyMkk6Y4D8fr/8fn8sawAAUpinADnn9Pjjj2vt2rWqr69XcXHxOWe2bdsmScrPz49pQQDA4OQpQNXV1Vq1apXWr1+vjIwMdXR0SJICgYCGDRumXbt2adWqVfrJT36iK664Qtu3b9eTTz6piRMnauzYsQn5BwAApCZP7wH5fL5+r1+xYoXmzJmj9vZ2/exnP9OOHTvU3d2twsJCTZ8+Xc8999x5v59zvr93CABITgl5D+hcrSosLFRDQ4OXbwkAuEhxLjgAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgIlLrBc4lXNOkhQOh403AQDE4uTP75M/z88k6QLU1dUlSSosLDTeBABwIbq6uhQIBM54u8+dK1EDrK+vT3v37lVGRoZ8Pl/UbeFwWIWFhWpvb1dmZqbRhvY4DidwHE7gOJzAcTghGY6Dc05dXV0qKChQWtqZ3+lJuldAaWlpGjFixFnvk5mZeVE/wU7iOJzAcTiB43ACx+EE6+Nwtlc+J/EhBACACQIEADCRUgHy+/1avHix/H6/9SqmOA4ncBxO4DicwHE4IZWOQ9J9CAEAcHFIqVdAAIDBgwABAEwQIACACQIEADCRMgFatmyZfvjDH+rSSy9VWVmZPv30U+uVBtwLL7wgn88XdRk9erT1Wgm3efNmTZ06VQUFBfL5fFq3bl3U7c45LVq0SPn5+Ro2bJgqKiq0c+dOm2UT6FzHYc6cOac9P6ZMmWKzbILU1tbqlltuUUZGhnJzczVt2jS1tLRE3efo0aOqrq7WFVdcocsvv1wzZsxQZ2en0caJcT7H4a677jrt+fDoo48abdy/lAjQ6tWrtXDhQi1evFifffaZSktLVVlZqf3791uvNuBuvPFG7du3L3L55JNPrFdKuO7ubpWWlmrZsmX93r506VK98soreu2117RlyxZddtllqqys1NGjRwd408Q613GQpClTpkQ9P956660B3DDxGhoaVF1draamJm3cuFHHjx/X5MmT1d3dHbnPk08+qffff19r1qxRQ0OD9u7dq/vuu89w6/g7n+MgSXPnzo16PixdutRo4zNwKWDChAmuuro68nVvb68rKChwtbW1hlsNvMWLF7vS0lLrNUxJcmvXro183dfX54LBoHvxxRcj1x06dMj5/X731ltvGWw4ME49Ds45N3v2bHfvvfea7GNl//79TpJraGhwzp34dz906FC3Zs2ayH2+/PJLJ8k1NjZarZlwpx4H55y788473RNPPGG31HlI+ldAx44dU3NzsyoqKiLXpaWlqaKiQo2NjYab2di5c6cKCgpUUlKiBx98ULt377ZeyVRbW5s6Ojqinh+BQEBlZWUX5fOjvr5eubm5uu666zRv3jwdPHjQeqWECoVCkqTs7GxJUnNzs44fPx71fBg9erSKiooG9fPh1ONw0ptvvqmcnByNGTNGNTU1OnLkiMV6Z5R0JyM91YEDB9Tb26u8vLyo6/Py8vTVV18ZbWWjrKxMK1eu1HXXXad9+/ZpyZIluuOOO7Rjxw5lZGRYr2eio6NDkvp9fpy87WIxZcoU3XfffSouLtauXbv07LPPqqqqSo2NjRoyZIj1enHX19enBQsW6LbbbtOYMWMknXg+pKenKysrK+q+g/n50N9xkKRZs2Zp5MiRKigo0Pbt2/XMM8+opaVF7733nuG20ZI+QPhOVVVV5Ndjx45VWVmZRo4cqXfeeUcPP/yw4WZIBvfff3/k1zfddJPGjh2rUaNGqb6+XpMmTTLcLDGqq6u1Y8eOi+J90LM503F45JFHIr++6aablJ+fr0mTJmnXrl0aNWrUQK/Zr6T/LbicnBwNGTLktE+xdHZ2KhgMGm2VHLKysnTttdeqtbXVehUzJ58DPD9OV1JSopycnEH5/Jg/f74++OADffzxx1F/fUswGNSxY8d06NChqPsP1ufDmY5Df8rKyiQpqZ4PSR+g9PR0jRs3TnV1dZHr+vr6VFdXp/LycsPN7B0+fFi7du1Sfn6+9SpmiouLFQwGo54f4XBYW7ZsueifH3v27NHBgwcH1fPDOaf58+dr7dq12rRpk4qLi6NuHzdunIYOHRr1fGhpadHu3bsH1fPhXMehP9u2bZOk5Ho+WH8K4ny8/fbbzu/3u5UrV7ovvvjCPfLIIy4rK8t1dHRYrzagfvGLX7j6+nrX1tbm/va3v7mKigqXk5Pj9u/fb71aQnV1dbnPP//cff75506Se+mll9znn3/u/v3vfzvnnPvtb3/rsrKy3Pr169327dvdvffe64qLi923335rvHl8ne04dHV1uaeeeso1Nja6trY299FHH7kf/ehH7pprrnFHjx61Xj1u5s2b5wKBgKuvr3f79u2LXI4cORK5z6OPPuqKiorcpk2b3NatW115ebkrLy833Dr+znUcWltb3a9+9Su3detW19bW5tavX+9KSkrcxIkTjTePlhIBcs65V1991RUVFbn09HQ3YcIE19TUZL3SgJs5c6bLz8936enp7qqrrnIzZ850ra2t1msl3Mcff+wknXaZPXu2c+7ER7Gff/55l5eX5/x+v5s0aZJraWmxXToBznYcjhw54iZPnuyuvPJKN3ToUDdy5Eg3d+7cQfc/af3980tyK1asiNzn22+/dY899pj7wQ9+4IYPH+6mT5/u9u3bZ7d0ApzrOOzevdtNnDjRZWdnO7/f766++mr3y1/+0oVCIdvFT8FfxwAAMJH07wEBAAYnAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMDE/wPgnA/bT9IQRgAAAABJRU5ErkJggg==", + "text/plain": [ + "<Figure size 640x480 with 1 Axes>" ] + }, + "metadata": {}, + "output_type": "display_data" } - ], - "metadata": { + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "\n", + "images = np.load(\"examples/mnist_quickrun/data_iid/client_0/train_data.npy\")\n", + "sample_img = images[0]\n", + "sample_fig = plt.imshow(sample_img,cmap='Greys')\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1vNWNGjefSfH" + }, + "source": [ + "For more information on how the `split_data` function works, you can look at the documentation. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { "colab": { - "collapsed_sections": [ - "s9bpLdH5ThpJ", - "KlY_vVtFHv2P", - "HoBcOs9hH2QA", - "kZtbxlwUftKd", - "wo6NDugiOH6V", - "Kd_MBQt9OJ40" - ], - "provenance": [] + "base_uri": "https://localhost:8080/" }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" + "id": "-wORmq5DYfRF", + "outputId": "4d79da63-ccad-4622-e600-ac36fae1ff3f" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Randomly split a dataset into shards.\n", + "\n", + " The resulting folder structure is :\n", + " folder/\n", + " └─── data*/\n", + " └─── client*/\n", + " │ train_data.* - training data\n", + " │ train_target.* - training labels\n", + " │ valid_data.* - validation data\n", + " │ valid_target.* - validation labels\n", + " └─── client*/\n", + " │ ...\n", + "\n", + " Parameters\n", + " ----------\n", + " folder: str, default = \".\"\n", + " Path to the folder where to add a data folder\n", + " holding output shard-wise files\n", + " data_file: str or None, default=None\n", + " Optional path to a folder where to find the data.\n", + " If None, default to the MNIST example.\n", + " target_file: str or int or None, default=None\n", + " If str, path to the labels file to import, or name of a `data`\n", + " column to use as labels (only if `data` points to a csv file).\n", + " If int, index of a `data` column of to use as labels).\n", + " Required if data is not None, ignored if data is None.\n", + " n_shards: int\n", + " Number of shards between which to split the data.\n", + " scheme: {\"iid\", \"labels\", \"biased\"}, default=\"iid\"\n", + " Splitting scheme(s) to use. In all cases, shards contain mutually-\n", + " exclusive samples and cover the full raw training data.\n", + " - If \"iid\", split the dataset through iid random sampling.\n", + " - If \"labels\", split into shards that hold all samples associated\n", + " with mutually-exclusive target classes.\n", + " - If \"biased\", split the dataset through random sampling according\n", + " to a shard-specific random labels distribution.\n", + " perc_train: float, default= 0.8\n", + " Train/validation split in each client dataset, must be in the\n", + " ]0,1] range.\n", + " seed: int or None, default=None\n", + " Optional seed to the RNG used for all sampling operations.\n", + " \n" + ] } + ], + "source": [ + "print(split_data.__doc__)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "kZtbxlwUftKd" + }, + "source": [ + "## Quickrun\n", + "\n", + "We can now run our experiment. As explained in the section 2.1 of the [quickstart documentation](https://magnet.gitlabpages.inria.fr/declearn/docs/latest/quickstart), using the `declearn-quickrun` entry-point requires a configuration file, some data, and a model:\n", + "\n", + "* A TOML file, to store your experiment configurations. Here: \n", + "`examples/mnist_quickrun/config.toml`.\n", + "* A folder with your data, split by client. Here: `examples/mnist_quickrun/data_iid`\n", + "* A model python file, to declare your model wrapped in a `declearn` object. Here: `examples/mnist_quickrun/model.py`.\n", + "\n", + "We then only have to run the `quickrun` util with the path to the TOML file:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "1n_mvTIIWpRf" + }, + "outputs": [], + "source": [ + "from declearn.quickrun import quickrun\n", + "\n", + "quickrun(config=\"examples/mnist_quickrun/config.toml\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The python code above is equivalent to running `declearn-quickrun examples/mnist_quickrun/config.toml` in a shell command-line." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "O0kuw7UxJqKk" + }, + "source": [ + "The output obtained is the combination of the CLI output of our server and our clients, going through: \n", + "\n", + "* `INFO:Server:Starting clients registration process.` : a first registration step, where clients register with the server\n", + "* `INFO:Server:Sending initialization requests to clients.`: the initilization of the object needed for training on both the server and clients side.\n", + "* `Server:INFO: Initiating training round 1`: the training starts, where each client makes its local update(s) and send the result to the server which aggregates them\n", + "* `INFO: Initiating evaluation round 1`: the model is evaluated at each round\n", + "* `Server:INFO: Stopping training`: the training is finalized " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wo6NDugiOH6V" + }, + "source": [ + "## Results \n", + "\n", + "You can have a look at the results in the `examples/mnist_quickrun/result_*` folder, including the metrics evolution during training. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "zlm5El13SvnG" + }, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import glob\n", + "import os \n", + "\n", + "res_file = glob.glob('examples/mnist_quickrun/result*') \n", + "res = pd.read_csv(os.path.join(res_file[0],'server/metrics.csv'))\n", + "res_fig = res.plot()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Kd_MBQt9OJ40" + }, + "source": [ + "# Experiment further\n", + "\n", + "\n", + "You can change the TOML config file to experiment with different strategies." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "E3OOeAYJRGqU" + }, + "source": [ + "For instance, try splitting the data in a very heterogenous way, by distributing digits in mutually exclusive way between clients. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "BNPLnpQuQ8Au" + }, + "outputs": [], + "source": [ + "split_data(folder=\"examples/mnist_quickrun\",scheme='labels')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Xfs-3wH-3Eio" + }, + "source": [ + "And change the `examples/mnist_quickrun/config.toml` file with:\n", + "\n", + "```\n", + "[data] \n", + " data_folder = \"examples/mnist_quickrun/data_labels\" \n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZZVFNO07O1ry" + }, + "source": [ + "If you run the model as is, you should see a drop of performance\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "7kFa0EbINJXq" + }, + "outputs": [], + "source": [ + "quickrun(config=\"examples/mnist_quickrun/config.toml\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "XV6JfaRzR3ee" + }, + "source": [ + "Now try modifying the `examples/mnist_quickrun/config.toml` file like this, to implement the [scaffold algorithm](https://arxiv.org/abs/1910.06378) and running the experiment again. \n", + "\n", + "```\n", + " [optim]\n", + "\n", + " [optim.client_opt]\n", + " lrate = 0.005 \n", + " modules = [\"scaffold-client\"] \n", + "\n", + " [optim.server_opt]\n", + " lrate = 1.0 \n", + " modules = [\"scaffold-client\"]\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "FK6c9HDjSdGZ" + }, + "outputs": [], + "source": [ + "quickrun(config=\"examples/mnist_quickrun/config.toml\")" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [ + "s9bpLdH5ThpJ", + "KlY_vVtFHv2P", + "HoBcOs9hH2QA", + "kZtbxlwUftKd", + "wo6NDugiOH6V", + "Kd_MBQt9OJ40" + ], + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" }, - "nbformat": 4, - "nbformat_minor": 0 + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + } + }, + "nbformat": 4, + "nbformat_minor": 1 } diff --git a/examples/mnist_quickrun/readme.md b/examples/mnist_quickrun/readme.md new file mode 100644 index 00000000..afcc573c --- /dev/null +++ b/examples/mnist_quickrun/readme.md @@ -0,0 +1,31 @@ +# Demo training task : MNIST in Quickrun Mode + +## Overview + +**We are going to use the declearn-quickrun tool to easily run a simulated +federated learning experiment on the classic +[MNIST dataset](http://yann.lecun.com/exdb/mnist/)**. The input of the model +is a set of images of handwritten digits, and the model needs to determine to +which digit between $0$ and $9$ each image corresponds. + +## Setup + +A Jupyter Notebook tutorial is provided, that you may import and run on Google +Colab so as to avoid having to set up a local python environment. + +Alternatively, you may run the notebook on your personal computer, or follow +its instructions to install declearn and operate the quickrun tools directly +from a shell command-line. + +## Contents + +This example's folder is structured the following way: + +``` +mnist/ +│ config.toml - configuration file for the quickrun FL experiment +| mnist.ipynb - tutorial for this example, as a jupyter notebook +| model.py - python file declaring the model to be trained +└─── data_iid - mnist data generated with `declearn-split` +└─── results_* - results generated after running `declearn-quickrun` +``` -- GitLab