From bbe43fba4e8c61a59057e8d518143b7f56453b9c Mon Sep 17 00:00:00 2001
From: Paul Andrey <paul.andrey@inria.fr>
Date: Thu, 15 Feb 2024 17:52:46 +0100
Subject: [PATCH] Update unit tests for 'communication' module.

---
 test/communication/test_exchanges.py |  42 ++++++---
 test/communication/test_grpc.py      |  26 ++++--
 test/communication/test_server.py    | 123 ++++++++++++++++-----------
 3 files changed, 119 insertions(+), 72 deletions(-)

diff --git a/test/communication/test_exchanges.py b/test/communication/test_exchanges.py
index c91270b0..a2f98048 100644
--- a/test/communication/test_exchanges.py
+++ b/test/communication/test_exchanges.py
@@ -43,11 +43,11 @@ from typing import AsyncIterator, Dict, List, Optional, Tuple
 import pytest
 import pytest_asyncio
 
+from declearn import messaging
 from declearn.communication import (
     build_client,
     build_server,
     list_available_protocols,
-    messaging,
 )
 from declearn.communication.api import NetworkClient, NetworkServer
 
@@ -133,7 +133,7 @@ class TestNetworkRegister:
         self, server: NetworkServer, client: NetworkClient
     ) -> None:
         """Test that early registration requests are rejected."""
-        accepted = await client.register(data_info={"test": 42})
+        accepted = await client.register()
         assert not accepted
         assert not server.client_names
 
@@ -142,11 +142,11 @@ class TestNetworkRegister:
         self, server: NetworkServer, client: NetworkClient
     ) -> None:
         """Test that client registration works properly."""
