From 98027bc4ff4ed17e9206c00e97953ebe0218b825 Mon Sep 17 00:00:00 2001 From: BIGAUD Nathan <nathan.bigaud@inria.fr> Date: Wed, 29 Mar 2023 16:01:08 +0200 Subject: [PATCH] In depth review of TOML parsing, first pass : * Creating "ModelConfig","DataSplitConfig", and "ExperimentConfig" * Relocating and updating parser to accept toml file as input * Update split data to use config input * Modularization and updating of run.py * Small changes to MNIST toml example --- declearn/quickrun/_config.py | 103 +++++++++++++ declearn/quickrun/_parser.py | 143 ++++++++++++++++++ declearn/quickrun/_split_data.py | 168 ++++----------------- declearn/quickrun/run.py | 241 ++++++++++--------------------- examples/quickrun/config.toml | 10 ++ examples/quickrun/model.py | 85 ++++++----- 6 files changed, 403 insertions(+), 347 deletions(-) create mode 100644 declearn/quickrun/_config.py create mode 100644 declearn/quickrun/_parser.py diff --git a/declearn/quickrun/_config.py b/declearn/quickrun/_config.py new file mode 100644 index 00000000..a80e711b --- /dev/null +++ b/declearn/quickrun/_config.py @@ -0,0 +1,103 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TOML-parsable container for quickrun configurations.""" + +import dataclasses +from typing import Dict, List, Literal, Optional, Union + +from declearn.utils import TomlConfig + +__all__ = [ + "ModelConfig", + "DataSplitConfig", + "ExperimentConfig", +] + + +@dataclasses.dataclass +class ModelConfig(TomlConfig): + """Dataclass used to provide custom model location and + class name""" + + model_file: Optional[str] = None + model_name: str = "MyModel" + + +@dataclasses.dataclass +class DataSplitConfig(TomlConfig): + """Dataclass associated with the function + declearn.quickrun._split_data:split_data + + export_folder: str + Path to the folder where to export shard-wise files. + n_shards: int + Number of shards between which to split the data. + data_file: str or None, default=None + Optional path to a folder where to find the data. + If None, default to the MNIST example. + target_file: str or int or None, default=None + If str, path to the labels file to import. If int, column of + the data file to be used as labels. Required if data is not None, + ignored if data is None. + scheme: {"iid", "labels", "biased"}, default="iid" + Splitting scheme(s) to use. In all cases, shards contain mutually- + exclusive samples and cover the full raw training data. + - If "iid", split the dataset through iid random sampling. + - If "labels", split into shards that hold all samples associated + with mutually-exclusive target classes. + - If "biased", split the dataset through random sampling according + to a shard-specific random labels distribution. + perc_train: float, default= 0.8 + Train/validation split in each client dataset, must be in the + ]0,1] range. + seed: int or None, default=None + Optional seed to the RNG used for all sampling operations. + """ + + export_folder: str = "." + n_shards: int = 5 + data_file: Optional[str] = None + label_file: Optional[Union[str, int]] = None + scheme: Literal["iid", "labels", "biased"] = "iid" + perc_train: float = 0.8 + seed: Optional[int] = None + + +@dataclasses.dataclass +class ExperimentConfig(TomlConfig): + """ + + Dataclass associated with the function + declearn.quickrun._parser:parse_data_folder + + data_folder : str or none + Absolute path to the main folder hosting the data, overwriting + the folder argument if provided. If None, default to expected + prefix search in folder. + client_names: list or None + List of custom client names to look for in the data_folder. + If None, default to expected prefix search. + dataset_names: dict or None + Dict of custom dataset names, to look for in each client folder. + Expect 'train_data, train_target, valid_data, valid_target' as keys. + If None, , default to expected prefix search. + """ + + data_folder: Optional[str] = None + client_names: Optional[List[str]] = None + dataset_names: Optional[Dict[str, str]] = None diff --git a/declearn/quickrun/_parser.py b/declearn/quickrun/_parser.py new file mode 100644 index 00000000..eae9c9fc --- /dev/null +++ b/declearn/quickrun/_parser.py @@ -0,0 +1,143 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +#TODO +""" + +import os +from pathlib import Path +from typing import Dict, List, Optional + +from declearn.test_utils import make_importable + +# Perform local imports. +# pylint: disable=wrong-import-order, wrong-import-position +with make_importable(os.path.dirname(__file__)): + from _config import ExperimentConfig +# pylint: enable=wrong-import-order, wrong-import-position + + +def parse_data_folder( + expe_config: ExperimentConfig, + folder: Optional[str] = None, +) -> Dict: + """Utils parsing a data folder following a standard format into a nested + dictionnary. + + The default expected format is : + + folder/ + └─── data*/ + └─── client*/ + │ train_data.* - training data + │ train_target.* - training labels + │ valid_data.* - validation data + │ valid_target.* - validation labels + └─── client*/ + │ ... + + Parameters: + ----------- + expe_config : ExperimentConfig + ExperimentConfig instance, see class documentation for details. + folder : str or None + The main experiment folder in which to look for a `data*` folder. + Overwritten by data_folder. + """ + + data_folder = expe_config.data_folder + client_names = expe_config.client_names + dataset_names = expe_config.dataset_names + + if not folder and not data_folder: + raise ValueError( + "Please provide either a parent folder or a data folder" + ) + # Data_folder + if not data_folder: + gen_folders = Path(folder).glob("data*") # type: ignore + data_folder = next(gen_folders, False) # type: ignore + if not data_folder: + raise ValueError( + f"No folder starting with 'data' found in {folder}. " + "Please store your data under a 'data_*' folder" + ) + if next(gen_folders, False): + raise ValueError( + "More than one folder starting with 'data' found" + f"in {folder}. Please store your data under a single" + "parent folder" + ) + # Get clients dir + if client_names: + if isinstance(client_names, list): + valid_names = [ + os.path.isdir(os.path.join(data_folder, n)) + for n in client_names + ] + if sum(valid_names) != len(client_names): + raise ValueError( + f"Not all provided client names could be found in {data_folder}" + ) + clients = {n: {} for n in client_names} + else: + raise ValueError( + "Please provide a valid list of client names for " + "argument 'client_names'" + ) + else: + gen_folders = Path(data_folder).glob("client*") # type: ignore + first_client = next(gen_folders, False) + if not first_client: + raise ValueError( + f"No folder starting with 'client' found in {data_folder}. " + "Please store your individual under client data under" + "a 'client*' folder" + ) + clients = {str(first_client).rsplit("/", 1)[-1]: {}} + while client := next(gen_folders, False): + clients[str(client).rsplit("/", 1)[-1]] = {} + # Get train and valid files + data_items = [ + "train_data", + "train_target", + "valid_data", + "valid_target", + ] + if dataset_names: + if set(data_items) != set(dataset_names.keys()): + raise ValueError( + f"Please provide a properly formatted dictionnary as input" + f"using the follwoing keys : {str(data_items)}" + ) + else: + dataset_names = {i: i for i in data_items} + for client, files in clients.items(): + for k, v in dataset_names.items(): + gen_file = Path(data_folder / client).glob(f"{v}*") # type: ignore + file = next(gen_file, False) + if not file: + raise ValueError( + f"Could not find a file named '{v}.*' in {client}" + ) + if next(gen_file, False): + raise ValueError( + f"Found more than one file named '{v}.*' in {client}" + ) + files[k] = str(file) + return clients diff --git a/declearn/quickrun/_split_data.py b/declearn/quickrun/_split_data.py index 96c93ef0..e323258e 100644 --- a/declearn/quickrun/_split_data.py +++ b/declearn/quickrun/_split_data.py @@ -45,6 +45,13 @@ import pandas as pd import requests # type: ignore from declearn.dataset import load_data_array +from declearn.test_utils import make_importable + +# Perform local imports. +# pylint: disable=wrong-import-order, wrong-import-position +with make_importable(os.path.dirname(__file__)): + from _config import DataSplitConfig +# pylint: enable=wrong-import-order, wrong-import-position SOURCE_URL = "https://pjreddie.com/media/files" @@ -191,15 +198,7 @@ def _split_biased( return split -def split_data( - folder: str, - n_shards: int = 5, - data: Optional[str] = None, - target: Optional[Union[str, int]] = None, - scheme: Literal["iid", "labels", "biased"] = "iid", - perc_train: float = 0.8, - seed: Optional[int] = None, -) -> None: +def split_data(data_config: DataSplitConfig) -> None: """Download and randomly split a dataset into shards. The resulting folder structure is : @@ -215,32 +214,17 @@ def split_data( Parameters ---------- - folder: str - Path to the folder where to export shard-wise files. - n_shards: int - Number of shards between which to split the data. - data: str or None, default=None - Optional path to a folder where to find the data. - If None, default to the MNIST example. - target: str or int or None, default=None - If str, path to the labels file to import. If int, column of - the data file to be used as labels. Required if data is not None, - ignored if data is None. - scheme: {"iid", "labels", "biased"}, default="iid" - Splitting scheme to use. In all cases, shards contain mutually- - exclusive samples and cover the full raw training data. - - If "iid", split the dataset through iid random sampling. - - If "labels", split into shards that hold all samples associated - with mutually-exclusive target classes. - - If "biased", split the dataset through random sampling according - to a shard-specific random labels distribution. - perc_train: float, default= 0.8 - Train/validation split in each client dataset, must be in the - ]0,1] range. - seed: int or None, default=None - Optional seed to the RNG used for all sampling operations. + data_config: DataSplitConfig + A DataSplitConfig instance, see class documentation for details """ + + def np_save(folder, data, i, name): + data_dir = os.path.join(folder, f"client_{i}") + os.makedirs(data_dir, exist_ok=True) + np.save(os.path.join(data_dir, f"{name}.npy"), data) + # Select the splitting function to be used. + scheme = data_config.scheme if scheme == "iid": func = _split_iid elif scheme == "labels": @@ -250,115 +234,27 @@ def split_data( else: raise ValueError(f"Invalid 'scheme' value: '{scheme}'.") # Set up the RNG, download the raw dataset and split it. - rng = np.random.default_rng(seed) - inputs, labels = load_data(data, target) - print(f"Splitting data into {n_shards} shards using the {scheme} scheme") - split = func(inputs, labels, n_shards, rng) + rng = np.random.default_rng(data_config.seed) + inputs, labels = load_data(data_config.data_file, data_config.label_file) + print( + f"Splitting data into {data_config.n_shards}" + f"shards using the {scheme} scheme" + ) + split = func(inputs, labels, data_config.n_shards, rng) # Export the resulting shard-wise data to files. - folder = os.path.join(folder, f"data_{scheme}") - - def np_save(data, i, name): - data_dir = os.path.join(folder, f"client_{i}") - os.makedirs(data_dir, exist_ok=True) - np.save(os.path.join(data_dir, f"{name}.npy"), data) - + folder = os.path.join(data_config.export_folder, f"data_{scheme}") for i, (inp, tgt) in enumerate(split): + perc_train = data_config.perc_train if not perc_train: - np_save(inp, i, "train_data") - np_save(tgt, i, "train_target") + np_save(folder, inp, i, "train_data") + np_save(folder, tgt, i, "train_target") else: if perc_train > 1.0 or perc_train < 0.0: raise ValueError("perc_train should be a float in ]0,1]") n_train = round(len(inp) * perc_train) t_inp, t_tgt = inp[:n_train], tgt[:n_train] v_inp, v_tgt = inp[n_train:], tgt[n_train:] - np_save(t_inp, i, "train_data") - np_save(t_tgt, i, "train_target") - np_save(v_inp, i, "valid_data") - np_save(v_tgt, i, "valid_target") - - -def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace: - """Set up and run a command-line arguments parser.""" - usage = """ - Download and split data into heterogeneous shards. - - The implemented schemes are the following: - * "iid": - Split the dataset through iid random sampling. - * "labels": - Split the dataset into shards that hold all samples - that have mutually-exclusive target classes. - * "biased": - Split the dataset through random sampling according - to a shard-specific random labels distribution. - """ - usage = re.sub("\n *(?=[a-z])", " ", textwrap.dedent(usage)) - parser = argparse.ArgumentParser( - formatter_class=argparse.RawTextHelpFormatter, - usage=re.sub("- ", "-", usage), - ) - parser.add_argument( - "--n_shards", - type=int, - default=5, - help="Number of shards between which to split the MNIST training data.", - ) - parser.add_argument( - "--root", - default=".", - dest="folder", - help="Path to the root folder where to export raw and split data.", - ) - parser.add_argument( - "--data_path", - default=None, # CHECK - dest="data", - help="Path to the data to be split", - ) - parser.add_argument( - "--target_path", - default=None, # CHECK - dest="target", - help="Path to the labels to be split", - ) - schemes_help = """ - Splitting scheme(s) to use, among {"iid", "labels", "biased"}. - If this argument is not specified, all three values are used. - See details above on the schemes' definition. - """ - parser.add_argument( - "--scheme", - action="append", - choices=["iid", "labels", "biased"], - default=["iid"], - dest="schemes", - nargs="+", - help=textwrap.dedent(schemes_help), - ) - parser.add_argument( - "--seed", - default=20221109, - dest="seed", - type=int, - help="RNG seed to use (default: 20221109).", - ) - return parser.parse_args(args) - - -def main(args: Optional[List[str]] = None) -> None: - """Run splitting schemes based on commandline-input arguments.""" - cmdargs = parse_args(args) - for scheme in cmdargs.schemes: - split_data( - folder=os.path.join(cmdargs.folder, f"data_{scheme}"), - n_shards=cmdargs.n_shards, - data=cmdargs.data, - target=cmdargs.target, - scheme=scheme, - seed=cmdargs.seed, - ) - - -if __name__ == "__main__": - main() + np_save(folder, t_inp, i, "train_data") + np_save(folder, t_tgt, i, "train_target") + np_save(folder, v_inp, i, "valid_data") + np_save(folder, v_tgt, i, "valid_target") diff --git a/declearn/quickrun/run.py b/declearn/quickrun/run.py index 6bd718de..e52a859d 100644 --- a/declearn/quickrun/run.py +++ b/declearn/quickrun/run.py @@ -37,12 +37,13 @@ import os import re import textwrap from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple from declearn.communication import NetworkClientConfig, NetworkServerConfig from declearn.dataset import InMemoryDataset from declearn.main import FederatedClient, FederatedServer from declearn.main.config import FLOptimConfig, FLRunConfig +from declearn.model.api import Model from declearn.test_utils import make_importable from declearn.utils import run_as_processes @@ -53,107 +54,49 @@ DEFAULT_FOLDER = "./examples/quickrun" # Perform local imports. # pylint: disable=wrong-import-order, wrong-import-position with make_importable(os.path.dirname(__file__)): + from _config import DataSplitConfig, ExperimentConfig, ModelConfig + from _parser import parse_data_folder from _split_data import split_data # pylint: enable=wrong-import-order, wrong-import-position +def _get_model(folder, model_config) -> Model: + file = "model" + if m_file := model_config.model_file: + folder = os.path.dirname(m_file) + file = m_file.rsplit("/", 1)[-1].split(".")[0] + with make_importable(folder): + mod = importlib.import_module(file) + model_cls = getattr(mod, model_config.model_name) + return model_cls + + def _run_server( folder: str, network: NetworkServerConfig, + model_config: ModelConfig, optim: FLOptimConfig, config: FLRunConfig, ) -> None: """Routine to run a FL server, called by `run_declearn_experiment`.""" - # get Model - name = "MyModel" - with make_importable(folder): - mod = importlib.import_module("model") - model_cls = getattr(mod, name) - model = model_cls + model = _get_model(folder, model_config) server = FederatedServer(model, network, optim) server.run(config) -def parse_data_folder(folder: str) -> Dict: - """Utils parsing a data folder following a standard format into a nested" - dictionnary. - - The expected format is : - - folder/ - └─── data*/ - └─── client*/ - │ train_data.* - training data - │ train_target.* - training labels - │ valid_data.* - validation data - │ valid_target.* - validation labels - └─── client*/ - │ ... - """ - # Get data dir - gen_folders = Path(folder).glob("data*") - data_folder = next(gen_folders, False) - if not data_folder: - raise ValueError( - f"No folder starting with 'data' found in {folder}. " - "Please store your data under a 'data_*' folder" - ) - if next(gen_folders, False): - raise ValueError( - "More than one folder starting with 'data' found" - f"in {folder}. Please store your data under a single" - "parent folder" - ) - # Get clients dir - gen_folders = data_folder.glob("client*") # type: ignore - first_client = next(gen_folders, False) - if not first_client: - raise ValueError( - f"No folder starting with 'client' found in {data_folder}. " - "Please store your individual under client data under" - "a 'client*' folder" - ) - clients = {str(first_client).rsplit("/", 1)[-1]: {}} - while client := next(gen_folders, False): - clients[str(client).rsplit("/", 1)[-1]] = {} - # Get train and valid files - data_items = [ - "train_data", - "train_target", - "valid_data", - "valid_target", - ] - for client, files in clients.items(): - for d in data_items: - gen_file = Path(data_folder / client).glob(f"{d}*") # type: ignore - file = next(gen_file, False) - if not file: - raise ValueError( - f"Could not find a file named '{d}.*' in {client}" - ) - if next(gen_file, False): - raise ValueError( - f"Found more than one file named '{d}.*' in {client}" - ) - files[d] = str(file) - - return clients - - def _run_client( + folder: str, network: NetworkClientConfig, + model_config: ModelConfig, name: str, paths: dict, - folder: str, ) -> None: """Routine to run a FL client, called by `run_declearn_experiment`.""" # Overwrite client name based on folder name network.name = name + # Make the model importable + _ = _get_model(folder, model_config) # Wrap train and validation data as Dataset objects. - name = "MyModel" - with make_importable(folder): - mod = importlib.import_module("model") - model_cls = getattr(mod, name) # pylint: disable=unused-variable train = InMemoryDataset( paths.get("train_data"), target=paths.get("train_target"), @@ -167,8 +110,42 @@ def _run_client( client.run() +def get_toml_folder(config: Optional[str] = None) -> Tuple[str, str]: + """Deternmine if provided config is a file or a directory, and + return : + * The path to the TOML config file + * The path to the main folder of the experiment + """ + # default to the mnist example + if not config: + config = DEFAULT_FOLDER + config = os.path.abspath(config) + # check if config is TOML or dir + if os.path.isfile(config): + toml = config + folder = os.path.dirname(config) + elif os.path.isdir(config): + folder = config + toml = f"{folder}/config.toml" + return toml, folder + + +def locate_or_create_split_data(toml: str, folder: str) -> Dict: + """Attempts to find split data according to the config toml or + or the defualt behavior. If failed, attempts to find full data + according to the config toml and split it""" + expe_config = ExperimentConfig.from_toml(toml, False, "experiment") + try: + client_dict = parse_data_folder(expe_config, folder) + except ValueError: + data_config = DataSplitConfig.from_toml(toml, False, "data") + split_data(folder, data_config) + client_dict = parse_data_folder(expe_config,folder) + return client_dict + + def quickrun( - folder: Optional[str] = None, + config: Optional[str] = None, **kwargs: Any, ) -> None: """Run a server and its clients using multiprocessing. @@ -176,28 +153,24 @@ def quickrun( The kwargs are the arguments expected by split_data, see [the documentation][declearn.quickrun._split_data] """ - # default to the mnist exampl - if not folder: - folder = DEFAULT_FOLDER - folder = os.path.abspath(folder) - # Get datasets and client_names from folder - try: - client_dict = parse_data_folder(folder) - except ValueError: - split_data(folder, **kwargs) - client_dict = parse_data_folder(folder) + toml, folder = get_toml_folder(config) + # locate split data or split it if needed + client_dict = locate_or_create_split_data(toml, folder) # Parse toml file to ServerConfig and ClientConfig - toml = f"{folder}/config.toml" ntk_server = NetworkServerConfig.from_toml(toml, False, "network_server") optim = FLOptimConfig.from_toml(toml, False, "optim") run = FLRunConfig.from_toml(toml, False, "run") ntk_client = NetworkClientConfig.from_toml(toml, False, "network_client") + model_config = ModelConfig.from_toml(toml, False, "model") # Set up a (func, args) tuple specifying the server process. - p_server = (_run_server, (folder, ntk_server, optim, run)) + p_server = (_run_server, (folder, ntk_server, model_config, optim, run)) # Set up the (func, args) tuples specifying client-wise processes. p_client = [] for name, data_dict in client_dict.items(): - client = (_run_client, (ntk_client, name, data_dict, folder)) + client = ( + _run_client, + (folder, ntk_client, model_config, name, data_dict), + ) p_client.append(client) # Run each and every process in parallel. success, outputs = run_as_processes(p_server, *p_client) @@ -210,29 +183,16 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace: """Set up and run a command-line arguments parser.""" usage = """ Quickly run an example locally using declearn. - The script requires to be provided with the path to a folder - containing: + The script requires to be provided with the path a TOML file + with all the elements required to configurate an FL experiment, + or the path to a folder containing : + * a TOML file with all the elements required to configurate an + FL experiment * A declearn model - * A TOML file with all the elements required to configurate an FL - experiment * A data folder, structured in a specific way If not provided with this, the script defaults to the MNIST example provided by declearn in `declearn.example.quickrun`. - - Once launched, this script splits data into heterogeneous shards. It - then locally runs the FL experiment as layed out in the TOML file, - using privided model and data, and stores its result in the same folder. - - The implemented schemes are the following: - * "iid": - Split the dataset through iid random sampling. - * "labels": - Split the dataset into shards that hold all samples - that have mutually-exclusive target classes. - * "biased": - Split the dataset through random sampling according - to a shard-specific random labels distribution. """ usage = re.sub("\n *(?=[a-z])", " ", textwrap.dedent(usage)) parser = argparse.ArgumentParser( @@ -240,73 +200,18 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace: usage=re.sub("- ", "-", usage), ) parser.add_argument( - "--n_shards", - type=int, - default=5, - help="Number of shards between which to split the data.", - ) - parser.add_argument( - "--root", + "--config", default=None, - dest="folder", + dest="config", help="Path to the root folder where to export data.", ) - parser.add_argument( - "--data_path", - default=None, - dest="data", - help="Path to the data to be split", - ) - parser.add_argument( - "--target_path", - default=None, - dest="target", - help="Path to the labels to be split", - ) - schemes_help = """ - Splitting scheme(s) to use, among {"iid", "labels", "biased"}. - If this argument is not specified, all "iid" is used. - See details above on the schemes' definition. - """ - parser.add_argument( - "--scheme", - action="append", - choices=["iid", "labels", "biased"], - default=["iid"], - dest="schemes", - nargs="+", - help=textwrap.dedent(schemes_help), - ) - parser.add_argument( - "--train_split", - default=0.8, - dest="perc_train", - type=float, - help="What proportion of the data to use for training vs validation", - ) - parser.add_argument( - "--seed", - default=20221109, - dest="seed", - type=int, - help="RNG seed to use (default: 20221109).", - ) return parser.parse_args(args) def main(args: Optional[List[str]] = None) -> None: """Quikcrun based on commandline-input arguments.""" cmdargs = parse_args(args) - for scheme in cmdargs.schemes: - quickrun( - folder=cmdargs.folder, - n_shards=cmdargs.n_shards, - data=cmdargs.data, - target=cmdargs.target, - scheme=scheme, - perc_train=cmdargs.perc_train, - seed=cmdargs.seed, - ) + quickrun(folder=cmdargs.config) if __name__ == "__main__": diff --git a/examples/quickrun/config.toml b/examples/quickrun/config.toml index 7e941bef..32f00403 100644 --- a/examples/quickrun/config.toml +++ b/examples/quickrun/config.toml @@ -15,6 +15,7 @@ server_opt = 1.0 # The server learning rate [optim.client_opt] lrate = 0.01 # The client learning rate regularizers = [["lasso", {alpha = 0.1}]] # The list of regularizer modules, each a list + modules = [["momentum", {"beta" = 0.9}]] [run] rounds = 2 # Number of training rounds @@ -33,3 +34,12 @@ rounds = 2 # Number of training rounds batch_size = 128 # Evaluation batch size +[experiment] +# all args for parse_data_folder +# target_folder = !!! + +[model] +# information on where to find the model file + +[data] +# all args from split_data argparser \ No newline at end of file diff --git a/examples/quickrun/model.py b/examples/quickrun/model.py index 211adbcd..8e5080bc 100644 --- a/examples/quickrun/model.py +++ b/examples/quickrun/model.py @@ -8,46 +8,45 @@ import torch.nn.functional as F from declearn.model.tensorflow import TensorflowModel from declearn.model.torch import TorchModel - -class Net(nn.Module): - def __init__(self): - super(Net, self).__init__() - self.conv1 = nn.Conv2d(1, 32, 3, 1) - self.conv2 = nn.Conv2d(32, 64, 3, 1) - self.dropout1 = nn.Dropout(0.25) - self.dropout2 = nn.Dropout(0.5) - self.fc1 = nn.Linear(9216, 128) - self.fc2 = nn.Linear(128, 10) - - def forward(self, x): - x = torch.transpose(x, 3, 1) - x = self.conv1(x) - x = F.relu(x) - x = self.conv2(x) - x = F.relu(x) - x = F.max_pool2d(x, 2) - x = self.dropout1(x) - x = torch.flatten(x, 1) - x = self.fc1(x) - x = F.relu(x) - x = self.dropout2(x) - x = self.fc2(x) - output = F.log_softmax(x, dim=1) - return output - - -MyModel = TorchModel(Net(), loss=nn.NLLLoss()) - -# stack = [ -# tf.keras.layers.InputLayer(input_shape=(28, 28, 1)), -# tf.keras.layers.Conv2D(32, 3, 1, activation="relu"), -# tf.keras.layers.Conv2D(64, 3, 1, activation="relu"), -# tf.keras.layers.MaxPool2D(2), -# tf.keras.layers.Dropout(0.25), -# tf.keras.layers.Flatten(), -# tf.keras.layers.Dense(128, activation="relu"), -# tf.keras.layers.Dropout(0.5), -# tf.keras.layers.Dense(10, activation="softmax"), -# ] -# model = tf.keras.models.Sequential(stack) -# MyModel = TensorflowModel(model, loss="sparse_categorical_crossentropy") +# class Net(nn.Module): +# def __init__(self): +# super(Net, self).__init__() +# self.conv1 = nn.Conv2d(1, 32, 3, 1) +# self.conv2 = nn.Conv2d(32, 64, 3, 1) +# self.dropout1 = nn.Dropout(0.25) +# self.dropout2 = nn.Dropout(0.5) +# self.fc1 = nn.Linear(9216, 128) +# self.fc2 = nn.Linear(128, 10) + +# def forward(self, x): +# x = torch.transpose(x, 3, 1) +# x = self.conv1(x) +# x = F.relu(x) +# x = self.conv2(x) +# x = F.relu(x) +# x = F.max_pool2d(x, 2) +# x = self.dropout1(x) +# x = torch.flatten(x, 1) +# x = self.fc1(x) +# x = F.relu(x) +# x = self.dropout2(x) +# x = self.fc2(x) +# output = F.log_softmax(x, dim=1) +# return output + + +# MyModel = TorchModel(Net(), loss=nn.NLLLoss()) + +stack = [ + tf.keras.layers.InputLayer(input_shape=(28, 28, 1)), + tf.keras.layers.Conv2D(32, 3, 1, activation="relu"), + tf.keras.layers.Conv2D(64, 3, 1, activation="relu"), + tf.keras.layers.MaxPool2D(2), + tf.keras.layers.Dropout(0.25), + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(128, activation="relu"), + tf.keras.layers.Dropout(0.5), + tf.keras.layers.Dense(10, activation="softmax"), +] +model = tf.keras.models.Sequential(stack) +MyModel = TensorflowModel(model, loss="sparse_categorical_crossentropy") -- GitLab