From 06a231da3ea141b46c0ed37b1b08ca7705535b9f Mon Sep 17 00:00:00 2001 From: Sonnet Xu <59452214+sonnetx@users.noreply.github.com> Date: Fri, 23 May 2025 08:35:27 -0700 Subject: [PATCH 01/12] add information extraction --- requirements.txt | 8 + src/benchmark/README.md | 2 +- src/extract_annotations/config.py | 55 ++ .../copy_of_pharmgkb_dbds_real.py | 902 ++++++++++++++++++ src/extract_annotations/data_download.py | 22 + src/extract_annotations/ncbi_fetch.py | 65 ++ src/extract_annotations/processing.py | 119 +++ src/extract_annotations/run_extraction.py | 112 +++ src/extract_annotations/variant_matching.py | 70 ++ src/extract_annotations/visualization.py | 68 ++ 10 files changed, 1422 insertions(+), 1 deletion(-) create mode 100644 src/extract_annotations/config.py create mode 100644 src/extract_annotations/copy_of_pharmgkb_dbds_real.py create mode 100644 src/extract_annotations/data_download.py create mode 100644 src/extract_annotations/ncbi_fetch.py create mode 100644 src/extract_annotations/processing.py create mode 100644 src/extract_annotations/run_extraction.py create mode 100644 src/extract_annotations/variant_matching.py create mode 100644 src/extract_annotations/visualization.py diff --git a/requirements.txt b/requirements.txt index da8cd9b..c6c0df1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,10 @@ # SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see CONTRIBUTORS.md) # SPDX-License-Identifier: Apache-2.0 + +requests +pandas +openai +biopython +beautifulsoup4 +tqdm +matplotlib \ No newline at end of file diff --git a/src/benchmark/README.md b/src/benchmark/README.md index d12e18e..e86211d 100644 --- a/src/benchmark/README.md +++ b/src/benchmark/README.md @@ -1,4 +1,4 @@ # Benchmark ## Functions -1. Calculate the niave difference between an extracted variant and the ground truth variant on Variant Annotation ID +1. Calculate the naive difference between an extracted variant and the ground truth variant on Variant Annotation ID diff --git a/src/extract_annotations/config.py b/src/extract_annotations/config.py new file mode 100644 index 0000000..f7d46c7 --- /dev/null +++ b/src/extract_annotations/config.py @@ -0,0 +1,55 @@ +# config.py +# Configuration file for the variant annotation extraction process + +# URLs +CLINICAL_VARIANTS_URL = "https://api.pharmgkb.org/v1/download/file/data/clinicalVariants.zip" +VARIANT_ANNOTATIONS_URL = "https://api.pharmgkb.org/v1/download/file/data/variantAnnotations.zip" + +# File paths +VAR_DRUG_ANN_PATH = "/content/variantAnnotations/var_drug_ann.tsv" +CHECKPOINT_PATH = "/content/api_processing_checkpoint.json" +OUTPUT_CSV_PATH = "/content/merged_first100.csv" +DF_NEW_CSV_PATH = "/content/df_new.csv" +WHOLE_CSV_PATH = "/content/wholecsv.csv" + +# NCBI email +ENTREZ_EMAIL = "aron7628@gmail.com" + +# API settings +OPENAI_MODEL = "gpt-4o-2024-08-06" + +# JSON schema +SCHEMA_TEXT = ''' +{ + "type": "object", + "properties": { + "gene": {"type": "string", "description": "The specific gene related to the drug response or phenotype (e.g., CYP3A4).", "examples": ["CYP2C19", "UGT1A3"]}, + "variant/haplotypes": {"type": "string", "description": "full star allele including gene, full rsid, or full haplotype", "example": ["CYP2C19*17"]}, + "drug(s)": {"type": "string", "description": "The drug(s) that are influenced by the gene variant(s).", "examples": ["abrocitinib", "mirabegron"]}, + "phenotype category": {"type": "string", "description": "Describes the type of phenotype related to the gene-drug interaction (e.g., Metabolism/PK, toxicity).", "enum": ["Metabolism/PK", "Efficacy", "Toxicity", "Other"], "examples": ["Metabolism/PK"]}, + "significance": {"type": "string", "description": "The level of importance or statistical significance of the gene-drug interaction.", "enum": ["significant", "not significant", "not stated"], "examples": ["significant", "not stated"]}, + "metabolizer types": {"type": "string", "description": "Indicates the metabolizer status of the patient based on the gene variant.", "enum": ["poor", "intermediate", "extensive", "ultrarapid"], "examples": ["poor", "extensive"]}, + "specialty population": {"type": "string", "description": "Refers to specific populations where this gene-drug interaction may have different effects.", "examples": ["healthy individuals", "African American", "pediatric"], "default": "Not specified"}, + "PMID": {"type": "integer", "description": "PMID from source spreadsheet", "example": 123345} + }, + "required": ["gene", "variant/haplotyptes", "drug(s)", "phenotype category", "significance", "metabolizer types", "PMID"] +} +''' + +# System message template +SYSTEM_MESSAGE_TEMPLATE = ( + "You are tasked with extracting information from scientific articles to assist in genetic variant annotation. " + "Focus on identifying key details related to genetic variants, including but not limited to:\n" + "- Variant identifiers (e.g., rsIDs, gene names, protein changes like p.Val600Glu, or DNA changes like c.1799T>A).\n" + "- Associated genes, transcripts, and protein products.\n" + "- Contextual information such as clinical significance, population frequency, or related diseases and drugs.\n" + "- Methodologies or evidence supporting the findings (e.g., experimental results, population studies, computational predictions).\n\n" + "Your output must be in the form of an array of JSON objects adhering to the following schema:\n" + "{schema}\n\n" + "Each JSON object should include:\n" + "1. A unique variant identifier.\n" + "2. Relevant metadata (e.g., associated gene, protein change, clinical significance).\n" + "3. Contextual evidence supporting the variant's importance.\n\n" + "Ensure the extracted information is accurate and directly relevant to variant annotation. " + "When extracting, prioritize structured data, avoiding ambiguous or irrelevant information." +) \ No newline at end of file diff --git a/src/extract_annotations/copy_of_pharmgkb_dbds_real.py b/src/extract_annotations/copy_of_pharmgkb_dbds_real.py new file mode 100644 index 0000000..1b1e98b --- /dev/null +++ b/src/extract_annotations/copy_of_pharmgkb_dbds_real.py @@ -0,0 +1,902 @@ +''' +Aaron's original code, with minor cleaning and reformatting. +No major changes were made to the logic or structure. +''' + +import requests +import zipfile +import io +import pandas as pd +from openai import OpenAI +import time +from Bio import Entrez +import os +import json +import random +from bs4 import BeautifulSoup +import tqdm + +url = "https://api.pharmgkb.org/v1/download/file/data/clinicalVariants.zip" + +# Download the zip file +try: + response = requests.get(url, stream=True) + response.raise_for_status() # Raise an exception for bad status codes + + # Unpack the zip file + with zipfile.ZipFile(io.BytesIO(response.content)) as z: + z.extractall("clinicalVariants") + print("Successfully downloaded and unpacked the zip file.") + +except requests.exceptions.RequestException as e: + print(f"Error downloading the file: {e}") +except zipfile.BadZipFile as e: + print(f"Error unpacking the zip file: {e}") +except Exception as e: + print(f"An unexpected error occurred: {e}") + +url = "https://api.pharmgkb.org/v1/download/file/data/variantAnnotations.zip" + + +# Download the zip file +try: + response = requests.get(url, stream=True) + response.raise_for_status() # Raise an exception for bad status codes + + # Unpack the zip file + with zipfile.ZipFile(io.BytesIO(response.content)) as z: + z.extractall("variantAnnotations") + print("Successfully downloaded and unpacked the zip file.") + +except requests.exceptions.RequestException as e: + print(f"Error downloading the file: {e}") +except zipfile.BadZipFile as e: + print(f"Error unpacking the zip file: {e}") +except Exception as e: + print(f"An unexpected error occurred: {e}") + + +client = OpenAI( + api_key = os.getenv("OPENAI_API_KEY"), +) + +## Loading the data + +var_drug_ann = "/content/variantAnnotations/var_drug_ann.tsv" +df_var_drug_ann = pd.read_csv(var_drug_ann, sep='\t') + +## Get unique values +phenotype_category_enum = df_var_drug_ann['Phenotype Category'].unique().tolist() +significance_enum = df_var_drug_ann['Significance'].unique().tolist() +metabolizer_types_enum = df_var_drug_ann['Metabolizer types'].unique().tolist() +specialty_population_enum = df_var_drug_ann['Population types'].unique().tolist() + +## Get PMCID + +# Email for NCBI +Entrez.email = "aron7628@gmail.com" + +# Step 1: Function to get PMCID from PMID +def get_pmcid_from_pmid(pmid, retries=3): + for attempt in range(retries): + try: + handle = Entrez.elink(dbfrom="pubmed", db="pmc", id=pmid, linkname="pubmed_pmc") + record = Entrez.read(handle) + handle.close() + if record and 'LinkSetDb' in record[0] and record[0]['LinkSetDb']: + pmcid = record[0]['LinkSetDb'][0]['Link'][0]['Id'] + return pmcid + else: + print(f"No PMCID found for PMID {pmid}.") + return None + except Exception as e: + print(f"An error occurred for pmid {pmid} on attempt {attempt + 1}: {e}") + if attempt < retries - 1: + # Backoff time increases exponentially with jitter + sleep_time = (2 ** attempt) + random.uniform(0, 1) + print(f"Retrying in {sleep_time:.2f} seconds...") + time.sleep(sleep_time) + else: + return None + +# Step 2: Function to fetch content using PMCID +def fetch_pmc_content(pmcid): + try: + handle = Entrez.efetch(db="pmc", id=pmcid, rettype="full", retmode="xml") + record = handle.read() + handle.close() + return record + except Exception as e: + print(f"An error occurred while fetching content for PMCID {pmcid}: {e}") + return None + +# Function to process each row in the DataFrame +def process_row(row, processed_pmids, processed_data): + time.sleep(0.4 + random.uniform(0, 0.5)) # Introduce delay to avoid throttling + pmid = str(row['PMID']) + + if pmid in processed_pmids: + # Use the previously processed data for duplicate PMIDs + return pd.Series(processed_data[pmid]) + + # Step 1: Get PMCID from PMID + pmcid = get_pmcid_from_pmid(pmid) + + if pmcid: + # Step 2: Fetch PMC content using the new fetch_pmc_content function + xml_content = fetch_pmc_content(pmcid) + + if xml_content: + # Step 3: Parse the XML content to extract text and title using BeautifulSoup + soup = BeautifulSoup(xml_content, 'xml') + + # Extract the article title + title_tag = soup.find('article-title') + title = title_tag.get_text() if title_tag else "No Title Found" + + # Extract the full text of the article + clean_text = soup.get_text() + + # Save processed data for this PMID + processed_pmids.add(pmid) + processed_data[pmid] = { + 'PMCID': pmcid, + 'Title': title, + 'Content': xml_content, + 'Content_text': clean_text, + } + else: + # Save processed data for failed PMC content fetch + processed_pmids.add(pmid) + processed_data[pmid] = { + 'PMCID': pmcid, + 'Title': None, + 'Content': None, + 'Content_text': None, + } + else: + # Save processed data for invalid PMIDs + processed_pmids.add(pmid) + processed_data[pmid] = { + 'PMCID': None, + 'Title': None, + 'Content': None, + 'Content_text': None, + } + + # Return the processed data for the current row + return pd.Series(processed_data[pmid]) + +# Wrapper function to handle processed PMIDs and data +def process_dataframe(df): + processed_pmids = set() + processed_data = {} + + tqdm.pandas(desc="Processing rows") # Initialize tqdm for Pandas + return df.progress_apply(lambda row: process_row(row, processed_pmids, processed_data), axis=1) + +how_many = 5 +testdf = df_var_drug_ann[:how_many] +result_df = process_dataframe(testdf) + +# Combine the results back with the original DataFrame +df_new = pd.concat([testdf, result_df], axis=1) +df_new.to_csv(f"first_{how_many}_var_drug.csv") + +df_new.dropna(subset=['Content'], inplace=True) +df_new.reset_index(drop=True, inplace=True) +print(df_new) + +len(df_new['PMID'].unique()) + +schema_text = ''' +{ + "type": "object", + "properties": { + "gene": { + "type": "string", + "description": "The specific gene related to the drug response or phenotype (e.g., CYP3A4).", + "examples": ["CYP2C19", "UGT1A3"] + }, + "variant/haplotypes": { + "type": "string", + "description": "full star allele including gene, full rsid, or full haplotype" + "example":["CYP2C19*17"] + }, + "drug(s)": { + "type": "string", + "description": "The drug(s) that are influenced by the gene variant(s).", + "examples": ["abrocitinib", "mirabegron"] + }, + "phenotype category": { + "type": "string", + "description": "Describes the type of phenotype related to the gene-drug interaction (e.g., Metabolism/PK, toxicity).", + "enum": ["Metabolism/PK", "Efficacy", "Toxicity", "Other"], + "examples": ["Metabolism/PK"] + }, + "significance": { + "type": "string", + "description": "The level of importance or statistical significance of the gene-drug interaction (e.g., 'significant', 'not significant').", + "enum": ["significant", "not significant", "not stated"], + "examples": ["significant", "not stated"] + }, + "metabolizer types": { + "type": "string", + "description": "Indicates the metabolizer status of the patient based on the gene variant (e.g., poor metabolizer, extensive metabolizer).", + "enum": ["poor", "intermediate", "extensive", "ultrarapid"], + "examples": ["poor", "extensive"] + }, + "specialty population": { + "type": "string", + "description": "Refers to specific populations where this gene-drug interaction may have different effects (e.g., ethnic groups, pediatric populations).", + "examples": ["healthy individuals", "African American", "pediatric"], + "default": "Not specified" + }, + "PMID": { + "type": "integer", # Changed from 'int' to 'integer' + "description": "PMID from source spreadsheet", + "example": 123345 + } + }, + "required": ["gene","variant/haplotyptes", "drug(s)", "phenotype category", "significance", "metabolizer types", "PMID"] +} +''' + +def clean_enum_list(enum_list): + # Remove NaN + cleaned_list = [x for x in enum_list if pd.notna(x)] + # Split comma-separated strings and flatten the list + split_list = [item.strip() for sublist in cleaned_list for item in sublist.split(',')] + return list(set(split_list)) # Ensure uniqueness + +# Clean phenotype_category_enum and metabolizer_types_enum +phenotype_category_enum = clean_enum_list(phenotype_category_enum) +metabolizer_types_enum = clean_enum_list(metabolizer_types_enum) + + +schema = { + "type": "object", # The root must be an object, not an array + "properties": { + "genes": { # The "genes" property will contain an array of gene objects + "type": "array", + "items": { + "type": "object", + "properties": { + "gene": { + "type": "string" + }, + "variant": { + "type": "string", + "description": "can be in the form of a star allele, rsid" + }, + "drug(s)": { + "type": "string" + }, + + "phenotype category": { + "type": "string", + "enum": phenotype_category_enum + }, + "significance": { + "type": "string", + "enum": ["significant", "not significant", "not stated"] + }, + "metabolizer types": { + "type": "string", + "enum": metabolizer_types_enum + }, + "specialty population": { + "type": "string" + } + + + }, + "required": ["gene","variant","drug(s)", "phenotype category", "significance", "metabolizer types", "specialty population"], + "additionalProperties": False + } + } + }, + "required": ["genes"], # The root object must have a "genes" array + "additionalProperties": False +} + +# Initialize a list to hold the flattened JSON results +flattened_results = [] + +# Loop through each row in df_new +for index, row in df_new.iterrows(): + content_text = row['Content_text'] + pmid = row['PMID'] + + # Make the API call for each row + response = client.chat.completions.create( + model="gpt-4o-2024-08-06", + messages=[ + {"role": "system", "content": f"Extract multiple gene-related information from the text and return them as an array of JSON objects with example schema{schema_text}"}, + {"role": "user", "content": content_text} # Assuming Content contains the text you're processing + ], + response_format={ + "type": "json_schema", + "json_schema": { + "name": "gene_array_response", + "schema": schema, + "strict": True + } + } + ) + + # Extract and load the JSON response + extracted_data = json.loads(response.choices[0].message.content) + + # Flatten the JSON output and attach PMID + for gene_info in extracted_data.get('genes', []): + gene_info['PMID'] = pmid # Attach the PMID to each gene entry + flattened_results.append(gene_info) # Add to the list of flattened results + +# Convert flattened results to a DataFrame +flattened_df = pd.DataFrame(flattened_results) + +print(flattened_df) + +# this fetches data. +def fetch_test_data(df, num_rows=10): + """ + Fetch a subset of data for testing. + :param df: DataFrame containing the full dataset. + :param num_rows: Number of rows to sample for testing. + :return: DataFrame with sampled rows. + """ + return df.head(num_rows) + +def create_messages(content_text, schema_text, custom_template=None): + """ + Create API messages from templates. + :param content_text: Text content to process. + :param schema_text: Schema description for the API call. + :param custom_template: Optional custom template string for the system message. + :return: List of messages for the API. + """ + if custom_template: + system_message = custom_template.format(schema=schema_text, content=content_text) + else: + system_message = ( + "You are tasked with extracting information from scientific articles to assist in genetic variant annotation. " + "Focus on identifying key details related to genetic variants, including but not limited to:\n" + "- Variant identifiers (e.g., rsIDs, gene names, protein changes like p.Val600Glu, or DNA changes like c.1799T>A).\n" + "- Associated genes, transcripts, and protein products.\n" + "- Contextual information such as clinical significance, population frequency, or related diseases and drugs.\n" + "- Methodologies or evidence supporting the findings (e.g., experimental results, population studies, computational predictions).\n\n" + "Your output must be in the form of an array of JSON objects adhering to the following schema:\n" + f"{schema_text}\n\n" + "Each JSON object should include:\n" + "1. A unique variant identifier.\n" + "2. Relevant metadata (e.g., associated gene, protein change, clinical significance).\n" + "3. Contextual evidence supporting the variant's importance.\n\n" + "Ensure the extracted information is accurate and directly relevant to variant annotation. " + "When extracting, prioritize structured data, avoiding ambiguous or irrelevant information." + ) + + return [ + {"role": "system", "content": system_message}, + {"role": "user", "content": content_text} + ] + +def call_api(client, messages, schema): + """ + Make an API call using the provided messages and schema. + :param client: API client object. + :param messages: List of messages to send to the API. + :param schema: JSON schema definition. + :return: Parsed JSON response. + """ + response = client.chat.completions.create( + model="gpt-4o-2024-08-06", + messages=messages, + response_format={ + "type": "json_schema", + "json_schema": { + "name": "gene_array_response", + "schema": schema, + "strict": True + } + } + ) + return json.loads(response.choices[0].message.content) + +def load_checkpoint(checkpoint_path): + """ + Load checkpoint data (processed PMIDs and saved results). + :param checkpoint_path: Path to the checkpoint file. + :return: Tuple of processed PMIDs set and saved results list. + """ + if os.path.exists(checkpoint_path): + with open(checkpoint_path, 'r') as f: + checkpoint = json.load(f) + return set(checkpoint.get('processed_pmids', [])), checkpoint.get('results', []) + return set(), [] + +def save_checkpoint(checkpoint_path, processed_pmids, results): + """ + Save checkpoint data (processed PMIDs and results). + :param checkpoint_path: Path to the checkpoint file. + :param processed_pmids: Set of processed PMIDs. + :param results: List of saved results. + """ + with open(checkpoint_path, 'w') as f: + json.dump({"processed_pmids": list(processed_pmids), "results": results}, f) + +def process_responses(df, client, schema_text, schema, checkpoint_path, custom_template=None): + """ + Process rows from the DataFrame and fetch API responses, saving dynamically. + :param df: DataFrame containing the rows to process. + :param client: API client object. + :param schema_text: Schema description for the API call. + :param schema: JSON schema definition. + :param checkpoint_path: Path to the checkpoint file. + :param custom_template: Optional custom template string for the system message. + :return: List of flattened JSON objects with additional metadata. + """ + processed_pmids, results = load_checkpoint(checkpoint_path) + + for _, row in tqdm(df.iterrows(), total=len(df), desc="Processing rows"): + pmid = row['PMID'] + + if pmid in processed_pmids: + continue # Skip already-processed PMIDs + + content_text = row['Content_text'] + try: + messages = create_messages(content_text, schema_text, custom_template) + extracted_data = call_api(client, messages, schema) + for gene_info in extracted_data.get('genes', []): + gene_info['PMID'] = pmid + results.append(gene_info) + + processed_pmids.add(pmid) # Mark as processed + except Exception as e: + print(f"Error processing PMID {pmid}: {e}") + + # Save progress after each processed row + save_checkpoint(checkpoint_path, processed_pmids, results) + + return results + +checkpoint_file = "/content/api_processing_checkpoint.json" +flattened_results = process_responses(df_new, client, schema_text, schema, checkpoint_file) + +# Convert to DataFrame +flattened_df = pd.DataFrame(flattened_results) + +# Print the resulting DataFrame +print(flattened_df) + +def loadCSV(path="/content/first_5000_var_drug.csv"): + return pd.read_csv(path,index_col=[0]) + +df_new = loadCSV() + +groups = df_new.groupby("PMID") + +maxV = [0,""] +lengths = set() +for i, (name, group) in enumerate(groups): + lengths.add(len(group)) + maxV[0] = max(len(group),maxV[0]) + maxV[1] = name + +print(maxV) +print(lengths) + +#load the outputs from chatgpt +outputsDF = pd.read_csv("/content/outputs_per_pmid.csv") +outputGroups = outputsDF.groupby("PMID") + +print(outputsDF.columns) +print(df_new.columns) + +column_mapping = { + 'gene': 'Gene', + 'variant/haplotypes': 'Variant/Haplotypes', + 'drug(s)': 'Drug(s)', + 'phenotype category': 'Phenotype Category', + 'significance': 'Significance', + 'metabolizer types': 'Metabolizer types', + 'specialty population': 'Specialty Population', + 'PMID': 'PMID' # Keep this for grouping +} + +# Add a new column to flattened_df to indicate whether a match was found and to store the Variant Annotation ID +flattened_df['match'] = False +flattened_df['Variant Annotation ID'] = None + +# Iterate through each row in flattened_df +for idx, row in flattened_df.iterrows(): + pmid = row['PMID'] + gene = row['gene'] + + # Find all rows in df_new where PMID matches + matching_pmid_rows = df_new[df_new['PMID'] == pmid] + + # Check if there is a gene match in the filtered df_new + gene_match_row = matching_pmid_rows[matching_pmid_rows['Gene'] == gene] + + if not gene_match_row.empty: + # If there's a match, attach the Variant Annotation ID to flattened_df + flattened_df.at[idx, 'Variant Annotation ID'] = gene_match_row.iloc[0]['Variant Annotation ID'] + flattened_df.at[idx, 'match'] = True # Mark as matched + else: + flattened_df.at[idx, 'match'] = False # No match found + +# Output the result +print(flattened_df) + +flattened_df = pd.read_csv("/content/outputs_per_pmid_matched.csv") + +# Perform a merge on 'Variant Annotation ID' +merged_df = pd.merge(flattened_df, df_new, on='Variant Annotation ID', how='inner', suffixes=('_flattened', '_df_new')) + +merged_df = merged_df.rename(columns={'Variant Annotation ID_flattened': 'Variant Annotation ID'}) + +# Output the final DataFrame + +merged_df.to_csv('/content/merged_first100.csv', index=False) + +df_new.to_csv('/content/df_new.csv', index=False) + +df_new_one = df_new +outputs_first100_one = flattened_df + +# Align the datasets based only on PMID, which is guaranteed to match +df_aligned_pmid = pd.merge(df_new_one, outputs_first100_one, how='inner', left_on='PMID', right_on='PMID', suffixes=('_truth', '_output')) + +# Compare key fields: 'Gene', 'Drug(s)', and 'Phenotype Category' +df_aligned_pmid['gene_match'] = df_aligned_pmid['Gene'] == df_aligned_pmid['gene'] +df_aligned_pmid['drug_match'] = df_aligned_pmid['Drug(s)'] == df_aligned_pmid['drug(s)'] +df_aligned_pmid['phenotype_match'] = df_aligned_pmid['Phenotype Category'] == df_aligned_pmid['phenotype category'] + +# Calculate the match statistics for each field +match_stats = { + 'gene_match_rate': df_aligned_pmid['gene_match'].mean() * 100, + 'drug_match_rate': df_aligned_pmid['drug_match'].mean() * 100, + 'phenotype_match_rate': df_aligned_pmid['phenotype_match'].mean() * 100, + 'exact_match_rate': df_aligned_pmid[['gene_match', 'drug_match', 'phenotype_match']].all(axis=1).mean() * 100, + 'partial_match_rate': df_aligned_pmid[['gene_match', 'drug_match', 'phenotype_match']].any(axis=1).mean() * 100, + 'mismatch_rate': (~df_aligned_pmid[['gene_match', 'drug_match', 'phenotype_match']].any(axis=1)).mean() * 100 +} + +# Display match statistics +print("Match Statistics:") +for key, value in match_stats.items(): + print(f"{key}: {value:.2f}%") + +pmidGroups=df_aligned_pmid.groupby("PMID") + +len(flattened_df["PMID"].unique()) +len(df_new[df_new["PMID"].isin(flattened_df["PMID"].unique())]) + +# Simplified VariantMatcher class +class SimplifiedVariantMatcher: + @staticmethod + def split_variants(variant_string): + """Split variant strings into individual components.""" + if pd.isna(variant_string): + return set() + return set(map(str.strip, variant_string.replace('/', ',').split(','))) + + @staticmethod + def preprocess_variants(variant_string, gene=None): + """Attach star alleles to gene and handle complex variant notations.""" + variants = SimplifiedVariantMatcher.split_variants(variant_string) + processed_variants = set() + + for variant in variants: + variant = variant.strip() + # Handle rsIDs + if 'rs' in variant: + rs_match = re.findall(r'rs\d+', variant) + processed_variants.update(rs_match) + + # Handle star alleles with additional SNP information + if '*' in variant and '-' in variant and gene: + star_allele, mutation = variant.split('-', 1) + processed_variants.add(f"{gene}{star_allele.strip()}") + processed_variants.add(mutation.strip()) + + # Handle simple star alleles + elif '*' in variant and gene: + processed_variants.add(f"{gene}{variant.strip()}") + + # Handle SNP notations directly + elif '>' in variant: + processed_variants.add(variant.strip()) + + # Add any remaining variants as is + else: + processed_variants.add(variant.strip()) + + return processed_variants + + def match_row(self, row): + """Match ground truth and predicted variants for a single row.""" + truth_variants = self.preprocess_variants(row['Variant/Haplotypes_truth'], gene=row['Gene_truth']) + predicted_variants = self.preprocess_variants(row['variant/haplotypes_output'], gene=row['Gene_output']) + + # Perform matching + if predicted_variants.intersection(truth_variants): + return 'Partial Match' + if predicted_variants == truth_variants: + return 'Exact Match' + return 'No Match' + + +# Manually add suffixes to overlapping columns +df_new_one = df_new_one.rename(columns={ + 'Gene': 'Gene_truth', + 'Drug(s)': 'Drug(s)_truth', + 'Phenotype Category': 'Phenotype Category_truth', + 'Variant/Haplotypes': 'Variant/Haplotypes_truth' +}) + +outputs_first100_one = outputs_first100_one.rename(columns={ + 'gene': 'Gene_output', + 'drug(s)': 'Drug(s)_output', + 'phenotype category': 'Phenotype Category_output', + 'variant/haplotypes': 'variant/haplotypes_output' +}) + +# Merge the datasets based on PMID +df_aligned_pmid = pd.merge( + df_new_one, + outputs_first100_one, + how='inner', + on='PMID' +) + +# Initialize SimplifiedVariantMatcher +matcher = SimplifiedVariantMatcher() + +# Add the variant_match column to the aligned dataset +df_aligned_pmid['variant_match'] = df_aligned_pmid.apply(lambda row: matcher.match_row(row), axis=1) + +# Compare key fields: 'Gene', 'Drug(s)', and 'Phenotype Category' +df_aligned_pmid['gene_match'] = df_aligned_pmid['Gene_truth'] == df_aligned_pmid['Gene_output'] +df_aligned_pmid['drug_match'] = df_aligned_pmid['Drug(s)_truth'] == df_aligned_pmid['Drug(s)_output'] +df_aligned_pmid['phenotype_match'] = df_aligned_pmid['Phenotype Category_truth'] == df_aligned_pmid['Phenotype Category_output'] +df_aligned_pmid['significance_match'] = ( + df_aligned_pmid['Significance'] == + df_aligned_pmid['significance'].map({ + 'significant': 'yes', + 'not significant': 'no', + 'not stated': 'not stated' + }) +) +df_aligned_pmid['metabolizer_match'] = df_aligned_pmid['Metabolizer types'] == df_aligned_pmid['metabolizer types'] +df_aligned_pmid['population_match'] = df_aligned_pmid['Specialty Population'] == df_aligned_pmid['specialty population'] + + +# Calculate match statistics for each field +match_stats = { + 'gene_match_rate': df_aligned_pmid['gene_match'].mean() * 100, + 'drug_match_rate': df_aligned_pmid['drug_match'].mean() * 100, + 'phenotype_match_rate': df_aligned_pmid['phenotype_match'].mean() * 100, + 'variant_match_rate': (df_aligned_pmid['variant_match'] == 'Exact Match').mean() * 100, + 'partial_variant_match_rate': (df_aligned_pmid['variant_match'] == 'Partial Match').mean() * 100, + 'exact_match_rate': df_aligned_pmid[['gene_match', 'drug_match', 'phenotype_match']].all(axis=1).mean() * 100, + 'partial_match_rate': df_aligned_pmid[['gene_match', 'drug_match', 'phenotype_match']].any(axis=1).mean() * 100, + 'mismatch_rate': (~df_aligned_pmid[['gene_match', 'drug_match', 'phenotype_match', 'variant_match']].any(axis=1)).mean() * 100, + 'significance_match_rate': df_aligned_pmid['significance_match'].mean() * 100, + 'metabolizer_match_rate': df_aligned_pmid['metabolizer_match'].mean() * 100, + 'population_match_rate': df_aligned_pmid['population_match'].mean() * 100, +} + +# Display match statistics +print("Match Statistics:") +for key, value in match_stats.items(): + print(f"{key}: {value:.2f}%") + + +# Assuming 'flattened_df' is your DataFrame and 'PMID' is the column name +flattened_df = flattened_df.dropna(subset=['PMID']) + + +# Update column names in comparisons based on the actual merged dataframe +gene_matches = [] +variant_matches = [] +drug_matches = [] +# Grouping annotations per PMID and creating sets for genes, variants, and drugs + +# Function to normalize and split values, handling missing data +def normalize_split(value): + if pd.isna(value): + return set() + return set(map(str.strip, str(value).lower().replace(';', ',').split(','))) + +# Group by PMID and create sets for each component +grouped_outputs = flattened_df.groupby('PMID').agg({ + 'gene': lambda x: set().union(*x.apply(normalize_split)), + 'variant/haplotypes': lambda x: set().union(*x.apply(normalize_split)), + 'drug(s)': lambda x: set().union(*x.apply(normalize_split)) +}).reset_index() + +grouped_drug = df_new.groupby('PMID').agg({ + 'Gene': lambda x: set().union(*x.apply(normalize_split)), + 'Variant/Haplotypes': lambda x: set().union(*x.apply(normalize_split)), + 'Drug(s)': lambda x: set().union(*x.apply(normalize_split)) +}).reset_index() + +# Merge the grouped dataframes on PMID +merged_grouped_df = pd.merge(grouped_outputs, grouped_drug, on='PMID', suffixes=('_output', '_drug'), how='inner') + + +# Compare sets for each PMID +for _, row in merged_grouped_df.iterrows(): + gene_matches.append(len(row['gene'].intersection(row['Gene'])) / len(row['Gene']) if len(row['Gene']) > 0 else 0) + variant_matches.append(len(row['variant/haplotypes'].intersection(row['Variant/Haplotypes'])) / len(row['Variant/Haplotypes']) if len(row['Variant/Haplotypes']) > 0 else 0) + drug_matches.append(len(row['drug(s)'].intersection(row['Drug(s)'])) / len(row['Drug(s)']) if len(row['Drug(s)']) > 0 else 0) + +# Calculate the average match rates +average_gene_match_rate = sum(gene_matches) / len(gene_matches) +average_variant_match_rate = sum(variant_matches) / len(variant_matches) +average_drug_match_rate = sum(drug_matches) / len(drug_matches) + +# Display average match rates +(average_gene_match_rate, average_variant_match_rate, average_drug_match_rate) + +average_gene_match_rate, average_variant_match_rate, average_drug_match_rate + +import matplotlib.pyplot as plt +# 1. Bar plot to visualize match rates for Gene, Drug, and Phenotype +match_fields = ['Gene', 'Drug', 'Phenotype','Signficance',"Variant"] + +match_rates = [match_stats['gene_match_rate'], match_stats['drug_match_rate'], match_stats['phenotype_match_rate'],match_stats['significance_match_rate'],match_stats['variant_match_rate']] + +plt.figure(figsize=(8, 6)) +plt.bar(match_fields, match_rates, color = ['#004B8D', '#175E54', '#8C1515', '#F58025','#5D4B3C'] +) +plt.title('Exact Match Rates by Category', fontweight='bold', fontsize=14) + +plt.ylabel('Match Rate (%)',fontweight="bold") +plt.xlabel('Category',fontweight='bold') +plt.ylim(0, 100) # Ensure the y-axis goes from 0 to 100 +plt.tight_layout() +plt.show() + + +# Data for Partial Match (outer pie chart) +sizes_partial = [match_stats['partial_match_rate'], 100 - match_stats['partial_match_rate']] +colors_partial = ['#175E54', 'none'] # Stanford Green for relevant portion + +# Data for Exact Match (inner pie chart) +sizes_exact = [match_stats['exact_match_rate'], 100 - match_stats['exact_match_rate']] +colors_exact = ['#8C1515', 'none'] # Cardinal Red for relevant portion + +# Create the figure +plt.figure(figsize=(8, 8)) + +# Plot the larger pie chart (Partial Match) +plt.pie( + sizes_partial, + colors=colors_partial, + startangle=90, + radius=1.0, + wedgeprops={'linewidth': 1, 'edgecolor': 'white'}, + labels=None # No direct labels, handled in legend +) + +# Plot the smaller pie chart (Exact Match) on top +plt.pie( + sizes_exact, + colors=colors_exact, + startangle=90, + radius=0.7, + wedgeprops={'linewidth': 1, 'edgecolor': 'white'}, + labels=None # No direct labels, handled in legend +) + +# Add legend with appropriate colors +legend_labels = [ + f"Partial Match ({round(match_stats['partial_match_rate'],2)}%)", + f"Exact Match ({round(match_stats['exact_match_rate'],3)}%)" +] +legend_colors = ['#175E54', '#8C1515'] # Match the colors used in the chart + +plt.legend( + handles=[ + plt.Line2D([0], [0], color=legend_colors[0], lw=6), # Partial Match color + plt.Line2D([0], [0], color=legend_colors[1], lw=6) # Exact Match color + ], + labels=legend_labels, + loc='upper right', + fontsize=10, + frameon=True, + title="Match Types", + title_fontsize=12 +) + +# Add a title +plt.title('Exact vs Partial Match Rates', fontweight = "bold",fontsize=14, pad=20) + +# Adjust layout for better spacing +plt.tight_layout() +plt.show() + +import matplotlib.pyplot as plt + +# Match rates grouped by PMID (these variables should be defined in your data) +# Example: average_gene_match_rate, average_variant_match_rate, average_drug_match_rate +categories = ['Gene', 'Drug','Variant'] +match_rates = [average_gene_match_rate*100, average_drug_match_rate*100,average_variant_match_rate*100,] + +# Stanford-inspired colors for bars +colors = ['#8C1515', '#175E54', '#F58025'] # Cardinal Red, Stanford Green, Stanford Orange + +# Create the bar chart +plt.figure(figsize=(10, 6)) +bars = plt.bar(categories, match_rates, color=colors, edgecolor='black', linewidth=1.2) + +# Add values on top of bars +for bar, rate in zip(bars, match_rates): + plt.text( + bar.get_x() + bar.get_width() / 2, + bar.get_height() + 1, # Adjust height for the text + f'{rate:.1f}%', # Rounded percentage + ha='center', + fontweight="bold", + fontsize=10, + color='black' + ) + +# Title and labels +plt.title('Match Rates by Category (Grouped by PMID)',fontweight="bold", fontsize=14, pad=20) +plt.ylabel('Match Rate (%)',fontweight="bold",fontsize=12) +plt.xlabel('Categories',fontweight="bold", fontsize=12) + +# Adjust ticks for readability +plt.xticks(fontsize=10) +plt.yticks(fontsize=10) + +# Set y-axis limits +plt.ylim(0, 100) # Assuming match rates are percentages + +# Add grid for better readability + + +# Show the plot +plt.tight_layout() +plt.show() + +# Load the dataset (replace the path with the correct file location in Colab) +wholecsv = pd.read_csv('/content/wholecsv.csv') + +# Summarizing match statistics based on the pre-marked columns +match_columns = [ + 'Match metabolizer', + 'Match significance', + 'Match all drug', + 'Match Any Drug', + 'Match gene', + 'Match phenotype', + 'Match population' +] + +# Calculating the percentage of matches for each attribute +match_stats_new = wholecsv[match_columns].mean() * 100 + +# Adjusting the color scheme to be more neutral and less bright +plt.figure(figsize=(10, 6)) +plt.bar(match_stats_new.index, match_stats_new.values, color=['#2E8B57', '#4682B4', '#6A5ACD', '#D2691E', '#556B2F', '#8B4513', '#2F4F4F']) +plt.title('Match Statistics for Different Attributes', fontsize=16) +plt.ylabel('Match Percentage (%)', fontsize=12) +plt.xlabel('Attributes', fontsize=12) +plt.xticks(rotation=45) +plt.ylim(0, 100) +plt.tight_layout() +plt.show() + + +# Creating the table summarizing match statistics for the poster +table_data = match_stats_new.reset_index() +table_data.columns = ['Attribute', 'Match Percentage'] + +# If you want to print the table to the console (optional): +print(table_data) \ No newline at end of file diff --git a/src/extract_annotations/data_download.py b/src/extract_annotations/data_download.py new file mode 100644 index 0000000..ae6c81d --- /dev/null +++ b/src/extract_annotations/data_download.py @@ -0,0 +1,22 @@ +# data_download.py +# This script downloads and extracts clinical variant and variant annotation data from specified URLs. + +import requests +import zipfile +import io +from config import CLINICAL_VARIANTS_URL, VARIANT_ANNOTATIONS_URL + +def download_and_extract_zip(url, extract_path): + """Download and extract a ZIP file from the given URL.""" + try: + response = requests.get(url, stream=True) + response.raise_for_status() + with zipfile.ZipFile(io.BytesIO(response.content)) as z: + z.extractall(extract_path) + print(f"Successfully downloaded and unpacked {url} to {extract_path}") + except requests.exceptions.RequestException as e: + print(f"Error downloading the file from {url}: {e}") + except zipfile.BadZipFile as e: + print(f"Error unpacking the zip file: {e}") + except Exception as e: + print(f"An unexpected error occurred: {e}") \ No newline at end of file diff --git a/src/extract_annotations/ncbi_fetch.py b/src/extract_annotations/ncbi_fetch.py new file mode 100644 index 0000000..6183d5c --- /dev/null +++ b/src/extract_annotations/ncbi_fetch.py @@ -0,0 +1,65 @@ +# ncbi_fetch.py +from Bio import Entrez +import time +import random +from bs4 import BeautifulSoup +from config import ENTREZ_EMAIL + +def setup_entrez(): + """Configure Entrez with email.""" + Entrez.email = ENTREZ_EMAIL + +def get_pmcid_from_pmid(pmid, retries=3): + """Get PMCID from PMID with retry mechanism.""" + for attempt in range(retries): + try: + handle = Entrez.elink(dbfrom="pubmed", db="pmc", id=pmid, linkname="pubmed_pmc") + record = Entrez.read(handle) + handle.close() + if record and 'LinkSetDb' in record[0] and record[0]['LinkSetDb']: + return record[0]['LinkSetDb'][0]['Link'][0]['Id'] + print(f"No PMCID found for PMID {pmid}.") + return None + except Exception as e: + print(f"Error for PMID {pmid} on attempt {attempt + 1}: {e}") + if attempt < retries - 1: + sleep_time = (2 ** attempt) + random.uniform(0, 1) + print(f"Retrying in {sleep_time:.2f} seconds...") + time.sleep(sleep_time) + else: + return None + +def fetch_pmc_content(pmcid): + """Fetch PMC content using PMCID.""" + try: + handle = Entrez.efetch(db="pmc", id=pmcid, rettype="full", retmode="xml") + record = handle.read() + handle.close() + return record + except Exception as e: + print(f"Error fetching content for PMCID {pmcid}: {e}") + return None + +def process_row(row, processed_pmids, processed_data): + """Process a single DataFrame row to fetch PMCID and content.""" + time.sleep(0.4 + random.uniform(0, 0.5)) + pmid = str(row['PMID']) + + if pmid in processed_pmids: + return pd.Series(processed_data[pmid]) + + pmcid = get_pmcid_from_pmid(pmid) + result = {'PMCID': None, 'Title': None, 'Content': None, 'Content_text': None} + + if pmcid: + xml_content = fetch_pmc_content(pmcid) + if xml_content: + soup = BeautifulSoup(xml_content, 'xml') + title_tag = soup.find('article-title') + title = title_tag.get_text() if title_tag else "No Title Found" + clean_text = soup.get_text() + result = {'PMCID': pmcid, 'Title': title, 'Content': xml_content, 'Content_text': clean_text} + + processed_pmids.add(pmid) + processed_data[pmid] = result + return pd.Series(result) \ No newline at end of file diff --git a/src/extract_annotations/processing.py b/src/extract_annotations/processing.py new file mode 100644 index 0000000..48d5384 --- /dev/null +++ b/src/extract_annotations/processing.py @@ -0,0 +1,119 @@ +# processing.py +import pandas as pd +import tqdm +from ncbi_fetch import process_row +import json +from openai import OpenAI +from tqdm import tqdm +from config import SCHEMA_TEXT, SYSTEM_MESSAGE_TEMPLATE, OPENAI_MODEL + +def clean_enum_list(enum_list): + """Clean and normalize enumeration lists.""" + cleaned_list = [x for x in enum_list if pd.notna(x)] + split_list = [item.strip() for sublist in cleaned_list for item in sublist.split(',')] + return list(set(split_list)) + +def load_and_prepare_data(file_path): + """Load and prepare the variant annotation DataFrame.""" + df = pd.read_csv(file_path, sep='\t') + phenotype_category_enum = clean_enum_list(df['Phenotype Category'].unique().tolist()) + significance_enum = clean_enum_list(df['Significance'].unique().tolist()) + metabolizer_types_enum = clean_enum_list(df['Metabolizer types'].unique().tolist()) + specialty_population_enum = clean_enum_list(df['Population types'].unique().tolist()) + return df, { + 'phenotype_category': phenotype_category_enum, + 'significance': significance_enum, + 'metabolizer_types': metabolizer_types_enum, + 'specialty_population': specialty_population_enum + } + +def process_dataframe(df, num_rows=None): + """Process DataFrame rows to fetch NCBI data.""" + if num_rows: + df = df.head(num_rows) + processed_pmids = set() + processed_data = {} + tqdm.pandas(desc="Processing rows") + result_df = df.progress_apply(lambda row: process_row(row, processed_pmids, processed_data), axis=1) + return pd.concat([df, result_df], axis=1) + +def create_schema(enum_values): + """Create JSON schema for API calls.""" + return { + "type": "object", + "properties": { + "genes": { + "type": "array", + "items": { + "type": "object", + "properties": { + "gene": {"type": "string"}, + "variant": {"type": "string", "description": "can be in the form of a star allele, rsid"}, + "drug(s)": {"type": "string"}, + "phenotype category": {"type": "string", "enum": enum_values['phenotype_category']}, + "significance": {"type": "string", "enum": ["significant", "not significant", "not stated"]}, + "metabolizer types": {"type": "string", "enum": enum_values['metabolizer_types']}, + "specialty population": {"type": "string"} + }, + "required": ["gene", "variant", "drug(s)", "phenotype category", "significance", "metabolizer types", "specialty population"], + "additionalProperties": False + } + } + }, + "required": ["genes"], + "additionalProperties": False + } + +def create_messages(content_text, schema_text, custom_template=None): + """Create API messages from templates.""" + system_message = custom_template or SYSTEM_MESSAGE_TEMPLATE + system_message = system_message.format(schema=schema_text) + return [ + {"role": "system", "content": system_message}, + {"role": "user", "content": content_text} + ] + +def call_api(client, messages, schema): + """Make an API call using the provided messages and schema.""" + response = client.chat.completions.create( + model=OPENAI_MODEL, + messages=messages, + response_format={ + "type": "json_schema", + "json_schema": {"name": "gene_array_response", "schema": schema, "strict": True} + } + ) + return json.loads(response.choices[0].message.content) + +def load_checkpoint(checkpoint_path): + """Load checkpoint data.""" + if os.path.exists(checkpoint_path): + with open(checkpoint_path, 'r') as f: + checkpoint = json.load(f) + return set(checkpoint.get('processed_pmids', [])), checkpoint.get('results', []) + return set(), [] + +def save_checkpoint(checkpoint_path, processed_pmids, results): + """Save checkpoint data.""" + with open(checkpoint_path, 'w') as f: + json.dump({"processed_pmids": list(processed_pmids), "results": results}, f) + +def process_responses(df, client, schema_text, schema, checkpoint_path, custom_template=None): + """Process DataFrame rows and fetch API responses.""" + processed_pmids, results = load_checkpoint(checkpoint_path) + for _, row in tqdm(df.iterrows(), total=len(df), desc="Processing API responses"): + pmid = row['PMID'] + if pmid in processed_pmids: + continue + content_text = row['Content_text'] + try: + messages = create_messages(content_text, schema_text, custom_template) + extracted_data = call_api(client, messages, schema) + for gene_info in extracted_data.get('genes', []): + gene_info['PMID'] = pmid + results.append(gene_info) + processed_pmids.add(pmid) + save_checkpoint(checkpoint_path, processed_pmids, results) + except Exception as e: + print(f"Error processing PMID {pmid}: {e}") + return results \ No newline at end of file diff --git a/src/extract_annotations/run_extraction.py b/src/extract_annotations/run_extraction.py new file mode 100644 index 0000000..67711d6 --- /dev/null +++ b/src/extract_annotations/run_extraction.py @@ -0,0 +1,112 @@ +# main.py +import pandas as pd +from openai import OpenAI +import os +from config import ( + CLINICAL_VARIANTS_URL, VARIANT_ANNOTATIONS_URL, VAR_DRUG_ANN_PATH, + CHECKPOINT_PATH, OUTPUT_CSV_PATH, DF_NEW_CSV_PATH, WHOLE_CSV_PATH, SCHEMA_TEXT +) +from data_download import download_and_extract_zip +from data_processing import load_and_prepare_data, process_dataframe +from api_processing import create_schema, process_responses +from variant_matching import align_and_compare_datasets +from visualization import plot_match_rates, plot_pie_charts, plot_grouped_match_rates, plot_attribute_match_rates +from ncbi_fetch import setup_entrez + +def main(): + # Initialize OpenAI client + client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) + + # Download data + download_and_extract_zip(CLINICAL_VARIANTS_URL, "clinicalVariants") + download_and_extract_zip(VARIANT_ANNOTATIONS_URL, "variantAnnotations") + + # Load and prepare data + df_var_drug_ann, enum_values = load_and_prepare_data(VAR_DRUG_ANN_PATH) + + # Process initial DataFrame + test_df = process_dataframe(df_var_drug_ann, num_rows=5) + test_df.dropna(subset=['Content'], inplace=True) + test_df.reset_index(drop=True, inplace=True) + + # Create schema + schema = create_schema(enum_values) + + # Process API responses + flattened_results = process_responses(test_df, client, SCHEMA_TEXT, schema, CHECKPOINT_PATH) + flattened_df = pd.DataFrame(flattened_results) + + # Align and compare datasets + df_aligned = align_and_compare_datasets(test_df, flattened_df) + + # Calculate match statistics + match_stats = { + 'gene_match_rate': df_aligned['gene_match'].mean() * 100, + 'drug_match_rate': df_aligned['drug_match'].mean() * 100, + 'phenotype_match_rate': df_aligned['phenotype_match'].mean() * 100, + 'variant_match_rate': (df_aligned['variant_match'] == 'Exact Match').mean() * 100, + 'partial_variant_match_rate': (df_aligned['variant_match'] == 'Partial Match').mean() * 100, + 'exact_match_rate': df_aligned[['gene_match', 'drug_match', 'phenotype_match']].all(axis=1).mean() * 100, + 'partial_match_rate': df_aligned[['gene_match', 'drug_match', 'phenotype_match']].any(axis=1).mean() * 100, + 'mismatch_rate': (~df_aligned[['gene_match', 'drug_match', 'phenotype_match', 'variant_match']].any(axis=1)).mean() * 100, + 'significance_match_rate': df_aligned['significance_match'].mean() * 100, + 'metabolizer_match_rate': df_aligned['metabolizer_match'].mean() * 100, + 'population_match_rate': df_aligned['population_match'].mean() * 100, + } + + # Grouped match rates + def normalize_split(value): + if pd.isna(value): + return set() + return set(map(str.strip, str(value).lower().replace(';', ',').split(','))) + + grouped_outputs = flattened_df.groupby('PMID').agg({ + 'gene': lambda x: set().union(*x.apply(normalize_split)), + 'variant/haplotypes': lambda x: set().union(*x.apply(normalize_split)), + 'drug(s)': lambda x: set().union(*x.apply(normalize_split)) + }).reset_index() + + grouped_drug = test_df.groupby('PMID').agg({ + 'Gene': lambda x: set().union(*x.apply(normalize_split)), + 'Variant/Haplotypes': lambda x: set().union(*x.apply(normalize_split)), + 'Drug(s)': lambda x: set().union(*x.apply(normalize_split)) + }).reset_index() + + merged_grouped_df = pd.merge(grouped_outputs, grouped_drug, on='PMID', suffixes=('_output', '_drug'), how='inner') + gene_matches = [] + variant_matches = [] + drug_matches = [] + for _, row in merged_grouped_df.iterrows(): + gene_matches.append(len(row['gene'].intersection(row['Gene'])) / len(row['Gene']) if len(row['Gene']) > 0 else 0) + variant_matches.append(len(row['variant/haplotypes'].intersection(row['Variant/Haplotypes'])) / len(row['Variant/Haplotypes']) if len(row['Variant/Haplotypes']) > 0 else 0) + drug_matches.append(len(row['drug(s)'].intersection(row['Drug(s)'])) / len(row['Drug(s)']) if len(row['Drug(s']) > 0 else 0) + + average_gene_match_rate = sum(gene_matches) / len(gene_matches) + average_variant_match_rate = sum(variant_matches) / len(variant_matches) + average_drug_match_rate = sum(drug_matches) / len(drug_matches) + + # Save outputs + test_df.to_csv(DF_NEW_CSV_PATH, index=False) + flattened_df.to_csv(OUTPUT_CSV_PATH, index=False) + + # Visualizations + plot_match_rates(match_stats) + plot_pie_charts(match_stats) + plot_grouped_match_rates(average_gene_match_rate, average_drug_match_rate, average_variant_match_rate) + wholecsv = pd.read_csv(WHOLE_CSV_PATH) + table_data = plot_attribute_match_rates(wholecsv) + + # Print results + print("Match Statistics:") + for key, value in match_stats.items(): + print(f"{key}: {value:.2f}%") + print("\nGrouped Match Rates:") + print(f"Average Gene Match Rate: {average_gene_match_rate*100:.2f}%") + print(f"Average Variant Match Rate: {average_variant_match_rate*100:.2f}%") + print(f"Average Drug Match Rate: {average_drug_match_rate*100:.2f}%") + print("\nAttribute Match Table:") + print(table_data) + +if __name__ == "__main__": + setup_entrez() + main() \ No newline at end of file diff --git a/src/extract_annotations/variant_matching.py b/src/extract_annotations/variant_matching.py new file mode 100644 index 0000000..f419e6c --- /dev/null +++ b/src/extract_annotations/variant_matching.py @@ -0,0 +1,70 @@ +# variant_matching.py +import pandas as pd +import re + +class SimplifiedVariantMatcher: + @staticmethod + def split_variants(variant_string): + """Split variant strings into individual components.""" + if pd.isna(variant_string): + return set() + return set(map(str.strip, variant_string.replace('/', ',').split(','))) + + @staticmethod + def preprocess_variants(variant_string, gene=None): + """Preprocess variant strings.""" + variants = SimplifiedVariantMatcher.split_variants(variant_string) + processed_variants = set() + for variant in variants: + variant = variant.strip() + if 'rs' in variant: + rs_match = re.findall(r'rs\d+', variant) + processed_variants.update(rs_match) + elif '*' in variant and '-' in variant and gene: + star_allele, mutation = variant.split('-', 1) + processed_variants.add(f"{gene}{star_allele.strip()}") + processed_variants.add(mutation.strip()) + elif '*' in variant and gene: + processed_variants.add(f"{gene}{variant.strip()}") + elif '>' in variant: + processed_variants.add(variant.strip()) + else: + processed_variants.add(variant.strip()) + return processed_variants + + def match_row(self, row): + """Match ground truth and predicted variants for a single row.""" + truth_variants = self.preprocess_variants(row['Variant/Haplotypes_truth'], gene=row['Gene_truth']) + predicted_variants = self.preprocess_variants(row['variant/haplotypes_output'], gene=row['Gene_output']) + if predicted_variants == truth_variants: + return 'Exact Match' + if predicted_variants.intersection(truth_variants): + return 'Partial Match' + return 'No Match' + +def align_and_compare_datasets(df_new, flattened_df): + """Align and compare datasets based on PMID.""" + df_new = df_new.rename(columns={ + 'Gene': 'Gene_truth', + 'Drug(s)': 'Drug(s)_truth', + 'Phenotype Category': 'Phenotype Category_truth', + 'Variant/Haplotypes': 'Variant/Haplotypes_truth' + }) + flattened_df = flattened_df.rename(columns={ + 'gene': 'Gene_output', + 'drug(s)': 'Drug(s)_output', + 'phenotype category': 'Phenotype Category_output', + 'variant/haplotypes': 'variant/haplotypes_output' + }) + df_aligned = pd.merge(df_new, flattened_df, how='inner', on='PMID') + matcher = SimplifiedVariantMatcher() + df_aligned['variant_match'] = df_aligned.apply(lambda row: matcher.match_row(row), axis=1) + df_aligned['gene_match'] = df_aligned['Gene_truth'] == df_aligned['Gene_output'] + df_aligned['drug_match'] = df_aligned['Drug(s)_truth'] == df_aligned['Drug(s)_output'] + df_aligned['phenotype_match'] = df_aligned['Phenotype Category_truth'] == df_aligned['Phenotype Category_output'] + df_aligned['significance_match'] = df_aligned['Significance'] == df_aligned['significance'].map({ + 'significant': 'yes', 'not significant': 'no', 'not stated': 'not stated' + }) + df_aligned['metabolizer_match'] = df_aligned['Metabolizer types'] == df_aligned['metabolizer types'] + df_aligned['population_match'] = df_aligned['Specialty Population'] == df_aligned['specialty population'] + return df_aligned \ No newline at end of file diff --git a/src/extract_annotations/visualization.py b/src/extract_annotations/visualization.py new file mode 100644 index 0000000..57a076b --- /dev/null +++ b/src/extract_annotations/visualization.py @@ -0,0 +1,68 @@ +# visualization.py +import matplotlib.pyplot as plt + +def plot_match_rates(match_stats): + """Plot match rates for different categories.""" + match_fields = ['Gene', 'Drug', 'Phenotype', 'Significance', 'Variant'] + match_rates = [ + match_stats['gene_match_rate'], + match_stats['drug_match_rate'], + match_stats['phenotype_match_rate'], + match_stats['significance_match_rate'], + match_stats['variant_match_rate'] + ] + plt.figure(figsize=(8, 6)) + plt.bar(match_fields, match_rates, color=['#004B8D', '#175E54', '#8C1515', '#F58025', '#5D4B3C']) + plt.title('Exact Match Rates by Category', fontweight='bold', fontsize=14) + plt.ylabel('Match Rate (%)', fontweight="bold") + plt.xlabel('Category', fontweight='bold') + plt.ylim(0, 100) + plt.tight_layout() + plt.show() + +def plot_pie_charts(match_stats): + """Plot pie charts for exact and partial match rates.""" + sizes_partial = [match_stats['partial_match_rate'], 100 - match_stats['partial_match_rate']] + colors_partial = ['#175E54', 'none'] + sizes_exact = [match_stats['exact_match_rate'], 100 - match_stats['exact_match_rate']] + colors_exact = ['#8C1515', 'none'] + plt.figure(figsize=(8, 8)) + plt.pie(sizes_partial, colors=colors_partial, startangle=90, radius=1.0, wedgeprops={'linewidth': 1, 'edgecolor': 'white'}) + plt.pie(sizes_exact, colors=colors_exact, startangle=90, radius=0.7, wedgeprops={'linewidth': 1, 'edgecolor': 'white'}) + legend_labels = [f"Partial Match ({round(match_stats['partial_match_rate'], 2)}%)", f"Exact Match ({round(match_stats['exact_match_rate'], 3)}%)"] + legend_colors = ['#175E54', '#8C1515'] + plt.legend(handles=[plt.Line2D([0], [0], color=c, lw=6) for c in legend_colors], labels=legend_labels, loc='upper right', fontsize=10, frameon=True, title="Match Types", title_fontsize=12) + plt.title('Exact vs Partial Match Rates', fontweight="bold", fontsize=14, pad=20) + plt.tight_layout() + plt.show() + +def plot_grouped_match_rates(average_gene_match_rate, average_drug_match_rate, average_variant_match_rate): + """Plot match rates grouped by PMID.""" + categories = ['Gene', 'Drug', 'Variant'] + match_rates = [average_gene_match_rate * 100, average_drug_match_rate * 100, average_variant_match_rate * 100] + colors = ['#8C1515', '#175E54', '#F58025'] + plt.figure(figsize=(10, 6)) + bars = plt.bar(categories, match_rates, color=colors, edgecolor='black', linewidth=1.2) + for bar, rate in zip(bars, match_rates): + plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 1, f'{rate:.1f}%', ha='center', fontweight="bold", fontsize=10, color='black') + plt.title('Match Rates by Category (Grouped by PMID)', fontweight="bold", fontsize=14, pad=20) + plt.ylabel('Match Rate (%)', fontweight="bold", fontsize=12) + plt.xlabel('Categories', fontweight="bold", fontsize=12) + plt.ylim(0, 100) + plt.tight_layout() + plt.show() + +def plot_attribute_match_rates(wholecsv): + """Plot match statistics for different attributes.""" + match_columns = ['Match metabolizer', 'Match significance', 'Match all drug', 'Match Any Drug', 'Match gene', 'Match phenotype', 'Match population'] + match_stats_new = wholecsv[match_columns].mean() * 100 + plt.figure(figsize=(10, 6)) + plt.bar(match_stats_new.index, match_stats_new.values, color=['#2E8B57', '#4682B4', '#6A5ACD', '#D2691E', '#556B2F', '#8B4513', '#2F4F4F']) + plt.title('Match Statistics for Different Attributes', fontsize=16) + plt.ylabel('Match Percentage (%)', fontsize=12) + plt.xlabel('Attributes', fontsize=12) + plt.xticks(rotation=45) + plt.ylim(0, 100) + plt.tight_layout() + plt.show() + return match_stats_new.reset_index().rename(columns={'index': 'Attribute', 0: 'Match Percentage'}) \ No newline at end of file From 52cfad3474025883daba67fc9feaa011ce43023f Mon Sep 17 00:00:00 2001 From: Sonnet Xu <59452214+sonnetx@users.noreply.github.com> Date: Fri, 23 May 2025 10:44:57 -0700 Subject: [PATCH 02/12] updates variant extraction --- src/{extract_annotations => variant_extraction}/config.py | 0 src/{extract_annotations => variant_extraction}/data_download.py | 0 src/{extract_annotations => variant_extraction}/ncbi_fetch.py | 0 .../original_ves.py} | 0 src/{extract_annotations => variant_extraction}/processing.py | 0 .../run_variant_extraction.py} | 0 .../variant_matching.py | 0 src/{extract_annotations => variant_extraction}/visualization.py | 0 8 files changed, 0 insertions(+), 0 deletions(-) rename src/{extract_annotations => variant_extraction}/config.py (100%) rename src/{extract_annotations => variant_extraction}/data_download.py (100%) rename src/{extract_annotations => variant_extraction}/ncbi_fetch.py (100%) rename src/{extract_annotations/copy_of_pharmgkb_dbds_real.py => variant_extraction/original_ves.py} (100%) rename src/{extract_annotations => variant_extraction}/processing.py (100%) rename src/{extract_annotations/run_extraction.py => variant_extraction/run_variant_extraction.py} (100%) rename src/{extract_annotations => variant_extraction}/variant_matching.py (100%) rename src/{extract_annotations => variant_extraction}/visualization.py (100%) diff --git a/src/extract_annotations/config.py b/src/variant_extraction/config.py similarity index 100% rename from src/extract_annotations/config.py rename to src/variant_extraction/config.py diff --git a/src/extract_annotations/data_download.py b/src/variant_extraction/data_download.py similarity index 100% rename from src/extract_annotations/data_download.py rename to src/variant_extraction/data_download.py diff --git a/src/extract_annotations/ncbi_fetch.py b/src/variant_extraction/ncbi_fetch.py similarity index 100% rename from src/extract_annotations/ncbi_fetch.py rename to src/variant_extraction/ncbi_fetch.py diff --git a/src/extract_annotations/copy_of_pharmgkb_dbds_real.py b/src/variant_extraction/original_ves.py similarity index 100% rename from src/extract_annotations/copy_of_pharmgkb_dbds_real.py rename to src/variant_extraction/original_ves.py diff --git a/src/extract_annotations/processing.py b/src/variant_extraction/processing.py similarity index 100% rename from src/extract_annotations/processing.py rename to src/variant_extraction/processing.py diff --git a/src/extract_annotations/run_extraction.py b/src/variant_extraction/run_variant_extraction.py similarity index 100% rename from src/extract_annotations/run_extraction.py rename to src/variant_extraction/run_variant_extraction.py diff --git a/src/extract_annotations/variant_matching.py b/src/variant_extraction/variant_matching.py similarity index 100% rename from src/extract_annotations/variant_matching.py rename to src/variant_extraction/variant_matching.py diff --git a/src/extract_annotations/visualization.py b/src/variant_extraction/visualization.py similarity index 100% rename from src/extract_annotations/visualization.py rename to src/variant_extraction/visualization.py From df0d02aaa5f8b424f502e5635e5054b1930c9d0d Mon Sep 17 00:00:00 2001 From: Shlok Natarajan Date: Fri, 23 May 2025 16:14:20 -0700 Subject: [PATCH 03/12] docs: updated setup readme --- README.MD | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/README.MD b/README.MD index 97f5c97..6cbc6bf 100644 --- a/README.MD +++ b/README.MD @@ -15,6 +15,18 @@ Output: Score 3. System for extracting drug related variants annotations from an article. Associations in which the variant affects a drug dose, response, metabolism, etc. 4. Continously fetch new pharmacogenomic articles +## Setup +To get started, you need to sources of data locally: +1. The annotations for the articles (data/variantAnnotations/var_drug_ann.tsv) +2. the articles themselves (data/articles) +These can be populated using the following commands: +``` +pixi run download-variants +pixi run update-download-map +pixi run download-articles +``` +The download-articles step takes the longest and can be skipped by unzipping data/articles.zip, creating a list of XMLs at the directory data/articles. + ## Description This repository contains Python scripts for running and building a Pharmacogenomic Agentic system to annotate and label genetic variants based on their phenotypical associations from journal articles. From bb7d147b7eae14a51d4515a265374b4860bce0fa Mon Sep 17 00:00:00 2001 From: Shlok Natarajan Date: Fri, 23 May 2025 16:15:05 -0700 Subject: [PATCH 04/12] fix: typo --- README.MD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.MD b/README.MD index 6cbc6bf..88f1e96 100644 --- a/README.MD +++ b/README.MD @@ -16,7 +16,7 @@ Output: Score 4. Continously fetch new pharmacogenomic articles ## Setup -To get started, you need to sources of data locally: +To get started, you need two sources of data locally: 1. The annotations for the articles (data/variantAnnotations/var_drug_ann.tsv) 2. the articles themselves (data/articles) These can be populated using the following commands: From 0c1ee9e5469c459c05e385038be941652947c657 Mon Sep 17 00:00:00 2001 From: Shlok Natarajan Date: Fri, 23 May 2025 16:18:23 -0700 Subject: [PATCH 05/12] docs: readme update --- README.MD | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.MD b/README.MD index 88f1e96..c05e67c 100644 --- a/README.MD +++ b/README.MD @@ -26,6 +26,8 @@ pixi run update-download-map pixi run download-articles ``` The download-articles step takes the longest and can be skipped by unzipping data/articles.zip, creating a list of XMLs at the directory data/articles. +If you are running download-articles, make sure to create a .env at the root with your email using the format +NCBI_EMAIL=YOUR_EMAIL@SCHOOL.EDU ## Description From bbf9e913bc83a83090873829dbff877f8b05f175 Mon Sep 17 00:00:00 2001 From: Shlok Natarajan Date: Fri, 23 May 2025 16:30:41 -0700 Subject: [PATCH 06/12] feat: moved download pipeline to separate file --- pixi.toml | 2 +- .../download_annotations_pipeline.py | 23 +++++++++++++++++++ src/load_variants/load_clinical_variants.py | 22 ------------------ 3 files changed, 24 insertions(+), 23 deletions(-) create mode 100644 src/load_variants/download_annotations_pipeline.py diff --git a/pixi.toml b/pixi.toml index 4bb20c4..8811cb5 100644 --- a/pixi.toml +++ b/pixi.toml @@ -12,7 +12,7 @@ platforms = ["osx-arm64"] version = "0.1.0" [tasks] -download-variants = "python -m src.load_variants.load_clinical_variants" +download-variants = "python -m src.load_variants.download_annotations_pipeline" update-download-map = "python -c 'from src.fetch_articles.article_downloader import update_downloaded_pmcids; update_downloaded_pmcids()'" download-articles = "python -m src.fetch_articles.article_downloader" diff --git a/src/load_variants/download_annotations_pipeline.py b/src/load_variants/download_annotations_pipeline.py new file mode 100644 index 0000000..0d9e1fa --- /dev/null +++ b/src/load_variants/download_annotations_pipeline.py @@ -0,0 +1,23 @@ +from loguru import logger +from src.load_variants.load_clinical_variants import download_and_extract_variant_annotations, load_raw_annotations, get_pmid_list + +def variant_annotations_pipeline(): + """ + Loads the variant annotations tsv file and saves the unique PMIDs to a json file. + """ + # Download and extract the variant annotations + logger.info("Downloading and extracting variant annotations...") + download_and_extract_variant_annotations() + + # Load the variant annotations + logger.info("Loading variant annotations...") + df = load_raw_annotations() + + # Get the PMIDs + logger.info("Getting PMIDs...") + pmid_list = get_pmid_list() + logger.info(f"Number of unique PMIDs: {len(pmid_list)}") + + +if __name__ == "__main__": + variant_annotations_pipeline() \ No newline at end of file diff --git a/src/load_variants/load_clinical_variants.py b/src/load_variants/load_clinical_variants.py index c43b340..e8bd36c 100644 --- a/src/load_variants/load_clinical_variants.py +++ b/src/load_variants/load_clinical_variants.py @@ -112,25 +112,3 @@ def get_pmid_list(override: bool = False) -> list: with open(pmid_list_path, "w") as f: json.dump(pmid_list, f) return pmid_list - - -def variant_annotations_pipeline(): - """ - Loads the variant annotations tsv file and saves the unique PMIDs to a json file. - """ - # Download and extract the variant annotations - logger.info("Downloading and extracting variant annotations...") - download_and_extract_variant_annotations() - - # Load the variant annotations - logger.info("Loading variant annotations...") - df = load_raw_variant_annotations() - - # Get the PMIDs - logger.info("Getting PMIDs...") - pmid_list = get_pmid_list() - logger.info(f"Number of unique PMIDs: {len(pmid_list)}") - - -if __name__ == "__main__": - variant_annotations_pipeline() From a0bc1bd02f00765645b5f0e1199d1f3bcfcb9aa2 Mon Sep 17 00:00:00 2001 From: Sonnet Xu <59452214+sonnetx@users.noreply.github.com> Date: Tue, 27 May 2025 21:39:33 -0700 Subject: [PATCH 07/12] delete old code --- src/variant_extraction/data_download.py | 22 - src/variant_extraction/original_ves.py | 902 ------------------------ 2 files changed, 924 deletions(-) delete mode 100644 src/variant_extraction/data_download.py delete mode 100644 src/variant_extraction/original_ves.py diff --git a/src/variant_extraction/data_download.py b/src/variant_extraction/data_download.py deleted file mode 100644 index ae6c81d..0000000 --- a/src/variant_extraction/data_download.py +++ /dev/null @@ -1,22 +0,0 @@ -# data_download.py -# This script downloads and extracts clinical variant and variant annotation data from specified URLs. - -import requests -import zipfile -import io -from config import CLINICAL_VARIANTS_URL, VARIANT_ANNOTATIONS_URL - -def download_and_extract_zip(url, extract_path): - """Download and extract a ZIP file from the given URL.""" - try: - response = requests.get(url, stream=True) - response.raise_for_status() - with zipfile.ZipFile(io.BytesIO(response.content)) as z: - z.extractall(extract_path) - print(f"Successfully downloaded and unpacked {url} to {extract_path}") - except requests.exceptions.RequestException as e: - print(f"Error downloading the file from {url}: {e}") - except zipfile.BadZipFile as e: - print(f"Error unpacking the zip file: {e}") - except Exception as e: - print(f"An unexpected error occurred: {e}") \ No newline at end of file diff --git a/src/variant_extraction/original_ves.py b/src/variant_extraction/original_ves.py deleted file mode 100644 index 1b1e98b..0000000 --- a/src/variant_extraction/original_ves.py +++ /dev/null @@ -1,902 +0,0 @@ -''' -Aaron's original code, with minor cleaning and reformatting. -No major changes were made to the logic or structure. -''' - -import requests -import zipfile -import io -import pandas as pd -from openai import OpenAI -import time -from Bio import Entrez -import os -import json -import random -from bs4 import BeautifulSoup -import tqdm - -url = "https://api.pharmgkb.org/v1/download/file/data/clinicalVariants.zip" - -# Download the zip file -try: - response = requests.get(url, stream=True) - response.raise_for_status() # Raise an exception for bad status codes - - # Unpack the zip file - with zipfile.ZipFile(io.BytesIO(response.content)) as z: - z.extractall("clinicalVariants") - print("Successfully downloaded and unpacked the zip file.") - -except requests.exceptions.RequestException as e: - print(f"Error downloading the file: {e}") -except zipfile.BadZipFile as e: - print(f"Error unpacking the zip file: {e}") -except Exception as e: - print(f"An unexpected error occurred: {e}") - -url = "https://api.pharmgkb.org/v1/download/file/data/variantAnnotations.zip" - - -# Download the zip file -try: - response = requests.get(url, stream=True) - response.raise_for_status() # Raise an exception for bad status codes - - # Unpack the zip file - with zipfile.ZipFile(io.BytesIO(response.content)) as z: - z.extractall("variantAnnotations") - print("Successfully downloaded and unpacked the zip file.") - -except requests.exceptions.RequestException as e: - print(f"Error downloading the file: {e}") -except zipfile.BadZipFile as e: - print(f"Error unpacking the zip file: {e}") -except Exception as e: - print(f"An unexpected error occurred: {e}") - - -client = OpenAI( - api_key = os.getenv("OPENAI_API_KEY"), -) - -## Loading the data - -var_drug_ann = "/content/variantAnnotations/var_drug_ann.tsv" -df_var_drug_ann = pd.read_csv(var_drug_ann, sep='\t') - -## Get unique values -phenotype_category_enum = df_var_drug_ann['Phenotype Category'].unique().tolist() -significance_enum = df_var_drug_ann['Significance'].unique().tolist() -metabolizer_types_enum = df_var_drug_ann['Metabolizer types'].unique().tolist() -specialty_population_enum = df_var_drug_ann['Population types'].unique().tolist() - -## Get PMCID - -# Email for NCBI -Entrez.email = "aron7628@gmail.com" - -# Step 1: Function to get PMCID from PMID -def get_pmcid_from_pmid(pmid, retries=3): - for attempt in range(retries): - try: - handle = Entrez.elink(dbfrom="pubmed", db="pmc", id=pmid, linkname="pubmed_pmc") - record = Entrez.read(handle) - handle.close() - if record and 'LinkSetDb' in record[0] and record[0]['LinkSetDb']: - pmcid = record[0]['LinkSetDb'][0]['Link'][0]['Id'] - return pmcid - else: - print(f"No PMCID found for PMID {pmid}.") - return None - except Exception as e: - print(f"An error occurred for pmid {pmid} on attempt {attempt + 1}: {e}") - if attempt < retries - 1: - # Backoff time increases exponentially with jitter - sleep_time = (2 ** attempt) + random.uniform(0, 1) - print(f"Retrying in {sleep_time:.2f} seconds...") - time.sleep(sleep_time) - else: - return None - -# Step 2: Function to fetch content using PMCID -def fetch_pmc_content(pmcid): - try: - handle = Entrez.efetch(db="pmc", id=pmcid, rettype="full", retmode="xml") - record = handle.read() - handle.close() - return record - except Exception as e: - print(f"An error occurred while fetching content for PMCID {pmcid}: {e}") - return None - -# Function to process each row in the DataFrame -def process_row(row, processed_pmids, processed_data): - time.sleep(0.4 + random.uniform(0, 0.5)) # Introduce delay to avoid throttling - pmid = str(row['PMID']) - - if pmid in processed_pmids: - # Use the previously processed data for duplicate PMIDs - return pd.Series(processed_data[pmid]) - - # Step 1: Get PMCID from PMID - pmcid = get_pmcid_from_pmid(pmid) - - if pmcid: - # Step 2: Fetch PMC content using the new fetch_pmc_content function - xml_content = fetch_pmc_content(pmcid) - - if xml_content: - # Step 3: Parse the XML content to extract text and title using BeautifulSoup - soup = BeautifulSoup(xml_content, 'xml') - - # Extract the article title - title_tag = soup.find('article-title') - title = title_tag.get_text() if title_tag else "No Title Found" - - # Extract the full text of the article - clean_text = soup.get_text() - - # Save processed data for this PMID - processed_pmids.add(pmid) - processed_data[pmid] = { - 'PMCID': pmcid, - 'Title': title, - 'Content': xml_content, - 'Content_text': clean_text, - } - else: - # Save processed data for failed PMC content fetch - processed_pmids.add(pmid) - processed_data[pmid] = { - 'PMCID': pmcid, - 'Title': None, - 'Content': None, - 'Content_text': None, - } - else: - # Save processed data for invalid PMIDs - processed_pmids.add(pmid) - processed_data[pmid] = { - 'PMCID': None, - 'Title': None, - 'Content': None, - 'Content_text': None, - } - - # Return the processed data for the current row - return pd.Series(processed_data[pmid]) - -# Wrapper function to handle processed PMIDs and data -def process_dataframe(df): - processed_pmids = set() - processed_data = {} - - tqdm.pandas(desc="Processing rows") # Initialize tqdm for Pandas - return df.progress_apply(lambda row: process_row(row, processed_pmids, processed_data), axis=1) - -how_many = 5 -testdf = df_var_drug_ann[:how_many] -result_df = process_dataframe(testdf) - -# Combine the results back with the original DataFrame -df_new = pd.concat([testdf, result_df], axis=1) -df_new.to_csv(f"first_{how_many}_var_drug.csv") - -df_new.dropna(subset=['Content'], inplace=True) -df_new.reset_index(drop=True, inplace=True) -print(df_new) - -len(df_new['PMID'].unique()) - -schema_text = ''' -{ - "type": "object", - "properties": { - "gene": { - "type": "string", - "description": "The specific gene related to the drug response or phenotype (e.g., CYP3A4).", - "examples": ["CYP2C19", "UGT1A3"] - }, - "variant/haplotypes": { - "type": "string", - "description": "full star allele including gene, full rsid, or full haplotype" - "example":["CYP2C19*17"] - }, - "drug(s)": { - "type": "string", - "description": "The drug(s) that are influenced by the gene variant(s).", - "examples": ["abrocitinib", "mirabegron"] - }, - "phenotype category": { - "type": "string", - "description": "Describes the type of phenotype related to the gene-drug interaction (e.g., Metabolism/PK, toxicity).", - "enum": ["Metabolism/PK", "Efficacy", "Toxicity", "Other"], - "examples": ["Metabolism/PK"] - }, - "significance": { - "type": "string", - "description": "The level of importance or statistical significance of the gene-drug interaction (e.g., 'significant', 'not significant').", - "enum": ["significant", "not significant", "not stated"], - "examples": ["significant", "not stated"] - }, - "metabolizer types": { - "type": "string", - "description": "Indicates the metabolizer status of the patient based on the gene variant (e.g., poor metabolizer, extensive metabolizer).", - "enum": ["poor", "intermediate", "extensive", "ultrarapid"], - "examples": ["poor", "extensive"] - }, - "specialty population": { - "type": "string", - "description": "Refers to specific populations where this gene-drug interaction may have different effects (e.g., ethnic groups, pediatric populations).", - "examples": ["healthy individuals", "African American", "pediatric"], - "default": "Not specified" - }, - "PMID": { - "type": "integer", # Changed from 'int' to 'integer' - "description": "PMID from source spreadsheet", - "example": 123345 - } - }, - "required": ["gene","variant/haplotyptes", "drug(s)", "phenotype category", "significance", "metabolizer types", "PMID"] -} -''' - -def clean_enum_list(enum_list): - # Remove NaN - cleaned_list = [x for x in enum_list if pd.notna(x)] - # Split comma-separated strings and flatten the list - split_list = [item.strip() for sublist in cleaned_list for item in sublist.split(',')] - return list(set(split_list)) # Ensure uniqueness - -# Clean phenotype_category_enum and metabolizer_types_enum -phenotype_category_enum = clean_enum_list(phenotype_category_enum) -metabolizer_types_enum = clean_enum_list(metabolizer_types_enum) - - -schema = { - "type": "object", # The root must be an object, not an array - "properties": { - "genes": { # The "genes" property will contain an array of gene objects - "type": "array", - "items": { - "type": "object", - "properties": { - "gene": { - "type": "string" - }, - "variant": { - "type": "string", - "description": "can be in the form of a star allele, rsid" - }, - "drug(s)": { - "type": "string" - }, - - "phenotype category": { - "type": "string", - "enum": phenotype_category_enum - }, - "significance": { - "type": "string", - "enum": ["significant", "not significant", "not stated"] - }, - "metabolizer types": { - "type": "string", - "enum": metabolizer_types_enum - }, - "specialty population": { - "type": "string" - } - - - }, - "required": ["gene","variant","drug(s)", "phenotype category", "significance", "metabolizer types", "specialty population"], - "additionalProperties": False - } - } - }, - "required": ["genes"], # The root object must have a "genes" array - "additionalProperties": False -} - -# Initialize a list to hold the flattened JSON results -flattened_results = [] - -# Loop through each row in df_new -for index, row in df_new.iterrows(): - content_text = row['Content_text'] - pmid = row['PMID'] - - # Make the API call for each row - response = client.chat.completions.create( - model="gpt-4o-2024-08-06", - messages=[ - {"role": "system", "content": f"Extract multiple gene-related information from the text and return them as an array of JSON objects with example schema{schema_text}"}, - {"role": "user", "content": content_text} # Assuming Content contains the text you're processing - ], - response_format={ - "type": "json_schema", - "json_schema": { - "name": "gene_array_response", - "schema": schema, - "strict": True - } - } - ) - - # Extract and load the JSON response - extracted_data = json.loads(response.choices[0].message.content) - - # Flatten the JSON output and attach PMID - for gene_info in extracted_data.get('genes', []): - gene_info['PMID'] = pmid # Attach the PMID to each gene entry - flattened_results.append(gene_info) # Add to the list of flattened results - -# Convert flattened results to a DataFrame -flattened_df = pd.DataFrame(flattened_results) - -print(flattened_df) - -# this fetches data. -def fetch_test_data(df, num_rows=10): - """ - Fetch a subset of data for testing. - :param df: DataFrame containing the full dataset. - :param num_rows: Number of rows to sample for testing. - :return: DataFrame with sampled rows. - """ - return df.head(num_rows) - -def create_messages(content_text, schema_text, custom_template=None): - """ - Create API messages from templates. - :param content_text: Text content to process. - :param schema_text: Schema description for the API call. - :param custom_template: Optional custom template string for the system message. - :return: List of messages for the API. - """ - if custom_template: - system_message = custom_template.format(schema=schema_text, content=content_text) - else: - system_message = ( - "You are tasked with extracting information from scientific articles to assist in genetic variant annotation. " - "Focus on identifying key details related to genetic variants, including but not limited to:\n" - "- Variant identifiers (e.g., rsIDs, gene names, protein changes like p.Val600Glu, or DNA changes like c.1799T>A).\n" - "- Associated genes, transcripts, and protein products.\n" - "- Contextual information such as clinical significance, population frequency, or related diseases and drugs.\n" - "- Methodologies or evidence supporting the findings (e.g., experimental results, population studies, computational predictions).\n\n" - "Your output must be in the form of an array of JSON objects adhering to the following schema:\n" - f"{schema_text}\n\n" - "Each JSON object should include:\n" - "1. A unique variant identifier.\n" - "2. Relevant metadata (e.g., associated gene, protein change, clinical significance).\n" - "3. Contextual evidence supporting the variant's importance.\n\n" - "Ensure the extracted information is accurate and directly relevant to variant annotation. " - "When extracting, prioritize structured data, avoiding ambiguous or irrelevant information." - ) - - return [ - {"role": "system", "content": system_message}, - {"role": "user", "content": content_text} - ] - -def call_api(client, messages, schema): - """ - Make an API call using the provided messages and schema. - :param client: API client object. - :param messages: List of messages to send to the API. - :param schema: JSON schema definition. - :return: Parsed JSON response. - """ - response = client.chat.completions.create( - model="gpt-4o-2024-08-06", - messages=messages, - response_format={ - "type": "json_schema", - "json_schema": { - "name": "gene_array_response", - "schema": schema, - "strict": True - } - } - ) - return json.loads(response.choices[0].message.content) - -def load_checkpoint(checkpoint_path): - """ - Load checkpoint data (processed PMIDs and saved results). - :param checkpoint_path: Path to the checkpoint file. - :return: Tuple of processed PMIDs set and saved results list. - """ - if os.path.exists(checkpoint_path): - with open(checkpoint_path, 'r') as f: - checkpoint = json.load(f) - return set(checkpoint.get('processed_pmids', [])), checkpoint.get('results', []) - return set(), [] - -def save_checkpoint(checkpoint_path, processed_pmids, results): - """ - Save checkpoint data (processed PMIDs and results). - :param checkpoint_path: Path to the checkpoint file. - :param processed_pmids: Set of processed PMIDs. - :param results: List of saved results. - """ - with open(checkpoint_path, 'w') as f: - json.dump({"processed_pmids": list(processed_pmids), "results": results}, f) - -def process_responses(df, client, schema_text, schema, checkpoint_path, custom_template=None): - """ - Process rows from the DataFrame and fetch API responses, saving dynamically. - :param df: DataFrame containing the rows to process. - :param client: API client object. - :param schema_text: Schema description for the API call. - :param schema: JSON schema definition. - :param checkpoint_path: Path to the checkpoint file. - :param custom_template: Optional custom template string for the system message. - :return: List of flattened JSON objects with additional metadata. - """ - processed_pmids, results = load_checkpoint(checkpoint_path) - - for _, row in tqdm(df.iterrows(), total=len(df), desc="Processing rows"): - pmid = row['PMID'] - - if pmid in processed_pmids: - continue # Skip already-processed PMIDs - - content_text = row['Content_text'] - try: - messages = create_messages(content_text, schema_text, custom_template) - extracted_data = call_api(client, messages, schema) - for gene_info in extracted_data.get('genes', []): - gene_info['PMID'] = pmid - results.append(gene_info) - - processed_pmids.add(pmid) # Mark as processed - except Exception as e: - print(f"Error processing PMID {pmid}: {e}") - - # Save progress after each processed row - save_checkpoint(checkpoint_path, processed_pmids, results) - - return results - -checkpoint_file = "/content/api_processing_checkpoint.json" -flattened_results = process_responses(df_new, client, schema_text, schema, checkpoint_file) - -# Convert to DataFrame -flattened_df = pd.DataFrame(flattened_results) - -# Print the resulting DataFrame -print(flattened_df) - -def loadCSV(path="/content/first_5000_var_drug.csv"): - return pd.read_csv(path,index_col=[0]) - -df_new = loadCSV() - -groups = df_new.groupby("PMID") - -maxV = [0,""] -lengths = set() -for i, (name, group) in enumerate(groups): - lengths.add(len(group)) - maxV[0] = max(len(group),maxV[0]) - maxV[1] = name - -print(maxV) -print(lengths) - -#load the outputs from chatgpt -outputsDF = pd.read_csv("/content/outputs_per_pmid.csv") -outputGroups = outputsDF.groupby("PMID") - -print(outputsDF.columns) -print(df_new.columns) - -column_mapping = { - 'gene': 'Gene', - 'variant/haplotypes': 'Variant/Haplotypes', - 'drug(s)': 'Drug(s)', - 'phenotype category': 'Phenotype Category', - 'significance': 'Significance', - 'metabolizer types': 'Metabolizer types', - 'specialty population': 'Specialty Population', - 'PMID': 'PMID' # Keep this for grouping -} - -# Add a new column to flattened_df to indicate whether a match was found and to store the Variant Annotation ID -flattened_df['match'] = False -flattened_df['Variant Annotation ID'] = None - -# Iterate through each row in flattened_df -for idx, row in flattened_df.iterrows(): - pmid = row['PMID'] - gene = row['gene'] - - # Find all rows in df_new where PMID matches - matching_pmid_rows = df_new[df_new['PMID'] == pmid] - - # Check if there is a gene match in the filtered df_new - gene_match_row = matching_pmid_rows[matching_pmid_rows['Gene'] == gene] - - if not gene_match_row.empty: - # If there's a match, attach the Variant Annotation ID to flattened_df - flattened_df.at[idx, 'Variant Annotation ID'] = gene_match_row.iloc[0]['Variant Annotation ID'] - flattened_df.at[idx, 'match'] = True # Mark as matched - else: - flattened_df.at[idx, 'match'] = False # No match found - -# Output the result -print(flattened_df) - -flattened_df = pd.read_csv("/content/outputs_per_pmid_matched.csv") - -# Perform a merge on 'Variant Annotation ID' -merged_df = pd.merge(flattened_df, df_new, on='Variant Annotation ID', how='inner', suffixes=('_flattened', '_df_new')) - -merged_df = merged_df.rename(columns={'Variant Annotation ID_flattened': 'Variant Annotation ID'}) - -# Output the final DataFrame - -merged_df.to_csv('/content/merged_first100.csv', index=False) - -df_new.to_csv('/content/df_new.csv', index=False) - -df_new_one = df_new -outputs_first100_one = flattened_df - -# Align the datasets based only on PMID, which is guaranteed to match -df_aligned_pmid = pd.merge(df_new_one, outputs_first100_one, how='inner', left_on='PMID', right_on='PMID', suffixes=('_truth', '_output')) - -# Compare key fields: 'Gene', 'Drug(s)', and 'Phenotype Category' -df_aligned_pmid['gene_match'] = df_aligned_pmid['Gene'] == df_aligned_pmid['gene'] -df_aligned_pmid['drug_match'] = df_aligned_pmid['Drug(s)'] == df_aligned_pmid['drug(s)'] -df_aligned_pmid['phenotype_match'] = df_aligned_pmid['Phenotype Category'] == df_aligned_pmid['phenotype category'] - -# Calculate the match statistics for each field -match_stats = { - 'gene_match_rate': df_aligned_pmid['gene_match'].mean() * 100, - 'drug_match_rate': df_aligned_pmid['drug_match'].mean() * 100, - 'phenotype_match_rate': df_aligned_pmid['phenotype_match'].mean() * 100, - 'exact_match_rate': df_aligned_pmid[['gene_match', 'drug_match', 'phenotype_match']].all(axis=1).mean() * 100, - 'partial_match_rate': df_aligned_pmid[['gene_match', 'drug_match', 'phenotype_match']].any(axis=1).mean() * 100, - 'mismatch_rate': (~df_aligned_pmid[['gene_match', 'drug_match', 'phenotype_match']].any(axis=1)).mean() * 100 -} - -# Display match statistics -print("Match Statistics:") -for key, value in match_stats.items(): - print(f"{key}: {value:.2f}%") - -pmidGroups=df_aligned_pmid.groupby("PMID") - -len(flattened_df["PMID"].unique()) -len(df_new[df_new["PMID"].isin(flattened_df["PMID"].unique())]) - -# Simplified VariantMatcher class -class SimplifiedVariantMatcher: - @staticmethod - def split_variants(variant_string): - """Split variant strings into individual components.""" - if pd.isna(variant_string): - return set() - return set(map(str.strip, variant_string.replace('/', ',').split(','))) - - @staticmethod - def preprocess_variants(variant_string, gene=None): - """Attach star alleles to gene and handle complex variant notations.""" - variants = SimplifiedVariantMatcher.split_variants(variant_string) - processed_variants = set() - - for variant in variants: - variant = variant.strip() - # Handle rsIDs - if 'rs' in variant: - rs_match = re.findall(r'rs\d+', variant) - processed_variants.update(rs_match) - - # Handle star alleles with additional SNP information - if '*' in variant and '-' in variant and gene: - star_allele, mutation = variant.split('-', 1) - processed_variants.add(f"{gene}{star_allele.strip()}") - processed_variants.add(mutation.strip()) - - # Handle simple star alleles - elif '*' in variant and gene: - processed_variants.add(f"{gene}{variant.strip()}") - - # Handle SNP notations directly - elif '>' in variant: - processed_variants.add(variant.strip()) - - # Add any remaining variants as is - else: - processed_variants.add(variant.strip()) - - return processed_variants - - def match_row(self, row): - """Match ground truth and predicted variants for a single row.""" - truth_variants = self.preprocess_variants(row['Variant/Haplotypes_truth'], gene=row['Gene_truth']) - predicted_variants = self.preprocess_variants(row['variant/haplotypes_output'], gene=row['Gene_output']) - - # Perform matching - if predicted_variants.intersection(truth_variants): - return 'Partial Match' - if predicted_variants == truth_variants: - return 'Exact Match' - return 'No Match' - - -# Manually add suffixes to overlapping columns -df_new_one = df_new_one.rename(columns={ - 'Gene': 'Gene_truth', - 'Drug(s)': 'Drug(s)_truth', - 'Phenotype Category': 'Phenotype Category_truth', - 'Variant/Haplotypes': 'Variant/Haplotypes_truth' -}) - -outputs_first100_one = outputs_first100_one.rename(columns={ - 'gene': 'Gene_output', - 'drug(s)': 'Drug(s)_output', - 'phenotype category': 'Phenotype Category_output', - 'variant/haplotypes': 'variant/haplotypes_output' -}) - -# Merge the datasets based on PMID -df_aligned_pmid = pd.merge( - df_new_one, - outputs_first100_one, - how='inner', - on='PMID' -) - -# Initialize SimplifiedVariantMatcher -matcher = SimplifiedVariantMatcher() - -# Add the variant_match column to the aligned dataset -df_aligned_pmid['variant_match'] = df_aligned_pmid.apply(lambda row: matcher.match_row(row), axis=1) - -# Compare key fields: 'Gene', 'Drug(s)', and 'Phenotype Category' -df_aligned_pmid['gene_match'] = df_aligned_pmid['Gene_truth'] == df_aligned_pmid['Gene_output'] -df_aligned_pmid['drug_match'] = df_aligned_pmid['Drug(s)_truth'] == df_aligned_pmid['Drug(s)_output'] -df_aligned_pmid['phenotype_match'] = df_aligned_pmid['Phenotype Category_truth'] == df_aligned_pmid['Phenotype Category_output'] -df_aligned_pmid['significance_match'] = ( - df_aligned_pmid['Significance'] == - df_aligned_pmid['significance'].map({ - 'significant': 'yes', - 'not significant': 'no', - 'not stated': 'not stated' - }) -) -df_aligned_pmid['metabolizer_match'] = df_aligned_pmid['Metabolizer types'] == df_aligned_pmid['metabolizer types'] -df_aligned_pmid['population_match'] = df_aligned_pmid['Specialty Population'] == df_aligned_pmid['specialty population'] - - -# Calculate match statistics for each field -match_stats = { - 'gene_match_rate': df_aligned_pmid['gene_match'].mean() * 100, - 'drug_match_rate': df_aligned_pmid['drug_match'].mean() * 100, - 'phenotype_match_rate': df_aligned_pmid['phenotype_match'].mean() * 100, - 'variant_match_rate': (df_aligned_pmid['variant_match'] == 'Exact Match').mean() * 100, - 'partial_variant_match_rate': (df_aligned_pmid['variant_match'] == 'Partial Match').mean() * 100, - 'exact_match_rate': df_aligned_pmid[['gene_match', 'drug_match', 'phenotype_match']].all(axis=1).mean() * 100, - 'partial_match_rate': df_aligned_pmid[['gene_match', 'drug_match', 'phenotype_match']].any(axis=1).mean() * 100, - 'mismatch_rate': (~df_aligned_pmid[['gene_match', 'drug_match', 'phenotype_match', 'variant_match']].any(axis=1)).mean() * 100, - 'significance_match_rate': df_aligned_pmid['significance_match'].mean() * 100, - 'metabolizer_match_rate': df_aligned_pmid['metabolizer_match'].mean() * 100, - 'population_match_rate': df_aligned_pmid['population_match'].mean() * 100, -} - -# Display match statistics -print("Match Statistics:") -for key, value in match_stats.items(): - print(f"{key}: {value:.2f}%") - - -# Assuming 'flattened_df' is your DataFrame and 'PMID' is the column name -flattened_df = flattened_df.dropna(subset=['PMID']) - - -# Update column names in comparisons based on the actual merged dataframe -gene_matches = [] -variant_matches = [] -drug_matches = [] -# Grouping annotations per PMID and creating sets for genes, variants, and drugs - -# Function to normalize and split values, handling missing data -def normalize_split(value): - if pd.isna(value): - return set() - return set(map(str.strip, str(value).lower().replace(';', ',').split(','))) - -# Group by PMID and create sets for each component -grouped_outputs = flattened_df.groupby('PMID').agg({ - 'gene': lambda x: set().union(*x.apply(normalize_split)), - 'variant/haplotypes': lambda x: set().union(*x.apply(normalize_split)), - 'drug(s)': lambda x: set().union(*x.apply(normalize_split)) -}).reset_index() - -grouped_drug = df_new.groupby('PMID').agg({ - 'Gene': lambda x: set().union(*x.apply(normalize_split)), - 'Variant/Haplotypes': lambda x: set().union(*x.apply(normalize_split)), - 'Drug(s)': lambda x: set().union(*x.apply(normalize_split)) -}).reset_index() - -# Merge the grouped dataframes on PMID -merged_grouped_df = pd.merge(grouped_outputs, grouped_drug, on='PMID', suffixes=('_output', '_drug'), how='inner') - - -# Compare sets for each PMID -for _, row in merged_grouped_df.iterrows(): - gene_matches.append(len(row['gene'].intersection(row['Gene'])) / len(row['Gene']) if len(row['Gene']) > 0 else 0) - variant_matches.append(len(row['variant/haplotypes'].intersection(row['Variant/Haplotypes'])) / len(row['Variant/Haplotypes']) if len(row['Variant/Haplotypes']) > 0 else 0) - drug_matches.append(len(row['drug(s)'].intersection(row['Drug(s)'])) / len(row['Drug(s)']) if len(row['Drug(s)']) > 0 else 0) - -# Calculate the average match rates -average_gene_match_rate = sum(gene_matches) / len(gene_matches) -average_variant_match_rate = sum(variant_matches) / len(variant_matches) -average_drug_match_rate = sum(drug_matches) / len(drug_matches) - -# Display average match rates -(average_gene_match_rate, average_variant_match_rate, average_drug_match_rate) - -average_gene_match_rate, average_variant_match_rate, average_drug_match_rate - -import matplotlib.pyplot as plt -# 1. Bar plot to visualize match rates for Gene, Drug, and Phenotype -match_fields = ['Gene', 'Drug', 'Phenotype','Signficance',"Variant"] - -match_rates = [match_stats['gene_match_rate'], match_stats['drug_match_rate'], match_stats['phenotype_match_rate'],match_stats['significance_match_rate'],match_stats['variant_match_rate']] - -plt.figure(figsize=(8, 6)) -plt.bar(match_fields, match_rates, color = ['#004B8D', '#175E54', '#8C1515', '#F58025','#5D4B3C'] -) -plt.title('Exact Match Rates by Category', fontweight='bold', fontsize=14) - -plt.ylabel('Match Rate (%)',fontweight="bold") -plt.xlabel('Category',fontweight='bold') -plt.ylim(0, 100) # Ensure the y-axis goes from 0 to 100 -plt.tight_layout() -plt.show() - - -# Data for Partial Match (outer pie chart) -sizes_partial = [match_stats['partial_match_rate'], 100 - match_stats['partial_match_rate']] -colors_partial = ['#175E54', 'none'] # Stanford Green for relevant portion - -# Data for Exact Match (inner pie chart) -sizes_exact = [match_stats['exact_match_rate'], 100 - match_stats['exact_match_rate']] -colors_exact = ['#8C1515', 'none'] # Cardinal Red for relevant portion - -# Create the figure -plt.figure(figsize=(8, 8)) - -# Plot the larger pie chart (Partial Match) -plt.pie( - sizes_partial, - colors=colors_partial, - startangle=90, - radius=1.0, - wedgeprops={'linewidth': 1, 'edgecolor': 'white'}, - labels=None # No direct labels, handled in legend -) - -# Plot the smaller pie chart (Exact Match) on top -plt.pie( - sizes_exact, - colors=colors_exact, - startangle=90, - radius=0.7, - wedgeprops={'linewidth': 1, 'edgecolor': 'white'}, - labels=None # No direct labels, handled in legend -) - -# Add legend with appropriate colors -legend_labels = [ - f"Partial Match ({round(match_stats['partial_match_rate'],2)}%)", - f"Exact Match ({round(match_stats['exact_match_rate'],3)}%)" -] -legend_colors = ['#175E54', '#8C1515'] # Match the colors used in the chart - -plt.legend( - handles=[ - plt.Line2D([0], [0], color=legend_colors[0], lw=6), # Partial Match color - plt.Line2D([0], [0], color=legend_colors[1], lw=6) # Exact Match color - ], - labels=legend_labels, - loc='upper right', - fontsize=10, - frameon=True, - title="Match Types", - title_fontsize=12 -) - -# Add a title -plt.title('Exact vs Partial Match Rates', fontweight = "bold",fontsize=14, pad=20) - -# Adjust layout for better spacing -plt.tight_layout() -plt.show() - -import matplotlib.pyplot as plt - -# Match rates grouped by PMID (these variables should be defined in your data) -# Example: average_gene_match_rate, average_variant_match_rate, average_drug_match_rate -categories = ['Gene', 'Drug','Variant'] -match_rates = [average_gene_match_rate*100, average_drug_match_rate*100,average_variant_match_rate*100,] - -# Stanford-inspired colors for bars -colors = ['#8C1515', '#175E54', '#F58025'] # Cardinal Red, Stanford Green, Stanford Orange - -# Create the bar chart -plt.figure(figsize=(10, 6)) -bars = plt.bar(categories, match_rates, color=colors, edgecolor='black', linewidth=1.2) - -# Add values on top of bars -for bar, rate in zip(bars, match_rates): - plt.text( - bar.get_x() + bar.get_width() / 2, - bar.get_height() + 1, # Adjust height for the text - f'{rate:.1f}%', # Rounded percentage - ha='center', - fontweight="bold", - fontsize=10, - color='black' - ) - -# Title and labels -plt.title('Match Rates by Category (Grouped by PMID)',fontweight="bold", fontsize=14, pad=20) -plt.ylabel('Match Rate (%)',fontweight="bold",fontsize=12) -plt.xlabel('Categories',fontweight="bold", fontsize=12) - -# Adjust ticks for readability -plt.xticks(fontsize=10) -plt.yticks(fontsize=10) - -# Set y-axis limits -plt.ylim(0, 100) # Assuming match rates are percentages - -# Add grid for better readability - - -# Show the plot -plt.tight_layout() -plt.show() - -# Load the dataset (replace the path with the correct file location in Colab) -wholecsv = pd.read_csv('/content/wholecsv.csv') - -# Summarizing match statistics based on the pre-marked columns -match_columns = [ - 'Match metabolizer', - 'Match significance', - 'Match all drug', - 'Match Any Drug', - 'Match gene', - 'Match phenotype', - 'Match population' -] - -# Calculating the percentage of matches for each attribute -match_stats_new = wholecsv[match_columns].mean() * 100 - -# Adjusting the color scheme to be more neutral and less bright -plt.figure(figsize=(10, 6)) -plt.bar(match_stats_new.index, match_stats_new.values, color=['#2E8B57', '#4682B4', '#6A5ACD', '#D2691E', '#556B2F', '#8B4513', '#2F4F4F']) -plt.title('Match Statistics for Different Attributes', fontsize=16) -plt.ylabel('Match Percentage (%)', fontsize=12) -plt.xlabel('Attributes', fontsize=12) -plt.xticks(rotation=45) -plt.ylim(0, 100) -plt.tight_layout() -plt.show() - - -# Creating the table summarizing match statistics for the poster -table_data = match_stats_new.reset_index() -table_data.columns = ['Attribute', 'Match Percentage'] - -# If you want to print the table to the console (optional): -print(table_data) \ No newline at end of file From 9c3c1cc3b2ceeccdc32423001395d9f31256c1a8 Mon Sep 17 00:00:00 2001 From: Sonnet Xu <59452214+sonnetx@users.noreply.github.com> Date: Tue, 27 May 2025 21:48:27 -0700 Subject: [PATCH 08/12] use load_variants module --- src/variant_extraction/run_variant_extraction.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/variant_extraction/run_variant_extraction.py b/src/variant_extraction/run_variant_extraction.py index 67711d6..b9215ae 100644 --- a/src/variant_extraction/run_variant_extraction.py +++ b/src/variant_extraction/run_variant_extraction.py @@ -3,12 +3,12 @@ from openai import OpenAI import os from config import ( - CLINICAL_VARIANTS_URL, VARIANT_ANNOTATIONS_URL, VAR_DRUG_ANN_PATH, - CHECKPOINT_PATH, OUTPUT_CSV_PATH, DF_NEW_CSV_PATH, WHOLE_CSV_PATH, SCHEMA_TEXT + VAR_DRUG_ANN_PATH, CHECKPOINT_PATH, OUTPUT_CSV_PATH, + DF_NEW_CSV_PATH, WHOLE_CSV_PATH, SCHEMA_TEXT ) -from data_download import download_and_extract_zip -from data_processing import load_and_prepare_data, process_dataframe -from api_processing import create_schema, process_responses + +from ..load_variants import download_annotations_pipeline, load_clinical_variants +from processing import load_and_prepare_data, process_dataframe, create_schema, process_responses from variant_matching import align_and_compare_datasets from visualization import plot_match_rates, plot_pie_charts, plot_grouped_match_rates, plot_attribute_match_rates from ncbi_fetch import setup_entrez @@ -18,8 +18,8 @@ def main(): client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) # Download data - download_and_extract_zip(CLINICAL_VARIANTS_URL, "clinicalVariants") - download_and_extract_zip(VARIANT_ANNOTATIONS_URL, "variantAnnotations") + load_clinical_variants() + download_annotations_pipeline() # Load and prepare data df_var_drug_ann, enum_values = load_and_prepare_data(VAR_DRUG_ANN_PATH) From c90bc460da75f160b008b57dae390098d3e2a05b Mon Sep 17 00:00:00 2001 From: Sonnet Xu <59452214+sonnetx@users.noreply.github.com> Date: Wed, 28 May 2025 08:23:24 -0700 Subject: [PATCH 09/12] create readme --- src/variant_extraction/README.md | 211 +++++++++++++++++++++++++++++++ 1 file changed, 211 insertions(+) create mode 100644 src/variant_extraction/README.md diff --git a/src/variant_extraction/README.md b/src/variant_extraction/README.md new file mode 100644 index 0000000..19921c7 --- /dev/null +++ b/src/variant_extraction/README.md @@ -0,0 +1,211 @@ +# Variant Extraction Module + +This is organized into the following Python modules, each handling a specific aspect of the workflow: + +config.py: Stores configuration variables, such as URLs, file paths, and API settings. +ncbi_fetch.py: Manages fetching PMCID and content from NCBI using the Entrez API. +processing.py: Loads and processes the variant annotation dataset, including enumeration cleaning and DataFrame processing. Also interacts with the OpenAI API to extract structured genetic variant data from publication content. +variant_matching.py: Compares extracted data with ground truth for accuracy evaluation. +visualization.py: Generates visualizations to summarize match rates and analysis results. +run_variant_extraction.py: Orchestrates the entire workflow, integrating all modules. + +## Methods +Below is a detailed description of the methods implemented in each module. +config.py +This module centralizes configuration settings to avoid hardcoding values in the codebase. + +Variables: +URLs for downloading PharmGKB data (CLINICAL_VARIANTS_URL, VARIANT_ANNOTATIONS_URL). +File paths for input and output data (VAR_DRUG_ANN_PATH, CHECKPOINT_PATH, OUTPUT_CSV_PATH, DF_NEW_CSV_PATH, WHOLE_CSV_PATH). +NCBI Entrez email (ENTREZ_EMAIL) for API compliance. +OpenAI model name (OPENAI_MODEL) and JSON schema (SCHEMA_TEXT) for structured API responses. +System message template (SYSTEM_MESSAGE_TEMPLATE) for API prompts. + + +## ncbi_fetch.py +This module interacts with the NCBI Entrez API to fetch publication metadata and content. + +`setup_entrez()`: + +Configures the Entrez API with the email address specified in config.py to comply with NCBI requirements. + + +`get_pmcid_from_pmid(pmid, retries=3)`: + +Queries the Entrez API to retrieve the PMCID associated with a given PMID. +Implements a retry mechanism with exponential backoff (with jitter) to handle transient errors. +Returns the PMCID if found, otherwise None. + + +`fetch_pmc_content(pmcid)`: + +Fetches the full XML content of a publication from the PMC database using the PMCID. +Returns the raw XML content or None if an error occurs. + + +`process_row(row, processed_pmids, processed_data)`: + +Processes a single DataFrame row to fetch PMCID and content. +Introduces a random delay (0.4–0.9 seconds) to avoid API throttling. +Checks for previously processed PMIDs to avoid redundant API calls. +Parses XML content using BeautifulSoup to extract the article title and full text. +Stores results in a dictionary and updates processed_pmids and processed_data for caching. + + + +## data_processing.py +This module handles loading and initial processing of the variant annotation dataset. + +`clean_enum_list(enum_list)`: + +Cleans and normalizes enumeration lists by removing NaN values, splitting comma-separated strings, and ensuring uniqueness. +Used to prepare valid enumeration values for the JSON schema. + + +`load_and_prepare_data(file_path)`: + +Loads the variant annotation TSV file into a pandas DataFrame. +Extracts unique values for Phenotype Category, Significance, Metabolizer types, and Population types to create enumeration lists. +Returns the DataFrame and a dictionary of cleaned enumeration values. + + +`process_dataframe(df, num_rows=None)`: + +Processes a subset of the DataFrame (limited to num_rows if specified) to fetch NCBI data for each row. +Uses tqdm for progress tracking during row processing. +Calls process_row from ncbi_fetch.py to fetch PMCID and content. +Combines the original DataFrame with fetched data and returns the result. + + + +## api_processing.py +This module handles interactions with the OpenAI API to extract structured genetic variant data. + +`create_schema(enum_values)`: + +Creates a JSON schema for API responses based on the provided enumeration values. +Defines a structure for an array of gene objects with fields like gene, variant, drug(s), and others, enforcing strict validation. + + +`create_messages(content_text, schema_text, custom_template=None)`: + +Generates API messages with a system prompt (using SYSTEM_MESSAGE_TEMPLATE or a custom template) and user content. +The system prompt instructs the API to extract genetic variant information in the specified schema format. + + +`call_api(client, messages, schema)`: + +Makes an API call to the OpenAI model (gpt-4o-2024-08-06) with the provided messages and schema. +Returns the parsed JSON response containing extracted gene data. + + +`load_checkpoint(checkpoint_path)`: + +Loads previously processed PMIDs and results from a checkpoint file to avoid redundant API calls. +Returns a set of processed PMIDs and a list of results. + + +`save_checkpoint(checkpoint_path, processed_pmids, results)`: + +Saves processed PMIDs and results to a checkpoint file for persistence. +Ensures progress is saved after each processed row to handle interruptions. + + +`process_responses(df, client, schema_text, schema, checkpoint_path, custom_template=None)`: + +Iterates through the DataFrame to process each row’s Content_text using the OpenAI API. +Skips previously processed PMIDs based on checkpoint data. +Saves results and updates the checkpoint after each row to ensure progress is not lost. +Returns a list of flattened JSON objects with extracted gene data and associated PMIDs. + + + +## variant_matching.py +This module compares extracted data with ground truth to evaluate accuracy. + +`SimplifiedVariantMatcher`: + +`split_variants(variant_string)`: +Splits variant strings into individual components, handling delimiters like commas and slashes. + + +`preprocess_variants(variant_string, gene=None)`: +Preprocesses variant strings to handle rsIDs, star alleles, and SNP notations. +Attaches gene names to star alleles and processes complex notations (e.g., CYP2C19*2-1234G>A). + + +`match_row(row)`: +Compares ground truth and predicted variants for a single row. +Returns Exact Match, Partial Match, or No Match based on set intersections. + + + + +`align_and_compare_datasets(df_new, flattened_df)`: + +Renames columns in input DataFrames to distinguish ground truth (_truth) and predicted (_output) data. +Merges DataFrames on PMID using an inner join. +Applies variant matching using SimplifiedVariantMatcher and compares other fields (gene, drug(s), phenotype category, significance, metabolizer types, specialty population). +Returns a DataFrame with match indicators for each field. + + + +## visualization.py +This module generates visualizations to summarize match rates and analysis results. + +`plot_match_rates(match_stats)`: + +Creates a bar plot of exact match rates for Gene, Drug, Phenotype, Significance, and Variant categories. +Uses a professional color scheme and ensures readability with appropriate labels and limits. + + +`plot_pie_charts(match_stats)`: + +Generates nested pie charts showing partial_match_rate (outer) and exact_match_rate (inner). +Includes a legend with percentage values for clarity. + + +`plot_grouped_match_rates(average_gene_match_rate, average_drug_match_rate, average_variant_match_rate)`: + +Plots a bar chart of match rates for Gene, Drug, and Variant categories, calculated by grouping data by PMID. +Adds percentage labels above bars for clarity. + + +`plot_attribute_match_rates(wholecsv)`: + +Creates a bar plot of match percentages for attributes like Match metabolizer, Match significance, etc., from the wholecsv dataset. +Returns a DataFrame summarizing the match statistics for inclusion in reports or posters. + + + +## run_variant_extraction.py +This module orchestrates the entire workflow. + +`main()`: +Initializes the OpenAI client with the API key from the environment. +Downloads and extracts PharmGKB data using data_download.download_and_extract_zip. +Loads and prepares the variant annotation dataset using data_processing.load_and_prepare_data. +Processes a subset of the DataFrame (e.g., 5 rows) to fetch NCBI data using data_processing.process_dataframe. +Creates a JSON schema using api_processing.create_schema. +Processes API responses to extract gene data using api_processing.process_responses. +Aligns and compares datasets using variant_matching.align_and_compare_datasets. +Calculates match statistics for various fields and grouped match rates by PMID. +Saves output DataFrames to CSV files (DF_NEW_CSV_PATH, OUTPUT_CSV_PATH). +Generates visualizations using visualization module functions. +Prints match statistics and attribute match table to the console. + + + +## Setup and Running + +### Install Dependencies: +`pip install -r requirements.txt` + + +### Set Environment Variable: + +Set the OPENAI_API_KEY environment variable with your OpenAI API key. + + +## Run the Project: +`python main.py` \ No newline at end of file From 8fdd19d2fe91f21f584011811c944770b4bb6b52 Mon Sep 17 00:00:00 2001 From: Sonnet Xu <59452214+sonnetx@users.noreply.github.com> Date: Wed, 28 May 2025 08:33:16 -0700 Subject: [PATCH 10/12] update readme.md --- src/variant_extraction/README.md | 67 +++++++++++++------------------- 1 file changed, 26 insertions(+), 41 deletions(-) diff --git a/src/variant_extraction/README.md b/src/variant_extraction/README.md index 19921c7..e7ac0b4 100644 --- a/src/variant_extraction/README.md +++ b/src/variant_extraction/README.md @@ -2,24 +2,22 @@ This is organized into the following Python modules, each handling a specific aspect of the workflow: -config.py: Stores configuration variables, such as URLs, file paths, and API settings. -ncbi_fetch.py: Manages fetching PMCID and content from NCBI using the Entrez API. -processing.py: Loads and processes the variant annotation dataset, including enumeration cleaning and DataFrame processing. Also interacts with the OpenAI API to extract structured genetic variant data from publication content. -variant_matching.py: Compares extracted data with ground truth for accuracy evaluation. -visualization.py: Generates visualizations to summarize match rates and analysis results. -run_variant_extraction.py: Orchestrates the entire workflow, integrating all modules. - -## Methods -Below is a detailed description of the methods implemented in each module. -config.py +- **config.py**: Stores configuration variables, such as URLs, file paths, and API settings. +- **ncbi_fetch.py**: Manages fetching PMCID and content from NCBI using the Entrez API. +- **processing.py**: Loads and processes the variant annotation dataset, including enumeration cleaning and DataFrame processing. Also interacts with the OpenAI API to extract structured genetic variant data from publication content. +- **variant_matching.py**: Compares extracted data with ground truth for accuracy evaluation. +- **visualization.py**: Generates visualizations to summarize match rates and analysis results. +- **run_variant_extraction.py**: Orchestrates the entire workflow, integrating all modules. + +## config.py This module centralizes configuration settings to avoid hardcoding values in the codebase. -Variables: -URLs for downloading PharmGKB data (CLINICAL_VARIANTS_URL, VARIANT_ANNOTATIONS_URL). -File paths for input and output data (VAR_DRUG_ANN_PATH, CHECKPOINT_PATH, OUTPUT_CSV_PATH, DF_NEW_CSV_PATH, WHOLE_CSV_PATH). -NCBI Entrez email (ENTREZ_EMAIL) for API compliance. -OpenAI model name (OPENAI_MODEL) and JSON schema (SCHEMA_TEXT) for structured API responses. -System message template (SYSTEM_MESSAGE_TEMPLATE) for API prompts. +**Variables**: +- URLs for downloading PharmGKB data (CLINICAL_VARIANTS_URL, VARIANT_ANNOTATIONS_URL). +- File paths for input and output data (VAR_DRUG_ANN_PATH, CHECKPOINT_PATH, OUTPUT_CSV_PATH, DF_NEW_CSV_PATH, WHOLE_CSV_PATH). +- NCBI Entrez email (ENTREZ_EMAIL) for API compliance. +- OpenAI model name (OPENAI_MODEL) and JSON schema (SCHEMA_TEXT) for structured API responses. +- System message template (SYSTEM_MESSAGE_TEMPLATE) for API prompts. ## ncbi_fetch.py @@ -78,7 +76,7 @@ Combines the original DataFrame with fetched data and returns the result. -## api_processing.py +## processing.py This module handles interactions with the OpenAI API to extract structured genetic variant data. `create_schema(enum_values)`: @@ -182,30 +180,17 @@ Returns a DataFrame summarizing the match statistics for inclusion in reports or This module orchestrates the entire workflow. `main()`: -Initializes the OpenAI client with the API key from the environment. -Downloads and extracts PharmGKB data using data_download.download_and_extract_zip. -Loads and prepares the variant annotation dataset using data_processing.load_and_prepare_data. -Processes a subset of the DataFrame (e.g., 5 rows) to fetch NCBI data using data_processing.process_dataframe. -Creates a JSON schema using api_processing.create_schema. -Processes API responses to extract gene data using api_processing.process_responses. -Aligns and compares datasets using variant_matching.align_and_compare_datasets. -Calculates match statistics for various fields and grouped match rates by PMID. -Saves output DataFrames to CSV files (DF_NEW_CSV_PATH, OUTPUT_CSV_PATH). -Generates visualizations using visualization module functions. -Prints match statistics and attribute match table to the console. - - - -## Setup and Running - -### Install Dependencies: -`pip install -r requirements.txt` - - -### Set Environment Variable: - -Set the OPENAI_API_KEY environment variable with your OpenAI API key. - +- Initializes the OpenAI client with the API key from the environment. +- Downloads and extracts PharmGKB data using data_download.download_and_extract_zip. +- Loads and prepares the variant annotation dataset using data_processing.load_and_prepare_data. +- Processes a subset of the DataFrame (e.g., 5 rows) to fetch NCBI data using data_processing.process_dataframe. +- Creates a JSON schema using processing.create_schema. +- Processes API responses to extract gene data using processing.process_responses. +- Aligns and compares datasets using variant_matching.align_and_compare_datasets. +- Calculates match statistics for various fields and grouped match rates by PMID. +- Saves output DataFrames to CSV files (DF_NEW_CSV_PATH, OUTPUT_CSV_PATH). +- Generates visualizations using visualization module functions. +- Prints match statistics and attribute match table to the console. ## Run the Project: `python main.py` \ No newline at end of file From b43d93b21b9794610d88a1701aa2fb22f8e42912 Mon Sep 17 00:00:00 2001 From: Sonnet Xu <59452214+sonnetx@users.noreply.github.com> Date: Wed, 28 May 2025 08:34:59 -0700 Subject: [PATCH 11/12] update readme.md --- src/variant_extraction/README.md | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/src/variant_extraction/README.md b/src/variant_extraction/README.md index e7ac0b4..e1e309c 100644 --- a/src/variant_extraction/README.md +++ b/src/variant_extraction/README.md @@ -49,10 +49,8 @@ Checks for previously processed PMIDs to avoid redundant API calls. Parses XML content using BeautifulSoup to extract the article title and full text. Stores results in a dictionary and updates processed_pmids and processed_data for caching. - - -## data_processing.py -This module handles loading and initial processing of the variant annotation dataset. +## processing.py +This module handles interactions with the OpenAI API to extract structured genetic variant data. `clean_enum_list(enum_list)`: @@ -74,11 +72,6 @@ Uses tqdm for progress tracking during row processing. Calls process_row from ncbi_fetch.py to fetch PMCID and content. Combines the original DataFrame with fetched data and returns the result. - - -## processing.py -This module handles interactions with the OpenAI API to extract structured genetic variant data. - `create_schema(enum_values)`: Creates a JSON schema for API responses based on the provided enumeration values. @@ -137,8 +130,6 @@ Compares ground truth and predicted variants for a single row. Returns Exact Match, Partial Match, or No Match based on set intersections. - - `align_and_compare_datasets(df_new, flattened_df)`: Renames columns in input DataFrames to distinguish ground truth (_truth) and predicted (_output) data. From 7f88ff9fb99444d09099f0d40a47481c588d0dcc Mon Sep 17 00:00:00 2001 From: Sonnet Xu <59452214+sonnetx@users.noreply.github.com> Date: Thu, 29 May 2025 14:00:18 -0700 Subject: [PATCH 12/12] clean and fix --- requirements.txt | 4 +- src/fetch_articles/article_downloader.py | 2 + src/fetch_articles/pmcid_converter.py | 4 -- src/load_variants/__init__.py | 3 +- .../download_annotations_pipeline.py | 4 +- src/variant_extraction/README.md | 41 +----------- src/variant_extraction/config.py | 10 +-- src/variant_extraction/ncbi_fetch.py | 65 ------------------- src/variant_extraction/processing.py | 15 +---- .../run_variant_extraction.py | 32 ++++----- 10 files changed, 29 insertions(+), 151 deletions(-) delete mode 100644 src/variant_extraction/ncbi_fetch.py diff --git a/requirements.txt b/requirements.txt index c6c0df1..8365210 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,6 @@ openai biopython beautifulsoup4 tqdm -matplotlib \ No newline at end of file +matplotlib +loguru +dotenv \ No newline at end of file diff --git a/src/fetch_articles/article_downloader.py b/src/fetch_articles/article_downloader.py index 543f01e..ee34c46 100644 --- a/src/fetch_articles/article_downloader.py +++ b/src/fetch_articles/article_downloader.py @@ -1,11 +1,13 @@ from loguru import logger from src.fetch_articles.pmcid_converter import get_unique_pmcids from src.utils.file_paths import get_project_root +from src.variant_extraction.config import ENTREZ_EMAIL from Bio import Entrez import os import json from tqdm import tqdm +Entrez.email = ENTREZ_EMAIL def fetch_pmc_content(pmcid): """ diff --git a/src/fetch_articles/pmcid_converter.py b/src/fetch_articles/pmcid_converter.py index 417ec57..1f6ae3a 100644 --- a/src/fetch_articles/pmcid_converter.py +++ b/src/fetch_articles/pmcid_converter.py @@ -13,10 +13,6 @@ # Email for NCBI Entrez.email = os.getenv("NCBI_EMAIL") -# Step 1: Function to get PMCID from PMID -import requests -from loguru import logger - import requests import time from loguru import logger diff --git a/src/load_variants/__init__.py b/src/load_variants/__init__.py index 6c56850..1470c6d 100644 --- a/src/load_variants/__init__.py +++ b/src/load_variants/__init__.py @@ -1,5 +1,4 @@ -from .load_clinical_variants import ( +from src.load_variants.load_clinical_variants import ( load_raw_variant_annotations, get_pmid_list, - variant_annotations_pipeline, ) diff --git a/src/load_variants/download_annotations_pipeline.py b/src/load_variants/download_annotations_pipeline.py index 0d9e1fa..9cfeef3 100644 --- a/src/load_variants/download_annotations_pipeline.py +++ b/src/load_variants/download_annotations_pipeline.py @@ -1,5 +1,5 @@ from loguru import logger -from src.load_variants.load_clinical_variants import download_and_extract_variant_annotations, load_raw_annotations, get_pmid_list +from src.load_variants.load_clinical_variants import download_and_extract_variant_annotations, load_raw_variant_annotations, get_pmid_list def variant_annotations_pipeline(): """ @@ -11,7 +11,7 @@ def variant_annotations_pipeline(): # Load the variant annotations logger.info("Loading variant annotations...") - df = load_raw_annotations() + df = load_raw_variant_annotations() # Get the PMIDs logger.info("Getting PMIDs...") diff --git a/src/variant_extraction/README.md b/src/variant_extraction/README.md index e1e309c..c77cb2f 100644 --- a/src/variant_extraction/README.md +++ b/src/variant_extraction/README.md @@ -19,36 +19,6 @@ This module centralizes configuration settings to avoid hardcoding values in the - OpenAI model name (OPENAI_MODEL) and JSON schema (SCHEMA_TEXT) for structured API responses. - System message template (SYSTEM_MESSAGE_TEMPLATE) for API prompts. - -## ncbi_fetch.py -This module interacts with the NCBI Entrez API to fetch publication metadata and content. - -`setup_entrez()`: - -Configures the Entrez API with the email address specified in config.py to comply with NCBI requirements. - - -`get_pmcid_from_pmid(pmid, retries=3)`: - -Queries the Entrez API to retrieve the PMCID associated with a given PMID. -Implements a retry mechanism with exponential backoff (with jitter) to handle transient errors. -Returns the PMCID if found, otherwise None. - - -`fetch_pmc_content(pmcid)`: - -Fetches the full XML content of a publication from the PMC database using the PMCID. -Returns the raw XML content or None if an error occurs. - - -`process_row(row, processed_pmids, processed_data)`: - -Processes a single DataFrame row to fetch PMCID and content. -Introduces a random delay (0.4–0.9 seconds) to avoid API throttling. -Checks for previously processed PMIDs to avoid redundant API calls. -Parses XML content using BeautifulSoup to extract the article title and full text. -Stores results in a dictionary and updates processed_pmids and processed_data for caching. - ## processing.py This module handles interactions with the OpenAI API to extract structured genetic variant data. @@ -65,13 +35,6 @@ Extracts unique values for Phenotype Category, Significance, Metabolizer types, Returns the DataFrame and a dictionary of cleaned enumeration values. -`process_dataframe(df, num_rows=None)`: - -Processes a subset of the DataFrame (limited to num_rows if specified) to fetch NCBI data for each row. -Uses tqdm for progress tracking during row processing. -Calls process_row from ncbi_fetch.py to fetch PMCID and content. -Combines the original DataFrame with fetched data and returns the result. - `create_schema(enum_values)`: Creates a JSON schema for API responses based on the provided enumeration values. @@ -183,5 +146,5 @@ This module orchestrates the entire workflow. - Generates visualizations using visualization module functions. - Prints match statistics and attribute match table to the console. -## Run the Project: -`python main.py` \ No newline at end of file +## Run the variant extraction: +`python -m src.variant_extraction.run_variant_extraction` \ No newline at end of file diff --git a/src/variant_extraction/config.py b/src/variant_extraction/config.py index f7d46c7..44c538d 100644 --- a/src/variant_extraction/config.py +++ b/src/variant_extraction/config.py @@ -6,11 +6,11 @@ VARIANT_ANNOTATIONS_URL = "https://api.pharmgkb.org/v1/download/file/data/variantAnnotations.zip" # File paths -VAR_DRUG_ANN_PATH = "/content/variantAnnotations/var_drug_ann.tsv" -CHECKPOINT_PATH = "/content/api_processing_checkpoint.json" -OUTPUT_CSV_PATH = "/content/merged_first100.csv" -DF_NEW_CSV_PATH = "/content/df_new.csv" -WHOLE_CSV_PATH = "/content/wholecsv.csv" +VAR_DRUG_ANN_PATH = "./data/variantAnnotations/var_drug_ann.tsv" +CHECKPOINT_PATH = "./data/api_processing_checkpoint.json" +OUTPUT_CSV_PATH = "./data/variant_extraction/merged.csv" +DF_NEW_CSV_PATH = "./data/variant_extraction/df_new.csv" +WHOLE_CSV_PATH = "./data/variant_extraction/wholecsv.csv" # NCBI email ENTREZ_EMAIL = "aron7628@gmail.com" diff --git a/src/variant_extraction/ncbi_fetch.py b/src/variant_extraction/ncbi_fetch.py deleted file mode 100644 index 6183d5c..0000000 --- a/src/variant_extraction/ncbi_fetch.py +++ /dev/null @@ -1,65 +0,0 @@ -# ncbi_fetch.py -from Bio import Entrez -import time -import random -from bs4 import BeautifulSoup -from config import ENTREZ_EMAIL - -def setup_entrez(): - """Configure Entrez with email.""" - Entrez.email = ENTREZ_EMAIL - -def get_pmcid_from_pmid(pmid, retries=3): - """Get PMCID from PMID with retry mechanism.""" - for attempt in range(retries): - try: - handle = Entrez.elink(dbfrom="pubmed", db="pmc", id=pmid, linkname="pubmed_pmc") - record = Entrez.read(handle) - handle.close() - if record and 'LinkSetDb' in record[0] and record[0]['LinkSetDb']: - return record[0]['LinkSetDb'][0]['Link'][0]['Id'] - print(f"No PMCID found for PMID {pmid}.") - return None - except Exception as e: - print(f"Error for PMID {pmid} on attempt {attempt + 1}: {e}") - if attempt < retries - 1: - sleep_time = (2 ** attempt) + random.uniform(0, 1) - print(f"Retrying in {sleep_time:.2f} seconds...") - time.sleep(sleep_time) - else: - return None - -def fetch_pmc_content(pmcid): - """Fetch PMC content using PMCID.""" - try: - handle = Entrez.efetch(db="pmc", id=pmcid, rettype="full", retmode="xml") - record = handle.read() - handle.close() - return record - except Exception as e: - print(f"Error fetching content for PMCID {pmcid}: {e}") - return None - -def process_row(row, processed_pmids, processed_data): - """Process a single DataFrame row to fetch PMCID and content.""" - time.sleep(0.4 + random.uniform(0, 0.5)) - pmid = str(row['PMID']) - - if pmid in processed_pmids: - return pd.Series(processed_data[pmid]) - - pmcid = get_pmcid_from_pmid(pmid) - result = {'PMCID': None, 'Title': None, 'Content': None, 'Content_text': None} - - if pmcid: - xml_content = fetch_pmc_content(pmcid) - if xml_content: - soup = BeautifulSoup(xml_content, 'xml') - title_tag = soup.find('article-title') - title = title_tag.get_text() if title_tag else "No Title Found" - clean_text = soup.get_text() - result = {'PMCID': pmcid, 'Title': title, 'Content': xml_content, 'Content_text': clean_text} - - processed_pmids.add(pmid) - processed_data[pmid] = result - return pd.Series(result) \ No newline at end of file diff --git a/src/variant_extraction/processing.py b/src/variant_extraction/processing.py index 48d5384..b81a65a 100644 --- a/src/variant_extraction/processing.py +++ b/src/variant_extraction/processing.py @@ -1,11 +1,11 @@ # processing.py +import os import pandas as pd import tqdm -from ncbi_fetch import process_row import json from openai import OpenAI from tqdm import tqdm -from config import SCHEMA_TEXT, SYSTEM_MESSAGE_TEMPLATE, OPENAI_MODEL +from src.variant_extraction.config import SCHEMA_TEXT, SYSTEM_MESSAGE_TEMPLATE, OPENAI_MODEL def clean_enum_list(enum_list): """Clean and normalize enumeration lists.""" @@ -27,16 +27,6 @@ def load_and_prepare_data(file_path): 'specialty_population': specialty_population_enum } -def process_dataframe(df, num_rows=None): - """Process DataFrame rows to fetch NCBI data.""" - if num_rows: - df = df.head(num_rows) - processed_pmids = set() - processed_data = {} - tqdm.pandas(desc="Processing rows") - result_df = df.progress_apply(lambda row: process_row(row, processed_pmids, processed_data), axis=1) - return pd.concat([df, result_df], axis=1) - def create_schema(enum_values): """Create JSON schema for API calls.""" return { @@ -105,6 +95,7 @@ def process_responses(df, client, schema_text, schema, checkpoint_path, custom_t pmid = row['PMID'] if pmid in processed_pmids: continue + print(row) content_text = row['Content_text'] try: messages = create_messages(content_text, schema_text, custom_template) diff --git a/src/variant_extraction/run_variant_extraction.py b/src/variant_extraction/run_variant_extraction.py index b9215ae..e93032d 100644 --- a/src/variant_extraction/run_variant_extraction.py +++ b/src/variant_extraction/run_variant_extraction.py @@ -2,42 +2,33 @@ import pandas as pd from openai import OpenAI import os -from config import ( +import sys + +from src.variant_extraction.config import ( VAR_DRUG_ANN_PATH, CHECKPOINT_PATH, OUTPUT_CSV_PATH, DF_NEW_CSV_PATH, WHOLE_CSV_PATH, SCHEMA_TEXT ) -from ..load_variants import download_annotations_pipeline, load_clinical_variants -from processing import load_and_prepare_data, process_dataframe, create_schema, process_responses -from variant_matching import align_and_compare_datasets -from visualization import plot_match_rates, plot_pie_charts, plot_grouped_match_rates, plot_attribute_match_rates -from ncbi_fetch import setup_entrez +from src.variant_extraction.processing import load_and_prepare_data, create_schema, process_responses +from src.variant_extraction.variant_matching import align_and_compare_datasets +from src.variant_extraction.visualization import plot_match_rates, plot_pie_charts, plot_grouped_match_rates, plot_attribute_match_rates def main(): # Initialize OpenAI client client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) - # Download data - load_clinical_variants() - download_annotations_pipeline() - # Load and prepare data df_var_drug_ann, enum_values = load_and_prepare_data(VAR_DRUG_ANN_PATH) - # Process initial DataFrame - test_df = process_dataframe(df_var_drug_ann, num_rows=5) - test_df.dropna(subset=['Content'], inplace=True) - test_df.reset_index(drop=True, inplace=True) - # Create schema schema = create_schema(enum_values) # Process API responses - flattened_results = process_responses(test_df, client, SCHEMA_TEXT, schema, CHECKPOINT_PATH) + flattened_results = process_responses(df_var_drug_ann, client, SCHEMA_TEXT, schema, CHECKPOINT_PATH) flattened_df = pd.DataFrame(flattened_results) # Align and compare datasets - df_aligned = align_and_compare_datasets(test_df, flattened_df) + df_aligned = align_and_compare_datasets(df_var_drug_ann, flattened_df) # Calculate match statistics match_stats = { @@ -66,7 +57,7 @@ def normalize_split(value): 'drug(s)': lambda x: set().union(*x.apply(normalize_split)) }).reset_index() - grouped_drug = test_df.groupby('PMID').agg({ + grouped_drug = df_var_drug_ann.groupby('PMID').agg({ 'Gene': lambda x: set().union(*x.apply(normalize_split)), 'Variant/Haplotypes': lambda x: set().union(*x.apply(normalize_split)), 'Drug(s)': lambda x: set().union(*x.apply(normalize_split)) @@ -86,14 +77,14 @@ def normalize_split(value): average_drug_match_rate = sum(drug_matches) / len(drug_matches) # Save outputs - test_df.to_csv(DF_NEW_CSV_PATH, index=False) + df_var_drug_ann.to_csv(DF_NEW_CSV_PATH, index=False) flattened_df.to_csv(OUTPUT_CSV_PATH, index=False) # Visualizations plot_match_rates(match_stats) plot_pie_charts(match_stats) plot_grouped_match_rates(average_gene_match_rate, average_drug_match_rate, average_variant_match_rate) - wholecsv = pd.read_csv(WHOLE_CSV_PATH) + table_data = plot_attribute_match_rates(wholecsv) # Print results @@ -108,5 +99,4 @@ def normalize_split(value): print(table_data) if __name__ == "__main__": - setup_entrez() main() \ No newline at end of file