diff --git a/declearn/quickrun/_config.py b/declearn/quickrun/_config.py
index fc30685f4d724b29a2657378bdb3e2ac4888fb8f..7614ed316dc6ded4e4de9dedce8121c4e1071a1a 100644
--- a/declearn/quickrun/_config.py
+++ b/declearn/quickrun/_config.py
@@ -82,7 +82,7 @@ class DataSplitConfig(TomlConfig):
     # Common args
     data_folder: Optional[str] = None
     # split_data args
-    n_shards: int = 5
+    n_shards: int = 3
     data_file: Optional[str] = None
     label_file: Optional[Union[str, int]] = None
     scheme: str = "iid"
diff --git a/declearn/quickrun/_parser.py b/declearn/quickrun/_parser.py
index eae9c9fccb39ee9e601595dff73c7529ad6021eb..a753d7dcfd26f5379dcb883af8947f13355d5ea3 100644
--- a/declearn/quickrun/_parser.py
+++ b/declearn/quickrun/_parser.py
@@ -63,6 +63,7 @@ def parse_data_folder(
     data_folder = expe_config.data_folder
     client_names = expe_config.client_names
     dataset_names = expe_config.dataset_names
+    scheme = expe_config.scheme
 
     if not folder and not data_folder:
         raise ValueError(
@@ -70,7 +71,8 @@ def parse_data_folder(
         )
     # Data_folder
     if not data_folder:
-        gen_folders = Path(folder).glob("data*")  # type: ignore
+        search_str = f"_{scheme}" if scheme else "*"
+        gen_folders = Path(folder).glob(f"data{search_str}")  # type: ignore
         data_folder = next(gen_folders, False)  # type: ignore
         if not data_folder:
             raise ValueError(
@@ -83,6 +85,8 @@ def parse_data_folder(
                 f"in {folder}. Please store your data under a single"
                 "parent folder"
             )
+    else:
+        data_folder = Path(data_folder)
     # Get clients dir
     if client_names:
         if isinstance(client_names, list):
diff --git a/declearn/quickrun/_split_data.py b/declearn/quickrun/_split_data.py
index 0fbf80692b8b2b155e048a120ee1b10926cdcf2d..f46e0dfa7eb2849c05c8e400e8df0150a1bb5aa1 100644
--- a/declearn/quickrun/_split_data.py
+++ b/declearn/quickrun/_split_data.py
@@ -111,6 +111,7 @@ def load_data(
         inputs = load_data_array(data)
         inputs = np.asarray(inputs)
     else:
+        print("\n\n", data, "\n\n")
         raise ValueError("The data path provided is not a valid file")
 
     if isinstance(target, int):
@@ -225,10 +226,16 @@ def split_data(data_config: DataSplitConfig, folder: str) -> None:
         np.save(os.path.join(data_dir, f"{name}.npy"), data)
 
     # Overwrite default folder if provided
+    scheme = data_config.scheme
+    name = f"data_{scheme}"
+    data_file = data_config.data_file
+    label_file = data_config.label_file
     if data_config.data_folder:
-        folder = data_config.data_folder
+        folder = os.path.dirname(data_config.data_folder)
+        name = os.path.split(data_config.data_folder)[-1]
+        data_file = os.path.abspath(data_config.data_file)
+        label_file = os.path.abspath(data_config.label_file)
     # Select the splitting function to be used.
-    scheme = data_config.scheme
     if scheme == "iid":
         func = _split_iid
     elif scheme == "labels":
@@ -239,14 +246,15 @@ def split_data(data_config: DataSplitConfig, folder: str) -> None:
         raise ValueError(f"Invalid 'scheme' value: '{scheme}'.")
     # Set up the RNG, download the raw dataset and split it.
     rng = np.random.default_rng(data_config.seed)
-    inputs, labels = load_data(data_config.data_file, data_config.label_file)
+
+    inputs, labels = load_data(data_file, label_file)
     print(
         f"Splitting data into {data_config.n_shards}"
         f"shards using the {scheme} scheme"
     )
     split = func(inputs, labels, data_config.n_shards, rng)
     # Export the resulting shard-wise data to files.
-    folder = os.path.join(folder, f"data_{scheme}")
+    folder = os.path.join(folder, name)
     for i, (inp, tgt) in enumerate(split):
         perc_train = data_config.perc_train
         if not perc_train:
diff --git a/declearn/quickrun/run.py b/declearn/quickrun/run.py
index 1a8f8a08300c1b0234cd7bc481d901ae3cffe8a7..21cd0028af162cce03344db5e8b4e8e50ee4a526 100644
--- a/declearn/quickrun/run.py
+++ b/declearn/quickrun/run.py
@@ -59,7 +59,8 @@ with make_importable(os.path.dirname(__file__)):
 # pylint: enable=wrong-import-order, wrong-import-position
 
 
-def _get_model(folder, model_config) -> Model:
+def get_model(folder, model_config) -> Model:
+    "Return a model instance from a model config instance"
     file = "model"
     if m_file := model_config.model_file:
         folder = os.path.dirname(m_file)
@@ -70,7 +71,7 @@ def _get_model(folder, model_config) -> Model:
     return model_cls
 
 
-def _run_server(
+def run_server(
     folder: str,
     network: NetworkServerConfig,
     model_config: ModelConfig,
@@ -79,7 +80,7 @@ def _run_server(
     expe_config: ExperimentConfig,
 ) -> None:
     """Routine to run a FL server, called by `run_declearn_experiment`."""
-    model = _get_model(folder, model_config)
+    model = get_model(folder, model_config)
     if expe_config.checkpoint:
         checkpoint = expe_config.checkpoint
     else:
@@ -91,7 +92,7 @@ def _run_server(
     server.run(config)
 
 
-def _run_client(
+def run_client(
     folder: str,
     network: NetworkClientConfig,
     model_config: ModelConfig,
@@ -103,7 +104,7 @@ def _run_client(
     # Overwrite client name based on folder name
     network.name = name
     # Make the model importable
-    _ = _get_model(folder, model_config)
+    _ = get_model(folder, model_config)
     # Add checkpointer
     if expe_config.checkpoint:
         checkpoint = expe_config.checkpoint
@@ -186,14 +187,14 @@ def quickrun(config: Optional[str] = None) -> None:
     expe_cfg = ExperimentConfig.from_toml(toml, False, "experiment")
     # Set up a (func, args) tuple specifying the server process.
     p_server = (
-        _run_server,
+        run_server,
         (folder, ntk_server_cfg, model_cfg, optim_cgf, run_cfg, expe_cfg),
     )
     # Set up the (func, args) tuples specifying client-wise processes.
     p_client = []
     for name, data_dict in client_dict.items():
         client = (
-            _run_client,
+            run_client,
             (folder, ntk_client_cfg, model_cfg, expe_cfg, name, data_dict),
         )
         p_client.append(client)
@@ -241,3 +242,4 @@ def main(args: Optional[List[str]] = None) -> None:
 
 if __name__ == "__main__":
     main()
+    # quickrun(config="examples/quickrun/config_custom.toml") # TODO
diff --git a/declearn/utils/_multiprocess.py b/declearn/utils/_multiprocess.py
index 72b343acb2f24c530fc8eb1ab4c7e18709c961fa..a39ac540d7d1001e466d3560ce3026edee38fca9 100644
--- a/declearn/utils/_multiprocess.py
+++ b/declearn/utils/_multiprocess.py
@@ -17,6 +17,7 @@
 
 """Utils to run concurrent routines parallelly using multiprocessing."""
 
+import functools
 import multiprocessing as mp
 import sys
 import traceback
@@ -110,21 +111,25 @@ def add_exception_catching(
     if not name:
         name = func.__name__
 
-    def wrapped(*args: Any, **kwargs: Any) -> Any:
-        """Call the wrapped function and catch exceptions or results."""
-        nonlocal name, queue
+    return functools.partial(wrapped, func=func, queue=queue, name=name)
 
-        try:
-            result = func(*args, **kwargs)
-        except Exception as exc:  # pylint: disable=broad-exception-caught
-            err = RuntimeError(
-                f"Exception of type {type(exc)} occurred:\n"
-                "".join(traceback.format_exception(type(exc), exc, tb=None))
-            )  # future: `traceback.format_exception(exc)` (py >=3.10)
-            queue.put((name, err))
-            sys.exit(1)
-        else:
-            queue.put((name, result))
-            sys.exit(0)
 
-    return wrapped
+def wrapped(
+    *args: Any,
+    func: Callable[..., Any],
+    queue: Queue,
+    name: str,
+) -> Any:
+    """Call the wrapped function and catch exceptions or results."""
+    try:
+        result = func(*args)
+    except Exception as exc:  # pylint: disable=broad-exception-caught
+        err = RuntimeError(
+            f"Exception of type {type(exc)} occurred:\n"
+            "".join(traceback.format_exception(type(exc), exc, tb=None))
+        )  # future: `traceback.format_exception(exc)` (py >=3.10)
+        queue.put((name, err))
+        sys.exit(1)
+    else:
+        queue.put((name, result))
+        sys.exit(0)