diff --git a/test/model/test_tflow.py b/test/model/test_tflow.py index da93b4643423e1240f7a2a6e628ca004bc3d4da6..a69a67f0863c0fb8608e4ab7d6f350bc7200650a 100644 --- a/test/model/test_tflow.py +++ b/test/model/test_tflow.py @@ -9,9 +9,12 @@ from typing import Any, List, Literal import numpy as np import pytest -with warnings.catch_warnings(): # silence tensorflow import-time warnings - warnings.simplefilter("ignore") - import tensorflow as tf # type: ignore +try: + with warnings.catch_warnings(): # silence tensorflow import-time warnings + warnings.simplefilter("ignore") + import tensorflow as tf # type: ignore +except ModuleNotFoundError: + pytest.skip("TensorFlow is unavailable", allow_module_level=True) from declearn.model.tensorflow import TensorflowModel, TensorflowVector from declearn.typing import Batch diff --git a/test/model/test_torch.py b/test/model/test_torch.py index f2d30f652329eb678b68abbd896e623ca5bf8691..53b8b9d174c1b651c0619dd65275739159718e3a 100644 --- a/test/model/test_torch.py +++ b/test/model/test_torch.py @@ -7,7 +7,11 @@ from typing import Any, List, Literal, Tuple import numpy as np import pytest -import torch + +try: + import torch +except ModuleNotFoundError: + pytest.skip("PyTorch is unavailable", allow_module_level=True) from declearn.model.torch import TorchModel, TorchVector from declearn.typing import Batch diff --git a/test/test_main.py b/test/test_main.py index f25372659bc62c400e4d1b453d69fbf54e0d2b84..2b22b6001299691608f0ef2af0c6484788e10262 100644 --- a/test/test_main.py +++ b/test/test_main.py @@ -9,11 +9,6 @@ from typing import Any, Dict, Literal, Optional import numpy as np import pytest -with warnings.catch_warnings(): # silence tensorflow import-time warnings - warnings.simplefilter("ignore") - import tensorflow as tf # type: ignore -import torch - from declearn.communication import ( build_client, build_server, @@ -23,11 +18,26 @@ from declearn.communication.api import NetworkClient, NetworkServer from declearn.dataset import InMemoryDataset from declearn.model.api import Model from declearn.model.sklearn import SklearnSGDModel -from declearn.model.tensorflow import TensorflowModel -from declearn.model.torch import TorchModel from declearn.main import FederatedClient, FederatedServer from declearn.test_utils import run_as_processes +# Select the subset of tests to run, based on framework availability. +# Note: TensorFlow and Torch (-related) imports are delayed due to this. +# pylint: disable=ungrouped-imports +FRAMEWORKS = ["Sksgd", "Tflow", "Torch"] +try: + import tensorflow as tf +except ModuleNotFoundError: + FRAMEWORKS.remove("Tflow") +else: + from declearn.model.tensorflow import TensorflowModel +try: + import torch +except ModuleNotFoundError: + FRAMEWORKS.remove("Torch") +else: + from declearn.model.torch import TorchModel + class DeclearnTestCase: """Test-case for the "main" federated learning orchestrating classes.""" @@ -73,7 +83,7 @@ class DeclearnTestCase: def _build_tflow_model( self, - ) -> TensorflowModel: + ) -> Model: """Return a TensorflowModel suitable for the learning task.""" if self.kind == "Reg": output_layer = tf.keras.layers.Dense(1) @@ -97,8 +107,9 @@ class DeclearnTestCase: def _build_torch_model( self, - ) -> TorchModel: + ) -> Model: """Return a TorchModel suitable for the learning task.""" + # Build the model and return it. stack = [ torch.nn.Linear(32, 32), torch.nn.ReLU(), @@ -245,7 +256,7 @@ def run_test_case( @pytest.mark.parametrize("strategy", ["FedAvg", "FedAvgM", "Scaffold"]) -@pytest.mark.parametrize("framework", ["Sksgd", "Tflow", "Torch"]) +@pytest.mark.parametrize("framework", FRAMEWORKS) @pytest.mark.parametrize("kind", ["Reg", "Bin", "Clf"]) @pytest.mark.filterwarnings("ignore: PyTorch JSON serialization") def test_declearn(