diff --git a/declearn/dataset/__init__.py b/declearn/dataset/__init__.py index dc5759a4a714b836e3bf7348d27e48b79d98468d..3d3710f831e8d7df9c77011ea63f3f315ad46c0a 100644 --- a/declearn/dataset/__init__.py +++ b/declearn/dataset/__init__.py @@ -32,3 +32,5 @@ arrays from ._base import Dataset, DataSpecs, load_dataset_from_json from ._inmemory import InMemoryDataset + +from ._utils import load_data_array, save_data_array diff --git a/declearn/dataset/_inmemory.py b/declearn/dataset/_inmemory.py index 58bd35f20b194c3d9fb62c952db0c17f4d7ace13..5365b6c5ef3aeb98d4903e8644721e7f6c20f817 100644 --- a/declearn/dataset/_inmemory.py +++ b/declearn/dataset/_inmemory.py @@ -17,7 +17,6 @@ """Dataset implementation to serve scikit-learn compatible in-memory data.""" -import functools import os from typing import Any, ClassVar, Dict, Iterator, List, Optional, Set, Union @@ -29,7 +28,7 @@ from sklearn.datasets import load_svmlight_file # type: ignore from typing_extensions import Self # future: import from typing (py >=3.11) from declearn.dataset._base import Dataset, DataSpecs -from declearn.dataset._sparse import sparse_from_file, sparse_to_file +from declearn.dataset._utils import load_data_array, save_data_array from declearn.typing import Batch from declearn.utils import json_dump, json_load, register_type @@ -89,7 +88,7 @@ class InMemoryDataset(Dataset): an instance that is either a numpy ndarray, a pandas DataFrame or a scipy spmatrix. - See the `load_data_array` method for details + See the `load_data_array` function in dataset._utils for details on supported file formats. Parameters @@ -131,7 +130,7 @@ class InMemoryDataset(Dataset): # Assign the main data array. if isinstance(data, str): self._data_path = data - data = self.load_data_array(data) + data = load_data_array(data) self.data = data # Assign the optional input features list. self.f_cols = f_cols @@ -147,7 +146,7 @@ class InMemoryDataset(Dataset): self.f_cols.remove(target) # type: ignore target = self.data[target] else: - target = self.load_data_array(target) + target = load_data_array(target) self.target = target # Assign the (optional) sample weights data array. if isinstance(s_wght, str): @@ -159,7 +158,7 @@ class InMemoryDataset(Dataset): self.f_cols.remove(s_wght) # type: ignore s_wght = self.data[s_wght] else: - s_wght = self.load_data_array(s_wght) + s_wght = load_data_array(s_wght) self.weights = s_wght # Assign the 'expose_classes' attribute. self.expose_classes = expose_classes @@ -388,15 +387,15 @@ class InMemoryDataset(Dataset): # fmt: off info["data"] = ( self._data_path or - self.save_data_array(os.path.join(folder, "data"), self.data) + save_data_array(os.path.join(folder, "data"), self.data) ) info["target"] = None if self.target is None else ( self._trgt_path or - self.save_data_array(os.path.join(folder, "trgt"), self.target) + save_data_array(os.path.join(folder, "trgt"), self.target) ) info["s_wght"] = None if self.weights is None else ( self._wght_path or - self.save_data_array(os.path.join(folder, "wght"), self.weights) + save_data_array(os.path.join(folder, "wght"), self.weights) ) # fmt: on info["f_cols"] = self.f_cols diff --git a/declearn/dataset/_utils.py b/declearn/dataset/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9e6e2e0f07f73c342302e85d059203391ba857c4 --- /dev/null +++ b/declearn/dataset/_utils.py @@ -0,0 +1,149 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dataset implementation to serve scikit-learn compatible in-memory data.""" + +import functools +import os +from typing import Any, Union + +import numpy as np +import pandas as pd # type: ignore +from scipy.sparse import spmatrix # type: ignore +from sklearn.datasets import load_svmlight_file # type: ignore + +from declearn.dataset._sparse import sparse_from_file, sparse_to_file + +__all__ = ["load_data_array", "save_data_array"] + +DataArray = Union[np.ndarray, pd.DataFrame, spmatrix] + + +def load_data_array( + path: str, + **kwargs: Any, +) -> DataArray: + """Load a data array from a dump file. + + Supported file extensions are: + .csv: + csv file, comma-delimited by default. + Any keyword arguments to `pandas.read_csv` may be passed. + .npy: + Non-pickle numpy array dump, created with `numpy.save`. + .sparse: + Scipy sparse matrix dump, created with the custom + `declearn.data.sparse.sparse_to_file` function. + .svmlight: + SVMlight sparse matrix and labels array dump. + Parse using `sklearn.load_svmlight_file`, and + return either features or labels based on the + `which` int keyword argument (default: 0, for + features). + + Parameters + ---------- + path: str + Path to the data array dump file. + Extension must be adequate to enable proper parsing; + see list of supported extensions above. + **kwargs: + Extension-type-based keyword parameters may be passed. + See above for details. + + Returns + ------- + data: numpy.ndarray or pandas.DataFrame or scipy.spmatrix + Reloaded data array. + + Raises + ------ + TypeError: + If `path` is of unsupported extension. + Any exception raised by data-loading functions may also be + raised (e.g. if the file cannot be proprely parsed). + """ + ext = os.path.splitext(path)[1] + if ext == ".csv": + return pd.read_csv(path, **kwargs) + if ext == ".npy": + return np.load(path, allow_pickle=False) + if ext == ".sparse": + return sparse_from_file(path) + if ext == ".svmlight": + which = kwargs.get("which", 0) + return load_svmlight_file(path)[which] + raise TypeError(f"Unsupported data array file extension: '{ext}'.") + + +def save_data_array( + path: str, + array: Union[DataArray, pd.Series], +) -> str: + """Save a data array to a dump file. + + Supported types of data arrays are: + pandas.DataFrame or pandas.Series: + Dump to a comma-separated ".csv" file. + numpy.ndarray: + Dump to a non-pickle ".npy" file. + scipy.sparse.spmatrix: + Dump to a ".sparse" file, using a custom format + and `declearn.data.sparse.sparse_to_file`. + + Parameters + ---------- + path: str + Path to the file where to dump the array. + Appropriate file extension will be added when + not present (i.e. `path` may be a basename). + array: data array structure (see above) + Data array that needs dumping to file. + See above for supported types and associated + behaviours. + + Returns + ------- + path: str + Path to the created file dump, based on the input + `path` and the chosen file extension (see above). + + Raises + ------ + TypeError: + If `array` is of unsupported type. + """ + # Select a file extension and set up the array-dumping function. + if isinstance(array, (pd.DataFrame, pd.Series)): + ext = ".csv" + save = functools.partial( + array.to_csv, sep=",", encoding="utf-8", index=False + ) + elif isinstance(array, np.ndarray): + ext = ".npy" + save = functools.partial(np.save, arr=array) + elif isinstance(array, spmatrix): + ext = ".sparse" + save = functools.partial(sparse_to_file, matrix=array) + else: + raise TypeError(f"Unsupported data array type: '{type(array)}'.") + # Ensure proper naming. Save the array. Return the path. + if not path.endswith(ext): + path += ext + os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) + save(path) + return path diff --git a/declearn/quickrun/_run.py b/declearn/quickrun/_run.py new file mode 100644 index 0000000000000000000000000000000000000000..3348c794b7bbbd4d7d285c5c8326b0a25fbf5fd2 --- /dev/null +++ b/declearn/quickrun/_run.py @@ -0,0 +1,148 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TODO""" +import importlib +from glob import glob + +from declearn.communication import NetworkClientConfig, NetworkServerConfig +from declearn.dataset import InMemoryDataset +from declearn.main import FederatedClient, FederatedServer +from declearn.main.config import FLOptimConfig, FLRunConfig +from declearn.test_utils import make_importable +from declearn.utils import run_as_processes + +DEFAULT_FOLDER = "./examples/quickrun" + + +def _run_server( + model: str, + network: NetworkServerConfig, + optim: FLOptimConfig, + config: FLRunConfig, +) -> None: + """Routine to run a FL server, called by `run_declearn_experiment`.""" + server = FederatedServer(model, network, optim) + server.run(config) + + +def _parse_data_folder(folder: str): + """Utils parsing a data folder following a standard format into a nested" + dictionnary""" + # Get data dir + data_folder = glob("data_*", root_dir=folder) + if len(data_folder) == 0: + raise ValueError( + f"No folder starting with 'data_' found in {folder}" + "Please store your data under a 'data_*' folder" + ) + if len(data_folder) > 1: + raise ValueError( + "More than one folder starting with 'data_' found" + f"in {folder}. Please store your data under a single" + "parent folder" + ) + data_folder = f"{folder}/{data_folder[0]}" + # Get clients dir + clients_folders = glob("client_*", root_dir=data_folder) + if len(clients_folders) == 0: + raise ValueError( + f"No folder starting with 'client_' found in {data_folder}" + "Please store your individual under client data under" + "a 'client_*' folder" + ) + clients = {c: {} for c in clients_folders} + # Get train and valid files + for c in clients.keys(): + path = f"{data_folder}/{c}/" + data_items = [ + "train_data", + "train_target", + "valid_data", + "valid_target", + ] + for d in data_items: + files = glob(f"{d}*", root_dir=path) + if len(files) != 1: + raise ValueError( + f"Could not find unique file named '{d}.*' in {path}" + ) + clients[c][d] = files[0] + + return clients + + +def _run_client( + network: str, + name: str, + paths: dict, +) -> None: + """Routine to run a FL client, called by `run_declearn_experiment`.""" + # Run the declearn FL client routine. + netwk = NetworkClientConfig.from_toml(network) + # Overwrite client name based on folder name + netwk.name = name + # Wrap train and validation data as Dataset objects. + train = InMemoryDataset( + paths.get("train_data"), + target=paths.get("train_target"), + expose_classes=True, + ) + valid = InMemoryDataset( + paths.get("valid_data"), + target=paths.get("valid_target"), + ) + client = FederatedClient(netwk, train, valid) + client.run() + + +def quickrun( + folder: str = None, +) -> None: + """Run a server and its clients using multiprocessing.""" + # default to the 101 example + if not folder: + folder = DEFAULT_FOLDER # TODO check data was run + # Parse toml file to ServerConfig and ClientConfig + toml = f"{folder}/config.toml" + ntk_server = NetworkServerConfig.from_toml(toml, False, "network_server") + optim = FLOptimConfig.from_toml(toml, False, "optim") + run = FLRunConfig.from_toml(toml, False, "run") + ntk_client = NetworkClientConfig.from_toml(toml, False, "network_client") + # get Model + module, name = f"{folder}/model.py", "MyModel" + mod = importlib.import_module(module) + model_cls = getattr(mod, name) + model = model_cls() + # Set up a (func, args) tuple specifying the server process. + p_server = (_run_server, (model, ntk_server, optim, run)) + # Get datasets and client_names from folder + client_dict = _parse_data_folder(folder) + # Set up the (func, args) tuples specifying client-wise processes. + p_client = [] + for name, data_dict in client_dict.items(): + client = (_run_client, (ntk_client, name, data_dict)) + p_client.append(client) + # Run each and every process in parallel. + success, outputs = run_as_processes(p_server, *p_client) + assert success, "The FL process failed:\n" + "\n".join( + str(exc) for exc in outputs if isinstance(exc, RuntimeError) + ) + + +if __name__ == "__main__": + quickrun() diff --git a/declearn/quickrun/_split_data.py b/declearn/quickrun/_split_data.py new file mode 100644 index 0000000000000000000000000000000000000000..f748bf1c9910af0cf0019ecb5e9cace030c13766 --- /dev/null +++ b/declearn/quickrun/_split_data.py @@ -0,0 +1,385 @@ +# coding: utf-8 + +"""Script to split data into heterogeneous shards and save them. + +Available splitting scheme: + +* "iid", split the dataset through iid random sampling. +* "labels", split into shards that hold all samples associated +with mutually-exclusive target classes. +* "biased", split the dataset through random sampling according +to a shard-specific random labels distribution. + +Utilities provided here are limited to : + +* 2D Dataset that be directly loaded into numpy arrays, excluding for +instance sparse data +* Single-class classification problems + +""" + +import argparse +import io +import json +import os +import re +import textwrap +from typing import List, Literal, Optional, Tuple, Union + +import numpy as np +import pandas as pd +import requests # type: ignore + +from declearn.dataset import load_data_array + +SOURCE_URL = "https://pjreddie.com/media/files" +DEFAULT_FOLDER = "./examples/quickrun/data" +# TODO remove duplicate with _run.py + + +def load_mnist( + train: bool = True, +) -> Tuple[np.ndarray, np.ndarray]: + """Load the raw MNIST dataset. + + Arguments + --------- + train: bool, default=True + Whether to return the 60k training subset, or the 10k testing one. + Note that the test set should not be used as a validation set. + """ + # Load the desired subset of MNIST + tag = "train" if train else "test" + url = f"{SOURCE_URL}/mnist_{tag}.csv" + data = requests.get(url, verify=False, timeout=20).content + df = pd.read_csv(io.StringIO(data.decode("utf-8")), header=None, sep=",") + return df.iloc[:, 1:].to_numpy(), df[0].to_numpy()[:, None] + + +def load_data( + data: Optional[str] = None, + target: Optional[Union[str, int]] = None, +) -> Tuple[np.ndarray, np.ndarray]: + """Loads a dataset in memory from provided path(s). Requires + inputs type that can be recognised as array by numpy. + + Arguments + --------- + data: str or None, default=None + Path to the data file to import. If None, default to importing + the MNIST train dataset. + target: str or int or None, default=None + If str, path to the labels file to import. If int, column of + the data file to be used as labels. Required if data is not None, + ignored if data is None. + + Note + ----- + Sparse inputs will not be properly parsed by numpy. Currently, this function + only works with .csv and .npy files + + """ + if not data: + return load_mnist() + + if os.path.isfile(data): + inputs = load_data_array(data) + inputs = np.asarray(inputs) + else: + raise ValueError("The data path provided is not a valid file") + + if isinstance(target, int): + labels = inputs[target][:, None] + inputs = np.delete(inputs, target, axis=1) + if isinstance(target, str): + if os.path.isfile(target): + labels = load_data_array(target) + labels = np.asarray(labels) + else: + raise ValueError( + "The target provided is invalid, please provide a" + "valid path to a file with labels or indicate" + "which column to use as label in the inputs " + ) + return inputs, labels + + +def _split_iid( + inputs: np.ndarray, + target: np.ndarray, + n_shards: int, + rng: np.random.Generator, +) -> List[Tuple[np.ndarray, np.ndarray]]: + """Split a dataset into shards using iid sampling.""" + order = rng.permutation(len(inputs)) + s_len = len(inputs) // n_shards + split = [] # type: List[Tuple[np.ndarray, np.ndarray]] + for idx in range(n_shards): + srt = idx * s_len + end = (srt + s_len) if idx < (n_shards - 1) else len(order) + shard = order[srt:end] + split.append((inputs[shard], target[shard])) + return split + + +def _split_labels( + inputs: np.ndarray, + target: np.ndarray, + n_shards: int, + rng: np.random.Generator, +) -> List[Tuple[np.ndarray, np.ndarray]]: + """Split a dataset into shards with mutually-exclusive label classes.""" + classes = np.unique(target) + if n_shards > len(classes): + raise ValueError( + f"Cannot share {len(classes)} classes between {n_shards}" + "shards with mutually-exclusive labels." + ) + s_len = len(classes) // n_shards + order = rng.permutation(classes) + split = [] # type: List[Tuple[np.ndarray, np.ndarray]] + for idx in range(n_shards): + srt = idx * s_len + end = (srt + s_len) if idx < (n_shards - 1) else len(order) + shard = np.isin(target, order[srt:end]) + split.append((inputs[shard], target[shard])) + return split + + +def _split_biased( + inputs: np.ndarray, + target: np.ndarray, + n_shards: int, + rng: np.random.Generator, +) -> List[Tuple[np.ndarray, np.ndarray]]: + """Split a dataset into shards with heterogeneous label distributions.""" + classes = np.unique(target) + index = np.arange(len(target)) + s_len = len(target) // n_shards + split = [] # type: List[Tuple[np.ndarray, np.ndarray]] + for idx in range(n_shards): + if idx < (n_shards - 1): + # Draw a random distribution of labels for this node. + logits = np.exp(rng.normal(size=len(classes))) + lprobs = logits[target[index]] + lprobs = lprobs / lprobs.sum() + # Draw samples based on this distribution, without replacement. + shard = rng.choice(index, size=s_len, replace=False, p=lprobs) + index = index[~np.isin(index, shard)] + else: + # For the last node: use the remaining samples. + shard = index + split.append((inputs[shard], target[shard])) + return split + + +def export_shard_to_csv( + path: str, + inputs: np.ndarray, + target: np.ndarray, +) -> None: + """Export an MNIST shard to a csv file.""" + specs = {"dtype": inputs.dtype.char, "shape": list(inputs[0].shape)} + with open(path, "w", encoding="utf-8") as file: + file.write(f"{json.dumps(specs)},target") + for inp, tgt in zip(inputs, target): + file.write(f"\n{inp.tobytes().hex()},{int(tgt)}") + + +def load_mnist_from_csv( + path: str, +) -> Tuple[np.ndarray, np.ndarray]: + """Reload an MNIST shard from a csv file.""" + # Prepare data containers. + inputs = [] # type: List[np.ndarray] + target = [] # type: List[int] + # Parse the csv file. + with open(path, "r", encoding="utf-8") as file: + # Parse input features' specs from the csv header. + specs = json.loads(file.readline().rsplit(",", 1)[0]) + dtype = specs["dtype"] + shape = specs["shape"] + # Iteratively deserialize features and labels from rows. + for row in file: + inp, tgt = row.strip("\n").rsplit(",", 1) + dat = bytes.fromhex(inp) + inputs.append(np.frombuffer(dat, dtype=dtype).reshape(shape)) + target.append(int(tgt)) + # Assemble the data into numpy arrays and return. + return np.array(inputs), np.array(target) + + +def split_data( + folder: str = DEFAULT_FOLDER, # CHECK if good practice + n_shards: int = 5, + data: Optional[str] = None, + target: Optional[Union[str, int]] = None, + scheme: Literal["iid", "labels", "biased"] = "iid", + perc_train: float = 0.8, + seed: Optional[int] = None, + use_csv: bool = False, +) -> None: + """Download and randomly split the MNIST dataset into shards. + #TODO + Parameters + ---------- + folder: str + Path to the folder where to export shard-wise files. + n_shards: int + Number of shards between which to split the MNIST training data. + data: str or None, default=None + Optional path to a folder where to find or download the raw MNIST + data. If None, use a temporary folder. + scheme: {"iid", "labels", "biased"}, default="iid" + Splitting scheme to use. In all cases, shards contain mutually- + exclusive samples and cover the full raw training data. + - If "iid", split the dataset through iid random sampling. + - If "labels", split into shards that hold all samples associated + with mutually-exclusive target classes. + - If "biased", split the dataset through random sampling according + to a shard-specific random labels distribution. + seed: int or None, default=None + Optional seed to the RNG used for all sampling operations. + use_csv: bool, default=False + Whether to export shard-wise csv files rather than pairs of .npy + files. This uses twice as much disk space and requires using the + `load_mnist_from_csv` function to reload instead of `numpy.load` + but is mandatory to have compatibility with the Fed-BioMed API. + """ + # Select the splitting function to be used. + if scheme == "iid": + func = _split_iid + elif scheme == "labels": + func = _split_labels + elif scheme == "biased": + func = _split_biased + else: + raise ValueError(f"Invalid 'scheme' value: '{scheme}'.") + # Set up the RNG, download the raw dataset and split it. + rng = np.random.default_rng(seed) + inputs, labels = load_data(data, target) + os.makedirs(folder, exist_ok=True) + print(f"Splitting data into {n_shards} shards using the {scheme} scheme") + split = func(inputs, labels, n_shards, rng) + # Export the resulting shard-wise data to files. + + def np_save(data, i, name): + np.save(os.path.join(folder, f"client_{i}/{name}.npy"), data) + + for i, (inp, tgt) in enumerate(split): + if use_csv: # TODO + path = os.path.join(folder, f"shard_{i}.csv") + export_shard_to_csv(path, inp, tgt) + return + if not perc_train: + np_save(inp, i, "train_data") + np_save(tgt, i, "train_target") + else: + if ~(perc_train <= 1.0) or ~(perc_train > 0.0): + raise ValueError("perc_train should be a float in ]0,1]") + n_train = round(len(inp) * perc_train) + t_inp, t_tgt = inp[:n_train], tgt[:n_train] + v_inp, v_tgt = inp[n_train:], inp[n_train:] + np_save(t_inp, i, "train_data") + np_save(t_tgt, i, "train_target") + np_save(v_inp, i, "valid_data") + np_save(v_tgt, i, "valid_target") + + +def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace: + """Set up and run a command-line arguments parser.""" + usage = """ + Download and split MNIST data into heterogeneous shards. + + This script automates the random splitting of the MNIST digits- + recognition images dataset's 60k training samples into shards, + based on various schemes. Shards contain mutually-exclusive + samples and cover the full raw dataset. + + The implemented schemes are the following: + * "iid": + Split the dataset through iid random sampling. + * "labels": + Split the dataset into shards that hold all samples + that have mutually-exclusive target classes. + * "biased": + Split the dataset through random sampling according + to a shard-specific random labels distribution. + """ + usage = re.sub("\n *(?=[a-z])", " ", textwrap.dedent(usage)) + parser = argparse.ArgumentParser( + formatter_class=argparse.RawTextHelpFormatter, + usage=re.sub("- ", "-", usage), + ) + parser.add_argument( + "--n_shards", + type=int, + default=5, + help="Number of shards between which to split the MNIST training data.", + ) + parser.add_argument( + "--root", + default=".", + dest="folder", + help="Path to the root folder where to export raw and split data.", + ) + parser.add_argument( + "--data_path", + default=None, # CHECK + dest="data", + help="Path to the data to be split", + ) + parser.add_argument( + "--target_path", + default=None, # CHECK + dest="target", + help="Path to the labels to be split", + ) + schemes_help = """ + Splitting scheme(s) to use, among {"iid", "labels", "biased"}. + If this argument is not specified, all three values are used. + See details above on the schemes' definition. + """ + parser.add_argument( + "--scheme", + action="append", + choices=["iid", "labels", "biased"], + default=["iid"], + dest="schemes", + nargs="+", + help=textwrap.dedent(schemes_help), + ) + parser.add_argument( + "--seed", + default=20221109, + dest="seed", + type=int, + help="RNG seed to use (default: 20221109).", + ) + parser.add_argument( + "--csv", + default=False, + dest="use_csv", + type=bool, + help="Export data as csv files (for use with Fed-BioMed).", + ) + return parser.parse_args(args) + + +def main(args: Optional[List[str]] = None) -> None: + """Run splitting schemes based on commandline-input arguments.""" + cmdargs = parse_args(args) + for scheme in cmdargs.schemes: + split_data( + folder=os.path.join(cmdargs.folder, f"data_{scheme}"), + n_shards=cmdargs.n_shards, + data=cmdargs.data, + target=cmdargs.target, + scheme=scheme, + seed=cmdargs.seed, + use_csv=cmdargs.use_csv, + ) + + +if __name__ == "__main__": + main() diff --git a/declearn/test_utils/__init__.py b/declearn/test_utils/__init__.py index ff6ffd54c1d7166ab64222f956937683346b0167..30e7aa42484d4e737c3aaf94e6bff08cfe856839 100644 --- a/declearn/test_utils/__init__.py +++ b/declearn/test_utils/__init__.py @@ -36,7 +36,6 @@ from ._assertions import ( ) from ._gen_ssl import generate_ssl_certificates from ._imports import make_importable -from ._multiprocess import run_as_processes from ._vectors import ( FrameworkType, GradientsTestCase, diff --git a/declearn/utils/__init__.py b/declearn/utils/__init__.py index 67a2e6633fe6a29efaacf1ebfc74c9cdd6f422b2..d639d77f49adee1414fffc8e643f829e0bc5effb 100644 --- a/declearn/utils/__init__.py +++ b/declearn/utils/__init__.py @@ -128,4 +128,5 @@ from ._serialize import ( deserialize_object, serialize_object, ) +from ._multiprocess import run_as_processes from ._toml_config import TomlConfig diff --git a/declearn/test_utils/_multiprocess.py b/declearn/utils/_multiprocess.py similarity index 100% rename from declearn/test_utils/_multiprocess.py rename to declearn/utils/_multiprocess.py diff --git a/declearn/utils/_toml_config.py b/declearn/utils/_toml_config.py index cfe96c2bb21b790db7eddab3f8175f9c21d2656e..3dacd2495a180bbcca20ac01a0c0de07db4a202a 100644 --- a/declearn/utils/_toml_config.py +++ b/declearn/utils/_toml_config.py @@ -293,10 +293,7 @@ class TomlConfig: raise TypeError(f"Failed to parse inputs for field {field.name}.") @classmethod - def from_toml( - cls, - path: str, - ) -> Self: + def from_toml(cls, path: str, warn: bool = True) -> Self: """Parse a structured configuration from a TOML file. The parsed TOML configuration file should be organized into sections @@ -315,6 +312,10 @@ class TomlConfig: path: str Path to a TOML configuration file, that provides with the hyper-parameters making up for the FL "run" configuration. + warn: bool, default=True + Boolean indicating whether to raise a warning when some + fields are unused. Useful for cases where unused fields are + expected, e.g. quickrun. Raises ------ diff --git a/examples/heart-uci/run.py b/examples/heart-uci/run.py index 4a5ac9ade7d8ef050d332c23edcfa4c76cb8401f..1a35c362f7aaa86d194736e25eba54f22510eb32 100644 --- a/examples/heart-uci/run.py +++ b/examples/heart-uci/run.py @@ -20,11 +20,8 @@ import os import tempfile -from declearn.test_utils import ( - generate_ssl_certificates, - make_importable, - run_as_processes, -) +from declearn.test_utils import generate_ssl_certificates, make_importable +from declearn.utils import run_as_processes # Perform local imports. # pylint: disable=wrong-import-position, wrong-import-order diff --git a/examples/quickrun/config.toml b/examples/quickrun/config.toml new file mode 100644 index 0000000000000000000000000000000000000000..21eef2d453836c5f0d0fb7b88280a846ac631a59 --- /dev/null +++ b/examples/quickrun/config.toml @@ -0,0 +1,31 @@ +[network_server] +protocol = "websockets" +host = "127.0.0.1" +port = 8765 + +[network_client] +protocol = "websockets" +server_uri = "ws://localhost:8765" +name = "replaceme" + +[optim] +aggregator = "averaging" +server_opt = 1.0 + + [optim.client_opt] + lrate = 0.001 + regularizers = ["lasso", {alpha = 0.1}] + +[run] +rounds = 10 + + [run.register] + min_clients = 3 + + [run.training] + n_epoch = 1 + batch_size = 48 + drop_remainder = false + + + diff --git a/examples/quickrun/data.py b/examples/quickrun/data.py new file mode 100644 index 0000000000000000000000000000000000000000..22b885e62d8c28dbe0ec21e1c40290520b92cae1 --- /dev/null +++ b/examples/quickrun/data.py @@ -0,0 +1,307 @@ +# coding: utf-8 + +"""Script to download and split MNIST data into heterogeneous shards.""" + +import argparse +import io +import json +import os +import re +import sys +import tempfile +import textwrap +from typing import List, Literal, Optional, Tuple + +import numpy as np +import pandas as pd +import requests # type: ignore + +SOURCE_URL = "https://pjreddie.com/media/files/" + +# TODO reduce arg numbers in functions using SplitConfig + + +def load_mnist( + folder: Optional[str] = None, + train: bool = True, +) -> Tuple[np.ndarray, np.ndarray]: + """Load the raw MNIST dataset, downloading it if needed. + + Arguments + --------- + folder: str or None, default=None + Optional path to a root folder where to find or download the + raw MNIST data. If None, use a temporary folder. + train: bool, default=True + Whether to return the 60k training subset, or the 10k testing one. + """ + # Optionally use a temporary folder where to download the raw data. + if folder is None: + with tempfile.TemporaryDirectory() as tmpdir: + return load_mnist(tmpdir, train) + # Load the desired subset of MNIST + tag = "train" if train else "test" + url = f"{SOURCE_URL}mnist_{tag}.csv" + data = requests.get(url, verify=False, timeout=20).content + df = pd.read_csv(io.StringIO(data.decode("utf-8")), header=None, sep=",") + return df.iloc[:, 1:].to_numpy(), df[0].to_numpy()[:, None] + + +def split_mnist( + folder: str, + n_shards: int, + scheme: Literal["iid", "labels", "biased"], + seed: Optional[int] = None, + mnist: Optional[str] = None, + use_csv: bool = False, +) -> None: + """Download and randomly split the MNIST dataset into shards. + + Parameters + ---------- + folder: str + Path to the folder where to export shard-wise files. + n_shards: int + Number of shards between which to split the MNIST training data. + scheme: {"iid", "labels", "biased"} + Splitting scheme to use. In all cases, shards contain mutually- + exclusive samples and cover the full raw training data. + - If "iid", split the dataset through iid random sampling. + - If "labels", split into shards that hold all samples associated + with mutually-exclusive target classes. + - If "biased", split the dataset through random sampling according + to a shard-specific random labels distribution. + seed: int or None, default=None + Optional seed to the RNG used for all sampling operations. + mnist: str or None, default=None + Optional path to a folder where to find or download the raw MNIST + data. If None, use a temporary folder. + use_csv: bool, default=False + Whether to export shard-wise csv files rather than pairs of .npy + files. This uses twice as much disk space and requires using the + `load_mnist_from_csv` function to reload instead of `numpy.load` + but is mandatory to have compatibility with the Fed-BioMed API. + """ + # Select the splitting function to be used. + if scheme == "iid": + func = _split_iid + elif scheme == "labels": + func = _split_labels + elif scheme == "biased": + func = _split_biased + else: + raise ValueError(f"Invalid 'scheme' value: '{scheme}'.") + # Set up the RNG, download the raw dataset and split it. + rng = np.random.default_rng(seed) + inputs, target = load_mnist(mnist, train=True) + os.makedirs(folder, exist_ok=True) + print(f"Splitting MNIST into {n_shards} shards using the {scheme} scheme") + split = func(inputs, target, n_shards, rng) + # Export the resulting shard-wise data to files. + for idx, (inp, tgt) in enumerate(split): + if use_csv: + path = os.path.join(folder, f"shard_{idx}.csv") + export_shard_to_csv(path, inp, tgt) + else: + np.save(os.path.join(folder, f"shard_{idx}_inputs.npy"), inp) + np.save(os.path.join(folder, f"shard_{idx}_target.npy"), tgt) + + +def _split_iid( + inputs: np.ndarray, + target: np.ndarray, + n_shards: int, + rng: np.random.Generator, +) -> List[Tuple[np.ndarray, np.ndarray]]: + """Split a dataset into shards using iid sampling.""" + order = rng.permutation(len(inputs)) + s_len = len(inputs) // n_shards + split = [] # type: List[Tuple[np.ndarray, np.ndarray]] + for idx in range(n_shards): + srt = idx * s_len + end = (srt + s_len) if idx < (n_shards - 1) else len(order) + shard = order[srt:end] + split.append((inputs[shard], target[shard])) + return split + + +def _split_labels( + inputs: np.ndarray, + target: np.ndarray, + n_shards: int, + rng: np.random.Generator, +) -> List[Tuple[np.ndarray, np.ndarray]]: + """Split a dataset into shards with mutually-exclusive label classes.""" + classes = np.unique(target) + if n_shards > len(classes): + raise ValueError( + f"Cannot share {len(classes)} classes between {n_shards}" + "shards with mutually-exclusive labels." + ) + s_len = len(classes) // n_shards + order = rng.permutation(classes) + split = [] # type: List[Tuple[np.ndarray, np.ndarray]] + for idx in range(n_shards): + srt = idx * s_len + end = (srt + s_len) if idx < (n_shards - 1) else len(order) + shard = np.isin(target, order[srt:end]) + split.append((inputs[shard], target[shard])) + return split + + +def _split_biased( + inputs: np.ndarray, + target: np.ndarray, + n_shards: int, + rng: np.random.Generator, +) -> List[Tuple[np.ndarray, np.ndarray]]: + """Split a dataset into shards with heterogeneous label distributions.""" + classes = np.unique(target) + index = np.arange(len(target)) + s_len = len(target) // n_shards + split = [] # type: List[Tuple[np.ndarray, np.ndarray]] + for idx in range(n_shards): + if idx < (n_shards - 1): + # Draw a random distribution of labels for this node. + logits = np.exp(rng.normal(size=len(classes))) + lprobs = logits[target[index]] + lprobs = lprobs / lprobs.sum() + # Draw samples based on this distribution, without replacement. + shard = rng.choice(index, size=s_len, replace=False, p=lprobs) + index = index[~np.isin(index, shard)] + else: + # For the last node: use the remaining samples. + shard = index + split.append((inputs[shard], target[shard])) + return split + + +def export_shard_to_csv( + path: str, + inputs: np.ndarray, + target: np.ndarray, +) -> None: + """Export an MNIST shard to a csv file.""" + specs = {"dtype": inputs.dtype.char, "shape": list(inputs[0].shape)} + with open(path, "w", encoding="utf-8") as file: + file.write(f"{json.dumps(specs)},target") + for inp, tgt in zip(inputs, target): + file.write(f"\n{inp.tobytes().hex()},{int(tgt)}") + + +def load_mnist_from_csv( + path: str, +) -> Tuple[np.ndarray, np.ndarray]: + """Reload an MNIST shard from a csv file.""" + # Prepare data containers. + inputs = [] # type: List[np.ndarray] + target = [] # type: List[int] + # Parse the csv file. + with open(path, "r", encoding="utf-8") as file: + # Parse input features' specs from the csv header. + specs = json.loads(file.readline().rsplit(",", 1)[0]) + dtype = specs["dtype"] + shape = specs["shape"] + # Iteratively deserialize features and labels from rows. + for row in file: + inp, tgt = row.strip("\n").rsplit(",", 1) + dat = bytes.fromhex(inp) + inputs.append(np.frombuffer(dat, dtype=dtype).reshape(shape)) + target.append(int(tgt)) + # Assemble the data into numpy arrays and return. + return np.array(inputs), np.array(target) + + +def report_download_progress( + chunk_number: int, chunk_size: int, file_size: int +): + if file_size != -1: + percent = min(1, (chunk_number * chunk_size) / file_size) + bar = "#" * int(64 * percent) + sys.stdout.write("\r0% |{:<64}| {}%".format(bar, int(percent * 100))) + + +def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace: + """Set up and run a command-line arguments parser.""" + usage = """ + Download and split MNIST data into heterogeneous shards. + + This script automates the random splitting of the MNIST digits- + recognition images dataset's 60k training samples into shards, + based on various schemes. Shards contain mutually-exclusive + samples and cover the full raw dataset. + + The implemented schemes are the following: + * "iid": + Split the dataset through iid random sampling. + * "labels": + Split the dataset into shards that hold all samples + that have mutually-exclusive target classes. + * "biased": + Split the dataset through random sampling according + to a shard-specific random labels distribution. + """ + usage = re.sub("\n *(?=[a-z])", " ", textwrap.dedent(usage)) + parser = argparse.ArgumentParser( + formatter_class=argparse.RawTextHelpFormatter, + usage=re.sub("- ", "-", usage), + ) + parser.add_argument( + "--n_shards", + type=int, + default=5, + help="Number of shards between which to split the MNIST training data.", + ) + parser.add_argument( + "--root", + default=".", + dest="folder", + help="Path to the root folder where to export raw and split data.", + ) + schemes_help = """ + Splitting scheme(s) to use, among {"iid", "labels", "biased"}. + If this argument is not specified, all three values are used. + See details above on the schemes' definition. + """ + parser.add_argument( + "--scheme", + action="append", + choices=["iid", "labels", "biased"], + default=["iid"], + dest="schemes", + nargs="+", + help=textwrap.dedent(schemes_help), + ) + parser.add_argument( + "--seed", + default=20221109, + dest="seed", + type=int, + help="RNG seed to use (default: 20221109).", + ) + parser.add_argument( + "--csv", + default=False, + dest="use_csv", + type=bool, + help="Export data as csv files (for use with Fed-BioMed).", + ) + return parser.parse_args(args) + + +def main(args: Optional[List[str]] = None) -> None: + """Run splitting schemes based on commandline-input arguments.""" + cmdargs = parse_args(args) + for scheme in cmdargs.schemes or ["iid", "labels", "biased"]: + split_mnist( + folder=os.path.join(cmdargs.folder, f"mnist_{scheme}"), + n_shards=cmdargs.n_shards, + scheme=scheme, + seed=cmdargs.seed, + mnist=cmdargs.folder, + use_csv=cmdargs.use_csv, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/quickrun/model.py b/examples/quickrun/model.py new file mode 100644 index 0000000000000000000000000000000000000000..c9984ba70bf6e3a2a39ec8a8fe7af27719bd391e --- /dev/null +++ b/examples/quickrun/model.py @@ -0,0 +1,33 @@ +"""Wrapping a torch model""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from declearn.model.torch import TorchModel + + +class Net(nn.Module): + + """A basic CNN, directly copied from Torch's 60 min blitz""" + + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = torch.flatten(x, 1) # flatten all dimensions except batch + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + +MyModel = TorchModel(Net(), loss=nn.NLLLoss()) diff --git a/test/communication/test_routines.py b/test/communication/test_routines.py new file mode 100644 index 0000000000000000000000000000000000000000..c00ff5a327ba0a982f39e9492d257abe858aff5a --- /dev/null +++ b/test/communication/test_routines.py @@ -0,0 +1,196 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functional test for declearn.communication classes. + +The test implemented here spawns a NetworkServer endpoint as well as one +or multiple NetworkClient ones, then runs parallelly routines that have +the clients register, and both sides exchange dummy messages. As such, +it only verifies that messages passing works, and does not constitute a +proper (ensemble of) unit test(s) of the classes. + +However, if this passes, it means that registration and basic +message passing work properly, using the following scenarios: +* gRPC or WebSockets protocol +* SSL-secured communications or not +* 1-client or 3-clients cases + +Note that the tests are somewhat slow when collected by pytest, +and that they make use of the multiprocessing library to isolate +the server and individual clients - which is not required when +running the code manually, and might require using '--full-trace' +pytest option to debug in case a test fails. + +Note: running code that uses `asyncio.gather` on concurrent coroutines +is unsuccessful with gRPC due to spawned clients sharing the same peer +context. This may be fixed by implementing proper authentication. +""" + +import asyncio +from typing import Any, Callable, Dict, List, Tuple + +import pytest + +from declearn.communication import ( + build_client, + build_server, + list_available_protocols, +) +from declearn.communication.api import NetworkClient, NetworkServer +from declearn.communication.messaging import GenericMessage +from declearn.utils import run_as_processes + + +async def client_routine( + client: NetworkClient, +) -> None: + """Basic client testing routine.""" + print("Registering") + await client.register({"foo": "bar"}) + print("Receiving") + message = await client.check_message() + print(message) + print("Sending") + await client.send_message(GenericMessage(action="maybe", params={})) + print("Receiving") + message = await client.check_message() + print(message) + print("Sending") + await client.send_message(message) + print("Done!") + + +async def server_routine( + server: NetworkServer, + nb_clients: int = 1, +) -> None: + """Basic server testing routine.""" + data_info = await server.wait_for_clients( + min_clients=nb_clients, max_clients=nb_clients, timeout=5 + ) + print(data_info) + print("Sending") + await server.broadcast_message( + GenericMessage(action="train", params={"let's": "go"}) + ) + print("Receiving") + messages = await server.wait_for_messages() + print(messages) + print("Sending") + messages = { + client: GenericMessage("hello", {"name": client}) + for client in server.client_names + } + await server.send_messages(messages) + print("Receiving") + messages = await server.wait_for_messages() + print(messages) + print("Closing") + + +@pytest.mark.parametrize("nb_clients", [1, 3], ids=["1_client", "3_clients"]) +@pytest.mark.parametrize("use_ssl", [False, True], ids=["ssl", "unsafe"]) +@pytest.mark.parametrize("protocol", list_available_protocols()) +def test_routines( + protocol: str, + nb_clients: int, + use_ssl: bool, + ssl_cert: Dict[str, str], +) -> None: + """Test that the defined server and client routines run properly.""" + run_test_routines(protocol, nb_clients, use_ssl, ssl_cert) + + +def run_test_routines( + protocol: str, + nb_clients: int, + use_ssl: bool, + ssl_cert: Dict[str, str], +) -> None: + """Test that the defined server and client routines run properly.""" + # Set up (func, args) tuples that specify concurrent routines. + args = (protocol, nb_clients, use_ssl, ssl_cert) + routines = [_build_server_func(*args)] + routines.extend(_build_client_funcs(*args)) + # Run the former using isolated processes. + success, outputs = run_as_processes(*routines) + # Assert that all processes terminated properly. + assert success, "Routines failed:\n" + "\n".join( + [str(exc) for exc in outputs if isinstance(exc, RuntimeError)] + ) + + +def _build_server_func( + protocol: str, + nb_clients: int, + use_ssl: bool, + ssl_cert: Dict[str, str], +) -> Tuple[Callable[..., None], Tuple[Any, ...]]: + """Return arguments to spawn and use a NetworkServer in a process.""" + server_cfg = { + "protocol": protocol, + "host": "127.0.0.1", + "port": 8765, + "certificate": ssl_cert["server_cert"] if use_ssl else None, + "private_key": ssl_cert["server_pkey"] if use_ssl else None, + } # type: Dict[str, Any] + + # Define a coroutine that spawns and runs a server. + async def server_coroutine() -> None: + """Spawn a client and run `server_routine` in its context.""" + nonlocal nb_clients, server_cfg + async with build_server(**server_cfg) as server: + await server_routine(server, nb_clients) + + # Define a routine that runs the former. + def server_func() -> None: + """Run `server_coroutine`.""" + asyncio.run(server_coroutine()) + + # Return the former as a (func, arg) tuple. + return (server_func, tuple()) + + +def _build_client_funcs( + protocol: str, + nb_clients: int, + use_ssl: bool, + ssl_cert: Dict[str, str], +) -> List[Tuple[Callable[..., None], Tuple[Any, ...]]]: + """Return arguments to spawn and use NetworkClient objects in processes.""" + certificate = ssl_cert["client_cert"] if use_ssl else None + server_uri = "localhost:8765" + if protocol == "websockets": + server_uri = f"ws{'s' * use_ssl}://{server_uri}" + + # Define a coroutine that spawns and runs a client. + async def client_coroutine( + name: str, + ) -> None: + """Spawn a client and run `client_routine` in its context.""" + nonlocal certificate, protocol, server_uri + args = (protocol, server_uri, name, certificate) + async with build_client(*args) as client: + await client_routine(client) + + # Define a routine that runs the former. + def client_func(name: str) -> None: + """Run `client_coroutine`.""" + asyncio.run(client_coroutine(name)) + + # Return a list of (func, args) tuples. + return [(client_func, (f"client_{idx}",)) for idx in range(nb_clients)] diff --git a/test/functional/test_main.py b/test/functional/test_main.py index 2b0986a93c2ebc3c66f85730749eccd592928e8a..1f1a0bfe67ab2ac0aea456c62c744f12000a2e97 100644 --- a/test/functional/test_main.py +++ b/test/functional/test_main.py @@ -34,7 +34,7 @@ from declearn.dataset import InMemoryDataset from declearn.model.api import Model from declearn.model.sklearn import SklearnSGDModel from declearn.main import FederatedClient, FederatedServer -from declearn.test_utils import run_as_processes +from declearn.utils import run_as_processes from declearn.utils import set_device_policy # Select the subset of tests to run, based on framework availability. diff --git a/test/functional/test_regression.py b/test/functional/test_regression.py index 66a01b67b9d6359e46b82ab8d55649fe7fce46fd..e065e364f4618f0b20ceaf000c3fb8ca759460aa 100644 --- a/test/functional/test_regression.py +++ b/test/functional/test_regression.py @@ -58,7 +58,8 @@ from declearn.metrics import RSquared from declearn.model.api import Model from declearn.model.sklearn import SklearnSGDModel from declearn.optimizer import Optimizer -from declearn.test_utils import FrameworkType, run_as_processes +from declearn.test_utils import FrameworkType +from declearn.utils import run_as_processes from declearn.utils import set_device_policy # pylint: disable=ungrouped-imports; optional frameworks' dependencies