Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit 8846f56b authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Remove deprecated features planned for v2.6 removal.

parent 391c2e48
No related branches found
No related tags found
1 merge request!70Finalize version 2.6.0
Showing
with 36 additions and 643 deletions
......@@ -19,7 +19,6 @@
import abc
import dataclasses
import warnings
from typing import Any, ClassVar, Dict, Generic, Type, TypeVar, Union
from typing_extensions import Self # future: import from typing (py >=3.11)
......@@ -185,48 +184,6 @@ class Aggregator(Generic[ModelUpdatesT], metaclass=abc.ABCMeta):
"""Instantiate an Aggregator from its configuration dict."""
return cls(**config)
def aggregate(
self,
updates: Dict[str, Vector[T]],
n_steps: Dict[str, int], # revise: abstract~generalize kwargs use
) -> Vector[T]:
"""DEPRECATED - Aggregate input vectors into a single one.
Parameters
----------
updates: dict[str, Vector]
Client-wise updates, as a dictionary with clients' names as
string keys and updates as Vector values.
n_steps: dict[str, int]
Client-wise number of local training steps performed during
the training round having produced the updates.
Returns
-------
gradients: Vector
Aggregated updates, as a Vector - treated as gradients by
the server-side optimizer.
Raises
------
TypeError
If the input `updates` are an empty dict.
"""
warnings.warn(
"'Aggregator.aggregate' was deprecated in DecLearn v2.4 in favor "
"of new API methods. It will be removed in DecLearn v2.6 and/or "
"v3.0.",
DeprecationWarning,
)
if not updates:
raise TypeError("'Aggregator.aggregate' received an empty dict.")
partials = [
self.prepare_for_sharing(updates[client], n_steps[client])
for client in updates
]
aggregated = sum(partials[1:], start=partials[0])
return self.finalize_updates(aggregated)
def list_aggregators() -> Dict[str, Type[Aggregator]]:
"""Return a mapping of registered Aggregator subclasses.
......
......@@ -17,8 +17,7 @@
"""FedAvg-like mean-aggregation class."""
import warnings
from typing import Any, Dict, Optional
from typing import Any, Dict
from declearn.aggregator._api import Aggregator, ModelUpdates
......@@ -44,7 +43,6 @@ class AveragingAggregator(Aggregator[ModelUpdates]):
def __init__(
self,
steps_weighted: bool = True,
client_weights: Optional[Dict[str, float]] = None,
) -> None:
"""Instantiate an averaging aggregator.
......@@ -53,37 +51,14 @@ class AveragingAggregator(Aggregator[ModelUpdates]):
steps_weighted:
Whether to conduct a weighted averaging of local model
updates based on local numbers of training steps.
client_weights:
DEPRECATED - this argument no longer affects computations,
save when using the deprecated 'aggregate' method.
Optional dict of client-wise base weights to use.
If None, homogeneous base weights are used.
Notes
-----
* One may specify `client_weights` and use `steps_weighted=True`.
In that case, the product of the client's base weight and their
number of training steps taken will be used (and unit-normed).
* One may use incomplete `client_weights`. In that case, unknown-
clients' base weights will be set to 1.
"""
self.steps_weighted = steps_weighted
self.client_weights = client_weights or {}
if client_weights: # pragma: no cover
warnings.warn(
f"'client_weights' argument to '{self.__class__.__name__}' was"
" deprecated in DecLearn v2.4 and is no longer used, saved by"
" the deprecated 'aggregate' method. It will be removed in"
" DecLearn v2.6 and/or v3.0.",
DeprecationWarning,
)
def get_config(
self,
) -> Dict[str, Any]:
return {
"steps_weighted": self.steps_weighted,
"client_weights": self.client_weights,
}
def prepare_for_sharing(
......@@ -103,63 +78,3 @@ class AveragingAggregator(Aggregator[ModelUpdates]):
updates: ModelUpdates,
) -> Vector:
return updates.updates / updates.weights
def aggregate(
self,
updates: Dict[str, Vector],
n_steps: Dict[str, int],
) -> Vector:
# Make use of 'client_weights' as part of this DEPRECATED method.
with warnings.catch_warnings():
warnings.simplefilter(action="ignore", category=DeprecationWarning)
weights = self.compute_client_weights(updates, n_steps)
steps_weighted = self.steps_weighted
try:
self.steps_weighted = True
return super().aggregate(updates, weights) # type: ignore
finally:
self.steps_weighted = steps_weighted
def compute_client_weights(
self,
updates: Dict[str, Vector],
n_steps: Dict[str, int],
) -> Dict[str, float]:
"""Compute weights to use when averaging a given set of updates.
This method is DEPRECATED as of DecLearn v2.4.
It will be removed in DecLearn 2.6 and/or 3.0.
Parameters
----------
updates: dict[str, Vector]
Client-wise updates, as a dictionary with clients' names as
string keys and updates as Vector values.
n_steps: dict[str, int]
Client-wise number of local training steps performed during
the training round having produced the updates.
Returns
-------
weights: dict[str, float]
Client-wise updates-averaging weights, suited to the input
parameters and normalized so that they sum to 1.
"""
warnings.warn(
f"'{self.__class__.__name__}.compute_client_weights' was"
" deprecated in DecLearn v2.4. It will be removed in DecLearn"
" v2.6 and/or v3.0.",
DeprecationWarning,
)
if self.steps_weighted:
weights = {
client: steps * self.client_weights.get(client, 1.0)
for client, steps in n_steps.items()
}
else:
weights = {
client: self.client_weights.get(client, 1.0)
for client in updates
}
total = sum(weights.values())
return {client: weight / total for client, weight in weights.items()}
......@@ -18,7 +18,6 @@
"""Gradient Masked Averaging aggregation class."""
import dataclasses
import warnings
from typing import Any, Dict, Optional, Tuple
from typing_extensions import Self # future: import from typing (py >=3.11)
......@@ -99,7 +98,6 @@ class GradientMaskedAveraging(Aggregator[GMAModelUpdates]):
self,
threshold: float = 1.0,
steps_weighted: bool = True,
client_weights: Optional[Dict[str, float]] = None,
) -> None:
"""Instantiate a gradient masked averaging aggregator.
......@@ -111,20 +109,9 @@ class GradientMaskedAveraging(Aggregator[GMAModelUpdates]):
steps_weighted: bool, default=True
Whether to weight updates based on the number of optimization
steps taken by the clients (relative to one another).
client_weights: dict[str, float] or None, default=None
Optional dict of client-wise base weights to use.
If None, homogeneous base weights are used.
Notes
-----
* One may specify `client_weights` and use `steps_weighted=True`.
In that case, the product of the client's base weight and their
number of training steps taken will be used (and unit-normed).
* One may use incomplete `client_weights`. In that case, unknown-
clients' base weights will be set to 1.
"""
self.threshold = threshold
self._avg = AveragingAggregator(steps_weighted, client_weights)
self._avg = AveragingAggregator(steps_weighted)
def get_config(
self,
......@@ -162,24 +149,3 @@ class GradientMaskedAveraging(Aggregator[GMAModelUpdates]):
scores = (1 - clip) * scores + clip # s = 1 if s > t else s
# Correct outputs' magnitude and return them.
return values * scores
def compute_client_weights( # pragma: no cover
self,
updates: Dict[str, Vector],
n_steps: Dict[str, int],
) -> Dict[str, float]:
"""Compute weights to use when averaging a given set of updates.
This method is DEPRECATED as of DecLearn v2.4.
It will be removed in DecLearn 2.6 and/or 3.0.
"""
# pylint: disable=duplicate-code
warnings.warn(
f"'{self.__class__.__name__}.compute_client_weights' was"
" deprecated in DecLearn v2.4. It will be removed in DecLearn"
" v2.6 and/or v3.0.",
DeprecationWarning,
)
with warnings.catch_warnings():
warnings.simplefilter(action="ignore", category=DeprecationWarning)
return self._avg.compute_client_weights(updates, n_steps)
......@@ -28,7 +28,6 @@ This module contains the following core submodules:
* [utils][declearn.communication.utils]:
Utils related to network communication endpoints' setup and usage.
It re-exports publicly from `utils` the following elements:
* [build_client][declearn.communication.build_client]:
......@@ -52,10 +51,6 @@ the associated third-party dependencies are available:
* [websockets][declearn.communication.websockets]:
WebSockets-based network communication endpoints.
Requires the `websockets` third-party package.
Additionnally, for retro-compatibility purposes, it exports the DEPRECATED
[messaging][declearn.communication.messaging] submodule, that should no
longer be used, as its contents were re-dispatched elsewhere in DecLearn.
"""
# Messaging API and base tools:
......@@ -79,6 +74,3 @@ try:
from . import websockets
except ImportError: # pragma: no cover
_INSTALLABLE_BACKENDS["websockets"] = ("websockets",)
# DEPRECATED submodule, kept for retro-compatibility until 2.6 and/or 3.0.
from . import messaging
......@@ -339,20 +339,3 @@ class NetworkClient(metaclass=abc.ABCMeta):
)
self.logger.critical(error)
raise TypeError(error)
async def check_message(
self,
timeout: Optional[float] = None,
) -> SerializedMessage:
"""Await a message from the server, with optional timeout.
This method is DEPRECATED in favor of the `recv_message` one.
It acts as an alias and will be removed in v2.6 and/or v3.0.
"""
warnings.warn(
"'NetworkServer.check_message' was renamed as 'recv_message' "
"in DecLearn 2.4. It now acts as an alias, but will be removed "
"in version 2.6 and/or 3.0.",
DeprecationWarning,
)
return await self.recv_message(timeout)
# coding: utf-8
# Copyright 2023 Inria (Institut National de Recherche en Informatique
# et Automatique)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DEPRECATED submodule defining messaging containers and flags.
This submodule was deprecated in DecLearn 2.4 in favor of `declearn.messaging`.
It should no longer be used, and will be removed in 2.6 and/or 3.0.
Most of its contents are re-exports of non-deprecated classes and functions
from 'declearn.messaging'. Others will trigger deprecation warnings (and may
cause failures) if used.
Deprecated classes uniquely-defined here are:
* [Empty][declearn.communication.messaging.Empty]
* [GetMessageRequest][declearn.communication.messaging.GetMessageRequest]
* [JoinReply][declearn.communication.messaging.JoinReply]
* [JoinRequest][declearn.communication.messaging.JoinRequest]
The `flags` submodule is also re-exported, but should preferably be imported
as `declearn.communication.api.backend.flags`.
"""
from declearn.communication.api.backend import flags
from ._messages import (
CancelTraining,
Empty,
Error,
GenericMessage,
GetMessageRequest,
EvaluationReply,
EvaluationRequest,
InitRequest,
JoinReply,
JoinRequest,
Message,
PrivacyRequest,
StopTraining,
TrainReply,
TrainRequest,
parse_message_from_string,
)
# coding: utf-8
# Copyright 2023 Inria (Institut National de Recherche en Informatique
# et Automatique)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dataclasses defining messages used in declearn communications."""
import abc
import dataclasses
import warnings
from typing import Any, Dict, Optional
from declearn.messaging import (
CancelTraining,
Error,
EvaluationReply,
EvaluationRequest,
GenericMessage,
InitRequest,
Message,
PrivacyRequest,
SerializedMessage,
StopTraining,
TrainReply,
TrainRequest,
)
__all__ = [
"CancelTraining",
"Empty",
"Error",
"EvaluationReply",
"EvaluationRequest",
"GenericMessage",
"GetMessageRequest",
"InitRequest",
"JoinReply",
"JoinRequest",
"Message",
"PrivacyRequest",
"StopTraining",
"TrainReply",
"TrainRequest",
"parse_message_from_string",
]
@dataclasses.dataclass
class DeprecatedMessage(Message, register=False, metaclass=abc.ABCMeta):
"""DEPRECATED Message subtype."""
def __post_init__(
self,
) -> None:
warnings.warn(
f"'{self.__class__.__name__}' was deprecated in DecLearn v2.4. "
"It should no longer be used and may cause failures. It will be "
"removed in DecLearn v2.6 and/or v3.0",
DeprecationWarning,
)
@dataclasses.dataclass
class Empty(DeprecatedMessage):
"""DEPRECATED empty message class."""
typekey = "empty"
@dataclasses.dataclass
class GetMessageRequest(DeprecatedMessage):
"""DEPRECATED message-retrieval query message class."""
typekey = "get_message"
timeout: Optional[int] = None
@dataclasses.dataclass
class JoinRequest(DeprecatedMessage):
"""DEPRECATED process joining query message class."""
typekey = "join_request"
name: str
data_info: Dict[str, Any]
version: Optional[str] = None
@dataclasses.dataclass
class JoinReply(DeprecatedMessage):
"""DEPRECATED process joining reply message class."""
typekey = "join_reply"
accept: bool
flag: str
def parse_message_from_string(
string: str,
) -> Message:
"""DEPRECATED - Instantiate a Message from its serialized string.
This function was DEPRECATED in DecLearn 2.4 and will be removed
in v2.6 and/or v3.0. Use the `declearn.messaging.SerializedMessage`
API to parse serialized message strings.
Parameters
----------
string:
Serialized string dump of the message.
Returns
-------
message:
Message instance recovered from the input string.
Raises
------
KeyError
If the string's typekey does not match any supported Message
subclass.
TypeError
If the string cannot be parsed to identify a message typekey.
ValueError
If the serialized data fails to be properly decoded.
"""
warnings.warn(
"'parse_message_from_string' was deprecated in DecLearn 2.4, in "
"favor of using 'declearn.messaging.SerializedMessage' to parse "
"and deserialize 'Message' instances from strings. It will be "
"removed in DecLearn version 2.6 and/or 3.0.",
DeprecationWarning,
)
serialized = SerializedMessage.from_message_string(
string
) # type: SerializedMessage[Any]
return serialized.deserialize()
......@@ -29,8 +29,6 @@ API tools
Abstract base class defining an API to access training or testing data.
* [DataSpec][declearn.dataset.DataSpecs]:
Dataclass to wrap a dataset's metadata.
* [load_dataset_from_json][declearn.dataset.load_dataset_from_json]
DEPRECATED Utility function to parse a JSON into a dataset object.
Dataset subclasses
------------------
......@@ -67,6 +65,6 @@ Utility entry-point
from . import utils
from . import examples
from ._base import Dataset, DataSpecs, load_dataset_from_json
from ._base import Dataset, DataSpecs
from ._inmemory import InMemoryDataset
from ._split_data import split_data
......@@ -19,11 +19,10 @@
import abc
import dataclasses
import warnings
from typing import Any, Iterator, List, Optional, Set, Tuple, Union
from declearn.typing import Batch
from declearn.utils import access_registered, create_types_registry, json_load
from declearn.utils import create_types_registry
__all__ = [
"DataSpecs",
......@@ -109,42 +108,3 @@ class Dataset(metaclass=abc.ABCMeta):
Optional weights associated with the samples, that are
typically used to balance a model's loss or metrics.
"""
def load_dataset_from_json(path: str) -> Dataset: # pragma: no cover
"""DEPRECATED Instantiate a dataset based on a JSON dump file.
Parameters
----------
path: str
Path to a JSON file output by the `save_to_json`
method of the Dataset that is being reloaded.
The actual type of dataset should be specified
under the "name" field of that file.
Returns
-------
dataset: Dataset
Dataset (subclass) instance, reloaded from JSON.
Raises
------
NotImplementedError
If the target `Dataset` does not implement a `load_from_json`
method (which was removed from the API in DecLearn 2.3.0).
"""
warnings.warn(
"'load_dataset_from_json' was deprecated in Declearn 2.4.0, after"
"'Dataset.load_from_json' was removed from the API in v2.3.0. It "
"may raise a 'NotImplementedError', and will be removed in DecLearn "
"2.6 and/or 3.0.",
category=DeprecationWarning,
stacklevel=2,
)
dump = json_load(path)
cls = access_registered(dump["name"], group="Dataset")
if not hasattr(cls, "load_from_json"):
raise NotImplementedError(
f"Dataset class '{cls}' does not implement 'load_from_json'."
)
return cls.load_from_json(path)
......@@ -32,7 +32,7 @@ from declearn.communication.utils import (
NetworkClientConfig,
verify_server_message_validity,
)
from declearn.dataset import Dataset, load_dataset_from_json
from declearn.dataset import Dataset
from declearn.fairness.api import FairnessControllerClient
from declearn.main.utils import Checkpointer
from declearn.messaging import Message, SerializedMessage
......@@ -116,16 +116,12 @@ class FederatedClient:
if replace_netwk_logger:
self.netwk.logger = self.logger
# Assign the wrapped training dataset.
if isinstance(train_data, str):
train_data = load_dataset_from_json(train_data)
if not isinstance(train_data, Dataset):
raise TypeError("'train_data' should be a Dataset or path to one.")
raise TypeError("'train_data' should be a Dataset.")
self.train_data = train_data
# Assign the wrapped validation dataset (if any).
if isinstance(valid_data, str):
valid_data = load_dataset_from_json(valid_data)
if not (valid_data is None or isinstance(valid_data, Dataset)):
raise TypeError("'valid_data' should be a Dataset or path to one.")
raise TypeError("'valid_data' should be a Dataset.")
self.valid_data = valid_data
# Assign an optional checkpointer.
if checkpoint is not None:
......
......@@ -18,7 +18,6 @@
"""Iterative and federative evaluation metrics base class."""
import abc
import warnings
from copy import deepcopy
from typing import Any, ClassVar, Dict, Generic, Optional, Type, TypeVar, Union
......@@ -248,46 +247,6 @@ class Metric(Generic[MetricStateT], metaclass=abc.ABCMeta):
)
self._states = deepcopy(states) # type: ignore
def agg_states(
self,
states: MetricStateT,
) -> None:
"""Aggregate provided state variables into self ones.
This method is DEPRECATED as of DecLearn v2.4, in favor of
merely aggregating `MetricState` instances, using either
their `aggregate` method or the overloaded `+` operator.
It will be removed in DecLearn 2.6 and/or 3.0.
This method is designed to aggregate results from multiple
similar metrics objects into a single one before computing
its results.
Parameters
----------
states:
`MetricState` emitted by another instance of this class
via its `get_states` method.
Raises
------
TypeError
If `states` is of improper type.
"""
warnings.warn(
"'Metric.agg_states' was deprecated in DecLearn v2.4, in favor "
"of aggregating 'MetricState' instances directly, and setting "
"final aggregated states using 'Metric.set_state'. It will be "
"removed in DecLearn 2.6 and/or 3.0.",
DeprecationWarning,
)
if not isinstance(states, self.state_cls):
raise TypeError(
f"'{self.__class__.__name__}.set_states' expected "
f"'{self.state_cls}' inputs, got '{type(states)}'."
)
self.set_states(self._states + states)
def __init_subclass__(
cls,
register: bool = True,
......
......@@ -17,7 +17,6 @@
"""Wrapper for an ensemble of Metric objects."""
import warnings
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
......@@ -215,49 +214,6 @@ class MetricSet:
if metric.name in states:
metric.set_states(states[metric.name])
def agg_states(
self,
states: Dict[str, MetricState],
) -> None:
"""Aggregate provided state variables into self ones.
This method is DEPRECATED as of DecLearn v2.4, in favor of
merely aggregating `MetricState` instances, using either
their `aggregate` method or the overloaded `+` operator.
It will be removed in DecLearn 2.6 and/or 3.0.
This method is designed to aggregate results from multiple
similar metrics objects into a single one before computing
its results.
Parameters
----------
states: dict[str, float or numpy.ndarray]
Dict of states emitted by another instance of this class
via its `get_states` method.
Raises
------
KeyError
If any state variable is missing from `states`.
TypeError
If any state variable is of improper type.
ValueError
If any array state variable is of improper shape.
"""
warnings.warn(
"'MetricSet.agg_states' was deprecated in DecLearn v2.4, in favor "
"of aggregating 'MetricState' instances directly, and setting "
"final aggregated states using 'MetricSet.set_state'. It will be "
"removed in DecLearn 2.6 and/or 3.0.",
DeprecationWarning,
)
with warnings.catch_warnings():
warnings.simplefilter(action="ignore", category=DeprecationWarning)
for metric in self.metrics:
if metric.name in states:
metric.agg_states(states[metric.name])
def get_config(
self,
) -> Dict[str, Any]:
......
......@@ -35,7 +35,7 @@ References
import dataclasses
import uuid
import warnings
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Dict, Optional, Set, Tuple, Union
from declearn.model.api import Vector
from declearn.optimizer.modules._api import AuxVar, OptiModule
......@@ -386,25 +386,10 @@ class ScaffoldServerModule(OptiModule[ScaffoldAuxVar]):
def __init__(
self,
clients: Optional[List[str]] = None,
) -> None:
"""Instantiate the server-side SCAFFOLD gradients-correction module.
Parameters
----------
clients:
DEPRECATED and unused starting with declearn 2.4.
Optional list of known clients' id strings.
"""
"""Instantiate the server-side SCAFFOLD gradients-correction module."""
self.s_state = 0.0 # type: Union[Vector, float]
self.clients = set() # type: Set[str]
if clients: # pragma: no cover
warnings.warn(
"ScaffoldServerModule's 'clients' argument has been deprecated"
" as of declearn v2.4, and no longer has any effect. It will"
" be removed in declearn 2.6 and/or 3.0.",
DeprecationWarning,
)
def run(
self,
......
......@@ -131,6 +131,32 @@ fairness levels of the last model.
## Other changes
### Removal of deprecated features
A number of features were deprecated in DecLearn 2.4.0 (whether legacy API
methods, submdules or methods that were re-organized or renamed, parameters
that were no longer used or a plainly-removed function). As of this new
release, those features that had been kept back with a deprecation warning
are now removed from the code.
As a remainder, the removed features include:
* Legacy aggregation methods:
- `declearn.aggregator.Aggregator.aggregate`
- `declearn.metrics.Metric.agg_states`
- `declearn.metrics.MetricSet.agg_states`
* Legacy instantiation parameters:
- `declearn.aggregator.AveragingAggregator` parameter `client_weights`
- `declearn.aggregator.GradientMaskedAveraging` parameter `client_weights`
- `declearn.optimizer.modules.ScaffoldServerModule` parameter `clients`
* Legacy names that were aliasing new locations:
- `declearn.communication.messaging` (moved to `declearn.messaging`)
- `declearn.communication.NetworkClient.check_message` (renamed
`recv_message`)
* `declearn.dataset.load_dataset_from_json`
### New developer-oriented changes
A few minor changes are shipped with this new release, that are mostly of
interest to developers - including end-users writing custom algorithms or
bridging DecLearn APIs within their own orchestration code.
......
......@@ -143,31 +143,3 @@ class TestAggregator:
result = aggregator.finalize_updates(output)
expect = aggregator.finalize_updates(updates_a + updates_b)
assert result == expect
# DEPRECATED: the following tests cover deprecated methods
@pytest.mark.parametrize("framework", VECTOR_FRAMEWORKS)
def test_aggregate(
self,
agg_cls: Type[Aggregator],
updates: Dict[str, Vector],
) -> None:
"""Test that the legacy (deprecated) 'aggregate' method still works."""
agg = agg_cls()
n_steps = {key: 10 for key in updates}
with pytest.warns(DeprecationWarning):
outputs = agg.aggregate(updates, n_steps)
ref_vec = list(updates.values())[0]
assert isinstance(outputs, type(ref_vec))
assert outputs.shapes() == ref_vec.shapes()
assert outputs.dtypes() == ref_vec.dtypes()
def test_aggregate_empty(
self,
agg_cls: Type[Aggregator],
) -> None:
"""Test that 'aggregate' raises the expected error on empty inputs."""
agg = agg_cls()
with pytest.warns(DeprecationWarning):
with pytest.raises(TypeError):
agg.aggregate(updates={}, n_steps={})
......@@ -139,17 +139,6 @@ class TestFederatedClientInit: # pylint: disable=too-many-public-methods
client = FederatedClient(netwk=MOCK_NETWK, train_data=dataset)
assert client.train_data is dataset
def test_train_data_str(self) -> None:
"""Test specifying 'train_data' as a file path."""
path = "mock_path_to_dataset.json"
with mock.patch(
"declearn.main._client.load_dataset_from_json",
return_value=mock.create_autospec(Dataset, instance=True),
) as patched:
client = FederatedClient(netwk=MOCK_NETWK, train_data=path)
patched.assert_called_once_with(path)
assert client.train_data is patched.return_value
def test_train_data_invalid(self) -> None:
"""Test specifying 'train_data' as an invalid type."""
with pytest.raises(TypeError):
......@@ -172,19 +161,6 @@ class TestFederatedClientInit: # pylint: disable=too-many-public-methods
)
assert client.valid_data is dataset
def test_valid_data_str(self) -> None:
"""Test specifying 'valid_data' as a file path."""
path = "mock_path_to_dataset.json"
with mock.patch(
"declearn.main._client.load_dataset_from_json",
return_value=mock.create_autospec(Dataset, instance=True),
) as patched:
client = FederatedClient(
netwk=MOCK_NETWK, train_data=MOCK_DATASET, valid_data=path
)
patched.assert_called_once_with(path)
assert client.valid_data is patched.return_value
def test_valid_data_invalid(self) -> None:
"""Test specifying 'valid_data' as an invalid type."""
with pytest.raises(TypeError):
......
......@@ -177,32 +177,6 @@ class MetricTestSuite:
expect = test_case.agg_scores
assert_dict_equal(metric.get_result(), expect, np_tolerance=self.tol)
def test_legacy_agg_states(self, test_case: MetricTestCase) -> None:
"""Test that the deprecated `agg_states` method works as expected."""
# Set up and update two identical metrics.
metric = test_case.metric
metbis = deepcopy(test_case.metric)
metric.update(**test_case.inputs)
metbis.update(**test_case.inputs)
# Aggregate the second into the first. Verify that they now differ.
assert_dict_equal(
metric.get_states().to_dict(), metbis.get_states().to_dict()
)
with pytest.warns(DeprecationWarning):
metbis.agg_states(metric.get_states())
assert_dict_equal(
metric.get_states().to_dict(), test_case.states.to_dict()
)
with pytest.raises(AssertionError): # assert not equal
assert_dict_equal(
metric.get_states().to_dict(), metbis.get_states().to_dict()
)
# Verify the correctness of the aggregated states and scores.
states = test_case.agg_states
scores = test_case.agg_scores
assert_dict_equal(metbis.get_states().to_dict(), states.to_dict())
assert_dict_equal(metbis.get_result(), scores, np_tolerance=self.tol)
def test_update_with_squeezable_inputs(
self, test_case: MetricTestCase
) -> None:
......
......@@ -18,7 +18,7 @@
"""Unit tests for `declearn.metrics.MetricSet`."""
from unittest import mock
from typing import Dict, Tuple
from typing import Tuple
import numpy as np
import pytest
......@@ -129,18 +129,6 @@ class TestMetricSet:
states["mse"]
)
def test_agg_states(self) -> None:
"""Test that deprecated `MetricSet.agg_states` works as expected."""
mae, mse, metrics = get_mock_metricset()
states = {
"mae": mae.get_states(),
"mse": mse.get_states(),
} # type: Dict[str, MetricState]
with pytest.warns(DeprecationWarning):
metrics.agg_states(states)
mae.agg_states.assert_called_once_with(states["mae"]) # type: ignore
mse.agg_states.assert_called_once_with(states["mse"]) # type: ignore
def test_get_config(self) -> None:
"""Test that `MetricSet.get_config` works as expected."""
mae = MeanAbsoluteError()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment