Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 140 additions & 70 deletions README.md

Large diffs are not rendered by default.

33 changes: 22 additions & 11 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,38 +35,46 @@ def get_parser():
"--src_data",
type=str,
choices=["eicu", "mimiciv", "umcdb", "hirid", "meds"],
default="mimiciv"
default="mimiciv",
)
parser.add_argument(
"--train_subset",
type=str,
default="train",
help="file name without extension to load data for the training. only used when"
"`--src_data` is set to `'meds'`."
"`--src_data` is set to `'meds'`.",
)
parser.add_argument(
"--valid_subset",
type=str,
default="tuning",
help="file name without extension to load data for the validation. only used when"
"`--src_data` is set to `'meds'`."
"`--src_data` is set to `'meds'`.",
)
parser.add_argument(
"--test_subset",
type=str,
default="held_out",
help="file name without extension to load data for the test. only used when `--src_data` "
"is set to `'meds'`."
"is set to `'meds'`.",
)
parser.add_argument(
"--unique_events_path",
type=str,
default=None,
help="path to directory containing `unique_events.h5` to encode events in MEDS dataset. "
"only used when `--src_data` is set to `'meds'`",
)

parser.add_argument(
"--test_cohort",
type=str,
default=None,
help="path to the test cohort, which must be a result of ACES. it can be either of "
"directory or the exact file path that has .parquet file extension. if provided with "
"directory, it tries to load `${test_subset}`/*.parquet files in the directory. "
"note that the set of patient ids in this cohort should be matched with that in the "
"test dataset"
"directory or the exact file path that has .parquet file extension. if provided with "
"directory, it tries to load `${test_subset}`/*.parquet files in the directory. "
"note that the set of patient ids in this cohort should be matched with that in the "
"test dataset",
)

