diff --git a/declearn/dataset/utils/__init__.py b/declearn/dataset/utils/__init__.py index dfebe91da3623f4f8cb87a5b3a6f1aae5de1681e..c19f9cb1efe452c6763bff5c3286ca5e9f442dc0 100644 --- a/declearn/dataset/utils/__init__.py +++ b/declearn/dataset/utils/__init__.py @@ -30,6 +30,13 @@ to and from various file formats: Backend to load a sparse matrix from a dump file. * [sparse_to_file][declearn.dataset.utils.sparse_to_file]: Backend to save a sparse matrix to a dump file + +Data splitting +-------------- +* [split_multi_classif_dataset] +[declearn.dataset.utils.split_multi_classif_dataset]: + Split a classification dataset into (opt. heterogeneous) shards. """ from ._sparse import sparse_from_file, sparse_to_file from ._save_load import load_data_array, save_data_array +from ._split_classif import split_multi_classif_dataset diff --git a/declearn/dataset/utils/_split_classif.py b/declearn/dataset/utils/_split_classif.py new file mode 100644 index 0000000000000000000000000000000000000000..7d941d463742fc3fd9dca22695a2057d08f2551b --- /dev/null +++ b/declearn/dataset/utils/_split_classif.py @@ -0,0 +1,170 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utils to split a multi-category classification dataset into shards.""" + +from typing import List, Literal, Optional, Tuple + +import numpy as np + + +__all__ = [ + "split_multi_classif_dataset", +] + + +def split_multi_classif_dataset( + dataset: Tuple[np.ndarray, np.ndarray], + n_shards: int, + scheme: Literal["iid", "labels", "biased"], + p_valid: float = 0.2, + seed: Optional[int] = None, +) -> List[Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]]: + """Split a classification dataset into (opt. heterogeneous) shards. + + The data-splitting schemes are the following: + + - If "iid", split the dataset through iid random sampling. + - If "labels", split into shards that hold all samples associated + with mutually-exclusive target classes. + - If "biased", split the dataset through random sampling according + to a shard-specific random labels distribution. + + Parameters + ---------- + dataset: tuple(np.ndarray, np.ndarray) + Raw dataset, as a pair of numpy arrays that respectively contain + the input features and (aligned) labels. + n_shards: int + Number of shards between which to split the dataset. + scheme: {"iid", "labels", "biased"} + Splitting scheme to use. In all cases, shards contain mutually- + exclusive samples and cover the full dataset. See details above. + p_valid: float, default=0.2 + Share of each shard to turn into a validation subset. + seed: int or None, default=None + Optional seed to the RNG used for all sampling operations. + + Returns + ------- + shards: list[((np.ndarray, np.ndarray), (np.ndarray, np.ndarray))] + List of dataset shards, where each element is formatted as a + tuple of tuples: `((x_train, y_train), (x_valid, y_valid))`. + + Raises + ------ + ValueError + If `scheme` has an invalid value. + """ + # Select the splitting function to be used. + if scheme == "iid": + func = split_iid + elif scheme == "labels": + func = split_labels + elif scheme == "biased": + func = split_biased + else: + raise ValueError(f"Invalid 'scheme' value: '{scheme}'.") + # Set up the RNG and split the dataset into shards. + rng = np.random.default_rng(seed) + inputs, target = dataset + split = func(inputs, target, n_shards, rng) + # Further split shards into training and validation subsets, and return. + return [train_valid_split(inp, tgt, p_valid, rng) for inp, tgt in split] + + +def split_iid( + inputs: np.ndarray, + target: np.ndarray, + n_shards: int, + rng: np.random.Generator, +) -> List[Tuple[np.ndarray, np.ndarray]]: + """Split a dataset into shards using iid sampling.""" + order = rng.permutation(len(inputs)) + s_len = len(inputs) // n_shards + split = [] # type: List[Tuple[np.ndarray, np.ndarray]] + for idx in range(n_shards): + srt = idx * s_len + end = (srt + s_len) if idx < (n_shards - 1) else len(order) + shard = order[srt:end] + split.append((inputs[shard], target[shard])) + return split + + +def split_labels( + inputs: np.ndarray, + target: np.ndarray, + n_shards: int, + rng: np.random.Generator, +) -> List[Tuple[np.ndarray, np.ndarray]]: + """Split a dataset into shards with mutually-exclusive label classes.""" + classes = np.unique(target) + if n_shards > len(classes): + raise ValueError( + f"Cannot share {len(classes)} classes between {n_shards}" + "shards with mutually-exclusive labels." + ) + s_len = len(classes) // n_shards + order = rng.permutation(classes) + split = [] # type: List[Tuple[np.ndarray, np.ndarray]] + for idx in range(n_shards): + srt = idx * s_len + end = (srt + s_len) if idx < (n_shards - 1) else len(order) + shard = np.isin(target, order[srt:end]) + split.append((inputs[shard], target[shard])) + return split + + +def split_biased( + inputs: np.ndarray, + target: np.ndarray, + n_shards: int, + rng: np.random.Generator, +) -> List[Tuple[np.ndarray, np.ndarray]]: + """Split a dataset into shards with heterogeneous label distributions.""" + classes = np.unique(target) + index = np.arange(len(target)) + s_len = len(target) // n_shards + split = [] # type: List[Tuple[np.ndarray, np.ndarray]] + for idx in range(n_shards): + if idx < (n_shards - 1): + # Draw a random distribution of labels for this node. + logits = np.exp(rng.normal(size=len(classes))) + lprobs = logits[target[index]] + lprobs = lprobs / lprobs.sum() + # Draw samples based on this distribution, without replacement. + shard = rng.choice(index, size=s_len, replace=False, p=lprobs) + index = index[~np.isin(index, shard)] + else: + # For the last node: use the remaining samples. + shard = index + split.append((inputs[shard], target[shard])) + return split + + +def train_valid_split( + inputs: np.ndarray, + target: np.ndarray, + p_valid: float, + rng: np.random.Generator, +) -> Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]: + """Split a dataset between train and validation using iid sampling.""" + order = rng.permutation(len(inputs)) + v_len = np.ceil(len(inputs) * p_valid).astype(int) + train = inputs[order[v_len:]], target[order[v_len:]] + valid = inputs[order[:v_len]], target[order[:v_len]] + return train, valid diff --git a/declearn/quickrun/_split_data.py b/declearn/quickrun/_split_data.py index fd8b3387fc4a4c72a2b06860500ce6aaabd83f25..4e595f424ab7a3d5c33b6df0142042c25cb225ae 100644 --- a/declearn/quickrun/_split_data.py +++ b/declearn/quickrun/_split_data.py @@ -33,12 +33,12 @@ instance sparse data """ import os -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple, Union import numpy as np from declearn.dataset.examples import load_mnist -from declearn.dataset.utils import load_data_array +from declearn.dataset.utils import load_data_array, split_multi_classif_dataset from declearn.quickrun._config import DataSplitConfig @@ -91,76 +91,6 @@ def load_data( return inputs, labels -def _split_iid( - inputs: np.ndarray, - target: np.ndarray, - n_shards: int, - rng: np.random.Generator, -) -> List[Tuple[np.ndarray, np.ndarray]]: - """Split a dataset into shards using iid sampling.""" - order = rng.permutation(len(inputs)) - s_len = len(inputs) // n_shards - split = [] # type: List[Tuple[np.ndarray, np.ndarray]] - for idx in range(n_shards): - srt = idx * s_len - end = (srt + s_len) if idx < (n_shards - 1) else len(order) - shard = order[srt:end] - split.append((inputs[shard], target[shard])) - return split - - -def _split_labels( - inputs: np.ndarray, - target: np.ndarray, - n_shards: int, - rng: np.random.Generator, -) -> List[Tuple[np.ndarray, np.ndarray]]: - """Split a dataset into shards with mutually-exclusive label classes.""" - classes = np.unique(target) - if n_shards > len(classes): - raise ValueError( - f"Cannot share {len(classes)} classes between {n_shards}" - "shards with mutually-exclusive labels." - ) - s_len = len(classes) // n_shards - order = rng.permutation(classes) - split = [] # type: List[Tuple[np.ndarray, np.ndarray]] - for idx in range(n_shards): - srt = idx * s_len - end = (srt + s_len) if idx < (n_shards - 1) else len(order) - shard = np.isin(target, order[srt:end]) - shuffle = rng.permutation(shard.sum()) - split.append((inputs[shard][shuffle], target[shard][shuffle])) - return split - - -def _split_biased( - inputs: np.ndarray, - target: np.ndarray, - n_shards: int, - rng: np.random.Generator, -) -> List[Tuple[np.ndarray, np.ndarray]]: - """Split a dataset into shards with heterogeneous label distributions.""" - classes = np.unique(target) - index = np.arange(len(target)) - s_len = len(target) // n_shards - split = [] # type: List[Tuple[np.ndarray, np.ndarray]] - for idx in range(n_shards): - if idx < (n_shards - 1): - # Draw a random distribution of labels for this node. - logits = np.exp(rng.normal(size=len(classes))) - lprobs = logits[target[index]] - lprobs = lprobs / lprobs.sum() - # Draw samples based on this distribution, without replacement. - shard = rng.choice(index, size=s_len, replace=False, p=lprobs) - index = index[~np.isin(index, shard)] - else: - # For the last node: use the remaining samples. - shard = index - split.append((inputs[shard], target[shard])) - return split - - def split_data(data_config: DataSplitConfig, folder: str) -> None: """Download and randomly split a dataset into shards. @@ -180,54 +110,33 @@ def split_data(data_config: DataSplitConfig, folder: str) -> None: data_config: DataSplitConfig A DataSplitConfig instance, see class documentation for details """ - - def np_save(folder, data, i, name): - data_dir = os.path.join(folder, f"client_{i}") - os.makedirs(data_dir, exist_ok=True) - np.save(os.path.join(data_dir, f"{name}.npy"), data) - - # Overwrite default folder if provided - scheme = data_config.scheme - name = f"data_{scheme}" - data_file = data_config.data_file - label_file = data_config.label_file + # Select output folder. if data_config.data_folder: folder = os.path.dirname(data_config.data_folder) - name = os.path.split(data_config.data_folder)[-1] - data_file = os.path.abspath(data_config.data_file) - label_file = os.path.abspath(data_config.label_file) - # Select the splitting function to be used. - if scheme == "iid": - func = _split_iid - elif scheme == "labels": - func = _split_labels - elif scheme == "biased": - func = _split_biased else: - raise ValueError(f"Invalid 'scheme' value: '{scheme}'.") - # Set up the RNG, download the raw dataset and split it. - rng = np.random.default_rng(data_config.seed) - - inputs, labels = load_data(data_file, label_file) + folder = f"data_{data_config.scheme}" + # Value-check the 'perc_train' parameter. + if not 0.0 < data_config.perc_train <= 1.0: + raise ValueError("'perc_train' should be a float in ]0,1]") + # Load the dataset and split it. + inputs, labels = load_data(data_config.data_file, data_config.label_file) print( - f"Splitting data into {data_config.n_shards}" - f"shards using the {scheme} scheme" + f"Splitting data into {data_config.n_shards} shards " + f"using the '{data_config.scheme}' scheme." + ) + split = split_multi_classif_dataset( + dataset=(inputs, labels), + n_shards=data_config.n_shards, + scheme=data_config.scheme, # type: ignore + p_valid=(1 - data_config.perc_train), + seed=data_config.seed, ) - split = func(inputs, labels, data_config.n_shards, rng) # Export the resulting shard-wise data to files. - folder = os.path.join(folder, name) - for i, (inp, tgt) in enumerate(split): - perc_train = data_config.perc_train - if not perc_train: - np_save(folder, inp, i, "train_data") - np_save(folder, tgt, i, "train_target") - else: - if perc_train > 1.0 or perc_train < 0.0: - raise ValueError("perc_train should be a float in ]0,1]") - n_train = round(len(inp) * perc_train) - t_inp, t_tgt = inp[:n_train], tgt[:n_train] - v_inp, v_tgt = inp[n_train:], tgt[n_train:] - np_save(folder, t_inp, i, "train_data") - np_save(folder, t_tgt, i, "train_target") - np_save(folder, v_inp, i, "valid_data") - np_save(folder, v_tgt, i, "valid_target") + for idx, ((x_train, y_train), (x_valid, y_valid)) in enumerate(split): + subdir = os.path.join(folder, f"client_{idx}") + os.makedirs(subdir, exist_ok=True) + np.save(os.path.join(subdir, "train_data.npy"), x_train) + np.save(os.path.join(subdir, "train_target.npy"), y_train) + if len(x_valid): + np.save(os.path.join(subdir, "valid_data.npy"), x_valid) + np.save(os.path.join(subdir, "valid_target.npy"), y_valid)