diff --git a/test/aggregator/test_aggregator.py b/test/aggregator/test_aggregator.py index cde6a033743eb836eb43e6d65f7eebabeb195a35..e619bccee54e28d2fcef52735fc98814b7373c15 100644 --- a/test/aggregator/test_aggregator.py +++ b/test/aggregator/test_aggregator.py @@ -17,7 +17,6 @@ """Unit tests for the 'Aggregator' subclasses.""" -import typing from typing import Dict, Type import pytest @@ -29,12 +28,13 @@ from declearn.test_utils import ( GradientsTestCase, assert_dict_equal, assert_json_serializable_dict, + list_available_frameworks, ) from declearn.utils import set_device_policy AGGREGATOR_CLASSES = list_aggregators() -VECTOR_FRAMEWORKS = typing.get_args(FrameworkType) +VECTOR_FRAMEWORKS = list_available_frameworks() @pytest.fixture(name="updates") diff --git a/test/communication/test_grpc.py b/test/communication/test_grpc.py index fe6a1effd66f30d23c8d65875622d8a110436c2b..da80ba9571e8b492cc5777c388b3f52cfee3c9a9 100644 --- a/test/communication/test_grpc.py +++ b/test/communication/test_grpc.py @@ -31,10 +31,14 @@ import asyncio import uuid from typing import AsyncIterator, Dict, Iterator -import grpc # type: ignore import pytest import pytest_asyncio +try: + import grpc # type: ignore +except ModuleNotFoundError: + pytest.skip("GRPC is unavailable", allow_module_level=True) + from declearn.communication.api.backend.actions import Ping from declearn.communication.grpc._server import load_pem_file from declearn.communication.grpc import GrpcClient, GrpcServer diff --git a/test/functional/test_toy_reg.py b/test/functional/test_toy_reg.py index 93fc05fbe88351d7fc8c8d375ae47e11b7d98ac4..c37f941fc2042871f43ad27dfb0735a6bbcde45b 100644 --- a/test/functional/test_toy_reg.py +++ b/test/functional/test_toy_reg.py @@ -67,7 +67,6 @@ from declearn.utils import set_device_policy # optional frameworks' dependencies pylint: disable=ungrouped-imports # pylint: disable=duplicate-code -# false-positives; pylint: disable=no-member # tensorflow imports try: @@ -104,7 +103,8 @@ else: return (y_pred - y_true) ** 2 -# pylint: disable=duplicate-code +# pylint: enable=duplicate-code, ungrouped-imports + SEED = 0 R2_THRESHOLD = 0.9999 @@ -124,7 +124,7 @@ def get_model(framework: FrameworkType) -> Model: raise ValueError(f"Unrecognised model framework: '{framework}'.") -def _get_model_numpy() -> SklearnSGDModel: +def _get_model_numpy() -> Model: """Return a linear model with MSE loss in Sklearn, with zero weights.""" np.random.seed(SEED) # set seed model = SklearnSGDModel.from_parameters( @@ -133,10 +133,12 @@ def _get_model_numpy() -> SklearnSGDModel: return model -def _get_model_tflow() -> TensorflowModel: +def _get_model_tflow() -> Model: """Return a linear model with MSE loss in TensorFlow, with zero weights.""" tf.random.set_seed(SEED) # set seed - tfmod = tf.keras.Sequential(tf.keras.layers.Dense(units=1)) + tfmod = tf.keras.Sequential( # pylint: disable=no-member + tf.keras.layers.Dense(units=1) # pylint: disable=no-member + ) tfmod.build([None, 100]) model = TensorflowModel(tfmod, loss="mean_squared_error") with tf.device("CPU"): @@ -148,7 +150,7 @@ def _get_model_tflow() -> TensorflowModel: return model -def _get_model_torch() -> TorchModel: +def _get_model_torch() -> Model: """Return a linear model with MSE loss in Torch, with zero weights.""" torch.manual_seed(SEED) # set seed torchmod = torch.nn.Sequential( @@ -164,7 +166,7 @@ def _get_model_torch() -> TorchModel: return model -def _get_model_haiku() -> HaikuModel: +def _get_model_haiku() -> Model: """Return a linear model with MSE loss in Haiku, with zero weights.""" model = HaikuModel(haiku_model_fn, loss=haiku_loss_fn) model.initialize({"data_type": "float32", "features_shape": (100,)}) @@ -427,6 +429,9 @@ async def async_run_client( await client.async_run() +# similar structure in other functional tests; pylint: disable=duplicate-code + + @pytest.mark.asyncio async def test_declearn_federated( framework: FrameworkType, diff --git a/test/main/test_main_client.py b/test/main/test_main_client.py index 861f1bc372cd0fa36bb2eb3e7a6597695ec10539..e6ed7c38f3b780646b7b1f5ec25e50433d679dbd 100644 --- a/test/main/test_main_client.py +++ b/test/main/test_main_client.py @@ -29,7 +29,6 @@ from declearn.dataset import Dataset, DataSpecs from declearn.communication import NetworkClientConfig from declearn.communication.api import NetworkClient from declearn.main import FederatedClient -from declearn.main.privacy import DPTrainingManager from declearn.main.utils import Checkpointer, TrainingManager from declearn.metrics import MetricState from declearn.model.api import Model @@ -37,6 +36,13 @@ from declearn.secagg import messaging as secagg_messaging from declearn.secagg.api import SecaggConfigClient, SecaggSetupQuery from declearn.utils import LOGGING_LEVEL_MAJOR +try: + from declearn.main.privacy import DPTrainingManager +except ModuleNotFoundError: + DP_AVAILABLE = False +else: + DP_AVAILABLE = True + MOCK_NETWK = mock.create_autospec(NetworkClient, instance=True) MOCK_NETWK.name = "client" @@ -514,6 +520,8 @@ class TestFederatedClientInitialize: @pytest.mark.asyncio async def test_initialize_with_dpsgd(self) -> None: """Test that initialization with DP-SGD works properly.""" + if not DP_AVAILABLE: + pytest.skip(reason="Unavailable DP features (missing Opacus).") # Set up a mock network receiving an InitRequest and a PrivacyRequest. netwk = mock.create_autospec(NetworkClient, instance=True) netwk.name = "client" @@ -555,6 +563,8 @@ class TestFederatedClientInitialize: @pytest.mark.asyncio async def test_initialize_with_dpsgd_error_wrong_message(self) -> None: """Test error catching for DP-SGD setup with wrong second message.""" + if not DP_AVAILABLE: + pytest.skip(reason="Unavailable DP features (missing Opacus).") # Set up a mock network receiving a DP InitRequest but wrong follow-up. netwk = mock.create_autospec(NetworkClient, instance=True) netwk.name = "client" @@ -580,6 +590,8 @@ class TestFederatedClientInitialize: @pytest.mark.asyncio async def test_initialize_with_dpsgd_error_setup(self) -> None: """Test error catching for DP-SGD setup with client-side failure.""" + if not DP_AVAILABLE: + pytest.skip(reason="Unavailable DP features (missing Opacus).") # Set up a mock network receiving an InitRequest and a PrivacyRequest. netwk = mock.create_autospec(NetworkClient, instance=True) netwk.name = "client" diff --git a/test/main/test_train_manager_dp.py b/test/main/test_train_manager_dp.py index 352a1f2b7b7e92dc483c93af5432c11b364afbcc..3735f290b8e9da87be72d4dce2ce2024ed95307a 100644 --- a/test/main/test_train_manager_dp.py +++ b/test/main/test_train_manager_dp.py @@ -21,8 +21,12 @@ import os from typing import Any, Optional import pytest -from opacus.accountants import RDPAccountant # type: ignore -from opacus.accountants.utils import get_noise_multiplier # type: ignore + +try: + from opacus.accountants import RDPAccountant # type: ignore + from opacus.accountants.utils import get_noise_multiplier # type: ignore +except ModuleNotFoundError: + pytest.skip("Opacus is unavailable", allow_module_level=True) from declearn.communication import messaging from declearn.dataset import DataSpecs