diff --git a/declearn/aggregator/_api.py b/declearn/aggregator/_api.py
index 778b1d3e0c92a7e0c0baedea85e5b35ae6ace0a7..5f089b4180f86802cfda57d6f50d636622f19f6a 100644
--- a/declearn/aggregator/_api.py
+++ b/declearn/aggregator/_api.py
@@ -127,7 +127,7 @@ class Aggregator(metaclass=ABCMeta):
         self,
     ) -> Dict[str, Any]:
         """Return a JSON-serializable dict with this object's parameters."""
-        return {}
+        return {}  # pragma: no cover
 
     @classmethod
     def from_config(
diff --git a/declearn/communication/__init__.py b/declearn/communication/__init__.py
index e7f685f248bfc22bbc35300a0ef3f39d66660dba..b846ff0b4aa4c42bc35178685df27a701d2b1bee 100644
--- a/declearn/communication/__init__.py
+++ b/declearn/communication/__init__.py
@@ -69,11 +69,9 @@ from ._build import (
 # Concrete implementations using various protocols:
 try:
     from . import grpc
-except ImportError:
-    # pragma: no cover
+except ImportError:  # pragma: no cover
     _INSTALLABLE_BACKENDS["grpc"] = ("grpcio", "protobuf")
 try:
     from . import websockets
-except ImportError:
-    # pragma: no cover
+except ImportError:  # pragma: no cover
     _INSTALLABLE_BACKENDS["websockets"] = ("websockets",)
diff --git a/declearn/communication/_build.py b/declearn/communication/_build.py
index 1552eba942e197f45ce9be44314c6e7d45517528..854af5d38f7bfbd5df9bc4bcd140dbc6f82ab366 100644
--- a/declearn/communication/_build.py
+++ b/declearn/communication/_build.py
@@ -48,7 +48,7 @@ def raise_if_installable(
     exc: Optional[Exception] = None,
 ) -> None:
     """Raise a RuntimeError if a given protocol is missing but installable."""
-    if protocol in _INSTALLABLE_BACKENDS:
+    if protocol in _INSTALLABLE_BACKENDS:  # pragma: no cover
         raise RuntimeError(
             f"The '{protocol}' communication protocol network endpoints "
             "could not be imported, but could be installed by satisfying "
@@ -95,7 +95,7 @@ def build_client(
     protocol = protocol.strip().lower()
     try:
         cls = access_registered(name=protocol, group="NetworkClient")
-    except KeyError as exc:
+    except KeyError as exc:  # pragma: no cover
         raise_if_installable(protocol, exc)
         raise KeyError(
             "Failed to retrieve NetworkClient "
@@ -153,7 +153,7 @@ def build_server(
     protocol = protocol.strip().lower()
     try:
         cls = access_registered(name=protocol, group="NetworkServer")
-    except KeyError as exc:
+    except KeyError as exc:  # pragma: no cover
         raise_if_installable(protocol, exc)
         raise KeyError(
             "Failed to retrieve NetworkServer "
diff --git a/declearn/communication/websockets/_client.py b/declearn/communication/websockets/_client.py
index 6eeca4fb91f2afe55ed9247a2c5821fc28db0a7a..ea21cd5a943fa01a7c5f07e33365d69212cacae7 100644
--- a/declearn/communication/websockets/_client.py
+++ b/declearn/communication/websockets/_client.py
@@ -33,8 +33,9 @@ from declearn.communication.websockets._tools import (
     send_websockets_message,
 )
 
-
-CHUNK_LENGTH = 100000
+__all__ = [
+    "WebsocketsClient",
+]
 
 
 class WebsocketsClient(NetworkClient):
diff --git a/declearn/communication/websockets/_server.py b/declearn/communication/websockets/_server.py
index e2c618df71220d1a991d03f2a5e7a3f4002a6938..a8dad15c652f71505d1853d35d25192dc1d4b42c 100644
--- a/declearn/communication/websockets/_server.py
+++ b/declearn/communication/websockets/_server.py
@@ -33,8 +33,9 @@ from declearn.communication.websockets._tools import (
     send_websockets_message,
 )
 
-
-ADD_HEADER = False  # revise: drop this constant (choose a behaviour)
+__all__ = [
+    "WebsocketsServer",
+]
 
 
 class WebsocketsServer(NetworkServer):
@@ -106,18 +107,12 @@ class WebsocketsServer(NetworkServer):
     ) -> None:
         """Start the websockets server."""
         # Set up the websockets connections handling process.
-        extra_headers = (
-            ws.Headers(Connection="keep-alive")  # type: ignore
-            if ADD_HEADER
-            else None
-        )
         server = ws.serve(  # type: ignore  # pylint: disable=no-member
             self._handle_connection,
             host=self.host,
             port=self.port,
             logger=self.logger,
             ssl=self._ssl,
-            extra_headers=extra_headers,
             ping_timeout=None,  # disable timeout on keep-alive pings
         )
         # Run the websockets server.
diff --git a/declearn/communication/websockets/_tools.py b/declearn/communication/websockets/_tools.py
index d31850f04f6edc28ac4d6c280e47f1045aec3c99..df4c004ab6a1b56ae2705ef095925ed0ed8b7069 100644
--- a/declearn/communication/websockets/_tools.py
+++ b/declearn/communication/websockets/_tools.py
@@ -23,6 +23,13 @@ from typing import Union
 from websockets.legacy.protocol import WebSocketCommonProtocol
 
 
+__all__ = [
+    "StreamRefusedError",
+    "receive_websockets_message",
+    "send_websockets_message",
+]
+
+
 FLAG_STREAM_START = "STREAM_START"
 FLAG_STREAM_CLOSE = "STREAM_CLOSE"
 FLAG_STREAM_ALLOW = "STREAM_ALLOW"
diff --git a/declearn/data_info/_fields.py b/declearn/data_info/_fields.py
index bb48a5587abbf8da4697b4e1ab82d61e2c4e99fe..b46fec8204d2e7fca512427c75827c72f6df8ea0 100644
--- a/declearn/data_info/_fields.py
+++ b/declearn/data_info/_fields.py
@@ -73,11 +73,12 @@ class DataTypeField(DataInfoField):
         cls,
         value: Any,
     ) -> bool:
-        if isinstance(value, str):
-            try:
-                np.dtype(value)
-            except TypeError:
-                return False
+        if not isinstance(value, str):
+            return False
+        try:
+            np.dtype(value)
+        except TypeError:
+            return False
         return True
 
     @classmethod
@@ -159,12 +160,12 @@ class NbSamplesField(DataInfoField):
 
 
 @register_data_info_field
-class InputShapeField(DataInfoField):
+class InputShapeField(DataInfoField):  # pragma: no cover
     """Specifications for 'input_shape' data_info field."""
 
     field = "input_shape"
     types = (tuple, list)
-    doc = "Input features' batched shape, checked to be equal."
+    doc = "DEPRECATED - Input features' batched shape, checked to be equal."
 
     @classmethod
     def is_valid(
@@ -216,12 +217,12 @@ class InputShapeField(DataInfoField):
 
 
 @register_data_info_field
-class NbFeaturesField(DataInfoField):
+class NbFeaturesField(DataInfoField):  # pragma: no cover
     """Deprecated specifications for 'n_features' data_info field."""
 
     field = "n_features"
     types = (int,)
-    doc = "Number of input features, checked to be equal."
+    doc = "DEPRECATED - Number of input features, checked to be equal."
 
     @classmethod
     def is_valid(
diff --git a/declearn/dataset/_inmemory.py b/declearn/dataset/_inmemory.py
index 534cf45d92ca0f01ad0fb9ae52a511a5f8519934..1bd69fe735626ed6c41cabcd639014ea20601be1 100644
--- a/declearn/dataset/_inmemory.py
+++ b/declearn/dataset/_inmemory.py
@@ -19,7 +19,7 @@
 
 import os
 import warnings
-from typing import Any, ClassVar, Dict, Iterator, List, Optional, Set, Union
+from typing import Any, Dict, Iterator, List, Optional, Set, Union
 
 import numpy as np
 import pandas as pd
@@ -72,8 +72,6 @@ class InMemoryDataset(Dataset):
     # attributes serve clarity; pylint: disable=too-many-instance-attributes
     # arguments serve modularity; pylint: disable=too-many-arguments
 
-    _type_key: ClassVar[str] = "InMemoryDataset"
-
     def __init__(
         self,
         data: Union[DataArray, str],
@@ -149,6 +147,11 @@ class InMemoryDataset(Dataset):
                 target = self.data[target]
             else:
                 target = load_data_array(target)
+                if (
+                    isinstance(target, pd.DataFrame)
+                    and len(target.columns) == 1
+                ):
+                    target = target.iloc[:, 0]
         self.target = target
         # Assign the (optional) sample weights data array.
         if isinstance(s_wght, str):
@@ -196,7 +199,7 @@ class InMemoryDataset(Dataset):
             return set(np.unique(self.target).tolist())
         if isinstance(self.target, spmatrix):
             return set(np.unique(self.target.tocsr().data).tolist())
-        raise TypeError(
+        raise TypeError(  # pragma: no cover
             f"Invalid 'target' attribute type: '{type(self.target)}'."
         )
 
@@ -205,17 +208,17 @@ class InMemoryDataset(Dataset):
         """Unique data type."""
         if not self.expose_data_type:
             return None
-        if isinstance(self.data, pd.DataFrame):
-            dtypes = [str(t) for t in list(self.data.dtypes)]
+        if isinstance(self.feats, pd.DataFrame):
+            dtypes = {str(t) for t in list(self.feats.dtypes)}
             if len(dtypes) > 1:
                 raise ValueError(
                     "Cannot work with mixed data types:"
                     "ensure the `data` attribute has unique dtype"
                 )
-            return dtypes[0]
-        if isinstance(self.data, (pd.Series, np.ndarray, spmatrix)):
-            return str(self.data.dtype)
-        raise TypeError(
+            return list(dtypes)[0]
+        if isinstance(self.feats, (pd.Series, np.ndarray, spmatrix)):
+            return str(self.feats.dtype)
+        raise TypeError(  # pragma: no cover
             f"Invalid 'data' attribute type: '{type(self.target)}'."
         )
 
@@ -223,7 +226,7 @@ class InMemoryDataset(Dataset):
     def load_data_array(
         path: str,
         **kwargs: Any,
-    ) -> DataArray:
+    ) -> DataArray:  # pragma: no cover
         """Load a data array from a dump file.
 
         As of declearn v2.2, this staticmethod is DEPRECATED in favor of
@@ -244,7 +247,7 @@ class InMemoryDataset(Dataset):
     def save_data_array(
         path: str,
         array: Union[DataArray, pd.Series],
-    ) -> str:
+    ) -> str:  # pragma: no cover
         """Save a data array to a dump file.
 
         As of declearn v2.2, this staticmethod is DEPRECATED in favor of
@@ -308,7 +311,7 @@ class InMemoryDataset(Dataset):
         path = os.path.abspath(path)
         folder = os.path.dirname(path)
         info = {}  # type: Dict[str, Any]
-        info["type"] = self._type_key
+        info["type"] = "InMemoryDataset"  # NOTE: for backward compatibility
         # Optionally create data dumps. Record data dumps' paths.
         # fmt: off
         info["data"] = (
@@ -349,16 +352,11 @@ class InMemoryDataset(Dataset):
         if "config" not in dump:
             raise KeyError("Missing key in the JSON file: 'config'.")
         info = dump["config"]
-        for key in ("type", "data", "target", "s_wght", "f_cols"):
+        for key in ("data", "target", "s_wght", "f_cols"):
             if key not in info:
                 error = f"Missing key in the JSON file: 'config/{key}'."
                 raise KeyError(error)
-        key = info.pop("type")
-        if key != cls._type_key:
-            raise TypeError(
-                f"Incorrect 'type' field: got '{key}', "
-                f"expected '{cls._type_key}'."
-            )
+        info.pop("type", None)
         # Instantiate the object and return it.
         return cls(**info)
 
diff --git a/declearn/main/utils/_early_stop.py b/declearn/main/utils/_early_stop.py
index 92155e4832127725a7845f3b1b625d112ea814bd..f703a9b95422733ce1e4fa93dbc8b49b91e49a01 100644
--- a/declearn/main/utils/_early_stop.py
+++ b/declearn/main/utils/_early_stop.py
@@ -105,7 +105,7 @@ class EarlyStopping:
             self._best_metric = metric
         if self.relative:
             diff /= self._best_metric
-        if diff < self.tolerance:
+        if diff <= self.tolerance:
             self._n_iter_stuck += 1
         else:
             self._n_iter_stuck = 0
diff --git a/declearn/metrics/_api.py b/declearn/metrics/_api.py
index ab23eb0d24fdacabeefe9e35ca0ac112559327bb..fce098c316b488ecdf45cc09302ac2f492f8cbcc 100644
--- a/declearn/metrics/_api.py
+++ b/declearn/metrics/_api.py
@@ -347,7 +347,9 @@ class Metric(metaclass=ABCMeta):
         return s_wght
 
     @staticmethod
-    def normalize_weights(s_wght: np.ndarray) -> np.ndarray:
+    def normalize_weights(  # pragma: no cover
+        s_wght: np.ndarray,
+    ) -> np.ndarray:
         """Utility method to ensure weights sum to one.
 
         Note that this method may or may not be used depending on
diff --git a/declearn/metrics/_mean.py b/declearn/metrics/_mean.py
index 2c4189c675971944dab74ab58aa393140e5e3da0..aca63e392f91d19f7e2e863a0b121d8689a2dc55 100644
--- a/declearn/metrics/_mean.py
+++ b/declearn/metrics/_mean.py
@@ -23,6 +23,7 @@ from typing import ClassVar, Dict, Optional, Union
 import numpy as np
 
 from declearn.metrics._api import Metric
+from declearn.metrics._utils import squeeze_into_identical_shapes
 
 __all__ = [
     "MeanMetric",
@@ -129,6 +130,7 @@ class MeanAbsoluteError(MeanMetric):
         y_pred: np.ndarray,
     ) -> np.ndarray:
         # Sample-wise (sum of) absolute error function.
+        y_true, y_pred = squeeze_into_identical_shapes(y_true, y_pred)
         errors = np.abs(y_true - y_pred)
         while errors.ndim > 1:
             errors = errors.sum(axis=-1)
@@ -158,6 +160,7 @@ class MeanSquaredError(MeanMetric):
         y_pred: np.ndarray,
     ) -> np.ndarray:
         # Sample-wise (sum of) squared error function.
+        y_true, y_pred = squeeze_into_identical_shapes(y_true, y_pred)
         errors = np.square(y_true - y_pred)
         while errors.ndim > 1:
             errors = errors.sum(axis=-1)
diff --git a/declearn/metrics/_rsquared.py b/declearn/metrics/_rsquared.py
index d61fdfeaf34fc7c6cb3a47433861e42b8ec43e5e..f00b40b2267a523d054ae6229cb8ae3a36c152c8 100644
--- a/declearn/metrics/_rsquared.py
+++ b/declearn/metrics/_rsquared.py
@@ -22,6 +22,7 @@ from typing import ClassVar, Dict, Optional, Union
 import numpy as np
 
 from declearn.metrics._api import Metric
+from declearn.metrics._utils import squeeze_into_identical_shapes
 
 __all__ = [
     "RSquared",
@@ -113,6 +114,7 @@ class RSquared(Metric):
         y_pred: np.ndarray,
         s_wght: Optional[np.ndarray] = None,
     ) -> None:
+        y_true, y_pred = squeeze_into_identical_shapes(y_true, y_pred)
         # Verify sample weights' shape, or set up 1-valued ones.
         s_wght = self._prepare_sample_weights(s_wght, n_samples=len(y_pred))
         # Update the residual sum of squares. wSSr = sum(w * (y - p)^2)
diff --git a/declearn/metrics/_utils.py b/declearn/metrics/_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..757637a6d560883f18ab5fc931db4615deb17c47
--- /dev/null
+++ b/declearn/metrics/_utils.py
@@ -0,0 +1,54 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Backend utils for metrics' computations."""
+
+from typing import Tuple
+
+import numpy as np
+
+__all__ = [
+    "squeeze_into_identical_shapes",
+]
+
+
+def squeeze_into_identical_shapes(
+    y_true: np.ndarray,
+    y_pred: np.ndarray,
+) -> Tuple[np.ndarray, np.ndarray]:
+    """Verify that inputs have identical shapes, up to squeezable dims.
+
+    Return the input arrays, squeezed when needed.
+    Raise a ValueError if they cannot be made to match.
+    """
+    # Case of identical-shape inputs.
+    if y_true.shape == y_pred.shape:
+        return y_true, y_pred
+    # Case of identical-shape inputs up to squeezable dims.
+    y_true = np.squeeze(y_true)
+    y_pred = np.squeeze(y_pred)
+    if y_true.shape == y_pred.shape:
+        # Handle edge case of scalar values: preserve one dimension.
+        if not y_true.shape:
+            y_true = np.expand_dims(y_true, 0)
+            y_pred = np.expand_dims(y_pred, 0)
+        return y_true, y_pred
+    # Case of mismatching shapes.
+    raise ValueError(
+        "Received inputs with incompatible shapes: "
+        f"y_true has shape {y_true.shape}, y_pred has shape {y_pred.shape}."
+    )
diff --git a/declearn/model/haiku/_vector.py b/declearn/model/haiku/_vector.py
index 054883432d10a4d9e55de11d3e397def761afe39..6b391857d3d365125a76b8e926064ef3a418d058 100644
--- a/declearn/model/haiku/_vector.py
+++ b/declearn/model/haiku/_vector.py
@@ -156,7 +156,7 @@ class JaxNumpyVector(Vector):
             warnings.warn(  # pragma: no cover
                 "The 'axis' and 'keepdims' arguments of 'JaxNumpyVector.sum' "
                 "have been deprecated as of declearn v2.3, and will be "
-                "removed in version 2.6 and/or 3.0.",
+                "removed in version 2.5 and/or 3.0.",
                 DeprecationWarning,
             )
         coefs = {
diff --git a/declearn/model/sklearn/_np_vec.py b/declearn/model/sklearn/_np_vec.py
index 4a99a04b876a650f9f93ab34dad5f20a59a37fa7..499a6ae658ab0b916c0788cccb27fc1cecc6f3d7 100644
--- a/declearn/model/sklearn/_np_vec.py
+++ b/declearn/model/sklearn/_np_vec.py
@@ -127,7 +127,7 @@ class NumpyVector(Vector):
             warnings.warn(  # pragma: no cover
                 "The 'axis' and 'keepdims' arguments of 'NumpyVector.sum' "
                 "have been deprecated as of declearn v2.3, and will be "
-                "removed in version 2.6 and/or 3.0.",
+                "removed in version 2.5 and/or 3.0.",
                 DeprecationWarning,
             )
         coefs = {
diff --git a/declearn/model/tensorflow/_vector.py b/declearn/model/tensorflow/_vector.py
index b62dffc5390166188caf8d7ca68e0f4806789679..09d9ef58b7e58b5772a8eedf3ade4b03064378b0 100644
--- a/declearn/model/tensorflow/_vector.py
+++ b/declearn/model/tensorflow/_vector.py
@@ -295,7 +295,7 @@ class TensorflowVector(Vector):
             warnings.warn(  # pragma: no cover
                 "The 'axis' and 'keepdims' arguments of 'TensorflowVector.sum'"
                 " have been deprecated as of declearn v2.3, and will be "
-                "removed in version 2.6 and/or 3.0.",
+                "removed in version 2.5 and/or 3.0.",
                 DeprecationWarning,
             )
         return self.apply_func(tf.reduce_sum, axis=axis, keepdims=keepdims)
diff --git a/declearn/model/torch/_vector.py b/declearn/model/torch/_vector.py
index ff67cc30af3b6284dfe02cccde0096d05ccf5428..5034db537194aed56723a15bc2fc0847a6e06a4e 100644
--- a/declearn/model/torch/_vector.py
+++ b/declearn/model/torch/_vector.py
@@ -200,7 +200,7 @@ class TorchVector(Vector):
             warnings.warn(  # pragma: no cover
                 "The 'axis' and 'keepdims' arguments of 'TorchVector.sum' "
                 "have been deprecated as of declearn v2.3, and will be "
-                "removed in version 2.6 and/or 3.0.",
+                "removed in version 2.5 and/or 3.0.",
                 DeprecationWarning,
             )
         coefs = {
diff --git a/declearn/quickrun/_parser.py b/declearn/quickrun/_parser.py
index 422d0999d6625b0959ebbd0931ec0d841f13d88d..f15bfb4f2c1c4a19424aff6e110553bfd6bde4b7 100644
--- a/declearn/quickrun/_parser.py
+++ b/declearn/quickrun/_parser.py
@@ -103,8 +103,8 @@ def parse_data_folder(
 
 
 def get_data_folder_path(
-    data_folder: Optional[str],
-    root_folder: Optional[str],
+    data_folder: Optional[str] = None,
+    root_folder: Optional[str] = None,
 ) -> Path:
     """Return the path to a data folder.
 
@@ -158,7 +158,7 @@ def get_data_folder_path(
 
 def list_client_names(
     data_folder: Path,
-    client_names: Optional[List[str]],
+    client_names: Optional[List[str]] = None,
 ) -> List[str]:
     """List client-wise subdirectories under a data folder.
 
diff --git a/declearn/quickrun/_run.py b/declearn/quickrun/_run.py
index c33ae875dd1af0f5df6ca02e610877bcee62818a..46f14384f282f851a6dddbe664c624bc1664426b 100644
--- a/declearn/quickrun/_run.py
+++ b/declearn/quickrun/_run.py
@@ -143,19 +143,32 @@ def run_client(
 def get_toml_folder(config: str) -> Tuple[str, str]:
     """Return the path to an experiment's folder and TOML config file.
 
-    Determine if provided config is a file or a directory, and return:
-
-    * The path to the TOML config file
-    * The path to the main folder of the experiment
+    Parameters
+    ----------
+    config: str
+        Path to either a TOML config file (within an experiment folder),
+        or to the experiment folder containing a "config.toml" file.
+
+    Returns
+    -------
+    toml:
+        The path to the TOML config file.
+    folder:
+        The path to the main folder of the experiment.
+
+    Raises
+    ------
+    FileNotFoundError:
+        If the TOML config file cannot be found based on inputs.
     """
     config = os.path.abspath(config)
-    if os.path.isfile(config):
-        toml = config
-        folder = os.path.dirname(config)
-    elif os.path.isdir(config):
+    if os.path.isdir(config):
         folder = config
-        toml = f"{folder}/config.toml"
+        toml = os.path.join(folder, "config.toml")
     else:
+        toml = config
+        folder = os.path.dirname(toml)
+    if not os.path.isfile(toml):
         raise FileNotFoundError(
             f"Failed to find quickrun config file at '{config}'."
         )
diff --git a/pyproject.toml b/pyproject.toml
index 2161244d8d6de8a5b4e8408ccd0a494f5fe575e3..7d7131725797dfa461c43172bad177973ad35a9b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -47,10 +47,10 @@ dependencies = [
 
 [project.optional-dependencies]
 all = [  # all non-tests extra dependencies
-    "dm-haiku == 0.0.9",
+    "dm-haiku ~= 0.0.9",
     "functorch >= 1.10, < 3.0",
     "grpcio >= 1.45",
-    "jax[cpu] >= 0.4, < 0.5",
+    "jax[cpu] ~= 0.4.1",
     "opacus ~= 1.1",
     "protobuf >= 3.19",
     "tensorflow ~= 2.5",
@@ -66,8 +66,8 @@ grpc = [
     "protobuf >= 3.19",
 ]
 haiku = [
-    "dm-haiku == 0.0.9",
-    "jax[cpu] >= 0.4, < 0.5",  # NOTE: GPU support must be manually installed
+    "dm-haiku ~= 0.0.9",
+    "jax[cpu] ~= 0.4.1",  # NOTE: GPU support must be manually installed
 ]
 tensorflow = [
     "tensorflow ~= 2.5",
diff --git a/test/communication/test_exchanges.py b/test/communication/test_exchanges.py
index 2b06aa77f639c05869fdc4ff72248844dee3f70f..c91270b096e665a2de776e11fadf899c1725de02 100644
--- a/test/communication/test_exchanges.py
+++ b/test/communication/test_exchanges.py
@@ -37,6 +37,7 @@ are run with either a single or three clients at once.
 """
 
 import asyncio
+import secrets
 from typing import AsyncIterator, Dict, List, Optional, Tuple
 
 import pytest
@@ -51,7 +52,7 @@ from declearn.communication import (
 from declearn.communication.api import NetworkClient, NetworkServer
 
 
-### 1. Test that connections can properly be set up.
+### 1. Test that connections can be properly set up.
 
 
 @pytest_asyncio.fixture(name="server")
@@ -203,15 +204,18 @@ class TestNetworkExchanges:
 
     @pytest.mark.asyncio
     async def test_exchanges(
-        self, agents: Tuple[NetworkServer, List[NetworkClient]]
+        self,
+        agents: Tuple[NetworkServer, List[NetworkClient]],
     ) -> None:
         """Run all tests with the same fixture-provided agents."""
         await self.clients_to_server(agents)
         await self.server_to_clients_broadcast(agents)
         await self.server_to_clients_individual(agents)
+        await self.clients_to_server_large(agents)
 
     async def clients_to_server(
-        self, agents: Tuple[NetworkServer, List[NetworkClient]]
+        self,
+        agents: Tuple[NetworkServer, List[NetworkClient]],
     ) -> None:
         """Test that clients can send messages to the server."""
         server, clients = agents
@@ -226,7 +230,8 @@ class TestNetworkExchanges:
         }
 
     async def server_to_clients_broadcast(
-        self, agents: Tuple[NetworkServer, List[NetworkClient]]
+        self,
+        agents: Tuple[NetworkServer, List[NetworkClient]],
     ) -> None:
         """Test that the server can send a shared message to all clients."""
         server, clients = agents
@@ -237,7 +242,8 @@ class TestNetworkExchanges:
         assert all(reply == msg for reply in replies)
 
     async def server_to_clients_individual(
-        self, agents: Tuple[NetworkServer, List[NetworkClient]]
+        self,
+        agents: Tuple[NetworkServer, List[NetworkClient]],
     ) -> None:
         """Test that the server can send individual messages to clients."""
         server, clients = agents
@@ -252,3 +258,24 @@ class TestNetworkExchanges:
             reply == messages[client.name]
             for client, reply in zip(clients, replies)
         )
+
+    async def clients_to_server_large(
+        self,
+        agents: Tuple[NetworkServer, List[NetworkClient]],
+    ) -> None:
+        """Test that the clients can send large messages to the server."""
+        server, clients = agents
+        coros = []
+        large = secrets.token_bytes(2**22).hex()
+        for idx, client in enumerate(clients):
+            msg = messaging.GenericMessage(
+                action="test", params={"idx": idx, "content": large}
+            )
+            coros.append(client.send_message(msg))
+        messages, *_ = await asyncio.gather(server.wait_for_messages(), *coros)
+        assert messages == {
+            c.name: messaging.GenericMessage(
+                action="test", params={"idx": i, "content": large}
+            )
+            for i, c in enumerate(clients)
+        }
diff --git a/test/data_info/test_classes_field.py b/test/data_info/test_classes_field.py
new file mode 100644
index 0000000000000000000000000000000000000000..98addb7cf01db28b79cfffbc1d366751d0e21e42
--- /dev/null
+++ b/test/data_info/test_classes_field.py
@@ -0,0 +1,59 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for 'declearn.data_info.ClassesField'."""
+
+
+import numpy as np
+import pytest
+
+from declearn.data_info import ClassesField
+
+
+class TestClassesField:
+    """Unit tests for 'declearn.data_info.ClassesField'."""
+
+    def test_is_valid_list(self) -> None:
+        """Test `is_valid` with a valid list value."""
+        assert ClassesField.is_valid([0, 1])
+
+    def test_is_valid_set(self) -> None:
+        """Test `is_valid` with a valid set value."""
+        assert ClassesField.is_valid({0, 1})
+
+    def test_is_valid_tuple(self) -> None:
+        """Test `is_valid` with a valid tuple value."""
+        assert ClassesField.is_valid((0, 1))
+
+    def test_is_valid_array(self) -> None:
+        """Test `is_valid` with a valid numpy array value."""
+        assert ClassesField.is_valid(np.array([0, 1]))
+
+    def test_is_invalid_2d_array(self) -> None:
+        """Test `is_valid` with an invalid numpy array value."""
+        assert not ClassesField.is_valid(np.array([[0, 1], [2, 3]]))
+
+    def test_combine(self) -> None:
+        """Test `combine` with valid and compatible inputs."""
+        values = ([0, 1], (0, 1), {1, 2}, np.array([1, 3]))
+        assert ClassesField.combine(*values) == {0, 1, 2, 3}
+
+    def test_combine_fails(self) -> None:
+        """Test `combine` with some invalid inputs."""
+        values = ([0, 1], np.array([[0, 1], [2, 3]]))
+        with pytest.raises(ValueError):
+            ClassesField.combine(*values)
diff --git a/test/data_info/test_data_info_utils.py b/test/data_info/test_data_info_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d215fd1d311b397e97c3a4d75a2928d695a7c076
--- /dev/null
+++ b/test/data_info/test_data_info_utils.py
@@ -0,0 +1,144 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for 'declearn.data_info' high-level utils."""
+
+import uuid
+from typing import Any, Type
+from unittest import mock
+
+import pytest
+
+from declearn.data_info import (
+    DataInfoField,
+    aggregate_data_info,
+    get_data_info_fields_documentation,
+    register_data_info_field,
+)
+
+
+class TestAggregateDataInfo:
+    """Unit tests for 'declearn.data_info.aggregate_data_info'."""
+
+    def test_aggregate_data_info(self) -> None:
+        """Test aggregating valid, compatible data info."""
+        clients_data_info = [
+            {"n_samples": 10, "features_shape": (100,)},
+            {"n_samples": 32, "features_shape": (100,)},
+        ]
+        result = aggregate_data_info(clients_data_info)
+        assert result == {"n_samples": 42, "features_shape": (100,)}
+
+    def test_aggregate_data_info_required(self) -> None:
+        """Test aggregating a subset of valid, compatible data info."""
+        clients_data_info = [
+            {"n_samples": 10, "features_shape": (100,)},
+            {"n_samples": 32, "features_shape": (100,)},
+        ]
+        result = aggregate_data_info(
+            clients_data_info, required_fields={"n_samples"}
+        )
+        assert result == {"n_samples": 42}
+
+    def test_aggregate_data_info_missing_required(self) -> None:
+        """Test that a KeyError is raised on missing required data info."""
+        clients_data_info = [
+            {"n_samples": 10},
+            {"n_samples": 32},
+        ]
+        with pytest.raises(KeyError):
+            aggregate_data_info(
+                clients_data_info,
+                required_fields={"n_samples", "features_shape"},
+            )
+
+    def test_aggregate_data_info_invalid_values(self) -> None:
+        """Test that a ValueError is raised on invalid values."""
+        clients_data_info = [
+            {"n_samples": 10},
+            {"n_samples": -1},
+        ]
+        with pytest.raises(ValueError):
+            aggregate_data_info(clients_data_info)
+
+    def test_aggregate_data_info_incompatible_values(self) -> None:
+        """Test that a ValueError is raised on incompatible values."""
+        clients_data_info = [
+            {"features_shape": (28,)},
+            {"features_shape": (32,)},
+        ]
+        with pytest.raises(ValueError):
+            aggregate_data_info(clients_data_info)
+
+    def test_aggregate_data_info_undefined_field(self) -> None:
+        """Test that unspecified fields are handled as expected."""
+        clients_data_info = [
+            {"n_samples": 10, "undefined": "a"},
+            {"n_samples": 32, "undefined": "b"},
+        ]
+        with mock.patch("warnings.warn") as patch_warn:
+            result = aggregate_data_info(clients_data_info)
+        patch_warn.assert_called_once()
+        assert result == {"n_samples": 42, "undefined": ["a", "b"]}
+
+
+class TestRegisterDataInfoField:
+    """Unit tests for 'declearn.data_info.register_data_info_field'."""
+
+    def create_mock_cls(self) -> Type[DataInfoField]:
+        """Create and return a mock DataInfoField subclass."""
+
+        field_name = f"mock_field_{uuid.uuid4()}"
+
+        class MockDataInfoField(DataInfoField):
+            """Mock DataInfoField subclass."""
+
+            field = field_name
+            types = (str,)
+            doc = f"Documentation for '{field_name}'."
+
+            @classmethod
+            def combine(cls, *values: Any) -> Any:
+                return values
+
+        return MockDataInfoField
+
+    def test_register_data_info_field(self) -> None:
+        """Test that registrating a custom DataInfoField works."""
+        # Set up a mock DataInfoField subclass.
+        mock_cls = self.create_mock_cls()
+        # Test that it can be registered, and thereafter accessed.
+        register_data_info_field(mock_cls)
+        documentation = get_data_info_fields_documentation()
+        assert mock_cls.field in documentation
+        assert documentation[mock_cls.field] == mock_cls.doc
+
+    def test_register_data_info_field_invalid_type(self) -> None:
+        """Test that registering a non-DataInfoField subclass fails."""
+        with pytest.raises(TypeError):
+            register_data_info_field(int)  # type: ignore
+
+    def test_register_data_info_field_already_used(self) -> None:
+        """Test that registering twice under the same name fails."""
+        # Set up a couple of DataInfoField mock classes with same field name.
+        mock_cls = self.create_mock_cls()
+        mock_bis = self.create_mock_cls()
+        mock_bis.field = mock_cls.field
+        # Test that they cannot both be registered.
+        register_data_info_field(mock_cls)
+        with pytest.raises(KeyError):
+            register_data_info_field(mock_bis)
diff --git a/test/data_info/test_datatype_field.py b/test/data_info/test_datatype_field.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4c2f1bf76c11fcea1ff5d628374e7dddeb6115a
--- /dev/null
+++ b/test/data_info/test_datatype_field.py
@@ -0,0 +1,57 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for 'declearn.data_info.DataTypeField'."""
+
+
+import numpy as np
+import pytest
+
+from declearn.data_info import DataTypeField
+
+
+class TestDataTypeField:
+    """Unit tests for 'declearn.data_info.DataTypeField'."""
+
+    def test_is_valid(self) -> None:
+        """Test `is_valid` with some valid values."""
+        assert DataTypeField.is_valid("float32")
+        assert DataTypeField.is_valid("float64")
+        assert DataTypeField.is_valid("int32")
+        assert DataTypeField.is_valid("uint8")
+
+    def test_is_not_valid(self) -> None:
+        """Test `is_valid` with invalid values."""
+        assert not DataTypeField.is_valid(np.int32)
+        assert not DataTypeField.is_valid("mocktype")
+
+    def test_combine(self) -> None:
+        """Test `combine` with valid and compatible inputs."""
+        values = ["float32", "float32"]
+        assert DataTypeField.combine(*values) == "float32"
+
+    def test_combine_invalid(self) -> None:
+        """Test `combine` with invalid inputs."""
+        values = ["float32", "mocktype"]
+        with pytest.raises(ValueError):
+            DataTypeField.combine(*values)
+
+    def test_combine_incompatible(self) -> None:
+        """Test `combine` with incompatible inputs."""
+        values = ["float32", "float16"]
+        with pytest.raises(ValueError):
+            DataTypeField.combine(*values)
diff --git a/test/data_info/test_nbsamples_field.py b/test/data_info/test_nbsamples_field.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc53f3dfdc31dd29ef21b9d0a2bb524034f5e739
--- /dev/null
+++ b/test/data_info/test_nbsamples_field.py
@@ -0,0 +1,52 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for 'declearn.data_info.NbSamplesField'."""
+
+
+import pytest
+
+from declearn.data_info import NbSamplesField
+
+
+class TestNbSamplesField:
+    """Unit tests for 'declearn.data_info.NbSamplesField'."""
+
+    def test_is_valid(self) -> None:
+        """Test `is_valid` with some valid input values."""
+        assert NbSamplesField.is_valid(32)
+        assert NbSamplesField.is_valid(100)
+        assert NbSamplesField.is_valid(8192)
+
+    def test_is_not_valid(self) -> None:
+        """Test `is_valid` with invalid values."""
+        assert not NbSamplesField.is_valid(16.5)
+        assert not NbSamplesField.is_valid(-12)
+        assert not NbSamplesField.is_valid(None)
+
+    def test_combine(self) -> None:
+        """Test `combine` with valid and compatible inputs."""
+        values = [32, 128]
+        assert NbSamplesField.combine(*values) == 160
+        values = [64, 64, 64, 64]
+        assert NbSamplesField.combine(*values) == 256
+
+    def test_combine_invalid(self) -> None:
+        """Test `combine` with invalid inputs."""
+        values = [128, -12]
+        with pytest.raises(ValueError):
+            NbSamplesField.combine(*values)
diff --git a/test/data_info/test_shape_field.py b/test/data_info/test_shape_field.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef9f61c4a06c28cb5d31a31fad560109a6d95089
--- /dev/null
+++ b/test/data_info/test_shape_field.py
@@ -0,0 +1,67 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for 'declearn.data_info.FeaturesShapeField'."""
+
+
+import pytest
+
+from declearn.data_info import FeaturesShapeField
+
+
+class TestFeaturesShapeField:
+    """Unit tests for 'declearn.data_info.FeaturesShapeField'."""
+
+    def test_is_valid(self) -> None:
+        """Test `is_valid` with some valid input values."""
+        # 1-d ; fixed 3-d (image-like) ; variable 2-d (text-like).
+        assert FeaturesShapeField.is_valid([32])
+        assert FeaturesShapeField.is_valid([64, 64, 3])
+        assert FeaturesShapeField.is_valid([None, 128])
+        # Same inputs, as tuples.
+        assert FeaturesShapeField.is_valid((32,))
+        assert FeaturesShapeField.is_valid((64, 64, 3))
+        assert FeaturesShapeField.is_valid((None, 128))
+
+    def test_is_not_valid(self) -> None:
+        """Test `is_valid` with invalid values."""
+        assert not FeaturesShapeField.is_valid(32)
+        assert not FeaturesShapeField.is_valid([32, -1])
+
+    def test_combine(self) -> None:
+        """Test `combine` with valid and compatible inputs."""
+        # 1-d inputs.
+        values = [[32], (32,)]
+        assert FeaturesShapeField.combine(*values) == (32,)
+        # 3-d fixed-size inputs.
+        values = [[16, 16, 3], (16, 16, 3)]
+        assert FeaturesShapeField.combine(*values) == (16, 16, 3)
+        # 2-d variable-size inputs.
+        values = [[None, 512], (None, 512)]  # type: ignore
+        assert FeaturesShapeField.combine(*values) == (None, 512)
+
+    def test_combine_invalid(self) -> None:
+        """Test `combine` with invalid inputs."""
+        values = [[32], [32, -1]]
+        with pytest.raises(ValueError):
+            FeaturesShapeField.combine(*values)
+
+    def test_combine_incompatible(self) -> None:
+        """Test `combine` with incompatible inputs."""
+        values = [(None, 32), (128,)]
+        with pytest.raises(ValueError):
+            FeaturesShapeField.combine(*values)
diff --git a/test/dataset/test_utils.py b/test/dataset/test_dataset_utils.py
similarity index 100%
rename from test/dataset/test_utils.py
rename to test/dataset/test_dataset_utils.py
diff --git a/test/dataset/test_inmemory.py b/test/dataset/test_inmemory.py
index ff6456c6a6fcb57729e0aff9bd8fb09528707cac..3551dd499058bdb064bc64c911f319dee5c77b36 100644
--- a/test/dataset/test_inmemory.py
+++ b/test/dataset/test_inmemory.py
@@ -15,13 +15,20 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-"""Unit tests objects for 'declearn.dataset.InMemoryDataset'"""
+"""Unit tests for 'declearn.dataset.InMemoryDataset'"""
 
+import json
 import os
 
+import numpy as np
+import pandas as pd
 import pytest
+import scipy.sparse  # type: ignore
+import sklearn.datasets  # type: ignore
+
 
 from declearn.dataset import InMemoryDataset
+from declearn.dataset.utils import save_data_array
 from declearn.test_utils import make_importable
 
 # relative imports from `dataset_testbase.py`
@@ -32,6 +39,9 @@ with make_importable(os.path.dirname(__file__)):
 SEED = 0
 
 
+### Shared-tests-based tests, revolving around batches generation.
+
+
 class InMemoryDatasetTestToolbox(DatasetTestToolbox):
     """Toolbox for InMemoryDataset"""
 
@@ -44,10 +54,288 @@ class InMemoryDatasetTestToolbox(DatasetTestToolbox):
 
 
 @pytest.fixture(name="toolbox")
-def fixture_dataset() -> DatasetTestToolbox:
+def fixture_toolbox() -> DatasetTestToolbox:
     """Fixture to access a InMemoryDatasetTestToolbox."""
     return InMemoryDatasetTestToolbox()
 
 
 class TestInMemoryDataset(DatasetTestSuite):
     """Unit tests for declearn.dataset.InMemoryDataset."""
+
+
+### InMemoryDataset-specific unit tests.
+
+
+@pytest.fixture(name="dataset")
+def dataset_fixture() -> pd.DataFrame:
+    """Fixture providing with a small toy dataset."""
+    rng = np.random.default_rng(seed=SEED)
+    wgt = rng.normal(size=10).astype("float32")
+    data = {
+        "col_a": np.arange(10, dtype="float32"),
+        "col_b": rng.normal(size=10).astype("float32"),
+        "col_y": rng.choice(3, size=10, replace=True),
+        "col_w": wgt / sum(wgt),
+    }
+    return pd.DataFrame(data)
+
+
+class TestInMemoryDatasetInit:
+    """Unit tests for 'declearn.dataset.InMemoryDataset' instantiation."""
+
+    def test_from_inputs(
+        self,
+        dataset: pd.DataFrame,
+    ) -> None:
+        """Test instantiating with (x, y, w) array data."""
+        # Split data into distinct objects with various types.
+        y_dat = dataset.pop("col_y")
+        w_dat = dataset.pop("col_w").values
+        x_dat = scipy.sparse.coo_matrix(dataset.values)
+        # Test that an InMemoryDataset can be instantiated from that data.
+        dst = InMemoryDataset(data=x_dat, target=y_dat, s_wght=w_dat)
+        assert dst.feats is x_dat
+        assert dst.target is y_dat
+        assert dst.weights is w_dat
+
+    def test_from_dataframe(
+        self,
+        dataset: pd.DataFrame,
+    ) -> None:
+        """Test instantiating with a pandas DataFrame and column names."""
+        dst = InMemoryDataset(data=dataset, target="col_y", s_wght="col_w")
+        assert np.allclose(dst.feats, dataset[["col_a", "col_b"]])
+        assert np.allclose(dst.target, dataset["col_y"])  # type: ignore
+        assert np.allclose(dst.weights, dataset["col_w"])  # type: ignore
+
+    def test_from_dataframe_with_fcols_str(
+        self,
+        dataset: pd.DataFrame,
+    ) -> None:
+        """Test instantiating from a pandas Dataframe with string f_cols."""
+        dst = InMemoryDataset(
+            data=dataset, target="col_y", s_wght="col_w", f_cols=["col_a"]
+        )
+        assert np.allclose(dst.feats, dataset[["col_a"]])
+        assert np.allclose(dst.target, dataset["col_y"])  # type: ignore
+        assert np.allclose(dst.weights, dataset["col_w"])  # type: ignore
+
+    def test_from_dataframe_with_fcols_int(
+        self,
+        dataset: pd.DataFrame,
+    ) -> None:
+        """Test instantiating from a pandas Dataframe with string f_cols."""
+        dst = InMemoryDataset(
+            data=dataset, target="col_y", s_wght="col_w", f_cols=[1]
+        )
+        assert np.allclose(dst.feats, dataset[["col_b"]])
+        assert np.allclose(dst.target, dataset["col_y"])  # type: ignore
+        assert np.allclose(dst.weights, dataset["col_w"])  # type: ignore
+
+    def test_from_csv_file(
+        self,
+        dataset: pd.DataFrame,
+        tmp_path: str,
+    ) -> None:
+        """Test instantiating from a single csv file and column names."""
+        # Dump the dataset to a csv file and instantiate from it.
+        path = os.path.join(tmp_path, "dataset.csv")
+        dataset.to_csv(path, index=False)
+        dst = InMemoryDataset(data=path, target="col_y", s_wght="col_w")
+        # Test that the data matches expectations.
+        assert np.allclose(dst.feats, dataset[["col_a", "col_b"]])
+        assert np.allclose(dst.target, dataset["col_y"])  # type: ignore
+        assert np.allclose(dst.weights, dataset["col_w"])  # type: ignore
+
+    def test_from_csv_file_feats_only(
+        self,
+        dataset: pd.DataFrame,
+        tmp_path: str,
+    ) -> None:
+        """Test instantiating from a single csv file without y nor w."""
+        # Dump the dataset to a csv file and instantiate from it.
+        path = os.path.join(tmp_path, "dataset.csv")
+        dataset.to_csv(path, index=False)
+        dst = InMemoryDataset(data=path)
+        # Test that the data matches expectations.
+        assert np.allclose(dst.feats, dataset)
+        assert dst.target is None
+        assert dst.weights is None
+
+    def test_from_data_files(
+        self,
+        dataset: pd.DataFrame,
+        tmp_path: str,
+    ) -> None:
+        """Test instantiating from a collection of files."""
+        # Split data into distinct objects with various types.
+        y_dat = dataset.pop("col_y")
+        w_dat = dataset.pop("col_w").values
+        x_dat = scipy.sparse.coo_matrix(dataset.values)
+        # Save these objects to files.
+        x_path = save_data_array(os.path.join(tmp_path, "data_x"), x_dat)
+        y_path = save_data_array(os.path.join(tmp_path, "data_y"), y_dat)
+        w_path = save_data_array(os.path.join(tmp_path, "data_w"), w_dat)
+        # Tes that an InMemoryDataset can be instantiated from these files.
+        dst = InMemoryDataset(data=x_path, target=y_path, s_wght=w_path)
+        assert isinstance(dst.feats, scipy.sparse.coo_matrix)
+        assert np.allclose(dst.feats.toarray(), x_dat.toarray())
+        assert isinstance(dst.target, pd.Series)
+        assert np.allclose(dst.target, y_dat)
+        assert isinstance(dst.weights, np.ndarray)
+        assert np.allclose(dst.weights, w_dat)  # type: ignore
+
+    def test_from_svmlight(
+        self,
+        dataset: pd.DataFrame,
+        tmp_path: str,
+    ) -> None:
+        """Test instantiating from a svmlight file."""
+        path = os.path.join(tmp_path, "dataset.svmlight")
+        sklearn.datasets.dump_svmlight_file(
+            scipy.sparse.coo_matrix(dataset[["col_a", "col_b"]].values),
+            dataset["col_y"].values,
+            path,
+        )
+        dst = InMemoryDataset.from_svmlight(path)
+        assert isinstance(dst.data, scipy.sparse.csr_matrix)
+        assert np.allclose(
+            dst.data.toarray(), dataset[["col_a", "col_b"]].values
+        )
+        assert isinstance(dst.target, np.ndarray)
+        assert np.allclose(dst.target, dataset["col_y"].to_numpy())
+        assert dst.weights is None
+
+
+class TestInMemoryDatasetProperties:
+    """Unit tests for 'declearn.dataset.InMemoryDataset' properties."""
+
+    def test_classes_array(
+        self,
+        dataset: pd.DataFrame,
+    ) -> None:
+        """Test (authorized) classes access with numpy array targets."""
+        dst = InMemoryDataset(
+            data=dataset, target=dataset["col_y"].values, expose_classes=True
+        )
+        assert dst.classes == {0, 1, 2}
+
+    def test_classes_series(
+        self,
+        dataset: pd.DataFrame,
+    ) -> None:
+        """Test (authorized) classes access with pandas Series targets."""
+        dst = InMemoryDataset(
+            data=dataset, target="col_y", expose_classes=True
+        )
+        assert dst.classes == {0, 1, 2}
+
+    def test_classes_dataframe(
+        self,
+        dataset: pd.DataFrame,
+    ) -> None:
+        """Test (authorized) classes access with pandas DataFrame targets."""
+        dst = InMemoryDataset(
+            data=dataset, target=dataset[["col_y"]], expose_classes=True
+        )
+        assert dst.classes == {0, 1, 2}
+
+    def test_classes_sparse(
+        self,
+        dataset: pd.DataFrame,
+    ) -> None:
+        """Test (authorized) classes access with scipy spmatrix targets."""
+        y_dat = scipy.sparse.coo_matrix(dataset[["col_y"]] + 1)
+        dst = InMemoryDataset(data=dataset, target=y_dat, expose_classes=True)
+        assert dst.classes == {1, 2, 3}
+
+    def test_data_type_dataframe(
+        self,
+        dataset: pd.DataFrame,
+    ) -> None:
+        """Test (authorized) data-type access with pandas DataFrame data."""
+        dst = InMemoryDataset(
+            data=dataset[["col_a", "col_b"]], expose_data_type=True
+        )
+        assert dst.data_type == "float32"
+
+    def test_data_type_dataframe_mixed(
+        self,
+        dataset: pd.DataFrame,
+    ) -> None:
+        """Test that an exception is raised with a mixed-type DataFrame."""
+        dst = InMemoryDataset(data=dataset, expose_data_type=True)
+        with pytest.raises(ValueError):
+            dst.data_type  # pylint: disable=pointless-statement
+
+    def test_data_type_series(
+        self,
+        dataset: pd.DataFrame,
+    ) -> None:
+        """Test (authorized) data-type access with pandas Series data."""
+        dst = InMemoryDataset(data=dataset["col_a"], expose_data_type=True)
+        assert dst.data_type == "float32"
+
+    def test_data_type_array(
+        self,
+        dataset: pd.DataFrame,
+    ) -> None:
+        """Test (authorized) data-type access with numpy array data."""
+        data = dataset[["col_a", "col_b"]].values
+        dst = InMemoryDataset(data=data, expose_data_type=True)
+        assert dst.data_type == "float32"
+
+    def test_data_type_sparse(
+        self,
+        dataset: pd.DataFrame,
+    ) -> None:
+        """Test (authorized) data-type access with scipy spmatrix data."""
+        data = scipy.sparse.coo_matrix(dataset[["col_a", "col_b"]].values)
+        dst = InMemoryDataset(data=data, expose_data_type=True)
+        assert dst.data_type == "float32"
+
+
+class TestInMemoryDatasetSaveLoad:
+    """Test JSON-file saving/loading features of InMemoryDataset."""
+
+    def test_save_load_json(
+        self,
+        dataset: pd.DataFrame,
+        tmp_path: str,
+    ) -> None:
+        """Test that a dataset can be saved to and loaded from JSON."""
+        dst = InMemoryDataset(dataset, target="col_y", s_wght="col_w")
+        # Test that the dataset can be saved to JSON.
+        path = os.path.join(tmp_path, "dataset.json")
+        dst.save_to_json(path)
+        assert os.path.isfile(path)
+        # Test that it can be reloaded from JSON.
+        bis = InMemoryDataset.load_from_json(path)
+        assert np.allclose(dst.data, bis.data)
+        assert np.allclose(dst.target, bis.target)  # type: ignore
+        assert np.allclose(dst.weights, bis.weights)  # type: ignore
+        assert dst.f_cols == bis.f_cols
+        assert dst.expose_classes == bis.expose_classes
+        assert dst.expose_data_type == bis.expose_data_type
+
+    def test_load_json_malformed(
+        self,
+        tmp_path: str,
+    ) -> None:
+        """Test with a JSON file that has nothing to do with a dataset."""
+        path = os.path.join(tmp_path, "dataset.json")
+        with open(path, "w", encoding="utf-8") as file:
+            json.dump({"not-a-dataset": "at-all"}, file)
+        with pytest.raises(KeyError):
+            InMemoryDataset.load_from_json(path)
+
+    def test_load_json_partial(
+        self,
+        tmp_path: str,
+    ) -> None:
+        """Test with a JSON file that contains a partial dataset config."""
+        path = os.path.join(tmp_path, "dataset.json")
+        with open(path, "w", encoding="utf-8") as file:
+            json.dump({"config": {"data": "mock", "target": "mock"}}, file)
+        with pytest.raises(KeyError):
+            InMemoryDataset.load_from_json(path)
diff --git a/test/functional/test_regression.py b/test/functional/test_toy_reg.py
similarity index 57%
rename from test/functional/test_regression.py
rename to test/functional/test_toy_reg.py
index ce5f805211d5fcf14208e2d22227b0ce8eee219c..469e06f3f1605cc1dd0713bdd14f58e994664dc5 100644
--- a/test/functional/test_regression.py
+++ b/test/functional/test_toy_reg.py
@@ -42,16 +42,20 @@ The convergence results of those experiments is then compared.
 
 """
 
+import asyncio
+import dataclasses
 import json
+import os
 import tempfile
 from typing import List, Tuple
 
 import numpy as np
-from sklearn.datasets import make_regression  # type: ignore
-from sklearn.linear_model import SGDRegressor  # type: ignore
+import pytest
+import sklearn.datasets  # type: ignore
+import sklearn.linear_model  # type: ignore
 
 from declearn.communication import NetworkClientConfig, NetworkServerConfig
-from declearn.dataset import InMemoryDataset
+from declearn.dataset import Dataset, InMemoryDataset
 from declearn.main import FederatedClient, FederatedServer
 from declearn.main.config import FLOptimConfig, FLRunConfig
 from declearn.metrics import RSquared
@@ -59,7 +63,7 @@ from declearn.model.api import Model
 from declearn.model.sklearn import SklearnSGDModel
 from declearn.optimizer import Optimizer
 from declearn.test_utils import FrameworkType
-from declearn.utils import run_as_processes, set_device_policy
+from declearn.utils import set_device_policy
 
 # optional frameworks' dependencies pylint: disable=ungrouped-imports
 # pylint: disable=duplicate-code
@@ -71,7 +75,8 @@ try:
 except ModuleNotFoundError:
     pass
 else:
-    from declearn.model.tensorflow import TensorflowModel
+    from declearn.dataset.tensorflow import TensorflowDataset
+    from declearn.model.tensorflow import TensorflowModel, TensorflowVector
 # torch imports
 try:
     import torch
@@ -79,8 +84,7 @@ except ModuleNotFoundError:
     pass
 else:
     from declearn.dataset.torch import TorchDataset
-    from declearn.model.torch import TorchModel
-# pylint: enable=duplicate-code
+    from declearn.model.torch import TorchModel, TorchVector
 # haiku and jax imports
 try:
     import haiku as hk
@@ -88,7 +92,7 @@ try:
 except ModuleNotFoundError:
     pass
 else:
-    from declearn.model.haiku import HaikuModel
+    from declearn.model.haiku import HaikuModel, JaxNumpyVector
 
     def haiku_model_fn(inputs: jax.Array) -> jax.Array:
         """Simple linear model implemented with Haiku."""
@@ -96,40 +100,81 @@ else:
 
     def haiku_loss_fn(y_pred: jax.Array, y_true: jax.Array) -> jax.Array:
         """Sample-wise squared error loss function."""
+        y_pred = jax.numpy.squeeze(y_pred)
         return (y_pred - y_true) ** 2
 
 
+# pylint: disable=duplicate-code
+
 SEED = 0
-R2_THRESHOLD = 0.999
+R2_THRESHOLD = 0.9999
 
 
 set_device_policy(gpu=False)  # disable GPU use to avoid concurrence
 
 
 def get_model(framework: FrameworkType) -> Model:
-    """Set up a simple toy regression model."""
+    """Set up a simple toy regression model, with zero-valued weights."""
     set_device_policy(gpu=False)  # disable GPU use to avoid concurrence
     if framework == "numpy":
-        np.random.seed(SEED)  # set seed
-        model = SklearnSGDModel.from_parameters(
-            kind="regressor", loss="squared_error", penalty="none"
-        )  # type: Model
-    elif framework == "tensorflow":
-        tf.random.set_seed(SEED)  # set seed
-        tfmod = tf.keras.Sequential(tf.keras.layers.Dense(units=1))
-        tfmod.build([None, 100])
-        model = TensorflowModel(tfmod, loss="mean_squared_error")
-    elif framework == "torch":
-        torch.manual_seed(SEED)  # set seed
-        torchmod = torch.nn.Sequential(
-            torch.nn.Linear(100, 1, bias=True),
-            torch.nn.Flatten(0),
-        )
-        model = TorchModel(torchmod, loss=torch.nn.MSELoss())
-    elif framework == "jax":
-        model = HaikuModel(haiku_model_fn, loss=haiku_loss_fn)
-    else:
-        raise ValueError("unrecognised framework")
+        return _get_model_numpy()
+    if framework == "tensorflow":
+        return _get_model_tflow()
+    if framework == "torch":
+        return _get_model_torch()
+    if framework == "jax":
+        return _get_model_haiku()
+    raise ValueError(f"Unrecognised model framework: '{framework}'.")
+
+
+def _get_model_numpy() -> SklearnSGDModel:
+    """Return a linear model with MSE loss in Sklearn, with zero weights."""
+    np.random.seed(SEED)  # set seed
+    model = SklearnSGDModel.from_parameters(
+        kind="regressor", loss="squared_error", penalty="none"
+    )
+    return model
+
+
+def _get_model_tflow() -> TensorflowModel:
+    """Return a linear model with MSE loss in TensorFlow, with zero weights."""
+    tf.random.set_seed(SEED)  # set seed
+    tfmod = tf.keras.Sequential(tf.keras.layers.Dense(units=1))
+    tfmod.build([None, 100])
+    model = TensorflowModel(tfmod, loss="mean_squared_error")
+    zeros = {
+        key: tf.zeros_like(val)
+        for key, val in model.get_weights().coefs.items()
+    }
+    model.set_weights(TensorflowVector(zeros))
+    return model
+
+
+def _get_model_torch() -> TorchModel:
+    """Return a linear model with MSE loss in Torch, with zero weights."""
+    torch.manual_seed(SEED)  # set seed
+    torchmod = torch.nn.Sequential(
+        torch.nn.Linear(100, 1, bias=True),
+        torch.nn.Flatten(0),
+    )
+    model = TorchModel(torchmod, loss=torch.nn.MSELoss())
+    zeros = {
+        key: torch.zeros_like(val)
+        for key, val in model.get_weights().coefs.items()
+    }
+    model.set_weights(TorchVector(zeros))
+    return model
+
+
+def _get_model_haiku() -> HaikuModel:
+    """Return a linear model with MSE loss in Haiku, with zero weights."""
+    model = HaikuModel(haiku_model_fn, loss=haiku_loss_fn)
+    model.initialize({"data_type": "float32", "features_shape": (100,)})
+    zeros = {
+        key: jax.numpy.zeros_like(val)
+        for key, val in model.get_weights().coefs.items()
+    }
+    model.set_weights(JaxNumpyVector(zeros))
     return model
 
 
@@ -139,48 +184,13 @@ def get_dataset(framework: FrameworkType, inputs, labels):
         inputs = torch.from_numpy(inputs)
         labels = torch.from_numpy(labels)
         return TorchDataset(torch.utils.data.TensorDataset(inputs, labels))
-    return InMemoryDataset(inputs, labels)
-
-
-def prep_client_datasets(
-    framework: FrameworkType,
-    clients: int = 3,
-    n_train: int = 100,
-    n_valid: int = 50,
-) -> List[Tuple[InMemoryDataset, InMemoryDataset]]:
-    """Generate and split toy data for a regression problem.
-
-    Parameters
-    ----------
-    clients: int, default=3
-        Number of clients, i.e. of dataset shards to generate.
-    n_train: int, default=30
-        Number of training samples per client.
-    n_valid: int, default=30
-        Number of validation samples per client.
-
-    Returns
-    -------
-    datasets: list[(InMemoryDataset, InMemoryDataset)]
-        List of client-wise (train, valid) pair of datasets.
-    """
-
-    n_samples = clients * (n_train + n_valid)
-    # false-positive; pylint: disable=unbalanced-tuple-unpacking
-    inputs, target = make_regression(
-        n_samples, n_features=100, n_informative=10, random_state=SEED
-    )
-    inputs, target = inputs.astype("float32"), target.astype("float32")
-    # Wrap up the data into client-wise pairs of dataset.
-    out = []  # type: List[Tuple[InMemoryDataset, InMemoryDataset]]
-    for idx in range(clients):
-        start = (n_train + n_valid) * idx
-        mid = start + n_train
-        end = mid + n_valid
-        train = get_dataset(framework, inputs[start:mid], target[start:mid])
-        valid = get_dataset(framework, inputs[mid:end], target[mid:end])
-        out.append((train, valid))
-    return out
+    if framework == "tensorflow":
+        inputs = tf.convert_to_tensor(inputs)
+        labels = tf.convert_to_tensor(labels)
+        return TensorflowDataset(
+            tf.data.Dataset.from_tensor_slices((inputs, labels))
+        )
+    return InMemoryDataset(inputs, labels, expose_data_type=True)
 
 
 def prep_full_dataset(
@@ -204,7 +214,7 @@ def prep_full_dataset(
     """
     n_samples = n_train + n_valid
     # false-positive; pylint: disable=unbalanced-tuple-unpacking
-    inputs, target = make_regression(
+    inputs, target = sklearn.datasets.make_regression(
         n_samples, n_features=100, n_informative=10, random_state=SEED
     )
     inputs, target = inputs.astype("float32"), target.astype("float32")
@@ -216,62 +226,150 @@ def prep_full_dataset(
     return out
 
 
-def test_declearn_experiment(
+def test_sklearn_baseline(
+    lrate: float = 0.04,
+    rounds: int = 10,
+    b_size: int = 10,
+) -> None:
+    """Run a baseline using scikit-learn to emulate a centralized setting.
+
+    This function does not use declearn. It sets up a single sklearn
+    model and performs training on the full dataset.
+
+    Parameters
+    ----------
+    lrate: float, default=0.01
+        Learning rate of the SGD algorithm.
+    rounds: int, default=10
+        Number of training rounds to perform, i.e. number of epochs.
+    b_size: int, default=10
+        Batch size fot the training (and validation) data.
+        Batching will be performed without shuffling nor replacement,
+        and the final batch may be smaller than the others (no drop).
+    """
+    # Generate the client datasets, then centralize them into numpy arrays.
+    train, valid = prep_full_dataset()
+    # Set up a scikit-learn model, implementing step-wise gradient descent.
+    sgd = sklearn.linear_model.SGDRegressor(
+        loss="squared_error",
+        penalty="l1",
+        alpha=0.1,
+        eta0=lrate / b_size,  # adjust learning rate for (dropped) batch size
+        learning_rate="constant",  # disable scaling, unused in declearn
+        max_iter=rounds,
+    )
+    # Iteratively train the model, evaluating it after each epoch.
+    for _ in range(rounds):
+        sgd.partial_fit(train[0], train[1])
+    assert sgd.score(valid[0], valid[1]) > R2_THRESHOLD
+
+
+def test_declearn_baseline(
     framework: FrameworkType,
-    lrate: float = 0.01,
+    lrate: float = 0.02,
     rounds: int = 10,
-    b_size: int = 1,
-    clients: int = 3,
+    b_size: int = 10,
 ) -> None:
-    """Run an experiment using declearn to perform a federative training.
+    """Run a baseline uing declearn APIs to emulate a centralized setting.
 
-    This function runs the experiment using declearn.
-    It sets up and runs the server and client-wise routines in separate
-    processes, to enable their concurrent execution.
+    This function uses declearn but sets up a single model and performs
+    training on the entire toy regression dataset.
 
     Parameters
     ----------
     framework: str
         Framework of the model to train and evaluate.
-    lrate: float, default=0.01
-        Learning rate of the SGD algorithm performed by clients.
+    lrate: float, default=0.02
+        Learning rate of the SGD algorithm.
     rounds: int, default=10
-        Number of FL training rounds to perform.
-        At each round, each client will perform a full epoch of training.
+        Number of training rounds to perform, i.e. number of epochs.
     b_size: int, default=10
         Batch size fot the training (and validation) data.
         Batching will be performed without shuffling nor replacement,
         and the final batch may be smaller than the others (no drop).
-    clients: int, default=3
-        Number of federated clients to set up and run.
     """
-    # pylint: disable=too-many-locals
-    with tempfile.TemporaryDirectory() as folder:
-        # Set up a (func, args) tuple specifying the server process.
-        p_server = (
-            _server_routine,
-            (folder, framework, lrate, rounds, b_size, clients),
+    # Generate the client datasets, then centralize them into numpy arrays.
+    train, valid = prep_full_dataset()
+    dst_train = get_dataset(framework, *train)
+    # Set up a declearn model and a SGD optimizer with Lasso regularization.
+    model = get_model(framework)
+    model.initialize(dataclasses.asdict(dst_train.get_data_specs()))
+    optim = Optimizer(
+        lrate=lrate if framework != "numpy" else (lrate * 2),
+        regularizers=[("lasso", {"alpha": 0.1})],
+    )
+    # Iteratively train the model and evaluate it between rounds.
+    r_sq = RSquared()
+    scores = []  # type: List[float]
+    for _ in range(rounds):
+        for batch in dst_train.generate_batches(
+            batch_size=b_size, drop_remainder=False
+        ):
+            optim.run_train_step(model, batch)
+        pred = model.compute_batch_predictions((*valid, None))
+        r_sq.reset()
+        r_sq.update(*pred)
+        scores.append(r_sq.get_result()["r2"])  # type: ignore
+    # Check that the R2 increased through epochs to reach a high value.
+    print(scores)
+    assert all(scores[i + 1] >= scores[i] for i in range(rounds - 1))
+    assert scores[-1] >= R2_THRESHOLD
+
+
+def prep_client_datasets(
+    framework: FrameworkType,
+    clients: int = 3,
+    n_train: int = 100,
+    n_valid: int = 50,
+) -> List[Tuple[Dataset, Dataset]]:
+    """Generate and split data for a toy sparse regression problem.
+
+    Parameters
+    ----------
+    framework:
+        Name of the framework being tested, based on which the Dataset
+        class choice may be adjusted as well.
+    clients:
+        Number of clients, i.e. of dataset shards to generate.
+    n_train:
+        Number of training samples per client.
+    n_valid:
+        Number of validation samples per client.
+
+    Returns
+    -------
+    datasets:
+        List of client-wise tuple of (train, valid) Dataset instances.
+    """
+    train, valid = prep_full_dataset(
+        n_train=clients * n_train,
+        n_valid=clients * n_valid,
+    )
+    # Wrap up the data into client-wise pairs of dataset.
+    out = []  # type: List[Tuple[Dataset, Dataset]]
+    for idx in range(clients):
+        # Gather the client's training dataset.
+        srt = n_train * idx
+        end = n_train + srt
+        dst_train = get_dataset(
+            framework=framework,
+            inputs=train[0][srt:end],
+            labels=train[1][srt:end],
         )
-        # Set up the (func, args) tuples specifying client-wise processes.
-        datasets = prep_client_datasets(framework, clients)
-        p_client = []
-        for i, data in enumerate(datasets):
-            client = (_client_routine, (data[0], data[1], f"client_{i}"))
-            p_client.append(client)
-        # Run each and every process in parallel.
-        success, outputs = run_as_processes(p_server, *p_client)
-        assert success, "The FL process failed:\n" + "\n".join(
-            str(exc) for exc in outputs if isinstance(exc, RuntimeError)
+        # Gather the client's validation dataset.
+        srt = n_valid * idx
+        end = n_valid + srt
+        dst_valid = get_dataset(
+            framework=framework,
+            inputs=valid[0][srt:end],
+            labels=valid[1][srt:end],
         )
-        # Assert convergence
-        with open(f"{folder}/metrics.json", encoding="utf-8") as file:
-            r2_dict = json.load(file)
-            last_r2_dict = r2_dict.get(max(r2_dict.keys()))
-            final_r2 = float(last_r2_dict.get("r2"))
-            assert final_r2 > R2_THRESHOLD, "The FL training did not converge"
+        # Store both datasets into the output list.
+        out.append((dst_train, dst_valid))
+    return out
 
 
-def _server_routine(
+async def async_run_server(
     folder: str,
     framework: FrameworkType,
     lrate: float = 0.01,
@@ -289,7 +387,7 @@ def _server_routine(
     optim = FLOptimConfig.from_params(
         aggregator="averaging",
         client_opt={
-            "lrate": lrate,
+            "lrate": lrate if framework != "numpy" else (lrate * 2),
             "regularizers": [("lasso", {"alpha": 0.1})],
         },
         server_opt=1.0,
@@ -299,24 +397,24 @@ def _server_routine(
         netwk,
         optim,
         metrics=["r2"],
-        checkpoint=folder,
+        checkpoint={"folder": folder, "max_history": 1},
     )
     # Set up hyper-parameters and run training.
     config = FLRunConfig.from_params(
         rounds=rounds,
-        register={"min_clients": clients},
+        register={"min_clients": clients, "timeout": 10},
         training={
             "n_epoch": 1,
             "batch_size": b_size,
             "drop_remainder": False,
         },
     )
-    server.run(config)
+    await server.async_run(config)
 
 
-def _client_routine(
-    train: InMemoryDataset,
-    valid: InMemoryDataset,
+async def async_run_client(
+    train: Dataset,
+    valid: Dataset,
     name: str = "client",
 ) -> None:
     """Routine to run a FL client, called by `run_declearn_experiment`."""
@@ -324,83 +422,62 @@ def _client_routine(
         protocol="websockets", server_uri="ws://localhost:8765", name=name
     )
     client = FederatedClient(netwk, train, valid)
-    client.run()
-
-
-def test_declearn_baseline(
-    lrate: float = 0.01,
-    rounds: int = 10,
-    b_size: int = 1,
-) -> None:
-    """Run a baseline uing declearn APIs to emulate a centralized setting.
+    await client.async_run()
 
-    This function uses declearn but sets up a single model and performs
-    training on the concatenation of "client-wise" datasets.
 
-    Parameters
-    ----------
-    lrate: float, default=0.01
-        Learning rate of the SGD algorithm.
-    rounds: int, default=10
-        Number of training rounds to perform, i.e. number of epochs.
-    b_size: int, default=10
-        Batch size fot the training (and validation) data.
-        Batching will be performed without shuffling nor replacement,
-        and the final batch may be smaller than the others (no drop).
-    """
-    # Generate the client datasets, then centralize them into numpy arrays.
-    train, valid = prep_full_dataset()
-    d_train = InMemoryDataset(train[0], train[1])
-    # Set up a declearn model and a vanilla SGD optimizer.
-    model = get_model("numpy")
-    model.initialize({"features_shape": (d_train.data.shape[1],)})
-    opt = Optimizer(lrate=lrate, regularizers=[("lasso", {"alpha": 0.1})])
-    # Iteratively train the model, evaluating it after each epoch.
-    for _ in range(rounds):
-        # Run the training round.
-        for batch in d_train.generate_batches(batch_size=b_size):
-            grads = model.compute_batch_gradients(batch)
-            opt.apply_gradients(model, grads)
-    # Check the final R2 value.
-    r_sq = RSquared()
-    r_sq.update(*model.compute_batch_predictions((valid[0], valid[1], None)))
-    assert r_sq.get_result()["r2"] > R2_THRESHOLD
-
-
-def test_sklearn_baseline(
+@pytest.mark.asyncio
+async def test_declearn_federated(
+    framework: FrameworkType,
     lrate: float = 0.01,
     rounds: int = 10,
     b_size: int = 1,
+    clients: int = 3,
 ) -> None:
-    """Run a baseline using scikit-learn to emulate a centralized setting.
+    """Run an experiment using declearn to perform a federative training.
 
-    This function does not use declearn. It sets up a single sklearn
-    model and performs training on the full dataset.
+    This function runs the experiment using declearn.
+    It sets up and runs the server and client-wise routines in separate
+    processes, to enable their concurrent execution.
 
     Parameters
     ----------
+    framework: str
+        Framework of the model to train and evaluate.
     lrate: float, default=0.01
-        Learning rate of the SGD algorithm.
+        Learning rate of the SGD algorithm performed by clients.
     rounds: int, default=10
-        Number of training rounds to perform, i.e. number of epochs.
+        Number of FL training rounds to perform.
+        At each round, each client will perform a full epoch of training.
     b_size: int, default=10
         Batch size fot the training (and validation) data.
         Batching will be performed without shuffling nor replacement,
         and the final batch may be smaller than the others (no drop).
+    clients: int, default=3
+        Number of federated clients to set up and run.
     """
-    # Generate the client datasets, then centralize them into numpy arrays.
-
-    train, valid = prep_full_dataset()
-    # Set up a scikit-learn model, implementing step-wise gradient descent.
-    sgd = SGDRegressor(
-        loss="squared_error",
-        penalty="l1",
-        alpha=0.1,
-        eta0=lrate / b_size,  # adjust learning rate for (dropped) batch size
-        learning_rate="constant",  # disable scaling, unused in declearn
-        max_iter=rounds,
-    )
-    # Iteratively train the model, evaluating it after each epoch.
-    for _ in range(rounds):
-        sgd.partial_fit(train[0], train[1])
-    assert sgd.score(valid[0], valid[1]) > R2_THRESHOLD
+    datasets = prep_client_datasets(framework, clients)
+    with tempfile.TemporaryDirectory() as folder:
+        # Set up the server and client coroutines.
+        coro_server = async_run_server(
+            folder, framework, lrate, rounds, b_size, clients
+        )
+        coro_clients = [
+            async_run_client(train, valid, name=f"client_{i}")
+            for i, (train, valid) in enumerate(datasets)
+        ]
+        # Run the coroutines concurrently using asyncio.
+        outputs = await asyncio.gather(
+            coro_server, *coro_clients, return_exceptions=True
+        )
+        # Assert that no exceptions occurred during the process.
+        errors = "\n".join(
+            repr(exc) for exc in outputs if isinstance(exc, Exception)
+        )
+        assert not errors, f"The FL process failed:\n{errors}"
+        # Assert that the federated model converged above an expected value.
+        with open(
+            os.path.join(folder, "metrics.json"), encoding="utf-8"
+        ) as file:
+            metrics = json.load(file)
+        best_r2 = max(values["r2"] for values in metrics.values())
+        assert best_r2 >= R2_THRESHOLD, "The FL training did not converge"
diff --git a/test/main/test_data_info.py b/test/main/test_data_info.py
new file mode 100644
index 0000000000000000000000000000000000000000..de7aa14761870e9a8b0024a21bd69cdccfce1804
--- /dev/null
+++ b/test/main/test_data_info.py
@@ -0,0 +1,95 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for 'declearn.main.utils.aggregate_clients_data_info'."""
+
+from unittest import mock
+
+import pytest
+
+from declearn.main.utils import AggregationError, aggregate_clients_data_info
+
+
+class TestAggregateClientsDataInfo:
+    """Unit tests for 'declearn.main.utils.aggregate_clients_data_info'."""
+
+    def test_with_valid_inputs(self) -> None:
+        """Test 'aggregate_clients_data_info' with valid inputs."""
+        clients_data_info = {
+            "client_a": {"n_samples": 10},
+            "client_b": {"n_samples": 32},
+        }
+        results = aggregate_clients_data_info(clients_data_info, {"n_samples"})
+        assert results == {"n_samples": 42}
+
+    def test_with_missing_fields(self) -> None:
+        """Test 'aggregate_clients_data_info' with some missing fields."""
+        clients_data_info = {
+            "client_a": {"n_samples": 10},
+            "client_b": {"n_samples": 32},
+        }
+        with pytest.raises(AggregationError):
+            aggregate_clients_data_info(
+                clients_data_info,
+                required_fields={"n_samples", "features_shape"},
+            )
+
+    def test_with_invalid_values(self) -> None:
+        """Test 'aggregate_clients_data_info' with some invalid values."""
+        clients_data_info = {
+            "client_a": {"n_samples": 10},
+            "client_b": {"n_samples": -1},
+        }
+        with pytest.raises(AggregationError):
+            aggregate_clients_data_info(
+                clients_data_info, required_fields={"n_samples"}
+            )
+
+    def test_with_incompatible_values(self) -> None:
+        """Test 'aggregate_clients_data_info' with some incompatible values."""
+        clients_data_info = {
+            "client_a": {"features_shape": (100,)},
+            "client_b": {"features_shape": (128,)},
+        }
+        with pytest.raises(AggregationError):
+            aggregate_clients_data_info(
+                clients_data_info, required_fields={"features_shape"}
+            )
+
+    def test_with_unexpected_keyerror(self) -> None:
+        """Test 'aggregate_clients_data_info' with an unforeseen KeyError."""
+        with mock.patch(
+            "declearn.main.utils._data_info.aggregate_data_info"
+        ) as patch_agg:
+            patch_agg.side_effect = KeyError("Forced KeyError")
+            with pytest.raises(AggregationError):
+                aggregate_clients_data_info(
+                    clients_data_info={"client_a": {}, "client_b": {}},
+                    required_fields=set(),
+                )
+
+    def test_with_unexpected_exception(self) -> None:
+        """Test 'aggregate_clients_data_info' with an unforeseen Exception."""
+        with mock.patch(
+            "declearn.main.utils._data_info.aggregate_data_info"
+        ) as patch_agg:
+            patch_agg.side_effect = Exception("Forced Exception")
+            with pytest.raises(AggregationError):
+                aggregate_clients_data_info(
+                    clients_data_info={"client_a": {}, "client_b": {}},
+                    required_fields=set(),
+                )
diff --git a/test/main/test_early_stopping.py b/test/main/test_early_stopping.py
new file mode 100644
index 0000000000000000000000000000000000000000..764b59ff75b96bec975a1cc067f85b13394bc247
--- /dev/null
+++ b/test/main/test_early_stopping.py
@@ -0,0 +1,108 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for 'declearn.main.utils.EarlyStopping'."""
+
+
+from declearn.main.utils import EarlyStopping
+
+
+class TestEarlyStopping:
+    """Unit tests for 'declearn.main.utils.EarlyStopping'."""
+
+    def test_keep_training_initial(self) -> None:
+        """Test that a brand new EarlyStopping indicates to train."""
+        early_stop = EarlyStopping()
+        assert early_stop.keep_training
+
+    def test_update_first(self) -> None:
+        """Test that an instantiated EarlyStopping's update works."""
+        early_stop = EarlyStopping()
+        keep_going = early_stop.update(1.0)
+        assert keep_going
+        assert keep_going == early_stop.keep_training
+
+    def test_update_twice(self) -> None:
+        """Test that an EarlyStopping can be reached in a simple case."""
+        early_stop = EarlyStopping(tolerance=0.0, patience=1, decrease=True)
+        assert early_stop.update(1.0)
+        assert not early_stop.update(1.0)
+        assert not early_stop.keep_training
+
+    def test_reset_after_stopping(self) -> None:
+        """Test that 'EarlyStopping.reset()' works properly."""
+        # Reach the criterion once.
+        early_stop = EarlyStopping(tolerance=0.0, patience=1, decrease=True)
+        assert early_stop.update(1.0)
+        assert not early_stop.update(1.0)
+        assert not early_stop.keep_training
+        # Reset and test that the criterion has been properly reset.
+        early_stop.reset()
+        assert early_stop.keep_training
+        assert early_stop.update(1.0)
+        # Reach the criterion for the second time.
+        assert not early_stop.update(1.0)
+        assert not early_stop.keep_training
+
+    def test_with_two_steps_patience(self) -> None:
+        """Test an EarlyStopping criterion with 2-steps patience."""
+        early_stop = EarlyStopping(tolerance=0.0, patience=2, decrease=True)
+        assert early_stop.update(1.0)
+        assert early_stop.update(1.5)  # patience tempers stopping
+        assert early_stop.update(0.0)  # patience is reset
+        assert early_stop.update(0.5)  # patience tempers stopping
+        assert not early_stop.update(0.2)  # patience is exhausted
+
+    def test_with_absolute_tolerance_positive(self) -> None:
+        """Test an EarlyStopping criterion with 0.2 absolute tolerance."""
+        early_stop = EarlyStopping(tolerance=0.2, patience=1, decrease=True)
+        assert early_stop.update(1.0)
+        assert early_stop.update(0.7)
+        assert not early_stop.update(0.6)  # progress below tolerance
+
+    def test_with_absolute_tolerance_negative(self) -> None:
+        """Test an EarlyStopping criterion with -0.5 absolute tolerance."""
+        early_stop = EarlyStopping(tolerance=-0.5, patience=1, decrease=True)
+        assert early_stop.update(1.0)
+        assert early_stop.update(1.2)  # regression below tolerance
+        assert not early_stop.update(1.6)  # regression above tolerance
+
+    def test_with_relative_tolerance_positive(self) -> None:
+        """Test an EarlyStopping criterion with 0.1 relative tolerance."""
+        early_stop = EarlyStopping(
+            tolerance=0.1, patience=1, decrease=True, relative=True
+        )
+        assert early_stop.update(1.0)
+        assert early_stop.update(0.8)  # progress above tolerance
+        assert not early_stop.update(0.75)  # progress below tolerance
+
+    def test_with_relative_tolerance_negative(self) -> None:
+        """Test an EarlyStopping criterion with -0.1 relative tolerance."""
+        early_stop = EarlyStopping(
+            tolerance=-0.1, patience=1, decrease=True, relative=True
+        )
+        assert early_stop.update(1.0)
+        assert early_stop.update(0.80)  # progress
+        assert early_stop.update(0.85)  # regression below tolerance
+        assert not early_stop.update(0.89)  # regression above tolerance
+
+    def test_with_increasing_metric(self) -> None:
+        """Test an EarlyStopping that monitors an increasing metric."""
+        early_stop = EarlyStopping(tolerance=0.0, patience=1, decrease=False)
+        assert early_stop.update(1.0)
+        assert early_stop.update(2.0)  # progress
+        assert not early_stop.update(1.5)  # regression (no patience/tolerance)
diff --git a/test/metrics/test_mae_mse.py b/test/metrics/test_mae_mse.py
index cd55e880d6561de0d5b33636be8cb57e4f21fdb0..de472efdf1fa668752e61e26b14c6ff9b187a29c 100644
--- a/test/metrics/test_mae_mse.py
+++ b/test/metrics/test_mae_mse.py
@@ -24,7 +24,7 @@ import numpy as np
 import pytest
 
 from declearn.metrics import MeanAbsoluteError, MeanSquaredError, Metric
-from declearn.test_utils import make_importable
+from declearn.test_utils import assert_dict_equal, make_importable
 
 # relative imports from `metric_testing.py`
 with make_importable(os.path.dirname(__file__)):
@@ -69,7 +69,7 @@ class MeanMetricTestSuite(MetricTestSuite):
     """Unit tests suite for `MeanMetric` subclasses."""
 
     def test_update_errors(self, test_case: MetricTestCase) -> None:
-        """Test that `update` raises on improper `s_wght` shapes."""
+        """Test that `update` raises on improper input shapes."""
         metric = test_case.metric
         inputs = test_case.inputs
         # Test with multi-dimensional sample weights.
@@ -80,12 +80,34 @@ class MeanMetricTestSuite(MetricTestSuite):
         s_wght = np.ones(shape=(len(inputs["y_pred"]) + 2,))
         with pytest.raises(ValueError):
             metric.update(inputs["y_true"], inputs["y_pred"], s_wght)
+        # Test with mismatching-shape inputs.
+        y_true = inputs["y_true"]
+        y_pred = np.stack([inputs["y_pred"], inputs["y_pred"]], axis=-1)
+        with pytest.raises(ValueError):
+            metric.update(y_true, y_pred, s_wght)
 
     def test_zero_result(self, test_case: MetricTestCase) -> None:
         """Test that `get_results` works with zero-valued divisor."""
         metric = test_case.metric
         assert metric.get_result() == {metric.name: 0.0}
 
+    def test_update_expanded_shape(self, test_case: MetricTestCase) -> None:
+        """Test that the metric supports expanded-dim input predictions."""
+        # Gather states with basic inputs.
+        metric, inputs = test_case.metric, test_case.inputs
+        metric.update(**inputs)
+        states = metric.get_states()
+        metric.reset()
+        # Do the same with expanded-dim predictions.
+        metric.update(
+            inputs["y_true"],
+            np.expand_dims(inputs["y_pred"], -1),
+            inputs.get("s_wght"),
+        )
+        st_bis = metric.get_states()
+        # Verify that results are the same.
+        assert_dict_equal(states, st_bis)
+
 
 @pytest.mark.parametrize("weighted", [False, True], ids=["base", "weighted"])
 @pytest.mark.parametrize("case", ["mae"])
diff --git a/test/optimizer/test_modules.py b/test/optimizer/test_modules.py
index 38c209380273f46bed6e44707e20edf53f95b8f1..479dec67480787ba9b9dc8aca0a88df6aea56d48 100644
--- a/test/optimizer/test_modules.py
+++ b/test/optimizer/test_modules.py
@@ -38,6 +38,7 @@ import os
 from typing import Type
 
 import pytest
+from declearn.optimizer import list_optim_modules
 from declearn.optimizer.modules import NoiseModule, OptiModule
 from declearn.test_utils import (
     FrameworkType,
@@ -46,7 +47,7 @@ from declearn.test_utils import (
     assert_json_serializable_dict,
     make_importable,
 )
-from declearn.utils import access_types_mapping, set_device_policy
+from declearn.utils import set_device_policy
 
 # relative imports from `optim_testing.py`
 with make_importable(os.path.dirname(__file__)):
@@ -54,7 +55,7 @@ with make_importable(os.path.dirname(__file__)):
 
 
 # Access the list of modules to test; remove some that have dedicated tests.
-OPTIMODULE_SUBCLASSES = access_types_mapping(group="OptiModule")
+OPTIMODULE_SUBCLASSES = list_optim_modules()
 OPTIMODULE_SUBCLASSES.pop("tensorflow-optim", None)
 OPTIMODULE_SUBCLASSES.pop("torch-optim", None)
 
diff --git a/test/optimizer/test_regularizers.py b/test/optimizer/test_regularizers.py
index 424a9df2af1c6db551a2b18c550e3ff7412cea39..9b5c21e34a18e8d00328e28330b32644082cb117 100644
--- a/test/optimizer/test_regularizers.py
+++ b/test/optimizer/test_regularizers.py
@@ -38,16 +38,16 @@ from typing import Type
 
 import pytest
 
+from declearn.optimizer import list_optim_regularizers
 from declearn.optimizer.regularizers import Regularizer
 from declearn.test_utils import make_importable
-from declearn.utils import access_types_mapping
 
 # relative imports from `optim_testing.py`
 with make_importable(os.path.dirname(__file__)):
     from optim_testing import PluginTestBase
 
 
-REGULARIZER_SUBCLASSES = access_types_mapping(group="Regularizer")
+REGULARIZER_SUBCLASSES = list_optim_regularizers()
 
 
 @pytest.mark.parametrize(
diff --git a/test/quickrun/test_quickrun_utils.py b/test/quickrun/test_quickrun_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..bec273314210fc8e8fd14d20a35a845bcf632d2e
--- /dev/null
+++ b/test/quickrun/test_quickrun_utils.py
@@ -0,0 +1,266 @@
+# coding: utf-8
+
+"""Tests for some 'declearn.quickrun' backend utils."""
+
+import os
+import pathlib
+from typing import List
+
+import pytest
+
+from declearn.quickrun import parse_data_folder
+from declearn.quickrun._config import DataSourceConfig
+from declearn.quickrun._parser import (
+    get_data_folder_path,
+    list_client_names,
+)
+from declearn.quickrun._run import get_toml_folder
+
+
+class TestGetTomlFolder:
+    """Tests for 'declearn.quickrun._run.get_toml_folder'."""
+
+    def test_get_toml_folder_from_file(
+        self,
+        tmp_path: pathlib.Path,
+    ) -> None:
+        """Test that 'get_toml_folder' works with a TOML file path."""
+        config = os.path.join(tmp_path, "config.toml")
+        with open(config, "w", encoding="utf-8") as file:
+            file.write("")
+        toml, folder = get_toml_folder(config)
+        assert toml == config
+        assert folder == tmp_path.as_posix()
+
+    def test_get_toml_folder_from_folder(
+        self,
+        tmp_path: pathlib.Path,
+    ) -> None:
+        """Test that 'get_toml_folder' works with a folder path."""
+        config = os.path.join(tmp_path, "config.toml")
+        with open(config, "w", encoding="utf-8") as file:
+            file.write("")
+        toml, folder = get_toml_folder(tmp_path.as_posix())
+        assert toml == config
+        assert folder == tmp_path.as_posix()
+
+    def test_get_toml_folder_from_file_fails(
+        self,
+        tmp_path: pathlib.Path,
+    ) -> None:
+        """Test that it fails with a path to a non-existing file."""
+        config = os.path.join(tmp_path, "config.toml")
+        with pytest.raises(FileNotFoundError):
+            get_toml_folder(config)
+
+    def test_get_toml_folder_from_folder_fails(
+        self,
+        tmp_path: pathlib.Path,
+    ) -> None:
+        """Test that it fails with a folder lacking a 'config.toml' file."""
+        with pytest.raises(FileNotFoundError):
+            get_toml_folder(tmp_path.as_posix())
+
+
+class TestGetDataFolderPath:
+    """Tests for 'declearn.quickrun._parser.get_data_folder_path'."""
+
+    def test_get_data_folder_path_from_data_folder(
+        self,
+        tmp_path: pathlib.Path,
+    ) -> None:
+        """Test that it works with a valid 'data_folder' argument."""
+        path = get_data_folder_path(data_folder=tmp_path.as_posix())
+        assert isinstance(path, pathlib.Path)
+        assert path == tmp_path
+
+    def test_get_data_folder_path_from_root_folder(
+        self,
+        tmp_path: pathlib.Path,
+    ) -> None:
+        """Test that it works with a valid 'root_folder' argument."""
+        data_dir = os.path.join(tmp_path, "data")
+        os.makedirs(data_dir)
+        path = get_data_folder_path(root_folder=tmp_path.as_posix())
+        assert isinstance(path, pathlib.Path)
+        assert path.as_posix() == data_dir
+
+    def test_get_data_folder_path_fails_no_inputs(
+        self,
+    ) -> None:
+        """Test that it fails with no folder specification."""
+        with pytest.raises(ValueError):
+            get_data_folder_path(None, None)
+
+    def test_get_data_folder_path_fails_invalid_data_folder(
+        self,
+        tmp_path: pathlib.Path,
+    ) -> None:
+        """Test that it fails with an invalid data_folder."""
+        missing_folder = os.path.join(tmp_path, "data")
+        with pytest.raises(ValueError):
+            get_data_folder_path(data_folder=missing_folder)
+
+    def test_get_data_folder_path_fails_from_root_no_data(
+        self,
+        tmp_path: pathlib.Path,
+    ) -> None:
+        """Test that it fails with an invalid root_folder (no data)."""
+        with pytest.raises(ValueError):
+            get_data_folder_path(root_folder=tmp_path.as_posix())
+
+    def test_get_data_folder_path_fails_from_root_multiple_data(
+        self,
+        tmp_path: pathlib.Path,
+    ) -> None:
+        """Test that it fails with multiple data* under root_folder."""
+        os.makedirs(os.path.join(tmp_path, "data_1"))
+        os.makedirs(os.path.join(tmp_path, "data_2"))
+        with pytest.raises(ValueError):
+            get_data_folder_path(root_folder=tmp_path.as_posix())
+
+
+class TestListClientNames:
+    """Tests for the 'declearn.quickrun._parser.list_client_names' function."""
+
+    def test_list_client_names_from_folder(
+        self,
+        tmp_path: pathlib.Path,
+    ) -> None:
+        """Test that it works with a data folder."""
+        os.makedirs(os.path.join(tmp_path, "client_1"))
+        os.makedirs(os.path.join(tmp_path, "client_2"))
+        names = list_client_names(tmp_path)
+        assert isinstance(names, list)
+        assert sorted(names) == ["client_1", "client_2"]
+
+    def test_list_client_names_from_names(
+        self,
+        tmp_path: pathlib.Path,
+    ) -> None:
+        """Test that it works with pre-specified names."""
+        os.makedirs(os.path.join(tmp_path, "client_1"))
+        os.makedirs(os.path.join(tmp_path, "client_2"))
+        names = list_client_names(tmp_path, ["client_2"])
+        assert names == ["client_2"]
+
+    def test_list_client_names_fails_invalid_names(
+        self,
+        tmp_path: pathlib.Path,
+    ) -> None:
+        """Test that it works with invalid pre-specified names."""
+        with pytest.raises(ValueError):
+            list_client_names(tmp_path, "invalid-type")  # type: ignore
+        with pytest.raises(ValueError):
+            list_client_names(tmp_path, ["client_2"])
+
+
+class TestParseDataFolder:
+    """Docstring."""
+
+    @staticmethod
+    def setup_data_folder(
+        data_folder: str,
+        client_names: List[str],
+        file_names: List[str],
+    ) -> None:
+        """Set up a data folder, with client subfolders and empty files."""
+        for cname in client_names:
+            folder = os.path.join(data_folder, cname)
+            os.makedirs(folder)
+            for fname in file_names:
+                path = os.path.join(folder, fname)
+                with open(path, "w", encoding="utf-8") as file:
+                    file.write("")
+
+    def test_parse_data_folder_with_default_names(
+        self,
+        tmp_path: pathlib.Path,
+    ) -> None:
+        """Test 'parse_data_folder' with default file names."""
+        # Setup a data folder with a couple of clients and default files.
+        data_folder = tmp_path.as_posix()
+        client_names = ["client-1", "client-2"]
+        file_names = [
+            # fmt: off
+            "train_data", "train_target", "valid_data", "valid_target"
+        ]
+        self.setup_data_folder(data_folder, client_names, file_names)
+        # Write up the expected outputs.
+        expected = {
+            cname: {
+                fname: os.path.join(data_folder, cname, fname)
+                for fname in file_names
+            }
+            for cname in client_names
+        }
+        # Run the function and validate its outputs.
+        config = DataSourceConfig(
+            data_folder=data_folder,
+            client_names=None,
+            dataset_names=None,
+        )
+        clients = parse_data_folder(config)
+        assert clients == expected
+
+    def test_parse_data_folder_with_custom_names(
+        self,
+        tmp_path: pathlib.Path,
+    ) -> None:
+        """Test 'parse_data_folder' with custom file names."""
+        # Setup a data folder with a couple of clients and default files.
+        data_folder = tmp_path.as_posix()
+        client_names = ["client-1", "client-2"]
+        base_names = [
+            # fmt: off
+            "train_data", "train_target", "valid_data", "valid_target"
+        ]
+        file_names = ["x_train", "y_train", "x_valid", "y_valid"]
+        self.setup_data_folder(data_folder, client_names, file_names)
+        # Write up the expected outputs.
+        expected = {
+            cname: {
+                bname: os.path.join(data_folder, cname, fname)
+                for bname, fname in zip(base_names, file_names)
+            }
+            for cname in client_names
+        }
+        # Run the function and validate its outputs.
+        config = DataSourceConfig(
+            data_folder=data_folder,
+            client_names=None,
+            dataset_names=dict(zip(base_names, file_names)),
+        )
+        clients = parse_data_folder(config)
+        assert clients == expected
+        # Verify that it would not work without the argument.
+        config = DataSourceConfig(
+            data_folder=data_folder,
+            client_names=None,
+            dataset_names=None,
+        )
+        with pytest.raises(ValueError):
+            clients = parse_data_folder(config)
+
+    def test_parse_data_folder_fails_multiple_files(
+        self,
+        tmp_path: pathlib.Path,
+    ) -> None:
+        """Test that 'parse_data_folder' fails with same-name files."""
+        # Setup a data folder with a couple of clients and duplicated files.
+        data_folder = tmp_path.as_posix()
+        client_names = ["client-1", "client-2"]
+        file_names = [
+            # fmt: off
+            "train_data", "train_target", "valid_data", "valid_target",
+            "train_data.bis"  # duplicated name prefix
+        ]
+        self.setup_data_folder(data_folder, client_names, file_names)
+        # Verify that the expected exception is raised.
+        config = DataSourceConfig(
+            data_folder=data_folder,
+            client_names=None,
+            dataset_names=None,
+        )
+        with pytest.raises(ValueError):
+            parse_data_folder(config)