diff --git a/declearn/dataset/__init__.py b/declearn/dataset/__init__.py index dc5759a4a714b836e3bf7348d27e48b79d98468d..b440360accdd5471fa19afb2280d02a76c479b71 100644 --- a/declearn/dataset/__init__.py +++ b/declearn/dataset/__init__.py @@ -23,12 +23,37 @@ actually being loaded (from a source file, a database, another API...). This declearn submodule provides with: +API tools +--------- * [Dataset][declearn.dataset.Dataset]: - abstract class defining an API to access training or testing data + Abstract base class defining an API to access training or testing data. +* [DataSpec][declearn.dataset.DataSpecs]: + Dataclass to wrap a dataset's metadata. +* [load_dataset_from_json][declearn.dataset.load_dataset_from_json] + Utility function to parse a JSON into a dataset object. + +Dataset subclasses +------------------ * [InMemoryDataset][declearn.dataset.InMemoryDataset]: - Dataset subclass serving numpy(-like) memory-loaded data -arrays + Dataset subclass serving numpy(-like) memory-loaded data arrays. + +Utility submodules +------------------ +* [examples] + Utils to fetch and prepare some open-source datasets. +* [utils] + Utils to manipulate datasets (load, save, split...). + +Utility entry-point +------------------- +* [split_data][declearn.dataset.split_data] + Utility to split a single dataset into shards. This function builds + on more unitary utils, and is installed as a command-line entry-point + together with declearn. """ +from . import utils +from . import examples from ._base import Dataset, DataSpecs, load_dataset_from_json from ._inmemory import InMemoryDataset +from ._split_data import split_data diff --git a/declearn/dataset/_inmemory.py b/declearn/dataset/_inmemory.py index 58bd35f20b194c3d9fb62c952db0c17f4d7ace13..09df296a3294af4563655f4959fe3475a9fd40a8 100644 --- a/declearn/dataset/_inmemory.py +++ b/declearn/dataset/_inmemory.py @@ -17,19 +17,19 @@ """Dataset implementation to serve scikit-learn compatible in-memory data.""" -import functools import os +import warnings from typing import Any, ClassVar, Dict, Iterator, List, Optional, Set, Union import numpy as np -import pandas as pd # type: ignore +import pandas as pd from numpy.typing import ArrayLike from scipy.sparse import spmatrix # type: ignore from sklearn.datasets import load_svmlight_file # type: ignore from typing_extensions import Self # future: import from typing (py >=3.11) from declearn.dataset._base import Dataset, DataSpecs -from declearn.dataset._sparse import sparse_from_file, sparse_to_file +from declearn.dataset.utils import load_data_array, save_data_array from declearn.typing import Batch from declearn.utils import json_dump, json_load, register_type @@ -89,7 +89,7 @@ class InMemoryDataset(Dataset): an instance that is either a numpy ndarray, a pandas DataFrame or a scipy spmatrix. - See the `load_data_array` method for details + See the `load_data_array` function in dataset._utils for details on supported file formats. Parameters @@ -131,7 +131,7 @@ class InMemoryDataset(Dataset): # Assign the main data array. if isinstance(data, str): self._data_path = data - data = self.load_data_array(data) + data = load_data_array(data) self.data = data # Assign the optional input features list. self.f_cols = f_cols @@ -147,7 +147,7 @@ class InMemoryDataset(Dataset): self.f_cols.remove(target) # type: ignore target = self.data[target] else: - target = self.load_data_array(target) + target = load_data_array(target) self.target = target # Assign the (optional) sample weights data array. if isinstance(s_wght, str): @@ -159,7 +159,7 @@ class InMemoryDataset(Dataset): self.f_cols.remove(s_wght) # type: ignore s_wght = self.data[s_wght] else: - s_wght = self.load_data_array(s_wght) + s_wght = load_data_array(s_wght) self.weights = s_wght # Assign the 'expose_classes' attribute. self.expose_classes = expose_classes @@ -187,7 +187,8 @@ class InMemoryDataset(Dataset): if (not self.expose_classes) or (self.target is None): return None if isinstance(self.target, pd.DataFrame): - return set(self.target.unstack().unique().tolist()) # type: ignore + c_list = self.target.unstack().unique().tolist() # type: ignore + return set(c_list) if isinstance(self.target, pd.Series): return set(self.target.unique().tolist()) if isinstance(self.target, np.ndarray): @@ -224,57 +225,19 @@ class InMemoryDataset(Dataset): ) -> DataArray: """Load a data array from a dump file. - Supported file extensions - ------------------------- - - `.csv`: - csv file, comma-delimited by default. - Any keyword arguments to `pandas.read_csv` may be passed. - - `.npy`: - Non-pickle numpy array dump, created with `numpy.save`. - - `.sparse`: - Scipy sparse matrix dump, created with the custom - `declearn.data.sparse.sparse_to_file` function. - - `.svmlight`: - SVMlight sparse matrix and labels array dump. - Parse using `sklearn.load_svmlight_file`, and - return either features or labels based on the - `which` int keyword argument (default: 0, for - features). + As of declearn v2.2, this staticmethod is DEPRECATED in favor of + `declearn.dataset.utils.load_data_array`, which is now calls. It + will be removed in v2.4 and/or v3.0. - Parameters - ---------- - path: str - Path to the data array dump file. - Extension must be adequate to enable proper parsing; - see list of supported extensions above. - **kwargs: - Extension-type-based keyword parameters may be passed. - See above for details. - - Returns - ------- - data: numpy.ndarray or pandas.DataFrame or scipy.spmatrix - Reloaded data array. - - Raises - ------ - TypeError - If `path` is of unsupported extension. - - Any exception raised by data-loading functions may also be raised - (e.g. if the file cannot be proprely parsed). + See [declearn.dataset.utils.load_data_array][] for more details. """ - ext = os.path.splitext(path)[1] - if ext == ".csv": - return pd.read_csv(path, **kwargs) - if ext == ".npy": - return np.load(path, allow_pickle=False) - if ext == ".sparse": - return sparse_from_file(path) - if ext == ".svmlight": - which = kwargs.get("which", 0) - return load_svmlight_file(path)[which] - raise TypeError(f"Unsupported data array file extension: '{ext}'.") + warnings.warn( + "'InMemoryDataset.load_data_array' has been deprecated in favor" + " of `declearn.dataset.utils.load_data_array`. It will be removed" + " in version 2.4 and/or 3.0.", + category=DeprecationWarning, + ) + return load_data_array(path, **kwargs) @staticmethod def save_data_array( @@ -283,58 +246,19 @@ class InMemoryDataset(Dataset): ) -> str: """Save a data array to a dump file. - Supported types of data arrays - ------------------------------ - - `pandas.DataFrame` or `pandas.Series`: - Dump to a comma-separated `.csv` file. - - `numpy.ndarray`: - Dump to a non-pickle `.npy` file. - - `scipy.sparse.spmatrix`: - Dump to a `.sparse` file, using a custom format - and `declearn.data.sparse.sparse_to_file`. - - Parameters - ---------- - path: str - Path to the file where to dump the array. - Appropriate file extension will be added when - not present (i.e. `path` may be a basename). - array: data array structure (see above) - Data array that needs dumping to file. - See above for supported types and associated - behaviours. + As of declearn v2.2, this staticmethod is DEPRECATED in favor of + `declearn.dataset.utils.save_data_array`, which is now calls. It + will be removed in v2.4 and/or v3.0. - Returns - ------- - path: str - Path to the created file dump, based on the input - `path` and the chosen file extension (see above). - - Raises - ------ - TypeError - If `array` is of unsupported type. + See [declearn.dataset.utils.save_data_array][] for more details. """ - # Select a file extension and set up the array-dumping function. - if isinstance(array, (pd.DataFrame, pd.Series)): - ext = ".csv" - save = functools.partial( - array.to_csv, sep=",", encoding="utf-8", index=False - ) - elif isinstance(array, np.ndarray): - ext = ".npy" - save = functools.partial(np.save, arr=array) - elif isinstance(array, spmatrix): - ext = ".sparse" - save = functools.partial(sparse_to_file, matrix=array) - else: - raise TypeError(f"Unsupported data array type: '{type(array)}'.") - # Ensure proper naming. Save the array. Return the path. - if not path.endswith(ext): - path += ext - os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) - save(path) - return path + warnings.warn( + "'InMemoryDataset.save_data_array' has been deprecated in favor" + " of `declearn.dataset.utils.save_data_array`. It will be removed" + " in version 2.4 and/or 3.0.", + category=DeprecationWarning, + ) + return save_data_array(path, array) @classmethod def from_svmlight( @@ -388,15 +312,15 @@ class InMemoryDataset(Dataset): # fmt: off info["data"] = ( self._data_path or - self.save_data_array(os.path.join(folder, "data"), self.data) + save_data_array(os.path.join(folder, "data"), self.data) ) info["target"] = None if self.target is None else ( self._trgt_path or - self.save_data_array(os.path.join(folder, "trgt"), self.target) + save_data_array(os.path.join(folder, "trgt"), self.target) ) info["s_wght"] = None if self.weights is None else ( self._wght_path or - self.save_data_array(os.path.join(folder, "wght"), self.weights) + save_data_array(os.path.join(folder, "wght"), self.weights) ) # fmt: on info["f_cols"] = self.f_cols diff --git a/declearn/dataset/_split_data.py b/declearn/dataset/_split_data.py new file mode 100644 index 0000000000000000000000000000000000000000..017c9c08d0a3866f99573ae7b9bf02dfc8837473 --- /dev/null +++ b/declearn/dataset/_split_data.py @@ -0,0 +1,222 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Script to split data into heterogeneous shards and save them. + +Available splitting scheme: + +* "iid", split the dataset through iid random sampling. +* "labels", split into shards that hold all samples associated + with mutually-exclusive target classes. +* "biased", split the dataset through random sampling according + to a shard-specific random labels distribution. + +Utilities provided here are limited to: + +* (>=2-)d dataset that be directly loaded into numpy arrays or sparse matrices. +* Single-label, multinomial classification problems. +""" + +import os +from typing import Optional, Tuple, Union + +import fire # type: ignore +import numpy as np +import pandas as pd +from scipy.sparse import spmatrix # type: ignore + +from declearn.dataset.examples import load_mnist +from declearn.dataset.utils import ( + load_data_array, + save_data_array, + split_multi_classif_dataset, +) + + +__all__ = [ + "split_data", +] + + +def load_data( + data: Optional[str] = None, + target: Optional[Union[str, int]] = None, +) -> Tuple[Union[np.ndarray, spmatrix], np.ndarray]: + """Load a dataset in memory from provided path(s). + + This functions supports `.csv`, `.npy`, `.svmlight` and `.sparse` + file formats. See [declearn.dataset.utils.load_data_array][] for + details. + + Arguments + --------- + data: str or None, default=None + Path to the data file to import. + If None, default to importing the MNIST train dataset. + target: str or int or None, default=None + If str, path to the labels file to import, or name of a `data` + column to use as labels (only if `data` points to a csv file). + If int, index of a `data` column of to use as labels). + Required if data is not None, ignored if data is None. + + Returns + ------- + inputs: + Input features, as a numpy array or scipy sparse matrix. + labels: + Ground-truth labels, as a numpy array. + """ + # Case when no arguments are provided: return the default MNIST dataset. + if not data: + return load_mnist(train=True) + # Otherwise, load the dataset, then load or extract the target labels. + inputs = load_data_array(data) + if isinstance(target, str): + # Case when 'target' points to a separate data file. + if os.path.isfile(target): + labels = load_data_array(target) + if isinstance(labels, spmatrix): + labels = labels.toarray() + elif isinstance(labels, pd.DataFrame): + labels = labels.values + # Case when 'target' is the name of a column in a csv file. + elif isinstance(inputs, pd.DataFrame) and target in inputs: + labels = inputs.pop(target).values + inputs = inputs.values + else: + raise ValueError( + "Invalid 'target' value: either the file is missing, or it " + "points to a column that is not present in the loaded data." + ) + elif isinstance(target, int): + # Case when 'target' is the index of a data column. + inputs, labels = _extract_column_by_index(inputs, target) + else: + raise TypeError("Invalid type for 'target': should be str or int.") + return inputs, labels + + +def _extract_column_by_index( + inputs: Union[np.ndarray, spmatrix, pd.DataFrame], + target: int, +) -> Tuple[Union[np.ndarray, spmatrix], np.ndarray]: + """Backend to extract a column by index in a data array.""" + if target > inputs.shape[1]: + raise ValueError( + f"Invalid 'target' value: index {target} is out of range " + f"for the dataset, that has {inputs.shape[1]} columns." + ) + if isinstance(inputs, pd.DataFrame): + inputs = inputs.values + if isinstance(inputs, np.ndarray): + labels = inputs[:, target] + inputs = np.delete(inputs, target, axis=1) + elif isinstance(inputs, spmatrix): + labels = inputs.getcol(target).toarray().ravel() + csc = inputs.tocsc() # sparse matrix with efficient column slicing + idx = [i for i in range(inputs.shape[1]) if i != target] + inputs = type(inputs)(csc[:, idx]) + return inputs, labels + + +def split_data( + folder: str = ".", + data_file: Optional[str] = None, + label_file: Optional[Union[str, int]] = None, + n_shards: int = 3, + scheme: str = "iid", + perc_train: float = 0.8, + seed: Optional[int] = None, +) -> None: + """Randomly split a dataset into shards. + + The resulting folder structure is : + folder/ + └─── data*/ + └─── client*/ + │ train_data.* - training data + │ train_target.* - training labels + │ valid_data.* - validation data + │ valid_target.* - validation labels + └─── client*/ + │ ... + + Parameters + ---------- + folder: str, default = "." + Path to the folder where to add a data folder + holding output shard-wise files + data_file: str or None, default=None + Optional path to a folder where to find the data. + If None, default to the MNIST example. + label_file: str or int or None, default=None + If str, path to the labels file to import, or name of a `data` + column to use as labels (only if `data` points to a csv file). + If int, index of a `data` column of to use as labels). + Required if data is not None, ignored if data is None. + n_shards: int + Number of shards between which to split the data. + scheme: {"iid", "labels", "biased"}, default="iid" + Splitting scheme(s) to use. In all cases, shards contain mutually- + exclusive samples and cover the full raw training data. + - If "iid", split the dataset through iid random sampling. + - If "labels", split into shards that hold all samples associated + with mutually-exclusive target classes. + - If "biased", split the dataset through random sampling according + to a shard-specific random labels distribution. + perc_train: float, default= 0.8 + Train/validation split in each client dataset, must be in the + ]0,1] range. + seed: int or None, default=None + Optional seed to the RNG used for all sampling operations. + """ + # pylint: disable=too-many-arguments,too-many-locals + # Select output folder. + folder = os.path.join(folder, f"data_{scheme}") + # Value-check the 'perc_train' parameter. + if not (isinstance(perc_train, float) and (0.0 < perc_train <= 1.0)): + raise ValueError("'perc_train' should be a float in ]0,1]") + # Load the dataset and split it. + inputs, labels = load_data(data_file, label_file) + print( + f"Splitting data into {n_shards} shards using the '{scheme}' scheme." + ) + split = split_multi_classif_dataset( + dataset=(inputs, labels), + n_shards=n_shards, + scheme=scheme, # type: ignore + p_valid=(1 - perc_train), + seed=seed, + ) + # Export the resulting shard-wise data to files. + for idx, ((x_train, y_train), (x_valid, y_valid)) in enumerate(split): + subdir = os.path.join(folder, f"client_{idx}") + os.makedirs(subdir, exist_ok=True) + save_data_array(os.path.join(subdir, "train_data"), x_train) + save_data_array(os.path.join(subdir, "train_target"), y_train) + if x_valid.shape[0]: + save_data_array(os.path.join(subdir, "valid_data"), x_valid) + save_data_array(os.path.join(subdir, "valid_target"), y_valid) + + +def main() -> None: + "Fire-wrapped `split_data`." + fire.Fire(split_data) + + +if __name__ == "__main__": + main() diff --git a/declearn/dataset/examples/__init__.py b/declearn/dataset/examples/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..15c4c9a8e00290144095309ffe9595ecc52e9167 --- /dev/null +++ b/declearn/dataset/examples/__init__.py @@ -0,0 +1,29 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utils to fetch and prepare some open-source datasets. + +Datasets +-------- +* [load_heart_uci][declearn.dataset.examples.load_heart_uci]: + Load and/or download a pre-processed UCI heart disease dataset. +* [load_mnist][declearn.dataset.examples.load_mnist]: + Load and/or download the MNIST digit-classification dataset. +""" + +from ._heart_uci import load_heart_uci +from ._mnist import load_mnist diff --git a/declearn/dataset/examples/_heart_uci.py b/declearn/dataset/examples/_heart_uci.py new file mode 100644 index 0000000000000000000000000000000000000000..457ee2c8bff2755b878b633b71128c7458abbebe --- /dev/null +++ b/declearn/dataset/examples/_heart_uci.py @@ -0,0 +1,99 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Util to download and pre-process the UCI Heart Disease dataset.""" + +import os +from typing import Literal, Optional, Tuple + +import pandas as pd # type: ignore + +__all__ = [ + "load_heart_uci", +] + + +def load_heart_uci( + name: Literal["cleveland", "hungarian", "switzerland", "va"], + folder: Optional[str] = None, +) -> Tuple[pd.DataFrame, str]: + """Load and/or download a pre-processed UCI Heart Disease dataset. + + See [https://archive.ics.uci.edu/ml/datasets/Heart+Disease] for + information on the UCI Heart Disease dataset. + + Arguments + --------- + name: str + Name of a center, the dataset from which to return. + folder: str or None, default=None + Optional path to a folder where to write output csv files. + If the file already exists in that folder, read from it. + + Returns + ------- + data: pd.DataFrame + Pre-processed dataset from the `name` center. + May be passed as `data` of a declearn `InMemoryDataset`. + target: str + Name of the target column in `data`. + May be passed as `target` of a declearn `InMemoryDataset`. + """ + # If the file already exists, read and return it. + if folder is not None: + path = os.path.join(folder, f"data_{name}.csv") + if os.path.isfile(path): + data = pd.read_csv(path) + return data, "num" + # Otherwise, download and pre-process the data, and optionally save it. + data = download_heart_uci_shard(name) + if folder is not None: + os.makedirs(folder, exist_ok=True) + data.to_csv(path, index=False) + return data, "num" + + +def download_heart_uci_shard( + name: Literal["cleveland", "hungarian", "switzerland", "va"], +) -> pd.DataFrame: + """Download and preprocess a subset of the Heart UCI dataset.""" + print(f"Downloading Heart Disease UCI dataset from center {name}.") + url = ( + "https://archive.ics.uci.edu/ml/machine-learning-databases/" + f"heart-disease/processed.{name}.data" + ) + # Download the dataaset. + data = pd.read_csv(url, header=None, na_values="?") + columns = [ + # fmt: off + "age", "sex", "cp", "trestbps", "chol", "fbs", "restecg", + "thalach", "exang", "oldpeak", "slope", "ca", "thal", "num", + ] + data = data.set_axis(columns, axis=1, copy=False) + # Drop unused columns and rows with missing values. + data.drop(columns=["ca", "chol", "fbs", "slope", "thal"], inplace=True) + data.dropna(inplace=True) + data.reset_index(inplace=True, drop=True) + # Normalize quantitative variables. + for col in ("age", "trestbps", "thalach", "oldpeak"): + data[col] = ( # type: ignore + data[col] - data[col].mean() / data[col].std() # type: ignore + ) + # Binarize the target variable. + data["num"] = (data["num"] > 0).astype(int) + # Return the prepared dataframe. + return data diff --git a/declearn/dataset/examples/_mnist.py b/declearn/dataset/examples/_mnist.py new file mode 100644 index 0000000000000000000000000000000000000000..a839d2b28a48328283ec5be459868421d639d7d5 --- /dev/null +++ b/declearn/dataset/examples/_mnist.py @@ -0,0 +1,110 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Util to download the MNIST digit-classification dataset.""" + +import gzip +import os +from typing import Optional, Tuple + +import numpy as np +import requests + + +__all__ = [ + "load_mnist", +] + + +def load_mnist( + train: bool = True, + folder: Optional[str] = None, +) -> Tuple[np.ndarray, np.ndarray]: + """Load and/or download the MNIST digit-classification dataset. + + See [https://en.wikipedia.org/wiki/MNIST_database] for information + on the MNIST dataset. + + Arguments + --------- + train: bool, default=True + Whether to return the 60k training subset, or the 10k testing one. + folder: str or None, default=None + Optional path to a root folder where to find or download the + raw MNIST data. If None, download the data but only store it + in memory. + + Returns + ------- + images: np.ndarray + Input images, as a (n_images, 28, 28) float numpy array. + May be passed as `data` of a declearn `InMemoryDataset`. + labels: np.ndarray + Target labels, as a (n_images) int numpy array. + May be passed as `target` of a declearn `InMemoryDataset`. + """ + tag = "train" if train else "t10k" + images = _load_mnist_data(folder, tag, images=True) + labels = _load_mnist_data(folder, tag, images=False) + return images, labels + + +def _load_mnist_data( + folder: Optional[str], tag: str, images: bool +) -> np.ndarray: + """Load (and/or download) and return data from a raw MNIST file.""" + name = f"{tag}-images-idx3" if images else f"{tag}-labels-idx1" + name = f"{name}-ubyte.gz" + # Optionally download the gzipped file from the internet. + if folder is None or not os.path.isfile(os.path.join(folder, name)): + data = _download_mnist_file(name, folder) + data = gzip.decompress(data) + # Otherwise, read its contents from a local copy. + else: + with gzip.open(os.path.join(folder, name), "rb") as file: + data = file.read() + # Read and parse the source data into a numpy array. + if images: + shape, off = [ + int(data[i : i + 4].hex(), 16) for i in range(4, 16, 4) + ], 16 + else: + shape, off = [int(data[4:8].hex(), 16)], 8 + array = np.frombuffer(bytearray(data[off:]), dtype="uint8").reshape(shape) + return (array / 255).astype(np.single) if images else array + + +def _download_mnist_file(name: str, folder: Optional[str]) -> bytes: + """Download a MNIST source file and opt. save it in a given folder.""" + # Download the file in memory. + print(f"Downloading MNIST source file {name}.") + reply = requests.get( + f"http://yann.lecun.com/exdb/mnist/{name}", timeout=300 + ) + try: + reply.raise_for_status() + except requests.HTTPError as exc: + raise RuntimeError( + f"Failed to download MNIST source file {name}." + ) from exc + # Optionally dump the file to disk. + if folder is not None: + os.makedirs(folder, exist_ok=True) + with open(os.path.join(folder, name), "wb") as file: + file.write(reply.content) + # Return the downloaded data. + return reply.content diff --git a/declearn/dataset/utils/__init__.py b/declearn/dataset/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1516e660f8c896b2b81942f45ea92ee5d36d43b3 --- /dev/null +++ b/declearn/dataset/utils/__init__.py @@ -0,0 +1,42 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utils to manipulate datasets (load, save, split...). + +Data loading and saving +----------------------- +declearn provides with utils to load and save array-like data tensors +to and from various file formats: + +* [load_data_array][declearn.dataset.utils.load_data_array]: + Load a data array (numpy, scipy, pandas) from a dump file. +* [save_data_array][declearn.dataset.utils.save_data_array]: + Save a data array (numpy, scipy, pandas) to a dump file. +* [sparse_from_file][declearn.dataset.utils.sparse_from_file]: + Backend to load a sparse matrix from a dump file. +* [sparse_to_file][declearn.dataset.utils.sparse_to_file]: + Backend to save a sparse matrix to a dump file + +Data splitting +-------------- +* [split_multi_classif_dataset] +[declearn.dataset.utils.split_multi_classif_dataset]: + Split a classification dataset into (opt. heterogeneous) shards. +""" +from ._save_load import load_data_array, save_data_array +from ._sparse import sparse_from_file, sparse_to_file +from ._split_classif import split_multi_classif_dataset diff --git a/declearn/dataset/utils/_save_load.py b/declearn/dataset/utils/_save_load.py new file mode 100644 index 0000000000000000000000000000000000000000..90d936f09e9a549c3bc1617e617c1eb84a960456 --- /dev/null +++ b/declearn/dataset/utils/_save_load.py @@ -0,0 +1,156 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utils to save and load array-like data to and from various file formats.""" + +import functools +import os +from typing import Any, Union + +import numpy as np +import pandas as pd # type: ignore +from scipy.sparse import spmatrix # type: ignore +from sklearn.datasets import load_svmlight_file # type: ignore + +from declearn.dataset.utils._sparse import sparse_from_file, sparse_to_file + +__all__ = [ + "load_data_array", + "save_data_array", +] + + +DataArray = Union[np.ndarray, pd.DataFrame, spmatrix] + + +def load_data_array( + path: str, + **kwargs: Any, +) -> DataArray: + """Load a data array from a dump file. + + Supported file extensions + ------------------------- + - `.csv`: + csv file, comma-delimited by default. + Any keyword arguments to `pandas.read_csv` may be passed. + - `.npy`: + Non-pickle numpy array dump, created with `numpy.save`. + - `.sparse`: + Scipy sparse matrix dump, created with the custom + `declearn.data.sparse.sparse_to_file` function. + - `.svmlight`: + SVMlight sparse matrix and labels array dump. + Parse using `sklearn.load_svmlight_file`, and + return either features or labels based on the + `which` int keyword argument (default: 0, for + features). + + Parameters + ---------- + path: str + Path to the data array dump file. + Extension must be adequate to enable proper parsing; + see list of supported extensions above. + **kwargs: + Extension-type-based keyword parameters may be passed. + See above for details. + + Returns + ------- + data: numpy.ndarray or pandas.DataFrame or scipy.spmatrix + Reloaded data array. + + Raises + ------ + TypeError + If `path` is of unsupported extension. + + Any exception raised by data-loading functions may also be raised + (e.g. if the file cannot be proprely parsed). + """ + ext = os.path.splitext(path)[1] + if ext == ".csv": + return pd.read_csv(path, **kwargs) + if ext == ".npy": + return np.load(path, allow_pickle=False) + if ext == ".sparse": + return sparse_from_file(path) + if ext == ".svmlight": + which = kwargs.get("which", 0) + return load_svmlight_file(path)[which] + raise TypeError(f"Unsupported data array file extension: '{ext}'.") + + +def save_data_array( + path: str, + array: Union[DataArray, pd.Series], +) -> str: + """Save a data array to a dump file. + + Supported types of data arrays + ------------------------------ + - `pandas.DataFrame` or `pandas.Series`: + Dump to a comma-separated `.csv` file. + - `numpy.ndarray`: + Dump to a non-pickle `.npy` file. + - `scipy.sparse.spmatrix`: + Dump to a `.sparse` file, using a custom format + and `declearn.data.sparse.sparse_to_file`. + + Parameters + ---------- + path: str + Path to the file where to dump the array. + Appropriate file extension will be added when + not present (i.e. `path` may be a basename). + array: data array structure (see above) + Data array that needs dumping to file. + See above for supported types and associated + behaviours. + + Returns + ------- + path: str + Path to the created file dump, based on the input + `path` and the chosen file extension (see above). + + Raises + ------ + TypeError + If `array` is of unsupported type. + """ + # Select a file extension and set up the array-dumping function. + if isinstance(array, (pd.DataFrame, pd.Series)): + ext = ".csv" + save = functools.partial( + array.to_csv, sep=",", encoding="utf-8", index=False + ) + elif isinstance(array, np.ndarray): + ext = ".npy" + save = functools.partial(np.save, arr=array) + elif isinstance(array, spmatrix): + ext = ".sparse" + save = functools.partial(sparse_to_file, matrix=array) + else: + raise TypeError(f"Unsupported data array type: '{type(array)}'.") + # Ensure proper naming. Save the array. Return the path. + if not path.endswith(ext): + path += ext + os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) + save(path) + return path diff --git a/declearn/dataset/_sparse.py b/declearn/dataset/utils/_sparse.py similarity index 70% rename from declearn/dataset/_sparse.py rename to declearn/dataset/utils/_sparse.py index 355537e83c3439affa46fc91e3cfe9d850aba7de..6bd3be35aedc1053cdae0b9c9a2cbfd200dd16e6 100644 --- a/declearn/dataset/_sparse.py +++ b/declearn/dataset/utils/_sparse.py @@ -17,18 +17,16 @@ """Sparse matrix file dumping and loading utils, inspired by svmlight. -The format used is mostly similar to the SVMlight one -(see for example `sklearn.datasets.dump_svmlight_file`), -but enables storing a single matrix rather than a (X, y) -pair of arrays. It also records the input matrix's dtype -and type of sparse format, which are thus restored when -reloading - while the scikit-learn implementation always -returns a CSR matrix and requires inputing the dtype. - -This implementation does not use any tricks (e.g. cython -or interfacing an external c++ tool) to optimize dump or -load runtimes, it may therefore be slower than using the -scikit-learn functions or any third-party alternative. +The format used is mostly similar to the SVMlight one (see for example +`sklearn.datasets.dump_svmlight_file`), but enables storing a single +matrix rather than a (X, y) pair of arrays. It also records the input +matrix's dtype and type of sparse format, which are thus restored when +reloading - while the scikit-learn implementation always returns a CSR +matrix and requires inputing the dtype. + +This implementation does not use any tricks (e.g. cython or interfacing an +external c++ tool) to optimize dump or load runtimes. It may therefore be +slower than using the scikit-learn functions or any third-party alternative. """ import json @@ -68,7 +66,8 @@ def sparse_to_file( ) -> None: """Dump a scipy sparse matrix as a text file. - See function `sparse_from_file` to reload from the dump file. + See the [`sparse_from_file`][declearn.dataset.utils.sparse_from_file] + counterpart function to reload the dumped data from the created file. Parameters ---------- @@ -85,11 +84,13 @@ def sparse_to_file( If 'matrix' is of unsupported type, i.e. not a BSR, CSC, CSR, COO, DIA, DOK or LIL sparse matrix. - Note: the format used is mostly similar to the SVMlight one - (see for example `sklearn.datasets.dump_svmlight_file`), but - enables storing a single matrix rather than a (X, y) pair of - arrays. It also records the input matrix's dtype and type of - sparse format, which are restored upon reloading. + Note + ---- + The format used is mostly similar to the SVMlight one (see for example + `sklearn.datasets.dump_svmlight_file`), but enables storing a single + matrix rather than a (X, y) pair of arrays. It also records the input + matrix's dtype and type of sparse format, which are restored when the + counterpart `sparse_from_file` function is used to reload it. """ if os.path.splitext(path)[1] != ".sparse": path += ".sparse" @@ -116,7 +117,8 @@ def sparse_to_file( def sparse_from_file(path: str) -> spmatrix: """Return a scipy sparse matrix loaded from a text file. - See function `sparse_to_file` to create reloadable dump files. + See the [`sparse_to_file`][declearn.dataset.utils.sparse_to_file] + counterpart function to create reloadable sparse data dump files. Parameters ---------- @@ -139,12 +141,13 @@ def sparse_from_file(path: str) -> spmatrix: i.e. "bsr", "csv", "csc", "coo", "dia", "dok" or "lil". - Note: the format used is mostly similar to the SVMlight one - (see for example `sklearn.datasets.load_svmlight_file`), but - the file must store a single matrix rather than a (X, y) pair - of arrays. It must also record some metadata in its header, - which are notably used to restore the initial matrix's dtype - and type of sparse format. + Note + ---- + The format used is mostly similar to the SVMlight one (see for example + `sklearn.datasets.load_svmlight_file`), but the file must store a single + matrix rather than a (X, y) pair of arrays. It must also record some + metadata in its header, which are notably used to restore the initial + matrix's dtype and type of sparse format. """ with open(path, "r", encoding="utf-8") as file: # Read and parse the file's header. @@ -161,7 +164,10 @@ def sparse_from_file(path: str) -> spmatrix: cnv = int if lil.dtype.kind == "i" else float # Iteratively parse and fill-in row data. for rix, row in enumerate(file): - for field in row.strip(" \n").split(" "): + row = row.strip(" \n") + if not row: # all-zeros row + continue + for field in row.split(" "): ind, val = field.split(":") lil[rix, int(ind)] = cnv(val) # Convert the matrix to its initial format and return. diff --git a/declearn/dataset/utils/_split_classif.py b/declearn/dataset/utils/_split_classif.py new file mode 100644 index 0000000000000000000000000000000000000000..0e3dbc98d348141e27b52f80605f1eb0176cc2c2 --- /dev/null +++ b/declearn/dataset/utils/_split_classif.py @@ -0,0 +1,192 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utils to split a multi-category classification dataset into shards.""" + +from typing import List, Literal, Optional, Tuple, Type, Union + +import numpy as np +from scipy.sparse import csr_matrix, spmatrix # type: ignore + + +__all__ = [ + "split_multi_classif_dataset", +] + + +def split_multi_classif_dataset( + dataset: Tuple[Union[np.ndarray, spmatrix], np.ndarray], + n_shards: int, + scheme: Literal["iid", "labels", "biased"], + p_valid: float = 0.2, + seed: Optional[int] = None, +) -> List[Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]]: + """Split a classification dataset into (opt. heterogeneous) shards. + + The data-splitting schemes are the following: + + - If "iid", split the dataset through iid random sampling. + - If "labels", split into shards that hold all samples associated + with mutually-exclusive target classes. + - If "biased", split the dataset through random sampling according + to a shard-specific random labels distribution. + + Parameters + ---------- + dataset: tuple(np.ndarray|spmatrix, np.ndarray) + Raw dataset, as a pair of numpy arrays that respectively contain + the input features and (aligned) labels. Input features may also + be a scipy sparse matrix, that will temporarily be cast to CSR. + n_shards: int + Number of shards between which to split the dataset. + scheme: {"iid", "labels", "biased"} + Splitting scheme to use. In all cases, shards contain mutually- + exclusive samples and cover the full dataset. See details above. + p_valid: float, default=0.2 + Share of each shard to turn into a validation subset. + seed: int or None, default=None + Optional seed to the RNG used for all sampling operations. + + Returns + ------- + shards: + List of dataset shards, where each element is formatted as a + tuple of tuples: `((x_train, y_train), (x_valid, y_valid))`. + Input features will be of same type as `inputs`. + + Raises + ------ + TypeError + If `inputs` is not a numpy array or scipy sparse matrix. + ValueError + If `scheme` has an invalid value. + """ + # Select the splitting function to be used. + if scheme == "iid": + func = split_iid + elif scheme == "labels": + func = split_labels + elif scheme == "biased": + func = split_biased + else: + raise ValueError(f"Invalid 'scheme' value: '{scheme}'.") + # Set up the RNG and unpack the dataset. + rng = np.random.default_rng(seed) + inputs, target = dataset + # Optionally handle sparse matrix inputs. + sp_type = None # type: Optional[Type[spmatrix]] + if isinstance(inputs, spmatrix): + sp_type = type(inputs) + inputs = csr_matrix(inputs) + elif not isinstance(inputs, np.ndarray): + raise TypeError( + "'inputs' should be a numpy array or scipy sparse matrix." + ) + # Split the dataset into shards. + split = func(inputs, target, n_shards, rng) + # Further split shards into training and validation subsets. + shards = [train_valid_split(inp, tgt, p_valid, rng) for inp, tgt in split] + # Optionally convert back sparse inputs, then return. + if sp_type is not None: + shards = [ + ((sp_type(xt), yt), (sp_type(xv), yv)) + for (xt, yt), (xv, yv) in shards + ] + return shards + + +def split_iid( + inputs: Union[np.ndarray, csr_matrix], + target: np.ndarray, + n_shards: int, + rng: np.random.Generator, +) -> List[Tuple[np.ndarray, np.ndarray]]: + """Split a dataset into shards using iid sampling.""" + order = rng.permutation(inputs.shape[0]) + s_len = inputs.shape[0] // n_shards + split = [] # type: List[Tuple[np.ndarray, np.ndarray]] + for idx in range(n_shards): + srt = idx * s_len + end = (srt + s_len) if idx < (n_shards - 1) else len(order) + shard = order[srt:end] + split.append((inputs[shard], target[shard])) + return split + + +def split_labels( + inputs: Union[np.ndarray, csr_matrix], + target: np.ndarray, + n_shards: int, + rng: np.random.Generator, +) -> List[Tuple[np.ndarray, np.ndarray]]: + """Split a dataset into shards with mutually-exclusive label classes.""" + classes = np.unique(target) + if n_shards > len(classes): + raise ValueError( + f"Cannot share {len(classes)} classes between {n_shards}" + "shards with mutually-exclusive labels." + ) + s_len = len(classes) // n_shards + order = rng.permutation(classes) + split = [] # type: List[Tuple[np.ndarray, np.ndarray]] + for idx in range(n_shards): + srt = idx * s_len + end = (srt + s_len) if idx < (n_shards - 1) else len(order) + shard = np.isin(target, order[srt:end]) + split.append((inputs[shard], target[shard])) + return split + + +def split_biased( + inputs: Union[np.ndarray, csr_matrix], + target: np.ndarray, + n_shards: int, + rng: np.random.Generator, +) -> List[Tuple[np.ndarray, np.ndarray]]: + """Split a dataset into shards with heterogeneous label distributions.""" + classes = np.unique(target) + index = np.arange(len(target)) + s_len = len(target) // n_shards + split = [] # type: List[Tuple[np.ndarray, np.ndarray]] + for idx in range(n_shards): + if idx < (n_shards - 1): + # Draw a random distribution of labels for this node. + logits = np.exp(rng.normal(size=len(classes))) + lprobs = logits[target[index]] + lprobs = lprobs / lprobs.sum() + # Draw samples based on this distribution, without replacement. + shard = rng.choice(index, size=s_len, replace=False, p=lprobs) + index = index[~np.isin(index, shard)] + else: + # For the last node: use the remaining samples. + shard = index + split.append((inputs[shard], target[shard])) + return split + + +def train_valid_split( + inputs: np.ndarray, + target: np.ndarray, + p_valid: float, + rng: np.random.Generator, +) -> Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]: + """Split a dataset between train and validation using iid sampling.""" + order = rng.permutation(inputs.shape[0]) + v_len = np.ceil(inputs.shape[0] * p_valid).astype(int) + train = inputs[order[v_len:]], target[order[v_len:]] + valid = inputs[order[:v_len]], target[order[:v_len]] + return train, valid diff --git a/declearn/main/utils/_training.py b/declearn/main/utils/_training.py index 76e46bcc672408ec9d58993d34ffe56c36ce6ffa..901d953868e64b48a84d30a945c59d680862edb8 100644 --- a/declearn/main/utils/_training.py +++ b/declearn/main/utils/_training.py @@ -33,7 +33,7 @@ from declearn.metrics import MeanMetric, Metric, MetricInputType, MetricSet from declearn.model.api import Model from declearn.optimizer import Optimizer from declearn.typing import Batch -from declearn.utils import get_logger +from declearn.utils import LOGGING_LEVEL_MAJOR, get_logger __all__ = [ "TrainingManager", @@ -344,7 +344,8 @@ class TrainingManager: effort = constraints.get_values() result = self.metrics.get_result() states = self.metrics.get_states() - self.logger.info( + self.logger.log( + LOGGING_LEVEL_MAJOR, "Local scalar evaluation metrics: %s", {k: v for k, v in result.items() if isinstance(v, float)}, ) diff --git a/declearn/quickrun/__init__.py b/declearn/quickrun/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..31a8832e13c2a4a965cdfeb0ae0ba2d5ee4019f2 --- /dev/null +++ b/declearn/quickrun/__init__.py @@ -0,0 +1,41 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Script to quickly run a simulated FL example locally using declearn. + +This submodule, which is not imported by default, mainly aims at providing +with the `declearn-quickrun` command-line entry-point so as to easily set +up and run simulated federated learning experiments on a single computer. + +It exposes the following, merely as a way to make the documentation of that +util available to end-users: + +- [quickrun][declearn.quickrun.quickrun]: + Backend function of the `declearn-quickrun` command-line entry-point. +- [parse_data_folder][declearn.quickrun.parse_data_folder]: + Util to parse through a data folder used in a quickrun experiment. +- [DataSourceConfig][declearn.quickrun.DataSourceConfig]: + Dataclass and TOML parser for data-parsing hyper-parameters. +- [ExperimentConfig][declearn.quickrun.ExperimentConfig]: + Dataclass and TOML parser for experiment-defining hyper-parameters. +- [ModelConfig][declearn.quickrun.ModelConfig]: + Dataclass and TOML parser for model-defining hyper-parameters. +""" + +from ._config import DataSourceConfig, ExperimentConfig, ModelConfig +from ._parser import parse_data_folder +from ._run import quickrun diff --git a/declearn/quickrun/_config.py b/declearn/quickrun/_config.py new file mode 100644 index 0000000000000000000000000000000000000000..7355a215db164a69c2e12e893ae49439bf7da22d --- /dev/null +++ b/declearn/quickrun/_config.py @@ -0,0 +1,112 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TOML-parsable container for quickrun configurations.""" + +import dataclasses +from typing import Any, Dict, List, Optional, Union + +from declearn.metrics import MetricInputType, MetricSet +from declearn.utils import TomlConfig + +__all__ = [ + "DataSourceConfig", + "ExperimentConfig", + "ModelConfig", +] + + +@dataclasses.dataclass +class ModelConfig(TomlConfig): + """Dataclass used to provide custom model location and name. + + Attributes + ---------- + model_file: str or None + Path to the python file under which the model is declared. + If None, look for "model.py" parallel to the "config.toml" one. + model_name: str, default="model" + Name of the variable storing the declearn Model to train, + declared in the main scope of the model file. + """ + + model_file: Optional[str] = None + model_name: str = "model" + + +@dataclasses.dataclass +class DataSourceConfig(TomlConfig): + """Dataclass associated with the quickrun's `parse_data_folder` function. + + Attributes + ---------- + data_folder: str + Absolute path to the to the main folder hosting the data. + client_names: list or None + List of custom client names to look for in the data folder. + If None, default to all subdirectories of the data folder. + dataset_names: dict or None + Dict of custom dataset names, to look for in each client folder. + Expect 'train_data, train_target, valid_data, valid_target' as keys. + If None, files will be expected to be prefixed using the former keys. + """ + + data_folder: Optional[str] = None + client_names: Optional[List[str]] = None + dataset_names: Optional[Dict[str, str]] = None + + +@dataclasses.dataclass +class ExperimentConfig(TomlConfig): + """Dataclass providing kwargs to `FederatedServer` and `FederatedClient`. + + Attributes + ---------- + metrics: MetricSet or None + Optional MetricSet instance, defining evaluation metrics to compute + in addition to the model's loss. It may be parsed from a list of + Metric names or (name, config) tuples), or from a MetricSet config + dict. + checkpoint: str or None + The checkpoint folder path where to save the server's and client-wise + outputs (round-wise model weights, evaluation metrics, logs, etc.). + """ + + metrics: Optional[MetricSet] = None + checkpoint: Optional[str] = None + + def parse_metrics( + self, + inputs: Union[MetricSet, Dict[str, Any], List[MetricInputType], None], + ) -> Optional[MetricSet]: + """Parser for metrics.""" + if inputs is None or isinstance(inputs, MetricSet): + return None + try: + # Case of a manual listing of metrics (most expected). + if isinstance(inputs, (tuple, list)): + return MetricSet.from_specs(inputs) + # Case of a MetricSet config dict (unexpected but supported). + if isinstance(inputs, dict): + return MetricSet.from_config(inputs) + except (TypeError, ValueError) as exc: + raise TypeError( + f"Failed to parse inputs for field 'metrics': {exc}." + ) from exc + raise TypeError( + "Failed to parse inputs for field 'metrics': unproper type." + ) diff --git a/declearn/quickrun/_parser.py b/declearn/quickrun/_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..422d0999d6625b0959ebbd0931ec0d841f13d88d --- /dev/null +++ b/declearn/quickrun/_parser.py @@ -0,0 +1,194 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Util to parse the contents of a data folder into a nested dict of paths.""" + +import os +from pathlib import Path +from typing import Dict, List, Optional + +from declearn.quickrun._config import DataSourceConfig + + +__all__ = [ + "parse_data_folder", +] + + +def parse_data_folder( + data_config: DataSourceConfig, + folder: Optional[str] = None, +) -> Dict[str, Dict[str, str]]: + """Parse the contents of a data folder into a nested dict of file paths. + + This function expects the folder to abide by the following standard: + + folder/ + └─── data*/ + └─── client*/ + │ train_data.* - training data + │ train_target.* - training labels + │ valid_data.* - validation data + │ valid_target.* - validation labels + └─── client*/ + │ ... + + Parameters + ---------- + data_config: DataSourceConfig + DataSourceConfig instance; see its documentation for details. + folder: str or None + The main experiment folder in which to look for a `data*` folder. + Overridden by `data_config.data_folder` when specified. + + Returns + ------- + paths: + Nested directory containing the parsed file paths, with structure + `{client_name: {file_key_name: file_path}}`, where the key names + are always the same: "train_data", "train_target", "valid_data" + and "valid_target". + """ + # Identify the root data folder. + data_folder = get_data_folder_path(data_config.data_folder, folder) + # Identify clients' data folders. + client_names = list_client_names(data_folder, data_config.client_names) + clients = {c: {} for c in client_names} # type: Dict[str, Dict[str, str]] + # Set up a mapping between expected files and their naming. + data_items = [ + "train_data", + "train_target", + "valid_data", + "valid_target", + ] + dataset_names = data_config.dataset_names + if dataset_names: + if set(data_items) != dataset_names.keys(): + raise ValueError( + "Please provide a properly formatted dictionnary as input, " + f"using the following keys: {data_items}" + ) + else: + dataset_names = {name: name for name in data_items} + # Gather client-wise file paths. + for client, paths in clients.items(): + client_dir = data_folder.joinpath(client) + for key, val in dataset_names.items(): + files = [p for p in client_dir.glob(f"{val}*") if p.is_file()] + if not files: + raise ValueError( + f"Could not find a '{val}.*' file for client '{client}'." + ) + if len(files) > 1: + raise ValueError( + f"Found multiple '{val}.*' files for client '{client}'." + ) + paths[key] = files[0].as_posix() + # Return the nested directory of parsed file paths. + return clients + + +def get_data_folder_path( + data_folder: Optional[str], + root_folder: Optional[str], +) -> Path: + """Return the path to a data folder. + + Parameters + ---------- + data_folder: + Optional user-specified data folder. + root_folder: + Root folder, under which to look up a 'data*' folder. + Unused if `data_folder` is not None. + + Returns + ------- + dirpath: + pathlib.Path wrapping the path to the identified data folder. + + Raises + ------ + ValueError + If the input arguments point to non-existing folders, or a data + folder cannot be unambiguously found under the root folder. + """ + # Case when a data folder is explicitly designated. + if isinstance(data_folder, str): + if os.path.isdir(data_folder): + return Path(data_folder) + raise ValueError( + f"{data_folder} is not a valid path. To use an example " + "dataset, run `declearn-split` first." + ) + # Case when working from a root folder. + if not isinstance(root_folder, str): + raise ValueError( + "Please provide either a data folder or its parent folder." + ) + folders = list(Path(root_folder).glob("data*")) + if not folders: + raise ValueError( + f"No folder starting with 'data' found under {root_folder}. " + "Please store your split data under a 'data_*' folder. " + "To use an example dataset, run `declearn-split` first." + ) + if len(folders) > 1: + raise ValueError( + "More than one folder starting with 'data' found under " + f"{root_folder}. Please store your data under a single " + "parent folder, or specify the target data folder." + ) + return folders[0] + + +def list_client_names( + data_folder: Path, + client_names: Optional[List[str]], +) -> List[str]: + """List client-wise subdirectories under a data folder. + + Parameters + ---------- + data_folder: + `pathlib.Path` designating the main data folder. + client_names: + Optional list of clients to restrict the outputs to. + + Raises + ------ + ValueError + If `client_names` is of unproper type, or lists names that cannot + be found under `data_folder`. + """ + # Case when client names are provided: verify that they can be found. + if client_names: + if not isinstance(client_names, list): + raise ValueError( + "Please provide a valid list of client names for " + "argument 'client_names'" + ) + if not all( + data_folder.joinpath(name).is_dir() for name in client_names + ): + raise ValueError( + "Not all provided client names could be found under " + f"{data_folder}" + ) + return client_names.copy() + # Otherwise, list subdirectories of the data folder. + return [path.name for path in data_folder.iterdir() if path.is_dir()] diff --git a/declearn/quickrun/_run.py b/declearn/quickrun/_run.py new file mode 100644 index 0000000000000000000000000000000000000000..293cf5297c9b9c8f3236c5a1bef9d50f8c4a1089 --- /dev/null +++ b/declearn/quickrun/_run.py @@ -0,0 +1,258 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Script to quickly run a simulated FL example locally using declearn. + +The script requires to be provided with the path to a folder containing: + +* A python file in which a declearn model is instantiated (in main scope). +* A TOML file with all the elements required to configure an FL experiment. +* A data folder, structured in a specific way. + +If not provided with this, the script defaults to the MNIST example provided +by declearn in `declearn.example.quickrun`. + +The script then locally runs the FL experiment as layed out in the TOML file, +using privided model and data, and stores its result in the same folder. +""" + +import importlib +import logging +import os +from datetime import datetime +from typing import Dict, Tuple + +import fire # type: ignore + +from declearn.communication import NetworkClientConfig, NetworkServerConfig +from declearn.dataset import InMemoryDataset +from declearn.main import FederatedClient, FederatedServer +from declearn.main.config import FLOptimConfig, FLRunConfig +from declearn.model.api import Model +from declearn.quickrun._config import ( + DataSourceConfig, + ExperimentConfig, + ModelConfig, +) +from declearn.quickrun._parser import parse_data_folder +from declearn.test_utils import make_importable +from declearn.utils import ( + LOGGING_LEVEL_MAJOR, + get_logger, + run_as_processes, + set_device_policy, +) + +__all__ = ["quickrun"] + + +def get_model(folder: str, model_config: ModelConfig) -> Model: + "Return a model instance from a model config instance" + path = model_config.model_file or os.path.join(folder, "model.py") + path = os.path.abspath(path) + if not os.path.isfile(path): + raise FileNotFoundError("Model file not found: '{path}'.") + with make_importable(os.path.dirname(path)): + mod = importlib.import_module(os.path.basename(path)[:-3]) + model = getattr(mod, model_config.model_name) + return model + + +def get_checkpoint(folder: str, expe_config: ExperimentConfig) -> str: + """Return the checkpoint folder, either default or as given in config""" + if expe_config.checkpoint: + checkpoint = expe_config.checkpoint + else: + timestamp = datetime.now().strftime("%y-%m-%d_%H-%M") + checkpoint = os.path.join(folder, f"result_{timestamp}") + return checkpoint + + +def run_server( + folder: str, + network: NetworkServerConfig, + model_config: ModelConfig, + optim: FLOptimConfig, + config: FLRunConfig, + expe_config: ExperimentConfig, +) -> None: + """Routine to run a FL server, called by `run_declearn_experiment`.""" + # arguments serve modularity; pylint: disable=too-many-arguments + set_device_policy(gpu=False) + model = get_model(folder, model_config) + checkpoint = get_checkpoint(folder, expe_config) + checkpoint = os.path.join(checkpoint, "server") + logger = get_logger("Server", fpath=os.path.join(checkpoint, "logger.txt")) + server = FederatedServer( + model, network, optim, expe_config.metrics, checkpoint, logger + ) + server.run(config) + + +def run_client( + folder: str, + network: NetworkClientConfig, + model_config: ModelConfig, + expe_config: ExperimentConfig, + name: str, + paths: Dict[str, str], +) -> None: + """Routine to run a FL client, called by `run_declearn_experiment`.""" + # arguments serve modularity; pylint: disable=too-many-arguments + # Overwrite client name based on folder name. + network.name = name + # Make the model importable and disable GPU use. + set_device_policy(gpu=False) + _ = get_model(folder, model_config) + # Add checkpointer. + checkpoint = get_checkpoint(folder, expe_config) + checkpoint = os.path.join(checkpoint, name) + # Set up a logger: write everything to file, but filter console outputs. + logger = get_logger(name, fpath=os.path.join(checkpoint, "logs.txt")) + for handler in logger.handlers: + if isinstance(handler, logging.StreamHandler): + handler.setLevel(LOGGING_LEVEL_MAJOR) + # Wrap train and validation data as Dataset objects. + train = InMemoryDataset( + paths.get("train_data"), + target=paths.get("train_target"), + expose_classes=True, + ) + valid = InMemoryDataset( + paths.get("valid_data"), + target=paths.get("valid_target"), + ) + client = FederatedClient(network, train, valid, checkpoint, logger=logger) + client.run() + + +def get_toml_folder(config: str) -> Tuple[str, str]: + """Return the path to an experiment's folder and TOML config file. + + Determine if provided config is a file or a directory, and return: + + * The path to the TOML config file + * The path to the main folder of the experiment + """ + config = os.path.abspath(config) + if os.path.isfile(config): + toml = config + folder = os.path.dirname(config) + elif os.path.isdir(config): + folder = config + toml = f"{folder}/config.toml" + else: + raise FileNotFoundError( + f"Failed to find quickrun config file at '{config}'." + ) + return toml, folder + + +def locate_split_data(toml: str, folder: str) -> Dict: + """Attempt to find split data according to the config toml or default.""" + data_config = DataSourceConfig.from_toml(toml, False, "data") + client_dict = parse_data_folder(data_config, folder) + return client_dict + + +def server_to_client_network( + network_cfg: NetworkServerConfig, +) -> NetworkClientConfig: + "Convert server network config to client network config." + return NetworkClientConfig.from_params( + protocol=network_cfg.protocol, + server_uri=f"ws://localhost:{network_cfg.port}", + name="replaceme", + ) + + +def quickrun(config: str) -> None: + """Run a server and its clients using multiprocessing. + + The script requires to be provided with the path to a TOML file + with all the elements required to configurate an FL experiment, + or the path to a folder containing : + + * A TOML file with all the elements required to configure an FL experiment. + * A python file in which a declearn model is instantiated (in main scope). + * A data folder, structured in a specific way: + folder/ + [client_a]/ + train_data.(csv|npy|sparse|svmlight) + train_target.(csv|npy|sparse|svmlight) + valid_data.(csv|npy|sparse|svmlight) + valid_target.(csv|npy|sparse|svmlight) + [client_b]/ + ... + ... + + Parameters + ---------- + config: str + Path to either a toml file or a properly formatted folder + containing the elements required to launch the experiment. + + Notes + ----- + - The data folder structure may be obtained by using the `declearn-split` + commandline entry-point, or the `declearn.dataset.split_data` util. + - The quickrun mode works by simulating a federated learning process, where + all clients operate under parallel python processes, and communicate over + the localhost using un-encrypted websockets communications. + - When run without any argument, this script/function operates on a basic + MNIST example, for demonstration purposes. + - You may refer to a more detailed MNIST example on our GitLab repository. + See the `examples/mnist_quickrun` folder. + """ + # main script; pylint: disable=too-many-locals + toml, folder = get_toml_folder(config) + # locate split data or split it if needed + client_dict = locate_split_data(toml, folder) + # Parse toml files + ntk_server_cfg = NetworkServerConfig.from_toml(toml, False, "network") + ntk_client_cfg = server_to_client_network(ntk_server_cfg) + optim_cgf = FLOptimConfig.from_toml(toml, False, "optim") + run_cfg = FLRunConfig.from_toml(toml, False, "run") + model_cfg = ModelConfig.from_toml(toml, False, "model", True) + expe_cfg = ExperimentConfig.from_toml(toml, False, "experiment", True) + # Set up a (func, args) tuple specifying the server process. + p_server = ( + run_server, + (folder, ntk_server_cfg, model_cfg, optim_cgf, run_cfg, expe_cfg), + ) + # Set up the (func, args) tuples specifying client-wise processes. + p_client = [] + for name, data_dict in client_dict.items(): + client = ( + run_client, + (folder, ntk_client_cfg, model_cfg, expe_cfg, name, data_dict), + ) + p_client.append(client) + # Run each and every process in parallel. + success, outputs = run_as_processes(p_server, *p_client) + assert success, "The FL process failed:\n" + "\n".join( + str(exc) for exc in outputs if isinstance(exc, RuntimeError) + ) + + +def main() -> None: + """Fire-wrapped `quickrun`.""" + fire.Fire(quickrun) + + +if __name__ == "__main__": + main() diff --git a/declearn/test_utils/__init__.py b/declearn/test_utils/__init__.py index ff6ffd54c1d7166ab64222f956937683346b0167..30e7aa42484d4e737c3aaf94e6bff08cfe856839 100644 --- a/declearn/test_utils/__init__.py +++ b/declearn/test_utils/__init__.py @@ -36,7 +36,6 @@ from ._assertions import ( ) from ._gen_ssl import generate_ssl_certificates from ._imports import make_importable -from ._multiprocess import run_as_processes from ._vectors import ( FrameworkType, GradientsTestCase, diff --git a/declearn/test_utils/_vectors.py b/declearn/test_utils/_vectors.py index 7b7722c0cf6a8039932c3a84bd149c4f1b41d4d9..fdc53dd79d6161b8de6f7d0d0ed004b3954ecea0 100644 --- a/declearn/test_utils/_vectors.py +++ b/declearn/test_utils/_vectors.py @@ -22,7 +22,7 @@ import typing from typing import List, Literal, Optional, Type import numpy as np -import pkg_resources +import pkg_resources # type: ignore from numpy.typing import ArrayLike from declearn.model.api import Vector @@ -60,7 +60,9 @@ class GradientsTestCase: """ def __init__( - self, framework: FrameworkType, seed: Optional[int] = 0 + self, + framework: FrameworkType, + seed: Optional[int] = 0, ) -> None: """Instantiate the parametrized test-case.""" if framework not in list_available_frameworks(): diff --git a/declearn/utils/__init__.py b/declearn/utils/__init__.py index 67a2e6633fe6a29efaacf1ebfc74c9cdd6f422b2..451517d3055da76d332b370297643257d79159a2 100644 --- a/declearn/utils/__init__.py +++ b/declearn/utils/__init__.py @@ -80,6 +80,15 @@ Utils to access or update parameters defining a global device-selection policy. * [set_device_policy][declearn.utils.set_device_policy]: Update the current global device policy. +Logging utils +------------- +Utils to set up and configure loggers: + +* [get_logger][declearn.utils.get_logger]: + Access or create a logger, automating basic handlers' configuration. +* [LOGGING_LEVEL_MAJOR][declearn.utils.LOGGING_LEVEL_MAJOR]: + Custom "MAJOR" severity level, between stdlib "INFO" and "WARNING". + Miscellaneous ------------- @@ -89,8 +98,8 @@ Miscellaneous Automatically build a dataclass matching a function's signature. * [dataclass_from_init][declearn.utils.dataclass_from_init]: Automatically build a dataclass matching a class's init signature. -* [get_logger][declearn.utils.get_logger]: - Access or create a logger, automating basic handlers' configuration. +* [run_as_processes][declearn.utils.run_as_processes]: + Run coroutines concurrently within individual processes. """ from ._dataclass import ( @@ -110,8 +119,10 @@ from ._json import ( json_unpack, ) from ._logging import ( + LOGGING_LEVEL_MAJOR, get_logger, ) +from ._multiprocess import run_as_processes from ._numpy import ( deserialize_numpy, serialize_numpy, diff --git a/declearn/utils/_logging.py b/declearn/utils/_logging.py index b1c926918541db206dc71f3d7b1a935177daf596..07f47627eb4bd0de5a2a41cb4084536eff09978f 100644 --- a/declearn/utils/_logging.py +++ b/declearn/utils/_logging.py @@ -22,6 +22,18 @@ import os from typing import Optional +__all__ = [ + "get_logger", + "LOGGING_LEVEL_MAJOR", +] + + +# Add a logging level between INFO and WARNING. +LOGGING_LEVEL_MAJOR = (logging.WARNING + logging.INFO) // 2 +"""Custom "MAJOR" severity level, between stdlib "INFO" and "WARNING".""" +logging.addLevelName(level=LOGGING_LEVEL_MAJOR, levelName="MAJOR") + + DEFAULT_FORMAT = "%(asctime)s:%(name)s:%(levelname)s: %(message)s" diff --git a/declearn/test_utils/_multiprocess.py b/declearn/utils/_multiprocess.py similarity index 82% rename from declearn/test_utils/_multiprocess.py rename to declearn/utils/_multiprocess.py index 5214e7bf556d80aac0cf08e98e04694a9ece4482..a39ac540d7d1001e466d3560ce3026edee38fca9 100644 --- a/declearn/test_utils/_multiprocess.py +++ b/declearn/utils/_multiprocess.py @@ -17,9 +17,11 @@ """Utils to run concurrent routines parallelly using multiprocessing.""" +import functools import multiprocessing as mp import sys import traceback +from queue import Queue from typing import Any, Callable, Dict, List, Optional, Tuple, Union __all__ = [ @@ -60,7 +62,9 @@ def run_as_processes( indicates that the process was interrupted while running. """ # Wrap routines into named processes and set up exceptions catching. - queue = mp.Queue() # type: ignore # mp.Queue[Union[Any, RuntimeError]] + queue = ( + mp.Manager().Queue() + ) # type: Queue[Tuple[str, Union[Any, RuntimeError]]] names = [] # type: List[str] count = {} # type: Dict[str, int] processes = [] # type: List[mp.Process] @@ -100,28 +104,32 @@ def run_as_processes( def add_exception_catching( func: Callable[..., Any], - queue: mp.Queue, + queue: Queue, name: Optional[str] = None, ) -> Callable[..., Any]: """Wrap a function to catch exceptions and put them in a Queue.""" if not name: name = func.__name__ - def wrapped(*args: Any, **kwargs: Any) -> Any: - """Call the wrapped function and catch exceptions or results.""" - nonlocal name, queue + return functools.partial(wrapped, func=func, queue=queue, name=name) - try: - result = func(*args, **kwargs) - except Exception as exc: # pylint: disable=broad-exception-caught - err = RuntimeError( - f"Exception of type {type(exc)} occurred:\n" - "".join(traceback.format_exception(type(exc), exc, tb=None)) - ) # future: `traceback.format_exception(exc)` (py >=3.10) - queue.put((name, err)) - sys.exit(1) - else: - queue.put((name, result)) - sys.exit(0) - return wrapped +def wrapped( + *args: Any, + func: Callable[..., Any], + queue: Queue, + name: str, +) -> Any: + """Call the wrapped function and catch exceptions or results.""" + try: + result = func(*args) + except Exception as exc: # pylint: disable=broad-exception-caught + err = RuntimeError( + f"Exception of type {type(exc)} occurred:\n" + "".join(traceback.format_exception(type(exc), exc, tb=None)) + ) # future: `traceback.format_exception(exc)` (py >=3.10) + queue.put((name, err)) + sys.exit(1) + else: + queue.put((name, result)) + sys.exit(0) diff --git a/declearn/utils/_toml_config.py b/declearn/utils/_toml_config.py index cfe96c2bb21b790db7eddab3f8175f9c21d2656e..012a2255d09a48b58314e5dc2ddc789255b4754b 100644 --- a/declearn/utils/_toml_config.py +++ b/declearn/utils/_toml_config.py @@ -296,6 +296,9 @@ class TomlConfig: def from_toml( cls, path: str, + warn_user: bool = True, + use_section: Optional[str] = None, + section_fail_ok: bool = False, ) -> Self: """Parse a structured configuration from a TOML file. @@ -315,6 +318,17 @@ class TomlConfig: path: str Path to a TOML configuration file, that provides with the hyper-parameters making up for the FL "run" configuration. + warn_user: bool, default=True + Boolean indicating whether to raise a warning when some + fields are unused. Useful for cases where unused fields are + expected, e.g. in declearn-quickrun mode. + use_section: optional(str), default=None + If not None, points to a specific section of the TOML that + should be used, rather than the whole file. Useful to parse + orchestrating TOML files, e.g. in declearn-quickrun mode. + section_fail_ok: bool, default=False + If True, allow the section specified in use_section to be + missing from the TOML file without raising an Error. Raises ------ @@ -338,6 +352,12 @@ class TomlConfig: "Failed to parse the TOML configuration file." ) from exc # Look for expected config sections in the parsed TOML file. + if isinstance(use_section, str): + try: + config = config[use_section] + except KeyError as exc: + if not section_fail_ok: + raise KeyError("Specified section not found") from exc params = {} # type: Dict[str, Any] for field in dataclasses.fields(cls): # Case when the section is provided: set it up for parsing. @@ -353,10 +373,11 @@ class TomlConfig: f"file: '{field.name}'." ) # Warn about remaining (unused) config sections. - for name in config: - warnings.warn( - f"Unsupported section encountered in {path} TOML file: " - f"'{name}'. This section will be ignored." - ) + if warn_user: + for name in config: + warnings.warn( + f"Unsupported section encountered in {path} TOML file: " + f"'{name}'. This section will be ignored." + ) # Finally, instantiate the FLConfig container. return cls.from_params(**params) diff --git a/docs/quickstart.md b/docs/quickstart.md index 72b041025376b0a29cc152db0209fee50ca9e2c2..be711050f18ecd067b2898ade1bdca86d082b035 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -1,38 +1,105 @@ # Quickstart -This section provides with demonstration code on how to run a simple federated -learning task using declearn, that requires minimal adjustments to be run for -real (mainly, to provide with a valid network configuration and actual data). +**Here's where to start if you want to quickly understand what `declearn` +does**. This tutorial exepects a basic understanding of +[federated learning](https://en.wikipedia.org/wiki/Federated_learning). -You may find even more concrete examples on our gitlab repository -[here](https://gitlab.inria.fr/magnet/declearn/declearn2/examples). -The Heart UCI example may notably be run as-is, either locally or on a -real-life network with minimal command-line parametrization. +We show different ways to use `declearn` on a well-known example, the +[MNIST dataset](http://yann.lecun.com/exdb/mnist/) +(see [section 1](#1-federated-learning-on-the-mnist-dataset)). +We then look at how to use declearn on your own problem +(see [section 2](#2-federated-learning-on-your-own-dataset)). -## Setting +## 1. Federated learning on the MNIST dataset -Here is a quickstart example on how to set up a federated learning process -to learn a LASSO logistic regression model (using a scikit-learn backend) -using pre-processed data, formatted as csv files with a "label" column, +**We are going to train a common model between three simulated clients on the +classic [MNIST dataset](http://yann.lecun.com/exdb/mnist/)**. The input of the +model is a set of images of handwritten digits, and the model needs to +determine to which number between $0$ and $9$ each image corresponds. +We show two ways to use `declearn` on this problem. + +### 1.1. Quickrun mode + +**The quickrun mode is the simplest way to simulate a federated learning +process on a single machine with `declearn`**. It does not require to +understand the details of the `declearn` implementation. It requires a basic +understanding of federated learning. + +--- +**To test this on the MNIST example**, you can follow along the jupyter +notebook provided +[here](https://gitlab.inria.fr/magnet/declearn/declearn2/-/blob/develop/examples/mnist_quickrun/mnist.ipynb), +which we recommend running on [Google Colab](https://colab.research.google.com) +to skip on setting up git, python, a virtual environment, etc. + +You may find a (possibly not entirely up-to-date) pre-hosted version of that +notebook +[here](https://colab.research.google.com/drive/13sBDOQeorI6dfziSoyRpU4q4iGuESIPo?usp=sharing). + +--- + +**If you want to run this locally**, the detailed notebook can be boiled down +to five shell commands. Set up a dedicated `conda` or `venv` environment, and +run: + +```bash +git clone https://gitlab.inria.fr/magnet/declearn/declearn2 && +cd declearn2 && +pip install .[tensorflow,websockets] && +declearn-split --folder "examples/mnist_quickrun" && +declearn-quickrun --config "examples/mnist_quickrun/config.toml" +``` + +**To better understand the details** of what happens under the hood you can +look at what the key element of the declearn process are in +[section 1.2.](#12-python-script). To understand how to use the quickrun mode +in practice, see [section 2.1.](#21-quickrun-on-your-problem). + +### 1.2. Python script + +#### MNIST + +**The quickrun mode abstracts away a lot of important elements** of the +process, and is only designed to simulate an FL experiment: the clients all +run on the same machine. In real life deployment, a `declearn` experiment is +built in python. + +--- +**To see what this looks like in practice**, you can head to the all-python +MNIST example `examples/mnist/` in the `declearn` repository, which you can +access [here](https://gitlab.inria.fr/magnet/declearn/declearn2/-/tree/develop/examples/mnist/). + +This version of the example may either be used to run a simulated process on +a single computer, or to deploy the example over a real-life network. + +--- + +#### Stylized structure + +At a very high-level, declearn is structured around two key objects. The +`Clients` hold the data and perform calculations locally. The `Server` owns +the model and the global training process. They communicate over a `network`, +the central endpoint of which is hosted by the `Server`. + +We provide below a stylized view of the main elements of the `Server` and +`Client` scripts. For more details, you can look at the hands-on usage +[section](user-guide/usage.md) of the documentation. + +We show what a `Client` and `Server` script can look like on a hypothetical +LASSO logistic regression model, using a scikit-learn backend and +pre-processed data. The data is in csv files with a "label" column, where each client has two files: one for training, the other for validation. Here, the code uses: -- standard FedAvg strategy (SGD for local steps, averaging of updates weighted - by clients' training dataset size, no modifications of server-side updates) -- 10 rounds of training, with 5 local epochs performed at each round and - 128-samples batch size -- at least 1 and at most 3 clients, awaited for 180 seconds by the server -- network communications using gRPC, on host "example.com" and port 8888 - -Note that this example code may easily be adjusted to suit use cases, using -other types of models, alternative federated learning algorithms and/or -modifying the communication, training and validation hyper-parameters. -Please refer to the [Hands-on usage](./user-guide/usage.md) section for a more -detailed and general description of how to set up a federated learning -task and process with declearn. +* **Aggregation**: the standard `FedAvg` strategy. +* **Optimizer**: standard SGD for both client and server. +* **Training**: 10 rounds of training, with 5 local epochs performed at each + round and 128-samples batch size. At least 1 and at most 3 clients, awaited + for at most 180 seconds by the server. +* **Network**: communications using `websockets`. -## Server-side script +The server-side script: ```python import declearn @@ -41,7 +108,7 @@ model = declearn.model.sklearn.SklearnSGDModel.from_parameters( kind="classifier", loss="log_loss", penalty="l1" ) netwk = declearn.communication.NetworkServerConfig( - protocol="grpc", host="example.com", port=8888, + protocol="websockets", host="127.0.0.1"", port=8888, certificate="path/to/certificate.pem", private_key="path/to/private_key.pem" ) @@ -60,14 +127,14 @@ config = declearn.main.config.FLRunConfig.from_params( server.run(config) ``` -## Client-side script +The client-side script ```python import declearn netwk = declearn.communication.NetworkClientConfig( - protocol="grpc", - server_uri="example.com:8888", + protocol="websockets", + server_uri="127.0.0.1":8888", name="client_name", certificate="path/to/root_ca.pem" ) @@ -82,22 +149,176 @@ client = declearn.main.FederatedClient( client.run() ``` -## Simulating this experiment locally - -To simulate the previous experiment on a single computer, you may set up -network communications to go through the localhost, and resort to one of -two possibilities: - -1. Run the server and client-wise scripts parallelly, e.g. in distinct - terminals. -2. Use declearn-provided tools to run the server and clients' routines - concurrently using multiprocessing. - -While technically similar (both solutions resolve on isolating the agents' -routines in separate python processes that communicate over the localhost), -the second solution offers more practicality in terms of offering a single -entrypoint for your experiment, and optionally automatically stopping any -running agent in case one of the other has failed. -To find out more about this solution, please have a look at the Heart UCI -example [implemented here](https://gitlab.inria.fr/magnet/declearn/declearn2\ --/tree/develop/examples/heart-uci). +## 2. Federated learning on your own dataset + +### 2.1. Quickrun on your problem + +Using the mode `declearn-quickrun` requires a configuration file, some data, +and a model file: + +* A TOML file, to store your experiment configurations. + In the MNIST example: `examples/mnist_quickrun/config.toml`. +* A folder with your data, split by client. + In the MNIST example: `examples/mnist_quickrun/data_iid` + (after running `declearn-split --folder "examples/mnist_quickrun"`). +* A pyhon model file, to declare your model wrapped in a `declearn` object. + In the MNIST example: `examples/mnist_quickrun/model.py`. + +#### The TOML file + +TOML is a minimal, human-readable configuration file format. +We use is to store all the configurations of an FL experiment. +The TOML is parsed by python as dictionnary with each `[header]` +as a key. For more details, see the [TOML doc](https://toml.io/en/) + +This file is your main entry point to everything else. +The absolute path to this file should be given as an argument in: + +```bash +declearn-quickrun --config <path_to_toml_file> +``` + +The TOML file has six sections, some of which are optional. Note that the order +does not matter, and that we give illustrative, not necessarily functionnal +examples. + +**`[network]`: Network configuration** used by both client and server, +most notably the port, host, and ssl certificates. An example: + +``` python +[network] + protocol = "websockets" # Protocol used, to keep things simple use websocket + host = "127.0.0.1" # Address used, works as is on most set ups + port = 8765 # Port used, works as is on most set ups +``` + +This section is parsed as the initialization arguments to the `NetworkServer` +class. Check its [documentation][declearn.communication.api.NetworkServer] +to see all available fields. Note it is also used to initialize a +[`NetworkClient`][declearn.communication.api.NetworkClient], mirroring the +server. + +**`[data]`: Where to find your data**. This is particularly useful if you have +split your data yourself, using custom names for files and folders. An example: + +```python +[data] + data_folder = "./custom/data_custom" # Your main data folder + client_names = ["client_a", "client_b", "client_c"] # The names of your client folders + + [data.dataset_names] # The names of train and test datasets + train_data = "cifar_train" + train_target = "label_train" + valid_data = "cifar_valid" + valid_target = "label_valid" +``` + +This section is parsed as the fields of a `DataSourceConfig` dataclass. +Check its [documentation][declearn.quickrun/DataSourceConfig] to see +all available fields. This `DataSourceConfig` is then parsed by the +[`parse_data_folder`][declearn.quickrun.parse_data_folder] function. + +**`[optim]`: Optimization options** for both client and server, with +three distinct sub-sections: the server-side aggregator (i) and optimizer (ii), +and the client optimizer (iii). An example: + +```python +[optim] + aggregator = "averaging" # The basic server aggregation strategy + + [optim.server_opt] # Server optimization strategy + lrate = 1.0 # Server learning rate + + [optim.client_opt] # Client optimization strategy + lrate = 0.001 # Client learning rate + modules = [["momentum", {"beta" = 0.9}]] # List of optimizer modules used + regularizers = [["lasso", {alpha = 0.1}]] # List of regularizer modules +``` + +This section is parsed as the fields of a `FLOptimConfig` dataclass. Check its +[documentation][declearn.main.config.FLOptimConfig] to see more details on +these three sub-sections. For more details on available fields within those +subsections, you can naviguate inside the documentation of the +[`Aggregator`][declearn.aggregator.Aggregator] and +[`Optimizer`][declearn.optimizer.Optimizer] classes. + +**`[run]`: Training process option** for both client and server. Most notably, +includes the number of rounds as well as the registration, training, and +evaluation parameters. An example: + +```python +[run] + rounds = 10 # Number of overall training rounds + + [run.register] # Client registration options + min_clients = 1 # Minimum of clients that need to connect + max_clients = 6 # The maximum number of clients that can connect + timeout = 5 # How long to wait for clients, in seconds + + [run.training] # Client training procedure + n_epoch = 1 # Number of local epochs + batch_size = 48 # Training batch size + drop_remainder = false # Whether to drop the last training examples + + [run.evaluate] + batch_size = 128 # Evaluation batch size +``` + +This section is parsed as the fields of a `FLRunConfig` dataclass. Check its +[documentation][declearn.main.config.FLOptimConfig] to see more details on the +sub-sections. For more details on available fields within those subsections, +you can naviguate inside the documentation of `FLRunConfig` to the relevant +dataclass, for instance [`TrainingConfig`][declearn.main.config.TrainingConfig]. + +**`[model]`: Optional section**, where to find the model. An example: + +```python +[model] +# The location to a model file +model_file = "./custom/model_custom.py" +# The name of your model file, if different from "MyModel" +model_name = "MyCustomModel" +``` + +This section is parsed as the fields of a `ModelConfig` dataclass. Check its +[documentation][declearn.quickrun.ModelConfig] to see all available fields. + +**`[experiment]`: Optional section**, what to report during the experiment and +where to report it. An example: + +```python +[experiment] +metrics=[["multi-classif",{labels = [0,1,2,3,4,5,6,7,8,9]}]] # Accuracy metric +checkpoint = "./result_custom" # Custom location for results +``` + +This section is parsed as the fields of a `ExperimentConfig` dataclass. +Check its [documentation][declearn.quickrun.ExperimentConfig] to see all +available fields. + +#### The data + +Your data, in a standard tabular format, split by client. Within each client +folder, we expect four files : training data and labels, validation data and +labels. + +If your data is not already split by client, we are developping an experimental +data splitting utility. It currently has a limited scope, only dealing +with classification tasks, excluding multi-label. You can call it using +`declearn-split --folder <path_to_original_data>`. For more details, refer to +the [documentation][declearn.dataset.split_data]. + +#### The Model file + +The model file should just contain the model you built for +your data, e.g. a `torch` model, wrapped in a declearn object. +See `examples/mnist_quickrun/model.py` for an example. + +The wrapped model should be named "model" by default. If you use any other +name, you have to specify it in the TOML file, as demonstrated in +`./custom/config_custom.toml`. + +### 2.2. Using declearn full capabilities + +To upgrade your experimental setting beyond the `quickrun` mode, you may move +on to the hands-on usage [section](user-guide/usage.md) of the documentation. diff --git a/docs/setup.md b/docs/setup.md index 547f44b1126135e50cdf39724d0c2d178a623cdf..979ef90c9dfc77757bb0e2c99b51cca58f716b82 100644 --- a/docs/setup.md +++ b/docs/setup.md @@ -107,3 +107,6 @@ pip install declearn[all,tests] # install all extra and testing dependencies package, and then to manually install the dependencies listed in the `pyproject.toml` file, using `conda install` rather than `pip install` whenever it is possible. +- On some systems, the square brackets used our pip install are not properly + parsed. Try replacing `[` by `\[` and `]` by `\]`, or putting the instruction + between quotes (`pip install "declearn[...]"`). diff --git a/examples/heart-uci/client.py b/examples/heart-uci/client.py index 1d0c8c1580edea70407f65dff1ef1288dcc1b039..3713b310909f0a0eb3082bdc522e8b801563cc15 100644 --- a/examples/heart-uci/client.py +++ b/examples/heart-uci/client.py @@ -18,26 +18,22 @@ """Script to run a federated client on the heart-disease example.""" import os +from typing import Literal import numpy as np -import pandas as pd # type: ignore + from declearn.communication import NetworkClientConfig from declearn.dataset import InMemoryDataset +from declearn.dataset.examples import load_heart_uci from declearn.main import FederatedClient from declearn.test_utils import make_importable, setup_client_argparse FILEDIR = os.path.dirname(__file__) -# Perform local imports. -# pylint: disable=wrong-import-order, wrong-import-position -with make_importable(FILEDIR): - from data import get_data -# pylint: enable=wrong-import-order, wrong-import-position - def run_client( - name: str, + name: Literal["cleveland", "hungarian", "switzerland", "va"], ca_cert: str, protocol: str = "websockets", serv_uri: str = "wss://localhost:8765", @@ -59,23 +55,20 @@ def run_client( # (1-2) Interface training and optional validation data. - # Load and randomly split the dataset. - path = os.path.join(FILEDIR, f"data/{name}.csv") - if not os.path.isfile(path): - get_data(os.path.join(FILEDIR, "data"), [name]) - data = pd.read_csv(path) + # Load and randomly split the dataset. Note: target is a str (column name). + data, target = load_heart_uci(name, folder=os.path.join(FILEDIR, "data")) data = data.loc[np.random.permutation(data.index)] n_tr = round(len(data) * 0.8) # 80% train, 20% valid # Wrap train and validation data as Dataset objects. train = InMemoryDataset( data=data.iloc[:n_tr], - target="num", + target=target, expose_classes=True, # share unique target labels with server ) valid = InMemoryDataset( data=data.iloc[n_tr:], - target="num", + target=target, ) # (3) Define network communication parameters. diff --git a/examples/heart-uci/data.py b/examples/heart-uci/data.py index 20ed3d6cb63264d99279ac3425bf0da801232324..d866d76707651fdb1a8b703fe66115c085e47b93 100644 --- a/examples/heart-uci/data.py +++ b/examples/heart-uci/data.py @@ -19,67 +19,12 @@ import argparse import os -from typing import Collection -import pandas as pd +from declearn.dataset.examples import load_heart_uci -NAMES = ("cleveland", "hungarian", "switzerland", "va") - -COLNAMES = [ - "age", - "sex", - "cp", - "trestbps", - "chol", - "fbs", - "restecg", - "thalach", - "exang", - "oldpeak", - "slope", - "ca", - "thal", - "num", -] DATADIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data") - - -def get_data( - folder: str = DATADIR, - names: Collection[str] = NAMES, -) -> None: - """Download and process the UCI heart disease dataset. - - Arguments - --------- - folder: str - Path to the folder where to write output csv files. - names: list[str] - Names of centers, the dataset from which to download, - pre-process and export as csv files. - """ - for name in names: - print(f"Downloading data from center {name}:") - url = ( - "https://archive.ics.uci.edu/ml/machine-learning-databases/" - f"heart-disease/processed.{name}.data" - ) - print(url) - # Download the dataset. - df = pd.read_csv(url, header=None, na_values="?") - df.columns = COLNAMES - # Drop unused columns and rows with missing values. - df.drop(columns=["ca", "chol", "fbs", "slope", "thal"], inplace=True) - df.dropna(inplace=True) - # Normalize quantitative variables. - for col in ("age", "trestbps", "thalach", "oldpeak"): - df[col] = (df[col] - df[col].mean()) / df[col].std() - # Binarize the target variable. - df["num"] = (df["num"] > 0).astype(int) - # Export the resulting dataset to a csv file. - os.makedirs(folder, exist_ok=True) - df.to_csv(f"{folder}/{name}.csv", index=False) +NAMES = ("cleveland", "hungarian", "switzerland", "va") # Code executed when the script is called directly. @@ -101,4 +46,5 @@ if __name__ == "__main__": ) args = parser.parse_args() # Download and pre-process the selected dataset(s). - get_data(folder=args.folder, names=args.names) + for name in args.names: + load_heart_uci(name=name, folder=args.folder) diff --git a/examples/heart-uci/run.py b/examples/heart-uci/run.py index 4a5ac9ade7d8ef050d332c23edcfa4c76cb8401f..1a35c362f7aaa86d194736e25eba54f22510eb32 100644 --- a/examples/heart-uci/run.py +++ b/examples/heart-uci/run.py @@ -20,11 +20,8 @@ import os import tempfile -from declearn.test_utils import ( - generate_ssl_certificates, - make_importable, - run_as_processes, -) +from declearn.test_utils import generate_ssl_certificates, make_importable +from declearn.utils import run_as_processes # Perform local imports. # pylint: disable=wrong-import-position, wrong-import-order diff --git a/examples/mnist/client.py b/examples/mnist/client.py new file mode 100644 index 0000000000000000000000000000000000000000..7dc57bc8f4d3feec529b8dd5bb0dd827872275ff --- /dev/null +++ b/examples/mnist/client.py @@ -0,0 +1,131 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Script to run a federated client on the heart-disease example.""" + +import datetime +import logging +import os + +import fire # type: ignore + +import declearn +import declearn.model.tensorflow + + +FILEDIR = os.path.dirname(__file__) + + +def run_client( + client_name: str, + ca_cert: str, + data_folder: str, + protocol: str = "websockets", + serv_uri: str = "wss://localhost:8765", + verbose: bool = True, +) -> None: + """Instantiate and run a given client. + + Parameters + --------- + client_name: str + Name of the client (i.e. center data from which to use). + ca_cert: str + Path to the certificate authority file that was used to + sign the server's SSL certificate. + data_folder: str + The parent folder of this client's data + protocol: str, default="websockets" + Name of the communication protocol to use. + serv_uri: str, default="wss://localhost:8765" + URI of the server to which to connect. + verbose: + Whether to log everything to the console, or filter out most non-error + information. + """ + + ### Optional: some convenience settings + + # Set CPU as device + declearn.utils.set_device_policy(gpu=False) + + # Set up logger and checkpointer + stamp = datetime.datetime.now().strftime("%y-%m-%d_%H-%M") + checkpoint = os.path.join(FILEDIR, f"result_{stamp}", client_name) + logger = declearn.utils.get_logger( + name=client_name, + fpath=os.path.join(checkpoint, "logs.txt"), + ) + + # Reduce logger verbosity + if not verbose: + for handler in logger.handlers: + if isinstance(handler, logging.StreamHandler): + handler.setLevel(declearn.utils.LOGGING_LEVEL_MAJOR) + + ### (1-2) Interface training and optional validation data. + + # Target the proper dataset (specific to our MNIST setup). + data_folder = os.path.join(FILEDIR, data_folder, client_name) + + # Interface the data through the generic `InMemoryDataset` class. + train = declearn.dataset.InMemoryDataset( + os.path.join(data_folder, "train_data.npy"), + os.path.join(data_folder, "train_target.npy"), + ) + valid = declearn.dataset.InMemoryDataset( + os.path.join(data_folder, "valid_data.npy"), + os.path.join(data_folder, "valid_target.npy"), + ) + + ### (3) Define network communication parameters. + + # Here, use websockets protocol on localhost:8765, + # with SSL encryption. + network = declearn.communication.build_client( + protocol=protocol, + server_uri=serv_uri, + name=client_name, + certificate=ca_cert, + ) + + ### (4) Run any necessary import statement. + # We imported `import declearn.model.tensorflow` + + ### (5) Instantiate a FederatedClient and run it. + + client = declearn.main.FederatedClient( + netwk=network, + train_data=train, + valid_data=valid, + checkpoint=checkpoint, + logger=logger, + ) + client.run() + + +# This part should not be altered: it provides with an argument parser +# for `python client.py`. + + +def main(): + "Fire-wrapped `run_client`." + fire.Fire(run_client) + + +if __name__ == "__main__": + main() diff --git a/examples/mnist/gen_ssl.py b/examples/mnist/gen_ssl.py new file mode 100644 index 0000000000000000000000000000000000000000..94f81e982c85f0813abf342337387922482b0eeb --- /dev/null +++ b/examples/mnist/gen_ssl.py @@ -0,0 +1,27 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Script to generate self-signed SSL certificates for the demo.""" + +import os + +from declearn.test_utils import generate_ssl_certificates + + +if __name__ == "__main__": + FILEDIR = os.path.dirname(os.path.abspath(__file__)) + generate_ssl_certificates(FILEDIR) diff --git a/examples/mnist/readme.md b/examples/mnist/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..bdff64b17812e116aac6836a5220b1b7e55c8218 --- /dev/null +++ b/examples/mnist/readme.md @@ -0,0 +1,133 @@ +# Demo training task : MNIST + +## Overview + +**We are going to train a common model between three simulated clients on the +classic [MNIST dataset](http://yann.lecun.com/exdb/mnist/)**. The input of the +model is a set of images of handwritten digits, and the model needs to +determine to which digit between $0$ and $9$ each image corresponds. + +## Setup + +To be able to experiment with this tutorial: + +* Clone the declearn repo (you may specify a given release branch or tag): + +```bash +git clone git@gitlab.inria.fr:magnet/declearn/declearn2.git declearn +``` + +* Create a dedicated virtual environment. +* Install declearn in it from the local repo: + +```bash +cd declearn && pip install .[websockets,tensorflow] && cd .. +``` + +In an FL experiment, we consider your data as a given. So before running +the experiment below, download and split the MNIST data using: + +```bash +declearn-split --folder "examples/mnist" --n_shards 3 +``` + +You may add `--seed <some_number>` if you want to ensure reproducibility. + +## Contents + +This script runs a FL experiment using MNIST. The folder is structured +the following way: + +``` +mnist/ +│ client.py - set up and launch a federated-learning client +│ gen_ssl.py - generate self-signed ssl certificates +│ run.py - launch both the server and clients in a single session +│ server.py - set up and launch a federated-learning server +└─── data - data split by client, created with the `split_data` util +└─── results - saved results from training procedure +``` + +## Run training routine + +The simplest way to run the demo is to run it locally, using multiprocessing. +For something closer to real life implementation, we also show a way to run +the demo from different terminals or machines. + +### Locally, for testing and experimentation + +**To simply run the demo**, use the bash command below. You can follow along +the code in the `hands-on` section of the package documentation. For more +details on what running the federated learning processes imply, see the last +section. + +```bash +cd declearn/examples/mnist/ +python run.py # note: python declearn/examples/mnist/run.py works as well +``` + +The `run.py` scripts collects the server and client routines defined under +the `server.py` and `client.py` scripts, and runs them concurrently under +a single python session using multiprocessing. + +This is the easiest way to launch the demo, e.g. to see the effects of +tweaking some learning parameters. + +### On separate terminals or machines + +**To run the examples from different terminals or machines**, +we first ensure data is appropriately distributed between machines, +and the machines can communicate over network using SSL-encrypted +communications. We give the code to simulate this on a single machine. +We then sequentially run the server then the clients on separate terminals. + +1. **Set up SSL certificates**:<br/> + Start by creating a signed SSL certificate for the server and sharing the + CA file with each and every clients. The CA may be self-signed. + + When testing locally, execute the `gen_ssl.py` script, to create a + self-signed root CA and an SSL certificate for "localhost": + + ```bash + python gen_ssl.py + ``` + + Note that in real-life applications, one would most likely use certificates + certificates signed by a trusted certificate authority instead. + Alternatively, `declearn.test_utils.gen_ssl_certificates` may be used to + generate a self-signed CA and a signed certificate for a given domain name + or IP address. + +2. **Run the server**:<br/> + Open a terminal and launch the server script for 1 to 4 clients, + specifying the path to the SSL certificate and private key files, + and network parameters. By default, things will run on the local + host, looking for `gen_ssl.py`-created PEM files. + + E.g., to use 2 clients: + + ```bash + python server.py 2 # use --help for details on network and SSL options + ``` + +3. **Run each client**:<br/> + Open a new terminal and launch the client script, specifying one of the + dataset-provider names, and optionally the path the CA file and network + parameters. By default, things will run on the local host, looking for + a `gen_ssl.py`-created CA PEM file. + + E.g., to launch a client using the "cleveland" dataset: + + ```bash + python client.py cleveland # use --help for details on other options + ``` + +Note that the server should be launched before the clients, otherwise the +latter might fail to connect which would cause the script to terminate. A +few seconds' delay is tolerable as clients will make multiple connection +attempts prior to failing. + +**To run the example in a real-life setting**, follow the instructions from +this section, after having generated and shared the appropriate PEM files to +set up SSL-encryption, and using additional script parameters to specify the +network host and port to use. diff --git a/examples/mnist/run.py b/examples/mnist/run.py new file mode 100644 index 0000000000000000000000000000000000000000..35edefc291850c4938566a0dc2dbd40d194e0a27 --- /dev/null +++ b/examples/mnist/run.py @@ -0,0 +1,72 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Demonstration script using the UCI Heart Disease Dataset.""" + +import glob +import os +import tempfile + +import fire # type: ignore + +from declearn.test_utils import generate_ssl_certificates, make_importable +from declearn.utils import run_as_processes + +# Perform local imports. +# pylint: disable=wrong-import-position, wrong-import-order +with make_importable(os.path.dirname(__file__)): + from client import run_client + from server import run_server +# pylint: enable=wrong-import-position, wrong-import-order + +FILEDIR = os.path.join(os.path.dirname(__file__)) +DATADIR = glob.glob(f"{FILEDIR}/data*")[0] + + +def run_demo(nb_clients: int = 3, data_folder: str = DATADIR) -> None: + """Run a server and its clients using multiprocessing. + + Parameters + ------ + + n_clients: int + number of clients to run. + data_folder: str + Relative path to the folder holding client's data + + """ + # Use a temporary directory for single-use self-signed SSL files. + with tempfile.TemporaryDirectory() as folder: + # Generate self-signed SSL certificates and gather their paths. + ca_cert, sv_cert, sv_pkey = generate_ssl_certificates(folder) + # Specify the server and client routines that need executing. + server = (run_server, (nb_clients, sv_cert, sv_pkey)) + clients = [ + (run_client, (f"client_{idx}", ca_cert, data_folder)) + for idx in range(nb_clients) + ] + # Run routines in isolated processes. Raise if any failed. + success, outp = run_as_processes(server, *clients) + if not success: + raise RuntimeError( + "Something went wrong during the demo. Exceptions caught:\n" + "\n".join(str(e) for e in outp if isinstance(e, RuntimeError)) + ) + + +if __name__ == "__main__": + fire.Fire(run_demo) diff --git a/examples/mnist/server.py b/examples/mnist/server.py new file mode 100644 index 0000000000000000000000000000000000000000..dda597e6d95ee93adac78495966f353573586378 --- /dev/null +++ b/examples/mnist/server.py @@ -0,0 +1,190 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Script to run a federated server on the heart-disease example.""" + +import datetime +import os + +import fire # type: ignore +import tensorflow as tf # type: ignore + +import declearn + + +FILEDIR = os.path.dirname(os.path.abspath(__file__)) +DEFAULT_CERT = os.path.join(FILEDIR, "server-cert.pem") +DEFAULT_PKEY = os.path.join(FILEDIR, "server-pkey.pem") + + +def run_server( + nb_clients: int, + certificate: str = DEFAULT_CERT, + private_key: str = DEFAULT_PKEY, + protocol: str = "websockets", + host: str = "localhost", + port: int = 8765, +) -> None: + """Instantiate and run the orchestrating server. + + Arguments + --------- + nb_clients: int + Exact number of clients used in this example. + certificate: str + Path to the (self-signed) SSL certificate to use. + private_key: str + Path to the associated private-key to use. + protocol: str, default="websockets" + Name of the communication protocol to use. + host: str, default="localhost" + Hostname or IP address on which to serve. + port: int, default=8765 + Communication port on which to serve. + """ + + ### Optional: some convenience settings + + # Set CPU as device + declearn.utils.set_device_policy(gpu=False) + + # Set up metrics suitable for MNIST. + metrics = declearn.metrics.MetricSet( + [ + declearn.metrics.MulticlassAccuracyPrecisionRecall( + labels=range(10) + ), + ] + ) + + # Set up checkpointing and logging. + stamp = datetime.datetime.now().strftime("%y-%m-%d_%H-%M") + checkpoint = os.path.join(FILEDIR, f"result_{stamp}", "server") + # Set up a logger, records from which will go to a file. + logger = declearn.utils.get_logger( + name="Server", + fpath=os.path.join(checkpoint, "logs.txt"), + ) + + ### (1) Define a model + + # Here we use a scikit-learn SGD classifier and parametrize it + # into a L2-penalized binary logistic regression. + stack = [ + tf.keras.layers.InputLayer(input_shape=(28, 28, 1)), + tf.keras.layers.Conv2D(32, 3, 1, activation="relu"), + tf.keras.layers.MaxPool2D(2), + tf.keras.layers.Dropout(0.25), + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(10, activation="softmax"), + ] + model = declearn.model.tensorflow.TensorflowModel( + model=tf.keras.Sequential(stack), + loss="sparse_categorical_crossentropy", + ) + + ### (2) Define an optimization strategy + + # Set up the cient updates' aggregator. By default: FedAvg. + aggregator = declearn.aggregator.AveragingAggregator() + + # Set up the server-side optimizer (to refine aggregated updates). + # By default: no refinement (lrate=1.0, no plug-ins). + server_opt = declearn.optimizer.Optimizer( + lrate=1.0, + w_decay=0.0, + modules=None, + ) + + # Set up the client-side optimizer (for local SGD steps). + # By default: vanilla SGD, with a selected learning rate. + client_opt = declearn.optimizer.Optimizer( + lrate=0.001, + w_decay=0.0, + regularizers=None, + modules=None, + ) + + # Wrap all this into a FLOptimConfig. + optim = declearn.main.config.FLOptimConfig.from_params( + aggregator=aggregator, + server_opt=server_opt, + client_opt=client_opt, + ) + + ### (3) Define network communication parameters. + + # Here, use websockets protocol on localhost:8765, with SSL encryption. + network = declearn.communication.build_server( + protocol=protocol, + host=host, + port=port, + certificate=certificate, + private_key=private_key, + ) + + ### (4) Instantiate and run a FederatedServer. + + # Instanciate + server = declearn.main.FederatedServer( + model=model, + netwk=network, + optim=optim, + metrics=metrics, + checkpoint=checkpoint, + logger=logger, + ) + + # Set up the experiment's hyper-parameters. + # Registration rules: wait for exactly `nb_clients`, at most 5 minutes. + register = declearn.main.config.RegisterConfig( + min_clients=nb_clients, + max_clients=nb_clients, + timeout=300, + ) + # Training rounds hyper-parameters. By default, 1 epoch / round. + training = declearn.main.config.TrainingConfig( + batch_size=32, + n_epoch=1, + ) + # Evaluation rounds. by default, 1 epoch with train's batch size. + evaluate = declearn.main.config.EvaluateConfig( + batch_size=128, + ) + # Wrap all this into a FLRunConfig. + run_config = declearn.main.config.FLRunConfig.from_params( + rounds=5, # you may change the number of training rounds + register=register, + training=training, + evaluate=evaluate, + privacy=None, # you may set up local DP (DP-SGD) here + early_stop=None, # you may add an early-stopping cirterion here + ) + server.run(run_config) + + +# This part should not be altered: it provides with an argument parser. +# for `python server.py`. + + +def main(): + "Fire-wrapped `run_server`." + fire.Fire(run_server) + + +if __name__ == "__main__": + main() diff --git a/examples/mnist_quickrun/config.toml b/examples/mnist_quickrun/config.toml new file mode 100644 index 0000000000000000000000000000000000000000..b01422a743b1631857f799686e13be08f60df34b --- /dev/null +++ b/examples/mnist_quickrun/config.toml @@ -0,0 +1,41 @@ +# This is a minimal TOML file for the MNIST example +# It contains the bare minimum to make the experiment run. +# See quickstart for more details. + +# The TOML is parsed by python as dictionnary with each `[header]` +# as a key. Note the "=" sign and the absence of quotes around keys. +# For more details, see the full doc : https://toml.io/en/ + +[network] # Network configuration used by both client and server + protocol = "websockets" # Protocol used, to keep things simple use websocket + host = "127.0.0.1" # Address used, works as-is on most set ups + port = 8765 # Port used, works as-is on most set ups + +[data] # Where to find your data + data_folder = "examples/mnist_quickrun/data_iid" + +[optim] # Optimization options for both client and server + aggregator = "averaging" # Server aggregation strategy + + [optim.client_opt] # Client optimization strategy + lrate = 0.001 # Client learning rate + modules = ["adam"] # List of optimizer modules used + + [optim.server_opt] # Server optimization strategy + lrate = 1.0 # Server learning rate + +[run] # Training process option for both client and server + rounds = 10 # Number of overall training rounds + + [run.register] # Client registration options + timeout = 5 # How long to wait for clients, in seconds + + [run.training] # Client training options + batch_size = 48 # Training batch size + + [run.evaluate] # Client evaluation options + batch_size = 128 # Evaluation batch size + +[experiment] # What to report during the experiment and where to report it + metrics=[["multi-classif",{labels = [0,1,2,3,4,5,6,7,8,9]}]] # Accuracy metric + diff --git a/examples/mnist_quickrun/mnist.ipynb b/examples/mnist_quickrun/mnist.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..a55dc891404420c67a26b96019212f4c22864cba --- /dev/null +++ b/examples/mnist_quickrun/mnist.ipynb @@ -0,0 +1,525 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This notebook is meant to be run in google colab. You can find import your local copy of the file in the the [colab welcome page](https://colab.research.google.com/)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "s9bpLdH5ThpJ" + }, + "source": [ + "# Setting up your declearn " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Clzf4NTja121" + }, + "source": [ + "We first clone the repo, to have both the package itself and the `examples` folder we will use in this tutorial, then naviguate to the package directory, and finally install the required dependencies" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "u2QDwb0_QQ_f", + "outputId": "cac0761c-b229-49b0-d71d-c7b5cef919b3" + }, + "outputs": [], + "source": [ + "# you may want to specify a release branch or tag\n", + "!git clone https://gitlab.inria.fr/magnet/declearn/declearn2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "9kDHh_AfPG2l", + "outputId": "74e2f85f-7f93-40ae-a218-f4403470d72c" + }, + "outputs": [], + "source": [ + "cd declearn2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Un212t1GluHB", + "outputId": "0ea67577-da6e-4f80-a412-7b7a79803aa1" + }, + "outputs": [], + "source": [ + "# Install the package, with TensorFlow and Websockets extra dependencies.\n", + "# You may want to work in a dedicated virtual environment.\n", + "!pip install .[tensorflow,websockets]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hC8Fty8YTy9P" + }, + "source": [ + "# Running your first experiment" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rcWcZJdob1IG" + }, + "source": [ + "We are going to train a common model between three simulated clients on the classic [MNIST dataset](http://yann.lecun.com/exdb/mnist/). The input of the model is a set of images of handwritten digits, and the model needs to determine which number between 0 and 9 each image corresponds to." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KlY_vVtFHv2P" + }, + "source": [ + "## The model\n", + "\n", + "To do this, we will use a simple CNN, defined in `examples/mnist_quickrun/model.py`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "C7D52a8_dEr7", + "outputId": "a25223f8-c8eb-4998-d7fd-4b8bfde92486" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: \"sequential\"\n", + "_________________________________________________________________\n", + " Layer (type) Output Shape Param # \n", + "=================================================================\n", + " conv2d (Conv2D) (None, 26, 26, 8) 80 \n", + " \n", + " max_pooling2d (MaxPooling2D (None, 13, 13, 8) 0 \n", + " ) \n", + " \n", + " dropout (Dropout) (None, 13, 13, 8) 0 \n", + " \n", + " flatten (Flatten) (None, 1352) 0 \n", + " \n", + " dense (Dense) (None, 64) 86592 \n", + " \n", + " dropout_1 (Dropout) (None, 64) 0 \n", + " \n", + " dense_1 (Dense) (None, 10) 650 \n", + " \n", + "=================================================================\n", + "Total params: 87,322\n", + "Trainable params: 87,322\n", + "Non-trainable params: 0\n", + "_________________________________________________________________\n" + ] + } + ], + "source": [ + "from examples.mnist_quickrun.model import network\n", + "network.summary()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HoBcOs9hH2QA" + }, + "source": [ + "## The data\n", + "\n", + "We start by splitting the MNIST dataset between 3 clients and storing the output in the `examples/mnist_quickrun` folder. For this we use an experimental utility provided by `declearn`. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "quduXkpIWFjL", + "outputId": "ddf7d45d-acf0-44ee-ce77-357c0987a2a1" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading MNIST source file train-images-idx3-ubyte.gz.\n", + "Downloading MNIST source file train-labels-idx1-ubyte.gz.\n", + "Splitting data into 3 shards using the 'iid' scheme.\n" + ] + } + ], + "source": [ + "from declearn.dataset import split_data\n", + "\n", + "split_data(folder=\"examples/mnist_quickrun\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The python code above is equivalent to running `declearn-split examples/mnist_quickrun/` in a shell command-line." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3-2hKmz-2RF4" + }, + "source": [ + "Here is what the first image of the first client looks like:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 430 + }, + "id": "MLVI9GOZ1TGd", + "outputId": "f34a6a93-cb5f-4a45-bc24-4146ea119d1a" + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAbo0lEQVR4nO3df2zV9fXH8dct0itoe7ta29tKy1r8gYrUDKR2Kv6goXQJESQLikvAGJxYjMicpkZBtiXdMPPrNAz/cTATUcQJRDNJsNgStxZDlRCmVsq6UQYtwsK9pUhh7fv7B+HqhfLjc7m3597yfCQ3offe03v8eO3Ty7188DnnnAAAGGBp1gsAAC5OBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJi4xHqBU/X19Wnv3r3KyMiQz+ezXgcA4JFzTl1dXSooKFBa2plf5yRdgPbu3avCwkLrNQAAF6i9vV0jRow44+1JF6CMjAxJJxbPzMw03gYA4FU4HFZhYWHk5/mZJCxAy5Yt04svvqiOjg6Vlpbq1Vdf1YQJE845d/K33TIzMwkQAKSwc72NkpAPIaxevVoLFy7U4sWL9dlnn6m0tFSVlZXav39/Ih4OAJCCEhKgl156SXPnztVDDz2kG264Qa+99pqGDx+uP/3pT4l4OABACop7gI4dO6bm5mZVVFR89yBpaaqoqFBjY+Np9+/p6VE4HI66AAAGv7gH6MCBA+rt7VVeXl7U9Xl5eero6Djt/rW1tQoEApELn4ADgIuD+R9ErampUSgUilza29utVwIADIC4fwouJydHQ4YMUWdnZ9T1nZ2dCgaDp93f7/fL7/fHew0AQJKL+yug9PR0jRs3TnV1dZHr+vr6VFdXp/Ly8ng/HAAgRSXkzwEtXLhQs2fP1vjx4zVhwgS9/PLL6u7u1kMPPZSIhwMApKCEBGjmzJn65ptvtGjRInV0dOjmm2/Whg0bTvtgAgDg4uVzzjnrJb4vHA4rEAgoFApxJgQASEHn+3Pc/FNwAICLEwECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGDiEusFAJyfnp4ezzMHDhyI6bFef/31mOa8+vjjjz3P1NfXe57x+XyeZ2IVyz/TnXfemYBNkh+vgAAAJggQAMBE3AP0wgsvyOfzRV1Gjx4d74cBAKS4hLwHdOONN+qjjz767kEu4a0mAEC0hJThkksuUTAYTMS3BgAMEgl5D2jnzp0qKChQSUmJHnzwQe3evfuM9+3p6VE4HI66AAAGv7gHqKysTCtXrtSGDRu0fPlytbW16Y477lBXV1e/96+trVUgEIhcCgsL470SACAJxT1AVVVV+ulPf6qxY8eqsrJSf/3rX3Xo0CG98847/d6/pqZGoVAocmlvb4/3SgCAJJTwTwdkZWXp2muvVWtra7+3+/1++f3+RK8BAEgyCf9zQIcPH9auXbuUn5+f6IcCAKSQuAfoqaeeUkNDg/71r3/p73//u6ZPn64hQ4bogQceiPdDAQBSWNx/C27Pnj164IEHdPDgQV155ZW6/fbb1dTUpCuvvDLeDwUASGFxD9Dbb78d728JJLXe3l7PMxs3bvQ8s2TJEs8zn376qeeZZJeWltxnEAuFQtYrpIzk/jcJABi0CBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATCf8L6ZD8nHMxzfX19XmeGagTScaymyR98803nmeqqqo8z2zfvt3zDAbekCFDPM8UFRUlYJPBiVdAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMMHZsKEdO3bENHfzzTd7nvnHP/7heebbb7/1PDN+/HjPM0gNGRkZnmdGjx4d02P95S9/8Txz1VVXxfRYFyNeAQEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJjgZ6SDzz3/+0/PMrFmzErBJ//bv3+95pqamJgGbXBwuvfTSmOYefPDBOG/SvyeeeMLzTFZWlucZThCanHgFBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCY4GSkg8yBAwc8z3zxxRcxPZbP5/M809TU5Hlmy5YtnmfS0mL7f6thw4bFNOfVzJkzPc/cf//9nmfKyso8z0jS5ZdfHtMc4AWvgAAAJggQAMCE5wBt3rxZU6dOVUFBgXw+n9atWxd1u3NOixYtUn5+voYNG6aKigrt3LkzXvsCAAYJzwHq7u5WaWmpli1b1u/tS5cu1SuvvKLXXntNW7Zs0WWXXabKykodPXr0gpcFAAwenj+EUFVVpaqqqn5vc87p5Zdf1nPPPad7771XkvTGG28oLy9P69ati+lNVADA4BTX94Da2trU0dGhioqKyHWBQEBlZWVqbGzsd6anp0fhcDjqAgAY/OIaoI6ODklSXl5e1PV5eXmR205VW1urQCAQuRQWFsZzJQBAkjL/FFxNTY1CoVDk0t7ebr0SAGAAxDVAwWBQktTZ2Rl1fWdnZ+S2U/n9fmVmZkZdAACDX1wDVFxcrGAwqLq6ush14XBYW7ZsUXl5eTwfCgCQ4jx/Cu7w4cNqbW2NfN3W1qZt27YpOztbRUVFWrBggX7zm9/ommuuUXFxsZ5//nkVFBRo2rRp8dwbAJDiPAdo69atuvvuuyNfL1y4UJI0e/ZsrVy5Uk8//bS6u7v1yCOP6NChQ7r99tu1YcMGXXrppfHbGgCQ8nzOOWe9xPeFw2EFAgGFQiHeD4rBrFmzPM+sXr06psfKzs72PBPLiU+//vprzzN+v9/zjCSNHz8+pjkA3znfn+Pmn4IDAFycCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYMLzX8cAnPTf//7X88zUqVM9z6xatcrzTElJiecZAAOLV0AAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAmfc85ZL/F94XBYgUBAoVBImZmZ1uuknMbGRs8zt99+ewI2iZ/hw4d7npkzZ05Mj7VkyRLPM7HsN3ToUM8zQ4YM8TwDWDjfn+O8AgIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATHAy0kEmHA57nlm+fHlMj/Xss8/GNJfMrr76as8ze/bs8Txzzz33eJ6JZbdYTZ8+3fPMrbfe6nkmPT3d8wySHycjBQAkNQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABCcjhfr6+mKae+qppzzPvPvuu55n/vOf/3iewcD7+c9/7nlm0aJFnmeCwaDnGQwsTkYKAEhqBAgAYMJzgDZv3qypU6eqoKBAPp9P69ati7p9zpw58vl8UZcpU6bEa18AwCDhOUDd3d0qLS3VsmXLznifKVOmaN++fZHLW2+9dUFLAgAGn0u8DlRVVamqquqs9/H7/bxRCAA4q4S8B1RfX6/c3Fxdd911mjdvng4ePHjG+/b09CgcDkddAACDX9wDNGXKFL3xxhuqq6vT7373OzU0NKiqqkq9vb393r+2tlaBQCByKSwsjPdKAIAk5Pm34M7l/vvvj/z6pptu0tixYzVq1CjV19dr0qRJp92/pqZGCxcujHwdDoeJEABcBBL+MeySkhLl5OSotbW139v9fr8yMzOjLgCAwS/hAdqzZ48OHjyo/Pz8RD8UACCFeP4tuMOHD0e9mmlra9O2bduUnZ2t7OxsLVmyRDNmzFAwGNSuXbv09NNP6+qrr1ZlZWVcFwcApDbPAdq6davuvvvuyNcn37+ZPXu2li9fru3bt+vPf/6zDh06pIKCAk2ePFm//vWv5ff747c1ACDlcTJSDKhjx455nvnf//7neebNN9/0PCNJX331leeZP/zhD55nkuw/OzOvvvqq55nHHnssAZsgnjgZKQAgqREgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEZ8MGLtDWrVs9z7z77rueZ7788kvPMx9++KHnGUnq7e2Nac6r4cOHe57ZsWOH55mRI0d6nkHsOBs2ACCpESAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmLrFeAEh148ePH5CZWHz99dcxzV1//fVx3qR/R44c8TzT2dnpeYaTkSYnXgEBAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACY4GSkwiF111VXWK5xVVlaW55mSkpL4LwITvAICAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAExwMlIMqD179nieGT58uOeZ7OxszzMDqaenx/NMW1ub55nf//73nmcG0oQJEzzP5OTkJGATWOAVEADABAECAJjwFKDa2lrdcsstysjIUG5urqZNm6aWlpao+xw9elTV1dW64oordPnll2vGjBnq7OyM69IAgNTnKUANDQ2qrq5WU1OTNm7cqOPHj2vy5Mnq7u6O3OfJJ5/U+++/rzVr1qihoUF79+7VfffdF/fFAQCpzdOHEDZs2BD19cqVK5Wbm6vm5mZNnDhRoVBIr7/+ulatWqV77rlHkrRixQpdf/31ampq0q233hq/zQEAKe2C3gMKhUKSvvvEUXNzs44fP66KiorIfUaPHq2ioiI1Njb2+z16enoUDoejLgCAwS/mAPX19WnBggW67bbbNGbMGElSR0eH0tPTT/t73vPy8tTR0dHv96mtrVUgEIhcCgsLY10JAJBCYg5QdXW1duzYobfffvuCFqipqVEoFIpc2tvbL+j7AQBSQ0x/EHX+/Pn64IMPtHnzZo0YMSJyfTAY1LFjx3To0KGoV0GdnZ0KBoP9fi+/3y+/3x/LGgCAFObpFZBzTvPnz9fatWu1adMmFRcXR90+btw4DR06VHV1dZHrWlpatHv3bpWXl8dnYwDAoODpFVB1dbVWrVql9evXKyMjI/K+TiAQ0LBhwxQIBPTwww9r4cKFys7OVmZmph5//HGVl5fzCTgAQBRPAVq+fLkk6a677oq6fsWKFZozZ44k6f/+7/+UlpamGTNmqKenR5WVlfrjH/8Yl2UBAIOHzznnrJf4vnA4rEAgoFAopMzMTOt1cBaffvqp55nvf0T/fF122WWeZ4qKijzPDKSuri7PM6eedWQwaG5u9jxz8803x38RxNX5/hznXHAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwEdPfiIrB5ciRIzHN/fjHP/Y8E8vJ17u7uz3P7N+/3/MMLkxbW5vnmcLCwgRsglTBKyAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQnI4XWr18f01wsJxbFwLr11ltjmlu2bJnnmdzcXM8zPp/P8wwGD14BAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmOBkpVFpaar3CRWfmzJkDMjN58mTPM5I0bNiwmOYAL3gFBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCY4GSk0A033BDTXG9vb5w3AXAx4RUQAMAEAQIAmPAUoNraWt1yyy3KyMhQbm6upk2bppaWlqj73HXXXfL5fFGXRx99NK5LAwBSn6cANTQ0qLq6Wk1NTdq4caOOHz+uyZMnq7u7O+p+c+fO1b59+yKXpUuXxnVpAEDq8/QhhA0bNkR9vXLlSuXm5qq5uVkTJ06MXD98+HAFg8H4bAgAGJQu6D2gUCgkScrOzo66/s0331ROTo7GjBmjmpoaHTly5Izfo6enR+FwOOoCABj8Yv4Ydl9fnxYsWKDbbrtNY8aMiVw/a9YsjRw5UgUFBdq+fbueeeYZtbS06L333uv3+9TW1mrJkiWxrgEASFE+55yLZXDevHn68MMP9cknn2jEiBFnvN+mTZs0adIktba2atSoUafd3tPTo56ensjX4XBYhYWFCoVCyszMjGU1AIChcDisQCBwzp/jMb0Cmj9/vj744ANt3rz5rPGRpLKyMkk6Y4D8fr/8fn8sawAAUpinADnn9Pjjj2vt2rWqr69XcXHxOWe2bdsmScrPz49pQQDA4OQpQNXV1Vq1apXWr1+vjIwMdXR0SJICgYCGDRumXbt2adWqVfrJT36iK664Qtu3b9eTTz6piRMnauzYsQn5BwAApCZP7wH5fL5+r1+xYoXmzJmj9vZ2/exnP9OOHTvU3d2twsJCTZ8+Xc8999x5v59zvr93CABITgl5D+hcrSosLFRDQ4OXbwkAuEhxLjgAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgIlLrBc4lXNOkhQOh403AQDE4uTP75M/z88k6QLU1dUlSSosLDTeBABwIbq6uhQIBM54u8+dK1EDrK+vT3v37lVGRoZ8Pl/UbeFwWIWFhWpvb1dmZqbRhvY4DidwHE7gOJzAcTghGY6Dc05dXV0qKChQWtqZ3+lJuldAaWlpGjFixFnvk5mZeVE/wU7iOJzAcTiB43ACx+EE6+Nwtlc+J/EhBACACQIEADCRUgHy+/1avHix/H6/9SqmOA4ncBxO4DicwHE4IZWOQ9J9CAEAcHFIqVdAAIDBgwABAEwQIACACQIEADCRMgFatmyZfvjDH+rSSy9VWVmZPv30U+uVBtwLL7wgn88XdRk9erT1Wgm3efNmTZ06VQUFBfL5fFq3bl3U7c45LVq0SPn5+Ro2bJgqKiq0c+dOm2UT6FzHYc6cOac9P6ZMmWKzbILU1tbqlltuUUZGhnJzczVt2jS1tLRE3efo0aOqrq7WFVdcocsvv1wzZsxQZ2en0caJcT7H4a677jrt+fDoo48abdy/lAjQ6tWrtXDhQi1evFifffaZSktLVVlZqf3791uvNuBuvPFG7du3L3L55JNPrFdKuO7ubpWWlmrZsmX93r506VK98soreu2117RlyxZddtllqqys1NGjRwd408Q613GQpClTpkQ9P956660B3DDxGhoaVF1draamJm3cuFHHjx/X5MmT1d3dHbnPk08+qffff19r1qxRQ0OD9u7dq/vuu89w6/g7n+MgSXPnzo16PixdutRo4zNwKWDChAmuuro68nVvb68rKChwtbW1hlsNvMWLF7vS0lLrNUxJcmvXro183dfX54LBoHvxxRcj1x06dMj5/X731ltvGWw4ME49Ds45N3v2bHfvvfea7GNl//79TpJraGhwzp34dz906FC3Zs2ayH2+/PJLJ8k1NjZarZlwpx4H55y788473RNPPGG31HlI+ldAx44dU3NzsyoqKiLXpaWlqaKiQo2NjYab2di5c6cKCgpUUlKiBx98ULt377ZeyVRbW5s6Ojqinh+BQEBlZWUX5fOjvr5eubm5uu666zRv3jwdPHjQeqWECoVCkqTs7GxJUnNzs44fPx71fBg9erSKiooG9fPh1ONw0ptvvqmcnByNGTNGNTU1OnLkiMV6Z5R0JyM91YEDB9Tb26u8vLyo6/Py8vTVV18ZbWWjrKxMK1eu1HXXXad9+/ZpyZIluuOOO7Rjxw5lZGRYr2eio6NDkvp9fpy87WIxZcoU3XfffSouLtauXbv07LPPqqqqSo2NjRoyZIj1enHX19enBQsW6LbbbtOYMWMknXg+pKenKysrK+q+g/n50N9xkKRZs2Zp5MiRKigo0Pbt2/XMM8+opaVF7733nuG20ZI+QPhOVVVV5Ndjx45VWVmZRo4cqXfeeUcPP/yw4WZIBvfff3/k1zfddJPGjh2rUaNGqb6+XpMmTTLcLDGqq6u1Y8eOi+J90LM503F45JFHIr++6aablJ+fr0mTJmnXrl0aNWrUQK/Zr6T/LbicnBwNGTLktE+xdHZ2KhgMGm2VHLKysnTttdeqtbXVehUzJ58DPD9OV1JSopycnEH5/Jg/f74++OADffzxx1F/fUswGNSxY8d06NChqPsP1ufDmY5Df8rKyiQpqZ4PSR+g9PR0jRs3TnV1dZHr+vr6VFdXp/LycsPN7B0+fFi7du1Sfn6+9SpmiouLFQwGo54f4XBYW7ZsueifH3v27NHBgwcH1fPDOaf58+dr7dq12rRpk4qLi6NuHzdunIYOHRr1fGhpadHu3bsH1fPhXMehP9u2bZOk5Ho+WH8K4ny8/fbbzu/3u5UrV7ovvvjCPfLIIy4rK8t1dHRYrzagfvGLX7j6+nrX1tbm/va3v7mKigqXk5Pj9u/fb71aQnV1dbnPP//cff75506Se+mll9znn3/u/v3vfzvnnPvtb3/rsrKy3Pr169327dvdvffe64qLi923335rvHl8ne04dHV1uaeeeso1Nja6trY299FHH7kf/ehH7pprrnFHjx61Xj1u5s2b5wKBgKuvr3f79u2LXI4cORK5z6OPPuqKiorcpk2b3NatW115ebkrLy833Dr+znUcWltb3a9+9Su3detW19bW5tavX+9KSkrcxIkTjTePlhIBcs65V1991RUVFbn09HQ3YcIE19TUZL3SgJs5c6bLz8936enp7qqrrnIzZ850ra2t1msl3Mcff+wknXaZPXu2c+7ER7Gff/55l5eX5/x+v5s0aZJraWmxXToBznYcjhw54iZPnuyuvPJKN3ToUDdy5Eg3d+7cQfc/af3980tyK1asiNzn22+/dY899pj7wQ9+4IYPH+6mT5/u9u3bZ7d0ApzrOOzevdtNnDjRZWdnO7/f766++mr3y1/+0oVCIdvFT8FfxwAAMJH07wEBAAYnAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMDE/wPgnA/bT9IQRgAAAABJRU5ErkJggg==", + "text/plain": [ + "<Figure size 640x480 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "\n", + "images = np.load(\"examples/mnist_quickrun/data_iid/client_0/train_data.npy\")\n", + "sample_img = images[0]\n", + "sample_fig = plt.imshow(sample_img,cmap='Greys')\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1vNWNGjefSfH" + }, + "source": [ + "For more information on how the `split_data` function works, you can look at the documentation. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "-wORmq5DYfRF", + "outputId": "4d79da63-ccad-4622-e600-ac36fae1ff3f" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Randomly split a dataset into shards.\n", + "\n", + " The resulting folder structure is :\n", + " folder/\n", + " └─── data*/\n", + " └─── client*/\n", + " │ train_data.* - training data\n", + " │ train_target.* - training labels\n", + " │ valid_data.* - validation data\n", + " │ valid_target.* - validation labels\n", + " └─── client*/\n", + " │ ...\n", + "\n", + " Parameters\n", + " ----------\n", + " folder: str, default = \".\"\n", + " Path to the folder where to add a data folder\n", + " holding output shard-wise files\n", + " data_file: str or None, default=None\n", + " Optional path to a folder where to find the data.\n", + " If None, default to the MNIST example.\n", + " target_file: str or int or None, default=None\n", + " If str, path to the labels file to import, or name of a `data`\n", + " column to use as labels (only if `data` points to a csv file).\n", + " If int, index of a `data` column of to use as labels).\n", + " Required if data is not None, ignored if data is None.\n", + " n_shards: int\n", + " Number of shards between which to split the data.\n", + " scheme: {\"iid\", \"labels\", \"biased\"}, default=\"iid\"\n", + " Splitting scheme(s) to use. In all cases, shards contain mutually-\n", + " exclusive samples and cover the full raw training data.\n", + " - If \"iid\", split the dataset through iid random sampling.\n", + " - If \"labels\", split into shards that hold all samples associated\n", + " with mutually-exclusive target classes.\n", + " - If \"biased\", split the dataset through random sampling according\n", + " to a shard-specific random labels distribution.\n", + " perc_train: float, default= 0.8\n", + " Train/validation split in each client dataset, must be in the\n", + " ]0,1] range.\n", + " seed: int or None, default=None\n", + " Optional seed to the RNG used for all sampling operations.\n", + " \n" + ] + } + ], + "source": [ + "print(split_data.__doc__)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "kZtbxlwUftKd" + }, + "source": [ + "## Quickrun\n", + "\n", + "We can now run our experiment. As explained in the section 2.1 of the [quickstart documentation](https://magnet.gitlabpages.inria.fr/declearn/docs/latest/quickstart), using the `declearn-quickrun` entry-point requires a configuration file, some data, and a model:\n", + "\n", + "* A TOML file, to store your experiment configurations. Here: \n", + "`examples/mnist_quickrun/config.toml`.\n", + "* A folder with your data, split by client. Here: `examples/mnist_quickrun/data_iid`\n", + "* A model python file, to declare your model wrapped in a `declearn` object. Here: `examples/mnist_quickrun/model.py`.\n", + "\n", + "We then only have to run the `quickrun` util with the path to the TOML file:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "1n_mvTIIWpRf" + }, + "outputs": [], + "source": [ + "from declearn.quickrun import quickrun\n", + "\n", + "quickrun(config=\"examples/mnist_quickrun/config.toml\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The python code above is equivalent to running `declearn-quickrun examples/mnist_quickrun/config.toml` in a shell command-line." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "O0kuw7UxJqKk" + }, + "source": [ + "The output obtained is the combination of the CLI output of our server and our clients, going through: \n", + "\n", + "* `INFO:Server:Starting clients registration process.` : a first registration step, where clients register with the server\n", + "* `INFO:Server:Sending initialization requests to clients.`: the initilization of the object needed for training on both the server and clients side.\n", + "* `Server:INFO: Initiating training round 1`: the training starts, where each client makes its local update(s) and send the result to the server which aggregates them\n", + "* `INFO: Initiating evaluation round 1`: the model is evaluated at each round\n", + "* `Server:INFO: Stopping training`: the training is finalized " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wo6NDugiOH6V" + }, + "source": [ + "## Results \n", + "\n", + "You can have a look at the results in the `examples/mnist_quickrun/result_*` folder, including the metrics evolution during training. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "zlm5El13SvnG" + }, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import glob\n", + "import os \n", + "\n", + "res_file = glob.glob('examples/mnist_quickrun/result*') \n", + "res = pd.read_csv(os.path.join(res_file[0],'server/metrics.csv'))\n", + "res_fig = res.plot()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Kd_MBQt9OJ40" + }, + "source": [ + "# Experiment further\n", + "\n", + "\n", + "You can change the TOML config file to experiment with different strategies." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "E3OOeAYJRGqU" + }, + "source": [ + "For instance, try splitting the data in a very heterogenous way, by distributing digits in mutually exclusive way between clients. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "BNPLnpQuQ8Au" + }, + "outputs": [], + "source": [ + "split_data(folder=\"examples/mnist_quickrun\",scheme='labels')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Xfs-3wH-3Eio" + }, + "source": [ + "And change the `examples/mnist_quickrun/config.toml` file with:\n", + "\n", + "```\n", + "[data] \n", + " data_folder = \"examples/mnist_quickrun/data_labels\" \n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZZVFNO07O1ry" + }, + "source": [ + "If you run the model as is, you should see a drop of performance\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "7kFa0EbINJXq" + }, + "outputs": [], + "source": [ + "quickrun(config=\"examples/mnist_quickrun/config.toml\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "XV6JfaRzR3ee" + }, + "source": [ + "Now try modifying the `examples/mnist_quickrun/config.toml` file like this, to implement the [scaffold algorithm](https://arxiv.org/abs/1910.06378) and running the experiment again. \n", + "\n", + "```\n", + " [optim]\n", + "\n", + " [optim.client_opt]\n", + " lrate = 0.005 \n", + " modules = [\"scaffold-client\"] \n", + "\n", + " [optim.server_opt]\n", + " lrate = 1.0 \n", + " modules = [\"scaffold-client\"]\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "FK6c9HDjSdGZ" + }, + "outputs": [], + "source": [ + "quickrun(config=\"examples/mnist_quickrun/config.toml\")" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [ + "s9bpLdH5ThpJ", + "KlY_vVtFHv2P", + "HoBcOs9hH2QA", + "kZtbxlwUftKd", + "wo6NDugiOH6V", + "Kd_MBQt9OJ40" + ], + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/examples/mnist_quickrun/model.py b/examples/mnist_quickrun/model.py new file mode 100644 index 0000000000000000000000000000000000000000..01dd16a094fa00fab9f6c5336f2c7c2825c67cb2 --- /dev/null +++ b/examples/mnist_quickrun/model.py @@ -0,0 +1,21 @@ +"""Simple TensorFlow-backed CNN model for the MNIST quickrun example.""" + +import tensorflow as tf + +from declearn.model.tensorflow import TensorflowModel + +stack = [ + tf.keras.layers.InputLayer(input_shape=(28, 28, 1)), + tf.keras.layers.Conv2D(8, 3, 1, activation="relu"), + tf.keras.layers.MaxPool2D(2), + tf.keras.layers.Dropout(0.25), + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(64, activation="relu"), + tf.keras.layers.Dropout(0.5), + tf.keras.layers.Dense(10, activation="softmax"), +] +network = tf.keras.models.Sequential(stack) + +# This needs to be called "model"; otherwise, a different name must be +# specified via the experiment's TOML configuration file. +model = TensorflowModel(network, loss="sparse_categorical_crossentropy") diff --git a/examples/mnist_quickrun/readme.md b/examples/mnist_quickrun/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..afcc573c93bd50b040dc4613b506dbe76260fe18 --- /dev/null +++ b/examples/mnist_quickrun/readme.md @@ -0,0 +1,31 @@ +# Demo training task : MNIST in Quickrun Mode + +## Overview + +**We are going to use the declearn-quickrun tool to easily run a simulated +federated learning experiment on the classic +[MNIST dataset](http://yann.lecun.com/exdb/mnist/)**. The input of the model +is a set of images of handwritten digits, and the model needs to determine to +which digit between $0$ and $9$ each image corresponds. + +## Setup + +A Jupyter Notebook tutorial is provided, that you may import and run on Google +Colab so as to avoid having to set up a local python environment. + +Alternatively, you may run the notebook on your personal computer, or follow +its instructions to install declearn and operate the quickrun tools directly +from a shell command-line. + +## Contents + +This example's folder is structured the following way: + +``` +mnist/ +│ config.toml - configuration file for the quickrun FL experiment +| mnist.ipynb - tutorial for this example, as a jupyter notebook +| model.py - python file declaring the model to be trained +└─── data_iid - mnist data generated with `declearn-split` +└─── results_* - results generated after running `declearn-quickrun` +``` diff --git a/pyproject.toml b/pyproject.toml index af3e8de13bc5cf5c41cfdcca31ee1dba0fb440a8..70ca5ffd51716408a94016d3818cbb7610ae0c45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,9 @@ classifiers = [ ] dependencies = [ "cryptography >= 35.0", + "fire >= 0.4", "pandas >= 1.2", + "requests ~= 2.18", "scikit-learn >= 1.0", "tomli >= 2.0 ; python_version < '3.11'", "typing_extensions >= 4.0", @@ -128,3 +130,7 @@ packages = ["declearn"] [tool.setuptools.package-data] declearn = ["py.typed"] + +[project.scripts] +declearn-quickrun = "declearn.quickrun._run:main" +declearn-split = "declearn.dataset._split_data:main" diff --git a/test/dataset/test_examples.py b/test/dataset/test_examples.py new file mode 100644 index 0000000000000000000000000000000000000000..84c6790f316ca05bc7d1ad5194c7e4079019f112 --- /dev/null +++ b/test/dataset/test_examples.py @@ -0,0 +1,61 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functional tests for 'declearn.dataset.examples' utils.""" + +from unittest import mock + +import numpy as np +import pandas as pd # type: ignore + +from declearn.dataset.examples import ( + load_heart_uci, + load_mnist, +) + + +def test_load_heart_uci(tmpdir: str) -> None: + """Functional tests for 'declearn.dataset.example.load_heart_uci'.""" + # Test that downloading the dataset works. + data, tcol = load_heart_uci("va", folder=tmpdir) + assert isinstance(data, pd.DataFrame) + assert tcol in data.columns + # Test that re-loading the dataset works. + with mock.patch( + "declearn.dataset.examples._heart_uci.download_heart_uci_shard" + ) as patch_download: + data_bis, tcol_bis = load_heart_uci("va", folder=tmpdir) + patch_download.assert_not_called() + assert np.allclose(data.values, data_bis.values) + assert tcol == tcol_bis + + +def test_load_mnist(tmpdir: str) -> None: + """Functional tests for 'declearn.dataset.example.load_mnist'.""" + # Test that downloading the (test) dataset works. + images, labels = load_mnist(train=False, folder=tmpdir) + assert isinstance(images, np.ndarray) + assert images.shape == (10000, 28, 28) + assert isinstance(labels, np.ndarray) + assert labels.shape == (images.shape[0],) + assert (np.unique(labels) == np.arange(10)).all() + # Test that re-loading the dataset works. + with mock.patch("requests.get") as patch_download: + img_bis, lab_bis = load_mnist(train=False, folder=tmpdir) + patch_download.assert_not_called() + assert (img_bis == images).all() + assert (lab_bis == labels).all() diff --git a/test/dataset/test_utils.py b/test/dataset/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ce7f28e02db9af0342f35f887ae5f746acfff89e --- /dev/null +++ b/test/dataset/test_utils.py @@ -0,0 +1,168 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for 'declearn.dataset.utils' functions.""" + +import json +import os +from typing import Type + +import numpy as np +import pandas as pd # type: ignore +import pytest +import scipy.sparse # type: ignore +from sklearn.datasets import dump_svmlight_file # type: ignore + +from declearn.dataset.utils import ( + load_data_array, + save_data_array, + sparse_from_file, + sparse_to_file, +) + + +def build_sparse_data() -> scipy.sparse.coo_matrix: + """Build a random-valued COO sparse matrix.""" + rng = np.random.default_rng(seed=0) + val = rng.normal(size=20) + idx = rng.choice(128, size=20) + jdx = rng.choice(32, size=20) + data = scipy.sparse.coo_matrix((val, (idx, jdx))) + return data + + +class TestSaveLoadDataArray: + """Unitary functional tests for data arrays loading and saving utils.""" + + def test_save_load_csv(self, tmpdir: str) -> None: + """Test '(save|load)_data_array' with pandas/csv data.""" + cat = np.random.choice(["a", "b", "c"], size=100) + num = np.random.normal(size=100).round(6) + data = pd.DataFrame({"cat": cat, "num": num}) + base = os.path.join(tmpdir, "data") + # Test that the data can properly be saved. + path = save_data_array(base, data) + assert isinstance(path, str) + assert path.startswith(base) and path.endswith(".csv") + assert os.path.isfile(path) + # Test that the data can properly be reloaded. + dbis = load_data_array(path) + assert isinstance(dbis, pd.DataFrame) + assert np.all(data.values == dbis.values) + + def test_save_load_npy(self, tmpdir: str) -> None: + """Test '(save|load)_data_array' with numpy data.""" + data = np.random.normal(size=(128, 32)) + base = os.path.join(tmpdir, "data") + # Test that the data can properly be saved. + path = save_data_array(base, data) + assert isinstance(path, str) + assert path.startswith(base) and path.endswith(".npy") + assert os.path.isfile(path) + # Test that the data can properly be reloaded. + dbis = load_data_array(path) + assert isinstance(dbis, np.ndarray) + assert np.all(data == dbis) + + def test_save_load_sparse(self, tmpdir: str) -> None: + """Test '(save|load)_data_array' with sparse data.""" + data = build_sparse_data() + base = os.path.join(tmpdir, "data") + # Test that the data can properly be saved. + path = save_data_array(base, data) + assert isinstance(path, str) + assert path.startswith(base) and path.endswith(".sparse") + assert os.path.isfile(path) + # Test that the data can properly be reloaded. + dbis = load_data_array(path) + assert isinstance(dbis, scipy.sparse.coo_matrix) + assert data.shape == dbis.shape + assert data.nnz == dbis.nnz + assert np.all(data.toarray() == dbis.toarray()) + + def test_load_svmlight(self, tmpdir: str) -> None: + """Test 'load_data_array' with svmlight data.""" + # Save some data to svmlight using scikit-learn. + path = os.path.join(tmpdir, "data.svmlight") + x_dat = np.random.normal(size=(100, 32)) + y_dat = np.random.normal(size=100) + dump_svmlight_file(x_dat, y_dat, path) + # Test that the data can properly be reloaded with declearn. + x_bis = load_data_array(path) + y_bis = load_data_array(path, which=1) + assert isinstance(x_bis, scipy.sparse.csr_matrix) + assert np.allclose(x_bis.toarray(), x_dat) + assert isinstance(y_bis, np.ndarray) + assert np.allclose(y_bis, y_dat) + + +SPARSE_TYPES = [ + scipy.sparse.bsr_matrix, + scipy.sparse.csc_matrix, + scipy.sparse.csr_matrix, + scipy.sparse.coo_matrix, + scipy.sparse.dia_matrix, + scipy.sparse.dok_matrix, + scipy.sparse.lil_matrix, +] + + +class TestSaveLoadSparse: + """Unit tests for custom sparse data dump and load utils.""" + + @pytest.mark.parametrize("sparse_cls", SPARSE_TYPES) + def test_sparse_to_from_file( + self, + sparse_cls: Type[scipy.sparse.spmatrix], + tmpdir: str, + ) -> None: + """Test that 'sparse_(to|from)_file' works properly.""" + data = build_sparse_data() + data = sparse_cls(data) + path = os.path.join(tmpdir, "data.sparse") + # Test that the data can properly be saved. + sparse_to_file(path, data) + assert os.path.isfile(path) + # Test that the data can properly be reloaded. + dbis = sparse_from_file(path) + assert isinstance(dbis, sparse_cls) + assert data.shape == dbis.shape + assert data.nnz == dbis.nnz + assert np.all(data.toarray() == dbis.toarray()) + + def test_sparse_to_file_fails(self, tmpdir: str) -> None: + """Test that a TypeError is raised with a bad input type.""" + data = np.random.normal(size=(128, 32)) + with pytest.raises(TypeError): + sparse_to_file(os.path.join(tmpdir, "data.sparse"), data) + + def test_sparse_from_file_keyerror(self, tmpdir: str) -> None: + """Test that a KeyError is raised with a wrongful header.""" + path = os.path.join(tmpdir, "data.sparse") + with open(path, "w", encoding="utf-8") as file: + file.write("Wrongful header\n") + with pytest.raises(KeyError): + sparse_from_file(path) + + def test_sparse_from_file_typeerror(self, tmpdir: str) -> None: + """Test that a TypeError is raised with an unknown spmatrix type.""" + path = os.path.join(tmpdir, "data.sparse") + header = {"stype": "bad", "dtype": "int32", "shape": [128, 32]} + with open(path, "w", encoding="utf-8") as file: + file.write(json.dumps(header) + "\n") + with pytest.raises(TypeError): + sparse_from_file(path) diff --git a/test/functional/test_main.py b/test/functional/test_main.py index 2b0986a93c2ebc3c66f85730749eccd592928e8a..1f1a0bfe67ab2ac0aea456c62c744f12000a2e97 100644 --- a/test/functional/test_main.py +++ b/test/functional/test_main.py @@ -34,7 +34,7 @@ from declearn.dataset import InMemoryDataset from declearn.model.api import Model from declearn.model.sklearn import SklearnSGDModel from declearn.main import FederatedClient, FederatedServer -from declearn.test_utils import run_as_processes +from declearn.utils import run_as_processes from declearn.utils import set_device_policy # Select the subset of tests to run, based on framework availability. diff --git a/test/functional/test_regression.py b/test/functional/test_regression.py index 66a01b67b9d6359e46b82ab8d55649fe7fce46fd..e065e364f4618f0b20ceaf000c3fb8ca759460aa 100644 --- a/test/functional/test_regression.py +++ b/test/functional/test_regression.py @@ -58,7 +58,8 @@ from declearn.metrics import RSquared from declearn.model.api import Model from declearn.model.sklearn import SklearnSGDModel from declearn.optimizer import Optimizer -from declearn.test_utils import FrameworkType, run_as_processes +from declearn.test_utils import FrameworkType +from declearn.utils import run_as_processes from declearn.utils import set_device_policy # pylint: disable=ungrouped-imports; optional frameworks' dependencies