From 23feb7d350d5eb7a01b33e00315dc3ac08ecdb92 Mon Sep 17 00:00:00 2001
From: Paul Andrey <paul.andrey@inria.fr>
Date: Fri, 21 Jun 2024 14:07:55 +0200
Subject: [PATCH] Add 'autofill_fields' class attribute to 'TomlConfig'.

- Until now, 'TomlConfig' fields were either mandatory of had a
  fixed default value. In some cases, the latter may be refined
  based on other values; however, they could not be missing in
  TOML files.
- With the added 'autofield_fields' class attribute, some fields
  can be explicitly marked as having a dynamically-created default
  value. In that case, it can safely be ignored in TOML files.
- This is now used to support omitting 'evaluate' and 'fairness'
  fields when writing down a 'FLRunConfig' as a TOML file.
---
 declearn/main/config/_run_config.py |  2 ++
 declearn/utils/_toml_config.py      | 19 ++++++++---
 test/utils/test_toml.py             | 53 +++++++++++++++++++++++++++++
 3 files changed, 69 insertions(+), 5 deletions(-)

diff --git a/declearn/main/config/_run_config.py b/declearn/main/config/_run_config.py
index 25b93597..57076113 100644
--- a/declearn/main/config/_run_config.py
+++ b/declearn/main/config/_run_config.py
@@ -107,6 +107,8 @@ class FLRunConfig(TomlConfig):
     privacy: Optional[PrivacyConfig] = None
     early_stop: Optional[EarlyStopConfig] = None  # type: ignore  # is a type
 
+    autofill_fields = {"evaluate", "fairness"}
+
     @classmethod
     def parse_register(
         cls,
diff --git a/declearn/utils/_toml_config.py b/declearn/utils/_toml_config.py
index adc3c2b8..a3c88602 100644
--- a/declearn/utils/_toml_config.py
+++ b/declearn/utils/_toml_config.py
@@ -27,7 +27,7 @@ try:
 except ModuleNotFoundError:
     import tomli as tomllib
 
-from typing import Any, Dict, Optional, Type, TypeVar, Union
+from typing import Any, ClassVar, Dict, Optional, Set, Type, TypeVar, Union
 
 from typing_extensions import Self  # future: import from typing (py >=3.11)
 
@@ -178,6 +178,14 @@ class TomlConfig:
         Instantiate by parsing inputs dicts (or objects).
     """
 
+    autofill_fields: ClassVar[Set[str]] = set()
+    """Class attribute listing names of auto-fill fields.
+
+    The listed fields do not have a formal default value, but one is
+    dynamically created upon parsing other fields. As a consequence,
+    they may safely been ignored in TOML files or input dict params.
+    """
+
     @classmethod
     def from_params(
         cls,
@@ -334,8 +342,8 @@ class TomlConfig:
             hyper-parameters making up for the FL "run" configuration.
         warn_user: bool, default=True
             Boolean indicating whether to raise a warning when some
-            fields are unused. Useful for cases where unused fields are
-            expected, e.g. in declearn-quickrun mode.
+            fields are unused. Useful for cases where unused fields
+            are expected, e.g. in declearn-quickrun mode.
         use_section: optional(str), default=None
             If not None, points to a specific section of the TOML that
             should be used, rather than the whole file. Useful to parse
@@ -381,10 +389,11 @@ class TomlConfig:
             elif (
                 field.default is dataclasses.MISSING
                 and field.default_factory is dataclasses.MISSING
+                and field.name not in cls.autofill_fields
             ):
                 raise RuntimeError(
-                    "Missing required section in the TOML configuration "
-                    f"file: '{field.name}'."
+                    "Missing section in the TOML configuration file: "
+                    f"'{field.name}'.",
                 )
         # Warn about remaining (unused) config sections.
         if warn_user:
diff --git a/test/utils/test_toml.py b/test/utils/test_toml.py
index ed8a04d9..1b0567d1 100644
--- a/test/utils/test_toml.py
+++ b/test/utils/test_toml.py
@@ -381,3 +381,56 @@ class TestTomlConfigNested:
         }["demo_a"]
         with pytest.raises(TypeError):
             ComplexTomlConfig.default_parser(field, path_bad)
+
+
+@dataclasses.dataclass
+class AutofillTomlConfig(TomlConfig):
+    """Demonstration TomlConfig subclass with an autofill field."""
+
+    base: int
+    auto: int
+
+    autofill_fields = {"auto"}
+
+    @classmethod
+    def from_params(
+        cls,
+        **kwargs: Any,
+    ) -> Self:
+        if "base" in kwargs:
+            kwargs.setdefault("auto", kwargs["base"])
+        return super().from_params(**kwargs)
+
+
+class TestTomlAutofill:
+    """Unit tests for a 'TomlConfig' subclass with an auto-fill field."""
+
+    def test_from_params_exhaustive(self) -> None:
+        """Test parsing kwargs with exhaustive values."""
+        config = AutofillTomlConfig.from_params(base=0, auto=1)
+        assert config.base == 0  # false-positive; pylint: disable=no-member
+        assert config.auto == 1  # false-positive; pylint: disable=no-member
+
+    def test_from_params_autofill(self) -> None:
+        """Test parsing kwargs without the auto-filled value."""
+        config = AutofillTomlConfig.from_params(base=0)
+        assert config.base == 0  # false-positive; pylint: disable=no-member
+        assert config.auto == 0  # false-positive; pylint: disable=no-member
+
+    def test_from_toml_exhaustive(self, tmp_path: str) -> None:
+        """Test parsing a TOML file with exhaustive values."""
+        path = os.path.join(tmp_path, "config.toml")
+        with open(path, "w", encoding="utf-8") as file:
+            file.write("base = 0\nauto = 1")
+        config = AutofillTomlConfig.from_toml(path)
+        assert config.base == 0  # false-positive; pylint: disable=no-member
+        assert config.auto == 1  # false-positive; pylint: disable=no-member
+
+    def test_from_toml_autofill(self, tmp_path: str) -> None:
+        """Test parsing a TOML file without the auto-filled value."""
+        path = os.path.join(tmp_path, "config.toml")
+        with open(path, "w", encoding="utf-8") as file:
+            file.write("base = 0")
+        config = AutofillTomlConfig.from_toml(path)
+        assert config.base == 0  # false-positive; pylint: disable=no-member
+        assert config.auto == 0  # false-positive; pylint: disable=no-member
-- 
GitLab