diff --git a/.gitignore b/.gitignore index 484ea2a8a37b23757e427944b7d1107ee6bceb3a..b5ae727ef68f7d6e03447bf5e50f4c500927d045 100644 --- a/.gitignore +++ b/.gitignore @@ -17,9 +17,10 @@ coverage.xml examples/*/*.csv examples/*/*.pem examples/*/*/ -# Documentation online rendering files. +# Documentation rendered files. public/ site/ docs/api-reference/*/ docs/api-reference/SUMMARY.md docs/api-reference/typing.md +docs/api-reference/version.md diff --git a/AUTHORS b/AUTHORS index b23a2041d0487cd5d38c2cd2a8adad2999b46453..9ed557a5d84b7e8447fe0de08146faa0632612b1 100644 --- a/AUTHORS +++ b/AUTHORS @@ -1,6 +1,9 @@ This file maintains the list of present and past declearn authors. A secondary file listing punctual open-source contributors may complement it. +Declearn 2.4 +- Paul Andrey + Declearn 2.1 - 2.3 - Paul Andrey - Nathan Bigaud diff --git a/declearn/aggregator/_api.py b/declearn/aggregator/_api.py index d8817dee83b5c027dd825b19940fddf36d654959..468b837b520c2cdd0e86a8d6b644e381c341ff9c 100644 --- a/declearn/aggregator/_api.py +++ b/declearn/aggregator/_api.py @@ -43,7 +43,23 @@ T = TypeVar("T") @dataclasses.dataclass class ModelUpdates(Aggregate, base_cls=True, register=True): - """Base dataclass for model updates' sharing and aggregation.""" + """Base dataclass for model updates' sharing and aggregation. + + Each and every `Aggregator` subclass is expected to be coupled with + one (or multiple) `ModelUpdates` (sub)type(s), that define which data + is exchanged and how it is aggregated across a network of peers. An + instance resulting from the aggregation of multiple peers' data may + be passed to an appropriate `Aggregator` instance for finalization + into a `Vector` of model updates. + + This class also defines whether contents are compatible with secure + aggregation, and whether some fields should remain in cleartext no + matter what. + + Note that subclasses are automatically type-registered, and should be + decorated as `dataclasses.dataclass`. To prevent registration, simply + pass `register=False` at inheritance. + """ updates: Vector weights: Union[int, float] diff --git a/declearn/communication/__init__.py b/declearn/communication/__init__.py index 97f875ae524fc2684e9bbcc2d0a00c0d096a2a50..93a508e0e6e2ab8cedadae70efd3a24be1e8fd9b 100644 --- a/declearn/communication/__init__.py +++ b/declearn/communication/__init__.py @@ -21,12 +21,15 @@ This is done by defining server-side and client-side network communication endpoints for federated learning processes, as well as suitable messages to be transmitted, and the available communication protocols. -This module contains the following core submodule: +This module contains the following core submodules: * [api][declearn.communication.api]: Base API to define client- and server-side communication endpoints. +* [utils][declearn.communication.utils]: + Utils related to network communication endpoints' setup and usage. -It also exposes the following core utility functions and dataclasses: + +It re-exports publicly from `utils` the following elements: * [build_client][declearn.communication.build_client]: Instantiate a NetworkClient, selecting its subclass based on protocol name. @@ -57,7 +60,8 @@ longer be used, as its contents were re-dispatched elsewhere in DecLearn. # Messaging API and base tools: from . import api -from ._build import ( +from . import utils +from .utils import ( _INSTALLABLE_BACKENDS, NetworkClientConfig, NetworkServerConfig, diff --git a/declearn/communication/api/backend/flags.py b/declearn/communication/api/backend/flags.py index c9400d60be33ce9c9c6e4b6744da058da5c06503..98d9d23861013a9251535d211194e9e4c766cf5e 100644 --- a/declearn/communication/api/backend/flags.py +++ b/declearn/communication/api/backend/flags.py @@ -26,27 +26,32 @@ from declearn.version import VERSION __all__ = [ "CHECK_MESSAGE_TIMEOUT", "INVALID_MESSAGE", - "REGISTERED_WELCOME", "REGISTERED_ALREADY", - "REGISTRATION_UNSTARTED", - "REGISTRATION_OPEN", + "REGISTERED_WELCOME", "REGISTRATION_CLOSED", + "REGISTRATION_OPEN", + "REGISTRATION_UNSTARTED", + "REJECT_INCOMPATIBLE_VERSION", "REJECT_UNREGISTERED", + "REJECT_UNREGISTERED_CHUNKED", ] # Registration flags. -REGISTRATION_UNSTARTED = "registration is not opened yet" -REGISTRATION_OPEN = "registration is open" -REGISTRATION_CLOSED = "registration is closed" -REGISTERED_WELCOME = "welcome, you are now registered" REGISTERED_ALREADY = "you were already registered" +REGISTERED_WELCOME = "welcome, you are now registered" +REGISTRATION_CLOSED = "registration is closed" +REGISTRATION_OPEN = "registration is open" +REGISTRATION_UNSTARTED = "registration is not opened yet" # Error flags. CHECK_MESSAGE_TIMEOUT = "no available message at timeout" INVALID_MESSAGE = "invalid message" -REJECT_UNREGISTERED = "rejected: not a registered user" REJECT_INCOMPATIBLE_VERSION = ( "cannot register due to the DecLearn version in use; " f"please update to `declearn ~= {VERSION}`" ) +REJECT_UNREGISTERED = "rejected: not a registered user" +REJECT_UNREGISTERED_CHUNKED = ( + "chunked messages from unregistered clients are not allowed" +) diff --git a/declearn/communication/grpc/_server.py b/declearn/communication/grpc/_server.py index eb50e92f82ffe3ab75315d9be62642de4adf17b6..d46984a91b6d8cedc975514ad9f44053d6eefff2 100644 --- a/declearn/communication/grpc/_server.py +++ b/declearn/communication/grpc/_server.py @@ -27,14 +27,12 @@ import grpc # type: ignore from cryptography.hazmat.primitives import serialization from declearn.communication.api import NetworkServer -from declearn.communication.api.backend import MessagesHandler +from declearn.communication.api.backend import MessagesHandler, actions, flags from declearn.communication.grpc.protobufs import message_pb2 from declearn.communication.grpc.protobufs.message_pb2_grpc import ( MessageBoardServicer, add_MessageBoardServicer_to_server, ) -from declearn.communication.messaging import Error - __all__ = [ "GrpcServer", @@ -223,9 +221,7 @@ class GrpcServicer(MessageBoardServicer): # async is needed; pylint: disable=invalid-overridden-method # Case when an unknown peer attempts sending a stream: send an error. if context.peer() not in self.handler.registered_clients: - error = Error( - "Chunked messages from unregistered clients are not allowed." - ) + error = actions.Reject(flags.REJECT_UNREGISTERED_CHUNKED) self.handler.logger.warning( "Refused a chunks-streaming request from client %s", context.peer(), diff --git a/declearn/communication/utils/__init__.py b/declearn/communication/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6b2b8a783bcfc1d58bfecba18806bb41591782fd --- /dev/null +++ b/declearn/communication/utils/__init__.py @@ -0,0 +1,65 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utils related to network communication endpoints' setup and usage. + + +Endpoints setup utils +--------------------- + +* [build_client][declearn.communication.utils.build_client]: + Instantiate a NetworkClient, selecting its subclass based on protocol name. +* [build_server][declearn.communication.utils.build_server]: + Instantiate a NetworkServer, selecting its subclass based on protocol name. +* [list_available_protocols]\ +[declearn.communication.utils.list_available_protocols]: + Return the list of readily-available network protocols. +* [NetworkClientConfig][declearn.communication.utils.NetworkClientConfig]: + TOML-parsable dataclass for network clients' instantiation. +* [NetworkServerConfig][declearn.communication.utils.NetworkServerConfig]: + TOML-parsable dataclass for network servers' instantiation. + + +Message-type control utils +-------------------------- + +* [ErrorMessageException][declearn.communication.utils.ErrorMessageException]: + Exception raised when an unexpected 'Error' message is received. +* [MessageTypeException][declearn.communication.utils.MessageTypeException]: + Exception raised when a received 'Message' has wrong type. +* [verify_client_messages_validity]\ +[declearn.communication.utils.verify_client_messages_validity]: + Verify that received serialized messages match an expected type. +* [verify_server_message_validity]\ +[declearn.communication.utils.verify_server_message_validity]: + Verify that a received serialized message matches expected type. +""" + +from ._build import ( + _INSTALLABLE_BACKENDS, + NetworkClientConfig, + NetworkServerConfig, + build_client, + build_server, + list_available_protocols, +) +from ._parse import ( + ErrorMessageException, + MessageTypeException, + verify_client_messages_validity, + verify_server_message_validity, +) diff --git a/declearn/communication/_build.py b/declearn/communication/utils/_build.py similarity index 100% rename from declearn/communication/_build.py rename to declearn/communication/utils/_build.py diff --git a/declearn/communication/utils/_parse.py b/declearn/communication/utils/_parse.py new file mode 100644 index 0000000000000000000000000000000000000000..17bd70b37fbde913ca50a81687a505533769faa3 --- /dev/null +++ b/declearn/communication/utils/_parse.py @@ -0,0 +1,166 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utils to type-check received messages from the server or clients.""" + +from typing import Dict, Type, TypeVar + + +from declearn.communication.api import NetworkClient, NetworkServer +from declearn.messaging import Error, Message, SerializedMessage + + +__all__ = [ + "ErrorMessageException", + "MessageTypeException", + "verify_client_messages_validity", + "verify_server_message_validity", +] + + +class ErrorMessageException(Exception): + """Exception raised when an unexpected 'Error' message is received.""" + + +class MessageTypeException(Exception): + """Exception raised when a received 'Message' has wrong type.""" + + +MessageT = TypeVar("MessageT", bound=Message) + + +async def verify_client_messages_validity( + netwk: NetworkServer, + received: Dict[str, SerializedMessage], + expected: Type[MessageT], +) -> Dict[str, MessageT]: + """Verify that received serialized messages match an expected type. + + - If all received messages matches expected type, deserialize them. + - If any received message is an unexpected `Error` message, send an + `Error` to non-error-send clients, then raise. + - If any received message belongs to any other type, send an `Error` + to each and every client, then raise. + + Parameters + ---------- + netwk: + `NetworkClient` endpoint, from which the processed message + was received. + received: + Received `SerializedMessage` to type-check and deserialize. + expected: + Expected `Message` subtype. Any subclass will be considered + as valid. + + Returns + ------- + messages: + Deserialized messages from `received`, with `expected` type, + wrapped as a `{client_name: client_message}` dict. + + Raises + ------ + ErrorMessageException + If any `received` message wraps an unexpected `Error` message. + MessageTypeException + If any `received` wrapped message does not match `expected` type. + """ + # Iterate over received messages to identify any unexpected 'Error' ones + # or unexpected-type message. + wrong_types = "" + unexp_errors = {} # type: Dict[str, str] + for client, srm in received.items(): + if issubclass(srm.message_cls, expected): + pass + elif issubclass(srm.message_cls, Error): + unexp_errors[client] = srm.deserialize().message + else: + wrong_types += f"\n\t{client}: '{srm.message_cls}'" + # In case of Error messages, send an Error to other clients and raise. + if unexp_errors: + await netwk.broadcast_message( + Error("Some clients reported errors."), + clients=set(received).difference(unexp_errors), + ) + error = "".join( + f"\n\t{key}:{val}" for key, val in unexp_errors.items() + ) + raise ErrorMessageException( + f"Expected '{expected.__name__}' messages, got the following " + f"Error messages:{error}" + ) + # In case of unproper messages, send an Error to all clients and raise. + if wrong_types: + error = ( + f"Expected '{expected.__name__}' messages, got the following " + f"unproper message types:{wrong_types}" + ) + await netwk.broadcast_message(Error(error), clients=set(received)) + raise MessageTypeException(error) + # If everyting is fine, deserialized and return the received messages. + return {cli: srm.deserialize() for cli, srm in received.items()} + + +async def verify_server_message_validity( + netwk: NetworkClient, + received: SerializedMessage, + expected: Type[MessageT], +) -> MessageT: + """Verify that a received serialized message matches expected type. + + - If the received message matches expected type, deserialize it. + - If the recevied message is an unexpected `Error` message, raise. + - If it belongs to any other type, send an `Error` to the server, + then raise. + + Parameters + ---------- + netwk: + `NetworkClient` endpoint, from which the processed message + was received. + received: + Received `SerializedMessage` to type-check and deserialize. + expected: + Expected `Message` subtype. Any subclass will be considered + as valid. + + Returns + ------- + message: + Deserialized `Message` from `received`, with `expected` type. + + Raises + ------ + ErrorMessageException + If `received` wraps an unexpected `Error` message. + MessageTypeException + If `received` wrapped message does not match `expected` type. + """ + # If a proper message is received, deserialize and return it. + if issubclass(received.message_cls, expected): + return received.deserialize() + # When an Error is received, merely raise using its content. + error = f"Expected a '{expected}' message" + if issubclass(received.message_cls, Error): + msg = received.deserialize() + error = f"{error}, received an Error message: '{msg.message}'." + raise ErrorMessageException(error) + # Otherwise, send an Error to the server, then raise. + error = f"{error}, got a '{received.message_cls}'." + await netwk.send_message(Error(error)) + raise MessageTypeException(error) diff --git a/declearn/data_info/__init__.py b/declearn/data_info/__init__.py index 72470e2e2d84828d2fee74ba51d96e958413210c..484f0ddfce351784f7991cf8146c3bb87edda4bd 100644 --- a/declearn/data_info/__init__.py +++ b/declearn/data_info/__init__.py @@ -52,14 +52,6 @@ Field specifications Specification for the 'features_shape' field. * [NbSamplesField][declearn.data_info.NbSamplesField]: Specification for the 'n_samples' field. - -Deprecated field specifications -------------------------------- - -* [InputShapeField][declearn.data_info.InputShapeField]: - Deprecacted as of v2.2 in favor of FeaturesShapeField. -* [NbFeaturesField][declearn.data_info.NbFeaturesField]: - Deprecacted as of v2.2 in favor of FeaturesShapeField. """ from ._base import ( @@ -71,8 +63,6 @@ from ._base import ( from ._fields import ( ClassesField, DataTypeField, - InputShapeField, FeaturesShapeField, - NbFeaturesField, NbSamplesField, ) diff --git a/declearn/data_info/_fields.py b/declearn/data_info/_fields.py index b46fec8204d2e7fca512427c75827c72f6df8ea0..1fc4d78d0531a9d447a729bcf1d8f7ab48de0836 100644 --- a/declearn/data_info/_fields.py +++ b/declearn/data_info/_fields.py @@ -17,8 +17,7 @@ """DataInfoField subclasses specifying common 'data_info' metadata fields.""" -import warnings -from typing import Any, List, Optional, Set, Tuple +from typing import Any, Optional, Set, Tuple import numpy as np @@ -29,8 +28,6 @@ __all__ = [ "DataTypeField", "FeaturesShapeField", "NbSamplesField", - "InputShapeField", # deprecated as of v2.2 - "NbFeaturesField", # deprecated as of v2.2 ] @@ -154,104 +151,3 @@ class NbSamplesField(DataInfoField): ) -> int: super().combine(*values) # type-check inputs return sum(values) - - -# Deprecated fields - - -@register_data_info_field -class InputShapeField(DataInfoField): # pragma: no cover - """Specifications for 'input_shape' data_info field.""" - - field = "input_shape" - types = (tuple, list) - doc = "DEPRECATED - Input features' batched shape, checked to be equal." - - @classmethod - def is_valid( - cls, - value: Any, - ) -> bool: - return ( - isinstance(value, cls.types) - and (len(value) >= 2) - and all(isinstance(val, int) or (val is None) for val in value) - ) - - @classmethod - def combine( - cls, - *values: Any, - ) -> List[Optional[int]]: - # Warn about this class being deprecated. - warnings.warn( - "'NbFeaturesField has been deprecated as of declearn v2.2," - " and will be removed in v2.4 and/or v3.0." - " Please use 'SingleInputShapeField' instead.", - DeprecationWarning, - stacklevel=3, - ) - # Type check each and every input shape. - super().combine(*values) - # Check that all shapes are of same length. - unique = list({len(shp) for shp in values}) - if len(unique) != 1: - raise ValueError( - f"Cannot combine '{cls.field}': inputs have various lengths." - ) - # Fill-in the unified shape: except all-None or (None or unique) value. - # Note: batching dimension is set to None by default (no check). - shape = [None] * unique[0] # type: List[Optional[int]] - for i in range(1, unique[0]): - val = [shp[i] for shp in values if shp[i] is not None] - if not val: # all None - shape[i] = None - elif len(set(val)) > 1: - raise ValueError( - f"Cannot combine '{cls.field}': provided shapes differ." - ) - else: - shape[i] = val[0] - # Return the combined shape. - return shape - - -@register_data_info_field -class NbFeaturesField(DataInfoField): # pragma: no cover - """Deprecated specifications for 'n_features' data_info field.""" - - field = "n_features" - types = (int,) - doc = "DEPRECATED - Number of input features, checked to be equal." - - @classmethod - def is_valid( - cls, - value: Any, - ) -> bool: - return isinstance(value, int) and (value > 0) - - @classmethod - def combine( - cls, - *values: Any, - ) -> int: - # Warn about this class being deprecated. - warnings.warn( - "'NbFeaturesField has been deprecated as of declearn v2.2," - " and will be removed in v2.4 and/or v3.0." - " Please use 'SingleInputShapeField' instead.", - DeprecationWarning, - stacklevel=3, - ) - # Perform the values' combination. - unique = list(set(values)) - if len(unique) != 1: - raise ValueError( - f"Cannot combine '{cls.field}': non-unique inputs." - ) - if not cls.is_valid(unique[0]): - raise ValueError( - f"Cannot combine '{cls.field}': invalid unique value." - ) - return unique[0] diff --git a/declearn/dataset/__init__.py b/declearn/dataset/__init__.py index bc6504024ffe5291e44b8a8c507e5fd131548329..7436e1ceacea4b5235fbf4f2d97bce198f3d386c 100644 --- a/declearn/dataset/__init__.py +++ b/declearn/dataset/__init__.py @@ -30,7 +30,7 @@ API tools * [DataSpec][declearn.dataset.DataSpecs]: Dataclass to wrap a dataset's metadata. * [load_dataset_from_json][declearn.dataset.load_dataset_from_json] - Utility function to parse a JSON into a dataset object. + DEPRECATED Utility function to parse a JSON into a dataset object. Dataset subclasses ------------------ diff --git a/declearn/dataset/_base.py b/declearn/dataset/_base.py index 004986ac8f24302ade1129e7c6baa247aa4bc15a..c0f3bbe7fe9a6719b8823c2f417e647f0e3cea51 100644 --- a/declearn/dataset/_base.py +++ b/declearn/dataset/_base.py @@ -17,9 +17,9 @@ """Dataset abstraction API.""" +import abc +import dataclasses import warnings -from abc import ABCMeta, abstractmethod -from dataclasses import dataclass from typing import Any, Iterator, List, Optional, Set, Tuple, Union from declearn.typing import Batch @@ -31,17 +31,9 @@ __all__ = [ ] -@dataclass +@dataclasses.dataclass class DataSpecs: - """Dataclass to wrap a dataset's metadata. - - Note - ---- - The `n_features` attribute has been deprecated as of declearn 2.2 - and will be removed in v2.4 and/or v3.0. It should therefore not - be used, whether at instantiation or afterwards. Please use the - `features_shape` attribute instead. - """ + """Dataclass to wrap a dataset's metadata.""" n_samples: int features_shape: Optional[ @@ -49,37 +41,10 @@ class DataSpecs: ] = None classes: Optional[Set[Any]] = None data_type: Optional[str] = None - n_features: Optional[int] = None # DEPRECATED as of declearn v2.2 - - def __post_init__(self): # pragma: no cover - # future: remove this (declearn >=2.4) - if isinstance(self.features_shape, int): - self.features_shape = (self.features_shape,) - warnings.warn( - "'features_shape' has replaced now-deprecated 'n_features'" - " and should therefore be passed as a tuple or list.", - RuntimeWarning, - stacklevel=3, - ) - if self.n_features is not None: - warnings.warn( - "'DataSepc.n_features' has been deprecated as of declearn v2.2" - " and should therefore no longer be used. It will be removed" - " in v2.4 and/or v3.0.", - RuntimeWarning, - stacklevel=3, - ) - if self.features_shape[-1] != self.n_features: - raise ValueError( - "Both 'features_shape' and deprecated 'n_features' were " - "passed to 'DataSpecs.__init__', with incoherent values." - ) - if self.features_shape: - self.n_features = self.features_shape[-1] @create_types_registry -class Dataset(metaclass=ABCMeta): +class Dataset(metaclass=abc.ABCMeta): """Abstract class defining an API to access training or testing data. A 'Dataset' is an interface towards data that exposes methods @@ -93,13 +58,13 @@ class Dataset(metaclass=ABCMeta): straightforward to specify as part of FL algorithms. """ - @abstractmethod + @abc.abstractmethod def get_data_specs( self, ) -> DataSpecs: """Return a DataSpecs object describing this dataset.""" - @abstractmethod + @abc.abstractmethod def generate_batches( # pylint: disable=too-many-arguments self, batch_size: int, @@ -146,8 +111,8 @@ class Dataset(metaclass=ABCMeta): """ -def load_dataset_from_json(path: str) -> Dataset: - """Instantiate a dataset based on a JSON dump file. +def load_dataset_from_json(path: str) -> Dataset: # pragma: no cover + """DEPRECATED Instantiate a dataset based on a JSON dump file. Parameters ---------- @@ -161,7 +126,25 @@ def load_dataset_from_json(path: str) -> Dataset: ------- dataset: Dataset Dataset (subclass) instance, reloaded from JSON. + + Raises + ------ + NotImplementedError + If the target `Dataset` does not implement a `load_from_json` + method (which was removed from the API in DecLearn 2.3.0). """ + warnings.warn( + "'load_dataset_from_json' was deprecated in Declearn 2.4.0, after" + "'Dataset.load_from_json' was removed from the API in v2.3.0. It " + "may raise a 'NotImplementedError', and will be removed in DecLearn " + "2.6 and/or 3.0.", + category=DeprecationWarning, + stacklevel=2, + ) dump = json_load(path) cls = access_registered(dump["name"], group="Dataset") + if not hasattr(cls, "load_from_json"): + raise NotImplementedError( + f"Dataset class '{cls}' does not implement 'load_from_json'." + ) return cls.load_from_json(path) diff --git a/declearn/dataset/_inmemory.py b/declearn/dataset/_inmemory.py index 1bd69fe735626ed6c41cabcd639014ea20601be1..650f2c88f402af14a49f624501285b9d978c5ff3 100644 --- a/declearn/dataset/_inmemory.py +++ b/declearn/dataset/_inmemory.py @@ -18,7 +18,6 @@ """Dataset implementation to serve scikit-learn compatible in-memory data.""" import os -import warnings from typing import Any, Dict, Iterator, List, Optional, Set, Union import numpy as np @@ -222,48 +221,6 @@ class InMemoryDataset(Dataset): f"Invalid 'data' attribute type: '{type(self.target)}'." ) - @staticmethod - def load_data_array( - path: str, - **kwargs: Any, - ) -> DataArray: # pragma: no cover - """Load a data array from a dump file. - - As of declearn v2.2, this staticmethod is DEPRECATED in favor of - `declearn.dataset.utils.load_data_array`, which is now calls. It - will be removed in v2.4 and/or v3.0. - - See [declearn.dataset.utils.load_data_array][] for more details. - """ - warnings.warn( - "'InMemoryDataset.load_data_array' has been deprecated in favor" - " of `declearn.dataset.utils.load_data_array`. It will be removed" - " in version 2.4 and/or 3.0.", - category=DeprecationWarning, - ) - return load_data_array(path, **kwargs) - - @staticmethod - def save_data_array( - path: str, - array: Union[DataArray, pd.Series], - ) -> str: # pragma: no cover - """Save a data array to a dump file. - - As of declearn v2.2, this staticmethod is DEPRECATED in favor of - `declearn.dataset.utils.save_data_array`, which is now calls. It - will be removed in v2.4 and/or v3.0. - - See [declearn.dataset.utils.save_data_array][] for more details. - """ - warnings.warn( - "'InMemoryDataset.save_data_array' has been deprecated in favor" - " of `declearn.dataset.utils.save_data_array`. It will be removed" - " in version 2.4 and/or 3.0.", - category=DeprecationWarning, - ) - return save_data_array(path, array) - @classmethod def from_svmlight( cls, diff --git a/declearn/dataset/tensorflow/_tensorflow.py b/declearn/dataset/tensorflow/_tensorflow.py index 070d2f76d81d3a07c6f0a32ed074576bc39f0e11..02bf602d95660088a412c7882f7ef1bfb7cee11e 100644 --- a/declearn/dataset/tensorflow/_tensorflow.py +++ b/declearn/dataset/tensorflow/_tensorflow.py @@ -334,7 +334,7 @@ def get_batch_function( def get_stack_function( batch_mode: BatchMode, -) -> Callable[[Union[List[None], List[tf.Tensor]]], Optional[tf.Tensor],]: +) -> Callable[[Union[List[None], List[tf.Tensor]]], Optional[tf.Tensor]]: """Return a function to stack sample-wise atomic elements.""" if batch_mode == "default": return _stack_default diff --git a/declearn/main/_client.py b/declearn/main/_client.py index 4f51ab72aab38aad7c3008465b862af848ce691b..75491a6625151e40d9a0962e5670d9a6dc9ee1dd 100644 --- a/declearn/main/_client.py +++ b/declearn/main/_client.py @@ -26,8 +26,11 @@ from typing import Any, Dict, Optional, Union import numpy as np from declearn import messaging -from declearn.communication import NetworkClientConfig from declearn.communication.api import NetworkClient +from declearn.communication.utils import ( + NetworkClientConfig, + verify_server_message_validity, +) from declearn.dataset import Dataset, load_dataset_from_json from declearn.main.utils import Checkpointer, TrainingManager from declearn.utils import LOGGING_LEVEL_MAJOR, get_logger @@ -63,12 +66,12 @@ class FederatedClient: In the latter three cases, the object's default logger will be set to that of this `FederatedClient`. train_data: Dataset or str - Dataset instance wrapping the training data, or path to - a JSON file from which it can be instantiated. + Dataset instance wrapping the training data. + (DEPRECATED) May be a path to a JSON dump file. valid_data: Dataset or str or None - Optional Dataset instance wrapping validation data, or - path to a JSON file from which it can be instantiated. + Optional Dataset instance wrapping validation data. If None, run evaluation rounds over `train_data`. + (DEPRECATED) May be a path to a JSON dump file. checkpoint: Checkpointer or dict or str or None, default=None Optional Checkpointer instance or instantiation dict to be used so as to save round-wise model, optimizer and metrics. @@ -249,24 +252,43 @@ class FederatedClient: """ # Await initialization instructions. self.logger.info("Awaiting initialization instructions from server.") - message = await self.netwk.recv_message() + received = await self.netwk.recv_message() # If a MetadataQuery is received, process it, then await InitRequest. - if message.message_cls is messaging.MetadataQuery: - await self._collect_and_send_metadata(message.deserialize()) - message = await self.netwk.recv_message() + if issubclass(received.message_cls, messaging.MetadataQuery): + await self._collect_and_send_metadata(received.deserialize()) + received = await self.netwk.recv_message() + # Ensure that an 'InitRequest' was received. + message = await verify_server_message_validity( + self.netwk, received, expected=messaging.InitRequest + ) # Perform initialization, catching errors to report them to the server. try: - if not issubclass(message.message_cls, messaging.InitRequest): - raise TypeError( - f"Awaited InitRequest message, got '{message.message_cls}'" - ) - await self._initialize_trainmanager(message.deserialize()) + self.trainmanager = TrainingManager( + model=message.model, + optim=message.optim, + aggrg=message.aggrg, + train_data=self.train_data, + valid_data=self.valid_data, + metrics=message.metrics, + logger=self.logger, + verbose=self.verbose, + ) except Exception as exc: await self.netwk.send_message(messaging.Error(repr(exc))) raise RuntimeError("Initialization failed.") from exc + # If instructed to do so, run additional steps to set up DP-SGD. + if message.dpsgd: + await self._initialize_dpsgd() # Send back an empty message to indicate that all went fine. self.logger.info("Notifying the server that initialization went fine.") await self.netwk.send_message(messaging.InitReply()) + # Optionally checkpoint the received model and optimizer. + if self.ckptr: + self.ckptr.checkpoint( + model=self.trainmanager.model, + optimizer=self.trainmanager.optim, + first_call=True, + ) async def _collect_and_send_metadata( self, @@ -286,37 +308,6 @@ class FederatedClient: ) await self.netwk.send_message(messaging.MetadataReply(data_info)) - async def _initialize_trainmanager( - self, - message: messaging.InitRequest, - ) -> None: - """Set up a TrainingManager based on server instructions. - - - Also await and set up DP constraints if instructed to do so. - - Checkpoint the model and optimizer if configured to do so. - """ - # Wrap up the model and optimizer received from the server. - self.trainmanager = TrainingManager( - model=message.model, - optim=message.optim, - aggrg=message.aggrg, - train_data=self.train_data, - valid_data=self.valid_data, - metrics=message.metrics, - logger=self.logger, - verbose=self.verbose, - ) - # If instructed to do so, await a PrivacyRequest to set up DP-SGD. - if message.dpsgd: - await self._initialize_dpsgd() - # Optionally checkpoint the received model and optimizer. - if self.ckptr: - self.ckptr.checkpoint( - model=self.trainmanager.model, - optimizer=self.trainmanager.optim, - first_call=True, - ) - async def _initialize_dpsgd( self, ) -> None: @@ -325,12 +316,13 @@ class FederatedClient: This method wraps the `make_private` one in the context of `initialize` and should never be called in another context. """ - message = await self.netwk.recv_message() - if not isinstance(message, messaging.PrivacyRequest): - msg = f"Expected a PrivacyRequest but received a '{type(message)}'" - self.logger.error(msg) - await self.netwk.send_message(messaging.Error(msg)) - raise RuntimeError(f"DP-SGD initialization failed: {msg}.") + received = await self.netwk.recv_message() + try: + message = await verify_server_message_validity( + self.netwk, received, expected=messaging.PrivacyRequest + ) + except Exception as exc: + raise RuntimeError("DP-SGD initialization failed.") from exc self.logger.info("Received a request to set up DP-SGD.") try: self.make_private(message) diff --git a/declearn/main/_server.py b/declearn/main/_server.py index 8dd9f48b6bf5625907fff31b51cfcc02e4377942..865efb5e6b71222cf690fb3ac25cb0b7e66571ba 100644 --- a/declearn/main/_server.py +++ b/declearn/main/_server.py @@ -18,6 +18,7 @@ """Server-side main Federated Learning orchestrating class.""" import asyncio +import copy import dataclasses import logging from typing import Any, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union @@ -150,6 +151,8 @@ class FederatedServer: # Set up private attributes to record the loss values and best weights. self._loss = {} # type: Dict[int, float] self._best = None # type: Optional[Vector] + # Set up a private attribute to prevent redundant weights sharing. + self._clients_holding_latest_model = set() # type: Set[str] def run( self, @@ -453,13 +456,48 @@ class FederatedServer: TrainingConfig dataclass instance wrapping data-batching and computational effort constraints hyper-parameters. """ - message = messaging.TrainRequest( + # Set up the base training request. + msg_light = messaging.TrainRequest( round_i=round_i, - weights=self.model.get_weights(trainable=True), + weights=None, aux_var=self.optim.collect_aux_var(), **train_cfg.message_params, ) - await self.netwk.broadcast_message(message, clients) + # Send it to clients, sparingly joining model weights. + await self._send_request_with_optional_weights(msg_light, clients) + + async def _send_request_with_optional_weights( + self, + msg_light: Union[messaging.TrainRequest, messaging.EvaluationRequest], + clients: Set[str], + ) -> None: + """Send a request to clients, sparingly adding model weights to it. + + Transmit the input message to all clients, adding a copy of the + global model weights for clients that do not already hold them. + + Parameters + ---------- + msg_light: + Message to send, with a 'weights' field left to None. + clients: + Name of the clients to whom the message is adressed. + """ + # Identify clients that do not already hold latest model weights. + needs_weights = clients.difference(self._clients_holding_latest_model) + # If any client does not hold latest weights, ensure they get it. + if needs_weights: + msg_heavy = copy.copy(msg_light) + msg_heavy.weights = self.model.get_weights(trainable=True) + messages = { + client: msg_heavy if client in needs_weights else msg_light + for client in clients + } + await self.netwk.send_messages(messages) + self._clients_holding_latest_model.update(needs_weights) + # If no client requires weights, do not even access them. + else: + await self.netwk.broadcast_message(msg_light, clients) def _conduct_global_update( self, @@ -482,6 +520,8 @@ class FederatedServer: updates = sum(msg.updates for msg in results.values()) gradients = self.aggrg.finalize_updates(updates) self.optim.apply_gradients(self.model, gradients) + # Record that no clients hold the updated model. + self._clients_holding_latest_model.clear() async def evaluation_round( self, @@ -547,12 +587,14 @@ class FederatedServer: EvaluateConfig dataclass instance wrapping data-batching and computational effort constraints hyper-parameters. """ - message = messaging.EvaluationRequest( + # Set up the base evaluation request. + msg_light = messaging.EvaluationRequest( round_i=round_i, - weights=self.model.get_weights(trainable=True), + weights=None, **valid_cfg.message_params, ) - await self.netwk.broadcast_message(message, clients) + # Send it to clients, sparingly joining model weights. + await self._send_request_with_optional_weights(msg_light, clients) def _aggregate_evaluation_results( self, diff --git a/declearn/main/utils/_training.py b/declearn/main/utils/_training.py index 4aa02369de55b190cd0f7619686bd436732ec547..09182c7bb22324bd2e8e621046b14abf737738cf 100644 --- a/declearn/main/utils/_training.py +++ b/declearn/main/utils/_training.py @@ -177,7 +177,11 @@ class TrainingManager: """Backend to `training_round`, without exception capture hooks.""" # Unpack and apply model weights and optimizer auxiliary variables. self.logger.info("Applying server updates to local objects.") - self.model.set_weights(message.weights, trainable=True) + if message.weights is None: + start_weights = self.model.get_weights(trainable=True) + else: + start_weights = message.weights + self.model.set_weights(start_weights, trainable=True) self.optim.process_aux_var(message.aux_var) self.optim.start_round() # trigger loss regularizer's `on_round_start` # Train under instructed effort constraints. @@ -190,7 +194,7 @@ class TrainingManager: # Compute and preprocess model updates and collect auxiliary variables. self.logger.info("Packing local updates to be sent to the server.") updates = self.aggrg.prepare_for_sharing( - updates=message.weights - self.model.get_weights(trainable=True), + updates=start_weights - self.model.get_weights(trainable=True), n_steps=int(effort["n_steps"]), ) aux_var = self.optim.collect_aux_var() @@ -331,8 +335,8 @@ class TrainingManager: ) -> messaging.EvaluationReply: """Backend to `evaluation_round`, without exception capture hooks.""" # Update the model's weights and evaluate on the local dataset. - # Revise: make the weights' update optional. - self.model.set_weights(message.weights, trainable=True) + if message.weights is not None: + self.model.set_weights(message.weights, trainable=True) metrics, states, effort = self.evaluate_under_constraints( message.batches, message.n_steps, message.timeout ) diff --git a/declearn/messaging/__init__.py b/declearn/messaging/__init__.py index 36e3a96502d77ac6231db82bc9ba89e22df01f3d..17f5e5e42fd6be0fbe6566dbcd68817937447f67 100644 --- a/declearn/messaging/__init__.py +++ b/declearn/messaging/__init__.py @@ -22,7 +22,7 @@ Message API and tools * [Message][declearn.messaging.Message]: Abstract base dataclass to define parsable messages. -* [SerializedMessage][declearn.messaging.SerializedMessage]]: +* [SerializedMessage][declearn.messaging.SerializedMessage]: Container for serialized Message instances. diff --git a/declearn/messaging/_base.py b/declearn/messaging/_base.py index f2f946e6cf0740f6f0f270174895236c6ef826b0..88dae220dabbd15a53d21743c8dc8816b103e601 100644 --- a/declearn/messaging/_base.py +++ b/declearn/messaging/_base.py @@ -74,7 +74,7 @@ class EvaluationRequest(Message): typekey = "eval_request" round_i: int - weights: Vector + weights: Optional[Vector] batches: Dict[str, Any] n_steps: Optional[int] timeout: Optional[int] @@ -210,7 +210,7 @@ class TrainRequest(Message): typekey = "train_request" round_i: int - weights: Vector + weights: Optional[Vector] aux_var: Dict[str, AuxVar] batches: Dict[str, Any] n_epoch: Optional[int] = None diff --git a/declearn/metrics/__init__.py b/declearn/metrics/__init__.py index dd5269697017e16d8a262f79bd09980924cf496c..445dbfe0be659711708be35b67a6f2c4e7d5071a 100644 --- a/declearn/metrics/__init__.py +++ b/declearn/metrics/__init__.py @@ -25,7 +25,7 @@ Abstractions ------------ * [Metric][declearn.metrics.Metric]: Abstract base class defining an API for metrics' computation. -* [MetricState][declearn.metric.MetricState]: +* [MetricState][declearn.metrics.MetricState]: Abstract base class for Metrics intermediate aggregatable states. * [MeanMetric][declearn.metrics.MeanMetric]: Abstract class that defines a template for simple scores' averaging. diff --git a/declearn/metrics/_api.py b/declearn/metrics/_api.py index e655a8c3af3ced33e0a3fd474c5312c4c66143ec..04405777044158eae800e14dfb8f20073cc9847f 100644 --- a/declearn/metrics/_api.py +++ b/declearn/metrics/_api.py @@ -41,7 +41,21 @@ __all__ = [ class MetricState( Aggregate, base_cls=True, register=False, metaclass=abc.ABCMeta ): - """Abstract base class for Metrics intermediate aggregatable states.""" + """Abstract base class for Metrics intermediate aggregatable states. + + Each and every `Metric` subclass is expected to be coupled with one + (or multiple) `MetricState` subtypes, which are used to exchange and + aggregate partial results across a network of peers, which can in the + end be passed to a single `Metric` instance for metrics' finalization. + + This class also defines whether contents are compatible with secure + aggregation, and whether some fields should remain in cleartext no + matter what. + + Note that subclasses are automatically type-registered, and should be + decorated as `dataclasses.dataclass`. To prevent registration, simply + pass `register=False` at inheritance. + """ _group_key = "MetricState" diff --git a/declearn/model/sklearn/_sgd.py b/declearn/model/sklearn/_sgd.py index 160a6d3970295c4b2e6c6fd66653b8b8f945b07d..17263da5a57ced83ff95517307ddf0dfda23bfbb 100644 --- a/declearn/model/sklearn/_sgd.py +++ b/declearn/model/sklearn/_sgd.py @@ -489,12 +489,13 @@ class SklearnSGDModel(Model): ) -> Callable[[np.ndarray, np.ndarray], np.ndarray]: """Return a function to compute point-wise loss for a given batch.""" # fmt: off - # Gather / instantiate a loss function from the wrapped model's specs. - if hasattr(self._model, "loss_function_"): - loss_smp = self._model.loss_function_.py_loss - else: - loss_cls, *args = self._model.loss_functions[self._model.loss] - loss_smp = loss_cls(*args).py_loss + # Instantiate a loss function from the wrapped model's specs. + loss_cls, *args = self._model.loss_functions[self._model.loss] + if self._model.loss in ( + "huber", "epsilon_insensitive", "squared_epsilon_insensitive" + ): + args = (self._model.epsilon,) + loss_smp = loss_cls(*args).py_loss # Wrap it to support batched inputs. def loss_1d(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray: return np.array([loss_smp(*smp) for smp in zip(y_pred, y_true)]) diff --git a/declearn/model/tensorflow/_model.py b/declearn/model/tensorflow/_model.py index 812fad761cae780e58386253bd464697de6c69dc..212b3189da0c57630be528bdddd71bfed8da5114 100644 --- a/declearn/model/tensorflow/_model.py +++ b/declearn/model/tensorflow/_model.py @@ -21,7 +21,12 @@ from copy import deepcopy from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import numpy as np +# fmt: off +# pylint: disable=import-error,no-name-in-module import tensorflow as tf # type: ignore +import tensorflow.keras as tf_keras # type: ignore +# pylint: enable=import-error,no-name-in-module +# fmt: on from numpy.typing import ArrayLike from typing_extensions import Self # future: import from typing (py >=3.11) @@ -46,7 +51,7 @@ __all__ = [ class TensorflowModel(Model): """Model wrapper for TensorFlow Model instances. - This `Model` subclass is designed to wrap a `tf.keras.Model` instance + This `Model` subclass is designed to wrap a `tf_keras.Model` instance to be trained federatively. Notes regarding device management (CPU, GPU, etc.): @@ -69,9 +74,9 @@ class TensorflowModel(Model): def __init__( self, - model: tf.keras.layers.Layer, - loss: Union[str, tf.keras.losses.Loss], - metrics: Optional[List[Union[str, tf.keras.metrics.Metric]]] = None, + model: tf_keras.layers.Layer, + loss: Union[str, tf_keras.losses.Loss], + metrics: Optional[List[Union[str, tf_keras.metrics.Metric]]] = None, _from_config: bool = False, **kwargs: Any, ) -> None: @@ -92,19 +97,19 @@ class TensorflowModel(Model): compiled with the model and computed using the `evaluate` method of the returned TensorflowModel instance. **kwargs: Any - Any additional keyword argument to `tf.keras.Model.compile` + Any additional keyword argument to `tf_keras.Model.compile` may be passed. """ # Type-check the input Model and wrap it up. - if not isinstance(model, tf.keras.layers.Layer): + if not isinstance(model, tf_keras.layers.Layer): raise TypeError( - "'model' should be a tf.keras.layers.Layer instance." + "'model' should be a tf_keras.layers.Layer instance." ) - if not isinstance(model, tf.keras.Model): - model = tf.keras.Sequential([model]) + if not isinstance(model, tf_keras.Model): + model = tf_keras.Sequential([model]) super().__init__(model) # Ensure the loss is a keras.Loss object and set its reduction to none. - loss = build_keras_loss(loss, reduction=tf.keras.losses.Reduction.NONE) + loss = build_keras_loss(loss, reduction=tf_keras.losses.Reduction.NONE) # Select the device where to place computations and move the model. policy = get_device_policy() self._device = select_device(gpu=policy.gpu, idx=policy.idx) @@ -150,9 +155,9 @@ class TensorflowModel(Model): def get_config( self, ) -> Dict[str, Any]: - config = tf.keras.layers.serialize(self._model) # type: Dict[str, Any] + config = tf_keras.layers.serialize(self._model) # type: Dict[str, Any] kwargs = deepcopy(self._kwargs) - loss = tf.keras.losses.serialize(kwargs.pop("loss")) + loss = tf_keras.losses.serialize(kwargs.pop("loss")) return {"model": config, "loss": loss, "kwargs": kwargs} @classmethod @@ -169,8 +174,8 @@ class TensorflowModel(Model): device = select_device(gpu=policy.gpu, idx=policy.idx) # Deserialize the model and loss keras objects on the device. with tf.device(device): - model = tf.keras.layers.deserialize(config["model"]) - loss = tf.keras.losses.deserialize(config["loss"]) + model = tf_keras.layers.deserialize(config["model"]) + loss = tf_keras.losses.deserialize(config["loss"]) # Instantiate the TensorflowModel, avoiding device-to-device copies. return cls(model, loss, **config["kwargs"], _from_config=True) @@ -178,10 +183,20 @@ class TensorflowModel(Model): self, trainable: bool = False, ) -> TensorflowVector: + variables = self._get_weight_variables(trainable) + return TensorflowVector({var.name: var.value() for var in variables}) + + def _get_weight_variables( + self, + trainable: bool, + ) -> Iterable[tf.Variable]: + """Access TensorFlow Variables wrapping model weight tensors.""" variables = ( self._model.trainable_weights if trainable else self._model.weights ) - return TensorflowVector({var.name: var.value() for var in variables}) + if hasattr(tf_keras, "version") and tf_keras.version().startswith("3"): + variables = (var.value for var in variables) + return variables def set_weights( self, @@ -193,7 +208,9 @@ class TensorflowModel(Model): "TensorflowModel requires TensorflowVector weights." ) self._verify_weights_compatibility(weights, trainable=trainable) - variables = {var.name: var for var in self._model.weights} + variables = { + var.name: var for var in self._get_weight_variables(trainable) + } with tf.device(self._device): for name, value in weights.coefs.items(): variables[name].assign(value, read_value=False) @@ -220,9 +237,7 @@ class TensorflowModel(Model): In case some expected keys are missing, or additional keys are present. Be verbose about the identified mismatch(es). """ - variables = ( - self._model.trainable_weights if trainable else self._model.weights - ) + variables = self._get_weight_variables(trainable) raise_on_stringsets_mismatch( received=set(vector.coefs), expected={var.name for var in variables}, @@ -242,7 +257,7 @@ class TensorflowModel(Model): norm = tf.constant(max_norm) grads, loss = self._compute_clipped_gradients(*data, norm) self._loss_history.append(float(loss.numpy())) - grads_and_vars = zip(grads, self._model.trainable_weights) + grads_and_vars = zip(grads, self._get_weight_variables(trainable=True)) return TensorflowVector( {var.name: grad for grad, var in grads_and_vars} ) @@ -326,7 +341,7 @@ class TensorflowModel(Model): ) -> None: self._verify_weights_compatibility(updates, trainable=True) with tf.device(self._device): - for var in self._model.trainable_weights: + for var in self._get_weight_variables(trainable=True): updt = updates.coefs[var.name] if isinstance(updt, tf.IndexedSlices): var.scatter_add(updt) diff --git a/declearn/model/tensorflow/_optim.py b/declearn/model/tensorflow/_optim.py index 8ee030d7456005bbf9797e95af6f4ea1ced3af4a..6538b6871b7373637b1769c0f41f14951f7a72a0 100644 --- a/declearn/model/tensorflow/_optim.py +++ b/declearn/model/tensorflow/_optim.py @@ -17,9 +17,14 @@ """Hacky OptiModule subclass enabling the use of a keras Optimizer.""" -from typing import Any, Dict, Union +from typing import Any, Dict, List, Union +# fmt: off +# pylint: disable=import-error,no-name-in-module import tensorflow as tf # type: ignore +import tensorflow.keras as tf_keras # type: ignore +# pylint: enable=import-error,no-name-in-module +# fmt: on from declearn.model.api import Vector from declearn.model.tensorflow.utils import select_device @@ -81,7 +86,7 @@ class TensorflowOptiModule(OptiModule): def __init__( self, - optim: Union[tf.keras.optimizers.Optimizer, str, Dict[str, Any]], + optim: Union[tf_keras.optimizers.Optimizer, str, Dict[str, Any]], ) -> None: """Instantiate a hacky tensorflow optimizer plug-in module. @@ -105,7 +110,7 @@ class TensorflowOptiModule(OptiModule): self._device = select_device(gpu=policy.gpu, idx=policy.idx) # Wrap the provided optimizer, enforcing a fixed learning rate of 1. # Also prevent the use of weight-decay or built-in ema (~momentum). - self.optim = tf.keras.optimizers.get(optim) + self.optim = tf_keras.optimizers.get(optim) config = self.optim.get_config() config["weight_decay"] = 0 config["use_ema"] = False @@ -184,7 +189,7 @@ class TensorflowOptiModule(OptiModule): key: tf.Variable(tf.zeros_like(grad), name=key) for key, grad in gradients.coefs.items() } - self.optim.build(self._vars.values()) + self.optim.build(list(self._vars.values())) def reset(self) -> None: """Reset this module to its uninitialized state. @@ -201,13 +206,13 @@ class TensorflowOptiModule(OptiModule): policy = get_device_policy() self._device = select_device(gpu=policy.gpu, idx=policy.idx) with tf.device(self._device): - self._vars = {} + self._vars.clear() self.optim = self.optim.from_config(self.optim.get_config()) def get_config( self, ) -> Dict[str, Any]: - optim = tf.keras.optimizers.serialize(self.optim) + optim = tf_keras.optimizers.serialize(self.optim) return {"optim": optim} def get_state( @@ -217,11 +222,20 @@ class TensorflowOptiModule(OptiModule): key: (val.shape.as_list(), val.dtype.name) for key, val in self._vars.items() } + variables = self._get_optimizer_variables() state = TensorflowVector( - {var.name: var.value() for var in self.optim.variables()} + {str(i): v.value() for i, v in enumerate(variables)} ) return {"specs": specs, "state": state} + def _get_optimizer_variables( + self, + ) -> List[tf.Variable]: + """Access wrapped optimizer's variables as 'tf.Variable' instances.""" + if hasattr(tf_keras, "version") and tf_keras.version().startswith("3"): + return [var.value for var in self.optim.variables] + return self.optim.variables() + def set_state( self, state: Dict[str, Any], @@ -239,9 +253,9 @@ class TensorflowOptiModule(OptiModule): key: tf.Variable(tf.zeros(shape, dtype), name=key) for key, (shape, dtype) in state["specs"].items() } - self.optim.build(self._vars.values()) + self.optim.build(list(self._vars.values())) # Restore optimizer variables' values from the input state dict. - opt_vars = {var.name: var for var in self.optim.variables()} + opt_vars = self._get_optimizer_variables() with tf.device(self._device): - for key, val in state["state"].coefs.items(): - opt_vars[key].assign(val, read_value=False) + for var, val in zip(opt_vars, state["state"].coefs.values()): + var.assign(val, read_value=False) diff --git a/declearn/model/tensorflow/utils/_gpu.py b/declearn/model/tensorflow/utils/_gpu.py index bea73eff4e01df9a24bbc5f42162ea9ffdbffb63..a6f7f9f31bc9ce4d700a76fae5cc4e4e4497c20b 100644 --- a/declearn/model/tensorflow/utils/_gpu.py +++ b/declearn/model/tensorflow/utils/_gpu.py @@ -21,7 +21,12 @@ import functools import warnings from typing import Any, Callable, Optional, Union +# fmt: off +# pylint: disable=import-error,no-name-in-module import tensorflow as tf # type: ignore +import tensorflow.keras as tf_keras # type: ignore +# pylint: enable=import-error,no-name-in-module +# fmt: on __all__ = [ @@ -82,9 +87,9 @@ def select_device( def move_layer_to_device( - layer: tf.keras.layers.Layer, + layer: tf_keras.layers.Layer, device: Union[tf.config.LogicalDevice, str], -) -> tf.keras.layers.Layer: +) -> tf_keras.layers.Layer: """Create a copy of an input keras layer placed on a given device. This functions creates a copy of the input layer and of all its weights. @@ -101,13 +106,13 @@ def move_layer_to_device( Returns ------- - layer: tf.keras.layers.Layer + layer: tf_keras.layers.Layer Copy of the input layer, with its weights backed on `device`. """ - config = tf.keras.layers.serialize(layer) + config = tf_keras.layers.serialize(layer) weights = layer.get_weights() with tf.device(device): - layer = tf.keras.layers.deserialize(config) + layer = tf_keras.layers.deserialize(config) layer.set_weights(weights) return layer diff --git a/declearn/model/tensorflow/utils/_loss.py b/declearn/model/tensorflow/utils/_loss.py index 5a3ffdaae06ee144f51777747f1d2bc797dd3f04..da49339d90414e3c7c3dab4d43d9cbc28611b200 100644 --- a/declearn/model/tensorflow/utils/_loss.py +++ b/declearn/model/tensorflow/utils/_loss.py @@ -21,7 +21,12 @@ import inspect from typing import Any, Callable, Dict, Optional, Union +# fmt: off +# pylint: disable=import-error,no-name-in-module import tensorflow as tf # type: ignore +import tensorflow.keras as tf_keras # type: ignore +# pylint: enable=import-error,no-name-in-module +# fmt: on __all__ = [ @@ -33,18 +38,18 @@ __all__ = [ CallableLoss = Callable[[tf.Tensor, tf.Tensor], tf.Tensor] -@tf.keras.utils.register_keras_serializable(package="declearn") -class LossFunction(tf.keras.losses.Loss): +@tf_keras.utils.register_keras_serializable(package="declearn") +class LossFunction(tf_keras.losses.Loss): """Generic loss function container enabling reduction strategy control.""" def __init__( self, loss_fn: Union[str, CallableLoss], - reduction: str = tf.keras.losses.Reduction.NONE, + reduction: str = tf_keras.losses.Reduction.NONE, name: Optional[str] = None, ) -> None: - super().__init__(reduction, name) - self.loss_fn = tf.keras.losses.deserialize(loss_fn) + super().__init__(reduction=reduction, name=name) + self.loss_fn = tf_keras.losses.deserialize(loss_fn) def call( self, @@ -59,14 +64,14 @@ class LossFunction(tf.keras.losses.Loss): ) -> Dict[str, Any]: # inherited docstring; pylint: disable=missing-docstring config = super().get_config() # type: Dict[str, Any] - config["loss_fn"] = tf.keras.losses.serialize(self.loss_fn) + config["loss_fn"] = tf_keras.losses.serialize(self.loss_fn) return config def build_keras_loss( - loss: Union[str, tf.keras.losses.Loss, CallableLoss], - reduction: str = tf.keras.losses.Reduction.NONE, -) -> tf.keras.losses.Loss: + loss: Union[str, tf_keras.losses.Loss, CallableLoss], + reduction: str = tf_keras.losses.Reduction.NONE, +) -> tf_keras.losses.Loss: """Type-check, deserialize and/or wrap a keras loss into a Loss object. Parameters @@ -79,11 +84,11 @@ def build_keras_loss( Returns ------- - loss_obj: tf.keras.losses.Loss + loss_obj: tf_keras.losses.Loss Loss object, configured to apply the `reduction` scheme. """ # Case when 'loss' is already a Loss object. - if isinstance(loss, tf.keras.losses.Loss): + if isinstance(loss, tf_keras.losses.Loss): loss.reduction = reduction # Case when 'loss' is a string: deserialize and/or wrap into a Loss object. elif isinstance(loss, str): @@ -92,7 +97,7 @@ def build_keras_loss( elif inspect.isfunction(loss): loss = LossFunction(loss, reduction=reduction) # Case when 'loss' is of invalid type: raise a TypeError. - if not isinstance(loss, tf.keras.losses.Loss): + if not isinstance(loss, tf_keras.losses.Loss): raise TypeError( "'loss' should be a keras Loss object or the name of one." ) @@ -103,7 +108,7 @@ def build_keras_loss( def get_keras_loss_from_string( name: str, reduction: str, -) -> tf.keras.losses.Loss: +) -> tf_keras.losses.Loss: """Instantiate a keras Loss object from a registered string identifier. - If `name` matches a Loss registration name, return an instance. @@ -112,8 +117,8 @@ def get_keras_loss_from_string( Loss subclass instance wrapping the function. - If it does not match anything, raise a ValueError. """ - loss = tf.keras.losses.deserialize(name) - if isinstance(loss, tf.keras.losses.Loss): + loss = tf_keras.losses.deserialize(name) + if isinstance(loss, tf_keras.losses.Loss): loss.reduction = reduction return loss if inspect.isfunction(loss): diff --git a/declearn/model/torch/_model.py b/declearn/model/torch/_model.py index 639ac398d80f36cbb6f97585511de71db510a7b1..1ad815967221aacfddde1b15c7b79a2cf526565a 100644 --- a/declearn/model/torch/_model.py +++ b/declearn/model/torch/_model.py @@ -373,7 +373,7 @@ class TorchModel(Model): def compute_batch_predictions( self, batch: Batch, - ) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray],]: + ) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]: inputs, y_true, s_wght = self._unpack_batch(batch) if y_true is None: raise TypeError( diff --git a/declearn/optimizer/modules/_api.py b/declearn/optimizer/modules/_api.py index 4aab241d8461cb08f106cc163f6d9100ba787f90..bbc13e80e2b3a17a10baf0deab0a85736244c5d1 100644 --- a/declearn/optimizer/modules/_api.py +++ b/declearn/optimizer/modules/_api.py @@ -42,7 +42,22 @@ T = TypeVar("T") @dataclasses.dataclass class AuxVar(Aggregate, base_cls=True, register=False, metaclass=abc.ABCMeta): - """Abstract base class for OptiModule auxiliary variables.""" + """Abstract base class for OptiModule auxiliary variables. + + Each and every `OptiModule` subclass that requires information to be + exchanged is expected to be coupled with one (or multiple) `AuxVar` + subtype(s). These may be used to transmit information from a server + to its clients, and/or to exchange and aggregate data from clients + into a single `AuxVar` instance to be processed by the server. + + This class also defines whether contents are compatible with secure + aggregation, and whether some fields should remain in cleartext no + matter what. + + Note that subclasses are automatically type-registered, and should be + decorated as `dataclasses.dataclass`. To prevent registration, simply + pass `register=False` at inheritance. + """ _group_key = "AuxVar" diff --git a/declearn/utils/_multiprocess.py b/declearn/utils/_multiprocess.py index 314c71638ba4819c449849e2e9b3e4de1853ea2d..44bbdd8ba97e72f1a3d9f115cbc7f1f9d4bdfd63 100644 --- a/declearn/utils/_multiprocess.py +++ b/declearn/utils/_multiprocess.py @@ -199,32 +199,26 @@ def add_exception_catching( name: str, ) -> Callable[..., Any]: """Wrap a function to catch exceptions and put them in a Queue.""" - return functools.partial( - _run_with_exception_catching, func=func, queue=queue, name=name - ) - -def _run_with_exception_catching( - *args: Any, - func: Callable[..., Any], - queue: Queue, # Queue[Tuple[str, Union[Any, RuntimeError]]] (py >=3.9) - name: str, - **kwargs: Any, -) -> Any: - """Call the wrapped function and catch exceptions or results.""" - try: - result = func(*args, **kwargs) - except Exception as exc: # pylint: disable=broad-exception-caught - err = RuntimeError( - f"Exception of type {type(exc)} occurred:\n" - "".join(traceback.format_exception(type(exc), exc, tb=None)) - ) # future: `traceback.format_exception(exc)` (py >=3.10) - queue.put((name, err)) - sys.exit(1) - else: - queue.put((name, result)) + @functools.wraps(func) + def wrapped(*args, **kwargs): + """Call the wrapped function and queue exceptions or results.""" + nonlocal name, queue + try: + result = func(*args, **kwargs) + except Exception as exc: # pylint: disable=broad-exception-caught + err = RuntimeError( + f"Exception of type {type(exc)} occurred:\n" + + "".join(traceback.format_exception(type(exc), exc, tb=None)) + ) # future: `traceback.format_exception(exc)` (py >=3.10) + queue.put((name, err)) + sys.exit(1) + else: + queue.put((name, result)) sys.exit(0) + return wrapped + def run_processes( processes: List[mp.Process], diff --git a/declearn/version.py b/declearn/version.py index 2c2a2ce58b75bbaa833cbdf2232f1886091e5575..ed20f25a5ca69ae11b600b2e902c4fc82afba255 100644 --- a/declearn/version.py +++ b/declearn/version.py @@ -17,5 +17,5 @@ """DecLearn version information, as hard-coded constants.""" -VERSION = "2.4.0b" +VERSION = "2.4.0" """Version information of the installed DecLearn package.""" diff --git a/docs/devs-guide/contribute.md b/docs/devs-guide/contribute.md index f1e7280ecbedfccce8da317708375a56d6c39602..d69577d43162ba8f6286d21ab7ef4d9519fe83e7 100644 --- a/docs/devs-guide/contribute.md +++ b/docs/devs-guide/contribute.md @@ -60,7 +60,7 @@ The **coding rules** are fairly simple: are detailed in the [docstrings style guide](./docs-style.md). - Type-hint the code, abiding by [PEP 484](https://peps.python.org/pep-0484/); note that the use of Any and of "type: ignore" comments is authorized, but - should be remain sparse. + should remain parsimonious. - Lint your code with [mypy](http://mypy-lang.org/) (for static type checking) and [pylint](https://pylint.pycqa.org/en/latest/) (for more general linting); do use "type: ..." and "pylint: disable=..." comments where you think it diff --git a/docs/release-notes/SUMMARY.md b/docs/release-notes/SUMMARY.md index 796597aa24bc08afc7ab7b7ed358654cdd05c8e6..568e1343869a8b2bcc974c538c5f1e7371c0fc10 100644 --- a/docs/release-notes/SUMMARY.md +++ b/docs/release-notes/SUMMARY.md @@ -1,3 +1,4 @@ +- [v2.4.0](v2.4.0.md) - [v2.3.2](v2.3.2.md) - [v2.3.1](v2.3.1.md) - [v2.3.0](v2.3.0.md) diff --git a/docs/release-notes/v2.4.0.md b/docs/release-notes/v2.4.0.md new file mode 100644 index 0000000000000000000000000000000000000000..cfaa2de8767ba86413dd980eae7bf49cc2302f8a --- /dev/null +++ b/docs/release-notes/v2.4.0.md @@ -0,0 +1,292 @@ +# declearn v2.4.0 + +Released: XX/XX/XXXX + +**Important notice**:<br/> +DecLearn 2.4 derogates to SemVer by revising some of the major DecLearn +component APIs. + +This is mitigated in two ways: + +- No changes relate to the main process setup code, meaning that end-users + that do not use custom components (aggregator, optimodule, metric, etc.) + should not see any difference, and their code will work as before (as an + illustration, our examples' code remains unchanged). +- Key methods that were deprecated in favor of new API ones are kept for + two more minor versions, and are still tested to work as before. + +Any end-user encountering issues due to the released or planned evolutions of +DecLearn is invited to contact us via GitLab, GitHub or e-mail so that we can +provide with assistance and/or update our roadmap so that changes do not hinder +the usability of DecLearn 2.x for research and applications. + +## New version policy (and future roadmap) + +As noted above, v2.4 does not fully abide by SemVer rules. In the future, more +partially-breaking changes and API revisions may be introduced, incrementally +paving the way towards the next major release, while trying as much as possible +not to break end-user code. + +To avoid unforeseen incompatibilities and cryptic bugs from arsing, from this +version onwards, the server and clients are expected and verified to use the +same `major.minor` version of DecLearn. +This policy may be updated in the future, e.g. to specify that clients may +have a newer minor version than the server (and most probably not the other +way around). + +To avoid unhappy surprises, we are starting to maintain a public roadmap on +our GitLab. Although it may change, it should provide interested users (notably +those that are interested in developing custom components or processes on top +of DecLearn) with a way to anticipate changes, and voice any concerns or advice +they might have. + + +## Revise all aggregation APIs + +### Revise the overal design for aggregation and introduce `Aggregate` API + +This release introduces the `Aggregate` API, which is based on an abstract base +dataclass acting as a template for data structures that require sharing across +peers and aggregation. + +The `declearn.utils.Aggregate` ABC acts as a shared ancestor providing with +a base API and shared backend code to define data structures that: + + - are serializable to and deserializable from JSON, and may therefore be + preserved across network communications + - are aggregatable into an instance of the same structure + - use summation as the default aggregation rule for fields, which is + overridable by redefining the `default_aggregate` method + - can implement custom `aggregate_<field.name>` methods to override the + default summation rule + - implement a `prepare_for_secagg` method that + - enables defining which fields merely require sum-aggregation and need + encryption when using SecAgg, and which fields are to be preserved in + cleartext (and therefore go through the usual default or custom + aggregation methods) + - can be made to raise a `NotImplementedError` when SecAgg cannot be + achieved on a data structure + +This new ABC currently has three main children: + + - `AuxVar`: replaces plain dict for `Optimizer` auxiliary variables + - `MetricState`: replaces plain dict for `Metric` intermediate states + - `ModelUpdates`: replaces sharing of updates as `Vector` and `n_steps` + +Each of this is defined jointly with another (pre-existing, revised) API for +components that (a) produce `Aggregate` data structures based on some input +data and/or computations; (b) produce some output results based on a received +`Aggregate` structure, meant to result from the aggregation of multiple peers' +produced data. + +### Revise `Aggregator` API, introducing `ModelUpdates` + +The `Aggregator` API was revised to make use of the new `ModelUpdates` data +structure (inheriting `Aggregate`). + +- `Aggregator.prepare_for_sharing` pre-processes an input `Vector` containing + raw model updates and an integer indicating the number of local SGD steps + into a `ModelUpdates` structure. +- `Aggregator.finalize_updates` receives a `ModelUpdates` resulting from the + aggregation of peers' instances, and performs final computations to produce + a `Vector` of aggregated model updates. +- The legacy `Aggregator.aggregate` method is deprecated (but still works). + +### Revise auxiliary variables for `Optimizer`, introducing `AuxVar` + +The `OptiModule` API (and, consequently, `Optimizer`) was revised as to the +design and signature of auxiliary variables related methods, to make use of +the new `AuxVar` data structure (inheriting `Aggregate`). + +- `OptiModule.collect_aux_var` now emits either `None` or an `AuxVar` instance + (the precise type of which is module-dependent), instead of a mere dict. +- `OptiModule.process_aux_var` now expects a proper-type `AuxVar` instance + that _already_ aggregates clients' data, externalizing the aggregation rules + to the `AuxVar` subtypes, while keeping the finalization logic part of the + `OptiModule` subclasses. +- `Optimizer.collect_aux_var` therefore emits a `{name: aux_var}` dict. +- `Optimizer.process_aux_var` therefore expects a `{name: aux_var}` dict, + rather than having distinct signatures on the client and server sides. +- It is now expected that server-side components will send the _same_ data + to all clients, rather than allow sending client-wise values. + +The backend code of `ScaffoldClientModule` and `ScaffoldServerModule` was +heavily revised to alter the distribution of information and computations: + +- Client-side modules are now the sole owners of their local state, and send + sum-aggregatable updates to the server, that are therefore SecAgg-compatible. +- The server consequently shares the same information with all clients, namely + the current global state. +- To keep track of the (possibly growing with time) number of unique client, + clients generate a random uuid that is sent with their state updates and + preserved in cleartext when SecAgg is used. +- As a consequence, the server component knows which clients contributed to a + given round, but receives an aggregate of local updates rather than the + client-wise state values. + +### Revise `Metric` API, introducing `ModelState` + +The `Metric` API was revised to make use of the new `MetricState` data +structure (inheriting `Aggregate`). + +- `Metric.build_initial_states` generates a "zero-state" `MetricState` instance + (it replaces the previously-private `_build_states` method that returned a + dict). +- `Metric.get_states` returns a (Metric-type-dependent) `MetricState` + instance, instead of a mere dict. +- `Metric.set_states` assigns an incoming `MetricState` into the instance, that + may be finalized into results using the unchanged `get_result` method. +- The legacy `Metric.agg_states` is deprecated, in favor of `set_states` (but + it still works). + + +## Revise backend communications and messaging APIs + +This release introduces some important backend changes to the communication +and messaging APIs of DecLearn, resulting in more robust code (that is also +easier to test and maintain), more efficient message parsing (possibly-costly +de-serialization is now delayed to a time posterior to validity verification) +and the extensibility of application messages, enabling to easily define and +use custom message structures in downstream applications. + +The most important API change is that network communication endpoints now +return `SerializedMessage` instances rather than `Message` ones. + +### New `declearn.communication.api.backend` submodule + +- Introduce a new `ActionMessage` minimal API under its `actions` + submodule, that defines hard-coded, lightweight and easy-to-parse + data structures designed to convey information and content across + network communications agnostic to the content's nature. +- Revise and expose the `MessagesHandler` util, that now builds on + the `ActionMessage` API to model remote calls and answer them. +- Move the `declearn.communication.messaging.flags` submodule to + `declearn.communication.api.backend.flags`. + +### New `declearn.messaging` submodule + +- Revise the `Message` API to make it extendable, with automated + type-registration of subclasses by default. +- Introduce `SerializedMessage` as a wrapper for received messages, + that parses the exact `Lessage` subtype (enabling logic tests and + message filtering) but delays actual content de-serialization and + `Message` object recovery (enabling to prevent undue resources use + for unwanted messages that end up being discarded). +- Move most existing `Message` subclasses to the new submodule, for + retro-compatibility purposes. In DecLearn 3.0 these will probably + be re-dispatched to make it clear that concrete messages only make + sense in the context of specific multi-agent processes. +- Drop backend-oriented `Message` subclasses that are replaced with + the new `ActionMessage` backbone structures. +- Deprecate the `declearn.communication.messaging` submodule, that + is temporarily maintained, re-exporting moved contents as well as + deprecated message types (which are bound to be rejected if sent). + +### Revise `NetworkClient` and `NetworkServer` + +- Have message-receiving methods return `SerializedMessage` instances + rather than finalized de-serialized `Message` ones. +- Quit sending and expecting 'data_info' with registration requests. +- Rename `NetworkClient.check_message` into `recv_message` (keep + the former as an alias, with a `DeprecationWarning`). +- Improve the use of (optional) timeouts when sending or expecting + messages and overall exceptions handling: + - `NetworkClient.recv_message` may either raise a `TimeoutError` + (in case of timeout) or `RuntimeError` (in case of rejection). + - `NeworkServer.send_messages` and `broadcast_message` quietly + stop waiting for clients to collect messages after the (opt.) + timeout delay has passed. Messages may still be collected. + - `NetworkServer.wait_for_messages` no longer accepts a timeout. + - `NetworkServer.wait_for_messages_with_timeout` implements the + possibility to setup a timeout. It returns both received client + replies and a list of clients that failed to answer. + - All timeouts can now be specified as float values (which is + mostly useful for testing purposes or simulated environments). +- Add a `heartbeat` instantiation parameter, with a default value of + 1 second, that is passed to the underlying `MessagesHandler`. In + simulated contexts (including tests), setting a low heartbeat can + cut runtime down significantly. + +### New `declearn.communication.utils` submodule + +Introduce the `declearn.communication.utils` submodule, and move existing +`declearn.communication` utils to it. Keep re-exporting them from the parent +module to preserve code compatibility. + +Add `verify_client_messages_validity` and `verify_server_message_validity` as +part of the new submodule, that refactor some backend code from orchestration +classes related to the filtering and type-checking of exchanged messages. + +## Usability updates + +A few minor changes were made in hope that they can improve DecLearn usability +for end-users. + +### Record and save training losses + +The `Model` API was updated so that `Model.compute_batch_gradients` now records +the computed batch-averaged model loss as a float value in an internal buffer, +and the newly-introduced `Model.collect_training_losses` method enables getting +all stored values (and purging the buffer on the way). + +The `FederatedClient` was consequently updated to collect and export training +loss values at the end of each and every training round when a `Checkpointer` +is attached to it (otherwise, values are purged from memory but not recorded to +disk). + +### Add a verbosity option for `FederatedClient` and `TrainingManager`. + +`FederatedClient` and `TrainingManager` now both accept a `verbose: bool=True` +instantiatation keyword argument, that changes: + +- (a) the default logger verbosity level: if `logger=None` is also passed, + the default logger will have 'info'-level verbosity if `verbose` and + `declearn.utils.LOGGING_LEVEL_MAJOR`-level if `not verbose`, so that only + evaluation metrics and errors are logged. +- (b) the optional display of a progressbar when conducting training or + evaluation rounds; if `not verbose`, no progressbar is used. + +### Add 'TrainingManager.(train|evaluate)_under_constraints' to the API. + +These public methods enable running training or evaluation rounds without +relying on `Message` structures to specify parameters nor collect results. + +### Modularize 'run_as_processes' input specs. + +The `declearn.utils.run_as_processes` util was modularized to that routines +can be specified in various ways. Previously, they could only be passed as +`(func, args)` tuples. Now, they can either be passed as `(func, args)`, +`(func, kwargs)` or `(func, args, kwargs)`, where `args` is still a tuple +of positional arguments, and `kwargs` is a dict of keyword ones. + +## Other changes + +### Fix redundant sharing of model weights with clients + +`FederatedServer` now keeps track of clients having received the latest global +model weights, and avoids sending them redundantly with future training (or +evaluation) requests. To achieve this, `TrainRequest` and `EvaluationRequest` +now support setting their `weights` field to `None`. + +### Update TensorFlow supported versions + +Older TensorFlow versions (v2.5 to 2.10 included) were improperly marked as +supported in spite of `TensorflowOptiModule` requiring at least version 2.11 +to work (due to changes of the Keras Optimizer API). This has been corrected. + +The latest TensorFlow version (v2.16) introduces backward-breaking changes, due to the backend swap from Keras 2 to Keras 3. Our backend code was updated to +both add support for this newer Keras backend, and preserve existing support. + +Note that at the moment, the CI does not support TensorFlow above 2.13, due to +newer versions not being compatible with Python 3.8. As such, our code will be +tested to remain backward-compatible. Forward compatibility has been (and will +keep being) tested locally with a newer Python version. + +### Deprecate `declearn.dataset.load_from_json` + +As the `save_to_json` and `from_to_json` methods were removed from the +`declearn.dataset.Dataset` API in DecLearn 2.3.0, there is no longer a +guarantee that this function works (save with `InMemoryDataset`). + +As a consequence, this function should have been deprecated, and has now been +documented as such planned for removal in DecLearn 2.6 and/or 3.0. diff --git a/docs/setup.md b/docs/setup.md index 7b00c0aafd38d5390a402fb70dd21747b817de97..0ae2ead3edb1b48dcadf0d071c27ffecae105076 100644 --- a/docs/setup.md +++ b/docs/setup.md @@ -1,5 +1,17 @@ # Installation guide +This guide provides with all the required information to install `declearn`. + +**TL;DR**:<br/> +If you want to install the latest stable version with all of its optional +dependencies, simply run `pip install declearn[all]` from your desired +python (preferably virtual) environment. + +**Important note**:<br/> +When running a federated process with DecLearn, the server and all clients +should use the same `major.minor` version; otherwise, clients' registration +will fail verbosely, prompting to install the same version as the server's. + ## Requirements - python >= 3.8 diff --git a/docs/user-guide/fl_process.md b/docs/user-guide/fl_process.md index 85f9fd1d8bc8f70b9738e5371d07469735b7e2eb..6972092b14e41f2196f6978a66b0702759549a68 100644 --- a/docs/user-guide/fl_process.md +++ b/docs/user-guide/fl_process.md @@ -11,8 +11,10 @@ exposed here. ## Overall process orchestrated by the server - Initially: - - have the clients connect and register for training - - prepare model and optimizer objects on both sides + - the clients connect to the server and register for training + - the server may collect targetted metadata from clients when required + - the server sets up the model, optimizers, aggregator and metrics + - all clients receive instructions to set up these objects as well - Iteratively: - perform a training round - perform an evaluation round @@ -36,28 +38,40 @@ exposed here. registered, optionally under a given timeout delay) - close registration (reject future requests) - Client: - - gather metadata about the local training dataset - (_e.g._ dimensions and unique labels) - - connect to the server and send a request to join training, - including the former information + - connect to the server and send a request to join training - await the server's response (retry after a timeout if the request came in too soon, i.e. registration is not opened yet) -- messaging : (JoinRequest <-> JoinReply) ### Post-registration initialization +#### (Optional) Metadata exchange + +This step is optional, and depends on the trained model's requirement +for dataset information (typically, features shape and/or dtype). + +- Server: + - query clients for targetted metadata about the local training datasets +- Client: + - collect and send back queried metadata +- messaging: (MetadataQuery <-> MetadataReply) - Server: - - validate and aggregate clients-transmitted metadata - - finalize the model's initialization using those metadata - - send the model, local optimizer and evaluation metrics specs to clients + - validate and aggregate received information + - pass it to the model so as to finalize its initialization + +#### Initialization of the federated optimization problem + +- Server: + - set up the model, local and global optimizer, aggregator and metrics + - send specs to the clients so that they set up local counterpart objects - Client: - - instantiate the model, optimizer and metrics based on server instructions -- messaging: (InitRequest <-> GenericMessage) + - instantiate the model, optimizer, aggregator and metrics based on specs +- messaging: (InitRequest <-> InitReply) + +#### (Optional) Local differential privacy setup -### (Optional) Local differential privacy setup +This step is optional; a flag in the InitRequest at the previous step +indicates to clients that it is to happen, as a secondary substep. -- This step is optional; a flag in the InitRequest at the previous step - indicates to clients that it is to happen, as a secondary substep. - Server: - send hyper-parameters to set up local differential privacy, including dp-specific hyper-parameters and information on the planned training @@ -91,8 +105,8 @@ exposed here. - Server: - select clients that are to participate - - send data-batching parameters and shared model trainable weights - - (_send effort constraints, unused for now_) + - send data-batching parameters and effort constraints + - send shared model trainable weights - Client: - update model weights - perform evaluation steps based on effort constraints diff --git a/docs/user-guide/optimizer.md b/docs/user-guide/optimizer.md index 103ceb1d3e6e061f4f6d261ca335eb5b4e988c1d..0579f58f2583568f54d0b2d6c5196f2d105cf2b2 100644 --- a/docs/user-guide/optimizer.md +++ b/docs/user-guide/optimizer.md @@ -280,7 +280,7 @@ To implement Scaffold in Declearn, one needs to set up both server-side and client-side OptiModule plug-ins. The client-side module is in charge of both correcting input gradients and computing the required quantities to update the states at the end of each training round, while the server-side module merely -manages the computation and distribution of the global and correction states. +manages the computation and distribution of the global referencestate. The following snippet sets up a pair of client-side and server-side optimizers that implement Scaffold, here with a 0.001 learning rate on the client side and @@ -447,18 +447,17 @@ Declearn introduces the notion of "auxiliary variables" to cover such cases: - The packaging and distribution of module-wise auxiliary variables is done by `Optimizer.collect_aux_var` and `process_aux_var`, which orchestrate calls to the plugged-in modules' methods of the same name. -- The management and compartementalization of client-wise auxiliary variables - information is also automated as part of `declearn.main.FederatedServer`, to - prevent information leakage between clients. +- Exchanged information is formatted via dedicated `AuxVar` data structures + (inheriting `declearn.optimizer.modules.AuxVar`) that define how to aggregate + peers' data, and indicate how to use secure aggregation on top of it (when it + is possible to do so). #### OptiModule and Optimizer auxiliary variables API At the level of any `OptiModule`: -- `OptiModule.collect_aux_var` should output a dict that may either have a - simple `{key: value}` structure (for server-purposed or shared-across-clients - information), or a nested `{client_name: {key: value}}` structure (that is to - be split in order to send distinct information to the clients). +- `OptiModule.collect_aux_var` should output either `None` or an instance of + a module-specific `AuxVar` subclass wrapping data to be shared. - `OptiModule.process_aux_var` should expect a dict that has the same structure as that emitted by `collect_aux_var` (of this module class, or of a @@ -466,11 +465,12 @@ At the level of any `OptiModule`: At the level of a wrapping `Optimizer`: -- `Optimizer.collect_aux_var` emits a `{module_aux_name: module_emitted_dict}` - dict. +- `Optimizer.collect_aux_var` outputs a `{module_aux_name: module_aux_var}` + dict to be shared. -- `Optimizer.process_aux_var` expects a `{module_aux_name: module_emitted_dict}` - dict as well. +- `Optimizer.process_aux_var` expects a `{module_aux_name: module_aux_var}` + dict as well, containing either server-emitted or aggregated clients-emitted + data. As a consequence, you should note that: @@ -478,10 +478,8 @@ As a consequence, you should note that: that have the same `name` or `aux_name`. - If you are using our `Optimizer` within your own orchestration code (_i.e._ outside of our `FederatedServer` / `FederatedClient` main classes), it is up - to you to handle the restructuration of auxiliary variables to ensure that - (a) each client gets its own information (and not that of others), and that - (b) client-wise auxiliary variables are concatenated properly for the - server-side optimizer to process. + to you to handle the aggregation of client-wise auxiliary variables into the + module-wise single instance that the server should receive. #### Integration to the Declearn FL process @@ -651,5 +649,5 @@ In some cases, you might want to clip your batch-averaged gradients, _e.g._ to prevent exploding gradients issues. This is possible in Declearn, thanks to a couple of `OptiModule` subclasses: `L2Clipping` (name: `'l2-clipping'`) clips arrays of weights based on their L2-norm, while `L2GlobalClipping` (name: -`'l2-global-clipping`) clips all weights based on their global L2-norm (as if +`'l2-global-clipping'`) clips all weights based on their global L2-norm (as if concatenated into a single array). diff --git a/docs/user-guide/package.md b/docs/user-guide/package.md index 4ac2de145d0afb14d727cd67ea6a4b1172ad295c..76c0a32b03e45aa5d128958d5a308de135b1a00d 100644 --- a/docs/user-guide/package.md +++ b/docs/user-guide/package.md @@ -14,6 +14,8 @@ The package is organized into the following submodules:   Data interfacing API and implementations. - `main`:<br/>   Main classes implementing a Federated Learning process. +- `messaging`:<br/> +   API and default classes to define parsable messages for applications. - `metrics`:<br/>   Iterative and federative evaluation metrics computation tools. - `model`:<br/> @@ -24,6 +26,8 @@ The package is organized into the following submodules:   Type hinting utils, defined and exposed for code readability purposes. - `utils`:<br/>   Shared utils used (extensively) across all of declearn. +- `version`:<br/> +   DecLearn version information, as hard-coded constants. ## Main abstractions @@ -34,7 +38,9 @@ well as references on how to extend the support of `declearn` backend (notably, (de)serialization and configuration utils) to new custom concrete implementations inheriting the abstraction. -### `Model` +### Model and Tensors + +#### `Model` - Import: `declearn.model.api.Model` - Object: Interface framework-specific machine learning models. - Usage: Compute gradients, apply updates, compute loss... @@ -44,7 +50,7 @@ new custom concrete implementations inheriting the abstraction. - `declearn.model.torch.TorchModel` - Extend: use `declearn.utils.register_type(group="Model")` -### `Vector` +#### `Vector` - Import: `declearn.model.api.Vector` - Object: Interface framework-specific data structures. - Usage: Wrap and operate on model weights, gradients, updates... @@ -54,7 +60,34 @@ new custom concrete implementations inheriting the abstraction. - `declearn.model.torch.TorchVector` - Extend: use `declearn.model.api.register_vector_type` -### `OptiModule` +### Federated Optimization + +You may learn more about our (non-abstract) `Optimizer` API by reading our +[Optimizer guide](./optimizer.md). + +#### `Aggregator` +- Import: `declearn.aggregator.Aggregator` +- Object: Define model updates aggregation algorithms. +- Usage: Post-process client updates; finalize aggregated global ones. +- Examples: + - `declearn.aggregator.AveragingAggregator` + - `declearn.aggregator.GradientMaskedAveraging` +- Extend: + - Simply inherit from `Aggregator` (registration is automated). + - To avoid it, use `class MyAggregator(Aggregator, register=False)`. + +#### `ModelUpdates` +- Import: `declearn.aggregator.ModelUpdates` +- Object: Define exchanged model updates data and their aggregation. +- Usage: Share and aggregate client's updates for a given `Aggregator`. +- Examples: + - Each `Aggregator` has its own dedicated/supported `ModelUpdates` type(s). +- Extend: + - Simply inherit from `ModelUpdates` (registration is automated). + - Define a `name` class attribute and decorate as a `dataclass`. + - To avoid it, use `class MyModelUpdates(ModelUpdates, register=False)`. + +#### `OptiModule` - Import: `declearn.optimizer.modules.OptiModule` - Object: Define optimization algorithm bricks. - Usage: Plug into a `declearn.optimizer.Optimizer`. @@ -67,19 +100,33 @@ new custom concrete implementations inheriting the abstraction. - Simply inherit from `OptiModule` (registration is automated). - To avoid it, use `class MyModule(OptiModule, register=False)`. -### `Regularizer` -- Import: `declearn.optimizer.modules.Regularizer` +#### `Regularizer` +- Import: `declearn.optimizer.regularizers.Regularizer` - Object: Define loss-regularization terms as gradients modifiers. - Usage: Plug into a `declearn.optimizer.Optimizer`. - Examples: - - `declearn.optimizer.regularizer.FedProxRegularizer` - - `declearn.optimizer.regularizer.LassoRegularizer` - - `declearn.optimizer.regularizer.RidgeRegularizer` + - `declearn.optimizer.regularizers.FedProxRegularizer` + - `declearn.optimizer.regularizers.LassoRegularizer` + - `declearn.optimizer.regularizers.RidgeRegularizer` - Extend: - Simply inherit from `Regularizer` (registration is automated). - To avoid it, use `class MyRegularizer(Regularizer, register=False)`. -### `Metric` +#### `AuxVar` +- Import: `declearn.optimizer.modules.AuxVar` +- Object: Define exchanged data between a pair of `OptiModules` across the + clients/server boundary, and their aggregation. +- Usage: Share information from server to clients and reciprocally. +- Examples: + - `declearn.optimizer.modules.ScaffoldAuxVar` +- Extend: + - Simply inherit from `AuxVar` (registration is automated). + - Define a `name` class attribute and decorate as a `dataclass`. + - To avoid it, use `class MyAuxVar(AuxVar, register=False)`. + +### Evaluation Metrics + +#### `Metric` - Import: `declearn.metrics.Metric` - Object: Define evaluation metrics to compute iteratively and federatively. - Usage: Compute local and federated metrics based on local data streams. @@ -89,9 +136,22 @@ new custom concrete implementations inheriting the abstraction. - `declearn.metric.MuticlassAccuracyPrecisionRecall` - Extend: - Simply inherit from `Metric` (registration is automated). - - To avoid it, use `class MyMetric(Metric, register=False)` + - To avoid it, use `class MyMetric(Metric, register=False)`. + +#### `MetricState` +- Import: `declearn.metrics.MetricState` +- Object: Define exchanged data to compute a `Metric` and their aggregation. +- Usage: Share locally-computed metrics for their aggregation into global ones. +- Examples: + - Each `Metric` has its own dedicated/supported `MetricState` type(s). +- Extend: + - Simply inherit from `MetricState` (registration is automated). + - Define a `name` class attribute and decorate as a `dataclass`. + - To avoid it, use `class MyMetricState(MetricState, register=False)`. -### `NetworkClient` +### Network communication + +#### `NetworkClient` - Import: `declearn.communication.api.NetworkClient` - Object: Instantiate a network communication client endpoint. - Usage: Register for training, send and receive messages. @@ -102,7 +162,7 @@ new custom concrete implementations inheriting the abstraction. - Simply inherit from `NetworkClient` (registration is automated). - To avoid it, use `class MyClient(NetworkClient, register=False)`. -### `NetworkServer` +#### `NetworkServer` - Import: `declearn.communication.api.NetworkServer` - Object: Instantiate a network communication server endpoint. - Usage: Receive clients' requests, send and receive messages. @@ -113,13 +173,30 @@ new custom concrete implementations inheriting the abstraction. - Simply inherit from `NetworkServer` (registration is automated). - To avoid it, use `class MyServer(NetworkServer, register=False)`. -### `Dataset` +#### `Message` +- Import: `declearn.messaging.Message` +- Object: Define serializable/parsable message types and their data. +- Usage: Exchanged via communication endpoints to transmit data and + trigger behaviors based on type analysis. +- Examples: + - `declearn.messages.TrainRequest` + - `declearn.messages.TrainReply` + - `declearn.messages.Error` +- Extend: + - Simply inherit from `Message` (registration is automated). + - To avoid it, use `class MyMessage(Message, register=False)`. + +### Dataset + +#### `Dataset` - Import: `declearn.dataset.Dataset` - Object: Interface data sources agnostic to their format. - Usage: Yield (inputs, labels, weights) data batches, expose metadata. - Examples: - `declearn.dataset.InMemoryDataset` -- Extend: use `declearn.utils.register_type(group="Dataset")` + - `declearn.dataset.tensorflow.TensorflowDataset` + - `declearn.dataset.torch.TorchDataset` +- Extend: use `declearn.utils.register_type(group="Dataset")`. ## Full API Reference diff --git a/examples/heart-uci/run.py b/examples/heart-uci/run.py index 57937dd47e1b04f63952a62e8b1b5511f636d81a..676db0b6a5466000ff94f0915ac00f811623774c 100644 --- a/examples/heart-uci/run.py +++ b/examples/heart-uci/run.py @@ -51,9 +51,12 @@ def run_demo( # Run routines in isolated processes. Raise if any failed. success, outp = run_as_processes(server, *clients) if not success: + exceptions = "\n".join( + str(e) for e in outp if isinstance(e, RuntimeError) + ) raise RuntimeError( "Something went wrong during the demo. Exceptions caught:\n" - "\n".join(str(e) for e in outp if isinstance(e, RuntimeError)) + + exceptions ) diff --git a/pyproject.toml b/pyproject.toml index d9d4a1d34ea18ce2f1f4a0383bb4b5324c401b16..e9224145e21ae0eac94847a78efe33a29b8bbb01 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ requires = [ [project] name = "declearn" -version = "2.3.1" +version = "2.4.0" description = "Declearn - a python package for private decentralized learning." readme = "README.md" requires-python = ">=3.8" @@ -53,7 +53,7 @@ all = [ # all non-tests extra dependencies "jax[cpu] ~= 0.4.1", "opacus ~= 1.4", "protobuf >= 3.19", - "tensorflow ~= 2.5", + "tensorflow ~= 2.11", "torch >= 1.13, < 3.0", "websockets >= 10.1, < 13.0", ] @@ -71,7 +71,7 @@ haiku = [ "jax[cpu] ~= 0.4.1", # NOTE: GPU support must be manually installed ] tensorflow = [ - "tensorflow ~= 2.5", + "tensorflow ~= 2.11", ] torch = [ # generic requirements for Torch "torch >= 1.13, < 3.0", @@ -94,7 +94,7 @@ docs = [ ] # test-specific dependencies tests = [ - "black ~= 23.0", + "black ~= 24.0", "mypy ~= 1.0", "pylint ~= 3.0", "pytest ~= 7.4", diff --git a/scripts/gen_docs.py b/scripts/gen_docs.py index a2c6808659be8c78f77e22babcd007f9fb4a422a..645ca0ff357916c297c8304b2116f31ffedfeb07 100644 --- a/scripts/gen_docs.py +++ b/scripts/gen_docs.py @@ -151,6 +151,9 @@ def _generate_public_submodules_doc( pub_mod = {} for key, mod in module.modules.items(): if not key.startswith("_"): + if isinstance(mod, griffe.dataclasses.Alias): + key = f"{key} (alias re-export)" + mod = mod.target pub_mod[key] = generate_module_docs(mod, docdir) return pub_mod diff --git a/scripts/run_tests.sh b/scripts/run_tests.sh index 605532c2b91c068eccd759db8e77567069784cdd..b2cb7703c5a50df7a45d841d2775e498b94f3d40 100644 --- a/scripts/run_tests.sh +++ b/scripts/run_tests.sh @@ -203,7 +203,7 @@ run_torch13_tests() { ' echo "Installing torch 1.13 and its co-dependencies." TORCH_DEPS=$(pip freeze | grep -e torch -e opacus) - pip install .[torch1] + pip install "opacus == 1.4.0" "torch ~=1.13.0" if [[ $? -eq 0 ]]; then echo "Running unit tests for torch 1.13." command="pytest $@ @@ -211,6 +211,7 @@ run_torch13_tests() { test/model/test_torch_model.py " echo -e "\e[34m$command\e[0m" + $command status=$? else echo "\e[31mSkipping tests as installation failed.\e[0m" diff --git a/test/communication/test_utils.py b/test/communication/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..362a76d120b7b423c0d4a3a60e7af51c0ea3de6f --- /dev/null +++ b/test/communication/test_utils.py @@ -0,0 +1,190 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for message-type-verification utils.""" + +import dataclasses +from unittest import mock + +import pytest + +from declearn.communication.api import NetworkClient, NetworkServer +from declearn.communication.utils import ( + ErrorMessageException, + MessageTypeException, + verify_client_messages_validity, + verify_server_message_validity, +) +from declearn.messaging import Error, Message, SerializedMessage + + +@dataclasses.dataclass +class SimpleMessage(Message, register=False): # type: ignore[call-arg] + """Stub Message subclass for this module's unit tests.""" + + typekey = "simple" + + content: str + + +@pytest.mark.asyncio +async def test_verify_client_messages_validity_expected_simple(): + """Test 'verify_client_messages_validity' with valid messages.""" + # Setup simple messages and have the server except them. + netwk = mock.create_autospec(NetworkServer, instance=True) + messages = {f"client_{i}": SimpleMessage(f"message_{i}") for i in range(3)} + received = { + key: SerializedMessage(type(val), val.to_string().split("\n", 1)[1]) + for key, val in messages.items() + } + results = await verify_client_messages_validity( + netwk=netwk, received=received, expected=SimpleMessage + ) + # Assert that results match expectations, and no message was sent. + assert isinstance(results, dict) + assert results == messages + netwk.broadcast_message.assert_not_called() + + +@pytest.mark.asyncio +async def test_verify_client_messages_validity_expected_error(): + """Test 'verify_client_messages_validity' with expected Error messages.""" + # Setup simple messages and have the server except them. + netwk = mock.create_autospec(NetworkServer, instance=True) + messages = {f"client_{i}": Error(f"message_{i}") for i in range(3)} + received = { + key: SerializedMessage(type(val), val.to_string().split("\n", 1)[1]) + for key, val in messages.items() + } + results = await verify_client_messages_validity( + netwk=netwk, received=received, expected=Error + ) + # Assert that results match expectations, and no message was sent. + assert isinstance(results, dict) + assert results == messages + netwk.broadcast_message.assert_not_called() + + +@pytest.mark.asyncio +async def test_verify_client_messages_validity_unexpected_types(): + """Test 'verify_client_messages_validity' with invalid messages.""" + # Setup simple messages, but have the server except Error messages. + netwk = mock.create_autospec(NetworkServer, instance=True) + messages = {f"client_{i}": SimpleMessage(f"message_{i}") for i in range(3)} + received = { + key: SerializedMessage(type(val), val.to_string().split("\n", 1)[1]) + for key, val in messages.items() + } + # Assert that an exception is raised. + with pytest.raises(MessageTypeException): + await verify_client_messages_validity( + netwk=netwk, received=received, expected=Error + ) + # Assert that an Error message was broadcast to all clients. + netwk.broadcast_message.assert_awaited_once_with( + message=Error(mock.ANY), clients=set(received) + ) + + +@pytest.mark.asyncio +async def test_verify_client_messages_validity_unexpected_error(): + """Test 'verify_client_messages_validity' with 'Error' messages.""" + # Setup simple messages, but have one be an Error. + netwk = mock.create_autospec(NetworkServer, instance=True) + messages = {f"client_{i}": SimpleMessage(f"message_{i}") for i in range(2)} + messages["client_2"] = Error("error_message") + received = { + key: SerializedMessage(type(val), val.to_string().split("\n", 1)[1]) + for key, val in messages.items() + } + # Assert that an exception is raised. + with pytest.raises(ErrorMessageException): + await verify_client_messages_validity( + netwk=netwk, received=received, expected=SimpleMessage + ) + # Assert that an Error message was broadcast to non-Error-sending clients. + netwk.broadcast_message.assert_awaited_once_with( + message=Error(mock.ANY), clients={"client_0", "client_1"} + ) + + +@pytest.mark.asyncio +async def test_verify_server_message_validity_expected_simple(): + """Test 'verify_server_message_validity' with a valid message.""" + # Setup a simple message matching client expectations. + netwk = mock.create_autospec(NetworkClient, instance=True) + message = SimpleMessage("message") + received = SerializedMessage( + SimpleMessage, message.to_string().split("\n", 1)[1] + ) + result = await verify_server_message_validity( + netwk=netwk, received=received, expected=SimpleMessage + ) + # Assert that results match expectations, and no message was sent. + assert isinstance(result, Message) + assert result == message + netwk.send_message.assert_not_called() + + +@pytest.mark.asyncio +async def test_verify_server_message_validity_expected_error(): + """Test 'verify_server_message_validity' with an expected Error message.""" + # Setup a simple message matching client expectations. + netwk = mock.create_autospec(NetworkClient, instance=True) + message = Error("message") + received = SerializedMessage(Error, message.to_string().split("\n", 1)[1]) + result = await verify_server_message_validity( + netwk=netwk, received=received, expected=Error + ) + # Assert that results match expectations, and no message was sent. + assert isinstance(result, Message) + assert result == message + netwk.send_message.assert_not_called() + + +@pytest.mark.asyncio +async def test_verify_server_message_validity_unexpected_type(): + """Test 'verify_server_message_validity' with an unexpected message.""" + # Setup a simple message, but have the client except an Error one. + netwk = mock.create_autospec(NetworkClient, instance=True) + message = SimpleMessage("message") + received = SerializedMessage( + SimpleMessage, message.to_string().split("\n", 1)[1] + ) + # Assert that an exception is raised. + with pytest.raises(MessageTypeException): + await verify_server_message_validity( + netwk=netwk, received=received, expected=Error + ) + # Assert that an Error was sent to the server. + netwk.send_message.assert_awaited_once_with(message=Error(mock.ANY)) + + +@pytest.mark.asyncio +async def test_verify_server_message_validity_unexpected_error(): + """Test 'verify_server_message_validity' with an unexpected 'Error'.""" + # Setup an unexpected Error message. + netwk = mock.create_autospec(NetworkClient, instance=True) + message = Error("message") + received = SerializedMessage(Error, message.to_string().split("\n", 1)[1]) + # Assert that an exception is raised. + with pytest.raises(ErrorMessageException): + await verify_server_message_validity( + netwk=netwk, received=received, expected=SimpleMessage + ) + # Assert that no Error was sent to the server. + netwk.send_message.assert_not_called() diff --git a/test/dataset/dataset_testbase.py b/test/dataset/dataset_testbase.py index 7dc36b80c52fb4cf863a2de0c83f142510cfb8c9..818df3f3286c8bfcaf89224b950fca2bcfb9bdbb 100644 --- a/test/dataset/dataset_testbase.py +++ b/test/dataset/dataset_testbase.py @@ -27,7 +27,6 @@ from declearn.test_utils import assert_batch_equal, to_numpy class DatasetTestToolbox: - """TestCase fixture-provider protocol.""" # pylint: disable=too-few-public-methods @@ -46,8 +45,7 @@ class DatasetTestToolbox: class DatasetTestSuite: - - """Base tests for declearn Dataset abstract methods""" + """Base tests for declearn Dataset abstract methods.""" def test_generate_batches_batchsize(self, toolbox: DatasetTestToolbox): """Test batch_size argument to test_generate_batches method""" diff --git a/test/dataset/test_torch_dataset.py b/test/dataset/test_torch_dataset.py index 1c8e8ce1feda176907b212166c4a3d93a9dc7a00..8af2fa4a43f83302c7faf3faff674f96260487cf 100644 --- a/test/dataset/test_torch_dataset.py +++ b/test/dataset/test_torch_dataset.py @@ -43,8 +43,7 @@ SEED = 0 class CustomDataset(torch.utils.data.Dataset): - - """Basic torch.utils.data.Dataset for testing purposes""" + """Basic torch.utils.data.Dataset for testing purposes.""" def __init__(self, inputs, labels, weights) -> None: self.inputs = inputs @@ -66,8 +65,7 @@ class CustomDataset(torch.utils.data.Dataset): class TorchDatasetTestToolbox(DatasetTestToolbox): - - """Toolbox for Torch Dataset""" + """Toolbox for Torch Dataset.""" # pylint: disable=too-few-public-methods diff --git a/test/functional/test_main.py b/test/functional/test_main.py index 7fc21958d5b0bcb0502e7499a25017fa141e8521..491ffd95d3628ec6408e13b31b23cc87c06315e2 100644 --- a/test/functional/test_main.py +++ b/test/functional/test_main.py @@ -42,10 +42,12 @@ from declearn.utils import set_device_policy # pylint: disable=ungrouped-imports FRAMEWORKS = ["Sksgd", "Tflow", "Torch"] try: - import tensorflow as tf # type: ignore + import tensorflow # type: ignore # pylint: disable=unused-import except ModuleNotFoundError: FRAMEWORKS.remove("Tflow") else: + # pylint: disable=import-error,no-name-in-module + import tensorflow.keras as tf_keras # type: ignore from declearn.model.tensorflow import TensorflowModel try: import torch @@ -103,23 +105,23 @@ class DeclearnTestCase: ) -> Model: """Return a TensorflowModel suitable for the learning task.""" if self.kind == "Reg": - output_layer = tf.keras.layers.Dense(1) + output_layer = tf_keras.layers.Dense(1) loss = "mse" elif self.kind == "Bin": - output_layer = tf.keras.layers.Dense(1, activation="sigmoid") + output_layer = tf_keras.layers.Dense(1, activation="sigmoid") loss = "binary_crossentropy" elif self.kind == "Clf": - output_layer = tf.keras.layers.Dense(4, activation="softmax") + output_layer = tf_keras.layers.Dense(4, activation="softmax") loss = "sparse_categorical_crossentropy" else: raise ValueError("Invalid 'kind' attribute.") stack = [ - tf.keras.layers.InputLayer((32,)), - tf.keras.layers.Dense(32, activation="relu"), - tf.keras.layers.Dense(8, activation="relu"), + tf_keras.layers.InputLayer((32,)), + tf_keras.layers.Dense(32, activation="relu"), + tf_keras.layers.Dense(8, activation="relu"), output_layer, ] - model = tf.keras.Sequential(stack) + model = tf_keras.Sequential(stack) return TensorflowModel(model, loss, metrics=None) def _build_torch_model( diff --git a/test/main/test_checkpoint.py b/test/main/test_checkpoint.py index 03ee3a2c42f31cfb379bd72f3b425b5562421791..61c26f17ddb10b21eca2a796210fcd3171eb540b 100644 --- a/test/main/test_checkpoint.py +++ b/test/main/test_checkpoint.py @@ -89,8 +89,7 @@ def create_config_file(checkpointer: Checkpointer, type_obj: str) -> str: class TestCheckpointer: - - """Unit tests for Checkpointer class""" + """Unit tests for Checkpointer class.""" def test_init_default(self, tmp_path: str) -> None: """Test `Checkpointer.__init__` with `max_history=None`.""" diff --git a/test/metrics/test_binary_roc.py b/test/metrics/test_binary_roc.py index c1a290743f007ee1c1dcd4ee8f65f4db18deabaa..5010f712bfed2cf8ecd9ac3c2a15fa5f48cd33c0 100644 --- a/test/metrics/test_binary_roc.py +++ b/test/metrics/test_binary_roc.py @@ -78,13 +78,11 @@ def test_case_fixture( ) -def _test_case_1d() -> ( - Tuple[ - Dict[str, np.ndarray], - Dict[str, Union[float, np.ndarray]], - Dict[str, Union[float, np.ndarray]], - ] -): +def _test_case_1d() -> Tuple[ + Dict[str, np.ndarray], + Dict[str, Union[float, np.ndarray]], + Dict[str, Union[float, np.ndarray]], +]: """Return a test case with 1-D samples (standard binary classif).""" # similar inputs as for Binary APR; pylint: disable=duplicate-code inputs = { @@ -124,13 +122,11 @@ def _test_case_1d() -> ( return inputs, states, scores -def _test_case_2d() -> ( - Tuple[ - Dict[str, np.ndarray], - Dict[str, Union[float, np.ndarray]], - Dict[str, Union[float, np.ndarray]], - ] -): +def _test_case_2d() -> Tuple[ + Dict[str, np.ndarray], + Dict[str, Union[float, np.ndarray]], + Dict[str, Union[float, np.ndarray]], +]: """Return a test case with 2-D samples (multilabel binary classif).""" # similar inputs as for Binary APR; pylint: disable=duplicate-code inputs = { diff --git a/test/model/model_testing.py b/test/model/model_testing.py index 4806e0df541368272f737b8a9669b16de94761a4..76cf45834469b95a82cfc5472221972f67e89e06 100644 --- a/test/model/model_testing.py +++ b/test/model/model_testing.py @@ -17,13 +17,14 @@ """Shared testing code for TensorFlow and Torch models' unit tests.""" +import copy import json from typing import Any, Generic, List, Protocol, Tuple, Type, TypeVar, Union import numpy as np from declearn.model.api import Model, Vector -from declearn.test_utils import to_numpy +from declearn.test_utils import assert_json_serializable_dict, to_numpy from declearn.typing import Batch from declearn.utils import json_pack, json_unpack @@ -59,14 +60,23 @@ class ModelTestCase(Protocol, Generic[VectorT]): class ModelTestSuite: """Unit tests for a declearn Model.""" - def test_serialization( + def test_get_config( self, test_case: ModelTestCase, ) -> None: - """Check that the model can be JSON-(de)serialized properly.""" + """Check that the model's config is JSON-serializable.""" model = test_case.model - config = json.dumps(model.get_config()) - other = model.from_config(json.loads(config)) + config = model.get_config() + assert_json_serializable_dict(config) + + def test_from_config( + self, + test_case: ModelTestCase, + ) -> None: + """Check that the model can be instantiated from its config.""" + model = test_case.model + config = model.get_config() + other = model.from_config(copy.deepcopy(config)) assert model.get_config() == other.get_config() assert model.device_policy == other.device_policy diff --git a/test/model/test_haiku_model.py b/test/model/test_haiku_model.py index 7bb9ca7b014fe2579743ece0f6cdba240a41b8a1..7e31f04854c36508a9619dd59c62295981bc1c0b 100644 --- a/test/model/test_haiku_model.py +++ b/test/model/test_haiku_model.py @@ -234,11 +234,18 @@ class TestHaikuModel(ModelTestSuite): """Unit tests for declearn.model.tensorflow.TensorflowModel.""" @pytest.mark.filterwarnings("ignore: Our custom Haiku serialization") - def test_serialization( + def test_get_config( self, test_case: ModelTestCase, ) -> None: - super().test_serialization(test_case) + super().test_get_config(test_case) + + @pytest.mark.filterwarnings("ignore: Our custom Haiku serialization") + def test_from_config( + self, + test_case: ModelTestCase, + ) -> None: + super().test_from_config(test_case) @pytest.mark.parametrize( "criterion_type", ["names", "pytree", "predicate"] diff --git a/test/model/test_sksgd_model.py b/test/model/test_sksgd_model.py index 35f1b23804a018acaa1aa27451807edc9281eed8..f88de8a7267bf1f8d97184198ab459c98f2aefef 100644 --- a/test/model/test_sksgd_model.py +++ b/test/model/test_sksgd_model.py @@ -120,15 +120,6 @@ def fixture_test_case( class TestSklearnSGDModel(ModelTestSuite): """Unit tests for declearn.model.sklearn.SklearnSGDModel.""" - def test_serialization( # type: ignore # Liskov does not matter here - self, - test_case: SklearnSGDTestCase, - ) -> None: - # Avoid re-running tests that are unaltered by data parameters. - if test_case.s_weights or test_case.as_sparse: - return None - return super().test_serialization(test_case) - def test_initialization( self, test_case: SklearnSGDTestCase, diff --git a/test/model/test_tflow_model.py b/test/model/test_tflow_model.py index b24e083a114109ebe8b7d879a5b903e0f9a837e7..65805b5fdde68ef61ecb46d2c31c18005b101086 100644 --- a/test/model/test_tflow_model.py +++ b/test/model/test_tflow_model.py @@ -29,6 +29,8 @@ try: import tensorflow as tf # type: ignore except ModuleNotFoundError: pytest.skip("TensorFlow is unavailable", allow_module_level=True) +else: + import tensorflow.keras as tf_keras # type: ignore from declearn.model.tensorflow import TensorflowModel, TensorflowVector from declearn.model.tensorflow.utils import build_keras_loss @@ -106,32 +108,32 @@ class TensorflowTestCase(ModelTestCase): """Suited toy binary-classification keras model.""" if self.kind.startswith("MLP"): stack = [ - tf.keras.layers.Dense(32, activation="relu"), - tf.keras.layers.Dense(16, activation="relu"), - tf.keras.layers.Dense(1, activation="sigmoid"), + tf_keras.layers.Dense(32, activation="relu"), + tf_keras.layers.Dense(16, activation="relu"), + tf_keras.layers.Dense(1, activation="sigmoid"), ] shape = [None, 64] if self.kind == "MLP-tune": stack[0].trainable = False elif self.kind == "RNN": stack = [ - tf.keras.layers.Embedding(100, 32), - tf.keras.layers.LSTM(16, activation="tanh"), - tf.keras.layers.Dense(1, activation="sigmoid"), + tf_keras.layers.Embedding(100, 32), + tf_keras.layers.LSTM(16, activation="tanh"), + tf_keras.layers.Dense(1, activation="sigmoid"), ] shape = [None, 128] elif self.kind == "CNN": cnn_kwargs = {"padding": "same", "activation": "relu"} stack = [ - tf.keras.layers.Conv2D(32, 7, **cnn_kwargs), - tf.keras.layers.MaxPool2D((8, 8)), - tf.keras.layers.Conv2D(16, 5, **cnn_kwargs), - tf.keras.layers.AveragePooling2D((8, 8)), - tf.keras.layers.Reshape((16,)), - tf.keras.layers.Dense(1, activation="sigmoid"), + tf_keras.layers.Conv2D(32, 7, **cnn_kwargs), + tf_keras.layers.MaxPool2D((8, 8)), + tf_keras.layers.Conv2D(16, 5, **cnn_kwargs), + tf_keras.layers.AveragePooling2D((8, 8)), + tf_keras.layers.Reshape((16,)), + tf_keras.layers.Dense(1, activation="sigmoid"), ] shape = [None, 64, 64, 3] - tfmod = tf.keras.Sequential(stack) + tfmod = tf_keras.Sequential(stack) tfmod.build(shape) # as model is built, no data_info is required return TensorflowModel(tfmod, loss="binary_crossentropy", metrics=None) @@ -173,14 +175,23 @@ class TestTensorflowModel(ModelTestSuite): test_case: ModelTestCase, ) -> None: """Check that `get_weights` behaves properly with frozen weights.""" + # Set up a model with a frozen layer. model = test_case.model tfmod = model.get_wrapped_model() tfmod.layers[0].trainable = False # freeze the first layer's weights - w_all = model.get_weights() - w_trn = model.get_weights(trainable=True) - assert set(w_trn.coefs).issubset(w_all.coefs) # check on keys - assert w_trn.coefs.keys() == {v.name for v in tfmod.trainable_weights} - assert w_all.coefs.keys() == {v.name for v in tfmod.weights} + # Access names of the model's variables (via TensorFlow/Keras API). + if hasattr(tf_keras, "version") and tf_keras.version().startswith("3"): + names_all_wghts = {v.value.name for v in tfmod.weights} + names_trainable = {v.value.name for v in tfmod.trainable_weights} + else: + names_all_wghts = {v.name for v in tfmod.weights} + names_trainable = {v.name for v in tfmod.trainable_weights} + # Verify that DecLearn-accessed weights' names match. + weights_all = model.get_weights() + weights_trn = model.get_weights(trainable=True) + assert set(weights_trn.coefs).issubset(weights_all.coefs) + assert weights_all.coefs.keys() == names_all_wghts + assert weights_trn.coefs.keys() == names_trainable def test_set_frozen_weights( self, @@ -190,7 +201,7 @@ class TestTensorflowModel(ModelTestSuite): # Setup a model with some frozen weights, and gather trainable ones. model = test_case.model tfmod = model.get_wrapped_model() - tfmod.layers[0].trainable = False # freeze the first layer's weights + tfmod.layers[-1].trainable = False # freeze the last layer's weights w_trn = model.get_weights(trainable=True) # Test that `set_weights` works if and only if properly parametrized. with pytest.raises(KeyError): @@ -210,7 +221,11 @@ class TestTensorflowModel(ModelTestSuite): assert policy.idx == 0 tfmod = model.get_wrapped_model() device = f"{test_case.device}:0" - for var in tfmod.weights: + if hasattr(tf_keras, "version") and tf_keras.version().startswith("3"): + variables = [var.value for var in tfmod.weights] + else: + variables = tfmod.weights + for var in variables: assert var.device.endswith(device) @@ -220,48 +235,50 @@ class TestBuildKerasLoss: def test_build_keras_loss_from_string_class_name(self) -> None: """Test `build_keras_loss` with a valid class name string input.""" loss = build_keras_loss( - "BinaryCrossentropy", tf.keras.losses.Reduction.SUM + "BinaryCrossentropy", tf_keras.losses.Reduction.SUM ) - assert isinstance(loss, tf.keras.losses.BinaryCrossentropy) - assert loss.reduction == tf.keras.losses.Reduction.SUM + assert isinstance(loss, tf_keras.losses.BinaryCrossentropy) + assert loss.reduction == tf_keras.losses.Reduction.SUM def test_build_keras_loss_from_string_function_name(self) -> None: """Test `build_keras_loss` with a valid function name string input.""" loss = build_keras_loss( - "binary_crossentropy", tf.keras.losses.Reduction.SUM + "binary_crossentropy", tf_keras.losses.Reduction.SUM ) - assert isinstance(loss, tf.keras.losses.BinaryCrossentropy) - assert loss.reduction == tf.keras.losses.Reduction.SUM + assert isinstance(loss, tf_keras.losses.BinaryCrossentropy) + assert loss.reduction == tf_keras.losses.Reduction.SUM def test_build_keras_loss_from_string_noclass_function_name(self) -> None: """Test `build_keras_loss` with a valid function name string input.""" - loss = build_keras_loss("mse", tf.keras.losses.Reduction.SUM) - assert isinstance(loss, tf.keras.losses.Loss) + if hasattr(tf_keras, "version") and tf_keras.version().startswith("3"): + pytest.skip("Skipping test that no longer works with Keras 3.") + loss = build_keras_loss("mse", tf_keras.losses.Reduction.SUM) + assert isinstance(loss, tf_keras.losses.Loss) assert hasattr(loss, "loss_fn") - assert loss.loss_fn is tf.keras.losses.mse - assert loss.reduction == tf.keras.losses.Reduction.SUM + assert loss.loss_fn is tf_keras.losses.mse + assert loss.reduction == tf_keras.losses.Reduction.SUM def test_build_keras_loss_from_loss_instance(self) -> None: """Test `build_keras_loss` with a valid keras Loss input.""" # Set up a BinaryCrossentropy loss instance. - loss = tf.keras.losses.BinaryCrossentropy( - reduction=tf.keras.losses.Reduction.SUM + loss = tf_keras.losses.BinaryCrossentropy( + reduction=tf_keras.losses.Reduction.SUM ) - assert loss.reduction == tf.keras.losses.Reduction.SUM + assert loss.reduction == tf_keras.losses.Reduction.SUM # Pass it through the util and verify that reduction changes. - loss = build_keras_loss(loss, tf.keras.losses.Reduction.NONE) - assert isinstance(loss, tf.keras.losses.BinaryCrossentropy) - assert loss.reduction == tf.keras.losses.Reduction.NONE + loss = build_keras_loss(loss, tf_keras.losses.Reduction.NONE) + assert isinstance(loss, tf_keras.losses.BinaryCrossentropy) + assert loss.reduction == tf_keras.losses.Reduction.NONE def test_build_keras_loss_from_loss_function(self) -> None: """Test `build_keras_loss` with a valid keras loss function input.""" loss = build_keras_loss( - tf.keras.losses.binary_crossentropy, tf.keras.losses.Reduction.SUM + tf_keras.losses.binary_crossentropy, tf_keras.losses.Reduction.SUM ) - assert isinstance(loss, tf.keras.losses.Loss) + assert isinstance(loss, tf_keras.losses.Loss) assert hasattr(loss, "loss_fn") - assert loss.loss_fn is tf.keras.losses.binary_crossentropy - assert loss.reduction == tf.keras.losses.Reduction.SUM + assert loss.loss_fn is tf_keras.losses.binary_crossentropy + assert loss.reduction == tf_keras.losses.Reduction.SUM def test_build_keras_loss_from_custom_function(self) -> None: """Test `build_keras_loss` with a valid custom loss function input.""" @@ -270,8 +287,8 @@ class TestBuildKerasLoss: """Custom loss function.""" return tf.reduce_sum(tf.cast(tf.equal(y_true, y_pred), tf.float32)) - loss = build_keras_loss(loss_fn, tf.keras.losses.Reduction.SUM) - assert isinstance(loss, tf.keras.losses.Loss) + loss = build_keras_loss(loss_fn, tf_keras.losses.Reduction.SUM) + assert isinstance(loss, tf_keras.losses.Loss) assert hasattr(loss, "loss_fn") assert loss.loss_fn is loss_fn - assert loss.reduction == tf.keras.losses.Reduction.SUM + assert loss.reduction == tf_keras.losses.Reduction.SUM diff --git a/test/model/test_torch_model.py b/test/model/test_torch_model.py index 74419484009f347e28886d2244837aea1d40af90..725e81c863da0614319620c110f0b2684120e361 100644 --- a/test/model/test_torch_model.py +++ b/test/model/test_torch_model.py @@ -17,7 +17,7 @@ """Unit tests for TorchModel.""" -import json +import copy import os import typing from typing import List, Literal, Tuple @@ -219,7 +219,7 @@ class TestTorchModel(ModelTestSuite): """Unit tests for declearn.model.torch.TorchModel.""" @pytest.mark.filterwarnings("ignore: PyTorch JSON serialization") - def test_serialization( + def test_get_config( self, test_case: ModelTestCase, ) -> None: @@ -228,26 +228,44 @@ class TestTorchModel(ModelTestSuite): # due to the (de)serialization of a custom nn.Module # the expected model behaviour is, however, correct try: - self._test_serialization(test_case) + super().test_get_config(test_case) except AssertionError: pytest.skip( "skipping failed test due to custom nn.Module pickling" ) - self._test_serialization(test_case) + super().test_get_config(test_case) - def _test_serialization( + @pytest.mark.filterwarnings("ignore: PyTorch JSON serialization") + def test_from_config( + self, + test_case: ModelTestCase, + ) -> None: + if getattr(test_case, "kind", "") == "RNN": + # NOTE: this test fails on python 3.8 but succeeds in 3.10 + # due to the (de)serialization of a custom nn.Module + # the expected model behaviour is, however, correct + try: + self._test_from_config(test_case) + except AssertionError: + pytest.skip( + "skipping failed test due to custom nn.Module pickling" + ) + self._test_from_config(test_case) + + def _test_from_config( self, test_case: ModelTestCase, ) -> None: - """Check that the model can be JSON-(de)serialized properly. + """Check that the model can be instantiated from its config. - This method replaces the parent `test_serialization` one. + This method replaces the parent `test_from_config` one. """ # Same setup as in parent test: a model and a config-based other. model = test_case.model - config = json.dumps(model.get_config()) - other = model.from_config(json.loads(config)) - # Verify that both models have the same device policy. + config = model.get_config() + other = model.from_config(copy.deepcopy(config)) + # Verify that both models have the same config and device policy. + assert other.get_config() == config assert model.device_policy == other.device_policy # Verify that both models have a similar structure of modules. mod_a = list(model.get_wrapped_model().modules()) diff --git a/test/optimizer/test_tflow_optim.py b/test/optimizer/test_tflow_optim.py index 81d3346aadd791e41977c997130a88ba8f74aa6d..7658fed03cd42216fe752f4c696d8a52a2755b07 100644 --- a/test/optimizer/test_tflow_optim.py +++ b/test/optimizer/test_tflow_optim.py @@ -31,6 +31,8 @@ try: import tensorflow as tf # type: ignore except ModuleNotFoundError: pytest.skip("TensorFlow is unavailable", allow_module_level=True) +else: + import tensorflow.keras as tf_keras # type: ignore # pylint: enable=duplicate-code from declearn.model.tensorflow import TensorflowOptiModule, TensorflowVector @@ -201,7 +203,11 @@ class TestTensorflowOptiModule(OptiModuleTestSuite): grads = GradientsTestCase("tensorflow").mock_gradient updts = module.run(grads) # Assert that the outputs and internal states are properly placed. + if hasattr(tf_keras, "version") and tf_keras.version().startswith("3"): + optimizer_variables = [var.value for var in module.optim.variables] + else: + optimizer_variables = module.optim.variables() assert all(device in t.device for t in updts.coefs.values()) - assert all(device in t.device for t in module.optim.variables()) + assert all(device in t.device for t in optimizer_variables) # Reset device policy to run other tests on CPU as expected. set_device_policy(gpu=False)