From a20ef77ca4f5c3560b64e3e75089653c69d2c2fa Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Wed, 19 Nov 2025 03:28:37 -0500 Subject: [PATCH 1/3] Add script to test memory usage of dataset Co-authored-by: John Wu <54558896+jhnwu3@users.noreply.github.com> --- examples/memtest.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 examples/memtest.py diff --git a/examples/memtest.py b/examples/memtest.py new file mode 100644 index 000000000..8a63090e8 --- /dev/null +++ b/examples/memtest.py @@ -0,0 +1,32 @@ +# %% +import psutil, os, time, threading +PEAK_MEM_USAGE = 0 +SELF_PROC = psutil.Process(os.getpid()) + +def track_mem(): + global PEAK_MEM_USAGE + while True: + m = SELF_PROC.memory_info().rss + if m > PEAK_MEM_USAGE: + PEAK_MEM_USAGE = m + time.sleep(0.1) + +threading.Thread(target=track_mem, daemon=True).start() +print(f"[MEM] start={PEAK_MEM_USAGE / (1024**3)} GB") + +# %% +from pyhealth.datasets import MIMIC4Dataset +DATASET_DIR = "/home/logic/physionet.org/files/mimiciv/3.1" +dataset = MIMIC4Dataset( + ehr_root=DATASET_DIR, + ehr_tables=[ + "patients", + "admissions", + "diagnoses_icd", + "procedures_icd", + "prescriptions", + "labevents", + ], +) +print(f"[MEM] __init__={PEAK_MEM_USAGE / (1024**3):.3f} GB") +# %% From bf2a86d5675a0dcb96ab642a6fd582f1f84c7a15 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Wed, 19 Nov 2025 03:35:04 -0500 Subject: [PATCH 2/3] Remove collect_schema usage during __init__ Co-authored-by: John Wu <54558896+jhnwu3@users.noreply.github.com> --- pyhealth/datasets/base_dataset.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 43948c828..d12466a59 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -208,6 +208,12 @@ def load_table(self, table_name: str) -> pl.LazyFrame: if table_name not in self.config.tables: raise ValueError(f"Table {table_name} not found in config") + def _to_lower(col_name: str) -> str: + lower_name = col_name.lower() + if lower_name != col_name: + logger.warning("Renaming column %s to lowercase %s", col_name, lower_name) + return lower_name + table_cfg = self.config.tables[table_name] csv_path = f"{self.root}/{table_cfg.file_path}" csv_path = clean_path(csv_path) @@ -216,10 +222,7 @@ def load_table(self, table_name: str) -> pl.LazyFrame: df = scan_csv_gz_or_csv_tsv(csv_path) # Convert column names to lowercase before calling preprocess_func - col_names = df.collect_schema().names() - if any(col != col.lower() for col in col_names): - logger.warning("Some column names were converted to lowercase") - df = df.with_columns([pl.col(col).alias(col.lower()) for col in col_names]) + df = df.rename(_to_lower) # Check if there is a preprocessing function for this table preprocess_func = getattr(self, f"preprocess_{table_name}", None) @@ -235,12 +238,7 @@ def load_table(self, table_name: str) -> pl.LazyFrame: other_csv_path = clean_path(other_csv_path) logger.info(f"Joining with table: {other_csv_path}") join_df = scan_csv_gz_or_csv_tsv(other_csv_path) - join_df = join_df.with_columns( - [ - pl.col(col).alias(col.lower()) - for col in join_df.collect_schema().names() - ] - ) + join_df = join_df.rename(_to_lower) join_key = join_cfg.on columns = join_cfg.columns how = join_cfg.how From bcb94e26669f45e08c4549fb00e72166500d6f39 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Wed, 19 Nov 2025 04:17:19 -0500 Subject: [PATCH 3/3] Fix some code still accessing upper case column name Co-authored-by: John Wu <54558896+jhnwu3@users.noreply.github.com> --- pyhealth/datasets/base_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index d12466a59..3390453ff 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -281,7 +281,7 @@ def _to_lower(col_name: str) -> str: # Flatten attribute columns with event_type prefix attribute_columns = [ - pl.col(attr).alias(f"{table_name}/{attr}") for attr in attribute_cols + pl.col(attr.lower()).alias(f"{table_name}/{attr}") for attr in attribute_cols ] event_frame = df.select(base_columns + attribute_columns)