diff --git a/declearn/quickrun/run.py b/declearn/quickrun/run.py index 136e2afa7a712ed7acbcf6f5f9706aa23bc7710c..ae51b5f22f24d09c0c8b4eb43907f0cffab5c725 100644 --- a/declearn/quickrun/run.py +++ b/declearn/quickrun/run.py @@ -57,19 +57,23 @@ from declearn.utils import LOGGING_LEVEL_MAJOR, get_logger, run_as_processes __all__ = ["quickrun"] -DEFAULT_FOLDER = "./examples/quickrun" + +DEFAULT_FOLDER = os.path.join( + os.path.dirname(__file__), + "../../examples/quickrun", +) def get_model(folder, model_config) -> Model: "Return a model instance from a model config instance" - file = "model" - if m_file := model_config.model_file: - folder = os.path.dirname(m_file) - file = m_file.rsplit("/", 1)[-1].split(".")[0] - with make_importable(folder): - mod = importlib.import_module(file) - model_cls = getattr(mod, model_config.model_name) - return model_cls + path = model_config.model_file or os.path.join(folder, "model.py") + path = os.path.abspath(path) + if not os.path.isfile(path): + raise FileNotFoundError("Model file not found: '{path}'.") + with make_importable(os.path.dirname(path)): + mod = importlib.import_module(os.path.basename(path)[:-3]) + model = getattr(mod, model_config.model_name) + return model def get_checkpoint(folder: str, expe_config: ExperimentConfig) -> str: @@ -152,6 +156,10 @@ def get_toml_folder(config: Optional[str] = None) -> Tuple[str, str]: elif os.path.isdir(config): folder = config toml = f"{folder}/config.toml" + else: + raise FileNotFoundError( + f"Failed to find quickrun config file at '{config}'." + ) return toml, folder