From d10a42e2de75a854882f941aa5d4d840e098eb83 Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Thu, 24 Aug 2023 13:54:23 +0200 Subject: [PATCH] Deploy 'TensorflowDataset' in toy-regression integration test. --- test/functional/test_regression.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/test/functional/test_regression.py b/test/functional/test_regression.py index 8e1bda9a..da526f96 100644 --- a/test/functional/test_regression.py +++ b/test/functional/test_regression.py @@ -71,6 +71,7 @@ try: except ModuleNotFoundError: pass else: + from declearn.dataset.tensorflow import TensorflowDataset from declearn.model.tensorflow import TensorflowModel # torch imports try: @@ -139,6 +140,12 @@ def get_dataset(framework: FrameworkType, inputs, labels): inputs = torch.from_numpy(inputs) labels = torch.from_numpy(labels) return TorchDataset(torch.utils.data.TensorDataset(inputs, labels)) + if framework == "tensorflow": + inputs = tf.convert_to_tensor(inputs) + labels = tf.convert_to_tensor(labels) + return TensorflowDataset( + tf.data.Dataset.from_tensor_slices((inputs, labels)) + ) return InMemoryDataset(inputs, labels, expose_data_type=True) @@ -164,7 +171,6 @@ def prep_client_datasets( datasets: list[(InMemoryDataset, InMemoryDataset)] List of client-wise (train, valid) pair of datasets. """ - n_samples = clients * (n_train + n_valid) # false-positive; pylint: disable=unbalanced-tuple-unpacking inputs, target = make_regression( -- GitLab