Mentions légales du service

Skip to content
Snippets Groups Projects
_handler.py 16.92 KiB
# coding: utf-8

# Copyright 2023 Inria (Institut National de Recherche en Informatique
# et Automatique)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Protocol-agnostic server-side network messages handler."""

import asyncio
import logging
import math
from typing import Any, Dict, Optional, Set, Union


from declearn.communication.api.backend import flags
from declearn.communication.api.backend.actions import (
    Accept,
    ActionMessage,
    Drop,
    Join,
    LegacyMessageError,
    LegacyReject,
    Ping,
    Recv,
    Reject,
    Send,
    parse_action_from_string,
)
from declearn.version import VERSION


class MessagesHandler:
    """Minimal protocol-agnostic server-side messages handler."""

    def __init__(
        self,
        logger: logging.Logger,
        heartbeat: float = 1.0,
    ) -> None:
        # Assign parameters as attributes.
        self.logger = logger
        self.heartbeat = heartbeat
        # Set up containers for client identifiers and pending messages.
        self.registered_clients = {}  # type: Dict[Any, str]
        self.outgoing_messages = {}  # type: Dict[str, str]
        self.incoming_messages = {}  # type: Dict[str, str]
        # Mark client-registration as unopened.
        self.registration_status = flags.REGISTRATION_UNSTARTED

    @property
    def client_names(self) -> Set[str]:
        """Names of the registered clients."""
        return set(self.registered_clients.values())

    async def purge(
        self,
    ) -> None:
        """Close opened connections and purge information about users.

        This resets the instance as though it was first initialized.
        User registration will be marked as unstarted.
        """
        self.registered_clients.clear()
        self.outgoing_messages.clear()
        self.incoming_messages.clear()
        self.registration_status = flags.REGISTRATION_UNSTARTED

    async def handle_message(
        self,
        string: str,
        context: Any,
    ) -> ActionMessage:
        """Handle an incoming message from a client.

        Parameters
        ----------
        string: str
            Received message, as a string that can be parsed back
            into an `ActionMessage` instance.
        context: hashable
            Communications-protocol-specific hashable object that
            may be used to uniquely identify (and thereof contact)
            the client that sent the message being handled.

        Returns
        -------
        message: ActionMessage
            Message to return to the sender, the specific type of
            which depends on the type of incoming request, errors
            encountered, etc.
        """
        # Parse the incoming message. If it is incorrect, reject it.
        try:
            message = parse_action_from_string(string)
        except (KeyError, TypeError, ValueError) as exc:
            self.logger.info(
                "Exception encountered while parsing received message: %s",
                repr(exc),
            )
            return Reject(flags.INVALID_MESSAGE)
        except LegacyMessageError as exc:
            self.logger.info(repr(exc))
            return LegacyReject()
        # Case: join request from a (new) client. Handle it.
        if isinstance(message, Join):
            return await self._handle_join_request(message, context)
        # Case: unregistered client. Reject message.
        if context not in self.registered_clients:
            return Reject(flags.REJECT_UNREGISTERED)
        # Case: registered client. Handle it.
        return await self._handle_registered_client_message(message, context)

    async def _handle_registered_client_message(
        self,
        message: ActionMessage,
        context: Any,
    ) -> ActionMessage:
        """Backend to handle a message from a registered client."""
        # Case: message-receiving request. Handle it.
        if isinstance(message, Recv):
            return await self._handle_recv_request(message, context)
        # Case: message-sending request. Handle it.
        if isinstance(message, Send):
            return await self._handle_send_request(message, context)
        # Case: drop message from a client. Handle it.
        if isinstance(message, Drop):
            return await self._handle_drop_request(message, context)
        # Case: ping request. Ping back.
        if isinstance(message, Ping):
            return Ping()
        # Case: unsupported message. Reject it.
        self.logger.error(
            "TypeError: received a message of unexpected type '%s'",
            type(message).__name__,
        )
        return Reject(flags.INVALID_MESSAGE)

    async def _handle_join_request(
        self,
        message: Join,
        context: Any,
    ) -> Union[Accept, Reject]:
        """Handle a join request."""
        # Case when client is already registered: warn but send OK.
        if context in self.registered_clients:
            self.logger.info(
                "Client %s is already registered.",
                self.registered_clients[context],
            )
            return Accept(flags.REGISTERED_ALREADY)
        # Case when registration is not opened: warn and reject.
        if self.registration_status != flags.REGISTRATION_OPEN:
            self.logger.info("Rejecting registration request.")
            return Reject(flag=self.registration_status)
        # Case when the client uses an incompatible declearn version.
        if (err := self._verify_version_compatibility(message)) is not None:
            return err
        # Case when registration is opened: register the client.
        self._register_client(message, context)
        return Accept(flag=flags.REGISTERED_WELCOME)

    def _verify_version_compatibility(
        self,
        message: Join,
    ) -> Optional[Reject]:
        """Return an 'Error' if a 'JoinRequest' is of incompatible version."""
        if message.version.split(".")[:2] == VERSION.split(".")[:2]:
            return None
        self.logger.info(
            "Received a registration request under name %s, that is "
            "invalid due to the client using DecLearn '%s'.",
            message.name,
            message.version,
        )
        return Reject(
            "Cannot register due to the DecLearn version in use. "
            f"Please update to `declearn ~= {VERSION}`."
        )

    def _register_client(
        self,
        message: Join,
        context: Any,
    ) -> None:
        """Register a user based on their Join request and context object."""
        # Alias the user name if needed to avoid duplication issues.
        name = message.name
        used = self.client_names
        if name in used:
            idx = sum(other.rsplit(".", 1)[0] == name for other in used)
            name = f"{name}.{idx}"
        # Register the user, recording context and received data information.
        self.logger.info("Registering client '%s' for training.", name)
        self.registered_clients[context] = name

    async def _handle_send_request(
        self,
        message: Send,
        context: Any,
    ) -> Union[Ping, Reject]:
        """Handle a message-sending request (client-to-server)."""
        name = self.registered_clients[context]
        # Wait for any previous message from this client to be collected.
        while self.incoming_messages.get(name):
            await asyncio.sleep(self.heartbeat)
        # Record the received message, and return a ping-back response.
        self.incoming_messages[name] = message.content
        return Ping()

    async def _handle_recv_request(
        self,
        message: Recv,
        context: Any,
    ) -> Union[Send, Reject]:
        """Handle a message-receiving request."""
        # Set up the optional timeout mechanism.
        timeout = message.timeout
        countdown = (
            max(math.ceil(timeout / self.heartbeat), 1) if timeout else -1
        )
        # Wait for a message to be available or timeout to be reached.
        name = self.registered_clients[context]
        while (not self.outgoing_messages.get(name)) and countdown:
            await asyncio.sleep(self.heartbeat)
            countdown -= 1
        # Either send back the collected message, or a timeout error.
        try:
            content = self.outgoing_messages.pop(name)
        except KeyError:
            return Reject(flags.CHECK_MESSAGE_TIMEOUT)
        return Send(content)

    async def _handle_drop_request(
        self,
        message: Drop,
        context: Any,
    ) -> Ping:
        """Handle a drop request from a client."""
        name = self.registered_clients.pop(context)
        reason = (
            f"reason: '{message.reason}'" if message.reason else "no reason"
        )
        self.logger.info("Client %s has dropped with %s.", name, reason)
        return Ping()

    def post_message(
        self,
        message: str,
        client: str,
    ) -> None:
        """Post a message to be requested by a given client.

        Parameters
        ----------
        message: str
            Message string that is to be posted for the client to collect.
        client: str
            Name of the client to whom the message is addressed.

        Notes
        -----
        This method merely makes the message available for the client
        to request, without any guarantee that it is received.
        See the `send_message` async method to wait for the posted
        message to have been requested by and thus sent to the client.
        """
        if client not in self.client_names:
            raise KeyError(f"Unkown destinatory client '{client}'.")
        if client in self.outgoing_messages:
            self.logger.warning(
                "Overwriting pending message uncollected by client '%s'.",
                client,
            )
        self.outgoing_messages[client] = message

    async def send_message(
        self,
        message: str,
        client: str,
        timeout: Optional[float] = None,
    ) -> None:
        """Post a message for a client and wait for it to be collected.

        Parameters
        ----------
        message: str
            Message string that is to be posted for the client to collect.
        client: str
            Name of the client to whom the message is addressed.
        timeout: float or None, default=None
            Optional maximum delay (in seconds) beyond which to stop
            waiting for collection and raise an asyncio.TimeoutError.

        Raises
        ------
        asyncio.TimeoutError
            If `timeout` is set and is reached while the message is
            yet to be collected by the client.

        Notes
        -----
        See the `post_message` method to synchronously post a message
        and move on without guarantees that it was collected.
        """
        # Post the message. Wait for it to have been collected.
        self.post_message(message, client)
        countdown = (
            max(math.ceil(timeout / self.heartbeat), 1) if timeout else -1
        )
        while self.outgoing_messages.get(client, False) and countdown:
            await asyncio.sleep(self.heartbeat)
            countdown -= 1
        # If the message is still there, raise a TimeoutError.
        if self.outgoing_messages.get(client):
            raise asyncio.TimeoutError(
                "Timeout reached before the sent message was collected."
            )

    def check_message(
        self,
        client: str,
    ) -> Optional[str]:
        """Check whether a message was received from a given client.

        Parameters
        ----------
        client: str
            Name of the client whose emitted message to check for.

        Returns
        -------
        message:
            Collected message that was sent by `client`, if any.
            In case no message is available, return None.

        Notes
        -----
        See the `recv_message` async method to wait for a message
        from the client to be available, collect and return it.
        """
        if client not in self.client_names:
            raise KeyError(f"Unregistered checked-for client '{client}'.")
        return self.incoming_messages.pop(client, None)

    async def recv_message(
        self,
        client: str,
        timeout: Optional[float] = None,
    ) -> str:
        """Wait for a message to be received from a given client.

        Parameters
        ----------
        client: str
            Name of the client whose emitted message to check for.
        timeout: float or None, default=None
            Optional maximum delay (in seconds) beyond which to stop
            waiting for a message and raise an asyncio.TimeoutError.

        Raises
        ------
        asyncio.TimeoutError
            If `timeout` is set and is reached while no message has
            been received from the client.

        Returns
        -------
        message:
            Collected message that was sent by `client`.

        Notes
        -----
        See the `check_message` method to synchronously check whether
        a message from the client is available and return it or None.
        """
        countdown = (
            max(math.ceil(timeout / self.heartbeat), 1) if timeout else -1
        )
        while countdown:
            message = self.check_message(client)
            if message is not None:
                return message
            await asyncio.sleep(self.heartbeat)
            countdown -= 1
        raise asyncio.TimeoutError(
            "Timeout reached before a message was received."
        )

    def open_clients_registration(
        self,
    ) -> None:
        """Make this servicer accept registration of new clients."""
        self.registration_status = flags.REGISTRATION_OPEN

    def close_clients_registration(
        self,
    ) -> None:
        """Make this servicer reject registration of new clients."""
        self.registration_status = flags.REGISTRATION_CLOSED

    async def wait_for_clients(
        self,
        min_clients: int = 1,
        max_clients: Optional[int] = None,
        timeout: Optional[float] = None,
    ) -> None:
        """Wait for clients to register for training, with given criteria.

        Parameters
        ----------
        min_clients: int, default=1
            Minimum number of clients required. Corrected to be >= 1.
            If `timeout` is None, used as the exact number of clients
            required - once reached, registration will be closed.
        max_clients: int or None, default=None
            Maximum number of clients authorized to register.
        timeout: float or None, default=None
            Optional maximum waiting time (in seconds) beyond which
            to close registration and either return or raise.

        Raises
        ------
        RuntimeError
            If the number of registered clients does not abide by the
            provided boundaries at the end of the process.
        """
        # Ensure any collected information is purged in case of failure
        # (due to raised errors or wrong number of registered clients).
        try:
            await self._wait_for_clients(min_clients, max_clients, timeout)
        except Exception as exc:  # re-raise; pylint: disable=broad-except
            await self.purge()
            raise exc

    async def _wait_for_clients(
        self,
        min_clients: int = 1,
        max_clients: Optional[int] = None,
        timeout: Optional[float] = None,
    ) -> None:
        """Backend of `wait_for_clients` method, without safeguards."""
        # Parse information on the required number of clients.
        min_clients = max(min_clients, 1)
        max_clients = -1 if max_clients is None else max_clients
        if max_clients < 0:
            max_clients = (
                min_clients if timeout is None else math.inf  # type: ignore
            )
        else:
            max_clients = max(min_clients, max_clients)
        # Wait for the required number of clients to have joined.
        self.open_clients_registration()
        countdown = (
            max(math.ceil(timeout / self.heartbeat), 1) if timeout else -1
        )
        while countdown and (len(self.registered_clients) < max_clients):
            await asyncio.sleep(self.heartbeat)
            countdown -= 1
        self.close_clients_registration()
        # Check whether all requirements have been checked.
        n_clients = len(self.registered_clients)
        if not min_clients <= n_clients <= max_clients:
            raise RuntimeError(
                f"The number of registered clients is {n_clients}, which "
                f"is out of the [{min_clients}, {max_clients}] range."
            )