Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit 730db92c authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Add tests for 'declearn.quickrun.parse_data_folder'.

parent 89c807bb
Branches
Tags
1 merge request!57Improve tests coverage and fix test-digged bugs
......@@ -4,10 +4,12 @@
import os
import pathlib
from typing import List
import pytest
from declearn.quickrun import parse_data_folder
from declearn.quickrun._config import DataSourceConfig
from declearn.quickrun._parser import (
get_data_folder_path,
list_client_names,
......@@ -151,3 +153,114 @@ class TestListClientNames:
list_client_names(tmp_path, "invalid-type") # type: ignore
with pytest.raises(ValueError):
list_client_names(tmp_path, ["client_2"])
class TestParseDataFolder:
"""Docstring."""
@staticmethod
def setup_data_folder(
data_folder: str,
client_names: List[str],
file_names: List[str],
) -> None:
"""Set up a data folder, with client subfolders and empty files."""
for cname in client_names:
folder = os.path.join(data_folder, cname)
os.makedirs(folder)
for fname in file_names:
path = os.path.join(folder, fname)
with open(path, "w", encoding="utf-8") as file:
file.write("")
def test_parse_data_folder_with_default_names(
self,
tmp_path: pathlib.Path,
) -> None:
"""Test 'parse_data_folder' with default file names."""
# Setup a data folder with a couple of clients and default files.
data_folder = tmp_path.as_posix()
client_names = ["client-1", "client-2"]
file_names = [
# fmt: off
"train_data", "train_target", "valid_data", "valid_target"
]
self.setup_data_folder(data_folder, client_names, file_names)
# Write up the expected outputs.
expected = {
cname: {
fname: os.path.join(data_folder, cname, fname)
for fname in file_names
}
for cname in client_names
}
# Run the function and validate its outputs.
config = DataSourceConfig(
data_folder=data_folder,
client_names=None,
dataset_names=None,
)
clients = parse_data_folder(config)
assert clients == expected
def test_parse_data_folder_with_custom_names(
self,
tmp_path: pathlib.Path,
) -> None:
"""Test 'parse_data_folder' with custom file names."""
# Setup a data folder with a couple of clients and default files.
data_folder = tmp_path.as_posix()
client_names = ["client-1", "client-2"]
base_names = [
# fmt: off
"train_data", "train_target", "valid_data", "valid_target"
]
file_names = ["x_train", "y_train", "x_valid", "y_valid"]
self.setup_data_folder(data_folder, client_names, file_names)
# Write up the expected outputs.
expected = {
cname: {
bname: os.path.join(data_folder, cname, fname)
for bname, fname in zip(base_names, file_names)
}
for cname in client_names
}
# Run the function and validate its outputs.
config = DataSourceConfig(
data_folder=data_folder,
client_names=None,
dataset_names=dict(zip(base_names, file_names)),
)
clients = parse_data_folder(config)
assert clients == expected
# Verify that it would not work without the argument.
config = DataSourceConfig(
data_folder=data_folder,
client_names=None,
dataset_names=None,
)
with pytest.raises(ValueError):
clients = parse_data_folder(config)
def test_parse_data_folder_fails_multiple_files(
self,
tmp_path: pathlib.Path,
) -> None:
"""Test that 'parse_data_folder' fails with same-name files."""
# Setup a data folder with a couple of clients and duplicated files.
data_folder = tmp_path.as_posix()
client_names = ["client-1", "client-2"]
file_names = [
# fmt: off
"train_data", "train_target", "valid_data", "valid_target",
"train_data.bis" # duplicated name prefix
]
self.setup_data_folder(data_folder, client_names, file_names)
# Verify that the expected exception is raised.
config = DataSourceConfig(
data_folder=data_folder,
client_names=None,
dataset_names=None,
)
with pytest.raises(ValueError):
parse_data_folder(config)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment