diff --git a/declearn/dataset/utils/__init__.py b/declearn/dataset/utils/__init__.py
index dfebe91da3623f4f8cb87a5b3a6f1aae5de1681e..c19f9cb1efe452c6763bff5c3286ca5e9f442dc0 100644
--- a/declearn/dataset/utils/__init__.py
+++ b/declearn/dataset/utils/__init__.py
@@ -30,6 +30,13 @@ to and from various file formats:
     Backend to load a sparse matrix from a dump file.
 * [sparse_to_file][declearn.dataset.utils.sparse_to_file]:
     Backend to save a sparse matrix to a dump file
+
+Data splitting
+--------------
+* [split_multi_classif_dataset]
+[declearn.dataset.utils.split_multi_classif_dataset]:
+    Split a classification dataset into (opt. heterogeneous) shards.
 """
 from ._sparse import sparse_from_file, sparse_to_file
 from ._save_load import load_data_array, save_data_array
+from ._split_classif import split_multi_classif_dataset
diff --git a/declearn/dataset/utils/_split_classif.py b/declearn/dataset/utils/_split_classif.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d941d463742fc3fd9dca22695a2057d08f2551b
--- /dev/null
+++ b/declearn/dataset/utils/_split_classif.py
@@ -0,0 +1,170 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utils to split a multi-category classification dataset into shards."""
+
+from typing import List, Literal, Optional, Tuple
+
+import numpy as np
+
+
+__all__ = [
+    "split_multi_classif_dataset",
+]
+
+
+def split_multi_classif_dataset(
+    dataset: Tuple[np.ndarray, np.ndarray],
+    n_shards: int,
+    scheme: Literal["iid", "labels", "biased"],
+    p_valid: float = 0.2,
+    seed: Optional[int] = None,
+) -> List[Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]]:
+    """Split a classification dataset into (opt. heterogeneous) shards.
+
+    The data-splitting schemes are the following:
+
+    - If "iid", split the dataset through iid random sampling.
+    - If "labels", split into shards that hold all samples associated
+      with mutually-exclusive target classes.
+    - If "biased", split the dataset through random sampling according
+      to a shard-specific random labels distribution.
+
+    Parameters
+    ----------
+    dataset: tuple(np.ndarray, np.ndarray)
+        Raw dataset, as a pair of numpy arrays that respectively contain
+        the input features and (aligned) labels.
+    n_shards: int
+        Number of shards between which to split the dataset.
+    scheme: {"iid", "labels", "biased"}
+        Splitting scheme to use. In all cases, shards contain mutually-
+        exclusive samples and cover the full dataset. See details above.
+    p_valid: float, default=0.2
+        Share of each shard to turn into a validation subset.
+    seed: int or None, default=None
+        Optional seed to the RNG used for all sampling operations.
+
+    Returns
+    -------
+    shards: list[((np.ndarray, np.ndarray), (np.ndarray, np.ndarray))]
+        List of dataset shards, where each element is formatted as a
+        tuple of tuples: `((x_train, y_train), (x_valid, y_valid))`.
+
+    Raises
+    ------
+    ValueError
+        If `scheme` has an invalid value.
+    """
+    # Select the splitting function to be used.
+    if scheme == "iid":
+        func = split_iid
+    elif scheme == "labels":
+        func = split_labels
+    elif scheme == "biased":
+        func = split_biased
+    else:
+        raise ValueError(f"Invalid 'scheme' value: '{scheme}'.")
+    # Set up the RNG and split the dataset into shards.
+    rng = np.random.default_rng(seed)
+    inputs, target = dataset
+    split = func(inputs, target, n_shards, rng)
+    # Further split shards into training and validation subsets, and return.
+    return [train_valid_split(inp, tgt, p_valid, rng) for inp, tgt in split]
+
+
+def split_iid(
+    inputs: np.ndarray,
+    target: np.ndarray,
+    n_shards: int,
+    rng: np.random.Generator,
+) -> List[Tuple[np.ndarray, np.ndarray]]:
+    """Split a dataset into shards using iid sampling."""
+    order = rng.permutation(len(inputs))
+    s_len = len(inputs) // n_shards
+    split = []  # type: List[Tuple[np.ndarray, np.ndarray]]
+    for idx in range(n_shards):
+        srt = idx * s_len
+        end = (srt + s_len) if idx < (n_shards - 1) else len(order)
+        shard = order[srt:end]
+        split.append((inputs[shard], target[shard]))
+    return split
+
+
+def split_labels(
+    inputs: np.ndarray,
+    target: np.ndarray,
+    n_shards: int,
+    rng: np.random.Generator,
+) -> List[Tuple[np.ndarray, np.ndarray]]:
+    """Split a dataset into shards with mutually-exclusive label classes."""
+    classes = np.unique(target)
+    if n_shards > len(classes):
+        raise ValueError(
+            f"Cannot share {len(classes)} classes between {n_shards}"
+            "shards with mutually-exclusive labels."
+        )
+    s_len = len(classes) // n_shards
+    order = rng.permutation(classes)
+    split = []  # type: List[Tuple[np.ndarray, np.ndarray]]
+    for idx in range(n_shards):
+        srt = idx * s_len
+        end = (srt + s_len) if idx < (n_shards - 1) else len(order)
+        shard = np.isin(target, order[srt:end])
+        split.append((inputs[shard], target[shard]))
+    return split
+
+
+def split_biased(
+    inputs: np.ndarray,
+    target: np.ndarray,
+    n_shards: int,
+    rng: np.random.Generator,
+) -> List[Tuple[np.ndarray, np.ndarray]]:
+    """Split a dataset into shards with heterogeneous label distributions."""
+    classes = np.unique(target)
+    index = np.arange(len(target))
+    s_len = len(target) // n_shards
+    split = []  # type: List[Tuple[np.ndarray, np.ndarray]]
+    for idx in range(n_shards):
+        if idx < (n_shards - 1):
+            # Draw a random distribution of labels for this node.
+            logits = np.exp(rng.normal(size=len(classes)))
+            lprobs = logits[target[index]]
+            lprobs = lprobs / lprobs.sum()
+            # Draw samples based on this distribution, without replacement.
+            shard = rng.choice(index, size=s_len, replace=False, p=lprobs)
+            index = index[~np.isin(index, shard)]
+        else:
+            # For the last node: use the remaining samples.
+            shard = index
+        split.append((inputs[shard], target[shard]))
+    return split
+
+
+def train_valid_split(
+    inputs: np.ndarray,
+    target: np.ndarray,
+    p_valid: float,
+    rng: np.random.Generator,
+) -> Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]:
+    """Split a dataset between train and validation using iid sampling."""
+    order = rng.permutation(len(inputs))
+    v_len = np.ceil(len(inputs) * p_valid).astype(int)
+    train = inputs[order[v_len:]], target[order[v_len:]]
+    valid = inputs[order[:v_len]], target[order[:v_len]]
+    return train, valid
diff --git a/declearn/quickrun/_split_data.py b/declearn/quickrun/_split_data.py
index fd8b3387fc4a4c72a2b06860500ce6aaabd83f25..4e595f424ab7a3d5c33b6df0142042c25cb225ae 100644
--- a/declearn/quickrun/_split_data.py
+++ b/declearn/quickrun/_split_data.py
@@ -33,12 +33,12 @@ instance sparse data
 """
 
 import os
-from typing import List, Optional, Tuple, Union
+from typing import Optional, Tuple, Union
 
 import numpy as np
 
 from declearn.dataset.examples import load_mnist
-from declearn.dataset.utils import load_data_array
+from declearn.dataset.utils import load_data_array, split_multi_classif_dataset
 from declearn.quickrun._config import DataSplitConfig
 
 
@@ -91,76 +91,6 @@ def load_data(
     return inputs, labels
 
 
-def _split_iid(
-    inputs: np.ndarray,
-    target: np.ndarray,
-    n_shards: int,
-    rng: np.random.Generator,
-) -> List[Tuple[np.ndarray, np.ndarray]]:
-    """Split a dataset into shards using iid sampling."""
-    order = rng.permutation(len(inputs))
-    s_len = len(inputs) // n_shards
-    split = []  # type: List[Tuple[np.ndarray, np.ndarray]]
-    for idx in range(n_shards):
-        srt = idx * s_len
-        end = (srt + s_len) if idx < (n_shards - 1) else len(order)
-        shard = order[srt:end]
-        split.append((inputs[shard], target[shard]))
-    return split
-
-
-def _split_labels(
-    inputs: np.ndarray,
-    target: np.ndarray,
-    n_shards: int,
-    rng: np.random.Generator,
-) -> List[Tuple[np.ndarray, np.ndarray]]:
-    """Split a dataset into shards with mutually-exclusive label classes."""
-    classes = np.unique(target)
-    if n_shards > len(classes):
-        raise ValueError(
-            f"Cannot share {len(classes)} classes between {n_shards}"
-            "shards with mutually-exclusive labels."
-        )
-    s_len = len(classes) // n_shards
-    order = rng.permutation(classes)
-    split = []  # type: List[Tuple[np.ndarray, np.ndarray]]
-    for idx in range(n_shards):
-        srt = idx * s_len
-        end = (srt + s_len) if idx < (n_shards - 1) else len(order)
-        shard = np.isin(target, order[srt:end])
-        shuffle = rng.permutation(shard.sum())
-        split.append((inputs[shard][shuffle], target[shard][shuffle]))
-    return split
-
-
-def _split_biased(
-    inputs: np.ndarray,
-    target: np.ndarray,
-    n_shards: int,
-    rng: np.random.Generator,
-) -> List[Tuple[np.ndarray, np.ndarray]]:
-    """Split a dataset into shards with heterogeneous label distributions."""
-    classes = np.unique(target)
-    index = np.arange(len(target))
-    s_len = len(target) // n_shards
-    split = []  # type: List[Tuple[np.ndarray, np.ndarray]]
-    for idx in range(n_shards):
-        if idx < (n_shards - 1):
-            # Draw a random distribution of labels for this node.
-            logits = np.exp(rng.normal(size=len(classes)))
-            lprobs = logits[target[index]]
-            lprobs = lprobs / lprobs.sum()
-            # Draw samples based on this distribution, without replacement.
-            shard = rng.choice(index, size=s_len, replace=False, p=lprobs)
-            index = index[~np.isin(index, shard)]
-        else:
-            # For the last node: use the remaining samples.
-            shard = index
-        split.append((inputs[shard], target[shard]))
-    return split
-
-
 def split_data(data_config: DataSplitConfig, folder: str) -> None:
     """Download and randomly split a dataset into shards.
 
