Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions examples/memtest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# %%
import psutil, os, time, threading
PEAK_MEM_USAGE = 0
SELF_PROC = psutil.Process(os.getpid())

def track_mem():
global PEAK_MEM_USAGE
while True:
m = SELF_PROC.memory_info().rss
if m > PEAK_MEM_USAGE:
PEAK_MEM_USAGE = m
time.sleep(0.1)

threading.Thread(target=track_mem, daemon=True).start()
print(f"[MEM] start={PEAK_MEM_USAGE / (1024**3)} GB")

# %%
from pyhealth.datasets import MIMIC4Dataset
DATASET_DIR = "/home/logic/physionet.org/files/mimiciv/3.1"
dataset = MIMIC4Dataset(
ehr_root=DATASET_DIR,
ehr_tables=[
"patients",
"admissions",
"diagnoses_icd",
"procedures_icd",
"prescriptions",
"labevents",
],
)
print(f"[MEM] __init__={PEAK_MEM_USAGE / (1024**3):.3f} GB")
# %%
20 changes: 9 additions & 11 deletions pyhealth/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,12 @@ def load_table(self, table_name: str) -> pl.LazyFrame:
if table_name not in self.config.tables:
raise ValueError(f"Table {table_name} not found in config")

def _to_lower(col_name: str) -> str:
lower_name = col_name.lower()
if lower_name != col_name:
logger.warning("Renaming column %s to lowercase %s", col_name, lower_name)
return lower_name

table_cfg = self.config.tables[table_name]
csv_path = f"{self.root}/{table_cfg.file_path}"
csv_path = clean_path(csv_path)
Expand All @@ -216,10 +222,7 @@ def load_table(self, table_name: str) -> pl.LazyFrame:
df = scan_csv_gz_or_csv_tsv(csv_path)

# Convert column names to lowercase before calling preprocess_func
col_names = df.collect_schema().names()
if any(col != col.lower() for col in col_names):
logger.warning("Some column names were converted to lowercase")
df = df.with_columns([pl.col(col).alias(col.lower()) for col in col_names])
df = df.rename(_to_lower)

# Check if there is a preprocessing function for this table
preprocess_func = getattr(self, f"preprocess_{table_name}", None)
Expand All @@ -235,12 +238,7 @@ def load_table(self, table_name: str) -> pl.LazyFrame:
other_csv_path = clean_path(other_csv_path)
logger.info(f"Joining with table: {other_csv_path}")
join_df = scan_csv_gz_or_csv_tsv(other_csv_path)
join_df = join_df.with_columns(
[
pl.col(col).alias(col.lower())
for col in join_df.collect_schema().names()
]
)
join_df = join_df.rename(_to_lower)
join_key = join_cfg.on
columns = join_cfg.columns
how = join_cfg.how
Expand Down Expand Up @@ -283,7 +281,7 @@ def load_table(self, table_name: str) -> pl.LazyFrame:

# Flatten attribute columns with event_type prefix
attribute_columns = [
pl.col(attr).alias(f"{table_name}/{attr}") for attr in attribute_cols
pl.col(attr.lower()).alias(f"{table_name}/{attr}") for attr in attribute_cols
]

event_frame = df.select(base_columns + attribute_columns)
Expand Down