From 622c2a2804bd7f126a107035ea477a6e47b7b9eb Mon Sep 17 00:00:00 2001 From: John Wu Date: Wed, 19 Nov 2025 14:53:40 -0600 Subject: [PATCH 1/7] init commit for 1 solution --- pyhealth/datasets/base_dataset.py | 72 +++++++++++++++++++++++-------- 1 file changed, 55 insertions(+), 17 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 3390453ff..cbdcfdd3a 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -1,6 +1,7 @@ import logging import os import pickle +import time from abc import ABC from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path @@ -211,7 +212,9 @@ def load_table(self, table_name: str) -> pl.LazyFrame: def _to_lower(col_name: str) -> str: lower_name = col_name.lower() if lower_name != col_name: - logger.warning("Renaming column %s to lowercase %s", col_name, lower_name) + logger.warning( + "Renaming column %s to lowercase %s", col_name, lower_name + ) return lower_name table_cfg = self.config.tables[table_name] @@ -281,7 +284,8 @@ def _to_lower(col_name: str) -> str: # Flatten attribute columns with event_type prefix attribute_columns = [ - pl.col(attr.lower()).alias(f"{table_name}/{attr}") for attr in attribute_cols + pl.col(attr.lower()).alias(f"{table_name}/{attr}") + for attr in attribute_cols ] event_frame = df.select(base_columns + attribute_columns) @@ -332,7 +336,6 @@ def iter_patients(self, df: Optional[pl.LazyFrame] = None) -> Iterator[Patient]: if df is None: df = self.collected_global_event_df grouped = df.group_by("patient_id") - for patient_id, patient_df in grouped: patient_id = patient_id[0] yield Patient(patient_id=patient_id, data_source=patient_df) @@ -440,23 +443,58 @@ def set_task( ): samples.extend(task(patient)) else: - # multi-threading (not recommended) + # multi-threading with lazy iteration and bounded queue logger.info( - f"Generating samples for {task.task_name} with " - f"{num_workers} workers" + "Generating samples for %s with %d workers", + task.task_name, + num_workers, ) - patients = list(self.iter_patients(filtered_global_event_df)) + + logger.info("Computing total patient count...") + start_time = time.time() + total_patients = filtered_global_event_df["patient_id"].n_unique() + elapsed = time.time() - start_time + logger.info( + "n_unique() completed in %.2f seconds, found %d patients", + elapsed, + total_patients, + ) + + patients_iter = self.iter_patients(filtered_global_event_df) + max_in_flight = num_workers * 4 + with ThreadPoolExecutor(max_workers=num_workers) as executor: - futures = [executor.submit(task, patient) for patient in patients] - for future in tqdm( - as_completed(futures), - total=len(futures), - desc=( - f"Collecting samples for {task.task_name} " - f"from {num_workers} workers" - ), - ): - samples.extend(future.result()) + in_flight = {} + + # Prime the in-flight queue + try: + for _ in range(max_in_flight): + patient = next(patients_iter) + fut = executor.submit(task, patient) + in_flight[fut] = None + except StopIteration: + pass + + with tqdm( + total=total_patients, + desc=f"Processing {task.task_name}", + ) as pbar: + while in_flight: + for fut in as_completed(list(in_flight.keys())): + in_flight.pop(fut, None) + result = fut.result() + samples.extend(result) + pbar.update(1) + + try: + next_patient = next(patients_iter) + new_fut = executor.submit(task, next_patient) + in_flight[new_fut] = None + except StopIteration: + pass + + # Re-enter as_completed with updated future set + break # Cache the samples if cache_dir is provided if cache_dir is not None: From 36e0b34d3e3db7cf751a9a9ec7e1774a02634800 Mon Sep 17 00:00:00 2001 From: John Wu Date: Wed, 19 Nov 2025 14:53:52 -0600 Subject: [PATCH 2/7] also change ex cache --- examples/mortality_mimic4_stagenet_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/mortality_mimic4_stagenet_v2.py b/examples/mortality_mimic4_stagenet_v2.py index b33a3d760..fd626f008 100644 --- a/examples/mortality_mimic4_stagenet_v2.py +++ b/examples/mortality_mimic4_stagenet_v2.py @@ -34,7 +34,7 @@ sample_dataset = base_dataset.set_task( MortalityPredictionStageNetMIMIC4(), num_workers=4, - cache_dir="../../mimic4_stagenet_cache", + cache_dir="../../mimic4_stagenet_cache_v2", ) print(f"Total samples: {len(sample_dataset)}") From d58158709583cd466ac3c72fdb1937884e0e3b2e Mon Sep 17 00:00:00 2001 From: John Wu Date: Thu, 20 Nov 2025 12:48:52 -0600 Subject: [PATCH 3/7] init commit --- examples/mortality_mimic4_stagenet_v2.py | 231 +++++++++++++++++- .../processors/nested_sequence_processor.py | 8 +- pyhealth/processors/sequence_processor.py | 18 +- pyhealth/processors/stagenet_processor.py | 14 +- .../mortality_prediction_stagenet_mimic4.py | 25 +- 5 files changed, 269 insertions(+), 27 deletions(-) diff --git a/examples/mortality_mimic4_stagenet_v2.py b/examples/mortality_mimic4_stagenet_v2.py index fd626f008..87a850964 100644 --- a/examples/mortality_mimic4_stagenet_v2.py +++ b/examples/mortality_mimic4_stagenet_v2.py @@ -6,18 +6,132 @@ 2. Applying the MortalityPredictionStageNetMIMIC4 task 3. Creating a SampleDataset with StageNet processors 4. Training a StageNet model +5. Testing with synthetic hold-out set (unseen codes, varying lengths) """ +import os +import random +import numpy as np from pyhealth.datasets import ( MIMIC4Dataset, get_dataloader, split_by_patient, + SampleDataset, ) +from pyhealth.datasets.utils import save_processors, load_processors from pyhealth.models import StageNet from pyhealth.tasks import MortalityPredictionStageNetMIMIC4 from pyhealth.trainer import Trainer import torch + +def generate_holdout_set( + sample_dataset: SampleDataset, num_samples: int = 10, seed: int = 42 +) -> SampleDataset: + """Generate synthetic hold-out set with unseen codes and varying lengths. + + This function creates synthetic samples to test the processor's ability to: + 1. Handle completely unseen tokens (mapped to ) + 2. Handle sequence lengths larger than training but within padding + + Args: + sample_dataset: Original SampleDataset with fitted processors + num_samples: Number of synthetic samples to generate + seed: Random seed for reproducibility + + Returns: + SampleDataset with synthetic samples using fitted processors + """ + random.seed(seed) + np.random.seed(seed) + + # Get the fitted processors + icd_processor = sample_dataset.input_processors["icd_codes"] + + # Get max nested length from ICD processor + max_icd_len = icd_processor._max_nested_len + padding = icd_processor._padding + + print("\n=== Hold-out Set Generation ===") + print(f"ICD max nested length: {max_icd_len}") + print(f"Padding: {padding}") + print(f"Observed max (without padding): {max_icd_len - padding}") + + synthetic_samples = [] + + for i in range(num_samples): + # Generate random number of visits (1-5) + num_visits = random.randint(1, 5) + + # Generate ICD codes with unseen tokens + icd_codes_list = [] + icd_times_list = [] + + for visit_idx in range(num_visits): + # Generate sequence length between observed_max and max_icd_len + # This tests the padding capacity + observed_max = max_icd_len - padding + seq_len = random.randint(max(1, observed_max - 2), max_icd_len - 1) + + # Generate unseen codes + visit_codes = [f"NEWCODE_{i}_{visit_idx}_{j}" for j in range(seq_len)] + icd_codes_list.append(visit_codes) + + # Generate time intervals (hours from previous visit) + if visit_idx == 0: + icd_times_list.append(0.0) + else: + icd_times_list.append(random.uniform(24.0, 720.0)) + + # Generate lab data (10-dimensional vectors) + num_lab_timestamps = random.randint(5, 15) + lab_values_list = [] + lab_times_list = [] + + for ts_idx in range(num_lab_timestamps): + # Generate 10D vector with some random values and some None + lab_vector = [] + for dim in range(10): + if random.random() < 0.8: # 80% chance of value + lab_vector.append(random.uniform(50.0, 150.0)) + else: + lab_vector.append(None) + + lab_values_list.append(lab_vector) + lab_times_list.append(random.uniform(0.0, 48.0)) + + # Create sample in the expected format (before processing) + synthetic_sample = { + "patient_id": f"HOLDOUT_PATIENT_{i}", + "icd_codes": (icd_times_list, icd_codes_list), + "labs": (lab_times_list, lab_values_list), + "mortality": random.randint(0, 1), + } + + synthetic_samples.append(synthetic_sample) + + # Create a new SampleDataset with the FITTED processors + holdout_dataset = SampleDataset( + samples=synthetic_samples, + input_schema=sample_dataset.input_schema, + output_schema=sample_dataset.output_schema, + dataset_name=f"{sample_dataset.dataset_name}_holdout", + task_name=sample_dataset.task_name, + input_processors=sample_dataset.input_processors, + output_processors=sample_dataset.output_processors, + ) + + print(f"Generated {len(holdout_dataset)} synthetic samples") + sample_seq_lens = [len(s["icd_codes"][1]) for s in synthetic_samples[:3]] + print(f"Sample ICD sequence lengths: {sample_seq_lens}") + sample_codes_per_visit = [ + [len(visit) for visit in s["icd_codes"][1]] for s in synthetic_samples[:3] + ] + print(f"Sample codes per visit: {sample_codes_per_visit}") + + return holdout_dataset + + # STEP 1: Load MIMIC-IV base dataset base_dataset = MIMIC4Dataset( ehr_root="/srv/local/data/physionet.org/files/mimiciv/2.2/", @@ -30,12 +144,38 @@ ], ) -# STEP 2: Apply StageNet mortality prediction task -sample_dataset = base_dataset.set_task( - MortalityPredictionStageNetMIMIC4(), - num_workers=4, - cache_dir="../../mimic4_stagenet_cache_v2", -) +# STEP 2: Apply StageNet mortality prediction task with padding +# +# Processor Saving/Loading: +# - Processors are saved after the first run to avoid refitting +# - On subsequent runs, pre-fitted processors are loaded from disk +# - This ensures consistent encoding and saves computation time +# - Processors include vocabulary mappings and sequence length statistics +processor_dir = "../../output/processors/stagenet_mortality_mimic4" +cache_dir = "../../mimic4_stagenet_cache_v3" + +if os.path.exists(os.path.join(processor_dir, "input_processors.pkl")): + print("\n=== Loading Pre-fitted Processors ===") + input_processors, output_processors = load_processors(processor_dir) + + sample_dataset = base_dataset.set_task( + MortalityPredictionStageNetMIMIC4(padding=20), + num_workers=4, + cache_dir=cache_dir, + input_processors=input_processors, + output_processors=output_processors, + ) +else: + print("\n=== Fitting New Processors ===") + sample_dataset = base_dataset.set_task( + MortalityPredictionStageNetMIMIC4(padding=20), + num_workers=4, + cache_dir=cache_dir, + ) + + # Save processors for future runs + print("\n=== Saving Processors ===") + save_processors(sample_dataset, processor_dir) print(f"Total samples: {len(sample_dataset)}") print(f"Input schema: {sample_dataset.input_schema}") @@ -100,3 +240,82 @@ print("\nSample predictions:") print(f" Predicted probabilities: {output['y_prob'][:5]}") print(f" True labels: {output['y_true'][:5]}") + +# STEP 8: Test with synthetic hold-out set (unseen codes, varying lengths) +print("\n" + "=" * 60) +print("TESTING PROCESSOR ROBUSTNESS WITH SYNTHETIC HOLD-OUT SET") +print("=" * 60) + +# Generate hold-out set with fitted processors +holdout_dataset = generate_holdout_set( + sample_dataset=sample_dataset, num_samples=50, seed=42 +) + +# Create dataloader for hold-out set +holdout_loader = get_dataloader(holdout_dataset, batch_size=16, shuffle=False) + +# Inspect processed samples +print("\n=== Inspecting Processed Hold-out Samples ===") +holdout_batch = next(iter(holdout_loader)) + +print(f"Batch size: {len(holdout_batch['patient_id'])}") +print(f"ICD codes tensor shape: {holdout_batch['icd_codes'][1].shape}") +print("ICD codes sample (first patient):") +print(f" Time: {holdout_batch['icd_codes'][0][0][:5]}") +print(f" Values (indices): {holdout_batch['icd_codes'][1][0][:3]}") + +# Check for unknown tokens +icd_processor = sample_dataset.input_processors["icd_codes"] +unk_token_idx = icd_processor.code_vocab[""] +pad_token_idx = icd_processor.code_vocab[""] + +print(f"\n token index: {unk_token_idx}") +print(f" token index: {pad_token_idx}") + +# Count unknown and padding tokens in batch +icd_values = holdout_batch["icd_codes"][1] +num_unk = (icd_values == unk_token_idx).sum().item() +num_pad = (icd_values == pad_token_idx).sum().item() +total_tokens = icd_values.numel() + +print("\nToken statistics in hold-out batch:") +print(f" Total tokens: {total_tokens}") +print(f" Unknown tokens: {num_unk} ({100*num_unk/total_tokens:.1f}%)") +print(f" Padding tokens: {num_pad} ({100*num_pad/total_tokens:.1f}%)") + +# Run model inference on hold-out set +print("\n=== Model Inference on Hold-out Set ===") +with torch.no_grad(): + holdout_output = model(**holdout_batch) + +print(f"Predictions shape: {holdout_output['y_prob'].shape}") +print(f"Sample predictions: {holdout_output['y_prob'][:5]}") +print(f"True labels: {holdout_output['y_true'][:5]}") + +print("\n" + "=" * 60) +print("HOLD-OUT SET TEST COMPLETED SUCCESSFULLY!") +print("Processors handled unseen codes and varying lengths correctly.") +print("=" * 60) + +# STEP 9: Inspect saved processors +print("\n" + "=" * 60) +print("PROCESSOR INFORMATION") +print("=" * 60) +print(f"\nProcessors saved at: {processor_dir}") +print("\nICD Codes Processor:") +print(f" {icd_processor}") +print(f" Vocabulary size: {icd_processor.vocab_size()}") +print(f" token index: {icd_processor.code_vocab['']}") +print(f" token index: {icd_processor.code_vocab['']}") +print(f" Max nested length: {icd_processor._max_nested_len}") +print(f" Padding capacity: {icd_processor._padding}") + +labs_processor = sample_dataset.input_processors["labs"] +print("\nLabs Processor:") +print(f" {labs_processor}") +print(f" Feature dimension: {labs_processor.size}") + +print("\nTo reuse these processors in future runs:") +print(" 1. Keep the processor_dir path the same") +print(" 2. The script will automatically load them on next run") +print(" 3. This ensures consistent encoding across experiments") diff --git a/pyhealth/processors/nested_sequence_processor.py b/pyhealth/processors/nested_sequence_processor.py index bf7ed2055..f89b5b127 100644 --- a/pyhealth/processors/nested_sequence_processor.py +++ b/pyhealth/processors/nested_sequence_processor.py @@ -45,8 +45,8 @@ class NestedSequenceProcessor(FeatureProcessor): """ def __init__(self, padding: int = 0): - # -1 for for ease of boolean arithmetic > 0, > -1, etc. - self.code_vocab: Dict[Any, int] = {"": -1, "": 0} + # will be set to len(vocab) after fit + self.code_vocab: Dict[Any, int] = {"": None, "": 0} self._next_index = 1 self._max_inner_len = 1 # Maximum length of inner sequences self._padding = padding # Additional padding beyond observed max @@ -82,6 +82,10 @@ def fit(self, samples: List[Dict[str, Any]], field: str) -> None: observed_max = max(1, max_inner_len) self._max_inner_len = observed_max + self._padding + # Set token to len(vocab) - 1 after building vocabulary + # (-1 because is already in vocab) + self.code_vocab[""] = len(self.code_vocab) - 1 + def process(self, value: List[List[Any]]) -> torch.Tensor: """Process nested sequence into padded 2D tensor. diff --git a/pyhealth/processors/sequence_processor.py b/pyhealth/processors/sequence_processor.py index 0816ebb33..8bb72d326 100644 --- a/pyhealth/processors/sequence_processor.py +++ b/pyhealth/processors/sequence_processor.py @@ -16,9 +16,8 @@ class SequenceProcessor(FeatureProcessor): """ def __init__(self): - # -1 for for ease of boolean arithmetic > 0, > -1, etc. - # TODO: this can be a problem if we pass -1 into nn.Embedding - self.code_vocab: Dict[Any, int] = {"": -1, "": 0} + # will be set to len(vocab) after fit + self.code_vocab: Dict[Any, int] = {"": None, "": 0} self._next_index = 1 def process(self, value: Any) -> torch.Tensor: @@ -32,19 +31,20 @@ def process(self, value: Any) -> torch.Tensor: """ indices = [] for token in value: - if token is None: # missing values + if token is None: # missing values indices.append(self.code_vocab[""]) else: if token not in self.code_vocab: self.code_vocab[token] = self._next_index self._next_index += 1 + # Update token to len(vocab) - 1 + # (-1 because is already in vocab) + self.code_vocab[""] = len(self.code_vocab) - 1 indices.append(self.code_vocab[token]) return torch.tensor(indices, dtype=torch.long) - + def size(self): return len(self.code_vocab) - + def __repr__(self): - return ( - f"SequenceProcessor(code_vocab_size={len(self.code_vocab)})" - ) + return f"SequenceProcessor(code_vocab_size={len(self.code_vocab)})" diff --git a/pyhealth/processors/stagenet_processor.py b/pyhealth/processors/stagenet_processor.py index cbbafac94..29e78935c 100644 --- a/pyhealth/processors/stagenet_processor.py +++ b/pyhealth/processors/stagenet_processor.py @@ -24,9 +24,9 @@ class StageNetProcessor(FeatureProcessor): - List of lists of strings -> nested code sequences Args: - padding: Additional padding to add on top of the observed maximum nested + padding: Additional padding to add on top of the observed maximum nested sequence length. The actual padding length will be observed_max + padding. - This ensures the processor can handle sequences longer than those in the + This ensures the processor can handle sequences longer than those in the training data. Default: 0 (no extra padding). Only applies to nested sequences. Returns: @@ -55,10 +55,12 @@ class StageNetProcessor(FeatureProcessor): """ def __init__(self, padding: int = 0): - self.code_vocab: Dict[Any, int] = {"": -1, "": 0} + # will be set to len(vocab) after fit + self.code_vocab: Dict[Any, int] = {"": None, "": 0} self._next_index = 1 self._is_nested = None # Will be determined during fit - self._max_nested_len = None # Max inner sequence length for nested codes + # Max inner sequence length for nested codes + self._max_nested_len = None self._padding = padding # Additional padding beyond observed max def fit(self, samples: List[Dict], key: str) -> None: @@ -116,6 +118,10 @@ def fit(self, samples: List[Dict], key: str) -> None: observed_max = max(1, max_inner_len) self._max_nested_len = observed_max + self._padding + # Set token to len(vocab) - 1 after building vocabulary + # (-1 because is already in vocab) + self.code_vocab[""] = len(self.code_vocab) - 1 + def process( self, value: Tuple[Optional[List], List] ) -> Tuple[Optional[torch.Tensor], torch.Tensor]: diff --git a/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py b/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py index fc9c58f7f..91e1f94cd 100644 --- a/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py +++ b/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any, ClassVar, Dict, List +from typing import Any, ClassVar, Dict, List, Tuple import polars as pl @@ -24,6 +24,10 @@ class MortalityPredictionStageNetMIMIC4(BaseTask): - Multiple itemids per category → take first observed value - Missing categories → None/NaN in vector + Args: + padding: Additional padding for StageNet processor to handle + sequences longer than observed during training. Default: 0. + Attributes: task_name (str): The name of the task. input_schema (Dict[str, str]): The schema for input data: @@ -35,11 +39,20 @@ class MortalityPredictionStageNetMIMIC4(BaseTask): """ task_name: str = "MortalityPredictionStageNetMIMIC4" - input_schema: Dict[str, str] = { - "icd_codes": "stagenet", - "labs": "stagenet_tensor", - } - output_schema: Dict[str, str] = {"mortality": "binary"} + + def __init__(self, padding: int = 0): + """Initialize task with optional padding parameter. + + Args: + padding: Additional padding for nested sequences. Default: 0. + """ + self.padding = padding + # Use tuple format to pass kwargs to processor + self.input_schema: Dict[str, Tuple[str, Dict[str, Any]]] = { + "icd_codes": ("stagenet", {"padding": padding}), + "labs": ("stagenet_tensor", {}), + } + self.output_schema: Dict[str, str] = {"mortality": "binary"} # Organize lab items by category # Each category will map to ONE dimension in the output vector From 58ca0fdae2e7b7c07d5c53f6518d056c1da987c0 Mon Sep 17 00:00:00 2001 From: John Wu Date: Fri, 21 Nov 2025 11:58:13 -0600 Subject: [PATCH 4/7] commit new test case and fixes --- examples/mortality_mimic4_stagenet_v2.py | 18 +- .../tutorial_stagenet_comprehensive.ipynb | 8 + pyhealth/processors/stagenet_processor.py | 6 +- tests/core/test_stagenet_processor.py | 467 ++++++++++++++++++ 4 files changed, 490 insertions(+), 9 deletions(-) create mode 100644 tests/core/test_stagenet_processor.py diff --git a/examples/mortality_mimic4_stagenet_v2.py b/examples/mortality_mimic4_stagenet_v2.py index 87a850964..0818704fa 100644 --- a/examples/mortality_mimic4_stagenet_v2.py +++ b/examples/mortality_mimic4_stagenet_v2.py @@ -50,11 +50,16 @@ def generate_holdout_set( # Get max nested length from ICD processor max_icd_len = icd_processor._max_nested_len - padding = icd_processor._padding + # Handle both old and new processor versions + padding = getattr(icd_processor, "_padding", 0) print("\n=== Hold-out Set Generation ===") + print(f"Processor attributes: {dir(icd_processor)}") + print(f"Has _padding attribute: {hasattr(icd_processor, '_padding')}") print(f"ICD max nested length: {max_icd_len}") - print(f"Padding: {padding}") + print(f"Padding (via getattr): {padding}") + if hasattr(icd_processor, "_padding"): + print(f"Padding (direct access): {icd_processor._padding}") print(f"Observed max (without padding): {max_icd_len - padding}") synthetic_samples = [] @@ -142,6 +147,7 @@ def generate_holdout_set( "procedures_icd", "labevents", ], + # dev=True, ) # STEP 2: Apply StageNet mortality prediction task with padding @@ -214,14 +220,14 @@ def generate_holdout_set( # STEP 5: Train the model trainer = Trainer( model=model, - device="cuda:5", # or "cpu" + device="cpu", # or "cpu" metrics=["pr_auc", "roc_auc", "accuracy", "f1"], ) trainer.train( train_dataloader=train_loader, val_dataloader=val_loader, - epochs=50, + epochs=1, monitor="roc_auc", optimizer_params={"lr": 1e-5}, ) @@ -304,11 +310,11 @@ def generate_holdout_set( print(f"\nProcessors saved at: {processor_dir}") print("\nICD Codes Processor:") print(f" {icd_processor}") -print(f" Vocabulary size: {icd_processor.vocab_size()}") +print(f" Vocabulary size: {icd_processor.size()}") print(f" token index: {icd_processor.code_vocab['']}") print(f" token index: {icd_processor.code_vocab['']}") print(f" Max nested length: {icd_processor._max_nested_len}") -print(f" Padding capacity: {icd_processor._padding}") +print(f" Padding capacity: {getattr(icd_processor, '_padding', 0)}") labs_processor = sample_dataset.input_processors["labs"] print("\nLabs Processor:") diff --git a/examples/tutorial_stagenet_comprehensive.ipynb b/examples/tutorial_stagenet_comprehensive.ipynb index 47015ce01..17afeb973 100644 --- a/examples/tutorial_stagenet_comprehensive.ipynb +++ b/examples/tutorial_stagenet_comprehensive.ipynb @@ -7,6 +7,14 @@ "source": [ "\n", "# Getting Started\n", + "Here, we will go over the following with StageNet across all utility modules in PyHealth:\n", + "\n", + "1. Loading the data\n", + "2. Task Processing (with padding to ensure compatibility)\n", + "3. ML Model Initialization \n", + "4. Model training\n", + "5. Holdout Inference on Sets of Codes Not in Vocabulary\n", + "6. Interpretability Example with DeepLift\n", "\n", "## Installation\n", "\n", diff --git a/pyhealth/processors/stagenet_processor.py b/pyhealth/processors/stagenet_processor.py index 29e78935c..462c02a1a 100644 --- a/pyhealth/processors/stagenet_processor.py +++ b/pyhealth/processors/stagenet_processor.py @@ -118,9 +118,9 @@ def fit(self, samples: List[Dict], key: str) -> None: observed_max = max(1, max_inner_len) self._max_nested_len = observed_max + self._padding - # Set token to len(vocab) - 1 after building vocabulary - # (-1 because is already in vocab) - self.code_vocab[""] = len(self.code_vocab) - 1 + # Set token to the next available index + # Since is already in the vocab dict, we use _next_index + self.code_vocab[""] = self._next_index def process( self, value: Tuple[Optional[List], List] diff --git a/tests/core/test_stagenet_processor.py b/tests/core/test_stagenet_processor.py new file mode 100644 index 000000000..b5eb55305 --- /dev/null +++ b/tests/core/test_stagenet_processor.py @@ -0,0 +1,467 @@ +""" +Unit tests for StageNet processors. + +Tests cover: +- Unknown token handling (must be len(vocab) - 1, not -1) +- Vocabulary building for flat and nested codes +- Time processing +- Padding for nested sequences +- Forward-fill for numeric values +- Edge cases (empty sequences, None values, etc.) +""" + +import unittest +import torch +import numpy as np + +from pyhealth.processors import StageNetProcessor, StageNetTensorProcessor + + +class TestStageNetProcessor(unittest.TestCase): + """Tests for StageNetProcessor (categorical codes).""" + + def test_unknown_token_index(self): + """Test that token is len(vocab) - 1, not -1.""" + processor = StageNetProcessor() + samples = [ + {"data": ([0.0, 1.0], [["A", "B"], ["C", "D", "E"]])}, + {"data": ([0.0], [["F"]])}, + ] + processor.fit(samples, "data") + + # should be len(vocab) - 1 (last index) + expected_unk_idx = len(processor.code_vocab) - 1 + self.assertEqual(processor.code_vocab[""], expected_unk_idx) + + # must be >= 0 for nn.Embedding compatibility + self.assertGreaterEqual(processor.code_vocab[""], 0) + + # should be 0 + self.assertEqual(processor.code_vocab[""], 0) + + # Verify vocab size includes both special tokens + # Vocab: , , A, B, C, D, E, F = 8 tokens + self.assertEqual(len(processor.code_vocab), 8) + self.assertEqual(processor.code_vocab[""], 7) + + def test_unknown_token_embedding_compatibility(self): + """Test that index works with nn.Embedding.""" + processor = StageNetProcessor() + samples = [{"data": ([0.0], [["A", "B"]])}] + processor.fit(samples, "data") + + # Create an embedding layer with vocab_size + vocab_size = processor.size() + embedding = torch.nn.Embedding(vocab_size, 64) + + # Process data with unknown codes + time, values = processor.process(([0.0], [["A", "UNKNOWN"]])) + + # Should not raise IndexError + try: + embedded = embedding(values) + self.assertEqual(embedded.shape, (1, 2, 64)) + except IndexError: + self.fail("nn.Embedding raised IndexError with token") + + def test_flat_codes(self): + """Test processing flat code sequences.""" + processor = StageNetProcessor() + samples = [ + {"data": ([0.0, 1.5, 2.3], ["code1", "code2", "code3"])}, + ] + processor.fit(samples, "data") + + # Check structure detection + self.assertFalse(processor._is_nested) + + # Process data + time, values = processor.process(([0.0, 1.5], ["code1", "code2"])) + + # Check shapes + self.assertEqual(time.shape, (2,)) + self.assertEqual(values.shape, (2,)) + self.assertEqual(values.dtype, torch.long) + + # Check values are encoded correctly + self.assertEqual(values[0].item(), processor.code_vocab["code1"]) + self.assertEqual(values[1].item(), processor.code_vocab["code2"]) + + def test_nested_codes(self): + """Test processing nested code sequences.""" + processor = StageNetProcessor(padding=0) + samples = [ + {"data": ([0.0, 1.5], [["A", "B"], ["C", "D", "E"]])}, + {"data": ([0.0], [["F"]])}, + ] + processor.fit(samples, "data") + + # Check structure detection + self.assertTrue(processor._is_nested) + + # Max inner length should be 3 (from ["C", "D", "E"]) + self.assertEqual(processor._max_nested_len, 3) + + # Process data + time, values = processor.process(([0.0, 1.5], [["A", "B"], ["C"]])) + + # Check shapes + self.assertEqual(time.shape, (2,)) + self.assertEqual(values.shape, (2, 3)) # 2 visits, padded to 3 + self.assertEqual(values.dtype, torch.long) + + # Check padding is applied + self.assertEqual(values[1, 1].item(), processor.code_vocab[""]) + self.assertEqual(values[1, 2].item(), processor.code_vocab[""]) + + def test_nested_codes_with_padding(self): + """Test nested codes with custom padding parameter.""" + processor = StageNetProcessor(padding=20) + samples = [ + {"data": ([0.0, 1.5], [["A", "B"], ["C", "D", "E"]])}, + ] + processor.fit(samples, "data") + + # Max inner length should be 3 + 20 = 23 + self.assertEqual(processor._max_nested_len, 23) + self.assertEqual(processor._padding, 20) + + # Process data + time, values = processor.process(([0.0], [["A", "B"]])) + + # Check shape includes padding + self.assertEqual(values.shape, (1, 23)) + + def test_unknown_codes_flat(self): + """Test handling of unknown codes in flat sequences.""" + processor = StageNetProcessor() + samples = [{"data": ([0.0], ["A", "B"])}] + processor.fit(samples, "data") + + # Process with unknown code + time, values = processor.process(([0.0, 1.0], ["A", "UNKNOWN"])) + + self.assertEqual(values[0].item(), processor.code_vocab["A"]) + self.assertEqual(values[1].item(), processor.code_vocab[""]) + + def test_unknown_codes_nested(self): + """Test handling of unknown codes in nested sequences.""" + processor = StageNetProcessor(padding=0) + samples = [{"data": ([0.0], [["A", "B"]])}] + processor.fit(samples, "data") + + # Process with unknown code + time, values = processor.process(([0.0], [["A", "UNKNOWN"]])) + + self.assertEqual(values[0, 0].item(), processor.code_vocab["A"]) + self.assertEqual(values[0, 1].item(), processor.code_vocab[""]) + + def test_none_codes(self): + """Test handling of None codes.""" + processor = StageNetProcessor(padding=0) + samples = [{"data": ([0.0], [["A", "B"]])}] + processor.fit(samples, "data") + + # Process with None code + time, values = processor.process(([0.0], [["A", None]])) + + self.assertEqual(values[0, 0].item(), processor.code_vocab["A"]) + self.assertEqual(values[0, 1].item(), processor.code_vocab[""]) + + def test_time_processing(self): + """Test time data processing.""" + processor = StageNetProcessor() + samples = [{"data": ([0.0, 1.5, 2.3], ["A", "B", "C"])}] + processor.fit(samples, "data") + + # Test with time data + time, values = processor.process(([0.0, 1.5], ["A", "B"])) + self.assertIsNotNone(time) + self.assertEqual(time.dtype, torch.float) + self.assertEqual(time[0].item(), 0.0) + self.assertEqual(time[1].item(), 1.5) + + def test_no_time_data(self): + """Test processing without time data.""" + processor = StageNetProcessor() + samples = [{"data": ([0.0], ["A", "B"])}] + processor.fit(samples, "data") + + # Process without time + time, values = processor.process((None, ["A", "B"])) + + self.assertIsNone(time) + self.assertEqual(values.shape, (2,)) + + def test_empty_codes_flat(self): + """Test processing empty code list (flat).""" + processor = StageNetProcessor() + samples = [{"data": ([0.0], ["A", "B"])}] + processor.fit(samples, "data") + + time, values = processor.process((None, [])) + + # Should return single padding token + self.assertEqual(values.shape, (1,)) + self.assertEqual(values[0].item(), processor.code_vocab[""]) + + def test_empty_codes_nested(self): + """Test processing empty nested codes.""" + processor = StageNetProcessor(padding=0) + samples = [{"data": ([0.0], [["A", "B"]])}] + processor.fit(samples, "data") + + time, values = processor.process((None, [])) + + # Should return single row of padding tokens + self.assertEqual(values.shape, (1, 2)) + self.assertEqual(values[0, 0].item(), processor.code_vocab[""]) + self.assertEqual(values[0, 1].item(), processor.code_vocab[""]) + + def test_vocab_size_method(self): + """Test vocab_size() returns correct size.""" + processor = StageNetProcessor() + samples = [ + {"data": ([0.0], [["A", "B", "C"]])}, + ] + processor.fit(samples, "data") + + # Vocab: , , A, B, C = 5 + self.assertEqual(processor.size(), 5) + self.assertEqual(len(processor.code_vocab), 5) + + def test_repr(self): + """Test string representation includes key info.""" + processor = StageNetProcessor(padding=10) + samples = [{"data": ([0.0], [["A", "B"]])}] + processor.fit(samples, "data") + + repr_str = repr(processor) + self.assertIn("StageNetProcessor", repr_str) + self.assertIn("is_nested=True", repr_str) + self.assertIn("vocab_size=4", repr_str) + self.assertIn("max_nested_len=12", repr_str) # 2 + 10 + self.assertIn("padding=10", repr_str) + + +class TestStageNetTensorProcessor(unittest.TestCase): + """Tests for StageNetTensorProcessor (numeric values).""" + + def test_flat_numerics(self): + """Test processing flat numeric sequences.""" + processor = StageNetTensorProcessor() + samples = [ + {"data": ([0.0, 1.5, 2.3], [1.0, 2.0, 3.0])}, + ] + processor.fit(samples, "data") + + # Check structure detection + self.assertFalse(processor._is_nested) + self.assertEqual(processor.size, 1) + + # Process data + time, values = processor.process(([0.0, 1.5], [1.5, 2.5])) + + # Check shapes + self.assertEqual(time.shape, (2,)) + self.assertEqual(values.shape, (2,)) + self.assertEqual(values.dtype, torch.float) + + # Check values + self.assertAlmostEqual(values[0].item(), 1.5, places=5) + self.assertAlmostEqual(values[1].item(), 2.5, places=5) + + def test_nested_numerics(self): + """Test processing nested numeric sequences (feature vectors).""" + processor = StageNetTensorProcessor() + samples = [ + {"data": ([0.0, 1.5], [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, + ] + processor.fit(samples, "data") + + # Check structure detection + self.assertTrue(processor._is_nested) + self.assertEqual(processor.size, 3) # 3 features + + # Process data + time, values = processor.process( + ([0.0, 1.5], [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + ) + + # Check shapes + self.assertEqual(time.shape, (2,)) + self.assertEqual(values.shape, (2, 3)) # 2 timesteps, 3 features + self.assertEqual(values.dtype, torch.float) + + def test_forward_fill_flat(self): + """Test forward-fill imputation for flat numerics.""" + processor = StageNetTensorProcessor() + samples = [{"data": ([0.0], [1.0, 2.0, 3.0])}] + processor.fit(samples, "data") + + # Process with None/NaN + time, values = processor.process(([0.0, 1.0, 2.0], [1.0, None, 3.0])) + + # None should be forward-filled to 1.0 + self.assertAlmostEqual(values[0].item(), 1.0, places=5) + self.assertAlmostEqual(values[1].item(), 1.0, places=5) # Forward filled + self.assertAlmostEqual(values[2].item(), 3.0, places=5) + + def test_forward_fill_nested(self): + """Test forward-fill imputation for nested numerics. + + Forward-fill works per feature dimension across timesteps: + - For each feature column, None is filled with previous timestep's value + - If no previous value exists, it becomes 0.0 + """ + processor = StageNetTensorProcessor() + samples = [{"data": ([0.0], [[1.0, 2.0, 3.0]])}] + processor.fit(samples, "data") + + # Process with None values + time, values = processor.process( + ([0.0, 1.0], [[1.0, None, 3.0], [None, 5.0, 6.0]]) + ) + + # First timestep: None at position 1 becomes 0.0 (no prior value for feature 1) + self.assertAlmostEqual(values[0, 0].item(), 1.0, places=5) + self.assertAlmostEqual(values[0, 1].item(), 0.0, places=5) # No prior + self.assertAlmostEqual(values[0, 2].item(), 3.0, places=5) + + # Second timestep: None at position 0 is forward-filled from first timestep + self.assertAlmostEqual(values[1, 0].item(), 1.0, places=5) # Forward filled + self.assertAlmostEqual(values[1, 1].item(), 5.0, places=5) + self.assertAlmostEqual(values[1, 2].item(), 6.0, places=5) + + def test_forward_fill_first_value_none(self): + """Test forward-fill when first value is None (should be 0.0).""" + processor = StageNetTensorProcessor() + samples = [{"data": ([0.0], [1.0, 2.0])}] + processor.fit(samples, "data") + + # Process with None as first value + time, values = processor.process(([0.0, 1.0], [None, 2.0])) + + # First None should become 0.0 (no prior value) + self.assertEqual(values[0].item(), 0.0) + self.assertAlmostEqual(values[1].item(), 2.0, places=5) + + def test_time_processing_tensor(self): + """Test time data processing for tensor processor.""" + processor = StageNetTensorProcessor() + samples = [{"data": ([0.0, 1.5], [[1.0, 2.0]])}] + processor.fit(samples, "data") + + # Test with time data + time, values = processor.process(([0.0, 1.5], [[1.0, 2.0], [3.0, 4.0]])) + + self.assertIsNotNone(time) + self.assertEqual(time.dtype, torch.float) + self.assertEqual(time[0].item(), 0.0) + self.assertEqual(time[1].item(), 1.5) + + def test_no_time_tensor(self): + """Test processing without time data.""" + processor = StageNetTensorProcessor() + samples = [{"data": ([0.0], [[1.0, 2.0]])}] + processor.fit(samples, "data") + + # Process without time + time, values = processor.process((None, [[1.0, 2.0], [3.0, 4.0]])) + + self.assertIsNone(time) + self.assertEqual(values.shape, (2, 2)) + + def test_repr_tensor(self): + """Test string representation for tensor processor.""" + processor = StageNetTensorProcessor() + samples = [{"data": ([0.0], [[1.0, 2.0, 3.0]])}] + processor.fit(samples, "data") + + repr_str = repr(processor) + self.assertIn("StageNetTensorProcessor", repr_str) + self.assertIn("is_nested=True", repr_str) + self.assertIn("feature_dim=3", repr_str) + + +class TestStageNetProcessorIntegration(unittest.TestCase): + """Integration tests for realistic scenarios.""" + + def test_mortality_prediction_scenario(self): + """Test realistic mortality prediction with ICD codes and labs.""" + icd_processor = StageNetProcessor(padding=20) + lab_processor = StageNetTensorProcessor() + + # Simulate patient data + icd_samples = [ + { + "icd_codes": ( + [0.0, 24.0, 48.0], + [["D1", "D2"], ["D3", "D4", "D5"], ["D6"]], + ) + }, + ] + lab_samples = [ + { + "labs": ( + [0.0, 12.0, 24.0], + [[98.6, 120.0, 80.0], [99.1, 125.0, 85.0], [98.0, 115.0, 75.0]], + ) + }, + ] + + icd_processor.fit(icd_samples, "icd_codes") + lab_processor.fit(lab_samples, "labs") + + # Process new patient with unseen codes + icd_time, icd_values = icd_processor.process( + ([0.0, 24.0], [["D1", "NEWCODE"], ["D3"]]) + ) + + # Check unknown code handling + self.assertEqual(icd_values[0, 1].item(), icd_processor.code_vocab[""]) + + # Check padding (max is 3 + 20 = 23) + self.assertEqual(icd_values.shape[1], 23) + + # Process labs with None values + # Forward-fill works per feature column across timesteps + lab_time, lab_values = lab_processor.process( + ([0.0, 12.0], [[98.6, None, 80.0], [None, 125.0, 85.0]]) + ) + + # Check forward-fill for labs + # Feature 1: None at first timestep becomes 0.0 (no prior) + self.assertAlmostEqual(lab_values[0, 1].item(), 0.0, places=5) + # Feature 0: None at second timestep filled from first (98.6) + self.assertAlmostEqual(lab_values[1, 0].item(), 98.6, places=5) + + def test_vocab_size_for_embedding_layer(self): + """Test that vocab_size() returns correct size for nn.Embedding.""" + processor = StageNetProcessor(padding=0) + samples = [ + {"data": ([0.0], [["A", "B", "C", "D", "E"]])}, + ] + processor.fit(samples, "data") + + # Create embedding layer + vocab_size = processor.size() + embedding = torch.nn.Embedding(vocab_size, 128) + + # Process data with all codes including unknown + time, values = processor.process(([0.0], [["A", "B", "UNKNOWN"]])) + + # Should work without IndexError + # Shape is (1, max_nested_len, 128) where max_nested_len=5 + embedded = embedding(values) + self.assertEqual(embedded.shape[0], 1) # 1 visit + self.assertEqual(embedded.shape[1], 5) # Padded to max_nested_len + self.assertEqual(embedded.shape[2], 128) # Embedding dim + + # Verify max index is within bounds + max_idx = values.max().item() + self.assertLess(max_idx, vocab_size) + + +if __name__ == "__main__": + unittest.main() From 8da24baa72500eda7f85c2caf9c83d3d16919111 Mon Sep 17 00:00:00 2001 From: John Wu Date: Sun, 23 Nov 2025 04:15:34 -0600 Subject: [PATCH 5/7] new update --- examples/mortality_mimic4_stagenet_v2.py | 6 +- .../tutorial_stagenet_comprehensive.ipynb | 780 +++++++++++------- pyhealth/datasets/base_dataset.py | 65 +- 3 files changed, 511 insertions(+), 340 deletions(-) diff --git a/examples/mortality_mimic4_stagenet_v2.py b/examples/mortality_mimic4_stagenet_v2.py index 0818704fa..3fc2722a9 100644 --- a/examples/mortality_mimic4_stagenet_v2.py +++ b/examples/mortality_mimic4_stagenet_v2.py @@ -166,7 +166,7 @@ def generate_holdout_set( sample_dataset = base_dataset.set_task( MortalityPredictionStageNetMIMIC4(padding=20), - num_workers=4, + num_workers=1, cache_dir=cache_dir, input_processors=input_processors, output_processors=output_processors, @@ -175,7 +175,7 @@ def generate_holdout_set( print("\n=== Fitting New Processors ===") sample_dataset = base_dataset.set_task( MortalityPredictionStageNetMIMIC4(padding=20), - num_workers=4, + num_workers=1, cache_dir=cache_dir, ) @@ -227,7 +227,7 @@ def generate_holdout_set( trainer.train( train_dataloader=train_loader, val_dataloader=val_loader, - epochs=1, + epochs=20, monitor="roc_auc", optimizer_params={"lr": 1e-5}, ) diff --git a/examples/tutorial_stagenet_comprehensive.ipynb b/examples/tutorial_stagenet_comprehensive.ipynb index 17afeb973..3ae8082cf 100644 --- a/examples/tutorial_stagenet_comprehensive.ipynb +++ b/examples/tutorial_stagenet_comprehensive.ipynb @@ -21,7 +21,7 @@ "Install the latest alpha release of StageNet modernized for PyHealth:\n", "\n", "```bash\n", - "pip install pyhealth==2.0a8\n", + "pip install pyhealth==2.0a10\n", "```\n", "\n", "## Loading Data\n", @@ -41,50 +41,10 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "fd30b75b", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/johnwu3/miniconda3/envs/medical_coding_demo/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Memory usage Starting MIMIC4Dataset init: 772.0 MB\n", - "Initializing MIMIC4EHRDataset with tables: ['patients', 'admissions', 'diagnoses_icd', 'procedures_icd', 'labevents'] (dev mode: False)\n", - "Using default EHR config: /home/johnwu3/projects/PyHealth_Branch_Testing/PyHealth/pyhealth/datasets/configs/mimic4_ehr.yaml\n", - "Memory usage Before initializing mimic4_ehr: 772.0 MB\n", - "Duplicate table names in tables list. Removing duplicates.\n", - "Initializing mimic4_ehr dataset from /srv/local/data/physionet.org/files/mimiciv/2.2/ (dev mode: False)\n", - "Scanning table: procedures_icd from /srv/local/data/physionet.org/files/mimiciv/2.2/hosp/procedures_icd.csv.gz\n", - "Joining with table: /srv/local/data/physionet.org/files/mimiciv/2.2/hosp/admissions.csv.gz\n", - "Original path does not exist. Using alternative: /srv/local/data/physionet.org/files/mimiciv/2.2/hosp/admissions.csv\n", - "Scanning table: labevents from /srv/local/data/physionet.org/files/mimiciv/2.2/hosp/labevents.csv.gz\n", - "Joining with table: /srv/local/data/physionet.org/files/mimiciv/2.2/hosp/d_labitems.csv.gz\n", - "Scanning table: icustays from /srv/local/data/physionet.org/files/mimiciv/2.2/icu/icustays.csv.gz\n", - "Scanning table: patients from /srv/local/data/physionet.org/files/mimiciv/2.2/hosp/patients.csv.gz\n", - "Scanning table: diagnoses_icd from /srv/local/data/physionet.org/files/mimiciv/2.2/hosp/diagnoses_icd.csv.gz\n", - "Joining with table: /srv/local/data/physionet.org/files/mimiciv/2.2/hosp/admissions.csv.gz\n", - "Original path does not exist. Using alternative: /srv/local/data/physionet.org/files/mimiciv/2.2/hosp/admissions.csv\n", - "Scanning table: admissions from /srv/local/data/physionet.org/files/mimiciv/2.2/hosp/admissions.csv.gz\n", - "Original path does not exist. Using alternative: /srv/local/data/physionet.org/files/mimiciv/2.2/hosp/admissions.csv\n", - "Memory usage After initializing mimic4_ehr: 30378.1 MB\n", - "Memory usage After EHR dataset initialization: 30378.1 MB\n", - "Memory usage Before combining data: 30378.1 MB\n", - "Combining data from ehr dataset\n", - "Creating combined dataframe\n", - "Memory usage After combining data: 30378.1 MB\n", - "Memory usage Completed MIMIC4Dataset init: 30378.1 MB\n" - ] - } - ], + "outputs": [], "source": [ "\"\"\"\n", "Example of using StageNet for mortality prediction on MIMIC-IV.\n", @@ -104,7 +64,6 @@ "from pyhealth.models import StageNet\n", "from pyhealth.tasks import MortalityPredictionStageNetMIMIC4\n", "from pyhealth.trainer import Trainer\n", - "import torch\n", "\n", "# STEP 1: Load MIMIC-IV base dataset\n", "base_dataset = MIMIC4Dataset(\n", @@ -116,6 +75,7 @@ " \"procedures_icd\",\n", " \"labevents\",\n", " ],\n", + " dev=True\n", ")" ] }, @@ -192,7 +152,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "ba3e055d", "metadata": {}, "outputs": [], @@ -204,9 +164,8 @@ "from pyhealth.processors import register_processor\n", "from pyhealth.processors.base_processor import FeatureProcessor\n", "\n", - "\n", "@register_processor(\"stagenet_ex\")\n", - "class StageNetProcessorEx(FeatureProcessor):\n", + "class StageNetProcessor(FeatureProcessor):\n", " \"\"\"\n", " Feature processor for StageNet CODE inputs with coupled value/time data.\n", "\n", @@ -222,6 +181,12 @@ " - List of strings -> flat code sequences\n", " - List of lists of strings -> nested code sequences\n", "\n", + " Args:\n", + " padding: Additional padding to add on top of the observed maximum nested\n", + " sequence length. The actual padding length will be observed_max + padding.\n", + " This ensures the processor can handle sequences longer than those in the\n", + " training data. Default: 0 (no extra padding). Only applies to nested sequences.\n", + "\n", " Returns:\n", " Tuple of (time_tensor, value_tensor) where time_tensor can be None\n", "\n", @@ -233,10 +198,11 @@ " >>> values.shape # (3,) - sequence of code indices\n", " >>> time.shape # (3,) - time intervals\n", "\n", - " >>> # Case 2: Nested codes with time\n", + " >>> # Case 2: Nested codes with time (with custom padding for extra capacity)\n", + " >>> processor = StageNetProcessor(padding=20)\n", " >>> data = ([0.0, 1.5], [[\"A\", \"B\"], [\"C\"]])\n", " >>> time, values = processor.process(data)\n", - " >>> values.shape # (2, max_inner_len) - padded nested sequences\n", + " >>> values.shape # (2, observed_max + 20) - padded nested sequences\n", " >>> time.shape # (2,)\n", "\n", " >>> # Case 3: Codes without time\n", @@ -246,11 +212,14 @@ " >>> time # None\n", " \"\"\"\n", "\n", - " def __init__(self):\n", - " self.code_vocab: Dict[Any, int] = {\"\": -1, \"\": 0}\n", + " def __init__(self, padding: int = 0):\n", + " # will be set to len(vocab) after fit\n", + " self.code_vocab: Dict[Any, int] = {\"\": None, \"\": 0}\n", " self._next_index = 1\n", " self._is_nested = None # Will be determined during fit\n", - " self._max_nested_len = None # Max inner sequence length for nested codes\n", + " # Max inner sequence length for nested codes\n", + " self._max_nested_len = None\n", + " self._padding = padding # Additional padding beyond observed max\n", "\n", " def fit(self, samples: List[Dict], key: str) -> None:\n", " \"\"\"Build vocabulary and determine input structure.\n", @@ -301,9 +270,15 @@ " self.code_vocab[code] = self._next_index\n", " self._next_index += 1\n", "\n", - " # Store max nested length (at least 1 for empty sequences)\n", + " # Store max nested length: add user-specified padding to observed maximum\n", + " # This ensures the processor can handle sequences longer than those in training data\n", " if self._is_nested:\n", - " self._max_nested_len = max(1, max_inner_len)\n", + " observed_max = max(1, max_inner_len)\n", + " self._max_nested_len = observed_max + self._padding\n", + "\n", + " # Set token to the next available index\n", + " # Since is already in the vocab dict, we use _next_index\n", + " self.code_vocab[\"\"] = self._next_index\n", "\n", " def process(\n", " self, value: Tuple[Optional[List], List]\n", @@ -390,17 +365,19 @@ " return (\n", " f\"StageNetProcessor(is_nested={self._is_nested}, \"\n", " f\"vocab_size={len(self.code_vocab)}, \"\n", - " f\"max_nested_len={self._max_nested_len})\"\n", + " f\"max_nested_len={self._max_nested_len}, \"\n", + " f\"padding={self._padding})\"\n", " )\n", " else:\n", " return (\n", " f\"StageNetProcessor(is_nested={self._is_nested}, \"\n", - " f\"vocab_size={len(self.code_vocab)})\"\n", + " f\"vocab_size={len(self.code_vocab)}, \"\n", + " f\"padding={self._padding})\"\n", " )\n", "\n", "\n", "@register_processor(\"stagenet_tensor_ex\")\n", - "class StageNetTensorProcessorEx(FeatureProcessor):\n", + "class StageNetTensorProcessor(FeatureProcessor):\n", " \"\"\"\n", " Feature processor for StageNet NUMERIC inputs with coupled value/time data.\n", "\n", @@ -565,18 +542,18 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "a2288cdc", "metadata": {}, "outputs": [], "source": [ "from datetime import datetime\n", - "from typing import Any, ClassVar, Dict, List\n", + "from typing import Any, ClassVar, Dict, List, Tuple\n", "\n", "import polars as pl\n", "\n", "from pyhealth.tasks.base_task import BaseTask\n", - "from pyhealth.processors import StageNetProcessor, StageNetTensorProcessor\n", + "\n", "\n", "class MortalityPredictionStageNetMIMIC4(BaseTask):\n", " \"\"\"Task for predicting mortality using MIMIC-IV with StageNet format.\n", @@ -596,6 +573,10 @@ " - Multiple itemids per category → take first observed value\n", " - Missing categories → None/NaN in vector\n", "\n", + " Args:\n", + " padding: Additional padding for StageNet processor to handle\n", + " sequences longer than observed during training. Default: 0.\n", + "\n", " Attributes:\n", " task_name (str): The name of the task.\n", " input_schema (Dict[str, str]): The schema for input data:\n", @@ -607,11 +588,20 @@ " \"\"\"\n", "\n", " task_name: str = \"MortalityPredictionStageNetMIMIC4\"\n", - " input_schema: Dict[str, str] = {\n", - " \"icd_codes\": StageNetProcessor,\n", - " \"labs\": StageNetTensorProcessor,\n", - " }\n", - " output_schema: Dict[str, str] = {\"mortality\": \"binary\"}\n", + "\n", + " def __init__(self, padding: int = 0):\n", + " \"\"\"Initialize task with optional padding parameter.\n", + "\n", + " Args:\n", + " padding: Additional padding for nested sequences. Default: 0.\n", + " \"\"\"\n", + " self.padding = padding\n", + " # Use tuple format to pass kwargs to processor\n", + " self.input_schema: Dict[str, Tuple[str, Dict[str, Any]]] = {\n", + " \"icd_codes\": (\"stagenet\", {\"padding\": padding}),\n", + " \"labs\": (\"stagenet_tensor\", {}),\n", + " }\n", + " self.output_schema: Dict[str, str] = {\"mortality\": \"binary\"}\n", "\n", " # Organize lab items by category\n", " # Each category will map to ONE dimension in the output vector\n", @@ -853,85 +843,58 @@ "id": "38b799ce", "metadata": {}, "source": [ - "## Setting the task and caching the data for quicker use down the road\n", - "We can finally set our task and get our training set below. Notice that we save a processed version of our dataset in .parquet files in our \"cache_dir\" here. We can also define a number of works for faster parallel processing (note this can be unstable if the value is too high)." + "## Setting the task and caching the data for quicker use down the road with padding\n", + "We can finally set our task and get our training set below. Notice that we save a processed version of our dataset in .parquet files in our \"cache_dir\" here. We can also define a number of works for faster parallel processing (note this can be unstable if the value is too high).\n", + "\n", + "We can also save and load processors so we don't need to refit the processor again (and we can also transfer processors across different samples)" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "8e01f7ec", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Setting task MortalityPredictionStageNetMIMIC4 for mimic4 base dataset...\n", - "Loading cached samples from /home/johnwu3/projects/mimic4_stagenet_cache/MortalityPredictionStageNetMIMIC4.parquet\n", - "Loaded 137778 cached samples\n", - "Label mortality vocab: {0: 0, 1: 1}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Processing samples: 100%|██████████| 137778/137778 [00:33<00:00, 4171.51it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Generated 137778 samples for task MortalityPredictionStageNetMIMIC4\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - } - ], + "outputs": [], "source": [ - "# STEP 2: Apply StageNet mortality prediction task\n", - "sample_dataset = base_dataset.set_task(\n", - " MortalityPredictionStageNetMIMIC4(),\n", - " num_workers=4,\n", - " cache_dir=\"/home/johnwu3/projects/mimic4_stagenet_cache\",\n", - ")" + "from pyhealth.datasets.utils import save_processors, load_processors\n", + "import os \n", + "processor_dir = \"../../output/processors/stagenet_mortality_mimic4\"\n", + "cache_dir = \"../../mimic4_stagenet_cache_v3\"\n", + "\n", + "if os.path.exists(os.path.join(processor_dir, \"input_processors.pkl\")):\n", + " print(\"\\n=== Loading Pre-fitted Processors ===\")\n", + " input_processors, output_processors = load_processors(processor_dir)\n", + "\n", + " sample_dataset = base_dataset.set_task(\n", + " MortalityPredictionStageNetMIMIC4(padding=20),\n", + " num_workers=1,\n", + " cache_dir=cache_dir,\n", + " input_processors=input_processors,\n", + " output_processors=output_processors,\n", + " )\n", + "else:\n", + " print(\"\\n=== Fitting New Processors ===\")\n", + " sample_dataset = base_dataset.set_task(\n", + " MortalityPredictionStageNetMIMIC4(padding=20),\n", + " num_workers=1,\n", + " cache_dir=cache_dir,\n", + " )\n", + "\n", + " # Save processors for future runs\n", + " print(\"\\n=== Saving Processors ===\")\n", + " save_processors(sample_dataset, processor_dir)\n", + "\n", + "print(f\"Total samples: {len(sample_dataset)}\")\n", + "print(f\"Input schema: {sample_dataset.input_schema}\")\n", + "print(f\"Output schema: {sample_dataset.output_schema}\")" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "id": "1c765bec", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Sample structure:\n", - " Patient ID: 17503482\n", - "ICD Codes: (tensor([ 0.0000, 315.3167]), tensor([[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 0, 0, 0, 0,\n", - " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", - " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", - " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", - " 0, 0, 0, 0, 0, 0, 0],\n", - " [15, 16, 3, 4, 17, 7, 18, 19, 12, 14, 20, 0, 0, 0, 0, 0, 0, 0,\n", - " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", - " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", - " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", - " 0, 0, 0, 0, 0, 0, 0]]))\n", - " Labs shape: 3 timesteps\n", - " Mortality: tensor([0.])\n" - ] - } - ], + "outputs": [], "source": [ "# Inspect a sample\n", "sample = sample_dataset.samples[0]\n", @@ -955,168 +918,20 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "id": "1708dca9", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'icd_codes': pyhealth.processors.stagenet_processor.StageNetProcessor,\n", - " 'labs': pyhealth.processors.stagenet_processor.StageNetTensorProcessor}" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "sample_dataset.input_schema" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "id": "0333b99e", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Model initialized with 9337777 parameters\n", - "StageNet(\n", - " (embedding_model): EmbeddingModel(embedding_layers=ModuleDict(\n", - " (icd_codes): Embedding(36681, 128, padding_idx=0)\n", - " (labs): Linear(in_features=10, out_features=128, bias=True)\n", - " ))\n", - " (stagenet): ModuleDict(\n", - " (icd_codes): StageNetLayer(\n", - " (kernel): Linear(in_features=129, out_features=1542, bias=True)\n", - " (recurrent_kernel): Linear(in_features=385, out_features=1542, bias=True)\n", - " (nn_scale): Linear(in_features=384, out_features=64, bias=True)\n", - " (nn_rescale): Linear(in_features=64, out_features=384, bias=True)\n", - " (nn_conv): Conv1d(384, 384, kernel_size=(10,), stride=(1,))\n", - " (nn_dropconnect): Dropout(p=0.3, inplace=False)\n", - " (nn_dropconnect_r): Dropout(p=0.3, inplace=False)\n", - " (nn_dropout): Dropout(p=0.3, inplace=False)\n", - " (nn_dropres): Dropout(p=0.3, inplace=False)\n", - " )\n", - " (labs): StageNetLayer(\n", - " (kernel): Linear(in_features=129, out_features=1542, bias=True)\n", - " (recurrent_kernel): Linear(in_features=385, out_features=1542, bias=True)\n", - " (nn_scale): Linear(in_features=384, out_features=64, bias=True)\n", - " (nn_rescale): Linear(in_features=64, out_features=384, bias=True)\n", - " (nn_conv): Conv1d(384, 384, kernel_size=(10,), stride=(1,))\n", - " (nn_dropconnect): Dropout(p=0.3, inplace=False)\n", - " (nn_dropconnect_r): Dropout(p=0.3, inplace=False)\n", - " (nn_dropout): Dropout(p=0.3, inplace=False)\n", - " (nn_dropres): Dropout(p=0.3, inplace=False)\n", - " )\n", - " )\n", - " (fc): Linear(in_features=768, out_features=1, bias=True)\n", - ")\n", - "Metrics: ['pr_auc', 'roc_auc', 'accuracy', 'f1']\n", - "Device: cuda:4\n", - "\n", - "Training:\n", - "Batch size: 256\n", - "Optimizer: \n", - "Optimizer params: {'lr': 1e-05}\n", - "Weight decay: 0.0\n", - "Max grad norm: None\n", - "Val dataloader: \n", - "Monitor: roc_auc\n", - "Monitor criterion: max\n", - "Epochs: 1\n", - "\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Epoch 0 / 1: 100%|██████████| 431/431 [08:32<00:00, 1.19s/it]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "--- Train epoch-0, step-431 ---\n", - "loss: 0.4052\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n", - "Evaluation: 100%|██████████| 54/54 [00:14<00:00, 3.62it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "--- Eval epoch-0, step-431 ---\n", - "pr_auc: 0.0840\n", - "roc_auc: 0.5105\n", - "accuracy: 0.9439\n", - "f1: 0.0000\n", - "loss: 0.2540\n", - "New best roc_auc score (0.5105) at epoch-0, step-431\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Loaded best model\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Evaluation: 100%|██████████| 54/54 [00:15<00:00, 3.49it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Test Results:\n", - " pr_auc: 0.0817\n", - " roc_auc: 0.4831\n", - " accuracy: 0.9382\n", - " f1: 0.0000\n", - " loss: 0.2813\n", - "\n", - "Sample predictions:\n", - " Predicted probabilities: tensor([[0.0171],\n", - " [0.0088],\n", - " [0.0096],\n", - " [0.0073],\n", - " [0.0144]], device='cuda:4')\n", - " True labels: tensor([[1.],\n", - " [0.],\n", - " [0.],\n", - " [0.],\n", - " [0.]], device='cuda:4')\n" - ] - } - ], + "outputs": [], "source": [ "# STEP 3: Split dataset\n", "train_dataset, val_dataset, test_dataset = split_by_patient(\n", @@ -1172,20 +987,411 @@ "print(f\" True labels: {output['y_true'][:5]}\")" ] }, + { + "cell_type": "markdown", + "id": "e877f9cf", + "metadata": {}, + "source": [ + "## Inference On a Holdout Set Example\n", + "Below, we'll generate some pseudo samples with a bunch of unknown tokens and visit lengths beyond what's observed in the training set." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "59475dc0", + "metadata": {}, + "outputs": [], + "source": [ + "from pyhealth.datasets.base_dataset import SampleDataset\n", + "import random\n", + "import numpy as np\n", + "\n", + "def generate_holdout_set(\n", + " sample_dataset: SampleDataset, num_samples: int = 10, seed: int = 42\n", + ") -> SampleDataset:\n", + " \"\"\"Generate synthetic hold-out set with unseen codes and varying lengths.\n", + "\n", + " This function creates synthetic samples to test the processor's ability to:\n", + " 1. Handle completely unseen tokens (mapped to )\n", + " 2. Handle sequence lengths larger than training but within padding\n", + "\n", + " Args:\n", + " sample_dataset: Original SampleDataset with fitted processors\n", + " num_samples: Number of synthetic samples to generate\n", + " seed: Random seed for reproducibility\n", + "\n", + " Returns:\n", + " SampleDataset with synthetic samples using fitted processors\n", + " \"\"\"\n", + " random.seed(seed)\n", + " np.random.seed(seed)\n", + "\n", + " # Get the fitted processors\n", + " icd_processor = sample_dataset.input_processors[\"icd_codes\"]\n", + "\n", + " # Get max nested length from ICD processor\n", + " max_icd_len = icd_processor._max_nested_len\n", + " # Handle both old and new processor versions\n", + " padding = getattr(icd_processor, \"_padding\", 0)\n", + "\n", + " print(\"\\n=== Hold-out Set Generation ===\")\n", + " print(f\"Processor attributes: {dir(icd_processor)}\")\n", + " print(f\"Has _padding attribute: {hasattr(icd_processor, '_padding')}\")\n", + " print(f\"ICD max nested length: {max_icd_len}\")\n", + " print(f\"Padding (via getattr): {padding}\")\n", + " if hasattr(icd_processor, \"_padding\"):\n", + " print(f\"Padding (direct access): {icd_processor._padding}\")\n", + " print(f\"Observed max (without padding): {max_icd_len - padding}\")\n", + "\n", + " synthetic_samples = []\n", + "\n", + " for i in range(num_samples):\n", + " # Generate random number of visits (1-5)\n", + " num_visits = random.randint(1, 5)\n", + "\n", + " # Generate ICD codes with unseen tokens\n", + " icd_codes_list = []\n", + " icd_times_list = []\n", + "\n", + " for visit_idx in range(num_visits):\n", + " # Generate sequence length between observed_max and max_icd_len\n", + " # This tests the padding capacity\n", + " observed_max = max_icd_len - padding\n", + " seq_len = random.randint(max(1, observed_max - 2), max_icd_len - 1)\n", + "\n", + " # Generate unseen codes\n", + " visit_codes = [f\"NEWCODE_{i}_{visit_idx}_{j}\" for j in range(seq_len)]\n", + " icd_codes_list.append(visit_codes)\n", + "\n", + " # Generate time intervals (hours from previous visit)\n", + " if visit_idx == 0:\n", + " icd_times_list.append(0.0)\n", + " else:\n", + " icd_times_list.append(random.uniform(24.0, 720.0))\n", + "\n", + " # Generate lab data (10-dimensional vectors)\n", + " num_lab_timestamps = random.randint(5, 15)\n", + " lab_values_list = []\n", + " lab_times_list = []\n", + "\n", + " for ts_idx in range(num_lab_timestamps):\n", + " # Generate 10D vector with some random values and some None\n", + " lab_vector = []\n", + " for dim in range(10):\n", + " if random.random() < 0.8: # 80% chance of value\n", + " lab_vector.append(random.uniform(50.0, 150.0))\n", + " else:\n", + " lab_vector.append(None)\n", + "\n", + " lab_values_list.append(lab_vector)\n", + " lab_times_list.append(random.uniform(0.0, 48.0))\n", + "\n", + " # Create sample in the expected format (before processing)\n", + " synthetic_sample = {\n", + " \"patient_id\": f\"HOLDOUT_PATIENT_{i}\",\n", + " \"icd_codes\": (icd_times_list, icd_codes_list),\n", + " \"labs\": (lab_times_list, lab_values_list),\n", + " \"mortality\": random.randint(0, 1),\n", + " }\n", + "\n", + " synthetic_samples.append(synthetic_sample)\n", + "\n", + " # Create a new SampleDataset with the FITTED processors\n", + " holdout_dataset = SampleDataset(\n", + " samples=synthetic_samples,\n", + " input_schema=sample_dataset.input_schema,\n", + " output_schema=sample_dataset.output_schema,\n", + " dataset_name=f\"{sample_dataset.dataset_name}_holdout\",\n", + " task_name=sample_dataset.task_name,\n", + " input_processors=sample_dataset.input_processors,\n", + " output_processors=sample_dataset.output_processors,\n", + " )\n", + "\n", + " print(f\"Generated {len(holdout_dataset)} synthetic samples\")\n", + " sample_seq_lens = [len(s[\"icd_codes\"][1]) for s in synthetic_samples[:3]]\n", + " print(f\"Sample ICD sequence lengths: {sample_seq_lens}\")\n", + " sample_codes_per_visit = [\n", + " [len(visit) for visit in s[\"icd_codes\"][1]] for s in synthetic_samples[:3]\n", + " ]\n", + " print(f\"Sample codes per visit: {sample_codes_per_visit}\")\n", + "\n", + " return holdout_dataset\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a9d898c3", + "metadata": {}, + "outputs": [], + "source": [ + "holdout_dataset = generate_holdout_set(sample_dataset, num_samples=10, seed=42)\n", + "# Create dataloader for hold-out set\n", + "holdout_loader = get_dataloader(holdout_dataset, batch_size=16, shuffle=False)\n", + "# Inspect processed samples\n", + "print(\"\\n=== Inspecting Processed Hold-out Samples ===\")\n", + "holdout_batch = next(iter(holdout_loader))\n", + "with torch.no_grad():\n", + " holdout_output = model(**holdout_batch)" + ] + }, { "cell_type": "markdown", "id": "9a3d1b7f", "metadata": {}, "source": [ - "## Post-hoc ML processing (TBD)\n", + "## Post-hoc ML processing (Interpretability)\n", "We note that once the model's trained and evaluation metrics are derived. People may be interested in things like post-hoc interpretability or uncertainty quantification.\n", "\n", "We note that this is quite a work-in-progress for PyHealth 2.0, but the roadmap includes the following:\n", "\n", - "- Layer-wise relevance propagation (deep NN-based interpretability)\n", - "- Conformal Prediction: We do have many other UQ techniques [here](https://pyhealth.readthedocs.io/en/latest/api/calib.html)\n", + "- Integrated Gradients (deep NN-based interpretability)\n", + "- Conformal Prediction: We do have many other UQ techniques [here](https://pyhealth.readthedocs.io/en/latest/api/calib.html)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "f268b6ff", + "metadata": {}, + "outputs": [], + "source": [ + "from pyhealth.medcode import CrossMap, InnerMap\n", + "\n", + "LAB_CATEGORY_NAMES = MortalityPredictionStageNetMIMIC4.LAB_CATEGORY_NAMES\n", + "\n", + "def unravel(flat_index: int, shape: torch.Size):\n", + " coords = []\n", + " remaining = flat_index\n", + " for dim in reversed(shape):\n", + " coords.append(remaining % dim)\n", + " remaining //= dim\n", + " return list(reversed(coords))\n", + "\n", + "def decode_token(idx: int, processor, feature_key: str):\n", + " icd9cm = InnerMap.load(\"ICD9CM\")\n", + "\n", + " if processor is None or not hasattr(processor, \"code_vocab\"):\n", + " return str(idx)\n", + " reverse_vocab = {index: token for token, index in processor.code_vocab.items()}\n", + " token = reverse_vocab.get(idx, f\"\")\n", + "\n", + " if feature_key == \"icd_codes\" and token not in {\"\", \"\"}:\n", + " desc = icd9cm.lookup(token)\n", + " if desc:\n", + " return f\"{token}: {desc}\"\n", + "\n", + " return token\n", + "\n", + "\n", + "def print_top_attributions(\n", + " attributions,\n", + " batch,\n", + " processors,\n", + " top_k: int = 10,\n", + "):\n", + " for feature_key, attr in attributions.items():\n", + " attr_cpu = attr.detach().cpu()\n", + " if attr_cpu.dim() == 0 or attr_cpu.size(0) == 0:\n", + " continue\n", + "\n", + " feature_input = batch[feature_key]\n", + " if isinstance(feature_input, tuple):\n", + " feature_input = feature_input[1]\n", + " feature_input = feature_input.detach().cpu()\n", + "\n", + " flattened = attr_cpu[0].flatten()\n", + " if flattened.numel() == 0:\n", + " continue\n", + "\n", + " print(f\"\\nFeature: {feature_key}\")\n", + " k = min(top_k, flattened.numel())\n", + " top_values, top_indices = torch.topk(flattened.abs(), k=k)\n", + " processor = processors.get(feature_key) if processors else None\n", + " is_continuous = torch.is_floating_point(feature_input)\n", + "\n", + " for rank, (_, flat_idx) in enumerate(zip(top_values, top_indices), 1):\n", + " attribution_value = flattened[flat_idx].item()\n", + " coords = unravel(flat_idx.item(), attr_cpu[0].shape)\n", + "\n", + " if is_continuous:\n", + " actual_value = feature_input[0][tuple(coords)].item()\n", + " label = \"\"\n", + " if feature_key == \"labs\" and len(coords) >= 1:\n", + " lab_idx = coords[-1]\n", + " if lab_idx < len(LAB_CATEGORY_NAMES):\n", + " label = f\"{LAB_CATEGORY_NAMES[lab_idx]} \"\n", + " print(\n", + " f\" {rank:2d}. idx={coords} {label}value={actual_value:.4f} \"\n", + " f\"attr={attribution_value:+.6f}\"\n", + " )\n", + " else:\n", + " token_idx = int(feature_input[0][tuple(coords)].item())\n", + " token = decode_token(token_idx, processor, feature_key)\n", + " print(\n", + " f\" {rank:2d}. idx={coords} token='{token}' \"\n", + " f\"attr={attribution_value:+.6f}\"\n", + " )\n" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "d65d228e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Feature: icd_codes\n", + " 1. idx=[24, 48] token='' attr=-0.000597\n", + " 2. idx=[24, 62] token='' attr=-0.000513\n", + " 3. idx=[24, 41] token='' attr=+0.000425\n", + " 4. idx=[24, 32] token='' attr=-0.000415\n", + " 5. idx=[24, 12] token='' attr=+0.000386\n", + " 6. idx=[24, 50] token='' attr=+0.000382\n", + " 7. idx=[24, 42] token='' attr=-0.000380\n", + " 8. idx=[24, 28] token='' attr=+0.000370\n", + " 9. idx=[24, 57] token='' attr=-0.000350\n", + " 10. idx=[24, 38] token='' attr=-0.000348\n", + "\n", + "Feature: labs\n", + " 1. idx=[401, 5] Calcium value=0.0000 attr=+0.001794\n", + " 2. idx=[401, 3] Bicarbonate value=0.0000 attr=+0.001794\n", + " 3. idx=[401, 4] Glucose value=0.0000 attr=+0.001794\n", + " 4. idx=[401, 6] Magnesium value=0.0000 attr=+0.001794\n", + " 5. idx=[401, 1] Potassium value=0.0000 attr=+0.001794\n", + " 6. idx=[401, 7] Anion Gap value=0.0000 attr=+0.001794\n", + " 7. idx=[401, 8] Osmolality value=0.0000 attr=+0.001794\n", + " 8. idx=[401, 9] Phosphate value=0.0000 attr=+0.001794\n", + " 9. idx=[401, 2] Chloride value=0.0000 attr=+0.001794\n", + " 10. idx=[401, 0] Sodium value=0.0000 attr=+0.001794\n" + ] + } + ], + "source": [ + "from pyhealth.interpret.methods import DeepLift, IntegratedGradients\n", + "def move_batch_to_device(batch, target_device):\n", + " moved = {}\n", + " for key, value in batch.items():\n", + " if isinstance(value, torch.Tensor):\n", + " moved[key] = value.to(target_device)\n", + " elif isinstance(value, tuple):\n", + " moved[key] = tuple(v.to(target_device) for v in value)\n", + " else:\n", + " moved[key] = value\n", + " return moved\n", + "\n", + "device = torch.device(\"cpu\")\n", + "model.to(device)\n", + "ig = IntegratedGradients(model)\n", + "\n", "\n", - "For quick and dirty feature attribution, I would highly recommend something like [SHAP](https://shap.readthedocs.io/en/latest/). For conceptual interpretability within the embedding space, I highly recommend looking into sparse autoencoders." + "sample_batch = next(iter(test_loader))\n", + "sample_batch_device = move_batch_to_device(sample_batch, device)\n", + "\n", + "with torch.no_grad():\n", + " output = model(**sample_batch_device)\n", + " probs = output[\"y_prob\"]\n", + " preds = torch.argmax(probs, dim=-1)\n", + " label_key = model.label_key\n", + " true_label = sample_batch_device[label_key]\n", + "\n", + " print(\"\\nModel prediction for the sampled patient:\")\n", + " print(f\" True label: {int(true_label.cpu()[0].item())}\")\n", + " print(f\" Predicted class: {int(preds.cpu()[0].item())}\")\n", + " print(f\" Probabilities: {probs[0].cpu().numpy()}\")\n", + "\n", + "\n", + "attributions = ig.attribute(**sample_batch_device)\n", + "print_top_attributions(attributions, sample_batch_device, input_processors, top_k=10)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "902bb29c", + "metadata": {}, + "outputs": [], + "source": [ + "def build_random_embedding_baseline(\n", + " model: StageNet,\n", + " batch: dict,\n", + " scale: float = 0.01,\n", + " seed: int = 42,\n", + ") -> dict:\n", + " \"\"\"Construct a non-empty baseline directly in embedding space.\n", + "\n", + " DeepLIFT subtracts the baseline embedding from the actual embedding.\n", + " Using pure zeros collapses StageNet masks (all visits become padding),\n", + " so we add small random noise to keep at least one timestep active.\n", + " \"\"\"\n", + "\n", + " torch.manual_seed(seed)\n", + " feature_inputs = {}\n", + " for key in model.feature_keys:\n", + " value = batch[key]\n", + " if isinstance(value, tuple):\n", + " value = value[1]\n", + " feature_inputs[key] = value.to(model.device)\n", + "\n", + " embedded = model.embedding_model(feature_inputs)\n", + " baseline = {}\n", + " for key, emb in embedded.items():\n", + " baseline[key] = torch.randn_like(emb) * scale\n", + " return baseline\n" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "b32ef9e4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Feature: icd_codes\n", + " 1. idx=[0, 1] token='42832: Chronic diastolic heart failure' attr=+0.079825\n", + " 2. idx=[0, 6] token='V5861: Long-term (current) use of anticoagulants' attr=-0.070667\n", + " 3. idx=[0, 5] token='V4501: Cardiac pacemaker in situ' attr=-0.058043\n", + " 4. idx=[0, 10] token='370: Keratitis' attr=+0.056914\n", + " 5. idx=[2, 10] token='V4501: Cardiac pacemaker in situ' attr=-0.050888\n", + " 6. idx=[0, 7] token='4019: Unspecified essential hypertension' attr=-0.048502\n", + " 7. idx=[0, 3] token='4280: Congestive heart failure, unspecified' attr=+0.045676\n", + " 8. idx=[0, 2] token='4233: Cardiac tamponade' attr=+0.037603\n", + " 9. idx=[2, 13] token='4019: Unspecified essential hypertension' attr=-0.031371\n", + " 10. idx=[2, 5] token='4280: Congestive heart failure, unspecified' attr=-0.025716\n", + "\n", + "Feature: labs\n", + " 1. idx=[400, 5] Calcium value=0.0000 attr=+0.004160\n", + " 2. idx=[400, 3] Bicarbonate value=0.0000 attr=+0.004160\n", + " 3. idx=[400, 4] Glucose value=0.0000 attr=+0.004160\n", + " 4. idx=[400, 6] Magnesium value=0.0000 attr=+0.004160\n", + " 5. idx=[400, 1] Potassium value=0.0000 attr=+0.004160\n", + " 6. idx=[400, 7] Anion Gap value=0.0000 attr=+0.004160\n", + " 7. idx=[400, 8] Osmolality value=0.0000 attr=+0.004160\n", + " 8. idx=[400, 9] Phosphate value=0.0000 attr=+0.004160\n", + " 9. idx=[400, 2] Chloride value=0.0000 attr=+0.004160\n", + " 10. idx=[400, 0] Sodium value=0.0000 attr=+0.004160\n" + ] + } + ], + "source": [ + "deeplift = DeepLift(model)\n", + "\n", + "random_baseline = build_random_embedding_baseline(model, sample_batch_device)\n", + "attributions = deeplift.attribute(\n", + " baseline=random_baseline,\n", + " **sample_batch_device,\n", + ")\n", + "print_top_attributions(attributions, sample_batch_device, input_processors, top_k=10)\n" ] } ], diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index cbdcfdd3a..5da5ef0db 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -1,7 +1,6 @@ import logging import os import pickle -import time from abc import ABC from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path @@ -336,6 +335,7 @@ def iter_patients(self, df: Optional[pl.LazyFrame] = None) -> Iterator[Patient]: if df is None: df = self.collected_global_event_df grouped = df.group_by("patient_id") + for patient_id, patient_df in grouped: patient_id = patient_id[0] yield Patient(patient_id=patient_id, data_source=patient_df) @@ -443,58 +443,23 @@ def set_task( ): samples.extend(task(patient)) else: - # multi-threading with lazy iteration and bounded queue - logger.info( - "Generating samples for %s with %d workers", - task.task_name, - num_workers, - ) - - logger.info("Computing total patient count...") - start_time = time.time() - total_patients = filtered_global_event_df["patient_id"].n_unique() - elapsed = time.time() - start_time + # multi-threading (not recommended) logger.info( - "n_unique() completed in %.2f seconds, found %d patients", - elapsed, - total_patients, + f"Generating samples for {task.task_name} with " + f"{num_workers} workers" ) - - patients_iter = self.iter_patients(filtered_global_event_df) - max_in_flight = num_workers * 4 - + patients = list(self.iter_patients(filtered_global_event_df)) with ThreadPoolExecutor(max_workers=num_workers) as executor: - in_flight = {} - - # Prime the in-flight queue - try: - for _ in range(max_in_flight): - patient = next(patients_iter) - fut = executor.submit(task, patient) - in_flight[fut] = None - except StopIteration: - pass - - with tqdm( - total=total_patients, - desc=f"Processing {task.task_name}", - ) as pbar: - while in_flight: - for fut in as_completed(list(in_flight.keys())): - in_flight.pop(fut, None) - result = fut.result() - samples.extend(result) - pbar.update(1) - - try: - next_patient = next(patients_iter) - new_fut = executor.submit(task, next_patient) - in_flight[new_fut] = None - except StopIteration: - pass - - # Re-enter as_completed with updated future set - break + futures = [executor.submit(task, patient) for patient in patients] + for future in tqdm( + as_completed(futures), + total=len(futures), + desc=( + f"Collecting samples for {task.task_name} " + f"from {num_workers} workers" + ), + ): + samples.extend(future.result()) # Cache the samples if cache_dir is provided if cache_dir is not None: From b1e95f58551a2070c6a44cbdd6f923a1206cba78 Mon Sep 17 00:00:00 2001 From: John Wu Date: Sun, 23 Nov 2025 17:01:18 -0600 Subject: [PATCH 6/7] minor update to the number of workers used, turns out it does make a difference in processing speed --- examples/mortality_mimic4_stagenet_v2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/mortality_mimic4_stagenet_v2.py b/examples/mortality_mimic4_stagenet_v2.py index 3fc2722a9..a836cfa22 100644 --- a/examples/mortality_mimic4_stagenet_v2.py +++ b/examples/mortality_mimic4_stagenet_v2.py @@ -175,7 +175,7 @@ def generate_holdout_set( print("\n=== Fitting New Processors ===") sample_dataset = base_dataset.set_task( MortalityPredictionStageNetMIMIC4(padding=20), - num_workers=1, + num_workers=4, cache_dir=cache_dir, ) @@ -220,7 +220,7 @@ def generate_holdout_set( # STEP 5: Train the model trainer = Trainer( model=model, - device="cpu", # or "cpu" + device="cuda:2", # or "cpu" metrics=["pr_auc", "roc_auc", "accuracy", "f1"], ) From 0327e6b4b8f7d31d8c7673c320ad605c8fa11e2e Mon Sep 17 00:00:00 2001 From: John Wu Date: Sun, 23 Nov 2025 17:03:15 -0600 Subject: [PATCH 7/7] update again --- examples/mortality_mimic4_stagenet_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/mortality_mimic4_stagenet_v2.py b/examples/mortality_mimic4_stagenet_v2.py index a836cfa22..e093cded4 100644 --- a/examples/mortality_mimic4_stagenet_v2.py +++ b/examples/mortality_mimic4_stagenet_v2.py @@ -166,7 +166,7 @@ def generate_holdout_set( sample_dataset = base_dataset.set_task( MortalityPredictionStageNetMIMIC4(padding=20), - num_workers=1, + num_workers=4, cache_dir=cache_dir, input_processors=input_processors, output_processors=output_processors,