From 1c26f2c201bf62a60902cfa9dc74b0c290982911 Mon Sep 17 00:00:00 2001
From: BIGAUD Nathan <nathan.bigaud@inria.fr>
Date: Thu, 30 Mar 2023 09:26:31 +0200
Subject: [PATCH] Adding checkpointer and metrics, re-shuffling data in labels
 split

---
 declearn/quickrun/_config.py     | 55 ++++++++++++++-----------
 declearn/quickrun/_split_data.py | 10 +++--
 declearn/quickrun/run.py         | 69 ++++++++++++++++++++++----------
 examples/quickrun/config.toml    | 19 ++++-----
 4 files changed, 93 insertions(+), 60 deletions(-)

diff --git a/declearn/quickrun/_config.py b/declearn/quickrun/_config.py
index a80e711b..c2559da5 100644
--- a/declearn/quickrun/_config.py
+++ b/declearn/quickrun/_config.py
@@ -18,7 +18,7 @@
 """TOML-parsable container for quickrun configurations."""
 
 import dataclasses
-from typing import Dict, List, Literal, Optional, Union
+from typing import Dict, List, Optional, Union
 
 from declearn.utils import TomlConfig
 
@@ -40,11 +40,13 @@ class ModelConfig(TomlConfig):
 
 @dataclasses.dataclass
 class DataSplitConfig(TomlConfig):
-    """Dataclass associated with the function
-    declearn.quickrun._split_data:split_data
+    """Dataclass associated with the functions
+    declearn.quickrun._split_data:split_data and
+    declearn.quickrun._parser:parse_data_folder
 
-    export_folder: str
-        Path to the folder where to export shard-wise files.
+    data_folder: str
+        Absolute path to the folder where to export shard-wise files,
+        and/or to the main folder hosting the data.
     n_shards: int
         Number of shards between which to split the data.
     data_file: str or None, default=None
@@ -67,37 +69,44 @@ class DataSplitConfig(TomlConfig):
         ]0,1] range.
     seed: int or None, default=None
         Optional seed to the RNG used for all sampling operations.
+    client_names: list or None
+        List of custom client names to look for in the data_folder.
+        If None, default to expected prefix search.
+    dataset_names: dict or None
+        Dict of custom dataset names, to look for in each client folder.
+        Expect 'train_data, train_target, valid_data, valid_target' as keys.
+        If None, , default to expected prefix search.
     """
 
-    export_folder: str = "."
+    # Common args
+    data_folder: Optional[str] = None
+    # split_data args
     n_shards: int = 5
     data_file: Optional[str] = None
     label_file: Optional[Union[str, int]] = None
-    scheme: Literal["iid", "labels", "biased"] = "iid"
+    scheme: str = "iid"
     perc_train: float = 0.8
     seed: Optional[int] = None
+    # parse_data_folder args
+    client_names: Optional[List[str]] = None
+    dataset_names: Optional[Dict[str, str]] = None
 
 
 @dataclasses.dataclass
 class ExperimentConfig(TomlConfig):
     """
 
-    Dataclass associated with the function
-    declearn.quickrun._parser:parse_data_folder
+    Dataclass providing kwargs to
+    declearn.main._server.FederatedServer
+    and declearn.main._client.FederatedClient
 
-    data_folder : str or none
-        Absolute path to the main folder hosting the data, overwriting
-        the folder argument if provided. If None, default to expected
-        prefix search in folder.
-    client_names: list or None
-        List of custom client names to look for in the data_folder.
-        If None, default to expected prefix search.
-    dataset_names: dict or None
-        Dict of custom dataset names, to look for in each client folder.
-        Expect 'train_data, train_target, valid_data, valid_target' as keys.
-        If None, , default to expected prefix search.
+    metrics: list[str] or None
+        List of Metric childclass names, defining evaluation metrics
+        to compute in addition to the model's loss.
+    checkpoint: str or None
+        The checkpoint folder path and use default values for other parameters
+        to be used so as to save round-wise model
     """
 
-    data_folder: Optional[str] = None
-    client_names: Optional[List[str]] = None
-    dataset_names: Optional[Dict[str, str]] = None
+    metrics: Optional[List[str]] = None
+    checkpoint: Optional[str] = None
diff --git a/declearn/quickrun/_split_data.py b/declearn/quickrun/_split_data.py
index e323258e..0fbf8069 100644
--- a/declearn/quickrun/_split_data.py
+++ b/declearn/quickrun/_split_data.py
@@ -167,7 +167,8 @@ def _split_labels(
         srt = idx * s_len
         end = (srt + s_len) if idx < (n_shards - 1) else len(order)
         shard = np.isin(target, order[srt:end])
-        split.append((inputs[shard], target[shard]))
+        shuffle = rng.permutation(shard.sum())
+        split.append((inputs[shard][shuffle], target[shard][shuffle]))
     return split
 
 
@@ -198,7 +199,7 @@ def _split_biased(
     return split
 
 
-def split_data(data_config: DataSplitConfig) -> None:
+def split_data(data_config: DataSplitConfig, folder: str) -> None:
     """Download and randomly split a dataset into shards.
 
     The resulting folder structure is :
