From 5847d21115b25391f52812da2d909cfc99d7cd52 Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Tue, 23 Jul 2024 11:22:15 +0200 Subject: [PATCH] Add unit tests for 'FLOptimConfig'. --- test/main/test_config_optim.py | 256 +++++++++++++++++++++++++++++++++ 1 file changed, 256 insertions(+) create mode 100644 test/main/test_config_optim.py diff --git a/test/main/test_config_optim.py b/test/main/test_config_optim.py new file mode 100644 index 0000000..8c030a4 --- /dev/null +++ b/test/main/test_config_optim.py @@ -0,0 +1,256 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for 'declearn.main.config.FLOptimConfig'.""" + +import dataclasses +import os +from unittest import mock + +import pytest + +from declearn.aggregator import Aggregator, AveragingAggregator, SumAggregator +from declearn.fairness.api import FairnessControllerServer +from declearn.fairness.fairgrad import FairgradControllerServer +from declearn.main.config import FLOptimConfig +from declearn.optimizer import Optimizer +from declearn.optimizer.modules import AdamModule + + +FIELDS = {field.name: field for field in dataclasses.fields(FLOptimConfig)} + + +class TestFLOptimConfig: + """Unit tests for 'declearn.main.config.FLOptimConfig'.""" + + # unit tests; pylint: disable=too-many-public-methods + + # Client-side optimizer. + + def test_parse_client_opt_float(self) -> None: + """Test parsing 'client_opt' from a float input.""" + field = FIELDS["client_opt"] + optim = FLOptimConfig.parse_client_opt(field, 0.1) + assert isinstance(optim, Optimizer) + assert optim.lrate == 0.1 + assert optim.w_decay == 0.0 + assert not optim.modules + assert not optim.regularizers + + def test_parse_client_opt_dict(self) -> None: + """Test parsing 'client_opt' from a dict input.""" + field = FIELDS["client_opt"] + config = {"lrate": 0.1, "modules": ["adam"]} + optim = FLOptimConfig.parse_client_opt(field, config) + assert isinstance(optim, Optimizer) + assert optim.lrate == 0.1 + assert optim.w_decay == 0.0 + assert len(optim.modules) == 1 + assert isinstance(optim.modules[0], AdamModule) + assert not optim.regularizers + + def test_parse_client_opt_dict_error(self) -> None: + """Test parsing 'client_opt' from an invalid dict input.""" + field = FIELDS["client_opt"] + config = {"modules": ["adam"]} # missing 'lrate' + with pytest.raises(TypeError): + FLOptimConfig.parse_client_opt(field, config) + + def test_parse_client_opt_optimizer(self) -> None: + """Test parsing 'client_opt' from an Optimizer input.""" + field = FIELDS["client_opt"] + optim = mock.create_autospec(Optimizer, instance=True) + assert FLOptimConfig.parse_client_opt(field, optim) is optim + + def test_parse_client_opt_error(self) -> None: + """Test parsing 'client_opt' from an invalid-type input.""" + field = FIELDS["client_opt"] + with pytest.raises(TypeError): + FLOptimConfig.parse_client_opt(field, mock.MagicMock()) + + # Server-side optimizer. + # pylint: disable=duplicate-code + + def test_parse_server_opt_none(self) -> None: + """Test parsing 'server_opt' from None.""" + field = FIELDS["server_opt"] + optim = FLOptimConfig.parse_server_opt(field, None) + assert isinstance(optim, Optimizer) + assert optim.lrate == 1.0 + assert optim.w_decay == 0.0 + assert not optim.modules + assert not optim.regularizers + + def test_parse_server_opt_float(self) -> None: + """Test parsing 'server_opt' from a float input.""" + field = FIELDS["server_opt"] + optim = FLOptimConfig.parse_server_opt(field, 0.1) + assert isinstance(optim, Optimizer) + assert optim.lrate == 0.1 + assert optim.w_decay == 0.0 + assert not optim.modules + assert not optim.regularizers + + def test_parse_server_opt_dict(self) -> None: + """Test parsing 'server_opt' from a dict input.""" + field = FIELDS["server_opt"] + config = {"lrate": 0.1, "modules": ["adam"]} + optim = FLOptimConfig.parse_server_opt(field, config) + assert isinstance(optim, Optimizer) + assert optim.lrate == 0.1 + assert optim.w_decay == 0.0 + assert len(optim.modules) == 1 + assert isinstance(optim.modules[0], AdamModule) + assert not optim.regularizers + + def test_parse_server_opt_dict_error(self) -> None: + """Test parsing 'server_opt' from an invalid dict input.""" + field = FIELDS["server_opt"] + config = {"modules": ["adam"]} # missing 'lrate' + with pytest.raises(TypeError): + FLOptimConfig.parse_server_opt(field, config) + + def test_parse_server_opt_optimizer(self) -> None: + """Test parsing 'server_opt' from an Optimizer input.""" + field = FIELDS["server_opt"] + optim = mock.create_autospec(Optimizer, instance=True) + assert FLOptimConfig.parse_server_opt(field, optim) is optim + + def test_parse_server_opt_error(self) -> None: + """Test parsing 'server_opt' from an invalid-type input.""" + field = FIELDS["server_opt"] + with pytest.raises(TypeError): + FLOptimConfig.parse_server_opt(field, mock.MagicMock()) + + # pylint: enable=duplicate-code + # Aggregator. + + def test_parse_aggregator_none(self) -> None: + """Test parsing 'aggregator' from None.""" + field = FIELDS["aggregator"] + aggregator = FLOptimConfig.parse_aggregator(field, None) + assert isinstance(aggregator, AveragingAggregator) + + def test_parse_aggregator_str(self) -> None: + """Test parsing 'aggregator' from a string.""" + field = FIELDS["aggregator"] + aggregator = FLOptimConfig.parse_aggregator(field, "sum") + assert isinstance(aggregator, SumAggregator) + + def test_parse_aggregator_dict(self) -> None: + """Test parsing 'aggregator' from a dict.""" + field = FIELDS["aggregator"] + config = {"name": "averaging", "config": {"steps_weighted": False}} + aggregator = FLOptimConfig.parse_aggregator(field, config) + assert isinstance(aggregator, AveragingAggregator) + assert not aggregator.steps_weighted + + def test_parse_aggregator_dict_error(self) -> None: + """Test parsing 'aggregator' from an invalid dict.""" + field = FIELDS["aggregator"] + config = {"name": "adam", "group": "OptiModule"} # wrong target type + with pytest.raises(TypeError): + FLOptimConfig.parse_aggregator(field, config) + + def test_parse_aggregator_aggregator(self) -> None: + """Test parsing 'aggregator' from an Aggregator.""" + field = FIELDS["aggregator"] + aggregator = mock.create_autospec(Aggregator, instance=True) + assert FLOptimConfig.parse_aggregator(field, aggregator) is aggregator + + def test_parse_aggregator_error(self) -> None: + """Test parsing 'aggregator' from an invalid-type input.""" + field = FIELDS["aggregator"] + with pytest.raises(TypeError): + FLOptimConfig.parse_aggregator(field, mock.MagicMock()) + + # Fairness. + + def test_parse_fairness_none(self) -> None: + """Test parsing 'fairness' from None.""" + field = FIELDS["fairness"] + fairness = FLOptimConfig.parse_fairness(field, None) + assert fairness is None + + def test_parse_fairness_dict(self) -> None: + """Test parsing 'fairness' from a dict.""" + field = FIELDS["fairness"] + config = { + "algorithm": "fedfairgrad", + "f_type": "demographic_parity", + "eta": 0.1, + "eps": 0.0, + } + fairness = FLOptimConfig.parse_fairness(field, config) + assert isinstance(fairness, FairgradControllerServer) + assert fairness.f_type == "demographic_parity" + assert fairness.weights_controller.eta == 0.1 + assert fairness.weights_controller.eps == 0.0 + + def test_parse_fairness_dict_error(self) -> None: + """Test parsing 'fairness' from an invalid dict.""" + field = FIELDS["fairness"] + config = {"algorithm": "fedfairgrad"} # missing f_type choice + with pytest.raises(TypeError): + FLOptimConfig.parse_fairness(field, config) + + def test_parse_fairness_controller(self) -> None: + """Test parsing 'fairness' from a FairnessControllerServer.""" + field = FIELDS["fairness"] + fairness = mock.create_autospec( + FairnessControllerServer, instance=True + ) + assert FLOptimConfig.parse_fairness(field, fairness) is fairness + + # Functional test. + + def test_from_toml(self, tmp_path: str) -> None: + """Test parsing an arbitrary, complex TOML file.""" + # Set up an arbitrary TOML file parseabld into an FLOptimConfig. + path = os.path.join(tmp_path, "config.toml") + toml_config = """ + [optim] + aggregator = "sum" + client_opt = 0.001 + [optim.server_opt] + lrate = 1.0 + modules = [["adam", {beta_1=0.8, beta_2=0.9}]] + [optim.fairness] + algorithm = "fedfairgrad" + f_type = "equalized_odds" + eta = 0.1 + eps = 0.0 + """ + with open(path, "w", encoding="utf-8") as file: + file.write(toml_config) + # Parse the TOML file and verify that outputs match expectations. + optim = FLOptimConfig.from_toml(path, use_section="optim") + assert isinstance(optim, FLOptimConfig) + assert isinstance(optim.aggregator, SumAggregator) + assert isinstance(optim.client_opt, Optimizer) + assert optim.client_opt.lrate == 0.001 + assert not optim.client_opt.modules + assert isinstance(optim.server_opt, Optimizer) + assert optim.server_opt.lrate == 1.0 + assert len(optim.server_opt.modules) == 1 + assert isinstance(optim.server_opt.modules[0], AdamModule) + assert optim.server_opt.modules[0].ewma_1.beta == 0.8 + assert optim.server_opt.modules[0].ewma_2.beta == 0.9 + assert isinstance(optim.fairness, FairgradControllerServer) + assert optim.fairness.f_type == "equalized_odds" + assert optim.fairness.weights_controller.eta == 0.1 + assert optim.fairness.weights_controller.eps == 0.0 -- GitLab