From 753b43f693b583bc97f63547b1d824a18052c66a Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Thu, 21 Sep 2023 17:32:54 +0200 Subject: [PATCH] Minor backend changes to 'Checkpointer'. --- declearn/main/utils/_checkpoint.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/declearn/main/utils/_checkpoint.py b/declearn/main/utils/_checkpoint.py index e6677acc..bdd4c462 100644 --- a/declearn/main/utils/_checkpoint.py +++ b/declearn/main/utils/_checkpoint.py @@ -23,7 +23,7 @@ from datetime import datetime from typing import Any, Dict, List, Optional, Union import numpy as np -import pandas as pd # type: ignore +import pandas as pd from typing_extensions import Self # future: import from typing (py >=3.11) from declearn.model.api import Model @@ -319,16 +319,17 @@ class Checkpointer: } # Filter out scalar metrics and write them to a csv file. scalars = {k: v for k, v in scores.items() if isinstance(v, float)} - fpath = os.path.join(self.folder, f"{prefix}.csv") - pd.DataFrame(scalars, index=[timestamp]).to_csv( - fpath, - sep=",", - mode=("a" if append else "w"), - header=not (append and os.path.isfile(fpath)), - index=True, - index_label="timestamp", - encoding="utf-8", - ) + if scalars: + fpath = os.path.join(self.folder, f"{prefix}.csv") + pd.DataFrame(scalars, index=[timestamp]).to_csv( + fpath, + sep=",", + mode=("a" if append else "w"), + header=not (append and os.path.isfile(fpath)), + index=True, + index_label="timestamp", + encoding="utf-8", + ) # Write the full set of metrics to a JSON file. jdump = json.dumps({timestamp: scores})[1:-1] # bracket-less dict fpath = os.path.join(self.folder, f"{prefix}.json") -- GitLab