diff --git a/README.MD b/README.MD index 97f5c97..c05e67c 100644 --- a/README.MD +++ b/README.MD @@ -15,6 +15,20 @@ Output: Score 3. System for extracting drug related variants annotations from an article. Associations in which the variant affects a drug dose, response, metabolism, etc. 4. Continously fetch new pharmacogenomic articles +## Setup +To get started, you need two sources of data locally: +1. The annotations for the articles (data/variantAnnotations/var_drug_ann.tsv) +2. the articles themselves (data/articles) +These can be populated using the following commands: +``` +pixi run download-variants +pixi run update-download-map +pixi run download-articles +``` +The download-articles step takes the longest and can be skipped by unzipping data/articles.zip, creating a list of XMLs at the directory data/articles. +If you are running download-articles, make sure to create a .env at the root with your email using the format +NCBI_EMAIL=YOUR_EMAIL@SCHOOL.EDU + ## Description This repository contains Python scripts for running and building a Pharmacogenomic Agentic system to annotate and label genetic variants based on their phenotypical associations from journal articles. diff --git a/pixi.toml b/pixi.toml index 4bb20c4..8811cb5 100644 --- a/pixi.toml +++ b/pixi.toml @@ -12,7 +12,7 @@ platforms = ["osx-arm64"] version = "0.1.0" [tasks] -download-variants = "python -m src.load_variants.load_clinical_variants" +download-variants = "python -m src.load_variants.download_annotations_pipeline" update-download-map = "python -c 'from src.fetch_articles.article_downloader import update_downloaded_pmcids; update_downloaded_pmcids()'" download-articles = "python -m src.fetch_articles.article_downloader" diff --git a/requirements.txt b/requirements.txt index da8cd9b..8365210 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,12 @@ # SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see CONTRIBUTORS.md) # SPDX-License-Identifier: Apache-2.0 + +requests +pandas +openai +biopython +beautifulsoup4 +tqdm +matplotlib +loguru +dotenv \ No newline at end of file diff --git a/src/benchmark/README.md b/src/benchmark/README.md index d12e18e..e86211d 100644 --- a/src/benchmark/README.md +++ b/src/benchmark/README.md @@ -1,4 +1,4 @@ # Benchmark ## Functions -1. Calculate the niave difference between an extracted variant and the ground truth variant on Variant Annotation ID +1. Calculate the naive difference between an extracted variant and the ground truth variant on Variant Annotation ID diff --git a/src/fetch_articles/article_downloader.py b/src/fetch_articles/article_downloader.py index 543f01e..ee34c46 100644 --- a/src/fetch_articles/article_downloader.py +++ b/src/fetch_articles/article_downloader.py @@ -1,11 +1,13 @@ from loguru import logger from src.fetch_articles.pmcid_converter import get_unique_pmcids from src.utils.file_paths import get_project_root +from src.variant_extraction.config import ENTREZ_EMAIL from Bio import Entrez import os import json from tqdm import tqdm +Entrez.email = ENTREZ_EMAIL def fetch_pmc_content(pmcid): """ diff --git a/src/fetch_articles/pmcid_converter.py b/src/fetch_articles/pmcid_converter.py index 417ec57..1f6ae3a 100644 --- a/src/fetch_articles/pmcid_converter.py +++ b/src/fetch_articles/pmcid_converter.py @@ -13,10 +13,6 @@ # Email for NCBI Entrez.email = os.getenv("NCBI_EMAIL") -# Step 1: Function to get PMCID from PMID -import requests -from loguru import logger - import requests import time from loguru import logger diff --git a/src/load_variants/__init__.py b/src/load_variants/__init__.py index 6c56850..1470c6d 100644 --- a/src/load_variants/__init__.py +++ b/src/load_variants/__init__.py @@ -1,5 +1,4 @@ -from .load_clinical_variants import ( +from src.load_variants.load_clinical_variants import ( load_raw_variant_annotations, get_pmid_list, - variant_annotations_pipeline, ) diff --git a/src/load_variants/download_annotations_pipeline.py b/src/load_variants/download_annotations_pipeline.py new file mode 100644 index 0000000..9cfeef3 --- /dev/null +++ b/src/load_variants/download_annotations_pipeline.py @@ -0,0 +1,23 @@ +from loguru import logger +from src.load_variants.load_clinical_variants import download_and_extract_variant_annotations, load_raw_variant_annotations, get_pmid_list + +def variant_annotations_pipeline(): + """ + Loads the variant annotations tsv file and saves the unique PMIDs to a json file. + """ + # Download and extract the variant annotations + logger.info("Downloading and extracting variant annotations...") + download_and_extract_variant_annotations() + + # Load the variant annotations + logger.info("Loading variant annotations...") + df = load_raw_variant_annotations() + + # Get the PMIDs + logger.info("Getting PMIDs...") + pmid_list = get_pmid_list() + logger.info(f"Number of unique PMIDs: {len(pmid_list)}") + + +if __name__ == "__main__": + variant_annotations_pipeline() \ No newline at end of file diff --git a/src/load_variants/load_clinical_variants.py b/src/load_variants/load_clinical_variants.py index c43b340..e8bd36c 100644 --- a/src/load_variants/load_clinical_variants.py +++ b/src/load_variants/load_clinical_variants.py @@ -112,25 +112,3 @@ def get_pmid_list(override: bool = False) -> list: with open(pmid_list_path, "w") as f: json.dump(pmid_list, f) return pmid_list - - -def variant_annotations_pipeline(): - """ - Loads the variant annotations tsv file and saves the unique PMIDs to a json file. - """ - # Download and extract the variant annotations - logger.info("Downloading and extracting variant annotations...") - download_and_extract_variant_annotations() - - # Load the variant annotations - logger.info("Loading variant annotations...") - df = load_raw_variant_annotations() - - # Get the PMIDs - logger.info("Getting PMIDs...") - pmid_list = get_pmid_list() - logger.info(f"Number of unique PMIDs: {len(pmid_list)}") - - -if __name__ == "__main__": - variant_annotations_pipeline() diff --git a/src/variant_extraction/README.md b/src/variant_extraction/README.md new file mode 100644 index 0000000..c77cb2f --- /dev/null +++ b/src/variant_extraction/README.md @@ -0,0 +1,150 @@ +# Variant Extraction Module + +This is organized into the following Python modules, each handling a specific aspect of the workflow: + +- **config.py**: Stores configuration variables, such as URLs, file paths, and API settings. +- **ncbi_fetch.py**: Manages fetching PMCID and content from NCBI using the Entrez API. +- **processing.py**: Loads and processes the variant annotation dataset, including enumeration cleaning and DataFrame processing. Also interacts with the OpenAI API to extract structured genetic variant data from publication content. +- **variant_matching.py**: Compares extracted data with ground truth for accuracy evaluation. +- **visualization.py**: Generates visualizations to summarize match rates and analysis results. +- **run_variant_extraction.py**: Orchestrates the entire workflow, integrating all modules. + +## config.py +This module centralizes configuration settings to avoid hardcoding values in the codebase. + +**Variables**: +- URLs for downloading PharmGKB data (CLINICAL_VARIANTS_URL, VARIANT_ANNOTATIONS_URL). +- File paths for input and output data (VAR_DRUG_ANN_PATH, CHECKPOINT_PATH, OUTPUT_CSV_PATH, DF_NEW_CSV_PATH, WHOLE_CSV_PATH). +- NCBI Entrez email (ENTREZ_EMAIL) for API compliance. +- OpenAI model name (OPENAI_MODEL) and JSON schema (SCHEMA_TEXT) for structured API responses. +- System message template (SYSTEM_MESSAGE_TEMPLATE) for API prompts. + +## processing.py +This module handles interactions with the OpenAI API to extract structured genetic variant data. + +`clean_enum_list(enum_list)`: + +Cleans and normalizes enumeration lists by removing NaN values, splitting comma-separated strings, and ensuring uniqueness. +Used to prepare valid enumeration values for the JSON schema. + + +`load_and_prepare_data(file_path)`: + +Loads the variant annotation TSV file into a pandas DataFrame. +Extracts unique values for Phenotype Category, Significance, Metabolizer types, and Population types to create enumeration lists. +Returns the DataFrame and a dictionary of cleaned enumeration values. + + +`create_schema(enum_values)`: + +Creates a JSON schema for API responses based on the provided enumeration values. +Defines a structure for an array of gene objects with fields like gene, variant, drug(s), and others, enforcing strict validation. + + +`create_messages(content_text, schema_text, custom_template=None)`: + +Generates API messages with a system prompt (using SYSTEM_MESSAGE_TEMPLATE or a custom template) and user content. +The system prompt instructs the API to extract genetic variant information in the specified schema format. + + +`call_api(client, messages, schema)`: + +Makes an API call to the OpenAI model (gpt-4o-2024-08-06) with the provided messages and schema. +Returns the parsed JSON response containing extracted gene data. + + +`load_checkpoint(checkpoint_path)`: + +Loads previously processed PMIDs and results from a checkpoint file to avoid redundant API calls. +Returns a set of processed PMIDs and a list of results. + + +`save_checkpoint(checkpoint_path, processed_pmids, results)`: + +Saves processed PMIDs and results to a checkpoint file for persistence. +Ensures progress is saved after each processed row to handle interruptions. + + +`process_responses(df, client, schema_text, schema, checkpoint_path, custom_template=None)`: + +Iterates through the DataFrame to process each row’s Content_text using the OpenAI API. +Skips previously processed PMIDs based on checkpoint data. +Saves results and updates the checkpoint after each row to ensure progress is not lost. +Returns a list of flattened JSON objects with extracted gene data and associated PMIDs. + + + +## variant_matching.py +This module compares extracted data with ground truth to evaluate accuracy. + +`SimplifiedVariantMatcher`: + +`split_variants(variant_string)`: +Splits variant strings into individual components, handling delimiters like commas and slashes. + + +`preprocess_variants(variant_string, gene=None)`: +Preprocesses variant strings to handle rsIDs, star alleles, and SNP notations. +Attaches gene names to star alleles and processes complex notations (e.g., CYP2C19*2-1234G>A). + + +`match_row(row)`: +Compares ground truth and predicted variants for a single row. +Returns Exact Match, Partial Match, or No Match based on set intersections. + + +`align_and_compare_datasets(df_new, flattened_df)`: + +Renames columns in input DataFrames to distinguish ground truth (_truth) and predicted (_output) data. +Merges DataFrames on PMID using an inner join. +Applies variant matching using SimplifiedVariantMatcher and compares other fields (gene, drug(s), phenotype category, significance, metabolizer types, specialty population). +Returns a DataFrame with match indicators for each field. + + + +## visualization.py +This module generates visualizations to summarize match rates and analysis results. + +`plot_match_rates(match_stats)`: + +Creates a bar plot of exact match rates for Gene, Drug, Phenotype, Significance, and Variant categories. +Uses a professional color scheme and ensures readability with appropriate labels and limits. + + +`plot_pie_charts(match_stats)`: + +Generates nested pie charts showing partial_match_rate (outer) and exact_match_rate (inner). +Includes a legend with percentage values for clarity. + + +`plot_grouped_match_rates(average_gene_match_rate, average_drug_match_rate, average_variant_match_rate)`: + +Plots a bar chart of match rates for Gene, Drug, and Variant categories, calculated by grouping data by PMID. +Adds percentage labels above bars for clarity. + + +`plot_attribute_match_rates(wholecsv)`: + +Creates a bar plot of match percentages for attributes like Match metabolizer, Match significance, etc., from the wholecsv dataset. +Returns a DataFrame summarizing the match statistics for inclusion in reports or posters. + + + +## run_variant_extraction.py +This module orchestrates the entire workflow. + +`main()`: +- Initializes the OpenAI client with the API key from the environment. +- Downloads and extracts PharmGKB data using data_download.download_and_extract_zip. +- Loads and prepares the variant annotation dataset using data_processing.load_and_prepare_data. +- Processes a subset of the DataFrame (e.g., 5 rows) to fetch NCBI data using data_processing.process_dataframe. +- Creates a JSON schema using processing.create_schema. +- Processes API responses to extract gene data using processing.process_responses. +- Aligns and compares datasets using variant_matching.align_and_compare_datasets. +- Calculates match statistics for various fields and grouped match rates by PMID. +- Saves output DataFrames to CSV files (DF_NEW_CSV_PATH, OUTPUT_CSV_PATH). +- Generates visualizations using visualization module functions. +- Prints match statistics and attribute match table to the console. + +## Run the variant extraction: +`python -m src.variant_extraction.run_variant_extraction` \ No newline at end of file diff --git a/src/variant_extraction/config.py b/src/variant_extraction/config.py new file mode 100644 index 0000000..44c538d --- /dev/null +++ b/src/variant_extraction/config.py @@ -0,0 +1,55 @@ +# config.py +# Configuration file for the variant annotation extraction process + +# URLs +CLINICAL_VARIANTS_URL = "https://api.pharmgkb.org/v1/download/file/data/clinicalVariants.zip" +VARIANT_ANNOTATIONS_URL = "https://api.pharmgkb.org/v1/download/file/data/variantAnnotations.zip" + +# File paths +VAR_DRUG_ANN_PATH = "./data/variantAnnotations/var_drug_ann.tsv" +CHECKPOINT_PATH = "./data/api_processing_checkpoint.json" +OUTPUT_CSV_PATH = "./data/variant_extraction/merged.csv" +DF_NEW_CSV_PATH = "./data/variant_extraction/df_new.csv" +WHOLE_CSV_PATH = "./data/variant_extraction/wholecsv.csv" + +# NCBI email +ENTREZ_EMAIL = "aron7628@gmail.com" + +# API settings +OPENAI_MODEL = "gpt-4o-2024-08-06" + +# JSON schema +SCHEMA_TEXT = ''' +{ + "type": "object", + "properties": { + "gene": {"type": "string", "description": "The specific gene related to the drug response or phenotype (e.g., CYP3A4).", "examples": ["CYP2C19", "UGT1A3"]}, + "variant/haplotypes": {"type": "string", "description": "full star allele including gene, full rsid, or full haplotype", "example": ["CYP2C19*17"]}, + "drug(s)": {"type": "string", "description": "The drug(s) that are influenced by the gene variant(s).", "examples": ["abrocitinib", "mirabegron"]}, + "phenotype category": {"type": "string", "description": "Describes the type of phenotype related to the gene-drug interaction (e.g., Metabolism/PK, toxicity).", "enum": ["Metabolism/PK", "Efficacy", "Toxicity", "Other"], "examples": ["Metabolism/PK"]}, + "significance": {"type": "string", "description": "The level of importance or statistical significance of the gene-drug interaction.", "enum": ["significant", "not significant", "not stated"], "examples": ["significant", "not stated"]}, + "metabolizer types": {"type": "string", "description": "Indicates the metabolizer status of the patient based on the gene variant.", "enum": ["poor", "intermediate", "extensive", "ultrarapid"], "examples": ["poor", "extensive"]}, + "specialty population": {"type": "string", "description": "Refers to specific populations where this gene-drug interaction may have different effects.", "examples": ["healthy individuals", "African American", "pediatric"], "default": "Not specified"}, + "PMID": {"type": "integer", "description": "PMID from source spreadsheet", "example": 123345} + }, + "required": ["gene", "variant/haplotyptes", "drug(s)", "phenotype category", "significance", "metabolizer types", "PMID"] +} +''' + +# System message template +SYSTEM_MESSAGE_TEMPLATE = ( + "You are tasked with extracting information from scientific articles to assist in genetic variant annotation. " + "Focus on identifying key details related to genetic variants, including but not limited to:\n" + "- Variant identifiers (e.g., rsIDs, gene names, protein changes like p.Val600Glu, or DNA changes like c.1799T>A).\n" + "- Associated genes, transcripts, and protein products.\n" + "- Contextual information such as clinical significance, population frequency, or related diseases and drugs.\n" + "- Methodologies or evidence supporting the findings (e.g., experimental results, population studies, computational predictions).\n\n" + "Your output must be in the form of an array of JSON objects adhering to the following schema:\n" + "{schema}\n\n" + "Each JSON object should include:\n" + "1. A unique variant identifier.\n" + "2. Relevant metadata (e.g., associated gene, protein change, clinical significance).\n" + "3. Contextual evidence supporting the variant's importance.\n\n" + "Ensure the extracted information is accurate and directly relevant to variant annotation. " + "When extracting, prioritize structured data, avoiding ambiguous or irrelevant information." +) \ No newline at end of file diff --git a/src/variant_extraction/processing.py b/src/variant_extraction/processing.py new file mode 100644 index 0000000..b81a65a --- /dev/null +++ b/src/variant_extraction/processing.py @@ -0,0 +1,110 @@ +# processing.py +import os +import pandas as pd +import tqdm +import json +from openai import OpenAI +from tqdm import tqdm +from src.variant_extraction.config import SCHEMA_TEXT, SYSTEM_MESSAGE_TEMPLATE, OPENAI_MODEL + +def clean_enum_list(enum_list): + """Clean and normalize enumeration lists.""" + cleaned_list = [x for x in enum_list if pd.notna(x)] + split_list = [item.strip() for sublist in cleaned_list for item in sublist.split(',')] + return list(set(split_list)) + +def load_and_prepare_data(file_path): + """Load and prepare the variant annotation DataFrame.""" + df = pd.read_csv(file_path, sep='\t') + phenotype_category_enum = clean_enum_list(df['Phenotype Category'].unique().tolist()) + significance_enum = clean_enum_list(df['Significance'].unique().tolist()) + metabolizer_types_enum = clean_enum_list(df['Metabolizer types'].unique().tolist()) + specialty_population_enum = clean_enum_list(df['Population types'].unique().tolist()) + return df, { + 'phenotype_category': phenotype_category_enum, + 'significance': significance_enum, + 'metabolizer_types': metabolizer_types_enum, + 'specialty_population': specialty_population_enum + } + +def create_schema(enum_values): + """Create JSON schema for API calls.""" + return { + "type": "object", + "properties": { + "genes": { + "type": "array", + "items": { + "type": "object", + "properties": { + "gene": {"type": "string"}, + "variant": {"type": "string", "description": "can be in the form of a star allele, rsid"}, + "drug(s)": {"type": "string"}, + "phenotype category": {"type": "string", "enum": enum_values['phenotype_category']}, + "significance": {"type": "string", "enum": ["significant", "not significant", "not stated"]}, + "metabolizer types": {"type": "string", "enum": enum_values['metabolizer_types']}, + "specialty population": {"type": "string"} + }, + "required": ["gene", "variant", "drug(s)", "phenotype category", "significance", "metabolizer types", "specialty population"], + "additionalProperties": False + } + } + }, + "required": ["genes"], + "additionalProperties": False + } + +def create_messages(content_text, schema_text, custom_template=None): + """Create API messages from templates.""" + system_message = custom_template or SYSTEM_MESSAGE_TEMPLATE + system_message = system_message.format(schema=schema_text) + return [ + {"role": "system", "content": system_message}, + {"role": "user", "content": content_text} + ] + +def call_api(client, messages, schema): + """Make an API call using the provided messages and schema.""" + response = client.chat.completions.create( + model=OPENAI_MODEL, + messages=messages, + response_format={ + "type": "json_schema", + "json_schema": {"name": "gene_array_response", "schema": schema, "strict": True} + } + ) + return json.loads(response.choices[0].message.content) + +def load_checkpoint(checkpoint_path): + """Load checkpoint data.""" + if os.path.exists(checkpoint_path): + with open(checkpoint_path, 'r') as f: + checkpoint = json.load(f) + return set(checkpoint.get('processed_pmids', [])), checkpoint.get('results', []) + return set(), [] + +def save_checkpoint(checkpoint_path, processed_pmids, results): + """Save checkpoint data.""" + with open(checkpoint_path, 'w') as f: + json.dump({"processed_pmids": list(processed_pmids), "results": results}, f) + +def process_responses(df, client, schema_text, schema, checkpoint_path, custom_template=None): + """Process DataFrame rows and fetch API responses.""" + processed_pmids, results = load_checkpoint(checkpoint_path) + for _, row in tqdm(df.iterrows(), total=len(df), desc="Processing API responses"): + pmid = row['PMID'] + if pmid in processed_pmids: + continue + print(row) + content_text = row['Content_text'] + try: + messages = create_messages(content_text, schema_text, custom_template) + extracted_data = call_api(client, messages, schema) + for gene_info in extracted_data.get('genes', []): + gene_info['PMID'] = pmid + results.append(gene_info) + processed_pmids.add(pmid) + save_checkpoint(checkpoint_path, processed_pmids, results) + except Exception as e: + print(f"Error processing PMID {pmid}: {e}") + return results \ No newline at end of file diff --git a/src/variant_extraction/run_variant_extraction.py b/src/variant_extraction/run_variant_extraction.py new file mode 100644 index 0000000..e93032d --- /dev/null +++ b/src/variant_extraction/run_variant_extraction.py @@ -0,0 +1,102 @@ +# main.py +import pandas as pd +from openai import OpenAI +import os +import sys + +from src.variant_extraction.config import ( + VAR_DRUG_ANN_PATH, CHECKPOINT_PATH, OUTPUT_CSV_PATH, + DF_NEW_CSV_PATH, WHOLE_CSV_PATH, SCHEMA_TEXT +) + +from src.variant_extraction.processing import load_and_prepare_data, create_schema, process_responses +from src.variant_extraction.variant_matching import align_and_compare_datasets +from src.variant_extraction.visualization import plot_match_rates, plot_pie_charts, plot_grouped_match_rates, plot_attribute_match_rates + +def main(): + # Initialize OpenAI client + client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) + + # Load and prepare data + df_var_drug_ann, enum_values = load_and_prepare_data(VAR_DRUG_ANN_PATH) + + # Create schema + schema = create_schema(enum_values) + + # Process API responses + flattened_results = process_responses(df_var_drug_ann, client, SCHEMA_TEXT, schema, CHECKPOINT_PATH) + flattened_df = pd.DataFrame(flattened_results) + + # Align and compare datasets + df_aligned = align_and_compare_datasets(df_var_drug_ann, flattened_df) + + # Calculate match statistics + match_stats = { + 'gene_match_rate': df_aligned['gene_match'].mean() * 100, + 'drug_match_rate': df_aligned['drug_match'].mean() * 100, + 'phenotype_match_rate': df_aligned['phenotype_match'].mean() * 100, + 'variant_match_rate': (df_aligned['variant_match'] == 'Exact Match').mean() * 100, + 'partial_variant_match_rate': (df_aligned['variant_match'] == 'Partial Match').mean() * 100, + 'exact_match_rate': df_aligned[['gene_match', 'drug_match', 'phenotype_match']].all(axis=1).mean() * 100, + 'partial_match_rate': df_aligned[['gene_match', 'drug_match', 'phenotype_match']].any(axis=1).mean() * 100, + 'mismatch_rate': (~df_aligned[['gene_match', 'drug_match', 'phenotype_match', 'variant_match']].any(axis=1)).mean() * 100, + 'significance_match_rate': df_aligned['significance_match'].mean() * 100, + 'metabolizer_match_rate': df_aligned['metabolizer_match'].mean() * 100, + 'population_match_rate': df_aligned['population_match'].mean() * 100, + } + + # Grouped match rates + def normalize_split(value): + if pd.isna(value): + return set() + return set(map(str.strip, str(value).lower().replace(';', ',').split(','))) + + grouped_outputs = flattened_df.groupby('PMID').agg({ + 'gene': lambda x: set().union(*x.apply(normalize_split)), + 'variant/haplotypes': lambda x: set().union(*x.apply(normalize_split)), + 'drug(s)': lambda x: set().union(*x.apply(normalize_split)) + }).reset_index() + + grouped_drug = df_var_drug_ann.groupby('PMID').agg({ + 'Gene': lambda x: set().union(*x.apply(normalize_split)), + 'Variant/Haplotypes': lambda x: set().union(*x.apply(normalize_split)), + 'Drug(s)': lambda x: set().union(*x.apply(normalize_split)) + }).reset_index() + + merged_grouped_df = pd.merge(grouped_outputs, grouped_drug, on='PMID', suffixes=('_output', '_drug'), how='inner') + gene_matches = [] + variant_matches = [] + drug_matches = [] + for _, row in merged_grouped_df.iterrows(): + gene_matches.append(len(row['gene'].intersection(row['Gene'])) / len(row['Gene']) if len(row['Gene']) > 0 else 0) + variant_matches.append(len(row['variant/haplotypes'].intersection(row['Variant/Haplotypes'])) / len(row['Variant/Haplotypes']) if len(row['Variant/Haplotypes']) > 0 else 0) + drug_matches.append(len(row['drug(s)'].intersection(row['Drug(s)'])) / len(row['Drug(s)']) if len(row['Drug(s']) > 0 else 0) + + average_gene_match_rate = sum(gene_matches) / len(gene_matches) + average_variant_match_rate = sum(variant_matches) / len(variant_matches) + average_drug_match_rate = sum(drug_matches) / len(drug_matches) + + # Save outputs + df_var_drug_ann.to_csv(DF_NEW_CSV_PATH, index=False) + flattened_df.to_csv(OUTPUT_CSV_PATH, index=False) + + # Visualizations + plot_match_rates(match_stats) + plot_pie_charts(match_stats) + plot_grouped_match_rates(average_gene_match_rate, average_drug_match_rate, average_variant_match_rate) + + table_data = plot_attribute_match_rates(wholecsv) + + # Print results + print("Match Statistics:") + for key, value in match_stats.items(): + print(f"{key}: {value:.2f}%") + print("\nGrouped Match Rates:") + print(f"Average Gene Match Rate: {average_gene_match_rate*100:.2f}%") + print(f"Average Variant Match Rate: {average_variant_match_rate*100:.2f}%") + print(f"Average Drug Match Rate: {average_drug_match_rate*100:.2f}%") + print("\nAttribute Match Table:") + print(table_data) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/variant_extraction/variant_matching.py b/src/variant_extraction/variant_matching.py new file mode 100644 index 0000000..f419e6c --- /dev/null +++ b/src/variant_extraction/variant_matching.py @@ -0,0 +1,70 @@ +# variant_matching.py +import pandas as pd +import re + +class SimplifiedVariantMatcher: + @staticmethod + def split_variants(variant_string): + """Split variant strings into individual components.""" + if pd.isna(variant_string): + return set() + return set(map(str.strip, variant_string.replace('/', ',').split(','))) + + @staticmethod + def preprocess_variants(variant_string, gene=None): + """Preprocess variant strings.""" + variants = SimplifiedVariantMatcher.split_variants(variant_string) + processed_variants = set() + for variant in variants: + variant = variant.strip() + if 'rs' in variant: + rs_match = re.findall(r'rs\d+', variant) + processed_variants.update(rs_match) + elif '*' in variant and '-' in variant and gene: + star_allele, mutation = variant.split('-', 1) + processed_variants.add(f"{gene}{star_allele.strip()}") + processed_variants.add(mutation.strip()) + elif '*' in variant and gene: + processed_variants.add(f"{gene}{variant.strip()}") + elif '>' in variant: + processed_variants.add(variant.strip()) + else: + processed_variants.add(variant.strip()) + return processed_variants + + def match_row(self, row): + """Match ground truth and predicted variants for a single row.""" + truth_variants = self.preprocess_variants(row['Variant/Haplotypes_truth'], gene=row['Gene_truth']) + predicted_variants = self.preprocess_variants(row['variant/haplotypes_output'], gene=row['Gene_output']) + if predicted_variants == truth_variants: + return 'Exact Match' + if predicted_variants.intersection(truth_variants): + return 'Partial Match' + return 'No Match' + +def align_and_compare_datasets(df_new, flattened_df): + """Align and compare datasets based on PMID.""" + df_new = df_new.rename(columns={ + 'Gene': 'Gene_truth', + 'Drug(s)': 'Drug(s)_truth', + 'Phenotype Category': 'Phenotype Category_truth', + 'Variant/Haplotypes': 'Variant/Haplotypes_truth' + }) + flattened_df = flattened_df.rename(columns={ + 'gene': 'Gene_output', + 'drug(s)': 'Drug(s)_output', + 'phenotype category': 'Phenotype Category_output', + 'variant/haplotypes': 'variant/haplotypes_output' + }) + df_aligned = pd.merge(df_new, flattened_df, how='inner', on='PMID') + matcher = SimplifiedVariantMatcher() + df_aligned['variant_match'] = df_aligned.apply(lambda row: matcher.match_row(row), axis=1) + df_aligned['gene_match'] = df_aligned['Gene_truth'] == df_aligned['Gene_output'] + df_aligned['drug_match'] = df_aligned['Drug(s)_truth'] == df_aligned['Drug(s)_output'] + df_aligned['phenotype_match'] = df_aligned['Phenotype Category_truth'] == df_aligned['Phenotype Category_output'] + df_aligned['significance_match'] = df_aligned['Significance'] == df_aligned['significance'].map({ + 'significant': 'yes', 'not significant': 'no', 'not stated': 'not stated' + }) + df_aligned['metabolizer_match'] = df_aligned['Metabolizer types'] == df_aligned['metabolizer types'] + df_aligned['population_match'] = df_aligned['Specialty Population'] == df_aligned['specialty population'] + return df_aligned \ No newline at end of file diff --git a/src/variant_extraction/visualization.py b/src/variant_extraction/visualization.py new file mode 100644 index 0000000..57a076b --- /dev/null +++ b/src/variant_extraction/visualization.py @@ -0,0 +1,68 @@ +# visualization.py +import matplotlib.pyplot as plt + +def plot_match_rates(match_stats): + """Plot match rates for different categories.""" + match_fields = ['Gene', 'Drug', 'Phenotype', 'Significance', 'Variant'] + match_rates = [ + match_stats['gene_match_rate'], + match_stats['drug_match_rate'], + match_stats['phenotype_match_rate'], + match_stats['significance_match_rate'], + match_stats['variant_match_rate'] + ] + plt.figure(figsize=(8, 6)) + plt.bar(match_fields, match_rates, color=['#004B8D', '#175E54', '#8C1515', '#F58025', '#5D4B3C']) + plt.title('Exact Match Rates by Category', fontweight='bold', fontsize=14) + plt.ylabel('Match Rate (%)', fontweight="bold") + plt.xlabel('Category', fontweight='bold') + plt.ylim(0, 100) + plt.tight_layout() + plt.show() + +def plot_pie_charts(match_stats): + """Plot pie charts for exact and partial match rates.""" + sizes_partial = [match_stats['partial_match_rate'], 100 - match_stats['partial_match_rate']] + colors_partial = ['#175E54', 'none'] + sizes_exact = [match_stats['exact_match_rate'], 100 - match_stats['exact_match_rate']] + colors_exact = ['#8C1515', 'none'] + plt.figure(figsize=(8, 8)) + plt.pie(sizes_partial, colors=colors_partial, startangle=90, radius=1.0, wedgeprops={'linewidth': 1, 'edgecolor': 'white'}) + plt.pie(sizes_exact, colors=colors_exact, startangle=90, radius=0.7, wedgeprops={'linewidth': 1, 'edgecolor': 'white'}) + legend_labels = [f"Partial Match ({round(match_stats['partial_match_rate'], 2)}%)", f"Exact Match ({round(match_stats['exact_match_rate'], 3)}%)"] + legend_colors = ['#175E54', '#8C1515'] + plt.legend(handles=[plt.Line2D([0], [0], color=c, lw=6) for c in legend_colors], labels=legend_labels, loc='upper right', fontsize=10, frameon=True, title="Match Types", title_fontsize=12) + plt.title('Exact vs Partial Match Rates', fontweight="bold", fontsize=14, pad=20) + plt.tight_layout() + plt.show() + +def plot_grouped_match_rates(average_gene_match_rate, average_drug_match_rate, average_variant_match_rate): + """Plot match rates grouped by PMID.""" + categories = ['Gene', 'Drug', 'Variant'] + match_rates = [average_gene_match_rate * 100, average_drug_match_rate * 100, average_variant_match_rate * 100] + colors = ['#8C1515', '#175E54', '#F58025'] + plt.figure(figsize=(10, 6)) + bars = plt.bar(categories, match_rates, color=colors, edgecolor='black', linewidth=1.2) + for bar, rate in zip(bars, match_rates): + plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 1, f'{rate:.1f}%', ha='center', fontweight="bold", fontsize=10, color='black') + plt.title('Match Rates by Category (Grouped by PMID)', fontweight="bold", fontsize=14, pad=20) + plt.ylabel('Match Rate (%)', fontweight="bold", fontsize=12) + plt.xlabel('Categories', fontweight="bold", fontsize=12) + plt.ylim(0, 100) + plt.tight_layout() + plt.show() + +def plot_attribute_match_rates(wholecsv): + """Plot match statistics for different attributes.""" + match_columns = ['Match metabolizer', 'Match significance', 'Match all drug', 'Match Any Drug', 'Match gene', 'Match phenotype', 'Match population'] + match_stats_new = wholecsv[match_columns].mean() * 100 + plt.figure(figsize=(10, 6)) + plt.bar(match_stats_new.index, match_stats_new.values, color=['#2E8B57', '#4682B4', '#6A5ACD', '#D2691E', '#556B2F', '#8B4513', '#2F4F4F']) + plt.title('Match Statistics for Different Attributes', fontsize=16) + plt.ylabel('Match Percentage (%)', fontsize=12) + plt.xlabel('Attributes', fontsize=12) + plt.xticks(rotation=45) + plt.ylim(0, 100) + plt.tight_layout() + plt.show() + return match_stats_new.reset_index().rename(columns={'index': 'Attribute', 0: 'Match Percentage'}) \ No newline at end of file