From dd17cd64d10bae0c6a75a49c698c7a3e1eba0d43 Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Thu, 27 Apr 2023 16:38:31 +0200 Subject: [PATCH] Add unit tests for 'TomlConfig' and revise its backend. --- declearn/utils/_toml_config.py | 67 +++--- test/utils/test_toml.py | 375 +++++++++++++++++++++++++++++++++ 2 files changed, 416 insertions(+), 26 deletions(-) create mode 100644 test/utils/test_toml.py diff --git a/declearn/utils/_toml_config.py b/declearn/utils/_toml_config.py index 012a2255..adc3c2b8 100644 --- a/declearn/utils/_toml_config.py +++ b/declearn/utils/_toml_config.py @@ -18,6 +18,7 @@ """Base class to define TOML-parsable configuration containers.""" import dataclasses +import os import typing import warnings @@ -99,7 +100,7 @@ def _isinstance_generic(inputs: Any, typevar: Type) -> bool: and all(_isinstance_generic(e, t) for e, t in zip(inputs, args)) ) # Unsupported cases. - raise TypeError( + raise TypeError( # pragma: no cover "Unsupported subscripted generic for instance check: " f"'{typevar}' with origin '{origin}'." ) @@ -112,26 +113,38 @@ def _parse_float(src: str) -> Optional[float]: def _instantiate_field( field: dataclasses.Field, # future: dataclasses.Field[T] (Py >=3.9) - *args: Any, **kwargs: Any, ) -> Any: # future: T """Instantiate a dataclass field from input args and kwargs. - This functions is meant to enable automatically building dataclass - fields that are annotated to be a union of types, notably optional - fields (i.e. Union[T, None]). + This function is meant to enable building dataclass object fields, + that requring instantiating from a received value or dict of kwargs. + + It supports doing so even when for fields that are annotated to be a + union of types (notably optional ones: Union[T, None]), and/or are a + `TomlConfig` subclass that should be instantiated via `from_params` + rather than `cls(*args, **kwargs)`. It will raise a TypeError if instantiation fails or if `field.type` has and unsupported typing origin. It may also raise any exception coming from the target type's `__init__` method. """ + + def _instantiate(cls: Type[Any]) -> Any: + """Try instantiating a given class from the kwargs.""" + if issubclass(cls, TomlConfig): + return cls.from_params(**kwargs) + return cls(**kwargs) + origin = typing.get_origin(field.type) - if origin is None: # raw type - return field.type(*args, **kwargs) - if origin is Union: # union of types, including optional + # Case of a raw type. + if origin is None: + return _instantiate(field.type) + # Case of a union of types (including optional). + if origin is Union: for cls in typing.get_args(field.type): try: - return cls(*args, **kwargs) + return _instantiate(cls) except TypeError: pass raise TypeError( @@ -219,12 +232,14 @@ class TomlConfig: for key in kwargs: warnings.warn( f"Unsupported keyword argument in {cls.__name__}.from_params: " - f"'{key}'. This argument was ignored." + f"'{key}'. This argument was ignored.", + category=RuntimeWarning, ) return cls(**fields) - @staticmethod + @classmethod def default_parser( + cls, field: dataclasses.Field, # future: dataclasses.Field[T] (Py >=3.9) inputs: Union[str, Dict[str, Any], T, None], ) -> Any: @@ -272,20 +287,19 @@ class TomlConfig: raise TypeError( f"Field '{field.name}' does not provide a default value." ) - # Case of str inputs: treat as the path to a TOML file to parse. - if isinstance(inputs, str): - # If the field implements TOML parsing, call it. - if issubclass(field.type, TomlConfig): - return field.type.from_toml(inputs) - # Otherwise, conduct minimal parsing. - with open(inputs, "rb") as file: - config = tomllib.load(file, parse_float=_parse_float) - section = config.get(field.name, config) # subsection or full file - return ( - _instantiate_field(field, **section) - if isinstance(section, dict) - else _instantiate_field(field, section) - ) + # Case of str inputs poiting to a file: try parsing it. + if isinstance(inputs, str) and os.path.isfile(inputs): + try: + with open(inputs, "rb") as file: + config = tomllib.load(file, parse_float=_parse_float) + except tomllib.TOMLDecodeError as exc: + raise TypeError( + f"Field {field.name}: could not parse secondary TOML file." + ) from exc + # Look for a subsection, otherwise keep the entire file. + section = config.get(field.name, config) + # Recursively call this parser. + return cls.default_parser(field, section) # Case of dict inputs: try instantiating the target type. if isinstance(inputs, dict): return _instantiate_field(field, **inputs) @@ -377,7 +391,8 @@ class TomlConfig: for name in config: warnings.warn( f"Unsupported section encountered in {path} TOML file: " - f"'{name}'. This section will be ignored." + f"'{name}'. This section will be ignored.", + category=RuntimeWarning, ) # Finally, instantiate the FLConfig container. return cls.from_params(**params) diff --git a/test/utils/test_toml.py b/test/utils/test_toml.py new file mode 100644 index 00000000..795dc0b1 --- /dev/null +++ b/test/utils/test_toml.py @@ -0,0 +1,375 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the `TomlConfig` util.""" + +import dataclasses +import os +import warnings +from typing import Any, Dict, List, Optional, Tuple, Union + +import pytest +from typing_extensions import Self + +from declearn.utils import TomlConfig + + +class Custom: + """Custom class that requires specific TOML parsing.""" + + def __init__( + self, + value: Any = "default", + _from_config: bool = False, + ) -> None: + """Default builder for Custom instances.""" + self.value = value + self.f_cfg = _from_config + + def __eq__(self, other: Any) -> bool: + if isinstance(other, Custom): + return self.value == other.value + return NotImplemented + + @classmethod + def from_config(cls, **kwargs: Any) -> Self: + """Alternative builder for Custom instances.""" + return cls(_from_config=True, **kwargs) + + +@dataclasses.dataclass +class DemoTomlConfig(TomlConfig): + """Demonstration TomlConfig subclass.""" + + req_int: int + req_lst: List[str] + opt_str: str = "default" + opt_tup: Optional[Tuple[int, int]] = None + opt_dct: Dict[str, float] = dataclasses.field(default_factory=dict) + opt_obj: Custom = dataclasses.field(default_factory=Custom) + opt_unc: Union[str, Custom, None] = None + + @classmethod + def parse_opt_tup( + cls, + field: dataclasses.Field, + inputs: Any, + ) -> Optional[Tuple[int, int]]: + """Custom parser for `opt_tup`, adding list-to-tuple conversion.""" + if isinstance(inputs, list): + inputs = tuple(inputs) + return cls.default_parser(field, inputs) + + +class TestTomlConfigDefaultParser: + """Unit tests for `TomlConfig.default_parser`, using a demo subclass.""" + + def test_int(self) -> None: + """Test that the parser works for an int field.""" + field = DemoTomlConfig.__dataclass_fields__["req_int"] + assert TomlConfig.default_parser(field, 42) == 42 + with pytest.raises(TypeError): + TomlConfig.default_parser(field, 42.0) + + def test_lst(self) -> None: + """Test that the parser works for a list of str field.""" + field = DemoTomlConfig.__dataclass_fields__["req_lst"] + # Test with a list of str. + value = ["this", "is", "a", "test"] + assert TomlConfig.default_parser(field, value) is value + # Test with an empty list. + value = [] + assert TomlConfig.default_parser(field, value) is value + # Test with a single str. + with pytest.raises(TypeError): + TomlConfig.default_parser(field, None) + # Test with a mixed-type list. + with pytest.raises(TypeError): + TomlConfig.default_parser(field, ["this", "fails", 0]) + + def test_opt_str(self) -> None: + """Test that the parser works for an optional str field.""" + field = DemoTomlConfig.__dataclass_fields__["opt_str"] + # Test without a value. + assert TomlConfig.default_parser(field, None) is field.default + # Test with a valid value. + assert TomlConfig.default_parser(field, "test") == "test" + # Test with an invalid value. + with pytest.raises(TypeError): + TomlConfig.default_parser(field, 0) + + def test_opt_tup(self) -> None: + """Test that the parser works for an optional tuple of int field.""" + field = DemoTomlConfig.__dataclass_fields__["opt_tup"] + # Test without a value. + assert TomlConfig.default_parser(field, None) is field.default + # Test with a valid value. + value = (12, 15) + assert TomlConfig.default_parser(field, value) is value + # Test with invalid values (wrong type, length or internal types). + with pytest.raises(TypeError): + TomlConfig.default_parser(field, [12, 15]) + with pytest.raises(TypeError): + TomlConfig.default_parser(field, (12, 15, 18)) + with pytest.raises(TypeError): + TomlConfig.default_parser(field, (12, "15")) + + def test_opt_dct(self) -> None: + """Test that the parser works for an optional dict field.""" + field = DemoTomlConfig.__dataclass_fields__["opt_dct"] + # Test without a value. + assert TomlConfig.default_parser(field, None) == {} + # Test with a valid value. + value = {"a": 0.0, "b": 1.0} + assert TomlConfig.default_parser(field, value) is value + # Test with invalid values (wrong type, key types or value types). + with pytest.raises(TypeError): + TomlConfig.default_parser(field, "test") + with pytest.raises(TypeError): + TomlConfig.default_parser(field, {0: 0.0}) # type: ignore + with pytest.raises(TypeError): + TomlConfig.default_parser(field, {"a": "val"}) + + def test_opt_obj(self) -> None: + """Test that the parser works for an optional custom object field.""" + field = DemoTomlConfig.__dataclass_fields__["opt_obj"] + # Test without a value. + assert TomlConfig.default_parser(field, None) == Custom() + # Test with a valid value. + value = Custom(value=42.0) + assert TomlConfig.default_parser(field, value) is value + # Test with kwargs for the object itself. + built = TomlConfig.default_parser(field, {"value": 0.0}) + assert isinstance(built, Custom) + assert built.value == 0.0 + # Test with an invalid value. + with pytest.raises(TypeError): + TomlConfig.default_parser(field, "invalid") + + def test_opt_unc(self) -> None: + """Test that the parser works for a union with custom types.""" + field = DemoTomlConfig.__dataclass_fields__["opt_unc"] + # Test without a value. + assert TomlConfig.default_parser(field, None) is None + # Test with a valid str value. + assert TomlConfig.default_parser(field, "unc") == "unc" + # Test with a valid Custom object.class TestTomlConfig: + value = Custom(value="test") + assert TomlConfig.default_parser(field, value) is value + # Test with kwargs for a Custom object. + built = TomlConfig.default_parser(field, {"value": "test"}) + assert isinstance(built, Custom) + assert built.value == "test" + # Test with an invalid value. + with pytest.raises(TypeError): + TomlConfig.default_parser(field, {"invalid": "kwarg"}) + + +class TestTomlConfigFromParams: + """Unit tests for `TomlConfig.from_params`, using a demo subclass.""" + + exhaustive_params = { + "req_int": 0, + "req_lst": ["test"], + "opt_str": "test", + "opt_tup": (0, 1), + "opt_dct": {"key": 0.0}, + "opt_obj": Custom(value=0), + "opt_unc": Custom(value=1), + } + + def test_all_params(self) -> None: + """Test that parsing from an exhaustive dict of valid params works.""" + parsed = DemoTomlConfig.from_params(**self.exhaustive_params) + assert isinstance(parsed, DemoTomlConfig) + for key, val in self.exhaustive_params.items(): + assert getattr(parsed, key) == val + + def test_partial_params(self) -> None: + """Test that parsing without some optional params works.""" + params = {"req_int": 0, "req_lst": ["test"]} + parsed = DemoTomlConfig.from_params(**params) + assert isinstance(parsed, DemoTomlConfig) + assert all(getattr(parsed, key) == val for key, val in params.items()) + + def test_bad_params(self) -> None: + """Test that parsing with some bad params fails.""" + # Missing required keys. + with pytest.raises(RuntimeError): + DemoTomlConfig.from_params(req_int=0, opt_str="test") + # Invalid value types. + with pytest.raises(RuntimeError): + DemoTomlConfig.from_params(req_int=0, req_lst=1) + + def test_extra_params(self) -> None: + """Test that provided with extra parameters raises a warning.""" + with pytest.warns(RuntimeWarning): + DemoTomlConfig.from_params(req_int=0, req_lst=["1"], extra=2) + + +class TestTomlConfigFromToml: + """Unit tests for `TomlConfig.from_toml`, using a demo subclass.""" + + exhaustive_toml = """ + req_int = 0 + req_lst = ["test"] + opt_str = "test" + opt_tup = [0, 1] + opt_dct = {key = 0.0} + opt_obj = {value = 0} + opt_unc = {value = 1} + """ + + def test_from_exhaustive_toml(self, tmp_path: str) -> None: + """Test that parsing from an exhaustive and valid TOML file works.""" + # Export the TOML config file. + path = os.path.join(tmp_path, "config.toml") + with open(path, "w", encoding="utf-8") as file: + file.write(self.exhaustive_toml) + # Parse it and verify the outputs. + parsed = DemoTomlConfig.from_toml(path) + assert isinstance(parsed, DemoTomlConfig) + for key, val in TestTomlConfigFromParams.exhaustive_params.items(): + assert getattr(parsed, key) == val + + def test_wrong_file_fails(self, tmp_path: str) -> None: + """Test that a proper error is raised when parsing an invalid file.""" + # Export the non-TOML file. + path = os.path.join(tmp_path, "config.toml") + with open(path, "w", encoding="utf-8") as file: + file.write("this is not a TOML file") + # Verify that the parsing error is properly caught and wrapped. + with pytest.raises(RuntimeError): + DemoTomlConfig.from_toml(path) + + def test_warn_user(self, tmp_path: str) -> None: + """Test that the 'warn_user' boolean flag works properly.""" + # Export the TOML config file, with an extra field. + path = os.path.join(tmp_path, "config.toml") + with open(path, "w", encoding="utf-8") as file: + file.write(self.exhaustive_toml + "\nunused = nan") + # Parse it, expecting a warning. + with pytest.warns(RuntimeWarning): + DemoTomlConfig.from_toml(path) + # Parse it again, expecting no warning. + with warnings.catch_warnings(): + warnings.simplefilter("error") + DemoTomlConfig.from_toml(path, warn_user=False) + + def test_use_section(self, tmp_path: str) -> None: + """Test that the 'use_section' option works properly.""" + # Export the TOML config file, within a wrapping section. + path = os.path.join(tmp_path, "config.toml") + with open(path, "w", encoding="utf-8") as file: + file.write("[mysection]\n" + self.exhaustive_toml) + # Verify that a basic attempt at parsing fails. + with pytest.raises(RuntimeError): + DemoTomlConfig.from_toml(path) + # Parse it and verify the outputs. + parsed = DemoTomlConfig.from_toml(path, use_section="mysection") + assert isinstance(parsed, DemoTomlConfig) + for key, val in TestTomlConfigFromParams.exhaustive_params.items(): + assert getattr(parsed, key) == val + + def test_use_section_fails(self, tmp_path: str) -> None: + """Test that the 'use_section' option fails properly.""" + # Export the TOML config file. + path = os.path.join(tmp_path, "config.toml") + with open(path, "w", encoding="utf-8") as file: + file.write(self.exhaustive_toml) + # Verify that targetting a non-existing section fails. + with pytest.raises(KeyError): + DemoTomlConfig.from_toml(path, use_section="mysection") + # Verify that the error can be skipped using `section_fails_ok`. + parsed = DemoTomlConfig.from_toml( + path, use_section="mysection", section_fail_ok=True + ) + assert isinstance(parsed, DemoTomlConfig) + for key, val in TestTomlConfigFromParams.exhaustive_params.items(): + assert getattr(parsed, key) == val + + +@dataclasses.dataclass +class ComplexTomlConfig(TomlConfig): + """Demonstration TomlConfig subclass with some nestedness.""" + + demo_a: DemoTomlConfig + demo_b: DemoTomlConfig + + +class TestTomlConfigNested: + """Unit test for a complex, nested 'TomlConfig' subclass.""" + + def test_multisection_config(self, tmp_path: str) -> None: + """Test parsing a multi-section TOML file using nested parsers.""" + # Export the multi-section TOML config file. + path = os.path.join(tmp_path, "config.toml") + toml = ( + f"[demo_a]\n{TestTomlConfigFromToml.exhaustive_toml}\n" + f"[demo_b]\n{TestTomlConfigFromToml.exhaustive_toml}\n" + ) + with open(path, "w", encoding="utf-8") as file: + file.write(toml) + # Verify that is can be properly parsed. + parsed = ComplexTomlConfig.from_toml(path) + assert isinstance(parsed, ComplexTomlConfig) + for fname in ("demo_a", "demo_b"): + field = getattr(parsed, fname) + assert isinstance(field, DemoTomlConfig) + for key, val in TestTomlConfigFromParams.exhaustive_params.items(): + assert getattr(field, key) == val + + def test_multifiles_config(self, tmp_path: str) -> None: + """Test parsing a multi-files TOML config using nested parsers.""" + # Export the multi-files TOML config. + # demo_a: section file + path_a = os.path.join(tmp_path, "demo_a.toml") + with open(path_a, "w", encoding="utf-8") as file: + file.write(f"[demo_a]\n{TestTomlConfigFromToml.exhaustive_toml}") + # demo_b: full-file + path_b = os.path.join(tmp_path, "demo_b.toml") + with open(path_b, "w", encoding="utf-8") as file: + file.write(TestTomlConfigFromToml.exhaustive_toml) + # main config file + path = os.path.join(tmp_path, "config.toml") + with open(path, "w", encoding="utf-8") as file: + file.write(f"demo_a = '{path_a}'\ndemo_b = '{path_b}'") + # Verify that is can be properly parsed. + parsed = ComplexTomlConfig.from_toml(path) + assert isinstance(parsed, ComplexTomlConfig) + for fname in ("demo_a", "demo_b"): + field = getattr(parsed, fname) + assert isinstance(field, DemoTomlConfig) + for key, val in TestTomlConfigFromParams.exhaustive_params.items(): + assert getattr(field, key) == val + + def test_multifiles_config_fails(self, tmp_path: str) -> None: + """Test parsing a multi-files TOML config with invalid contents.""" + # Export a badly-formatted secondary file. + path_bad = os.path.join(tmp_path, "bad.toml") + with open(path_bad, "w", encoding="utf-8") as file: + file.write("this file is not a TOML one") + # Export a main config file that points to the unproper one. + path = os.path.join(tmp_path, "config.toml") + with open(path, "w", encoding="utf-8") as file: + file.write(f"demo_a = '{path_bad}'\ndemo_b = '{path_bad}'") + # Verify the the parsing error is properly caught and wrapped. + with pytest.raises(RuntimeError): + ComplexTomlConfig.from_toml(path) + with pytest.raises(TypeError): + field = ComplexTomlConfig.__dataclass_fields__["demo_a"] + ComplexTomlConfig.default_parser(field, path_bad) -- GitLab