From 8846f56b986ff5d033803438034d3173ea062d6f Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Fri, 26 Jul 2024 10:52:55 +0200 Subject: [PATCH] Remove deprecated features planned for v2.6 removal. --- declearn/aggregator/_api.py | 43 ----- declearn/aggregator/_avg.py | 87 +--------- declearn/aggregator/_gma.py | 36 +---- declearn/communication/__init__.py | 8 - declearn/communication/api/_client.py | 17 -- declearn/communication/messaging/__init__.py | 57 ------- declearn/communication/messaging/_messages.py | 153 ------------------ declearn/dataset/__init__.py | 4 +- declearn/dataset/_base.py | 42 +---- declearn/main/_client.py | 10 +- declearn/metrics/_api.py | 41 ----- declearn/metrics/_wrapper.py | 44 ----- declearn/optimizer/modules/_scaffold.py | 19 +-- docs/release-notes/v2.6.0.md | 26 +++ test/aggregator/test_aggregator.py | 28 ---- test/main/test_main_client.py | 24 --- test/metrics/metric_testing.py | 26 --- test/metrics/test_metricset.py | 14 +- 18 files changed, 36 insertions(+), 643 deletions(-) delete mode 100644 declearn/communication/messaging/__init__.py delete mode 100644 declearn/communication/messaging/_messages.py diff --git a/declearn/aggregator/_api.py b/declearn/aggregator/_api.py index 468b837b..56813d50 100644 --- a/declearn/aggregator/_api.py +++ b/declearn/aggregator/_api.py @@ -19,7 +19,6 @@ import abc import dataclasses -import warnings from typing import Any, ClassVar, Dict, Generic, Type, TypeVar, Union from typing_extensions import Self # future: import from typing (py >=3.11) @@ -185,48 +184,6 @@ class Aggregator(Generic[ModelUpdatesT], metaclass=abc.ABCMeta): """Instantiate an Aggregator from its configuration dict.""" return cls(**config) - def aggregate( - self, - updates: Dict[str, Vector[T]], - n_steps: Dict[str, int], # revise: abstract~generalize kwargs use - ) -> Vector[T]: - """DEPRECATED - Aggregate input vectors into a single one. - - Parameters - ---------- - updates: dict[str, Vector] - Client-wise updates, as a dictionary with clients' names as - string keys and updates as Vector values. - n_steps: dict[str, int] - Client-wise number of local training steps performed during - the training round having produced the updates. - - Returns - ------- - gradients: Vector - Aggregated updates, as a Vector - treated as gradients by - the server-side optimizer. - - Raises - ------ - TypeError - If the input `updates` are an empty dict. - """ - warnings.warn( - "'Aggregator.aggregate' was deprecated in DecLearn v2.4 in favor " - "of new API methods. It will be removed in DecLearn v2.6 and/or " - "v3.0.", - DeprecationWarning, - ) - if not updates: - raise TypeError("'Aggregator.aggregate' received an empty dict.") - partials = [ - self.prepare_for_sharing(updates[client], n_steps[client]) - for client in updates - ] - aggregated = sum(partials[1:], start=partials[0]) - return self.finalize_updates(aggregated) - def list_aggregators() -> Dict[str, Type[Aggregator]]: """Return a mapping of registered Aggregator subclasses. diff --git a/declearn/aggregator/_avg.py b/declearn/aggregator/_avg.py index 05dfa5b4..333b4b9d 100644 --- a/declearn/aggregator/_avg.py +++ b/declearn/aggregator/_avg.py @@ -17,8 +17,7 @@ """FedAvg-like mean-aggregation class.""" -import warnings -from typing import Any, Dict, Optional +from typing import Any, Dict from declearn.aggregator._api import Aggregator, ModelUpdates @@ -44,7 +43,6 @@ class AveragingAggregator(Aggregator[ModelUpdates]): def __init__( self, steps_weighted: bool = True, - client_weights: Optional[Dict[str, float]] = None, ) -> None: """Instantiate an averaging aggregator. @@ -53,37 +51,14 @@ class AveragingAggregator(Aggregator[ModelUpdates]): steps_weighted: Whether to conduct a weighted averaging of local model updates based on local numbers of training steps. - client_weights: - DEPRECATED - this argument no longer affects computations, - save when using the deprecated 'aggregate' method. - Optional dict of client-wise base weights to use. - If None, homogeneous base weights are used. - - Notes - ----- - * One may specify `client_weights` and use `steps_weighted=True`. - In that case, the product of the client's base weight and their - number of training steps taken will be used (and unit-normed). - * One may use incomplete `client_weights`. In that case, unknown- - clients' base weights will be set to 1. """ self.steps_weighted = steps_weighted - self.client_weights = client_weights or {} - if client_weights: # pragma: no cover - warnings.warn( - f"'client_weights' argument to '{self.__class__.__name__}' was" - " deprecated in DecLearn v2.4 and is no longer used, saved by" - " the deprecated 'aggregate' method. It will be removed in" - " DecLearn v2.6 and/or v3.0.", - DeprecationWarning, - ) def get_config( self, ) -> Dict[str, Any]: return { "steps_weighted": self.steps_weighted, - "client_weights": self.client_weights, } def prepare_for_sharing( @@ -103,63 +78,3 @@ class AveragingAggregator(Aggregator[ModelUpdates]): updates: ModelUpdates, ) -> Vector: return updates.updates / updates.weights - - def aggregate( - self, - updates: Dict[str, Vector], - n_steps: Dict[str, int], - ) -> Vector: - # Make use of 'client_weights' as part of this DEPRECATED method. - with warnings.catch_warnings(): - warnings.simplefilter(action="ignore", category=DeprecationWarning) - weights = self.compute_client_weights(updates, n_steps) - steps_weighted = self.steps_weighted - try: - self.steps_weighted = True - return super().aggregate(updates, weights) # type: ignore - finally: - self.steps_weighted = steps_weighted - - def compute_client_weights( - self, - updates: Dict[str, Vector], - n_steps: Dict[str, int], - ) -> Dict[str, float]: - """Compute weights to use when averaging a given set of updates. - - This method is DEPRECATED as of DecLearn v2.4. - It will be removed in DecLearn 2.6 and/or 3.0. - - Parameters - ---------- - updates: dict[str, Vector] - Client-wise updates, as a dictionary with clients' names as - string keys and updates as Vector values. - n_steps: dict[str, int] - Client-wise number of local training steps performed during - the training round having produced the updates. - - Returns - ------- - weights: dict[str, float] - Client-wise updates-averaging weights, suited to the input - parameters and normalized so that they sum to 1. - """ - warnings.warn( - f"'{self.__class__.__name__}.compute_client_weights' was" - " deprecated in DecLearn v2.4. It will be removed in DecLearn" - " v2.6 and/or v3.0.", - DeprecationWarning, - ) - if self.steps_weighted: - weights = { - client: steps * self.client_weights.get(client, 1.0) - for client, steps in n_steps.items() - } - else: - weights = { - client: self.client_weights.get(client, 1.0) - for client in updates - } - total = sum(weights.values()) - return {client: weight / total for client, weight in weights.items()} diff --git a/declearn/aggregator/_gma.py b/declearn/aggregator/_gma.py index dfdcc1d7..4832d4a6 100644 --- a/declearn/aggregator/_gma.py +++ b/declearn/aggregator/_gma.py @@ -18,7 +18,6 @@ """Gradient Masked Averaging aggregation class.""" import dataclasses -import warnings from typing import Any, Dict, Optional, Tuple from typing_extensions import Self # future: import from typing (py >=3.11) @@ -99,7 +98,6 @@ class GradientMaskedAveraging(Aggregator[GMAModelUpdates]): self, threshold: float = 1.0, steps_weighted: bool = True, - client_weights: Optional[Dict[str, float]] = None, ) -> None: """Instantiate a gradient masked averaging aggregator. @@ -111,20 +109,9 @@ class GradientMaskedAveraging(Aggregator[GMAModelUpdates]): steps_weighted: bool, default=True Whether to weight updates based on the number of optimization steps taken by the clients (relative to one another). - client_weights: dict[str, float] or None, default=None - Optional dict of client-wise base weights to use. - If None, homogeneous base weights are used. - - Notes - ----- - * One may specify `client_weights` and use `steps_weighted=True`. - In that case, the product of the client's base weight and their - number of training steps taken will be used (and unit-normed). - * One may use incomplete `client_weights`. In that case, unknown- - clients' base weights will be set to 1. """ self.threshold = threshold - self._avg = AveragingAggregator(steps_weighted, client_weights) + self._avg = AveragingAggregator(steps_weighted) def get_config( self, @@ -162,24 +149,3 @@ class GradientMaskedAveraging(Aggregator[GMAModelUpdates]): scores = (1 - clip) * scores + clip # s = 1 if s > t else s # Correct outputs' magnitude and return them. return values * scores - - def compute_client_weights( # pragma: no cover - self, - updates: Dict[str, Vector], - n_steps: Dict[str, int], - ) -> Dict[str, float]: - """Compute weights to use when averaging a given set of updates. - - This method is DEPRECATED as of DecLearn v2.4. - It will be removed in DecLearn 2.6 and/or 3.0. - """ - # pylint: disable=duplicate-code - warnings.warn( - f"'{self.__class__.__name__}.compute_client_weights' was" - " deprecated in DecLearn v2.4. It will be removed in DecLearn" - " v2.6 and/or v3.0.", - DeprecationWarning, - ) - with warnings.catch_warnings(): - warnings.simplefilter(action="ignore", category=DeprecationWarning) - return self._avg.compute_client_weights(updates, n_steps) diff --git a/declearn/communication/__init__.py b/declearn/communication/__init__.py index 93a508e0..5f15e89b 100644 --- a/declearn/communication/__init__.py +++ b/declearn/communication/__init__.py @@ -28,7 +28,6 @@ This module contains the following core submodules: * [utils][declearn.communication.utils]: Utils related to network communication endpoints' setup and usage. - It re-exports publicly from `utils` the following elements: * [build_client][declearn.communication.build_client]: @@ -52,10 +51,6 @@ the associated third-party dependencies are available: * [websockets][declearn.communication.websockets]: WebSockets-based network communication endpoints. Requires the `websockets` third-party package. - -Additionnally, for retro-compatibility purposes, it exports the DEPRECATED -[messaging][declearn.communication.messaging] submodule, that should no -longer be used, as its contents were re-dispatched elsewhere in DecLearn. """ # Messaging API and base tools: @@ -79,6 +74,3 @@ try: from . import websockets except ImportError: # pragma: no cover _INSTALLABLE_BACKENDS["websockets"] = ("websockets",) - -# DEPRECATED submodule, kept for retro-compatibility until 2.6 and/or 3.0. -from . import messaging diff --git a/declearn/communication/api/_client.py b/declearn/communication/api/_client.py index fa5f1648..2482d315 100644 --- a/declearn/communication/api/_client.py +++ b/declearn/communication/api/_client.py @@ -339,20 +339,3 @@ class NetworkClient(metaclass=abc.ABCMeta): ) self.logger.critical(error) raise TypeError(error) - - async def check_message( - self, - timeout: Optional[float] = None, - ) -> SerializedMessage: - """Await a message from the server, with optional timeout. - - This method is DEPRECATED in favor of the `recv_message` one. - It acts as an alias and will be removed in v2.6 and/or v3.0. - """ - warnings.warn( - "'NetworkServer.check_message' was renamed as 'recv_message' " - "in DecLearn 2.4. It now acts as an alias, but will be removed " - "in version 2.6 and/or 3.0.", - DeprecationWarning, - ) - return await self.recv_message(timeout) diff --git a/declearn/communication/messaging/__init__.py b/declearn/communication/messaging/__init__.py deleted file mode 100644 index 510fe6d6..00000000 --- a/declearn/communication/messaging/__init__.py +++ /dev/null @@ -1,57 +0,0 @@ -# coding: utf-8 - -# Copyright 2023 Inria (Institut National de Recherche en Informatique -# et Automatique) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""DEPRECATED submodule defining messaging containers and flags. - -This submodule was deprecated in DecLearn 2.4 in favor of `declearn.messaging`. -It should no longer be used, and will be removed in 2.6 and/or 3.0. - -Most of its contents are re-exports of non-deprecated classes and functions -from 'declearn.messaging'. Others will trigger deprecation warnings (and may -cause failures) if used. - -Deprecated classes uniquely-defined here are: - -* [Empty][declearn.communication.messaging.Empty] -* [GetMessageRequest][declearn.communication.messaging.GetMessageRequest] -* [JoinReply][declearn.communication.messaging.JoinReply] -* [JoinRequest][declearn.communication.messaging.JoinRequest] - -The `flags` submodule is also re-exported, but should preferably be imported -as `declearn.communication.api.backend.flags`. -""" - -from declearn.communication.api.backend import flags - -from ._messages import ( - CancelTraining, - Empty, - Error, - GenericMessage, - GetMessageRequest, - EvaluationReply, - EvaluationRequest, - InitRequest, - JoinReply, - JoinRequest, - Message, - PrivacyRequest, - StopTraining, - TrainReply, - TrainRequest, - parse_message_from_string, -) diff --git a/declearn/communication/messaging/_messages.py b/declearn/communication/messaging/_messages.py deleted file mode 100644 index f81574c2..00000000 --- a/declearn/communication/messaging/_messages.py +++ /dev/null @@ -1,153 +0,0 @@ -# coding: utf-8 - -# Copyright 2023 Inria (Institut National de Recherche en Informatique -# et Automatique) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Dataclasses defining messages used in declearn communications.""" - -import abc -import dataclasses -import warnings -from typing import Any, Dict, Optional - - -from declearn.messaging import ( - CancelTraining, - Error, - EvaluationReply, - EvaluationRequest, - GenericMessage, - InitRequest, - Message, - PrivacyRequest, - SerializedMessage, - StopTraining, - TrainReply, - TrainRequest, -) - - -__all__ = [ - "CancelTraining", - "Empty", - "Error", - "EvaluationReply", - "EvaluationRequest", - "GenericMessage", - "GetMessageRequest", - "InitRequest", - "JoinReply", - "JoinRequest", - "Message", - "PrivacyRequest", - "StopTraining", - "TrainReply", - "TrainRequest", - "parse_message_from_string", -] - - -@dataclasses.dataclass -class DeprecatedMessage(Message, register=False, metaclass=abc.ABCMeta): - """DEPRECATED Message subtype.""" - - def __post_init__( - self, - ) -> None: - warnings.warn( - f"'{self.__class__.__name__}' was deprecated in DecLearn v2.4. " - "It should no longer be used and may cause failures. It will be " - "removed in DecLearn v2.6 and/or v3.0", - DeprecationWarning, - ) - - -@dataclasses.dataclass -class Empty(DeprecatedMessage): - """DEPRECATED empty message class.""" - - typekey = "empty" - - -@dataclasses.dataclass -class GetMessageRequest(DeprecatedMessage): - """DEPRECATED message-retrieval query message class.""" - - typekey = "get_message" - - timeout: Optional[int] = None - - -@dataclasses.dataclass -class JoinRequest(DeprecatedMessage): - """DEPRECATED process joining query message class.""" - - typekey = "join_request" - - name: str - data_info: Dict[str, Any] - version: Optional[str] = None - - -@dataclasses.dataclass -class JoinReply(DeprecatedMessage): - """DEPRECATED process joining reply message class.""" - - typekey = "join_reply" - - accept: bool - flag: str - - -def parse_message_from_string( - string: str, -) -> Message: - """DEPRECATED - Instantiate a Message from its serialized string. - - This function was DEPRECATED in DecLearn 2.4 and will be removed - in v2.6 and/or v3.0. Use the `declearn.messaging.SerializedMessage` - API to parse serialized message strings. - - Parameters - ---------- - string: - Serialized string dump of the message. - - Returns - ------- - message: - Message instance recovered from the input string. - - Raises - ------ - KeyError - If the string's typekey does not match any supported Message - subclass. - TypeError - If the string cannot be parsed to identify a message typekey. - ValueError - If the serialized data fails to be properly decoded. - """ - warnings.warn( - "'parse_message_from_string' was deprecated in DecLearn 2.4, in " - "favor of using 'declearn.messaging.SerializedMessage' to parse " - "and deserialize 'Message' instances from strings. It will be " - "removed in DecLearn version 2.6 and/or 3.0.", - DeprecationWarning, - ) - serialized = SerializedMessage.from_message_string( - string - ) # type: SerializedMessage[Any] - return serialized.deserialize() diff --git a/declearn/dataset/__init__.py b/declearn/dataset/__init__.py index 7436e1ce..d885c587 100644 --- a/declearn/dataset/__init__.py +++ b/declearn/dataset/__init__.py @@ -29,8 +29,6 @@ API tools Abstract base class defining an API to access training or testing data. * [DataSpec][declearn.dataset.DataSpecs]: Dataclass to wrap a dataset's metadata. -* [load_dataset_from_json][declearn.dataset.load_dataset_from_json] - DEPRECATED Utility function to parse a JSON into a dataset object. Dataset subclasses ------------------ @@ -67,6 +65,6 @@ Utility entry-point from . import utils from . import examples -from ._base import Dataset, DataSpecs, load_dataset_from_json +from ._base import Dataset, DataSpecs from ._inmemory import InMemoryDataset from ._split_data import split_data diff --git a/declearn/dataset/_base.py b/declearn/dataset/_base.py index c0f3bbe7..64cf1c83 100644 --- a/declearn/dataset/_base.py +++ b/declearn/dataset/_base.py @@ -19,11 +19,10 @@ import abc import dataclasses -import warnings from typing import Any, Iterator, List, Optional, Set, Tuple, Union from declearn.typing import Batch -from declearn.utils import access_registered, create_types_registry, json_load +from declearn.utils import create_types_registry __all__ = [ "DataSpecs", @@ -109,42 +108,3 @@ class Dataset(metaclass=abc.ABCMeta): Optional weights associated with the samples, that are typically used to balance a model's loss or metrics. """ - - -def load_dataset_from_json(path: str) -> Dataset: # pragma: no cover - """DEPRECATED Instantiate a dataset based on a JSON dump file. - - Parameters - ---------- - path: str - Path to a JSON file output by the `save_to_json` - method of the Dataset that is being reloaded. - The actual type of dataset should be specified - under the "name" field of that file. - - Returns - ------- - dataset: Dataset - Dataset (subclass) instance, reloaded from JSON. - - Raises - ------ - NotImplementedError - If the target `Dataset` does not implement a `load_from_json` - method (which was removed from the API in DecLearn 2.3.0). - """ - warnings.warn( - "'load_dataset_from_json' was deprecated in Declearn 2.4.0, after" - "'Dataset.load_from_json' was removed from the API in v2.3.0. It " - "may raise a 'NotImplementedError', and will be removed in DecLearn " - "2.6 and/or 3.0.", - category=DeprecationWarning, - stacklevel=2, - ) - dump = json_load(path) - cls = access_registered(dump["name"], group="Dataset") - if not hasattr(cls, "load_from_json"): - raise NotImplementedError( - f"Dataset class '{cls}' does not implement 'load_from_json'." - ) - return cls.load_from_json(path) diff --git a/declearn/main/_client.py b/declearn/main/_client.py index 506046e6..0b03c36a 100644 --- a/declearn/main/_client.py +++ b/declearn/main/_client.py @@ -32,7 +32,7 @@ from declearn.communication.utils import ( NetworkClientConfig, verify_server_message_validity, ) -from declearn.dataset import Dataset, load_dataset_from_json +from declearn.dataset import Dataset from declearn.fairness.api import FairnessControllerClient from declearn.main.utils import Checkpointer from declearn.messaging import Message, SerializedMessage @@ -116,16 +116,12 @@ class FederatedClient: if replace_netwk_logger: self.netwk.logger = self.logger # Assign the wrapped training dataset. - if isinstance(train_data, str): - train_data = load_dataset_from_json(train_data) if not isinstance(train_data, Dataset): - raise TypeError("'train_data' should be a Dataset or path to one.") + raise TypeError("'train_data' should be a Dataset.") self.train_data = train_data # Assign the wrapped validation dataset (if any). - if isinstance(valid_data, str): - valid_data = load_dataset_from_json(valid_data) if not (valid_data is None or isinstance(valid_data, Dataset)): - raise TypeError("'valid_data' should be a Dataset or path to one.") + raise TypeError("'valid_data' should be a Dataset.") self.valid_data = valid_data # Assign an optional checkpointer. if checkpoint is not None: diff --git a/declearn/metrics/_api.py b/declearn/metrics/_api.py index 04405777..46a53d50 100644 --- a/declearn/metrics/_api.py +++ b/declearn/metrics/_api.py @@ -18,7 +18,6 @@ """Iterative and federative evaluation metrics base class.""" import abc -import warnings from copy import deepcopy from typing import Any, ClassVar, Dict, Generic, Optional, Type, TypeVar, Union @@ -248,46 +247,6 @@ class Metric(Generic[MetricStateT], metaclass=abc.ABCMeta): ) self._states = deepcopy(states) # type: ignore - def agg_states( - self, - states: MetricStateT, - ) -> None: - """Aggregate provided state variables into self ones. - - This method is DEPRECATED as of DecLearn v2.4, in favor of - merely aggregating `MetricState` instances, using either - their `aggregate` method or the overloaded `+` operator. - It will be removed in DecLearn 2.6 and/or 3.0. - - This method is designed to aggregate results from multiple - similar metrics objects into a single one before computing - its results. - - Parameters - ---------- - states: - `MetricState` emitted by another instance of this class - via its `get_states` method. - - Raises - ------ - TypeError - If `states` is of improper type. - """ - warnings.warn( - "'Metric.agg_states' was deprecated in DecLearn v2.4, in favor " - "of aggregating 'MetricState' instances directly, and setting " - "final aggregated states using 'Metric.set_state'. It will be " - "removed in DecLearn 2.6 and/or 3.0.", - DeprecationWarning, - ) - if not isinstance(states, self.state_cls): - raise TypeError( - f"'{self.__class__.__name__}.set_states' expected " - f"'{self.state_cls}' inputs, got '{type(states)}'." - ) - self.set_states(self._states + states) - def __init_subclass__( cls, register: bool = True, diff --git a/declearn/metrics/_wrapper.py b/declearn/metrics/_wrapper.py index cbb5cee2..9e7a1709 100644 --- a/declearn/metrics/_wrapper.py +++ b/declearn/metrics/_wrapper.py @@ -17,7 +17,6 @@ """Wrapper for an ensemble of Metric objects.""" -import warnings from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import numpy as np @@ -215,49 +214,6 @@ class MetricSet: if metric.name in states: metric.set_states(states[metric.name]) - def agg_states( - self, - states: Dict[str, MetricState], - ) -> None: - """Aggregate provided state variables into self ones. - - This method is DEPRECATED as of DecLearn v2.4, in favor of - merely aggregating `MetricState` instances, using either - their `aggregate` method or the overloaded `+` operator. - It will be removed in DecLearn 2.6 and/or 3.0. - - This method is designed to aggregate results from multiple - similar metrics objects into a single one before computing - its results. - - Parameters - ---------- - states: dict[str, float or numpy.ndarray] - Dict of states emitted by another instance of this class - via its `get_states` method. - - Raises - ------ - KeyError - If any state variable is missing from `states`. - TypeError - If any state variable is of improper type. - ValueError - If any array state variable is of improper shape. - """ - warnings.warn( - "'MetricSet.agg_states' was deprecated in DecLearn v2.4, in favor " - "of aggregating 'MetricState' instances directly, and setting " - "final aggregated states using 'MetricSet.set_state'. It will be " - "removed in DecLearn 2.6 and/or 3.0.", - DeprecationWarning, - ) - with warnings.catch_warnings(): - warnings.simplefilter(action="ignore", category=DeprecationWarning) - for metric in self.metrics: - if metric.name in states: - metric.agg_states(states[metric.name]) - def get_config( self, ) -> Dict[str, Any]: diff --git a/declearn/optimizer/modules/_scaffold.py b/declearn/optimizer/modules/_scaffold.py index c5c9a4a8..277bac03 100644 --- a/declearn/optimizer/modules/_scaffold.py +++ b/declearn/optimizer/modules/_scaffold.py @@ -35,7 +35,7 @@ References import dataclasses import uuid import warnings -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Optional, Set, Tuple, Union from declearn.model.api import Vector from declearn.optimizer.modules._api import AuxVar, OptiModule @@ -386,25 +386,10 @@ class ScaffoldServerModule(OptiModule[ScaffoldAuxVar]): def __init__( self, - clients: Optional[List[str]] = None, ) -> None: - """Instantiate the server-side SCAFFOLD gradients-correction module. - - Parameters - ---------- - clients: - DEPRECATED and unused starting with declearn 2.4. - Optional list of known clients' id strings. - """ + """Instantiate the server-side SCAFFOLD gradients-correction module.""" self.s_state = 0.0 # type: Union[Vector, float] self.clients = set() # type: Set[str] - if clients: # pragma: no cover - warnings.warn( - "ScaffoldServerModule's 'clients' argument has been deprecated" - " as of declearn v2.4, and no longer has any effect. It will" - " be removed in declearn 2.6 and/or 3.0.", - DeprecationWarning, - ) def run( self, diff --git a/docs/release-notes/v2.6.0.md b/docs/release-notes/v2.6.0.md index 1c1da1aa..4d1bf837 100644 --- a/docs/release-notes/v2.6.0.md +++ b/docs/release-notes/v2.6.0.md @@ -131,6 +131,32 @@ fairness levels of the last model. ## Other changes +### Removal of deprecated features + +A number of features were deprecated in DecLearn 2.4.0 (whether legacy API +methods, submdules or methods that were re-organized or renamed, parameters +that were no longer used or a plainly-removed function). As of this new +release, those features that had been kept back with a deprecation warning +are now removed from the code. + +As a remainder, the removed features include: + +* Legacy aggregation methods: + - `declearn.aggregator.Aggregator.aggregate` + - `declearn.metrics.Metric.agg_states` + - `declearn.metrics.MetricSet.agg_states` +* Legacy instantiation parameters: + - `declearn.aggregator.AveragingAggregator` parameter `client_weights` + - `declearn.aggregator.GradientMaskedAveraging` parameter `client_weights` + - `declearn.optimizer.modules.ScaffoldServerModule` parameter `clients` +* Legacy names that were aliasing new locations: + - `declearn.communication.messaging` (moved to `declearn.messaging`) + - `declearn.communication.NetworkClient.check_message` (renamed + `recv_message`) +* `declearn.dataset.load_dataset_from_json` + +### New developer-oriented changes + A few minor changes are shipped with this new release, that are mostly of interest to developers - including end-users writing custom algorithms or bridging DecLearn APIs within their own orchestration code. diff --git a/test/aggregator/test_aggregator.py b/test/aggregator/test_aggregator.py index e619bcce..8d1f74b4 100644 --- a/test/aggregator/test_aggregator.py +++ b/test/aggregator/test_aggregator.py @@ -143,31 +143,3 @@ class TestAggregator: result = aggregator.finalize_updates(output) expect = aggregator.finalize_updates(updates_a + updates_b) assert result == expect - - # DEPRECATED: the following tests cover deprecated methods - - @pytest.mark.parametrize("framework", VECTOR_FRAMEWORKS) - def test_aggregate( - self, - agg_cls: Type[Aggregator], - updates: Dict[str, Vector], - ) -> None: - """Test that the legacy (deprecated) 'aggregate' method still works.""" - agg = agg_cls() - n_steps = {key: 10 for key in updates} - with pytest.warns(DeprecationWarning): - outputs = agg.aggregate(updates, n_steps) - ref_vec = list(updates.values())[0] - assert isinstance(outputs, type(ref_vec)) - assert outputs.shapes() == ref_vec.shapes() - assert outputs.dtypes() == ref_vec.dtypes() - - def test_aggregate_empty( - self, - agg_cls: Type[Aggregator], - ) -> None: - """Test that 'aggregate' raises the expected error on empty inputs.""" - agg = agg_cls() - with pytest.warns(DeprecationWarning): - with pytest.raises(TypeError): - agg.aggregate(updates={}, n_steps={}) diff --git a/test/main/test_main_client.py b/test/main/test_main_client.py index e002b9dc..345c1a20 100644 --- a/test/main/test_main_client.py +++ b/test/main/test_main_client.py @@ -139,17 +139,6 @@ class TestFederatedClientInit: # pylint: disable=too-many-public-methods client = FederatedClient(netwk=MOCK_NETWK, train_data=dataset) assert client.train_data is dataset - def test_train_data_str(self) -> None: - """Test specifying 'train_data' as a file path.""" - path = "mock_path_to_dataset.json" - with mock.patch( - "declearn.main._client.load_dataset_from_json", - return_value=mock.create_autospec(Dataset, instance=True), - ) as patched: - client = FederatedClient(netwk=MOCK_NETWK, train_data=path) - patched.assert_called_once_with(path) - assert client.train_data is patched.return_value - def test_train_data_invalid(self) -> None: """Test specifying 'train_data' as an invalid type.""" with pytest.raises(TypeError): @@ -172,19 +161,6 @@ class TestFederatedClientInit: # pylint: disable=too-many-public-methods ) assert client.valid_data is dataset - def test_valid_data_str(self) -> None: - """Test specifying 'valid_data' as a file path.""" - path = "mock_path_to_dataset.json" - with mock.patch( - "declearn.main._client.load_dataset_from_json", - return_value=mock.create_autospec(Dataset, instance=True), - ) as patched: - client = FederatedClient( - netwk=MOCK_NETWK, train_data=MOCK_DATASET, valid_data=path - ) - patched.assert_called_once_with(path) - assert client.valid_data is patched.return_value - def test_valid_data_invalid(self) -> None: """Test specifying 'valid_data' as an invalid type.""" with pytest.raises(TypeError): diff --git a/test/metrics/metric_testing.py b/test/metrics/metric_testing.py index 2adb2fd9..043c6341 100644 --- a/test/metrics/metric_testing.py +++ b/test/metrics/metric_testing.py @@ -177,32 +177,6 @@ class MetricTestSuite: expect = test_case.agg_scores assert_dict_equal(metric.get_result(), expect, np_tolerance=self.tol) - def test_legacy_agg_states(self, test_case: MetricTestCase) -> None: - """Test that the deprecated `agg_states` method works as expected.""" - # Set up and update two identical metrics. - metric = test_case.metric - metbis = deepcopy(test_case.metric) - metric.update(**test_case.inputs) - metbis.update(**test_case.inputs) - # Aggregate the second into the first. Verify that they now differ. - assert_dict_equal( - metric.get_states().to_dict(), metbis.get_states().to_dict() - ) - with pytest.warns(DeprecationWarning): - metbis.agg_states(metric.get_states()) - assert_dict_equal( - metric.get_states().to_dict(), test_case.states.to_dict() - ) - with pytest.raises(AssertionError): # assert not equal - assert_dict_equal( - metric.get_states().to_dict(), metbis.get_states().to_dict() - ) - # Verify the correctness of the aggregated states and scores. - states = test_case.agg_states - scores = test_case.agg_scores - assert_dict_equal(metbis.get_states().to_dict(), states.to_dict()) - assert_dict_equal(metbis.get_result(), scores, np_tolerance=self.tol) - def test_update_with_squeezable_inputs( self, test_case: MetricTestCase ) -> None: diff --git a/test/metrics/test_metricset.py b/test/metrics/test_metricset.py index 1ff2eae2..8eb7b445 100644 --- a/test/metrics/test_metricset.py +++ b/test/metrics/test_metricset.py @@ -18,7 +18,7 @@ """Unit tests for `declearn.metrics.MetricSet`.""" from unittest import mock -from typing import Dict, Tuple +from typing import Tuple import numpy as np import pytest @@ -129,18 +129,6 @@ class TestMetricSet: states["mse"] ) - def test_agg_states(self) -> None: - """Test that deprecated `MetricSet.agg_states` works as expected.""" - mae, mse, metrics = get_mock_metricset() - states = { - "mae": mae.get_states(), - "mse": mse.get_states(), - } # type: Dict[str, MetricState] - with pytest.warns(DeprecationWarning): - metrics.agg_states(states) - mae.agg_states.assert_called_once_with(states["mae"]) # type: ignore - mse.agg_states.assert_called_once_with(states["mse"]) # type: ignore - def test_get_config(self) -> None: """Test that `MetricSet.get_config` works as expected.""" mae = MeanAbsoluteError() -- GitLab