Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit 27e0122a authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Fix tests' resilience to missing optional dependencies.

- Run `Aggregator` unit tests for available frameworks only.
- Skip `gRPC`-dedicated tests when grpc is not installed.
- Skip `DPTrainingManager` tests when opacus is not installed.
- Skip some `FederatedClient` tests when opacus is not installed.
- Fix toy regression functional tests failing when some frameworks
  are missing due to the use of too-specific type hints.
parent 72ccc11d
No related branches found
No related tags found
No related merge requests found
......@@ -17,7 +17,6 @@
"""Unit tests for the 'Aggregator' subclasses."""
import typing
from typing import Dict, Type
import pytest
......@@ -29,12 +28,13 @@ from declearn.test_utils import (
GradientsTestCase,
assert_dict_equal,
assert_json_serializable_dict,
list_available_frameworks,
)
from declearn.utils import set_device_policy
AGGREGATOR_CLASSES = list_aggregators()
VECTOR_FRAMEWORKS = typing.get_args(FrameworkType)
VECTOR_FRAMEWORKS = list_available_frameworks()
@pytest.fixture(name="updates")
......
......@@ -31,10 +31,14 @@ import asyncio
import uuid
from typing import AsyncIterator, Dict, Iterator
import grpc # type: ignore
import pytest
import pytest_asyncio
try:
import grpc # type: ignore
except ModuleNotFoundError:
pytest.skip("GRPC is unavailable", allow_module_level=True)
from declearn.communication.api.backend.actions import Ping
from declearn.communication.grpc._server import load_pem_file
from declearn.communication.grpc import GrpcClient, GrpcServer
......
......@@ -67,7 +67,6 @@ from declearn.utils import set_device_policy
# optional frameworks' dependencies pylint: disable=ungrouped-imports
# pylint: disable=duplicate-code
# false-positives; pylint: disable=no-member
# tensorflow imports
try:
......@@ -104,7 +103,8 @@ else:
return (y_pred - y_true) ** 2
# pylint: disable=duplicate-code
# pylint: enable=duplicate-code, ungrouped-imports
SEED = 0
R2_THRESHOLD = 0.9999
......@@ -124,7 +124,7 @@ def get_model(framework: FrameworkType) -> Model:
raise ValueError(f"Unrecognised model framework: '{framework}'.")
def _get_model_numpy() -> SklearnSGDModel:
def _get_model_numpy() -> Model:
"""Return a linear model with MSE loss in Sklearn, with zero weights."""
np.random.seed(SEED) # set seed
model = SklearnSGDModel.from_parameters(
......@@ -133,10 +133,12 @@ def _get_model_numpy() -> SklearnSGDModel:
return model
def _get_model_tflow() -> TensorflowModel:
def _get_model_tflow() -> Model:
"""Return a linear model with MSE loss in TensorFlow, with zero weights."""
tf.random.set_seed(SEED) # set seed
tfmod = tf.keras.Sequential(tf.keras.layers.Dense(units=1))
tfmod = tf.keras.Sequential( # pylint: disable=no-member
tf.keras.layers.Dense(units=1) # pylint: disable=no-member
)
tfmod.build([None, 100])
model = TensorflowModel(tfmod, loss="mean_squared_error")
with tf.device("CPU"):
......@@ -148,7 +150,7 @@ def _get_model_tflow() -> TensorflowModel:
return model
def _get_model_torch() -> TorchModel:
def _get_model_torch() -> Model:
"""Return a linear model with MSE loss in Torch, with zero weights."""
torch.manual_seed(SEED) # set seed
torchmod = torch.nn.Sequential(
......@@ -164,7 +166,7 @@ def _get_model_torch() -> TorchModel:
return model
def _get_model_haiku() -> HaikuModel:
def _get_model_haiku() -> Model:
"""Return a linear model with MSE loss in Haiku, with zero weights."""
model = HaikuModel(haiku_model_fn, loss=haiku_loss_fn)
model.initialize({"data_type": "float32", "features_shape": (100,)})
......@@ -427,6 +429,9 @@ async def async_run_client(
await client.async_run()
# similar structure in other functional tests; pylint: disable=duplicate-code
@pytest.mark.asyncio
async def test_declearn_federated(
framework: FrameworkType,
......
......@@ -29,7 +29,6 @@ from declearn.dataset import Dataset, DataSpecs
from declearn.communication import NetworkClientConfig
from declearn.communication.api import NetworkClient
from declearn.main import FederatedClient
from declearn.main.privacy import DPTrainingManager
from declearn.main.utils import Checkpointer, TrainingManager
from declearn.metrics import MetricState
from declearn.model.api import Model
......@@ -37,6 +36,13 @@ from declearn.secagg import messaging as secagg_messaging
from declearn.secagg.api import SecaggConfigClient, SecaggSetupQuery
from declearn.utils import LOGGING_LEVEL_MAJOR
try:
from declearn.main.privacy import DPTrainingManager
except ModuleNotFoundError:
DP_AVAILABLE = False
else:
DP_AVAILABLE = True
MOCK_NETWK = mock.create_autospec(NetworkClient, instance=True)
MOCK_NETWK.name = "client"
......@@ -514,6 +520,8 @@ class TestFederatedClientInitialize:
@pytest.mark.asyncio
async def test_initialize_with_dpsgd(self) -> None:
"""Test that initialization with DP-SGD works properly."""
if not DP_AVAILABLE:
pytest.skip(reason="Unavailable DP features (missing Opacus).")
# Set up a mock network receiving an InitRequest and a PrivacyRequest.
netwk = mock.create_autospec(NetworkClient, instance=True)
netwk.name = "client"
......@@ -555,6 +563,8 @@ class TestFederatedClientInitialize:
@pytest.mark.asyncio
async def test_initialize_with_dpsgd_error_wrong_message(self) -> None:
"""Test error catching for DP-SGD setup with wrong second message."""
if not DP_AVAILABLE:
pytest.skip(reason="Unavailable DP features (missing Opacus).")
# Set up a mock network receiving a DP InitRequest but wrong follow-up.
netwk = mock.create_autospec(NetworkClient, instance=True)
netwk.name = "client"
......@@ -580,6 +590,8 @@ class TestFederatedClientInitialize:
@pytest.mark.asyncio
async def test_initialize_with_dpsgd_error_setup(self) -> None:
"""Test error catching for DP-SGD setup with client-side failure."""
if not DP_AVAILABLE:
pytest.skip(reason="Unavailable DP features (missing Opacus).")
# Set up a mock network receiving an InitRequest and a PrivacyRequest.
netwk = mock.create_autospec(NetworkClient, instance=True)
netwk.name = "client"
......
......@@ -21,8 +21,12 @@ import os
from typing import Any, Optional
import pytest
from opacus.accountants import RDPAccountant # type: ignore
from opacus.accountants.utils import get_noise_multiplier # type: ignore
try:
from opacus.accountants import RDPAccountant # type: ignore
from opacus.accountants.utils import get_noise_multiplier # type: ignore
except ModuleNotFoundError:
pytest.skip("Opacus is unavailable", allow_module_level=True)
from declearn.communication import messaging
from declearn.dataset import DataSpecs
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment