From 125e2708881f527dd6a3a2301687f29f5aa7421f Mon Sep 17 00:00:00 2001
From: Paul Andrey <paul.andrey@inria.fr>
Date: Thu, 15 Feb 2024 12:03:58 +0100
Subject: [PATCH] Integrate new Message API with communication endpoints.

- Revert some changes from a previous commit, that may be envisioned
  again when revising APIs in-depth for DecLearn 3.0.
- Stitch back the 'communication' and 'messaging' submodules together,
  having 'NetworkClient' and 'NetworkServer' send 'Message' instances
  (that are merely serialized to string prior to being exchanged) and
  parse received strings into 'SerializedMessage' instances.
- This change is a compromise between keeping things as before (which
  would not benefit from the introduced delayed-parsing capability of
  'SerializedMessage'), and changing them so deeply that current code
  making use of communication endpoint would require heavy revisions.
---
 declearn/__init__.py                  |  3 ++
 declearn/communication/api/_client.py | 17 +++++-----
 declearn/communication/api/_server.py | 49 ++++++++++++++++-----------
 3 files changed, 41 insertions(+), 28 deletions(-)

diff --git a/declearn/__init__.py b/declearn/__init__.py
index 08fe9b6e..187c4ee3 100644
--- a/declearn/__init__.py
+++ b/declearn/__init__.py
@@ -40,6 +40,8 @@ The package is organized into the following submodules:
     Data interfacing API and implementations.
 * [main][declearn.main]:
     Main classes implementing a Federated Learning process.
+* [messaging][declearn.messaging]:
+    API and default classes to define parsable messages for applications.
 * [metrics][declearn.metrics]:
     Iterative and federative evaluation metrics computation tools.
 * [model][declearn.model]:
@@ -61,6 +63,7 @@ from . import (
     dataset,
     main,
     metrics,
+    messaging,
     model,
     optimizer,
     typing,
diff --git a/declearn/communication/api/_client.py b/declearn/communication/api/_client.py
index bd72df36..7bc3b64b 100644
--- a/declearn/communication/api/_client.py
+++ b/declearn/communication/api/_client.py
@@ -38,6 +38,7 @@ from declearn.communication.api.backend.actions import (
     Send,
     parse_action_from_string,
 )
+from declearn.messaging import Message, SerializedMessage
 from declearn.utils import create_types_registry, get_logger, register_type
 from declearn.version import VERSION
 
@@ -247,14 +248,14 @@ class NetworkClient(metaclass=abc.ABCMeta):
 
     async def send_message(
         self,
-        message: str,
+        message: Message,
     ) -> None:
         """Send a message to the server.
 
         Parameters
         ----------
         message: str
-            Message string that is to be delivered to the server.
+            Message instance that is to be delivered to the server.
 
         Raises
         ------
@@ -269,7 +270,7 @@ class NetworkClient(metaclass=abc.ABCMeta):
         The message sent here is designed to be received using the
         `NetworkServer.wait_for_messages` method.
         """
-        query = Send(message)
+        query = Send(message.to_string())
         reply = await self._exchange_action_messages(query)
         if isinstance(reply, Ping):
             return None
@@ -287,7 +288,7 @@ class NetworkClient(metaclass=abc.ABCMeta):
     async def recv_message(
         self,
         timeout: Optional[int] = None,
-    ) -> str:
+    ) -> SerializedMessage:
         """Await a message from the server, with optional timeout.
 
         Parameters
@@ -299,8 +300,8 @@ class NetworkClient(metaclass=abc.ABCMeta):
 
         Returns
         -------
-        message: Message
-            Message received from the server.
+        message: SerializedMessage
+            Serialized message received from the server.
 
         Note
         ----
@@ -322,7 +323,7 @@ class NetworkClient(metaclass=abc.ABCMeta):
         query = Recv(timeout)
         reply = await self._exchange_action_messages(query)
         if isinstance(reply, Send):
-            return reply.content
+            return SerializedMessage.from_message_string(reply.content)
         # Handle the various kinds of failures and raise accordingly.
         if isinstance(reply, Reject):
             if reply.flag == flags.CHECK_MESSAGE_TIMEOUT:
@@ -342,7 +343,7 @@ class NetworkClient(metaclass=abc.ABCMeta):
     async def check_message(
         self,
         timeout: Optional[int] = None,
-    ) -> str:
+    ) -> SerializedMessage:
         """Await a message from the server, with optional timeout.
 
         This method is DEPRECATED in favor of the `recv_message` one.
diff --git a/declearn/communication/api/_server.py b/declearn/communication/api/_server.py
index e1b931d1..d340c04e 100644
--- a/declearn/communication/api/_server.py
+++ b/declearn/communication/api/_server.py
@@ -21,11 +21,15 @@ import abc
 import asyncio
 import logging
 import types
