Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions babs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from babs.scheduler import (
request_all_job_status,
)
from babs.status import StatusCollection
from babs.system import validate_queue
from babs.utils import (
combine_inclusion_dataframes,
Expand Down Expand Up @@ -132,6 +133,7 @@ def _apply_config(self) -> None:
- queue
- container
- input_datasets
- status_collection

"""
# Sanity check: the path `project_root` exists:
Expand Down Expand Up @@ -167,6 +169,18 @@ def _apply_config(self) -> None:
self.input_datasets = InputDatasets(self.processing_level, config_yaml['input_datasets'])
self.input_datasets.update_abs_paths(Path(self.project_root) / 'analysis')

self._update_status_collection()

def _update_status_collection(self) -> None:
"""
Update the status collection.
"""
self.status_collection = StatusCollection(
self.list_sub_path_abs,
jobs=self.job_status_path_abs,
results=self.job_status_path_abs,
)

def _update_inclusion_dataframe(
self, initial_inclusion_df: pd.DataFrame | None = None
) -> None:
Expand Down
168 changes: 168 additions & 0 deletions babs/status.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import os.path as op
from dataclasses import dataclass
from typing import Literal

import pandas as pd


@dataclass
class StatusKey:
sub_id: str
ses_id: str | None = None

def __post_init__(self):
self.key = (self.sub_id, self.ses_id)


@dataclass
class JobStatus:
"""
This class is used to get the status of the jobs.
"""

job_id: int = -1
task_id: int = -1
state: str = 'Unknown'
time_used: str = 'Unknown'
time_limit: str = 'Unknown'
nodes: int = 0
cpus: int = 0
partition: str = 'Unknown'
name: str = 'Unknown'
sub_id: str | None = None
ses_id: str | None = None


@dataclass
class ResultStatus:
"""
This class is used to get the status of the results.
"""

has_results: bool = False
is_failed: bool = False
result_location: Literal['branch', 'zip'] | None = None
submitted: bool = False
sub_id: str | None = None
ses_id: str | None = None


class StatusCollection:
"""
This class is used to get the status of the jobs and results.
"""

status_ids: list[StatusKey]
results: dict[tuple[str, str | None], ResultStatus]
jobs: dict[tuple[str, str | None], JobStatus]

def __init__(
self,
inclusion_df: pd.DataFrame | str,
results: pd.DataFrame | str | None = None,
jobs: pd.DataFrame | str | None = None,
):
if isinstance(inclusion_df, str):
inclusion_df = pd.read_csv(inclusion_df)
self.status_ids = [
StatusKey(row['sub_id'], row.get('ses_id', None))
for row in inclusion_df.to_dict(orient='records')
]
self.status_ids.sort(key=lambda x: x.key)
self.results = {}
self.jobs = {}
self.update_results(results)
self.update_jobs(jobs)

def update_jobs(self, new_jobs: pd.DataFrame | str | list[JobStatus] | None = None):
"""
Update the jobs in the StatusCollection.

The total number of jobs in this dataframe will always be equal to the
number of subjects in the inclusion_df.

Parameters
----------
new_jobs: pd.DataFrame | str | list[JobStatus]
The new jobs to update the StatusCollection with. If a str, it assumed
to be the path to a csv file containing the jobs. If a pd.DataFrame,
it is assumed to have the correct columns and will be converted to a
dict of JobStatus objects. If a list, it is assumed to be a list of
JobStatus objects.
"""
# Ensure we have a list of JobStatus objects
if isinstance(new_jobs, str) or isinstance(new_jobs, pd.DataFrame):
job_list = _load_jobs(new_jobs)
elif isinstance(new_jobs, list):
job_list = new_jobs
else:
job_list = []

# Create a dictionary of JobStatus objects
jobs = {}
for job in job_list:
key = (job.sub_id, job.ses_id)
jobs[key] = job
# Update the results to show that the job has been submitted
self.results[key].submitted = True

# For any keys not in the updated, keep the old jobstatus or make an empty new one
all_keys = {status_id.key for status_id in self.status_ids}
missing_jobs_keys = all_keys - set(jobs.keys())
for key in missing_jobs_keys:
jobs[key] = self.jobs.get(key, JobStatus(sub_id=key[0], ses_id=key[1]))
self.jobs = jobs

def update_results(self, new_results: pd.DataFrame | str | list[ResultStatus] | None = None):
# Ensure we have a list of ResultStatus objects
if isinstance(new_results, str) or isinstance(new_results, pd.DataFrame):
result_list = _load_results(new_results)
elif isinstance(new_results, list):
result_list = new_results
else:
result_list = []

# Create a dictionary of ResultStatus objects
results = {}
for result in result_list:
key = (result.sub_id, result.ses_id)
results[key] = result

# Warn about any keys in results_df that are not in all_keys:
all_keys = {status_id.key for status_id in self.status_ids}
missing_results_keys = all_keys - set(results.keys())
for key in missing_results_keys:
results[key] = self.results.get(key, ResultStatus(sub_id=key[0], ses_id=key[1]))
self.results = results

def write_results(self, output_dir: str):
# Convert None to pd.NA for writing to CSV
results_df = pd.DataFrame(self.results.values()).replace({None: pd.NA})
results_df.to_csv(op.join(output_dir, 'results.csv'), index=False)
jobs_df = pd.DataFrame(self.jobs.values()).replace({None: pd.NA})
jobs_df.to_csv(op.join(output_dir, 'jobs.csv'), index=False)


def _load_jobs(new_jobs: pd.DataFrame | str):
if isinstance(new_jobs, str):
new_jobs = pd.read_csv(new_jobs)
# Convert nan to None
jobs_dict = new_jobs.replace({pd.NA: None, pd.NaT: None}).to_dict(orient='records')
return [JobStatus(**row) for row in jobs_dict]


def _load_results(new_results: pd.DataFrame | str):
if isinstance(new_results, str):
new_results = pd.read_csv(new_results)
# Convert nan to None
results_dict = new_results.replace({pd.NA: None, pd.NaT: None}).to_dict(orient='records')
return [ResultStatus(**row) for row in results_dict]


def results_from_branches(branches: list[str]):
"""
Get the results from the branches.
"""
results = []
for branch in branches:
results.append(ResultStatus(branch=branch))
2 changes: 1 addition & 1 deletion tests/pytest_in_docker.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ docker run -it \
--cov-report=xml \
--cov=babs \
--pdb \
/babs/tests/test_update_input_data.py
/babs/tests/test_status.py

175 changes: 175 additions & 0 deletions tests/test_status.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import random
from dataclasses import asdict

import pandas as pd

from babs.status import JobStatus, ResultStatus, StatusCollection, _load_jobs, _load_results


def make_inclusion_df(n_subjects: int, n_sessions: int, output_csv: str) -> None:
if n_sessions == 0:
df = pd.DataFrame(
{
'sub_id': [f'sub-{i:02d}' for i in range(n_subjects)],
}
)
else:
sub_ids = []
ses_ids = []
for subid in range(n_subjects):
for sesid in range(n_sessions):
sub_ids.append(f'sub-{subid:02d}')
ses_ids.append(f'ses-{sesid:02d}')
df = pd.DataFrame(
{
'sub_id': sub_ids,
'ses_id': ses_ids,
}
)
df.to_csv(output_csv, index=False)


def test_status_sessionlevel_collection():
n_subjects = 10
n_sessions = 2
output_csv = 'inclusion.csv'
make_inclusion_df(n_subjects, n_sessions, output_csv)
status_collection = StatusCollection(output_csv)
assert len(status_collection.jobs) == n_subjects * n_sessions
assert len(status_collection.results) == n_subjects * n_sessions


def test_status_subjectlevel_collection():
n_subjects = 10
n_sessions = 0
output_csv = 'inclusion.csv'
make_inclusion_df(n_subjects, n_sessions, output_csv)
status_collection = StatusCollection(output_csv)
assert len(status_collection.jobs) == n_subjects
assert len(status_collection.results) == n_subjects


def check_equal(
status_collection_1: StatusCollection,
status_collection_2: StatusCollection,
jobs_or_results: str,
):
status_ids_1 = {status_id.key for status_id in status_collection_1.status_ids}
status_ids_2 = {status_id.key for status_id in status_collection_2.status_ids}
assert status_ids_1 == status_ids_2

for status_id in status_ids_1:
if jobs_or_results == 'results':
item1 = status_collection_1.results[status_id]
item2 = status_collection_2.results[status_id]
else:
item1 = status_collection_1.jobs[status_id]
item2 = status_collection_2.jobs[status_id]

assert asdict(item1) == asdict(item2)


def test_update_status_collection_with_jobs(tmp_path_factory):
n_subjects = 10
n_sessions = 2
output_csv = 'inclusion.csv'
make_inclusion_df(n_subjects, n_sessions, output_csv)
status_collection = StatusCollection(output_csv)

# choose 4 random status_ids to add results to
changed_status_ids = random.sample(status_collection.status_ids, 4)
updated_keys = {status_id.key for status_id in changed_status_ids}
update_jobs = []
for jobnum, status_id in enumerate(changed_status_ids):
update_jobs.append(
JobStatus(
sub_id=status_id.sub_id,
ses_id=status_id.ses_id,
job_id=1,
task_id=jobnum + 1,
state='running',
time_used='1:00:00',
time_limit='2:00:00',
nodes=1,
cpus=1,
partition='standard',
name='test',
)
)
status_collection.update_jobs(update_jobs)
for jobnum, status_id in enumerate(changed_status_ids):
assert status_collection.jobs[status_id.key].job_id == 1
assert status_collection.jobs[status_id.key].task_id == jobnum + 1
assert status_collection.jobs[status_id.key].state == 'running'
assert status_collection.jobs[status_id.key].time_used == '1:00:00'
assert status_collection.jobs[status_id.key].time_limit == '2:00:00'
assert status_collection.jobs[status_id.key].nodes == 1
assert status_collection.jobs[status_id.key].cpus == 1
assert status_collection.jobs[status_id.key].partition == 'standard'
assert status_collection.jobs[status_id.key].name == 'test'

jobs_dir = tmp_path_factory.mktemp('jobs')
status_collection.write_results(str(jobs_dir))

new_status_collection = StatusCollection(
pd.read_csv(output_csv), jobs=str(jobs_dir / 'jobs.csv')
)
check_equal(status_collection, new_status_collection, 'jobs')

loaded_jobs = _load_jobs(str(jobs_dir / 'jobs.csv'))
assert len(loaded_jobs) == len(status_collection.jobs)

pd_loaded_jobs = _load_jobs(pd.read_csv(str(jobs_dir / 'jobs.csv')))
assert len(pd_loaded_jobs) == len(status_collection.jobs)


def test_update_status_collection_with_results(tmp_path_factory):
n_subjects = 10
n_sessions = 2
output_csv = 'inclusion.csv'
make_inclusion_df(n_subjects, n_sessions, output_csv)
status_collection = StatusCollection(output_csv)

# choose 4 random status_ids to add results to
changed_status_ids = random.sample(status_collection.status_ids, 4)
updated_keys = {status_id.key for status_id in changed_status_ids}
update_results = []
for status_id in changed_status_ids:
update_results.append(
ResultStatus(
sub_id=status_id.sub_id,
ses_id=status_id.ses_id,
result_location='branch',
is_failed=False,
has_results=True,
submitted=True,
)
)
status_collection.update_results(update_results)
for status_id in changed_status_ids:
assert status_collection.results[status_id.key].result_location == 'branch'
assert not status_collection.results[status_id.key].is_failed
assert status_collection.results[status_id.key].has_results

# Find some status_ids that should not have results
unchanged_status_ids = [
status_id
for status_id in status_collection.status_ids
if status_id.key not in updated_keys
]
for status_id in unchanged_status_ids:
assert not status_collection.results[status_id.key].has_results

results_dir = tmp_path_factory.mktemp('results')
status_collection.write_results(str(results_dir))

new_status_collection = StatusCollection(
pd.read_csv(output_csv), results=str(results_dir / 'results.csv')
)
check_equal(status_collection, new_status_collection, 'results')

loaded_results = _load_results(str(results_dir / 'results.csv'))
assert len(loaded_results) == len(status_collection.results)

pd_loaded_results = _load_results(pd.read_csv(str(results_dir / 'results.csv')))
assert len(pd_loaded_results) == len(status_collection.results)
Loading