diff --git a/declearn/__init__.py b/declearn/__init__.py index 08fe9b6ea7be6ad63be6b1725a49c0c1309b60c7..187c4ee3bbf39112d1e53c08b9dc886b85160d34 100644 --- a/declearn/__init__.py +++ b/declearn/__init__.py @@ -40,6 +40,8 @@ The package is organized into the following submodules: Data interfacing API and implementations. * [main][declearn.main]: Main classes implementing a Federated Learning process. +* [messaging][declearn.messaging]: + API and default classes to define parsable messages for applications. * [metrics][declearn.metrics]: Iterative and federative evaluation metrics computation tools. * [model][declearn.model]: @@ -61,6 +63,7 @@ from . import ( dataset, main, metrics, + messaging, model, optimizer, typing, diff --git a/declearn/communication/api/_client.py b/declearn/communication/api/_client.py index bd72df364a2a08e14494cfb339bee2caabe436c8..7bc3b64bb1b57d047c34522f25ae14b38e7e059a 100644 --- a/declearn/communication/api/_client.py +++ b/declearn/communication/api/_client.py @@ -38,6 +38,7 @@ from declearn.communication.api.backend.actions import ( Send, parse_action_from_string, ) +from declearn.messaging import Message, SerializedMessage from declearn.utils import create_types_registry, get_logger, register_type from declearn.version import VERSION @@ -247,14 +248,14 @@ class NetworkClient(metaclass=abc.ABCMeta): async def send_message( self, - message: str, + message: Message, ) -> None: """Send a message to the server. Parameters ---------- message: str - Message string that is to be delivered to the server. + Message instance that is to be delivered to the server. Raises ------ @@ -269,7 +270,7 @@ class NetworkClient(metaclass=abc.ABCMeta): The message sent here is designed to be received using the `NetworkServer.wait_for_messages` method. """ - query = Send(message) + query = Send(message.to_string()) reply = await self._exchange_action_messages(query) if isinstance(reply, Ping): return None @@ -287,7 +288,7 @@ class NetworkClient(metaclass=abc.ABCMeta): async def recv_message( self, timeout: Optional[int] = None, - ) -> str: + ) -> SerializedMessage: """Await a message from the server, with optional timeout. Parameters @@ -299,8 +300,8 @@ class NetworkClient(metaclass=abc.ABCMeta): Returns ------- - message: Message - Message received from the server. + message: SerializedMessage + Serialized message received from the server. Note ---- @@ -322,7 +323,7 @@ class NetworkClient(metaclass=abc.ABCMeta): query = Recv(timeout) reply = await self._exchange_action_messages(query) if isinstance(reply, Send): - return reply.content + return SerializedMessage.from_message_string(reply.content) # Handle the various kinds of failures and raise accordingly. if isinstance(reply, Reject): if reply.flag == flags.CHECK_MESSAGE_TIMEOUT: @@ -342,7 +343,7 @@ class NetworkClient(metaclass=abc.ABCMeta): async def check_message( self, timeout: Optional[int] = None, - ) -> str: + ) -> SerializedMessage: """Await a message from the server, with optional timeout. This method is DEPRECATED in favor of the `recv_message` one. diff --git a/declearn/communication/api/_server.py b/declearn/communication/api/_server.py index e1b931d173038e1e61a645d82557f6c83d7178de..d340c04e01c96f757920284e62c32b3eb8d5e966 100644 --- a/declearn/communication/api/_server.py +++ b/declearn/communication/api/_server.py @@ -21,11 +21,15 @@ import abc import asyncio import logging import types -from typing import Any, ClassVar, Dict, List, Optional, Set, Type, Tuple, Union +from typing import ( + # fmt: off + Any, ClassVar, Dict, List, Mapping, Optional, Set, Type, Tuple, Union +) from typing_extensions import Self # future: import from typing (py >=3.11) from declearn.communication.api.backend import MessagesHandler +from declearn.messaging import Message, SerializedMessage from declearn.utils import create_types_registry, get_logger, register_type @@ -222,7 +226,7 @@ class NetworkServer(metaclass=abc.ABCMeta): async def send_message( self, - message: str, + message: Message, client: str, timeout: Optional[int] = None, ) -> None: @@ -231,7 +235,7 @@ class NetworkServer(metaclass=abc.ABCMeta): Parameters ---------- message: str - Message string that is to be delivered to the client. + Message instance that is to be delivered to the client. client: str Identifier of the client to whom the message is addressed. timeout: int or None, default=None @@ -244,18 +248,18 @@ class NetworkServer(metaclass=abc.ABCMeta): If `timeout` is set and is reached while the message is yet to be collected by the client. """ - await self.handler.send_message(message, client, timeout) + await self.handler.send_message(message.to_string(), client, timeout) async def send_messages( self, - messages: Dict[str, str], + messages: Mapping[str, Message], timeout: Optional[int] = None, ) -> None: """Send messages to an ensemble of clients and await their collection. Parameters ---------- - messages: dict[str, str] + messages: dict[str, Message] Dict mapping client names to the messages addressed to them. timeout: int or None, default=None Optional maximum delay (in seconds) beyond which to stop @@ -275,7 +279,7 @@ class NetworkServer(metaclass=abc.ABCMeta): async def broadcast_message( self, - message: str, + message: Message, clients: Optional[Set[str]] = None, timeout: Optional[int] = None, ) -> None: @@ -284,7 +288,7 @@ class NetworkServer(metaclass=abc.ABCMeta): Parameters ---------- message: str - Message string that is to be delivered to the clients. + Message instance that is to be delivered to the clients. clients: set[str] or None, default=None Optional subset of registered clients, messages from whom to wait for. If None, set to `self.client_names`. @@ -306,7 +310,7 @@ class NetworkServer(metaclass=abc.ABCMeta): async def wait_for_messages( self, clients: Optional[Set[str]] = None, - ) -> Dict[str, str]: + ) -> Dict[str, SerializedMessage]: """Wait for messages from (a subset of) all clients. Parameters @@ -317,21 +321,24 @@ class NetworkServer(metaclass=abc.ABCMeta): Returns ------- - messages: dict[str, str] - A dictionary where the keys are the clients' names and - the values are message strings they sent to the server. + messages: + A dictionary mapping clients' names to the serialized + messages they sent to the server. """ if clients is None: clients = self.client_names routines = [self.handler.recv_message(client) for client in clients] received = await asyncio.gather(*routines, return_exceptions=False) - return dict(zip(clients, received)) + return { + client: SerializedMessage.from_message_string(string) + for client, string in zip(clients, received) + } async def wait_for_messages_with_timeout( self, timeout: int, clients: Optional[Set[str]] = None, - ) -> Tuple[Dict[str, str], List[str]]: + ) -> Tuple[Dict[str, SerializedMessage], List[str]]: """Wait for an ensemble of clients to have sent a message. Parameters @@ -358,13 +365,15 @@ class NetworkServer(metaclass=abc.ABCMeta): self.handler.recv_message(client, timeout) for client in clients ] received = await asyncio.gather(*routines, return_exceptions=True) - messages = {} # type: Dict[str, str] + messages = {} # type: Dict[str, SerializedMessage] timeouts = [] # type: List[str] - for client, message in zip(clients, received): - if isinstance(message, asyncio.TimeoutError): + for client, output in zip(clients, received): + if isinstance(output, asyncio.TimeoutError): timeouts.append(client) - elif isinstance(message, BaseException): - raise message + elif isinstance(output, BaseException): + raise output else: - messages[client] = message + messages[client] = SerializedMessage.from_message_string( + output + ) return messages, timeouts