From 23feb7d350d5eb7a01b33e00315dc3ac08ecdb92 Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Fri, 21 Jun 2024 14:07:55 +0200 Subject: [PATCH] Add 'autofill_fields' class attribute to 'TomlConfig'. - Until now, 'TomlConfig' fields were either mandatory of had a fixed default value. In some cases, the latter may be refined based on other values; however, they could not be missing in TOML files. - With the added 'autofield_fields' class attribute, some fields can be explicitly marked as having a dynamically-created default value. In that case, it can safely be ignored in TOML files. - This is now used to support omitting 'evaluate' and 'fairness' fields when writing down a 'FLRunConfig' as a TOML file. --- declearn/main/config/_run_config.py | 2 ++ declearn/utils/_toml_config.py | 19 ++++++++--- test/utils/test_toml.py | 53 +++++++++++++++++++++++++++++ 3 files changed, 69 insertions(+), 5 deletions(-) diff --git a/declearn/main/config/_run_config.py b/declearn/main/config/_run_config.py index 25b93597..57076113 100644 --- a/declearn/main/config/_run_config.py +++ b/declearn/main/config/_run_config.py @@ -107,6 +107,8 @@ class FLRunConfig(TomlConfig): privacy: Optional[PrivacyConfig] = None early_stop: Optional[EarlyStopConfig] = None # type: ignore # is a type + autofill_fields = {"evaluate", "fairness"} + @classmethod def parse_register( cls, diff --git a/declearn/utils/_toml_config.py b/declearn/utils/_toml_config.py index adc3c2b8..a3c88602 100644 --- a/declearn/utils/_toml_config.py +++ b/declearn/utils/_toml_config.py @@ -27,7 +27,7 @@ try: except ModuleNotFoundError: import tomli as tomllib -from typing import Any, Dict, Optional, Type, TypeVar, Union +from typing import Any, ClassVar, Dict, Optional, Set, Type, TypeVar, Union from typing_extensions import Self # future: import from typing (py >=3.11) @@ -178,6 +178,14 @@ class TomlConfig: Instantiate by parsing inputs dicts (or objects). """ + autofill_fields: ClassVar[Set[str]] = set() + """Class attribute listing names of auto-fill fields. + + The listed fields do not have a formal default value, but one is + dynamically created upon parsing other fields. As a consequence, + they may safely been ignored in TOML files or input dict params. + """ + @classmethod def from_params( cls, @@ -334,8 +342,8 @@ class TomlConfig: hyper-parameters making up for the FL "run" configuration. warn_user: bool, default=True Boolean indicating whether to raise a warning when some - fields are unused. Useful for cases where unused fields are - expected, e.g. in declearn-quickrun mode. + fields are unused. Useful for cases where unused fields + are expected, e.g. in declearn-quickrun mode. use_section: optional(str), default=None If not None, points to a specific section of the TOML that should be used, rather than the whole file. Useful to parse @@ -381,10 +389,11 @@ class TomlConfig: elif ( field.default is dataclasses.MISSING and field.default_factory is dataclasses.MISSING + and field.name not in cls.autofill_fields ): raise RuntimeError( - "Missing required section in the TOML configuration " - f"file: '{field.name}'." + "Missing section in the TOML configuration file: " + f"'{field.name}'.", ) # Warn about remaining (unused) config sections. if warn_user: diff --git a/test/utils/test_toml.py b/test/utils/test_toml.py index ed8a04d9..1b0567d1 100644 --- a/test/utils/test_toml.py +++ b/test/utils/test_toml.py @@ -381,3 +381,56 @@ class TestTomlConfigNested: }["demo_a"] with pytest.raises(TypeError): ComplexTomlConfig.default_parser(field, path_bad) + + +@dataclasses.dataclass +class AutofillTomlConfig(TomlConfig): + """Demonstration TomlConfig subclass with an autofill field.""" + + base: int + auto: int + + autofill_fields = {"auto"} + + @classmethod + def from_params( + cls, + **kwargs: Any, + ) -> Self: + if "base" in kwargs: + kwargs.setdefault("auto", kwargs["base"]) + return super().from_params(**kwargs) + + +class TestTomlAutofill: + """Unit tests for a 'TomlConfig' subclass with an auto-fill field.""" + + def test_from_params_exhaustive(self) -> None: + """Test parsing kwargs with exhaustive values.""" + config = AutofillTomlConfig.from_params(base=0, auto=1) + assert config.base == 0 # false-positive; pylint: disable=no-member + assert config.auto == 1 # false-positive; pylint: disable=no-member + + def test_from_params_autofill(self) -> None: + """Test parsing kwargs without the auto-filled value.""" + config = AutofillTomlConfig.from_params(base=0) + assert config.base == 0 # false-positive; pylint: disable=no-member + assert config.auto == 0 # false-positive; pylint: disable=no-member + + def test_from_toml_exhaustive(self, tmp_path: str) -> None: + """Test parsing a TOML file with exhaustive values.""" + path = os.path.join(tmp_path, "config.toml") + with open(path, "w", encoding="utf-8") as file: + file.write("base = 0\nauto = 1") + config = AutofillTomlConfig.from_toml(path) + assert config.base == 0 # false-positive; pylint: disable=no-member + assert config.auto == 1 # false-positive; pylint: disable=no-member + + def test_from_toml_autofill(self, tmp_path: str) -> None: + """Test parsing a TOML file without the auto-filled value.""" + path = os.path.join(tmp_path, "config.toml") + with open(path, "w", encoding="utf-8") as file: + file.write("base = 0") + config = AutofillTomlConfig.from_toml(path) + assert config.base == 0 # false-positive; pylint: disable=no-member + assert config.auto == 0 # false-positive; pylint: disable=no-member -- GitLab