diff --git a/examples/benchmark_streaming.py b/examples/benchmark_streaming.py new file mode 100644 index 000000000..5175a6a41 --- /dev/null +++ b/examples/benchmark_streaming.py @@ -0,0 +1,326 @@ +"""Benchmark PyHealth Dataset Processing Performance + +This script benchmarks PyHealth dataset processing using the StageNet mortality +prediction task on MIMIC-IV. + +You can benchmark either streaming mode or normal mode: +- Streaming mode: Memory-efficient disk-backed processing (use --stream flag) +- Normal mode: Traditional in-memory processing (default) + +Measures: +- Processing time +- Peak memory usage +- Sample throughput + +Usage: + # Benchmark normal mode + python benchmark_streaming.py + + # Benchmark streaming mode + python benchmark_streaming.py --stream + + # Benchmark streaming mode with batch_size=1000 + python benchmark_streaming.py --stream --batch_size 1000 + + # Dev mode with 5000 patients + python benchmark_streaming.py --stream --dev --dev_max_patients 5000 + + # Use cached processors to skip fitting + python benchmark_streaming.py --stream --use_cached_processors +""" + +import argparse +import time +from pathlib import Path +from typing import Dict, Any + +import psutil +from torch.utils.data import DataLoader + +from pyhealth.datasets import ( + MIMIC4Dataset, + load_processors, + save_processors, +) +from pyhealth.tasks import MortalityPredictionStageNetMIMIC4 + + +def format_bytes(bytes_val: int) -> str: + """Format bytes to human-readable string.""" + for unit in ["B", "KB", "MB", "GB"]: + if bytes_val < 1024.0: + return f"{bytes_val:.2f} {unit}" + bytes_val /= 1024.0 + return f"{bytes_val:.2f} TB" + + +def benchmark_mode( + mode_name: str, + ehr_root: str, + cache_dir: str, + stream: bool, + num_workers: int = 0, + batch_size: int = 100, + use_cached_processors: bool = False, + processor_dir: str = None, + dev: bool = False, + dev_max_patients: int = 1000, +) -> Dict[str, Any]: + """Benchmark a single mode (streaming or normal). + + Args: + mode_name: Name for display ("Streaming" or "Normal") + ehr_root: Root directory of MIMIC-IV dataset + cache_dir: Directory for cache + stream: Whether to use streaming mode + num_workers: Number of worker processes for parallel processing + batch_size: Number of patients to process per batch (streaming only) + use_cached_processors: Whether to load pre-fitted processors + processor_dir: Directory containing cached processors + dev: Whether to enable dev mode (limit patients) + dev_max_patients: Maximum number of patients in dev mode + + Returns: + Dictionary containing benchmark metrics + """ + print(f"\n{'='*70}") + print( + f"{mode_name} Mode " + f"({num_workers} workers, " + f"batch_size={batch_size if stream else 'N/A'}, " + f"cached_procs={use_cached_processors})" + ) + print(f"{'='*70}\n") + + ehr_tables = [ + "patients", + "admissions", + "diagnoses_icd", + "procedures_icd", + "labevents", + ] + + task = MortalityPredictionStageNetMIMIC4() + + # Get process for memory tracking + process = psutil.Process() + + # Start tracking + baseline_memory = process.memory_info().rss + total_start = time.time() + + # Phase 1: Initialize dataset + print(f"[1/2] Initializing MIMIC4Dataset (stream={stream})...") + init_start = time.time() + init_memory_before = process.memory_info().rss + + dataset = MIMIC4Dataset( + ehr_root=ehr_root, + ehr_tables=ehr_tables, + stream=stream, + cache_dir=cache_dir, + dev=dev, + dev_max_patients=dev_max_patients, + ) + + init_time = time.time() - init_start + init_memory_after = process.memory_info().rss + init_memory_delta = init_memory_after - init_memory_before + print(f" Time: {init_time:.2f}s") + print(f" Memory Delta: {format_bytes(init_memory_delta)}") + + # Phase 2: Apply task + print(f"[2/2] Generating samples with {task.task_name}...") + task_start = time.time() + task_memory_before = process.memory_info().rss + + # Load processors if requested + input_processors = None + output_processors = None + + if use_cached_processors and processor_dir: + processor_dir_path = Path(processor_dir) + input_procs_file = processor_dir_path / "input_processors.pkl" + output_procs_file = processor_dir_path / "output_processors.pkl" + + if input_procs_file.exists() and output_procs_file.exists(): + print(f" Loading processors from {processor_dir}...") + load_start = time.time() + input_processors, output_processors = load_processors(processor_dir) + load_time = time.time() - load_start + print(f" Processors loaded in {load_time:.2f}s") + print(f" ✓ Skipping processor fitting!") + else: + print(f" WARNING: Processor files not found in {processor_dir}") + print(f" Will create new processors") + + sample_dataset = dataset.set_task( + task, + cache_dir=cache_dir, + num_workers=num_workers, + batch_size=batch_size, + input_processors=input_processors, + output_processors=output_processors, + ) + + # Save processors if they were newly created + if use_cached_processors and processor_dir: + if input_processors is None and output_processors is None: + print(f" Saving processors to {processor_dir}...") + save_start = time.time() + save_processors(sample_dataset, processor_dir) + save_time = time.time() - save_start + print(f" Processors saved in {save_time:.2f}s") + + task_time = time.time() - task_start + task_memory_after = process.memory_info().rss + task_memory_delta = task_memory_after - task_memory_before + print(f" Time: {task_time:.2f}s") + print(f" Memory Delta: {format_bytes(task_memory_delta)}") + + # Get final stats + total_time = time.time() - total_start + final_memory = process.memory_info().rss + total_memory_delta = final_memory - baseline_memory + + num_samples = len(sample_dataset) + + # Print summary + print(f"\n{'='*70}") + print(f"{mode_name} Mode Results") + print(f"{'='*70}") + print(f"Total Time: {total_time:.2f}s") + print(f" - Init: {init_time:.2f}s") + print(f" - Task/Samples: {task_time:.2f}s") + print(f"Total Memory: {format_bytes(final_memory)}") + print(f"Memory Delta: {format_bytes(total_memory_delta)}") + print(f" - Init Delta: {format_bytes(init_memory_delta)}") + print(f" - Task Delta: {format_bytes(task_memory_delta)}") + print(f"Samples: {num_samples}") + if total_time > 0: + throughput = num_samples / total_time + print(f"Throughput: {throughput:.2f} samples/sec") + print(f"{'='*70}\n") + + return { + "mode": mode_name, + "total_time": total_time, + "init_time": init_time, + "task_time": task_time, + "total_memory": final_memory, + "memory_delta": total_memory_delta, + "init_memory_delta": init_memory_delta, + "task_memory_delta": task_memory_delta, + "num_samples": num_samples, + "throughput": num_samples / total_time if total_time > 0 else 0, + } + + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark PyHealth dataset processing" + ) + parser.add_argument( + "--ehr_root", + type=str, + default="/srv/local/data/physionet.org/files/mimiciv/2.2/", + help="Root directory of MIMIC-IV dataset", + ) + parser.add_argument( + "--cache_dir", + type=str, + default="../benchmark_cache", + help="Directory for cache files", + ) + parser.add_argument( + "--stream", + action="store_true", + help="Use streaming mode (default: normal mode)", + ) + parser.add_argument( + "--num_workers", + type=int, + default=4, + help="Number of worker processes for parallel processing", + ) + parser.add_argument( + "--batch_size", + type=int, + default=100, + help="Number of patients to process per batch in streaming mode", + ) + parser.add_argument( + "--use_cached_processors", + action="store_true", + help="Load pre-fitted processors instead of fitting from scratch", + ) + parser.add_argument( + "--processor_dir", + type=str, + default="../output/processors/stagenet_mortality_mimic4_benchmark_10k", + help="Directory to load/save cached processors", + ) + parser.add_argument( + "--dev", + action="store_true", + default=False, + help="Enable dev mode to limit number of patients (default: False)", + ) + parser.add_argument( + "--dev_max_patients", + type=int, + default=1000, + help="Maximum number of patients in dev mode (default: 1000)", + ) + + args = parser.parse_args() + + mode_name = "STREAMING" if args.stream else "NORMAL" + + print("\n" + "=" * 70) + print(f"PyHealth {mode_name} Mode Benchmark") + print("=" * 70) + print( + f"Dataset: MIMIC-IV (Dev Mode: {args.dev}, Max Patients: {args.dev_max_patients})" + ) + print("Task: StageNet Mortality Prediction") + print(f"Mode: {mode_name}") + print(f"EHR Root: {args.ehr_root}") + print(f"Cache Dir: {args.cache_dir}") + print(f"Workers: {args.num_workers}") + if args.stream: + print(f"Batch Size: {args.batch_size}") + print(f"Use Cached Processors: {args.use_cached_processors}") + if args.use_cached_processors: + print(f"Processor Dir: {args.processor_dir}") + print("=" * 70) + + # Benchmark the selected mode + results = benchmark_mode( + mode_name=mode_name, + ehr_root=args.ehr_root, + cache_dir=args.cache_dir, + stream=args.stream, + num_workers=args.num_workers, + batch_size=args.batch_size, + use_cached_processors=args.use_cached_processors, + processor_dir=args.processor_dir, + dev=args.dev, + dev_max_patients=args.dev_max_patients, + ) + + # Print final summary + print("\n" + "=" * 70) + print("FINAL SUMMARY") + print("=" * 70) + print(f"Mode: {mode_name}") + print(f"Total Time: {results['total_time']:.2f}s") + print(f"Peak Memory: {format_bytes(results['total_memory'])}") + print(f"Memory Delta: {format_bytes(results['memory_delta'])}") + print(f"Samples Generated: {results['num_samples']}") + print(f"Throughput: {results['throughput']:.2f} samples/sec") + print("=" * 70 + "\n") + + +if __name__ == "__main__": + main() diff --git a/examples/mortality_mimic4_stagenet_streaming.py b/examples/mortality_mimic4_stagenet_streaming.py new file mode 100644 index 000000000..106ea8cc3 --- /dev/null +++ b/examples/mortality_mimic4_stagenet_streaming.py @@ -0,0 +1,252 @@ +""" +Example of using StageNet for mortality prediction on MIMIC-IV with +STREAMING MODE. + +This example demonstrates the new streaming mode for memory-efficient +training: +1. Loading MIMIC-IV data in streaming mode (stream=True) +2. Applying the MortalityPredictionStageNetMIMIC4 task +3. Creating an IterableSampleDataset with disk-backed storage +4. Training a StageNet model with streaming data + +Key differences from non-streaming mode: +- stream=True: Data loaded from disk on-demand +- IterableSampleDataset: Samples not stored in memory +- No random shuffling (sequential iteration only) +- Much lower memory footprint +- Ideal for large datasets (>100k samples) + +Note: IterableDataset is fully compatible with PyTorch DataLoader +and Trainer! +""" + +from pyhealth.datasets import ( + MIMIC4Dataset, + get_dataloader, + split_by_patient_stream, +) +from pyhealth.models import StageNet +from pyhealth.tasks import MortalityPredictionStageNetMIMIC4 +from pyhealth.trainer import Trainer +import torch +import psutil +import os + + +def get_memory_usage(): + """Get current memory usage in MB.""" + process = psutil.Process(os.getpid()) + return process.memory_info().rss / 1024 / 1024 + + +def print_memory_stats(stage_name): + """Print current memory usage statistics.""" + mem_mb = get_memory_usage() + print(f"[Memory] {stage_name}: {mem_mb:.1f} MB") + return mem_mb + + +# Track memory at start +initial_memory = print_memory_stats("Initial") + +# STEP 1: Load MIMIC-IV base dataset in STREAMING MODE +print("=" * 70) +print("STREAMING MODE: Memory-Efficient Dataset Processing") +print("=" * 70) + +base_dataset = MIMIC4Dataset( + ehr_root="/srv/local/data/physionet.org/files/mimiciv/2.2/", + ehr_tables=[ + "patients", + "admissions", + "diagnoses_icd", + "procedures_icd", + "labevents", + ], + stream=True, # ⭐ Enable streaming mode for memory efficiency + cache_dir="../mimic4_streaming_cache", # Disk-backed cache + # dev=True, # Set to True for quick testing with limited patients + # dev_max_patients=10000, # Only used if dev=True +) + +print("Dataset mode: STREAMING (disk-backed)") +print(f"Cache directory: {base_dataset.cache_dir}") +print_memory_stats("After loading base dataset") + +# STEP 2: Apply StageNet mortality prediction task in streaming mode +print("\n" + "=" * 70) +print("Applying Task with Streaming Sample Generation") +print("=" * 70) + +sample_dataset = base_dataset.set_task( + MortalityPredictionStageNetMIMIC4(), + batch_size=1000, # ⭐ Process patients in batches for I/O efficiency + cache_dir="../mimic4_stagenet_streaming_cache", +) + +print(f"Dataset type: {type(sample_dataset).__name__}") +print(f"Total samples: {len(sample_dataset)}") +print(f"Input schema: {sample_dataset.input_schema}") +print(f"Output schema: {sample_dataset.output_schema}") +print_memory_stats("After applying task (samples on disk)") + +# Inspect a sample (IterableDataset requires iteration) +print("\nSample structure (from first sample):") +for i, sample in enumerate(sample_dataset): + print(f" Patient ID: {sample['patient_id']}") + print(f" ICD Codes: {sample['icd_codes']}") + print(f" Labs shape: {len(sample['labs'][0])} timesteps") + print(f" Mortality: {sample['mortality']}") + if i == 0: # Only show first sample + break + +# STEP 3: Create train/val/test splits using patient ID filtering +print("\n" + "=" * 70) +print("Creating Train/Val/Test Splits with Filtering") +print("=" * 70) + +# ⭐ Use get_patient_ids() to get patients with samples! +# This method reads from the cache and returns only patients that have +# valid samples after task processing (e.g., excluding patients without +# mortality outcomes). +patients_with_samples = sample_dataset.get_patient_ids() + +# Optional: Show how many patients were excluded by task processing +base_patient_count = len(base_dataset.patient_ids) +sample_patient_count = len(patients_with_samples) +if sample_patient_count < base_patient_count: + excluded_count = base_patient_count - sample_patient_count + print( + f"Note: {excluded_count} patients excluded by task processing " + f"(no valid outcomes)" + ) + +# Split patient IDs into train/val/test sets +train_patient_ids, val_patient_ids, test_patient_ids = split_by_patient_stream( + patients_with_samples, # ← Use the property! + ratios=[0.8, 0.1, 0.1], # 80% train, 10% val, 10% test + seed=42, +) + +print(f"Total patients with samples: {len(patients_with_samples)}") +print(f"Train patients: {len(train_patient_ids)}") +print(f"Val patients: {len(val_patient_ids)}") +print(f"Test patients: {len(test_patient_ids)}") + +# Create filtered views of the dataset using predicate pushdown +# ⭐ Physical splits: Each split gets its own parquet file for memory efficiency! +train_dataset = sample_dataset.filter_by_patients(train_patient_ids, split_name="train") +val_dataset = sample_dataset.filter_by_patients(val_patient_ids, split_name="val") +test_dataset = sample_dataset.filter_by_patients(test_patient_ids, split_name="test") + +print("\nPhysical splits created:") +print("✓ Each split has its own parquet file") +print("✓ No in-memory filtering needed during training") +print("✓ Maximum memory efficiency for large datasets") +print_memory_stats("After creating filtered datasets") + +# Create dataloaders +# ⭐ Note: shuffle=False for IterableDataset (no random access) +train_loader = get_dataloader(train_dataset, batch_size=256, shuffle=False) +val_loader = get_dataloader(val_dataset, batch_size=256, shuffle=False) +test_loader = get_dataloader(test_dataset, batch_size=256, shuffle=False) + +print("\nDataLoaders created (batch_size=256)") +print("Note: IterableDataset uses sequential iteration (no shuffling)") +print_memory_stats("After creating dataloaders") + +# STEP 4: Initialize StageNet model +print("\n" + "=" * 70) +print("Initializing StageNet Model") +print("=" * 70) + +model = StageNet( + dataset=sample_dataset, + embedding_dim=128, + chunk_size=128, + levels=3, + dropout=0.3, +) + +num_params = sum(p.numel() for p in model.parameters()) +print(f"Model initialized with {num_params:,} parameters") +print_memory_stats("After model initialization") + +# STEP 5: Train the model with streaming data +print("\n" + "=" * 70) +print("Training with Streaming Data") +print("=" * 70) + +trainer = Trainer( + model=model, + device="cuda:5", # or "cpu" + metrics=["pr_auc", "roc_auc", "accuracy", "f1"], +) + +# ⭐ Training works the same way with IterableDataset! +# PyTorch's DataLoader handles IterableDataset automatically +trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, # Now we have validation! + epochs=2, # Fewer epochs for demo + monitor="roc_auc", + optimizer_params={"lr": 1e-5}, +) + +print("\n" + "=" * 70) +print("Training Complete!") +print("=" * 70) +final_memory = print_memory_stats("After training") +print(f"Total memory increase: {final_memory - initial_memory:.1f} MB") + +# STEP 6: Evaluate on test set +print("\n" + "=" * 70) +print("Evaluating on Test Set") +print("=" * 70) + +test_results = trainer.evaluate(test_loader) +print(f"Test Results: {test_results}") + +print("\n" + "=" * 70) +print("Streaming Mode Workflow Complete!") +print("=" * 70) +print("Workflow summary:") +print("1. Split patient IDs into train/val/test sets") +print("2. Create filtered views using filter_by_patients()") +print("3. All splits share the same cache (no regeneration)") +print("4. Polars predicate pushdown for efficient filtering") +print("5. Train with validation and test evaluation") + +# STEP 7: Show memory benefits +print("\n" + "=" * 70) +print("Memory Benefits of Streaming Mode") +print("=" * 70) +print("✓ Samples stored on disk (not in RAM)") +print("✓ Only active batch loaded in memory") +print("✓ Memory usage independent of dataset size") +print("✓ Ideal for datasets >100k samples") +print("✓ Enables training on massive datasets with limited RAM") +print("\nMemory Usage Summary:") +print(f" Initial: {initial_memory:.1f} MB") +print(f" Final: {final_memory:.1f} MB") +print(f" Increase: {final_memory - initial_memory:.1f} MB") +print( + f" Note: Most memory used by model parameters " + f"({num_params * 4 / 1024 / 1024:.1f} MB)" +) + +# STEP 8: Inspect model predictions +print("\n" + "=" * 70) +print("Sample Predictions") +print("=" * 70) + +sample_batch = next(iter(train_loader)) +with torch.no_grad(): + output = model(**sample_batch) + +print(f"Predicted probabilities (first 5): {output['y_prob'][:5]}") +print(f"True labels (first 5): {output['y_true'][:5]}") + +print("\n" + "=" * 70) +print("Streaming Mode Demo Complete!") +print("=" * 70) diff --git a/pyhealth/data/data.py b/pyhealth/data/data.py index 2a6d3a45c..5f77ad2e7 100644 --- a/pyhealth/data/data.py +++ b/pyhealth/data/data.py @@ -123,19 +123,44 @@ class Patient: event_type_partitions (Dict[str, pl.DataFrame]): Dictionary mapping event types to their respective DataFrame partitions. """ - def __init__(self, patient_id: str, data_source: pl.DataFrame) -> None: + def __init__( + self, patient_id: str, data_source: pl.DataFrame, lazy_partition: bool = False + ) -> None: """ Initialize a Patient instance. Args: patient_id (str): Unique patient identifier. data_source (pl.DataFrame): DataFrame containing all events. + lazy_partition (bool): If True, delay partitioning until needed (memory optimization for streaming). + Default is False to maintain backward compatibility. """ self.patient_id = patient_id self.data_source = data_source.sort("timestamp") - self.event_type_partitions = self.data_source.partition_by("event_type", maintain_order=True, as_dict=True) + self._lazy_partition = lazy_partition - def _filter_by_time_range_regular(self, df: pl.DataFrame, start: Optional[datetime], end: Optional[datetime]) -> pl.DataFrame: + if lazy_partition: + # Streaming mode: delay partition_by to save memory + self._event_type_partitions = None + else: + # Normal mode: pre-compute partitions for fast access (original behavior) + self._event_type_partitions = self.data_source.partition_by( + "event_type", maintain_order=True, as_dict=True + ) + + @property + def event_type_partitions(self) -> Dict[tuple, pl.DataFrame]: + """Get event type partitions, computing lazily if needed.""" + if self._event_type_partitions is None: + # Lazy computation for streaming mode + self._event_type_partitions = self.data_source.partition_by( + "event_type", maintain_order=True, as_dict=True + ) + return self._event_type_partitions + + def _filter_by_time_range_regular( + self, df: pl.DataFrame, start: Optional[datetime], end: Optional[datetime] + ) -> pl.DataFrame: """Regular filtering by time. Time complexity: O(n).""" if start is not None: df = df.filter(pl.col("timestamp") >= start) @@ -143,7 +168,9 @@ def _filter_by_time_range_regular(self, df: pl.DataFrame, start: Optional[dateti df = df.filter(pl.col("timestamp") <= end) return df - def _filter_by_time_range_fast(self, df: pl.DataFrame, start: Optional[datetime], end: Optional[datetime]) -> pl.DataFrame: + def _filter_by_time_range_fast( + self, df: pl.DataFrame, start: Optional[datetime], end: Optional[datetime] + ) -> pl.DataFrame: """Fast filtering by time using binary search on sorted timestamps. Time complexity: O(log n).""" if start is None and end is None: return df @@ -157,13 +184,17 @@ def _filter_by_time_range_fast(self, df: pl.DataFrame, start: Optional[datetime] end_idx = np.searchsorted(ts_col, end, side="right") return df.slice(start_idx, end_idx - start_idx) - def _filter_by_event_type_regular(self, df: pl.DataFrame, event_type: Optional[str]) -> pl.DataFrame: + def _filter_by_event_type_regular( + self, df: pl.DataFrame, event_type: Optional[str] + ) -> pl.DataFrame: """Regular filtering by event type. Time complexity: O(n).""" if event_type: df = df.filter(pl.col("event_type") == event_type) return df - def _filter_by_event_type_fast(self, df: pl.DataFrame, event_type: Optional[str]) -> pl.DataFrame: + def _filter_by_event_type_fast( + self, df: pl.DataFrame, event_type: Optional[str] + ) -> pl.DataFrame: """Fast filtering by event type using pre-built event type index. Time complexity: O(1).""" if event_type: return self.event_type_partitions.get((event_type,), df[:0]) @@ -203,7 +234,9 @@ def get_events( # df = self._filter_by_time_range_regular(df, start, end) if filters: - assert event_type is not None, "event_type must be provided if filters are provided" + assert ( + event_type is not None + ), "event_type must be provided if filters are provided" else: filters = [] exprs = [] diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 91a8da937..0f0db9bfb 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -60,12 +60,14 @@ def __init__(self, *args, **kwargs): from .mimicextract import MIMICExtractDataset from .omop import OMOPDataset from .sample_dataset import SampleDataset +from .iterable_sample_dataset import IterableSampleDataset from .shhs import SHHSDataset from .sleepedf import SleepEDFDataset from .bmd_hs import BMDHSDataset from .splitter import ( split_by_patient, split_by_patient_conformal, + split_by_patient_stream, split_by_sample, split_by_sample_conformal, split_by_visit, diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 43948c828..2e52a1856 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -1,22 +1,25 @@ import logging import os -import pickle from abc import ABC -from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path -from typing import Dict, Iterator, List, Optional +from typing import Dict, Iterator, List, Optional, Union from urllib.parse import urlparse, urlunparse import polars as pl import requests -from tqdm import tqdm from ..data import Patient from ..tasks import BaseTask from ..processors.base_processor import FeatureProcessor from .configs import load_yaml_config -from .sample_dataset import SampleDataset -from .utils import _convert_for_cache, _restore_from_cache +from .processing import ( + build_patient_cache, + iter_patients_streaming, + set_task_normal, + set_task_streaming, + setup_streaming_cache, + _create_patients_from_dataframe, +) logger = logging.getLogger(__name__) @@ -103,7 +106,8 @@ class BaseDataset(ABC): dataset_name (str): Name of the dataset. config (dict): Configuration loaded from a YAML file. global_event_df (pl.LazyFrame): The global event data frame. - dev (bool): Whether to enable dev mode (limit to 1000 patients). + dev (bool): Whether to enable dev mode (limit patients). + dev_max_patients (int): Max patients in dev mode (default 1000). """ def __init__( @@ -113,6 +117,9 @@ def __init__( dataset_name: Optional[str] = None, config_path: Optional[str] = None, dev: bool = False, + dev_max_patients: int = 1000, + stream: bool = False, + cache_dir: Optional[str] = None, ): """Initializes the BaseDataset. @@ -121,7 +128,14 @@ def __init__( tables (List[str]): List of table names to load. dataset_name (Optional[str]): Name of the dataset. Defaults to class name. config_path (Optional[str]): Path to the configuration YAML file. - dev (bool): Whether to run in dev mode (limits to 1000 patients). + dev (bool): Whether to run in dev mode (limits number of patients). + dev_max_patients (int): Maximum number of patients in dev mode. + Default is 1000. Only used when dev=True. + stream (bool): Whether to enable streaming mode for memory efficiency. + When True, data is loaded from disk on-demand rather than kept in memory. + Default is False for backward compatibility. + cache_dir (Optional[str]): Directory for streaming cache. If None, uses + {root}/.pyhealth_cache. Only used when stream=True. """ if len(set(tables)) != len(tables): logger.warning("Duplicate table names in tables list. Removing duplicates.") @@ -131,6 +145,18 @@ def __init__( self.dataset_name = dataset_name or self.__class__.__name__ self.config = load_yaml_config(config_path) self.dev = dev + self.dev_max_patients = dev_max_patients + self.stream = stream + + # Setup cache directory + if cache_dir is None: + self.cache_dir = Path(root) / ".pyhealth_cache" + else: + self.cache_dir = Path(cache_dir) + + if self.stream: + logger.info(f"Stream mode enabled - using disk cache at {self.cache_dir}") + setup_streaming_cache(self) logger.info( f"Initializing {self.dataset_name} dataset from {self.root} (dev mode: {self.dev})" @@ -142,13 +168,57 @@ def __init__( self._collected_global_event_df = None self._unique_patient_ids = None + # Streaming-specific attributes + # Initialize to None, will be set by setup_streaming_cache if needed + self._patient_cache_path = None + self._patient_index_path = None + self._patient_index = None + + def _setup_streaming_cache(self) -> None: + """Setup disk-backed cache directory structure for streaming mode. + + Delegates to processing.streaming_mode.setup_streaming_cache(). + Kept as a method for backward compatibility. + """ + setup_streaming_cache(self) + + def _build_patient_cache( + self, + filtered_df: Optional[pl.LazyFrame] = None, + force_rebuild: bool = False, + ) -> None: + """Build disk-backed patient cache with efficient indexing. + + Delegates to processing.streaming_mode.build_patient_cache(). + Kept as a method for backward compatibility. + + Args: + filtered_df: Pre-filtered LazyFrame (e.g., from task.pre_filter()). + If None, uses self.global_event_df. + force_rebuild: If True, rebuild cache even if it exists. + """ + build_patient_cache(self, filtered_df, force_rebuild) + @property def collected_global_event_df(self) -> pl.DataFrame: """Collects and returns the global event data frame. + WARNING: This property is NOT available in stream mode as it would + load the entire dataset into memory, defeating the purpose of streaming. + Returns: pl.DataFrame: The collected global event data frame. + + Raises: + RuntimeError: If called in stream mode. """ + if self.stream: + raise RuntimeError( + "collected_global_event_df is not available in stream mode " + "as it would load the entire dataset into memory. " + "Use iter_patients() or get_patient() for memory-efficient access." + ) + if self._collected_global_event_df is None: logger.info("Collecting global event dataframe...") @@ -157,8 +227,14 @@ def collected_global_event_df(self) -> pl.DataFrame: # TODO: dev doesn't seem to improve the speed / memory usage if self.dev: # Limit the number of patients in dev mode - logger.info("Dev mode enabled: limiting to 1000 patients") - limited_patients = df.select(pl.col("patient_id")).unique().limit(1000) + logger.info( + f"Dev mode enabled: limiting to {self.dev_max_patients} patients" + ) + limited_patients = ( + df.select(pl.col("patient_id")) + .unique() + .limit(self.dev_max_patients) + ) df = df.join(limited_patients, on="patient_id", how="inner") self._collected_global_event_df = df.collect() @@ -290,26 +366,73 @@ def load_table(self, table_name: str) -> pl.LazyFrame: return event_frame + @property + def patient_ids(self) -> List[str]: + """Returns a list of patient IDs, limited in dev mode. + + This is the primary property for getting patient IDs. In dev mode, + it automatically limits to dev_max_patients for faster testing. + + Returns: + List[str]: List of patient IDs (limited to dev_max_patients if dev=True) + """ + full_patient_ids = self.unique_patient_ids + + # Limit patients in dev mode + if self.dev and len(full_patient_ids) > self.dev_max_patients: + logger.info( + f"Dev mode: limiting from {len(full_patient_ids)} " + f"to {self.dev_max_patients} patients" + ) + return full_patient_ids[: self.dev_max_patients] + + return full_patient_ids + @property def unique_patient_ids(self) -> List[str]: - """Returns a list of unique patient IDs. + """Returns the full list of unique patient IDs (ignores dev mode). + + This property always returns ALL patient IDs regardless of dev mode. + For dev-mode-aware access, use the `patient_ids` property instead. Returns: - List[str]: List of unique patient IDs. + List[str]: Complete list of unique patient IDs """ if self._unique_patient_ids is None: - self._unique_patient_ids = ( - self.collected_global_event_df.select("patient_id") - .unique() - .to_series() - .to_list() - ) + if self.stream: + # Streaming mode: Get patient IDs from patient index + # Ensure paths are set up (should be done in __init__) + if self._patient_cache_path is None: + setup_streaming_cache(self) + + # Build cache only if either file is missing + if not ( + self._patient_cache_path.exists() + and self._patient_index_path.exists() + ): + self._build_patient_cache() + + patient_index = pl.scan_parquet(self._patient_index_path).collect( + streaming=True + ) + self._unique_patient_ids = patient_index["patient_id"].to_list() + else: + # Normal mode: Get from collected DataFrame + self._unique_patient_ids = ( + self.collected_global_event_df.select("patient_id") + .unique() + .to_series() + .to_list() + ) logger.info(f"Found {len(self._unique_patient_ids)} unique patient IDs") return self._unique_patient_ids def get_patient(self, patient_id: str) -> Patient: """Retrieves a Patient object for the given patient ID. + In streaming mode, loads the patient from disk cache. + In normal mode, filters from the collected DataFrame. + Args: patient_id (str): The ID of the patient to retrieve. @@ -322,22 +445,122 @@ def get_patient(self, patient_id: str) -> Patient: assert ( patient_id in self.unique_patient_ids ), f"Patient {patient_id} not found in dataset" - df = self.collected_global_event_df.filter(pl.col("patient_id") == patient_id) - return Patient(patient_id=patient_id, data_source=df) - def iter_patients(self, df: Optional[pl.LazyFrame] = None) -> Iterator[Patient]: + if self.stream: + # Streaming mode: Load patient from disk cache + if not self._patient_cache_path.exists(): + self._build_patient_cache() + + patient_df = ( + pl.scan_parquet(self._patient_cache_path) + .filter(pl.col("patient_id") == patient_id) + .collect(streaming=True) + ) + return Patient(patient_id=patient_id, data_source=patient_df) + else: + # Normal mode: Filter from collected DataFrame + df = self.collected_global_event_df.filter( + pl.col("patient_id") == patient_id + ) + return Patient(patient_id=patient_id, data_source=df) + + def iter_patients( + self, + df: Optional[pl.DataFrame] = None, + patient_ids: Optional[List[str]] = None, + batch_size: Optional[int] = None, + ) -> Iterator[Union[Patient, List[Patient]]]: """Yields Patient objects for each unique patient in the dataset. + Automatically uses streaming iteration when stream=True and df is None. + In normal mode, loads data into memory and iterates. + + Args: + df (Optional[pl.DataFrame]): Optional pre-filtered DataFrame. + If None, behavior depends on stream mode: + - stream=False: Uses collected_global_event_df (loads to memory) + - stream=True: Uses disk-backed streaming iteration + patient_ids (Optional[List[str]]): Optional list of specific patient IDs + to iterate over. Only used in streaming mode when df is None. + batch_size (Optional[int]): If specified, yields batches of Patient objects + instead of individual patients. This is much more efficient for I/O as it + loads multiple patients in a single query. Recommended for streaming mode. + If None (default), yields individual Patient objects. + Yields: - Iterator[Patient]: An iterator over Patient objects. + Union[Patient, List[Patient]]: Individual Patient objects (if batch_size=None) + or batches of Patient objects (if batch_size is specified). + + Example: + >>> # Individual patients (default) + >>> for patient in dataset.iter_patients(): + ... print(patient.patient_id) + + >>> # Batches of 100 patients (much faster for streaming) + >>> for patient_batch in dataset.iter_patients(batch_size=100): + ... print(f"Processing {len(patient_batch)} patients") """ + # Determine effective batch size (1 for single-patient mode) + effective_batch_size = batch_size if batch_size is not None else 1 + + # Batch mode - yield batches of patients (or single patients if batch_size=1) if df is None: - df = self.collected_global_event_df - grouped = df.group_by("patient_id") + if self.stream: + # Streaming mode: Delegate to processing module + # This handles both single (batch_size=1) and batch modes + yield from iter_patients_streaming( + self, patient_ids, effective_batch_size + ) + else: + # Normal mode: Batch from in-memory DataFrame + df = self.collected_global_event_df + + # Filter to specific patients if requested + if patient_ids is not None: + df = df.filter(pl.col("patient_id").is_in(patient_ids)) + + all_patient_ids = df.select("patient_id").unique().to_series().to_list() + + # Process in batches + for i in range(0, len(all_patient_ids), effective_batch_size): + batch_patient_ids = all_patient_ids[i : i + effective_batch_size] + + batch_df = df.filter(pl.col("patient_id").is_in(batch_patient_ids)) - for patient_id, patient_df in grouped: - patient_id = patient_id[0] - yield Patient(patient_id=patient_id, data_source=patient_df) + # Convert DataFrame to Patient objects + batch_patients = _create_patients_from_dataframe(batch_df) + + # Yield single patient or batch depending on batch_size + if batch_size is None: + # Single patient mode + if batch_patients: + yield batch_patients[0] + else: + # Batch mode + yield batch_patients + else: + # DataFrame provided: Batch iteration + if patient_ids is not None: + df = df.filter(pl.col("patient_id").is_in(patient_ids)) + + all_patient_ids = df.select("patient_id").unique().to_series().to_list() + + for i in range(0, len(all_patient_ids), effective_batch_size): + batch_patient_ids = all_patient_ids[i : i + effective_batch_size] + + batch_df = df.filter(pl.col("patient_id").is_in(batch_patient_ids)) + + # Convert DataFrame to Patient objects + batch_patients = _create_patients_from_dataframe(batch_df) + + # Yield single patient or batch depending on batch_size + if batch_size is None: + # Single patient mode + if batch_patients: + yield batch_patients[0] + else: + # Batch mode + yield batch_patients def stats(self) -> None: """Prints statistics about the dataset.""" @@ -364,30 +587,36 @@ def set_task( cache_format: str = "parquet", input_processors: Optional[Dict[str, FeatureProcessor]] = None, output_processors: Optional[Dict[str, FeatureProcessor]] = None, - ) -> SampleDataset: + batch_size: Optional[int] = None, + ): """Processes the base dataset to generate the task-specific sample dataset. Args: task (Optional[BaseTask]): The task to set. Uses default task if None. num_workers (int): Number of workers for multi-threading. Default is 1. - This is because the task function is usually CPU-bound. And using - multi-threading may not speed up the task function. + Only used in non-streaming mode. cache_dir (Optional[str]): Directory to cache processed samples. Default is None (no caching). cache_format (str): Format for caching ('parquet' or 'pickle'). Default is 'parquet'. input_processors (Optional[Dict[str, FeatureProcessor]]): - Pre-fitted input processors. If provided, these will be used - instead of creating new ones from task's input_schema. Defaults to None. + Pre-fitted input processors. output_processors (Optional[Dict[str, FeatureProcessor]]): - Pre-fitted output processors. If provided, these will be used - instead of creating new ones from task's output_schema. Defaults to None. + Pre-fitted output processors. + batch_size (Optional[int]): Number of patients to process per batch. + Required in streaming mode. Larger batches = better I/O efficiency + but more memory. Typical values: 50-500. Default is None. + - Streaming mode: Required, raises error if not provided. + - Normal mode: Ignored (not used). Returns: - SampleDataset: The generated sample dataset. + Union[SampleDataset, IterableSampleDataset]: The generated sample dataset. + Returns SampleDataset in non-streaming mode. + Returns IterableSampleDataset in streaming mode. Raises: AssertionError: If no default task is found and task is None. + ValueError: If streaming mode is enabled but batch_size is not provided. """ if task is None: assert self.default_task is not None, "No default tasks found" @@ -397,102 +626,31 @@ def set_task( f"Setting task {task.task_name} for {self.dataset_name} base dataset..." ) - # Check for cached data if cache_dir is provided - samples = None - if cache_dir is not None: - cache_filename = f"{task.task_name}.{cache_format}" - cache_path = Path(cache_dir) / cache_filename - if cache_path.exists(): - logger.info(f"Loading cached samples from {cache_path}") - try: - if cache_format == "parquet": - # Load samples from parquet file - cached_df = pl.read_parquet(cache_path) - samples = [ - _restore_from_cache(row) for row in cached_df.to_dicts() - ] - elif cache_format == "pickle": - # Load samples from pickle file - with open(cache_path, "rb") as f: - samples = pickle.load(f) - else: - msg = f"Unsupported cache format: {cache_format}" - raise ValueError(msg) - logger.info(f"Loaded {len(samples)} cached samples") - except Exception as e: - logger.warning( - "Failed to load cached data: %s. Regenerating...", - e, - ) - samples = None - - # Generate samples if not loaded from cache - if samples is None: - logger.info(f"Generating samples with {num_workers} worker(s)...") - filtered_global_event_df = task.pre_filter(self.collected_global_event_df) - samples = [] - - if num_workers == 1: - # single-threading (by default) - for patient in tqdm( - self.iter_patients(filtered_global_event_df), - total=filtered_global_event_df["patient_id"].n_unique(), - desc=(f"Generating samples for {task.task_name} " "with 1 worker"), - smoothing=0, - ): - samples.extend(task(patient)) - else: - # multi-threading (not recommended) - logger.info( - f"Generating samples for {task.task_name} with " - f"{num_workers} workers" + # Delegate to appropriate processing mode + if self.stream: + # Streaming mode: memory-efficient disk-backed processing + if batch_size is None: + raise ValueError( + "batch_size is required for streaming mode. " + "Typical values: 50-500 patients per batch. " + "Example: dataset.set_task(task, batch_size=100)" ) - patients = list(self.iter_patients(filtered_global_event_df)) - with ThreadPoolExecutor(max_workers=num_workers) as executor: - futures = [executor.submit(task, patient) for patient in patients] - for future in tqdm( - as_completed(futures), - total=len(futures), - desc=( - f"Collecting samples for {task.task_name} " - f"from {num_workers} workers" - ), - ): - samples.extend(future.result()) - - # Cache the samples if cache_dir is provided - if cache_dir is not None: - cache_path = Path(cache_dir) / cache_filename - cache_path.parent.mkdir(parents=True, exist_ok=True) - logger.info(f"Caching samples to {cache_path}") - try: - if cache_format == "parquet": - # Save samples as parquet file - samples_for_cache = [ - _convert_for_cache(sample) for sample in samples - ] - samples_df = pl.DataFrame(samples_for_cache) - samples_df.write_parquet(cache_path) - elif cache_format == "pickle": - # Save samples as pickle file - with open(cache_path, "wb") as f: - pickle.dump(samples, f) - else: - msg = f"Unsupported cache format: {cache_format}" - raise ValueError(msg) - logger.info(f"Successfully cached {len(samples)} samples") - except Exception as e: - logger.warning(f"Failed to cache samples: {e}") - - sample_dataset = SampleDataset( - samples, - input_schema=task.input_schema, - output_schema=task.output_schema, - dataset_name=self.dataset_name, - task_name=task, - input_processors=input_processors, - output_processors=output_processors, - ) - - logger.info(f"Generated {len(samples)} samples for task {task.task_name}") - return sample_dataset + return set_task_streaming( + dataset=self, + task=task, + batch_size=batch_size, + cache_dir=cache_dir, + input_processors=input_processors, + output_processors=output_processors, + ) + else: + # Normal mode: traditional in-memory processing + return set_task_normal( + dataset=self, + task=task, + num_workers=num_workers, + cache_dir=cache_dir, + cache_format=cache_format, + input_processors=input_processors, + output_processors=output_processors, + ) diff --git a/pyhealth/datasets/iterable_sample_dataset.py b/pyhealth/datasets/iterable_sample_dataset.py new file mode 100644 index 000000000..b9269e7af --- /dev/null +++ b/pyhealth/datasets/iterable_sample_dataset.py @@ -0,0 +1,748 @@ +from typing import Dict, List, Optional, Union, Type, Iterator +import inspect +import json +import logging +from pathlib import Path + +from torch.utils.data import IterableDataset +import torch +import polars as pl + +from ..processors import get_processor +from ..processors.base_processor import FeatureProcessor +from .utils import _convert_for_cache, deserialize_sample_from_parquet + +logger = logging.getLogger(__name__) + + +class IterableSampleDataset(IterableDataset): + """Iterable sample dataset class for streaming mode. + + This class provides memory-efficient iteration over samples stored on disk. + It is designed for streaming mode and is the recommended approach for + large datasets that don't fit in memory. + + Key differences from SampleDataset: + - Inherits from IterableDataset (not Dataset) + - No __getitem__ support (iteration only) + - Samples stored on disk, loaded in batches + - Memory usage independent of dataset size + + Attributes: + input_schema (Dict[str, Union[str, Type[FeatureProcessor]]]): + Schema for input data. + output_schema (Dict[str, Union[str, Type[FeatureProcessor]]]): + Schema for output data. + dataset_name (Optional[str]): Name of the dataset. + task_name (Optional[str]): Name of the task. + cache_dir (Path): Directory for disk-backed cache. + """ + + def __init__( + self, + input_schema: Dict[str, Union[str, Type[FeatureProcessor]]], + output_schema: Dict[str, Union[str, Type[FeatureProcessor]]], + dataset_name: Optional[str] = None, + task_name: Optional[str] = None, + input_processors: Optional[Dict[str, FeatureProcessor]] = None, + output_processors: Optional[Dict[str, FeatureProcessor]] = None, + cache_dir: Optional[str] = None, + dev: bool = False, + dev_max_patients: int = 1000, + patient_ids: Optional[List[str]] = None, + io_batch_size: int = 1000, + ) -> None: + """Initializes the IterableSampleDataset. + + Args: + input_schema: Schema for input data. + output_schema: Schema for output data. + dataset_name: Name of the dataset. + task_name: Name of the task. + input_processors: Pre-fitted input processors. + output_processors: Pre-fitted output processors. + cache_dir: Directory for disk-backed cache. + dev: Whether dev mode is enabled (for separate caching). + dev_max_patients: Max patients for dev mode (used in cache naming). + patient_ids: Optional list of patient IDs to filter to. + If provided, only samples from these patients will be yielded. + If None, will be computed from cache after finalization. + Uses Polars predicate pushdown for efficient filtering. + Useful for train/val/test splits. + io_batch_size: Number of samples to read from disk per I/O operation. + Larger values = fewer I/O ops but more memory. Default 1000. + Recommended: 500-2000. Should be >= DataLoader batch_size. + """ + self.input_schema = input_schema + self.output_schema = output_schema + self.dataset_name = dataset_name or "" + self.task_name = task_name or "" + self.dev = dev + self.dev_max_patients = dev_max_patients + self.patient_ids = patient_ids # Can be None initially + self.io_batch_size = io_batch_size # I/O batch size for reading from disk + + # Processor dictionaries + self.input_processors = input_processors or {} + self.output_processors = output_processors or {} + + # Setup streaming storage + self.cache_dir = Path(cache_dir) if cache_dir else Path(".pyhealth_cache") + self._setup_streaming_storage() + + def _setup_streaming_storage(self) -> None: + """Setup disk-backed sample storage for streaming mode.""" + self.cache_dir.mkdir(parents=True, exist_ok=True) + + # Add dev suffix to separate dev and full caches + if self.dev: + suffix = f"_dev_{self.dev_max_patients}" + else: + suffix = "" + + self._sample_cache_path = ( + self.cache_dir + / f"{self.dataset_name}_{self.task_name}_samples{suffix}.parquet" + ) + + # Directory for storing individual batch files before combining + self._sample_batch_dir = ( + self.cache_dir / f"{self.dataset_name}_{self.task_name}_batches{suffix}" + ) + + # Track number of samples for __len__ + self._num_samples = 0 + self._samples_finalized = False + self._batch_counter = 0 # Track number of batch files + + logger.info(f"Streaming sample cache: {self._sample_cache_path}") + + def add_samples_streaming(self, samples: List[Dict]) -> None: + """Add samples to disk-backed storage in streaming mode. + + This method writes samples to parquet file incrementally, + allowing processing of datasets larger than memory. + + Args: + samples: List of sample dictionaries to add + + Raises: + RuntimeError: If called after finalize_samples() + """ + if self._samples_finalized: + raise RuntimeError( + "Cannot add more samples after finalize_samples() has been called" + ) + + if not samples: + return # Nothing to add + + # Convert samples for cache-friendly storage + converted_samples = [_convert_for_cache(s) for s in samples] + + # Serialize complex nested structures to JSON strings for Parquet storage + # This avoids Polars type inference issues with mixed nested types + serialized_samples = [] + for sample in converted_samples: + serialized_sample = {} + for key, value in sample.items(): + if isinstance(value, dict) and "__stagenet_cache__" in value: + # Serialize StageNet cache dicts to JSON strings + serialized_sample[key] = json.dumps(value) + else: + serialized_sample[key] = value + serialized_samples.append(serialized_sample) + + # DEBUG: Inspect first converted sample to understand structure + if self._num_samples == 0 and len(serialized_samples) > 0: + logger.info("DEBUG: Inspecting first serialized sample structure:") + first_sample = serialized_samples[0] + for key, value in first_sample.items(): + value_type = type(value).__name__ + if isinstance(value, str) and len(value) > 100: + logger.info(f" {key}: {value_type} (JSON, len={len(value)})") + else: + logger.info(f" {key}: {value_type} = {value}") + + # Convert samples to DataFrame + try: + sample_df = pl.DataFrame(serialized_samples) + except Exception as e: + logger.error(f"Failed to create DataFrame from samples: {e}") + logger.error("Sample structure causing issue:") + for key in serialized_samples[0].keys(): + values = [s[key] for s in serialized_samples[:3]] + logger.error(f" {key}: {[type(v).__name__ for v in values]}") + raise + + # Write batch to individual file (safer than appending) + # Create batch directory if needed + self._sample_batch_dir.mkdir(parents=True, exist_ok=True) + + # Write this batch to a separate file + batch_file = self._sample_batch_dir / f"batch_{self._batch_counter:06d}.parquet" + sample_df.write_parquet(batch_file, compression="zstd") + + self._batch_counter += 1 + self._num_samples += len(samples) + logger.debug(f"Added {len(samples)} samples (total: {self._num_samples})") + + def finalize_samples(self) -> None: + """Finalize sample writing and prepare for reading. + + Call this after all samples have been added via add_samples_streaming(). + This combines all batch files into a single parquet file. + """ + if self._samples_finalized: + logger.warning("finalize_samples() called multiple times") + return + + # Combine all batch files into final cache file + if self._batch_counter > 0: + logger.info(f"Combining {self._batch_counter} batch files...") + + # Read all batch files and concatenate using streaming + batch_files = sorted(self._sample_batch_dir.glob("batch_*.parquet")) + + if len(batch_files) == 1: + # Only one batch - just rename it + import shutil + + shutil.move(str(batch_files[0]), str(self._sample_cache_path)) + else: + # Multiple batches - concatenate using scan for memory efficiency + lazy_frames = [pl.scan_parquet(f) for f in batch_files] + combined = pl.concat(lazy_frames, how="diagonal") + + # Write final file using streaming + combined.sink_parquet(self._sample_cache_path, compression="zstd") + + # Clean up batch files with more robust error handling + import shutil + import time + + try: + shutil.rmtree(self._sample_batch_dir) + except OSError as e: + # If deletion fails, try again after a brief pause + # (files might still be held by OS) + logger.warning( + f"Failed to delete batch dir on first try: {e}. Retrying..." + ) + time.sleep(0.5) + try: + shutil.rmtree(self._sample_batch_dir) + except OSError as e: + # If it still fails, just log a warning and continue + # The directory will be overwritten on next run anyway + logger.warning( + f"Could not delete batch directory " + f"{self._sample_batch_dir}: {e}" + ) + + logger.info(f"Combined into {self._sample_cache_path}") + + self._samples_finalized = True + logger.info(f"Finalized {self._num_samples} samples in streaming mode") + + def build_streaming(self) -> None: + """Build processors in streaming mode. + + Strategy: Read samples in batches to fit processors without + loading everything to memory. Samples remain in cache as raw data + and are processed on-the-fly during iteration. + + This method requires that samples have been finalized. + """ + if not self._samples_finalized: + raise RuntimeError("Must call finalize_samples() before build_streaming()") + + logger.info("Building processors in streaming mode...") + + # Step 1: Create processor instances (only for missing ones) + # Track which processors need fitting + processors_to_fit = {} + + if not self.input_processors: + for k, v in self.input_schema.items(): + processor = self._get_processor_instance(v) + self.input_processors[k] = processor + processors_to_fit[k] = processor + if not self.output_processors: + for k, v in self.output_schema.items(): + processor = self._get_processor_instance(v) + self.output_processors[k] = processor + processors_to_fit[k] = processor + + # Step 2: Fit processors by reading samples in batches + # Only fit processors that were just created (not pre-fitted ones) + if processors_to_fit: + logger.info(f"Fitting {len(processors_to_fit)} processors on samples...") + batch_size = 1000 + + # Warn about large batch size and Polars streaming issues + if batch_size > 200: + logger.warning( + f"Using batch_size={batch_size} for processor fitting. " + f"Note: Polars streaming is disabled for slice operations " + f"to avoid race conditions in async parquet reader. " + f"For better performance, consider batch_size <= 200." + ) + + lf = pl.scan_parquet(self._sample_cache_path) + + # Check if we have enough memory to load all samples + # For now, use streaming fit if more than 100k samples + use_streaming_fit = self._num_samples > 100_000 + + if use_streaming_fit: + # Streaming fit: Process in batches without accumulating + logger.info( + f"Using streaming fit for {self._num_samples} samples " + f"(memory-efficient mode)" + ) + num_batches = (self._num_samples + batch_size - 1) // batch_size + + for i in range(num_batches): + batch = lf.slice(i * batch_size, batch_size).collect( + streaming=False + ) + batch_samples = batch.to_dicts() + restored_samples = [ + deserialize_sample_from_parquet(s) for s in batch_samples + ] + + # Fit incrementally on each batch + for key, processor in processors_to_fit.items(): + processor.fit(restored_samples, key, stream=True) + + # Free batch memory immediately + del restored_samples + + # Finalize all processors + for key, processor in processors_to_fit.items(): + if hasattr(processor, "finalize_fit"): + processor.finalize_fit() + else: + # Non-streaming fit: Load all samples then fit once + logger.info("Loading all samples for processor fitting...") + all_restored_samples = [] + num_batches = (self._num_samples + batch_size - 1) // batch_size + for i in range(num_batches): + batch = lf.slice(i * batch_size, batch_size).collect( + streaming=False + ) + batch_samples = batch.to_dicts() + restored_samples = [ + deserialize_sample_from_parquet(s) for s in batch_samples + ] + all_restored_samples.extend(restored_samples) + + # Fit each processor once on all samples + for key, processor in processors_to_fit.items(): + logger.debug(f"Fitting processor for key: {key}") + processor.fit(all_restored_samples, key, stream=False) + + # Clean up to free memory + del all_restored_samples + else: + logger.info("Using pre-fitted processors (skipping fit)") + + logger.info( + "Processors ready! Processing will happen on-the-fly during iteration." + ) + + def _get_processor_instance( + self, processor_spec: Union[str, Type[FeatureProcessor]] + ) -> FeatureProcessor: + """Get processor instance from schema value. + + Args: + processor_spec: Either a string alias or a processor class + + Returns: + Processor instance + """ + if isinstance(processor_spec, str): + return get_processor(processor_spec)() + elif inspect.isclass(processor_spec) and issubclass( + processor_spec, FeatureProcessor + ): + return processor_spec() + else: + raise ValueError( + f"Processor spec must be either a string alias or a " + f"FeatureProcessor class, got {type(processor_spec)}" + ) + + def filter_by_patients( + self, + patient_ids: List[str], + split_name: Optional[str] = None, + ) -> "IterableSampleDataset": + """Create a filtered dataset by physically splitting samples into a new cache file. + + This creates a separate parquet file containing only samples from the + specified patients. This is much more memory-efficient than in-memory + filtering for large datasets, as each split can be streamed independently + without loading the full dataset. + + Args: + patient_ids: List of patient IDs to include in this split + split_name: Optional name for this split (e.g., 'train', 'val', 'test'). + Used for naming the cache file. If None, generates a hash-based name. + + Returns: + New IterableSampleDataset with its own cache file containing only + the filtered samples + + Example: + >>> # Split patient IDs + >>> train_ids, val_ids, test_ids = split_by_patient_stream( + ... full_dataset.get_patient_ids(), [0.8, 0.1, 0.1] + ... ) + >>> + >>> # Create physical splits (each gets its own cache file) + >>> train_ds = full_dataset.filter_by_patients(train_ids, split_name='train') + >>> val_ds = full_dataset.filter_by_patients(val_ids, split_name='val') + >>> test_ds = full_dataset.filter_by_patients(test_ids, split_name='test') + >>> + >>> # Each dataset streams from its own file - maximum memory efficiency! + """ + # Create new instance with same config + filtered_ds = IterableSampleDataset( + input_schema=self.input_schema, + output_schema=self.output_schema, + dataset_name=self.dataset_name, + task_name=self.task_name, + input_processors=self.input_processors, + output_processors=self.output_processors, + cache_dir=str(self.cache_dir), + dev=self.dev, + dev_max_patients=self.dev_max_patients, + patient_ids=None, # Physical split - no in-memory filtering + io_batch_size=self.io_batch_size, # Inherit I/O batch size + ) + + # Create the physical split cache file + self._create_physical_split(filtered_ds, patient_ids, split_name) + + return filtered_ds + + def _create_physical_split( + self, + filtered_ds: "IterableSampleDataset", + patient_ids: List[str], + split_name: Optional[str], + ) -> None: + """Create a physical split by writing filtered samples to a new parquet file. + + This is much more memory-efficient than using is_in() filtering for large datasets. + + Args: + filtered_ds: The new filtered dataset instance to configure + patient_ids: List of patient IDs to include in this split + split_name: Name for this split (e.g., 'train', 'val', 'test') + """ + # Generate split cache path + if split_name: + split_suffix = f"_{split_name}" + else: + # Use hash of patient IDs if no name provided + import hashlib + + id_hash = hashlib.md5("".join(sorted(patient_ids)).encode()).hexdigest()[:8] + split_suffix = f"_split_{id_hash}" + + dev_suffix = f"_dev_{self.dev_max_patients}" if self.dev else "" + split_cache_path = ( + self.cache_dir + / f"{self.dataset_name}_{self.task_name}_samples{split_suffix}{dev_suffix}.parquet" + ) + + # Check if split already exists + if split_cache_path.exists(): + logger.info(f"Using existing split cache: {split_cache_path}") + else: + # Create the split by filtering and writing + logger.info( + f"Creating physical split with {len(patient_ids)} patients " + f"→ {split_cache_path}" + ) + + # Read from main cache, filter, and write to split cache + lf = pl.scan_parquet(self._sample_cache_path) + filtered_lf = lf.filter(pl.col("patient_id").is_in(patient_ids)) + + # Write the filtered data (this materializes it once) + filtered_lf.sink_parquet(split_cache_path, compression="zstd") + + logger.info(f"Physical split created: {split_cache_path}") + + # Configure the filtered dataset to use the split cache + filtered_ds._sample_cache_path = split_cache_path + filtered_ds._sample_batch_dir = self._sample_batch_dir # Not used for splits + filtered_ds._samples_finalized = True + filtered_ds._batch_counter = 0 + + # Compute num_samples from the split cache + lf = pl.scan_parquet(split_cache_path) + filtered_ds._num_samples = lf.select(pl.len()).collect(streaming=True).item() + + logger.info( + f"Split '{split_name or 'unnamed'}' has {filtered_ds._num_samples} samples" + ) + + def __iter__(self) -> Iterator[Dict]: + """Iterate over samples efficiently. + + This is the main method for accessing samples in streaming mode. + Samples are read from disk in batches, deserialized, and processed + on-the-fly with fitted processors. + + If filter_patient_ids is set, only samples from those patients are + yielded. Polars' predicate pushdown optimizes this filtering at the + Parquet level for efficiency. + + Yields: + Sample dictionaries with processed features (as tensors) + + Example: + >>> for sample in dataset: + ... train_on_sample(sample) + """ + if not self._samples_finalized: + raise RuntimeError("Cannot iterate before finalize_samples()") + + # Get worker info for distributed training + worker_info = torch.utils.data.get_worker_info() + + # Use configured I/O batch size for reading from disk + # This is the number of samples read per disk I/O operation + # The DataLoader will then batch these samples for training + batch_size = self.io_batch_size + + # Warn about large batch sizes that may trigger Polars streaming bugs + if batch_size > 2000: + logger.warning( + f"io_batch_size={batch_size} is very large and may cause " + f"memory issues. Consider using io_batch_size <= 2000 for " + f"better stability." + ) + + lf = pl.scan_parquet(self._sample_cache_path) + + # Note: With physical splits, patient_ids should be None + # If somehow patient_ids is set, we still support in-memory filtering + # (but this is not recommended for large datasets) + if self.patient_ids is not None: + logger.warning( + "In-memory patient ID filtering detected! " + "For better memory efficiency, use filter_by_patients() " + "to create physical splits instead." + ) + lf = lf.filter(pl.col("patient_id").is_in(self.patient_ids)) + filtered_count = lf.select(pl.count()).collect(streaming=True).item() + num_samples_to_iterate = filtered_count + else: + num_samples_to_iterate = self._num_samples + + if worker_info is None: + # Single worker - iterate all samples + if self.patient_ids is not None: + # With filtering, iterate based on actual filtered count + num_batches = (num_samples_to_iterate + batch_size - 1) // batch_size + logger.debug( + f"Iterating {num_batches} batches " + f"({num_samples_to_iterate} samples)" + ) + for batch_idx in range(num_batches): + offset = batch_idx * batch_size + # Collect a batch using slice on the filtered LazyFrame + batch = lf.slice(offset, batch_size).collect(streaming=False) + + # If batch is empty, we've processed all filtered samples + if len(batch) == 0: + logger.warning( + f"Empty batch at offset {offset}, stopping iteration" + ) + break + + for sample in batch.to_dicts(): + # Deserialize from parquet format + restored_sample = deserialize_sample_from_parquet(sample) + + # Process with fitted processors + for k, v in restored_sample.items(): + if k in self.input_processors: + restored_sample[k] = self.input_processors[k].process(v) + elif k in self.output_processors: + restored_sample[k] = self.output_processors[k].process( + v + ) + + yield restored_sample + else: + # No filtering - use efficient offset-based iteration + num_batches = (self._num_samples + batch_size - 1) // batch_size + for batch_idx in range(num_batches): + offset = batch_idx * batch_size + length = min(batch_size, self._num_samples - offset) + # Disable streaming for slice operations due to Polars bug: + # Large batches (>200) trigger race conditions in async + # parquet reader causing "range end index out of bounds" errors. + # Using streaming=False forces synchronous reads, avoiding bug. + batch = lf.slice(offset, length).collect(streaming=False) + for sample in batch.to_dicts(): + # Deserialize from parquet format + restored_sample = deserialize_sample_from_parquet(sample) + + # Process with fitted processors + for k, v in restored_sample.items(): + if k in self.input_processors: + restored_sample[k] = self.input_processors[k].process(v) + elif k in self.output_processors: + restored_sample[k] = self.output_processors[k].process( + v + ) + + yield restored_sample + else: + # Multiple workers - partition samples + worker_id = worker_info.id + num_workers = worker_info.num_workers + + if self.patient_ids is not None: + # With filtering, iterate in batches and partition + num_batches = (num_samples_to_iterate + batch_size - 1) // batch_size + for batch_idx in range(num_batches): + if batch_idx % num_workers == worker_id: + offset = batch_idx * batch_size + batch = lf.slice(offset, batch_size).collect(streaming=False) + + # If batch is empty, we've processed all filtered samples + if len(batch) == 0: + break + + for sample in batch.to_dicts(): + # Deserialize from parquet format + restored_sample = deserialize_sample_from_parquet(sample) + + # Process with fitted processors + for k, v in restored_sample.items(): + if k in self.input_processors: + restored_sample[k] = self.input_processors[ + k + ].process(v) + elif k in self.output_processors: + restored_sample[k] = self.output_processors[ + k + ].process(v) + + yield restored_sample + else: + # No filtering - each worker processes every nth batch + num_batches = (self._num_samples + batch_size - 1) // batch_size + for batch_idx in range(num_batches): + if batch_idx % num_workers == worker_id: + offset = batch_idx * batch_size + length = min(batch_size, self._num_samples - offset) + # Disable streaming for slice operations (see comment above) + batch = lf.slice(offset, length).collect(streaming=False) + for sample in batch.to_dicts(): + # Deserialize from parquet format + restored_sample = deserialize_sample_from_parquet(sample) + + # Process with fitted processors + for k, v in restored_sample.items(): + if k in self.input_processors: + restored_sample[k] = self.input_processors[ + k + ].process(v) + elif k in self.output_processors: + restored_sample[k] = self.output_processors[ + k + ].process(v) + + yield restored_sample + + def get_patient_ids(self) -> List[str]: + """Get unique patient IDs for this dataset. + + For filtered datasets, returns the filtered patient list. + For unfiltered datasets, reads unique patient IDs from cache. + + Returns: + List[str]: List of unique patient IDs + + Example: + >>> # Get patients with samples for splitting + >>> all_patient_ids = sample_dataset.get_patient_ids() + >>> train_ids, val_ids, test_ids = split_by_patient_stream( + ... all_patient_ids, ratios=[0.8, 0.1, 0.1] + ... ) + """ + # If already set (filtered dataset), return it + if self.patient_ids is not None: + return self.patient_ids + + # Otherwise, read from cache + if not self._samples_finalized: + raise RuntimeError( + "Cannot get patient_ids before samples are finalized. " + "Call finalize_samples() first." + ) + + # Read unique patient IDs from cache using streaming + self.patient_ids = ( + pl.scan_parquet(self._sample_cache_path) + .select("patient_id") + .unique() + .collect(streaming=True)["patient_id"] + .to_list() + ) + + return self.patient_ids + + def __len__(self) -> int: + """Returns the number of samples in the dataset. + + For filtered datasets, computes the actual filtered sample count. + The count is cached to avoid repeated expensive queries. + + Returns: + int: The number of samples. + """ + # If we have a cached count, use it + if self._num_samples is not None: + return self._num_samples + + # Need to compute the count + if self.patient_ids is not None: + # Filtered dataset - count samples for these patients + lf = pl.scan_parquet(self._sample_cache_path) + lf = lf.filter(pl.col("patient_id").is_in(self.patient_ids)) + self._num_samples = lf.select(pl.len()).collect(streaming=True).item() + logger.info( + f"Filtered dataset has {self._num_samples} samples " + f"for {len(self.patient_ids)} patients" + ) + else: + # Unfiltered dataset - should have been set during finalization + # This shouldn't happen, but fall back to counting + logger.warning( + "Computing sample count from cache " + "(should have been set during finalization)" + ) + lf = pl.scan_parquet(self._sample_cache_path) + self._num_samples = lf.select(pl.len()).collect(streaming=True).item() + + return self._num_samples + + def __str__(self) -> str: + """Returns a string representation of the dataset. + + Returns: + str: A string with the dataset and task names. + """ + return f"Iterable sample dataset {self.dataset_name} {self.task_name}" diff --git a/pyhealth/datasets/mimic4.py b/pyhealth/datasets/mimic4.py index 05321dedb..23b6bc8f1 100644 --- a/pyhealth/datasets/mimic4.py +++ b/pyhealth/datasets/mimic4.py @@ -8,6 +8,7 @@ try: import psutil + HAS_PSUTIL = True except ImportError: HAS_PSUTIL = False @@ -39,7 +40,7 @@ class MIMIC4EHRDataset(BaseDataset): tables (List[str]): A list of tables to be included in the dataset. dataset_name (Optional[str]): The name of the dataset. config_path (Optional[str]): The path to the configuration file. - """ + """ def __init__( self, @@ -47,10 +48,12 @@ def __init__( tables: List[str], dataset_name: str = "mimic4_ehr", config_path: Optional[str] = None, - **kwargs + **kwargs, ): if config_path is None: - config_path = os.path.join(os.path.dirname(__file__), "configs", "mimic4_ehr.yaml") + config_path = os.path.join( + os.path.dirname(__file__), "configs", "mimic4_ehr.yaml" + ) logger.info(f"Using default EHR config: {config_path}") log_memory_usage(f"Before initializing {dataset_name}") @@ -61,7 +64,7 @@ def __init__( tables=tables, dataset_name=dataset_name, config_path=config_path, - **kwargs + **kwargs, ) log_memory_usage(f"After initializing {dataset_name}") @@ -86,10 +89,12 @@ def __init__( tables: List[str], dataset_name: str = "mimic4_note", config_path: Optional[str] = None, - **kwargs + **kwargs, ): if config_path is None: - config_path = os.path.join(os.path.dirname(__file__), "configs", "mimic4_note.yaml") + config_path = os.path.join( + os.path.dirname(__file__), "configs", "mimic4_note.yaml" + ) logger.info(f"Using default note config: {config_path}") if "discharge" in tables: warnings.warn( @@ -109,7 +114,7 @@ def __init__( tables=tables, dataset_name=dataset_name, config_path=config_path, - **kwargs + **kwargs, ) log_memory_usage(f"After initializing {dataset_name}") @@ -134,10 +139,12 @@ def __init__( tables: List[str], dataset_name: str = "mimic4_cxr", config_path: Optional[str] = None, - **kwargs + **kwargs, ): if config_path is None: - config_path = os.path.join(os.path.dirname(__file__), "configs", "mimic4_cxr.yaml") + config_path = os.path.join( + os.path.dirname(__file__), "configs", "mimic4_cxr.yaml" + ) logger.info(f"Using default CXR config: {config_path}") self.prepare_metadata(root) log_memory_usage(f"Before initializing {dataset_name}") @@ -146,12 +153,14 @@ def __init__( tables=tables, dataset_name=dataset_name, config_path=config_path, - **kwargs + **kwargs, ) log_memory_usage(f"After initializing {dataset_name}") def prepare_metadata(self, root: str) -> None: - metadata = pd.read_csv(os.path.join(root, "mimic-cxr-2.0.0-metadata.csv.gz"), dtype=str) + metadata = pd.read_csv( + os.path.join(root, "mimic-cxr-2.0.0-metadata.csv.gz"), dtype=str + ) def process_studytime(x): # reformat studytime to be 6 digits (e.g. 123.002 -> 000123 which is 12:30:00) @@ -160,6 +169,7 @@ def process_studytime(x): return f"{int(x):06d}" except Exception: return x + metadata["StudyTime"] = metadata["StudyTime"].apply(process_studytime) def process_image_path(x): @@ -168,10 +178,15 @@ def process_image_path(x): folder = subject_id[:3] study_id = "s" + x["study_id"] dicom_id = x["dicom_id"] - return os.path.join(root, "files", folder, subject_id, study_id, f"{dicom_id}.jpg") + return os.path.join( + root, "files", folder, subject_id, study_id, f"{dicom_id}.jpg" + ) + metadata["image_path"] = metadata.apply(process_image_path, axis=1) - metadata.to_csv(os.path.join(root, "mimic-cxr-2.0.0-metadata-pyhealth.csv"), index=False) + metadata.to_csv( + os.path.join(root, "mimic-cxr-2.0.0-metadata-pyhealth.csv"), index=False + ) return @@ -184,6 +199,39 @@ class MIMIC4Dataset(BaseDataset): - Clinical notes (discharge summaries, radiology reports) - Chest X-rays (images and metadata) + You can use any combination of data sources. Patient IDs are determined + by priority order: + 1. EHR dataset (if ehr_root provided) - primary source + 2. Note dataset (if note_root provided and no EHR) + 3. CXR dataset (if cxr_root provided and no EHR/Note) + + When using multiple data sources in streaming mode, all sub-datasets are + automatically synchronized to use the same patient cohort from the primary + source. + + Examples: + # Use all three modalities (EHR determines patient cohort) + dataset = MIMIC4Dataset( + ehr_root="/path/to/ehr", + note_root="/path/to/notes", + cxr_root="/path/to/cxr", + stream=True + ) + + # Use only chest X-rays (CXR determines patient cohort) + dataset = MIMIC4Dataset( + cxr_root="/path/to/cxr", + cxr_tables=["metadata", "chexpert"], + stream=True + ) + + # Use only clinical notes (Note determines patient cohort) + dataset = MIMIC4Dataset( + note_root="/path/to/notes", + note_tables=["discharge"], + stream=True + ) + Args: ehr_root: Root directory for MIMIC-IV EHR data note_root: Root directory for MIMIC-IV notes data @@ -195,7 +243,10 @@ class MIMIC4Dataset(BaseDataset): note_config_path: Path to the note config file cxr_config_path: Path to the CXR config file dataset_name: Name of the dataset - dev: Whether to enable dev mode (limit to 1000 patients) + dev: Whether to enable dev mode (limit patients) + dev_max_patients: Maximum number of patients in dev mode (default 1000) + stream: Whether to enable streaming mode for memory efficiency + cache_dir: Directory for streaming cache """ def __init__( @@ -211,23 +262,55 @@ def __init__( cxr_config_path: Optional[str] = None, dataset_name: str = "mimic4", dev: bool = False, + dev_max_patients: int = 1000, + stream: bool = False, + cache_dir: Optional[str] = None, ): log_memory_usage("Starting MIMIC4Dataset init") - # Initialize child datasets - self.dataset_name = dataset_name - self.sub_datasets = {} - self.root = None - self.tables = None - self.config = None - # Dev flag is only used in the MIMIC4Dataset class - # to ensure the same set of patients are used for all sub-datasets. - self.dev = dev - # We need at least one root directory if not any([ehr_root, note_root, cxr_root]): raise ValueError("At least one root directory must be provided") + # Initialize base class attributes for streaming mode + # MIMIC4Dataset doesn't follow the normal BaseDataset pattern + # (no single root/tables/config), so we initialize the attributes + # that BaseDataset would normally set + self.dataset_name = dataset_name + self.dev = dev + self.dev_max_patients = dev_max_patients + self.stream = stream + + # Handle cache_dir (convert to Path if needed) + if cache_dir is None: + from pathlib import Path + + cache_dir = Path.home() / ".cache" / "pyhealth" / dataset_name + elif isinstance(cache_dir, str): + from pathlib import Path + + cache_dir = Path(cache_dir) + self.cache_dir = cache_dir + + # Initialize streaming attributes (normally done in BaseDataset.__init__) + self._collected_global_event_df = None + self._unique_patient_ids = None + self._patient_cache_path = None + self._patient_index_path = None + self._patient_index = None + + # Setup streaming cache if enabled + if self.stream: + from .processing.streaming import setup_streaming_cache + + setup_streaming_cache(self) + + # MIMIC4-specific attributes + self.sub_datasets = {} + self.root = None # Composite dataset has no single root + self.tables = None # No single tables list + self.config = None # No single config + # Initialize empty lists if None provided ehr_tables = ehr_tables or [] note_tables = note_tables or [] @@ -235,31 +318,49 @@ def __init__( # Initialize EHR dataset if root is provided if ehr_root: - logger.info(f"Initializing MIMIC4EHRDataset with tables: {ehr_tables} (dev mode: {dev})") + logger.info( + f"Initializing MIMIC4EHRDataset with tables: {ehr_tables} (dev mode: {dev})" + ) self.sub_datasets["ehr"] = MIMIC4EHRDataset( root=ehr_root, tables=ehr_tables, config_path=ehr_config_path, + dev=dev, + dev_max_patients=dev_max_patients, + stream=stream, + cache_dir=cache_dir, ) log_memory_usage("After EHR dataset initialization") # Initialize Notes dataset if root is provided if note_root is not None and note_tables: - logger.info(f"Initializing MIMIC4NoteDataset with tables: {note_tables} (dev mode: {dev})") + logger.info( + f"Initializing MIMIC4NoteDataset with tables: {note_tables} (dev mode: {dev})" + ) self.sub_datasets["note"] = MIMIC4NoteDataset( root=note_root, tables=note_tables, config_path=note_config_path, + dev=dev, + dev_max_patients=dev_max_patients, + stream=stream, + cache_dir=cache_dir, ) log_memory_usage("After Note dataset initialization") # Initialize CXR dataset if root is provided if cxr_root is not None: - logger.info(f"Initializing MIMIC4CXRDataset with tables: {cxr_tables} (dev mode: {dev})") + logger.info( + f"Initializing MIMIC4CXRDataset with tables: {cxr_tables} (dev mode: {dev})" + ) self.sub_datasets["cxr"] = MIMIC4CXRDataset( root=cxr_root, tables=cxr_tables, config_path=cxr_config_path, + dev=dev, + dev_max_patients=dev_max_patients, + stream=stream, + cache_dir=cache_dir, ) log_memory_usage("After CXR dataset initialization") @@ -268,9 +369,24 @@ def __init__( self.global_event_df = self._combine_data() log_memory_usage("After combining data") - # Cache attributes - self._collected_global_event_df = None - self._unique_patient_ids = None + # CRITICAL: Trigger patient synchronization befosre cache building + # This ensures sub-datasets have synchronized patient IDs from the parent + # BEFORE their patient caches are built in set_task_streaming(). + # Without this, sub-dataset caches are built with ALL patients (unsynchronized), + # causing a mismatch with the parent's patient_ids. + # + # In streaming mode, we MUST trigger this early to ensure: + # 1. parent._unique_patient_ids is set + # 2. sub-dataset._unique_patient_ids is synchronized to parent + # 3. When set_task() calls build_patient_cache(), it uses the synchronized IDs + if self.stream: + logger.info("Pre-computing patient IDs for streaming mode...") + # Access unique_patient_ids to trigger synchronization + _ = self.unique_patient_ids + logger.info( + f"Initialized with {len(self._unique_patient_ids)} patients " + f"from {len(self.sub_datasets)} sub-dataset(s)" + ) log_memory_usage("Completed MIMIC4Dataset init") @@ -294,3 +410,78 @@ def _combine_data(self) -> pl.LazyFrame: return frames[0] else: return pl.concat(frames, how="diagonal") + + @property + def unique_patient_ids(self) -> List[str]: + """ + Get unique patient IDs from the dataset. + + Patient ID determination logic: + 1. If EHR is present: Use EHR patient IDs (EHR is the linking key) + 2. If only Notes + CXR: Use intersection (only patients in both) + 3. If single source: Use all patients from that source + + When EHR is present, it takes precedence because it contains the + core patient demographics and links to all other modalities. + + Returns: + List[str]: List of unique patient IDs + """ + # Cache the patient IDs if not already computed + if self._unique_patient_ids is None: + if len(self.sub_datasets) == 0: + raise ValueError( + "MIMIC4Dataset has no sub-datasets. At least one of " + "ehr_root, note_root, or cxr_root must be provided." + ) + + # Get patient IDs from all sub-datasets + all_patient_id_sets = {} + for dataset_name, dataset in self.sub_datasets.items(): + patient_ids = dataset.unique_patient_ids + all_patient_id_sets[dataset_name] = set(patient_ids) + logger.info(f"{dataset_name} dataset has {len(patient_ids)} patients") + + # Strategy 1: EHR takes precedence (it's the linking key) + if "ehr" in self.sub_datasets: + self._unique_patient_ids = list(all_patient_id_sets["ehr"]) + logger.info( + f"Using {len(self._unique_patient_ids)} patients " + f"from EHR dataset (EHR is primary source)" + ) + + # In streaming mode, sync other datasets to use EHR patients + # by directly setting their _unique_patient_ids + if self.stream and len(self.sub_datasets) > 1: + logger.info("Synchronizing Notes/CXR to use EHR patient set") + for dataset_name, dataset in self.sub_datasets.items(): + if dataset_name != "ehr": + # Directly set patient IDs (no separate variable) + dataset._unique_patient_ids = self._unique_patient_ids + logger.debug(f"Synchronized {dataset_name} with EHR") + + # Strategy 2: No EHR, single source - use all patients + elif len(self.sub_datasets) == 1: + self._unique_patient_ids = list(list(all_patient_id_sets.values())[0]) + logger.info( + f"Using all {len(self._unique_patient_ids)} patients " + f"from single data source" + ) + + # Strategy 3: No EHR, multiple sources - use intersection + else: + common_patients = set.intersection(*all_patient_id_sets.values()) + self._unique_patient_ids = sorted(list(common_patients)) + + logger.info( + f"No EHR dataset - using intersection of " + f"{len(self.sub_datasets)} sources: " + f"{len(self._unique_patient_ids)} patients" + ) + + # In streaming mode, sync all datasets to common set + if self.stream: + for dataset_name, dataset in self.sub_datasets.items(): + dataset._unique_patient_ids = self._unique_patient_ids + + return self._unique_patient_ids diff --git a/pyhealth/datasets/processing/__init__.py b/pyhealth/datasets/processing/__init__.py new file mode 100644 index 000000000..c3ca16dd2 --- /dev/null +++ b/pyhealth/datasets/processing/__init__.py @@ -0,0 +1,27 @@ +"""Dataset processing module. + +This module provides different processing modes for PyHealth datasets: +- normal: Traditional in-memory processing for smaller datasets +- streaming: Disk-backed streaming processing for large datasets + +These are internal implementation details and should not be imported directly +by users. The public API is through BaseDataset.set_task(). +""" + +from .normal import set_task_normal +from .streaming import ( + build_patient_cache, + iter_patients_streaming, + set_task_streaming, + setup_streaming_cache, + _create_patients_from_dataframe, +) + +__all__ = [ + "set_task_normal", + "set_task_streaming", + "setup_streaming_cache", + "build_patient_cache", + "iter_patients_streaming", + "_create_patients_from_dataframe", +] diff --git a/pyhealth/datasets/processing/normal.py b/pyhealth/datasets/processing/normal.py new file mode 100644 index 000000000..9963cab3b --- /dev/null +++ b/pyhealth/datasets/processing/normal.py @@ -0,0 +1,275 @@ +"""Normal (in-memory) mode implementation for BaseDataset. + +This module contains the implementation for processing datasets in normal mode, +where all data is loaded into memory. This is the traditional PyHealth approach +suitable for smaller datasets that fit in memory. +""" + +import json +import logging +import pickle +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path +from typing import Dict, List, Optional + +import polars as pl +from tqdm import tqdm + +from ...processors.base_processor import FeatureProcessor +from ...tasks import BaseTask +from ..sample_dataset import SampleDataset +from ..utils import _convert_for_cache, save_processors, load_processors + +logger = logging.getLogger(__name__) + + +def load_cached_samples_normal( + cache_path: Path, cache_format: str +) -> Optional[List[Dict]]: + """Load cached samples from disk (normal mode). + + Args: + cache_path: Path to cached samples file + cache_format: Format of cache ('parquet' or 'pickle') + + Returns: + List of sample dictionaries if successful, None if loading failed + """ + if not cache_path.exists(): + return None + + logger.info(f"Loading cached samples from {cache_path}") + try: + if cache_format == "parquet": + from ..utils import deserialize_sample_from_parquet + + cached_df = pl.read_parquet(cache_path) + samples = [ + deserialize_sample_from_parquet(row) for row in cached_df.to_dicts() + ] + elif cache_format == "pickle": + with open(cache_path, "rb") as f: + samples = pickle.load(f) + else: + raise ValueError(f"Unsupported cache format: {cache_format}") + + logger.info(f"Loaded {len(samples)} cached samples") + return samples + except Exception as e: + logger.warning("Failed to load cached data: %s. Regenerating...", e) + return None + + +def generate_samples_normal( + dataset, + task: BaseTask, + num_workers: int = 1, +) -> List[Dict]: + """Generate samples in normal (in-memory) mode. + + Args: + dataset: The BaseDataset instance + task: The task to generate samples for + num_workers: Number of worker threads (1 = single-threaded) + + Returns: + List of generated sample dictionaries + """ + logger.info(f"Generating samples with {num_workers} worker(s)...") + filtered_event_df = task.pre_filter(dataset.collected_global_event_df) + samples = [] + + if num_workers == 1: + # Single-threading (default and recommended) + for patient in tqdm( + dataset.iter_patients(filtered_event_df), + total=filtered_event_df["patient_id"].n_unique(), + desc=f"Generating samples for {task.task_name} with 1 worker", + smoothing=0, + ): + samples.extend(task(patient)) + else: + # Multi-threading (not recommended but available) + logger.info( + f"Generating samples for {task.task_name} " f"with {num_workers} workers" + ) + # Load all patients first (with progress bar) + patients = list( + tqdm( + dataset.iter_patients(filtered_event_df), + total=filtered_event_df["patient_id"].n_unique(), + desc=f"Loading patients for {task.task_name}", + smoothing=0, + ) + ) + + # Process patients in parallel (with progress bar) + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [executor.submit(task, patient) for patient in patients] + for future in tqdm( + as_completed(futures), + total=len(futures), + desc=( + f"Collecting samples for {task.task_name} " + f"from {num_workers} workers" + ), + ): + samples.extend(future.result()) + + return samples + + +def cache_samples_normal( + samples: List[Dict], cache_path: Path, cache_format: str +) -> None: + """Cache samples to disk (normal mode). + + Args: + samples: List of sample dictionaries to cache + cache_path: Path where to save the cache + cache_format: Format to use ('parquet' or 'pickle') + """ + cache_path.parent.mkdir(parents=True, exist_ok=True) + logger.info(f"Caching samples to {cache_path}") + + try: + if cache_format == "parquet": + # Convert samples to cache-friendly format + samples_for_cache = [_convert_for_cache(sample) for sample in samples] + + # Serialize nested dicts to JSON for parquet compatibility + # Avoids type inference issues like: + # "failed to determine supertype of list[f64] and list[list[str]]" + serialized_samples = [] + for sample in samples_for_cache: + serialized_sample = {} + for key, value in sample.items(): + if isinstance(value, dict) and "__stagenet_cache__" in value: + # Serialize StageNet cache dicts to JSON strings + serialized_sample[key] = json.dumps(value) + else: + serialized_sample[key] = value + serialized_samples.append(serialized_sample) + + samples_df = pl.DataFrame(serialized_samples) + samples_df.write_parquet(cache_path) + elif cache_format == "pickle": + # Save samples as pickle file + with open(cache_path, "wb") as f: + pickle.dump(samples, f) + else: + # Don't raise – just warn and skip caching + logger.warning( + "Unsupported cache format '%s'. Skipping caching.", + cache_format, + ) + return + + logger.info(f"Successfully cached {len(samples)} samples") + except Exception as e: + logger.warning(f"Failed to cache samples: {e}") + + +def set_task_normal( + dataset, + task: BaseTask, + num_workers: int, + cache_dir: Optional[str], + cache_format: str, + input_processors: Optional[Dict[str, FeatureProcessor]], + output_processors: Optional[Dict[str, FeatureProcessor]], +) -> SampleDataset: + """Execute set_task in normal (in-memory) mode. + + This is the traditional PyHealth approach where all data is loaded into + memory. Suitable for datasets that fit in memory (<100k samples typically). + + Implements smart caching: + 1. Checks if sample cache exists + 2. Checks if processors are cached + 3. Only regenerates what's missing + + Args: + dataset: The BaseDataset instance + task: The task to execute + num_workers: Number of worker threads for parallel processing + cache_dir: Directory for caching processed samples + cache_format: Format for cache files ('parquet' or 'pickle') + input_processors: Pre-fitted input processors + output_processors: Pre-fitted output processors + + Returns: + SampleDataset with processed samples loaded in memory + """ + # Determine cache filename (include dev params to avoid conflicts) + cache_filename = None + cache_path = None + if cache_dir is not None: + if dataset.dev: + cache_filename = ( + f"{task.task_name}_dev_{dataset.dev_max_patients}" f".{cache_format}" + ) + else: + cache_filename = f"{task.task_name}.{cache_format}" + cache_path = Path(cache_dir) / cache_filename + + # Check if processors are cached + processor_cache_dir = None + processors_cached = False + if cache_dir is not None: + processor_cache_dir = Path(cache_dir) / "processors" + processors_cached = ( + processor_cache_dir / "input_processors.pkl" + ).exists() and (processor_cache_dir / "output_processors.pkl").exists() + + # Load cached processors if available and not provided + if processors_cached and input_processors is None and output_processors is None: + logger.info(f"Loading cached processors from {processor_cache_dir}") + try: + input_processors, output_processors = load_processors( + str(processor_cache_dir) + ) + except Exception as e: + logger.warning(f"Failed to load cached processors: {e}. Will rebuild.") + processors_cached = False + input_processors = None + output_processors = None + + # Try to load from cache + samples = None + if cache_path is not None: + samples = load_cached_samples_normal(cache_path, cache_format) + + # Generate samples if not cached + if samples is None: + logger.info("No sample cache found. Generating samples...") + samples = generate_samples_normal(dataset, task, num_workers) + + # Cache the generated samples + if cache_path is not None: + cache_samples_normal(samples, cache_path, cache_format) + else: + logger.info(f"Using {len(samples)} cached samples from {cache_path}") + + # Create and return SampleDataset + sample_dataset = SampleDataset( + samples, + input_schema=task.input_schema, + output_schema=task.output_schema, + dataset_name=dataset.dataset_name, + task_name=task.task_name, + input_processors=input_processors, + output_processors=output_processors, + ) + + # Save processors if they were just built (and not loaded from cache) + if ( + cache_dir is not None + and not processors_cached + and (input_processors is None or output_processors is None) + ): + logger.info(f"Saving processors to {processor_cache_dir}") + save_processors(sample_dataset, str(processor_cache_dir)) + + logger.info(f"Dataset ready with {len(samples)} samples") + return sample_dataset diff --git a/pyhealth/datasets/processing/streaming.py b/pyhealth/datasets/processing/streaming.py new file mode 100644 index 000000000..f53df1666 --- /dev/null +++ b/pyhealth/datasets/processing/streaming.py @@ -0,0 +1,455 @@ +"""Streaming mode implementation for BaseDataset. + +This module contains the implementation for processing datasets in streaming mode, +where data is processed in batches and stored on disk. This enables memory-efficient +processing of large datasets that don't fit in memory. + +All streaming-specific logic is centralized here, including: +- Cache setup and management +- Patient cache building with disk-backed storage +- Streaming iteration over patients +- Task processing in streaming mode +""" + +import logging +from pathlib import Path +from typing import Dict, Iterator, List, Optional, Union + +import polars as pl +from tqdm import tqdm + +from ...data import Patient +from ...processors.base_processor import FeatureProcessor +from ...tasks import BaseTask +from ..iterable_sample_dataset import IterableSampleDataset +from ..utils import save_processors, load_processors + +logger = logging.getLogger(__name__) + + +# ============================================================================ +# Helper Functions +# ============================================================================ + + +def _create_patients_from_dataframe( + df: pl.DataFrame, lazy_partition: bool = False +) -> List[Patient]: + """Convert a DataFrame with multiple patients into Patient objects. + + Args: + df: DataFrame containing events for one or more patients + lazy_partition: Whether to enable lazy partitioning for streaming + + Returns: + List of Patient objects + """ + batch_patients = [] + grouped = df.group_by("patient_id") + for patient_id, patient_df in grouped: + patient_id = patient_id[0] + batch_patients.append( + Patient( + patient_id=patient_id, + data_source=patient_df, + lazy_partition=lazy_partition, + ) + ) + return batch_patients + + +# ============================================================================ +# Cache Setup and Management +# ============================================================================ + + +def setup_streaming_cache(dataset) -> None: + """Setup disk-backed cache directory structure for streaming mode. + + Creates cache directory and defines paths for patient cache and index. + Called during BaseDataset.__init__ when stream=True. + + Dev mode uses separate cache files to avoid conflicts with full dataset. + Cache filenames include dev_max_patients to support different dev sizes. + + Args: + dataset: The BaseDataset instance with cache_dir, dataset_name, + dev, and dev_max_patients attributes + """ + dataset.cache_dir.mkdir(parents=True, exist_ok=True) + + # Add dev suffix with patient count to cache paths + # This ensures different dev configurations use different cache files + if dataset.dev: + suffix = f"_dev_{dataset.dev_max_patients}" + else: + suffix = "" + + # Define cache file paths + dataset._patient_cache_path = ( + dataset.cache_dir / f"{dataset.dataset_name}_patients{suffix}.parquet" + ) + dataset._patient_index_path = ( + dataset.cache_dir / f"{dataset.dataset_name}_patient_index{suffix}.parquet" + ) + + logger.info(f"Streaming cache directory: {dataset.cache_dir}") + + +def build_patient_cache( + dataset, + filtered_df: Optional[pl.LazyFrame] = None, + force_rebuild: bool = False, +) -> None: + """Build disk-backed patient cache with efficient indexing. + + This method uses Polars' native streaming execution via sink_parquet to process + the data without loading everything into memory. According to the Polars docs + (https://docs.pola.rs/user-guide/concepts/streaming/), sink_parquet automatically + uses the streaming engine to process data in batches. + + Args: + dataset: The BaseDataset instance + filtered_df: Pre-filtered LazyFrame (e.g., from task.pre_filter()). + If None, uses dataset.global_event_df. + force_rebuild: If True, rebuild cache even if it exists. + Default is False. + + Implementation Notes: + - Uses Polars' sink_parquet for automatic streaming execution + - Sorts by patient_id for efficient patient-level reads + - Creates index for O(1) patient lookup + - Row group size tuned for typical patient size (~100 events) + """ + # Check if both cache and index exist + cache_exists = dataset._patient_cache_path.exists() + index_exists = dataset._patient_index_path.exists() + + if cache_exists and index_exists and not force_rebuild: + logger.info(f"Using existing patient cache: {dataset._patient_cache_path}") + return + + logger.info("Building patient cache using Polars streaming engine...") + + # Use filtered_df if provided, otherwise use global_event_df + df = filtered_df if filtered_df is not None else dataset.global_event_df + + # Filter to synchronized patient IDs if set. + # This applies to: + # 1. Sub-datasets (Notes/CXR) - filtered to parent's patient set + # 2. Composite datasets (MIMIC4) - filtered to synchronized patient set + # 3. Any dataset with manually set _unique_patient_ids + if dataset._unique_patient_ids is not None: + logger.info( + f"Filtering to {len(dataset._unique_patient_ids)} " f"synchronized patients" + ) + reference_patients = pl.DataFrame( + {"patient_id": dataset._unique_patient_ids} + ).lazy() + df = df.join(reference_patients, on="patient_id", how="inner") + + # Apply dev mode filtering at the LazyFrame level + if dataset.dev: + logger.info( + f"Dev mode enabled: limiting to {dataset.dev_max_patients} patients" + ) + limited_patients = ( + df.select(pl.col("patient_id")).unique().limit(dataset.dev_max_patients) + ) + df = df.join(limited_patients, on="patient_id", how="inner") + + # CRITICAL: Sort by patient_id for efficient patient-level access + # This enables: + # 1. Efficient patient-level reads (via row group filtering) + # 2. Polars can use merge joins on subsequent operations + # 3. Better compression (similar data grouped together) + df = df.sort("patient_id", "timestamp") + + # Use sink_parquet for memory-efficient writing + # According to https://www.rhosignal.com/posts/streaming-in-polars/, + # sink_parquet automatically uses Polars' streaming engine and never + # loads the full dataset into memory + df.sink_parquet( + dataset._patient_cache_path, + # Row group size tuned for patient-level access + # Assuming ~100 events per patient, 10000 events ≈ 100 patients per row group + row_group_size=10000, + compression="zstd", # Good balance of compression ratio and speed + statistics=True, # Enable statistics for better predicate pushdown + ) + + # Build patient index for fast lookups using streaming + logger.info("Building patient index with streaming...") + patient_index = ( + pl.scan_parquet(dataset._patient_cache_path) + .group_by("patient_id") + .agg( + [ + pl.count().alias("event_count"), + pl.first("timestamp").alias("first_timestamp"), + pl.last("timestamp").alias("last_timestamp"), + ] + ) + .sort("patient_id") + ) + # sink_parquet uses streaming automatically + patient_index.sink_parquet(dataset._patient_index_path) + + # Load index with streaming for verification + dataset._patient_index = pl.scan_parquet(dataset._patient_index_path).collect( + streaming=True + ) + + cache_size_mb = dataset._patient_cache_path.stat().st_size / 1e6 + logger.info(f"Patient cache built: {cache_size_mb:.2f} MB") + + +# ============================================================================ +# Streaming Patient Iteration +# ============================================================================ + + +def iter_patients_streaming( + dataset, + patient_ids: Optional[List[str]], + batch_size: int = 1, +) -> Iterator[Union[Patient, List[Patient]]]: + """Iterate over patients in streaming mode. + + Loads patients from disk cache in batches for I/O efficiency. + When batch_size=1, yields individual patients. When batch_size>1, + yields lists of patients. + + Args: + dataset: The BaseDataset instance + patient_ids: Optional list of specific patient IDs to iterate over. + If None, iterates over all patients in the cache. + batch_size: Number of patients to load per disk query. Default is 1. + Larger batches = better I/O efficiency but more memory. + Use batch_size=1 for single-patient iteration. + + Yields: + Patient objects (if batch_size=1) or lists of Patient objects (if batch_size>1) + + Example: + >>> # Single patients (batch_size=1) + >>> for patient in iter_patients_streaming(dataset, None, batch_size=1): + ... print(patient.patient_id) + + >>> # Batches of 100 (much more efficient!) + >>> for batch in iter_patients_streaming(dataset, None, batch_size=100): + ... print(f"Processing {len(batch)} patients") + """ + # Ensure cache exists + if not dataset._patient_cache_path.exists(): + build_patient_cache(dataset) + + # Load patient index + if dataset._patient_index is None: + dataset._patient_index = pl.scan_parquet(dataset._patient_index_path).collect( + streaming=True + ) + + patient_index_df = dataset._patient_index + + # Filter to specific patients if requested + if patient_ids is not None: + patient_index_df = patient_index_df.filter( + pl.col("patient_id").is_in(patient_ids) + ) + + patient_list = patient_index_df["patient_id"].to_list() + + # Process in batches (even if batch_size=1, this is still efficient) + for i in range(0, len(patient_list), batch_size): + batch_patient_ids = patient_list[i : i + batch_size] + + # Load entire batch in one disk query (efficient even for single patient) + batch_df = ( + pl.scan_parquet(dataset._patient_cache_path) + .filter(pl.col("patient_id").is_in(batch_patient_ids)) + .collect(streaming=True) + ) + + # Convert DataFrame to Patient objects + batch_patients = _create_patients_from_dataframe(batch_df, lazy_partition=True) + + # Yield batch or individual patient depending on batch_size + if batch_size == 1: + # Single patient mode - yield individual patient + if batch_patients: + yield batch_patients[0] + else: + # Batch mode - yield list of patients + yield batch_patients + + # Explicitly clear batch to help garbage collection + del batch_df + + +# ============================================================================ +# Task Processing in Streaming Mode +# ============================================================================ +# ============================================================================ + + +def set_task_streaming( + dataset, + task: BaseTask, + batch_size: int, + cache_dir: Optional[str], + input_processors: Optional[Dict[str, FeatureProcessor]], + output_processors: Optional[Dict[str, FeatureProcessor]], +) -> IterableSampleDataset: + """Execute set_task in streaming mode. + + This mode processes patients in batches and writes samples to disk, + enabling memory-efficient processing of large datasets (>100k samples). + + Implements smart caching: + 1. Checks if patient cache exists + 2. Checks if sample cache exists + 3. Checks if processors are cached + 4. Only rebuilds what's missing + + Args: + dataset: The BaseDataset instance + task: The task to execute + batch_size: Number of patients to process per batch + cache_dir: Directory for caching processed samples + input_processors: Pre-fitted input processors + output_processors: Pre-fitted output processors + + Returns: + IterableSampleDataset with samples stored on disk + """ + # Apply task's pre_filter on LazyFrame (no data loaded yet!) + filtered_lazy_df = task.pre_filter(dataset.global_event_df) + + # Build patient cache if not exists (lazy execution with sink_parquet) + if not dataset._patient_cache_path.exists(): + logger.info("Building patient cache...") + build_patient_cache(dataset, filtered_lazy_df) + else: + logger.info(f"Using existing patient cache: {dataset._patient_cache_path}") + + # Create streaming sample dataset + # Use batch_size as io_batch_size for efficient I/O + # This makes disk reads align with sample generation batch size + sample_dataset = IterableSampleDataset( + input_schema=task.input_schema, + output_schema=task.output_schema, + dataset_name=dataset.dataset_name, + task_name=task.task_name, + input_processors=input_processors, + output_processors=output_processors, + cache_dir=cache_dir or str(dataset.cache_dir), + dev=dataset.dev, + dev_max_patients=dataset.dev_max_patients, + io_batch_size=max(batch_size, 128), # At least 128 for efficient I/O + ) + + # Check if sample cache already exists + sample_cache_exists = sample_dataset._sample_cache_path.exists() + + # Check if processors are cached + processor_cache_dir = Path(cache_dir or dataset.cache_dir) / "processors" + processors_cached = (processor_cache_dir / "input_processors.pkl").exists() and ( + processor_cache_dir / "output_processors.pkl" + ).exists() + + # Load cached processors if available and not provided + if processors_cached and input_processors is None and output_processors is None: + logger.info(f"Loading cached processors from {processor_cache_dir}") + try: + input_processors, output_processors = load_processors( + str(processor_cache_dir) + ) + sample_dataset.input_processors = input_processors + sample_dataset.output_processors = output_processors + except Exception as e: + logger.warning(f"Failed to load cached processors: {e}. Will rebuild.") + processors_cached = False + + if sample_cache_exists: + # Sample cache exists - just load it + logger.info(f"Using existing sample cache: {sample_dataset._sample_cache_path}") + + # Mark as finalized and load sample count + sample_dataset._samples_finalized = True + sample_count_df = ( + pl.scan_parquet(sample_dataset._sample_cache_path) + .select(pl.count().alias("count")) + .collect(streaming=True) + ) + sample_dataset._num_samples = sample_count_df["count"][0] + + # Build processors if not already provided/cached + if not processors_cached and ( + input_processors is None or output_processors is None + ): + logger.info("Sample cache exists, but processors need to be built") + sample_dataset.build_streaming() + + # Save processors for future use + logger.info(f"Saving processors to {processor_cache_dir}") + save_processors(sample_dataset, str(processor_cache_dir)) + else: + logger.info("Using cached/provided processors") + else: + # Sample cache doesn't exist - generate samples + logger.info("Generating samples in streaming mode...") + + # Process patients in batches and write samples to disk + write_batch_samples = [] + write_batch_size = 500 # Write every 500 samples + + # Get total patient count for progress bar + patient_index = pl.scan_parquet(dataset._patient_index_path).collect( + streaming=True + ) + total_patients = len(patient_index) + + logger.info( + f"Processing patients in batches of {batch_size} " + f"for better I/O efficiency" + ) + + # Use streaming batch iteration for better performance + for patient_batch in tqdm( + iter_patients_streaming(dataset, None, batch_size), + total=(total_patients + batch_size - 1) // batch_size, + desc=f"Generating samples for {task.task_name}", + unit="batch", + ): + # Process all patients in the batch + for patient in patient_batch: + patient_samples = task(patient) + write_batch_samples.extend(patient_samples) + + # Write to disk when batch gets large enough + if len(write_batch_samples) >= write_batch_size: + sample_dataset.add_samples_streaming(write_batch_samples) + write_batch_samples = [] + + # Explicitly clear patient_batch to free memory immediately + # This helps garbage collector reclaim Patient DataFrames + del patient_batch + + # Write remaining samples + if write_batch_samples: + sample_dataset.add_samples_streaming(write_batch_samples) + + # Finalize sample cache + sample_dataset.finalize_samples() + + # Build processors (must happen after all samples written) + sample_dataset.build_streaming() + + # Save processors for future use + logger.info(f"Saving processors to {processor_cache_dir}") + save_processors(sample_dataset, str(processor_cache_dir)) + + logger.info(f"Dataset ready with {len(sample_dataset)} samples") + + return sample_dataset diff --git a/pyhealth/datasets/sample_dataset.py b/pyhealth/datasets/sample_dataset.py index 54c40420c..326c2a315 100644 --- a/pyhealth/datasets/sample_dataset.py +++ b/pyhealth/datasets/sample_dataset.py @@ -1,5 +1,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union, Type import inspect +import logging from torch.utils.data import Dataset from tqdm import tqdm @@ -7,6 +8,8 @@ from ..processors import get_processor from ..processors.base_processor import FeatureProcessor +logger = logging.getLogger(__name__) + class SampleDataset(Dataset): """Sample dataset class for handling and processing data samples. diff --git a/pyhealth/datasets/splitter.py b/pyhealth/datasets/splitter.py index c70df5660..1207b3a35 100644 --- a/pyhealth/datasets/splitter.py +++ b/pyhealth/datasets/splitter.py @@ -10,6 +10,117 @@ # TODO: add more splitting methods +def split_by_patient_stream( + patient_ids: List[str], + ratios: List[float], + seed: int = 42, +) -> Tuple[List[str], ...]: + """Split patient IDs into train/val/test or other proportions. + + This function provides deterministic patient-level splitting by operating + on patient ID lists rather than SampleDataset objects. This is ideal for: + - Streaming mode datasets where you filter patients before task application + - Pre-computing splits to save for reproducibility + - Creating custom patient-level cross-validation folds + + Unlike `split_by_patient` which operates on `SampleDataset` objects and + returns `Subset` objects, this function operates on raw patient ID lists. + + Args: + patient_ids: List of all patient IDs to split + ratios: List of floats that sum to 1.0 specifying split proportions. + Common patterns: + - [0.8, 0.2] for train/test + - [0.8, 0.1, 0.1] for train/val/test + - [0.7, 0.15, 0.15] for larger validation/test sets + seed: Random seed for reproducibility (default: 42) + + Returns: + Tuple of patient ID lists, one per ratio specified. + Length matches len(ratios). + + Raises: + AssertionError: If ratios don't sum to 1.0 (within 1e-6 tolerance) + + Examples: + >>> # Standard train/val/test split + >>> from pyhealth.datasets import split_by_patient_stream + >>> patient_ids = ["patient-1", "patient-2", ..., "patient-100"] + >>> train, val, test = split_by_patient_stream( + ... patient_ids, [0.8, 0.1, 0.1] + ... ) + >>> len(train), len(val), len(test) + (80, 10, 10) + + >>> # Use with streaming datasets + >>> base_dataset = MIMIC4Dataset(..., stream=True) + >>> all_ids = base_dataset.patient_ids + >>> train_ids, val_ids, test_ids = split_by_patient_stream( + ... all_ids, [0.8, 0.1, 0.1] + ... ) + >>> # Then filter when creating sample datasets + >>> train_samples = base_dataset.set_task( + ... task, patient_ids=train_ids # Filter to train patients + ... ) + + Note: + Patient-level splitting is essential in medical ML to prevent: + - Data leakage from multiple visits of same patient + - Optimistically biased performance estimates + - Models that memorize patient-specific patterns + + See Also: + - `split_by_patient`: Splits SampleDataset objects into Subset objects + - `split_by_visit`: Splits by samples/visits + """ + import random + + # Validation + assert isinstance(patient_ids, list), "patient_ids must be a list" + assert isinstance(ratios, list), "ratios must be a list" + assert len(ratios) >= 2, "Must provide at least 2 ratios for splitting" + assert abs(sum(ratios) - 1.0) < 1e-6, f"Ratios must sum to 1.0, got {sum(ratios)}" + assert all(r > 0 for r in ratios), "All ratios must be positive" + + # Shuffle patient IDs deterministically + random.seed(seed) + shuffled_ids = patient_ids.copy() + random.shuffle(shuffled_ids) + + # Calculate split indices + n_total = len(shuffled_ids) + splits = [] + start_idx = 0 + + for i, ratio in enumerate(ratios[:-1]): + # Calculate size for this split + split_size = int(n_total * ratio) + end_idx = start_idx + split_size + + splits.append(shuffled_ids[start_idx:end_idx]) + start_idx = end_idx + + # Last split gets all remaining patients (handles rounding) + splits.append(shuffled_ids[start_idx:]) + + # Print summary + split_names = ( + ["train", "val", "test"] + if len(splits) == 3 + else ( + ["train", "test"] + if len(splits) == 2 + else [f"split_{i+1}" for i in range(len(splits))] + ) + ) + print(f"Split {n_total} patients by patient ID (seed={seed}):") + for name, split in zip(split_names, splits): + pct = len(split) / n_total * 100 + print(f" {name.capitalize():8s}: " f"{len(split):6d} patients ({pct:5.1f}%)") + + return tuple(splits) + + def split_by_visit( dataset: SampleDataset, ratios: Union[Tuple[float, float, float], List[float]], diff --git a/pyhealth/datasets/utils.py b/pyhealth/datasets/utils.py index 63ca4152a..32e94fd46 100644 --- a/pyhealth/datasets/utils.py +++ b/pyhealth/datasets/utils.py @@ -231,6 +231,40 @@ def _restore_from_cache(sample: Dict[str, Any]) -> Dict[str, Any]: return restored +def deserialize_sample_from_parquet(sample: Dict[str, Any]) -> Dict[str, Any]: + """Deserialize a sample loaded from Parquet cache. + + This function handles both JSON deserialization (for fields that were + JSON-encoded to work around Parquet type limitations) and restoration + of temporal tuples from their cache representation. + + Args: + sample: Dictionary representing a sample loaded from Parquet + + Returns: + Dict[str, Any]: Deserialized sample with temporal tuples restored + + Example: + >>> batch = pl.scan_parquet(cache_path).collect().to_dicts() + >>> restored = [deserialize_sample_from_parquet(s) for s in batch] + """ + import json + + # Step 1: Deserialize JSON-encoded fields back to dicts + for key, value in list(sample.items()): + if isinstance(value, str): + try: + decoded = json.loads(value) + if isinstance(decoded, dict) and "__stagenet_cache__" in decoded: + sample[key] = decoded + except (json.JSONDecodeError, TypeError): + # Not JSON or not a cache dict, keep as is + pass + + # Step 2: Restore temporal tuples from cache format + return _restore_from_cache(sample) + + def collate_fn_dict(batch: List[dict]) -> dict: """Collates a batch of data into a dictionary of lists. diff --git a/pyhealth/processors/base_processor.py b/pyhealth/processors/base_processor.py index 050cb5357..2d4df0d9c 100644 --- a/pyhealth/processors/base_processor.py +++ b/pyhealth/processors/base_processor.py @@ -33,11 +33,26 @@ class FeatureProcessor(Processor): Example: Tokenization, image loading, normalization. """ - def fit(self, samples: List[Dict[str, Any]], field: str) -> None: + def fit( + self, samples: List[Dict[str, Any]], field: str, stream: bool = False + ) -> None: """Fit the processor to the samples. Args: - samples: List of sample dictionaries. + samples: List of sample dictionaries (all samples or a batch). + field: Field name to process. + stream: If True, accumulate statistics incrementally across + multiple fit() calls. If False (default), fit on complete + dataset in single call. Default maintains backward compatibility. + """ + pass + + def finalize_fit(self) -> None: + """Finalize fitting after all batches in streaming mode. + + Called after all fit(stream=True) calls to perform validation + or compute final statistics that require complete dataset view. + Optional - only implement if needed. """ pass diff --git a/pyhealth/processors/label_processor.py b/pyhealth/processors/label_processor.py index ad2df1897..6d251b2b6 100644 --- a/pyhealth/processors/label_processor.py +++ b/pyhealth/processors/label_processor.py @@ -18,20 +18,48 @@ class BinaryLabelProcessor(FeatureProcessor): def __init__(self): super().__init__() self.label_vocab: Dict[Any, int] = {} - - def fit(self, samples: List[Dict[str, Any]], field: str) -> None: - all_labels = set([sample[field] for sample in samples]) - if len(all_labels) != 2: - raise ValueError(f"Expected 2 unique labels, got {len(all_labels)}") - if all_labels == {0, 1}: + self._all_labels = set() # For streaming mode + + def fit( + self, samples: List[Dict[str, Any]], field: str, stream: bool = False + ) -> None: + if not stream: + # Non-streaming mode: original behavior (backward compatible) + all_labels = set([sample[field] for sample in samples]) + if len(all_labels) != 2: + raise ValueError(f"Expected 2 unique labels, got {len(all_labels)}") + if all_labels == {0, 1}: + self.label_vocab = {0: 0, 1: 1} + elif all_labels == {False, True}: + self.label_vocab = {False: 0, True: 1} + else: + all_labels = list(all_labels) + all_labels.sort() + self.label_vocab = {label: i for i, label in enumerate(all_labels)} + logger.info(f"Label {field} vocab: {self.label_vocab}") + else: + # Streaming mode: accumulate labels across batches + for sample in samples: + label = sample[field] + # Convert tensor to Python value if needed + if hasattr(label, "item"): + label = label.item() + self._all_labels.add(label) + + def finalize_fit(self) -> None: + """Finalize vocab after all streaming batches.""" + if len(self._all_labels) != 2: + raise ValueError(f"Expected 2 unique labels, got {len(self._all_labels)}") + if self._all_labels == {0, 1}: self.label_vocab = {0: 0, 1: 1} - elif all_labels == {False, True}: + elif self._all_labels == {False, True}: self.label_vocab = {False: 0, True: 1} else: - all_labels = list(all_labels) + all_labels = list(self._all_labels) all_labels.sort() self.label_vocab = {label: i for i, label in enumerate(all_labels)} - logger.info(f"Label {field} vocab: {self.label_vocab}") + logger.info(f"Label mortality vocab: {self.label_vocab}") + self._all_labels = set() # Clear temporary storage def process(self, value: Any) -> torch.Tensor: index = self.label_vocab[value] @@ -53,22 +81,47 @@ class MultiClassLabelProcessor(FeatureProcessor): def __init__(self): super().__init__() self.label_vocab: Dict[Any, int] = {} - - def fit(self, samples: List[Dict[str, Any]], field: str) -> None: - all_labels = set([sample[field] for sample in samples]) - num_classes = len(all_labels) - if all_labels == set(range(num_classes)): + self._all_labels = set() # For streaming mode + + def fit( + self, samples: List[Dict[str, Any]], field: str, stream: bool = False + ) -> None: + if not stream: + # Non-streaming mode: original behavior (backward compatible) + all_labels = set([sample[field] for sample in samples]) + num_classes = len(all_labels) + if all_labels == set(range(num_classes)): + self.label_vocab = {i: i for i in range(num_classes)} + else: + all_labels = list(all_labels) + all_labels.sort() + self.label_vocab = {label: i for i, label in enumerate(all_labels)} + logger.info(f"Label {field} vocab: {self.label_vocab}") + else: + # Streaming mode: accumulate labels across batches + for sample in samples: + label = sample[field] + # Convert tensor to Python value if needed + if hasattr(label, "item"): + label = label.item() + self._all_labels.add(label) + + def finalize_fit(self) -> None: + """Finalize vocab after all streaming batches.""" + num_classes = len(self._all_labels) + if self._all_labels == set(range(num_classes)): self.label_vocab = {i: i for i in range(num_classes)} else: - all_labels = list(all_labels) + all_labels = list(self._all_labels) all_labels.sort() self.label_vocab = {label: i for i, label in enumerate(all_labels)} - logger.info(f"Label {field} vocab: {self.label_vocab}") + logger.info(f"Label vocab: {self.label_vocab}") + self._all_labels = set() # Clear temporary storage def process(self, value: Any) -> torch.Tensor: index = self.label_vocab[value] return torch.tensor(index, dtype=torch.long) - + def size(self): return len(self.label_vocab) @@ -88,20 +141,45 @@ class MultiLabelProcessor(FeatureProcessor): def __init__(self): super().__init__() self.label_vocab: Dict[Any, int] = {} - - def fit(self, samples: List[Dict[str, Any]], field: str) -> None: - all_labels = set() - for sample in samples: - for label in sample[field]: - all_labels.add(label) - num_classes = len(all_labels) - if all_labels == set(range(num_classes)): + self._all_labels = set() # For streaming mode + + def fit( + self, samples: List[Dict[str, Any]], field: str, stream: bool = False + ) -> None: + if not stream: + # Non-streaming mode: original behavior (backward compatible) + all_labels = set() + for sample in samples: + for label in sample[field]: + all_labels.add(label) + num_classes = len(all_labels) + if all_labels == set(range(num_classes)): + self.label_vocab = {i: i for i in range(num_classes)} + else: + all_labels = list(all_labels) + all_labels.sort() + self.label_vocab = {label: i for i, label in enumerate(all_labels)} + logger.info(f"Label {field} vocab: {self.label_vocab}") + else: + # Streaming mode: accumulate labels across batches + for sample in samples: + for label in sample[field]: + # Convert tensor to Python value if needed + if hasattr(label, "item"): + label = label.item() + self._all_labels.add(label) + + def finalize_fit(self) -> None: + """Finalize vocab after all streaming batches.""" + num_classes = len(self._all_labels) + if self._all_labels == set(range(num_classes)): self.label_vocab = {i: i for i in range(num_classes)} else: - all_labels = list(all_labels) + all_labels = list(self._all_labels) all_labels.sort() self.label_vocab = {label: i for i, label in enumerate(all_labels)} - logger.info(f"Label {field} vocab: {self.label_vocab}") + logger.info(f"Label vocab: {self.label_vocab}") + self._all_labels = set() # Clear temporary storage def process(self, value: Any) -> torch.Tensor: if not isinstance(value, list): @@ -127,7 +205,7 @@ class RegressionLabelProcessor(FeatureProcessor): def process(self, value: Any) -> torch.Tensor: return torch.tensor([float(value)], dtype=torch.float32) - + def size(self): return 1 diff --git a/pyhealth/processors/nested_sequence_processor.py b/pyhealth/processors/nested_sequence_processor.py index bf7ed2055..40e31dea7 100644 --- a/pyhealth/processors/nested_sequence_processor.py +++ b/pyhealth/processors/nested_sequence_processor.py @@ -49,15 +49,27 @@ def __init__(self, padding: int = 0): self.code_vocab: Dict[Any, int] = {"": -1, "": 0} self._next_index = 1 self._max_inner_len = 1 # Maximum length of inner sequences + # For streaming mode + self._stream_max_inner_len = 0 self._padding = padding # Additional padding beyond observed max - def fit(self, samples: List[Dict[str, Any]], field: str) -> None: + def fit( + self, samples: List[Dict[str, Any]], field: str, stream: bool = False + ) -> None: """Build vocabulary and determine maximum inner sequence length. Args: samples: List of sample dictionaries. field: The field name containing nested sequences. + stream: If True, accumulate vocab across batches. """ + if not stream: + self._fit_non_streaming(samples, field) + else: + self._fit_streaming_batch(samples, field) + + def _fit_non_streaming(self, samples: List[Dict[str, Any]], field: str) -> None: + """Original fit logic (backward compatible).""" max_inner_len = 0 for sample in samples: @@ -82,6 +94,29 @@ def fit(self, samples: List[Dict[str, Any]], field: str) -> None: observed_max = max(1, max_inner_len) self._max_inner_len = observed_max + self._padding + def _fit_streaming_batch(self, samples: List[Dict[str, Any]], field: str) -> None: + """Accumulate vocab from streaming batch.""" + for sample in samples: + if field in sample and sample[field] is not None: + nested_seq = sample[field] + + if isinstance(nested_seq, list): + for inner_seq in nested_seq: + if isinstance(inner_seq, list): + self._stream_max_inner_len = max( + self._stream_max_inner_len, len(inner_seq) + ) + + for code in inner_seq: + if code is not None and code not in self.code_vocab: + self.code_vocab[code] = self._next_index + self._next_index += 1 + + def finalize_fit(self) -> None: + """Finalize vocab after all streaming batches.""" + self._max_inner_len = max(1, self._stream_max_inner_len) + self._stream_max_inner_len = 0 # Clear temporary storage + def process(self, value: List[List[Any]]) -> torch.Tensor: """Process nested sequence into padded 2D tensor. @@ -181,15 +216,26 @@ class NestedFloatsProcessor(FeatureProcessor): def __init__(self, forward_fill: bool = True, padding: int = 0): self._max_inner_len = 1 # Maximum length of inner sequences self.forward_fill = forward_fill + self._stream_max_inner_len = 0 self._padding = padding # Additional padding beyond observed max - def fit(self, samples: List[Dict[str, Any]], field: str) -> None: + def fit( + self, samples: List[Dict[str, Any]], field: str, stream: bool = False + ) -> None: """Determine maximum inner sequence length. Args: samples: List of sample dictionaries. field: The field name containing nested sequences. + stream: If True, accumulate max length across batches. """ + if not stream: + self._fit_non_streaming(samples, field) + else: + self._fit_streaming_batch(samples, field) + + def _fit_non_streaming(self, samples: List[Dict[str, Any]], field: str) -> None: + """Original fit logic (backward compatible).""" max_inner_len = 0 for sample in samples: @@ -208,6 +254,24 @@ def fit(self, samples: List[Dict[str, Any]], field: str) -> None: observed_max = max(1, max_inner_len) self._max_inner_len = observed_max + self._padding + def _fit_streaming_batch(self, samples: List[Dict[str, Any]], field: str) -> None: + """Accumulate max length from streaming batch.""" + for sample in samples: + if field in sample and sample[field] is not None: + nested_seq = sample[field] + + if isinstance(nested_seq, list): + for inner_seq in nested_seq: + if isinstance(inner_seq, list): + self._stream_max_inner_len = max( + self._stream_max_inner_len, len(inner_seq) + ) + + def finalize_fit(self) -> None: + """Finalize max length after all streaming batches.""" + self._max_inner_len = max(1, self._stream_max_inner_len) + self._stream_max_inner_len = 0 # Clear temporary storage + def process(self, value: List[List[float]]) -> torch.Tensor: """Process nested numerical sequence with optional forward fill. diff --git a/pyhealth/processors/stagenet_processor.py b/pyhealth/processors/stagenet_processor.py index cbbafac94..2deb593af 100644 --- a/pyhealth/processors/stagenet_processor.py +++ b/pyhealth/processors/stagenet_processor.py @@ -59,15 +59,27 @@ def __init__(self, padding: int = 0): self._next_index = 1 self._is_nested = None # Will be determined during fit self._max_nested_len = None # Max inner sequence length for nested codes + # For streaming mode + self._stream_max_inner_len = 0 self._padding = padding # Additional padding beyond observed max - def fit(self, samples: List[Dict], key: str) -> None: + def fit(self, samples: List[Dict], key: str, stream: bool = False) -> None: """Build vocabulary and determine input structure. Args: samples: List of sample dictionaries key: The key in samples that contains tuple (time, values) + stream: If True, accumulate vocab across batches """ + if not stream: + # Non-streaming mode: original behavior (backward compatible) + self._fit_non_streaming(samples, key) + else: + # Streaming mode: accumulate vocab across batches + self._fit_streaming_batch(samples, key) + + def _fit_non_streaming(self, samples: List[Dict], key: str) -> None: + """Original fit logic (backward compatible).""" # Examine first non-None sample to determine structure for sample in samples: if key in sample and sample[key] is not None: @@ -116,6 +128,48 @@ def fit(self, samples: List[Dict], key: str) -> None: observed_max = max(1, max_inner_len) self._max_nested_len = observed_max + self._padding + def _fit_streaming_batch(self, samples: List[Dict], key: str) -> None: + """Accumulate vocab from streaming batch.""" + # Determine structure from first batch if not set + if self._is_nested is None: + for sample in samples: + if key in sample and sample[key] is not None: + time_data, value_data = sample[key] + if isinstance(value_data, list) and len(value_data) > 0: + first_elem = value_data[0] + if isinstance(first_elem, str): + self._is_nested = False + elif isinstance(first_elem, list): + if len(first_elem) > 0 and isinstance(first_elem[0], str): + self._is_nested = True + break + + # Accumulate vocab and track max lengths + for sample in samples: + if key in sample and sample[key] is not None: + time_data, value_data = sample[key] + + if self._is_nested: + for inner_list in value_data: + self._stream_max_inner_len = max( + self._stream_max_inner_len, len(inner_list) + ) + for code in inner_list: + if code is not None and code not in self.code_vocab: + self.code_vocab[code] = self._next_index + self._next_index += 1 + else: + for code in value_data: + if code is not None and code not in self.code_vocab: + self.code_vocab[code] = self._next_index + self._next_index += 1 + + def finalize_fit(self) -> None: + """Finalize vocab after all streaming batches.""" + if self._is_nested: + self._max_nested_len = max(1, self._stream_max_inner_len) + self._stream_max_inner_len = 0 # Clear temporary storage + def process( self, value: Tuple[Optional[List], List] ) -> Tuple[Optional[torch.Tensor], torch.Tensor]: @@ -256,13 +310,19 @@ def __init__(self): self._size = None # Feature dimension (set during fit) self._is_nested = None - def fit(self, samples: List[Dict], key: str) -> None: + def fit(self, samples: List[Dict], key: str, stream: bool = False) -> None: """Determine input structure. Args: samples: List of sample dictionaries key: The key in samples that contains tuple (time, values) + stream: If True, this is a streaming batch (no-op for this processor) """ + # Structure detection doesn't need streaming mode + # because we only need to examine first sample + if self._is_nested is not None: + return # Already determined in previous batch + # Examine first non-None sample to determine structure for sample in samples: if key in sample and sample[key] is not None: @@ -285,6 +345,10 @@ def fit(self, samples: List[Dict], key: str) -> None: self._size = len(first_elem) break + def finalize_fit(self) -> None: + """No-op for this processor (structure determined in first batch).""" + pass + def process( self, value: Tuple[Optional[List], List] ) -> Tuple[Optional[torch.Tensor], torch.Tensor]: diff --git a/tests/core/test_caching.py b/tests/core/test_caching.py index b9b355bce..0805b3e2e 100644 --- a/tests/core/test_caching.py +++ b/tests/core/test_caching.py @@ -44,8 +44,10 @@ def __init__(self): # Initialize without calling parent __init__ to avoid file dependencies self.dataset_name = "TestDataset" self.dev = False + self.stream = False # Required by BaseDataset.set_task() - # Create realistic test data with patient_id, test_attribute, and test_label + # Create realistic test data with patient_id, test_attribute, + # and test_label self._collected_global_event_df = pl.DataFrame( { "patient_id": ["1", "2", "1", "2"], @@ -104,7 +106,16 @@ def test_set_task_signature(self): sig = inspect.signature(BaseDataset.set_task) params = list(sig.parameters.keys()) - expected_params = ["self", "task", "num_workers", "cache_dir", "cache_format", "input_processors", "output_processors"] + expected_params = [ + "self", + "task", + "num_workers", + "cache_dir", + "cache_format", + "input_processors", + "output_processors", + "batch_size", + ] self.assertEqual(params, expected_params) # Check default values @@ -114,6 +125,8 @@ def test_set_task_signature(self): self.assertEqual(sig.parameters["cache_format"].default, "parquet") self.assertEqual(sig.parameters["input_processors"].default, None) self.assertEqual(sig.parameters["output_processors"].default, None) + self.assertEqual(sig.parameters["batch_size"].default, None) + self.assertEqual(sig.parameters["output_processors"].default, None) def test_set_task_no_caching(self): """Test set_task without caching (cache_dir=None).""" diff --git a/tests/core/test_stream_processors.py b/tests/core/test_stream_processors.py new file mode 100644 index 000000000..f9fe22e69 --- /dev/null +++ b/tests/core/test_stream_processors.py @@ -0,0 +1,783 @@ +import unittest +import torch + +from pyhealth.datasets import SampleDataset, get_dataloader +from pyhealth.processors import ( + BinaryLabelProcessor, + MultiClassLabelProcessor, + MultiLabelProcessor, + StageNetProcessor, + StageNetTensorProcessor, + NestedSequenceProcessor, + NestedFloatsProcessor, +) + + +class TestStreamProcessors(unittest.TestCase): + """Test cases for streaming fit functionality across all processors.""" + + def test_binary_label_processor_streaming(self): + """Test BinaryLabelProcessor with streaming mode.""" + # Create batches of samples + batch1 = [ + {"label": 0}, + {"label": 1}, + {"label": 0}, + ] + batch2 = [ + {"label": 1}, + {"label": 0}, + ] + batch3 = [ + {"label": 1}, + ] + + # Non-streaming mode (baseline) + processor_baseline = BinaryLabelProcessor() + all_samples = batch1 + batch2 + batch3 + processor_baseline.fit(all_samples, "label", stream=False) + + # Streaming mode + processor_streaming = BinaryLabelProcessor() + processor_streaming.fit(batch1, "label", stream=True) + processor_streaming.fit(batch2, "label", stream=True) + processor_streaming.fit(batch3, "label", stream=True) + processor_streaming.finalize_fit() + + # Verify vocabs are identical + self.assertEqual( + processor_baseline.label_vocab, + processor_streaming.label_vocab, + ) + + # Test processing works correctly + self.assertTrue( + torch.equal( + processor_baseline.process(0), + processor_streaming.process(0), + ) + ) + self.assertTrue( + torch.equal( + processor_baseline.process(1), + processor_streaming.process(1), + ) + ) + + def test_multiclass_label_processor_streaming(self): + """Test MultiClassLabelProcessor with streaming mode.""" + batch1 = [ + {"label": "class_a"}, + {"label": "class_b"}, + ] + batch2 = [ + {"label": "class_c"}, + {"label": "class_a"}, + ] + batch3 = [ + {"label": "class_d"}, + {"label": "class_b"}, + ] + + # Non-streaming mode + processor_baseline = MultiClassLabelProcessor() + all_samples = batch1 + batch2 + batch3 + processor_baseline.fit(all_samples, "label", stream=False) + + # Streaming mode + processor_streaming = MultiClassLabelProcessor() + processor_streaming.fit(batch1, "label", stream=True) + processor_streaming.fit(batch2, "label", stream=True) + processor_streaming.fit(batch3, "label", stream=True) + processor_streaming.finalize_fit() + + # Verify vocabs are identical + self.assertEqual( + processor_baseline.label_vocab, + processor_streaming.label_vocab, + ) + self.assertEqual(processor_baseline.size(), 4) + self.assertEqual(processor_streaming.size(), 4) + + # Test processing + for label in ["class_a", "class_b", "class_c", "class_d"]: + self.assertTrue( + torch.equal( + processor_baseline.process(label), + processor_streaming.process(label), + ) + ) + + def test_multilabel_processor_streaming(self): + """Test MultiLabelProcessor with streaming mode.""" + batch1 = [ + {"tags": ["tag1", "tag2"]}, + {"tags": ["tag2", "tag3"]}, + ] + batch2 = [ + {"tags": ["tag1", "tag4"]}, + {"tags": ["tag3", "tag5"]}, + ] + batch3 = [ + {"tags": ["tag2", "tag5", "tag6"]}, + ] + + # Non-streaming mode + processor_baseline = MultiLabelProcessor() + all_samples = batch1 + batch2 + batch3 + processor_baseline.fit(all_samples, "tags", stream=False) + + # Streaming mode + processor_streaming = MultiLabelProcessor() + processor_streaming.fit(batch1, "tags", stream=True) + processor_streaming.fit(batch2, "tags", stream=True) + processor_streaming.fit(batch3, "tags", stream=True) + processor_streaming.finalize_fit() + + # Verify vocabs are identical + self.assertEqual( + processor_baseline.label_vocab, + processor_streaming.label_vocab, + ) + self.assertEqual(processor_baseline.size(), 6) + self.assertEqual(processor_streaming.size(), 6) + + # Test processing + test_tags = ["tag1", "tag3", "tag5"] + result_baseline = processor_baseline.process(test_tags) + result_streaming = processor_streaming.process(test_tags) + self.assertTrue(torch.equal(result_baseline, result_streaming)) + + def test_stagenet_processor_streaming(self): + """Test StageNetProcessor with streaming mode.""" + # Flat codes + batch1 = [ + {"codes": ([0.0, 1.0], ["code1", "code2"])}, + {"codes": ([0.0, 1.5], ["code2", "code3"])}, + ] + batch2 = [ + {"codes": ([0.0], ["code4"])}, + {"codes": ([0.0, 0.5, 1.0], ["code1", "code5", "code6"])}, + ] + + # Non-streaming mode + processor_baseline = StageNetProcessor() + all_samples = batch1 + batch2 + processor_baseline.fit(all_samples, "codes", stream=False) + + # Streaming mode + processor_streaming = StageNetProcessor() + processor_streaming.fit(batch1, "codes", stream=True) + processor_streaming.fit(batch2, "codes", stream=True) + processor_streaming.finalize_fit() + + # Verify vocabs are identical + self.assertEqual( + processor_baseline.code_vocab, + processor_streaming.code_vocab, + ) + self.assertEqual(processor_baseline._is_nested, False) + self.assertEqual(processor_streaming._is_nested, False) + + # Test processing + test_data = ([0.0, 1.0], ["code1", "code2"]) + result_baseline = processor_baseline.process(test_data) + result_streaming = processor_streaming.process(test_data) + + # Results are tuples of (time, values) + time_b, values_b = result_baseline + time_s, values_s = result_streaming + + if time_b is not None and time_s is not None: + self.assertTrue(torch.equal(time_b, time_s)) + self.assertTrue(torch.equal(values_b, values_s)) + + def test_stagenet_processor_nested_streaming(self): + """Test StageNetProcessor with nested codes in streaming mode.""" + # Nested codes + batch1 = [ + {"procs": ([0.0], [["A01", "A02"], ["B01"]])}, + {"procs": ([0.0, 1.0], [["A03"], ["B02", "C01"]])}, + ] + batch2 = [ + {"procs": ([0.0, 1.0], [["A01", "A04"], ["C02"]])}, + {"procs": ([0.0], [["D01", "D02", "D03"]])}, + ] + + # Non-streaming mode + processor_baseline = StageNetProcessor() + all_samples = batch1 + batch2 + processor_baseline.fit(all_samples, "procs", stream=False) + + # Streaming mode + processor_streaming = StageNetProcessor() + processor_streaming.fit(batch1, "procs", stream=True) + processor_streaming.fit(batch2, "procs", stream=True) + processor_streaming.finalize_fit() + + # Verify vocabs and structure + self.assertEqual( + processor_baseline.code_vocab, + processor_streaming.code_vocab, + ) + self.assertEqual(processor_baseline._is_nested, True) + self.assertEqual(processor_streaming._is_nested, True) + self.assertEqual( + processor_baseline._max_nested_len, + processor_streaming._max_nested_len, + ) + # Max nested length should be 3 (from ["D01", "D02", "D03"]) + self.assertEqual(processor_baseline._max_nested_len, 3) + + def test_stagenet_tensor_processor_streaming(self): + """Test StageNetTensorProcessor with streaming mode.""" + batch1 = [ + {"values": (None, [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, + {"values": (None, [[7.0, 8.0, 9.0]])}, + ] + batch2 = [ + {"values": ([0.0, 1.0], [[1.5, 2.5, 3.5], [4.5, 5.5, 6.5]])}, + ] + + # Non-streaming mode + processor_baseline = StageNetTensorProcessor() + all_samples = batch1 + batch2 + processor_baseline.fit(all_samples, "values", stream=False) + + # Streaming mode + processor_streaming = StageNetTensorProcessor() + processor_streaming.fit(batch1, "values", stream=True) + processor_streaming.fit(batch2, "values", stream=True) + processor_streaming.finalize_fit() + + # Verify structure detection + self.assertEqual(processor_baseline._is_nested, True) + self.assertEqual(processor_streaming._is_nested, True) + self.assertEqual(processor_baseline._size, 3) + self.assertEqual(processor_streaming._size, 3) + + def test_nested_sequence_processor_streaming(self): + """Test NestedSequenceProcessor with streaming mode.""" + batch1 = [ + {"codes": [["A", "B"], ["C", "D", "E"]]}, + {"codes": [["F"]]}, + ] + batch2 = [ + {"codes": [["A", "G"], ["H", "I"]]}, + {"codes": [["B", "C"], ["J", "K", "L", "M"]]}, + ] + + # Non-streaming mode + processor_baseline = NestedSequenceProcessor() + all_samples = batch1 + batch2 + processor_baseline.fit(all_samples, "codes", stream=False) + + # Streaming mode + processor_streaming = NestedSequenceProcessor() + processor_streaming.fit(batch1, "codes", stream=True) + processor_streaming.fit(batch2, "codes", stream=True) + processor_streaming.finalize_fit() + + # Verify vocabs are identical + self.assertEqual( + processor_baseline.code_vocab, + processor_streaming.code_vocab, + ) + # Max inner length should be 4 (from ["J", "K", "L", "M"]) + self.assertEqual(processor_baseline._max_inner_len, 4) + self.assertEqual(processor_streaming._max_inner_len, 4) + + # Test processing + test_data = [["A", "B"], ["C"]] + result_baseline = processor_baseline.process(test_data) + result_streaming = processor_streaming.process(test_data) + self.assertTrue(torch.equal(result_baseline, result_streaming)) + + def test_nested_floats_processor_streaming(self): + """Test NestedFloatsProcessor with streaming mode.""" + batch1 = [ + {"values": [[1.0, 2.0], [3.0, 4.0, 5.0]]}, + {"values": [[6.0]]}, + ] + batch2 = [ + {"values": [[7.0, 8.0], [9.0, 10.0]]}, + {"values": [[11.0, 12.0, 13.0, 14.0]]}, + ] + + # Non-streaming mode + processor_baseline = NestedFloatsProcessor() + all_samples = batch1 + batch2 + processor_baseline.fit(all_samples, "values", stream=False) + + # Streaming mode + processor_streaming = NestedFloatsProcessor() + processor_streaming.fit(batch1, "values", stream=True) + processor_streaming.fit(batch2, "values", stream=True) + processor_streaming.finalize_fit() + + # Max inner length should be 4 (from [11.0, 12.0, 13.0, 14.0]) + self.assertEqual(processor_baseline._max_inner_len, 4) + self.assertEqual(processor_streaming._max_inner_len, 4) + + # Test processing + test_data = [[1.0, 2.0], [3.0]] + result_baseline = processor_baseline.process(test_data) + result_streaming = processor_streaming.process(test_data) + self.assertTrue(torch.equal(result_baseline, result_streaming)) + + def test_streaming_with_empty_batches(self): + """Test streaming mode handles empty batches gracefully.""" + batch1 = [{"label": 0}, {"label": 1}] + batch2 = [] # Empty batch + batch3 = [{"label": 0}] + + processor = BinaryLabelProcessor() + processor.fit(batch1, "label", stream=True) + processor.fit(batch2, "label", stream=True) # Should not crash + processor.fit(batch3, "label", stream=True) + processor.finalize_fit() + + self.assertEqual(len(processor.label_vocab), 2) + + def test_streaming_integration_with_dataset_nonstreaming(self): + """Test dataset creation with non-streaming processor fit.""" + # Create a small dataset (< 100k samples, non-streaming mode) + samples = [] + for i in range(50): + samples.append( + { + "patient_id": f"patient-{i}", + "visit_id": f"visit-{i}", + "conditions": ["cond-33", "cond-86", "cond-80"], + "procedures": [1.0, 2.0, 3.5, 4.0], + "label": i % 2, + } + ) + + input_schema = { + "conditions": "sequence", + "procedures": "tensor", + } + output_schema = {"label": "binary"} + + # Create dataset - should use non-streaming fit (< 100k samples) + dataset = SampleDataset( + samples=samples, + input_schema=input_schema, + output_schema=output_schema, + dataset_name="test_nonstreaming_integration", + ) + + # Verify dataset created successfully + self.assertEqual(len(dataset), 50) + self.assertIn("conditions", dataset[0]) + self.assertIn("procedures", dataset[0]) + self.assertIn("label", dataset[0]) + + # Test dataloader works + train_loader = get_dataloader(dataset, batch_size=10, shuffle=False) + data_batch = next(iter(train_loader)) + + self.assertEqual(data_batch["label"].shape[0], 10) + self.assertIsInstance(data_batch["conditions"], torch.Tensor) + self.assertIsInstance(data_batch["procedures"], torch.Tensor) + + def test_manual_streaming_vs_nonstreaming_dataset(self): + """Test that streaming and non-streaming produce identical results.""" + # Create mock samples similar to test_mlp.py + samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "conditions": ["cond-33", "cond-86", "cond-80", "cond-12"], + "label": 0, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-1", + "conditions": ["cond-33", "cond-86"], + "label": 1, + }, + { + "patient_id": "patient-2", + "visit_id": "visit-2", + "conditions": ["cond-80", "cond-12", "cond-99"], + "label": 0, + }, + { + "patient_id": "patient-3", + "visit_id": "visit-3", + "conditions": ["cond-86", "cond-99", "cond-33"], + "label": 1, + }, + ] + + input_schema = {"conditions": "sequence"} + output_schema = {"label": "binary"} + + # Create two separate datasets to ensure clean state + # Dataset 1: Normal non-streaming mode + dataset_normal = SampleDataset( + samples=samples.copy(), + input_schema=input_schema, + output_schema=output_schema, + dataset_name="test_normal_vs_stream", + ) + + # Get the fitted label processor from normal dataset + label_proc_normal = dataset_normal.output_processors["label"] + + # Manually create and fit a streaming version with fresh processor + from pyhealth.processors import BinaryLabelProcessor + + label_proc_streaming = BinaryLabelProcessor() + + # Debug: Print object IDs to verify they're different instances + print(f"\nNormal processor id: {id(label_proc_normal)}") + print(f"Streaming processor id: {id(label_proc_streaming)}") + print(f"Initial _all_labels: {label_proc_streaming._all_labels}") + + # Debug: Check initial state + self.assertEqual(len(label_proc_streaming._all_labels), 0) + + # Simulate streaming fit in batches + batch1 = samples[:2] + batch2 = samples[2:] + + label_proc_streaming.fit(batch1, "label", stream=True) + print(f"After batch1, _all_labels: {label_proc_streaming._all_labels}") + # After first batch, should have 2 labels (0 and 1) + self.assertEqual(len(label_proc_streaming._all_labels), 2) + + label_proc_streaming.fit(batch2, "label", stream=True) + print(f"After batch2, _all_labels: {label_proc_streaming._all_labels}") + # After second batch, still should have 2 labels (0 and 1) + self.assertEqual(len(label_proc_streaming._all_labels), 2) + + label_proc_streaming.finalize_fit() + + # Verify vocabs match + self.assertEqual( + label_proc_normal.label_vocab, + label_proc_streaming.label_vocab, + ) + + # Test that both process values the same way + for label in [0, 1]: + result_normal = label_proc_normal.process(label) + result_streaming = label_proc_streaming.process(label) + self.assertTrue(torch.equal(result_normal, result_streaming)) + + def test_stagenet_integration_with_dataset(self): + """Test StageNet processors with SampleDataset.""" + samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "codes": ([0.0, 2.0], ["code1", "code2"]), + "label": 0, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-1", + "codes": ([0.0, 1.0, 2.0], ["code2", "code3", "code4"]), + "label": 1, + }, + { + "patient_id": "patient-2", + "visit_id": "visit-2", + "codes": ([0.0], ["code1"]), + "label": 0, + }, + ] + + input_schema = {"codes": "stagenet"} + output_schema = {"label": "binary"} + + # Create dataset + dataset = SampleDataset( + samples=samples, + input_schema=input_schema, + output_schema=output_schema, + dataset_name="test_stagenet_integration", + ) + + # Verify dataset created successfully + self.assertEqual(len(dataset), 3) + + # Check sample structure + sample = dataset[0] + self.assertIn("codes", sample) + self.assertIn("label", sample) + + # codes should be tuple (time, values) + self.assertIsInstance(sample["codes"], tuple) + time, values = sample["codes"] + self.assertIsInstance(values, torch.Tensor) + + # Test dataloader works + train_loader = get_dataloader(dataset, batch_size=2, shuffle=False) + data_batch = next(iter(train_loader)) + + self.assertEqual(data_batch["label"].shape[0], 2) + self.assertIsInstance(data_batch["codes"], tuple) + + def test_binary_label_validation_streaming(self): + """Test that binary label validation works in streaming mode.""" + # Create batches that will result in 3 unique labels (invalid) + batch1 = [{"label": 0}, {"label": 1}] + batch2 = [{"label": 2}] # Third label - should fail + + processor = BinaryLabelProcessor() + processor.fit(batch1, "label", stream=True) + processor.fit(batch2, "label", stream=True) + + # Should raise ValueError during finalize_fit + with self.assertRaises(ValueError) as context: + processor.finalize_fit() + + self.assertIn("Expected 2 unique labels, got 3", str(context.exception)) + + def test_streaming_preserves_backward_compatibility(self): + """Test that default stream=False maintains backward compatibility.""" + samples = [ + {"label": 0}, + {"label": 1}, + {"label": 0}, + {"label": 1}, + ] + + # Old API (implicit stream=False) + processor_old = BinaryLabelProcessor() + processor_old.fit(samples, "label") + + # New API (explicit stream=False) + processor_new = BinaryLabelProcessor() + processor_new.fit(samples, "label", stream=False) + + # Should be identical + self.assertEqual(processor_old.label_vocab, processor_new.label_vocab) + + # Test processing works the same + self.assertTrue( + torch.equal( + processor_old.process(0), + processor_new.process(0), + ) + ) + + def test_processors_output_tensor_types_streaming(self): + """Test that streaming processors produce correct tensor types and shapes.""" + + # Test BinaryLabelProcessor + binary_samples = [{"label": 0}, {"label": 1}] + binary_proc = BinaryLabelProcessor() + binary_proc.fit(binary_samples, "label", stream=True) + binary_proc.finalize_fit() + + result = binary_proc.process(1) + self.assertIsInstance(result, torch.Tensor, "BinaryLabel should return tensor") + self.assertEqual(result.dtype, torch.float32, "BinaryLabel should be float32") + self.assertEqual(result.shape, torch.Size([1]), "BinaryLabel should be [1]") + self.assertEqual(result.item(), 1.0, "BinaryLabel value should match") + + # Test MultiClassLabelProcessor + multiclass_samples = [ + {"label": "class_a"}, + {"label": "class_b"}, + {"label": "class_c"}, + ] + multiclass_proc = MultiClassLabelProcessor() + multiclass_proc.fit(multiclass_samples, "label", stream=True) + multiclass_proc.finalize_fit() + + result = multiclass_proc.process("class_b") + self.assertIsInstance(result, torch.Tensor, "MultiClass should return tensor") + self.assertEqual(result.dtype, torch.long, "MultiClass should be long tensor") + self.assertEqual(result.shape, torch.Size([]), "MultiClass should be scalar") + + # Test MultiLabelProcessor + multilabel_samples = [ + {"tags": ["tag1", "tag2"]}, + {"tags": ["tag2", "tag3", "tag4"]}, + ] + multilabel_proc = MultiLabelProcessor() + multilabel_proc.fit(multilabel_samples, "tags", stream=True) + multilabel_proc.finalize_fit() + + result = multilabel_proc.process(["tag1", "tag3"]) + self.assertIsInstance(result, torch.Tensor, "MultiLabel should return tensor") + self.assertEqual(result.dtype, torch.float, "MultiLabel should be float tensor") + self.assertEqual(result.shape[0], 4, "MultiLabel should have size of vocab") + # Should be one-hot encoded + self.assertEqual( + result.sum().item(), 2.0, "MultiLabel should have 2 active labels" + ) + + # Test StageNetProcessor (flat) + stagenet_samples = [ + {"codes": ([0.0, 1.0], ["code1", "code2"])}, + {"codes": ([0.0, 2.0], ["code3", "code4"])}, + ] + stagenet_proc = StageNetProcessor() + stagenet_proc.fit(stagenet_samples, "codes", stream=True) + stagenet_proc.finalize_fit() + + time, values = stagenet_proc.process(([0.0, 1.0], ["code1", "code2"])) + self.assertIsInstance(time, torch.Tensor, "StageNet time should be tensor") + self.assertIsInstance(values, torch.Tensor, "StageNet values should be tensor") + self.assertEqual(time.dtype, torch.float, "StageNet time should be float") + self.assertEqual(values.dtype, torch.long, "StageNet values should be long") + self.assertEqual( + time.shape, torch.Size([2]), "StageNet time should match input" + ) + self.assertEqual( + values.shape, torch.Size([2]), "StageNet values should match input" + ) + + # Test StageNetTensorProcessor + stagenet_tensor_samples = [ + {"values": (None, [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, + {"values": (None, [[7.0, 8.0, 9.0]])}, + ] + stagenet_tensor_proc = StageNetTensorProcessor() + stagenet_tensor_proc.fit(stagenet_tensor_samples, "values", stream=True) + stagenet_tensor_proc.finalize_fit() + + time, values = stagenet_tensor_proc.process((None, [[1.0, 2.0, 3.0]])) + self.assertIsNone(time, "StageNetTensor time can be None") + self.assertIsInstance( + values, torch.Tensor, "StageNetTensor values should be tensor" + ) + self.assertEqual(values.dtype, torch.float, "StageNetTensor should be float") + self.assertEqual( + values.shape, + torch.Size([1, 3]), + "StageNetTensor shape should be [visits, features]", + ) + + # Test NestedSequenceProcessor + nested_seq_samples = [ + {"codes": [["A", "B"], ["C", "D", "E"]]}, + {"codes": [["F", "G"]]}, + ] + nested_seq_proc = NestedSequenceProcessor() + nested_seq_proc.fit(nested_seq_samples, "codes", stream=True) + nested_seq_proc.finalize_fit() + + result = nested_seq_proc.process([["A", "C"], ["B"]]) + self.assertIsInstance( + result, torch.Tensor, "NestedSequence should return tensor" + ) + self.assertEqual( + result.dtype, torch.long, "NestedSequence should be long tensor" + ) + self.assertEqual( + len(result.shape), 2, "NestedSequence should be 2D [visits, codes]" + ) + + # Test NestedFloatsProcessor + nested_floats_samples = [ + {"values": [[1.0, 2.0], [3.0, 4.0, 5.0]]}, + {"values": [[6.0]]}, + ] + nested_floats_proc = NestedFloatsProcessor() + nested_floats_proc.fit(nested_floats_samples, "values", stream=True) + nested_floats_proc.finalize_fit() + + result = nested_floats_proc.process([[1.0, 2.0], [3.0]]) + self.assertIsInstance(result, torch.Tensor, "NestedFloats should return tensor") + self.assertEqual( + result.dtype, torch.float, "NestedFloats should be float tensor" + ) + self.assertEqual( + len(result.shape), 2, "NestedFloats should be 2D [visits, values]" + ) + + def test_processors_output_tensor_types_nonstreaming(self): + """Test that non-streaming processors produce correct tensor types (baseline).""" + + # Test BinaryLabelProcessor + binary_samples = [{"label": 0}, {"label": 1}] + binary_proc = BinaryLabelProcessor() + binary_proc.fit(binary_samples, "label", stream=False) + + result = binary_proc.process(1) + self.assertIsInstance(result, torch.Tensor, "BinaryLabel should return tensor") + self.assertEqual(result.dtype, torch.float32, "BinaryLabel should be float32") + + # Test MultiClassLabelProcessor + multiclass_samples = [{"label": "class_a"}, {"label": "class_b"}] + multiclass_proc = MultiClassLabelProcessor() + multiclass_proc.fit(multiclass_samples, "label", stream=False) + + result = multiclass_proc.process("class_b") + self.assertIsInstance(result, torch.Tensor, "MultiClass should return tensor") + self.assertEqual(result.dtype, torch.long, "MultiClass should be long tensor") + + # Test MultiLabelProcessor + multilabel_samples = [{"tags": ["tag1", "tag2"]}, {"tags": ["tag3"]}] + multilabel_proc = MultiLabelProcessor() + multilabel_proc.fit(multilabel_samples, "tags", stream=False) + + result = multilabel_proc.process(["tag1", "tag3"]) + self.assertIsInstance(result, torch.Tensor, "MultiLabel should return tensor") + self.assertEqual(result.dtype, torch.float, "MultiLabel should be float tensor") + + def test_streaming_vs_nonstreaming_tensor_equality(self): + """Test that streaming and non-streaming produce identical tensors.""" + + # BinaryLabelProcessor + samples_binary = [{"label": 0}, {"label": 1}, {"label": 0}] + + proc_stream = BinaryLabelProcessor() + proc_stream.fit(samples_binary[:2], "label", stream=True) + proc_stream.fit(samples_binary[2:], "label", stream=True) + proc_stream.finalize_fit() + + proc_normal = BinaryLabelProcessor() + proc_normal.fit(samples_binary, "label", stream=False) + + for label in [0, 1]: + result_stream = proc_stream.process(label) + result_normal = proc_normal.process(label) + self.assertTrue( + torch.equal(result_stream, result_normal), + f"Tensors should be identical for label {label}", + ) + self.assertEqual( + result_stream.dtype, result_normal.dtype, "Tensor dtypes should match" + ) + self.assertEqual( + result_stream.shape, result_normal.shape, "Tensor shapes should match" + ) + + # MultiLabelProcessor + samples_multilabel = [ + {"tags": ["tag1", "tag2"]}, + {"tags": ["tag2", "tag3"]}, + {"tags": ["tag4"]}, + ] + + proc_multi_stream = MultiLabelProcessor() + proc_multi_stream.fit(samples_multilabel[:2], "tags", stream=True) + proc_multi_stream.fit(samples_multilabel[2:], "tags", stream=True) + proc_multi_stream.finalize_fit() + + proc_multi_normal = MultiLabelProcessor() + proc_multi_normal.fit(samples_multilabel, "tags", stream=False) + + test_tags = ["tag1", "tag3"] + result_stream = proc_multi_stream.process(test_tags) + result_normal = proc_multi_normal.process(test_tags) + + self.assertTrue( + torch.equal(result_stream, result_normal), + "MultiLabel tensors should be identical", + ) + self.assertEqual(result_stream.dtype, result_normal.dtype) + self.assertEqual(result_stream.shape, result_normal.shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_streaming_cache.py b/tests/core/test_streaming_cache.py new file mode 100644 index 000000000..329d41703 --- /dev/null +++ b/tests/core/test_streaming_cache.py @@ -0,0 +1,378 @@ +"""Tests for streaming mode caching functionality. + +This module tests that caching works correctly at all levels: +1. Patient cache +2. Sample cache +3. Processor cache +4. Physical split cache +""" + +import unittest +import tempfile +import shutil +from pathlib import Path +from unittest.mock import Mock, patch, MagicMock +import polars as pl + +from pyhealth.datasets.iterable_sample_dataset import IterableSampleDataset + + +class TestStreamingCache(unittest.TestCase): + """Test streaming mode caching at all levels.""" + + def setUp(self): + """Set up test fixtures.""" + self.temp_dir = tempfile.mkdtemp() + self.cache_dir = Path(self.temp_dir) + + def tearDown(self): + """Clean up temporary files.""" + if Path(self.temp_dir).exists(): + shutil.rmtree(self.temp_dir) + + def test_sample_cache_check_exists(self): + """Test that sample cache existence is checked before generating.""" + # Create a mock sample cache file + dataset = IterableSampleDataset( + input_schema={"feature": "sequence"}, + output_schema={"label": "label"}, + dataset_name="TestDataset", + task_name="TestTask", + cache_dir=str(self.cache_dir), + ) + + # Create fake cache file + sample_cache_path = dataset._sample_cache_path + sample_cache_path.parent.mkdir(parents=True, exist_ok=True) + + # Write some sample data + sample_df = pl.DataFrame( + { + "patient_id": ["P1", "P2"], + "feature": [[1, 2, 3], [4, 5, 6]], + "label": [0, 1], + } + ) + sample_df.write_parquet(sample_cache_path) + + # Check that cache exists + self.assertTrue(sample_cache_path.exists()) + print(f"✓ Sample cache exists check passed: {sample_cache_path}") + + def test_processor_cache_paths(self): + """Test that processor cache paths are constructed correctly.""" + processor_cache_dir = self.cache_dir / "processors" + + input_processors_path = processor_cache_dir / "input_processors.pkl" + output_processors_path = processor_cache_dir / "output_processors.pkl" + + # Check path construction + self.assertEqual(input_processors_path.name, "input_processors.pkl") + self.assertEqual(output_processors_path.name, "output_processors.pkl") + print("✓ Processor cache path construction correct") + + def test_sample_cache_loading(self): + """Test that sample cache can be loaded correctly.""" + dataset = IterableSampleDataset( + input_schema={"feature": "sequence"}, + output_schema={"label": "label"}, + dataset_name="TestDataset", + task_name="TestTask", + cache_dir=str(self.cache_dir), + ) + + # Create and write sample cache + sample_df = pl.DataFrame( + { + "patient_id": ["P1", "P2", "P3"], + "feature": [[1, 2], [3, 4], [5, 6]], + "label": [0, 1, 0], + } + ) + sample_df.write_parquet(dataset._sample_cache_path) + + # Simulate loading cache + dataset._samples_finalized = True + sample_count_df = ( + pl.scan_parquet(dataset._sample_cache_path) + .select(pl.count().alias("count")) + .collect(streaming=True) + ) + dataset._num_samples = sample_count_df["count"][0] + + self.assertEqual(dataset._num_samples, 3) + self.assertTrue(dataset._samples_finalized) + print("✓ Sample cache loading works correctly") + + def test_physical_split_cache_check(self): + """Test that physical split cache check works.""" + dataset = IterableSampleDataset( + input_schema={"feature": "sequence"}, + output_schema={"label": "label"}, + dataset_name="TestDataset", + task_name="TestTask", + cache_dir=str(self.cache_dir), + ) + + # Create main cache + sample_df = pl.DataFrame( + { + "patient_id": ["P1", "P2", "P3"], + "feature": [[1, 2], [3, 4], [5, 6]], + "label": [0, 1, 0], + } + ) + sample_df.write_parquet(dataset._sample_cache_path) + dataset._samples_finalized = True + dataset._num_samples = 3 + + # Create a split cache file + split_cache_path = self.cache_dir / "TestDataset_TestTask_samples_train.parquet" + train_df = sample_df.filter(pl.col("patient_id").is_in(["P1", "P2"])) + train_df.write_parquet(split_cache_path) + + # Check that split cache exists + self.assertTrue(split_cache_path.exists()) + + # Verify split has correct data + loaded_split = pl.read_parquet(split_cache_path) + self.assertEqual(len(loaded_split), 2) + print("✓ Physical split cache check works") + + def test_cache_hierarchy(self): + """Test the complete cache hierarchy structure.""" + # Expected cache structure: + # cache_dir/ + # ├── {dataset}_{task}_samples.parquet (main cache) + # ├── {dataset}_{task}_samples_train.parquet (split) + # ├── {dataset}_{task}_samples_val.parquet (split) + # └── processors/ + # ├── input_processors.pkl + # └── output_processors.pkl + + dataset = IterableSampleDataset( + input_schema={"feature": "sequence"}, + output_schema={"label": "label"}, + dataset_name="MyDataset", + task_name="MyTask", + cache_dir=str(self.cache_dir), + ) + + # Create main cache + main_cache = dataset._sample_cache_path + self.assertEqual(main_cache.name, "MyDataset_MyTask_samples.parquet") + + # Check expected split paths + train_split = self.cache_dir / "MyDataset_MyTask_samples_train.parquet" + val_split = self.cache_dir / "MyDataset_MyTask_samples_val.parquet" + + self.assertEqual(train_split.name, "MyDataset_MyTask_samples_train.parquet") + self.assertEqual(val_split.name, "MyDataset_MyTask_samples_val.parquet") + + # Check processor directory + processor_dir = self.cache_dir / "processors" + self.assertEqual(processor_dir.name, "processors") + + print("✓ Cache hierarchy structure correct") + + def test_dev_mode_cache_separation(self): + """Test that dev mode uses separate cache files.""" + # Dev mode dataset + dev_dataset = IterableSampleDataset( + input_schema={"feature": "sequence"}, + output_schema={"label": "label"}, + dataset_name="TestDataset", + task_name="TestTask", + cache_dir=str(self.cache_dir), + dev=True, + dev_max_patients=100, + ) + + # Full mode dataset + full_dataset = IterableSampleDataset( + input_schema={"feature": "sequence"}, + output_schema={"label": "label"}, + dataset_name="TestDataset", + task_name="TestTask", + cache_dir=str(self.cache_dir), + dev=False, + ) + + # Check that cache paths are different + self.assertNotEqual( + dev_dataset._sample_cache_path, full_dataset._sample_cache_path + ) + + # Check dev cache has suffix + self.assertIn("_dev_100", str(dev_dataset._sample_cache_path)) + self.assertNotIn("_dev", str(full_dataset._sample_cache_path)) + + print("✓ Dev mode cache separation works") + + def test_cache_reuse_scenario(self): + """Test complete cache reuse scenario (simulated).""" + dataset = IterableSampleDataset( + input_schema={"feature": "sequence"}, + output_schema={"label": "label"}, + dataset_name="TestDataset", + task_name="TestTask", + cache_dir=str(self.cache_dir), + ) + + # First run - create cache + sample_df = pl.DataFrame( + {"patient_id": ["P1", "P2"], "feature": [[1, 2], [3, 4]], "label": [0, 1]} + ) + sample_df.write_parquet(dataset._sample_cache_path) + + # Second run - check cache exists + cache_exists = dataset._sample_cache_path.exists() + self.assertTrue(cache_exists) + + if cache_exists: + # Simulate loading from cache + dataset._samples_finalized = True + sample_count_df = ( + pl.scan_parquet(dataset._sample_cache_path) + .select(pl.count().alias("count")) + .collect(streaming=True) + ) + dataset._num_samples = sample_count_df["count"][0] + + self.assertEqual(dataset._num_samples, 2) + print("✓ Cache reuse scenario works correctly") + + def test_split_cache_reuse(self): + """Test that split caches are reused when they exist.""" + dataset = IterableSampleDataset( + input_schema={"feature": "sequence"}, + output_schema={"label": "label"}, + dataset_name="TestDataset", + task_name="TestTask", + cache_dir=str(self.cache_dir), + ) + + # Create main cache + sample_df = pl.DataFrame( + { + "patient_id": ["P1", "P2", "P3", "P4"], + "feature": [[1], [2], [3], [4]], + "label": [0, 1, 0, 1], + } + ) + sample_df.write_parquet(dataset._sample_cache_path) + dataset._samples_finalized = True + dataset._num_samples = 4 + + # Create train split cache (first time) + train_split_path = self.cache_dir / "TestDataset_TestTask_samples_train.parquet" + train_df = sample_df.filter(pl.col("patient_id").is_in(["P1", "P2", "P3"])) + train_df.write_parquet(train_split_path) + + # Check split exists + self.assertTrue(train_split_path.exists()) + + # Second time - split should be reused + split_exists = train_split_path.exists() + self.assertTrue(split_exists) + + if split_exists: + # Load from existing split + loaded_split = pl.read_parquet(train_split_path) + self.assertEqual(len(loaded_split), 3) + + print("✓ Split cache reuse works correctly") + + def test_cache_invalidation_on_different_config(self): + """Test that different dev configs use different caches.""" + # Dataset with dev=True, dev_max_patients=100 + dataset1 = IterableSampleDataset( + input_schema={"feature": "sequence"}, + output_schema={"label": "label"}, + dataset_name="TestDataset", + task_name="TestTask", + cache_dir=str(self.cache_dir), + dev=True, + dev_max_patients=100, + ) + + # Dataset with dev=True, dev_max_patients=500 + dataset2 = IterableSampleDataset( + input_schema={"feature": "sequence"}, + output_schema={"label": "label"}, + dataset_name="TestDataset", + task_name="TestTask", + cache_dir=str(self.cache_dir), + dev=True, + dev_max_patients=500, + ) + + # Cache paths should be different + self.assertNotEqual(dataset1._sample_cache_path, dataset2._sample_cache_path) + + self.assertIn("_dev_100", str(dataset1._sample_cache_path)) + self.assertIn("_dev_500", str(dataset2._sample_cache_path)) + + print("✓ Different dev configs use different caches") + + +class TestProcessorCaching(unittest.TestCase): + """Test processor caching functionality.""" + + def setUp(self): + """Set up test fixtures.""" + self.temp_dir = tempfile.mkdtemp() + self.cache_dir = Path(self.temp_dir) + + def tearDown(self): + """Clean up temporary files.""" + if Path(self.temp_dir).exists(): + shutil.rmtree(self.temp_dir) + + def test_processor_cache_directory_creation(self): + """Test that processor cache directory is created correctly.""" + processor_dir = self.cache_dir / "processors" + processor_dir.mkdir(parents=True, exist_ok=True) + + self.assertTrue(processor_dir.exists()) + self.assertTrue(processor_dir.is_dir()) + print("✓ Processor cache directory created") + + def test_processor_file_paths(self): + """Test processor file path construction.""" + processor_dir = self.cache_dir / "processors" + + input_path = processor_dir / "input_processors.pkl" + output_path = processor_dir / "output_processors.pkl" + + self.assertEqual(input_path.suffix, ".pkl") + self.assertEqual(output_path.suffix, ".pkl") + print("✓ Processor file paths correct") + + def test_processor_cache_check(self): + """Test checking if processors are cached.""" + processor_dir = self.cache_dir / "processors" + processor_dir.mkdir(parents=True, exist_ok=True) + + input_path = processor_dir / "input_processors.pkl" + output_path = processor_dir / "output_processors.pkl" + + # Initially, processors should not exist + processors_cached = input_path.exists() and output_path.exists() + self.assertFalse(processors_cached) + + # Create dummy files + input_path.touch() + output_path.touch() + + # Now processors should be cached + processors_cached = input_path.exists() and output_path.exists() + self.assertTrue(processors_cached) + print("✓ Processor cache check works") + + +if __name__ == "__main__": + print("=" * 70) + print("STREAMING CACHE TESTS") + print("=" * 70) + unittest.main(verbosity=2) diff --git a/tests/core/test_streaming_mode.py b/tests/core/test_streaming_mode.py new file mode 100644 index 000000000..2c232f88e --- /dev/null +++ b/tests/core/test_streaming_mode.py @@ -0,0 +1,521 @@ +"""Fast unit tests for streaming mode functionality. + +These are "baby tests" designed to run quickly to verify streaming mechanics work correctly. +They use small synthetic datasets and mocks to minimize runtime. +""" + +import unittest +import tempfile +from pathlib import Path +from unittest.mock import Mock +import polars as pl + +print("=" * 70) +print("STREAMING MODE TESTS - Starting test execution") +print("=" * 70) + +try: + from pyhealth.datasets.base_dataset import BaseDataset + + print("✓ Successfully imported BaseDataset") +except ImportError as e: + print(f"✗ Failed to import BaseDataset: {e}") + raise + +try: + from pyhealth.datasets.iterable_sample_dataset import IterableSampleDataset + + print("✓ Successfully imported IterableSampleDataset") +except ImportError as e: + print(f"✗ Failed to import IterableSampleDataset: {e}") + print("\nNote: You may need to install PyHealth in development mode:") + print(" cd PyHealth && pip install -e .") + raise + +print("=" * 70) +print() + + +class TestStreamingMode(unittest.TestCase): + """Fast unit tests for streaming mode functionality.""" + + def setUp(self): + """Set up test data for each test.""" + print("\n[SETUP] Creating test fixtures...") + # Create temporary directory for caching + self.temp_dir = tempfile.mkdtemp() + self.temp_cache_dir = Path(self.temp_dir) + print(f" Cache dir: {self.temp_dir}") + + # Create small synthetic data (3 patients, 5 events) + self.mock_data = pl.DataFrame( + { + "patient_id": ["P1", "P1", "P2", "P2", "P3"], + "event_type": [ + "diagnosis", + "medication", + "diagnosis", + "procedure", + "diagnosis", + ], + "timestamp": pl.Series( + [ + "2020-01-01", + "2020-01-02", + "2020-01-01", + "2020-01-03", + "2020-01-01", + ] + ).str.strptime(pl.Datetime, format="%Y-%m-%d"), + "diagnosis/code": ["D001", None, "D002", None, "D003"], + "medication/name": [None, "M001", None, None, None], + "procedure/type": [None, None, None, "P001", None], + } + ) + print(f" Mock dataset shape: {self.mock_data.shape}") + + def tearDown(self): + """Clean up temporary files after each test.""" + print("[TEARDOWN] Cleaning up...") + import shutil + + if self.temp_cache_dir.exists(): + shutil.rmtree(self.temp_cache_dir) + print(" ✓ Cleaned up temporary cache directory") + + def test_streaming_mode_initialization(self): + """Test that streaming mode can be initialized (fast - no data loading).""" + print("\n[TEST] test_streaming_mode_initialization") + # Use mock to avoid actual dataset loading + dataset_mock = Mock(spec=BaseDataset) + dataset_mock.stream = True + dataset_mock.cache_dir = self.temp_cache_dir + + self.assertTrue(dataset_mock.stream) + self.assertEqual(dataset_mock.cache_dir, self.temp_cache_dir) + print(" ✓ Streaming mode initialization test passed") + + def test_patient_cache_creation(self): + """Test that patient cache is created (fast - minimal data).""" + print("\n[TEST] test_patient_cache_creation") + + # Write small test data to cache + cache_path = self.temp_cache_dir / "test_patients.parquet" + print(f" Writing {len(self.mock_data)} rows to {cache_path.name}...") + self.mock_data.write_parquet(cache_path) + + self.assertTrue(cache_path.exists()) + print(f" ✓ Cache file created: {cache_path.exists()}") + + # Verify data can be read back + loaded = pl.read_parquet(cache_path) + self.assertEqual(len(loaded), len(self.mock_data)) + self.assertEqual(loaded.schema, self.mock_data.schema) + print(f" ✓ Cache data verified: {len(loaded)} rows read back") + + def test_patient_cache_sorted_by_patient_id(self): + """Test that patient cache is sorted by patient_id for efficient access.""" + print("\n[TEST] test_patient_cache_sorted_by_patient_id") + + # Sort data by patient_id (as done in _build_patient_cache) + sorted_data = self.mock_data.sort("patient_id", "timestamp") + + cache_path = self.temp_cache_dir / "test_patients_sorted.parquet" + sorted_data.write_parquet(cache_path) + + # Verify data is sorted + loaded = pl.read_parquet(cache_path) + patient_ids = loaded["patient_id"].to_list() + self.assertEqual(patient_ids, sorted(patient_ids)) + print(" ✓ Patient cache is sorted correctly") + + def test_patient_index_creation(self): + """Test that patient index is created correctly.""" + print("\n[TEST] test_patient_index_creation") + + # Create patient cache + cache_path = self.temp_cache_dir / "test_patients.parquet" + self.mock_data.write_parquet(cache_path) + + # Build patient index (as done in _build_patient_cache) + patient_index = ( + pl.scan_parquet(cache_path) + .group_by("patient_id") + .agg( + [ + pl.count().alias("event_count"), + pl.first("timestamp").alias("first_timestamp"), + pl.last("timestamp").alias("last_timestamp"), + ] + ) + .sort("patient_id") + .collect() + ) + + # Verify index + self.assertEqual(len(patient_index), 3) # P1, P2, P3 + self.assertIn("event_count", patient_index.columns) + self.assertIn("first_timestamp", patient_index.columns) + self.assertIn("last_timestamp", patient_index.columns) + + # Verify counts + p1_count = patient_index.filter(pl.col("patient_id") == "P1")["event_count"][0] + self.assertEqual(p1_count, 2) # P1 has 2 events + print(" ✓ Patient index created correctly") + + def test_sample_storage_streaming(self): + """Test that samples can be written incrementally (fast).""" + print("\n[TEST] test_sample_storage_streaming") + # Create small sample batch + samples = [ + {"patient_id": "P1", "label": 0, "value": 100}, + {"patient_id": "P2", "label": 1, "value": 200}, + ] + + sample_df = pl.DataFrame(samples) + cache_path = self.temp_cache_dir / "test_samples.parquet" + sample_df.write_parquet(cache_path) + + self.assertTrue(cache_path.exists()) + + # Verify samples can be read + loaded_samples = pl.read_parquet(cache_path).to_dicts() + self.assertEqual(len(loaded_samples), 2) + self.assertEqual(loaded_samples[0]["patient_id"], "P1") + self.assertEqual(loaded_samples[1]["patient_id"], "P2") + print(" ✓ Sample storage works correctly") + + def test_iterable_dataset_length(self): + """Test that IterableSampleDataset reports correct length (fast).""" + print("\n[TEST] test_iterable_dataset_length") + # Create small sample dataset + samples = pl.DataFrame( + { + "patient_id": ["P1", "P2", "P3"], + "label": [0, 1, 0], + } + ) + cache_path = self.temp_cache_dir / "samples.parquet" + samples.write_parquet(cache_path) + + # Mock IterableSampleDataset + dataset_mock = Mock(spec=IterableSampleDataset) + dataset_mock._num_samples = len(samples) + dataset_mock.__len__ = lambda self: dataset_mock._num_samples + + self.assertEqual(len(dataset_mock), 3) + print(" ✓ IterableSampleDataset length works correctly") + + def test_batch_iteration(self): + """Test batch reading from parquet (fast - 10 samples).""" + print("\n[TEST] test_batch_iteration") + # Create test samples + samples = pl.DataFrame( + { + "idx": range(10), + "value": range(100, 110), + } + ) + cache_path = self.temp_cache_dir / "samples.parquet" + samples.write_parquet(cache_path) + + # Test batch reading + batch_size = 3 + lf = pl.scan_parquet(cache_path) + + batches = [] + num_samples = 10 + num_batches = (num_samples + batch_size - 1) // batch_size + for i in range(num_batches): + offset = i * batch_size + length = min(batch_size, num_samples - offset) + batch = lf.slice(offset, length).collect() + batches.append(batch) + + # Verify we got all samples + total_samples = sum(len(b) for b in batches) + self.assertEqual(total_samples, 10) + + # Verify batch sizes + self.assertEqual(len(batches[0]), 3) # First batch + self.assertEqual(len(batches[1]), 3) # Second batch + self.assertEqual(len(batches[2]), 3) # Third batch + self.assertEqual(len(batches[3]), 1) # Last batch (remainder) + print(" ✓ Batch iteration works correctly") + + def test_streaming_mode_error_messages(self): + """Test that appropriate errors are raised in stream mode (fast).""" + print("\n[TEST] test_streaming_mode_error_messages") + dataset_mock = Mock(spec=BaseDataset) + dataset_mock.stream = True + + # Test collected_global_event_df error + def raise_runtime_error(): + raise RuntimeError( + "collected_global_event_df is not available in stream mode " + "as it would load the entire dataset into memory. " + "Use iter_patients_streaming() for memory-efficient patient iteration." + ) + + type(dataset_mock).collected_global_event_df = property( + lambda self: raise_runtime_error() + ) + + with self.assertRaises(RuntimeError) as context: + _ = dataset_mock.collected_global_event_df + + self.assertIn("not available in stream mode", str(context.exception)) + print(" ✓ Stream mode error messages work correctly") + + def test_iter_patients_error_in_stream_mode(self): + """Test that iter_patients raises error when called without df in stream mode.""" + print("\n[TEST] test_iter_patients_error_in_stream_mode") + dataset_mock = Mock(spec=BaseDataset) + dataset_mock.stream = True + + def mock_iter_patients(df=None): + if df is None and dataset_mock.stream: + raise RuntimeError( + "iter_patients() requires collected DataFrame which is not " + "available in stream mode. Use iter_patients_streaming() instead." + ) + + dataset_mock.iter_patients = mock_iter_patients + + with self.assertRaises(RuntimeError) as context: + dataset_mock.iter_patients() + + self.assertIn("Use iter_patients_streaming", str(context.exception)) + print(" ✓ iter_patients error in stream mode works correctly") + + def test_iterable_sample_dataset_initialization(self): + """Test IterableSampleDataset can be initialized correctly.""" + print("\n[TEST] test_iterable_sample_dataset_initialization") + print(" Creating IterableSampleDataset...") + dataset = IterableSampleDataset( + input_schema={"feature": "sequence"}, + output_schema={"label": "label"}, + dataset_name="TestDataset", + task_name="TestTask", + cache_dir=str(self.temp_cache_dir), + ) + + self.assertEqual(dataset.dataset_name, "TestDataset") + self.assertEqual(dataset.task_name, "TestTask") + self.assertEqual(dataset._num_samples, 0) + self.assertFalse(dataset._samples_finalized) + print(f" ✓ IterableSampleDataset initialized successfully") + print(f" - Dataset name: {dataset.dataset_name}") + print(f" - Task name: {dataset.task_name}") + print(f" - Initial samples: {dataset._num_samples}") + + def test_iterable_sample_dataset_add_samples(self): + """Test adding samples to IterableSampleDataset.""" + print("\n[TEST] test_iterable_sample_dataset_add_samples") + dataset = IterableSampleDataset( + input_schema={"feature": "sequence"}, + output_schema={"label": "label"}, + dataset_name="TestDataset", + task_name="TestTask", + cache_dir=str(self.temp_cache_dir), + ) + + # Add first batch + samples1 = [ + {"patient_id": "P1", "feature": [1, 2, 3], "label": 0}, + {"patient_id": "P2", "feature": [4, 5, 6], "label": 1}, + ] + print(f" Adding batch 1: {len(samples1)} samples...") + dataset.add_samples_streaming(samples1) + + self.assertEqual(dataset._num_samples, 2) + print(f" ✓ Batch 1 added. Total samples: {dataset._num_samples}") + + # Add second batch + samples2 = [ + {"patient_id": "P3", "feature": [7, 8, 9], "label": 0}, + ] + print(f" Adding batch 2: {len(samples2)} samples...") + dataset.add_samples_streaming(samples2) + + self.assertEqual(dataset._num_samples, 3) + print(f" ✓ Batch 2 added. Total samples: {dataset._num_samples}") + + def test_iterable_sample_dataset_finalize(self): + """Test finalizing samples in IterableSampleDataset.""" + print("\n[TEST] test_iterable_sample_dataset_finalize") + dataset = IterableSampleDataset( + input_schema={"feature": "sequence"}, + output_schema={"label": "label"}, + dataset_name="TestDataset", + task_name="TestTask", + cache_dir=str(self.temp_cache_dir), + ) + + samples = [ + {"patient_id": "P1", "feature": [1, 2, 3], "label": 0}, + ] + dataset.add_samples_streaming(samples) + dataset.finalize_samples() + + self.assertTrue(dataset._samples_finalized) + + # Should raise error if trying to add after finalize + with self.assertRaises(RuntimeError) as context: + dataset.add_samples_streaming(samples) + + self.assertIn("Cannot add more samples", str(context.exception)) + print(" ✓ Sample finalization works correctly") + + def test_cache_path_construction(self): + """Test that cache paths are constructed correctly.""" + print("\n[TEST] test_cache_path_construction") + dataset = IterableSampleDataset( + input_schema={"feature": "sequence"}, + output_schema={"label": "label"}, + dataset_name="MyDataset", + task_name="MyTask", + cache_dir=str(self.temp_cache_dir), + ) + + expected_path = self.temp_cache_dir / "MyDataset_MyTask_samples.parquet" + self.assertEqual(dataset._sample_cache_path, expected_path) + print(" ✓ Cache path construction works correctly") + + def test_patient_filtering_from_cache(self): + """Test filtering specific patients from cache (fast).""" + print("\n[TEST] test_patient_filtering_from_cache") + + # Create patient cache + cache_path = self.temp_cache_dir / "test_patients.parquet" + self.mock_data.write_parquet(cache_path) + + # Filter for specific patient + patient_df = ( + pl.scan_parquet(cache_path).filter(pl.col("patient_id") == "P1").collect() + ) + + self.assertEqual(len(patient_df), 2) # P1 has 2 events + self.assertTrue(all(patient_df["patient_id"] == "P1")) + print(" ✓ Patient filtering from cache works correctly") + + def test_multiple_patient_filtering(self): + """Test filtering multiple patients from cache.""" + print("\n[TEST] test_multiple_patient_filtering") + + # Create patient cache + cache_path = self.temp_cache_dir / "test_patients.parquet" + self.mock_data.write_parquet(cache_path) + + # Filter for multiple patients + patient_ids = ["P1", "P3"] + filtered_df = ( + pl.scan_parquet(cache_path) + .filter(pl.col("patient_id").is_in(patient_ids)) + .collect() + ) + + self.assertEqual(len(filtered_df), 3) # P1 has 2 events, P3 has 1 event + self.assertEqual( + set(filtered_df["patient_id"].unique().to_list()), {"P1", "P3"} + ) + print(" ✓ Multiple patient filtering works correctly") + + def test_iter_patients_with_batch_size(self): + """Test iter_patients with batch_size parameter.""" + print("\n[TEST] test_iter_patients_with_batch_size") + + # Create patient cache + cache_path = self.temp_cache_dir / "test_patients.parquet" + self.mock_data.write_parquet(cache_path) + + # Create patient index + patient_index = ( + pl.scan_parquet(cache_path) + .group_by("patient_id") + .agg( + [ + pl.count().alias("event_count"), + pl.first("timestamp").alias("first_timestamp"), + pl.last("timestamp").alias("last_timestamp"), + ] + ) + .sort("patient_id") + .collect() + ) + index_path = self.temp_cache_dir / "test_patient_index.parquet" + patient_index.write_parquet(index_path) + + # Mock dataset with streaming mode + dataset_mock = Mock(spec=BaseDataset) + dataset_mock.stream = True + dataset_mock._patient_cache_path = cache_path + dataset_mock._patient_index_path = index_path + dataset_mock._patient_index = None + + # Import the actual iter_patients method from BaseDataset + # For this test, we'll verify batch reading logic + batch_size = 2 + lf = pl.scan_parquet(cache_path) + all_patient_ids = patient_index["patient_id"].to_list() + + batches_collected = [] + for i in range(0, len(all_patient_ids), batch_size): + batch_patient_ids = all_patient_ids[i : i + batch_size] + batch_df = lf.filter( + pl.col("patient_id").is_in(batch_patient_ids) + ).collect() + batches_collected.append(batch_df) + + # Verify we got 2 batches (3 patients with batch_size=2) + self.assertEqual(len(batches_collected), 2) + # First batch should have P1 and P2 data + self.assertGreater(len(batches_collected[0]), 0) + # Second batch should have P3 data + self.assertGreater(len(batches_collected[1]), 0) + print(" ✓ iter_patients with batch_size works correctly") + + def test_iter_patients_batch_vs_individual(self): + """Test that batch mode yields lists while individual yields patients.""" + print("\n[TEST] test_iter_patients_batch_vs_individual") + + # Create patient cache + cache_path = self.temp_cache_dir / "test_patients.parquet" + self.mock_data.write_parquet(cache_path) + + # Test batch mode (batch_size=2) + batch_size = 2 + lf = pl.scan_parquet(cache_path) + patient_ids = self.mock_data["patient_id"].unique().to_list() + + batches = [] + for i in range(0, len(patient_ids), batch_size): + batch_patient_ids = patient_ids[i : i + batch_size] + batch_df = lf.filter( + pl.col("patient_id").is_in(batch_patient_ids) + ).collect() + # In actual implementation, this would be a list of Patient objects + batches.append(batch_df) + + # Should get 2 batches for 3 patients + self.assertEqual(len(batches), 2) + + # Test individual mode (no batch_size) + individual_patients = [] + for patient_id in patient_ids: + patient_df = lf.filter(pl.col("patient_id") == patient_id).collect() + individual_patients.append(patient_df) + + # Should get 3 individual patients + self.assertEqual(len(individual_patients), 3) + print(" ✓ Batch vs individual iteration works correctly") + + +print("\n" + "=" * 70) +print("STREAMING MODE TESTS - Module loaded successfully") +print("All test functions defined and ready to run") +print("=" * 70) + + +if __name__ == "__main__": + unittest.main()