diff --git a/declearn/quickrun/_split_data.py b/declearn/quickrun/_split_data.py
index f1eff42d312fcaf02174fba8fe39693c51b7769a..96c93ef0c068422e4dff3c8a258b68a0a0cbc48a 100644
--- a/declearn/quickrun/_split_data.py
+++ b/declearn/quickrun/_split_data.py
@@ -63,9 +63,15 @@ def load_mnist(
     # Load the desired subset of MNIST
     tag = "train" if train else "test"
     url = f"{SOURCE_URL}/mnist_{tag}.csv"
-    data = requests.get(url, verify=False, timeout=20).content
-    df = pd.read_csv(io.StringIO(data.decode("utf-8")), header=None, sep=",")
-    return df.iloc[:, 1:].to_numpy(), df[0].to_numpy()[:, None]
+    og_data = requests.get(url, verify=False, timeout=20).content
+    df = pd.read_csv(
+        io.StringIO(og_data.decode("utf-8")), header=None, sep=","
+    )
+    data = df.iloc[:, 1:].to_numpy()
+    data = (data.reshape(data.shape[0], 28, 28, 1) / 255).astype(np.single)
+    # Channel last : B,H,W,C
+    labels = df[0].to_numpy()
+    return data, labels
 
 
 def load_data(
@@ -101,7 +107,7 @@ def load_data(
         raise ValueError("The data path provided is not a valid file")
 
     if isinstance(target, int):
-        labels = inputs[target][:, None]
+        labels = inputs[:, target]
         inputs = np.delete(inputs, target, axis=1)
     if isinstance(target, str):
         if os.path.isfile(target):
@@ -265,7 +271,7 @@ def split_data(
                 raise ValueError("perc_train should be a float in ]0,1]")
             n_train = round(len(inp) * perc_train)
             t_inp, t_tgt = inp[:n_train], tgt[:n_train]
-            v_inp, v_tgt = inp[n_train:], inp[n_train:]
+            v_inp, v_tgt = inp[n_train:], tgt[n_train:]
             np_save(t_inp, i, "train_data")
             np_save(t_tgt, i, "train_target")
             np_save(v_inp, i, "valid_data")
diff --git a/declearn/quickrun/run.py b/declearn/quickrun/run.py
index c6562bbd0980bf662596eacdec77cb5bd9a67293..6bd718de9f8bb72f2adaec86cd06cb4a1e4cad3a 100644
--- a/declearn/quickrun/run.py
+++ b/declearn/quickrun/run.py
@@ -25,10 +25,10 @@ The script requires to be provided with the path to a folder containing:
 * A data folder, structured in a specific way
 
 If not provided with this, the script defaults to the MNIST example provided
-by declearn in `declearn.example.quickrun`. 
+by declearn in `declearn.example.quickrun`.
 
-The script then locally runs the FL experiment as layed out in the TOML file, 
-using privided model and data, and stores its result in the same folder.  
+The script then locally runs the FL experiment as layed out in the TOML file,
+using privided model and data, and stores its result in the same folder.
 """
 
 import argparse
@@ -58,12 +58,18 @@ with make_importable(os.path.dirname(__file__)):
 
 
 def _run_server(
-    model: str,
+    folder: str,
     network: NetworkServerConfig,
     optim: FLOptimConfig,
     config: FLRunConfig,
 ) -> None:
     """Routine to run a FL server, called by `run_declearn_experiment`."""
+    # get Model
+    name = "MyModel"
+    with make_importable(folder):
+        mod = importlib.import_module("model")
+        model_cls = getattr(mod, name)
+        model = model_cls
     server = FederatedServer(model, network, optim)
     server.run(config)
 
@@ -135,16 +141,19 @@ def parse_data_folder(folder: str) -> Dict:
 
 
 def _run_client(
-    network: str,
+    network: NetworkClientConfig,
     name: str,
     paths: dict,
+    folder: str,
 ) -> None:
     """Routine to run a FL client, called by `run_declearn_experiment`."""
-    # Run the declearn FL client routine.
-    netwk = NetworkClientConfig.from_toml(network)
     # Overwrite client name based on folder name
-    netwk.name = name
+    network.name = name
     # Wrap train and validation data as Dataset objects.
+    name = "MyModel"
+    with make_importable(folder):
+        mod = importlib.import_module("model")
+        model_cls = getattr(mod, name)  # pylint: disable=unused-variable
     train = InMemoryDataset(
         paths.get("train_data"),
         target=paths.get("train_target"),
@@ -154,7 +163,7 @@ def _run_client(
         paths.get("valid_data"),
         target=paths.get("valid_target"),
     )
-    client = FederatedClient(netwk, train, valid)
+    client = FederatedClient(network, train, valid)
     client.run()
 
 
@@ -167,7 +176,7 @@ def quickrun(
     The kwargs are the arguments expected by split_data,
     see [the documentation][declearn.quickrun._split_data]
     """
-    # default to the mnist example
+    # default to the mnist exampl
     if not folder:
         folder = DEFAULT_FOLDER
     folder = os.path.abspath(folder)
@@ -183,18 +192,12 @@ def quickrun(
     optim = FLOptimConfig.from_toml(toml, False, "optim")
     run = FLRunConfig.from_toml(toml, False, "run")
     ntk_client = NetworkClientConfig.from_toml(toml, False, "network_client")
-    # get Model
-    name = "MyModel"
-    with make_importable(folder):
-        mod = importlib.import_module("model")
-        model_cls = getattr(mod, name)
-        model = model_cls
     # Set up a (func, args) tuple specifying the server process.
-    p_server = (_run_server, (model, ntk_server, optim, run))
+    p_server = (_run_server, (folder, ntk_server, optim, run))
     # Set up the (func, args) tuples specifying client-wise processes.
     p_client = []
     for name, data_dict in client_dict.items():
-        client = (_run_client, (ntk_client, name, data_dict))
+        client = (_run_client, (ntk_client, name, data_dict, folder))
         p_client.append(client)
     # Run each and every process in parallel.
     success, outputs = run_as_processes(p_server, *p_client)
diff --git a/examples/quickrun/config.toml b/examples/quickrun/config.toml
index dd45f263225dc78b005df90419efd1d604846f58..7e941bef12c1a7ec9f7109cd455c3bace998d04a 100644
--- a/examples/quickrun/config.toml
+++ b/examples/quickrun/config.toml
@@ -9,27 +9,27 @@ server_uri = "ws://localhost:8765"
 name = "replaceme"
 
 [optim]
-aggregator = "averaging"
-server_opt = 1.0
+aggregator = "averaging" # The chosen aggregation strategy
+server_opt = 1.0 # The server learning rate
 
     [optim.client_opt]
-    lrate = 0.001
-    regularizers = [["lasso", {alpha = 0.1}]]
+    lrate = 0.01 # The client learning rate
+    regularizers = [["lasso", {alpha = 0.1}]] # The list of regularizer modules, each a list
 
 [run]
-rounds = 10
+rounds = 2 # Number of training rounds
 
     [run.register]
-    min_clients = 3
+    min_clients = 1
+    max_clients = 6
+    timeout = 5
 
     [run.training]
-    n_epoch = 1
-    batch_size = 48
-    drop_remainder = false
+    n_epoch = 1 # Number of local epochs
+    batch_size = 48 # Training batch size
+    drop_remainder = false # Whether to drop the last trainig examples
 
     [run.evaluate]
-    batch_size = 128
-
-
+    batch_size = 128 # Evaluation batch size
 
 
diff --git a/examples/quickrun/model.py b/examples/quickrun/model.py
index c9984ba70bf6e3a2a39ec8a8fe7af27719bd391e..211adbcdbfc1fe73123f0d50980df8d33208322f 100644
--- a/examples/quickrun/model.py
+++ b/examples/quickrun/model.py
@@ -1,33 +1,53 @@
 """Wrapping a torch model"""
 
+import tensorflow as tf
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
+from declearn.model.tensorflow import TensorflowModel
 from declearn.model.torch import TorchModel
 
 
 class Net(nn.Module):
-
-    """A basic CNN, directly copied from Torch's 60 min blitz"""
-
     def __init__(self):
-        super().__init__()
-        self.conv1 = nn.Conv2d(3, 6, 5)
-        self.pool = nn.MaxPool2d(2, 2)
-        self.conv2 = nn.Conv2d(6, 16, 5)
-        self.fc1 = nn.Linear(16 * 5 * 5, 120)
-        self.fc2 = nn.Linear(120, 84)
-        self.fc3 = nn.Linear(84, 10)
+        super(Net, self).__init__()
+        self.conv1 = nn.Conv2d(1, 32, 3, 1)
+        self.conv2 = nn.Conv2d(32, 64, 3, 1)
+        self.dropout1 = nn.Dropout(0.25)
+        self.dropout2 = nn.Dropout(0.5)
+        self.fc1 = nn.Linear(9216, 128)
+        self.fc2 = nn.Linear(128, 10)
 
     def forward(self, x):
-        x = self.pool(F.relu(self.conv1(x)))
-        x = self.pool(F.relu(self.conv2(x)))
-        x = torch.flatten(x, 1)  # flatten all dimensions except batch
-        x = F.relu(self.fc1(x))
-        x = F.relu(self.fc2(x))
-        x = self.fc3(x)
-        return x
+        x = torch.transpose(x, 3, 1)
+        x = self.conv1(x)
+        x = F.relu(x)
+        x = self.conv2(x)
+        x = F.relu(x)
+        x = F.max_pool2d(x, 2)
+        x = self.dropout1(x)
+        x = torch.flatten(x, 1)
+        x = self.fc1(x)
+        x = F.relu(x)
+        x = self.dropout2(x)
+        x = self.fc2(x)
+        output = F.log_softmax(x, dim=1)
+        return output
 
 
 MyModel = TorchModel(Net(), loss=nn.NLLLoss())
+
+# stack = [
+#     tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),
+#     tf.keras.layers.Conv2D(32, 3, 1, activation="relu"),
+#     tf.keras.layers.Conv2D(64, 3, 1, activation="relu"),
+#     tf.keras.layers.MaxPool2D(2),
+#     tf.keras.layers.Dropout(0.25),
+#     tf.keras.layers.Flatten(),
+#     tf.keras.layers.Dense(128, activation="relu"),
+#     tf.keras.layers.Dropout(0.5),
+#     tf.keras.layers.Dense(10, activation="softmax"),
+# ]
+# model = tf.keras.models.Sequential(stack)
+# MyModel = TensorflowModel(model, loss="sparse_categorical_crossentropy")