diff --git a/declearn/main/_client.py b/declearn/main/_client.py index 9a8df6852e90d6f0c5bf4c2b69745a791ba5cfc5..4f51ab72aab38aad7c3008465b862af848ce691b 100644 --- a/declearn/main/_client.py +++ b/declearn/main/_client.py @@ -25,7 +25,8 @@ from typing import Any, Dict, Optional, Union import numpy as np -from declearn.communication import NetworkClientConfig, messaging +from declearn import messaging +from declearn.communication import NetworkClientConfig from declearn.communication.api import NetworkClient from declearn.dataset import Dataset, load_dataset_from_json from declearn.main.utils import Checkpointer, TrainingManager @@ -164,21 +165,21 @@ class FederatedClient: await self.initialize() # Process server instructions as they come. while True: - message = await self.netwk.check_message() + message = await self.netwk.recv_message() stoprun = await self.handle_message(message) if stoprun: break async def handle_message( self, - message: messaging.Message, + message: messaging.SerializedMessage, ) -> bool: """Handle an incoming message from the server. Parameters ---------- - message: messaging.Message - Message instance that needs triage and processing. + message: messaging.SerializedMessage + Serialized message that needs triage and processing. Returns ------- @@ -186,18 +187,18 @@ class FederatedClient: Whether to interrupt the client's message-receiving loop. """ exit_loop = False - if isinstance(message, messaging.TrainRequest): - await self.training_round(message) - elif isinstance(message, messaging.EvaluationRequest): - await self.evaluation_round(message) - elif isinstance(message, messaging.StopTraining): - await self.stop_training(message) + if issubclass(message.message_cls, messaging.TrainRequest): + await self.training_round(message.deserialize()) + elif issubclass(message.message_cls, messaging.EvaluationRequest): + await self.evaluation_round(message.deserialize()) + elif issubclass(message.message_cls, messaging.StopTraining): + await self.stop_training(message.deserialize()) exit_loop = True - elif isinstance(message, messaging.CancelTraining): - await self.cancel_training(message) + elif issubclass(message.message_cls, messaging.CancelTraining): + await self.cancel_training(message.deserialize()) else: - error = "Unexpected instruction received from server:" - error += repr(message) + error = "Unexpected message type received from server: " + error += message.message_cls.__name__ self.logger.error(error) raise ValueError(error) return exit_loop @@ -213,13 +214,11 @@ class FederatedClient: If registration has failed 10 times (with a 1 minute delay between connection and registration attempts). """ - # revise: add validation dataset specs - data_info = dataclasses.asdict(self.train_data.get_data_specs()) for i in range(10): # max_attempts (10) self.logger.info( "Attempting to join training (attempt n°%s)", i + 1 ) - registered = await self.netwk.register(data_info) + registered = await self.netwk.register() if registered: break await asyncio.sleep(60) # delay_retries (1 minute) @@ -248,23 +247,54 @@ class FederatedClient: optim: Optimizer Optimizer that is to be used locally to train the model. """ - # Await initialization instructions. Report messages-unpacking errors. + # Await initialization instructions. self.logger.info("Awaiting initialization instructions from server.") + message = await self.netwk.recv_message() + # If a MetadataQuery is received, process it, then await InitRequest. + if message.message_cls is messaging.MetadataQuery: + await self._collect_and_send_metadata(message.deserialize()) + message = await self.netwk.recv_message() + # Perform initialization, catching errors to report them to the server. try: - message = await self.netwk.check_message() + if not issubclass(message.message_cls, messaging.InitRequest): + raise TypeError( + f"Awaited InitRequest message, got '{message.message_cls}'" + ) + await self._initialize_trainmanager(message.deserialize()) except Exception as exc: await self.netwk.send_message(messaging.Error(repr(exc))) raise RuntimeError("Initialization failed.") from exc - # Otherwise, check that the request is of valid type. - if not isinstance(message, messaging.InitRequest): - error = f"Awaited InitRequest message, got: '{message}'" - self.logger.error(error) - raise RuntimeError(error) # Send back an empty message to indicate that all went fine. self.logger.info("Notifying the server that initialization went fine.") - await self.netwk.send_message( - messaging.GenericMessage(action="InitializationOK", params={}) + await self.netwk.send_message(messaging.InitReply()) + + async def _collect_and_send_metadata( + self, + message: messaging.MetadataQuery, + ) -> None: + """Collect and report some metadata based on server instructions.""" + self.logger.info("Collecting metadata to send to the server.") + metadata = dataclasses.asdict(self.train_data.get_data_specs()) + if missing := set(message.fields).difference(metadata): + err_msg = f"Metadata query for undefined fields: {missing}." + await self.netwk.send_message(messaging.Error(err_msg)) + raise RuntimeError(err_msg) + data_info = {key: metadata[key] for key in message.fields} + self.logger.info( + "Sending training dataset metadata to the server: %s.", + list(data_info), ) + await self.netwk.send_message(messaging.MetadataReply(data_info)) + + async def _initialize_trainmanager( + self, + message: messaging.InitRequest, + ) -> None: + """Set up a TrainingManager based on server instructions. + + - Also await and set up DP constraints if instructed to do so. + - Checkpoint the model and optimizer if configured to do so. + """ # Wrap up the model and optimizer received from the server. self.trainmanager = TrainingManager( model=message.model, @@ -295,7 +325,7 @@ class FederatedClient: This method wraps the `make_private` one in the context of `initialize` and should never be called in another context. """ - message = await self.netwk.check_message() + message = await self.netwk.recv_message() if not isinstance(message, messaging.PrivacyRequest): msg = f"Expected a PrivacyRequest but received a '{type(message)}'" self.logger.error(msg) @@ -312,9 +342,7 @@ class FederatedClient: raise RuntimeError("DP-SGD initialization failed.") from exc # If things went right, notify the server. self.logger.info("Notifying the server that DP-SGD setup went fine.") - await self.netwk.send_message( - messaging.GenericMessage(action="privacy-ok", params={}) - ) + await self.netwk.send_message(messaging.PrivacyReply()) def make_private( self, diff --git a/declearn/main/_server.py b/declearn/main/_server.py index 67391464e9ac709e98963c5b5214ff5c3906f378..8dd9f48b6bf5625907fff31b51cfcc02e4377942 100644 --- a/declearn/main/_server.py +++ b/declearn/main/_server.py @@ -24,7 +24,8 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union import numpy as np -from declearn.communication import NetworkServerConfig, messaging +from declearn import messaging +from declearn.communication import NetworkServerConfig from declearn.communication.api import NetworkServer from declearn.main.config import ( EvaluateConfig, @@ -240,13 +241,14 @@ class FederatedServer: """ # Gather the RegisterConfig instance from the main FLRunConfig. regst_cfg = config.register - # Wait for clients to register and process their data information. + # Wait for clients to register. self.logger.info("Starting clients registration process.") - data_info = await self.netwk.wait_for_clients( + await self.netwk.wait_for_clients( regst_cfg.min_clients, regst_cfg.max_clients, regst_cfg.timeout ) self.logger.info("Clients' registration is now complete.") - await self._process_data_info(data_info) + # When needed, prompt clients for metadata and process them. + await self._require_and_process_data_info() # Serialize intialization information and send it to clients. message = messaging.InitRequest( model=self.model, @@ -262,25 +264,18 @@ class FederatedServer: self.logger.info("Waiting for clients' responses.") await self._collect_results( clients=self.netwk.client_names, - msgtype=messaging.GenericMessage, - context="initialization", + msgtype=messaging.InitReply, + context="Initialization", ) # If local differential privacy is configured, set it up. if config.privacy is not None: await self._initialize_dpsgd(config) self.logger.info("Initialization was successful.") - async def _process_data_info( + async def _require_and_process_data_info( self, - clients_data_info: Dict[str, Dict[str, Any]], ) -> None: - """Validate, aggregate and make use of clients' data-info. - - Parameters - ---------- - clients_data_info: dict[str, dict[str, any]] - Client-wise data-info dict, that are to be aggregated - and passed to the global model for initialization. + """Collect, validate, aggregate and make use of clients' data-info. Raises ------ @@ -290,6 +285,19 @@ class FederatedServer: raising. """ fields = self.model.required_data_info # revise: add optimizer, etc. + if not fields: + return + # Collect required metadata from clients. + query = messaging.MetadataQuery(list(fields)) + await self.netwk.broadcast_message(query) + replies = await self._collect_results( + self.netwk.client_names, + msgtype=messaging.MetadataReply, + context="Metadata collection", + ) + clients_data_info = { + client: reply.data_info for client, reply in replies.items() + } # Try aggregating the input data_info. try: info = aggregate_clients_data_info(clients_data_info, fields) @@ -298,7 +306,7 @@ class FederatedServer: messages = { client: messaging.CancelTraining(reason) for client, reason in exc.messages.items() - } # type: Dict[str, messaging.Message] + } await self.netwk.send_messages(messages) self.logger.error(exc.error) raise exc @@ -339,13 +347,14 @@ class FederatedServer: replies = await self.netwk.wait_for_messages(clients) results = {} # type: Dict[str, MessageT] errors = {} # type: Dict[str, str] - for client, message in replies.items(): - if isinstance(message, msgtype): - results[client] = message - elif isinstance(message, messaging.Error): - errors[client] = f"{context} failed: {message.message}" + for client, reply in replies.items(): + if issubclass(reply.message_cls, msgtype): + results[client] = reply.deserialize() + elif issubclass(reply.message_cls, messaging.Error): + err_msg = reply.deserialize().message + errors[client] = f"{context} failed: {err_msg}" else: - errors[client] = f"Unexpected message: {message}" + errors[client] = f"Unexpected message: {reply.message_cls}" # If any client has failed to send proper results, raise. # future: modularize errors-handling behaviour if errors: @@ -390,7 +399,7 @@ class FederatedServer: self.logger.info("Waiting for clients' responses.") await self._collect_results( clients=self.netwk.client_names, - msgtype=messaging.GenericMessage, + msgtype=messaging.PrivacyReply, context="Privacy initialization", ) self.logger.info("Privacy requests were processed by clients.") diff --git a/declearn/messaging/__init__.py b/declearn/messaging/__init__.py index 2463a962aa14eedd626e142859dba8d0875339d2..36e3a96502d77ac6231db82bc9ba89e22df01f3d 100644 --- a/declearn/messaging/__init__.py +++ b/declearn/messaging/__init__.py @@ -35,7 +35,11 @@ Base messages * [EvaluationRequest][declearn.messaging.EvaluationRequest] * [GenericMessage][declearn.messaging.GenericMessage] * [InitRequest][declearn.messaging.InitRequest] +* [InitReply][declearn.messaging.InitReply] +* [MetadataQuery][declearn.messaging.MetadataQuery] +* [MetadataReply][declearn.messaging.MetadataReply] * [PrivacyRequest][declearn.messaging.PrivacyRequest] +* [PrivacyReply][declearn.messaging.PrivacyReply] * [StopTraining][declearn.messaging.StopTraining] * [TrainReply][declearn.messaging.TrainReply] * [TrainRequest][declearn.messaging.TrainRequest] @@ -53,7 +57,11 @@ from ._base import ( EvaluationRequest, GenericMessage, InitRequest, + InitReply, + MetadataQuery, + MetadataReply, PrivacyRequest, + PrivacyReply, StopTraining, TrainReply, TrainRequest, diff --git a/declearn/messaging/_base.py b/declearn/messaging/_base.py index 9f7a695609df48e11f41eefe3223d2925f751b43..f2f946e6cf0740f6f0f270174895236c6ef826b0 100644 --- a/declearn/messaging/_base.py +++ b/declearn/messaging/_base.py @@ -38,7 +38,11 @@ __all__ = [ "EvaluationRequest", "GenericMessage", "InitRequest", + "InitReply", + "MetadataQuery", + "MetadataReply", "PrivacyRequest", + "PrivacyReply", "StopTraining", "TrainReply", "TrainRequest", @@ -135,6 +139,31 @@ class InitRequest(Message): return cls(**kwargs) +@dataclasses.dataclass +class InitReply(Message): + """Client-emitted message indicating that initialization went fine.""" + + typekey = "init_reply" + + +@dataclasses.dataclass +class MetadataQuery(Message): + """Server-emitted request for metadata on a client's dataset.""" + + typekey = "metadata_query" + + fields: List[str] + + +@dataclasses.dataclass +class MetadataReply(Message): + """Client-emitted metadata in response to a server request.""" + + typekey = "metadata_reply" + + data_info: Dict[str, Any] + + @dataclasses.dataclass class PrivacyRequest(Message): """Server-emitted request to set up local differential privacy.""" @@ -156,6 +185,13 @@ class PrivacyRequest(Message): n_steps: Optional[int] +@dataclasses.dataclass +class PrivacyReply(Message): + """Client-emitted message indicating that DP setup went fine.""" + + typekey = "privacy_reply" + + @dataclasses.dataclass class StopTraining(Message): """Server-emitted notification that the training process is over."""