Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ data/paper
data
models
wandb
wandb_logs

__pycache__/
*.pyc
Expand Down
2 changes: 1 addition & 1 deletion processed_sources.jsonl
Original file line number Diff line number Diff line change
Expand Up @@ -1106,7 +1106,7 @@
{"url": "https://oig.hhs.gov/oei/reports/oei-02-16-00570.pdf", "date_accessed": "2024-01-25", "local_path": "./data/raw/hhs_oig/oei-02-16-00570.pdf", "tags": ["hhs-oig", "regulatory-guidance"], "preprocessor": "pdf", "md5": "ea67cf9aa08f8eeb92f43178a0259de8", "local_processed_path": "./data/processed/hhs_oig/oei-02-16-00570.jsonl"}
{"url": "https://oig.hhs.gov/oas/reports/region9/91602042.pdf", "date_accessed": "2024-01-25", "local_path": "./data/raw/hhs_oig/91602042.pdf", "tags": ["hhs-oig", "regulatory-guidance"], "preprocessor": "pdf", "md5": "0956b057e0ada92e40dcc7bf02bb5cf2", "local_processed_path": "./data/processed/hhs_oig/91602042.jsonl"}
{"url": "https://oig.hhs.gov/oei/reports/oei-03-15-00180.pdf", "date_accessed": "2024-01-25", "local_path": "./data/raw/hhs_oig/oei-03-15-00180.pdf", "tags": ["hhs-oig", "regulatory-guidance"], "preprocessor": "pdf", "md5": "fa56c5c8d598b139873e22a08131b69b", "local_processed_path": "./data/processed/hhs_oig/oei-03-15-00180.jsonl"}
{"url": "https://oig.hhs.gov/reports-and-publications/portfolio/portfolio-12-12-01.pdf", "date_accessed": "2024-01-25", "local_path": "./data/raw/hhs_oig/portfolio-12-12-01.pdf", "tags": ["hhs-oig", "regulatory-guidance"], "preprocessor": "pdf", "md5": "a0b1f3c770899d1ed57ced5a477e9e94", "local_processed_path": "./data/processed/hhs_oig/portfolio-12-12-01.jsonl"}
{"url": "https://oig.hhs.gov/reports-and-publications/portfolio/portfolio-12-12-01.pdf", "date_accessed": "2024-01-25", "local_path": "./data/raw/hhs_oig/portfolio-12-12-01.pdf", "tags": ["hhs-oig"], "preprocessor": "pdf", "md5": "a0b1f3c770899d1ed57ced5a477e9e94", "local_processed_path": "./data/processed/hhs_oig/portfolio-12-12-01.jsonl"}
{"url": "https://oig.hhs.gov/documents/testimony/1119/bliss-testimony-05182023.pdf", "date_accessed": "2024-01-25", "local_path": "./data/raw/hhs_oig/bliss-testimony-05182023.pdf", "tags": ["hhs-oig", "testimony", "opinion-policy-summary"], "preprocessor": "pdf", "md5": "417d6519b93105b6a371602110bfe33c", "local_processed_path": "./data/processed/hhs_oig/bliss-testimony-05182023.jsonl"}
{"url": "https://oig.hhs.gov/documents/testimony/1118/megan-tinker-testimony-05172023.pdf", "date_accessed": "2024-01-25", "local_path": "./data/raw/hhs_oig/megan-tinker-testimony-05172023.pdf", "tags": ["hhs-oig", "testimony", "opinion-policy-summary"], "preprocessor": "pdf", "md5": "e562fcf65a825e1e5d72ed55a930b993", "local_processed_path": "./data/processed/hhs_oig/megan-tinker-testimony-05172023.jsonl"}
{"url": "https://oig.hhs.gov/documents/testimony/1112/christi-grimm-testimony-04182023.pdf", "date_accessed": "2024-01-25", "local_path": "./data/raw/hhs_oig/christi-grimm-testimony-04182023.pdf", "tags": ["hhs-oig", "testimony", "opinion-policy-summary"], "preprocessor": "pdf", "md5": "3d2f8e1e1fea7a9113ade3287acb7884", "local_processed_path": "./data/processed/hhs_oig/christi-grimm-testimony-04182023.jsonl"}
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pre-commit==3.7.1
RapidFuzz==3.13.0
requests==2.32.3
torch==2.6.0
transformers==4.48.3
transformers==4.51.3
evaluate
scikit-learn==1.5.0
selenium==4.17.2
Expand Down
299 changes: 298 additions & 1 deletion src/modeling/background_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from multiprocessing import Pool, cpu_count

import torch
from datasets import Dataset
from rapidfuzz import fuzz
from sklearn.model_selection import train_test_split
from transformers import (
Expand All @@ -11,6 +12,11 @@
AutoTokenizer,
)

from src.modeling.data_augmentation import (
augment_sufficient_examples,
generate_unrelated_content,
rewrite_to_generic,
)
from src.util import add_jsonl_lines, batcher, get_records_list


Expand Down Expand Up @@ -210,6 +216,8 @@ def construct_recs(raw_recs: list[dict], tokenizer: AutoTokenizer, model: torch.
"decision": rec["decision"],
"appeal_type": rec["appeal_type"],
"full_text": rec["text"],
"jurisdiction": rec.get("jurisdiction", "Unspecified"),
"insurance_type": rec.get("insurance_type", "Unspecified"),
}
backgrounds.append(updated_record)
idx += 1
Expand Down Expand Up @@ -241,6 +249,8 @@ def construct_recs_batch(
"appeal_type": rec["appeal_type"],
"full_text": rec["text"],
"sufficiency_id": sufficiency_id,
"jurisdiction": rec.get("jurisdiction", "Unspecified"),
"insurance_type": rec.get("insurance_type", "Unspecified"),
}
for (rec, background, sufficiency_id) in zip(batch, background_batch, sufficiency_batch)
]
Expand Down Expand Up @@ -319,6 +329,217 @@ def combine_jsonl(directory) -> None:
return None


def create_augmented_examples(
records: list[dict],
api_type: str = "openai",
api_url: str = "https://api.openai.com/v1/chat/completions",
api_key: str | None = None,
model_name: str = "gpt-4o",
generic_rewrites_per_example: int = 1,
sufficient_augmentations_per_example: int = 1,
num_unrelated_examples: int = 100,
api_call_limit: int | None = 700,
seed: int = 42,
) -> list[dict]:
"""
Create augmented examples from background texts with optional limit on API calls.

Args:
records: List of record dictionaries with text, decision, etc.
api_type: Type of API to use ("openai" or "llamacpp")
api_url: URL for the API endpoint
model_name: Name of the model to use
generic_rewrites_per_example: Number of generic rewrites per sufficient example
sufficient_augmentations_per_example: Number of sufficient augmentations per sufficient example
num_unrelated_examples: Number of unrelated content examples to generate
api_call_limit: Maximum number of API calls to make (if None, no limit)
seed: Random seed for reproducibility

Returns:
List of augmented record dictionaries
"""
import random

random.seed(seed)

print(f"Creating augmented examples using {api_type} API at {api_url}...")
augmented_records = []

# Track API calls
api_calls_made = 0

# 1. Create dataset for augmentation functions
dataset_records = []
for rec in records:
# Check if record has sufficiency_id
sufficiency_score = 4 # Default to sufficient for original records
if "sufficiency_id" in rec:
sufficiency_score = 4 if rec["sufficiency_id"] == 1 else 2

dataset_records.append({"text": rec["text"], "sufficiency_score": sufficiency_score})

dataset = Dataset.from_list(dataset_records)

# If we have a limit, we need to randomly select which examples to augment
sufficient_examples = [ex for ex in dataset_records if ex["sufficiency_score"] >= 3]
num_sufficient = len(sufficient_examples)

# Calculate how many API calls each technique would require
total_generic_calls = num_sufficient * generic_rewrites_per_example
total_sufficient_calls = num_sufficient * sufficient_augmentations_per_example
total_unrelated_calls = num_unrelated_examples

# Calculate total potential API calls
total_potential_calls = total_generic_calls + total_sufficient_calls + total_unrelated_calls

# If we have a limit and it's less than the potential total, adjust
if api_call_limit is not None and api_call_limit < total_potential_calls:
print(f"API call limit ({api_call_limit}) is less than potential total ({total_potential_calls})")

# Distribute the limit evenly across the three techniques
calls_per_technique = api_call_limit // 3

# Calculate adjusted calls for each technique
adjusted_generic_calls = calls_per_technique
adjusted_sufficient_calls = calls_per_technique
adjusted_unrelated_calls = (
api_call_limit - adjusted_generic_calls - adjusted_sufficient_calls
) # Use remainder for unrelated

# Calculate how many examples we can augment for each technique
examples_for_generic = adjusted_generic_calls // generic_rewrites_per_example
examples_for_sufficient = adjusted_sufficient_calls // sufficient_augmentations_per_example

# Randomly select examples to augment
if examples_for_generic < num_sufficient:
examples_generic = random.sample(sufficient_examples, examples_for_generic)
else:
examples_generic = sufficient_examples

if examples_for_sufficient < num_sufficient:
examples_sufficient = random.sample(sufficient_examples, examples_for_sufficient)
else:
examples_sufficient = sufficient_examples

# Adjust unrelated examples count
adjusted_num_unrelated = adjusted_unrelated_calls

print("Adjusted numbers based on API call limit:")
print(f" Generic rewrites: {examples_for_generic} examples ({adjusted_generic_calls} calls)")
print(f" Sufficient augmentations: {examples_for_sufficient} examples ({adjusted_sufficient_calls} calls)")
print(f" Unrelated content: {adjusted_num_unrelated} examples ({adjusted_unrelated_calls} calls)")

# Create filtered datasets for limited augmentation
generic_dataset = Dataset.from_list(examples_generic)
sufficient_dataset = Dataset.from_list(examples_sufficient)

# Update parameters
generic_params_count = adjusted_generic_calls
sufficient_params_count = adjusted_sufficient_calls
unrelated_params_count = adjusted_unrelated_calls
else:
# No limit or limit is high enough, use all examples
generic_dataset = dataset
sufficient_dataset = dataset
generic_params_count = total_generic_calls
sufficient_params_count = total_sufficient_calls
unrelated_params_count = num_unrelated_examples

# 2. Apply generic rewrite augmentation (make sufficient examples insufficient)
if generic_params_count > 0:
print("Generating generic rewrites...")
generic_rewrite_params = {
"num_augmentations_per_example": generic_rewrites_per_example,
"api_type": api_type,
"api_url": api_url,
"api_key": api_key,
"model_name": model_name,
"seed": seed,
}

generic_examples = rewrite_to_generic(generic_dataset, **generic_rewrite_params)
api_calls_made += len(generic_examples)
print(f"Generated {len(generic_examples)} examples through generic rewriting")
else:
generic_examples = []
print("Skipping generic rewrites due to API call limit")

# 3. Apply sufficient example augmentation (keep sufficient examples sufficient)
if sufficient_params_count > 0:
print("Generating sufficient augmentations...")
sufficient_augmentation_params = {
"num_augmentations_per_example": sufficient_augmentations_per_example,
"api_type": api_type,
"api_url": api_url,
"api_key": api_key,
"model_name": model_name,
"seed": seed,
}

sufficient_examples = augment_sufficient_examples(sufficient_dataset, **sufficient_augmentation_params)
api_calls_made += len(sufficient_examples)
print(f"Generated {len(sufficient_examples)} augmented sufficient examples")
else:
sufficient_examples = []
print("Skipping sufficient augmentations due to API call limit")

# 4. Generate unrelated content
if unrelated_params_count > 0:
print("Generating unrelated content...")
unrelated_params = {
"num_examples": unrelated_params_count,
"api_type": api_type,
"api_url": api_url,
"api_key": api_key,
"model_name": model_name,
"seed": seed,
}

unrelated_examples = generate_unrelated_content(**unrelated_params)
api_calls_made += len(unrelated_examples)
print(f"Generated {len(unrelated_examples)} examples with unrelated content")
else:
unrelated_examples = []
print("Skipping unrelated content due to API call limit")

# 5. Convert augmented texts back to record format
# For generic rewrites and sufficient augmentations, we need to find the original record
all_records_dict = {rec["text"]: rec for rec in records}

for aug_example in generic_examples + sufficient_examples:
source_text = aug_example["source_text"]
# Find the original record
original_rec = all_records_dict.get(source_text)
if original_rec:
augmented_rec = original_rec.copy()
augmented_rec["text"] = aug_example["text"]
augmented_rec["sufficiency_id"] = 1 if aug_example["sufficiency_score"] >= 3 else 0
augmented_rec["augmentation_type"] = aug_example["augmentation_type"]
augmented_records.append(augmented_rec)

# For unrelated content, create new records with random decision label
decisions = ["Upheld", "Overturned"]
appeal_types = ["IMR", "DMHC", "CDI"] # Example appeal types

for i, aug_example in enumerate(unrelated_examples):
augmented_rec = {
"text": aug_example["text"],
"decision": random.choice(decisions), # Random decision since unrelated to actual cases
"appeal_type": random.choice(appeal_types),
"full_text": aug_example["text"], # No original full text
"sufficiency_id": 0, # Always insufficient
"jurisdiction": "Unspecified",
"insurance_type": "Unspecified",
"augmentation_type": aug_example["augmentation_type"],
"id": f"unrelated_{i}",
}
augmented_records.append(augmented_rec)

print(f"Total API calls made: {api_calls_made}")
print(f"Total augmented records created: {len(augmented_records)}")
return augmented_records


if __name__ == "__main__":
# Config applied to both models
device = "cuda"
Expand Down Expand Up @@ -361,7 +582,7 @@ def combine_jsonl(directory) -> None:
# Combine CA CDI files into single jsonl
combine_jsonl("./data/processed/ca_cdi/summaries")

# Define Outcome Map / preprocessing
# Define Outcome Map / preprocessing and jurisdiction mapping
# TODO: decide if this standardization belongs here, or in raw dataset storage (i.e in records we read here)
extraction_targets = [
(
Expand Down Expand Up @@ -394,6 +615,11 @@ def combine_jsonl(directory) -> None:
# We will also split a train and test set for consistent experiments
train_out_path = "./data/outcomes/train_backgrounds_suff.jsonl"
test_out_path = "./data/outcomes/test_backgrounds_suff.jsonl"

# New paths for augmented datasets
train_augmented_out_path = "./data/outcomes/train_backgrounds_suff_augmented.jsonl"
test_augmented_out_path = "./data/outcomes/test_backgrounds_suff_augmented.jsonl"

train_subset = []
test_subset = []
for path, outcome_map in extraction_targets:
Expand Down Expand Up @@ -433,5 +659,76 @@ def combine_jsonl(directory) -> None:
train_subset.extend(train_recs)
test_subset.extend(val_recs)

# Write original train and test sets
add_jsonl_lines(train_out_path, train_subset)
add_jsonl_lines(test_out_path, test_subset)

print(f"Original train set: {len(train_subset)} examples")
print(f"Original test set: {len(test_subset)} examples")

# Check for OpenAI API key in environment variable
api_key = os.environ.get("OPENAI_API_KEY")
api_type = "openai" if api_key else "llamacpp"
api_url = "https://api.openai.com/v1/chat/completions" if api_key else "http://localhost:8080/completion"
model_name = "gpt-4o" if api_key else "llama-3.1"

print(f"Using {api_type} API for augmentation")

train_subset = get_records_list(train_out_path)
test_subset = get_records_list(test_out_path)

# Create augmented examples for train set
print("\n=== Creating Augmented Examples for Train Set ===")
train_augmented = create_augmented_examples(
train_subset,
api_type=api_type,
api_url=api_url,
api_key=api_key,
model_name=model_name,
generic_rewrites_per_example=1,
sufficient_augmentations_per_example=1,
num_unrelated_examples=1000,
api_call_limit=6000,
seed=42,
)

# Write augmented train and test sets
# Combine original and augmented examples
train_combined = train_subset + train_augmented
add_jsonl_lines(train_augmented_out_path, train_combined)

# Create augmented examples for test set
print("\n=== Creating Augmented Examples for Test Set ===")
test_augmented = create_augmented_examples(
test_subset,
api_type=api_type,
api_url=api_url,
api_key=api_key,
model_name=model_name,
generic_rewrites_per_example=1, # Fewer augmentations for test set
sufficient_augmentations_per_example=1,
num_unrelated_examples=1000,
api_call_limit=6000,
seed=43, # Different seed for test set
)

# Combine original and augmented examples
test_combined = test_subset + test_augmented

# Write augmented test set
add_jsonl_lines(test_augmented_out_path, test_combined)

# print(f"\nFinal augmented train set: {len(train_combined)} examples")
print(f"Final augmented test set: {len(test_combined)} examples")

# Print augmentation statistics
train_augmented_types = {}
for rec in train_augmented:
aug_type = rec.get("augmentation_type", "unknown")
if aug_type not in train_augmented_types:
train_augmented_types[aug_type] = 0
train_augmented_types[aug_type] += 1

print("\nTrain set augmentation breakdown:")
for aug_type, count in train_augmented_types.items():
print(f" {aug_type}: {count}")
Loading