diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 24c2fc5..a85b1c3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,4 +1,4 @@ -name: test +name: Tests and Lint on: # Trigger workflow on pull requests to the main branch @@ -11,32 +11,46 @@ on: - main jobs: - test: + tests: + name: Run Unit Tests and Lint runs-on: ubuntu-22.04 steps: # Step 1: Checkout the repository - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - # Step 2: Set up Python - - name: Set up Python - uses: actions/setup-python@v4 + # Step 3: Install Poetry + - name: Install Poetry + uses: snok/install-poetry@v1 with: - python-version: "3.8.18" # Use your project’s Python version + version: 2.2.1 + virtualenvs-create: true + virtualenvs-in-project: true - # Step 3: Install Poetry - - name: Install Poetry + - name: Verify Poetry install run: | - curl -sSL https://install.python-poetry.org | POETRY_VERSION=1.8.4 python3 - - echo "$HOME/.local/bin" >> $GITHUB_PATH + echo "PATH is: $PATH" + which poetry + poetry --version - # Step 4: Install dependencies + # Step 2: Set up Python + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" # Use project’s Python version + cache: "poetry" + + # Step 5: Install dependencies - name: Install dependencies run: | poetry install --with dev - # Step 5: Run Tests + # Step 6: Run Tests - name: Run tests run: | poetry run pytest -s --cov=biasanalyzer --cov-config=.coveragerc + + # Step 7: Run Ruff check + - name: Run ruff + run: poetry run ruff check . diff --git a/biasanalyzer/__init__.py b/biasanalyzer/__init__.py index b794fd4..3dc1f76 100644 --- a/biasanalyzer/__init__.py +++ b/biasanalyzer/__init__.py @@ -1 +1 @@ -__version__ = '0.1.0' +__version__ = "0.1.0" diff --git a/biasanalyzer/api.py b/biasanalyzer/api.py index 90a688e..57d947c 100644 --- a/biasanalyzer/api.py +++ b/biasanalyzer/api.py @@ -1,13 +1,15 @@ import time -from pydantic import ValidationError from typing import List -from biasanalyzer.database import OMOPCDMDatabase, BiasDatabase + +from IPython.display import display +from ipytree import Tree +from ipywidgets import Label, VBox +from pydantic import ValidationError + from biasanalyzer.cohort import CohortAction from biasanalyzer.config import load_config -from ipywidgets import VBox, Label -from ipytree import Tree -from IPython.display import display -from biasanalyzer.utils import get_direction_arrow, notify_users, build_concept_tree +from biasanalyzer.database import BiasDatabase, OMOPCDMDatabase +from biasanalyzer.utils import build_concept_tree, get_direction_arrow, notify_users class BIAS: @@ -22,48 +24,55 @@ def __init__(self, config_file_path=None): def set_config(self, config_file_path: str): if not config_file_path: - notify_users('no configuration file specified. ' - 'Call set_config(config_file_path) next to specify configurations') + notify_users( + "no configuration file specified. Call set_config(config_file_path) next to specify configurations" + ) else: try: self.config = load_config(config_file_path) - notify_users(f'configuration specified in {config_file_path} loaded successfully') + notify_users(f"configuration specified in {config_file_path} loaded successfully") except FileNotFoundError: - notify_users('specified configuration file does not exist. ' - 'Call set_config(config_file_path) next to specify a valid configuration file', - level='error') + notify_users( + "specified configuration file does not exist. " + "Call set_config(config_file_path) next to specify a valid configuration file", + level="error", + ) except ValidationError as ex: - notify_users(f'configuration yaml file is not valid with validation error: {ex}', level='error') + notify_users(f"configuration yaml file is not valid with validation error: {ex}", level="error") def set_root_omop(self): if not self.config: - notify_users('no valid configuration to set root OMOP CDM data. ' - 'Call set_config(config_file_path) to specify configurations first.') + notify_users( + "no valid configuration to set root OMOP CDM data. " + "Call set_config(config_file_path) to specify configurations first." + ) return self.cleanup() - db_type = self.config['root_omop_cdm_database']['database_type'] - if db_type == 'postgresql': - user = self.config['root_omop_cdm_database']['username'] - password = self.config['root_omop_cdm_database']['password'] - host = self.config['root_omop_cdm_database']['hostname'] - port = self.config['root_omop_cdm_database']['port'] - db = self.config['root_omop_cdm_database']['database'] + db_type = self.config["root_omop_cdm_database"]["database_type"] + if db_type == "postgresql": + user = self.config["root_omop_cdm_database"]["username"] + password = self.config["root_omop_cdm_database"]["password"] + host = self.config["root_omop_cdm_database"]["hostname"] + port = self.config["root_omop_cdm_database"]["port"] + db = self.config["root_omop_cdm_database"]["database"] db_url = f"postgresql://{user}:{password}@{host}:{port}/{db}" self.omop_cdm_db = OMOPCDMDatabase(db_url) - self.bias_db = BiasDatabase(':memory:', omop_db_url=db_url) - elif db_type == 'duckdb': - db_path = self.config['root_omop_cdm_database'].get('database', ":memory:") + self.bias_db = BiasDatabase(":memory:", omop_db_url=db_url) + elif db_type == "duckdb": + db_path = self.config["root_omop_cdm_database"].get("database", ":memory:") self.omop_cdm_db = OMOPCDMDatabase(db_path) - self.bias_db = BiasDatabase(':memory:', omop_db_url=db_path) + self.bias_db = BiasDatabase(":memory:", omop_db_url=db_path) else: notify_users(f"Unsupported database type: {db_type}") def _set_cohort_action(self): if self.omop_cdm_db is None: - notify_users('A valid OMOP CDM must be set before creating a cohort. ' - 'Call set_root_omop first to set a valid root OMOP CDM') + notify_users( + "A valid OMOP CDM must be set before creating a cohort. " + "Call set_root_omop first to set a valid root OMOP CDM" + ) return None if self.cohort_action is None: self.cohort_action = CohortAction(self.omop_cdm_db, self.bias_db) @@ -71,25 +80,31 @@ def _set_cohort_action(self): def get_domains_and_vocabularies(self): if self.omop_cdm_db is None: - notify_users('A valid OMOP CDM must be set before getting domains. ' - 'Call set_root_omop first to set a valid root OMOP CDM') + notify_users( + "A valid OMOP CDM must be set before getting domains. " + "Call set_root_omop first to set a valid root OMOP CDM" + ) return None return self.omop_cdm_db.get_domains_and_vocabularies() def get_concepts(self, search_term, domain=None, vocabulary=None): if self.omop_cdm_db is None: - notify_users('A valid OMOP CDM must be set before getting concepts. ' - 'Call set_root_omop first to set a valid root OMOP CDM') + notify_users( + "A valid OMOP CDM must be set before getting concepts. " + "Call set_root_omop first to set a valid root OMOP CDM" + ) return None if domain is None and vocabulary is None: - notify_users('either domain or vocabulary must be set to constrain the number of returned concepts') + notify_users("either domain or vocabulary must be set to constrain the number of returned concepts") return None return self.omop_cdm_db.get_concepts(search_term, domain, vocabulary) def get_concept_hierarchy(self, concept_id): if self.omop_cdm_db is None: - notify_users('A valid OMOP CDM must be set before getting concepts. ' - 'Call set_root_omop first to set a valid root OMOP CDM') + notify_users( + "A valid OMOP CDM must be set before getting concepts. " + "Call set_root_omop first to set a valid root OMOP CDM" + ) return None return self.omop_cdm_db.get_concept_hierarchy(concept_id) @@ -98,20 +113,21 @@ def display_concept_tree(self, concept_tree: dict, level: int = 0, show_in_text_ Recursively prints the concept hierarchy tree in an indented format for display. """ details = concept_tree.get("details", {}) - if 'parents' in concept_tree: - tree_type = 'parents' - elif 'children' in concept_tree: - tree_type = 'children' + if "parents" in concept_tree: + tree_type = "parents" + elif "children" in concept_tree: + tree_type = "children" else: - notify_users('The input concept tree must contain parents or children key as the type of the tree.') - return '' + notify_users("The input concept tree must contain parents or children key as the type of the tree.") + return "" if show_in_text_format: if details: direction_arrow = get_direction_arrow(tree_type) print( " " * level + f"{direction_arrow} {details['concept_name']} (ID: {details['concept_id']}, " - f"Code: {details['concept_code']})") + f"Code: {details['concept_code']})" + ) for child in concept_tree.get(tree_type, []): if child: @@ -128,9 +144,9 @@ def display_concept_tree(self, concept_tree: dict, level: int = 0, show_in_text_ display(VBox([Label("Concept Hierarchy"), tree])) return root_node - - def create_cohort(self, cohort_name: str, cohort_desc: str, query_or_yaml_file: str, created_by: str, - delay: float=0): + def create_cohort( + self, cohort_name: str, cohort_desc: str, query_or_yaml_file: str, created_by: str, delay: float = 0 + ): """ API method that allows to create a cohort :param cohort_name: name of the cohort @@ -149,17 +165,15 @@ def create_cohort(self, cohort_name: str, cohort_desc: str, query_or_yaml_file: if delay > 0: notify_users(f"[DEBUG] Simulating long-running task with {delay} seconds delay...") time.sleep(delay) - notify_users('cohort created successfully') + notify_users("cohort created successfully") return created_cohort else: - notify_users('failed to create a valid cohort action object') + notify_users("failed to create a valid cohort action object") return None - - def get_cohorts_concept_stats(self, cohorts: List[int], - concept_type: str='condition_occurrence', - filter_count: int=0, - vocab=None): + def get_cohorts_concept_stats( + self, cohorts: List[int], concept_type: str = "condition_occurrence", filter_count: int = 0, vocab=None + ): """ compute concept statistics such as concept prevalence in a union of multiple cohorts :param cohorts: list of cohort ids @@ -170,26 +184,25 @@ def get_cohorts_concept_stats(self, cohorts: List[int], :return: ConceptHierarchy object """ if not cohorts: - notify_users('The input cohorts list is empty. At least one cohort id must be provided.') + notify_users("The input cohorts list is empty. At least one cohort id must be provided.") return None c_action = self._set_cohort_action() if c_action: - return c_action.get_cohorts_concept_stats(cohorts, concept_type=concept_type, filter_count=filter_count, - vocab=vocab) + return c_action.get_cohorts_concept_stats( + cohorts, concept_type=concept_type, filter_count=filter_count, vocab=vocab + ) else: - notify_users('failed to get concept prevalence stats for the union of cohorts') + notify_users("failed to get concept prevalence stats for the union of cohorts") return None - def compare_cohorts(self, cohort_id1, cohort_id2): c_action = self._set_cohort_action() if c_action: return c_action.compare_cohorts(cohort_id1, cohort_id2) else: - notify_users('failed to create a valid cohort action object') + notify_users("failed to create a valid cohort action object") return None - def cleanup(self): if self.bias_db: self.bias_db.close() diff --git a/biasanalyzer/background/threading_utils.py b/biasanalyzer/background/threading_utils.py index b65198f..de89671 100644 --- a/biasanalyzer/background/threading_utils.py +++ b/biasanalyzer/background/threading_utils.py @@ -1,6 +1,7 @@ import threading import traceback + class BackgroundResult: def __init__(self): self.value = None @@ -12,6 +13,7 @@ def set(self, result, error=None): self.error = error self.ready = True + def run_in_background(func, *args, result_holder=None, on_complete=None, **kwargs): """ Run a time-consuming function in background @@ -22,6 +24,7 @@ def run_in_background(func, *args, result_holder=None, on_complete=None, **kwarg :param kwargs: any keyword arguments of the function to be passed in as a dict :return: a background thread """ + def wrapper(): try: print("[*] Background task started...", flush=True) diff --git a/biasanalyzer/cohort.py b/biasanalyzer/cohort.py index 00166d7..4d5af20 100644 --- a/biasanalyzer/cohort.py +++ b/biasanalyzer/cohort.py @@ -1,15 +1,17 @@ +from datetime import datetime from functools import reduce +from typing import List + import pandas as pd -from datetime import datetime -from tqdm.auto import tqdm from pydantic import ValidationError -from typing import List -from biasanalyzer.models import CohortDefinition -from biasanalyzer.config import load_cohort_creation_config -from biasanalyzer.database import OMOPCDMDatabase, BiasDatabase -from biasanalyzer.utils import hellinger_distance, clean_string, notify_users +from tqdm.auto import tqdm + from biasanalyzer.cohort_query_builder import CohortQueryBuilder from biasanalyzer.concept import ConceptHierarchy +from biasanalyzer.config import load_cohort_creation_config +from biasanalyzer.database import BiasDatabase, OMOPCDMDatabase +from biasanalyzer.models import CohortDefinition +from biasanalyzer.utils import clean_string, hellinger_distance, notify_users class CohortData: @@ -17,7 +19,7 @@ def __init__(self, cohort_id: int, bias_db: BiasDatabase, omop_db: OMOPCDMDataba self.cohort_id = cohort_id self.bias_db = bias_db self.omop_db = omop_db - self._cohort_data = None # cache the cohort data + self._cohort_data = None # cache the cohort data self._metadata = None self.query_builder = CohortQueryBuilder(cohort_creation=False) @@ -37,7 +39,7 @@ def metadata(self): self._metadata = self.bias_db.get_cohort_definition(self.cohort_id) return self._metadata - def get_stats(self, variable=''): + def get_stats(self, variable=""): """ Get aggregation statistics for the cohort in BiasDatabase. variable is optional with a default empty string. Supported variables are: age, gender, @@ -51,23 +53,26 @@ def get_distributions(self, variable): """ return self.bias_db.get_cohort_distributions(self.cohort_id, variable) - def get_concept_stats(self, concept_type='condition_occurrence', filter_count=0, - vocab=None, print_concept_hierarchy=False): + def get_concept_stats( + self, concept_type="condition_occurrence", filter_count=0, vocab=None, print_concept_hierarchy=False + ): """ Get cohort concept statistics such as concept prevalence """ - cohort_stats = self.bias_db.get_cohort_concept_stats(self.cohort_id, - self.query_builder, - concept_type=concept_type, - filter_count=filter_count, - vocab=vocab, - print_concept_hierarchy=print_concept_hierarchy) - return (cohort_stats, - ConceptHierarchy.build_concept_hierarchy_from_results(self.cohort_id, concept_type, - cohort_stats[concept_type], - filter_count=filter_count, - vocab=vocab)) - + cohort_stats = self.bias_db.get_cohort_concept_stats( + self.cohort_id, + self.query_builder, + concept_type=concept_type, + filter_count=filter_count, + vocab=vocab, + print_concept_hierarchy=print_concept_hierarchy, + ) + return ( + cohort_stats, + ConceptHierarchy.build_concept_hierarchy_from_results( + self.cohort_id, concept_type, cohort_stats[concept_type], filter_count=filter_count, vocab=vocab + ), + ) def __del__(self): self._cohort_data = None @@ -80,8 +85,7 @@ def __init__(self, omop_db: OMOPCDMDatabase, bias_db: BiasDatabase): self.bias_db = bias_db self._query_builder = CohortQueryBuilder() - def create_cohort(self, cohort_name: str, description: str, query_or_yaml_file: str, - created_by: str): + def create_cohort(self, cohort_name: str, description: str, query_or_yaml_file: str, created_by: str): """ Create a new cohort by executing a query on OMOP CDM database and storing the result in BiasDatabase. The query can be passed in directly @@ -96,21 +100,23 @@ def create_cohort(self, cohort_name: str, description: str, query_or_yaml_file: stages = [ "Built query", "Executed query on OMOP database to get cohort data", - "Inserted cohort data into DuckDB - Done" + "Inserted cohort data into DuckDB - Done", ] progress = tqdm(total=len(stages), desc="Cohort creation", unit="stage", dynamic_ncols=True, leave=True) progress.set_postfix_str(stages[0]) - if query_or_yaml_file.endswith('.yaml') or query_or_yaml_file.endswith('.yml'): + if query_or_yaml_file.endswith(".yaml") or query_or_yaml_file.endswith(".yml"): try: cohort_config = load_cohort_creation_config(query_or_yaml_file) - tqdm.write(f'configuration specified in {query_or_yaml_file} loaded successfully') + tqdm.write(f"configuration specified in {query_or_yaml_file} loaded successfully") except FileNotFoundError: - notify_users('specified cohort creation configuration file does not exist. Make sure ' - 'the configuration file name with path is specified correctly.') + notify_users( + "specified cohort creation configuration file does not exist. Make sure " + "the configuration file name with path is specified correctly." + ) return None except ValidationError as ex: - notify_users(f'cohort creation configuration yaml file is not valid with validation error: {ex}') + notify_users(f"cohort creation configuration yaml file is not valid with validation error: {ex}") return None query = self._query_builder.build_query_cohort_creation(cohort_config) @@ -130,7 +136,7 @@ def create_cohort(self, cohort_name: str, description: str, query_or_yaml_file: description=description, created_date=datetime.now().date(), creation_info=clean_string(query), - created_by=created_by + created_by=created_by, ) cohort_def_id = self.bias_db.create_cohort_definition(cohort_def, progress_obj=tqdm) progress.update(1) @@ -138,7 +144,7 @@ def create_cohort(self, cohort_name: str, description: str, query_or_yaml_file: progress.set_postfix_str(stages[2]) # Store cohort_definition and cohort data into BiasDatabase cohort_df = pd.DataFrame(result) - cohort_df['cohort_definition_id'] = cohort_def_id + cohort_df["cohort_definition_id"] = cohort_def_id cohort_df = cohort_df.rename(columns={"person_id": "subject_id"}) self.bias_db.create_cohort_in_bulk(cohort_df) progress.update(1) @@ -147,7 +153,7 @@ def create_cohort(self, cohort_name: str, description: str, query_or_yaml_file: return CohortData(cohort_id=cohort_def_id, bias_db=self.bias_db, omop_db=self.omop_db) else: progress.update(2) - notify_users(f"No cohort is created due to empty results being returned from query") + notify_users("No cohort is created due to empty results being returned from query") return None except Exception as e: progress.update(2) @@ -156,20 +162,21 @@ def create_cohort(self, cohort_name: str, description: str, query_or_yaml_file: omop_session.close() return None - def get_cohorts_concept_stats(self, cohorts: List[int], - concept_type: str = 'condition_occurrence', - filter_count: int = 0, - vocab=None): - cohort_concept_stats = [self.bias_db.get_cohort_concept_stats(c, self._query_builder, - concept_type=concept_type, - filter_count=filter_count, - vocab=vocab) - for c in cohorts] - hierarchies = [ConceptHierarchy.build_concept_hierarchy_from_results(c, concept_type, - c_stats.get(concept_type, []), - filter_count=filter_count, - vocab=vocab) - for c, c_stats in zip(cohorts, cohort_concept_stats)] + def get_cohorts_concept_stats( + self, cohorts: List[int], concept_type: str = "condition_occurrence", filter_count: int = 0, vocab=None + ): + cohort_concept_stats = [ + self.bias_db.get_cohort_concept_stats( + c, self._query_builder, concept_type=concept_type, filter_count=filter_count, vocab=vocab + ) + for c in cohorts + ] + hierarchies = [ + ConceptHierarchy.build_concept_hierarchy_from_results( + c, concept_type, c_stats.get(concept_type, []), filter_count=filter_count, vocab=vocab + ) + for c, c_stats in zip(cohorts, cohort_concept_stats) + ] return reduce(lambda h1, h2: h1.union(h2), hierarchies).to_dict() def compare_cohorts(self, cohort_id_1: int, cohort_id_2: int): @@ -180,11 +187,9 @@ def compare_cohorts(self, cohort_id_1: int, cohort_id_2: int): for variable in self.bias_db.cohort_distribution_variables: cohort_1_stats = self.bias_db.get_cohort_distributions(cohort_id_1, variable=variable) cohort_2_stats = self.bias_db.get_cohort_distributions(cohort_id_2, variable=variable) - cohort_1_probs = [entry['probability'] for entry in cohort_1_stats] - cohort_2_probs = [entry['probability'] for entry in cohort_2_stats] + cohort_1_probs = [entry["probability"] for entry in cohort_1_stats] + cohort_2_probs = [entry["probability"] for entry in cohort_2_stats] dist = hellinger_distance(cohort_1_probs, cohort_2_probs) - results.append({ - f'{variable}_hellinger_distance': dist - }) + results.append({f"{variable}_hellinger_distance": dist}) return results diff --git a/biasanalyzer/cohort_query_builder.py b/biasanalyzer/cohort_query_builder.py index 6c3f82a..7f93f79 100644 --- a/biasanalyzer/cohort_query_builder.py +++ b/biasanalyzer/cohort_query_builder.py @@ -1,30 +1,33 @@ +# ruff: noqa: S608 +import importlib.resources import os import sys -import importlib.resources -from biasanalyzer.models import TemporalEventGroup, DOMAIN_MAPPING + from jinja2 import Environment, FileSystemLoader +from biasanalyzer.models import DOMAIN_MAPPING, TemporalEventGroup + class CohortQueryBuilder: def __init__(self, cohort_creation=True): """Get the path to SQL templates, whether running from source or installed.""" try: - if sys.version_info >= (3, 9): # pragma: no cover + if sys.version_info >= (3, 9): # pragma: no cover # Python 3.9+: Use importlib.resources.files() template_path = importlib.resources.files("biasanalyzer").joinpath("sql_templates") - else: + else: # pragma: no cover # Python 3.8: Use importlib.resources.path() (context manager) with importlib.resources.path("biasanalyzer", "sql_templates") as p: template_path = str(p) - except ModuleNotFoundError: # pragma: no cover + except ModuleNotFoundError: # pragma: no cover template_path = os.path.join(os.path.dirname(__file__), "sql_templates") - print(f'template_path: {template_path}') - self.env = Environment(loader=FileSystemLoader(template_path), extensions=['jinja2.ext.do']) + print(f"template_path: {template_path}") + self.env = Environment(loader=FileSystemLoader(template_path), extensions=["jinja2.ext.do"]) if cohort_creation: self.env.globals.update( - demographics_filter=self._load_macro('demographics_filter'), - temporal_event_filter=self.temporal_event_filter + demographics_filter=self._load_macro("demographics_filter"), + temporal_event_filter=self.temporal_event_filter, ) def _extract_domains(self, events): @@ -40,7 +43,7 @@ def _load_macro(self, macro_name): """ Load a macro from macros.sql.j2 into the Jinja2 environment. """ - macros_template = self.env.get_template('macros.sql.j2') + macros_template = self.env.get_template("macros.sql.j2") return macros_template.module.__dict__[macro_name] def build_query_cohort_creation(self, cohort_config: dict) -> str: @@ -49,30 +52,31 @@ def build_query_cohort_creation(self, cohort_config: dict) -> str: :param cohort_config: dict object loaded from yaml file for building sql query. :return: The rendered SQL query. """ - inclusion_criteria = cohort_config.get('inclusion_criteria') - exclusion_criteria = cohort_config.get('exclusion_criteria', {}) + inclusion_criteria = cohort_config.get("inclusion_criteria") + exclusion_criteria = cohort_config.get("exclusion_criteria", {}) inclusion_events = inclusion_criteria.get("temporal_events", []) exclusion_events = exclusion_criteria.get("temporal_events", []) - temporal_events = bool(inclusion_events) # Only inclusion_events matter for cohort dates + temporal_events = bool(inclusion_events) # Only inclusion_events matter for cohort dates all_domains = self._extract_domains(inclusion_events + exclusion_events) # Filter DOMAIN_MAPPING to exclude domains with table: None - valid_domains = {k: v for k, v in DOMAIN_MAPPING.items() if v.get('table')} + valid_domains = {k: v for k, v in DOMAIN_MAPPING.items() if v.get("table")} ranked_domains = {dt: valid_domains[dt] for dt in all_domains if dt in valid_domains} if not temporal_events: # For demographic only inclusion criteria, filter DOMAIN_MAPPING to exclude domains with table: None ranked_domains = valid_domains - template = self.env.get_template(f"cohort_creation_query.sql.j2") + template = self.env.get_template("cohort_creation_query.sql.j2") return template.render( inclusion_criteria=inclusion_criteria, exclusion_criteria=exclusion_criteria, ranked_domains=ranked_domains, - temporal_events=temporal_events + temporal_events=temporal_events, ) - def build_concept_prevalence_query(self, db_schema: str, omop_alias: str, concept_type: str, cid: int, - filter_count: int, vocab: str) -> str: + def build_concept_prevalence_query( + self, db_schema: str, omop_alias: str, concept_type: str, cid: int, filter_count: int, vocab: str + ) -> str: """ Build a SQL query for concept prevalence statistics for a given domain and cohort. :param db_schema: BiasDatabase database schema under which all tables are stored. @@ -103,7 +107,7 @@ def build_concept_prevalence_query(self, db_schema: str, omop_alias: str, concep start_date_column=DOMAIN_MAPPING[concept_type]["start_date"], cid=cid, filter_count=filter_count, - vocab=effective_vocab + vocab=effective_vocab, ) @staticmethod @@ -149,12 +153,11 @@ def render_event(event): {adjusted_start} AS adjusted_start, {adjusted_end} AS adjusted_end FROM {rank_table} - WHERE concept_id = {event['event_concept_id']}{instance_condition} + WHERE concept_id = {event["event_concept_id"]}{instance_condition} """ return base_sql - @staticmethod def render_event_group(event_group, alias_prefix="evt"): """ @@ -166,7 +169,7 @@ def render_event_group(event_group, alias_prefix="evt"): Returns: str: SQL query string for the event group. """ - queries = [] # accumulate SQL queries when called recursively with nested event groups + queries = [] # accumulate SQL queries when called recursively with nested event groups if "events" not in event_group: # Single event return CohortQueryBuilder.render_event(event_group) else: @@ -174,7 +177,7 @@ def render_event_group(event_group, alias_prefix="evt"): event_sql = CohortQueryBuilder.render_event_group(event, f"{alias_prefix}_{i}") if event_sql: queries.append(event_sql) - if not queries: # pragma: no cover + if not queries: # pragma: no cover return "" if event_group["operator"] == "AND": @@ -196,7 +199,7 @@ def render_event_group(event_group, alias_prefix="evt"): combined_sql = f""" SELECT person_id, event_start_date, event_end_date, adjusted_start, adjusted_end FROM ( - {' UNION ALL '.join(f'({q})' for q in queries)} + {" UNION ALL ".join(f"({q})" for q in queries)} ) AS all_events WHERE person_id IN ( {person_id_sql} @@ -205,8 +208,10 @@ def render_event_group(event_group, alias_prefix="evt"): return combined_sql elif event_group["operator"] == "OR": - return (f"SELECT person_id, event_start_date, event_end_date, adjusted_start, adjusted_end " - f"FROM ({' UNION '.join(queries)}) AS {alias_prefix}_or") + return ( + f"SELECT person_id, event_start_date, event_end_date, adjusted_start, adjusted_end " + f"FROM ({' UNION '.join(queries)}) AS {alias_prefix}_or" + ) elif event_group["operator"] == "NOT": not_query = queries[0] # Return a query that selects all persons from a base table (e.g., person), @@ -223,12 +228,12 @@ def render_event_group(event_group, alias_prefix="evt"): if len(queries) == 1: # the other query is the timestamp event which has to be handled here as it depends on the other # event in the BEFORE operator - timestamp_event = next((e for e in event_group['events'] if e["event_type"] == "date"), None) - non_timestamp_event = next((e for e in event_group['events'] if e["event_type"] != "date"), None) + timestamp_event = next((e for e in event_group["events"] if e["event_type"] == "date"), None) + non_timestamp_event = next((e for e in event_group["events"] if e["event_type"] != "date"), None) if timestamp_event and non_timestamp_event: timestamp = timestamp_event["timestamp"] - timestamp_event_index = event_group['events'].index(timestamp_event) - non_timestamp_event_index = event_group['events'].index(non_timestamp_event) + timestamp_event_index = event_group["events"].index(timestamp_event) + non_timestamp_event_index = event_group["events"].index(non_timestamp_event) if timestamp_event_index < non_timestamp_event_index: # timestamp needs to happen before non-timestamp event return f""" @@ -259,7 +264,8 @@ def render_event_group(event_group, alias_prefix="evt"): AND {e1_alias}.event_start_date < {e2_alias}.event_start_date {interval_sql} UNION ALL - SELECT {e2_alias}.person_id, {e2_alias}.event_start_date, {e2_alias}.event_end_date, + SELECT {e2_alias}.person_id, {e2_alias}.event_start_date, + {e2_alias}.event_end_date, {e2_alias}.adjusted_start, {e2_alias}.adjusted_end FROM ({queries[1]}) AS {e2_alias} JOIN ({queries[0]}) AS {e1_alias} @@ -269,7 +275,7 @@ def render_event_group(event_group, alias_prefix="evt"): """ return "" # pragma: no cover - def temporal_event_filter(self, event_groups, alias='c'): + def temporal_event_filter(self, event_groups, alias="c"): """ Generates the SQL filter for temporal event criteria. @@ -285,14 +291,14 @@ def temporal_event_filter(self, event_groups, alias='c'): for i, event_group in enumerate(event_groups): group_sql = self.render_event_group(event_group) if group_sql: - if alias == 'ex': + if alias == "ex": # exclusion criteria filters.append(f"AND {alias}.person_id IN (SELECT person_id FROM ({group_sql}) AS ex_subquery_{i})") else: filters.append(f"({group_sql})") if not filters: # pragma: no cover return "" - if alias == 'ex': + if alias == "ex": # For exclusion, combine with AND as filters return " ".join(filters) else: @@ -312,9 +318,11 @@ def temporal_event_filter(self, event_groups, alias='c'): # events: # - event_type: drug_exposure # event_concept_id: 67890 - return (f"SELECT person_id, event_start_date, event_end_date, " - f"adjusted_start, adjusted_end FROM " - f"({' UNION ALL '.join(filters)}) AS combined_events") + return ( + f"SELECT person_id, event_start_date, event_end_date, " + f"adjusted_start, adjusted_end FROM " + f"({' UNION ALL '.join(filters)}) AS combined_events" + ) # Single event group case with operator defined return filters[0] diff --git a/biasanalyzer/concept.py b/biasanalyzer/concept.py index 1c98504..5d86989 100644 --- a/biasanalyzer/concept.py +++ b/biasanalyzer/concept.py @@ -1,6 +1,7 @@ -import networkx as nx -from typing import List, Optional, Union from _collections import deque +from typing import List, Optional, Union + +import networkx as nx class ConceptNode: @@ -61,8 +62,9 @@ def to_dict(self, include_children: bool = True, include_union_metrics: bool = F "parent_ids": list(self._ch.graph.predecessors(self.id)), } if include_children: - data["children"] = [c.to_dict(include_children=True, include_union_metrics=include_union_metrics) - for c in self.children] + data["children"] = [ + c.to_dict(include_children=True, include_union_metrics=include_union_metrics) for c in self.children + ] return data @@ -86,8 +88,9 @@ def _normalize_identifier(identifier: str) -> str: return "+".join(parts) @classmethod - def build_concept_hierarchy_from_results(cls, cohort_id: int, concept_type: str, results: List[dict], - filter_count=0, vocab=None): + def build_concept_hierarchy_from_results( + cls, cohort_id: int, concept_type: str, results: List[dict], filter_count=0, vocab=None + ): """ build concept hierarchy tree managed by networkx from list of dicts returned from the concept prevalence SQL with cache management. cohort_id, concept_type, and filter_count are used for caching to uniquely identify @@ -147,20 +150,19 @@ def get_root_nodes(self, serialization: bool = False) -> List: roots = [n for n in self.graph.nodes if self.graph.in_degree(n) == 0] root_nodes = [ConceptNode(r, self) for r in roots] if serialization: - return [rn.to_dict(include_children=False) for rn in root_nodes] + return [rn.to_dict(include_children=False) for rn in root_nodes] else: return root_nodes def get_leaf_nodes(self, serialization: bool = False) -> List: leaves = [n for n in self.graph.nodes if self.graph.out_degree(n) == 0] - leave_nodes = [ConceptNode(l, self) for l in leaves] + leave_nodes = [ConceptNode(lv, self) for lv in leaves] if serialization: return [ln.to_dict(include_children=False) for ln in leave_nodes] else: return leave_nodes - def iter_nodes(self, root_id: int, order: str = "bfs", - serialization: bool = False): + def iter_nodes(self, root_id: int, order: str = "bfs", serialization: bool = False): """Iterate nodes in BFS or DFS order from a given root.""" if root_id not in self.graph.nodes: raise ValueError(f"Root node {root_id} not found in graph.") @@ -187,9 +189,7 @@ def iter_nodes(self, root_id: int, order: str = "bfs", raise ValueError("order must be 'bfs' or 'dfs'") def union(self, other: "ConceptHierarchy") -> "ConceptHierarchy": - new_ident = ConceptHierarchy._normalize_identifier( - f"{self.identifier}+{other.identifier}" - ) + new_ident = ConceptHierarchy._normalize_identifier(f"{self.identifier}+{other.identifier}") if new_ident in ConceptHierarchy._graph_cache: return ConceptHierarchy._graph_cache[new_ident] @@ -217,8 +217,17 @@ def to_dict(self, root_id: Optional[int] = None, include_union_metrics: bool = F if root_id is not None: if root_id not in self.graph: raise ValueError(f"Input concept id {root_id} not found in the concept hierarchy graph") - return {"hierarchy": [ConceptNode(root_id, self).to_dict(include_children=True, - include_union_metrics=include_union_metrics)]} + return { + "hierarchy": [ + ConceptNode(root_id, self).to_dict( + include_children=True, include_union_metrics=include_union_metrics + ) + ] + } - return {"hierarchy": [r.to_dict(include_children=True, include_union_metrics=include_union_metrics) - for r in self.get_root_nodes()]} + return { + "hierarchy": [ + r.to_dict(include_children=True, include_union_metrics=include_union_metrics) + for r in self.get_root_nodes() + ] + } diff --git a/biasanalyzer/config.py b/biasanalyzer/config.py index 141a047..7acd211 100644 --- a/biasanalyzer/config.py +++ b/biasanalyzer/config.py @@ -1,5 +1,6 @@ import yaml -from biasanalyzer.models import Configuration, CohortCreationConfig + +from biasanalyzer.models import CohortCreationConfig, Configuration def load_config(config_file): @@ -10,7 +11,7 @@ def load_config(config_file): def load_cohort_creation_config(config_file): - with open(config_file, encoding='utf-8') as f: + with open(config_file, encoding="utf-8") as f: config = yaml.safe_load(f) CohortCreationConfig(**config) return config diff --git a/biasanalyzer/database.py b/biasanalyzer/database.py index b40674e..cb449ab 100644 --- a/biasanalyzer/database.py +++ b/biasanalyzer/database.py @@ -1,16 +1,25 @@ -import duckdb -import pandas as pd +# ruff: noqa: S608 import gc -from typing import Optional from datetime import datetime -from tqdm.auto import tqdm -from sqlalchemy.orm import sessionmaker -from sqlalchemy.exc import SQLAlchemyError +from typing import Optional + +import duckdb +import pandas as pd from sqlalchemy import create_engine, text +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import sessionmaker +from tqdm.auto import tqdm + from biasanalyzer.models import CohortDefinition -from biasanalyzer.sql import (AGE_DISTRIBUTION_QUERY, GENDER_DISTRIBUTION_QUERY, AGE_STATS_QUERY, - GENDER_STATS_QUERY, RACE_STATS_QUERY, ETHNICITY_STATS_QUERY) -from biasanalyzer.utils import build_concept_hierarchy, print_hierarchy, find_roots, notify_users +from biasanalyzer.sql import ( + AGE_DISTRIBUTION_QUERY, + AGE_STATS_QUERY, + ETHNICITY_STATS_QUERY, + GENDER_DISTRIBUTION_QUERY, + GENDER_STATS_QUERY, + RACE_STATS_QUERY, +) +from biasanalyzer.utils import build_concept_hierarchy, find_roots, notify_users, print_hierarchy class BiasDatabase: @@ -22,12 +31,13 @@ class BiasDatabase: "age": AGE_STATS_QUERY, "gender": GENDER_STATS_QUERY, "race": RACE_STATS_QUERY, - "ethnicity": ETHNICITY_STATS_QUERY + "ethnicity": ETHNICITY_STATS_QUERY, } _instance = None # indicating a singleton with only one instance of the class ever created + def __new__(cls, *args, **kwargs): if cls._instance is None: - cls._instance = super(BiasDatabase, cls).__new__(cls) + cls._instance = super().__new__(cls) cls._instance._initialize(*args, **kwargs) # Initialize only once return cls._instance @@ -35,17 +45,17 @@ def _initialize(self, db_url, omop_db_url=None): # by default, duckdb uses in memory database self.conn = duckdb.connect(db_url) self.schema = "biasanalyzer" - self.omop_alias = 'omop' + self.omop_alias = "omop" self.conn.execute(f"CREATE SCHEMA IF NOT EXISTS {self.schema}") self.omop_cdm_db_url = omop_db_url if omop_db_url is not None: - if omop_db_url.startswith('postgresql://'): + if omop_db_url.startswith("postgresql://"): # omop db is postgreSQL self.load_postgres_extension() self.conn.execute(f""" ATTACH '{self.omop_cdm_db_url}' as {self.omop_alias} (TYPE postgres) """) - elif omop_db_url.endswith('.duckdb'): + elif omop_db_url.endswith(".duckdb"): self.conn.execute(f""" ATTACH '{self.omop_cdm_db_url}' as {self.omop_alias} """) @@ -60,13 +70,13 @@ def _initialize(self, db_url, omop_db_url=None): def _create_cohort_definition_table(self): try: - self.conn.execute(f'CREATE SEQUENCE {self.schema}.id_sequence START 1') + self.conn.execute(f"CREATE SEQUENCE {self.schema}.id_sequence START 1") except duckdb.Error as e: if "already exists" in str(e).lower(): notify_users("Sequence already exists, skipping creation.") else: raise - self.conn.execute(f''' + self.conn.execute(f""" CREATE TABLE IF NOT EXISTS {self.schema}.cohort_definition ( id INTEGER DEFAULT nextval('{self.schema}.id_sequence'), name VARCHAR NOT NULL, @@ -76,11 +86,11 @@ def _create_cohort_definition_table(self): created_by VARCHAR, PRIMARY KEY (id) ) - ''') + """) notify_users("Cohort Definition table created.") def _create_cohort_table(self): - self.conn.execute(f''' + self.conn.execute(f""" CREATE TABLE IF NOT EXISTS {self.schema}.cohort ( subject_id BIGINT, cohort_definition_id INTEGER, @@ -88,12 +98,12 @@ def _create_cohort_table(self): cohort_end_date DATE, FOREIGN KEY (cohort_definition_id) REFERENCES {self.schema}.cohort_definition(id) ) - ''') + """) try: - self.conn.execute(f''' + self.conn.execute(f""" CREATE INDEX idx_cohort_dates ON {self.schema}.cohort (cohort_definition_id, cohort_start_date, cohort_end_date); - ''') + """) except duckdb.Error as e: if "already exists" in str(e).lower(): notify_users("Index already exists, skipping creation.") @@ -106,16 +116,19 @@ def load_postgres_extension(self): self.conn.execute("LOAD postgres;") def create_cohort_definition(self, cohort_definition: CohortDefinition, progress_obj=None): - self.conn.execute(f''' + self.conn.execute( + f""" INSERT INTO {self.schema}.cohort_definition (name, description, created_date, creation_info, created_by) VALUES (?, ?, ?, ?, ?) - ''', ( - cohort_definition.name, - cohort_definition.description, - cohort_definition.created_date or datetime.now(), - cohort_definition.creation_info, - cohort_definition.created_by - )) + """, + ( + cohort_definition.name, + cohort_definition.description, + cohort_definition.created_date or datetime.now(), + cohort_definition.creation_info, + cohort_definition.created_by, + ), + ) if progress_obj is None: notify_users("Cohort definition inserted successfully.") # pragma: no cover else: @@ -128,16 +141,16 @@ def create_cohort_definition(self, cohort_definition: CohortDefinition, progress def create_cohort_in_bulk(self, cohort_df: pd.DataFrame): # make duckdb to treat cohort_df dataframe as a virtual table named "cohort_df" self.conn.register("cohort_df", cohort_df) - self.conn.execute(f''' + self.conn.execute(f""" INSERT INTO {self.schema}.cohort (subject_id, cohort_definition_id, cohort_start_date, cohort_end_date) SELECT subject_id, cohort_definition_id, cohort_start_date, cohort_end_date FROM cohort_df - ''') + """) def get_cohort_definition(self, cohort_definition_id): - results = self.conn.execute(f''' + results = self.conn.execute(f""" SELECT id, name, description, created_date, creation_info, created_by FROM {self.schema}.cohort_definition WHERE id = {cohort_definition_id} - ''') + """) headers = [desc[0] for desc in results.description] row = results.fetchall() if len(row) == 0: @@ -146,10 +159,10 @@ def get_cohort_definition(self, cohort_definition_id): return dict(zip(headers, row[0])) def get_cohort(self, cohort_definition_id): - results = self.conn.execute(f''' + results = self.conn.execute(f""" SELECT subject_id, cohort_definition_id, cohort_start_date, cohort_end_date FROM {self.schema}.cohort WHERE cohort_definition_id = {cohort_definition_id} - ''') + """) headers = [desc[0] for desc in results.description] rows = results.fetchall() return [dict(zip(headers, row)) for row in rows] @@ -161,7 +174,7 @@ def _execute_query(self, query_str): rows = results.fetchall() return [dict(zip(headers, row)) for row in rows] - def get_cohort_basic_stats(self, cohort_definition_id: int, variable=''): + def get_cohort_basic_stats(self, cohort_definition_id: int, variable=""): """ Get aggregation statistics for a cohort from the cohort table. :param cohort_definition_id: cohort definition id representing the cohort @@ -174,12 +187,16 @@ def get_cohort_basic_stats(self, cohort_definition_id: int, variable=''): if variable: query_str = self.__class__.stats_queries.get(variable) if query_str is None: - raise ValueError(f"Statistics for variable '{variable}' is not available. " - f"Valid variables are {self.__class__.stats_queries.keys()}") - stats_query = query_str.format(ba_schema=self.schema, omop=self.omop_alias, cohort_definition_id=cohort_definition_id) + raise ValueError( + f"Statistics for variable '{variable}' is not available. " + f"Valid variables are {self.__class__.stats_queries.keys()}" + ) + stats_query = query_str.format( + ba_schema=self.schema, omop=self.omop_alias, cohort_definition_id=cohort_definition_id + ) else: # Query the cohort data to get basic statistics - stats_query = f''' + stats_query = f""" WITH cohort_Duration AS ( SELECT subject_id, @@ -202,11 +219,11 @@ def get_cohort_basic_stats(self, cohort_definition_id: int, variable=''): CAST(PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY duration_days) AS INT) AS median_duration, ROUND(STDDEV(duration_days), 2) AS stddev_duration FROM cohort_Duration - ''' + """ return self._execute_query(stats_query) except Exception as e: - notify_users(f"Error computing cohort basic statistics: {e}", level='error') + notify_users(f"Error computing cohort basic statistics: {e}", level="error") return None @property @@ -220,18 +237,27 @@ def get_cohort_distributions(self, cohort_definition_id: int, variable: str): try: query_str = self.__class__.distribution_queries.get(variable) if not query_str: - raise ValueError(f"Distribution for variable '{variable}' is not available. " - f"Valid variables are {self.__class__.distribution_queries.keys()}") - query = query_str.format(ba_schema=self.schema, omop=self.omop_alias, - cohort_definition_id=cohort_definition_id) + raise ValueError( + f"Distribution for variable '{variable}' is not available. " + f"Valid variables are {self.__class__.distribution_queries.keys()}" + ) + query = query_str.format( + ba_schema=self.schema, omop=self.omop_alias, cohort_definition_id=cohort_definition_id + ) return self._execute_query(query) except Exception as e: - notify_users(f"Error computing cohort {variable} distributions: {e}", level='error') + notify_users(f"Error computing cohort {variable} distributions: {e}", level="error") return None - def get_cohort_concept_stats(self, cohort_definition_id: int, qry_builder, - concept_type='condition_occurrence', filter_count=0, vocab=None, - print_concept_hierarchy=False): + def get_cohort_concept_stats( + self, + cohort_definition_id: int, + qry_builder, + concept_type="condition_occurrence", + filter_count=0, + vocab=None, + print_concept_hierarchy=False, + ): """ Get concept statistics for a cohort from the cohort table. """ @@ -241,35 +267,40 @@ def get_cohort_concept_stats(self, cohort_definition_id: int, qry_builder, # validate input vocab if it is not None if vocab is not None: valid_vocabs = self._execute_query(f"SELECT distinct vocabulary_id FROM {self.omop_alias}.concept") - valid_vocab_ids = [row['vocabulary_id'] for row in valid_vocabs] + valid_vocab_ids = [row["vocabulary_id"] for row in valid_vocabs] if vocab not in valid_vocab_ids: - err_msg = (f"input {vocab} is not a valid vocabulary in OMOP. " - f"Supported vocabulary ids are: {valid_vocab_ids}") - notify_users(err_msg, level='error') + err_msg = ( + f"input {vocab} is not a valid vocabulary in OMOP. " + f"Supported vocabulary ids are: {valid_vocab_ids}" + ) + notify_users(err_msg, level="error") raise ValueError(err_msg) - query = qry_builder.build_concept_prevalence_query(self.schema, self.omop_alias, concept_type, - cohort_definition_id, filter_count, vocab) + query = qry_builder.build_concept_prevalence_query( + self.schema, self.omop_alias, concept_type, cohort_definition_id, filter_count, vocab + ) concept_stats[concept_type] = self._execute_query(query) cs_df = pd.DataFrame(concept_stats[concept_type]) # Combine concept_name and prevalence into a "details" column cs_df["details"] = cs_df.apply( lambda row: f"{row['concept_name']} (Code: {row['concept_code']}, " - f"Count: {row['count_in_cohort']}, Prevalence: {row['prevalence']:.3%})", axis=1) + f"Count: {row['count_in_cohort']}, Prevalence: {row['prevalence']:.3%})", + axis=1, + ) if print_concept_hierarchy: - filtered_cs_df = cs_df[cs_df['ancestor_concept_id'] != cs_df['descendant_concept_id']] + filtered_cs_df = cs_df[cs_df["ancestor_concept_id"] != cs_df["descendant_concept_id"]] roots = find_roots(filtered_cs_df) hierarchy = build_concept_hierarchy(filtered_cs_df) - notify_users(f'cohort concept hierarchy for {concept_type} with root concept ids {roots}:') + notify_users(f"cohort concept hierarchy for {concept_type} with root concept ids {roots}:") for root in roots: - root_detail = cs_df[(cs_df['ancestor_concept_id'] == root) - & (cs_df['descendant_concept_id'] == root)]['details'].iloc[0] + root_detail = cs_df[ + (cs_df["ancestor_concept_id"] == root) & (cs_df["descendant_concept_id"] == root) + ]["details"].iloc[0] print_hierarchy(hierarchy, parent=root, level=0, parent_details=root_detail) return concept_stats except Exception as e: - err_msg = f"Error computing cohort concept stats: {e}" - raise ValueError(err_msg) + raise ValueError("Error computing cohort concept stats") from e def close(self): if self.conn: @@ -281,46 +312,47 @@ def close(self): class OMOPCDMDatabase: _instance = None # indicating a singleton with only one instance of the class ever created _database_type = None + def __new__(cls, *args, **kwargs): if cls._instance is None: - cls._instance = super(OMOPCDMDatabase, cls).__new__(cls) + cls._instance = super().__new__(cls) cls._instance._initialize(*args, **kwargs) # Initialize only once return cls._instance def _initialize(self, db_url): - if db_url.endswith('.duckdb'): + if db_url.endswith(".duckdb"): # close any potential global connections if any - for obj in gc.get_objects(): # pragma: no cover + for obj in gc.get_objects(): # pragma: no cover if isinstance(obj, duckdb.DuckDBPyConnection): try: obj.close() except Exception as e: - notify_users(f'failed to close the lingering duckdb connection before opening a new one: {e}') + notify_users(f"failed to close the lingering duckdb connection before opening a new one: {e}") # Handle DuckDB connection try: self.engine = duckdb.connect(db_url) notify_users(f"Connected to the DuckDB database: {db_url}.") except duckdb.Error as e: # pragma: no cover - notify_users(f"Failed to connect to DuckDB: {e}", level='error') + notify_users(f"Failed to connect to DuckDB: {e}", level="error") self.Session = self.engine # Use engine directly for DuckDB - self._database_type = 'duckdb' + self._database_type = "duckdb" else: # pragma: no cover # Handle PostgreSQL connection try: self.engine = create_engine( db_url, echo=False, - connect_args={'options': '-c default_transaction_read_only=on'} # Enforce read-only transactions + connect_args={"options": "-c default_transaction_read_only=on"}, # Enforce read-only transactions ) self.Session = sessionmaker(bind=self.engine) notify_users("Connected to the OMOP CDM database (read-only).") - self._database_type = 'postgresql' + self._database_type = "postgresql" except SQLAlchemyError as e: - notify_users(f"Failed to connect to the database: {e}", level='error') + notify_users(f"Failed to connect to the database: {e}", level="error") def get_session(self): - if self._database_type == 'duckdb': + if self._database_type == "duckdb": return self.engine else: # pragma: no cover # postgresql connection: provide a new session for read-only queries @@ -328,7 +360,7 @@ def get_session(self): def execute_query(self, query, params=None): try: - if self._database_type == 'duckdb': + if self._database_type == "duckdb": # DuckDB query execution results = self.engine.execute(query, params).fetchall() headers = [desc[0] for desc in self.engine.execute(query, params).description] @@ -344,10 +376,10 @@ def execute_query(self, query, params=None): return [dict(zip(headers, row)) for row in results] except duckdb.Error as e: - notify_users(f"Error executing query: {e}", level='error') + notify_users(f"Error executing query: {e}", level="error") return [] except SQLAlchemyError as e: # pragma: no cover - notify_users(f"Error executing query: {e}", level='error') + notify_users(f"Error executing query: {e}", level="error") if omop_session: omop_session.close() return [] @@ -361,11 +393,11 @@ def get_domains_and_vocabularies(self) -> list: def get_concepts(self, search_term: str, domain: Optional[str], vocab: Optional[str]) -> list: search_term_exact = search_term.lower() - search_term_suffix = f'{search_term_exact} ' - search_term_prefix = f' {search_term_exact}' - search_term_prefix_suffix = f' {search_term_exact} ' + search_term_suffix = f"{search_term_exact} " + search_term_prefix = f" {search_term_exact}" + search_term_prefix_suffix = f" {search_term_exact} " - if self._database_type == 'duckdb': + if self._database_type == "duckdb": # Use positional parameters and ? as placeholder to meet duckdb syntax requirement base_query = """ SELECT concept_id, concept_name, valid_start_date, valid_end_date, domain_id, vocabulary_id \ @@ -385,16 +417,20 @@ def get_concepts(self, search_term: str, domain: Optional[str], vocab: Optional[ if domain is not None and vocab is not None: condition_str = "domain_id = ? AND vocabulary_id = ?" - params = [domain, vocab, search_term_exact, search_term_prefix, search_term_suffix, - search_term_prefix_suffix] + params = [ + domain, + vocab, + search_term_exact, + search_term_prefix, + search_term_suffix, + search_term_prefix_suffix, + ] elif domain is None: condition_str = "vocabulary_id = ?" - params = [vocab, search_term_exact, search_term_prefix, search_term_suffix, - search_term_prefix_suffix] + params = [vocab, search_term_exact, search_term_prefix, search_term_suffix, search_term_prefix_suffix] else: condition_str = "domain_id = ?" - params = [domain, search_term_exact, search_term_prefix, search_term_suffix, - search_term_prefix_suffix] + params = [domain, search_term_exact, search_term_prefix, search_term_suffix, search_term_prefix_suffix] else: # pragma: no cover # Use named parameters with :param_name syntax for SQLAlchemy/PostgreSQL @@ -418,19 +454,19 @@ def get_concepts(self, search_term: str, domain: Optional[str], vocab: Optional[ "search_term_exact": search_term_exact, "search_term_prefix": search_term_prefix, "search_term_suffix": search_term_suffix, - "search_term_prefix_suffix": search_term_prefix_suffix + "search_term_prefix_suffix": search_term_prefix_suffix, } if domain is not None and vocab is not None: condition_str = "domain_id = :domain AND vocabulary_id = :vocabulary" - params['domain'] = domain - params['vocabulary'] = vocab + params["domain"] = domain + params["vocabulary"] = vocab elif domain is None: condition_str = "vocabulary_id = :vocabulary" - params['vocabulary'] = vocab + params["vocabulary"] = vocab else: condition_str = "domain_id = :domain" - params['domain'] = domain + params["domain"] = domain query = base_query.format(condition_str=condition_str) return self.execute_query(query, params=params) @@ -444,11 +480,7 @@ def get_concept_hierarchy(self, concept_id: int): # this check is important to avoid SQL injection risk raise ValueError("concept_id must be an integer") - stages = [ - "Queried concept hierarchy", - "Fetched concept details", - "Built hierarchy tree" - ] + stages = ["Queried concept hierarchy", "Fetched concept details", "Built hierarchy tree"] progress = tqdm(total=len(stages), desc="Concept Hierarchy", unit="stage") progress.set_postfix_str(stages[0]) @@ -475,7 +507,9 @@ def get_concept_hierarchy(self, concept_id: int): progress.set_postfix_str(stages[1]) # Collect all unique concept IDs involved in the hierarchy using set comprehension - concept_ids = {row['ancestor_concept_id'] for row in results} | {row['descendant_concept_id'] for row in results} + concept_ids = {row["ancestor_concept_id"] for row in results} | { + row["descendant_concept_id"] for row in results + } # Fetch details of each concept concept_details = {} if concept_ids: @@ -488,7 +522,7 @@ def get_concept_hierarchy(self, concept_id: int): """ result = self.execute_query(query) - concept_details = {row['concept_id']: row for row in result} + concept_details = {row["concept_id"]: row for row in result} progress.update(1) progress.set_postfix_str(stages[2]) @@ -496,19 +530,23 @@ def get_concept_hierarchy(self, concept_id: int): hierarchy = {} reverse_hierarchy = {} for row in results: - ancestor_id = row['ancestor_concept_id'] - descendant_id = row['descendant_concept_id'] + ancestor_id = row["ancestor_concept_id"] + descendant_id = row["descendant_concept_id"] ancestor_entry = hierarchy.setdefault( - ancestor_id, {"details": concept_details[ancestor_id], "children": []}) + ancestor_id, {"details": concept_details[ancestor_id], "children": []} + ) descendant_entry = hierarchy.setdefault( - descendant_id, {"details": concept_details[descendant_id], "children": []}) + descendant_id, {"details": concept_details[descendant_id], "children": []} + ) ancestor_entry["children"].append(descendant_entry) desc_entry_rev = reverse_hierarchy.setdefault( - descendant_id, {"details": concept_details[descendant_id], "parents": []}) + descendant_id, {"details": concept_details[descendant_id], "parents": []} + ) ancestor_entry_rev = reverse_hierarchy.setdefault( - ancestor_id, {"details": concept_details[ancestor_id], "parents": []}) + ancestor_id, {"details": concept_details[ancestor_id], "parents": []} + ) desc_entry_rev["parents"].append(ancestor_entry_rev) progress.update(1) progress.close() @@ -516,7 +554,6 @@ def get_concept_hierarchy(self, concept_id: int): # Return the parent hierarchy and children hierarchy of the specified concept return reverse_hierarchy[concept_id], hierarchy[concept_id] - def close(self): if isinstance(self.engine, duckdb.DuckDBPyConnection): self.engine.close() diff --git a/biasanalyzer/models.py b/biasanalyzer/models.py index 72790fa..35bbbc7 100644 --- a/biasanalyzer/models.py +++ b/biasanalyzer/models.py @@ -1,7 +1,7 @@ -from pydantic import BaseModel, StrictStr, ConfigDict, field_validator, model_validator -from typing import Optional, Literal, List, Union from datetime import date +from typing import List, Literal, Optional, Union +from pydantic import BaseModel, ConfigDict, StrictStr, field_validator, model_validator DOMAIN_MAPPING = { "condition_occurrence": { @@ -9,50 +9,50 @@ "concept_id": "condition_concept_id", "start_date": "condition_start_date", "end_date": "condition_end_date", - "default_vocab": "SNOMED" # for use by concept prevalence query + "default_vocab": "SNOMED", # for use by concept prevalence query }, "drug_exposure": { "table": "drug_exposure", "concept_id": "drug_concept_id", "start_date": "drug_exposure_start_date", "end_date": "drug_exposure_end_date", - "default_vocab": "RxNorm" # for use by concept prevalence query + "default_vocab": "RxNorm", # for use by concept prevalence query }, "procedure_occurrence": { "table": "procedure_occurrence", "concept_id": "procedure_concept_id", "start_date": "procedure_date", "end_date": "procedure_date", - "default_vocab": "SNOMED" # for use by concept prevalence query + "default_vocab": "SNOMED", # for use by concept prevalence query }, "visit_occurrence": { "table": "visit_occurrence", "concept_id": "visit_concept_id", "start_date": "visit_start_date", "end_date": "visit_end_date", - "default_vocab": "SNOMED" # for use by concept prevalence query + "default_vocab": "SNOMED", # for use by concept prevalence query }, "measurement": { "table": "measurement", "concept_id": "measurement_concept_id", "start_date": "measurement_date", "end_date": "measurement_date", - "default_vocab": "LOINC" # for use by concept prevalence query + "default_vocab": "LOINC", # for use by concept prevalence query }, "observation": { "table": "observation", "concept_id": "observation_concept_id", "start_date": "observation_date", "end_date": "observation_date", - "default_vocab": "SNOMED" # for use by concept prevalence query + "default_vocab": "SNOMED", # for use by concept prevalence query }, "date": { # Special case for static timestamps "table": None, "concept_id": None, "start_date": "timestamp", - "end_date": "timestamp", - "default_vocab": None - } + "end_date": "timestamp", + "default_vocab": None, + }, } EVENT_TYPE_LITERAL = Literal[tuple(DOMAIN_MAPPING.keys())] @@ -60,7 +60,7 @@ ###===========System Configuration==============### class RootOMOPCDM(BaseModel): - model_config = ConfigDict(extra='ignore') + model_config = ConfigDict(extra="ignore") username: StrictStr password: StrictStr hostname: StrictStr @@ -69,10 +69,13 @@ class RootOMOPCDM(BaseModel): class Configuration(BaseModel): - model_config = ConfigDict(extra='ignore') + model_config = ConfigDict(extra="ignore") root_omop_cdm_database: RootOMOPCDM + + ###===========System Configuration==============### + ###===========CohortDefinition Model==============### class CohortDefinition(BaseModel): model_config = ConfigDict(from_attributes=True) @@ -81,8 +84,11 @@ class CohortDefinition(BaseModel): created_date: date creation_info: str created_by: str + + ###===========CohortDefinition Model==============### + ###===========Cohort Model====================### class Cohort(BaseModel): model_config = ConfigDict(from_attributes=True) @@ -90,12 +96,15 @@ class Cohort(BaseModel): subject_id: int cohort_start_date: Optional[date] cohort_end_date: Optional[date] + + ###===========Cohort Model====================### + ###=========CohortCreationConfig==================### class DemographicsCriteria(BaseModel): # Gender with "male" and "female" as valid input - gender: Optional[Literal['male', 'female']] = None + gender: Optional[Literal["male", "female"]] = None # Minimum birth year min_birth_year: Optional[int] = None max_birth_year: Optional[int] = None @@ -165,7 +174,7 @@ def validate_events_list(cls, values): return values - def get_interval_sql(self, e1_alias='e1', e2_alias='e2') -> str: + def get_interval_sql(self, e1_alias="e1", e2_alias="e2") -> str: """Generate SQL for the interval.""" if not self.interval: # pragma: no cover return "" @@ -184,4 +193,6 @@ class CohortCreationConfig(BaseModel): # cohort creation criteria inclusion_criteria: CohortCreationCriteria exclusion_criteria: Optional[CohortCreationCriteria] = None + + ###=========CohortCreationConfig==================### diff --git a/biasanalyzer/module_test.py b/biasanalyzer/module_test.py index cacc79d..b3c0093 100644 --- a/biasanalyzer/module_test.py +++ b/biasanalyzer/module_test.py @@ -1,60 +1,72 @@ +import os import pprint -from biasanalyzer.api import BIAS import time -import os + import pandas as pd +from biasanalyzer.api import BIAS + def cohort_creation_template_test(bias_obj): - cohort_data = bias_obj.create_cohort('COVID-19 patients', 'COVID-19 patients', - os.path.join(os.path.dirname(__file__), '..', 'tests', 'assets', - 'cohort_creation', - 'extras', - 'diabetes_example2', - 'cohort_creation_config_baseline_example2.yaml'), - # 'covid_example3', - # 'cohort_creation_config_baseline_example3.yaml'), - # 'test_cohort_creation_condition_occurrence_config_study.yaml'), - 'system') + cohort_data = bias_obj.create_cohort( + "COVID-19 patients", + "COVID-19 patients", + os.path.join( + os.path.dirname(__file__), + "..", + "tests", + "assets", + "cohort_creation", + "extras", + "diabetes_example2", + "cohort_creation_config_baseline_example2.yaml", + ), + # 'covid_example3', + # 'cohort_creation_config_baseline_example3.yaml'), + # 'test_cohort_creation_condition_occurrence_config_study.yaml'), + "system", + ) if cohort_data: md = cohort_data.metadata - print(f'cohort_definition: {md}') - print(f'The first five records in the cohort {cohort_data.data[:5]}') - print(f'the cohort stats: {cohort_data.get_stats()}') - print(f'the cohort age stats: {cohort_data.get_stats("age")}') - print(f'the cohort gender stats: {cohort_data.get_stats("gender")}') - print(f'the cohort race stats: {cohort_data.get_stats("race")}') - print(f'the cohort ethnicity stats: {cohort_data.get_stats("ethnicity")}') - print(f'the cohort age distributions: {cohort_data.get_distributions("age")}') - print(f'the cohort gender distributions: {cohort_data.get_distributions("gender")}') - compare_stats = bias_obj.compare_cohorts(cohort_data.metadata['id'], cohort_data.metadata['id']) - print(f'compare_stats: {compare_stats}') + print(f"cohort_definition: {md}") + print(f"The first five records in the cohort {cohort_data.data[:5]}") + print(f"the cohort stats: {cohort_data.get_stats()}") + print(f"the cohort age stats: {cohort_data.get_stats('age')}") + print(f"the cohort gender stats: {cohort_data.get_stats('gender')}") + print(f"the cohort race stats: {cohort_data.get_stats('race')}") + print(f"the cohort ethnicity stats: {cohort_data.get_stats('ethnicity')}") + print(f"the cohort age distributions: {cohort_data.get_distributions('age')}") + print(f"the cohort gender distributions: {cohort_data.get_distributions('gender')}") + compare_stats = bias_obj.compare_cohorts(cohort_data.metadata["id"], cohort_data.metadata["id"]) + print(f"compare_stats: {compare_stats}") return def condition_cohort_test(bias_obj): - baseline_cohort_query = ('SELECT c.person_id, MIN(c.condition_start_date) as cohort_start_date, ' - 'MAX(c.condition_end_date) as cohort_end_date ' - 'FROM condition_occurrence c JOIN ' - 'person p ON c.person_id = p.person_id ' - 'WHERE c.condition_concept_id = 201826 GROUP BY c.person_id') - cohort_data = bias_obj.create_cohort('Diabetics patients', 'Diabetics patients', - baseline_cohort_query, 'system') + baseline_cohort_query = ( + "SELECT c.person_id, MIN(c.condition_start_date) as cohort_start_date, " + "MAX(c.condition_end_date) as cohort_end_date " + "FROM condition_occurrence c JOIN " + "person p ON c.person_id = p.person_id " + "WHERE c.condition_concept_id = 201826 GROUP BY c.person_id" + ) + cohort_data = bias_obj.create_cohort("Diabetics patients", "Diabetics patients", baseline_cohort_query, "system") if cohort_data: md = cohort_data.metadata - print(f'cohort_definition: {md}') - print(f'The first five records in the cohort {cohort_data.data[:5]}') - print(f'the cohort stats: {cohort_data.get_stats()}') - print(f'the cohort age stats: {cohort_data.get_stats("age")}') - print(f'the cohort gender stats: {cohort_data.get_stats("gender")}') - print(f'the cohort race stats: {cohort_data.get_stats("race")}') - print(f'the cohort ethnicity stats: {cohort_data.get_stats("ethnicity")}') - print(f'the cohort age distributions: {cohort_data.get_distributions("age")}') + print(f"cohort_definition: {md}") + print(f"The first five records in the cohort {cohort_data.data[:5]}") + print(f"the cohort stats: {cohort_data.get_stats()}") + print(f"the cohort age stats: {cohort_data.get_stats('age')}") + print(f"the cohort gender stats: {cohort_data.get_stats('gender')}") + print(f"the cohort race stats: {cohort_data.get_stats('race')}") + print(f"the cohort ethnicity stats: {cohort_data.get_stats('ethnicity')}") + print(f"the cohort age distributions: {cohort_data.get_distributions('age')}") t1 = time.time() - _, cohort_concept_hierarchy = cohort_data.get_concept_stats(concept_type='condition_occurrence', - filter_count=5000) + _, cohort_concept_hierarchy = cohort_data.get_concept_stats( + concept_type="condition_occurrence", filter_count=5000 + ) concept_node = cohort_concept_hierarchy.get_node(concept_id=201826) - print(f'concept_node 201826 metric: {concept_node.get_metrics(md["id"])}') + print(f"concept_node 201826 metric: {concept_node.get_metrics(md['id'])}") # Print the root node root_nodes = cohort_concept_hierarchy.get_root_nodes() @@ -64,53 +76,54 @@ def condition_cohort_test(bias_obj): print(f"Root: {root}", flush=True) print(f"Leaves: {leaves}", flush=True) for node in cohort_concept_hierarchy.iter_nodes(root_nodes[0].id, serialization=True): - print(node) + print(node) hier_dict = cohort_concept_hierarchy.to_dict() - pprint.pprint(hier_dict, indent=2) - + with open("diabetics_condition_occurrence_hierarchy_dict.txt", "w") as cof: + pprint.pprint(hier_dict, indent=2, stream=cof) - _, cohort_de_concept_hierarchy = cohort_data.get_concept_stats(concept_type='drug_exposure', - filter_count=500) + _, cohort_de_concept_hierarchy = cohort_data.get_concept_stats(concept_type="drug_exposure", filter_count=500) de_hier_dict = cohort_de_concept_hierarchy.to_dict() - pprint.pprint(de_hier_dict, indent=2) - compare_stats = bias_obj.compare_cohorts(cohort_data.metadata['id'], cohort_data.metadata['id']) - print(f'compare_stats: {compare_stats}') + with open("diabetics_drug_exposure_hierarchy_dict.txt", "w") as dof: + pprint.pprint(de_hier_dict, indent=2, stream=dof) + # compare_stats = bias_obj.compare_cohorts(cohort_data.metadata['id'], cohort_data.metadata['id']) + # print(f'compare_stats: {compare_stats}') + print(f"times taken for computing cohort concept hierarcy: {time.time() - t1}") return def concept_test(bias_obj): - print(f'domains and vocabularies: \n{pd.DataFrame(bias_obj.get_domains_and_vocabularies())}') + print(f"domains and vocabularies: \n{pd.DataFrame(bias_obj.get_domains_and_vocabularies())}") # calling get_concepts() without passing in domain and vocabulary should raise an exception bias_obj.get_concepts("COVID-19") concepts = bias_obj.get_concepts("COVID-19", "Condition", "SNOMED") - print(f'concepts for COVID-19 in Condition domain with SNOMED vocabulary: \n{pd.DataFrame(concepts)}') + print(f"concepts for COVID-19 in Condition domain with SNOMED vocabulary: \n{pd.DataFrame(concepts)}") concepts = bias_obj.get_concepts("COVID-19", domain="Condition") - print(f'concepts for COVID-19 in Condition domain: \n{pd.DataFrame(concepts)}') + print(f"concepts for COVID-19 in Condition domain: \n{pd.DataFrame(concepts)}") concepts = bias_obj.get_concepts("COVID-19", vocabulary="SNOMED") - print(f'concepts for COVID-19 in SNOMED vocabulary: \n{pd.DataFrame(concepts)}') + print(f"concepts for COVID-19 in SNOMED vocabulary: \n{pd.DataFrame(concepts)}") parent_concept_tree, children_concept_tree = bias_obj.get_concept_hierarchy(37311061) - print('parent concept hierarchy for COVID-19 in text format:') + print("parent concept hierarchy for COVID-19 in text format:") print(bias_obj.display_concept_tree(parent_concept_tree)) - print('children concept hierarchy for COVID-19 in text format:') + print("children concept hierarchy for COVID-19 in text format:") print(bias_obj.display_concept_tree(children_concept_tree)) - print(f'parent concept hierarchy for COVID-19 in widget tree format:') + print("parent concept hierarchy for COVID-19 in widget tree format:") bias_obj.display_concept_tree(parent_concept_tree, show_in_text_format=False) - print(f'children concept hierarchy for COVID-19 in widget tree format:') + print("children concept hierarchy for COVID-19 in widget tree format:") bias_obj.display_concept_tree(children_concept_tree, show_in_text_format=False) return -if __name__ == '__main__': +if __name__ == "__main__": bias = None - pd.set_option('display.max_rows', None) - pd.set_option('display.max_columns', None) - pd.set_option('display.width', 1000) + pd.set_option("display.max_rows", None) + pd.set_option("display.max_columns", None) + pd.set_option("display.width", 1000) try: bias = BIAS() # bias.set_config(os.path.join(os.path.dirname(__file__), '..', 'config_duckdb.yaml')) - bias.set_config(os.path.join(os.path.dirname(__file__), '..', 'config.yaml')) + bias.set_config(os.path.join(os.path.dirname(__file__), "..", "config.yaml")) bias.set_root_omop() cohort_creation_template_test(bias) @@ -119,4 +132,4 @@ def concept_test(bias_obj): finally: if bias is not None: bias.cleanup() - print('done') + print("done") diff --git a/biasanalyzer/sql.py b/biasanalyzer/sql.py index a47c7c6..84c7cd8 100644 --- a/biasanalyzer/sql.py +++ b/biasanalyzer/sql.py @@ -1,6 +1,6 @@ # SQL templates for querying in OMOP database -AGE_DISTRIBUTION_QUERY = ''' +AGE_DISTRIBUTION_QUERY = """ WITH Age_Cohort AS ( SELECT p.person_id, EXTRACT(YEAR FROM @@ -42,9 +42,9 @@ ROUND(bin_count * 1.0 / SUM(bin_count) OVER (), 2) AS probability -- Normalize to get probability FROM Age_Distribution ORDER BY age_bin -''' +""" -GENDER_DISTRIBUTION_QUERY = ''' +GENDER_DISTRIBUTION_QUERY = """ WITH Gender_Categories AS ( SELECT 'male' AS gender, 8507 AS gender_concept_id UNION ALL SELECT 'female', 8532 @@ -76,9 +76,9 @@ ROUND(COALESCE(gender_count, 0) * 1.0 / SUM(COALESCE(gender_count, 0)) OVER (), 2) AS probability FROM Gender_Distribution ORDER BY gender; -''' +""" -AGE_STATS_QUERY = ''' +AGE_STATS_QUERY = """ WITH Age_Cohort AS ( SELECT p.person_id, EXTRACT(YEAR FROM @@ -100,9 +100,9 @@ CAST(PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY age) AS INT) AS median_age, ROUND(STDDEV(age), 2) as stddev_age FROM Age_Cohort -''' +""" -GENDER_STATS_QUERY = ''' +GENDER_STATS_QUERY = """ SELECT CASE WHEN p.gender_concept_id = 8507 THEN 'male' @@ -114,9 +114,9 @@ FROM {ba_schema}.cohort c JOIN {omop}.person p ON c.subject_id = p.person_id WHERE c.cohort_definition_id = {cohort_definition_id} GROUP BY p.gender_concept_id -''' +""" -RACE_STATS_QUERY = ''' +RACE_STATS_QUERY = """ SELECT CASE WHEN p.race_concept_id = 8516 THEN 'Black or African American' @@ -131,9 +131,9 @@ FROM {ba_schema}.cohort c JOIN {omop}.person p ON c.subject_id = p.person_id WHERE c.cohort_definition_id = {cohort_definition_id} GROUP BY p.race_concept_id -''' +""" -ETHNICITY_STATS_QUERY = ''' +ETHNICITY_STATS_QUERY = """ SELECT CASE WHEN p.ethnicity_concept_id = 38003563 THEN 'Hispanic or Latino' @@ -145,4 +145,4 @@ FROM {ba_schema}.cohort c JOIN {omop}.person p ON c.subject_id = p.person_id WHERE c.cohort_definition_id = {cohort_definition_id} GROUP BY p.ethnicity_concept_id -''' +""" diff --git a/biasanalyzer/utils.py b/biasanalyzer/utils.py index c2a2e2b..ad6b47d 100644 --- a/biasanalyzer/utils.py +++ b/biasanalyzer/utils.py @@ -1,8 +1,8 @@ -import numpy as np -import re -from ipytree import Node import logging +import re +import numpy as np +from ipytree import Node logger = logging.getLogger(__name__) @@ -29,14 +29,14 @@ def notify_users(message: str, level: str = "info"): def get_direction_arrow(tree_type): # the two unicodes are for up and down arrows - return "\U0001F53C" if tree_type == 'parents' else "\U0001F53D" + return "\U0001f53c" if tree_type == "parents" else "\U0001f53d" def clean_string(text): # replace newlines and tabs with a space - text = re.sub(r'[\n\t]', ' ', text) + text = re.sub(r"[\n\t]", " ", text) # Replace multiple spaces with a single space - text = re.sub(r'\s+', ' ', text) + text = re.sub(r"\s+", " ", text) # Strip leading and trailing spaces return text.strip() @@ -55,23 +55,21 @@ def hellinger_distance(p, q): return np.sqrt(0.5 * np.sum((np.sqrt(p) - np.sqrt(q)) ** 2)) -def build_concept_hierarchy(df, parent_col="ancestor_concept_id", child_col="descendant_concept_id", - details_col="details"): +def build_concept_hierarchy( + df, parent_col="ancestor_concept_id", child_col="descendant_concept_id", details_col="details" +): """ Builds a hierarchy using only direct parent-child relationships to remove duplicate branches. """ grouped = df.groupby(parent_col) - hierarchy = { - parent: list(zip(group[child_col], group[details_col])) - for parent, group in grouped - } + hierarchy = {parent: list(zip(group[child_col], group[details_col])) for parent, group in grouped} return hierarchy def build_concept_tree(concept_tree: dict, tree_type: str) -> Node: """ - Recursively builds an ipytree Node for a given concept tree. - """ + Recursively builds an ipytree Node for a given concept tree. + """ # Extract concept details details = concept_tree.get("details", {}) concept_name = details.get("concept_name", "Unknown Concept") @@ -111,5 +109,5 @@ def print_hierarchy(hierarchy, parent=None, level=0, parent_details=None): print(parent_details) level += 1 for child, details in hierarchy[parent]: - print(f" " * level + details) + print(" " * level + details) print_hierarchy(hierarchy, parent=child, level=level + 1) diff --git a/poetry.lock b/poetry.lock index 8224fb5..28cf8c9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. [[package]] name = "annotated-types" @@ -6,6 +6,7 @@ version = "0.7.0" description = "Reusable constraint types to use with typing.Annotated" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53"}, {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, @@ -20,6 +21,8 @@ version = "0.1.4" description = "Disable App Nap on macOS >= 10.9" optional = false python-versions = ">=3.6" +groups = ["main"] +markers = "sys_platform == \"darwin\"" files = [ {file = "appnope-0.1.4-py2.py3-none-any.whl", hash = "sha256:502575ee11cd7a28c0205f379b525beefebab9d161b7c964670864014ed7213c"}, {file = "appnope-0.1.4.tar.gz", hash = "sha256:1de3860566df9caf38f01f86f65e0e13e379af54f9e4bee1e66b48f2efffd1ee"}, @@ -31,6 +34,7 @@ version = "3.0.0" description = "Annotate AST trees with source code positions" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "asttokens-3.0.0-py3-none-any.whl", hash = "sha256:e3078351a059199dd5138cb1c706e6430c05eff2ff136af5eb4790f9d28932e2"}, {file = "asttokens-3.0.0.tar.gz", hash = "sha256:0dcd8baa8d62b0c1d118b399b2ddba3c4aff271d0d7a9e0d4c1681c79035bbc7"}, @@ -46,6 +50,7 @@ version = "0.2.0" description = "Specifications for callback functions passed in to an API" optional = false python-versions = "*" +groups = ["main"] files = [ {file = "backcall-0.2.0-py2.py3-none-any.whl", hash = "sha256:fbbce6a29f263178a1f7915c1940bde0ec2b2a967566fe1c65c1dfb7422bd255"}, {file = "backcall-0.2.0.tar.gz", hash = "sha256:5cbdbf27be5e7cfadb448baf0aa95508f91f2bbc6c6437cd9cd06e2a4c215e1e"}, @@ -57,10 +62,12 @@ version = "0.4.6" description = "Cross-platform colored terminal text." optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +groups = ["main", "dev"] files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +markers = {main = "sys_platform == \"win32\" or platform_system == \"Windows\"", dev = "sys_platform == \"win32\""} [[package]] name = "comm" @@ -68,6 +75,7 @@ version = "0.2.2" description = "Jupyter Python Comm implementation, for usage in ipykernel, xeus-python etc." optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "comm-0.2.2-py3-none-any.whl", hash = "sha256:e6fb86cb70ff661ee8c9c14e7d36d6de3b4066f1441be4063df9c5009f0a64d3"}, {file = "comm-0.2.2.tar.gz", hash = "sha256:3fd7a84065306e07bea1773df6eb8282de51ba82f77c72f9c85716ab11fe980e"}, @@ -85,6 +93,7 @@ version = "7.6.1" description = "Code coverage measurement for Python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "coverage-7.6.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b06079abebbc0e89e6163b8e8f0e16270124c154dc6e4a47b413dd538859af16"}, {file = "coverage-7.6.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cf4b19715bccd7ee27b6b120e7e9dd56037b9c0681dcc1adc9ba9db3d417fa36"}, @@ -164,7 +173,7 @@ files = [ tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.11.0a6\" and extra == \"toml\""} [package.extras] -toml = ["tomli"] +toml = ["tomli ; python_full_version <= \"3.11.0a6\""] [[package]] name = "decorator" @@ -172,6 +181,7 @@ version = "5.1.1" description = "Decorators for Humans" optional = false python-versions = ">=3.5" +groups = ["main"] files = [ {file = "decorator-5.1.1-py3-none-any.whl", hash = "sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186"}, {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"}, @@ -183,6 +193,7 @@ version = "1.1.3" description = "DuckDB in-process database" optional = false python-versions = ">=3.7.0" +groups = ["main"] files = [ {file = "duckdb-1.1.3-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:1c0226dc43e2ee4cc3a5a4672fddb2d76fd2cf2694443f395c02dd1bea0b7fce"}, {file = "duckdb-1.1.3-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:7c71169fa804c0b65e49afe423ddc2dc83e198640e3b041028da8110f7cd16f7"}, @@ -244,6 +255,7 @@ version = "0.13.6" description = "SQLAlchemy driver for duckdb" optional = false python-versions = "<4,>=3.8" +groups = ["main"] files = [ {file = "duckdb_engine-0.13.6-py3-none-any.whl", hash = "sha256:cedd44252cce5f42de88752026925154a566c407987116a242d250642904ba84"}, {file = "duckdb_engine-0.13.6.tar.gz", hash = "sha256:221ec7759e157fd8d4fcb0bd64f603c5a4b1889186f30d805a91b10a73f8c59a"}, @@ -260,6 +272,8 @@ version = "1.2.2" description = "Backport of PEP 654 (exception groups)" optional = false python-versions = ">=3.7" +groups = ["dev"] +markers = "python_version <= \"3.10\"" files = [ {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"}, {file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"}, @@ -274,13 +288,14 @@ version = "2.2.0" description = "Get the currently executing AST node of a frame, and other information" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "executing-2.2.0-py2.py3-none-any.whl", hash = "sha256:11387150cad388d62750327a53d3339fad4888b39a6fe233c3afbb54ecffd3aa"}, {file = "executing-2.2.0.tar.gz", hash = "sha256:5d108c028108fe2551d1a7b2e8b713341e2cb4fc0aa7dcf966fa4327a5226755"}, ] [package.extras] -tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipython", "littleutils", "pytest", "rich"] +tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipython", "littleutils", "pytest", "rich ; python_version >= \"3.11\""] [[package]] name = "greenlet" @@ -288,6 +303,8 @@ version = "3.1.1" description = "Lightweight in-process concurrent programming" optional = false python-versions = ">=3.7" +groups = ["main"] +markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\"" files = [ {file = "greenlet-3.1.1-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:0bbae94a29c9e5c7e4a2b7f0aae5c17e8e90acbfd3bf6270eeba60c39fce3563"}, {file = "greenlet-3.1.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0fde093fb93f35ca72a556cf72c92ea3ebfda3d79fc35bb19fbe685853869a83"}, @@ -374,6 +391,7 @@ version = "2.0.0" description = "brain-dead simple config-ini parsing" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, @@ -385,6 +403,7 @@ version = "8.12.3" description = "IPython: Productive Interactive Computing" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "ipython-8.12.3-py3-none-any.whl", hash = "sha256:b0340d46a933d27c657b211a329d0be23793c36595acf9e6ef4164bc01a1804c"}, {file = "ipython-8.12.3.tar.gz", hash = "sha256:3910c4b54543c2ad73d06579aa771041b7d5707b033bd488669b4cf544e3b363"}, @@ -424,6 +443,7 @@ version = "0.2.2" description = "A Tree Widget using jsTree" optional = false python-versions = "*" +groups = ["main"] files = [ {file = "ipytree-0.2.2-py2.py3-none-any.whl", hash = "sha256:744dc1a02c3ec26df8a5ecd87d085a67dc8232a1def6048834403ddcf3b64143"}, {file = "ipytree-0.2.2.tar.gz", hash = "sha256:d53d739bbaaa45415733cd06e0dc420a2af3d173438617db472a517bc7a61e56"}, @@ -438,6 +458,7 @@ version = "8.1.5" description = "Jupyter interactive widgets" optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "ipywidgets-8.1.5-py3-none-any.whl", hash = "sha256:3290f526f87ae6e77655555baba4f36681c555b8bdbbff430b70e52c34c86245"}, {file = "ipywidgets-8.1.5.tar.gz", hash = "sha256:870e43b1a35656a80c18c9503bbf2d16802db1cb487eec6fab27d683381dde17"}, @@ -459,6 +480,7 @@ version = "0.19.2" description = "An autocompletion tool for Python that can be used for text editors." optional = false python-versions = ">=3.6" +groups = ["main"] files = [ {file = "jedi-0.19.2-py2.py3-none-any.whl", hash = "sha256:a8ef22bde8490f57fe5c7681a3c83cb58874daf72b4784de3cce5b6ef6edb5b9"}, {file = "jedi-0.19.2.tar.gz", hash = "sha256:4770dc3de41bde3966b02eb84fbcf557fb33cce26ad23da12c742fb50ecb11f0"}, @@ -478,6 +500,7 @@ version = "3.1.6" description = "A very fast and expressive template engine." optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67"}, {file = "jinja2-3.1.6.tar.gz", hash = "sha256:0137fb05990d35f1275a587e9aee6d56da821fc83491a0fb838183be43f66d6d"}, @@ -495,6 +518,7 @@ version = "3.0.13" description = "Jupyter interactive widgets for JupyterLab" optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "jupyterlab_widgets-3.0.13-py3-none-any.whl", hash = "sha256:e3cda2c233ce144192f1e29914ad522b2f4c40e77214b0cc97377ca3d323db54"}, {file = "jupyterlab_widgets-3.0.13.tar.gz", hash = "sha256:a2966d385328c1942b683a8cd96b89b8dd82c8b8f81dda902bb2bc06d46f5bed"}, @@ -506,6 +530,7 @@ version = "2.1.5" description = "Safely add untrusted strings to HTML/XML markup." optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "MarkupSafe-2.1.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a17a92de5231666cfbe003f0e4b9b3a7ae3afb1ec2845aadc2bacc93ff85febc"}, {file = "MarkupSafe-2.1.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:72b6be590cc35924b02c78ef34b467da4ba07e4e0f0454a2c5907f473fc50ce5"}, @@ -575,6 +600,7 @@ version = "0.1.7" description = "Inline Matplotlib backend for Jupyter" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "matplotlib_inline-0.1.7-py3-none-any.whl", hash = "sha256:df192d39a4ff8f21b1895d72e6a13f5fcc5099f00fa84384e0ea28c2cc0653ca"}, {file = "matplotlib_inline-0.1.7.tar.gz", hash = "sha256:8423b23ec666be3d16e16b60bdd8ac4e86e840ebd1dd11a30b9f117f2fa0ab90"}, @@ -589,6 +615,7 @@ version = "3.1" description = "Python package for creating and manipulating graphs and networks" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "networkx-3.1-py3-none-any.whl", hash = "sha256:4f33f68cb2afcf86f28a45f43efc27a9386b535d567d2127f8f61d51dec58d36"}, {file = "networkx-3.1.tar.gz", hash = "sha256:de346335408f84de0eada6ff9fafafff9bcda11f0a0dfaa931133debb146ab61"}, @@ -607,6 +634,8 @@ version = "1.24.4" description = "Fundamental package for array computing in Python" optional = false python-versions = ">=3.8" +groups = ["main"] +markers = "python_version < \"3.12\"" files = [ {file = "numpy-1.24.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c0bfb52d2169d58c1cdb8cc1f16989101639b34c7d3ce60ed70b19c63eba0b64"}, {file = "numpy-1.24.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ed094d4f0c177b1b8e7aa9cba7d6ceed51c0e569a5318ac0ca9a090680a6a1b1"}, @@ -638,12 +667,60 @@ files = [ {file = "numpy-1.24.4.tar.gz", hash = "sha256:80f5e3a4e498641401868df4208b74581206afbee7cf7b8329daae82676d9463"}, ] +[[package]] +name = "numpy" +version = "1.26.4" +description = "Fundamental package for array computing in Python" +optional = false +python-versions = ">=3.9" +groups = ["main"] +markers = "python_version == \"3.12\"" +files = [ + {file = "numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0"}, + {file = "numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a"}, + {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4"}, + {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f"}, + {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a"}, + {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2"}, + {file = "numpy-1.26.4-cp310-cp310-win32.whl", hash = "sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07"}, + {file = "numpy-1.26.4-cp310-cp310-win_amd64.whl", hash = "sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5"}, + {file = "numpy-1.26.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71"}, + {file = "numpy-1.26.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef"}, + {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e"}, + {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5"}, + {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a"}, + {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a"}, + {file = "numpy-1.26.4-cp311-cp311-win32.whl", hash = "sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20"}, + {file = "numpy-1.26.4-cp311-cp311-win_amd64.whl", hash = "sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2"}, + {file = "numpy-1.26.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218"}, + {file = "numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b"}, + {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b"}, + {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed"}, + {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a"}, + {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0"}, + {file = "numpy-1.26.4-cp312-cp312-win32.whl", hash = "sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110"}, + {file = "numpy-1.26.4-cp312-cp312-win_amd64.whl", hash = "sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818"}, + {file = "numpy-1.26.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7349ab0fa0c429c82442a27a9673fc802ffdb7c7775fad780226cb234965e53c"}, + {file = "numpy-1.26.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:52b8b60467cd7dd1e9ed082188b4e6bb35aa5cdd01777621a1658910745b90be"}, + {file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5241e0a80d808d70546c697135da2c613f30e28251ff8307eb72ba696945764"}, + {file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3"}, + {file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:679b0076f67ecc0138fd2ede3a8fd196dddc2ad3254069bcb9faf9a79b1cebcd"}, + {file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:47711010ad8555514b434df65f7d7b076bb8261df1ca9bb78f53d3b2db02e95c"}, + {file = "numpy-1.26.4-cp39-cp39-win32.whl", hash = "sha256:a354325ee03388678242a4d7ebcd08b5c727033fcff3b2f536aea978e15ee9e6"}, + {file = "numpy-1.26.4-cp39-cp39-win_amd64.whl", hash = "sha256:3373d5d70a5fe74a2c1bb6d2cfd9609ecf686d47a2d7b1d37a8f3b6bf6003aea"}, + {file = "numpy-1.26.4-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:afedb719a9dcfc7eaf2287b839d8198e06dcd4cb5d276a3df279231138e83d30"}, + {file = "numpy-1.26.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95a7476c59002f2f6c590b9b7b998306fba6a5aa646b1e22ddfeaf8f78c3a29c"}, + {file = "numpy-1.26.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7e50d0a0cc3189f9cb0aeb3a6a6af18c16f59f004b866cd2be1c14b36134a4a0"}, + {file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"}, +] + [[package]] name = "packaging" version = "24.2" description = "Core utilities for Python packages" optional = false python-versions = ">=3.8" +groups = ["main", "dev"] files = [ {file = "packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759"}, {file = "packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f"}, @@ -655,6 +732,7 @@ version = "2.0.3" description = "Powerful data structures for data analysis, time series, and statistics" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "pandas-2.0.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e4c7c9f27a4185304c7caf96dc7d91bc60bc162221152de697c98eb0b2648dd8"}, {file = "pandas-2.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f167beed68918d62bffb6ec64f2e1d8a7d297a038f86d4aed056b9493fca407f"}, @@ -687,7 +765,7 @@ files = [ numpy = [ {version = ">=1.20.3", markers = "python_version < \"3.10\""}, {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, - {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, + {version = ">=1.21.0", markers = "python_version == \"3.10\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -722,6 +800,7 @@ version = "0.8.4" description = "A Python Parser" optional = false python-versions = ">=3.6" +groups = ["main"] files = [ {file = "parso-0.8.4-py2.py3-none-any.whl", hash = "sha256:a418670a20291dacd2dddc80c377c5c3791378ee1e8d12bffc35420643d43f18"}, {file = "parso-0.8.4.tar.gz", hash = "sha256:eb3a7b58240fb99099a345571deecc0f9540ea5f4dd2fe14c2a99d6b281ab92d"}, @@ -737,6 +816,8 @@ version = "4.9.0" description = "Pexpect allows easy control of interactive console applications." optional = false python-versions = "*" +groups = ["main"] +markers = "sys_platform != \"win32\"" files = [ {file = "pexpect-4.9.0-py2.py3-none-any.whl", hash = "sha256:7236d1e080e4936be2dc3e326cec0af72acf9212a7e1d060210e70a47e253523"}, {file = "pexpect-4.9.0.tar.gz", hash = "sha256:ee7d41123f3c9911050ea2c2dac107568dc43b2d3b0c7557a33212c398ead30f"}, @@ -751,6 +832,7 @@ version = "0.7.5" description = "Tiny 'shelve'-like database with concurrency support" optional = false python-versions = "*" +groups = ["main"] files = [ {file = "pickleshare-0.7.5-py2.py3-none-any.whl", hash = "sha256:9649af414d74d4df115d5d718f82acb59c9d418196b7b4290ed47a12ce62df56"}, {file = "pickleshare-0.7.5.tar.gz", hash = "sha256:87683d47965c1da65cdacaf31c8441d12b8044cdec9aca500cd78fc2c683afca"}, @@ -762,6 +844,7 @@ version = "1.5.0" description = "plugin and hook calling mechanisms for python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"}, {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"}, @@ -777,6 +860,7 @@ version = "3.0.50" description = "Library for building powerful interactive command lines in Python" optional = false python-versions = ">=3.8.0" +groups = ["main"] files = [ {file = "prompt_toolkit-3.0.50-py3-none-any.whl", hash = "sha256:9b6427eb19e479d98acff65196a307c555eb567989e6d88ebbb1b509d9779198"}, {file = "prompt_toolkit-3.0.50.tar.gz", hash = "sha256:544748f3860a2623ca5cd6d2795e7a14f3d0e1c3c9728359013f79877fc89bab"}, @@ -791,6 +875,7 @@ version = "2.9.10" description = "psycopg2 - Python-PostgreSQL Database Adapter" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "psycopg2-2.9.10-cp310-cp310-win32.whl", hash = "sha256:5df2b672140f95adb453af93a7d669d7a7bf0a56bcd26f1502329166f4a61716"}, {file = "psycopg2-2.9.10-cp310-cp310-win_amd64.whl", hash = "sha256:c6f7b8561225f9e711a9c47087388a97fdc948211c10a4bccbf0ba68ab7b3b5a"}, @@ -809,6 +894,8 @@ version = "0.7.0" description = "Run a subprocess in a pseudo terminal" optional = false python-versions = "*" +groups = ["main"] +markers = "sys_platform != \"win32\"" files = [ {file = "ptyprocess-0.7.0-py2.py3-none-any.whl", hash = "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35"}, {file = "ptyprocess-0.7.0.tar.gz", hash = "sha256:5c5d0a3b48ceee0b48485e0c26037c0acd7d29765ca3fbb5cb3831d347423220"}, @@ -820,6 +907,7 @@ version = "0.2.3" description = "Safely evaluate AST nodes without side effects" optional = false python-versions = "*" +groups = ["main"] files = [ {file = "pure_eval-0.2.3-py3-none-any.whl", hash = "sha256:1db8e35b67b3d218d818ae653e27f06c3aa420901fa7b081ca98cbedc874e0d0"}, {file = "pure_eval-0.2.3.tar.gz", hash = "sha256:5f4e983f40564c576c7c8635ae88db5956bb2229d7e9237d03b3c0b0190eaf42"}, @@ -834,6 +922,7 @@ version = "2.10.6" description = "Data validation using Python type hints" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "pydantic-2.10.6-py3-none-any.whl", hash = "sha256:427d664bf0b8a2b34ff5dd0f5a18df00591adcee7198fbd71981054cef37b584"}, {file = "pydantic-2.10.6.tar.gz", hash = "sha256:ca5daa827cce33de7a42be142548b0096bf05a7e7b365aebfa5f8eeec7128236"}, @@ -846,7 +935,7 @@ typing-extensions = ">=4.12.2" [package.extras] email = ["email-validator (>=2.0.0)"] -timezone = ["tzdata"] +timezone = ["tzdata ; python_version >= \"3.9\" and platform_system == \"Windows\""] [[package]] name = "pydantic-core" @@ -854,6 +943,7 @@ version = "2.27.2" description = "Core functionality for Pydantic validation and serialization" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "pydantic_core-2.27.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:2d367ca20b2f14095a8f4fa1210f5a7b78b8a20009ecced6b12818f455b1e9fa"}, {file = "pydantic_core-2.27.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:491a2b73db93fab69731eaee494f320faa4e093dbed776be1a829c2eb222c34c"}, @@ -966,6 +1056,7 @@ version = "2.19.1" description = "Pygments is a syntax highlighting package written in Python." optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "pygments-2.19.1-py3-none-any.whl", hash = "sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c"}, {file = "pygments-2.19.1.tar.gz", hash = "sha256:61c16d2a8576dc0649d9f39e089b5f02bcd27fba10d8fb4dcc28173f7a45151f"}, @@ -980,6 +1071,7 @@ version = "8.3.4" description = "pytest: simple powerful testing with Python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "pytest-8.3.4-py3-none-any.whl", hash = "sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6"}, {file = "pytest-8.3.4.tar.gz", hash = "sha256:965370d062bce11e73868e0335abac31b4d3de0e82f4007408d242b4f8610761"}, @@ -1002,6 +1094,7 @@ version = "5.0.0" description = "Pytest plugin for measuring coverage." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "pytest-cov-5.0.0.tar.gz", hash = "sha256:5837b58e9f6ebd335b0f8060eecce69b662415b16dc503883a02f45dfeb14857"}, {file = "pytest_cov-5.0.0-py3-none-any.whl", hash = "sha256:4f0764a1219df53214206bf1feea4633c3b558a2925c8b59f144f682861ce652"}, @@ -1020,6 +1113,7 @@ version = "2.9.0.post0" description = "Extensions to the standard Python datetime module" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +groups = ["main"] files = [ {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"}, {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"}, @@ -1034,6 +1128,7 @@ version = "2024.2" description = "World timezone definitions, modern and historical" optional = false python-versions = "*" +groups = ["main"] files = [ {file = "pytz-2024.2-py2.py3-none-any.whl", hash = "sha256:31c7c1817eb7fae7ca4b8c7ee50c72f93aa2dd863de768e1ef4245d426aa0725"}, {file = "pytz-2024.2.tar.gz", hash = "sha256:2aa355083c50a0f93fa581709deac0c9ad65cca8a9e9beac660adcbd493c798a"}, @@ -1045,6 +1140,7 @@ version = "6.0.2" description = "YAML parser and emitter for Python" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "PyYAML-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0a9a2848a5b7feac301353437eb7d5957887edbf81d56e903999a75a3d743086"}, {file = "PyYAML-6.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:29717114e51c84ddfba879543fb232a6ed60086602313ca38cce623c1d62cfbf"}, @@ -1101,12 +1197,43 @@ files = [ {file = "pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e"}, ] +[[package]] +name = "ruff" +version = "0.14.3" +description = "An extremely fast Python linter and code formatter, written in Rust." +optional = false +python-versions = ">=3.7" +groups = ["dev"] +files = [ + {file = "ruff-0.14.3-py3-none-linux_armv6l.whl", hash = "sha256:876b21e6c824f519446715c1342b8e60f97f93264012de9d8d10314f8a79c371"}, + {file = "ruff-0.14.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:b6fd8c79b457bedd2abf2702b9b472147cd860ed7855c73a5247fa55c9117654"}, + {file = "ruff-0.14.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:71ff6edca490c308f083156938c0c1a66907151263c4abdcb588602c6e696a14"}, + {file = "ruff-0.14.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:786ee3ce6139772ff9272aaf43296d975c0217ee1b97538a98171bf0d21f87ed"}, + {file = "ruff-0.14.3-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:cd6291d0061811c52b8e392f946889916757610d45d004e41140d81fb6cd5ddc"}, + {file = "ruff-0.14.3-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a497ec0c3d2c88561b6d90f9c29f5ae68221ac00d471f306fa21fa4264ce5fcd"}, + {file = "ruff-0.14.3-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:e231e1be58fc568950a04fbe6887c8e4b85310e7889727e2b81db205c45059eb"}, + {file = "ruff-0.14.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:469e35872a09c0e45fecf48dd960bfbce056b5db2d5e6b50eca329b4f853ae20"}, + {file = "ruff-0.14.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3d6bc90307c469cb9d28b7cfad90aaa600b10d67c6e22026869f585e1e8a2db0"}, + {file = "ruff-0.14.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0e2f8a0bbcffcfd895df39c9a4ecd59bb80dca03dc43f7fb63e647ed176b741e"}, + {file = "ruff-0.14.3-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:678fdd7c7d2d94851597c23ee6336d25f9930b460b55f8598e011b57c74fd8c5"}, + {file = "ruff-0.14.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:1ec1ac071e7e37e0221d2f2dbaf90897a988c531a8592a6a5959f0603a1ecf5e"}, + {file = "ruff-0.14.3-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:afcdc4b5335ef440d19e7df9e8ae2ad9f749352190e96d481dc501b753f0733e"}, + {file = "ruff-0.14.3-py3-none-musllinux_1_2_i686.whl", hash = "sha256:7bfc42f81862749a7136267a343990f865e71fe2f99cf8d2958f684d23ce3dfa"}, + {file = "ruff-0.14.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:a65e448cfd7e9c59fae8cf37f9221585d3354febaad9a07f29158af1528e165f"}, + {file = "ruff-0.14.3-py3-none-win32.whl", hash = "sha256:f3d91857d023ba93e14ed2d462ab62c3428f9bbf2b4fbac50a03ca66d31991f7"}, + {file = "ruff-0.14.3-py3-none-win_amd64.whl", hash = "sha256:d7b7006ac0756306db212fd37116cce2bd307e1e109375e1c6c106002df0ae5f"}, + {file = "ruff-0.14.3-py3-none-win_arm64.whl", hash = "sha256:26eb477ede6d399d898791d01961e16b86f02bc2486d0d1a7a9bb2379d055dc1"}, + {file = "ruff-0.14.3.tar.gz", hash = "sha256:4ff876d2ab2b161b6de0aa1f5bd714e8e9b4033dc122ee006925fbacc4f62153"}, +] + [[package]] name = "scipy" version = "1.10.1" description = "Fundamental algorithms for scientific computing in Python" optional = false python-versions = "<3.12,>=3.8" +groups = ["main"] +markers = "python_version < \"3.12\"" files = [ {file = "scipy-1.10.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e7354fd7527a4b0377ce55f286805b34e8c54b91be865bac273f527e1b839019"}, {file = "scipy-1.10.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:4b3f429188c66603a1a5c549fb414e4d3bdc2a24792e061ffbd607d3d75fd84e"}, @@ -1139,12 +1266,65 @@ dev = ["click", "doit (>=0.36.0)", "flake8", "mypy", "pycodestyle", "pydevtool", doc = ["matplotlib (>2)", "numpydoc", "pydata-sphinx-theme (==0.9.0)", "sphinx (!=4.1.0)", "sphinx-design (>=0.2.0)"] test = ["asv", "gmpy2", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] +[[package]] +name = "scipy" +version = "1.14.1" +description = "Fundamental algorithms for scientific computing in Python" +optional = false +python-versions = ">=3.10" +groups = ["main"] +markers = "python_version == \"3.12\"" +files = [ + {file = "scipy-1.14.1-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:b28d2ca4add7ac16ae8bb6632a3c86e4b9e4d52d3e34267f6e1b0c1f8d87e389"}, + {file = "scipy-1.14.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:d0d2821003174de06b69e58cef2316a6622b60ee613121199cb2852a873f8cf3"}, + {file = "scipy-1.14.1-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:8bddf15838ba768bb5f5083c1ea012d64c9a444e16192762bd858f1e126196d0"}, + {file = "scipy-1.14.1-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:97c5dddd5932bd2a1a31c927ba5e1463a53b87ca96b5c9bdf5dfd6096e27efc3"}, + {file = "scipy-1.14.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2ff0a7e01e422c15739ecd64432743cf7aae2b03f3084288f399affcefe5222d"}, + {file = "scipy-1.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e32dced201274bf96899e6491d9ba3e9a5f6b336708656466ad0522d8528f69"}, + {file = "scipy-1.14.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:8426251ad1e4ad903a4514712d2fa8fdd5382c978010d1c6f5f37ef286a713ad"}, + {file = "scipy-1.14.1-cp310-cp310-win_amd64.whl", hash = "sha256:a49f6ed96f83966f576b33a44257d869756df6cf1ef4934f59dd58b25e0327e5"}, + {file = "scipy-1.14.1-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:2da0469a4ef0ecd3693761acbdc20f2fdeafb69e6819cc081308cc978153c675"}, + {file = "scipy-1.14.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:c0ee987efa6737242745f347835da2cc5bb9f1b42996a4d97d5c7ff7928cb6f2"}, + {file = "scipy-1.14.1-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:3a1b111fac6baec1c1d92f27e76511c9e7218f1695d61b59e05e0fe04dc59617"}, + {file = "scipy-1.14.1-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:8475230e55549ab3f207bff11ebfc91c805dc3463ef62eda3ccf593254524ce8"}, + {file = "scipy-1.14.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:278266012eb69f4a720827bdd2dc54b2271c97d84255b2faaa8f161a158c3b37"}, + {file = "scipy-1.14.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fef8c87f8abfb884dac04e97824b61299880c43f4ce675dd2cbeadd3c9b466d2"}, + {file = "scipy-1.14.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b05d43735bb2f07d689f56f7b474788a13ed8adc484a85aa65c0fd931cf9ccd2"}, + {file = "scipy-1.14.1-cp311-cp311-win_amd64.whl", hash = "sha256:716e389b694c4bb564b4fc0c51bc84d381735e0d39d3f26ec1af2556ec6aad94"}, + {file = "scipy-1.14.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:631f07b3734d34aced009aaf6fedfd0eb3498a97e581c3b1e5f14a04164a456d"}, + {file = "scipy-1.14.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:af29a935803cc707ab2ed7791c44288a682f9c8107bc00f0eccc4f92c08d6e07"}, + {file = "scipy-1.14.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:2843f2d527d9eebec9a43e6b406fb7266f3af25a751aa91d62ff416f54170bc5"}, + {file = "scipy-1.14.1-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:eb58ca0abd96911932f688528977858681a59d61a7ce908ffd355957f7025cfc"}, + {file = "scipy-1.14.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:30ac8812c1d2aab7131a79ba62933a2a76f582d5dbbc695192453dae67ad6310"}, + {file = "scipy-1.14.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f9ea80f2e65bdaa0b7627fb00cbeb2daf163caa015e59b7516395fe3bd1e066"}, + {file = "scipy-1.14.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:edaf02b82cd7639db00dbff629995ef185c8df4c3ffa71a5562a595765a06ce1"}, + {file = "scipy-1.14.1-cp312-cp312-win_amd64.whl", hash = "sha256:2ff38e22128e6c03ff73b6bb0f85f897d2362f8c052e3b8ad00532198fbdae3f"}, + {file = "scipy-1.14.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:1729560c906963fc8389f6aac023739ff3983e727b1a4d87696b7bf108316a79"}, + {file = "scipy-1.14.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:4079b90df244709e675cdc8b93bfd8a395d59af40b72e339c2287c91860deb8e"}, + {file = "scipy-1.14.1-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:e0cf28db0f24a38b2a0ca33a85a54852586e43cf6fd876365c86e0657cfe7d73"}, + {file = "scipy-1.14.1-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:0c2f95de3b04e26f5f3ad5bb05e74ba7f68b837133a4492414b3afd79dfe540e"}, + {file = "scipy-1.14.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b99722ea48b7ea25e8e015e8341ae74624f72e5f21fc2abd45f3a93266de4c5d"}, + {file = "scipy-1.14.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5149e3fd2d686e42144a093b206aef01932a0059c2a33ddfa67f5f035bdfe13e"}, + {file = "scipy-1.14.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e4f5a7c49323533f9103d4dacf4e4f07078f360743dec7f7596949149efeec06"}, + {file = "scipy-1.14.1-cp313-cp313-win_amd64.whl", hash = "sha256:baff393942b550823bfce952bb62270ee17504d02a1801d7fd0719534dfb9c84"}, + {file = "scipy-1.14.1.tar.gz", hash = "sha256:5a275584e726026a5699459aa72f828a610821006228e841b94275c4a7c08417"}, +] + +[package.dependencies] +numpy = ">=1.23.5,<2.3" + +[package.extras] +dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy (==1.10.0)", "pycodestyle", "pydevtool", "rich-click", "ruff (>=0.0.292)", "types-psutil", "typing_extensions"] +doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.13.1)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0,<=7.3.7)", "sphinx-design (>=0.4.0)"] +test = ["Cython", "array-api-strict (>=2.0)", "asv", "gmpy2", "hypothesis (>=6.30)", "meson", "mpmath", "ninja ; sys_platform != \"emscripten\"", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] + [[package]] name = "six" version = "1.17.0" description = "Python 2 and 3 compatibility utilities" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +groups = ["main"] files = [ {file = "six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274"}, {file = "six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81"}, @@ -1156,6 +1336,7 @@ version = "2.0.37" description = "Database Abstraction Library" optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "SQLAlchemy-2.0.37-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:da36c3b0e891808a7542c5c89f224520b9a16c7f5e4d6a1156955605e54aef0e"}, {file = "SQLAlchemy-2.0.37-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e7402ff96e2b073a98ef6d6142796426d705addd27b9d26c3b32dbaa06d7d069"}, @@ -1251,6 +1432,7 @@ version = "0.6.3" description = "Extract data from python stack frames and tracebacks for informative displays" optional = false python-versions = "*" +groups = ["main"] files = [ {file = "stack_data-0.6.3-py3-none-any.whl", hash = "sha256:d5558e0c25a4cb0853cddad3d77da9891a08cb85dd9f9f91b9f8cd66e511e695"}, {file = "stack_data-0.6.3.tar.gz", hash = "sha256:836a778de4fec4dcd1dcd89ed8abff8a221f58308462e1c4aa2a3cf30148f0b9"}, @@ -1270,6 +1452,8 @@ version = "2.2.1" description = "A lil' TOML parser" optional = false python-versions = ">=3.8" +groups = ["dev"] +markers = "python_version <= \"3.10\"" files = [ {file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"}, {file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"}, @@ -1311,6 +1495,7 @@ version = "4.67.1" description = "Fast, Extensible Progress Meter" optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2"}, {file = "tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2"}, @@ -1332,6 +1517,7 @@ version = "5.14.3" description = "Traitlets Python configuration system" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "traitlets-5.14.3-py3-none-any.whl", hash = "sha256:b74e89e397b1ed28cc831db7aea759ba6640cb3de13090ca145426688ff1ac4f"}, {file = "traitlets-5.14.3.tar.gz", hash = "sha256:9ed0579d3502c94b4b3732ac120375cda96f923114522847de4b3bb98b96b6b7"}, @@ -1347,6 +1533,7 @@ version = "4.12.2" description = "Backported and Experimental Type Hints for Python 3.8+" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"}, {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, @@ -1358,6 +1545,7 @@ version = "2025.1" description = "Provider of IANA time zone data" optional = false python-versions = ">=2" +groups = ["main"] files = [ {file = "tzdata-2025.1-py2.py3-none-any.whl", hash = "sha256:7e127113816800496f027041c570f50bcd464a020098a3b6b199517772303639"}, {file = "tzdata-2025.1.tar.gz", hash = "sha256:24894909e88cdb28bd1636c6887801df64cb485bd593f2fd83ef29075a81d694"}, @@ -1369,6 +1557,7 @@ version = "0.2.13" description = "Measures the displayed width of unicode strings in a terminal" optional = false python-versions = "*" +groups = ["main"] files = [ {file = "wcwidth-0.2.13-py2.py3-none-any.whl", hash = "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859"}, {file = "wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5"}, @@ -1380,12 +1569,13 @@ version = "4.0.13" description = "Jupyter interactive widgets for Jupyter Notebook" optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "widgetsnbextension-4.0.13-py3-none-any.whl", hash = "sha256:74b2692e8500525cc38c2b877236ba51d34541e6385eeed5aec15a70f88a6c71"}, {file = "widgetsnbextension-4.0.13.tar.gz", hash = "sha256:ffcb67bc9febd10234a362795f643927f4e0c05d9342c727b65d2384f8feacb6"}, ] [metadata] -lock-version = "2.0" -python-versions = ">=3.8.10,<3.12" -content-hash = "7ef5e6a3bec2bcef8429f74816408f554f0d021da19349481077a67065489833" +lock-version = "2.1" +python-versions = ">=3.8.10,<3.13" +content-hash = "a588906e6e90991a51af15f24a95ed2841c834dc72b79872e843b54e889f56c5" diff --git a/pyproject.toml b/pyproject.toml index 7600f55..34f0901 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,28 +10,74 @@ include = [ {path = "biasanalyzer/sql_templates/*.sql", format=["sdist", "wheel"]} ] [tool.poetry.dependencies] -python = ">=3.8.10,<3.12" +python = ">=3.8.10,<3.13" duckdb = "^1.1.1" pandas = "2.0.3" -scipy = "1.10.1" -numpy = "1.24.4" + +scipy = [ + {version = ">=1.10.1,<1.11", markers = "python_version<'3.12'"}, + {version = ">=1.14.1,<1.15", markers = "python_version>='3.12'"} +] +numpy = [ + {version = ">=1.24.4,<1.25", markers = "python_version<'3.12'"}, + {version = ">=1.25.0,<=1.26.4", markers = "python_version>='3.12'"} +] duckdb-engine = "^0.13.2" sqlalchemy = "^2.0.35" pyyaml = "^6.0.2" pydantic = "^2.9.2" -psycopg2 = "^2.9.1" +psycopg2 = "^2.9.9" ipytree = "^0.2.2" ipywidgets = "^8.1.5" jinja2 = "3.1.6" tqdm = "4.67.1" networkx = "3.1" -[tool.poetry.dev-dependencies] -pytest = "^8.3.3" - [tool.poetry.group.dev.dependencies] +pytest = "^8.3.3" pytest-cov = "5.0.0" +ruff = "^0.14.3" [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" + +# ---------------------------- +# Ruff configuration +# ---------------------------- +[tool.ruff] +# Linting targets your source and test code +target-version = "py38" +src = ["biasanalyzer", "tests", "scripts"] +line-length = 120 + +# Enable full linting set roughly equivalent to Flake8 + isort + pyupgrade +lint.select = [ + "E", # pycodestyle errors + "F", # pyflakes + "I", # isort-style import sorting + "B", # flake8-bugbear + "UP", # pyupgrade + "N", # pep8-naming + "S", # security (flake8-bandit subset) +] +lint.ignore = ["S101", "N805", "S701"] + +# Automatically fix simple issues on save (optional) +fix = true + +# Exclude build/test caches and notebooks +exclude = [ + ".venv", + ".git", + "__pycache__", + "build", + "dist", + "notebooks", +] + +[tool.ruff.format] +# Ruff's built-in formatter replaces Black +indent-style = "space" +docstring-code-format = true +docstring-code-line-length = "dynamic" diff --git a/scripts/ingest_csvs_to_omop_duckdb.py b/scripts/ingest_csvs_to_omop_duckdb.py index f4d8525..a96ec38 100644 --- a/scripts/ingest_csvs_to_omop_duckdb.py +++ b/scripts/ingest_csvs_to_omop_duckdb.py @@ -1,3 +1,5 @@ +# ruff: noqa: S608 + """ This script ingests both clinical and vocabulary OMOP CSV exports into a single DckDB database for downstream use of the core BiasAnalyzer python library. @@ -8,17 +10,18 @@ --output data/omop.duckdb """ -import duckdb -import time import argparse import sys +import time from pathlib import Path +import duckdb + def load_csv_to_duckdb(con, csv_path: Path, table_name: str): """Load a single CSV file into DuckDB.""" t0 = time.time() - print(f'loading {table_name} from {csv_path}') + print(f"loading {table_name} from {csv_path}") con.execute(f""" CREATE OR REPLACE TABLE {table_name} AS SELECT * FROM read_csv_auto('{csv_path}', header=True, quote='', parallel=True) @@ -45,12 +48,19 @@ def ingest_directory(con, csv_dir: Path): def main(): parser = argparse.ArgumentParser(description="Ingest OMOP CSVs into DuckDB") - parser.add_argument("--clinical", type=Path, required=False, - help="Directory containing OMOP clinical CSVs (person, condition_occurrence, etc.)") - parser.add_argument("--vocab", type=Path, required=False, - help="Directory containing OMOP vocabulary CSVs (concept, concept_relationship, etc.)") - parser.add_argument("--output", type=Path, required=True, - help="Output DuckDB file path") + parser.add_argument( + "--clinical", + type=Path, + required=False, + help="Directory containing OMOP clinical CSVs (person, condition_occurrence, etc.)", + ) + parser.add_argument( + "--vocab", + type=Path, + required=False, + help="Directory containing OMOP vocabulary CSVs (concept, concept_relationship, etc.)", + ) + parser.add_argument("--output", type=Path, required=True, help="Output DuckDB file path") args = parser.parse_args() @@ -85,5 +95,6 @@ def main(): print(f"Ingestion complete with {len(all_results)} tables loaded. Details shown below:") print(f"\n{all_results}") + if __name__ == "__main__": main() diff --git a/tests/conftest.py b/tests/conftest.py index 1ae39cc..e71d1ed 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,9 @@ -import pytest +import os + import duckdb +import pytest from biasanalyzer.api import BIAS from biasanalyzer.config import load_config -import os @pytest.fixture @@ -15,9 +16,9 @@ def fresh_bias_obj(): @pytest.fixture(scope="session") def test_db(): - config_file = os.path.join(os.path.dirname(__file__), 'assets', 'config', 'test_config.yaml') + config_file = os.path.join(os.path.dirname(__file__), "assets", "config", "test_config.yaml") config = load_config(config_file) - db_path = config['root_omop_cdm_database']['database'] + db_path = config["root_omop_cdm_database"]["database"] conn = duckdb.connect(db_path) conn.execute(""" CREATE TABLE IF NOT EXISTS person ( @@ -145,7 +146,7 @@ def test_db(): (37311061, 'COVID-19', '2012-04-01', '2020-04-01', '840539006', 'SNOMED', 'Condition'), (4041664, 'Difficulty breathing', '2012-04-01', '2020-04-01', '230145002', 'SNOMED', 'Condition'), (316139, 'Heart failure', '2012-04-01', '2020-04-01', '84114007', 'SNOMED', 'Condition'), - (201826, 'Type 2 diabetes mellitus', '2012-04-01', '2020-04-01', '44054006', 'SNOMED', 'Condition'); + (201826, 'Type 2 diabetes mellitus', '2012-04-01', '2020-04-01', '44054006', 'SNOMED', 'Condition') """) # Insert hierarchical relationships as needed @@ -175,7 +176,8 @@ def test_db(): result = conn.execute("SELECT COUNT(*) FROM condition_occurrence").fetchone() if result[0] == 0: conn.execute(""" - INSERT INTO condition_occurrence (person_id, condition_concept_id, condition_start_date, condition_end_date) + INSERT INTO condition_occurrence (person_id, condition_concept_id, condition_start_date, + condition_end_date) VALUES (101, 2, '2023-01-01', '2023-01-31'), -- Patient 101 has Type 1 Diabetes (101, 3, '2023-01-01', '2023-02-27'), -- Patient 101 has Type 2 Diabetes @@ -215,7 +217,8 @@ def test_db(): result = conn.execute("SELECT COUNT(*) FROM visit_occurrence").fetchone() if result[0] == 0: conn.execute(""" - INSERT INTO visit_occurrence (person_id, visit_occurrence_id, visit_concept_id, visit_start_date, visit_end_date) + INSERT INTO visit_occurrence (person_id, visit_occurrence_id, visit_concept_id, visit_start_date, + visit_end_date) VALUES (108, 1, 9201, '2020-04-13', '2020-04-14'), -- Inpatient Visit (108, 2, 9201, '2020-04-16', '2020-04-27'), -- Second inpatient visit (meets criteria) @@ -239,7 +242,8 @@ def test_db(): result = conn.execute("SELECT COUNT(*) FROM procedure_occurrence").fetchone() if result[0] == 0: conn.execute(""" - INSERT INTO procedure_occurrence (person_id, procedure_occurrence_id, procedure_concept_id, procedure_date) + INSERT INTO procedure_occurrence (person_id, procedure_occurrence_id, procedure_concept_id, + procedure_date) VALUES (1, 1, 4048609, '2020-06-20'), -- Person 1: Blood test (2, 2, 4048609, '2020-06-20'), -- Person 2: Blood test @@ -255,7 +259,8 @@ def test_db(): result = conn.execute("SELECT COUNT(*) FROM drug_exposure").fetchone() if result[0] == 0: conn.execute(""" - INSERT INTO drug_exposure (person_id, drug_concept_id, drug_exposure_start_date, drug_exposure_end_date) + INSERT INTO drug_exposure (person_id, drug_concept_id, drug_exposure_start_date, + drug_exposure_end_date) VALUES (1, 4285892, '2020-06-15', '2020-06-15'), -- Person 1: Insulin 14 days after (2, 4285892, '2020-06-15', '2020-06-15'), -- Person 2: Insulin @@ -266,7 +271,6 @@ def test_db(): -- Person 5: No insulin """) - # mock configuration file bias = BIAS(config_file_path=config_file) bias.set_root_omop() diff --git a/tests/query_based/test_cohort_creation.py b/tests/query_based/test_cohort_creation.py index 8f82190..1bc3226 100644 --- a/tests/query_based/test_cohort_creation.py +++ b/tests/query_based/test_cohort_creation.py @@ -1,25 +1,23 @@ -import os import datetime import logging +import os + import pytest -from sqlalchemy.exc import SQLAlchemyError -from numpy.ma.testutils import assert_equal from biasanalyzer.models import DemographicsCriteria, TemporalEvent, TemporalEventGroup +from numpy.ma.testutils import assert_equal +from sqlalchemy.exc import SQLAlchemyError def test_cohort_yaml_validation(test_db): invalid_data = { "gender": "female", "min_birth_year": 2000, - "max_birth_year": 1999 # Invalid: less than min_birth_year + "max_birth_year": 1999, # Invalid: less than min_birth_year } with pytest.raises(ValueError): DemographicsCriteria(**invalid_data) - invalid_data = { - "event_type": "date", - "event_concept_id": "dummy" - } + invalid_data = {"event_type": "date", "event_concept_id": "dummy"} # validate date event_type must have a timestamp field with pytest.raises(ValueError): TemporalEvent(**invalid_data) @@ -27,12 +25,10 @@ def test_cohort_yaml_validation(test_db): invalid_data = { "operator": "BEFORE", "events": [ - {'event_type': 'condition_occurrence', - 'event_concept_id': 201826}, - {'event_type': 'drug_exposure', - 'event_concept_id': 4285892}, + {"event_type": "condition_occurrence", "event_concept_id": 201826}, + {"event_type": "drug_exposure", "event_concept_id": 4285892}, ], - "interval": [100, 50] + "interval": [100, 50], } # validate interval start must be smaller than interval end with pytest.raises(ValueError): @@ -54,49 +50,60 @@ def test_cohort_yaml_validation(test_db): with pytest.raises(ValueError): TemporalEventGroup(**invalid_data) + def test_cohort_creation_baseline(caplog, test_db): bias = test_db cohort = bias.create_cohort( "COVID-19 patient", "Cohort of young female patients", - os.path.join(os.path.dirname(__file__), '..', 'assets', 'cohort_creation', - 'test_cohort_creation_condition_occurrence_config_baseline.yaml'), - "test_user" + os.path.join( + os.path.dirname(__file__), + "..", + "assets", + "cohort_creation", + "test_cohort_creation_condition_occurrence_config_baseline.yaml", + ), + "test_user", ) # Test cohort object and methods assert cohort is not None, "Cohort creation failed" cohort_id = cohort.cohort_id - assert bias.bias_db.get_cohort_definition(cohort_id)['name'] == "COVID-19 patient" + assert bias.bias_db.get_cohort_definition(cohort_id)["name"] == "COVID-19 patient" assert bias.bias_db.get_cohort_definition(cohort_id + 1) == {} assert cohort.metadata is not None, "Cohort creation wrongly returned None metadata" - assert 'creation_info' in cohort.metadata, "Cohort creation does not contain 'creation_info' key" + assert "creation_info" in cohort.metadata, "Cohort creation does not contain 'creation_info' key" assert cohort.data is not None, "Cohort creation wrongly returned None data" caplog.clear() with caplog.at_level(logging.ERROR): - cohort.get_distributions('ethnicity') + cohort.get_distributions("ethnicity") assert "Distribution for variable 'ethnicity' is not available" in caplog.text - assert len(cohort.get_distributions('age')) == 10, "Cohort get_distribution('age') does not return 10 age_bin items" - assert len(cohort.get_distributions('gender')) == 3, ("Cohort get_distribution('gender') does not return " - "3 gender_bin items") + assert len(cohort.get_distributions("age")) == 10, "Cohort get_distribution('age') does not return 10 age_bin items" + assert len(cohort.get_distributions("gender")) == 3, ( + "Cohort get_distribution('gender') does not return 3 gender_bin items" + ) - patient_ids = set([item['subject_id'] for item in cohort.data]) + patient_ids = set([item["subject_id"] for item in cohort.data]) assert_equal(len(patient_ids), 5) assert_equal(patient_ids, {106, 108, 110, 111, 112}) # select two patients to check for cohort_start_date and cohort_end_date automatically computed - patient_106 = next(item for item in cohort.data if item['subject_id'] == 106) - patient_108 = next(item for item in cohort.data if item['subject_id'] == 108) + patient_106 = next(item for item in cohort.data if item["subject_id"] == 106) + patient_108 = next(item for item in cohort.data if item["subject_id"] == 108) # Replace dates with actual values from your test data - assert_equal(patient_106['cohort_start_date'], datetime.date(2023, 3, 1), - "Incorrect cohort_start_date for patient 106") - assert_equal(patient_106['cohort_end_date'], datetime.date(2023, 3, 15), - "Incorrect cohort_end_date for patient 106") - assert_equal(patient_108['cohort_start_date'], datetime.date(2020, 4, 10), - "Incorrect cohort_start_date for patient 108") - assert_equal(patient_108['cohort_end_date'], datetime.date(2020, 4, 27), - "Incorrect cohort_end_date for patient 108") + assert_equal( + patient_106["cohort_start_date"], datetime.date(2023, 3, 1), "Incorrect cohort_start_date for patient 106" + ) + assert_equal( + patient_106["cohort_end_date"], datetime.date(2023, 3, 15), "Incorrect cohort_end_date for patient 106" + ) + assert_equal( + patient_108["cohort_start_date"], datetime.date(2020, 4, 10), "Incorrect cohort_start_date for patient 108" + ) + assert_equal( + patient_108["cohort_end_date"], datetime.date(2020, 4, 27), "Incorrect cohort_end_date for patient 108" + ) def test_cohort_creation_study(test_db): @@ -104,19 +111,25 @@ def test_cohort_creation_study(test_db): cohort = bias.create_cohort( "COVID-19 patient", "Cohort of young female patients with COVID-19", - os.path.join(os.path.dirname(__file__), '..', 'assets', 'cohort_creation', - 'test_cohort_creation_condition_occurrence_config_study.yaml'), - "test_user" + os.path.join( + os.path.dirname(__file__), + "..", + "assets", + "cohort_creation", + "test_cohort_creation_condition_occurrence_config_study.yaml", + ), + "test_user", ) # Test cohort object and methods assert cohort is not None, "Cohort creation failed" assert cohort.metadata is not None, "Cohort creation wrongly returned None metadata" - assert 'creation_info' in cohort.metadata, "Cohort creation does not contain 'creation_info' key" + assert "creation_info" in cohort.metadata, "Cohort creation does not contain 'creation_info' key" assert cohort.data is not None, "Cohort creation wrongly returned None data" - patient_ids = set([item['subject_id'] for item in cohort.data]) + patient_ids = set([item["subject_id"] for item in cohort.data]) assert_equal(len(patient_ids), 4) assert_equal(patient_ids, {108, 110, 111, 112}) + def test_cohort_creation_study2(caplog, test_db): bias = test_db caplog.clear() @@ -124,21 +137,27 @@ def test_cohort_creation_study2(caplog, test_db): cohort = bias.create_cohort( "COVID-19 patient", "Cohort of young female patients with no COVID-19", - os.path.join(os.path.dirname(__file__), '..', 'assets', 'cohort_creation', - 'test_cohort_creation_condition_occurrence_config_study2.yaml'), + os.path.join( + os.path.dirname(__file__), + "..", + "assets", + "cohort_creation", + "test_cohort_creation_condition_occurrence_config_study2.yaml", + ), "test_user", - delay=1 + delay=1, ) - assert 'Simulating long-running task' in caplog.text + assert "Simulating long-running task" in caplog.text # Test cohort object and methods assert cohort is not None, "Cohort creation failed" assert cohort.metadata is not None, "Cohort creation wrongly returned None metadata" - assert 'creation_info' in cohort.metadata, "Cohort creation does not contain 'creation_info' key" + assert "creation_info" in cohort.metadata, "Cohort creation does not contain 'creation_info' key" assert cohort.data is not None, "Cohort creation wrongly returned None data" - patient_ids = set([item['subject_id'] for item in cohort.data]) + patient_ids = set([item["subject_id"] for item in cohort.data]) assert_equal(len(patient_ids), 1) assert_equal(patient_ids, {106}) + def test_cohort_creation_all(caplog, test_db): bias = test_db cohort = bias.create_cohort( @@ -146,29 +165,35 @@ def test_cohort_creation_all(caplog, test_db): "Cohort of young female patients with COVID-19 who have the condition with difficulty breathing 2 to 5 days " "before a COVID diagnosis 3/15/20-12/11/20 AND have at least one emergency room visit or at least " "two inpatient visits", - os.path.join(os.path.dirname(__file__), '..', 'assets', 'cohort_creation', - 'test_cohort_creation_condition_occurrence_config.yaml'), - "test_user" + os.path.join( + os.path.dirname(__file__), + "..", + "assets", + "cohort_creation", + "test_cohort_creation_condition_occurrence_config.yaml", + ), + "test_user", ) # Test cohort object and methods assert cohort is not None, "Cohort creation failed" assert cohort.metadata is not None, "Cohort creation wrongly returned None metadata" - assert 'creation_info' in cohort.metadata, "Cohort creation does not contain 'creation_info' key" + assert "creation_info" in cohort.metadata, "Cohort creation does not contain 'creation_info' key" stats = cohort.get_stats() assert stats is not None, "Created cohort's stats is None" - gender_stats = cohort.get_stats(variable='gender') + gender_stats = cohort.get_stats(variable="gender") assert gender_stats is not None, "Created cohort's gender stats is None" caplog.clear() with caplog.at_level(logging.ERROR): - cohort.get_stats(variable='address') - assert 'is not available' in caplog.text + cohort.get_stats(variable="address") + assert "is not available" in caplog.text assert gender_stats is not None, "Created cohort's gender stats is None" assert cohort.data is not None, "Cohort creation wrongly returned None data" - patient_ids = set([item['subject_id'] for item in cohort.data]) - print(f'patient_ids: {patient_ids}', flush=True) + patient_ids = set([item["subject_id"] for item in cohort.data]) + print(f"patient_ids: {patient_ids}", flush=True) assert_equal(len(patient_ids), 2) assert_equal(patient_ids, {108, 110}) + def test_cohort_creation_multiple_temporary_groups_with_no_operator(test_db): bias = test_db cohort = bias.create_cohort( @@ -176,16 +201,22 @@ def test_cohort_creation_multiple_temporary_groups_with_no_operator(test_db): "Cohort of young female patients who either have COVID-19 with difficulty breathing 2 to 5 days " "before a COVID diagnosis 3/15/20-12/11/20 OR have at least one emergency room visit or at least " "two inpatient visits", - os.path.join(os.path.dirname(__file__), '..', 'assets', 'cohort_creation', - 'test_cohort_creation_multiple_temporal_groups_without_operator.yaml'), - "test_user" + os.path.join( + os.path.dirname(__file__), + "..", + "assets", + "cohort_creation", + "test_cohort_creation_multiple_temporal_groups_without_operator.yaml", + ), + "test_user", ) # Test cohort object and methods - patient_ids = set([item['subject_id'] for item in cohort.data]) - print(f'patient_ids: {patient_ids}', flush=True) + patient_ids = set([item["subject_id"] for item in cohort.data]) + print(f"patient_ids: {patient_ids}", flush=True) assert_equal(len(patient_ids), 2) assert_equal(patient_ids, {108, 110}) + def test_cohort_creation_mixed_domains(test_db): """ Test cohort creation with mixed domains (condition, drug, visit, procedure). @@ -196,95 +227,104 @@ def test_cohort_creation_mixed_domains(test_db): "Cohort of female patients with diabetes who had insulin prescribed 0-30 days after diagnosis " "and have at least one outpatient or emergency visit and underwent a blood test before 12/31/2020, " "with patients born after 1995 and with cardiac surgery excluded", - os.path.join(os.path.dirname(__file__), '..', 'assets', 'cohort_creation', - 'test_cohort_creation_config.yaml'), - "test_user" + os.path.join(os.path.dirname(__file__), "..", "assets", "cohort_creation", "test_cohort_creation_config.yaml"), + "test_user", ) # Test cohort object and methods assert cohort is not None, "Cohort creation failed" - print(f'metadata: {cohort.metadata}') + print(f"metadata: {cohort.metadata}") assert cohort.metadata is not None, "Cohort creation wrongly returned None metadata" - assert 'creation_info' in cohort.metadata, "Cohort creation does not contain 'creation_info' key" + assert "creation_info" in cohort.metadata, "Cohort creation does not contain 'creation_info' key" stats = cohort.get_stats() assert stats is not None, "Created cohort's stats is None" assert cohort.data is not None, "Cohort creation wrongly returned None data" - patient_ids = set([item['subject_id'] for item in cohort.data]) - print(f'patient_ids: {patient_ids}', flush=True) + patient_ids = set([item["subject_id"] for item in cohort.data]) + print(f"patient_ids: {patient_ids}", flush=True) assert_equal(len(patient_ids), 3) assert_equal(patient_ids, {1, 2, 6}) - start_dates = [item['cohort_start_date'] for item in cohort.data] + start_dates = [item["cohort_start_date"] for item in cohort.data] assert_equal(len(start_dates), 3) - assert_equal(start_dates, [datetime.date(2020, 6, 1), - datetime.date(2020, 6, 1), - datetime.date(2018, 1, 1)]) - end_dates = [item['cohort_end_date'] for item in cohort.data] + assert_equal(start_dates, [datetime.date(2020, 6, 1), datetime.date(2020, 6, 1), datetime.date(2018, 1, 1)]) + end_dates = [item["cohort_end_date"] for item in cohort.data] assert_equal(len(end_dates), 3) - assert_equal(end_dates, [datetime.date(2020, 6, 20), - datetime.date(2020, 6, 20), - datetime.date(2018, 1, 20)]) + assert_equal(end_dates, [datetime.date(2020, 6, 20), datetime.date(2020, 6, 20), datetime.date(2018, 1, 20)]) + def test_cohort_comparison(test_db): bias = test_db cohort_base = bias.create_cohort( "COVID-19 patient", "Cohort of young female patients", - os.path.join(os.path.dirname(__file__), '..', 'assets', 'cohort_creation', - 'test_cohort_creation_condition_occurrence_config_baseline.yaml'), - "test_user" + os.path.join( + os.path.dirname(__file__), + "..", + "assets", + "cohort_creation", + "test_cohort_creation_condition_occurrence_config_baseline.yaml", + ), + "test_user", ) cohort_study = bias.create_cohort( "Female diabetes patients born between 1970 and 2000", "Cohort of female patients with diabetes who had insulin prescribed 0-30 days after diagnosis " "and have at least one outpatient or emergency visit and underwent a blood test before 12/31/2020, " "with patients born after 1995 and with cardiac surgery excluded", - os.path.join(os.path.dirname(__file__), '..', 'assets', 'cohort_creation', - 'test_cohort_creation_config.yaml'), - "test_user" + os.path.join(os.path.dirname(__file__), "..", "assets", "cohort_creation", "test_cohort_creation_config.yaml"), + "test_user", ) results = bias.compare_cohorts(cohort_base.cohort_id, cohort_study.cohort_id) - assert {'gender_hellinger_distance': 0.0} in results - assert any('age_hellinger_distance' in r for r in results) + assert {"gender_hellinger_distance": 0.0} in results + assert any("age_hellinger_distance" in r for r in results) + def test_cohort_invalid(caplog, test_db): caplog.clear() with caplog.at_level(logging.INFO): - invalid_cohort = test_db.create_cohort('invalid_cohort', 'invalid_cohort', - 'invalid_yaml_file.yml', - 'invalid_created_by') - assert 'cohort creation configuration file does not exist' in caplog.text + invalid_cohort = test_db.create_cohort( + "invalid_cohort", "invalid_cohort", "invalid_yaml_file.yml", "invalid_created_by" + ) + assert "cohort creation configuration file does not exist" in caplog.text assert invalid_cohort is None caplog.clear() with caplog.at_level(logging.INFO): - invalid_cohort = test_db.create_cohort('invalid_cohort', 'invalid_cohort', - os.path.join(os.path.dirname(__file__), '..', 'assets', 'config', - 'test_config.yaml'), 'invalid_created_by') - assert 'configuration yaml file is not valid' in caplog.text + invalid_cohort = test_db.create_cohort( + "invalid_cohort", + "invalid_cohort", + os.path.join(os.path.dirname(__file__), "..", "assets", "config", "test_config.yaml"), + "invalid_created_by", + ) + assert "configuration yaml file is not valid" in caplog.text assert invalid_cohort is None with caplog.at_level(logging.INFO): - invalid_cohort = test_db.create_cohort('invalid_cohort', 'invalid_cohort', - 'INVALID SQL QUERY STRING', - 'invalid_created_by') - assert 'Error executing query:' in caplog.text + invalid_cohort = test_db.create_cohort( + "invalid_cohort", "invalid_cohort", "INVALID SQL QUERY STRING", "invalid_created_by" + ) + assert "Error executing query:" in caplog.text assert invalid_cohort is None + def test_create_cohort_sqlalchemy_error(monkeypatch, fresh_bias_obj): # Mock omop_db methods class MockOmopDB: def get_session(self): return self # not used after error + def execute_query(self, query): raise SQLAlchemyError("Mocked SQLAlchemy error") + def close(self): pass class MockBiasDB: def create_cohort_definition(self, *args, **kwargs): pass + def create_cohort_in_bulk(self, *args, **kwargs): pass + def close(self): pass @@ -295,6 +335,7 @@ def close(self): assert result is None + def test_cohort_creation_negative_instance(test_db): """ Test cohort creation with negative event_instance (last occurrence of a condition). @@ -303,25 +344,32 @@ def test_cohort_creation_negative_instance(test_db): cohort = bias.create_cohort( "Diabetes patients (last occurrence)", "Cohort of female patients born 1970-2000 with the last Type 2 diabetes diagnosis", - os.path.join(os.path.dirname(__file__), '..', 'assets', 'cohort_creation', - 'test_cohort_creation_negative_instance.yaml'), - "test_user" + os.path.join( + os.path.dirname(__file__), "..", "assets", "cohort_creation", "test_cohort_creation_negative_instance.yaml" + ), + "test_user", ) # Test cohort object and methods assert cohort is not None, "Cohort creation failed" assert cohort.data is not None, "Cohort creation returned None data" - patient_ids = set([item['subject_id'] for item in cohort.data]) + patient_ids = set([item["subject_id"] for item in cohort.data]) assert_equal(len(patient_ids), 6) # Female patients 1, 2, 3, 5 assert_equal(patient_ids, {1, 2, 3, 5, 6, 7}) # Verify dates for a specific patient (e.g., patient 1 with last diabetes diagnosis) - patient_1 = next(item for item in cohort.data if item['subject_id'] == 1) - assert_equal(patient_1['cohort_start_date'], datetime.date(2020, 6, 1), - "Incorrect cohort_start_date for patient 1 (last diabetes)") - assert_equal(patient_1['cohort_end_date'], datetime.date(2020, 6, 1), - "Incorrect cohort_end_date for patient 1 (last diabetes)") + patient_1 = next(item for item in cohort.data if item["subject_id"] == 1) + assert_equal( + patient_1["cohort_start_date"], + datetime.date(2020, 6, 1), + "Incorrect cohort_start_date for patient 1 (last diabetes)", + ) + assert_equal( + patient_1["cohort_end_date"], + datetime.date(2020, 6, 1), + "Incorrect cohort_end_date for patient 1 (last diabetes)", + ) def test_cohort_creation_offset(test_db): @@ -332,28 +380,33 @@ def test_cohort_creation_offset(test_db): cohort = bias.create_cohort( "Diabetes patients with offset", "Cohort of female patients born 1970-2000 with Type 2 diabetes diagnosis, adjusted by +180 and -730 days", - os.path.join(os.path.dirname(__file__), '..', 'assets', 'cohort_creation', - 'test_cohort_creation_offset.yaml'), - "test_user" + os.path.join(os.path.dirname(__file__), "..", "assets", "cohort_creation", "test_cohort_creation_offset.yaml"), + "test_user", ) # Test cohort object and methods assert cohort is not None, "Cohort creation failed" assert cohort.metadata is not None, "Cohort creation wrongly returned None metadata" - assert 'creation_info' in cohort.metadata, "Cohort creation does not contain 'creation_info' key" + assert "creation_info" in cohort.metadata, "Cohort creation does not contain 'creation_info' key" assert cohort.data is not None, "Cohort creation wrongly returned None data" - patient_ids = set([item['subject_id'] for item in cohort.data]) + patient_ids = set([item["subject_id"] for item in cohort.data]) assert_equal(len(patient_ids), 6) # Female patients 1, 2, 3, 5 assert_equal(patient_ids, {1, 2, 3, 5, 6, 7}) # Verify dates for a specific patient (e.g., patient 1 with offset) - patient_1 = next(item for item in cohort.data if item['subject_id'] == 1) + patient_1 = next(item for item in cohort.data if item["subject_id"] == 1) # Diabetes on 2020-06-01: -730 days = 2018-06-02, +180 days = 2020-11-28 - assert_equal(patient_1['cohort_start_date'], datetime.date(2018, 6, 2), - "Incorrect cohort_start_date for patient 1 (with -730 day offset)") - assert_equal(patient_1['cohort_end_date'], datetime.date(2020, 11, 28), - "Incorrect cohort_end_date for patient 1 (with +180 day offset)") + assert_equal( + patient_1["cohort_start_date"], + datetime.date(2018, 6, 2), + "Incorrect cohort_start_date for patient 1 (with -730 day offset)", + ) + assert_equal( + patient_1["cohort_end_date"], + datetime.date(2020, 11, 28), + "Incorrect cohort_end_date for patient 1 (with +180 day offset)", + ) def test_cohort_creation_negative_instance_offset(test_db): @@ -364,25 +417,36 @@ def test_cohort_creation_negative_instance_offset(test_db): cohort = bias.create_cohort( "Diabetes patients (last occurrence with offset)", "Cohort of female patients born 1970-2000 with the last Type 2 diabetes diagnosis, adjusted by +180 days", - os.path.join(os.path.dirname(__file__), '..', 'assets', 'cohort_creation', - 'test_cohort_creation_negative_instance_offset.yaml'), - "test_user" + os.path.join( + os.path.dirname(__file__), + "..", + "assets", + "cohort_creation", + "test_cohort_creation_negative_instance_offset.yaml", + ), + "test_user", ) # Test cohort object and methods assert cohort is not None, "Cohort creation failed" assert cohort.metadata is not None, "Cohort creation wrongly returned None metadata" - assert 'creation_info' in cohort.metadata, "Cohort creation does not contain 'creation_info' key" + assert "creation_info" in cohort.metadata, "Cohort creation does not contain 'creation_info' key" assert cohort.data is not None, "Cohort creation wrongly returned None data" - patient_ids = set([item['subject_id'] for item in cohort.data]) + patient_ids = set([item["subject_id"] for item in cohort.data]) assert_equal(len(patient_ids), 6) assert_equal(patient_ids, {1, 2, 3, 5, 6, 7}) # Verify dates for a specific patient (e.g., patient 1 with last diabetes and offset) - patient_1 = next(item for item in cohort.data if item['subject_id'] == 1) + patient_1 = next(item for item in cohort.data if item["subject_id"] == 1) # Last diabetes on 2020-06-01: +180 days = 2020-11-28 - assert_equal(patient_1['cohort_start_date'], datetime.date(2020, 6, 1), - "Incorrect cohort_start_date for patient 1 (last diabetes)") - assert_equal(patient_1['cohort_end_date'], datetime.date(2020, 11, 28), - "Incorrect cohort_end_date for patient 1 (last diabetes with +180 day offset)") + assert_equal( + patient_1["cohort_start_date"], + datetime.date(2020, 6, 1), + "Incorrect cohort_start_date for patient 1 (last diabetes)", + ) + assert_equal( + patient_1["cohort_end_date"], + datetime.date(2020, 11, 28), + "Incorrect cohort_end_date for patient 1 (last diabetes with +180 day offset)", + ) diff --git a/tests/query_based/test_hierarchical_prevalence.py b/tests/query_based/test_hierarchical_prevalence.py index a12d173..41ab068 100644 --- a/tests/query_based/test_hierarchical_prevalence.py +++ b/tests/query_based/test_hierarchical_prevalence.py @@ -1,6 +1,5 @@ import pytest -from functools import reduce -from biasanalyzer.concept import ConceptHierarchy, ConceptNode +from biasanalyzer.concept import ConceptHierarchy def test_cohort_concept_hierarchical_prevalence(test_db, caplog): @@ -13,47 +12,66 @@ def test_cohort_concept_hierarchical_prevalence(test_db, caplog): """ cohort = bias.create_cohort( - "Diabetes Cohort", - "Cohort of patients with diabetes-related conditions", - cohort_query, - "test_user" + "Diabetes Cohort", "Cohort of patients with diabetes-related conditions", cohort_query, "test_user" ) # Test cohort object and methods assert cohort is not None, "Cohort creation failed" # test concept_type must be one of the supported OMOP domain name with pytest.raises(ValueError): - cohort.get_concept_stats(concept_type='dummy_invalid') + cohort.get_concept_stats(concept_type="dummy_invalid") # test vocab must be None to use the default vocab or one of the supported OMOP vocabulary id with pytest.raises(ValueError): - cohort.get_concept_stats(vocab='dummy_invalid_vocab') + cohort.get_concept_stats(vocab="dummy_invalid_vocab") # test the cohort does not have procedure_occurrence related concepts with pytest.raises(ValueError): - cohort.get_concept_stats(concept_type='procedure_occurrence') + cohort.get_concept_stats(concept_type="procedure_occurrence") - concept_stats, _ = cohort.get_concept_stats(vocab='ICD10CM', print_concept_hierarchy=True) + concept_stats, _ = cohort.get_concept_stats(vocab="ICD10CM", print_concept_hierarchy=True) assert concept_stats is not None, "Failed to fetch concept stats" assert len(concept_stats) > 0, "No concept stats returned" # check returned data - assert not all(s['ancestor_concept_id'] == s['descendant_concept_id'] - for s in concept_stats['condition_occurrence']), \ - "Some ancestor_concept_id and descendant_concept_id should differ" + assert not all( + s["ancestor_concept_id"] == s["descendant_concept_id"] for s in concept_stats["condition_occurrence"] + ), "Some ancestor_concept_id and descendant_concept_id should differ" # Check concept prevalence for overlaps - diabetes_prevalence = next((c for c in concept_stats['condition_occurrence'] - if c['ancestor_concept_id'] == 1 and c['descendant_concept_id'] == 1), None) + diabetes_prevalence = next( + ( + c + for c in concept_stats["condition_occurrence"] + if c["ancestor_concept_id"] == 1 and c["descendant_concept_id"] == 1 + ), + None, + ) assert diabetes_prevalence is not None, "Parent diabetes concept prevalence missing" - type1_prevalence = next((c for c in concept_stats['condition_occurrence'] - if c['ancestor_concept_id'] == 2 and c['descendant_concept_id'] == 2), None) + type1_prevalence = next( + ( + c + for c in concept_stats["condition_occurrence"] + if c["ancestor_concept_id"] == 2 and c["descendant_concept_id"] == 2 + ), + None, + ) assert type1_prevalence is not None, "Child type 1 diabetes concept prevalence missing" - type2_prevalence = next((c for c in concept_stats['condition_occurrence'] - if c['ancestor_concept_id'] == 3 and c['descendant_concept_id'] == 3), None) + type2_prevalence = next( + ( + c + for c in concept_stats["condition_occurrence"] + if c["ancestor_concept_id"] == 3 and c["descendant_concept_id"] == 3 + ), + None, + ) assert type2_prevalence is not None, "Child type 2 diabetes concept prevalence missing" - print(f"type1_prevalence: {type1_prevalence['prevalence']}, type2_prevalence: {type2_prevalence['prevalence']}, " - f"diabetes_prevalence: {diabetes_prevalence['prevalence']}") - assert diabetes_prevalence['prevalence'] < type1_prevalence['prevalence'] + type2_prevalence['prevalence'], \ - ("Parent diabetes concept prevalence does not reflect overlap between type 1 and type 2 diabetes " - "children concept prevalence") + print( + f"type1_prevalence: {type1_prevalence['prevalence']}, type2_prevalence: {type2_prevalence['prevalence']}, " + f"diabetes_prevalence: {diabetes_prevalence['prevalence']}" + ) + assert diabetes_prevalence["prevalence"] < type1_prevalence["prevalence"] + type2_prevalence["prevalence"], ( + "Parent diabetes concept prevalence does not reflect overlap between type 1 and type 2 diabetes " + "children concept prevalence" + ) + def test_identifier_normalization_and_cache(): ConceptHierarchy.clear_cache() @@ -63,39 +81,60 @@ def test_identifier_normalization_and_cache(): # fake minimal results to build hierarchy results1 = [ - {"ancestor_concept_id": 1, "descendant_concept_id": 1, - "concept_name": "Diabetes", "concept_code": "DIA", - "count_in_cohort": 5, "prevalence": 0.5} + { + "ancestor_concept_id": 1, + "descendant_concept_id": 1, + "concept_name": "Diabetes", + "concept_code": "DIA", + "count_in_cohort": 5, + "prevalence": 0.5, + } ] results2 = [ - {"ancestor_concept_id": 1, "descendant_concept_id": 1, - "concept_name": "Diabetes2", "concept_code": "DIA", - "count_in_cohort": 15, "prevalence": 0.15} + { + "ancestor_concept_id": 1, + "descendant_concept_id": 1, + "concept_name": "Diabetes2", + "concept_code": "DIA", + "count_in_cohort": 15, + "prevalence": 0.15, + } ] - h1 = ConceptHierarchy.build_concept_hierarchy_from_results(1, 'condition_occurrence', results1) - h2 = ConceptHierarchy.build_concept_hierarchy_from_results(1, 'condition_occurrence', results2) + h1 = ConceptHierarchy.build_concept_hierarchy_from_results(1, "condition_occurrence", results1) + h2 = ConceptHierarchy.build_concept_hierarchy_from_results(1, "condition_occurrence", results2) assert h1 is h2 # cache reuse even though results2 is different from results1 assert h1.identifier == "1-condition_occurrence-0-None" - h2 = ConceptHierarchy.build_concept_hierarchy_from_results(1, 'drug_exposure', results2) - assert not h1 is h2 # cache is not used since drug_exposure concept_name is different than the cached + h2 = ConceptHierarchy.build_concept_hierarchy_from_results(1, "drug_exposure", results2) + assert h1 is not h2 # cache is not used since drug_exposure concept_name is different than the cached # condition_occurrence assert h2.identifier == "1-drug_exposure-0-None" + def test_union_and_cache_behavior(): ConceptHierarchy.clear_cache() results1 = [ - {"ancestor_concept_id": 1, "descendant_concept_id": 1, - "concept_name": "Diabetes", "concept_code": "DIA", - "count_in_cohort": 5, "prevalence": 0.5} + { + "ancestor_concept_id": 1, + "descendant_concept_id": 1, + "concept_name": "Diabetes", + "concept_code": "DIA", + "count_in_cohort": 5, + "prevalence": 0.5, + } ] results2 = [ - {"ancestor_concept_id": 2, "descendant_concept_id": 2, - "concept_name": "Hypertension", "concept_code": "HYP", - "count_in_cohort": 3, "prevalence": 0.3} + { + "ancestor_concept_id": 2, + "descendant_concept_id": 2, + "concept_name": "Hypertension", + "concept_code": "HYP", + "count_in_cohort": 3, + "prevalence": 0.3, + } ] - h1 = ConceptHierarchy.build_concept_hierarchy_from_results(1, 'condition_occurrence', results1) - h2 = ConceptHierarchy.build_concept_hierarchy_from_results(2, 'condition_occurrence', results2) + h1 = ConceptHierarchy.build_concept_hierarchy_from_results(1, "condition_occurrence", results1) + h2 = ConceptHierarchy.build_concept_hierarchy_from_results(2, "condition_occurrence", results2) assert "1-condition_occurrence-0-None" in ConceptHierarchy._graph_cache assert "2-condition_occurrence-0-None" in ConceptHierarchy._graph_cache h12 = h1.union(h2) @@ -104,17 +143,28 @@ def test_union_and_cache_behavior(): assert h21.identifier == "1-condition_occurrence-0-None+2-condition_occurrence-0-None" assert h12 is h21 + def test_traversal_and_serialization(): ConceptHierarchy.clear_cache() results = [ - {"ancestor_concept_id": 1, "descendant_concept_id": 1, - "concept_name": "Root", "concept_code": "R", - "count_in_cohort": 5, "prevalence": 0.5}, - {"ancestor_concept_id": 1, "descendant_concept_id": 2, - "concept_name": "Child", "concept_code": "C", - "count_in_cohort": 2, "prevalence": 0.2} + { + "ancestor_concept_id": 1, + "descendant_concept_id": 1, + "concept_name": "Root", + "concept_code": "R", + "count_in_cohort": 5, + "prevalence": 0.5, + }, + { + "ancestor_concept_id": 1, + "descendant_concept_id": 2, + "concept_name": "Child", + "concept_code": "C", + "count_in_cohort": 2, + "prevalence": 0.2, + }, ] - h = ConceptHierarchy.build_concept_hierarchy_from_results(1, 'condition_occurrence', results) + h = ConceptHierarchy.build_concept_hierarchy_from_results(1, "condition_occurrence", results) # roots roots = h.get_root_nodes() @@ -128,16 +178,12 @@ def test_traversal_and_serialization(): leaf_nodes = h.get_leaf_nodes(serialization=True) assert leaf_nodes == [ { - 'concept_id': 2, - 'concept_name': 'Child', - 'concept_code': 'C', - 'metrics': { - '1': { - 'count': 2, 'prevalence': 0.2 - } - }, - 'source_cohorts': [1], - 'parent_ids': [1] + "concept_id": 2, + "concept_name": "Child", + "concept_code": "C", + "metrics": {"1": {"count": 2, "prevalence": 0.2}}, + "source_cohorts": [1], + "parent_ids": [1], } ] @@ -152,14 +198,9 @@ def test_traversal_and_serialization(): "concept_id": 1, "concept_name": "Root", "concept_code": "R", - "metrics": { - "1": { - "count": 5, - "prevalence": 0.5 - } - }, - 'source_cohorts': [1], - "parent_ids": [] + "metrics": {"1": {"count": 5, "prevalence": 0.5}}, + "source_cohorts": [1], + "parent_ids": [], } # graph traversal @@ -180,7 +221,7 @@ def test_traversal_and_serialization(): dfs_nodes = [n.id for n in h.iter_nodes(1, order="dfs")] assert set(dfs_nodes) == {1, 2} - dfs_nodes = [n['concept_id'] for n in h.iter_nodes(1, order="dfs", serialization=True)] + dfs_nodes = [n["concept_id"] for n in h.iter_nodes(1, order="dfs", serialization=True)] assert set(dfs_nodes) == {1, 2} # serialization @@ -196,27 +237,51 @@ def test_traversal_and_serialization(): h.to_dict(111) h_dict = h.to_dict(1, include_union_metrics=True) - assert h_dict == {'hierarchy': [{ - 'concept_id': 1, 'concept_name': 'Root', 'concept_code': 'R', - 'metrics': {'union': {'count': 5, 'prevalence': 0.5}, - '1': {'count': 5, 'prevalence': 0.5}}, - 'source_cohorts': [1], - 'parent_ids': [], - 'children': [{'concept_id': 2, 'concept_name': 'Child', 'concept_code': 'C', - 'metrics': {'union': {'count': 2, 'prevalence': 0.2}, - '1': {'count': 2, 'prevalence': 0.2}}, - 'source_cohorts': [1], - 'parent_ids': [1], 'children': []}]} - ]} + assert h_dict == { + "hierarchy": [ + { + "concept_id": 1, + "concept_name": "Root", + "concept_code": "R", + "metrics": {"union": {"count": 5, "prevalence": 0.5}, "1": {"count": 5, "prevalence": 0.5}}, + "source_cohorts": [1], + "parent_ids": [], + "children": [ + { + "concept_id": 2, + "concept_name": "Child", + "concept_code": "C", + "metrics": {"union": {"count": 2, "prevalence": 0.2}, "1": {"count": 2, "prevalence": 0.2}}, + "source_cohorts": [1], + "parent_ids": [1], + "children": [], + } + ], + } + ] + } h_dict = h.to_dict() - assert h_dict == {'hierarchy': [{ - 'concept_id': 1, 'concept_name': 'Root', 'concept_code': 'R', - 'metrics': {'1': {'count': 5, 'prevalence': 0.5}}, - 'source_cohorts': [1], - 'parent_ids': [], - 'children': [{'concept_id': 2, 'concept_name': 'Child', 'concept_code': 'C', - 'metrics': {'1': {'count': 2, 'prevalence': 0.2}}, - 'source_cohorts': [1], - 'parent_ids': [1], 'children': []}]} - ]} + assert h_dict == { + "hierarchy": [ + { + "concept_id": 1, + "concept_name": "Root", + "concept_code": "R", + "metrics": {"1": {"count": 5, "prevalence": 0.5}}, + "source_cohorts": [1], + "parent_ids": [], + "children": [ + { + "concept_id": 2, + "concept_name": "Child", + "concept_code": "C", + "metrics": {"1": {"count": 2, "prevalence": 0.2}}, + "source_cohorts": [1], + "parent_ids": [1], + "children": [], + } + ], + } + ] + } diff --git a/tests/test_biasanalyzer_api.py b/tests/test_biasanalyzer_api.py index a47a3a5..6f42567 100644 --- a/tests/test_biasanalyzer_api.py +++ b/tests/test_biasanalyzer_api.py @@ -1,46 +1,49 @@ -import os import datetime import logging +import os + import pytest -from ipytree import Node -from biasanalyzer.concept import ConceptHierarchy from biasanalyzer import __version__ +from biasanalyzer.concept import ConceptHierarchy +from ipytree import Node def test_version(): - assert __version__ == '0.1.0' + assert __version__ == "0.1.0" + def test_set_config(caplog, fresh_bias_obj): caplog.clear() with caplog.at_level(logging.INFO): - fresh_bias_obj.set_config('') - assert 'no configuration file specified' in caplog.text + fresh_bias_obj.set_config("") + assert "no configuration file specified" in caplog.text caplog.clear() with caplog.at_level(logging.ERROR): - fresh_bias_obj.set_config('non_existent_config_file.yaml') - assert 'does not exist' in caplog.text + fresh_bias_obj.set_config("non_existent_config_file.yaml") + assert "does not exist" in caplog.text caplog.clear() with caplog.at_level(logging.ERROR): - invalid_config_file = os.path.join(os.path.dirname(__file__), 'assets', 'config', - 'test_invalid_config.yaml') + invalid_config_file = os.path.join(os.path.dirname(__file__), "assets", "config", "test_invalid_config.yaml") fresh_bias_obj.set_config(invalid_config_file) - assert 'is not valid' in caplog.text + assert "is not valid" in caplog.text + def test_set_root_omop(monkeypatch, caplog, fresh_bias_obj): caplog.clear() with caplog.at_level(logging.INFO): fresh_bias_obj.set_root_omop() - assert 'no valid configuration' in caplog.text + assert "no valid configuration" in caplog.text caplog.clear() with caplog.at_level(logging.INFO): - config_file_with_unsupported_db_type = os.path.join(os.path.dirname(__file__), 'assets', 'config', - 'test_config_unsupported_db_type.yaml') + config_file_with_unsupported_db_type = os.path.join( + os.path.dirname(__file__), "assets", "config", "test_config_unsupported_db_type.yaml" + ) fresh_bias_obj.set_config(config_file_with_unsupported_db_type) fresh_bias_obj.set_root_omop() - assert 'Unsupported database type' in caplog.text + assert "Unsupported database type" in caplog.text # Create a fake postgresql config config = { @@ -50,7 +53,7 @@ def test_set_root_omop(monkeypatch, caplog, fresh_bias_obj): "password": "testpass", "hostname": "localhost", "port": 5432, - "database": "testdb" + "database": "testdb", } } @@ -61,6 +64,7 @@ def test_set_root_omop(monkeypatch, caplog, fresh_bias_obj): class MockOMOPCDMDatabase: def __init__(self, db_url): self.db_url = db_url + def close(self): pass @@ -89,35 +93,41 @@ def close(self): assert fresh_bias_obj.bias_db is not None assert fresh_bias_obj.bias_db.omop_cdm_db_url == "postgresql://testuser:testpass@localhost:5432/testdb" + def test_set_cohort_action(caplog, fresh_bias_obj): caplog.clear() with caplog.at_level(logging.INFO): fresh_bias_obj._set_cohort_action() - assert 'valid OMOP CDM must be set' in caplog.text + assert "valid OMOP CDM must be set" in caplog.text + def test_create_cohort_with_no_action(caplog, fresh_bias_obj): caplog.clear() with caplog.at_level(logging.INFO): - fresh_bias_obj.create_cohort('test', 'test', 'test.yaml', 'test') - assert 'failed to create a valid cohort action object' in caplog.text + fresh_bias_obj.create_cohort("test", "test", "test.yaml", "test") + assert "failed to create a valid cohort action object" in caplog.text + def test_compare_cohort_with_no_action(caplog, fresh_bias_obj): caplog.clear() with caplog.at_level(logging.INFO): fresh_bias_obj.compare_cohorts(1, 2) - assert 'failed to create a valid cohort action object' in caplog.text + assert "failed to create a valid cohort action object" in caplog.text + def test_cohorts_concept_stats_empty_input_cohorts(caplog, fresh_bias_obj): caplog.clear() with caplog.at_level(logging.INFO): fresh_bias_obj.get_cohorts_concept_stats([]) - assert 'The input cohorts list is empty. At least one cohort id must be provided.' in caplog.text + assert "The input cohorts list is empty. At least one cohort id must be provided." in caplog.text + def test_cohorts_concept_stats_no_cohort_action(caplog, fresh_bias_obj): caplog.clear() with caplog.at_level(logging.INFO): fresh_bias_obj.get_cohorts_concept_stats([1]) - assert 'failed to get concept prevalence stats for the union of cohorts' in caplog.text + assert "failed to get concept prevalence stats for the union of cohorts" in caplog.text + def test_cohorts_union_concept_stats(test_db): ConceptHierarchy.clear_cache() @@ -146,8 +156,8 @@ def test_cohorts_union_concept_stats(test_db): print("Concept stats per cohort:\n", stats_df.to_string(index=False), flush=True) union_result = test_db.get_cohorts_concept_stats([1, 2]) - print(f'union_result: {union_result}', flush=True) - union_result['hierarchy'] = sorted(union_result['hierarchy'], key=lambda x: x['concept_id']) + print(f"union_result: {union_result}", flush=True) + union_result["hierarchy"] = sorted(union_result["hierarchy"], key=lambda x: x["concept_id"]) # NOTE: The union_result takes cohort_start_date and cohort_end_date into account # when joining cohort with condition_occurrence for inclusion/exclusion criteria. # That means counts may differ from the raw numbers above. For example: @@ -158,124 +168,169 @@ def test_cohorts_union_concept_stats(test_db): # - Concept 5 disappears entirely, because its single occurrence is outside # the cohort date window. # This explains why union_result values differ from the raw stats above. - assert union_result == {'hierarchy': [ - {'concept_id': 316139, 'concept_name': 'Heart failure', 'concept_code': '84114007', - 'metrics': {'1': {'count': 2, 'prevalence': 0.4}, - '2': {'count': 2, 'prevalence': 0.5}}, - 'source_cohorts': [1, 2], - 'parent_ids': [], 'children': []}, - {'concept_id': 4041664, 'concept_name': 'Difficulty breathing', 'concept_code': '230145002', - 'metrics': { - '1': {'count': 4, 'prevalence': 0.8}, - '2': {'count': 1, 'prevalence': 0.25} - }, - 'source_cohorts': [1, 2], - 'parent_ids': [], 'children': []}, - {'concept_id': 37311061, 'concept_name': 'COVID-19', 'concept_code': '840539006', - 'metrics': {'1': {'count': 4, 'prevalence': 0.8}, - '2': {'count': 4, 'prevalence': 1.0}}, - 'source_cohorts': [1, 2], - 'parent_ids': [], 'children': []}, - ]} + assert union_result == { + "hierarchy": [ + { + "concept_id": 316139, + "concept_name": "Heart failure", + "concept_code": "84114007", + "metrics": {"1": {"count": 2, "prevalence": 0.4}, "2": {"count": 2, "prevalence": 0.5}}, + "source_cohorts": [1, 2], + "parent_ids": [], + "children": [], + }, + { + "concept_id": 4041664, + "concept_name": "Difficulty breathing", + "concept_code": "230145002", + "metrics": {"1": {"count": 4, "prevalence": 0.8}, "2": {"count": 1, "prevalence": 0.25}}, + "source_cohorts": [1, 2], + "parent_ids": [], + "children": [], + }, + { + "concept_id": 37311061, + "concept_name": "COVID-19", + "concept_code": "840539006", + "metrics": {"1": {"count": 4, "prevalence": 0.8}, "2": {"count": 4, "prevalence": 1.0}}, + "source_cohorts": [1, 2], + "parent_ids": [], + "children": [], + }, + ] + } + def test_get_domains_and_vocabularies_invalid(caplog, fresh_bias_obj): caplog.clear() with caplog.at_level(logging.INFO): fresh_bias_obj.get_domains_and_vocabularies() - assert 'valid OMOP CDM must be set' in caplog.text + assert "valid OMOP CDM must be set" in caplog.text + def test_get_domains_and_vocabularies(test_db): domains_and_vocabularies = test_db.get_domains_and_vocabularies() - expected = [{'domain_id': 'Condition', 'vocabulary_id': 'ICD10CM'}, - {'domain_id': 'Condition', 'vocabulary_id': 'SNOMED'}] + expected = [ + {"domain_id": "Condition", "vocabulary_id": "ICD10CM"}, + {"domain_id": "Condition", "vocabulary_id": "SNOMED"}, + ] assert domains_and_vocabularies == expected + def test_get_concepts_no_omop_cdm(caplog, fresh_bias_obj): caplog.clear() with caplog.at_level(logging.INFO): - fresh_bias_obj.get_concepts('dummy') - assert 'valid OMOP CDM must be set' in caplog.text + fresh_bias_obj.get_concepts("dummy") + assert "valid OMOP CDM must be set" in caplog.text + def test_get_concepts_no_domain_and_vocab(caplog, test_db): caplog.clear() with caplog.at_level(logging.INFO): - test_db.get_concepts('dummy') - assert 'either domain or vocabulary must be set' in caplog.text + test_db.get_concepts("dummy") + assert "either domain or vocabulary must be set" in caplog.text + def test_get_concepts(test_db): - concepts = test_db.get_concepts('Heart failure', domain='Condition', vocabulary='SNOMED') - expected = [{'concept_id': 316139, 'concept_name': 'Heart failure', - 'valid_start_date': datetime.date(2012, 4, 1), - 'valid_end_date': datetime.date(2020, 4, 1), - 'domain_id': 'Condition', 'vocabulary_id': 'SNOMED'}] + concepts = test_db.get_concepts("Heart failure", domain="Condition", vocabulary="SNOMED") + expected = [ + { + "concept_id": 316139, + "concept_name": "Heart failure", + "valid_start_date": datetime.date(2012, 4, 1), + "valid_end_date": datetime.date(2020, 4, 1), + "domain_id": "Condition", + "vocabulary_id": "SNOMED", + } + ] assert concepts == expected - concepts = test_db.get_concepts('Heart failure', vocabulary='SNOMED') + concepts = test_db.get_concepts("Heart failure", vocabulary="SNOMED") assert concepts == expected - concepts = test_db.get_concepts('Heart failure', domain='Condition') - print(f'concepts: {concepts}', flush=True) + concepts = test_db.get_concepts("Heart failure", domain="Condition") + print(f"concepts: {concepts}", flush=True) assert concepts == expected + def test_get_concept_hierarchy_no_omop_cdm(caplog, fresh_bias_obj): caplog.clear() with caplog.at_level(logging.INFO): - fresh_bias_obj.get_concept_hierarchy('dummy') - assert 'valid OMOP CDM must be set' in caplog.text + fresh_bias_obj.get_concept_hierarchy("dummy") + assert "valid OMOP CDM must be set" in caplog.text + def test_get_concept_hierarchy(test_db): with pytest.raises(ValueError): - test_db.get_concept_hierarchy('not_int_str') + test_db.get_concept_hierarchy("not_int_str") hierarchy = test_db.get_concept_hierarchy(2) - print(f'hierarchy: {hierarchy}', flush=True) - expected = ({'details': {'concept_id': 2, 'concept_name': 'Type 1 Diabetes Mellitus', 'vocabulary_id': 'ICD10CM', - 'concept_code': 'E10'}, 'parents': [{'details': {'concept_id': 1, 'concept_name': - 'Diabetes Mellitus', 'vocabulary_id': 'ICD10CM', 'concept_code': 'E10-E14'}, 'parents': []}]}, - {'details': {'concept_id': 2, 'concept_name': 'Type 1 Diabetes Mellitus', 'vocabulary_id': 'ICD10CM', - 'concept_code': 'E10'}, 'children': [{'details': {'concept_id': 4, 'concept_name': - 'Diabetic Retinopathy', 'vocabulary_id': 'ICD10CM', 'concept_code': 'E10.3/E11.3'}, - 'children': []}]}) + print(f"hierarchy: {hierarchy}", flush=True) + expected = ( + { + "details": { + "concept_id": 2, + "concept_name": "Type 1 Diabetes Mellitus", + "vocabulary_id": "ICD10CM", + "concept_code": "E10", + }, + "parents": [ + { + "details": { + "concept_id": 1, + "concept_name": "Diabetes Mellitus", + "vocabulary_id": "ICD10CM", + "concept_code": "E10-E14", + }, + "parents": [], + } + ], + }, + { + "details": { + "concept_id": 2, + "concept_name": "Type 1 Diabetes Mellitus", + "vocabulary_id": "ICD10CM", + "concept_code": "E10", + }, + "children": [ + { + "details": { + "concept_id": 4, + "concept_name": "Diabetic Retinopathy", + "vocabulary_id": "ICD10CM", + "concept_code": "E10.3/E11.3", + }, + "children": [], + } + ], + }, + ) assert hierarchy == expected + def test_display_concept_tree_text_format(capsys, test_db): - sample_tree = { - "details": { - "concept_id": 123, - "concept_name": "Hypertension", - "concept_code": "I10" - } - } + sample_tree = {"details": {"concept_id": 123, "concept_name": "Hypertension", "concept_code": "I10"}} test_db.display_concept_tree(sample_tree) captured = capsys.readouterr() assert "concept tree must contain parents or children key" in captured.out - sample_tree['children'] = [{ - "details": { - "concept_id": 456, - "concept_name": "Essential Hypertension", - "concept_code": "I10.0" - }, - "children": [] - }] + sample_tree["children"] = [ + { + "details": {"concept_id": 456, "concept_name": "Essential Hypertension", "concept_code": "I10.0"}, + "children": [], + } + ] test_db.display_concept_tree(sample_tree, show_in_text_format=True) captured = capsys.readouterr() assert "Hypertension (ID: 123" in captured.out assert "Essential Hypertension (ID: 456" in captured.out + def test_display_concept_tree_widget(test_db): sample_tree = { - "details": { - "concept_id": 456, - "concept_name": "Essential Hypertension", - "concept_code": "I10.0" - }, - "parents": [{ - "details": { - "concept_id": 123, - "concept_name": "Hypertension", - "concept_code": "I10" - }, - "parents": [] - }] + "details": {"concept_id": 456, "concept_name": "Essential Hypertension", "concept_code": "I10.0"}, + "parents": [ + {"details": {"concept_id": 123, "concept_name": "Hypertension", "concept_code": "I10"}, "parents": []} + ], } tree_output = test_db.display_concept_tree(sample_tree, show_in_text_format=False) diff --git a/tests/test_config.py b/tests/test_config.py index 060b940..d3163d2 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,45 +1,43 @@ import os -from biasanalyzer.config import load_config, load_cohort_creation_config +from biasanalyzer.config import load_cohort_creation_config, load_config def test_load_config(): - try: - config = load_config(os.path.join(os.path.dirname(__file__), 'assets', 'config', 'test_config.yaml')) - except Exception as e: - assert False, f"load_config() raised an exception: {e}" - - assert config.get('root_omop_cdm_database') == { - 'database_type': 'duckdb', - 'username': 'test_username', - 'password': 'test_password', - 'hostname': 'test_db_hostname', - 'database': 'shared_test_db.duckdb', - 'port': 5432 + config = load_config(os.path.join(os.path.dirname(__file__), "assets", "config", "test_config.yaml")) + + assert config.get("root_omop_cdm_database") == { + "database_type": "duckdb", + "username": "test_username", + "password": "test_password", + "hostname": "test_db_hostname", + "database": "shared_test_db.duckdb", + "port": 5432, } -def test_load_cohort_creation_config(): - try: - config = load_cohort_creation_config( - os.path.join(os.path.dirname(__file__), 'assets', 'cohort_creation', - 'test_cohort_creation_condition_occurrence_config.yaml')) - except Exception as e: - assert False, f"test_load_cohort_creation_config() raised an exception: {e}" - - assert 'inclusion_criteria' in config +def test_load_cohort_creation_config(): + config = load_cohort_creation_config( + os.path.join( + os.path.dirname(__file__), + "assets", + "cohort_creation", + "test_cohort_creation_condition_occurrence_config.yaml", + ) + ) + assert "inclusion_criteria" in config # assert 'exclusion_criteria' in config - assert 'temporal_events' in config.get('inclusion_criteria') + assert "temporal_events" in config.get("inclusion_criteria") # assert 'temporal_events' in config.get('exclusion_criteria') - assert 'demographics' in config.get('inclusion_criteria') + assert "demographics" in config.get("inclusion_criteria") # assert 'demographics' in config.get('exclusion_criteria') - demographics = config.get('inclusion_criteria').get('demographics') - assert 'gender' in demographics - assert 'min_birth_year' in demographics - assert 'max_birth_year' in demographics - assert demographics['max_birth_year'] >= demographics['min_birth_year'] - assert demographics['gender'] == 'female' or demographics['gender'] == 'male' - - in_events = config.get('inclusion_criteria')['temporal_events'] - assert 'operator' in in_events[0] - assert 'events' in in_events[0] + demographics = config.get("inclusion_criteria").get("demographics") + assert "gender" in demographics + assert "min_birth_year" in demographics + assert "max_birth_year" in demographics + assert demographics["max_birth_year"] >= demographics["min_birth_year"] + assert demographics["gender"] == "female" or demographics["gender"] == "male" + + in_events = config.get("inclusion_criteria")["temporal_events"] + assert "operator" in in_events[0] + assert "events" in in_events[0] diff --git a/tests/test_database.py b/tests/test_database.py index 8fb1ed5..97ab8a8 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -1,6 +1,7 @@ +from unittest.mock import Mock + import duckdb import pytest -from unittest.mock import Mock from biasanalyzer.cohort_query_builder import CohortQueryBuilder from biasanalyzer.database import BiasDatabase @@ -27,6 +28,7 @@ def execute(self, query): assert "INSTALL postgres" in calls[0] assert "LOAD postgres" in calls[1] + def test_bias_db_postgres_omop_db_url(monkeypatch): # Reset singleton to get a clean instance BiasDatabase._instance = None @@ -43,21 +45,24 @@ def close(self): # Mock duckdb.connect to return our MockConn mock_connect = Mock(return_value=MockConn()) monkeypatch.setattr("duckdb.connect", mock_connect) - db = BiasDatabase(":memory:", omop_db_url="postgresql://testuser:testpass@localhost:5432/testdb") + BiasDatabase(":memory:", omop_db_url="postgresql://testuser:testpass@localhost:5432/testdb") assert len(calls) >= 3 assert any("INSTALL postgres" in call for call in calls), "INSTALL postgres must be run at BiasDatabase init" assert any("LOAD postgres" in call for call in calls), "LOAD postgres must be run at BiasDatabase init" assert any("ATTACH" in call for call in calls), "ATTACH must be run at BiasDatabase init" + def test_bias_db_invalid_omop_db_url(): BiasDatabase._instance = None - with pytest.raises(ValueError, match='Unsupported OMOP database backend'): - db = BiasDatabase(":memory:", omop_db_url='dummy_invalid_url') + with pytest.raises(ValueError, match="Unsupported OMOP database backend"): + BiasDatabase(":memory:", omop_db_url="dummy_invalid_url") + def test_create_cohort_definition_table_error_on_sequence(): BiasDatabase._instance = None db = BiasDatabase(":memory:") + class MockConn: def __init__(self): self.calls = [] @@ -76,9 +81,11 @@ def close(self): with pytest.raises(duckdb.Error, match="random error"): db._create_cohort_definition_table() + def test_create_cohort_definition_table_sequence_exists(): BiasDatabase._instance = None db = BiasDatabase(":memory:") + class MockConn: def __init__(self): self.call_count = 0 @@ -103,9 +110,11 @@ def close(self): assert db.conn.call_count >= 2 assert any("CREATE SEQUENCE" in sql for sql in db.conn.executed_sql) + def test_create_cohort_index_error(): BiasDatabase._instance = None db = BiasDatabase(":memory:") + class MockConn: def __init__(self): self.calls = [] @@ -124,9 +133,11 @@ def close(self): with pytest.raises(duckdb.Error, match="random error"): db._create_cohort_table() + def test_create_cohort_index_exists(): BiasDatabase._instance = None db = BiasDatabase(":memory:") + class MockConn: def __init__(self): self.call_count = 0 @@ -151,22 +162,24 @@ def close(self): assert db.conn.call_count >= 2 assert any("CREATE INDEX" in sql for sql in db.conn.executed_sql) + def test_get_cohort_concept_stats_handles_exception(caplog): BiasDatabase._instance = None db = BiasDatabase(":memory:") - db.omop_cdm_db_url = 'duckdb' + db.omop_cdm_db_url = "duckdb" qry_builder = CohortQueryBuilder(cohort_creation=False) with pytest.raises(ValueError): db.get_cohort_concept_stats(123, qry_builder) + def test_get_cohort_attributes_handles_exception(): BiasDatabase._instance = None db = BiasDatabase(":memory:") qry_builder = CohortQueryBuilder(cohort_creation=False) db.omop_cdm_db_url = None - result_stats = db.get_cohort_basic_stats(123, variable='age') + result_stats = db.get_cohort_basic_stats(123, variable="age") assert result_stats is None - result = db.get_cohort_distributions(123, 'age') + result = db.get_cohort_distributions(123, "age") assert result is None with pytest.raises(ValueError): db.get_cohort_concept_stats(123, qry_builder)