From 69699d65e34cc1098e8bb3e67db41bd262a69f82 Mon Sep 17 00:00:00 2001 From: BIGAUD Nathan <nathan.bigaud@inria.fr> Date: Thu, 30 Mar 2023 20:59:24 +0200 Subject: [PATCH] Simplyfing defualt config and model --- examples/quickrun/config.toml | 57 +++++++++++++++++------------------ examples/quickrun/model.py | 36 ++-------------------- 2 files changed, 30 insertions(+), 63 deletions(-) diff --git a/examples/quickrun/config.toml b/examples/quickrun/config.toml index a25f77f7..eaa59344 100644 --- a/examples/quickrun/config.toml +++ b/examples/quickrun/config.toml @@ -1,40 +1,39 @@ -[network] -protocol = "websockets" -host = "127.0.0.1" -port = 8765 +# This is a minimal TOML file for the MNIST example +# It contains the bare minimum to make the experiment run +# It can be used as a template for other experiments. +# See `examples/quickrun/readme.md` for more details. -[optim] -aggregator = "averaging" # The chosen aggregation strategy -server_opt = 1.0 # The server learning rate - [optim.client_opt] - lrate = 0.001 # The client learning rate - modules = ["adam"] # The optimzer modules +[network] # Network configuration used by both client and server +protocol = "websockets" # Protocol used, to keep things simple use websocket +host = "127.0.0.1" # Address used, works as is on most set ups +port = 8765 # Port used, works as is on most set ups -[run] -rounds = 10 # Number of training rounds +[model] # Information on where to find the model file - [run.register] - min_clients = 1 - max_clients = 6 - timeout = 5 +[data] # How to split your data +scheme = "iid" # SPlit your data iid between simulated clients - [run.training] - n_epoch = 1 # Number of local epochs - batch_size = 48 # Training batch size - drop_remainder = false # Whether to drop the last trainig examples +[optim] # Optimizers options for both client and server +aggregator = "averaging" # Server aggregation strategy - [run.evaluate] - batch_size = 128 # Evaluation batch size + [optim.client_opt] # Client optimization strategy + lrate = 0.001 # Client learning rate + modules = ["adam"] # List of optimizer modules used + +[run] # Training process option for both client and server +rounds = 10 # Number of overall training rounds + [run.register] # Client registration options + timeout = 5 # How long to wait for clients, in seconds -[experiment] -# all args for parse_data_folder + [run.training] # Client training procedure + batch_size = 48 # Training batch size + + [run.evaluate] + # batch_size = 128 # Evaluation batch size -[model] -# information on where to find the model file +[experiment] # What to report during the experiment and where to report it +metrics=[["multi-classif",{labels = [0,1,2,3,4,5,6,7,8,9]}]] -[data] -# all args from split_data argparser -scheme = "labels" \ No newline at end of file diff --git a/examples/quickrun/model.py b/examples/quickrun/model.py index 8e5080bc..0b5e53cc 100644 --- a/examples/quickrun/model.py +++ b/examples/quickrun/model.py @@ -1,41 +1,8 @@ -"""Wrapping a torch model""" +"""Wrapping a simple CNN for the MNIST example""" import tensorflow as tf -import torch -import torch.nn as nn -import torch.nn.functional as F from declearn.model.tensorflow import TensorflowModel -from declearn.model.torch import TorchModel - -# class Net(nn.Module): -# def __init__(self): -# super(Net, self).__init__() -# self.conv1 = nn.Conv2d(1, 32, 3, 1) -# self.conv2 = nn.Conv2d(32, 64, 3, 1) -# self.dropout1 = nn.Dropout(0.25) -# self.dropout2 = nn.Dropout(0.5) -# self.fc1 = nn.Linear(9216, 128) -# self.fc2 = nn.Linear(128, 10) - -# def forward(self, x): -# x = torch.transpose(x, 3, 1) -# x = self.conv1(x) -# x = F.relu(x) -# x = self.conv2(x) -# x = F.relu(x) -# x = F.max_pool2d(x, 2) -# x = self.dropout1(x) -# x = torch.flatten(x, 1) -# x = self.fc1(x) -# x = F.relu(x) -# x = self.dropout2(x) -# x = self.fc2(x) -# output = F.log_softmax(x, dim=1) -# return output - - -# MyModel = TorchModel(Net(), loss=nn.NLLLoss()) stack = [ tf.keras.layers.InputLayer(input_shape=(28, 28, 1)), @@ -49,4 +16,5 @@ stack = [ tf.keras.layers.Dense(10, activation="softmax"), ] model = tf.keras.models.Sequential(stack) + MyModel = TensorflowModel(model, loss="sparse_categorical_crossentropy") -- GitLab