Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit 83811146 authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Add 'declearn.dataset.utils.split_multi_classif_dataset'.

parent 8b38b10f
No related branches found
No related tags found
1 merge request!41Quickrun mode
......@@ -30,6 +30,13 @@ to and from various file formats:
Backend to load a sparse matrix from a dump file.
* [sparse_to_file][declearn.dataset.utils.sparse_to_file]:
Backend to save a sparse matrix to a dump file
Data splitting
--------------
* [split_multi_classif_dataset]
[declearn.dataset.utils.split_multi_classif_dataset]:
Split a classification dataset into (opt. heterogeneous) shards.
"""
from ._sparse import sparse_from_file, sparse_to_file
from ._save_load import load_data_array, save_data_array
from ._split_classif import split_multi_classif_dataset
# coding: utf-8
# Copyright 2023 Inria (Institut National de Recherche en Informatique
# et Automatique)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils to split a multi-category classification dataset into shards."""
from typing import List, Literal, Optional, Tuple
import numpy as np
__all__ = [
"split_multi_classif_dataset",
]
def split_multi_classif_dataset(
dataset: Tuple[np.ndarray, np.ndarray],
n_shards: int,
scheme: Literal["iid", "labels", "biased"],
p_valid: float = 0.2,
seed: Optional[int] = None,
) -> List[Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]]:
"""Split a classification dataset into (opt. heterogeneous) shards.
The data-splitting schemes are the following:
- If "iid", split the dataset through iid random sampling.
- If "labels", split into shards that hold all samples associated
with mutually-exclusive target classes.
- If "biased", split the dataset through random sampling according
to a shard-specific random labels distribution.
Parameters
----------
dataset: tuple(np.ndarray, np.ndarray)
Raw dataset, as a pair of numpy arrays that respectively contain
the input features and (aligned) labels.
n_shards: int
Number of shards between which to split the dataset.
scheme: {"iid", "labels", "biased"}
Splitting scheme to use. In all cases, shards contain mutually-
exclusive samples and cover the full dataset. See details above.
p_valid: float, default=0.2
Share of each shard to turn into a validation subset.
seed: int or None, default=None
Optional seed to the RNG used for all sampling operations.
Returns
-------
shards: list[((np.ndarray, np.ndarray), (np.ndarray, np.ndarray))]
List of dataset shards, where each element is formatted as a
tuple of tuples: `((x_train, y_train), (x_valid, y_valid))`.
Raises
------
ValueError
If `scheme` has an invalid value.
"""
# Select the splitting function to be used.
if scheme == "iid":
func = split_iid
elif scheme == "labels":
func = split_labels
elif scheme == "biased":
func = split_biased
else:
raise ValueError(f"Invalid 'scheme' value: '{scheme}'.")
# Set up the RNG and split the dataset into shards.
rng = np.random.default_rng(seed)
inputs, target = dataset
split = func(inputs, target, n_shards, rng)
# Further split shards into training and validation subsets, and return.
return [train_valid_split(inp, tgt, p_valid, rng) for inp, tgt in split]
def split_iid(
inputs: np.ndarray,
target: np.ndarray,
n_shards: int,
rng: np.random.Generator,
) -> List[Tuple[np.ndarray, np.ndarray]]:
"""Split a dataset into shards using iid sampling."""
order = rng.permutation(len(inputs))
s_len = len(inputs) // n_shards
split = [] # type: List[Tuple[np.ndarray, np.ndarray]]
for idx in range(n_shards):
srt = idx * s_len
end = (srt + s_len) if idx < (n_shards - 1) else len(order)
shard = order[srt:end]
split.append((inputs[shard], target[shard]))
return split
def split_labels(
inputs: np.ndarray,
target: np.ndarray,
n_shards: int,
rng: np.random.Generator,
) -> List[Tuple[np.ndarray, np.ndarray]]:
"""Split a dataset into shards with mutually-exclusive label classes."""
classes = np.unique(target)
if n_shards > len(classes):
raise ValueError(
f"Cannot share {len(classes)} classes between {n_shards}"
"shards with mutually-exclusive labels."
)
s_len = len(classes) // n_shards
order = rng.permutation(classes)
split = [] # type: List[Tuple[np.ndarray, np.ndarray]]
for idx in range(n_shards):
srt = idx * s_len
end = (srt + s_len) if idx < (n_shards - 1) else len(order)
shard = np.isin(target, order[srt:end])
split.append((inputs[shard], target[shard]))
return split
def split_biased(
inputs: np.ndarray,
target: np.ndarray,
n_shards: int,
rng: np.random.Generator,
) -> List[Tuple[np.ndarray, np.ndarray]]:
"""Split a dataset into shards with heterogeneous label distributions."""
classes = np.unique(target)
index = np.arange(len(target))
s_len = len(target) // n_shards
split = [] # type: List[Tuple[np.ndarray, np.ndarray]]
for idx in range(n_shards):
if idx < (n_shards - 1):
# Draw a random distribution of labels for this node.
logits = np.exp(rng.normal(size=len(classes)))
lprobs = logits[target[index]]
lprobs = lprobs / lprobs.sum()
# Draw samples based on this distribution, without replacement.
shard = rng.choice(index, size=s_len, replace=False, p=lprobs)
index = index[~np.isin(index, shard)]
else:
# For the last node: use the remaining samples.
shard = index
split.append((inputs[shard], target[shard]))
return split
def train_valid_split(
inputs: np.ndarray,
target: np.ndarray,
p_valid: float,
rng: np.random.Generator,
) -> Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]:
"""Split a dataset between train and validation using iid sampling."""
order = rng.permutation(len(inputs))
v_len = np.ceil(len(inputs) * p_valid).astype(int)
train = inputs[order[v_len:]], target[order[v_len:]]
valid = inputs[order[:v_len]], target[order[:v_len]]
return train, valid
......@@ -33,12 +33,12 @@ instance sparse data
"""
import os
from typing import List, Optional, Tuple, Union
from typing import Optional, Tuple, Union
import numpy as np
from declearn.dataset.examples import load_mnist
from declearn.dataset.utils import load_data_array
from declearn.dataset.utils import load_data_array, split_multi_classif_dataset
from declearn.quickrun._config import DataSplitConfig
......@@ -91,76 +91,6 @@ def load_data(
return inputs, labels
def _split_iid(
inputs: np.ndarray,
target: np.ndarray,
n_shards: int,
rng: np.random.Generator,
) -> List[Tuple[np.ndarray, np.ndarray]]:
"""Split a dataset into shards using iid sampling."""
order = rng.permutation(len(inputs))
s_len = len(inputs) // n_shards
split = [] # type: List[Tuple[np.ndarray, np.ndarray]]
for idx in range(n_shards):
srt = idx * s_len
end = (srt + s_len) if idx < (n_shards - 1) else len(order)
shard = order[srt:end]
split.append((inputs[shard], target[shard]))
return split
def _split_labels(
inputs: np.ndarray,
target: np.ndarray,
n_shards: int,
rng: np.random.Generator,
) -> List[Tuple[np.ndarray, np.ndarray]]:
"""Split a dataset into shards with mutually-exclusive label classes."""
classes = np.unique(target)
if n_shards > len(classes):
raise ValueError(
f"Cannot share {len(classes)} classes between {n_shards}"
"shards with mutually-exclusive labels."
)
s_len = len(classes) // n_shards
order = rng.permutation(classes)
split = [] # type: List[Tuple[np.ndarray, np.ndarray]]
for idx in range(n_shards):
srt = idx * s_len
end = (srt + s_len) if idx < (n_shards - 1) else len(order)
shard = np.isin(target, order[srt:end])
shuffle = rng.permutation(shard.sum())
split.append((inputs[shard][shuffle], target[shard][shuffle]))
return split
def _split_biased(
inputs: np.ndarray,
target: np.ndarray,
n_shards: int,
rng: np.random.Generator,
) -> List[Tuple[np.ndarray, np.ndarray]]:
"""Split a dataset into shards with heterogeneous label distributions."""
classes = np.unique(target)
index = np.arange(len(target))
s_len = len(target) // n_shards
split = [] # type: List[Tuple[np.ndarray, np.ndarray]]
for idx in range(n_shards):
if idx < (n_shards - 1):
# Draw a random distribution of labels for this node.
logits = np.exp(rng.normal(size=len(classes)))
lprobs = logits[target[index]]
lprobs = lprobs / lprobs.sum()
# Draw samples based on this distribution, without replacement.
shard = rng.choice(index, size=s_len, replace=False, p=lprobs)
index = index[~np.isin(index, shard)]
else:
# For the last node: use the remaining samples.
shard = index
split.append((inputs[shard], target[shard]))
return split
def split_data(data_config: DataSplitConfig, folder: str) -> None:
"""Download and randomly split a dataset into shards.
......@@ -180,54 +110,33 @@ def split_data(data_config: DataSplitConfig, folder: str) -> None:
data_config: DataSplitConfig
A DataSplitConfig instance, see class documentation for details
"""
def np_save(folder, data, i, name):
data_dir = os.path.join(folder, f"client_{i}")
os.makedirs(data_dir, exist_ok=True)
np.save(os.path.join(data_dir, f"{name}.npy"), data)
# Overwrite default folder if provided
scheme = data_config.scheme
name = f"data_{scheme}"
data_file = data_config.data_file
label_file = data_config.label_file
# Select output folder.
if data_config.data_folder:
folder = os.path.dirname(data_config.data_folder)
name = os.path.split(data_config.data_folder)[-1]
data_file = os.path.abspath(data_config.data_file)
label_file = os.path.abspath(data_config.label_file)
# Select the splitting function to be used.
if scheme == "iid":
func = _split_iid
elif scheme == "labels":
func = _split_labels
elif scheme == "biased":
func = _split_biased
else:
raise ValueError(f"Invalid 'scheme' value: '{scheme}'.")
# Set up the RNG, download the raw dataset and split it.
rng = np.random.default_rng(data_config.seed)
inputs, labels = load_data(data_file, label_file)
folder = f"data_{data_config.scheme}"
# Value-check the 'perc_train' parameter.
if not 0.0 < data_config.perc_train <= 1.0:
raise ValueError("'perc_train' should be a float in ]0,1]")
# Load the dataset and split it.
inputs, labels = load_data(data_config.data_file, data_config.label_file)
print(
f"Splitting data into {data_config.n_shards}"
f"shards using the {scheme} scheme"
f"Splitting data into {data_config.n_shards} shards "
f"using the '{data_config.scheme}' scheme."
)
split = split_multi_classif_dataset(
dataset=(inputs, labels),
n_shards=data_config.n_shards,
scheme=data_config.scheme, # type: ignore
p_valid=(1 - data_config.perc_train),
seed=data_config.seed,
)
split = func(inputs, labels, data_config.n_shards, rng)
# Export the resulting shard-wise data to files.
folder = os.path.join(folder, name)
for i, (inp, tgt) in enumerate(split):
perc_train = data_config.perc_train
if not perc_train:
np_save(folder, inp, i, "train_data")
np_save(folder, tgt, i, "train_target")
else:
if perc_train > 1.0 or perc_train < 0.0:
raise ValueError("perc_train should be a float in ]0,1]")
n_train = round(len(inp) * perc_train)
t_inp, t_tgt = inp[:n_train], tgt[:n_train]
v_inp, v_tgt = inp[n_train:], tgt[n_train:]
np_save(folder, t_inp, i, "train_data")
np_save(folder, t_tgt, i, "train_target")
np_save(folder, v_inp, i, "valid_data")
np_save(folder, v_tgt, i, "valid_target")
for idx, ((x_train, y_train), (x_valid, y_valid)) in enumerate(split):
subdir = os.path.join(folder, f"client_{idx}")
os.makedirs(subdir, exist_ok=True)
np.save(os.path.join(subdir, "train_data.npy"), x_train)
np.save(os.path.join(subdir, "train_target.npy"), y_train)
if len(x_valid):
np.save(os.path.join(subdir, "valid_data.npy"), x_valid)
np.save(os.path.join(subdir, "valid_target.npy"), y_valid)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment