From 98027bc4ff4ed17e9206c00e97953ebe0218b825 Mon Sep 17 00:00:00 2001
From: BIGAUD Nathan <nathan.bigaud@inria.fr>
Date: Wed, 29 Mar 2023 16:01:08 +0200
Subject: [PATCH] In depth review of TOML parsing, first pass : * Creating
 "ModelConfig","DataSplitConfig", and "ExperimentConfig" * Relocating and
 updating parser to accept toml file as input * Update split data to use
 config input * Modularization and updating of run.py * Small changes to MNIST
 toml example

---
 declearn/quickrun/_config.py     | 103 +++++++++++++
 declearn/quickrun/_parser.py     | 143 ++++++++++++++++++
 declearn/quickrun/_split_data.py | 168 ++++-----------------
 declearn/quickrun/run.py         | 241 ++++++++++---------------------
 examples/quickrun/config.toml    |  10 ++
 examples/quickrun/model.py       |  85 ++++++-----
 6 files changed, 403 insertions(+), 347 deletions(-)
 create mode 100644 declearn/quickrun/_config.py
 create mode 100644 declearn/quickrun/_parser.py

diff --git a/declearn/quickrun/_config.py b/declearn/quickrun/_config.py
new file mode 100644
index 00000000..a80e711b
--- /dev/null
+++ b/declearn/quickrun/_config.py
@@ -0,0 +1,103 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""TOML-parsable container for quickrun configurations."""
+
+import dataclasses
+from typing import Dict, List, Literal, Optional, Union
+
+from declearn.utils import TomlConfig
+
+__all__ = [
+    "ModelConfig",
+    "DataSplitConfig",
+    "ExperimentConfig",
+]
+
+
+@dataclasses.dataclass
+class ModelConfig(TomlConfig):
+    """Dataclass used to provide custom model location and
+    class name"""
+
+    model_file: Optional[str] = None
+    model_name: str = "MyModel"
+
+
+@dataclasses.dataclass
+class DataSplitConfig(TomlConfig):
+    """Dataclass associated with the function
+    declearn.quickrun._split_data:split_data
+
+    export_folder: str
+        Path to the folder where to export shard-wise files.
+    n_shards: int
+        Number of shards between which to split the data.
+    data_file: str or None, default=None
+        Optional path to a folder where to find the data.
+        If None, default to the MNIST example.
+    target_file: str or int or None, default=None
+        If str, path to the labels file to import. If int, column of
+        the data file to be used as labels. Required if data is not None,
+        ignored if data is None.
+    scheme: {"iid", "labels", "biased"}, default="iid"
+        Splitting scheme(s) to use. In all cases, shards contain mutually-
+        exclusive samples and cover the full raw training data.
+        - If "iid", split the dataset through iid random sampling.
+        - If "labels", split into shards that hold all samples associated
+        with mutually-exclusive target classes.
+        - If "biased", split the dataset through random sampling according
+        to a shard-specific random labels distribution.
+    perc_train:  float, default= 0.8
+        Train/validation split in each client dataset, must be in the
+        ]0,1] range.
+    seed: int or None, default=None
+        Optional seed to the RNG used for all sampling operations.
+    """
+
+    export_folder: str = "."
+    n_shards: int = 5
+    data_file: Optional[str] = None
+    label_file: Optional[Union[str, int]] = None
+    scheme: Literal["iid", "labels", "biased"] = "iid"
+    perc_train: float = 0.8
+    seed: Optional[int] = None
+
+
+@dataclasses.dataclass
+class ExperimentConfig(TomlConfig):
+    """
+
+    Dataclass associated with the function
+    declearn.quickrun._parser:parse_data_folder
+
+    data_folder : str or none
+        Absolute path to the main folder hosting the data, overwriting
+        the folder argument if provided. If None, default to expected
+        prefix search in folder.
+    client_names: list or None
+        List of custom client names to look for in the data_folder.
+        If None, default to expected prefix search.
+    dataset_names: dict or None
+        Dict of custom dataset names, to look for in each client folder.
+        Expect 'train_data, train_target, valid_data, valid_target' as keys.
+        If None, , default to expected prefix search.
+    """
+
+    data_folder: Optional[str] = None
+    client_names: Optional[List[str]] = None
+    dataset_names: Optional[Dict[str, str]] = None
diff --git a/declearn/quickrun/_parser.py b/declearn/quickrun/_parser.py
new file mode 100644
index 00000000..eae9c9fc
--- /dev/null
+++ b/declearn/quickrun/_parser.py
@@ -0,0 +1,143 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+#TODO
+"""
+
+import os
+from pathlib import Path
+from typing import Dict, List, Optional
+
+from declearn.test_utils import make_importable
+
+# Perform local imports.
+# pylint: disable=wrong-import-order, wrong-import-position
+with make_importable(os.path.dirname(__file__)):
+    from _config import ExperimentConfig
+# pylint: enable=wrong-import-order, wrong-import-position
+
+
+def parse_data_folder(
+    expe_config: ExperimentConfig,
+    folder: Optional[str] = None,
+) -> Dict:
+    """Utils parsing a data folder following a standard format into a nested
+    dictionnary.
+
+    The default expected format is :
+
+        folder/
+        └─── data*/
+            └─── client*/
+            │      train_data.* - training data
+            │      train_target.* - training labels
+            │      valid_data.* - validation data
+            │      valid_target.* - validation labels
+            └─── client*/
+            │    ...
+
+    Parameters:
+    -----------
+    expe_config : ExperimentConfig
+        ExperimentConfig instance, see class documentation for details.
+    folder : str or None
+        The main experiment folder in which to look for a `data*` folder.
+        Overwritten by data_folder.
+    """
+
+    data_folder = expe_config.data_folder
+    client_names = expe_config.client_names
+    dataset_names = expe_config.dataset_names
+
+    if not folder and not data_folder:
+        raise ValueError(
+            "Please provide either a parent folder or a data folder"
+        )
+    # Data_folder
+    if not data_folder:
+        gen_folders = Path(folder).glob("data*")  # type: ignore
+        data_folder = next(gen_folders, False)  # type: ignore
+        if not data_folder:
+            raise ValueError(
+                f"No folder starting with 'data' found in {folder}. "
+                "Please store your data under a 'data_*' folder"
+            )
+        if next(gen_folders, False):
+            raise ValueError(
+                "More than one folder starting with 'data' found"
+                f"in {folder}. Please store your data under a single"
+                "parent folder"
+            )
+    # Get clients dir
+    if client_names:
+        if isinstance(client_names, list):
+            valid_names = [
+                os.path.isdir(os.path.join(data_folder, n))
+                for n in client_names
+            ]
+            if sum(valid_names) != len(client_names):
+                raise ValueError(
+                    f"Not all provided client names could be found in {data_folder}"
+                )
+            clients = {n: {} for n in client_names}
+        else:
+            raise ValueError(
+                "Please provide a valid list of client names for "
+                "argument 'client_names'"
+            )
+    else:
+        gen_folders = Path(data_folder).glob("client*")  # type: ignore
+        first_client = next(gen_folders, False)
+        if not first_client:
+            raise ValueError(
+                f"No folder starting with 'client' found in {data_folder}. "
+                "Please store your individual under client data under"
+                "a 'client*' folder"
+            )
+        clients = {str(first_client).rsplit("/", 1)[-1]: {}}
+        while client := next(gen_folders, False):
+            clients[str(client).rsplit("/", 1)[-1]] = {}
+    # Get train and valid files
+    data_items = [
+        "train_data",
+        "train_target",
+        "valid_data",
+        "valid_target",
+    ]
+    if dataset_names:
+        if set(data_items) != set(dataset_names.keys()):
+            raise ValueError(
+                f"Please provide a properly formatted dictionnary as input"
+                f"using the follwoing keys : {str(data_items)}"
+            )
+    else:
+        dataset_names = {i: i for i in data_items}
+    for client, files in clients.items():
+        for k, v in dataset_names.items():
+            gen_file = Path(data_folder / client).glob(f"{v}*")  # type: ignore
+            file = next(gen_file, False)
+            if not file:
+                raise ValueError(
+                    f"Could not find a file named '{v}.*' in {client}"
+                )
+            if next(gen_file, False):
+                raise ValueError(
+                    f"Found more than one file named '{v}.*' in {client}"
+                )
+            files[k] = str(file)
+    return clients
diff --git a/declearn/quickrun/_split_data.py b/declearn/quickrun/_split_data.py
index 96c93ef0..e323258e 100644
--- a/declearn/quickrun/_split_data.py
+++ b/declearn/quickrun/_split_data.py
@@ -45,6 +45,13 @@ import pandas as pd
 import requests  # type: ignore
 
 from declearn.dataset import load_data_array
