diff --git a/test/functional/test_regression.py b/test/functional/test_regression.py index 506184f51fbae29248420532ae1cbf6a4e335c33..60b2f54a798824ba9c76115d0b1bc293188cd440 100644 --- a/test/functional/test_regression.py +++ b/test/functional/test_regression.py @@ -59,11 +59,22 @@ from declearn.main.config import FLOptimConfig, FLRunConfig from declearn.metrics import RSquared from declearn.model.api import Model from declearn.model.sklearn import SklearnSGDModel -from declearn.model.tensorflow import TensorflowModel -from declearn.model.torch import TorchModel from declearn.optimizer import Optimizer from declearn.test_utils import FrameworkType, run_as_processes +# pylint: disable=ungrouped-imports; optional frameworks' dependencies +try: + import tensorflow as tf # type: ignore + from declearn.model.tensorflow import TensorflowModel +except ModuleNotFoundError: + pass +try: + import torch + from declearn.model.torch import TorchModel +except ModuleNotFoundError: + pass + + SEED = 0 R2_THRESHOLD = 0.999