diff --git a/declearn/quickrun/_parser.py b/declearn/quickrun/_parser.py index 422d0999d6625b0959ebbd0931ec0d841f13d88d..f15bfb4f2c1c4a19424aff6e110553bfd6bde4b7 100644 --- a/declearn/quickrun/_parser.py +++ b/declearn/quickrun/_parser.py @@ -103,8 +103,8 @@ def parse_data_folder( def get_data_folder_path( - data_folder: Optional[str], - root_folder: Optional[str], + data_folder: Optional[str] = None, + root_folder: Optional[str] = None, ) -> Path: """Return the path to a data folder. @@ -158,7 +158,7 @@ def get_data_folder_path( def list_client_names( data_folder: Path, - client_names: Optional[List[str]], + client_names: Optional[List[str]] = None, ) -> List[str]: """List client-wise subdirectories under a data folder. diff --git a/declearn/quickrun/_run.py b/declearn/quickrun/_run.py index c33ae875dd1af0f5df6ca02e610877bcee62818a..46f14384f282f851a6dddbe664c624bc1664426b 100644 --- a/declearn/quickrun/_run.py +++ b/declearn/quickrun/_run.py @@ -143,19 +143,32 @@ def run_client( def get_toml_folder(config: str) -> Tuple[str, str]: """Return the path to an experiment's folder and TOML config file. - Determine if provided config is a file or a directory, and return: - - * The path to the TOML config file - * The path to the main folder of the experiment + Parameters + ---------- + config: str + Path to either a TOML config file (within an experiment folder), + or to the experiment folder containing a "config.toml" file. + + Returns + ------- + toml: + The path to the TOML config file. + folder: + The path to the main folder of the experiment. + + Raises + ------ + FileNotFoundError: + If the TOML config file cannot be found based on inputs. """ config = os.path.abspath(config) - if os.path.isfile(config): - toml = config - folder = os.path.dirname(config) - elif os.path.isdir(config): + if os.path.isdir(config): folder = config - toml = f"{folder}/config.toml" + toml = os.path.join(folder, "config.toml") else: + toml = config + folder = os.path.dirname(toml) + if not os.path.isfile(toml): raise FileNotFoundError( f"Failed to find quickrun config file at '{config}'." ) diff --git a/test/quickrun/test_utils.py b/test/quickrun/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5272641fb155197c350f8537f01c523b955be33b --- /dev/null +++ b/test/quickrun/test_utils.py @@ -0,0 +1,153 @@ +# coding: utf-8 + +"""Tests for some 'declearn.quickrun' backend utils.""" + +import os +import pathlib + +import pytest + +from declearn.quickrun import parse_data_folder +from declearn.quickrun._parser import ( + get_data_folder_path, + list_client_names, +) +from declearn.quickrun._run import get_toml_folder + + +class TestGetTomlFolder: + """Tests for 'declearn.quickrun._run.get_toml_folder'.""" + + def test_get_toml_folder_from_file( + self, + tmp_path: pathlib.Path, + ) -> None: + """Test that 'get_toml_folder' works with a TOML file path.""" + config = os.path.join(tmp_path, "config.toml") + with open(config, "w", encoding="utf-8") as file: + file.write("") + toml, folder = get_toml_folder(config) + assert toml == config + assert folder == tmp_path.as_posix() + + def test_get_toml_folder_from_folder( + self, + tmp_path: pathlib.Path, + ) -> None: + """Test that 'get_toml_folder' works with a folder path.""" + config = os.path.join(tmp_path, "config.toml") + with open(config, "w", encoding="utf-8") as file: + file.write("") + toml, folder = get_toml_folder(tmp_path.as_posix()) + assert toml == config + assert folder == tmp_path.as_posix() + + def test_get_toml_folder_from_file_fails( + self, + tmp_path: pathlib.Path, + ) -> None: + """Test that it fails with a path to a non-existing file.""" + config = os.path.join(tmp_path, "config.toml") + with pytest.raises(FileNotFoundError): + get_toml_folder(config) + + def test_get_toml_folder_from_folder_fails( + self, + tmp_path: pathlib.Path, + ) -> None: + """Test that it fails with a folder lacking a 'config.toml' file.""" + with pytest.raises(FileNotFoundError): + get_toml_folder(tmp_path.as_posix()) + + +class TestGetDataFolderPath: + """Tests for 'declearn.quickrun._parser.get_data_folder_path'.""" + + def test_get_data_folder_path_from_data_folder( + self, + tmp_path: pathlib.Path, + ) -> None: + """Test that it works with a valid 'data_folder' argument.""" + path = get_data_folder_path(data_folder=tmp_path.as_posix()) + assert isinstance(path, pathlib.Path) + assert path == tmp_path + + def test_get_data_folder_path_from_root_folder( + self, + tmp_path: pathlib.Path, + ) -> None: + """Test that it works with a valid 'root_folder' argument.""" + data_dir = os.path.join(tmp_path, "data") + os.makedirs(data_dir) + path = get_data_folder_path(root_folder=tmp_path.as_posix()) + assert isinstance(path, pathlib.Path) + assert path.as_posix() == data_dir + + def test_get_data_folder_path_fails_no_inputs( + self, + ) -> None: + """Test that it fails with no folder specification.""" + with pytest.raises(ValueError): + get_data_folder_path(None, None) + + def test_get_data_folder_path_fails_invalid_data_folder( + self, + tmp_path: pathlib.Path, + ) -> None: + """Test that it fails with an invalid data_folder.""" + missing_folder = os.path.join(tmp_path, "data") + with pytest.raises(ValueError): + get_data_folder_path(data_folder=missing_folder) + + def test_get_data_folder_path_fails_from_root_no_data( + self, + tmp_path: pathlib.Path, + ) -> None: + """Test that it fails with an invalid root_folder (no data).""" + with pytest.raises(ValueError): + get_data_folder_path(root_folder=tmp_path.as_posix()) + + def test_get_data_folder_path_fails_from_root_multiple_data( + self, + tmp_path: pathlib.Path, + ) -> None: + """Test that it fails with multiple data* under root_folder.""" + os.makedirs(os.path.join(tmp_path, "data_1")) + os.makedirs(os.path.join(tmp_path, "data_2")) + with pytest.raises(ValueError): + get_data_folder_path(root_folder=tmp_path.as_posix()) + + +class TestListClientNames: + """Tests for the 'declearn.quickrun._parser.list_client_names' function.""" + + def test_list_client_names_from_folder( + self, + tmp_path: pathlib.Path, + ) -> None: + """Test that it works with a data folder.""" + os.makedirs(os.path.join(tmp_path, "client_1")) + os.makedirs(os.path.join(tmp_path, "client_2")) + names = list_client_names(tmp_path) + assert isinstance(names, list) + assert sorted(names) == ["client_1", "client_2"] + + def test_list_client_names_from_names( + self, + tmp_path: pathlib.Path, + ) -> None: + """Test that it works with pre-specified names.""" + os.makedirs(os.path.join(tmp_path, "client_1")) + os.makedirs(os.path.join(tmp_path, "client_2")) + names = list_client_names(tmp_path, ["client_2"]) + assert names == ["client_2"] + + def test_list_client_names_fails_invalid_names( + self, + tmp_path: pathlib.Path, + ) -> None: + """Test that it works with invalid pre-specified names.""" + with pytest.raises(ValueError): + list_client_names(tmp_path, "invalid-type") # type: ignore + with pytest.raises(ValueError): + list_client_names(tmp_path, ["client_2"])