-from typing import Any, ClassVar, Dict, List, Optional, Set, Type, Tuple, Union
+from typing import (
+    # fmt: off
+    Any, ClassVar, Dict, List, Mapping, Optional, Set, Type, Tuple, Union
+)
 
 from typing_extensions import Self  # future: import from typing (py >=3.11)
 
 from declearn.communication.api.backend import MessagesHandler
+from declearn.messaging import Message, SerializedMessage
 from declearn.utils import create_types_registry, get_logger, register_type
 
 
@@ -222,7 +226,7 @@ class NetworkServer(metaclass=abc.ABCMeta):
 
     async def send_message(
         self,
-        message: str,
+        message: Message,
         client: str,
         timeout: Optional[int] = None,
     ) -> None:
@@ -231,7 +235,7 @@ class NetworkServer(metaclass=abc.ABCMeta):
         Parameters
         ----------
         message: str
-            Message string that is to be delivered to the client.
+            Message instance that is to be delivered to the client.
         client: str
             Identifier of the client to whom the message is addressed.
         timeout: int or None, default=None
@@ -244,18 +248,18 @@ class NetworkServer(metaclass=abc.ABCMeta):
             If `timeout` is set and is reached while the message is
             yet to be collected by the client.
         """
-        await self.handler.send_message(message, client, timeout)
+        await self.handler.send_message(message.to_string(), client, timeout)
 
     async def send_messages(
         self,
-        messages: Dict[str, str],
+        messages: Mapping[str, Message],
         timeout: Optional[int] = None,
     ) -> None:
         """Send messages to an ensemble of clients and await their collection.
 
         Parameters
         ----------
-        messages: dict[str, str]
+        messages: dict[str, Message]
             Dict mapping client names to the messages addressed to them.
         timeout: int or None, default=None
             Optional maximum delay (in seconds) beyond which to stop
@@ -275,7 +279,7 @@ class NetworkServer(metaclass=abc.ABCMeta):
 
     async def broadcast_message(
         self,
-        message: str,
+        message: Message,
         clients: Optional[Set[str]] = None,
         timeout: Optional[int] = None,
     ) -> None:
@@ -284,7 +288,7 @@ class NetworkServer(metaclass=abc.ABCMeta):
         Parameters
         ----------
         message: str
-            Message string that is to be delivered to the clients.
+            Message instance that is to be delivered to the clients.
         clients: set[str] or None, default=None
             Optional subset of registered clients, messages from
             whom to wait for. If None, set to `self.client_names`.
@@ -306,7 +310,7 @@ class NetworkServer(metaclass=abc.ABCMeta):
     async def wait_for_messages(
         self,
         clients: Optional[Set[str]] = None,
-    ) -> Dict[str, str]:
+    ) -> Dict[str, SerializedMessage]:
         """Wait for messages from (a subset of) all clients.
 
         Parameters
@@ -317,21 +321,24 @@ class NetworkServer(metaclass=abc.ABCMeta):
 
         Returns
         -------
-        messages: dict[str, str]
-            A dictionary where the keys are the clients' names and
-            the values are message strings they sent to the server.
+        messages:
+            A dictionary mapping clients' names to the serialized
+            messages they sent to the server.
         """
         if clients is None:
             clients = self.client_names
         routines = [self.handler.recv_message(client) for client in clients]
         received = await asyncio.gather(*routines, return_exceptions=False)
-        return dict(zip(clients, received))
+        return {
+            client: SerializedMessage.from_message_string(string)
+            for client, string in zip(clients, received)
+        }
 
     async def wait_for_messages_with_timeout(
         self,
         timeout: int,
         clients: Optional[Set[str]] = None,
-    ) -> Tuple[Dict[str, str], List[str]]:
+    ) -> Tuple[Dict[str, SerializedMessage], List[str]]:
         """Wait for an ensemble of clients to have sent a message.
 
         Parameters
@@ -358,13 +365,15 @@ class NetworkServer(metaclass=abc.ABCMeta):
             self.handler.recv_message(client, timeout) for client in clients
         ]
         received = await asyncio.gather(*routines, return_exceptions=True)
-        messages = {}  # type: Dict[str, str]
+        messages = {}  # type: Dict[str, SerializedMessage]
         timeouts = []  # type: List[str]
-        for client, message in zip(clients, received):
-            if isinstance(message, asyncio.TimeoutError):
+        for client, output in zip(clients, received):
+            if isinstance(output, asyncio.TimeoutError):
                 timeouts.append(client)
-            elif isinstance(message, BaseException):
-                raise message
+            elif isinstance(output, BaseException):
+                raise output
             else:
-                messages[client] = message
+                messages[client] = SerializedMessage.from_message_string(
+                    output
+                )
         return messages, timeouts
-- 
GitLab