diff --git a/AUTHORS b/AUTHORS index b37d13179802baadb2c7884a7134c8c5cf1ec729..7815f8784177617514f4a770847189d29037a562 100644 --- a/AUTHORS +++ b/AUTHORS @@ -1,7 +1,7 @@ This file maintains the list of present and past declearn authors. A secondary file listing punctual open-source contributors may complement it. -Declearn 2.4 - 2.5 +Declearn 2.4 - 2.6 - Paul Andrey Declearn 2.1 - 2.3 diff --git a/declearn/__init__.py b/declearn/__init__.py index 0e52c31b2cab3e2cb4cf6e978c54270d9a90095a..287df3c4557a4b485f7e8b2093735fd518b8b70e 100644 --- a/declearn/__init__.py +++ b/declearn/__init__.py @@ -15,15 +15,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Declearn - a python package for private decentralized learning. +"""DecLearn - a modular and extensible framework for Federated Learning. -Declearn is a modular framework to set up and run federated learning +DecLearn is a modular framework to set up and run federated learning processes. It is being developed by the MAGNET team of INRIA Lille, with the aim of providing users with a modular and extensible framework to implement federated learning algorithms and apply them to real-world (or simulated) data using any common machine learning framework. -Declearn provides with abstractions that enable algorithms to be written +DecLearn provides with abstractions that enable algorithms to be written agnostic to the actual computation framework as well as with workable interfaces that cover some of the most popular frameworks, such as Scikit-Learn, TensorFlow and PyTorch. diff --git a/declearn/aggregator/_api.py b/declearn/aggregator/_api.py index 468b837b520c2cdd0e86a8d6b644e381c341ff9c..56813d50f9abb6452c72f19f76cf564b44de168b 100644 --- a/declearn/aggregator/_api.py +++ b/declearn/aggregator/_api.py @@ -19,7 +19,6 @@ import abc import dataclasses -import warnings from typing import Any, ClassVar, Dict, Generic, Type, TypeVar, Union from typing_extensions import Self # future: import from typing (py >=3.11) @@ -185,48 +184,6 @@ class Aggregator(Generic[ModelUpdatesT], metaclass=abc.ABCMeta): """Instantiate an Aggregator from its configuration dict.""" return cls(**config) - def aggregate( - self, - updates: Dict[str, Vector[T]], - n_steps: Dict[str, int], # revise: abstract~generalize kwargs use - ) -> Vector[T]: - """DEPRECATED - Aggregate input vectors into a single one. - - Parameters - ---------- - updates: dict[str, Vector] - Client-wise updates, as a dictionary with clients' names as - string keys and updates as Vector values. - n_steps: dict[str, int] - Client-wise number of local training steps performed during - the training round having produced the updates. - - Returns - ------- - gradients: Vector - Aggregated updates, as a Vector - treated as gradients by - the server-side optimizer. - - Raises - ------ - TypeError - If the input `updates` are an empty dict. - """ - warnings.warn( - "'Aggregator.aggregate' was deprecated in DecLearn v2.4 in favor " - "of new API methods. It will be removed in DecLearn v2.6 and/or " - "v3.0.", - DeprecationWarning, - ) - if not updates: - raise TypeError("'Aggregator.aggregate' received an empty dict.") - partials = [ - self.prepare_for_sharing(updates[client], n_steps[client]) - for client in updates - ] - aggregated = sum(partials[1:], start=partials[0]) - return self.finalize_updates(aggregated) - def list_aggregators() -> Dict[str, Type[Aggregator]]: """Return a mapping of registered Aggregator subclasses. diff --git a/declearn/aggregator/_avg.py b/declearn/aggregator/_avg.py index 05dfa5b40edc04b82e60c3a702b833a657a741fb..333b4b9d77543ce896294b97833fe5eb7063698e 100644 --- a/declearn/aggregator/_avg.py +++ b/declearn/aggregator/_avg.py @@ -17,8 +17,7 @@ """FedAvg-like mean-aggregation class.""" -import warnings -from typing import Any, Dict, Optional +from typing import Any, Dict from declearn.aggregator._api import Aggregator, ModelUpdates @@ -44,7 +43,6 @@ class AveragingAggregator(Aggregator[ModelUpdates]): def __init__( self, steps_weighted: bool = True, - client_weights: Optional[Dict[str, float]] = None, ) -> None: """Instantiate an averaging aggregator. @@ -53,37 +51,14 @@ class AveragingAggregator(Aggregator[ModelUpdates]): steps_weighted: Whether to conduct a weighted averaging of local model updates based on local numbers of training steps. - client_weights: - DEPRECATED - this argument no longer affects computations, - save when using the deprecated 'aggregate' method. - Optional dict of client-wise base weights to use. - If None, homogeneous base weights are used. - - Notes - ----- - * One may specify `client_weights` and use `steps_weighted=True`. - In that case, the product of the client's base weight and their - number of training steps taken will be used (and unit-normed). - * One may use incomplete `client_weights`. In that case, unknown- - clients' base weights will be set to 1. """ self.steps_weighted = steps_weighted - self.client_weights = client_weights or {} - if client_weights: # pragma: no cover - warnings.warn( - f"'client_weights' argument to '{self.__class__.__name__}' was" - " deprecated in DecLearn v2.4 and is no longer used, saved by" - " the deprecated 'aggregate' method. It will be removed in" - " DecLearn v2.6 and/or v3.0.", - DeprecationWarning, - ) def get_config( self, ) -> Dict[str, Any]: return { "steps_weighted": self.steps_weighted, - "client_weights": self.client_weights, } def prepare_for_sharing( @@ -103,63 +78,3 @@ class AveragingAggregator(Aggregator[ModelUpdates]): updates: ModelUpdates, ) -> Vector: return updates.updates / updates.weights - - def aggregate( - self, - updates: Dict[str, Vector], - n_steps: Dict[str, int], - ) -> Vector: - # Make use of 'client_weights' as part of this DEPRECATED method. - with warnings.catch_warnings(): - warnings.simplefilter(action="ignore", category=DeprecationWarning) - weights = self.compute_client_weights(updates, n_steps) - steps_weighted = self.steps_weighted - try: - self.steps_weighted = True - return super().aggregate(updates, weights) # type: ignore - finally: - self.steps_weighted = steps_weighted - - def compute_client_weights( - self, - updates: Dict[str, Vector], - n_steps: Dict[str, int], - ) -> Dict[str, float]: - """Compute weights to use when averaging a given set of updates. - - This method is DEPRECATED as of DecLearn v2.4. - It will be removed in DecLearn 2.6 and/or 3.0. - - Parameters - ---------- - updates: dict[str, Vector] - Client-wise updates, as a dictionary with clients' names as - string keys and updates as Vector values. - n_steps: dict[str, int] - Client-wise number of local training steps performed during - the training round having produced the updates. - - Returns - ------- - weights: dict[str, float] - Client-wise updates-averaging weights, suited to the input - parameters and normalized so that they sum to 1. - """ - warnings.warn( - f"'{self.__class__.__name__}.compute_client_weights' was" - " deprecated in DecLearn v2.4. It will be removed in DecLearn" - " v2.6 and/or v3.0.", - DeprecationWarning, - ) - if self.steps_weighted: - weights = { - client: steps * self.client_weights.get(client, 1.0) - for client, steps in n_steps.items() - } - else: - weights = { - client: self.client_weights.get(client, 1.0) - for client in updates - } - total = sum(weights.values()) - return {client: weight / total for client, weight in weights.items()} diff --git a/declearn/aggregator/_gma.py b/declearn/aggregator/_gma.py index dfdcc1d70f99f1753c90d29ba856c45f7ade4c0f..4832d4a69aa73e0d6607c131633df14339531a3c 100644 --- a/declearn/aggregator/_gma.py +++ b/declearn/aggregator/_gma.py @@ -18,7 +18,6 @@ """Gradient Masked Averaging aggregation class.""" import dataclasses -import warnings from typing import Any, Dict, Optional, Tuple from typing_extensions import Self # future: import from typing (py >=3.11) @@ -99,7 +98,6 @@ class GradientMaskedAveraging(Aggregator[GMAModelUpdates]): self, threshold: float = 1.0, steps_weighted: bool = True, - client_weights: Optional[Dict[str, float]] = None, ) -> None: """Instantiate a gradient masked averaging aggregator. @@ -111,20 +109,9 @@ class GradientMaskedAveraging(Aggregator[GMAModelUpdates]): steps_weighted: bool, default=True Whether to weight updates based on the number of optimization steps taken by the clients (relative to one another). - client_weights: dict[str, float] or None, default=None - Optional dict of client-wise base weights to use. - If None, homogeneous base weights are used. - - Notes - ----- - * One may specify `client_weights` and use `steps_weighted=True`. - In that case, the product of the client's base weight and their - number of training steps taken will be used (and unit-normed). - * One may use incomplete `client_weights`. In that case, unknown- - clients' base weights will be set to 1. """ self.threshold = threshold - self._avg = AveragingAggregator(steps_weighted, client_weights) + self._avg = AveragingAggregator(steps_weighted) def get_config( self, @@ -162,24 +149,3 @@ class GradientMaskedAveraging(Aggregator[GMAModelUpdates]): scores = (1 - clip) * scores + clip # s = 1 if s > t else s # Correct outputs' magnitude and return them. return values * scores - - def compute_client_weights( # pragma: no cover - self, - updates: Dict[str, Vector], - n_steps: Dict[str, int], - ) -> Dict[str, float]: - """Compute weights to use when averaging a given set of updates. - - This method is DEPRECATED as of DecLearn v2.4. - It will be removed in DecLearn 2.6 and/or 3.0. - """ - # pylint: disable=duplicate-code - warnings.warn( - f"'{self.__class__.__name__}.compute_client_weights' was" - " deprecated in DecLearn v2.4. It will be removed in DecLearn" - " v2.6 and/or v3.0.", - DeprecationWarning, - ) - with warnings.catch_warnings(): - warnings.simplefilter(action="ignore", category=DeprecationWarning) - return self._avg.compute_client_weights(updates, n_steps) diff --git a/declearn/communication/__init__.py b/declearn/communication/__init__.py index 93a508e0e6e2ab8cedadae70efd3a24be1e8fd9b..5f15e89b754edbb12817387787a00c53571a3e20 100644 --- a/declearn/communication/__init__.py +++ b/declearn/communication/__init__.py @@ -28,7 +28,6 @@ This module contains the following core submodules: * [utils][declearn.communication.utils]: Utils related to network communication endpoints' setup and usage. - It re-exports publicly from `utils` the following elements: * [build_client][declearn.communication.build_client]: @@ -52,10 +51,6 @@ the associated third-party dependencies are available: * [websockets][declearn.communication.websockets]: WebSockets-based network communication endpoints. Requires the `websockets` third-party package. - -Additionnally, for retro-compatibility purposes, it exports the DEPRECATED -[messaging][declearn.communication.messaging] submodule, that should no -longer be used, as its contents were re-dispatched elsewhere in DecLearn. """ # Messaging API and base tools: @@ -79,6 +74,3 @@ try: from . import websockets except ImportError: # pragma: no cover _INSTALLABLE_BACKENDS["websockets"] = ("websockets",) - -# DEPRECATED submodule, kept for retro-compatibility until 2.6 and/or 3.0. -from . import messaging diff --git a/declearn/communication/api/_client.py b/declearn/communication/api/_client.py index fa5f1648a642ba16fb6f0de3f78ef67efe3145af..2482d315267c8f167525bc4f71083175c10641e7 100644 --- a/declearn/communication/api/_client.py +++ b/declearn/communication/api/_client.py @@ -339,20 +339,3 @@ class NetworkClient(metaclass=abc.ABCMeta): ) self.logger.critical(error) raise TypeError(error) - - async def check_message( - self, - timeout: Optional[float] = None, - ) -> SerializedMessage: - """Await a message from the server, with optional timeout. - - This method is DEPRECATED in favor of the `recv_message` one. - It acts as an alias and will be removed in v2.6 and/or v3.0. - """ - warnings.warn( - "'NetworkServer.check_message' was renamed as 'recv_message' " - "in DecLearn 2.4. It now acts as an alias, but will be removed " - "in version 2.6 and/or 3.0.", - DeprecationWarning, - ) - return await self.recv_message(timeout) diff --git a/declearn/communication/messaging/__init__.py b/declearn/communication/messaging/__init__.py deleted file mode 100644 index 510fe6d6fad88992296c6fcd649d895d5cddbc23..0000000000000000000000000000000000000000 --- a/declearn/communication/messaging/__init__.py +++ /dev/null @@ -1,57 +0,0 @@ -# coding: utf-8 - -# Copyright 2023 Inria (Institut National de Recherche en Informatique -# et Automatique) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""DEPRECATED submodule defining messaging containers and flags. - -This submodule was deprecated in DecLearn 2.4 in favor of `declearn.messaging`. -It should no longer be used, and will be removed in 2.6 and/or 3.0. - -Most of its contents are re-exports of non-deprecated classes and functions -from 'declearn.messaging'. Others will trigger deprecation warnings (and may -cause failures) if used. - -Deprecated classes uniquely-defined here are: - -* [Empty][declearn.communication.messaging.Empty] -* [GetMessageRequest][declearn.communication.messaging.GetMessageRequest] -* [JoinReply][declearn.communication.messaging.JoinReply] -* [JoinRequest][declearn.communication.messaging.JoinRequest] - -The `flags` submodule is also re-exported, but should preferably be imported -as `declearn.communication.api.backend.flags`. -""" - -from declearn.communication.api.backend import flags - -from ._messages import ( - CancelTraining, - Empty, - Error, - GenericMessage, - GetMessageRequest, - EvaluationReply, - EvaluationRequest, - InitRequest, - JoinReply, - JoinRequest, - Message, - PrivacyRequest, - StopTraining, - TrainReply, - TrainRequest, - parse_message_from_string, -) diff --git a/declearn/communication/messaging/_messages.py b/declearn/communication/messaging/_messages.py deleted file mode 100644 index f81574c2016fc7bbf796110ff20677e035e4b4a4..0000000000000000000000000000000000000000 --- a/declearn/communication/messaging/_messages.py +++ /dev/null @@ -1,153 +0,0 @@ -# coding: utf-8 - -# Copyright 2023 Inria (Institut National de Recherche en Informatique -# et Automatique) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Dataclasses defining messages used in declearn communications.""" - -import abc -import dataclasses -import warnings -from typing import Any, Dict, Optional - - -from declearn.messaging import ( - CancelTraining, - Error, - EvaluationReply, - EvaluationRequest, - GenericMessage, - InitRequest, - Message, - PrivacyRequest, - SerializedMessage, - StopTraining, - TrainReply, - TrainRequest, -) - - -__all__ = [ - "CancelTraining", - "Empty", - "Error", - "EvaluationReply", - "EvaluationRequest", - "GenericMessage", - "GetMessageRequest", - "InitRequest", - "JoinReply", - "JoinRequest", - "Message", - "PrivacyRequest", - "StopTraining", - "TrainReply", - "TrainRequest", - "parse_message_from_string", -] - - -@dataclasses.dataclass -class DeprecatedMessage(Message, register=False, metaclass=abc.ABCMeta): - """DEPRECATED Message subtype.""" - - def __post_init__( - self, - ) -> None: - warnings.warn( - f"'{self.__class__.__name__}' was deprecated in DecLearn v2.4. " - "It should no longer be used and may cause failures. It will be " - "removed in DecLearn v2.6 and/or v3.0", - DeprecationWarning, - ) - - -@dataclasses.dataclass -class Empty(DeprecatedMessage): - """DEPRECATED empty message class.""" - - typekey = "empty" - - -@dataclasses.dataclass -class GetMessageRequest(DeprecatedMessage): - """DEPRECATED message-retrieval query message class.""" - - typekey = "get_message" - - timeout: Optional[int] = None - - -@dataclasses.dataclass -class JoinRequest(DeprecatedMessage): - """DEPRECATED process joining query message class.""" - - typekey = "join_request" - - name: str - data_info: Dict[str, Any] - version: Optional[str] = None - - -@dataclasses.dataclass -class JoinReply(DeprecatedMessage): - """DEPRECATED process joining reply message class.""" - - typekey = "join_reply" - - accept: bool - flag: str - - -def parse_message_from_string( - string: str, -) -> Message: - """DEPRECATED - Instantiate a Message from its serialized string. - - This function was DEPRECATED in DecLearn 2.4 and will be removed - in v2.6 and/or v3.0. Use the `declearn.messaging.SerializedMessage` - API to parse serialized message strings. - - Parameters - ---------- - string: - Serialized string dump of the message. - - Returns - ------- - message: - Message instance recovered from the input string. - - Raises - ------ - KeyError - If the string's typekey does not match any supported Message - subclass. - TypeError - If the string cannot be parsed to identify a message typekey. - ValueError - If the serialized data fails to be properly decoded. - """ - warnings.warn( - "'parse_message_from_string' was deprecated in DecLearn 2.4, in " - "favor of using 'declearn.messaging.SerializedMessage' to parse " - "and deserialize 'Message' instances from strings. It will be " - "removed in DecLearn version 2.6 and/or 3.0.", - DeprecationWarning, - ) - serialized = SerializedMessage.from_message_string( - string - ) # type: SerializedMessage[Any] - return serialized.deserialize() diff --git a/declearn/dataset/__init__.py b/declearn/dataset/__init__.py index 7436e1ceacea4b5235fbf4f2d97bce198f3d386c..d885c587b28f705e3d9cc4e958a142a07cb64257 100644 --- a/declearn/dataset/__init__.py +++ b/declearn/dataset/__init__.py @@ -29,8 +29,6 @@ API tools Abstract base class defining an API to access training or testing data. * [DataSpec][declearn.dataset.DataSpecs]: Dataclass to wrap a dataset's metadata. -* [load_dataset_from_json][declearn.dataset.load_dataset_from_json] - DEPRECATED Utility function to parse a JSON into a dataset object. Dataset subclasses ------------------ @@ -67,6 +65,6 @@ Utility entry-point from . import utils from . import examples -from ._base import Dataset, DataSpecs, load_dataset_from_json +from ._base import Dataset, DataSpecs from ._inmemory import InMemoryDataset from ._split_data import split_data diff --git a/declearn/dataset/_base.py b/declearn/dataset/_base.py index c0f3bbe7fe9a6719b8823c2f417e647f0e3cea51..64cf1c8357c818d000b62f918b223b14a1a4ccdd 100644 --- a/declearn/dataset/_base.py +++ b/declearn/dataset/_base.py @@ -19,11 +19,10 @@ import abc import dataclasses -import warnings from typing import Any, Iterator, List, Optional, Set, Tuple, Union from declearn.typing import Batch -from declearn.utils import access_registered, create_types_registry, json_load +from declearn.utils import create_types_registry __all__ = [ "DataSpecs", @@ -109,42 +108,3 @@ class Dataset(metaclass=abc.ABCMeta): Optional weights associated with the samples, that are typically used to balance a model's loss or metrics. """ - - -def load_dataset_from_json(path: str) -> Dataset: # pragma: no cover - """DEPRECATED Instantiate a dataset based on a JSON dump file. - - Parameters - ---------- - path: str - Path to a JSON file output by the `save_to_json` - method of the Dataset that is being reloaded. - The actual type of dataset should be specified - under the "name" field of that file. - - Returns - ------- - dataset: Dataset - Dataset (subclass) instance, reloaded from JSON. - - Raises - ------ - NotImplementedError - If the target `Dataset` does not implement a `load_from_json` - method (which was removed from the API in DecLearn 2.3.0). - """ - warnings.warn( - "'load_dataset_from_json' was deprecated in Declearn 2.4.0, after" - "'Dataset.load_from_json' was removed from the API in v2.3.0. It " - "may raise a 'NotImplementedError', and will be removed in DecLearn " - "2.6 and/or 3.0.", - category=DeprecationWarning, - stacklevel=2, - ) - dump = json_load(path) - cls = access_registered(dump["name"], group="Dataset") - if not hasattr(cls, "load_from_json"): - raise NotImplementedError( - f"Dataset class '{cls}' does not implement 'load_from_json'." - ) - return cls.load_from_json(path) diff --git a/declearn/main/_client.py b/declearn/main/_client.py index 506046e634945a75160649fdd7d9e433cb59a5f5..0b03c36a38482841664e7b69295fda6b603149d8 100644 --- a/declearn/main/_client.py +++ b/declearn/main/_client.py @@ -32,7 +32,7 @@ from declearn.communication.utils import ( NetworkClientConfig, verify_server_message_validity, ) -from declearn.dataset import Dataset, load_dataset_from_json +from declearn.dataset import Dataset from declearn.fairness.api import FairnessControllerClient from declearn.main.utils import Checkpointer from declearn.messaging import Message, SerializedMessage @@ -116,16 +116,12 @@ class FederatedClient: if replace_netwk_logger: self.netwk.logger = self.logger # Assign the wrapped training dataset. - if isinstance(train_data, str): - train_data = load_dataset_from_json(train_data) if not isinstance(train_data, Dataset): - raise TypeError("'train_data' should be a Dataset or path to one.") + raise TypeError("'train_data' should be a Dataset.") self.train_data = train_data # Assign the wrapped validation dataset (if any). - if isinstance(valid_data, str): - valid_data = load_dataset_from_json(valid_data) if not (valid_data is None or isinstance(valid_data, Dataset)): - raise TypeError("'valid_data' should be a Dataset or path to one.") + raise TypeError("'valid_data' should be a Dataset.") self.valid_data = valid_data # Assign an optional checkpointer. if checkpoint is not None: diff --git a/declearn/main/_server.py b/declearn/main/_server.py index cb4d9e86f46dabb7409ce87b03b08ce5c733a4af..cb53805d8c186b44ab688ca64543799b46b5c00f 100644 --- a/declearn/main/_server.py +++ b/declearn/main/_server.py @@ -137,7 +137,7 @@ class FederatedServer: self._decrypter = None # type: Optional[Decrypter] self._secagg_peers = set() # type: Set[str] # Set up private attributes to record the loss values and best weights. - self._loss = {} # type: Dict[int, float] + self._losses = [] # type: List[float] self._best = None # type: Optional[Vector] # Set up a private attribute to prevent redundant weights sharing. self._clients_holding_latest_model = set() # type: Set[str] @@ -284,7 +284,6 @@ class FederatedServer: # Iteratively run training and evaluation rounds. round_i = 0 while True: - # Run (opt.) fairness; training; evaluation. await self.fairness_round(round_i, config.fairness) round_i += 1 await self.training_round(round_i, config.training) @@ -292,9 +291,15 @@ class FederatedServer: # Decide whether to keep training for at least one round. if not self._keep_training(round_i, config.rounds, early_stop): break - # When checkpointing, evaluate the last model's fairness. + # When checkpointing, force evaluating the last model. if self.ckptr is not None: - await self.fairness_round(round_i, config.fairness) + if round_i % config.evaluate.frequency: + await self.evaluation_round( + round_i, config.evaluate, force_run=True + ) + await self.fairness_round( + round_i, config.fairness, force_run=True + ) # Interrupt training when time comes. self.logger.info("Stopping training.") await self.stop_training(round_i) @@ -542,8 +547,12 @@ class FederatedServer: self, round_i: int, fairness_cfg: FairnessConfig, + force_run: bool = False, ) -> None: - """Orchestrate a fairness round. + """Orchestrate a fairness round, when configured to do so. + + If fairness is not set, or if `round_i` is to be skipped based + on `fairness_cfg.frequency`, do nothing. Parameters ---------- @@ -553,9 +562,15 @@ class FederatedServer: FairnessConfig dataclass instance wrapping data-batching and computational effort constraints hyper-parameters for fairness evaluation. + force_run: + Whether to disregard `fairness_cfg.frequency` and run the + round (provided a fairness controller is setup). """ + # Early exit when fairness is not set or the round is to be skipped. if self.fairness is None: return + if (round_i % fairness_cfg.frequency) and not force_run: + return # Run SecAgg setup when needed. self.logger.info("Initiating fairness-enforcing round %s", round_i) clients = self.netwk.client_names # FUTURE: enable sampling(?) @@ -721,17 +736,24 @@ class FederatedServer: self, round_i: int, valid_cfg: EvaluateConfig, + force_run: bool = False, ) -> None: - """Orchestrate an evaluation round. + """Orchestrate an evaluation round, when configured to do so. + + If `round_i` is to be skipped based on `fairness_cfg.frequency`, + do nothing. Parameters ---------- round_i: int - Index of the evaluation round. + Index of the latest training round. valid_cfg: EvaluateConfig EvaluateConfig dataclass instance wrapping data-batching and computational effort constraints hyper-parameters. """ + # Early exit when the evaluation round is to be skipped. + if (round_i % valid_cfg.frequency) and not force_run: + return # Select participating clients. Run SecAgg setup when needed. self.logger.info("Initiating evaluation round %s", round_i) clients = self._select_evaluation_round_participants() @@ -766,8 +788,8 @@ class FederatedServer: metrics, results if len(results) > 1 else {} ) # Record the global loss, and update the kept "best" weights. - self._loss[round_i] = loss - if loss == min(self._loss.values()): + self._losses.append(loss) + if loss == min(self._losses): self._best = self.model.get_weights() def _select_evaluation_round_participants( @@ -891,7 +913,7 @@ class FederatedServer: self.ckptr.save_metrics( metrics=metrics, prefix=f"metrics_{client}", - append=bool(self._loss), + append=bool(self._losses), timestamp=timestamp, ) @@ -917,7 +939,7 @@ class FederatedServer: self.logger.info("Maximum number of training rounds reached.") return False if early_stop is not None: - early_stop.update(self._loss[round_i]) + early_stop.update(self._losses[-1]) if not early_stop.keep_training: self.logger.info("Early stopping criterion reached.") return False @@ -937,7 +959,7 @@ class FederatedServer: self.logger.info("Recovering weights that yielded the lowest loss.") message = messaging.StopTraining( weights=self._best or self.model.get_weights(), - loss=min(self._loss.values()) if self._loss else float("nan"), + loss=min(self._losses, default=float("nan")), rounds=rounds, ) self.logger.info("Notifying clients that training is over.") diff --git a/declearn/main/config/_dataclasses.py b/declearn/main/config/_dataclasses.py index 867343e689aa097273bd238bee54dbd32739e82f..91f04102477703cbfd3985c6d7091789258fe176 100644 --- a/declearn/main/config/_dataclasses.py +++ b/declearn/main/config/_dataclasses.py @@ -127,12 +127,20 @@ class TrainingConfig: class EvaluateConfig(TrainingConfig): """Dataclass wrapping parameters for an evaluation round. + Exclusive attributes + -------------------- + frequency: int + Number of training rounds to run between evaluation ones. + By default, run an evaluation round after each training one. + Please refer to the parent class `TrainingConfig` for details - on the wrapped parameters / attribute. Note that `n_epoch` is - dropped when this config is turned into an EvaluationRequest - message. + on the other wrapped parameters / attributes. + + Note that `n_epoch` is dropped when this config is turned into + an EvaluationRequest message. """ + frequency: int = 1 drop_remainder: bool = False @property @@ -238,6 +246,9 @@ class FairnessConfig: ---------- batch_size: int Number of samples per processed data batch. + frequency: int + Number of training rounds to run between fairness ones. + By default, run a fairness round before each training one. n_batch: int or None, default=None Optional maximum number of batches to draw. If None, use the entire training dataset. @@ -249,5 +260,6 @@ class FairnessConfig: """ batch_size: int = 32 + frequency: int = 1 n_batch: Optional[int] = None thresh: Optional[float] = None diff --git a/declearn/metrics/_api.py b/declearn/metrics/_api.py index 04405777044158eae800e14dfb8f20073cc9847f..46a53d50421ba469d651a1a5b66d21d222f49911 100644 --- a/declearn/metrics/_api.py +++ b/declearn/metrics/_api.py @@ -18,7 +18,6 @@ """Iterative and federative evaluation metrics base class.""" import abc -import warnings from copy import deepcopy from typing import Any, ClassVar, Dict, Generic, Optional, Type, TypeVar, Union @@ -248,46 +247,6 @@ class Metric(Generic[MetricStateT], metaclass=abc.ABCMeta): ) self._states = deepcopy(states) # type: ignore - def agg_states( - self, - states: MetricStateT, - ) -> None: - """Aggregate provided state variables into self ones. - - This method is DEPRECATED as of DecLearn v2.4, in favor of - merely aggregating `MetricState` instances, using either - their `aggregate` method or the overloaded `+` operator. - It will be removed in DecLearn 2.6 and/or 3.0. - - This method is designed to aggregate results from multiple - similar metrics objects into a single one before computing - its results. - - Parameters - ---------- - states: - `MetricState` emitted by another instance of this class - via its `get_states` method. - - Raises - ------ - TypeError - If `states` is of improper type. - """ - warnings.warn( - "'Metric.agg_states' was deprecated in DecLearn v2.4, in favor " - "of aggregating 'MetricState' instances directly, and setting " - "final aggregated states using 'Metric.set_state'. It will be " - "removed in DecLearn 2.6 and/or 3.0.", - DeprecationWarning, - ) - if not isinstance(states, self.state_cls): - raise TypeError( - f"'{self.__class__.__name__}.set_states' expected " - f"'{self.state_cls}' inputs, got '{type(states)}'." - ) - self.set_states(self._states + states) - def __init_subclass__( cls, register: bool = True, diff --git a/declearn/metrics/_wrapper.py b/declearn/metrics/_wrapper.py index cbb5cee2afc71f39ccafca789b818b10d79f3bc8..9e7a17098d67b9e67aac2d5cd83d6341129ad180 100644 --- a/declearn/metrics/_wrapper.py +++ b/declearn/metrics/_wrapper.py @@ -17,7 +17,6 @@ """Wrapper for an ensemble of Metric objects.""" -import warnings from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import numpy as np @@ -215,49 +214,6 @@ class MetricSet: if metric.name in states: metric.set_states(states[metric.name]) - def agg_states( - self, - states: Dict[str, MetricState], - ) -> None: - """Aggregate provided state variables into self ones. - - This method is DEPRECATED as of DecLearn v2.4, in favor of - merely aggregating `MetricState` instances, using either - their `aggregate` method or the overloaded `+` operator. - It will be removed in DecLearn 2.6 and/or 3.0. - - This method is designed to aggregate results from multiple - similar metrics objects into a single one before computing - its results. - - Parameters - ---------- - states: dict[str, float or numpy.ndarray] - Dict of states emitted by another instance of this class - via its `get_states` method. - - Raises - ------ - KeyError - If any state variable is missing from `states`. - TypeError - If any state variable is of improper type. - ValueError - If any array state variable is of improper shape. - """ - warnings.warn( - "'MetricSet.agg_states' was deprecated in DecLearn v2.4, in favor " - "of aggregating 'MetricState' instances directly, and setting " - "final aggregated states using 'MetricSet.set_state'. It will be " - "removed in DecLearn 2.6 and/or 3.0.", - DeprecationWarning, - ) - with warnings.catch_warnings(): - warnings.simplefilter(action="ignore", category=DeprecationWarning) - for metric in self.metrics: - if metric.name in states: - metric.agg_states(states[metric.name]) - def get_config( self, ) -> Dict[str, Any]: diff --git a/declearn/optimizer/modules/_scaffold.py b/declearn/optimizer/modules/_scaffold.py index c5c9a4a8c5aad2b38bf0ca47e04dbae2575eb015..277bac03c786c9f6fee74d629caa5a75d4e36ff7 100644 --- a/declearn/optimizer/modules/_scaffold.py +++ b/declearn/optimizer/modules/_scaffold.py @@ -35,7 +35,7 @@ References import dataclasses import uuid import warnings -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Optional, Set, Tuple, Union from declearn.model.api import Vector from declearn.optimizer.modules._api import AuxVar, OptiModule @@ -386,25 +386,10 @@ class ScaffoldServerModule(OptiModule[ScaffoldAuxVar]): def __init__( self, - clients: Optional[List[str]] = None, ) -> None: - """Instantiate the server-side SCAFFOLD gradients-correction module. - - Parameters - ---------- - clients: - DEPRECATED and unused starting with declearn 2.4. - Optional list of known clients' id strings. - """ + """Instantiate the server-side SCAFFOLD gradients-correction module.""" self.s_state = 0.0 # type: Union[Vector, float] self.clients = set() # type: Set[str] - if clients: # pragma: no cover - warnings.warn( - "ScaffoldServerModule's 'clients' argument has been deprecated" - " as of declearn v2.4, and no longer has any effect. It will" - " be removed in declearn 2.6 and/or 3.0.", - DeprecationWarning, - ) def run( self, diff --git a/declearn/training/_manager.py b/declearn/training/_manager.py index a3d158f8dcf3d0ccc4dae7ca8888cd6a1075111f..bfe21be65c3864a393cd519445dad8f5cdebe1c3 100644 --- a/declearn/training/_manager.py +++ b/declearn/training/_manager.py @@ -23,8 +23,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import tqdm +from declearn import messaging from declearn.aggregator import Aggregator -from declearn.communication import messaging from declearn.dataset import Dataset from declearn.metrics import ( MeanMetric, diff --git a/declearn/training/dp/_manager.py b/declearn/training/dp/_manager.py index 26aa2245a8e7f71b221a5da64ce00bec7ecd113a..291bea84caeee4eae16f2596f6f51b72b1da4ede 100644 --- a/declearn/training/dp/_manager.py +++ b/declearn/training/dp/_manager.py @@ -23,8 +23,8 @@ from typing import List, Optional, Tuple, Union from opacus.accountants import IAccountant, create_accountant # type: ignore from opacus.accountants.utils import get_noise_multiplier # type: ignore +from declearn import messaging from declearn.aggregator import Aggregator -from declearn.communication import messaging from declearn.dataset import Dataset from declearn.metrics import MetricInputType, MetricSet from declearn.model.api import Model diff --git a/declearn/version.py b/declearn/version.py index 70a1a14e9b2da930342dc069a7d713b7a1742893..f516287592a9cd133bb72923aeea2f28d3de271b 100644 --- a/declearn/version.py +++ b/declearn/version.py @@ -17,5 +17,5 @@ """DecLearn version information, as hard-coded constants.""" -VERSION = "2.6.0.dev2" +VERSION = "2.6.0" """Version information of the installed DecLearn package.""" diff --git a/declearn3.md b/declearn3.md new file mode 100644 index 0000000000000000000000000000000000000000..c64d7a10f8fc727a49ca87ff2bb61743897d72ab --- /dev/null +++ b/declearn3.md @@ -0,0 +1,40 @@ +* Integrate Fairness-aware methods (update branch, merge) (2 weeks) + +* Add Analytics (API + processes + SecAgg (incl. Metrics)) (1 Month) + + +* Refactor routines + - NOTES: + - Write some structure to handle Model + Optimizer(s) + Aggregator + - Write logic for setup (duplicate from Server to Client / Peer to Peer) + - Write logic for local training (provided data is available) + - Write logic for aggregation (then, modularize to support Gossip, etc.) + +* Revise serialization (move to Msgpack ; revise custom code ; screen perfs) (1 week) + + +--- + + +* Configuration tools => improve usability and extensibility ; see with Rosalie's work + +* Interface and use FLamby (for examples and/or benchmarks) + -> See with Paul & Edwige + +* Nouveaux algos + - Personalization via Hybrid training (2 weeks) + - Things with Rosalie? + +* Profile performances (benchmark: with asv or, easier, using current logging) + +* Revise Network Communication: + - Modularize timeout on responses => go minimal on that + - Enable connection loss/re-connection => not now, wait for tests / actual problems + - Improve the way clients are identified by MessagesHandler? => test again for issues (see if required/interesting) + - Improve the tackling of MessagesHandler receiving multiple messages from or for the same client? => wait for roadmap on decentralized + +* Add client sampling + +* Split NetworkServer from FederatedServer + +* (Later) Quickrun mode revision diff --git a/docs/release-notes/SUMMARY.md b/docs/release-notes/SUMMARY.md index ff3d13e338b50e5bf84dc090f2b6f52f79280c7f..e134a525eed22649cec92b813ac904fb666e8216 100644 --- a/docs/release-notes/SUMMARY.md +++ b/docs/release-notes/SUMMARY.md @@ -1,3 +1,4 @@ +- [v2.6.0](v2.6.0.md) - [v2.5.1](v2.5.1.md) - [v2.5.0](v2.5.0.md) - [v2.4.1](v2.4.1.md) diff --git a/docs/release-notes/v2.5.0.md b/docs/release-notes/v2.5.0.md index 27a4e923066459157ccb6334ef1279007ea96467..da1ecd42946c62fe824610594132d29a3e471b5f 100644 --- a/docs/release-notes/v2.5.0.md +++ b/docs/release-notes/v2.5.0.md @@ -1,4 +1,4 @@ -# declearn v2.4.0 +# declearn v2.5.0 Released: 13/05/2024 diff --git a/docs/release-notes/v2.6.0.md b/docs/release-notes/v2.6.0.md new file mode 100644 index 0000000000000000000000000000000000000000..4d1bf837ab1ff84fc170f41c4975aac977dae912 --- /dev/null +++ b/docs/release-notes/v2.6.0.md @@ -0,0 +1,176 @@ +# declearn v2.6.0 + +Released: 26/07/2024 + +## Release Highlights + +### Group-Fairness capabilities + +This new version of DecLearn brings a whole new type of federated optimization +algorithms to the party, introducing an API and various algorithms to measure +and optimize the group fairness of the trained model over the union of clients' +training datasets. + +This is the result of a year-long collaboration with Michaël Perrot and Brahim +Erraji to design and evaluate algorithms to learn models under group-fairness +constraints in a federated learning setting, using either newly-introduced +algorithms or existing ones from the litterature. + +A dedicated [guide on fairness features](../user-guide/fairness.md) was added +to the documentation, that is the advised entry-point for people interested +in getting around these new features. The guide is both about explaining what +(group-)fairness in machine learning is, what the design choices (and limits) +of our new API are, how the API works, which algorithms are available, and how +to write custom fairness definitions or fairness-enforcing algorithms. + +As noted in the guide, end-users with an interest in fairness-aware federated +learning are very welcome to get in touch if they have feedback, questions or +requests about the current capabilities and possible future ones. + +To sum it up shortly: + +- The newly-introduced `declearn.fairness` submodule provides with an API and + concrete algorithms to enforce fairness constraints in a federated learning + process. +- When such an algorithm is to be used, the only required modifications to an + existing user-defined process is to: + - Plug a `declearn.fairness.api.FairnessControllerServer` subclass instance + (or its configuration) into the `declearn.main.config.FLOptimConfig` that + is defined by the server. + - Wrap each and every client's training dataset as a + `declearn.fairness.api.FairnessDataset`; for instance using + `declearn.fairness.core.FairnessInMemoryDataset`, which is an extension + of the base `declearn.dataset.InMemoryDataset`. +- There are currently three available algorithms to enforce fairness: + - Fed-FairGrad, defined under `declearn.fairness.fairgrad` + - Fed-FairBatch/FedFB, defined under `declearn.fairness.fairbatch` + - FairFed, defined under `declearn.fairness.fairfed` +- In addition, `declearn.fairness.monitor` provides with an algorithm to + merely measure fairness throughout training, typically to evaluate baselines + when conducting experiments on fairness-enforcing algorithms. +- There are currently four available group-fairness criteria that can be used + with the previous algorithms: + - Accuracy Parity + - Demographic Parity + - Equalized Odds + - Equality of Opportunity + +### Scheduler API for learning rates + +DecLearn 2.6.0 also introduces a long-awaited feature: scheduling rules for the +learning rate (and/or weight decay factor), that adjust the scheduled value +throughout training based on the number of training steps and/or rounds already +taken. + +This takes the form of a new (and extensible) `Scheduler` API, implemented +under the new `declearn.optimizer.schedulers` submodule. Instances of +`Scheduler` subclasses (or their JSON-serializable specs) may be passed to +`Optimizer.__init__` instead of float values to specify the `lrate` and/or +`w_decay` parameters, resulting in time-varying values being computed and used +rather than a constant one. + +`Scheduler` is easily-extensible by end-users to write their own rules. +At the moment, DecLearn natively provides with: + +- Various kinds of decay (step, multi-steps or round based; linear, + exponential, polynomial...); +- Cyclic learning rates (based on [this](https://arxiv.org/pdf/1506.01186) + and [that](https://arxiv.org/abs/1608.03983) paper); +- Linear warmup (steps or round based; combinable with another scheduler + to use after the warmup period). + +The [user-guide on the Optimizer API](../user-guide/optimizer.md) was updated +to cover this new feature, and remains the preferred entry-point for new users +that want to get hold of the overall design and specific features offered by +this API. Users already familiar with `Optimizer` may simply check out the API +docs for the new [`Scheduler`][declearn.optimizer.schedulers.Scheduler] API. + +### `declearn.training` submodule reorganization + +DecLearn 2.6.0 introduces the `declearn.training` submodule, that merely +refactors some unchanged classes previously made available under +`declearn.main.utils` and `declearn.main.privacy`. The mapping of changes +is the following: + +- `declearn.main.TrainingManager` -> `declearn.training.TrainingManager` +- `declearn.main.privacy` -> `declearn.training.dp` (which remains a + manual-import submodule relying on the availability of the optional + `opacus` third-party dependency) + +The former `declearn.main.privacy` is deprecated and will be removed in +DecLearn 2.8 and/or 3.0. It is kept for now as an alias re-export of +`declearn.training.dp`, that raises a `DeprecationWarning` unpon manual +import. + +The `declearn.main.utils` submodule is kept, but importing `TrainingManager` +from it is deprecated and will also be removed in version 2.8 and/or 3.0. +For now, the class is merely re-exported from it. + +### Evaluation rounds can now be skipped + +Prior to this release, `FederatedServer` always deterministically ran training +and evaluation rounds in alternance as part of a Federated Learning process. +This can now be modularized, using the new `frequency` parameter as part of +`declearn.main.config.EvaluateConfig` (_i.e._ the "evaluate" field of the +`declearn.main.config.FLRunConfig` instance, dict or TOML file provided as +input to `FederatedServer.run`). + +By default, `frequency=1`, meaning an evaluation round is run after each and +every training round. But make it `frequency=N` and evaluation will only +occur after the N-th, 2*N-th, ... training rounds. Note that if the server +is checkpointing results, then an evaluation round will forcefully occur +after the last training round. + +Note that a similar parameter is available for `FairnessConfig`, albeit working +slightly differently, because fairness evaluation rounds occur _before_ +training. Hence, with `frequency=N`, fairness evaluation and constraints +update will occur before the 1st, N+1-th, 2*N+1-th, ... training rounds. Note +that if the server if checkpointing results, then a fairness round will +forcefully occur after the last training round, for the sake of measuring the +fairness levels of the last model. + +## Other changes + +### Removal of deprecated features + +A number of features were deprecated in DecLearn 2.4.0 (whether legacy API +methods, submdules or methods that were re-organized or renamed, parameters +that were no longer used or a plainly-removed function). As of this new +release, those features that had been kept back with a deprecation warning +are now removed from the code. + +As a remainder, the removed features include: + +* Legacy aggregation methods: + - `declearn.aggregator.Aggregator.aggregate` + - `declearn.metrics.Metric.agg_states` + - `declearn.metrics.MetricSet.agg_states` +* Legacy instantiation parameters: + - `declearn.aggregator.AveragingAggregator` parameter `client_weights` + - `declearn.aggregator.GradientMaskedAveraging` parameter `client_weights` + - `declearn.optimizer.modules.ScaffoldServerModule` parameter `clients` +* Legacy names that were aliasing new locations: + - `declearn.communication.messaging` (moved to `declearn.messaging`) + - `declearn.communication.NetworkClient.check_message` (renamed + `recv_message`) +* `declearn.dataset.load_dataset_from_json` + +### New developer-oriented changes + +A few minor changes are shipped with this new release, that are mostly of +interest to developers - including end-users writing custom algorithms or +bridging DecLearn APIs within their own orchestration code. + +- The `declearn.secagg.messaging.aggregate_secagg_messages` function was + introduced as a refactoring of previous backend code to combine and + decrypt an ensemble of client-emitted `SecaggMessage` instances into a + single aggregated cleartext `Message`. +- The `declearn.utils.TomlConfig` class, from which all TOML-parsing config + dataclasses of DecLearn inherit, now has a new `autofill_fields` class + attribute to indicate fields that may be left empty by users and will then + be dynamically filled when parsing all fields. For instance, this enables + not specifying `evaluate` in the TOML file to a `FLRunConfig` instance. +- New unit tests were added, most notably for `FederatedServer`, that now + benefits from proper coverage by mere unit tests that verify high-level + logic and coherence of actions with inputs and documentation - whereas + the overall working keeps being assessed using functional tests. diff --git a/docs/user-guide/fl_process.md b/docs/user-guide/fl_process.md index 6972092b14e41f2196f6978a66b0702759549a68..60ed43839dee1df8f8d394e35f2aa9f5dfa1ae16 100644 --- a/docs/user-guide/fl_process.md +++ b/docs/user-guide/fl_process.md @@ -15,13 +15,17 @@ exposed here. - the server may collect targetted metadata from clients when required - the server sets up the model, optimizers, aggregator and metrics - all clients receive instructions to set up these objects as well + - additional setup phases optionally occur to set up advanced features + (secure aggregation, differential privacy and/or group fairness) - Iteratively: + - (optionally) perform a fairness-related round - perform a training round - - perform an evaluation round + - (optionally) perform an evaluation round - decide whether to continue, based on the number of rounds taken or on the evolution of the global loss - Finally: - - restore the model weights that yielded the lowest global loss + - (optionally) evaluate the last model, if it was not already done + - restore the model weights that yielded the lowest global validation loss - notify clients that training is over, so they can disconnect and run their final routine (e.g. save the "best" model) - optionally checkpoint the "best" model @@ -65,6 +69,8 @@ for dataset information (typically, features shape and/or dtype). - send specs to the clients so that they set up local counterpart objects - Client: - instantiate the model, optimizer, aggregator and metrics based on specs + - verify that (optional) secure aggregation algorithm choice is coherent + with that of the server - messaging: (InitRequest <-> InitReply) #### (Optional) Local differential privacy setup @@ -79,15 +85,84 @@ indicates to clients that it is to happen, as a secondary substep. - adjust the training process to use sample-wise gradient clipping and add gaussian noise to gradients, implementing the DP-SGD algorithm - set up a privacy accountant to monitor the use of the privacy budget -- messaging: (PrivacyRequest <-> GenericMessage) +- messaging: (PrivacyRequest <-> PrivacyReply) + +#### (Optional) Fairness-aware federated learning setup + +This step is optional; a flag in the InitRequest at a previous step +indicates to clients that it is to happen, as a secondary substep. + +See our [guide on Fairness](./fairness.md) for further details on +what (group) fairness is and how it is implemented in DecLearn. + +When Secure Aggregation is to be used, it is also set up as a first step +to this routine, ensuring exchanged values are protected when possible. + +- Server: + - send hyper-parameters to set up a controller for fairness-aware + federated learning +- Client: + - set up a controller based on the server-emitted query + - send back sensitive group definitions +- messaging: (FairnessSetupQuery <-> FairnessGroups) +- Server: + - define a sorted list of sensitive group definitions across clients + and share it with clients + - await associated sample counts from clients and (secure-)aggregate them +- Client: + - await group definitions and send back group-wise sample counts +- messaging: (FairnessGroups <-> FairnessCounts) +- Server & Client: run algorithm-specific additional setup steps, that + may have side effects on the training data, model, optimizer and/or + aggregator; further communication may occur. + +### (Optional) Secure Aggregation setup + +When configured to be used, Secure Aggregation may be set up any number of +times during the process, as fresh controllers will be required each and +every time the participating clients to a round differs from those chosen +at the previous round. + +By default however, all clients participate to each and every round, so +that a single setup will occur early in the overall FL process. + +See our [guide on Secure Aggregation](./secagg.md) for further details on +what secure aggregation is and how it is implemented in DecLearn. + +- Server: + - send an algorithm-specific SecaggSetupQuery message to selected clients + - trigger an algorithm-dependent setup routine +- Client: + - parse the query and execute the associated setup routine +- Server & Client: perform algorithm-dependent computations and communication; + eventually, instantiate and assign respective encryption and decryption + controllers. +- messaging: (SecaggSetupQuery <-> (algorithm-dependent Message)) + +### (Optional) Fairness round + +This round only occurs when a fairness controller was set up, and may be +configured to be periodically skipped. +If fairness is set up, the first fairness round will always occur. +If checkpointing is set up on the server side, the last model will undergo +a fairness round, to evaluate its fairness prior to ending the FL process. + +- Server: + - send a query to clients, including computational effort constraints, + and current shared model weights (when not already held by clients) +- Client: + - compute metrics that account for the fairness of the current model +- messaging: (FairnessQuery <-> FairnessReply) +- Server & Client: take any algorithm-specific additional actions to alter + training based on the exchanged values; further, communication may happen. ### Training round - Server: - select clients that are to participate - send data-batching and effort constraints parameters - - send shared model trainable weights and (opt. client-specific) optimizer - auxiliary variables + - send current shared model trainable weights (to clients that do not + already hold them) and optimizer auxiliary variables (if any) - Client: - update model weights and optimizer auxiliary variables - perform training steps based on effort constraints @@ -101,7 +176,11 @@ indicates to clients that it is to happen, as a secondary substep. - run global updates through the server's optimizer to modify and finally apply them -### Evaluation round +### (Optional) Evaluation round + +This round may be configured to be periodically skipped. +If checkpointing is set up on the server side, the last model will always be +evaluated prior to ending the FL process. - Server: - select clients that are to participate diff --git a/docs/user-guide/package.md b/docs/user-guide/package.md index c976351fb3b83e07972f8113b526d09477e46525..5f7c690902f07649a2828d71da400436002d1fcb 100644 --- a/docs/user-guide/package.md +++ b/docs/user-guide/package.md @@ -12,6 +12,8 @@ The package is organized into the following submodules:   Tools to write and extend shareable metadata fields specifications. - `dataset`:<br/>   Data interfacing API and implementations. +- `fairness`:<br/> + Processes and components for fairness-aware federated learning. - `main`:<br/>   Main classes implementing a Federated Learning process. - `messaging`:<br/> @@ -24,6 +26,8 @@ The package is organized into the following submodules:   Framework-agnostic optimizer and algorithmic plug-ins API and tools. - `secagg`:<br/>   Secure Aggregation API, methods and utils. +- `training`:<br/> + Model training and evaluation orchestration tools. - `typing`:<br/>   Type hinting utils, defined and exposed for code readability purposes. - `utils`:<br/> @@ -270,6 +274,44 @@ You may learn more about our (non-abstract) `Optimizer` API by reading our and is about making the class JSON-serializable). - To avoid it, use `class MyClass(SecureAggregate, register=False)`. +### Fairness + +#### `FairnessFunction` +- Import: `declearn.fairness.api.FairnessFunction` +- Object: Define a group-fairness criterion. +- Usage: Compute fairness levels of a model based on group-wise accuracy. +- Examples: + - `declearn.fairness.core.DemographicParityFunction` + - `declearn.fairness.core.EqualizedOddsFunction` +- Extend: + - Simply inherit from `FairnessFunction` (registration is automated). + - To avoid it, use `class MyClass(FairnessFunction, register=False)`. + +#### `FairnessControllerServer` +- Import: `declearn.fairness.api.FairnessControllerServer` +- Object: Define server-side routines to monitor and enforce fairness. +- Usage: modify the federated optimization algorithm; orchestrate fairness + rounds to measure the trained model's fairness level and adjust training + based on it. +- Examples: + - `declearn.fairness.fairgrad.FairgradControllerServer` + - `declearn.fairness.fairbatch.FairbatchControllerServer` +- Extend: + - Simply inherit from `FairnessControllerServer` (registration is automated). + - To avoid it, use `class MyClass(FairnessControllerServer, register=False)`. + +#### `FairnessControllerClient` +- Import: `declearn.fairness.api.FairnessControllerClient` +- Object: Define client-side routines to monitor and enforce fairness. +- Usage: modify the federated optimization algorithm; measure a model's + local fairness level; adjust training based on server-emitted values. +- Examples: + - `declearn.fairness.fairgrad.FairgradControllerClient` + - `declearn.fairness.fairbatch.FairbatchControllerClient` +- Extend: + - Simply inherit from `FairnessControllerClient` (registration is automated). + - To avoid it, use `class MyClass(FairnessControllerClient, register=False)`. + ## Full API Reference The full API reference, which is generated automatically from the code's diff --git a/docs/user-guide/usage.md b/docs/user-guide/usage.md index 6b05d75a11bc12683a088c2c2f1cb09ea695ccf4..9392b9cdca362f0dbd03583e2400df65ea28a9b4 100644 --- a/docs/user-guide/usage.md +++ b/docs/user-guide/usage.md @@ -32,7 +32,9 @@ details on this example and on how to run it, please refer to its own used by clients to derive local step-wise updates from model gradients. - Similarly, parameterize an `Optimizer` to be used by the server to (optionally) refine the aggregated model updates before applying them. - - Wrap these three objects into a `declearn.main.config.FLOptimConfig`, + - Optionally, parametrize a `FairnessControllerServer`, defining an + algorithm to enforce fairness constraints to the model being trained. + - Wrap these objects into a `declearn.main.config.FLOptimConfig`, possibly using its `from_config` method to specify the former three components via configuration dicts rather than actual instances. - Alternatively, write up a TOML configuration file that specifies these @@ -56,6 +58,9 @@ details on this example and on how to run it, please refer to its own defines metrics to be computed by clients on their validation data. - Optionally provide the path to a folder where to write output files (model checkpoints and global loss history). + - Optionally parameterize and provide with a `SecaggConfigServer` or its + configuration, to set up and use secure aggregation for all quantities + that support it (model weights, metrics and metadata). - Instantiate a `declearn.main.config.FLRunConfig` to specify the process: - Maximum number of training and evaluation rounds to run. - Registration parameters: exact or min/max number of clients to have @@ -63,11 +68,16 @@ details on this example and on how to run it, please refer to its own - Training parameters: data-batching parameters and effort constraints (number of local epochs and/or steps to take, and optional timeout). - Evaluation parameters: data-batching parameters and effort constraints - (optional maximum number of steps (<=1 epoch) and optional timeout). + (optional maximum number of steps (<=1 epoch) and optional timeout); + optional frequency (to only evaluate after every N training rounds). - Early-stopping parameters (optionally): patience, tolerance, etc. as to the global model loss's evolution throughout rounds. - Local Differential-Privacy parameters (optionally): (epsilon, delta) budget, type of accountant, clipping norm threshold, RNG parameters. + - Fairness evaluation parameters (optionally): computational constraints + and optionally frequency of fairness rounds; only used if fairness is + set up in the `FLOptimConfig`, and automatically/dynamically-filled if + left untouched. - Alternatively, write up a TOML configuration file that specifies all of the former hyper-parameters. - Call the server's `run` method, passing it the former config object, @@ -113,6 +123,9 @@ details on this example and on how to run it, please refer to its own concerns. - Optionally provide the path to a folder where to write output files (model checkpoints and local loss history). + - Optionally parameterize and provide with a `SecaggConfigClient` or its + configuration, to set up and use secure aggregation for all quantities + that support it (model weights, metrics and metadata). - Call the client's `run` method and let the magic happen. ## Logging diff --git a/pyproject.toml b/pyproject.toml index 13163b577144f87ee587e488a82e66d958fef5c3..6efe3d3781174ae473687df0d948b8eb1795b89f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,8 +8,10 @@ requires = [ [project] name = "declearn" -version = "2.6.0.dev2" -description = "Declearn - a python package for private decentralized learning." +version = "2.6.0" +description = """ +DecLearn - a modular and extensible framework for Federated Learning. +""" readme = "README.md" requires-python = ">=3.8" license = {file = "LICENSE"} diff --git a/test/aggregator/test_aggregator.py b/test/aggregator/test_aggregator.py index e619bccee54e28d2fcef52735fc98814b7373c15..8d1f74b4c2da57b914d0024e0ff8733305c06c08 100644 --- a/test/aggregator/test_aggregator.py +++ b/test/aggregator/test_aggregator.py @@ -143,31 +143,3 @@ class TestAggregator: result = aggregator.finalize_updates(output) expect = aggregator.finalize_updates(updates_a + updates_b) assert result == expect - - # DEPRECATED: the following tests cover deprecated methods - - @pytest.mark.parametrize("framework", VECTOR_FRAMEWORKS) - def test_aggregate( - self, - agg_cls: Type[Aggregator], - updates: Dict[str, Vector], - ) -> None: - """Test that the legacy (deprecated) 'aggregate' method still works.""" - agg = agg_cls() - n_steps = {key: 10 for key in updates} - with pytest.warns(DeprecationWarning): - outputs = agg.aggregate(updates, n_steps) - ref_vec = list(updates.values())[0] - assert isinstance(outputs, type(ref_vec)) - assert outputs.shapes() == ref_vec.shapes() - assert outputs.dtypes() == ref_vec.dtypes() - - def test_aggregate_empty( - self, - agg_cls: Type[Aggregator], - ) -> None: - """Test that 'aggregate' raises the expected error on empty inputs.""" - agg = agg_cls() - with pytest.warns(DeprecationWarning): - with pytest.raises(TypeError): - agg.aggregate(updates={}, n_steps={}) diff --git a/test/functional/test_toy_clf_fairness.py b/test/functional/test_toy_clf_fairness.py index 41df97a8f2275f090099720de905c7e8b74cbe9d..b94b4f5b8995860c21da5a304d9e27abb5c25c0d 100644 --- a/test/functional/test_toy_clf_fairness.py +++ b/test/functional/test_toy_clf_fairness.py @@ -147,6 +147,7 @@ async def server_routine( rounds=5, register={"min_clients": n_clients, "timeout": 2}, training={"n_epoch": 1, "batch_size": 10}, + evaluate={"frequency": 5}, # only evaluate the last model fairness={"batch_size": 50}, ) with warnings.catch_warnings(): diff --git a/test/functional/test_toy_clf_secagg.py b/test/functional/test_toy_clf_secagg.py index 3599d22f0c677c577c1bb7ad876141d7dee79cc1..429284ce1de66de4feede8165716607e5a07814a 100644 --- a/test/functional/test_toy_clf_secagg.py +++ b/test/functional/test_toy_clf_secagg.py @@ -143,9 +143,10 @@ async def async_run_server( ) # Set up hyper-parameters and run training. config = FLRunConfig.from_params( - rounds=10, + rounds=8, register={"min_clients": n_clients, "timeout": 2}, training={"n_epoch": 1, "batch_size": 1, "drop_remainder": False}, + evaluate={"frequency": 8}, # only evaluate the last model ) await server.async_run(config) diff --git a/test/functional/test_toy_reg.py b/test/functional/test_toy_reg.py index 515e182e5899cea3fdbab1da8c9ae9cb1bdaae7a..9b5514d920cee9047be43fe4aeccfa96b9326d7d 100644 --- a/test/functional/test_toy_reg.py +++ b/test/functional/test_toy_reg.py @@ -411,6 +411,7 @@ async def async_run_server( "batch_size": b_size, "drop_remainder": False, }, + evaluate={"frequency": 10}, # only evaluate the last model ) await server.async_run(config) diff --git a/test/main/test_main_client.py b/test/main/test_main_client.py index e002b9dc9455dff31c52f3b91adfd912785da39e..345c1a2005d9fa27d5b7a9fd1e41a00000d293a2 100644 --- a/test/main/test_main_client.py +++ b/test/main/test_main_client.py @@ -139,17 +139,6 @@ class TestFederatedClientInit: # pylint: disable=too-many-public-methods client = FederatedClient(netwk=MOCK_NETWK, train_data=dataset) assert client.train_data is dataset - def test_train_data_str(self) -> None: - """Test specifying 'train_data' as a file path.""" - path = "mock_path_to_dataset.json" - with mock.patch( - "declearn.main._client.load_dataset_from_json", - return_value=mock.create_autospec(Dataset, instance=True), - ) as patched: - client = FederatedClient(netwk=MOCK_NETWK, train_data=path) - patched.assert_called_once_with(path) - assert client.train_data is patched.return_value - def test_train_data_invalid(self) -> None: """Test specifying 'train_data' as an invalid type.""" with pytest.raises(TypeError): @@ -172,19 +161,6 @@ class TestFederatedClientInit: # pylint: disable=too-many-public-methods ) assert client.valid_data is dataset - def test_valid_data_str(self) -> None: - """Test specifying 'valid_data' as a file path.""" - path = "mock_path_to_dataset.json" - with mock.patch( - "declearn.main._client.load_dataset_from_json", - return_value=mock.create_autospec(Dataset, instance=True), - ) as patched: - client = FederatedClient( - netwk=MOCK_NETWK, train_data=MOCK_DATASET, valid_data=path - ) - patched.assert_called_once_with(path) - assert client.valid_data is patched.return_value - def test_valid_data_invalid(self) -> None: """Test specifying 'valid_data' as an invalid type.""" with pytest.raises(TypeError): diff --git a/test/main/test_main_server.py b/test/main/test_main_server.py index 1ae08bb5bab6bda583c47b5d87a1eb81827db905..e8e416b20e204d31869098a51724fc2b6d0efa2b 100644 --- a/test/main/test_main_server.py +++ b/test/main/test_main_server.py @@ -18,23 +18,52 @@ """Unit tests for 'FederatedServer'.""" import logging +import math import os from unittest import mock +from typing import Dict, List, Optional, Type import pytest # type: ignore -from declearn.aggregator import Aggregator +from declearn.aggregator import Aggregator, ModelUpdates from declearn.communication import NetworkServerConfig from declearn.communication.api import NetworkServer from declearn.fairness.api import FairnessControllerServer from declearn.main import FederatedServer -from declearn.main.config import FLOptimConfig +from declearn.main.config import ( + FLOptimConfig, + FLRunConfig, + EvaluateConfig, + FairnessConfig, + RegisterConfig, + TrainingConfig, +) from declearn.main.utils import Checkpointer from declearn.metrics import MetricSet +from declearn.messaging import ( + EvaluationReply, + EvaluationRequest, + FairnessQuery, + InitReply, + InitRequest, + Message, + MetadataQuery, + MetadataReply, + PrivacyReply, + PrivacyRequest, + SerializedMessage, + StopTraining, + TrainRequest, + TrainReply, +) from declearn.model.api import Model from declearn.model.sklearn import SklearnSGDModel from declearn.optimizer import Optimizer -from declearn.secagg.api import SecaggConfigServer +from declearn.secagg.api import Decrypter, SecaggConfigServer +from declearn.secagg.messaging import ( + SecaggEvaluationReply, + SecaggTrainReply, +) from declearn.utils import serialize_object @@ -348,3 +377,530 @@ class TestFederatedServerInit: # pylint: disable=too-many-public-methods FederatedServer( MOCK_MODEL, MOCK_NETWK, MOCK_OPTIM, logger=mock.MagicMock() ) + + +class TestFederatedServerRoutines: + """Unit tests for 'FederatedServer' main unitary routines.""" + + @staticmethod + async def setup_test_server( + use_secagg: bool = False, + use_fairness: bool = False, + ) -> FederatedServer: + """Set up a FederatedServer wrapping mock controllers.""" + netwk = mock.create_autospec(NetworkServer, instance=True) + netwk.name = "server" + netwk.client_names = {"client_a", "client_b"} + optim = FLOptimConfig( + client_opt=mock.create_autospec(Optimizer, instance=True), + server_opt=mock.create_autospec(Optimizer, instance=True), + aggregator=mock.create_autospec(Aggregator, instance=True), + fairness=( + mock.create_autospec(FairnessControllerServer, instance=True) + if use_fairness + else None + ), + ) + secagg = None # type: Optional[SecaggConfigServer] + if use_secagg: + secagg = mock.create_autospec(SecaggConfigServer, instance=True) + secagg.secagg_type = "mock_secagg" # type: ignore + return FederatedServer( + model=mock.create_autospec(Model, instance=True), + netwk=netwk, + optim=optim, + metrics=mock.create_autospec(MetricSet, instance=True), + secagg=secagg, + checkpoint=mock.create_autospec(Checkpointer, instance=True), + ) + + @staticmethod + def setup_mock_serialized_message( + msg_cls: Type[Message], + wrapped: Optional[Message] = None, + ) -> mock.NonCallableMagicMock: + """Set up a mock SerializedMessage with given wrapped message type.""" + message = mock.create_autospec(SerializedMessage, instance=True) + message.message_cls = msg_cls + if wrapped is None: + wrapped = mock.create_autospec(msg_cls, instance=True) + message.deserialize.return_value = wrapped + return message + + @pytest.mark.parametrize( + "metadata", [False, True], ids=["nometa", "metadata"] + ) + @pytest.mark.parametrize("privacy", [False, True], ids=["nodp", "dpsgd"]) + @pytest.mark.parametrize( + "fairness", [False, True], ids=["unfair", "fairness"] + ) + @pytest.mark.parametrize("secagg", [False, True], ids=["clrtxt", "secagg"]) + @pytest.mark.asyncio + async def test_initialization( + self, + secagg: bool, + fairness: bool, + privacy: bool, + metadata: bool, + ) -> None: + """Test that the 'initialization' routine triggers expected calls.""" + # Set up a server with mocked attributes. + server = await self.setup_test_server( + use_secagg=secagg, use_fairness=fairness + ) + assert isinstance(server.netwk, mock.NonCallableMagicMock) + assert isinstance(server.model, mock.NonCallableMagicMock) + server.model.required_data_info = {"n_samples"} if metadata else {} + aggrg = server.aggrg + # Run the initialization routine. + config = FLRunConfig.from_params( + rounds=10, + register=RegisterConfig(0, 2, 120), + training={"batch_size": 8}, + privacy=( + {"budget": (1e-3, 0.0), "sclip_norm": 1.0} if privacy else None + ), + ) + server.netwk.wait_for_messages.side_effect = self._setup_init_replies( + metadata, privacy + ) + await server.initialization(config) + # Verify that the clients-registration routine was called. + server.netwk.wait_for_clients.assert_awaited_once_with(0, 2, 120) + # Verify that the expected number of message exchanges occured. + assert server.netwk.broadcast_message.await_count == ( + 1 + metadata + privacy + ) + queries = server.netwk.broadcast_message.await_args_list.copy() + # When configured, verify that metadata were queried and used. + if metadata: + query = queries.pop(0)[0][0] + assert isinstance(query, MetadataQuery) + assert query.fields == ["n_samples"] + server.model.initialize.assert_called_once_with({"n_samples": 200}) + # Verify that an InitRequest was sent with expected parameters. + query = queries.pop(0)[0][0] + assert isinstance(query, InitRequest) + assert query.dpsgd is privacy + if secagg: + assert query.secagg is not None + else: + assert query.secagg is None + assert query.fairness is fairness + # Verify that DP-SGD setup occurred when expected. + if privacy: + query = queries.pop(0)[0][0] + assert isinstance(query, PrivacyRequest) + assert query.budget == (1e-3, 0.0) + assert query.sclip_norm == 1.0 + assert query.rounds == 10 + # Verify that SecAgg setup occurred when expected. + decrypter = None # type: Optional[Decrypter] + if secagg: + assert isinstance(server.secagg, mock.NonCallableMagicMock) + if fairness: + server.secagg.setup_decrypter.assert_awaited_once() + decrypter = server.secagg.setup_decrypter.return_value + else: + server.secagg.setup_decrypter.assert_not_called() + # Verify that fairness setup occurred when expected. + if fairness: + assert isinstance(server.fairness, mock.NonCallableMagicMock) + server.fairness.setup_fairness.assert_awaited_once_with( + netwk=server.netwk, aggregator=aggrg, secagg=decrypter + ) + assert server.aggrg is server.fairness.setup_fairness.return_value + + def _setup_init_replies( + self, + metadata: bool, + privacy: bool, + ) -> List[Dict[str, mock.NonCallableMagicMock]]: + clients = ("client_a", "client_b") + messages = [] # type: List[Dict[str, mock.NonCallableMagicMock]] + if metadata: + msg = MetadataReply({"n_samples": 100}) + messages.append( + { + key: self.setup_mock_serialized_message(MetadataReply, msg) + for key in clients + } + ) + messages.append( + { + key: self.setup_mock_serialized_message(InitReply) + for key in clients + } + ) + if privacy: + messages.append( + { + key: self.setup_mock_serialized_message(PrivacyReply) + for key in clients + } + ) + return messages + + @pytest.mark.parametrize("secagg", [False, True], ids=["clrtxt", "secagg"]) + @pytest.mark.asyncio + async def test_training_round( + self, + secagg: bool, + ) -> None: + """Test that the 'training_round' routine triggers expected calls.""" + # Set up a server with mocked attributes. + server = await self.setup_test_server(use_secagg=secagg) + assert isinstance(server.netwk, mock.NonCallableMagicMock) + assert isinstance(server.model, mock.NonCallableMagicMock) + assert isinstance(server.optim, mock.NonCallableMagicMock) + assert isinstance(server.aggrg, mock.NonCallableMagicMock) + # Mock-run a training routine. + reply_cls = ( + SecaggTrainReply if secagg else TrainReply # type: ignore + ) # type: Type[Message] + updates = mock.create_autospec(ModelUpdates, instance=True) + reply_msg = TrainReply( + n_epoch=1, n_steps=10, t_spent=0.0, updates=updates, aux_var={} + ) + wrapped = None if secagg else reply_msg + server.netwk.wait_for_messages.return_value = { + "client_a": self.setup_mock_serialized_message(reply_cls, wrapped), + "client_b": self.setup_mock_serialized_message(reply_cls, wrapped), + } + with mock.patch( + "declearn.secagg.messaging.aggregate_secagg_messages", + return_value=reply_msg, + ) as patch_aggregate_secagg_messages: + await server.training_round( + round_i=1, train_cfg=TrainingConfig(batch_size=8) + ) + # Verify that expected actions occured. + # (a) optional secagg setup + if secagg: + assert isinstance(server.secagg, mock.NonCallableMagicMock) + server.secagg.setup_decrypter.assert_awaited_once() + # (b) training request emission, including model weights + server.netwk.send_messages.assert_awaited_once() + queries = server.netwk.send_messages.await_args[0][0] + assert isinstance(queries, dict) + assert queries.keys() == server.netwk.client_names + for query in queries.values(): + assert isinstance(query, TrainRequest) + assert query.weights is server.model.get_weights.return_value + assert query.aux_var is server.optim.collect_aux_var.return_value + # (c) training reply reception + server.netwk.wait_for_messages.assert_awaited_once() + if secagg: + patch_aggregate_secagg_messages.assert_called_once() + else: + patch_aggregate_secagg_messages.assert_not_called() + # (d) updates aggregation and global model weights update + server.optim.process_aux_var.assert_called_once() + server.aggrg.finalize_updates.assert_called_once() + server.optim.apply_gradients.assert_called_once() + + @pytest.mark.parametrize("secagg", [False, True], ids=["clrtxt", "secagg"]) + @pytest.mark.asyncio + async def test_evaluation_round( + self, + secagg: bool, + ) -> None: + """Test that the 'evaluation_round' routine triggers expected calls.""" + # Set up a server with mocked attributes. + server = await self.setup_test_server(use_secagg=secagg) + assert isinstance(server.netwk, mock.NonCallableMagicMock) + assert isinstance(server.model, mock.NonCallableMagicMock) + assert isinstance(server.metrics, mock.NonCallableMagicMock) + assert isinstance(server.ckptr, mock.NonCallableMagicMock) + # Mock-run an evaluation routine. + reply_cls = ( + SecaggEvaluationReply # type: ignore + if secagg + else EvaluationReply + ) # type: Type[Message] + reply_msg = EvaluationReply( + loss=0.42, n_steps=10, t_spent=0.0, metrics={} + ) + wrapped = None if secagg else reply_msg + server.netwk.wait_for_messages.return_value = { + "client_a": self.setup_mock_serialized_message(reply_cls, wrapped), + "client_b": self.setup_mock_serialized_message(reply_cls, wrapped), + } + with mock.patch( + "declearn.secagg.messaging.aggregate_secagg_messages", + return_value=reply_msg, + ) as patch_aggregate_secagg_messages: + await server.evaluation_round( + round_i=1, valid_cfg=EvaluateConfig(batch_size=8) + ) + # Verify that expected actions occured. + # (a) optional secagg setup + if secagg: + assert isinstance(server.secagg, mock.NonCallableMagicMock) + server.secagg.setup_decrypter.assert_awaited_once() + # (b) evaluation request emission, including model weights + server.netwk.send_messages.assert_awaited_once() + queries = server.netwk.send_messages.await_args[0][0] + assert isinstance(queries, dict) + assert queries.keys() == server.netwk.client_names + for query in queries.values(): + assert isinstance(query, EvaluationRequest) + assert query.weights is server.model.get_weights.return_value + # (c) evaluation reply reception + server.netwk.wait_for_messages.assert_awaited_once() + if secagg: + patch_aggregate_secagg_messages.assert_called_once() + else: + patch_aggregate_secagg_messages.assert_not_called() + # (d) metrics aggregation + server.metrics.reset.assert_called_once() + server.metrics.set_states.assert_called_once() + server.metrics.get_result.assert_called_once() + # (e) checkpointing + server.ckptr.checkpoint.assert_called_once_with( + model=server.model, + optimizer=server.optim, + metrics=server.metrics.get_result.return_value, + ) + + @pytest.mark.asyncio + async def test_evaluation_round_skip( + self, + ) -> None: + """Test that 'evaluation_round' skips rounds when configured.""" + # Set up a server with mocked attributes. + server = await self.setup_test_server() + assert isinstance(server.netwk, mock.NonCallableMagicMock) + # Mock a call that should result in skipping the round. + await server.evaluation_round( + round_i=1, + valid_cfg=EvaluateConfig(batch_size=8, frequency=2), + ) + # Assert that no message was sent (routine was skipped). + server.netwk.broadcast_message.assert_not_called() + server.netwk.send_messages.assert_not_called() + server.netwk.send_message.assert_not_called() + + @pytest.mark.parametrize("secagg", [False, True], ids=["clrtxt", "secagg"]) + @pytest.mark.asyncio + async def test_fairness_round( + self, + secagg: bool, + ) -> None: + """Test that the 'fairness_round' routine triggers expected calls.""" + # Set up a server with mocked attributes. + server = await self.setup_test_server( + use_secagg=secagg, use_fairness=True + ) + assert isinstance(server.netwk, mock.NonCallableMagicMock) + assert isinstance(server.model, mock.NonCallableMagicMock) + assert isinstance(server.fairness, mock.NonCallableMagicMock) + assert isinstance(server.ckptr, mock.NonCallableMagicMock) + # Mock-run a fairness routine. + await server.fairness_round( + round_i=0, + fairness_cfg=FairnessConfig(), + ) + # Verify that expected actions occured. + # (a) optional secagg setup + decrypter = None # type: Optional[Decrypter] + if secagg: + assert isinstance(server.secagg, mock.NonCallableMagicMock) + server.secagg.setup_decrypter.assert_awaited_once() + decrypter = server.secagg.setup_decrypter.return_value + # (b) fairness query emission, including model weights + server.netwk.send_messages.assert_awaited_once() + queries = server.netwk.send_messages.await_args[0][0] + assert isinstance(queries, dict) + assert queries.keys() == server.netwk.client_names + for query in queries.values(): + assert isinstance(query, FairnessQuery) + assert query.weights is server.model.get_weights.return_value + # (c) fairness controller round routine + server.fairness.run_fairness_round.assert_awaited_once_with( + netwk=server.netwk, secagg=decrypter + ) + # (d) checkpointing + server.ckptr.save_metrics.assert_called_once_with( + metrics=server.fairness.run_fairness_round.return_value, + prefix="fairness_metrics", + append=False, + timestamp="round_0", + ) + + @pytest.mark.asyncio + async def test_fairness_round_undefined( + self, + ) -> None: + """Test that 'fairness_round' early-exits when fairness is not set.""" + # Set up a server with mocked attributes and no fairness controller. + server = await self.setup_test_server(use_fairness=False) + assert isinstance(server.netwk, mock.NonCallableMagicMock) + assert server.fairness is None + # Call the fairness round routine.1 + await server.fairness_round( + round_i=0, + fairness_cfg=FairnessConfig(), + ) + # Assert that no message was sent (routine was skipped). + server.netwk.broadcast_message.assert_not_called() + server.netwk.send_messages.assert_not_called() + server.netwk.send_message.assert_not_called() + + @pytest.mark.asyncio + async def test_fairness_round_skip( + self, + ) -> None: + """Test that 'fairness_round' skips rounds when configured.""" + # Set up a server with a mocked fairness controller. + server = await self.setup_test_server(use_fairness=True) + assert isinstance(server.fairness, mock.NonCallableMagicMock) + # Mock a call that should result in skipping the round. + await server.fairness_round( + round_i=1, + fairness_cfg=FairnessConfig(frequency=2), + ) + # Assert that the round was skipped. + server.fairness.run_fairness_round.assert_not_called() + + @pytest.mark.asyncio + async def test_stop_training( + self, + ) -> None: + """Test that 'stop_training' triggers expected actions.""" + # Set up a server with mocked attributes. + server = await self.setup_test_server() + assert isinstance(server.netwk, mock.NonCallableMagicMock) + assert isinstance(server.model, mock.NonCallableMagicMock) + assert isinstance(server.ckptr, mock.NonCallableMagicMock) + server.ckptr.folder = "mock_folder" + # Call the 'stop_training' routine. + await server.stop_training(rounds=5) + # Verify that the expected message was broadcasted. + server.netwk.broadcast_message.assert_awaited_once() + message = server.netwk.broadcast_message.await_args[0][0] + assert isinstance(message, StopTraining) + assert message.weights is server.model.get_weights.return_value + assert math.isnan(message.loss) + assert message.rounds == 5 + # Verify that the expected checkpointing occured. + server.ckptr.save_model.assert_called_once_with( + server.model, timestamp="best" + ) + + +class TestFederatedServerRun: + """Unit tests for 'FederatedServer.run' and 'async_run' routines.""" + + # Unit tests for FLRunConfig parsing via synchronous 'run' method. + + def test_run_from_dict( + self, + ) -> None: + """Test that 'run' properly parses input dict config. + + Mock the actual underlying routine. + """ + server = FederatedServer( + model=MOCK_MODEL, netwk=MOCK_NETWK, optim=MOCK_OPTIM + ) + config = mock.create_autospec(dict, instance=True) + with mock.patch.object( + FLRunConfig, + "from_params", + return_value=mock.create_autospec(FLRunConfig, instance=True), + ) as patch_flrunconfig_from_params: + with mock.patch.object(server, "async_run") as patch_async_run: + server.run(config) + patch_flrunconfig_from_params.assert_called_once_with(**config) + patch_async_run.assert_called_once_with( + patch_flrunconfig_from_params.return_value + ) + + def test_run_from_toml( + self, + ) -> None: + """Test that 'run' properly parses input TOML file. + + Mock the actual underlying routine. + """ + server = FederatedServer( + model=MOCK_MODEL, netwk=MOCK_NETWK, optim=MOCK_OPTIM + ) + config = "mock_path.toml" + with mock.patch.object( + FLRunConfig, + "from_toml", + return_value=mock.create_autospec(FLRunConfig, instance=True), + ) as patch_flrunconfig_from_toml: + with mock.patch.object(server, "async_run") as patch_async_run: + server.run(config) + patch_flrunconfig_from_toml.assert_called_once_with(config) + patch_async_run.assert_called_once_with( + patch_flrunconfig_from_toml.return_value + ) + + def test_run_from_config( + self, + ) -> None: + """Test that 'run' properly uses input FLRunConfig. + + Mock the actual underlying routine. + """ + server = FederatedServer( + model=MOCK_MODEL, netwk=MOCK_NETWK, optim=MOCK_OPTIM + ) + config = mock.create_autospec(FLRunConfig, instance=True) + with mock.patch.object(server, "async_run") as patch_async_run: + server.run(config) + patch_async_run.assert_called_once_with(config) + + # Unit tests for overall actions sequence in 'async_run'. + + @pytest.mark.asyncio + async def test_async_run_actions_sequence(self) -> None: + """Test that 'async_run' triggers expected routines.""" + # Setup a server and a run config with mock attributes. + server = FederatedServer( + model=MOCK_MODEL, + netwk=MOCK_NETWK, + optim=MOCK_OPTIM, + checkpoint=mock.create_autospec(Checkpointer, instance=True), + ) + config = FLRunConfig( + rounds=10, + register=mock.create_autospec(RegisterConfig, instance=True), + training=mock.create_autospec(TrainingConfig, instance=True), + evaluate=mock.create_autospec(EvaluateConfig, instance=True), + fairness=mock.create_autospec(FairnessConfig, instance=True), + privacy=None, + early_stop=None, + ) + # Call 'async_run', mocking all underlying routines. + with mock.patch.object( + server, "initialization" + ) as patch_initialization: + with mock.patch.object(server, "training_round") as patch_training: + with mock.patch.object( + server, "evaluation_round" + ) as patch_evaluation: + with mock.patch.object( + server, "fairness_round" + ) as patch_fairness: + with mock.patch.object( + server, "stop_training" + ) as patch_stop_training: + await server.async_run(config) + # Verify that expected calls occured. + patch_initialization.assert_called_once_with(config) + patch_training.assert_has_calls( + [mock.call(idx, config.training) for idx in range(1, 11)] + ) + patch_evaluation.assert_has_calls( + [mock.call(idx, config.evaluate) for idx in range(1, 11)] + ) + patch_fairness.assert_has_calls( + [mock.call(idx, config.fairness) for idx in range(0, 10)] + + [mock.call(10, config.fairness, force_run=True)] + ) + patch_stop_training.assert_called_once_with(10) diff --git a/test/metrics/metric_testing.py b/test/metrics/metric_testing.py index 2adb2fd9a7998bff3b87fe8760ff9ba0ec5e3d2b..043c63410e26e007897c41976f2f232c628f3abb 100644 --- a/test/metrics/metric_testing.py +++ b/test/metrics/metric_testing.py @@ -177,32 +177,6 @@ class MetricTestSuite: expect = test_case.agg_scores assert_dict_equal(metric.get_result(), expect, np_tolerance=self.tol) - def test_legacy_agg_states(self, test_case: MetricTestCase) -> None: - """Test that the deprecated `agg_states` method works as expected.""" - # Set up and update two identical metrics. - metric = test_case.metric - metbis = deepcopy(test_case.metric) - metric.update(**test_case.inputs) - metbis.update(**test_case.inputs) - # Aggregate the second into the first. Verify that they now differ. - assert_dict_equal( - metric.get_states().to_dict(), metbis.get_states().to_dict() - ) - with pytest.warns(DeprecationWarning): - metbis.agg_states(metric.get_states()) - assert_dict_equal( - metric.get_states().to_dict(), test_case.states.to_dict() - ) - with pytest.raises(AssertionError): # assert not equal - assert_dict_equal( - metric.get_states().to_dict(), metbis.get_states().to_dict() - ) - # Verify the correctness of the aggregated states and scores. - states = test_case.agg_states - scores = test_case.agg_scores - assert_dict_equal(metbis.get_states().to_dict(), states.to_dict()) - assert_dict_equal(metbis.get_result(), scores, np_tolerance=self.tol) - def test_update_with_squeezable_inputs( self, test_case: MetricTestCase ) -> None: diff --git a/test/metrics/test_metricset.py b/test/metrics/test_metricset.py index 1ff2eae25120c1776f963a503dfa7b4215884b45..8eb7b445f9f2f58e447d06fdffda7ac12560f0e8 100644 --- a/test/metrics/test_metricset.py +++ b/test/metrics/test_metricset.py @@ -18,7 +18,7 @@ """Unit tests for `declearn.metrics.MetricSet`.""" from unittest import mock -from typing import Dict, Tuple +from typing import Tuple import numpy as np import pytest @@ -129,18 +129,6 @@ class TestMetricSet: states["mse"] ) - def test_agg_states(self) -> None: - """Test that deprecated `MetricSet.agg_states` works as expected.""" - mae, mse, metrics = get_mock_metricset() - states = { - "mae": mae.get_states(), - "mse": mse.get_states(), - } # type: Dict[str, MetricState] - with pytest.warns(DeprecationWarning): - metrics.agg_states(states) - mae.agg_states.assert_called_once_with(states["mae"]) # type: ignore - mse.agg_states.assert_called_once_with(states["mse"]) # type: ignore - def test_get_config(self) -> None: """Test that `MetricSet.get_config` works as expected.""" mae = MeanAbsoluteError() diff --git a/test/training/test_train_manager.py b/test/training/test_train_manager.py index 271c2a1f52d1d8a88596f66315f2101c2f1403b2..f59ae5b154ed7bcc2457f6b5ce58aed7c7d876be 100644 --- a/test/training/test_train_manager.py +++ b/test/training/test_train_manager.py @@ -22,8 +22,8 @@ from typing import Any, Iterator, Optional import numpy +from declearn import messaging from declearn.aggregator import Aggregator -from declearn.communication import messaging from declearn.dataset import Dataset from declearn.metrics import Metric, MetricSet from declearn.model.api import Model, Vector diff --git a/test/training/test_train_manager_dp.py b/test/training/test_train_manager_dp.py index 01c62d7ff9c09f32d190055b6f147c94886d17d1..83edd26761b3738937afbdee326789d5b6db7847 100644 --- a/test/training/test_train_manager_dp.py +++ b/test/training/test_train_manager_dp.py @@ -28,7 +28,7 @@ try: except ModuleNotFoundError: pytest.skip("Opacus is unavailable", allow_module_level=True) -from declearn.communication import messaging +from declearn import messaging from declearn.dataset import DataSpecs from declearn.optimizer.modules import GaussianNoiseModule from declearn.training.dp import DPTrainingManager