+from declearn.test_utils import make_importable
+
+# Perform local imports.
+# pylint: disable=wrong-import-order, wrong-import-position
+with make_importable(os.path.dirname(__file__)):
+    from _config import DataSplitConfig
+# pylint: enable=wrong-import-order, wrong-import-position
 
 SOURCE_URL = "https://pjreddie.com/media/files"
 
@@ -191,15 +198,7 @@ def _split_biased(
     return split
 
 
-def split_data(
-    folder: str,
-    n_shards: int = 5,
-    data: Optional[str] = None,
-    target: Optional[Union[str, int]] = None,
-    scheme: Literal["iid", "labels", "biased"] = "iid",
-    perc_train: float = 0.8,
-    seed: Optional[int] = None,
-) -> None:
+def split_data(data_config: DataSplitConfig) -> None:
     """Download and randomly split a dataset into shards.
 
     The resulting folder structure is :
@@ -215,32 +214,17 @@ def split_data(
 
     Parameters
     ----------
-    folder: str
-        Path to the folder where to export shard-wise files.
-    n_shards: int
-        Number of shards between which to split the data.
-    data: str or None, default=None
-        Optional path to a folder where to find the data.
-        If None, default to the MNIST example.
-    target: str or int or None, default=None
-        If str, path to the labels file to import. If int, column of
-        the data file to be used as labels. Required if data is not None,
-        ignored if data is None.
-    scheme: {"iid", "labels", "biased"}, default="iid"
-        Splitting scheme to use. In all cases, shards contain mutually-
-        exclusive samples and cover the full raw training data.
-        - If "iid", split the dataset through iid random sampling.
-        - If "labels", split into shards that hold all samples associated
-        with mutually-exclusive target classes.
-        - If "biased", split the dataset through random sampling according
-        to a shard-specific random labels distribution.
-    perc_train:  float, default= 0.8
-        Train/validation split in each client dataset, must be in the
-        ]0,1] range.
-    seed: int or None, default=None
-        Optional seed to the RNG used for all sampling operations.
+    data_config: DataSplitConfig
+        A DataSplitConfig instance, see class documentation for details
     """
+
+    def np_save(folder, data, i, name):
+        data_dir = os.path.join(folder, f"client_{i}")
+        os.makedirs(data_dir, exist_ok=True)
+        np.save(os.path.join(data_dir, f"{name}.npy"), data)
+
     # Select the splitting function to be used.
+    scheme = data_config.scheme
     if scheme == "iid":
         func = _split_iid
     elif scheme == "labels":
@@ -250,115 +234,27 @@ def split_data(
     else:
         raise ValueError(f"Invalid 'scheme' value: '{scheme}'.")
     # Set up the RNG, download the raw dataset and split it.
-    rng = np.random.default_rng(seed)
-    inputs, labels = load_data(data, target)
-    print(f"Splitting data into {n_shards} shards using the {scheme} scheme")
-    split = func(inputs, labels, n_shards, rng)
+    rng = np.random.default_rng(data_config.seed)
+    inputs, labels = load_data(data_config.data_file, data_config.label_file)
+    print(
+        f"Splitting data into {data_config.n_shards}"
+        f"shards using the {scheme} scheme"
+    )
+    split = func(inputs, labels, data_config.n_shards, rng)
     # Export the resulting shard-wise data to files.
-    folder = os.path.join(folder, f"data_{scheme}")
-
-    def np_save(data, i, name):
-        data_dir = os.path.join(folder, f"client_{i}")
-        os.makedirs(data_dir, exist_ok=True)
-        np.save(os.path.join(data_dir, f"{name}.npy"), data)
-
+    folder = os.path.join(data_config.export_folder, f"data_{scheme}")
     for i, (inp, tgt) in enumerate(split):
+        perc_train = data_config.perc_train
         if not perc_train:
-            np_save(inp, i, "train_data")
-            np_save(tgt, i, "train_target")
+            np_save(folder, inp, i, "train_data")
+            np_save(folder, tgt, i, "train_target")
         else:
             if perc_train > 1.0 or perc_train < 0.0:
                 raise ValueError("perc_train should be a float in ]0,1]")
             n_train = round(len(inp) * perc_train)
             t_inp, t_tgt = inp[:n_train], tgt[:n_train]
             v_inp, v_tgt = inp[n_train:], tgt[n_train:]
-            np_save(t_inp, i, "train_data")
-            np_save(t_tgt, i, "train_target")
-            np_save(v_inp, i, "valid_data")
-            np_save(v_tgt, i, "valid_target")
-
-
-def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace:
-    """Set up and run a command-line arguments parser."""
-    usage = """
-        Download and split data into heterogeneous shards.
-
-        The implemented schemes are the following:
-        * "iid":
-            Split the dataset through iid random sampling.
-        * "labels":
-            Split the dataset into shards that hold all samples
-            that have mutually-exclusive target classes.
-        * "biased":
-            Split the dataset through random sampling according
-            to a shard-specific random labels distribution.
-    """
-    usage = re.sub("\n *(?=[a-z])", " ", textwrap.dedent(usage))
-    parser = argparse.ArgumentParser(
-        formatter_class=argparse.RawTextHelpFormatter,
-        usage=re.sub("- ", "-", usage),
-    )
-    parser.add_argument(
-        "--n_shards",
-        type=int,
-        default=5,
-        help="Number of shards between which to split the MNIST training data.",
-    )
-    parser.add_argument(
-        "--root",
-        default=".",
-        dest="folder",
-        help="Path to the root folder where to export raw and split data.",
-    )
-    parser.add_argument(
-        "--data_path",
-        default=None,  # CHECK
-        dest="data",
-        help="Path to the data to be split",
-    )
-    parser.add_argument(
-        "--target_path",
-        default=None,  # CHECK
-        dest="target",
-        help="Path to the labels to be split",
-    )
-    schemes_help = """
-        Splitting scheme(s) to use, among {"iid", "labels", "biased"}.
-        If this argument is not specified, all three values are used.
-        See details above on the schemes' definition.
-    """
-    parser.add_argument(
-        "--scheme",
-        action="append",
-        choices=["iid", "labels", "biased"],
-        default=["iid"],
-        dest="schemes",
-        nargs="+",
-        help=textwrap.dedent(schemes_help),
-    )
-    parser.add_argument(
-        "--seed",
-        default=20221109,
-        dest="seed",
-        type=int,
-        help="RNG seed to use (default: 20221109).",
-    )
-    return parser.parse_args(args)
-
-
-def main(args: Optional[List[str]] = None) -> None:
-    """Run splitting schemes based on commandline-input arguments."""
-    cmdargs = parse_args(args)
-    for scheme in cmdargs.schemes:
-        split_data(
-            folder=os.path.join(cmdargs.folder, f"data_{scheme}"),
-            n_shards=cmdargs.n_shards,
-            data=cmdargs.data,
-            target=cmdargs.target,
-            scheme=scheme,
-            seed=cmdargs.seed,
-        )
-
-
-if __name__ == "__main__":
-    main()
+            np_save(folder, t_inp, i, "train_data")
+            np_save(folder, t_tgt, i, "train_target")
+            np_save(folder, v_inp, i, "valid_data")
+            np_save(folder, v_tgt, i, "valid_target")
diff --git a/declearn/quickrun/run.py b/declearn/quickrun/run.py
index 6bd718de..e52a859d 100644
--- a/declearn/quickrun/run.py
+++ b/declearn/quickrun/run.py
@@ -37,12 +37,13 @@ import os
 import re
 import textwrap
 from pathlib import Path
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Tuple
 
 from declearn.communication import NetworkClientConfig, NetworkServerConfig
 from declearn.dataset import InMemoryDataset
 from declearn.main import FederatedClient, FederatedServer
 from declearn.main.config import FLOptimConfig, FLRunConfig
+from declearn.model.api import Model
 from declearn.test_utils import make_importable
 from declearn.utils import run_as_processes
 
@@ -53,107 +54,49 @@ DEFAULT_FOLDER = "./examples/quickrun"
 # Perform local imports.
 # pylint: disable=wrong-import-order, wrong-import-position
 with make_importable(os.path.dirname(__file__)):
+    from _config import DataSplitConfig, ExperimentConfig, ModelConfig
+    from _parser import parse_data_folder
     from _split_data import split_data
 # pylint: enable=wrong-import-order, wrong-import-position
 
 
+def _get_model(folder, model_config) -> Model:
+    file = "model"
+    if m_file := model_config.model_file:
+        folder = os.path.dirname(m_file)
+        file = m_file.rsplit("/", 1)[-1].split(".")[0]
+    with make_importable(folder):
+        mod = importlib.import_module(file)
+        model_cls = getattr(mod, model_config.model_name)
+    return model_cls
+
+
 def _run_server(
     folder: str,
     network: NetworkServerConfig,
+    model_config: ModelConfig,
     optim: FLOptimConfig,
     config: FLRunConfig,
 ) -> None:
     """Routine to run a FL server, called by `run_declearn_experiment`."""
-    # get Model
-    name = "MyModel"
-    with make_importable(folder):
-        mod = importlib.import_module("model")
-        model_cls = getattr(mod, name)
-        model = model_cls
+    model = _get_model(folder, model_config)
     server = FederatedServer(model, network, optim)
     server.run(config)
 
 
-def parse_data_folder(folder: str) -> Dict:
-    """Utils parsing a data folder following a standard format into a nested"
-    dictionnary.
-
-    The expected format is :
-
-        folder/
-        └─── data*/
-            └─── client*/
-            │      train_data.* - training data
-            │      train_target.* - training labels
-            │      valid_data.* - validation data
-            │      valid_target.* - validation labels
-            └─── client*/
-            │    ...
-    """
-    # Get data dir
-    gen_folders = Path(folder).glob("data*")
-    data_folder = next(gen_folders, False)
-    if not data_folder:
-        raise ValueError(
-            f"No folder starting with 'data' found in {folder}. "
-            "Please store your data under a 'data_*' folder"
-        )
-    if next(gen_folders, False):
-        raise ValueError(
-            "More than one folder starting with 'data' found"
-            f"in {folder}. Please store your data under a single"
-            "parent folder"
-        )
-    # Get clients dir
-    gen_folders = data_folder.glob("client*")  # type: ignore
-    first_client = next(gen_folders, False)
-    if not first_client:
-        raise ValueError(
-            f"No folder starting with 'client' found in {data_folder}. "
-            "Please store your individual under client data under"
-            "a 'client*' folder"
-        )
-    clients = {str(first_client).rsplit("/", 1)[-1]: {}}
-    while client := next(gen_folders, False):
-        clients[str(client).rsplit("/", 1)[-1]] = {}
-    # Get train and valid files
-    data_items = [
-        "train_data",
-        "train_target",
-        "valid_data",
-        "valid_target",
-    ]
-    for client, files in clients.items():
-        for d in data_items:
-            gen_file = Path(data_folder / client).glob(f"{d}*")  # type: ignore
-            file = next(gen_file, False)
-            if not file:
-                raise ValueError(
-                    f"Could not find a file named '{d}.*' in {client}"
-                )
-            if next(gen_file, False):
-                raise ValueError(
-                    f"Found more than one file named '{d}.*' in {client}"
-                )
-            files[d] = str(file)
-
-    return clients
-
-
 def _run_client(
+    folder: str,
     network: NetworkClientConfig,
+    model_config: ModelConfig,
     name: str,
     paths: dict,
-    folder: str,
 ) -> None:
     """Routine to run a FL client, called by `run_declearn_experiment`."""
     # Overwrite client name based on folder name
     network.name = name
+    # Make the model importable
+    _ = _get_model(folder, model_config)
     # Wrap train and validation data as Dataset objects.
-    name = "MyModel"
-    with make_importable(folder):
-        mod = importlib.import_module("model")
-        model_cls = getattr(mod, name)  # pylint: disable=unused-variable
     train = InMemoryDataset(
         paths.get("train_data"),
         target=paths.get("train_target"),
@@ -167,8 +110,42 @@ def _run_client(
     client.run()
 
 
+def get_toml_folder(config: Optional[str] = None) -> Tuple[str, str]:
+    """Deternmine if provided config is a file or a directory, and
+    return :
+    * The path to the TOML config file
+    * The path to the main folder of the experiment
+    """
+    # default to the mnist example
+    if not config:
+        config = DEFAULT_FOLDER
+    config = os.path.abspath(config)
+    # check if config is TOML or dir
+    if os.path.isfile(config):
+        toml = config
+        folder = os.path.dirname(config)
+    elif os.path.isdir(config):
+        folder = config
+        toml = f"{folder}/config.toml"
+    return toml, folder
+
+
+def locate_or_create_split_data(toml: str, folder: str) -> Dict:
+    """Attempts to find split data according to the config toml or
+    or the defualt behavior. If failed, attempts to find full data
+    according to the config toml and split it"""
+    expe_config = ExperimentConfig.from_toml(toml, False, "experiment")
+    try:
+        client_dict = parse_data_folder(expe_config, folder)
+    except ValueError:
+        data_config = DataSplitConfig.from_toml(toml, False, "data")
+        split_data(folder, data_config)
+        client_dict = parse_data_folder(expe_config,folder)
+    return client_dict
+
+
 def quickrun(
-    folder: Optional[str] = None,
+    config: Optional[str] = None,
     **kwargs: Any,
 ) -> None:
     """Run a server and its clients using multiprocessing.
@@ -176,28 +153,24 @@ def quickrun(
     The kwargs are the arguments expected by split_data,
     see [the documentation][declearn.quickrun._split_data]
     """
-    # default to the mnist exampl
-    if not folder:
-        folder = DEFAULT_FOLDER
-    folder = os.path.abspath(folder)
-    # Get datasets and client_names from folder
-    try:
-        client_dict = parse_data_folder(folder)
-    except ValueError:
-        split_data(folder, **kwargs)
-        client_dict = parse_data_folder(folder)
+    toml, folder = get_toml_folder(config)
+    # locate split data or split it if needed
+    client_dict = locate_or_create_split_data(toml, folder)
     # Parse toml file to ServerConfig and ClientConfig
-    toml = f"{folder}/config.toml"
     ntk_server = NetworkServerConfig.from_toml(toml, False, "network_server")
     optim = FLOptimConfig.from_toml(toml, False, "optim")
     run = FLRunConfig.from_toml(toml, False, "run")
     ntk_client = NetworkClientConfig.from_toml(toml, False, "network_client")
+    model_config = ModelConfig.from_toml(toml, False, "model")
     # Set up a (func, args) tuple specifying the server process.
-    p_server = (_run_server, (folder, ntk_server, optim, run))
+    p_server = (_run_server, (folder, ntk_server, model_config, optim, run))
     # Set up the (func, args) tuples specifying client-wise processes.
     p_client = []
     for name, data_dict in client_dict.items():
-        client = (_run_client, (ntk_client, name, data_dict, folder))
+        client = (
+            _run_client,
+            (folder, ntk_client, model_config, name, data_dict),
+        )
         p_client.append(client)
     # Run each and every process in parallel.
     success, outputs = run_as_processes(p_server, *p_client)
@@ -210,29 +183,16 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace:
     """Set up and run a command-line arguments parser."""
     usage = """
         Quickly run an example locally using declearn.
-        The script requires to be provided with the path to a folder
-        containing:
+        The script requires to be provided with the path a TOML file
+        with all the elements required to configurate an FL experiment,
+        or the path to a folder containing :
+        * a TOML file with all the elements required to configurate an 
+        FL experiment
         * A declearn model
-        * A TOML file with all the elements required to configurate an FL
-        experiment
         * A data folder, structured in a specific way
 
         If not provided with this, the script defaults to the MNIST example
         provided by declearn in `declearn.example.quickrun`.
-
-        Once launched, this script splits data into heterogeneous shards. It
-        then locally runs the FL experiment as layed out in the TOML file,
-        using privided model and data, and stores its result in the same folder.
-
-        The implemented schemes are the following:
-        * "iid":
-            Split the dataset through iid random sampling.
-        * "labels":
-            Split the dataset into shards that hold all samples
-            that have mutually-exclusive target classes.
-        * "biased":
-            Split the dataset through random sampling according
-            to a shard-specific random labels distribution.
     """
     usage = re.sub("\n *(?=[a-z])", " ", textwrap.dedent(usage))
     parser = argparse.ArgumentParser(
@@ -240,73 +200,18 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace:
         usage=re.sub("- ", "-", usage),
     )
     parser.add_argument(
-        "--n_shards",
-        type=int,
-        default=5,
-        help="Number of shards between which to split the data.",
-    )
-    parser.add_argument(
-        "--root",
+        "--config",
         default=None,
-        dest="folder",
+        dest="config",
         help="Path to the root folder where to export data.",
     )
-    parser.add_argument(
-        "--data_path",
-        default=None,
-        dest="data",
-        help="Path to the data to be split",
-    )
-    parser.add_argument(
-        "--target_path",
-        default=None,
-        dest="target",
-        help="Path to the labels to be split",
-    )
-    schemes_help = """
-        Splitting scheme(s) to use, among {"iid", "labels", "biased"}.
-        If this argument is not specified, all "iid" is used.
-        See details above on the schemes' definition.
-    """
-    parser.add_argument(
-        "--scheme",
-        action="append",
-        choices=["iid", "labels", "biased"],
-        default=["iid"],
-        dest="schemes",
-        nargs="+",
-        help=textwrap.dedent(schemes_help),
-    )
-    parser.add_argument(
-        "--train_split",
-        default=0.8,
-        dest="perc_train",
-        type=float,
-        help="What proportion of the data to use for training vs validation",
-    )
-    parser.add_argument(
-        "--seed",
-        default=20221109,
-        dest="seed",
-        type=int,
-        help="RNG seed to use (default: 20221109).",
-    )
     return parser.parse_args(args)
 
 
 def main(args: Optional[List[str]] = None) -> None:
     """Quikcrun based on commandline-input arguments."""
     cmdargs = parse_args(args)
-    for scheme in cmdargs.schemes:
-        quickrun(
-            folder=cmdargs.folder,
-            n_shards=cmdargs.n_shards,
-            data=cmdargs.data,
-            target=cmdargs.target,
-            scheme=scheme,
-            perc_train=cmdargs.perc_train,
-            seed=cmdargs.seed,
-        )
+    quickrun(folder=cmdargs.config)
 
 
 if __name__ == "__main__":
diff --git a/examples/quickrun/config.toml b/examples/quickrun/config.toml
index 7e941bef..32f00403 100644
--- a/examples/quickrun/config.toml
+++ b/examples/quickrun/config.toml
@@ -15,6 +15,7 @@ server_opt = 1.0 # The server learning rate
     [optim.client_opt]
     lrate = 0.01 # The client learning rate
     regularizers = [["lasso", {alpha = 0.1}]] # The list of regularizer modules, each a list
+    modules = [["momentum", {"beta" = 0.9}]]
 
 [run]
 rounds = 2 # Number of training rounds
@@ -33,3 +34,12 @@ rounds = 2 # Number of training rounds
     batch_size = 128 # Evaluation batch size
 
 
+[experiment]
+# all args for parse_data_folder
+# target_folder = !!!
+
+[model]
+# information on where to find the model file
+
+[data]
+# all args from split_data argparser
\ No newline at end of file
diff --git a/examples/quickrun/model.py b/examples/quickrun/model.py
index 211adbcd..8e5080bc 100644
--- a/examples/quickrun/model.py
+++ b/examples/quickrun/model.py
@@ -8,46 +8,45 @@ import torch.nn.functional as F
 from declearn.model.tensorflow import TensorflowModel
 from declearn.model.torch import TorchModel
 
-
-class Net(nn.Module):
-    def __init__(self):
-        super(Net, self).__init__()
-        self.conv1 = nn.Conv2d(1, 32, 3, 1)
-        self.conv2 = nn.Conv2d(32, 64, 3, 1)
-        self.dropout1 = nn.Dropout(0.25)
-        self.dropout2 = nn.Dropout(0.5)
-        self.fc1 = nn.Linear(9216, 128)
-        self.fc2 = nn.Linear(128, 10)
-
-    def forward(self, x):
-        x = torch.transpose(x, 3, 1)
-        x = self.conv1(x)
-        x = F.relu(x)
-        x = self.conv2(x)
-        x = F.relu(x)
-        x = F.max_pool2d(x, 2)
-        x = self.dropout1(x)
-        x = torch.flatten(x, 1)
-        x = self.fc1(x)
-        x = F.relu(x)
-        x = self.dropout2(x)
-        x = self.fc2(x)
-        output = F.log_softmax(x, dim=1)
-        return output
-
-
-MyModel = TorchModel(Net(), loss=nn.NLLLoss())
-
-# stack = [
-#     tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),
-#     tf.keras.layers.Conv2D(32, 3, 1, activation="relu"),
-#     tf.keras.layers.Conv2D(64, 3, 1, activation="relu"),
-#     tf.keras.layers.MaxPool2D(2),
-#     tf.keras.layers.Dropout(0.25),
-#     tf.keras.layers.Flatten(),
-#     tf.keras.layers.Dense(128, activation="relu"),
-#     tf.keras.layers.Dropout(0.5),
-#     tf.keras.layers.Dense(10, activation="softmax"),
-# ]
-# model = tf.keras.models.Sequential(stack)
-# MyModel = TensorflowModel(model, loss="sparse_categorical_crossentropy")
+# class Net(nn.Module):
+#     def __init__(self):
+#         super(Net, self).__init__()
+#         self.conv1 = nn.Conv2d(1, 32, 3, 1)
+#         self.conv2 = nn.Conv2d(32, 64, 3, 1)
+#         self.dropout1 = nn.Dropout(0.25)
+#         self.dropout2 = nn.Dropout(0.5)
+#         self.fc1 = nn.Linear(9216, 128)
+#         self.fc2 = nn.Linear(128, 10)
+
+#     def forward(self, x):
+#         x = torch.transpose(x, 3, 1)
+#         x = self.conv1(x)
+#         x = F.relu(x)
+#         x = self.conv2(x)
+#         x = F.relu(x)
+#         x = F.max_pool2d(x, 2)
+#         x = self.dropout1(x)
+#         x = torch.flatten(x, 1)
+#         x = self.fc1(x)
+#         x = F.relu(x)
+#         x = self.dropout2(x)
+#         x = self.fc2(x)
+#         output = F.log_softmax(x, dim=1)
+#         return output
+
+
+# MyModel = TorchModel(Net(), loss=nn.NLLLoss())
+
+stack = [
+    tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),
+    tf.keras.layers.Conv2D(32, 3, 1, activation="relu"),
+    tf.keras.layers.Conv2D(64, 3, 1, activation="relu"),
+    tf.keras.layers.MaxPool2D(2),
+    tf.keras.layers.Dropout(0.25),
+    tf.keras.layers.Flatten(),
+    tf.keras.layers.Dense(128, activation="relu"),
+    tf.keras.layers.Dropout(0.5),
+    tf.keras.layers.Dense(10, activation="softmax"),
+]
+model = tf.keras.models.Sequential(stack)
+MyModel = TensorflowModel(model, loss="sparse_categorical_crossentropy")
-- 
GitLab