diff --git a/.gitignore b/.gitignore index f2ff7f1..123b1cc 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ data/paper data models wandb +wandb_logs __pycache__/ *.pyc diff --git a/processed_sources.jsonl b/processed_sources.jsonl index c6e8d0e..0ab1d9e 100644 --- a/processed_sources.jsonl +++ b/processed_sources.jsonl @@ -1106,7 +1106,7 @@ {"url": "https://oig.hhs.gov/oei/reports/oei-02-16-00570.pdf", "date_accessed": "2024-01-25", "local_path": "./data/raw/hhs_oig/oei-02-16-00570.pdf", "tags": ["hhs-oig", "regulatory-guidance"], "preprocessor": "pdf", "md5": "ea67cf9aa08f8eeb92f43178a0259de8", "local_processed_path": "./data/processed/hhs_oig/oei-02-16-00570.jsonl"} {"url": "https://oig.hhs.gov/oas/reports/region9/91602042.pdf", "date_accessed": "2024-01-25", "local_path": "./data/raw/hhs_oig/91602042.pdf", "tags": ["hhs-oig", "regulatory-guidance"], "preprocessor": "pdf", "md5": "0956b057e0ada92e40dcc7bf02bb5cf2", "local_processed_path": "./data/processed/hhs_oig/91602042.jsonl"} {"url": "https://oig.hhs.gov/oei/reports/oei-03-15-00180.pdf", "date_accessed": "2024-01-25", "local_path": "./data/raw/hhs_oig/oei-03-15-00180.pdf", "tags": ["hhs-oig", "regulatory-guidance"], "preprocessor": "pdf", "md5": "fa56c5c8d598b139873e22a08131b69b", "local_processed_path": "./data/processed/hhs_oig/oei-03-15-00180.jsonl"} -{"url": "https://oig.hhs.gov/reports-and-publications/portfolio/portfolio-12-12-01.pdf", "date_accessed": "2024-01-25", "local_path": "./data/raw/hhs_oig/portfolio-12-12-01.pdf", "tags": ["hhs-oig", "regulatory-guidance"], "preprocessor": "pdf", "md5": "a0b1f3c770899d1ed57ced5a477e9e94", "local_processed_path": "./data/processed/hhs_oig/portfolio-12-12-01.jsonl"} +{"url": "https://oig.hhs.gov/reports-and-publications/portfolio/portfolio-12-12-01.pdf", "date_accessed": "2024-01-25", "local_path": "./data/raw/hhs_oig/portfolio-12-12-01.pdf", "tags": ["hhs-oig"], "preprocessor": "pdf", "md5": "a0b1f3c770899d1ed57ced5a477e9e94", "local_processed_path": "./data/processed/hhs_oig/portfolio-12-12-01.jsonl"} {"url": "https://oig.hhs.gov/documents/testimony/1119/bliss-testimony-05182023.pdf", "date_accessed": "2024-01-25", "local_path": "./data/raw/hhs_oig/bliss-testimony-05182023.pdf", "tags": ["hhs-oig", "testimony", "opinion-policy-summary"], "preprocessor": "pdf", "md5": "417d6519b93105b6a371602110bfe33c", "local_processed_path": "./data/processed/hhs_oig/bliss-testimony-05182023.jsonl"} {"url": "https://oig.hhs.gov/documents/testimony/1118/megan-tinker-testimony-05172023.pdf", "date_accessed": "2024-01-25", "local_path": "./data/raw/hhs_oig/megan-tinker-testimony-05172023.pdf", "tags": ["hhs-oig", "testimony", "opinion-policy-summary"], "preprocessor": "pdf", "md5": "e562fcf65a825e1e5d72ed55a930b993", "local_processed_path": "./data/processed/hhs_oig/megan-tinker-testimony-05172023.jsonl"} {"url": "https://oig.hhs.gov/documents/testimony/1112/christi-grimm-testimony-04182023.pdf", "date_accessed": "2024-01-25", "local_path": "./data/raw/hhs_oig/christi-grimm-testimony-04182023.pdf", "tags": ["hhs-oig", "testimony", "opinion-policy-summary"], "preprocessor": "pdf", "md5": "3d2f8e1e1fea7a9113ade3287acb7884", "local_processed_path": "./data/processed/hhs_oig/christi-grimm-testimony-04182023.jsonl"} diff --git a/requirements.txt b/requirements.txt index 9cbbf5d..4c83c7a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,7 +14,7 @@ pre-commit==3.7.1 RapidFuzz==3.13.0 requests==2.32.3 torch==2.6.0 -transformers==4.48.3 +transformers==4.51.3 evaluate scikit-learn==1.5.0 selenium==4.17.2 diff --git a/src/modeling/background_extraction.py b/src/modeling/background_extraction.py index 542d590..69f89a5 100644 --- a/src/modeling/background_extraction.py +++ b/src/modeling/background_extraction.py @@ -3,6 +3,7 @@ from multiprocessing import Pool, cpu_count import torch +from datasets import Dataset from rapidfuzz import fuzz from sklearn.model_selection import train_test_split from transformers import ( @@ -11,6 +12,11 @@ AutoTokenizer, ) +from src.modeling.data_augmentation import ( + augment_sufficient_examples, + generate_unrelated_content, + rewrite_to_generic, +) from src.util import add_jsonl_lines, batcher, get_records_list @@ -210,6 +216,8 @@ def construct_recs(raw_recs: list[dict], tokenizer: AutoTokenizer, model: torch. "decision": rec["decision"], "appeal_type": rec["appeal_type"], "full_text": rec["text"], + "jurisdiction": rec.get("jurisdiction", "Unspecified"), + "insurance_type": rec.get("insurance_type", "Unspecified"), } backgrounds.append(updated_record) idx += 1 @@ -241,6 +249,8 @@ def construct_recs_batch( "appeal_type": rec["appeal_type"], "full_text": rec["text"], "sufficiency_id": sufficiency_id, + "jurisdiction": rec.get("jurisdiction", "Unspecified"), + "insurance_type": rec.get("insurance_type", "Unspecified"), } for (rec, background, sufficiency_id) in zip(batch, background_batch, sufficiency_batch) ] @@ -319,6 +329,217 @@ def combine_jsonl(directory) -> None: return None +def create_augmented_examples( + records: list[dict], + api_type: str = "openai", + api_url: str = "https://api.openai.com/v1/chat/completions", + api_key: str | None = None, + model_name: str = "gpt-4o", + generic_rewrites_per_example: int = 1, + sufficient_augmentations_per_example: int = 1, + num_unrelated_examples: int = 100, + api_call_limit: int | None = 700, + seed: int = 42, +) -> list[dict]: + """ + Create augmented examples from background texts with optional limit on API calls. + + Args: + records: List of record dictionaries with text, decision, etc. + api_type: Type of API to use ("openai" or "llamacpp") + api_url: URL for the API endpoint + model_name: Name of the model to use + generic_rewrites_per_example: Number of generic rewrites per sufficient example + sufficient_augmentations_per_example: Number of sufficient augmentations per sufficient example + num_unrelated_examples: Number of unrelated content examples to generate + api_call_limit: Maximum number of API calls to make (if None, no limit) + seed: Random seed for reproducibility + + Returns: + List of augmented record dictionaries + """ + import random + + random.seed(seed) + + print(f"Creating augmented examples using {api_type} API at {api_url}...") + augmented_records = [] + + # Track API calls + api_calls_made = 0 + + # 1. Create dataset for augmentation functions + dataset_records = [] + for rec in records: + # Check if record has sufficiency_id + sufficiency_score = 4 # Default to sufficient for original records + if "sufficiency_id" in rec: + sufficiency_score = 4 if rec["sufficiency_id"] == 1 else 2 + + dataset_records.append({"text": rec["text"], "sufficiency_score": sufficiency_score}) + + dataset = Dataset.from_list(dataset_records) + + # If we have a limit, we need to randomly select which examples to augment + sufficient_examples = [ex for ex in dataset_records if ex["sufficiency_score"] >= 3] + num_sufficient = len(sufficient_examples) + + # Calculate how many API calls each technique would require + total_generic_calls = num_sufficient * generic_rewrites_per_example + total_sufficient_calls = num_sufficient * sufficient_augmentations_per_example + total_unrelated_calls = num_unrelated_examples + + # Calculate total potential API calls + total_potential_calls = total_generic_calls + total_sufficient_calls + total_unrelated_calls + + # If we have a limit and it's less than the potential total, adjust + if api_call_limit is not None and api_call_limit < total_potential_calls: + print(f"API call limit ({api_call_limit}) is less than potential total ({total_potential_calls})") + + # Distribute the limit evenly across the three techniques + calls_per_technique = api_call_limit // 3 + + # Calculate adjusted calls for each technique + adjusted_generic_calls = calls_per_technique + adjusted_sufficient_calls = calls_per_technique + adjusted_unrelated_calls = ( + api_call_limit - adjusted_generic_calls - adjusted_sufficient_calls + ) # Use remainder for unrelated + + # Calculate how many examples we can augment for each technique + examples_for_generic = adjusted_generic_calls // generic_rewrites_per_example + examples_for_sufficient = adjusted_sufficient_calls // sufficient_augmentations_per_example + + # Randomly select examples to augment + if examples_for_generic < num_sufficient: + examples_generic = random.sample(sufficient_examples, examples_for_generic) + else: + examples_generic = sufficient_examples + + if examples_for_sufficient < num_sufficient: + examples_sufficient = random.sample(sufficient_examples, examples_for_sufficient) + else: + examples_sufficient = sufficient_examples + + # Adjust unrelated examples count + adjusted_num_unrelated = adjusted_unrelated_calls + + print("Adjusted numbers based on API call limit:") + print(f" Generic rewrites: {examples_for_generic} examples ({adjusted_generic_calls} calls)") + print(f" Sufficient augmentations: {examples_for_sufficient} examples ({adjusted_sufficient_calls} calls)") + print(f" Unrelated content: {adjusted_num_unrelated} examples ({adjusted_unrelated_calls} calls)") + + # Create filtered datasets for limited augmentation + generic_dataset = Dataset.from_list(examples_generic) + sufficient_dataset = Dataset.from_list(examples_sufficient) + + # Update parameters + generic_params_count = adjusted_generic_calls + sufficient_params_count = adjusted_sufficient_calls + unrelated_params_count = adjusted_unrelated_calls + else: + # No limit or limit is high enough, use all examples + generic_dataset = dataset + sufficient_dataset = dataset + generic_params_count = total_generic_calls + sufficient_params_count = total_sufficient_calls + unrelated_params_count = num_unrelated_examples + + # 2. Apply generic rewrite augmentation (make sufficient examples insufficient) + if generic_params_count > 0: + print("Generating generic rewrites...") + generic_rewrite_params = { + "num_augmentations_per_example": generic_rewrites_per_example, + "api_type": api_type, + "api_url": api_url, + "api_key": api_key, + "model_name": model_name, + "seed": seed, + } + + generic_examples = rewrite_to_generic(generic_dataset, **generic_rewrite_params) + api_calls_made += len(generic_examples) + print(f"Generated {len(generic_examples)} examples through generic rewriting") + else: + generic_examples = [] + print("Skipping generic rewrites due to API call limit") + + # 3. Apply sufficient example augmentation (keep sufficient examples sufficient) + if sufficient_params_count > 0: + print("Generating sufficient augmentations...") + sufficient_augmentation_params = { + "num_augmentations_per_example": sufficient_augmentations_per_example, + "api_type": api_type, + "api_url": api_url, + "api_key": api_key, + "model_name": model_name, + "seed": seed, + } + + sufficient_examples = augment_sufficient_examples(sufficient_dataset, **sufficient_augmentation_params) + api_calls_made += len(sufficient_examples) + print(f"Generated {len(sufficient_examples)} augmented sufficient examples") + else: + sufficient_examples = [] + print("Skipping sufficient augmentations due to API call limit") + + # 4. Generate unrelated content + if unrelated_params_count > 0: + print("Generating unrelated content...") + unrelated_params = { + "num_examples": unrelated_params_count, + "api_type": api_type, + "api_url": api_url, + "api_key": api_key, + "model_name": model_name, + "seed": seed, + } + + unrelated_examples = generate_unrelated_content(**unrelated_params) + api_calls_made += len(unrelated_examples) + print(f"Generated {len(unrelated_examples)} examples with unrelated content") + else: + unrelated_examples = [] + print("Skipping unrelated content due to API call limit") + + # 5. Convert augmented texts back to record format + # For generic rewrites and sufficient augmentations, we need to find the original record + all_records_dict = {rec["text"]: rec for rec in records} + + for aug_example in generic_examples + sufficient_examples: + source_text = aug_example["source_text"] + # Find the original record + original_rec = all_records_dict.get(source_text) + if original_rec: + augmented_rec = original_rec.copy() + augmented_rec["text"] = aug_example["text"] + augmented_rec["sufficiency_id"] = 1 if aug_example["sufficiency_score"] >= 3 else 0 + augmented_rec["augmentation_type"] = aug_example["augmentation_type"] + augmented_records.append(augmented_rec) + + # For unrelated content, create new records with random decision label + decisions = ["Upheld", "Overturned"] + appeal_types = ["IMR", "DMHC", "CDI"] # Example appeal types + + for i, aug_example in enumerate(unrelated_examples): + augmented_rec = { + "text": aug_example["text"], + "decision": random.choice(decisions), # Random decision since unrelated to actual cases + "appeal_type": random.choice(appeal_types), + "full_text": aug_example["text"], # No original full text + "sufficiency_id": 0, # Always insufficient + "jurisdiction": "Unspecified", + "insurance_type": "Unspecified", + "augmentation_type": aug_example["augmentation_type"], + "id": f"unrelated_{i}", + } + augmented_records.append(augmented_rec) + + print(f"Total API calls made: {api_calls_made}") + print(f"Total augmented records created: {len(augmented_records)}") + return augmented_records + + if __name__ == "__main__": # Config applied to both models device = "cuda" @@ -361,7 +582,7 @@ def combine_jsonl(directory) -> None: # Combine CA CDI files into single jsonl combine_jsonl("./data/processed/ca_cdi/summaries") - # Define Outcome Map / preprocessing + # Define Outcome Map / preprocessing and jurisdiction mapping # TODO: decide if this standardization belongs here, or in raw dataset storage (i.e in records we read here) extraction_targets = [ ( @@ -394,6 +615,11 @@ def combine_jsonl(directory) -> None: # We will also split a train and test set for consistent experiments train_out_path = "./data/outcomes/train_backgrounds_suff.jsonl" test_out_path = "./data/outcomes/test_backgrounds_suff.jsonl" + + # New paths for augmented datasets + train_augmented_out_path = "./data/outcomes/train_backgrounds_suff_augmented.jsonl" + test_augmented_out_path = "./data/outcomes/test_backgrounds_suff_augmented.jsonl" + train_subset = [] test_subset = [] for path, outcome_map in extraction_targets: @@ -433,5 +659,76 @@ def combine_jsonl(directory) -> None: train_subset.extend(train_recs) test_subset.extend(val_recs) + # Write original train and test sets add_jsonl_lines(train_out_path, train_subset) add_jsonl_lines(test_out_path, test_subset) + + print(f"Original train set: {len(train_subset)} examples") + print(f"Original test set: {len(test_subset)} examples") + + # Check for OpenAI API key in environment variable + api_key = os.environ.get("OPENAI_API_KEY") + api_type = "openai" if api_key else "llamacpp" + api_url = "https://api.openai.com/v1/chat/completions" if api_key else "http://localhost:8080/completion" + model_name = "gpt-4o" if api_key else "llama-3.1" + + print(f"Using {api_type} API for augmentation") + + train_subset = get_records_list(train_out_path) + test_subset = get_records_list(test_out_path) + + # Create augmented examples for train set + print("\n=== Creating Augmented Examples for Train Set ===") + train_augmented = create_augmented_examples( + train_subset, + api_type=api_type, + api_url=api_url, + api_key=api_key, + model_name=model_name, + generic_rewrites_per_example=1, + sufficient_augmentations_per_example=1, + num_unrelated_examples=1000, + api_call_limit=6000, + seed=42, + ) + + # Write augmented train and test sets + # Combine original and augmented examples + train_combined = train_subset + train_augmented + add_jsonl_lines(train_augmented_out_path, train_combined) + + # Create augmented examples for test set + print("\n=== Creating Augmented Examples for Test Set ===") + test_augmented = create_augmented_examples( + test_subset, + api_type=api_type, + api_url=api_url, + api_key=api_key, + model_name=model_name, + generic_rewrites_per_example=1, # Fewer augmentations for test set + sufficient_augmentations_per_example=1, + num_unrelated_examples=1000, + api_call_limit=6000, + seed=43, # Different seed for test set + ) + + # Combine original and augmented examples + test_combined = test_subset + test_augmented + + # Write augmented test set + add_jsonl_lines(test_augmented_out_path, test_combined) + + # print(f"\nFinal augmented train set: {len(train_combined)} examples") + print(f"Final augmented test set: {len(test_combined)} examples") + + # Print augmentation statistics + train_augmented_types = {} + for rec in train_augmented: + aug_type = rec.get("augmentation_type", "unknown") + if aug_type not in train_augmented_types: + train_augmented_types[aug_type] = 0 + train_augmented_types[aug_type] += 1 + + print("\nTrain set augmentation breakdown:") + for aug_type, count in train_augmented_types.items(): + print(f" {aug_type}: {count}") diff --git a/src/modeling/config/outcome_prediction/clinicalbert.yaml b/src/modeling/config/outcome_prediction/clinicalbert.yaml index 1388b3b..3b538c7 100644 --- a/src/modeling/config/outcome_prediction/clinicalbert.yaml +++ b/src/modeling/config/outcome_prediction/clinicalbert.yaml @@ -8,17 +8,17 @@ hicric_pretrained: False base_model_name: "clinicalbert" pretrained_model_dir: "None" pretrained_hf_model_key: "medicalai/ClinicalBERT" -train_data_path: "./data/outcomes/train_backgrounds_suff.jsonl" +train_data_path: "./data/outcomes/train_backgrounds_suff_augmented.jsonl" # Training settings learning_rate: 8.0e-7 weight_decay: 0.01 num_epochs: 40 -batch_size: 20 +batch_size: 32 dtype: "float16" # 'float32','float16' for training dtype compile: True # Whether to use torch compile # Test eval settings test_data_path: "./data/outcomes/test_backgrounds_suff.jsonl" -checkpoint_name: "checkpoint-46134" -eval_threshold: .55 \ No newline at end of file +checkpoint_name: checkpoint-36288 +eval_threshold: .65 \ No newline at end of file diff --git a/src/modeling/config/outcome_prediction/distilbert.yaml b/src/modeling/config/outcome_prediction/distilbert.yaml index 6e2ebe3..cbc46b8 100644 --- a/src/modeling/config/outcome_prediction/distilbert.yaml +++ b/src/modeling/config/outcome_prediction/distilbert.yaml @@ -8,17 +8,17 @@ hicric_pretrained: False base_model_name: "distilbert" pretrained_model_dir: "None" pretrained_hf_model_key: "distilbert/distilbert-base-uncased" -train_data_path: "./data/outcomes/train_backgrounds_suff.jsonl" +train_data_path: "./data/outcomes/train_backgrounds_suff_augmented.jsonl" # Training settings learning_rate: 8.0e-7 weight_decay: 0.01 -num_epochs: 40 -batch_size: 20 +num_epochs: 20 +batch_size: 48 dtype: "float16" # 'float32','float16' for training dtype compile: True # Whether to use torch compile # Test eval settings test_data_path: "./data/outcomes/test_backgrounds_suff.jsonl" -checkpoint_name: "checkpoint-41008" -eval_threshold: .55 \ No newline at end of file +checkpoint_name: "checkpoint-23040" +eval_threshold: .7 \ No newline at end of file diff --git a/src/modeling/config/outcome_prediction/legalbert_small.yaml b/src/modeling/config/outcome_prediction/legalbert_small.yaml index 493fb8c..393a628 100644 --- a/src/modeling/config/outcome_prediction/legalbert_small.yaml +++ b/src/modeling/config/outcome_prediction/legalbert_small.yaml @@ -8,17 +8,17 @@ hicric_pretrained: False base_model_name: "legal-bert-small-uncased" pretrained_model_dir: "None" pretrained_hf_model_key: "nlpaueb/legal-bert-small-uncased" -train_data_path: "./data/outcomes/train_backgrounds_suff.jsonl" +train_data_path: "./data/outcomes/train_backgrounds_suff_augmented.jsonl" # Training settings learning_rate: 8.0e-7 weight_decay: 0.01 -num_epochs: 40 -batch_size: 20 +num_epochs: 20 +batch_size: 48 dtype: "float16" # 'float32','float16' for training dtype compile: True # Whether to use torch compile # Test eval settings test_data_path: "./data/outcomes/test_backgrounds_suff.jsonl" -checkpoint_name: "checkpoint-82016" -eval_threshold: .8 \ No newline at end of file +checkpoint_name: "checkpoint-20736" +eval_threshold: .55 \ No newline at end of file diff --git a/src/modeling/config/sufficiency_classification/augmented.yaml b/src/modeling/config/sufficiency_classification/augmented.yaml index 95031fe..53cedc5 100644 --- a/src/modeling/config/sufficiency_classification/augmented.yaml +++ b/src/modeling/config/sufficiency_classification/augmented.yaml @@ -22,4 +22,11 @@ use_saved_augmentations: false sufficient_augmentation_params: api_type: openai api_url: https://api.openai.com/v1/chat/completions - model_name: gpt-4o \ No newline at end of file + model_name: gpt-4o + num_augmentations_per_example: 2 + +generic_rewrite_params: + num_augmentations_per_example: 2 + +unrelated_params: + num_examples: 500 \ No newline at end of file diff --git a/src/modeling/data_augmentation.py b/src/modeling/data_augmentation.py index d250871..2cf8696 100644 --- a/src/modeling/data_augmentation.py +++ b/src/modeling/data_augmentation.py @@ -85,7 +85,7 @@ def rewrite_to_generic( {"role": "user", "content": prompt}, ], "max_tokens": 500, - "temperature": 0.7, + "temperature": 1.2, } else: # llama request_data = {"prompt": f"{system_instruction}\n\n{prompt}", "temperature": 0.7, "stream": False} @@ -187,7 +187,7 @@ def generate_unrelated_content( "Make a general statement about insurance denials", "Make a statement about insurance denials for a particular type of care, but that does not describe a specific situation.", "Write some generic unrelated user input (hello, how are you, etc.). Nothing inappropriate.", - "Write some gibberish that might be input accidentally by a user, or sent inadverdently, cut off sentences, etc.", + "Write some gibberish that might be input accidentally by a user, or sent inadvertently, cut off sentences, etc.", ] # Calculate how many examples to generate per category @@ -219,12 +219,12 @@ def generate_unrelated_content( {"role": "user", "content": category}, ], "max_tokens": 100, - "temperature": 0.8, + "temperature": 1.2, } else: # LLaMa.cpp request_data = { "prompt": f"{system_instruction}\n\n{category}", - "temperature": 0.8, + "temperature": 1.2, "max_tokens": 100, "stream": False, } @@ -243,8 +243,8 @@ def generate_unrelated_content( new_text = response_json["content"].strip() # Clean up and limit length - if len(new_text) > 200: - new_text = new_text[:200] + if len(new_text) > 512: + new_text = new_text[:512] # Add the example with metadata all_generations.append( @@ -280,8 +280,8 @@ def generate_unrelated_content( def augment_sufficient_examples( dataset: Dataset, num_augmentations_per_example: int = 1, - api_type: str = "llamacpp", # "openai" or "llamacpp" - api_url: str = "http://localhost:8080/completion", + api_type: str = "openai", # "openai" or "llamacpp" + api_url: str = "https://api.openai.com/v1/chat/completions", api_key: str | None = None, model_name: str = "localllama", seed: int = 42, @@ -331,7 +331,7 @@ def augment_sufficient_examples( ) clinical_prompt = "Rewrite the following description using more clinical and technical medical terminology, maintaining all the key details. {}" paraphrase_prompt = "Rewrite the following description in different words while preserving the exact same meaning and all key details. Make minimal changes: {}" - details_prompt = "Rewrite the following description, adding a few more specific details about the condition and treatment, while preserving the core information. Make minimal changes: {}" + details_prompt = "Rewrite the following description, adding a few more specific details about the condition and treatment, but removing none. Make minimal changes: {}" # System instruction for models system_instruction = "You are a helpful assistant that rewrites healthcare denial descriptions while preserving their key information." @@ -353,21 +353,17 @@ def augment_sufficient_examples( continue # Delete 1-3 random words (but not too many) - num_to_delete = min(random.randint(1, 3), len(words) // 5) + num_to_delete = 1 indices_to_delete = random.sample(range(len(words)), num_to_delete) new_text = " ".join([w for i, w in enumerate(words) if i not in indices_to_delete]) - # Slightly reduce sufficiency score but keep it sufficient (>= 3) - new_score = max(3, sufficiency_score - 1) if random.random() < 0.3 else sufficiency_score - # Add the augmented example with metadata augmented_examples.append( { "text": new_text, - "sufficiency_score": new_score, + "sufficiency_score": sufficiency_score, "source_text": text, - "source_score": sufficiency_score, "augmentation_type": "sufficient_word_deletion", } ) @@ -398,13 +394,13 @@ def augment_sufficient_examples( {"role": "user", "content": prompt}, ], "max_tokens": 500, - "temperature": 0.7, + "temperature": 1.2, } else: # LLaMa.cpp # LLaMa.cpp format (depends on your server implementation) request_data = { "prompt": f"{system_instruction}\n\n{prompt}", - "temperature": 0.7, + "temperature": 1.2, "max_tokens": 500, "stop": ["\n\n", "###"], # Common stop sequences "stream": False, @@ -543,7 +539,7 @@ def process_and_save_augmentations( train_records = [dataset_records[i] for i in train_indices] test_records = [dataset_records[i] for i in test_indices] - # Apply augmentations only to training data only + # Apply augmentations to training data only train_dataset = Dataset.from_list(train_records) test_dataset = Dataset.from_list(test_records) augmented_train_records = [] diff --git a/src/modeling/eval_hf_outcome_predictor.py b/src/modeling/eval_hf_outcome_predictor.py index 6044034..f44239f 100644 --- a/src/modeling/eval_hf_outcome_predictor.py +++ b/src/modeling/eval_hf_outcome_predictor.py @@ -5,6 +5,7 @@ import numpy as np import scipy import torch +from safetensors import safe_open from sklearn.metrics import ( accuracy_score, f1_score, @@ -13,11 +14,11 @@ roc_auc_score, ) from transformers import ( - AutoModelForSequenceClassification, AutoTokenizer, TextClassificationPipeline, ) +from src.modeling.train_outcome_predictor import create_metadata_model from src.modeling.util import load_config from src.util import get_records_list @@ -25,6 +26,10 @@ ID2LABEL = {0: "Insufficient", 1: "Upheld", 2: "Overturned"} LABEL2ID = {v: k for k, v in ID2LABEL.items()} +# Define mappings for jurisdiction and insurance type +JURISDICTION_MAP = {"NY": 0, "CA": 1, "Unspecified": 2} +INSURANCE_TYPE_MAP = {"Commercial": 0, "Medicaid": 1, "Unspecified": 2} + # TODO: centralize label construction, label consts def construct_label(outcome, sufficiency_id, label2id): @@ -35,6 +40,91 @@ def construct_label(outcome, sufficiency_id, label2id): return label2id[outcome] +def inspect_model_metadata(checkpoint_path): + """Load a saved model in safetensors format and examine its metadata components.""" + # Path to the safetensors file + safetensors_path = os.path.join(checkpoint_path, "model.safetensors") + if not os.path.exists(safetensors_path): + # Try alternate naming patterns + potential_paths = [ + os.path.join(checkpoint_path, "pytorch_model.safetensors"), + os.path.join(checkpoint_path, "model.safetensors"), + ] + for path in potential_paths: + if os.path.exists(path): + safetensors_path = path + break + + print(f"Loading model from {safetensors_path}") + + if not os.path.exists(safetensors_path): + print(f"No safetensors file found at {safetensors_path}") + return None + + # Load the model using safetensors + tensors = {} + with safe_open(safetensors_path, framework="pt", device="cpu") as f: + for key in f.keys(): + tensors[key] = f.get_tensor(key) + + # Check if metadata_scale exists and print its value + if "metadata_scale" in tensors: + metadata_scale = tensors["metadata_scale"].item() + print(f"Metadata scale: {metadata_scale}") + else: + print("metadata_scale not found in state dict") + # Try looking for it with model prefix + for key in tensors: + if key.endswith("metadata_scale"): + metadata_scale = tensors[key].item() + print(f"Found metadata scale as {key}: {metadata_scale}") + + if "attention_scale" in tensors: + attention_scale = tensors["attention_scale"].item() + print(f"Attention scale: {attention_scale}") + else: + # Try looking for it with model prefix + for key in tensors: + if key.endswith("attention_scale"): + attention_scale = tensors[key].item() + print(f"Found attention scale as {key}: {attention_scale}") + + # Find embedding keys + j_embedding_key = None + i_embedding_key = None + for key in tensors: + if "jurisdiction_embeddings.weight" in key: + j_embedding_key = key + if "insurance_type_embeddings.weight" in key: + i_embedding_key = key + + # Check embedding distances + if j_embedding_key and i_embedding_key: + j_embeddings = tensors[j_embedding_key].numpy() + i_embeddings = tensors[i_embedding_key].numpy() + + print(f"Jurisdiction embedding shape: {j_embeddings.shape}") + print(f"Insurance type embedding shape: {i_embeddings.shape}") + + # Calculate distances between embeddings + if j_embeddings.shape[0] >= 2: + j_distance = np.linalg.norm(j_embeddings[0] - j_embeddings[1]) + print(f"Distance between jurisdiction embeddings (0 vs 1): {j_distance}") + + if i_embeddings.shape[0] >= 2: + i_distance = np.linalg.norm(i_embeddings[0] - i_embeddings[1]) + print(f"Distance between insurance type embeddings (0 vs 1): {i_distance}") + else: + print("Embedding weights not found in tensors") + print("Available keys:", list(tensors.keys())) + + return { + "metadata_scale": metadata_scale if "metadata_scale" in tensors else None, + "j_embeddings": j_embeddings if j_embedding_key else None, + "i_embeddings": i_embeddings if i_embedding_key else None, + } + + def compute_metrics(predictions: np.ndarray, labels: np.ndarray) -> dict: roc_auc = roc_auc_score(labels, scipy.special.softmax(predictions, axis=-1), multi_class="ovr") @@ -131,6 +221,34 @@ def postprocess(self, model_outputs): return out_logits +# Check if the model has a custom forward method that supports metadata +def is_metadata_model(model): + """Check if model has metadata capabilities by inspecting its methods""" + # Look for the key attributes in our custom model + has_j_embeddings = hasattr(model, "jurisdiction_embeddings") + has_i_embeddings = hasattr(model, "insurance_type_embeddings") + + # Check if model's forward method signature includes the metadata params + forward_sig = model.forward.__code__.co_varnames + accepts_metadata = "jurisdiction_id" in forward_sig and "insurance_type_id" in forward_sig + + return has_j_embeddings and has_i_embeddings and accepts_metadata + + +def count_parameters(model): + """Count the number of trainable parameters in the model and calculate size in MB""" + param_count = 0 + size_in_bytes = 0 + + for p in model.parameters(): + if p.requires_grad: + param_count += p.numel() + size_in_bytes += p.numel() * p.element_size() + + size_in_mb = size_in_bytes / (1024 * 1024) + return param_count, size_in_mb + + def main(config_path: str): cfg = load_config(config_path) @@ -147,40 +265,143 @@ def main(config_path: str): # Load raw dataset test_dataset = get_records_list(dataset_path) - # tokenizer = AutoTokenizer.from_pretrained(pretrained_model_key, model_max_length=512) - checkpoints_dir = os.path.join(MODEL_DIR, "train_backgrounds_suff", base_model_name) + # Set up paths + checkpoints_dir = os.path.join(MODEL_DIR, "train_backgrounds_suff_augmented", base_model_name) print("Evaluating from checkpoint: ", checkpoint_name) ckpt_path = os.path.join(checkpoints_dir, checkpoint_name) + _results = inspect_model_metadata(ckpt_path) + + # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(ckpt_path) - model = AutoModelForSequenceClassification.from_pretrained( - ckpt_path, num_labels=3, id2label=ID2LABEL, label2id=LABEL2ID + + model = create_metadata_model( + pretrained_model_key=ckpt_path, + base_model_name=base_model_name, + num_labels=3, + id2label=ID2LABEL, + label2id=LABEL2ID, ) - # Isolate records + # Count model parameters + param_count, size_in_mb = count_parameters(model) + print(f"Model size: {param_count:,} parameters ({size_in_mb:.2f} MB)") + + # Check if model supports metadata + supports_metadata = is_metadata_model(model) + + # Prepare data text_records = [rec["text"] for rec in test_dataset] labels = [construct_label(rec["decision"], rec["sufficiency_id"], LABEL2ID) for rec in test_dataset] - device = "cuda" - # pipeline = ClassificationPipeline(model=model, tokenizer=tokenizer, device=device) - pipeline = ClassificationPipeline(model=model, tokenizer=tokenizer, device=device, truncation=True) + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Using device: {device}") + + # For models with metadata support, use a custom approach + if supports_metadata: + # Extract metadata + jurisdictions = [rec.get("jurisdiction", "Unspecified") for rec in test_dataset] + insurance_types = [rec.get("insurance_type", "Unspecified") for rec in test_dataset] + + # Convert metadata to tensors + j_ids = torch.tensor([JURISDICTION_MAP.get(j, JURISDICTION_MAP["Unspecified"]) for j in jurisdictions]) + i_ids = torch.tensor([INSURANCE_TYPE_MAP.get(i, INSURANCE_TYPE_MAP["Unspecified"]) for i in insurance_types]) - predictions = torch.cat(pipeline(text_records, batch_size=100)) + model.to(device) + model.eval() + # Process in batches + batch_size = 100 + all_logits = [] + + for i in range(0, len(text_records), batch_size): + end_idx = min(i + batch_size, len(text_records)) + batch_texts = text_records[i:end_idx] + batch_j_ids = j_ids[i:end_idx].to(device) + batch_i_ids = i_ids[i:end_idx].to(device) + + inputs = tokenizer(batch_texts, truncation=True, padding=True, return_tensors="pt").to(device) + + with torch.no_grad(): + outputs = model(**inputs, jurisdiction_id=batch_j_ids, insurance_type_id=batch_i_ids) + + all_logits.append(outputs["logits"].cpu()) + + predictions = torch.cat(all_logits) + + # Now create predictions with unspecified metadata for all examples + print("Computing metrics with unspecified metadata for all examples...") + unspecified_j_ids = torch.full((len(text_records),), JURISDICTION_MAP["Unspecified"], dtype=torch.long) + unspecified_i_ids = torch.full((len(text_records),), INSURANCE_TYPE_MAP["Unspecified"], dtype=torch.long) + + all_unspecified_logits = [] + + for i in range(0, len(text_records), batch_size): + end_idx = min(i + batch_size, len(text_records)) + batch_texts = text_records[i:end_idx] + batch_j_ids = unspecified_j_ids[i:end_idx].to(device) + batch_i_ids = unspecified_i_ids[i:end_idx].to(device) + + inputs = tokenizer(batch_texts, truncation=True, padding=True, return_tensors="pt").to(device) + + with torch.no_grad(): + outputs = model(**inputs, jurisdiction_id=batch_j_ids, insurance_type_id=batch_i_ids) + + all_unspecified_logits.append(outputs["logits"].cpu()) + + unspecified_predictions = torch.cat(all_unspecified_logits) + else: + # For standard models, use the original pipeline + pipeline = ClassificationPipeline(model=model, tokenizer=tokenizer, device=device, truncation=True) + predictions = torch.cat(pipeline(text_records, batch_size=100)) + # No metadata to ignore for standard models, so unspecified predictions are the same + unspecified_predictions = predictions + + # Compute standard metrics + metrics = compute_metrics(predictions, labels) + metrics["param_count"] = param_count + metrics["model_size_mb"] = size_in_mb + + # Add model size to metrics + metrics["param_count"] = param_count + + # Compute metrics with unspecified metadata + unspecified_metrics = compute_metrics(unspecified_predictions, labels) + # Add model size to unspecified metrics + unspecified_metrics["param_count"] = param_count + + # Compute threshold metrics if specified if threshold is not None: threshold_metrics = compute_metrics_w_threshold(predictions, labels, threshold) + threshold_metrics["param_count"] = param_count + threshold_metrics["model_size_mb"] = size_in_mb - metrics = compute_metrics(predictions, labels) + unspecified_threshold_metrics = compute_metrics_w_threshold(unspecified_predictions, labels, threshold) + unspecified_threshold_metrics["param_count"] = param_count + unspecified_threshold_metrics["model_size_mb"] = size_in_mb # Print and write metrics + print("Standard metrics:") print(metrics) with open(os.path.join(ckpt_path, "test_metrics.json"), "w") as f: json.dump(metrics, f) + + print("\nMetrics with unspecified metadata:") + print(unspecified_metrics) + with open(os.path.join(ckpt_path, "test_metrics_unspecified.json"), "w") as f: + json.dump(unspecified_metrics, f) + if threshold is not None: + print("\nStandard threshold metrics:") print(threshold_metrics) with open(os.path.join(ckpt_path, "test_metrics_w_threshold.json"), "w") as f: json.dump(threshold_metrics, f) + print("\nThreshold metrics with unspecified metadata:") + print(unspecified_threshold_metrics) + with open(os.path.join(ckpt_path, "test_metrics_w_threshold_unspecified.json"), "w") as f: + json.dump(unspecified_threshold_metrics, f) + return None diff --git a/src/modeling/openai_batch_call.py b/src/modeling/openai_batch_call.py index 4269993..3583b77 100644 --- a/src/modeling/openai_batch_call.py +++ b/src/modeling/openai_batch_call.py @@ -1,3 +1,5 @@ +import argparse +import json import os import time @@ -11,11 +13,11 @@ MODEL_KEY = "gpt-4o-mini-2024-07-18" OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", None) SYSTEM_MESSAGE = """You are an expert in U.S. health law and health policy, as well as a medical expert. In what follows, I will provide a description of a case in which a patient has submitted an appeal of a decision by their health insurer to deny a claim submitted on their behalf. You must predict whether an independent reviewer will overturn the denial, or uphold the denial when reviewing the appeal. If they would overturn it, your decision would be "Overturned". If they would not, your decision would be "Upheld". If there is insufficient information in the context to predict this, and it could go either way depending on more details, your decision would be "Insufficient". -Very short cases are often insufficient, as are cases which describe a treatment or service without saying what it is for. Sufficiency does not mean sufficient-without-a-doubt, it just means sufficient to make a good estimated guess about which -way the review will go. Most cases I present to you will have sufficient information, by my subjective standards. + +Very short cases are often insufficient, as are cases which describe a treatment or service without saying what it is for. Sufficiency does not mean sufficient-without-a-doubt, it just means sufficient to make a good estimated guess about which way the review will go. Most cases I present to you will have sufficient information, by my subjective standards. You must reply with json of the following form: -{"decision": "Overturned", "probability": .82} +{"decision": "Overturned", "probability": 0.82} with the decision, and the associated probability that that decision is correct. The possible decision classes are "Insufficient", "Upheld", and "Overturned". @@ -23,10 +25,10 @@ Here are two examples: Prompt: "An enrollee has requested Zepatier for treatment of her hepatitis C." -Desired Output: {"decision": "Overturned", "probability": .75} +Desired Output: {"decision": "Overturned", "probability": 0.75} Prompt: "An enrollee has requested emergency services provided on an emergent or urgent basis for treatment of her medical condition." -Desired Output: {"decision": "Insufficient", "probability": .99} +Desired Output: {"decision": "Insufficient", "probability": 0.99} """ @@ -84,7 +86,20 @@ def construct_answer_batch(recs: list[dict], output_path: str) -> None: outcome = rec["decision"] sufficiency_id = rec["sufficiency_id"] - answers.append({"custom_id": f"{custom_id}", "decision": construct_label(outcome, sufficiency_id)}) + # Get metadata fields (default to "Unspecified" if not present) + jurisdiction = rec.get("jurisdiction", "Unspecified") + insurance_type = rec.get("insurance_type", "Unspecified") + + # Include all fields in the answer + answers.append( + { + "custom_id": f"{custom_id}", + "decision": construct_label(outcome, sufficiency_id), + "sufficiency_id": sufficiency_id, + "jurisdiction": jurisdiction, + "insurance_type": insurance_type, + } + ) add_jsonl_lines(output_path, answers) @@ -104,7 +119,7 @@ def construct_request_batch(recs: list[dict], model_key: str, output_path: str) return None -def prepare_batches() -> None: +def prepare_batches() -> tuple: recs = get_records_list(os.path.join("./data/outcomes/", OUTCOMES_DATASET + ".jsonl")) output_path = f"./data/provider_annotated_outcomes/openai/{OUTCOMES_DATASET}/hicric_eval_request_answers.jsonl" if os.path.exists(output_path): @@ -122,7 +137,7 @@ def prepare_batches() -> None: # Full batch construct_request_batch(recs, MODEL_KEY, output_path) split_batch(single_batch_path=output_path, subbatch_dir=subbatch_dir) - return subbatch_dir + return subbatch_dir, recs def split_batch(single_batch_path, subbatch_dir, subbatch_size=1500) -> None: @@ -136,7 +151,7 @@ def split_batch(single_batch_path, subbatch_dir, subbatch_size=1500) -> None: return None -def batch_call(filepaths: list[str], api_key: str | None = OPENAI_API_KEY) -> None: +def batch_call(filepaths: list[str], api_key: str | None = OPENAI_API_KEY) -> list[str]: """Submit a batch completion call for each batch file in filepaths.""" if not api_key: raise Exception("You need to export an Open AI API key as an env var.") @@ -174,7 +189,7 @@ def batch_call(filepaths: list[str], api_key: str | None = OPENAI_API_KEY) -> No return batch_request_ids -def download_response(batch_id: str, subbatch_dir: str, api_key: str | None = OPENAI_API_KEY) -> None: +def download_response(batch_id: str, subbatch_dir: str, api_key: str | None = OPENAI_API_KEY) -> str: client = OpenAI(api_key=api_key) status_meta = client.batches.retrieve(batch_id) download_path = os.path.join(subbatch_dir, f"response_{batch_id}.jsonl") @@ -183,7 +198,7 @@ def download_response(batch_id: str, subbatch_dir: str, api_key: str | None = OP def poll_and_download( - batch_request_ids: list[str], subbatch_dir: str, poll_sleep_mins: int = 0.1, api_key: str | None = OPENAI_API_KEY + batch_request_ids: list[str], subbatch_dir: str, poll_sleep_mins: float = 0.1, api_key: str | None = OPENAI_API_KEY ) -> list[str]: client = OpenAI(api_key=api_key) @@ -217,19 +232,95 @@ def merge_jsonl(paths: list[str], output_path): return None +def synchronous_call(records: list[dict], api_key: str | None = OPENAI_API_KEY) -> list[dict]: + """Make synchronous API calls one at a time for each record.""" + if not api_key: + raise Exception("You need to export an Open AI API key as an env var.") + + client = OpenAI(api_key=api_key) + results = [] + + print(f"Processing {len(records)} records synchronously...") + for idx, rec in enumerate(records): + case_description = rec["text"] + custom_id = idx + + print(f"Processing record {idx+1}/{len(records)}") + try: + response = client.chat.completions.create( + model=MODEL_KEY, + messages=[{"role": "system", "content": SYSTEM_MESSAGE}, {"role": "user", "content": case_description}], + ) + + # Get the response content + completion_text = response.choices[0].message.content + + # Parse JSON response + try: + completion_json = json.loads(completion_text) + except json.JSONDecodeError: + print(f"Failed to parse JSON for record {idx}: {completion_text}") + completion_json = {"decision": "Error", "probability": 0} + + # Create result record + result = { + "custom_id": str(custom_id), + "request": { + "body": { + "messages": [ + {"role": "system", "content": SYSTEM_MESSAGE}, + {"role": "user", "content": case_description}, + ] + } + }, + "response": {"choices": [{"message": {"content": completion_text, "role": "assistant"}}]}, + "parsed_response": completion_json, + } + + results.append(result) + + except Exception as e: + print(f"Error processing record {idx}: {str(e)}") + # Add error record + results.append({"custom_id": str(custom_id), "error": str(e)}) + + return results + + if __name__ == "__main__": - subbatch_dir = prepare_batches() - filenames = [os.path.join(subbatch_dir, filename) for filename in os.listdir(subbatch_dir)] - - # Can only upload small number of files at a time due to enque limit - all_output_files = [] - batch_size = 1 - for start_idx in range(0, len(filenames), batch_size): - enqueued_files = filenames[start_idx : start_idx + batch_size] - batch_request_ids = batch_call(enqueued_files) - output_files = poll_and_download(batch_request_ids, subbatch_dir) - all_output_files.extend(output_files) - - # Merge to single response - output_path = f"./data/provider_annotated_outcomes/openai/{OUTCOMES_DATASET}/hicric_eval_response_{MODEL_KEY}.jsonl" - merge_jsonl(all_output_files, output_path) + parser = argparse.ArgumentParser(description="Run OpenAI API calls for health insurance claim evaluations") + parser.add_argument("--synchronous", action="store_true", help="Use synchronous API calls instead of batch") + args = parser.parse_args() + + subbatch_dir, records = prepare_batches() + + if args.synchronous: + # Synchronous mode + print("Running in synchronous mode...") + results = synchronous_call(records) + + # Save results + output_path = ( + f"./data/provider_annotated_outcomes/openai/{OUTCOMES_DATASET}/hicric_eval_response_sync_{MODEL_KEY}.jsonl" + ) + add_jsonl_lines(output_path, results) + print(f"Synchronous results saved to {output_path}") + else: + # Batch mode (original behavior) + filenames = [os.path.join(subbatch_dir, filename) for filename in os.listdir(subbatch_dir)] + + # Can only upload small number of files at a time due to enque limit + all_output_files = [] + batch_size = 1 + for start_idx in range(0, len(filenames), batch_size): + enqueued_files = filenames[start_idx : start_idx + batch_size] + batch_request_ids = batch_call(enqueued_files) + output_files = poll_and_download(batch_request_ids, subbatch_dir) + all_output_files.extend(output_files) + + # Merge to single response + output_path = ( + f"./data/provider_annotated_outcomes/openai/{OUTCOMES_DATASET}/hicric_eval_response_{MODEL_KEY}.jsonl" + ) + merge_jsonl(all_output_files, output_path) + print(f"Batch results saved to {output_path}") diff --git a/src/modeling/openai_eval.py b/src/modeling/openai_eval.py index 404341f..dc0c0ec 100644 --- a/src/modeling/openai_eval.py +++ b/src/modeling/openai_eval.py @@ -2,14 +2,41 @@ import os import numpy as np -from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score +from sklearn.metrics import ( + accuracy_score, + f1_score, + precision_score, + recall_score, + roc_auc_score, +) from src.util import get_records_list MODEL_KEY = "gpt-4o-mini-2024-07-18" +# Define mappings for jurisdiction and insurance type +JURISDICTION_MAP = {"NY": 0, "CA": 1, "Unspecified": 2} +INSURANCE_TYPE_MAP = {"Commercial": 0, "Medicaid": 1, "Unspecified": 2} + def compute_metrics(predictions: np.ndarray, labels: np.ndarray) -> dict: + # Check if we have enough unique classes for ROC AUC + try: + # Convert predictions to one-hot format for ROC AUC if they're not already + if len(predictions.shape) == 1: # If predictions are just class labels + n_classes = len(np.unique(labels)) + one_hot_preds = np.zeros((len(predictions), n_classes)) + for i, pred in enumerate(predictions): + if pred < n_classes: # Ensure valid index + one_hot_preds[i, pred] = 1 + roc_auc = roc_auc_score(labels, one_hot_preds, multi_class="ovr") + else: + # Assuming predictions are already probabilities/scores + roc_auc = roc_auc_score(labels, predictions, multi_class="ovr") + except ValueError: + # Handle case when there are classes with no predictions + roc_auc = float("nan") + acc = accuracy_score(labels, predictions) # Macro @@ -28,6 +55,7 @@ def compute_metrics(predictions: np.ndarray, labels: np.ndarray) -> dict: f1 = f1_score(labels, predictions, average=None, zero_division=np.nan) return { + "metric_fn_name": "compute_metrics", "accuracy": acc, "macro_precision": macro_prec.tolist(), "macro_recall": macro_rec.tolist(), @@ -38,6 +66,70 @@ def compute_metrics(predictions: np.ndarray, labels: np.ndarray) -> dict: "class_precision": prec.tolist(), "class_recall": rec.tolist(), "class_f1": f1.tolist(), + "roc_auc": roc_auc, + } + + +def compute_metrics_w_threshold(predictions: np.ndarray, labels: np.ndarray, threshold: float) -> dict: + # Convert predictions to probabilities if they're not already + if len(predictions.shape) == 1: # If predictions are just class labels + n_classes = max(3, np.max(predictions) + 1) # Ensure we have at least 3 classes + probs = np.zeros((len(predictions), n_classes)) + for i, pred in enumerate(predictions): + if pred < n_classes: # Ensure valid index + probs[i, pred] = 1 + else: + # Assuming predictions are already probabilities/scores + probs = predictions + + try: + roc_auc = roc_auc_score(labels, probs, multi_class="ovr") + except ValueError: + roc_auc = float("nan") + + # Apply threshold logic as in original implementation + pred_labels = np.full_like(labels, -1) # Initialize with -1 for unassigned classes + + # Assign class 0 where the probability meets or exceeds the threshold + class_0_mask = probs[:, 0] >= threshold + pred_labels[class_0_mask] = 0 + + # If probability does not exceed threshold, take argmax of upheld/overturned class options + remaining_mask = pred_labels == -1 + remaining_preds = probs[remaining_mask, 1:] # Exclude class 0 + pred_labels[remaining_mask] = np.argmax(remaining_preds, axis=1) + 1 # + 1 to adjust for class index shift + + # Compute metrics + acc = accuracy_score(labels, pred_labels) + + # Macro + macro_prec = precision_score(labels, pred_labels, average="macro") + macro_rec = recall_score(labels, pred_labels, average="macro") + macro_f1 = f1_score(labels, pred_labels, average="macro", zero_division=np.nan) + + # Micro + micro_prec = precision_score(labels, pred_labels, average="micro") + micro_rec = recall_score(labels, pred_labels, average="micro") + micro_f1 = f1_score(labels, pred_labels, average="micro", zero_division=np.nan) + + # Class-level + prec = precision_score(labels, pred_labels, average=None) + rec = recall_score(labels, pred_labels, average=None) + f1 = f1_score(labels, pred_labels, average=None, zero_division=np.nan) + + return { + "metric_fn_name": "compute_metrics_w_threshold", + "accuracy": acc, + "macro_precision": macro_prec.tolist(), + "macro_recall": macro_rec.tolist(), + "macro_f1": macro_f1.tolist(), + "micro_precision": micro_prec.tolist(), + "micro_recall": micro_rec.tolist(), + "micro_f1": micro_f1.tolist(), + "class_precision": prec.tolist(), + "class_recall": rec.tolist(), + "class_f1": f1.tolist(), + "roc_auc": roc_auc, } @@ -46,57 +138,108 @@ def merge_pred_gt(response_path: str, answers_path: str) -> dict: gt = get_records_list(answers_path) responses = get_records_list(response_path) + # Extract ground truth and metadata for rec in gt: - merged[rec["custom_id"]] = {"answer": rec["decision"]} - + # Store decision and metadata in merged dict + merged[rec["custom_id"]] = { + "answer": rec["decision"], + "jurisdiction": rec.get("jurisdiction", "Unspecified"), + "insurance_type": rec.get("insurance_type", "Unspecified"), + "sufficiency_id": rec.get("sufficiency_id", 1), # Default to sufficient if not specified + } + + # Extract predictions for rec in responses: - resp = rec["response"]["body"]["choices"][0]["message"]["content"] - try: - pred = json.loads(resp.replace(".", "0.")) - except json.JSONDecodeError: - print(rec["custom_id"]) - print(type(resp)) - print(resp) - pred = None - merged[rec["custom_id"]]["pred"] = pred + pred = rec["parsed_response"] + + # Add prediction to the merged dict if the ID exists + if rec["custom_id"] in merged: + merged[rec["custom_id"]]["pred"] = pred return merged -def eval_preds() -> dict: - pass +def construct_label(outcome, sufficiency_id, label_map): + """Construct 3-class label from outcome, if sufficient background.""" + if sufficiency_id == 0: + return 0 + else: + return label_map.get(outcome, 0) # Default to 0 if outcome not in map -if __name__ == "__main__": +def eval_preds(threshold=None) -> dict: outcomes_dataset = "test_backgrounds_suff" answers_path = f"./data/provider_annotated_outcomes/openai/{outcomes_dataset}/hicric_eval_request_answers.jsonl" response_path = ( f"./data/provider_annotated_outcomes/openai/{outcomes_dataset}/hicric_eval_response_{MODEL_KEY}.jsonl" ) - merged = merge_pred_gt(response_path, answers_path) + merged = merge_pred_gt(response_path, answers_path) out_map = {"Insufficient": 0, "Upheld": 1, "Overturned": 2} preds = [] labels = [] + jurisdiction_ids = [] + insurance_type_ids = [] + # Favorable evaluation: don't count non-answers against GPT, to be generous for id, rec in merged.items(): print(rec) - if rec["pred"] is None: + if "pred" not in rec or rec["pred"] is None: continue if rec["pred"].get("decision") is None: continue if rec["pred"].get("decision") not in out_map.keys(): continue else: + # Extract prediction preds.append(out_map[rec["pred"]["decision"]]) - labels.append(out_map[rec["answer"]]) + + # Extract label with sufficiency consideration + sufficiency_id = rec.get("sufficiency_id", 1) # Default to sufficient if not specified + labels.append(construct_label(rec["answer"], sufficiency_id, out_map)) + + # Extract metadata + j = rec.get("jurisdiction", "Unspecified") + i = rec.get("insurance_type", "Unspecified") + jurisdiction_ids.append(JURISDICTION_MAP.get(j, JURISDICTION_MAP["Unspecified"])) + insurance_type_ids.append(INSURANCE_TYPE_MAP.get(i, INSURANCE_TYPE_MAP["Unspecified"])) + if len(preds) < len(merged): print(f"Model failed to produce valid json for {len(merged) - len(preds)} values.") - metrics = compute_metrics(preds, labels) + + # Convert to numpy arrays for metrics computation + preds_np = np.array(preds) + labels_np = np.array(labels) + + # Compute standard metrics + metrics = compute_metrics(preds_np, labels_np) + + # Add metadata distribution info to metrics + metrics["jurisdiction_distribution"] = {j: jurisdiction_ids.count(JURISDICTION_MAP[j]) for j in JURISDICTION_MAP} + metrics["insurance_type_distribution"] = { + i: insurance_type_ids.count(INSURANCE_TYPE_MAP[i]) for i in INSURANCE_TYPE_MAP + } + + # Compute threshold metrics if threshold is provided + if threshold is not None: + threshold_metrics = compute_metrics_w_threshold(preds_np, labels_np, threshold) + + # Print and save metrics print(metrics) - with open( - os.path.join(f"./data/provider_annotated_outcomes/openai/{outcomes_dataset}", f"{MODEL_KEY}_test_metrics.json"), - "w", - ) as f: + output_dir = f"./data/provider_annotated_outcomes/openai/{outcomes_dataset}" + with open(os.path.join(output_dir, f"{MODEL_KEY}_test_metrics.json"), "w") as f: json.dump(metrics, f) + + if threshold is not None: + print(threshold_metrics) + with open(os.path.join(output_dir, f"{MODEL_KEY}_test_metrics_w_threshold.json"), "w") as f: + json.dump(threshold_metrics, f) + + return metrics + + +if __name__ == "__main__": + # Optionally specify a threshold for the additional metrics + threshold = 0.5 # Set to None if you don't want threshold-based metrics + eval_preds(threshold) diff --git a/src/modeling/predict.py b/src/modeling/predict.py index 9ef0a40..8a2d3c5 100644 --- a/src/modeling/predict.py +++ b/src/modeling/predict.py @@ -1,101 +1,369 @@ +#!/usr/bin/env python +import argparse +import copy import os -import time import numpy as np import onnxruntime import scipy import torch -from transformers import ( - AutoModelForSequenceClassification, - AutoTokenizer, - TextClassificationPipeline, - pipeline, -) +from colorama import Fore, Style, init +from transformers import AutoTokenizer +from src.modeling.train_outcome_predictor import TextClassificationWithMetadata from src.modeling.util import export_onnx_model, quantize_onnx_model +# Initialize colorama +init() + MODEL_DIR = "./models/overturn_predictor" +ID2LABEL = {0: "Insufficient", 1: "Upheld", 2: "Overturned"} +LABEL2ID = {v: k for k, v in ID2LABEL.items()} +# Test examples with expected classifications +TEST_EXAMPLES = [ + { + "text": "Diagnosis: Broken Ribs\nTreatment: Inpatient Hospital Admission\n\nThe insurer denied inpatient hospital admission. \n\nThe patient is an adult male. He presented by ambulance to the hospital with severe back pain. The patient had fallen down a ramp and onto his back two days prior. The patient developed back pain and had pain with deep inspiration, prompting a call to 911 for an ambulance. The patient was taking ibuprofen and Tylenol for pain at home. A computed tomography (CT) scan of the patient's chest showed a right posterior minimally displaced 9th and 10th rib fractures. There was no associated intra-abdominal injury. There was atelectasis of the lung in the region of the rib fractures. Vital signs, including oxygen saturation, were normal in the emergency department triage note. The patient did not require supplemental oxygen during the hospitalization. The patient was admitted to the acute inpatient level of care for pain control, breathing treatments, and venous thromboembolism prophylaxis. The patient was seen and cleared by Physical Therapy. The patient's pain was controlled with oral analgesia and a lidocaine patch. Total time in the hospital was less than 13 hours. The acute inpatient level of care was denied coverage by the health plan as not medically necessary.", + "expected": "Upheld", + }, + { + "text": "Diagnosis: General Debility Due to Lumbar Stenosis\nTreatment: Continued Stay in Skilled Nursing Facility\n\nThe insurer denied continued stay in skilled nursing facility\n\nThe patient is an adult female with a history of general debility due to lumbar stenosis affecting her functional mobility and activities of daily living (ADLs). She has impairments of balance, mobility, and strength, with an increased risk for falls.\n\nThe patient's relevant past medical history includes obesity status post gastric sleeve times two (x2), severe knee and hip osteoarthritis, anxiety, bipolar disorder, hiatus hernia, depression, asthma, hiatus hernia/gastroesophageal reflux disease (GERD), fractured ribs, fractured ankle, sarcoidosis, and pulmonary embolism. Before admission, the patient was living with family and friends in a house, independent with activities of daily living, and with support from others. The patient was admitted to a skilled nursing facility (SNF) three months ago, requiring total dependence for most activities of daily living, and as of two months ago, the patient was non-ambulatory, requiring supervision for bed mobility, contact guard for transfers, and maximum assistance for static standing. She has limitations in completing mobility and locomotive activities due to gross weakness of the bilateral lower extremities, decreased stability and controlled mobility, increased pain, impaired coordination, and decreased aerobic capacity.", + "expected": "Overturned", + }, + { + "text": "Diagnosis: Dilated CBD, distended gallbladder\n \nTreatment: Inpatient admission, diagnostic treatment and surgery\n\nThe insurer denied the inpatient admission. The patient presented with abdominal pain. He was afebrile and the vital signs were stable. There was no abdominal rebound or guarding. The WBC count was 14.5. The bilirubin was normal. A CAT scan revealed a dilated CBD and a distended gallbladder. An MRCP revealed a CBD stone with bile duct dilatation. The patient was treated with antibiotics. He underwent an ERCP with sphincterotomy and balloon sweeps. A laparoscopic cholecystectomy was then done. The patient remained hemodynamically stable and his pain was controlled.", + "expected": "Upheld", + }, + { + "text": "This is a female patient with a medical history of severe bilateral proliferative diabetic retinopathy and diabetic macular edema. The patient underwent an injection of Lucentis in her left eye and treatment with panretinal photocoagulation without complications. It was reported that the patient had severe disease with many dot/blot hemorrhages. Documentation revealed the patient had arteriovenous (AV) crossing changes bilaterally with venous tortuosity. There were scattered dot/blot hemorrhages bilaterally to the macula and periphery and macular edema. Additionally, she was counseled on proper diet control, exercise and hypertension control. Avastin and Mvasi are the same drug - namely bevacizumab: as per lexicomp: 'humanized monoclonal antibody which binds to, and neutralizes, vascular endothelial growth factor (VEGF), preventing its association with endothelial receptors, Flt-1 and KDR. VEGF binding initiates angiogenesis (endothelial proliferation and the formation of new blood vessels). The inhibition of microvascular growth is believed to retard the growth of all tissues (including metastatic tissue).' Lucentis is ranibizumab: as per lexicomp: 'a recombinant humanized monoclonal antibody fragment which binds to and inhibits human vascular endothelial growth factor A (VEGF-A). Ranibizumab inhibits VEGF from binding to its receptors and thereby suppressing neovascularization and slowing vision loss.' The formulary, step therapy options and the requested drug act against VEGF. There is no suggestion that Avastin or Mvasi would cause physical or mental harm to the patient. There are no contraindications in the documentation that would put the patient at risk for adverse reactions. This patient has a diagnosis of maculopathy. Avastin and Mvasi have been shown to be helpful with this condition.", + "expected": "Upheld", + }, + {"text": "hello? is this relevant?", "expected": "Insufficient"}, + { + "text": "A patient is being denied wegovy for morbid obesity. The health plan states it is not medically necessary.", + "expected": "Overturned", + }, + { + "text": "This patient has extensive and inoperable carcinoma of the stomach. He was started on chemotherapy with Xeloda and Oxaliplatin, because he has less nausea with Oxaliplatin than with the alternative, Cisplatin. Oxaliplatin was denied as experimental for treatment of his gastric cancer.", + "expected": "Overturned", + }, + { + "text": "This is a patient who was denied breast tomosynthesis to screen for breast cancer.", + "expected": "Overturned", + }, + { + "text": "This is a patient with Crohn's Disease who is being treated with Humira. Their health plan has denied Anser ADA blood level testing for Humira, claiming it is investigational.", + "expected": "Upheld", + }, + { + "text": "The patient is a 44-year-old female who initially presented with an abnormal screening mammogram. The patient was seen by a radiation oncologist who recommended treatment of the right chest wall and comprehensive nodal regions using proton beam radiation therapy.", + "expected": "Upheld", + }, + { + "text": "The patient is a 10-year-old female with a history of Pitt-Hopkins syndrome and associated motor planning difficulties, possible weakness in the oral area, and receptive and expressive language delays. The provider has recommended that the patient continue to receive individual speech and language therapy sessions twice a week for 60-minute sessions. The Health Insurer has denied the requested services as not medically necessary for treatment of the patient’s medical condition.", + "expected": "Overturned", + }, + { + "text": "The patient is a nine-year-old female with a history of autism spectrum disorder and a speech delay. The patient’s parent has requested reimbursement for the ABA services provided over the course of a year. The Health Insurer has denied the services at issue as not medically necessary for the treatment of the patient.", + "expected": "Overturned", + }, +] -class ClassificationPipeline(TextClassificationPipeline): - def postprocess(self, model_outputs): - out_logits = model_outputs["logits"] - return out_logits +def run_pytorch_model(model, tokenizer, text, jurisdiction_id=2, insurance_type_id=2): + """Run inference with the PyTorch model""" + model.eval() + tokenized = tokenizer(text, return_tensors="pt", truncation=True, padding=True) + with torch.no_grad(): + # Convert jurisdiction and insurance type to tensors + j_id = torch.tensor([jurisdiction_id]) + i_id = torch.tensor([insurance_type_id]) -if __name__ == "__main__": - # TODO: get these from model checkpoints - ID2LABEL = {0: "Insufficient", 1: "Upheld", 2: "Overturned"} - LABEL2ID = {v: k for k, v in ID2LABEL.items()} + result = model(**tokenized, jurisdiction_id=j_id, insurance_type_id=i_id) - # Load model and tokenizer - pretrained_model_key = "distilbert/distilbert-base-uncased" + probs = torch.softmax(result["logits"], dim=-1) + prob, argmax = torch.max(probs, dim=-1) + return { + "class_id": argmax.item(), + "class_name": ID2LABEL[argmax.item()], + "confidence": prob.item(), + "probs": probs[0].tolist(), + "logits": result["logits"][0].tolist(), + } + + +def run_pytorch_model_no_metadata(model, tokenizer, text): + """Run inference with the PyTorch model without metadata""" + model.eval() + tokenized = tokenizer(text, return_tensors="pt", truncation=True, padding=True) + with torch.no_grad(): + try: + # Try running without metadata + result = model(**tokenized) + probs = torch.softmax(result["logits"], dim=-1) + prob, argmax = torch.max(probs, dim=-1) + + return { + "class_id": argmax.item(), + "class_name": ID2LABEL[argmax.item()], + "confidence": prob.item(), + "probs": probs[0].tolist(), + "logits": result["logits"][0].tolist(), + } + except Exception as e: + return {"error": str(e), "class_name": "ERROR", "confidence": 0.0, "probs": [0.0, 0.0, 0.0]} + + +def run_pytorch_quantized(model_int8, tokenizer, text, jurisdiction_id=2, insurance_type_id=2): + """Run inference with the quantized PyTorch model""" + model_int8.eval() + tokenized = tokenizer(text, return_tensors="pt", truncation=True, padding=True) + with torch.no_grad(): + # Convert jurisdiction and insurance type to tensors + j_id = torch.tensor([jurisdiction_id]) + i_id = torch.tensor([insurance_type_id]) + + result = model_int8(**tokenized, jurisdiction_id=j_id, insurance_type_id=i_id) + + probs = torch.softmax(result["logits"], dim=-1) + prob, argmax = torch.max(probs, dim=-1) + + return { + "class_id": argmax.item(), + "class_name": ID2LABEL[argmax.item()], + "confidence": prob.item(), + "probs": probs[0].tolist(), + "logits": result["logits"][0].tolist(), + } + + +def run_onnx(session, tokenizer, text, jurisdiction_id=2, insurance_type_id=2): + """Run inference with ONNX Runtime""" + try: + # Tokenize input text + inputs = tokenizer(text, return_tensors="np", truncation=True, padding=True) + + # Construct the full input feed + input_feed = dict(inputs) + + # Check if the model expects metadata inputs + input_names = [input.name for input in session.get_inputs()] + + # Add jurisdiction and insurance type if supported by the model + if "jurisdiction_id" in input_names: + input_feed["jurisdiction_id"] = np.array([jurisdiction_id], dtype=np.int64) + if "insurance_type_id" in input_names: + input_feed["insurance_type_id"] = np.array([insurance_type_id], dtype=np.int64) + + # Run inference + outputs = session.run(output_names=["logits"], input_feed=input_feed) + + # Process outputs + result = scipy.special.softmax(outputs[0], axis=-1) + argmax = np.argmax(result[0]) + prob = result[0][argmax] + return { + "class_id": int(argmax), + "class_name": ID2LABEL[int(argmax)], + "confidence": float(prob), + "probs": result[0].tolist(), + "logits": outputs[0][0], + } + except Exception as e: + return {"error": str(e), "class_name": "ERROR", "confidence": 0.0, "probs": [0.0, 0.0, 0.0]} + + +def print_result(model_name, result, expected=None, show_probs=True, show_logits=True): + """Print inference result with color-coding based on matching expected output""" + if "error" in result: + print(f"{model_name}: {Fore.RED}ERROR{Style.RESET_ALL} - {result['error']}") + return + + # Determine color based on expected value + color = Fore.WHITE + if expected: + if result["class_name"] == expected: + color = Fore.GREEN + else: + color = Fore.RED + + # Print result + print(f"{model_name}: {color}{result['class_name']}{Style.RESET_ALL} ({result['confidence']:.4f})") + + # Optionally show logits + if show_logits: + logits_str = ", ".join([f"{ID2LABEL[i]}: {p:.4f}" for i, p in enumerate(result["logits"])]) + print(f" Logits: {logits_str}") + + # Optionally show probabilities + if show_probs: + probs_str = ", ".join([f"{ID2LABEL[i]}: {p:.4f}" for i, p in enumerate(result["probs"])]) + print(f" Probabilities: {probs_str}") + + +def test_all_examples(model, model_int8, onnx_session, onnx_quant_session, tokenizer): + """Run all test examples through all model variants""" + print(f"\n{Fore.CYAN}{Style.BRIGHT}===== TESTING ALL EXAMPLES ====={Style.RESET_ALL}") + + for i, example in enumerate(TEST_EXAMPLES): + text = example["text"] + expected = example["expected"] + + # Print a shortened version of the text + print( + f"\n{Fore.YELLOW}{Style.BRIGHT}Example {i + 1}: '{text[:250]}...' (Expected: {expected}){Style.RESET_ALL}" + ) + + # Run through all model variants + pt_result = run_pytorch_model(model, tokenizer, text) + # pt_no_metadata_result = run_pytorch_model_no_metadata(model, tokenizer, text) + pt_quant_result = run_pytorch_quantized(model_int8, tokenizer, text) + + onnx_result = run_onnx(onnx_session, tokenizer, text) + onnx_quant_result = run_onnx(onnx_quant_session, tokenizer, text) + + # Print results + print_result("PyTorch (with metadata)", pt_result, expected) + # print_result("PyTorch (no metadata)", pt_no_metadata_result, expected) + print_result("PyTorch (quantized)", pt_quant_result, expected) + print_result("ONNX", onnx_result, expected) + print_result("ONNX (quantized)", onnx_quant_result, expected) + + print("-" * 80) + + +def process_single_prompt(model, model_int8, onnx_session, onnx_quant_session, tokenizer, text): + """Process a single user-provided prompt through all model variants""" + print(f"\n{Fore.YELLOW}{Style.BRIGHT}Processing: '{text[:250]}...'{Style.RESET_ALL}") + + # Run through all model variants + pt_result = run_pytorch_model(model, tokenizer, text) + pt_no_metadata_result = run_pytorch_model_no_metadata(model, tokenizer, text) + pt_quant_result = run_pytorch_quantized(model_int8, tokenizer, text) + + onnx_result = run_onnx(onnx_session, tokenizer, text) + onnx_quant_result = run_onnx(onnx_quant_session, tokenizer, text) + + # Print results + print_result("PyTorch (with metadata)", pt_result, show_probs=True) + print_result("PyTorch (no metadata)", pt_no_metadata_result, show_probs=True) + print_result("PyTorch (quantized)", pt_quant_result) + print_result("ONNX", onnx_result) + print_result("ONNX (quantized)", onnx_quant_result) + + +def calibrate_model(model, tokenizer, calibration_data): + """Run calibration data through the model for quantization preparation""" + print(f"{Fore.CYAN}Calibrating model with {len(calibration_data)} examples...{Style.RESET_ALL}") + model.eval() + + # Define mappings for jurisdiction and insurance type if not already in the calibration data + JURISDICTION_MAP = {"NY": 0, "CA": 1, "Unspecified": 2} + INSURANCE_TYPE_MAP = {"Commercial": 0, "Medicaid": 1, "Unspecified": 2} + + with torch.no_grad(): + for example in calibration_data: + text = example["text"] - tokenizer = AutoTokenizer.from_pretrained(pretrained_model_key, model_max_length=512) + # Get jurisdiction and insurance type IDs, default to "Unspecified" (2) + if "jurisdiction" in example: + j_id = JURISDICTION_MAP.get(example["jurisdiction"], 2) + else: + j_id = example.get("jurisdiction_id", 2) - dataset_name = "train_backgrounds_suff" - checkpoints_dir = os.path.join(MODEL_DIR, dataset_name, pretrained_model_key) + if "insurance_type" in example: + i_id = INSURANCE_TYPE_MAP.get(example["insurance_type"], 2) + else: + i_id = example.get("insurance_type_id", 2) + + # Tokenize the text + tokenized = tokenizer(text, return_tensors="pt", truncation=True, padding=True) + + # Convert jurisdiction and insurance type to tensors + j_tensor = torch.tensor([j_id]) + i_tensor = torch.tensor([i_id]) + + # Run forward pass for calibration + _ = model(**tokenized, jurisdiction_id=j_tensor, insurance_type_id=i_tensor) + + return model + + +def quantize_model_with_proper_embedding_config(model, tokenizer, calibration_data=None): + """Quantize model focusing on embeddings and linear layers while avoiding layer_norm issues""" + print(f"{Fore.CYAN}Setting up quantization for embeddings and linear layers...{Style.RESET_ALL}") + + # Create a copy of the model to preserve the original + model_copy = copy.deepcopy(model) + model_copy.eval() + + # Dynamic quantization approach that properly handles both layer types + model_int8 = torch.ao.quantization.quantize_dynamic( + model_copy, + { + torch.nn.Linear, + # torch.nn.Embedding + }, # quantize both linear and embedding layers + dtype=torch.qint8, + ) + + print(f"{Fore.GREEN}Quantization complete (Linear and Embedding layers){Style.RESET_ALL}") + return model_int8 + + +def main(): + parser = argparse.ArgumentParser(description="Appeal Classification Model Inference") + parser.add_argument("--test", action="store_true", help="Run all test examples") + parser.add_argument("--prompt", type=str, help="Text to classify") + args = parser.parse_args() + + if not args.test and not args.prompt: + parser.error("Either --test or --prompt must be specified") + + # Load model and tokenizer + dataset_name = "train_backgrounds_suff_augmented" + checkpoints_dir = os.path.join(MODEL_DIR, dataset_name, "distilbert") checkpoint_dirs = sorted(os.listdir(checkpoints_dir)) checkpoint_name = checkpoint_dirs[0] ckpt_path = os.path.join(checkpoints_dir, checkpoint_name) - model = AutoModelForSequenceClassification.from_pretrained( - ckpt_path, num_labels=3, id2label=ID2LABEL, label2id=LABEL2ID - ) - # Upheld example - text = "Diagnosis: Broken Ribs\nTreatment: Inpatient Hospital Admission\n\nThe insurer denied inpatient hospital admission. \n\nThe patient is an adult male. He presented by ambulance to the hospital with severe back pain. The patient had fallen down a ramp and onto his back two days prior. The patient developed back pain and had pain with deep inspiration, prompting a call to 911 for an ambulance. The patient was taking ibuprofen and Tylenol for pain at home. A computed tomography (CT) scan of the patient's chest showed a right posterior minimally displaced 9th and 10th rib fractures. There was no associated intra-abdominal injury. There was atelectasis of the lung in the region of the rib fractures. Vital signs, including oxygen saturation, were normal in the emergency department triage note. The patient did not require supplemental oxygen during the hospitalization. The patient was admitted to the acute inpatient level of care for pain control, breathing treatments, and venous thromboembolism prophylaxis. The patient was seen and cleared by Physical Therapy. The patient's pain was controlled with oral analgesia and a lidocaine patch. Total time in the hospital was less than 13 hours. The acute inpatient level of care was denied coverage by the health plan as not medically necessary." - - # Overturned example - # text = "Diagnosis: General Debility Due to Lumbar Stenosis\nTreatment: Continued Stay in Skilled Nursing Facility\n\nThe insurer denied continued stay in skilled nursing facility\n\nThe patient is an adult female with a history of general debility due to lumbar stenosis affecting her functional mobility and activities of daily living (ADLs). She has impairments of balance, mobility, and strength, with an increased risk for falls.\n\nThe patient's relevant past medical history includes obesity status post gastric sleeve times two (x2), severe knee and hip osteoarthritis, anxiety, bipolar disorder, hiatus hernia, depression, asthma, hiatus hernia/gastroesophageal reflux disease (GERD), fractured ribs, fractured ankle, sarcoidosis, and pulmonary embolism. Before admission, the patient was living with family and friends in a house, independent with activities of daily living, and with support from others. The patient was admitted to a skilled nursing facility (SNF) three months ago, requiring total dependence for most activities of daily living, and as of two months ago, the patient was non-ambulatory, requiring supervision for bed mobility, contact guard for transfers, and maximum assistance for static standing. She has limitations in completing mobility and locomotive activities due to gross weakness of the bilateral lower extremities, decreased stability and controlled mobility, increased pain, impaired coordination, and decreased aerobic capacity. \n" - # text = "Diagnosis: Dilated CBD, distended gallbladder\n \nTreatment: Inpatient admission, diagnostic treatment and surgery\n\nThe insurer denied the inpatient admission. The patient presented with abdominal pain. He was afebrile and the vital signs were stable. There was no abdominal rebound or guarding. The WBC count was 14.5. The bilirubin was normal. A CAT scan revealed a dilated CBD and a distended gallbladder. An MRCP revealed a CBD stone with bile duct dilatation. The patient was treated with antibiotics. He underwent an ERCP with sphincterotomy and balloon sweeps. A laparoscopic cholecystectomy was then done. The patient remained hemodynamically stable and his pain was controlled." - # text = "This is a female patient with a medical history of severe bilateral proliferative diabetic retinopathy and diabetic macular edema. The patient underwent an injection of Lucentis in her left eye and treatment with panretinal photocoagulation without complications. It was reported that the patient had severe disease with many dot/blot hemorrhages. Documentation revealed the patient had arteriovenous (AV) crossing changes bilaterally with venous tortuosity. There were scattered dot/blot hemorrhages bilaterally to the macula and periphery and macular edema. Additionally, she was counseled on proper diet control, exercise and hypertension control. Avastin and Mvasi are the same drug - namely bevacizumab: as per lexicomp: 'humanized monoclonal antibody which binds to, and neutralizes, vascular endothelial growth factor (VEGF), preventing its association with endothelial receptors, Flt-1 and KDR. VEGF binding initiates angiogenesis (endothelial proliferation and the formation of new blood vessels). The inhibition of microvascular growth is believed to retard the growth of all tissues (including metastatic tissue).' Lucentis is ranibizumab: as per lexicomp: 'a recombinant humanized monoclonal antibody fragment which binds to and inhibits human vascular endothelial growth factor A (VEGF-A). Ranibizumab inhibits VEGF from binding to its receptors and thereby suppressing neovascularization and slowing vision loss.' The formulary, step therapy options and the requested drug act against VEGF. There is no suggestion that Avastin or Mvasi would cause physical or mental harm to the patient. There are no contraindications in the documentation that would put the patient at risk for adverse reactions. This patient has a diagnosis of maculopathy. Avastin and Mvasi have been shown to be helpful with this condition." - # tokens = tokenizer(text) - # with torch.no_grad(): - # output = model(torch.tensor(tokens["input_ids"])) - # print(output) - - # Upheld - # text = "A patient is being denied wegovy for morbid obesity. The health plan states it is not medically necessary." - - classifier = ClassificationPipeline(model=model, tokenizer=tokenizer) - classifier2 = pipeline("text-classification", model=model, tokenizer=tokenizer) - start = time.time() - result = classifier(text) - # result = classifier2(text) - end = time.time() - print("Vanilla HF pipeline:") - print(f"Logits: {result[0]}") - probs = torch.softmax(result[0], dim=-1) - print(f"Probs: {probs}") - prob, argmax = torch.max(probs, dim=-1) - print(f"Class pred: {argmax.item()}, Score: {prob.item()}") - print(f"Latency: {end-start}") + print(f"{Fore.CYAN}Loading models...{Style.RESET_ALL}") + + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained(ckpt_path) + + # Load PyTorch model + model = TextClassificationWithMetadata.from_pretrained( + ckpt_path, + num_labels=3, + id2label=ID2LABEL, + label2id=LABEL2ID, + ) print( - "Model size (MB):", + "Vanilla Pytorch Model size (MB):", round(os.path.getsize(os.path.join(ckpt_path, "model.safetensors")) / (1024 * 1024)), - "\n", ) - # Pytorch quantized model - model = classifier.model - # TODO: Fix this, this is not the right model to be quantizing via the torch or onnx ops below. - model_int8 = torch.ao.quantization.quantize_dynamic( - model, # the original model - {torch.nn.Linear}, # a set of layers to dynamically quantize - dtype=torch.qint8, - ) # the target dtype for quantized weights - model_int8.eval() - start = time.time() - tokenized = tokenizer(text, return_tensors="pt") - with torch.no_grad(): - result = model_int8(**tokenized) - end = time.time() - print("Quantized pytorch model:") - probs = torch.softmax(result.logits, dim=-1) - print(f"Probs: {probs}") - prob, argmax = torch.max(probs, dim=-1) - print(f"Class pred: {argmax.item()}, Score: {prob.item()}") - print(f"Latency: {end-start}") + # Prepare calibration data from test examples + calibration_data = [] + for example in TEST_EXAMPLES: + # Add jurisdiction and insurance_type variations for each example + for j_id in [0, 1, 2]: # NY, CA, Unspecified + for i_id in [0, 1, 2]: # Commercial, Medicaid, Unspecified + calibration_data.append({"text": example["text"], "jurisdiction_id": j_id, "insurance_type_id": i_id}) + + # Prepare calibration data from test examples + calibration_data = [] + for example in TEST_EXAMPLES: + # Add jurisdiction and insurance_type variations for each example + for j_id in [0, 1, 2]: # NY, CA, Unspecified + for i_id in [0, 1, 2]: # Commercial, Medicaid, Unspecified + calibration_data.append({"text": example["text"], "jurisdiction_id": j_id, "insurance_type_id": i_id}) + + # Quantize model with proper embedding configuration + model_int8 = quantize_model_with_proper_embedding_config(model.to("cpu"), tokenizer, calibration_data) + + print(f"{Fore.GREEN}Model calibration and quantization complete{Style.RESET_ALL}") param_size = 0 for param in model_int8.parameters(): param_size += param.nelement() * param.element_size() @@ -104,59 +372,42 @@ def postprocess(self, model_outputs): buffer_size += buffer.nelement() * buffer.element_size() size_all_mb = (param_size + buffer_size) / (1024 * 1024) - print(f"Model size (MB): {round(size_all_mb)}\n") + print(f"Quantized Pytorch Model size (MB): {round(size_all_mb)}") - # Export onxx model and quantized version, if nonexistent - onnx_file_name = "model.onnx" - onnx_model_path = os.path.join(ckpt_path, onnx_file_name) + # Load ONNX models + onnx_model_path = os.path.join(ckpt_path, "model.onnx") quant_onnx_model_path = os.path.join(ckpt_path, "quant-model.onnx") - export_onnx_model(onnx_model_path, model, tokenizer) + + # Ensure ONNX models exist + if not os.path.exists(onnx_model_path): + print("Exporting model to ONNX format...") + export_onnx_model(onnx_model_path, model, tokenizer) + print( + "Onnx Model size (MB):", + round(os.path.getsize(onnx_model_path) / (1024 * 1024)), + ) if not os.path.exists(quant_onnx_model_path): + print("Quantizing ONNX model...") quantize_onnx_model(onnx_model_path, quant_onnx_model_path) - - print( - "ONNX full precision model size (MB):", - os.path.getsize(onnx_model_path) / (1024 * 1024), - ) - - # Try running with onnx runtime (quantized) - sess_options = onnxruntime.SessionOptions() - sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL - session = onnxruntime.InferenceSession(quant_onnx_model_path, sess_options) - start = time.time() - inputs = tokenizer(text, return_tensors="np") - outputs = session.run(output_names=["logits"], input_feed=dict(inputs)) - result = scipy.special.softmax(outputs[0], axis=-1) - end = time.time() - print("Quantized Onnx:") - print(f"Probs: {result[0]}") - argmax = np.argmax(result[0], axis=-1) - prob = result[0][argmax] - print(f"Class pred: {argmax}, Score: {prob}") - print(f"Latency: {end-start}") print( - "Model size (MB):", + "Quantized Onnx Model size (MB):", round(os.path.getsize(quant_onnx_model_path) / (1024 * 1024)), - "\n", ) - - # Try running with onnx runtime + # Create ONNX sessions sess_options = onnxruntime.SessionOptions() sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL - session = onnxruntime.InferenceSession(onnx_model_path, sess_options) - start = time.time() - inputs = tokenizer(text, return_tensors="np") - outputs = session.run(output_names=["logits"], input_feed=dict(inputs)) - result = scipy.special.softmax(outputs[0], axis=-1) - end = time.time() - print("Onnx:") - print(f"Probs: {result}") - argmax = np.argmax(result[0], axis=-1) - prob = result[0][argmax] - print(f"Class pred: {argmax}, Score: {prob}") - print(f"Latency: {end-start}") - print( - "Model size (MB):", - round(os.path.getsize(onnx_model_path) / (1024 * 1024)), - ) + onnx_session = onnxruntime.InferenceSession(onnx_model_path, sess_options) + onnx_quant_session = onnxruntime.InferenceSession(quant_onnx_model_path, sess_options) + + print(f"{Fore.GREEN}Models loaded successfully{Style.RESET_ALL}") + + # Run the appropriate mode + if args.test: + test_all_examples(model, model_int8, onnx_session, onnx_quant_session, tokenizer) + else: + process_single_prompt(model, model_int8, onnx_session, onnx_quant_session, tokenizer, args.prompt) + + +if __name__ == "__main__": + main() diff --git a/src/modeling/pretrain.py b/src/modeling/pretrain.py index 0101919..d32a37b 100644 --- a/src/modeling/pretrain.py +++ b/src/modeling/pretrain.py @@ -9,6 +9,7 @@ import datasets import numpy as np +import wandb from datasets import Dataset from transformers import ( AutoModelForMaskedLM, @@ -18,7 +19,6 @@ default_data_collator, ) -import wandb from src.modeling.util import load_config os.environ["TOKENIZERS_PARALLELISM"] = "false" diff --git a/src/modeling/train_background_token_classification.py b/src/modeling/train_background_token_classification.py index 0f8bdd8..31b0302 100644 --- a/src/modeling/train_background_token_classification.py +++ b/src/modeling/train_background_token_classification.py @@ -6,6 +6,7 @@ import evaluate import numpy as np +import wandb from datasets import load_dataset from transformers import ( AutoModelForTokenClassification, @@ -16,7 +17,6 @@ TrainingArguments, ) -import wandb from src.modeling.util import load_config NO_CLASS_ENCODING = 0 diff --git a/src/modeling/train_outcome_predictor.py b/src/modeling/train_outcome_predictor.py index 72dc61c..14a4b51 100644 --- a/src/modeling/train_outcome_predictor.py +++ b/src/modeling/train_outcome_predictor.py @@ -5,7 +5,10 @@ import numpy as np import scipy -from datasets import Dataset, load_dataset +import torch +import torch.nn as nn +import wandb +from datasets import Dataset, DatasetDict, load_dataset from sklearn.metrics import ( accuracy_score, f1_score, @@ -13,35 +16,57 @@ recall_score, roc_auc_score, ) +from sklearn.model_selection import train_test_split from transformers import ( AutoModelForSequenceClassification, AutoTokenizer, + BertForSequenceClassification, DataCollatorWithPadding, + DistilBertForSequenceClassification, Trainer, TrainingArguments, ) +from transformers.modeling_outputs import SequenceClassifierOutput -import wandb from src.modeling.util import export_onnx_model, load_config, quantize_onnx_model from src.util import get_records_list ID2LABEL = {0: "Insufficient", 1: "Upheld", 2: "Overturned"} LABEL2ID = {v: k for k, v in ID2LABEL.items()} + +# Define mappings for jurisdiction and insurance type +JURISDICTION_MAP = {"NY": 0, "CA": 1, "Unspecified": 2} +INSURANCE_TYPE_MAP = {"Commercial": 0, "Medicaid": 1, "Unspecified": 2} + OUTPUT_DIR = "./models/overturn_predictor" os.environ["TOKENIZERS_PARALLELISM"] = "false" def load_and_split( - jsonl_path: str, test_size: float = 0.2, filter_keys: list = ["decision", "text", "sufficiency_id"], seed: int = 2 + jsonl_path: str, + test_size: float = 0.2, + filter_keys: list = ["decision", "text", "sufficiency_id", "jurisdiction", "insurance_type"], + seed: int = 2, ) -> Dataset: if len(filter_keys) > 0: recs = get_records_list(jsonl_path) - recs = [{key: rec[key] for key in filter_keys} for rec in recs] + recs = [{key: rec.get(key, "Unspecified") for key in filter_keys} for rec in recs] dataset = Dataset.from_list(recs) + else: dataset = load_dataset("json", data_files=jsonl_path)["train"] - dataset = dataset.train_test_split(test_size=test_size, seed=1) - return dataset + + # Use scikit-learn train_test_split for stratified sampling + train_indices, test_indices = train_test_split( + list(range(len(dataset))), test_size=test_size, random_state=seed, shuffle=True, stratify=dataset["decision"] + ) + + # Create train and test datasets using the indices + train_dataset = dataset.select(train_indices) + test_dataset = dataset.select(test_indices) + + # Return as a DatasetDict + return DatasetDict({"train": train_dataset, "test": test_dataset}) def tokenize_batch(examples, tokenizer): @@ -59,10 +84,20 @@ def construct_label(outcome: str, sufficiency_id: int, label2id: dict) -> int: return label2id[outcome] -def add_integral_ids_batch(examples, label2id: dict): +def add_integral_ids_batch(examples, label2id: dict, jurisdiction_map: dict, insurance_type_map: dict): outcomes = examples["decision"] sufficiency_ids = examples["sufficiency_id"] + + # Map jurisdiction and insurance_type to their respective IDs + jurisdictions = examples.get("jurisdiction", ["Unspecified"] * len(outcomes)) + insurance_types = examples.get("insurance_type", ["Unspecified"] * len(outcomes)) + examples["label"] = [construct_label(outcome, id, label2id) for (outcome, id) in zip(outcomes, sufficiency_ids)] + examples["jurisdiction_id"] = [jurisdiction_map.get(j, jurisdiction_map["Unspecified"]) for j in jurisdictions] + examples["insurance_type_id"] = [ + insurance_type_map.get(i, insurance_type_map["Unspecified"]) for i in insurance_types + ] + return examples @@ -155,6 +190,485 @@ def compute_metrics2(eval_pred) -> dict: return best_metrics +class TextClassificationWithMetadata(DistilBertForSequenceClassification): + def __init__(self, config): + super().__init__(config) + self.hidden_size = config.hidden_size + + # Metadata embeddings + self.jurisdiction_embeddings = nn.Embedding(2, self.hidden_size // 4) + self.insurance_type_embeddings = nn.Embedding(2, self.hidden_size // 4) + + # Fusion projection + self.metadata_projection = nn.Sequential( + nn.Linear(self.hidden_size // 2, self.hidden_size), + nn.LayerNorm(self.hidden_size), + nn.GELU(), + nn.Dropout(0.1), + ) + + # Manual cross-attention components + self.num_heads = 8 + self.head_dim = self.hidden_size // self.num_heads + + # Query, Key, Value projections + self.q_proj = nn.Linear(self.hidden_size, self.hidden_size) + self.k_proj = nn.Linear(self.hidden_size, self.hidden_size) + self.v_proj = nn.Linear(self.hidden_size, self.hidden_size) + self.out_proj = nn.Linear(self.hidden_size, self.hidden_size) + + self.dropout = nn.Dropout(0.1) + + self.metadata_norm = nn.LayerNorm(self.hidden_size) + self.sequence_norm = nn.LayerNorm(self.hidden_size) + + # Save the original num_labels + self.num_labels = ( + self.classifier.out_features if hasattr(self.classifier, "out_features") else config.num_labels + ) + + # Now replace the classifier + self.enhanced_classifier = nn.Sequential( + nn.Linear(self.hidden_size, self.hidden_size), + nn.LayerNorm(self.hidden_size), + nn.GELU(), + nn.Dropout(0.1), + nn.Linear(self.hidden_size, self.num_labels), + ) + + def manual_cross_attention(self, query, key, value, key_padding_mask=None): + # Same implementation, but with explicit cleanup + batch_size = query.size(1) + seq_length = key.size(0) + + # Project query, key, value + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + + # Reshape for multi-head attention + q = q.view(1, batch_size, self.num_heads, self.head_dim).permute(2, 1, 0, 3) + k = k.view(seq_length, batch_size, self.num_heads, self.head_dim).permute(2, 1, 0, 3) + v = v.view(seq_length, batch_size, self.num_heads, self.head_dim).permute(2, 1, 0, 3) + + # Calculate attention scores + scaling = float(self.head_dim) ** -0.5 + q = q * scaling + attn_scores = torch.matmul(q, k.transpose(-2, -1)) + + # Apply attention mask if provided + if key_padding_mask is not None: + attn_mask = key_padding_mask.unsqueeze(1).unsqueeze(1) + attn_mask = attn_mask.transpose(0, 1) + attn_mask = attn_mask.expand(self.num_heads, -1, 1, -1) + attn_scores = attn_scores.masked_fill(attn_mask, -10000.0) + + # Get max values for numerical stability + attn_scores_max, _ = torch.max(attn_scores, dim=-1, keepdim=True) + attn_scores = attn_scores - attn_scores_max + + # Apply softmax + attn_weights = torch.exp(torch.clamp(attn_scores, min=-10000.0, max=100.0)) + attn_sum = torch.sum(attn_weights, dim=-1, keepdim=True) + 1e-6 + attn_weights = attn_weights / attn_sum + + # Apply dropout + attn_weights = self.dropout(attn_weights) + + # Clean up intermediates + del attn_scores, attn_scores_max, attn_sum + if key_padding_mask is not None: + del attn_mask + + # Apply attention weights to values + context = torch.matmul(attn_weights, v) + + # Clean up more intermediates + del attn_weights + + # Reshape back + context = context.permute(2, 1, 0, 3) + context = context.reshape(1, batch_size, self.hidden_size) + + # Clean up + del q, k, v + + # Apply output projection + result = self.out_proj(context) + + # Final cleanup + del context + + return result + + def forward( + self, + input_ids=None, + attention_mask=None, + jurisdiction_id=None, + insurance_type_id=None, + labels=None, + return_dict=None, + token_type_ids=None, + ): + # Validate that required parameters are provided, use index 2 for optional + if jurisdiction_id is None or insurance_type_id is None: + raise ValueError("jurisdiction_id and insurance_type_id must be provided") + + # Process metadata + j_unspecified_mask = jurisdiction_id == 2 + i_unspecified_mask = insurance_type_id == 2 + + # Temporarily assign unspecified metadata classes -> 1 + j_ids_valid = torch.clamp(jurisdiction_id, 0, 1) + i_ids_valid = torch.clamp(insurance_type_id, 0, 1) + + j_embeddings = self.jurisdiction_embeddings(j_ids_valid) + i_embeddings = self.insurance_type_embeddings(i_ids_valid) + + avg_j_embedding = self.jurisdiction_embeddings.weight.mean(dim=0) + avg_i_embedding = self.insurance_type_embeddings.weight.mean(dim=0) + + # Assign unspecified classes the average embeddings + j_embeddings = torch.where( + j_unspecified_mask.unsqueeze(-1).expand_as(j_embeddings), + avg_j_embedding.expand_as(j_embeddings), + j_embeddings, + ) + i_embeddings = torch.where( + i_unspecified_mask.unsqueeze(-1).expand_as(i_embeddings), + avg_i_embedding.expand_as(i_embeddings), + i_embeddings, + ) + + # Combine metadata and project + combined_metadata = torch.cat([j_embeddings, i_embeddings], dim=1) + metadata_features = self.metadata_projection(combined_metadata) + + # Call the parent model but bypass its classification head + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + base_outputs = super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, # We need the hidden states + return_dict=True, + token_type_ids=token_type_ids, + labels=None, # Don't pass labels yet + ) + + sequence_output = base_outputs.hidden_states[-1] + + # Use cross-attention for fusion + metadata_features = self.metadata_norm(metadata_features) + sequence_output = self.sequence_norm(sequence_output) + + # Reshape for attention + metadata_features = metadata_features.unsqueeze(0) # [1, batch, hidden_dim] + sequence_output_t = sequence_output.transpose(0, 1) # [seq_len, batch, hidden_dim] + + key_padding_mask = attention_mask == 0 + + fused_features = self.manual_cross_attention( + query=metadata_features, key=sequence_output_t, value=sequence_output_t, key_padding_mask=key_padding_mask + ) + + # Combine with CLS token + fused_features = fused_features.squeeze(0) + sequence_output[:, 0] + + logits = self.enhanced_classifier(fused_features) + + # Calculate loss if labels provided + loss = None + if labels is not None: + loss = torch.nn.functional.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1)) + + if return_dict: + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=base_outputs.hidden_states, + attentions=base_outputs.attentions, + ) + else: + output = (logits,) + base_outputs[2:] + return ((loss,) + output) if loss is not None else output + + +# Nearly identical but inherits from DistilBertForSequenceClassification +# and doesn't use token_type_ids +class DistilBertTextClassificationWithMetadata(DistilBertForSequenceClassification): + def __init__(self, config): + super().__init__(config) + self.hidden_size = config.hidden_size + + # Metadata embeddings + self.jurisdiction_embeddings = nn.Embedding(2, self.hidden_size // 4) + self.insurance_type_embeddings = nn.Embedding(2, self.hidden_size // 4) + + # Fusion projection + self.metadata_projection = nn.Sequential( + nn.Linear(self.hidden_size // 2, self.hidden_size), + nn.LayerNorm(self.hidden_size), + nn.GELU(), + nn.Dropout(0.1), + ) + + # Manual cross-attention components + self.num_heads = 8 + self.head_dim = self.hidden_size // self.num_heads + + # Query, Key, Value projections + self.q_proj = nn.Linear(self.hidden_size, self.hidden_size) + self.k_proj = nn.Linear(self.hidden_size, self.hidden_size) + self.v_proj = nn.Linear(self.hidden_size, self.hidden_size) + self.out_proj = nn.Linear(self.hidden_size, self.hidden_size) + + self.dropout = nn.Dropout(0.1) + + self.metadata_norm = nn.LayerNorm(self.hidden_size) + self.sequence_norm = nn.LayerNorm(self.hidden_size) + + # Save the original num_labels + self.num_labels = ( + self.classifier.out_features if hasattr(self.classifier, "out_features") else config.num_labels + ) + + # Now replace the classifier + self.enhanced_classifier = nn.Sequential( + nn.Linear(self.hidden_size, self.hidden_size), + nn.LayerNorm(self.hidden_size), + nn.GELU(), + nn.Dropout(0.1), + nn.Linear(self.hidden_size, self.num_labels), + ) + + def manual_cross_attention(self, query, key, value, key_padding_mask=None): + # Same implementation, but with explicit cleanup + batch_size = query.size(1) + seq_length = key.size(0) + + # Project query, key, value + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + + # Reshape for multi-head attention + q = q.view(1, batch_size, self.num_heads, self.head_dim).permute(2, 1, 0, 3) + k = k.view(seq_length, batch_size, self.num_heads, self.head_dim).permute(2, 1, 0, 3) + v = v.view(seq_length, batch_size, self.num_heads, self.head_dim).permute(2, 1, 0, 3) + + # Calculate attention scores + scaling = float(self.head_dim) ** -0.5 + q = q * scaling + attn_scores = torch.matmul(q, k.transpose(-2, -1)) + + # Apply attention mask if provided + if key_padding_mask is not None: + attn_mask = key_padding_mask.unsqueeze(1).unsqueeze(1) + attn_mask = attn_mask.transpose(0, 1) + attn_mask = attn_mask.expand(self.num_heads, -1, 1, -1) + attn_scores = attn_scores.masked_fill(attn_mask, -10000.0) + + # Get max values for numerical stability + attn_scores_max, _ = torch.max(attn_scores, dim=-1, keepdim=True) + attn_scores = attn_scores - attn_scores_max + + # Apply softmax + attn_weights = torch.exp(torch.clamp(attn_scores, min=-10000.0, max=100.0)) + attn_sum = torch.sum(attn_weights, dim=-1, keepdim=True) + 1e-6 + attn_weights = attn_weights / attn_sum + + # Apply dropout + attn_weights = self.dropout(attn_weights) + + # Clean up intermediates + del attn_scores, attn_scores_max, attn_sum + if key_padding_mask is not None: + del attn_mask + + # Apply attention weights to values + context = torch.matmul(attn_weights, v) + + # Clean up more intermediates + del attn_weights + + # Reshape back + context = context.permute(2, 1, 0, 3) + context = context.reshape(1, batch_size, self.hidden_size) + + # Clean up + del q, k, v + + # Apply output projection + result = self.out_proj(context) + + # Final cleanup + del context + + return result + + def forward( + self, + input_ids=None, + attention_mask=None, + jurisdiction_id=None, + insurance_type_id=None, + labels=None, + return_dict=None, + token_type_ids=None, # This parameter is left for API compatibility but not used + ): + # Validate that required parameters are provided, use index 2 for optional + if jurisdiction_id is None or insurance_type_id is None: + raise ValueError("jurisdiction_id and insurance_type_id must be provided") + + # Process metadata + j_unspecified_mask = jurisdiction_id == 2 + i_unspecified_mask = insurance_type_id == 2 + + # Temporarily assign unspecified metadata classes -> 1 + j_ids_valid = torch.clamp(jurisdiction_id, 0, 1) + i_ids_valid = torch.clamp(insurance_type_id, 0, 1) + + j_embeddings = self.jurisdiction_embeddings(j_ids_valid) + i_embeddings = self.insurance_type_embeddings(i_ids_valid) + + avg_j_embedding = self.jurisdiction_embeddings.weight.mean(dim=0) + avg_i_embedding = self.insurance_type_embeddings.weight.mean(dim=0) + + # Assign unspecified classes the average embeddings + j_embeddings = torch.where( + j_unspecified_mask.unsqueeze(-1).expand_as(j_embeddings), + avg_j_embedding.expand_as(j_embeddings), + j_embeddings, + ) + i_embeddings = torch.where( + i_unspecified_mask.unsqueeze(-1).expand_as(i_embeddings), + avg_i_embedding.expand_as(i_embeddings), + i_embeddings, + ) + + # Combine metadata and project + combined_metadata = torch.cat([j_embeddings, i_embeddings], dim=1) + metadata_features = self.metadata_projection(combined_metadata) + + # Call the parent model but bypass its classification head + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + base_outputs = super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + # No token_type_ids for DistilBERT + output_hidden_states=True, # We need the hidden states + return_dict=True, + labels=None, # Don't pass labels yet + ) + + sequence_output = base_outputs.hidden_states[-1] + + # Use cross-attention for fusion + metadata_features = self.metadata_norm(metadata_features) + sequence_output = self.sequence_norm(sequence_output) + + # Reshape for attention + metadata_features = metadata_features.unsqueeze(0) # [1, batch, hidden_dim] + sequence_output_t = sequence_output.transpose(0, 1) # [seq_len, batch, hidden_dim] + + key_padding_mask = attention_mask == 0 + + fused_features = self.manual_cross_attention( + query=metadata_features, key=sequence_output_t, value=sequence_output_t, key_padding_mask=key_padding_mask + ) + + # Combine with CLS token + fused_features = fused_features.squeeze(0) + sequence_output[:, 0] + + logits = self.enhanced_classifier(fused_features) + + # Calculate loss if labels provided + loss = None + if labels is not None: + loss = torch.nn.functional.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1)) + + if return_dict: + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=base_outputs.hidden_states, + attentions=base_outputs.attentions, + ) + else: + output = (logits,) + base_outputs[2:] + return ((loss,) + output) if loss is not None else output + + +class DataCollatorWithMetadata(DataCollatorWithPadding): + def __call__(self, features): + batch = super().__call__(features) + + # Add the jurisdiction and insurance type IDs to the batch + if "jurisdiction_id" in features[0]: + batch["jurisdiction_id"] = torch.tensor([f["jurisdiction_id"] for f in features]) + + if "insurance_type_id" in features[0]: + batch["insurance_type_id"] = torch.tensor([f["insurance_type_id"] for f in features]) + + return batch + + +class CPUEvalTrainer(Trainer): + def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None): + has_labels = all(inputs.get(k) is not None for k in self.label_names) + inputs = self._prepare_inputs(inputs) + + with torch.no_grad(): + outputs = model(**inputs) + logits = outputs.logits + + # Move logits to CPU immediately to save GPU memory + logits = logits.detach().cpu() + + labels = None + if has_labels: + labels = tuple(inputs.get(name).detach().cpu() for name in self.label_names) + if len(labels) == 1: + labels = labels[0] + + loss = None + if has_labels and outputs.loss is not None: + loss = outputs.loss.detach().cpu() + + return (loss, logits, labels) + + +def create_metadata_model(pretrained_model_key, base_model_name, num_labels, id2label, label2id): + # Map of base model names to model classes + MODEL_CLASS_MAP = { + # BERT-based models + "legal-bert-small-uncased": TextClassificationWithMetadata, + # DistilBERT-based models + "distilbert": DistilBertTextClassificationWithMetadata, + "clinicalbert": DistilBertTextClassificationWithMetadata, + } + + # Get model class based on base_model_name + # This will use TextClassificationWithMetadata by default for unrecognized models + model_class = MODEL_CLASS_MAP.get( + base_model_name, + DistilBertTextClassificationWithMetadata + if "distilbert" in base_model_name.lower() + else TextClassificationWithMetadata, + ) + + # Create and return the appropriate model + return model_class.from_pretrained( + pretrained_model_key, + num_labels=num_labels, + id2label=id2label, + label2id=label2id, + ) + + def main(config_path: str) -> None: cfg = load_config(config_path) @@ -193,12 +707,22 @@ def main(config_path: str) -> None: # Prepare dataset for training dataset = dataset.map(partial(tokenize_batch, tokenizer=tokenizer), batched=True) - dataset = dataset.map(partial(add_integral_ids_batch, label2id=LABEL2ID), batched=True) + dataset = dataset.map( + partial( + add_integral_ids_batch, + label2id=LABEL2ID, + jurisdiction_map=JURISDICTION_MAP, + insurance_type_map=INSURANCE_TYPE_MAP, + ), + batched=True, + ) - data_collator = DataCollatorWithPadding(tokenizer=tokenizer) + # Use our custom data collator that handles the additional features + data_collator = DataCollatorWithMetadata(tokenizer=tokenizer) - model = AutoModelForSequenceClassification.from_pretrained( - pretrained_model_key, + model = create_metadata_model( + pretrained_model_key=pretrained_model_key, + base_model_name=base_model_name, num_labels=3, id2label=ID2LABEL, label2id=LABEL2ID, @@ -213,13 +737,15 @@ def main(config_path: str) -> None: checkpoints_dir = os.path.join(OUTPUT_DIR, dataset_name, outdir_name) training_args = TrainingArguments( output_dir=checkpoints_dir, + run_name=cfg["wandb_run_tag"], learning_rate=cfg["learning_rate"], per_device_train_batch_size=cfg["batch_size"], per_device_eval_batch_size=cfg["batch_size"], + eval_accumulation_steps=16, num_train_epochs=cfg["num_epochs"], weight_decay=cfg["weight_decay"], fp16=(cfg["dtype"] == "float16"), - evaluation_strategy="epoch", + eval_strategy="epoch", save_strategy="epoch", save_total_limit=1, load_best_model_at_end=True, @@ -229,12 +755,12 @@ def main(config_path: str) -> None: dataloader_pin_memory=True, ) - trainer = Trainer( + trainer = CPUEvalTrainer( model=model, args=training_args, train_dataset=dataset["train"], eval_dataset=dataset["test"], - tokenizer=tokenizer, + processing_class=tokenizer, data_collator=data_collator, compute_metrics=compute_metrics2, ) diff --git a/src/modeling/train_sufficiency_classifier.py b/src/modeling/train_sufficiency_classifier.py index e174f76..c4bf976 100644 --- a/src/modeling/train_sufficiency_classifier.py +++ b/src/modeling/train_sufficiency_classifier.py @@ -7,6 +7,7 @@ import numpy as np import scipy import torch +import wandb from datasets import Dataset from sklearn.metrics import ( accuracy_score, @@ -23,7 +24,6 @@ TrainingArguments, ) -import wandb from src.modeling.data_augmentation import ( load_augmented_dataset, process_and_save_augmentations, @@ -223,21 +223,12 @@ def main(config_path: str) -> None: elif use_data_augmentation: generic_rewrite_params = cfg.get( "generic_rewrite_params", - {"num_augmentations_per_example": 1, "seed": 1, "api_key": os.environ.get("OPENAI_API_KEY")}, - ) - unrelated_params = cfg.get( - "unrelated_params", {"num_examples": 500, "seed": 1, "api_key": os.environ.get("OPENAI_API_KEY")} - ) - sufficient_augmentation_params = cfg.get( - "sufficient_augmentation_params", - { - "num_augmentations_per_example": 1, - "api_type": "llamacpp", - "model_name": "Meta-Llama-3-8B-Instruct.Q4_K_M.gguf", - "seed": 1, - "api_key": os.environ.get("OPENAI_API_KEY"), - }, ) + generic_rewrite_params["api_key"] = os.environ.get("OPENAI_API_KEY") + unrelated_params = cfg.get("unrelated_params") + unrelated_params["api_key"] = os.environ.get("OPENAI_API_KEY") + sufficient_augmentation_params = cfg.get("sufficient_augmentation_params") + sufficient_augmentation_params["api_key"] = os.environ.get("OPENAI_API_KEY") # Set output paths based on save_augmentations flag output_train_path = augmented_train_path if save_augmentations else None diff --git a/src/modeling/util.py b/src/modeling/util.py index 5f8d83d..bf60d65 100644 --- a/src/modeling/util.py +++ b/src/modeling/util.py @@ -3,38 +3,100 @@ import torch import yaml from onnxruntime.quantization import QuantType, quantize_dynamic +from onnxruntime.quantization.shape_inference import quant_pre_process from transformers import AutoModel, AutoTokenizer -def quantize_onnx_model(onnx_model_path: str, quantized_model_path: str): - quantize_dynamic(onnx_model_path, quantized_model_path, weight_type=QuantType.QInt8) - print(f"quantized model saved to:{quantized_model_path}") - return None - - def export_onnx_model(output_model_path: str, model: torch.nn.Module | AutoModel, tokenizer: AutoTokenizer): if os.path.exists(output_model_path): print(f"Warning: overwriting existing ONNX model at path {output_model_path}") - dummy_text = "test" + + # Use a more representative dummy text + dummy_text = "This is a medical test example with sufficient content to exercise the model" + # Tokenize the text dummy_input = tokenizer(dummy_text, return_tensors="pt") - torch.onnx.export( - model, - tuple([dummy_input["input_ids"], dummy_input["attention_mask"]]), - f=output_model_path, - export_params=True, - input_names=["input_ids", "attention_mask"], - output_names=["logits"], - dynamic_axes={ - "input_ids": {0: "batch_size", 1: "sequence"}, - "attention_mask": {0: "batch_size", 1: "sequence"}, - "logits": {0: "batch_size", 1: "sequence"}, - }, - do_constant_folding=True, - opset_version=17, - ) + + # Add metadata inputs for the custom model + dummy_jurisdiction_id = torch.tensor([2]) # 2 = Unspecified + dummy_insurance_type_id = torch.tensor([2]) # 2 = Unspecified + + # Check if model accepts metadata + try: + # Test if model accepts metadata parameters + with torch.no_grad(): + _ = model( + input_ids=dummy_input["input_ids"], + attention_mask=dummy_input["attention_mask"], + jurisdiction_id=dummy_jurisdiction_id, + insurance_type_id=dummy_insurance_type_id, + ) + + # If we got here, model accepts metadata + print("Exporting model with metadata support...") + torch.onnx.export( + model, + tuple( + [ + dummy_input["input_ids"], + dummy_input["attention_mask"], + dummy_jurisdiction_id, + dummy_insurance_type_id, + ] + ), + f=output_model_path, + export_params=True, + input_names=["input_ids", "attention_mask", "jurisdiction_id", "insurance_type_id"], + output_names=["logits"], + dynamic_axes={ + "input_ids": {0: "batch_size", 1: "sequence"}, + "attention_mask": {0: "batch_size", 1: "sequence"}, + "jurisdiction_id": {0: "batch_size"}, + "insurance_type_id": {0: "batch_size"}, + "logits": {0: "batch_size"}, + }, + do_constant_folding=True, + opset_version=17, + ) + except Exception as e: + print(f"Model doesn't support metadata, exporting without it: {e}") + # Export without metadata + torch.onnx.export( + model, + tuple([dummy_input["input_ids"], dummy_input["attention_mask"]]), + f=output_model_path, + export_params=True, + input_names=["input_ids", "attention_mask"], + output_names=["logits"], + dynamic_axes={ + "input_ids": {0: "batch_size", 1: "sequence"}, + "attention_mask": {0: "batch_size", 1: "sequence"}, + "logits": {0: "batch_size"}, + }, + do_constant_folding=True, + opset_version=17, + ) + print(f"Exported ONNX model to {output_model_path}.") +def quantize_onnx_model(onnx_model_path: str, quantized_model_path: str): + # print("Preprocessing ONNX model before quantization...") + # pre_processed_model_path = onnx_model_path + ".pre_processed.onnx" + + # quant_pre_process(onnx_model_path, pre_processed_model_path) + + # Quantize the model + quantize_dynamic( + onnx_model_path, + quantized_model_path, + per_channel=False, + reduce_range=False, + weight_type=QuantType.QInt8, + ) + print(f"Quantized model saved to: {quantized_model_path}") + return None + + def load_config(config_path: str) -> dict: with open(config_path, "r") as stream: try: diff --git a/src/processors/ca_cdi.py b/src/processors/ca_cdi.py index c7d425f..07b2adf 100644 --- a/src/processors/ca_cdi.py +++ b/src/processors/ca_cdi.py @@ -63,6 +63,8 @@ def process(source_lineitem: dict, output_dirname: str) -> dict: "treatment": treatment, "decision": decision, "patient_race": patient_race, + "jurisdiction": "CA", + "insurance_type": "Commercial", } add_jsonl_line(outfile, line_data) diff --git a/src/processors/ca_dmhc.py b/src/processors/ca_dmhc.py index eb84326..7aba525 100644 --- a/src/processors/ca_dmhc.py +++ b/src/processors/ca_dmhc.py @@ -36,6 +36,8 @@ def process(source_lineitem: dict, output_dirname: str) -> dict: "decision": tuple[1], "appeal_type": tuple[2], "appeal_expedited_status": tuple[3], + "jurisdiction": "CA", + "insurance_type": "Commercial", } add_jsonl_line(outfile, line_data) diff --git a/src/processors/ny_dfs.py b/src/processors/ny_dfs.py index fa4c845..bb819e1 100644 --- a/src/processors/ny_dfs.py +++ b/src/processors/ny_dfs.py @@ -37,6 +37,15 @@ def extract_row_summaries(row) -> list[str]: axis=1, ).to_list() + coverage_type_map = { + "HMO": "Commercial", + "PPO": "Commercial", + "EPO": "Commercial", + "Self-Funded": "Commercial", + "Medicaid": "Medicaid", + "Managed Long Term Care": "Medicaid", + } + # Construct a record for each case summary # (some raw case adjudications include multiple summaries from different reviewers) for tuple in df_tuples: @@ -49,6 +58,8 @@ def extract_row_summaries(row) -> list[str]: "appeal_type": tuple[3], "diagnosis": tuple[4], "treatment": tuple[5], + "jurisdiction": "NY", + "insurance_type": coverage_type_map.get(tuple[2], "Unspecified"), } add_jsonl_line(outfile, line_data)