diff --git a/declearn/quickrun/_split_data.py b/declearn/quickrun/_split_data.py index f748bf1c9910af0cf0019ecb5e9cace030c13766..a45e5c64a0d555d5c5f7b41fa9da087b62667ed4 100644 --- a/declearn/quickrun/_split_data.py +++ b/declearn/quickrun/_split_data.py @@ -173,51 +173,14 @@ def _split_biased( return split -def export_shard_to_csv( - path: str, - inputs: np.ndarray, - target: np.ndarray, -) -> None: - """Export an MNIST shard to a csv file.""" - specs = {"dtype": inputs.dtype.char, "shape": list(inputs[0].shape)} - with open(path, "w", encoding="utf-8") as file: - file.write(f"{json.dumps(specs)},target") - for inp, tgt in zip(inputs, target): - file.write(f"\n{inp.tobytes().hex()},{int(tgt)}") - - -def load_mnist_from_csv( - path: str, -) -> Tuple[np.ndarray, np.ndarray]: - """Reload an MNIST shard from a csv file.""" - # Prepare data containers. - inputs = [] # type: List[np.ndarray] - target = [] # type: List[int] - # Parse the csv file. - with open(path, "r", encoding="utf-8") as file: - # Parse input features' specs from the csv header. - specs = json.loads(file.readline().rsplit(",", 1)[0]) - dtype = specs["dtype"] - shape = specs["shape"] - # Iteratively deserialize features and labels from rows. - for row in file: - inp, tgt = row.strip("\n").rsplit(",", 1) - dat = bytes.fromhex(inp) - inputs.append(np.frombuffer(dat, dtype=dtype).reshape(shape)) - target.append(int(tgt)) - # Assemble the data into numpy arrays and return. - return np.array(inputs), np.array(target) - - def split_data( - folder: str = DEFAULT_FOLDER, # CHECK if good practice + folder: str = DEFAULT_FOLDER, n_shards: int = 5, data: Optional[str] = None, target: Optional[Union[str, int]] = None, scheme: Literal["iid", "labels", "biased"] = "iid", perc_train: float = 0.8, seed: Optional[int] = None, - use_csv: bool = False, ) -> None: """Download and randomly split the MNIST dataset into shards. #TODO @@ -267,10 +230,6 @@ def split_data( np.save(os.path.join(folder, f"client_{i}/{name}.npy"), data) for i, (inp, tgt) in enumerate(split): - if use_csv: # TODO - path = os.path.join(folder, f"shard_{i}.csv") - export_shard_to_csv(path, inp, tgt) - return if not perc_train: np_save(inp, i, "train_data") np_save(tgt, i, "train_target") @@ -356,13 +315,6 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace: type=int, help="RNG seed to use (default: 20221109).", ) - parser.add_argument( - "--csv", - default=False, - dest="use_csv", - type=bool, - help="Export data as csv files (for use with Fed-BioMed).", - ) return parser.parse_args(args) @@ -377,7 +329,6 @@ def main(args: Optional[List[str]] = None) -> None: target=cmdargs.target, scheme=scheme, seed=cmdargs.seed, - use_csv=cmdargs.use_csv, )