From 07ed21db1f83cbf31d3b39ce4f291c8076696cc0 Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Fri, 16 Feb 2024 15:01:39 +0100 Subject: [PATCH] Add unit tests for 'declearn.communication.api.backend' code. --- test/communication/backend/test_actions.py | 122 +++++ .../backend/test_messages_handler.py | 424 ++++++++++++++++++ 2 files changed, 546 insertions(+) create mode 100644 test/communication/backend/test_actions.py create mode 100644 test/communication/backend/test_messages_handler.py diff --git a/test/communication/backend/test_actions.py b/test/communication/backend/test_actions.py new file mode 100644 index 00000000..ba53e745 --- /dev/null +++ b/test/communication/backend/test_actions.py @@ -0,0 +1,122 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for 'declearn.communication.api.backend.actions'.""" + +import dataclasses +import json + +import pytest + +from declearn.communication.api.backend.actions import ( + Accept, + ActionMessage, + Drop, + Join, + LegacyReject, + LegacyMessageError, + Ping, + Recv, + Reject, + Send, + parse_action_from_string, +) +from declearn.communication.api.backend import flags + + +def assert_action_is_serializable( + action: ActionMessage, +) -> None: + """Test that a given 'ActionMessage' is (un)serializable.""" + string = action.to_string() + assert isinstance(string, str) + result = parse_action_from_string(string) + assert isinstance(result, action.__class__) + assert dataclasses.asdict(result) == dataclasses.asdict(action) + + +class TestActionMessage: + """Unit tests for 'ActionMessage' subclasses.""" + + def test_accept(self) -> None: + """Test that 'Accept' is serializable.""" + action = Accept(flag=flags.REGISTERED_WELCOME) + assert_action_is_serializable(action) + + def test_drop(self) -> None: + """Test that 'Drop' is serializable.""" + action = Drop() + assert_action_is_serializable(action) + + def test_join(self) -> None: + """Test that 'Join' is serializable.""" + action = Join(name="client", version="version") + assert_action_is_serializable(action) + + def test_ping(self) -> None: + """Test that 'Ping' is serializable.""" + action = Ping() + assert_action_is_serializable(action) + + def test_recv(self) -> None: + """Test that 'Recv' is serializable.""" + action = Recv(timeout=1) + assert_action_is_serializable(action) + + def test_reject(self) -> None: + """Test that 'Reject' is serializable.""" + action = Reject(flag=flags.REJECT_UNREGISTERED) + assert_action_is_serializable(action) + + def test_send(self) -> None: + """Test that 'Send' is serializable.""" + action = Send(content="stub-content") + assert_action_is_serializable(action) + + +class TestParseActionErrors: + """Unit tests for exception-raising action string parsing.""" + + def test_invalid_json(self) -> None: + """Test that a ValueError is raised on invalid action string.""" + with pytest.raises(ValueError): + parse_action_from_string("{invalid-json}") + + def test_no_action_key(self) -> None: + """Test that a ValueError is raised on invalid json dump.""" + string = json.dumps({"data": "stub"}) + with pytest.raises(ValueError): + parse_action_from_string(string) + + def test_invalid_action_key(self) -> None: + """Test that a KeyError is raised on invalid action key.""" + string = json.dumps({"action": "stub-action"}) + with pytest.raises(KeyError): + parse_action_from_string(string) + + def test_legacy_message(self) -> None: + """Test that a LegacyMessageError is raised on Message dump.""" + string = json.dumps({"typekey": "stub", "data": "stub-data"}) + with pytest.raises(LegacyMessageError): + parse_action_from_string(string) + + def test_parse_legacy_reject_action(self) -> None: + """Test that a 'LegacyReject' action cannot be properly parsed.""" + action = LegacyReject() + string = action.to_string() + with pytest.raises(LegacyMessageError): + parse_action_from_string(string) diff --git a/test/communication/backend/test_messages_handler.py b/test/communication/backend/test_messages_handler.py new file mode 100644 index 00000000..36ff5753 --- /dev/null +++ b/test/communication/backend/test_messages_handler.py @@ -0,0 +1,424 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for 'declearn.communication.api.backend.MessagesHandler'.""" + +import asyncio +import logging +import time +from unittest import mock + +import pytest + +from declearn.communication.api.backend.actions import ( + Accept, + Drop, + Join, + LegacyReject, + Ping, + Recv, + Reject, + Send, +) +from declearn.communication.api.backend import MessagesHandler, flags +from declearn.version import VERSION + + +@pytest.fixture(name="handler") +def fixture_handler() -> MessagesHandler: + """Setup a MessagesHandler with a mock logger and a 0.1 heartbeat.""" + logger = mock.create_autospec(logging.Logger) + return MessagesHandler(logger, heartbeat=0.1) + + +@pytest.mark.asyncio +class TestMessagesHandler: + """Unit tests for 'declearn.communication.api.backend.MessagesHandler'.""" + + # unit tests namespace; pylint: disable=too-many-public-methods + + async def test_handle_invalid_action( + self, + handler: MessagesHandler, + ) -> None: + """Test that an invalid message is rejected.""" + query = "invalid-action-string" + reply = await handler.handle_message(query, context=mock.MagicMock()) + assert isinstance(reply, Reject) + assert reply.flag == flags.INVALID_MESSAGE + + async def test_handle_legacy_message( + self, + handler: MessagesHandler, + ) -> None: + """Test that an invalid message is rejected.""" + query = LegacyReject().to_string() + reply = await handler.handle_message(query, context=mock.MagicMock()) + assert isinstance(reply, LegacyReject) + + async def test_handle_join_open( + self, + handler: MessagesHandler, + ) -> None: + """Test Join action handling with open registration.""" + handler.open_clients_registration() + query = Join(name="name", version=VERSION).to_string() + reply = await handler.handle_message(query, context="context") + assert isinstance(reply, Accept) + assert reply.flag == flags.REGISTERED_WELCOME + assert handler.registered_clients["context"] == "name" + + async def test_handle_join_close( + self, + handler: MessagesHandler, + ) -> None: + """Test Join action handling with close registration.""" + handler.close_clients_registration() + query = Join(name="name", version=VERSION).to_string() + reply = await handler.handle_message(query, context="context") + assert isinstance(reply, Reject) + assert reply.flag == flags.REGISTRATION_CLOSED + assert "context" not in handler.registered_clients + + async def test_handle_join_wrong_version( + self, + handler: MessagesHandler, + ) -> None: + """Test Join action handling with declearn version mismatch.""" + handler.open_clients_registration() + query = Join(name="name", version="mock.version.string").to_string() + reply = await handler.handle_message(query, context="context") + assert isinstance(reply, Reject) + assert reply.flag == flags.REJECT_INCOMPATIBLE_VERSION + assert "context" not in handler.registered_clients + + async def test_handle_join_redundant( + self, + handler: MessagesHandler, + ) -> None: + """Test Join action handling for a pre-registered client.""" + # Register the client a first time and close registration. + handler.open_clients_registration() + query = Join(name="name", version=VERSION).to_string() + reply = await handler.handle_message(query, context="context") + assert isinstance(reply, Accept) + assert reply.flag == flags.REGISTERED_WELCOME + handler.close_clients_registration() + # Re-process the same registration request. + reply = await handler.handle_message(query, context="context") + assert isinstance(reply, Accept) + assert reply.flag == flags.REGISTERED_ALREADY + + async def test_handle_unregistered( + self, + handler: MessagesHandler, + ) -> None: + """Test handling of a non-Join action from an unregistered client.""" + handler.close_clients_registration() + query = Ping().to_string() + reply = await handler.handle_message(query, "context") + assert isinstance(reply, Reject) + assert reply.flag == flags.REJECT_UNREGISTERED + + async def test_handle_unexpected_type( + self, + handler: MessagesHandler, + ) -> None: + """Test handling of a server-reserved action from a client.""" + handler.registered_clients = {"context": "client"} + query = Reject(flag="stub").to_string() + reply = await handler.handle_message(query, context="context") + assert isinstance(reply, Reject) + assert reply.flag == flags.INVALID_MESSAGE + + async def test_handle_recv( + self, + handler: MessagesHandler, + ) -> None: + """Test handling of a Recv action with a pending message.""" + handler.registered_clients = {"context": "client"} + handler.outgoing_messages["client"] = "message" + query = Recv(timeout=1).to_string() + reply = await handler.handle_message(query, context="context") + assert isinstance(reply, Send) + assert reply.content == "message" + + async def test_handle_recv_timeout( + self, + handler: MessagesHandler, + ) -> None: + """Test handling of a Recv action that times out.""" + handler.registered_clients = {"context": "client"} + query = Recv(timeout=0.2).to_string() + start = time.time() + reply = await handler.handle_message(query, context="context") + delay = time.time() - start + assert isinstance(reply, Reject) + assert reply.flag == flags.CHECK_MESSAGE_TIMEOUT + assert 0.2 <= delay <= 0.2 + handler.heartbeat + + async def test_handle_send( + self, + handler: MessagesHandler, + ) -> None: + """Test handling of a Send action from a registered client.""" + handler.registered_clients = {"context": "client"} + query = Send(content="message").to_string() + reply = await handler.handle_message(query, context="context") + assert isinstance(reply, Ping) + assert handler.incoming_messages["client"] == "message" + + async def test_handle_send_over_pending( + self, + handler: MessagesHandler, + ) -> None: + """Test handling of a Send action with a pending message.""" + handler.registered_clients = {"context": "client"} + handler.incoming_messages["client"] = "pending" + query = Send(content="message").to_string() + coro = handler.handle_message(query, context="context") + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(coro, timeout=0.2) + assert handler.incoming_messages["client"] == "pending" + + async def test_handle_drop( + self, + handler: MessagesHandler, + ) -> None: + """Test handling of a Drop action from a registered client.""" + handler.registered_clients = {"context": "client"} + query = Drop().to_string() + reply = await handler.handle_message(query, context="context") + assert isinstance(reply, Ping) + assert "content" not in handler.registered_clients + + async def test_handle_ping( + self, + handler: MessagesHandler, + ) -> None: + """Test handling of a Ping action from a registered client.""" + handler.registered_clients = {"context": "client"} + query = Ping().to_string() + reply = await handler.handle_message(query, context="context") + assert isinstance(reply, Ping) + + async def test_post_message( + self, + handler: MessagesHandler, + ) -> None: + """Test posting a message adressed to a registered client.""" + handler.registered_clients = {"context": "client"} + handler.post_message("message", "client") + assert handler.outgoing_messages["client"] == "message" + + async def test_post_message_overwrite( + self, + handler: MessagesHandler, + ) -> None: + """Test posting a message that overwrites another pending one.""" + handler.registered_clients = {"context": "client"} + handler.outgoing_messages["client"] = "pending" + handler.post_message("message", "client") + handler.logger.warning.assert_called_once() # type: ignore + assert handler.outgoing_messages["client"] == "message" + + async def test_post_message_invalid_client( + self, + handler: MessagesHandler, + ) -> None: + """Test posting a message to an invalid client name.""" + with pytest.raises(KeyError): + handler.post_message("message", "client") + + async def test_send_message( + self, + handler: MessagesHandler, + ) -> None: + """Test sending a message that gets collected.""" + handler.registered_clients = {"context": "client"} + send_outp, recv_reply = await asyncio.gather( + handler.send_message("message", "client", timeout=1), + handler.handle_message(Recv(timeout=1).to_string(), "context"), + ) + assert send_outp is None + assert isinstance(recv_reply, Send) and recv_reply.content == "message" + + async def test_send_message_timeout( + self, + handler: MessagesHandler, + ) -> None: + """Test sending a message that is not collected before timeout.""" + handler.registered_clients = {"context": "client"} + with pytest.raises(asyncio.TimeoutError): + await handler.send_message("message", "client", timeout=0.1) + + async def test_check_message( + self, + handler: MessagesHandler, + ) -> None: + """Test collecting a pending message posted by a client.""" + handler.registered_clients = {"context": "client"} + handler.incoming_messages["client"] = "message" + output = handler.check_message("client") + assert output == "message" + assert "client" not in handler.incoming_messages + + async def test_check_message_no_message( + self, + handler: MessagesHandler, + ) -> None: + """Test checking for a client's message that is not there.""" + handler.registered_clients = {"context": "client"} + output = handler.check_message("client") + assert output is None + + async def test_check_message_invalid_client( + self, + handler: MessagesHandler, + ) -> None: + """Test checking for a client's message with invalid name.""" + with pytest.raises(KeyError): + handler.check_message("client") + + async def test_recv_message( + self, + handler: MessagesHandler, + ) -> None: + """Test receiving a message that gets posted.""" + handler.registered_clients = {"context": "client"} + recv_message, send_reply = await asyncio.gather( + handler.recv_message("client", timeout=1), + handler.handle_message(Send("message").to_string(), "context"), + ) + assert recv_message == "message" + assert isinstance(send_reply, Ping) + + async def test_recv_message_timeout( + self, + handler: MessagesHandler, + ) -> None: + """Test receiving a message that is not posted before timeout.""" + handler.registered_clients = {"context": "client"} + with pytest.raises(asyncio.TimeoutError): + await handler.recv_message("client", timeout=0.1) + + async def test_wait_for_clients_single_client( + self, + handler: MessagesHandler, + ) -> None: + """Test awaiting registration from a single client.""" + coro_wait_for_client = handler.wait_for_clients( + min_clients=1, max_clients=None, timeout=1.0 + ) + coro_register_client = handler.handle_message( + Join("client", VERSION).to_string(), "context" + ) + outp_wait, outp_join = await asyncio.gather( + coro_wait_for_client, coro_register_client + ) + assert outp_wait is None + assert isinstance(outp_join, Accept) + assert handler.client_names == {"client"} + + async def test_wait_for_clients_with_timeout( + self, + handler: MessagesHandler, + ) -> None: + """Test awaiting registration for a given delay.""" + # Have the handler await 0.2 seconds and two clients join. + coro_wait_for_clients = handler.wait_for_clients( + min_clients=1, max_clients=None, timeout=0.2 + ) + coro_register_client_a = handler.handle_message( + Join("client", VERSION).to_string(), "context-a" + ) + coro_register_client_b = handler.handle_message( + Join("client", VERSION).to_string(), "context-b" + ) + start = time.time() + outp_wait, outp_join_a, outp_join_b = await asyncio.gather( + coro_wait_for_clients, + coro_register_client_a, + coro_register_client_b, + ) + delay = time.time() - start + # Verify that the delay was respected and both clients were registered. + assert 0.2 <= delay <= 0.2 + handler.heartbeat + assert outp_wait is None + assert isinstance(outp_join_a, Accept) + assert isinstance(outp_join_b, Accept) + assert handler.client_names == {"client", "client.1"} + + async def test_wait_for_clients_not_enough_error( + self, + handler: MessagesHandler, + ) -> None: + """Test awaiting registration from too many clients.""" + # Have the handler await 0.2 seconds for 2 clients but only 1 join. + coro_wait_for_client = handler.wait_for_clients( + min_clients=2, max_clients=None, timeout=0.2 + ) + coro_register_client = handler.handle_message( + Join("client", VERSION).to_string(), "context" + ) + start = time.time() + excp_wait, outp_join = await asyncio.gather( + coro_wait_for_client, coro_register_client, return_exceptions=True + ) + delay = time.time() - start + # Verify that the delay was respected and a RuntimeError was raised. + assert 0.2 <= delay <= 0.2 + handler.heartbeat + assert isinstance(excp_wait, RuntimeError) # type: ignore + # Verify that in spite of initial acceptance, handler was purged. + assert isinstance(outp_join, Accept) # type: ignore + assert not handler.registered_clients + + async def test_wait_for_clients_too_many_error( + self, + handler: MessagesHandler, + ) -> None: + """Test awaiting registration with too many concurrent requests.""" + # Have the handler await maximum 2 clients, but 3 attempt joining. + coro_wait_for_clients = handler.wait_for_clients( + min_clients=1, max_clients=2, timeout=0.2 + ) + coro_register_client_a = handler.handle_message( + Join("client", VERSION).to_string(), "context-a" + ) + coro_register_client_b = handler.handle_message( + Join("client", VERSION).to_string(), "context-b" + ) + coro_register_client_c = handler.handle_message( + Join("client", VERSION).to_string(), "context-c" + ) + start = time.time() + excp_wait, *join_replies = await asyncio.gather( + coro_wait_for_clients, + coro_register_client_a, + coro_register_client_b, + coro_register_client_c, + return_exceptions=True, + ) + delay = time.time() - start + # Verify that requests were all accepted due to concurrence. + assert delay < 0.2 + assert all( + isinstance(reply, Accept) for reply in join_replies # type: ignore + ) + # Verify that this resulting in a RuntimeError and purging the handler. + assert isinstance(excp_wait, RuntimeError) # type: ignore + assert not handler.registered_clients -- GitLab