diff --git a/test/functional/test_regression.py b/test/functional/test_regression.py
index 8e1bda9a31ad09009712e1a267aec220a72ca198..da526f969abe2e1c7107506c206b2da9aba64c86 100644
--- a/test/functional/test_regression.py
+++ b/test/functional/test_regression.py
@@ -71,6 +71,7 @@ try:
 except ModuleNotFoundError:
     pass
 else:
+    from declearn.dataset.tensorflow import TensorflowDataset
     from declearn.model.tensorflow import TensorflowModel
 # torch imports
 try:
@@ -139,6 +140,12 @@ def get_dataset(framework: FrameworkType, inputs, labels):
         inputs = torch.from_numpy(inputs)
         labels = torch.from_numpy(labels)
         return TorchDataset(torch.utils.data.TensorDataset(inputs, labels))
+    if framework == "tensorflow":
+        inputs = tf.convert_to_tensor(inputs)
+        labels = tf.convert_to_tensor(labels)
+        return TensorflowDataset(
+            tf.data.Dataset.from_tensor_slices((inputs, labels))
+        )
     return InMemoryDataset(inputs, labels, expose_data_type=True)
 
 
@@ -164,7 +171,6 @@ def prep_client_datasets(
     datasets: list[(InMemoryDataset, InMemoryDataset)]
         List of client-wise (train, valid) pair of datasets.
     """
-
     n_samples = clients * (n_train + n_valid)
     # false-positive; pylint: disable=unbalanced-tuple-unpacking
     inputs, target = make_regression(