Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit 1c26f2c2 authored by BIGAUD Nathan's avatar BIGAUD Nathan Committed by ANDREY Paul
Browse files

Adding checkpointer and metrics, re-shuffling data in labels split

parent 98027bc4
No related branches found
No related tags found
1 merge request!41Quickrun mode
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
"""TOML-parsable container for quickrun configurations.""" """TOML-parsable container for quickrun configurations."""
import dataclasses import dataclasses
from typing import Dict, List, Literal, Optional, Union from typing import Dict, List, Optional, Union
from declearn.utils import TomlConfig from declearn.utils import TomlConfig
...@@ -40,11 +40,13 @@ class ModelConfig(TomlConfig): ...@@ -40,11 +40,13 @@ class ModelConfig(TomlConfig):
@dataclasses.dataclass @dataclasses.dataclass
class DataSplitConfig(TomlConfig): class DataSplitConfig(TomlConfig):
"""Dataclass associated with the function """Dataclass associated with the functions
declearn.quickrun._split_data:split_data declearn.quickrun._split_data:split_data and
declearn.quickrun._parser:parse_data_folder
export_folder: str data_folder: str
Path to the folder where to export shard-wise files. Absolute path to the folder where to export shard-wise files,
and/or to the main folder hosting the data.
n_shards: int n_shards: int
Number of shards between which to split the data. Number of shards between which to split the data.
data_file: str or None, default=None data_file: str or None, default=None
...@@ -67,37 +69,44 @@ class DataSplitConfig(TomlConfig): ...@@ -67,37 +69,44 @@ class DataSplitConfig(TomlConfig):
]0,1] range. ]0,1] range.
seed: int or None, default=None seed: int or None, default=None
Optional seed to the RNG used for all sampling operations. Optional seed to the RNG used for all sampling operations.
client_names: list or None
List of custom client names to look for in the data_folder.
If None, default to expected prefix search.
dataset_names: dict or None
Dict of custom dataset names, to look for in each client folder.
Expect 'train_data, train_target, valid_data, valid_target' as keys.
If None, , default to expected prefix search.
""" """
export_folder: str = "." # Common args
data_folder: Optional[str] = None
# split_data args
n_shards: int = 5 n_shards: int = 5
data_file: Optional[str] = None data_file: Optional[str] = None
label_file: Optional[Union[str, int]] = None label_file: Optional[Union[str, int]] = None
scheme: Literal["iid", "labels", "biased"] = "iid" scheme: str = "iid"
perc_train: float = 0.8 perc_train: float = 0.8
seed: Optional[int] = None seed: Optional[int] = None
# parse_data_folder args
client_names: Optional[List[str]] = None
dataset_names: Optional[Dict[str, str]] = None
@dataclasses.dataclass @dataclasses.dataclass
class ExperimentConfig(TomlConfig): class ExperimentConfig(TomlConfig):
""" """
Dataclass associated with the function Dataclass providing kwargs to
declearn.quickrun._parser:parse_data_folder declearn.main._server.FederatedServer
and declearn.main._client.FederatedClient
data_folder : str or none metrics: list[str] or None
Absolute path to the main folder hosting the data, overwriting List of Metric childclass names, defining evaluation metrics
the folder argument if provided. If None, default to expected to compute in addition to the model's loss.
prefix search in folder. checkpoint: str or None
client_names: list or None The checkpoint folder path and use default values for other parameters
List of custom client names to look for in the data_folder. to be used so as to save round-wise model
If None, default to expected prefix search.
dataset_names: dict or None
Dict of custom dataset names, to look for in each client folder.
Expect 'train_data, train_target, valid_data, valid_target' as keys.
If None, , default to expected prefix search.
""" """
data_folder: Optional[str] = None metrics: Optional[List[str]] = None
client_names: Optional[List[str]] = None checkpoint: Optional[str] = None
dataset_names: Optional[Dict[str, str]] = None
...@@ -167,7 +167,8 @@ def _split_labels( ...@@ -167,7 +167,8 @@ def _split_labels(
srt = idx * s_len srt = idx * s_len
end = (srt + s_len) if idx < (n_shards - 1) else len(order) end = (srt + s_len) if idx < (n_shards - 1) else len(order)
shard = np.isin(target, order[srt:end]) shard = np.isin(target, order[srt:end])
split.append((inputs[shard], target[shard])) shuffle = rng.permutation(shard.sum())
split.append((inputs[shard][shuffle], target[shard][shuffle]))
return split return split
...@@ -198,7 +199,7 @@ def _split_biased( ...@@ -198,7 +199,7 @@ def _split_biased(
return split return split
def split_data(data_config: DataSplitConfig) -> None: def split_data(data_config: DataSplitConfig, folder: str) -> None:
"""Download and randomly split a dataset into shards. """Download and randomly split a dataset into shards.
The resulting folder structure is : The resulting folder structure is :
...@@ -223,6 +224,9 @@ def split_data(data_config: DataSplitConfig) -> None: ...@@ -223,6 +224,9 @@ def split_data(data_config: DataSplitConfig) -> None:
os.makedirs(data_dir, exist_ok=True) os.makedirs(data_dir, exist_ok=True)
np.save(os.path.join(data_dir, f"{name}.npy"), data) np.save(os.path.join(data_dir, f"{name}.npy"), data)
# Overwrite default folder if provided
if data_config.data_folder:
folder = data_config.data_folder
# Select the splitting function to be used. # Select the splitting function to be used.
scheme = data_config.scheme scheme = data_config.scheme
if scheme == "iid": if scheme == "iid":
...@@ -242,7 +246,7 @@ def split_data(data_config: DataSplitConfig) -> None: ...@@ -242,7 +246,7 @@ def split_data(data_config: DataSplitConfig) -> None:
) )
split = func(inputs, labels, data_config.n_shards, rng) split = func(inputs, labels, data_config.n_shards, rng)
# Export the resulting shard-wise data to files. # Export the resulting shard-wise data to files.
folder = os.path.join(data_config.export_folder, f"data_{scheme}") folder = os.path.join(folder, f"data_{scheme}")
for i, (inp, tgt) in enumerate(split): for i, (inp, tgt) in enumerate(split):
perc_train = data_config.perc_train perc_train = data_config.perc_train
if not perc_train: if not perc_train:
......
...@@ -36,7 +36,6 @@ import importlib ...@@ -36,7 +36,6 @@ import importlib
import os import os
import re import re
import textwrap import textwrap
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from declearn.communication import NetworkClientConfig, NetworkServerConfig from declearn.communication import NetworkClientConfig, NetworkServerConfig
...@@ -77,10 +76,18 @@ def _run_server( ...@@ -77,10 +76,18 @@ def _run_server(
model_config: ModelConfig, model_config: ModelConfig,
optim: FLOptimConfig, optim: FLOptimConfig,
config: FLRunConfig, config: FLRunConfig,
expe_config: ExperimentConfig,
) -> None: ) -> None:
"""Routine to run a FL server, called by `run_declearn_experiment`.""" """Routine to run a FL server, called by `run_declearn_experiment`."""
model = _get_model(folder, model_config) model = _get_model(folder, model_config)
server = FederatedServer(model, network, optim) if expe_config.checkpoint:
checkpoint = expe_config.checkpoint
else:
checkpoint = os.path.join(folder, "result")
checkpoint = os.path.join(checkpoint, "server")
server = FederatedServer(
model, network, optim, expe_config.metrics, checkpoint
)
server.run(config) server.run(config)
...@@ -88,6 +95,7 @@ def _run_client( ...@@ -88,6 +95,7 @@ def _run_client(
folder: str, folder: str,
network: NetworkClientConfig, network: NetworkClientConfig,
model_config: ModelConfig, model_config: ModelConfig,
expe_config: ExperimentConfig,
name: str, name: str,
paths: dict, paths: dict,
) -> None: ) -> None:
...@@ -96,6 +104,12 @@ def _run_client( ...@@ -96,6 +104,12 @@ def _run_client(
network.name = name network.name = name
# Make the model importable # Make the model importable
_ = _get_model(folder, model_config) _ = _get_model(folder, model_config)
# Add checkpointer
if expe_config.checkpoint:
checkpoint = expe_config.checkpoint
else:
checkpoint = os.path.join(folder, "result")
checkpoint = os.path.join(checkpoint, name)
# Wrap train and validation data as Dataset objects. # Wrap train and validation data as Dataset objects.
train = InMemoryDataset( train = InMemoryDataset(
paths.get("train_data"), paths.get("train_data"),
...@@ -106,7 +120,7 @@ def _run_client( ...@@ -106,7 +120,7 @@ def _run_client(
paths.get("valid_data"), paths.get("valid_data"),
target=paths.get("valid_target"), target=paths.get("valid_target"),
) )
client = FederatedClient(network, train, valid) client = FederatedClient(network, train, valid, checkpoint)
client.run() client.run()
...@@ -132,22 +146,29 @@ def get_toml_folder(config: Optional[str] = None) -> Tuple[str, str]: ...@@ -132,22 +146,29 @@ def get_toml_folder(config: Optional[str] = None) -> Tuple[str, str]:
def locate_or_create_split_data(toml: str, folder: str) -> Dict: def locate_or_create_split_data(toml: str, folder: str) -> Dict:
"""Attempts to find split data according to the config toml or """Attempts to find split data according to the config toml or
or the defualt behavior. If failed, attempts to find full data or the default behavior. If failed, attempts to find full data
according to the config toml and split it""" according to the config toml and split it"""
expe_config = ExperimentConfig.from_toml(toml, False, "experiment") data_config = DataSplitConfig.from_toml(toml, False, "data")
try: try:
client_dict = parse_data_folder(expe_config, folder) client_dict = parse_data_folder(data_config, folder)
except ValueError: except ValueError:
data_config = DataSplitConfig.from_toml(toml, False, "data") split_data(data_config, folder)
split_data(folder, data_config) client_dict = parse_data_folder(data_config, folder)
client_dict = parse_data_folder(expe_config,folder)
return client_dict return client_dict
def quickrun( def server_to_client_network(
config: Optional[str] = None, network_cfg: NetworkServerConfig,
**kwargs: Any, ) -> NetworkClientConfig:
) -> None: "Converts server network config to client network config"
return NetworkClientConfig.from_params(
protocol=network_cfg.protocol,
server_uri=f"ws://localhost:{network_cfg.port}",
name="replaceme",
)
def quickrun(config: Optional[str] = None) -> None:
"""Run a server and its clients using multiprocessing. """Run a server and its clients using multiprocessing.
The kwargs are the arguments expected by split_data, The kwargs are the arguments expected by split_data,
...@@ -156,20 +177,24 @@ def quickrun( ...@@ -156,20 +177,24 @@ def quickrun(
toml, folder = get_toml_folder(config) toml, folder = get_toml_folder(config)
# locate split data or split it if needed # locate split data or split it if needed
client_dict = locate_or_create_split_data(toml, folder) client_dict = locate_or_create_split_data(toml, folder)
# Parse toml file to ServerConfig and ClientConfig # Parse toml files
ntk_server = NetworkServerConfig.from_toml(toml, False, "network_server") ntk_server_cfg = NetworkServerConfig.from_toml(toml, False, "network")
optim = FLOptimConfig.from_toml(toml, False, "optim") ntk_client_cfg = server_to_client_network(ntk_server_cfg)
run = FLRunConfig.from_toml(toml, False, "run") optim_cgf = FLOptimConfig.from_toml(toml, False, "optim")
ntk_client = NetworkClientConfig.from_toml(toml, False, "network_client") run_cfg = FLRunConfig.from_toml(toml, False, "run")
model_config = ModelConfig.from_toml(toml, False, "model") model_cfg = ModelConfig.from_toml(toml, False, "model")
expe_cfg = ExperimentConfig.from_toml(toml, False, "experiment")
# Set up a (func, args) tuple specifying the server process. # Set up a (func, args) tuple specifying the server process.
p_server = (_run_server, (folder, ntk_server, model_config, optim, run)) p_server = (
_run_server,
(folder, ntk_server_cfg, model_cfg, optim_cgf, run_cfg, expe_cfg),
)
# Set up the (func, args) tuples specifying client-wise processes. # Set up the (func, args) tuples specifying client-wise processes.
p_client = [] p_client = []
for name, data_dict in client_dict.items(): for name, data_dict in client_dict.items():
client = ( client = (
_run_client, _run_client,
(folder, ntk_client, model_config, name, data_dict), (folder, ntk_client_cfg, model_cfg, expe_cfg, name, data_dict),
) )
p_client.append(client) p_client.append(client)
# Run each and every process in parallel. # Run each and every process in parallel.
...@@ -211,7 +236,7 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace: ...@@ -211,7 +236,7 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace:
def main(args: Optional[List[str]] = None) -> None: def main(args: Optional[List[str]] = None) -> None:
"""Quikcrun based on commandline-input arguments.""" """Quikcrun based on commandline-input arguments."""
cmdargs = parse_args(args) cmdargs = parse_args(args)
quickrun(folder=cmdargs.config) quickrun(config=cmdargs.config)
if __name__ == "__main__": if __name__ == "__main__":
......
[network_server] [network]
protocol = "websockets" protocol = "websockets"
host = "127.0.0.1" host = "127.0.0.1"
port = 8765 port = 8765
[network_client]
protocol = "websockets"
server_uri = "ws://localhost:8765"
name = "replaceme"
[optim] [optim]
aggregator = "averaging" # The chosen aggregation strategy aggregator = "averaging" # The chosen aggregation strategy
server_opt = 1.0 # The server learning rate server_opt = 1.0 # The server learning rate
[optim.client_opt] [optim.client_opt]
lrate = 0.01 # The client learning rate lrate = 0.001 # The client learning rate
regularizers = [["lasso", {alpha = 0.1}]] # The list of regularizer modules, each a list modules = ["adam"] # The optimzer modules
modules = [["momentum", {"beta" = 0.9}]]
[run] [run]
rounds = 2 # Number of training rounds rounds = 10 # Number of training rounds
[run.register] [run.register]
min_clients = 1 min_clients = 1
...@@ -36,10 +30,11 @@ rounds = 2 # Number of training rounds ...@@ -36,10 +30,11 @@ rounds = 2 # Number of training rounds
[experiment] [experiment]
# all args for parse_data_folder # all args for parse_data_folder
# target_folder = !!!
[model] [model]
# information on where to find the model file # information on where to find the model file
[data] [data]
# all args from split_data argparser # all args from split_data argparser
\ No newline at end of file scheme = "labels"
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment