diff --git a/test/functional/test_main.py b/test/functional/test_main.py index 491ffd95d3628ec6408e13b31b23cc87c06315e2..40500dbd1cc08991b1b7b3a29badb2814771fe56 100644 --- a/test/functional/test_main.py +++ b/test/functional/test_main.py @@ -106,7 +106,12 @@ class DeclearnTestCase: """Return a TensorflowModel suitable for the learning task.""" if self.kind == "Reg": output_layer = tf_keras.layers.Dense(1) - loss = "mse" + if hasattr(tf_keras, "version") and tf_keras.version().startswith( + "3" + ): + loss = "MeanSquaredError" + else: + loss = "mse" elif self.kind == "Bin": output_layer = tf_keras.layers.Dense(1, activation="sigmoid") loss = "binary_crossentropy" @@ -204,14 +209,19 @@ class DeclearnTestCase: """Return parameters to instantiate a FLOptimConfig.""" client_modules = [] server_modules = [] + # Optionally use Scaffold and/or sever-side momentum. if self.strategy == "Scaffold": client_modules.append("scaffold-client") server_modules.append("scaffold-server") if self.strategy in ("FedAvgM", "ScaffoldM"): server_modules.append("momentum") + # Use a warmup over 100 steps followed by exponential decay. + exp_decay = ("exponential-decay", {"base": 0.01, "rate": 0.8}) + cli_lrate = ("warmup", {"base": exp_decay, "warmup": 10}) + # Return the federated optimization configuration. return { "aggregator": "averaging", - "client_opt": {"lrate": 0.01, "modules": client_modules}, + "client_opt": {"lrate": cli_lrate, "modules": client_modules}, "server_opt": {"lrate": 1.0, "modules": server_modules}, }