@@ -180,54 +110,33 @@ def split_data(data_config: DataSplitConfig, folder: str) -> None:
     data_config: DataSplitConfig
         A DataSplitConfig instance, see class documentation for details
     """
-
-    def np_save(folder, data, i, name):
-        data_dir = os.path.join(folder, f"client_{i}")
-        os.makedirs(data_dir, exist_ok=True)
-        np.save(os.path.join(data_dir, f"{name}.npy"), data)
-
-    # Overwrite default folder if provided
-    scheme = data_config.scheme
-    name = f"data_{scheme}"
-    data_file = data_config.data_file
-    label_file = data_config.label_file
+    # Select output folder.
     if data_config.data_folder:
         folder = os.path.dirname(data_config.data_folder)
-        name = os.path.split(data_config.data_folder)[-1]
-        data_file = os.path.abspath(data_config.data_file)
-        label_file = os.path.abspath(data_config.label_file)
-    # Select the splitting function to be used.
-    if scheme == "iid":
-        func = _split_iid
-    elif scheme == "labels":
-        func = _split_labels
-    elif scheme == "biased":
-        func = _split_biased
     else:
-        raise ValueError(f"Invalid 'scheme' value: '{scheme}'.")
-    # Set up the RNG, download the raw dataset and split it.
-    rng = np.random.default_rng(data_config.seed)
-
-    inputs, labels = load_data(data_file, label_file)
+        folder = f"data_{data_config.scheme}"
+    # Value-check the 'perc_train' parameter.
+    if not 0.0 < data_config.perc_train <= 1.0:
+        raise ValueError("'perc_train' should be a float in ]0,1]")
+    # Load the dataset and split it.
+    inputs, labels = load_data(data_config.data_file, data_config.label_file)
     print(
-        f"Splitting data into {data_config.n_shards}"
-        f"shards using the {scheme} scheme"
+        f"Splitting data into {data_config.n_shards} shards "
+        f"using the '{data_config.scheme}' scheme."
+    )
+    split = split_multi_classif_dataset(
+        dataset=(inputs, labels),
+        n_shards=data_config.n_shards,
+        scheme=data_config.scheme,  # type: ignore
+        p_valid=(1 - data_config.perc_train),
+        seed=data_config.seed,
     )
-    split = func(inputs, labels, data_config.n_shards, rng)
     # Export the resulting shard-wise data to files.
-    folder = os.path.join(folder, name)
-    for i, (inp, tgt) in enumerate(split):
-        perc_train = data_config.perc_train
-        if not perc_train:
-            np_save(folder, inp, i, "train_data")
-            np_save(folder, tgt, i, "train_target")
-        else:
-            if perc_train > 1.0 or perc_train < 0.0:
-                raise ValueError("perc_train should be a float in ]0,1]")
-            n_train = round(len(inp) * perc_train)
-            t_inp, t_tgt = inp[:n_train], tgt[:n_train]
-            v_inp, v_tgt = inp[n_train:], tgt[n_train:]
-            np_save(folder, t_inp, i, "train_data")
-            np_save(folder, t_tgt, i, "train_target")
-            np_save(folder, v_inp, i, "valid_data")
-            np_save(folder, v_tgt, i, "valid_target")
+    for idx, ((x_train, y_train), (x_valid, y_valid)) in enumerate(split):
+        subdir = os.path.join(folder, f"client_{idx}")
+        os.makedirs(subdir, exist_ok=True)
+        np.save(os.path.join(subdir, "train_data.npy"), x_train)
+        np.save(os.path.join(subdir, "train_target.npy"), y_train)
+        if len(x_valid):
+            np.save(os.path.join(subdir, "valid_data.npy"), x_valid)
+            np.save(os.path.join(subdir, "valid_target.npy"), y_valid)