-        data_info, accepted = await asyncio.gather(
+        output, accepted = await asyncio.gather(
             server.wait_for_clients(1),
-            client.register(data_info={"test": 42}),
+            client.register(),
         )
-        assert data_info == {"client": {"test": 42}}
+        assert output is None
         assert accepted
         assert server.client_names == {"client"}
 
@@ -159,7 +159,7 @@ class TestNetworkRegister:
         with pytest.raises(RuntimeError):
             await server.wait_for_clients(timeout=1)
         # Try registering after that timeout.
-        accepted = await client.register(data_info={"test": 42})
+        accepted = await client.register()
         assert not accepted
 
 
@@ -184,7 +184,7 @@ async def agents_fixture(
     await asyncio.gather(*[client.start() for client in clients])
     await asyncio.gather(
         server.wait_for_clients(n_clients, timeout=2),
-        *[client.register({}) for client in clients],
+        *[client.register() for client in clients],
     )
     # Yield the server and clients. On exit, stop the clients.
     yield server, clients
@@ -223,7 +223,12 @@ class TestNetworkExchanges:
         for idx, client in enumerate(clients):
             msg = messaging.GenericMessage(action="test", params={"idx": idx})
             coros.append(client.send_message(msg))
-        messages, *_ = await asyncio.gather(server.wait_for_messages(), *coros)
+        protos, *_ = await asyncio.gather(server.wait_for_messages(), *coros)
+        assert all(
+            isinstance(proto, messaging.SerializedMessage)
+            for proto in protos.values()
+        )
+        messages = {key: proto.deserialize() for key, proto in protos.items()}
         assert messages == {
             c.name: messaging.GenericMessage(action="test", params={"idx": i})
             for i, c in enumerate(clients)
@@ -237,9 +242,12 @@ class TestNetworkExchanges:
         server, clients = agents
         msg = messaging.GenericMessage(action="test", params={"value": 42})
         send = server.broadcast_message(msg)
-        recv = [client.check_message(timeout=1) for client in clients]
+        recv = [client.recv_message(timeout=1) for client in clients]
         _, *replies = await asyncio.gather(send, *recv)
-        assert all(reply == msg for reply in replies)
+        assert all(
+            isinstance(reply, messaging.SerializedMessage) for reply in replies
+        )
+        assert all(reply.deserialize() == msg for reply in replies)
 
     async def server_to_clients_individual(
         self,
@@ -252,10 +260,13 @@ class TestNetworkExchanges:
             for idx, name in enumerate(server.client_names)
         }  # type: Dict[str, messaging.Message]
         send = server.send_messages(messages)
-        recv = [client.check_message(timeout=1) for client in clients]
+        recv = [client.recv_message(timeout=1) for client in clients]
         _, *replies = await asyncio.gather(send, *recv)
         assert all(
-            reply == messages[client.name]
+            isinstance(reply, messaging.SerializedMessage) for reply in replies
+        )
+        assert all(
+            reply.deserialize() == messages[client.name]
             for client, reply in zip(clients, replies)
         )
 
@@ -272,7 +283,12 @@ class TestNetworkExchanges:
                 action="test", params={"idx": idx, "content": large}
             )
             coros.append(client.send_message(msg))
-        messages, *_ = await asyncio.gather(server.wait_for_messages(), *coros)
+        protos, *_ = await asyncio.gather(server.wait_for_messages(), *coros)
+        assert all(
+            isinstance(proto, messaging.SerializedMessage)
+            for proto in protos.values()
+        )
+        messages = {key: proto.deserialize() for key, proto in protos.items()}
         assert messages == {
             c.name: messaging.GenericMessage(
                 action="test", params={"idx": i, "content": large}
diff --git a/test/communication/test_grpc.py b/test/communication/test_grpc.py
index 3d5f8856..9dcbb6d2 100644
--- a/test/communication/test_grpc.py
+++ b/test/communication/test_grpc.py
@@ -28,13 +28,14 @@ test scripts.
 """
 
 import asyncio
+import uuid
 from typing import AsyncIterator, Dict, Iterator
 
 import grpc  # type: ignore
 import pytest
 import pytest_asyncio
 
-from declearn.communication.messaging import Empty
+from declearn.communication.api.backend.actions import Ping
 from declearn.communication.grpc._server import load_pem_file
 from declearn.communication.grpc import GrpcClient, GrpcServer
 from declearn.communication.grpc.protobufs import message_pb2
@@ -43,6 +44,7 @@ from declearn.communication.grpc.protobufs.message_pb2_grpc import (
     MessageBoardStub,
     add_MessageBoardServicer_to_server,
 )
+from declearn.messaging import Message
 
 #################################################################
 # 0. Set up pytest fixtures to avoid redundant code in tests
@@ -52,6 +54,12 @@ PORT = 50051
 SERVER_URI = f"{HOST}:{PORT}"
 
 
+class StubMessage(Message):
+    """Minimal stub Message subclass."""
+
+    typekey = f"stub-{uuid.uuid4()}"
+
+
 class FakeMessageBoard(MessageBoardServicer):
     """Minimal MessageBoard implementation to test the connection."""
 
@@ -67,7 +75,7 @@ class FakeMessageBoard(MessageBoardServicer):
         request: message_pb2.Message,
         context: grpc.ServicerContext,
     ) -> Iterator[message_pb2.Message]:
-        yield message_pb2.Message(message=Empty().to_string())
+        yield message_pb2.Message(message=Ping().to_string())
 
 
 @pytest_asyncio.fixture(name="insecure_grpc_server")
@@ -260,7 +268,7 @@ async def test_client_with_insecure_grpc_server(
     """Unit test for minimal unsecured GrpcClient use."""
     # fixture; pylint: disable=unused-argument
     client = insecure_declearn_client
-    await client.send_message(Empty())
+    await client.send_message(StubMessage())
 
 
 @pytest.mark.asyncio
@@ -271,7 +279,7 @@ async def test_secure_client_with_secure_grpc_server(
     """Unit test for minimal secured GrpcClient use."""
     # fixture; pylint: disable=unused-argument
     client = secure_declearn_client
-    await client.send_message(Empty())
+    await client.send_message(StubMessage())
 
 
 @pytest.mark.asyncio
@@ -283,7 +291,7 @@ async def test_insecure_client_with_secure_grpc_server_fails(
     # fixture; pylint: disable=unused-argument
     client = insecure_declearn_client
     with pytest.raises(grpc.aio.AioRpcError):
-        await client.send_message(Empty())
+        await client.send_message(StubMessage())
 
 
 #################################################################
@@ -302,7 +310,7 @@ async def test_client_with_insecure_server(
     await asyncio.gather(
         server.wait_for_clients(1, timeout=5), client.register({})
     )
-    await client.send_message(Empty())
+    await client.send_message(StubMessage())
 
 
 @pytest.mark.asyncio
@@ -317,7 +325,7 @@ async def test_secure_client_with_secure_server(
     await asyncio.gather(
         server.wait_for_clients(1, timeout=5), client.register({})
     )
-    await client.send_message(Empty())
+    await client.send_message(StubMessage())
 
 
 @pytest.mark.asyncio
@@ -329,7 +337,7 @@ async def test_insecure_client_with_secure_server_fails(
     # fixture; pylint: disable=unused-argument
     client = insecure_declearn_client
     with pytest.raises(grpc.aio.AioRpcError):
-        await client.send_message(Empty())
+        await client.send_message(StubMessage())
 
 
 @pytest.mark.asyncio
@@ -341,4 +349,4 @@ async def test_secure_client_with_insecure_server_fails(
     # fixture; pylint: disable=unused-argument
     client = secure_declearn_client
     with pytest.raises(grpc.aio.AioRpcError):
-        await client.send_message(Empty())
+        await client.send_message(StubMessage())
diff --git a/test/communication/test_server.py b/test/communication/test_server.py
index 79b58ef5..69e77607 100644
--- a/test/communication/test_server.py
+++ b/test/communication/test_server.py
@@ -24,14 +24,15 @@ from unittest import mock
 import pytest
 import pytest_asyncio
 
-from declearn import __version__ as VERSION
+from declearn import messaging
 from declearn.communication import (
     build_server,
     list_available_protocols,
-    messaging,
 )
 from declearn.communication.api import NetworkServer
+from declearn.communication.api.backend import actions, flags
 from declearn.utils import access_types_mapping, get_logger
+from declearn.version import VERSION
 
 
 SERVER_CLASSES = access_types_mapping("NetworkServer")
@@ -133,28 +134,25 @@ class TestNetworkServerRegister:
     async def test_server_early_request(self, server: NetworkServer) -> None:
         """Test that early 'JoinRequest' are adequately rejected."""
         ctx = mock.MagicMock()
-        req = messaging.JoinRequest("mock", {}, VERSION).to_string()
+        req = actions.Join(name="mock", version=VERSION).to_string()
         rep = await server.handler.handle_message(req, context=ctx)
-        assert isinstance(rep, messaging.JoinReply)
-        assert not rep.accept
-        assert rep.flag == messaging.flags.REGISTRATION_UNSTARTED
+        assert isinstance(rep, actions.Reject)
+        assert rep.flag == flags.REGISTRATION_UNSTARTED
 
     @pytest.mark.asyncio
     async def test_server_await_client(self, server: NetworkServer) -> None:
         """Test 'wait_for_clients' with a single client."""
-        clients_info = asyncio.create_task(
+        wait_for_clients = asyncio.create_task(
             server.wait_for_clients(min_clients=1)
         )
-        join_request = messaging.JoinRequest("mock", {}, VERSION)
+        join_request = actions.Join(name="mock", version=VERSION)
         server_reply = asyncio.create_task(
             server.handler.handle_message(join_request.to_string(), context=0)
         )
-        info = await clients_info
-        assert info == {"mock": {}}
+        await wait_for_clients
         reply = await server_reply
-        assert isinstance(reply, messaging.JoinReply)
-        assert reply.accept
-        assert reply.flag == messaging.flags.REGISTERED_WELCOME
+        assert isinstance(reply, actions.Accept)
+        assert reply.flag == flags.REGISTERED_WELCOME
 
     @pytest.mark.asyncio
     async def test_server_await_timeout(self, server: NetworkServer) -> None:
@@ -173,29 +171,28 @@ class TestNetworkServerRegister:
         - late join request (third client with only two places)
         """
         # Set up a server waiting routine and join requests' posting.
-        clients_info = server.wait_for_clients(
+        wait_clients = server.wait_for_clients(
             min_clients=1, max_clients=2, timeout=2
         )
         join_replies = []
         for idx in range(3):
-            req = messaging.JoinRequest("mock", {}, VERSION).to_string()
+            req = actions.Join(name="mock", version=VERSION).to_string()
             ctx = min(idx, 1)  # first and second contexts will be the same
             join_replies.append(server.handler.handle_message(req, ctx))
         # Run the former routines concurrently. Verify server-side results.
-        results = await asyncio.gather(clients_info, *join_replies)
-        assert results[0] == {"mock": {}, "mock.1": {}}
+        results = await asyncio.gather(wait_clients, *join_replies)
+        assert results[0] is None
         # Verify request-wise replies.
         for idx, rep in enumerate(results[1:]):
-            assert isinstance(rep, messaging.JoinReply)
             if idx < 2:  # first and third requests will be accepted
-                assert rep.accept, idx
-                assert rep.flag == messaging.flags.REGISTERED_WELCOME
+                assert isinstance(rep, actions.Accept)
+                assert rep.flag == flags.REGISTERED_WELCOME
             elif idx == 2:  # second request is reundant with the first
-                assert rep.accept
-                assert rep.flag == messaging.flags.REGISTERED_ALREADY
+                assert isinstance(rep, actions.Accept)
+                assert rep.flag == flags.REGISTERED_ALREADY
             else:  # fourth request is a third client when only two are exp.
-                assert not rep.accept
-                assert rep.flag == messaging.flags.REGISTRATION_CLOSED
+                assert isinstance(rep, actions.Reject)
+                assert rep.flag == flags.REGISTRATION_CLOSED
 
 
 @pytest.mark.parametrize("protocol", list_available_protocols())
@@ -213,8 +210,9 @@ class TestNetworkServerSend:
         msg = messaging.GenericMessage(action="test", params={})
         await server.broadcast_message(msg)
         assert handler.send_message.await_count == 3
+        dump = msg.to_string()
         handler.send_message.assert_has_awaits(
-            [mock.call(msg, client, 1, None) for client in ("a", "b", "c")],
+            [mock.call(dump, client, None) for client in ("a", "b", "c")],
             any_order=True,
         )
 
@@ -230,8 +228,9 @@ class TestNetworkServerSend:
         msg = messaging.GenericMessage(action="test", params={})
         await server.broadcast_message(msg, clients={"a", "b"})
         assert handler.send_message.await_count == 2
+        dump = msg.to_string()
         handler.send_message.assert_has_awaits(
-            [mock.call(msg, client, 1, None) for client in ("a", "b")],
+            [mock.call(dump, client, None) for client in ("a", "b")],
             any_order=True,
         )
 
@@ -249,7 +248,10 @@ class TestNetworkServerSend:
         await server.send_messages(messages)
         assert handler.send_message.await_count == 3
         handler.send_message.assert_has_awaits(
-            [mock.call(msg, clt, 1, None) for clt, msg in messages.items()],
+            [
+                mock.call(msg.to_string(), clt, None)
+                for clt, msg in messages.items()
+            ],
             any_order=True,
         )
 
@@ -271,13 +273,14 @@ class TestNetworkServerSend:
         server.handler.registered_clients = {0: "mock.0", 1: "mock.1"}
         msg = messaging.GenericMessage(action="test", params={})
         # Create tasks to send a message and let the client collect it.
-        req = messaging.GetMessageRequest().to_string()
+        req = actions.Recv().to_string()
         send = server.send_message(msg, client="mock.0")
         recv = server.handler.handle_message(req, 0)
         # Check that the send routine works, as does the collection one.
         outpt, reply = await asyncio.gather(send, recv)
         assert outpt is None
-        assert reply == msg
+        assert isinstance(reply, actions.Send)
+        assert reply.content == msg.to_string()
 
     @pytest.mark.asyncio
     async def test_send_message_errors(self, server: NetworkServer) -> None:
@@ -296,7 +299,7 @@ class TestNetworkServerSend:
         """Test that 'send_message' properly handles clients' identity."""
         server.handler.registered_clients = {0: "mock.0", 1: "mock.1"}
         msg = messaging.GenericMessage(action="test", params={})
-        req = messaging.GetMessageRequest(timeout=1).to_string()
+        req = actions.Recv(timeout=1).to_string()
         # Create tasks to send a message and have another client request one.
         send = server.send_message(msg, client="mock.0", timeout=1)
         recv = server.handler.handle_message(req, 1)
@@ -305,8 +308,8 @@ class TestNetworkServerSend:
             send, recv, return_exceptions=True
         )  # type: Tuple[asyncio.TimeoutError, messaging.Error]
         assert isinstance(excpt, asyncio.TimeoutError)
-        assert isinstance(reply, messaging.Error)
-        assert reply.message == messaging.flags.CHECK_MESSAGE_TIMEOUT
+        assert isinstance(reply, actions.Reject)
+        assert reply.flag == flags.CHECK_MESSAGE_TIMEOUT
 
 
 @pytest.mark.parametrize("protocol", list_available_protocols())
@@ -318,15 +321,22 @@ class TestNetworkServerRecv:
         """Test that 'wait_for_messages' works correctly."""
         server.handler.registered_clients = {0: "mock.0", 1: "mock.1"}
         msg = messaging.GenericMessage(action="test", params={})
+        act = actions.Send(msg.to_string())
         # Create tasks to wait for messages, and receive them.
-        wait = server.wait_for_messages()
-        recv_0 = server.handler.handle_message(msg.to_string(), context=0)
-        recv_1 = server.handler.handle_message(msg.to_string(), context=1)
+        wait = server.wait_for_messages_with_timeout(2)
+        recv_0 = server.handler.handle_message(act.to_string(), context=0)
+        recv_1 = server.handler.handle_message(act.to_string(), context=1)
         # Await all tasks and assert that results match expectations.
         outp, reply_0, reply_1 = await asyncio.gather(wait, recv_0, recv_1)
-        assert outp == {"mock.0": msg, "mock.1": msg}
-        assert isinstance(reply_0, messaging.Empty)
-        assert isinstance(reply_1, messaging.Empty)
+        assert isinstance(outp[1], list) and not outp[1]
+        assert all(
+            isinstance(out, messaging.SerializedMessage)
+            for out in outp[0].values()
+        )
+        recv = {key: out.deserialize() for key, out in outp[0].items()}
+        assert recv == {"mock.0": msg, "mock.1": msg}
+        assert isinstance(reply_0, actions.Ping)
+        assert isinstance(reply_1, actions.Ping)
 
     @pytest.mark.asyncio
     async def test_wait_for_messages_subset(
@@ -336,12 +346,18 @@ class TestNetworkServerRecv:
         server.handler.registered_clients = {0: "mock.0", 1: "mock.1"}
         msg = messaging.GenericMessage(action="test", params={})
         # Create tasks to wait for messages of the 1st client and receive it.
-        wait = server.wait_for_messages(clients={"mock.1"})
-        recv = server.handler.handle_message(msg.to_string(), context=1)
+        wait = server.wait_for_messages_with_timeout(
+            timeout=2, clients={"mock.1"}
+        )
+        recv = server.handler.handle_message(
+            actions.Send(msg.to_string()).to_string(), context=1
+        )
         # Await all tasks and assert that results match expectations.
         outp, reply = await asyncio.gather(wait, recv)
-        assert outp == {"mock.1": msg}
-        assert isinstance(reply, messaging.Empty)
+        assert isinstance(outp[1], list) and not outp[1]
+        recv = {key: out.deserialize() for key, out in outp[0].items()}
+        assert recv == {"mock.1": msg}
+        assert isinstance(reply, actions.Ping)
 
     @pytest.mark.asyncio
     async def test_wait_for_messages_errors(
@@ -353,8 +369,9 @@ class TestNetworkServerRecv:
         with pytest.raises(KeyError):
             await server.wait_for_messages(clients={"unknown"})
         # Wait for a message that never comes, with a timeout.
-        with pytest.raises(asyncio.TimeoutError):
-            await server.wait_for_messages(timeout=1)
+        recv, miss = await server.wait_for_messages_with_timeout(timeout=1)
+        assert isinstance(recv, dict) and not recv
+        assert sorted(miss) == ["mock.0", "mock.1"]
 
     @pytest.mark.asyncio
     async def test_reject_send_request(self, server: NetworkServer) -> None:
@@ -365,13 +382,19 @@ class TestNetworkServerRecv:
         # Create tasks to wait for messages of the 1st client and receive it.
         wait_0 = server.wait_for_messages(clients={"mock.0"})
         wait_1 = server.wait_for_messages(clients={"mock.1"})
-        recv_1 = server.handler.handle_message(msg_1.to_string(), context=1)
-        recv_0 = server.handler.handle_message(msg_0.to_string(), context=0)
+        recv_1 = server.handler.handle_message(
+            actions.Send(msg_1.to_string()).to_string(), context=1
+        )
+        recv_0 = server.handler.handle_message(
+            actions.Send(msg_0.to_string()).to_string(), context=0
+        )
         # Await all tasks and assert that results match expectations.
         outp_0, outp_1, reply_1, reply_0 = await asyncio.gather(
             wait_0, wait_1, recv_1, recv_0
         )
-        assert outp_0 == {"mock.0": msg_0}
-        assert outp_1 == {"mock.1": msg_1}
-        assert isinstance(reply_0, messaging.Empty)
-        assert isinstance(reply_1, messaging.Empty)
+        recv_0 = {key: out.deserialize() for key, out in outp_0.items()}
+        recv_1 = {key: out.deserialize() for key, out in outp_1.items()}
+        assert recv_0 == {"mock.0": msg_0}
+        assert recv_1 == {"mock.1": msg_1}
+        assert isinstance(reply_0, actions.Ping)
+        assert isinstance(reply_1, actions.Ping)
-- 
GitLab