Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit 99486665 authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Fix tests' resilience to missing optional dependencies.

- Run `Aggregator` unit tests for available frameworks only.
- Skip `gRPC`-dedicated tests when grpc is not installed.
- Skip `DPTrainingManager` tests when opacus is not installed.
- Fix toy regression functional tests failing when some frameworks
  are missing due to the use of too-specific type hints.
parent e1d5e4f0
No related branches found
No related tags found
No related merge requests found
......@@ -17,7 +17,6 @@
"""Unit tests for the 'Aggregator' subclasses."""
import typing
from typing import Dict, Type
import pytest
......@@ -29,12 +28,13 @@ from declearn.test_utils import (
GradientsTestCase,
assert_dict_equal,
assert_json_serializable_dict,
list_available_frameworks,
)
from declearn.utils import set_device_policy
AGGREGATOR_CLASSES = list_aggregators()
VECTOR_FRAMEWORKS = typing.get_args(FrameworkType)
VECTOR_FRAMEWORKS = list_available_frameworks()
@pytest.fixture(name="updates")
......
......@@ -31,10 +31,14 @@ import asyncio
import uuid
from typing import AsyncIterator, Dict, Iterator
import grpc # type: ignore
import pytest
import pytest_asyncio
try:
import grpc # type: ignore
except ModuleNotFoundError:
pytest.skip("GRPC is unavailable", allow_module_level=True)
from declearn.communication.api.backend.actions import Ping
from declearn.communication.grpc._server import load_pem_file
from declearn.communication.grpc import GrpcClient, GrpcServer
......
......@@ -67,7 +67,6 @@ from declearn.utils import set_device_policy
# optional frameworks' dependencies pylint: disable=ungrouped-imports
# pylint: disable=duplicate-code
# false-positives; pylint: disable=no-member
# tensorflow imports
try:
......@@ -104,7 +103,8 @@ else:
return (y_pred - y_true) ** 2
# pylint: disable=duplicate-code
# pylint: enable=duplicate-code, ungrouped-imports
SEED = 0
R2_THRESHOLD = 0.9999
......@@ -124,7 +124,7 @@ def get_model(framework: FrameworkType) -> Model:
raise ValueError(f"Unrecognised model framework: '{framework}'.")
def _get_model_numpy() -> SklearnSGDModel:
def _get_model_numpy() -> Model:
"""Return a linear model with MSE loss in Sklearn, with zero weights."""
np.random.seed(SEED) # set seed
model = SklearnSGDModel.from_parameters(
......@@ -133,10 +133,12 @@ def _get_model_numpy() -> SklearnSGDModel:
return model
def _get_model_tflow() -> TensorflowModel:
def _get_model_tflow() -> Model:
"""Return a linear model with MSE loss in TensorFlow, with zero weights."""
tf.random.set_seed(SEED) # set seed
tfmod = tf.keras.Sequential(tf.keras.layers.Dense(units=1))
tfmod = tf.keras.Sequential( # pylint: disable=no-member
tf.keras.layers.Dense(units=1) # pylint: disable=no-member
)
tfmod.build([None, 100])
model = TensorflowModel(tfmod, loss="mean_squared_error")
with tf.device("CPU"):
......@@ -148,7 +150,7 @@ def _get_model_tflow() -> TensorflowModel:
return model
def _get_model_torch() -> TorchModel:
def _get_model_torch() -> Model:
"""Return a linear model with MSE loss in Torch, with zero weights."""
torch.manual_seed(SEED) # set seed
torchmod = torch.nn.Sequential(
......@@ -164,7 +166,7 @@ def _get_model_torch() -> TorchModel:
return model
def _get_model_haiku() -> HaikuModel:
def _get_model_haiku() -> Model:
"""Return a linear model with MSE loss in Haiku, with zero weights."""
model = HaikuModel(haiku_model_fn, loss=haiku_loss_fn)
model.initialize({"data_type": "float32", "features_shape": (100,)})
......@@ -427,6 +429,9 @@ async def async_run_client(
await client.async_run()
# similar structure in other functional tests; pylint: disable=duplicate-code
@pytest.mark.asyncio
async def test_declearn_federated(
framework: FrameworkType,
......
......@@ -21,8 +21,12 @@ import os
from typing import Any, Optional
import pytest
from opacus.accountants import RDPAccountant # type: ignore
from opacus.accountants.utils import get_noise_multiplier # type: ignore
try:
from opacus.accountants import RDPAccountant # type: ignore
from opacus.accountants.utils import get_noise_multiplier # type: ignore
except ModuleNotFoundError:
pytest.skip("Opacus is unavailable", allow_module_level=True)
from declearn.communication import messaging
from declearn.dataset import DataSpecs
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment