From 255ef1c81415f508dd1028b8e29b6785b8dc43e3 Mon Sep 17 00:00:00 2001 From: MikeG112 Date: Fri, 18 Apr 2025 18:01:34 -0400 Subject: [PATCH 01/27] Add insurance type metadata --- processed_sources.jsonl | 2 +- src/modeling/background_extraction.py | 10 +- src/modeling/train_outcome_predictor.py | 135 ++++++++++++++++++++++-- src/processors/ca_cdi.py | 2 + src/processors/ca_dmhc.py | 2 + src/processors/ny_dfs.py | 11 ++ 6 files changed, 152 insertions(+), 10 deletions(-) diff --git a/processed_sources.jsonl b/processed_sources.jsonl index c6e8d0e..0ab1d9e 100644 --- a/processed_sources.jsonl +++ b/processed_sources.jsonl @@ -1106,7 +1106,7 @@ {"url": "https://oig.hhs.gov/oei/reports/oei-02-16-00570.pdf", "date_accessed": "2024-01-25", "local_path": "./data/raw/hhs_oig/oei-02-16-00570.pdf", "tags": ["hhs-oig", "regulatory-guidance"], "preprocessor": "pdf", "md5": "ea67cf9aa08f8eeb92f43178a0259de8", "local_processed_path": "./data/processed/hhs_oig/oei-02-16-00570.jsonl"} {"url": "https://oig.hhs.gov/oas/reports/region9/91602042.pdf", "date_accessed": "2024-01-25", "local_path": "./data/raw/hhs_oig/91602042.pdf", "tags": ["hhs-oig", "regulatory-guidance"], "preprocessor": "pdf", "md5": "0956b057e0ada92e40dcc7bf02bb5cf2", "local_processed_path": "./data/processed/hhs_oig/91602042.jsonl"} {"url": "https://oig.hhs.gov/oei/reports/oei-03-15-00180.pdf", "date_accessed": "2024-01-25", "local_path": "./data/raw/hhs_oig/oei-03-15-00180.pdf", "tags": ["hhs-oig", "regulatory-guidance"], "preprocessor": "pdf", "md5": "fa56c5c8d598b139873e22a08131b69b", "local_processed_path": "./data/processed/hhs_oig/oei-03-15-00180.jsonl"} -{"url": "https://oig.hhs.gov/reports-and-publications/portfolio/portfolio-12-12-01.pdf", "date_accessed": "2024-01-25", "local_path": "./data/raw/hhs_oig/portfolio-12-12-01.pdf", "tags": ["hhs-oig", "regulatory-guidance"], "preprocessor": "pdf", "md5": "a0b1f3c770899d1ed57ced5a477e9e94", "local_processed_path": "./data/processed/hhs_oig/portfolio-12-12-01.jsonl"} +{"url": "https://oig.hhs.gov/reports-and-publications/portfolio/portfolio-12-12-01.pdf", "date_accessed": "2024-01-25", "local_path": "./data/raw/hhs_oig/portfolio-12-12-01.pdf", "tags": ["hhs-oig"], "preprocessor": "pdf", "md5": "a0b1f3c770899d1ed57ced5a477e9e94", "local_processed_path": "./data/processed/hhs_oig/portfolio-12-12-01.jsonl"} {"url": "https://oig.hhs.gov/documents/testimony/1119/bliss-testimony-05182023.pdf", "date_accessed": "2024-01-25", "local_path": "./data/raw/hhs_oig/bliss-testimony-05182023.pdf", "tags": ["hhs-oig", "testimony", "opinion-policy-summary"], "preprocessor": "pdf", "md5": "417d6519b93105b6a371602110bfe33c", "local_processed_path": "./data/processed/hhs_oig/bliss-testimony-05182023.jsonl"} {"url": "https://oig.hhs.gov/documents/testimony/1118/megan-tinker-testimony-05172023.pdf", "date_accessed": "2024-01-25", "local_path": "./data/raw/hhs_oig/megan-tinker-testimony-05172023.pdf", "tags": ["hhs-oig", "testimony", "opinion-policy-summary"], "preprocessor": "pdf", "md5": "e562fcf65a825e1e5d72ed55a930b993", "local_processed_path": "./data/processed/hhs_oig/megan-tinker-testimony-05172023.jsonl"} {"url": "https://oig.hhs.gov/documents/testimony/1112/christi-grimm-testimony-04182023.pdf", "date_accessed": "2024-01-25", "local_path": "./data/raw/hhs_oig/christi-grimm-testimony-04182023.pdf", "tags": ["hhs-oig", "testimony", "opinion-policy-summary"], "preprocessor": "pdf", "md5": "3d2f8e1e1fea7a9113ade3287acb7884", "local_processed_path": "./data/processed/hhs_oig/christi-grimm-testimony-04182023.jsonl"} diff --git a/src/modeling/background_extraction.py b/src/modeling/background_extraction.py index 542d590..c3ae903 100644 --- a/src/modeling/background_extraction.py +++ b/src/modeling/background_extraction.py @@ -210,6 +210,8 @@ def construct_recs(raw_recs: list[dict], tokenizer: AutoTokenizer, model: torch. "decision": rec["decision"], "appeal_type": rec["appeal_type"], "full_text": rec["text"], + "jurisdiction": rec.get("jurisdiction", "Unspecified"), + "insurance_type": rec.get("insurance_type", "Unspecified"), } backgrounds.append(updated_record) idx += 1 @@ -241,6 +243,8 @@ def construct_recs_batch( "appeal_type": rec["appeal_type"], "full_text": rec["text"], "sufficiency_id": sufficiency_id, + "jurisdiction": rec.get("jurisdiction", "Unspecified"), + "insurance_type": rec.get("insurance_type", "Unspecified"), } for (rec, background, sufficiency_id) in zip(batch, background_batch, sufficiency_batch) ] @@ -336,7 +340,7 @@ def combine_jsonl(directory) -> None: background_model = AutoModelForTokenClassification.from_pretrained(trained_model_path) # Load pretrained sufficiency model - pretrained_model_path = "distilbert/distilbert-base-cased" + pretrained_model_path = "distilbert/distilbert-base-uncased" background_dataset = "case-backgrounds" checkpoints_dir = f"./models/sufficiency_predictor/{background_dataset}/{pretrained_model_path}" trained_model_path = [f.path for f in os.scandir(checkpoints_dir) if f.is_dir()][ @@ -361,7 +365,7 @@ def combine_jsonl(directory) -> None: # Combine CA CDI files into single jsonl combine_jsonl("./data/processed/ca_cdi/summaries") - # Define Outcome Map / preprocessing + # Define Outcome Map / preprocessing and jurisdiction mapping # TODO: decide if this standardization belongs here, or in raw dataset storage (i.e in records we read here) extraction_targets = [ ( @@ -396,7 +400,7 @@ def combine_jsonl(directory) -> None: test_out_path = "./data/outcomes/test_backgrounds_suff.jsonl" train_subset = [] test_subset = [] - for path, outcome_map in extraction_targets: + for path, outcome_map, jurisdiction, insurance_type in extraction_targets: print(f"Processing dataset at {path}") # Get records and standardize outcome labels diff --git a/src/modeling/train_outcome_predictor.py b/src/modeling/train_outcome_predictor.py index 72dc61c..316c9e5 100644 --- a/src/modeling/train_outcome_predictor.py +++ b/src/modeling/train_outcome_predictor.py @@ -5,6 +5,7 @@ import numpy as np import scipy +import torch from datasets import Dataset, load_dataset from sklearn.metrics import ( accuracy_score, @@ -27,16 +28,24 @@ ID2LABEL = {0: "Insufficient", 1: "Upheld", 2: "Overturned"} LABEL2ID = {v: k for k, v in ID2LABEL.items()} + +# Define mappings for jurisdiction and insurance type +JURISDICTION_MAP = {"NY": 0, "CA": 1, "Unspecified": 2} +INSURANCE_TYPE_MAP = {"Commercial": 0, "Medicaid": 1, "Unspecified": 2} + OUTPUT_DIR = "./models/overturn_predictor" os.environ["TOKENIZERS_PARALLELISM"] = "false" def load_and_split( - jsonl_path: str, test_size: float = 0.2, filter_keys: list = ["decision", "text", "sufficiency_id"], seed: int = 2 + jsonl_path: str, + test_size: float = 0.2, + filter_keys: list = ["decision", "text", "sufficiency_id", "jurisdiction", "insurance_type"], + seed: int = 2, ) -> Dataset: if len(filter_keys) > 0: recs = get_records_list(jsonl_path) - recs = [{key: rec[key] for key in filter_keys} for rec in recs] + recs = [{key: rec.get(key, "Unspecified") for key in filter_keys} for rec in recs] dataset = Dataset.from_list(recs) else: dataset = load_dataset("json", data_files=jsonl_path)["train"] @@ -59,10 +68,20 @@ def construct_label(outcome: str, sufficiency_id: int, label2id: dict) -> int: return label2id[outcome] -def add_integral_ids_batch(examples, label2id: dict): +def add_integral_ids_batch(examples, label2id: dict, jurisdiction_map: dict, insurance_type_map: dict): outcomes = examples["decision"] sufficiency_ids = examples["sufficiency_id"] + + # Map jurisdiction and insurance_type to their respective IDs + jurisdictions = examples.get("jurisdiction", ["Unspecified"] * len(outcomes)) + insurance_types = examples.get("insurance_type", ["Unspecified"] * len(outcomes)) + examples["label"] = [construct_label(outcome, id, label2id) for (outcome, id) in zip(outcomes, sufficiency_ids)] + examples["jurisdiction_id"] = [jurisdiction_map.get(j, jurisdiction_map["Unspecified"]) for j in jurisdictions] + examples["insurance_type_id"] = [ + insurance_type_map.get(i, insurance_type_map["Unspecified"]) for i in insurance_types + ] + return examples @@ -155,6 +174,89 @@ def compute_metrics2(eval_pred) -> dict: return best_metrics +# Custom model class to handle the additional features +class TextClassificationWithMetadata(AutoModelForSequenceClassification): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Add embeddings for jurisdiction and insurance type (2 categories each + unspecified) + # We'll use 0 and 1 for the specific categories, and handle unspecified (2) separately + self.jurisdiction_embeddings = torch.nn.Embedding(3, 16) + self.insurance_type_embeddings = torch.nn.Embedding(3, 16) + + # Initialize the unspecified embeddings to be the average of the others + with torch.no_grad(): + # Initialize the unspecified embedding (index 2) as zeros + # It will be dynamically computed during forward pass + self.jurisdiction_embeddings.weight[2].fill_(0) + self.insurance_type_embeddings.weight[2].fill_(0) + + # Adjust the classifier to include these additional features + config = self.config + hidden_size = config.hidden_size + + # Create a new classifier with the additional features + self.classifier = torch.nn.Linear(hidden_size + 32, config.num_labels) + + def forward(self, input_ids=None, attention_mask=None, jurisdiction_id=None, insurance_type_id=None, **kwargs): + # Get the default output from parent + outputs = super().forward(input_ids=input_ids, attention_mask=attention_mask, **kwargs) + + # If we're not using the additional features, return the default outputs + if jurisdiction_id is None or insurance_type_id is None: + return outputs + + # Create masks for where the IDs are "Unspecified" (value 2) + j_unspecified_mask = jurisdiction_id == 2 + i_unspecified_mask = insurance_type_id == 2 + + # Get embeddings for the additional features + j_embeddings = self.jurisdiction_embeddings(jurisdiction_id) + i_embeddings = self.insurance_type_embeddings(insurance_type_id) + + # For unspecified jurisdiction, use the average of NY and CA embeddings + if j_unspecified_mask.any(): + # Calculate average of specific jurisdiction embeddings (indices 0 and 1) + avg_j_embedding = (self.jurisdiction_embeddings.weight[0] + self.jurisdiction_embeddings.weight[1]) / 2 + + # Apply the average embedding where jurisdiction is unspecified + j_embeddings[j_unspecified_mask] = avg_j_embedding + + # For unspecified insurance_type, use the average of Commercial and Medicaid embeddings + if i_unspecified_mask.any(): + # Calculate average of specific insurance_type embeddings (indices 0 and 1) + avg_i_embedding = (self.insurance_type_embeddings.weight[0] + self.insurance_type_embeddings.weight[1]) / 2 + + # Apply the average embedding where insurance_type is unspecified + i_embeddings[i_unspecified_mask] = avg_i_embedding + + # Concatenate with the pooled output + pooled_output = outputs.pooler_output + combined_features = torch.cat([pooled_output, j_embeddings, i_embeddings], dim=1) + + # Pass through the classifier + logits = self.classifier(combined_features) + + # Replace the logits in the outputs + outputs.logits = logits + + return outputs + + +# Custom collator to handle the additional features +class DataCollatorWithMetadata(DataCollatorWithPadding): + def __call__(self, features): + batch = super().__call__(features) + + # Add the jurisdiction and insurance type IDs to the batch + if "jurisdiction_id" in features[0]: + batch["jurisdiction_id"] = torch.tensor([f["jurisdiction_id"] for f in features]) + + if "insurance_type_id" in features[0]: + batch["insurance_type_id"] = torch.tensor([f["insurance_type_id"] for f in features]) + + return batch + + def main(config_path: str) -> None: cfg = load_config(config_path) @@ -193,17 +295,38 @@ def main(config_path: str) -> None: # Prepare dataset for training dataset = dataset.map(partial(tokenize_batch, tokenizer=tokenizer), batched=True) - dataset = dataset.map(partial(add_integral_ids_batch, label2id=LABEL2ID), batched=True) + dataset = dataset.map( + partial( + add_integral_ids_batch, + label2id=LABEL2ID, + jurisdiction_map=JURISDICTION_MAP, + insurance_type_map=INSURANCE_TYPE_MAP, + ), + batched=True, + ) + + # Use our custom data collator that handles the additional features + data_collator = DataCollatorWithMetadata(tokenizer=tokenizer) - data_collator = DataCollatorWithPadding(tokenizer=tokenizer) + # Load the base model + base_model = AutoModelForSequenceClassification.from_pretrained( + pretrained_model_key, + num_labels=3, + id2label=ID2LABEL, + label2id=LABEL2ID, + ) - model = AutoModelForSequenceClassification.from_pretrained( + # Create our custom model + model = TextClassificationWithMetadata.from_pretrained( pretrained_model_key, num_labels=3, id2label=ID2LABEL, label2id=LABEL2ID, ) + # Copy weights from base model to our custom model + model.load_state_dict(base_model.state_dict(), strict=False) + # Handle annoyance with HF / pretrained legalbert bug/user error # HF trainer complains of param data not being contiguous when loading checkpoints if base_model_name == "legal-bert-small-uncased": diff --git a/src/processors/ca_cdi.py b/src/processors/ca_cdi.py index c7d425f..07b2adf 100644 --- a/src/processors/ca_cdi.py +++ b/src/processors/ca_cdi.py @@ -63,6 +63,8 @@ def process(source_lineitem: dict, output_dirname: str) -> dict: "treatment": treatment, "decision": decision, "patient_race": patient_race, + "jurisdiction": "CA", + "insurance_type": "Commercial", } add_jsonl_line(outfile, line_data) diff --git a/src/processors/ca_dmhc.py b/src/processors/ca_dmhc.py index eb84326..7aba525 100644 --- a/src/processors/ca_dmhc.py +++ b/src/processors/ca_dmhc.py @@ -36,6 +36,8 @@ def process(source_lineitem: dict, output_dirname: str) -> dict: "decision": tuple[1], "appeal_type": tuple[2], "appeal_expedited_status": tuple[3], + "jurisdiction": "CA", + "insurance_type": "Commercial", } add_jsonl_line(outfile, line_data) diff --git a/src/processors/ny_dfs.py b/src/processors/ny_dfs.py index fa4c845..bb819e1 100644 --- a/src/processors/ny_dfs.py +++ b/src/processors/ny_dfs.py @@ -37,6 +37,15 @@ def extract_row_summaries(row) -> list[str]: axis=1, ).to_list() + coverage_type_map = { + "HMO": "Commercial", + "PPO": "Commercial", + "EPO": "Commercial", + "Self-Funded": "Commercial", + "Medicaid": "Medicaid", + "Managed Long Term Care": "Medicaid", + } + # Construct a record for each case summary # (some raw case adjudications include multiple summaries from different reviewers) for tuple in df_tuples: @@ -49,6 +58,8 @@ def extract_row_summaries(row) -> list[str]: "appeal_type": tuple[3], "diagnosis": tuple[4], "treatment": tuple[5], + "jurisdiction": "NY", + "insurance_type": coverage_type_map.get(tuple[2], "Unspecified"), } add_jsonl_line(outfile, line_data) From 9efe9560249db175d94dff128d20616934bddf5c Mon Sep 17 00:00:00 2001 From: MikeG112 Date: Sat, 19 Apr 2025 16:31:40 -0400 Subject: [PATCH 02/27] Fix predict script --- src/modeling/predict.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/modeling/predict.py b/src/modeling/predict.py index 9ef0a40..2213f14 100644 --- a/src/modeling/predict.py +++ b/src/modeling/predict.py @@ -29,12 +29,12 @@ def postprocess(self, model_outputs): LABEL2ID = {v: k for k, v in ID2LABEL.items()} # Load model and tokenizer - pretrained_model_key = "distilbert/distilbert-base-uncased" + pretrained_model_key = "distilbert/distilbert-base-cased" tokenizer = AutoTokenizer.from_pretrained(pretrained_model_key, model_max_length=512) - dataset_name = "train_backgrounds_suff" - checkpoints_dir = os.path.join(MODEL_DIR, dataset_name, pretrained_model_key) + dataset_name = "train_backgrounds_suff_augmented" + checkpoints_dir = os.path.join(MODEL_DIR, dataset_name, "distilbert") checkpoint_dirs = sorted(os.listdir(checkpoints_dir)) checkpoint_name = checkpoint_dirs[0] ckpt_path = os.path.join(checkpoints_dir, checkpoint_name) @@ -77,7 +77,7 @@ def postprocess(self, model_outputs): ) # Pytorch quantized model - model = classifier.model + model = classifier.model.to("cpu") # TODO: Fix this, this is not the right model to be quantizing via the torch or onnx ops below. model_int8 = torch.ao.quantization.quantize_dynamic( model, # the original model From ae344b2dab50919ca8588306fb5a8efcf6572240 Mon Sep 17 00:00:00 2001 From: MikeG112 Date: Sat, 19 Apr 2025 17:27:14 -0400 Subject: [PATCH 03/27] Fix extraction --- src/modeling/background_extraction.py | 295 +++++++++++++++++++++++++- 1 file changed, 293 insertions(+), 2 deletions(-) diff --git a/src/modeling/background_extraction.py b/src/modeling/background_extraction.py index c3ae903..5a89e33 100644 --- a/src/modeling/background_extraction.py +++ b/src/modeling/background_extraction.py @@ -3,6 +3,7 @@ from multiprocessing import Pool, cpu_count import torch +from datasets import Dataset from rapidfuzz import fuzz from sklearn.model_selection import train_test_split from transformers import ( @@ -11,6 +12,11 @@ AutoTokenizer, ) +from src.modeling.data_augmentation import ( + augment_sufficient_examples, + generate_unrelated_content, + rewrite_to_generic, +) from src.util import add_jsonl_lines, batcher, get_records_list @@ -323,6 +329,217 @@ def combine_jsonl(directory) -> None: return None +def create_augmented_examples( + records: list[dict], + api_type: str = "openai", + api_url: str = "https://api.openai.com/v1/chat/completions", + api_key: str | None = None, + model_name: str = "gpt-4o", + generic_rewrites_per_example: int = 1, + sufficient_augmentations_per_example: int = 1, + num_unrelated_examples: int = 100, + api_call_limit: int | None = 700, + seed: int = 42, +) -> list[dict]: + """ + Create augmented examples from background texts with optional limit on API calls. + + Args: + records: List of record dictionaries with text, decision, etc. + api_type: Type of API to use ("openai" or "llamacpp") + api_url: URL for the API endpoint + model_name: Name of the model to use + generic_rewrites_per_example: Number of generic rewrites per sufficient example + sufficient_augmentations_per_example: Number of sufficient augmentations per sufficient example + num_unrelated_examples: Number of unrelated content examples to generate + api_call_limit: Maximum number of API calls to make (if None, no limit) + seed: Random seed for reproducibility + + Returns: + List of augmented record dictionaries + """ + import random + + random.seed(seed) + + print(f"Creating augmented examples using {api_type} API at {api_url}...") + augmented_records = [] + + # Track API calls + api_calls_made = 0 + + # 1. Create dataset for augmentation functions + dataset_records = [] + for rec in records: + # Check if record has sufficiency_id + sufficiency_score = 4 # Default to sufficient for original records + if "sufficiency_id" in rec: + sufficiency_score = 4 if rec["sufficiency_id"] == 1 else 2 + + dataset_records.append({"text": rec["text"], "sufficiency_score": sufficiency_score}) + + dataset = Dataset.from_list(dataset_records) + + # If we have a limit, we need to randomly select which examples to augment + sufficient_examples = [ex for ex in dataset_records if ex["sufficiency_score"] >= 3] + num_sufficient = len(sufficient_examples) + + # Calculate how many API calls each technique would require + total_generic_calls = num_sufficient * generic_rewrites_per_example + total_sufficient_calls = num_sufficient * sufficient_augmentations_per_example + total_unrelated_calls = num_unrelated_examples + + # Calculate total potential API calls + total_potential_calls = total_generic_calls + total_sufficient_calls + total_unrelated_calls + + # If we have a limit and it's less than the potential total, adjust + if api_call_limit is not None and api_call_limit < total_potential_calls: + print(f"API call limit ({api_call_limit}) is less than potential total ({total_potential_calls})") + + # Distribute the limit evenly across the three techniques + calls_per_technique = api_call_limit // 3 + + # Calculate adjusted calls for each technique + adjusted_generic_calls = calls_per_technique + adjusted_sufficient_calls = calls_per_technique + adjusted_unrelated_calls = ( + api_call_limit - adjusted_generic_calls - adjusted_sufficient_calls + ) # Use remainder for unrelated + + # Calculate how many examples we can augment for each technique + examples_for_generic = adjusted_generic_calls // generic_rewrites_per_example + examples_for_sufficient = adjusted_sufficient_calls // sufficient_augmentations_per_example + + # Randomly select examples to augment + if examples_for_generic < num_sufficient: + examples_generic = random.sample(sufficient_examples, examples_for_generic) + else: + examples_generic = sufficient_examples + + if examples_for_sufficient < num_sufficient: + examples_sufficient = random.sample(sufficient_examples, examples_for_sufficient) + else: + examples_sufficient = sufficient_examples + + # Adjust unrelated examples count + adjusted_num_unrelated = adjusted_unrelated_calls + + print("Adjusted numbers based on API call limit:") + print(f" Generic rewrites: {examples_for_generic} examples ({adjusted_generic_calls} calls)") + print(f" Sufficient augmentations: {examples_for_sufficient} examples ({adjusted_sufficient_calls} calls)") + print(f" Unrelated content: {adjusted_num_unrelated} examples ({adjusted_unrelated_calls} calls)") + + # Create filtered datasets for limited augmentation + generic_dataset = Dataset.from_list(examples_generic) + sufficient_dataset = Dataset.from_list(examples_sufficient) + + # Update parameters + generic_params_count = adjusted_generic_calls + sufficient_params_count = adjusted_sufficient_calls + unrelated_params_count = adjusted_unrelated_calls + else: + # No limit or limit is high enough, use all examples + generic_dataset = dataset + sufficient_dataset = dataset + generic_params_count = total_generic_calls + sufficient_params_count = total_sufficient_calls + unrelated_params_count = num_unrelated_examples + + # 2. Apply generic rewrite augmentation (make sufficient examples insufficient) + if generic_params_count > 0: + print("Generating generic rewrites...") + generic_rewrite_params = { + "num_augmentations_per_example": generic_rewrites_per_example, + "api_type": api_type, + "api_url": api_url, + "api_key": api_key, + "model_name": model_name, + "seed": seed, + } + + generic_examples = rewrite_to_generic(generic_dataset, **generic_rewrite_params) + api_calls_made += len(generic_examples) + print(f"Generated {len(generic_examples)} examples through generic rewriting") + else: + generic_examples = [] + print("Skipping generic rewrites due to API call limit") + + # 3. Apply sufficient example augmentation (keep sufficient examples sufficient) + if sufficient_params_count > 0: + print("Generating sufficient augmentations...") + sufficient_augmentation_params = { + "num_augmentations_per_example": sufficient_augmentations_per_example, + "api_type": api_type, + "api_url": api_url, + "api_key": api_key, + "model_name": model_name, + "seed": seed, + } + + sufficient_examples = augment_sufficient_examples(sufficient_dataset, **sufficient_augmentation_params) + api_calls_made += len(sufficient_examples) + print(f"Generated {len(sufficient_examples)} augmented sufficient examples") + else: + sufficient_examples = [] + print("Skipping sufficient augmentations due to API call limit") + + # 4. Generate unrelated content + if unrelated_params_count > 0: + print("Generating unrelated content...") + unrelated_params = { + "num_examples": unrelated_params_count, + "api_type": api_type, + "api_url": api_url, + "api_key": api_key, + "model_name": model_name, + "seed": seed, + } + + unrelated_examples = generate_unrelated_content(**unrelated_params) + api_calls_made += len(unrelated_examples) + print(f"Generated {len(unrelated_examples)} examples with unrelated content") + else: + unrelated_examples = [] + print("Skipping unrelated content due to API call limit") + + # 5. Convert augmented texts back to record format + # For generic rewrites and sufficient augmentations, we need to find the original record + all_records_dict = {rec["text"]: rec for rec in records} + + for aug_example in generic_examples + sufficient_examples: + source_text = aug_example["source_text"] + # Find the original record + original_rec = all_records_dict.get(source_text) + if original_rec: + augmented_rec = original_rec.copy() + augmented_rec["text"] = aug_example["text"] + augmented_rec["sufficiency_id"] = 1 if aug_example["sufficiency_score"] >= 3 else 0 + augmented_rec["augmentation_type"] = aug_example["augmentation_type"] + augmented_records.append(augmented_rec) + + # For unrelated content, create new records with random decision label + decisions = ["Upheld", "Overturned"] + appeal_types = ["IMR", "DMHC", "CDI"] # Example appeal types + + for i, aug_example in enumerate(unrelated_examples): + augmented_rec = { + "text": aug_example["text"], + "decision": random.choice(decisions), # Random decision since unrelated to actual cases + "appeal_type": random.choice(appeal_types), + "full_text": aug_example["text"], # No original full text + "sufficiency_id": 0, # Always insufficient + "jurisdiction": "Unspecified", + "insurance_type": "Unspecified", + "augmentation_type": aug_example["augmentation_type"], + "id": f"unrelated_{i}", + } + augmented_records.append(augmented_rec) + + print(f"Total API calls made: {api_calls_made}") + print(f"Total augmented records created: {len(augmented_records)}") + return augmented_records + + if __name__ == "__main__": # Config applied to both models device = "cuda" @@ -340,7 +557,7 @@ def combine_jsonl(directory) -> None: background_model = AutoModelForTokenClassification.from_pretrained(trained_model_path) # Load pretrained sufficiency model - pretrained_model_path = "distilbert/distilbert-base-uncased" + pretrained_model_path = "distilbert/distilbert-base-cased" background_dataset = "case-backgrounds" checkpoints_dir = f"./models/sufficiency_predictor/{background_dataset}/{pretrained_model_path}" trained_model_path = [f.path for f in os.scandir(checkpoints_dir) if f.is_dir()][ @@ -398,9 +615,14 @@ def combine_jsonl(directory) -> None: # We will also split a train and test set for consistent experiments train_out_path = "./data/outcomes/train_backgrounds_suff.jsonl" test_out_path = "./data/outcomes/test_backgrounds_suff.jsonl" + + # New paths for augmented datasets + train_augmented_out_path = "./data/outcomes/train_backgrounds_suff_augmented.jsonl" + test_augmented_out_path = "./data/outcomes/test_backgrounds_suff_augmented.jsonl" + train_subset = [] test_subset = [] - for path, outcome_map, jurisdiction, insurance_type in extraction_targets: + for path, outcome_map in extraction_targets: print(f"Processing dataset at {path}") # Get records and standardize outcome labels @@ -437,5 +659,74 @@ def combine_jsonl(directory) -> None: train_subset.extend(train_recs) test_subset.extend(val_recs) + # Write original train and test sets add_jsonl_lines(train_out_path, train_subset) add_jsonl_lines(test_out_path, test_subset) + + print(f"Original train set: {len(train_subset)} examples") + print(f"Original test set: {len(test_subset)} examples") + + # Check for OpenAI API key in environment variable + api_key = os.environ.get("OPENAI_API_KEY") + api_type = "openai" if api_key else "llamacpp" + api_url = "https://api.openai.com/v1/chat/completions" if api_key else "http://localhost:8080/completion" + model_name = "gpt-4o" if api_key else "llama-3.1" + + print(f"Using {api_type} API for augmentation") + + train_subset = get_records_list(train_out_path) + test_subset = get_records_list(test_out_path) + + # Create augmented examples for train set + print("\n=== Creating Augmented Examples for Train Set ===") + train_augmented = create_augmented_examples( + train_subset, + api_type=api_type, + api_url=api_url, + api_key=api_key, + model_name=model_name, + generic_rewrites_per_example=2, + sufficient_augmentations_per_example=2, + num_unrelated_examples=100, + seed=42, + ) + + # Write augmented train and test sets + # Combine original and augmented examples + train_combined = train_subset + train_augmented + add_jsonl_lines(train_augmented_out_path, train_combined) + + # Create augmented examples for test set + print("\n=== Creating Augmented Examples for Test Set ===") + test_augmented = create_augmented_examples( + test_subset, + api_type=api_type, + api_url=api_url, + api_key=api_key, + model_name=model_name, + generic_rewrites_per_example=1, # Fewer augmentations for test set + sufficient_augmentations_per_example=1, + num_unrelated_examples=10, + seed=43, # Different seed for test set + ) + + # Combine original and augmented examples + test_combined = test_subset + test_augmented + + # Write augmented test set + add_jsonl_lines(test_augmented_out_path, test_combined) + + # print(f"\nFinal augmented train set: {len(train_combined)} examples") + print(f"Final augmented test set: {len(test_combined)} examples") + + # Print augmentation statistics + train_augmented_types = {} + for rec in train_augmented: + aug_type = rec.get("augmentation_type", "unknown") + if aug_type not in train_augmented_types: + train_augmented_types[aug_type] = 0 + train_augmented_types[aug_type] += 1 + + print("\nTrain set augmentation breakdown:") + for aug_type, count in train_augmented_types.items(): + print(f" {aug_type}: {count}") From d9ca1514fcbeef4d305047b49fb3243d1070cedc Mon Sep 17 00:00:00 2001 From: MikeG112 Date: Sun, 20 Apr 2025 15:40:28 -0400 Subject: [PATCH 04/27] bump transformers --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 9cbbf5d..4c83c7a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,7 +14,7 @@ pre-commit==3.7.1 RapidFuzz==3.13.0 requests==2.32.3 torch==2.6.0 -transformers==4.48.3 +transformers==4.51.3 evaluate scikit-learn==1.5.0 selenium==4.17.2 From 8f0a13f78075c2a6494838717e97de402e728cb9 Mon Sep 17 00:00:00 2001 From: MikeG112 Date: Sun, 20 Apr 2025 15:49:01 -0400 Subject: [PATCH 05/27] Update predict --- src/modeling/predict.py | 363 +++++++++++++++++++++++++++------------- 1 file changed, 249 insertions(+), 114 deletions(-) diff --git a/src/modeling/predict.py b/src/modeling/predict.py index 2213f14..816d58e 100644 --- a/src/modeling/predict.py +++ b/src/modeling/predict.py @@ -1,101 +1,253 @@ +#!/usr/bin/env python +import argparse import os -import time import numpy as np import onnxruntime import scipy import torch -from transformers import ( - AutoModelForSequenceClassification, - AutoTokenizer, - TextClassificationPipeline, - pipeline, -) +from colorama import Fore, Style, init +from transformers import AutoTokenizer +from src.modeling.train_outcome_predictor import TextClassificationWithMetadata from src.modeling.util import export_onnx_model, quantize_onnx_model +# Initialize colorama +init() + MODEL_DIR = "./models/overturn_predictor" +# Test examples with expected classifications +TEST_EXAMPLES = [ + { + "text": "Diagnosis: Broken Ribs\nTreatment: Inpatient Hospital Admission\n\nThe insurer denied inpatient hospital admission. \n\nThe patient is an adult male. He presented by ambulance to the hospital with severe back pain. The patient had fallen down a ramp and onto his back two days prior. The patient developed back pain and had pain with deep inspiration, prompting a call to 911 for an ambulance. The patient was taking ibuprofen and Tylenol for pain at home. A computed tomography (CT) scan of the patient's chest showed a right posterior minimally displaced 9th and 10th rib fractures. There was no associated intra-abdominal injury. There was atelectasis of the lung in the region of the rib fractures. Vital signs, including oxygen saturation, were normal in the emergency department triage note. The patient did not require supplemental oxygen during the hospitalization. The patient was admitted to the acute inpatient level of care for pain control, breathing treatments, and venous thromboembolism prophylaxis. The patient was seen and cleared by Physical Therapy. The patient's pain was controlled with oral analgesia and a lidocaine patch. Total time in the hospital was less than 13 hours. The acute inpatient level of care was denied coverage by the health plan as not medically necessary.", + "expected": "Upheld", + }, + { + "text": "Diagnosis: General Debility Due to Lumbar Stenosis\nTreatment: Continued Stay in Skilled Nursing Facility\n\nThe insurer denied continued stay in skilled nursing facility\n\nThe patient is an adult female with a history of general debility due to lumbar stenosis affecting her functional mobility and activities of daily living (ADLs). She has impairments of balance, mobility, and strength, with an increased risk for falls.\n\nThe patient's relevant past medical history includes obesity status post gastric sleeve times two (x2), severe knee and hip osteoarthritis, anxiety, bipolar disorder, hiatus hernia, depression, asthma, hiatus hernia/gastroesophageal reflux disease (GERD), fractured ribs, fractured ankle, sarcoidosis, and pulmonary embolism. Before admission, the patient was living with family and friends in a house, independent with activities of daily living, and with support from others. The patient was admitted to a skilled nursing facility (SNF) three months ago, requiring total dependence for most activities of daily living, and as of two months ago, the patient was non-ambulatory, requiring supervision for bed mobility, contact guard for transfers, and maximum assistance for static standing. She has limitations in completing mobility and locomotive activities due to gross weakness of the bilateral lower extremities, decreased stability and controlled mobility, increased pain, impaired coordination, and decreased aerobic capacity.", + "expected": "Overturned", + }, + { + "text": "Diagnosis: Dilated CBD, distended gallbladder\n \nTreatment: Inpatient admission, diagnostic treatment and surgery\n\nThe insurer denied the inpatient admission. The patient presented with abdominal pain. He was afebrile and the vital signs were stable. There was no abdominal rebound or guarding. The WBC count was 14.5. The bilirubin was normal. A CAT scan revealed a dilated CBD and a distended gallbladder. An MRCP revealed a CBD stone with bile duct dilatation. The patient was treated with antibiotics. He underwent an ERCP with sphincterotomy and balloon sweeps. A laparoscopic cholecystectomy was then done. The patient remained hemodynamically stable and his pain was controlled.", + "expected": "Upheld", + }, + { + "text": "This is a female patient with a medical history of severe bilateral proliferative diabetic retinopathy and diabetic macular edema. The patient underwent an injection of Lucentis in her left eye and treatment with panretinal photocoagulation without complications. It was reported that the patient had severe disease with many dot/blot hemorrhages. Documentation revealed the patient had arteriovenous (AV) crossing changes bilaterally with venous tortuosity. There were scattered dot/blot hemorrhages bilaterally to the macula and periphery and macular edema. Additionally, she was counseled on proper diet control, exercise and hypertension control. Avastin and Mvasi are the same drug - namely bevacizumab: as per lexicomp: 'humanized monoclonal antibody which binds to, and neutralizes, vascular endothelial growth factor (VEGF), preventing its association with endothelial receptors, Flt-1 and KDR. VEGF binding initiates angiogenesis (endothelial proliferation and the formation of new blood vessels). The inhibition of microvascular growth is believed to retard the growth of all tissues (including metastatic tissue).' Lucentis is ranibizumab: as per lexicomp: 'a recombinant humanized monoclonal antibody fragment which binds to and inhibits human vascular endothelial growth factor A (VEGF-A). Ranibizumab inhibits VEGF from binding to its receptors and thereby suppressing neovascularization and slowing vision loss.' The formulary, step therapy options and the requested drug act against VEGF. There is no suggestion that Avastin or Mvasi would cause physical or mental harm to the patient. There are no contraindications in the documentation that would put the patient at risk for adverse reactions. This patient has a diagnosis of maculopathy. Avastin and Mvasi have been shown to be helpful with this condition.", + "expected": "Upheld", + }, + {"text": "hello? is this relevant?", "expected": "Insufficient"}, + { + "text": "A patient is being denied wegovy for morbid obesity. The health plan states it is not medically necessary.", + "expected": "Overturned", + }, +] -class ClassificationPipeline(TextClassificationPipeline): - def postprocess(self, model_outputs): - out_logits = model_outputs["logits"] - return out_logits +def run_pytorch_model(model, tokenizer, text, jurisdiction_id=2, insurance_type_id=2): + """Run inference with the PyTorch model""" + model.eval() + tokenized = tokenizer(text, return_tensors="pt", truncation=True, padding=True) + with torch.no_grad(): + # Convert jurisdiction and insurance type to tensors + j_id = torch.tensor([jurisdiction_id]) + i_id = torch.tensor([insurance_type_id]) -if __name__ == "__main__": - # TODO: get these from model checkpoints + result = model(**tokenized, jurisdiction_id=j_id, insurance_type_id=i_id) + + probs = torch.softmax(result["logits"], dim=-1) + prob, argmax = torch.max(probs, dim=-1) + + return { + "class_id": argmax.item(), + "class_name": ID2LABEL[argmax.item()], + "confidence": prob.item(), + "probs": probs[0].tolist(), + } + + +def run_pytorch_model_no_metadata(model, tokenizer, text): + """Run inference with the PyTorch model without metadata""" + model.eval() + tokenized = tokenizer(text, return_tensors="pt", truncation=True, padding=True) + with torch.no_grad(): + try: + # Try running without metadata + result = model(**tokenized) + probs = torch.softmax(result["logits"], dim=-1) + prob, argmax = torch.max(probs, dim=-1) + + return { + "class_id": argmax.item(), + "class_name": ID2LABEL[argmax.item()], + "confidence": prob.item(), + "probs": probs[0].tolist(), + } + except Exception as e: + return {"error": str(e), "class_name": "ERROR", "confidence": 0.0, "probs": [0.0, 0.0, 0.0]} + + +def run_pytorch_quantized(model_int8, tokenizer, text, jurisdiction_id=2, insurance_type_id=2): + """Run inference with the quantized PyTorch model""" + model_int8.eval() + tokenized = tokenizer(text, return_tensors="pt", truncation=True, padding=True) + with torch.no_grad(): + # Convert jurisdiction and insurance type to tensors + j_id = torch.tensor([jurisdiction_id]) + i_id = torch.tensor([insurance_type_id]) + + result = model_int8(**tokenized, jurisdiction_id=j_id, insurance_type_id=i_id) + + probs = torch.softmax(result["logits"], dim=-1) + prob, argmax = torch.max(probs, dim=-1) + + return { + "class_id": argmax.item(), + "class_name": ID2LABEL[argmax.item()], + "confidence": prob.item(), + "probs": probs[0].tolist(), + } + + +def run_onnx(session, tokenizer, text): + """Run inference with ONNX Runtime""" + try: + inputs = tokenizer(text, return_tensors="np", truncation=True, padding=True) + outputs = session.run(output_names=["logits"], input_feed=dict(inputs)) + result = scipy.special.softmax(outputs[0], axis=-1) + + argmax = np.argmax(result[0]) + prob = result[0][argmax] + + return { + "class_id": int(argmax), + "class_name": ID2LABEL[int(argmax)], + "confidence": float(prob), + "probs": result[0].tolist(), + } + except Exception as e: + return {"error": str(e), "class_name": "ERROR", "confidence": 0.0, "probs": [0.0, 0.0, 0.0]} + + +def print_result(model_name, result, expected=None, show_probs=False): + """Print inference result with color-coding based on matching expected output""" + if "error" in result: + print(f"{model_name}: {Fore.RED}ERROR{Style.RESET_ALL} - {result['error']}") + return + + # Determine color based on expected value + color = Fore.WHITE + if expected: + if result["class_name"] == expected: + color = Fore.GREEN + else: + color = Fore.RED + + # Print result + print(f"{model_name}: {color}{result['class_name']}{Style.RESET_ALL} ({result['confidence']:.4f})") + + # Optionally show probabilities + if show_probs: + probs_str = ", ".join([f"{ID2LABEL[i]}: {p:.4f}" for i, p in enumerate(result["probs"])]) + print(f" Probabilities: {probs_str}") + + +def test_all_examples(model, model_int8, onnx_session, onnx_quant_session, tokenizer): + """Run all test examples through all model variants""" + print(f"\n{Fore.CYAN}{Style.BRIGHT}===== TESTING ALL EXAMPLES ====={Style.RESET_ALL}") + + for i, example in enumerate(TEST_EXAMPLES): + text = example["text"] + expected = example["expected"] + + # Print a shortened version of the text + print(f"\n{Fore.YELLOW}{Style.BRIGHT}Example {i+1}: '{text[:250]}...' (Expected: {expected}){Style.RESET_ALL}") + + # Run through all model variants + pt_result = run_pytorch_model(model, tokenizer, text) + pt_no_metadata_result = run_pytorch_model_no_metadata(model, tokenizer, text) + pt_quant_result = run_pytorch_quantized(model_int8, tokenizer, text) + + onnx_result = run_onnx(onnx_session, tokenizer, text) + onnx_quant_result = run_onnx(onnx_quant_session, tokenizer, text) + + # Print results + print_result("PyTorch (with metadata)", pt_result, expected) + print_result("PyTorch (no metadata)", pt_no_metadata_result, expected) + print_result("PyTorch (quantized)", pt_quant_result, expected) + print_result("ONNX", onnx_result, expected) + print_result("ONNX (quantized)", onnx_quant_result, expected) + + # Print probabilities from PyTorch model + probs_str = ", ".join([f"{ID2LABEL[i]}: {p:.4f}" for i, p in enumerate(pt_result["probs"])]) + print(f"Probabilities: {probs_str}") + + print("-" * 80) + + +def process_single_prompt(model, model_int8, onnx_session, onnx_quant_session, tokenizer, text): + """Process a single user-provided prompt through all model variants""" + print(f"\n{Fore.YELLOW}{Style.BRIGHT}Processing: '{text[:250]}...'{Style.RESET_ALL}") + + # Run through all model variants + pt_result = run_pytorch_model(model, tokenizer, text) + pt_no_metadata_result = run_pytorch_model_no_metadata(model, tokenizer, text) + pt_quant_result = run_pytorch_quantized(model_int8, tokenizer, text) + + onnx_result = run_onnx(onnx_session, tokenizer, text) + onnx_quant_result = run_onnx(onnx_quant_session, tokenizer, text) + + # Print results + print_result("PyTorch (with metadata)", pt_result, show_probs=True) + print_result("PyTorch (no metadata)", pt_no_metadata_result, show_probs=True) + print_result("PyTorch (quantized)", pt_quant_result) + print_result("ONNX", onnx_result) + print_result("ONNX (quantized)", onnx_quant_result) + + +def main(): + parser = argparse.ArgumentParser(description="Medical Classification Model Inference") + parser.add_argument("--test", action="store_true", help="Run all test examples") + parser.add_argument("--prompt", type=str, help="Text to classify") + args = parser.parse_args() + + if not args.test and not args.prompt: + parser.error("Either --test or --prompt must be specified") + + # Model labels + global ID2LABEL, LABEL2ID ID2LABEL = {0: "Insufficient", 1: "Upheld", 2: "Overturned"} LABEL2ID = {v: k for k, v in ID2LABEL.items()} # Load model and tokenizer - pretrained_model_key = "distilbert/distilbert-base-cased" - - tokenizer = AutoTokenizer.from_pretrained(pretrained_model_key, model_max_length=512) - dataset_name = "train_backgrounds_suff_augmented" checkpoints_dir = os.path.join(MODEL_DIR, dataset_name, "distilbert") checkpoint_dirs = sorted(os.listdir(checkpoints_dir)) checkpoint_name = checkpoint_dirs[0] ckpt_path = os.path.join(checkpoints_dir, checkpoint_name) - model = AutoModelForSequenceClassification.from_pretrained( - ckpt_path, num_labels=3, id2label=ID2LABEL, label2id=LABEL2ID - ) - # Upheld example - text = "Diagnosis: Broken Ribs\nTreatment: Inpatient Hospital Admission\n\nThe insurer denied inpatient hospital admission. \n\nThe patient is an adult male. He presented by ambulance to the hospital with severe back pain. The patient had fallen down a ramp and onto his back two days prior. The patient developed back pain and had pain with deep inspiration, prompting a call to 911 for an ambulance. The patient was taking ibuprofen and Tylenol for pain at home. A computed tomography (CT) scan of the patient's chest showed a right posterior minimally displaced 9th and 10th rib fractures. There was no associated intra-abdominal injury. There was atelectasis of the lung in the region of the rib fractures. Vital signs, including oxygen saturation, were normal in the emergency department triage note. The patient did not require supplemental oxygen during the hospitalization. The patient was admitted to the acute inpatient level of care for pain control, breathing treatments, and venous thromboembolism prophylaxis. The patient was seen and cleared by Physical Therapy. The patient's pain was controlled with oral analgesia and a lidocaine patch. Total time in the hospital was less than 13 hours. The acute inpatient level of care was denied coverage by the health plan as not medically necessary." - - # Overturned example - # text = "Diagnosis: General Debility Due to Lumbar Stenosis\nTreatment: Continued Stay in Skilled Nursing Facility\n\nThe insurer denied continued stay in skilled nursing facility\n\nThe patient is an adult female with a history of general debility due to lumbar stenosis affecting her functional mobility and activities of daily living (ADLs). She has impairments of balance, mobility, and strength, with an increased risk for falls.\n\nThe patient's relevant past medical history includes obesity status post gastric sleeve times two (x2), severe knee and hip osteoarthritis, anxiety, bipolar disorder, hiatus hernia, depression, asthma, hiatus hernia/gastroesophageal reflux disease (GERD), fractured ribs, fractured ankle, sarcoidosis, and pulmonary embolism. Before admission, the patient was living with family and friends in a house, independent with activities of daily living, and with support from others. The patient was admitted to a skilled nursing facility (SNF) three months ago, requiring total dependence for most activities of daily living, and as of two months ago, the patient was non-ambulatory, requiring supervision for bed mobility, contact guard for transfers, and maximum assistance for static standing. She has limitations in completing mobility and locomotive activities due to gross weakness of the bilateral lower extremities, decreased stability and controlled mobility, increased pain, impaired coordination, and decreased aerobic capacity. \n" - # text = "Diagnosis: Dilated CBD, distended gallbladder\n \nTreatment: Inpatient admission, diagnostic treatment and surgery\n\nThe insurer denied the inpatient admission. The patient presented with abdominal pain. He was afebrile and the vital signs were stable. There was no abdominal rebound or guarding. The WBC count was 14.5. The bilirubin was normal. A CAT scan revealed a dilated CBD and a distended gallbladder. An MRCP revealed a CBD stone with bile duct dilatation. The patient was treated with antibiotics. He underwent an ERCP with sphincterotomy and balloon sweeps. A laparoscopic cholecystectomy was then done. The patient remained hemodynamically stable and his pain was controlled." - # text = "This is a female patient with a medical history of severe bilateral proliferative diabetic retinopathy and diabetic macular edema. The patient underwent an injection of Lucentis in her left eye and treatment with panretinal photocoagulation without complications. It was reported that the patient had severe disease with many dot/blot hemorrhages. Documentation revealed the patient had arteriovenous (AV) crossing changes bilaterally with venous tortuosity. There were scattered dot/blot hemorrhages bilaterally to the macula and periphery and macular edema. Additionally, she was counseled on proper diet control, exercise and hypertension control. Avastin and Mvasi are the same drug - namely bevacizumab: as per lexicomp: 'humanized monoclonal antibody which binds to, and neutralizes, vascular endothelial growth factor (VEGF), preventing its association with endothelial receptors, Flt-1 and KDR. VEGF binding initiates angiogenesis (endothelial proliferation and the formation of new blood vessels). The inhibition of microvascular growth is believed to retard the growth of all tissues (including metastatic tissue).' Lucentis is ranibizumab: as per lexicomp: 'a recombinant humanized monoclonal antibody fragment which binds to and inhibits human vascular endothelial growth factor A (VEGF-A). Ranibizumab inhibits VEGF from binding to its receptors and thereby suppressing neovascularization and slowing vision loss.' The formulary, step therapy options and the requested drug act against VEGF. There is no suggestion that Avastin or Mvasi would cause physical or mental harm to the patient. There are no contraindications in the documentation that would put the patient at risk for adverse reactions. This patient has a diagnosis of maculopathy. Avastin and Mvasi have been shown to be helpful with this condition." - # tokens = tokenizer(text) - # with torch.no_grad(): - # output = model(torch.tensor(tokens["input_ids"])) - # print(output) - - # Upheld - # text = "A patient is being denied wegovy for morbid obesity. The health plan states it is not medically necessary." - - classifier = ClassificationPipeline(model=model, tokenizer=tokenizer) - classifier2 = pipeline("text-classification", model=model, tokenizer=tokenizer) - start = time.time() - result = classifier(text) - # result = classifier2(text) - end = time.time() - print("Vanilla HF pipeline:") - print(f"Logits: {result[0]}") - probs = torch.softmax(result[0], dim=-1) - print(f"Probs: {probs}") - prob, argmax = torch.max(probs, dim=-1) - print(f"Class pred: {argmax.item()}, Score: {prob.item()}") - print(f"Latency: {end-start}") + print(f"{Fore.CYAN}Loading models...{Style.RESET_ALL}") + + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained(ckpt_path, model_max_length=512) + + # Load PyTorch model + model = TextClassificationWithMetadata.from_pretrained( + ckpt_path, + num_labels=3, + id2label=ID2LABEL, + label2id=LABEL2ID, + ) print( - "Model size (MB):", + "Vanilla Pytorch Model size (MB):", round(os.path.getsize(os.path.join(ckpt_path, "model.safetensors")) / (1024 * 1024)), - "\n", ) - # Pytorch quantized model - model = classifier.model.to("cpu") - # TODO: Fix this, this is not the right model to be quantizing via the torch or onnx ops below. + # Load quantized PyTorch model model_int8 = torch.ao.quantization.quantize_dynamic( - model, # the original model - {torch.nn.Linear}, # a set of layers to dynamically quantize + model.to("cpu"), + {torch.nn.Linear}, dtype=torch.qint8, - ) # the target dtype for quantized weights - model_int8.eval() - start = time.time() - tokenized = tokenizer(text, return_tensors="pt") - with torch.no_grad(): - result = model_int8(**tokenized) - end = time.time() - print("Quantized pytorch model:") - probs = torch.softmax(result.logits, dim=-1) - print(f"Probs: {probs}") - prob, argmax = torch.max(probs, dim=-1) - print(f"Class pred: {argmax.item()}, Score: {prob.item()}") - print(f"Latency: {end-start}") + ) param_size = 0 for param in model_int8.parameters(): param_size += param.nelement() * param.element_size() @@ -104,59 +256,42 @@ def postprocess(self, model_outputs): buffer_size += buffer.nelement() * buffer.element_size() size_all_mb = (param_size + buffer_size) / (1024 * 1024) - print(f"Model size (MB): {round(size_all_mb)}\n") + print(f"Quantized Pytorch Model size (MB): {round(size_all_mb)}") - # Export onxx model and quantized version, if nonexistent - onnx_file_name = "model.onnx" - onnx_model_path = os.path.join(ckpt_path, onnx_file_name) + # Load ONNX models + onnx_model_path = os.path.join(ckpt_path, "model.onnx") quant_onnx_model_path = os.path.join(ckpt_path, "quant-model.onnx") - export_onnx_model(onnx_model_path, model, tokenizer) + + # Ensure ONNX models exist + if not os.path.exists(onnx_model_path): + print("Exporting model to ONNX format...") + export_onnx_model(onnx_model_path, model, tokenizer) + print( + "Onnx Model size (MB):", + round(os.path.getsize(onnx_model_path) / (1024 * 1024)), + ) if not os.path.exists(quant_onnx_model_path): + print("Quantizing ONNX model...") quantize_onnx_model(onnx_model_path, quant_onnx_model_path) - - print( - "ONNX full precision model size (MB):", - os.path.getsize(onnx_model_path) / (1024 * 1024), - ) - - # Try running with onnx runtime (quantized) - sess_options = onnxruntime.SessionOptions() - sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL - session = onnxruntime.InferenceSession(quant_onnx_model_path, sess_options) - start = time.time() - inputs = tokenizer(text, return_tensors="np") - outputs = session.run(output_names=["logits"], input_feed=dict(inputs)) - result = scipy.special.softmax(outputs[0], axis=-1) - end = time.time() - print("Quantized Onnx:") - print(f"Probs: {result[0]}") - argmax = np.argmax(result[0], axis=-1) - prob = result[0][argmax] - print(f"Class pred: {argmax}, Score: {prob}") - print(f"Latency: {end-start}") print( - "Model size (MB):", + "Quantized Onnx Model size (MB):", round(os.path.getsize(quant_onnx_model_path) / (1024 * 1024)), - "\n", ) - - # Try running with onnx runtime + # Create ONNX sessions sess_options = onnxruntime.SessionOptions() sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL - session = onnxruntime.InferenceSession(onnx_model_path, sess_options) - start = time.time() - inputs = tokenizer(text, return_tensors="np") - outputs = session.run(output_names=["logits"], input_feed=dict(inputs)) - result = scipy.special.softmax(outputs[0], axis=-1) - end = time.time() - print("Onnx:") - print(f"Probs: {result}") - argmax = np.argmax(result[0], axis=-1) - prob = result[0][argmax] - print(f"Class pred: {argmax}, Score: {prob}") - print(f"Latency: {end-start}") - print( - "Model size (MB):", - round(os.path.getsize(onnx_model_path) / (1024 * 1024)), - ) + onnx_session = onnxruntime.InferenceSession(onnx_model_path, sess_options) + onnx_quant_session = onnxruntime.InferenceSession(quant_onnx_model_path, sess_options) + + print(f"{Fore.GREEN}Models loaded successfully{Style.RESET_ALL}") + + # Run the appropriate mode + if args.test: + test_all_examples(model, model_int8, onnx_session, onnx_quant_session, tokenizer) + else: + process_single_prompt(model, model_int8, onnx_session, onnx_quant_session, tokenizer, args.prompt) + + +if __name__ == "__main__": + main() From 2515e1395fe141d04244c5ba9b83cb89fd0796a5 Mon Sep 17 00:00:00 2001 From: MikeG112 Date: Sun, 20 Apr 2025 16:07:23 -0400 Subject: [PATCH 06/27] Add examples --- src/modeling/predict.py | 43 +++++++++++++++++++++++++++++++---------- 1 file changed, 33 insertions(+), 10 deletions(-) diff --git a/src/modeling/predict.py b/src/modeling/predict.py index 816d58e..0dd4f0e 100644 --- a/src/modeling/predict.py +++ b/src/modeling/predict.py @@ -16,6 +16,8 @@ init() MODEL_DIR = "./models/overturn_predictor" +ID2LABEL = {0: "Insufficient", 1: "Upheld", 2: "Overturned"} +LABEL2ID = {v: k for k, v in ID2LABEL.items()} # Test examples with expected classifications TEST_EXAMPLES = [ @@ -40,6 +42,30 @@ "text": "A patient is being denied wegovy for morbid obesity. The health plan states it is not medically necessary.", "expected": "Overturned", }, + { + "text": "This patient has extensive and inoperable carcinoma of the stomach. He was started on chemotherapy with Xeloda and Oxaliplatin, because he has less nausea with Oxaliplatin than with the alternative, Cisplatin. Oxaliplatin was denied as experimental for treatment of his gastric cancer.", + "expected": "Overturned", + }, + { + "text": "This is a patient who was denied breast tomosynthesis to screen for breast cancer.", + "expected": "Overturned", + }, + { + "text": "This is a patient with Crohn's Disease who is being treated with Humira. Their health plan has denied Anser ADA blood level testing for Humira, claiming it is investigational.", + "expected": "Upheld", + }, + { + "text": "The patient is a 44-year-old female who initially presented with an abnormal screening mammogram. The patient was seen by a radiation oncologist who recommended treatment of the right chest wall and comprehensive nodal regions using proton beam radiation therapy.", + "expected": "Upheld", + }, + { + "text": "The patient is a 10-year-old female with a history of Pitt-Hopkins syndrome and associated motor planning difficulties, possible weakness in the oral area, and receptive and expressive language delays. The provider has recommended that the patient continue to receive individual speech and language therapy sessions twice a week for 60-minute sessions. The Health Insurer has denied the requested services as not medically necessary for treatment of the patient’s medical condition.", + "expected": "Overturned", + }, + { + "text": "The patient is a nine-year-old female with a history of autism spectrum disorder and a speech delay. The patient’s parent has requested reimbursement for the ABA services provided over the course of a year. The Health Insurer has denied the services at issue as not medically necessary for the treatment of the patient.", + "expected": "Overturned", + }, ] @@ -160,11 +186,13 @@ def test_all_examples(model, model_int8, onnx_session, onnx_quant_session, token expected = example["expected"] # Print a shortened version of the text - print(f"\n{Fore.YELLOW}{Style.BRIGHT}Example {i+1}: '{text[:250]}...' (Expected: {expected}){Style.RESET_ALL}") + print( + f"\n{Fore.YELLOW}{Style.BRIGHT}Example {i + 1}: '{text[:250]}...' (Expected: {expected}){Style.RESET_ALL}" + ) # Run through all model variants pt_result = run_pytorch_model(model, tokenizer, text) - pt_no_metadata_result = run_pytorch_model_no_metadata(model, tokenizer, text) + # pt_no_metadata_result = run_pytorch_model_no_metadata(model, tokenizer, text) pt_quant_result = run_pytorch_quantized(model_int8, tokenizer, text) onnx_result = run_onnx(onnx_session, tokenizer, text) @@ -172,7 +200,7 @@ def test_all_examples(model, model_int8, onnx_session, onnx_quant_session, token # Print results print_result("PyTorch (with metadata)", pt_result, expected) - print_result("PyTorch (no metadata)", pt_no_metadata_result, expected) + # print_result("PyTorch (no metadata)", pt_no_metadata_result, expected) print_result("PyTorch (quantized)", pt_quant_result, expected) print_result("ONNX", onnx_result, expected) print_result("ONNX (quantized)", onnx_quant_result, expected) @@ -205,7 +233,7 @@ def process_single_prompt(model, model_int8, onnx_session, onnx_quant_session, t def main(): - parser = argparse.ArgumentParser(description="Medical Classification Model Inference") + parser = argparse.ArgumentParser(description="Appeal Classification Model Inference") parser.add_argument("--test", action="store_true", help="Run all test examples") parser.add_argument("--prompt", type=str, help="Text to classify") args = parser.parse_args() @@ -213,11 +241,6 @@ def main(): if not args.test and not args.prompt: parser.error("Either --test or --prompt must be specified") - # Model labels - global ID2LABEL, LABEL2ID - ID2LABEL = {0: "Insufficient", 1: "Upheld", 2: "Overturned"} - LABEL2ID = {v: k for k, v in ID2LABEL.items()} - # Load model and tokenizer dataset_name = "train_backgrounds_suff_augmented" checkpoints_dir = os.path.join(MODEL_DIR, dataset_name, "distilbert") @@ -228,7 +251,7 @@ def main(): print(f"{Fore.CYAN}Loading models...{Style.RESET_ALL}") # Load tokenizer - tokenizer = AutoTokenizer.from_pretrained(ckpt_path, model_max_length=512) + tokenizer = AutoTokenizer.from_pretrained(ckpt_path) # Load PyTorch model model = TextClassificationWithMetadata.from_pretrained( From cfc8513d82680b5baa975590efd3e249602bbd5e Mon Sep 17 00:00:00 2001 From: MikeG112 Date: Sun, 20 Apr 2025 17:04:20 -0400 Subject: [PATCH 07/27] Fix onnx export --- src/modeling/predict.py | 108 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 99 insertions(+), 9 deletions(-) diff --git a/src/modeling/predict.py b/src/modeling/predict.py index 0dd4f0e..f7a168a 100644 --- a/src/modeling/predict.py +++ b/src/modeling/predict.py @@ -1,5 +1,6 @@ #!/usr/bin/env python import argparse +import copy import os import numpy as np @@ -134,13 +135,29 @@ def run_pytorch_quantized(model_int8, tokenizer, text, jurisdiction_id=2, insura } -def run_onnx(session, tokenizer, text): +def run_onnx(session, tokenizer, text, jurisdiction_id=2, insurance_type_id=2): """Run inference with ONNX Runtime""" try: + # Tokenize input text inputs = tokenizer(text, return_tensors="np", truncation=True, padding=True) - outputs = session.run(output_names=["logits"], input_feed=dict(inputs)) - result = scipy.special.softmax(outputs[0], axis=-1) + # Construct the full input feed + input_feed = dict(inputs) + + # Check if the model expects metadata inputs + input_names = [input.name for input in session.get_inputs()] + + # Add jurisdiction and insurance type if supported by the model + if "jurisdiction_id" in input_names: + input_feed["jurisdiction_id"] = np.array([jurisdiction_id], dtype=np.int64) + if "insurance_type_id" in input_names: + input_feed["insurance_type_id"] = np.array([insurance_type_id], dtype=np.int64) + + # Run inference + outputs = session.run(output_names=["logits"], input_feed=input_feed) + + # Process outputs + result = scipy.special.softmax(outputs[0], axis=-1) argmax = np.argmax(result[0]) prob = result[0][argmax] @@ -232,6 +249,65 @@ def process_single_prompt(model, model_int8, onnx_session, onnx_quant_session, t print_result("ONNX (quantized)", onnx_quant_result) +def calibrate_model(model, tokenizer, calibration_data): + """Run calibration data through the model for quantization preparation""" + print(f"{Fore.CYAN}Calibrating model with {len(calibration_data)} examples...{Style.RESET_ALL}") + model.eval() + + # Define mappings for jurisdiction and insurance type if not already in the calibration data + JURISDICTION_MAP = {"NY": 0, "CA": 1, "Unspecified": 2} + INSURANCE_TYPE_MAP = {"Commercial": 0, "Medicaid": 1, "Unspecified": 2} + + with torch.no_grad(): + for example in calibration_data: + text = example["text"] + + # Get jurisdiction and insurance type IDs, default to "Unspecified" (2) + if "jurisdiction" in example: + j_id = JURISDICTION_MAP.get(example["jurisdiction"], 2) + else: + j_id = example.get("jurisdiction_id", 2) + + if "insurance_type" in example: + i_id = INSURANCE_TYPE_MAP.get(example["insurance_type"], 2) + else: + i_id = example.get("insurance_type_id", 2) + + # Tokenize the text + tokenized = tokenizer(text, return_tensors="pt", truncation=True, padding=True) + + # Convert jurisdiction and insurance type to tensors + j_tensor = torch.tensor([j_id]) + i_tensor = torch.tensor([i_id]) + + # Run forward pass for calibration + _ = model(**tokenized, jurisdiction_id=j_tensor, insurance_type_id=i_tensor) + + return model + + +def quantize_model_with_proper_embedding_config(model, tokenizer, calibration_data=None): + """Quantize model focusing on embeddings and linear layers while avoiding layer_norm issues""" + print(f"{Fore.CYAN}Setting up quantization for embeddings and linear layers...{Style.RESET_ALL}") + + # Create a copy of the model to preserve the original + model_copy = copy.deepcopy(model) + model_copy.eval() + + # Dynamic quantization approach that properly handles both layer types + model_int8 = torch.ao.quantization.quantize_dynamic( + model_copy, + { + torch.nn.Linear, + # torch.nn.Embedding + }, # quantize both linear and embedding layers + dtype=torch.qint8, + ) + + print(f"{Fore.GREEN}Quantization complete (Linear and Embedding layers){Style.RESET_ALL}") + return model_int8 + + def main(): parser = argparse.ArgumentParser(description="Appeal Classification Model Inference") parser.add_argument("--test", action="store_true", help="Run all test examples") @@ -265,12 +341,26 @@ def main(): round(os.path.getsize(os.path.join(ckpt_path, "model.safetensors")) / (1024 * 1024)), ) - # Load quantized PyTorch model - model_int8 = torch.ao.quantization.quantize_dynamic( - model.to("cpu"), - {torch.nn.Linear}, - dtype=torch.qint8, - ) + # Prepare calibration data from test examples + calibration_data = [] + for example in TEST_EXAMPLES: + # Add jurisdiction and insurance_type variations for each example + for j_id in [0, 1, 2]: # NY, CA, Unspecified + for i_id in [0, 1, 2]: # Commercial, Medicaid, Unspecified + calibration_data.append({"text": example["text"], "jurisdiction_id": j_id, "insurance_type_id": i_id}) + + # Prepare calibration data from test examples + calibration_data = [] + for example in TEST_EXAMPLES: + # Add jurisdiction and insurance_type variations for each example + for j_id in [0, 1, 2]: # NY, CA, Unspecified + for i_id in [0, 1, 2]: # Commercial, Medicaid, Unspecified + calibration_data.append({"text": example["text"], "jurisdiction_id": j_id, "insurance_type_id": i_id}) + + # Quantize model with proper embedding configuration + model_int8 = quantize_model_with_proper_embedding_config(model.to("cpu"), tokenizer, calibration_data) + + print(f"{Fore.GREEN}Model calibration and quantization complete{Style.RESET_ALL}") param_size = 0 for param in model_int8.parameters(): param_size += param.nelement() * param.element_size() From b7da77a72138ee4c68ec799c6b0c7c5002530628 Mon Sep 17 00:00:00 2001 From: MikeG112 Date: Sun, 20 Apr 2025 17:59:32 -0400 Subject: [PATCH 08/27] Update training loop --- .../config/outcome_prediction/distilbert.yaml | 12 +- src/modeling/eval_hf_outcome_predictor.py | 83 ++++++++-- src/modeling/train_outcome_predictor.py | 143 ++++++++++-------- src/modeling/util.py | 100 +++++++++--- 4 files changed, 241 insertions(+), 97 deletions(-) diff --git a/src/modeling/config/outcome_prediction/distilbert.yaml b/src/modeling/config/outcome_prediction/distilbert.yaml index 6e2ebe3..776e4b0 100644 --- a/src/modeling/config/outcome_prediction/distilbert.yaml +++ b/src/modeling/config/outcome_prediction/distilbert.yaml @@ -8,17 +8,17 @@ hicric_pretrained: False base_model_name: "distilbert" pretrained_model_dir: "None" pretrained_hf_model_key: "distilbert/distilbert-base-uncased" -train_data_path: "./data/outcomes/train_backgrounds_suff.jsonl" +train_data_path: "./data/outcomes/train_backgrounds_suff_augmented.jsonl" # Training settings learning_rate: 8.0e-7 weight_decay: 0.01 -num_epochs: 40 -batch_size: 20 +num_epochs: 20 +batch_size: 48 dtype: "float16" # 'float32','float16' for training dtype compile: True # Whether to use torch compile # Test eval settings -test_data_path: "./data/outcomes/test_backgrounds_suff.jsonl" -checkpoint_name: "checkpoint-41008" -eval_threshold: .55 \ No newline at end of file +test_data_path: "./data/outcomes/test_backgrounds_suff_augmented.jsonl" +checkpoint_name: "checkpoint-19152" +eval_threshold: .35 \ No newline at end of file diff --git a/src/modeling/eval_hf_outcome_predictor.py b/src/modeling/eval_hf_outcome_predictor.py index 6044034..39640ca 100644 --- a/src/modeling/eval_hf_outcome_predictor.py +++ b/src/modeling/eval_hf_outcome_predictor.py @@ -18,6 +18,7 @@ TextClassificationPipeline, ) +from src.modeling.train_outcome_predictor import TextClassificationWithMetadata from src.modeling.util import load_config from src.util import get_records_list @@ -25,6 +26,10 @@ ID2LABEL = {0: "Insufficient", 1: "Upheld", 2: "Overturned"} LABEL2ID = {v: k for k, v in ID2LABEL.items()} +# Define mappings for jurisdiction and insurance type +JURISDICTION_MAP = {"NY": 0, "CA": 1, "Unspecified": 2} +INSURANCE_TYPE_MAP = {"Commercial": 0, "Medicaid": 1, "Unspecified": 2} + # TODO: centralize label construction, label consts def construct_label(outcome, sufficiency_id, label2id): @@ -131,6 +136,20 @@ def postprocess(self, model_outputs): return out_logits +# Check if the model has a custom forward method that supports metadata +def is_metadata_model(model): + """Check if model has metadata capabilities by inspecting its methods""" + # Look for the key attributes in our custom model + has_j_embeddings = hasattr(model, "jurisdiction_embeddings") + has_i_embeddings = hasattr(model, "insurance_type_embeddings") + + # Check if model's forward method signature includes the metadata params + forward_sig = model.forward.__code__.co_varnames + accepts_metadata = "jurisdiction_id" in forward_sig and "insurance_type_id" in forward_sig + + return has_j_embeddings and has_i_embeddings and accepts_metadata + + def main(config_path: str): cfg = load_config(config_path) @@ -147,26 +166,72 @@ def main(config_path: str): # Load raw dataset test_dataset = get_records_list(dataset_path) - # tokenizer = AutoTokenizer.from_pretrained(pretrained_model_key, model_max_length=512) - checkpoints_dir = os.path.join(MODEL_DIR, "train_backgrounds_suff", base_model_name) + # Set up paths + checkpoints_dir = os.path.join(MODEL_DIR, "train_backgrounds_suff_augmented", base_model_name) print("Evaluating from checkpoint: ", checkpoint_name) ckpt_path = os.path.join(checkpoints_dir, checkpoint_name) + # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(ckpt_path) - model = AutoModelForSequenceClassification.from_pretrained( - ckpt_path, num_labels=3, id2label=ID2LABEL, label2id=LABEL2ID + + # Load model + # model = AutoModelForSequenceClassification.from_pretrained( + # ckpt_path, num_labels=3, id2label=ID2LABEL, label2id=LABEL2ID + # ) + model = TextClassificationWithMetadata.from_pretrained( + ckpt_path, + num_labels=3, + id2label=ID2LABEL, + label2id=LABEL2ID, ) - # Isolate records + # Check if model supports metadata + supports_metadata = is_metadata_model(model) + + # Prepare data text_records = [rec["text"] for rec in test_dataset] labels = [construct_label(rec["decision"], rec["sufficiency_id"], LABEL2ID) for rec in test_dataset] - device = "cuda" - # pipeline = ClassificationPipeline(model=model, tokenizer=tokenizer, device=device) - pipeline = ClassificationPipeline(model=model, tokenizer=tokenizer, device=device, truncation=True) + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Using device: {device}") + + # For models with metadata support, use a custom approach + if supports_metadata: + # Extract metadata + jurisdictions = [rec.get("jurisdiction", "Unspecified") for rec in test_dataset] + insurance_types = [rec.get("insurance_type", "Unspecified") for rec in test_dataset] + + # Convert metadata to tensors + j_ids = torch.tensor([JURISDICTION_MAP.get(j, JURISDICTION_MAP["Unspecified"]) for j in jurisdictions]) + i_ids = torch.tensor([INSURANCE_TYPE_MAP.get(i, INSURANCE_TYPE_MAP["Unspecified"]) for i in insurance_types]) - predictions = torch.cat(pipeline(text_records, batch_size=100)) + model.to(device) + model.eval() + # Process in batches + batch_size = 100 + all_logits = [] + + for i in range(0, len(text_records), batch_size): + end_idx = min(i + batch_size, len(text_records)) + batch_texts = text_records[i:end_idx] + batch_j_ids = j_ids[i:end_idx].to(device) + batch_i_ids = i_ids[i:end_idx].to(device) + + inputs = tokenizer(batch_texts, truncation=True, padding=True, return_tensors="pt").to(device) + + with torch.no_grad(): + outputs = model(**inputs, jurisdiction_id=batch_j_ids, insurance_type_id=batch_i_ids) + + all_logits.append(outputs["logits"].cpu()) + + predictions = torch.cat(all_logits) + else: + # For standard models, use the original pipeline + pipeline = ClassificationPipeline(model=model, tokenizer=tokenizer, device=device, truncation=True) + predictions = torch.cat(pipeline(text_records, batch_size=100)) + + # Compute metrics if threshold is not None: threshold_metrics = compute_metrics_w_threshold(predictions, labels, threshold) diff --git a/src/modeling/train_outcome_predictor.py b/src/modeling/train_outcome_predictor.py index 316c9e5..f975e8e 100644 --- a/src/modeling/train_outcome_predictor.py +++ b/src/modeling/train_outcome_predictor.py @@ -6,7 +6,7 @@ import numpy as np import scipy import torch -from datasets import Dataset, load_dataset +from datasets import Dataset, DatasetDict, load_dataset from sklearn.metrics import ( accuracy_score, f1_score, @@ -14,10 +14,12 @@ recall_score, roc_auc_score, ) +from sklearn.model_selection import train_test_split from transformers import ( AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding, + DistilBertForSequenceClassification, Trainer, TrainingArguments, ) @@ -47,10 +49,21 @@ def load_and_split( recs = get_records_list(jsonl_path) recs = [{key: rec.get(key, "Unspecified") for key in filter_keys} for rec in recs] dataset = Dataset.from_list(recs) + else: dataset = load_dataset("json", data_files=jsonl_path)["train"] - dataset = dataset.train_test_split(test_size=test_size, seed=1) - return dataset + + # Use scikit-learn train_test_split for stratified sampling + train_indices, test_indices = train_test_split( + list(range(len(dataset))), test_size=test_size, random_state=seed, shuffle=True, stratify=dataset["decision"] + ) + + # Create train and test datasets using the indices + train_dataset = dataset.select(train_indices) + test_dataset = dataset.select(test_indices) + + # Return as a DatasetDict + return DatasetDict({"train": train_dataset, "test": test_dataset}) def tokenize_batch(examples, tokenizer): @@ -119,7 +132,8 @@ def compute_metrics(eval_pred) -> dict: def compute_metrics2(eval_pred) -> dict: - predictions, labels = eval_pred + predictions = eval_pred.predictions[0] + labels = eval_pred.predictions[1] softmax_preds = scipy.special.softmax(predictions, axis=-1) @@ -174,75 +188,86 @@ def compute_metrics2(eval_pred) -> dict: return best_metrics -# Custom model class to handle the additional features -class TextClassificationWithMetadata(AutoModelForSequenceClassification): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - # Add embeddings for jurisdiction and insurance type (2 categories each + unspecified) - # We'll use 0 and 1 for the specific categories, and handle unspecified (2) separately +class TextClassificationWithMetadata(DistilBertForSequenceClassification): + def __init__(self, config): + super().__init__(config) + # Add embeddings for jurisdiction and insurance type self.jurisdiction_embeddings = torch.nn.Embedding(3, 16) self.insurance_type_embeddings = torch.nn.Embedding(3, 16) - # Initialize the unspecified embeddings to be the average of the others + # Initialize the unspecified embeddings to be zeros with torch.no_grad(): - # Initialize the unspecified embedding (index 2) as zeros - # It will be dynamically computed during forward pass self.jurisdiction_embeddings.weight[2].fill_(0) self.insurance_type_embeddings.weight[2].fill_(0) - # Adjust the classifier to include these additional features - config = self.config - hidden_size = config.hidden_size + # Get the hidden size from the model's config + hidden_size = self.config.hidden_size # Create a new classifier with the additional features - self.classifier = torch.nn.Linear(hidden_size + 32, config.num_labels) - - def forward(self, input_ids=None, attention_mask=None, jurisdiction_id=None, insurance_type_id=None, **kwargs): - # Get the default output from parent - outputs = super().forward(input_ids=input_ids, attention_mask=attention_mask, **kwargs) - - # If we're not using the additional features, return the default outputs + self.final_classifier = torch.nn.Linear(hidden_size + 32, config.num_labels) + + def forward( + self, + input_ids=None, + attention_mask=None, + jurisdiction_id=None, + insurance_type_id=None, + return_dict=True, + return_loss=True, + **kwargs, + ): + # Extract labels before passing to super().forward() + labels = kwargs.pop("labels", None) + + # Get the embeddings and pooler output from the base model + base_outputs = super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + labels=None, # Important: don't pass labels yet + # **kwargs # Pass remaining kwargs + ) + + # If we're not using the additional features, use base model logits if jurisdiction_id is None or insurance_type_id is None: - return outputs - - # Create masks for where the IDs are "Unspecified" (value 2) - j_unspecified_mask = jurisdiction_id == 2 - i_unspecified_mask = insurance_type_id == 2 + logits = base_outputs.logits + else: + # Process metadata features + j_unspecified_mask = jurisdiction_id == 2 + i_unspecified_mask = insurance_type_id == 2 - # Get embeddings for the additional features - j_embeddings = self.jurisdiction_embeddings(jurisdiction_id) - i_embeddings = self.insurance_type_embeddings(insurance_type_id) + j_embeddings = self.jurisdiction_embeddings(jurisdiction_id) + i_embeddings = self.insurance_type_embeddings(insurance_type_id) - # For unspecified jurisdiction, use the average of NY and CA embeddings - if j_unspecified_mask.any(): - # Calculate average of specific jurisdiction embeddings (indices 0 and 1) - avg_j_embedding = (self.jurisdiction_embeddings.weight[0] + self.jurisdiction_embeddings.weight[1]) / 2 + if j_unspecified_mask.any(): + avg_j_embedding = (self.jurisdiction_embeddings.weight[0] + self.jurisdiction_embeddings.weight[1]) / 2 + j_embeddings[j_unspecified_mask] = avg_j_embedding - # Apply the average embedding where jurisdiction is unspecified - j_embeddings[j_unspecified_mask] = avg_j_embedding + if i_unspecified_mask.any(): + avg_i_embedding = ( + self.insurance_type_embeddings.weight[0] + self.insurance_type_embeddings.weight[1] + ) / 2 + i_embeddings[i_unspecified_mask] = avg_i_embedding - # For unspecified insurance_type, use the average of Commercial and Medicaid embeddings - if i_unspecified_mask.any(): - # Calculate average of specific insurance_type embeddings (indices 0 and 1) - avg_i_embedding = (self.insurance_type_embeddings.weight[0] + self.insurance_type_embeddings.weight[1]) / 2 + # For models without pooler_output, use the last hidden state's [CLS] token + last_hidden_state = base_outputs.hidden_states[-1] + pooled_output = last_hidden_state[:, 0] - # Apply the average embedding where insurance_type is unspecified - i_embeddings[i_unspecified_mask] = avg_i_embedding + combined_features = torch.cat([pooled_output, j_embeddings, i_embeddings], dim=1) + logits = self.final_classifier(combined_features) - # Concatenate with the pooled output - pooled_output = outputs.pooler_output - combined_features = torch.cat([pooled_output, j_embeddings, i_embeddings], dim=1) + # Update the logits in the output + results = {"logits": logits} - # Pass through the classifier - logits = self.classifier(combined_features) + if labels is not None: + loss_fct = torch.nn.CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + results["loss"] = loss + results["labels"] = labels - # Replace the logits in the outputs - outputs.logits = logits + return results - return outputs - -# Custom collator to handle the additional features class DataCollatorWithMetadata(DataCollatorWithPadding): def __call__(self, features): batch = super().__call__(features) @@ -308,7 +333,7 @@ def main(config_path: str) -> None: # Use our custom data collator that handles the additional features data_collator = DataCollatorWithMetadata(tokenizer=tokenizer) - # Load the base model + # Load the base model to determine its class base_model = AutoModelForSequenceClassification.from_pretrained( pretrained_model_key, num_labels=3, @@ -316,7 +341,7 @@ def main(config_path: str) -> None: label2id=LABEL2ID, ) - # Create our custom model + # Now instantiate your custom model correctly model = TextClassificationWithMetadata.from_pretrained( pretrained_model_key, num_labels=3, @@ -324,9 +349,6 @@ def main(config_path: str) -> None: label2id=LABEL2ID, ) - # Copy weights from base model to our custom model - model.load_state_dict(base_model.state_dict(), strict=False) - # Handle annoyance with HF / pretrained legalbert bug/user error # HF trainer complains of param data not being contiguous when loading checkpoints if base_model_name == "legal-bert-small-uncased": @@ -336,13 +358,14 @@ def main(config_path: str) -> None: checkpoints_dir = os.path.join(OUTPUT_DIR, dataset_name, outdir_name) training_args = TrainingArguments( output_dir=checkpoints_dir, + run_name=cfg["wandb_run_tag"], learning_rate=cfg["learning_rate"], per_device_train_batch_size=cfg["batch_size"], per_device_eval_batch_size=cfg["batch_size"], num_train_epochs=cfg["num_epochs"], weight_decay=cfg["weight_decay"], fp16=(cfg["dtype"] == "float16"), - evaluation_strategy="epoch", + eval_strategy="epoch", save_strategy="epoch", save_total_limit=1, load_best_model_at_end=True, @@ -357,7 +380,7 @@ def main(config_path: str) -> None: args=training_args, train_dataset=dataset["train"], eval_dataset=dataset["test"], - tokenizer=tokenizer, + processing_class=tokenizer, data_collator=data_collator, compute_metrics=compute_metrics2, ) diff --git a/src/modeling/util.py b/src/modeling/util.py index 5f8d83d..099deb7 100644 --- a/src/modeling/util.py +++ b/src/modeling/util.py @@ -6,35 +6,91 @@ from transformers import AutoModel, AutoTokenizer -def quantize_onnx_model(onnx_model_path: str, quantized_model_path: str): - quantize_dynamic(onnx_model_path, quantized_model_path, weight_type=QuantType.QInt8) - print(f"quantized model saved to:{quantized_model_path}") - return None - - def export_onnx_model(output_model_path: str, model: torch.nn.Module | AutoModel, tokenizer: AutoTokenizer): if os.path.exists(output_model_path): print(f"Warning: overwriting existing ONNX model at path {output_model_path}") - dummy_text = "test" + + # Use a more representative dummy text + dummy_text = "This is a medical test example with sufficient content to exercise the model" + # Tokenize the text dummy_input = tokenizer(dummy_text, return_tensors="pt") - torch.onnx.export( - model, - tuple([dummy_input["input_ids"], dummy_input["attention_mask"]]), - f=output_model_path, - export_params=True, - input_names=["input_ids", "attention_mask"], - output_names=["logits"], - dynamic_axes={ - "input_ids": {0: "batch_size", 1: "sequence"}, - "attention_mask": {0: "batch_size", 1: "sequence"}, - "logits": {0: "batch_size", 1: "sequence"}, - }, - do_constant_folding=True, - opset_version=17, - ) + + # Add metadata inputs for the custom model + dummy_jurisdiction_id = torch.tensor([2]) # 2 = Unspecified + dummy_insurance_type_id = torch.tensor([2]) # 2 = Unspecified + + # Check if model accepts metadata + try: + # Test if model accepts metadata parameters + with torch.no_grad(): + _ = model( + input_ids=dummy_input["input_ids"], + attention_mask=dummy_input["attention_mask"], + jurisdiction_id=dummy_jurisdiction_id, + insurance_type_id=dummy_insurance_type_id, + ) + + # If we got here, model accepts metadata + print("Exporting model with metadata support...") + torch.onnx.export( + model, + tuple( + [ + dummy_input["input_ids"], + dummy_input["attention_mask"], + dummy_jurisdiction_id, + dummy_insurance_type_id, + ] + ), + f=output_model_path, + export_params=True, + input_names=["input_ids", "attention_mask", "jurisdiction_id", "insurance_type_id"], + output_names=["logits"], + dynamic_axes={ + "input_ids": {0: "batch_size", 1: "sequence"}, + "attention_mask": {0: "batch_size", 1: "sequence"}, + "jurisdiction_id": {0: "batch_size"}, + "insurance_type_id": {0: "batch_size"}, + "logits": {0: "batch_size"}, + }, + do_constant_folding=True, + opset_version=17, + ) + except Exception as e: + print(f"Model doesn't support metadata, exporting without it: {e}") + # Export without metadata + torch.onnx.export( + model, + tuple([dummy_input["input_ids"], dummy_input["attention_mask"]]), + f=output_model_path, + export_params=True, + input_names=["input_ids", "attention_mask"], + output_names=["logits"], + dynamic_axes={ + "input_ids": {0: "batch_size", 1: "sequence"}, + "attention_mask": {0: "batch_size", 1: "sequence"}, + "logits": {0: "batch_size"}, + }, + do_constant_folding=True, + opset_version=17, + ) + print(f"Exported ONNX model to {output_model_path}.") +def quantize_onnx_model(onnx_model_path: str, quantized_model_path: str): + # Quantize the model + quantize_dynamic( + onnx_model_path, + quantized_model_path, + weight_type=QuantType.QInt8, + # Optionally exclude problematic operators + op_types_to_quantize=["MatMul", "Gemm", "Conv"], + ) + print(f"Quantized model saved to: {quantized_model_path}") + return None + + def load_config(config_path: str) -> dict: with open(config_path, "r") as stream: try: From ddb13c84064dffe9125550331eb6181759572b97 Mon Sep 17 00:00:00 2001 From: MikeG112 Date: Sun, 20 Apr 2025 18:08:46 -0400 Subject: [PATCH 09/27] Remove ifs --- src/modeling/train_outcome_predictor.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/src/modeling/train_outcome_predictor.py b/src/modeling/train_outcome_predictor.py index f975e8e..95e5e48 100644 --- a/src/modeling/train_outcome_predictor.py +++ b/src/modeling/train_outcome_predictor.py @@ -239,15 +239,22 @@ def forward( j_embeddings = self.jurisdiction_embeddings(jurisdiction_id) i_embeddings = self.insurance_type_embeddings(insurance_type_id) - if j_unspecified_mask.any(): - avg_j_embedding = (self.jurisdiction_embeddings.weight[0] + self.jurisdiction_embeddings.weight[1]) / 2 - j_embeddings[j_unspecified_mask] = avg_j_embedding - - if i_unspecified_mask.any(): - avg_i_embedding = ( - self.insurance_type_embeddings.weight[0] + self.insurance_type_embeddings.weight[1] - ) / 2 - i_embeddings[i_unspecified_mask] = avg_i_embedding + # Calculate average embeddings + avg_j_embedding = (self.jurisdiction_embeddings.weight[0] + self.jurisdiction_embeddings.weight[1]) / 2 + avg_i_embedding = (self.insurance_type_embeddings.weight[0] + self.insurance_type_embeddings.weight[1]) / 2 + + j_embeddings = torch.where( + j_unspecified_mask.unsqueeze(-1).expand_as(j_embeddings), + avg_j_embedding.expand_as(j_embeddings), + j_embeddings, + ) + + # This replaces: if i_unspecified_mask.any(): i_embeddings[i_unspecified_mask] = avg_i_embedding + i_embeddings = torch.where( + i_unspecified_mask.unsqueeze(-1).expand_as(i_embeddings), + avg_i_embedding.expand_as(i_embeddings), + i_embeddings, + ) # For models without pooler_output, use the last hidden state's [CLS] token last_hidden_state = base_outputs.hidden_states[-1] From 98b8e92b7eda83374130a50154ab26d5774c3503 Mon Sep 17 00:00:00 2001 From: MikeG112 Date: Sun, 20 Apr 2025 22:05:58 -0400 Subject: [PATCH 10/27] Remove exclusion --- src/modeling/util.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/modeling/util.py b/src/modeling/util.py index 099deb7..f391681 100644 --- a/src/modeling/util.py +++ b/src/modeling/util.py @@ -84,8 +84,6 @@ def quantize_onnx_model(onnx_model_path: str, quantized_model_path: str): onnx_model_path, quantized_model_path, weight_type=QuantType.QInt8, - # Optionally exclude problematic operators - op_types_to_quantize=["MatMul", "Gemm", "Conv"], ) print(f"Quantized model saved to: {quantized_model_path}") return None From 9d8a1ee0bcf88ae8a70350bba0093bf4151aa6f2 Mon Sep 17 00:00:00 2001 From: MikeG112 Date: Sun, 20 Apr 2025 23:00:47 -0400 Subject: [PATCH 11/27] Upate sufficiency augmentation --- .../sufficiency_classification/augmented.yaml | 9 +++++- src/modeling/data_augmentation.py | 28 ++++++++----------- 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/src/modeling/config/sufficiency_classification/augmented.yaml b/src/modeling/config/sufficiency_classification/augmented.yaml index 95031fe..53cedc5 100644 --- a/src/modeling/config/sufficiency_classification/augmented.yaml +++ b/src/modeling/config/sufficiency_classification/augmented.yaml @@ -22,4 +22,11 @@ use_saved_augmentations: false sufficient_augmentation_params: api_type: openai api_url: https://api.openai.com/v1/chat/completions - model_name: gpt-4o \ No newline at end of file + model_name: gpt-4o + num_augmentations_per_example: 2 + +generic_rewrite_params: + num_augmentations_per_example: 2 + +unrelated_params: + num_examples: 500 \ No newline at end of file diff --git a/src/modeling/data_augmentation.py b/src/modeling/data_augmentation.py index d250871..b18ee1b 100644 --- a/src/modeling/data_augmentation.py +++ b/src/modeling/data_augmentation.py @@ -85,7 +85,7 @@ def rewrite_to_generic( {"role": "user", "content": prompt}, ], "max_tokens": 500, - "temperature": 0.7, + "temperature": 1.2, } else: # llama request_data = {"prompt": f"{system_instruction}\n\n{prompt}", "temperature": 0.7, "stream": False} @@ -187,7 +187,7 @@ def generate_unrelated_content( "Make a general statement about insurance denials", "Make a statement about insurance denials for a particular type of care, but that does not describe a specific situation.", "Write some generic unrelated user input (hello, how are you, etc.). Nothing inappropriate.", - "Write some gibberish that might be input accidentally by a user, or sent inadverdently, cut off sentences, etc.", + "Write some gibberish that might be input accidentally by a user, or sent inadvertently, cut off sentences, etc.", ] # Calculate how many examples to generate per category @@ -219,12 +219,12 @@ def generate_unrelated_content( {"role": "user", "content": category}, ], "max_tokens": 100, - "temperature": 0.8, + "temperature": 1.2, } else: # LLaMa.cpp request_data = { "prompt": f"{system_instruction}\n\n{category}", - "temperature": 0.8, + "temperature": 1.2, "max_tokens": 100, "stream": False, } @@ -243,8 +243,8 @@ def generate_unrelated_content( new_text = response_json["content"].strip() # Clean up and limit length - if len(new_text) > 200: - new_text = new_text[:200] + if len(new_text) > 512: + new_text = new_text[:512] # Add the example with metadata all_generations.append( @@ -331,7 +331,7 @@ def augment_sufficient_examples( ) clinical_prompt = "Rewrite the following description using more clinical and technical medical terminology, maintaining all the key details. {}" paraphrase_prompt = "Rewrite the following description in different words while preserving the exact same meaning and all key details. Make minimal changes: {}" - details_prompt = "Rewrite the following description, adding a few more specific details about the condition and treatment, while preserving the core information. Make minimal changes: {}" + details_prompt = "Rewrite the following description, adding a few more specific details about the condition and treatment, but removing none. Make minimal changes: {}" # System instruction for models system_instruction = "You are a helpful assistant that rewrites healthcare denial descriptions while preserving their key information." @@ -353,21 +353,17 @@ def augment_sufficient_examples( continue # Delete 1-3 random words (but not too many) - num_to_delete = min(random.randint(1, 3), len(words) // 5) + num_to_delete = 1 indices_to_delete = random.sample(range(len(words)), num_to_delete) new_text = " ".join([w for i, w in enumerate(words) if i not in indices_to_delete]) - # Slightly reduce sufficiency score but keep it sufficient (>= 3) - new_score = max(3, sufficiency_score - 1) if random.random() < 0.3 else sufficiency_score - # Add the augmented example with metadata augmented_examples.append( { "text": new_text, - "sufficiency_score": new_score, + "sufficiency_score": sufficiency_score, "source_text": text, - "source_score": sufficiency_score, "augmentation_type": "sufficient_word_deletion", } ) @@ -398,13 +394,13 @@ def augment_sufficient_examples( {"role": "user", "content": prompt}, ], "max_tokens": 500, - "temperature": 0.7, + "temperature": 1.2, } else: # LLaMa.cpp # LLaMa.cpp format (depends on your server implementation) request_data = { "prompt": f"{system_instruction}\n\n{prompt}", - "temperature": 0.7, + "temperature": 1.2, "max_tokens": 500, "stop": ["\n\n", "###"], # Common stop sequences "stream": False, @@ -543,7 +539,7 @@ def process_and_save_augmentations( train_records = [dataset_records[i] for i in train_indices] test_records = [dataset_records[i] for i in test_indices] - # Apply augmentations only to training data only + # Apply augmentations to training data only train_dataset = Dataset.from_list(train_records) test_dataset = Dataset.from_list(test_records) augmented_train_records = [] From 2ddc09afb73dc01e55149e7f38e05c1682db03fc Mon Sep 17 00:00:00 2001 From: MikeG112 Date: Sun, 20 Apr 2025 23:04:58 -0400 Subject: [PATCH 12/27] Add API key --- src/modeling/data_augmentation.py | 4 ++-- src/modeling/train_sufficiency_classifier.py | 19 +++++-------------- 2 files changed, 7 insertions(+), 16 deletions(-) diff --git a/src/modeling/data_augmentation.py b/src/modeling/data_augmentation.py index b18ee1b..2cf8696 100644 --- a/src/modeling/data_augmentation.py +++ b/src/modeling/data_augmentation.py @@ -280,8 +280,8 @@ def generate_unrelated_content( def augment_sufficient_examples( dataset: Dataset, num_augmentations_per_example: int = 1, - api_type: str = "llamacpp", # "openai" or "llamacpp" - api_url: str = "http://localhost:8080/completion", + api_type: str = "openai", # "openai" or "llamacpp" + api_url: str = "https://api.openai.com/v1/chat/completions", api_key: str | None = None, model_name: str = "localllama", seed: int = 42, diff --git a/src/modeling/train_sufficiency_classifier.py b/src/modeling/train_sufficiency_classifier.py index e174f76..f34c079 100644 --- a/src/modeling/train_sufficiency_classifier.py +++ b/src/modeling/train_sufficiency_classifier.py @@ -223,21 +223,12 @@ def main(config_path: str) -> None: elif use_data_augmentation: generic_rewrite_params = cfg.get( "generic_rewrite_params", - {"num_augmentations_per_example": 1, "seed": 1, "api_key": os.environ.get("OPENAI_API_KEY")}, - ) - unrelated_params = cfg.get( - "unrelated_params", {"num_examples": 500, "seed": 1, "api_key": os.environ.get("OPENAI_API_KEY")} - ) - sufficient_augmentation_params = cfg.get( - "sufficient_augmentation_params", - { - "num_augmentations_per_example": 1, - "api_type": "llamacpp", - "model_name": "Meta-Llama-3-8B-Instruct.Q4_K_M.gguf", - "seed": 1, - "api_key": os.environ.get("OPENAI_API_KEY"), - }, ) + generic_rewrite_params["api_key"] = os.environ.get("OPENAI_API_KEY") + unrelated_params = cfg.get("unrelated_params") + unrelated_params["api_key"] = os.environ.get("OPENAI_API_KEY") + sufficient_augmentation_params = cfg.get("sufficient_augmentation_params") + sufficient_augmentation_params["api_key"] = os.environ.get("OPENAI_API_KEY") # Set output paths based on save_augmentations flag output_train_path = augmented_train_path if save_augmentations else None From 96cbd653a2e810445a9d39461cf1db6d0727bb27 Mon Sep 17 00:00:00 2001 From: MikeG112 Date: Mon, 21 Apr 2025 00:18:10 -0400 Subject: [PATCH 13/27] Increase call limit --- src/modeling/background_extraction.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/modeling/background_extraction.py b/src/modeling/background_extraction.py index 5a89e33..69f89a5 100644 --- a/src/modeling/background_extraction.py +++ b/src/modeling/background_extraction.py @@ -685,9 +685,10 @@ def create_augmented_examples( api_url=api_url, api_key=api_key, model_name=model_name, - generic_rewrites_per_example=2, - sufficient_augmentations_per_example=2, - num_unrelated_examples=100, + generic_rewrites_per_example=1, + sufficient_augmentations_per_example=1, + num_unrelated_examples=1000, + api_call_limit=6000, seed=42, ) @@ -706,7 +707,8 @@ def create_augmented_examples( model_name=model_name, generic_rewrites_per_example=1, # Fewer augmentations for test set sufficient_augmentations_per_example=1, - num_unrelated_examples=10, + num_unrelated_examples=1000, + api_call_limit=6000, seed=43, # Different seed for test set ) From 704371af76171f630a543df5f48692730741af71 Mon Sep 17 00:00:00 2001 From: MikeG112 Date: Mon, 21 Apr 2025 15:58:43 -0400 Subject: [PATCH 14/27] Add logit printing --- src/modeling/predict.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/modeling/predict.py b/src/modeling/predict.py index f7a168a..8a2d3c5 100644 --- a/src/modeling/predict.py +++ b/src/modeling/predict.py @@ -83,12 +83,12 @@ def run_pytorch_model(model, tokenizer, text, jurisdiction_id=2, insurance_type_ probs = torch.softmax(result["logits"], dim=-1) prob, argmax = torch.max(probs, dim=-1) - return { "class_id": argmax.item(), "class_name": ID2LABEL[argmax.item()], "confidence": prob.item(), "probs": probs[0].tolist(), + "logits": result["logits"][0].tolist(), } @@ -108,6 +108,7 @@ def run_pytorch_model_no_metadata(model, tokenizer, text): "class_name": ID2LABEL[argmax.item()], "confidence": prob.item(), "probs": probs[0].tolist(), + "logits": result["logits"][0].tolist(), } except Exception as e: return {"error": str(e), "class_name": "ERROR", "confidence": 0.0, "probs": [0.0, 0.0, 0.0]} @@ -132,6 +133,7 @@ def run_pytorch_quantized(model_int8, tokenizer, text, jurisdiction_id=2, insura "class_name": ID2LABEL[argmax.item()], "confidence": prob.item(), "probs": probs[0].tolist(), + "logits": result["logits"][0].tolist(), } @@ -160,18 +162,18 @@ def run_onnx(session, tokenizer, text, jurisdiction_id=2, insurance_type_id=2): result = scipy.special.softmax(outputs[0], axis=-1) argmax = np.argmax(result[0]) prob = result[0][argmax] - return { "class_id": int(argmax), "class_name": ID2LABEL[int(argmax)], "confidence": float(prob), "probs": result[0].tolist(), + "logits": outputs[0][0], } except Exception as e: return {"error": str(e), "class_name": "ERROR", "confidence": 0.0, "probs": [0.0, 0.0, 0.0]} -def print_result(model_name, result, expected=None, show_probs=False): +def print_result(model_name, result, expected=None, show_probs=True, show_logits=True): """Print inference result with color-coding based on matching expected output""" if "error" in result: print(f"{model_name}: {Fore.RED}ERROR{Style.RESET_ALL} - {result['error']}") @@ -188,6 +190,11 @@ def print_result(model_name, result, expected=None, show_probs=False): # Print result print(f"{model_name}: {color}{result['class_name']}{Style.RESET_ALL} ({result['confidence']:.4f})") + # Optionally show logits + if show_logits: + logits_str = ", ".join([f"{ID2LABEL[i]}: {p:.4f}" for i, p in enumerate(result["logits"])]) + print(f" Logits: {logits_str}") + # Optionally show probabilities if show_probs: probs_str = ", ".join([f"{ID2LABEL[i]}: {p:.4f}" for i, p in enumerate(result["probs"])]) @@ -222,10 +229,6 @@ def test_all_examples(model, model_int8, onnx_session, onnx_quant_session, token print_result("ONNX", onnx_result, expected) print_result("ONNX (quantized)", onnx_quant_result, expected) - # Print probabilities from PyTorch model - probs_str = ", ".join([f"{ID2LABEL[i]}: {p:.4f}" for i, p in enumerate(pt_result["probs"])]) - print(f"Probabilities: {probs_str}") - print("-" * 80) From 4c671f1bb2fe0d2572fa755ba61c2edbb441c150 Mon Sep 17 00:00:00 2001 From: MikeG112 Date: Mon, 21 Apr 2025 15:59:10 -0400 Subject: [PATCH 15/27] Update config --- src/modeling/config/outcome_prediction/distilbert.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/modeling/config/outcome_prediction/distilbert.yaml b/src/modeling/config/outcome_prediction/distilbert.yaml index 776e4b0..4fce9db 100644 --- a/src/modeling/config/outcome_prediction/distilbert.yaml +++ b/src/modeling/config/outcome_prediction/distilbert.yaml @@ -19,6 +19,6 @@ dtype: "float16" # 'float32','float16' for training dtype compile: True # Whether to use torch compile # Test eval settings -test_data_path: "./data/outcomes/test_backgrounds_suff_augmented.jsonl" -checkpoint_name: "checkpoint-19152" -eval_threshold: .35 \ No newline at end of file +test_data_path: "./data/outcomes/test_backgrounds_suff.jsonl" +checkpoint_name: "checkpoint-23040" +eval_threshold: .5 \ No newline at end of file From e91409c500b6111cb283f96e3b11a1b246f1bdfb Mon Sep 17 00:00:00 2001 From: MikeG112 Date: Thu, 24 Apr 2025 16:51:35 -0400 Subject: [PATCH 16/27] Update configs --- .../config/outcome_prediction/clinicalbert.yaml | 4 ++-- .../config/outcome_prediction/legalbert_small.yaml | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/modeling/config/outcome_prediction/clinicalbert.yaml b/src/modeling/config/outcome_prediction/clinicalbert.yaml index 1388b3b..bd95a1f 100644 --- a/src/modeling/config/outcome_prediction/clinicalbert.yaml +++ b/src/modeling/config/outcome_prediction/clinicalbert.yaml @@ -8,7 +8,7 @@ hicric_pretrained: False base_model_name: "clinicalbert" pretrained_model_dir: "None" pretrained_hf_model_key: "medicalai/ClinicalBERT" -train_data_path: "./data/outcomes/train_backgrounds_suff.jsonl" +train_data_path: "./data/outcomes/train_backgrounds_suff_augmented.jsonl" # Training settings learning_rate: 8.0e-7 @@ -20,5 +20,5 @@ compile: True # Whether to use torch compile # Test eval settings test_data_path: "./data/outcomes/test_backgrounds_suff.jsonl" -checkpoint_name: "checkpoint-46134" +checkpoint_name: checkpoint-49770 eval_threshold: .55 \ No newline at end of file diff --git a/src/modeling/config/outcome_prediction/legalbert_small.yaml b/src/modeling/config/outcome_prediction/legalbert_small.yaml index 493fb8c..62ab992 100644 --- a/src/modeling/config/outcome_prediction/legalbert_small.yaml +++ b/src/modeling/config/outcome_prediction/legalbert_small.yaml @@ -8,17 +8,17 @@ hicric_pretrained: False base_model_name: "legal-bert-small-uncased" pretrained_model_dir: "None" pretrained_hf_model_key: "nlpaueb/legal-bert-small-uncased" -train_data_path: "./data/outcomes/train_backgrounds_suff.jsonl" +train_data_path: "./data/outcomes/train_backgrounds_suff_augmented.jsonl" # Training settings learning_rate: 8.0e-7 weight_decay: 0.01 -num_epochs: 40 -batch_size: 20 +num_epochs: 20 +batch_size: 48 dtype: "float16" # 'float32','float16' for training dtype compile: True # Whether to use torch compile # Test eval settings test_data_path: "./data/outcomes/test_backgrounds_suff.jsonl" -checkpoint_name: "checkpoint-82016" -eval_threshold: .8 \ No newline at end of file +checkpoint_name: "checkpoint-23040" +eval_threshold: .45 \ No newline at end of file From 84a35c19d0b6ecbda27bb43124d4168da848df4a Mon Sep 17 00:00:00 2001 From: MikeG112 Date: Thu, 24 Apr 2025 21:10:54 -0400 Subject: [PATCH 17/27] Record model size --- src/modeling/eval_hf_outcome_predictor.py | 76 ++++++++++++++++++++++- 1 file changed, 74 insertions(+), 2 deletions(-) diff --git a/src/modeling/eval_hf_outcome_predictor.py b/src/modeling/eval_hf_outcome_predictor.py index 39640ca..eb7a52a 100644 --- a/src/modeling/eval_hf_outcome_predictor.py +++ b/src/modeling/eval_hf_outcome_predictor.py @@ -150,6 +150,20 @@ def is_metadata_model(model): return has_j_embeddings and has_i_embeddings and accepts_metadata +def count_parameters(model): + """Count the number of trainable parameters in the model and calculate size in MB""" + param_count = 0 + size_in_bytes = 0 + + for p in model.parameters(): + if p.requires_grad: + param_count += p.numel() + size_in_bytes += p.numel() * p.element_size() + + size_in_mb = size_in_bytes / (1024 * 1024) + return param_count, size_in_mb + + def main(config_path: str): cfg = load_config(config_path) @@ -185,6 +199,10 @@ def main(config_path: str): label2id=LABEL2ID, ) + # Count model parameters + param_count, size_in_mb = count_parameters(model) + print(f"Model size: {param_count:,} parameters ({size_in_mb:.2f} MB)") + # Check if model supports metadata supports_metadata = is_metadata_model(model) @@ -226,26 +244,80 @@ def main(config_path: str): all_logits.append(outputs["logits"].cpu()) predictions = torch.cat(all_logits) + + # Now create predictions with unspecified metadata for all examples + print("Computing metrics with unspecified metadata for all examples...") + unspecified_j_ids = torch.full((len(text_records),), JURISDICTION_MAP["Unspecified"], dtype=torch.long) + unspecified_i_ids = torch.full((len(text_records),), INSURANCE_TYPE_MAP["Unspecified"], dtype=torch.long) + + all_unspecified_logits = [] + + for i in range(0, len(text_records), batch_size): + end_idx = min(i + batch_size, len(text_records)) + batch_texts = text_records[i:end_idx] + batch_j_ids = unspecified_j_ids[i:end_idx].to(device) + batch_i_ids = unspecified_i_ids[i:end_idx].to(device) + + inputs = tokenizer(batch_texts, truncation=True, padding=True, return_tensors="pt").to(device) + + with torch.no_grad(): + outputs = model(**inputs, jurisdiction_id=batch_j_ids, insurance_type_id=batch_i_ids) + + all_unspecified_logits.append(outputs["logits"].cpu()) + + unspecified_predictions = torch.cat(all_unspecified_logits) else: # For standard models, use the original pipeline pipeline = ClassificationPipeline(model=model, tokenizer=tokenizer, device=device, truncation=True) predictions = torch.cat(pipeline(text_records, batch_size=100)) + # No metadata to ignore for standard models, so unspecified predictions are the same + unspecified_predictions = predictions - # Compute metrics + # Compute standard metrics + metrics = compute_metrics(predictions, labels) + metrics["param_count"] = param_count + metrics["model_size_mb"] = size_in_mb + + # Add model size to metrics + metrics["param_count"] = param_count + + # Compute metrics with unspecified metadata + unspecified_metrics = compute_metrics(unspecified_predictions, labels) + # Add model size to unspecified metrics + unspecified_metrics["param_count"] = param_count + + # Compute threshold metrics if specified if threshold is not None: threshold_metrics = compute_metrics_w_threshold(predictions, labels, threshold) + threshold_metrics["param_count"] = param_count + threshold_metrics["model_size_mb"] = size_in_mb - metrics = compute_metrics(predictions, labels) + unspecified_threshold_metrics = compute_metrics_w_threshold(unspecified_predictions, labels, threshold) + unspecified_threshold_metrics["param_count"] = param_count + unspecified_threshold_metrics["model_size_mb"] = size_in_mb # Print and write metrics + print("Standard metrics:") print(metrics) with open(os.path.join(ckpt_path, "test_metrics.json"), "w") as f: json.dump(metrics, f) + + print("\nMetrics with unspecified metadata:") + print(unspecified_metrics) + with open(os.path.join(ckpt_path, "test_metrics_unspecified.json"), "w") as f: + json.dump(unspecified_metrics, f) + if threshold is not None: + print("\nStandard threshold metrics:") print(threshold_metrics) with open(os.path.join(ckpt_path, "test_metrics_w_threshold.json"), "w") as f: json.dump(threshold_metrics, f) + print("\nThreshold metrics with unspecified metadata:") + print(unspecified_threshold_metrics) + with open(os.path.join(ckpt_path, "test_metrics_w_threshold_unspecified.json"), "w") as f: + json.dump(unspecified_threshold_metrics, f) + return None From f831c0c34133f85bd8d0627a203e16aaa62a7ca9 Mon Sep 17 00:00:00 2001 From: MikeG112 Date: Fri, 25 Apr 2025 15:13:18 -0400 Subject: [PATCH 18/27] Update eval scripts --- .../outcome_prediction/legalbert_small.yaml | 4 +- src/modeling/openai_batch_call.py | 143 ++++++++++--- src/modeling/openai_eval.py | 189 +++++++++++++++--- 3 files changed, 285 insertions(+), 51 deletions(-) diff --git a/src/modeling/config/outcome_prediction/legalbert_small.yaml b/src/modeling/config/outcome_prediction/legalbert_small.yaml index 62ab992..1130517 100644 --- a/src/modeling/config/outcome_prediction/legalbert_small.yaml +++ b/src/modeling/config/outcome_prediction/legalbert_small.yaml @@ -20,5 +20,5 @@ compile: True # Whether to use torch compile # Test eval settings test_data_path: "./data/outcomes/test_backgrounds_suff.jsonl" -checkpoint_name: "checkpoint-23040" -eval_threshold: .45 \ No newline at end of file +checkpoint_name: "checkpoint-20736" +eval_threshold: .5 \ No newline at end of file diff --git a/src/modeling/openai_batch_call.py b/src/modeling/openai_batch_call.py index 4269993..3583b77 100644 --- a/src/modeling/openai_batch_call.py +++ b/src/modeling/openai_batch_call.py @@ -1,3 +1,5 @@ +import argparse +import json import os import time @@ -11,11 +13,11 @@ MODEL_KEY = "gpt-4o-mini-2024-07-18" OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", None) SYSTEM_MESSAGE = """You are an expert in U.S. health law and health policy, as well as a medical expert. In what follows, I will provide a description of a case in which a patient has submitted an appeal of a decision by their health insurer to deny a claim submitted on their behalf. You must predict whether an independent reviewer will overturn the denial, or uphold the denial when reviewing the appeal. If they would overturn it, your decision would be "Overturned". If they would not, your decision would be "Upheld". If there is insufficient information in the context to predict this, and it could go either way depending on more details, your decision would be "Insufficient". -Very short cases are often insufficient, as are cases which describe a treatment or service without saying what it is for. Sufficiency does not mean sufficient-without-a-doubt, it just means sufficient to make a good estimated guess about which -way the review will go. Most cases I present to you will have sufficient information, by my subjective standards. + +Very short cases are often insufficient, as are cases which describe a treatment or service without saying what it is for. Sufficiency does not mean sufficient-without-a-doubt, it just means sufficient to make a good estimated guess about which way the review will go. Most cases I present to you will have sufficient information, by my subjective standards. You must reply with json of the following form: -{"decision": "Overturned", "probability": .82} +{"decision": "Overturned", "probability": 0.82} with the decision, and the associated probability that that decision is correct. The possible decision classes are "Insufficient", "Upheld", and "Overturned". @@ -23,10 +25,10 @@ Here are two examples: Prompt: "An enrollee has requested Zepatier for treatment of her hepatitis C." -Desired Output: {"decision": "Overturned", "probability": .75} +Desired Output: {"decision": "Overturned", "probability": 0.75} Prompt: "An enrollee has requested emergency services provided on an emergent or urgent basis for treatment of her medical condition." -Desired Output: {"decision": "Insufficient", "probability": .99} +Desired Output: {"decision": "Insufficient", "probability": 0.99} """ @@ -84,7 +86,20 @@ def construct_answer_batch(recs: list[dict], output_path: str) -> None: outcome = rec["decision"] sufficiency_id = rec["sufficiency_id"] - answers.append({"custom_id": f"{custom_id}", "decision": construct_label(outcome, sufficiency_id)}) + # Get metadata fields (default to "Unspecified" if not present) + jurisdiction = rec.get("jurisdiction", "Unspecified") + insurance_type = rec.get("insurance_type", "Unspecified") + + # Include all fields in the answer + answers.append( + { + "custom_id": f"{custom_id}", + "decision": construct_label(outcome, sufficiency_id), + "sufficiency_id": sufficiency_id, + "jurisdiction": jurisdiction, + "insurance_type": insurance_type, + } + ) add_jsonl_lines(output_path, answers) @@ -104,7 +119,7 @@ def construct_request_batch(recs: list[dict], model_key: str, output_path: str) return None -def prepare_batches() -> None: +def prepare_batches() -> tuple: recs = get_records_list(os.path.join("./data/outcomes/", OUTCOMES_DATASET + ".jsonl")) output_path = f"./data/provider_annotated_outcomes/openai/{OUTCOMES_DATASET}/hicric_eval_request_answers.jsonl" if os.path.exists(output_path): @@ -122,7 +137,7 @@ def prepare_batches() -> None: # Full batch construct_request_batch(recs, MODEL_KEY, output_path) split_batch(single_batch_path=output_path, subbatch_dir=subbatch_dir) - return subbatch_dir + return subbatch_dir, recs def split_batch(single_batch_path, subbatch_dir, subbatch_size=1500) -> None: @@ -136,7 +151,7 @@ def split_batch(single_batch_path, subbatch_dir, subbatch_size=1500) -> None: return None -def batch_call(filepaths: list[str], api_key: str | None = OPENAI_API_KEY) -> None: +def batch_call(filepaths: list[str], api_key: str | None = OPENAI_API_KEY) -> list[str]: """Submit a batch completion call for each batch file in filepaths.""" if not api_key: raise Exception("You need to export an Open AI API key as an env var.") @@ -174,7 +189,7 @@ def batch_call(filepaths: list[str], api_key: str | None = OPENAI_API_KEY) -> No return batch_request_ids -def download_response(batch_id: str, subbatch_dir: str, api_key: str | None = OPENAI_API_KEY) -> None: +def download_response(batch_id: str, subbatch_dir: str, api_key: str | None = OPENAI_API_KEY) -> str: client = OpenAI(api_key=api_key) status_meta = client.batches.retrieve(batch_id) download_path = os.path.join(subbatch_dir, f"response_{batch_id}.jsonl") @@ -183,7 +198,7 @@ def download_response(batch_id: str, subbatch_dir: str, api_key: str | None = OP def poll_and_download( - batch_request_ids: list[str], subbatch_dir: str, poll_sleep_mins: int = 0.1, api_key: str | None = OPENAI_API_KEY + batch_request_ids: list[str], subbatch_dir: str, poll_sleep_mins: float = 0.1, api_key: str | None = OPENAI_API_KEY ) -> list[str]: client = OpenAI(api_key=api_key) @@ -217,19 +232,95 @@ def merge_jsonl(paths: list[str], output_path): return None +def synchronous_call(records: list[dict], api_key: str | None = OPENAI_API_KEY) -> list[dict]: + """Make synchronous API calls one at a time for each record.""" + if not api_key: + raise Exception("You need to export an Open AI API key as an env var.") + + client = OpenAI(api_key=api_key) + results = [] + + print(f"Processing {len(records)} records synchronously...") + for idx, rec in enumerate(records): + case_description = rec["text"] + custom_id = idx + + print(f"Processing record {idx+1}/{len(records)}") + try: + response = client.chat.completions.create( + model=MODEL_KEY, + messages=[{"role": "system", "content": SYSTEM_MESSAGE}, {"role": "user", "content": case_description}], + ) + + # Get the response content + completion_text = response.choices[0].message.content + + # Parse JSON response + try: + completion_json = json.loads(completion_text) + except json.JSONDecodeError: + print(f"Failed to parse JSON for record {idx}: {completion_text}") + completion_json = {"decision": "Error", "probability": 0} + + # Create result record + result = { + "custom_id": str(custom_id), + "request": { + "body": { + "messages": [ + {"role": "system", "content": SYSTEM_MESSAGE}, + {"role": "user", "content": case_description}, + ] + } + }, + "response": {"choices": [{"message": {"content": completion_text, "role": "assistant"}}]}, + "parsed_response": completion_json, + } + + results.append(result) + + except Exception as e: + print(f"Error processing record {idx}: {str(e)}") + # Add error record + results.append({"custom_id": str(custom_id), "error": str(e)}) + + return results + + if __name__ == "__main__": - subbatch_dir = prepare_batches() - filenames = [os.path.join(subbatch_dir, filename) for filename in os.listdir(subbatch_dir)] - - # Can only upload small number of files at a time due to enque limit - all_output_files = [] - batch_size = 1 - for start_idx in range(0, len(filenames), batch_size): - enqueued_files = filenames[start_idx : start_idx + batch_size] - batch_request_ids = batch_call(enqueued_files) - output_files = poll_and_download(batch_request_ids, subbatch_dir) - all_output_files.extend(output_files) - - # Merge to single response - output_path = f"./data/provider_annotated_outcomes/openai/{OUTCOMES_DATASET}/hicric_eval_response_{MODEL_KEY}.jsonl" - merge_jsonl(all_output_files, output_path) + parser = argparse.ArgumentParser(description="Run OpenAI API calls for health insurance claim evaluations") + parser.add_argument("--synchronous", action="store_true", help="Use synchronous API calls instead of batch") + args = parser.parse_args() + + subbatch_dir, records = prepare_batches() + + if args.synchronous: + # Synchronous mode + print("Running in synchronous mode...") + results = synchronous_call(records) + + # Save results + output_path = ( + f"./data/provider_annotated_outcomes/openai/{OUTCOMES_DATASET}/hicric_eval_response_sync_{MODEL_KEY}.jsonl" + ) + add_jsonl_lines(output_path, results) + print(f"Synchronous results saved to {output_path}") + else: + # Batch mode (original behavior) + filenames = [os.path.join(subbatch_dir, filename) for filename in os.listdir(subbatch_dir)] + + # Can only upload small number of files at a time due to enque limit + all_output_files = [] + batch_size = 1 + for start_idx in range(0, len(filenames), batch_size): + enqueued_files = filenames[start_idx : start_idx + batch_size] + batch_request_ids = batch_call(enqueued_files) + output_files = poll_and_download(batch_request_ids, subbatch_dir) + all_output_files.extend(output_files) + + # Merge to single response + output_path = ( + f"./data/provider_annotated_outcomes/openai/{OUTCOMES_DATASET}/hicric_eval_response_{MODEL_KEY}.jsonl" + ) + merge_jsonl(all_output_files, output_path) + print(f"Batch results saved to {output_path}") diff --git a/src/modeling/openai_eval.py b/src/modeling/openai_eval.py index 404341f..dc0c0ec 100644 --- a/src/modeling/openai_eval.py +++ b/src/modeling/openai_eval.py @@ -2,14 +2,41 @@ import os import numpy as np -from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score +from sklearn.metrics import ( + accuracy_score, + f1_score, + precision_score, + recall_score, + roc_auc_score, +) from src.util import get_records_list MODEL_KEY = "gpt-4o-mini-2024-07-18" +# Define mappings for jurisdiction and insurance type +JURISDICTION_MAP = {"NY": 0, "CA": 1, "Unspecified": 2} +INSURANCE_TYPE_MAP = {"Commercial": 0, "Medicaid": 1, "Unspecified": 2} + def compute_metrics(predictions: np.ndarray, labels: np.ndarray) -> dict: + # Check if we have enough unique classes for ROC AUC + try: + # Convert predictions to one-hot format for ROC AUC if they're not already + if len(predictions.shape) == 1: # If predictions are just class labels + n_classes = len(np.unique(labels)) + one_hot_preds = np.zeros((len(predictions), n_classes)) + for i, pred in enumerate(predictions): + if pred < n_classes: # Ensure valid index + one_hot_preds[i, pred] = 1 + roc_auc = roc_auc_score(labels, one_hot_preds, multi_class="ovr") + else: + # Assuming predictions are already probabilities/scores + roc_auc = roc_auc_score(labels, predictions, multi_class="ovr") + except ValueError: + # Handle case when there are classes with no predictions + roc_auc = float("nan") + acc = accuracy_score(labels, predictions) # Macro @@ -28,6 +55,7 @@ def compute_metrics(predictions: np.ndarray, labels: np.ndarray) -> dict: f1 = f1_score(labels, predictions, average=None, zero_division=np.nan) return { + "metric_fn_name": "compute_metrics", "accuracy": acc, "macro_precision": macro_prec.tolist(), "macro_recall": macro_rec.tolist(), @@ -38,6 +66,70 @@ def compute_metrics(predictions: np.ndarray, labels: np.ndarray) -> dict: "class_precision": prec.tolist(), "class_recall": rec.tolist(), "class_f1": f1.tolist(), + "roc_auc": roc_auc, + } + + +def compute_metrics_w_threshold(predictions: np.ndarray, labels: np.ndarray, threshold: float) -> dict: + # Convert predictions to probabilities if they're not already + if len(predictions.shape) == 1: # If predictions are just class labels + n_classes = max(3, np.max(predictions) + 1) # Ensure we have at least 3 classes + probs = np.zeros((len(predictions), n_classes)) + for i, pred in enumerate(predictions): + if pred < n_classes: # Ensure valid index + probs[i, pred] = 1 + else: + # Assuming predictions are already probabilities/scores + probs = predictions + + try: + roc_auc = roc_auc_score(labels, probs, multi_class="ovr") + except ValueError: + roc_auc = float("nan") + + # Apply threshold logic as in original implementation + pred_labels = np.full_like(labels, -1) # Initialize with -1 for unassigned classes + + # Assign class 0 where the probability meets or exceeds the threshold + class_0_mask = probs[:, 0] >= threshold + pred_labels[class_0_mask] = 0 + + # If probability does not exceed threshold, take argmax of upheld/overturned class options + remaining_mask = pred_labels == -1 + remaining_preds = probs[remaining_mask, 1:] # Exclude class 0 + pred_labels[remaining_mask] = np.argmax(remaining_preds, axis=1) + 1 # + 1 to adjust for class index shift + + # Compute metrics + acc = accuracy_score(labels, pred_labels) + + # Macro + macro_prec = precision_score(labels, pred_labels, average="macro") + macro_rec = recall_score(labels, pred_labels, average="macro") + macro_f1 = f1_score(labels, pred_labels, average="macro", zero_division=np.nan) + + # Micro + micro_prec = precision_score(labels, pred_labels, average="micro") + micro_rec = recall_score(labels, pred_labels, average="micro") + micro_f1 = f1_score(labels, pred_labels, average="micro", zero_division=np.nan) + + # Class-level + prec = precision_score(labels, pred_labels, average=None) + rec = recall_score(labels, pred_labels, average=None) + f1 = f1_score(labels, pred_labels, average=None, zero_division=np.nan) + + return { + "metric_fn_name": "compute_metrics_w_threshold", + "accuracy": acc, + "macro_precision": macro_prec.tolist(), + "macro_recall": macro_rec.tolist(), + "macro_f1": macro_f1.tolist(), + "micro_precision": micro_prec.tolist(), + "micro_recall": micro_rec.tolist(), + "micro_f1": micro_f1.tolist(), + "class_precision": prec.tolist(), + "class_recall": rec.tolist(), + "class_f1": f1.tolist(), + "roc_auc": roc_auc, } @@ -46,57 +138,108 @@ def merge_pred_gt(response_path: str, answers_path: str) -> dict: gt = get_records_list(answers_path) responses = get_records_list(response_path) + # Extract ground truth and metadata for rec in gt: - merged[rec["custom_id"]] = {"answer": rec["decision"]} - + # Store decision and metadata in merged dict + merged[rec["custom_id"]] = { + "answer": rec["decision"], + "jurisdiction": rec.get("jurisdiction", "Unspecified"), + "insurance_type": rec.get("insurance_type", "Unspecified"), + "sufficiency_id": rec.get("sufficiency_id", 1), # Default to sufficient if not specified + } + + # Extract predictions for rec in responses: - resp = rec["response"]["body"]["choices"][0]["message"]["content"] - try: - pred = json.loads(resp.replace(".", "0.")) - except json.JSONDecodeError: - print(rec["custom_id"]) - print(type(resp)) - print(resp) - pred = None - merged[rec["custom_id"]]["pred"] = pred + pred = rec["parsed_response"] + + # Add prediction to the merged dict if the ID exists + if rec["custom_id"] in merged: + merged[rec["custom_id"]]["pred"] = pred return merged -def eval_preds() -> dict: - pass +def construct_label(outcome, sufficiency_id, label_map): + """Construct 3-class label from outcome, if sufficient background.""" + if sufficiency_id == 0: + return 0 + else: + return label_map.get(outcome, 0) # Default to 0 if outcome not in map -if __name__ == "__main__": +def eval_preds(threshold=None) -> dict: outcomes_dataset = "test_backgrounds_suff" answers_path = f"./data/provider_annotated_outcomes/openai/{outcomes_dataset}/hicric_eval_request_answers.jsonl" response_path = ( f"./data/provider_annotated_outcomes/openai/{outcomes_dataset}/hicric_eval_response_{MODEL_KEY}.jsonl" ) - merged = merge_pred_gt(response_path, answers_path) + merged = merge_pred_gt(response_path, answers_path) out_map = {"Insufficient": 0, "Upheld": 1, "Overturned": 2} preds = [] labels = [] + jurisdiction_ids = [] + insurance_type_ids = [] + # Favorable evaluation: don't count non-answers against GPT, to be generous for id, rec in merged.items(): print(rec) - if rec["pred"] is None: + if "pred" not in rec or rec["pred"] is None: continue if rec["pred"].get("decision") is None: continue if rec["pred"].get("decision") not in out_map.keys(): continue else: + # Extract prediction preds.append(out_map[rec["pred"]["decision"]]) - labels.append(out_map[rec["answer"]]) + + # Extract label with sufficiency consideration + sufficiency_id = rec.get("sufficiency_id", 1) # Default to sufficient if not specified + labels.append(construct_label(rec["answer"], sufficiency_id, out_map)) + + # Extract metadata + j = rec.get("jurisdiction", "Unspecified") + i = rec.get("insurance_type", "Unspecified") + jurisdiction_ids.append(JURISDICTION_MAP.get(j, JURISDICTION_MAP["Unspecified"])) + insurance_type_ids.append(INSURANCE_TYPE_MAP.get(i, INSURANCE_TYPE_MAP["Unspecified"])) + if len(preds) < len(merged): print(f"Model failed to produce valid json for {len(merged) - len(preds)} values.") - metrics = compute_metrics(preds, labels) + + # Convert to numpy arrays for metrics computation + preds_np = np.array(preds) + labels_np = np.array(labels) + + # Compute standard metrics + metrics = compute_metrics(preds_np, labels_np) + + # Add metadata distribution info to metrics + metrics["jurisdiction_distribution"] = {j: jurisdiction_ids.count(JURISDICTION_MAP[j]) for j in JURISDICTION_MAP} + metrics["insurance_type_distribution"] = { + i: insurance_type_ids.count(INSURANCE_TYPE_MAP[i]) for i in INSURANCE_TYPE_MAP + } + + # Compute threshold metrics if threshold is provided + if threshold is not None: + threshold_metrics = compute_metrics_w_threshold(preds_np, labels_np, threshold) + + # Print and save metrics print(metrics) - with open( - os.path.join(f"./data/provider_annotated_outcomes/openai/{outcomes_dataset}", f"{MODEL_KEY}_test_metrics.json"), - "w", - ) as f: + output_dir = f"./data/provider_annotated_outcomes/openai/{outcomes_dataset}" + with open(os.path.join(output_dir, f"{MODEL_KEY}_test_metrics.json"), "w") as f: json.dump(metrics, f) + + if threshold is not None: + print(threshold_metrics) + with open(os.path.join(output_dir, f"{MODEL_KEY}_test_metrics_w_threshold.json"), "w") as f: + json.dump(threshold_metrics, f) + + return metrics + + +if __name__ == "__main__": + # Optionally specify a threshold for the additional metrics + threshold = 0.5 # Set to None if you don't want threshold-based metrics + eval_preds(threshold) From 4840e5f6b67e8928a091697f0cd2ea0cc2225104 Mon Sep 17 00:00:00 2001 From: MikeG112 Date: Fri, 25 Apr 2025 15:18:44 -0400 Subject: [PATCH 19/27] Adjust metadata model --- src/modeling/train_outcome_predictor.py | 100 ++++++++++-------------- 1 file changed, 43 insertions(+), 57 deletions(-) diff --git a/src/modeling/train_outcome_predictor.py b/src/modeling/train_outcome_predictor.py index 95e5e48..37ab9e4 100644 --- a/src/modeling/train_outcome_predictor.py +++ b/src/modeling/train_outcome_predictor.py @@ -18,6 +18,7 @@ from transformers import ( AutoModelForSequenceClassification, AutoTokenizer, + BertForSequenceClassification, DataCollatorWithPadding, DistilBertForSequenceClassification, Trainer, @@ -188,22 +189,16 @@ def compute_metrics2(eval_pred) -> dict: return best_metrics -class TextClassificationWithMetadata(DistilBertForSequenceClassification): +class TextClassificationWithMetadata(BertForSequenceClassification): def __init__(self, config): super().__init__(config) - # Add embeddings for jurisdiction and insurance type - self.jurisdiction_embeddings = torch.nn.Embedding(3, 16) - self.insurance_type_embeddings = torch.nn.Embedding(3, 16) + # Only create embeddings for the actual classes (NY, CA) and (Commercial, Medicaid) + # The metadata is optional, and unspecified inputs get the average embeddings + self.jurisdiction_embeddings = torch.nn.Embedding(2, 16) + self.insurance_type_embeddings = torch.nn.Embedding(2, 16) - # Initialize the unspecified embeddings to be zeros - with torch.no_grad(): - self.jurisdiction_embeddings.weight[2].fill_(0) - self.insurance_type_embeddings.weight[2].fill_(0) - - # Get the hidden size from the model's config hidden_size = self.config.hidden_size - # Create a new classifier with the additional features self.final_classifier = torch.nn.Linear(hidden_size + 32, config.num_labels) def forward( @@ -216,52 +211,51 @@ def forward( return_loss=True, **kwargs, ): - # Extract labels before passing to super().forward() labels = kwargs.pop("labels", None) - # Get the embeddings and pooler output from the base model base_outputs = super().forward( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, - labels=None, # Important: don't pass labels yet - # **kwargs # Pass remaining kwargs + labels=None, # Don't pass labels yet + ) + + # Get the last hidden state's [CLS] token + last_hidden_state = base_outputs.hidden_states[-1] + pooled_output = last_hidden_state[:, 0] + + # Create masks for unspecified values (which will get assigned idx n for n specified classes) + j_unspecified_mask = jurisdiction_id == 2 + i_unspecified_mask = insurance_type_id == 2 + + # Convert invalid embedding indices to last embedding indices, temporarily + j_ids_valid = torch.clamp(jurisdiction_id, 0, 1) + i_ids_valid = torch.clamp(insurance_type_id, 0, 1) + + # Get actual embeddings (and embeddings for temp values) + j_embeddings = self.jurisdiction_embeddings(j_ids_valid) + i_embeddings = self.insurance_type_embeddings(i_ids_valid) + + # Calculate average embeddings for each metadata type + avg_j_embedding = (self.jurisdiction_embeddings.weight[0] + self.jurisdiction_embeddings.weight[1]) / 2 + avg_i_embedding = (self.insurance_type_embeddings.weight[0] + self.insurance_type_embeddings.weight[1]) / 2 + + # Replace unspecified temp embedding values with averages + j_embeddings = torch.where( + j_unspecified_mask.unsqueeze(-1).expand_as(j_embeddings), + avg_j_embedding.expand_as(j_embeddings), + j_embeddings, ) - # If we're not using the additional features, use base model logits - if jurisdiction_id is None or insurance_type_id is None: - logits = base_outputs.logits - else: - # Process metadata features - j_unspecified_mask = jurisdiction_id == 2 - i_unspecified_mask = insurance_type_id == 2 - - j_embeddings = self.jurisdiction_embeddings(jurisdiction_id) - i_embeddings = self.insurance_type_embeddings(insurance_type_id) - - # Calculate average embeddings - avg_j_embedding = (self.jurisdiction_embeddings.weight[0] + self.jurisdiction_embeddings.weight[1]) / 2 - avg_i_embedding = (self.insurance_type_embeddings.weight[0] + self.insurance_type_embeddings.weight[1]) / 2 - - j_embeddings = torch.where( - j_unspecified_mask.unsqueeze(-1).expand_as(j_embeddings), - avg_j_embedding.expand_as(j_embeddings), - j_embeddings, - ) - - # This replaces: if i_unspecified_mask.any(): i_embeddings[i_unspecified_mask] = avg_i_embedding - i_embeddings = torch.where( - i_unspecified_mask.unsqueeze(-1).expand_as(i_embeddings), - avg_i_embedding.expand_as(i_embeddings), - i_embeddings, - ) - - # For models without pooler_output, use the last hidden state's [CLS] token - last_hidden_state = base_outputs.hidden_states[-1] - pooled_output = last_hidden_state[:, 0] - - combined_features = torch.cat([pooled_output, j_embeddings, i_embeddings], dim=1) - logits = self.final_classifier(combined_features) + i_embeddings = torch.where( + i_unspecified_mask.unsqueeze(-1).expand_as(i_embeddings), + avg_i_embedding.expand_as(i_embeddings), + i_embeddings, + ) + + # Combine features and get logits + combined_features = torch.cat([pooled_output, j_embeddings, i_embeddings], dim=1) + logits = self.final_classifier(combined_features) # Update the logits in the output results = {"logits": logits} @@ -340,14 +334,6 @@ def main(config_path: str) -> None: # Use our custom data collator that handles the additional features data_collator = DataCollatorWithMetadata(tokenizer=tokenizer) - # Load the base model to determine its class - base_model = AutoModelForSequenceClassification.from_pretrained( - pretrained_model_key, - num_labels=3, - id2label=ID2LABEL, - label2id=LABEL2ID, - ) - # Now instantiate your custom model correctly model = TextClassificationWithMetadata.from_pretrained( pretrained_model_key, From b9e80b7b07400d0396b2e822ceca2e7d048e3664 Mon Sep 17 00:00:00 2001 From: MikeG112 Date: Sat, 26 Apr 2025 20:37:55 -0400 Subject: [PATCH 20/27] Update metdata model{ --- src/modeling/eval_hf_outcome_predictor.py | 88 ++++++++ src/modeling/train_outcome_predictor.py | 238 ++++++++++++++++++---- 2 files changed, 283 insertions(+), 43 deletions(-) diff --git a/src/modeling/eval_hf_outcome_predictor.py b/src/modeling/eval_hf_outcome_predictor.py index eb7a52a..bc75be9 100644 --- a/src/modeling/eval_hf_outcome_predictor.py +++ b/src/modeling/eval_hf_outcome_predictor.py @@ -5,6 +5,7 @@ import numpy as np import scipy import torch +from safetensors import safe_open from sklearn.metrics import ( accuracy_score, f1_score, @@ -40,6 +41,91 @@ def construct_label(outcome, sufficiency_id, label2id): return label2id[outcome] +def inspect_model_metadata(checkpoint_path): + """Load a saved model in safetensors format and examine its metadata components.""" + # Path to the safetensors file + safetensors_path = os.path.join(checkpoint_path, "model.safetensors") + if not os.path.exists(safetensors_path): + # Try alternate naming patterns + potential_paths = [ + os.path.join(checkpoint_path, "pytorch_model.safetensors"), + os.path.join(checkpoint_path, "model.safetensors"), + ] + for path in potential_paths: + if os.path.exists(path): + safetensors_path = path + break + + print(f"Loading model from {safetensors_path}") + + if not os.path.exists(safetensors_path): + print(f"No safetensors file found at {safetensors_path}") + return None + + # Load the model using safetensors + tensors = {} + with safe_open(safetensors_path, framework="pt", device="cpu") as f: + for key in f.keys(): + tensors[key] = f.get_tensor(key) + + # Check if metadata_scale exists and print its value + if "metadata_scale" in tensors: + metadata_scale = tensors["metadata_scale"].item() + print(f"Metadata scale: {metadata_scale}") + else: + print("metadata_scale not found in state dict") + # Try looking for it with model prefix + for key in tensors: + if key.endswith("metadata_scale"): + metadata_scale = tensors[key].item() + print(f"Found metadata scale as {key}: {metadata_scale}") + + if "attention_scale" in tensors: + attention_scale = tensors["attention_scale"].item() + print(f"Attention scale: {attention_scale}") + else: + # Try looking for it with model prefix + for key in tensors: + if key.endswith("attention_scale"): + attention_scale = tensors[key].item() + print(f"Found attention scale as {key}: {attention_scale}") + + # Find embedding keys + j_embedding_key = None + i_embedding_key = None + for key in tensors: + if "jurisdiction_embeddings.weight" in key: + j_embedding_key = key + if "insurance_type_embeddings.weight" in key: + i_embedding_key = key + + # Check embedding distances + if j_embedding_key and i_embedding_key: + j_embeddings = tensors[j_embedding_key].numpy() + i_embeddings = tensors[i_embedding_key].numpy() + + print(f"Jurisdiction embedding shape: {j_embeddings.shape}") + print(f"Insurance type embedding shape: {i_embeddings.shape}") + + # Calculate distances between embeddings + if j_embeddings.shape[0] >= 2: + j_distance = np.linalg.norm(j_embeddings[0] - j_embeddings[1]) + print(f"Distance between jurisdiction embeddings (0 vs 1): {j_distance}") + + if i_embeddings.shape[0] >= 2: + i_distance = np.linalg.norm(i_embeddings[0] - i_embeddings[1]) + print(f"Distance between insurance type embeddings (0 vs 1): {i_distance}") + else: + print("Embedding weights not found in tensors") + print("Available keys:", list(tensors.keys())) + + return { + "metadata_scale": metadata_scale if "metadata_scale" in tensors else None, + "j_embeddings": j_embeddings if j_embedding_key else None, + "i_embeddings": i_embeddings if i_embedding_key else None, + } + + def compute_metrics(predictions: np.ndarray, labels: np.ndarray) -> dict: roc_auc = roc_auc_score(labels, scipy.special.softmax(predictions, axis=-1), multi_class="ovr") @@ -185,6 +271,8 @@ def main(config_path: str): print("Evaluating from checkpoint: ", checkpoint_name) ckpt_path = os.path.join(checkpoints_dir, checkpoint_name) + _results = inspect_model_metadata(ckpt_path) + # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(ckpt_path) diff --git a/src/modeling/train_outcome_predictor.py b/src/modeling/train_outcome_predictor.py index 37ab9e4..559c5e8 100644 --- a/src/modeling/train_outcome_predictor.py +++ b/src/modeling/train_outcome_predictor.py @@ -6,6 +6,7 @@ import numpy as np import scipy import torch +import torch.nn as nn from datasets import Dataset, DatasetDict, load_dataset from sklearn.metrics import ( accuracy_score, @@ -24,6 +25,7 @@ Trainer, TrainingArguments, ) +from transformers.modeling_outputs import SequenceClassifierOutput import wandb from src.modeling.util import export_onnx_model, load_config, quantize_onnx_model @@ -133,8 +135,7 @@ def compute_metrics(eval_pred) -> dict: def compute_metrics2(eval_pred) -> dict: - predictions = eval_pred.predictions[0] - labels = eval_pred.predictions[1] + predictions, labels = eval_pred softmax_preds = scipy.special.softmax(predictions, axis=-1) @@ -189,17 +190,116 @@ def compute_metrics2(eval_pred) -> dict: return best_metrics -class TextClassificationWithMetadata(BertForSequenceClassification): +class TextClassificationWithMetadata(DistilBertForSequenceClassification): def __init__(self, config): super().__init__(config) - # Only create embeddings for the actual classes (NY, CA) and (Commercial, Medicaid) - # The metadata is optional, and unspecified inputs get the average embeddings - self.jurisdiction_embeddings = torch.nn.Embedding(2, 16) - self.insurance_type_embeddings = torch.nn.Embedding(2, 16) + self.hidden_size = config.hidden_size + + # Metadata embeddings + self.jurisdiction_embeddings = nn.Embedding(2, self.hidden_size // 4) + self.insurance_type_embeddings = nn.Embedding(2, self.hidden_size // 4) + + # Fusion projection + self.metadata_projection = nn.Sequential( + nn.Linear(self.hidden_size // 2, self.hidden_size), + nn.LayerNorm(self.hidden_size), + nn.GELU(), + nn.Dropout(0.1), + ) + + # Manual cross-attention components + self.num_heads = 8 + self.head_dim = self.hidden_size // self.num_heads + + # Query, Key, Value projections + self.q_proj = nn.Linear(self.hidden_size, self.hidden_size) + self.k_proj = nn.Linear(self.hidden_size, self.hidden_size) + self.v_proj = nn.Linear(self.hidden_size, self.hidden_size) + self.out_proj = nn.Linear(self.hidden_size, self.hidden_size) + + self.dropout = nn.Dropout(0.1) + + self.metadata_norm = nn.LayerNorm(self.hidden_size) + self.sequence_norm = nn.LayerNorm(self.hidden_size) + + # Save the original num_labels + self.num_labels = ( + self.classifier.out_features if hasattr(self.classifier, "out_features") else config.num_labels + ) + + # Now replace the classifier + self.enhanced_classifier = nn.Sequential( + nn.Linear(self.hidden_size, self.hidden_size), + nn.LayerNorm(self.hidden_size), + nn.GELU(), + nn.Dropout(0.1), + nn.Linear(self.hidden_size, self.num_labels), + ) + + def manual_cross_attention(self, query, key, value, key_padding_mask=None): + # Same implementation, but with explicit cleanup + batch_size = query.size(1) + seq_length = key.size(0) + + # Project query, key, value + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + + # Reshape for multi-head attention + q = q.view(1, batch_size, self.num_heads, self.head_dim).permute(2, 1, 0, 3) + k = k.view(seq_length, batch_size, self.num_heads, self.head_dim).permute(2, 1, 0, 3) + v = v.view(seq_length, batch_size, self.num_heads, self.head_dim).permute(2, 1, 0, 3) + + # Calculate attention scores + scaling = float(self.head_dim) ** -0.5 + q = q * scaling + attn_scores = torch.matmul(q, k.transpose(-2, -1)) - hidden_size = self.config.hidden_size + # Apply attention mask if provided + if key_padding_mask is not None: + attn_mask = key_padding_mask.unsqueeze(1).unsqueeze(1) + attn_mask = attn_mask.transpose(0, 1) + attn_mask = attn_mask.expand(self.num_heads, -1, 1, -1) + attn_scores = attn_scores.masked_fill(attn_mask, -10000.0) - self.final_classifier = torch.nn.Linear(hidden_size + 32, config.num_labels) + # Get max values for numerical stability + attn_scores_max, _ = torch.max(attn_scores, dim=-1, keepdim=True) + attn_scores = attn_scores - attn_scores_max + + # Apply softmax + attn_weights = torch.exp(torch.clamp(attn_scores, min=-10000.0, max=100.0)) + attn_sum = torch.sum(attn_weights, dim=-1, keepdim=True) + 1e-6 + attn_weights = attn_weights / attn_sum + + # Apply dropout + attn_weights = self.dropout(attn_weights) + + # Clean up intermediates + del attn_scores, attn_scores_max, attn_sum + if key_padding_mask is not None: + del attn_mask + + # Apply attention weights to values + context = torch.matmul(attn_weights, v) + + # Clean up more intermediates + del attn_weights + + # Reshape back + context = context.permute(2, 1, 0, 3) + context = context.reshape(1, batch_size, self.hidden_size) + + # Clean up + del q, k, v + + # Apply output projection + result = self.out_proj(context) + + # Final cleanup + del context + + return result def forward( self, @@ -207,66 +307,92 @@ def forward( attention_mask=None, jurisdiction_id=None, insurance_type_id=None, - return_dict=True, - return_loss=True, - **kwargs, + labels=None, + return_dict=None, + token_type_ids=None, ): - labels = kwargs.pop("labels", None) + # Validate that required parameters are provided, use index 2 for optional + if jurisdiction_id is None or insurance_type_id is None: + raise ValueError("jurisdiction_id and insurance_type_id must be provided") - base_outputs = super().forward( - input_ids=input_ids, - attention_mask=attention_mask, - output_hidden_states=True, - labels=None, # Don't pass labels yet - ) - - # Get the last hidden state's [CLS] token - last_hidden_state = base_outputs.hidden_states[-1] - pooled_output = last_hidden_state[:, 0] - - # Create masks for unspecified values (which will get assigned idx n for n specified classes) + # Process metadata j_unspecified_mask = jurisdiction_id == 2 i_unspecified_mask = insurance_type_id == 2 - # Convert invalid embedding indices to last embedding indices, temporarily + # Temporarily assign unspecified metadata classes -> 1 j_ids_valid = torch.clamp(jurisdiction_id, 0, 1) i_ids_valid = torch.clamp(insurance_type_id, 0, 1) - # Get actual embeddings (and embeddings for temp values) j_embeddings = self.jurisdiction_embeddings(j_ids_valid) i_embeddings = self.insurance_type_embeddings(i_ids_valid) - # Calculate average embeddings for each metadata type - avg_j_embedding = (self.jurisdiction_embeddings.weight[0] + self.jurisdiction_embeddings.weight[1]) / 2 - avg_i_embedding = (self.insurance_type_embeddings.weight[0] + self.insurance_type_embeddings.weight[1]) / 2 + avg_j_embedding = self.jurisdiction_embeddings.weight.mean(dim=0) + avg_i_embedding = self.insurance_type_embeddings.weight.mean(dim=0) - # Replace unspecified temp embedding values with averages + # Assign unspecified classes the average embeddings j_embeddings = torch.where( j_unspecified_mask.unsqueeze(-1).expand_as(j_embeddings), avg_j_embedding.expand_as(j_embeddings), j_embeddings, ) - i_embeddings = torch.where( i_unspecified_mask.unsqueeze(-1).expand_as(i_embeddings), avg_i_embedding.expand_as(i_embeddings), i_embeddings, ) - # Combine features and get logits - combined_features = torch.cat([pooled_output, j_embeddings, i_embeddings], dim=1) - logits = self.final_classifier(combined_features) + # Combine metadata and project + combined_metadata = torch.cat([j_embeddings, i_embeddings], dim=1) + metadata_features = self.metadata_projection(combined_metadata) + + # Call the parent model but bypass its classification head + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + base_outputs = super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, # We need the hidden states + return_dict=True, + token_type_ids=token_type_ids, + labels=None, # Don't pass labels yet + ) + + sequence_output = base_outputs.hidden_states[-1] - # Update the logits in the output - results = {"logits": logits} + # Use cross-attention for fusion + metadata_features = self.metadata_norm(metadata_features) + sequence_output = self.sequence_norm(sequence_output) + # Reshape for attention + metadata_features = metadata_features.unsqueeze(0) # [1, batch, hidden_dim] + sequence_output_t = sequence_output.transpose(0, 1) # [seq_len, batch, hidden_dim] + + key_padding_mask = attention_mask == 0 + + fused_features = self.manual_cross_attention( + query=metadata_features, key=sequence_output_t, value=sequence_output_t, key_padding_mask=key_padding_mask + ) + + # Combine with CLS token + fused_features = fused_features.squeeze(0) + sequence_output[:, 0] + + logits = self.enhanced_classifier(fused_features) + + # Calculate loss if labels provided + loss = None if labels is not None: - loss_fct = torch.nn.CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) - results["loss"] = loss - results["labels"] = labels + loss = torch.nn.functional.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1)) - return results + if return_dict: + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=base_outputs.hidden_states, + attentions=base_outputs.attentions, + ) + else: + output = (logits,) + base_outputs[2:] + return ((loss,) + output) if loss is not None else output class DataCollatorWithMetadata(DataCollatorWithPadding): @@ -283,6 +409,31 @@ def __call__(self, features): return batch +class CPUEvalTrainer(Trainer): + def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None): + has_labels = all(inputs.get(k) is not None for k in self.label_names) + inputs = self._prepare_inputs(inputs) + + with torch.no_grad(): + outputs = model(**inputs) + logits = outputs.logits + + # Move logits to CPU immediately to save GPU memory + logits = logits.detach().cpu() + + labels = None + if has_labels: + labels = tuple(inputs.get(name).detach().cpu() for name in self.label_names) + if len(labels) == 1: + labels = labels[0] + + loss = None + if has_labels and outputs.loss is not None: + loss = outputs.loss.detach().cpu() + + return (loss, logits, labels) + + def main(config_path: str) -> None: cfg = load_config(config_path) @@ -355,6 +506,7 @@ def main(config_path: str) -> None: learning_rate=cfg["learning_rate"], per_device_train_batch_size=cfg["batch_size"], per_device_eval_batch_size=cfg["batch_size"], + eval_accumulation_steps=16, num_train_epochs=cfg["num_epochs"], weight_decay=cfg["weight_decay"], fp16=(cfg["dtype"] == "float16"), @@ -368,7 +520,7 @@ def main(config_path: str) -> None: dataloader_pin_memory=True, ) - trainer = Trainer( + trainer = CPUEvalTrainer( model=model, args=training_args, train_dataset=dataset["train"], From 1adbd2d1b803bea611a4629b0e2276dd3cc4bdc3 Mon Sep 17 00:00:00 2001 From: MikeG112 Date: Sat, 26 Apr 2025 20:38:12 -0400 Subject: [PATCH 21/27] Update legalbert small config --- src/modeling/config/outcome_prediction/legalbert_small.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/modeling/config/outcome_prediction/legalbert_small.yaml b/src/modeling/config/outcome_prediction/legalbert_small.yaml index 1130517..393a628 100644 --- a/src/modeling/config/outcome_prediction/legalbert_small.yaml +++ b/src/modeling/config/outcome_prediction/legalbert_small.yaml @@ -21,4 +21,4 @@ compile: True # Whether to use torch compile # Test eval settings test_data_path: "./data/outcomes/test_backgrounds_suff.jsonl" checkpoint_name: "checkpoint-20736" -eval_threshold: .5 \ No newline at end of file +eval_threshold: .55 \ No newline at end of file From a7505dca716c7b9ebb48f817b02f4afe14a3fb18 Mon Sep 17 00:00:00 2001 From: MikeG112 Date: Sat, 26 Apr 2025 20:38:57 -0400 Subject: [PATCH 22/27] Update per channel / reduce --- src/modeling/util.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/modeling/util.py b/src/modeling/util.py index f391681..bf60d65 100644 --- a/src/modeling/util.py +++ b/src/modeling/util.py @@ -3,6 +3,7 @@ import torch import yaml from onnxruntime.quantization import QuantType, quantize_dynamic +from onnxruntime.quantization.shape_inference import quant_pre_process from transformers import AutoModel, AutoTokenizer @@ -79,10 +80,17 @@ def export_onnx_model(output_model_path: str, model: torch.nn.Module | AutoModel def quantize_onnx_model(onnx_model_path: str, quantized_model_path: str): + # print("Preprocessing ONNX model before quantization...") + # pre_processed_model_path = onnx_model_path + ".pre_processed.onnx" + + # quant_pre_process(onnx_model_path, pre_processed_model_path) + # Quantize the model quantize_dynamic( onnx_model_path, quantized_model_path, + per_channel=False, + reduce_range=False, weight_type=QuantType.QInt8, ) print(f"Quantized model saved to: {quantized_model_path}") From ede7f19e29e5c59c624bf65bf85ec4f845bb2e87 Mon Sep 17 00:00:00 2001 From: MikeG112 Date: Sun, 27 Apr 2025 17:54:40 -0400 Subject: [PATCH 23/27] Support bert/distilbert --- src/modeling/train_outcome_predictor.py | 241 +++++++++++++++++++++++- 1 file changed, 238 insertions(+), 3 deletions(-) diff --git a/src/modeling/train_outcome_predictor.py b/src/modeling/train_outcome_predictor.py index 559c5e8..38ca399 100644 --- a/src/modeling/train_outcome_predictor.py +++ b/src/modeling/train_outcome_predictor.py @@ -395,6 +395,213 @@ def forward( return ((loss,) + output) if loss is not None else output +# Nearly identical but inherits from DistilBertForSequenceClassification +# and doesn't use token_type_ids +class DistilBertTextClassificationWithMetadata(DistilBertForSequenceClassification): + def __init__(self, config): + super().__init__(config) + self.hidden_size = config.hidden_size + + # Metadata embeddings + self.jurisdiction_embeddings = nn.Embedding(2, self.hidden_size // 4) + self.insurance_type_embeddings = nn.Embedding(2, self.hidden_size // 4) + + # Fusion projection + self.metadata_projection = nn.Sequential( + nn.Linear(self.hidden_size // 2, self.hidden_size), + nn.LayerNorm(self.hidden_size), + nn.GELU(), + nn.Dropout(0.1), + ) + + # Manual cross-attention components + self.num_heads = 8 + self.head_dim = self.hidden_size // self.num_heads + + # Query, Key, Value projections + self.q_proj = nn.Linear(self.hidden_size, self.hidden_size) + self.k_proj = nn.Linear(self.hidden_size, self.hidden_size) + self.v_proj = nn.Linear(self.hidden_size, self.hidden_size) + self.out_proj = nn.Linear(self.hidden_size, self.hidden_size) + + self.dropout = nn.Dropout(0.1) + + self.metadata_norm = nn.LayerNorm(self.hidden_size) + self.sequence_norm = nn.LayerNorm(self.hidden_size) + + # Save the original num_labels + self.num_labels = ( + self.classifier.out_features if hasattr(self.classifier, "out_features") else config.num_labels + ) + + # Now replace the classifier + self.enhanced_classifier = nn.Sequential( + nn.Linear(self.hidden_size, self.hidden_size), + nn.LayerNorm(self.hidden_size), + nn.GELU(), + nn.Dropout(0.1), + nn.Linear(self.hidden_size, self.num_labels), + ) + + def manual_cross_attention(self, query, key, value, key_padding_mask=None): + # Same implementation, but with explicit cleanup + batch_size = query.size(1) + seq_length = key.size(0) + + # Project query, key, value + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + + # Reshape for multi-head attention + q = q.view(1, batch_size, self.num_heads, self.head_dim).permute(2, 1, 0, 3) + k = k.view(seq_length, batch_size, self.num_heads, self.head_dim).permute(2, 1, 0, 3) + v = v.view(seq_length, batch_size, self.num_heads, self.head_dim).permute(2, 1, 0, 3) + + # Calculate attention scores + scaling = float(self.head_dim) ** -0.5 + q = q * scaling + attn_scores = torch.matmul(q, k.transpose(-2, -1)) + + # Apply attention mask if provided + if key_padding_mask is not None: + attn_mask = key_padding_mask.unsqueeze(1).unsqueeze(1) + attn_mask = attn_mask.transpose(0, 1) + attn_mask = attn_mask.expand(self.num_heads, -1, 1, -1) + attn_scores = attn_scores.masked_fill(attn_mask, -10000.0) + + # Get max values for numerical stability + attn_scores_max, _ = torch.max(attn_scores, dim=-1, keepdim=True) + attn_scores = attn_scores - attn_scores_max + + # Apply softmax + attn_weights = torch.exp(torch.clamp(attn_scores, min=-10000.0, max=100.0)) + attn_sum = torch.sum(attn_weights, dim=-1, keepdim=True) + 1e-6 + attn_weights = attn_weights / attn_sum + + # Apply dropout + attn_weights = self.dropout(attn_weights) + + # Clean up intermediates + del attn_scores, attn_scores_max, attn_sum + if key_padding_mask is not None: + del attn_mask + + # Apply attention weights to values + context = torch.matmul(attn_weights, v) + + # Clean up more intermediates + del attn_weights + + # Reshape back + context = context.permute(2, 1, 0, 3) + context = context.reshape(1, batch_size, self.hidden_size) + + # Clean up + del q, k, v + + # Apply output projection + result = self.out_proj(context) + + # Final cleanup + del context + + return result + + def forward( + self, + input_ids=None, + attention_mask=None, + jurisdiction_id=None, + insurance_type_id=None, + labels=None, + return_dict=None, + token_type_ids=None, # This parameter is left for API compatibility but not used + ): + # Validate that required parameters are provided, use index 2 for optional + if jurisdiction_id is None or insurance_type_id is None: + raise ValueError("jurisdiction_id and insurance_type_id must be provided") + + # Process metadata + j_unspecified_mask = jurisdiction_id == 2 + i_unspecified_mask = insurance_type_id == 2 + + # Temporarily assign unspecified metadata classes -> 1 + j_ids_valid = torch.clamp(jurisdiction_id, 0, 1) + i_ids_valid = torch.clamp(insurance_type_id, 0, 1) + + j_embeddings = self.jurisdiction_embeddings(j_ids_valid) + i_embeddings = self.insurance_type_embeddings(i_ids_valid) + + avg_j_embedding = self.jurisdiction_embeddings.weight.mean(dim=0) + avg_i_embedding = self.insurance_type_embeddings.weight.mean(dim=0) + + # Assign unspecified classes the average embeddings + j_embeddings = torch.where( + j_unspecified_mask.unsqueeze(-1).expand_as(j_embeddings), + avg_j_embedding.expand_as(j_embeddings), + j_embeddings, + ) + i_embeddings = torch.where( + i_unspecified_mask.unsqueeze(-1).expand_as(i_embeddings), + avg_i_embedding.expand_as(i_embeddings), + i_embeddings, + ) + + # Combine metadata and project + combined_metadata = torch.cat([j_embeddings, i_embeddings], dim=1) + metadata_features = self.metadata_projection(combined_metadata) + + # Call the parent model but bypass its classification head + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + base_outputs = super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + # No token_type_ids for DistilBERT + output_hidden_states=True, # We need the hidden states + return_dict=True, + labels=None, # Don't pass labels yet + ) + + sequence_output = base_outputs.hidden_states[-1] + + # Use cross-attention for fusion + metadata_features = self.metadata_norm(metadata_features) + sequence_output = self.sequence_norm(sequence_output) + + # Reshape for attention + metadata_features = metadata_features.unsqueeze(0) # [1, batch, hidden_dim] + sequence_output_t = sequence_output.transpose(0, 1) # [seq_len, batch, hidden_dim] + + key_padding_mask = attention_mask == 0 + + fused_features = self.manual_cross_attention( + query=metadata_features, key=sequence_output_t, value=sequence_output_t, key_padding_mask=key_padding_mask + ) + + # Combine with CLS token + fused_features = fused_features.squeeze(0) + sequence_output[:, 0] + + logits = self.enhanced_classifier(fused_features) + + # Calculate loss if labels provided + loss = None + if labels is not None: + loss = torch.nn.functional.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1)) + + if return_dict: + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=base_outputs.hidden_states, + attentions=base_outputs.attentions, + ) + else: + output = (logits,) + base_outputs[2:] + return ((loss,) + output) if loss is not None else output + + class DataCollatorWithMetadata(DataCollatorWithPadding): def __call__(self, features): batch = super().__call__(features) @@ -434,6 +641,34 @@ def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None) return (loss, logits, labels) +def create_metadata_model(pretrained_model_key, base_model_name, num_labels, id2label, label2id): + # Map of base model names to model classes + MODEL_CLASS_MAP = { + # BERT-based models + "legal-bert-small-uncased": TextClassificationWithMetadata, + # DistilBERT-based models + "distilbert": DistilBertTextClassificationWithMetadata, + "clinicalbert": DistilBertTextClassificationWithMetadata, + } + + # Get model class based on base_model_name + # This will use TextClassificationWithMetadata by default for unrecognized models + model_class = MODEL_CLASS_MAP.get( + base_model_name, + DistilBertTextClassificationWithMetadata + if "distilbert" in base_model_name.lower() + else TextClassificationWithMetadata, + ) + + # Create and return the appropriate model + return model_class.from_pretrained( + pretrained_model_key, + num_labels=num_labels, + id2label=id2label, + label2id=label2id, + ) + + def main(config_path: str) -> None: cfg = load_config(config_path) @@ -485,9 +720,9 @@ def main(config_path: str) -> None: # Use our custom data collator that handles the additional features data_collator = DataCollatorWithMetadata(tokenizer=tokenizer) - # Now instantiate your custom model correctly - model = TextClassificationWithMetadata.from_pretrained( - pretrained_model_key, + model = create_metadata_model( + pretrained_model_key=pretrained_model_key, + base_model_name=base_model_name, num_labels=3, id2label=ID2LABEL, label2id=LABEL2ID, From 6d9e4dd4c8a86bfd4e6015c218d61ddef218076e Mon Sep 17 00:00:00 2001 From: MikeG112 Date: Sun, 27 Apr 2025 17:54:59 -0400 Subject: [PATCH 24/27] Update eval configs --- src/modeling/config/outcome_prediction/clinicalbert.yaml | 6 +++--- src/modeling/config/outcome_prediction/distilbert.yaml | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/modeling/config/outcome_prediction/clinicalbert.yaml b/src/modeling/config/outcome_prediction/clinicalbert.yaml index bd95a1f..3b538c7 100644 --- a/src/modeling/config/outcome_prediction/clinicalbert.yaml +++ b/src/modeling/config/outcome_prediction/clinicalbert.yaml @@ -14,11 +14,11 @@ train_data_path: "./data/outcomes/train_backgrounds_suff_augmented.jsonl" learning_rate: 8.0e-7 weight_decay: 0.01 num_epochs: 40 -batch_size: 20 +batch_size: 32 dtype: "float16" # 'float32','float16' for training dtype compile: True # Whether to use torch compile # Test eval settings test_data_path: "./data/outcomes/test_backgrounds_suff.jsonl" -checkpoint_name: checkpoint-49770 -eval_threshold: .55 \ No newline at end of file +checkpoint_name: checkpoint-36288 +eval_threshold: .65 \ No newline at end of file diff --git a/src/modeling/config/outcome_prediction/distilbert.yaml b/src/modeling/config/outcome_prediction/distilbert.yaml index 4fce9db..cbc46b8 100644 --- a/src/modeling/config/outcome_prediction/distilbert.yaml +++ b/src/modeling/config/outcome_prediction/distilbert.yaml @@ -21,4 +21,4 @@ compile: True # Whether to use torch compile # Test eval settings test_data_path: "./data/outcomes/test_backgrounds_suff.jsonl" checkpoint_name: "checkpoint-23040" -eval_threshold: .5 \ No newline at end of file +eval_threshold: .7 \ No newline at end of file From a80413fc2e2b6f872fedd5b86977af1ab41d4095 Mon Sep 17 00:00:00 2001 From: MikeG112 Date: Sun, 27 Apr 2025 17:55:14 -0400 Subject: [PATCH 25/27] Use factory function --- src/modeling/eval_hf_outcome_predictor.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/modeling/eval_hf_outcome_predictor.py b/src/modeling/eval_hf_outcome_predictor.py index bc75be9..f44239f 100644 --- a/src/modeling/eval_hf_outcome_predictor.py +++ b/src/modeling/eval_hf_outcome_predictor.py @@ -14,12 +14,11 @@ roc_auc_score, ) from transformers import ( - AutoModelForSequenceClassification, AutoTokenizer, TextClassificationPipeline, ) -from src.modeling.train_outcome_predictor import TextClassificationWithMetadata +from src.modeling.train_outcome_predictor import create_metadata_model from src.modeling.util import load_config from src.util import get_records_list @@ -276,12 +275,9 @@ def main(config_path: str): # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(ckpt_path) - # Load model - # model = AutoModelForSequenceClassification.from_pretrained( - # ckpt_path, num_labels=3, id2label=ID2LABEL, label2id=LABEL2ID - # ) - model = TextClassificationWithMetadata.from_pretrained( - ckpt_path, + model = create_metadata_model( + pretrained_model_key=ckpt_path, + base_model_name=base_model_name, num_labels=3, id2label=ID2LABEL, label2id=LABEL2ID, From 6b9af8af80ed4b584473dc8681f9efaa7d97c6c4 Mon Sep 17 00:00:00 2001 From: MikeG112 Date: Sun, 27 Apr 2025 18:10:16 -0400 Subject: [PATCH 26/27] Fix imports --- src/modeling/pretrain.py | 2 +- src/modeling/train_background_token_classification.py | 2 +- src/modeling/train_outcome_predictor.py | 2 +- src/modeling/train_sufficiency_classifier.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/modeling/pretrain.py b/src/modeling/pretrain.py index 0101919..d32a37b 100644 --- a/src/modeling/pretrain.py +++ b/src/modeling/pretrain.py @@ -9,6 +9,7 @@ import datasets import numpy as np +import wandb from datasets import Dataset from transformers import ( AutoModelForMaskedLM, @@ -18,7 +19,6 @@ default_data_collator, ) -import wandb from src.modeling.util import load_config os.environ["TOKENIZERS_PARALLELISM"] = "false" diff --git a/src/modeling/train_background_token_classification.py b/src/modeling/train_background_token_classification.py index 0f8bdd8..31b0302 100644 --- a/src/modeling/train_background_token_classification.py +++ b/src/modeling/train_background_token_classification.py @@ -6,6 +6,7 @@ import evaluate import numpy as np +import wandb from datasets import load_dataset from transformers import ( AutoModelForTokenClassification, @@ -16,7 +17,6 @@ TrainingArguments, ) -import wandb from src.modeling.util import load_config NO_CLASS_ENCODING = 0 diff --git a/src/modeling/train_outcome_predictor.py b/src/modeling/train_outcome_predictor.py index 38ca399..14a4b51 100644 --- a/src/modeling/train_outcome_predictor.py +++ b/src/modeling/train_outcome_predictor.py @@ -7,6 +7,7 @@ import scipy import torch import torch.nn as nn +import wandb from datasets import Dataset, DatasetDict, load_dataset from sklearn.metrics import ( accuracy_score, @@ -27,7 +28,6 @@ ) from transformers.modeling_outputs import SequenceClassifierOutput -import wandb from src.modeling.util import export_onnx_model, load_config, quantize_onnx_model from src.util import get_records_list diff --git a/src/modeling/train_sufficiency_classifier.py b/src/modeling/train_sufficiency_classifier.py index f34c079..c4bf976 100644 --- a/src/modeling/train_sufficiency_classifier.py +++ b/src/modeling/train_sufficiency_classifier.py @@ -7,6 +7,7 @@ import numpy as np import scipy import torch +import wandb from datasets import Dataset from sklearn.metrics import ( accuracy_score, @@ -23,7 +24,6 @@ TrainingArguments, ) -import wandb from src.modeling.data_augmentation import ( load_augmented_dataset, process_and_save_augmentations, From 2321660b3b7b9d1785668ede5177c724d152f030 Mon Sep 17 00:00:00 2001 From: MikeG112 Date: Sun, 27 Apr 2025 18:10:28 -0400 Subject: [PATCH 27/27] ignore wandb logs --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index f2ff7f1..123b1cc 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ data/paper data models wandb +wandb_logs __pycache__/ *.pyc