From 6360878f27a504818ca1a5429ce5ec4a42a1fb05 Mon Sep 17 00:00:00 2001
From: Paul Andrey <paul.andrey@inria.fr>
Date: Mon, 6 Feb 2023 11:12:10 +0100
Subject: [PATCH] Modularize tests' execution based on frameworks'
 availability.

---
 test/model/test_tflow.py |  9 ++++++---
 test/model/test_torch.py |  6 +++++-
 test/test_main.py        | 31 +++++++++++++++++++++----------
 3 files changed, 32 insertions(+), 14 deletions(-)

diff --git a/test/model/test_tflow.py b/test/model/test_tflow.py
index da93b464..a69a67f0 100644
--- a/test/model/test_tflow.py
+++ b/test/model/test_tflow.py
@@ -9,9 +9,12 @@ from typing import Any, List, Literal
 import numpy as np
 import pytest
 
-with warnings.catch_warnings():  # silence tensorflow import-time warnings
-    warnings.simplefilter("ignore")
-    import tensorflow as tf  # type: ignore
+try:
+    with warnings.catch_warnings():  # silence tensorflow import-time warnings
+        warnings.simplefilter("ignore")
+        import tensorflow as tf  # type: ignore
+except ModuleNotFoundError:
+    pytest.skip("TensorFlow is unavailable", allow_module_level=True)
 
 from declearn.model.tensorflow import TensorflowModel, TensorflowVector
 from declearn.typing import Batch
diff --git a/test/model/test_torch.py b/test/model/test_torch.py
index f2d30f65..53b8b9d1 100644
--- a/test/model/test_torch.py
+++ b/test/model/test_torch.py
@@ -7,7 +7,11 @@ from typing import Any, List, Literal, Tuple
 
 import numpy as np
 import pytest
-import torch
+
+try:
+    import torch
+except ModuleNotFoundError:
+    pytest.skip("PyTorch is unavailable", allow_module_level=True)
 
 from declearn.model.torch import TorchModel, TorchVector
 from declearn.typing import Batch
diff --git a/test/test_main.py b/test/test_main.py
index f2537265..2b22b600 100644
--- a/test/test_main.py
+++ b/test/test_main.py
@@ -9,11 +9,6 @@ from typing import Any, Dict, Literal, Optional
 import numpy as np
 import pytest
 
-with warnings.catch_warnings():  # silence tensorflow import-time warnings
-    warnings.simplefilter("ignore")
-    import tensorflow as tf  # type: ignore
-import torch
-
 from declearn.communication import (
     build_client,
     build_server,
@@ -23,11 +18,26 @@ from declearn.communication.api import NetworkClient, NetworkServer
 from declearn.dataset import InMemoryDataset
 from declearn.model.api import Model
 from declearn.model.sklearn import SklearnSGDModel
-from declearn.model.tensorflow import TensorflowModel
-from declearn.model.torch import TorchModel
 from declearn.main import FederatedClient, FederatedServer
 from declearn.test_utils import run_as_processes
 
+# Select the subset of tests to run, based on framework availability.
+# Note: TensorFlow and Torch (-related) imports are delayed due to this.
+# pylint: disable=ungrouped-imports
+FRAMEWORKS = ["Sksgd", "Tflow", "Torch"]
+try:
+    import tensorflow as tf
+except ModuleNotFoundError:
+    FRAMEWORKS.remove("Tflow")
+else:
+    from declearn.model.tensorflow import TensorflowModel
+try:
+    import torch
+except ModuleNotFoundError:
+    FRAMEWORKS.remove("Torch")
+else:
+    from declearn.model.torch import TorchModel
+
 
 class DeclearnTestCase:
     """Test-case for the "main" federated learning orchestrating classes."""
@@ -73,7 +83,7 @@ class DeclearnTestCase:
 
     def _build_tflow_model(
         self,
-    ) -> TensorflowModel:
+    ) -> Model:
         """Return a TensorflowModel suitable for the learning task."""
         if self.kind == "Reg":
             output_layer = tf.keras.layers.Dense(1)
@@ -97,8 +107,9 @@ class DeclearnTestCase:
 
     def _build_torch_model(
         self,
-    ) -> TorchModel:
+    ) -> Model:
         """Return a TorchModel suitable for the learning task."""
+        # Build the model and return it.
         stack = [
             torch.nn.Linear(32, 32),
             torch.nn.ReLU(),
@@ -245,7 +256,7 @@ def run_test_case(
 
 
 @pytest.mark.parametrize("strategy", ["FedAvg", "FedAvgM", "Scaffold"])
-@pytest.mark.parametrize("framework", ["Sksgd", "Tflow", "Torch"])
+@pytest.mark.parametrize("framework", FRAMEWORKS)
 @pytest.mark.parametrize("kind", ["Reg", "Bin", "Clf"])
 @pytest.mark.filterwarnings("ignore: PyTorch JSON serialization")
 def test_declearn(
-- 
GitLab