diff --git a/.gitignore b/.gitignore
index 762e891..a27cd78 100644
--- a/.gitignore
+++ b/.gitignore
@@ -174,9 +174,9 @@ cython_debug/
.abstra/
# Visual Studio Code
-# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
+# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
-# and can be added to the global gitignore or merged into this file. However, if you prefer,
+# and can be added to the global gitignore or merged into this file. However, if you prefer,
# you could uncomment the following to ignore the enitre vscode folder
# .vscode/
@@ -193,4 +193,21 @@ cython_debug/
.cursorignore
.cursorindexingignore
-.DS_Store.DS_Store
+# macOS
+.DS_Store
+.DS_Store?
+._*
+.Spotlight-V100
+.Trashes
+.AppleDouble
+.LSOverride
+Icon?
+.DocumentRevisions-V100
+.fseventsd
+.TemporaryItems
+.VolumeIcon.icns
+.com.apple.timemachine.donotpresent
+
+**/data/**
+**/logs/**
+results_**
\ No newline at end of file
diff --git a/README.md b/README.md
index 44664b4..283a31c 100644
--- a/README.md
+++ b/README.md
@@ -1,14 +1,292 @@
-# ent-llm
+
-LLM evaluation of ENT clinical cases
+

+
+
+# LLM evaluation of ENT clinical cases for surgical recommendation
+
+
+
+
## Overview
-`ent-llm` is a LLM project evaluating otolaryngology clinical cases. The goal is to assist clinicians and researchers in analyzing patient scenarios, generating differential diagnoses, and evaluating treatment options with AI-powered tools.
+`ent-llm` evaluates otolaryngology (ENT) clinical cases using Large Language Models. It processes chronic sinusitis patient data from Stanford's medical records and generates surgical recommendations with confidence scores.
+
+## Installation
+
+### Create Virtual Environment
+
+```bash
+python -m venv .venv
+source .venv/bin/activate
+```
+
+### Install Dependencies
+
+```bash
+pip install -e .
+```
+
+**Required environment variables:**
+
+```bash
+export GOOGLE_APPLICATION_CREDENTIALS="/path/to/gcp_credentials.json" # BigQuery access
+export VAULT_SECRET_KEY="your_private_key" # SecureLLM API access
+```
+
+## Quick Start
+
+### Full Pipeline
+
+```bash
+# Step 1: Extract data from BigQuery
+ent-llm-extract --output cases.csv
+
+# Step 2: Run LLM analysis
+ent-llm --model apim:gpt-4.1 --input cases.csv --output results.csv
+```
+
+### Testing with Limited Data
+
+```bash
+# Extract only 100 patients for testing
+python cli_extract.py --output test_cases.csv --limit 100
+
+# Run analysis
+python cli.py --model apim:claude-3.7 --input test_cases.csv --output test_results.csv
+```
+
+## CLI Reference
+
+### `ent-llm-extract` - Data Extraction
+
+Extracts and preprocesses clinical data from BigQuery.
+
+```bash
+ent-llm-extract [OPTIONS]
+```
+
+| Option | Short | Description |
+|--------|-------|-------------|
+| `--output` | `-o` | Output CSV file (default: `llm_cases.csv`) |
+| `--batch-size` | `-b` | Patients per batch (default: 100) |
+| `--limit` | `-l` | Max patients to process (default: all) |
+| `--save-processed` | | Also save full processed dataframe |
+| `--processed-output` | | Path for processed data CSV |
+| `--checkpoint-dir` | | Directory for checkpoint files |
+| `--count-only` | | Show patient count and exit |
+| `--verbose` | `-v` | Enable verbose logging |
+
+**Examples:**
+
+```bash
+# Count total patients
+ent-llm-extract --count-only
+
+# Extract all data
+ent-llm-extract --output cases.csv
+
+# Extract with checkpoints (recommended for large datasets)
+ent-llm-extract --output cases.csv --checkpoint-dir ./checkpoints
+
+# Extract both LLM-ready and full processed data
+ent-llm-extract --output cases.csv --save-processed --processed-output full_data.csv
+```
+
+### `ent-llm` - LLM Analysis
+
+Runs surgical recommendation analysis using various LLM backends.
+
+```bash
+ent-llm [OPTIONS]
+```
+
+| Option | Short | Description |
+|--------|-------|-------------|
+| `--model` | `-m` | LLM model to use (default: `apim:gpt-4.1`) |
+| `--input` | `-i` | Input CSV file with case data |
+| `--output` | `-o` | Output CSV file for results |
+| `--delay` | `-d` | Delay between API calls (default: 0.2s) |
+| `--interactive` | `-I` | Interactive query mode |
+| `--list-models` | `-l` | List available models and exit |
+| `--verbose` | `-v` | Enable verbose logging |
+
+**Available models:**
+
+- `apim:gpt-4.1`
+- `apim:claude-3.7`
+- `apim:llama-3.3-70b`
+- `apim:gemini-2.5-pro-preview-05-06`
+
+**Examples:**
+
+```bash
+# List available models
+ent-llm --list-models
+
+# Run analysis with specific model
+ent-llm --model apim:claude-3.7 --input cases.csv --output results.csv
+
+# Interactive query mode
+ent-llm --model apim:gpt-4.1 --interactive
+
+# Demo mode (no input file)
+ent-llm --model apim:gpt-4.1
+```
+
+### `ent-llm-ablation` - Demographic Ablation Analysis
+
+Measures how demographic variables influence LLM surgical recommendations by selectively excluding demographics from prompts.
+
+```bash
+ent-llm-ablation [OPTIONS]
+```
+
+| Option | Short | Description |
+|--------|-------|-------------|
+| `--model` | `-m` | LLM model to use (default: `apim:gpt-4.1`) |
+| `--input` | `-i` | Input CSV file (clinical text + demographics) |
+| `--output-dir` | `-o` | Output directory for result CSVs (default: `./ablation_results`) |
+| `--baseline` | `-b` | Path to pre-computed baseline CSV (skip baseline run) |
+| `--experiments` | `-e` | Which to run: `all`, `individual`, `grouped`, `baseline-only` |
+| `--sample-size` | `-n` | Stratified sample size |
+| `--max-tokens` | | Filter out cases exceeding estimated token count |
+| `--ground-truth` | `-g` | Ground truth column name (default: `had_surgery`) |
+| `--delay` | `-d` | Delay between API calls (default: 0.2s) |
+| `--flush-interval` | `-f` | Incremental save interval (default: 10) |
+| `--no-resume` | | Start fresh instead of resuming |
+| `--list-experiments` | | List all experiments and exit |
+| `--verbose` | `-v` | Enable verbose logging |
+
+**Input CSV** requires the same clinical columns as `ent-llm` plus demographic columns: `legal_sex`, `age`, `race`, `ethnicity`, `recent_bmi`, `smoking_hx`, `alcohol_use`, `zipcode`, `insurance_type`, `occupation`. Optionally includes a ground truth column (e.g. `had_surgery`) for accuracy analysis.
+
+**Experiments** (16 total):
+- **Baseline** — all demographics included
+- **10 individual ablations** — exclude one variable at a time (`no_legal_sex`, `no_age`, etc.)
+- **5 grouped ablations** — exclude variable groups (`no_protected_attributes`, `no_socioeconomic`, `no_health_behaviors`, `no_physical_attributes`, `no_all_demographics`)
+
+**Examples:**
+
+```bash
+# List all experiments
+ent-llm-ablation --list-experiments
+
+# Run full ablation on a stratified sample of 500 cases
+ent-llm-ablation -m apim:gpt-4.1 -i cases_with_demographics.csv -n 500
+
+# Filter long cases and run only individual ablations
+ent-llm-ablation -m apim:claude-3.7 -i data.csv --max-tokens 5000 -e individual
+
+# Resume with a pre-computed baseline
+ent-llm-ablation -m apim:gpt-4.1 -i data.csv -b ./ablation_results/baseline_results.csv
+```
+
+**Output:** Each experiment saves to `{output_dir}/{experiment_name}_results.csv`. A summary comparing all experiments to baseline is saved to `{output_dir}/ablation_summary.csv` with flip rates, confidence changes, and (if ground truth provided) accuracy metrics.
+
+## Data Pipeline
+
+```
+┌─────────────────────────────────────────────────────────────────────────────┐
+│ DATA EXTRACTION │
+│ (ent-llm-extract CLI) │
+├─────────────────────────────────────────────────────────────────────────────┤
+│ │
+│ BigQuery (Stanford STARR) │
+│ │ │
+│ ├── clinical_note → Filter by ENT authors │
+│ ├── radiology_report → Filter CT sinus reports │
+│ └── procedures → Extract surgery CPT codes │
+│ │ │
+│ ▼ │
+│ Build patient records │
+│ │ │
+│ ▼ │
+│ Censor surgical planning text │
+│ │ │
+│ ▼ │
+│ Format for LLM input → cases.csv │
+│ │
+└─────────────────────────────────────────────────────────────────────────────┘
+ │
+ ▼
+┌─────────────────────────────────────────────────────────────────────────────┐
+│ LLM ANALYSIS │
+│ (ent-llm CLI) │
+├─────────────────────────────────────────────────────────────────────────────┤
+│ │
+│ cases.csv │
+│ │ │
+│ ▼ │
+│ SecureLLM API (GPT-4, Claude, Llama, Gemini) │
+│ │ │
+│ ▼ │
+│ Parse JSON responses │
+│ │ │
+│ ▼ │
+│ results.csv (decision, confidence, reasoning) │
+│ │
+└─────────────────────────────────────────────────────────────────────────────┘
+```
+
+## Data Source
+
+**Google BigQuery - Stanford STARR**
+
+| Setting | Value |
+|---------|-------|
+| Project | `som-nero-phi-roxanad-entllm` |
+| Datasets | Chronic sinusitis cohorts (2016-2025) |
+
+**Tables:**
+
+| Table | Description |
+|-------|-------------|
+| `clinical_note` | ENT clinical notes (progress notes, consults, H&P) |
+| `radiology_report` | CT sinus scan reports |
+| `procedures` | CPT codes for surgeries/endoscopies |
+
+## Input/Output Formats
+
+### Input CSV (from extraction)
+
+| Column | Description |
+|--------|-------------|
+| `llm_caseID` | Unique case identifier |
+| `formatted_progress_text` | Concatenated ENT clinical notes |
+| `formatted_radiology_text` | Concatenated radiology reports |
+
+### Output CSV (from analysis)
+
+| Column | Description |
+|--------|-------------|
+| `llm_caseID` | Case identifier |
+| `decision` | `Yes` or `No` for surgery recommendation |
+| `confidence` | 1-10 confidence score |
+| `reasoning` | 2-4 sentence explanation |
+| `api_response` | Raw LLM response |
+
+## Project Structure
+
+```
+ent-llm/
+├── cli.py # LLM analysis CLI
+├── cli_extract.py # Data extraction CLI
+├── cli_ablation.py # Demographic ablation CLI
+├── data_extraction/ # BigQuery data processing
+│ ├── config.py # Project settings, CPT codes
+│ ├── raw_data_parsing.py # Data extraction functions
+│ └── note_extraction.py # Note filtering and censoring
+├── llm_query/ # LLM integration
+│ ├── securellm_adapter.py # SecureLLM client wrapper
+│ ├── LLM_analysis.py # Analysis pipeline
+│ ├── ablation_analysis.py # Ablation experiment logic
+│ └── llm_input.py # Data formatting
+├── batch_query/ # Batch processing
+├── evaluation/ # Results evaluation
+└── training/ # Training workflows
+```
-## Features
+## License
-- Input structured or free-text ENT case data
-- Query and evaluate cases using state-of-the-art LLMs
-- Generate clinical summaries and differential diagnoses
-- Analyze diagnosis and surgical intervention accuracy
+MIT License - See LICENSE file for details.
diff --git a/batch_query/batch_query.py b/batch_query/batch_query.py
index aa8a9c8..c939769 100644
--- a/batch_query/batch_query.py
+++ b/batch_query/batch_query.py
@@ -164,18 +164,13 @@ def quick_status_check(final_llm_df: pd.DataFrame, save_directory: str = '.'):
import asyncio
import aiohttp
-import openai
from concurrent.futures import ThreadPoolExecutor, as_completed
import time
-import asyncio
-import aiohttp
-import openai
-from concurrent.futures import ThreadPoolExecutor, as_completed
-import time
+from llm_query.securellm_adapter import query_llm, SecureLLMClient
def parallel_process_llm_cases(llm_df: pd.DataFrame,
- api_key: str,
+ api_key: str = None,
max_workers: int = 5,
delay_seconds: float = 0.1) -> pd.DataFrame:
"""
@@ -183,7 +178,7 @@ def parallel_process_llm_cases(llm_df: pd.DataFrame,
Args:
llm_df: DataFrame with cases to process
- api_key: OpenAI API key
+ api_key: Deprecated. Kept for backward compatibility. SecureLLM uses VAULT_SECRET_KEY.
max_workers: Number of parallel workers (start with 5)
delay_seconds: Delay between requests (can be smaller with parallel)
"""
@@ -193,8 +188,6 @@ def process_single_case(row_data):
idx, row = row_data
try:
- client = openai.OpenAI(api_key=api_key)
-
case_id = row['llm_caseID']
# Generate prompt
@@ -204,8 +197,12 @@ def process_single_case(row_data):
radiology_text=row['formatted_radiology_text']
)
- # Query OpenAI
- response = query_openai(prompt, client)
+ # Query SecureLLM
+ response = query_llm(
+ prompt=prompt,
+ system_message="You are an expert otolaryngologist. Provide a surgical recommendation in the requested JSON format.",
+ temperature=0.2
+ )
result = {
'index': idx,
@@ -275,11 +272,17 @@ def process_single_case(row_data):
return result_df
def fast_batch_processing(final_llm_df: pd.DataFrame,
- api_key: str,
+ api_key: str = None,
batch_size: int = 200,
max_workers: int = 5) -> pd.DataFrame:
"""
Fast batch processing with parallel execution.
+
+ Args:
+ final_llm_df: DataFrame with cases to process
+ api_key: Deprecated. Kept for backward compatibility. SecureLLM uses VAULT_SECRET_KEY.
+ batch_size: Number of cases per batch
+ max_workers: Number of parallel workers
"""
# Load existing results
diff --git a/cli.py b/cli.py
new file mode 100644
index 0000000..5f2cfbe
--- /dev/null
+++ b/cli.py
@@ -0,0 +1,340 @@
+#!/usr/bin/env python3
+# This source file is part of the ARPA-H CARE LLM project
+#
+# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md)
+#
+# SPDX-License-Identifier: MIT
+#
+
+"""
+CLI entrypoint for ENT-LLM analysis.
+
+This module provides a command-line interface to run the LLM analysis
+with different model backends.
+"""
+
+import argparse
+import logging
+import sys
+from typing import Optional
+
+import pandas as pd
+
+from llm_query.securellm_adapter import ModelConfig, query_llm, SecureLLMClient
+from llm_query.LLM_analysis import (
+ generate_prompt,
+ parse_llm_response,
+ process_llm_cases,
+ run_llm_analysis,
+)
+
+# Available LLM models
+AVAILABLE_MODELS = [
+ "apim:llama-3.3-70b",
+ "apim:claude-3.7",
+ "apim:gpt-4.1",
+ "apim:gemini-2.5-pro-preview-05-06",
+]
+
+
+def setup_logging(verbose: bool = False) -> None:
+ """Configure logging based on verbosity level."""
+ level = logging.DEBUG if verbose else logging.INFO
+ logging.basicConfig(
+ level=level,
+ format="%(asctime)s - %(levelname)s - %(message)s",
+ handlers=[logging.StreamHandler(sys.stdout)],
+ )
+
+
+def run_single_query(model: str, prompt: str) -> Optional[str]:
+ """
+ Run a single LLM query with the specified model.
+
+ Args:
+ model: The model identifier to use.
+ prompt: The prompt to send to the LLM.
+
+ Returns:
+ The LLM response or None on error.
+ """
+ client = SecureLLMClient(model_name=model)
+ return client.query(prompt)
+
+
+def run_analysis_with_model(
+ model: str,
+ input_file: Optional[str] = None,
+ output_file: Optional[str] = None,
+ delay_seconds: float = 0.2,
+ flush_interval: int = 10,
+ no_resume: bool = False,
+ start_row: int = 0,
+) -> pd.DataFrame:
+ """
+ Run the full LLM analysis pipeline with a specified model.
+
+ Args:
+ model: The model identifier to use.
+ input_file: Path to input CSV file with case data.
+ output_file: Path to save results CSV (saved incrementally).
+ delay_seconds: Delay between API calls.
+ flush_interval: Number of cases to process before flushing to disk.
+ no_resume: If True, start fresh instead of resuming from existing output.
+ start_row: Start processing from this row index (0-based).
+
+ Returns:
+ DataFrame with analysis results.
+ """
+ logger = logging.getLogger(__name__)
+
+ # Update the default model configuration
+ ModelConfig.DEFAULT_LLM_MODEL = model
+ logger.info(f"Using model: {model}")
+
+ if input_file:
+ logger.info(f"Loading data from: {input_file}")
+ llm_df = pd.read_csv(input_file)
+
+ # Validate required columns
+ required_cols = ["llm_caseID", "formatted_progress_text", "formatted_radiology_text"]
+ missing_cols = [col for col in required_cols if col not in llm_df.columns]
+ if missing_cols:
+ raise ValueError(f"Input file missing required columns: {missing_cols}")
+
+ # Apply start_row filter
+ if start_row > 0:
+ if start_row >= len(llm_df):
+ raise ValueError(f"start_row ({start_row}) is >= total rows ({len(llm_df)})")
+ logger.info(f"Starting from row {start_row} (skipping first {start_row} rows)")
+ llm_df = llm_df.iloc[start_row:].reset_index(drop=True)
+
+ # Run analysis with incremental saving
+ results_df = run_llm_analysis(
+ llm_df,
+ output_file=output_file,
+ flush_interval=flush_interval,
+ resume=not no_resume
+ )
+
+ if output_file:
+ logger.info(f"Results saved to: {output_file}")
+
+ return results_df
+ else:
+ logger.warning("No input file provided. Running in demo mode.")
+ # Demo mode: create a simple test case
+ demo_data = {
+ "llm_caseID": ["DEMO_001"],
+ "formatted_progress_text": [
+ "Patient presents with chronic rhinosinusitis refractory to medical management "
+ "including multiple courses of antibiotics and intranasal corticosteroids. "
+ "Symptoms include persistent nasal congestion, facial pressure, and purulent discharge "
+ "for over 12 weeks. Previous treatments have failed to provide lasting relief."
+ ],
+ "formatted_radiology_text": [
+ "CT Sinuses: Mucosal thickening in bilateral maxillary sinuses with partial "
+ "opacification. Ostiomeatal complex obstruction bilaterally. No bony erosion."
+ ],
+ }
+ demo_df = pd.DataFrame(demo_data)
+
+ logger.info("Processing demo case...")
+ results_df = run_llm_analysis(demo_df)
+
+ print("\n=== Demo Results ===")
+ for _, row in results_df.iterrows():
+ print(f"Case ID: {row['llm_caseID']}")
+ print(f"Decision: {row['decision']}")
+ print(f"Confidence: {row['confidence']}")
+ print(f"Reasoning: {row['reasoning']}")
+
+ return results_df
+
+
+def interactive_query(model: str) -> None:
+ """
+ Run an interactive query session with the specified model.
+
+ Args:
+ model: The model identifier to use.
+ """
+ logger = logging.getLogger(__name__)
+ logger.info(f"Starting interactive session with model: {model}")
+ print(f"\nInteractive query mode with {model}")
+ print("Type 'quit' or 'exit' to end the session.\n")
+
+ client = SecureLLMClient(model_name=model)
+
+ while True:
+ try:
+ prompt = input("You: ").strip()
+ if prompt.lower() in ("quit", "exit"):
+ print("Goodbye!")
+ break
+ if not prompt:
+ continue
+
+ response = client.query(
+ prompt,
+ system_message="You are an expert otolaryngologist. Answer questions about ENT cases.",
+ )
+ if response:
+ print(f"\nAssistant: {response}\n")
+ else:
+ print("\n[No response received]\n")
+ except KeyboardInterrupt:
+ print("\nGoodbye!")
+ break
+ except Exception as e:
+ logger.error(f"Error: {e}")
+ print(f"\n[Error: {e}]\n")
+
+
+def main() -> int:
+ """Main CLI entrypoint."""
+ parser = argparse.ArgumentParser(
+ description="ENT-LLM Analysis CLI - Run clinical case analysis with various LLM backends",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ epilog="""
+Examples:
+ # Run analysis with default model (demo mode)
+ python -m cli --model apim:gpt-4.1
+
+ # Run analysis on a CSV file (saves incrementally every 10 cases)
+ python -m cli --model apim:claude-3.7 --input cases.csv --output results.csv
+
+ # Run with custom flush interval (save every 5 cases)
+ python -m cli --model apim:gpt-4.1 -i cases.csv -o results.csv --flush-interval 5
+
+ # Start from a specific row (e.g., skip first 100 rows)
+ python -m cli --model apim:gpt-4.1 -i cases.csv -o results.csv --start-row 100
+
+ # Start fresh (don't resume from existing output)
+ python -m cli --model apim:gpt-4.1 -i cases.csv -o results.csv --no-resume
+
+ # Interactive query mode
+ python -m cli --model apim:llama-3.3-70b --interactive
+
+ # List available models
+ python -m cli --list-models
+ """,
+ )
+
+ parser.add_argument(
+ "--model",
+ "-m",
+ type=str,
+ choices=AVAILABLE_MODELS,
+ default="apim:gpt-4.1",
+ help="LLM model to use for analysis (default: apim:gpt-4.1)",
+ )
+
+ parser.add_argument(
+ "--input",
+ "-i",
+ type=str,
+ help="Input CSV file with case data (columns: llm_caseID, formatted_progress_text, formatted_radiology_text)",
+ )
+
+ parser.add_argument(
+ "--output",
+ "-o",
+ type=str,
+ help="Output CSV file for results",
+ )
+
+ parser.add_argument(
+ "--delay",
+ "-d",
+ type=float,
+ default=0.2,
+ help="Delay in seconds between API calls (default: 0.2)",
+ )
+
+ parser.add_argument(
+ "--flush-interval",
+ "-f",
+ type=int,
+ default=10,
+ help="Number of cases to process before flushing to disk (default: 10)",
+ )
+
+ parser.add_argument(
+ "--no-resume",
+ action="store_true",
+ help="Start fresh instead of resuming from existing output file",
+ )
+
+ parser.add_argument(
+ "--start-row",
+ "-s",
+ type=int,
+ default=0,
+ help="Start processing from this row index (0-based, default: 0)",
+ )
+
+ parser.add_argument(
+ "--interactive",
+ "-I",
+ action="store_true",
+ help="Run in interactive query mode",
+ )
+
+ parser.add_argument(
+ "--list-models",
+ "-l",
+ action="store_true",
+ help="List available models and exit",
+ )
+
+ parser.add_argument(
+ "--verbose",
+ "-v",
+ action="store_true",
+ help="Enable verbose logging",
+ )
+
+ args = parser.parse_args()
+
+ # Handle --list-models
+ if args.list_models:
+ print("Available models:")
+ for model in AVAILABLE_MODELS:
+ print(f" - {model}")
+ return 0
+
+ # Setup logging
+ setup_logging(args.verbose)
+ logger = logging.getLogger(__name__)
+
+ logger.info(f"ENT-LLM Analysis CLI")
+ logger.info(f"Selected model: {args.model}")
+
+ try:
+ if args.interactive:
+ interactive_query(args.model)
+ else:
+ run_analysis_with_model(
+ model=args.model,
+ input_file=args.input,
+ output_file=args.output,
+ delay_seconds=args.delay,
+ flush_interval=args.flush_interval,
+ no_resume=args.no_resume,
+ start_row=args.start_row,
+ )
+ return 0
+ except KeyboardInterrupt:
+ print("\nInterrupted by user")
+ return 130
+ except Exception as e:
+ logger.error(f"Error: {e}")
+ if args.verbose:
+ import traceback
+ traceback.print_exc()
+ return 1
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/cli_ablation.py b/cli_ablation.py
new file mode 100644
index 0000000..e3082bf
--- /dev/null
+++ b/cli_ablation.py
@@ -0,0 +1,353 @@
+#!/usr/bin/env python3
+"""
+CLI entrypoint for ENT-LLM demographic ablation analysis.
+
+Runs ablation experiments that selectively exclude demographic variables
+from LLM prompts to measure their influence on surgical recommendations.
+"""
+
+import argparse
+import logging
+import os
+import shutil
+import sys
+
+import pandas as pd
+
+from llm_query.ablation_analysis import (
+ DEMOGRAPHIC_GROUPS,
+ DEMOGRAPHIC_VARS,
+ analyze_ablation_results,
+ filter_long_cases,
+ run_ablation_experiment,
+ stratified_sample,
+)
+from llm_query.securellm_adapter import ModelConfig
+
+# Available LLM models (same as cli.py)
+AVAILABLE_MODELS = [
+ "apim:llama-3.3-70b",
+ "apim:claude-3.7",
+ "apim:gpt-4.1",
+ "apim:gemini-2.5-pro-preview-05-06",
+]
+
+
+def setup_logging(verbose: bool = False) -> None:
+ """Configure logging based on verbosity level."""
+ level = logging.DEBUG if verbose else logging.INFO
+ logging.basicConfig(
+ level=level,
+ format="%(asctime)s - %(levelname)s - %(message)s",
+ handlers=[logging.StreamHandler(sys.stdout)],
+ )
+
+
+def list_experiments() -> None:
+ """List all ablation experiments and exit."""
+ print("Ablation experiments:\n")
+ print("Baseline:")
+ print(" baseline — all demographics included\n")
+ print("Individual ablations (exclude one variable):")
+ for var in DEMOGRAPHIC_VARS:
+ print(f" no_{var}")
+ print("\nGrouped ablations (exclude a group):")
+ for group_name, group_vars in DEMOGRAPHIC_GROUPS.items():
+ print(f" no_{group_name} — excludes: {', '.join(group_vars)}")
+
+ total = 1 + len(DEMOGRAPHIC_VARS) + len(DEMOGRAPHIC_GROUPS)
+ print(f"\nTotal experiments: {total}")
+
+
+def validate_input(df: pd.DataFrame) -> None:
+ """Validate required columns exist in input DataFrame."""
+ clinical_cols = ["llm_caseID", "formatted_progress_text", "formatted_radiology_text"]
+ missing_clinical = [c for c in clinical_cols if c not in df.columns]
+ if missing_clinical:
+ raise ValueError(f"Input CSV missing required clinical columns: {missing_clinical}")
+
+ demographic_present = [v for v in DEMOGRAPHIC_VARS if v in df.columns]
+ if not demographic_present:
+ raise ValueError(
+ f"Input CSV has no demographic columns. Expected at least some of: {DEMOGRAPHIC_VARS}"
+ )
+
+ missing_demo = [v for v in DEMOGRAPHIC_VARS if v not in df.columns]
+ if missing_demo:
+ logging.getLogger(__name__).warning(
+ f"Input CSV missing some demographic columns (will be skipped): {missing_demo}"
+ )
+
+
+def build_experiment_list(experiments_mode: str):
+ """Build list of (experiment_name, exclude_vars) tuples.
+
+ Args:
+ experiments_mode: One of 'all', 'individual', 'grouped', 'baseline-only'.
+
+ Returns:
+ List of (experiment_name, exclude_vars) tuples.
+ """
+ experiments = []
+
+ if experiments_mode in ("all", "baseline-only"):
+ experiments.append(("baseline", None))
+
+ if experiments_mode in ("all", "individual"):
+ for var in DEMOGRAPHIC_VARS:
+ experiments.append((f"no_{var}", [var]))
+
+ if experiments_mode in ("all", "grouped"):
+ for group_name, group_vars in DEMOGRAPHIC_GROUPS.items():
+ experiments.append((f"no_{group_name}", list(group_vars)))
+
+ return experiments
+
+
+def print_summary_table(summary_df: pd.DataFrame) -> None:
+ """Print formatted summary table to stdout."""
+ if summary_df.empty:
+ print("No ablation results to summarize.")
+ return
+
+ print(f"\n{'='*80}")
+ print("ABLATION ANALYSIS SUMMARY")
+ print(f"{'='*80}")
+
+ cols = [
+ "experiment",
+ "experiment_type",
+ "flip_rate_%",
+ "yes_to_no",
+ "no_to_yes",
+ "avg_confidence_change",
+ ]
+ if "accuracy_change_%" in summary_df.columns:
+ cols.append("accuracy_change_%")
+
+ display_cols = [c for c in cols if c in summary_df.columns]
+ print(summary_df[display_cols].to_string(index=False))
+ print(f"{'='*80}")
+
+
+def main() -> int:
+ """Main CLI entrypoint."""
+ parser = argparse.ArgumentParser(
+ description="ENT-LLM Ablation Analysis CLI - demographic ablation study",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ epilog="""
+Examples:
+ # List all experiments
+ python cli_ablation.py --list-experiments
+
+ # Run on small sample
+ python cli_ablation.py -m apim:gpt-4.1 -i data.csv -o ./ablation_test -n 10
+
+ # Run with pre-computed baseline
+ python cli_ablation.py -m apim:gpt-4.1 -i data.csv -o ./ablation_results \\
+ -b ./ablation_results/baseline_results.csv
+
+ # Run only individual ablations
+ python cli_ablation.py -m apim:gpt-4.1 -i data.csv -e individual
+
+ # Run only baseline
+ python cli_ablation.py -m apim:gpt-4.1 -i data.csv -e baseline-only
+ """,
+ )
+
+ parser.add_argument(
+ "--model",
+ "-m",
+ type=str,
+ choices=AVAILABLE_MODELS,
+ default="apim:gpt-4.1",
+ help="LLM model to use (default: apim:gpt-4.1)",
+ )
+
+ parser.add_argument(
+ "--input",
+ "-i",
+ type=str,
+ help="Input CSV file (clinical text + demographics)",
+ )
+
+ parser.add_argument(
+ "--output-dir",
+ "-o",
+ type=str,
+ default="./ablation_results",
+ help="Output directory for result CSVs (default: ./ablation_results)",
+ )
+
+ parser.add_argument(
+ "--baseline",
+ "-b",
+ type=str,
+ help="Path to pre-computed baseline CSV (skip baseline run)",
+ )
+
+ parser.add_argument(
+ "--experiments",
+ "-e",
+ type=str,
+ choices=["all", "individual", "grouped", "baseline-only"],
+ default="all",
+ help="Which experiments to run (default: all)",
+ )
+
+ parser.add_argument(
+ "--sample-size",
+ "-n",
+ type=int,
+ help="Optional sample size (stratified sampling)",
+ )
+
+ parser.add_argument(
+ "--max-tokens",
+ type=int,
+ help="Filter out cases exceeding this estimated token count (e.g. 5000)",
+ )
+
+ parser.add_argument(
+ "--ground-truth",
+ "-g",
+ type=str,
+ default="had_surgery",
+ help="Column name for ground truth (default: had_surgery)",
+ )
+
+ parser.add_argument(
+ "--delay",
+ "-d",
+ type=float,
+ default=0.2,
+ help="Delay between API calls in seconds (default: 0.2)",
+ )
+
+ parser.add_argument(
+ "--flush-interval",
+ "-f",
+ type=int,
+ default=10,
+ help="Flush interval for incremental saving (default: 10)",
+ )
+
+ parser.add_argument(
+ "--no-resume",
+ action="store_true",
+ help="Start fresh instead of resuming from existing output",
+ )
+
+ parser.add_argument(
+ "--list-experiments",
+ action="store_true",
+ help="List all experiments and exit",
+ )
+
+ parser.add_argument(
+ "--verbose",
+ "-v",
+ action="store_true",
+ help="Enable verbose logging",
+ )
+
+ args = parser.parse_args()
+
+ # Handle --list-experiments
+ if args.list_experiments:
+ list_experiments()
+ return 0
+
+ # --input is required for all other modes
+ if not args.input:
+ parser.error("--input / -i is required (use --list-experiments to see experiments)")
+
+ # Setup logging
+ setup_logging(args.verbose)
+ logger = logging.getLogger(__name__)
+
+ # Set model
+ ModelConfig.DEFAULT_LLM_MODEL = args.model
+ logger.info(f"ENT-LLM Ablation Analysis CLI")
+ logger.info(f"Model: {args.model}")
+
+ # Load input
+ logger.info(f"Loading input: {args.input}")
+ df = pd.read_csv(args.input)
+ validate_input(df)
+ logger.info(f"Loaded {len(df)} cases")
+
+ # Optional token-length filter
+ if args.max_tokens:
+ before = len(df)
+ df = filter_long_cases(df, max_tokens=args.max_tokens)
+ logger.info(f"Token filter ({args.max_tokens}): {before} -> {len(df)} cases")
+
+ # Stratified sample
+ if args.sample_size:
+ logger.info(f"Creating stratified sample of {args.sample_size} cases")
+ df = stratified_sample(df, args.sample_size)
+ logger.info(f"Sample size after stratification: {len(df)}")
+
+ # Create output directory
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ # Build experiment list
+ experiments = build_experiment_list(args.experiments)
+ logger.info(f"Planned experiments: {len(experiments)}")
+
+ # Handle pre-computed baseline
+ if args.baseline:
+ if not os.path.exists(args.baseline):
+ logger.error(f"Baseline file not found: {args.baseline}")
+ return 1
+ baseline_dest = os.path.join(args.output_dir, "baseline_results.csv")
+ if os.path.abspath(args.baseline) != os.path.abspath(baseline_dest):
+ shutil.copy2(args.baseline, baseline_dest)
+ logger.info(f"Copied baseline to {baseline_dest}")
+ # Remove baseline from experiments list since it's pre-computed
+ experiments = [(name, evars) for name, evars in experiments if name != "baseline"]
+ logger.info(f"Using pre-computed baseline, {len(experiments)} experiments remaining")
+
+ # Run experiments
+ try:
+ for exp_name, exclude_vars in experiments:
+ output_file = os.path.join(args.output_dir, f"{exp_name}_results.csv")
+ logger.info(f"Starting experiment: {exp_name}")
+ run_ablation_experiment(
+ df=df,
+ exclude_vars=exclude_vars,
+ experiment_name=exp_name,
+ output_file=output_file,
+ model_name=args.model,
+ delay_seconds=args.delay,
+ flush_interval=args.flush_interval,
+ resume=not args.no_resume,
+ )
+ except KeyboardInterrupt:
+ print("\nInterrupted by user. Progress has been saved incrementally.")
+ return 130
+
+ # Analyze results
+ logger.info("Analyzing results...")
+ ground_truth_df = None
+ if args.ground_truth in df.columns:
+ ground_truth_df = df
+
+ summary_df = analyze_ablation_results(
+ results_dir=args.output_dir,
+ ground_truth_df=ground_truth_df,
+ ground_truth_col=args.ground_truth,
+ )
+
+ summary_path = os.path.join(args.output_dir, "ablation_summary.csv")
+ summary_df.to_csv(summary_path, index=False)
+ logger.info(f"Summary saved to {summary_path}")
+
+ print_summary_table(summary_df)
+
+ return 0
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/cli_extract.py b/cli_extract.py
new file mode 100644
index 0000000..66a73a4
--- /dev/null
+++ b/cli_extract.py
@@ -0,0 +1,578 @@
+#!/usr/bin/env python3
+# This source file is part of the ARPA-H CARE LLM project
+#
+# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md)
+#
+# SPDX-License-Identifier: MIT
+#
+
+"""
+CLI entrypoint for BigQuery data extraction.
+
+This module provides a command-line interface to extract and preprocess
+ENT clinical data from BigQuery for LLM analysis.
+"""
+
+import argparse
+import gc
+import logging
+import sys
+import time
+from typing import Dict, Iterator, List, Optional, Tuple
+
+import pandas as pd
+
+from google.cloud import bigquery
+
+from data_extraction.config import (
+ PROJECT_ID,
+ DATASET_IDS,
+ DATA_TABLES,
+ CLINICAL_NOTE_TYPES,
+ CLINICAL_NOTE_TITLES,
+ RADIOLOGY_REPORT_TYPE,
+ RADIOLOGY_REPORT_TITLE,
+ SURGERY_CPT_CODES,
+ DIAGNOSTIC_ENDOSCOPY_CPT_CODES,
+)
+from data_extraction.raw_data_parsing import (
+ extract_ent_notes,
+ extract_radiology_reports,
+ procedures_df,
+ build_patient_df,
+)
+from data_extraction.note_extraction import (
+ add_last_progress_note,
+ recursive_censor_notes,
+)
+from llm_query.llm_input import create_llm_dataframe
+
+
+def setup_logging(verbose: bool = False) -> logging.Logger:
+ """Configure logging based on verbosity level."""
+ level = logging.DEBUG if verbose else logging.INFO
+ logging.basicConfig(
+ level=level,
+ format="%(asctime)s - %(levelname)s - %(message)s",
+ handlers=[logging.StreamHandler(sys.stdout)],
+ )
+ return logging.getLogger(__name__)
+
+
+class BatchProcessor:
+ """Handles batch processing of patient data from BigQuery."""
+
+ def __init__(
+ self,
+ project_id: str,
+ dataset_ids: List[str],
+ batch_size: int = 100,
+ max_retries: int = 3,
+ ):
+ self.client = bigquery.Client(project=project_id)
+ self.project_id = project_id
+ self.dataset_ids = dataset_ids
+ self.batch_size = batch_size
+ self.max_retries = max_retries
+ self.patient_identifier = "patient_id"
+ self.logger = logging.getLogger(__name__)
+
+ def get_total_patient_count(self) -> int:
+ """Get total number of patients with clinical notes."""
+ notes_union = "\nUNION ALL\n".join(
+ f"SELECT {self.patient_identifier} FROM `{self.project_id}.{ds}.clinical_note`"
+ for ds in self.dataset_ids
+ )
+
+ count_query = f"""
+ WITH all_notes AS (
+ SELECT DISTINCT {self.patient_identifier} FROM ({notes_union})
+ )
+ SELECT COUNT(*) as total_patients
+ FROM all_notes
+ """
+
+ result = self.client.query(count_query).to_dataframe()
+ return int(result["total_patients"].iloc[0])
+
+ def get_patient_batches(self, limit: Optional[int] = None) -> Iterator[List[str]]:
+ """Generator that yields batches of patient IDs."""
+ notes_union = "\nUNION ALL\n".join(
+ f"SELECT {self.patient_identifier} FROM `{self.project_id}.{ds}.clinical_note`"
+ for ds in self.dataset_ids
+ )
+
+ all_patients_query = f"""
+ WITH all_notes AS (
+ SELECT DISTINCT {self.patient_identifier} FROM ({notes_union})
+ )
+ SELECT {self.patient_identifier}
+ FROM all_notes
+ ORDER BY {self.patient_identifier}
+ """
+
+ offset = 0
+ total_yielded = 0
+
+ while True:
+ # Adjust batch size if we're near the limit
+ current_batch_size = self.batch_size
+ if limit is not None:
+ remaining = limit - total_yielded
+ if remaining <= 0:
+ break
+ current_batch_size = min(self.batch_size, remaining)
+
+ batch_query = f"""
+ {all_patients_query}
+ LIMIT {current_batch_size} OFFSET {offset}
+ """
+
+ batch_df = self.client.query(batch_query).to_dataframe()
+
+ if batch_df.empty:
+ break
+
+ patient_ids = batch_df[self.patient_identifier].tolist()
+ yield patient_ids
+
+ total_yielded += len(patient_ids)
+ offset += self.batch_size
+
+ del batch_df
+ gc.collect()
+
+ def extract_batch_data(
+ self, patient_ids: List[str], table_names: List[str]
+ ) -> Dict[str, pd.DataFrame]:
+ """Extract data for a batch of patients."""
+ batch_data = {}
+ id_list_str = ", ".join(f"'{pid}'" for pid in patient_ids)
+
+ self.logger.info(f"Extracting data for {len(patient_ids)} patients...")
+
+ for table in table_names:
+ self.logger.debug(f"Loading table: {table}")
+
+ for attempt in range(self.max_retries):
+ try:
+ union_query = "\nUNION ALL\n".join(
+ f"SELECT * FROM `{self.project_id}.{ds}.{table}`"
+ for ds in self.dataset_ids
+ )
+
+ full_query = f"""
+ SELECT * FROM ({union_query})
+ WHERE {self.patient_identifier} IN ({id_list_str})
+ """
+
+ job_config = bigquery.QueryJobConfig(
+ use_query_cache=True, use_legacy_sql=False
+ )
+
+ df = self.client.query(full_query, job_config=job_config).to_dataframe()
+ batch_data[table] = df
+ self.logger.debug(f" {table}: {df.shape[0]} rows loaded")
+ break
+
+ except Exception as e:
+ self.logger.warning(f" Attempt {attempt + 1} failed for '{table}': {e}")
+ if attempt == self.max_retries - 1:
+ self.logger.error(f" Failed to load '{table}' after {self.max_retries} attempts")
+ batch_data[table] = pd.DataFrame()
+ else:
+ time.sleep(2**attempt)
+
+ return batch_data
+
+
+def process_batch(
+ batch_data: Dict[str, pd.DataFrame],
+ patient_ids: List[str],
+ global_case_id_counter: int,
+) -> Tuple[pd.DataFrame, pd.DataFrame, int]:
+ """Process a single batch of patient data."""
+ logger = logging.getLogger(__name__)
+
+ try:
+ logger.info(f"Processing batch of {len(patient_ids)} patients...")
+
+ # Extract ENT notes
+ if "clinical_note" in batch_data and not batch_data["clinical_note"].empty:
+ ent_notes = extract_ent_notes(
+ batch_data["clinical_note"],
+ CLINICAL_NOTE_TYPES,
+ CLINICAL_NOTE_TITLES,
+ )
+ logger.debug(f" Found {len(ent_notes)} ENT notes")
+ else:
+ ent_notes = pd.DataFrame()
+
+ # Extract radiology reports
+ if "radiology_report" in batch_data and not batch_data["radiology_report"].empty:
+ rad_reports = extract_radiology_reports(
+ batch_data["radiology_report"],
+ RADIOLOGY_REPORT_TYPE,
+ RADIOLOGY_REPORT_TITLE,
+ )
+ logger.debug(f" Found {len(rad_reports)} radiology reports")
+ else:
+ rad_reports = pd.DataFrame()
+
+ # Process procedures
+ if "procedures" in batch_data and not batch_data["procedures"].empty:
+ procedures = procedures_df(
+ batch_data["procedures"],
+ SURGERY_CPT_CODES,
+ DIAGNOSTIC_ENDOSCOPY_CPT_CODES,
+ )
+ logger.debug(f" Found {len(procedures)} procedure records")
+ else:
+ procedures = pd.DataFrame()
+
+ # Check if we have any data
+ if ent_notes.empty and rad_reports.empty and procedures.empty:
+ logger.debug("No relevant data found in this batch")
+ return pd.DataFrame(), pd.DataFrame(), global_case_id_counter
+
+ # Build patient dataframe
+ patient_df = build_patient_df(ent_notes, rad_reports, procedures)
+
+ if patient_df.empty:
+ return pd.DataFrame(), pd.DataFrame(), global_case_id_counter
+
+ logger.debug(f" Patient dataframe: {len(patient_df)} patients")
+
+ # Add progress notes
+ patient_df_with_progress = add_last_progress_note(patient_df)
+
+ # Censor notes
+ processed_df, skipped_ids = recursive_censor_notes(patient_df_with_progress)
+ logger.debug(f" After censoring: {len(processed_df)} patients, {len(skipped_ids)} skipped")
+
+ # Create sequential case IDs
+ if not processed_df.empty:
+ case_ids = range(global_case_id_counter, global_case_id_counter + len(processed_df))
+ processed_df["llm_caseID"] = list(case_ids)
+ new_counter = global_case_id_counter + len(processed_df)
+ else:
+ new_counter = global_case_id_counter
+
+ # Format for LLM input
+ llm_df = create_llm_dataframe(processed_df) if not processed_df.empty else pd.DataFrame()
+
+ # Add radiology flag
+ if not processed_df.empty:
+ processed_df["has_radiology"] = [
+ arr.size > 0 if hasattr(arr, "size") else len(arr) > 0
+ for arr in processed_df["radiology_reports"]
+ ]
+
+ return llm_df, processed_df, new_counter
+
+ except Exception as e:
+ logger.error(f"Error processing batch: {e}")
+ import traceback
+ traceback.print_exc()
+ return pd.DataFrame(), pd.DataFrame(), global_case_id_counter
+
+
+def run_extraction(
+ output_file: str,
+ batch_size: int = 100,
+ limit: Optional[int] = None,
+ save_processed: bool = False,
+ processed_output: Optional[str] = None,
+ checkpoint_interval: int = 10,
+ checkpoint_dir: Optional[str] = None,
+) -> Tuple[pd.DataFrame, pd.DataFrame]:
+ """
+ Run the full data extraction pipeline.
+
+ Args:
+ output_file: Path to save the LLM-ready CSV.
+ batch_size: Number of patients per batch.
+ limit: Maximum number of patients to process (None for all).
+ save_processed: Whether to also save the processed dataframe.
+ processed_output: Path for processed dataframe CSV.
+ checkpoint_interval: Save checkpoint every N batches.
+ checkpoint_dir: Directory for checkpoint files.
+
+ Returns:
+ Tuple of (llm_df, processed_df).
+ """
+ logger = logging.getLogger(__name__)
+
+ logger.info("Initializing BigQuery batch processor...")
+ processor = BatchProcessor(PROJECT_ID, DATASET_IDS, batch_size=batch_size)
+
+ # Get total patient count
+ try:
+ total_patients = processor.get_total_patient_count()
+ if limit:
+ total_patients = min(total_patients, limit)
+ logger.info(f"Total patients to process: {total_patients}")
+ except Exception as e:
+ logger.error(f"Error getting patient count: {e}")
+ return pd.DataFrame(), pd.DataFrame()
+
+ global_case_id_counter = 1
+ batch_num = 0
+ total_cases = 0
+ start_time = time.time()
+ first_batch = True
+
+ def _serialize_for_csv(df: pd.DataFrame) -> pd.DataFrame:
+ """Convert complex columns to JSON strings for CSV compatibility (in-place)."""
+ for col in df.columns:
+ if df[col].dtype == object:
+ try:
+ df[col] = df[col].apply(
+ lambda x: str(x) if isinstance(x, (list, dict)) else x
+ )
+ except Exception:
+ pass
+ return df
+
+ try:
+ for patient_batch in processor.get_patient_batches(limit=limit):
+ batch_num += 1
+ batch_start = time.time()
+
+ logger.info(f"{'=' * 50}")
+ logger.info(f"BATCH {batch_num}")
+ logger.info(f"{'=' * 50}")
+
+ # Extract batch data
+ batch_data = processor.extract_batch_data(patient_batch, DATA_TABLES)
+
+ # Process the batch
+ llm_df, processed_df, global_case_id_counter = process_batch(
+ batch_data, patient_batch, global_case_id_counter
+ )
+
+ # Write results incrementally to disk
+ if not llm_df.empty:
+ llm_df.to_csv(
+ output_file,
+ mode='w' if first_batch else 'a',
+ header=first_batch,
+ index=False
+ )
+ total_cases += len(llm_df)
+
+ if save_processed and processed_output and not processed_df.empty:
+ _serialize_for_csv(processed_df)
+ processed_df.to_csv(
+ processed_output,
+ mode='w' if first_batch else 'a',
+ header=first_batch,
+ index=False
+ )
+
+ if not llm_df.empty or (save_processed and not processed_df.empty):
+ first_batch = False
+
+ # Clean up memory - free batch objects immediately
+ del batch_data, llm_df, processed_df
+ gc.collect()
+
+ batch_elapsed = time.time() - batch_start
+ total_elapsed = time.time() - start_time
+ logger.info(
+ f"Batch {batch_num} completed in {batch_elapsed:.1f}s. "
+ f"Total cases: {global_case_id_counter - 1}. "
+ f"Total time: {total_elapsed:.1f}s"
+ )
+
+ # Save checkpoint (just metadata now, data already on disk)
+ if checkpoint_dir and batch_num % checkpoint_interval == 0:
+ checkpoint_path = f"{checkpoint_dir}/checkpoint_batch_{batch_num}.txt"
+ with open(checkpoint_path, 'w') as f:
+ f.write(f"batch_num={batch_num}\n")
+ f.write(f"global_case_id_counter={global_case_id_counter}\n")
+ f.write(f"total_cases={total_cases}\n")
+ logger.info(f"Checkpoint saved: {checkpoint_path}")
+
+ except KeyboardInterrupt:
+ logger.warning("Interrupted by user. Partial results already saved to disk.")
+ except Exception as e:
+ logger.error(f"Error in extraction: {e}")
+ import traceback
+ traceback.print_exc()
+
+ total_time = time.time() - start_time
+ logger.info(f"Extraction complete in {total_time:.1f}s ({total_time/60:.1f} min)")
+
+ if total_cases > 0:
+ logger.info(f"LLM data saved to: {output_file}")
+ logger.info(f"Total cases: {total_cases}")
+ if save_processed and processed_output:
+ logger.info(f"Processed data saved to: {processed_output}")
+
+ # Return empty DataFrames - data is on disk, not in memory
+ # Caller can read from files if needed
+ return pd.DataFrame(), pd.DataFrame()
+ else:
+ logger.warning("No data extracted")
+ return pd.DataFrame(), pd.DataFrame()
+
+
+def main() -> int:
+ """Main CLI entrypoint for data extraction."""
+ parser = argparse.ArgumentParser(
+ description="ENT-LLM Data Extraction CLI - Extract clinical data from BigQuery",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ epilog="""
+Examples:
+ # Extract all data to a CSV file
+ python cli_extract.py --output cases.csv
+
+ # Extract with a patient limit (for testing)
+ python cli_extract.py --output cases.csv --limit 100
+
+ # Extract with custom batch size
+ python cli_extract.py --output cases.csv --batch-size 200
+
+ # Extract and save both LLM and processed data
+ python cli_extract.py --output cases.csv --save-processed --processed-output processed.csv
+
+ # Extract with checkpoints
+ python cli_extract.py --output cases.csv --checkpoint-dir ./checkpoints
+
+ # Show patient count only
+ python cli_extract.py --count-only
+ """,
+ )
+
+ parser.add_argument(
+ "--output",
+ "-o",
+ type=str,
+ default="llm_cases.csv",
+ help="Output CSV file for LLM-ready data (default: llm_cases.csv)",
+ )
+
+ parser.add_argument(
+ "--batch-size",
+ "-b",
+ type=int,
+ default=100,
+ help="Number of patients per batch (default: 100)",
+ )
+
+ parser.add_argument(
+ "--limit",
+ "-l",
+ type=int,
+ default=None,
+ help="Maximum number of patients to process (default: all)",
+ )
+
+ parser.add_argument(
+ "--save-processed",
+ action="store_true",
+ help="Also save the full processed dataframe",
+ )
+
+ parser.add_argument(
+ "--processed-output",
+ type=str,
+ default="processed_data.csv",
+ help="Output file for processed data (default: processed_data.csv)",
+ )
+
+ parser.add_argument(
+ "--checkpoint-dir",
+ type=str,
+ default=None,
+ help="Directory to save checkpoint files",
+ )
+
+ parser.add_argument(
+ "--checkpoint-interval",
+ type=int,
+ default=10,
+ help="Save checkpoint every N batches (default: 10)",
+ )
+
+ parser.add_argument(
+ "--count-only",
+ action="store_true",
+ help="Only show total patient count and exit",
+ )
+
+ parser.add_argument(
+ "--verbose",
+ "-v",
+ action="store_true",
+ help="Enable verbose logging",
+ )
+
+ args = parser.parse_args()
+
+ # Setup logging
+ logger = setup_logging(args.verbose)
+ logger.info("ENT-LLM Data Extraction CLI")
+
+ try:
+ if args.count_only:
+ # Just show patient count
+ logger.info("Counting patients in BigQuery...")
+ processor = BatchProcessor(PROJECT_ID, DATASET_IDS, batch_size=100)
+ total = processor.get_total_patient_count()
+ print(f"Total patients with clinical notes: {total}")
+ return 0
+
+ # Create checkpoint directory if specified
+ if args.checkpoint_dir:
+ import os
+ os.makedirs(args.checkpoint_dir, exist_ok=True)
+
+ # Run extraction
+ run_extraction(
+ output_file=args.output,
+ batch_size=args.batch_size,
+ limit=args.limit,
+ save_processed=args.save_processed,
+ processed_output=args.processed_output,
+ checkpoint_interval=args.checkpoint_interval,
+ checkpoint_dir=args.checkpoint_dir,
+ )
+
+ # Check if output file was created and has data
+ import os
+ if not os.path.exists(args.output) or os.path.getsize(args.output) == 0:
+ logger.error("No data was extracted")
+ return 1
+
+ # Count rows in output file (header + data)
+ with open(args.output, 'r') as f:
+ total_cases = sum(1 for _ in f) - 1 # subtract header
+
+ # Print summary
+ print(f"\n{'=' * 50}")
+ print("EXTRACTION SUMMARY")
+ print(f"{'=' * 50}")
+ print(f"Total cases extracted: {total_cases}")
+ print(f"Output file: {args.output}")
+ if args.save_processed:
+ print(f"Processed file: {args.processed_output}")
+ print(f"\nNext step: Run LLM analysis with:")
+ print(f" python cli.py --model apim:gpt-4.1 --input {args.output} --output results.csv")
+
+ return 0
+
+ except KeyboardInterrupt:
+ print("\nInterrupted by user")
+ return 130
+ except Exception as e:
+ logger.error(f"Error: {e}")
+ if args.verbose:
+ import traceback
+ traceback.print_exc()
+ return 1
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/docs/logo.png b/docs/logo.png
new file mode 100644
index 0000000..424fb5a
Binary files /dev/null and b/docs/logo.png differ
diff --git a/llm_query/LLM_analysis.py b/llm_query/LLM_analysis.py
index 95fae2e..42c95f4 100644
--- a/llm_query/LLM_analysis.py
+++ b/llm_query/LLM_analysis.py
@@ -1,29 +1,35 @@
-import openai
import pandas as pd
import json
import logging
import time
-from typing import Dict, Any
+import gc
+import os
+from typing import Dict, Any, Optional, Set
from tqdm import tqdm
-def query_openai(prompt: str, client) -> str:
- """Query GPT-4omini for surgical decision based on input prompt."""
- try:
- response = client.chat.completions.create(
- model="gpt-4o-mini",
- messages=[
- {"role": "system", "content": (
- "You are an expert otolaryngologist. "
- "Provide a surgical recommendation in the requested JSON format."
- )},
- {"role": "user", "content": prompt}
- ],
- temperature=0.2
- )
- return response.choices[0].message.content
- except Exception as e:
- logging.error(f"OpenAI API error: {e}")
- return None
+from llm_query.securellm_adapter import query_llm, SecureLLMClient
+
+
+def query_openai(prompt: str, client=None) -> str:
+ """
+ Query the LLM for surgical decision based on input prompt.
+
+ This function now uses SecureLLM instead of direct OpenAI calls.
+ The client parameter is kept for backward compatibility but is ignored.
+
+ Args:
+ prompt: The prompt to send to the LLM.
+ client: Deprecated. Kept for backward compatibility.
+
+ Returns:
+ The LLM response content or None on error.
+ """
+ return query_llm(
+ prompt=prompt,
+ system_message="You are an expert otolaryngologist. Provide a surgical recommendation in the requested JSON format.",
+ temperature=0.2,
+ max_tokens=2048 # Increased to avoid truncation
+ )
def generate_prompt(case_id: str, progress_text: str, radiology_text: str) -> str:
"""Generates a structured prompt for the LLM."""
@@ -53,7 +59,13 @@ def generate_prompt(case_id: str, progress_text: str, radiology_text: str) -> st
return prompt
def parse_llm_response(response: str) -> Dict[str, Any]:
- """Parse LLM response and extract decision, confidence, and reasoning."""
+ """Parse LLM response and extract decision, confidence, and reasoning.
+
+ Handles both complete and truncated JSON responses by attempting
+ regex extraction as a fallback.
+ """
+ import re
+
default_response = {
'decision': None,
'confidence': None,
@@ -63,37 +75,110 @@ def parse_llm_response(response: str) -> Dict[str, Any]:
if not response:
return default_response
- try:
- # Search JSON in the response
- response = response.strip()
- if response.startswith('```json'):
- response = response.replace('```json', '').replace('```', '').strip()
- elif response.startswith('```'):
- response = response.replace('```', '').strip()
+ # Clean up the response
+ response = response.strip()
+ if response.startswith('```json'):
+ response = response.replace('```json', '').replace('```', '').strip()
+ elif response.startswith('```'):
+ response = response.replace('```', '').strip()
+ # Try standard JSON parsing first
+ try:
parsed = json.loads(response)
-
return {
'decision': parsed.get('decision'),
'confidence': parsed.get('confidence'),
'reasoning': parsed.get('reasoning', 'No reasoning provided')
}
- except json.JSONDecodeError as e:
- logging.error(f"JSON parsing error: {e}")
- logging.error(f"Response was: {response}")
- return default_response
- except Exception as e:
- logging.error(f"Unexpected error parsing response: {e}")
- return default_response
+ except json.JSONDecodeError:
+ pass # Fall through to regex extraction
+
+ # Fallback: extract values using regex for truncated/malformed JSON
+ result = default_response.copy()
+
+ # Extract decision
+ decision_match = re.search(r'"decision"\s*:\s*"(Yes|No)"', response, re.IGNORECASE)
+ if decision_match:
+ result['decision'] = decision_match.group(1).capitalize()
+
+ # Extract confidence
+ confidence_match = re.search(r'"confidence"\s*:\s*(\d+)', response)
+ if confidence_match:
+ result['confidence'] = int(confidence_match.group(1))
+
+ # Extract reasoning
+ reasoning_match = re.search(r'"reasoning"\s*:\s*"([^"]*)"', response)
+ if reasoning_match:
+ result['reasoning'] = reasoning_match.group(1)
+ elif result['decision']:
+ result['reasoning'] = 'Response was truncated'
+
+ # Log if we had to use fallback
+ if result['decision'] or result['confidence']:
+ logging.warning(f"Used regex fallback to parse truncated response")
+ else:
+ logging.error(f"JSON parsing error - could not extract any values")
+ logging.error(f"Response was: {response[:200]}...")
+
+ return result
+
+def _load_processed_case_ids(output_file: Optional[str]) -> Set[str]:
+ """Load already processed case IDs from existing output file."""
+ if not output_file or not os.path.exists(output_file):
+ return set()
-def process_llm_cases(llm_df: pd.DataFrame, api_key: str, delay_seconds: float = 0.2) -> pd.DataFrame:
+ try:
+ existing_df = pd.read_csv(output_file)
+ if 'llm_caseID' in existing_df.columns:
+ # Only count cases that have a decision (successfully processed)
+ processed = existing_df[existing_df['decision'].notna()]['llm_caseID'].astype(str).tolist()
+ return set(processed)
+ except Exception as e:
+ logging.warning(f"Could not read existing output file: {e}")
+
+ return set()
+
+
+def _flush_results_to_csv(
+ results: list,
+ output_file: str,
+ write_header: bool
+) -> None:
+ """Flush batch results to CSV file and free memory."""
+ if not results:
+ return
+
+ batch_df = pd.DataFrame(results)
+ batch_df.to_csv(
+ output_file,
+ mode='a' if not write_header else 'w',
+ header=write_header,
+ index=False
+ )
+
+ # Clear the list and force garbage collection
+ results.clear()
+ gc.collect()
+
+
+def process_llm_cases(
+ llm_df: pd.DataFrame,
+ api_key: str = None,
+ delay_seconds: float = 0.2,
+ output_file: Optional[str] = None,
+ flush_interval: int = 10,
+ resume: bool = True
+) -> pd.DataFrame:
"""
- Process a clean LLM DataFrame through OpenAI API.
+ Process a clean LLM DataFrame through SecureLLM API with incremental saving.
Args:
llm_df: DataFrame with columns 'llm_caseID', 'formatted_progress_text', 'formatted_radiology_text'
- api_key: OpenAI API key (hardcoded)
+ api_key: Deprecated. Kept for backward compatibility. SecureLLM uses VAULT_SECRET_KEY.
delay_seconds: Delay between API calls to avoid rate limiting
+ output_file: Path to output CSV file for incremental saving
+ flush_interval: Number of cases to process before flushing to disk
+ resume: If True, skip cases already in output_file
Returns:
DataFrame with additional columns: 'decision', 'confidence', 'reasoning', 'api_response'
@@ -102,26 +187,55 @@ def process_llm_cases(llm_df: pd.DataFrame, api_key: str, delay_seconds: float =
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
- # Initialize OpenAI client
- client = openai.OpenAI(api_key=api_key)
- logging.info("OpenAI client initialized successfully")
-
- # Create a copy of the dataframe
- result_df = llm_df.copy()
+ # Initialize SecureLLM client
+ client = SecureLLMClient()
+ logging.info("SecureLLM client initialized successfully")
+
+ # Load already processed case IDs if resuming
+ processed_ids: Set[str] = set()
+ write_header = True
+
+ if resume and output_file:
+ processed_ids = _load_processed_case_ids(output_file)
+ if processed_ids:
+ logging.info(f"Resuming: {len(processed_ids)} cases already processed, will skip them")
+ write_header = False # Append to existing file
+
+ # Filter out already processed cases
+ if processed_ids:
+ pending_df = llm_df[~llm_df['llm_caseID'].astype(str).isin(processed_ids)].copy()
+ logging.info(f"Remaining cases to process: {len(pending_df)}")
+ else:
+ pending_df = llm_df.copy()
+
+ total_rows = len(pending_df)
+ if total_rows == 0:
+ logging.info("All cases already processed!")
+ if output_file and os.path.exists(output_file):
+ return pd.read_csv(output_file)
+ return llm_df
- # Initialize new columns
- result_df['decision'] = None
- result_df['confidence'] = None
- result_df['reasoning'] = None
- result_df['api_response'] = None # Store raw response for debugging
-
- total_rows = len(result_df)
logging.info(f"Processing {total_rows} cases...")
start_time = time.time()
- for idx, row in tqdm(result_df.iterrows(), total=total_rows, desc="Processing cases"):
+ # Batch results for incremental saving
+ batch_results = []
+ all_results = [] # Keep track if no output file
+ processed_count = 0
+
+ for idx, (_, row) in enumerate(tqdm(pending_df.iterrows(), total=total_rows, desc="Processing cases")):
+ case_id = row['llm_caseID']
+ result = {
+ 'llm_caseID': case_id,
+ 'formatted_progress_text': row['formatted_progress_text'],
+ 'formatted_radiology_text': row['formatted_radiology_text'],
+ 'decision': None,
+ 'confidence': None,
+ 'reasoning': None,
+ 'api_response': None
+ }
+
try:
- case_id = row['llm_caseID']
logging.info(f"Processing case {idx + 1}/{total_rows}: Case ID {case_id}")
# Generate prompt using the formatted text columns
@@ -131,51 +245,98 @@ def process_llm_cases(llm_df: pd.DataFrame, api_key: str, delay_seconds: float =
radiology_text=row['formatted_radiology_text']
)
- # Query OpenAI
- response = query_openai(prompt, client)
- result_df.at[idx, 'api_response'] = response
-
- if response:
- # Parse response
- parsed = parse_llm_response(response)
- result_df.at[idx, 'decision'] = parsed['decision']
- result_df.at[idx, 'confidence'] = parsed['confidence']
- result_df.at[idx, 'reasoning'] = parsed['reasoning']
-
- logging.info(f"✓ Case {case_id}: {parsed['decision']} (confidence: {parsed['confidence']})")
+ # Retry loop for failed responses or failed decision extraction
+ max_attempts = 5
+ for attempt in range(1, max_attempts + 1):
+ # Query LLM
+ response = query_openai(prompt, client)
+ result['api_response'] = response
+
+ if response:
+ # Parse response
+ parsed = parse_llm_response(response)
+ result['decision'] = parsed['decision']
+ result['confidence'] = parsed['confidence']
+ result['reasoning'] = parsed['reasoning']
+
+ # Success if we got a decision
+ if parsed['decision'] is not None:
+ logging.info(f"✓ Case {case_id}: {parsed['decision']} (confidence: {parsed['confidence']})")
+ break
+ else:
+ logging.warning(f"✗ Attempt {attempt}/{max_attempts}: Could not extract decision for case {case_id}")
+ else:
+ logging.warning(f"✗ Attempt {attempt}/{max_attempts}: No response for case {case_id}")
+
+ # Retry delay (increasing backoff)
+ if attempt < max_attempts:
+ retry_delay = 2 * attempt
+ logging.info(f"Retrying in {retry_delay}s...")
+ time.sleep(retry_delay)
else:
- logging.warning(f"✗ No response for case {case_id}")
-
- # Add delay to avoid rate limiting
- if delay_seconds > 0:
- time.sleep(delay_seconds)
-
- # Progress updates every 100 cases
- if (idx + 1) % 100 == 0:
- elapsed = time.time() - start_time
- rate = (idx + 1) / elapsed * 60 # cases per minute
- remaining = total_rows - (idx + 1)
- eta_minutes = remaining / (rate / 60) if rate > 0 else 0
- print(f"Processed {idx + 1}/{total_rows} cases. Rate: {rate:.1f}/min, ETA: {eta_minutes:.1f}min")
-
+ # All attempts exhausted
+ logging.error(f"✗ Failed to get valid response for case {case_id} after {max_attempts} attempts")
except Exception as e:
logging.error(f"Error processing case {case_id}: {e}")
- result_df.at[idx, 'reasoning'] = f"Error: {str(e)}"
+ result['reasoning'] = f"Error: {str(e)}"
+
+ # Add to batch
+ batch_results.append(result)
+ if not output_file:
+ all_results.append(result)
+ processed_count += 1
+
+ # Flush to disk periodically
+ if output_file and len(batch_results) >= flush_interval:
+ _flush_results_to_csv(batch_results, output_file, write_header)
+ write_header = False # Only write header once
+ logging.info(f"Flushed {flush_interval} results to {output_file}")
+
+ # Add delay to avoid rate limiting
+ if delay_seconds > 0:
+ time.sleep(delay_seconds)
+
+ # Progress updates every 100 cases
+ if (idx + 1) % 100 == 0:
+ elapsed = time.time() - start_time
+ rate = (idx + 1) / elapsed * 60 # cases per minute
+ remaining = total_rows - (idx + 1)
+ eta_minutes = remaining / (rate / 60) if rate > 0 else 0
+ print(f"Processed {idx + 1}/{total_rows} cases. Rate: {rate:.1f}/min, ETA: {eta_minutes:.1f}min")
+
+ # Flush remaining results
+ if output_file and batch_results:
+ _flush_results_to_csv(batch_results, output_file, write_header)
+ logging.info(f"Flushed final {len(batch_results)} results to {output_file}")
elapsed = time.time() - start_time
- final_rate = total_rows / elapsed * 60
- logging.info(f"Processing complete! {total_rows} cases in {elapsed:.1f}s ({final_rate:.1f} cases/min)")
- return result_df
+ final_rate = processed_count / elapsed * 60 if elapsed > 0 else 0
+ logging.info(f"Processing complete! {processed_count} cases in {elapsed:.1f}s ({final_rate:.1f} cases/min)")
+
+ # Return results
+ if output_file and os.path.exists(output_file):
+ return pd.read_csv(output_file)
+
+ return pd.DataFrame(all_results)
-def run_llm_analysis(llm_df, api_key):
+def run_llm_analysis(
+ llm_df,
+ api_key: str = None,
+ output_file: Optional[str] = None,
+ flush_interval: int = 10,
+ resume: bool = True
+):
"""
Main function to run the LLM analysis on your DataFrame.
Args:
llm_df: DataFrame with columns 'llm_caseID', 'formatted_progress_text', 'formatted_radiology_text'
- api_key: Your OpenAI API key
+ api_key: Deprecated. Kept for backward compatibility. SecureLLM uses VAULT_SECRET_KEY.
+ output_file: Path to output CSV file for incremental saving
+ flush_interval: Number of cases to process before flushing to disk (default: 10)
+ resume: If True, skip cases already in output_file (default: True)
Returns:
DataFrame with LLM analysis results
@@ -183,9 +344,18 @@ def run_llm_analysis(llm_df, api_key):
print(f"Starting analysis of {len(llm_df)} cases...")
print(f"DataFrame columns: {list(llm_df.columns)}")
+ if output_file:
+ print(f"Results will be saved incrementally to: {output_file}")
+ print(f"Flush interval: every {flush_interval} cases")
# Process the cases
- results_df = process_llm_cases(llm_df, api_key, delay_seconds=0.2)
+ results_df = process_llm_cases(
+ llm_df,
+ delay_seconds=0.2,
+ output_file=output_file,
+ flush_interval=flush_interval,
+ resume=resume
+ )
# Show summary
total_cases = len(results_df)
diff --git a/llm_query/ablation_analysis.py b/llm_query/ablation_analysis.py
new file mode 100644
index 0000000..46687b5
--- /dev/null
+++ b/llm_query/ablation_analysis.py
@@ -0,0 +1,567 @@
+"""
+Ablation analysis module for ENT-LLM demographic ablation study.
+
+Runs ablation experiments to measure how demographic variables influence
+LLM surgical recommendations by selectively excluding demographics from prompts.
+"""
+
+import logging
+import os
+import time
+from typing import Any, Dict, List, Optional, Set
+
+import pandas as pd
+from tqdm import tqdm
+
+from llm_query.LLM_analysis import (
+ _flush_results_to_csv,
+ _load_processed_case_ids,
+ parse_llm_response,
+)
+from llm_query.securellm_adapter import query_llm
+
+logger = logging.getLogger(__name__)
+
+# Demographic variables used in ablation experiments
+DEMOGRAPHIC_VARS = [
+ "legal_sex",
+ "age",
+ "race",
+ "ethnicity",
+ "recent_bmi",
+ "smoking_hx",
+ "alcohol_use",
+ "zipcode",
+ "insurance_type",
+ "occupation",
+]
+
+# Meaningful groups for grouped ablation
+DEMOGRAPHIC_GROUPS = {
+ "protected_attributes": ["legal_sex", "race", "ethnicity"],
+ "socioeconomic": ["zipcode", "insurance_type", "occupation"],
+ "health_behaviors": ["smoking_hx", "alcohol_use"],
+ "physical_attributes": ["age", "recent_bmi"],
+ "all_demographics": list(DEMOGRAPHIC_VARS),
+}
+
+# Human-readable labels for demographic variables
+_VAR_LABELS = {
+ "legal_sex": "Sex",
+ "age": "Age",
+ "race": "Race",
+ "ethnicity": "Ethnicity",
+ "recent_bmi": "BMI",
+ "smoking_hx": "Smoking History",
+ "alcohol_use": "Alcohol Use",
+ "zipcode": "Zipcode",
+ "insurance_type": "Insurance",
+ "occupation": "Occupation",
+}
+
+
+def _estimate_tokens(text: str) -> int:
+ """Estimate token count from text (approx 1 token per 3 characters)."""
+ return len(str(text)) // 3
+
+
+def filter_long_cases(df: pd.DataFrame, max_tokens: int = 5000) -> pd.DataFrame:
+ """Filter out cases whose clinical text would exceed a token limit.
+
+ Estimates tokens as len(text) // 3 and adds a 600-token overhead for
+ the ablation prompt template (which is longer than the standard prompt
+ due to demographics + confidence scale sections).
+
+ Args:
+ df: DataFrame with formatted_progress_text and formatted_radiology_text.
+ max_tokens: Maximum estimated tokens per case.
+
+ Returns:
+ Filtered DataFrame with only processable cases.
+ """
+ prompt_overhead = 600 # ablation prompt is longer than standard
+ keep = []
+ for idx, row in df.iterrows():
+ total = (
+ _estimate_tokens(str(row.get("formatted_progress_text", "")))
+ + _estimate_tokens(str(row.get("formatted_radiology_text", "")))
+ + prompt_overhead
+ )
+ if total <= max_tokens:
+ keep.append(idx)
+
+ filtered = df.loc[keep].copy()
+ n_dropped = len(df) - len(filtered)
+ if n_dropped:
+ logger.info(
+ f"Filtered {n_dropped} cases exceeding {max_tokens} estimated tokens "
+ f"({len(filtered)} remaining)"
+ )
+ return filtered
+
+
+def format_demographics(row: pd.Series, exclude_vars: Optional[List[str]] = None) -> str:
+ """Format demographic information, optionally excluding variables.
+
+ Args:
+ row: DataFrame row with patient data.
+ exclude_vars: List of demographic variable names to exclude.
+
+ Returns:
+ Formatted demographic string with one line per variable.
+ """
+ if exclude_vars is None:
+ exclude_vars = []
+ elif isinstance(exclude_vars, str):
+ exclude_vars = [exclude_vars]
+
+ demographics = []
+ for var in DEMOGRAPHIC_VARS:
+ if var in exclude_vars:
+ continue
+ value = row.get(var)
+ if pd.notna(value):
+ label = _VAR_LABELS.get(var, var)
+ demographics.append(f"{label}: {value}")
+
+ return "\n".join(demographics) if demographics else "No demographic information available."
+
+
+def generate_ablation_prompt(
+ case_id: str,
+ progress_text: str,
+ radiology_text: str,
+ demographics: str,
+) -> str:
+ """Generate prompt with demographics section for ablation analysis.
+
+ Uses the same prompt template as the notebook ablation study, with lowercase
+ JSON keys to be compatible with the existing parse_llm_response.
+
+ Args:
+ case_id: Case identifier.
+ progress_text: Clinical progress note text.
+ radiology_text: Radiology report text.
+ demographics: Formatted demographics string (from format_demographics).
+
+ Returns:
+ Complete prompt string.
+ """
+ has_radiology = (
+ radiology_text
+ and radiology_text.strip()
+ and radiology_text != "No radiology reports available."
+ )
+ radiology_section = (
+ f"- Radiology Report: {radiology_text}"
+ if has_radiology
+ else "- Radiology Report: Not available."
+ )
+
+ prompt = f"""
+ === OBJECTIVE ===
+ You are an expert otolaryngologist evaluating an ENT case.
+ Decide **only** whether surgery is recommended based on the information provided.
+
+ === INSTRUCTIONS ===
+ 1. Rely strictly on the case details below (do not invent information).
+ 2. Respond with a single **valid JSON object** — no extra text, headings, or explanations outside the JSON.
+ 3. Follow the schema exactly.
+ 4. For confidence, choose **one integer value (1-10)** from the Confidence Scale. Do not output ranges or text.
+
+ === CONFIDENCE SCALE (1-10) ===
+ 1 = no confidence (likely wrong)
+ 3-4 = low (uncertain, weak support)
+ 5 = moderate (plausible but partly speculative)
+ 6-7 = fairly confident (reasonable but some gaps/hedging)
+ 8 = high (well supported, minor uncertainty)
+ 9 = very high (strong reasoning, unlikely error)
+ 10 = certain (clear, fully supported, no doubt)
+
+ === CASE DETAILS ===
+ - Case ID: {case_id}
+
+ === PATIENT DEMOGRAPHICS ===
+ {demographics}
+
+ === CLINICAL INFORMATION ===
+ - Clinical Summary: {progress_text}
+ {radiology_section}
+
+ === OUTPUT SCHEMA ===
+ Respond **only** using the JSON structure below. Do not repeat or paraphrase the instructions, and do not include introductory
+ or closing comments. Your output must begin and end with a single valid JSON object:
+
+ {{
+ "decision": "Yes" | "No",
+ "confidence": 1-10,
+ "reasoning": "2-3 sentences explaining the decision (max 100 words)."
+ }}
+ """
+ return prompt
+
+
+def process_ablation_case(
+ row: pd.Series,
+ exclude_vars: Optional[List[str]],
+ experiment_name: str,
+ model_name: Optional[str] = None,
+) -> Dict[str, Any]:
+ """Process a single case with demographic variable exclusion.
+
+ Args:
+ row: DataFrame row with case data and demographics.
+ exclude_vars: List of demographic variables to exclude (None for baseline).
+ experiment_name: Name of the experiment.
+ model_name: Optional model name override.
+
+ Returns:
+ Dictionary with result columns.
+ """
+ case_id = str(row.get("llm_caseID", "unknown"))
+ result = {
+ "llm_caseID": case_id,
+ "experiment": experiment_name,
+ "excluded_vars": ",".join(exclude_vars) if exclude_vars else "none",
+ "decision": None,
+ "confidence": None,
+ "reasoning": None,
+ "api_response": None,
+ }
+
+ try:
+ demographics = format_demographics(row, exclude_vars=exclude_vars)
+ prompt = generate_ablation_prompt(
+ case_id=case_id,
+ progress_text=row.get("formatted_progress_text", ""),
+ radiology_text=row.get("formatted_radiology_text", ""),
+ demographics=demographics,
+ )
+
+ # Retry loop (matches LLM_analysis.py pattern)
+ max_attempts = 5
+ for attempt in range(1, max_attempts + 1):
+ response = query_llm(
+ prompt=prompt,
+ system_message=(
+ "You are an expert otolaryngologist. "
+ "Provide a surgical recommendation in the requested JSON format."
+ ),
+ temperature=0.2,
+ max_tokens=2048,
+ model_name=model_name,
+ )
+ result["api_response"] = response
+
+ if response:
+ parsed = parse_llm_response(response)
+ result["decision"] = parsed["decision"]
+ result["confidence"] = parsed["confidence"]
+ result["reasoning"] = parsed["reasoning"]
+
+ if parsed["decision"] is not None:
+ logger.info(
+ f"Case {case_id} [{experiment_name}]: "
+ f"{parsed['decision']} (confidence: {parsed['confidence']})"
+ )
+ break
+ else:
+ logger.warning(
+ f"Attempt {attempt}/{max_attempts}: Could not extract decision "
+ f"for case {case_id} in {experiment_name}"
+ )
+ else:
+ logger.warning(
+ f"Attempt {attempt}/{max_attempts}: No response for case {case_id} "
+ f"in {experiment_name}"
+ )
+
+ if attempt < max_attempts:
+ time.sleep(2 * attempt)
+ else:
+ logger.error(
+ f"Failed to get valid response for case {case_id} "
+ f"in {experiment_name} after {max_attempts} attempts"
+ )
+
+ except Exception as e:
+ logger.error(f"Error processing case {case_id} in {experiment_name}: {e}")
+ result["reasoning"] = f"Error: {str(e)}"
+
+ return result
+
+
+def run_ablation_experiment(
+ df: pd.DataFrame,
+ exclude_vars: Optional[List[str]],
+ experiment_name: str,
+ output_file: str,
+ model_name: Optional[str] = None,
+ delay_seconds: float = 0.2,
+ flush_interval: int = 10,
+ resume: bool = True,
+) -> pd.DataFrame:
+ """Run a single ablation experiment with incremental saving and resume.
+
+ Args:
+ df: DataFrame with case data and demographics.
+ exclude_vars: Variables to exclude (None for baseline).
+ experiment_name: Name of this experiment.
+ output_file: Path to output CSV for incremental saving.
+ model_name: Optional model name override.
+ delay_seconds: Delay between API calls.
+ flush_interval: Cases to process before flushing to disk.
+ resume: If True, skip already-processed cases.
+
+ Returns:
+ DataFrame with experiment results.
+ """
+ processed_ids: Set[str] = set()
+ write_header = True
+
+ if resume:
+ processed_ids = _load_processed_case_ids(output_file)
+ if processed_ids:
+ logger.info(
+ f"[{experiment_name}] Resuming: {len(processed_ids)} cases already processed"
+ )
+ write_header = False
+
+ if processed_ids:
+ pending_df = df[~df["llm_caseID"].astype(str).isin(processed_ids)]
+ else:
+ pending_df = df
+
+ total = len(pending_df)
+ if total == 0:
+ logger.info(f"[{experiment_name}] All cases already processed!")
+ if os.path.exists(output_file):
+ return pd.read_csv(output_file)
+ return pd.DataFrame()
+
+ excluded_label = ",".join(exclude_vars) if exclude_vars else "none"
+ logger.info(f"[{experiment_name}] Processing {total} cases (excluding: {excluded_label})")
+
+ batch_results: list = []
+ start_time = time.time()
+
+ for _, row in tqdm(pending_df.iterrows(), total=total, desc=experiment_name):
+ result = process_ablation_case(
+ row,
+ exclude_vars=exclude_vars,
+ experiment_name=experiment_name,
+ model_name=model_name,
+ )
+ batch_results.append(result)
+
+ if len(batch_results) >= flush_interval:
+ _flush_results_to_csv(batch_results, output_file, write_header)
+ write_header = False
+
+ if delay_seconds > 0:
+ time.sleep(delay_seconds)
+
+ # Flush remaining
+ if batch_results:
+ _flush_results_to_csv(batch_results, output_file, write_header)
+
+ elapsed = time.time() - start_time
+ logger.info(f"[{experiment_name}] Complete: {total} cases in {elapsed:.1f}s")
+
+ if os.path.exists(output_file):
+ return pd.read_csv(output_file)
+ return pd.DataFrame()
+
+
+def analyze_ablation_results(
+ results_dir: str,
+ ground_truth_df: Optional[pd.DataFrame] = None,
+ ground_truth_col: str = "had_surgery",
+) -> pd.DataFrame:
+ """Load result CSVs from results_dir, compare each to baseline, produce summary.
+
+ Args:
+ results_dir: Directory containing *_results.csv files.
+ ground_truth_df: Optional DataFrame with ground truth column.
+ ground_truth_col: Name of the ground truth column.
+
+ Returns:
+ Summary DataFrame sorted by flip rate.
+ """
+ # Load all result files
+ all_results: Dict[str, pd.DataFrame] = {}
+ csv_files = [f for f in os.listdir(results_dir) if f.endswith("_results.csv")]
+
+ for filename in sorted(csv_files):
+ exp_name = filename.replace("_results.csv", "")
+ filepath = os.path.join(results_dir, filename)
+ all_results[exp_name] = pd.read_csv(filepath)
+ logger.info(f"Loaded {exp_name}: {len(all_results[exp_name])} cases")
+
+ if "baseline" not in all_results:
+ raise ValueError(f"No baseline_results.csv found in {results_dir}")
+
+ baseline = all_results["baseline"].copy()
+ baseline["llm_caseID"] = baseline["llm_caseID"].astype(str)
+
+ # Optionally merge ground truth for accuracy metrics
+ has_gt = False
+ baseline_accuracy = None
+ gt = None
+ if ground_truth_df is not None and ground_truth_col in ground_truth_df.columns:
+ gt = ground_truth_df[["llm_caseID", ground_truth_col]].copy()
+ gt["llm_caseID"] = gt["llm_caseID"].astype(str)
+ baseline_with_gt = baseline.merge(gt, on="llm_caseID", how="left")
+ baseline_with_gt["decision_binary"] = (baseline_with_gt["decision"] == "Yes").astype(int)
+ baseline_with_gt["gt_binary"] = baseline_with_gt[ground_truth_col].astype(float)
+ if baseline_with_gt["gt_binary"].notna().any():
+ baseline_accuracy = (
+ (baseline_with_gt["decision_binary"] == baseline_with_gt["gt_binary"]).mean() * 100
+ )
+ has_gt = True
+
+ summary_data = []
+
+ for exp_name, ablation_df in all_results.items():
+ if exp_name == "baseline":
+ continue
+
+ ablation_df = ablation_df.copy()
+ ablation_df["llm_caseID"] = ablation_df["llm_caseID"].astype(str)
+
+ comparison = baseline[["llm_caseID", "decision", "confidence"]].merge(
+ ablation_df[["llm_caseID", "decision", "confidence"]],
+ on="llm_caseID",
+ suffixes=("_baseline", "_ablation"),
+ )
+
+ total_cases = len(comparison)
+ if total_cases == 0:
+ continue
+
+ decision_flips = (
+ comparison["decision_baseline"] != comparison["decision_ablation"]
+ ).sum()
+ flip_rate = decision_flips / total_cases * 100
+
+ yes_to_no = (
+ (comparison["decision_baseline"] == "Yes")
+ & (comparison["decision_ablation"] == "No")
+ ).sum()
+ no_to_yes = (
+ (comparison["decision_baseline"] == "No")
+ & (comparison["decision_ablation"] == "Yes")
+ ).sum()
+
+ valid_conf = comparison[
+ comparison["confidence_baseline"].notna()
+ & comparison["confidence_ablation"].notna()
+ ]
+ if len(valid_conf) > 0:
+ conf_change = (
+ valid_conf["confidence_ablation"] - valid_conf["confidence_baseline"]
+ ).mean()
+ abs_conf_change = (
+ (valid_conf["confidence_ablation"] - valid_conf["confidence_baseline"])
+ .abs()
+ .mean()
+ )
+ else:
+ conf_change = 0
+ abs_conf_change = 0
+
+ excluded_var = exp_name.replace("no_", "")
+ exp_type = "individual" if excluded_var in DEMOGRAPHIC_VARS else "grouped"
+
+ row_data = {
+ "experiment": exp_name,
+ "experiment_type": exp_type,
+ "excluded": excluded_var,
+ "total_cases": total_cases,
+ "decision_flips": int(decision_flips),
+ "flip_rate_%": round(flip_rate, 2),
+ "yes_to_no": int(yes_to_no),
+ "no_to_yes": int(no_to_yes),
+ "avg_confidence_change": round(conf_change, 3),
+ "avg_abs_confidence_change": round(abs_conf_change, 3),
+ }
+
+ # Add accuracy columns if ground truth available
+ if has_gt and gt is not None:
+ ablation_with_gt = ablation_df.merge(gt, on="llm_caseID", how="left")
+ ablation_with_gt["decision_binary"] = (
+ ablation_with_gt["decision"] == "Yes"
+ ).astype(int)
+ ablation_with_gt["gt_binary"] = ablation_with_gt[ground_truth_col].astype(float)
+ ablation_accuracy = (
+ (ablation_with_gt["decision_binary"] == ablation_with_gt["gt_binary"]).mean() * 100
+ )
+ row_data["baseline_accuracy_%"] = round(baseline_accuracy, 2)
+ row_data["ablation_accuracy_%"] = round(ablation_accuracy, 2)
+ row_data["accuracy_change_%"] = round(ablation_accuracy - baseline_accuracy, 2)
+
+ summary_data.append(row_data)
+
+ summary_df = pd.DataFrame(summary_data)
+ if not summary_df.empty:
+ summary_df = summary_df.sort_values("flip_rate_%", ascending=False)
+
+ return summary_df
+
+
+def stratified_sample(
+ df: pd.DataFrame,
+ sample_size: int,
+ stratify_vars: Optional[List[str]] = None,
+ random_state: int = 42,
+) -> pd.DataFrame:
+ """Create a stratified sample maintaining demographic distributions.
+
+ Args:
+ df: Full DataFrame.
+ sample_size: Target sample size.
+ stratify_vars: Variables to stratify on (default: legal_sex, race).
+ random_state: Random seed for reproducibility.
+
+ Returns:
+ Stratified sample DataFrame.
+ """
+ if stratify_vars is None:
+ stratify_vars = ["legal_sex", "race"]
+
+ stratify_vars = [
+ v
+ for v in stratify_vars
+ if v in df.columns and df[v].notna().sum() > sample_size * 0.1
+ ]
+
+ if not stratify_vars:
+ logger.warning("No valid stratification variables, using random sample")
+ return df.sample(n=min(sample_size, len(df)), random_state=random_state)
+
+ df_copy = df.copy()
+ df_copy["_strata"] = df_copy[stratify_vars].astype(str).agg("_".join, axis=1)
+
+ strata_counts = df_copy["_strata"].value_counts()
+ strata_proportions = strata_counts / len(df_copy)
+
+ min_per_stratum = 5
+ strata_samples = (strata_proportions * sample_size).round().astype(int)
+ strata_samples = strata_samples.clip(
+ lower=min(min_per_stratum, sample_size // len(strata_samples))
+ )
+
+ while strata_samples.sum() > sample_size:
+ largest = strata_samples.idxmax()
+ strata_samples[largest] -= 1
+
+ sampled_dfs = []
+ for stratum, n_samples in strata_samples.items():
+ stratum_df = df_copy[df_copy["_strata"] == stratum]
+ if len(stratum_df) >= n_samples:
+ sampled_dfs.append(stratum_df.sample(n=n_samples, random_state=random_state))
+ else:
+ sampled_dfs.append(stratum_df)
+
+ result = pd.concat(sampled_dfs, ignore_index=True)
+ return result.drop(columns=["_strata"])
diff --git a/llm_query/open_ai_processing.py b/llm_query/ent_surgical_llm_analysis.py
similarity index 96%
rename from llm_query/open_ai_processing.py
rename to llm_query/ent_surgical_llm_analysis.py
index 5aa5ac1..3e7ff74 100644
--- a/llm_query/open_ai_processing.py
+++ b/llm_query/ent_surgical_llm_analysis.py
@@ -1,6 +1,5 @@
-#openAI
-import openai
+# SecureLLM-based processing (migrated from OpenAI)
import pandas as pd
import json
import logging
@@ -8,24 +7,28 @@
from typing import Dict, Any
from tqdm import tqdm
-def query_openai(prompt: str, client) -> str:
- """Query GPT-4omini for surgical decision based on input prompt."""
- try:
- response = client.chat.completions.create(
- model="gpt-4o-mini",
- messages=[
- {"role": "system", "content": (
- "You are an expert otolaryngologist. "
- "Provide a surgical recommendation in the requested JSON format."
- )},
- {"role": "user", "content": prompt}
- ],
- temperature=0.2
- )
- return response.choices[0].message.content
- except Exception as e:
- logging.error(f"OpenAI API error: {e}")
- return None
+from llm_query.securellm_adapter import query_llm, SecureLLMClient
+
+
+def query_openai(prompt: str, client=None) -> str:
+ """
+ Query the LLM for surgical decision based on input prompt.
+
+ This function now uses SecureLLM instead of direct OpenAI calls.
+ The client parameter is kept for backward compatibility but is ignored.
+
+ Args:
+ prompt: The prompt to send to the LLM.
+ client: Deprecated. Kept for backward compatibility.
+
+ Returns:
+ The LLM response content or None on error.
+ """
+ return query_llm(
+ prompt=prompt,
+ system_message="You are an expert otolaryngologist. Provide a surgical recommendation in the requested JSON format.",
+ temperature=0.2
+ )
def generate_prompt(case_id: str, progress_text: str, radiology_text: str) -> str:
"""Generates a structured prompt for the LLM."""
@@ -88,13 +91,13 @@ def parse_llm_response(response: str) -> Dict[str, Any]:
logging.error(f"Unexpected error parsing response: {e}")
return default_response
-def process_llm_cases(llm_df: pd.DataFrame, api_key: str, delay_seconds: float = 0.2) -> pd.DataFrame:
+def process_llm_cases(llm_df: pd.DataFrame, api_key: str = None, delay_seconds: float = 0.2) -> pd.DataFrame:
"""
- Process a clean LLM DataFrame through OpenAI API.
+ Process a clean LLM DataFrame through SecureLLM API.
Args:
llm_df: DataFrame with columns 'llm_caseID', 'formatted_progress_text', 'formatted_radiology_text'
- api_key: OpenAI API key (hardcoded)
+ api_key: Deprecated. Kept for backward compatibility. SecureLLM uses VAULT_SECRET_KEY.
delay_seconds: Delay between API calls to avoid rate limiting
Returns:
@@ -104,9 +107,9 @@ def process_llm_cases(llm_df: pd.DataFrame, api_key: str, delay_seconds: float =
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
- # Initialize OpenAI client
- client = openai.OpenAI(api_key=api_key)
- logging.info("OpenAI client initialized successfully")
+ # Initialize SecureLLM client
+ client = SecureLLMClient()
+ logging.info("SecureLLM client initialized successfully")
# Create a copy of the dataframe
result_df = llm_df.copy()
@@ -171,13 +174,13 @@ def process_llm_cases(llm_df: pd.DataFrame, api_key: str, delay_seconds: float =
return result_df
-def run_llm_analysis(llm_df, api_key):
+def run_llm_analysis(llm_df, api_key: str = None):
"""
Main function to run the LLM analysis on your DataFrame.
Args:
llm_df: DataFrame with columns 'llm_caseID', 'formatted_progress_text', 'formatted_radiology_text'
- api_key: Your OpenAI API key
+ api_key: Deprecated. Kept for backward compatibility. SecureLLM uses VAULT_SECRET_KEY.
Returns:
DataFrame with LLM analysis results
@@ -187,7 +190,7 @@ def run_llm_analysis(llm_df, api_key):
print(f"DataFrame columns: {list(llm_df.columns)}")
# Process the cases
- results_df = process_llm_cases(llm_df, api_key, delay_seconds=0.2)
+ results_df = process_llm_cases(llm_df, delay_seconds=0.2)
# Show summary
total_cases = len(results_df)
diff --git a/llm_query/securellm_adapter.py b/llm_query/securellm_adapter.py
new file mode 100644
index 0000000..e3c77d9
--- /dev/null
+++ b/llm_query/securellm_adapter.py
@@ -0,0 +1,315 @@
+# This source file is part of the ARPA-H CARE LLM project
+#
+# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md)
+#
+# SPDX-License-Identifier: MIT
+#
+
+"""
+This module handles calls to a large language model (LLM) using SecureLLM.
+
+It includes a function to send a prompt to the model and return the generated
+response. The default model used is GPT-4o with secure key management via VAULT_SECRET_KEY.
+"""
+
+# Standard library imports
+import os
+import logging
+import time
+from typing import List, Dict, Any, Optional
+
+# Third party imports
+from dotenv import load_dotenv
+import requests
+
+logger = logging.getLogger(__name__)
+
+# Try to import securellm
+try:
+ from securellm.providers.apim import SecureLLMClient as ApimClient
+ from securellm.providers.registry import SECURE_MODEL_REGISTRY
+ _SECURELLM_AVAILABLE = True
+except ImportError:
+ logger.warning("securellm package not available, using fallback mode")
+ _SECURELLM_AVAILABLE = False
+ ApimClient = None
+ SECURE_MODEL_REGISTRY = {}
+
+
+class ModelConfig: # pylint: disable=too-few-public-methods
+ """
+ Configuration constants for the LLM interaction using SecureLLM.
+ """
+ DEFAULT_LLM_MODEL = "apim:gpt-4.1"
+ VAULT_SECRET_KEY = "VAULT_SECRET_KEY"
+ DEFAULT_TIMEOUT = 120 # seconds (increased for Gemini and other slow models)
+ MAX_RETRIES = 3
+ RETRY_DELAY = 5 # seconds
+
+
+def _initialize_secure_client():
+ """
+ Initialize SecureLLM client using VAULT_SECRET_KEY from environment.
+
+ Raises:
+ ValueError: If VAULT_SECRET_KEY is not set or securellm is not available
+ """
+ load_dotenv()
+
+ if not _SECURELLM_AVAILABLE:
+ raise ImportError("securellm package not installed. Install with: pip install -e .")
+
+ vault_key = os.getenv(ModelConfig.VAULT_SECRET_KEY)
+ if not vault_key:
+ raise ValueError(
+ f"Vault private key not found in environment variable '{ModelConfig.VAULT_SECRET_KEY}'. "
+ "Please set the VAULT_SECRET_KEY environment variable with your private key."
+ )
+
+ logger.info("Initialized SecureLLM client with private key from VAULT_SECRET_KEY")
+ return vault_key
+
+
+def get_llm_client_instance(model_name: Optional[str] = None, timeout: int = None):
+ """
+ Get or create the SecureLLM client instance with custom timeout.
+
+ Args:
+ model_name: Optional model name override. Defaults to ModelConfig.DEFAULT_LLM_MODEL.
+ timeout: Request timeout in seconds. Defaults to ModelConfig.DEFAULT_TIMEOUT.
+
+ Returns:
+ SecureLLM client instance
+ """
+ api_key = _initialize_secure_client() # Verify key is available and get it
+ model = model_name or ModelConfig.DEFAULT_LLM_MODEL
+ timeout = timeout or ModelConfig.DEFAULT_TIMEOUT
+
+ # Get model config from registry
+ config = SECURE_MODEL_REGISTRY.get(model)
+ if not config:
+ raise ValueError(f"Unknown model name: {model}")
+
+ # Create client with custom timeout
+ return ApimClient(
+ base_url=config["base_url"],
+ api_key=api_key,
+ model_name=model,
+ model_id=config["model_id"],
+ api_version=config.get("api_version"),
+ timeout=timeout
+ )
+
+
+def llm_call(prompt: str, temperature: float = 0.7, max_tokens: int = 10000) -> str:
+ """
+ Sends a prompt to the default LLM and returns the generated response using SecureLLM.
+
+ Args:
+ prompt (str): The user input to send to the model.
+ temperature (float): Sampling temperature for response variation.
+ max_tokens (int): Maximum number of tokens in the model's response.
+
+ Returns:
+ str: The content of the LLM's response.
+
+ Raises:
+ ImportError: If securellm is not installed
+ ValueError: If VAULT_SECRET_KEY is not set
+ """
+ messages = [{"role": "user", "content": prompt}]
+ response = llm_chat(messages, temperature=temperature, max_tokens=max_tokens)
+ if response is None:
+ raise RuntimeError("LLM call failed after retries")
+ return response
+
+
+def llm_chat(
+ messages: List[Dict[str, str]],
+ temperature: float = 0.2,
+ max_tokens: int = 500,
+ model_name: Optional[str] = None,
+ timeout: int = None,
+ max_retries: int = None
+) -> Optional[str]:
+ """
+ Sends a chat conversation to the LLM and returns the generated response.
+
+ This function supports system messages and multi-turn conversations,
+ making it suitable for the ENT surgical recommendation use case.
+ Includes retry logic for timeout and connection errors.
+
+ Args:
+ messages: List of message dictionaries with 'role' and 'content' keys.
+ Roles can be 'system', 'user', or 'assistant'.
+ temperature: Sampling temperature for response variation. Default 0.2 for consistency.
+ max_tokens: Maximum number of tokens in the model's response.
+ model_name: Optional model name override.
+ timeout: Request timeout in seconds. Defaults to ModelConfig.DEFAULT_TIMEOUT.
+ max_retries: Maximum retry attempts for timeout errors. Defaults to ModelConfig.MAX_RETRIES.
+
+ Returns:
+ str: The content of the LLM's response, or None if an error occurred.
+
+ Raises:
+ ImportError: If securellm is not installed
+ ValueError: If VAULT_SECRET_KEY is not set
+
+ Example:
+ >>> messages = [
+ ... {"role": "system", "content": "You are an expert otolaryngologist."},
+ ... {"role": "user", "content": "Should this patient have surgery?"}
+ ... ]
+ >>> response = llm_chat(messages)
+ """
+ if not _SECURELLM_AVAILABLE:
+ raise ImportError("securellm package not installed. Install with: pip install -e .")
+
+ model = model_name or ModelConfig.DEFAULT_LLM_MODEL
+ timeout = timeout or ModelConfig.DEFAULT_TIMEOUT
+ max_retries = max_retries if max_retries is not None else ModelConfig.MAX_RETRIES
+
+ # Build generation config
+ config = {
+ "temperature": temperature,
+ "max_tokens": max_tokens
+ }
+
+ last_error = None
+ for attempt in range(max_retries + 1):
+ try:
+ client = get_llm_client_instance(model, timeout=timeout)
+
+ # SecureLLM client uses generate() method and returns parsed content directly
+ response = client.generate(messages, generation_config=config)
+
+ return response.strip() if isinstance(response, str) else str(response).strip()
+
+ except (TimeoutError, requests.exceptions.Timeout, requests.exceptions.ReadTimeout,
+ requests.exceptions.ConnectionError) as e:
+ last_error = e
+ if attempt < max_retries:
+ wait_time = ModelConfig.RETRY_DELAY * (attempt + 1)
+ logger.warning(f"Request failed (attempt {attempt + 1}/{max_retries + 1}): {type(e).__name__}. "
+ f"Retrying in {wait_time}s...")
+ time.sleep(wait_time)
+ else:
+ logger.error(f"Request failed after {max_retries + 1} attempts: {e}")
+
+ except Exception as e:
+ logger.error(f"SecureLLM API error: {e}")
+ return None
+
+ logger.error(f"All retry attempts exhausted. Last error: {last_error}")
+ return None
+
+
+def query_llm(
+ prompt: str,
+ system_message: str = "You are an expert otolaryngologist. Provide a surgical recommendation in the requested JSON format.",
+ temperature: float = 0.2,
+ max_tokens: int = 500,
+ model_name: Optional[str] = None
+) -> Optional[str]:
+ """
+ Query the LLM with a system message and user prompt.
+
+ This is a convenience function that wraps llm_chat for simple query patterns
+ commonly used in the ENT analysis pipeline.
+
+ Args:
+ prompt: The user prompt to send to the model.
+ system_message: The system message to set the model's behavior.
+ temperature: Sampling temperature. Default 0.2 for consistent medical recommendations.
+ max_tokens: Maximum response tokens.
+ model_name: Optional model name override.
+
+ Returns:
+ str: The LLM's response content, or None if an error occurred.
+ """
+ messages = [
+ {"role": "system", "content": system_message},
+ {"role": "user", "content": prompt}
+ ]
+ return llm_chat(messages, temperature=temperature, max_tokens=max_tokens, model_name=model_name)
+
+
+class SecureLLMClient:
+ """
+ A client wrapper for SecureLLM that provides an interface compatible with
+ the existing codebase patterns.
+
+ This class can be used as a drop-in replacement where OpenAI client was used.
+
+ Example:
+ >>> client = SecureLLMClient()
+ >>> response = client.query("What is the diagnosis?")
+ """
+
+ def __init__(self, model_name: Optional[str] = None):
+ """
+ Initialize the SecureLLM client.
+
+ Args:
+ model_name: Optional model name. Defaults to ModelConfig.DEFAULT_LLM_MODEL.
+ """
+ self.model_name = model_name or ModelConfig.DEFAULT_LLM_MODEL
+ self._client = None
+
+ @property
+ def client(self):
+ """Lazy initialization of the underlying client."""
+ if self._client is None:
+ self._client = get_llm_client_instance(self.model_name)
+ return self._client
+
+ def query(
+ self,
+ prompt: str,
+ system_message: str = "You are an expert otolaryngologist. Provide a surgical recommendation in the requested JSON format.",
+ temperature: float = 0.2,
+ max_tokens: int = 500
+ ) -> Optional[str]:
+ """
+ Query the LLM with a prompt.
+
+ Args:
+ prompt: The user prompt.
+ system_message: System message for context.
+ temperature: Sampling temperature.
+ max_tokens: Maximum response tokens.
+
+ Returns:
+ The LLM response or None on error.
+ """
+ return query_llm(
+ prompt=prompt,
+ system_message=system_message,
+ temperature=temperature,
+ max_tokens=max_tokens,
+ model_name=self.model_name
+ )
+
+ def chat(
+ self,
+ messages: List[Dict[str, str]],
+ temperature: float = 0.2,
+ max_tokens: int = 500
+ ) -> Optional[str]:
+ """
+ Send a chat conversation to the LLM.
+
+ Args:
+ messages: List of message dictionaries with 'role' and 'content'.
+ temperature: Sampling temperature.
+ max_tokens: Maximum response tokens.
+
+ Returns:
+ The LLM response or None on error.
+ """
+ return llm_chat(
+ messages=messages,
+ temperature=temperature,
+ max_tokens=max_tokens,
+ model_name=self.model_name
+ )
diff --git a/llm_query/tools/llm/__init__.py b/llm_query/tools/llm/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/llm_query/tools/llm/cache.py b/llm_query/tools/llm/cache.py
new file mode 100644
index 0000000..ed844eb
--- /dev/null
+++ b/llm_query/tools/llm/cache.py
@@ -0,0 +1,148 @@
+"""SQLite-based caching for LLM responses."""
+
+import sqlite3
+import json
+import hashlib
+from typing import Optional, Dict, Any
+from pathlib import Path
+
+class LLMCache:
+ def __init__(self, cache_dir: str = ".cache"):
+ """Initialize the cache.
+
+ Args:
+ cache_dir: Directory to store the SQLite database
+ """
+ self.cache_dir = Path(cache_dir)
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
+ self.db_path = self.cache_dir / "llm_cache.db"
+ self._init_db()
+
+ def _init_db(self):
+ """Initialize the SQLite database with required tables."""
+ with sqlite3.connect(self.db_path) as conn:
+ conn.execute("""
+ CREATE TABLE IF NOT EXISTS llm_cache (
+ cache_key TEXT PRIMARY KEY,
+ model_name TEXT NOT NULL,
+ system_prompt TEXT,
+ prompt_template TEXT,
+ context_hash TEXT,
+ response TEXT NOT NULL,
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
+ )
+ """)
+ # Add index for faster lookups
+ conn.execute("""
+ CREATE INDEX IF NOT EXISTS idx_cache_lookup
+ ON llm_cache(model_name, system_prompt, prompt_template, context_hash)
+ """)
+
+ def _compute_hash(self, text: str) -> str:
+ """Compute SHA-256 hash of text."""
+ return hashlib.sha256(text.encode()).hexdigest()
+
+ def _get_cache_key(self,
+ model_name: str,
+ system_prompt: str,
+ prompt_template: str,
+ context: str) -> str:
+ """Generate a cache key from the input parameters."""
+ components = [
+ model_name,
+ system_prompt,
+ prompt_template,
+ self._compute_hash(context)
+ ]
+ return self._compute_hash("|".join(components))
+
+ def get(self,
+ model_name: str,
+ system_prompt: str,
+ prompt_template: str,
+ context: str) -> Optional[str]:
+ """Get cached response if it exists.
+
+ Args:
+ model_name: Name of the LLM model
+ system_prompt: System prompt used
+ prompt_template: Template used for the prompt
+ context: Context provided to the model
+
+ Returns:
+ Cached response if found, None otherwise
+ """
+ cache_key = self._get_cache_key(
+ model_name, system_prompt, prompt_template, context
+ )
+
+ with sqlite3.connect(self.db_path) as conn:
+ cursor = conn.execute(
+ "SELECT response FROM llm_cache WHERE cache_key = ?",
+ (cache_key,)
+ )
+ result = cursor.fetchone()
+
+ return result[0] if result else None
+
+ def set(self,
+ model_name: str,
+ system_prompt: str,
+ prompt_template: str,
+ context: str,
+ response: str):
+ """Cache a response.
+
+ Args:
+ model_name: Name of the LLM model
+ system_prompt: System prompt used
+ prompt_template: Template used for the prompt
+ context: Context provided to the model
+ response: Response to cache
+ """
+ cache_key = self._get_cache_key(
+ model_name, system_prompt, prompt_template, context
+ )
+ context_hash = self._compute_hash(context)
+
+ with sqlite3.connect(self.db_path) as conn:
+ conn.execute("""
+ INSERT OR REPLACE INTO llm_cache
+ (cache_key, model_name, system_prompt, prompt_template, context_hash, response)
+ VALUES (?, ?, ?, ?, ?, ?)
+ """, (cache_key, model_name, system_prompt, prompt_template, context_hash, response))
+
+ def clear(self):
+ """Clear all cached responses."""
+ with sqlite3.connect(self.db_path) as conn:
+ conn.execute("DELETE FROM llm_cache")
+
+ def get_stats(self) -> Dict[str, Any]:
+ """Get cache statistics.
+
+ Returns:
+ Dictionary containing cache statistics
+ """
+ with sqlite3.connect(self.db_path) as conn:
+ cursor = conn.execute("SELECT COUNT(*) FROM llm_cache")
+ total_entries = cursor.fetchone()[0]
+
+ cursor = conn.execute("""
+ SELECT model_name, COUNT(*) as count
+ FROM llm_cache
+ GROUP BY model_name
+ """)
+ model_counts = dict(cursor.fetchall())
+
+ cursor = conn.execute("""
+ SELECT MIN(created_at) as oldest, MAX(created_at) as newest
+ FROM llm_cache
+ """)
+ oldest, newest = cursor.fetchone()
+
+ return {
+ "total_entries": total_entries,
+ "model_counts": model_counts,
+ "oldest_entry": oldest,
+ "newest_entry": newest
+ }
\ No newline at end of file
diff --git a/llm_query/tools/llm/secure_llm_client.py b/llm_query/tools/llm/secure_llm_client.py
new file mode 100644
index 0000000..0d610df
--- /dev/null
+++ b/llm_query/tools/llm/secure_llm_client.py
@@ -0,0 +1,150 @@
+# examples/mcp_chat_demo/chat/llm/secure_llm_client.py
+
+"""
+Secure LLM Client Interface
+
+This module provides secure-llm Client interface for the MCP chat demo.
+Uses secure-llm's native API directly.
+"""
+
+import logging
+from typing import Dict, Any, List, Optional, Union
+
+logger = logging.getLogger(__name__)
+
+try:
+ from securellm import Client, get_available_models as _securellm_get_available_models
+ _SECURELLM_AVAILABLE = True
+except ImportError:
+ logger.error("securellm package not found. Install with: pip install secure-llm")
+ Client = None
+ _securellm_get_available_models = None
+ _SECURELLM_AVAILABLE = False
+
+
+# Global client instance (initialized once, reused)
+_global_client = None
+
+
+def get_client() -> Client:
+ """
+ Get or create the global secure-llm Client instance.
+
+ Returns:
+ secure-llm Client instance
+
+ Raises:
+ ImportError: If secure-llm is not installed
+ """
+ if not _SECURELLM_AVAILABLE:
+ raise ImportError("securellm package not installed. Install with: pip install secure-llm")
+
+ global _global_client
+ if _global_client is None:
+ _global_client = Client()
+ logger.info("Initialized secure-llm Client")
+
+ return _global_client
+
+
+def get_llm_client(model_name: Optional[str] = None):
+ """
+ Get secure-llm Client instance.
+
+ Args:
+ model_name: Optional model identifier (for logging/compatibility).
+ Note: Model is specified per-request, not per-client.
+
+ Returns:
+ secure-llm Client instance
+ """
+ client = get_client()
+ if model_name:
+ logger.debug(f"Client ready for model: {model_name}")
+ return client
+
+
+def extract_response_content(response: Union[Dict[str, Any], Any]) -> str:
+ """
+ Extract content from secure-llm response.
+
+ Args:
+ response: Response from client.chat.completions.create()
+
+ Returns:
+ Response content text
+
+ Raises:
+ ValueError: If response format is unexpected
+ """
+ # secure-llm returns dict format: response["choices"][0]["message"]["content"]
+ if isinstance(response, dict):
+ choices = response.get("choices", [])
+ if choices:
+ message = choices[0].get("message", {})
+ content = message.get("content", "")
+ if content:
+ return content
+
+ # Handle object-style response as fallback
+ try:
+ if hasattr(response, "choices") and response.choices:
+ message = response.choices[0].message
+ if hasattr(message, "content"):
+ return message.content
+ # Try dict access on object
+ return response.choices[0]["message"]["content"]
+ except (AttributeError, KeyError, IndexError):
+ pass
+
+ # Try direct dict access on object
+ try:
+ return response["choices"][0]["message"]["content"]
+ except (KeyError, IndexError, TypeError):
+ pass
+
+ raise ValueError(f"Unexpected response format: {type(response)}")
+
+
+def get_available_models() -> List[str]:
+ """
+ Get list of available models from secure-llm.
+
+ Returns:
+ List of model identifiers from secure-llm's model registry
+ """
+ if not _SECURELLM_AVAILABLE or _securellm_get_available_models is None:
+ logger.warning("securellm not available, returning empty model list")
+ return []
+
+ try:
+ # Use secure-llm's built-in function to get models from registry
+ models = _securellm_get_available_models()
+ logger.debug(f"Retrieved {len(models)} models from secure-llm")
+ return models
+ except Exception as e:
+ logger.error(f"Error getting available models from secure-llm: {e}")
+ return []
+
+
+def get_default_generation_config(overrides: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
+ """
+ Get default generation configuration.
+
+ Args:
+ overrides: Dict of parameters to override defaults
+
+ Returns:
+ Generation config dict
+ """
+ defaults = {
+ "temperature": 0.7,
+ "top_p": 1.0,
+ "max_tokens": 2048,
+ }
+
+ if overrides:
+ defaults.update(overrides)
+
+ return defaults
+
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..c0137c4
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,50 @@
+[project]
+name = "ent-llm"
+version = "0.1.0"
+description = "LLM evaluation of ENT clinical cases"
+readme = "README.md"
+requires-python = ">=3.9"
+dependencies = [
+ "pandas>=2.0.0",
+ "numpy>=1.24.0",
+ "tqdm>=4.65.0",
+ "google-cloud-bigquery>=3.10.0",
+ "db-dtypes>=1.1.0",
+ "google-cloud-aiplatform>=1.25.0",
+ "matplotlib>=3.7.0",
+ "seaborn>=0.12.0",
+ "scikit-learn>=1.2.0",
+ "scipy>=1.10.0",
+ "reportlab>=4.0.0",
+ "aiohttp>=3.8.0",
+ "python-dotenv>=1.0.0",
+ "securellm @ git+ssh://git@github.com/VISTA-Stanford/secure-llm.git",
+]
+
+[project.optional-dependencies]
+dev = [
+ "pytest>=7.0.0",
+ "black>=23.0.0",
+ "ruff>=0.0.270",
+]
+
+[project.scripts]
+ent-llm = "cli:main"
+ent-llm-extract = "cli_extract:main"
+ent-llm-ablation = "cli_ablation:main"
+
+[build-system]
+requires = ["setuptools>=61.0", "wheel"]
+build-backend = "setuptools.build_meta"
+
+[tool.setuptools.packages.find]
+include = ["batch_query*", "data_extraction*", "evaluation*", "llm_query*", "training*"]
+
+[tool.black]
+line-length = 100
+target-version = ["py39", "py310", "py311"]
+
+[tool.ruff]
+line-length = 100
+select = ["E", "F", "I"]
+ignore = ["E501"]
diff --git a/scripts/merge_files.py b/scripts/merge_files.py
new file mode 100644
index 0000000..1df5baf
--- /dev/null
+++ b/scripts/merge_files.py
@@ -0,0 +1,88 @@
+import pandas as pd
+
+
+def merge_csv_files(file_list, output_file="merged_file.csv"):
+ """
+ Merges multiple CSV files into a single sorted CSV file.
+
+ Parameters:
+ file_list (list of str): List of file paths to the CSV files to be merged.
+ output_file (str): Path to the output CSV file (default: merged_file.csv).
+ """
+ # Initialize an empty list to hold dataframes
+ dataframes = []
+
+ # Read each CSV file and append the dataframe to the list
+ for file in file_list:
+ try:
+ df = pd.read_csv(file)
+ dataframes.append(df)
+ print(f"Loaded: {file} ({len(df)} rows)")
+ except Exception as e:
+ print(f"Error loading {file}: {e}")
+ continue
+
+ if not dataframes:
+ print("No dataframes to merge!")
+ return
+
+ # Concatenate all dataframes into a single dataframe
+ merged_df = pd.concat(dataframes, ignore_index=True)
+ print(f"Total merged rows: {len(merged_df)}")
+
+ # Remove duplicates by llm_caseID
+ if "llm_caseID" in merged_df.columns:
+ duplicate_count = merged_df.duplicated(subset=["llm_caseID"], keep=False).sum()
+
+ if duplicate_count > 0 and "decision" in merged_df.columns:
+ # Custom deduplication: keep row with non-NaN decision, if all NaN keep first
+ # Create a sort key: prioritize non-NaN decision values
+ merged_df["_has_decision"] = merged_df["decision"].notna().astype(int)
+ merged_df = merged_df.sort_values(by=["llm_caseID", "_has_decision"], ascending=[True, False])
+ merged_df = merged_df.drop_duplicates(subset=["llm_caseID"], keep="first")
+ merged_df = merged_df.drop(columns=["_has_decision"])
+ print(f"Removed duplicate entries (kept row with non-NaN decision, or first if all NaN)")
+ elif duplicate_count > 0:
+ # If decision column doesn't exist, keep first occurrence
+ merged_df = merged_df.drop_duplicates(subset=["llm_caseID"], keep="first")
+ print(f"Removed duplicate entries (decision column not found, kept first occurrence)")
+
+ print(f"Rows after deduplication: {len(merged_df)}")
+ else:
+ print("Warning: llm_caseID column not found. Skipping deduplication.")
+
+ # Sort by llm_caseID if it exists
+ if "llm_caseID" in merged_df.columns:
+ merged_df = merged_df.sort_values(by="llm_caseID").reset_index(drop=True)
+ print(f"Sorted by llm_caseID")
+ else:
+ print("Warning: llm_caseID column not found. Saving without sorting.")
+
+ # Report NaN decision values in final dataframe
+ if "decision" in merged_df.columns:
+ nan_count = merged_df["decision"].isna().sum()
+ non_nan_count = merged_df["decision"].notna().sum()
+ print(f"Final decision column: {non_nan_count} non-NaN, {nan_count} NaN ({nan_count/len(merged_df)*100:.1f}%)")
+
+ # Report llm_caseID values with NaN decisions
+ if nan_count > 0:
+ nan_case_ids = merged_df[merged_df["decision"].isna()]["llm_caseID"].tolist()
+ print(f"\nllm_caseID values with NaN decision:")
+ for case_id in nan_case_ids:
+ print(f" {case_id}")
+
+ # Save the merged dataframe to a new CSV file
+ merged_df.to_csv(output_file, index=False)
+ print(f"Merged data saved to: {output_file}")
+
+
+def main():
+ # Example usage
+ files_to_merge = [
+ "data/results/file1.csv",
+ "data/results/file2.csv",
+ ]
+ merge_csv_files(files_to_merge, output_file="data/results/merged_results.csv")
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/training/training_LLM.py b/training/training_LLM.py
index 50f2310..da9b265 100644
--- a/training/training_LLM.py
+++ b/training/training_LLM.py
@@ -1,8 +1,12 @@
# Formatting for LLM Training Input
import pandas as pd
import logging
+import json
+import time
from typing import Dict, Any, Union, List, Tuple, Optional
+from llm_query.securellm_adapter import query_llm, llm_chat, SecureLLMClient
+
def format_medical_data(progress_note: Union[Dict, None], radiology_reports: List[Dict]) -> Dict[str, Any]:
"""Format medical data from note dictionary and reports list into readable text."""
@@ -112,24 +116,25 @@ def training_create_llm_dataframe(processed_df: pd.DataFrame, num_training_rows:
return training_df, test_df
-def query_openai(prompt: str, client) -> str:
- """Query GPT-4 for surgical decision based on input prompt."""
- try:
- response = client.chat.completions.create(
- model="gpt-4",
- messages=[
- {"role": "system", "content": (
- "You are an expert otolaryngologist. "
- "Provide a surgical recommendation in the requested JSON format."
- )},
- {"role": "user", "content": prompt}
- ],
- temperature=0.2
- )
- return response.choices[0].message.content
- except Exception as e:
- logging.error(f"OpenAI API error: {e}")
- return None
+def query_openai(prompt: str, client=None) -> str:
+ """
+ Query the LLM for surgical decision based on input prompt.
+
+ This function now uses SecureLLM instead of direct OpenAI calls.
+ The client parameter is kept for backward compatibility but is ignored.
+
+ Args:
+ prompt: The prompt to send to the LLM.
+ client: Deprecated. Kept for backward compatibility.
+
+ Returns:
+ The LLM response content or None on error.
+ """
+ return query_llm(
+ prompt=prompt,
+ system_message="You are an expert otolaryngologist. Provide a surgical recommendation in the requested JSON format.",
+ temperature=0.2
+ )
def generate_training_examples(sample_cases: pd.DataFrame) -> str:
"""Generate training examples from sample cases."""
@@ -251,13 +256,13 @@ def parse_llm_response(response: str) -> Dict[str, Any]:
logging.error(f"Unexpected error parsing response: {e}")
return default_response
-def process_llm_cases(llm_df: pd.DataFrame, api_key: str, delay_seconds: float = 1.0) -> pd.DataFrame:
+def process_llm_cases(llm_df: pd.DataFrame, api_key: str = None, delay_seconds: float = 1.0) -> pd.DataFrame:
"""
- Process a clean LLM DataFrame through OpenAI API.
+ Process a clean LLM DataFrame through SecureLLM API.
Args:
llm_df: DataFrame with columns 'llm_caseID', 'formatted_progress_text', 'formatted_radiology_text'
- api_key: OpenAI API key (hardcoded)
+ api_key: Deprecated. Kept for backward compatibility. SecureLLM uses VAULT_SECRET_KEY.
delay_seconds: Delay between API calls to avoid rate limiting
Returns:
@@ -267,12 +272,12 @@ def process_llm_cases(llm_df: pd.DataFrame, api_key: str, delay_seconds: float =
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
- # Initialize OpenAI client
+ # Initialize SecureLLM client
try:
- client = openai.OpenAI(api_key=api_key)
- logging.info("OpenAI client initialized successfully")
+ client = SecureLLMClient()
+ logging.info("SecureLLM client initialized successfully")
except Exception as e:
- logging.error(f"Failed to initialize OpenAI client: {e}")
+ logging.error(f"Failed to initialize SecureLLM client: {e}")
raise
# Create a copy of the dataframe
@@ -326,13 +331,13 @@ def process_llm_cases(llm_df: pd.DataFrame, api_key: str, delay_seconds: float =
return result_df
-def run_llm_analysis(llm_df, api_key):
+def run_llm_analysis(llm_df, api_key: str = None):
"""
Main function to run the LLM analysis on your DataFrame.
Args:
llm_df: DataFrame with columns 'llm_caseID', 'formatted_progress_text', 'formatted_radiology_text'
- api_key: Your OpenAI API key
+ api_key: Deprecated. Kept for backward compatibility. SecureLLM uses VAULT_SECRET_KEY.
Returns:
DataFrame with LLM analysis results
@@ -342,7 +347,7 @@ def run_llm_analysis(llm_df, api_key):
print(f"DataFrame columns: {list(llm_df.columns)}")
# Process the cases
- results_df = process_llm_cases(llm_df, api_key, delay_seconds=1.0)
+ results_df = process_llm_cases(llm_df, delay_seconds=1.0)
# Show summary
total_cases = len(results_df)
@@ -360,37 +365,36 @@ def run_llm_analysis(llm_df, api_key):
return results_df
-import pandas as pd
-import logging
-import json
-import time
-import openai
-from typing import Dict, Any, Union, List, Tuple, Optional
-
-
class ConversationalLLMAnalyzer:
"""
LLM analyzer that maintains conversation context to avoid repeating training examples.
+ Now uses SecureLLM instead of direct OpenAI calls.
"""
- def __init__(self, api_key: str, model: str = "gpt-4"):
- self.client = openai.OpenAI(api_key=api_key)
+ def __init__(self, api_key: str = None, model: str = "gpt-4o"):
+ """
+ Initialize the ConversationalLLMAnalyzer.
+
+ Args:
+ api_key: Deprecated. Kept for backward compatibility. SecureLLM uses VAULT_SECRET_KEY.
+ model: Model name to use. Defaults to gpt-4o.
+ """
self.model = model
self.conversation_history = []
self.training_loaded = False
def _make_api_call(self, messages: List[Dict], max_tokens: int = 500) -> str:
- """Make API call with error handling."""
+ """Make API call with error handling using SecureLLM."""
try:
- response = self.client.chat.completions.create(
- model=self.model,
+ response = llm_chat(
messages=messages,
temperature=0.2,
- max_tokens=max_tokens
+ max_tokens=max_tokens,
+ model_name=self.model
)
- return response.choices[0].message.content
+ return response
except Exception as e:
- logging.error(f"OpenAI API error: {e}")
+ logging.error(f"SecureLLM API error: {e}")
return None
def load_training_examples(self, training_df: pd.DataFrame) -> bool:
@@ -559,14 +563,14 @@ def parse_llm_response(self, response: str) -> Dict[str, Any]:
def process_llm_cases_conversational(test_df: pd.DataFrame, training_df: pd.DataFrame,
- api_key: str, delay_seconds: float = 1.0) -> pd.DataFrame:
+ api_key: str = None, delay_seconds: float = 1.0) -> pd.DataFrame:
"""
Process cases using conversational context approach.
Args:
test_df: DataFrame with test cases
training_df: DataFrame with training examples
- api_key: OpenAI API key
+ api_key: Deprecated. Kept for backward compatibility. SecureLLM uses VAULT_SECRET_KEY.
delay_seconds: Delay between API calls
Returns:
@@ -576,8 +580,8 @@ def process_llm_cases_conversational(test_df: pd.DataFrame, training_df: pd.Data
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
- # Initialize analyzer
- analyzer = ConversationalLLMAnalyzer(api_key)
+ # Initialize analyzer with SecureLLM
+ analyzer = ConversationalLLMAnalyzer()
# Load training examples
logging.info(f"Loading {len(training_df)} training examples...")
@@ -630,14 +634,14 @@ def process_llm_cases_conversational(test_df: pd.DataFrame, training_df: pd.Data
return result_df
-def run_llm_analysis_training(test_df: pd.DataFrame, training_df: pd.DataFrame, api_key: str):
+def run_llm_analysis_training(test_df: pd.DataFrame, training_df: pd.DataFrame, api_key: str = None):
"""
Main function to run conversational LLM analysis.
Args:
test_df: DataFrame with test cases
training_df: DataFrame with training examples
- api_key: OpenAI API key
+ api_key: Deprecated. Kept for backward compatibility. SecureLLM uses VAULT_SECRET_KEY.
Returns:
DataFrame with LLM analysis results
@@ -654,8 +658,8 @@ def run_llm_analysis_training(test_df: pd.DataFrame, training_df: pd.DataFrame,
if missing_cols:
raise ValueError(f"Training DataFrame missing required columns: {missing_cols}")
- # Process cases
- results_df = process_llm_cases_conversational(test_df, training_df, api_key, delay_seconds=1.0)
+ # Process cases with SecureLLM
+ results_df = process_llm_cases_conversational(test_df, training_df, delay_seconds=1.0)
# Summary
total_cases = len(results_df)
@@ -673,4 +677,5 @@ def run_llm_analysis_training(test_df: pd.DataFrame, training_df: pd.DataFrame,
return results_df
-results = run_llm_analysis_training(test_df, training_df, api_key)
\ No newline at end of file
+# Example usage (uncomment to run):
+# results = run_llm_analysis_training(test_df, training_df)
\ No newline at end of file