Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion biasanalyzer/cohort.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from biasanalyzer.concept import ConceptHierarchy
from biasanalyzer.config import load_cohort_creation_config
from biasanalyzer.database import BiasDatabase, OMOPCDMDatabase
from biasanalyzer.models import CohortDefinition
from biasanalyzer.models import CohortDefinition, DOMAIN_MAPPING
from biasanalyzer.utils import clean_string, hellinger_distance, notify_users


Expand Down Expand Up @@ -59,6 +59,9 @@ def get_concept_stats(
"""
Get cohort concept statistics such as concept prevalence
"""
if concept_type not in DOMAIN_MAPPING:
raise ValueError(f'input concept_type {concept_type} is not a valid concept type to get concept stats')

cohort_stats = self.bias_db.get_cohort_concept_stats(
self.cohort_id,
self.query_builder,
Expand Down
16 changes: 9 additions & 7 deletions biasanalyzer/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _create_cohort_definition_table(self):
def _create_cohort_table(self):
self.conn.execute(f"""
CREATE TABLE IF NOT EXISTS {self.schema}.cohort (
subject_id BIGINT,
subject_id VARCHAR NOT NULL,
cohort_definition_id INTEGER,
cohort_start_date DATE,
cohort_end_date DATE,
Expand Down Expand Up @@ -288,12 +288,14 @@ def get_cohort_concept_stats(
)
concept_stats[concept_type] = self._execute_query(query)
cs_df = pd.DataFrame(concept_stats[concept_type])
# Combine concept_name and prevalence into a "details" column
cs_df["details"] = cs_df.apply(
lambda row: f"{row['concept_name']} (Code: {row['concept_code']}, "
f"Count: {row['count_in_cohort']}, Prevalence: {row['prevalence']:.3%})",
axis=1,
)

if not cs_df.empty:
# Combine concept_name and prevalence into a "details" column
cs_df["details"] = cs_df.apply(
lambda row: f"{row['concept_name']} (Code: {row['concept_code']}, "
f"Count: {row['count_in_cohort']}, Prevalence: {row['prevalence']:.3%})",
axis=1,
)

if print_concept_hierarchy:
filtered_cs_df = cs_df[cs_df["ancestor_concept_id"] != cs_df["descendant_concept_id"]]
Expand Down
28 changes: 14 additions & 14 deletions tests/query_based/test_cohort_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ def test_cohort_creation_baseline(caplog, test_db):

patient_ids = set([item["subject_id"] for item in cohort.data])
assert_equal(len(patient_ids), 5)
assert_equal(patient_ids, {106, 108, 110, 111, 112})
assert_equal(patient_ids, {'106', '108', '110', '111', '112'})
# select two patients to check for cohort_start_date and cohort_end_date automatically computed
patient_106 = next(item for item in cohort.data if item["subject_id"] == 106)
patient_108 = next(item for item in cohort.data if item["subject_id"] == 108)
patient_106 = next(item for item in cohort.data if item["subject_id"] == '106')
patient_108 = next(item for item in cohort.data if item["subject_id"] == '108')

# Replace dates with actual values from your test data
assert_equal(
Expand Down Expand Up @@ -127,7 +127,7 @@ def test_cohort_creation_study(test_db):
assert cohort.data is not None, "Cohort creation wrongly returned None data"
patient_ids = set([item["subject_id"] for item in cohort.data])
assert_equal(len(patient_ids), 4)
assert_equal(patient_ids, {108, 110, 111, 112})
assert_equal(patient_ids, {'108', '110', '111', '112'})


def test_cohort_creation_study2(caplog, test_db):
Expand Down Expand Up @@ -155,7 +155,7 @@ def test_cohort_creation_study2(caplog, test_db):
assert cohort.data is not None, "Cohort creation wrongly returned None data"
patient_ids = set([item["subject_id"] for item in cohort.data])
assert_equal(len(patient_ids), 1)
assert_equal(patient_ids, {106})
assert_equal(patient_ids, {'106'})


def test_cohort_creation_all(caplog, test_db):
Expand Down Expand Up @@ -191,7 +191,7 @@ def test_cohort_creation_all(caplog, test_db):
patient_ids = set([item["subject_id"] for item in cohort.data])
print(f"patient_ids: {patient_ids}", flush=True)
assert_equal(len(patient_ids), 2)
assert_equal(patient_ids, {108, 110})
assert_equal(patient_ids, {'108', '110'})


def test_cohort_creation_multiple_temporary_groups_with_no_operator(test_db):
Expand All @@ -214,7 +214,7 @@ def test_cohort_creation_multiple_temporary_groups_with_no_operator(test_db):
patient_ids = set([item["subject_id"] for item in cohort.data])
print(f"patient_ids: {patient_ids}", flush=True)
assert_equal(len(patient_ids), 2)
assert_equal(patient_ids, {108, 110})
assert_equal(patient_ids, {'108', '110'})


def test_cohort_creation_mixed_domains(test_db):
Expand Down Expand Up @@ -242,7 +242,7 @@ def test_cohort_creation_mixed_domains(test_db):
patient_ids = set([item["subject_id"] for item in cohort.data])
print(f"patient_ids: {patient_ids}", flush=True)
assert_equal(len(patient_ids), 3)
assert_equal(patient_ids, {1, 2, 6})
assert_equal(patient_ids, {'1', '2', '6'})
start_dates = [item["cohort_start_date"] for item in cohort.data]
assert_equal(len(start_dates), 3)
assert_equal(start_dates, [datetime.date(2020, 6, 1), datetime.date(2020, 6, 1), datetime.date(2018, 1, 1)])
Expand Down Expand Up @@ -356,10 +356,10 @@ def test_cohort_creation_negative_instance(test_db):

patient_ids = set([item["subject_id"] for item in cohort.data])
assert_equal(len(patient_ids), 6) # Female patients 1, 2, 3, 5
assert_equal(patient_ids, {1, 2, 3, 5, 6, 7})
assert_equal(patient_ids, {'1', '2', '3', '5', '6', '7'})

# Verify dates for a specific patient (e.g., patient 1 with last diabetes diagnosis)
patient_1 = next(item for item in cohort.data if item["subject_id"] == 1)
patient_1 = next(item for item in cohort.data if item["subject_id"] == '1')
assert_equal(
patient_1["cohort_start_date"],
datetime.date(2020, 6, 1),
Expand Down Expand Up @@ -392,10 +392,10 @@ def test_cohort_creation_offset(test_db):

patient_ids = set([item["subject_id"] for item in cohort.data])
assert_equal(len(patient_ids), 6) # Female patients 1, 2, 3, 5
assert_equal(patient_ids, {1, 2, 3, 5, 6, 7})
assert_equal(patient_ids, {'1', '2', '3', '5', '6', '7'})

# Verify dates for a specific patient (e.g., patient 1 with offset)
patient_1 = next(item for item in cohort.data if item["subject_id"] == 1)
patient_1 = next(item for item in cohort.data if item["subject_id"] == '1')
# Diabetes on 2020-06-01: -730 days = 2018-06-02, +180 days = 2020-11-28
assert_equal(
patient_1["cohort_start_date"],
Expand Down Expand Up @@ -435,10 +435,10 @@ def test_cohort_creation_negative_instance_offset(test_db):

patient_ids = set([item["subject_id"] for item in cohort.data])
assert_equal(len(patient_ids), 6)
assert_equal(patient_ids, {1, 2, 3, 5, 6, 7})
assert_equal(patient_ids, {'1', '2', '3', '5', '6', '7'})

# Verify dates for a specific patient (e.g., patient 1 with last diabetes and offset)
patient_1 = next(item for item in cohort.data if item["subject_id"] == 1)
patient_1 = next(item for item in cohort.data if item["subject_id"] == '1')
# Last diabetes on 2020-06-01: +180 days = 2020-11-28
assert_equal(
patient_1["cohort_start_date"],
Expand Down
6 changes: 4 additions & 2 deletions tests/query_based/test_hierarchical_prevalence.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import pytest
from numpy.ma.testutils import assert_equal

from biasanalyzer.concept import ConceptHierarchy


Expand All @@ -25,8 +27,8 @@ def test_cohort_concept_hierarchical_prevalence(test_db, caplog):
cohort.get_concept_stats(vocab="dummy_invalid_vocab")

# test the cohort does not have procedure_occurrence related concepts
with pytest.raises(ValueError):
cohort.get_concept_stats(concept_type="procedure_occurrence")
cohort_stat, _ = cohort.get_concept_stats(concept_type="procedure_occurrence")
assert_equal(cohort_stat, {'procedure_occurrence': []})

concept_stats, _ = cohort.get_concept_stats(vocab="ICD10CM", print_concept_hierarchy=True)
assert concept_stats is not None, "Failed to fetch concept stats"
Expand Down