diff --git a/examples/synthetic_data_generation_mimic3_medgan.py b/examples/synthetic_data_generation_mimic3_medgan.py new file mode 100644 index 000000000..f9a5c2e59 --- /dev/null +++ b/examples/synthetic_data_generation_mimic3_medgan.py @@ -0,0 +1,390 @@ +""" +Synthetic data generation using MedGAN on MIMIC-III data. + +This example demonstrates how to train MedGAN to generate synthetic ICD-9 matrices +from MIMIC-III data, following PyHealth conventions. +""" + +import os +import torch +import numpy as np +import argparse +from torch.utils.data import DataLoader +import pickle +import json +from tqdm import tqdm +import pandas as pd + +from pyhealth.datasets import MIMIC3Dataset +from pyhealth.datasets.icd9_matrix import create_icd9_matrix, ICD9MatrixDataset +from pyhealth.models.generators.medgan import MedGAN + +""" +python examples/synthetic_data_generation_mimic3_medgan.py --autoencoder_epochs 5 --gan_epochs 10 --batch_size 16 +""" +def train_medgan(model, dataloader, n_epochs, device, save_dir, lr=0.001, weight_decay=0.0001, b1=0.5, b2=0.9): + """ + Train MedGAN model using the original synthEHRella approach. + + Args: + model: MedGAN model + dataloader: DataLoader for training data + n_epochs: Number of training epochs + device: Device to train on + save_dir: Directory to save checkpoints + lr: Learning rate + weight_decay: Weight decay for regularization + b1: Beta1 for Adam optimizer + b2: Beta2 for Adam optimizer + + Returns: + loss_history: Dictionary containing loss history + """ + + def generator_loss(y_fake): + """ + Original synthEHRella generator loss + """ + # standard GAN generator loss - want fake samples to be classified as real + return -torch.mean(torch.log(y_fake + 1e-12)) + + def discriminator_loss(outputs, labels): + """ + Original synthEHRella discriminator loss + """ + loss = -torch.mean(labels * torch.log(outputs + 1e-12)) - torch.mean((1 - labels) * torch.log(1. - outputs + 1e-12)) + return loss + + optimizer_g = torch.optim.Adam([ + {'params': model.generator.parameters()}, + {'params': model.autoencoder.decoder.parameters(), 'lr': lr * 0.1} + ], lr=lr, betas=(b1, b2), weight_decay=weight_decay) + + optimizer_d = torch.optim.Adam(model.discriminator.parameters(), + lr=lr * 0.1, betas=(b1, b2), weight_decay=weight_decay) + + g_losses = [] + d_losses = [] + + print("="*60) + print("Epoch | D_loss | G_loss | Progress") + print("="*60) + + for epoch in range(n_epochs): + epoch_g_loss = 0.0 + epoch_d_loss = 0.0 + num_batches = 0 + + for i, real_data in enumerate(dataloader): + real_data = real_data.to(device) + batch_size = real_data.size(0) + + valid = torch.ones(batch_size).to(device) # 1D tensor + fake = torch.zeros(batch_size).to(device) # 1D tensor + + z = torch.randn(batch_size, model.latent_dim).to(device) + + # Disable discriminator gradients for generator training to prevent discriminator from being updated + for p in model.discriminator.parameters(): + p.requires_grad = False + + # generate fake samples + fake_samples = model.generator(z) + fake_samples = model.autoencoder.decode(fake_samples) + + # generator loss using original medgan loss function + fake_output = model.discriminator(fake_samples).view(-1) + g_loss = generator_loss(fake_output) + + optimizer_g.zero_grad() + g_loss.backward() + optimizer_g.step() + + # --------------------- + # Train Discriminator + # --------------------- + + # Enable discriminator gradients + for p in model.discriminator.parameters(): + p.requires_grad = True + + optimizer_d.zero_grad() + + # Real samples + real_output = model.discriminator(real_data).view(-1) + real_loss = discriminator_loss(real_output, valid) + real_loss.backward() + + # Fake samples (detached) + fake_output = model.discriminator(fake_samples.detach()).view(-1) + fake_loss = discriminator_loss(fake_output, fake) + fake_loss.backward() + + # Total discriminator loss + d_loss = (real_loss + fake_loss) / 2 + + optimizer_d.step() + + # Track losses + epoch_g_loss += g_loss.item() + epoch_d_loss += d_loss.item() + num_batches += 1 + + # calculate average losses + avg_g_loss = epoch_g_loss / num_batches + avg_d_loss = epoch_d_loss / num_batches + + # store losses for trackin + g_losses.append(avg_g_loss) + d_losses.append(avg_d_loss) + + progress = (epoch + 1) / n_epochs * 100 + print(f"{epoch+1:5d} | {avg_d_loss:.4f} | {avg_g_loss:.4f} | {progress:5.1f}%") + + # save every 50 epochs + if (epoch + 1) % 50 == 0: + checkpoint_path = os.path.join(save_dir, f"medgan_epoch_{epoch+1}.pth") + torch.save({ + 'epoch': epoch + 1, + 'generator_state_dict': model.generator.state_dict(), + 'discriminator_state_dict': model.discriminator.state_dict(), + 'autoencoder_state_dict': model.autoencoder.state_dict(), + 'optimizer_g_state_dict': optimizer_g.state_dict(), + 'optimizer_d_state_dict': optimizer_d.state_dict(), + 'g_losses': g_losses, + 'd_losses': d_losses, + }, checkpoint_path) + print(f"Checkpoint saved to {checkpoint_path}") + + print("="*60) + print("GAN Training Completed!") + print(f"Final G_loss: {g_losses[-1]:.4f}") + print(f"Final D_loss: {d_losses[-1]:.4f}") + + # save loss history + loss_history = { + 'g_losses': g_losses, + 'd_losses': d_losses, + } + np.save(os.path.join(save_dir, "loss_history.npy"), loss_history) + + return loss_history + + + + +def main(): + parser = argparse.ArgumentParser(description="Train MedGAN for synthetic data generation") + parser.add_argument("--data_path", type=str, default="./data_files", help="path to MIMIC-III data") + parser.add_argument("--output_path", type=str, default="./medgan_results", help="Output directory") + parser.add_argument("--autoencoder_epochs", type=int, default=100, help="Autoencoder pretraining epochs") + parser.add_argument("--gan_epochs", type=int, default=1000, help="GAN training epochs") + parser.add_argument("--latent_dim", type=int, default=128, help="Latent dimension") + parser.add_argument("--hidden_dim", type=int, default=128, help="Hidden dimension") + parser.add_argument("--batch_size", type=int, default=128, help="Batch size") + parser.add_argument("--lr", type=float, default=0.001, help="adam: learning rate") + parser.add_argument("--weight_decay", type=float, default=0.0001, help="l2 regularization") + parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient") + parser.add_argument("--b2", type=float, default=0.9, help="adam: decay of second order momentum of gradient") + parser.add_argument("--save_dir", type=str, default="medgan_results", help="directory to save results") + args = parser.parse_args() + + # setup + os.makedirs(args.output_path, exist_ok=True) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + # load MIMIC-III data + print("Loading MIMIC-III data") + dataset = MIMIC3Dataset(root=args.data_path, tables=["DIAGNOSES_ICD"]) + + # create ICD-9 matrix using utility function + print("Creating ICD-9 matrix") + icd9_matrix, icd9_types = create_icd9_matrix(dataset, args.output_path) + print(f"ICD-9 matrix shape: {icd9_matrix.shape}") + + + # initialize MedGAN model + print("Initializing MedGAN model...") + model = MedGAN.from_binary_matrix( + binary_matrix=icd9_matrix, + latent_dim=args.latent_dim, + autoencoder_hidden_dim=args.hidden_dim, + discriminator_hidden_dim=args.hidden_dim, + minibatch_averaging=True + ) + + # device stuff + model = model.to(device) + model.autoencoder = model.autoencoder.to(device) + model.generator = model.generator.to(device) + model.discriminator = model.discriminator.to(device) + + # make a dataloader + print("Creating dataloader...") + icd9_matrix_dataset = ICD9MatrixDataset(icd9_matrix) + dataloader = DataLoader( + icd9_matrix_dataset, + batch_size=args.batch_size, + shuffle=True + ) + + # autoencoder pretraining + print("Pretraining autoencoder...") + autoencoder_losses = model.pretrain_autoencoder( + dataloader=dataloader, + epochs=args.autoencoder_epochs, + lr=args.lr, + device=device + ) + + # train GAN + print("Training GAN...") + gan_loss_history = train_medgan( + model=model, + dataloader=dataloader, + n_epochs=args.gan_epochs, + device=device, + save_dir=args.save_dir, + lr=args.lr, + weight_decay=args.weight_decay, + b1=args.b1, + b2=args.b2 + ) + + # generate synthetic data + print("Generating synthetic data...") + with torch.no_grad(): + synthetic_data = model.generate(1000, device) + binary_data = model.sample_transform(synthetic_data, threshold=0.5) + + synthetic_matrix = binary_data.cpu().numpy() + + # save + print("Saving results...") + torch.save({ + 'model_config': { + 'latent_dim': args.latent_dim, + 'hidden_dim': args.hidden_dim, + 'autoencoder_hidden_dim': args.hidden_dim, + 'discriminator_hidden_dim': args.hidden_dim, + 'input_dim': icd9_matrix.shape[1], + }, + 'generator_state_dict': model.generator.state_dict(), + 'discriminator_state_dict': model.discriminator.state_dict(), + 'autoencoder_state_dict': model.autoencoder.state_dict(), + }, os.path.join(args.output_path, "medgan_final.pth")) + + np.save(os.path.join(args.output_path, "synthetic_binary_matrix.npy"), synthetic_matrix) + + # save loss histories + loss_history = { + 'autoencoder_losses': autoencoder_losses, + 'gan_losses': gan_loss_history, + } + np.save(os.path.join(args.output_path, "loss_history.npy"), loss_history) + + # print final stats + print("\n" + "="*50) + print("TRAINING COMPLETED") + print("="*50) + print(f"Real data shape: {icd9_matrix.shape}") + print(f"Real data mean activation: {icd9_matrix.mean():.4f}") + print(f"Real data sparsity: {(icd9_matrix == 0).mean():.4f}") + print(f"Synthetic data shape: {synthetic_matrix.shape}") + print(f"Synthetic data mean activation: {synthetic_matrix.mean():.4f}") + print(f"Synthetic data sparsity: {(synthetic_matrix == 0).mean():.4f}") + print(f"Results saved to: {args.output_path}") + print("="*50) + + print("\nGenerated synthetic data in original MIMIC3 ICD-9 format.") + + +if __name__ == "__main__": + main() + +""" +Slurm script example: +#!/bin/bash +#SBATCH --account=jalenj4-ic +#SBATCH --job-name=medgan_pyhealth +#SBATCH --output=logs/medgan_pyhealth_%j.out +#SBATCH --error=logs/medgan_pyhealth_%j.err +#SBATCH --partition=IllinoisComputes-GPU # Change to appropriate partition +#SBATCH --gres=gpu:1 # Request 1 GPU +#SBATCH --cpus-per-task=4 +#SBATCH --mem=32G +#SBATCH --time=12:00:00 + +# Change to the directory where you submitted the job +cd "$SLURM_SUBMIT_DIR" +source pyhealth/bin/activate +export PYTHONPATH=/u/jalenj4/PyHealth/PyHealth:$PYTHONPATH + +# Print useful Slurm environment variables for debugging +echo "SLURM_JOB_ID: $SLURM_JOB_ID" +echo "SLURM_JOB_NODELIST: $SLURM_JOB_NODELIST" +echo "SLURM_NTASKS: $SLURM_NTASKS" +echo "SLURM_CPUS_ON_NODE: $SLURM_CPUS_ON_NODE" +echo "SLURM_GPUS_ON_NODE: $SLURM_GPUS_ON_NODE" +echo "SLURM_GPUS: $SLURM_GPUS" +echo "CUDA_VISIBLE_DEVICES: $CUDA_VISIBLE_DEVICES" + +# Optional: check what GPU(s) is/are actually visible +echo "Running nvidia-smi to confirm GPU availability:" +nvidia-smi + +# Load modules or activate environment +# module load python/3.10 +# module load cuda/11.7 +# conda activate pyhealth + +# Create output directories +mkdir -p logs +mkdir -p medgan_results + +# Set parameters (matching original synthEHRella defaults) +export AUTOENCODER_EPOCHS=100 +export GAN_EPOCHS=1000 +export BATCH_SIZE=128 +export LATENT_DIM=128 +export HIDDEN_DIM=128 +export NUM_SAMPLES=1000 +export LEARNING_RATE=0.001 +export WEIGHT_DECAY=0.0001 +export BETA1=0.5 +export BETA2=0.9 + +echo "Starting PyHealth MedGAN training with parameters:" +echo " Autoencoder epochs: $AUTOENCODER_EPOCHS" +echo " GAN epochs: $GAN_EPOCHS" +echo " Batch size: $BATCH_SIZE" +echo " Latent dimension: $LATENT_DIM" +echo " Hidden dimension: $HIDDEN_DIM" +echo " Number of synthetic samples: $NUM_SAMPLES" +echo " Learning rate: $LEARNING_RATE" +echo " Weight decay: $WEIGHT_DECAY" +echo " Beta1: $BETA1" +echo " Beta2: $BETA2" + +# Run the comprehensive PyHealth MedGAN script +python examples/synthetic_data_generation_mimic3_medgan.py \ + --data_path ./data_files \ + --output_path ./medgan_results \ + --autoencoder_epochs $AUTOENCODER_EPOCHS \ + --gan_epochs $GAN_EPOCHS \ + --batch_size $BATCH_SIZE \ + --latent_dim $LATENT_DIM \ + --hidden_dim $HIDDEN_DIM \ + --lr $LEARNING_RATE \ + --weight_decay $WEIGHT_DECAY \ + --b1 $BETA1 \ + --b2 $BETA2 \ + +echo "PyHealth MedGAN training completed!" +echo "Results saved to: ./medgan_results/" +echo "Check the following files:" +echo " - synthetic_binary_matrix.npy: Synthetic data in original MIMIC3 ICD-9 format" +echo " - medgan_final.pth: Trained model" +echo " - loss_history.npy: Training loss history" +""" \ No newline at end of file diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index f0e4f53e7..d2a107008 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -36,6 +36,7 @@ def __init__(self, *args, **kwargs): from .mimic4 import MIMIC4CXRDataset, MIMIC4Dataset, MIMIC4EHRDataset, MIMIC4NoteDataset from .mimicextract import MIMICExtractDataset from .omop import OMOPDataset +from .icd9_matrix import create_icd9_matrix, ICD9MatrixDataset from .sample_dataset import SampleDataset from .shhs import SHHSDataset from .sleepedf import SleepEDFDataset diff --git a/pyhealth/datasets/icd9_matrix.py b/pyhealth/datasets/icd9_matrix.py new file mode 100644 index 000000000..a949cb58c --- /dev/null +++ b/pyhealth/datasets/icd9_matrix.py @@ -0,0 +1,101 @@ +""" +Simple ICD-9 matrix utilities for MIMIC-III data. +No conversions - keeps original ICD-9 format. +""" + +import os +import numpy as np +import torch +from typing import Dict, Tuple +from torch.utils.data import Dataset + +from .base_dataset import BaseDataset + + +def convert_to_3digit_icd9(dxStr): + """Convert ICD-9 to 3-digit format""" + if dxStr.startswith('E'): + if len(dxStr) > 4: + return dxStr[:4] + else: + return dxStr + else: + if len(dxStr) > 3: + return dxStr[:3] + else: + return dxStr + + +def create_icd9_matrix(dataset: BaseDataset, output_path: str = None) -> Tuple[np.ndarray, Dict[str, int]]: + """ + Create ICD-9 binary matrix from MIMIC3Dataset. + + Args: + dataset: MIMIC3Dataset instance + output_path: Optional path to save matrix + + Returns: + Tuple of (matrix, icd9_code_to_index_mapping) + """ + print("Processing ICD-9 codes...") + + # Collect all ICD codes from patients + icd_codes = set() + patient_icd_map = {} + + for patient in dataset.iter_patients(): + patient_codes = set() + + # Get events from the diagnoses_icd table + events = patient.get_events(event_type="diagnoses_icd") + + for event in events: + # Check if this is an ICD-9 diagnosis + if "icd9_code" in event.attr_dict and event.attr_dict["icd9_code"]: + # Use 3-digit truncation + code = convert_to_3digit_icd9(event.attr_dict["icd9_code"]) + icd_codes.add(code) + patient_codes.add(code) + + if patient_codes: + patient_icd_map[patient.patient_id] = patient_codes + + # Create ICD-9 matrix + icd_codes_list = sorted(list(icd_codes)) + icd9_types = {code: idx for idx, code in enumerate(icd_codes_list)} + + num_patients = len(patient_icd_map) + num_codes = len(icd9_types) + + print(f"Creating ICD-9 matrix: {num_patients} patients x {num_codes} codes") + + icd9_matrix = np.zeros((num_patients, num_codes), dtype=np.float32) + + for i, (patient_id, codes) in enumerate(patient_icd_map.items()): + for code in codes: + if code in icd9_types: + icd9_matrix[i, icd9_types[code]] = 1.0 + + # Save matrix if output path provided + if output_path: + os.makedirs(output_path, exist_ok=True) + np.save(os.path.join(output_path, "icd9_matrix.npy"), icd9_matrix) + print(f"Saved ICD-9 matrix to {output_path}/icd9_matrix.npy") + + print(f"Final matrix shape: {icd9_matrix.shape}") + print(f"Sparsity: {1.0 - np.mean(icd9_matrix):.3f}") + + return icd9_matrix, icd9_types + + +class ICD9MatrixDataset(Dataset): + """Simple dataset wrapper for ICD-9 matrix""" + + def __init__(self, icd9_matrix: np.ndarray): + self.icd9_matrix = icd9_matrix + + def __len__(self): + return self.icd9_matrix.shape[0] + + def __getitem__(self, idx): + return torch.tensor(self.icd9_matrix[idx], dtype=torch.float32) \ No newline at end of file diff --git a/pyhealth/models/generators/__init__.py b/pyhealth/models/generators/__init__.py new file mode 100644 index 000000000..a5f63a7ba --- /dev/null +++ b/pyhealth/models/generators/__init__.py @@ -0,0 +1 @@ +from .medgan import MedGAN \ No newline at end of file diff --git a/pyhealth/models/generators/halo.py b/pyhealth/models/generators/halo.py new file mode 100644 index 000000000..e69de29bb diff --git a/pyhealth/models/generators/medgan.py b/pyhealth/models/generators/medgan.py new file mode 100644 index 000000000..c81859854 --- /dev/null +++ b/pyhealth/models/generators/medgan.py @@ -0,0 +1,358 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from typing import Dict, List, Optional, Tuple, Union +from torch.utils.data import DataLoader + +from pyhealth.models import BaseModel + + +class MedGANAutoencoder(nn.Module): + """simple autoencoder for pretraining""" + + def __init__(self, input_dim: int, hidden_dim: int = 128): + super().__init__() + self.encoder = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.Tanh() + ) + self.decoder = nn.Sequential( + nn.Linear(hidden_dim, input_dim), + nn.Sigmoid() + ) + + def forward(self, x): + encoded = self.encoder(x) + decoded = self.decoder(encoded) + return decoded + + def encode(self, x): + return self.encoder(x) + + def decode(self, x): + return self.decoder(x) + + +class MedGANGenerator(nn.Module): + """generator with residual connections""" + + def __init__(self, latent_dim: int = 128, hidden_dim: int = 128): + super().__init__() + self.linear1 = nn.Linear(latent_dim, hidden_dim) + self.bn1 = nn.BatchNorm1d(hidden_dim, eps=0.001, momentum=0.01) + self.activation1 = nn.ReLU() + + self.linear2 = nn.Linear(hidden_dim, hidden_dim) + self.bn2 = nn.BatchNorm1d(hidden_dim, eps=0.001, momentum=0.01) + self.activation2 = nn.Tanh() + + def forward(self, x): + # residual block 1 + residual = x + out = self.activation1(self.bn1(self.linear1(x))) + out1 = out + residual + + # residual block 2 + residual = out1 + out = self.activation2(self.bn2(self.linear2(out1))) + out2 = out + residual + + return out2 + + +class MedGANDiscriminator(nn.Module): + """discriminator with minibatch averaging""" + + def __init__(self, input_dim: int, hidden_dim: int = 256, minibatch_averaging: bool = True): + super().__init__() + self.minibatch_averaging = minibatch_averaging + model_input_dim = input_dim * 2 if minibatch_averaging else input_dim + + self.model = nn.Sequential( + nn.Linear(model_input_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim // 2), + nn.ReLU(), + nn.Linear(hidden_dim // 2, 1), + nn.Sigmoid() + ) + + def forward(self, x): + if self.minibatch_averaging: + x_mean = torch.mean(x, dim=0).repeat(x.shape[0], 1) + x = torch.cat((x, x_mean), dim=1) + return self.model(x) + + +class MedGAN(BaseModel): + """MedGAN for binary matrix generation""" + + def __init__( + self, + dataset, + feature_keys: List[str], + label_key: str, + mode: str = "generation", + latent_dim: int = 128, + hidden_dim: int = 128, + autoencoder_hidden_dim: int = 128, + discriminator_hidden_dim: int = 256, + minibatch_averaging: bool = True, + **kwargs + ): + # dummy wrapper for BaseModel compatibility + class DummyWrapper: + def __init__(self, dataset, feature_keys, label_key): + self.dataset = dataset + self.input_schema = {key: "multilabel" for key in feature_keys} + self.output_schema = {label_key: "multilabel"} + self.input_processors = {} + self.output_processors = {} + + wrapped_dataset = DummyWrapper(dataset, feature_keys, label_key) + super().__init__(dataset=wrapped_dataset) + + self.latent_dim = latent_dim + self.hidden_dim = hidden_dim + self.minibatch_averaging = minibatch_averaging + + # build vocab (simplified) + self.global_vocab = self._build_global_vocab(dataset, feature_keys) + self.input_dim = len(self.global_vocab) + + # init components + self.autoencoder = MedGANAutoencoder(input_dim=self.input_dim, hidden_dim=autoencoder_hidden_dim) + self.generator = MedGANGenerator(latent_dim=latent_dim, hidden_dim=autoencoder_hidden_dim) + self.discriminator = MedGANDiscriminator( + input_dim=self.input_dim, + hidden_dim=discriminator_hidden_dim, + minibatch_averaging=minibatch_averaging + ) + + self._init_weights() + + @classmethod + def from_binary_matrix( + cls, + binary_matrix: np.ndarray, + latent_dim: int = 128, + hidden_dim: int = 128, + autoencoder_hidden_dim: int = 128, + discriminator_hidden_dim: int = 256, + minibatch_averaging: bool = True, + **kwargs + ): + """create MedGAN model from binary matrix (ICD-9, etc.)""" + class MatrixWrapper: + def __init__(self, matrix): + self.matrix = matrix + self.input_processors = {} + self.output_processors = {} + + def __len__(self): + return self.matrix.shape[0] + + def __getitem__(self, idx): + return {"binary_vector": torch.tensor(self.matrix[idx], dtype=torch.float32)} + + def iter_patients(self): + """iterate over patients""" + for i in range(len(self)): + yield type('Patient', (), { + 'binary_vector': self.matrix[i], + 'patient_id': f'patient_{i}' + })() + + dummy_dataset = MatrixWrapper(binary_matrix) + + model = cls( + dataset=dummy_dataset, + feature_keys=["binary_vector"], + label_key="binary_vector", + latent_dim=latent_dim, + hidden_dim=hidden_dim, + autoencoder_hidden_dim=autoencoder_hidden_dim, + discriminator_hidden_dim=discriminator_hidden_dim, + minibatch_averaging=minibatch_averaging, + **kwargs + ) + + # override input dimension + model.input_dim = binary_matrix.shape[1] + + # reinitialize components with correct dimensions + model.autoencoder = MedGANAutoencoder(input_dim=model.input_dim, hidden_dim=autoencoder_hidden_dim) + model.generator = MedGANGenerator(latent_dim=latent_dim, hidden_dim=autoencoder_hidden_dim) + model.discriminator = MedGANDiscriminator( + input_dim=model.input_dim, + hidden_dim=discriminator_hidden_dim, + minibatch_averaging=minibatch_averaging + ) + + # Move all components to the same device as the model + device = next(model.parameters()).device if hasattr(model, 'parameters') else torch.device('cpu') + model.autoencoder = model.autoencoder.to(device) + model.generator = model.generator.to(device) + model.discriminator = model.discriminator.to(device) + + # override feature extraction + def extract_features(batch_data, device): + return batch_data["binary_vector"].to(device) + + model._extract_features_from_batch = extract_features + + return model + + def _build_global_vocab(self, dataset, feature_keys: List[str]) -> List[str]: + """build vocab from dataset (simplified)""" + vocab = set() + for patient in dataset.iter_patients(): + for feature_key in feature_keys: + if hasattr(patient, feature_key): + feature_values = getattr(patient, feature_key) + if isinstance(feature_values, list): + vocab.update(feature_values) + elif isinstance(feature_values, str): + vocab.add(feature_values) + return sorted(list(vocab)) + + def _init_weights(self): + """init weights""" + def weights_init(m): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + self.autoencoder.apply(weights_init) + self.generator.apply(weights_init) + self.discriminator.apply(weights_init) + + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: + """forward pass""" + features = self._extract_features_from_batch(kwargs, self.device) + noise = torch.randn(features.shape[0], self.latent_dim, device=self.device) + fake_samples = self.generator(noise) + return {"real_features": features, "fake_samples": fake_samples} + + def generate(self, n_samples: int, device: torch.device = None) -> torch.Tensor: + """generate synthetic samples""" + if device is None: + device = self.device + + self.generator.eval() + self.autoencoder.eval() + with torch.no_grad(): + noise = torch.randn(n_samples, self.latent_dim, device=device) + generated = self.generator(noise) + # use autoencoder decoder to get final output + generated = self.autoencoder.decode(generated) + + return generated + + def discriminate(self, x: torch.Tensor) -> torch.Tensor: + """discriminate real vs fake""" + return self.discriminator(x) + + def pretrain_autoencoder(self, dataloader: DataLoader, epochs: int = 100, lr: float = 0.001, device: torch.device = None): + """pretrain autoencoder with detailed loss tracking""" + if device is None: + device = self.device + + # Ensure autoencoder is on the correct device + self.autoencoder = self.autoencoder.to(device) + + print("Pretraining Autoencoder...") + print("="*50) + print("Epoch | A_loss | Progress") + print("="*50) + + optimizer = torch.optim.Adam(self.autoencoder.parameters(), lr=lr) + criterion = nn.BCELoss() + + # Track losses for plotting + a_losses = [] + + self.autoencoder.train() + + for epoch in range(epochs): + total_loss = 0 + num_batches = 0 + + for batch in dataloader: + # handle both tensor and dict inputs + if isinstance(batch, torch.Tensor): + features = batch.to(device) + else: + features = self._extract_features_from_batch(batch, device) + + reconstructed = self.autoencoder(features) + loss = criterion(reconstructed, features) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + total_loss += loss.item() + num_batches += 1 + + avg_loss = total_loss / num_batches + a_losses.append(avg_loss) + + # Print progress every epoch for shorter training runs, every 10 for longer runs + print_freq = 1 if epochs <= 50 else 10 + if (epoch + 1) % print_freq == 0 or epoch == 0 or epoch == epochs - 1: + progress = (epoch + 1) / epochs * 100 + print(f"{epoch+1:5d} | {avg_loss:.4f} | {progress:5.1f}%") + + print("="*50) + print("Autoencoder Pretraining Completed!") + print(f"Final A_loss: {a_losses[-1]:.4f}") + + return a_losses + + def _extract_features_from_batch(self, batch_data, device: torch.device) -> torch.Tensor: + """extract features from batch""" + features = [] + for feature_key in self.feature_keys: + if feature_key in batch_data: + features.append(batch_data[feature_key]) + + if len(features) == 1: + return features[0].to(device) + else: + return torch.cat(features, dim=1).to(device) + + def sample_transform(self, samples: torch.Tensor, threshold: float = 0.5) -> torch.Tensor: + """convert to binary using threshold""" + return (samples > threshold).float() + + def train_step(self, batch, optimizer_g, optimizer_d, optimizer_ae=None): + """single training step""" + real_features = self._extract_features_from_batch(batch, self.device) + + # train discriminator + optimizer_d.zero_grad() + noise = torch.randn(real_features.shape[0], self.latent_dim, device=self.device) + fake_samples = self.generator(noise) + + real_predictions = self.discriminator(real_features) + fake_predictions = self.discriminator(fake_samples.detach()) + + d_loss = F.binary_cross_entropy(real_predictions, torch.ones_like(real_predictions)) + \ + F.binary_cross_entropy(fake_predictions, torch.zeros_like(fake_predictions)) + d_loss.backward() + optimizer_d.step() + + # train generator + optimizer_g.zero_grad() + fake_predictions = self.discriminator(fake_samples) + g_loss = F.binary_cross_entropy(fake_predictions, torch.ones_like(fake_predictions)) + g_loss.backward() + optimizer_g.step() + + return {"d_loss": d_loss.item(), "g_loss": g_loss.item()} \ No newline at end of file diff --git a/pyhealth/models/generators/promptehr.py b/pyhealth/models/generators/promptehr.py new file mode 100644 index 000000000..e69de29bb