diff --git a/declearn/communication/__init__.py b/declearn/communication/__init__.py index 97f875ae524fc2684e9bbcc2d0a00c0d096a2a50..93a508e0e6e2ab8cedadae70efd3a24be1e8fd9b 100644 --- a/declearn/communication/__init__.py +++ b/declearn/communication/__init__.py @@ -21,12 +21,15 @@ This is done by defining server-side and client-side network communication endpoints for federated learning processes, as well as suitable messages to be transmitted, and the available communication protocols. -This module contains the following core submodule: +This module contains the following core submodules: * [api][declearn.communication.api]: Base API to define client- and server-side communication endpoints. +* [utils][declearn.communication.utils]: + Utils related to network communication endpoints' setup and usage. -It also exposes the following core utility functions and dataclasses: + +It re-exports publicly from `utils` the following elements: * [build_client][declearn.communication.build_client]: Instantiate a NetworkClient, selecting its subclass based on protocol name. @@ -57,7 +60,8 @@ longer be used, as its contents were re-dispatched elsewhere in DecLearn. # Messaging API and base tools: from . import api -from ._build import ( +from . import utils +from .utils import ( _INSTALLABLE_BACKENDS, NetworkClientConfig, NetworkServerConfig, diff --git a/declearn/communication/utils/__init__.py b/declearn/communication/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6b2b8a783bcfc1d58bfecba18806bb41591782fd --- /dev/null +++ b/declearn/communication/utils/__init__.py @@ -0,0 +1,65 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utils related to network communication endpoints' setup and usage. + + +Endpoints setup utils +--------------------- + +* [build_client][declearn.communication.utils.build_client]: + Instantiate a NetworkClient, selecting its subclass based on protocol name. +* [build_server][declearn.communication.utils.build_server]: + Instantiate a NetworkServer, selecting its subclass based on protocol name. +* [list_available_protocols]\ +[declearn.communication.utils.list_available_protocols]: + Return the list of readily-available network protocols. +* [NetworkClientConfig][declearn.communication.utils.NetworkClientConfig]: + TOML-parsable dataclass for network clients' instantiation. +* [NetworkServerConfig][declearn.communication.utils.NetworkServerConfig]: + TOML-parsable dataclass for network servers' instantiation. + + +Message-type control utils +-------------------------- + +* [ErrorMessageException][declearn.communication.utils.ErrorMessageException]: + Exception raised when an unexpected 'Error' message is received. +* [MessageTypeException][declearn.communication.utils.MessageTypeException]: + Exception raised when a received 'Message' has wrong type. +* [verify_client_messages_validity]\ +[declearn.communication.utils.verify_client_messages_validity]: + Verify that received serialized messages match an expected type. +* [verify_server_message_validity]\ +[declearn.communication.utils.verify_server_message_validity]: + Verify that a received serialized message matches expected type. +""" + +from ._build import ( + _INSTALLABLE_BACKENDS, + NetworkClientConfig, + NetworkServerConfig, + build_client, + build_server, + list_available_protocols, +) +from ._parse import ( + ErrorMessageException, + MessageTypeException, + verify_client_messages_validity, + verify_server_message_validity, +) diff --git a/declearn/communication/_build.py b/declearn/communication/utils/_build.py similarity index 100% rename from declearn/communication/_build.py rename to declearn/communication/utils/_build.py diff --git a/declearn/communication/utils/_parse.py b/declearn/communication/utils/_parse.py new file mode 100644 index 0000000000000000000000000000000000000000..17bd70b37fbde913ca50a81687a505533769faa3 --- /dev/null +++ b/declearn/communication/utils/_parse.py @@ -0,0 +1,166 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utils to type-check received messages from the server or clients.""" + +from typing import Dict, Type, TypeVar + + +from declearn.communication.api import NetworkClient, NetworkServer +from declearn.messaging import Error, Message, SerializedMessage + + +__all__ = [ + "ErrorMessageException", + "MessageTypeException", + "verify_client_messages_validity", + "verify_server_message_validity", +] + + +class ErrorMessageException(Exception): + """Exception raised when an unexpected 'Error' message is received.""" + + +class MessageTypeException(Exception): + """Exception raised when a received 'Message' has wrong type.""" + + +MessageT = TypeVar("MessageT", bound=Message) + + +async def verify_client_messages_validity( + netwk: NetworkServer, + received: Dict[str, SerializedMessage], + expected: Type[MessageT], +) -> Dict[str, MessageT]: + """Verify that received serialized messages match an expected type. + + - If all received messages matches expected type, deserialize them. + - If any received message is an unexpected `Error` message, send an + `Error` to non-error-send clients, then raise. + - If any received message belongs to any other type, send an `Error` + to each and every client, then raise. + + Parameters + ---------- + netwk: + `NetworkClient` endpoint, from which the processed message + was received. + received: + Received `SerializedMessage` to type-check and deserialize. + expected: + Expected `Message` subtype. Any subclass will be considered + as valid. + + Returns + ------- + messages: + Deserialized messages from `received`, with `expected` type, + wrapped as a `{client_name: client_message}` dict. + + Raises + ------ + ErrorMessageException + If any `received` message wraps an unexpected `Error` message. + MessageTypeException + If any `received` wrapped message does not match `expected` type. + """ + # Iterate over received messages to identify any unexpected 'Error' ones + # or unexpected-type message. + wrong_types = "" + unexp_errors = {} # type: Dict[str, str] + for client, srm in received.items(): + if issubclass(srm.message_cls, expected): + pass + elif issubclass(srm.message_cls, Error): + unexp_errors[client] = srm.deserialize().message + else: + wrong_types += f"\n\t{client}: '{srm.message_cls}'" + # In case of Error messages, send an Error to other clients and raise. + if unexp_errors: + await netwk.broadcast_message( + Error("Some clients reported errors."), + clients=set(received).difference(unexp_errors), + ) + error = "".join( + f"\n\t{key}:{val}" for key, val in unexp_errors.items() + ) + raise ErrorMessageException( + f"Expected '{expected.__name__}' messages, got the following " + f"Error messages:{error}" + ) + # In case of unproper messages, send an Error to all clients and raise. + if wrong_types: + error = ( + f"Expected '{expected.__name__}' messages, got the following " + f"unproper message types:{wrong_types}" + ) + await netwk.broadcast_message(Error(error), clients=set(received)) + raise MessageTypeException(error) + # If everyting is fine, deserialized and return the received messages. + return {cli: srm.deserialize() for cli, srm in received.items()} + + +async def verify_server_message_validity( + netwk: NetworkClient, + received: SerializedMessage, + expected: Type[MessageT], +) -> MessageT: + """Verify that a received serialized message matches expected type. + + - If the received message matches expected type, deserialize it. + - If the recevied message is an unexpected `Error` message, raise. + - If it belongs to any other type, send an `Error` to the server, + then raise. + + Parameters + ---------- + netwk: + `NetworkClient` endpoint, from which the processed message + was received. + received: + Received `SerializedMessage` to type-check and deserialize. + expected: + Expected `Message` subtype. Any subclass will be considered + as valid. + + Returns + ------- + message: + Deserialized `Message` from `received`, with `expected` type. + + Raises + ------ + ErrorMessageException + If `received` wraps an unexpected `Error` message. + MessageTypeException + If `received` wrapped message does not match `expected` type. + """ + # If a proper message is received, deserialize and return it. + if issubclass(received.message_cls, expected): + return received.deserialize() + # When an Error is received, merely raise using its content. + error = f"Expected a '{expected}' message" + if issubclass(received.message_cls, Error): + msg = received.deserialize() + error = f"{error}, received an Error message: '{msg.message}'." + raise ErrorMessageException(error) + # Otherwise, send an Error to the server, then raise. + error = f"{error}, got a '{received.message_cls}'." + await netwk.send_message(Error(error)) + raise MessageTypeException(error) diff --git a/declearn/main/_client.py b/declearn/main/_client.py index 4f51ab72aab38aad7c3008465b862af848ce691b..aa387742886def2c3faccf05dbaf4bace1041a13 100644 --- a/declearn/main/_client.py +++ b/declearn/main/_client.py @@ -26,8 +26,11 @@ from typing import Any, Dict, Optional, Union import numpy as np from declearn import messaging -from declearn.communication import NetworkClientConfig from declearn.communication.api import NetworkClient +from declearn.communication.utils import ( + NetworkClientConfig, + verify_server_message_validity, +) from declearn.dataset import Dataset, load_dataset_from_json from declearn.main.utils import Checkpointer, TrainingManager from declearn.utils import LOGGING_LEVEL_MAJOR, get_logger @@ -249,24 +252,43 @@ class FederatedClient: """ # Await initialization instructions. self.logger.info("Awaiting initialization instructions from server.") - message = await self.netwk.recv_message() + received = await self.netwk.recv_message() # If a MetadataQuery is received, process it, then await InitRequest. - if message.message_cls is messaging.MetadataQuery: - await self._collect_and_send_metadata(message.deserialize()) - message = await self.netwk.recv_message() + if issubclass(received.message_cls, messaging.MetadataQuery): + await self._collect_and_send_metadata(received.deserialize()) + received = await self.netwk.recv_message() + # Ensure that an 'InitRequest' was received. + message = await verify_server_message_validity( + self.netwk, received, expected=messaging.InitRequest + ) # Perform initialization, catching errors to report them to the server. try: - if not issubclass(message.message_cls, messaging.InitRequest): - raise TypeError( - f"Awaited InitRequest message, got '{message.message_cls}'" - ) - await self._initialize_trainmanager(message.deserialize()) + self.trainmanager = TrainingManager( + model=message.model, + optim=message.optim, + aggrg=message.aggrg, + train_data=self.train_data, + valid_data=self.valid_data, + metrics=message.metrics, + logger=self.logger, + verbose=self.verbose, + ) except Exception as exc: await self.netwk.send_message(messaging.Error(repr(exc))) raise RuntimeError("Initialization failed.") from exc + # If instructed to do so, run additional steps to set up DP-SGD. + if message.dpsgd: + await self._initialize_dpsgd() # Send back an empty message to indicate that all went fine. self.logger.info("Notifying the server that initialization went fine.") await self.netwk.send_message(messaging.InitReply()) + # Optionally checkpoint the received model and optimizer. + if self.ckptr: + self.ckptr.checkpoint( + model=self.trainmanager.model, + optimizer=self.trainmanager.optim, + first_call=True, + ) async def _collect_and_send_metadata( self, @@ -286,37 +308,6 @@ class FederatedClient: ) await self.netwk.send_message(messaging.MetadataReply(data_info)) - async def _initialize_trainmanager( - self, - message: messaging.InitRequest, - ) -> None: - """Set up a TrainingManager based on server instructions. - - - Also await and set up DP constraints if instructed to do so. - - Checkpoint the model and optimizer if configured to do so. - """ - # Wrap up the model and optimizer received from the server. - self.trainmanager = TrainingManager( - model=message.model, - optim=message.optim, - aggrg=message.aggrg, - train_data=self.train_data, - valid_data=self.valid_data, - metrics=message.metrics, - logger=self.logger, - verbose=self.verbose, - ) - # If instructed to do so, await a PrivacyRequest to set up DP-SGD. - if message.dpsgd: - await self._initialize_dpsgd() - # Optionally checkpoint the received model and optimizer. - if self.ckptr: - self.ckptr.checkpoint( - model=self.trainmanager.model, - optimizer=self.trainmanager.optim, - first_call=True, - ) - async def _initialize_dpsgd( self, ) -> None: @@ -325,12 +316,13 @@ class FederatedClient: This method wraps the `make_private` one in the context of `initialize` and should never be called in another context. """ - message = await self.netwk.recv_message() - if not isinstance(message, messaging.PrivacyRequest): - msg = f"Expected a PrivacyRequest but received a '{type(message)}'" - self.logger.error(msg) - await self.netwk.send_message(messaging.Error(msg)) - raise RuntimeError(f"DP-SGD initialization failed: {msg}.") + received = await self.netwk.recv_message() + try: + message = await verify_server_message_validity( + self.netwk, received, expected=messaging.PrivacyRequest + ) + except Exception as exc: + raise RuntimeError("DP-SGD initialization failed.") from exc self.logger.info("Received a request to set up DP-SGD.") try: self.make_private(message) diff --git a/docs/release-notes/v2.4.0.md b/docs/release-notes/v2.4.0.md index 43ce11979f155a2973421fad00b5fb784c6affa1..f75412991e6ddd9db5f6da73da44c5ed26f525da 100644 --- a/docs/release-notes/v2.4.0.md +++ b/docs/release-notes/v2.4.0.md @@ -207,6 +207,15 @@ return `SerializedMessage` instances rather than `Message` ones. simulated contexts (including tests), setting a low heartbeat can cut runtime down significantly. +### New `declearn.communication.utils` submodule + +Introduce the `declearn.communication.utils` submodule, and move existing +`declearn.communication` utils to it. Keep re-exporting them from the parent +module to preserve code compatibility. + +Add `verify_client_messages_validity` and `verify_server_message_validity` as +part of the new submodule, that refactor some backend code from orchestration +classes related to the filtering and type-checking of exchanged messages. ## Usability updates diff --git a/test/communication/test_utils.py b/test/communication/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..362a76d120b7b423c0d4a3a60e7af51c0ea3de6f --- /dev/null +++ b/test/communication/test_utils.py @@ -0,0 +1,190 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for message-type-verification utils.""" + +import dataclasses +from unittest import mock + +import pytest + +from declearn.communication.api import NetworkClient, NetworkServer +from declearn.communication.utils import ( + ErrorMessageException, + MessageTypeException, + verify_client_messages_validity, + verify_server_message_validity, +) +from declearn.messaging import Error, Message, SerializedMessage + + +@dataclasses.dataclass +class SimpleMessage(Message, register=False): # type: ignore[call-arg] + """Stub Message subclass for this module's unit tests.""" + + typekey = "simple" + + content: str + + +@pytest.mark.asyncio +async def test_verify_client_messages_validity_expected_simple(): + """Test 'verify_client_messages_validity' with valid messages.""" + # Setup simple messages and have the server except them. + netwk = mock.create_autospec(NetworkServer, instance=True) + messages = {f"client_{i}": SimpleMessage(f"message_{i}") for i in range(3)} + received = { + key: SerializedMessage(type(val), val.to_string().split("\n", 1)[1]) + for key, val in messages.items() + } + results = await verify_client_messages_validity( + netwk=netwk, received=received, expected=SimpleMessage + ) + # Assert that results match expectations, and no message was sent. + assert isinstance(results, dict) + assert results == messages + netwk.broadcast_message.assert_not_called() + + +@pytest.mark.asyncio +async def test_verify_client_messages_validity_expected_error(): + """Test 'verify_client_messages_validity' with expected Error messages.""" + # Setup simple messages and have the server except them. + netwk = mock.create_autospec(NetworkServer, instance=True) + messages = {f"client_{i}": Error(f"message_{i}") for i in range(3)} + received = { + key: SerializedMessage(type(val), val.to_string().split("\n", 1)[1]) + for key, val in messages.items() + } + results = await verify_client_messages_validity( + netwk=netwk, received=received, expected=Error + ) + # Assert that results match expectations, and no message was sent. + assert isinstance(results, dict) + assert results == messages + netwk.broadcast_message.assert_not_called() + + +@pytest.mark.asyncio +async def test_verify_client_messages_validity_unexpected_types(): + """Test 'verify_client_messages_validity' with invalid messages.""" + # Setup simple messages, but have the server except Error messages. + netwk = mock.create_autospec(NetworkServer, instance=True) + messages = {f"client_{i}": SimpleMessage(f"message_{i}") for i in range(3)} + received = { + key: SerializedMessage(type(val), val.to_string().split("\n", 1)[1]) + for key, val in messages.items() + } + # Assert that an exception is raised. + with pytest.raises(MessageTypeException): + await verify_client_messages_validity( + netwk=netwk, received=received, expected=Error + ) + # Assert that an Error message was broadcast to all clients. + netwk.broadcast_message.assert_awaited_once_with( + message=Error(mock.ANY), clients=set(received) + ) + + +@pytest.mark.asyncio +async def test_verify_client_messages_validity_unexpected_error(): + """Test 'verify_client_messages_validity' with 'Error' messages.""" + # Setup simple messages, but have one be an Error. + netwk = mock.create_autospec(NetworkServer, instance=True) + messages = {f"client_{i}": SimpleMessage(f"message_{i}") for i in range(2)} + messages["client_2"] = Error("error_message") + received = { + key: SerializedMessage(type(val), val.to_string().split("\n", 1)[1]) + for key, val in messages.items() + } + # Assert that an exception is raised. + with pytest.raises(ErrorMessageException): + await verify_client_messages_validity( + netwk=netwk, received=received, expected=SimpleMessage + ) + # Assert that an Error message was broadcast to non-Error-sending clients. + netwk.broadcast_message.assert_awaited_once_with( + message=Error(mock.ANY), clients={"client_0", "client_1"} + ) + + +@pytest.mark.asyncio +async def test_verify_server_message_validity_expected_simple(): + """Test 'verify_server_message_validity' with a valid message.""" + # Setup a simple message matching client expectations. + netwk = mock.create_autospec(NetworkClient, instance=True) + message = SimpleMessage("message") + received = SerializedMessage( + SimpleMessage, message.to_string().split("\n", 1)[1] + ) + result = await verify_server_message_validity( + netwk=netwk, received=received, expected=SimpleMessage + ) + # Assert that results match expectations, and no message was sent. + assert isinstance(result, Message) + assert result == message + netwk.send_message.assert_not_called() + + +@pytest.mark.asyncio +async def test_verify_server_message_validity_expected_error(): + """Test 'verify_server_message_validity' with an expected Error message.""" + # Setup a simple message matching client expectations. + netwk = mock.create_autospec(NetworkClient, instance=True) + message = Error("message") + received = SerializedMessage(Error, message.to_string().split("\n", 1)[1]) + result = await verify_server_message_validity( + netwk=netwk, received=received, expected=Error + ) + # Assert that results match expectations, and no message was sent. + assert isinstance(result, Message) + assert result == message + netwk.send_message.assert_not_called() + + +@pytest.mark.asyncio +async def test_verify_server_message_validity_unexpected_type(): + """Test 'verify_server_message_validity' with an unexpected message.""" + # Setup a simple message, but have the client except an Error one. + netwk = mock.create_autospec(NetworkClient, instance=True) + message = SimpleMessage("message") + received = SerializedMessage( + SimpleMessage, message.to_string().split("\n", 1)[1] + ) + # Assert that an exception is raised. + with pytest.raises(MessageTypeException): + await verify_server_message_validity( + netwk=netwk, received=received, expected=Error + ) + # Assert that an Error was sent to the server. + netwk.send_message.assert_awaited_once_with(message=Error(mock.ANY)) + + +@pytest.mark.asyncio +async def test_verify_server_message_validity_unexpected_error(): + """Test 'verify_server_message_validity' with an unexpected 'Error'.""" + # Setup an unexpected Error message. + netwk = mock.create_autospec(NetworkClient, instance=True) + message = Error("message") + received = SerializedMessage(Error, message.to_string().split("\n", 1)[1]) + # Assert that an exception is raised. + with pytest.raises(ErrorMessageException): + await verify_server_message_validity( + netwk=netwk, received=received, expected=SimpleMessage + ) + # Assert that no Error was sent to the server. + netwk.send_message.assert_not_called()