diff --git a/declearn/quickrun/_config.py b/declearn/quickrun/_config.py index fc30685f4d724b29a2657378bdb3e2ac4888fb8f..7614ed316dc6ded4e4de9dedce8121c4e1071a1a 100644 --- a/declearn/quickrun/_config.py +++ b/declearn/quickrun/_config.py @@ -82,7 +82,7 @@ class DataSplitConfig(TomlConfig): # Common args data_folder: Optional[str] = None # split_data args - n_shards: int = 5 + n_shards: int = 3 data_file: Optional[str] = None label_file: Optional[Union[str, int]] = None scheme: str = "iid" diff --git a/declearn/quickrun/_parser.py b/declearn/quickrun/_parser.py index eae9c9fccb39ee9e601595dff73c7529ad6021eb..a753d7dcfd26f5379dcb883af8947f13355d5ea3 100644 --- a/declearn/quickrun/_parser.py +++ b/declearn/quickrun/_parser.py @@ -63,6 +63,7 @@ def parse_data_folder( data_folder = expe_config.data_folder client_names = expe_config.client_names dataset_names = expe_config.dataset_names + scheme = expe_config.scheme if not folder and not data_folder: raise ValueError( @@ -70,7 +71,8 @@ def parse_data_folder( ) # Data_folder if not data_folder: - gen_folders = Path(folder).glob("data*") # type: ignore + search_str = f"_{scheme}" if scheme else "*" + gen_folders = Path(folder).glob(f"data{search_str}") # type: ignore data_folder = next(gen_folders, False) # type: ignore if not data_folder: raise ValueError( @@ -83,6 +85,8 @@ def parse_data_folder( f"in {folder}. Please store your data under a single" "parent folder" ) + else: + data_folder = Path(data_folder) # Get clients dir if client_names: if isinstance(client_names, list): diff --git a/declearn/quickrun/_split_data.py b/declearn/quickrun/_split_data.py index 0fbf80692b8b2b155e048a120ee1b10926cdcf2d..f46e0dfa7eb2849c05c8e400e8df0150a1bb5aa1 100644 --- a/declearn/quickrun/_split_data.py +++ b/declearn/quickrun/_split_data.py @@ -111,6 +111,7 @@ def load_data( inputs = load_data_array(data) inputs = np.asarray(inputs) else: + print("\n\n", data, "\n\n") raise ValueError("The data path provided is not a valid file") if isinstance(target, int): @@ -225,10 +226,16 @@ def split_data(data_config: DataSplitConfig, folder: str) -> None: np.save(os.path.join(data_dir, f"{name}.npy"), data) # Overwrite default folder if provided + scheme = data_config.scheme + name = f"data_{scheme}" + data_file = data_config.data_file + label_file = data_config.label_file if data_config.data_folder: - folder = data_config.data_folder + folder = os.path.dirname(data_config.data_folder) + name = os.path.split(data_config.data_folder)[-1] + data_file = os.path.abspath(data_config.data_file) + label_file = os.path.abspath(data_config.label_file) # Select the splitting function to be used. - scheme = data_config.scheme if scheme == "iid": func = _split_iid elif scheme == "labels": @@ -239,14 +246,15 @@ def split_data(data_config: DataSplitConfig, folder: str) -> None: raise ValueError(f"Invalid 'scheme' value: '{scheme}'.") # Set up the RNG, download the raw dataset and split it. rng = np.random.default_rng(data_config.seed) - inputs, labels = load_data(data_config.data_file, data_config.label_file) + + inputs, labels = load_data(data_file, label_file) print( f"Splitting data into {data_config.n_shards}" f"shards using the {scheme} scheme" ) split = func(inputs, labels, data_config.n_shards, rng) # Export the resulting shard-wise data to files. - folder = os.path.join(folder, f"data_{scheme}") + folder = os.path.join(folder, name) for i, (inp, tgt) in enumerate(split): perc_train = data_config.perc_train if not perc_train: diff --git a/declearn/quickrun/run.py b/declearn/quickrun/run.py index 1a8f8a08300c1b0234cd7bc481d901ae3cffe8a7..21cd0028af162cce03344db5e8b4e8e50ee4a526 100644 --- a/declearn/quickrun/run.py +++ b/declearn/quickrun/run.py @@ -59,7 +59,8 @@ with make_importable(os.path.dirname(__file__)): # pylint: enable=wrong-import-order, wrong-import-position -def _get_model(folder, model_config) -> Model: +def get_model(folder, model_config) -> Model: + "Return a model instance from a model config instance" file = "model" if m_file := model_config.model_file: folder = os.path.dirname(m_file) @@ -70,7 +71,7 @@ def _get_model(folder, model_config) -> Model: return model_cls -def _run_server( +def run_server( folder: str, network: NetworkServerConfig, model_config: ModelConfig, @@ -79,7 +80,7 @@ def _run_server( expe_config: ExperimentConfig, ) -> None: """Routine to run a FL server, called by `run_declearn_experiment`.""" - model = _get_model(folder, model_config) + model = get_model(folder, model_config) if expe_config.checkpoint: checkpoint = expe_config.checkpoint else: @@ -91,7 +92,7 @@ def _run_server( server.run(config) -def _run_client( +def run_client( folder: str, network: NetworkClientConfig, model_config: ModelConfig, @@ -103,7 +104,7 @@ def _run_client( # Overwrite client name based on folder name network.name = name # Make the model importable - _ = _get_model(folder, model_config) + _ = get_model(folder, model_config) # Add checkpointer if expe_config.checkpoint: checkpoint = expe_config.checkpoint @@ -186,14 +187,14 @@ def quickrun(config: Optional[str] = None) -> None: expe_cfg = ExperimentConfig.from_toml(toml, False, "experiment") # Set up a (func, args) tuple specifying the server process. p_server = ( - _run_server, + run_server, (folder, ntk_server_cfg, model_cfg, optim_cgf, run_cfg, expe_cfg), ) # Set up the (func, args) tuples specifying client-wise processes. p_client = [] for name, data_dict in client_dict.items(): client = ( - _run_client, + run_client, (folder, ntk_client_cfg, model_cfg, expe_cfg, name, data_dict), ) p_client.append(client) @@ -241,3 +242,4 @@ def main(args: Optional[List[str]] = None) -> None: if __name__ == "__main__": main() + # quickrun(config="examples/quickrun/config_custom.toml") # TODO diff --git a/declearn/utils/_multiprocess.py b/declearn/utils/_multiprocess.py index 72b343acb2f24c530fc8eb1ab4c7e18709c961fa..a39ac540d7d1001e466d3560ce3026edee38fca9 100644 --- a/declearn/utils/_multiprocess.py +++ b/declearn/utils/_multiprocess.py @@ -17,6 +17,7 @@ """Utils to run concurrent routines parallelly using multiprocessing.""" +import functools import multiprocessing as mp import sys import traceback @@ -110,21 +111,25 @@ def add_exception_catching( if not name: name = func.__name__ - def wrapped(*args: Any, **kwargs: Any) -> Any: - """Call the wrapped function and catch exceptions or results.""" - nonlocal name, queue + return functools.partial(wrapped, func=func, queue=queue, name=name) - try: - result = func(*args, **kwargs) - except Exception as exc: # pylint: disable=broad-exception-caught - err = RuntimeError( - f"Exception of type {type(exc)} occurred:\n" - "".join(traceback.format_exception(type(exc), exc, tb=None)) - ) # future: `traceback.format_exception(exc)` (py >=3.10) - queue.put((name, err)) - sys.exit(1) - else: - queue.put((name, result)) - sys.exit(0) - return wrapped +def wrapped( + *args: Any, + func: Callable[..., Any], + queue: Queue, + name: str, +) -> Any: + """Call the wrapped function and catch exceptions or results.""" + try: + result = func(*args) + except Exception as exc: # pylint: disable=broad-exception-caught + err = RuntimeError( + f"Exception of type {type(exc)} occurred:\n" + "".join(traceback.format_exception(type(exc), exc, tb=None)) + ) # future: `traceback.format_exception(exc)` (py >=3.10) + queue.put((name, err)) + sys.exit(1) + else: + queue.put((name, result)) + sys.exit(0)