diff --git a/declearn/dataset/_inmemory.py b/declearn/dataset/_inmemory.py index 650f2c88f402af14a49f624501285b9d978c5ff3..3d1f57368631015b624427b045fa714fdf178205 100644 --- a/declearn/dataset/_inmemory.py +++ b/declearn/dataset/_inmemory.py @@ -18,7 +18,8 @@ """Dataset implementation to serve scikit-learn compatible in-memory data.""" import os -from typing import Any, Dict, Iterator, List, Optional, Set, Union +import typing +from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union import numpy as np import pandas as pd @@ -29,7 +30,7 @@ from typing_extensions import Self # future: import from typing (py >=3.11) from declearn.dataset._base import Dataset, DataSpecs from declearn.dataset.utils import load_data_array, save_data_array -from declearn.typing import Batch +from declearn.typing import Batch, DataArray from declearn.utils import json_dump, json_load, register_type __all__ = [ @@ -37,7 +38,7 @@ __all__ = [ ] -DataArray = Union[np.ndarray, pd.DataFrame, spmatrix] +DATA_ARRAY_TYPES = typing.get_args(DataArray) @register_type(group="Dataset") @@ -87,90 +88,173 @@ class InMemoryDataset(Dataset): an instance that is either a numpy ndarray, a pandas DataFrame or a scipy spmatrix. - See the `load_data_array` function in dataset._utils for details - on supported file formats. + See the `load_data_array` function in `dataset.utils` + for details on supported file formats. Parameters ---------- - data: data array or str + data: Main data array which contains input features (and possibly more), or path to a dump file from which it is to be loaded. - target: data array or str or None, default=None - Optional data array containing target labels (for supervised - learning), or path to a dump file from which to load it. - If `data` is a pandas DataFrame (or a path to a csv file), - `target` may be the name of a column to use as labels (and - thus not to use as input feature unless listed in `f_cols`). - s_wght: int or str or function or None, default=None - Optional data array containing sample weights, or path to a - dump file from which to load it. - If `data` is a pandas DataFrame (or a path to a csv file), - `s_wght` may be the name of a column to use as labels (and - thus not to use as input feature unless listed in `f_cols`). - f_cols: list[int] or list[str] or None, default=None - Optional list of columns in `data` to use as input features - (other columns will not be included in the first array of - the batches yielded by `self.generate_batches(...)`). - expose_classes: bool, default=False - Whether the dataset should be used for classification, in - which case the unique values of `target` are exposed under - `self.classes` and exported by `self.get_data_specs()`). - expose_data_type: bool, default=False - Whether the dataset should expose the data type , in - which case check if the type is unique and exposed it - under `self.data_type` and exported by `self.get_data_specs()`). - seed: int or None, default=None - Optional seed for the random number generator based on which - the dataset is (optionally) shuffled when generating batches. + target: + Optional target labels, as a data array, or as a path to a + dump file, or as the name of a `data` column. + s_wght: + Optional sample weights, as a data array, or as a path to a + dump file, or as the name of a `data` column. + f_cols: + Optional list of columns in `data` to use as input features. + These may be specified as column names or indices. If None, + use all non-target, non-sample-weights columns of `data`. + + Other parameters + ---------------- + expose_classes: + Whether to expose unique target values as part of data specs. + This should only be used for classification datasets. + expose_data_type: + Whether to expose features' dtype, which will be verified to + be unique, as part of data specs. + seed: + Optional seed for the random number generator used for all + randomness-based operations required to generate batches + (e.g. to shuffle the data or sample from it). """ # arguments serve modularity; pylint: disable=too-many-arguments - self._data_path = None # type: Optional[str] - self._trgt_path = None # type: Optional[str] # Assign the main data array. - if isinstance(data, str): - self._data_path = data - data = load_data_array(data) - self.data = data + data_array, src_path = self._parse_data_argument(data) + self.data = data_array + self._data_path = src_path # Assign the optional input features list. - self.f_cols = f_cols + self.f_cols = self._parse_fcols_argument(f_cols, data=self.data) # Assign the (optional) target data array. - if isinstance(target, str): - self._trgt_path = target - if ( - isinstance(self.data, pd.DataFrame) - and target in self.data.columns - ): - if f_cols is None: - self.f_cols = self.f_cols or list(self.data.columns) - self.f_cols.remove(target) # type: ignore - target = self.data[target] - else: - target = load_data_array(target) - if ( - isinstance(target, pd.DataFrame) - and len(target.columns) == 1 - ): - target = target.iloc[:, 0] - self.target = target + data_array, src_path = self._parse_array_or_column_argument( + value=target, data=self.data, name="target" + ) + self.target = data_array + self._trgt_path = src_path + if self.f_cols and src_path and src_path in self.f_cols: + self.f_cols.remove(src_path) # type: ignore[arg-type] # Assign the (optional) sample weights data array. - if isinstance(s_wght, str): - self._wght_path = s_wght - if isinstance(self.data, pd.DataFrame): - if s_wght in self.data.columns: - if f_cols is None: - self.f_cols = self.f_cols or list(self.data.columns) - self.f_cols.remove(s_wght) # type: ignore - s_wght = self.data[s_wght] - else: - s_wght = load_data_array(s_wght) - self.weights = s_wght - # Assign the 'expose_classes' attribute. + data_array, src_path = self._parse_array_or_column_argument( + value=s_wght, data=self.data, name="s_wght" + ) + self.weights = data_array + self._wght_path = src_path + if self.f_cols and src_path and src_path in self.f_cols: + self.f_cols.remove(src_path) # type: ignore[arg-type] + # Assign the 'expose_classes' and 'expose_data_type' attributes. self.expose_classes = expose_classes self.expose_data_type = expose_data_type # Assign a random number generator. self.seed = seed self._rng = np.random.default_rng(seed) + @staticmethod + def _parse_data_argument( + data: Union[DataArray, str], + ) -> Tuple[DataArray, Optional[str]]: + """Parse 'data' instantiation argument. + + Return the definitive 'data' array, and its source path if any. + """ + # Case when an array is provided directly. + if isinstance(data, DATA_ARRAY_TYPES): + return data, None + # Case when an invalid type is provided. + if not isinstance(data, str): + raise TypeError( + f"'data' must be a data array or str, not '{type(data)}'." + ) + # Case when a string is provided: treat it as a file path. + try: + array = load_data_array(data) + except Exception as exc: + raise ValueError( + "Error while trying to load main 'data' array from file." + ) from exc + return array, data + + @staticmethod + def _parse_fcols_argument( + f_cols: Union[List[str], List[int], None], + data: DataArray, + ) -> Union[List[str], List[int], None]: + """Type and value-check 'f_cols' argument. + + Return its definitive value or raise an exception. + """ + # Case when 'f_cols' is None: optionally replace with list of names. + if f_cols is None: + if isinstance(data, pd.DataFrame): + return list(data.columns) + return f_cols + # Case when 'f_cols' has an invalid type. + if not isinstance(f_cols, (list, tuple, set)): + raise TypeError( + f"'f_cols' must be None or a list, nor '{type(f_cols)}'." + ) + # Case when 'f_cols' is a list of str: verify and return it. + if all(isinstance(col, str) for col in f_cols): + if not isinstance(data, pd.DataFrame): + raise ValueError( + "'f_cols' is a list of str but 'data' is not a DataFrame." + ) + if set(f_cols).issubset(data.columns): + return f_cols.copy() + raise ValueError( + "Specified 'f_cols' is not a subset of 'data' columns." + ) + # Case when 'f_cols' is a list of str: verify and return it. + if all(isinstance(col, int) for col in f_cols): + if max(f_cols) >= data.shape[1]: # type: ignore + raise ValueError( + "Invalid 'f_cols' indices given 'data' shape." + ) + return f_cols.copy() + # Case when 'f_cols' has mixed or invalid internal types. + raise TypeError( + "'f_cols' should be a list of all-int or all-str values." + ) + + @staticmethod + def _parse_array_or_column_argument( + value: Union[DataArray, str, None], + data: DataArray, + name: str, + ) -> Tuple[DataArray, Optional[str]]: + """Parse input 'target' argument. + + Return 'target' (optional data array) and its source 'path' + when relevant (optional string). + """ + # Case of a data array or None value: return as-is. + if isinstance(value, DATA_ARRAY_TYPES) or value is None: + return value, None + # Case of an invalid type: raise. + if not isinstance(value, str): + raise TypeError( + f"'{name}' must be a data array or str, not '{type(value)}'." + ) + # Case of a string matching a 'data' column name: return it. + if isinstance(data, pd.DataFrame) and value in data.columns: + return data[value], value + # Case of a string matching nothing. + if not os.path.isfile(value): + raise ValueError( + f"'{name}' does not match any 'data' column nor file path." + ) + # Case of a string matching a filepath. + try: + array = load_data_array(value) + except Exception as exc: + raise ValueError( + f"Error while trying to load '{name}' data from file." + ) from exc + if isinstance(array, pd.DataFrame) and len(array.columns) == 1: + array = array.iloc[:, 0] + return array, value + @property def feats( self, diff --git a/declearn/dataset/examples/_mnist.py b/declearn/dataset/examples/_mnist.py index a839d2b28a48328283ec5be459868421d639d7d5..030db9522ad3ceaf3b001b91262da27d53f6eb99 100644 --- a/declearn/dataset/examples/_mnist.py +++ b/declearn/dataset/examples/_mnist.py @@ -71,7 +71,10 @@ def _load_mnist_data( name = f"{name}-ubyte.gz" # Optionally download the gzipped file from the internet. if folder is None or not os.path.isfile(os.path.join(folder, name)): - data = _download_mnist_file(name, folder) + try: + data = _download_mnist_file(name, folder, lecun=True) + except RuntimeError: # pragma: no cover + data = _download_mnist_file(name, folder, lecun=False) data = gzip.decompress(data) # Otherwise, read its contents from a local copy. else: @@ -88,13 +91,20 @@ def _load_mnist_data( return (array / 255).astype(np.single) if images else array -def _download_mnist_file(name: str, folder: Optional[str]) -> bytes: +def _download_mnist_file( + name: str, + folder: Optional[str], + lecun: bool = True, +) -> bytes: """Download a MNIST source file and opt. save it in a given folder.""" # Download the file in memory. print(f"Downloading MNIST source file {name}.") - reply = requests.get( - f"http://yann.lecun.com/exdb/mnist/{name}", timeout=300 + base_url = ( + "http://yann.lecun.com/exdb" + if lecun + else "https://ossci-datasets.s3.amazonaws.com" ) + reply = requests.get(f"{base_url}/mnist/{name}", timeout=300) try: reply.raise_for_status() except requests.HTTPError as exc: diff --git a/declearn/dataset/utils/_save_load.py b/declearn/dataset/utils/_save_load.py index 90d936f09e9a549c3bc1617e617c1eb84a960456..efb0ef4a21fd7cf858a22f4e5a7fb2a2a7782f0c 100644 --- a/declearn/dataset/utils/_save_load.py +++ b/declearn/dataset/utils/_save_load.py @@ -22,11 +22,12 @@ import os from typing import Any, Union import numpy as np -import pandas as pd # type: ignore +import pandas as pd from scipy.sparse import spmatrix # type: ignore from sklearn.datasets import load_svmlight_file # type: ignore from declearn.dataset.utils._sparse import sparse_from_file, sparse_to_file +from declearn.typing import DataArray __all__ = [ "load_data_array", @@ -34,9 +35,6 @@ __all__ = [ ] -DataArray = Union[np.ndarray, pd.DataFrame, spmatrix] - - def load_data_array( path: str, **kwargs: Any, diff --git a/declearn/dataset/utils/_split_classif.py b/declearn/dataset/utils/_split_classif.py index 0e3dbc98d348141e27b52f80605f1eb0176cc2c2..7f254188b03e71c5dfab1e2e3b80047fefcb6f42 100644 --- a/declearn/dataset/utils/_split_classif.py +++ b/declearn/dataset/utils/_split_classif.py @@ -17,9 +17,11 @@ """Utils to split a multi-category classification dataset into shards.""" -from typing import List, Literal, Optional, Tuple, Type, Union +import functools +from typing import Any, List, Literal, Optional, Tuple, Type, Union import numpy as np +import scipy.stats # type: ignore from scipy.sparse import csr_matrix, spmatrix # type: ignore @@ -31,9 +33,10 @@ __all__ = [ def split_multi_classif_dataset( dataset: Tuple[Union[np.ndarray, spmatrix], np.ndarray], n_shards: int, - scheme: Literal["iid", "labels", "biased"], + scheme: Literal["iid", "labels", "dirichlet", "biased"], p_valid: float = 0.2, seed: Optional[int] = None, + **kwargs: Any, ) -> List[Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]]: """Split a classification dataset into (opt. heterogeneous) shards. @@ -42,6 +45,9 @@ def split_multi_classif_dataset( - If "iid", split the dataset through iid random sampling. - If "labels", split into shards that hold all samples associated with mutually-exclusive target classes. + - If "dirichlet", split the dataset through random sampling using + label-wise shard-assignment probabilities drawn from a symmetrical + Dirichlet distribution, parametrized by an `alpha` parameter. - If "biased", split the dataset through random sampling according to a shard-specific random labels distribution. @@ -53,13 +59,17 @@ def split_multi_classif_dataset( be a scipy sparse matrix, that will temporarily be cast to CSR. n_shards: int Number of shards between which to split the dataset. - scheme: {"iid", "labels", "biased"} + scheme: {"iid", "labels", "dirichlet", "biased"} Splitting scheme to use. In all cases, shards contain mutually- exclusive samples and cover the full dataset. See details above. p_valid: float, default=0.2 Share of each shard to turn into a validation subset. seed: int or None, default=None Optional seed to the RNG used for all sampling operations. + **kwargs: + Additional hyper-parameters specific to the split scheme. + Exhaustive list of possible values: + - `alpha: float = 0.5` for `scheme="dirichlet"` Returns ------- @@ -80,6 +90,10 @@ def split_multi_classif_dataset( func = split_iid elif scheme == "labels": func = split_labels + elif scheme == "dirichlet": + func = functools.partial( + split_dirichlet, alpha=kwargs.get("alpha", 0.5) + ) elif scheme == "biased": func = split_biased else: @@ -157,7 +171,14 @@ def split_biased( n_shards: int, rng: np.random.Generator, ) -> List[Tuple[np.ndarray, np.ndarray]]: - """Split a dataset into shards with heterogeneous label distributions.""" + """Split a dataset into shards with heterogeneous label distributions. + + Use a normal distribution to draw logits of labels distributions for + each and every node. + + This approach is not based on the litterature. We advise end-users to + use a Dirichlet split instead, which is probably better-grounded. + """ classes = np.unique(target) index = np.arange(len(target)) s_len = len(target) // n_shards @@ -178,6 +199,36 @@ def split_biased( return split +def split_dirichlet( + inputs: Union[np.ndarray, csr_matrix], + target: np.ndarray, + n_shards: int, + rng: np.random.Generator, + alpha: float = 0.5, +) -> List[Tuple[np.ndarray, np.ndarray]]: + """Split a dataset into shards with heterogeneous label distributions. + + Use a symmetrical multinomial Dirichlet(alpha) distribution to sample + the proportion of samples per label in each shard. + + This approach has notably been used by Sturluson et al. (2021). + FedRAD: Federated Robust Adaptive Distillation. arXiv:2112.01405 [cs.LG] + """ + classes = np.unique(target) + # Draw per-label proportion of samples to assign to each shard. + process = scipy.stats.dirichlet(alpha=[alpha] * n_shards) + c_probs = process.rvs(size=len(classes), random_state=rng) + # Randomly assign label-wise samples to shards based on these. + shard_i = [[] for _ in range(n_shards)] # type: List[List[int]] + for lab_i, label in enumerate(classes): + index = np.where(target == label)[0] + s_idx = rng.choice(n_shards, size=len(index), p=c_probs[lab_i]) + for i in range(n_shards): + shard_i[i].extend(index[s_idx == i]) + # Gather the actual sample shards. + return [(inputs[index], target[index]) for index in shard_i] + + def train_valid_split( inputs: np.ndarray, target: np.ndarray, diff --git a/declearn/model/haiku/__init__.py b/declearn/model/haiku/__init__.py index 1adb8f1f2d0d9ec2593a3dcacc1d25bd808070cf..596d4c5c25ac11b7bf6be54e61c7642dbca6af12 100644 --- a/declearn/model/haiku/__init__.py +++ b/declearn/model/haiku/__init__.py @@ -17,9 +17,21 @@ """Haiku models interfacing tools. -This submodule provides with a generic interface to wrap up -any Haiku module instance that is to be trained -through gradient descent. +Haiku is a Google DeepMind library that provides with tools to build +artificial neural networks backed by the JAX computation library. We +selected it as a primary candidate to support using JAX-backed models, +mostly because of its simplicity, that leaves apart some components +that DecLearn already provides (such as optimization algorithms). + +In July 2023, Haiku development was announced to be stalled as far as +new features are concerned, in favor of Flax, a concurrent Google project. + +DecLearn is planned to add support for Flax at some point (building on the +existing Haiku-oriented code, notably as far as Jax NumPy is concerned). +In the meanwhile, this submodule enables running code that operates using +haiku, which probably does not cover a lot of use cases, but it bound to +keep working at least for a while, until Google decides to drop maintenance +altogether. This module exposes: * HaikuModel: Model subclass to wrap haiku.Model objects diff --git a/declearn/model/haiku/_vector.py b/declearn/model/haiku/_vector.py index be4aedb74bd9b8aaa52e978668bf60fe48063233..3c97a67f9f01341e46fe8497cd151b9b8544de9a 100644 --- a/declearn/model/haiku/_vector.py +++ b/declearn/model/haiku/_vector.py @@ -17,8 +17,7 @@ """JaxNumpyVector data arrays container.""" -import warnings -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Set, Tuple, Type, Union import jax import jax.numpy as jnp @@ -153,19 +152,9 @@ class JaxNumpyVector(Vector): def sum( self, - axis: Optional[int] = None, - keepdims: bool = False, ) -> Self: - if isinstance(axis, int) or keepdims: - warnings.warn( # pragma: no cover - "The 'axis' and 'keepdims' arguments of 'JaxNumpyVector.sum' " - "have been deprecated as of declearn v2.3, and will be " - "removed in version 2.5 and/or 3.0.", - DeprecationWarning, - ) coefs = { - key: jnp.array(jnp.sum(val, axis=axis, keepdims=keepdims)) - for key, val in self.coefs.items() + key: jnp.array(jnp.sum(val)) for key, val in self.coefs.items() } return self.__class__(coefs) diff --git a/declearn/model/sklearn/_np_vec.py b/declearn/model/sklearn/_np_vec.py index 40a8805ebd6010cbd12e1d142e966b9167a245a5..6ed6af7919a07f5e46b5b340acb1e5a594d1f6bb 100644 --- a/declearn/model/sklearn/_np_vec.py +++ b/declearn/model/sklearn/_np_vec.py @@ -17,8 +17,7 @@ """NumpyVector data arrays container.""" -import warnings -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Tuple, Union import numpy as np from typing_extensions import Self # future: import from typing (Py>=3.11) @@ -121,20 +120,8 @@ class NumpyVector(Vector): def sum( self, - axis: Optional[int] = None, - keepdims: bool = False, ) -> Self: - if isinstance(axis, int) or keepdims: - warnings.warn( # pragma: no cover - "The 'axis' and 'keepdims' arguments of 'NumpyVector.sum' " - "have been deprecated as of declearn v2.3, and will be " - "removed in version 2.5 and/or 3.0.", - DeprecationWarning, - ) - coefs = { - key: np.array(np.sum(val, axis=axis, keepdims=keepdims)) - for key, val in self.coefs.items() - } + coefs = {key: np.array(np.sum(val)) for key, val in self.coefs.items()} return self.__class__(coefs) def flatten( diff --git a/declearn/model/tensorflow/_vector.py b/declearn/model/tensorflow/_vector.py index ef0a0d17861ad50f8c6851252b5152af155f6ddb..146754298402788b76ac5a133b3dcb8ccd94c0a1 100644 --- a/declearn/model/tensorflow/_vector.py +++ b/declearn/model/tensorflow/_vector.py @@ -18,10 +18,7 @@ """TensorflowVector data arrays container.""" import warnings -from typing import ( - # fmt: off - Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union -) +from typing import Any, Callable, Dict, List, Set, Tuple, Type, TypeVar, Union # fmt: off import numpy as np @@ -279,27 +276,8 @@ class TensorflowVector(Vector): def sum( self, - axis: Optional[int] = None, - keepdims: bool = False, ) -> Self: - if keepdims or (axis is not None): - if any( # pragma: no cover - isinstance(x, tf.IndexedSlices) for x in self.coefs.values() - ): - warnings.warn( # pragma: no cover - "Calling `TensorflowVector.sum()` with non-default " - "arguments and tf.IndexedSlices coefficients might " - "result in unexpected outputs, due to the latter " - "being converted to their dense counterpart.", - category=RuntimeWarning, - ) - warnings.warn( # pragma: no cover - "The 'axis' and 'keepdims' arguments of 'TensorflowVector.sum'" - " have been deprecated as of declearn v2.3, and will be " - "removed in version 2.5 and/or 3.0.", - DeprecationWarning, - ) - return self.apply_func(tf.reduce_sum, axis=axis, keepdims=keepdims) + return self.apply_func(tf.reduce_sum) def __pow__( self, diff --git a/declearn/model/tensorflow/utils/_gpu.py b/declearn/model/tensorflow/utils/_gpu.py index a6f7f9f31bc9ce4d700a76fae5cc4e4e4497c20b..f9b86ccf6e30da6461bf928228a81c17952ae6f1 100644 --- a/declearn/model/tensorflow/utils/_gpu.py +++ b/declearn/model/tensorflow/utils/_gpu.py @@ -59,7 +59,7 @@ def select_device( Returns ------- - device: tf.config.LogicalDevice + device: Selected device, usable as `tf.device` argument. """ idx = 0 if idx is None else idx @@ -106,14 +106,15 @@ def move_layer_to_device( Returns ------- - layer: tf_keras.layers.Layer + layer: Copy of the input layer, with its weights backed on `device`. """ config = tf_keras.layers.serialize(layer) - weights = layer.get_weights() + weights = layer.get_weights() if layer.built else None with tf.device(device): layer = tf_keras.layers.deserialize(config) - layer.set_weights(weights) + if weights: + layer.set_weights(weights) return layer diff --git a/declearn/model/torch/_vector.py b/declearn/model/torch/_vector.py index de9447956b13fe1996615a56a60b9d4c33b85fdf..40fcc58ec26b6b7c938c6a665ac4d88f1b95aa0a 100644 --- a/declearn/model/torch/_vector.py +++ b/declearn/model/torch/_vector.py @@ -17,8 +17,7 @@ """TorchVector data arrays container.""" -import warnings -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Set, Tuple, Type, Union import numpy as np import torch @@ -214,20 +213,8 @@ class TorchVector(Vector): def sum( self, - axis: Optional[int] = None, - keepdims: bool = False, ) -> Self: - if isinstance(axis, int) or keepdims: - warnings.warn( # pragma: no cover - "The 'axis' and 'keepdims' arguments of 'TorchVector.sum' " - "have been deprecated as of declearn v2.3, and will be " - "removed in version 2.5 and/or 3.0.", - DeprecationWarning, - ) - coefs = { - key: val.sum(dim=axis, keepdims=keepdims) - for key, val in self.coefs.items() - } + coefs = {key: val.sum() for key, val in self.coefs.items()} return self.__class__(coefs) def flatten( diff --git a/declearn/typing.py b/declearn/typing.py index c7b1c6429a093c14fc6110de2c8acbe220668cc8..916b1a718d8177e8ced2d165936cb176ca59070d 100644 --- a/declearn/typing.py +++ b/declearn/typing.py @@ -20,12 +20,16 @@ from abc import ABCMeta, abstractmethod from typing import Any, Dict, List, Optional, Protocol, Tuple, Union +import numpy as np +import pandas as pd from numpy.typing import ArrayLike +from scipy.sparse import spmatrix # type: ignore from typing_extensions import Self # future: import from typing (Py>=3.11) __all__ = [ "Batch", + "DataArray", "SupportsConfig", ] @@ -44,6 +48,15 @@ This type-hint designates (inputs, labels, weights) inputs, where: """ # this is rendered as a docstring for `Batch` in the docs +DataArray = Union[np.ndarray, pd.DataFrame, pd.Series, spmatrix] +"""Type-annotation alias for a union of data type structures. + +This alias covers types supported by [declearn.dataset.utils.save_data_array][] +and its counterpart [declearn.dataset.utils.load_data_array][], and is hence +used to annotate some dataset-interfacing tools under [declearn.dataset][]. +""" + + class SupportsConfig(Protocol, metaclass=ABCMeta): """Protocol for type annotation of objects with get/from_config methods. diff --git a/declearn/version.py b/declearn/version.py index ed20f25a5ca69ae11b600b2e902c4fc82afba255..93721d5083e5fbea3219d4d9eb0947c2adb45be7 100644 --- a/declearn/version.py +++ b/declearn/version.py @@ -17,5 +17,5 @@ """DecLearn version information, as hard-coded constants.""" -VERSION = "2.4.0" +VERSION = "2.5.0" """Version information of the installed DecLearn package.""" diff --git a/docs/release-notes/SUMMARY.md b/docs/release-notes/SUMMARY.md index 568e1343869a8b2bcc974c538c5f1e7371c0fc10..6da663077a6734430c23d628a2ff73841170ab6b 100644 --- a/docs/release-notes/SUMMARY.md +++ b/docs/release-notes/SUMMARY.md @@ -1,3 +1,4 @@ +- [v2.5.0](v2.5.0.md) - [v2.4.0](v2.4.0.md) - [v2.3.2](v2.3.2.md) - [v2.3.1](v2.3.1.md) diff --git a/docs/release-notes/v2.5.0.md b/docs/release-notes/v2.5.0.md new file mode 100644 index 0000000000000000000000000000000000000000..27a4e923066459157ccb6334ef1279007ea96467 --- /dev/null +++ b/docs/release-notes/v2.5.0.md @@ -0,0 +1,91 @@ +# declearn v2.4.0 + +Released: 13/05/2024 + +## Release Highlight: Secure Aggregation + +### Overview + +This new version of DecLearn is mostly about enabling the use of Secure +Aggregation (also known as SecAgg), _i.e._ methods that enable aggregating +client-emitted information without revealing said information to the server +in charge of this aggregation. + +DecLearn now implements both a generic API for SecAgg and a couple of practical +solutions that are ready-for-use. This makes SecAgg easily-usable as part of +the existing federated learning process, and extensible by advanced users or +researchers that would like to use or test their own method and/or setup. + +New features are mostly implemented under the new `declearn.secagg` submodule, +with further changes to `declearn.main.FederatedClient` and `FederatesServer` +integrating these features to the main process. + +### Usage and scope + +Setting up SecAgg is meant to be straightforward: + +- The server and clients agree on a SecAgg method and, when required, some + hyper-parameters in advance. +- Clients need to hold a private Ed25519 identity key and share the associate + public key with all other clients in advance, so that the SecAgg setup can + include the verification that ephemeral-key-agreement information comes from + trusted peers. _(This is mandatory in current implementations, but may be + removed in custom setup implementations.)_ +- The server and clients must simply pass an additional `secagg` keyword + argument when instantiating their `FederatedServer` or `FederatedClient`, + which can take the form of a class, a dict or the path to a TOML file. +- _Voilà !_ + +At the moment, SecAgg is used to secure the aggregation of model parameter +updates, optimizer auxiliary variables and evaluation metrics, as well as +some metadata from training and evaluation rounds. In the future, we plan +to cover metadata queries and (yet-to-be-implemented) federated analytics. + +### Available algorithms + +At the moment, DecLearn provides with the following SecAgg algorithms: + +- Masking-based SecAgg (`declearn.secagg.masking`), that uses pseudo-random + number generators (PRNG) to generate masks over a finite integer field so + that the sum of clients' masks is known to be zero. + - This is based on + [Bonawitz et al., 2016](https://dl.acm.org/doi/10.1145/3133956.3133982). + - The setup that produces pairwise PRNG seeds is conducted using the + [X3DH](https://www.signal.org/docs/specifications/x3dh/) protocol. + - This solution has very limited computation and commmunication overhead + and should be considered the default SecAgg solution with DecLearn. + +- Joye-Libert sum-homomorphic encryption (`declearn.secagg.joye-libert`), that + uses actual encryption, modified summation operator, and aggregate-decryption + primitives that operate on a large biprime-defined integer field. + - This is based on + [Joye & Libert, 2013](https://marcjoye.github.io/papers/JL13aggreg.pdf). + - The setup that compute the public key as a sum of arbitrary private keys + involves the [X3DH](https://www.signal.org/docs/specifications/x3dh/) + protocol as well as + [Shamir Secret Sharing](https://dl.acm.org/doi/10.1145/359168.359176). + - This solution has a high computation and commmunication overhead. It is + not really suitable for model with many parameters (including few-layers + artificial neural networks). + +### Documentation + +In addition to the extensive in-code documentation, a guide on the SecAgg +features was added to the user documentation, that may be found +[here](../user-guide/secagg.md). If anything is not as clear as you would +hope, do let us know by opening a GitLab or GitHub issue, or by dropping +an e-mail to the package maintainers! + +## Other changes + +A few minor changes are shipped with this new release in addition to the new +SecAgg features: + +- `InMemoryDataset` backend code was refactored and made more robust, and its + documentation was improved for readability purposes. +- `declearn.typing.DataArray` was added as an alias. +- `declearn.dataset.utils.split_multi_classif_dataset` was enhanced to support + a new scheme based on Dirichlet allocation. Unit tests were also added for + both this scheme and existing ones. +- Unit tests for both `FederatedClient` and `FederatedServer` were added. +- Deprecated keyword arguments of `Vector.sum` were removed, as due. diff --git a/docs/user-guide/SUMMARY.md b/docs/user-guide/SUMMARY.md index 7ce0e34f4c5bd1eb0610e2dfb4e0c7f380a5edfd..865118d0c1c485e2fcbed0dcf86f34eafd5a0916 100644 --- a/docs/user-guide/SUMMARY.md +++ b/docs/user-guide/SUMMARY.md @@ -4,3 +4,4 @@ - [Hands-on usage](./usage.md) - [Guide to the Optimizer API](./optimizer.md) - [Local Differential Privacy capabilities](./local_dp.md) +- [Secure Aggregation capabilities](./secagg.md) diff --git a/docs/user-guide/index.md b/docs/user-guide/index.md index 92a2b040d1d651f16543fec069c87b6d62c05733..362050cb74125820c35e2c0302ee6458b9aec886 100644 --- a/docs/user-guide/index.md +++ b/docs/user-guide/index.md @@ -16,3 +16,5 @@ This guide is structured this way: API, design principles and practical how-tos of the declearn Optimizer. - [Local Differential Privacy capabilities](./local_dp.md):<br/> Description of the local-DP features of declearn. +- [Secure Aggregation capabilities](./secagg.md):<br/> + Description of the SecAgg features of declearn. diff --git a/docs/user-guide/package.md b/docs/user-guide/package.md index 76c0a32b03e45aa5d128958d5a308de135b1a00d..ed6eda864f604a20cd34f3e491512204069bd652 100644 --- a/docs/user-guide/package.md +++ b/docs/user-guide/package.md @@ -22,6 +22,8 @@ The package is organized into the following submodules:   Model interfacing API and implementations. - `optimizer`:<br/>   Framework-agnostic optimizer and algorithmic plug-ins API and tools. +- `secagg`:<br/> +   Secure Aggregation API, methods and utils. - `typing`:<br/>   Type hinting utils, defined and exposed for code readability purposes. - `utils`:<br/> @@ -198,6 +200,64 @@ You may learn more about our (non-abstract) `Optimizer` API by reading our - `declearn.dataset.torch.TorchDataset` - Extend: use `declearn.utils.register_type(group="Dataset")`. +### Secure Aggregation (SecAgg) + +#### `SecaggConfigClient` +- Import: `declearn.secagg.api.SecaggConfigClient` +- Object: Set up Secure Aggregation based on wrapped configuration parameters. +- Usage: Set up an `Encrypter` (see below) instance based on parameters and a + server-emitted setup request, jointly with other clients. +- Examples: + - `declearn.secagg.joye_libert.JoyeLibertSecaggConfigClient` + - `declearn.secagg.masking.MaskingSecaggConfigClient` +- Extend: + - Simply inherit from `SecaggConfigClient` (registration is automated). + - To avoid it, use `class MySetup(SecaggConfigClient, register=False)`. + +#### `SecaggConfigServer` +- Import: `declearn.secagg.api.SecaggConfigServer` +- Object: Set up Secure Aggregation based on wrapped configuration parameters. +- Usage: Set up a `Decrypter` (see below) instance based on parameters and in + interaction with (a subset of) clients. +- Examples: + - `declearn.secagg.joye_libert.JoyeLibertSecaggConfigServer` + - `declearn.secagg.masking.MaskingSecaggConfigServer` +- Extend: + - Simply inherit from `SecaggConfigServer` (registration is automated). + - To avoid it, use `class MySetup(SecaggConfigServer, register=False)`. + +#### `Encrypter` +- Import: `declearn.secagg.api.Encrypter` +- Object: Encrypt values that need secure aggregation. +- Usage: Encrypt shared data, typically packed as a `Message` or `Aggregate`. +- Examples: + - `declearn.secagg.joye_libert.JoyeLibertEncrypter` + - `declearn.secagg.masking.MaskingEncrypter` +- Extend: there is no type-registration, simply implement a subclass. + +#### `Decrypter` +- Import: `declearn.secagg.api.Decrypter` +- Object: Decrypt aggregated values to finalize secure aggregation. +- Usage: Decrypt shared data, typically packed as a `Message` or + `SecureAggregate`. +- Examples: + - `declearn.secagg.joye_libert.JoyeLibertDecrypter` + - `declearn.secagg.masking.MaskingDecrypter` +- Extend: there is no type-registration, simply implement a subclass. + +#### `SecureAggregate` +- Import: `declearn.secagg.api.SecureAggregate` +- Object: Wrap up encrypted data from an `Aggregate` (e.g. `ModelUpdates`, + `AuxVar` or `MetricState` instance) and enable their aggregation. +- Usage: Used by `Encrypter` and `Decrypter` to wrap up encrypted data. +- Examples: + - `declearn.secagg.joye_libert.JlsAggregate` + - `declearn.secagg.masking.MaskedAggregate` +- Extend: + - Simply inherit from `SecureAggregate` (registration is automated, + and is about making the class JSON-serializable). + - To avoid it, use `class MyClass(SecureAggregate, register=False)`. + ## Full API Reference The full API reference, which is generated automatically from the code's diff --git a/docs/user-guide/secagg.md b/docs/user-guide/secagg.md new file mode 100644 index 0000000000000000000000000000000000000000..356219774b7869f83dd074cb6380acd90e90cfc9 --- /dev/null +++ b/docs/user-guide/secagg.md @@ -0,0 +1,319 @@ +# Secure Aggregation + +## Overview + +### What is Secure Aggregation? + +Secure Aggregation (often and hereafter abbreviated as SecAgg) is a generic +term to describe methods that enable aggregating client-emitted information +without revealing said information to the server in charge of this aggregation. +In other words, SecAgg is about computing a public aggregate of private values +in a secure way, limiting the amount of trust put into the server. + +Various methods have been proposed in the litterature, that may use homomorphic +encryption, one-time pads, pseudo-random masks, multi-party computation methods +and so on. These methods come with various costs (both in terms of computation +overhead and communication costs, usually increasing messages' size and/or +frequency), and various features (_e.g._ some support a given amount of loss +of information due to clients dropping from the process; some are best-suited +for some security settings than others; some require more involved setup...). + +### General capabilities + +DecLearn implements both a generic API for SecAgg and some practical solutions +that are ready-for-use. In the current state of things however, some important +hypotheses are enforced: + +- The current SecAgg implementations require clients to generate Ed25519 + identity keys and share the associate public keys with other clients prior + to using DecLearn. The API also leans towards this requirement. In practice, + it would not be difficult to make it so that the server distributes clients' + public keys across the network, but we believe that the incurred loss in + security partially defeats the purpose of SecAgg, hence we prefer to provide + with the current behavior (that requires sharing keys via distinct channels) + and leave it up to end-users to set up alternatives if they want to. +- There is no resilience to clients dropping from the process. If a client was + supposed to participate in a round but does not send any information, then + some new setup and secure aggregation would need to be performed. This may + change in the future, but is on par with the current limitations of the + framework. +- There are no countermeasures to dishonest clients, i.e. there is no effort + put in verifying that outputs are coherent. This is not specific to SecAgg + either, and is again something that will hopefully be tackled in future + versions. + +### Details and caveats + +At the moment, the SecAgg API and its integration as part of the main federated +learning orchestration classes has the following characteristics: + +- SecAgg is jointly parametrized by the server and clients, that must use + coherent parameters. + - The choice of using SecAgg and of a given SecAgg method must be the same + across all clients and the server, otherwise an error is early-raised. + - Some things are up to clients, notably the specification of identity keys + used to set up short-lived secrets, and more generally any parameter that + cannot be verified to be trustworthy if the server sets it up. + - Some things are up to the server, notably the choice of quantization + hyper-parameters, that is coupled with details on the expected magnitude + of shared quantities (gradients, metrics...). +- SecAgg is set up every time participating clients to a training or validation + round change. + - Whenever participating clients do not match those in the previous round, + a SecAgg setup round is triggered to as to set up controllers anew across + participating clients. + - This is designed to guarantee that controllers are properly aligned and + there is no aggregation of mismatching encrypted values. + - This may however not be as optimal as theoretically-achievable in terms + of setup communication and computation costs. +- SecAgg is used to protect training and evaluation rounds' results. + - Model updates, optimizer auxiliary variables and metrics are covered. + - Some metadata are secure-aggregated (number of steps), some are discarded + (number of epochs), some remain cleartext but are reduced (time spent for + computations is sent in cleartext and max-aggregated). +- SecAgg does not (yet) cover other computations. + - Metadata queries are not secured (meaning the number of data samples may + be sent in cleartext). + - The topic of protecting specific or arbitrary quantities as part of + processes implemented by subclassing the current main classes remains + open. +- SecAgg of `Aggregate`-inheriting objects (notably `ModelUpdates`, `AuxVar` + and `MetricState` instances) is based on their `prepare_for_secagg` method, + that may not always be defined. It is up to end-users to ensure that the + components they use (and custom components they might add) are properly + made compatible with SecAgg. + +## How to setup and use SecAgg + +From the end-user perspective, setting up and using SecAgg requires: + +- Having generated and shared clients' (long-lived) identity keys across + trusted peers prior to running DecLearn. +- Having all peers specify coherent SecAgg parameters, that are passed to + the main `FederatedClient` and `FederatedServer` classes at instantiation. +- Ensuring that the `Aggregator`, `OptiModule` plug-ins and `Metrics` used + for the experiment are compatible with SecAgg. +- _Voilà !_ + +### Available SecAgg algorithms + +At the moment, DecLearn provides with the following SecAgg algorithms: + +- Masking-based SecAgg (`declearn.secagg.masking`), that uses pseudo-random + number generators (PRNG) to generate masks over a finite integer field so + that the sum of clients' masks is known to be zero. + - This is based on + [Bonawitz et al., 2016](https://dl.acm.org/doi/10.1145/3133956.3133982). + - The setup that produces pairwise PRNG seeds is conducted using the + [X3DH](https://www.signal.org/docs/specifications/x3dh/) protocol. + - This solution has very limited computation and commmunication overhead + and should be considered the default SecAgg solution with DecLearn. + +- Joye-Libert sum-homomorphic encryption (`declearn.secagg.joye-libert`), that + uses actual encryption, modified summation operator, and aggregate-decryption + primitives that operate on a large biprime-defined integer field. + - This is based on + [Joye & Libert, 2013](https://marcjoye.github.io/papers/JL13aggreg.pdf). + - The setup that compute the public key as a sum of arbitrary private keys + involves the [X3DH](https://www.signal.org/docs/specifications/x3dh/) + protocol as well as + [Shamir Secret Sharing](https://dl.acm.org/doi/10.1145/359168.359176). + - This solution has a high computation and commmunication overhead. It is + not really suitable for model with many parameters (including few-layers + artificial neural networks). + +### Hands-on example + +If we use the [MNIST example](https://gitlab.inria.fr/magnet/declearn/declearn2/-/tree/develop/examples/mnist/) +implemented via Python scripts and want to use the DecLearn-provided +masking-based algorithm for SecAgg (see below), we merely have to apply the +following modifications: + +**1. Generate client Ed25519 identity keys** + +This may be done using dedicated tools, but here is a script to merely +generate and dump identity keys to files on a local computer: + +```python +from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey +from cryptography.hazmat.primitives import serialization +from declearn.secagg.utils import IdentityKeys + +# Generate some private Ed25519 keys and gather their public counterparts. +n_clients = 5 # adjust to the number of clients +private_keys = [Ed25519PrivateKey.generate() for _ in range(n_clients)] +public_keys = [key.public_key() for key in private_keys] + +# Export all public keys as a single file with custom format. +IdentityKeys(private_keys[0], trusted=public_keys).export_trusted_keys_to_file( + "trusted_public.keys" +) + +# Export private keys as PEM files without password protection. +for idx, key in enumerate(private_keys): + dat = key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.OpenSSH, + encryption_algorithm=serialization.NoEncryption(), + ) + with open(f"private_{idx}.pem", "wb") as file: + file.write(dat) +``` + +**2. Write client-side SecAgg config** + +Add the following code as part of the client routine to setup and run the +federated learning process. Paths and parameters may and should be adjusted +for practical use cases. + +```python +from declearn.secagg import parse_secagg_config_client +from declearn.secagg.utils import IdentityKeys + +# Add this as part of the `run_client` function in the `run_client.py` file. + +client_idx = int(client_name.rsplit("_", 1)[-1]) +secagg = parse_secagg_config_client( + secagg_type="masking", + id_keys=IdentityKeys( + prv_key=f"client_{client_idx}.pem", + trusted="trusted_public.keys", + ) +) +# Alternatively, write `secagg` as the dict of previous kwargs, +# or as a `declearn.secagg.masking.MaskingSecaggConfigClient` instance. + +# Overload the `FederatedClient` instantation in step (5) of the function. + +client = declearn.main.FederatedClient( + # ... what is already there + secagg=secagg, +) +``` + +**3. Write server-side SecAgg config** + +Add the following code as part of the server routine to setup and run the +federated learning process. Parameters may and should be adjusted to practical +use cases. + +```python +from declearn.secagg import parse_secagg_config_server + +# Add this as part of the `run_server` function in the `run_server.py` file. + +secagg = parse_secagg_config_server( + secagg_type="masking", + # You may tune hyper-parameters, controlling values' quantization. +) +# Alternatively, write `secagg` as the dict of previous kwargs, +# or as a `declearn.secagg.masking.MaskingSecaggConfigServer` instance. + +# Overload the `FederatedServer` instantation in step (4) of the function. + +server = declearn.main.FederatedServer( + # ... what is already there + secagg=secagg, +) +``` + +Note that if the configurations do not match (whether as to the use of SecAgg +or not, the choice of algorithm, its hyper-parameters, or the validity of +trusted identity keys), an error will be raised at some (early) point when +attempting to run the Federated Learning process. + +## How to implement a new SecAgg method + +The API for SecAgg relies on a number of abstractions, that may be divided in +two categories: + +- backend controllers that provide with primitives to encrypt, aggregate + and decrypt values; +- user-end controllers that parse configuration parameters and implement + a setup protocol to instantiate backend controllers. + +As such, it is possible to write a new setup for existing controllers; but in +general, adding a new SecAgg algorithm to DecLearn will involve writing it all +up. The first category is somewhat decoupled from the rest of the framework (it +only needs to know what the `Vector` and `Aggregate` data structures are, and +how to operate on them), while the second is very much coupled with the network +communication and messaging APIs. + +### `Encrypter` and `Decrypter` controllers + +Primitives for the encryption of private values, aggregation of encrypted +values and decryption of an aggregated value are to be implemented by a pair +of `Encrypter` and `Decrypter` subclasses. + +These classes may (and often will) use any form of (private) time index, that +is not required to be shared with encrypted values, as values are assumed to be +encrypted in the same order by each and every client and decrypted in that same +order by the server once aggregated. + +`Encrypter` has two abstract methods that need implementing: + +- `encrypt_uint`, that encrypts a scalar uint value (that may arise from the + quantization of a float value). It is called when encrypting floats, numpy + arrays and declearn `Vector` or `Aggregate` instances. +- `wrap_into_secure_aggregate`, that is called to wrap up encrypted values + and other metadata and information from an `Aggregate` instance into a + `SecureAggregate` one. This should in most cases merely correspond to + instantiating the proper `SecureAggregate` subclass with a mix of input + arguments to the method and attributes from the `Encrypter` instance. + +`Decrypter` has an abstract class attribute and two abstract methods that +need implementing: + +- `sum_encrypted`, that aggregates a list of two or more encrypted values that + need aggregation into a single one (in a way that makes their aggregate + decryptable into the sum of cleartext values). +- `decrypt_uint`, that decrypts an input into a scalar uint value. It is called + when decrypting inputs into any supported type, and is the counterpart to the + `Encrypter.encrypt_uint` method. +- `secure_aggregate_cls`, that is a class attribute that is merely the type of + the `SecureAggregate` subclass emitted by the paired `Encrypter`'s + `wrap_into_secure_aggregate` method. + +`SecureAggregate` is a third class that is used by the former two, and usually +not directly accessed by end-users. It is an `Aggregate`-like wrapper for +encrypted counterparts to `Aggregate` objects. The only abstract method that +needs defining is `aggregate_encrypted`, which should mostly be the same as +`Decrypter.sum_encrypted`. Subclasses may also embark additional metadata +about the SecAgg algorithm's parameters, and conduct associate verification +of coherence across peers. + +### `SecaggConfigClient` and `SecaggConfigServer` endpoints + +Routines to set up matching `Encrypter` and `Decrypter` instances across a +federated network of peers are to be implemented by a pair of +`SecaggConfligClient` and `SecaggConfligServer` subclasses. In addition, a +dedicated `SecaggSetupQuery` subclass (itself a `Message`) should be defined. + +These classes define a setup that is bound to be initiated by the server, that +emits a `SecaggSetupQuery` message triggering the client's call to the setup +routine. After that, any number of network communication exchanges may be run, +depending on the setup being implemented. + +Both `SecaggConfigClient` and `SecaggConfigServer` subclasses must be decorated +as `dataclasses.dataclass`, so as to benefit from TOML-parsing capabilities. +They are automatically type-registered (which may be prevented by passing the +`register=False` parameter at inheritance). + +`SecaggConfigClient` has an abstract class attribute and an abstract method: + +- `secagg_type`, that is a string class attribute that must be unique across + subclasses (for type-registration) and match that of the server-side class. +- `setup_encrypter`, that takes a `NetworkClient` and a received serialized + `SecaggSetupQuery` message, and conducts any steps towards setting up and + returning an `Encrypter` instance. + +`SecaggConfigServer` has an abstract class attribute and two abstract methods: + +- `secagg_type`, that is a string class attribute that must be unique across + subclasses (for type-registration) and match that of the client-side class. +- `prepare_secagg_setup_query`, that returns a `SecaggSetupQuery` to be sent + to (the subset of) clients that are meant to participate in the setup. +- `finalize_secagg_setup`, that takes a `NetworkServer` and an optional set + of client names, is called right after sending the setup query to (these) + clients and may thereafter conduct any steps towards setting up and returning + a `Decrypter` instance. diff --git a/pyproject.toml b/pyproject.toml index 58af1de0fafb70b92141a2de4740a5738b742cd3..68f9dcb10c96f92ee50eebbd85b4834ca38820f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ requires = [ [project] name = "declearn" -version = "2.4.0" +version = "2.5.0" description = "Declearn - a python package for private decentralized learning." readme = "README.md" requires-python = ">=3.8" diff --git a/test/dataset/test_split_multi_classif.py b/test/dataset/test_split_multi_classif.py new file mode 100644 index 0000000000000000000000000000000000000000..39659fe3831a29f9ab012621c35e342003bea687 --- /dev/null +++ b/test/dataset/test_split_multi_classif.py @@ -0,0 +1,224 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for 'declearn.dataset.utils.split_multi_classif_dataset'.""" + +from typing import List, Tuple, Type, Union + +import numpy as np +import pytest +from scipy.sparse import coo_matrix, spmatrix # type: ignore +from scipy.stats import chi2_contingency # type: ignore + +from declearn.dataset.utils import split_multi_classif_dataset + + +Array = Union[np.ndarray, spmatrix] + + +@pytest.fixture(name="dataset", scope="module") +def dataset_fixture( + sparse: bool, +) -> Tuple[Array, np.ndarray]: + """Fixture providing with a random classification dataset.""" + rng = np.random.default_rng(seed=0) + x_dat = rng.normal(size=(400, 32)) + y_dat = rng.choice(4, size=400) + if sparse: + return coo_matrix(x_dat), y_dat + return x_dat, y_dat + + +@pytest.mark.parametrize( + "sparse", [False, True], ids=["dense", "sparse"], scope="module" +) +class TestSplitMultiClassifDataset: + """Unit tests for `split_multi_classif_dataset`.""" + + @classmethod + def assert_expected_shard_shapes( + cls, + dataset: Tuple[Array, np.ndarray], + shards: List[ + Tuple[Tuple[Array, np.ndarray], Tuple[Array, np.ndarray]] + ], + n_shards: int, + p_valid: float, + ) -> None: + """Verify that shards match expected shapes.""" + assert isinstance(shards, list) and len(shards) == n_shards + n_samples = 0 + for shard in shards: + cls._assert_valid_shard( + shard, + n_feats=dataset[0].shape[1], + n_label=len(np.unique(dataset[1])), + p_valid=p_valid, + x_type=type(dataset[0]), + ) + n_samples += shard[0][0].shape[0] + shard[1][0].shape[0] + assert n_samples == dataset[0].shape[0] + + @staticmethod + def _assert_valid_shard( + shard: Tuple[Tuple[Array, np.ndarray], Tuple[Array, np.ndarray]], + n_feats: int = 32, + n_label: int = 4, + p_valid: float = 0.2, + x_type: Union[Type[np.ndarray], Type[spmatrix]] = np.ndarray, + ) -> None: + """Assert that a given dataset shard matches expected specs.""" + # Unpack arrays and verify that their types are coherent. + (x_train, y_train), (x_valid, y_valid) = shard + assert isinstance(x_train, x_type) + assert isinstance(x_valid, x_type) + assert isinstance(y_train, np.ndarray) + assert isinstance(y_valid, np.ndarray) + # Assert that array shapes match expectations. + assert x_train.ndim == x_valid.ndim == 2 + assert x_train.shape[0] == y_train.shape[0] + assert x_valid.shape[0] == y_valid.shape[0] + assert x_train.shape[1] == x_valid.shape[1] == n_feats + assert y_train.ndim == y_valid.ndim == 1 + # Assert that labels have proper values. + labels = list(range(n_label)) + assert np.all(np.isin(y_train, labels)) + assert np.all(np.isin(y_valid, labels)) + # Assert that train/valid partition matches expectation. + s_valid = x_valid.shape[0] / (x_train.shape[0] + x_valid.shape[0]) + assert abs(p_valid - s_valid) <= 0.02 + + @staticmethod + def get_label_counts( + shards: List[ + Tuple[Tuple[Array, np.ndarray], Tuple[Array, np.ndarray]] + ], + n_label: int = 4, + ) -> np.ndarray: + """Return the shard-wise subset-wise label counts, stacked together.""" + counts = [ + np.bincount(y, minlength=n_label) + for (_, y_train), (_, y_valid) in shards + for y in (y_train, y_valid) + ] + return np.stack(counts) + + def test_scheme_iid( + self, + dataset: Tuple[Array, np.ndarray], + ) -> None: + """Test that the iid scheme yields iid samples.""" + shards = split_multi_classif_dataset( + dataset, n_shards=4, scheme="iid", p_valid=0.2, seed=0 + ) + # Verify that shards match expected shapes. + self.assert_expected_shard_shapes( + dataset, shards, n_shards=4, p_valid=0.2 + ) + # Verify that labels are iid-distributed across shards. + # To do so, use a chi2-test with blatantly high acceptance rate + # for the null hypothesis that distributions differ, and verify + # that the hypothesis would still be rejected at that rate. + y_counts = self.get_label_counts(shards) + assert chi2_contingency(y_counts).pvalue >= 0.90 + + def test_scheme_labels( + self, + dataset: Tuple[Array, np.ndarray], + ) -> None: + """Test that the labels scheme yields non-overlapping-labels shards.""" + shards = split_multi_classif_dataset( + dataset, n_shards=2, scheme="labels", p_valid=0.4, seed=0 + ) + # Verify that shards match expected shapes. + self.assert_expected_shard_shapes( + dataset, shards, n_shards=2, p_valid=0.4 + ) + # Verify that labels are distributed without overlap across shards. + labels_train_0 = np.unique(shards[0][0][1]) + labels_valid_0 = np.unique(shards[0][1][1]) + labels_train_1 = np.unique(shards[1][0][1]) + labels_valid_1 = np.unique(shards[1][1][1]) + assert np.all(labels_train_0 == labels_valid_0) + assert np.all(labels_train_1 == labels_valid_1) + assert np.intersect1d(labels_train_0, labels_train_1).shape == (0,) + + def test_scheme_biased( + self, + dataset: Tuple[Array, np.ndarray], + ) -> None: + """Test that the biased scheme yields disparate-distrib. shards.""" + shards = split_multi_classif_dataset( + dataset, n_shards=4, scheme="biased", p_valid=0.2, seed=0 + ) + # Verify that shards match expected shapes. + self.assert_expected_shard_shapes( + dataset, shards, n_shards=4, p_valid=0.2 + ) + # Verify that labels have distinct distributions across shards. + # To do so, use a chi2-test (with the null hypothesis that + # distributions differ), and verify that the hypothesis is + # accepted overall with high confidence, and rejected on + # train/valid pairs with high confidence as well. + y_counts = self.get_label_counts(shards) + assert chi2_contingency(y_counts).pvalue <= 1e-10 + for i in range(len(shards)): + assert chi2_contingency(y_counts[i : i + 1]).pvalue >= 0.90 + + def test_scheme_dirichlet( + self, + dataset: Tuple[Array, np.ndarray], + ) -> None: + """Test the dirichlet scheme's disparity, with various alpha values.""" + shards = split_multi_classif_dataset( + dataset, n_shards=2, scheme="dirichlet", seed=0 + ) + # Verify that shards match expected shapes. + self.assert_expected_shard_shapes( + dataset, shards, n_shards=2, p_valid=0.2 + ) + # Verify that labels are differently-distributed across shards. + # To do so, use a chi2-test (with the null hypothesis that + # distributions differ), and verify that the hypothesis is + # accepted overall with high confidence, and rejected on + # train/valid pairs with high confidence as well. + y_counts = self.get_label_counts(shards) + pval_low = chi2_contingency(y_counts).pvalue + assert pval_low <= 1e-10 + for i in range(len(shards)): + assert chi2_contingency(y_counts[i : i + 1]).pvalue >= 0.60 + # Verify that using a higher alpha results in a high p-value. + shards = split_multi_classif_dataset( + dataset, n_shards=2, scheme="dirichlet", seed=0, alpha=2.0 + ) + y_counts = self.get_label_counts(shards) + pval_mid = chi2_contingency(y_counts).pvalue + assert pval_low < pval_mid <= 0.05 + # Verify that using a very high alpha value results in iid data. + shards = split_multi_classif_dataset( + dataset, n_shards=2, scheme="dirichlet", seed=0, alpha=100000.0 + ) + y_counts = self.get_label_counts(shards) + assert chi2_contingency(y_counts).pvalue >= 0.90 + + def test_error_labels_too_many_shards( + self, + dataset: Tuple[Array, np.ndarray], + ) -> None: + """Test that in 'labels' schemes 'n_shards > n_lab' raises an error.""" + with pytest.raises(ValueError): + split_multi_classif_dataset(dataset, n_shards=8, scheme="labels") diff --git a/test/functional/test_toy_clf_secagg.py b/test/functional/test_toy_clf_secagg.py index c1b13091cd0d1f39daca6379282c5296ac700cbe..cc5723e41b5d8c8d71c28da465a581637b587f4b 100644 --- a/test/functional/test_toy_clf_secagg.py +++ b/test/functional/test_toy_clf_secagg.py @@ -85,7 +85,7 @@ def generate_toy_dataset( # Cluster samples based on features and assign them to clients thereof. # Also split client-wise data: 80% for training and 20% for validation. kclust = sklearn.cluster.KMeans( - n_clusters=n_clients, init="random", random_state=SEED + n_clusters=n_clients, init="random", n_init="auto", random_state=SEED ).fit_predict(inputs) datasets = [] # type: List[Tuple[InMemoryDataset, InMemoryDataset]] for i in range(n_clients): diff --git a/test/main/test_main_client.py b/test/main/test_main_client.py index 384c2b4c210b75d6feb0f92872129594d5296eb6..861f1bc372cd0fa36bb2eb3e7a6597695ec10539 100644 --- a/test/main/test_main_client.py +++ b/test/main/test_main_client.py @@ -43,6 +43,11 @@ MOCK_NETWK.name = "client" MOCK_DATASET = mock.create_autospec(Dataset, instance=True) +def object_new(cls, *_, **__) -> Any: + """Wrapper for 'object.__new__' accepting/discarding *args and **kwargs.""" + return object.__new__(cls) + + @contextlib.contextmanager def patch_class_constructor( cls: Type[Any], @@ -69,7 +74,7 @@ def patch_class_constructor( yield patch finally: if new is object.__new__: - cls.__new__ = lambda cls, *args, **kwargs: object.__new__(cls) + cls.__new__ = object_new # type: ignore[assignment] else: cls.__new__ = new