diff --git a/declearn/quickrun/_config.py b/declearn/quickrun/_config.py index c2559da5dd94a5397b5dae53954ba248abeceb9e..fc30685f4d724b29a2657378bdb3e2ac4888fb8f 100644 --- a/declearn/quickrun/_config.py +++ b/declearn/quickrun/_config.py @@ -18,8 +18,9 @@ """TOML-parsable container for quickrun configurations.""" import dataclasses -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union +from declearn.metrics import MetricInputType, MetricSet from declearn.utils import TomlConfig __all__ = [ @@ -108,5 +109,27 @@ class ExperimentConfig(TomlConfig): to be used so as to save round-wise model """ - metrics: Optional[List[str]] = None + metrics: Optional[MetricSet] = None checkpoint: Optional[str] = None + + def parse_metrics( + self, + inputs: Union[MetricSet, Dict[str, Any], List[MetricInputType], None], + ) -> Optional[MetricSet]: + """Parser for metrics.""" + if inputs is None or isinstance(inputs, MetricSet): + return None + try: + # Case of a manual listing of metrics (most expected). + if isinstance(inputs, (tuple, list)): + return MetricSet.from_specs(inputs) + # Case of a MetricSet config dict (unexpected but supported). + if isinstance(inputs, dict): + return MetricSet.from_config(inputs) + except (TypeError, ValueError) as exc: + raise TypeError( + f"Failed to parse inputs for field 'metrics': {exc}." + ) from exc + raise TypeError( + "Failed to parse inputs for field 'metrics': unproper type." + )