Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit 29b30d48 authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Use a (complex) scheduler in 'test_main.py'.

parent 521affb4
No related branches found
No related tags found
1 merge request!66Add a Scheduler API to enable time-based learning rate (and weight decay) adjustements
Pipeline #982202 passed
......@@ -106,7 +106,12 @@ class DeclearnTestCase:
"""Return a TensorflowModel suitable for the learning task."""
if self.kind == "Reg":
output_layer = tf_keras.layers.Dense(1)
loss = "mse"
if hasattr(tf_keras, "version") and tf_keras.version().startswith(
"3"
):
loss = "MeanSquaredError"
else:
loss = "mse"
elif self.kind == "Bin":
output_layer = tf_keras.layers.Dense(1, activation="sigmoid")
loss = "binary_crossentropy"
......@@ -204,14 +209,19 @@ class DeclearnTestCase:
"""Return parameters to instantiate a FLOptimConfig."""
client_modules = []
server_modules = []
# Optionally use Scaffold and/or sever-side momentum.
if self.strategy == "Scaffold":
client_modules.append("scaffold-client")
server_modules.append("scaffold-server")
if self.strategy in ("FedAvgM", "ScaffoldM"):
server_modules.append("momentum")
# Use a warmup over 100 steps followed by exponential decay.
exp_decay = ("exponential-decay", {"base": 0.01, "rate": 0.8})
cli_lrate = ("warmup", {"base": exp_decay, "warmup": 10})
# Return the federated optimization configuration.
return {
"aggregator": "averaging",
"client_opt": {"lrate": 0.01, "modules": client_modules},
"client_opt": {"lrate": cli_lrate, "modules": client_modules},
"server_opt": {"lrate": 1.0, "modules": server_modules},
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment