diff --git a/test/functional/test_regression.py b/test/functional/test_toy_reg.py
similarity index 60%
rename from test/functional/test_regression.py
rename to test/functional/test_toy_reg.py
index da526f969abe2e1c7107506c206b2da9aba64c86..469e06f3f1605cc1dd0713bdd14f58e994664dc5 100644
--- a/test/functional/test_regression.py
+++ b/test/functional/test_toy_reg.py
@@ -42,16 +42,20 @@ The convergence results of those experiments is then compared.
 
 """
 
+import asyncio
+import dataclasses
 import json
+import os
 import tempfile
 from typing import List, Tuple
 
 import numpy as np
-from sklearn.datasets import make_regression  # type: ignore
-from sklearn.linear_model import SGDRegressor  # type: ignore
+import pytest
+import sklearn.datasets  # type: ignore
+import sklearn.linear_model  # type: ignore
 
 from declearn.communication import NetworkClientConfig, NetworkServerConfig
-from declearn.dataset import InMemoryDataset
+from declearn.dataset import Dataset, InMemoryDataset
 from declearn.main import FederatedClient, FederatedServer
 from declearn.main.config import FLOptimConfig, FLRunConfig
 from declearn.metrics import RSquared
@@ -59,7 +63,7 @@ from declearn.model.api import Model
 from declearn.model.sklearn import SklearnSGDModel
 from declearn.optimizer import Optimizer
 from declearn.test_utils import FrameworkType
-from declearn.utils import run_as_processes, set_device_policy
+from declearn.utils import set_device_policy
 
 # optional frameworks' dependencies pylint: disable=ungrouped-imports
 # pylint: disable=duplicate-code
@@ -72,7 +76,7 @@ except ModuleNotFoundError:
     pass
 else:
     from declearn.dataset.tensorflow import TensorflowDataset
-    from declearn.model.tensorflow import TensorflowModel
+    from declearn.model.tensorflow import TensorflowModel, TensorflowVector
 # torch imports
 try:
     import torch
@@ -80,8 +84,7 @@ except ModuleNotFoundError:
     pass
 else:
     from declearn.dataset.torch import TorchDataset
-    from declearn.model.torch import TorchModel
-# pylint: enable=duplicate-code
+    from declearn.model.torch import TorchModel, TorchVector
 # haiku and jax imports
 try:
     import haiku as hk
@@ -89,7 +92,7 @@ try:
 except ModuleNotFoundError:
     pass
 else:
-    from declearn.model.haiku import HaikuModel
+    from declearn.model.haiku import HaikuModel, JaxNumpyVector
 
     def haiku_model_fn(inputs: jax.Array) -> jax.Array:
         """Simple linear model implemented with Haiku."""
@@ -97,40 +100,81 @@ else:
 
     def haiku_loss_fn(y_pred: jax.Array, y_true: jax.Array) -> jax.Array:
         """Sample-wise squared error loss function."""
+        y_pred = jax.numpy.squeeze(y_pred)
         return (y_pred - y_true) ** 2
 
 
+# pylint: disable=duplicate-code
+
 SEED = 0
-R2_THRESHOLD = 0.999
+R2_THRESHOLD = 0.9999
 
 
 set_device_policy(gpu=False)  # disable GPU use to avoid concurrence
 
 
 def get_model(framework: FrameworkType) -> Model:
-    """Set up a simple toy regression model."""
+    """Set up a simple toy regression model, with zero-valued weights."""
     set_device_policy(gpu=False)  # disable GPU use to avoid concurrence
     if framework == "numpy":
-        np.random.seed(SEED)  # set seed
-        model = SklearnSGDModel.from_parameters(
-            kind="regressor", loss="squared_error", penalty="none"
-        )  # type: Model
-    elif framework == "tensorflow":
-        tf.random.set_seed(SEED)  # set seed
-        tfmod = tf.keras.Sequential(tf.keras.layers.Dense(units=1))
-        tfmod.build([None, 100])
-        model = TensorflowModel(tfmod, loss="mean_squared_error")
-    elif framework == "torch":
-        torch.manual_seed(SEED)  # set seed
-        torchmod = torch.nn.Sequential(
-            torch.nn.Linear(100, 1, bias=True),
-            torch.nn.Flatten(0),
-        )
-        model = TorchModel(torchmod, loss=torch.nn.MSELoss())
-    elif framework == "jax":
-        model = HaikuModel(haiku_model_fn, loss=haiku_loss_fn)
-    else:
-        raise ValueError("unrecognised framework")
+        return _get_model_numpy()
+    if framework == "tensorflow":
+        return _get_model_tflow()
+    if framework == "torch":
+        return _get_model_torch()
+    if framework == "jax":
+        return _get_model_haiku()
+    raise ValueError(f"Unrecognised model framework: '{framework}'.")
+
+
+def _get_model_numpy() -> SklearnSGDModel:
+    """Return a linear model with MSE loss in Sklearn, with zero weights."""
+    np.random.seed(SEED)  # set seed
+    model = SklearnSGDModel.from_parameters(
+        kind="regressor", loss="squared_error", penalty="none"
+    )
+    return model
+
+
+def _get_model_tflow() -> TensorflowModel:
+    """Return a linear model with MSE loss in TensorFlow, with zero weights."""
+    tf.random.set_seed(SEED)  # set seed
+    tfmod = tf.keras.Sequential(tf.keras.layers.Dense(units=1))
+    tfmod.build([None, 100])
+    model = TensorflowModel(tfmod, loss="mean_squared_error")
+    zeros = {
+        key: tf.zeros_like(val)
+        for key, val in model.get_weights().coefs.items()
+    }
+    model.set_weights(TensorflowVector(zeros))
+    return model
+
+
+def _get_model_torch() -> TorchModel:
+    """Return a linear model with MSE loss in Torch, with zero weights."""
+    torch.manual_seed(SEED)  # set seed
+    torchmod = torch.nn.Sequential(
+        torch.nn.Linear(100, 1, bias=True),
+        torch.nn.Flatten(0),
+    )
+    model = TorchModel(torchmod, loss=torch.nn.MSELoss())
+    zeros = {
+        key: torch.zeros_like(val)
+        for key, val in model.get_weights().coefs.items()
+    }
+    model.set_weights(TorchVector(zeros))
+    return model
+
+
+def _get_model_haiku() -> HaikuModel:
+    """Return a linear model with MSE loss in Haiku, with zero weights."""
+    model = HaikuModel(haiku_model_fn, loss=haiku_loss_fn)
+    model.initialize({"data_type": "float32", "features_shape": (100,)})
+    zeros = {
+        key: jax.numpy.zeros_like(val)
+        for key, val in model.get_weights().coefs.items()
+    }
+    model.set_weights(JaxNumpyVector(zeros))
     return model
 
 
@@ -149,46 +193,6 @@ def get_dataset(framework: FrameworkType, inputs, labels):
     return InMemoryDataset(inputs, labels, expose_data_type=True)
 
 
-def prep_client_datasets(
-    framework: FrameworkType,
-    clients: int = 3,
-    n_train: int = 100,
-    n_valid: int = 50,
-) -> List[Tuple[InMemoryDataset, InMemoryDataset]]:
-    """Generate and split toy data for a regression problem.
-
-    Parameters
-    ----------
-    clients: int, default=3
-        Number of clients, i.e. of dataset shards to generate.
-    n_train: int, default=30
-        Number of training samples per client.
-    n_valid: int, default=30
-        Number of validation samples per client.
-
-    Returns
-    -------
-    datasets: list[(InMemoryDataset, InMemoryDataset)]
-        List of client-wise (train, valid) pair of datasets.
-    """
-    n_samples = clients * (n_train + n_valid)
-    # false-positive; pylint: disable=unbalanced-tuple-unpacking
-    inputs, target = make_regression(
-        n_samples, n_features=100, n_informative=10, random_state=SEED
-    )
-    inputs, target = inputs.astype("float32"), target.astype("float32")
-    # Wrap up the data into client-wise pairs of dataset.
-    out = []  # type: List[Tuple[InMemoryDataset, InMemoryDataset]]
-    for idx in range(clients):
-        start = (n_train + n_valid) * idx
-        mid = start + n_train
-        end = mid + n_valid
-        train = get_dataset(framework, inputs[start:mid], target[start:mid])
-        valid = get_dataset(framework, inputs[mid:end], target[mid:end])
-        out.append((train, valid))
-    return out
-
-
 def prep_full_dataset(
     n_train: int = 300,
     n_valid: int = 150,
@@ -210,7 +214,7 @@ def prep_full_dataset(
     """
     n_samples = n_train + n_valid
     # false-positive; pylint: disable=unbalanced-tuple-unpacking
