From e50c311ebc2f797f5110f35adec1062c2a53cb76 Mon Sep 17 00:00:00 2001 From: mattcieslak Date: Mon, 5 May 2025 10:52:48 -0400 Subject: [PATCH] stash --- babs/base.py | 14 +++ babs/status.py | 168 ++++++++++++++++++++++++++++++++++++ tests/pytest_in_docker.sh | 2 +- tests/test_status.py | 175 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 358 insertions(+), 1 deletion(-) create mode 100644 babs/status.py create mode 100644 tests/test_status.py diff --git a/babs/base.py b/babs/base.py index c6e1f818..93f6c192 100644 --- a/babs/base.py +++ b/babs/base.py @@ -13,6 +13,7 @@ from babs.scheduler import ( request_all_job_status, ) +from babs.status import StatusCollection from babs.system import validate_queue from babs.utils import ( combine_inclusion_dataframes, @@ -132,6 +133,7 @@ def _apply_config(self) -> None: - queue - container - input_datasets + - status_collection """ # Sanity check: the path `project_root` exists: @@ -167,6 +169,18 @@ def _apply_config(self) -> None: self.input_datasets = InputDatasets(self.processing_level, config_yaml['input_datasets']) self.input_datasets.update_abs_paths(Path(self.project_root) / 'analysis') + self._update_status_collection() + + def _update_status_collection(self) -> None: + """ + Update the status collection. + """ + self.status_collection = StatusCollection( + self.list_sub_path_abs, + jobs=self.job_status_path_abs, + results=self.job_status_path_abs, + ) + def _update_inclusion_dataframe( self, initial_inclusion_df: pd.DataFrame | None = None ) -> None: diff --git a/babs/status.py b/babs/status.py new file mode 100644 index 00000000..34edbde9 --- /dev/null +++ b/babs/status.py @@ -0,0 +1,168 @@ +import os.path as op +from dataclasses import dataclass +from typing import Literal + +import pandas as pd + + +@dataclass +class StatusKey: + sub_id: str + ses_id: str | None = None + + def __post_init__(self): + self.key = (self.sub_id, self.ses_id) + + +@dataclass +class JobStatus: + """ + This class is used to get the status of the jobs. + """ + + job_id: int = -1 + task_id: int = -1 + state: str = 'Unknown' + time_used: str = 'Unknown' + time_limit: str = 'Unknown' + nodes: int = 0 + cpus: int = 0 + partition: str = 'Unknown' + name: str = 'Unknown' + sub_id: str | None = None + ses_id: str | None = None + + +@dataclass +class ResultStatus: + """ + This class is used to get the status of the results. + """ + + has_results: bool = False + is_failed: bool = False + result_location: Literal['branch', 'zip'] | None = None + submitted: bool = False + sub_id: str | None = None + ses_id: str | None = None + + +class StatusCollection: + """ + This class is used to get the status of the jobs and results. + """ + + status_ids: list[StatusKey] + results: dict[tuple[str, str | None], ResultStatus] + jobs: dict[tuple[str, str | None], JobStatus] + + def __init__( + self, + inclusion_df: pd.DataFrame | str, + results: pd.DataFrame | str | None = None, + jobs: pd.DataFrame | str | None = None, + ): + if isinstance(inclusion_df, str): + inclusion_df = pd.read_csv(inclusion_df) + self.status_ids = [ + StatusKey(row['sub_id'], row.get('ses_id', None)) + for row in inclusion_df.to_dict(orient='records') + ] + self.status_ids.sort(key=lambda x: x.key) + self.results = {} + self.jobs = {} + self.update_results(results) + self.update_jobs(jobs) + + def update_jobs(self, new_jobs: pd.DataFrame | str | list[JobStatus] | None = None): + """ + Update the jobs in the StatusCollection. + + The total number of jobs in this dataframe will always be equal to the + number of subjects in the inclusion_df. + + Parameters + ---------- + new_jobs: pd.DataFrame | str | list[JobStatus] + The new jobs to update the StatusCollection with. If a str, it assumed + to be the path to a csv file containing the jobs. If a pd.DataFrame, + it is assumed to have the correct columns and will be converted to a + dict of JobStatus objects. If a list, it is assumed to be a list of + JobStatus objects. + """ + # Ensure we have a list of JobStatus objects + if isinstance(new_jobs, str) or isinstance(new_jobs, pd.DataFrame): + job_list = _load_jobs(new_jobs) + elif isinstance(new_jobs, list): + job_list = new_jobs + else: + job_list = [] + + # Create a dictionary of JobStatus objects + jobs = {} + for job in job_list: + key = (job.sub_id, job.ses_id) + jobs[key] = job + # Update the results to show that the job has been submitted + self.results[key].submitted = True + + # For any keys not in the updated, keep the old jobstatus or make an empty new one + all_keys = {status_id.key for status_id in self.status_ids} + missing_jobs_keys = all_keys - set(jobs.keys()) + for key in missing_jobs_keys: + jobs[key] = self.jobs.get(key, JobStatus(sub_id=key[0], ses_id=key[1])) + self.jobs = jobs + + def update_results(self, new_results: pd.DataFrame | str | list[ResultStatus] | None = None): + # Ensure we have a list of ResultStatus objects + if isinstance(new_results, str) or isinstance(new_results, pd.DataFrame): + result_list = _load_results(new_results) + elif isinstance(new_results, list): + result_list = new_results + else: + result_list = [] + + # Create a dictionary of ResultStatus objects + results = {} + for result in result_list: + key = (result.sub_id, result.ses_id) + results[key] = result + + # Warn about any keys in results_df that are not in all_keys: + all_keys = {status_id.key for status_id in self.status_ids} + missing_results_keys = all_keys - set(results.keys()) + for key in missing_results_keys: + results[key] = self.results.get(key, ResultStatus(sub_id=key[0], ses_id=key[1])) + self.results = results + + def write_results(self, output_dir: str): + # Convert None to pd.NA for writing to CSV + results_df = pd.DataFrame(self.results.values()).replace({None: pd.NA}) + results_df.to_csv(op.join(output_dir, 'results.csv'), index=False) + jobs_df = pd.DataFrame(self.jobs.values()).replace({None: pd.NA}) + jobs_df.to_csv(op.join(output_dir, 'jobs.csv'), index=False) + + +def _load_jobs(new_jobs: pd.DataFrame | str): + if isinstance(new_jobs, str): + new_jobs = pd.read_csv(new_jobs) + # Convert nan to None + jobs_dict = new_jobs.replace({pd.NA: None, pd.NaT: None}).to_dict(orient='records') + return [JobStatus(**row) for row in jobs_dict] + + +def _load_results(new_results: pd.DataFrame | str): + if isinstance(new_results, str): + new_results = pd.read_csv(new_results) + # Convert nan to None + results_dict = new_results.replace({pd.NA: None, pd.NaT: None}).to_dict(orient='records') + return [ResultStatus(**row) for row in results_dict] + + +def results_from_branches(branches: list[str]): + """ + Get the results from the branches. + """ + results = [] + for branch in branches: + results.append(ResultStatus(branch=branch)) diff --git a/tests/pytest_in_docker.sh b/tests/pytest_in_docker.sh index 54c29620..30daa41b 100755 --- a/tests/pytest_in_docker.sh +++ b/tests/pytest_in_docker.sh @@ -10,5 +10,5 @@ docker run -it \ --cov-report=xml \ --cov=babs \ --pdb \ - /babs/tests/test_update_input_data.py + /babs/tests/test_status.py \ No newline at end of file diff --git a/tests/test_status.py b/tests/test_status.py new file mode 100644 index 00000000..badfc7ce --- /dev/null +++ b/tests/test_status.py @@ -0,0 +1,175 @@ +import random +from dataclasses import asdict + +import pandas as pd + +from babs.status import JobStatus, ResultStatus, StatusCollection, _load_jobs, _load_results + + +def make_inclusion_df(n_subjects: int, n_sessions: int, output_csv: str) -> None: + if n_sessions == 0: + df = pd.DataFrame( + { + 'sub_id': [f'sub-{i:02d}' for i in range(n_subjects)], + } + ) + else: + sub_ids = [] + ses_ids = [] + for subid in range(n_subjects): + for sesid in range(n_sessions): + sub_ids.append(f'sub-{subid:02d}') + ses_ids.append(f'ses-{sesid:02d}') + df = pd.DataFrame( + { + 'sub_id': sub_ids, + 'ses_id': ses_ids, + } + ) + df.to_csv(output_csv, index=False) + + +def test_status_sessionlevel_collection(): + n_subjects = 10 + n_sessions = 2 + output_csv = 'inclusion.csv' + make_inclusion_df(n_subjects, n_sessions, output_csv) + status_collection = StatusCollection(output_csv) + assert len(status_collection.jobs) == n_subjects * n_sessions + assert len(status_collection.results) == n_subjects * n_sessions + + +def test_status_subjectlevel_collection(): + n_subjects = 10 + n_sessions = 0 + output_csv = 'inclusion.csv' + make_inclusion_df(n_subjects, n_sessions, output_csv) + status_collection = StatusCollection(output_csv) + assert len(status_collection.jobs) == n_subjects + assert len(status_collection.results) == n_subjects + + +def check_equal( + status_collection_1: StatusCollection, + status_collection_2: StatusCollection, + jobs_or_results: str, +): + status_ids_1 = {status_id.key for status_id in status_collection_1.status_ids} + status_ids_2 = {status_id.key for status_id in status_collection_2.status_ids} + assert status_ids_1 == status_ids_2 + + for status_id in status_ids_1: + if jobs_or_results == 'results': + item1 = status_collection_1.results[status_id] + item2 = status_collection_2.results[status_id] + else: + item1 = status_collection_1.jobs[status_id] + item2 = status_collection_2.jobs[status_id] + + assert asdict(item1) == asdict(item2) + + +def test_update_status_collection_with_jobs(tmp_path_factory): + n_subjects = 10 + n_sessions = 2 + output_csv = 'inclusion.csv' + make_inclusion_df(n_subjects, n_sessions, output_csv) + status_collection = StatusCollection(output_csv) + + # choose 4 random status_ids to add results to + changed_status_ids = random.sample(status_collection.status_ids, 4) + updated_keys = {status_id.key for status_id in changed_status_ids} + update_jobs = [] + for jobnum, status_id in enumerate(changed_status_ids): + update_jobs.append( + JobStatus( + sub_id=status_id.sub_id, + ses_id=status_id.ses_id, + job_id=1, + task_id=jobnum + 1, + state='running', + time_used='1:00:00', + time_limit='2:00:00', + nodes=1, + cpus=1, + partition='standard', + name='test', + ) + ) + status_collection.update_jobs(update_jobs) + for jobnum, status_id in enumerate(changed_status_ids): + assert status_collection.jobs[status_id.key].job_id == 1 + assert status_collection.jobs[status_id.key].task_id == jobnum + 1 + assert status_collection.jobs[status_id.key].state == 'running' + assert status_collection.jobs[status_id.key].time_used == '1:00:00' + assert status_collection.jobs[status_id.key].time_limit == '2:00:00' + assert status_collection.jobs[status_id.key].nodes == 1 + assert status_collection.jobs[status_id.key].cpus == 1 + assert status_collection.jobs[status_id.key].partition == 'standard' + assert status_collection.jobs[status_id.key].name == 'test' + + jobs_dir = tmp_path_factory.mktemp('jobs') + status_collection.write_results(str(jobs_dir)) + + new_status_collection = StatusCollection( + pd.read_csv(output_csv), jobs=str(jobs_dir / 'jobs.csv') + ) + check_equal(status_collection, new_status_collection, 'jobs') + + loaded_jobs = _load_jobs(str(jobs_dir / 'jobs.csv')) + assert len(loaded_jobs) == len(status_collection.jobs) + + pd_loaded_jobs = _load_jobs(pd.read_csv(str(jobs_dir / 'jobs.csv'))) + assert len(pd_loaded_jobs) == len(status_collection.jobs) + + +def test_update_status_collection_with_results(tmp_path_factory): + n_subjects = 10 + n_sessions = 2 + output_csv = 'inclusion.csv' + make_inclusion_df(n_subjects, n_sessions, output_csv) + status_collection = StatusCollection(output_csv) + + # choose 4 random status_ids to add results to + changed_status_ids = random.sample(status_collection.status_ids, 4) + updated_keys = {status_id.key for status_id in changed_status_ids} + update_results = [] + for status_id in changed_status_ids: + update_results.append( + ResultStatus( + sub_id=status_id.sub_id, + ses_id=status_id.ses_id, + result_location='branch', + is_failed=False, + has_results=True, + submitted=True, + ) + ) + status_collection.update_results(update_results) + for status_id in changed_status_ids: + assert status_collection.results[status_id.key].result_location == 'branch' + assert not status_collection.results[status_id.key].is_failed + assert status_collection.results[status_id.key].has_results + + # Find some status_ids that should not have results + unchanged_status_ids = [ + status_id + for status_id in status_collection.status_ids + if status_id.key not in updated_keys + ] + for status_id in unchanged_status_ids: + assert not status_collection.results[status_id.key].has_results + + results_dir = tmp_path_factory.mktemp('results') + status_collection.write_results(str(results_dir)) + + new_status_collection = StatusCollection( + pd.read_csv(output_csv), results=str(results_dir / 'results.csv') + ) + check_equal(status_collection, new_status_collection, 'results') + + loaded_results = _load_results(str(results_dir / 'results.csv')) + assert len(loaded_results) == len(status_collection.results) + + pd_loaded_results = _load_results(pd.read_csv(str(results_dir / 'results.csv'))) + assert len(pd_loaded_results) == len(status_collection.results)