diff --git a/README.md b/README.md index 7045421d47ebbba50d495b3d1ab09f4d677e35e1..366561648cc5995488e0c08c2447649c1ed94461 100644 --- a/README.md +++ b/README.md @@ -62,15 +62,19 @@ file. Some third-party requirements are optional, and may not be installed. These are also specified as part of the `pyproject.toml` file, and may be divided into two categories:<br/> -(a) dependencies of optional, applied declearn components -(such as the TensorFlow and PyTorch tensor libraries) that are not imported -with declearn by default ;<br/> +(a) dependencies of optional, applied declearn components (such as the PyTorch +and Tensorflow tensor libraries, or the gRPC and websockets network +communication backends) that are not imported with declearn by default<br/> (b) dependencies for running tests on the package (mainly pytest and some of its plug-ins) The second category is more developer-oriented, while the first may or may not be relevant depending on the use case to which you wish to apply `declearn`. +In the `pyproject.toml` file, the `[project.optional-dependencies]` tables +`all` and `test` respectively list the first and (first + second) categories, +while additional tables redundantly list dependencies unit by unit. + ### Using a virtual environment (optional) It is generally advised to use a virtual environment, to avoid any dependency @@ -113,8 +117,13 @@ To also install optional requirements, add the name of the extras between brackets to the `pip install` command, _e.g._ running one of the following: ```bash +# Examples of cherry-picked installation instructions. +pip install .[grpc] # install dependencies to support gRPC communications pip install .[torch] # install declearn.model.torch submodule dependencies pip install .[tensorflow,torch] # install both tensorflow and torch + +# Instructions to install bundles of optional components. +pip install .[all] # install all optional dependencies, save for testing pip install .[tests] # install all optional dependencies plus testing ones ``` diff --git a/declearn/communication/__init__.py b/declearn/communication/__init__.py index 3f0b4b5ef678a851dbc7d90743305bd3c2f95978..393a85e021dbd6ad8d0b75fc3958526298626d89 100644 --- a/declearn/communication/__init__.py +++ b/declearn/communication/__init__.py @@ -2,18 +2,15 @@ """Submodule implementing client/server communications. -This module contains the following submodules: +This module contains the following core submodules: * api: Base API to define client- and server-side communication endpoints. * messaging: Message dataclasses defining information containers to be exchanged between communication endpoints. -* grpc: - gRPC-based network communication endpoints. -* websockets: - WebSockets-based network communication endpoints. -It also exposes the following functions: + +It also exposes the following core utility functions: * build_client: Instantiate a NetworkClient, selecting its subclass based on protocol name. * build_server: @@ -23,6 +20,14 @@ It also exposes the following functions: classes are registered (hence available to `build_client`/`build_server`). +Finally, it defines the following protocol-specific submodules, provided +the associated third-party dependencies are available: +* grpc: + gRPC-based network communication endpoints. + Requires the `grpcio` and `protobuf` third-party packages. +* websockets: + WebSockets-based network communication endpoints. + Requires the `websockets` third-party package. """ # Messaging and Communications API and base tools: @@ -34,8 +39,15 @@ from ._build import ( build_client, build_server, list_available_protocols, + _INSTALLABLE_BACKENDS, ) # Concrete implementations using various protocols: -from . import grpc -from . import websockets +try: + from . import grpc +except ImportError: + _INSTALLABLE_BACKENDS["grpc"] = ("grpcio", "protobuf") +try: + from . import websockets +except ImportError: + _INSTALLABLE_BACKENDS["websockets"] = ("websockets",) diff --git a/declearn/communication/_build.py b/declearn/communication/_build.py index d9638f8d5f290bc088e0dc4e311a6bde19b86d86..d57b1fccaca8be521daed67e98f23b62a30f14ab 100644 --- a/declearn/communication/_build.py +++ b/declearn/communication/_build.py @@ -4,7 +4,7 @@ import dataclasses import logging -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union from declearn.communication.api import NetworkClient, NetworkServer @@ -17,9 +17,27 @@ __all__ = [ "build_client", "build_server", "list_available_protocols", + "_INSTALLABLE_BACKENDS", ] +_INSTALLABLE_BACKENDS = {} # type: Dict[str, Tuple[str, ...]] + + +def raise_if_installable( + protocol: str, + exc: Optional[Exception] = None, +) -> None: + """Raise a RuntimeError if a given protocol is missing but installable.""" + if protocol in _INSTALLABLE_BACKENDS: + raise RuntimeError( + f"The '{protocol}' communication protocol network endpoints " + "could not be imported, but could be installed by satisfying " + f"the following dependencies: {_INSTALLABLE_BACKENDS[protocol]}, " + f"or by running `pip install declearn[{protocol}]`." + ) from exc + + def build_client( protocol: str, server_uri: str, @@ -54,6 +72,7 @@ def build_client( try: cls = access_registered(name=protocol, group="NetworkClient") except KeyError as exc: + raise_if_installable(protocol, exc) raise KeyError( "Failed to retrieve NetworkClient " f"class for protocol '{protocol}'." @@ -155,6 +174,7 @@ def build_server( try: cls = access_registered(name=protocol, group="NetworkServer") except KeyError as exc: + raise_if_installable(protocol, exc) raise KeyError( "Failed to retrieve NetworkServer " f"class for protocol '{protocol}'." diff --git a/pyproject.toml b/pyproject.toml index e7faa0880309cc72c6bf2e3dd7c7d13e4294637f..4326b0f73b38b02b56bcece39d7c5dc6cece87ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,22 +36,29 @@ dependencies = [ "cryptography >= 35.0", "grpcio >= 1.45", "pandas >= 1.2", - "protobuf >= 3.19", + "scikit-learn >= 1.0", "typing_extensions >= 4.0", "websockets >= 10.1", ] [project.optional-dependencies] -all = [ - "opacus >= 1.1", +all = [ # all non-tests extra dependencies "functorch >= 0.1", + "grpcio >= 1.45", + "opacus >= 1.1", + "protobuf >= 3.19", "tensorflow >= 2.5", "torch >= 1.10", + "websockets >= 10.1", ] dp = [ "opacus >= 1.1", ] +grpc = [ + "grpcio >= 1.45", + "protobuf >= 3.19", +] tensorflow = [ "tensorflow >= 2.5", ] @@ -59,6 +66,9 @@ torch = [ "functorch >= 0.1", # note: functorch is included with torch>=1.13 "torch >= 1.10", ] +websockets = [ + "websockets >= 10.1", +] tests = [ # test-specific dependencies "black ~= 22.0", @@ -66,11 +76,14 @@ tests = [ "pylint >= 2.14", "pytest >= 6.1", "pytest-asyncio", - # other extra dependencies - "opacus >= 1.1", + # other extra dependencies (copy of "all") "functorch >= 0.1", + "grpcio >= 1.45", + "opacus >= 1.1", + "protobuf >= 3.19", "tensorflow >= 2.5", "torch >= 1.10", + "websockets >= 10.1", ] [project.urls] diff --git a/test/test_main.py b/test/test_main.py index 85fb51d97d473526dd83c838bc05641c5cc013dc..ebda7f3463d244ed024095573a2903344b106dc5 100644 --- a/test/test_main.py +++ b/test/test_main.py @@ -15,8 +15,12 @@ with warnings.catch_warnings(): # silence tensorflow import-time warnings import torch from typing_extensions import Literal # future: import from typing (Py>=3.8) -from declearn.communication import build_client, build_server -from declearn.communication.api import Client, Server +from declearn.communication import ( + build_client, + build_server, + list_available_protocols, +) +from declearn.communication.api import NetworkClient, NetworkServer from declearn.dataset import InMemoryDataset from declearn.model.api import Model from declearn.model.sklearn import SklearnSGDModel @@ -141,8 +145,8 @@ class DeclearnTestCase: def build_netwk_server( self, - ) -> Server: - """Return a communication Server.""" + ) -> NetworkServer: + """Return a NetworkServer instance.""" return build_server( self.protocol, host="127.0.0.1", @@ -154,8 +158,8 @@ class DeclearnTestCase: def build_netwk_client( self, name: str = "client", - ) -> Client: - """Return a communication Client.""" + ) -> NetworkClient: + """Return a NetworkClient instance.""" server_uri = "localhost:8765" if self.protocol == "websockets": server_uri = f"ws{'s' * self.use_ssl}://" + server_uri @@ -257,12 +261,19 @@ def test_declearn( Note: Use unsecured websockets communication, which are less costful to establish than gRPC and/or SSL-secured ones (the latter due to the certificates-generation costs). + Note: if websockets is unavailable, use gRPC (warn) or fail. """ if not fulltest: if (kind != "Reg") or (strategy == "FedAvgM"): pytest.skip("skip scenario (no --fulltest option)") + protocol = "websockets" # type: Literal["grpc", "websockets"] + if "websockets" not in list_available_protocols(): + if "grpc" not in list_available_protocols(): + pytest.fail("Both 'grpc' and 'websockets' are unavailable.") + protocol = "grpc" + warnings.warn("Using 'grpc' as 'websockets' is unavailable.") # fmt: off run_test_case( kind, framework, strategy, - nb_clients=2, protocol="websockets", use_ssl=False, rounds=2, + nb_clients=2, protocol=protocol, use_ssl=False, rounds=2, )