diff --git a/declearn/aggregator/_api.py b/declearn/aggregator/_api.py index 778b1d3e0c92a7e0c0baedea85e5b35ae6ace0a7..5f089b4180f86802cfda57d6f50d636622f19f6a 100644 --- a/declearn/aggregator/_api.py +++ b/declearn/aggregator/_api.py @@ -127,7 +127,7 @@ class Aggregator(metaclass=ABCMeta): self, ) -> Dict[str, Any]: """Return a JSON-serializable dict with this object's parameters.""" - return {} + return {} # pragma: no cover @classmethod def from_config( diff --git a/declearn/communication/__init__.py b/declearn/communication/__init__.py index e7f685f248bfc22bbc35300a0ef3f39d66660dba..b846ff0b4aa4c42bc35178685df27a701d2b1bee 100644 --- a/declearn/communication/__init__.py +++ b/declearn/communication/__init__.py @@ -69,11 +69,9 @@ from ._build import ( # Concrete implementations using various protocols: try: from . import grpc -except ImportError: - # pragma: no cover +except ImportError: # pragma: no cover _INSTALLABLE_BACKENDS["grpc"] = ("grpcio", "protobuf") try: from . import websockets -except ImportError: - # pragma: no cover +except ImportError: # pragma: no cover _INSTALLABLE_BACKENDS["websockets"] = ("websockets",) diff --git a/declearn/communication/_build.py b/declearn/communication/_build.py index 1552eba942e197f45ce9be44314c6e7d45517528..854af5d38f7bfbd5df9bc4bcd140dbc6f82ab366 100644 --- a/declearn/communication/_build.py +++ b/declearn/communication/_build.py @@ -48,7 +48,7 @@ def raise_if_installable( exc: Optional[Exception] = None, ) -> None: """Raise a RuntimeError if a given protocol is missing but installable.""" - if protocol in _INSTALLABLE_BACKENDS: + if protocol in _INSTALLABLE_BACKENDS: # pragma: no cover raise RuntimeError( f"The '{protocol}' communication protocol network endpoints " "could not be imported, but could be installed by satisfying " @@ -95,7 +95,7 @@ def build_client( protocol = protocol.strip().lower() try: cls = access_registered(name=protocol, group="NetworkClient") - except KeyError as exc: + except KeyError as exc: # pragma: no cover raise_if_installable(protocol, exc) raise KeyError( "Failed to retrieve NetworkClient " @@ -153,7 +153,7 @@ def build_server( protocol = protocol.strip().lower() try: cls = access_registered(name=protocol, group="NetworkServer") - except KeyError as exc: + except KeyError as exc: # pragma: no cover raise_if_installable(protocol, exc) raise KeyError( "Failed to retrieve NetworkServer " diff --git a/declearn/communication/websockets/_client.py b/declearn/communication/websockets/_client.py index 6eeca4fb91f2afe55ed9247a2c5821fc28db0a7a..ea21cd5a943fa01a7c5f07e33365d69212cacae7 100644 --- a/declearn/communication/websockets/_client.py +++ b/declearn/communication/websockets/_client.py @@ -33,8 +33,9 @@ from declearn.communication.websockets._tools import ( send_websockets_message, ) - -CHUNK_LENGTH = 100000 +__all__ = [ + "WebsocketsClient", +] class WebsocketsClient(NetworkClient): diff --git a/declearn/communication/websockets/_server.py b/declearn/communication/websockets/_server.py index e2c618df71220d1a991d03f2a5e7a3f4002a6938..a8dad15c652f71505d1853d35d25192dc1d4b42c 100644 --- a/declearn/communication/websockets/_server.py +++ b/declearn/communication/websockets/_server.py @@ -33,8 +33,9 @@ from declearn.communication.websockets._tools import ( send_websockets_message, ) - -ADD_HEADER = False # revise: drop this constant (choose a behaviour) +__all__ = [ + "WebsocketsServer", +] class WebsocketsServer(NetworkServer): @@ -106,18 +107,12 @@ class WebsocketsServer(NetworkServer): ) -> None: """Start the websockets server.""" # Set up the websockets connections handling process. - extra_headers = ( - ws.Headers(Connection="keep-alive") # type: ignore - if ADD_HEADER - else None - ) server = ws.serve( # type: ignore # pylint: disable=no-member self._handle_connection, host=self.host, port=self.port, logger=self.logger, ssl=self._ssl, - extra_headers=extra_headers, ping_timeout=None, # disable timeout on keep-alive pings ) # Run the websockets server. diff --git a/declearn/communication/websockets/_tools.py b/declearn/communication/websockets/_tools.py index d31850f04f6edc28ac4d6c280e47f1045aec3c99..df4c004ab6a1b56ae2705ef095925ed0ed8b7069 100644 --- a/declearn/communication/websockets/_tools.py +++ b/declearn/communication/websockets/_tools.py @@ -23,6 +23,13 @@ from typing import Union from websockets.legacy.protocol import WebSocketCommonProtocol +__all__ = [ + "StreamRefusedError", + "receive_websockets_message", + "send_websockets_message", +] + + FLAG_STREAM_START = "STREAM_START" FLAG_STREAM_CLOSE = "STREAM_CLOSE" FLAG_STREAM_ALLOW = "STREAM_ALLOW" diff --git a/declearn/data_info/_fields.py b/declearn/data_info/_fields.py index bb48a5587abbf8da4697b4e1ab82d61e2c4e99fe..b46fec8204d2e7fca512427c75827c72f6df8ea0 100644 --- a/declearn/data_info/_fields.py +++ b/declearn/data_info/_fields.py @@ -73,11 +73,12 @@ class DataTypeField(DataInfoField): cls, value: Any, ) -> bool: - if isinstance(value, str): - try: - np.dtype(value) - except TypeError: - return False + if not isinstance(value, str): + return False + try: + np.dtype(value) + except TypeError: + return False return True @classmethod @@ -159,12 +160,12 @@ class NbSamplesField(DataInfoField): @register_data_info_field -class InputShapeField(DataInfoField): +class InputShapeField(DataInfoField): # pragma: no cover """Specifications for 'input_shape' data_info field.""" field = "input_shape" types = (tuple, list) - doc = "Input features' batched shape, checked to be equal." + doc = "DEPRECATED - Input features' batched shape, checked to be equal." @classmethod def is_valid( @@ -216,12 +217,12 @@ class InputShapeField(DataInfoField): @register_data_info_field -class NbFeaturesField(DataInfoField): +class NbFeaturesField(DataInfoField): # pragma: no cover """Deprecated specifications for 'n_features' data_info field.""" field = "n_features" types = (int,) - doc = "Number of input features, checked to be equal." + doc = "DEPRECATED - Number of input features, checked to be equal." @classmethod def is_valid( diff --git a/declearn/dataset/_inmemory.py b/declearn/dataset/_inmemory.py index 534cf45d92ca0f01ad0fb9ae52a511a5f8519934..1bd69fe735626ed6c41cabcd639014ea20601be1 100644 --- a/declearn/dataset/_inmemory.py +++ b/declearn/dataset/_inmemory.py @@ -19,7 +19,7 @@ import os import warnings -from typing import Any, ClassVar, Dict, Iterator, List, Optional, Set, Union +from typing import Any, Dict, Iterator, List, Optional, Set, Union import numpy as np import pandas as pd @@ -72,8 +72,6 @@ class InMemoryDataset(Dataset): # attributes serve clarity; pylint: disable=too-many-instance-attributes # arguments serve modularity; pylint: disable=too-many-arguments - _type_key: ClassVar[str] = "InMemoryDataset" - def __init__( self, data: Union[DataArray, str], @@ -149,6 +147,11 @@ class InMemoryDataset(Dataset): target = self.data[target] else: target = load_data_array(target) + if ( + isinstance(target, pd.DataFrame) + and len(target.columns) == 1 + ): + target = target.iloc[:, 0] self.target = target # Assign the (optional) sample weights data array. if isinstance(s_wght, str): @@ -196,7 +199,7 @@ class InMemoryDataset(Dataset): return set(np.unique(self.target).tolist()) if isinstance(self.target, spmatrix): return set(np.unique(self.target.tocsr().data).tolist()) - raise TypeError( + raise TypeError( # pragma: no cover f"Invalid 'target' attribute type: '{type(self.target)}'." ) @@ -205,17 +208,17 @@ class InMemoryDataset(Dataset): """Unique data type.""" if not self.expose_data_type: return None - if isinstance(self.data, pd.DataFrame): - dtypes = [str(t) for t in list(self.data.dtypes)] + if isinstance(self.feats, pd.DataFrame): + dtypes = {str(t) for t in list(self.feats.dtypes)} if len(dtypes) > 1: raise ValueError( "Cannot work with mixed data types:" "ensure the `data` attribute has unique dtype" ) - return dtypes[0] - if isinstance(self.data, (pd.Series, np.ndarray, spmatrix)): - return str(self.data.dtype) - raise TypeError( + return list(dtypes)[0] + if isinstance(self.feats, (pd.Series, np.ndarray, spmatrix)): + return str(self.feats.dtype) + raise TypeError( # pragma: no cover f"Invalid 'data' attribute type: '{type(self.target)}'." ) @@ -223,7 +226,7 @@ class InMemoryDataset(Dataset): def load_data_array( path: str, **kwargs: Any, - ) -> DataArray: + ) -> DataArray: # pragma: no cover """Load a data array from a dump file. As of declearn v2.2, this staticmethod is DEPRECATED in favor of @@ -244,7 +247,7 @@ class InMemoryDataset(Dataset): def save_data_array( path: str, array: Union[DataArray, pd.Series], - ) -> str: + ) -> str: # pragma: no cover """Save a data array to a dump file. As of declearn v2.2, this staticmethod is DEPRECATED in favor of @@ -308,7 +311,7 @@ class InMemoryDataset(Dataset): path = os.path.abspath(path) folder = os.path.dirname(path) info = {} # type: Dict[str, Any] - info["type"] = self._type_key + info["type"] = "InMemoryDataset" # NOTE: for backward compatibility # Optionally create data dumps. Record data dumps' paths. # fmt: off info["data"] = ( @@ -349,16 +352,11 @@ class InMemoryDataset(Dataset): if "config" not in dump: raise KeyError("Missing key in the JSON file: 'config'.") info = dump["config"] - for key in ("type", "data", "target", "s_wght", "f_cols"): + for key in ("data", "target", "s_wght", "f_cols"): if key not in info: error = f"Missing key in the JSON file: 'config/{key}'." raise KeyError(error) - key = info.pop("type") - if key != cls._type_key: - raise TypeError( - f"Incorrect 'type' field: got '{key}', " - f"expected '{cls._type_key}'." - ) + info.pop("type", None) # Instantiate the object and return it. return cls(**info) diff --git a/declearn/main/utils/_early_stop.py b/declearn/main/utils/_early_stop.py index 92155e4832127725a7845f3b1b625d112ea814bd..f703a9b95422733ce1e4fa93dbc8b49b91e49a01 100644 --- a/declearn/main/utils/_early_stop.py +++ b/declearn/main/utils/_early_stop.py @@ -105,7 +105,7 @@ class EarlyStopping: self._best_metric = metric if self.relative: diff /= self._best_metric - if diff < self.tolerance: + if diff <= self.tolerance: self._n_iter_stuck += 1 else: self._n_iter_stuck = 0 diff --git a/declearn/metrics/_api.py b/declearn/metrics/_api.py index ab23eb0d24fdacabeefe9e35ca0ac112559327bb..fce098c316b488ecdf45cc09302ac2f492f8cbcc 100644 --- a/declearn/metrics/_api.py +++ b/declearn/metrics/_api.py @@ -347,7 +347,9 @@ class Metric(metaclass=ABCMeta): return s_wght @staticmethod - def normalize_weights(s_wght: np.ndarray) -> np.ndarray: + def normalize_weights( # pragma: no cover + s_wght: np.ndarray, + ) -> np.ndarray: """Utility method to ensure weights sum to one. Note that this method may or may not be used depending on diff --git a/declearn/metrics/_mean.py b/declearn/metrics/_mean.py index 2c4189c675971944dab74ab58aa393140e5e3da0..aca63e392f91d19f7e2e863a0b121d8689a2dc55 100644 --- a/declearn/metrics/_mean.py +++ b/declearn/metrics/_mean.py @@ -23,6 +23,7 @@ from typing import ClassVar, Dict, Optional, Union import numpy as np from declearn.metrics._api import Metric +from declearn.metrics._utils import squeeze_into_identical_shapes __all__ = [ "MeanMetric", @@ -129,6 +130,7 @@ class MeanAbsoluteError(MeanMetric): y_pred: np.ndarray, ) -> np.ndarray: # Sample-wise (sum of) absolute error function. + y_true, y_pred = squeeze_into_identical_shapes(y_true, y_pred) errors = np.abs(y_true - y_pred) while errors.ndim > 1: errors = errors.sum(axis=-1) @@ -158,6 +160,7 @@ class MeanSquaredError(MeanMetric): y_pred: np.ndarray, ) -> np.ndarray: # Sample-wise (sum of) squared error function. + y_true, y_pred = squeeze_into_identical_shapes(y_true, y_pred) errors = np.square(y_true - y_pred) while errors.ndim > 1: errors = errors.sum(axis=-1) diff --git a/declearn/metrics/_rsquared.py b/declearn/metrics/_rsquared.py index d61fdfeaf34fc7c6cb3a47433861e42b8ec43e5e..f00b40b2267a523d054ae6229cb8ae3a36c152c8 100644 --- a/declearn/metrics/_rsquared.py +++ b/declearn/metrics/_rsquared.py @@ -22,6 +22,7 @@ from typing import ClassVar, Dict, Optional, Union import numpy as np from declearn.metrics._api import Metric +from declearn.metrics._utils import squeeze_into_identical_shapes __all__ = [ "RSquared", @@ -113,6 +114,7 @@ class RSquared(Metric): y_pred: np.ndarray, s_wght: Optional[np.ndarray] = None, ) -> None: + y_true, y_pred = squeeze_into_identical_shapes(y_true, y_pred) # Verify sample weights' shape, or set up 1-valued ones. s_wght = self._prepare_sample_weights(s_wght, n_samples=len(y_pred)) # Update the residual sum of squares. wSSr = sum(w * (y - p)^2) diff --git a/declearn/metrics/_utils.py b/declearn/metrics/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..757637a6d560883f18ab5fc931db4615deb17c47 --- /dev/null +++ b/declearn/metrics/_utils.py @@ -0,0 +1,54 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Backend utils for metrics' computations.""" + +from typing import Tuple + +import numpy as np + +__all__ = [ + "squeeze_into_identical_shapes", +] + + +def squeeze_into_identical_shapes( + y_true: np.ndarray, + y_pred: np.ndarray, +) -> Tuple[np.ndarray, np.ndarray]: + """Verify that inputs have identical shapes, up to squeezable dims. + + Return the input arrays, squeezed when needed. + Raise a ValueError if they cannot be made to match. + """ + # Case of identical-shape inputs. + if y_true.shape == y_pred.shape: + return y_true, y_pred + # Case of identical-shape inputs up to squeezable dims. + y_true = np.squeeze(y_true) + y_pred = np.squeeze(y_pred) + if y_true.shape == y_pred.shape: + # Handle edge case of scalar values: preserve one dimension. + if not y_true.shape: + y_true = np.expand_dims(y_true, 0) + y_pred = np.expand_dims(y_pred, 0) + return y_true, y_pred + # Case of mismatching shapes. + raise ValueError( + "Received inputs with incompatible shapes: " + f"y_true has shape {y_true.shape}, y_pred has shape {y_pred.shape}." + ) diff --git a/declearn/model/haiku/_vector.py b/declearn/model/haiku/_vector.py index 054883432d10a4d9e55de11d3e397def761afe39..6b391857d3d365125a76b8e926064ef3a418d058 100644 --- a/declearn/model/haiku/_vector.py +++ b/declearn/model/haiku/_vector.py @@ -156,7 +156,7 @@ class JaxNumpyVector(Vector): warnings.warn( # pragma: no cover "The 'axis' and 'keepdims' arguments of 'JaxNumpyVector.sum' " "have been deprecated as of declearn v2.3, and will be " - "removed in version 2.6 and/or 3.0.", + "removed in version 2.5 and/or 3.0.", DeprecationWarning, ) coefs = { diff --git a/declearn/model/sklearn/_np_vec.py b/declearn/model/sklearn/_np_vec.py index 4a99a04b876a650f9f93ab34dad5f20a59a37fa7..499a6ae658ab0b916c0788cccb27fc1cecc6f3d7 100644 --- a/declearn/model/sklearn/_np_vec.py +++ b/declearn/model/sklearn/_np_vec.py @@ -127,7 +127,7 @@ class NumpyVector(Vector): warnings.warn( # pragma: no cover "The 'axis' and 'keepdims' arguments of 'NumpyVector.sum' " "have been deprecated as of declearn v2.3, and will be " - "removed in version 2.6 and/or 3.0.", + "removed in version 2.5 and/or 3.0.", DeprecationWarning, ) coefs = { diff --git a/declearn/model/tensorflow/_vector.py b/declearn/model/tensorflow/_vector.py index b62dffc5390166188caf8d7ca68e0f4806789679..09d9ef58b7e58b5772a8eedf3ade4b03064378b0 100644 --- a/declearn/model/tensorflow/_vector.py +++ b/declearn/model/tensorflow/_vector.py @@ -295,7 +295,7 @@ class TensorflowVector(Vector): warnings.warn( # pragma: no cover "The 'axis' and 'keepdims' arguments of 'TensorflowVector.sum'" " have been deprecated as of declearn v2.3, and will be " - "removed in version 2.6 and/or 3.0.", + "removed in version 2.5 and/or 3.0.", DeprecationWarning, ) return self.apply_func(tf.reduce_sum, axis=axis, keepdims=keepdims) diff --git a/declearn/model/torch/_vector.py b/declearn/model/torch/_vector.py index ff67cc30af3b6284dfe02cccde0096d05ccf5428..5034db537194aed56723a15bc2fc0847a6e06a4e 100644 --- a/declearn/model/torch/_vector.py +++ b/declearn/model/torch/_vector.py @@ -200,7 +200,7 @@ class TorchVector(Vector): warnings.warn( # pragma: no cover "The 'axis' and 'keepdims' arguments of 'TorchVector.sum' " "have been deprecated as of declearn v2.3, and will be " - "removed in version 2.6 and/or 3.0.", + "removed in version 2.5 and/or 3.0.", DeprecationWarning, ) coefs = { diff --git a/declearn/quickrun/_parser.py b/declearn/quickrun/_parser.py index 422d0999d6625b0959ebbd0931ec0d841f13d88d..f15bfb4f2c1c4a19424aff6e110553bfd6bde4b7 100644 --- a/declearn/quickrun/_parser.py +++ b/declearn/quickrun/_parser.py @@ -103,8 +103,8 @@ def parse_data_folder( def get_data_folder_path( - data_folder: Optional[str], - root_folder: Optional[str], + data_folder: Optional[str] = None, + root_folder: Optional[str] = None, ) -> Path: """Return the path to a data folder. @@ -158,7 +158,7 @@ def get_data_folder_path( def list_client_names( data_folder: Path, - client_names: Optional[List[str]], + client_names: Optional[List[str]] = None, ) -> List[str]: """List client-wise subdirectories under a data folder. diff --git a/declearn/quickrun/_run.py b/declearn/quickrun/_run.py index c33ae875dd1af0f5df6ca02e610877bcee62818a..46f14384f282f851a6dddbe664c624bc1664426b 100644 --- a/declearn/quickrun/_run.py +++ b/declearn/quickrun/_run.py @@ -143,19 +143,32 @@ def run_client( def get_toml_folder(config: str) -> Tuple[str, str]: """Return the path to an experiment's folder and TOML config file. - Determine if provided config is a file or a directory, and return: - - * The path to the TOML config file - * The path to the main folder of the experiment + Parameters + ---------- + config: str + Path to either a TOML config file (within an experiment folder), + or to the experiment folder containing a "config.toml" file. + + Returns + ------- + toml: + The path to the TOML config file. + folder: + The path to the main folder of the experiment. + + Raises + ------ + FileNotFoundError: + If the TOML config file cannot be found based on inputs. """ config = os.path.abspath(config) - if os.path.isfile(config): - toml = config - folder = os.path.dirname(config) - elif os.path.isdir(config): + if os.path.isdir(config): folder = config - toml = f"{folder}/config.toml" + toml = os.path.join(folder, "config.toml") else: + toml = config + folder = os.path.dirname(toml) + if not os.path.isfile(toml): raise FileNotFoundError( f"Failed to find quickrun config file at '{config}'." ) diff --git a/pyproject.toml b/pyproject.toml index 2161244d8d6de8a5b4e8408ccd0a494f5fe575e3..7d7131725797dfa461c43172bad177973ad35a9b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,10 +47,10 @@ dependencies = [ [project.optional-dependencies] all = [ # all non-tests extra dependencies - "dm-haiku == 0.0.9", + "dm-haiku ~= 0.0.9", "functorch >= 1.10, < 3.0", "grpcio >= 1.45", - "jax[cpu] >= 0.4, < 0.5", + "jax[cpu] ~= 0.4.1", "opacus ~= 1.1", "protobuf >= 3.19", "tensorflow ~= 2.5", @@ -66,8 +66,8 @@ grpc = [ "protobuf >= 3.19", ] haiku = [ - "dm-haiku == 0.0.9", - "jax[cpu] >= 0.4, < 0.5", # NOTE: GPU support must be manually installed + "dm-haiku ~= 0.0.9", + "jax[cpu] ~= 0.4.1", # NOTE: GPU support must be manually installed ] tensorflow = [ "tensorflow ~= 2.5", diff --git a/test/communication/test_exchanges.py b/test/communication/test_exchanges.py index 2b06aa77f639c05869fdc4ff72248844dee3f70f..c91270b096e665a2de776e11fadf899c1725de02 100644 --- a/test/communication/test_exchanges.py +++ b/test/communication/test_exchanges.py @@ -37,6 +37,7 @@ are run with either a single or three clients at once. """ import asyncio +import secrets from typing import AsyncIterator, Dict, List, Optional, Tuple import pytest @@ -51,7 +52,7 @@ from declearn.communication import ( from declearn.communication.api import NetworkClient, NetworkServer -### 1. Test that connections can properly be set up. +### 1. Test that connections can be properly set up. @pytest_asyncio.fixture(name="server") @@ -203,15 +204,18 @@ class TestNetworkExchanges: @pytest.mark.asyncio async def test_exchanges( - self, agents: Tuple[NetworkServer, List[NetworkClient]] + self, + agents: Tuple[NetworkServer, List[NetworkClient]], ) -> None: """Run all tests with the same fixture-provided agents.""" await self.clients_to_server(agents) await self.server_to_clients_broadcast(agents) await self.server_to_clients_individual(agents) + await self.clients_to_server_large(agents) async def clients_to_server( - self, agents: Tuple[NetworkServer, List[NetworkClient]] + self, + agents: Tuple[NetworkServer, List[NetworkClient]], ) -> None: """Test that clients can send messages to the server.""" server, clients = agents @@ -226,7 +230,8 @@ class TestNetworkExchanges: } async def server_to_clients_broadcast( - self, agents: Tuple[NetworkServer, List[NetworkClient]] + self, + agents: Tuple[NetworkServer, List[NetworkClient]], ) -> None: """Test that the server can send a shared message to all clients.""" server, clients = agents @@ -237,7 +242,8 @@ class TestNetworkExchanges: assert all(reply == msg for reply in replies) async def server_to_clients_individual( - self, agents: Tuple[NetworkServer, List[NetworkClient]] + self, + agents: Tuple[NetworkServer, List[NetworkClient]], ) -> None: """Test that the server can send individual messages to clients.""" server, clients = agents @@ -252,3 +258,24 @@ class TestNetworkExchanges: reply == messages[client.name] for client, reply in zip(clients, replies) ) + + async def clients_to_server_large( + self, + agents: Tuple[NetworkServer, List[NetworkClient]], + ) -> None: + """Test that the clients can send large messages to the server.""" + server, clients = agents + coros = [] + large = secrets.token_bytes(2**22).hex() + for idx, client in enumerate(clients): + msg = messaging.GenericMessage( + action="test", params={"idx": idx, "content": large} + ) + coros.append(client.send_message(msg)) + messages, *_ = await asyncio.gather(server.wait_for_messages(), *coros) + assert messages == { + c.name: messaging.GenericMessage( + action="test", params={"idx": i, "content": large} + ) + for i, c in enumerate(clients) + } diff --git a/test/data_info/test_classes_field.py b/test/data_info/test_classes_field.py new file mode 100644 index 0000000000000000000000000000000000000000..98addb7cf01db28b79cfffbc1d366751d0e21e42 --- /dev/null +++ b/test/data_info/test_classes_field.py @@ -0,0 +1,59 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for 'declearn.data_info.ClassesField'.""" + + +import numpy as np +import pytest + +from declearn.data_info import ClassesField + + +class TestClassesField: + """Unit tests for 'declearn.data_info.ClassesField'.""" + + def test_is_valid_list(self) -> None: + """Test `is_valid` with a valid list value.""" + assert ClassesField.is_valid([0, 1]) + + def test_is_valid_set(self) -> None: + """Test `is_valid` with a valid set value.""" + assert ClassesField.is_valid({0, 1}) + + def test_is_valid_tuple(self) -> None: + """Test `is_valid` with a valid tuple value.""" + assert ClassesField.is_valid((0, 1)) + + def test_is_valid_array(self) -> None: + """Test `is_valid` with a valid numpy array value.""" + assert ClassesField.is_valid(np.array([0, 1])) + + def test_is_invalid_2d_array(self) -> None: + """Test `is_valid` with an invalid numpy array value.""" + assert not ClassesField.is_valid(np.array([[0, 1], [2, 3]])) + + def test_combine(self) -> None: + """Test `combine` with valid and compatible inputs.""" + values = ([0, 1], (0, 1), {1, 2}, np.array([1, 3])) + assert ClassesField.combine(*values) == {0, 1, 2, 3} + + def test_combine_fails(self) -> None: + """Test `combine` with some invalid inputs.""" + values = ([0, 1], np.array([[0, 1], [2, 3]])) + with pytest.raises(ValueError): + ClassesField.combine(*values) diff --git a/test/data_info/test_data_info_utils.py b/test/data_info/test_data_info_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d215fd1d311b397e97c3a4d75a2928d695a7c076 --- /dev/null +++ b/test/data_info/test_data_info_utils.py @@ -0,0 +1,144 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for 'declearn.data_info' high-level utils.""" + +import uuid +from typing import Any, Type +from unittest import mock + +import pytest + +from declearn.data_info import ( + DataInfoField, + aggregate_data_info, + get_data_info_fields_documentation, + register_data_info_field, +) + + +class TestAggregateDataInfo: + """Unit tests for 'declearn.data_info.aggregate_data_info'.""" + + def test_aggregate_data_info(self) -> None: + """Test aggregating valid, compatible data info.""" + clients_data_info = [ + {"n_samples": 10, "features_shape": (100,)}, + {"n_samples": 32, "features_shape": (100,)}, + ] + result = aggregate_data_info(clients_data_info) + assert result == {"n_samples": 42, "features_shape": (100,)} + + def test_aggregate_data_info_required(self) -> None: + """Test aggregating a subset of valid, compatible data info.""" + clients_data_info = [ + {"n_samples": 10, "features_shape": (100,)}, + {"n_samples": 32, "features_shape": (100,)}, + ] + result = aggregate_data_info( + clients_data_info, required_fields={"n_samples"} + ) + assert result == {"n_samples": 42} + + def test_aggregate_data_info_missing_required(self) -> None: + """Test that a KeyError is raised on missing required data info.""" + clients_data_info = [ + {"n_samples": 10}, + {"n_samples": 32}, + ] + with pytest.raises(KeyError): + aggregate_data_info( + clients_data_info, + required_fields={"n_samples", "features_shape"}, + ) + + def test_aggregate_data_info_invalid_values(self) -> None: + """Test that a ValueError is raised on invalid values.""" + clients_data_info = [ + {"n_samples": 10}, + {"n_samples": -1}, + ] + with pytest.raises(ValueError): + aggregate_data_info(clients_data_info) + + def test_aggregate_data_info_incompatible_values(self) -> None: + """Test that a ValueError is raised on incompatible values.""" + clients_data_info = [ + {"features_shape": (28,)}, + {"features_shape": (32,)}, + ] + with pytest.raises(ValueError): + aggregate_data_info(clients_data_info) + + def test_aggregate_data_info_undefined_field(self) -> None: + """Test that unspecified fields are handled as expected.""" + clients_data_info = [ + {"n_samples": 10, "undefined": "a"}, + {"n_samples": 32, "undefined": "b"}, + ] + with mock.patch("warnings.warn") as patch_warn: + result = aggregate_data_info(clients_data_info) + patch_warn.assert_called_once() + assert result == {"n_samples": 42, "undefined": ["a", "b"]} + + +class TestRegisterDataInfoField: + """Unit tests for 'declearn.data_info.register_data_info_field'.""" + + def create_mock_cls(self) -> Type[DataInfoField]: + """Create and return a mock DataInfoField subclass.""" + + field_name = f"mock_field_{uuid.uuid4()}" + + class MockDataInfoField(DataInfoField): + """Mock DataInfoField subclass.""" + + field = field_name + types = (str,) + doc = f"Documentation for '{field_name}'." + + @classmethod + def combine(cls, *values: Any) -> Any: + return values + + return MockDataInfoField + + def test_register_data_info_field(self) -> None: + """Test that registrating a custom DataInfoField works.""" + # Set up a mock DataInfoField subclass. + mock_cls = self.create_mock_cls() + # Test that it can be registered, and thereafter accessed. + register_data_info_field(mock_cls) + documentation = get_data_info_fields_documentation() + assert mock_cls.field in documentation + assert documentation[mock_cls.field] == mock_cls.doc + + def test_register_data_info_field_invalid_type(self) -> None: + """Test that registering a non-DataInfoField subclass fails.""" + with pytest.raises(TypeError): + register_data_info_field(int) # type: ignore + + def test_register_data_info_field_already_used(self) -> None: + """Test that registering twice under the same name fails.""" + # Set up a couple of DataInfoField mock classes with same field name. + mock_cls = self.create_mock_cls() + mock_bis = self.create_mock_cls() + mock_bis.field = mock_cls.field + # Test that they cannot both be registered. + register_data_info_field(mock_cls) + with pytest.raises(KeyError): + register_data_info_field(mock_bis) diff --git a/test/data_info/test_datatype_field.py b/test/data_info/test_datatype_field.py new file mode 100644 index 0000000000000000000000000000000000000000..c4c2f1bf76c11fcea1ff5d628374e7dddeb6115a --- /dev/null +++ b/test/data_info/test_datatype_field.py @@ -0,0 +1,57 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for 'declearn.data_info.DataTypeField'.""" + + +import numpy as np +import pytest + +from declearn.data_info import DataTypeField + + +class TestDataTypeField: + """Unit tests for 'declearn.data_info.DataTypeField'.""" + + def test_is_valid(self) -> None: + """Test `is_valid` with some valid values.""" + assert DataTypeField.is_valid("float32") + assert DataTypeField.is_valid("float64") + assert DataTypeField.is_valid("int32") + assert DataTypeField.is_valid("uint8") + + def test_is_not_valid(self) -> None: + """Test `is_valid` with invalid values.""" + assert not DataTypeField.is_valid(np.int32) + assert not DataTypeField.is_valid("mocktype") + + def test_combine(self) -> None: + """Test `combine` with valid and compatible inputs.""" + values = ["float32", "float32"] + assert DataTypeField.combine(*values) == "float32" + + def test_combine_invalid(self) -> None: + """Test `combine` with invalid inputs.""" + values = ["float32", "mocktype"] + with pytest.raises(ValueError): + DataTypeField.combine(*values) + + def test_combine_incompatible(self) -> None: + """Test `combine` with incompatible inputs.""" + values = ["float32", "float16"] + with pytest.raises(ValueError): + DataTypeField.combine(*values) diff --git a/test/data_info/test_nbsamples_field.py b/test/data_info/test_nbsamples_field.py new file mode 100644 index 0000000000000000000000000000000000000000..fc53f3dfdc31dd29ef21b9d0a2bb524034f5e739 --- /dev/null +++ b/test/data_info/test_nbsamples_field.py @@ -0,0 +1,52 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for 'declearn.data_info.NbSamplesField'.""" + + +import pytest + +from declearn.data_info import NbSamplesField + + +class TestNbSamplesField: + """Unit tests for 'declearn.data_info.NbSamplesField'.""" + + def test_is_valid(self) -> None: + """Test `is_valid` with some valid input values.""" + assert NbSamplesField.is_valid(32) + assert NbSamplesField.is_valid(100) + assert NbSamplesField.is_valid(8192) + + def test_is_not_valid(self) -> None: + """Test `is_valid` with invalid values.""" + assert not NbSamplesField.is_valid(16.5) + assert not NbSamplesField.is_valid(-12) + assert not NbSamplesField.is_valid(None) + + def test_combine(self) -> None: + """Test `combine` with valid and compatible inputs.""" + values = [32, 128] + assert NbSamplesField.combine(*values) == 160 + values = [64, 64, 64, 64] + assert NbSamplesField.combine(*values) == 256 + + def test_combine_invalid(self) -> None: + """Test `combine` with invalid inputs.""" + values = [128, -12] + with pytest.raises(ValueError): + NbSamplesField.combine(*values) diff --git a/test/data_info/test_shape_field.py b/test/data_info/test_shape_field.py new file mode 100644 index 0000000000000000000000000000000000000000..ef9f61c4a06c28cb5d31a31fad560109a6d95089 --- /dev/null +++ b/test/data_info/test_shape_field.py @@ -0,0 +1,67 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for 'declearn.data_info.FeaturesShapeField'.""" + + +import pytest + +from declearn.data_info import FeaturesShapeField + + +class TestFeaturesShapeField: + """Unit tests for 'declearn.data_info.FeaturesShapeField'.""" + + def test_is_valid(self) -> None: + """Test `is_valid` with some valid input values.""" + # 1-d ; fixed 3-d (image-like) ; variable 2-d (text-like). + assert FeaturesShapeField.is_valid([32]) + assert FeaturesShapeField.is_valid([64, 64, 3]) + assert FeaturesShapeField.is_valid([None, 128]) + # Same inputs, as tuples. + assert FeaturesShapeField.is_valid((32,)) + assert FeaturesShapeField.is_valid((64, 64, 3)) + assert FeaturesShapeField.is_valid((None, 128)) + + def test_is_not_valid(self) -> None: + """Test `is_valid` with invalid values.""" + assert not FeaturesShapeField.is_valid(32) + assert not FeaturesShapeField.is_valid([32, -1]) + + def test_combine(self) -> None: + """Test `combine` with valid and compatible inputs.""" + # 1-d inputs. + values = [[32], (32,)] + assert FeaturesShapeField.combine(*values) == (32,) + # 3-d fixed-size inputs. + values = [[16, 16, 3], (16, 16, 3)] + assert FeaturesShapeField.combine(*values) == (16, 16, 3) + # 2-d variable-size inputs. + values = [[None, 512], (None, 512)] # type: ignore + assert FeaturesShapeField.combine(*values) == (None, 512) + + def test_combine_invalid(self) -> None: + """Test `combine` with invalid inputs.""" + values = [[32], [32, -1]] + with pytest.raises(ValueError): + FeaturesShapeField.combine(*values) + + def test_combine_incompatible(self) -> None: + """Test `combine` with incompatible inputs.""" + values = [(None, 32), (128,)] + with pytest.raises(ValueError): + FeaturesShapeField.combine(*values) diff --git a/test/dataset/test_utils.py b/test/dataset/test_dataset_utils.py similarity index 100% rename from test/dataset/test_utils.py rename to test/dataset/test_dataset_utils.py diff --git a/test/dataset/test_inmemory.py b/test/dataset/test_inmemory.py index ff6456c6a6fcb57729e0aff9bd8fb09528707cac..3551dd499058bdb064bc64c911f319dee5c77b36 100644 --- a/test/dataset/test_inmemory.py +++ b/test/dataset/test_inmemory.py @@ -15,13 +15,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Unit tests objects for 'declearn.dataset.InMemoryDataset'""" +"""Unit tests for 'declearn.dataset.InMemoryDataset'""" +import json import os +import numpy as np +import pandas as pd import pytest +import scipy.sparse # type: ignore +import sklearn.datasets # type: ignore + from declearn.dataset import InMemoryDataset +from declearn.dataset.utils import save_data_array from declearn.test_utils import make_importable # relative imports from `dataset_testbase.py` @@ -32,6 +39,9 @@ with make_importable(os.path.dirname(__file__)): SEED = 0 +### Shared-tests-based tests, revolving around batches generation. + + class InMemoryDatasetTestToolbox(DatasetTestToolbox): """Toolbox for InMemoryDataset""" @@ -44,10 +54,288 @@ class InMemoryDatasetTestToolbox(DatasetTestToolbox): @pytest.fixture(name="toolbox") -def fixture_dataset() -> DatasetTestToolbox: +def fixture_toolbox() -> DatasetTestToolbox: """Fixture to access a InMemoryDatasetTestToolbox.""" return InMemoryDatasetTestToolbox() class TestInMemoryDataset(DatasetTestSuite): """Unit tests for declearn.dataset.InMemoryDataset.""" + + +### InMemoryDataset-specific unit tests. + + +@pytest.fixture(name="dataset") +def dataset_fixture() -> pd.DataFrame: + """Fixture providing with a small toy dataset.""" + rng = np.random.default_rng(seed=SEED) + wgt = rng.normal(size=10).astype("float32") + data = { + "col_a": np.arange(10, dtype="float32"), + "col_b": rng.normal(size=10).astype("float32"), + "col_y": rng.choice(3, size=10, replace=True), + "col_w": wgt / sum(wgt), + } + return pd.DataFrame(data) + + +class TestInMemoryDatasetInit: + """Unit tests for 'declearn.dataset.InMemoryDataset' instantiation.""" + + def test_from_inputs( + self, + dataset: pd.DataFrame, + ) -> None: + """Test instantiating with (x, y, w) array data.""" + # Split data into distinct objects with various types. + y_dat = dataset.pop("col_y") + w_dat = dataset.pop("col_w").values + x_dat = scipy.sparse.coo_matrix(dataset.values) + # Test that an InMemoryDataset can be instantiated from that data. + dst = InMemoryDataset(data=x_dat, target=y_dat, s_wght=w_dat) + assert dst.feats is x_dat + assert dst.target is y_dat + assert dst.weights is w_dat + + def test_from_dataframe( + self, + dataset: pd.DataFrame, + ) -> None: + """Test instantiating with a pandas DataFrame and column names.""" + dst = InMemoryDataset(data=dataset, target="col_y", s_wght="col_w") + assert np.allclose(dst.feats, dataset[["col_a", "col_b"]]) + assert np.allclose(dst.target, dataset["col_y"]) # type: ignore + assert np.allclose(dst.weights, dataset["col_w"]) # type: ignore + + def test_from_dataframe_with_fcols_str( + self, + dataset: pd.DataFrame, + ) -> None: + """Test instantiating from a pandas Dataframe with string f_cols.""" + dst = InMemoryDataset( + data=dataset, target="col_y", s_wght="col_w", f_cols=["col_a"] + ) + assert np.allclose(dst.feats, dataset[["col_a"]]) + assert np.allclose(dst.target, dataset["col_y"]) # type: ignore + assert np.allclose(dst.weights, dataset["col_w"]) # type: ignore + + def test_from_dataframe_with_fcols_int( + self, + dataset: pd.DataFrame, + ) -> None: + """Test instantiating from a pandas Dataframe with string f_cols.""" + dst = InMemoryDataset( + data=dataset, target="col_y", s_wght="col_w", f_cols=[1] + ) + assert np.allclose(dst.feats, dataset[["col_b"]]) + assert np.allclose(dst.target, dataset["col_y"]) # type: ignore + assert np.allclose(dst.weights, dataset["col_w"]) # type: ignore + + def test_from_csv_file( + self, + dataset: pd.DataFrame, + tmp_path: str, + ) -> None: + """Test instantiating from a single csv file and column names.""" + # Dump the dataset to a csv file and instantiate from it. + path = os.path.join(tmp_path, "dataset.csv") + dataset.to_csv(path, index=False) + dst = InMemoryDataset(data=path, target="col_y", s_wght="col_w") + # Test that the data matches expectations. + assert np.allclose(dst.feats, dataset[["col_a", "col_b"]]) + assert np.allclose(dst.target, dataset["col_y"]) # type: ignore + assert np.allclose(dst.weights, dataset["col_w"]) # type: ignore + + def test_from_csv_file_feats_only( + self, + dataset: pd.DataFrame, + tmp_path: str, + ) -> None: + """Test instantiating from a single csv file without y nor w.""" + # Dump the dataset to a csv file and instantiate from it. + path = os.path.join(tmp_path, "dataset.csv") + dataset.to_csv(path, index=False) + dst = InMemoryDataset(data=path) + # Test that the data matches expectations. + assert np.allclose(dst.feats, dataset) + assert dst.target is None + assert dst.weights is None + + def test_from_data_files( + self, + dataset: pd.DataFrame, + tmp_path: str, + ) -> None: + """Test instantiating from a collection of files.""" + # Split data into distinct objects with various types. + y_dat = dataset.pop("col_y") + w_dat = dataset.pop("col_w").values + x_dat = scipy.sparse.coo_matrix(dataset.values) + # Save these objects to files. + x_path = save_data_array(os.path.join(tmp_path, "data_x"), x_dat) + y_path = save_data_array(os.path.join(tmp_path, "data_y"), y_dat) + w_path = save_data_array(os.path.join(tmp_path, "data_w"), w_dat) + # Tes that an InMemoryDataset can be instantiated from these files. + dst = InMemoryDataset(data=x_path, target=y_path, s_wght=w_path) + assert isinstance(dst.feats, scipy.sparse.coo_matrix) + assert np.allclose(dst.feats.toarray(), x_dat.toarray()) + assert isinstance(dst.target, pd.Series) + assert np.allclose(dst.target, y_dat) + assert isinstance(dst.weights, np.ndarray) + assert np.allclose(dst.weights, w_dat) # type: ignore + + def test_from_svmlight( + self, + dataset: pd.DataFrame, + tmp_path: str, + ) -> None: + """Test instantiating from a svmlight file.""" + path = os.path.join(tmp_path, "dataset.svmlight") + sklearn.datasets.dump_svmlight_file( + scipy.sparse.coo_matrix(dataset[["col_a", "col_b"]].values), + dataset["col_y"].values, + path, + ) + dst = InMemoryDataset.from_svmlight(path) + assert isinstance(dst.data, scipy.sparse.csr_matrix) + assert np.allclose( + dst.data.toarray(), dataset[["col_a", "col_b"]].values + ) + assert isinstance(dst.target, np.ndarray) + assert np.allclose(dst.target, dataset["col_y"].to_numpy()) + assert dst.weights is None + + +class TestInMemoryDatasetProperties: + """Unit tests for 'declearn.dataset.InMemoryDataset' properties.""" + + def test_classes_array( + self, + dataset: pd.DataFrame, + ) -> None: + """Test (authorized) classes access with numpy array targets.""" + dst = InMemoryDataset( + data=dataset, target=dataset["col_y"].values, expose_classes=True + ) + assert dst.classes == {0, 1, 2} + + def test_classes_series( + self, + dataset: pd.DataFrame, + ) -> None: + """Test (authorized) classes access with pandas Series targets.""" + dst = InMemoryDataset( + data=dataset, target="col_y", expose_classes=True + ) + assert dst.classes == {0, 1, 2} + + def test_classes_dataframe( + self, + dataset: pd.DataFrame, + ) -> None: + """Test (authorized) classes access with pandas DataFrame targets.""" + dst = InMemoryDataset( + data=dataset, target=dataset[["col_y"]], expose_classes=True + ) + assert dst.classes == {0, 1, 2} + + def test_classes_sparse( + self, + dataset: pd.DataFrame, + ) -> None: + """Test (authorized) classes access with scipy spmatrix targets.""" + y_dat = scipy.sparse.coo_matrix(dataset[["col_y"]] + 1) + dst = InMemoryDataset(data=dataset, target=y_dat, expose_classes=True) + assert dst.classes == {1, 2, 3} + + def test_data_type_dataframe( + self, + dataset: pd.DataFrame, + ) -> None: + """Test (authorized) data-type access with pandas DataFrame data.""" + dst = InMemoryDataset( + data=dataset[["col_a", "col_b"]], expose_data_type=True + ) + assert dst.data_type == "float32" + + def test_data_type_dataframe_mixed( + self, + dataset: pd.DataFrame, + ) -> None: + """Test that an exception is raised with a mixed-type DataFrame.""" + dst = InMemoryDataset(data=dataset, expose_data_type=True) + with pytest.raises(ValueError): + dst.data_type # pylint: disable=pointless-statement + + def test_data_type_series( + self, + dataset: pd.DataFrame, + ) -> None: + """Test (authorized) data-type access with pandas Series data.""" + dst = InMemoryDataset(data=dataset["col_a"], expose_data_type=True) + assert dst.data_type == "float32" + + def test_data_type_array( + self, + dataset: pd.DataFrame, + ) -> None: + """Test (authorized) data-type access with numpy array data.""" + data = dataset[["col_a", "col_b"]].values + dst = InMemoryDataset(data=data, expose_data_type=True) + assert dst.data_type == "float32" + + def test_data_type_sparse( + self, + dataset: pd.DataFrame, + ) -> None: + """Test (authorized) data-type access with scipy spmatrix data.""" + data = scipy.sparse.coo_matrix(dataset[["col_a", "col_b"]].values) + dst = InMemoryDataset(data=data, expose_data_type=True) + assert dst.data_type == "float32" + + +class TestInMemoryDatasetSaveLoad: + """Test JSON-file saving/loading features of InMemoryDataset.""" + + def test_save_load_json( + self, + dataset: pd.DataFrame, + tmp_path: str, + ) -> None: + """Test that a dataset can be saved to and loaded from JSON.""" + dst = InMemoryDataset(dataset, target="col_y", s_wght="col_w") + # Test that the dataset can be saved to JSON. + path = os.path.join(tmp_path, "dataset.json") + dst.save_to_json(path) + assert os.path.isfile(path) + # Test that it can be reloaded from JSON. + bis = InMemoryDataset.load_from_json(path) + assert np.allclose(dst.data, bis.data) + assert np.allclose(dst.target, bis.target) # type: ignore + assert np.allclose(dst.weights, bis.weights) # type: ignore + assert dst.f_cols == bis.f_cols + assert dst.expose_classes == bis.expose_classes + assert dst.expose_data_type == bis.expose_data_type + + def test_load_json_malformed( + self, + tmp_path: str, + ) -> None: + """Test with a JSON file that has nothing to do with a dataset.""" + path = os.path.join(tmp_path, "dataset.json") + with open(path, "w", encoding="utf-8") as file: + json.dump({"not-a-dataset": "at-all"}, file) + with pytest.raises(KeyError): + InMemoryDataset.load_from_json(path) + + def test_load_json_partial( + self, + tmp_path: str, + ) -> None: + """Test with a JSON file that contains a partial dataset config.""" + path = os.path.join(tmp_path, "dataset.json") + with open(path, "w", encoding="utf-8") as file: + json.dump({"config": {"data": "mock", "target": "mock"}}, file) + with pytest.raises(KeyError): + InMemoryDataset.load_from_json(path) diff --git a/test/functional/test_regression.py b/test/functional/test_toy_reg.py similarity index 57% rename from test/functional/test_regression.py rename to test/functional/test_toy_reg.py index ce5f805211d5fcf14208e2d22227b0ce8eee219c..469e06f3f1605cc1dd0713bdd14f58e994664dc5 100644 --- a/test/functional/test_regression.py +++ b/test/functional/test_toy_reg.py @@ -42,16 +42,20 @@ The convergence results of those experiments is then compared. """ +import asyncio +import dataclasses import json +import os import tempfile from typing import List, Tuple import numpy as np -from sklearn.datasets import make_regression # type: ignore -from sklearn.linear_model import SGDRegressor # type: ignore +import pytest +import sklearn.datasets # type: ignore +import sklearn.linear_model # type: ignore from declearn.communication import NetworkClientConfig, NetworkServerConfig -from declearn.dataset import InMemoryDataset +from declearn.dataset import Dataset, InMemoryDataset from declearn.main import FederatedClient, FederatedServer from declearn.main.config import FLOptimConfig, FLRunConfig from declearn.metrics import RSquared @@ -59,7 +63,7 @@ from declearn.model.api import Model from declearn.model.sklearn import SklearnSGDModel from declearn.optimizer import Optimizer from declearn.test_utils import FrameworkType -from declearn.utils import run_as_processes, set_device_policy +from declearn.utils import set_device_policy # optional frameworks' dependencies pylint: disable=ungrouped-imports # pylint: disable=duplicate-code @@ -71,7 +75,8 @@ try: except ModuleNotFoundError: pass else: - from declearn.model.tensorflow import TensorflowModel + from declearn.dataset.tensorflow import TensorflowDataset + from declearn.model.tensorflow import TensorflowModel, TensorflowVector # torch imports try: import torch @@ -79,8 +84,7 @@ except ModuleNotFoundError: pass else: from declearn.dataset.torch import TorchDataset - from declearn.model.torch import TorchModel -# pylint: enable=duplicate-code + from declearn.model.torch import TorchModel, TorchVector # haiku and jax imports try: import haiku as hk @@ -88,7 +92,7 @@ try: except ModuleNotFoundError: pass else: - from declearn.model.haiku import HaikuModel + from declearn.model.haiku import HaikuModel, JaxNumpyVector def haiku_model_fn(inputs: jax.Array) -> jax.Array: """Simple linear model implemented with Haiku.""" @@ -96,40 +100,81 @@ else: def haiku_loss_fn(y_pred: jax.Array, y_true: jax.Array) -> jax.Array: """Sample-wise squared error loss function.""" + y_pred = jax.numpy.squeeze(y_pred) return (y_pred - y_true) ** 2 +# pylint: disable=duplicate-code + SEED = 0 -R2_THRESHOLD = 0.999 +R2_THRESHOLD = 0.9999 set_device_policy(gpu=False) # disable GPU use to avoid concurrence def get_model(framework: FrameworkType) -> Model: - """Set up a simple toy regression model.""" + """Set up a simple toy regression model, with zero-valued weights.""" set_device_policy(gpu=False) # disable GPU use to avoid concurrence if framework == "numpy": - np.random.seed(SEED) # set seed - model = SklearnSGDModel.from_parameters( - kind="regressor", loss="squared_error", penalty="none" - ) # type: Model - elif framework == "tensorflow": - tf.random.set_seed(SEED) # set seed - tfmod = tf.keras.Sequential(tf.keras.layers.Dense(units=1)) - tfmod.build([None, 100]) - model = TensorflowModel(tfmod, loss="mean_squared_error") - elif framework == "torch": - torch.manual_seed(SEED) # set seed - torchmod = torch.nn.Sequential( - torch.nn.Linear(100, 1, bias=True), - torch.nn.Flatten(0), - ) - model = TorchModel(torchmod, loss=torch.nn.MSELoss()) - elif framework == "jax": - model = HaikuModel(haiku_model_fn, loss=haiku_loss_fn) - else: - raise ValueError("unrecognised framework") + return _get_model_numpy() + if framework == "tensorflow": + return _get_model_tflow() + if framework == "torch": + return _get_model_torch() + if framework == "jax": + return _get_model_haiku() + raise ValueError(f"Unrecognised model framework: '{framework}'.") + + +def _get_model_numpy() -> SklearnSGDModel: + """Return a linear model with MSE loss in Sklearn, with zero weights.""" + np.random.seed(SEED) # set seed + model = SklearnSGDModel.from_parameters( + kind="regressor", loss="squared_error", penalty="none" + ) + return model + + +def _get_model_tflow() -> TensorflowModel: + """Return a linear model with MSE loss in TensorFlow, with zero weights.""" + tf.random.set_seed(SEED) # set seed + tfmod = tf.keras.Sequential(tf.keras.layers.Dense(units=1)) + tfmod.build([None, 100]) + model = TensorflowModel(tfmod, loss="mean_squared_error") + zeros = { + key: tf.zeros_like(val) + for key, val in model.get_weights().coefs.items() + } + model.set_weights(TensorflowVector(zeros)) + return model + + +def _get_model_torch() -> TorchModel: + """Return a linear model with MSE loss in Torch, with zero weights.""" + torch.manual_seed(SEED) # set seed + torchmod = torch.nn.Sequential( + torch.nn.Linear(100, 1, bias=True), + torch.nn.Flatten(0), + ) + model = TorchModel(torchmod, loss=torch.nn.MSELoss()) + zeros = { + key: torch.zeros_like(val) + for key, val in model.get_weights().coefs.items() + } + model.set_weights(TorchVector(zeros)) + return model + + +def _get_model_haiku() -> HaikuModel: + """Return a linear model with MSE loss in Haiku, with zero weights.""" + model = HaikuModel(haiku_model_fn, loss=haiku_loss_fn) + model.initialize({"data_type": "float32", "features_shape": (100,)}) + zeros = { + key: jax.numpy.zeros_like(val) + for key, val in model.get_weights().coefs.items() + } + model.set_weights(JaxNumpyVector(zeros)) return model @@ -139,48 +184,13 @@ def get_dataset(framework: FrameworkType, inputs, labels): inputs = torch.from_numpy(inputs) labels = torch.from_numpy(labels) return TorchDataset(torch.utils.data.TensorDataset(inputs, labels)) - return InMemoryDataset(inputs, labels) - - -def prep_client_datasets( - framework: FrameworkType, - clients: int = 3, - n_train: int = 100, - n_valid: int = 50, -) -> List[Tuple[InMemoryDataset, InMemoryDataset]]: - """Generate and split toy data for a regression problem. - - Parameters - ---------- - clients: int, default=3 - Number of clients, i.e. of dataset shards to generate. - n_train: int, default=30 - Number of training samples per client. - n_valid: int, default=30 - Number of validation samples per client. - - Returns - ------- - datasets: list[(InMemoryDataset, InMemoryDataset)] - List of client-wise (train, valid) pair of datasets. - """ - - n_samples = clients * (n_train + n_valid) - # false-positive; pylint: disable=unbalanced-tuple-unpacking - inputs, target = make_regression( - n_samples, n_features=100, n_informative=10, random_state=SEED - ) - inputs, target = inputs.astype("float32"), target.astype("float32") - # Wrap up the data into client-wise pairs of dataset. - out = [] # type: List[Tuple[InMemoryDataset, InMemoryDataset]] - for idx in range(clients): - start = (n_train + n_valid) * idx - mid = start + n_train - end = mid + n_valid - train = get_dataset(framework, inputs[start:mid], target[start:mid]) - valid = get_dataset(framework, inputs[mid:end], target[mid:end]) - out.append((train, valid)) - return out + if framework == "tensorflow": + inputs = tf.convert_to_tensor(inputs) + labels = tf.convert_to_tensor(labels) + return TensorflowDataset( + tf.data.Dataset.from_tensor_slices((inputs, labels)) + ) + return InMemoryDataset(inputs, labels, expose_data_type=True) def prep_full_dataset( @@ -204,7 +214,7 @@ def prep_full_dataset( """ n_samples = n_train + n_valid # false-positive; pylint: disable=unbalanced-tuple-unpacking - inputs, target = make_regression( + inputs, target = sklearn.datasets.make_regression( n_samples, n_features=100, n_informative=10, random_state=SEED ) inputs, target = inputs.astype("float32"), target.astype("float32") @@ -216,62 +226,150 @@ def prep_full_dataset( return out -def test_declearn_experiment( +def test_sklearn_baseline( + lrate: float = 0.04, + rounds: int = 10, + b_size: int = 10, +) -> None: + """Run a baseline using scikit-learn to emulate a centralized setting. + + This function does not use declearn. It sets up a single sklearn + model and performs training on the full dataset. + + Parameters + ---------- + lrate: float, default=0.01 + Learning rate of the SGD algorithm. + rounds: int, default=10 + Number of training rounds to perform, i.e. number of epochs. + b_size: int, default=10 + Batch size fot the training (and validation) data. + Batching will be performed without shuffling nor replacement, + and the final batch may be smaller than the others (no drop). + """ + # Generate the client datasets, then centralize them into numpy arrays. + train, valid = prep_full_dataset() + # Set up a scikit-learn model, implementing step-wise gradient descent. + sgd = sklearn.linear_model.SGDRegressor( + loss="squared_error", + penalty="l1", + alpha=0.1, + eta0=lrate / b_size, # adjust learning rate for (dropped) batch size + learning_rate="constant", # disable scaling, unused in declearn + max_iter=rounds, + ) + # Iteratively train the model, evaluating it after each epoch. + for _ in range(rounds): + sgd.partial_fit(train[0], train[1]) + assert sgd.score(valid[0], valid[1]) > R2_THRESHOLD + + +def test_declearn_baseline( framework: FrameworkType, - lrate: float = 0.01, + lrate: float = 0.02, rounds: int = 10, - b_size: int = 1, - clients: int = 3, + b_size: int = 10, ) -> None: - """Run an experiment using declearn to perform a federative training. + """Run a baseline uing declearn APIs to emulate a centralized setting. - This function runs the experiment using declearn. - It sets up and runs the server and client-wise routines in separate - processes, to enable their concurrent execution. + This function uses declearn but sets up a single model and performs + training on the entire toy regression dataset. Parameters ---------- framework: str Framework of the model to train and evaluate. - lrate: float, default=0.01 - Learning rate of the SGD algorithm performed by clients. + lrate: float, default=0.02 + Learning rate of the SGD algorithm. rounds: int, default=10 - Number of FL training rounds to perform. - At each round, each client will perform a full epoch of training. + Number of training rounds to perform, i.e. number of epochs. b_size: int, default=10 Batch size fot the training (and validation) data. Batching will be performed without shuffling nor replacement, and the final batch may be smaller than the others (no drop). - clients: int, default=3 - Number of federated clients to set up and run. """ - # pylint: disable=too-many-locals - with tempfile.TemporaryDirectory() as folder: - # Set up a (func, args) tuple specifying the server process. - p_server = ( - _server_routine, - (folder, framework, lrate, rounds, b_size, clients), + # Generate the client datasets, then centralize them into numpy arrays. + train, valid = prep_full_dataset() + dst_train = get_dataset(framework, *train) + # Set up a declearn model and a SGD optimizer with Lasso regularization. + model = get_model(framework) + model.initialize(dataclasses.asdict(dst_train.get_data_specs())) + optim = Optimizer( + lrate=lrate if framework != "numpy" else (lrate * 2), + regularizers=[("lasso", {"alpha": 0.1})], + ) + # Iteratively train the model and evaluate it between rounds. + r_sq = RSquared() + scores = [] # type: List[float] + for _ in range(rounds): + for batch in dst_train.generate_batches( + batch_size=b_size, drop_remainder=False + ): + optim.run_train_step(model, batch) + pred = model.compute_batch_predictions((*valid, None)) + r_sq.reset() + r_sq.update(*pred) + scores.append(r_sq.get_result()["r2"]) # type: ignore + # Check that the R2 increased through epochs to reach a high value. + print(scores) + assert all(scores[i + 1] >= scores[i] for i in range(rounds - 1)) + assert scores[-1] >= R2_THRESHOLD + + +def prep_client_datasets( + framework: FrameworkType, + clients: int = 3, + n_train: int = 100, + n_valid: int = 50, +) -> List[Tuple[Dataset, Dataset]]: + """Generate and split data for a toy sparse regression problem. + + Parameters + ---------- + framework: + Name of the framework being tested, based on which the Dataset + class choice may be adjusted as well. + clients: + Number of clients, i.e. of dataset shards to generate. + n_train: + Number of training samples per client. + n_valid: + Number of validation samples per client. + + Returns + ------- + datasets: + List of client-wise tuple of (train, valid) Dataset instances. + """ + train, valid = prep_full_dataset( + n_train=clients * n_train, + n_valid=clients * n_valid, + ) + # Wrap up the data into client-wise pairs of dataset. + out = [] # type: List[Tuple[Dataset, Dataset]] + for idx in range(clients): + # Gather the client's training dataset. + srt = n_train * idx + end = n_train + srt + dst_train = get_dataset( + framework=framework, + inputs=train[0][srt:end], + labels=train[1][srt:end], ) - # Set up the (func, args) tuples specifying client-wise processes. - datasets = prep_client_datasets(framework, clients) - p_client = [] - for i, data in enumerate(datasets): - client = (_client_routine, (data[0], data[1], f"client_{i}")) - p_client.append(client) - # Run each and every process in parallel. - success, outputs = run_as_processes(p_server, *p_client) - assert success, "The FL process failed:\n" + "\n".join( - str(exc) for exc in outputs if isinstance(exc, RuntimeError) + # Gather the client's validation dataset. + srt = n_valid * idx + end = n_valid + srt + dst_valid = get_dataset( + framework=framework, + inputs=valid[0][srt:end], + labels=valid[1][srt:end], ) - # Assert convergence - with open(f"{folder}/metrics.json", encoding="utf-8") as file: - r2_dict = json.load(file) - last_r2_dict = r2_dict.get(max(r2_dict.keys())) - final_r2 = float(last_r2_dict.get("r2")) - assert final_r2 > R2_THRESHOLD, "The FL training did not converge" + # Store both datasets into the output list. + out.append((dst_train, dst_valid)) + return out -def _server_routine( +async def async_run_server( folder: str, framework: FrameworkType, lrate: float = 0.01, @@ -289,7 +387,7 @@ def _server_routine( optim = FLOptimConfig.from_params( aggregator="averaging", client_opt={ - "lrate": lrate, + "lrate": lrate if framework != "numpy" else (lrate * 2), "regularizers": [("lasso", {"alpha": 0.1})], }, server_opt=1.0, @@ -299,24 +397,24 @@ def _server_routine( netwk, optim, metrics=["r2"], - checkpoint=folder, + checkpoint={"folder": folder, "max_history": 1}, ) # Set up hyper-parameters and run training. config = FLRunConfig.from_params( rounds=rounds, - register={"min_clients": clients}, + register={"min_clients": clients, "timeout": 10}, training={ "n_epoch": 1, "batch_size": b_size, "drop_remainder": False, }, ) - server.run(config) + await server.async_run(config) -def _client_routine( - train: InMemoryDataset, - valid: InMemoryDataset, +async def async_run_client( + train: Dataset, + valid: Dataset, name: str = "client", ) -> None: """Routine to run a FL client, called by `run_declearn_experiment`.""" @@ -324,83 +422,62 @@ def _client_routine( protocol="websockets", server_uri="ws://localhost:8765", name=name ) client = FederatedClient(netwk, train, valid) - client.run() - - -def test_declearn_baseline( - lrate: float = 0.01, - rounds: int = 10, - b_size: int = 1, -) -> None: - """Run a baseline uing declearn APIs to emulate a centralized setting. + await client.async_run() - This function uses declearn but sets up a single model and performs - training on the concatenation of "client-wise" datasets. - Parameters - ---------- - lrate: float, default=0.01 - Learning rate of the SGD algorithm. - rounds: int, default=10 - Number of training rounds to perform, i.e. number of epochs. - b_size: int, default=10 - Batch size fot the training (and validation) data. - Batching will be performed without shuffling nor replacement, - and the final batch may be smaller than the others (no drop). - """ - # Generate the client datasets, then centralize them into numpy arrays. - train, valid = prep_full_dataset() - d_train = InMemoryDataset(train[0], train[1]) - # Set up a declearn model and a vanilla SGD optimizer. - model = get_model("numpy") - model.initialize({"features_shape": (d_train.data.shape[1],)}) - opt = Optimizer(lrate=lrate, regularizers=[("lasso", {"alpha": 0.1})]) - # Iteratively train the model, evaluating it after each epoch. - for _ in range(rounds): - # Run the training round. - for batch in d_train.generate_batches(batch_size=b_size): - grads = model.compute_batch_gradients(batch) - opt.apply_gradients(model, grads) - # Check the final R2 value. - r_sq = RSquared() - r_sq.update(*model.compute_batch_predictions((valid[0], valid[1], None))) - assert r_sq.get_result()["r2"] > R2_THRESHOLD - - -def test_sklearn_baseline( +@pytest.mark.asyncio +async def test_declearn_federated( + framework: FrameworkType, lrate: float = 0.01, rounds: int = 10, b_size: int = 1, + clients: int = 3, ) -> None: - """Run a baseline using scikit-learn to emulate a centralized setting. + """Run an experiment using declearn to perform a federative training. - This function does not use declearn. It sets up a single sklearn - model and performs training on the full dataset. + This function runs the experiment using declearn. + It sets up and runs the server and client-wise routines in separate + processes, to enable their concurrent execution. Parameters ---------- + framework: str + Framework of the model to train and evaluate. lrate: float, default=0.01 - Learning rate of the SGD algorithm. + Learning rate of the SGD algorithm performed by clients. rounds: int, default=10 - Number of training rounds to perform, i.e. number of epochs. + Number of FL training rounds to perform. + At each round, each client will perform a full epoch of training. b_size: int, default=10 Batch size fot the training (and validation) data. Batching will be performed without shuffling nor replacement, and the final batch may be smaller than the others (no drop). + clients: int, default=3 + Number of federated clients to set up and run. """ - # Generate the client datasets, then centralize them into numpy arrays. - - train, valid = prep_full_dataset() - # Set up a scikit-learn model, implementing step-wise gradient descent. - sgd = SGDRegressor( - loss="squared_error", - penalty="l1", - alpha=0.1, - eta0=lrate / b_size, # adjust learning rate for (dropped) batch size - learning_rate="constant", # disable scaling, unused in declearn - max_iter=rounds, - ) - # Iteratively train the model, evaluating it after each epoch. - for _ in range(rounds): - sgd.partial_fit(train[0], train[1]) - assert sgd.score(valid[0], valid[1]) > R2_THRESHOLD + datasets = prep_client_datasets(framework, clients) + with tempfile.TemporaryDirectory() as folder: + # Set up the server and client coroutines. + coro_server = async_run_server( + folder, framework, lrate, rounds, b_size, clients + ) + coro_clients = [ + async_run_client(train, valid, name=f"client_{i}") + for i, (train, valid) in enumerate(datasets) + ] + # Run the coroutines concurrently using asyncio. + outputs = await asyncio.gather( + coro_server, *coro_clients, return_exceptions=True + ) + # Assert that no exceptions occurred during the process. + errors = "\n".join( + repr(exc) for exc in outputs if isinstance(exc, Exception) + ) + assert not errors, f"The FL process failed:\n{errors}" + # Assert that the federated model converged above an expected value. + with open( + os.path.join(folder, "metrics.json"), encoding="utf-8" + ) as file: + metrics = json.load(file) + best_r2 = max(values["r2"] for values in metrics.values()) + assert best_r2 >= R2_THRESHOLD, "The FL training did not converge" diff --git a/test/main/test_data_info.py b/test/main/test_data_info.py new file mode 100644 index 0000000000000000000000000000000000000000..de7aa14761870e9a8b0024a21bd69cdccfce1804 --- /dev/null +++ b/test/main/test_data_info.py @@ -0,0 +1,95 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for 'declearn.main.utils.aggregate_clients_data_info'.""" + +from unittest import mock + +import pytest + +from declearn.main.utils import AggregationError, aggregate_clients_data_info + + +class TestAggregateClientsDataInfo: + """Unit tests for 'declearn.main.utils.aggregate_clients_data_info'.""" + + def test_with_valid_inputs(self) -> None: + """Test 'aggregate_clients_data_info' with valid inputs.""" + clients_data_info = { + "client_a": {"n_samples": 10}, + "client_b": {"n_samples": 32}, + } + results = aggregate_clients_data_info(clients_data_info, {"n_samples"}) + assert results == {"n_samples": 42} + + def test_with_missing_fields(self) -> None: + """Test 'aggregate_clients_data_info' with some missing fields.""" + clients_data_info = { + "client_a": {"n_samples": 10}, + "client_b": {"n_samples": 32}, + } + with pytest.raises(AggregationError): + aggregate_clients_data_info( + clients_data_info, + required_fields={"n_samples", "features_shape"}, + ) + + def test_with_invalid_values(self) -> None: + """Test 'aggregate_clients_data_info' with some invalid values.""" + clients_data_info = { + "client_a": {"n_samples": 10}, + "client_b": {"n_samples": -1}, + } + with pytest.raises(AggregationError): + aggregate_clients_data_info( + clients_data_info, required_fields={"n_samples"} + ) + + def test_with_incompatible_values(self) -> None: + """Test 'aggregate_clients_data_info' with some incompatible values.""" + clients_data_info = { + "client_a": {"features_shape": (100,)}, + "client_b": {"features_shape": (128,)}, + } + with pytest.raises(AggregationError): + aggregate_clients_data_info( + clients_data_info, required_fields={"features_shape"} + ) + + def test_with_unexpected_keyerror(self) -> None: + """Test 'aggregate_clients_data_info' with an unforeseen KeyError.""" + with mock.patch( + "declearn.main.utils._data_info.aggregate_data_info" + ) as patch_agg: + patch_agg.side_effect = KeyError("Forced KeyError") + with pytest.raises(AggregationError): + aggregate_clients_data_info( + clients_data_info={"client_a": {}, "client_b": {}}, + required_fields=set(), + ) + + def test_with_unexpected_exception(self) -> None: + """Test 'aggregate_clients_data_info' with an unforeseen Exception.""" + with mock.patch( + "declearn.main.utils._data_info.aggregate_data_info" + ) as patch_agg: + patch_agg.side_effect = Exception("Forced Exception") + with pytest.raises(AggregationError): + aggregate_clients_data_info( + clients_data_info={"client_a": {}, "client_b": {}}, + required_fields=set(), + ) diff --git a/test/main/test_early_stopping.py b/test/main/test_early_stopping.py new file mode 100644 index 0000000000000000000000000000000000000000..764b59ff75b96bec975a1cc067f85b13394bc247 --- /dev/null +++ b/test/main/test_early_stopping.py @@ -0,0 +1,108 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for 'declearn.main.utils.EarlyStopping'.""" + + +from declearn.main.utils import EarlyStopping + + +class TestEarlyStopping: + """Unit tests for 'declearn.main.utils.EarlyStopping'.""" + + def test_keep_training_initial(self) -> None: + """Test that a brand new EarlyStopping indicates to train.""" + early_stop = EarlyStopping() + assert early_stop.keep_training + + def test_update_first(self) -> None: + """Test that an instantiated EarlyStopping's update works.""" + early_stop = EarlyStopping() + keep_going = early_stop.update(1.0) + assert keep_going + assert keep_going == early_stop.keep_training + + def test_update_twice(self) -> None: + """Test that an EarlyStopping can be reached in a simple case.""" + early_stop = EarlyStopping(tolerance=0.0, patience=1, decrease=True) + assert early_stop.update(1.0) + assert not early_stop.update(1.0) + assert not early_stop.keep_training + + def test_reset_after_stopping(self) -> None: + """Test that 'EarlyStopping.reset()' works properly.""" + # Reach the criterion once. + early_stop = EarlyStopping(tolerance=0.0, patience=1, decrease=True) + assert early_stop.update(1.0) + assert not early_stop.update(1.0) + assert not early_stop.keep_training + # Reset and test that the criterion has been properly reset. + early_stop.reset() + assert early_stop.keep_training + assert early_stop.update(1.0) + # Reach the criterion for the second time. + assert not early_stop.update(1.0) + assert not early_stop.keep_training + + def test_with_two_steps_patience(self) -> None: + """Test an EarlyStopping criterion with 2-steps patience.""" + early_stop = EarlyStopping(tolerance=0.0, patience=2, decrease=True) + assert early_stop.update(1.0) + assert early_stop.update(1.5) # patience tempers stopping + assert early_stop.update(0.0) # patience is reset + assert early_stop.update(0.5) # patience tempers stopping + assert not early_stop.update(0.2) # patience is exhausted + + def test_with_absolute_tolerance_positive(self) -> None: + """Test an EarlyStopping criterion with 0.2 absolute tolerance.""" + early_stop = EarlyStopping(tolerance=0.2, patience=1, decrease=True) + assert early_stop.update(1.0) + assert early_stop.update(0.7) + assert not early_stop.update(0.6) # progress below tolerance + + def test_with_absolute_tolerance_negative(self) -> None: + """Test an EarlyStopping criterion with -0.5 absolute tolerance.""" + early_stop = EarlyStopping(tolerance=-0.5, patience=1, decrease=True) + assert early_stop.update(1.0) + assert early_stop.update(1.2) # regression below tolerance + assert not early_stop.update(1.6) # regression above tolerance + + def test_with_relative_tolerance_positive(self) -> None: + """Test an EarlyStopping criterion with 0.1 relative tolerance.""" + early_stop = EarlyStopping( + tolerance=0.1, patience=1, decrease=True, relative=True + ) + assert early_stop.update(1.0) + assert early_stop.update(0.8) # progress above tolerance + assert not early_stop.update(0.75) # progress below tolerance + + def test_with_relative_tolerance_negative(self) -> None: + """Test an EarlyStopping criterion with -0.1 relative tolerance.""" + early_stop = EarlyStopping( + tolerance=-0.1, patience=1, decrease=True, relative=True + ) + assert early_stop.update(1.0) + assert early_stop.update(0.80) # progress + assert early_stop.update(0.85) # regression below tolerance + assert not early_stop.update(0.89) # regression above tolerance + + def test_with_increasing_metric(self) -> None: + """Test an EarlyStopping that monitors an increasing metric.""" + early_stop = EarlyStopping(tolerance=0.0, patience=1, decrease=False) + assert early_stop.update(1.0) + assert early_stop.update(2.0) # progress + assert not early_stop.update(1.5) # regression (no patience/tolerance) diff --git a/test/metrics/test_mae_mse.py b/test/metrics/test_mae_mse.py index cd55e880d6561de0d5b33636be8cb57e4f21fdb0..de472efdf1fa668752e61e26b14c6ff9b187a29c 100644 --- a/test/metrics/test_mae_mse.py +++ b/test/metrics/test_mae_mse.py @@ -24,7 +24,7 @@ import numpy as np import pytest from declearn.metrics import MeanAbsoluteError, MeanSquaredError, Metric -from declearn.test_utils import make_importable +from declearn.test_utils import assert_dict_equal, make_importable # relative imports from `metric_testing.py` with make_importable(os.path.dirname(__file__)): @@ -69,7 +69,7 @@ class MeanMetricTestSuite(MetricTestSuite): """Unit tests suite for `MeanMetric` subclasses.""" def test_update_errors(self, test_case: MetricTestCase) -> None: - """Test that `update` raises on improper `s_wght` shapes.""" + """Test that `update` raises on improper input shapes.""" metric = test_case.metric inputs = test_case.inputs # Test with multi-dimensional sample weights. @@ -80,12 +80,34 @@ class MeanMetricTestSuite(MetricTestSuite): s_wght = np.ones(shape=(len(inputs["y_pred"]) + 2,)) with pytest.raises(ValueError): metric.update(inputs["y_true"], inputs["y_pred"], s_wght) + # Test with mismatching-shape inputs. + y_true = inputs["y_true"] + y_pred = np.stack([inputs["y_pred"], inputs["y_pred"]], axis=-1) + with pytest.raises(ValueError): + metric.update(y_true, y_pred, s_wght) def test_zero_result(self, test_case: MetricTestCase) -> None: """Test that `get_results` works with zero-valued divisor.""" metric = test_case.metric assert metric.get_result() == {metric.name: 0.0} + def test_update_expanded_shape(self, test_case: MetricTestCase) -> None: + """Test that the metric supports expanded-dim input predictions.""" + # Gather states with basic inputs. + metric, inputs = test_case.metric, test_case.inputs + metric.update(**inputs) + states = metric.get_states() + metric.reset() + # Do the same with expanded-dim predictions. + metric.update( + inputs["y_true"], + np.expand_dims(inputs["y_pred"], -1), + inputs.get("s_wght"), + ) + st_bis = metric.get_states() + # Verify that results are the same. + assert_dict_equal(states, st_bis) + @pytest.mark.parametrize("weighted", [False, True], ids=["base", "weighted"]) @pytest.mark.parametrize("case", ["mae"]) diff --git a/test/optimizer/test_modules.py b/test/optimizer/test_modules.py index 38c209380273f46bed6e44707e20edf53f95b8f1..479dec67480787ba9b9dc8aca0a88df6aea56d48 100644 --- a/test/optimizer/test_modules.py +++ b/test/optimizer/test_modules.py @@ -38,6 +38,7 @@ import os from typing import Type import pytest +from declearn.optimizer import list_optim_modules from declearn.optimizer.modules import NoiseModule, OptiModule from declearn.test_utils import ( FrameworkType, @@ -46,7 +47,7 @@ from declearn.test_utils import ( assert_json_serializable_dict, make_importable, ) -from declearn.utils import access_types_mapping, set_device_policy +from declearn.utils import set_device_policy # relative imports from `optim_testing.py` with make_importable(os.path.dirname(__file__)): @@ -54,7 +55,7 @@ with make_importable(os.path.dirname(__file__)): # Access the list of modules to test; remove some that have dedicated tests. -OPTIMODULE_SUBCLASSES = access_types_mapping(group="OptiModule") +OPTIMODULE_SUBCLASSES = list_optim_modules() OPTIMODULE_SUBCLASSES.pop("tensorflow-optim", None) OPTIMODULE_SUBCLASSES.pop("torch-optim", None) diff --git a/test/optimizer/test_regularizers.py b/test/optimizer/test_regularizers.py index 424a9df2af1c6db551a2b18c550e3ff7412cea39..9b5c21e34a18e8d00328e28330b32644082cb117 100644 --- a/test/optimizer/test_regularizers.py +++ b/test/optimizer/test_regularizers.py @@ -38,16 +38,16 @@ from typing import Type import pytest +from declearn.optimizer import list_optim_regularizers from declearn.optimizer.regularizers import Regularizer from declearn.test_utils import make_importable -from declearn.utils import access_types_mapping # relative imports from `optim_testing.py` with make_importable(os.path.dirname(__file__)): from optim_testing import PluginTestBase -REGULARIZER_SUBCLASSES = access_types_mapping(group="Regularizer") +REGULARIZER_SUBCLASSES = list_optim_regularizers() @pytest.mark.parametrize( diff --git a/test/quickrun/test_quickrun_utils.py b/test/quickrun/test_quickrun_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bec273314210fc8e8fd14d20a35a845bcf632d2e --- /dev/null +++ b/test/quickrun/test_quickrun_utils.py @@ -0,0 +1,266 @@ +# coding: utf-8 + +"""Tests for some 'declearn.quickrun' backend utils.""" + +import os +import pathlib +from typing import List + +import pytest + +from declearn.quickrun import parse_data_folder +from declearn.quickrun._config import DataSourceConfig +from declearn.quickrun._parser import ( + get_data_folder_path, + list_client_names, +) +from declearn.quickrun._run import get_toml_folder + + +class TestGetTomlFolder: + """Tests for 'declearn.quickrun._run.get_toml_folder'.""" + + def test_get_toml_folder_from_file( + self, + tmp_path: pathlib.Path, + ) -> None: + """Test that 'get_toml_folder' works with a TOML file path.""" + config = os.path.join(tmp_path, "config.toml") + with open(config, "w", encoding="utf-8") as file: + file.write("") + toml, folder = get_toml_folder(config) + assert toml == config + assert folder == tmp_path.as_posix() + + def test_get_toml_folder_from_folder( + self, + tmp_path: pathlib.Path, + ) -> None: + """Test that 'get_toml_folder' works with a folder path.""" + config = os.path.join(tmp_path, "config.toml") + with open(config, "w", encoding="utf-8") as file: + file.write("") + toml, folder = get_toml_folder(tmp_path.as_posix()) + assert toml == config + assert folder == tmp_path.as_posix() + + def test_get_toml_folder_from_file_fails( + self, + tmp_path: pathlib.Path, + ) -> None: + """Test that it fails with a path to a non-existing file.""" + config = os.path.join(tmp_path, "config.toml") + with pytest.raises(FileNotFoundError): + get_toml_folder(config) + + def test_get_toml_folder_from_folder_fails( + self, + tmp_path: pathlib.Path, + ) -> None: + """Test that it fails with a folder lacking a 'config.toml' file.""" + with pytest.raises(FileNotFoundError): + get_toml_folder(tmp_path.as_posix()) + + +class TestGetDataFolderPath: + """Tests for 'declearn.quickrun._parser.get_data_folder_path'.""" + + def test_get_data_folder_path_from_data_folder( + self, + tmp_path: pathlib.Path, + ) -> None: + """Test that it works with a valid 'data_folder' argument.""" + path = get_data_folder_path(data_folder=tmp_path.as_posix()) + assert isinstance(path, pathlib.Path) + assert path == tmp_path + + def test_get_data_folder_path_from_root_folder( + self, + tmp_path: pathlib.Path, + ) -> None: + """Test that it works with a valid 'root_folder' argument.""" + data_dir = os.path.join(tmp_path, "data") + os.makedirs(data_dir) + path = get_data_folder_path(root_folder=tmp_path.as_posix()) + assert isinstance(path, pathlib.Path) + assert path.as_posix() == data_dir + + def test_get_data_folder_path_fails_no_inputs( + self, + ) -> None: + """Test that it fails with no folder specification.""" + with pytest.raises(ValueError): + get_data_folder_path(None, None) + + def test_get_data_folder_path_fails_invalid_data_folder( + self, + tmp_path: pathlib.Path, + ) -> None: + """Test that it fails with an invalid data_folder.""" + missing_folder = os.path.join(tmp_path, "data") + with pytest.raises(ValueError): + get_data_folder_path(data_folder=missing_folder) + + def test_get_data_folder_path_fails_from_root_no_data( + self, + tmp_path: pathlib.Path, + ) -> None: + """Test that it fails with an invalid root_folder (no data).""" + with pytest.raises(ValueError): + get_data_folder_path(root_folder=tmp_path.as_posix()) + + def test_get_data_folder_path_fails_from_root_multiple_data( + self, + tmp_path: pathlib.Path, + ) -> None: + """Test that it fails with multiple data* under root_folder.""" + os.makedirs(os.path.join(tmp_path, "data_1")) + os.makedirs(os.path.join(tmp_path, "data_2")) + with pytest.raises(ValueError): + get_data_folder_path(root_folder=tmp_path.as_posix()) + + +class TestListClientNames: + """Tests for the 'declearn.quickrun._parser.list_client_names' function.""" + + def test_list_client_names_from_folder( + self, + tmp_path: pathlib.Path, + ) -> None: + """Test that it works with a data folder.""" + os.makedirs(os.path.join(tmp_path, "client_1")) + os.makedirs(os.path.join(tmp_path, "client_2")) + names = list_client_names(tmp_path) + assert isinstance(names, list) + assert sorted(names) == ["client_1", "client_2"] + + def test_list_client_names_from_names( + self, + tmp_path: pathlib.Path, + ) -> None: + """Test that it works with pre-specified names.""" + os.makedirs(os.path.join(tmp_path, "client_1")) + os.makedirs(os.path.join(tmp_path, "client_2")) + names = list_client_names(tmp_path, ["client_2"]) + assert names == ["client_2"] + + def test_list_client_names_fails_invalid_names( + self, + tmp_path: pathlib.Path, + ) -> None: + """Test that it works with invalid pre-specified names.""" + with pytest.raises(ValueError): + list_client_names(tmp_path, "invalid-type") # type: ignore + with pytest.raises(ValueError): + list_client_names(tmp_path, ["client_2"]) + + +class TestParseDataFolder: + """Docstring.""" + + @staticmethod + def setup_data_folder( + data_folder: str, + client_names: List[str], + file_names: List[str], + ) -> None: + """Set up a data folder, with client subfolders and empty files.""" + for cname in client_names: + folder = os.path.join(data_folder, cname) + os.makedirs(folder) + for fname in file_names: + path = os.path.join(folder, fname) + with open(path, "w", encoding="utf-8") as file: + file.write("") + + def test_parse_data_folder_with_default_names( + self, + tmp_path: pathlib.Path, + ) -> None: + """Test 'parse_data_folder' with default file names.""" + # Setup a data folder with a couple of clients and default files. + data_folder = tmp_path.as_posix() + client_names = ["client-1", "client-2"] + file_names = [ + # fmt: off + "train_data", "train_target", "valid_data", "valid_target" + ] + self.setup_data_folder(data_folder, client_names, file_names) + # Write up the expected outputs. + expected = { + cname: { + fname: os.path.join(data_folder, cname, fname) + for fname in file_names + } + for cname in client_names + } + # Run the function and validate its outputs. + config = DataSourceConfig( + data_folder=data_folder, + client_names=None, + dataset_names=None, + ) + clients = parse_data_folder(config) + assert clients == expected + + def test_parse_data_folder_with_custom_names( + self, + tmp_path: pathlib.Path, + ) -> None: + """Test 'parse_data_folder' with custom file names.""" + # Setup a data folder with a couple of clients and default files. + data_folder = tmp_path.as_posix() + client_names = ["client-1", "client-2"] + base_names = [ + # fmt: off + "train_data", "train_target", "valid_data", "valid_target" + ] + file_names = ["x_train", "y_train", "x_valid", "y_valid"] + self.setup_data_folder(data_folder, client_names, file_names) + # Write up the expected outputs. + expected = { + cname: { + bname: os.path.join(data_folder, cname, fname) + for bname, fname in zip(base_names, file_names) + } + for cname in client_names + } + # Run the function and validate its outputs. + config = DataSourceConfig( + data_folder=data_folder, + client_names=None, + dataset_names=dict(zip(base_names, file_names)), + ) + clients = parse_data_folder(config) + assert clients == expected + # Verify that it would not work without the argument. + config = DataSourceConfig( + data_folder=data_folder, + client_names=None, + dataset_names=None, + ) + with pytest.raises(ValueError): + clients = parse_data_folder(config) + + def test_parse_data_folder_fails_multiple_files( + self, + tmp_path: pathlib.Path, + ) -> None: + """Test that 'parse_data_folder' fails with same-name files.""" + # Setup a data folder with a couple of clients and duplicated files. + data_folder = tmp_path.as_posix() + client_names = ["client-1", "client-2"] + file_names = [ + # fmt: off + "train_data", "train_target", "valid_data", "valid_target", + "train_data.bis" # duplicated name prefix + ] + self.setup_data_folder(data_folder, client_names, file_names) + # Verify that the expected exception is raised. + config = DataSourceConfig( + data_folder=data_folder, + client_names=None, + dataset_names=None, + ) + with pytest.raises(ValueError): + parse_data_folder(config)