diff --git a/declearn/main/utils/_checkpoint.py b/declearn/main/utils/_checkpoint.py index e6677accf0ea31cc28e06e45a53ea9b39d2b36c2..bdd4c462475b2d05659fc54820e98b388bf8e1fd 100644 --- a/declearn/main/utils/_checkpoint.py +++ b/declearn/main/utils/_checkpoint.py @@ -23,7 +23,7 @@ from datetime import datetime from typing import Any, Dict, List, Optional, Union import numpy as np -import pandas as pd # type: ignore +import pandas as pd from typing_extensions import Self # future: import from typing (py >=3.11) from declearn.model.api import Model @@ -319,16 +319,17 @@ class Checkpointer: } # Filter out scalar metrics and write them to a csv file. scalars = {k: v for k, v in scores.items() if isinstance(v, float)} - fpath = os.path.join(self.folder, f"{prefix}.csv") - pd.DataFrame(scalars, index=[timestamp]).to_csv( - fpath, - sep=",", - mode=("a" if append else "w"), - header=not (append and os.path.isfile(fpath)), - index=True, - index_label="timestamp", - encoding="utf-8", - ) + if scalars: + fpath = os.path.join(self.folder, f"{prefix}.csv") + pd.DataFrame(scalars, index=[timestamp]).to_csv( + fpath, + sep=",", + mode=("a" if append else "w"), + header=not (append and os.path.isfile(fpath)), + index=True, + index_label="timestamp", + encoding="utf-8", + ) # Write the full set of metrics to a JSON file. jdump = json.dumps({timestamp: scores})[1:-1] # bracket-less dict fpath = os.path.join(self.folder, f"{prefix}.json")