diff --git a/examples/memtest.py b/examples/memtest.py new file mode 100644 index 000000000..8a63090e8 --- /dev/null +++ b/examples/memtest.py @@ -0,0 +1,32 @@ +# %% +import psutil, os, time, threading +PEAK_MEM_USAGE = 0 +SELF_PROC = psutil.Process(os.getpid()) + +def track_mem(): + global PEAK_MEM_USAGE + while True: + m = SELF_PROC.memory_info().rss + if m > PEAK_MEM_USAGE: + PEAK_MEM_USAGE = m + time.sleep(0.1) + +threading.Thread(target=track_mem, daemon=True).start() +print(f"[MEM] start={PEAK_MEM_USAGE / (1024**3)} GB") + +# %% +from pyhealth.datasets import MIMIC4Dataset +DATASET_DIR = "/home/logic/physionet.org/files/mimiciv/3.1" +dataset = MIMIC4Dataset( + ehr_root=DATASET_DIR, + ehr_tables=[ + "patients", + "admissions", + "diagnoses_icd", + "procedures_icd", + "prescriptions", + "labevents", + ], +) +print(f"[MEM] __init__={PEAK_MEM_USAGE / (1024**3):.3f} GB") +# %% diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 43948c828..3390453ff 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -208,6 +208,12 @@ def load_table(self, table_name: str) -> pl.LazyFrame: if table_name not in self.config.tables: raise ValueError(f"Table {table_name} not found in config") + def _to_lower(col_name: str) -> str: + lower_name = col_name.lower() + if lower_name != col_name: + logger.warning("Renaming column %s to lowercase %s", col_name, lower_name) + return lower_name + table_cfg = self.config.tables[table_name] csv_path = f"{self.root}/{table_cfg.file_path}" csv_path = clean_path(csv_path) @@ -216,10 +222,7 @@ def load_table(self, table_name: str) -> pl.LazyFrame: df = scan_csv_gz_or_csv_tsv(csv_path) # Convert column names to lowercase before calling preprocess_func - col_names = df.collect_schema().names() - if any(col != col.lower() for col in col_names): - logger.warning("Some column names were converted to lowercase") - df = df.with_columns([pl.col(col).alias(col.lower()) for col in col_names]) + df = df.rename(_to_lower) # Check if there is a preprocessing function for this table preprocess_func = getattr(self, f"preprocess_{table_name}", None) @@ -235,12 +238,7 @@ def load_table(self, table_name: str) -> pl.LazyFrame: other_csv_path = clean_path(other_csv_path) logger.info(f"Joining with table: {other_csv_path}") join_df = scan_csv_gz_or_csv_tsv(other_csv_path) - join_df = join_df.with_columns( - [ - pl.col(col).alias(col.lower()) - for col in join_df.collect_schema().names() - ] - ) + join_df = join_df.rename(_to_lower) join_key = join_cfg.on columns = join_cfg.columns how = join_cfg.how @@ -283,7 +281,7 @@ def load_table(self, table_name: str) -> pl.LazyFrame: # Flatten attribute columns with event_type prefix attribute_columns = [ - pl.col(attr).alias(f"{table_name}/{attr}") for attr in attribute_cols + pl.col(attr.lower()).alias(f"{table_name}/{attr}") for attr in attribute_cols ] event_frame = df.select(base_columns + attribute_columns)