From 29b30d48e0248085b45c6351e9c1eddef6b8768d Mon Sep 17 00:00:00 2001
From: Paul Andrey <paul.andrey@inria.fr>
Date: Mon, 27 May 2024 11:51:37 +0200
Subject: [PATCH] Use a (complex) scheduler in 'test_main.py'.

---
 test/functional/test_main.py | 14 ++++++++++++--
 1 file changed, 12 insertions(+), 2 deletions(-)

diff --git a/test/functional/test_main.py b/test/functional/test_main.py
index 491ffd95..40500dbd 100644
--- a/test/functional/test_main.py
+++ b/test/functional/test_main.py
@@ -106,7 +106,12 @@ class DeclearnTestCase:
         """Return a TensorflowModel suitable for the learning task."""
         if self.kind == "Reg":
             output_layer = tf_keras.layers.Dense(1)
-            loss = "mse"
+            if hasattr(tf_keras, "version") and tf_keras.version().startswith(
+                "3"
+            ):
+                loss = "MeanSquaredError"
+            else:
+                loss = "mse"
         elif self.kind == "Bin":
             output_layer = tf_keras.layers.Dense(1, activation="sigmoid")
             loss = "binary_crossentropy"
@@ -204,14 +209,19 @@ class DeclearnTestCase:
         """Return parameters to instantiate a FLOptimConfig."""
         client_modules = []
         server_modules = []
+        # Optionally use Scaffold and/or sever-side momentum.
         if self.strategy == "Scaffold":
             client_modules.append("scaffold-client")
             server_modules.append("scaffold-server")
         if self.strategy in ("FedAvgM", "ScaffoldM"):
             server_modules.append("momentum")
+        # Use a warmup over 100 steps followed by exponential decay.
+        exp_decay = ("exponential-decay", {"base": 0.01, "rate": 0.8})
+        cli_lrate = ("warmup", {"base": exp_decay, "warmup": 10})
+        # Return the federated optimization configuration.
         return {
             "aggregator": "averaging",
-            "client_opt": {"lrate": 0.01, "modules": client_modules},
+            "client_opt": {"lrate": cli_lrate, "modules": client_modules},
             "server_opt": {"lrate": 1.0, "modules": server_modules},
         }
 
-- 
GitLab