diff --git a/.gitignore b/.gitignore index 85a6aa6df..5f887de5e 100644 --- a/.gitignore +++ b/.gitignore @@ -130,4 +130,9 @@ leaderboard/credentials.json leaderboard/rtd_token.txt # locally pre-trained models -pyhealth/medcode/pretrained_embeddings/kg_emb/examples/pretrained_model \ No newline at end of file +pyhealth/medcode/pretrained_embeddings/kg_emb/examples/pretrained_model + +# local testing files +halo_testing/ +halo_testing_script.py +test_halo_model.slurm \ No newline at end of file diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index f0e4f53e7..2d1f4f4a0 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -32,6 +32,7 @@ def __init__(self, *args, **kwargs): from .eicu import eICUDataset from .isruc import ISRUCDataset from .medical_transcriptions import MedicalTranscriptionsDataset +from .halo_mimic3 import HALO_MIMIC3Dataset from .mimic3 import MIMIC3Dataset from .mimic4 import MIMIC4CXRDataset, MIMIC4Dataset, MIMIC4EHRDataset, MIMIC4NoteDataset from .mimicextract import MIMICExtractDataset diff --git a/pyhealth/datasets/configs/hcup_ccs_2015_definitions_benchmark.yaml b/pyhealth/datasets/configs/hcup_ccs_2015_definitions_benchmark.yaml new file mode 100644 index 000000000..cf0caddfa --- /dev/null +++ b/pyhealth/datasets/configs/hcup_ccs_2015_definitions_benchmark.yaml @@ -0,0 +1,149 @@ +"Septicemia (except in labor)": + use_in_benchmark: True + type: "acute" + id: 2 + codes: [ "0031", "0202", "0223", "0362", "0380", "0381", "03810", "03811", "03812", "03819", "0382", "0383", "03840", "03841", "03842", "03843", "03844", "03849", "0388", "0389", "0545", "449", "77181", "7907", "99591", "99592" ] + +"Diabetes mellitus without complication": + use_in_benchmark: True + type: "chronic" + id: 49 + codes: [ "24900", "25000", "25001", "7902", "79021", "79022", "79029", "7915", "7916", "V4585", "V5391", "V6546" ] + +"Diabetes mellitus with complications": + use_in_benchmark: True + type: "chronic" + id: 50 + codes: [ "24901", "24910", "24911", "24920", "24921", "24930", "24931", "24940", "24941", "24950", "24951", "24960", "24961", "24970", "24971", "24980", "24981", "24990", "24991", "25002", "25003", "25010", "25011", "25012", "25013", "25020", "25021", "25022", "25023", "25030", "25031", "25032", "25033", "25040", "25041", "25042", "25043", "25050", "25051", "25052", "25053", "25060", "25061", "25062", "25063", "25070", "25071", "25072", "25073", "25080", "25081", "25082", "25083", "25090", "25091", "25092", "25093" ] + +"Disorders of lipid metabolism": + use_in_benchmark: True + type: "chronic" + id: 53 + codes: [ "2720", "2721", "2722", "2723", "2724" ] + +"Fluid and electrolyte disorders": + use_in_benchmark: True + type: "acute" + id: 55 + codes: [ "2760", "2761", "2762", "2763", "2764", "2765", "27650", "27651", "27652", "2766", "27669", "2767", "2768", "2769", "9951" ] + +"Essential hypertension": + use_in_benchmark: True + type: "chronic" + id: 98 + codes: [ "4011", "4019" ] + +"Hypertension with complications and secondary hypertension": + use_in_benchmark: True + type: "chronic" + id: 99 + codes: [ "4010", "40200", "40201", "40210", "40211", "40290", "40291", "4030", "40300", "40301", "4031", "40310", "40311", "4039", "40390", "40391", "4040", "40400", "40401", "40402", "40403", "4041", "40410", "40411", "40412", "40413", "4049", "40490", "40491", "40492", "40493", "40501", "40509", "40511", "40519", "40591", "40599", "4372" ] + +"Acute myocardial infarction": + use_in_benchmark: True + type: "acute" + id: 100 + codes: [ "4100", "41000", "41001", "41002", "4101", "41010", "41011", "41012", "4102", "41020", "41021", "41022", "4103", "41030", "41031", "41032", "4104", "41040", "41041", "41042", "4105", "41050", "41051", "41052", "4106", "41060", "41061", "41062", "4107", "41070", "41071", "41072", "4108", "41080", "41081", "41082", "4109", "41090", "41091", "41092" ] + +"Coronary atherosclerosis and other heart disease": + use_in_benchmark: True + type: "chronic" + id: 101 + codes: [ "4110", "4111", "4118", "41181", "41189", "412", "4130", "4131", "4139", "4140", "41400", "41401", "41406", "4142", "4143", "4144", "4148", "4149", "V4581", "V4582" ] + +"Conduction disorders": + use_in_benchmark: True + type: "chronic" + id: 105 + codes: [ "4260", "42610", "42611", "42612", "42613", "4262", "4263", "4264", "42650", "42651", "42652", "42653", "42654", "4266", "4267", "42681", "42682", "42689", "4269", "V450", "V4500", "V4501", "V4502", "V4509", "V533", "V5331", "V5332", "V5339" ] + +"Cardiac dysrhythmias": + use_in_benchmark: True + type: "chronic" + id: 106 + codes: [ "4270", "4271", "4272", "42731", "42732", "42760", "42761", "42769", "42781", "42789", "4279", "7850", "7851" ] + +"Congestive heart failure; nonhypertensive": + use_in_benchmark: True + type: "acute" + id: 108 + codes: [ "39891", "4280", "4281", "42820", "42821", "42822", "42823", "42830", "42831", "42832", "42833", "42840", "42841", "42842", "42843", "4289" ] + +"Acute cerebrovascular disease": + use_in_benchmark: True + type: "acute" + id: 109 + codes: [ "34660", "34661", "34662", "34663", "430", "431", "4320", "4321", "4329", "43301", "43311", "43321", "43331", "43381", "43391", "4340", "43400", "43401", "4341", "43410", "43411", "4349", "43490", "43491", "436" ] + +"Pneumonia (except that caused by tuberculosis or sexually transmitted disease)": + use_in_benchmark: True + type: "acute" + id: 122 + codes: [ "00322", "0203", "0204", "0205", "0212", "0221", "0310", "0391", "0521", "0551", "0730", "0830", "1124", "1140", "1144", "1145", "11505", "11515", "11595", "1304", "1363", "4800", "4801", "4802", "4803", "4808", "4809", "481", "4820", "4821", "4822", "4823", "48230", "48231", "48232", "48239", "4824", "48240", "48241", "48242", "48249", "4828", "48281", "48282", "48283", "48284", "48289", "4829", "483", "4830", "4831", "4838", "4841", "4843", "4845", "4846", "4847", "4848", "485", "486", "5130", "5171" ] + +"Chronic obstructive pulmonary disease and bronchiectasis": + use_in_benchmark: True + type: "chronic" + id: 127 + codes: [ "490", "4910", "4911", "4912", "49120", "49121", "49122", "4918", "4919", "4920", "4928", "494", "4940", "4941", "496" ] + +"Pleurisy; pneumothorax; pulmonary collapse": + use_in_benchmark: True + type: "acute" + id: 130 + codes: [ "5100", "5109", "5110", "5111", "5118", "51189", "5119", "5120", "5128", "51281", "51282", "51283", "51284", "51289", "5180", "5181", "5182" ] + +"Respiratory failure; insufficiency; arrest (adult)": + use_in_benchmark: True + type: "acute" + id: 131 + codes: [ "5173", "5185", "51851", "51852", "51853", "51881", "51882", "51883", "51884", "7991", "V461", "V4611", "V4612", "V4613", "V4614", "V462" ] + +"Other lower respiratory disease": + use_in_benchmark: True + type: "acute" + id: 133 + codes: [ "5131", "514", "515", "5160", "5161", "5162", "5163", "51630", "51631", "51632", "51633", "51634", "51635", "51636", "51637", "5164", "5165", "51661", "51662", "51663", "51664", "51669", "5168", "5169", "5172", "5178", "5183", "5184", "51889", "5194", "5198", "5199", "7825", "78600", "78601", "78602", "78603", "78604", "78605", "78606", "78607", "78609", "7862", "7863", "78630", "78631", "78639", "7864", "78652", "7866", "7867", "7868", "7869", "7931", "79311", "79319", "7942", "V126", "V1260", "V1261", "V1269", "V426" ] + +"Other upper respiratory disease": + use_in_benchmark: True + type: "acute" + id: 134 + codes: [ "470", "4710", "4711", "4718", "4719", "4720", "4721", "4722", "4760", "4761", "4770", "4772", "4778", "4779", "4780", "4781", "47811", "47819", "47820", "47821", "47822", "47824", "47825", "47826", "47829", "47830", "47831", "47832", "47833", "47834", "4784", "4785", "4786", "47870", "47871", "47874", "47875", "47879", "4788", "4789", "5191", "51911", "51919", "5192", "5193", "7841", "78440", "78441", "78442", "78443", "78444", "78449", "7847", "7848", "7849", "78499", "7861", "V414", "V440", "V550" ] + +"Other liver diseases": + use_in_benchmark: True + type: "acute" + id: 151 + codes: [ "570", "5715", "5716", "5718", "5719", "5720", "5721", "5722", "5723", "5724", "5728", "5730", "5734", "5735", "5738", "5739", "7824", "7891", "7895", "78959", "7904", "7905", "7948", "V427" ] + +"Gastrointestinal hemorrhage": + use_in_benchmark: True + type: "acute" + id: 153 + codes: [ "4560", "45620", "5307", "53082", "53100", "53101", "53120", "53121", "53140", "53141", "53160", "53161", "53200", "53201", "53220", "53221", "53240", "53241", "53260", "53261", "53300", "53301", "53320", "53321", "53340", "53341", "53360", "53361", "53400", "53401", "53420", "53421", "53440", "53441", "53460", "53461", "5693", "5780", "5781", "5789" ] + +"Acute and unspecified renal failure": + use_in_benchmark: True + type: "acute" + id: 157 + codes: [ "5845", "5846", "5847", "5848", "5849", "586" ] + +"Chronic kidney disease": + use_in_benchmark: True + type: "chronic" + id: 158 + codes: [ "585", "5851", "5852", "5853", "5854", "5855", "5856", "5859", "7925", "V420", "V451", "V4511", "V4512", "V560", "V561", "V562", "V5631", "V5632", "V568" ] + +"Complications of surgical procedures or medical care": + use_in_benchmark: True + type: "acute" + id: 238 + codes: [ "27661", "27783", "27788", "2853", "28741", "3490", "3491", "34931", "41511", "4294", "4582", "45821", "45829", "5121", "5122", "5187", "5190", "51900", "51901", "51902", "51909", "53086", "53087", "53640", "53641", "53642", "53649", "53901", "53909", "53981", "53989", "5642", "5643", "5644", "5696", "56962", "56971", "56979", "5793", "59681", "78062", "78063", "78066", "9093", "99524", "9954", "99586", "9970", "99700", "99701", "99702", "99709", "9971", "9972", "9973", "99731", "99732", "99739", "9974", "99741", "99749", "9975", "99760", "99761", "99762", "99769", "99771", "99772", "99779", "9979", "99791", "99799", "9980", "99800", "99801", "99802", "99809", "9981", "99811", "99812", "99813", "9982", "9983", "99830", "99831", "99832", "99833", "9984", "9985", "99851", "99859", "9986", "9987", "9988", "99881", "99882", "99883", "99889", "9989", "9990", "9991", "9992", "9993", "99934", "99939", "9994", "99941", "99942", "99949", "9995", "99951", "99952", "99959", "9996", "99960", "99961", "99962", "99963", "99969", "9997", "99970", "99971", "99972", "99973", "99974", "99975", "99976", "99977", "99978", "99979", "9998", "99980", "99981", "99982", "99983", "99984", "99985", "99988", "99989", "9999", "V1553", "V1580", "V1583", "V9001", "V9009" ] + +"Shock": + use_in_benchmark: True + type: "acute" + id: 249 + codes: [ "78550", "78551", "78552", "78559" ] \ No newline at end of file diff --git a/pyhealth/datasets/halo_mimic3.py b/pyhealth/datasets/halo_mimic3.py new file mode 100644 index 000000000..f08eb8a86 --- /dev/null +++ b/pyhealth/datasets/halo_mimic3.py @@ -0,0 +1,134 @@ +import logging +import yaml +import pickle +import numpy as np +import pandas as pd +from tqdm import tqdm +from sklearn.model_selection import train_test_split + +logger = logging.getLogger(__name__) + + +class HALO_MIMIC3Dataset: + """ + A dataset class for handling MIMIC-III data, specifically designed to be compatible with HALO. + + This class is responsible for loading and managing the MIMIC-III dataset, + which includes tables such as patients, admissions, and icustays. + + Attributes: + mimic3_dir (str): The root directory where the dataset is stored. + pkl_data_dir (str): The directory in which .pkl files related to the dataset object will be stored. + gzip (Optional[bool]): Determines whether the object will look for ".csv.gz" (True) or ".csv" (False) files in mimic3_dir. + """ + + def __init__( + self, + mimic3_dir: str = "./", + pkl_data_dir: str = "./", + gzip: bool = False + ) -> None: + self.gzip = gzip + self.mimic3_dir = mimic3_dir + self.pkl_data_dir = pkl_data_dir + self.build_dataset() + + def build_dataset(self) -> None: + admissionFile = self.mimic3_dir + f"ADMISSIONS.csv{'.gz' if self.gzip else ''}" + diagnosisFile = self.mimic3_dir + f"DIAGNOSES_ICD.csv{'.gz' if self.gzip else ''}" + + admissionDf = pd.read_csv(admissionFile, dtype=str) + admissionDf['ADMITTIME'] = pd.to_datetime(admissionDf['ADMITTIME']) + admissionDf = admissionDf.sort_values('ADMITTIME') + admissionDf = admissionDf.reset_index(drop=True) + diagnosisDf = pd.read_csv(diagnosisFile, dtype=str).set_index("HADM_ID") + diagnosisDf = diagnosisDf[diagnosisDf['ICD9_CODE'].notnull()] + diagnosisDf = diagnosisDf[['ICD9_CODE']] + + data = {} + for row in tqdm(admissionDf.itertuples(), total=admissionDf.shape[0]): + #Extracting Admissions Table Info + hadm_id = row.HADM_ID + subject_id = row.SUBJECT_ID + + # Extracting the Diagnoses + if hadm_id in diagnosisDf.index: + diagnoses = list(set(diagnosisDf.loc[[hadm_id]]["ICD9_CODE"])) + else: + diagnoses = [] + + # Building the hospital admission data point + if subject_id not in data: + data[subject_id] = {'visits': [diagnoses]} + else: + data[subject_id]['visits'].append(diagnoses) + + code_to_index = {} + all_codes = list(set([c for p in data.values() for v in p['visits'] for c in v])) + np.random.shuffle(all_codes) + for c in all_codes: + code_to_index[c] = len(code_to_index) + print(f"VOCAB SIZE: {len(code_to_index)}") + index_to_code = {v: k for k, v in code_to_index.items()} + + data = list(data.values()) + + with open("./configs/hcup_ccs_2015_definitions_benchmark.yaml") as definitions_file: + definitions = yaml.full_load(definitions_file) + + code_to_group = {} + for group in definitions: + if definitions[group]['use_in_benchmark'] == False: + continue + codes = definitions[group]['codes'] + for code in codes: + if code not in code_to_group: + code_to_group[code] = group + else: + assert code_to_group[code] == group + + id_to_group = sorted([k for k in definitions.keys() if definitions[k]['use_in_benchmark'] == True]) + group_to_id = dict((x, i) for (i, x) in enumerate(id_to_group)) + + # Add Labels + for p in data: + label = np.zeros(len(group_to_id)) + for v in p['visits']: + for c in v: + if c in code_to_group: + label[group_to_id[code_to_group[c]]] = 1 + + p['labels'] = label + + for p in data: + new_visits = [] + for v in p['visits']: + new_visit = [] + for c in v: + new_visit.append(code_to_index[c]) + + new_visits.append((list(set(new_visit)))) + + p['visits'] = new_visits + + print(f"MAX LEN: {max([len(p['visits']) for p in data])}") + print(f"AVG LEN: {np.mean([len(p['visits']) for p in data])}") + print(f"MAX VISIT LEN: {max([len(v) for p in data for v in p['visits']])}") + print(f"AVG VISIT LEN: {np.mean([len(v) for p in data for v in p['visits']])}") + print(f"NUM RECORDS: {len(data)}") + print(f"NUM LONGITUDINAL RECORDS: {len([p for p in data if len(p['visits']) > 1])}") + + # Train-Val-Test Split + train_dataset, test_dataset = train_test_split(data, test_size=0.2, random_state=4, shuffle=True) + train_dataset, val_dataset = train_test_split(train_dataset, test_size=0.1, random_state=4, shuffle=True) + + # Save Everything + print("Saving Everything") + print(len(index_to_code)) + pickle.dump(code_to_index, open(f"{self.pkl_data_dir}codeToIndex.pkl", "wb")) + pickle.dump(index_to_code, open(f"{self.pkl_data_dir}indexToCode.pkl", "wb")) + pickle.dump(id_to_group, open(f"{self.pkl_data_dir}idToLabel.pkl", "wb")) + pickle.dump(train_dataset, open(f"{self.pkl_data_dir}trainDataset.pkl", "wb")) + pickle.dump(val_dataset, open(f"{self.pkl_data_dir}valDataset.pkl", "wb")) + pickle.dump(test_dataset, open(f"{self.pkl_data_dir}testDataset.pkl", "wb")) + diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 69dc60eb8..eb3766ccc 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -25,3 +25,4 @@ from .transformer import Transformer, TransformerLayer from .transformers_model import TransformersModel from .vae import VAE +from .generators.halo import HALO \ No newline at end of file diff --git a/pyhealth/models/generators/__init__.py b/pyhealth/models/generators/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pyhealth/models/generators/halo.py b/pyhealth/models/generators/halo.py new file mode 100644 index 000000000..9223d6d9f --- /dev/null +++ b/pyhealth/models/generators/halo.py @@ -0,0 +1,332 @@ +import torch +import os +import numpy as np +import random +import pickle +from tqdm import tqdm + +from pyhealth.datasets import HALO_MIMIC3Dataset + +# Import the HALO transformer implementation +from pyhealth.models.generators.halo_resources.halo_model import HALOModel +from pyhealth.models.generators.halo_resources.halo_config import HALOConfig + +class HALO: + + def __init__( + self, + dataset: HALO_MIMIC3Dataset, + config: HALOConfig, + save_dir: str = "./save/", + train_on_init: bool = True + ) -> None: + SEED = 4 + random.seed(SEED) + np.random.seed(SEED) + torch.manual_seed(SEED) + self.config = config + self.dataset = dataset + self.save_dir = save_dir + + local_rank = -1 + fp16 = False + if local_rank == -1: + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.n_gpu = torch.cuda.device_count() + else: + torch.cuda.set_device(local_rank) + self.device = torch.device("cuda", local_rank) + self.n_gpu = 1 + # Initializes the distributed backend which will take care of sychronizing nodes/GPUs + torch.distributed.init_process_group(backend='nccl') + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(SEED) + + self.train_ehr_dataset = pickle.load(open(f'{self.dataset.pkl_data_dir}trainDataset.pkl', 'rb')) + self.val_ehr_dataset = pickle.load(open(f'{self.dataset.pkl_data_dir}valDataset.pkl', 'rb')) + self.index_to_code = pickle.load(open(f"{self.dataset.pkl_data_dir}indexToCode.pkl", "rb")) + self.id_to_label = pickle.load(open(f"{self.dataset.pkl_data_dir}idToLabel.pkl", "rb")) + test_ehr_dataset = pickle.load(open(f'{self.dataset.pkl_data_dir}testDataset.pkl', 'rb')) + + train_c = set([c for p in self.train_ehr_dataset for v in p['visits'] for c in v]) + self.test_ehr_dataset = [{'labels': p['labels'], 'visits': [[c for c in v if c in train_c] for v in p['visits']]} for p in test_ehr_dataset] + + if train_on_init: + self.train() + + + def train(self) -> None: + + # HELPER: + def get_batch(loc, batch_size, mode): + # EHR data saved as [(P_1, L_1), (P_2, L_2), ... , (P_i, L_i)] + # Where each patient P is [V_1, V_2, ... , V_j] + # Where each visit V is [C_1, C_2, ... , C_k] + # And where each Label L is a binary vector [L_1 ... L_n] + if mode == 'train': + ehr = self.train_ehr_dataset[loc:loc+batch_size] + elif mode == 'valid': + ehr = self.val_ehr_dataset[loc:loc+batch_size] + else: + ehr = self.test_ehr_dataset[loc:loc+batch_size] + + batch_ehr = np.zeros((len(ehr), self.config.n_ctx, self.config.total_vocab_size)) + batch_mask = np.zeros((len(ehr), self.config.n_ctx, 1)) + for i, p in enumerate(ehr): + visits = p['visits'] + for j, v in enumerate(visits): + batch_ehr[i,j+2][v] = 1 + batch_mask[i,j+2] = 1 + batch_ehr[i,1,self.config.code_vocab_size:self.config.code_vocab_size+self.config.label_vocab_size] = np.array(p['labels']) # Set the patient labels + batch_ehr[i,len(visits)+1,self.config.code_vocab_size+self.config.label_vocab_size+1] = 1 # Set the final visit to have the end token + batch_ehr[i,len(visits)+2:,self.config.code_vocab_size+self.config.label_vocab_size+2] = 1 # Set the rest to the padded visit token + + batch_mask[:,1] = 1 # Set the mask to cover the labels + batch_ehr[:,0,self.config.code_vocab_size+self.config.label_vocab_size] = 1 # Set the first visits to be the start token + batch_mask = batch_mask[:,1:,:] # Shift the mask to match the shifted labels and predictions the model will return + return batch_ehr, batch_mask + + # HELPER + def shuffle_training_data(train_ehr_dataset): + np.random.shuffle(train_ehr_dataset) + + # TRAIN PIPELINE: + + # Load any previous checkpoint if exists + self.model = HALOModel(self.config).to(self.device) + self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config.lr) + if os.path.exists(f"{self.save_dir}halo_model"): + checkpoint = torch.load(f'{self.save_dir}halo_model', map_location=torch.device(self.device)) + self.model.load_state_dict(checkpoint['model']) + self.optimizer.load_state_dict(checkpoint['optimizer']) + + # Train + global_loss = 1e10 + for e in tqdm(range(self.config.epoch)): + shuffle_training_data(self.train_ehr_dataset) + for i in range(0, len(self.train_ehr_dataset), self.config.batch_size): + self.model.train() + + batch_ehr, batch_mask = get_batch(i, self.config.batch_size, 'train') + batch_ehr = torch.tensor(batch_ehr, dtype=torch.float32).to(self.device) + batch_mask = torch.tensor(batch_mask, dtype=torch.float32).to(self.device) + + self.optimizer.zero_grad() + loss, _, _ = self.model(batch_ehr, position_ids=None, ehr_labels=batch_ehr, ehr_masks=batch_mask, pos_loss_weight=self.config.pos_loss_weight) + loss.backward() + self.optimizer.step() + + if i % (500*self.config.batch_size) == 0: + print("Epoch %d, Iter %d: Training Loss:%.6f"%(e, i, loss * 8)) + if i % (500*self.config.batch_size) == 0: + if i == 0: + continue + + self.model.eval() + with torch.no_grad(): + val_l = [] + for v_i in range(0, len(self.val_ehr_dataset), self.config.batch_size): + batch_ehr, batch_mask = get_batch(v_i, self.config.batch_size, 'valid') + batch_ehr = torch.tensor(batch_ehr, dtype=torch.float32).to(self.device) + batch_mask = torch.tensor(batch_mask, dtype=torch.float32).to(self.device) + + val_loss, _, _ = self.model(batch_ehr, position_ids=None, ehr_labels=batch_ehr, ehr_masks=batch_mask, pos_loss_weight=self.config.pos_loss_weight) + val_l.append((val_loss).cpu().detach().numpy()) + + cur_val_loss = np.mean(val_l) + print("Epoch %d Validation Loss:%.7f"%(e, cur_val_loss)) + if cur_val_loss < global_loss: + global_loss = cur_val_loss + state = { + 'model': self.model.state_dict(), + 'optimizer': self.optimizer.state_dict(), + 'iteration': i + } + torch.save(state, f'{self.save_dir}halo_model') + print('\n------------ Save best model ------------\n') + + def test(self, testing_results_dir: str = "./results/testing_stats/") -> None: + + ## HELPER FUNC: + def get_batch(loc, batch_size, mode): + # EHR data saved as [(P_1, L_1), (P_2, L_2), ... , (P_i, L_i)] + # Where each patient P is [V_1, V_2, ... , V_j] + # Where each visit V is [C_1, C_2, ... , C_k] + # And where each Label L is a binary vector [L_1 ... L_n] + if mode == 'train': + ehr = self.train_ehr_dataset[loc:loc+batch_size] + elif mode == 'valid': + ehr = self.val_ehr_dataset[loc:loc+batch_size] + else: + ehr = self.test_ehr_dataset[loc:loc+batch_size] + + batch_ehr = np.zeros((len(ehr), self.config.n_ctx, self.config.total_vocab_size)) + batch_mask = np.zeros((len(ehr), self.config.n_ctx, 1)) + for i, p in enumerate(ehr): + visits = p['visits'] + for j, v in enumerate(visits): + batch_ehr[i,j+2][v] = 1 + batch_mask[i,j+2] = 1 + batch_ehr[i,1,self.config.code_vocab_size:self.config.code_vocab_size+self.config.label_vocab_size] = np.array(p['labels']) # Set the patient labels + batch_ehr[i,len(visits)+1,self.config.code_vocab_size+self.config.label_vocab_size+1] = 1 # Set the final visit to have the end token + batch_ehr[i,len(visits)+2:,self.config.code_vocab_size+self.config.label_vocab_size+2] = 1 # Set the rest to the padded visit token + + batch_mask[:,1] = 1 # Set the mask to cover the labels + batch_ehr[:,0,self.config.code_vocab_size+self.config.label_vocab_size] = 1 # Set the first visits to be the start token + batch_mask = batch_mask[:,1:,:] # Shift the mask to match the shifted labels and predictions the model will return + return batch_ehr, batch_mask + + ## HELPER FUNC: + def conf_mat(x, y): + totaltrue = np.sum(x) + totalfalse = len(x) - totaltrue + truepos, totalpos = np.sum(x & y), np.sum(y) + falsepos = totalpos - truepos + return np.array([[totalfalse - falsepos, falsepos], #true negatives, false positives + [totaltrue - truepos, truepos]]) #false negatives, true positives + + ## MAIN TEST FUNC: + checkpoint = torch.load(f'{self.save_dir}halo_model', map_location=torch.device(self.device)) + self.model.load_state_dict(checkpoint['model']) + self.optimizer.load_state_dict(checkpoint['optimizer']) + + confusion_matrix = None + probability_list = [] + loss_list = [] + n_visits = 0 + n_pos_codes = 0 + n_total_codes = 0 + self.model.eval() + with torch.no_grad(): + for v_i in tqdm(range(0, len(self.test_ehr_dataset), 2*self.config.batch_size)): + # Get batch inputs + batch_ehr, batch_mask = get_batch(v_i, 2*self.config.batch_size, 'test') + batch_ehr = torch.tensor(batch_ehr, dtype=torch.float32).to(self.device) + batch_mask = torch.tensor(batch_mask, dtype=torch.float32).to(self.device) + + # Get batch outputs + test_loss, predictions, labels = self.model(batch_ehr, position_ids=None, ehr_labels=batch_ehr, ehr_masks=batch_mask, pos_loss_weight=self.config.pos_loss_weight) + batch_mask_array = batch_mask.squeeze().cpu().detach().numpy() + rounded_preds = np.around(predictions.squeeze().cpu().detach().numpy()).transpose((2,0,1)) + rounded_preds = rounded_preds + batch_mask_array - 1 # Setting the masked visits to be -1 to be ignored by the confusion matrix + rounded_preds = rounded_preds.flatten() + true_values = labels.squeeze().cpu().detach().numpy().transpose((2,0,1)) + true_values = true_values + batch_mask_array - 1 # Setting the masked visits to be -1 to be ignored by the confusion matrix + true_values = true_values.flatten() + + # Append test lost + loss_list.append(test_loss.cpu().detach().numpy()) + + # Add number of visits and codes + n_visits += torch.sum(batch_mask).cpu().item() + n_pos_codes += torch.sum(labels).cpu().item() + n_total_codes += (torch.sum(batch_mask) * self.config.total_vocab_size).cpu().item() + + # Add confusion matrix + batch_cmatrix = conf_mat(true_values == 1, rounded_preds == 1) + batch_cmatrix[0][0] = torch.sum(batch_mask) * self.config.total_vocab_size - batch_cmatrix[0][1] - batch_cmatrix[1][0] - batch_cmatrix[1][1] # Remove the masked values + confusion_matrix = batch_cmatrix if confusion_matrix is None else confusion_matrix + batch_cmatrix + + # Calculate and add probabilities + # Note that the masked codes will have probability 1 and be ignored + label_probs = torch.abs(labels - 1.0 + predictions) + log_prob = torch.sum(torch.log(label_probs)).cpu().item() + probability_list.append(log_prob) + + # Save intermediate values in case of error + intermediate = {} + intermediate["Losses"] = loss_list + intermediate["Confusion Matrix"] = confusion_matrix + intermediate["Probabilities"] = probability_list + intermediate["Num Visits"] = n_visits + intermediate["Num Positive Codes"] = n_pos_codes + intermediate["Num Total Codes"] = n_total_codes + pickle.dump(intermediate, open(f"{testing_results_dir}HALO_intermediate_results.pkl", "wb")) + + #Extract, save, and display test metrics + avg_loss = np.nanmean(loss_list) + tn, fp, fn, tp = confusion_matrix.ravel() + acc = (tn + tp)/(tn+fp+fn+tp) + prc = tp/(tp+fp) + rec = tp/(tp+fn) + f1 = (2 * prc * rec)/(prc + rec) + log_probability = np.sum(probability_list) + pp_visit = np.exp(-log_probability/n_visits) + pp_positive = np.exp(-log_probability/n_pos_codes) + pp_possible = np.exp(-log_probability/n_total_codes) + + metrics_dict = {} + metrics_dict['Test Loss'] = avg_loss + metrics_dict['Confusion Matrix'] = confusion_matrix + metrics_dict['Accuracy'] = acc + metrics_dict['Precision'] = prc + metrics_dict['Recall'] = rec + metrics_dict['F1 Score'] = f1 + metrics_dict['Test Log Probability'] = log_probability + metrics_dict['Perplexity Per Visit'] = pp_visit + metrics_dict['Perplexity Per Positive Code'] = pp_positive + metrics_dict['Perplexity Per Possible Code'] = pp_possible + pickle.dump(metrics_dict, open(f"{testing_results_dir}HALO_Metrics.pkl", "wb")) + + + def convert_ehr(self, ehrs, index_to_code=None): + ehr_outputs = [] + for i in range(len(ehrs)): + ehr = ehrs[i] + ehr_output = [] + labels_output = ehr[1][self.config.code_vocab_size:self.config.code_vocab_size+self.config.label_vocab_size] + if index_to_code is not None: + labels_output = [index_to_code[idx + self.config.code_vocab_size] for idx in np.nonzero(labels_output)[0]] + for j in range(2, len(ehr)): + visit = ehr[j] + visit_output = [] + indices = np.nonzero(visit)[0] + end = False + for idx in indices: + if idx < self.config.code_vocab_size: + visit_output.append(index_to_code[idx] if index_to_code is not None else idx) + elif idx == self.config.code_vocab_size+self.config.label_vocab_size+1: + end = True + if visit_output != []: + ehr_output.append(visit_output) + if end: + break + ehr_outputs.append({'visits': ehr_output, 'labels': labels_output}) + ehr = None + ehr_output = None + labels_output = None + visit = None + visit_output = None + indices = None + return ehr_outputs + + + def synthesize_dataset(self, pkl_save_dir: str = "./results/datasets/") -> None: + + ## HELPER: + def sample_sequence(model, length, context, batch_size, device='cuda', sample=True): + empty = torch.zeros((1,1,self.config.total_vocab_size), device=device, dtype=torch.float32).repeat(batch_size, 1, 1) + context = torch.tensor(context, device=device, dtype=torch.float32).unsqueeze(0).repeat(batch_size, 1) + prev = context.unsqueeze(1) + context = None + with torch.no_grad(): + for _ in range(length-1): + prev = model.sample(torch.cat((prev,empty), dim=1), sample) + if torch.sum(torch.sum(prev[:,:,self.config.code_vocab_size+self.config.label_vocab_size+1], dim=1).bool().int(), dim=0).item() == batch_size: + break + ehr = prev.cpu().detach().numpy() + prev = None + empty = None + return ehr + + ## MAIN FUNC: + synthetic_ehr_dataset = [] + stoken = np.zeros(self.config.total_vocab_size) + stoken[self.config.code_vocab_size+self.config.label_vocab_size] = 1 + for i in tqdm(range(0, len(self.train_ehr_dataset), self.config.sample_batch_size)): + bs = min([len(self.train_ehr_dataset)-i, self.config.sample_batch_size]) + batch_synthetic_ehrs = sample_sequence(self.model, self.config.n_ctx, stoken, batch_size=bs, device=self.device, sample=True) + batch_synthetic_ehrs = self.convert_ehr(batch_synthetic_ehrs) + synthetic_ehr_dataset += batch_synthetic_ehrs + + pickle.dump(synthetic_ehr_dataset, open(f"{pkl_save_dir}haloDataset.pkl", 'wb')) diff --git a/pyhealth/models/generators/halo_resources/halo_config.py b/pyhealth/models/generators/halo_resources/halo_config.py new file mode 100644 index 000000000..2127aa3a0 --- /dev/null +++ b/pyhealth/models/generators/halo_resources/halo_config.py @@ -0,0 +1,42 @@ +''' + code by Brandon Theodorou + Original GPT-2 Paper and repository here: https://github.com/openai/gpt-2 + Original GPT-2 Pytorch Model: https://github.com/huggingface/pytorch-pretrained-BERT + GPT-2 Pytorch Model Derived From: https://github.com/graykode/gpt-2-Pytorch +''' +class HALOConfig(object): + def __init__( + self, + total_vocab_size=6984, + code_vocab_size=6841, + label_vocab_size=25, + special_vocab_size=3, + n_positions=56, + n_ctx=48, + n_embd=768, + n_layer=12, + n_head=12, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + batch_size=48, + sample_batch_size=256, + epoch=50, + pos_loss_weight=None, + lr=1e-4, + ): + self.total_vocab_size = total_vocab_size + self.code_vocab_size = code_vocab_size + self.label_vocab_size = label_vocab_size + self.special_vocab_size = special_vocab_size + self.n_positions = n_positions + self.n_ctx = n_ctx + self.n_embd = n_embd + self.n_layer = n_layer + self.n_head = n_head + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.batch_size = batch_size + self.sample_batch_size = sample_batch_size + self.epoch = epoch + self.pos_loss_weight = pos_loss_weight + self.lr = lr \ No newline at end of file diff --git a/pyhealth/models/generators/halo_resources/halo_model.py b/pyhealth/models/generators/halo_resources/halo_model.py new file mode 100644 index 000000000..a25a9fe04 --- /dev/null +++ b/pyhealth/models/generators/halo_resources/halo_model.py @@ -0,0 +1,237 @@ +''' + code by Brandon Theodorou + Original GPT-2 Paper and repository here: https://github.com/openai/gpt-2 + Original GPT-2 Pytorch Model: https://github.com/huggingface/pytorch-pretrained-BERT + GPT-2 Pytorch Model Derived From: https://github.com/graykode/gpt-2-Pytorch +''' +import copy +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +def gelu(x): + return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + +class LayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-12): + """Construct a layernorm module in the TF style (epsilon inside the square root).""" + super(LayerNorm, self).__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.bias = nn.Parameter(torch.zeros(hidden_size)) + self.variance_epsilon = eps + + def forward(self, x): + u = x.mean(-1, keepdim=True) + s = (x - u).pow(2).mean(-1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.variance_epsilon) + return self.weight * x + self.bias + +class Conv1D(nn.Module): + def __init__(self, nf, nx): + super(Conv1D, self).__init__() + self.nf = nf + w = torch.empty(nx, nf) + nn.init.normal_(w, std=0.02) + self.weight = nn.Parameter(w) + self.bias = nn.Parameter(torch.zeros(nf)) + + def forward(self, x): + size_out = x.size()[:-1] + (self.nf,) + x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) + x = x.view(*size_out) + return x + +class Attention(nn.Module): + def __init__(self, nx, n_ctx, config, scale=False): + super(Attention, self).__init__() + n_state = nx # in Attention: n_state=768 (nx=n_embd) + assert n_state % config.n_head == 0 + self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx)) + self.n_head = config.n_head + self.split_size = n_state + self.scale = scale + self.c_attn = Conv1D(n_state * 3, nx) + self.c_proj = Conv1D(n_state, nx) + + def _attn(self, q, k, v): + w = torch.matmul(q, k) + if self.scale: + w = w / math.sqrt(v.size(-1)) + nd, ns = w.size(-2), w.size(-1) + b = self.bias[:, :, ns-nd:ns, :ns] + w = w * b - 1e10 * (1 - b) + w = nn.Softmax(dim=-1)(w) + return torch.matmul(w, v) + + def merge_heads(self, x): + x = x.permute(0, 2, 1, 3).contiguous() + new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) + return x.view(*new_x_shape) + + def split_heads(self, x, k=False): + new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) + x = x.view(*new_x_shape) + if k: + return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length) + else: + return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + + def forward(self, x, layer_past=None): + x = self.c_attn(x) + query, key, value = x.split(self.split_size, dim=2) + query = self.split_heads(query) + key = self.split_heads(key, k=True) + value = self.split_heads(value) + if layer_past is not None: + past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below + key = torch.cat((past_key, key), dim=-1) + value = torch.cat((past_value, value), dim=-2) + present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking + a = self._attn(query, key, value) + a = self.merge_heads(a) + a = self.c_proj(a) + return a, present + +class MLP(nn.Module): + def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd) + super(MLP, self).__init__() + nx = config.n_embd + self.c_fc = Conv1D(n_state, nx) + self.c_proj = Conv1D(nx, n_state) + self.act = gelu + + def forward(self, x): + h = self.act(self.c_fc(x)) + h2 = self.c_proj(h) + return h2 + +class Block(nn.Module): + def __init__(self, n_ctx, config, scale=False): + super(Block, self).__init__() + nx = config.n_embd + self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon) + self.attn = Attention(nx, n_ctx, config, scale) + self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon) + self.mlp = MLP(4 * nx, config) + + def forward(self, x, layer_past=None): + a, present = self.attn(self.ln_1(x), layer_past=layer_past) + x = x + a + m = self.mlp(self.ln_2(x)) + x = x + m + return x, present + +class CoarseTransformerModel(nn.Module): + def __init__(self, config): + super(CoarseTransformerModel, self).__init__() + self.n_layer = config.n_layer + self.n_embd = config.n_embd + self.n_vocab = config.total_vocab_size + + self.vis_embed_mat = nn.Linear(config.total_vocab_size, config.n_embd, bias=False) + self.pos_embed_mat = nn.Embedding(config.n_positions, config.n_embd) + block = Block(config.n_ctx, config, scale=True) + self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)]) + self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) + + def forward(self, input_visits, position_ids=None, past=None): + if past is None: + past_length = 0 + past = [None] * len(self.h) + else: + past_length = past[0][0].size(-2) + if position_ids is None: + position_ids = torch.arange(past_length, input_visits.size(1) + past_length, dtype=torch.long, + device=input_visits.device) + position_ids = position_ids.unsqueeze(0).expand(input_visits.size(0), input_visits.size(1)) + + inputs_embeds = self.vis_embed_mat(input_visits) + position_embeds = self.pos_embed_mat(position_ids) + hidden_states = inputs_embeds + position_embeds + for block, layer_past in zip(self.h, past): + hidden_states, _ = block(hidden_states, layer_past) + hidden_states = self.ln_f(hidden_states) + return hidden_states + +class AutoregressiveLinear(nn.Linear): + """ same as Linear except has a configurable mask on the weights """ + def __init__(self, in_features, out_features, bias=True): + super().__init__(in_features, out_features, bias) + self.register_buffer('mask', torch.tril(torch.ones(in_features, out_features)).int()) + + def forward(self, input): + return F.linear(input, self.mask * self.weight, self.bias) + +class FineAutoregressiveHead(nn.Module): + def __init__(self, config): + super(FineAutoregressiveHead, self).__init__() + self.auto1 = AutoregressiveLinear(config.n_embd + config.total_vocab_size, config.n_embd + config.total_vocab_size) + self.auto2 = AutoregressiveLinear(config.n_embd + config.total_vocab_size, config.n_embd + config.total_vocab_size) + self.n_embd = config.n_embd + self.tot_vocab = config.total_vocab_size + + def forward(self, history, input_visits): + history = history[:,:-1,:] + input_visits = input_visits[:,1:,:] + code_logits = self.auto2(torch.relu(self.auto1(torch.cat((history, input_visits), dim=2))))[:,:,self.n_embd-1:-1] + return code_logits + + def sample(self, history, input_visits): + history = history[:,:-1,:] + input_visits = input_visits[:,1:,:] + currVisit = torch.cat((history, input_visits), dim=2)[:,-1,:].unsqueeze(1) + code_logits = self.auto2(torch.relu(self.auto1(currVisit)))[:,:,self.n_embd-1:-1] + return code_logits + +class HALOModel(nn.Module): + def __init__(self, config): + super(HALOModel, self).__init__() + self.transformer = CoarseTransformerModel(config) + self.ehr_head = FineAutoregressiveHead(config) + + def forward(self, input_visits, position_ids=None, ehr_labels=None, ehr_masks=None, past=None, pos_loss_weight=None): + hidden_states = self.transformer(input_visits, position_ids, past) + code_logits = self.ehr_head(hidden_states, input_visits) + sig = nn.Sigmoid() + code_probs = sig(code_logits) + if ehr_labels is not None: + shift_labels = ehr_labels[..., 1:, :].contiguous() + loss_weights = None + if pos_loss_weight is not None: + loss_weights = torch.ones(code_probs.shape, device=code_probs.device) + loss_weights = loss_weights + (pos_loss_weight-1) * shift_labels + if ehr_masks is not None: + code_probs = code_probs * ehr_masks + shift_labels = shift_labels * ehr_masks + if pos_loss_weight is not None: + loss_weights = loss_weights * ehr_masks + + bce = nn.BCELoss(weight=loss_weights) + loss = bce(code_probs, shift_labels) + return loss, code_probs, shift_labels + + return code_probs + + def sample(self, input_visits, random=True): + sig = nn.Sigmoid() + hidden_states = self.transformer(input_visits) + i = 0 + while i < self.ehr_head.tot_vocab: + next_logits = self.ehr_head.sample(hidden_states, input_visits) + next_probs = sig(next_logits) + if random: + visit = torch.bernoulli(next_probs) + else: + visit = torch.round(next_probs) + + remaining_visit = visit[:,0,i:] + nonzero = torch.nonzero(remaining_visit, as_tuple=True)[1] + if nonzero.numel() == 0: + break + + first_nonzero = nonzero.min() + input_visits[:,-1,i + first_nonzero] = visit[:,0,i + first_nonzero] + i = i + first_nonzero + 1 + + return input_visits \ No newline at end of file diff --git a/pyhealth/models/generators/medgan.py b/pyhealth/models/generators/medgan.py new file mode 100644 index 000000000..e69de29bb diff --git a/pyhealth/models/generators/promptehr.py b/pyhealth/models/generators/promptehr.py new file mode 100644 index 000000000..e69de29bb