diff --git a/test/aggregator/test_aggregator.py b/test/aggregator/test_aggregator.py index cde6a033743eb836eb43e6d65f7eebabeb195a35..e619bccee54e28d2fcef52735fc98814b7373c15 100644 --- a/test/aggregator/test_aggregator.py +++ b/test/aggregator/test_aggregator.py @@ -17,7 +17,6 @@ """Unit tests for the 'Aggregator' subclasses.""" -import typing from typing import Dict, Type import pytest @@ -29,12 +28,13 @@ from declearn.test_utils import ( GradientsTestCase, assert_dict_equal, assert_json_serializable_dict, + list_available_frameworks, ) from declearn.utils import set_device_policy AGGREGATOR_CLASSES = list_aggregators() -VECTOR_FRAMEWORKS = typing.get_args(FrameworkType) +VECTOR_FRAMEWORKS = list_available_frameworks() @pytest.fixture(name="updates") diff --git a/test/communication/test_grpc.py b/test/communication/test_grpc.py index fe6a1effd66f30d23c8d65875622d8a110436c2b..da80ba9571e8b492cc5777c388b3f52cfee3c9a9 100644 --- a/test/communication/test_grpc.py +++ b/test/communication/test_grpc.py @@ -31,10 +31,14 @@ import asyncio import uuid from typing import AsyncIterator, Dict, Iterator -import grpc # type: ignore import pytest import pytest_asyncio +try: + import grpc # type: ignore +except ModuleNotFoundError: + pytest.skip("GRPC is unavailable", allow_module_level=True) + from declearn.communication.api.backend.actions import Ping from declearn.communication.grpc._server import load_pem_file from declearn.communication.grpc import GrpcClient, GrpcServer diff --git a/test/functional/test_toy_reg.py b/test/functional/test_toy_reg.py index 93fc05fbe88351d7fc8c8d375ae47e11b7d98ac4..c37f941fc2042871f43ad27dfb0735a6bbcde45b 100644 --- a/test/functional/test_toy_reg.py +++ b/test/functional/test_toy_reg.py @@ -67,7 +67,6 @@ from declearn.utils import set_device_policy # optional frameworks' dependencies pylint: disable=ungrouped-imports # pylint: disable=duplicate-code -# false-positives; pylint: disable=no-member # tensorflow imports try: @@ -104,7 +103,8 @@ else: return (y_pred - y_true) ** 2 -# pylint: disable=duplicate-code +# pylint: enable=duplicate-code, ungrouped-imports + SEED = 0 R2_THRESHOLD = 0.9999 @@ -124,7 +124,7 @@ def get_model(framework: FrameworkType) -> Model: raise ValueError(f"Unrecognised model framework: '{framework}'.") -def _get_model_numpy() -> SklearnSGDModel: +def _get_model_numpy() -> Model: """Return a linear model with MSE loss in Sklearn, with zero weights.""" np.random.seed(SEED) # set seed model = SklearnSGDModel.from_parameters( @@ -133,10 +133,12 @@ def _get_model_numpy() -> SklearnSGDModel: return model -def _get_model_tflow() -> TensorflowModel: +def _get_model_tflow() -> Model: """Return a linear model with MSE loss in TensorFlow, with zero weights.""" tf.random.set_seed(SEED) # set seed - tfmod = tf.keras.Sequential(tf.keras.layers.Dense(units=1)) + tfmod = tf.keras.Sequential( # pylint: disable=no-member + tf.keras.layers.Dense(units=1) # pylint: disable=no-member + ) tfmod.build([None, 100]) model = TensorflowModel(tfmod, loss="mean_squared_error") with tf.device("CPU"): @@ -148,7 +150,7 @@ def _get_model_tflow() -> TensorflowModel: return model -def _get_model_torch() -> TorchModel: +def _get_model_torch() -> Model: """Return a linear model with MSE loss in Torch, with zero weights.""" torch.manual_seed(SEED) # set seed torchmod = torch.nn.Sequential( @@ -164,7 +166,7 @@ def _get_model_torch() -> TorchModel: return model -def _get_model_haiku() -> HaikuModel: +def _get_model_haiku() -> Model: """Return a linear model with MSE loss in Haiku, with zero weights.""" model = HaikuModel(haiku_model_fn, loss=haiku_loss_fn) model.initialize({"data_type": "float32", "features_shape": (100,)}) @@ -427,6 +429,9 @@ async def async_run_client( await client.async_run() +# similar structure in other functional tests; pylint: disable=duplicate-code + + @pytest.mark.asyncio async def test_declearn_federated( framework: FrameworkType, diff --git a/test/main/test_train_manager_dp.py b/test/main/test_train_manager_dp.py index 352a1f2b7b7e92dc483c93af5432c11b364afbcc..3735f290b8e9da87be72d4dce2ce2024ed95307a 100644 --- a/test/main/test_train_manager_dp.py +++ b/test/main/test_train_manager_dp.py @@ -21,8 +21,12 @@ import os from typing import Any, Optional import pytest -from opacus.accountants import RDPAccountant # type: ignore -from opacus.accountants.utils import get_noise_multiplier # type: ignore + +try: + from opacus.accountants import RDPAccountant # type: ignore + from opacus.accountants.utils import get_noise_multiplier # type: ignore +except ModuleNotFoundError: + pytest.skip("Opacus is unavailable", allow_module_level=True) from declearn.communication import messaging from declearn.dataset import DataSpecs