diff --git a/declearn/dataset/_inmemory.py b/declearn/dataset/_inmemory.py
index 650f2c88f402af14a49f624501285b9d978c5ff3..3d1f57368631015b624427b045fa714fdf178205 100644
--- a/declearn/dataset/_inmemory.py
+++ b/declearn/dataset/_inmemory.py
@@ -18,7 +18,8 @@
 """Dataset implementation to serve scikit-learn compatible in-memory data."""
 
 import os
-from typing import Any, Dict, Iterator, List, Optional, Set, Union
+import typing
+from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union
 
 import numpy as np
 import pandas as pd
@@ -29,7 +30,7 @@ from typing_extensions import Self  # future: import from typing (py >=3.11)
 
 from declearn.dataset._base import Dataset, DataSpecs
 from declearn.dataset.utils import load_data_array, save_data_array
-from declearn.typing import Batch
+from declearn.typing import Batch, DataArray
 from declearn.utils import json_dump, json_load, register_type
 
 __all__ = [
@@ -37,7 +38,7 @@ __all__ = [
 ]
 
 
-DataArray = Union[np.ndarray, pd.DataFrame, spmatrix]
+DATA_ARRAY_TYPES = typing.get_args(DataArray)
 
 
 @register_type(group="Dataset")
@@ -87,90 +88,173 @@ class InMemoryDataset(Dataset):
         an instance that is either a numpy ndarray, a pandas
         DataFrame or a scipy spmatrix.
 
-        See the `load_data_array` function in dataset._utils for details
-        on supported file formats.
+        See the `load_data_array` function in `dataset.utils`
+        for details on supported file formats.
 
         Parameters
         ----------
-        data: data array or str
+        data:
             Main data array which contains input features (and possibly
             more), or path to a dump file from which it is to be loaded.
-        target: data array or str or None, default=None
-            Optional data array containing target labels (for supervised
-            learning), or path to a dump file from which to load it.
-            If `data` is a pandas DataFrame (or a path to a csv file),
-            `target` may be the name of a column to use as labels (and
-            thus not to use as input feature unless listed in `f_cols`).
-        s_wght: int or str or function or None, default=None
-            Optional data array containing sample weights, or path to a
-            dump file from which to load it.
-            If `data` is a pandas DataFrame (or a path to a csv file),
-            `s_wght` may be the name of a column to use as labels (and
-            thus not to use as input feature unless listed in `f_cols`).
-        f_cols: list[int] or list[str] or None, default=None
-            Optional list of columns in `data` to use as input features
-            (other columns will not be included in the first array of
-            the batches yielded by `self.generate_batches(...)`).
-        expose_classes: bool, default=False
-            Whether the dataset should be used for classification, in
-            which case the unique values of `target` are exposed under
-            `self.classes` and exported by `self.get_data_specs()`).
-        expose_data_type: bool, default=False
-            Whether the dataset should expose the data type , in
-            which case check if the type is unique and exposed it
-            under `self.data_type` and exported by `self.get_data_specs()`).
-        seed: int or None, default=None
-            Optional seed for the random number generator based on which
-            the dataset is (optionally) shuffled when generating batches.
+        target:
+            Optional target labels, as a data array, or as a path to a
+            dump file, or as the name of a `data` column.
+        s_wght:
+            Optional sample weights, as a data array, or as a path to a
+            dump file, or as the name of a `data` column.
+        f_cols:
+            Optional list of columns in `data` to use as input features.
+            These may be specified as column names or indices. If None,
+            use all non-target, non-sample-weights columns of `data`.
+
+        Other parameters
+        ----------------
+        expose_classes:
+            Whether to expose unique target values as part of data specs.
+            This should only be used for classification datasets.
+        expose_data_type:
+            Whether to expose features' dtype, which will be verified to
+            be unique, as part of data specs.
+        seed:
+            Optional seed for the random number generator used for all
+            randomness-based operations required to generate batches
+            (e.g. to shuffle the data or sample from it).
         """
         # arguments serve modularity; pylint: disable=too-many-arguments
-        self._data_path = None  # type: Optional[str]
-        self._trgt_path = None  # type: Optional[str]
         # Assign the main data array.
-        if isinstance(data, str):
-            self._data_path = data
-            data = load_data_array(data)
-        self.data = data
+        data_array, src_path = self._parse_data_argument(data)
+        self.data = data_array
+        self._data_path = src_path
         # Assign the optional input features list.
-        self.f_cols = f_cols
+        self.f_cols = self._parse_fcols_argument(f_cols, data=self.data)
         # Assign the (optional) target data array.
-        if isinstance(target, str):
-            self._trgt_path = target
-            if (
-                isinstance(self.data, pd.DataFrame)
-                and target in self.data.columns
-            ):
-                if f_cols is None:
-                    self.f_cols = self.f_cols or list(self.data.columns)
-                    self.f_cols.remove(target)  # type: ignore
-                target = self.data[target]
-            else:
-                target = load_data_array(target)
-                if (
-                    isinstance(target, pd.DataFrame)
-                    and len(target.columns) == 1
-                ):
-                    target = target.iloc[:, 0]
-        self.target = target
+        data_array, src_path = self._parse_array_or_column_argument(
+            value=target, data=self.data, name="target"
+        )
+        self.target = data_array
+        self._trgt_path = src_path
+        if self.f_cols and src_path and src_path in self.f_cols:
+            self.f_cols.remove(src_path)  # type: ignore[arg-type]
         # Assign the (optional) sample weights data array.
-        if isinstance(s_wght, str):
-            self._wght_path = s_wght
-            if isinstance(self.data, pd.DataFrame):
-                if s_wght in self.data.columns:
-                    if f_cols is None:
-                        self.f_cols = self.f_cols or list(self.data.columns)
-                        self.f_cols.remove(s_wght)  # type: ignore
-                    s_wght = self.data[s_wght]
-            else:
-                s_wght = load_data_array(s_wght)
-        self.weights = s_wght
-        # Assign the 'expose_classes' attribute.
+        data_array, src_path = self._parse_array_or_column_argument(
+            value=s_wght, data=self.data, name="s_wght"
+        )
+        self.weights = data_array
+        self._wght_path = src_path
+        if self.f_cols and src_path and src_path in self.f_cols:
+            self.f_cols.remove(src_path)  # type: ignore[arg-type]
+        # Assign the 'expose_classes' and 'expose_data_type' attributes.
         self.expose_classes = expose_classes
         self.expose_data_type = expose_data_type
         # Assign a random number generator.
         self.seed = seed
         self._rng = np.random.default_rng(seed)
 
+    @staticmethod
+    def _parse_data_argument(
+        data: Union[DataArray, str],
+    ) -> Tuple[DataArray, Optional[str]]:
+        """Parse 'data' instantiation argument.
+
+        Return the definitive 'data' array, and its source path if any.
+        """
+        # Case when an array is provided directly.
+        if isinstance(data, DATA_ARRAY_TYPES):
+            return data, None
+        # Case when an invalid type is provided.
+        if not isinstance(data, str):
+            raise TypeError(
+                f"'data' must be a data array or str, not '{type(data)}'."
+            )
+        # Case when a string is provided: treat it as a file path.
+        try:
+            array = load_data_array(data)
+        except Exception as exc:
+            raise ValueError(
+                "Error while trying to load main 'data' array from file."
+            ) from exc
+        return array, data
+
+    @staticmethod
+    def _parse_fcols_argument(
+        f_cols: Union[List[str], List[int], None],
+        data: DataArray,
+    ) -> Union[List[str], List[int], None]:
+        """Type and value-check 'f_cols' argument.
+
+        Return its definitive value or raise an exception.
+        """
+        # Case when 'f_cols' is None: optionally replace with list of names.
+        if f_cols is None:
+            if isinstance(data, pd.DataFrame):
+                return list(data.columns)
+            return f_cols
+        # Case when 'f_cols' has an invalid type.
+        if not isinstance(f_cols, (list, tuple, set)):
+            raise TypeError(
+                f"'f_cols' must be None or a list, nor '{type(f_cols)}'."
+            )
+        # Case when 'f_cols' is a list of str: verify and return it.
+        if all(isinstance(col, str) for col in f_cols):
+            if not isinstance(data, pd.DataFrame):
+                raise ValueError(
+                    "'f_cols' is a list of str but 'data' is not a DataFrame."
+                )
+            if set(f_cols).issubset(data.columns):
+                return f_cols.copy()
+            raise ValueError(
+                "Specified 'f_cols' is not a subset of 'data' columns."
+            )
+        # Case when 'f_cols' is a list of str: verify and return it.
+        if all(isinstance(col, int) for col in f_cols):
+            if max(f_cols) >= data.shape[1]:  # type: ignore
+                raise ValueError(
+                    "Invalid 'f_cols' indices given 'data' shape."
+                )
+            return f_cols.copy()
+        # Case when 'f_cols' has mixed or invalid internal types.
+        raise TypeError(
+            "'f_cols' should be a list of all-int or all-str values."
+        )
+
+    @staticmethod
+    def _parse_array_or_column_argument(
+        value: Union[DataArray, str, None],
+        data: DataArray,
+        name: str,
+    ) -> Tuple[DataArray, Optional[str]]:
+        """Parse input 'target' argument.
+
+        Return 'target' (optional data array) and its source 'path'
+        when relevant (optional string).
+        """
+        # Case of a data array or None value: return as-is.
+        if isinstance(value, DATA_ARRAY_TYPES) or value is None:
+            return value, None
+        # Case of an invalid type: raise.
+        if not isinstance(value, str):
+            raise TypeError(
+                f"'{name}' must be a data array or str, not '{type(value)}'."
+            )
+        # Case of a string matching a 'data' column name: return it.
+        if isinstance(data, pd.DataFrame) and value in data.columns:
+            return data[value], value
+        # Case of a string matching nothing.
+        if not os.path.isfile(value):
+            raise ValueError(
+                f"'{name}' does not match any 'data' column nor file path."
+            )
+        # Case of a string matching a filepath.
+        try:
+            array = load_data_array(value)
+        except Exception as exc:
+            raise ValueError(
+                f"Error while trying to load '{name}' data from file."
+            ) from exc
+        if isinstance(array, pd.DataFrame) and len(array.columns) == 1:
+            array = array.iloc[:, 0]
+        return array, value
+
     @property
     def feats(
         self,
diff --git a/declearn/dataset/examples/_mnist.py b/declearn/dataset/examples/_mnist.py
index a839d2b28a48328283ec5be459868421d639d7d5..030db9522ad3ceaf3b001b91262da27d53f6eb99 100644
--- a/declearn/dataset/examples/_mnist.py
+++ b/declearn/dataset/examples/_mnist.py
@@ -71,7 +71,10 @@ def _load_mnist_data(
     name = f"{name}-ubyte.gz"
     # Optionally download the gzipped file from the internet.
     if folder is None or not os.path.isfile(os.path.join(folder, name)):
-        data = _download_mnist_file(name, folder)
+        try:
+            data = _download_mnist_file(name, folder, lecun=True)
+        except RuntimeError:  # pragma: no cover
+            data = _download_mnist_file(name, folder, lecun=False)
         data = gzip.decompress(data)
     # Otherwise, read its contents from a local copy.
     else:
@@ -88,13 +91,20 @@ def _load_mnist_data(
     return (array / 255).astype(np.single) if images else array
 
 
-def _download_mnist_file(name: str, folder: Optional[str]) -> bytes:
+def _download_mnist_file(
+    name: str,
+    folder: Optional[str],
+    lecun: bool = True,
+) -> bytes:
     """Download a MNIST source file and opt. save it in a given folder."""
     # Download the file in memory.
     print(f"Downloading MNIST source file {name}.")
-    reply = requests.get(
-        f"http://yann.lecun.com/exdb/mnist/{name}", timeout=300
+    base_url = (
+        "http://yann.lecun.com/exdb"
+        if lecun
+        else "https://ossci-datasets.s3.amazonaws.com"
     )
+    reply = requests.get(f"{base_url}/mnist/{name}", timeout=300)
     try:
         reply.raise_for_status()
     except requests.HTTPError as exc:
diff --git a/declearn/dataset/utils/_save_load.py b/declearn/dataset/utils/_save_load.py
index 90d936f09e9a549c3bc1617e617c1eb84a960456..efb0ef4a21fd7cf858a22f4e5a7fb2a2a7782f0c 100644
--- a/declearn/dataset/utils/_save_load.py
+++ b/declearn/dataset/utils/_save_load.py
@@ -22,11 +22,12 @@ import os
 from typing import Any, Union
 
 import numpy as np
-import pandas as pd  # type: ignore
+import pandas as pd
 from scipy.sparse import spmatrix  # type: ignore
 from sklearn.datasets import load_svmlight_file  # type: ignore
 
 from declearn.dataset.utils._sparse import sparse_from_file, sparse_to_file
+from declearn.typing import DataArray
 
 __all__ = [
     "load_data_array",
@@ -34,9 +35,6 @@ __all__ = [
 ]
 
 
-DataArray = Union[np.ndarray, pd.DataFrame, spmatrix]
-
-
 def load_data_array(
     path: str,
     **kwargs: Any,
diff --git a/declearn/dataset/utils/_split_classif.py b/declearn/dataset/utils/_split_classif.py
index 0e3dbc98d348141e27b52f80605f1eb0176cc2c2..7f254188b03e71c5dfab1e2e3b80047fefcb6f42 100644
--- a/declearn/dataset/utils/_split_classif.py
+++ b/declearn/dataset/utils/_split_classif.py
@@ -17,9 +17,11 @@
 
 """Utils to split a multi-category classification dataset into shards."""
 
-from typing import List, Literal, Optional, Tuple, Type, Union
+import functools
+from typing import Any, List, Literal, Optional, Tuple, Type, Union
 
 import numpy as np
+import scipy.stats  # type: ignore
 from scipy.sparse import csr_matrix, spmatrix  # type: ignore
 
 
@@ -31,9 +33,10 @@ __all__ = [
 def split_multi_classif_dataset(
     dataset: Tuple[Union[np.ndarray, spmatrix], np.ndarray],
     n_shards: int,
-    scheme: Literal["iid", "labels", "biased"],
+    scheme: Literal["iid", "labels", "dirichlet", "biased"],
     p_valid: float = 0.2,
     seed: Optional[int] = None,
+    **kwargs: Any,
 ) -> List[Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]]:
     """Split a classification dataset into (opt. heterogeneous) shards.
 