@@ -223,6 +224,9 @@ def split_data(data_config: DataSplitConfig) -> None:
         os.makedirs(data_dir, exist_ok=True)
         np.save(os.path.join(data_dir, f"{name}.npy"), data)
 
+    # Overwrite default folder if provided
+    if data_config.data_folder:
+        folder = data_config.data_folder
     # Select the splitting function to be used.
     scheme = data_config.scheme
     if scheme == "iid":
@@ -242,7 +246,7 @@ def split_data(data_config: DataSplitConfig) -> None:
     )
     split = func(inputs, labels, data_config.n_shards, rng)
     # Export the resulting shard-wise data to files.
-    folder = os.path.join(data_config.export_folder, f"data_{scheme}")
+    folder = os.path.join(folder, f"data_{scheme}")
     for i, (inp, tgt) in enumerate(split):
         perc_train = data_config.perc_train
         if not perc_train:
diff --git a/declearn/quickrun/run.py b/declearn/quickrun/run.py
index e52a859d..1a8f8a08 100644
--- a/declearn/quickrun/run.py
+++ b/declearn/quickrun/run.py
@@ -36,7 +36,6 @@ import importlib
 import os
 import re
 import textwrap
-from pathlib import Path
 from typing import Any, Dict, List, Optional, Tuple
 
 from declearn.communication import NetworkClientConfig, NetworkServerConfig
@@ -77,10 +76,18 @@ def _run_server(
     model_config: ModelConfig,
     optim: FLOptimConfig,
     config: FLRunConfig,
+    expe_config: ExperimentConfig,
 ) -> None:
     """Routine to run a FL server, called by `run_declearn_experiment`."""
     model = _get_model(folder, model_config)
-    server = FederatedServer(model, network, optim)
+    if expe_config.checkpoint:
+        checkpoint = expe_config.checkpoint
+    else:
+        checkpoint = os.path.join(folder, "result")
+    checkpoint = os.path.join(checkpoint, "server")
+    server = FederatedServer(
+        model, network, optim, expe_config.metrics, checkpoint
+    )
     server.run(config)
 
 
@@ -88,6 +95,7 @@ def _run_client(
     folder: str,
     network: NetworkClientConfig,
     model_config: ModelConfig,
+    expe_config: ExperimentConfig,
     name: str,
     paths: dict,
 ) -> None:
