diff --git a/declearn/communication/api/_server.py b/declearn/communication/api/_server.py index 1d28545b3b561db44687091aa4d4b25045c860e1..a4df5dbb9b7ba2fb644df68c18e997ba7d28b575 100644 --- a/declearn/communication/api/_server.py +++ b/declearn/communication/api/_server.py @@ -219,7 +219,6 @@ class NetworkServer(metaclass=ABCMeta): If the number of registered clients does not abide by the provided boundaries at the end of the process. - Returns ------- client_info: dict[str, dict[str, any]] diff --git a/declearn/communication/websockets/_server.py b/declearn/communication/websockets/_server.py index 89516d443a9dbef4e91bfdfd10128cdb2259d713..e2c618df71220d1a991d03f2a5e7a3f4002a6938 100644 --- a/declearn/communication/websockets/_server.py +++ b/declearn/communication/websockets/_server.py @@ -167,5 +167,6 @@ class WebsocketsServer(NetworkServer): """Stop the websockets server and purge information about clients.""" if self._server is not None: self._server.close() + await self._server.wait_closed() self._server = None await self.handler.purge() diff --git a/test/communication/test_exchanges.py b/test/communication/test_exchanges.py new file mode 100644 index 0000000000000000000000000000000000000000..2b06aa77f639c05869fdc4ff72248844dee3f70f --- /dev/null +++ b/test/communication/test_exchanges.py @@ -0,0 +1,254 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Minimal unit tests for declearn.communication network endpoint classes. + +The tests implemented here gradually assess that a NetworkServer and one +or multiple properly-configured NetworkClient instances can connect with +each other and exchange information over the localhost using the actual +protocol they rely upon (as opposed to assessing classes' behaviors with +mock communication I/O). + +Here, tests are run incrementally that assess: + - the possibility for connections to be set up + - the possibility for a server to perform clients' registration + - the possibility to exchange messages in both directions between + a server and registered clients + +All available communication protocols are used, with or without SSL. +In the last tier of tests, unencrypted communications stop being tested +(notably because they can fail on the local host due to clients _not_ +being identified as distinct by some protocols, such as gRPC) and tests +are run with either a single or three clients at once. +""" + +import asyncio +from typing import AsyncIterator, Dict, List, Optional, Tuple + +import pytest +import pytest_asyncio + +from declearn.communication import ( + build_client, + build_server, + list_available_protocols, + messaging, +) +from declearn.communication.api import NetworkClient, NetworkServer + + +### 1. Test that connections can properly be set up. + + +@pytest_asyncio.fixture(name="server") +async def server_fixture( + protocol: str, + ssl_cert: Dict[str, str], + ssl: bool, +) -> AsyncIterator[NetworkServer]: + """Fixture to provide with an instantiated and started NetworkServer.""" + server = build_server( + protocol=protocol, + host="127.0.0.1", # truly "localhost", but fails on the CI otherwise + port=8765, + certificate=ssl_cert["server_cert"] if ssl else None, + private_key=ssl_cert["server_pkey"] if ssl else None, + ) + await server.start() + yield server + await server.stop() + + +def client_from_server( + server: NetworkServer, + c_name: str = "client", + ca_ssl: Optional[str] = None, +) -> NetworkClient: + """Instantiate a NetworkClient based on a NetworkServer.""" + return build_client( + protocol=server.protocol, + server_uri=server.uri.replace("127.0.0.1", "localhost"), + name=c_name, + certificate=ca_ssl, + ) + + +@pytest.mark.parametrize("ssl", [True, False], ids=["ssl", "no_ssl"]) +@pytest.mark.parametrize("protocol", list_available_protocols()) +@pytest.mark.asyncio +async def test_network_connect( + server: NetworkServer, + ssl_cert: Dict[str, str], + ssl: bool, +) -> None: + """Test that connections can be set up for a given framework.""" + ca_ssl = ssl_cert["client_cert"] if ssl else None + client = client_from_server(server, c_name="client", ca_ssl=ca_ssl) + await client.start() + await client.stop() + + +### 2. Test that a client can be registered over network by a server. + + +@pytest_asyncio.fixture(name="client") +async def client_fixture( + server: NetworkServer, + ssl_cert: Dict[str, str], + ssl: bool, +) -> AsyncIterator[NetworkClient]: + """Fixture to provide with an instantiated and started NetworkClient.""" + client = client_from_server( + server, + c_name="client", + ca_ssl=ssl_cert["client_cert"] if ssl else None, + ) + await client.start() + yield client + await client.stop() + + +@pytest.mark.parametrize("ssl", [True, False], ids=["ssl", "no_ssl"]) +@pytest.mark.parametrize("protocol", list_available_protocols()) +class TestNetworkRegister: + """Unit tests for client-registration operations.""" + + @pytest.mark.asyncio + async def test_early_request( + self, server: NetworkServer, client: NetworkClient + ) -> None: + """Test that early registration requests are rejected.""" + accepted = await client.register(data_info={"test": 42}) + assert not accepted + assert not server.client_names + + @pytest.mark.asyncio + async def test_register( + self, server: NetworkServer, client: NetworkClient + ) -> None: + """Test that client registration works properly.""" + data_info, accepted = await asyncio.gather( + server.wait_for_clients(1), + client.register(data_info={"test": 42}), + ) + assert data_info == {"client": {"test": 42}} + assert accepted + assert server.client_names == {"client"} + + @pytest.mark.asyncio + async def test_register_late( + self, server: NetworkServer, client: NetworkClient + ) -> None: + """Test that late client registration fails properly.""" + # Wait for clients, with a timeout. + with pytest.raises(RuntimeError): + await server.wait_for_clients(timeout=1) + # Try registering after that timeout. + accepted = await client.register(data_info={"test": 42}) + assert not accepted + + +### 3. Test that a server and its registered clients can exchange messages. + + +@pytest_asyncio.fixture(name="agents") +async def agents_fixture( + server: NetworkServer, + n_clients: int, + ssl_cert: Dict[str, str], + ssl: bool, +) -> AsyncIterator[Tuple[NetworkServer, List[NetworkClient]]]: + """Fixture to provide with a server and pre-registered client(s).""" + # Instantiate the clients. + ca_ssl = ssl_cert["client_cert"] if ssl else None + clients = [ + client_from_server(server, c_name=f"client-{idx}", ca_ssl=ca_ssl) + for idx in range(n_clients) + ] + # Start the clients and have the server register them. + await asyncio.gather(*[client.start() for client in clients]) + await asyncio.gather( + server.wait_for_clients(n_clients, timeout=2), + *[client.register({}) for client in clients], + ) + # Yield the server and clients. On exit, stop the clients. + yield server, clients + await asyncio.gather(*[client.stop() for client in clients]) + + +@pytest.mark.parametrize("n_clients", [1, 3], ids=["1_client", "3_clients"]) +@pytest.mark.parametrize("ssl", [True], ids=["ssl"]) +@pytest.mark.parametrize("protocol", list_available_protocols()) +class TestNetworkExchanges: + """Unit tests for messaging-over-network operations. + + Note: the unit tests implemented here are grouped into a single + larger call, to avoid setup costs, while preserving code + and failure-information readability (hopefully). + """ + + @pytest.mark.asyncio + async def test_exchanges( + self, agents: Tuple[NetworkServer, List[NetworkClient]] + ) -> None: + """Run all tests with the same fixture-provided agents.""" + await self.clients_to_server(agents) + await self.server_to_clients_broadcast(agents) + await self.server_to_clients_individual(agents) + + async def clients_to_server( + self, agents: Tuple[NetworkServer, List[NetworkClient]] + ) -> None: + """Test that clients can send messages to the server.""" + server, clients = agents + coros = [] + for idx, client in enumerate(clients): + msg = messaging.GenericMessage(action="test", params={"idx": idx}) + coros.append(client.send_message(msg)) + messages, *_ = await asyncio.gather(server.wait_for_messages(), *coros) + assert messages == { + c.name: messaging.GenericMessage(action="test", params={"idx": i}) + for i, c in enumerate(clients) + } + + async def server_to_clients_broadcast( + self, agents: Tuple[NetworkServer, List[NetworkClient]] + ) -> None: + """Test that the server can send a shared message to all clients.""" + server, clients = agents + msg = messaging.GenericMessage(action="test", params={"value": 42}) + send = server.broadcast_message(msg) + recv = [client.check_message(timeout=1) for client in clients] + _, *replies = await asyncio.gather(send, *recv) + assert all(reply == msg for reply in replies) + + async def server_to_clients_individual( + self, agents: Tuple[NetworkServer, List[NetworkClient]] + ) -> None: + """Test that the server can send individual messages to clients.""" + server, clients = agents + messages = { + name: messaging.GenericMessage(action="test", params={"idx": idx}) + for idx, name in enumerate(server.client_names) + } # type: Dict[str, messaging.Message] + send = server.send_messages(messages) + recv = [client.check_message(timeout=1) for client in clients] + _, *replies = await asyncio.gather(send, *recv) + assert all( + reply == messages[client.name] + for client, reply in zip(clients, replies) + ) diff --git a/test/communication/test_routines.py b/test/communication/test_routines.py deleted file mode 100644 index b5dcc35a5c4dcdcd140020dd11b5b8d3462be41b..0000000000000000000000000000000000000000 --- a/test/communication/test_routines.py +++ /dev/null @@ -1,196 +0,0 @@ -# coding: utf-8 - -# Copyright 2023 Inria (Institut National de Recherche en Informatique -# et Automatique) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Functional test for declearn.communication classes. - -The test implemented here spawns a NetworkServer endpoint as well as one -or multiple NetworkClient ones, then runs parallelly routines that have -the clients register, and both sides exchange dummy messages. As such, -it only verifies that messages passing works, and does not constitute a -proper (ensemble of) unit test(s) of the classes. - -However, if this passes, it means that registration and basic -message passing work properly, using the following scenarios: -* gRPC or WebSockets protocol -* SSL-secured communications or not -* 1-client or 3-clients cases - -Note that the tests are somewhat slow when collected by pytest, -and that they make use of the multiprocessing library to isolate -the server and individual clients - which is not required when -running the code manually, and might require using '--full-trace' -pytest option to debug in case a test fails. - -Note: running code that uses `asyncio.gather` on concurrent coroutines -is unsuccessful with gRPC due to spawned clients sharing the same peer -context. This may be fixed by implementing proper authentication. -""" - -import asyncio -from typing import Any, Callable, Dict, List, Tuple - -import pytest - -from declearn.communication import ( - build_client, - build_server, - list_available_protocols, -) -from declearn.communication.api import NetworkClient, NetworkServer -from declearn.communication.messaging import GenericMessage -from declearn.test_utils import run_as_processes - - -async def client_routine( - client: NetworkClient, -) -> None: - """Basic client testing routine.""" - print("Registering") - await client.register({"foo": "bar"}) - print("Receiving") - message = await client.check_message() - print(message) - print("Sending") - await client.send_message(GenericMessage(action="maybe", params={})) - print("Receiving") - message = await client.check_message() - print(message) - print("Sending") - await client.send_message(message) - print("Done!") - - -async def server_routine( - server: NetworkServer, - nb_clients: int = 1, -) -> None: - """Basic server testing routine.""" - data_info = await server.wait_for_clients( - min_clients=nb_clients, max_clients=nb_clients, timeout=5 - ) - print(data_info) - print("Sending") - await server.broadcast_message( - GenericMessage(action="train", params={"let's": "go"}) - ) - print("Receiving") - messages = await server.wait_for_messages() - print(messages) - print("Sending") - messages = { - client: GenericMessage("hello", {"name": client}) - for client in server.client_names - } - await server.send_messages(messages) - print("Receiving") - messages = await server.wait_for_messages() - print(messages) - print("Closing") - - -@pytest.mark.parametrize("nb_clients", [1, 3], ids=["1_client", "3_clients"]) -@pytest.mark.parametrize("use_ssl", [False, True], ids=["ssl", "unsafe"]) -@pytest.mark.parametrize("protocol", list_available_protocols()) -def test_routines( - protocol: str, - nb_clients: int, - use_ssl: bool, - ssl_cert: Dict[str, str], -) -> None: - """Test that the defined server and client routines run properly.""" - run_test_routines(protocol, nb_clients, use_ssl, ssl_cert) - - -def run_test_routines( - protocol: str, - nb_clients: int, - use_ssl: bool, - ssl_cert: Dict[str, str], -) -> None: - """Test that the defined server and client routines run properly.""" - # Set up (func, args) tuples that specify concurrent routines. - args = (protocol, nb_clients, use_ssl, ssl_cert) - routines = [_build_server_func(*args)] - routines.extend(_build_client_funcs(*args)) - # Run the former using isolated processes. - success, outputs = run_as_processes(*routines) - # Assert that all processes terminated properly. - assert success, "Routines failed:\n" + "\n".join( - [str(exc) for exc in outputs if isinstance(exc, RuntimeError)] - ) - - -def _build_server_func( - protocol: str, - nb_clients: int, - use_ssl: bool, - ssl_cert: Dict[str, str], -) -> Tuple[Callable[..., None], Tuple[Any, ...]]: - """Return arguments to spawn and use a NetworkServer in a process.""" - server_cfg = { - "protocol": protocol, - "host": "127.0.0.1", - "port": 8765, - "certificate": ssl_cert["server_cert"] if use_ssl else None, - "private_key": ssl_cert["server_pkey"] if use_ssl else None, - } # type: Dict[str, Any] - - # Define a coroutine that spawns and runs a server. - async def server_coroutine() -> None: - """Spawn a client and run `server_routine` in its context.""" - nonlocal nb_clients, server_cfg - async with build_server(**server_cfg) as server: - await server_routine(server, nb_clients) - - # Define a routine that runs the former. - def server_func() -> None: - """Run `server_coroutine`.""" - asyncio.run(server_coroutine()) - - # Return the former as a (func, arg) tuple. - return (server_func, tuple()) - - -def _build_client_funcs( - protocol: str, - nb_clients: int, - use_ssl: bool, - ssl_cert: Dict[str, str], -) -> List[Tuple[Callable[..., None], Tuple[Any, ...]]]: - """Return arguments to spawn and use NetworkClient objects in processes.""" - certificate = ssl_cert["client_cert"] if use_ssl else None - server_uri = "localhost:8765" - if protocol == "websockets": - server_uri = f"ws{'s' * use_ssl}://{server_uri}" - - # Define a coroutine that spawns and runs a client. - async def client_coroutine( - name: str, - ) -> None: - """Spawn a client and run `client_routine` in its context.""" - nonlocal certificate, protocol, server_uri - args = (protocol, server_uri, name, certificate) - async with build_client(*args) as client: - await client_routine(client) - - # Define a routine that runs the former. - def client_func(name: str) -> None: - """Run `client_coroutine`.""" - asyncio.run(client_coroutine(name)) - - # Return a list of (func, args) tuples. - return [(client_func, (f"client_{idx}",)) for idx in range(nb_clients)] diff --git a/test/communication/test_server.py b/test/communication/test_server.py new file mode 100644 index 0000000000000000000000000000000000000000..0cf6613ec8a1cfe94541877f7de8623f4351c5e7 --- /dev/null +++ b/test/communication/test_server.py @@ -0,0 +1,374 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for `declearn.communication.api.NetworkServer` classes.""" + +import asyncio +from typing import AsyncIterator, Dict +from unittest import mock + +import pytest +import pytest_asyncio + +from declearn.communication import ( + build_server, + list_available_protocols, + messaging, +) +from declearn.communication.api import NetworkServer +from declearn.utils import access_types_mapping, get_logger + + +SERVER_CLASSES = access_types_mapping("NetworkServer") + + +@pytest.mark.parametrize("protocol", list_available_protocols()) +class TestNetworkServerInit: + """Unit tests for `declearn.communication.api.NetworkServer` classes. + + This class groups tests that revolve around instantiating + a server and accessing its attributes and properties. + """ + + def test_registered(self, protocol: str) -> None: + """Assert that the tested class is properly type-registered.""" + assert protocol in SERVER_CLASSES + assert issubclass(SERVER_CLASSES[protocol], NetworkServer) + assert SERVER_CLASSES[protocol].protocol == protocol + + def test_init_minimal(self, protocol: str) -> None: + """Test that instantiation with minimal parameters work.""" + cls = SERVER_CLASSES[protocol] + server = cls(host="127.0.0.1", port=8765) + assert isinstance(server, cls) + assert server.host == "127.0.0.1" + assert server.port == 8765 + assert server.handler.__class__.__name__ == "MessagesHandler" + + def test_init_ssl(self, protocol: str, ssl_cert: Dict[str, str]) -> None: + """Test that instantiation with optional SSL parameters works.""" + cls = SERVER_CLASSES[protocol] + server = cls( + host="127.0.0.1", + port=8765, + certificate=ssl_cert["server_cert"], + private_key=ssl_cert["server_pkey"], + ) + assert getattr(server, "_ssl") is not None + + def test_init_ssl_fails(self, protocol: str) -> None: + """Test that instantiation with invalid SSL parameters fails.""" + cls = SERVER_CLASSES[protocol] + with pytest.raises(ValueError): + cls( + host="127.0.0.1", + port=8765, + certificate="certificate", + private_key=None, + ) + with pytest.raises(ValueError): + cls( + host="127.0.0.1", + port=8765, + certificate=None, + private_key="private-key", + ) + + def test_init_logger(self, protocol: str) -> None: + """Test that the 'logger' argument is properly parsed.""" + cls = SERVER_CLASSES[protocol] + logger = get_logger(f"{cls.__name__}Test") + srv = cls("127.0.0.1", 8765, logger=logger) + assert srv.logger is logger + + def test_uri(self, protocol: str) -> None: + """Test that the `uri` property can properly be accessed.""" + cls = SERVER_CLASSES[protocol] + srv = cls("127.0.0.1", 8765) + assert isinstance(srv.uri, str) + + def test_client_names(self, protocol: str) -> None: + """Test that the `client_names` propety can properly be accessed.""" + cls = SERVER_CLASSES[protocol] + srv = cls("127.0.0.1", 8765) + assert srv.client_names == set() + srv.handler.registered_clients[mock.MagicMock()] = "mock" + assert srv.client_names == {"mock"} + + +@pytest_asyncio.fixture(name="server") +async def server_fixture( + protocol: str, +) -> AsyncIterator[NetworkServer]: + """Fixture to provide with an instantiated and started NetworkServer.""" + server = build_server( + protocol=protocol, + host="127.0.0.1", + port=8765, + ) + async with server: + yield server + + +@pytest.mark.parametrize("protocol", list_available_protocols()) +class TestNetworkServerRegister: + """Unit tests for `NetworkServer` client-registration methods.""" + + @pytest.mark.asyncio + async def test_server_early_request(self, server: NetworkServer) -> None: + """Test that early 'JoinRequest' are adequately rejected.""" + ctx = mock.MagicMock() + req = messaging.JoinRequest("mock", {}).to_string() + rep = await server.handler.handle_message(req, context=ctx) + assert isinstance(rep, messaging.JoinReply) + assert not rep.accept + assert rep.flag == messaging.flags.REGISTRATION_UNSTARTED + + @pytest.mark.asyncio + async def test_server_await_client(self, server: NetworkServer) -> None: + """Test 'wait_for_clients' with a single client.""" + clients_info = asyncio.create_task( + server.wait_for_clients(min_clients=1) + ) + join_request = messaging.JoinRequest("mock", {}) + server_reply = asyncio.create_task( + server.handler.handle_message(join_request.to_string(), context=0) + ) + info = await clients_info + assert info == {"mock": {}} + reply = await server_reply + assert isinstance(reply, messaging.JoinReply) + assert reply.accept + assert reply.flag == messaging.flags.REGISTERED_WELCOME + + @pytest.mark.asyncio + async def test_server_await_timeout(self, server: NetworkServer) -> None: + """Test 'wait_for_clients' with an expected timeout error.""" + with pytest.raises(RuntimeError): + await server.wait_for_clients(timeout=1) + + @pytest.mark.asyncio + async def test_server_await_clients(self, server: NetworkServer) -> None: + """Test 'wait_for_clients' with a race between many clients. + + Test that the following cases yield expected behaviors: + - valid join request with a new name + - valid join request with a duplicated name + - duplicated valid join request (same context) + - late join request (third client with only two places) + """ + # Set up a server waiting routine and join requests' posting. + clients_info = server.wait_for_clients( + min_clients=1, max_clients=2, timeout=2 + ) + join_replies = [] + for idx in range(3): + req = messaging.JoinRequest("mock", {}).to_string() + ctx = min(idx, 1) # first and second contexts will be the same + join_replies.append(server.handler.handle_message(req, ctx)) + # Run the former routines concurrently. Verify server-side results. + results = await asyncio.gather(clients_info, *join_replies) + assert results[0] == {"mock": {}, "mock.1": {}} + # Verify request-wise replies. + for idx, rep in enumerate(results[1:]): + assert isinstance(rep, messaging.JoinReply) + if idx < 2: # first and third requests will be accepted + assert rep.accept, idx + assert rep.flag == messaging.flags.REGISTERED_WELCOME + elif idx == 2: # second request is reundant with the first + assert rep.accept + assert rep.flag == messaging.flags.REGISTERED_ALREADY + else: # fourth request is a third client when only two are exp. + assert not rep.accept + assert rep.flag == messaging.flags.REGISTRATION_CLOSED + + +@pytest.mark.parametrize("protocol", list_available_protocols()) +class TestNetworkServerSend: + """Unit tests for `NetworkServer` message-sending methods.""" + + @pytest.mark.asyncio + async def test_broadcast_message(self, server: NetworkServer) -> None: + """Test 'broadcast_message' to all clients. + + Mock the message-sending backend, that has dedicated tests. + """ + handler = server.handler = mock.create_autospec(server.handler) + setattr(server.handler, "client_names", {"a", "b", "c"}) + msg = messaging.GenericMessage(action="test", params={}) + await server.broadcast_message(msg) + assert handler.send_message.await_count == 3 + handler.send_message.assert_has_awaits( + [mock.call(msg, client, 1, None) for client in ("a", "b", "c")], + any_order=True, + ) + + @pytest.mark.asyncio + async def test_broadcast_message_subset( + self, server: NetworkServer + ) -> None: + """Test 'broadcast_message' to a selected subset of clients. + + Mock the message-sending backend, that has dedicated tests. + """ + handler = server.handler = mock.create_autospec(server.handler) + msg = messaging.GenericMessage(action="test", params={}) + await server.broadcast_message(msg, clients={"a", "b"}) + assert handler.send_message.await_count == 2 + handler.send_message.assert_has_awaits( + [mock.call(msg, client, 1, None) for client in ("a", "b")], + any_order=True, + ) + + @pytest.mark.asyncio + async def test_send_messages(self, server: NetworkServer) -> None: + """Test 'send_messages', mocking the message-sending backend. + + Mock the message-sending backend, that has dedicated tests. + """ + handler = server.handler = mock.create_autospec(server.handler) + messages = { + str(i): messaging.GenericMessage(action="test", params={"idx": i}) + for i in range(3) + } # type: Dict[str, messaging.Message] + await server.send_messages(messages) + assert handler.send_message.await_count == 3 + handler.send_message.assert_has_awaits( + [mock.call(msg, clt, 1, None) for clt, msg in messages.items()], + any_order=True, + ) + + @pytest.mark.asyncio + async def test_send_messages_error(self, server: NetworkServer) -> None: + """Test 'send_messages' error-raising, mocking the backend.""" + handler = server.handler = mock.create_autospec(server.handler) + handler.send_message.side_effect = RuntimeError + messages = { + str(i): messaging.GenericMessage(action="test", params={"idx": i}) + for i in range(3) + } # type: Dict[str, messaging.Message] + with pytest.raises(RuntimeError): + await server.send_messages(messages) + + @pytest.mark.asyncio + async def test_send_message(self, server: NetworkServer) -> None: + """Test 'send_message' - enabling to mock it elsewhere.""" + server.handler.registered_clients = {0: "mock.0", 1: "mock.1"} + msg = messaging.GenericMessage(action="test", params={}) + # Create tasks to send a message and let the client collect it. + req = messaging.GetMessageRequest().to_string() + send = server.send_message(msg, client="mock.0") + recv = server.handler.handle_message(req, 0) + # Check that the send routine works, as does the collection one. + outpt, reply = await asyncio.gather(send, recv) + assert outpt is None + assert reply == msg + + @pytest.mark.asyncio + async def test_send_message_errors(self, server: NetworkServer) -> None: + """Test that 'send_message' raises expected exceptions.""" + server.handler.registered_clients = {0: "mock.0", 1: "mock.1"} + msg = messaging.GenericMessage(action="test", params={}) + # Test case when sending to an unknown client. + with pytest.raises(KeyError): + await server.send_message(msg, client="unknown") + # Test case when sending results with a timeout. + with pytest.raises(asyncio.TimeoutError): + await server.send_message(msg, client="mock.0", timeout=1) + + @pytest.mark.asyncio + async def test_reject_msg_request(self, server: NetworkServer) -> None: + """Test that 'send_message' properly handles clients' identity.""" + server.handler.registered_clients = {0: "mock.0", 1: "mock.1"} + msg = messaging.GenericMessage(action="test", params={}) + # Create tasks to send a message and have another client request one. + req = messaging.GetMessageRequest(timeout=1).to_string() + send = server.send_message(msg, client="mock.0", timeout=1) + recv = server.handler.handle_message(req, 1) + # Check that both routines time out. + excpt, reply = await asyncio.gather(send, recv, return_exceptions=True) + assert isinstance(excpt, asyncio.TimeoutError) + assert isinstance(reply, messaging.Error) + assert reply.message == messaging.flags.CHECK_MESSAGE_TIMEOUT + + +@pytest.mark.parametrize("protocol", list_available_protocols()) +class TestNetworkServerRecv: + """Unit tests for `NetworkServer` message-receiving methods.""" + + @pytest.mark.asyncio + async def test_wait_for_messages(self, server: NetworkServer) -> None: + """Test that 'wait_for_messages' works correctly.""" + server.handler.registered_clients = {0: "mock.0", 1: "mock.1"} + msg = messaging.GenericMessage(action="test", params={}) + # Create tasks to wait for messages, and receive them. + wait = server.wait_for_messages() + recv_0 = server.handler.handle_message(msg.to_string(), context=0) + recv_1 = server.handler.handle_message(msg.to_string(), context=1) + # Await all tasks and assert that results match expectations. + outp, reply_0, reply_1 = await asyncio.gather(wait, recv_0, recv_1) + assert outp == {"mock.0": msg, "mock.1": msg} + assert isinstance(reply_0, messaging.Empty) + assert isinstance(reply_1, messaging.Empty) + + @pytest.mark.asyncio + async def test_wait_for_messages_subset( + self, server: NetworkServer + ) -> None: + """Test 'wait_for_messages' for a subset of all clients.""" + server.handler.registered_clients = {0: "mock.0", 1: "mock.1"} + msg = messaging.GenericMessage(action="test", params={}) + # Create tasks to wait for messages of the 1st client and receive it. + wait = server.wait_for_messages(clients={"mock.1"}) + recv = server.handler.handle_message(msg.to_string(), context=1) + # Await all tasks and assert that results match expectations. + outp, reply = await asyncio.gather(wait, recv) + assert outp == {"mock.1": msg} + assert isinstance(reply, messaging.Empty) + + @pytest.mark.asyncio + async def test_wait_for_messages_errors( + self, server: NetworkServer + ) -> None: + """Test that 'wait_for_messages' raises expected errors.""" + server.handler.registered_clients = {0: "mock.0", 1: "mock.1"} + # Specify to wait for messages from an unknown client. + with pytest.raises(KeyError): + await server.wait_for_messages(clients={"unknown"}) + # Wait for a message that never comes, with a timeout. + with pytest.raises(asyncio.TimeoutError): + await server.wait_for_messages(timeout=1) + + @pytest.mark.asyncio + async def test_reject_send_request(self, server: NetworkServer) -> None: + """Test that 'wait_for_messages' properly handles clients' identity.""" + server.handler.registered_clients = {0: "mock.0", 1: "mock.1"} + msg_0 = messaging.GenericMessage(action="test-0", params={}) + msg_1 = messaging.GenericMessage(action="test-1", params={}) + # Create tasks to wait for messages of the 1st client and receive it. + wait_0 = server.wait_for_messages(clients={"mock.0"}) + wait_1 = server.wait_for_messages(clients={"mock.1"}) + recv_1 = server.handler.handle_message(msg_1.to_string(), context=1) + recv_0 = server.handler.handle_message(msg_0.to_string(), context=0) + # Await all tasks and assert that results match expectations. + outp_0, outp_1, reply_1, reply_0 = await asyncio.gather( + wait_0, wait_1, recv_1, recv_0 + ) + assert outp_0 == {"mock.0": msg_0} + assert outp_1 == {"mock.1": msg_1} + assert isinstance(reply_0, messaging.Empty) + assert isinstance(reply_1, messaging.Empty) diff --git a/test/conftest.py b/test/conftest.py index 5d51c2c90bb87f3062215fc1a537281c810a3e19..dfc786f240829604ef1012a072a5d622b276e692 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -20,17 +20,29 @@ import pytest -def pytest_addoption(parser) -> None: # type: ignore - """Add a '--fulltest' option to the pytest commandline.""" +def pytest_addoption(parser) -> None: + """Add some custom options to the pytest commandline.""" parser.addoption( "--fulltest", action="store_true", default=False, help="--fulltest: run all test scenarios in 'test_main.py'", ) + parser.addoption( + "--cpu-only", + action="store_true", + default=False, + help="--cpu-only: disable the use of GPU devices in tests", + ) @pytest.fixture(name="fulltest") -def fulltest_fixture(request) -> bool: # type: ignore +def fulltest_fixture(request) -> bool: """Gather the '--fulltest' option's value.""" return bool(request.config.getoption("--fulltest")) + + +@pytest.fixture(name="cpu_only") +def cpu_only_fixture(request) -> bool: + """Gather the '--cpu-only' option's value.""" + return bool(request.config.getoption("--cpu-only")) diff --git a/test/functional/test_main.py b/test/functional/test_main.py index 5e5a59e3535596a5566c3fee07ef59eac47b6c8c..2b0986a93c2ebc3c66f85730749eccd592928e8a 100644 --- a/test/functional/test_main.py +++ b/test/functional/test_main.py @@ -35,6 +35,7 @@ from declearn.model.api import Model from declearn.model.sklearn import SklearnSGDModel from declearn.main import FederatedClient, FederatedServer from declearn.test_utils import run_as_processes +from declearn.utils import set_device_policy # Select the subset of tests to run, based on framework availability. # Note: TensorFlow and Torch (-related) imports are delayed due to this. @@ -214,6 +215,7 @@ class DeclearnTestCase: self, ) -> None: """Set up and run a FederatedServer.""" + set_device_policy(gpu=False) # disable GPU use to avoid concurrence model = self.build_model() netwk = self.build_netwk_server() optim = self.build_optim_config() @@ -231,6 +233,7 @@ class DeclearnTestCase: name: str = "client", ) -> None: """Set up and run a FederatedClient.""" + set_device_policy(gpu=False) # disable GPU use to avoid concurrence netwk = self.build_netwk_client(name) train = self.build_dataset(size=1000) valid = self.build_dataset(size=250) @@ -287,7 +290,7 @@ def test_declearn( Note: Use unsecured websockets communication, which are less costful to establish than gRPC and/or SSL-secured ones (the latter due to the certificates-generation costs). - Note: if websockets is unavailable, use gRPC (warn) or fail. + Note: If websockets is unavailable, use gRPC (warn) or fail. """ if not fulltest: if (kind != "Reg") or (strategy == "FedAvgM"): diff --git a/test/functional/test_regression.py b/test/functional/test_regression.py index 506184f51fbae29248420532ae1cbf6a4e335c33..66a01b67b9d6359e46b82ab8d55649fe7fce46fd 100644 --- a/test/functional/test_regression.py +++ b/test/functional/test_regression.py @@ -47,8 +47,6 @@ import tempfile from typing import List, Tuple import numpy as np -import tensorflow as tf # type: ignore -import torch from sklearn.datasets import make_regression # type: ignore from sklearn.linear_model import SGDRegressor # type: ignore @@ -59,10 +57,22 @@ from declearn.main.config import FLOptimConfig, FLRunConfig from declearn.metrics import RSquared from declearn.model.api import Model from declearn.model.sklearn import SklearnSGDModel -from declearn.model.tensorflow import TensorflowModel -from declearn.model.torch import TorchModel from declearn.optimizer import Optimizer from declearn.test_utils import FrameworkType, run_as_processes +from declearn.utils import set_device_policy + +# pylint: disable=ungrouped-imports; optional frameworks' dependencies +try: + import tensorflow as tf # type: ignore + from declearn.model.tensorflow import TensorflowModel +except ModuleNotFoundError: + pass +try: + import torch + from declearn.model.torch import TorchModel +except ModuleNotFoundError: + pass + SEED = 0 R2_THRESHOLD = 0.999 @@ -72,6 +82,7 @@ R2_THRESHOLD = 0.999 def get_model(framework: FrameworkType) -> Model: """Set up a simple toy regression model.""" + set_device_policy(gpu=False) # disable GPU use to avoid concurrence if framework == "numpy": np.random.seed(SEED) # set seed model = SklearnSGDModel.from_parameters( diff --git a/test/model/test_tflow.py b/test/model/test_tflow.py index aefd62e45489b2e264babbc69830fec68458147b..2e89fb4c71f1f129a2ed2ab3d7d7eb014640bc14 100644 --- a/test/model/test_tflow.py +++ b/test/model/test_tflow.py @@ -159,8 +159,11 @@ class TensorflowTestCase(ModelTestCase): def fixture_test_case( kind: Literal["MLP", "MLP-tune", "RNN", "CNN"], device: Literal["CPU", "GPU"], + cpu_only: bool, ) -> TensorflowTestCase: """Fixture to access a TensorflowTestCase.""" + if cpu_only and (device == "GPU"): + pytest.skip(reason="--cpu-only mode") return TensorflowTestCase(kind, device) diff --git a/test/model/test_torch.py b/test/model/test_torch.py index 15031107378231ef5b1a4ec62714da1e7988a53e..41e60a8bb8d1e389fceae03678c7ac1cbc06ef73 100644 --- a/test/model/test_torch.py +++ b/test/model/test_torch.py @@ -176,8 +176,11 @@ class TorchTestCase(ModelTestCase): def fixture_test_case( kind: Literal["MLP", "MLP-tune", "RNN", "CNN"], device: Literal["CPU", "GPU"], + cpu_only: bool, ) -> TorchTestCase: """Fixture to access a TorchTestCase.""" + if cpu_only and device == "GPU": + pytest.skip(reason="--cpu-only mode") return TorchTestCase(kind, device) diff --git a/tox.ini b/tox.ini index 74598c0a97c018a722066867a5b6821984058e18..2d19fd2fd6a81761b7400d944be49b2977cb0242 100644 --- a/tox.ini +++ b/tox.ini @@ -13,11 +13,7 @@ commands= # run unit tests first pytest {posargs} \ --ignore=test/functional/ \ - --ignore=test/communication/test_routines.py \ test - # run separately to avoid (unexplained) side-effects - pytest {posargs} \ - test/communication/test_routines.py # run functional tests (that build on units) pytest {posargs} \ test/functional/