diff --git a/examples/promptehr_mimic3_synthetic_generation.py b/examples/promptehr_mimic3_synthetic_generation.py
new file mode 100644
index 000000000..43ba6c467
--- /dev/null
+++ b/examples/promptehr_mimic3_synthetic_generation.py
@@ -0,0 +1,726 @@
+"""
+PromptEHR MIMIC-III Synthetic Data Generation Pipeline
+
+This script implements the complete pipeline for generating synthetic MIMIC-III EHR data
+using our restored PyHealth PromptEHR implementation, following the synthEHRella approach
+but with full training capability.
+
+Pipeline:
+1. MIMIC-III Data Preprocessing (PyHealth)
+2. PromptEHR Training (Restored Training Pipeline)
+3. Synthetic Data Generation
+4. Format Conversion (Compatible with MedGAN/CorGAN evaluations)
+
+Usage:
+ # Full pipeline (training + generation)
+ python promptehr_mimic3_synthetic_generation.py --mode train_and_generate --mimic_root ./data_files --output_dir ./promptehr_synthetic
+
+ # Generation only (using pretrained model)
+ python promptehr_mimic3_synthetic_generation.py --mode generate_only --model_path ./trained_promptehr --output_dir ./promptehr_synthetic
+
+ # Preprocessing only
+ python promptehr_mimic3_synthetic_generation.py --mode preprocess_only --mimic_root ./data_files --output_dir ./promptehr_preprocessed
+"""
+
+import os
+import sys
+import argparse
+import pickle
+import numpy as np
+import pandas as pd
+import torch
+import warnings
+from pathlib import Path
+from collections import defaultdict, Counter
+from typing import Dict, List, Tuple, Any, Optional
+import json
+from tqdm import tqdm
+
+# Add PyHealth to path
+sys.path.append(str(Path(__file__).parent))
+import sys
+import os
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'pyhealth', 'models', 'generators'))
+from promptehr import PromptEHR
+from pyhealth.datasets import SampleDataset, MIMIC3Dataset
+from pyhealth.tasks import BaseTask
+
+warnings.filterwarnings('ignore')
+
+class MIMIC3PromptEHRTask(BaseTask):
+
+ def __init__(
+ self,
+ max_visits_per_patient: int = 20,
+ min_visits_per_patient: int = 2,
+ include_procedures: bool = True,
+ include_medications: bool = True,
+ code_vocab_threshold: int = 5,
+ convert_to_3digit_icd9: bool = True
+ ):
+ super().__init__()
+ self.task_name = "MIMIC3_PromptEHR"
+ self.input_schema = {}
+ self.output_schema = {}
+ self.max_visits_per_patient = max_visits_per_patient
+ self.min_visits_per_patient = min_visits_per_patient
+ self.include_procedures = include_procedures
+ self.include_medications = include_medications
+ self.code_vocab_threshold = code_vocab_threshold
+ self.convert_to_3digit_icd9 = convert_to_3digit_icd9
+
+ def _convert_to_3digit_icd9(self, dx_str: str) -> str:
+ if dx_str.startswith('E'):
+ if len(dx_str) > 1:
+ numeric_part = dx_str[1:] # Remove 'E' prefix
+ if len(numeric_part) > 3:
+ numeric_part = numeric_part[:3]
+ try:
+ num = int(numeric_part)
+ return str(800 + (num % 200))
+ except:
+ return '800'
+ else:
+ return '800'
+ elif dx_str.startswith('V'):
+ if len(dx_str) > 1:
+ numeric_part = dx_str[1:] # Remove 'V' prefix
+ if len(numeric_part) > 2:
+ numeric_part = numeric_part[:2]
+ try:
+ num = int(numeric_part)
+ return str(700 + (num % 100))
+ except:
+ return '700'
+ else:
+ return '700'
+ else:
+ if len(dx_str) > 3:
+ return dx_str[:3]
+ else:
+ return dx_str
+
+ def __call__(self, patient) -> List[Dict]:
+
+ # Get patient admissions
+ admissions = patient.get_events(event_type="admissions")
+
+ if len(admissions) < self.min_visits_per_patient:
+ return []
+
+ all_diagnoses = patient.get_events(event_type="diagnoses_icd")
+ all_procedures = patient.get_events(event_type="procedures_icd") if self.include_procedures else []
+ all_medications = patient.get_events(event_type="prescriptions") if self.include_medications else []
+
+ diag_codes = []
+ for diagnosis in all_diagnoses:
+ icd9_code = diagnosis.attr_dict.get('icd9_code')
+ if icd9_code:
+ code = str(icd9_code).strip()
+ if self.convert_to_3digit_icd9:
+ code = self._convert_to_3digit_icd9(code)
+ diag_codes.append(code)
+
+ proc_codes = []
+ if self.include_procedures:
+ for procedure in all_procedures:
+ icd9_code = procedure.attr_dict.get('icd9_code')
+ if icd9_code:
+ proc_codes.append(str(icd9_code).strip())
+
+ med_codes = []
+ if self.include_medications:
+ for prescription in all_medications:
+ drug = prescription.attr_dict.get('drug')
+ if drug:
+ drug_name = str(drug).strip()
+ import hashlib
+ drug_hash = hashlib.md5(drug_name.encode()).hexdigest()
+ drug_id = abs(int(drug_hash[:8], 16)) % 100000
+ med_codes.append(str(drug_id))
+
+ if not diag_codes and not proc_codes and not med_codes:
+ return []
+
+ num_visits = min(len(admissions), self.max_visits_per_patient)
+ visits = []
+
+ for i in range(num_visits):
+ if i == 0:
+ visit_data = {
+ 'diag': diag_codes,
+ 'proc': proc_codes,
+ 'med': med_codes
+ }
+ else:
+ visit_data = {
+ 'diag': [],
+ 'proc': [],
+ 'med': []
+ }
+ visits.append(visit_data)
+
+ diag_visits = [visit.get('diag', []) for visit in visits]
+ proc_visits = [visit.get('proc', []) for visit in visits] if self.include_procedures else [[] for _ in visits]
+ med_visits = [visit.get('med', []) for visit in visits] if self.include_medications else [[] for _ in visits]
+
+ baseline_features = self._extract_baseline_features(patient, admissions[0])
+
+ sample = {
+ 'patient_id': patient.patient_id,
+ 'v': {
+ 'diag': diag_visits,
+ 'proc': proc_visits,
+ 'med': med_visits
+ },
+ 'x': baseline_features,
+ 'num_visits': num_visits
+ }
+
+ return [sample]
+
+ def _process_visit(self, patient, admission) -> Dict[str, List[str]]:
+
+ visit_codes = {'diag': [], 'proc': [], 'med': []}
+
+ from datetime import timedelta
+ start_time = admission.timestamp
+ discharge_time = admission.attr_dict.get('dischtime')
+ if discharge_time:
+ try:
+ from datetime import datetime
+ if isinstance(discharge_time, str):
+ end_time = datetime.strptime(discharge_time, '%Y-%m-%d %H:%M:%S')
+ else:
+ end_time = discharge_time
+ end_time = end_time + timedelta(hours=24)
+ except:
+ end_time = admission.timestamp + timedelta(days=30)
+ else:
+ end_time = admission.timestamp + timedelta(days=30)
+
+ try:
+ all_diagnoses = patient.get_events(event_type="diagnoses_icd")
+ admission_id = admission.attr_dict.get('hadm_id')
+ diagnoses = []
+ if admission_id:
+ for diag in all_diagnoses:
+ if diag.attr_dict.get('hadm_id') == admission_id:
+ diagnoses.append(diag)
+ for diagnosis in diagnoses:
+ icd9_code = diagnosis.attr_dict.get('icd9_code')
+ if icd9_code:
+ code = str(icd9_code).strip()
+ if self.convert_to_3digit_icd9:
+ code = self.convert_to_3digit_icd9(code)
+ visit_codes['diag'].append(code)
+ except Exception:
+ pass
+
+ if self.include_procedures:
+ try:
+ all_procedures = patient.get_events(event_type="procedures_icd")
+ procedures = []
+ if admission_id:
+ for proc in all_procedures:
+ if proc.attr_dict.get('hadm_id') == admission_id:
+ procedures.append(proc)
+ for procedure in procedures:
+ icd9_code = procedure.attr_dict.get('icd9_code')
+ if icd9_code:
+ visit_codes['proc'].append(str(icd9_code).strip())
+ except Exception:
+ pass
+
+ if self.include_medications:
+ try:
+ all_prescriptions = patient.get_events(event_type="prescriptions")
+ prescriptions = []
+ if admission_id:
+ for pres in all_prescriptions:
+ if pres.attr_dict.get('hadm_id') == admission_id:
+ prescriptions.append(pres)
+ for prescription in prescriptions:
+ drug = prescription.attr_dict.get('drug')
+ if drug:
+ visit_codes['med'].append(str(drug).strip())
+ except Exception:
+ pass
+
+ for code_type in visit_codes:
+ visit_codes[code_type] = list(dict.fromkeys(visit_codes[code_type]))
+
+ return visit_codes
+
+ def _extract_baseline_features(self, patient, first_admission) -> List[float]:
+
+ features = []
+
+ age = first_admission.attr_dict.get('age', 65.0)
+ if age is None:
+ age = 65.0
+ features.append(min(float(age) / 100.0, 1.0))
+
+ gender = first_admission.attr_dict.get('gender', 'F')
+ if gender is None:
+ gender = 'F'
+ features.append(1.0 if str(gender).upper() == 'M' else 0.0)
+
+ admission_type = first_admission.attr_dict.get('admission_type', '').upper()
+ if 'EMERGENCY' in admission_type:
+ adm_type_val = 1.0
+ elif 'ELECTIVE' in admission_type:
+ adm_type_val = 0.5
+ elif 'URGENT' in admission_type:
+ adm_type_val = 0.75
+ else:
+ adm_type_val = 0.25
+ features.append(adm_type_val)
+
+ insurance = str(first_admission.attr_dict.get('insurance', '')).upper()
+ if 'MEDICARE' in insurance or 'MEDICAID' in insurance:
+ ins_val = 1.0
+ elif 'PRIVATE' in insurance:
+ ins_val = 0.5
+ elif 'SELF' in insurance:
+ ins_val = 0.25
+ else:
+ ins_val = 0.0
+ features.append(ins_val)
+
+ ethnicity = str(first_admission.attr_dict.get('ethnicity', '')).upper()
+ features.append(1.0 if 'WHITE' in ethnicity else 0.0)
+ features.append(1.0 if 'BLACK' in ethnicity or 'AFRICAN' in ethnicity else 0.0)
+ features.append(1.0 if 'HISPANIC' in ethnicity or 'LATINO' in ethnicity else 0.0)
+ features.append(1.0 if 'ASIAN' in ethnicity else 0.0)
+
+ marital = str(first_admission.attr_dict.get('marital_status', '')).upper()
+ features.append(1.0 if 'MARRIED' in marital else 0.0)
+
+ language = str(first_admission.attr_dict.get('language', '')).upper()
+ features.append(1.0 if language == 'ENGL' or language == 'ENGLISH' or language == '' else 0.0)
+
+ return features
+
+def preprocess_mimic3_data(mimic_root: str, output_dir: str, args) -> str:
+ print("Preprocessing MIMIC-III data...")
+
+ output_path = Path(output_dir)
+ output_path.mkdir(parents=True, exist_ok=True)
+
+ print(f"Loading MIMIC-III data from {mimic_root}")
+ tables = ["ADMISSIONS", "DIAGNOSES_ICD"]
+ if args.include_procedures:
+ tables.append("PROCEDURES_ICD")
+ if args.include_medications:
+ tables.append("PRESCRIPTIONS")
+
+ try:
+ dataset = MIMIC3Dataset(root=mimic_root, tables=tables)
+ print(f"Loaded {len(tables)} tables")
+ except Exception as e:
+ print(f"Failed to load dataset: {e}")
+ return None
+
+ print("Applying preprocessing...")
+ task = MIMIC3PromptEHRTask(
+ max_visits_per_patient=args.max_visits,
+ min_visits_per_patient=args.min_visits,
+ include_procedures=args.include_procedures,
+ include_medications=args.include_medications,
+ code_vocab_threshold=args.code_vocab_threshold
+ )
+
+ sample_dataset = dataset.set_task(task)
+
+ if len(sample_dataset.samples) == 0:
+ print("No samples generated")
+ return None
+
+ print(f"Processed {len(sample_dataset.samples)} patients")
+
+ print("Building vocabulary...")
+ vocab_stats = defaultdict(Counter)
+ for sample in sample_dataset.samples:
+ for code_type, visits in sample['v'].items():
+ for visit_codes in visits:
+ vocab_stats[code_type].update(visit_codes)
+
+ filtered_vocab = {}
+ for code_type, counter in vocab_stats.items():
+ filtered_vocab[code_type] = [code for code, count in counter.items() if count >= args.code_vocab_threshold]
+ print(f" {code_type}: {len(filtered_vocab[code_type])} codes (min_freq={args.code_vocab_threshold})")
+
+ print("Splitting data")
+ np.random.seed(42)
+ indices = np.random.permutation(len(sample_dataset.samples))
+ split_idx = int(len(sample_dataset.samples) * args.train_ratio)
+
+ train_indices = indices[:split_idx]
+ val_indices = indices[split_idx:]
+
+ train_samples = [sample_dataset.samples[i] for i in train_indices]
+ val_samples = [sample_dataset.samples[i] for i in val_indices]
+
+ print(f"Split: {len(train_samples)} train, {len(val_samples)} validation")
+
+ print("Saving data...")
+
+ with open(output_path / "train_samples.pkl", "wb") as f:
+ pickle.dump(train_samples, f)
+
+ with open(output_path / "val_samples.pkl", "wb") as f:
+ pickle.dump(val_samples, f)
+
+ with open(output_path / "vocabulary.pkl", "wb") as f:
+ pickle.dump(filtered_vocab, f)
+
+ metadata = {
+ 'total_patients': len(sample_dataset.samples),
+ 'train_patients': len(train_samples),
+ 'val_patients': len(val_samples),
+ 'vocabulary_sizes': {k: len(v) for k, v in filtered_vocab.items()},
+ 'code_types': list(filtered_vocab.keys()),
+ 'preprocessing_args': vars(args)
+ }
+
+ with open(output_path / "metadata.json", "w") as f:
+ json.dump(metadata, f, indent=2)
+
+ print(f"Saved to {output_path}")
+ print(f"Total patients: {metadata['total_patients']}")
+ print(f"Vocabulary sizes: {metadata['vocabulary_sizes']}")
+
+ return str(output_path)
+
+def create_promptehr_dataset(samples: List[Dict]) -> SampleDataset:
+
+ dataset = SampleDataset(
+ samples=samples,
+ input_schema={"v": "raw", "x": "raw"},
+ output_schema={}
+ )
+
+ dataset.metadata = {
+ 'visit': {'mode': 'dense'},
+ 'voc': {},
+ 'max_visit': max(s['num_visits'] for s in samples)
+ }
+
+ return dataset
+
+def train_promptehr_model(preprocess_dir: str, output_dir: str, args) -> str:
+ print("Training PromptEHR model...")
+
+ output_path = Path(output_dir)
+ output_path.mkdir(parents=True, exist_ok=True)
+
+ preprocess_path = Path(preprocess_dir)
+
+ print("Loading data...")
+ with open(preprocess_path / "train_samples.pkl", "rb") as f:
+ train_samples = pickle.load(f)
+
+ with open(preprocess_path / "val_samples.pkl", "rb") as f:
+ val_samples = pickle.load(f)
+
+ with open(preprocess_path / "metadata.json", "r") as f:
+ metadata = json.load(f)
+
+ print(f"Loaded {len(train_samples)} train, {len(val_samples)} val samples")
+
+ train_dataset = create_promptehr_dataset(train_samples)
+ val_dataset = create_promptehr_dataset(val_samples)
+
+ print("Initializing model...")
+ n_features = len(train_samples[0]['x']) if train_samples[0]['x'] else 0
+ model = PromptEHR(
+ code_type=metadata['code_types'],
+ n_num_feature=n_features,
+ cat_cardinalities=None,
+ epoch=args.epochs,
+ batch_size=args.batch_size,
+ eval_batch_size=args.eval_batch_size,
+ learning_rate=args.learning_rate,
+ output_dir=str(output_path / "training_logs"),
+ device=args.device
+ )
+
+ print(f"Code types: {model.config['code_type']}")
+ print(f"Epochs: {model.config['epoch']}")
+ print(f"Batch size: {model.config['batch_size']}")
+
+ print("Starting training...")
+ try:
+ model.fit(train_data=train_dataset, val_data=val_dataset)
+ print("Training completed")
+ except Exception as e:
+ print(f"Training failed: {e}")
+ return None
+
+ model_path = output_path / "trained_model"
+ print(f"Saving model to {model_path}")
+ model.save_model(str(model_path))
+
+ print(f"Model saved to {model_path}")
+ return str(model_path)
+
+def generate_synthetic_data(model_path: str, output_dir: str, args) -> str:
+ print("Generating synthetic data...")
+
+ output_path = Path(output_dir)
+ output_path.mkdir(parents=True, exist_ok=True)
+
+ print(f"Loading model from {model_path}")
+ try:
+ model = PromptEHR()
+ model.load_model(model_path)
+ print("Model loaded")
+ except Exception as e:
+ print(f"Failed to load model: {e}")
+ return None
+
+ print("Creating seed data...")
+ seed_samples = []
+ for i in range(min(args.n_seed_samples, 100)):
+ seed_sample = {
+ 'patient_id': f'seed_{i}',
+ 'v': {
+ 'diag': [['401', '250'], ['414', '428']], # Common diagnosis patterns
+ 'proc': [[], []],
+ 'med': [[], []]
+ },
+ 'x': np.random.randn(7).tolist()
+ }
+ seed_samples.append(seed_sample)
+
+ seed_dataset = create_promptehr_dataset(seed_samples)
+
+ print(f"Generating {args.n_synthetic} samples...")
+ try:
+ synthetic_results = model.predict(
+ test_data=seed_dataset,
+ n=args.n_synthetic,
+ n_per_sample=args.n_per_sample,
+ sample_config={'temperature': args.temperature},
+ verbose=True
+ )
+ print("Generation completed")
+ except Exception as e:
+ print(f"Generation failed: {e}")
+ return None
+
+ raw_output_path = output_path / "promptehr_synthetic_raw.pkl"
+ with open(raw_output_path, "wb") as f:
+ pickle.dump(synthetic_results, f)
+
+ print(f"Raw data saved to {raw_output_path}")
+
+ print("Converting to binary matrix...")
+ binary_matrix = convert_to_binary_matrix(synthetic_results, output_path)
+
+ if binary_matrix is not None:
+ print(f"Binary matrix saved with shape: {binary_matrix.shape}")
+ return str(output_path)
+ else:
+ return None
+
+def convert_to_binary_matrix(synthetic_results: Dict, output_dir: Path) -> Optional[np.ndarray]:
+
+ print("Extracting diagnosis codes...")
+ all_diag_codes = set()
+ patient_diagnoses = []
+
+ for i, patient_visits in enumerate(synthetic_results['visit']):
+ patient_diags = set()
+ for visit in patient_visits:
+ if visit and len(visit) > 0 and len(visit[0]) > 0:
+ for diag_code in visit[0]:
+ if isinstance(diag_code, (int, str)):
+ code_str = str(diag_code)
+ if code_str.startswith('E'):
+ if len(code_str) > 4:
+ code_str = code_str[:4]
+ else:
+ if len(code_str) > 3:
+ code_str = code_str[:3]
+
+ patient_diags.add(f'D_{code_str}')
+ all_diag_codes.add(f'D_{code_str}')
+
+ patient_diagnoses.append(patient_diags)
+
+ if not all_diag_codes:
+ print("No diagnosis codes found")
+ return None
+
+ sorted_codes = sorted(list(all_diag_codes))
+ code_to_idx = {code: idx for idx, code in enumerate(sorted_codes)}
+
+ print(f"Found {len(sorted_codes)} unique diagnosis codes")
+
+ n_patients = len(patient_diagnoses)
+ n_features = len(sorted_codes)
+ binary_matrix = np.zeros((n_patients, n_features), dtype=np.float32)
+
+ for i, patient_diags in enumerate(patient_diagnoses):
+ for diag_code in patient_diags:
+ if diag_code in code_to_idx:
+ binary_matrix[i, code_to_idx[diag_code]] = 1.0
+
+ matrix_path = output_dir / "promptehr_synthetic_binary.npy"
+ np.save(matrix_path, binary_matrix)
+
+ mapping_path = output_dir / "code_mapping.pkl"
+ with open(mapping_path, "wb") as f:
+ pickle.dump({
+ 'code_to_idx': code_to_idx,
+ 'idx_to_code': {v: k for k, v in code_to_idx.items()},
+ 'sorted_codes': sorted_codes
+ }, f)
+
+ stats = {
+ 'n_patients': int(n_patients),
+ 'n_features': int(n_features),
+ 'sparsity': float(1.0 - (np.count_nonzero(binary_matrix) / binary_matrix.size)),
+ 'avg_codes_per_patient': float(np.mean(np.sum(binary_matrix, axis=1))),
+ 'total_unique_codes': len(sorted_codes),
+ 'timestamp': pd.Timestamp.now().isoformat(),
+ 'generation_method': 'PromptEHR'
+ }
+
+ stats_path = output_dir / "generation_stats.json"
+ with open(stats_path, "w") as f:
+ json.dump(stats, f, indent=2)
+
+ print("Creating CSV output...")
+
+ patient_data = []
+ for i, patient_diags in enumerate(patient_diagnoses):
+ patient_data.append({
+ 'patient_id': f'synthetic_{i:06d}',
+ 'num_diagnosis_codes': len(patient_diags),
+ 'diagnosis_codes': ';'.join(sorted(list(patient_diags))) if patient_diags else '',
+ 'generation_timestamp': pd.Timestamp.now().isoformat(),
+ 'generation_method': 'PromptEHR'
+ })
+
+ patient_df = pd.DataFrame(patient_data)
+ patient_csv_path = output_dir / "synthetic_patients_summary.csv"
+ patient_df.to_csv(patient_csv_path, index=False)
+ print(f"Patient summary saved to {patient_csv_path}")
+
+ code_freq_data = []
+ for code in sorted_codes:
+ freq = np.sum(binary_matrix[:, code_to_idx[code]])
+ code_freq_data.append({
+ 'diagnosis_code': code,
+ 'frequency': int(freq),
+ 'prevalence': freq / n_patients,
+ 'code_type': 'ICD9_diagnosis'
+ })
+
+ freq_df = pd.DataFrame(code_freq_data)
+ freq_csv_path = output_dir / "synthetic_code_frequencies.csv"
+ freq_df.to_csv(freq_csv_path, index=False)
+ print(f"Code frequencies saved to {freq_csv_path}")
+
+ print("Creating sparse CSV...")
+ sparse_data = []
+ for i in range(n_patients):
+ for j in range(n_features):
+ if binary_matrix[i, j] == 1:
+ sparse_data.append({
+ 'patient_id': f'synthetic_{i:06d}',
+ 'diagnosis_code': sorted_codes[j],
+ 'present': 1
+ })
+
+ if sparse_data:
+ sparse_df = pd.DataFrame(sparse_data)
+ sparse_csv_path = output_dir / "synthetic_patient_diagnoses_sparse.csv"
+ sparse_df.to_csv(sparse_csv_path, index=False)
+ print(f"Sparse matrix saved to {sparse_csv_path}")
+
+ print(f"Shape: {binary_matrix.shape}")
+ print(f"Sparsity: {stats['sparsity']:.3f}")
+ print(f"Avg codes per patient: {stats['avg_codes_per_patient']:.1f}")
+
+ return binary_matrix
+
+def main():
+ parser = argparse.ArgumentParser(description="PromptEHR MIMIC-III Synthetic Data Generation")
+
+ parser.add_argument("--mode", type=str, choices=['train_and_generate', 'generate_only', 'preprocess_only'],
+ default='train_and_generate', help="Pipeline mode")
+
+ parser.add_argument("--mimic_root", type=str, default="./data_files", help="MIMIC-III root directory")
+ parser.add_argument("--output_dir", type=str, default="./promptehr_synthetic", help="Output directory")
+ parser.add_argument("--model_path", type=str, help="Path to trained model (for generate_only mode)")
+ parser.add_argument("--preprocess_dir", type=str, help="Path to preprocessed data")
+
+ parser.add_argument("--max_visits", type=int, default=20, help="Max visits per patient")
+ parser.add_argument("--min_visits", type=int, default=2, help="Min visits per patient")
+ parser.add_argument("--include_procedures", action="store_true", help="Include procedure codes")
+ parser.add_argument("--include_medications", action="store_true", help="Include medication codes")
+ parser.add_argument("--code_vocab_threshold", type=int, default=5, help="Minimum code frequency")
+ parser.add_argument("--train_ratio", type=float, default=0.8, help="Training data ratio")
+
+ parser.add_argument("--epochs", type=int, default=10, help="Training epochs")
+ parser.add_argument("--batch_size", type=int, default=8, help="Training batch size")
+ parser.add_argument("--eval_batch_size", type=int, default=8, help="Evaluation batch size")
+ parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate")
+ parser.add_argument("--device", type=str, default="cuda", help="Training device")
+
+ parser.add_argument("--n_synthetic", type=int, default=10000, help="Number of synthetic samples")
+ parser.add_argument("--n_per_sample", type=int, default=1, help="Samples per seed patient")
+ parser.add_argument("--n_seed_samples", type=int, default=100, help="Number of seed samples")
+ parser.add_argument("--temperature", type=float, default=1.0, help="Generation temperature")
+
+ args = parser.parse_args()
+
+ print(f"Mode: {args.mode}")
+ print(f"Output: {args.output_dir}")
+
+ if args.mode == 'preprocess_only':
+ preprocess_dir = preprocess_mimic3_data(args.mimic_root, args.output_dir, args)
+ if preprocess_dir:
+ print(f"Preprocessing completed")
+ print(f"Data saved to: {preprocess_dir}")
+
+ elif args.mode == 'generate_only':
+ if not args.model_path:
+ print("--model_path required for generate_only mode")
+ return
+
+ synthetic_dir = generate_synthetic_data(args.model_path, args.output_dir, args)
+ if synthetic_dir:
+ print(f"Generation completed")
+ print(f"Data saved to: {synthetic_dir}")
+
+ elif args.mode == 'train_and_generate':
+
+ if args.preprocess_dir:
+ preprocess_dir = args.preprocess_dir
+ print(f"Using existing preprocessed data: {preprocess_dir}")
+ else:
+ preprocess_dir = preprocess_mimic3_data(args.mimic_root,
+ str(Path(args.output_dir) / "preprocessed"), args)
+ if not preprocess_dir:
+ print("Preprocessing failed")
+ return
+
+ model_path = train_promptehr_model(preprocess_dir,
+ str(Path(args.output_dir) / "model"), args)
+ if not model_path:
+ print("Training failed")
+ return
+
+ synthetic_dir = generate_synthetic_data(model_path,
+ str(Path(args.output_dir) / "synthetic"), args)
+ if synthetic_dir:
+ print(f"Pipeline completed")
+ print(f"Data saved to: {synthetic_dir}")
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/pyhealth/models/generators/__init__.py b/pyhealth/models/generators/__init__.py
new file mode 100644
index 000000000..0b650a1ad
--- /dev/null
+++ b/pyhealth/models/generators/__init__.py
@@ -0,0 +1,3 @@
+from .promptehr import PromptEHR
+
+__all__ = ["PromptEHR"]
\ No newline at end of file
diff --git a/pyhealth/models/generators/halo.py b/pyhealth/models/generators/halo.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pyhealth/models/generators/medgan.py b/pyhealth/models/generators/medgan.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pyhealth/models/generators/promptehr.py b/pyhealth/models/generators/promptehr.py
new file mode 100644
index 000000000..9650c723b
--- /dev/null
+++ b/pyhealth/models/generators/promptehr.py
@@ -0,0 +1,2669 @@
+'''
+User interface to use promptEHR models.
+Adapted from original PromptEHR implementation for PyHealth integration.
+Preserves original BART-based architecture with conditional prompts.
+'''
+import os
+import pdb
+import json
+import math
+import glob
+import random
+import copy
+import time
+from collections import defaultdict
+import warnings
+
+import pickle
+import torch
+from torch import nn, Tensor
+from torch.utils.data.dataloader import DataLoader
+import numpy as np
+from tqdm import tqdm
+from typing import Any, Dict, List, Optional, Tuple, Union
+from dataclasses import dataclass
+
+# Added from original PromptEHR for training support
+from transformers import TrainingArguments, Trainer
+from torch.nn.utils.rnn import pad_sequence
+from transformers.data.data_collator import InputDataClass
+from transformers.trainer_pt_utils import (
+ nested_detach, nested_concat, nested_truncate, nested_numpify, find_batch_size
+)
+from transformers.trainer_utils import has_length, denumpify_detensorize, EvalLoopOutput, EvalPrediction
+from transformers.trainer_pt_utils import IterableDatasetShard
+
+from transformers import BartTokenizer, BartConfig
+from transformers.generation.utils import GenerationMixin
+from transformers.models.bart.modeling_bart import BartModel, BartPretrainedModel, BartEncoder, BartDecoder
+from transformers.models.bart.modeling_bart import shift_tokens_right, BartLearnedPositionalEmbedding as TransformersBartLearnedPositionalEmbedding
+from transformers.modeling_outputs import BaseModelOutput, Seq2SeqModelOutput, BaseModelOutputWithPastAndCrossAttentions
+from transformers.file_utils import ModelOutput
+from tokenizers import Tokenizer
+from tokenizers.pre_tokenizers import Whitespace
+from tokenizers.models import WordLevel
+
+
+# Constants from original PromptEHR implementation
+CODE_TYPES = ['tbd']
+SPECIAL_TOKEN_DICT = {'tbd':['','']}
+UNKNOWN_TOKEN = ''
+MODEL_MAX_LENGTH = 512
+EPS = 1e-16
+
+# Additional constants from original constants.py
+PRETRAINED_MODEL_URL = 'https://storage.googleapis.com/pytrial/promptEHR_pretrained.zip'
+SYNTHETIC_DATA_URL = 'https://github.com/RyanWangZf/PromptEHR/raw/main/demo_data/synthetic_ehr/data.pkl'
+
+# a name mapping from the original promptehr config to the training_args
+config_to_train_args = {
+ 'epoch': 'num_train_epochs',
+ 'num_worker': 'dataloader_num_workers',
+ 'batch_size': 'per_device_train_batch_size',
+ 'eval_batch_size': 'per_device_eval_batch_size',
+ 'eval_step': 'eval_steps',
+}
+
+def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ bsz, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+ inverted_mask = 1.0 - expanded_mask
+ return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
+
+def _all_or_none(values):
+ return all(x is None for x in values) or all(x is not None for x in values)
+
+class NumericalConditionalPrompt(nn.Module):
+ '''Embedding for conditional prompts based on numerical input patient features,
+ take reparametrization trick.
+
+ Parameters
+ ----------
+ n_feature: number of input features.
+ d_model: dimension of output embeddings.
+ d_hidden: dimension of intermediate embeddings for reparametrization.
+ '''
+ def __init__(self, n_feature, d_model, d_hidden) -> None:
+ super().__init__()
+ self.weight = nn.init.xavier_uniform_(nn.Parameter(Tensor(n_feature, d_hidden)))
+ self.bias = nn.init.xavier_uniform_(nn.Parameter(Tensor(n_feature, d_hidden)))
+ self.proj = nn.Linear(d_hidden, d_model, bias=False)
+
+ def forward(self, x):
+ # Ensure weight and bias are on the same device as input
+ device = x.device
+ weight = self.weight.to(device)
+ bias = self.bias.to(device)
+
+ x = weight[None] * x[..., None]
+ x = x + bias[None]
+
+ # Ensure projection layer is on the same device as input
+ self.proj = self.proj.to(device)
+ x = self.proj(x)
+ return x
+
+class CategoricalConditionalPrompt(nn.Module):
+ '''Embedding for conditional prompts based on categorical input patient features,
+ take reparametrization trick.
+
+ Parameters
+ ----------
+ cardinalities: the number of distinct values for each feature, e.g., [2, 3, 5] indicates the first cat has 2 possible categories and so on.
+ d_model: the output embedding dimension.
+ d_hidden: the intermediate layer dimension for reparameterization.
+ '''
+ def __init__(self,
+ cardinalities,
+ d_model,
+ d_hidden
+ ) -> None:
+ super().__init__()
+ assert cardinalities, 'cardinalities must be non-empty'
+ category_offsets = torch.tensor([0] + cardinalities[:-1]).cumsum(0)
+ self.register_buffer('category_offsets', category_offsets, persistent=False)
+ self.embeddings = nn.Embedding(sum(cardinalities), d_hidden)
+ self.bias = nn.init.xavier_uniform_(nn.Parameter(Tensor(len(cardinalities),d_hidden)))
+ self.proj = nn.Linear(d_hidden, d_model, bias=False)
+
+ def forward(self, x):
+ # Ensure category_offsets and bias are on the same device as input
+ device = x.device
+ category_offsets = self.category_offsets.to(device)
+ bias = self.bias.to(device)
+
+ # Ensure embeddings and projection layer are on the same device as input
+ self.embeddings = self.embeddings.to(device)
+ self.proj = self.proj.to(device)
+
+ x = self.embeddings(x + category_offsets[None])
+ x = x + bias[None]
+ x = self.proj(x)
+ return x
+
+class ConditionalPrompt(nn.Module):
+ '''Provide conditional prompt embedding for both categorical and numerical features.
+
+ Parameters
+ ----------
+ n_num_feature: number of input numerical features.
+ cat_cardinalities: a list of unique numbers of each feature.
+ d_model: the output dimension.
+ d_hidden: the intermediate layer dimension for reparametrization.
+ '''
+ def __init__(self,
+ n_num_feature=None,
+ cat_cardinalities=None,
+ d_model=None,
+ d_hidden=None,
+ ) -> None:
+ super().__init__()
+ if n_num_feature is not None:
+ assert isinstance(n_num_feature, int), 'the passed `n_num_feature` to `promptehr` must be an integer, {} with type {} found.'.format(n_num_feature, type(n_num_feature))
+ assert n_num_feature >= 0, 'n_num_feature must be non-negative'
+ assert (n_num_feature or cat_cardinalities), 'at least one of n_num_feature or cat_cardinalities must be positive/non-empty'
+ self.num_tokenizer = (
+ NumericalConditionalPrompt(
+ n_feature=n_num_feature,
+ d_model=d_model,
+ d_hidden=d_hidden,
+ )
+ if n_num_feature
+ else None
+ )
+ self.cat_tokenizer = (
+ CategoricalConditionalPrompt(
+ cat_cardinalities,
+ d_model=d_model,
+ d_hidden=d_hidden,
+ )
+ if cat_cardinalities
+ else None
+ )
+
+ def forward(self, x_num=None, x_cat=None):
+ '''Perform the forward pass to encode features into prompt context vectors.
+
+ Parameters
+ ----------
+ x_num: continuous features. Must be presented if :code:`n_num_feature > 0` was passed.
+ x_cat: categorical features. Must be presented if non-empty :code:`cat_cardinalities` was passed.
+ '''
+ assert (
+ x_num is not None or x_cat is not None
+ ), 'At least one of x_num and x_cat must be presented'
+ assert _all_or_none(
+ [self.num_tokenizer, x_num]
+ ), 'If self.num_tokenizer is (not) None, then x_num must (not) be None'
+ assert _all_or_none(
+ [self.cat_tokenizer, x_cat]
+ ), 'If self.cat_tokenizer is (not) None, then x_cat must (not) be None'
+ x = []
+ if self.num_tokenizer is not None:
+ x.append(self.num_tokenizer(x_num))
+ if self.cat_tokenizer is not None:
+ x.append(self.cat_tokenizer(x_cat))
+ return x[0] if len(x) == 1 else torch.cat(x, dim=1)
+
+
+class BartLearnedPositionalEmbedding(nn.Embedding):
+ """
+ This module learns positional embeddings up to a fixed maximum size.
+ """
+
+ def __init__(self, num_embeddings: int, embedding_dim: int):
+ # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
+ # and adjust num_embeddings appropriately. Other models don't have this hack
+ self.offset = 2
+ super().__init__(num_embeddings + self.offset, embedding_dim)
+
+ def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
+ """`input_ids_shape` is expected to be [bsz x seqlen]."""
+ bsz, seq_len = input_ids_shape[:2]
+ # Handle the case where seq_len might be a tensor or torch.Size element
+ if torch.is_tensor(seq_len):
+ if seq_len.numel() == 1:
+ seq_len = seq_len.item()
+ else:
+ # If it's a multi-element tensor, take the first element or max
+ seq_len = seq_len.max().item() if seq_len.numel() > 0 else 1
+ else:
+ seq_len = int(seq_len)
+ positions = torch.arange(
+ past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
+ )
+ positions = positions + self.offset
+ positions = torch.minimum(positions, torch.ones_like(positions).to(positions.device)*1024)
+ res = super().forward(positions)
+ return res
+
+
+class PromptBartEncoder(BartEncoder):
+ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None):
+ super().__init__(config, embed_tokens)
+ embed_dim = config.d_model
+ self.embed_positions = BartLearnedPositionalEmbedding(
+ config.max_position_embeddings,
+ embed_dim,
+ )
+ # Add missing embed_scale (standard BART uses sqrt of d_model if scale_embedding is True)
+ self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
+
+ def forward(self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_prompt_embeds: Optional[torch.FloatTensor]=None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
+ '''Make encoding.
+ Parameters
+ ----------
+ inputs_prompt_embeds: Embeded conditional prompt embeddings.
+ '''
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
+
+ if inputs_prompt_embeds is not None:
+ # concatenate prompt embeddings in front of the input embeds
+ # modify input_shape and attention_mask at the same time
+ inputs_embeds = torch.cat([inputs_prompt_embeds, inputs_embeds], dim=1)
+ input_shape = inputs_embeds.size()[:-1]
+ if attention_mask is not None:
+ add_att_mask = torch.ones(inputs_prompt_embeds.shape[:-1]).to(attention_mask.device)
+ attention_mask = torch.cat([add_att_mask, attention_mask], dim=1)
+
+ embed_pos = self.embed_positions(input_shape)
+ hidden_states = inputs_embeds + embed_pos
+ hidden_states = self.layernorm_embedding(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ # expand attention_mask
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ # check if head_mask has a correct number of layers specified if desired
+ if head_mask is not None:
+ if head_mask.size()[0] != (len(self.layers)):
+ raise ValueError(
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
+ )
+
+ for idx, encoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ dropout_probability = random.uniform(0, 1)
+ if self.training and (dropout_probability < self.layerdrop): # skip the layer
+ layer_outputs = (None, None)
+ else:
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(encoder_layer),
+ hidden_states,
+ attention_mask,
+ (head_mask[idx] if head_mask is not None else None),
+ )
+ else:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+ )
+
+class PromptBartDecoder(BartDecoder):
+ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None):
+ super().__init__(config, embed_tokens)
+ self.embed_positions = BartLearnedPositionalEmbedding(
+ config.max_position_embeddings,
+ config.d_model,
+ )
+ # Add missing embed_scale (standard BART uses sqrt of d_model if scale_embedding is True)
+ self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
+
+ def forward(self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_prompt_embeds: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **kwargs,
+ ):
+ '''Make forward pass by the decoder.
+
+ Parameters
+ ----------
+ inputs_prompt_embeds: the embeddings of conditional prompts for the decoder.
+
+ '''
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ # past_key_values_length
+ past_key_values_length = int(past_key_values[0][0].shape[2]) if past_key_values is not None else 0
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
+
+ if inputs_prompt_embeds is not None:
+ # concatenate prompt embeddings in front of the input embeds
+ # modify input_shape and attention_mask at the same time
+ inputs_embeds = torch.cat([inputs_prompt_embeds, inputs_embeds], dim=1)
+ input_shape = inputs_embeds.size()[:-1]
+ if attention_mask is not None:
+ add_att_mask = torch.ones(inputs_prompt_embeds.shape[:-1]).to(attention_mask.device)
+ attention_mask = torch.cat([add_att_mask, attention_mask], dim=1)
+
+ # Handle different transformers versions - method was renamed
+ if hasattr(self, '_prepare_decoder_attention_mask'):
+ attention_mask = self._prepare_decoder_attention_mask(
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
+ )
+ elif hasattr(self, 'create_extended_attention_mask_for_decoder'):
+ if attention_mask is not None:
+ attention_mask = self.create_extended_attention_mask_for_decoder(
+ input_shape, attention_mask, past_key_values_length
+ )
+ else:
+ # Fallback for newer transformers versions
+ pass
+
+ # expand encoder attention mask
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
+ if inputs_prompt_embeds is not None:
+ # adjust for input prompt embeddings
+ add_att_mask = torch.ones(inputs_prompt_embeds.shape[:-1]).to(encoder_attention_mask.device)
+ encoder_attention_mask = torch.cat([add_att_mask, encoder_attention_mask], dim=1)
+
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
+
+ # embed positions
+ dummy_input_ids = torch.zeros(input_shape, dtype=torch.long, device=inputs_embeds.device)
+
+ # Debug and bounds check
+ seq_len = input_shape[-1]
+ max_pos_embeddings = self.embed_positions.num_embeddings
+ max_allowed_pos = max_pos_embeddings - 10 # Conservative buffer
+
+ # Calculate safe sequence length considering past key values
+ max_safe_seq_len = max(1, max_allowed_pos - past_key_values_length)
+
+ # Ensure we don't truncate to less than 1
+ if seq_len > max_safe_seq_len and max_safe_seq_len > 0:
+ truncated_seq_len = max_safe_seq_len
+ truncated_shape = input_shape[:-1] + (truncated_seq_len,)
+ dummy_input_ids = torch.zeros(truncated_shape, dtype=torch.long, device=inputs_embeds.device)
+ inputs_embeds = inputs_embeds[:, :truncated_seq_len, :]
+ # Update input_shape for attention mask compatibility
+ input_shape = truncated_shape
+ else:
+ # Don't truncate if it would make sequence too short
+ truncated_seq_len = seq_len
+ dummy_input_ids = torch.zeros(input_shape, dtype=torch.long, device=inputs_embeds.device)
+
+ positions = self.embed_positions(dummy_input_ids.shape, past_key_values_length)
+
+ # Ensure both tensors are on the same device before addition
+ if positions.device != inputs_embeds.device:
+ try:
+ positions = positions.to(inputs_embeds.device)
+ except RuntimeError:
+ # If device transfer fails, move inputs_embeds to positions device
+ inputs_embeds = inputs_embeds.to(positions.device)
+
+ hidden_states = inputs_embeds + positions
+ hidden_states = self.layernorm_embedding(hidden_states)
+
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+ next_decoder_cache = () if use_cache else None
+
+ # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
+ for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
+ if attn_mask is not None:
+ if attn_mask.size()[0] != (len(self.layers)):
+ raise ValueError(
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
+ )
+
+ for idx, decoder_layer in enumerate(self.layers):
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+ dropout_probability = random.uniform(0, 1)
+ if self.training and (dropout_probability < self.layerdrop):
+ continue
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ if use_cache:
+ print(
+ "[warning] `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, output_attentions, use_cache)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(decoder_layer),
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ head_mask[idx] if head_mask is not None else None,
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
+ None,
+ )
+ else: # testing/generating
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ cross_attn_layer_head_mask=(
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
+ ),
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ if encoder_hidden_states is not None:
+ all_cross_attentions += (layer_outputs[2],)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ cross_attentions=all_cross_attentions,
+ )
+
+class PromptBartModel(BartModel):
+ '''a subclass of BartModel by using additional prompts for controllable EHR generation.
+ '''
+ def __init__(self, config: BartConfig):
+ super().__init__(config)
+ # Store config for later access
+ self.config = config
+
+ padding_idx, vocab_size = config.pad_token_id, config.vocab_size
+ self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
+
+ self.encoder = PromptBartEncoder(config, self.shared)
+ self.decoder = PromptBartDecoder(config, self.shared)
+
+ # build encoder & decoder prompts
+ n_num_feature = config.n_num_feature
+ cat_cardinalities = config.cat_cardinalities
+ if n_num_feature is not None or cat_cardinalities is not None:
+ self.encoder_conditional_prompt = ConditionalPrompt(n_num_feature=n_num_feature,
+ cat_cardinalities=cat_cardinalities,
+ d_model=config.d_model,
+ d_hidden=config.d_prompt_hidden)
+ self.decoder_conditional_prompt = ConditionalPrompt(n_num_feature=n_num_feature,
+ cat_cardinalities=cat_cardinalities,
+ d_model=config.d_model,
+ d_hidden=config.d_prompt_hidden)
+ else:
+ # fix when no baseline feature is provided.
+ warnings.warn('No numerical or categorical baseline features are provided, `ConditionalPrompt` is not used in the model.')
+ self.encoder_conditional_prompt = None
+ self.decoder_conditional_prompt = None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def forward(self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
+ x_num: Optional[torch.FloatTensor] = None,
+ x_cat: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ decoder_head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ encoder_outputs: Optional[List[torch.FloatTensor]] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
+ '''Make the forward pass to encode inputs with Bart model.
+
+ Parameters
+ ----------
+ x_num: the input numerical features, shape (bs, num_feat)
+ x_cat: the input categorical features, shape (bs, num_cat)
+ '''
+ # different to other models, Bart automatically creates decoder_input_ids from
+ # input_ids if no decoder_input_ids are provided
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
+ if input_ids is None:
+ raise ValueError(
+ "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
+ "passed, `input_ids` cannot be `None`. Please pass either "
+ "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
+ )
+
+ decoder_input_ids = shift_tokens_right(
+ input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
+ )
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if encoder_outputs is None:
+ if x_num is not None or x_cat is not None:
+ if self.encoder_conditional_prompt is None:
+ warnings.warn('Detect input baseline features in the data,` \
+ but `ConditionalPrompt was not built because no numerical or categorical baseline features are provided when model was initialized. \
+ Consider setting `config.n_num_feature` or `config.cat_cardinalities` when initializing the model.')
+ prompt_embeds = None
+ else:
+ prompt_embeds = self.encoder_conditional_prompt(x_num=x_num, x_cat=x_cat)
+ else:
+ prompt_embeds = None
+
+ encoder_outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ inputs_prompt_embeds=prompt_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
+ encoder_outputs = BaseModelOutput(
+ last_hidden_state=encoder_outputs[0],
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
+ )
+
+ # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
+ if x_num is not None or x_cat is not None:
+ if self.decoder_conditional_prompt is None:
+ warnings.warn('{} {} {}'.format('Detect input baseline features in the data, but `ConditionalPrompt`',
+ 'was not built because no numerical or categorical baseline features',
+ 'Consider setting `config.n_num_feature` or `config.cat_cardinalities` when initializing the model.')
+ )
+ decoder_prompt_embeds = None
+ else:
+ decoder_prompt_embeds = self.decoder_conditional_prompt(x_num=x_num, x_cat=x_cat)
+ else:
+ decoder_prompt_embeds = None
+
+ decoder_outputs = self.decoder(
+ input_ids=decoder_input_ids,
+ attention_mask=decoder_attention_mask,
+ encoder_hidden_states=encoder_outputs[0],
+ encoder_attention_mask=attention_mask,
+ head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=decoder_inputs_embeds,
+ inputs_prompt_embeds=decoder_prompt_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if not return_dict:
+ return decoder_outputs + encoder_outputs
+
+ return Seq2SeqModelOutput(
+ last_hidden_state=decoder_outputs.last_hidden_state,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ )
+
+def EHRBartConfig(data_tokenizer, model_tokenizer, **kwargs):
+ '''Build the config used for building the promptBart model.
+ '''
+ bart_config = BartConfig.from_pretrained('facebook/bart-base')
+ # Store the num_tokens dict separately so we can access it later
+ num_tokens_dict = model_tokenizer.get_num_tokens
+ kwargs.update(num_tokens_dict)
+ bart_config.__dict__['_num_tokens_dict'] = num_tokens_dict # Store it with a special key
+ kwargs['data_tokenizer_num_vocab'] = len(data_tokenizer)
+
+ # CRITICAL FIX: Update vocab_size to match the extended tokenizer vocabulary
+ # The data_tokenizer has been extended with medical codes, so we need to update
+ # the BART config to match this larger vocabulary size
+ original_vocab_size = bart_config.vocab_size
+ extended_vocab_size = len(data_tokenizer)
+ bart_config.vocab_size = extended_vocab_size
+ print(f"Updated BART config vocab_size from {original_vocab_size} to {extended_vocab_size}")
+
+ if 'd_prompt_hidden' not in kwargs:
+ kwargs['d_prompt_hidden'] = 128
+ if 'n_num_feature' not in kwargs:
+ kwargs['n_num_feature'] = None
+ if 'cat_cardinalities' not in kwargs:
+ kwargs['cat_cardinalities'] = None
+ bart_config.__dict__.update(kwargs)
+
+ # specify bos, eos token id
+ bart_config.__dict__['decoder_start_token_id'] = 0
+ bart_config.__dict__['bos_token_id'] = 0
+ bart_config.__dict__['eos_token_id'] = 1
+ bart_config.__dict__['forced_eos_token_id'] = 1
+ return bart_config
+
+class DataTokenizer(BartTokenizer):
+ r'''construct tokenizer to process the input raw records.
+ '''
+ new_token_type_list = CODE_TYPES
+ special_token_dict = SPECIAL_TOKEN_DICT
+ code_vocab = defaultdict(list)
+
+ def add_token_to_code_vocab(self, tokens, code):
+ # Only add tokens that aren't already in the tokenizer vocabulary
+ new_tokens = [token for token in tokens if token not in self.get_vocab()]
+ if new_tokens:
+ self.add_tokens(new_tokens)
+
+ if code not in self.code_vocab:
+ self.code_vocab[code] = np.array(tokens)
+ else:
+ origin_tokens = self.code_vocab[code]
+ new_tokens = np.array(tokens)
+ self.code_vocab[code] = np.unique(np.concatenate([origin_tokens, new_tokens]))
+
+ def update_special_token_config(self, code_types):
+ self.new_token_type_list = code_types
+ self.special_token_dict = {}
+ special_token_list = []
+ for code_type in code_types:
+ l = [f'<{code_type}>', f'{code_type}>']
+ self.special_token_dict[code_type] = l
+ special_token_list.extend(l)
+ self.add_tokens(special_token_list)
+
+ def extend_vocab(self, token_dict):
+ '''
+ Parameters:
+ ----------
+ token_dict: dict
+ key: code type, value: a list of tokens.
+ '''
+ for key in token_dict.keys():
+ self.code_vocab[key] = np.array(token_dict[key])
+ self.add_tokens(token_dict[key])
+
+ def extend_vocab_from_dir(self, data_dir):
+ # add new tokens from the data dir
+ for key in self.new_token_type_list:
+ filename = os.path.join(data_dir,'{}_token_list.txt'.format(key))
+ with open(filename, 'r', encoding='utf-8') as f:
+ token_list = [line.strip() for line in f.readlines()]
+ self.code_vocab[key] = np.array(token_list)
+ self.add_tokens(token_list)
+
+ # add special tokens indicating different modality
+ for key, value in self.special_token_dict.items():
+ self.add_tokens(value, special_tokens=True)
+
+class ModelTokenizer:
+ r'''construct an EHR tokenizer that converts tokenized indices to code-specific token indices.
+ '''
+ def __init__(self, tokenizer: DataTokenizer):
+ # map_token = lambda x: str(tokenizer(x).input_ids[1])
+ org_vocab = tokenizer.get_vocab()
+ tokenizer_dict = {}
+ num_token_dict = {}
+ label_offset = 1 # Default offset for special tokens (UNKNOWN_TOKEN = 0, so offset starts at 1)
+
+ for key, value in tokenizer.code_vocab.items():
+ vocab = defaultdict(int)
+ vocab[UNKNOWN_TOKEN] = 0
+ for i,token in enumerate(tokenizer.special_token_dict[key]):
+ vocab[str(org_vocab[token])] = i+1
+ offset = len(vocab)
+ label_offset = offset # Update with the last computed offset
+
+ for i, token in enumerate(value): # str token = 'diag_xxx'
+ # fix: if token has more than one '_', e.g., 'diag_t_a_b_100', will only take the last '100' as the index.
+ # _, index = token.split('_')
+ indexes = token.split('_')
+ try:
+ index = int(indexes[-1])
+ except:
+ raise ValueError(f"Token {token} is not a valid token, it should be splited by '_' and the last part should be a number, e.g., 'diag_100'. ")
+ vocab[str(org_vocab[token])] = index + offset
+
+ # new tokenizer
+ specific_tokenizer = Tokenizer(WordLevel(vocab=vocab, unk_token=UNKNOWN_TOKEN))
+ specific_tokenizer.pre_tokenizer = Whitespace()
+
+ # num_token_dict is decided by the max index instead of number of tokens
+ num_token_dict[key] = (max(vocab.values())+1) - offset
+ tokenizer_dict[key] = specific_tokenizer
+
+ # each code type has its own tokenizer corresponding to specific LM heads
+ self.tokenizer_dict = tokenizer_dict
+ self.num_token_dict = num_token_dict
+ self.label_offset = label_offset
+
+
+ def encode(self, input_ids, code_type):
+ if len(input_ids.shape) > 1: # a batch
+ ids = self.encode_batch(input_ids, code_type)
+ else:
+ ids = self.tokenizer_dict[code_type].encode(input_ids.cpu().numpy().astype(str), is_pretokenized=True).ids
+ ids = torch.tensor(ids, device=input_ids.device)
+
+ return ids
+
+ def encode_batch(self, input_ids, code_type):
+ ids_list = self.tokenizer_dict[code_type].encode_batch(input_ids.cpu().numpy().astype(str).tolist(), is_pretokenized=True)
+
+ ids = torch.tensor([x.ids for x in ids_list], device=input_ids.device)
+
+ return ids
+
+ @property
+ def get_num_tokens(self):
+ return self.num_token_dict
+
+@dataclass
+class EHRBartOutput(ModelOutput):
+ """
+ Base class for sequence-to-sequence language models outputs.
+
+ Args:
+ loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
+ Language modeling loss.
+ logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
+ Tuple of :obj:`tuple(torch.FloatTensor)` of length :obj:`config.n_layers`, with each tuple having 2 tensors
+ of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
+ shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding.
+ decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
+ decoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
+ Tuple of :obj:`tuple(torch.FloatTensor)` (one for each layer) of shape :obj:`(batch_size, num_heads,
+ sequence_length, sequence_length)`.
+
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
+ Tuple of :obj:`tuple(torch.FloatTensor)` (one for each layer) of shape :obj:`(batch_size, num_heads,
+ sequence_length, sequence_length)`.
+
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ encoder_last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
+ encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
+ encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
+ Tuple of :obj:`tuple(torch.FloatTensor)` (one for each layer) of shape :obj:`(batch_size, num_heads,
+ sequence_length, sequence_length)`.
+
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ perplexity:
+ perplexity calculated when the label mask is given.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
+ decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
+ cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+ encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
+ perplexity: Optional[torch.FloatTensor] = None
+
+class BartForEHRSimulation(BartPretrainedModel, GenerationMixin):
+ '''BART model for EHR sequence simulation.
+ Extend the BartPretrainedModel to support code-specific output and conditional prompt.
+ '''
+ base_model_prefix = "model"
+ _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
+
+ def __init__(self, config: BartConfig):
+ super().__init__(config)
+ self.model = PromptBartModel(config)
+ # build LM heads for different code types
+ self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
+ # build LM head for each code type
+ self.lm_head_list = nn.ModuleDict()
+ # get_num_tokens was stored in config during EHRBartConfig
+ if hasattr(config, '_num_tokens_dict'):
+ num_tokens_dict = config._num_tokens_dict
+ else:
+ # Fallback: try to find the token counts from config attributes
+ num_tokens_dict = {}
+ standard_attrs = set(dir(BartConfig()))
+ for attr_name in dir(config):
+ if (attr_name not in standard_attrs and
+ not attr_name.startswith('_') and
+ hasattr(config, attr_name)):
+ attr_value = getattr(config, attr_name)
+ if isinstance(attr_value, int) and attr_value > 0:
+ num_tokens_dict[attr_name] = attr_value
+
+ for code_type in num_tokens_dict.keys():
+ lm_head = nn.Linear(config.d_model, num_tokens_dict[code_type], bias=False)
+ self.lm_head_list[code_type] = lm_head
+
+ # Create a main lm_head for compatibility with transformers library
+ # Use the vocabulary size from the shared embeddings
+ self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_encoder(self):
+ return self.model.encoder
+
+ def get_decoder(self):
+ return self.model.decoder
+
+ def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
+ new_embeddings = super().resize_token_embeddings(new_num_tokens)
+ self._resize_final_logits_bias(new_num_tokens)
+ return new_embeddings
+
+ def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
+ old_num_tokens = self.final_logits_bias.shape[-1]
+ if new_num_tokens <= old_num_tokens:
+ new_bias = self.final_logits_bias[:, :new_num_tokens]
+ else:
+ extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
+ new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
+ self.register_buffer("final_logits_bias", new_bias)
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def _compute_loss(self,
+ sequence_output,
+ target_label,
+ target_mask,
+ code_type,
+ return_logits=False,
+ **kwargs
+ ):
+ '''Compute the cross entropy loss for given code type.
+ '''
+ lm_head = self.lm_head_list[code_type]
+
+ # Ensure lm_head is on the same device as sequence_output
+ if sequence_output is not None:
+ device = sequence_output.device
+ lm_head = lm_head.to(device)
+ self.lm_head_list[code_type] = lm_head
+
+ # if return logits only, does not compute loss
+ if target_label is None:
+ lm_logits = lm_head(sequence_output)
+ return lm_logits
+
+ # compute loss per code type
+ lm_logits = lm_head(sequence_output)
+
+ # Handle empty tensor case
+ if lm_logits.numel() == 0 or target_label.numel() == 0:
+ # Return zero loss for empty tensors
+ return torch.tensor(0.0, device=lm_logits.device, requires_grad=True)
+
+ loss_fct = nn.CrossEntropyLoss(reduction='none')
+
+ # Ensure lm_logits and target_label have the same sequence length
+ min_seq_len = min(lm_logits.size(1), target_label.size(1))
+ lm_logits_trimmed = lm_logits[:, :min_seq_len, :]
+ target_label_trimmed = target_label[:, :min_seq_len]
+
+ # Clamp target labels to valid vocabulary range to prevent CUDA assertion errors
+ vocab_size = lm_logits_trimmed.size(-1)
+ target_label_trimmed = torch.clamp(target_label_trimmed, min=0, max=vocab_size-1)
+
+ masked_lm_loss = loss_fct(lm_logits_trimmed.reshape(-1, lm_logits_trimmed.size(-1)), target_label_trimmed.reshape(-1))
+ # mask out the loss for non-active predictions
+ masked_lm_loss = masked_lm_loss.reshape(lm_logits_trimmed.size(0), lm_logits_trimmed.size(1))
+ target_mask_trimmed = target_mask[:, :min_seq_len] if target_mask is not None else None
+ if target_mask_trimmed is not None:
+ masked_lm_loss = masked_lm_loss * target_mask_trimmed
+
+ loss = masked_lm_loss.sum() / (target_mask_trimmed.sum() if target_mask_trimmed is not None else 1)
+
+ if return_logits:
+ return loss, lm_logits
+ else:
+ return loss
+
+ def _get_perplexity(self,
+ sequence_output,
+ target_label,
+ target_mask,
+ code_type,
+ **kwargs
+ ):
+ '''compute perplexity.
+ '''
+ with torch.no_grad():
+ lm_head = self.lm_head_list[code_type]
+
+ # Ensure lm_head is on the same device as sequence_output
+ if sequence_output is not None:
+ device = sequence_output.device
+ lm_head = lm_head.to(device)
+ self.lm_head_list[code_type] = lm_head
+
+ lm_logits = lm_head(sequence_output)
+ lm_logits = torch.nn.functional.log_softmax(lm_logits, dim=-1)
+
+ # Ensure target_label indices are within vocabulary bounds
+ vocab_size = lm_logits.size(-1)
+ target_label_clamped = torch.clamp(target_label, min=0, max=vocab_size-1)
+
+ # (bs, seq_len)
+ picked_logits = torch.gather(lm_logits, 2, target_label_clamped.unsqueeze(-1)).squeeze(-1)
+ picked_logits = picked_logits * target_mask
+ sum_picked_logits = picked_logits.sum(dim=-1) # (bs,)
+ sum_target_mask = target_mask.sum(dim=-1) + EPS # (bs,)
+ perplexity = torch.exp(-sum_picked_logits / sum_target_mask)
+ return perplexity
+
+ def forward(self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
+ x_num: Optional[torch.FloatTensor] = None,
+ x_cat: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ decoder_head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ encoder_outputs: Optional[List[torch.FloatTensor]] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ label_mask: Optional[torch.LongTensor] = None,
+ code_type: Optional[str] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ return_perplexity: Optional[bool] = None,
+ **kwargs,
+ ):
+ r"""
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Labels for computing the masked language modeling loss. Indices should either be in ``[0, ...,
+ config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are ignored
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``.
+
+ Returns:
+ """
+ # Ensure all model components are on the same device as the input
+ if input_ids is not None:
+ device = input_ids.device
+ if hasattr(self.model, 'device') and self.model.device != device:
+ self.model = self.model.to(device)
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if labels is not None:
+ use_cache = False
+
+ outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ x_num=x_num,
+ x_cat=x_cat,
+ encoder_outputs=encoder_outputs,
+ decoder_attention_mask=decoder_attention_mask,
+ head_mask=head_mask,
+ decoder_head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ decoder_inputs_embeds=decoder_inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ if self.config.tie_word_embeddings:
+ # Rescale output before projecting on vocab
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
+ sequence_output = sequence_output * (self.model.shared.embedding_dim ** -0.5)
+
+ # compute loss / prediction
+ loss = None
+ perplexity = None
+ if labels is not None and code_type is not None:
+ # make sure label_mask exists when computing the loss
+ assert label_mask is not None
+ loss = self._compute_loss(sequence_output, labels, label_mask, code_type)
+ if return_perplexity:
+ perplexity = self._get_perplexity(sequence_output, labels, label_mask, code_type)
+
+ logits = {}
+ # compute all types logits when not training
+ if labels is None:
+ for code_type in self.lm_head_list.keys():
+ logits[code_type] = self._compute_loss(sequence_output, None, None, code_type)
+ else:
+ # only return interested logits at training stage
+ logits[code_type] = self._compute_loss(sequence_output, None, None, code_type)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return EHRBartOutput(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ decoder_hidden_states=outputs.decoder_hidden_states,
+ decoder_attentions=outputs.decoder_attentions,
+ cross_attentions=outputs.cross_attentions,
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
+ encoder_hidden_states=outputs.encoder_hidden_states,
+ encoder_attentions=outputs.encoder_attentions,
+ perplexity=perplexity,
+ )
+
+ def prepare_inputs_for_generation(self,
+ decoder_input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ x_num=None,
+ x_cat=None,
+ head_mask=None,
+ decoder_head_mask=None,
+ cross_attn_head_mask=None,
+ use_cache=None,
+ encoder_outputs=None,
+ **kwargs,
+ ):
+ # cut decoder_input_ids if past is used
+ if past_key_values is not None:
+ decoder_input_ids = decoder_input_ids[:, -1:]
+
+ return {
+ "input_ids": None, # encoder_outputs is defined. input_ids not needed
+ "encoder_outputs": encoder_outputs,
+ "past_key_values": past_key_values,
+ "decoder_input_ids": decoder_input_ids,
+ "attention_mask": attention_mask,
+ "x_num": x_num,
+ "x_cat": x_cat,
+ "head_mask": head_mask,
+ "decoder_head_mask": decoder_head_mask,
+ "cross_attn_head_mask": cross_attn_head_mask,
+ "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
+ }
+
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
+ return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
+
+ @staticmethod
+ def _reorder_cache(past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ # cached cross_attention states don't have to be reordered -> they are always the same
+ reordered_past += (
+ tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
+ )
+ return reordered_past
+
+# Full MimicDataCollator from original PromptEHR implementation
+class FullMimicDataCollator:
+ '''Data collator for train/evaluate the EHR-BART model.
+ Should keep the whole batch all with features or all without features,
+ otherwise raise error!
+ '''
+ __code_type_list__ = CODE_TYPES
+ __special_token_dict__ = SPECIAL_TOKEN_DICT
+ __del_or_rep__ = ['rep', 'del']
+
+ def __init__(self,
+ tokenizer,
+ code_types,
+ n_num_feature,
+ mlm_prob=0.15,
+ lambda_poisson=3.0,
+ del_prob=0.15,
+ max_train_batch_size=16,
+ drop_feature=False,
+ mode='train',
+ eval_code_type=None,
+ eval_ppl_type='span'
+ ):
+ '''mlm_prob: probability of masked tokens
+ lambda_poisoon: span infilling parameters
+ del_prob: probability of delete tokens
+ max_train_batch_size: sample batch to avoid OOM, because for each patient we will generate a batch of series
+ '''
+ # update code_types
+ self.__code_type_list__ = code_types
+ self.__special_token_dict__ = {}
+ for code in code_types: self.__special_token_dict__[code] = [f'<{code}>', f'{code}>']
+
+ self.mlm_prob = mlm_prob
+ self.tokenizer = tokenizer
+ self.tokenizer.model_max_length = MODEL_MAX_LENGTH
+ self.mlm_probability = mlm_prob
+ self.lambda_poisson = lambda_poisson
+ self.del_probability = del_prob
+ self.max_train_batch_size = max_train_batch_size # sample batch to avoid OOM
+ self.eval_code_type = eval_code_type if eval_code_type is not None else (code_types[0] if code_types else None)
+ self.eval_ppl_type = eval_ppl_type
+ self.drop_feature = drop_feature
+ self.n_num_feature = n_num_feature
+
+ assert mode in ['train', 'val', 'test']
+ self.mode = mode
+ self.is_training = (mode == 'train')
+ self.is_testing = (mode == 'test')
+
+ def __getstate__(self):
+ """Custom pickling to ensure eval attributes are preserved in multiprocessing"""
+ state = self.__dict__.copy()
+ return state
+
+ def __setstate__(self, state):
+ """Custom unpickling to restore eval attributes in multiprocessing"""
+ self.__dict__.update(state)
+ # Ensure eval attributes exist after unpickling
+ if not hasattr(self, 'eval_code_type'):
+ self.eval_code_type = self.__code_type_list__[0] if hasattr(self, '__code_type_list__') and self.__code_type_list__ else None
+ if not hasattr(self, 'eval_ppl_type'):
+ self.eval_ppl_type = 'span'
+ # Ensure mode attributes exist after unpickling
+ if not hasattr(self, 'mode'):
+ self.mode = 'val' # Default to validation mode
+ if not hasattr(self, 'is_training'):
+ self.is_training = (self.mode == 'train')
+ if not hasattr(self, 'is_testing'):
+ self.is_testing = (self.mode == 'test')
+
+ def __call__(self, samples: List[InputDataClass]) -> Dict[str, Any]:
+ # samples format
+ # [{'pid': 'x_num':[], 'x_cat':[], 'diagnosis':[[],[],[],...], 'procedure': [[],[]...], 'drug':[[],[],...] }]
+ def _seq_patient_to_promptehr(samples):
+ post_samples = []
+ for sample in samples:
+ post_sample = {}
+ visit = sample['v']
+ post_sample.update(visit)
+ if ('x' in sample) and (self.n_num_feature is not None):
+ if not isinstance(sample['x'], list):
+ sample['x'] = sample['x'].tolist()
+ post_sample['x_num'] = sample['x'][:self.n_num_feature]
+ # Only add x_cat if there are categorical features remaining
+ remaining_features = sample['x'][self.n_num_feature:]
+ if remaining_features: # Only add if non-empty
+ post_sample['x_cat'] = remaining_features
+ post_samples.append(post_sample)
+ return post_samples
+
+ samples = _seq_patient_to_promptehr(samples)
+
+ if self.is_training:
+ batch = self.call_train(samples)
+ elif self.is_testing:
+ batch = self.call_test(samples)
+ else:
+ batch = self.call_val(samples)
+ return batch
+
+ def call_train(self, samples: List[InputDataClass]) -> Dict[str, Any]:
+ '''label mask should not be used during training.
+ '''
+ batch = defaultdict(list)
+
+ # randomly pick one of code types for prediction, keep the same for this batch
+ code_type = random.sample(self.__code_type_list__, 1)[0]
+ batch['code_type'] = code_type
+
+ for sample in samples:
+ num_adm = len(sample[code_type])
+
+ # accumulated during enumerating all admisions
+ input_str_all = []
+ label_str_all = []
+ num_token_all = []
+
+ # cope with too long labtest codes
+ # start from the offset if the labtest is too long
+ adm = 0
+ while adm < num_adm:
+ span_str_list = [] # input ids
+ span_label_str_list = [] # label ids
+ num_token_this_adm = 0
+
+ # shuffle the code order
+ code_list = list(sample.keys())
+ random.shuffle(code_list)
+ for code in sample.keys():
+ if code in ['pid','x_num','x_cat']: continue
+
+ span = sample[code][adm]
+
+ if len(span) == 0: continue
+
+ # restrict the num of tokens in each span
+ span = random.sample(span, min(20, len(span)))
+
+ # translate span to code_span
+ span = self._process_span(span, code)
+
+ span_str = self._pad_special_token_head_tail(' '.join(span), code)
+ span_label_str_list.append(span_str)
+ num_token_this_adm += len(span) + 2
+
+ if code == code_type:
+ # do mask infilling / mask
+ infill_span, _, _ = self.mask_infill([span])
+ span_str = self._pad_special_token_head_tail(' '.join(infill_span[0]), code)
+ span_str_list.append(span_str)
+ else:
+ if self.__del_or_rep__[random.randint(0,1)] == 'rep':
+ rep_del_span = self.rep_token([span], code)
+ else:
+ rep_del_span = self.del_token([span])
+
+ span_str = self._pad_special_token_head_tail(' '.join(rep_del_span[0]), code)
+ span_str_list.append(span_str)
+
+ span_str_list.append('')
+ span_label_str_list.append('')
+ num_token_this_adm += 1
+
+ input_str_all.append(' '.join(span_str_list))
+ label_str_all.append(' '.join(span_label_str_list))
+ num_token_all.append(num_token_this_adm)
+
+ # check break condition
+ if len(input_str_all) >= self.max_train_batch_size:
+ break # do not sample too many examples to avoid OOM
+
+ if adm < num_adm - 1:
+ total_token_next_adm = sum(num_token_all) + len(sample[code_type][adm+1]) + 10
+ if total_token_next_adm >= self.tokenizer.model_max_length - 10:
+ break # do not sample too many tokens to avoid break
+ adm += 1
+
+ # tokenization
+ batch['input_ids'].extend(self.tokenizer(input_str_all, return_tensors='pt', padding=True, truncation=True, max_length=MODEL_MAX_LENGTH)['input_ids'])
+ batch['labels'].extend(self.tokenizer(label_str_all, return_tensors='pt', padding=True, truncation=True, max_length=MODEL_MAX_LENGTH)['input_ids'])
+ if 'x_num' in sample:
+ if not self.drop_feature:
+ batch['x_num'].extend([torch.tensor(sample['x_num'], dtype=torch.float32)] * len(input_str_all))
+ if 'x_cat' in sample:
+ if not self.drop_feature:
+ batch['x_cat'].extend([torch.tensor(sample['x_cat'], dtype=torch.long)] * len(input_str_all))
+
+ # padding
+ batch['input_ids'] = pad_sequence(batch['input_ids'], batch_first=True)
+ batch['labels'] = pad_sequence(batch['labels'], batch_first=True)
+ batch['attention_mask'] = (batch['input_ids'] != self.tokenizer.pad_token_id).float()
+ batch['label_mask'] = (batch['labels'] != self.tokenizer.pad_token_id).float()
+
+ if 'x_num' in batch:
+ batch['x_num'] = torch.stack(batch['x_num'])
+ if 'x_cat' in batch:
+ batch['x_cat'] = torch.stack(batch['x_cat'])
+
+ return dict(batch)
+
+ def call_val(self, samples: List[InputDataClass]) -> Dict[str, Any]:
+ return self.call_test(samples)
+
+ def call_test(self, samples: List[InputDataClass]) -> Dict[str, Any]:
+ '''compute the preplexity for each code type.
+ '''
+ assert self.eval_code_type is not None
+ code_type = self.eval_code_type
+ assert self.eval_ppl_type is not None
+ ppl_type = self.eval_ppl_type
+
+ batch = defaultdict(list)
+ batch['code_type'] = code_type
+
+ for sample in samples:
+ num_adm = len(sample[code_type])
+
+ # accumulated during enumerating all admisions
+ input_str_all = []
+ label_str_all = []
+
+ # cope with too long labtest codes
+ # start from the offset if the labtest is too long
+ adm = 0
+ while adm < num_adm:
+ span_str_list = [] # input ids
+ span_label_str_list = [] # label ids
+
+ for code in sample.keys():
+ if code in ['pid','x_num','x_cat']: continue
+
+ span = sample[code][adm]
+
+ if len(span) == 0: continue
+
+ # translate span to code_span
+ span = self._process_span(span, code)
+
+ span_str = self._pad_special_token_head_tail(' '.join(span), code)
+ span_label_str_list.append(span_str)
+
+ if code == code_type:
+ if ppl_type == 'spl': # single prediction loss
+ # do mask infilling / mask
+ infill_span, _, _ = self.mask_infill([span])
+ span_str = self._pad_special_token_head_tail(' '.join(infill_span[0]), code)
+ elif ppl_type == 'tpl': # teacher forcing loss
+ span_str = self._pad_special_token_head_tail(' '.join(span), code)
+ span_str_list.append(span_str)
+ else:
+ span_str = self._pad_special_token_head_tail(' '.join(span), code)
+ span_str_list.append(span_str)
+
+ span_str_list.append('')
+ span_label_str_list.append('')
+
+ input_str_all.append(' '.join(span_str_list))
+ label_str_all.append(' '.join(span_label_str_list))
+ adm += 1
+
+ # tokenization
+ batch['input_ids'].extend(self.tokenizer(input_str_all, return_tensors='pt', padding=True, truncation=True, max_length=MODEL_MAX_LENGTH)['input_ids'])
+ batch['labels'].extend(self.tokenizer(label_str_all, return_tensors='pt', padding=True, truncation=True, max_length=MODEL_MAX_LENGTH)['input_ids'])
+ if 'x_num' in sample:
+ batch['x_num'].extend([torch.tensor(sample['x_num'], dtype=torch.float32)] * len(input_str_all))
+ if 'x_cat' in sample:
+ batch['x_cat'].extend([torch.tensor(sample['x_cat'], dtype=torch.long)] * len(input_str_all))
+
+ # padding
+ batch['input_ids'] = pad_sequence(batch['input_ids'], batch_first=True)
+ batch['labels'] = pad_sequence(batch['labels'], batch_first=True)
+ batch['attention_mask'] = (batch['input_ids'] != self.tokenizer.pad_token_id).float()
+ batch['label_mask'] = (batch['labels'] != self.tokenizer.pad_token_id).float()
+
+ if 'x_num' in batch:
+ batch['x_num'] = torch.stack(batch['x_num'])
+ if 'x_cat' in batch:
+ batch['x_cat'] = torch.stack(batch['x_cat'])
+
+ return dict(batch)
+
+ def set_eval_code_type(self, code_type):
+ self.eval_code_type = code_type
+
+ def set_eval_ppl_type(self, ppl_type):
+ self.eval_ppl_type = ppl_type
+
+ def _process_span(self, span, code):
+ return [code+'_'+str(s) for s in span]
+
+ def _pad_special_token_head_tail(self, span_str, code):
+ head_tag = self.__special_token_dict__[code][0] #
+ tail_tag = self.__special_token_dict__[code][1] #
+ return head_tag + ' ' + span_str + ' ' + tail_tag
+
+ def mask_infill(self, spans):
+ '''mask tokens and infill with token
+ '''
+ results = []
+ org_tokens = []
+ labels = []
+ for span in spans:
+ num_to_mask = max(1, int(self.mlm_probability * len(span)))
+
+ if num_to_mask == len(span):
+ num_to_mask = len(span) - 1
+
+ # randomly decide the mask length
+ mask_length = np.random.poisson(self.lambda_poisson)
+ mask_length = min(mask_length, num_to_mask)
+ mask_length = max(mask_length, 1)
+
+ # randomly decide the start position to mask
+ start_pos = random.randint(0, len(span) - mask_length)
+
+ new_span = span.copy()
+ # replace the selected tokens with
+ new_span[start_pos:start_pos+mask_length] = [''] * mask_length
+ results.append(new_span)
+ org_tokens.append(span[start_pos:start_pos+mask_length])
+ labels.append([start_pos, start_pos+mask_length])
+ return results, org_tokens, labels
+
+ def rep_token(self, spans, code):
+ '''replace some tokens to the same modality randomly
+ '''
+ results = []
+ for span in spans:
+ num_to_rep = max(1, int(self.mlm_probability * len(span)))
+ rep_idx = random.sample(range(len(span)), num_to_rep)
+ new_span = span.copy()
+ for idx in rep_idx:
+ # randomly pick tokens from the same code vocab
+ rep_tokens = self.tokenizer.code_vocab[code]
+ rep_token_str = random.sample(rep_tokens.tolist(), 1)[0]
+ new_span[idx] = rep_token_str
+ results.append(new_span)
+ return results
+
+ def del_token(self, spans):
+ '''delete some tokens for data corruption
+ '''
+ results = []
+ for span in spans:
+ num_to_del = max(1, int(self.del_probability * len(span)))
+ if num_to_del == len(span): num_to_del = len(span) - 1
+ del_idx = random.sample(range(len(span)), num_to_del)
+ new_span = [span[i] for i in range(len(span)) if i not in del_idx]
+ results.append(new_span)
+ return results
+
+class MimicDataCollator:
+ '''Data collator with masking for MIMIC data.
+ '''
+ def __init__(self, tokenizer, model_tokenizer, mlm_probability=0.15, **kwargs):
+ self.tokenizer = tokenizer
+ self.model_tokenizer = model_tokenizer
+ self.mlm_probability = mlm_probability
+ for k, v in kwargs.items():
+ setattr(self, k, v)
+
+ def __call__(self, examples):
+ # Handle dict or list of dicts
+ if isinstance(examples[0], dict):
+ batch = self.tokenizer.pad(examples, return_tensors="pt")
+ else:
+ batch = self.tokenizer.pad(
+ {"input_ids": examples}, return_tensors="pt"
+ )
+
+ # If special token mask has been preprocessed, pop it from the dict.
+ special_tokens_mask = batch.pop("special_tokens_mask", None)
+ if special_tokens_mask is None:
+ special_tokens_mask = [
+ self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True)
+ for val in batch["input_ids"].tolist()
+ ]
+ special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
+ else:
+ special_tokens_mask = special_tokens_mask.bool()
+
+ batch["input_ids"], batch["labels"] = self.torch_mask_tokens(
+ batch["input_ids"], special_tokens_mask=special_tokens_mask
+ )
+ return batch
+
+ def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> Tuple[Any, Any]:
+ """
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
+ """
+ import torch
+
+ labels = inputs.clone()
+ # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
+ probability_matrix = torch.full(labels.shape, self.mlm_probability)
+ if special_tokens_mask is None:
+ special_tokens_mask = [
+ self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
+ ]
+ special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
+ else:
+ special_tokens_mask = special_tokens_mask.bool()
+
+ probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
+ masked_indices = torch.bernoulli(probability_matrix).bool()
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
+
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
+ indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
+ inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
+
+ # 10% of the time, we replace masked input tokens with random word
+ indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
+ vocab_size = len(self.tokenizer.get_vocab()) if hasattr(self.tokenizer, 'get_vocab') else len(self.tokenizer)
+ random_words = torch.randint(vocab_size, labels.shape, dtype=torch.long)
+ inputs[indices_random] = random_words[indices_random]
+
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
+ return inputs, labels
+
+# PromptEHRTrainer from original implementation
+class PromptEHRTrainer(Trainer):
+ def __init__(self,
+ model= None,
+ args = None,
+ data_collator=None,
+ train_dataset=None,
+ eval_dataset=None,
+ val_data_collator=None,
+ ):
+ super().__init__(model, args, data_collator, train_dataset, eval_dataset)
+ self.val_data_collator = val_data_collator if val_data_collator is not None else self.data_collator
+
+ def get_train_dataloader(self) -> DataLoader:
+ """
+ Returns the training [`~torch.utils.data.DataLoader`].
+ Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
+ training if necessary) otherwise.
+ Subclass and override this method if you want to inject some custom behavior.
+ """
+ if self.train_dataset is None:
+ raise ValueError("Trainer: training requires a train_dataset.")
+
+ train_dataset = self.train_dataset
+ data_collator = self.data_collator
+ train_sampler = self._get_train_sampler()
+
+ return DataLoader(
+ train_dataset,
+ batch_size=self._train_batch_size,
+ sampler=train_sampler,
+ collate_fn=data_collator,
+ drop_last=self.args.dataloader_drop_last,
+ num_workers=self.args.dataloader_num_workers,
+ pin_memory=self.args.dataloader_pin_memory,
+ )
+
+ def get_eval_dataloader(self, eval_dataset, code_type):
+ """
+ Returns the evaluation [`~torch.utils.data.DataLoader`].
+ Subclass and override this method if you want to inject some custom behavior.
+ Args:
+ eval_dataset (`torch.utils.data.Dataset`, *optional*):
+ If provided, will override `self.eval_dataset`. If it is an `datasets.Dataset`, columns not accepted by
+ the `model.forward()` method are automatically removed. It must implement `__len__`.
+ """
+ if eval_dataset is None and self.eval_dataset is None:
+ raise ValueError("Trainer: evaluation requires an eval_dataset.")
+ eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
+ self.val_data_collator.set_eval_code_type(code_type) # set evaluation for this code
+ return DataLoader(
+ eval_dataset,
+ batch_size=self.args.eval_batch_size,
+ collate_fn=self.val_data_collator,
+ num_workers=self.args.dataloader_num_workers,
+ pin_memory=self.args.dataloader_pin_memory,
+ shuffle=False,
+ drop_last=False,
+ )
+
+ def evaluate(
+ self,
+ eval_dataset=None,
+ ignore_keys=['encoder_last_hidden_state', 'past_key_values'],
+ metric_key_prefix: str = "eval",
+ ):
+ """
+ Run evaluation and returns metrics.
+
+ The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
+ (pass it to the init :obj:`compute_metrics` argument).
+
+ You can also subclass and override this method to inject custom behavior.
+
+ Args:
+ eval_dataset (:obj:`Dataset`, `optional`):
+ Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`,
+ columns not accepted by the ``model.forward()`` method are automatically removed. It must implement the
+ :obj:`__len__` method.
+ ignore_keys (:obj:`Lst[str]`, `optional`):
+ A list of keys in the output of your model (if it is a dictionary) that should be ignored when
+ gathering predictions.
+ metric_key_prefix (:obj:`str`, `optional`, defaults to :obj:`"eval"`):
+ An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
+ "eval_bleu" if the prefix is "eval" (default)
+
+ Returns:
+ A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
+ dictionary also contains the epoch number which comes from the training state.
+ """
+ # Use validation collator for proper evaluation setup
+ eval_collator = self.val_data_collator if hasattr(self, 'val_data_collator') else self.data_collator
+ eval_dataloader = self.get_eval_dataloader(eval_dataset, eval_collator.eval_code_type if hasattr(eval_collator, 'eval_code_type') else None)
+ start_time = time.time()
+
+ # Run evaluation loop
+ eval_loss = 0.0
+ nb_eval_steps = 0
+ ppl_lists = {}
+
+ # Initialize perplexity lists for each code type using validation collator
+ if hasattr(eval_collator, 'eval_code_type') and eval_collator.eval_code_type:
+ code_type = eval_collator.eval_code_type
+ ppl_lists[code_type] = []
+
+ self.model.eval()
+ for batch in eval_dataloader:
+ with torch.no_grad():
+ # Explicitly request perplexity computation during evaluation
+ batch['return_perplexity'] = True
+ outputs = self.model(**batch)
+ if hasattr(outputs, 'loss') and outputs.loss is not None:
+ eval_loss += outputs.loss.mean().item()
+
+ # Collect perplexity if available
+ if hasattr(outputs, 'perplexity') and outputs.perplexity is not None:
+ code_type = eval_collator.eval_code_type
+ if code_type and code_type in ppl_lists:
+ batch_ppl = outputs.perplexity.cpu().flatten().tolist()
+ ppl_lists[code_type].extend(batch_ppl)
+
+ nb_eval_steps += 1
+
+ eval_loss = eval_loss / nb_eval_steps if nb_eval_steps > 0 else 0.0
+
+ metrics = {
+ f"{metric_key_prefix}_loss": eval_loss,
+ f"{metric_key_prefix}_runtime": time.time() - start_time,
+ f"{metric_key_prefix}_samples": len(eval_dataloader.dataset) if hasattr(eval_dataloader.dataset, '__len__') else 0,
+ }
+
+ # Add perplexity metrics
+ for code_type, ppl_list in ppl_lists.items():
+ if ppl_list:
+ ppl_ar = np.array(ppl_list)
+ metrics[f"{metric_key_prefix}_ppl_{code_type}"] = float(np.median(ppl_ar))
+
+ return metrics
+
+# Evaluator from original implementation
+class Evaluator:
+ def __init__(self, model, dataset, collate_fn, device=None):
+ self.model = model
+ self.dataset = dataset
+ self.collate_fn = collate_fn
+ self.device = 'cpu' if device is None else device
+
+ def evaluate(self, code_type, ppl_type, eval_batch_size):
+ mimic_val_dataset = self.dataset
+ mimic_val_collator = self.collate_fn
+ mimic_val_collator.set_eval_code_type(code_type)
+ mimic_val_collator.set_eval_ppl_type(ppl_type)
+ dataloader = DataLoader(mimic_val_dataset,
+ batch_size=eval_batch_size,
+ num_workers=0,
+ drop_last=False,
+ collate_fn=mimic_val_collator,
+ shuffle=False,
+ pin_memory=False)
+
+ ppl_list = []
+ for batch in dataloader:
+ if batch is not None:
+ batch = self._prepare_inputs(batch)
+ with torch.no_grad():
+ outputs = self.model(**batch)
+ batch_ppl = outputs.perplexity
+ batch_ppl = batch_ppl.cpu().flatten().tolist()
+ ppl_list.extend(batch_ppl)
+ ppl_ar = np.array(ppl_list)
+ return np.median(ppl_ar)
+
+ def _prepare_inputs(self, data):
+ return type(data)(**{k: self._prepare_input(v) for k, v in data.items()})
+
+ def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]:
+ """
+ Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.
+ """
+ if isinstance(data, (tuple, list)):
+ return type(data)(self._prepare_input(v) for v in data)
+ elif isinstance(data, torch.Tensor):
+ kwargs = dict(device=self.device)
+ return data.to(**kwargs)
+ return data
+
+class PromptEHR(nn.Module):
+ '''
+ Initialize a PromptEHR model to leverage language models to simulate sequential patient EHR data.
+ Adapted from original PromptEHR implementation for PyHealth integration.
+ Preserves original BART-based architecture with conditional prompts.
+
+ Parameters:
+ -----------
+ dataset: SampleDataset
+ PyHealth dataset containing patient records.
+
+ code_type: list[str]
+ A list of code types that the model will learn and generate.
+ For example, `code_type=['diag','prod','med']`.
+
+ token_dict: dict[list]
+ A dictionary of new tokens (code events, e.g., ICD code) that the model needs to learn and generate.
+
+ n_num_feature: int (default=None)
+ Number of numerical patient baseline features. Notice that it assumes that the input
+ baseline features are `ALWAYS` numerical feature first. That is to say,
+ the input baseline feature = [num1, num2, .., num_n, cat1, cat2,...].
+ If not specified, the model will never include baseline features
+ for conditional generation!
+
+ cat_cardinalities: list[int]
+ The number of categories for each categorical patient baseline features.
+ The input baseline feature = [num1, num2, .., num_n, cat1, cat2,...].
+
+ device: str or list[int]
+ Should be str like `cuda:0` or `cpu`, otherwise should be a list GPU ids.
+ '''
+ sample_config = {
+ 'num_beams': 1, # >1: beam_sample; =1: sample_gen
+ 'no_repeat_ngram_size': 1,
+ 'do_sample': True,
+ 'num_return_sequences': 1,
+ 'code_type': 'diagnosis',
+ 'top_k': 1,
+ 'temperature': 1.0,
+ 'max_length': 6,
+ }
+
+ def __init__(self,
+ code_type=None,
+ n_num_feature=None,
+ cat_cardinalities=None,
+ token_dict=None,
+ epoch=50,
+ batch_size=16,
+ eval_batch_size=16,
+ eval_step=1000,
+ learning_rate=5e-5,
+ weight_decay=1e-4,
+ num_worker=8,
+ output_dir='./promptEHR_logs',
+ device='cuda:0',
+ seed=123,
+ **kwargs
+ ) -> None:
+ super().__init__()
+
+ # Initialize tokenizers from original implementation
+ self.data_tokenizer = DataTokenizer.from_pretrained('facebook/bart-base')
+
+ # will extend vocab after pass training data
+ if code_type is not None:
+ self.data_tokenizer.update_special_token_config(code_types=code_type)
+ if token_dict is not None:
+ self.data_tokenizer.extend_vocab(token_dict)
+
+ self.model_tokenizer = None # Will be created during fit() like in original
+
+ # Debug: Print vocabulary sizes
+ bart_vocab_size = len(self.data_tokenizer) # Use data_tokenizer length, not model_tokenizer
+ print(f"BART model vocabulary size: {bart_vocab_size}")
+ for ct in code_type:
+ # Use model_tokenizer.tokenizer_dict since that has the get_vocab() method
+ if self.model_tokenizer and ct in self.model_tokenizer.tokenizer_dict:
+ data_vocab_size = len(self.model_tokenizer.tokenizer_dict[ct].get_vocab())
+ print(f"{ct} data tokenizer vocab size: {data_vocab_size}")
+ if data_vocab_size > bart_vocab_size:
+ print(f"WARNING: {ct} vocab ({data_vocab_size}) exceeds BART vocab ({bart_vocab_size})")
+ else:
+ print(f"{ct}: tokenizer will be built during fit()")
+
+ self.config = {
+ 'code_type': code_type,
+ 'n_num_feature':n_num_feature,
+ 'cat_cardinalities':cat_cardinalities,
+ 'epoch':epoch,
+ 'batch_size':batch_size,
+ 'eval_batch_size':eval_batch_size,
+ 'eval_step':eval_step,
+ 'learning_rate':learning_rate,
+ 'weight_decay':weight_decay,
+ }
+ self.device_name = device
+ if isinstance(device, list):
+ self._set_visible_device(device=device)
+
+ # Add training arguments from original implementation (deferred to avoid accelerate dependency)
+ self.training_args = None
+ self._training_config = {
+ 'per_device_train_batch_size': batch_size,
+ 'per_device_eval_batch_size': eval_batch_size,
+ 'gradient_accumulation_steps': 1,
+ 'learning_rate': learning_rate,
+ 'weight_decay': weight_decay,
+ 'output_dir': output_dir,
+ 'num_train_epochs': epoch,
+ 'save_steps': eval_step,
+ 'eval_steps': eval_step,
+ 'warmup_ratio': 0.06,
+ 'max_grad_norm': 0.5,
+ 'save_total_limit': 5,
+ 'logging_steps': eval_step,
+ 'dataloader_num_workers': num_worker,
+ 'dataloader_pin_memory': True,
+ 'eval_strategy': 'steps',
+ 'metric_for_best_model': f'eval_ppl_{code_type[0]}' if code_type is not None else None,
+ 'greater_is_better': False,
+ 'eval_accumulation_steps': 10,
+ 'load_best_model_at_end': True,
+ 'logging_dir': output_dir,
+ 'overwrite_output_dir': True,
+ 'seed': seed,
+ 'no_cuda': True if self.device_name == 'cpu' else False,
+ }
+
+ # avoid dead clock when taking multiple workers for dataloaders
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
+
+ # Model will be built during fit() like in original
+ self.model = None
+
+
+ def predict(self, test_data, n_per_sample=None, n=None, sample_config=None, verbose=None):
+ '''
+ Generate synthetic records based on input real patient seq data.
+
+ Parameters
+ ----------
+ test_data: SequencePatient
+ A `SequencePatient` contains patient records where 'v' corresponds to
+ visit sequence of different events.
+
+ n: int
+ How many samples in total will be generated.
+
+ n_per_sample: int
+ How many samples generated based on each indivudals.
+
+ sample_config: dict
+ Configuration for sampling synthetic records, key parameters:
+ 'num_beams': Number of beams in beam search, if set `1` then beam search is deactivated;
+ 'top_k': Sampling from top k candidates.
+ 'temperature': temperature to make sampling distribution flater or skewer.
+
+ verbose: bool
+ If print the progress bar or not.
+
+ Returns
+ -------
+ Synthetic patient records in `SequencePatient` format.
+ '''
+ if n is not None: assert isinstance(n, int), 'Input `n` should be integer.'
+ if n_per_sample is not None: assert isinstance(n_per_sample, int), 'Input `n_per_sample` should be integer.'
+ assert (not n_per_sample is None) or (not n is None), 'Either `n` or `n_per_sample` should be provided to generate.'
+ assert isinstance(self.model, BartForEHRSimulation), 'Model not found! Please fit the model or load the model from pretrained checkpoint first.'
+
+ n, n_per_sample = self._compute_n_per_sample(len(test_data), n, n_per_sample)
+
+ if sample_config is not None:
+ self.sample_config.update(sample_config)
+ print('### Sampling Config ###')
+ print(self.sample_config)
+
+ # get test data loader
+ test_dataloader = self._get_test_dataloader(test_data)
+
+ # make generation
+ outputs = self._predict_on_dataloader(test_dataloader, n, n_per_sample, verbose=verbose)
+
+ # formulate outputs to standard sequencepatient data
+ # need 'visit', 'order', 'feature', 'n_num_feature', 'cat_cardinalties'
+ visits, features, labels = [], [], []
+ for output in outputs:
+ code_types = [c for c in self.config['code_type'] if c in output]
+ num_visit = len(output[code_types[0]])
+ visit, feature = [], []
+ for n in range(num_visit):
+ visit_ = [output[code][n] for code in code_types]
+ visit.append(visit_)
+ visits.append(visit)
+ if 'x_num' in output:
+ feature.extend(output['x_num'])
+ if 'x_cat' in output:
+ feature.extend(output['x_cat'])
+ if len(feature) > 0:
+ features.append(feature)
+ if 'y' in output: labels.append(output['y'])
+
+ if len(features) > 0:
+ features = np.stack(features, 0)
+ else:
+ features = None
+
+ return_res = {
+ 'visit':visits,
+ 'feature':features,
+ 'order':self.config['code_type'],
+ 'n_num_feature':self.config['n_num_feature'],
+ 'cat_cardinalties':self.config['cat_cardinalities'],
+ 'y':labels,
+ 'voc': test_data.metadata['voc'],
+ }
+ return return_res
+
+ # fit() method from original PromptEHR implementation
+ def fit(self, train_data, val_data=None):
+ '''
+ Fit PromptEHR model on the input training EHR data.
+
+ Parameters
+ ----------
+ train_data: SequencePatient
+ A `SequencePatient` contains patient records where 'v' corresponds to
+ visit sequence of different events.
+
+ val_data: dict
+ A `SequencePatient` contains patient records where 'v' corresponds to
+ visit sequence of different events.
+ '''
+ # create tokenizers based on the input data
+ self._create_tokenizers(train_data)
+
+ # can only build model after fit
+ self._build_model()
+
+ # start training
+ self._fit(train_data=train_data,val_data=val_data)
+
+ def save_model(self, output_dir):
+ '''
+ Save the learned simulation model to the disk.
+
+ Parameters
+ ----------
+ output_dir: str
+ The dir to save the learned model.
+ '''
+ make_dir_if_not_exist(output_dir)
+ self._save_config(config=self.config, output_dir=output_dir)
+ self._save_checkpoint(output_dir=output_dir)
+ print('Save the trained model to:', output_dir)
+
+ def from_pretrained(self, input_dir='./simulation/pretrained_promptEHR'):
+ '''
+ Load pretrained PromptEHR model and make patient EHRs generation.
+ Pretrained model was learned from MIMIC-III patient sequence data.
+ '''
+ if input_dir is None:
+ input_dir = './simulation/pretrained_promptEHR'
+
+ if not os.path.exists(input_dir):
+ os.makedirs(input_dir)
+ url = PRETRAINED_MODEL_URL
+ download_pretrained(url, input_dir)
+ print(f'Download pretrained PromptEHR model, save to {input_dir}.')
+
+ print('Load pretrained PromptEHR model from', input_dir)
+ self.load_model(input_dir)
+
+ def load_model(self, checkpoint):
+ '''
+ Load model and the pre-encoded trial embeddings from the given
+ checkpoint dir.
+
+ Parameters
+ ----------
+ checkpoint: str
+ The input dir that stores the pretrained model.
+ '''
+ checkpoint_filename = check_checkpoint_file(checkpoint)
+ config_filename = check_model_config_file(checkpoint)
+ data_tokenizer_file, model_tokenizer_file = check_tokenizer_file(checkpoint)
+
+ # load config
+ self.config = self._load_config(config_filename)
+
+ # load data tokenizer and model tokenizer
+ self._load_tokenizer(data_tokenizer_file, model_tokenizer_file)
+
+ # load configuration
+ self.configuration = EHRBartConfig(self.data_tokenizer, self.model_tokenizer, n_num_feature=self.config['n_num_feature'], cat_cardinalities=self.config['cat_cardinalities'])
+ self.configuration.from_pretrained(checkpoint)
+
+ # build model
+ self._build_model()
+
+ # load checkpoint
+ state_dict = torch.load(checkpoint_filename, map_location='cpu')
+ self.load_state_dict(state_dict, strict=True)
+ print('Load the pre-trained model from:', checkpoint)
+
+ def _build_model(self):
+ """Build the BartForEHRSimulation model using the current configuration."""
+ self.model = BartForEHRSimulation(self.configuration)
+ self._setup_device()
+
+ def _setup_device(self):
+ # check if cuda is available using torch
+ if not torch.cuda.is_available():
+ warnings.warn('No GPU found, using CPU instead.')
+ self.device_name = 'cpu'
+
+ if isinstance(self.device_name, list):
+ self._set_visible_device(self.device_name)
+ self.model.cuda()
+ elif 'cuda' in self.device_name:
+ self.model.cuda()
+ else:
+ # on cpu
+ self._set_visible_device([])
+ self.model.cpu()
+
+ def _set_visible_device(self, device):
+ if len(device) > 0:
+ os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(d) for d in device])
+ else:
+ os.environ['CUDA_VISIBLE_DEVICES'] = ''
+
+ def _compute_n_per_sample(self, n_test_sample, n=None, n_per_sample=None):
+ if n_per_sample is not None:
+ n_total = n_test_sample*n_per_sample
+ if n is not None:
+ n_total = min(n_total, n)
+ return n_total, n_per_sample
+ else:
+ return n, math.ceil(n / n_test_sample)
+
+ def _get_test_dataloader(self, dataset):
+ def _seq_patient_to_promptehr(samples):
+ post_samples = []
+ for sample in samples:
+ post_sample = {}
+ visit = sample['v']
+ post_sample.update(visit)
+
+ if ('x' in sample) and (self.config['n_num_feature'] is not None):
+ if not isinstance(sample['x'], list):
+ sample['x'] = sample['x'].tolist()
+ post_sample['x_num'] = torch.tensor(sample['x'][:self.config['n_num_feature']])
+ post_sample['x_cat'] = torch.tensor(sample['x'][self.config['n_num_feature']:], dtype=int)
+
+ if 'y' in sample:
+ post_sample['y'] = sample['y']
+
+ post_samples.append(post_sample)
+ return post_samples
+
+ dataloader = DataLoader(dataset,
+ batch_size=1, # one patient once
+ drop_last=False,
+ num_workers=0,
+ pin_memory=False,
+ shuffle=False,
+ collate_fn=_seq_patient_to_promptehr,
+ )
+ return dataloader
+
+ def _save_config(self, config, output_dir=None):
+ temp_path = os.path.join(output_dir, 'promptehr_config.json')
+ with open(temp_path, 'w', encoding='utf-8') as f:
+ f.write(
+ json.dumps(config, indent=4)
+ )
+
+ # save the data tokenizer and model tokenizer of the model
+ temp_path = os.path.join(output_dir, 'data_tokenizer.pkl')
+ with open(temp_path, 'wb') as f:
+ pickle.dump(self.data_tokenizer, f)
+
+ temp_path = os.path.join(output_dir, 'model_tokenizer.pkl')
+ with open(temp_path, 'wb') as f:
+ pickle.dump(self.model_tokenizer, f)
+
+ # save configuration
+ self.configuration.save_pretrained(output_dir)
+
+ def _load_tokenizer(self, data_tokenizer_file, model_tokenizer_file):
+ with open(data_tokenizer_file, 'rb') as f:
+ self.data_tokenizer = pickle.load(f)
+ self.data_tokenizer._in_target_context_manager = False # fix bugs when upgrade transformers to 4.23
+
+ with open(model_tokenizer_file, 'rb') as f:
+ self.model_tokenizer = pickle.load(f)
+
+ def _load_config(self, filename):
+ with open(filename, 'r') as f:
+ config = json.load(f)
+ return config
+
+ def _save_checkpoint(self,
+ epoch_id=0,
+ is_best=False,
+ output_dir=None,
+ filename='checkpoint.pth.tar'):
+
+ if epoch_id < 1:
+ filepath = os.path.join(output_dir, 'latest.' + filename)
+ elif is_best:
+ filepath = os.path.join(output_dir, 'best.' + filename)
+ else:
+ filepath = os.path.join(output_dir, str(epoch_id) + '.' + filename)
+
+ # save statedict
+ state_dict = self.state_dict()
+ torch.save(state_dict, filepath)
+
+ def _predict_on_dataloader(self, dataloader, n, n_per_sample, verbose=None):
+ total_number = 0
+ data_iterator = iter(dataloader)
+
+ if verbose:
+ pbar = tqdm(total=n)
+
+ new_record_list = []
+ while total_number < n:
+ try:
+ data = next(data_iterator)
+ except:
+ data_iterator = iter(dataloader)
+ data = next(data_iterator)
+ data = data[0] # batch size is 1 when doing generation
+
+ # to device
+ device = 'cpu' if self.device_name == 'cpu' else 'cuda:0'
+ if 'x_num' in data: data['x_num'] = data['x_num'].to(device)
+ if 'x_cat' in data: data['x_cat'] = data['x_cat'].to(device)
+
+ inputs = self._prepare_input_for_generation(data)
+
+ # start generation
+ for _ in range(n_per_sample):
+ new_record = self._generation_loop(data, inputs)
+ if 'x_cat' in data:
+ new_record.update({
+ 'x_cat':data['x_cat'].cpu().numpy().tolist(),
+ })
+
+ if 'x_num' in data:
+ new_record.update({
+ 'x_num':data['x_num'].cpu().numpy().tolist(),
+ })
+
+ # add more features to new_record
+ for k,v in data.items():
+ if k not in new_record:
+ new_record[k] = v
+ new_record_list.append(new_record)
+
+ total_number += n_per_sample
+ if verbose:
+ pbar.update(n_per_sample)
+
+ if verbose:
+ pbar.close()
+ return new_record_list
+
+ def _prepare_input_for_generation(self, data):
+ def _process_span(span, code):
+ return [code+'_'+str(s) for s in span]
+
+ def _to_device(x, device):
+ for k,v in x.items():
+ x[k] = v.to(device)
+ return x
+
+ tokenizer = self.data_tokenizer
+ code_type = [k for k in data.keys() if k in self.config['code_type']]
+ num_visit = len(data[code_type[0]])
+
+ # init codes
+ init_code = random.sample(data[code_type[0]][0], 1)
+ init_code_str = _process_span(init_code, code_type[0])
+ init_codes = tokenizer(init_code_str, return_tensors='pt', add_special_tokens=False)
+ bos = torch.tensor([tokenizer.bos_token_id], device=self.model.device)
+ code_prompt_idx = tokenizer.encode(tokenizer.special_token_dict[code_type[0]], add_special_tokens=False, return_tensors='pt')
+ init_input_ids = torch.cat([bos[:,None],code_prompt_idx[:,0,None],init_codes['input_ids']], dim=-1)
+ init_input_ids = _to_device({'input_ids':init_input_ids}, self.model.device)['input_ids']
+ input_ids = init_input_ids.clone()
+ return {'input_ids':input_ids, 'init_input_ids':init_input_ids, 'num_visit':num_visit, 'init_code':init_code}
+
+ def _generation_loop(self, data, inputs):
+ new_record = defaultdict(list)
+ tokenizer = self.data_tokenizer
+ special_token_dict = self.data_tokenizer.special_token_dict
+ sample_gen_kwargs = self.sample_config.copy()
+
+ input_ids_list = []
+ num_visit_code_list = []
+ first_code_flag = True
+
+ input_ids = inputs['input_ids']
+ for visit in range(inputs['num_visit']):
+ this_visit_ids_list = []
+ for code in self.config['code_type']:
+ target_list = data[code][visit]
+ sample_gen_kwargs['code_type'] = code
+ num_code = len(target_list)
+ if num_code > 20:
+ num_code = min(num_code, 20)
+ target_list = np.random.choice(target_list, num_code, replace=False).tolist()
+
+ # random select part of codes from target list
+ target_ar = np.array(target_list)
+ sub_code = target_ar[np.random.binomial(1, 0.5, num_code).astype(bool)]
+ code_prompt_idx = [special_token_dict[code][0]] + sub_code.tolist() + [special_token_dict[code][1]]
+ code_prompt_idx = tokenizer.encode(code_prompt_idx, add_special_tokens=False, return_tensors='pt')
+ code_prompt_idx = code_prompt_idx.to(self.model.device)
+
+ if num_code == 0:
+ if first_code_flag:
+ new_next_tokens = code_prompt_idx[:,-1,None]
+ first_code_flag = False
+ else:
+ new_next_tokens = code_prompt_idx
+
+ this_visit_ids_list.append(new_next_tokens)
+ input_ids = torch.cat([input_ids, new_next_tokens], dim=-1)
+ new_record[code].append([])
+
+ else:
+ sample_gen_kwargs['max_length'] = num_code+2
+
+ # do conditional generation
+ if 'x_cat' in data:
+ sample_gen_kwargs['x_cat'] = data['x_cat']
+ if 'x_num' in data:
+ sample_gen_kwargs['x_num'] = data['x_num']
+
+ new_next_tokens = self.model.generate(input_ids, **sample_gen_kwargs)
+
+ # randomly pick / rm sub code overlap
+ new_next_tokens = new_next_tokens[:,1:-1]
+ new_next_tokens = np.setdiff1d(new_next_tokens[0].cpu(), code_prompt_idx[0].cpu())
+
+ try:
+ if num_code-len(sub_code) > len(new_next_tokens):
+ new_sub_idxs = np.unique(np.random.choice(np.arange(len(new_next_tokens)), num_code-len(sub_code), replace=True))
+ else:
+ new_sub_idxs = np.unique(np.random.choice(np.arange(len(new_next_tokens)), num_code-len(sub_code), replace=False))
+ except:
+ pdb.set_trace()
+ pass
+ new_next_tokens = torch.tensor(new_next_tokens[None, new_sub_idxs]).to(code_prompt_idx.device)
+
+ # append to the synthetic record dict
+ code_str_list = tokenizer.batch_decode(new_next_tokens)[0]
+
+ # remove special tokens ahead of original code event
+ # e.g., `diag_384` -> `384`
+ code_str_list = code_str_list.replace(code+'_','')
+ code_str_list = code_str_list.split()
+ code_str_list = [int(c) for c in code_str_list+sub_code.tolist()]
+ new_record[code].append(list(set(code_str_list)))
+
+ if first_code_flag:
+ new_next_tokens = torch.cat([new_next_tokens, code_prompt_idx[:,1:]], dim=-1)
+ first_code_flag = False
+ else:
+ # cover by modality prompt
+ new_next_tokens = torch.cat([code_prompt_idx[:,:-1], new_next_tokens, code_prompt_idx[:,-1,None]], dim=-1)
+
+ if visit > 1:
+ # check input length
+ cur_len = input_ids.shape[1] + new_next_tokens.shape[1]
+ while cur_len >= tokenizer.model_max_length:
+ print(f'{cur_len} reach model max length {tokenizer.model_max_length}, do cut.')
+ input_ids_list = input_ids_list[1:]
+ num_visit_code_list = num_visit_code_list[1:]
+ input_ids = torch.cat(input_ids_list,dim=-1)
+ cur_len = input_ids.shape[1] + new_next_tokens.shape[1]
+
+ # concat
+ this_visit_ids_list.append(new_next_tokens)
+ input_ids = torch.cat([input_ids, new_next_tokens], dim=-1)
+
+ # after one visit, add eos token id
+ eos = torch.tensor([tokenizer.eos_token_id], device=self.model.device)
+ input_ids = torch.cat([input_ids, eos[:,None]], dim=-1)
+ this_visit_ids = torch.cat(this_visit_ids_list, dim=-1)
+ this_visit_ids = torch.cat([this_visit_ids, eos[:,None]], dim=-1)
+ if visit == 0: this_visit_ids = torch.cat([inputs['init_input_ids'], this_visit_ids], dim=-1)
+ num_visit_code_list.append(this_visit_ids.shape[-1])
+ input_ids_list.append(this_visit_ids)
+
+ # add init code
+ new_record[self.config['code_type'][0]][0] += inputs['init_code']
+ return new_record
+
+ # _create_tokenizers() method from original PromptEHR implementation
+ def _create_tokenizers(self, train_data):
+ # update data_tokenizer first
+ def _collate_fn(inputs):
+ outputs = defaultdict(list)
+ for input in inputs:
+ visit = input['v']
+ for k,v in visit.items():
+ code_list = sum(v,[])
+ code_list = [k+'_'+str(c) for c in list(set(code_list))]
+ outputs[k].extend(code_list)
+ return outputs
+ dataloader = DataLoader(train_data, collate_fn=_collate_fn, batch_size=512, shuffle=False)
+ for batch in dataloader:
+ for k,v in batch.items():
+ unq_codes = list(set(v))
+ self.data_tokenizer.add_token_to_code_vocab(unq_codes, k)
+
+ self.model_tokenizer = ModelTokenizer(self.data_tokenizer)
+ self.configuration = EHRBartConfig(self.data_tokenizer, self.model_tokenizer, n_num_feature=self.config['n_num_feature'], cat_cardinalities=self.config['cat_cardinalities'])
+ self.data_tokenizer.update_special_token_config(code_types=self.config['code_type'])
+
+ def _build_model(self):
+ """Build the BartForEHRSimulation model using the current configuration."""
+ self.model = BartForEHRSimulation(self.configuration)
+ self._setup_device()
+
+ def _fit(self, train_data, val_data):
+ # Create TrainingArguments only when training is actually called
+ if self.training_args is None:
+ self.training_args = TrainingArguments(**self._training_config)
+
+ mimic_train_collator = FullMimicDataCollator(self.data_tokenizer,
+ code_types=self.config['code_type'],
+ n_num_feature=self.config['n_num_feature'],
+ max_train_batch_size=self.config['batch_size'], mode='train')
+
+ # Create validation collator with eval parameters set during initialization
+ eval_code_type = self.config['code_type'][0] if self.config['code_type'] else None
+ mimic_val_collator = FullMimicDataCollator(self.data_tokenizer,
+ code_types=self.config['code_type'],
+ n_num_feature=self.config['n_num_feature'],
+ mode='val',
+ eval_code_type=eval_code_type,
+ eval_ppl_type='span')
+
+ trainer = PromptEHRTrainer(
+ model=self.model,
+ args=self.training_args,
+ train_dataset=train_data,
+ data_collator=mimic_train_collator,
+ eval_dataset=val_data,
+ val_data_collator=mimic_val_collator,
+ )
+ try:
+ trainer.train()
+ except Exception as e:
+ print(f"Training failed with error: {e}")
+ import traceback
+ traceback.print_exc()
+ raise e
+
+ # evaluate() method from original PromptEHR implementation
+ def evaluate(self, test_data):
+ '''
+ Evaluate the trained PromptEHR model on the input data, will test the perplexity
+ for each type of codes.
+
+ Parameters
+ ----------
+ test_data: PatientSequence
+ Standard sequential patient records in `PatientSequence` format.
+ '''
+ self.model.eval()
+ self.eval()
+
+ collator = FullMimicDataCollator(
+ self.data_tokenizer,
+ code_types=self.config['code_type'],
+ n_num_feature=self.config['n_num_feature'],
+ mode='test',
+ drop_feature=False
+ )
+
+ evaluator = Evaluator(
+ self.model,
+ test_data,
+ collator,
+ device='cpu' if self.device_name == 'cpu' else 'cuda:0',
+ )
+
+ code_types = self.config['code_type']
+ ppl_types = ['tpl','spl']
+ for code_type in code_types:
+ for ppl_type in ppl_types:
+ ppl = evaluator.evaluate(code_type, ppl_type, eval_batch_size=self.config['eval_batch_size'])
+ print(f'code: {code_type}, ppl_type: {ppl_type}, value: {ppl}')
+
+ # update_config() method from original PromptEHR implementation
+ def update_config(self, config):
+ '''
+ Update the configuration of the model.
+
+ Parameters
+ ----------
+ config: dict
+ The configuration of the model.
+ Refer to the `config` in `__init__` for more details.
+ '''
+ self.config.update(config)
+
+ # update training args
+ train_args = copy.deepcopy(config)
+ for k, v in config.items():
+ if k in config_to_train_args:
+ train_args[config_to_train_args[k]] = v
+ train_args.pop(k)
+
+ for k,v in train_args.items():
+ if hasattr(self.training_args, k):
+ setattr(self.training_args, k, v)
+
+ # important when you train the model with different datasets
+ code_type = self.config['code_type']
+ self.training_args.metric_for_best_model = \
+ f'eval_ppl_{code_type[0]}' if code_type is not None else None,
+
+ print('### Model Config ###')
+ print(self.config)
+
+ print('### Training Args ###')
+ print(self.training_args)
+
+# Utility functions from original implementation
+def download_pretrained(url, output_dir):
+ import wget
+ import zipfile
+ filename = wget.download(url=url, out=output_dir)
+ zipf = zipfile.ZipFile(filename, 'r')
+ zipf.extractall(output_dir)
+ zipf.close()
+
+def make_dir_if_not_exist(path):
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+def check_checkpoint_file(input_dir, suffix='pth.tar'):
+ '''
+ Check whether the `input_path` is directory or to the checkpoint file.
+ If it is a directory, find the only 'pth.tar' file under it.
+
+ Parameters
+ ----------
+ input_path: str
+ The input path to the pretrained model.
+ suffix: 'pth.tar' or 'model'
+ The checkpoint file suffix;
+ If 'pth.tar', the saved model is a torch model.
+ If 'model', the saved model is a scikit-learn based model.
+ '''
+ suffix = '.' + suffix
+ if input_dir.endswith(suffix):
+ return input_dir
+
+ ckpt_list = glob.glob(os.path.join(input_dir, '*'+suffix))
+ assert len(ckpt_list) <= 1, f'Find more than one checkpoints under the dir {input_dir}, please specify the one to load.'
+ assert len(ckpt_list) > 0, f'Do not find any checkpoint under the dir {input_dir}.'
+ return ckpt_list[0]
+
+def check_model_config_file(input_dir):
+ '''
+ Check whether the `input_path` is directory or to the `model_config.json` file.
+ If it is a directory, find the only '.json' file under it.
+
+ Parameters
+ ----------
+ input_path: str
+ The input path to the pretrained model.
+ '''
+ if input_dir.endswith('.json'):
+ return input_dir
+
+ if not os.path.isdir(input_dir):
+ # if the input_dir is the given checkpoint model path,
+ # we need to find the config file under the same dir.
+ input_dir = os.path.dirname(input_dir)
+
+ ckpt_list = glob.glob(os.path.join(input_dir, '*.json'))
+
+ if len(ckpt_list) == 0:
+ return None
+
+ # find model_config.json under this input_dir
+ model_config_name = [config for config in ckpt_list if 'promptehr_config.json' in config]
+ if len(model_config_name) == 1:
+ return model_config_name[0]
+
+ # if no model_config.json found, retrieve the only .json file.
+ assert len(ckpt_list) <= 1, f'Find more than one config .json under the dir {input_dir}.'
+ return ckpt_list[0]
+
+def check_tokenizer_file(input_dir):
+ return os.path.join(input_dir,'data_tokenizer.pkl'), os.path.join(input_dir,'model_tokenizer.pkl')
\ No newline at end of file
diff --git a/run_promptehr_synthetic_generation.slurm b/run_promptehr_synthetic_generation.slurm
new file mode 100644
index 000000000..5c90c3c1d
--- /dev/null
+++ b/run_promptehr_synthetic_generation.slurm
@@ -0,0 +1,95 @@
+#!/bin/bash
+#SBATCH --account=jalenj4-ic
+#SBATCH --job-name=promptehr_synthetic
+#SBATCH --output=logs/promptehr_synthetic_%j.out
+#SBATCH --error=logs/promptehr_synthetic_%j.err
+#SBATCH --partition=IllinoisComputes-GPU
+#SBATCH --gres=gpu:1
+#SBATCH --cpus-per-task=8
+#SBATCH --mem=64G
+#SBATCH --time=24:00:00
+
+cd "$SLURM_SUBMIT_DIR"
+
+cd /u/jalenj4/PyHealth
+source pyhealth/bin/activate
+export PYTHONPATH=/u/jalenj4/PyHealth:$PYTHONPATH
+
+export CUDA_LAUNCH_BLOCKING=1
+export TORCH_USE_CUDA_DSA=1
+mkdir -p logs promptehr_synthetic_output
+
+pip install 'accelerate>=0.26.0' --quiet
+
+echo "Starting PromptEHR pipeline"
+export TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
+export OUTPUT_DIR="./promptehr_synthetic_output_${TIMESTAMP}"
+export MIMIC_ROOT="./data_files"
+export MODE="train_and_generate"
+export N_SYNTHETIC="1000"
+export EPOCHS="1"
+export BATCH_SIZE="8"
+export LEARNING_RATE="1e-4"
+export MAX_VISITS="10"
+export MIN_VISITS="2"
+
+
+if [ "$MODE" = "preprocess_only" ]; then
+ python examples/promptehr_mimic3_synthetic_generation.py \
+ --mode preprocess_only \
+ --mimic_root $MIMIC_ROOT \
+ --output_dir $OUTPUT_DIR \
+ --max_visits $MAX_VISITS \
+ --min_visits $MIN_VISITS \
+ --code_vocab_threshold 5 \
+ --train_ratio 0.8
+
+elif [ "$MODE" = "generate_only" ]; then
+ if [ -z "$MODEL_PATH" ]; then
+ echo "ERROR: MODEL_PATH required for generate_only mode"
+ exit 1
+ fi
+ python examples/promptehr_mimic3_synthetic_generation.py \
+ --mode generate_only \
+ --model_path $MODEL_PATH \
+ --output_dir $OUTPUT_DIR \
+ --n_synthetic $N_SYNTHETIC \
+ --temperature 1.0
+
+else
+ python examples/promptehr_mimic3_synthetic_generation.py \
+ --mode train_and_generate \
+ --mimic_root $MIMIC_ROOT \
+ --output_dir $OUTPUT_DIR \
+ --max_visits $MAX_VISITS \
+ --min_visits $MIN_VISITS \
+ --epochs $EPOCHS \
+ --batch_size $BATCH_SIZE \
+ --learning_rate $LEARNING_RATE \
+ --device cuda \
+ --n_synthetic $N_SYNTHETIC \
+ --temperature 1.0 \
+ --code_vocab_threshold 5 \
+ --train_ratio 0.8 \
+ --include_procedures \
+ --include_medications
+fi
+
+if [ "$MODE" = "preprocess_only" ]; then
+ echo "Preprocessing completed"
+ echo "Data in: $OUTPUT_DIR"
+elif [ "$MODE" = "generate_only" ]; then
+ echo "Generation completed"
+ echo "Files in: $OUTPUT_DIR"
+else
+ echo "Pipeline completed"
+ echo "Output: $OUTPUT_DIR"
+ echo "Files:"
+ echo " $OUTPUT_DIR/synthetic/promptehr_synthetic_binary.npy"
+ echo " $OUTPUT_DIR/synthetic/promptehr_synthetic_raw.pkl"
+ echo " $OUTPUT_DIR/synthetic/code_mapping.pkl"
+ echo " $OUTPUT_DIR/synthetic/generation_stats.json"
+ echo " $OUTPUT_DIR/synthetic/synthetic_patients_summary.csv"
+ echo " $OUTPUT_DIR/synthetic/synthetic_code_frequencies.csv"
+ echo " $OUTPUT_DIR/synthetic/synthetic_patient_diagnoses_sparse.csv"
+fi
\ No newline at end of file