Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit 89c807bb authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Add unit tests for some quickrun backend utils.

parent 474f5447
No related branches found
No related tags found
1 merge request!57Improve tests coverage and fix test-digged bugs
......@@ -103,8 +103,8 @@ def parse_data_folder(
def get_data_folder_path(
data_folder: Optional[str],
root_folder: Optional[str],
data_folder: Optional[str] = None,
root_folder: Optional[str] = None,
) -> Path:
"""Return the path to a data folder.
......@@ -158,7 +158,7 @@ def get_data_folder_path(
def list_client_names(
data_folder: Path,
client_names: Optional[List[str]],
client_names: Optional[List[str]] = None,
) -> List[str]:
"""List client-wise subdirectories under a data folder.
......
......@@ -143,19 +143,32 @@ def run_client(
def get_toml_folder(config: str) -> Tuple[str, str]:
"""Return the path to an experiment's folder and TOML config file.
Determine if provided config is a file or a directory, and return:
* The path to the TOML config file
* The path to the main folder of the experiment
Parameters
----------
config: str
Path to either a TOML config file (within an experiment folder),
or to the experiment folder containing a "config.toml" file.
Returns
-------
toml:
The path to the TOML config file.
folder:
The path to the main folder of the experiment.
Raises
------
FileNotFoundError:
If the TOML config file cannot be found based on inputs.
"""
config = os.path.abspath(config)
if os.path.isfile(config):
toml = config
folder = os.path.dirname(config)
elif os.path.isdir(config):
if os.path.isdir(config):
folder = config
toml = f"{folder}/config.toml"
toml = os.path.join(folder, "config.toml")
else:
toml = config
folder = os.path.dirname(toml)
if not os.path.isfile(toml):
raise FileNotFoundError(
f"Failed to find quickrun config file at '{config}'."
)
......
# coding: utf-8
"""Tests for some 'declearn.quickrun' backend utils."""
import os
import pathlib
import pytest
from declearn.quickrun import parse_data_folder
from declearn.quickrun._parser import (
get_data_folder_path,
list_client_names,
)
from declearn.quickrun._run import get_toml_folder
class TestGetTomlFolder:
"""Tests for 'declearn.quickrun._run.get_toml_folder'."""
def test_get_toml_folder_from_file(
self,
tmp_path: pathlib.Path,
) -> None:
"""Test that 'get_toml_folder' works with a TOML file path."""
config = os.path.join(tmp_path, "config.toml")
with open(config, "w", encoding="utf-8") as file:
file.write("")
toml, folder = get_toml_folder(config)
assert toml == config
assert folder == tmp_path.as_posix()
def test_get_toml_folder_from_folder(
self,
tmp_path: pathlib.Path,
) -> None:
"""Test that 'get_toml_folder' works with a folder path."""
config = os.path.join(tmp_path, "config.toml")
with open(config, "w", encoding="utf-8") as file:
file.write("")
toml, folder = get_toml_folder(tmp_path.as_posix())
assert toml == config
assert folder == tmp_path.as_posix()
def test_get_toml_folder_from_file_fails(
self,
tmp_path: pathlib.Path,
) -> None:
"""Test that it fails with a path to a non-existing file."""
config = os.path.join(tmp_path, "config.toml")
with pytest.raises(FileNotFoundError):
get_toml_folder(config)
def test_get_toml_folder_from_folder_fails(
self,
tmp_path: pathlib.Path,
) -> None:
"""Test that it fails with a folder lacking a 'config.toml' file."""
with pytest.raises(FileNotFoundError):
get_toml_folder(tmp_path.as_posix())
class TestGetDataFolderPath:
"""Tests for 'declearn.quickrun._parser.get_data_folder_path'."""
def test_get_data_folder_path_from_data_folder(
self,
tmp_path: pathlib.Path,
) -> None:
"""Test that it works with a valid 'data_folder' argument."""
path = get_data_folder_path(data_folder=tmp_path.as_posix())
assert isinstance(path, pathlib.Path)
assert path == tmp_path
def test_get_data_folder_path_from_root_folder(
self,
tmp_path: pathlib.Path,
) -> None:
"""Test that it works with a valid 'root_folder' argument."""
data_dir = os.path.join(tmp_path, "data")
os.makedirs(data_dir)
path = get_data_folder_path(root_folder=tmp_path.as_posix())
assert isinstance(path, pathlib.Path)
assert path.as_posix() == data_dir
def test_get_data_folder_path_fails_no_inputs(
self,
) -> None:
"""Test that it fails with no folder specification."""
with pytest.raises(ValueError):
get_data_folder_path(None, None)
def test_get_data_folder_path_fails_invalid_data_folder(
self,
tmp_path: pathlib.Path,
) -> None:
"""Test that it fails with an invalid data_folder."""
missing_folder = os.path.join(tmp_path, "data")
with pytest.raises(ValueError):
get_data_folder_path(data_folder=missing_folder)
def test_get_data_folder_path_fails_from_root_no_data(
self,
tmp_path: pathlib.Path,
) -> None:
"""Test that it fails with an invalid root_folder (no data)."""
with pytest.raises(ValueError):
get_data_folder_path(root_folder=tmp_path.as_posix())
def test_get_data_folder_path_fails_from_root_multiple_data(
self,
tmp_path: pathlib.Path,
) -> None:
"""Test that it fails with multiple data* under root_folder."""
os.makedirs(os.path.join(tmp_path, "data_1"))
os.makedirs(os.path.join(tmp_path, "data_2"))
with pytest.raises(ValueError):
get_data_folder_path(root_folder=tmp_path.as_posix())
class TestListClientNames:
"""Tests for the 'declearn.quickrun._parser.list_client_names' function."""
def test_list_client_names_from_folder(
self,
tmp_path: pathlib.Path,
) -> None:
"""Test that it works with a data folder."""
os.makedirs(os.path.join(tmp_path, "client_1"))
os.makedirs(os.path.join(tmp_path, "client_2"))
names = list_client_names(tmp_path)
assert isinstance(names, list)
assert sorted(names) == ["client_1", "client_2"]
def test_list_client_names_from_names(
self,
tmp_path: pathlib.Path,
) -> None:
"""Test that it works with pre-specified names."""
os.makedirs(os.path.join(tmp_path, "client_1"))
os.makedirs(os.path.join(tmp_path, "client_2"))
names = list_client_names(tmp_path, ["client_2"])
assert names == ["client_2"]
def test_list_client_names_fails_invalid_names(
self,
tmp_path: pathlib.Path,
) -> None:
"""Test that it works with invalid pre-specified names."""
with pytest.raises(ValueError):
list_client_names(tmp_path, "invalid-type") # type: ignore
with pytest.raises(ValueError):
list_client_names(tmp_path, ["client_2"])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment