diff --git a/declearn/main/config/_run_config.py b/declearn/main/config/_run_config.py index 25b9359731491a77900a8b88fc3483818e4d4165..570761133fbb8aec1b3fca42daa4bdaefd033311 100644 --- a/declearn/main/config/_run_config.py +++ b/declearn/main/config/_run_config.py @@ -107,6 +107,8 @@ class FLRunConfig(TomlConfig): privacy: Optional[PrivacyConfig] = None early_stop: Optional[EarlyStopConfig] = None # type: ignore # is a type + autofill_fields = {"evaluate", "fairness"} + @classmethod def parse_register( cls, diff --git a/declearn/utils/_toml_config.py b/declearn/utils/_toml_config.py index adc3c2b87cd6fbf0825b37ff35e4905f032e8b98..a3c886021231c37d03934467d5797008e898d242 100644 --- a/declearn/utils/_toml_config.py +++ b/declearn/utils/_toml_config.py @@ -27,7 +27,7 @@ try: except ModuleNotFoundError: import tomli as tomllib -from typing import Any, Dict, Optional, Type, TypeVar, Union +from typing import Any, ClassVar, Dict, Optional, Set, Type, TypeVar, Union from typing_extensions import Self # future: import from typing (py >=3.11) @@ -178,6 +178,14 @@ class TomlConfig: Instantiate by parsing inputs dicts (or objects). """ + autofill_fields: ClassVar[Set[str]] = set() + """Class attribute listing names of auto-fill fields. + + The listed fields do not have a formal default value, but one is + dynamically created upon parsing other fields. As a consequence, + they may safely been ignored in TOML files or input dict params. + """ + @classmethod def from_params( cls, @@ -334,8 +342,8 @@ class TomlConfig: hyper-parameters making up for the FL "run" configuration. warn_user: bool, default=True Boolean indicating whether to raise a warning when some - fields are unused. Useful for cases where unused fields are - expected, e.g. in declearn-quickrun mode. + fields are unused. Useful for cases where unused fields + are expected, e.g. in declearn-quickrun mode. use_section: optional(str), default=None If not None, points to a specific section of the TOML that should be used, rather than the whole file. Useful to parse @@ -381,10 +389,11 @@ class TomlConfig: elif ( field.default is dataclasses.MISSING and field.default_factory is dataclasses.MISSING + and field.name not in cls.autofill_fields ): raise RuntimeError( - "Missing required section in the TOML configuration " - f"file: '{field.name}'." + "Missing section in the TOML configuration file: " + f"'{field.name}'.", ) # Warn about remaining (unused) config sections. if warn_user: diff --git a/test/utils/test_toml.py b/test/utils/test_toml.py index ed8a04d981648591e1c9a30536daae04cd2af2c5..1b0567d11691468c8933d23c175b98c4f94b0ef7 100644 --- a/test/utils/test_toml.py +++ b/test/utils/test_toml.py @@ -381,3 +381,56 @@ class TestTomlConfigNested: }["demo_a"] with pytest.raises(TypeError): ComplexTomlConfig.default_parser(field, path_bad) + + +@dataclasses.dataclass +class AutofillTomlConfig(TomlConfig): + """Demonstration TomlConfig subclass with an autofill field.""" + + base: int + auto: int + + autofill_fields = {"auto"} + + @classmethod + def from_params( + cls, + **kwargs: Any, + ) -> Self: + if "base" in kwargs: + kwargs.setdefault("auto", kwargs["base"]) + return super().from_params(**kwargs) + + +class TestTomlAutofill: + """Unit tests for a 'TomlConfig' subclass with an auto-fill field.""" + + def test_from_params_exhaustive(self) -> None: + """Test parsing kwargs with exhaustive values.""" + config = AutofillTomlConfig.from_params(base=0, auto=1) + assert config.base == 0 # false-positive; pylint: disable=no-member + assert config.auto == 1 # false-positive; pylint: disable=no-member + + def test_from_params_autofill(self) -> None: + """Test parsing kwargs without the auto-filled value.""" + config = AutofillTomlConfig.from_params(base=0) + assert config.base == 0 # false-positive; pylint: disable=no-member + assert config.auto == 0 # false-positive; pylint: disable=no-member + + def test_from_toml_exhaustive(self, tmp_path: str) -> None: + """Test parsing a TOML file with exhaustive values.""" + path = os.path.join(tmp_path, "config.toml") + with open(path, "w", encoding="utf-8") as file: + file.write("base = 0\nauto = 1") + config = AutofillTomlConfig.from_toml(path) + assert config.base == 0 # false-positive; pylint: disable=no-member + assert config.auto == 1 # false-positive; pylint: disable=no-member + + def test_from_toml_autofill(self, tmp_path: str) -> None: + """Test parsing a TOML file without the auto-filled value.""" + path = os.path.join(tmp_path, "config.toml") + with open(path, "w", encoding="utf-8") as file: + file.write("base = 0") + config = AutofillTomlConfig.from_toml(path) + assert config.base == 0 # false-positive; pylint: disable=no-member + assert config.auto == 0 # false-positive; pylint: disable=no-member