@@ -96,6 +104,12 @@ def _run_client(
     network.name = name
     # Make the model importable
     _ = _get_model(folder, model_config)
+    # Add checkpointer
+    if expe_config.checkpoint:
+        checkpoint = expe_config.checkpoint
+    else:
+        checkpoint = os.path.join(folder, "result")
+    checkpoint = os.path.join(checkpoint, name)
     # Wrap train and validation data as Dataset objects.
     train = InMemoryDataset(
         paths.get("train_data"),
@@ -106,7 +120,7 @@ def _run_client(
         paths.get("valid_data"),
         target=paths.get("valid_target"),
     )
-    client = FederatedClient(network, train, valid)
+    client = FederatedClient(network, train, valid, checkpoint)
     client.run()
 
 
@@ -132,22 +146,29 @@ def get_toml_folder(config: Optional[str] = None) -> Tuple[str, str]:
 
 def locate_or_create_split_data(toml: str, folder: str) -> Dict:
     """Attempts to find split data according to the config toml or
-    or the defualt behavior. If failed, attempts to find full data
+    or the default behavior. If failed, attempts to find full data
     according to the config toml and split it"""
-    expe_config = ExperimentConfig.from_toml(toml, False, "experiment")
+    data_config = DataSplitConfig.from_toml(toml, False, "data")
     try:
-        client_dict = parse_data_folder(expe_config, folder)
+        client_dict = parse_data_folder(data_config, folder)
     except ValueError:
-        data_config = DataSplitConfig.from_toml(toml, False, "data")
-        split_data(folder, data_config)
-        client_dict = parse_data_folder(expe_config,folder)
+        split_data(data_config, folder)
+        client_dict = parse_data_folder(data_config, folder)
     return client_dict
 
 
-def quickrun(
-    config: Optional[str] = None,
-    **kwargs: Any,
-) -> None:
+def server_to_client_network(
+    network_cfg: NetworkServerConfig,
+) -> NetworkClientConfig:
+    "Converts server network config to client network config"
+    return NetworkClientConfig.from_params(
+        protocol=network_cfg.protocol,
+        server_uri=f"ws://localhost:{network_cfg.port}",
+        name="replaceme",
+    )
+
+
+def quickrun(config: Optional[str] = None) -> None:
     """Run a server and its clients using multiprocessing.
 
     The kwargs are the arguments expected by split_data,
@@ -156,20 +177,24 @@ def quickrun(
     toml, folder = get_toml_folder(config)
     # locate split data or split it if needed
     client_dict = locate_or_create_split_data(toml, folder)
-    # Parse toml file to ServerConfig and ClientConfig
-    ntk_server = NetworkServerConfig.from_toml(toml, False, "network_server")
-    optim = FLOptimConfig.from_toml(toml, False, "optim")
-    run = FLRunConfig.from_toml(toml, False, "run")
-    ntk_client = NetworkClientConfig.from_toml(toml, False, "network_client")
-    model_config = ModelConfig.from_toml(toml, False, "model")
+    # Parse toml files
+    ntk_server_cfg = NetworkServerConfig.from_toml(toml, False, "network")
+    ntk_client_cfg = server_to_client_network(ntk_server_cfg)
+    optim_cgf = FLOptimConfig.from_toml(toml, False, "optim")
+    run_cfg = FLRunConfig.from_toml(toml, False, "run")
+    model_cfg = ModelConfig.from_toml(toml, False, "model")
+    expe_cfg = ExperimentConfig.from_toml(toml, False, "experiment")
     # Set up a (func, args) tuple specifying the server process.
-    p_server = (_run_server, (folder, ntk_server, model_config, optim, run))
+    p_server = (
+        _run_server,
+        (folder, ntk_server_cfg, model_cfg, optim_cgf, run_cfg, expe_cfg),
+    )
     # Set up the (func, args) tuples specifying client-wise processes.
     p_client = []
     for name, data_dict in client_dict.items():
         client = (
             _run_client,
-            (folder, ntk_client, model_config, name, data_dict),
+            (folder, ntk_client_cfg, model_cfg, expe_cfg, name, data_dict),
         )
         p_client.append(client)
     # Run each and every process in parallel.
@@ -211,7 +236,7 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace:
 def main(args: Optional[List[str]] = None) -> None:
     """Quikcrun based on commandline-input arguments."""
     cmdargs = parse_args(args)
-    quickrun(folder=cmdargs.config)
+    quickrun(config=cmdargs.config)
 
 
 if __name__ == "__main__":
diff --git a/examples/quickrun/config.toml b/examples/quickrun/config.toml
index 32f00403..a25f77f7 100644
--- a/examples/quickrun/config.toml
+++ b/examples/quickrun/config.toml
@@ -1,24 +1,18 @@
-[network_server]
+[network]
 protocol = "websockets"
 host = "127.0.0.1"
 port = 8765
 
-[network_client]
-protocol = "websockets"
-server_uri = "ws://localhost:8765"
-name = "replaceme"
-
 [optim]
 aggregator = "averaging" # The chosen aggregation strategy
 server_opt = 1.0 # The server learning rate
 
     [optim.client_opt]
-    lrate = 0.01 # The client learning rate
-    regularizers = [["lasso", {alpha = 0.1}]] # The list of regularizer modules, each a list
-    modules = [["momentum", {"beta" = 0.9}]]
+    lrate = 0.001 # The client learning rate
+    modules = ["adam"] # The optimzer modules 
 
 [run]
-rounds = 2 # Number of training rounds
+rounds = 10 # Number of training rounds
 
     [run.register]
     min_clients = 1
@@ -36,10 +30,11 @@ rounds = 2 # Number of training rounds
 
 [experiment]
 # all args for parse_data_folder
-# target_folder = !!!
+
 
 [model]
 # information on where to find the model file
 
 [data]
-# all args from split_data argparser
\ No newline at end of file
+# all args from split_data argparser
+scheme = "labels"
\ No newline at end of file
-- 
GitLab