From aa88d2b219049c7e3a7e9c076b8a6c3df896fd13 Mon Sep 17 00:00:00 2001 From: Jwoo5 Date: Wed, 2 Oct 2024 16:12:33 +0900 Subject: [PATCH 1/8] Updates for MEDS processing --- README.md | 7 + main.py | 9 + scripts/meds/encode_events.sh | 11 +- scripts/meds/extract_unique_events.py | 128 ++++++++++++ scripts/meds/map_events_to_vec.py | 127 ++++++++++++ scripts/meds/predict.sh | 13 +- scripts/meds/process_meds.py | 107 ++++++++-- src/dataset.py | 124 +++++++---- src/models/eventencoder.py | 6 + src/models/model.py | 10 +- src/trainer/base.py | 288 ++++++++++++-------------- src/trainer/remed.py | 14 +- 12 files changed, 600 insertions(+), 244 deletions(-) create mode 100644 scripts/meds/extract_unique_events.py create mode 100644 scripts/meds/map_events_to_vec.py diff --git a/README.md b/README.md index 258a9cf..c3782ab 100644 --- a/README.md +++ b/README.md @@ -153,6 +153,10 @@ accelerate launch \ ## Support for MEDS dataset + +> [!Caution] +> This instruction is still under progress, which may not be aligned with the recent updates. + We officially support to process [MEDS](https://github.com/Medical-Event-Data-Standard/meds/releases/tag/0.3.0) dataset (currently, MEDS v0.3) with a cohort defined by [ACES](https://github.com/justin13601/ACES), only for the REMed model. It consists of 4 steps in total, each of which can be run by Python or shell scripts that are prepared in [`scripts/meds/`](scripts/meds/) directory. For more detailed information, please follow the instructions below. @@ -304,6 +308,7 @@ Additionally, the following scripts assume your dataset is split into `"train"`, --scorer \ --scorer_use_time \ --max_seq_len 200000 \ + --max_retrieve_len 512 \ # if you want to log using wandb --wandb \ --wandb_entity_name $wandb_entity_name \ @@ -334,6 +339,8 @@ Additionally, the following scripts assume your dataset is split into `"train"`, --train_type remed \ --scorer \ --scorer_use_time \ + --max_seq_len 200000 \ + --max_retrieve_len 512 \ --test_only \ --test_cohort $ACES_TEST_COHORT_DIR \ --resume_name $CHECKPOINT_DIR diff --git a/main.py b/main.py index b71a8e4..fb8e71d 100644 --- a/main.py +++ b/main.py @@ -58,6 +58,14 @@ def get_parser(): help="file name without extension to load data for the test. only used when `--src_data` " "is set to `'meds'`." ) + parser.add_argument( + "--unique_events_path", + type=str, + default=None, + help="path to directory containing `unique_events.h5` to encode events in MEDS dataset. " + "only used when `--src_data` is set to `'meds'`" + ) + parser.add_argument( "--test_cohort", type=str, @@ -81,6 +89,7 @@ def get_parser(): "mortality_3", "mortality_7", "mortality_14", + "mortality", "diagnosis", "creatinine_1", "creatinine_2", diff --git a/scripts/meds/encode_events.sh b/scripts/meds/encode_events.sh index e6d3a3a..3340bae 100644 --- a/scripts/meds/encode_events.sh +++ b/scripts/meds/encode_events.sh @@ -24,10 +24,10 @@ if [ "$#" -lt 4 ]; then display_help fi -PROCESSED_MEDS_DIR="$1" +UNIQUE_EVENTS_DIR="$1" SAVE_DIR="$2" -GPU_ID="$3" -PRETRAINED_CHECKPOINT_DIR="$4" +PRETRAINED_CHECKPOINT_DIR="$3" +GPU_ID="$4" accelerate launch \ --config_file config/single.json \ @@ -35,11 +35,12 @@ accelerate launch \ --gpu_ids="$GPU_ID" \ main.py \ --src_data meds \ - --input_path "$PROCESSED_MEDS_DIR" \ + --input_path null \ + --unique_events_path "$UNIQUE_EVENTS_DIR" \ --save_dir "$SAVE_DIR" \ --pred_targets meds_single_task \ --train_type short \ - --random_sample \ + --batch_size 8192 \ --encode_events \ --encode_only \ --resume_name "$PRETRAINED_CHECKPOINT_DIR" \ No newline at end of file diff --git a/scripts/meds/extract_unique_events.py b/scripts/meds/extract_unique_events.py new file mode 100644 index 0000000..623b95b --- /dev/null +++ b/scripts/meds/extract_unique_events.py @@ -0,0 +1,128 @@ +import os +import sys +import glob +import math +import shutil +from typing import List +import multiprocessing +import h5pickle +import numpy as np +import logging +from argparse import ArgumentParser +from tqdm import tqdm + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout +) +logger = logging.getLogger(__name__) + +def get_parser(): + parser = ArgumentParser() + parser.add_argument( + "root", + help="path to the **processed** MEDS dataset containing subdirectories for each split. " + "it will try to scan all **/*.h5 files existed in this directory and process them." + ) + parser.add_argument( + "--output_dir", + type=str, + default="outputs", + help="directory to save the processed outputs.", + ) + parser.add_argument( + "--workers", + metavar="N", + default=1, + type=int, + help="number of parallel workers." + ) + parser.add_argument( + "--n_events_per_shard", + metavar="N", + default=1000000, + type=int, + help="number of events included for each shard" + ) + + return parser + +def main(args): + filelist = glob.glob(os.path.join(args.root, "**/*.h5")) + files = [h5pickle.File(fname) for fname in filelist] + + if args.workers <= 1: + unique_events = _extract_unique_events(files) + else: + n = args.workers + files_chunks = [files[i::n] for i in range(n)] + + pool = multiprocessing.get_context("spawn").Pool(processes=args.workers) + unique_events_gathered = pool.map(_extract_unique_events, files_chunks) + pool.close() + pool.join() + + logger.info("Gathering and reducing local unique events...") + unique_events = np.concatenate(unique_events_gathered) + unique_events = np.unique(unique_events, axis=0) + logger.info("Done!") + + # rebase the output directory + if os.path.exists(os.path.join(args.output_dir, "unique_events")): + shutil.rmtree(os.path.join(args.output_dir, "unique_events")) + os.makedirs(os.path.join(args.output_dir, "unique_events")) + + num_shards = math.ceil(len(unique_events) / args.n_events_per_shard) + for shard_id in range(num_shards): + start = shard_id * args.n_events_per_shard + end = min((shard_id + 1) * args.n_events_per_shard, len(unique_events)) + sharded_unique_events = unique_events[start:end] + with h5pickle.File( + os.path.join(args.output_dir, "unique_events", f"unique_events_{shard_id}.h5"), "w" + ) as f: + for i, event_tuple in tqdm(enumerate(sharded_unique_events), total=len(sharded_unique_events)): + idx = str(shard_id * args.n_events_per_shard + i) + data = f.create_group(idx) + + sources = np.stack([ + np.array(event_tuple[0]), + np.array(event_tuple[1]), + np.array(event_tuple[2]) + ]) + data.create_dataset( + "sources", + data=sources, + dtype="i2", + compression="lzf", + shuffle=True + ) + +def _extract_unique_events(files: List[h5pickle.File]): + unique_events = [] + pbar = tqdm(files, total=len(files)) + for f in pbar: + pbar.set_description(f.filename) + for sbj_id in f["ehr"]: + input_ids = f["ehr"][sbj_id]["hi"][:, 0] + type_ids = f["ehr"][sbj_id]["hi"][:, 1] + dpe_ids = f["ehr"][sbj_id]["hi"][:, 2] + + event_tokens = [ + ( + tuple(input_id), + tuple(type_id), + tuple(dpe_id), + ) for input_id, type_id, dpe_id in zip(input_ids, type_ids, dpe_ids) + ] + event_tokens = list(np.unique(event_tokens, axis=0)) + unique_events.extend(event_tokens) + unique_events = list(np.unique(unique_events, axis=0)) + + return unique_events + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/scripts/meds/map_events_to_vec.py b/scripts/meds/map_events_to_vec.py new file mode 100644 index 0000000..59bcfa2 --- /dev/null +++ b/scripts/meds/map_events_to_vec.py @@ -0,0 +1,127 @@ +import os +import sys +import functools +import glob +import logging +import shutil +from typing import List +import multiprocessing +import pickle +import h5pickle +import numpy as np +from argparse import ArgumentParser +from tqdm import tqdm + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout +) +logger = logging.getLogger(__name__) + +global_event_to_vec = None + +def init_worker(event_to_vec): + global global_event_to_vec + global_event_to_vec = event_to_vec + logger.info(f"Process {multiprocessing.current_process().name} initialized with data") + +def get_parser(): + parser = ArgumentParser() + parser.add_argument( + "root", + help="path to the **processed** MEDS dataset containing subdirectories for each split. " + "it will try to scan all **/*.h5 files existed in this directory except for " + "unique_events_*.h5 and process them." + ) + parser.add_argument( + "--map_dir", + help="path to the directory containing **`event_to_vec.pkl`**, which is generated by " + "running the model with `--encode_events=True` and `--encode_only=True`" + ) + parser.add_argument( + "--output_dir", + type=str, + default="outputs", + help="directory to save the processed outputs.", + ) + parser.add_argument( + "--workers", + metavar="N", + default=1, + type=int, + help="number of parallel workers." + ) + + return parser + +def main(args): + filelist = glob.glob(os.path.join(args.root, "**/*.h5")) + files = [h5pickle.File(fname) for fname in filelist if "unique_events" not in fname] + logger.info("Reading mapping dictionary for event vectors...") + with open(os.path.join(args.map_dir, "event_to_vec.pkl"), "rb") as f: + event_to_vec = pickle.load(f) + + subdirs = [ + os.path.relpath(os.path.dirname(p), os.path.abspath(args.root)) + for p in filelist + ] + output_dirs = [ + os.path.join(args.output_dir, subdir) for subdir in subdirs + ] + for subdir in np.unique(subdirs): + if os.path.exists(os.path.join(args.output_dir, subdir)): + shutil.rmtree(os.path.join(args.output_dir, subdir)) + os.makedirs(os.path.join(args.output_dir, subdir)) + + logger.info("Mapping events to their representation vectors...") + if args.workers <= 1: + _map_events_to_vec(event_to_vec, output_dirs, files) + else: + n = args.workers + files_chunks = [files[i::n] for i in range(n)] + output_dir_chunks = [output_dirs[i::n] for i in range(n)] + + pool = multiprocessing.Pool( + processes=args.workers, initializer=init_worker, initargs=(event_to_vec,) + ) + pool.starmap(_map_events_to_vec, zip(output_dir_chunks, files_chunks)) + pool.close() + pool.join() + +def _map_events_to_vec(output_dirs: List[str], files: List[h5pickle.File]): + assert len(files) == len(output_dirs) + for f, output_dir in zip(files, output_dirs): + output_name = os.path.basename(f.filename).split(".h5")[0] + "_encoded.h5" + + with h5pickle.File(os.path.join(output_dir, output_name), "w") as output_f: + output_f.create_group("ehr") + for sbj_id in tqdm(f["ehr"], total=len(f["ehr"]), desc=output_name): + input_ids = f["ehr"][sbj_id]["hi"][:, 0] # S, 128 + event_tuples = [ + tuple(event[event != 0]) for event in input_ids + ] + event_vectors = np.array([ + global_event_to_vec[event] for event in event_tuples + ]) + output_f["ehr"].create_group(sbj_id) + output_f["ehr"][sbj_id].create_dataset( + "encoded", + data=event_vectors, + dtype="i2", + compression="lzf", + shuffle=True, + chunks=event_vectors.shape + ) + output_f["ehr"][sbj_id].create_dataset( + "time", data=f["ehr"][sbj_id]["time"][()], + ) + output_f["ehr"][sbj_id].create_dataset( + "label", data=f["ehr"][sbj_id]["label"][()] + ) + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/scripts/meds/predict.sh b/scripts/meds/predict.sh index c37afc7..5ce0264 100644 --- a/scripts/meds/predict.sh +++ b/scripts/meds/predict.sh @@ -2,7 +2,7 @@ # Function to display help message function display_help() { - echo "Usage: $0 " + echo "Usage: $0 " echo echo "This script produces predicted labels and their probabilities for a given task and its" echo "cohort." @@ -10,9 +10,9 @@ function display_help() { echo "Arguments:" echo " ENCODED_MEDS_DIR Directory containing encoded MEDS data, expected to contain *_encoded.h5 files" echo " SAVE_DIR Output directory to save the predicted results." - echo " GPU_ID GPU index to be used for training the model." echo " ACES_TEST_COHORT_DIR Directory containing test cohorts generated from ACES, expected to contain *.parquet files." echo " CHECKPOINT_DIR Directory containing checkpoint for the trained REMed model, expected to contain checkpoint_best.pt." + echo " GPU_ID GPU index to be used for training the model." echo echo "Options:" echo " -h, --help Display this help message and exit." @@ -27,9 +27,9 @@ fi ENCODED_MEDS_DIR="$1" SAVE_DIR="$2" -GPU_ID="$3" -ACES_TEST_COHORT_DIR="$4" -CHECKPOINT_DIR="$5" +ACES_TEST_COHORT_DIR="$3" +CHECKPOINT_DIR="$4" +GPU_ID="$5" accelerate launch \ --config_file config/single.json \ @@ -40,9 +40,12 @@ accelerate launch \ --input_path "$ENCODED_MEDS_DIR" \ --save_dir "$SAVE_DIR" \ --pred_targets meds_single_task \ + --pred_time 24 \ --train_type remed \ --scorer \ --scorer_use_time \ + --max_seq_len 200000 \ + --max_retrieve_len 512 \ --test_only \ --test_cohort "$ACES_TEST_COHORT_DIR" \ --resume_name "$CHECKPOINT_DIR" \ No newline at end of file diff --git a/scripts/meds/process_meds.py b/scripts/meds/process_meds.py index 65f6e7c..cd83a37 100644 --- a/scripts/meds/process_meds.py +++ b/scripts/meds/process_meds.py @@ -5,7 +5,8 @@ import os import re import shutil -import time + +import warnings from argparse import ArgumentParser from bisect import bisect_left, bisect_right from datetime import datetime @@ -18,6 +19,8 @@ from tqdm import tqdm from transformers import AutoTokenizer +pool_manager = multiprocessing.Manager() +warned_codes = pool_manager.list() def find_boundary_between(tuples_list, start, end): starts = [s for s, e in tuples_list] @@ -29,7 +32,6 @@ def find_boundary_between(tuples_list, start, end): return start_index, end_index - def get_parser(): parser = ArgumentParser() parser.add_argument( @@ -78,13 +80,20 @@ def get_parser(): help="number of parallel workers.", ) - return parser + parser.add_argument( + "--mimic_dir", + default=None, + help="path to directory for MIMIC-IV database containing hosp/ and icu/ as a subdirectory. " + "this is used for addressing missing descriptions in the metadata for MIMIC-IV codes." + ) + return parser def main(args): root_path = Path(args.root) output_dir = Path(args.output_dir) metadata_dir = Path(args.metadata_dir) + mimic_dir = Path(args.mimic_dir) if args.mimic_dir is not None else None if not output_dir.exists(): output_dir.mkdir() @@ -115,6 +124,26 @@ def main(args): codes_metadata = pl.read_parquet(metadata_dir / "codes.parquet").to_pandas() codes_metadata = codes_metadata.set_index("code")["description"].to_dict() + # do not allow to use static events or birth event + birth_code = ( + "MEDS_BIRTH" # NOTE can we assume code for "birth" is always "MEDS_BIRTH"? + ) + if birth_code not in codes_metadata: + print( + f'"{birth_code}" is not found in the codes metadata, which may lead to ' + "unexpected results since we currently exclude this event from the input data. " + ) + + if mimic_dir is not None: + d_items = pd.read_csv(mimic_dir / "icu" / "d_items.csv.gz") + d_items["itemid"] = d_items["itemid"].astype("str") + d_items = d_items.set_index("itemid")["label"].to_dict() + d_labitems = pd.read_csv(mimic_dir / "hosp" / "d_labitems.csv.gz") + d_labitems["itemid"] = d_labitems["itemid"].astype("str") + d_labitems = d_labitems.set_index("itemid")["label"].to_dict() + else: + d_items = None + d_labitems = None progress_bar = tqdm(data_paths, total=len(data_paths)) for data_path in progress_bar: @@ -129,15 +158,6 @@ def main(args): else: raise ValueError(f"Unsupported file format: {data_path.suffix}") - # do not allow to use static events or birth event - birth_code = ( - "MEDS_BIRTH" # NOTE can we assume code for "birth" is always "MEDS_BIRTH"? - ) - if birth_code not in codes_metadata: - print( - f'"{birth_code}" is not found in the codes metadata, which may lead to ' - "unexpected results since we currently exclude this event from the input data. " - ) data = data.with_columns( pl.when(pl.col("code") == birth_code) .then(None) @@ -156,6 +176,7 @@ def main(args): raise ValueError(f"Unsupported file format: {cohort_path.suffix}") cohort = cohort.drop_nulls(label_col_name) + cohort = cohort.unique() cohort = cohort.select( [pl.col("subject_id"), @@ -215,11 +236,9 @@ def extract_cohort(row): return_dtype=pl.Struct( {"cohort_end": pl.List(pl.Datetime()), "cohort_label": pl.List(pl.Boolean)} ) - .alias("cohort_criteria") ) - .unnest("cohort_criteria") - .collect() - ) + .alias("cohort_criteria") + ).unnest("cohort_criteria").collect() data = data.drop_nulls("cohort_label") @@ -259,6 +278,9 @@ def extract_cohort(row): output_dir, output_name, args.workers, + d_items, + d_labitems, + warned_codes, ) # meds --> remed @@ -268,11 +290,12 @@ def extract_cohort(row): del data else: subject_ids = data["subject_id"].unique().to_list() - chunksize = math.ceil(len(subject_ids) / args.workers) + n = args.workers + subject_id_chunks = [subject_ids[i::n] for i in range(n)] data_chunks = [] - for i in range(0, len(subject_ids), chunksize): + for subject_id_chunk in subject_id_chunks: data_chunks.append( - data.filter(pl.col("subject_id").is_in(subject_ids[i:i+chunksize])) + data.filter(pl.col("subject_id").is_in(subject_id_chunk)) ) del data pool = multiprocessing.get_context("spawn").Pool(processes=args.workers) @@ -280,6 +303,8 @@ def extract_cohort(row): length_per_subject_gathered = pool.map( meds_to_remed_partial, data_chunks ) + pool.close() + pool.join() del data_chunks if len(length_per_subject_gathered) != args.workers: @@ -293,7 +318,6 @@ def extract_cohort(row): for subject_id, (length, shard_id) in length_per_subject.items(): manifest_f.write(f"{subject_id}\t{length}\t{shard_id}\n") - def meds_to_remed( tokenizer, rest_of_columns, @@ -302,8 +326,13 @@ def meds_to_remed( output_dir, output_name, num_shards, + d_items, + d_labitems, + warned_codes, df_chunk ): + code_matching_pattern = re.compile(r"\d+") + def meds_to_remed_unit(row): events = [] digit_offsets = [] @@ -313,13 +342,48 @@ def meds_to_remed_unit(row): digit_offset = [] col_name_offset = [] for col_name in ["code", "numeric_value"] + rest_of_columns: + # do not process something like "icustay_id" or "hadm_id" + if "id" in col_name: + continue + col_event = row[column_name_idcs[col_name]][event_index] if col_event is not None: col_event = str(col_event) if col_name == "code": if col_event in codes_metadata and codes_metadata[col_event] != "": col_event = codes_metadata[col_event] - elif not "id" in col_name: + else: + do_break = False + items = col_event.split("//") + is_code = [ + bool(code_matching_pattern.fullmatch(item)) for item in items + ] + if True in is_code: + if d_items is not None and d_labitems is not None: + code_idx = is_code.index(True) + code = items[code_idx] + + if code in d_items: + desc = d_items[code] + elif code in d_labitems: + desc = d_labitems[code] + else: + do_break = True + + if not do_break: + items[code_idx] = desc + col_event = "//".join(items) + else: + do_break = True + + if do_break and col_event not in warned_codes: + warned_codes.append(col_event) + warnings.warn( + "The dataset contains some codes that are not specified in " + "the codes metadata, which may not be intended. Note that we " + f"process this code as it is for now: {col_event}." + ) + else: col_event = re.sub( r"\d*\.\d+", lambda x: str(round(float(x.group(0)), 4)), @@ -515,7 +579,6 @@ def meds_to_remed_unit(row): return length_per_subject - if __name__ == "__main__": parser = get_parser() args = parser.parse_args() diff --git a/src/dataset.py b/src/dataset.py index f58ad42..e80d3d7 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -1,3 +1,4 @@ +import glob import math import os import random @@ -71,6 +72,7 @@ def collate_fn(self, out): else: padded = pad_sequence([i[k] for i in out], batch_first=True) ret[k] = pad(padded, (0, 0, 0, padding_to - padded.shape[1])) + return ret def get_labels(self, data): @@ -264,11 +266,18 @@ def __getitem__(self, idx): # assume it is a scalar value for a binary classification task label = torch.tensor([data["label"][()]]).float() - if self.args.max_seq_len < input.shape[0]: + #XXX + max_num_events = 300000 + if self.args.max_seq_len < len(input): + length = len(input) + if length > max_num_events: + times = times - times[-max_num_events] + if self.args.random_sample: indices = random.sample( - range(0, input.shape[0]), self.args.max_seq_len + range(max(0, length - max_num_events), length), self.args.max_seq_len ) + indices.sort() input = input[indices, :, :] times = times[indices] else: @@ -284,44 +293,52 @@ def __getitem__(self, idx): } -class MEDSForReprGen(MEDSDataset): - def __init__(self, args, split, data_path, *pargs, **kwargs): - super().__init__(args, split, data_path, *pargs, **kwargs) +class MEDSForReprGen(Dataset): + def __init__(self, args, data_path): + super().__init__() - self.manifest["num_samples"] = self.manifest["num_events"].map( - lambda x: math.ceil(x / args.max_seq_len) - ) - self.manifest["last_sample_index"] = np.cumsum(self.manifest["num_samples"]) + self.args = args + + self.data = {} + if not data_path.endswith("unique_events"): + data_path = os.path.join(data_path, "unique_events") + for fname in glob.glob(os.path.join(data_path, "*.h5")): + shard_id = int(os.path.splitext(fname)[0].split("_")[-1]) + self.data[shard_id] = h5pickle.File( + os.path.join(data_path, f"unique_events_{shard_id}.h5") + ) + self.manifest = {} + for shard_id, data in self.data.items(): + keys = data.keys() + shard_manifest = {k: shard_id for k in keys} + self.manifest |= shard_manifest + self.keys = list(self.manifest.keys()) def __len__(self): - return self.manifest["last_sample_index"].max() + return len(self.manifest) + + def collate_fn(self, samples): + input_ids = torch.stack([s["input_ids"] for s in samples]) + type_ids = torch.stack([s["type_ids"] for s in samples]) + dpe_ids = torch.stack([s["dpe_ids"] for s in samples]) + + ret = { + "input_ids": input_ids, + "type_ids": type_ids, + "dpe_ids": dpe_ids + } + + return ret def __getitem__(self, idx): - patient_index = self.manifest["last_sample_index"].searchsorted( - idx, side="right" - ) - subject_id = self.manifest.index[patient_index] - shard_id = self.manifest.iloc[patient_index]["shard_id"] - prev_idx = ( - 0 - if patient_index == 0 - else (self.manifest["last_sample_index"].iloc[patient_index - 1]) - ) - data = self.data[shard_id][str(subject_id)] + key = self.keys[idx] + shard_id = self.manifest[key] + data = self.data[shard_id][key]["sources"] # (3, 128) - input = data["hi"] - sample_idx_in_patient = idx - prev_idx - start = self.args.max_seq_len * sample_idx_in_patient - end = self.args.max_seq_len * (sample_idx_in_patient + 1) - label = torch.tensor([data["label"][()]]).float() return { - "input_ids": torch.LongTensor(input[:, 0, :][start:end]), - "type_ids": torch.LongTensor(input[:, 1, :][start:end]), - "dpe_ids": torch.LongTensor(input[:, 2, :][start:end]), - "times": torch.IntTensor(data["time"][start:end]), - "subject_id": subject_id, - "index": sample_idx_in_patient, - "label": label, + "input_ids": torch.LongTensor(data[0, :]), + "type_ids": torch.LongTensor(data[1, :]), + "dpe_ids": torch.LongTensor(data[2, :]), } @@ -331,21 +348,30 @@ def __init__(self, args, split, data_path, *pargs, **kwargs): self.args = args - if not split.endswith("_encoded"): - split = split + "_encoded" - - self.data = h5pickle.File(os.path.join(data_path, split + ".h5"))["ehr"] - self.manifest = list(self.data.keys()) + self.data = {} + #TODO ..? + for fname in glob.glob(os.path.join(data_path, split, f"*_encoded.h5")): + shard_id = int(os.path.splitext(fname)[0].split("_")[-2]) + self.data[shard_id] = h5pickle.File( + os.path.join(data_path, split, split + f"_{shard_id}_encoded.h5") + )["ehr"] + self.manifest = {} + for shard_id, data in self.data.items(): + keys = data.keys() + shard_manifest = {k: shard_id for k in keys} + self.manifest |= shard_manifest + self.keys = list(self.manifest.keys()) def __len__(self): - return len(self.data) + return len(self.manifest) def collate_fn(self, samples): ret = dict() max_sample_len = max([s["times"].shape[0] for s in samples]) - padding_to = min( - 2 ** math.ceil(math.log(max_sample_len, 2)), self.args.max_seq_len - ) + # padding_to = min( + # 2 ** math.ceil(math.log(max_sample_len, 2)), self.args.max_seq_len + # ) + padding_to = max_sample_len for k, v in samples[0].items(): if k == "times": @@ -369,7 +395,10 @@ def collate_fn(self, samples): return ret def __getitem__(self, idx): - data = self.data[self.manifest[idx]] + subject_id = self.keys[idx] + shard_id = self.manifest[subject_id] + data = self.data[shard_id][subject_id] + encoded = data["encoded"][:] times = data["time"][:] @@ -377,14 +406,19 @@ def __getitem__(self, idx): repr = torch.FloatTensor(encoded) times = torch.IntTensor(times) - times = max(times) - times label = torch.tensor([data["label"][()]]).float() + if len(repr) > self.args.max_seq_len: + repr = repr[-self.args.max_seq_len:, :] + times = times[-self.args.max_seq_len:] + # inverse times + times = max(times) - times + return { "repr": repr, "times": times, "label": label, - "subject_id": self.manifest[idx] + "subject_id": subject_id, } diff --git a/src/models/eventencoder.py b/src/models/eventencoder.py index ecf4eec..4576d15 100644 --- a/src/models/eventencoder.py +++ b/src/models/eventencoder.py @@ -31,6 +31,12 @@ def __init__(self, args): ) def forward(self, all_codes_embs, input_ids, **kwargs): + # all_codes_embs: (B * S, L, Hidden) -- (16 * 512, 128, 512) + # input_ids: (B, S, L) -- (16, 512, 128) + if input_ids.ndim == 2: + assert input_ids.size(0) == all_codes_embs.size(0) + input_ids = input_ids.unsqueeze(1) # (B, L) -> (B, 1, L) + B, S, L = input_ids.shape # All-padding col -> cause nan output -> unmask it (and multiply 0 to the results) src_pad_mask = (input_ids.reshape(-1, L).eq(0)) ^ ( diff --git a/src/models/model.py b/src/models/model.py index 6b25829..02d8e44 100644 --- a/src/models/model.py +++ b/src/models/model.py @@ -478,16 +478,16 @@ def __init__(self, args): self.time_encoder = FlattenTimeEncoding(args) self.layer_norm = nn.LayerNorm(args.pred_dim, eps=1e-12) - def forward(self, input_ids, type_ids, dpe_ids, times, **kwargs): - B, S = input_ids.shape[0], input_ids.shape[1] - + def forward(self, input_ids, type_ids, dpe_ids, times=None, **kwargs): x = self.input_ids_embedding(input_ids) x += self.type_ids_embedding(type_ids) x += self.dpe_ids_embedding(dpe_ids) if "flatten" in self.args.train_type: # (B, S, E) -> (B, S, E) x = self.time_encoder(x, times, **kwargs) - else: # (B, S, W, E) -> (B*S, W, E) - x = x.view(B * S, -1, self.args.pred_dim) + else: + if input_ids.ndim == 3: # x: (B, S, W, E) -> (B*S, W, E) + B, S, W = input_ids.shape + x = x.view(B * S, -1, self.args.pred_dim) x = self.pos_encoder(x) x = self.layer_norm(x) return x diff --git a/src/trainer/base.py b/src/trainer/base.py index 60a5e22..811aceb 100644 --- a/src/trainer/base.py +++ b/src/trainer/base.py @@ -2,12 +2,15 @@ import logging import os import uuid +import pickle from contextlib import nullcontext from shutil import rmtree +from datetime import timedelta import polars as pl import torch from accelerate import Accelerator +from accelerate import InitProcessGroupKwargs from accelerate.logging import get_logger from accelerate.utils import broadcast, set_seed from h5pickle import File @@ -39,7 +42,7 @@ def __init__(self, args): # make subject_id to be {subject_id}_{cohort_number} to prevent duplicated ids if os.path.isdir(self.test_cohort): test_cohort = pl.read_parquet( - os.path.join(self.test_cohort, "*.parquet") + os.path.join(self.test_cohort, self.test_subset, "*.parquet") ) else: test_cohort = pl.read_parquet(self.test_cohort) @@ -64,13 +67,18 @@ def __init__(self, args): self.log = None if self.args.debug or not self.args.wandb else "wandb" def run(self): + ipg_handler = InitProcessGroupKwargs(timeout=timedelta(hours=24)) self.accelerator = Accelerator( - log_with=self.log, split_batches=True, mixed_precision="bf16" + kwargs_handlers=[ipg_handler], log_with=self.log, split_batches=True, mixed_precision="bf16" ) self.args.local_batch_size = ( self.args.batch_size // self.accelerator.num_processes ) - if self.args.resume_name: + if self.args.src_data == "meds": + if self.args.save_dir.endswith("/"): + self.args.save_dir = self.args.save_dir[:-1] + self.args.exp_name = os.path.basename(self.args.save_dir) + "_" + str(self.args.seed) + elif self.args.resume_name: self.args.exp_name = self.args.resume_name else: self.args.exp_name = f"{uuid.uuid4().hex}_{self.args.seed}" @@ -120,6 +128,13 @@ def run(self): if self.args.encode_events or self.args.encode_only: if self.args.src_data == "meds": + assert self.args.encode_events and self.args.encode_only, ( + "encoding MEDS dataset should be run with both the `self.args.encode_events` " + "and `self.args.encode_only` being True." + ) + assert self.args.unique_events_path is not None, ( + "`--unique_events_path` shuold be provided to encode MEDS dataset." + ) self.encode_events_meds() else: self.encode_events() @@ -148,6 +163,9 @@ def train(self): valid_loader = self.dataloader_set(self.valid_subset) model = self.architecture(self.args) + assert self.args.pretrained is None or self.args.resume_name is None, ( + "--pretrained and --resume_name should not be provided together" + ) if self.args.pretrained and not self.args.no_pretrained_checkpoint: if self.args.src_data == "meds": pretrained_path = os.path.join( @@ -158,6 +176,9 @@ def train(self): self.args.save_dir, self.args.pretrained, "checkpoint_best.pt" ) model = load_model(pretrained_path, model) + elif self.args.src_data == "meds" and self.args.resume_name is not None: + resume_path = os.path.join(self.args.resume_name, "checkpoint_last.pt") + model = load_model(resume_path, model) if self.args.enable_fsdp: model = self.accelerator.prepare(model) @@ -277,8 +298,7 @@ def test(self): .map_elements(lambda x: x.split("_")[0], return_dtype=pl.String) .cast(int) ) - exp_name = os.path.basename(self.args.exp_name) - save_dir = os.path.join(self.args.save_dir, exp_name) + save_dir = self.args.save_dir save_path = os.path.join(save_dir, f"{self.test_subset}.parquet") if not os.path.exists(save_dir): os.makedirs(save_dir) @@ -295,7 +315,7 @@ def test(self): return metric_dict def dataloader_set(self, split): - if split is None: + if self.args.src_data == "meds" and split is None: return None dataset = self.dataset(self.args, split, self.data, self.df) return DataLoader( @@ -303,7 +323,7 @@ def dataloader_set(self, split): batch_size=self.args.batch_size, shuffle=False, num_workers=8, - pin_memory=True, + pin_memory=False, collate_fn=dataset.collate_fn, persistent_workers=True, ) @@ -340,9 +360,8 @@ def epoch(self, split, data_loader, n_epoch=0): metrics = self.metric.get_metrics() log_dict = log_from_dict(metrics, split, n_epoch) - if self.log is None: - print(log_dict) - else: + logger.info(log_dict) + if self.log is not None: self.accelerator.log(log_dict) return metrics @@ -356,8 +375,10 @@ def encode_events(self): model = self.architecture(self.args) model = load_model(best_model_path, model) self.model = self.accelerator.prepare(model) - self.args.max_seq_len = 1024 - self.args.batch_size = 8 + # self.args.max_seq_len = 1024 + # self.args.batch_size = 8 + self.args.max_seq_len = 512 + self.args.batch_size = 16 else: self.args.max_seq_len = 512 self.args.batch_size = 8 @@ -372,7 +393,7 @@ def _get_hdf5_path(i): postfix = "" if i == -1 else "_" + str(i) return os.path.join( self.args.save_dir, - self.args.exp_name, + # self.args.exp_name, f"{self.args.src_data}_encoded{postfix}.h5", ) @@ -387,22 +408,20 @@ def _get_hdf5_path(i): k = str(k) stay_g = encoded.create_group(k) stay_g.create_dataset( - "encoded", - shape=(len(self.data["ehr"][k]["time"]), self.args.pred_dim), - dtype="i2", - compression="lzf", - shuffle=True, - chunks=(len(self.data["ehr"][k]["time"]), self.args.pred_dim), + "time", data=self.data["ehr"][k]["time"][()] ) + stay_g.attrs.update(self.data["ehr"][k].attrs) self.accelerator.wait_for_everyone() with torch.no_grad(): loader = enumerate(dataloader) if self.accelerator.is_main_process: - loader = tqdm(loader) + loader = tqdm(loader, total=len(dataloader)) + buffer = {} for i, batch in loader: self.step = i all_codes_embs = self.model.input2emb_model(**batch) + # (16, 512, 128) -> (16, 512, 128, 512) -> (16 * 512, 128, 512) reprs = self.model.eventencoder_model( all_codes_embs, **batch @@ -411,26 +430,36 @@ def _get_hdf5_path(i): stay_ids = batch["stay_id"].cpu().numpy().reshape(-1) indices = batch["index"].cpu().numpy().reshape(-1) for repr, stay_id, index in zip(reprs, stay_ids, indices): - stay_id, index = int(stay_id), int(index) start = self.args.max_seq_len * index end = start + self.args.max_seq_len - max_len = encoded[str(stay_id)]["encoded"].shape[0] + max_len = dataloader.dataset.df["time"].loc[stay_id] if end > max_len: repr = repr[: max_len - start, :] end = max_len - encoded[str(stay_id)]["encoded"][start:end, :] = repr + if stay_id not in buffer: + buffer[stay_id] = [] + heapq.heappush(buffer[stay_id], (index, repr)) + if ((i + 1) % 100 == 0) or ((i + 1) == len(loader)): + for stay_id in list(buffer.keys()): + items = buffer[stay_id] + num_events = dataloader.dataset.df["time"].loc[stay_id] + num_samples = dataloader.dataset.df["num_sample_per_pat"].loc[stay_id] + if len(items) == num_samples: + data = np.concatenate([x[1] for x in items]) + encoded[str(stay_id)].create_dataset( + "encoded", + data=data, + dtype="i2", + compression="lzf", + shuffle=True, + chunks=(num_events, self.args.pred_dim) + ) + del buffer[stay_id] f.close() self.accelerator.wait_for_everyone() if self.accelerator.num_processes == 1: os.rename(hdf5_path, _get_hdf5_path(-1)) - main_file = File(_get_hdf5_path(-1), "r+") - for k in tqdm(self.data["ehr"].keys()): - main_file["ehr"][k].create_dataset( - "time", data=self.data["ehr"][k]["time"][()] - ) - main_file["ehr"][k].attrs.update(self.data["ehr"][k].attrs) - main_file.close() else: if self.accelerator.is_main_process: main_file = File(_get_hdf5_path(-1), "w") @@ -470,138 +499,79 @@ def _get_hdf5_path(i): # to add compatibility with meds dataset def encode_events_meds(self): - if self.args.encode_only: - best_model_path = os.path.join(self.args.resume_name, "checkpoint_best.pt") - else: - best_model_path = os.path.join( - self.args.save_dir, self.args.exp_name, "checkpoint_best.pt" - ) - if self.args.train_type != "bioclinicalbert_encode": - model = self.architecture(self.args) - model = load_model(best_model_path, model) - self.model = self.accelerator.prepare(model) - self.args.max_seq_len = 1024 - self.args.batch_size = 8 - else: - self.args.max_seq_len = 512 - self.args.batch_size = 8 + best_model_path = os.path.join(self.args.resume_name, "checkpoint_best.pt") + model = self.architecture(self.args) + model = load_model(best_model_path, model) + self.model = self.accelerator.prepare(model) self.model.eval() - self.dataset = MEDSForReprGen - - for split in [self.train_subset, self.valid_subset, self.test_subset]: - if split is None: - continue - logger.info(f"Start Encoding for {split} split") - - dataloader = self.dataloader_set(split) - dataloader = self.accelerator.prepare(dataloader) - - def _get_hdf5_path(i): - postfix = "" if i == -1 else "_" + str(i) - return os.path.join( - self.args.save_dir, - f"{split}_encoded{postfix}.h5", - ) + logger.info( + f"Start to generate representation vectors for each of unique events in MEDS dataset" + ) + dataset = MEDSForReprGen(self.args, self.args.unique_events_path) + dataloader = DataLoader( + dataset, + batch_size=self.args.batch_size, + shuffle=False, + num_workers=8, + pin_memory=False, + collate_fn=dataset.collate_fn, + persistent_workers=True, + ) + dataloader = self.accelerator.prepare(dataloader) - hdf5_path = _get_hdf5_path(self.accelerator.local_process_index) - logger.info("Writing metadata to HDF5") + event_to_vec = {} + with torch.no_grad(): + loader = enumerate(dataloader) + loader = tqdm( + loader, total=len(dataloader), desc=str(self.accelerator.local_process_index) + ) + for i, batch in loader: + self.step = i + + embedded = self.accelerator.unwrap_model(self.model).input2emb_model(**batch) + event_vectors = ( + self.accelerator.unwrap_model(self.model).eventencoder_model(embedded, **batch) + ).squeeze(1) # (B, 1, E) -> (B, E) + event_vectors = event_vectors.cpu().bfloat16().view(torch.int16).numpy() + + input_ids = batch["input_ids"].cpu().numpy() # (B, 128) + for j, event in enumerate(input_ids): + event_tuple = tuple(event[event != 0]) + event_to_vec[event_tuple] = event_vectors[j] + + def get_local_path(i): + postfix = "" if i == -1 else "_" + str(i) + return os.path.join(self.args.save_dir, f"event_to_vec{postfix}.pkl") - with File(hdf5_path, "w") as f: - f.create_group("ehr") - encoded = f["ehr"] + local_path = get_local_path(self.accelerator.local_process_index) + logger.info("Saving the resulted vector maps...") + with open(local_path, "wb") as f: + pickle.dump(event_to_vec, f) + logger.info("Done!") + self.accelerator.wait_for_everyone() - for subject_id, metadata in dataloader.dataset.manifest.iterrows(): - stay_g = encoded.create_group(subject_id) - stay_g.create_dataset( - "time", - data=dataloader.dataset.data[metadata["shard_id"]][subject_id]["time"][()], - dtype="i" - ) - stay_g.create_dataset( - "label", - data=dataloader.dataset.data[metadata["shard_id"]][subject_id]["label"][()], - ) - self.accelerator.wait_for_everyone() - - with torch.no_grad(): - loader = enumerate(dataloader) - if self.accelerator.is_main_process: - loader = tqdm(loader, total=len(dataloader)) - buffer = {} - for i, batch in loader: - self.step = i - all_codes_embs = self.model.input2emb_model(**batch) - - reprs = self.model.eventencoder_model( - all_codes_embs, **batch - ) # B, S, E - reprs = reprs.cpu().bfloat16().view(torch.int16).numpy() - subject_ids = batch["subject_id"].reshape(-1) - indices = batch["index"].reshape(-1) - for repr, subject_id, index, in zip(reprs, subject_ids, indices): - start = self.args.max_seq_len * index - end = start + self.args.max_seq_len - max_len = dataloader.dataset.manifest.loc[subject_id]["num_events"] - if end > max_len: - repr = repr[: max_len - start, :] - end = max_len - if subject_id not in buffer: - buffer[subject_id] = [] - heapq.heappush(buffer[subject_id], (index, repr)) - if ((i + 1) % 100 == 0) or ((i + 1) == len(loader)): - # flush buffer if applicable - for subject_id in list(buffer.keys()): - items = buffer[subject_id] - metadata = dataloader.dataset.manifest.loc[subject_id] - if len(items) == metadata["num_samples"]: - data = np.concatenate([x[1] for x in items]) - encoded[subject_id].create_dataset( - "encoded", - data=data, - dtype="i2", - compression="lzf", - shuffle=True, - chunks=( - metadata["num_events"], - self.args.pred_dim, - ), - ) - del buffer[subject_id] + if self.accelerator.num_processes == 1: + if os.path.exists(get_local_path(-1)): + os.remove(get_local_path(-1)) + os.rename(local_path, get_local_path(-1)) + else: + if self.accelerator.is_main_process: + main_dict = {} + local_dicts = [] + for i in range(self.accelerator.num_processes): + with open(get_local_path(i), "rb") as local_f: + local_dicts.append(pickle.load(local_f)) + logger.info("Gathering and summarizing local vector maps...") + for local_dict in local_dicts: + # NOTE only work in python >= 3.9.0 + main_dict = main_dict | local_dict + + if os.path.exists(get_local_path(-1)): + os.remove(get_local_path(-1)) + with open(get_local_path(-1), "wb") as main_f: + pickle.dump(main_dict, main_f) - self.accelerator.wait_for_everyone() - if self.accelerator.num_processes == 1: - os.rename(hdf5_path, _get_hdf5_path(-1)) - else: - if self.accelerator.is_main_process: - main_file = File(_get_hdf5_path(-1), "w") - main_file.create_group("ehr") - files = [ - File(_get_hdf5_path(i), "r") - for i in range(self.accelerator.num_processes) - ] - for k in tqdm(dataloader.dataset.manifest.index): - # Chunkwise sum, but may be duplicated chunks - encodeds = ( - np.stack( - [f["ehr"][k]["encoded"][()] for f in files], axis=0 - ) - .astype(np.uint16) - .max(axis=0) - .astype(np.int16) - ) - main_file["ehr"].create_group(k) - main_file["ehr"][k].create_dataset( - "encoded", - data=encodeds, - dtype="i2", - compression="lzf", - shuffle=True, - chunks=encodeds.shape, - ) - for f in files: - f.close() - main_file.close() - for i in range(self.accelerator.num_processes): - os.remove(_get_hdf5_path(i)) - self.accelerator.wait_for_everyone() + for i in range(self.accelerator.num_processes): + os.remove(get_local_path(i)) + self.accelerator.wait_for_everyone() diff --git a/src/trainer/remed.py b/src/trainer/remed.py index 5b1c563..0519fc0 100644 --- a/src/trainer/remed.py +++ b/src/trainer/remed.py @@ -89,10 +89,10 @@ def step(sample): for sample in t: if split == self.train_subset: - self.model.set_mode("scorer") + self.accelerator.unwrap_model(self.model).set_mode("scorer") net_output, logging_output = step(sample) - self.model.set_mode("predictor") + self.accelerator.unwrap_model(self.model).set_mode("predictor") net_output, logging_output = step(sample) # meds -- output @@ -105,6 +105,14 @@ def step(sample): if do_output_cohort: predicted_cohort = pl.DataFrame(predicted_cohort) self.test_cohort = self.test_cohort.join(predicted_cohort, on="subject_id", how="left") + self.test_cohort = self.test_cohort.select( + [ + pl.col("boolean_prediction"), + pl.col("subject_id"), + pl.col("prediction_time"), + pl.col("boolean_value") + ] + ) metrics = self.metric.get_metrics() log_dict = log_from_dict(metrics, split, n_epoch) @@ -112,4 +120,4 @@ def step(sample): print(log_dict) else: self.accelerator.log(log_dict) - return metrics + return metrics \ No newline at end of file From b085187d3c0fa187e4e487624aa9a06fc1686bfa Mon Sep 17 00:00:00 2001 From: Jwoo5 Date: Thu, 3 Oct 2024 15:38:41 +0900 Subject: [PATCH 2/8] Explicitly exclude path including `"unique_events"` --- scripts/meds/map_events_to_vec.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/meds/map_events_to_vec.py b/scripts/meds/map_events_to_vec.py index 59bcfa2..3a942d5 100644 --- a/scripts/meds/map_events_to_vec.py +++ b/scripts/meds/map_events_to_vec.py @@ -65,7 +65,7 @@ def main(args): subdirs = [ os.path.relpath(os.path.dirname(p), os.path.abspath(args.root)) - for p in filelist + for p in filelist if "unique_events" not in p ] output_dirs = [ os.path.join(args.output_dir, subdir) for subdir in subdirs From 78170e86faa4040f2364e6a9e79b381c001392b0 Mon Sep 17 00:00:00 2001 From: Junu Kim Date: Fri, 4 Oct 2024 15:53:17 +0900 Subject: [PATCH 3/8] Formatting + Python3.8 --- scripts/meds/extract_unique_events.py | 55 ++++++++++-------- scripts/meds/map_events_to_vec.py | 56 ++++++++++-------- scripts/meds/process_meds.py | 82 +++++++++++++++++---------- src/dataset.py | 32 +++++------ src/models/eventencoder.py | 2 +- src/models/model.py | 4 +- src/trainer/base.py | 66 +++++++++++++-------- src/trainer/remed.py | 16 +++--- 8 files changed, 183 insertions(+), 130 deletions(-) diff --git a/scripts/meds/extract_unique_events.py b/scripts/meds/extract_unique_events.py index 623b95b..809f509 100644 --- a/scripts/meds/extract_unique_events.py +++ b/scripts/meds/extract_unique_events.py @@ -1,30 +1,32 @@ -import os -import sys import glob +import logging import math +import multiprocessing +import os import shutil +import sys +from argparse import ArgumentParser from typing import List -import multiprocessing + import h5pickle import numpy as np -import logging -from argparse import ArgumentParser from tqdm import tqdm logging.basicConfig( format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=os.environ.get("LOGLEVEL", "INFO").upper(), - stream=sys.stdout + stream=sys.stdout, ) logger = logging.getLogger(__name__) + def get_parser(): parser = ArgumentParser() parser.add_argument( "root", help="path to the **processed** MEDS dataset containing subdirectories for each split. " - "it will try to scan all **/*.h5 files existed in this directory and process them." + "it will try to scan all **/*.h5 files existed in this directory and process them.", ) parser.add_argument( "--output_dir", @@ -37,18 +39,19 @@ def get_parser(): metavar="N", default=1, type=int, - help="number of parallel workers." + help="number of parallel workers.", ) parser.add_argument( "--n_events_per_shard", metavar="N", default=1000000, type=int, - help="number of events included for each shard" + help="number of events included for each shard", ) return parser + def main(args): filelist = glob.glob(os.path.join(args.root, "**/*.h5")) files = [h5pickle.File(fname) for fname in filelist] @@ -80,25 +83,29 @@ def main(args): end = min((shard_id + 1) * args.n_events_per_shard, len(unique_events)) sharded_unique_events = unique_events[start:end] with h5pickle.File( - os.path.join(args.output_dir, "unique_events", f"unique_events_{shard_id}.h5"), "w" + os.path.join( + args.output_dir, "unique_events", f"unique_events_{shard_id}.h5" + ), + "w", ) as f: - for i, event_tuple in tqdm(enumerate(sharded_unique_events), total=len(sharded_unique_events)): + for i, event_tuple in tqdm( + enumerate(sharded_unique_events), total=len(sharded_unique_events) + ): idx = str(shard_id * args.n_events_per_shard + i) data = f.create_group(idx) - sources = np.stack([ - np.array(event_tuple[0]), - np.array(event_tuple[1]), - np.array(event_tuple[2]) - ]) + sources = np.stack( + [ + np.array(event_tuple[0]), + np.array(event_tuple[1]), + np.array(event_tuple[2]), + ] + ) data.create_dataset( - "sources", - data=sources, - dtype="i2", - compression="lzf", - shuffle=True + "sources", data=sources, dtype="i2", compression="lzf", shuffle=True ) + def _extract_unique_events(files: List[h5pickle.File]): unique_events = [] pbar = tqdm(files, total=len(files)) @@ -114,7 +121,8 @@ def _extract_unique_events(files: List[h5pickle.File]): tuple(input_id), tuple(type_id), tuple(dpe_id), - ) for input_id, type_id, dpe_id in zip(input_ids, type_ids, dpe_ids) + ) + for input_id, type_id, dpe_id in zip(input_ids, type_ids, dpe_ids) ] event_tokens = list(np.unique(event_tokens, axis=0)) unique_events.extend(event_tokens) @@ -122,7 +130,8 @@ def _extract_unique_events(files: List[h5pickle.File]): return unique_events + if __name__ == "__main__": parser = get_parser() args = parser.parse_args() - main(args) \ No newline at end of file + main(args) diff --git a/scripts/meds/map_events_to_vec.py b/scripts/meds/map_events_to_vec.py index 3a942d5..211a9be 100644 --- a/scripts/meds/map_events_to_vec.py +++ b/scripts/meds/map_events_to_vec.py @@ -1,44 +1,49 @@ -import os -import sys import functools import glob import logging -import shutil -from typing import List import multiprocessing +import os import pickle +import shutil +import sys +from argparse import ArgumentParser +from typing import List + import h5pickle import numpy as np -from argparse import ArgumentParser from tqdm import tqdm logging.basicConfig( format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=os.environ.get("LOGLEVEL", "INFO").upper(), - stream=sys.stdout + stream=sys.stdout, ) logger = logging.getLogger(__name__) global_event_to_vec = None + def init_worker(event_to_vec): global global_event_to_vec global_event_to_vec = event_to_vec - logger.info(f"Process {multiprocessing.current_process().name} initialized with data") + logger.info( + f"Process {multiprocessing.current_process().name} initialized with data" + ) + def get_parser(): parser = ArgumentParser() parser.add_argument( "root", help="path to the **processed** MEDS dataset containing subdirectories for each split. " - "it will try to scan all **/*.h5 files existed in this directory except for " - "unique_events_*.h5 and process them." + "it will try to scan all **/*.h5 files existed in this directory except for " + "unique_events_*.h5 and process them.", ) parser.add_argument( "--map_dir", help="path to the directory containing **`event_to_vec.pkl`**, which is generated by " - "running the model with `--encode_events=True` and `--encode_only=True`" + "running the model with `--encode_events=True` and `--encode_only=True`", ) parser.add_argument( "--output_dir", @@ -51,11 +56,12 @@ def get_parser(): metavar="N", default=1, type=int, - help="number of parallel workers." + help="number of parallel workers.", ) return parser + def main(args): filelist = glob.glob(os.path.join(args.root, "**/*.h5")) files = [h5pickle.File(fname) for fname in filelist if "unique_events" not in fname] @@ -65,11 +71,10 @@ def main(args): subdirs = [ os.path.relpath(os.path.dirname(p), os.path.abspath(args.root)) - for p in filelist if "unique_events" not in p - ] - output_dirs = [ - os.path.join(args.output_dir, subdir) for subdir in subdirs + for p in filelist + if "unique_events" not in p ] + output_dirs = [os.path.join(args.output_dir, subdir) for subdir in subdirs] for subdir in np.unique(subdirs): if os.path.exists(os.path.join(args.output_dir, subdir)): shutil.rmtree(os.path.join(args.output_dir, subdir)) @@ -90,6 +95,7 @@ def main(args): pool.close() pool.join() + def _map_events_to_vec(output_dirs: List[str], files: List[h5pickle.File]): assert len(files) == len(output_dirs) for f, output_dir in zip(files, output_dirs): @@ -98,13 +104,11 @@ def _map_events_to_vec(output_dirs: List[str], files: List[h5pickle.File]): with h5pickle.File(os.path.join(output_dir, output_name), "w") as output_f: output_f.create_group("ehr") for sbj_id in tqdm(f["ehr"], total=len(f["ehr"]), desc=output_name): - input_ids = f["ehr"][sbj_id]["hi"][:, 0] # S, 128 - event_tuples = [ - tuple(event[event != 0]) for event in input_ids - ] - event_vectors = np.array([ - global_event_to_vec[event] for event in event_tuples - ]) + input_ids = f["ehr"][sbj_id]["hi"][:, 0] # S, 128 + event_tuples = [tuple(event[event != 0]) for event in input_ids] + event_vectors = np.array( + [global_event_to_vec[event] for event in event_tuples] + ) output_f["ehr"].create_group(sbj_id) output_f["ehr"][sbj_id].create_dataset( "encoded", @@ -112,16 +116,18 @@ def _map_events_to_vec(output_dirs: List[str], files: List[h5pickle.File]): dtype="i2", compression="lzf", shuffle=True, - chunks=event_vectors.shape + chunks=event_vectors.shape, ) output_f["ehr"][sbj_id].create_dataset( - "time", data=f["ehr"][sbj_id]["time"][()], + "time", + data=f["ehr"][sbj_id]["time"][()], ) output_f["ehr"][sbj_id].create_dataset( "label", data=f["ehr"][sbj_id]["label"][()] ) + if __name__ == "__main__": parser = get_parser() args = parser.parse_args() - main(args) \ No newline at end of file + main(args) diff --git a/scripts/meds/process_meds.py b/scripts/meds/process_meds.py index cd83a37..d31b327 100644 --- a/scripts/meds/process_meds.py +++ b/scripts/meds/process_meds.py @@ -5,7 +5,6 @@ import os import re import shutil - import warnings from argparse import ArgumentParser from bisect import bisect_left, bisect_right @@ -22,6 +21,7 @@ pool_manager = multiprocessing.Manager() warned_codes = pool_manager.list() + def find_boundary_between(tuples_list, start, end): starts = [s for s, e in tuples_list] ends = [e for s, e in tuples_list] @@ -32,6 +32,7 @@ def find_boundary_between(tuples_list, start, end): return start_index, end_index + def get_parser(): parser = ArgumentParser() parser.add_argument( @@ -84,11 +85,12 @@ def get_parser(): "--mimic_dir", default=None, help="path to directory for MIMIC-IV database containing hosp/ and icu/ as a subdirectory. " - "this is used for addressing missing descriptions in the metadata for MIMIC-IV codes." + "this is used for addressing missing descriptions in the metadata for MIMIC-IV codes.", ) return parser + def main(args): root_path = Path(args.root) output_dir = Path(args.output_dir) @@ -179,15 +181,18 @@ def main(args): cohort = cohort.unique() cohort = cohort.select( - [pl.col("subject_id"), + [ + pl.col("subject_id"), pl.col(label_col_name), # pl.col("input.end_summary").struct.field("timestamp_at_start").alias("starttime"), pl.col("prediction_time").alias("endtime"), ] ) - cohort = cohort.group_by( - "subject_id", maintain_order=True - ).agg(pl.col(["endtime", label_col_name])).collect() # omitted "starttime" + cohort = ( + cohort.group_by("subject_id", maintain_order=True) + .agg(pl.col(["endtime", label_col_name])) + .collect() + ) # omitted "starttime" cohort_dict = { x["subject_id"]: { # "starttime": x["starttime"], @@ -229,16 +234,23 @@ def extract_cohort(row): return {"cohort_end": None, "cohort_label": None} data = data.group_by(["subject_id", "time"], maintain_order=True).agg(pl.all()) - data = data.with_columns( - pl.struct(["subject_id", "time"]) - .map_elements( - extract_cohort, - return_dtype=pl.Struct( - {"cohort_end": pl.List(pl.Datetime()), "cohort_label": pl.List(pl.Boolean)} + data = ( + data.with_columns( + pl.struct(["subject_id", "time"]) + .map_elements( + extract_cohort, + return_dtype=pl.Struct( + { + "cohort_end": pl.List(pl.Datetime()), + "cohort_label": pl.List(pl.Boolean), + } + ), ) + .alias("cohort_criteria") ) - .alias("cohort_criteria") - ).unnest("cohort_criteria").collect() + .unnest("cohort_criteria") + .collect() + ) data = data.drop_nulls("cohort_label") @@ -264,7 +276,12 @@ def extract_cohort(row): manifest_f.write("subject_id\tnum_events\tshard_id\n") must_have_columns = [ - "subject_id", "cohort_end", "cohort_label", "time", "code", "numeric_value" + "subject_id", + "cohort_end", + "cohort_label", + "time", + "code", + "numeric_value", ] rest_of_columns = [x for x in data.columns if x not in must_have_columns] column_name_idcs = {col: i for i, col in enumerate(data.columns)} @@ -318,6 +335,7 @@ def extract_cohort(row): for subject_id, (length, shard_id) in length_per_subject.items(): manifest_f.write(f"{subject_id}\t{length}\t{shard_id}\n") + def meds_to_remed( tokenizer, rest_of_columns, @@ -329,7 +347,7 @@ def meds_to_remed( d_items, d_labitems, warned_codes, - df_chunk + df_chunk, ): code_matching_pattern = re.compile(r"\d+") @@ -350,13 +368,17 @@ def meds_to_remed_unit(row): if col_event is not None: col_event = str(col_event) if col_name == "code": - if col_event in codes_metadata and codes_metadata[col_event] != "": + if ( + col_event in codes_metadata + and codes_metadata[col_event] != "" + ): col_event = codes_metadata[col_event] else: do_break = False items = col_event.split("//") is_code = [ - bool(code_matching_pattern.fullmatch(item)) for item in items + bool(code_matching_pattern.fullmatch(item)) + for item in items ] if True in is_code: if d_items is not None and d_labitems is not None: @@ -377,12 +399,12 @@ def meds_to_remed_unit(row): do_break = True if do_break and col_event not in warned_codes: - warned_codes.append(col_event) - warnings.warn( - "The dataset contains some codes that are not specified in " - "the codes metadata, which may not be intended. Note that we " - f"process this code as it is for now: {col_event}." - ) + warned_codes.append(col_event) + warnings.warn( + "The dataset contains some codes that are not specified in " + "the codes metadata, which may not be intended. Note that we " + f"process this code as it is for now: {col_event}." + ) else: col_event = re.sub( r"\d*\.\d+", @@ -521,14 +543,13 @@ def meds_to_remed_unit(row): ) events_data = np.concatenate(events_data) - df_chunk = df_chunk.select( - ["subject_id", "cohort_end", "cohort_label", "time"] - ) + df_chunk = df_chunk.select(["subject_id", "cohort_end", "cohort_label", "time"]) df_chunk = df_chunk.insert_column(4, data_index) df_chunk = df_chunk.explode(["cohort_end", "cohort_label"]) df_chunk = df_chunk.group_by( # ["subject_id", "cohort_start", "cohort_end", "cohort_label"], maintain_order=True - ["subject_id", "cohort_end", "cohort_label"], maintain_order=True + ["subject_id", "cohort_end", "cohort_label"], + maintain_order=True, ).agg(pl.all()) # regard {subject_id} as {cohort_id}: {subject_id}_{cohort_number} @@ -536,7 +557,9 @@ def meds_to_remed_unit(row): pl.col("subject_id").cum_count().over("subject_id").alias("suffix") ) df_chunk = df_chunk.with_columns( - (pl.col("subject_id").cast(str) + "_" + pl.col("suffix").cast(str)).alias("subject_id") + (pl.col("subject_id").cast(str) + "_" + pl.col("suffix").cast(str)).alias( + "subject_id" + ) ) # data = data.drop("suffix", "cohort_start", "cohort_end") df_chunk = df_chunk.drop("suffix", "cohort_end") @@ -579,6 +602,7 @@ def meds_to_remed_unit(row): return length_per_subject + if __name__ == "__main__": parser = get_parser() args = parser.parse_args() diff --git a/src/dataset.py b/src/dataset.py index e80d3d7..311f87d 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -72,7 +72,7 @@ def collate_fn(self, out): else: padded = pad_sequence([i[k] for i in out], batch_first=True) ret[k] = pad(padded, (0, 0, 0, padding_to - padded.shape[1])) - + return ret def get_labels(self, data): @@ -249,7 +249,7 @@ def collate_fn(self, samples): ret[k]["meds_single_task"] = torch.FloatTensor( torch.stack([s["label"] for s in samples]) ) - elif k in ["subject_id", "index"]: # for MEDSForReprGen + elif k in ["subject_id", "index"]: # for MEDSForReprGen ret[k] = np.array([s[k] for s in samples]) else: padded = pad_sequence([s[k] for s in samples], batch_first=True) @@ -266,7 +266,7 @@ def __getitem__(self, idx): # assume it is a scalar value for a binary classification task label = torch.tensor([data["label"][()]]).float() - #XXX + # XXX max_num_events = 300000 if self.args.max_seq_len < len(input): length = len(input) @@ -275,21 +275,22 @@ def __getitem__(self, idx): if self.args.random_sample: indices = random.sample( - range(max(0, length - max_num_events), length), self.args.max_seq_len + range(max(0, length - max_num_events), length), + self.args.max_seq_len, ) indices.sort() input = input[indices, :, :] times = times[indices] else: - input = input[-self.args.max_seq_len:, :, :] - times = times[-self.args.max_seq_len:] + input = input[-self.args.max_seq_len :, :, :] + times = times[-self.args.max_seq_len :] return { "input_ids": torch.LongTensor(input[:, 0, :]), "type_ids": torch.LongTensor(input[:, 1, :]), "dpe_ids": torch.LongTensor(input[:, 2, :]), "times": torch.IntTensor(times), - "label": label + "label": label, } @@ -311,7 +312,7 @@ def __init__(self, args, data_path): for shard_id, data in self.data.items(): keys = data.keys() shard_manifest = {k: shard_id for k in keys} - self.manifest |= shard_manifest + self.manifest.update(**shard_manifest) self.keys = list(self.manifest.keys()) def __len__(self): @@ -322,18 +323,14 @@ def collate_fn(self, samples): type_ids = torch.stack([s["type_ids"] for s in samples]) dpe_ids = torch.stack([s["dpe_ids"] for s in samples]) - ret = { - "input_ids": input_ids, - "type_ids": type_ids, - "dpe_ids": dpe_ids - } + ret = {"input_ids": input_ids, "type_ids": type_ids, "dpe_ids": dpe_ids} return ret def __getitem__(self, idx): key = self.keys[idx] shard_id = self.manifest[key] - data = self.data[shard_id][key]["sources"] # (3, 128) + data = self.data[shard_id][key]["sources"] # (3, 128) return { "input_ids": torch.LongTensor(data[0, :]), @@ -349,7 +346,6 @@ def __init__(self, args, split, data_path, *pargs, **kwargs): self.args = args self.data = {} - #TODO ..? for fname in glob.glob(os.path.join(data_path, split, f"*_encoded.h5")): shard_id = int(os.path.splitext(fname)[0].split("_")[-2]) self.data[shard_id] = h5pickle.File( @@ -359,7 +355,7 @@ def __init__(self, args, split, data_path, *pargs, **kwargs): for shard_id, data in self.data.items(): keys = data.keys() shard_manifest = {k: shard_id for k in keys} - self.manifest |= shard_manifest + self.manifest.update(**shard_manifest) self.keys = list(self.manifest.keys()) def __len__(self): @@ -409,8 +405,8 @@ def __getitem__(self, idx): label = torch.tensor([data["label"][()]]).float() if len(repr) > self.args.max_seq_len: - repr = repr[-self.args.max_seq_len:, :] - times = times[-self.args.max_seq_len:] + repr = repr[-self.args.max_seq_len :, :] + times = times[-self.args.max_seq_len :] # inverse times times = max(times) - times diff --git a/src/models/eventencoder.py b/src/models/eventencoder.py index 4576d15..6c859f2 100644 --- a/src/models/eventencoder.py +++ b/src/models/eventencoder.py @@ -35,7 +35,7 @@ def forward(self, all_codes_embs, input_ids, **kwargs): # input_ids: (B, S, L) -- (16, 512, 128) if input_ids.ndim == 2: assert input_ids.size(0) == all_codes_embs.size(0) - input_ids = input_ids.unsqueeze(1) # (B, L) -> (B, 1, L) + input_ids = input_ids.unsqueeze(1) # (B, L) -> (B, 1, L) B, S, L = input_ids.shape # All-padding col -> cause nan output -> unmask it (and multiply 0 to the results) diff --git a/src/models/model.py b/src/models/model.py index 02d8e44..3e9b4c0 100644 --- a/src/models/model.py +++ b/src/models/model.py @@ -484,8 +484,8 @@ def forward(self, input_ids, type_ids, dpe_ids, times=None, **kwargs): x += self.dpe_ids_embedding(dpe_ids) if "flatten" in self.args.train_type: # (B, S, E) -> (B, S, E) x = self.time_encoder(x, times, **kwargs) - else: - if input_ids.ndim == 3: # x: (B, S, W, E) -> (B*S, W, E) + else: + if input_ids.ndim == 3: # x: (B, S, W, E) -> (B*S, W, E) B, S, W = input_ids.shape x = x.view(B * S, -1, self.args.pred_dim) x = self.pos_encoder(x) diff --git a/src/trainer/base.py b/src/trainer/base.py index 811aceb..545a63d 100644 --- a/src/trainer/base.py +++ b/src/trainer/base.py @@ -1,16 +1,15 @@ import heapq import logging import os -import uuid import pickle +import uuid from contextlib import nullcontext -from shutil import rmtree from datetime import timedelta +from shutil import rmtree import polars as pl import torch -from accelerate import Accelerator -from accelerate import InitProcessGroupKwargs +from accelerate import Accelerator, InitProcessGroupKwargs from accelerate.logging import get_logger from accelerate.utils import broadcast, set_seed from h5pickle import File @@ -50,7 +49,11 @@ def __init__(self, args): pl.col("subject_id").cum_count().over("subject_id").alias("suffix") ) test_cohort = test_cohort.with_columns( - (pl.col("subject_id").cast(str) + "_" + pl.col("suffix").cast(str)).alias("subject_id") + ( + pl.col("subject_id").cast(str) + + "_" + + pl.col("suffix").cast(str) + ).alias("subject_id") ) test_cohort = test_cohort.drop("suffix") self.test_cohort = test_cohort @@ -69,7 +72,10 @@ def __init__(self, args): def run(self): ipg_handler = InitProcessGroupKwargs(timeout=timedelta(hours=24)) self.accelerator = Accelerator( - kwargs_handlers=[ipg_handler], log_with=self.log, split_batches=True, mixed_precision="bf16" + kwargs_handlers=[ipg_handler], + log_with=self.log, + split_batches=True, + mixed_precision="bf16", ) self.args.local_batch_size = ( self.args.batch_size // self.accelerator.num_processes @@ -77,7 +83,9 @@ def run(self): if self.args.src_data == "meds": if self.args.save_dir.endswith("/"): self.args.save_dir = self.args.save_dir[:-1] - self.args.exp_name = os.path.basename(self.args.save_dir) + "_" + str(self.args.seed) + self.args.exp_name = ( + os.path.basename(self.args.save_dir) + "_" + str(self.args.seed) + ) elif self.args.resume_name: self.args.exp_name = self.args.resume_name else: @@ -132,9 +140,9 @@ def run(self): "encoding MEDS dataset should be run with both the `self.args.encode_events` " "and `self.args.encode_only` being True." ) - assert self.args.unique_events_path is not None, ( - "`--unique_events_path` shuold be provided to encode MEDS dataset." - ) + assert ( + self.args.unique_events_path is not None + ), "`--unique_events_path` shuold be provided to encode MEDS dataset." self.encode_events_meds() else: self.encode_events() @@ -163,9 +171,9 @@ def train(self): valid_loader = self.dataloader_set(self.valid_subset) model = self.architecture(self.args) - assert self.args.pretrained is None or self.args.resume_name is None, ( - "--pretrained and --resume_name should not be provided together" - ) + assert ( + self.args.pretrained is None or self.args.resume_name is None + ), "--pretrained and --resume_name should not be provided together" if self.args.pretrained and not self.args.no_pretrained_checkpoint: if self.args.src_data == "meds": pretrained_path = os.path.join( @@ -407,9 +415,7 @@ def _get_hdf5_path(i): for k in self.data["ehr"].keys(): k = str(k) stay_g = encoded.create_group(k) - stay_g.create_dataset( - "time", data=self.data["ehr"][k]["time"][()] - ) + stay_g.create_dataset("time", data=self.data["ehr"][k]["time"][()]) stay_g.attrs.update(self.data["ehr"][k].attrs) self.accelerator.wait_for_everyone() @@ -443,7 +449,9 @@ def _get_hdf5_path(i): for stay_id in list(buffer.keys()): items = buffer[stay_id] num_events = dataloader.dataset.df["time"].loc[stay_id] - num_samples = dataloader.dataset.df["num_sample_per_pat"].loc[stay_id] + num_samples = dataloader.dataset.df["num_sample_per_pat"].loc[ + stay_id + ] if len(items) == num_samples: data = np.concatenate([x[1] for x in items]) encoded[str(stay_id)].create_dataset( @@ -452,7 +460,7 @@ def _get_hdf5_path(i): dtype="i2", compression="lzf", shuffle=True, - chunks=(num_events, self.args.pred_dim) + chunks=(num_events, self.args.pred_dim), ) del buffer[stay_id] f.close() @@ -524,18 +532,26 @@ def encode_events_meds(self): with torch.no_grad(): loader = enumerate(dataloader) loader = tqdm( - loader, total=len(dataloader), desc=str(self.accelerator.local_process_index) + loader, + total=len(dataloader), + desc=str(self.accelerator.local_process_index), ) for i, batch in loader: self.step = i - - embedded = self.accelerator.unwrap_model(self.model).input2emb_model(**batch) + + embedded = self.accelerator.unwrap_model(self.model).input2emb_model( + **batch + ) event_vectors = ( - self.accelerator.unwrap_model(self.model).eventencoder_model(embedded, **batch) - ).squeeze(1) # (B, 1, E) -> (B, E) + self.accelerator.unwrap_model(self.model).eventencoder_model( + embedded, **batch + ) + ).squeeze( + 1 + ) # (B, 1, E) -> (B, E) event_vectors = event_vectors.cpu().bfloat16().view(torch.int16).numpy() - input_ids = batch["input_ids"].cpu().numpy() # (B, 128) + input_ids = batch["input_ids"].cpu().numpy() # (B, 128) for j, event in enumerate(input_ids): event_tuple = tuple(event[event != 0]) event_to_vec[event_tuple] = event_vectors[j] @@ -566,7 +582,7 @@ def get_local_path(i): for local_dict in local_dicts: # NOTE only work in python >= 3.9.0 main_dict = main_dict | local_dict - + if os.path.exists(get_local_path(-1)): os.remove(get_local_path(-1)) with open(get_local_path(-1), "wb") as main_f: diff --git a/src/trainer/remed.py b/src/trainer/remed.py index 0519fc0..0cd6f4b 100644 --- a/src/trainer/remed.py +++ b/src/trainer/remed.py @@ -75,9 +75,9 @@ def step(sample): ): if self.accelerator.num_processes == 1: # check if test cohort is valid - assert set(data_loader.dataset.manifest) == set(self.test_cohort["subject_id"]), ( - "a set of patient ids in the test cohort should equal to that in the test dataset" - ) + assert set(data_loader.dataset.manifest) == set( + self.test_cohort["subject_id"] + ), "a set of patient ids in the test cohort should equal to that in the test dataset" predicted_cohort = {"subject_id": [], "boolean_prediction": []} do_output_cohort = True else: @@ -99,18 +99,20 @@ def step(sample): if do_output_cohort: predicted_cohort["subject_id"].extend(sample["subject_id"].tolist()) predicted_cohort["boolean_prediction"].extend( - net_output['pred']['meds_single_task'].view(-1).tolist() + net_output["pred"]["meds_single_task"].view(-1).tolist() ) if do_output_cohort: predicted_cohort = pl.DataFrame(predicted_cohort) - self.test_cohort = self.test_cohort.join(predicted_cohort, on="subject_id", how="left") + self.test_cohort = self.test_cohort.join( + predicted_cohort, on="subject_id", how="left" + ) self.test_cohort = self.test_cohort.select( [ pl.col("boolean_prediction"), pl.col("subject_id"), pl.col("prediction_time"), - pl.col("boolean_value") + pl.col("boolean_value"), ] ) @@ -120,4 +122,4 @@ def step(sample): print(log_dict) else: self.accelerator.log(log_dict) - return metrics \ No newline at end of file + return metrics From b0c7069abc91452633a23bd26fa324aef5e10844 Mon Sep 17 00:00:00 2001 From: Junu Kim Date: Mon, 7 Oct 2024 15:20:12 +0900 Subject: [PATCH 4/8] main.py formatting --- main.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/main.py b/main.py index fb8e71d..e76630d 100644 --- a/main.py +++ b/main.py @@ -35,35 +35,35 @@ def get_parser(): "--src_data", type=str, choices=["eicu", "mimiciv", "umcdb", "hirid", "meds"], - default="mimiciv" + default="mimiciv", ) parser.add_argument( "--train_subset", type=str, default="train", help="file name without extension to load data for the training. only used when" - "`--src_data` is set to `'meds'`." + "`--src_data` is set to `'meds'`.", ) parser.add_argument( "--valid_subset", type=str, default="tuning", help="file name without extension to load data for the validation. only used when" - "`--src_data` is set to `'meds'`." + "`--src_data` is set to `'meds'`.", ) parser.add_argument( "--test_subset", type=str, default="held_out", help="file name without extension to load data for the test. only used when `--src_data` " - "is set to `'meds'`." + "is set to `'meds'`.", ) parser.add_argument( "--unique_events_path", type=str, default=None, help="path to directory containing `unique_events.h5` to encode events in MEDS dataset. " - "only used when `--src_data` is set to `'meds'`" + "only used when `--src_data` is set to `'meds'`", ) parser.add_argument( @@ -71,10 +71,10 @@ def get_parser(): type=str, default=None, help="path to the test cohort, which must be a result of ACES. it can be either of " - "directory or the exact file path that has .parquet file extension. if provided with " - "directory, it tries to load `${test_subset}`/*.parquet files in the directory. " - "note that the set of patient ids in this cohort should be matched with that in the " - "test dataset" + "directory or the exact file path that has .parquet file extension. if provided with " + "directory, it tries to load `${test_subset}`/*.parquet files in the directory. " + "note that the set of patient ids in this cohort should be matched with that in the " + "test dataset", ) parser.add_argument( @@ -84,7 +84,7 @@ def get_parser(): "readmission", "los_7", "los_14", - "mortality_1", + "mortality" "mortality_1", "mortality_2", "mortality_3", "mortality_7", @@ -109,7 +109,7 @@ def get_parser(): "sodium_1", "sodium_2", "sodium_3", - "meds_single_task" + "meds_single_task", ], default=[ "readmission", @@ -206,7 +206,9 @@ def get_parser(): parser.add_argument("--debug", action="store_true") parser.add_argument("--log_loss", action="store_true") # Wandb - parser.add_argument("--wandb", action="store_true", help="whether to log using wandb") + parser.add_argument( + "--wandb", action="store_true", help="whether to log using wandb" + ) parser.add_argument("--wandb_entity_name", type=str) parser.add_argument("--wandb_project_name", type=str, default="REMed") parser.add_argument("--pretrained", type=str, default=None) From c0120bab97291e548d99fed7d08f96038dc11b7e Mon Sep 17 00:00:00 2001 From: Jwoo5 Date: Wed, 30 Oct 2024 21:47:21 +0900 Subject: [PATCH 5/8] Update guidelines for supporting MEDS dataset with GenHPF model --- README.md | 203 ++++++++++++------ scripts/meds/encode_events.sh | 16 +- scripts/meds/predict.sh | 42 ++-- .../meds/{pretrain.sh => pretrain_genhpf.sh} | 27 ++- scripts/meds/process_meds.py | 22 +- scripts/meds/train_genhpf.sh | 53 +++++ scripts/meds/{train.sh => train_remed.sh} | 27 ++- src/dataset.py | 11 +- src/trainer/base.py | 63 +++++- 9 files changed, 329 insertions(+), 135 deletions(-) rename scripts/meds/{pretrain.sh => pretrain_genhpf.sh} (64%) create mode 100644 scripts/meds/train_genhpf.sh rename scripts/meds/{train.sh => train_remed.sh} (62%) diff --git a/README.md b/README.md index c3782ab..779a901 100644 --- a/README.md +++ b/README.md @@ -155,10 +155,10 @@ accelerate launch \ ## Support for MEDS dataset > [!Caution] -> This instruction is still under progress, which may not be aligned with the recent updates. +> We are currently investigating the cause of lower performances seen in REMed model with the MEDS dataset. So please note that experiments with REMed model could be unstable for now. -We officially support to process [MEDS](https://github.com/Medical-Event-Data-Standard/meds/releases/tag/0.3.0) dataset (currently, MEDS v0.3) with a cohort defined by [ACES](https://github.com/justin13601/ACES), only for the REMed model. -It consists of 4 steps in total, each of which can be run by Python or shell scripts that are prepared in [`scripts/meds/`](scripts/meds/) directory. +We officially support to process [MEDS](https://github.com/Medical-Event-Data-Standard/meds) dataset with a cohort defined by [ACES](https://github.com/justin13601/ACES) for GenHPF and REMed model in this repository. +It consists of several steps, each of which can be run by Python or shell scripts that are prepared in [`scripts/meds/`](scripts/meds/) directory. For more detailed information, please follow the instructions below. Note that all the following commands should be run in the root directory of the repository, not in `scripts/meds/` or any other sub-directories. Additionally, the following scripts assume your dataset is split into `"train"`, `"tuning"`, and `"held_out"` subsets for training, validation, and test, respecitvely. If it doesn't apply to your case, you can modify them by adding these command line arguments: `--train_subset`, `--valid_subset`, and `--test_subset`. For example, if you need to process only the train subset, you can specify it by adding `--train_subset="train" --valid_subset="" --test_subset=""`. @@ -172,12 +172,15 @@ Additionally, the following scripts assume your dataset is split into `"train"`, $ python scripts/meds/process_meds.py $MEDS_PATH \ --cohort $ACES_COHORT_PATH \ --output_dir $PROCESSED_MEDS_DIR \ + --birth_code $BIRTH_CODE \ --rebase \ --workers $NUM_WORKERS ``` * `$MEDS_PATH`: path to MEDS dataset to be processed. It can be a directory or the exact file path with the file exenstion (only `.csv` or `.parquet` allowed). If provided with directory, it tries to scan all `*.csv` or `*.parquet` files contained in the directory recursively. * `$ACES_COHORT_PATH`: path to the defined cohort, which must be a result of [ACES](https://github.com/justin13601/ACES). It can be a directory or the exact file path that has the same file extension with the MEDS dataset to be processed. The file structure of this cohort directory should be the same with the provided MEDS dataset directory (`$MEDS_PATH`) to match each cohort to its corresponding shard data. * `$PROCESSED_MEDS_DIR`: directory to save processed outputs. + * Enabling `--rebase` will renew this directory. If you don't want, please disable this argument. + * `$BIRTH_CODE`: string code for the birth event in the dataset, set to `"MEDS_BIRTH"` by default. * `$NUM_WORKERS`: number of parallel workes to multi-process the script. * **NOTE: If you encounter this error:** _"polars' maximum length reached. consider installing 'polars-u64-idx'"_, **please consider using more workers or doing `pip install polars-u64-idx`.** * As a result of this script, you will have .h5 and .tsv files that has a following respective structure: @@ -223,35 +226,68 @@ Additionally, the following scripts assume your dataset is split into `"train"`,
- Pretrain event encoder - -* This stage pretrains event encoder (e.g., GenHPF) using a random event sequence with a length of `max_seq_len` (by default, set to `128`) every epoch for each cohort sample. -* After completing the pretraining, we should encode all the events in the dataset and cache them to reuse in the following stage. -* For a shell script to run this, see [`./scripts/meds/pretrain.sh`](./scripts/meds/pretrain.sh). -* For Python, please run: - ```shell script - accelerate launch \ - --config_file config/single.json \ - --num_processes 1 \ - --gpu_ids $GPU_ID \ - main.py \ - --src_data meds \ - --input_path $PROCESSED_MEDS_DIR \ - --save_dir $PRETRAIN_SAVE_DIR \ - --pred_targets meds_single_task \ - --train_type short \ - --lr 5e-5 \ - --random_sample \ - # if you want to log using wandb - --wandb \ - --wandb_entity_name $wandb_entity_name \ - --wandb_project_name $wandb_project_name - ``` - * `$PROCESSED_MEDS_DIR`: directory containing processed MEDS data, expected to contain `*.h5` and `*.tsv` files. - * `$PRETRAIN_SAVE_DIR`: output directory to save the checkpoint for the pretrained event encoder. - * `$GPU_ID`: GPU index to be used for training the model. + Pretrain GenHPF + +* After pre-processing the dataset, we can train GenHPF in two directions: 1) following the original GenHPF setup (e.g., loading the last 256 events for each sample), or 2) following the REMed setup (e.g., random sampling 256 events for each sample) to further train the model with an additional retriever module. +* **Following the original GenHPF setup**: + * For a shell script to run this, see [`./scripts/meds/train_genhpf.sh`](./scripts/meds/train_genhpf.sh). + * This script runs the following Python command: + ```shell script + accelerate launch \ + --config_file config/config.json \ + --num_processes $NUM_PROCESSES \ + --gpu_ids $GPU_IDS \ + main.py \ + --src_data meds \ + --input_path $PROCESSED_MEDS_DIR \ + --save_dir $SAVE_DIR \ + --pred_targets meds_single_task \ + --train_type short \ + --lr 5e-5 \ + --n_agg_layers 4 \ + --pred_dim 128 \ + --batch_size 64 \ + --max_seq_len 512 \ + --dropout 0.3 \ + --seed 2020 \ + --patience 5 \ + # if you want to log using wandb + --wandb \ + --wandb_project_name $wandb_entity_name \ + --wandb_entity_name $wandb_project_name + ``` + * `$PROCESSED_MEDS_DIR`: directory containing processed MEDS data, expected to contain `*.h5` and `*.tsv` files. + * `$PRETRAIN_SAVE_DIR`: output directory to save the checkpoint for the pretrained event encoder. + * `$NUM_PROCESSES`: number of parallel processes. + * `$GPU_IDS`: comma separated list indicating GPU indices (e.g., `0` or `0,1`) to be used for training the model. + * Checkpoint will be saved to `$SAVE_DIR/` with a subdirectory indicating the random seed. For example, if `$SAVE_DIR=/workspace/genhpf` and `--seed=2020`, then the checkpoint will be saved in `workspace/genhpf/genhpf_2020`. + * **To get the final prediction results, jump to the last step.** +* **Following the REMed setup**: + * For a shell script to run this, see [`./scripts/meds/pretrain_genhpf.sh`](./scripts/meds/pretrain_genhpf.sh). + * This script runs the following Python command: + ```shell script + accelerate launch \ + --config_file config/config.json \ + --num_processes $NUM_PROCESSES \ + --gpu_ids $GPU_IDS \ + main.py \ + --src_data meds \ + --input_path $PROCESSED_MEDS_DIR \ + --save_dir $SAVE_DIR \ + --pred_targets meds_single_task \ + --train_type short \ + --lr 5e-5 \ + --batch_size 32 \ + --random_sample \ + --seed 2020 \ + --patience 5 \ + # if you want to log using wandb + --wandb \ + --wandb_entity_name $wandb_entity_name \ + --wandb_project_name $wandb_project_name + ``` * It will pretrain event encoder using the processed MEDS data, which will be used to encode all events present in the MEDS data for the REMed model later. - * Checkpoint for the pretrained event encoder will be saved to `$PRETRAIN_SAVE_DIR/${EXPERIMENT_NAME}` directory, where `${EXPERIMENT_NAME}` is a 32-length hexadecimal string generated automatically for each unique experiment. + * After completing the pretraining, we should encode all the events in the dataset and cache them to reuse in the following stage.
@@ -259,30 +295,47 @@ Additionally, the following scripts assume your dataset is split into `"train"`, Encode all events present in the input MEDS data, and cache them * In this stage, we encode all events present in the input MEDS data, and cache them, which will be input data for the REMed model. -* For a shell script to run this, see [`./scripts/meds/encode_events.sh`](./scripts/meds/encode_events.sh). -* For Python, please run: +* To do this, we should firstly extract all the unique events existed in the MEDS dataset by the following Python command: ```shell script - accelerate launch \ - --config_file config/single.json \ - --num_processes 1 \ - --gpu_ids="$GPU_ID" \ - main.py \ - --src_data meds \ - --input_path $PROCESSED_MEDS_DIR \ - --save_dir $ENCODED_MEDS_DIR \ - --pred_targets meds_single_task \ - --train_type short \ - --random_sample \ - --encode_events \ - --encode_only \ - --resume_name $PRETRAINED_CHECKPOINT_DIR + python ./scripts/meds/extract_unique_events.py \ + $PROCESSED_MEDS_DIR \ + --output_dir $UNIQUE_EVENTS_DIR \ + --workers $NUM_WORKERS + ``` + * This script will save the unique events to `$UNIQUE_EVENTS_DIR`. + * **Note that you don't need to run this script multiple times per dataset; just run only once per dataset.** +* Then, we can encode unique events using the pre-traiend event encoder (i.e., GenHPF) by: + * For a shell script, see [`./scripts/meds/encode_events.sh`](./scripts/meds/encode_events.sh). + * This script runs the following Python command: + ```shell script + accelerate launch \ + --config_file config/single.json \ + --num_processes 1 \ + --gpu_ids $GPU_ID \ + main.py \ + --src_data meds \ + --input_path null \ + --unique_events_path $UNIQUE_EVENTS_DIR \ + --save_dir $ENCODED_EVENTS_DIR \ + --pred_targets meds_single_task \ + --train_type short \ + --batch_size 8192 \ + --encode_events \ + --encode_only \ + --resume_name $PRETRAINED_CHECKPOINT_DIR + ``` + * `$PRETRAINED_CHECKPOINT_DIR`: directory containing checkpoint for the pretrained event encoder containing `checkpoint_best.pt`. + * This script will generate `event_to_vec.pkl` to `$ENCODED_EVENTS_DIR`, which is a look-up table to encode each event to its embedding vector by the pretrained event encoder. +* Finally, we will encode all the samples in the MEDS dataset using the look-up table (`event_to_vec.pkl`). + ```shell script + python ./scripts/meds/map_events_to_vec.py \ + $PROCESSED_MEDS_DIR \ + --map_dir $ENCODED_EVENTS_DIR \ + --output_dir $ENCODED_MEDS_DIR \ + --workers $NUM_WORKERS ``` - * `$PROCESSED_MEDS_DIR`: directory containing processed MEDS data, expected to contain `*.h5` and `*.tsv` files. - * `$ENCODED_MEDS_DIR`: output directory to save the encoded data where the file names will be `*_encoded.h5`. - * `$GPU_ID`: GPU index to be used for running the model. - * `$PRETRAINED_CHECKPOINT_DIR`: directory containing checkpoint for the pretrained event encoder, expected to be `$PRETRAIN_SAVE_DIR/${EXPERIMENT_NAME}` containing `checkpoint_best.pt`. - * It will encode all events present in the processed meds data (`*.h5`) located in `$PROCESSED_MEDS_DIR`, and save the results into `ENCODED_MEDS_DIR/*_encoded.h5`. - * Note that it requires large empty disk space (>200G) to save all the encoded events to the storage. This process will take about 3 hours (for ~7500 steps). + * This script will encode all events present in the processed meds data (`*.h5`) located in `$PROCESSED_MEDS_DIR`, and save the results into `$ENCODED_MEDS_DIR/*_encoded.h5`. + * **Note that it requires large empty disk space to save all the encoded events to the storage.** @@ -291,13 +344,13 @@ Additionally, the following scripts assume your dataset is split into `"train"`, * In this stage, we finally train the REMed model using the encoded MEDS data. * After training ends, it will save the best checkpoint for the trained REMed model. -* For a shell script to run this, see [`./scripts/meds/train.sh`](./scripts/meds/train.sh). -* For Python, please run: +* For a shell script to run this, see [`./scripts/meds/train_remed.sh`](./scripts/meds/train_remed.sh). +* This script runs the following Python command: ```shell script accelerate launch \ - --config_file config/single.json \ - --num_processes 1 \ - --gpu_ids $GPU_ID \ + --config_file config/config.json \ + --num_processes $NUM_PROCESSES \ + --gpu_ids $GPU_IDS \ main.py \ --src_data meds \ --input_path $ENCODED_MEDS_DIR \ @@ -305,6 +358,7 @@ Additionally, the following scripts assume your dataset is split into `"train"`, --pred_targets meds_single_task \ --train_type remed \ --lr 1e-5 \ + --batch_size 32 \ --scorer \ --scorer_use_time \ --max_seq_len 200000 \ @@ -316,16 +370,15 @@ Additionally, the following scripts assume your dataset is split into `"train"`, ``` * `$ENCODED_MEDS_DIR`: directory containing encoded MEDS data, expected to contain `*_encoded.h5` files. * `$REMED_SAVE_DIR`: output directory to save the REMed model checkpoint. - * `$GPU_ID`: GPU index to be used for running the model.
- Generate predicted results to the test cohort dataframe for a given task using trained REMed model + Generate predicted results to the test cohort dataframe for a given task using trained model (GenHPF or REMed) -* In this final stage, we load the trained REMed model to do prediction on the test cohort for a given task, and generate the predicted results as two additional columns, `predicted_label` and `predicted_prob`, to the test cohort dataframe. +* In this final stage, we load the trained model to do prediction on the test cohort for a given task, and generate the predicted results as two additional columns, `predicted_label` and `predicted_prob`, to the test cohort dataframe. * For a shell script to run this, see [`./scripts/meds/predict.sh`](./scripts/meds/predict.sh). -* For Python, please run: +* This script runs the following Python command: ```shell script accelerate launch \ --config_file config/single.json \ @@ -336,20 +389,28 @@ Additionally, the following scripts assume your dataset is split into `"train"`, --input_path $ENCODED_MEDS_DIR \ --save_dir $SAVE_DIR \ --pred_targets meds_single_task \ - --train_type remed \ - --scorer \ - --scorer_use_time \ - --max_seq_len 200000 \ - --max_retrieve_len 512 \ + --train_type $REMED_OR_SHORT \ --test_only \ --test_cohort $ACES_TEST_COHORT_DIR \ --resume_name $CHECKPOINT_DIR + # enable the following arguments for GenHPF model + # --n_agg_layers 4 \ + # --pred_dim 128 \ + # --max_seq_len 512 \ + # --dropout 0.3 \ + + # enable the following arguments for REMed model + # --scorer \ + # --scorer_use_time \ + # --max_seq_len 200000 \ + # --max_retrieve_len 512 \ ``` + * `$REMED_OR_SHORT`: `"remed"` for REMed model, `"short"` for GenHPF model. * `$ENCODED_MEDS_DIR`: directory containing encoded MEDS data, expected to contain `*_encoded.h5` files. - * `$SAVE_DIR`: output directory to save the predicted results, which will be `$test_subset.parquet`. the results will be saved to `${SAVE_DIR}/${EXPERIMENT_NAME}` directory. this result file has the same rows with the test cohort dataframe provided with `$ACES_TEST_COHORT_DIR`, but has two additional columns: `predicted_label` and `predicted_prob` - * `$GPU_ID`: GPU index to be used for running the model. - * `$ACES_TEST_COHORT_DIR`: directory containing test cohorts generated from ACES, expected to contain `*.parquet` files. - * `$CHECKPOINT_DIR`: directory containing checkpoint for the trained REMed model, expected to be `$REMED_SAVE_DIR/${EXPERIMENT_NAME}` + * `$SAVE_DIR`: output directory to save the predicted results as `$test_subset.parquet` (e.g., `held_out.parquet`). This result file has the same rows with the test cohort dataframe provided with `$ACES_TEST_COHORT_DIR`, but has two additional columns: `predicted_label` and `predicted_prob`. + * `$ACES_TEST_COHORT_DIR`: directory containing **test (held_out)** cohorts generated from ACES, expected to contain `*.parquet` files. + * `$CHECKPOINT_DIR`: directory containing checkpoint for the trained REMed model containing `checkpoint_best.pt`. + * **Note that this script doesn't support parallel processing currently. Please use single GPU only.**
diff --git a/scripts/meds/encode_events.sh b/scripts/meds/encode_events.sh index 3340bae..a2e1eff 100644 --- a/scripts/meds/encode_events.sh +++ b/scripts/meds/encode_events.sh @@ -2,15 +2,15 @@ # Function to display help message function display_help() { - echo "Usage: $0 " + echo "Usage: $0 " echo echo "This script encodes all events present in a MEDS cohort and caches them, which will" echo "be the input data for the REMed model." echo echo "Arguments:" - echo " PROCESSED_MEDS_DIR Directory containing processed MEDS data, expected to contain *.h5 and *.tsv files." - echo " SAVE_DIR Output directory to save the encoded data as *_encoded.h5." echo " GPU_ID GPU index to be used for training the model." + echo " UNIQUE_EVENTS_DIR directory containing the unique events to be encoded." + echo " SAVE_DIR Output directory to save the encoded unique events." echo " PRETRAINED_CHECKPOINT_DIR Directory containing checkpoint for the pretrained event encoder, expected to contain checkpoint_best.pt." echo echo "Options:" @@ -24,10 +24,10 @@ if [ "$#" -lt 4 ]; then display_help fi -UNIQUE_EVENTS_DIR="$1" -SAVE_DIR="$2" -PRETRAINED_CHECKPOINT_DIR="$3" -GPU_ID="$4" +GPU_ID="$1" +UNIQUE_EVENTS_DIR="$2" +SAVE_DIR="$3" +PRETRAINED_CHECKPOINT_DIR="$4" accelerate launch \ --config_file config/single.json \ @@ -43,4 +43,4 @@ accelerate launch \ --batch_size 8192 \ --encode_events \ --encode_only \ - --resume_name "$PRETRAINED_CHECKPOINT_DIR" \ No newline at end of file + --resume_name "$PRETRAINED_CHECKPOINT_DIR" diff --git a/scripts/meds/predict.sh b/scripts/meds/predict.sh index 5ce0264..a536728 100644 --- a/scripts/meds/predict.sh +++ b/scripts/meds/predict.sh @@ -2,17 +2,18 @@ # Function to display help message function display_help() { - echo "Usage: $0 " + echo "Usage: $0 " echo echo "This script produces predicted labels and their probabilities for a given task and its" echo "cohort." echo echo "Arguments:" + echo " GPU_ID GPU index to be used for training the model." echo " ENCODED_MEDS_DIR Directory containing encoded MEDS data, expected to contain *_encoded.h5 files" echo " SAVE_DIR Output directory to save the predicted results." + echo " REMED_OR_SHORT String indicator for whether to test REMed model ('remed') or GenHPF model ('short)" echo " ACES_TEST_COHORT_DIR Directory containing test cohorts generated from ACES, expected to contain *.parquet files." echo " CHECKPOINT_DIR Directory containing checkpoint for the trained REMed model, expected to contain checkpoint_best.pt." - echo " GPU_ID GPU index to be used for training the model." echo echo "Options:" echo " -h, --help Display this help message and exit." @@ -25,11 +26,12 @@ if [ "$#" -lt 5 ]; then display_help fi -ENCODED_MEDS_DIR="$1" -SAVE_DIR="$2" -ACES_TEST_COHORT_DIR="$3" -CHECKPOINT_DIR="$4" -GPU_ID="$5" +GPU_ID="$1" +ENCODED_MEDS_DIR="$2" +SAVE_DIR="$3" +REMED_OR_SHORT="$4" +ACES_TEST_COHORT_DIR="$5" +CHECKPOINT_DIR="$6" accelerate launch \ --config_file config/single.json \ @@ -37,15 +39,21 @@ accelerate launch \ --gpu_ids $GPU_ID \ main.py \ --src_data meds \ - --input_path "$ENCODED_MEDS_DIR" \ - --save_dir "$SAVE_DIR" \ + --input_path $ENCODED_MEDS_DIR \ + --save_dir $SAVE_DIR \ --pred_targets meds_single_task \ - --pred_time 24 \ - --train_type remed \ - --scorer \ - --scorer_use_time \ - --max_seq_len 200000 \ - --max_retrieve_len 512 \ + --train_type $REMED_OR_SHORT \ --test_only \ - --test_cohort "$ACES_TEST_COHORT_DIR" \ - --resume_name "$CHECKPOINT_DIR" \ No newline at end of file + --test_cohort $ACES_TEST_COHORT_DIR \ + --resume_name $CHECKPOINT_DIR + # enable the following arguments for GenHPF model + # --n_agg_layers 4 \ + # --pred_dim 128 \ + # --max_seq_len 512 \ + # --dropout 0.3 \ + + # enable the following arguments for REMed model + # --scorer \ + # --scorer_use_time \ + # --max_seq_len 200000 \ + # --max_retrieve_len 512 \ diff --git a/scripts/meds/pretrain.sh b/scripts/meds/pretrain_genhpf.sh similarity index 64% rename from scripts/meds/pretrain.sh rename to scripts/meds/pretrain_genhpf.sh index de6b158..be78053 100644 --- a/scripts/meds/pretrain.sh +++ b/scripts/meds/pretrain_genhpf.sh @@ -2,15 +2,16 @@ # Function to display help message function display_help() { - echo "Usage: $0 " + echo "Usage: $0 " echo echo "This script pretrains event encoder using a MEDS cohort, which will be used to encode" echo "all events present in the MEDS cohort for the REMed model later." echo echo "Arguments:" + echo " NUM_PROCESSES Number of parallel processes" + echo " GPU_IDS GPU index to be used for training the model." echo " PROCESSED_MEDS_DIR Directory containing processed MEDS data, expected to contain *.h5 and *.tsv files." echo " SAVE_DIR Output directory to save the checkpoint for the pretrained event encoder." - echo " GPU_ID GPU index to be used for training the model." echo echo "Options:" echo " -h, --help Display this help message and exit." @@ -23,15 +24,15 @@ if [ "$#" -lt 3 ]; then display_help fi - -PROCESSED_MEDS_DIR="$1" -SAVE_DIR="$2" -GPU_ID="$3" +NUM_PROCESSES="$1" +GPU_IDS="$2" +PROCESSED_MEDS_DIR="$3" +SAVE_DIR="$4" accelerate launch \ - --config_file config/single.json \ - --num_processes 1 \ - --gpu_ids="$GPU_ID" \ + --config_file config/config.json \ + --num_processes $NUM_PROCESSES \ + --gpu_ids="$GPU_IDS" \ main.py \ --src_data meds \ --input_path "$PROCESSED_MEDS_DIR" \ @@ -39,4 +40,10 @@ accelerate launch \ --pred_targets meds_single_task \ --train_type short \ --lr 5e-5 \ - --random_sample \ No newline at end of file + --batch_size 32 \ + --random_sample \ + --seed 2020 \ + --patience 5 \ + # --wandb \ + # --wandb_project_name ??? \ + # --wandb_entity_name ??? diff --git a/scripts/meds/process_meds.py b/scripts/meds/process_meds.py index cd83a37..2a3b8cf 100644 --- a/scripts/meds/process_meds.py +++ b/scripts/meds/process_meds.py @@ -46,6 +46,13 @@ def get_parser(): help="path to metadata directory for the input MEDS dataset, which contains codes.parquet", ) + parser.add_argument( + "--birth_code", + type=str, + default="MEDS_BIRTH", + help="string code for the birth event in the dataset." + ) + parser.add_argument( "--cohort", type=str, @@ -125,14 +132,12 @@ def main(args): codes_metadata = pl.read_parquet(metadata_dir / "codes.parquet").to_pandas() codes_metadata = codes_metadata.set_index("code")["description"].to_dict() # do not allow to use static events or birth event - birth_code = ( - "MEDS_BIRTH" # NOTE can we assume code for "birth" is always "MEDS_BIRTH"? - ) - if birth_code not in codes_metadata: - print( - f'"{birth_code}" is not found in the codes metadata, which may lead to ' - "unexpected results since we currently exclude this event from the input data. " - ) + birth_code = args.birth_code + # if birth_code not in codes_metadata: + # print( + # f'"{birth_code}" is not found in the codes metadata, which may lead to ' + # "unexpected results since we currently exclude this event from the input data. " + # ) if mimic_dir is not None: d_items = pd.read_csv(mimic_dir / "icu" / "d_items.csv.gz") @@ -531,6 +536,7 @@ def meds_to_remed_unit(row): ["subject_id", "cohort_end", "cohort_label"], maintain_order=True ).agg(pl.all()) + df_chunk = df_chunk.sort(by=['subject_id', 'cohort_end']) # regard {subject_id} as {cohort_id}: {subject_id}_{cohort_number} df_chunk = df_chunk.with_columns( pl.col("subject_id").cum_count().over("subject_id").alias("suffix") diff --git a/scripts/meds/train_genhpf.sh b/scripts/meds/train_genhpf.sh new file mode 100644 index 0000000..fe30223 --- /dev/null +++ b/scripts/meds/train_genhpf.sh @@ -0,0 +1,53 @@ +#!/bin/bash + +# Function to display help message +function display_help() { + echo "Usage: $0 " + echo + echo "This script pretrains event encoder using a MEDS cohort, which will be used to encode" + echo "all events present in the MEDS cohort for the REMed model later." + echo + echo "Arguments:" + echo " NUM_PROCESSES Number of parallel processes" + echo " GPU_IDS GPU indices to be used for training the model." + echo " PROCESSED_MEDS_DIR Directory containing processed MEDS data, expected to contain *.h5 and *.tsv files." + echo " SAVE_DIR Output directory to save the checkpoint for the pretrained event encoder." + echo + echo "Options:" + echo " -h, --help Display this help message and exit." + exit 1 +} + +# Check for mandatory parameters +if [ "$#" -lt 3 ]; then + echo "Error: Incorrect number of arguments provided." + display_help +fi + + +NUM_PROCESSES="$1" +GPU_IDS="$2" +PROCESSED_MEDS_DIR="$3" +SAVE_DIR="$4" + +accelerate launch \ + --config_file config/config.json \ + --num_processes $NUM_PROCESSES \ + --gpu_ids="$GPU_IDS" \ + main.py \ + --src_data meds \ + --input_path "$PROCESSED_MEDS_DIR" \ + --save_dir "$SAVE_DIR" \ + --pred_targets meds_single_task \ + --train_type short \ + --lr 5e-5 \ + --n_agg_layers 4 \ + --pred_dim 128 \ + --batch_size 64 \ + --max_seq_len 512 \ + --dropout 0.3 \ + --seed 2020 \ + --patience 5 \ + # --wandb \ + # --wandb_project_name ??? \ + # --wandb_entity_name ??? diff --git a/scripts/meds/train.sh b/scripts/meds/train_remed.sh similarity index 62% rename from scripts/meds/train.sh rename to scripts/meds/train_remed.sh index bd90c7a..ad0a60f 100644 --- a/scripts/meds/train.sh +++ b/scripts/meds/train_remed.sh @@ -2,15 +2,16 @@ # Function to display help message function display_help() { - echo "Usage: $0 " + echo "Usage: $0 " echo echo "This script encodes all the events present in a MEDS cohort and caches them, which will" echo "be the input data for the REMed model." echo echo "Arguments:" + echo " NUM_PROCESSES Number of parallel processes" + echo " GPU_IDS GPU index to be used for training the model." echo " ENCODED_MEDS_DIR Directory containing encoded MEDS data, expected to contain *_encoded.h5 files" echo " SAVE_DIR Output directory to save the model checkpoint" - echo " GPU_ID GPU index to be used for training the model." echo echo "Options:" echo " -h, --help Display this help message and exit." @@ -23,21 +24,27 @@ if [ "$#" -lt 3 ]; then display_help fi -ENCODED_MEDS_DIR="$1" -SAVE_DIR="$2" -GPU_ID="$3" +NUM_PROCESSES="$1" +GPU_IDS="$2" +ENCODED_MEDS_DIR="$3" +SAVE_DIR="$4" accelerate launch \ - --config_file config/single.json \ - --num_processes 1 \ - --gpu_ids $GPU_ID \ + --config_file config/config.json \ + --num_processes $NUM_PROCESSES \ + --gpu_ids $GPU_IDS \ main.py \ --src_data meds \ --input_path "$ENCODED_MEDS_DIR" \ - --save_dir "$SAVE_DIR" \ + --save_dir "$REMED_SAVE_DIR" \ --pred_targets meds_single_task \ --train_type remed \ --lr 1e-5 \ + --batch_size 32 \ --scorer \ --scorer_use_time \ - --max_seq_len 200000 \ No newline at end of file + --max_seq_len 200000 \ + --max_retrieve_len 512 \ + # --wandb \ + # --wandb_project_name ??? \ + # --wandb_entity_name ??? diff --git a/src/dataset.py b/src/dataset.py index e80d3d7..e3a22c4 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -222,6 +222,7 @@ def __init__(self, args, split, data_path, *pargs, **kwargs): self.manifest = pd.read_csv( os.path.join(data_path, split + ".tsv"), delimiter="\t" ).set_index("subject_id") + self.subject_ids = self.manifest.index.tolist() unique_shard_ids = self.manifest["shard_id"].unique() self.data = {} @@ -249,7 +250,7 @@ def collate_fn(self, samples): ret[k]["meds_single_task"] = torch.FloatTensor( torch.stack([s["label"] for s in samples]) ) - elif k in ["subject_id", "index"]: # for MEDSForReprGen + elif k == "subject_id": ret[k] = np.array([s[k] for s in samples]) else: padded = pad_sequence([s[k] for s in samples], batch_first=True) @@ -266,7 +267,6 @@ def __getitem__(self, idx): # assume it is a scalar value for a binary classification task label = torch.tensor([data["label"][()]]).float() - #XXX max_num_events = 300000 if self.args.max_seq_len < len(input): length = len(input) @@ -283,13 +283,15 @@ def __getitem__(self, idx): else: input = input[-self.args.max_seq_len:, :, :] times = times[-self.args.max_seq_len:] + times = times - times[0] return { "input_ids": torch.LongTensor(input[:, 0, :]), "type_ids": torch.LongTensor(input[:, 1, :]), "dpe_ids": torch.LongTensor(input[:, 2, :]), "times": torch.IntTensor(times), - "label": label + "label": label, + "subject_id": subject_id, } @@ -349,7 +351,6 @@ def __init__(self, args, split, data_path, *pargs, **kwargs): self.args = args self.data = {} - #TODO ..? for fname in glob.glob(os.path.join(data_path, split, f"*_encoded.h5")): shard_id = int(os.path.splitext(fname)[0].split("_")[-2]) self.data[shard_id] = h5pickle.File( @@ -362,6 +363,8 @@ def __init__(self, args, split, data_path, *pargs, **kwargs): self.manifest |= shard_manifest self.keys = list(self.manifest.keys()) + self.subject_ids = list(self.manifest.keys()) + def __len__(self): return len(self.manifest) diff --git a/src/trainer/base.py b/src/trainer/base.py index 811aceb..c53e882 100644 --- a/src/trainer/base.py +++ b/src/trainer/base.py @@ -41,11 +41,17 @@ def __init__(self, args): if self.test_cohort is not None: # make subject_id to be {subject_id}_{cohort_number} to prevent duplicated ids if os.path.isdir(self.test_cohort): - test_cohort = pl.read_parquet( - os.path.join(self.test_cohort, self.test_subset, "*.parquet") - ) + if os.path.basename(self.test_cohort) != self.test_subset: + test_cohort = pl.read_parquet( + os.path.join(self.test_cohort, self.test_subset, "*.parquet") + ) + else: + test_cohort = pl.read_parquet( + os.path.join(self.test_cohort, "*.parquet") + ) else: test_cohort = pl.read_parquet(self.test_cohort) + test_cohort = test_cohort.sort(by=["subject_id", "prediction_time"]) test_cohort = test_cohort.with_columns( pl.col("subject_id").cum_count().over("subject_id").alias("suffix") ) @@ -75,9 +81,14 @@ def run(self): self.args.batch_size // self.accelerator.num_processes ) if self.args.src_data == "meds": - if self.args.save_dir.endswith("/"): - self.args.save_dir = self.args.save_dir[:-1] - self.args.exp_name = os.path.basename(self.args.save_dir) + "_" + str(self.args.seed) + if self.args.test_only: + if self.args.resume_name.endswith("/"): + self.args.resume_name = self.args.resume_name[:-1] + self.args.exp_name = os.path.basename(self.args.resume_name) + else: + if self.args.save_dir.endswith("/"): + self.args.save_dir = self.args.save_dir[:-1] + self.args.exp_name = os.path.basename(self.args.save_dir) + "_" + str(self.args.seed) elif self.args.resume_name: self.args.exp_name = self.args.resume_name else: @@ -103,7 +114,7 @@ def run(self): exp_encoded = broadcast(exp_encoded.to(self.accelerator.device)) self.args.exp_name = "".join([chr(int(i)) for i in exp_encoded]) - if not self.args.encode_only: + if not self.args.encode_only and not self.args.test_only: os.makedirs( os.path.join(self.args.save_dir, self.args.exp_name), exist_ok=True ) @@ -342,8 +353,34 @@ def epoch(self, split, data_loader, n_epoch=0): t = tqdm(data_loader, desc=f"{split} epoch {n_epoch}") else: t = data_loader + + do_output_cohort = False + if self.args.src_data == "meds" and ( + split == self.test_subset and self.test_cohort is not None + ): + if self.accelerator.num_processes == 1: + # check if test cohort is valid + assert set(data_loader.dataset.subject_ids) == set(self.test_cohort["subject_id"]), ( + "a set of patient ids in the test cohort should equal to that in the test dataset" + ) + predicted_cohort = {"subject_id": [], "boolean_prediction": []} + do_output_cohort = True + else: + logger.warning( + "not yet implemented to output predicted labels and probs with " + "--test_cohort in multi-processing environment. please run with " + "--num_processes=1 in accelerate launch." + ) + for sample in t: output, reprs = self.model(**sample) + # meds -- output + if do_output_cohort: + predicted_cohort["subject_id"].extend(sample["subject_id"].tolist()) + predicted_cohort["boolean_prediction"].extend( + output["pred"]["meds_single_task"].view(-1).tolist() + ) + loss, logging_outputs = self.criterion(output, reprs) if split == self.train_subset: self.optimizer.zero_grad(set_to_none=True) @@ -358,6 +395,18 @@ def epoch(self, split, data_loader, n_epoch=0): ): self.accelerator.log({f"{split}_loss": loss}) + if do_output_cohort: + predicted_cohort = pl.DataFrame(predicted_cohort) + self.test_cohort = self.test_cohort.join(predicted_cohort, on="subject_id", how="left") + self.test_cohort = self.test_cohort.select( + [ + pl.col("boolean_prediction"), + pl.col("subject_id"), + pl.col("prediction_time"), + pl.col("boolean_value") + ] + ) + metrics = self.metric.get_metrics() log_dict = log_from_dict(metrics, split, n_epoch) logger.info(log_dict) From 47742f4e6c86ee5d74352fb8b3159ec1a2807d5e Mon Sep 17 00:00:00 2001 From: Jwoo5 Date: Thu, 14 Nov 2024 11:33:29 +0900 Subject: [PATCH 6/8] Update readme --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 779a901..97587eb 100644 --- a/README.md +++ b/README.md @@ -171,12 +171,14 @@ Additionally, the following scripts assume your dataset is split into `"train"`, ```shell script $ python scripts/meds/process_meds.py $MEDS_PATH \ --cohort $ACES_COHORT_PATH \ + --metadata_dir $METADATA_DIR \ --output_dir $PROCESSED_MEDS_DIR \ --birth_code $BIRTH_CODE \ --rebase \ --workers $NUM_WORKERS ``` * `$MEDS_PATH`: path to MEDS dataset to be processed. It can be a directory or the exact file path with the file exenstion (only `.csv` or `.parquet` allowed). If provided with directory, it tries to scan all `*.csv` or `*.parquet` files contained in the directory recursively. + * `$METADATA_DIR`: path to the metadata directory for the input MEDS dataset, expected to contain `codes.parquet`. This is used to retrieve descriptions for codes in MEDS events and convert each code to the retrieved description. Note that if a code has no specific description in `codes.parquet`, it will just treat that code as a plain text and process the event as it is. * `$ACES_COHORT_PATH`: path to the defined cohort, which must be a result of [ACES](https://github.com/justin13601/ACES). It can be a directory or the exact file path that has the same file extension with the MEDS dataset to be processed. The file structure of this cohort directory should be the same with the provided MEDS dataset directory (`$MEDS_PATH`) to match each cohort to its corresponding shard data. * `$PROCESSED_MEDS_DIR`: directory to save processed outputs. * Enabling `--rebase` will renew this directory. If you don't want, please disable this argument. From 4eec590060b26049e80301828c1428d02e7db52a Mon Sep 17 00:00:00 2001 From: Jwoo5 Date: Wed, 11 Dec 2024 16:02:59 +0900 Subject: [PATCH 7/8] Update readme --- README.md | 4 ++-- scripts/meds/predict.sh | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 97587eb..650f965 100644 --- a/README.md +++ b/README.md @@ -388,7 +388,7 @@ Additionally, the following scripts assume your dataset is split into `"train"`, --gpu_ids $GPU_ID \ main.py \ --src_data meds \ - --input_path $ENCODED_MEDS_DIR \ + --input_path $MEDS_DATA_DIR \ --save_dir $SAVE_DIR \ --pred_targets meds_single_task \ --train_type $REMED_OR_SHORT \ @@ -408,7 +408,7 @@ Additionally, the following scripts assume your dataset is split into `"train"`, # --max_retrieve_len 512 \ ``` * `$REMED_OR_SHORT`: `"remed"` for REMed model, `"short"` for GenHPF model. - * `$ENCODED_MEDS_DIR`: directory containing encoded MEDS data, expected to contain `*_encoded.h5` files. + * `$MEDS_DATA_DIR`: directory containing MEDS data. Specifically, same with `$PROCESSED_MEDS_DIR` for GenHPF, or `$ENCODED_MEDS_DIR` for REMed. * `$SAVE_DIR`: output directory to save the predicted results as `$test_subset.parquet` (e.g., `held_out.parquet`). This result file has the same rows with the test cohort dataframe provided with `$ACES_TEST_COHORT_DIR`, but has two additional columns: `predicted_label` and `predicted_prob`. * `$ACES_TEST_COHORT_DIR`: directory containing **test (held_out)** cohorts generated from ACES, expected to contain `*.parquet` files. * `$CHECKPOINT_DIR`: directory containing checkpoint for the trained REMed model containing `checkpoint_best.pt`. diff --git a/scripts/meds/predict.sh b/scripts/meds/predict.sh index a536728..78e6303 100644 --- a/scripts/meds/predict.sh +++ b/scripts/meds/predict.sh @@ -9,7 +9,7 @@ function display_help() { echo echo "Arguments:" echo " GPU_ID GPU index to be used for training the model." - echo " ENCODED_MEDS_DIR Directory containing encoded MEDS data, expected to contain *_encoded.h5 files" + echo " MEDS_DATA_DIR Directory containing MEDS data. Same with `$PROCESSED_MEDS_DIR` for GenHPF, or `$ENCODED_MEDS_DIR` for REMed" echo " SAVE_DIR Output directory to save the predicted results." echo " REMED_OR_SHORT String indicator for whether to test REMed model ('remed') or GenHPF model ('short)" echo " ACES_TEST_COHORT_DIR Directory containing test cohorts generated from ACES, expected to contain *.parquet files." @@ -27,7 +27,7 @@ if [ "$#" -lt 5 ]; then fi GPU_ID="$1" -ENCODED_MEDS_DIR="$2" +MEDS_DATA_DIR="$2" SAVE_DIR="$3" REMED_OR_SHORT="$4" ACES_TEST_COHORT_DIR="$5" @@ -39,7 +39,7 @@ accelerate launch \ --gpu_ids $GPU_ID \ main.py \ --src_data meds \ - --input_path $ENCODED_MEDS_DIR \ + --input_path $MEDS_DATA_DIR \ --save_dir $SAVE_DIR \ --pred_targets meds_single_task \ --train_type $REMED_OR_SHORT \ From de5846e0910edd43f996d252712b76c5fecaae4e Mon Sep 17 00:00:00 2001 From: Jwoo5 Date: Wed, 11 Dec 2024 16:03:25 +0900 Subject: [PATCH 8/8] update prediction results to be compatible with meds-evaluation --- src/trainer/base.py | 12 ++++++++---- src/trainer/remed.py | 10 +++++++--- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/src/trainer/base.py b/src/trainer/base.py index f4d5b88..2565a11 100644 --- a/src/trainer/base.py +++ b/src/trainer/base.py @@ -369,7 +369,7 @@ def epoch(self, split, data_loader, n_epoch=0): assert set(data_loader.dataset.subject_ids) == set(self.test_cohort["subject_id"]), ( "a set of patient ids in the test cohort should equal to that in the test dataset" ) - predicted_cohort = {"subject_id": [], "boolean_prediction": []} + predicted_cohort = {"subject_id": [], "predicted_boolean_probability": []} do_output_cohort = True else: logger.warning( @@ -383,7 +383,7 @@ def epoch(self, split, data_loader, n_epoch=0): # meds -- output if do_output_cohort: predicted_cohort["subject_id"].extend(sample["subject_id"].tolist()) - predicted_cohort["boolean_prediction"].extend( + predicted_cohort["predicted_boolean_probability"].extend( output["pred"]["meds_single_task"].view(-1).tolist() ) @@ -403,13 +403,17 @@ def epoch(self, split, data_loader, n_epoch=0): if do_output_cohort: predicted_cohort = pl.DataFrame(predicted_cohort) + predicted_cohort = predicted_cohort.with_columns( + (pl.col("predicted_boolean_probability") > 0.5).alias("predicted_boolean_value") + ) self.test_cohort = self.test_cohort.join(predicted_cohort, on="subject_id", how="left") self.test_cohort = self.test_cohort.select( [ - pl.col("boolean_prediction"), pl.col("subject_id"), pl.col("prediction_time"), - pl.col("boolean_value") + pl.col("boolean_value"), + pl.col("predicted_boolean_value"), + pl.col("predicted_boolean_probability"), ] ) diff --git a/src/trainer/remed.py b/src/trainer/remed.py index 0cd6f4b..eb93cf2 100644 --- a/src/trainer/remed.py +++ b/src/trainer/remed.py @@ -78,7 +78,7 @@ def step(sample): assert set(data_loader.dataset.manifest) == set( self.test_cohort["subject_id"] ), "a set of patient ids in the test cohort should equal to that in the test dataset" - predicted_cohort = {"subject_id": [], "boolean_prediction": []} + predicted_cohort = {"subject_id": [], "predicted_boolean_probability": []} do_output_cohort = True else: logger.warning( @@ -98,21 +98,25 @@ def step(sample): # meds -- output if do_output_cohort: predicted_cohort["subject_id"].extend(sample["subject_id"].tolist()) - predicted_cohort["boolean_prediction"].extend( + predicted_cohort["predicted_boolean_probability"].extend( net_output["pred"]["meds_single_task"].view(-1).tolist() ) if do_output_cohort: predicted_cohort = pl.DataFrame(predicted_cohort) + predicted_cohort = predicted_cohort.with_columns( + (pl.col("predicted_boolean_probability") > 0.5).alias("predicted_boolean_value") + ) self.test_cohort = self.test_cohort.join( predicted_cohort, on="subject_id", how="left" ) self.test_cohort = self.test_cohort.select( [ - pl.col("boolean_prediction"), pl.col("subject_id"), pl.col("prediction_time"), pl.col("boolean_value"), + pl.col("predicted_boolean_value"), + pl.col("predicted_boolean_probability"), ] )