diff --git a/cadetrdm/batch_running/case.py b/cadetrdm/batch_running/case.py index e94c1db..ffd3ee0 100644 --- a/cadetrdm/batch_running/case.py +++ b/cadetrdm/batch_running/case.py @@ -21,6 +21,7 @@ def __init__( environment: Environment| None = None, name: str | None = None, study: Study | None = None, + run_method: str = "main" ) -> None: if study is not None: warnings.warn( @@ -43,6 +44,7 @@ def __init__( self._options_hash = options.get_hash() self.results_branch = None self._current_environment = None + self.run_method = run_method def __str__(self): return self.name @@ -213,7 +215,9 @@ def run_study(self, force=False, container_adapter: "ContainerAdapter" = None, c self.status = "failed" return False else: - self.project_repo.module.main(self.options, str(self.project_repo.path)) + module = self.project_repo.module + run_method = getattr(module, self.run_method) + run_method(self.options, str(self.project_repo.path)) print("Command execution successful.") self.status = 'finished' diff --git a/cadetrdm/repositories.py b/cadetrdm/repositories.py index cf1fa2a..ffd108f 100644 --- a/cadetrdm/repositories.py +++ b/cadetrdm/repositories.py @@ -14,6 +14,7 @@ from stat import S_IREAD, S_IWRITE import tarfile import tempfile +from types import ModuleType from typing import List, Optional, Any from urllib.request import urlretrieve @@ -758,10 +759,19 @@ def add_list_of_remotes_in_readme_file(self, repo_identifier: str, remotes_url_l class ProjectRepo(BaseRepo): - def __init__(self, path=None, output_folder=None, - search_parent_directories=True, suppress_lfs_warning=False, - url=None, branch=None, options=None, - *args, **kwargs): + def __init__( + self, + path: os.PathLike = None, + output_folder = None, + search_parent_directories: bool = True, + suppress_lfs_warning: bool = False, + url: str = None, + branch: str = None, + options: Options | None = None, + package_dir: str | None = None, + *args: Any, + **kwargs: Any, + ) -> None: """ Class for Project-Repositories. Handles interaction between the project repo and the output (i.e. results) repo. @@ -780,6 +790,8 @@ def __init__(self, path=None, output_folder=None, from a system without git-lfs :param branch: Optional branch to check out upon initialization + :param package_dir: + Name of the directory containing the main package. :param options: Options dictionary containing ... :param args: @@ -823,24 +835,30 @@ def __init__(self, path=None, output_folder=None, if branch is not None: self.checkout(branch) + self._package_dir = package_dir + @property def name(self): return self.path.parts[-1] @property - def module(self): + def package_dir(self) -> str: + if self._package_dir is None: + return self.name + return self._package_dir + + @property + def module(self) -> ModuleType: cur_dir = os.getcwd() - os.chdir(self.path) try: sys.path.insert(0, str(self.path)) - module = importlib.import_module(self.name) + os.chdir(self.path) + return importlib.import_module(self.package_dir) finally: sys.path.remove(str(self.path)) os.chdir(cur_dir) - return module - def _update_version(self, metadata, cadetrdm_version): current_version = Version.coerce(metadata["cadet_rdm_version"])