From bbe43fba4e8c61a59057e8d518143b7f56453b9c Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Thu, 15 Feb 2024 17:52:46 +0100 Subject: [PATCH] Update unit tests for 'communication' module. --- test/communication/test_exchanges.py | 42 ++++++--- test/communication/test_grpc.py | 26 ++++-- test/communication/test_server.py | 123 ++++++++++++++++----------- 3 files changed, 119 insertions(+), 72 deletions(-) diff --git a/test/communication/test_exchanges.py b/test/communication/test_exchanges.py index c91270b0..a2f98048 100644 --- a/test/communication/test_exchanges.py +++ b/test/communication/test_exchanges.py @@ -43,11 +43,11 @@ from typing import AsyncIterator, Dict, List, Optional, Tuple import pytest import pytest_asyncio +from declearn import messaging from declearn.communication import ( build_client, build_server, list_available_protocols, - messaging, ) from declearn.communication.api import NetworkClient, NetworkServer @@ -133,7 +133,7 @@ class TestNetworkRegister: self, server: NetworkServer, client: NetworkClient ) -> None: """Test that early registration requests are rejected.""" - accepted = await client.register(data_info={"test": 42}) + accepted = await client.register() assert not accepted assert not server.client_names @@ -142,11 +142,11 @@ class TestNetworkRegister: self, server: NetworkServer, client: NetworkClient ) -> None: """Test that client registration works properly.""" - data_info, accepted = await asyncio.gather( + output, accepted = await asyncio.gather( server.wait_for_clients(1), - client.register(data_info={"test": 42}), + client.register(), ) - assert data_info == {"client": {"test": 42}} + assert output is None assert accepted assert server.client_names == {"client"} @@ -159,7 +159,7 @@ class TestNetworkRegister: with pytest.raises(RuntimeError): await server.wait_for_clients(timeout=1) # Try registering after that timeout. - accepted = await client.register(data_info={"test": 42}) + accepted = await client.register() assert not accepted @@ -184,7 +184,7 @@ async def agents_fixture( await asyncio.gather(*[client.start() for client in clients]) await asyncio.gather( server.wait_for_clients(n_clients, timeout=2), - *[client.register({}) for client in clients], + *[client.register() for client in clients], ) # Yield the server and clients. On exit, stop the clients. yield server, clients @@ -223,7 +223,12 @@ class TestNetworkExchanges: for idx, client in enumerate(clients): msg = messaging.GenericMessage(action="test", params={"idx": idx}) coros.append(client.send_message(msg)) - messages, *_ = await asyncio.gather(server.wait_for_messages(), *coros) + protos, *_ = await asyncio.gather(server.wait_for_messages(), *coros) + assert all( + isinstance(proto, messaging.SerializedMessage) + for proto in protos.values() + ) + messages = {key: proto.deserialize() for key, proto in protos.items()} assert messages == { c.name: messaging.GenericMessage(action="test", params={"idx": i}) for i, c in enumerate(clients) @@ -237,9 +242,12 @@ class TestNetworkExchanges: server, clients = agents msg = messaging.GenericMessage(action="test", params={"value": 42}) send = server.broadcast_message(msg) - recv = [client.check_message(timeout=1) for client in clients] + recv = [client.recv_message(timeout=1) for client in clients] _, *replies = await asyncio.gather(send, *recv) - assert all(reply == msg for reply in replies) + assert all( + isinstance(reply, messaging.SerializedMessage) for reply in replies + ) + assert all(reply.deserialize() == msg for reply in replies) async def server_to_clients_individual( self, @@ -252,10 +260,13 @@ class TestNetworkExchanges: for idx, name in enumerate(server.client_names) } # type: Dict[str, messaging.Message] send = server.send_messages(messages) - recv = [client.check_message(timeout=1) for client in clients] + recv = [client.recv_message(timeout=1) for client in clients] _, *replies = await asyncio.gather(send, *recv) assert all( - reply == messages[client.name] + isinstance(reply, messaging.SerializedMessage) for reply in replies + ) + assert all( + reply.deserialize() == messages[client.name] for client, reply in zip(clients, replies) ) @@ -272,7 +283,12 @@ class TestNetworkExchanges: action="test", params={"idx": idx, "content": large} ) coros.append(client.send_message(msg)) - messages, *_ = await asyncio.gather(server.wait_for_messages(), *coros) + protos, *_ = await asyncio.gather(server.wait_for_messages(), *coros) + assert all( + isinstance(proto, messaging.SerializedMessage) + for proto in protos.values() + ) + messages = {key: proto.deserialize() for key, proto in protos.items()} assert messages == { c.name: messaging.GenericMessage( action="test", params={"idx": i, "content": large} diff --git a/test/communication/test_grpc.py b/test/communication/test_grpc.py index 3d5f8856..9dcbb6d2 100644 --- a/test/communication/test_grpc.py +++ b/test/communication/test_grpc.py @@ -28,13 +28,14 @@ test scripts. """ import asyncio +import uuid from typing import AsyncIterator, Dict, Iterator import grpc # type: ignore import pytest import pytest_asyncio -from declearn.communication.messaging import Empty +from declearn.communication.api.backend.actions import Ping from declearn.communication.grpc._server import load_pem_file from declearn.communication.grpc import GrpcClient, GrpcServer from declearn.communication.grpc.protobufs import message_pb2 @@ -43,6 +44,7 @@ from declearn.communication.grpc.protobufs.message_pb2_grpc import ( MessageBoardStub, add_MessageBoardServicer_to_server, ) +from declearn.messaging import Message ################################################################# # 0. Set up pytest fixtures to avoid redundant code in tests @@ -52,6 +54,12 @@ PORT = 50051 SERVER_URI = f"{HOST}:{PORT}" +class StubMessage(Message): + """Minimal stub Message subclass.""" + + typekey = f"stub-{uuid.uuid4()}" + + class FakeMessageBoard(MessageBoardServicer): """Minimal MessageBoard implementation to test the connection.""" @@ -67,7 +75,7 @@ class FakeMessageBoard(MessageBoardServicer): request: message_pb2.Message, context: grpc.ServicerContext, ) -> Iterator[message_pb2.Message]: - yield message_pb2.Message(message=Empty().to_string()) + yield message_pb2.Message(message=Ping().to_string()) @pytest_asyncio.fixture(name="insecure_grpc_server") @@ -260,7 +268,7 @@ async def test_client_with_insecure_grpc_server( """Unit test for minimal unsecured GrpcClient use.""" # fixture; pylint: disable=unused-argument client = insecure_declearn_client - await client.send_message(Empty()) + await client.send_message(StubMessage()) @pytest.mark.asyncio @@ -271,7 +279,7 @@ async def test_secure_client_with_secure_grpc_server( """Unit test for minimal secured GrpcClient use.""" # fixture; pylint: disable=unused-argument client = secure_declearn_client - await client.send_message(Empty()) + await client.send_message(StubMessage()) @pytest.mark.asyncio @@ -283,7 +291,7 @@ async def test_insecure_client_with_secure_grpc_server_fails( # fixture; pylint: disable=unused-argument client = insecure_declearn_client with pytest.raises(grpc.aio.AioRpcError): - await client.send_message(Empty()) + await client.send_message(StubMessage()) ################################################################# @@ -302,7 +310,7 @@ async def test_client_with_insecure_server( await asyncio.gather( server.wait_for_clients(1, timeout=5), client.register({}) ) - await client.send_message(Empty()) + await client.send_message(StubMessage()) @pytest.mark.asyncio @@ -317,7 +325,7 @@ async def test_secure_client_with_secure_server( await asyncio.gather( server.wait_for_clients(1, timeout=5), client.register({}) ) - await client.send_message(Empty()) + await client.send_message(StubMessage()) @pytest.mark.asyncio @@ -329,7 +337,7 @@ async def test_insecure_client_with_secure_server_fails( # fixture; pylint: disable=unused-argument client = insecure_declearn_client with pytest.raises(grpc.aio.AioRpcError): - await client.send_message(Empty()) + await client.send_message(StubMessage()) @pytest.mark.asyncio @@ -341,4 +349,4 @@ async def test_secure_client_with_insecure_server_fails( # fixture; pylint: disable=unused-argument client = secure_declearn_client with pytest.raises(grpc.aio.AioRpcError): - await client.send_message(Empty()) + await client.send_message(StubMessage()) diff --git a/test/communication/test_server.py b/test/communication/test_server.py index 79b58ef5..69e77607 100644 --- a/test/communication/test_server.py +++ b/test/communication/test_server.py @@ -24,14 +24,15 @@ from unittest import mock import pytest import pytest_asyncio -from declearn import __version__ as VERSION +from declearn import messaging from declearn.communication import ( build_server, list_available_protocols, - messaging, ) from declearn.communication.api import NetworkServer +from declearn.communication.api.backend import actions, flags from declearn.utils import access_types_mapping, get_logger +from declearn.version import VERSION SERVER_CLASSES = access_types_mapping("NetworkServer") @@ -133,28 +134,25 @@ class TestNetworkServerRegister: async def test_server_early_request(self, server: NetworkServer) -> None: """Test that early 'JoinRequest' are adequately rejected.""" ctx = mock.MagicMock() - req = messaging.JoinRequest("mock", {}, VERSION).to_string() + req = actions.Join(name="mock", version=VERSION).to_string() rep = await server.handler.handle_message(req, context=ctx) - assert isinstance(rep, messaging.JoinReply) - assert not rep.accept - assert rep.flag == messaging.flags.REGISTRATION_UNSTARTED + assert isinstance(rep, actions.Reject) + assert rep.flag == flags.REGISTRATION_UNSTARTED @pytest.mark.asyncio async def test_server_await_client(self, server: NetworkServer) -> None: """Test 'wait_for_clients' with a single client.""" - clients_info = asyncio.create_task( + wait_for_clients = asyncio.create_task( server.wait_for_clients(min_clients=1) ) - join_request = messaging.JoinRequest("mock", {}, VERSION) + join_request = actions.Join(name="mock", version=VERSION) server_reply = asyncio.create_task( server.handler.handle_message(join_request.to_string(), context=0) ) - info = await clients_info - assert info == {"mock": {}} + await wait_for_clients reply = await server_reply - assert isinstance(reply, messaging.JoinReply) - assert reply.accept - assert reply.flag == messaging.flags.REGISTERED_WELCOME + assert isinstance(reply, actions.Accept) + assert reply.flag == flags.REGISTERED_WELCOME @pytest.mark.asyncio async def test_server_await_timeout(self, server: NetworkServer) -> None: @@ -173,29 +171,28 @@ class TestNetworkServerRegister: - late join request (third client with only two places) """ # Set up a server waiting routine and join requests' posting. - clients_info = server.wait_for_clients( + wait_clients = server.wait_for_clients( min_clients=1, max_clients=2, timeout=2 ) join_replies = [] for idx in range(3): - req = messaging.JoinRequest("mock", {}, VERSION).to_string() + req = actions.Join(name="mock", version=VERSION).to_string() ctx = min(idx, 1) # first and second contexts will be the same join_replies.append(server.handler.handle_message(req, ctx)) # Run the former routines concurrently. Verify server-side results. - results = await asyncio.gather(clients_info, *join_replies) - assert results[0] == {"mock": {}, "mock.1": {}} + results = await asyncio.gather(wait_clients, *join_replies) + assert results[0] is None # Verify request-wise replies. for idx, rep in enumerate(results[1:]): - assert isinstance(rep, messaging.JoinReply) if idx < 2: # first and third requests will be accepted - assert rep.accept, idx - assert rep.flag == messaging.flags.REGISTERED_WELCOME + assert isinstance(rep, actions.Accept) + assert rep.flag == flags.REGISTERED_WELCOME elif idx == 2: # second request is reundant with the first - assert rep.accept - assert rep.flag == messaging.flags.REGISTERED_ALREADY + assert isinstance(rep, actions.Accept) + assert rep.flag == flags.REGISTERED_ALREADY else: # fourth request is a third client when only two are exp. - assert not rep.accept - assert rep.flag == messaging.flags.REGISTRATION_CLOSED + assert isinstance(rep, actions.Reject) + assert rep.flag == flags.REGISTRATION_CLOSED @pytest.mark.parametrize("protocol", list_available_protocols()) @@ -213,8 +210,9 @@ class TestNetworkServerSend: msg = messaging.GenericMessage(action="test", params={}) await server.broadcast_message(msg) assert handler.send_message.await_count == 3 + dump = msg.to_string() handler.send_message.assert_has_awaits( - [mock.call(msg, client, 1, None) for client in ("a", "b", "c")], + [mock.call(dump, client, None) for client in ("a", "b", "c")], any_order=True, ) @@ -230,8 +228,9 @@ class TestNetworkServerSend: msg = messaging.GenericMessage(action="test", params={}) await server.broadcast_message(msg, clients={"a", "b"}) assert handler.send_message.await_count == 2 + dump = msg.to_string() handler.send_message.assert_has_awaits( - [mock.call(msg, client, 1, None) for client in ("a", "b")], + [mock.call(dump, client, None) for client in ("a", "b")], any_order=True, ) @@ -249,7 +248,10 @@ class TestNetworkServerSend: await server.send_messages(messages) assert handler.send_message.await_count == 3 handler.send_message.assert_has_awaits( - [mock.call(msg, clt, 1, None) for clt, msg in messages.items()], + [ + mock.call(msg.to_string(), clt, None) + for clt, msg in messages.items() + ], any_order=True, ) @@ -271,13 +273,14 @@ class TestNetworkServerSend: server.handler.registered_clients = {0: "mock.0", 1: "mock.1"} msg = messaging.GenericMessage(action="test", params={}) # Create tasks to send a message and let the client collect it. - req = messaging.GetMessageRequest().to_string() + req = actions.Recv().to_string() send = server.send_message(msg, client="mock.0") recv = server.handler.handle_message(req, 0) # Check that the send routine works, as does the collection one. outpt, reply = await asyncio.gather(send, recv) assert outpt is None - assert reply == msg + assert isinstance(reply, actions.Send) + assert reply.content == msg.to_string() @pytest.mark.asyncio async def test_send_message_errors(self, server: NetworkServer) -> None: @@ -296,7 +299,7 @@ class TestNetworkServerSend: """Test that 'send_message' properly handles clients' identity.""" server.handler.registered_clients = {0: "mock.0", 1: "mock.1"} msg = messaging.GenericMessage(action="test", params={}) - req = messaging.GetMessageRequest(timeout=1).to_string() + req = actions.Recv(timeout=1).to_string() # Create tasks to send a message and have another client request one. send = server.send_message(msg, client="mock.0", timeout=1) recv = server.handler.handle_message(req, 1) @@ -305,8 +308,8 @@ class TestNetworkServerSend: send, recv, return_exceptions=True ) # type: Tuple[asyncio.TimeoutError, messaging.Error] assert isinstance(excpt, asyncio.TimeoutError) - assert isinstance(reply, messaging.Error) - assert reply.message == messaging.flags.CHECK_MESSAGE_TIMEOUT + assert isinstance(reply, actions.Reject) + assert reply.flag == flags.CHECK_MESSAGE_TIMEOUT @pytest.mark.parametrize("protocol", list_available_protocols()) @@ -318,15 +321,22 @@ class TestNetworkServerRecv: """Test that 'wait_for_messages' works correctly.""" server.handler.registered_clients = {0: "mock.0", 1: "mock.1"} msg = messaging.GenericMessage(action="test", params={}) + act = actions.Send(msg.to_string()) # Create tasks to wait for messages, and receive them. - wait = server.wait_for_messages() - recv_0 = server.handler.handle_message(msg.to_string(), context=0) - recv_1 = server.handler.handle_message(msg.to_string(), context=1) + wait = server.wait_for_messages_with_timeout(2) + recv_0 = server.handler.handle_message(act.to_string(), context=0) + recv_1 = server.handler.handle_message(act.to_string(), context=1) # Await all tasks and assert that results match expectations. outp, reply_0, reply_1 = await asyncio.gather(wait, recv_0, recv_1) - assert outp == {"mock.0": msg, "mock.1": msg} - assert isinstance(reply_0, messaging.Empty) - assert isinstance(reply_1, messaging.Empty) + assert isinstance(outp[1], list) and not outp[1] + assert all( + isinstance(out, messaging.SerializedMessage) + for out in outp[0].values() + ) + recv = {key: out.deserialize() for key, out in outp[0].items()} + assert recv == {"mock.0": msg, "mock.1": msg} + assert isinstance(reply_0, actions.Ping) + assert isinstance(reply_1, actions.Ping) @pytest.mark.asyncio async def test_wait_for_messages_subset( @@ -336,12 +346,18 @@ class TestNetworkServerRecv: server.handler.registered_clients = {0: "mock.0", 1: "mock.1"} msg = messaging.GenericMessage(action="test", params={}) # Create tasks to wait for messages of the 1st client and receive it. - wait = server.wait_for_messages(clients={"mock.1"}) - recv = server.handler.handle_message(msg.to_string(), context=1) + wait = server.wait_for_messages_with_timeout( + timeout=2, clients={"mock.1"} + ) + recv = server.handler.handle_message( + actions.Send(msg.to_string()).to_string(), context=1 + ) # Await all tasks and assert that results match expectations. outp, reply = await asyncio.gather(wait, recv) - assert outp == {"mock.1": msg} - assert isinstance(reply, messaging.Empty) + assert isinstance(outp[1], list) and not outp[1] + recv = {key: out.deserialize() for key, out in outp[0].items()} + assert recv == {"mock.1": msg} + assert isinstance(reply, actions.Ping) @pytest.mark.asyncio async def test_wait_for_messages_errors( @@ -353,8 +369,9 @@ class TestNetworkServerRecv: with pytest.raises(KeyError): await server.wait_for_messages(clients={"unknown"}) # Wait for a message that never comes, with a timeout. - with pytest.raises(asyncio.TimeoutError): - await server.wait_for_messages(timeout=1) + recv, miss = await server.wait_for_messages_with_timeout(timeout=1) + assert isinstance(recv, dict) and not recv + assert sorted(miss) == ["mock.0", "mock.1"] @pytest.mark.asyncio async def test_reject_send_request(self, server: NetworkServer) -> None: @@ -365,13 +382,19 @@ class TestNetworkServerRecv: # Create tasks to wait for messages of the 1st client and receive it. wait_0 = server.wait_for_messages(clients={"mock.0"}) wait_1 = server.wait_for_messages(clients={"mock.1"}) - recv_1 = server.handler.handle_message(msg_1.to_string(), context=1) - recv_0 = server.handler.handle_message(msg_0.to_string(), context=0) + recv_1 = server.handler.handle_message( + actions.Send(msg_1.to_string()).to_string(), context=1 + ) + recv_0 = server.handler.handle_message( + actions.Send(msg_0.to_string()).to_string(), context=0 + ) # Await all tasks and assert that results match expectations. outp_0, outp_1, reply_1, reply_0 = await asyncio.gather( wait_0, wait_1, recv_1, recv_0 ) - assert outp_0 == {"mock.0": msg_0} - assert outp_1 == {"mock.1": msg_1} - assert isinstance(reply_0, messaging.Empty) - assert isinstance(reply_1, messaging.Empty) + recv_0 = {key: out.deserialize() for key, out in outp_0.items()} + recv_1 = {key: out.deserialize() for key, out in outp_1.items()} + assert recv_0 == {"mock.0": msg_0} + assert recv_1 == {"mock.1": msg_1} + assert isinstance(reply_0, actions.Ping) + assert isinstance(reply_1, actions.Ping) -- GitLab