-    inputs, target = make_regression(
+    inputs, target = sklearn.datasets.make_regression(
         n_samples, n_features=100, n_informative=10, random_state=SEED
     )
     inputs, target = inputs.astype("float32"), target.astype("float32")
@@ -222,62 +226,150 @@ def prep_full_dataset(
     return out
 
 
-def test_declearn_experiment(
+def test_sklearn_baseline(
+    lrate: float = 0.04,
+    rounds: int = 10,
+    b_size: int = 10,
+) -> None:
+    """Run a baseline using scikit-learn to emulate a centralized setting.
+
+    This function does not use declearn. It sets up a single sklearn
+    model and performs training on the full dataset.
+
+    Parameters
+    ----------
+    lrate: float, default=0.01
+        Learning rate of the SGD algorithm.
+    rounds: int, default=10
+        Number of training rounds to perform, i.e. number of epochs.
+    b_size: int, default=10
+        Batch size fot the training (and validation) data.
+        Batching will be performed without shuffling nor replacement,
+        and the final batch may be smaller than the others (no drop).
+    """
+    # Generate the client datasets, then centralize them into numpy arrays.
+    train, valid = prep_full_dataset()
+    # Set up a scikit-learn model, implementing step-wise gradient descent.
+    sgd = sklearn.linear_model.SGDRegressor(
+        loss="squared_error",
+        penalty="l1",
+        alpha=0.1,
+        eta0=lrate / b_size,  # adjust learning rate for (dropped) batch size
+        learning_rate="constant",  # disable scaling, unused in declearn
+        max_iter=rounds,
+    )
+    # Iteratively train the model, evaluating it after each epoch.
+    for _ in range(rounds):
+        sgd.partial_fit(train[0], train[1])
+    assert sgd.score(valid[0], valid[1]) > R2_THRESHOLD
+
+
+def test_declearn_baseline(
     framework: FrameworkType,
-    lrate: float = 0.01,
+    lrate: float = 0.02,
     rounds: int = 10,
-    b_size: int = 1,
-    clients: int = 3,
+    b_size: int = 10,
 ) -> None:
-    """Run an experiment using declearn to perform a federative training.
+    """Run a baseline uing declearn APIs to emulate a centralized setting.
 
-    This function runs the experiment using declearn.
-    It sets up and runs the server and client-wise routines in separate
-    processes, to enable their concurrent execution.
+    This function uses declearn but sets up a single model and performs
+    training on the entire toy regression dataset.
 
     Parameters
     ----------
     framework: str
         Framework of the model to train and evaluate.
-    lrate: float, default=0.01
-        Learning rate of the SGD algorithm performed by clients.
+    lrate: float, default=0.02
+        Learning rate of the SGD algorithm.
     rounds: int, default=10
-        Number of FL training rounds to perform.
-        At each round, each client will perform a full epoch of training.
+        Number of training rounds to perform, i.e. number of epochs.
     b_size: int, default=10
         Batch size fot the training (and validation) data.
         Batching will be performed without shuffling nor replacement,
         and the final batch may be smaller than the others (no drop).
-    clients: int, default=3
-        Number of federated clients to set up and run.
     """
-    # pylint: disable=too-many-locals
-    with tempfile.TemporaryDirectory() as folder:
-        # Set up a (func, args) tuple specifying the server process.
-        p_server = (
-            _server_routine,
-            (folder, framework, lrate, rounds, b_size, clients),
+    # Generate the client datasets, then centralize them into numpy arrays.
+    train, valid = prep_full_dataset()
+    dst_train = get_dataset(framework, *train)
+    # Set up a declearn model and a SGD optimizer with Lasso regularization.
+    model = get_model(framework)
+    model.initialize(dataclasses.asdict(dst_train.get_data_specs()))
+    optim = Optimizer(
+        lrate=lrate if framework != "numpy" else (lrate * 2),
+        regularizers=[("lasso", {"alpha": 0.1})],
+    )
+    # Iteratively train the model and evaluate it between rounds.
+    r_sq = RSquared()
+    scores = []  # type: List[float]
+    for _ in range(rounds):
+        for batch in dst_train.generate_batches(
+            batch_size=b_size, drop_remainder=False
+        ):
+            optim.run_train_step(model, batch)
+        pred = model.compute_batch_predictions((*valid, None))
+        r_sq.reset()
+        r_sq.update(*pred)
+        scores.append(r_sq.get_result()["r2"])  # type: ignore
+    # Check that the R2 increased through epochs to reach a high value.
+    print(scores)
+    assert all(scores[i + 1] >= scores[i] for i in range(rounds - 1))
+    assert scores[-1] >= R2_THRESHOLD
+
+
+def prep_client_datasets(
+    framework: FrameworkType,
+    clients: int = 3,
+    n_train: int = 100,
+    n_valid: int = 50,
+) -> List[Tuple[Dataset, Dataset]]:
+    """Generate and split data for a toy sparse regression problem.
+
+    Parameters
+    ----------
+    framework:
+        Name of the framework being tested, based on which the Dataset
+        class choice may be adjusted as well.
+    clients:
+        Number of clients, i.e. of dataset shards to generate.
+    n_train:
+        Number of training samples per client.
+    n_valid:
+        Number of validation samples per client.
+
+    Returns
+    -------
+    datasets:
+        List of client-wise tuple of (train, valid) Dataset instances.
+    """
+    train, valid = prep_full_dataset(
+        n_train=clients * n_train,
+        n_valid=clients * n_valid,
+    )
+    # Wrap up the data into client-wise pairs of dataset.
+    out = []  # type: List[Tuple[Dataset, Dataset]]
+    for idx in range(clients):
+        # Gather the client's training dataset.
+        srt = n_train * idx
+        end = n_train + srt
+        dst_train = get_dataset(
+            framework=framework,
+            inputs=train[0][srt:end],
+            labels=train[1][srt:end],
         )
-        # Set up the (func, args) tuples specifying client-wise processes.
-        datasets = prep_client_datasets(framework, clients)
-        p_client = []
-        for i, data in enumerate(datasets):
-            client = (_client_routine, (data[0], data[1], f"client_{i}"))
-            p_client.append(client)
-        # Run each and every process in parallel.
-        success, outputs = run_as_processes(p_server, *p_client)
-        assert success, "The FL process failed:\n" + "\n".join(
-            str(exc) for exc in outputs if isinstance(exc, RuntimeError)
+        # Gather the client's validation dataset.
+        srt = n_valid * idx
+        end = n_valid + srt
+        dst_valid = get_dataset(
+            framework=framework,
+            inputs=valid[0][srt:end],
+            labels=valid[1][srt:end],
         )
-        # Assert convergence
-        with open(f"{folder}/metrics.json", encoding="utf-8") as file:
-            r2_dict = json.load(file)
-            last_r2_dict = r2_dict.get(max(r2_dict.keys()))
-            final_r2 = float(last_r2_dict.get("r2"))
-            assert final_r2 > R2_THRESHOLD, "The FL training did not converge"
+        # Store both datasets into the output list.
+        out.append((dst_train, dst_valid))
+    return out
 
 
-def _server_routine(
+async def async_run_server(
     folder: str,
     framework: FrameworkType,
     lrate: float = 0.01,
@@ -295,7 +387,7 @@ def _server_routine(
     optim = FLOptimConfig.from_params(
         aggregator="averaging",
         client_opt={
-            "lrate": lrate,
+            "lrate": lrate if framework != "numpy" else (lrate * 2),
             "regularizers": [("lasso", {"alpha": 0.1})],
         },
         server_opt=1.0,
@@ -305,24 +397,24 @@ def _server_routine(
         netwk,
         optim,
         metrics=["r2"],
-        checkpoint=folder,
+        checkpoint={"folder": folder, "max_history": 1},
     )
     # Set up hyper-parameters and run training.
     config = FLRunConfig.from_params(
         rounds=rounds,
-        register={"min_clients": clients},
+        register={"min_clients": clients, "timeout": 10},
         training={
             "n_epoch": 1,
             "batch_size": b_size,
             "drop_remainder": False,
         },
     )
-    server.run(config)
+    await server.async_run(config)
 
 
-def _client_routine(
-    train: InMemoryDataset,
-    valid: InMemoryDataset,
+async def async_run_client(
+    train: Dataset,
+    valid: Dataset,
     name: str = "client",
 ) -> None:
     """Routine to run a FL client, called by `run_declearn_experiment`."""
@@ -330,83 +422,62 @@ def _client_routine(
         protocol="websockets", server_uri="ws://localhost:8765", name=name
     )
     client = FederatedClient(netwk, train, valid)
-    client.run()
-
-
-def test_declearn_baseline(
-    lrate: float = 0.01,
-    rounds: int = 10,
-    b_size: int = 1,
-) -> None:
-    """Run a baseline uing declearn APIs to emulate a centralized setting.
-
-    This function uses declearn but sets up a single model and performs
-    training on the concatenation of "client-wise" datasets.
-
-    Parameters
-    ----------
-    lrate: float, default=0.01
-        Learning rate of the SGD algorithm.
-    rounds: int, default=10
-        Number of training rounds to perform, i.e. number of epochs.
-    b_size: int, default=10
-        Batch size fot the training (and validation) data.
-        Batching will be performed without shuffling nor replacement,
-        and the final batch may be smaller than the others (no drop).
-    """
-    # Generate the client datasets, then centralize them into numpy arrays.
-    train, valid = prep_full_dataset()
-    d_train = InMemoryDataset(train[0], train[1])
-    # Set up a declearn model and a vanilla SGD optimizer.
-    model = get_model("numpy")
-    model.initialize({"features_shape": (d_train.data.shape[1],)})
-    opt = Optimizer(lrate=lrate, regularizers=[("lasso", {"alpha": 0.1})])
-    # Iteratively train the model, evaluating it after each epoch.
-    for _ in range(rounds):
-        # Run the training round.
-        for batch in d_train.generate_batches(batch_size=b_size):
-            grads = model.compute_batch_gradients(batch)
-            opt.apply_gradients(model, grads)
-    # Check the final R2 value.
-    r_sq = RSquared()
-    r_sq.update(*model.compute_batch_predictions((valid[0], valid[1], None)))
-    assert r_sq.get_result()["r2"] > R2_THRESHOLD
+    await client.async_run()
 
 
-def test_sklearn_baseline(
+@pytest.mark.asyncio
+async def test_declearn_federated(
+    framework: FrameworkType,
     lrate: float = 0.01,
     rounds: int = 10,
     b_size: int = 1,
+    clients: int = 3,
 ) -> None:
-    """Run a baseline using scikit-learn to emulate a centralized setting.
+    """Run an experiment using declearn to perform a federative training.
 
-    This function does not use declearn. It sets up a single sklearn
-    model and performs training on the full dataset.
+    This function runs the experiment using declearn.
+    It sets up and runs the server and client-wise routines in separate
+    processes, to enable their concurrent execution.
 
     Parameters
     ----------
+    framework: str
+        Framework of the model to train and evaluate.
     lrate: float, default=0.01
-        Learning rate of the SGD algorithm.
+        Learning rate of the SGD algorithm performed by clients.
     rounds: int, default=10
-        Number of training rounds to perform, i.e. number of epochs.
+        Number of FL training rounds to perform.
+        At each round, each client will perform a full epoch of training.
     b_size: int, default=10
         Batch size fot the training (and validation) data.
         Batching will be performed without shuffling nor replacement,
         and the final batch may be smaller than the others (no drop).
+    clients: int, default=3
+        Number of federated clients to set up and run.
     """
-    # Generate the client datasets, then centralize them into numpy arrays.
-
-    train, valid = prep_full_dataset()
-    # Set up a scikit-learn model, implementing step-wise gradient descent.
-    sgd = SGDRegressor(
-        loss="squared_error",
-        penalty="l1",
-        alpha=0.1,
-        eta0=lrate / b_size,  # adjust learning rate for (dropped) batch size
-        learning_rate="constant",  # disable scaling, unused in declearn
-        max_iter=rounds,
-    )
-    # Iteratively train the model, evaluating it after each epoch.
-    for _ in range(rounds):
-        sgd.partial_fit(train[0], train[1])
-    assert sgd.score(valid[0], valid[1]) > R2_THRESHOLD
+    datasets = prep_client_datasets(framework, clients)
+    with tempfile.TemporaryDirectory() as folder:
+        # Set up the server and client coroutines.
+        coro_server = async_run_server(
+            folder, framework, lrate, rounds, b_size, clients
+        )
+        coro_clients = [
+            async_run_client(train, valid, name=f"client_{i}")
+            for i, (train, valid) in enumerate(datasets)
+        ]
+        # Run the coroutines concurrently using asyncio.
+        outputs = await asyncio.gather(
+            coro_server, *coro_clients, return_exceptions=True
+        )
+        # Assert that no exceptions occurred during the process.
+        errors = "\n".join(
+            repr(exc) for exc in outputs if isinstance(exc, Exception)
+        )
+        assert not errors, f"The FL process failed:\n{errors}"
+        # Assert that the federated model converged above an expected value.
+        with open(
+            os.path.join(folder, "metrics.json"), encoding="utf-8"
+        ) as file:
+            metrics = json.load(file)
+        best_r2 = max(values["r2"] for values in metrics.values())
+        assert best_r2 >= R2_THRESHOLD, "The FL training did not converge"