diff --git a/README.md b/README.md
index 258a9cf..650f965 100644
--- a/README.md
+++ b/README.md
@@ -153,8 +153,12 @@ accelerate launch \
## Support for MEDS dataset
-We officially support to process [MEDS](https://github.com/Medical-Event-Data-Standard/meds/releases/tag/0.3.0) dataset (currently, MEDS v0.3) with a cohort defined by [ACES](https://github.com/justin13601/ACES), only for the REMed model.
-It consists of 4 steps in total, each of which can be run by Python or shell scripts that are prepared in [`scripts/meds/`](scripts/meds/) directory.
+
+> [!Caution]
+> We are currently investigating the cause of lower performances seen in REMed model with the MEDS dataset. So please note that experiments with REMed model could be unstable for now.
+
+We officially support to process [MEDS](https://github.com/Medical-Event-Data-Standard/meds) dataset with a cohort defined by [ACES](https://github.com/justin13601/ACES) for GenHPF and REMed model in this repository.
+It consists of several steps, each of which can be run by Python or shell scripts that are prepared in [`scripts/meds/`](scripts/meds/) directory.
For more detailed information, please follow the instructions below.
Note that all the following commands should be run in the root directory of the repository, not in `scripts/meds/` or any other sub-directories.
Additionally, the following scripts assume your dataset is split into `"train"`, `"tuning"`, and `"held_out"` subsets for training, validation, and test, respecitvely. If it doesn't apply to your case, you can modify them by adding these command line arguments: `--train_subset`, `--valid_subset`, and `--test_subset`. For example, if you need to process only the train subset, you can specify it by adding `--train_subset="train" --valid_subset="" --test_subset=""`.
@@ -167,13 +171,18 @@ Additionally, the following scripts assume your dataset is split into `"train"`,
```shell script
$ python scripts/meds/process_meds.py $MEDS_PATH \
--cohort $ACES_COHORT_PATH \
+ --metadata_dir $METADATA_DIR \
--output_dir $PROCESSED_MEDS_DIR \
+ --birth_code $BIRTH_CODE \
--rebase \
--workers $NUM_WORKERS
```
* `$MEDS_PATH`: path to MEDS dataset to be processed. It can be a directory or the exact file path with the file exenstion (only `.csv` or `.parquet` allowed). If provided with directory, it tries to scan all `*.csv` or `*.parquet` files contained in the directory recursively.
+ * `$METADATA_DIR`: path to the metadata directory for the input MEDS dataset, expected to contain `codes.parquet`. This is used to retrieve descriptions for codes in MEDS events and convert each code to the retrieved description. Note that if a code has no specific description in `codes.parquet`, it will just treat that code as a plain text and process the event as it is.
* `$ACES_COHORT_PATH`: path to the defined cohort, which must be a result of [ACES](https://github.com/justin13601/ACES). It can be a directory or the exact file path that has the same file extension with the MEDS dataset to be processed. The file structure of this cohort directory should be the same with the provided MEDS dataset directory (`$MEDS_PATH`) to match each cohort to its corresponding shard data.
* `$PROCESSED_MEDS_DIR`: directory to save processed outputs.
+ * Enabling `--rebase` will renew this directory. If you don't want, please disable this argument.
+ * `$BIRTH_CODE`: string code for the birth event in the dataset, set to `"MEDS_BIRTH"` by default.
* `$NUM_WORKERS`: number of parallel workes to multi-process the script.
* **NOTE: If you encounter this error:** _"polars' maximum length reached. consider installing 'polars-u64-idx'"_, **please consider using more workers or doing `pip install polars-u64-idx`.**
* As a result of this script, you will have .h5 and .tsv files that has a following respective structure:
@@ -219,35 +228,68 @@ Additionally, the following scripts assume your dataset is split into `"train"`,
- Pretrain event encoder
-
-* This stage pretrains event encoder (e.g., GenHPF) using a random event sequence with a length of `max_seq_len` (by default, set to `128`) every epoch for each cohort sample.
-* After completing the pretraining, we should encode all the events in the dataset and cache them to reuse in the following stage.
-* For a shell script to run this, see [`./scripts/meds/pretrain.sh`](./scripts/meds/pretrain.sh).
-* For Python, please run:
- ```shell script
- accelerate launch \
- --config_file config/single.json \
- --num_processes 1 \
- --gpu_ids $GPU_ID \
- main.py \
- --src_data meds \
- --input_path $PROCESSED_MEDS_DIR \
- --save_dir $PRETRAIN_SAVE_DIR \
- --pred_targets meds_single_task \
- --train_type short \
- --lr 5e-5 \
- --random_sample \
- # if you want to log using wandb
- --wandb \
- --wandb_entity_name $wandb_entity_name \
- --wandb_project_name $wandb_project_name
- ```
- * `$PROCESSED_MEDS_DIR`: directory containing processed MEDS data, expected to contain `*.h5` and `*.tsv` files.
- * `$PRETRAIN_SAVE_DIR`: output directory to save the checkpoint for the pretrained event encoder.
- * `$GPU_ID`: GPU index to be used for training the model.
+ Pretrain GenHPF
+
+* After pre-processing the dataset, we can train GenHPF in two directions: 1) following the original GenHPF setup (e.g., loading the last 256 events for each sample), or 2) following the REMed setup (e.g., random sampling 256 events for each sample) to further train the model with an additional retriever module.
+* **Following the original GenHPF setup**:
+ * For a shell script to run this, see [`./scripts/meds/train_genhpf.sh`](./scripts/meds/train_genhpf.sh).
+ * This script runs the following Python command:
+ ```shell script
+ accelerate launch \
+ --config_file config/config.json \
+ --num_processes $NUM_PROCESSES \
+ --gpu_ids $GPU_IDS \
+ main.py \
+ --src_data meds \
+ --input_path $PROCESSED_MEDS_DIR \
+ --save_dir $SAVE_DIR \
+ --pred_targets meds_single_task \
+ --train_type short \
+ --lr 5e-5 \
+ --n_agg_layers 4 \
+ --pred_dim 128 \
+ --batch_size 64 \
+ --max_seq_len 512 \
+ --dropout 0.3 \
+ --seed 2020 \
+ --patience 5 \
+ # if you want to log using wandb
+ --wandb \
+ --wandb_project_name $wandb_entity_name \
+ --wandb_entity_name $wandb_project_name
+ ```
+ * `$PROCESSED_MEDS_DIR`: directory containing processed MEDS data, expected to contain `*.h5` and `*.tsv` files.
+ * `$PRETRAIN_SAVE_DIR`: output directory to save the checkpoint for the pretrained event encoder.
+ * `$NUM_PROCESSES`: number of parallel processes.
+ * `$GPU_IDS`: comma separated list indicating GPU indices (e.g., `0` or `0,1`) to be used for training the model.
+ * Checkpoint will be saved to `$SAVE_DIR/` with a subdirectory indicating the random seed. For example, if `$SAVE_DIR=/workspace/genhpf` and `--seed=2020`, then the checkpoint will be saved in `workspace/genhpf/genhpf_2020`.
+ * **To get the final prediction results, jump to the last step.**
+* **Following the REMed setup**:
+ * For a shell script to run this, see [`./scripts/meds/pretrain_genhpf.sh`](./scripts/meds/pretrain_genhpf.sh).
+ * This script runs the following Python command:
+ ```shell script
+ accelerate launch \
+ --config_file config/config.json \
+ --num_processes $NUM_PROCESSES \
+ --gpu_ids $GPU_IDS \
+ main.py \
+ --src_data meds \
+ --input_path $PROCESSED_MEDS_DIR \
+ --save_dir $SAVE_DIR \
+ --pred_targets meds_single_task \
+ --train_type short \
+ --lr 5e-5 \
+ --batch_size 32 \
+ --random_sample \
+ --seed 2020 \
+ --patience 5 \
+ # if you want to log using wandb
+ --wandb \
+ --wandb_entity_name $wandb_entity_name \
+ --wandb_project_name $wandb_project_name
+ ```
* It will pretrain event encoder using the processed MEDS data, which will be used to encode all events present in the MEDS data for the REMed model later.
- * Checkpoint for the pretrained event encoder will be saved to `$PRETRAIN_SAVE_DIR/${EXPERIMENT_NAME}` directory, where `${EXPERIMENT_NAME}` is a 32-length hexadecimal string generated automatically for each unique experiment.
+ * After completing the pretraining, we should encode all the events in the dataset and cache them to reuse in the following stage.
@@ -255,30 +297,47 @@ Additionally, the following scripts assume your dataset is split into `"train"`,
Encode all events present in the input MEDS data, and cache them
* In this stage, we encode all events present in the input MEDS data, and cache them, which will be input data for the REMed model.
-* For a shell script to run this, see [`./scripts/meds/encode_events.sh`](./scripts/meds/encode_events.sh).
-* For Python, please run:
+* To do this, we should firstly extract all the unique events existed in the MEDS dataset by the following Python command:
```shell script
- accelerate launch \
- --config_file config/single.json \
- --num_processes 1 \
- --gpu_ids="$GPU_ID" \
- main.py \
- --src_data meds \
- --input_path $PROCESSED_MEDS_DIR \
- --save_dir $ENCODED_MEDS_DIR \
- --pred_targets meds_single_task \
- --train_type short \
- --random_sample \
- --encode_events \
- --encode_only \
- --resume_name $PRETRAINED_CHECKPOINT_DIR
+ python ./scripts/meds/extract_unique_events.py \
+ $PROCESSED_MEDS_DIR \
+ --output_dir $UNIQUE_EVENTS_DIR \
+ --workers $NUM_WORKERS
+ ```
+ * This script will save the unique events to `$UNIQUE_EVENTS_DIR`.
+ * **Note that you don't need to run this script multiple times per dataset; just run only once per dataset.**
+* Then, we can encode unique events using the pre-traiend event encoder (i.e., GenHPF) by:
+ * For a shell script, see [`./scripts/meds/encode_events.sh`](./scripts/meds/encode_events.sh).
+ * This script runs the following Python command:
+ ```shell script
+ accelerate launch \
+ --config_file config/single.json \
+ --num_processes 1 \
+ --gpu_ids $GPU_ID \
+ main.py \
+ --src_data meds \
+ --input_path null \
+ --unique_events_path $UNIQUE_EVENTS_DIR \
+ --save_dir $ENCODED_EVENTS_DIR \
+ --pred_targets meds_single_task \
+ --train_type short \
+ --batch_size 8192 \
+ --encode_events \
+ --encode_only \
+ --resume_name $PRETRAINED_CHECKPOINT_DIR
+ ```
+ * `$PRETRAINED_CHECKPOINT_DIR`: directory containing checkpoint for the pretrained event encoder containing `checkpoint_best.pt`.
+ * This script will generate `event_to_vec.pkl` to `$ENCODED_EVENTS_DIR`, which is a look-up table to encode each event to its embedding vector by the pretrained event encoder.
+* Finally, we will encode all the samples in the MEDS dataset using the look-up table (`event_to_vec.pkl`).
+ ```shell script
+ python ./scripts/meds/map_events_to_vec.py \
+ $PROCESSED_MEDS_DIR \
+ --map_dir $ENCODED_EVENTS_DIR \
+ --output_dir $ENCODED_MEDS_DIR \
+ --workers $NUM_WORKERS
```
- * `$PROCESSED_MEDS_DIR`: directory containing processed MEDS data, expected to contain `*.h5` and `*.tsv` files.
- * `$ENCODED_MEDS_DIR`: output directory to save the encoded data where the file names will be `*_encoded.h5`.
- * `$GPU_ID`: GPU index to be used for running the model.
- * `$PRETRAINED_CHECKPOINT_DIR`: directory containing checkpoint for the pretrained event encoder, expected to be `$PRETRAIN_SAVE_DIR/${EXPERIMENT_NAME}` containing `checkpoint_best.pt`.
- * It will encode all events present in the processed meds data (`*.h5`) located in `$PROCESSED_MEDS_DIR`, and save the results into `ENCODED_MEDS_DIR/*_encoded.h5`.
- * Note that it requires large empty disk space (>200G) to save all the encoded events to the storage. This process will take about 3 hours (for ~7500 steps).
+ * This script will encode all events present in the processed meds data (`*.h5`) located in `$PROCESSED_MEDS_DIR`, and save the results into `$ENCODED_MEDS_DIR/*_encoded.h5`.
+ * **Note that it requires large empty disk space to save all the encoded events to the storage.**
@@ -287,13 +346,13 @@ Additionally, the following scripts assume your dataset is split into `"train"`,
* In this stage, we finally train the REMed model using the encoded MEDS data.
* After training ends, it will save the best checkpoint for the trained REMed model.
-* For a shell script to run this, see [`./scripts/meds/train.sh`](./scripts/meds/train.sh).
-* For Python, please run:
+* For a shell script to run this, see [`./scripts/meds/train_remed.sh`](./scripts/meds/train_remed.sh).
+* This script runs the following Python command:
```shell script
accelerate launch \
- --config_file config/single.json \
- --num_processes 1 \
- --gpu_ids $GPU_ID \
+ --config_file config/config.json \
+ --num_processes $NUM_PROCESSES \
+ --gpu_ids $GPU_IDS \
main.py \
--src_data meds \
--input_path $ENCODED_MEDS_DIR \
@@ -301,9 +360,11 @@ Additionally, the following scripts assume your dataset is split into `"train"`,
--pred_targets meds_single_task \
--train_type remed \
--lr 1e-5 \
+ --batch_size 32 \
--scorer \
--scorer_use_time \
--max_seq_len 200000 \
+ --max_retrieve_len 512 \
# if you want to log using wandb
--wandb \
--wandb_entity_name $wandb_entity_name \
@@ -311,16 +372,15 @@ Additionally, the following scripts assume your dataset is split into `"train"`,
```
* `$ENCODED_MEDS_DIR`: directory containing encoded MEDS data, expected to contain `*_encoded.h5` files.
* `$REMED_SAVE_DIR`: output directory to save the REMed model checkpoint.
- * `$GPU_ID`: GPU index to be used for running the model.
- Generate predicted results to the test cohort dataframe for a given task using trained REMed model
+ Generate predicted results to the test cohort dataframe for a given task using trained model (GenHPF or REMed)
-* In this final stage, we load the trained REMed model to do prediction on the test cohort for a given task, and generate the predicted results as two additional columns, `predicted_label` and `predicted_prob`, to the test cohort dataframe.
+* In this final stage, we load the trained model to do prediction on the test cohort for a given task, and generate the predicted results as two additional columns, `predicted_label` and `predicted_prob`, to the test cohort dataframe.
* For a shell script to run this, see [`./scripts/meds/predict.sh`](./scripts/meds/predict.sh).
-* For Python, please run:
+* This script runs the following Python command:
```shell script
accelerate launch \
--config_file config/single.json \
@@ -328,21 +388,31 @@ Additionally, the following scripts assume your dataset is split into `"train"`,
--gpu_ids $GPU_ID \
main.py \
--src_data meds \
- --input_path $ENCODED_MEDS_DIR \
+ --input_path $MEDS_DATA_DIR \
--save_dir $SAVE_DIR \
--pred_targets meds_single_task \
- --train_type remed \
- --scorer \
- --scorer_use_time \
+ --train_type $REMED_OR_SHORT \
--test_only \
--test_cohort $ACES_TEST_COHORT_DIR \
--resume_name $CHECKPOINT_DIR
+ # enable the following arguments for GenHPF model
+ # --n_agg_layers 4 \
+ # --pred_dim 128 \
+ # --max_seq_len 512 \
+ # --dropout 0.3 \
+
+ # enable the following arguments for REMed model
+ # --scorer \
+ # --scorer_use_time \
+ # --max_seq_len 200000 \
+ # --max_retrieve_len 512 \
```
- * `$ENCODED_MEDS_DIR`: directory containing encoded MEDS data, expected to contain `*_encoded.h5` files.
- * `$SAVE_DIR`: output directory to save the predicted results, which will be `$test_subset.parquet`. the results will be saved to `${SAVE_DIR}/${EXPERIMENT_NAME}` directory. this result file has the same rows with the test cohort dataframe provided with `$ACES_TEST_COHORT_DIR`, but has two additional columns: `predicted_label` and `predicted_prob`
- * `$GPU_ID`: GPU index to be used for running the model.
- * `$ACES_TEST_COHORT_DIR`: directory containing test cohorts generated from ACES, expected to contain `*.parquet` files.
- * `$CHECKPOINT_DIR`: directory containing checkpoint for the trained REMed model, expected to be `$REMED_SAVE_DIR/${EXPERIMENT_NAME}`
+ * `$REMED_OR_SHORT`: `"remed"` for REMed model, `"short"` for GenHPF model.
+ * `$MEDS_DATA_DIR`: directory containing MEDS data. Specifically, same with `$PROCESSED_MEDS_DIR` for GenHPF, or `$ENCODED_MEDS_DIR` for REMed.
+ * `$SAVE_DIR`: output directory to save the predicted results as `$test_subset.parquet` (e.g., `held_out.parquet`). This result file has the same rows with the test cohort dataframe provided with `$ACES_TEST_COHORT_DIR`, but has two additional columns: `predicted_label` and `predicted_prob`.
+ * `$ACES_TEST_COHORT_DIR`: directory containing **test (held_out)** cohorts generated from ACES, expected to contain `*.parquet` files.
+ * `$CHECKPOINT_DIR`: directory containing checkpoint for the trained REMed model containing `checkpoint_best.pt`.
+ * **Note that this script doesn't support parallel processing currently. Please use single GPU only.**
diff --git a/main.py b/main.py
index b71a8e4..e76630d 100644
--- a/main.py
+++ b/main.py
@@ -35,38 +35,46 @@ def get_parser():
"--src_data",
type=str,
choices=["eicu", "mimiciv", "umcdb", "hirid", "meds"],
- default="mimiciv"
+ default="mimiciv",
)
parser.add_argument(
"--train_subset",
type=str,
default="train",
help="file name without extension to load data for the training. only used when"
- "`--src_data` is set to `'meds'`."
+ "`--src_data` is set to `'meds'`.",
)
parser.add_argument(
"--valid_subset",
type=str,
default="tuning",
help="file name without extension to load data for the validation. only used when"
- "`--src_data` is set to `'meds'`."
+ "`--src_data` is set to `'meds'`.",
)
parser.add_argument(
"--test_subset",
type=str,
default="held_out",
help="file name without extension to load data for the test. only used when `--src_data` "
- "is set to `'meds'`."
+ "is set to `'meds'`.",
)
+ parser.add_argument(
+ "--unique_events_path",
+ type=str,
+ default=None,
+ help="path to directory containing `unique_events.h5` to encode events in MEDS dataset. "
+ "only used when `--src_data` is set to `'meds'`",
+ )
+
parser.add_argument(
"--test_cohort",
type=str,
default=None,
help="path to the test cohort, which must be a result of ACES. it can be either of "
- "directory or the exact file path that has .parquet file extension. if provided with "
- "directory, it tries to load `${test_subset}`/*.parquet files in the directory. "
- "note that the set of patient ids in this cohort should be matched with that in the "
- "test dataset"
+ "directory or the exact file path that has .parquet file extension. if provided with "
+ "directory, it tries to load `${test_subset}`/*.parquet files in the directory. "
+ "note that the set of patient ids in this cohort should be matched with that in the "
+ "test dataset",
)
parser.add_argument(
@@ -76,11 +84,12 @@ def get_parser():
"readmission",
"los_7",
"los_14",
- "mortality_1",
+ "mortality" "mortality_1",
"mortality_2",
"mortality_3",
"mortality_7",
"mortality_14",
+ "mortality",
"diagnosis",
"creatinine_1",
"creatinine_2",
@@ -100,7 +109,7 @@ def get_parser():
"sodium_1",
"sodium_2",
"sodium_3",
- "meds_single_task"
+ "meds_single_task",
],
default=[
"readmission",
@@ -197,7 +206,9 @@ def get_parser():
parser.add_argument("--debug", action="store_true")
parser.add_argument("--log_loss", action="store_true")
# Wandb
- parser.add_argument("--wandb", action="store_true", help="whether to log using wandb")
+ parser.add_argument(
+ "--wandb", action="store_true", help="whether to log using wandb"
+ )
parser.add_argument("--wandb_entity_name", type=str)
parser.add_argument("--wandb_project_name", type=str, default="REMed")
parser.add_argument("--pretrained", type=str, default=None)
diff --git a/scripts/meds/encode_events.sh b/scripts/meds/encode_events.sh
index e6d3a3a..a2e1eff 100644
--- a/scripts/meds/encode_events.sh
+++ b/scripts/meds/encode_events.sh
@@ -2,15 +2,15 @@
# Function to display help message
function display_help() {
- echo "Usage: $0 "
+ echo "Usage: $0 "
echo
echo "This script encodes all events present in a MEDS cohort and caches them, which will"
echo "be the input data for the REMed model."
echo
echo "Arguments:"
- echo " PROCESSED_MEDS_DIR Directory containing processed MEDS data, expected to contain *.h5 and *.tsv files."
- echo " SAVE_DIR Output directory to save the encoded data as *_encoded.h5."
echo " GPU_ID GPU index to be used for training the model."
+ echo " UNIQUE_EVENTS_DIR directory containing the unique events to be encoded."
+ echo " SAVE_DIR Output directory to save the encoded unique events."
echo " PRETRAINED_CHECKPOINT_DIR Directory containing checkpoint for the pretrained event encoder, expected to contain checkpoint_best.pt."
echo
echo "Options:"
@@ -24,9 +24,9 @@ if [ "$#" -lt 4 ]; then
display_help
fi
-PROCESSED_MEDS_DIR="$1"
-SAVE_DIR="$2"
-GPU_ID="$3"
+GPU_ID="$1"
+UNIQUE_EVENTS_DIR="$2"
+SAVE_DIR="$3"
PRETRAINED_CHECKPOINT_DIR="$4"
accelerate launch \
@@ -35,11 +35,12 @@ accelerate launch \
--gpu_ids="$GPU_ID" \
main.py \
--src_data meds \
- --input_path "$PROCESSED_MEDS_DIR" \
+ --input_path null \
+ --unique_events_path "$UNIQUE_EVENTS_DIR" \
--save_dir "$SAVE_DIR" \
--pred_targets meds_single_task \
--train_type short \
- --random_sample \
+ --batch_size 8192 \
--encode_events \
--encode_only \
- --resume_name "$PRETRAINED_CHECKPOINT_DIR"
\ No newline at end of file
+ --resume_name "$PRETRAINED_CHECKPOINT_DIR"
diff --git a/scripts/meds/extract_unique_events.py b/scripts/meds/extract_unique_events.py
new file mode 100644
index 0000000..809f509
--- /dev/null
+++ b/scripts/meds/extract_unique_events.py
@@ -0,0 +1,137 @@
+import glob
+import logging
+import math
+import multiprocessing
+import os
+import shutil
+import sys
+from argparse import ArgumentParser
+from typing import List
+
+import h5pickle
+import numpy as np
+from tqdm import tqdm
+
+logging.basicConfig(
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
+ stream=sys.stdout,
+)
+logger = logging.getLogger(__name__)
+
+
+def get_parser():
+ parser = ArgumentParser()
+ parser.add_argument(
+ "root",
+ help="path to the **processed** MEDS dataset containing subdirectories for each split. "
+ "it will try to scan all **/*.h5 files existed in this directory and process them.",
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="outputs",
+ help="directory to save the processed outputs.",
+ )
+ parser.add_argument(
+ "--workers",
+ metavar="N",
+ default=1,
+ type=int,
+ help="number of parallel workers.",
+ )
+ parser.add_argument(
+ "--n_events_per_shard",
+ metavar="N",
+ default=1000000,
+ type=int,
+ help="number of events included for each shard",
+ )
+
+ return parser
+
+
+def main(args):
+ filelist = glob.glob(os.path.join(args.root, "**/*.h5"))
+ files = [h5pickle.File(fname) for fname in filelist]
+
+ if args.workers <= 1:
+ unique_events = _extract_unique_events(files)
+ else:
+ n = args.workers
+ files_chunks = [files[i::n] for i in range(n)]
+
+ pool = multiprocessing.get_context("spawn").Pool(processes=args.workers)
+ unique_events_gathered = pool.map(_extract_unique_events, files_chunks)
+ pool.close()
+ pool.join()
+
+ logger.info("Gathering and reducing local unique events...")
+ unique_events = np.concatenate(unique_events_gathered)
+ unique_events = np.unique(unique_events, axis=0)
+ logger.info("Done!")
+
+ # rebase the output directory
+ if os.path.exists(os.path.join(args.output_dir, "unique_events")):
+ shutil.rmtree(os.path.join(args.output_dir, "unique_events"))
+ os.makedirs(os.path.join(args.output_dir, "unique_events"))
+
+ num_shards = math.ceil(len(unique_events) / args.n_events_per_shard)
+ for shard_id in range(num_shards):
+ start = shard_id * args.n_events_per_shard
+ end = min((shard_id + 1) * args.n_events_per_shard, len(unique_events))
+ sharded_unique_events = unique_events[start:end]
+ with h5pickle.File(
+ os.path.join(
+ args.output_dir, "unique_events", f"unique_events_{shard_id}.h5"
+ ),
+ "w",
+ ) as f:
+ for i, event_tuple in tqdm(
+ enumerate(sharded_unique_events), total=len(sharded_unique_events)
+ ):
+ idx = str(shard_id * args.n_events_per_shard + i)
+ data = f.create_group(idx)
+
+ sources = np.stack(
+ [
+ np.array(event_tuple[0]),
+ np.array(event_tuple[1]),
+ np.array(event_tuple[2]),
+ ]
+ )
+ data.create_dataset(
+ "sources", data=sources, dtype="i2", compression="lzf", shuffle=True
+ )
+
+
+def _extract_unique_events(files: List[h5pickle.File]):
+ unique_events = []
+ pbar = tqdm(files, total=len(files))
+ for f in pbar:
+ pbar.set_description(f.filename)
+ for sbj_id in f["ehr"]:
+ input_ids = f["ehr"][sbj_id]["hi"][:, 0]
+ type_ids = f["ehr"][sbj_id]["hi"][:, 1]
+ dpe_ids = f["ehr"][sbj_id]["hi"][:, 2]
+
+ event_tokens = [
+ (
+ tuple(input_id),
+ tuple(type_id),
+ tuple(dpe_id),
+ )
+ for input_id, type_id, dpe_id in zip(input_ids, type_ids, dpe_ids)
+ ]
+ event_tokens = list(np.unique(event_tokens, axis=0))
+ unique_events.extend(event_tokens)
+ unique_events = list(np.unique(unique_events, axis=0))
+
+ return unique_events
+
+
+if __name__ == "__main__":
+ parser = get_parser()
+ args = parser.parse_args()
+ main(args)
diff --git a/scripts/meds/map_events_to_vec.py b/scripts/meds/map_events_to_vec.py
new file mode 100644
index 0000000..211a9be
--- /dev/null
+++ b/scripts/meds/map_events_to_vec.py
@@ -0,0 +1,133 @@
+import functools
+import glob
+import logging
+import multiprocessing
+import os
+import pickle
+import shutil
+import sys
+from argparse import ArgumentParser
+from typing import List
+
+import h5pickle
+import numpy as np
+from tqdm import tqdm
+
+logging.basicConfig(
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
+ stream=sys.stdout,
+)
+logger = logging.getLogger(__name__)
+
+global_event_to_vec = None
+
+
+def init_worker(event_to_vec):
+ global global_event_to_vec
+ global_event_to_vec = event_to_vec
+ logger.info(
+ f"Process {multiprocessing.current_process().name} initialized with data"
+ )
+
+
+def get_parser():
+ parser = ArgumentParser()
+ parser.add_argument(
+ "root",
+ help="path to the **processed** MEDS dataset containing subdirectories for each split. "
+ "it will try to scan all **/*.h5 files existed in this directory except for "
+ "unique_events_*.h5 and process them.",
+ )
+ parser.add_argument(
+ "--map_dir",
+ help="path to the directory containing **`event_to_vec.pkl`**, which is generated by "
+ "running the model with `--encode_events=True` and `--encode_only=True`",
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="outputs",
+ help="directory to save the processed outputs.",
+ )
+ parser.add_argument(
+ "--workers",
+ metavar="N",
+ default=1,
+ type=int,
+ help="number of parallel workers.",
+ )
+
+ return parser
+
+
+def main(args):
+ filelist = glob.glob(os.path.join(args.root, "**/*.h5"))
+ files = [h5pickle.File(fname) for fname in filelist if "unique_events" not in fname]
+ logger.info("Reading mapping dictionary for event vectors...")
+ with open(os.path.join(args.map_dir, "event_to_vec.pkl"), "rb") as f:
+ event_to_vec = pickle.load(f)
+
+ subdirs = [
+ os.path.relpath(os.path.dirname(p), os.path.abspath(args.root))
+ for p in filelist
+ if "unique_events" not in p
+ ]
+ output_dirs = [os.path.join(args.output_dir, subdir) for subdir in subdirs]
+ for subdir in np.unique(subdirs):
+ if os.path.exists(os.path.join(args.output_dir, subdir)):
+ shutil.rmtree(os.path.join(args.output_dir, subdir))
+ os.makedirs(os.path.join(args.output_dir, subdir))
+
+ logger.info("Mapping events to their representation vectors...")
+ if args.workers <= 1:
+ _map_events_to_vec(event_to_vec, output_dirs, files)
+ else:
+ n = args.workers
+ files_chunks = [files[i::n] for i in range(n)]
+ output_dir_chunks = [output_dirs[i::n] for i in range(n)]
+
+ pool = multiprocessing.Pool(
+ processes=args.workers, initializer=init_worker, initargs=(event_to_vec,)
+ )
+ pool.starmap(_map_events_to_vec, zip(output_dir_chunks, files_chunks))
+ pool.close()
+ pool.join()
+
+
+def _map_events_to_vec(output_dirs: List[str], files: List[h5pickle.File]):
+ assert len(files) == len(output_dirs)
+ for f, output_dir in zip(files, output_dirs):
+ output_name = os.path.basename(f.filename).split(".h5")[0] + "_encoded.h5"
+
+ with h5pickle.File(os.path.join(output_dir, output_name), "w") as output_f:
+ output_f.create_group("ehr")
+ for sbj_id in tqdm(f["ehr"], total=len(f["ehr"]), desc=output_name):
+ input_ids = f["ehr"][sbj_id]["hi"][:, 0] # S, 128
+ event_tuples = [tuple(event[event != 0]) for event in input_ids]
+ event_vectors = np.array(
+ [global_event_to_vec[event] for event in event_tuples]
+ )
+ output_f["ehr"].create_group(sbj_id)
+ output_f["ehr"][sbj_id].create_dataset(
+ "encoded",
+ data=event_vectors,
+ dtype="i2",
+ compression="lzf",
+ shuffle=True,
+ chunks=event_vectors.shape,
+ )
+ output_f["ehr"][sbj_id].create_dataset(
+ "time",
+ data=f["ehr"][sbj_id]["time"][()],
+ )
+ output_f["ehr"][sbj_id].create_dataset(
+ "label", data=f["ehr"][sbj_id]["label"][()]
+ )
+
+
+if __name__ == "__main__":
+ parser = get_parser()
+ args = parser.parse_args()
+ main(args)
diff --git a/scripts/meds/predict.sh b/scripts/meds/predict.sh
index c37afc7..78e6303 100644
--- a/scripts/meds/predict.sh
+++ b/scripts/meds/predict.sh
@@ -2,15 +2,16 @@
# Function to display help message
function display_help() {
- echo "Usage: $0 "
+ echo "Usage: $0 "
echo
echo "This script produces predicted labels and their probabilities for a given task and its"
echo "cohort."
echo
echo "Arguments:"
- echo " ENCODED_MEDS_DIR Directory containing encoded MEDS data, expected to contain *_encoded.h5 files"
- echo " SAVE_DIR Output directory to save the predicted results."
echo " GPU_ID GPU index to be used for training the model."
+ echo " MEDS_DATA_DIR Directory containing MEDS data. Same with `$PROCESSED_MEDS_DIR` for GenHPF, or `$ENCODED_MEDS_DIR` for REMed"
+ echo " SAVE_DIR Output directory to save the predicted results."
+ echo " REMED_OR_SHORT String indicator for whether to test REMed model ('remed') or GenHPF model ('short)"
echo " ACES_TEST_COHORT_DIR Directory containing test cohorts generated from ACES, expected to contain *.parquet files."
echo " CHECKPOINT_DIR Directory containing checkpoint for the trained REMed model, expected to contain checkpoint_best.pt."
echo
@@ -25,11 +26,12 @@ if [ "$#" -lt 5 ]; then
display_help
fi
-ENCODED_MEDS_DIR="$1"
-SAVE_DIR="$2"
-GPU_ID="$3"
-ACES_TEST_COHORT_DIR="$4"
-CHECKPOINT_DIR="$5"
+GPU_ID="$1"
+MEDS_DATA_DIR="$2"
+SAVE_DIR="$3"
+REMED_OR_SHORT="$4"
+ACES_TEST_COHORT_DIR="$5"
+CHECKPOINT_DIR="$6"
accelerate launch \
--config_file config/single.json \
@@ -37,12 +39,21 @@ accelerate launch \
--gpu_ids $GPU_ID \
main.py \
--src_data meds \
- --input_path "$ENCODED_MEDS_DIR" \
- --save_dir "$SAVE_DIR" \
+ --input_path $MEDS_DATA_DIR \
+ --save_dir $SAVE_DIR \
--pred_targets meds_single_task \
- --train_type remed \
- --scorer \
- --scorer_use_time \
+ --train_type $REMED_OR_SHORT \
--test_only \
- --test_cohort "$ACES_TEST_COHORT_DIR" \
- --resume_name "$CHECKPOINT_DIR"
\ No newline at end of file
+ --test_cohort $ACES_TEST_COHORT_DIR \
+ --resume_name $CHECKPOINT_DIR
+ # enable the following arguments for GenHPF model
+ # --n_agg_layers 4 \
+ # --pred_dim 128 \
+ # --max_seq_len 512 \
+ # --dropout 0.3 \
+
+ # enable the following arguments for REMed model
+ # --scorer \
+ # --scorer_use_time \
+ # --max_seq_len 200000 \
+ # --max_retrieve_len 512 \
diff --git a/scripts/meds/pretrain.sh b/scripts/meds/pretrain_genhpf.sh
similarity index 64%
rename from scripts/meds/pretrain.sh
rename to scripts/meds/pretrain_genhpf.sh
index de6b158..be78053 100644
--- a/scripts/meds/pretrain.sh
+++ b/scripts/meds/pretrain_genhpf.sh
@@ -2,15 +2,16 @@
# Function to display help message
function display_help() {
- echo "Usage: $0 "
+ echo "Usage: $0 "
echo
echo "This script pretrains event encoder using a MEDS cohort, which will be used to encode"
echo "all events present in the MEDS cohort for the REMed model later."
echo
echo "Arguments:"
+ echo " NUM_PROCESSES Number of parallel processes"
+ echo " GPU_IDS GPU index to be used for training the model."
echo " PROCESSED_MEDS_DIR Directory containing processed MEDS data, expected to contain *.h5 and *.tsv files."
echo " SAVE_DIR Output directory to save the checkpoint for the pretrained event encoder."
- echo " GPU_ID GPU index to be used for training the model."
echo
echo "Options:"
echo " -h, --help Display this help message and exit."
@@ -23,15 +24,15 @@ if [ "$#" -lt 3 ]; then
display_help
fi
-
-PROCESSED_MEDS_DIR="$1"
-SAVE_DIR="$2"
-GPU_ID="$3"
+NUM_PROCESSES="$1"
+GPU_IDS="$2"
+PROCESSED_MEDS_DIR="$3"
+SAVE_DIR="$4"
accelerate launch \
- --config_file config/single.json \
- --num_processes 1 \
- --gpu_ids="$GPU_ID" \
+ --config_file config/config.json \
+ --num_processes $NUM_PROCESSES \
+ --gpu_ids="$GPU_IDS" \
main.py \
--src_data meds \
--input_path "$PROCESSED_MEDS_DIR" \
@@ -39,4 +40,10 @@ accelerate launch \
--pred_targets meds_single_task \
--train_type short \
--lr 5e-5 \
- --random_sample
\ No newline at end of file
+ --batch_size 32 \
+ --random_sample \
+ --seed 2020 \
+ --patience 5 \
+ # --wandb \
+ # --wandb_project_name ??? \
+ # --wandb_entity_name ???
diff --git a/scripts/meds/process_meds.py b/scripts/meds/process_meds.py
index 65f6e7c..ff8f026 100644
--- a/scripts/meds/process_meds.py
+++ b/scripts/meds/process_meds.py
@@ -5,7 +5,7 @@
import os
import re
import shutil
-import time
+import warnings
from argparse import ArgumentParser
from bisect import bisect_left, bisect_right
from datetime import datetime
@@ -18,6 +18,9 @@
from tqdm import tqdm
from transformers import AutoTokenizer
+pool_manager = multiprocessing.Manager()
+warned_codes = pool_manager.list()
+
def find_boundary_between(tuples_list, start, end):
starts = [s for s, e in tuples_list]
@@ -44,6 +47,13 @@ def get_parser():
help="path to metadata directory for the input MEDS dataset, which contains codes.parquet",
)
+ parser.add_argument(
+ "--birth_code",
+ type=str,
+ default="MEDS_BIRTH",
+ help="string code for the birth event in the dataset."
+ )
+
parser.add_argument(
"--cohort",
type=str,
@@ -78,6 +88,13 @@ def get_parser():
help="number of parallel workers.",
)
+ parser.add_argument(
+ "--mimic_dir",
+ default=None,
+ help="path to directory for MIMIC-IV database containing hosp/ and icu/ as a subdirectory. "
+ "this is used for addressing missing descriptions in the metadata for MIMIC-IV codes.",
+ )
+
return parser
@@ -85,6 +102,7 @@ def main(args):
root_path = Path(args.root)
output_dir = Path(args.output_dir)
metadata_dir = Path(args.metadata_dir)
+ mimic_dir = Path(args.mimic_dir) if args.mimic_dir is not None else None
if not output_dir.exists():
output_dir.mkdir()
@@ -115,6 +133,24 @@ def main(args):
codes_metadata = pl.read_parquet(metadata_dir / "codes.parquet").to_pandas()
codes_metadata = codes_metadata.set_index("code")["description"].to_dict()
+ # do not allow to use static events or birth event
+ birth_code = args.birth_code
+ # if birth_code not in codes_metadata:
+ # print(
+ # f'"{birth_code}" is not found in the codes metadata, which may lead to '
+ # "unexpected results since we currently exclude this event from the input data. "
+ # )
+
+ if mimic_dir is not None:
+ d_items = pd.read_csv(mimic_dir / "icu" / "d_items.csv.gz")
+ d_items["itemid"] = d_items["itemid"].astype("str")
+ d_items = d_items.set_index("itemid")["label"].to_dict()
+ d_labitems = pd.read_csv(mimic_dir / "hosp" / "d_labitems.csv.gz")
+ d_labitems["itemid"] = d_labitems["itemid"].astype("str")
+ d_labitems = d_labitems.set_index("itemid")["label"].to_dict()
+ else:
+ d_items = None
+ d_labitems = None
progress_bar = tqdm(data_paths, total=len(data_paths))
for data_path in progress_bar:
@@ -129,15 +165,6 @@ def main(args):
else:
raise ValueError(f"Unsupported file format: {data_path.suffix}")
- # do not allow to use static events or birth event
- birth_code = (
- "MEDS_BIRTH" # NOTE can we assume code for "birth" is always "MEDS_BIRTH"?
- )
- if birth_code not in codes_metadata:
- print(
- f'"{birth_code}" is not found in the codes metadata, which may lead to '
- "unexpected results since we currently exclude this event from the input data. "
- )
data = data.with_columns(
pl.when(pl.col("code") == birth_code)
.then(None)
@@ -156,17 +183,21 @@ def main(args):
raise ValueError(f"Unsupported file format: {cohort_path.suffix}")
cohort = cohort.drop_nulls(label_col_name)
+ cohort = cohort.unique()
cohort = cohort.select(
- [pl.col("subject_id"),
+ [
+ pl.col("subject_id"),
pl.col(label_col_name),
# pl.col("input.end_summary").struct.field("timestamp_at_start").alias("starttime"),
pl.col("prediction_time").alias("endtime"),
]
)
- cohort = cohort.group_by(
- "subject_id", maintain_order=True
- ).agg(pl.col(["endtime", label_col_name])).collect() # omitted "starttime"
+ cohort = (
+ cohort.group_by("subject_id", maintain_order=True)
+ .agg(pl.col(["endtime", label_col_name]))
+ .collect()
+ ) # omitted "starttime"
cohort_dict = {
x["subject_id"]: {
# "starttime": x["starttime"],
@@ -208,12 +239,17 @@ def extract_cohort(row):
return {"cohort_end": None, "cohort_label": None}
data = data.group_by(["subject_id", "time"], maintain_order=True).agg(pl.all())
- data = data.with_columns(
- pl.struct(["subject_id", "time"])
- .map_elements(
- extract_cohort,
- return_dtype=pl.Struct(
- {"cohort_end": pl.List(pl.Datetime()), "cohort_label": pl.List(pl.Boolean)}
+ data = (
+ data.with_columns(
+ pl.struct(["subject_id", "time"])
+ .map_elements(
+ extract_cohort,
+ return_dtype=pl.Struct(
+ {
+ "cohort_end": pl.List(pl.Datetime()),
+ "cohort_label": pl.List(pl.Boolean),
+ }
+ ),
)
.alias("cohort_criteria")
)
@@ -245,7 +281,12 @@ def extract_cohort(row):
manifest_f.write("subject_id\tnum_events\tshard_id\n")
must_have_columns = [
- "subject_id", "cohort_end", "cohort_label", "time", "code", "numeric_value"
+ "subject_id",
+ "cohort_end",
+ "cohort_label",
+ "time",
+ "code",
+ "numeric_value",
]
rest_of_columns = [x for x in data.columns if x not in must_have_columns]
column_name_idcs = {col: i for i, col in enumerate(data.columns)}
@@ -259,6 +300,9 @@ def extract_cohort(row):
output_dir,
output_name,
args.workers,
+ d_items,
+ d_labitems,
+ warned_codes,
)
# meds --> remed
@@ -268,11 +312,12 @@ def extract_cohort(row):
del data
else:
subject_ids = data["subject_id"].unique().to_list()
- chunksize = math.ceil(len(subject_ids) / args.workers)
+ n = args.workers
+ subject_id_chunks = [subject_ids[i::n] for i in range(n)]
data_chunks = []
- for i in range(0, len(subject_ids), chunksize):
+ for subject_id_chunk in subject_id_chunks:
data_chunks.append(
- data.filter(pl.col("subject_id").is_in(subject_ids[i:i+chunksize]))
+ data.filter(pl.col("subject_id").is_in(subject_id_chunk))
)
del data
pool = multiprocessing.get_context("spawn").Pool(processes=args.workers)
@@ -280,6 +325,8 @@ def extract_cohort(row):
length_per_subject_gathered = pool.map(
meds_to_remed_partial, data_chunks
)
+ pool.close()
+ pool.join()
del data_chunks
if len(length_per_subject_gathered) != args.workers:
@@ -302,8 +349,13 @@ def meds_to_remed(
output_dir,
output_name,
num_shards,
- df_chunk
+ d_items,
+ d_labitems,
+ warned_codes,
+ df_chunk,
):
+ code_matching_pattern = re.compile(r"\d+")
+
def meds_to_remed_unit(row):
events = []
digit_offsets = []
@@ -313,13 +365,52 @@ def meds_to_remed_unit(row):
digit_offset = []
col_name_offset = []
for col_name in ["code", "numeric_value"] + rest_of_columns:
+ # do not process something like "icustay_id" or "hadm_id"
+ if "id" in col_name:
+ continue
+
col_event = row[column_name_idcs[col_name]][event_index]
if col_event is not None:
col_event = str(col_event)
if col_name == "code":
- if col_event in codes_metadata and codes_metadata[col_event] != "":
+ if (
+ col_event in codes_metadata
+ and codes_metadata[col_event] != ""
+ ):
col_event = codes_metadata[col_event]
- elif not "id" in col_name:
+ else:
+ do_break = False
+ items = col_event.split("//")
+ is_code = [
+ bool(code_matching_pattern.fullmatch(item))
+ for item in items
+ ]
+ if True in is_code:
+ if d_items is not None and d_labitems is not None:
+ code_idx = is_code.index(True)
+ code = items[code_idx]
+
+ if code in d_items:
+ desc = d_items[code]
+ elif code in d_labitems:
+ desc = d_labitems[code]
+ else:
+ do_break = True
+
+ if not do_break:
+ items[code_idx] = desc
+ col_event = "//".join(items)
+ else:
+ do_break = True
+
+ if do_break and col_event not in warned_codes:
+ warned_codes.append(col_event)
+ warnings.warn(
+ "The dataset contains some codes that are not specified in "
+ "the codes metadata, which may not be intended. Note that we "
+ f"process this code as it is for now: {col_event}."
+ )
+ else:
col_event = re.sub(
r"\d*\.\d+",
lambda x: str(round(float(x.group(0)), 4)),
@@ -457,22 +548,24 @@ def meds_to_remed_unit(row):
)
events_data = np.concatenate(events_data)
- df_chunk = df_chunk.select(
- ["subject_id", "cohort_end", "cohort_label", "time"]
- )
+ df_chunk = df_chunk.select(["subject_id", "cohort_end", "cohort_label", "time"])
df_chunk = df_chunk.insert_column(4, data_index)
df_chunk = df_chunk.explode(["cohort_end", "cohort_label"])
df_chunk = df_chunk.group_by(
# ["subject_id", "cohort_start", "cohort_end", "cohort_label"], maintain_order=True
- ["subject_id", "cohort_end", "cohort_label"], maintain_order=True
+ ["subject_id", "cohort_end", "cohort_label"],
+ maintain_order=True,
).agg(pl.all())
+ df_chunk = df_chunk.sort(by=['subject_id', 'cohort_end'])
# regard {subject_id} as {cohort_id}: {subject_id}_{cohort_number}
df_chunk = df_chunk.with_columns(
pl.col("subject_id").cum_count().over("subject_id").alias("suffix")
)
df_chunk = df_chunk.with_columns(
- (pl.col("subject_id").cast(str) + "_" + pl.col("suffix").cast(str)).alias("subject_id")
+ (pl.col("subject_id").cast(str) + "_" + pl.col("suffix").cast(str)).alias(
+ "subject_id"
+ )
)
# data = data.drop("suffix", "cohort_start", "cohort_end")
df_chunk = df_chunk.drop("suffix", "cohort_end")
diff --git a/scripts/meds/train_genhpf.sh b/scripts/meds/train_genhpf.sh
new file mode 100644
index 0000000..fe30223
--- /dev/null
+++ b/scripts/meds/train_genhpf.sh
@@ -0,0 +1,53 @@
+#!/bin/bash
+
+# Function to display help message
+function display_help() {
+ echo "Usage: $0 "
+ echo
+ echo "This script pretrains event encoder using a MEDS cohort, which will be used to encode"
+ echo "all events present in the MEDS cohort for the REMed model later."
+ echo
+ echo "Arguments:"
+ echo " NUM_PROCESSES Number of parallel processes"
+ echo " GPU_IDS GPU indices to be used for training the model."
+ echo " PROCESSED_MEDS_DIR Directory containing processed MEDS data, expected to contain *.h5 and *.tsv files."
+ echo " SAVE_DIR Output directory to save the checkpoint for the pretrained event encoder."
+ echo
+ echo "Options:"
+ echo " -h, --help Display this help message and exit."
+ exit 1
+}
+
+# Check for mandatory parameters
+if [ "$#" -lt 3 ]; then
+ echo "Error: Incorrect number of arguments provided."
+ display_help
+fi
+
+
+NUM_PROCESSES="$1"
+GPU_IDS="$2"
+PROCESSED_MEDS_DIR="$3"
+SAVE_DIR="$4"
+
+accelerate launch \
+ --config_file config/config.json \
+ --num_processes $NUM_PROCESSES \
+ --gpu_ids="$GPU_IDS" \
+ main.py \
+ --src_data meds \
+ --input_path "$PROCESSED_MEDS_DIR" \
+ --save_dir "$SAVE_DIR" \
+ --pred_targets meds_single_task \
+ --train_type short \
+ --lr 5e-5 \
+ --n_agg_layers 4 \
+ --pred_dim 128 \
+ --batch_size 64 \
+ --max_seq_len 512 \
+ --dropout 0.3 \
+ --seed 2020 \
+ --patience 5 \
+ # --wandb \
+ # --wandb_project_name ??? \
+ # --wandb_entity_name ???
diff --git a/scripts/meds/train.sh b/scripts/meds/train_remed.sh
similarity index 62%
rename from scripts/meds/train.sh
rename to scripts/meds/train_remed.sh
index bd90c7a..ad0a60f 100644
--- a/scripts/meds/train.sh
+++ b/scripts/meds/train_remed.sh
@@ -2,15 +2,16 @@
# Function to display help message
function display_help() {
- echo "Usage: $0 "
+ echo "Usage: $0 "
echo
echo "This script encodes all the events present in a MEDS cohort and caches them, which will"
echo "be the input data for the REMed model."
echo
echo "Arguments:"
+ echo " NUM_PROCESSES Number of parallel processes"
+ echo " GPU_IDS GPU index to be used for training the model."
echo " ENCODED_MEDS_DIR Directory containing encoded MEDS data, expected to contain *_encoded.h5 files"
echo " SAVE_DIR Output directory to save the model checkpoint"
- echo " GPU_ID GPU index to be used for training the model."
echo
echo "Options:"
echo " -h, --help Display this help message and exit."
@@ -23,21 +24,27 @@ if [ "$#" -lt 3 ]; then
display_help
fi
-ENCODED_MEDS_DIR="$1"
-SAVE_DIR="$2"
-GPU_ID="$3"
+NUM_PROCESSES="$1"
+GPU_IDS="$2"
+ENCODED_MEDS_DIR="$3"
+SAVE_DIR="$4"
accelerate launch \
- --config_file config/single.json \
- --num_processes 1 \
- --gpu_ids $GPU_ID \
+ --config_file config/config.json \
+ --num_processes $NUM_PROCESSES \
+ --gpu_ids $GPU_IDS \
main.py \
--src_data meds \
--input_path "$ENCODED_MEDS_DIR" \
- --save_dir "$SAVE_DIR" \
+ --save_dir "$REMED_SAVE_DIR" \
--pred_targets meds_single_task \
--train_type remed \
--lr 1e-5 \
+ --batch_size 32 \
--scorer \
--scorer_use_time \
- --max_seq_len 200000
\ No newline at end of file
+ --max_seq_len 200000 \
+ --max_retrieve_len 512 \
+ # --wandb \
+ # --wandb_project_name ??? \
+ # --wandb_entity_name ???
diff --git a/src/dataset.py b/src/dataset.py
index f58ad42..bd8ebb1 100644
--- a/src/dataset.py
+++ b/src/dataset.py
@@ -1,3 +1,4 @@
+import glob
import math
import os
import random
@@ -71,6 +72,7 @@ def collate_fn(self, out):
else:
padded = pad_sequence([i[k] for i in out], batch_first=True)
ret[k] = pad(padded, (0, 0, 0, padding_to - padded.shape[1]))
+
return ret
def get_labels(self, data):
@@ -220,6 +222,7 @@ def __init__(self, args, split, data_path, *pargs, **kwargs):
self.manifest = pd.read_csv(
os.path.join(data_path, split + ".tsv"), delimiter="\t"
).set_index("subject_id")
+ self.subject_ids = self.manifest.index.tolist()
unique_shard_ids = self.manifest["shard_id"].unique()
self.data = {}
@@ -247,7 +250,7 @@ def collate_fn(self, samples):
ret[k]["meds_single_task"] = torch.FloatTensor(
torch.stack([s["label"] for s in samples])
)
- elif k in ["subject_id", "index"]: # for MEDSForReprGen
+ elif k == "subject_id":
ret[k] = np.array([s[k] for s in samples])
else:
padded = pad_sequence([s[k] for s in samples], batch_first=True)
@@ -264,64 +267,77 @@ def __getitem__(self, idx):
# assume it is a scalar value for a binary classification task
label = torch.tensor([data["label"][()]]).float()
- if self.args.max_seq_len < input.shape[0]:
+ max_num_events = 300000
+ if self.args.max_seq_len < len(input):
+ length = len(input)
+ if length > max_num_events:
+ times = times - times[-max_num_events]
+
if self.args.random_sample:
indices = random.sample(
- range(0, input.shape[0]), self.args.max_seq_len
+ range(max(0, length - max_num_events), length),
+ self.args.max_seq_len,
)
+ indices.sort()
input = input[indices, :, :]
times = times[indices]
else:
input = input[-self.args.max_seq_len:, :, :]
times = times[-self.args.max_seq_len:]
+ times = times - times[0]
return {
"input_ids": torch.LongTensor(input[:, 0, :]),
"type_ids": torch.LongTensor(input[:, 1, :]),
"dpe_ids": torch.LongTensor(input[:, 2, :]),
"times": torch.IntTensor(times),
- "label": label
+ "label": label,
+ "subject_id": subject_id,
}
-class MEDSForReprGen(MEDSDataset):
- def __init__(self, args, split, data_path, *pargs, **kwargs):
- super().__init__(args, split, data_path, *pargs, **kwargs)
+class MEDSForReprGen(Dataset):
+ def __init__(self, args, data_path):
+ super().__init__()
- self.manifest["num_samples"] = self.manifest["num_events"].map(
- lambda x: math.ceil(x / args.max_seq_len)
- )
- self.manifest["last_sample_index"] = np.cumsum(self.manifest["num_samples"])
+ self.args = args
+
+ self.data = {}
+ if not data_path.endswith("unique_events"):
+ data_path = os.path.join(data_path, "unique_events")
+ for fname in glob.glob(os.path.join(data_path, "*.h5")):
+ shard_id = int(os.path.splitext(fname)[0].split("_")[-1])
+ self.data[shard_id] = h5pickle.File(
+ os.path.join(data_path, f"unique_events_{shard_id}.h5")
+ )
+ self.manifest = {}
+ for shard_id, data in self.data.items():
+ keys = data.keys()
+ shard_manifest = {k: shard_id for k in keys}
+ self.manifest.update(**shard_manifest)
+ self.keys = list(self.manifest.keys())
def __len__(self):
- return self.manifest["last_sample_index"].max()
+ return len(self.manifest)
+
+ def collate_fn(self, samples):
+ input_ids = torch.stack([s["input_ids"] for s in samples])
+ type_ids = torch.stack([s["type_ids"] for s in samples])
+ dpe_ids = torch.stack([s["dpe_ids"] for s in samples])
+
+ ret = {"input_ids": input_ids, "type_ids": type_ids, "dpe_ids": dpe_ids}
+
+ return ret
def __getitem__(self, idx):
- patient_index = self.manifest["last_sample_index"].searchsorted(
- idx, side="right"
- )
- subject_id = self.manifest.index[patient_index]
- shard_id = self.manifest.iloc[patient_index]["shard_id"]
- prev_idx = (
- 0
- if patient_index == 0
- else (self.manifest["last_sample_index"].iloc[patient_index - 1])
- )
- data = self.data[shard_id][str(subject_id)]
+ key = self.keys[idx]
+ shard_id = self.manifest[key]
+ data = self.data[shard_id][key]["sources"] # (3, 128)
- input = data["hi"]
- sample_idx_in_patient = idx - prev_idx
- start = self.args.max_seq_len * sample_idx_in_patient
- end = self.args.max_seq_len * (sample_idx_in_patient + 1)
- label = torch.tensor([data["label"][()]]).float()
return {
- "input_ids": torch.LongTensor(input[:, 0, :][start:end]),
- "type_ids": torch.LongTensor(input[:, 1, :][start:end]),
- "dpe_ids": torch.LongTensor(input[:, 2, :][start:end]),
- "times": torch.IntTensor(data["time"][start:end]),
- "subject_id": subject_id,
- "index": sample_idx_in_patient,
- "label": label,
+ "input_ids": torch.LongTensor(data[0, :]),
+ "type_ids": torch.LongTensor(data[1, :]),
+ "dpe_ids": torch.LongTensor(data[2, :]),
}
@@ -331,21 +347,31 @@ def __init__(self, args, split, data_path, *pargs, **kwargs):
self.args = args
- if not split.endswith("_encoded"):
- split = split + "_encoded"
+ self.data = {}
+ for fname in glob.glob(os.path.join(data_path, split, f"*_encoded.h5")):
+ shard_id = int(os.path.splitext(fname)[0].split("_")[-2])
+ self.data[shard_id] = h5pickle.File(
+ os.path.join(data_path, split, split + f"_{shard_id}_encoded.h5")
+ )["ehr"]
+ self.manifest = {}
+ for shard_id, data in self.data.items():
+ keys = data.keys()
+ shard_manifest = {k: shard_id for k in keys}
+ self.manifest.update(**shard_manifest)
+ self.keys = list(self.manifest.keys())
- self.data = h5pickle.File(os.path.join(data_path, split + ".h5"))["ehr"]
- self.manifest = list(self.data.keys())
+ self.subject_ids = list(self.manifest.keys())
def __len__(self):
- return len(self.data)
+ return len(self.manifest)
def collate_fn(self, samples):
ret = dict()
max_sample_len = max([s["times"].shape[0] for s in samples])
- padding_to = min(
- 2 ** math.ceil(math.log(max_sample_len, 2)), self.args.max_seq_len
- )
+ # padding_to = min(
+ # 2 ** math.ceil(math.log(max_sample_len, 2)), self.args.max_seq_len
+ # )
+ padding_to = max_sample_len
for k, v in samples[0].items():
if k == "times":
@@ -369,7 +395,10 @@ def collate_fn(self, samples):
return ret
def __getitem__(self, idx):
- data = self.data[self.manifest[idx]]
+ subject_id = self.keys[idx]
+ shard_id = self.manifest[subject_id]
+ data = self.data[shard_id][subject_id]
+
encoded = data["encoded"][:]
times = data["time"][:]
@@ -377,14 +406,19 @@ def __getitem__(self, idx):
repr = torch.FloatTensor(encoded)
times = torch.IntTensor(times)
- times = max(times) - times
label = torch.tensor([data["label"][()]]).float()
+ if len(repr) > self.args.max_seq_len:
+ repr = repr[-self.args.max_seq_len :, :]
+ times = times[-self.args.max_seq_len :]
+ # inverse times
+ times = max(times) - times
+
return {
"repr": repr,
"times": times,
"label": label,
- "subject_id": self.manifest[idx]
+ "subject_id": subject_id,
}
diff --git a/src/models/eventencoder.py b/src/models/eventencoder.py
index ecf4eec..6c859f2 100644
--- a/src/models/eventencoder.py
+++ b/src/models/eventencoder.py
@@ -31,6 +31,12 @@ def __init__(self, args):
)
def forward(self, all_codes_embs, input_ids, **kwargs):
+ # all_codes_embs: (B * S, L, Hidden) -- (16 * 512, 128, 512)
+ # input_ids: (B, S, L) -- (16, 512, 128)
+ if input_ids.ndim == 2:
+ assert input_ids.size(0) == all_codes_embs.size(0)
+ input_ids = input_ids.unsqueeze(1) # (B, L) -> (B, 1, L)
+
B, S, L = input_ids.shape
# All-padding col -> cause nan output -> unmask it (and multiply 0 to the results)
src_pad_mask = (input_ids.reshape(-1, L).eq(0)) ^ (
diff --git a/src/models/model.py b/src/models/model.py
index 6b25829..3e9b4c0 100644
--- a/src/models/model.py
+++ b/src/models/model.py
@@ -478,16 +478,16 @@ def __init__(self, args):
self.time_encoder = FlattenTimeEncoding(args)
self.layer_norm = nn.LayerNorm(args.pred_dim, eps=1e-12)
- def forward(self, input_ids, type_ids, dpe_ids, times, **kwargs):
- B, S = input_ids.shape[0], input_ids.shape[1]
-
+ def forward(self, input_ids, type_ids, dpe_ids, times=None, **kwargs):
x = self.input_ids_embedding(input_ids)
x += self.type_ids_embedding(type_ids)
x += self.dpe_ids_embedding(dpe_ids)
if "flatten" in self.args.train_type: # (B, S, E) -> (B, S, E)
x = self.time_encoder(x, times, **kwargs)
- else: # (B, S, W, E) -> (B*S, W, E)
- x = x.view(B * S, -1, self.args.pred_dim)
+ else:
+ if input_ids.ndim == 3: # x: (B, S, W, E) -> (B*S, W, E)
+ B, S, W = input_ids.shape
+ x = x.view(B * S, -1, self.args.pred_dim)
x = self.pos_encoder(x)
x = self.layer_norm(x)
return x
diff --git a/src/trainer/base.py b/src/trainer/base.py
index 60a5e22..2565a11 100644
--- a/src/trainer/base.py
+++ b/src/trainer/base.py
@@ -1,13 +1,15 @@
import heapq
import logging
import os
+import pickle
import uuid
from contextlib import nullcontext
+from datetime import timedelta
from shutil import rmtree
import polars as pl
import torch
-from accelerate import Accelerator
+from accelerate import Accelerator, InitProcessGroupKwargs
from accelerate.logging import get_logger
from accelerate.utils import broadcast, set_seed
from h5pickle import File
@@ -38,16 +40,26 @@ def __init__(self, args):
if self.test_cohort is not None:
# make subject_id to be {subject_id}_{cohort_number} to prevent duplicated ids
if os.path.isdir(self.test_cohort):
- test_cohort = pl.read_parquet(
- os.path.join(self.test_cohort, "*.parquet")
- )
+ if os.path.basename(self.test_cohort) != self.test_subset:
+ test_cohort = pl.read_parquet(
+ os.path.join(self.test_cohort, self.test_subset, "*.parquet")
+ )
+ else:
+ test_cohort = pl.read_parquet(
+ os.path.join(self.test_cohort, "*.parquet")
+ )
else:
test_cohort = pl.read_parquet(self.test_cohort)
+ test_cohort = test_cohort.sort(by=["subject_id", "prediction_time"])
test_cohort = test_cohort.with_columns(
pl.col("subject_id").cum_count().over("subject_id").alias("suffix")
)
test_cohort = test_cohort.with_columns(
- (pl.col("subject_id").cast(str) + "_" + pl.col("suffix").cast(str)).alias("subject_id")
+ (
+ pl.col("subject_id").cast(str)
+ + "_"
+ + pl.col("suffix").cast(str)
+ ).alias("subject_id")
)
test_cohort = test_cohort.drop("suffix")
self.test_cohort = test_cohort
@@ -64,13 +76,26 @@ def __init__(self, args):
self.log = None if self.args.debug or not self.args.wandb else "wandb"
def run(self):
+ ipg_handler = InitProcessGroupKwargs(timeout=timedelta(hours=24))
self.accelerator = Accelerator(
- log_with=self.log, split_batches=True, mixed_precision="bf16"
+ kwargs_handlers=[ipg_handler],
+ log_with=self.log,
+ split_batches=True,
+ mixed_precision="bf16",
)
self.args.local_batch_size = (
self.args.batch_size // self.accelerator.num_processes
)
- if self.args.resume_name:
+ if self.args.src_data == "meds":
+ if self.args.test_only:
+ if self.args.resume_name.endswith("/"):
+ self.args.resume_name = self.args.resume_name[:-1]
+ self.args.exp_name = os.path.basename(self.args.resume_name)
+ else:
+ if self.args.save_dir.endswith("/"):
+ self.args.save_dir = self.args.save_dir[:-1]
+ self.args.exp_name = os.path.basename(self.args.save_dir) + "_" + str(self.args.seed)
+ elif self.args.resume_name:
self.args.exp_name = self.args.resume_name
else:
self.args.exp_name = f"{uuid.uuid4().hex}_{self.args.seed}"
@@ -95,7 +120,7 @@ def run(self):
exp_encoded = broadcast(exp_encoded.to(self.accelerator.device))
self.args.exp_name = "".join([chr(int(i)) for i in exp_encoded])
- if not self.args.encode_only:
+ if not self.args.encode_only and not self.args.test_only:
os.makedirs(
os.path.join(self.args.save_dir, self.args.exp_name), exist_ok=True
)
@@ -120,6 +145,13 @@ def run(self):
if self.args.encode_events or self.args.encode_only:
if self.args.src_data == "meds":
+ assert self.args.encode_events and self.args.encode_only, (
+ "encoding MEDS dataset should be run with both the `self.args.encode_events` "
+ "and `self.args.encode_only` being True."
+ )
+ assert (
+ self.args.unique_events_path is not None
+ ), "`--unique_events_path` shuold be provided to encode MEDS dataset."
self.encode_events_meds()
else:
self.encode_events()
@@ -148,6 +180,9 @@ def train(self):
valid_loader = self.dataloader_set(self.valid_subset)
model = self.architecture(self.args)
+ assert (
+ self.args.pretrained is None or self.args.resume_name is None
+ ), "--pretrained and --resume_name should not be provided together"
if self.args.pretrained and not self.args.no_pretrained_checkpoint:
if self.args.src_data == "meds":
pretrained_path = os.path.join(
@@ -158,6 +193,9 @@ def train(self):
self.args.save_dir, self.args.pretrained, "checkpoint_best.pt"
)
model = load_model(pretrained_path, model)
+ elif self.args.src_data == "meds" and self.args.resume_name is not None:
+ resume_path = os.path.join(self.args.resume_name, "checkpoint_last.pt")
+ model = load_model(resume_path, model)
if self.args.enable_fsdp:
model = self.accelerator.prepare(model)
@@ -277,8 +315,7 @@ def test(self):
.map_elements(lambda x: x.split("_")[0], return_dtype=pl.String)
.cast(int)
)
- exp_name = os.path.basename(self.args.exp_name)
- save_dir = os.path.join(self.args.save_dir, exp_name)
+ save_dir = self.args.save_dir
save_path = os.path.join(save_dir, f"{self.test_subset}.parquet")
if not os.path.exists(save_dir):
os.makedirs(save_dir)
@@ -295,7 +332,7 @@ def test(self):
return metric_dict
def dataloader_set(self, split):
- if split is None:
+ if self.args.src_data == "meds" and split is None:
return None
dataset = self.dataset(self.args, split, self.data, self.df)
return DataLoader(
@@ -303,7 +340,7 @@ def dataloader_set(self, split):
batch_size=self.args.batch_size,
shuffle=False,
num_workers=8,
- pin_memory=True,
+ pin_memory=False,
collate_fn=dataset.collate_fn,
persistent_workers=True,
)
@@ -322,8 +359,34 @@ def epoch(self, split, data_loader, n_epoch=0):
t = tqdm(data_loader, desc=f"{split} epoch {n_epoch}")
else:
t = data_loader
+
+ do_output_cohort = False
+ if self.args.src_data == "meds" and (
+ split == self.test_subset and self.test_cohort is not None
+ ):
+ if self.accelerator.num_processes == 1:
+ # check if test cohort is valid
+ assert set(data_loader.dataset.subject_ids) == set(self.test_cohort["subject_id"]), (
+ "a set of patient ids in the test cohort should equal to that in the test dataset"
+ )
+ predicted_cohort = {"subject_id": [], "predicted_boolean_probability": []}
+ do_output_cohort = True
+ else:
+ logger.warning(
+ "not yet implemented to output predicted labels and probs with "
+ "--test_cohort in multi-processing environment. please run with "
+ "--num_processes=1 in accelerate launch."
+ )
+
for sample in t:
output, reprs = self.model(**sample)
+ # meds -- output
+ if do_output_cohort:
+ predicted_cohort["subject_id"].extend(sample["subject_id"].tolist())
+ predicted_cohort["predicted_boolean_probability"].extend(
+ output["pred"]["meds_single_task"].view(-1).tolist()
+ )
+
loss, logging_outputs = self.criterion(output, reprs)
if split == self.train_subset:
self.optimizer.zero_grad(set_to_none=True)
@@ -338,11 +401,26 @@ def epoch(self, split, data_loader, n_epoch=0):
):
self.accelerator.log({f"{split}_loss": loss})
+ if do_output_cohort:
+ predicted_cohort = pl.DataFrame(predicted_cohort)
+ predicted_cohort = predicted_cohort.with_columns(
+ (pl.col("predicted_boolean_probability") > 0.5).alias("predicted_boolean_value")
+ )
+ self.test_cohort = self.test_cohort.join(predicted_cohort, on="subject_id", how="left")
+ self.test_cohort = self.test_cohort.select(
+ [
+ pl.col("subject_id"),
+ pl.col("prediction_time"),
+ pl.col("boolean_value"),
+ pl.col("predicted_boolean_value"),
+ pl.col("predicted_boolean_probability"),
+ ]
+ )
+
metrics = self.metric.get_metrics()
log_dict = log_from_dict(metrics, split, n_epoch)
- if self.log is None:
- print(log_dict)
- else:
+ logger.info(log_dict)
+ if self.log is not None:
self.accelerator.log(log_dict)
return metrics
@@ -356,8 +434,10 @@ def encode_events(self):
model = self.architecture(self.args)
model = load_model(best_model_path, model)
self.model = self.accelerator.prepare(model)
- self.args.max_seq_len = 1024
- self.args.batch_size = 8
+ # self.args.max_seq_len = 1024
+ # self.args.batch_size = 8
+ self.args.max_seq_len = 512
+ self.args.batch_size = 16
else:
self.args.max_seq_len = 512
self.args.batch_size = 8
@@ -372,7 +452,7 @@ def _get_hdf5_path(i):
postfix = "" if i == -1 else "_" + str(i)
return os.path.join(
self.args.save_dir,
- self.args.exp_name,
+ # self.args.exp_name,
f"{self.args.src_data}_encoded{postfix}.h5",
)
@@ -386,23 +466,19 @@ def _get_hdf5_path(i):
for k in self.data["ehr"].keys():
k = str(k)
stay_g = encoded.create_group(k)
- stay_g.create_dataset(
- "encoded",
- shape=(len(self.data["ehr"][k]["time"]), self.args.pred_dim),
- dtype="i2",
- compression="lzf",
- shuffle=True,
- chunks=(len(self.data["ehr"][k]["time"]), self.args.pred_dim),
- )
+ stay_g.create_dataset("time", data=self.data["ehr"][k]["time"][()])
+ stay_g.attrs.update(self.data["ehr"][k].attrs)
self.accelerator.wait_for_everyone()
with torch.no_grad():
loader = enumerate(dataloader)
if self.accelerator.is_main_process:
- loader = tqdm(loader)
+ loader = tqdm(loader, total=len(dataloader))
+ buffer = {}
for i, batch in loader:
self.step = i
all_codes_embs = self.model.input2emb_model(**batch)
+ # (16, 512, 128) -> (16, 512, 128, 512) -> (16 * 512, 128, 512)
reprs = self.model.eventencoder_model(
all_codes_embs, **batch
@@ -411,26 +487,38 @@ def _get_hdf5_path(i):
stay_ids = batch["stay_id"].cpu().numpy().reshape(-1)
indices = batch["index"].cpu().numpy().reshape(-1)
for repr, stay_id, index in zip(reprs, stay_ids, indices):
- stay_id, index = int(stay_id), int(index)
start = self.args.max_seq_len * index
end = start + self.args.max_seq_len
- max_len = encoded[str(stay_id)]["encoded"].shape[0]
+ max_len = dataloader.dataset.df["time"].loc[stay_id]
if end > max_len:
repr = repr[: max_len - start, :]
end = max_len
- encoded[str(stay_id)]["encoded"][start:end, :] = repr
+ if stay_id not in buffer:
+ buffer[stay_id] = []
+ heapq.heappush(buffer[stay_id], (index, repr))
+ if ((i + 1) % 100 == 0) or ((i + 1) == len(loader)):
+ for stay_id in list(buffer.keys()):
+ items = buffer[stay_id]
+ num_events = dataloader.dataset.df["time"].loc[stay_id]
+ num_samples = dataloader.dataset.df["num_sample_per_pat"].loc[
+ stay_id
+ ]
+ if len(items) == num_samples:
+ data = np.concatenate([x[1] for x in items])
+ encoded[str(stay_id)].create_dataset(
+ "encoded",
+ data=data,
+ dtype="i2",
+ compression="lzf",
+ shuffle=True,
+ chunks=(num_events, self.args.pred_dim),
+ )
+ del buffer[stay_id]
f.close()
self.accelerator.wait_for_everyone()
if self.accelerator.num_processes == 1:
os.rename(hdf5_path, _get_hdf5_path(-1))
- main_file = File(_get_hdf5_path(-1), "r+")
- for k in tqdm(self.data["ehr"].keys()):
- main_file["ehr"][k].create_dataset(
- "time", data=self.data["ehr"][k]["time"][()]
- )
- main_file["ehr"][k].attrs.update(self.data["ehr"][k].attrs)
- main_file.close()
else:
if self.accelerator.is_main_process:
main_file = File(_get_hdf5_path(-1), "w")
@@ -470,138 +558,87 @@ def _get_hdf5_path(i):
# to add compatibility with meds dataset
def encode_events_meds(self):
- if self.args.encode_only:
- best_model_path = os.path.join(self.args.resume_name, "checkpoint_best.pt")
- else:
- best_model_path = os.path.join(
- self.args.save_dir, self.args.exp_name, "checkpoint_best.pt"
- )
- if self.args.train_type != "bioclinicalbert_encode":
- model = self.architecture(self.args)
- model = load_model(best_model_path, model)
- self.model = self.accelerator.prepare(model)
- self.args.max_seq_len = 1024
- self.args.batch_size = 8
- else:
- self.args.max_seq_len = 512
- self.args.batch_size = 8
+ best_model_path = os.path.join(self.args.resume_name, "checkpoint_best.pt")
+ model = self.architecture(self.args)
+ model = load_model(best_model_path, model)
+ self.model = self.accelerator.prepare(model)
self.model.eval()
- self.dataset = MEDSForReprGen
-
- for split in [self.train_subset, self.valid_subset, self.test_subset]:
- if split is None:
- continue
- logger.info(f"Start Encoding for {split} split")
+ logger.info(
+ f"Start to generate representation vectors for each of unique events in MEDS dataset"
+ )
+ dataset = MEDSForReprGen(self.args, self.args.unique_events_path)
+ dataloader = DataLoader(
+ dataset,
+ batch_size=self.args.batch_size,
+ shuffle=False,
+ num_workers=8,
+ pin_memory=False,
+ collate_fn=dataset.collate_fn,
+ persistent_workers=True,
+ )
+ dataloader = self.accelerator.prepare(dataloader)
- dataloader = self.dataloader_set(split)
- dataloader = self.accelerator.prepare(dataloader)
+ event_to_vec = {}
+ with torch.no_grad():
+ loader = enumerate(dataloader)
+ loader = tqdm(
+ loader,
+ total=len(dataloader),
+ desc=str(self.accelerator.local_process_index),
+ )
+ for i, batch in loader:
+ self.step = i
- def _get_hdf5_path(i):
- postfix = "" if i == -1 else "_" + str(i)
- return os.path.join(
- self.args.save_dir,
- f"{split}_encoded{postfix}.h5",
+ embedded = self.accelerator.unwrap_model(self.model).input2emb_model(
+ **batch
)
+ event_vectors = (
+ self.accelerator.unwrap_model(self.model).eventencoder_model(
+ embedded, **batch
+ )
+ ).squeeze(
+ 1
+ ) # (B, 1, E) -> (B, E)
+ event_vectors = event_vectors.cpu().bfloat16().view(torch.int16).numpy()
- hdf5_path = _get_hdf5_path(self.accelerator.local_process_index)
- logger.info("Writing metadata to HDF5")
+ input_ids = batch["input_ids"].cpu().numpy() # (B, 128)
+ for j, event in enumerate(input_ids):
+ event_tuple = tuple(event[event != 0])
+ event_to_vec[event_tuple] = event_vectors[j]
- with File(hdf5_path, "w") as f:
- f.create_group("ehr")
- encoded = f["ehr"]
+ def get_local_path(i):
+ postfix = "" if i == -1 else "_" + str(i)
+ return os.path.join(self.args.save_dir, f"event_to_vec{postfix}.pkl")
- for subject_id, metadata in dataloader.dataset.manifest.iterrows():
- stay_g = encoded.create_group(subject_id)
- stay_g.create_dataset(
- "time",
- data=dataloader.dataset.data[metadata["shard_id"]][subject_id]["time"][()],
- dtype="i"
- )
- stay_g.create_dataset(
- "label",
- data=dataloader.dataset.data[metadata["shard_id"]][subject_id]["label"][()],
- )
- self.accelerator.wait_for_everyone()
-
- with torch.no_grad():
- loader = enumerate(dataloader)
- if self.accelerator.is_main_process:
- loader = tqdm(loader, total=len(dataloader))
- buffer = {}
- for i, batch in loader:
- self.step = i
- all_codes_embs = self.model.input2emb_model(**batch)
-
- reprs = self.model.eventencoder_model(
- all_codes_embs, **batch
- ) # B, S, E
- reprs = reprs.cpu().bfloat16().view(torch.int16).numpy()
- subject_ids = batch["subject_id"].reshape(-1)
- indices = batch["index"].reshape(-1)
- for repr, subject_id, index, in zip(reprs, subject_ids, indices):
- start = self.args.max_seq_len * index
- end = start + self.args.max_seq_len
- max_len = dataloader.dataset.manifest.loc[subject_id]["num_events"]
- if end > max_len:
- repr = repr[: max_len - start, :]
- end = max_len
- if subject_id not in buffer:
- buffer[subject_id] = []
- heapq.heappush(buffer[subject_id], (index, repr))
- if ((i + 1) % 100 == 0) or ((i + 1) == len(loader)):
- # flush buffer if applicable
- for subject_id in list(buffer.keys()):
- items = buffer[subject_id]
- metadata = dataloader.dataset.manifest.loc[subject_id]
- if len(items) == metadata["num_samples"]:
- data = np.concatenate([x[1] for x in items])
- encoded[subject_id].create_dataset(
- "encoded",
- data=data,
- dtype="i2",
- compression="lzf",
- shuffle=True,
- chunks=(
- metadata["num_events"],
- self.args.pred_dim,
- ),
- )
- del buffer[subject_id]
+ local_path = get_local_path(self.accelerator.local_process_index)
+ logger.info("Saving the resulted vector maps...")
+ with open(local_path, "wb") as f:
+ pickle.dump(event_to_vec, f)
+ logger.info("Done!")
+ self.accelerator.wait_for_everyone()
- self.accelerator.wait_for_everyone()
- if self.accelerator.num_processes == 1:
- os.rename(hdf5_path, _get_hdf5_path(-1))
- else:
- if self.accelerator.is_main_process:
- main_file = File(_get_hdf5_path(-1), "w")
- main_file.create_group("ehr")
- files = [
- File(_get_hdf5_path(i), "r")
- for i in range(self.accelerator.num_processes)
- ]
- for k in tqdm(dataloader.dataset.manifest.index):
- # Chunkwise sum, but may be duplicated chunks
- encodeds = (
- np.stack(
- [f["ehr"][k]["encoded"][()] for f in files], axis=0
- )
- .astype(np.uint16)
- .max(axis=0)
- .astype(np.int16)
- )
- main_file["ehr"].create_group(k)
- main_file["ehr"][k].create_dataset(
- "encoded",
- data=encodeds,
- dtype="i2",
- compression="lzf",
- shuffle=True,
- chunks=encodeds.shape,
- )
- for f in files:
- f.close()
- main_file.close()
- for i in range(self.accelerator.num_processes):
- os.remove(_get_hdf5_path(i))
- self.accelerator.wait_for_everyone()
+ if self.accelerator.num_processes == 1:
+ if os.path.exists(get_local_path(-1)):
+ os.remove(get_local_path(-1))
+ os.rename(local_path, get_local_path(-1))
+ else:
+ if self.accelerator.is_main_process:
+ main_dict = {}
+ local_dicts = []
+ for i in range(self.accelerator.num_processes):
+ with open(get_local_path(i), "rb") as local_f:
+ local_dicts.append(pickle.load(local_f))
+ logger.info("Gathering and summarizing local vector maps...")
+ for local_dict in local_dicts:
+ # NOTE only work in python >= 3.9.0
+ main_dict = main_dict | local_dict
+
+ if os.path.exists(get_local_path(-1)):
+ os.remove(get_local_path(-1))
+ with open(get_local_path(-1), "wb") as main_f:
+ pickle.dump(main_dict, main_f)
+
+ for i in range(self.accelerator.num_processes):
+ os.remove(get_local_path(i))
+ self.accelerator.wait_for_everyone()
diff --git a/src/trainer/remed.py b/src/trainer/remed.py
index 5b1c563..eb93cf2 100644
--- a/src/trainer/remed.py
+++ b/src/trainer/remed.py
@@ -75,10 +75,10 @@ def step(sample):
):
if self.accelerator.num_processes == 1:
# check if test cohort is valid
- assert set(data_loader.dataset.manifest) == set(self.test_cohort["subject_id"]), (
- "a set of patient ids in the test cohort should equal to that in the test dataset"
- )
- predicted_cohort = {"subject_id": [], "boolean_prediction": []}
+ assert set(data_loader.dataset.manifest) == set(
+ self.test_cohort["subject_id"]
+ ), "a set of patient ids in the test cohort should equal to that in the test dataset"
+ predicted_cohort = {"subject_id": [], "predicted_boolean_probability": []}
do_output_cohort = True
else:
logger.warning(
@@ -89,22 +89,36 @@ def step(sample):
for sample in t:
if split == self.train_subset:
- self.model.set_mode("scorer")
+ self.accelerator.unwrap_model(self.model).set_mode("scorer")
net_output, logging_output = step(sample)
- self.model.set_mode("predictor")
+ self.accelerator.unwrap_model(self.model).set_mode("predictor")
net_output, logging_output = step(sample)
# meds -- output
if do_output_cohort:
predicted_cohort["subject_id"].extend(sample["subject_id"].tolist())
- predicted_cohort["boolean_prediction"].extend(
- net_output['pred']['meds_single_task'].view(-1).tolist()
+ predicted_cohort["predicted_boolean_probability"].extend(
+ net_output["pred"]["meds_single_task"].view(-1).tolist()
)
if do_output_cohort:
predicted_cohort = pl.DataFrame(predicted_cohort)
- self.test_cohort = self.test_cohort.join(predicted_cohort, on="subject_id", how="left")
+ predicted_cohort = predicted_cohort.with_columns(
+ (pl.col("predicted_boolean_probability") > 0.5).alias("predicted_boolean_value")
+ )
+ self.test_cohort = self.test_cohort.join(
+ predicted_cohort, on="subject_id", how="left"
+ )
+ self.test_cohort = self.test_cohort.select(
+ [
+ pl.col("subject_id"),
+ pl.col("prediction_time"),
+ pl.col("boolean_value"),
+ pl.col("predicted_boolean_value"),
+ pl.col("predicted_boolean_probability"),
+ ]
+ )
metrics = self.metric.get_metrics()
log_dict = log_from_dict(metrics, split, n_epoch)