@@ -42,6 +45,9 @@ def split_multi_classif_dataset(
     - If "iid", split the dataset through iid random sampling.
     - If "labels", split into shards that hold all samples associated
       with mutually-exclusive target classes.
+    - If "dirichlet", split the dataset through random sampling using
+      label-wise shard-assignment probabilities drawn from a symmetrical
+      Dirichlet distribution, parametrized by an `alpha` parameter.
     - If "biased", split the dataset through random sampling according
       to a shard-specific random labels distribution.
 
@@ -53,13 +59,17 @@ def split_multi_classif_dataset(
         be a scipy sparse matrix, that will temporarily be cast to CSR.
     n_shards: int
         Number of shards between which to split the dataset.
-    scheme: {"iid", "labels", "biased"}
+    scheme: {"iid", "labels", "dirichlet", "biased"}
         Splitting scheme to use. In all cases, shards contain mutually-
         exclusive samples and cover the full dataset. See details above.
     p_valid: float, default=0.2
         Share of each shard to turn into a validation subset.
     seed: int or None, default=None
         Optional seed to the RNG used for all sampling operations.
+    **kwargs:
+        Additional hyper-parameters specific to the split scheme.
+        Exhaustive list of possible values:
+            - `alpha: float = 0.5` for `scheme="dirichlet"`
 
     Returns
     -------
@@ -80,6 +90,10 @@ def split_multi_classif_dataset(
         func = split_iid
     elif scheme == "labels":
         func = split_labels
+    elif scheme == "dirichlet":
+        func = functools.partial(
+            split_dirichlet, alpha=kwargs.get("alpha", 0.5)
+        )
     elif scheme == "biased":
         func = split_biased
     else:
@@ -157,7 +171,14 @@ def split_biased(
     n_shards: int,
     rng: np.random.Generator,
 ) -> List[Tuple[np.ndarray, np.ndarray]]:
-    """Split a dataset into shards with heterogeneous label distributions."""
+    """Split a dataset into shards with heterogeneous label distributions.
+
+    Use a normal distribution to draw logits of labels distributions for
+    each and every node.
+
+    This approach is not based on the litterature. We advise end-users to
+    use a Dirichlet split instead, which is probably better-grounded.
+    """
     classes = np.unique(target)
     index = np.arange(len(target))
     s_len = len(target) // n_shards
@@ -178,6 +199,36 @@ def split_biased(
     return split
 
 
+def split_dirichlet(
+    inputs: Union[np.ndarray, csr_matrix],
+    target: np.ndarray,
+    n_shards: int,
+    rng: np.random.Generator,
+    alpha: float = 0.5,
+) -> List[Tuple[np.ndarray, np.ndarray]]:
+    """Split a dataset into shards with heterogeneous label distributions.
+
+    Use a symmetrical multinomial Dirichlet(alpha) distribution to sample
+    the proportion of samples per label in each shard.
+
+    This approach has notably been used by Sturluson et al. (2021).
+    FedRAD: Federated Robust Adaptive Distillation. arXiv:2112.01405 [cs.LG]
+    """
+    classes = np.unique(target)
+    # Draw per-label proportion of samples to assign to each shard.
+    process = scipy.stats.dirichlet(alpha=[alpha] * n_shards)
+    c_probs = process.rvs(size=len(classes), random_state=rng)
+    # Randomly assign label-wise samples to shards based on these.
+    shard_i = [[] for _ in range(n_shards)]  # type: List[List[int]]
+    for lab_i, label in enumerate(classes):
+        index = np.where(target == label)[0]
+        s_idx = rng.choice(n_shards, size=len(index), p=c_probs[lab_i])
+        for i in range(n_shards):
+            shard_i[i].extend(index[s_idx == i])
+    # Gather the actual sample shards.
+    return [(inputs[index], target[index]) for index in shard_i]
+
+
 def train_valid_split(
     inputs: np.ndarray,
     target: np.ndarray,
diff --git a/declearn/model/haiku/__init__.py b/declearn/model/haiku/__init__.py
index 1adb8f1f2d0d9ec2593a3dcacc1d25bd808070cf..596d4c5c25ac11b7bf6be54e61c7642dbca6af12 100644
--- a/declearn/model/haiku/__init__.py
+++ b/declearn/model/haiku/__init__.py
@@ -17,9 +17,21 @@
 
 """Haiku models interfacing tools.
 
-This submodule provides with a generic interface to wrap up
-any Haiku module instance that is to be trained
-through gradient descent.
+Haiku is a Google DeepMind library that provides with tools to build
+artificial neural networks backed by the JAX computation library. We
+selected it as a primary candidate to support using JAX-backed models,
+mostly because of its simplicity, that leaves apart some components
+that DecLearn already provides (such as optimization algorithms).
+
+In July 2023, Haiku development was announced to be stalled as far as
+new features are concerned, in favor of Flax, a concurrent Google project.
+
+DecLearn is planned to add support for Flax at some point (building on the
+existing Haiku-oriented code, notably as far as Jax NumPy is concerned).
+In the meanwhile, this submodule enables running code that operates using
+haiku, which probably does not cover a lot of use cases, but it bound to
+keep working at least for a while, until Google decides to drop maintenance
+altogether.
 
 This module exposes:
 * HaikuModel: Model subclass to wrap haiku.Model objects
diff --git a/declearn/model/haiku/_vector.py b/declearn/model/haiku/_vector.py
index be4aedb74bd9b8aaa52e978668bf60fe48063233..3c97a67f9f01341e46fe8497cd151b9b8544de9a 100644
--- a/declearn/model/haiku/_vector.py
+++ b/declearn/model/haiku/_vector.py
@@ -17,8 +17,7 @@
 
 """JaxNumpyVector data arrays container."""
 
-import warnings
-from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
+from typing import Any, Callable, Dict, List, Set, Tuple, Type, Union
 
 import jax
 import jax.numpy as jnp
@@ -153,19 +152,9 @@ class JaxNumpyVector(Vector):
 
     def sum(
         self,
-        axis: Optional[int] = None,
-        keepdims: bool = False,
     ) -> Self:
-        if isinstance(axis, int) or keepdims:
-            warnings.warn(  # pragma: no cover
-                "The 'axis' and 'keepdims' arguments of 'JaxNumpyVector.sum' "
-                "have been deprecated as of declearn v2.3, and will be "
-                "removed in version 2.5 and/or 3.0.",
-                DeprecationWarning,
-            )
         coefs = {
-            key: jnp.array(jnp.sum(val, axis=axis, keepdims=keepdims))
-            for key, val in self.coefs.items()
+            key: jnp.array(jnp.sum(val)) for key, val in self.coefs.items()
         }
         return self.__class__(coefs)
 
diff --git a/declearn/model/sklearn/_np_vec.py b/declearn/model/sklearn/_np_vec.py
index 40a8805ebd6010cbd12e1d142e966b9167a245a5..6ed6af7919a07f5e46b5b340acb1e5a594d1f6bb 100644
--- a/declearn/model/sklearn/_np_vec.py
+++ b/declearn/model/sklearn/_np_vec.py
@@ -17,8 +17,7 @@
 
 """NumpyVector data arrays container."""
 
-import warnings
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+from typing import Any, Callable, Dict, List, Tuple, Union
 
 import numpy as np
 from typing_extensions import Self  # future: import from typing (Py>=3.11)
@@ -121,20 +120,8 @@ class NumpyVector(Vector):
 
     def sum(
         self,
-        axis: Optional[int] = None,
-        keepdims: bool = False,
     ) -> Self:
-        if isinstance(axis, int) or keepdims:
-            warnings.warn(  # pragma: no cover
-                "The 'axis' and 'keepdims' arguments of 'NumpyVector.sum' "
-                "have been deprecated as of declearn v2.3, and will be "
-                "removed in version 2.5 and/or 3.0.",
-                DeprecationWarning,
-            )
-        coefs = {
-            key: np.array(np.sum(val, axis=axis, keepdims=keepdims))
-            for key, val in self.coefs.items()
-        }
+        coefs = {key: np.array(np.sum(val)) for key, val in self.coefs.items()}
         return self.__class__(coefs)
 
     def flatten(
diff --git a/declearn/model/tensorflow/_vector.py b/declearn/model/tensorflow/_vector.py
index ef0a0d17861ad50f8c6851252b5152af155f6ddb..146754298402788b76ac5a133b3dcb8ccd94c0a1 100644
--- a/declearn/model/tensorflow/_vector.py
+++ b/declearn/model/tensorflow/_vector.py
@@ -18,10 +18,7 @@
 """TensorflowVector data arrays container."""
 
 import warnings
-from typing import (
-    # fmt: off
-    Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union
-)
+from typing import Any, Callable, Dict, List, Set, Tuple, Type, TypeVar, Union
 
 # fmt: off
 import numpy as np
@@ -279,27 +276,8 @@ class TensorflowVector(Vector):
 
     def sum(
         self,
-        axis: Optional[int] = None,
-        keepdims: bool = False,
     ) -> Self:
-        if keepdims or (axis is not None):
-            if any(  # pragma: no cover
-                isinstance(x, tf.IndexedSlices) for x in self.coefs.values()
-            ):
-                warnings.warn(  # pragma: no cover
-                    "Calling `TensorflowVector.sum()` with non-default "
-                    "arguments and tf.IndexedSlices coefficients might "
-                    "result in unexpected outputs, due to the latter "
-                    "being converted to their dense counterpart.",
-                    category=RuntimeWarning,
-                )
-            warnings.warn(  # pragma: no cover
-                "The 'axis' and 'keepdims' arguments of 'TensorflowVector.sum'"
-                " have been deprecated as of declearn v2.3, and will be "
-                "removed in version 2.5 and/or 3.0.",
-                DeprecationWarning,
-            )
-        return self.apply_func(tf.reduce_sum, axis=axis, keepdims=keepdims)
+        return self.apply_func(tf.reduce_sum)
 
     def __pow__(
         self,
diff --git a/declearn/model/tensorflow/utils/_gpu.py b/declearn/model/tensorflow/utils/_gpu.py
index a6f7f9f31bc9ce4d700a76fae5cc4e4e4497c20b..f9b86ccf6e30da6461bf928228a81c17952ae6f1 100644
--- a/declearn/model/tensorflow/utils/_gpu.py
+++ b/declearn/model/tensorflow/utils/_gpu.py
@@ -59,7 +59,7 @@ def select_device(
 
     Returns
     -------
-    device: tf.config.LogicalDevice
+    device:
         Selected device, usable as `tf.device` argument.
     """
     idx = 0 if idx is None else idx
@@ -106,14 +106,15 @@ def move_layer_to_device(
 
     Returns
     -------
-    layer: tf_keras.layers.Layer
+    layer:
         Copy of the input layer, with its weights backed on `device`.
     """
     config = tf_keras.layers.serialize(layer)
-    weights = layer.get_weights()
+    weights = layer.get_weights() if layer.built else None
     with tf.device(device):
         layer = tf_keras.layers.deserialize(config)
-        layer.set_weights(weights)
+        if weights:
+            layer.set_weights(weights)
     return layer
 
 
diff --git a/declearn/model/torch/_vector.py b/declearn/model/torch/_vector.py
index de9447956b13fe1996615a56a60b9d4c33b85fdf..40fcc58ec26b6b7c938c6a665ac4d88f1b95aa0a 100644
--- a/declearn/model/torch/_vector.py
+++ b/declearn/model/torch/_vector.py
@@ -17,8 +17,7 @@
 
 """TorchVector data arrays container."""
 
-import warnings
-from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
+from typing import Any, Callable, Dict, List, Set, Tuple, Type, Union
 
 import numpy as np
 import torch
@@ -214,20 +213,8 @@ class TorchVector(Vector):
 
     def sum(
         self,
-        axis: Optional[int] = None,
-        keepdims: bool = False,
     ) -> Self:
-        if isinstance(axis, int) or keepdims:
-            warnings.warn(  # pragma: no cover
-                "The 'axis' and 'keepdims' arguments of 'TorchVector.sum' "
-                "have been deprecated as of declearn v2.3, and will be "
-                "removed in version 2.5 and/or 3.0.",
-                DeprecationWarning,
-            )
-        coefs = {
-            key: val.sum(dim=axis, keepdims=keepdims)
-            for key, val in self.coefs.items()
-        }
+        coefs = {key: val.sum() for key, val in self.coefs.items()}
         return self.__class__(coefs)
 
     def flatten(
diff --git a/declearn/typing.py b/declearn/typing.py
index c7b1c6429a093c14fc6110de2c8acbe220668cc8..916b1a718d8177e8ced2d165936cb176ca59070d 100644
--- a/declearn/typing.py
+++ b/declearn/typing.py
@@ -20,12 +20,16 @@
 from abc import ABCMeta, abstractmethod
 from typing import Any, Dict, List, Optional, Protocol, Tuple, Union
 
+import numpy as np
+import pandas as pd
 from numpy.typing import ArrayLike
+from scipy.sparse import spmatrix  # type: ignore
 from typing_extensions import Self  # future: import from typing (Py>=3.11)
 
 
 __all__ = [
     "Batch",
+    "DataArray",
     "SupportsConfig",
 ]
 
@@ -44,6 +48,15 @@ This type-hint designates (inputs, labels, weights) inputs, where:
 """  # this is rendered as a docstring for `Batch` in the docs
 
 
+DataArray = Union[np.ndarray, pd.DataFrame, pd.Series, spmatrix]
+"""Type-annotation alias for a union of data type structures.
+
+This alias covers types supported by [declearn.dataset.utils.save_data_array][]
+and its counterpart [declearn.dataset.utils.load_data_array][], and is hence
+used to annotate some dataset-interfacing tools under [declearn.dataset][].
+"""
+
+
 class SupportsConfig(Protocol, metaclass=ABCMeta):
     """Protocol for type annotation of objects with get/from_config methods.
 
diff --git a/declearn/version.py b/declearn/version.py
index ed20f25a5ca69ae11b600b2e902c4fc82afba255..93721d5083e5fbea3219d4d9eb0947c2adb45be7 100644
--- a/declearn/version.py
+++ b/declearn/version.py
@@ -17,5 +17,5 @@
 
 """DecLearn version information, as hard-coded constants."""
 
-VERSION = "2.4.0"
+VERSION = "2.5.0"
 """Version information of the installed DecLearn package."""
diff --git a/docs/release-notes/SUMMARY.md b/docs/release-notes/SUMMARY.md
index 568e1343869a8b2bcc974c538c5f1e7371c0fc10..6da663077a6734430c23d628a2ff73841170ab6b 100644
--- a/docs/release-notes/SUMMARY.md
+++ b/docs/release-notes/SUMMARY.md
@@ -1,3 +1,4 @@
+- [v2.5.0](v2.5.0.md)
 - [v2.4.0](v2.4.0.md)
 - [v2.3.2](v2.3.2.md)
 - [v2.3.1](v2.3.1.md)
diff --git a/docs/release-notes/v2.5.0.md b/docs/release-notes/v2.5.0.md
new file mode 100644
index 0000000000000000000000000000000000000000..27a4e923066459157ccb6334ef1279007ea96467
--- /dev/null
+++ b/docs/release-notes/v2.5.0.md
@@ -0,0 +1,91 @@
+# declearn v2.4.0
+
+Released: 13/05/2024
+
+## Release Highlight: Secure Aggregation
+
+### Overview
+
+This new version of DecLearn is mostly about enabling the use of Secure
+Aggregation (also known as SecAgg), _i.e._ methods that enable aggregating
+client-emitted information without revealing said information to the server
+in charge of this aggregation.
+
+DecLearn now implements both a generic API for SecAgg and a couple of practical
+solutions that are ready-for-use. This makes SecAgg easily-usable as part of
+the existing federated learning process, and extensible by advanced users or
+researchers that would like to use or test their own method and/or setup.
+
+New features are mostly implemented under the new `declearn.secagg` submodule,
+with further changes to `declearn.main.FederatedClient` and `FederatesServer`
+integrating these features to the main process.
+
+### Usage and scope
+
+Setting up SecAgg is meant to be straightforward:
+
+- The server and clients agree on a SecAgg method and, when required, some
+  hyper-parameters in advance.
+- Clients need to hold a private Ed25519 identity key and share the associate
+  public key with all other clients in advance, so that the SecAgg setup can
+  include the verification that ephemeral-key-agreement information comes from
+  trusted peers. _(This is mandatory in current implementations, but may be
+  removed in custom setup implementations.)_
+- The server and clients must simply pass an additional `secagg` keyword
+  argument when instantiating their `FederatedServer` or `FederatedClient`,
+  which can take the form of a class, a dict or the path to a TOML file.
+- _Voilà!_
+
+At the moment, SecAgg is used to secure the aggregation of model parameter
+updates, optimizer auxiliary variables and evaluation metrics, as well as
+some metadata from training and evaluation rounds. In the future, we plan
+to cover metadata queries and (yet-to-be-implemented) federated analytics.
+
+### Available algorithms
+
+At the moment, DecLearn provides with the following SecAgg algorithms:
+
+- Masking-based SecAgg (`declearn.secagg.masking`), that uses pseudo-random
+  number generators (PRNG) to generate masks over a finite integer field so
+  that the sum of clients' masks is known to be zero.
+    - This is based on
+      [Bonawitz et al., 2016](https://dl.acm.org/doi/10.1145/3133956.3133982).
+    - The setup that produces pairwise PRNG seeds is conducted using the
+      [X3DH](https://www.signal.org/docs/specifications/x3dh/) protocol.
+    - This solution has very limited computation and commmunication overhead
+      and should be considered the default SecAgg solution with DecLearn.
+
+- Joye-Libert sum-homomorphic encryption (`declearn.secagg.joye-libert`), that
+  uses actual encryption, modified summation operator, and aggregate-decryption
+  primitives that operate on a large biprime-defined integer field.
+    - This is based on
+      [Joye & Libert, 2013](https://marcjoye.github.io/papers/JL13aggreg.pdf).
+    - The setup that compute the public key as a sum of arbitrary private keys
+      involves the [X3DH](https://www.signal.org/docs/specifications/x3dh/)
+      protocol as well as
+      [Shamir Secret Sharing](https://dl.acm.org/doi/10.1145/359168.359176).
+    - This solution has a high computation and commmunication overhead. It is
+      not really suitable for model with many parameters (including few-layers
+      artificial neural networks).
+
+### Documentation
+
+In addition to the extensive in-code documentation, a guide on the SecAgg
+features was added to the user documentation, that may be found
+[here](../user-guide/secagg.md). If anything is not as clear as you would
+hope, do let us know by opening a GitLab or GitHub issue, or by dropping
+an e-mail to the package maintainers!
+
+## Other changes
+
+A few minor changes are shipped with this new release in addition to the new
+SecAgg features:
+
+- `InMemoryDataset` backend code was refactored and made more robust, and its
+  documentation was improved for readability purposes.
+- `declearn.typing.DataArray` was added as an alias.
+- `declearn.dataset.utils.split_multi_classif_dataset` was enhanced to support
+  a new scheme based on Dirichlet allocation. Unit tests were also added for
+  both this scheme and existing ones.
+- Unit tests for both `FederatedClient` and `FederatedServer` were added.
+- Deprecated keyword arguments of `Vector.sum` were removed, as due.
diff --git a/docs/user-guide/SUMMARY.md b/docs/user-guide/SUMMARY.md
index 7ce0e34f4c5bd1eb0610e2dfb4e0c7f380a5edfd..865118d0c1c485e2fcbed0dcf86f34eafd5a0916 100644
--- a/docs/user-guide/SUMMARY.md
+++ b/docs/user-guide/SUMMARY.md
@@ -4,3 +4,4 @@
 - [Hands-on usage](./usage.md)
 - [Guide to the Optimizer API](./optimizer.md)
 - [Local Differential Privacy capabilities](./local_dp.md)
+- [Secure Aggregation capabilities](./secagg.md)
diff --git a/docs/user-guide/index.md b/docs/user-guide/index.md
index 92a2b040d1d651f16543fec069c87b6d62c05733..362050cb74125820c35e2c0302ee6458b9aec886 100644
--- a/docs/user-guide/index.md
+++ b/docs/user-guide/index.md
@@ -16,3 +16,5 @@ This guide is structured this way:
     API, design principles and practical how-tos of the declearn Optimizer.
 - [Local Differential Privacy capabilities](./local_dp.md):<br/>
     Description of the local-DP features of declearn.
+- [Secure Aggregation capabilities](./secagg.md):<br/>
+    Description of the SecAgg features of declearn.
diff --git a/docs/user-guide/package.md b/docs/user-guide/package.md
index 76c0a32b03e45aa5d128958d5a308de135b1a00d..ed6eda864f604a20cd34f3e491512204069bd652 100644
--- a/docs/user-guide/package.md
+++ b/docs/user-guide/package.md
@@ -22,6 +22,8 @@ The package is organized into the following submodules:
   &emsp; Model interfacing API and implementations.
 - `optimizer`:<br/>
   &emsp; Framework-agnostic optimizer and algorithmic plug-ins API and tools.
+- `secagg`:<br/>
+  &emsp; Secure Aggregation API, methods and utils.
 - `typing`:<br/>
   &emsp; Type hinting utils, defined and exposed for code readability purposes.
 - `utils`:<br/>
@@ -198,6 +200,64 @@ You may learn more about our (non-abstract) `Optimizer` API by reading our
     - `declearn.dataset.torch.TorchDataset`
 - Extend: use `declearn.utils.register_type(group="Dataset")`.
 
+### Secure Aggregation (SecAgg)
+
+#### `SecaggConfigClient`
+- Import: `declearn.secagg.api.SecaggConfigClient`
+- Object: Set up Secure Aggregation based on wrapped configuration parameters.
+- Usage: Set up an `Encrypter` (see below) instance based on parameters and a
+  server-emitted setup request, jointly with other clients.
+- Examples:
+    - `declearn.secagg.joye_libert.JoyeLibertSecaggConfigClient`
+    - `declearn.secagg.masking.MaskingSecaggConfigClient`
+- Extend:
+    - Simply inherit from `SecaggConfigClient` (registration is automated).
+    - To avoid it, use `class MySetup(SecaggConfigClient, register=False)`.
+
+#### `SecaggConfigServer`
+- Import: `declearn.secagg.api.SecaggConfigServer`
+- Object: Set up Secure Aggregation based on wrapped configuration parameters.
+- Usage: Set up a `Decrypter` (see below) instance based on parameters and in
+  interaction with (a subset of) clients.
+- Examples:
+    - `declearn.secagg.joye_libert.JoyeLibertSecaggConfigServer`
+    - `declearn.secagg.masking.MaskingSecaggConfigServer`
+- Extend:
+    - Simply inherit from `SecaggConfigServer` (registration is automated).
+    - To avoid it, use `class MySetup(SecaggConfigServer, register=False)`.
+
+#### `Encrypter`
+- Import: `declearn.secagg.api.Encrypter`
+- Object: Encrypt values that need secure aggregation.
+- Usage: Encrypt shared data, typically packed as a `Message` or `Aggregate`.
+- Examples:
+    - `declearn.secagg.joye_libert.JoyeLibertEncrypter`
+    - `declearn.secagg.masking.MaskingEncrypter`
+- Extend: there is no type-registration, simply implement a subclass.
+
+#### `Decrypter`
+- Import: `declearn.secagg.api.Decrypter`
+- Object: Decrypt aggregated values to finalize secure aggregation.
+- Usage: Decrypt shared data, typically packed as a `Message` or
+  `SecureAggregate`.
+- Examples:
+    - `declearn.secagg.joye_libert.JoyeLibertDecrypter`
+    - `declearn.secagg.masking.MaskingDecrypter`
+- Extend: there is no type-registration, simply implement a subclass.
+
+#### `SecureAggregate`
+- Import: `declearn.secagg.api.SecureAggregate`
+- Object: Wrap up encrypted data from an `Aggregate` (e.g. `ModelUpdates`,
+  `AuxVar` or `MetricState` instance) and enable their aggregation.
+- Usage: Used by `Encrypter` and `Decrypter` to wrap up encrypted data.
+- Examples:
+    - `declearn.secagg.joye_libert.JlsAggregate`
+    - `declearn.secagg.masking.MaskedAggregate`
+- Extend:
+    - Simply inherit from `SecureAggregate` (registration is automated,
+      and is about making the class JSON-serializable).
+    - To avoid it, use `class MyClass(SecureAggregate, register=False)`.
+
 ## Full API Reference
 
 The full API reference, which is generated automatically from the code's
diff --git a/docs/user-guide/secagg.md b/docs/user-guide/secagg.md
new file mode 100644
index 0000000000000000000000000000000000000000..356219774b7869f83dd074cb6380acd90e90cfc9
--- /dev/null
+++ b/docs/user-guide/secagg.md
@@ -0,0 +1,319 @@
+# Secure Aggregation
+
+## Overview
+
+### What is Secure Aggregation?
+
+Secure Aggregation (often and hereafter abbreviated as SecAgg) is a generic
+term to describe methods that enable aggregating client-emitted information
+without revealing said information to the server in charge of this aggregation.
+In other words, SecAgg is about computing a public aggregate of private values
+in a secure way, limiting the amount of trust put into the server.
+
+Various methods have been proposed in the litterature, that may use homomorphic
+encryption, one-time pads, pseudo-random masks, multi-party computation methods
+and so on. These methods come with various costs (both in terms of computation
+overhead and communication costs, usually increasing messages' size and/or
+frequency), and various features (_e.g._ some support a given amount of loss
+of information due to clients dropping from the process; some are best-suited
+for some security settings than others; some require more involved setup...).
+
+### General capabilities
+
+DecLearn implements both a generic API for SecAgg and some practical solutions
+that are ready-for-use. In the current state of things however, some important
+hypotheses are enforced:
+
+- The current SecAgg implementations require clients to generate Ed25519
+  identity keys and share the associate public keys with other clients prior
+  to using DecLearn. The API also leans towards this requirement. In practice,
+  it would not be difficult to make it so that the server distributes clients'
+  public keys across the network, but we believe that the incurred loss in
+  security partially defeats the purpose of SecAgg, hence we prefer to provide
+  with the current behavior (that requires sharing keys via distinct channels)
+  and leave it up to end-users to set up alternatives if they want to.
+- There is no resilience to clients dropping from the process. If a client was
+  supposed to participate in a round but does not send any information, then
+  some new setup and secure aggregation would need to be performed. This may
+  change in the future, but is on par with the current limitations of the
+  framework.
+- There are no countermeasures to dishonest clients, i.e. there is no effort
+  put in verifying that outputs are coherent. This is not specific to SecAgg
+  either, and is again something that will hopefully be tackled in future
+  versions.
+
+### Details and caveats
+
+At the moment, the SecAgg API and its integration as part of the main federated
+learning orchestration classes has the following characteristics:
+
+- SecAgg is jointly parametrized by the server and clients, that must use
+  coherent parameters.
+    - The choice of using SecAgg and of a given SecAgg method must be the same
+      across all clients and the server, otherwise an error is early-raised.
+    - Some things are up to clients, notably the specification of identity keys
+      used to set up short-lived secrets, and more generally any parameter that
+      cannot be verified to be trustworthy if the server sets it up.
+    - Some things are up to the server, notably the choice of quantization
+      hyper-parameters, that is coupled with details on the expected magnitude
+      of shared quantities (gradients, metrics...).
+- SecAgg is set up every time participating clients to a training or validation
+  round change.
+    - Whenever participating clients do not match those in the previous round,
+      a SecAgg setup round is triggered to as to set up controllers anew across
+      participating clients.
+    - This is designed to guarantee that controllers are properly aligned and
+      there is no aggregation of mismatching encrypted values.
+    - This may however not be as optimal as theoretically-achievable in terms
+      of setup communication and computation costs.
+- SecAgg is used to protect training and evaluation rounds' results.
+    - Model updates, optimizer auxiliary variables and metrics are covered.
+    - Some metadata are secure-aggregated (number of steps), some are discarded
+      (number of epochs), some remain cleartext but are reduced (time spent for
+      computations is sent in cleartext and max-aggregated).
+- SecAgg does not (yet) cover other computations.
+    - Metadata queries are not secured (meaning the number of data samples may
+      be sent in cleartext).
+    - The topic of protecting specific or arbitrary quantities as part of
+      processes implemented by subclassing the current main classes remains
+      open.
+- SecAgg of `Aggregate`-inheriting objects (notably `ModelUpdates`, `AuxVar`
+  and `MetricState` instances) is based on their `prepare_for_secagg` method,
+  that may not always be defined. It is up to end-users to ensure that the
+  components they use (and custom components they might add) are properly
+  made compatible with SecAgg.
+
+## How to setup and use SecAgg
+
+From the end-user perspective, setting up and using SecAgg requires:
+
+- Having generated and shared clients' (long-lived) identity keys across
+  trusted peers prior to running DecLearn.
+- Having all peers specify coherent SecAgg parameters, that are passed to
+  the main `FederatedClient` and `FederatedServer` classes at instantiation.
+- Ensuring that the `Aggregator`, `OptiModule` plug-ins and `Metrics` used
+  for the experiment are compatible with SecAgg.
+- _Voilà!_
+
+### Available SecAgg algorithms
+
+At the moment, DecLearn provides with the following SecAgg algorithms:
+
+- Masking-based SecAgg (`declearn.secagg.masking`), that uses pseudo-random
+  number generators (PRNG) to generate masks over a finite integer field so
+  that the sum of clients' masks is known to be zero.
+    - This is based on
+      [Bonawitz et al., 2016](https://dl.acm.org/doi/10.1145/3133956.3133982).
+    - The setup that produces pairwise PRNG seeds is conducted using the
+      [X3DH](https://www.signal.org/docs/specifications/x3dh/) protocol.
+    - This solution has very limited computation and commmunication overhead
+      and should be considered the default SecAgg solution with DecLearn.
+
+- Joye-Libert sum-homomorphic encryption (`declearn.secagg.joye-libert`), that
+  uses actual encryption, modified summation operator, and aggregate-decryption
+  primitives that operate on a large biprime-defined integer field.
+    - This is based on
+      [Joye & Libert, 2013](https://marcjoye.github.io/papers/JL13aggreg.pdf).
+    - The setup that compute the public key as a sum of arbitrary private keys
+      involves the [X3DH](https://www.signal.org/docs/specifications/x3dh/)
+      protocol as well as
+      [Shamir Secret Sharing](https://dl.acm.org/doi/10.1145/359168.359176).
+    - This solution has a high computation and commmunication overhead. It is
+      not really suitable for model with many parameters (including few-layers
+      artificial neural networks).
+
+### Hands-on example
+
+If we use the [MNIST example](https://gitlab.inria.fr/magnet/declearn/declearn2/-/tree/develop/examples/mnist/)
+implemented via Python scripts and want to use the DecLearn-provided
+masking-based algorithm for SecAgg (see below), we merely have to apply the
+following modifications:
+
+**1. Generate client Ed25519 identity keys**
+
+This may be done using dedicated tools, but here is a script to merely
+generate and dump identity keys to files on a local computer:
+
+```python
+from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey
+from cryptography.hazmat.primitives import serialization
+from declearn.secagg.utils import IdentityKeys
+
+# Generate some private Ed25519 keys and gather their public counterparts.
+n_clients = 5  # adjust to the number of clients
+private_keys = [Ed25519PrivateKey.generate() for _ in range(n_clients)]
+public_keys = [key.public_key() for key in private_keys]
+
+# Export all public keys as a single file with custom format.
+IdentityKeys(private_keys[0], trusted=public_keys).export_trusted_keys_to_file(
+    "trusted_public.keys"
+)
+
+# Export private keys as PEM files without password protection.
+for idx, key in enumerate(private_keys):
+    dat = key.private_bytes(
+        encoding=serialization.Encoding.PEM,
+        format=serialization.PrivateFormat.OpenSSH,
+        encryption_algorithm=serialization.NoEncryption(),
+    )
+    with open(f"private_{idx}.pem", "wb") as file:
+        file.write(dat)
+```
+
+**2. Write client-side SecAgg config**
+
+Add the following code as part of the client routine to setup and run the
+federated learning process. Paths and parameters may and should be adjusted
+for practical use cases.
+
+```python
+from declearn.secagg import parse_secagg_config_client
+from declearn.secagg.utils import IdentityKeys
+
+# Add this as part of the `run_client` function in the `run_client.py` file.
+
+client_idx = int(client_name.rsplit("_", 1)[-1])
+secagg = parse_secagg_config_client(
+    secagg_type="masking",
+    id_keys=IdentityKeys(
+        prv_key=f"client_{client_idx}.pem",
+        trusted="trusted_public.keys",
+    )
+)
+# Alternatively, write `secagg` as the dict of previous kwargs,
+# or as a `declearn.secagg.masking.MaskingSecaggConfigClient` instance.
+
+# Overload the `FederatedClient` instantation in step (5) of the function.
+
+client = declearn.main.FederatedClient(
+    # ... what is already there
+    secagg=secagg,
+)
+```
+
+**3. Write server-side SecAgg config**
+
+Add the following code as part of the server routine to setup and run the
+federated learning process. Parameters may and should be adjusted to practical
+use cases.
+
+```python
+from declearn.secagg import parse_secagg_config_server
+
+# Add this as part of the `run_server` function in the `run_server.py` file.
+
+secagg = parse_secagg_config_server(
+    secagg_type="masking",
+    # You may tune hyper-parameters, controlling values' quantization.
+)
+# Alternatively, write `secagg` as the dict of previous kwargs,
+# or as a `declearn.secagg.masking.MaskingSecaggConfigServer` instance.
+
+# Overload the `FederatedServer` instantation in step (4) of the function.
+
+server = declearn.main.FederatedServer(
+    # ... what is already there
+    secagg=secagg,
+)
+```
+
+Note that if the configurations do not match (whether as to the use of SecAgg
+or not, the choice of algorithm, its hyper-parameters, or the validity of
+trusted identity keys), an error will be raised at some (early) point when
+attempting to run the Federated Learning process.
+
+## How to implement a new SecAgg method
+
+The API for SecAgg relies on a number of abstractions, that may be divided in
+two categories:
+
+- backend controllers that provide with primitives to encrypt, aggregate
+  and decrypt values;
+- user-end controllers that parse configuration parameters and implement
+  a setup protocol to instantiate backend controllers.
+
+As such, it is possible to write a new setup for existing controllers; but in
+general, adding a new SecAgg algorithm to DecLearn will involve writing it all
+up. The first category is somewhat decoupled from the rest of the framework (it
+only needs to know what the `Vector` and `Aggregate` data structures are, and
+how to operate on them), while the second is very much coupled with the network
+communication and messaging APIs.
+
+### `Encrypter` and `Decrypter` controllers
+
+Primitives for the encryption of private values, aggregation of encrypted
+values and decryption of an aggregated value are to be implemented by a pair
+of `Encrypter` and `Decrypter` subclasses.
+
+These classes may (and often will) use any form of (private) time index, that
+is not required to be shared with encrypted values, as values are assumed to be
+encrypted in the same order by each and every client and decrypted in that same
+order by the server once aggregated.
+
+`Encrypter` has two abstract methods that need implementing:
+
+- `encrypt_uint`, that encrypts a scalar uint value (that may arise from the
+  quantization of a float value). It is called when encrypting floats, numpy
+  arrays and declearn `Vector` or `Aggregate` instances.
+- `wrap_into_secure_aggregate`, that is called to wrap up encrypted values
+  and other metadata and information from an `Aggregate` instance into a
+  `SecureAggregate` one. This should in most cases merely correspond to
+  instantiating the proper `SecureAggregate` subclass with a mix of input
+  arguments to the method and attributes from the `Encrypter` instance.
+
+`Decrypter` has an abstract class attribute and two abstract methods that
+need implementing:
+
+- `sum_encrypted`, that aggregates a list of two or more encrypted values that
+  need aggregation into a single one (in a way that makes their aggregate
+  decryptable into the sum of cleartext values).
+- `decrypt_uint`, that decrypts an input into a scalar uint value. It is called
+  when decrypting inputs into any supported type, and is the counterpart to the
+  `Encrypter.encrypt_uint` method.
+- `secure_aggregate_cls`, that is a class attribute that is merely the type of
+  the `SecureAggregate` subclass emitted by the paired `Encrypter`'s
+  `wrap_into_secure_aggregate` method.
+
+`SecureAggregate` is a third class that is used by the former two, and usually
+not directly accessed by end-users. It is an `Aggregate`-like wrapper for
+encrypted counterparts to `Aggregate` objects. The only abstract method that
+needs defining is `aggregate_encrypted`, which should mostly be the same as
+`Decrypter.sum_encrypted`. Subclasses may also embark additional metadata
+about the SecAgg algorithm's parameters, and conduct associate verification
+of coherence across peers.
+
+### `SecaggConfigClient` and `SecaggConfigServer` endpoints
+
+Routines to set up matching `Encrypter` and `Decrypter` instances across a
+federated network of peers are to be implemented by a pair of
+`SecaggConfligClient` and `SecaggConfligServer` subclasses. In addition, a
+dedicated `SecaggSetupQuery` subclass (itself a `Message`) should be defined.
+
+These classes define a setup that is bound to be initiated by the server, that
+emits a `SecaggSetupQuery` message triggering the client's call to the setup
+routine. After that, any number of network communication exchanges may be run,
+depending on the setup being implemented.
+
+Both `SecaggConfigClient` and `SecaggConfigServer` subclasses must be decorated
+as `dataclasses.dataclass`, so as to benefit from TOML-parsing capabilities.
+They are automatically type-registered (which may be prevented by passing the
+`register=False` parameter at inheritance).
+
+`SecaggConfigClient` has an abstract class attribute and an abstract method:
+
+- `secagg_type`, that is a string class attribute that must be unique across
+  subclasses (for type-registration) and match that of the server-side class.
+- `setup_encrypter`, that takes a `NetworkClient` and a received serialized
+  `SecaggSetupQuery` message, and conducts any steps towards setting up and
+  returning an `Encrypter` instance.
+
+`SecaggConfigServer` has an abstract class attribute and two abstract methods:
+
+- `secagg_type`, that is a string class attribute that must be unique across
+  subclasses (for type-registration) and match that of the client-side class.
+- `prepare_secagg_setup_query`, that returns a `SecaggSetupQuery` to be sent
+  to (the subset of) clients that are meant to participate in the setup.
+- `finalize_secagg_setup`, that takes a `NetworkServer` and an optional set
+  of client names, is called right after sending the setup query to (these)
+  clients and may thereafter conduct any steps towards setting up and returning
+  a `Decrypter` instance.
diff --git a/pyproject.toml b/pyproject.toml
index 58af1de0fafb70b92141a2de4740a5738b742cd3..68f9dcb10c96f92ee50eebbd85b4834ca38820f0 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -8,7 +8,7 @@ requires = [
 
 [project]
 name = "declearn"
-version = "2.4.0"
+version = "2.5.0"
 description = "Declearn - a python package for private decentralized learning."
 readme = "README.md"
 requires-python = ">=3.8"
diff --git a/test/dataset/test_split_multi_classif.py b/test/dataset/test_split_multi_classif.py
new file mode 100644
index 0000000000000000000000000000000000000000..39659fe3831a29f9ab012621c35e342003bea687
--- /dev/null
+++ b/test/dataset/test_split_multi_classif.py
@@ -0,0 +1,224 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for 'declearn.dataset.utils.split_multi_classif_dataset'."""
+
+from typing import List, Tuple, Type, Union
+
+import numpy as np
+import pytest
+from scipy.sparse import coo_matrix, spmatrix  # type: ignore
+from scipy.stats import chi2_contingency  # type: ignore
+
+from declearn.dataset.utils import split_multi_classif_dataset
+
+
+Array = Union[np.ndarray, spmatrix]
+
+
+@pytest.fixture(name="dataset", scope="module")
+def dataset_fixture(
+    sparse: bool,
+) -> Tuple[Array, np.ndarray]:
+    """Fixture providing with a random classification dataset."""
+    rng = np.random.default_rng(seed=0)
+    x_dat = rng.normal(size=(400, 32))
+    y_dat = rng.choice(4, size=400)
+    if sparse:
+        return coo_matrix(x_dat), y_dat
+    return x_dat, y_dat
+
+
+@pytest.mark.parametrize(
+    "sparse", [False, True], ids=["dense", "sparse"], scope="module"
+)
+class TestSplitMultiClassifDataset:
+    """Unit tests for `split_multi_classif_dataset`."""
+
+    @classmethod
+    def assert_expected_shard_shapes(
+        cls,
+        dataset: Tuple[Array, np.ndarray],
+        shards: List[
+            Tuple[Tuple[Array, np.ndarray], Tuple[Array, np.ndarray]]
+        ],
+        n_shards: int,
+        p_valid: float,
+    ) -> None:
+        """Verify that shards match expected shapes."""
+        assert isinstance(shards, list) and len(shards) == n_shards
+        n_samples = 0
+        for shard in shards:
+            cls._assert_valid_shard(
+                shard,
+                n_feats=dataset[0].shape[1],
+                n_label=len(np.unique(dataset[1])),
+                p_valid=p_valid,
+                x_type=type(dataset[0]),
+            )
+            n_samples += shard[0][0].shape[0] + shard[1][0].shape[0]
+        assert n_samples == dataset[0].shape[0]
+
+    @staticmethod
+    def _assert_valid_shard(
+        shard: Tuple[Tuple[Array, np.ndarray], Tuple[Array, np.ndarray]],
+        n_feats: int = 32,
+        n_label: int = 4,
+        p_valid: float = 0.2,
+        x_type: Union[Type[np.ndarray], Type[spmatrix]] = np.ndarray,
+    ) -> None:
+        """Assert that a given dataset shard matches expected specs."""
+        # Unpack arrays and verify that their types are coherent.
+        (x_train, y_train), (x_valid, y_valid) = shard
+        assert isinstance(x_train, x_type)
+        assert isinstance(x_valid, x_type)
+        assert isinstance(y_train, np.ndarray)
+        assert isinstance(y_valid, np.ndarray)
+        # Assert that array shapes match expectations.
+        assert x_train.ndim == x_valid.ndim == 2
+        assert x_train.shape[0] == y_train.shape[0]
+        assert x_valid.shape[0] == y_valid.shape[0]
+        assert x_train.shape[1] == x_valid.shape[1] == n_feats
+        assert y_train.ndim == y_valid.ndim == 1
+        # Assert that labels have proper values.
+        labels = list(range(n_label))
+        assert np.all(np.isin(y_train, labels))
+        assert np.all(np.isin(y_valid, labels))
+        # Assert that train/valid partition matches expectation.
+        s_valid = x_valid.shape[0] / (x_train.shape[0] + x_valid.shape[0])
+        assert abs(p_valid - s_valid) <= 0.02
+
+    @staticmethod
+    def get_label_counts(
+        shards: List[
+            Tuple[Tuple[Array, np.ndarray], Tuple[Array, np.ndarray]]
+        ],
+        n_label: int = 4,
+    ) -> np.ndarray:
+        """Return the shard-wise subset-wise label counts, stacked together."""
+        counts = [
+            np.bincount(y, minlength=n_label)
+            for (_, y_train), (_, y_valid) in shards
+            for y in (y_train, y_valid)
+        ]
+        return np.stack(counts)
+
+    def test_scheme_iid(
+        self,
+        dataset: Tuple[Array, np.ndarray],
+    ) -> None:
+        """Test that the iid scheme yields iid samples."""
+        shards = split_multi_classif_dataset(
+            dataset, n_shards=4, scheme="iid", p_valid=0.2, seed=0
+        )
+        # Verify that shards match expected shapes.
+        self.assert_expected_shard_shapes(
+            dataset, shards, n_shards=4, p_valid=0.2
+        )
+        # Verify that labels are iid-distributed across shards.
+        # To do so, use a chi2-test with blatantly high acceptance rate
+        # for the null hypothesis that distributions differ, and verify
+        # that the hypothesis would still be rejected at that rate.
+        y_counts = self.get_label_counts(shards)
+        assert chi2_contingency(y_counts).pvalue >= 0.90
+
+    def test_scheme_labels(
+        self,
+        dataset: Tuple[Array, np.ndarray],
+    ) -> None:
+        """Test that the labels scheme yields non-overlapping-labels shards."""
+        shards = split_multi_classif_dataset(
+            dataset, n_shards=2, scheme="labels", p_valid=0.4, seed=0
+        )
+        # Verify that shards match expected shapes.
+        self.assert_expected_shard_shapes(
+            dataset, shards, n_shards=2, p_valid=0.4
+        )
+        # Verify that labels are distributed without overlap across shards.
+        labels_train_0 = np.unique(shards[0][0][1])
+        labels_valid_0 = np.unique(shards[0][1][1])
+        labels_train_1 = np.unique(shards[1][0][1])
+        labels_valid_1 = np.unique(shards[1][1][1])
+        assert np.all(labels_train_0 == labels_valid_0)
+        assert np.all(labels_train_1 == labels_valid_1)
+        assert np.intersect1d(labels_train_0, labels_train_1).shape == (0,)
+
+    def test_scheme_biased(
+        self,
+        dataset: Tuple[Array, np.ndarray],
+    ) -> None:
+        """Test that the biased scheme yields disparate-distrib. shards."""
+        shards = split_multi_classif_dataset(
+            dataset, n_shards=4, scheme="biased", p_valid=0.2, seed=0
+        )
+        # Verify that shards match expected shapes.
+        self.assert_expected_shard_shapes(
+            dataset, shards, n_shards=4, p_valid=0.2
+        )
+        # Verify that labels have distinct distributions across shards.
+        # To do so, use a chi2-test (with the null hypothesis that
+        # distributions differ), and verify that the hypothesis is
+        # accepted overall with high confidence, and rejected on
+        # train/valid pairs with high confidence as well.
+        y_counts = self.get_label_counts(shards)
+        assert chi2_contingency(y_counts).pvalue <= 1e-10
+        for i in range(len(shards)):
+            assert chi2_contingency(y_counts[i : i + 1]).pvalue >= 0.90
+
+    def test_scheme_dirichlet(
+        self,
+        dataset: Tuple[Array, np.ndarray],
+    ) -> None:
+        """Test the dirichlet scheme's disparity, with various alpha values."""
+        shards = split_multi_classif_dataset(
+            dataset, n_shards=2, scheme="dirichlet", seed=0
+        )
+        # Verify that shards match expected shapes.
+        self.assert_expected_shard_shapes(
+            dataset, shards, n_shards=2, p_valid=0.2
+        )
+        # Verify that labels are differently-distributed across shards.
+        # To do so, use a chi2-test (with the null hypothesis that
+        # distributions differ), and verify that the hypothesis is
+        # accepted overall with high confidence, and rejected on
+        # train/valid pairs with high confidence as well.
+        y_counts = self.get_label_counts(shards)
+        pval_low = chi2_contingency(y_counts).pvalue
+        assert pval_low <= 1e-10
+        for i in range(len(shards)):
+            assert chi2_contingency(y_counts[i : i + 1]).pvalue >= 0.60
+        # Verify that using a higher alpha results in a high p-value.
+        shards = split_multi_classif_dataset(
+            dataset, n_shards=2, scheme="dirichlet", seed=0, alpha=2.0
+        )
+        y_counts = self.get_label_counts(shards)
+        pval_mid = chi2_contingency(y_counts).pvalue
+        assert pval_low < pval_mid <= 0.05
+        # Verify that using a very high alpha value results in iid data.
+        shards = split_multi_classif_dataset(
+            dataset, n_shards=2, scheme="dirichlet", seed=0, alpha=100000.0
+        )
+        y_counts = self.get_label_counts(shards)
+        assert chi2_contingency(y_counts).pvalue >= 0.90
+
+    def test_error_labels_too_many_shards(
+        self,
+        dataset: Tuple[Array, np.ndarray],
+    ) -> None:
+        """Test that in 'labels' schemes 'n_shards > n_lab' raises an error."""
+        with pytest.raises(ValueError):
+            split_multi_classif_dataset(dataset, n_shards=8, scheme="labels")
diff --git a/test/functional/test_toy_clf_secagg.py b/test/functional/test_toy_clf_secagg.py
index c1b13091cd0d1f39daca6379282c5296ac700cbe..cc5723e41b5d8c8d71c28da465a581637b587f4b 100644
--- a/test/functional/test_toy_clf_secagg.py
+++ b/test/functional/test_toy_clf_secagg.py
@@ -85,7 +85,7 @@ def generate_toy_dataset(
     # Cluster samples based on features and assign them to clients thereof.
     # Also split client-wise data: 80% for training and 20% for validation.
     kclust = sklearn.cluster.KMeans(
-        n_clusters=n_clients, init="random", random_state=SEED
+        n_clusters=n_clients, init="random", n_init="auto", random_state=SEED
     ).fit_predict(inputs)
     datasets = []  # type: List[Tuple[InMemoryDataset, InMemoryDataset]]
     for i in range(n_clients):
diff --git a/test/main/test_main_client.py b/test/main/test_main_client.py
index 384c2b4c210b75d6feb0f92872129594d5296eb6..861f1bc372cd0fa36bb2eb3e7a6597695ec10539 100644
--- a/test/main/test_main_client.py
+++ b/test/main/test_main_client.py
@@ -43,6 +43,11 @@ MOCK_NETWK.name = "client"
 MOCK_DATASET = mock.create_autospec(Dataset, instance=True)
 
 
+def object_new(cls, *_, **__) -> Any:
+    """Wrapper for 'object.__new__' accepting/discarding *args and **kwargs."""
+    return object.__new__(cls)
+
+
 @contextlib.contextmanager
 def patch_class_constructor(
     cls: Type[Any],
@@ -69,7 +74,7 @@ def patch_class_constructor(
             yield patch
     finally:
         if new is object.__new__:
-            cls.__new__ = lambda cls, *args, **kwargs: object.__new__(cls)
+            cls.__new__ = object_new  # type: ignore[assignment]
         else:
             cls.__new__ = new