From 1c26f2c201bf62a60902cfa9dc74b0c290982911 Mon Sep 17 00:00:00 2001 From: BIGAUD Nathan <nathan.bigaud@inria.fr> Date: Thu, 30 Mar 2023 09:26:31 +0200 Subject: [PATCH] Adding checkpointer and metrics, re-shuffling data in labels split --- declearn/quickrun/_config.py | 55 ++++++++++++++----------- declearn/quickrun/_split_data.py | 10 +++-- declearn/quickrun/run.py | 69 ++++++++++++++++++++++---------- examples/quickrun/config.toml | 19 ++++----- 4 files changed, 93 insertions(+), 60 deletions(-) diff --git a/declearn/quickrun/_config.py b/declearn/quickrun/_config.py index a80e711b..c2559da5 100644 --- a/declearn/quickrun/_config.py +++ b/declearn/quickrun/_config.py @@ -18,7 +18,7 @@ """TOML-parsable container for quickrun configurations.""" import dataclasses -from typing import Dict, List, Literal, Optional, Union +from typing import Dict, List, Optional, Union from declearn.utils import TomlConfig @@ -40,11 +40,13 @@ class ModelConfig(TomlConfig): @dataclasses.dataclass class DataSplitConfig(TomlConfig): - """Dataclass associated with the function - declearn.quickrun._split_data:split_data + """Dataclass associated with the functions + declearn.quickrun._split_data:split_data and + declearn.quickrun._parser:parse_data_folder - export_folder: str - Path to the folder where to export shard-wise files. + data_folder: str + Absolute path to the folder where to export shard-wise files, + and/or to the main folder hosting the data. n_shards: int Number of shards between which to split the data. data_file: str or None, default=None @@ -67,37 +69,44 @@ class DataSplitConfig(TomlConfig): ]0,1] range. seed: int or None, default=None Optional seed to the RNG used for all sampling operations. + client_names: list or None + List of custom client names to look for in the data_folder. + If None, default to expected prefix search. + dataset_names: dict or None + Dict of custom dataset names, to look for in each client folder. + Expect 'train_data, train_target, valid_data, valid_target' as keys. + If None, , default to expected prefix search. """ - export_folder: str = "." + # Common args + data_folder: Optional[str] = None + # split_data args n_shards: int = 5 data_file: Optional[str] = None label_file: Optional[Union[str, int]] = None - scheme: Literal["iid", "labels", "biased"] = "iid" + scheme: str = "iid" perc_train: float = 0.8 seed: Optional[int] = None + # parse_data_folder args + client_names: Optional[List[str]] = None + dataset_names: Optional[Dict[str, str]] = None @dataclasses.dataclass class ExperimentConfig(TomlConfig): """ - Dataclass associated with the function - declearn.quickrun._parser:parse_data_folder + Dataclass providing kwargs to + declearn.main._server.FederatedServer + and declearn.main._client.FederatedClient - data_folder : str or none - Absolute path to the main folder hosting the data, overwriting - the folder argument if provided. If None, default to expected - prefix search in folder. - client_names: list or None - List of custom client names to look for in the data_folder. - If None, default to expected prefix search. - dataset_names: dict or None - Dict of custom dataset names, to look for in each client folder. - Expect 'train_data, train_target, valid_data, valid_target' as keys. - If None, , default to expected prefix search. + metrics: list[str] or None + List of Metric childclass names, defining evaluation metrics + to compute in addition to the model's loss. + checkpoint: str or None + The checkpoint folder path and use default values for other parameters + to be used so as to save round-wise model """ - data_folder: Optional[str] = None - client_names: Optional[List[str]] = None - dataset_names: Optional[Dict[str, str]] = None + metrics: Optional[List[str]] = None + checkpoint: Optional[str] = None diff --git a/declearn/quickrun/_split_data.py b/declearn/quickrun/_split_data.py index e323258e..0fbf8069 100644 --- a/declearn/quickrun/_split_data.py +++ b/declearn/quickrun/_split_data.py @@ -167,7 +167,8 @@ def _split_labels( srt = idx * s_len end = (srt + s_len) if idx < (n_shards - 1) else len(order) shard = np.isin(target, order[srt:end]) - split.append((inputs[shard], target[shard])) + shuffle = rng.permutation(shard.sum()) + split.append((inputs[shard][shuffle], target[shard][shuffle])) return split @@ -198,7 +199,7 @@ def _split_biased( return split -def split_data(data_config: DataSplitConfig) -> None: +def split_data(data_config: DataSplitConfig, folder: str) -> None: """Download and randomly split a dataset into shards. The resulting folder structure is : @@ -223,6 +224,9 @@ def split_data(data_config: DataSplitConfig) -> None: os.makedirs(data_dir, exist_ok=True) np.save(os.path.join(data_dir, f"{name}.npy"), data) + # Overwrite default folder if provided + if data_config.data_folder: + folder = data_config.data_folder # Select the splitting function to be used. scheme = data_config.scheme if scheme == "iid": @@ -242,7 +246,7 @@ def split_data(data_config: DataSplitConfig) -> None: ) split = func(inputs, labels, data_config.n_shards, rng) # Export the resulting shard-wise data to files. - folder = os.path.join(data_config.export_folder, f"data_{scheme}") + folder = os.path.join(folder, f"data_{scheme}") for i, (inp, tgt) in enumerate(split): perc_train = data_config.perc_train if not perc_train: diff --git a/declearn/quickrun/run.py b/declearn/quickrun/run.py index e52a859d..1a8f8a08 100644 --- a/declearn/quickrun/run.py +++ b/declearn/quickrun/run.py @@ -36,7 +36,6 @@ import importlib import os import re import textwrap -from pathlib import Path from typing import Any, Dict, List, Optional, Tuple from declearn.communication import NetworkClientConfig, NetworkServerConfig @@ -77,10 +76,18 @@ def _run_server( model_config: ModelConfig, optim: FLOptimConfig, config: FLRunConfig, + expe_config: ExperimentConfig, ) -> None: """Routine to run a FL server, called by `run_declearn_experiment`.""" model = _get_model(folder, model_config) - server = FederatedServer(model, network, optim) + if expe_config.checkpoint: + checkpoint = expe_config.checkpoint + else: + checkpoint = os.path.join(folder, "result") + checkpoint = os.path.join(checkpoint, "server") + server = FederatedServer( + model, network, optim, expe_config.metrics, checkpoint + ) server.run(config) @@ -88,6 +95,7 @@ def _run_client( folder: str, network: NetworkClientConfig, model_config: ModelConfig, + expe_config: ExperimentConfig, name: str, paths: dict, ) -> None: @@ -96,6 +104,12 @@ def _run_client( network.name = name # Make the model importable _ = _get_model(folder, model_config) + # Add checkpointer + if expe_config.checkpoint: + checkpoint = expe_config.checkpoint + else: + checkpoint = os.path.join(folder, "result") + checkpoint = os.path.join(checkpoint, name) # Wrap train and validation data as Dataset objects. train = InMemoryDataset( paths.get("train_data"), @@ -106,7 +120,7 @@ def _run_client( paths.get("valid_data"), target=paths.get("valid_target"), ) - client = FederatedClient(network, train, valid) + client = FederatedClient(network, train, valid, checkpoint) client.run() @@ -132,22 +146,29 @@ def get_toml_folder(config: Optional[str] = None) -> Tuple[str, str]: def locate_or_create_split_data(toml: str, folder: str) -> Dict: """Attempts to find split data according to the config toml or - or the defualt behavior. If failed, attempts to find full data + or the default behavior. If failed, attempts to find full data according to the config toml and split it""" - expe_config = ExperimentConfig.from_toml(toml, False, "experiment") + data_config = DataSplitConfig.from_toml(toml, False, "data") try: - client_dict = parse_data_folder(expe_config, folder) + client_dict = parse_data_folder(data_config, folder) except ValueError: - data_config = DataSplitConfig.from_toml(toml, False, "data") - split_data(folder, data_config) - client_dict = parse_data_folder(expe_config,folder) + split_data(data_config, folder) + client_dict = parse_data_folder(data_config, folder) return client_dict -def quickrun( - config: Optional[str] = None, - **kwargs: Any, -) -> None: +def server_to_client_network( + network_cfg: NetworkServerConfig, +) -> NetworkClientConfig: + "Converts server network config to client network config" + return NetworkClientConfig.from_params( + protocol=network_cfg.protocol, + server_uri=f"ws://localhost:{network_cfg.port}", + name="replaceme", + ) + + +def quickrun(config: Optional[str] = None) -> None: """Run a server and its clients using multiprocessing. The kwargs are the arguments expected by split_data, @@ -156,20 +177,24 @@ def quickrun( toml, folder = get_toml_folder(config) # locate split data or split it if needed client_dict = locate_or_create_split_data(toml, folder) - # Parse toml file to ServerConfig and ClientConfig - ntk_server = NetworkServerConfig.from_toml(toml, False, "network_server") - optim = FLOptimConfig.from_toml(toml, False, "optim") - run = FLRunConfig.from_toml(toml, False, "run") - ntk_client = NetworkClientConfig.from_toml(toml, False, "network_client") - model_config = ModelConfig.from_toml(toml, False, "model") + # Parse toml files + ntk_server_cfg = NetworkServerConfig.from_toml(toml, False, "network") + ntk_client_cfg = server_to_client_network(ntk_server_cfg) + optim_cgf = FLOptimConfig.from_toml(toml, False, "optim") + run_cfg = FLRunConfig.from_toml(toml, False, "run") + model_cfg = ModelConfig.from_toml(toml, False, "model") + expe_cfg = ExperimentConfig.from_toml(toml, False, "experiment") # Set up a (func, args) tuple specifying the server process. - p_server = (_run_server, (folder, ntk_server, model_config, optim, run)) + p_server = ( + _run_server, + (folder, ntk_server_cfg, model_cfg, optim_cgf, run_cfg, expe_cfg), + ) # Set up the (func, args) tuples specifying client-wise processes. p_client = [] for name, data_dict in client_dict.items(): client = ( _run_client, - (folder, ntk_client, model_config, name, data_dict), + (folder, ntk_client_cfg, model_cfg, expe_cfg, name, data_dict), ) p_client.append(client) # Run each and every process in parallel. @@ -211,7 +236,7 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace: def main(args: Optional[List[str]] = None) -> None: """Quikcrun based on commandline-input arguments.""" cmdargs = parse_args(args) - quickrun(folder=cmdargs.config) + quickrun(config=cmdargs.config) if __name__ == "__main__": diff --git a/examples/quickrun/config.toml b/examples/quickrun/config.toml index 32f00403..a25f77f7 100644 --- a/examples/quickrun/config.toml +++ b/examples/quickrun/config.toml @@ -1,24 +1,18 @@ -[network_server] +[network] protocol = "websockets" host = "127.0.0.1" port = 8765 -[network_client] -protocol = "websockets" -server_uri = "ws://localhost:8765" -name = "replaceme" - [optim] aggregator = "averaging" # The chosen aggregation strategy server_opt = 1.0 # The server learning rate [optim.client_opt] - lrate = 0.01 # The client learning rate - regularizers = [["lasso", {alpha = 0.1}]] # The list of regularizer modules, each a list - modules = [["momentum", {"beta" = 0.9}]] + lrate = 0.001 # The client learning rate + modules = ["adam"] # The optimzer modules [run] -rounds = 2 # Number of training rounds +rounds = 10 # Number of training rounds [run.register] min_clients = 1 @@ -36,10 +30,11 @@ rounds = 2 # Number of training rounds [experiment] # all args for parse_data_folder -# target_folder = !!! + [model] # information on where to find the model file [data] -# all args from split_data argparser \ No newline at end of file +# all args from split_data argparser +scheme = "labels" \ No newline at end of file -- GitLab