parser.add_argument(
Expand All @@ -76,11 +84,12 @@ def get_parser():
"readmission",
"los_7",
"los_14",
"mortality_1",
"mortality" "mortality_1",
"mortality_2",
"mortality_3",
"mortality_7",
"mortality_14",
"mortality",
"diagnosis",
"creatinine_1",
"creatinine_2",
Expand All @@ -100,7 +109,7 @@ def get_parser():
"sodium_1",
"sodium_2",
"sodium_3",
"meds_single_task"
"meds_single_task",
],
default=[
"readmission",
Expand Down Expand Up @@ -197,7 +206,9 @@ def get_parser():
parser.add_argument("--debug", action="store_true")
parser.add_argument("--log_loss", action="store_true")
# Wandb
parser.add_argument("--wandb", action="store_true", help="whether to log using wandb")
parser.add_argument(
"--wandb", action="store_true", help="whether to log using wandb"
)
parser.add_argument("--wandb_entity_name", type=str)
parser.add_argument("--wandb_project_name", type=str, default="REMed")
parser.add_argument("--pretrained", type=str, default=None)
Expand Down
19 changes: 10 additions & 9 deletions scripts/meds/encode_events.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

# Function to display help message
function display_help() {
echo "Usage: $0 <PROCESSED_MEDS_DIR> <SAVE_DIR> <GPU_ID> <PRETRAINED_CHECKPOINT_DIR>"
echo "Usage: $0 <GPU_ID> <UNIQUE_EVENTS_DIR> <SAVE_DIR> <PRETRAINED_CHECKPOINT_DIR>"
echo
echo "This script encodes all events present in a MEDS cohort and caches them, which will"
echo "be the input data for the REMed model."
echo
echo "Arguments:"
echo " PROCESSED_MEDS_DIR Directory containing processed MEDS data, expected to contain *.h5 and *.tsv files."
echo " SAVE_DIR Output directory to save the encoded data as *_encoded.h5."
echo " GPU_ID GPU index to be used for training the model."
echo " UNIQUE_EVENTS_DIR directory containing the unique events to be encoded."
echo " SAVE_DIR Output directory to save the encoded unique events."
echo " PRETRAINED_CHECKPOINT_DIR Directory containing checkpoint for the pretrained event encoder, expected to contain checkpoint_best.pt."
echo
echo "Options:"
Expand All @@ -24,9 +24,9 @@ if [ "$#" -lt 4 ]; then
display_help
fi

PROCESSED_MEDS_DIR="$1"
SAVE_DIR="$2"
GPU_ID="$3"
GPU_ID="$1"
UNIQUE_EVENTS_DIR="$2"
SAVE_DIR="$3"
PRETRAINED_CHECKPOINT_DIR="$4"

accelerate launch \
Expand All @@ -35,11 +35,12 @@ accelerate launch \
--gpu_ids="$GPU_ID" \
main.py \
--src_data meds \
--input_path "$PROCESSED_MEDS_DIR" \
--input_path null \
--unique_events_path "$UNIQUE_EVENTS_DIR" \
--save_dir "$SAVE_DIR" \
--pred_targets meds_single_task \
--train_type short \
--random_sample \
--batch_size 8192 \
--encode_events \
--encode_only \
--resume_name "$PRETRAINED_CHECKPOINT_DIR"
--resume_name "$PRETRAINED_CHECKPOINT_DIR"
137 changes: 137 additions & 0 deletions scripts/meds/extract_unique_events.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import glob
import logging
import math
import multiprocessing
import os
import shutil
import sys
from argparse import ArgumentParser
from typing import List

import h5pickle
import numpy as np
from tqdm import tqdm

logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=os.environ.get("LOGLEVEL", "INFO").upper(),
stream=sys.stdout,
)
logger = logging.getLogger(__name__)


def get_parser():
parser = ArgumentParser()
parser.add_argument(
"root",
help="path to the **processed** MEDS dataset containing subdirectories for each split. "
"it will try to scan all **/*.h5 files existed in this directory and process them.",
)
parser.add_argument(
"--output_dir",
type=str,
default="outputs",
help="directory to save the processed outputs.",
)
parser.add_argument(
"--workers",
metavar="N",
default=1,
type=int,
help="number of parallel workers.",
)
parser.add_argument(
"--n_events_per_shard",
metavar="N",
default=1000000,
type=int,
help="number of events included for each shard",
)

return parser


def main(args):
filelist = glob.glob(os.path.join(args.root, "**/*.h5"))
files = [h5pickle.File(fname) for fname in filelist]

if args.workers <= 1:
unique_events = _extract_unique_events(files)
else:
n = args.workers
files_chunks = [files[i::n] for i in range(n)]

pool = multiprocessing.get_context("spawn").Pool(processes=args.workers)
unique_events_gathered = pool.map(_extract_unique_events, files_chunks)
pool.close()
pool.join()

logger.info("Gathering and reducing local unique events...")
unique_events = np.concatenate(unique_events_gathered)
unique_events = np.unique(unique_events, axis=0)
logger.info("Done!")

# rebase the output directory
if os.path.exists(os.path.join(args.output_dir, "unique_events")):
shutil.rmtree(os.path.join(args.output_dir, "unique_events"))
os.makedirs(os.path.join(args.output_dir, "unique_events"))

num_shards = math.ceil(len(unique_events) / args.n_events_per_shard)
for shard_id in range(num_shards):
start = shard_id * args.n_events_per_shard
end = min((shard_id + 1) * args.n_events_per_shard, len(unique_events))
sharded_unique_events = unique_events[start:end]
with h5pickle.File(
os.path.join(
args.output_dir, "unique_events", f"unique_events_{shard_id}.h5"
),
"w",
) as f:
for i, event_tuple in tqdm(
enumerate(sharded_unique_events), total=len(sharded_unique_events)
):
idx = str(shard_id * args.n_events_per_shard + i)
data = f.create_group(idx)

sources = np.stack(
[
np.array(event_tuple[0]),
np.array(event_tuple[1]),
np.array(event_tuple[2]),
]
)
data.create_dataset(
"sources", data=sources, dtype="i2", compression="lzf", shuffle=True
)


def _extract_unique_events(files: List[h5pickle.File]):
unique_events = []
pbar = tqdm(files, total=len(files))
for f in pbar:
pbar.set_description(f.filename)
for sbj_id in f["ehr"]:
input_ids = f["ehr"][sbj_id]["hi"][:, 0]
type_ids = f["ehr"][sbj_id]["hi"][:, 1]
dpe_ids = f["ehr"][sbj_id]["hi"][:, 2]

event_tokens = [
(
tuple(input_id),
tuple(type_id),
tuple(dpe_id),
)
for input_id, type_id, dpe_id in zip(input_ids, type_ids, dpe_ids)
]
event_tokens = list(np.unique(event_tokens, axis=0))
unique_events.extend(event_tokens)
unique_events = list(np.unique(unique_events, axis=0))

return unique_events


if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()
main(args)
Loading