diff --git a/.gitignore b/.gitignore index 762e891..a27cd78 100644 --- a/.gitignore +++ b/.gitignore @@ -174,9 +174,9 @@ cython_debug/ .abstra/ # Visual Studio Code -# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore +# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore -# and can be added to the global gitignore or merged into this file. However, if you prefer, +# and can be added to the global gitignore or merged into this file. However, if you prefer, # you could uncomment the following to ignore the enitre vscode folder # .vscode/ @@ -193,4 +193,21 @@ cython_debug/ .cursorignore .cursorindexingignore -.DS_Store.DS_Store +# macOS +.DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +.AppleDouble +.LSOverride +Icon? +.DocumentRevisions-V100 +.fseventsd +.TemporaryItems +.VolumeIcon.icns +.com.apple.timemachine.donotpresent + +**/data/** +**/logs/** +results_** \ No newline at end of file diff --git a/README.md b/README.md index 44664b4..283a31c 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,292 @@ -# ent-llm +
-LLM evaluation of ENT clinical cases +Placeholder + + +# LLM evaluation of ENT clinical cases for surgical recommendation + + + +
## Overview -`ent-llm` is a LLM project evaluating otolaryngology clinical cases. The goal is to assist clinicians and researchers in analyzing patient scenarios, generating differential diagnoses, and evaluating treatment options with AI-powered tools. +`ent-llm` evaluates otolaryngology (ENT) clinical cases using Large Language Models. It processes chronic sinusitis patient data from Stanford's medical records and generates surgical recommendations with confidence scores. + +## Installation + +### Create Virtual Environment + +```bash +python -m venv .venv +source .venv/bin/activate +``` + +### Install Dependencies + +```bash +pip install -e . +``` + +**Required environment variables:** + +```bash +export GOOGLE_APPLICATION_CREDENTIALS="/path/to/gcp_credentials.json" # BigQuery access +export VAULT_SECRET_KEY="your_private_key" # SecureLLM API access +``` + +## Quick Start + +### Full Pipeline + +```bash +# Step 1: Extract data from BigQuery +ent-llm-extract --output cases.csv + +# Step 2: Run LLM analysis +ent-llm --model apim:gpt-4.1 --input cases.csv --output results.csv +``` + +### Testing with Limited Data + +```bash +# Extract only 100 patients for testing +python cli_extract.py --output test_cases.csv --limit 100 + +# Run analysis +python cli.py --model apim:claude-3.7 --input test_cases.csv --output test_results.csv +``` + +## CLI Reference + +### `ent-llm-extract` - Data Extraction + +Extracts and preprocesses clinical data from BigQuery. + +```bash +ent-llm-extract [OPTIONS] +``` + +| Option | Short | Description | +|--------|-------|-------------| +| `--output` | `-o` | Output CSV file (default: `llm_cases.csv`) | +| `--batch-size` | `-b` | Patients per batch (default: 100) | +| `--limit` | `-l` | Max patients to process (default: all) | +| `--save-processed` | | Also save full processed dataframe | +| `--processed-output` | | Path for processed data CSV | +| `--checkpoint-dir` | | Directory for checkpoint files | +| `--count-only` | | Show patient count and exit | +| `--verbose` | `-v` | Enable verbose logging | + +**Examples:** + +```bash +# Count total patients +ent-llm-extract --count-only + +# Extract all data +ent-llm-extract --output cases.csv + +# Extract with checkpoints (recommended for large datasets) +ent-llm-extract --output cases.csv --checkpoint-dir ./checkpoints + +# Extract both LLM-ready and full processed data +ent-llm-extract --output cases.csv --save-processed --processed-output full_data.csv +``` + +### `ent-llm` - LLM Analysis + +Runs surgical recommendation analysis using various LLM backends. + +```bash +ent-llm [OPTIONS] +``` + +| Option | Short | Description | +|--------|-------|-------------| +| `--model` | `-m` | LLM model to use (default: `apim:gpt-4.1`) | +| `--input` | `-i` | Input CSV file with case data | +| `--output` | `-o` | Output CSV file for results | +| `--delay` | `-d` | Delay between API calls (default: 0.2s) | +| `--interactive` | `-I` | Interactive query mode | +| `--list-models` | `-l` | List available models and exit | +| `--verbose` | `-v` | Enable verbose logging | + +**Available models:** + +- `apim:gpt-4.1` +- `apim:claude-3.7` +- `apim:llama-3.3-70b` +- `apim:gemini-2.5-pro-preview-05-06` + +**Examples:** + +```bash +# List available models +ent-llm --list-models + +# Run analysis with specific model +ent-llm --model apim:claude-3.7 --input cases.csv --output results.csv + +# Interactive query mode +ent-llm --model apim:gpt-4.1 --interactive + +# Demo mode (no input file) +ent-llm --model apim:gpt-4.1 +``` + +### `ent-llm-ablation` - Demographic Ablation Analysis + +Measures how demographic variables influence LLM surgical recommendations by selectively excluding demographics from prompts. + +```bash +ent-llm-ablation [OPTIONS] +``` + +| Option | Short | Description | +|--------|-------|-------------| +| `--model` | `-m` | LLM model to use (default: `apim:gpt-4.1`) | +| `--input` | `-i` | Input CSV file (clinical text + demographics) | +| `--output-dir` | `-o` | Output directory for result CSVs (default: `./ablation_results`) | +| `--baseline` | `-b` | Path to pre-computed baseline CSV (skip baseline run) | +| `--experiments` | `-e` | Which to run: `all`, `individual`, `grouped`, `baseline-only` | +| `--sample-size` | `-n` | Stratified sample size | +| `--max-tokens` | | Filter out cases exceeding estimated token count | +| `--ground-truth` | `-g` | Ground truth column name (default: `had_surgery`) | +| `--delay` | `-d` | Delay between API calls (default: 0.2s) | +| `--flush-interval` | `-f` | Incremental save interval (default: 10) | +| `--no-resume` | | Start fresh instead of resuming | +| `--list-experiments` | | List all experiments and exit | +| `--verbose` | `-v` | Enable verbose logging | + +**Input CSV** requires the same clinical columns as `ent-llm` plus demographic columns: `legal_sex`, `age`, `race`, `ethnicity`, `recent_bmi`, `smoking_hx`, `alcohol_use`, `zipcode`, `insurance_type`, `occupation`. Optionally includes a ground truth column (e.g. `had_surgery`) for accuracy analysis. + +**Experiments** (16 total): +- **Baseline** — all demographics included +- **10 individual ablations** — exclude one variable at a time (`no_legal_sex`, `no_age`, etc.) +- **5 grouped ablations** — exclude variable groups (`no_protected_attributes`, `no_socioeconomic`, `no_health_behaviors`, `no_physical_attributes`, `no_all_demographics`) + +**Examples:** + +```bash +# List all experiments +ent-llm-ablation --list-experiments + +# Run full ablation on a stratified sample of 500 cases +ent-llm-ablation -m apim:gpt-4.1 -i cases_with_demographics.csv -n 500 + +# Filter long cases and run only individual ablations +ent-llm-ablation -m apim:claude-3.7 -i data.csv --max-tokens 5000 -e individual + +# Resume with a pre-computed baseline +ent-llm-ablation -m apim:gpt-4.1 -i data.csv -b ./ablation_results/baseline_results.csv +``` + +**Output:** Each experiment saves to `{output_dir}/{experiment_name}_results.csv`. A summary comparing all experiments to baseline is saved to `{output_dir}/ablation_summary.csv` with flip rates, confidence changes, and (if ground truth provided) accuracy metrics. + +## Data Pipeline + +``` +┌─────────────────────────────────────────────────────────────────────────────┐ +│ DATA EXTRACTION │ +│ (ent-llm-extract CLI) │ +├─────────────────────────────────────────────────────────────────────────────┤ +│ │ +│ BigQuery (Stanford STARR) │ +│ │ │ +│ ├── clinical_note → Filter by ENT authors │ +│ ├── radiology_report → Filter CT sinus reports │ +│ └── procedures → Extract surgery CPT codes │ +│ │ │ +│ ▼ │ +│ Build patient records │ +│ │ │ +│ ▼ │ +│ Censor surgical planning text │ +│ │ │ +│ ▼ │ +│ Format for LLM input → cases.csv │ +│ │ +└─────────────────────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────────────────┐ +│ LLM ANALYSIS │ +│ (ent-llm CLI) │ +├─────────────────────────────────────────────────────────────────────────────┤ +│ │ +│ cases.csv │ +│ │ │ +│ ▼ │ +│ SecureLLM API (GPT-4, Claude, Llama, Gemini) │ +│ │ │ +│ ▼ │ +│ Parse JSON responses │ +│ │ │ +│ ▼ │ +│ results.csv (decision, confidence, reasoning) │ +│ │ +└─────────────────────────────────────────────────────────────────────────────┘ +``` + +## Data Source + +**Google BigQuery - Stanford STARR** + +| Setting | Value | +|---------|-------| +| Project | `som-nero-phi-roxanad-entllm` | +| Datasets | Chronic sinusitis cohorts (2016-2025) | + +**Tables:** + +| Table | Description | +|-------|-------------| +| `clinical_note` | ENT clinical notes (progress notes, consults, H&P) | +| `radiology_report` | CT sinus scan reports | +| `procedures` | CPT codes for surgeries/endoscopies | + +## Input/Output Formats + +### Input CSV (from extraction) + +| Column | Description | +|--------|-------------| +| `llm_caseID` | Unique case identifier | +| `formatted_progress_text` | Concatenated ENT clinical notes | +| `formatted_radiology_text` | Concatenated radiology reports | + +### Output CSV (from analysis) + +| Column | Description | +|--------|-------------| +| `llm_caseID` | Case identifier | +| `decision` | `Yes` or `No` for surgery recommendation | +| `confidence` | 1-10 confidence score | +| `reasoning` | 2-4 sentence explanation | +| `api_response` | Raw LLM response | + +## Project Structure + +``` +ent-llm/ +├── cli.py # LLM analysis CLI +├── cli_extract.py # Data extraction CLI +├── cli_ablation.py # Demographic ablation CLI +├── data_extraction/ # BigQuery data processing +│ ├── config.py # Project settings, CPT codes +│ ├── raw_data_parsing.py # Data extraction functions +│ └── note_extraction.py # Note filtering and censoring +├── llm_query/ # LLM integration +│ ├── securellm_adapter.py # SecureLLM client wrapper +│ ├── LLM_analysis.py # Analysis pipeline +│ ├── ablation_analysis.py # Ablation experiment logic +│ └── llm_input.py # Data formatting +├── batch_query/ # Batch processing +├── evaluation/ # Results evaluation +└── training/ # Training workflows +``` -## Features +## License -- Input structured or free-text ENT case data -- Query and evaluate cases using state-of-the-art LLMs -- Generate clinical summaries and differential diagnoses -- Analyze diagnosis and surgical intervention accuracy +MIT License - See LICENSE file for details. diff --git a/batch_query/batch_query.py b/batch_query/batch_query.py index aa8a9c8..c939769 100644 --- a/batch_query/batch_query.py +++ b/batch_query/batch_query.py @@ -164,18 +164,13 @@ def quick_status_check(final_llm_df: pd.DataFrame, save_directory: str = '.'): import asyncio import aiohttp -import openai from concurrent.futures import ThreadPoolExecutor, as_completed import time -import asyncio -import aiohttp -import openai -from concurrent.futures import ThreadPoolExecutor, as_completed -import time +from llm_query.securellm_adapter import query_llm, SecureLLMClient def parallel_process_llm_cases(llm_df: pd.DataFrame, - api_key: str, + api_key: str = None, max_workers: int = 5, delay_seconds: float = 0.1) -> pd.DataFrame: """ @@ -183,7 +178,7 @@ def parallel_process_llm_cases(llm_df: pd.DataFrame, Args: llm_df: DataFrame with cases to process - api_key: OpenAI API key + api_key: Deprecated. Kept for backward compatibility. SecureLLM uses VAULT_SECRET_KEY. max_workers: Number of parallel workers (start with 5) delay_seconds: Delay between requests (can be smaller with parallel) """ @@ -193,8 +188,6 @@ def process_single_case(row_data): idx, row = row_data try: - client = openai.OpenAI(api_key=api_key) - case_id = row['llm_caseID'] # Generate prompt @@ -204,8 +197,12 @@ def process_single_case(row_data): radiology_text=row['formatted_radiology_text'] ) - # Query OpenAI - response = query_openai(prompt, client) + # Query SecureLLM + response = query_llm( + prompt=prompt, + system_message="You are an expert otolaryngologist. Provide a surgical recommendation in the requested JSON format.", + temperature=0.2 + ) result = { 'index': idx, @@ -275,11 +272,17 @@ def process_single_case(row_data): return result_df def fast_batch_processing(final_llm_df: pd.DataFrame, - api_key: str, + api_key: str = None, batch_size: int = 200, max_workers: int = 5) -> pd.DataFrame: """ Fast batch processing with parallel execution. + + Args: + final_llm_df: DataFrame with cases to process + api_key: Deprecated. Kept for backward compatibility. SecureLLM uses VAULT_SECRET_KEY. + batch_size: Number of cases per batch + max_workers: Number of parallel workers """ # Load existing results diff --git a/cli.py b/cli.py new file mode 100644 index 0000000..5f2cfbe --- /dev/null +++ b/cli.py @@ -0,0 +1,340 @@ +#!/usr/bin/env python3 +# This source file is part of the ARPA-H CARE LLM project +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT +# + +""" +CLI entrypoint for ENT-LLM analysis. + +This module provides a command-line interface to run the LLM analysis +with different model backends. +""" + +import argparse +import logging +import sys +from typing import Optional + +import pandas as pd + +from llm_query.securellm_adapter import ModelConfig, query_llm, SecureLLMClient +from llm_query.LLM_analysis import ( + generate_prompt, + parse_llm_response, + process_llm_cases, + run_llm_analysis, +) + +# Available LLM models +AVAILABLE_MODELS = [ + "apim:llama-3.3-70b", + "apim:claude-3.7", + "apim:gpt-4.1", + "apim:gemini-2.5-pro-preview-05-06", +] + + +def setup_logging(verbose: bool = False) -> None: + """Configure logging based on verbosity level.""" + level = logging.DEBUG if verbose else logging.INFO + logging.basicConfig( + level=level, + format="%(asctime)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler(sys.stdout)], + ) + + +def run_single_query(model: str, prompt: str) -> Optional[str]: + """ + Run a single LLM query with the specified model. + + Args: + model: The model identifier to use. + prompt: The prompt to send to the LLM. + + Returns: + The LLM response or None on error. + """ + client = SecureLLMClient(model_name=model) + return client.query(prompt) + + +def run_analysis_with_model( + model: str, + input_file: Optional[str] = None, + output_file: Optional[str] = None, + delay_seconds: float = 0.2, + flush_interval: int = 10, + no_resume: bool = False, + start_row: int = 0, +) -> pd.DataFrame: + """ + Run the full LLM analysis pipeline with a specified model. + + Args: + model: The model identifier to use. + input_file: Path to input CSV file with case data. + output_file: Path to save results CSV (saved incrementally). + delay_seconds: Delay between API calls. + flush_interval: Number of cases to process before flushing to disk. + no_resume: If True, start fresh instead of resuming from existing output. + start_row: Start processing from this row index (0-based). + + Returns: + DataFrame with analysis results. + """ + logger = logging.getLogger(__name__) + + # Update the default model configuration + ModelConfig.DEFAULT_LLM_MODEL = model + logger.info(f"Using model: {model}") + + if input_file: + logger.info(f"Loading data from: {input_file}") + llm_df = pd.read_csv(input_file) + + # Validate required columns + required_cols = ["llm_caseID", "formatted_progress_text", "formatted_radiology_text"] + missing_cols = [col for col in required_cols if col not in llm_df.columns] + if missing_cols: + raise ValueError(f"Input file missing required columns: {missing_cols}") + + # Apply start_row filter + if start_row > 0: + if start_row >= len(llm_df): + raise ValueError(f"start_row ({start_row}) is >= total rows ({len(llm_df)})") + logger.info(f"Starting from row {start_row} (skipping first {start_row} rows)") + llm_df = llm_df.iloc[start_row:].reset_index(drop=True) + + # Run analysis with incremental saving + results_df = run_llm_analysis( + llm_df, + output_file=output_file, + flush_interval=flush_interval, + resume=not no_resume + ) + + if output_file: + logger.info(f"Results saved to: {output_file}") + + return results_df + else: + logger.warning("No input file provided. Running in demo mode.") + # Demo mode: create a simple test case + demo_data = { + "llm_caseID": ["DEMO_001"], + "formatted_progress_text": [ + "Patient presents with chronic rhinosinusitis refractory to medical management " + "including multiple courses of antibiotics and intranasal corticosteroids. " + "Symptoms include persistent nasal congestion, facial pressure, and purulent discharge " + "for over 12 weeks. Previous treatments have failed to provide lasting relief." + ], + "formatted_radiology_text": [ + "CT Sinuses: Mucosal thickening in bilateral maxillary sinuses with partial " + "opacification. Ostiomeatal complex obstruction bilaterally. No bony erosion." + ], + } + demo_df = pd.DataFrame(demo_data) + + logger.info("Processing demo case...") + results_df = run_llm_analysis(demo_df) + + print("\n=== Demo Results ===") + for _, row in results_df.iterrows(): + print(f"Case ID: {row['llm_caseID']}") + print(f"Decision: {row['decision']}") + print(f"Confidence: {row['confidence']}") + print(f"Reasoning: {row['reasoning']}") + + return results_df + + +def interactive_query(model: str) -> None: + """ + Run an interactive query session with the specified model. + + Args: + model: The model identifier to use. + """ + logger = logging.getLogger(__name__) + logger.info(f"Starting interactive session with model: {model}") + print(f"\nInteractive query mode with {model}") + print("Type 'quit' or 'exit' to end the session.\n") + + client = SecureLLMClient(model_name=model) + + while True: + try: + prompt = input("You: ").strip() + if prompt.lower() in ("quit", "exit"): + print("Goodbye!") + break + if not prompt: + continue + + response = client.query( + prompt, + system_message="You are an expert otolaryngologist. Answer questions about ENT cases.", + ) + if response: + print(f"\nAssistant: {response}\n") + else: + print("\n[No response received]\n") + except KeyboardInterrupt: + print("\nGoodbye!") + break + except Exception as e: + logger.error(f"Error: {e}") + print(f"\n[Error: {e}]\n") + + +def main() -> int: + """Main CLI entrypoint.""" + parser = argparse.ArgumentParser( + description="ENT-LLM Analysis CLI - Run clinical case analysis with various LLM backends", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Run analysis with default model (demo mode) + python -m cli --model apim:gpt-4.1 + + # Run analysis on a CSV file (saves incrementally every 10 cases) + python -m cli --model apim:claude-3.7 --input cases.csv --output results.csv + + # Run with custom flush interval (save every 5 cases) + python -m cli --model apim:gpt-4.1 -i cases.csv -o results.csv --flush-interval 5 + + # Start from a specific row (e.g., skip first 100 rows) + python -m cli --model apim:gpt-4.1 -i cases.csv -o results.csv --start-row 100 + + # Start fresh (don't resume from existing output) + python -m cli --model apim:gpt-4.1 -i cases.csv -o results.csv --no-resume + + # Interactive query mode + python -m cli --model apim:llama-3.3-70b --interactive + + # List available models + python -m cli --list-models + """, + ) + + parser.add_argument( + "--model", + "-m", + type=str, + choices=AVAILABLE_MODELS, + default="apim:gpt-4.1", + help="LLM model to use for analysis (default: apim:gpt-4.1)", + ) + + parser.add_argument( + "--input", + "-i", + type=str, + help="Input CSV file with case data (columns: llm_caseID, formatted_progress_text, formatted_radiology_text)", + ) + + parser.add_argument( + "--output", + "-o", + type=str, + help="Output CSV file for results", + ) + + parser.add_argument( + "--delay", + "-d", + type=float, + default=0.2, + help="Delay in seconds between API calls (default: 0.2)", + ) + + parser.add_argument( + "--flush-interval", + "-f", + type=int, + default=10, + help="Number of cases to process before flushing to disk (default: 10)", + ) + + parser.add_argument( + "--no-resume", + action="store_true", + help="Start fresh instead of resuming from existing output file", + ) + + parser.add_argument( + "--start-row", + "-s", + type=int, + default=0, + help="Start processing from this row index (0-based, default: 0)", + ) + + parser.add_argument( + "--interactive", + "-I", + action="store_true", + help="Run in interactive query mode", + ) + + parser.add_argument( + "--list-models", + "-l", + action="store_true", + help="List available models and exit", + ) + + parser.add_argument( + "--verbose", + "-v", + action="store_true", + help="Enable verbose logging", + ) + + args = parser.parse_args() + + # Handle --list-models + if args.list_models: + print("Available models:") + for model in AVAILABLE_MODELS: + print(f" - {model}") + return 0 + + # Setup logging + setup_logging(args.verbose) + logger = logging.getLogger(__name__) + + logger.info(f"ENT-LLM Analysis CLI") + logger.info(f"Selected model: {args.model}") + + try: + if args.interactive: + interactive_query(args.model) + else: + run_analysis_with_model( + model=args.model, + input_file=args.input, + output_file=args.output, + delay_seconds=args.delay, + flush_interval=args.flush_interval, + no_resume=args.no_resume, + start_row=args.start_row, + ) + return 0 + except KeyboardInterrupt: + print("\nInterrupted by user") + return 130 + except Exception as e: + logger.error(f"Error: {e}") + if args.verbose: + import traceback + traceback.print_exc() + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/cli_ablation.py b/cli_ablation.py new file mode 100644 index 0000000..e3082bf --- /dev/null +++ b/cli_ablation.py @@ -0,0 +1,353 @@ +#!/usr/bin/env python3 +""" +CLI entrypoint for ENT-LLM demographic ablation analysis. + +Runs ablation experiments that selectively exclude demographic variables +from LLM prompts to measure their influence on surgical recommendations. +""" + +import argparse +import logging +import os +import shutil +import sys + +import pandas as pd + +from llm_query.ablation_analysis import ( + DEMOGRAPHIC_GROUPS, + DEMOGRAPHIC_VARS, + analyze_ablation_results, + filter_long_cases, + run_ablation_experiment, + stratified_sample, +) +from llm_query.securellm_adapter import ModelConfig + +# Available LLM models (same as cli.py) +AVAILABLE_MODELS = [ + "apim:llama-3.3-70b", + "apim:claude-3.7", + "apim:gpt-4.1", + "apim:gemini-2.5-pro-preview-05-06", +] + + +def setup_logging(verbose: bool = False) -> None: + """Configure logging based on verbosity level.""" + level = logging.DEBUG if verbose else logging.INFO + logging.basicConfig( + level=level, + format="%(asctime)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler(sys.stdout)], + ) + + +def list_experiments() -> None: + """List all ablation experiments and exit.""" + print("Ablation experiments:\n") + print("Baseline:") + print(" baseline — all demographics included\n") + print("Individual ablations (exclude one variable):") + for var in DEMOGRAPHIC_VARS: + print(f" no_{var}") + print("\nGrouped ablations (exclude a group):") + for group_name, group_vars in DEMOGRAPHIC_GROUPS.items(): + print(f" no_{group_name} — excludes: {', '.join(group_vars)}") + + total = 1 + len(DEMOGRAPHIC_VARS) + len(DEMOGRAPHIC_GROUPS) + print(f"\nTotal experiments: {total}") + + +def validate_input(df: pd.DataFrame) -> None: + """Validate required columns exist in input DataFrame.""" + clinical_cols = ["llm_caseID", "formatted_progress_text", "formatted_radiology_text"] + missing_clinical = [c for c in clinical_cols if c not in df.columns] + if missing_clinical: + raise ValueError(f"Input CSV missing required clinical columns: {missing_clinical}") + + demographic_present = [v for v in DEMOGRAPHIC_VARS if v in df.columns] + if not demographic_present: + raise ValueError( + f"Input CSV has no demographic columns. Expected at least some of: {DEMOGRAPHIC_VARS}" + ) + + missing_demo = [v for v in DEMOGRAPHIC_VARS if v not in df.columns] + if missing_demo: + logging.getLogger(__name__).warning( + f"Input CSV missing some demographic columns (will be skipped): {missing_demo}" + ) + + +def build_experiment_list(experiments_mode: str): + """Build list of (experiment_name, exclude_vars) tuples. + + Args: + experiments_mode: One of 'all', 'individual', 'grouped', 'baseline-only'. + + Returns: + List of (experiment_name, exclude_vars) tuples. + """ + experiments = [] + + if experiments_mode in ("all", "baseline-only"): + experiments.append(("baseline", None)) + + if experiments_mode in ("all", "individual"): + for var in DEMOGRAPHIC_VARS: + experiments.append((f"no_{var}", [var])) + + if experiments_mode in ("all", "grouped"): + for group_name, group_vars in DEMOGRAPHIC_GROUPS.items(): + experiments.append((f"no_{group_name}", list(group_vars))) + + return experiments + + +def print_summary_table(summary_df: pd.DataFrame) -> None: + """Print formatted summary table to stdout.""" + if summary_df.empty: + print("No ablation results to summarize.") + return + + print(f"\n{'='*80}") + print("ABLATION ANALYSIS SUMMARY") + print(f"{'='*80}") + + cols = [ + "experiment", + "experiment_type", + "flip_rate_%", + "yes_to_no", + "no_to_yes", + "avg_confidence_change", + ] + if "accuracy_change_%" in summary_df.columns: + cols.append("accuracy_change_%") + + display_cols = [c for c in cols if c in summary_df.columns] + print(summary_df[display_cols].to_string(index=False)) + print(f"{'='*80}") + + +def main() -> int: + """Main CLI entrypoint.""" + parser = argparse.ArgumentParser( + description="ENT-LLM Ablation Analysis CLI - demographic ablation study", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # List all experiments + python cli_ablation.py --list-experiments + + # Run on small sample + python cli_ablation.py -m apim:gpt-4.1 -i data.csv -o ./ablation_test -n 10 + + # Run with pre-computed baseline + python cli_ablation.py -m apim:gpt-4.1 -i data.csv -o ./ablation_results \\ + -b ./ablation_results/baseline_results.csv + + # Run only individual ablations + python cli_ablation.py -m apim:gpt-4.1 -i data.csv -e individual + + # Run only baseline + python cli_ablation.py -m apim:gpt-4.1 -i data.csv -e baseline-only + """, + ) + + parser.add_argument( + "--model", + "-m", + type=str, + choices=AVAILABLE_MODELS, + default="apim:gpt-4.1", + help="LLM model to use (default: apim:gpt-4.1)", + ) + + parser.add_argument( + "--input", + "-i", + type=str, + help="Input CSV file (clinical text + demographics)", + ) + + parser.add_argument( + "--output-dir", + "-o", + type=str, + default="./ablation_results", + help="Output directory for result CSVs (default: ./ablation_results)", + ) + + parser.add_argument( + "--baseline", + "-b", + type=str, + help="Path to pre-computed baseline CSV (skip baseline run)", + ) + + parser.add_argument( + "--experiments", + "-e", + type=str, + choices=["all", "individual", "grouped", "baseline-only"], + default="all", + help="Which experiments to run (default: all)", + ) + + parser.add_argument( + "--sample-size", + "-n", + type=int, + help="Optional sample size (stratified sampling)", + ) + + parser.add_argument( + "--max-tokens", + type=int, + help="Filter out cases exceeding this estimated token count (e.g. 5000)", + ) + + parser.add_argument( + "--ground-truth", + "-g", + type=str, + default="had_surgery", + help="Column name for ground truth (default: had_surgery)", + ) + + parser.add_argument( + "--delay", + "-d", + type=float, + default=0.2, + help="Delay between API calls in seconds (default: 0.2)", + ) + + parser.add_argument( + "--flush-interval", + "-f", + type=int, + default=10, + help="Flush interval for incremental saving (default: 10)", + ) + + parser.add_argument( + "--no-resume", + action="store_true", + help="Start fresh instead of resuming from existing output", + ) + + parser.add_argument( + "--list-experiments", + action="store_true", + help="List all experiments and exit", + ) + + parser.add_argument( + "--verbose", + "-v", + action="store_true", + help="Enable verbose logging", + ) + + args = parser.parse_args() + + # Handle --list-experiments + if args.list_experiments: + list_experiments() + return 0 + + # --input is required for all other modes + if not args.input: + parser.error("--input / -i is required (use --list-experiments to see experiments)") + + # Setup logging + setup_logging(args.verbose) + logger = logging.getLogger(__name__) + + # Set model + ModelConfig.DEFAULT_LLM_MODEL = args.model + logger.info(f"ENT-LLM Ablation Analysis CLI") + logger.info(f"Model: {args.model}") + + # Load input + logger.info(f"Loading input: {args.input}") + df = pd.read_csv(args.input) + validate_input(df) + logger.info(f"Loaded {len(df)} cases") + + # Optional token-length filter + if args.max_tokens: + before = len(df) + df = filter_long_cases(df, max_tokens=args.max_tokens) + logger.info(f"Token filter ({args.max_tokens}): {before} -> {len(df)} cases") + + # Stratified sample + if args.sample_size: + logger.info(f"Creating stratified sample of {args.sample_size} cases") + df = stratified_sample(df, args.sample_size) + logger.info(f"Sample size after stratification: {len(df)}") + + # Create output directory + os.makedirs(args.output_dir, exist_ok=True) + + # Build experiment list + experiments = build_experiment_list(args.experiments) + logger.info(f"Planned experiments: {len(experiments)}") + + # Handle pre-computed baseline + if args.baseline: + if not os.path.exists(args.baseline): + logger.error(f"Baseline file not found: {args.baseline}") + return 1 + baseline_dest = os.path.join(args.output_dir, "baseline_results.csv") + if os.path.abspath(args.baseline) != os.path.abspath(baseline_dest): + shutil.copy2(args.baseline, baseline_dest) + logger.info(f"Copied baseline to {baseline_dest}") + # Remove baseline from experiments list since it's pre-computed + experiments = [(name, evars) for name, evars in experiments if name != "baseline"] + logger.info(f"Using pre-computed baseline, {len(experiments)} experiments remaining") + + # Run experiments + try: + for exp_name, exclude_vars in experiments: + output_file = os.path.join(args.output_dir, f"{exp_name}_results.csv") + logger.info(f"Starting experiment: {exp_name}") + run_ablation_experiment( + df=df, + exclude_vars=exclude_vars, + experiment_name=exp_name, + output_file=output_file, + model_name=args.model, + delay_seconds=args.delay, + flush_interval=args.flush_interval, + resume=not args.no_resume, + ) + except KeyboardInterrupt: + print("\nInterrupted by user. Progress has been saved incrementally.") + return 130 + + # Analyze results + logger.info("Analyzing results...") + ground_truth_df = None + if args.ground_truth in df.columns: + ground_truth_df = df + + summary_df = analyze_ablation_results( + results_dir=args.output_dir, + ground_truth_df=ground_truth_df, + ground_truth_col=args.ground_truth, + ) + + summary_path = os.path.join(args.output_dir, "ablation_summary.csv") + summary_df.to_csv(summary_path, index=False) + logger.info(f"Summary saved to {summary_path}") + + print_summary_table(summary_df) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/cli_extract.py b/cli_extract.py new file mode 100644 index 0000000..66a73a4 --- /dev/null +++ b/cli_extract.py @@ -0,0 +1,578 @@ +#!/usr/bin/env python3 +# This source file is part of the ARPA-H CARE LLM project +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT +# + +""" +CLI entrypoint for BigQuery data extraction. + +This module provides a command-line interface to extract and preprocess +ENT clinical data from BigQuery for LLM analysis. +""" + +import argparse +import gc +import logging +import sys +import time +from typing import Dict, Iterator, List, Optional, Tuple + +import pandas as pd + +from google.cloud import bigquery + +from data_extraction.config import ( + PROJECT_ID, + DATASET_IDS, + DATA_TABLES, + CLINICAL_NOTE_TYPES, + CLINICAL_NOTE_TITLES, + RADIOLOGY_REPORT_TYPE, + RADIOLOGY_REPORT_TITLE, + SURGERY_CPT_CODES, + DIAGNOSTIC_ENDOSCOPY_CPT_CODES, +) +from data_extraction.raw_data_parsing import ( + extract_ent_notes, + extract_radiology_reports, + procedures_df, + build_patient_df, +) +from data_extraction.note_extraction import ( + add_last_progress_note, + recursive_censor_notes, +) +from llm_query.llm_input import create_llm_dataframe + + +def setup_logging(verbose: bool = False) -> logging.Logger: + """Configure logging based on verbosity level.""" + level = logging.DEBUG if verbose else logging.INFO + logging.basicConfig( + level=level, + format="%(asctime)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler(sys.stdout)], + ) + return logging.getLogger(__name__) + + +class BatchProcessor: + """Handles batch processing of patient data from BigQuery.""" + + def __init__( + self, + project_id: str, + dataset_ids: List[str], + batch_size: int = 100, + max_retries: int = 3, + ): + self.client = bigquery.Client(project=project_id) + self.project_id = project_id + self.dataset_ids = dataset_ids + self.batch_size = batch_size + self.max_retries = max_retries + self.patient_identifier = "patient_id" + self.logger = logging.getLogger(__name__) + + def get_total_patient_count(self) -> int: + """Get total number of patients with clinical notes.""" + notes_union = "\nUNION ALL\n".join( + f"SELECT {self.patient_identifier} FROM `{self.project_id}.{ds}.clinical_note`" + for ds in self.dataset_ids + ) + + count_query = f""" + WITH all_notes AS ( + SELECT DISTINCT {self.patient_identifier} FROM ({notes_union}) + ) + SELECT COUNT(*) as total_patients + FROM all_notes + """ + + result = self.client.query(count_query).to_dataframe() + return int(result["total_patients"].iloc[0]) + + def get_patient_batches(self, limit: Optional[int] = None) -> Iterator[List[str]]: + """Generator that yields batches of patient IDs.""" + notes_union = "\nUNION ALL\n".join( + f"SELECT {self.patient_identifier} FROM `{self.project_id}.{ds}.clinical_note`" + for ds in self.dataset_ids + ) + + all_patients_query = f""" + WITH all_notes AS ( + SELECT DISTINCT {self.patient_identifier} FROM ({notes_union}) + ) + SELECT {self.patient_identifier} + FROM all_notes + ORDER BY {self.patient_identifier} + """ + + offset = 0 + total_yielded = 0 + + while True: + # Adjust batch size if we're near the limit + current_batch_size = self.batch_size + if limit is not None: + remaining = limit - total_yielded + if remaining <= 0: + break + current_batch_size = min(self.batch_size, remaining) + + batch_query = f""" + {all_patients_query} + LIMIT {current_batch_size} OFFSET {offset} + """ + + batch_df = self.client.query(batch_query).to_dataframe() + + if batch_df.empty: + break + + patient_ids = batch_df[self.patient_identifier].tolist() + yield patient_ids + + total_yielded += len(patient_ids) + offset += self.batch_size + + del batch_df + gc.collect() + + def extract_batch_data( + self, patient_ids: List[str], table_names: List[str] + ) -> Dict[str, pd.DataFrame]: + """Extract data for a batch of patients.""" + batch_data = {} + id_list_str = ", ".join(f"'{pid}'" for pid in patient_ids) + + self.logger.info(f"Extracting data for {len(patient_ids)} patients...") + + for table in table_names: + self.logger.debug(f"Loading table: {table}") + + for attempt in range(self.max_retries): + try: + union_query = "\nUNION ALL\n".join( + f"SELECT * FROM `{self.project_id}.{ds}.{table}`" + for ds in self.dataset_ids + ) + + full_query = f""" + SELECT * FROM ({union_query}) + WHERE {self.patient_identifier} IN ({id_list_str}) + """ + + job_config = bigquery.QueryJobConfig( + use_query_cache=True, use_legacy_sql=False + ) + + df = self.client.query(full_query, job_config=job_config).to_dataframe() + batch_data[table] = df + self.logger.debug(f" {table}: {df.shape[0]} rows loaded") + break + + except Exception as e: + self.logger.warning(f" Attempt {attempt + 1} failed for '{table}': {e}") + if attempt == self.max_retries - 1: + self.logger.error(f" Failed to load '{table}' after {self.max_retries} attempts") + batch_data[table] = pd.DataFrame() + else: + time.sleep(2**attempt) + + return batch_data + + +def process_batch( + batch_data: Dict[str, pd.DataFrame], + patient_ids: List[str], + global_case_id_counter: int, +) -> Tuple[pd.DataFrame, pd.DataFrame, int]: + """Process a single batch of patient data.""" + logger = logging.getLogger(__name__) + + try: + logger.info(f"Processing batch of {len(patient_ids)} patients...") + + # Extract ENT notes + if "clinical_note" in batch_data and not batch_data["clinical_note"].empty: + ent_notes = extract_ent_notes( + batch_data["clinical_note"], + CLINICAL_NOTE_TYPES, + CLINICAL_NOTE_TITLES, + ) + logger.debug(f" Found {len(ent_notes)} ENT notes") + else: + ent_notes = pd.DataFrame() + + # Extract radiology reports + if "radiology_report" in batch_data and not batch_data["radiology_report"].empty: + rad_reports = extract_radiology_reports( + batch_data["radiology_report"], + RADIOLOGY_REPORT_TYPE, + RADIOLOGY_REPORT_TITLE, + ) + logger.debug(f" Found {len(rad_reports)} radiology reports") + else: + rad_reports = pd.DataFrame() + + # Process procedures + if "procedures" in batch_data and not batch_data["procedures"].empty: + procedures = procedures_df( + batch_data["procedures"], + SURGERY_CPT_CODES, + DIAGNOSTIC_ENDOSCOPY_CPT_CODES, + ) + logger.debug(f" Found {len(procedures)} procedure records") + else: + procedures = pd.DataFrame() + + # Check if we have any data + if ent_notes.empty and rad_reports.empty and procedures.empty: + logger.debug("No relevant data found in this batch") + return pd.DataFrame(), pd.DataFrame(), global_case_id_counter + + # Build patient dataframe + patient_df = build_patient_df(ent_notes, rad_reports, procedures) + + if patient_df.empty: + return pd.DataFrame(), pd.DataFrame(), global_case_id_counter + + logger.debug(f" Patient dataframe: {len(patient_df)} patients") + + # Add progress notes + patient_df_with_progress = add_last_progress_note(patient_df) + + # Censor notes + processed_df, skipped_ids = recursive_censor_notes(patient_df_with_progress) + logger.debug(f" After censoring: {len(processed_df)} patients, {len(skipped_ids)} skipped") + + # Create sequential case IDs + if not processed_df.empty: + case_ids = range(global_case_id_counter, global_case_id_counter + len(processed_df)) + processed_df["llm_caseID"] = list(case_ids) + new_counter = global_case_id_counter + len(processed_df) + else: + new_counter = global_case_id_counter + + # Format for LLM input + llm_df = create_llm_dataframe(processed_df) if not processed_df.empty else pd.DataFrame() + + # Add radiology flag + if not processed_df.empty: + processed_df["has_radiology"] = [ + arr.size > 0 if hasattr(arr, "size") else len(arr) > 0 + for arr in processed_df["radiology_reports"] + ] + + return llm_df, processed_df, new_counter + + except Exception as e: + logger.error(f"Error processing batch: {e}") + import traceback + traceback.print_exc() + return pd.DataFrame(), pd.DataFrame(), global_case_id_counter + + +def run_extraction( + output_file: str, + batch_size: int = 100, + limit: Optional[int] = None, + save_processed: bool = False, + processed_output: Optional[str] = None, + checkpoint_interval: int = 10, + checkpoint_dir: Optional[str] = None, +) -> Tuple[pd.DataFrame, pd.DataFrame]: + """ + Run the full data extraction pipeline. + + Args: + output_file: Path to save the LLM-ready CSV. + batch_size: Number of patients per batch. + limit: Maximum number of patients to process (None for all). + save_processed: Whether to also save the processed dataframe. + processed_output: Path for processed dataframe CSV. + checkpoint_interval: Save checkpoint every N batches. + checkpoint_dir: Directory for checkpoint files. + + Returns: + Tuple of (llm_df, processed_df). + """ + logger = logging.getLogger(__name__) + + logger.info("Initializing BigQuery batch processor...") + processor = BatchProcessor(PROJECT_ID, DATASET_IDS, batch_size=batch_size) + + # Get total patient count + try: + total_patients = processor.get_total_patient_count() + if limit: + total_patients = min(total_patients, limit) + logger.info(f"Total patients to process: {total_patients}") + except Exception as e: + logger.error(f"Error getting patient count: {e}") + return pd.DataFrame(), pd.DataFrame() + + global_case_id_counter = 1 + batch_num = 0 + total_cases = 0 + start_time = time.time() + first_batch = True + + def _serialize_for_csv(df: pd.DataFrame) -> pd.DataFrame: + """Convert complex columns to JSON strings for CSV compatibility (in-place).""" + for col in df.columns: + if df[col].dtype == object: + try: + df[col] = df[col].apply( + lambda x: str(x) if isinstance(x, (list, dict)) else x + ) + except Exception: + pass + return df + + try: + for patient_batch in processor.get_patient_batches(limit=limit): + batch_num += 1 + batch_start = time.time() + + logger.info(f"{'=' * 50}") + logger.info(f"BATCH {batch_num}") + logger.info(f"{'=' * 50}") + + # Extract batch data + batch_data = processor.extract_batch_data(patient_batch, DATA_TABLES) + + # Process the batch + llm_df, processed_df, global_case_id_counter = process_batch( + batch_data, patient_batch, global_case_id_counter + ) + + # Write results incrementally to disk + if not llm_df.empty: + llm_df.to_csv( + output_file, + mode='w' if first_batch else 'a', + header=first_batch, + index=False + ) + total_cases += len(llm_df) + + if save_processed and processed_output and not processed_df.empty: + _serialize_for_csv(processed_df) + processed_df.to_csv( + processed_output, + mode='w' if first_batch else 'a', + header=first_batch, + index=False + ) + + if not llm_df.empty or (save_processed and not processed_df.empty): + first_batch = False + + # Clean up memory - free batch objects immediately + del batch_data, llm_df, processed_df + gc.collect() + + batch_elapsed = time.time() - batch_start + total_elapsed = time.time() - start_time + logger.info( + f"Batch {batch_num} completed in {batch_elapsed:.1f}s. " + f"Total cases: {global_case_id_counter - 1}. " + f"Total time: {total_elapsed:.1f}s" + ) + + # Save checkpoint (just metadata now, data already on disk) + if checkpoint_dir and batch_num % checkpoint_interval == 0: + checkpoint_path = f"{checkpoint_dir}/checkpoint_batch_{batch_num}.txt" + with open(checkpoint_path, 'w') as f: + f.write(f"batch_num={batch_num}\n") + f.write(f"global_case_id_counter={global_case_id_counter}\n") + f.write(f"total_cases={total_cases}\n") + logger.info(f"Checkpoint saved: {checkpoint_path}") + + except KeyboardInterrupt: + logger.warning("Interrupted by user. Partial results already saved to disk.") + except Exception as e: + logger.error(f"Error in extraction: {e}") + import traceback + traceback.print_exc() + + total_time = time.time() - start_time + logger.info(f"Extraction complete in {total_time:.1f}s ({total_time/60:.1f} min)") + + if total_cases > 0: + logger.info(f"LLM data saved to: {output_file}") + logger.info(f"Total cases: {total_cases}") + if save_processed and processed_output: + logger.info(f"Processed data saved to: {processed_output}") + + # Return empty DataFrames - data is on disk, not in memory + # Caller can read from files if needed + return pd.DataFrame(), pd.DataFrame() + else: + logger.warning("No data extracted") + return pd.DataFrame(), pd.DataFrame() + + +def main() -> int: + """Main CLI entrypoint for data extraction.""" + parser = argparse.ArgumentParser( + description="ENT-LLM Data Extraction CLI - Extract clinical data from BigQuery", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Extract all data to a CSV file + python cli_extract.py --output cases.csv + + # Extract with a patient limit (for testing) + python cli_extract.py --output cases.csv --limit 100 + + # Extract with custom batch size + python cli_extract.py --output cases.csv --batch-size 200 + + # Extract and save both LLM and processed data + python cli_extract.py --output cases.csv --save-processed --processed-output processed.csv + + # Extract with checkpoints + python cli_extract.py --output cases.csv --checkpoint-dir ./checkpoints + + # Show patient count only + python cli_extract.py --count-only + """, + ) + + parser.add_argument( + "--output", + "-o", + type=str, + default="llm_cases.csv", + help="Output CSV file for LLM-ready data (default: llm_cases.csv)", + ) + + parser.add_argument( + "--batch-size", + "-b", + type=int, + default=100, + help="Number of patients per batch (default: 100)", + ) + + parser.add_argument( + "--limit", + "-l", + type=int, + default=None, + help="Maximum number of patients to process (default: all)", + ) + + parser.add_argument( + "--save-processed", + action="store_true", + help="Also save the full processed dataframe", + ) + + parser.add_argument( + "--processed-output", + type=str, + default="processed_data.csv", + help="Output file for processed data (default: processed_data.csv)", + ) + + parser.add_argument( + "--checkpoint-dir", + type=str, + default=None, + help="Directory to save checkpoint files", + ) + + parser.add_argument( + "--checkpoint-interval", + type=int, + default=10, + help="Save checkpoint every N batches (default: 10)", + ) + + parser.add_argument( + "--count-only", + action="store_true", + help="Only show total patient count and exit", + ) + + parser.add_argument( + "--verbose", + "-v", + action="store_true", + help="Enable verbose logging", + ) + + args = parser.parse_args() + + # Setup logging + logger = setup_logging(args.verbose) + logger.info("ENT-LLM Data Extraction CLI") + + try: + if args.count_only: + # Just show patient count + logger.info("Counting patients in BigQuery...") + processor = BatchProcessor(PROJECT_ID, DATASET_IDS, batch_size=100) + total = processor.get_total_patient_count() + print(f"Total patients with clinical notes: {total}") + return 0 + + # Create checkpoint directory if specified + if args.checkpoint_dir: + import os + os.makedirs(args.checkpoint_dir, exist_ok=True) + + # Run extraction + run_extraction( + output_file=args.output, + batch_size=args.batch_size, + limit=args.limit, + save_processed=args.save_processed, + processed_output=args.processed_output, + checkpoint_interval=args.checkpoint_interval, + checkpoint_dir=args.checkpoint_dir, + ) + + # Check if output file was created and has data + import os + if not os.path.exists(args.output) or os.path.getsize(args.output) == 0: + logger.error("No data was extracted") + return 1 + + # Count rows in output file (header + data) + with open(args.output, 'r') as f: + total_cases = sum(1 for _ in f) - 1 # subtract header + + # Print summary + print(f"\n{'=' * 50}") + print("EXTRACTION SUMMARY") + print(f"{'=' * 50}") + print(f"Total cases extracted: {total_cases}") + print(f"Output file: {args.output}") + if args.save_processed: + print(f"Processed file: {args.processed_output}") + print(f"\nNext step: Run LLM analysis with:") + print(f" python cli.py --model apim:gpt-4.1 --input {args.output} --output results.csv") + + return 0 + + except KeyboardInterrupt: + print("\nInterrupted by user") + return 130 + except Exception as e: + logger.error(f"Error: {e}") + if args.verbose: + import traceback + traceback.print_exc() + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/docs/logo.png b/docs/logo.png new file mode 100644 index 0000000..424fb5a Binary files /dev/null and b/docs/logo.png differ diff --git a/llm_query/LLM_analysis.py b/llm_query/LLM_analysis.py index 95fae2e..42c95f4 100644 --- a/llm_query/LLM_analysis.py +++ b/llm_query/LLM_analysis.py @@ -1,29 +1,35 @@ -import openai import pandas as pd import json import logging import time -from typing import Dict, Any +import gc +import os +from typing import Dict, Any, Optional, Set from tqdm import tqdm -def query_openai(prompt: str, client) -> str: - """Query GPT-4omini for surgical decision based on input prompt.""" - try: - response = client.chat.completions.create( - model="gpt-4o-mini", - messages=[ - {"role": "system", "content": ( - "You are an expert otolaryngologist. " - "Provide a surgical recommendation in the requested JSON format." - )}, - {"role": "user", "content": prompt} - ], - temperature=0.2 - ) - return response.choices[0].message.content - except Exception as e: - logging.error(f"OpenAI API error: {e}") - return None +from llm_query.securellm_adapter import query_llm, SecureLLMClient + + +def query_openai(prompt: str, client=None) -> str: + """ + Query the LLM for surgical decision based on input prompt. + + This function now uses SecureLLM instead of direct OpenAI calls. + The client parameter is kept for backward compatibility but is ignored. + + Args: + prompt: The prompt to send to the LLM. + client: Deprecated. Kept for backward compatibility. + + Returns: + The LLM response content or None on error. + """ + return query_llm( + prompt=prompt, + system_message="You are an expert otolaryngologist. Provide a surgical recommendation in the requested JSON format.", + temperature=0.2, + max_tokens=2048 # Increased to avoid truncation + ) def generate_prompt(case_id: str, progress_text: str, radiology_text: str) -> str: """Generates a structured prompt for the LLM.""" @@ -53,7 +59,13 @@ def generate_prompt(case_id: str, progress_text: str, radiology_text: str) -> st return prompt def parse_llm_response(response: str) -> Dict[str, Any]: - """Parse LLM response and extract decision, confidence, and reasoning.""" + """Parse LLM response and extract decision, confidence, and reasoning. + + Handles both complete and truncated JSON responses by attempting + regex extraction as a fallback. + """ + import re + default_response = { 'decision': None, 'confidence': None, @@ -63,37 +75,110 @@ def parse_llm_response(response: str) -> Dict[str, Any]: if not response: return default_response - try: - # Search JSON in the response - response = response.strip() - if response.startswith('```json'): - response = response.replace('```json', '').replace('```', '').strip() - elif response.startswith('```'): - response = response.replace('```', '').strip() + # Clean up the response + response = response.strip() + if response.startswith('```json'): + response = response.replace('```json', '').replace('```', '').strip() + elif response.startswith('```'): + response = response.replace('```', '').strip() + # Try standard JSON parsing first + try: parsed = json.loads(response) - return { 'decision': parsed.get('decision'), 'confidence': parsed.get('confidence'), 'reasoning': parsed.get('reasoning', 'No reasoning provided') } - except json.JSONDecodeError as e: - logging.error(f"JSON parsing error: {e}") - logging.error(f"Response was: {response}") - return default_response - except Exception as e: - logging.error(f"Unexpected error parsing response: {e}") - return default_response + except json.JSONDecodeError: + pass # Fall through to regex extraction + + # Fallback: extract values using regex for truncated/malformed JSON + result = default_response.copy() + + # Extract decision + decision_match = re.search(r'"decision"\s*:\s*"(Yes|No)"', response, re.IGNORECASE) + if decision_match: + result['decision'] = decision_match.group(1).capitalize() + + # Extract confidence + confidence_match = re.search(r'"confidence"\s*:\s*(\d+)', response) + if confidence_match: + result['confidence'] = int(confidence_match.group(1)) + + # Extract reasoning + reasoning_match = re.search(r'"reasoning"\s*:\s*"([^"]*)"', response) + if reasoning_match: + result['reasoning'] = reasoning_match.group(1) + elif result['decision']: + result['reasoning'] = 'Response was truncated' + + # Log if we had to use fallback + if result['decision'] or result['confidence']: + logging.warning(f"Used regex fallback to parse truncated response") + else: + logging.error(f"JSON parsing error - could not extract any values") + logging.error(f"Response was: {response[:200]}...") + + return result + +def _load_processed_case_ids(output_file: Optional[str]) -> Set[str]: + """Load already processed case IDs from existing output file.""" + if not output_file or not os.path.exists(output_file): + return set() -def process_llm_cases(llm_df: pd.DataFrame, api_key: str, delay_seconds: float = 0.2) -> pd.DataFrame: + try: + existing_df = pd.read_csv(output_file) + if 'llm_caseID' in existing_df.columns: + # Only count cases that have a decision (successfully processed) + processed = existing_df[existing_df['decision'].notna()]['llm_caseID'].astype(str).tolist() + return set(processed) + except Exception as e: + logging.warning(f"Could not read existing output file: {e}") + + return set() + + +def _flush_results_to_csv( + results: list, + output_file: str, + write_header: bool +) -> None: + """Flush batch results to CSV file and free memory.""" + if not results: + return + + batch_df = pd.DataFrame(results) + batch_df.to_csv( + output_file, + mode='a' if not write_header else 'w', + header=write_header, + index=False + ) + + # Clear the list and force garbage collection + results.clear() + gc.collect() + + +def process_llm_cases( + llm_df: pd.DataFrame, + api_key: str = None, + delay_seconds: float = 0.2, + output_file: Optional[str] = None, + flush_interval: int = 10, + resume: bool = True +) -> pd.DataFrame: """ - Process a clean LLM DataFrame through OpenAI API. + Process a clean LLM DataFrame through SecureLLM API with incremental saving. Args: llm_df: DataFrame with columns 'llm_caseID', 'formatted_progress_text', 'formatted_radiology_text' - api_key: OpenAI API key (hardcoded) + api_key: Deprecated. Kept for backward compatibility. SecureLLM uses VAULT_SECRET_KEY. delay_seconds: Delay between API calls to avoid rate limiting + output_file: Path to output CSV file for incremental saving + flush_interval: Number of cases to process before flushing to disk + resume: If True, skip cases already in output_file Returns: DataFrame with additional columns: 'decision', 'confidence', 'reasoning', 'api_response' @@ -102,26 +187,55 @@ def process_llm_cases(llm_df: pd.DataFrame, api_key: str, delay_seconds: float = # Setup logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') - # Initialize OpenAI client - client = openai.OpenAI(api_key=api_key) - logging.info("OpenAI client initialized successfully") - - # Create a copy of the dataframe - result_df = llm_df.copy() + # Initialize SecureLLM client + client = SecureLLMClient() + logging.info("SecureLLM client initialized successfully") + + # Load already processed case IDs if resuming + processed_ids: Set[str] = set() + write_header = True + + if resume and output_file: + processed_ids = _load_processed_case_ids(output_file) + if processed_ids: + logging.info(f"Resuming: {len(processed_ids)} cases already processed, will skip them") + write_header = False # Append to existing file + + # Filter out already processed cases + if processed_ids: + pending_df = llm_df[~llm_df['llm_caseID'].astype(str).isin(processed_ids)].copy() + logging.info(f"Remaining cases to process: {len(pending_df)}") + else: + pending_df = llm_df.copy() + + total_rows = len(pending_df) + if total_rows == 0: + logging.info("All cases already processed!") + if output_file and os.path.exists(output_file): + return pd.read_csv(output_file) + return llm_df - # Initialize new columns - result_df['decision'] = None - result_df['confidence'] = None - result_df['reasoning'] = None - result_df['api_response'] = None # Store raw response for debugging - - total_rows = len(result_df) logging.info(f"Processing {total_rows} cases...") start_time = time.time() - for idx, row in tqdm(result_df.iterrows(), total=total_rows, desc="Processing cases"): + # Batch results for incremental saving + batch_results = [] + all_results = [] # Keep track if no output file + processed_count = 0 + + for idx, (_, row) in enumerate(tqdm(pending_df.iterrows(), total=total_rows, desc="Processing cases")): + case_id = row['llm_caseID'] + result = { + 'llm_caseID': case_id, + 'formatted_progress_text': row['formatted_progress_text'], + 'formatted_radiology_text': row['formatted_radiology_text'], + 'decision': None, + 'confidence': None, + 'reasoning': None, + 'api_response': None + } + try: - case_id = row['llm_caseID'] logging.info(f"Processing case {idx + 1}/{total_rows}: Case ID {case_id}") # Generate prompt using the formatted text columns @@ -131,51 +245,98 @@ def process_llm_cases(llm_df: pd.DataFrame, api_key: str, delay_seconds: float = radiology_text=row['formatted_radiology_text'] ) - # Query OpenAI - response = query_openai(prompt, client) - result_df.at[idx, 'api_response'] = response - - if response: - # Parse response - parsed = parse_llm_response(response) - result_df.at[idx, 'decision'] = parsed['decision'] - result_df.at[idx, 'confidence'] = parsed['confidence'] - result_df.at[idx, 'reasoning'] = parsed['reasoning'] - - logging.info(f"✓ Case {case_id}: {parsed['decision']} (confidence: {parsed['confidence']})") + # Retry loop for failed responses or failed decision extraction + max_attempts = 5 + for attempt in range(1, max_attempts + 1): + # Query LLM + response = query_openai(prompt, client) + result['api_response'] = response + + if response: + # Parse response + parsed = parse_llm_response(response) + result['decision'] = parsed['decision'] + result['confidence'] = parsed['confidence'] + result['reasoning'] = parsed['reasoning'] + + # Success if we got a decision + if parsed['decision'] is not None: + logging.info(f"✓ Case {case_id}: {parsed['decision']} (confidence: {parsed['confidence']})") + break + else: + logging.warning(f"✗ Attempt {attempt}/{max_attempts}: Could not extract decision for case {case_id}") + else: + logging.warning(f"✗ Attempt {attempt}/{max_attempts}: No response for case {case_id}") + + # Retry delay (increasing backoff) + if attempt < max_attempts: + retry_delay = 2 * attempt + logging.info(f"Retrying in {retry_delay}s...") + time.sleep(retry_delay) else: - logging.warning(f"✗ No response for case {case_id}") - - # Add delay to avoid rate limiting - if delay_seconds > 0: - time.sleep(delay_seconds) - - # Progress updates every 100 cases - if (idx + 1) % 100 == 0: - elapsed = time.time() - start_time - rate = (idx + 1) / elapsed * 60 # cases per minute - remaining = total_rows - (idx + 1) - eta_minutes = remaining / (rate / 60) if rate > 0 else 0 - print(f"Processed {idx + 1}/{total_rows} cases. Rate: {rate:.1f}/min, ETA: {eta_minutes:.1f}min") - + # All attempts exhausted + logging.error(f"✗ Failed to get valid response for case {case_id} after {max_attempts} attempts") except Exception as e: logging.error(f"Error processing case {case_id}: {e}") - result_df.at[idx, 'reasoning'] = f"Error: {str(e)}" + result['reasoning'] = f"Error: {str(e)}" + + # Add to batch + batch_results.append(result) + if not output_file: + all_results.append(result) + processed_count += 1 + + # Flush to disk periodically + if output_file and len(batch_results) >= flush_interval: + _flush_results_to_csv(batch_results, output_file, write_header) + write_header = False # Only write header once + logging.info(f"Flushed {flush_interval} results to {output_file}") + + # Add delay to avoid rate limiting + if delay_seconds > 0: + time.sleep(delay_seconds) + + # Progress updates every 100 cases + if (idx + 1) % 100 == 0: + elapsed = time.time() - start_time + rate = (idx + 1) / elapsed * 60 # cases per minute + remaining = total_rows - (idx + 1) + eta_minutes = remaining / (rate / 60) if rate > 0 else 0 + print(f"Processed {idx + 1}/{total_rows} cases. Rate: {rate:.1f}/min, ETA: {eta_minutes:.1f}min") + + # Flush remaining results + if output_file and batch_results: + _flush_results_to_csv(batch_results, output_file, write_header) + logging.info(f"Flushed final {len(batch_results)} results to {output_file}") elapsed = time.time() - start_time - final_rate = total_rows / elapsed * 60 - logging.info(f"Processing complete! {total_rows} cases in {elapsed:.1f}s ({final_rate:.1f} cases/min)") - return result_df + final_rate = processed_count / elapsed * 60 if elapsed > 0 else 0 + logging.info(f"Processing complete! {processed_count} cases in {elapsed:.1f}s ({final_rate:.1f} cases/min)") + + # Return results + if output_file and os.path.exists(output_file): + return pd.read_csv(output_file) + + return pd.DataFrame(all_results) -def run_llm_analysis(llm_df, api_key): +def run_llm_analysis( + llm_df, + api_key: str = None, + output_file: Optional[str] = None, + flush_interval: int = 10, + resume: bool = True +): """ Main function to run the LLM analysis on your DataFrame. Args: llm_df: DataFrame with columns 'llm_caseID', 'formatted_progress_text', 'formatted_radiology_text' - api_key: Your OpenAI API key + api_key: Deprecated. Kept for backward compatibility. SecureLLM uses VAULT_SECRET_KEY. + output_file: Path to output CSV file for incremental saving + flush_interval: Number of cases to process before flushing to disk (default: 10) + resume: If True, skip cases already in output_file (default: True) Returns: DataFrame with LLM analysis results @@ -183,9 +344,18 @@ def run_llm_analysis(llm_df, api_key): print(f"Starting analysis of {len(llm_df)} cases...") print(f"DataFrame columns: {list(llm_df.columns)}") + if output_file: + print(f"Results will be saved incrementally to: {output_file}") + print(f"Flush interval: every {flush_interval} cases") # Process the cases - results_df = process_llm_cases(llm_df, api_key, delay_seconds=0.2) + results_df = process_llm_cases( + llm_df, + delay_seconds=0.2, + output_file=output_file, + flush_interval=flush_interval, + resume=resume + ) # Show summary total_cases = len(results_df) diff --git a/llm_query/ablation_analysis.py b/llm_query/ablation_analysis.py new file mode 100644 index 0000000..46687b5 --- /dev/null +++ b/llm_query/ablation_analysis.py @@ -0,0 +1,567 @@ +""" +Ablation analysis module for ENT-LLM demographic ablation study. + +Runs ablation experiments to measure how demographic variables influence +LLM surgical recommendations by selectively excluding demographics from prompts. +""" + +import logging +import os +import time +from typing import Any, Dict, List, Optional, Set + +import pandas as pd +from tqdm import tqdm + +from llm_query.LLM_analysis import ( + _flush_results_to_csv, + _load_processed_case_ids, + parse_llm_response, +) +from llm_query.securellm_adapter import query_llm + +logger = logging.getLogger(__name__) + +# Demographic variables used in ablation experiments +DEMOGRAPHIC_VARS = [ + "legal_sex", + "age", + "race", + "ethnicity", + "recent_bmi", + "smoking_hx", + "alcohol_use", + "zipcode", + "insurance_type", + "occupation", +] + +# Meaningful groups for grouped ablation +DEMOGRAPHIC_GROUPS = { + "protected_attributes": ["legal_sex", "race", "ethnicity"], + "socioeconomic": ["zipcode", "insurance_type", "occupation"], + "health_behaviors": ["smoking_hx", "alcohol_use"], + "physical_attributes": ["age", "recent_bmi"], + "all_demographics": list(DEMOGRAPHIC_VARS), +} + +# Human-readable labels for demographic variables +_VAR_LABELS = { + "legal_sex": "Sex", + "age": "Age", + "race": "Race", + "ethnicity": "Ethnicity", + "recent_bmi": "BMI", + "smoking_hx": "Smoking History", + "alcohol_use": "Alcohol Use", + "zipcode": "Zipcode", + "insurance_type": "Insurance", + "occupation": "Occupation", +} + + +def _estimate_tokens(text: str) -> int: + """Estimate token count from text (approx 1 token per 3 characters).""" + return len(str(text)) // 3 + + +def filter_long_cases(df: pd.DataFrame, max_tokens: int = 5000) -> pd.DataFrame: + """Filter out cases whose clinical text would exceed a token limit. + + Estimates tokens as len(text) // 3 and adds a 600-token overhead for + the ablation prompt template (which is longer than the standard prompt + due to demographics + confidence scale sections). + + Args: + df: DataFrame with formatted_progress_text and formatted_radiology_text. + max_tokens: Maximum estimated tokens per case. + + Returns: + Filtered DataFrame with only processable cases. + """ + prompt_overhead = 600 # ablation prompt is longer than standard + keep = [] + for idx, row in df.iterrows(): + total = ( + _estimate_tokens(str(row.get("formatted_progress_text", ""))) + + _estimate_tokens(str(row.get("formatted_radiology_text", ""))) + + prompt_overhead + ) + if total <= max_tokens: + keep.append(idx) + + filtered = df.loc[keep].copy() + n_dropped = len(df) - len(filtered) + if n_dropped: + logger.info( + f"Filtered {n_dropped} cases exceeding {max_tokens} estimated tokens " + f"({len(filtered)} remaining)" + ) + return filtered + + +def format_demographics(row: pd.Series, exclude_vars: Optional[List[str]] = None) -> str: + """Format demographic information, optionally excluding variables. + + Args: + row: DataFrame row with patient data. + exclude_vars: List of demographic variable names to exclude. + + Returns: + Formatted demographic string with one line per variable. + """ + if exclude_vars is None: + exclude_vars = [] + elif isinstance(exclude_vars, str): + exclude_vars = [exclude_vars] + + demographics = [] + for var in DEMOGRAPHIC_VARS: + if var in exclude_vars: + continue + value = row.get(var) + if pd.notna(value): + label = _VAR_LABELS.get(var, var) + demographics.append(f"{label}: {value}") + + return "\n".join(demographics) if demographics else "No demographic information available." + + +def generate_ablation_prompt( + case_id: str, + progress_text: str, + radiology_text: str, + demographics: str, +) -> str: + """Generate prompt with demographics section for ablation analysis. + + Uses the same prompt template as the notebook ablation study, with lowercase + JSON keys to be compatible with the existing parse_llm_response. + + Args: + case_id: Case identifier. + progress_text: Clinical progress note text. + radiology_text: Radiology report text. + demographics: Formatted demographics string (from format_demographics). + + Returns: + Complete prompt string. + """ + has_radiology = ( + radiology_text + and radiology_text.strip() + and radiology_text != "No radiology reports available." + ) + radiology_section = ( + f"- Radiology Report: {radiology_text}" + if has_radiology + else "- Radiology Report: Not available." + ) + + prompt = f""" + === OBJECTIVE === + You are an expert otolaryngologist evaluating an ENT case. + Decide **only** whether surgery is recommended based on the information provided. + + === INSTRUCTIONS === + 1. Rely strictly on the case details below (do not invent information). + 2. Respond with a single **valid JSON object** — no extra text, headings, or explanations outside the JSON. + 3. Follow the schema exactly. + 4. For confidence, choose **one integer value (1-10)** from the Confidence Scale. Do not output ranges or text. + + === CONFIDENCE SCALE (1-10) === + 1 = no confidence (likely wrong) + 3-4 = low (uncertain, weak support) + 5 = moderate (plausible but partly speculative) + 6-7 = fairly confident (reasonable but some gaps/hedging) + 8 = high (well supported, minor uncertainty) + 9 = very high (strong reasoning, unlikely error) + 10 = certain (clear, fully supported, no doubt) + + === CASE DETAILS === + - Case ID: {case_id} + + === PATIENT DEMOGRAPHICS === + {demographics} + + === CLINICAL INFORMATION === + - Clinical Summary: {progress_text} + {radiology_section} + + === OUTPUT SCHEMA === + Respond **only** using the JSON structure below. Do not repeat or paraphrase the instructions, and do not include introductory + or closing comments. Your output must begin and end with a single valid JSON object: + + {{ + "decision": "Yes" | "No", + "confidence": 1-10, + "reasoning": "2-3 sentences explaining the decision (max 100 words)." + }} + """ + return prompt + + +def process_ablation_case( + row: pd.Series, + exclude_vars: Optional[List[str]], + experiment_name: str, + model_name: Optional[str] = None, +) -> Dict[str, Any]: + """Process a single case with demographic variable exclusion. + + Args: + row: DataFrame row with case data and demographics. + exclude_vars: List of demographic variables to exclude (None for baseline). + experiment_name: Name of the experiment. + model_name: Optional model name override. + + Returns: + Dictionary with result columns. + """ + case_id = str(row.get("llm_caseID", "unknown")) + result = { + "llm_caseID": case_id, + "experiment": experiment_name, + "excluded_vars": ",".join(exclude_vars) if exclude_vars else "none", + "decision": None, + "confidence": None, + "reasoning": None, + "api_response": None, + } + + try: + demographics = format_demographics(row, exclude_vars=exclude_vars) + prompt = generate_ablation_prompt( + case_id=case_id, + progress_text=row.get("formatted_progress_text", ""), + radiology_text=row.get("formatted_radiology_text", ""), + demographics=demographics, + ) + + # Retry loop (matches LLM_analysis.py pattern) + max_attempts = 5 + for attempt in range(1, max_attempts + 1): + response = query_llm( + prompt=prompt, + system_message=( + "You are an expert otolaryngologist. " + "Provide a surgical recommendation in the requested JSON format." + ), + temperature=0.2, + max_tokens=2048, + model_name=model_name, + ) + result["api_response"] = response + + if response: + parsed = parse_llm_response(response) + result["decision"] = parsed["decision"] + result["confidence"] = parsed["confidence"] + result["reasoning"] = parsed["reasoning"] + + if parsed["decision"] is not None: + logger.info( + f"Case {case_id} [{experiment_name}]: " + f"{parsed['decision']} (confidence: {parsed['confidence']})" + ) + break + else: + logger.warning( + f"Attempt {attempt}/{max_attempts}: Could not extract decision " + f"for case {case_id} in {experiment_name}" + ) + else: + logger.warning( + f"Attempt {attempt}/{max_attempts}: No response for case {case_id} " + f"in {experiment_name}" + ) + + if attempt < max_attempts: + time.sleep(2 * attempt) + else: + logger.error( + f"Failed to get valid response for case {case_id} " + f"in {experiment_name} after {max_attempts} attempts" + ) + + except Exception as e: + logger.error(f"Error processing case {case_id} in {experiment_name}: {e}") + result["reasoning"] = f"Error: {str(e)}" + + return result + + +def run_ablation_experiment( + df: pd.DataFrame, + exclude_vars: Optional[List[str]], + experiment_name: str, + output_file: str, + model_name: Optional[str] = None, + delay_seconds: float = 0.2, + flush_interval: int = 10, + resume: bool = True, +) -> pd.DataFrame: + """Run a single ablation experiment with incremental saving and resume. + + Args: + df: DataFrame with case data and demographics. + exclude_vars: Variables to exclude (None for baseline). + experiment_name: Name of this experiment. + output_file: Path to output CSV for incremental saving. + model_name: Optional model name override. + delay_seconds: Delay between API calls. + flush_interval: Cases to process before flushing to disk. + resume: If True, skip already-processed cases. + + Returns: + DataFrame with experiment results. + """ + processed_ids: Set[str] = set() + write_header = True + + if resume: + processed_ids = _load_processed_case_ids(output_file) + if processed_ids: + logger.info( + f"[{experiment_name}] Resuming: {len(processed_ids)} cases already processed" + ) + write_header = False + + if processed_ids: + pending_df = df[~df["llm_caseID"].astype(str).isin(processed_ids)] + else: + pending_df = df + + total = len(pending_df) + if total == 0: + logger.info(f"[{experiment_name}] All cases already processed!") + if os.path.exists(output_file): + return pd.read_csv(output_file) + return pd.DataFrame() + + excluded_label = ",".join(exclude_vars) if exclude_vars else "none" + logger.info(f"[{experiment_name}] Processing {total} cases (excluding: {excluded_label})") + + batch_results: list = [] + start_time = time.time() + + for _, row in tqdm(pending_df.iterrows(), total=total, desc=experiment_name): + result = process_ablation_case( + row, + exclude_vars=exclude_vars, + experiment_name=experiment_name, + model_name=model_name, + ) + batch_results.append(result) + + if len(batch_results) >= flush_interval: + _flush_results_to_csv(batch_results, output_file, write_header) + write_header = False + + if delay_seconds > 0: + time.sleep(delay_seconds) + + # Flush remaining + if batch_results: + _flush_results_to_csv(batch_results, output_file, write_header) + + elapsed = time.time() - start_time + logger.info(f"[{experiment_name}] Complete: {total} cases in {elapsed:.1f}s") + + if os.path.exists(output_file): + return pd.read_csv(output_file) + return pd.DataFrame() + + +def analyze_ablation_results( + results_dir: str, + ground_truth_df: Optional[pd.DataFrame] = None, + ground_truth_col: str = "had_surgery", +) -> pd.DataFrame: + """Load result CSVs from results_dir, compare each to baseline, produce summary. + + Args: + results_dir: Directory containing *_results.csv files. + ground_truth_df: Optional DataFrame with ground truth column. + ground_truth_col: Name of the ground truth column. + + Returns: + Summary DataFrame sorted by flip rate. + """ + # Load all result files + all_results: Dict[str, pd.DataFrame] = {} + csv_files = [f for f in os.listdir(results_dir) if f.endswith("_results.csv")] + + for filename in sorted(csv_files): + exp_name = filename.replace("_results.csv", "") + filepath = os.path.join(results_dir, filename) + all_results[exp_name] = pd.read_csv(filepath) + logger.info(f"Loaded {exp_name}: {len(all_results[exp_name])} cases") + + if "baseline" not in all_results: + raise ValueError(f"No baseline_results.csv found in {results_dir}") + + baseline = all_results["baseline"].copy() + baseline["llm_caseID"] = baseline["llm_caseID"].astype(str) + + # Optionally merge ground truth for accuracy metrics + has_gt = False + baseline_accuracy = None + gt = None + if ground_truth_df is not None and ground_truth_col in ground_truth_df.columns: + gt = ground_truth_df[["llm_caseID", ground_truth_col]].copy() + gt["llm_caseID"] = gt["llm_caseID"].astype(str) + baseline_with_gt = baseline.merge(gt, on="llm_caseID", how="left") + baseline_with_gt["decision_binary"] = (baseline_with_gt["decision"] == "Yes").astype(int) + baseline_with_gt["gt_binary"] = baseline_with_gt[ground_truth_col].astype(float) + if baseline_with_gt["gt_binary"].notna().any(): + baseline_accuracy = ( + (baseline_with_gt["decision_binary"] == baseline_with_gt["gt_binary"]).mean() * 100 + ) + has_gt = True + + summary_data = [] + + for exp_name, ablation_df in all_results.items(): + if exp_name == "baseline": + continue + + ablation_df = ablation_df.copy() + ablation_df["llm_caseID"] = ablation_df["llm_caseID"].astype(str) + + comparison = baseline[["llm_caseID", "decision", "confidence"]].merge( + ablation_df[["llm_caseID", "decision", "confidence"]], + on="llm_caseID", + suffixes=("_baseline", "_ablation"), + ) + + total_cases = len(comparison) + if total_cases == 0: + continue + + decision_flips = ( + comparison["decision_baseline"] != comparison["decision_ablation"] + ).sum() + flip_rate = decision_flips / total_cases * 100 + + yes_to_no = ( + (comparison["decision_baseline"] == "Yes") + & (comparison["decision_ablation"] == "No") + ).sum() + no_to_yes = ( + (comparison["decision_baseline"] == "No") + & (comparison["decision_ablation"] == "Yes") + ).sum() + + valid_conf = comparison[ + comparison["confidence_baseline"].notna() + & comparison["confidence_ablation"].notna() + ] + if len(valid_conf) > 0: + conf_change = ( + valid_conf["confidence_ablation"] - valid_conf["confidence_baseline"] + ).mean() + abs_conf_change = ( + (valid_conf["confidence_ablation"] - valid_conf["confidence_baseline"]) + .abs() + .mean() + ) + else: + conf_change = 0 + abs_conf_change = 0 + + excluded_var = exp_name.replace("no_", "") + exp_type = "individual" if excluded_var in DEMOGRAPHIC_VARS else "grouped" + + row_data = { + "experiment": exp_name, + "experiment_type": exp_type, + "excluded": excluded_var, + "total_cases": total_cases, + "decision_flips": int(decision_flips), + "flip_rate_%": round(flip_rate, 2), + "yes_to_no": int(yes_to_no), + "no_to_yes": int(no_to_yes), + "avg_confidence_change": round(conf_change, 3), + "avg_abs_confidence_change": round(abs_conf_change, 3), + } + + # Add accuracy columns if ground truth available + if has_gt and gt is not None: + ablation_with_gt = ablation_df.merge(gt, on="llm_caseID", how="left") + ablation_with_gt["decision_binary"] = ( + ablation_with_gt["decision"] == "Yes" + ).astype(int) + ablation_with_gt["gt_binary"] = ablation_with_gt[ground_truth_col].astype(float) + ablation_accuracy = ( + (ablation_with_gt["decision_binary"] == ablation_with_gt["gt_binary"]).mean() * 100 + ) + row_data["baseline_accuracy_%"] = round(baseline_accuracy, 2) + row_data["ablation_accuracy_%"] = round(ablation_accuracy, 2) + row_data["accuracy_change_%"] = round(ablation_accuracy - baseline_accuracy, 2) + + summary_data.append(row_data) + + summary_df = pd.DataFrame(summary_data) + if not summary_df.empty: + summary_df = summary_df.sort_values("flip_rate_%", ascending=False) + + return summary_df + + +def stratified_sample( + df: pd.DataFrame, + sample_size: int, + stratify_vars: Optional[List[str]] = None, + random_state: int = 42, +) -> pd.DataFrame: + """Create a stratified sample maintaining demographic distributions. + + Args: + df: Full DataFrame. + sample_size: Target sample size. + stratify_vars: Variables to stratify on (default: legal_sex, race). + random_state: Random seed for reproducibility. + + Returns: + Stratified sample DataFrame. + """ + if stratify_vars is None: + stratify_vars = ["legal_sex", "race"] + + stratify_vars = [ + v + for v in stratify_vars + if v in df.columns and df[v].notna().sum() > sample_size * 0.1 + ] + + if not stratify_vars: + logger.warning("No valid stratification variables, using random sample") + return df.sample(n=min(sample_size, len(df)), random_state=random_state) + + df_copy = df.copy() + df_copy["_strata"] = df_copy[stratify_vars].astype(str).agg("_".join, axis=1) + + strata_counts = df_copy["_strata"].value_counts() + strata_proportions = strata_counts / len(df_copy) + + min_per_stratum = 5 + strata_samples = (strata_proportions * sample_size).round().astype(int) + strata_samples = strata_samples.clip( + lower=min(min_per_stratum, sample_size // len(strata_samples)) + ) + + while strata_samples.sum() > sample_size: + largest = strata_samples.idxmax() + strata_samples[largest] -= 1 + + sampled_dfs = [] + for stratum, n_samples in strata_samples.items(): + stratum_df = df_copy[df_copy["_strata"] == stratum] + if len(stratum_df) >= n_samples: + sampled_dfs.append(stratum_df.sample(n=n_samples, random_state=random_state)) + else: + sampled_dfs.append(stratum_df) + + result = pd.concat(sampled_dfs, ignore_index=True) + return result.drop(columns=["_strata"]) diff --git a/llm_query/open_ai_processing.py b/llm_query/ent_surgical_llm_analysis.py similarity index 96% rename from llm_query/open_ai_processing.py rename to llm_query/ent_surgical_llm_analysis.py index 5aa5ac1..3e7ff74 100644 --- a/llm_query/open_ai_processing.py +++ b/llm_query/ent_surgical_llm_analysis.py @@ -1,6 +1,5 @@ -#openAI -import openai +# SecureLLM-based processing (migrated from OpenAI) import pandas as pd import json import logging @@ -8,24 +7,28 @@ from typing import Dict, Any from tqdm import tqdm -def query_openai(prompt: str, client) -> str: - """Query GPT-4omini for surgical decision based on input prompt.""" - try: - response = client.chat.completions.create( - model="gpt-4o-mini", - messages=[ - {"role": "system", "content": ( - "You are an expert otolaryngologist. " - "Provide a surgical recommendation in the requested JSON format." - )}, - {"role": "user", "content": prompt} - ], - temperature=0.2 - ) - return response.choices[0].message.content - except Exception as e: - logging.error(f"OpenAI API error: {e}") - return None +from llm_query.securellm_adapter import query_llm, SecureLLMClient + + +def query_openai(prompt: str, client=None) -> str: + """ + Query the LLM for surgical decision based on input prompt. + + This function now uses SecureLLM instead of direct OpenAI calls. + The client parameter is kept for backward compatibility but is ignored. + + Args: + prompt: The prompt to send to the LLM. + client: Deprecated. Kept for backward compatibility. + + Returns: + The LLM response content or None on error. + """ + return query_llm( + prompt=prompt, + system_message="You are an expert otolaryngologist. Provide a surgical recommendation in the requested JSON format.", + temperature=0.2 + ) def generate_prompt(case_id: str, progress_text: str, radiology_text: str) -> str: """Generates a structured prompt for the LLM.""" @@ -88,13 +91,13 @@ def parse_llm_response(response: str) -> Dict[str, Any]: logging.error(f"Unexpected error parsing response: {e}") return default_response -def process_llm_cases(llm_df: pd.DataFrame, api_key: str, delay_seconds: float = 0.2) -> pd.DataFrame: +def process_llm_cases(llm_df: pd.DataFrame, api_key: str = None, delay_seconds: float = 0.2) -> pd.DataFrame: """ - Process a clean LLM DataFrame through OpenAI API. + Process a clean LLM DataFrame through SecureLLM API. Args: llm_df: DataFrame with columns 'llm_caseID', 'formatted_progress_text', 'formatted_radiology_text' - api_key: OpenAI API key (hardcoded) + api_key: Deprecated. Kept for backward compatibility. SecureLLM uses VAULT_SECRET_KEY. delay_seconds: Delay between API calls to avoid rate limiting Returns: @@ -104,9 +107,9 @@ def process_llm_cases(llm_df: pd.DataFrame, api_key: str, delay_seconds: float = # Setup logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') - # Initialize OpenAI client - client = openai.OpenAI(api_key=api_key) - logging.info("OpenAI client initialized successfully") + # Initialize SecureLLM client + client = SecureLLMClient() + logging.info("SecureLLM client initialized successfully") # Create a copy of the dataframe result_df = llm_df.copy() @@ -171,13 +174,13 @@ def process_llm_cases(llm_df: pd.DataFrame, api_key: str, delay_seconds: float = return result_df -def run_llm_analysis(llm_df, api_key): +def run_llm_analysis(llm_df, api_key: str = None): """ Main function to run the LLM analysis on your DataFrame. Args: llm_df: DataFrame with columns 'llm_caseID', 'formatted_progress_text', 'formatted_radiology_text' - api_key: Your OpenAI API key + api_key: Deprecated. Kept for backward compatibility. SecureLLM uses VAULT_SECRET_KEY. Returns: DataFrame with LLM analysis results @@ -187,7 +190,7 @@ def run_llm_analysis(llm_df, api_key): print(f"DataFrame columns: {list(llm_df.columns)}") # Process the cases - results_df = process_llm_cases(llm_df, api_key, delay_seconds=0.2) + results_df = process_llm_cases(llm_df, delay_seconds=0.2) # Show summary total_cases = len(results_df) diff --git a/llm_query/securellm_adapter.py b/llm_query/securellm_adapter.py new file mode 100644 index 0000000..e3c77d9 --- /dev/null +++ b/llm_query/securellm_adapter.py @@ -0,0 +1,315 @@ +# This source file is part of the ARPA-H CARE LLM project +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT +# + +""" +This module handles calls to a large language model (LLM) using SecureLLM. + +It includes a function to send a prompt to the model and return the generated +response. The default model used is GPT-4o with secure key management via VAULT_SECRET_KEY. +""" + +# Standard library imports +import os +import logging +import time +from typing import List, Dict, Any, Optional + +# Third party imports +from dotenv import load_dotenv +import requests + +logger = logging.getLogger(__name__) + +# Try to import securellm +try: + from securellm.providers.apim import SecureLLMClient as ApimClient + from securellm.providers.registry import SECURE_MODEL_REGISTRY + _SECURELLM_AVAILABLE = True +except ImportError: + logger.warning("securellm package not available, using fallback mode") + _SECURELLM_AVAILABLE = False + ApimClient = None + SECURE_MODEL_REGISTRY = {} + + +class ModelConfig: # pylint: disable=too-few-public-methods + """ + Configuration constants for the LLM interaction using SecureLLM. + """ + DEFAULT_LLM_MODEL = "apim:gpt-4.1" + VAULT_SECRET_KEY = "VAULT_SECRET_KEY" + DEFAULT_TIMEOUT = 120 # seconds (increased for Gemini and other slow models) + MAX_RETRIES = 3 + RETRY_DELAY = 5 # seconds + + +def _initialize_secure_client(): + """ + Initialize SecureLLM client using VAULT_SECRET_KEY from environment. + + Raises: + ValueError: If VAULT_SECRET_KEY is not set or securellm is not available + """ + load_dotenv() + + if not _SECURELLM_AVAILABLE: + raise ImportError("securellm package not installed. Install with: pip install -e .") + + vault_key = os.getenv(ModelConfig.VAULT_SECRET_KEY) + if not vault_key: + raise ValueError( + f"Vault private key not found in environment variable '{ModelConfig.VAULT_SECRET_KEY}'. " + "Please set the VAULT_SECRET_KEY environment variable with your private key." + ) + + logger.info("Initialized SecureLLM client with private key from VAULT_SECRET_KEY") + return vault_key + + +def get_llm_client_instance(model_name: Optional[str] = None, timeout: int = None): + """ + Get or create the SecureLLM client instance with custom timeout. + + Args: + model_name: Optional model name override. Defaults to ModelConfig.DEFAULT_LLM_MODEL. + timeout: Request timeout in seconds. Defaults to ModelConfig.DEFAULT_TIMEOUT. + + Returns: + SecureLLM client instance + """ + api_key = _initialize_secure_client() # Verify key is available and get it + model = model_name or ModelConfig.DEFAULT_LLM_MODEL + timeout = timeout or ModelConfig.DEFAULT_TIMEOUT + + # Get model config from registry + config = SECURE_MODEL_REGISTRY.get(model) + if not config: + raise ValueError(f"Unknown model name: {model}") + + # Create client with custom timeout + return ApimClient( + base_url=config["base_url"], + api_key=api_key, + model_name=model, + model_id=config["model_id"], + api_version=config.get("api_version"), + timeout=timeout + ) + + +def llm_call(prompt: str, temperature: float = 0.7, max_tokens: int = 10000) -> str: + """ + Sends a prompt to the default LLM and returns the generated response using SecureLLM. + + Args: + prompt (str): The user input to send to the model. + temperature (float): Sampling temperature for response variation. + max_tokens (int): Maximum number of tokens in the model's response. + + Returns: + str: The content of the LLM's response. + + Raises: + ImportError: If securellm is not installed + ValueError: If VAULT_SECRET_KEY is not set + """ + messages = [{"role": "user", "content": prompt}] + response = llm_chat(messages, temperature=temperature, max_tokens=max_tokens) + if response is None: + raise RuntimeError("LLM call failed after retries") + return response + + +def llm_chat( + messages: List[Dict[str, str]], + temperature: float = 0.2, + max_tokens: int = 500, + model_name: Optional[str] = None, + timeout: int = None, + max_retries: int = None +) -> Optional[str]: + """ + Sends a chat conversation to the LLM and returns the generated response. + + This function supports system messages and multi-turn conversations, + making it suitable for the ENT surgical recommendation use case. + Includes retry logic for timeout and connection errors. + + Args: + messages: List of message dictionaries with 'role' and 'content' keys. + Roles can be 'system', 'user', or 'assistant'. + temperature: Sampling temperature for response variation. Default 0.2 for consistency. + max_tokens: Maximum number of tokens in the model's response. + model_name: Optional model name override. + timeout: Request timeout in seconds. Defaults to ModelConfig.DEFAULT_TIMEOUT. + max_retries: Maximum retry attempts for timeout errors. Defaults to ModelConfig.MAX_RETRIES. + + Returns: + str: The content of the LLM's response, or None if an error occurred. + + Raises: + ImportError: If securellm is not installed + ValueError: If VAULT_SECRET_KEY is not set + + Example: + >>> messages = [ + ... {"role": "system", "content": "You are an expert otolaryngologist."}, + ... {"role": "user", "content": "Should this patient have surgery?"} + ... ] + >>> response = llm_chat(messages) + """ + if not _SECURELLM_AVAILABLE: + raise ImportError("securellm package not installed. Install with: pip install -e .") + + model = model_name or ModelConfig.DEFAULT_LLM_MODEL + timeout = timeout or ModelConfig.DEFAULT_TIMEOUT + max_retries = max_retries if max_retries is not None else ModelConfig.MAX_RETRIES + + # Build generation config + config = { + "temperature": temperature, + "max_tokens": max_tokens + } + + last_error = None + for attempt in range(max_retries + 1): + try: + client = get_llm_client_instance(model, timeout=timeout) + + # SecureLLM client uses generate() method and returns parsed content directly + response = client.generate(messages, generation_config=config) + + return response.strip() if isinstance(response, str) else str(response).strip() + + except (TimeoutError, requests.exceptions.Timeout, requests.exceptions.ReadTimeout, + requests.exceptions.ConnectionError) as e: + last_error = e + if attempt < max_retries: + wait_time = ModelConfig.RETRY_DELAY * (attempt + 1) + logger.warning(f"Request failed (attempt {attempt + 1}/{max_retries + 1}): {type(e).__name__}. " + f"Retrying in {wait_time}s...") + time.sleep(wait_time) + else: + logger.error(f"Request failed after {max_retries + 1} attempts: {e}") + + except Exception as e: + logger.error(f"SecureLLM API error: {e}") + return None + + logger.error(f"All retry attempts exhausted. Last error: {last_error}") + return None + + +def query_llm( + prompt: str, + system_message: str = "You are an expert otolaryngologist. Provide a surgical recommendation in the requested JSON format.", + temperature: float = 0.2, + max_tokens: int = 500, + model_name: Optional[str] = None +) -> Optional[str]: + """ + Query the LLM with a system message and user prompt. + + This is a convenience function that wraps llm_chat for simple query patterns + commonly used in the ENT analysis pipeline. + + Args: + prompt: The user prompt to send to the model. + system_message: The system message to set the model's behavior. + temperature: Sampling temperature. Default 0.2 for consistent medical recommendations. + max_tokens: Maximum response tokens. + model_name: Optional model name override. + + Returns: + str: The LLM's response content, or None if an error occurred. + """ + messages = [ + {"role": "system", "content": system_message}, + {"role": "user", "content": prompt} + ] + return llm_chat(messages, temperature=temperature, max_tokens=max_tokens, model_name=model_name) + + +class SecureLLMClient: + """ + A client wrapper for SecureLLM that provides an interface compatible with + the existing codebase patterns. + + This class can be used as a drop-in replacement where OpenAI client was used. + + Example: + >>> client = SecureLLMClient() + >>> response = client.query("What is the diagnosis?") + """ + + def __init__(self, model_name: Optional[str] = None): + """ + Initialize the SecureLLM client. + + Args: + model_name: Optional model name. Defaults to ModelConfig.DEFAULT_LLM_MODEL. + """ + self.model_name = model_name or ModelConfig.DEFAULT_LLM_MODEL + self._client = None + + @property + def client(self): + """Lazy initialization of the underlying client.""" + if self._client is None: + self._client = get_llm_client_instance(self.model_name) + return self._client + + def query( + self, + prompt: str, + system_message: str = "You are an expert otolaryngologist. Provide a surgical recommendation in the requested JSON format.", + temperature: float = 0.2, + max_tokens: int = 500 + ) -> Optional[str]: + """ + Query the LLM with a prompt. + + Args: + prompt: The user prompt. + system_message: System message for context. + temperature: Sampling temperature. + max_tokens: Maximum response tokens. + + Returns: + The LLM response or None on error. + """ + return query_llm( + prompt=prompt, + system_message=system_message, + temperature=temperature, + max_tokens=max_tokens, + model_name=self.model_name + ) + + def chat( + self, + messages: List[Dict[str, str]], + temperature: float = 0.2, + max_tokens: int = 500 + ) -> Optional[str]: + """ + Send a chat conversation to the LLM. + + Args: + messages: List of message dictionaries with 'role' and 'content'. + temperature: Sampling temperature. + max_tokens: Maximum response tokens. + + Returns: + The LLM response or None on error. + """ + return llm_chat( + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + model_name=self.model_name + ) diff --git a/llm_query/tools/llm/__init__.py b/llm_query/tools/llm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/llm_query/tools/llm/cache.py b/llm_query/tools/llm/cache.py new file mode 100644 index 0000000..ed844eb --- /dev/null +++ b/llm_query/tools/llm/cache.py @@ -0,0 +1,148 @@ +"""SQLite-based caching for LLM responses.""" + +import sqlite3 +import json +import hashlib +from typing import Optional, Dict, Any +from pathlib import Path + +class LLMCache: + def __init__(self, cache_dir: str = ".cache"): + """Initialize the cache. + + Args: + cache_dir: Directory to store the SQLite database + """ + self.cache_dir = Path(cache_dir) + self.cache_dir.mkdir(parents=True, exist_ok=True) + self.db_path = self.cache_dir / "llm_cache.db" + self._init_db() + + def _init_db(self): + """Initialize the SQLite database with required tables.""" + with sqlite3.connect(self.db_path) as conn: + conn.execute(""" + CREATE TABLE IF NOT EXISTS llm_cache ( + cache_key TEXT PRIMARY KEY, + model_name TEXT NOT NULL, + system_prompt TEXT, + prompt_template TEXT, + context_hash TEXT, + response TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + # Add index for faster lookups + conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_cache_lookup + ON llm_cache(model_name, system_prompt, prompt_template, context_hash) + """) + + def _compute_hash(self, text: str) -> str: + """Compute SHA-256 hash of text.""" + return hashlib.sha256(text.encode()).hexdigest() + + def _get_cache_key(self, + model_name: str, + system_prompt: str, + prompt_template: str, + context: str) -> str: + """Generate a cache key from the input parameters.""" + components = [ + model_name, + system_prompt, + prompt_template, + self._compute_hash(context) + ] + return self._compute_hash("|".join(components)) + + def get(self, + model_name: str, + system_prompt: str, + prompt_template: str, + context: str) -> Optional[str]: + """Get cached response if it exists. + + Args: + model_name: Name of the LLM model + system_prompt: System prompt used + prompt_template: Template used for the prompt + context: Context provided to the model + + Returns: + Cached response if found, None otherwise + """ + cache_key = self._get_cache_key( + model_name, system_prompt, prompt_template, context + ) + + with sqlite3.connect(self.db_path) as conn: + cursor = conn.execute( + "SELECT response FROM llm_cache WHERE cache_key = ?", + (cache_key,) + ) + result = cursor.fetchone() + + return result[0] if result else None + + def set(self, + model_name: str, + system_prompt: str, + prompt_template: str, + context: str, + response: str): + """Cache a response. + + Args: + model_name: Name of the LLM model + system_prompt: System prompt used + prompt_template: Template used for the prompt + context: Context provided to the model + response: Response to cache + """ + cache_key = self._get_cache_key( + model_name, system_prompt, prompt_template, context + ) + context_hash = self._compute_hash(context) + + with sqlite3.connect(self.db_path) as conn: + conn.execute(""" + INSERT OR REPLACE INTO llm_cache + (cache_key, model_name, system_prompt, prompt_template, context_hash, response) + VALUES (?, ?, ?, ?, ?, ?) + """, (cache_key, model_name, system_prompt, prompt_template, context_hash, response)) + + def clear(self): + """Clear all cached responses.""" + with sqlite3.connect(self.db_path) as conn: + conn.execute("DELETE FROM llm_cache") + + def get_stats(self) -> Dict[str, Any]: + """Get cache statistics. + + Returns: + Dictionary containing cache statistics + """ + with sqlite3.connect(self.db_path) as conn: + cursor = conn.execute("SELECT COUNT(*) FROM llm_cache") + total_entries = cursor.fetchone()[0] + + cursor = conn.execute(""" + SELECT model_name, COUNT(*) as count + FROM llm_cache + GROUP BY model_name + """) + model_counts = dict(cursor.fetchall()) + + cursor = conn.execute(""" + SELECT MIN(created_at) as oldest, MAX(created_at) as newest + FROM llm_cache + """) + oldest, newest = cursor.fetchone() + + return { + "total_entries": total_entries, + "model_counts": model_counts, + "oldest_entry": oldest, + "newest_entry": newest + } \ No newline at end of file diff --git a/llm_query/tools/llm/secure_llm_client.py b/llm_query/tools/llm/secure_llm_client.py new file mode 100644 index 0000000..0d610df --- /dev/null +++ b/llm_query/tools/llm/secure_llm_client.py @@ -0,0 +1,150 @@ +# examples/mcp_chat_demo/chat/llm/secure_llm_client.py + +""" +Secure LLM Client Interface + +This module provides secure-llm Client interface for the MCP chat demo. +Uses secure-llm's native API directly. +""" + +import logging +from typing import Dict, Any, List, Optional, Union + +logger = logging.getLogger(__name__) + +try: + from securellm import Client, get_available_models as _securellm_get_available_models + _SECURELLM_AVAILABLE = True +except ImportError: + logger.error("securellm package not found. Install with: pip install secure-llm") + Client = None + _securellm_get_available_models = None + _SECURELLM_AVAILABLE = False + + +# Global client instance (initialized once, reused) +_global_client = None + + +def get_client() -> Client: + """ + Get or create the global secure-llm Client instance. + + Returns: + secure-llm Client instance + + Raises: + ImportError: If secure-llm is not installed + """ + if not _SECURELLM_AVAILABLE: + raise ImportError("securellm package not installed. Install with: pip install secure-llm") + + global _global_client + if _global_client is None: + _global_client = Client() + logger.info("Initialized secure-llm Client") + + return _global_client + + +def get_llm_client(model_name: Optional[str] = None): + """ + Get secure-llm Client instance. + + Args: + model_name: Optional model identifier (for logging/compatibility). + Note: Model is specified per-request, not per-client. + + Returns: + secure-llm Client instance + """ + client = get_client() + if model_name: + logger.debug(f"Client ready for model: {model_name}") + return client + + +def extract_response_content(response: Union[Dict[str, Any], Any]) -> str: + """ + Extract content from secure-llm response. + + Args: + response: Response from client.chat.completions.create() + + Returns: + Response content text + + Raises: + ValueError: If response format is unexpected + """ + # secure-llm returns dict format: response["choices"][0]["message"]["content"] + if isinstance(response, dict): + choices = response.get("choices", []) + if choices: + message = choices[0].get("message", {}) + content = message.get("content", "") + if content: + return content + + # Handle object-style response as fallback + try: + if hasattr(response, "choices") and response.choices: + message = response.choices[0].message + if hasattr(message, "content"): + return message.content + # Try dict access on object + return response.choices[0]["message"]["content"] + except (AttributeError, KeyError, IndexError): + pass + + # Try direct dict access on object + try: + return response["choices"][0]["message"]["content"] + except (KeyError, IndexError, TypeError): + pass + + raise ValueError(f"Unexpected response format: {type(response)}") + + +def get_available_models() -> List[str]: + """ + Get list of available models from secure-llm. + + Returns: + List of model identifiers from secure-llm's model registry + """ + if not _SECURELLM_AVAILABLE or _securellm_get_available_models is None: + logger.warning("securellm not available, returning empty model list") + return [] + + try: + # Use secure-llm's built-in function to get models from registry + models = _securellm_get_available_models() + logger.debug(f"Retrieved {len(models)} models from secure-llm") + return models + except Exception as e: + logger.error(f"Error getting available models from secure-llm: {e}") + return [] + + +def get_default_generation_config(overrides: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + """ + Get default generation configuration. + + Args: + overrides: Dict of parameters to override defaults + + Returns: + Generation config dict + """ + defaults = { + "temperature": 0.7, + "top_p": 1.0, + "max_tokens": 2048, + } + + if overrides: + defaults.update(overrides) + + return defaults + diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..c0137c4 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,50 @@ +[project] +name = "ent-llm" +version = "0.1.0" +description = "LLM evaluation of ENT clinical cases" +readme = "README.md" +requires-python = ">=3.9" +dependencies = [ + "pandas>=2.0.0", + "numpy>=1.24.0", + "tqdm>=4.65.0", + "google-cloud-bigquery>=3.10.0", + "db-dtypes>=1.1.0", + "google-cloud-aiplatform>=1.25.0", + "matplotlib>=3.7.0", + "seaborn>=0.12.0", + "scikit-learn>=1.2.0", + "scipy>=1.10.0", + "reportlab>=4.0.0", + "aiohttp>=3.8.0", + "python-dotenv>=1.0.0", + "securellm @ git+ssh://git@github.com/VISTA-Stanford/secure-llm.git", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.0.0", + "black>=23.0.0", + "ruff>=0.0.270", +] + +[project.scripts] +ent-llm = "cli:main" +ent-llm-extract = "cli_extract:main" +ent-llm-ablation = "cli_ablation:main" + +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +include = ["batch_query*", "data_extraction*", "evaluation*", "llm_query*", "training*"] + +[tool.black] +line-length = 100 +target-version = ["py39", "py310", "py311"] + +[tool.ruff] +line-length = 100 +select = ["E", "F", "I"] +ignore = ["E501"] diff --git a/scripts/merge_files.py b/scripts/merge_files.py new file mode 100644 index 0000000..1df5baf --- /dev/null +++ b/scripts/merge_files.py @@ -0,0 +1,88 @@ +import pandas as pd + + +def merge_csv_files(file_list, output_file="merged_file.csv"): + """ + Merges multiple CSV files into a single sorted CSV file. + + Parameters: + file_list (list of str): List of file paths to the CSV files to be merged. + output_file (str): Path to the output CSV file (default: merged_file.csv). + """ + # Initialize an empty list to hold dataframes + dataframes = [] + + # Read each CSV file and append the dataframe to the list + for file in file_list: + try: + df = pd.read_csv(file) + dataframes.append(df) + print(f"Loaded: {file} ({len(df)} rows)") + except Exception as e: + print(f"Error loading {file}: {e}") + continue + + if not dataframes: + print("No dataframes to merge!") + return + + # Concatenate all dataframes into a single dataframe + merged_df = pd.concat(dataframes, ignore_index=True) + print(f"Total merged rows: {len(merged_df)}") + + # Remove duplicates by llm_caseID + if "llm_caseID" in merged_df.columns: + duplicate_count = merged_df.duplicated(subset=["llm_caseID"], keep=False).sum() + + if duplicate_count > 0 and "decision" in merged_df.columns: + # Custom deduplication: keep row with non-NaN decision, if all NaN keep first + # Create a sort key: prioritize non-NaN decision values + merged_df["_has_decision"] = merged_df["decision"].notna().astype(int) + merged_df = merged_df.sort_values(by=["llm_caseID", "_has_decision"], ascending=[True, False]) + merged_df = merged_df.drop_duplicates(subset=["llm_caseID"], keep="first") + merged_df = merged_df.drop(columns=["_has_decision"]) + print(f"Removed duplicate entries (kept row with non-NaN decision, or first if all NaN)") + elif duplicate_count > 0: + # If decision column doesn't exist, keep first occurrence + merged_df = merged_df.drop_duplicates(subset=["llm_caseID"], keep="first") + print(f"Removed duplicate entries (decision column not found, kept first occurrence)") + + print(f"Rows after deduplication: {len(merged_df)}") + else: + print("Warning: llm_caseID column not found. Skipping deduplication.") + + # Sort by llm_caseID if it exists + if "llm_caseID" in merged_df.columns: + merged_df = merged_df.sort_values(by="llm_caseID").reset_index(drop=True) + print(f"Sorted by llm_caseID") + else: + print("Warning: llm_caseID column not found. Saving without sorting.") + + # Report NaN decision values in final dataframe + if "decision" in merged_df.columns: + nan_count = merged_df["decision"].isna().sum() + non_nan_count = merged_df["decision"].notna().sum() + print(f"Final decision column: {non_nan_count} non-NaN, {nan_count} NaN ({nan_count/len(merged_df)*100:.1f}%)") + + # Report llm_caseID values with NaN decisions + if nan_count > 0: + nan_case_ids = merged_df[merged_df["decision"].isna()]["llm_caseID"].tolist() + print(f"\nllm_caseID values with NaN decision:") + for case_id in nan_case_ids: + print(f" {case_id}") + + # Save the merged dataframe to a new CSV file + merged_df.to_csv(output_file, index=False) + print(f"Merged data saved to: {output_file}") + + +def main(): + # Example usage + files_to_merge = [ + "data/results/file1.csv", + "data/results/file2.csv", + ] + merge_csv_files(files_to_merge, output_file="data/results/merged_results.csv") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/training/training_LLM.py b/training/training_LLM.py index 50f2310..da9b265 100644 --- a/training/training_LLM.py +++ b/training/training_LLM.py @@ -1,8 +1,12 @@ # Formatting for LLM Training Input import pandas as pd import logging +import json +import time from typing import Dict, Any, Union, List, Tuple, Optional +from llm_query.securellm_adapter import query_llm, llm_chat, SecureLLMClient + def format_medical_data(progress_note: Union[Dict, None], radiology_reports: List[Dict]) -> Dict[str, Any]: """Format medical data from note dictionary and reports list into readable text.""" @@ -112,24 +116,25 @@ def training_create_llm_dataframe(processed_df: pd.DataFrame, num_training_rows: return training_df, test_df -def query_openai(prompt: str, client) -> str: - """Query GPT-4 for surgical decision based on input prompt.""" - try: - response = client.chat.completions.create( - model="gpt-4", - messages=[ - {"role": "system", "content": ( - "You are an expert otolaryngologist. " - "Provide a surgical recommendation in the requested JSON format." - )}, - {"role": "user", "content": prompt} - ], - temperature=0.2 - ) - return response.choices[0].message.content - except Exception as e: - logging.error(f"OpenAI API error: {e}") - return None +def query_openai(prompt: str, client=None) -> str: + """ + Query the LLM for surgical decision based on input prompt. + + This function now uses SecureLLM instead of direct OpenAI calls. + The client parameter is kept for backward compatibility but is ignored. + + Args: + prompt: The prompt to send to the LLM. + client: Deprecated. Kept for backward compatibility. + + Returns: + The LLM response content or None on error. + """ + return query_llm( + prompt=prompt, + system_message="You are an expert otolaryngologist. Provide a surgical recommendation in the requested JSON format.", + temperature=0.2 + ) def generate_training_examples(sample_cases: pd.DataFrame) -> str: """Generate training examples from sample cases.""" @@ -251,13 +256,13 @@ def parse_llm_response(response: str) -> Dict[str, Any]: logging.error(f"Unexpected error parsing response: {e}") return default_response -def process_llm_cases(llm_df: pd.DataFrame, api_key: str, delay_seconds: float = 1.0) -> pd.DataFrame: +def process_llm_cases(llm_df: pd.DataFrame, api_key: str = None, delay_seconds: float = 1.0) -> pd.DataFrame: """ - Process a clean LLM DataFrame through OpenAI API. + Process a clean LLM DataFrame through SecureLLM API. Args: llm_df: DataFrame with columns 'llm_caseID', 'formatted_progress_text', 'formatted_radiology_text' - api_key: OpenAI API key (hardcoded) + api_key: Deprecated. Kept for backward compatibility. SecureLLM uses VAULT_SECRET_KEY. delay_seconds: Delay between API calls to avoid rate limiting Returns: @@ -267,12 +272,12 @@ def process_llm_cases(llm_df: pd.DataFrame, api_key: str, delay_seconds: float = # Setup logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') - # Initialize OpenAI client + # Initialize SecureLLM client try: - client = openai.OpenAI(api_key=api_key) - logging.info("OpenAI client initialized successfully") + client = SecureLLMClient() + logging.info("SecureLLM client initialized successfully") except Exception as e: - logging.error(f"Failed to initialize OpenAI client: {e}") + logging.error(f"Failed to initialize SecureLLM client: {e}") raise # Create a copy of the dataframe @@ -326,13 +331,13 @@ def process_llm_cases(llm_df: pd.DataFrame, api_key: str, delay_seconds: float = return result_df -def run_llm_analysis(llm_df, api_key): +def run_llm_analysis(llm_df, api_key: str = None): """ Main function to run the LLM analysis on your DataFrame. Args: llm_df: DataFrame with columns 'llm_caseID', 'formatted_progress_text', 'formatted_radiology_text' - api_key: Your OpenAI API key + api_key: Deprecated. Kept for backward compatibility. SecureLLM uses VAULT_SECRET_KEY. Returns: DataFrame with LLM analysis results @@ -342,7 +347,7 @@ def run_llm_analysis(llm_df, api_key): print(f"DataFrame columns: {list(llm_df.columns)}") # Process the cases - results_df = process_llm_cases(llm_df, api_key, delay_seconds=1.0) + results_df = process_llm_cases(llm_df, delay_seconds=1.0) # Show summary total_cases = len(results_df) @@ -360,37 +365,36 @@ def run_llm_analysis(llm_df, api_key): return results_df -import pandas as pd -import logging -import json -import time -import openai -from typing import Dict, Any, Union, List, Tuple, Optional - - class ConversationalLLMAnalyzer: """ LLM analyzer that maintains conversation context to avoid repeating training examples. + Now uses SecureLLM instead of direct OpenAI calls. """ - def __init__(self, api_key: str, model: str = "gpt-4"): - self.client = openai.OpenAI(api_key=api_key) + def __init__(self, api_key: str = None, model: str = "gpt-4o"): + """ + Initialize the ConversationalLLMAnalyzer. + + Args: + api_key: Deprecated. Kept for backward compatibility. SecureLLM uses VAULT_SECRET_KEY. + model: Model name to use. Defaults to gpt-4o. + """ self.model = model self.conversation_history = [] self.training_loaded = False def _make_api_call(self, messages: List[Dict], max_tokens: int = 500) -> str: - """Make API call with error handling.""" + """Make API call with error handling using SecureLLM.""" try: - response = self.client.chat.completions.create( - model=self.model, + response = llm_chat( messages=messages, temperature=0.2, - max_tokens=max_tokens + max_tokens=max_tokens, + model_name=self.model ) - return response.choices[0].message.content + return response except Exception as e: - logging.error(f"OpenAI API error: {e}") + logging.error(f"SecureLLM API error: {e}") return None def load_training_examples(self, training_df: pd.DataFrame) -> bool: @@ -559,14 +563,14 @@ def parse_llm_response(self, response: str) -> Dict[str, Any]: def process_llm_cases_conversational(test_df: pd.DataFrame, training_df: pd.DataFrame, - api_key: str, delay_seconds: float = 1.0) -> pd.DataFrame: + api_key: str = None, delay_seconds: float = 1.0) -> pd.DataFrame: """ Process cases using conversational context approach. Args: test_df: DataFrame with test cases training_df: DataFrame with training examples - api_key: OpenAI API key + api_key: Deprecated. Kept for backward compatibility. SecureLLM uses VAULT_SECRET_KEY. delay_seconds: Delay between API calls Returns: @@ -576,8 +580,8 @@ def process_llm_cases_conversational(test_df: pd.DataFrame, training_df: pd.Data # Setup logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') - # Initialize analyzer - analyzer = ConversationalLLMAnalyzer(api_key) + # Initialize analyzer with SecureLLM + analyzer = ConversationalLLMAnalyzer() # Load training examples logging.info(f"Loading {len(training_df)} training examples...") @@ -630,14 +634,14 @@ def process_llm_cases_conversational(test_df: pd.DataFrame, training_df: pd.Data return result_df -def run_llm_analysis_training(test_df: pd.DataFrame, training_df: pd.DataFrame, api_key: str): +def run_llm_analysis_training(test_df: pd.DataFrame, training_df: pd.DataFrame, api_key: str = None): """ Main function to run conversational LLM analysis. Args: test_df: DataFrame with test cases training_df: DataFrame with training examples - api_key: OpenAI API key + api_key: Deprecated. Kept for backward compatibility. SecureLLM uses VAULT_SECRET_KEY. Returns: DataFrame with LLM analysis results @@ -654,8 +658,8 @@ def run_llm_analysis_training(test_df: pd.DataFrame, training_df: pd.DataFrame, if missing_cols: raise ValueError(f"Training DataFrame missing required columns: {missing_cols}") - # Process cases - results_df = process_llm_cases_conversational(test_df, training_df, api_key, delay_seconds=1.0) + # Process cases with SecureLLM + results_df = process_llm_cases_conversational(test_df, training_df, delay_seconds=1.0) # Summary total_cases = len(results_df) @@ -673,4 +677,5 @@ def run_llm_analysis_training(test_df: pd.DataFrame, training_df: pd.DataFrame, return results_df -results = run_llm_analysis_training(test_df, training_df, api_key) \ No newline at end of file +# Example usage (uncomment to run): +# results = run_llm_analysis_training(test_df, training_df) \ No newline at end of file