From d10a42e2de75a854882f941aa5d4d840e098eb83 Mon Sep 17 00:00:00 2001
From: Paul Andrey <paul.andrey@inria.fr>
Date: Thu, 24 Aug 2023 13:54:23 +0200
Subject: [PATCH] Deploy 'TensorflowDataset' in toy-regression integration
 test.

---
 test/functional/test_regression.py | 8 +++++++-
 1 file changed, 7 insertions(+), 1 deletion(-)

diff --git a/test/functional/test_regression.py b/test/functional/test_regression.py
index 8e1bda9a..da526f96 100644
--- a/test/functional/test_regression.py
+++ b/test/functional/test_regression.py
@@ -71,6 +71,7 @@ try:
 except ModuleNotFoundError:
     pass
 else:
+    from declearn.dataset.tensorflow import TensorflowDataset
     from declearn.model.tensorflow import TensorflowModel
 # torch imports
 try:
@@ -139,6 +140,12 @@ def get_dataset(framework: FrameworkType, inputs, labels):
         inputs = torch.from_numpy(inputs)
         labels = torch.from_numpy(labels)
         return TorchDataset(torch.utils.data.TensorDataset(inputs, labels))
+    if framework == "tensorflow":
+        inputs = tf.convert_to_tensor(inputs)
+        labels = tf.convert_to_tensor(labels)
+        return TensorflowDataset(
+            tf.data.Dataset.from_tensor_slices((inputs, labels))
+        )
     return InMemoryDataset(inputs, labels, expose_data_type=True)
 
 
@@ -164,7 +171,6 @@ def prep_client_datasets(
     datasets: list[(InMemoryDataset, InMemoryDataset)]
         List of client-wise (train, valid) pair of datasets.
     """
-
     n_samples = clients * (n_train + n_valid)
     # false-positive; pylint: disable=unbalanced-tuple-unpacking
     inputs, target = make_regression(
-- 
GitLab