diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 048e2275..6c857c4b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -40,4 +40,4 @@ jobs: # - name: Setup package # run: pip install -e . # - name: Test command line tool - # run: python -m icu_benchmarks.run --help \ No newline at end of file + # run: python -m icu_benchmarks.run --help diff --git a/README.md b/README.md index 808a0764..6f163450 100644 --- a/README.md +++ b/README.md @@ -199,6 +199,7 @@ icu-benchmarks \ > For a list of available flags, run `icu-benchmarks train -h`. > Run with `PYTORCH_ENABLE_MPS_FALLBACK=1` on Macs with Metal Performance Shaders. +> Can set conda enviroment variable by running `conda env config vars set PYTORCH_ENABLE_MPS_FALLBACK=1` [//]: # (> Please note that, for Windows based systems, paths need to be formatted differently, e.g: ` r"\..\data\mortality_seq\hirid"`.) > For Windows based systems, the next line character (\\) needs to be replaced by (^) (Command Prompt) or (`) (Powershell) diff --git a/configs/prediction_models/DeepARpytorch.gin b/configs/prediction_models/DeepARpytorch.gin new file mode 100644 index 00000000..e7dbf54b --- /dev/null +++ b/configs/prediction_models/DeepARpytorch.gin @@ -0,0 +1,143 @@ +# Settings for DeepAR model. + +# Common settings for DL models +include "configs/prediction_models/common/DLCommon.gin" + +# Optimizer params +train_common.model = @DeepARpytorch + +optimizer/hyperparameter.class_to_tune = @Adam +optimizer/hyperparameter.weight_decay = 1e-6 +optimizer/hyperparameter.lr = 0.00046 +#(1e-5, 1e-3) +# Model params + +model/hyperparameter.class_to_tune = @DeepARpytorch +model/hyperparameter.hidden = 116 +#(4, 120, "log-uniform", 2) +model/hyperparameter.rnn_layers=1 + +model/hyperparameter.num_classes = %NUM_CLASSES +model/hyperparameter.cell_type='LSTM' +model/hyperparameter.dropout = (0.286968842146375, 0.29) +#model/hyperparameter.lr_scheduler = "exponential" + +train_common/hyperparameter.class_to_tune = @train_common +train_common/hyperparameter.batch_size=256 +train_common/hyperparameter.gradient_clip_val=0.01 +#(0,0.01, 1.0, 100.0) +# Dataset params +PredictionDatasetpytorch.max_encoder_length = 24 +PredictionDatasetpytorch.max_prediction_length = 1 +PredictionDatasetpytorch.time_varying_known_reals=[] +PredictionDatasetpytorch.add_relative_time_idx=False +PredictionDatasetpytorch.target=[ + "alb", + "alp", + "alt", + "ast", + "be", + "bicar", + "bili", + "bili_dir", + "bnd", + "bun", + "ca", + "cai", + "ck", + "ckmb", + "cl", + "crea", + "crp", + "dbp", + "fgn", + "fio2", + "glu", + "hgb", + "hr", + "inr_pt", + "k", + "lact", + "lymph", + "map", + "mch", + "mchc", + "mcv", + "methb", + "mg", + "na", + "neut", + "o2sat", + "pco2", + "ph", + "phos", + "plt", + "po2", + "ptt", + "resp", + "sbp", + "temp", + "tnt", + "urine", + "wbc", + "label", + + ] + +PredictionDatasetpytorch.time_varying_unknown_reals=[ + "alb", + "alp", + "alt", + "ast", + "be", + "bicar", + "bili", + "bili_dir", + "bnd", + "bun", + "ca", + "cai", + "ck", + "ckmb", + "cl", + "crea", + "crp", + "dbp", + "fgn", + "fio2", + "glu", + "hgb", + "hr", + "inr_pt", + "k", + "lact", + "lymph", + "map", + "mch", + "mchc", + "mcv", + "methb", + "mg", + "na", + "neut", + "o2sat", + "pco2", + "ph", + "phos", + "plt", + "po2", + "ptt", + "resp", + "sbp", + "temp", + "tnt", + "urine", + "wbc", + "label", + + ] +PredictionDatasetpytorch.time_varying_unknown_categoricals=[] +PredictionDatasetpytorch.lagged_variables=[] +PredictionDatasetpytorch.targetnormalizer='multi' + + \ No newline at end of file diff --git a/configs/prediction_models/RNNpytorch.gin b/configs/prediction_models/RNNpytorch.gin new file mode 100644 index 00000000..8f8132e0 --- /dev/null +++ b/configs/prediction_models/RNNpytorch.gin @@ -0,0 +1,142 @@ +# Settings for RNN model. + +# Common settings for DL models +include "configs/prediction_models/common/DLCommon.gin" + +# Optimizer params +train_common.model = @RNNpytorch + +optimizer/hyperparameter.class_to_tune = @Adam +optimizer/hyperparameter.weight_decay = 1e-6 +optimizer/hyperparameter.lr = 0.00041 + +# Model params +model/hyperparameter.class_to_tune = @RNNpytorch +model/hyperparameter.hidden = 214 +#(2, 64, "log-uniform", 2) +model/hyperparameter.rnn_layers=1 +#(1,3) +model/hyperparameter.num_classes = %NUM_CLASSES +model/hyperparameter.cell_type='LSTM' +model/hyperparameter.dropout = (0.244, 0.2441) +#(0.0, 0.4) +#model/hyperparameter.lr_scheduler = "exponential" + +train_common/hyperparameter.class_to_tune = @train_common +train_common/hyperparameter.batch_size=256 +#(32,64,128,256,512) +train_common/hyperparameter.gradient_clip_val=100.0 +#(0,0.01, 1.0, 100.0) +# Dataset params +PredictionDatasetpytorch.max_encoder_length = 24 +PredictionDatasetpytorch.max_prediction_length = 1 +PredictionDatasetpytorch.time_varying_known_reals=[ ] +PredictionDatasetpytorch.add_relative_time_idx=False +PredictionDatasetpytorch.target=[ + "alb", + "alp", + "alt", + "ast", + "be", + "bicar", + "bili", + "bili_dir", + "bnd", + "bun", + "ca", + "cai", + "ck", + "ckmb", + "cl", + "crea", + "crp", + "dbp", + "fgn", + "fio2", + "glu", + "hgb", + "hr", + "inr_pt", + "k", + "lact", + "lymph", + "map", + "mch", + "mchc", + "mcv", + "methb", + "mg", + "na", + "neut", + "o2sat", + "pco2", + "ph", + "phos", + "plt", + "po2", + "ptt", + "resp", + "sbp", + "temp", + "tnt", + "urine", + "wbc", + "label", + + ] + +PredictionDatasetpytorch.time_varying_unknown_reals=[ + "alb", + "alp", + "alt", + "ast", + "be", + "bicar", + "bili", + "bili_dir", + "bnd", + "bun", + "ca", + "cai", + "ck", + "ckmb", + "cl", + "crea", + "crp", + "dbp", + "fgn", + "fio2", + "glu", + "hgb", + "hr", + "inr_pt", + "k", + "lact", + "lymph", + "map", + "mch", + "mchc", + "mcv", + "methb", + "mg", + "na", + "neut", + "o2sat", + "pco2", + "ph", + "phos", + "plt", + "po2", + "ptt", + "resp", + "sbp", + "temp", + "tnt", + "urine", + "wbc", + "label", + + ] +PredictionDatasetpytorch.time_varying_unknown_categoricals=[] +PredictionDatasetpytorch.lagged_variables=[] +PredictionDatasetpytorch.targetnormalizer='multi' \ No newline at end of file diff --git a/configs/prediction_models/TFT.gin b/configs/prediction_models/TFT.gin new file mode 100644 index 00000000..4a0a4e82 --- /dev/null +++ b/configs/prediction_models/TFT.gin @@ -0,0 +1,166 @@ +# Settings for TFT model. + +# Common settings for DL models +include "configs/prediction_models/common/DLCommon.gin" + +# Optimizer params +train_common.model = @TFT +train_common/hyperparameter.class_to_tune = @train_common + + +optimizer/hyperparameter.class_to_tune = @Adam +optimizer/hyperparameter.weight_decay = 1e-6 +optimizer/hyperparameter.lr = (0.0001, 0.001,0.01) +#(0.0001, 0.001,0.01) +# Encoder params +model/hyperparameter.class_to_tune = @TFT +model/hyperparameter.encoder_length = 24 +model/hyperparameter.hidden = 32 +#(10,20,40,80,160,240,320) +model/hyperparameter.num_classes = %NUM_CLASSES +model/hyperparameter.dropout = (0.37, 0.38) +model/hyperparameter.dropout_att = 0.0589 +model/hyperparameter.n_heads =2 +#(1,2,4) +model/hyperparameter.example_length=25 + +#TFT parameters + +TFT.temporal_target_size=1 +TFT.quantiles=[0.1,0.5,0.9] + +train_common/hyperparameter.batch_size=64 +#(32,64,128,256,512) + +train_common/hyperparameter.gradient_clip_val=0.01 +#(0,0.01, 1.0, 100.0) + + +#Vars types + +TFT.vars_type = { + "alb": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "alp": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "alt": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "ast": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "be": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "bicar": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "bili": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "bili_dir": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "bnd": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "bun": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "ca": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "cai": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "ck": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "ckmb": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "cl": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "crea": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "crp": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "dbp": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "fgn": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "fio2": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "glu": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "hgb": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "hr": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "inr_pt": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "k": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "lact": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "lymph": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "map": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "mch": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "mchc": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "mcv": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "methb": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "mg": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "na": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "neut": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "o2sat": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "pco2": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "ph": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "phos": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "plt": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "po2": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "ptt": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "resp": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "sbp": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "temp": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "tnt": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "urine": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "wbc": [%DataTypes.CONTINUOUS, %InputTypes.OBSERVED], + "MissingIndicator_1": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_2": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_3": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_4": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_5": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_6": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_7": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_8": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_9": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_10": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_11": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_12": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_13": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_14": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_15": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_16": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_17": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_18": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_19": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_20": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_21": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_22": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_23": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_24": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_25": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_26": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_27": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_28": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_29": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_30": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_31": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_32": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_33": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_34": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_35": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_36": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_37": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_38": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_39": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_40": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_41": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_42": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_43": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_44": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_45": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_46": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_47": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "MissingIndicator_48": [%DataTypes.CATEGORICAL, %InputTypes.OBSERVED, 2], + "age": [%DataTypes.CONTINUOUS, %InputTypes.STATIC], + "sex": [%DataTypes.CATEGORICAL, %InputTypes.STATIC,2], + "height": [%DataTypes.CONTINUOUS, %InputTypes.STATIC], + "weight": [%DataTypes.CONTINUOUS, %InputTypes.STATIC] + } + + +TFT.vars= ['alb', 'alp', 'alt', 'ast', 'be', 'bicar', 'bili', 'bili_dir', 'bnd', + 'bun', 'ca', 'cai', 'ck', 'ckmb', 'cl', 'crea', 'crp', 'dbp', 'fgn', + 'fio2', 'glu', 'hgb', 'hr', 'inr_pt', 'k', 'lact', 'lymph', 'map', + 'mch', 'mchc', 'mcv', 'methb', 'mg', 'na', 'neut', 'o2sat', 'pco2', + 'ph', 'phos', 'plt', 'po2', 'ptt', 'resp', 'sbp', 'temp', 'tnt', + 'urine', 'wbc', 'MissingIndicator_1', 'MissingIndicator_2', + 'MissingIndicator_3', 'MissingIndicator_4', 'MissingIndicator_5', + 'MissingIndicator_6', 'MissingIndicator_7', 'MissingIndicator_8', + 'MissingIndicator_9', 'MissingIndicator_10', 'MissingIndicator_11', + 'MissingIndicator_12', 'MissingIndicator_13', 'MissingIndicator_14', + 'MissingIndicator_15', 'MissingIndicator_16', 'MissingIndicator_17', + 'MissingIndicator_18', 'MissingIndicator_19', 'MissingIndicator_20', + 'MissingIndicator_21', 'MissingIndicator_22', 'MissingIndicator_23', + 'MissingIndicator_24', 'MissingIndicator_25', 'MissingIndicator_26', + 'MissingIndicator_27', 'MissingIndicator_28', 'MissingIndicator_29', + 'MissingIndicator_30', 'MissingIndicator_31', 'MissingIndicator_32', + 'MissingIndicator_33', 'MissingIndicator_34', 'MissingIndicator_35', + 'MissingIndicator_36', 'MissingIndicator_37', 'MissingIndicator_38', + 'MissingIndicator_39', 'MissingIndicator_40', 'MissingIndicator_41', + 'MissingIndicator_42', 'MissingIndicator_43', 'MissingIndicator_44', + 'MissingIndicator_45', 'MissingIndicator_46', 'MissingIndicator_47', + 'MissingIndicator_48', 'age', 'sex', 'height', 'weight'] diff --git a/configs/prediction_models/TFTpytorch.gin b/configs/prediction_models/TFTpytorch.gin new file mode 100644 index 00000000..37b2dd1a --- /dev/null +++ b/configs/prediction_models/TFTpytorch.gin @@ -0,0 +1,92 @@ +# Settings for TFT model. + +# Common settings for DL models +include "configs/prediction_models/common/DLCommon.gin" + +# Optimizer params +train_common.model = @TFTpytorch + +optimizer/hyperparameter.class_to_tune = @Adam +optimizer/hyperparameter.weight_decay = 1e-6 +optimizer/hyperparameter.lr = 0.001 +#(0.0001, 0.001,0.01) + +# Model params +model/hyperparameter.class_to_tune = @TFTpytorch +model/hyperparameter.hidden = 32 +#(10,20,40,80,160,240,320) +model/hyperparameter.num_classes = %NUM_CLASSES +model/hyperparameter.dropout =(0.37, 0.38) +model/hyperparameter.dropout_att = 0.0589 +#(0.1, 0.9) +model/hyperparameter.n_heads =2 + +#dataloader + +train_common/hyperparameter.class_to_tune = @train_common +train_common/hyperparameter.batch_size=64 +#(32,64,128,256,512) + +train_common/hyperparameter.gradient_clip_val=0.01 +#(0,0.01, 1.0, 100.0) + + + + +# Dataset params +PredictionDatasetpytorch.max_encoder_length = 24 +PredictionDatasetpytorch.max_prediction_length = 1 +PredictionDatasetpytorch.target="label" +PredictionDatasetpytorch.time_varying_known_reals=["time_idx"] +PredictionDatasetpytorch.add_relative_time_idx=False +PredictionDatasetpytorch.time_varying_unknown_categoricals=[] +PredictionDatasetpytorch.time_varying_unknown_reals=["alb", + "alp", + "alt", + "ast", + "be", + "bicar", + "bili", + "bili_dir", + "bnd", + "bun", + "ca", + "cai", + "ck", + "ckmb", + "cl", + "crea", + "crp", + "dbp", + "fgn", + "fio2", + "glu", + "hgb", + "hr", + "inr_pt", + "k", + "lact", + "lymph", + "map", + "mch", + "mchc", + "mcv", + "methb", + "mg", + "na", + "neut", + "o2sat", + "pco2", + "ph", + "phos", + "plt", + "po2", + "ptt", + "resp", + "sbp", + "temp", + "tnt", + "urine", + "wbc",] +PredictionDatasetpytorch.lagged_variables=[] +PredictionDatasetpytorch.targetnormalizer='single' \ No newline at end of file diff --git a/configs/prediction_models/common/DLCommon.gin b/configs/prediction_models/common/DLCommon.gin index c220e6ab..6b76c60d 100644 --- a/configs/prediction_models/common/DLCommon.gin +++ b/configs/prediction_models/common/DLCommon.gin @@ -5,6 +5,8 @@ import gin.torch.external_configurables import icu_benchmarks.models.wrappers import icu_benchmarks.models.dl_models import icu_benchmarks.models.utils +import icu_benchmarks.models.layers +import icu_benchmarks.data.loader # Do not generate features from dynamic data base_classification_preprocessor.generate_features = False @@ -13,9 +15,11 @@ base_regression_preprocessor.generate_features = False # Train params train_common.optimizer = @Adam train_common.epochs = 1000 -train_common.batch_size = 64 -train_common.patience = 10 +#train_common.batch_size = 10 +train_common.patience = 20 train_common.min_delta = 1e-4 + + # Hyperparameter tuning settings include "configs/prediction_models/common/DLTuning.gin" \ No newline at end of file diff --git a/configs/prediction_models/common/DLTuning.gin b/configs/prediction_models/common/DLTuning.gin index b4d13e12..04fb8ff6 100644 --- a/configs/prediction_models/common/DLTuning.gin +++ b/configs/prediction_models/common/DLTuning.gin @@ -1,5 +1,5 @@ # Hyperparameter tuner settings for Deep Learning. -tune_hyperparameters.scopes = ["model", "optimizer"] +tune_hyperparameters.scopes = ["model", "optimizer","train_common"] tune_hyperparameters.n_initial_points = 5 tune_hyperparameters.n_calls = 30 tune_hyperparameters.folds_to_tune_on = 2 \ No newline at end of file diff --git a/configs/tasks/BinaryClassification.gin b/configs/tasks/BinaryClassification.gin index 492a12eb..f3a7790d 100644 --- a/configs/tasks/BinaryClassification.gin +++ b/configs/tasks/BinaryClassification.gin @@ -16,7 +16,7 @@ train_common.ram_cache = True # DEEP LEARNING DLPredictionWrapper.loss = @cross_entropy - +DLPredictionPytorchForecastingWrapper.loss= @cross_entropy # SELECTING PREPROCESSOR preprocess.preprocessor = @base_classification_preprocessor preprocess.vars = %vars diff --git a/configs/tasks/Regression.gin b/configs/tasks/Regression.gin index 5cf3f8d9..1234f707 100644 --- a/configs/tasks/Regression.gin +++ b/configs/tasks/Regression.gin @@ -16,6 +16,7 @@ train_common.ram_cache = True # LOSS FUNCTION DLPredictionWrapper.loss = @mse_loss +DLPredictionPytorchForecastingWrapper.loss = @mse_loss MLWrapper.loss = @mean_squared_error # SELECTING PREPROCESSOR @@ -25,7 +26,7 @@ preprocess.use_static = True # SPECIFYING REGRESSION OUTCOME SCALING base_regression_preprocessor.outcome_min = 0 -base_regression_preprocessor.outcome_max = 15 +base_regression_preprocessor.outcome_max = 168 # SELECTING DATASET PredictionDataset.vars = %vars diff --git a/configs/tasks/common/Imports.gin b/configs/tasks/common/Imports.gin index ea448476..b3da4a80 100644 --- a/configs/tasks/common/Imports.gin +++ b/configs/tasks/common/Imports.gin @@ -3,4 +3,5 @@ import icu_benchmarks.data.split_process_data import icu_benchmarks.data.loader import icu_benchmarks.models.wrappers import icu_benchmarks.models.dl_models -import icu_benchmarks.models.ml_models \ No newline at end of file +import icu_benchmarks.models.ml_models +import icu_benchmarks.models.metrics \ No newline at end of file diff --git a/environment.yml b/environment.yml index 1124300f..d081b46d 100644 --- a/environment.yml +++ b/environment.yml @@ -13,7 +13,7 @@ dependencies: - gin-config=0.5.0 - ignite=0.4.11 - pytorch=2.0.1 - - pytorch-cuda=11.8 + # - pytorch-cuda=11.8 - lightgbm=3.3.5 - numpy=1.24.3 - pandas=2.0.0 @@ -29,10 +29,13 @@ dependencies: - einops=0.6.1 - hydra-core=1.3 - pip: + - quantus==0.5.3 - recipies==0.1.3 # Fixed version because of NumPy incompatibility and stale development status. - scikit-optimize-fix==0.9.1 - hydra-submitit-launcher==1.2.0 + - git+https://github.com/youssefmecky96/pytorchforecasting + - git+https://github.com/youssefmecky96/captum # Note: versioning of Pytorch might be dependent on compatible CUDA version. # Please check yourself if your Pytorch installation supports cuda (for gpu acceleration) diff --git a/experiments/benchmark_regression.yml b/experiments/benchmark_regression.yml index 8aa8d13e..1f9176f3 100644 --- a/experiments/benchmark_regression.yml +++ b/experiments/benchmark_regression.yml @@ -21,10 +21,10 @@ parameters: - ../data/los/hirid - ../data/los/eicu - ../data/los/aumc - - ../data/kf/miiv - - ../data/kf/hirid - - ../data/kf/eicu - - ../data/kf/aumc + - ../data/kidney_function/miiv + - ../data/kidney_function/hirid + - ../data/kidney_function/eicu + - ../data/kidney_function/aumc model: values: - ElasticNet diff --git a/experiments/demo_benchmark_classification.yml b/experiments/demo_benchmark_classification.yml index 32516df2..341a40dd 100644 --- a/experiments/demo_benchmark_classification.yml +++ b/experiments/demo_benchmark_classification.yml @@ -11,7 +11,7 @@ command: - --tune - --wandb-sweep - -gc - - -lc + - -lc method: grid name: yaib_demo_classification_benchmark parameters: @@ -19,7 +19,7 @@ parameters: values: - demo_data/mortality24/eicu_demo - demo_data/mortality24/mimic_demo - - demo_data/aki/eicu_demo + #fails for some reason - demo_data/aki/eicu_demo - demo_data/aki/mimic_demo - demo_data/sepsis/eicu_demo - demo_data/sepsis/mimic_demo @@ -31,6 +31,8 @@ parameters: - LSTM - TCN - Transformer + - TFT + # - TFTpytorch seed: values: - 1111 diff --git a/experiments/demo_benchmark_regression.yml b/experiments/demo_benchmark_regression.yml index 3b678371..57f380b8 100644 --- a/experiments/demo_benchmark_regression.yml +++ b/experiments/demo_benchmark_regression.yml @@ -19,8 +19,8 @@ parameters: values: - demo_data/los/eicu_demo - demo_data/los/mimic_demo - - demo_data/kf/eicu_demo - - demo_data/kf/mimic_demo + # - demo_data/kf/eicu_demo + # - demo_data/kf/mimic_demo model: values: - ElasticNet @@ -28,7 +28,7 @@ parameters: - GRU - LSTM - TCN - - Transformer + - TFT seed: values: - 1111 diff --git a/icu_benchmarks/cross_validation.py b/icu_benchmarks/cross_validation.py index 95e44e1a..ffa11b2a 100644 --- a/icu_benchmarks/cross_validation.py +++ b/icu_benchmarks/cross_validation.py @@ -37,7 +37,12 @@ def execute_repeated_cv( cpu: bool = False, verbose: bool = False, wandb: bool = False, - complete_train: bool = False + complete_train: bool = False, + explain: bool = False, + pytorch_forecasting: bool = False, + XAI_metric: bool = False, + random_labels: bool = False, + random_model_dir: str = None, ) -> float: """Preprocesses data and trains a model for each fold. @@ -101,7 +106,7 @@ def execute_repeated_cv( fold_index=fold_index, pretrained_imputation_model=pretrained_imputation_model, runmode=mode, - complete_train=complete_train + complete_train=complete_train, ) preprocess_time = datetime.now() - start_time @@ -118,15 +123,23 @@ def execute_repeated_cv( cpu=cpu, verbose=verbose, use_wandb=wandb, - train_only=complete_train + train_only=complete_train, + explain=explain, + pytorch_forecasting=pytorch_forecasting, + XAI_metric=XAI_metric, + random_labels=random_labels, + random_model_dir=random_model_dir, ) - train_time = datetime.now() - start_time + train_time = datetime.now() - start_time log_full_line( f"FINISHED FOLD {fold_index}| PREPROCESSING DURATION {preprocess_time}| PROCEDURE DURATION {train_time}", level=logging.INFO, ) - durations = {"preprocessing_duration": preprocess_time, "train_duration": train_time} + durations = { + "preprocessing_duration": preprocess_time, + "train_duration": train_time, + } with open(repetition_fold_dir / "durations.json", "w") as f: json.dump(durations, f, cls=JsonResultLoggingEncoder) @@ -134,6 +147,11 @@ def execute_repeated_cv( wandb_log({"Iteration": repetition * cv_folds_to_train + fold_index}) if repetition * cv_folds_to_train + fold_index > 1: aggregate_results(log_dir) - log_full_line(f"FINISHED CV REPETITION {repetition}", level=logging.INFO, char="=", num_newlines=3) + log_full_line( + f"FINISHED CV REPETITION {repetition}", + level=logging.INFO, + char="=", + num_newlines=3, + ) return agg_loss / (cv_repetitions_to_train * cv_folds_to_train) diff --git a/icu_benchmarks/data/loader.py b/icu_benchmarks/data/loader.py index 3c7a9280..a0eafc38 100644 --- a/icu_benchmarks/data/loader.py +++ b/icu_benchmarks/data/loader.py @@ -1,15 +1,21 @@ from typing import List from pandas import DataFrame +import pandas as pd import gin import numpy as np -from torch import Tensor, cat, from_numpy, float32 +from torch import Tensor, cat, from_numpy, float32, randn_like from torch.utils.data import Dataset import logging -from typing import Dict, Tuple - +from typing import Dict, Tuple, Union from icu_benchmarks.imputation.amputations import ampute_data from .constants import DataSegment as Segment from .constants import DataSplit as Split +from pytorch_forecasting import ( + TimeSeriesDataSet, + GroupNormalizer, + MultiNormalizer, + EncoderNormalizer, +) class CommonDataset(Dataset): @@ -33,7 +39,10 @@ def __init__( self.vars = vars self.grouping_df = data[split][grouping_segment].set_index(self.vars["GROUP"]) self.features_df = ( - data[split][Segment.features].set_index(self.vars["GROUP"]).drop(labels=self.vars["SEQUENCE"], axis=1) + # drops time coulmn and sets index to stay_id + data[split][Segment.features] + .set_index(self.vars["GROUP"]) + .drop(labels=self.vars["SEQUENCE"], axis=1) ) # calculate basic info for the data @@ -93,10 +102,8 @@ def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor]: """ if self._cached_dataset is not None: return self._cached_dataset[idx] - pad_value = 0.0 - stay_id = self.outcome_df.index.unique()[idx] # [self.vars["GROUP"]] - + stay_id = self.outcome_df.index.unique()[idx] # slice to make sure to always return a DF window = self.features_df.loc[stay_id:stay_id].to_numpy() labels = self.outcome_df.loc[stay_id:stay_id][self.vars["LABEL"]].to_numpy(dtype=float) @@ -160,6 +167,142 @@ def to_tensor(self): return from_numpy(data), from_numpy(labels) + +@gin.configurable("PredictionDatasetTFT") +class PredictionDatasetTFT(PredictionDataset): + """ + Subclass of prediction dataset for TFT as we need to define if variables are cont,static,known or observed. + We also need to feed the model the variables in a specific order + Args: + ram_cache (bool, optional): Whether the complete dataset should be stored in ram. Defaults to True. + """ + + def __init__(self, *args, ram_cache: bool = True, **kwargs): + super().__init__(*args, ram_cache=True, **kwargs) + + def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor]: + + """ + Function to sample from the data split of choice. Used for TFT. + The data needs to be given to the model in the following order + [static categorical,static contious,known catergorical,known continous, + observed categorical, observed continous,target ,id] + Args: + idx: A specific row index to sample. + Returns: + A sample from the data, consisting of data, labels and padding mask. + """ + + if self._cached_dataset is not None: + return self._cached_dataset[idx] + + pad_value = 0.0 + stay_id = self.outcome_df.index.unique()[idx] + + # We need to be sure that tensors are returned in the correct order to be processed correclty by tft + tensors = [[] for _ in range(8)] + for var in self.features_df.columns: + if var == "sex": + tensors[0].append(self.features_df.loc[stay_id:stay_id][var].to_numpy()) + elif var == "age" or var == "height" or var == "weight": + tensors[1].append(self.features_df.loc[stay_id:stay_id][var].to_numpy()) + elif "MissingIndicator" in var: + tensors[4].append(self.features_df.loc[stay_id:stay_id][var].to_numpy()) + else: + tensors[5].append(self.features_df.loc[stay_id:stay_id][var].to_numpy()) + + tensors[6].extend( + self.outcome_df.loc[stay_id:stay_id][self.vars["LABEL"]].to_numpy( + dtype=float + ) + ) + tensors[7].append(np.asarray([stay_id])) + window_shape0 = np.shape(tensors[0])[1] + + if len(tensors[6]) == 1: + # only one label per stay, align with window + tensors[6] = np.concatenate( + [np.empty(window_shape0 - 1) * np.nan, tensors[6]], axis=0 + ) + + length_diff = self.maxlen - window_shape0 + pad_mask = np.ones(window_shape0) + # Padding the array to fulfill size requirement + + if length_diff > 0: + # window shorter than the longest window in dataset, pad to same length + tensors[0] = np.concatenate( + [ + tensors[0], + np.ones( + (np.shape(tensors[0])[0], self.maxlen - np.shape(tensors[0])[1]) + ) + * pad_value, + ], + axis=1, + ) + tensors[1] = np.concatenate( + [ + tensors[1], + np.ones( + (np.shape(tensors[1])[0], self.maxlen - np.shape(tensors[1])[1]) + ) + * pad_value, + ], + axis=1, + ) + tensors[4] = np.concatenate( + [ + tensors[4], + np.ones( + (np.shape(tensors[4])[0], self.maxlen - np.shape(tensors[4])[1]) + ) + * pad_value, + ], + axis=1, + ) + tensors[5] = np.concatenate( + [ + tensors[5], + np.ones( + (np.shape(tensors[5])[0], self.maxlen - np.shape(tensors[5])[1]) + ) + * pad_value, + ], + axis=1, + ) + + tensors[6] = np.concatenate( + [ + tensors[6], + np.ones(self.maxlen - np.shape(tensors[6])[0]) * pad_value, + ], + axis=0, + ) + pad_mask = np.concatenate([pad_mask, np.zeros(length_diff)], axis=0) + tensors[7] = np.concatenate( + [ + tensors[7], + np.ones( + (np.shape(tensors[7])[0], self.maxlen - np.shape(tensors[7])[1]) + ) + * stay_id, + ], + axis=1, + ) # should be done regardless of length_diff + not_labeled = np.argwhere(np.isnan(tensors[6])) + if len(not_labeled) > 0: + tensors[6][not_labeled] = -1 + pad_mask[not_labeled] = 0 + tensors[6] = [tensors[6]] + pad_mask = pad_mask.astype(bool) + + tensors = (from_numpy(np.array(tensor)).to(float32) for tensor in tensors) + tensors = [stack((x,), dim=-1) if x.numel() > 0 else empty(0) for x in tensors] + return OrderedDict(zip(Features.FEAT_NAMES, tensors)), from_numpy(pad_mask) + + + @gin.configurable("ImputationDataset") class ImputationDataset(CommonDataset): """Subclass of Common Dataset that contains data for imputation models.""" @@ -291,3 +434,118 @@ def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor]: window = self.dyn_df.loc[stay_id:stay_id, :] return from_numpy(window.values).to(float32) + + +@gin.configurable("PredictionDatasetpytorch") +class PredictionDatasetpytorch(TimeSeriesDataSet): + """Subclass of timeseries dataset works with pyotrch forecasting library . + + Args: + data (DataFrame): dict of the different splits of the data + split: Either 'train','val' or 'test' + max_prediction_length: maximum number of time steps to predict, + max_encoder_length: maximum length of input sequence to give the model, + ram_cache (bool, optional): wether the dataset should be stored in ram. Defaults to True. + """ + + def __init__( + self, + data: dict, + split: str, + time_varying_unknown_reals: List[str], + target: Union[str, List[str]], + time_varying_known_reals: List[str], + time_varying_unknown_categoricals: List[str], + lagged_variables: List[str], + *args, + target_normalizer: str = "", + ram_cache: bool = False, + add_relative_time_idx: bool = False, + name: str = "", + max_prediction_length: int = 24, + max_encoder_length: int = 24, + **kwargs, + ): + data[split]["FEATURES"]["time_idx"] = ((data[split]["FEATURES"]["time"] / pd.Timedelta(seconds=3600))).astype( + int + ) # create an incremental column indicating the time step(required by constructor) + data = data.get(split) # get split + labels = data["OUTCOME"] + + features = data["FEATURES"] + self.name = name + self.data = pd.merge(labels, features, on=["stay_id", "time"]) + if len(lagged_variables) > 0: + if self.data["label"].dtype == "bool": + self.data["label"] = self.data["label"].astype(float) + columns_to_lag = lagged_variables + grouped = self.data.sort_values("time_idx").groupby("stay_id") + for lag in range(1, max_encoder_length + 1): + for column in columns_to_lag: + # Create a new column with lagged values + self.data[f"{column}_lag_{lag}"] = grouped[column].shift(lag, fill_value=0) + + self.split = split + self.args = args + self.ram_cache = ram_cache + self.kwargs = kwargs + self.column_names = features.columns + if target_normalizer == "multi": + target_normalizer = MultiNormalizer( + [EncoderNormalizer(transformation="relu") for _ in range(len(target) - 1)] + + [GroupNormalizer(groups=["stay_id"], transformation="relu")] + ) + else: + target_normalizer = GroupNormalizer(groups=["stay_id"], transformation="relu") + super().__init__( + data=self.data, + time_idx="time_idx", + target=target, + group_ids=["stay_id"], + min_encoder_length=max_encoder_length, + max_encoder_length=max_encoder_length, + min_prediction_length=max_prediction_length, + max_prediction_length=max_prediction_length, + static_categoricals=[], + static_reals=["height", "weight", "age", "sex"], + time_varying_known_categoricals=[], + time_varying_known_reals=time_varying_known_reals, + time_varying_unknown_categoricals=time_varying_unknown_categoricals, + time_varying_unknown_reals=time_varying_unknown_reals, + add_relative_time_idx=add_relative_time_idx, + # add_target_scales=True, + # add_encoder_length=True, + predict_mode=True, + target_normalizer=GroupNormalizer(groups=["stay_id"], transformation="relu"), + ) + + def get_balance(self) -> list: + """Return the weight balance for the split of interest. + + Returns: + Weights for each label. + """ + if len(self.data["target"]) == 1: + counts = self.data["target"][0].unique(return_counts=True) + else: + counts = self.data["target"][-1].unique(return_counts=True) + + return list((1 / counts[1]) * counts[1].sum() / counts[0].shape[0]) + + def get_feature_names(self): + return self.column_names + + def randomize_labels(self, num_classes=None, min=None, max=None): + if num_classes == 1: + random_target = np.random.uniform( + self.data["target"][0].min(), + self.data["target"][0].max(), + size=len(self.data["target"][0]), + ) + else: + random_target = np.random.randint(num_classes, size=len(self.data["target"][0])) + self.data["target"][0] = Tensor(random_target) + + def add_noise(self, num_classes=None, min=None, max=None): + noise = randn_like(self.data["reals"]) * 0.01 + self.data["reals"] += noise diff --git a/icu_benchmarks/data/pooling.py b/icu_benchmarks/data/pooling.py index 7eeb2949..48689d35 100644 --- a/icu_benchmarks/data/pooling.py +++ b/icu_benchmarks/data/pooling.py @@ -16,16 +16,17 @@ class PooledDataset: class PooledData: - def __init__(self, - data_dir, - vars, - datasets, - file_names, - shuffle=False, - stratify=None, - runmode=RunMode.classification, - save_test=True, - ): + def __init__( + self, + data_dir, + vars, + datasets, + file_names, + shuffle=False, + stratify=None, + runmode=RunMode.classification, + save_test=True, + ): """ Generate pooled data from existing datasets. Args: @@ -48,10 +49,10 @@ def __init__(self, self.save_test = save_test def generate( - self, - datasets, - samples=10000, - seed=42, + self, + datasets, + samples=10000, + seed=42, ): """ Generate pooled data from existing datasets. @@ -65,8 +66,8 @@ def generate( if folder.is_dir(): if folder.name in datasets: data[folder.name] = { - f: pq.read_table(folder / self.file_names[f]).to_pandas(self_destruct=True) for f in - self.file_names.keys() + f: pq.read_table(folder / self.file_names[f]).to_pandas(self_destruct=True) + for f in self.file_names.keys() } data = self._pool_datasets( datasets=data, @@ -101,15 +102,15 @@ def _save_pooled_data(self, data_dir, data, datasets, file_names, samples=10000) logging.info(f"Saved pooled data at {save_dir}") def _pool_datasets( - self, - datasets={}, - samples=10000, - vars=[], - seed=42, - shuffle=True, - runmode=RunMode.classification, - data_dir=Path("data"), - save_test=True, + self, + datasets={}, + samples=10000, + vars=[], + seed=42, + shuffle=True, + runmode=RunMode.classification, + data_dir=Path("data"), + save_test=True, ): """ Pool datasets into a single dataset. @@ -144,8 +145,9 @@ def _pool_datasets( # If we have more outcomes than stays, check max label value per stay id labels = outcome.groupby(id).max()[vars[Var.label]].reset_index(drop=True) # if pd.Series(outcome[id].unique()) is outcome[id]): - selected_stays = train_test_split(stays, stratify=labels, shuffle=shuffle, random_state=seed, - train_size=samples) + selected_stays = train_test_split( + stays, stratify=labels, shuffle=shuffle, random_state=seed, train_size=samples + ) else: selected_stays = train_test_split(stays, shuffle=shuffle, random_state=seed, train_size=samples) # Select only stays that are in the selected_stays diff --git a/icu_benchmarks/data/preprocessor.py b/icu_benchmarks/data/preprocessor.py index 69b90f5b..486ab0e8 100644 --- a/icu_benchmarks/data/preprocessor.py +++ b/icu_benchmarks/data/preprocessor.py @@ -205,8 +205,11 @@ def apply(self, data, vars): Returns: Preprocessed data. """ - for split in [Split.train, Split.val, Split.test]: - data = self._process_outcome(data, vars, split) + + self.outcome_max = data["train"]["OUTCOME"]["label"].max() + self.outcome_min = data["train"]["OUTCOME"]["label"].min() + # for split in [Split.train, Split.val, Split.test]: + # data = self._process_outcome(data, vars, split) data = super().apply(data, vars) return data diff --git a/icu_benchmarks/imputation/diffwave.py b/icu_benchmarks/imputation/diffwave.py index 147eeb37..437303ed 100644 --- a/icu_benchmarks/imputation/diffwave.py +++ b/icu_benchmarks/imputation/diffwave.py @@ -61,7 +61,9 @@ def __init__( ) self.final_conv = nn.Sequential( - Conv(skip_channels, skip_channels, kernel_size=1), nn.ReLU(), ZeroConv1d(skip_channels, out_channels) + Conv(skip_channels, skip_channels, kernel_size=1), + nn.ReLU(), + ZeroConv1d(skip_channels, out_channels), ) self.diffusion_parameters = calc_diffusion_hyperparams(diffusion_time_steps, beta_0, beta_T) @@ -93,7 +95,10 @@ def step_fn(self, batch, step_prefix=""): observed_mask = 1 - amputation_mask.float() if step_prefix in ["train", "val"]: - T, Alpha_bar = self.hparams.diffusion_time_steps, self.diffusion_parameters["Alpha_bar"] + T, Alpha_bar = ( + self.hparams.diffusion_time_steps, + self.diffusion_parameters["Alpha_bar"], + ) B, C, L = amputated_data.shape # B is batchsize, C=1, L is audio length diffusion_steps = torch.randint(T, size=(B, 1, 1)).to(self.device) # randomly sample diffusion steps from 1~T @@ -121,7 +126,12 @@ def step_fn(self, batch, step_prefix=""): amputated_data[target_missingness > 0] = target[target_missingness > 0] loss = self.loss(amputated_data, target) for metric in self.metrics[step_prefix].values(): - metric.update((torch.flatten(amputated_data, start_dim=1).clone(), torch.flatten(target, start_dim=1).clone())) + metric.update( + ( + torch.flatten(amputated_data, start_dim=1).clone(), + torch.flatten(target, start_dim=1).clone(), + ) + ) self.log(f"{step_prefix}/loss", loss.item(), prog_bar=True) return loss @@ -230,7 +240,13 @@ def calc_diffusion_hyperparams(diffusion_time_steps, beta_0, beta_T): Sigma = torch.sqrt(Beta_tilde) # \sigma_t^2 = \tilde{\beta}_t _dh = {} - _dh["T"], _dh["Beta"], _dh["Alpha"], _dh["Alpha_bar"], _dh["Sigma"] = diffusion_time_steps, Beta, Alpha, Alpha_bar, Sigma + _dh["T"], _dh["Beta"], _dh["Alpha"], _dh["Alpha_bar"], _dh["Sigma"] = ( + diffusion_time_steps, + Beta, + Alpha, + Alpha_bar, + Sigma, + ) diffusion_hyperparams = _dh return diffusion_hyperparams @@ -239,7 +255,13 @@ class Conv(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1): super(Conv, self).__init__() self.padding = dilation * (kernel_size - 1) // 2 - self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation, padding=self.padding) + self.conv = nn.Conv1d( + in_channels, + out_channels, + kernel_size, + dilation=dilation, + padding=self.padding, + ) self.conv = nn.utils.weight_norm(self.conv) nn.init.kaiming_normal_(self.conv.weight) @@ -261,7 +283,14 @@ def forward(self, x): class Residual_block(nn.Module): - def __init__(self, res_channels, skip_channels, dilation, diffusion_step_embed_dim_out, in_channels): + def __init__( + self, + res_channels, + skip_channels, + dilation, + diffusion_step_embed_dim_out, + in_channels, + ): super(Residual_block, self).__init__() self.res_channels = res_channels @@ -301,7 +330,7 @@ def forward(self, input_data): cond = self.cond_conv(cond) h += cond - out = torch.tanh(h[:, :self.res_channels, :]) * torch.sigmoid(h[:, self.res_channels:, :]) + out = torch.tanh(h[:, : self.res_channels, :]) * torch.sigmoid(h[:, self.res_channels:, :]) res = self.res_conv(out) assert x.shape == res.shape diff --git a/icu_benchmarks/imputation/layers/s4layer.py b/icu_benchmarks/imputation/layers/s4layer.py index 0691c7ef..fd8a53a6 100644 --- a/icu_benchmarks/imputation/layers/s4layer.py +++ b/icu_benchmarks/imputation/layers/s4layer.py @@ -92,6 +92,7 @@ def _resolve_conj(x): def _resolve_conj(x): return x.conj() + """ simple nn.Module components """ @@ -162,16 +163,16 @@ def forward(self, x): def LinearActivation( - d_input, - d_output, - bias=True, - zero_bias_init=False, - transposed=False, - initializer=None, - activation=None, - activate=False, # Apply activation as part of this module - weight_norm=False, - **kwargs, + d_input, + d_output, + bias=True, + zero_bias_init=False, + transposed=False, + initializer=None, + activation=None, + activate=False, # Apply activation as part of this module + weight_norm=False, + **kwargs, ): """Returns a linear nn.Module with control over axes order, initialization, and activation""" @@ -385,7 +386,7 @@ def rank_correction(measure, N, rank=1, dtype=torch.float): P = torch.stack([P0, P1], dim=0) # (2 N) elif measure == "lagt": assert rank >= 1 - P = 0.5 ** 0.5 * torch.ones(1, N, dtype=dtype) + P = 0.5**0.5 * torch.ones(1, N, dtype=dtype) elif measure == "fourier": P = torch.ones(N, dtype=dtype) # (N) P0 = P.clone() @@ -509,18 +510,18 @@ def _omega(self, L, dtype, device, cache=True): return omega, z def __init__( - self, - L, - w, - P, - B, - C, - log_dt, - hurwitz=False, - trainable=None, - lr=None, - tie_state=False, - length_correction=True, + self, + L, + w, + P, + B, + C, + log_dt, + hurwitz=False, + trainable=None, + lr=None, + tie_state=False, + length_correction=True, ): """ L: Maximum length; this module computes an SSM kernel of length L @@ -689,10 +690,10 @@ def forward(self, state=None, rate=1.0, L=None): r11 = r[-self.rank:, -self.rank:, :, :] det = (1 + r11[:1, :1, :, :]) * (1 + r11[1:, 1:, :, :]) - r11[:1, 1:, :, :] * r11[1:, :1, :, :] s = ( - r01[:, :1, :, :] * (1 + r11[1:, 1:, :, :]) * r10[:1, :, :, :] - + r01[:, 1:, :, :] * (1 + r11[:1, :1, :, :]) * r10[1:, :, :, :] - - r01[:, :1, :, :] * (r11[:1, 1:, :, :]) * r10[1:, :, :, :] - - r01[:, 1:, :, :] * (r11[1:, :1, :, :]) * r10[:1, :, :, :] + r01[:, :1, :, :] * (1 + r11[1:, 1:, :, :]) * r10[:1, :, :, :] + + r01[:, 1:, :, :] * (1 + r11[:1, :1, :, :]) * r10[1:, :, :, :] + - r01[:, :1, :, :] * (r11[:1, 1:, :, :]) * r10[1:, :, :, :] + - r01[:, 1:, :, :] * (r11[1:, :1, :, :]) * r10[:1, :, :, :] ) s = s / det k_f = r00 - s @@ -737,7 +738,7 @@ def _setup_linear(self): dt = torch.exp(self.log_dt) D = (2.0 / dt.unsqueeze(-1) - w).reciprocal() # (H, N) R = ( - torch.eye(self.rank, dtype=w.dtype, device=w.device) + 2 * contract("r h n, h n, s h n -> h r s", Q, D, P).real + torch.eye(self.rank, dtype=w.dtype, device=w.device) + 2 * contract("r h n, h n, s h n -> h r s", Q, D, P).real ) # (H r r) Q_D = rearrange(Q * D, "r h n -> h r n") R = torch.linalg.solve(R.to(Q_D), Q_D) # (H r N) @@ -778,8 +779,8 @@ def _step_state_linear(self, u=None, state=None): def contract_fn(p, x, y): return contract("r h n, r h m, ... h m -> ... h n", _conj(p), _conj(x), _conj(y))[ - ..., : self.N - ] # inner outer product + ..., : self.N + ] # inner outer product else: assert state.size(-1) == 2 * self.N @@ -940,24 +941,24 @@ class HippoSSKernel(nn.Module): """ def __init__( - self, - H, - N=64, - L=1, - measure="legs", - rank=1, - channels=1, # 1-dim to C-dim map; can think of C as having separate "heads" - dt_min=0.001, - dt_max=0.1, - trainable=None, # Dictionary of options to train various HiPPO parameters - lr=None, # Hook to set LR of hippo parameters differently - length_correction=True, # Multiply by I-A|^L after initialization; can be turned off for initialization speed - hurwitz=False, - tie_state=False, # Tie parameters of HiPPO ODE across the H features - precision=1, # 1 (single) or 2 (double) for the kernel - resample=False, # If given inputs of different lengths, adjust the sampling rate. - # Note that L should always be provided in this case, as it assumes that L is the true underlying - # length of the continuous signal + self, + H, + N=64, + L=1, + measure="legs", + rank=1, + channels=1, # 1-dim to C-dim map; can think of C as having separate "heads" + dt_min=0.001, + dt_max=0.1, + trainable=None, # Dictionary of options to train various HiPPO parameters + lr=None, # Hook to set LR of hippo parameters differently + length_correction=True, # Multiply by I-A|^L after initialization; can be turned off for initialization speed + hurwitz=False, + tie_state=False, # Tie parameters of HiPPO ODE across the H features + precision=1, # 1 (single) or 2 (double) for the kernel + resample=False, # If given inputs of different lengths, adjust the sampling rate. + # Note that L should always be provided in this case, as it assumes that L is the true underlying + # length of the continuous signal ): super().__init__() self.N = N @@ -1007,23 +1008,23 @@ def get_torch_trans(heads=8, layers=1, channels=64): class S4(nn.Module): def __init__( - self, - d_model, - d_state=64, - l_max=1, # Maximum length of sequence. Fine if not provided: the kernel will keep doubling in length until longer - # than sequence. However, this can be marginally slower if the true length is not a power of 2 - channels=1, # maps 1-dim to C-dim - bidirectional=False, - # Arguments for FF - activation="gelu", # activation in between SS and FF - postact=None, # activation after FF - initializer=None, # initializer on FF - weight_norm=False, # weight normalization on FF - hyper_act=None, # Use a "hypernetwork" multiplication - dropout=0.0, - transposed=True, # axis ordering (B, L, D) or (B, D, L) - # SSM Kernel arguments - **kernel_args, + self, + d_model, + d_state=64, + l_max=1, # Maximum length of sequence. Fine if not provided: the kernel will keep doubling in length until longer + # than sequence. However, this can be marginally slower if the true length is not a power of 2 + channels=1, # maps 1-dim to C-dim + bidirectional=False, + # Arguments for FF + activation="gelu", # activation in between SS and FF + postact=None, # activation after FF + initializer=None, # initializer on FF + weight_norm=False, # weight normalization on FF + hyper_act=None, # Use a "hypernetwork" multiplication + dropout=0.0, + transposed=True, # axis ordering (B, L, D) or (B, D, L) + # SSM Kernel arguments + **kernel_args, ): """ d_state: the dimension of the state, also denoted by N diff --git a/icu_benchmarks/imputation/sssds4.py b/icu_benchmarks/imputation/sssds4.py index ce6e8c0d..205affa4 100644 --- a/icu_benchmarks/imputation/sssds4.py +++ b/icu_benchmarks/imputation/sssds4.py @@ -298,7 +298,7 @@ def forward(self, input_data): h = self.S42(h.permute(2, 0, 1)).permute(1, 2, 0) - out = torch.tanh(h[:, :self.res_channels, :]) * torch.sigmoid(h[:, self.res_channels:, :]) + out = torch.tanh(h[:, : self.res_channels, :]) * torch.sigmoid(h[:, self.res_channels:, :]) res = self.res_conv(out) assert x.shape == res.shape @@ -414,7 +414,8 @@ def calc_diffusion_hyperparams(diffusion_time_steps, beta_0, beta_T): Beta_tilde = Beta + 0 for t in range(1, diffusion_time_steps): Alpha_bar[t] *= Alpha_bar[t - 1] # \bar{\alpha}_t = \prod_{s=1}^t \alpha_s - Beta_tilde[t] *= (1 - Alpha_bar[t - 1]) / (1 - Alpha_bar[t]) # \tilde{\beta}_t = \beta_t * (1-\bar{\alpha}_{t-1}) + # \tilde{\beta}_t = \beta_t * (1-\bar{\alpha}_{t-1}) + Beta_tilde[t] *= (1 - Alpha_bar[t - 1]) / (1 - Alpha_bar[t]) # / (1-\bar{\alpha}_t) Sigma = torch.sqrt(Beta_tilde) # \sigma_t^2 = \tilde{\beta}_t diff --git a/icu_benchmarks/models/constants.py b/icu_benchmarks/models/constants.py index 45af8271..752fd8ba 100644 --- a/icu_benchmarks/models/constants.py +++ b/icu_benchmarks/models/constants.py @@ -1,4 +1,9 @@ -from ignite.contrib.metrics import AveragePrecision, ROC_AUC, RocCurve, PrecisionRecallCurve +from ignite.contrib.metrics import ( + AveragePrecision, + ROC_AUC, + RocCurve, + PrecisionRecallCurve, +) from ignite.metrics import Accuracy, RootMeanSquaredError from sklearn.calibration import calibration_curve from sklearn.metrics import ( @@ -19,7 +24,7 @@ CalibrationError, F1Score, ) -from enum import Enum +from enum import Enum,IntEnum from icu_benchmarks.models.custom_metrics import ( CalibrationCurve, BalancedAccuracy, @@ -27,6 +32,7 @@ JSD, BinaryFairnessWrapper, ) +import gin class MLMetrics: @@ -95,3 +101,21 @@ class ImputationInit(str, Enum): XAVIER = "xavier" KAIMING = "kaiming" ORTHOGONAL = "orthogonal" + + +@gin.constants_from_enum +class DataTypes(Enum): + """Defines numerical types of each column.""" + CONTINUOUS = 0 + CATEGORICAL = 1 + DATE = 2 + STR = 3 +@gin.constants_from_enum +class InputTypes(IntEnum): + """Defines input types of each column.""" + TARGET = 0 + OBSERVED = 1 + KNOWN = 2 + STATIC = 3 + ID = 4 # Single column used as an entity identifier + TIME = 5 # Single column exclusively used as a time index \ No newline at end of file diff --git a/icu_benchmarks/models/custom_metrics.py b/icu_benchmarks/models/custom_metrics.py index ddb5d37e..0cff08c0 100644 --- a/icu_benchmarks/models/custom_metrics.py +++ b/icu_benchmarks/models/custom_metrics.py @@ -6,6 +6,7 @@ from sklearn.calibration import calibration_curve from scipy.spatial.distance import jensenshannon from torchmetrics.classification import BinaryFairness +from quantus.functions.similarity_func import correlation_spearman, cosine """" This file contains custom metrics that can be added to YAIB. @@ -32,7 +33,9 @@ def accuracy(output, target, topk=(1,)): class BalancedAccuracy(EpochMetric): def __init__(self, output_transform: Callable = lambda x: x, check_compute_fn: bool = False) -> None: super(BalancedAccuracy, self).__init__( - self.balanced_accuracy_compute, output_transform=output_transform, check_compute_fn=check_compute_fn + self.balanced_accuracy_compute, + output_transform=output_transform, + check_compute_fn=check_compute_fn, ) def balanced_accuracy_compute(y_preds: torch.Tensor, y_targets: torch.Tensor) -> float: @@ -44,7 +47,9 @@ def balanced_accuracy_compute(y_preds: torch.Tensor, y_targets: torch.Tensor) -> class CalibrationCurve(EpochMetric): def __init__(self, output_transform: Callable = lambda x: x, check_compute_fn: bool = False) -> None: super(CalibrationCurve, self).__init__( - self.ece_curve_compute_fn, output_transform=output_transform, check_compute_fn=check_compute_fn + self.ece_curve_compute_fn, + output_transform=output_transform, + check_compute_fn=check_compute_fn, ) def ece_curve_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor, n_bins=10) -> float: @@ -69,6 +74,7 @@ def __init__( def mae_with_invert_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor, invert_fn=Callable) -> float: y_true = invert_fn(y_targets.numpy().reshape(-1, 1))[:, 0] y_pred = invert_fn(y_preds.numpy().reshape(-1, 1))[:, 0] + return mean_absolute_error(y_true, y_pred) @@ -130,3 +136,423 @@ def feature_helper(self, trainer, step_prefix): else: feature_names = trainer.test_dataloaders.dataset.features return feature_names + + +# XAI Metrics + + +class Faithfulness(EpochMetric): + def __init__(self, output_transform: Callable = lambda x: x, check_compute_fn: bool = False, *args, **kwargs) -> None: + super().__init__(output_transform, check_compute_fn, *args, **kwargs) + + def add_noise(self, x, indices, time_step, feature, feature_timestep): + noise = torch.randn_like(x["encoder_cont"]) + if time_step: + idx0, idx1 = np.meshgrid(indices[0], indices[1], indexing="ij") + + with torch.no_grad(): + x["encoder_cont"][idx0, idx1, :] += noise[idx0, idx1, :] + + elif feature: + idx0, idx1 = np.meshgrid(indices[0], indices[1], indexing="ij") + + with torch.no_grad(): + x["encoder_cont"][idx0, :, idx1] += noise[idx0, :, idx1] + + elif feature_timestep: + idx0, idx1, idx2 = np.meshgrid(indices[0], indices[1], indices[2], indexing="ij") + + with torch.no_grad(): + x["encoder_cont"][idx0, idx1, idx2] += noise[idx0, idx1, idx2] + return x + + def apply_baseline(self, x, indices, time_step, feature, feature_timestep): + mask = torch.ones_like(x["encoder_cont"]) + if time_step: + ( + idx0, + idx1, + ) = np.meshgrid(indices[0], indices[1], indexing="ij") + + mask[idx0, idx1, :] -= mask[idx0, idx1, :] + elif feature: + ( + idx0, + idx1, + ) = np.meshgrid(indices[0], indices[1], indexing="ij") + + mask[idx0, :, idx1] -= mask[idx0, :, idx1] + + elif feature_timestep: + idx0, idx1, idx2 = np.meshgrid(indices[0], indices[1], indices[2], indexing="ij") + + mask[idx0, idx1, idx2] -= mask[idx0, idx1, idx2] + + with torch.no_grad(): + x["encoder_cont"] *= mask + return x + + def update( + self, + x, + attribution, + model, + similarity_func=None, + nr_runs=100, + pertrub=None, + subset_size=3, + feature=False, + time_step=False, + feature_timestep=False, + device="cuda", + ): + """ + Calculates faithfulness scores for captum attributions + + Args: + - x:Batch input + -attribution: attribution generated by captum, + - similarity_func:function to determine similarity between sum of attributions and difference in prediction + - nr_runs: How many times to repeat the experiment, + - pertrub: What change to do to the input, + - subset_size: The size of the subset of featrues to alter , + - feature: Determines if to calcualte faithfulness of feature attributions, + - time_step: Determines if to calcualte faithfulness of timesteps attributions, + - feature_timestep: Determines if to calcualte faithfulness of featrues per timesteps attributions, + Returns: + score: similarity score between sum of attributions and difference in prediction averaged over nr_runs + + Implementation of faithfulness correlation by Bhatt et al., 2020. + + The Faithfulness Correlation metric intend to capture an explanation's relative faithfulness + (or 'fidelity') with respect to the model behaviour. + + Faithfulness correlation scores shows to what extent the predicted logits of each modified test point and + the average explanation attribution for only the subset of features are (linearly) correlated, taking the + average over multiple runs and test samples. The metric returns one float per input-attribution pair that + ranges between -1 and 1, where higher scores are better. + + For each test sample, |S| features are randomly selected and replace them with baseline values (zero baseline + or average of set). Thereafter, Pearson’s correlation coefficient between the predicted logits of each modified + test point and the average explanation attribution for only the subset of features is calculated. Results is + average over multiple runs and several test samples. + This code is adapted from the quantus libray to suit our use case + + References: + 1) Umang Bhatt et al.: "Evaluating and aggregating feature-based model + explanations." IJCAI (2020): 3016-3022. + 2)Hedström, Anna, et al. "Quantus: An explainable ai toolkit for + responsible evaluation of neural network explanations and beyond." + Journal of Machine Learning Research 24.34 (2023): 1-11. + """ + + # Assuming 'attribution' is already a GPU tensor + attribution = torch.tensor(attribution).to(device) + # Other initializations + if similarity_func is None: + similarity_func = correlation_spearman + if pertrub is None: + pertrub = "baseline" + similarities = [] + + # Assuming this is a method to prepare your data + + y_pred = model(model.prep_data(x)).detach() # Keep on GPU + pred_deltas = [] + att_sums = [] + + for i_ix in range(nr_runs): + if time_step: + timesteps_idx = np.random.choice(24, subset_size, replace=False) + patient_idx = np.random.choice(64, 1, replace=False) + a_ix = [patient_idx, timesteps_idx] + + elif feature: + feature_idx = np.random.choice(53, subset_size, replace=False) + patient_idx = np.random.choice(64, 1, replace=False) + a_ix = [patient_idx, feature_idx] + elif feature_timestep: + timesteps_idx = np.random.choice(24, subset_size[0], replace=False) + feature_idx = np.random.choice(53, subset_size[1], replace=False) + patient_idx = np.random.choice(64, 1, replace=False) + a_ix = [patient_idx, timesteps_idx, feature_idx] + + # Apply perturbation + if pertrub == "Noise": + x = self.add_noise(x, a_ix, time_step, feature, feature_timestep) + elif pertrub == "baseline": + x = self.apply_baseline(x, a_ix, time_step, feature, feature_timestep) + + # Predict on perturbed input and calculate deltas + y_pred_perturb = (model(model.prep_data(x))).detach() # Keep on GPU + + if time_step: + if attribution.size() == torch.Size([24]): + att_sums.append((attribution[timesteps_idx]).sum()) + else: + att_sums.append((attribution[patient_idx, :][:, timesteps_idx]).sum()) + elif feature: + if len(attribution) == 53: + att_sums.append((attribution[feature_idx]).sum()) + else: + att_sums.append((attribution[patient_idx, :][:, feature_idx]).sum()) + elif feature_timestep: + att_sums.append((attribution[patient_idx, :, :][:, timesteps_idx, :][:, :, feature_idx]).sum()) + + pred_deltas.append((y_pred - y_pred_perturb)[patient_idx].item()) + # Convert to CPU for numpy operations + + pred_deltas_cpu = torch.tensor(pred_deltas).cpu().numpy() + att_sums_cpu = torch.tensor(att_sums).cpu().numpy() + + similarities.append(similarity_func(pred_deltas_cpu, att_sums_cpu)) + + score = np.nanmean(similarities) + return score + + +class Stability(EpochMetric): + def __init__(self, output_transform: Callable = lambda x: x, check_compute_fn: bool = False, *args, **kwargs) -> None: + super().__init__(output_transform, check_compute_fn, *args, **kwargs) + + def update(self, x, attribution, model, explain_method, dataloader=None, thershold=0.5, device="cuda", **kwargs): + """ + Args: + - x:Batch input + -attribution: attribution + - explain_method:function to generate explantations + - method_name: Name of the explantation + - dataloader:In case of using Attention as the explain method need to pass the dataloader instead of the batch. + + + Returns: + RIS : relative distance between the explantation and the input + ROS: relative distance between the explantation and the output + + + References: + 1) `https://arxiv.org/pdf/2203.06877.pdf + 2)Hedström, Anna, et al. "Quantus: An explainable ai toolkit for + responsible evaluation of neural network explanations and beyond." + Journal of Machine Learning Research 24.34 (2023): 1-11. + + """ + + def relative_stability_objective( + x, + xs, + e_x, + e_xs, + close_indices, + eps_min=0.0001, + input=False, + attention=False, + device="cuda", + ) -> torch.Tensor: + """ + Computes relative input and output stabilities maximization objective + as defined here :ref:`https://arxiv.org/pdf/2203.06877.pdf` by the authors. + + Args: + + x: Input tensor + xs: perturbed tensor. + e_x: Explanations for x. + e_xs: Explanations for xs. + eps_min:Value to avoid division by zero if needed + input:Boolean to indicate if this is an input or an output + device: the device to keep the tensors on + + Returns: + + ris_obj: Tensor + RIS maximization objective. + """ + + # Function to convert inputs to tensors if they are numpy arrays + def to_tensor(input_array): + if isinstance(input_array, np.ndarray): + return torch.index_select(torch.tensor(input_array).to(device), 0, close_indices) + + return torch.index_select(input_array.to(device), 0, close_indices) + + # Convert all inputs to tensors and move to GPU + if attention: + x, xs = map(to_tensor, [x, xs]) + else: + x, xs, e_x, e_xs = map(to_tensor, [x, xs, e_x, e_xs]) + + if input: + num_dim = x.ndim + else: + num_dim = e_x.ndim + + if num_dim == 3: + + def norm_function(arr): + return torch.norm(arr, dim=(-1, -2)) + + elif num_dim == 2: + + def norm_function(arr): + return torch.norm(arr, dim=-1) + + else: + + def norm_function(arr): + return torch.norm(arr) + + nominator = (e_x - e_xs) / (e_x + (e_x == 0) * eps_min) + nominator = norm_function(nominator) + + if input: + denominator = x - xs + denominator /= x + (x == 0) * eps_min + denominator = norm_function(denominator) + denominator += (denominator == 0) * eps_min + else: + denominator = torch.squeeze(x) - torch.squeeze(xs) + denominator = torch.norm(denominator, dim=-1) + denominator += (denominator == 0) * eps_min + + return nominator / denominator + + if explain_method == "Attention": + y_pred = model.model.predict(dataloader) + x_original = dataloader.dataset.data["reals"].clone() + + dataloader.dataset.add_noise() + x_perturb = dataloader.dataset.data["reals"].clone() + y_pred_perturb = model.model.predict(dataloader) + Attention_weights = model.interpertations(dataloader) + att_perturb = Attention_weights["attention"] + # Calculate the absolute difference + difference = torch.abs(y_pred_perturb - y_pred) + + # Find where the difference is less than or equal to a thershold + close_indices = torch.nonzero(difference <= thershold).squeeze()[:, 0].to(device) + + RIS = relative_stability_objective( + x_original.detach(), + x_perturb.detach(), + attribution, + att_perturb, + close_indices=close_indices, + input=True, + attention=True, + ) + ROS = relative_stability_objective( + y_pred, + y_pred_perturb, + attribution, + att_perturb, + close_indices=close_indices, + input=False, + attention=True, + ) + + else: + y_pred = model(model.prep_data(x)).detach() + x_original = x["encoder_cont"].detach().clone() + + with torch.no_grad(): + noise = torch.randn_like(x["encoder_cont"]) * 0.01 + x["encoder_cont"] += noise + y_pred_perturb = model(model.prep_data(x)).detach() + if explain_method == "Random": + att_perturb = np.random.normal(size=[64, 24, 53]) + + else: + att_perturb, features_attrs, timestep_attrs = model.explantation2(x, explain_method) + + # + # Calculate the absolute difference + difference = torch.abs(y_pred_perturb - y_pred) + + # Find where the difference is less than or equal to a thershold + close_indices = torch.nonzero(difference <= thershold).squeeze()[:, 0].to(device) + + RIS = relative_stability_objective( + x_original.detach(), + x["encoder_cont"].detach(), + attribution, + att_perturb, + close_indices=close_indices, + input=True, + ) + ROS = relative_stability_objective( + y_pred, + y_pred_perturb, + attribution, + att_perturb, + close_indices=close_indices, + input=False, + ) + + return np.max(RIS.cpu().numpy()).astype(np.float64), np.max(ROS.cpu().numpy()).astype(np.float64) + + +class Randomization(EpochMetric): + def __init__(self, output_transform: Callable = lambda x: x, check_compute_fn: bool = False, *args, **kwargs) -> None: + super().__init__(output_transform, check_compute_fn, *args, **kwargs) + + def update(self, x, attribution, model, explain_method, random_model, similarity_func=cosine, dataloader=None, **kwargs): + """ + + Args: + - x:Batch input + -attribution: attribution + - explain_method:function to generate explantations + - random_model: Reference to model trained on random labels + - similarity_func: Function to measure similiarity + - dataloader:In case of using Attention as the explain method need to pass the dataloader instead of the batch , + - method_name: Name of the explantation + + Returns: + score: similarity score between attributions of model trained on random data and model trained on real data + + Implementation of the Random Logit Metric by Sixt et al., 2020. + + The Random Logit Metric computes the distance between the original explanation and a reference explanation of + a randomly chosen non-target class. + This code is adapted from the quantus libray to suit our use case + + References: + 1) Leon Sixt et al.: "When Explanations Lie: Why Many Modified BP + Attributions Fail." ICML (2020): 9046-9057. + 2)Hedström, Anna, et al. "Quantus: An explainable ai toolkit for + responsible evaluation of neural network explanations and beyond." + Journal of Machine Learning Research 24.34 (2023): 1-11. + + """ + + if explain_method == "Attention": + Attention_weights = random_model.interpertations(dataloader) + attribution = attribution.cpu().numpy() + min_val = np.min(attribution) + max_val = np.max(attribution) + + attribution = (attribution - min_val) / (max_val - min_val) + random_attr = Attention_weights["attention"].cpu().numpy() + min_val = np.min(random_attr) + max_val = np.max(random_attr) + random_attr = (random_attr - min_val) / (max_val - min_val) + score = similarity_func(random_attr, attribution) + elif explain_method == "Random": + score = similarity_func(np.random.normal(size=[64, 24, 53]).flatten(), attribution.flatten()) + else: + data, baselines = model.prep_data_captum(x) + + random_attr, features_attrs, timestep_attrs = model.explantation2(x, explain_method) + + attribution = attribution.flatten() + min_val = np.min(attribution) + max_val = np.max(attribution) + attribution = (attribution - min_val) / (max_val - min_val) + random_attr = random_attr.flatten() + min_val = np.min(random_attr) + max_val = np.max(random_attr) + random_attr = (random_attr - min_val) / (max_val - min_val) + + score = similarity_func(random_attr, attribution) + return score diff --git a/icu_benchmarks/models/dl_models.py b/icu_benchmarks/models/dl_models.py index 0fb1b0d2..68b6d2ae 100644 --- a/icu_benchmarks/models/dl_models.py +++ b/icu_benchmarks/models/dl_models.py @@ -2,9 +2,33 @@ from numbers import Integral import numpy as np import torch.nn as nn +from typing import Dict from icu_benchmarks.contants import RunMode -from icu_benchmarks.models.layers import TransformerBlock, LocalBlock, TemporalBlock, PositionalEncoding -from icu_benchmarks.models.wrappers import DLPredictionWrapper +from icu_benchmarks.models.constants import InputTypes,DataTypes +from icu_benchmarks.models.layers import ( + TransformerBlock, + LocalBlock, + TemporalBlock, + PositionalEncoding, + MaybeLayerNorm, + GLU, + GRN, + TFTEmbedding, + LazyEmbedding, + VariableSelectionNetwork, + StaticCovariateEncoder, + InterpretableMultiHeadAttention, + TFTBack +) +import matplotlib.pyplot as plt +from icu_benchmarks.models.wrappers import ( + DLPredictionWrapper, + DLPredictionPytorchForecastingWrapper, +) +from torch import Tensor,cat,stack,from_numpy,zeros +from pytorch_forecasting import TemporalFusionTransformer, RecurrentNetwork, DeepAR +from pytorch_forecasting.metrics import QuantileLoss +from collections import OrderedDict @gin.configurable @@ -15,7 +39,12 @@ class RNNet(DLPredictionWrapper): def __init__(self, input_size, hidden_dim, layer_dim, num_classes, *args, **kwargs): super().__init__( - input_size=input_size, hidden_dim=hidden_dim, layer_dim=layer_dim, num_classes=num_classes, *args, **kwargs + input_size=input_size, + hidden_dim=hidden_dim, + layer_dim=layer_dim, + num_classes=num_classes, + *args, + **kwargs, ) self.hidden_dim = hidden_dim self.layer_dim = layer_dim @@ -41,7 +70,12 @@ class LSTMNet(DLPredictionWrapper): def __init__(self, input_size, hidden_dim, layer_dim, num_classes, *args, **kwargs): super().__init__( - input_size=input_size, hidden_dim=hidden_dim, layer_dim=layer_dim, num_classes=num_classes, *args, **kwargs + input_size=input_size, + hidden_dim=hidden_dim, + layer_dim=layer_dim, + num_classes=num_classes, + *args, + **kwargs, ) self.hidden_dim = hidden_dim self.layer_dim = layer_dim @@ -68,7 +102,12 @@ class GRUNet(DLPredictionWrapper): def __init__(self, input_size, hidden_dim, layer_dim, num_classes, *args, **kwargs): super().__init__( - input_size=input_size, hidden_dim=hidden_dim, layer_dim=layer_dim, num_classes=num_classes, *args, **kwargs + input_size=input_size, + hidden_dim=hidden_dim, + layer_dim=layer_dim, + num_classes=num_classes, + *args, + **kwargs, ) self.hidden_dim = hidden_dim self.layer_dim = layer_dim @@ -235,7 +274,17 @@ class TemporalConvNet(DLPredictionWrapper): _supported_run_modes = [RunMode.classification, RunMode.regression] - def __init__(self, input_size, num_channels, num_classes, *args, max_seq_length=0, kernel_size=2, dropout=0.0, **kwargs): + def __init__( + self, + input_size, + num_channels, + num_classes, + *args, + max_seq_length=0, + kernel_size=2, + dropout=0.0, + **kwargs, + ): super().__init__( input_size=input_size, num_channels=num_channels, @@ -280,3 +329,433 @@ def forward(self, x): o = o.permute(0, 2, 1) # Permute to channel last pred = self.logit(o) return pred + + + +@gin.configurable +class TFT(DLPredictionWrapper): + """ + Implementation of https://arxiv.org/abs/1912.09363 + from https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Forecasting/TFT + """ + + _supported_run_modes = [RunMode.classification, RunMode.regression] + + def __init__( + self, + num_classes, + encoder_length, # determines interval to use for prediction + hidden, + dropout, + n_heads, + dropout_att, + example_length, # determines interval to predict + *args, + quantiles :list = gin.REQUIRED, # quantiles to produce + # number of target variables + vars_type:dict =gin.REQUIRED, + vars: Dict[str, str] = gin.REQUIRED, + temporal_target_size:int=gin.REQUIRED, + **kwargs, + ): + self.vars_type=vars_type + self.vars=vars + static_categorical_inp_size=[] # number of catergories + temporal_known_categorical_inp_size=[] + temporal_observed_categorical_inp_size=[] # number of categories in each category of observed variables + static_continuous_inp_size=0 # number of static coutinous variables + temporal_known_continuous_inp_size=0 + temporal_observed_continuous_inp_size=0 + #Infering # of varaibles in each category based on the gin input + for value in self.vars_type.values(): + if value==[DataTypes.CONTINUOUS, InputTypes.OBSERVED]: + temporal_observed_continuous_inp_size+=1 + elif value[0:2]==[DataTypes.CATEGORICAL, InputTypes.OBSERVED]:#categoral variables need to define also # of categories + temporal_observed_categorical_inp_size.append(value[2]) + elif value==[DataTypes.CONTINUOUS, InputTypes.STATIC]: + static_continuous_inp_size+=1 + elif value[0:2]==[DataTypes.CATEGORICAL, InputTypes.STATIC]: + static_categorical_inp_size.append(value[2]) + elif value==[DataTypes.CONTINUOUS, InputTypes.KNOWN]: + temporal_known_continuous_inp_size+=1 + elif value[0:2]==[DataTypes.CATEGORICAL, InputTypes.KNOWN]: + temporal_known_categorical_inp_size.append(value[2]) + else: + print('incorrect datatype') + + + + + # derived variables + num_static_vars = len(static_categorical_inp_size) + static_continuous_inp_size + num_future_vars = ( + len(temporal_known_categorical_inp_size) + + temporal_known_continuous_inp_size + ) + num_historic_vars = sum( + [ + num_future_vars, + temporal_observed_continuous_inp_size, + temporal_target_size, + len(temporal_observed_categorical_inp_size), + ] + ) + + super().__init__( + num_classes=num_classes, + encoder_length=encoder_length, + hidden=hidden, + n_heads=n_heads, + dropout_att=dropout_att, + example_length=example_length, + quantiles=quantiles, + num_static_vars=num_static_vars, + num_future_vars=num_future_vars, + num_historic_vars=num_historic_vars, + *args, + **kwargs, + ) + + self.encoder_length = encoder_length # this determines from how distant past we want to use data from + + self.embedding = LazyEmbedding( + static_categorical_inp_size, + temporal_known_categorical_inp_size, + temporal_observed_categorical_inp_size, + static_continuous_inp_size, + temporal_known_continuous_inp_size, + temporal_observed_continuous_inp_size, + temporal_target_size, + hidden, + ) # embeddings for all variables + + self.static_encoder = StaticCovariateEncoder( + num_static_vars, hidden, dropout + ) # encoding for static variables + self.TFTpart = TFTBack( + encoder_length, + num_historic_vars, + hidden, + dropout, + num_future_vars, + n_heads, + dropout_att, + example_length, + quantiles, + ) # The main part of the TFT + self.logit = nn.Linear( + len(quantiles), num_classes + ) # Linear layer on top to output to the number of classes and allow modification by predictionwrapper + + def forward(self, x) -> Tensor: + #Prep data to be in format model expects + + tensors = [[] for _ in range(8)] + i=0 + nan_array = from_numpy(np.full_like(x[:,:, 0].cpu(), -1)).to(x[:,:, 0].device)#target is nan in the input + + for var in self.vars: + + + + if self.vars_type[var][0:2] == [DataTypes.CATEGORICAL, InputTypes.STATIC]: + tensors[0].append(x[:, :, i]) + elif self.vars_type[var] == [DataTypes.CONTINUOUS, InputTypes.STATIC]: + tensors[1].append(x[:, :, i]) + elif self.vars_type[var][0:2] == [DataTypes.CATEGORICAL, InputTypes.KNOWN]: + tensors[2].append(x[:, :, i]) + elif self.vars_type[var] == [DataTypes.CONTINUOUS, InputTypes.KNOWN]: + tensors[3].append(x[:, :, i]) + elif self.vars_type[var][0:2] == [DataTypes.CATEGORICAL, InputTypes.OBSERVED]: + tensors[4].append(x[:, :, i]) + elif self.vars_type[var] == [DataTypes.CONTINUOUS, InputTypes.OBSERVED]: + tensors[5].append(x[:, :, i]) + + i+=1 + + tensors[6].append( + nan_array#target needs to be there + ) + + + tensors = [stack(x, dim=-1) if x else zeros(0) for x in tensors] + FEAT_NAMES = ['s_cat' , 's_cont' , 'k_cat' , 'k_cont' , 'o_cat' , 'o_cont' , 'target'] + s_inp, t_known_inp, t_observed_inp, t_observed_tgt = self.embedding(OrderedDict(zip(FEAT_NAMES, tensors))) + # Static context + cs, ce, ch, cc = self.static_encoder(s_inp) + + ch, cc = ch.unsqueeze(0), cc.unsqueeze(0) # lstm initial states + # Temporal input + + _historical_inputs = [] + + # Check for t_observed_inp + if t_observed_inp is not None: + _historical_inputs.append(t_observed_inp[:, : self.encoder_length, :]) + + # Check for t_known_inp + if t_known_inp is not None: + _historical_inputs.append(t_known_inp[:, : self.encoder_length, :]) + # Check for t_observed_tgt + if t_observed_tgt is not None: + _historical_inputs.append(t_observed_tgt[:, : self.encoder_length, :]) + historical_inputs = cat(_historical_inputs, dim=-2) + future_inputs = Tensor() + if t_known_inp is not None: + future_inputs = t_known_inp[:, self.encoder_length :] + + o = self.TFTpart( + historical_inputs, + cs, + ch, + cc, + ce, + future_inputs.to(historical_inputs.device), + ) + pred = self.logit(o) + return pred + + + +@gin.configurable +class TFTpytorch(DLPredictionPytorchForecastingWrapper): + """ + Implementation of https://arxiv.org/abs/1912.09363 from pytorch forecasting + """ + + supported_run_modes = [RunMode.classification, RunMode.regression] + + def __init__( + self, + dataset, + hidden, + dropout, + n_heads, + dropout_att, + optimizer, + num_classes, + *args, + **kwargs, + ): + super().__init__(optimizer=optimizer, pytorch_forecasting=True, *args, **kwargs) + self.dataset = dataset + self.model = TemporalFusionTransformer.from_dataset( + dataset=dataset, + hidden_size=hidden, + dropout=dropout, + attention_head_size=n_heads, + optimizer=optimizer, + loss=QuantileLoss(), + hidden_continuous_size=hidden, + ) + self.num_classes = num_classes + self.logit = nn.Linear(7, num_classes) + + def forward( + self, + tuple_x: tuple, + ) -> Dict[str, Tensor]: + x_dict = { + "encoder_cat": tuple_x[0], + "encoder_cont": tuple_x[1], + "encoder_target": tuple_x[2], + "encoder_lengths": tuple_x[3], + "decoder_cat": tuple_x[4], + "decoder_cont": tuple_x[5], + "decoder_target": tuple_x[6], + "decoder_lengths": tuple_x[7], + "decoder_time_idx": tuple_x[8], + "groups": tuple_x[9], + "target_scale": tuple_x[10], + } + out = self.model(x_dict) + pred = self.logit(out["prediction"]) + return pred + + def actual_vs_predictions_plot(self, dataloader): + predictions = self.model.predict(dataloader, return_x=True) + predictions_vs_actuals = self.model.calculate_prediction_actual_by_variable(predictions.x, predictions.output) + self.model.plot_prediction_actual_by_variable(predictions_vs_actuals) + return predictions_vs_actuals + + def interpertations(self, dataloader, log_dir=".", plot=False): + raw_predictions = self.model.predict(dataloader, return_x=True, mode="raw") + interpretation = self.model.interpret_output(raw_predictions.output, reduction="mean") + if plot: + figs = self.model.plot_interpretation(interpretation) + for key, fig in figs.items(): + fig.savefig(log_dir / f"interpretation_{key}.png", bbox_inches="tight") + + self.model = self.model.to(self.device) + + return interpretation + + def predict_dependency(self, dataloader, variable, log_dir): + dependency = self.model.predict_dependency( + dataloader.dataset, + variable, + np.linspace(0, 30, 30), + show_progress_bar=True, + mode="dataframe", + ) + # plotting median and 25% and 75% percentile + agg_dependency = dependency.groupby(variable).normalized_prediction.agg( + median="median", + q25=lambda x: x.quantile(0.25), + q75=lambda x: x.quantile(0.75), + ) + ax = agg_dependency.plot(y="median") + ax.fill_between(agg_dependency.index, agg_dependency.q25, agg_dependency.q75, alpha=0.3) + plt.savefig(log_dir / "dependecy.png", bbox_inches="tight") + return dependency + + def forward_captum( + self, + encoder_cat: Tensor, + encoder_cont: Tensor, + encoder_target: Tensor, + encoder_lengths: Tensor, + decoder_cat: Tensor, + decoder_cont: Tensor, + decoder_target: Tensor, + decoder_lengths: Tensor, + decoder_time_idx: Tensor, + groups: Tensor, + target_scale: Tensor, + ): + tuple_x = ( + encoder_cat, + encoder_cont, + encoder_target, + encoder_lengths, + decoder_cat, + decoder_cont, + decoder_target, + decoder_lengths, + decoder_time_idx, + groups, + target_scale, + ) + + return self.forward(tuple_x) + + +@gin.configurable +class RNNpytorch(DLPredictionPytorchForecastingWrapper): + """ + Implementation of RNN from pytorch forecasting + """ + + supported_run_modes = [RunMode.classification, RunMode.regression] + + def __init__( + self, + dataset, + hidden, + dropout, + optimizer, + num_classes, + cell_type, + rnn_layers, + batch_size, + *args, + **kwargs, + ): + super().__init__(optimizer=optimizer, pytorch_forecasting=True, *args, **kwargs) + + self.model = RecurrentNetwork.from_dataset( + cell_type=cell_type, + rnn_layers=rnn_layers, + dataset=dataset, + hidden_size=hidden, + dropout=dropout, + optimizer=optimizer, + ) + self.num_classes = num_classes + self.logit = nn.Linear(1, num_classes) + + def forward( + self, + tuple_x: tuple, + ) -> Dict[str, Tensor]: + x_dict = { + "encoder_cat": tuple_x[0], + "encoder_cont": tuple_x[1], + "encoder_target": tuple_x[2], + "encoder_lengths": tuple_x[3], + "decoder_cat": tuple_x[4], + "decoder_cont": tuple_x[5], + "decoder_target": tuple_x[6], + "decoder_lengths": tuple_x[7], + "decoder_time_idx": tuple_x[8], + "groups": tuple_x[9], + "target_scale": tuple_x[10], + } + if self.num_classes == 1: + x_dict["encoder_cont"][:, :, -1] = 0.0 + + out = self.model(x_dict) + + pred = self.logit(out["prediction"]) + + return pred + + +@gin.configurable +class DeepARpytorch(DLPredictionPytorchForecastingWrapper): + """ + Implementation of RNN from pytorch forecasting + """ + + supported_run_modes = [RunMode.classification, RunMode.regression] + + def __init__( + self, + dataset, + hidden, + dropout, + optimizer, + num_classes, + cell_type, + rnn_layers, + batch_size, + *args, + **kwargs, + ): + super().__init__(optimizer=optimizer, pytorch_forecasting=True, *args, **kwargs) + + self.model = DeepAR.from_dataset( + cell_type=cell_type, + rnn_layers=rnn_layers, + dataset=dataset, + hidden_size=hidden, + dropout=dropout, + optimizer=optimizer, + ) + self.num_classes = num_classes + self.logit = nn.Linear(4, num_classes) + + def forward( + self, + tuple_x: tuple, + ) -> Dict[str, Tensor]: + x_dict = { + "encoder_cat": tuple_x[0], + "encoder_cont": tuple_x[1], + "encoder_target": tuple_x[2], + "encoder_lengths": tuple_x[3], + "decoder_cat": tuple_x[4], + "decoder_cont": tuple_x[5], + "decoder_target": tuple_x[6], + "decoder_lengths": tuple_x[7], + "decoder_time_idx": tuple_x[8], + "groups": tuple_x[9], + "target_scale": tuple_x[10], + } + if self.num_classes == 1: + x_dict["encoder_cont"][:, :, -1] = 0.0 + out = self.model(x_dict) + + pred = self.logit(out["prediction"]) + + return pred diff --git a/icu_benchmarks/models/layers.py b/icu_benchmarks/models/layers.py index c08623bd..0ea1bb02 100644 --- a/icu_benchmarks/models/layers.py +++ b/icu_benchmarks/models/layers.py @@ -4,7 +4,11 @@ import torch.nn as nn import torch.nn.functional as F from torch.nn.utils import weight_norm - +from torch import Tensor +from torch.nn.parameter import UninitializedParameter +from typing import Dict, Tuple, Optional +from torch.nn import LayerNorm +from collections import OrderedDict @gin.configurable("masking") def parallel_recomb(q_t, kv_t, att_type="all", local_context=3, bin_size=None): @@ -314,3 +318,525 @@ def forward(self, x): out = self.net(x) res = x if self.downsample is None else self.downsample(x) return self.relu(out + res) + + +class MaybeLayerNorm(nn.Module): + """ + Implements layer normalization or identity function depending on output_size + """ + + def __init__(self, output_size, hidden, eps): + super().__init__() + if output_size and output_size == 1: + self.ln = nn.Identity() + else: + self.ln = LayerNorm(output_size if output_size else hidden, eps=eps) + + def forward(self, x): + return self.ln(x) + + +class GLU(nn.Module): + """ + Gated Linear Unit consists of a linear layer followed by a GLU where input is split in half along dim to form a and b + GLU(a,b)=a ⊗ σ(b)where σ is signmoid activation and ⊗ is element-wise product + """ + + def __init__(self, hidden, output_size): + super().__init__() + self.lin = nn.Linear(hidden, output_size * 2) + + def forward(self, x: Tensor) -> Tensor: + x = self.lin(x) + x = F.glu(x) + return x + + +class GRN(nn.Module): + """ + Gated Residual Network consists of a maybe normalization layer -->linear --> ELU -->linear-->GLU + in addition to a residual connection + """ + + def __init__( + self, + input_size, + hidden, + output_size=None, + context_hidden=None, + dropout=0.0, + ): + super().__init__() + + self.layer_norm = MaybeLayerNorm(output_size, hidden, eps=1e-3) + + self.lin_a = nn.Linear(input_size, hidden) + + if context_hidden is not None: + self.lin_c = nn.Linear(context_hidden, hidden, bias=False) + else: + self.lin_c = nn.Identity() + self.lin_i = nn.Linear(hidden, hidden) + self.glu = GLU(hidden, output_size if output_size else hidden) + self.dropout = nn.Dropout(dropout) + self.out_proj = nn.Linear(input_size, output_size) if output_size else None + + def forward(self, a: Tensor, c: Optional[Tensor] = None): + x = self.lin_a(a) + + if c is not None: + x = x + self.lin_c(c).unsqueeze(1) + x = F.elu(x) + x = self.lin_i(x) + x = self.dropout(x) + x = self.glu(x) + y = a if self.out_proj is None else self.out_proj(a) + x = x + y + + return self.layer_norm(x) + + +# @torch.jit.script #Currently broken with autocast +def fused_pointwise_linear_v1(x, a, b): + out = torch.mul(x.unsqueeze(-1), a) + out = out + b + return out + + +# @torch.jit.script +def fused_pointwise_linear_v2(x, a, b): + out = x.unsqueeze(3) * a + out = out + b + return out + + +class TFTEmbedding(nn.Module): + def __init__( + self, + static_categorical_inp_size, + temporal_known_categorical_inp_size, + temporal_observed_categorical_inp_size, + static_continuous_inp_size, + temporal_known_continuous_inp_size, + temporal_observed_continuous_inp_size, + temporal_target_size, + hidden, + initialize_cont_params=True, + ): + # initialize_cont_params=False prevents form initializing parameters inside this class + # so they can be lazily initialized in LazyEmbedding module + super().__init__() + # these are basically number of varaibales that falls under each category + self.s_cat_inp_size = static_categorical_inp_size + self.t_cat_k_inp_size = temporal_known_categorical_inp_size + self.t_cat_o_inp_size = temporal_observed_categorical_inp_size + self.s_cont_inp_size = static_continuous_inp_size + self.t_cont_k_inp_size = temporal_known_continuous_inp_size + self.t_cont_o_inp_size = temporal_observed_continuous_inp_size + self.t_tgt_size = temporal_target_size + + self.hidden = hidden + + # There are 7 types of input: + # 1. Static categorical + # 2. Static continuous + # 3. Temporal known a priori categorical + # 4. Temporal known a priori continuous + # 5. Temporal observed categorical + # 6. Temporal observed continuous + # 7. Temporal observed targets (time series obseved so far) + self.s_cat_embed = ( + nn.ModuleList([nn.Embedding(n, self.hidden) for n in self.s_cat_inp_size]) if self.s_cat_inp_size else None + ) + self.t_cat_k_embed = ( + nn.ModuleList([nn.Embedding(n, self.hidden) for n in self.t_cat_k_inp_size]) if self.t_cat_k_inp_size else None + ) + self.t_cat_o_embed = ( + nn.ModuleList([nn.Embedding(n, self.hidden) for n in self.t_cat_o_inp_size]) if self.t_cat_o_inp_size else None + ) + + if initialize_cont_params: + self.s_cont_embedding_vectors = ( + nn.Parameter(torch.Tensor(self.s_cont_inp_size, self.hidden)) if self.s_cont_inp_size else None + ) + self.t_cont_k_embedding_vectors = ( + nn.Parameter(torch.Tensor(self.t_cont_k_inp_size, self.hidden)) if self.t_cont_k_inp_size else None + ) + self.t_cont_o_embedding_vectors = ( + nn.Parameter(torch.Tensor(self.t_cont_o_inp_size, self.hidden)) if self.t_cont_o_inp_size else None + ) + self.t_tgt_embedding_vectors = nn.Parameter(torch.Tensor(self.t_tgt_size, self.hidden)) + self.s_cont_embedding_bias = ( + nn.Parameter(torch.zeros(self.s_cont_inp_size, self.hidden)) if self.s_cont_inp_size else None + ) + self.t_cont_k_embedding_bias = ( + nn.Parameter(torch.zeros(self.t_cont_k_inp_size, self.hidden)) if self.t_cont_k_inp_size else None + ) + self.t_cont_o_embedding_bias = ( + nn.Parameter(torch.zeros(self.t_cont_o_inp_size, self.hidden)) if self.t_cont_o_inp_size else None + ) + self.t_tgt_embedding_bias = nn.Parameter(torch.zeros(self.t_tgt_size, self.hidden)) + + self.reset_parameters() + + def reset_parameters(self): + """' + embeddings are initilitized using xavier's method and biases are initlitized with zeros + """ + if self.s_cont_embedding_vectors is not None: + torch.nn.init.xavier_normal_(self.s_cont_embedding_vectors) + torch.nn.init.zeros_(self.s_cont_embedding_bias) + if self.t_cont_k_embedding_vectors is not None: + torch.nn.init.xavier_normal_(self.t_cont_k_embedding_vectors) + torch.nn.init.zeros_(self.t_cont_k_embedding_bias) + if self.t_cont_o_embedding_vectors is not None: + torch.nn.init.xavier_normal_(self.t_cont_o_embedding_vectors) + torch.nn.init.zeros_(self.t_cont_o_embedding_bias) + if self.t_tgt_embedding_vectors is not None: + torch.nn.init.xavier_normal_(self.t_tgt_embedding_vectors) + torch.nn.init.zeros_(self.t_tgt_embedding_bias) + if self.s_cat_embed is not None: + for module in self.s_cat_embed: + module.reset_parameters() + if self.t_cat_k_embed is not None: + for module in self.t_cat_k_embed: + module.reset_parameters() + if self.t_cat_o_embed is not None: + for module in self.t_cat_o_embed: + module.reset_parameters() + + def _apply_embedding( + self, + cat: Optional[Tensor], + cont: Optional[Tensor], + cat_emb: Optional[nn.ModuleList], + cont_emb: Tensor, + cont_bias: Tensor, + ) -> Tuple[Optional[Tensor], Optional[Tensor]]: + """ print("Input cat:") + print(cat) + print("Shape:", cat.shape) + print("Contains NaNs:", torch.isnan(cat).any()) + if cat is not None and cat.size(0) > 0: + e_cat = [] + for i, embed in enumerate(cat_emb): + indices = cat[..., i].int() + embedded_values = embed(indices) + print(f"Output of embedding layer {i}:") + print(embedded_values) + print("Contains NaNs:", torch.isnan(embedded_values).any()) + e_cat.append(embedded_values) + e_cat = torch.stack(e_cat, dim=-2) + print(1,e_cat) """ + e_cat = ( + torch.stack([embed(cat[..., i].int()) for i, embed in enumerate(cat_emb)], dim=-2) + if (cat is not None) and (cat.size()[0] > 0) + else None + ) + #print(2,e_cat) + + if (cont is not None) and (cont.size()[0] > 0): + + + e_cont = torch.mul(cont.unsqueeze(-1), cont_emb) + e_cont = e_cont + cont_bias + + else: + e_cont = None + if e_cat is not None and e_cont is not None: + return torch.cat([e_cat, e_cont], dim=-2) + elif e_cat is not None: + return e_cat + elif e_cont is not None: + return e_cont + else: + return None + + def forward(self, x: Dict[str, Tensor]): + # temporal/static categorical/continuous known/observed input + s_cat_inp = x.get("s_cat", None) + s_cont_inp = x.get("s_cont", None) + t_cat_k_inp = x.get("k_cat", None) + t_cont_k_inp = x.get("k_cont", None) + t_cat_o_inp = x.get("o_cat", None) + t_cont_o_inp = x.get("o_cont", None) + t_tgt_obs = x.get("target") # Has to be present + # Static inputs are expected to be equal for all timesteps + # For memory efficiency there is no assert statement + + + s_cat_inp = s_cat_inp[:, 0, :] if s_cat_inp is not None else None + s_cont_inp = s_cont_inp[:, 0, :] if s_cont_inp is not None else None + s_inp = self._apply_embedding( + s_cat_inp, s_cont_inp, self.s_cat_embed, self.s_cont_embedding_vectors, self.s_cont_embedding_bias + ) + + + t_known_inp = self._apply_embedding( + t_cat_k_inp, t_cont_k_inp, self.t_cat_k_embed, self.t_cont_k_embedding_vectors, self.t_cont_k_embedding_bias + ) + + t_observed_inp = self._apply_embedding( + t_cat_o_inp, t_cont_o_inp, self.t_cat_o_embed, self.t_cont_o_embedding_vectors, self.t_cont_o_embedding_bias + ) + + # Temporal observed targets + t_observed_tgt = torch.matmul(t_tgt_obs.unsqueeze(3).unsqueeze(4), self.t_tgt_embedding_vectors.unsqueeze(1)).squeeze( + 3 + ) + + t_observed_tgt = t_observed_tgt + self.t_tgt_embedding_bias + + return s_inp, t_known_inp, t_observed_inp, t_observed_tgt + + +class LazyEmbedding(nn.modules.lazy.LazyModuleMixin, TFTEmbedding): + cls_to_become = TFTEmbedding + + def __init__( + self, + static_categorical_inp_size, + temporal_known_categorical_inp_size, + temporal_observed_categorical_inp_size, + static_continuous_inp_size, + temporal_known_continuous_inp_size, + temporal_observed_continuous_inp_size, + temporal_target_size, + hidden, + ): + super().__init__( + static_categorical_inp_size, + temporal_known_categorical_inp_size, + temporal_observed_categorical_inp_size, + static_continuous_inp_size, + temporal_known_continuous_inp_size, + temporal_observed_continuous_inp_size, + temporal_target_size, + hidden, + initialize_cont_params=False, + ) + if static_continuous_inp_size: + self.s_cont_embedding_vectors = UninitializedParameter() + self.s_cont_embedding_bias = UninitializedParameter() + else: + self.s_cont_embedding_vectors = None + self.s_cont_embedding_bias = None + if temporal_known_continuous_inp_size: + self.t_cont_k_embedding_vectors = UninitializedParameter() + self.t_cont_k_embedding_bias = UninitializedParameter() + else: + self.t_cont_k_embedding_vectors = None + self.t_cont_k_embedding_bias = None + + if temporal_observed_continuous_inp_size: + self.t_cont_o_embedding_vectors = UninitializedParameter() + self.t_cont_o_embedding_bias = UninitializedParameter() + else: + self.t_cont_o_embedding_vectors = None + self.t_cont_o_embedding_bias = None + self.t_tgt_embedding_vectors = UninitializedParameter() + self.t_tgt_embedding_bias = UninitializedParameter() + + def initialize_parameters(self, x): + + if self.has_uninitialized_params(): + s_cont_inp = x.get("s_cont", None) + t_cont_k_inp = x.get("k_cont", None) + t_cont_o_inp = x.get("o_cont", None) + t_tgt_obs = x["target"] # Has to be present + if (s_cont_inp is not None) and (s_cont_inp.size()[0] > 0): + self.s_cont_embedding_vectors.materialize((s_cont_inp.shape[-1], self.hidden)) + self.s_cont_embedding_bias.materialize((s_cont_inp.shape[-1], self.hidden)) + + if (t_cont_k_inp is not None) and (t_cont_k_inp.size()[0] > 0): + self.t_cont_k_embedding_vectors.materialize((t_cont_k_inp.shape[-1], self.hidden)) + self.t_cont_k_embedding_bias.materialize((t_cont_k_inp.shape[-1], self.hidden)) + + if (t_cont_o_inp) is not None and (t_cont_o_inp.size()[0] > 0): + self.t_cont_o_embedding_vectors.materialize((t_cont_o_inp.shape[-1], self.hidden)) + self.t_cont_o_embedding_bias.materialize((t_cont_o_inp.shape[-1], self.hidden)) + + self.t_tgt_embedding_vectors.materialize((t_tgt_obs.shape[-1], self.hidden)) + self.t_tgt_embedding_bias.materialize((t_tgt_obs.shape[-1], self.hidden)) + + self.reset_parameters() + + +class VariableSelectionNetwork(nn.Module): + """ + Learns to select important netowrks consists of GRNs with one GRN for variable weights + and the others for input embedding + """ + + def __init__(self, hidden, dropout, num_inputs): + super().__init__() + self.hidden = hidden + self.joint_grn = GRN(hidden * num_inputs, hidden, output_size=num_inputs, context_hidden=hidden) + self.var_grns = nn.ModuleList([GRN(hidden, hidden, dropout=dropout) for _ in range(num_inputs)]) + + def forward(self, x: Tensor, context: Optional[Tensor] = None): + if x.numel() == 0: # Check if x is an empty tensor + batch_size = context.size(0) if context is not None else 1 + variable_ctx = torch.zeros(batch_size, 1, self.hidden, device=x.device) + sparse_weights = torch.ones(batch_size, 1, self.hidden, device=x.device) + return variable_ctx, sparse_weights + + Xi = torch.flatten(x, start_dim=-2) + grn_outputs = self.joint_grn(Xi, c=context) + sparse_weights = F.softmax(grn_outputs, dim=-1) + transformed_embed_list = [m(x[..., i, :]) for i, m in enumerate(self.var_grns)] + transformed_embed = torch.stack(transformed_embed_list, dim=-1) + variable_ctx = torch.matmul(transformed_embed, sparse_weights.unsqueeze(-1)).squeeze(-1) + + return variable_ctx, sparse_weights + + +class StaticCovariateEncoder(nn.Module): + """ + Network to produce 4 contexts vectors to enrich static variables + Vriable selection Network --> GRNs + """ + + def __init__(self, num_static_vars, hidden, dropout): + super().__init__() + self.vsn = VariableSelectionNetwork(hidden, dropout, num_static_vars) + self.context_grns = nn.ModuleList([GRN(hidden, hidden, dropout=dropout) for _ in range(4)]) + + def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + variable_ctx, sparse_weights = self.vsn(x) + + # Context vectors: + # variable selection context + # enrichment context + # state_c context + # state_h context + cs, ce, ch, cc = [m(variable_ctx) for m in self.context_grns] + + return cs, ce, ch, cc + + +class InterpretableMultiHeadAttention(nn.Module): + """ + Multi-head attention different as it outputs the attention_probability and it combines the attention weights instead of + concating them different from the one implemented already in YAIB + """ + + def __init__(self, n_head, hidden, dropout_att, dropout, example_length): + super().__init__() + self.n_head = n_head + assert hidden % n_head == 0 + self.d_head = hidden // n_head + self.qkv_linears = nn.Linear(hidden, (2 * n_head + 1) * self.d_head, bias=False) + self.out_proj = nn.Linear(self.d_head, hidden, bias=False) + self.dropout_att = nn.Dropout(dropout_att) + self.out_dropout = nn.Dropout(dropout) + self.scale = self.d_head**-0.5 + self.register_buffer("_mask", torch.triu(torch.full((example_length, example_length), float("-inf")), 1).unsqueeze(0)) + + def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: + bs, t, h_size = x.shape + qkv = self.qkv_linears(x) + q, k, v = qkv.split((self.n_head * self.d_head, self.n_head * self.d_head, self.d_head), dim=-1) + q = q.view(bs, t, self.n_head, self.d_head) + k = k.view(bs, t, self.n_head, self.d_head) + v = v.view(bs, t, self.d_head) + + # attn_score = torch.einsum('bind,bjnd->bnij', q, k) + attn_score = torch.matmul(q.permute((0, 2, 1, 3)), k.permute((0, 2, 3, 1))) + attn_score.mul_(self.scale) + attn_score = attn_score + self._mask + + attn_prob = F.softmax(attn_score, dim=3) + attn_prob = self.dropout_att(attn_prob) + + # attn_vec = torch.einsum('bnij,bjd->bnid', attn_prob, v) + attn_vec = torch.matmul(attn_prob, v.unsqueeze(1)) + m_attn_vec = torch.mean(attn_vec, dim=1) + out = self.out_proj(m_attn_vec) + out = self.out_dropout(out) + + return out, attn_prob + + +class TFTBack(nn.Module): + """ + Big part of TFT architecture consists of static enrichment followed by mutli-head self-attention then + position wise feed forward followed by a gate and a dense layer + GRNs-->multi-head attention-->GRNs-->GLU-->Linear-->output + """ + + def __init__( + self, + encoder_length, + num_historic_vars, + hidden, + dropout, + num_future_vars, + n_head, + dropout_att, + example_length, + quantiles, + ): + super().__init__() + + self.encoder_length = encoder_length + self.history_vsn = VariableSelectionNetwork(hidden, dropout, num_historic_vars) + self.history_encoder = nn.LSTM(hidden, hidden, batch_first=True) + self.future_vsn = VariableSelectionNetwork(hidden, dropout, num_future_vars) + self.future_encoder = nn.LSTM(hidden, hidden, batch_first=True) + + self.input_gate = GLU(hidden, hidden) + self.input_gate_ln = LayerNorm(hidden, eps=1e-3) + + self.enrichment_grn = GRN(hidden, hidden, context_hidden=hidden, dropout=dropout) + self.attention = InterpretableMultiHeadAttention(n_head, hidden, dropout_att, dropout, example_length) + self.attention_gate = GLU(hidden, hidden) + self.attention_ln = LayerNorm(hidden, eps=1e-3) + + self.positionwise_grn = GRN(hidden, hidden, dropout=dropout) + + self.decoder_gate = GLU(hidden, hidden) + self.decoder_ln = LayerNorm(hidden, eps=1e-3) + + self.quantile_proj = nn.Linear(hidden, len(quantiles)) + + def forward(self, historical_inputs, cs, ch, cc, ce, future_inputs): + historical_features, _ = self.history_vsn(historical_inputs, cs) + history, state = self.history_encoder(historical_features, (ch, cc)) + future_features, _ = self.future_vsn(future_inputs, cs) + + future, _ = self.future_encoder(future_features, state) + + # skip connection + input_embedding = torch.cat([historical_features, future_features], dim=1) + temporal_features = torch.cat([history, future], dim=1) + temporal_features = self.input_gate(temporal_features) + temporal_features = temporal_features + input_embedding + temporal_features = self.input_gate_ln(temporal_features) + + # Static enrichment + enriched = self.enrichment_grn(temporal_features, c=ce) + + # Temporal self attention + x, attn_prob = self.attention(enriched) + + # Don't compute historical quantiles + x = x[:, self.encoder_length:, :] + temporal_features = temporal_features[:, self.encoder_length:, :] + enriched = enriched[:, self.encoder_length:, :] + + x = self.attention_gate(x) + x = x + enriched + x = self.attention_ln(x) + + # Position-wise feed-forward + x = self.positionwise_grn(x) + + # Final skip connection + x = self.decoder_gate(x) + x = x + temporal_features + x = self.decoder_ln(x) + + out = self.quantile_proj(x) + + return out diff --git a/icu_benchmarks/models/metrics.py b/icu_benchmarks/models/metrics.py new file mode 100644 index 00000000..738c5523 --- /dev/null +++ b/icu_benchmarks/models/metrics.py @@ -0,0 +1,97 @@ +import torch +from typing import Callable +import numpy as np +from ignite.metrics import EpochMetric +from sklearn.metrics import balanced_accuracy_score, mean_absolute_error +from sklearn.calibration import calibration_curve +from scipy.spatial.distance import jensenshannon + + +"""" +This file contains metrics that are not available in ignite.metrics. Specifically, it adds transformation capabilities to some +metrics. +""" + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def balanced_accuracy_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor) -> float: + y_true = y_targets.numpy() + y_pred = np.argmax(y_preds.numpy(), axis=-1) + return balanced_accuracy_score(y_true, y_pred) + + +def ece_curve_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor) -> float: + y_true = y_targets.numpy() + y_pred = y_preds.numpy() + return calibration_curve(y_true, y_pred, n_bins=10) + + +def mae_with_invert_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor, invert_fn=Callable) -> float: + y_true = invert_fn(y_targets.numpy().reshape(-1, 1))[:, 0] + y_pred = invert_fn(y_preds.numpy().reshape(-1, 1))[:, 0] + return mean_absolute_error(y_true, y_pred) + + +def JSD_fn(y_preds: torch.Tensor, y_targets: torch.Tensor): + return jensenshannon(abs(y_preds).flatten(), abs(y_targets).flatten()) ** 2 + + +class BalancedAccuracy(EpochMetric): + def __init__(self, output_transform: Callable = lambda x: x, check_compute_fn: bool = False) -> None: + super(BalancedAccuracy, self).__init__( + balanced_accuracy_compute_fn, + output_transform=output_transform, + check_compute_fn=check_compute_fn, + ) + + +class CalibrationCurve(EpochMetric): + def __init__(self, output_transform: Callable = lambda x: x, check_compute_fn: bool = False) -> None: + super(CalibrationCurve, self).__init__( + ece_curve_compute_fn, + output_transform=output_transform, + check_compute_fn=check_compute_fn, + ) + + +class MAE(EpochMetric): + def __init__( + self, + output_transform: Callable = lambda x: x, + check_compute_fn: bool = False, + invert_transform: Callable = lambda x: x, + ) -> None: + super(MAE, self).__init__( + lambda x, y: mae_with_invert_compute_fn(x, y, invert_transform), + output_transform=output_transform, + check_compute_fn=check_compute_fn, + ) + + +class JSD(EpochMetric): + def __init__( + self, + output_transform: Callable = lambda x: x, + check_compute_fn: bool = False, + ) -> None: + super(JSD, self).__init__( + lambda x, y: JSD_fn(x, y), + output_transform=output_transform, + check_compute_fn=check_compute_fn, + ) diff --git a/icu_benchmarks/models/train.py b/icu_benchmarks/models/train.py index db7aabda..86677269 100644 --- a/icu_benchmarks/models/train.py +++ b/icu_benchmarks/models/train.py @@ -2,18 +2,32 @@ import gin import torch import logging +import json import pandas as pd from joblib import load from torch.optim import Adam from torch.utils.data import DataLoader from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar, LearningRateMonitor +from pytorch_lightning.callbacks import ( + EarlyStopping, + ModelCheckpoint, + TQDMProgressBar, + LearningRateMonitor, +) from pathlib import Path -from icu_benchmarks.data.loader import PredictionDataset, ImputationDataset +from icu_benchmarks.data.loader import ( + PredictionDataset, + ImputationDataset, + PredictionDatasetpytorch, +) + + from icu_benchmarks.models.utils import save_config_file, JSONMetricsLogger from icu_benchmarks.contants import RunMode from icu_benchmarks.data.constants import DataSplit as Split +from captum.attr import IntegratedGradients, Saliency, FeatureAblation, Lime + cpu_core_count = len(os.sched_getaffinity(0)) if hasattr(os, "sched_getaffinity") else os.cpu_count() @@ -41,15 +55,25 @@ def train_common( epochs=1000, patience=20, min_delta=1e-5, + gradient_clip_val=0, test_on: str = Split.test, dataset_names=None, use_wandb: bool = False, cpu: bool = False, - verbose=False, + verbose=True, ram_cache=False, pl_model=True, train_only=False, - num_workers: int = min(cpu_core_count, torch.cuda.device_count() * 8 * int(torch.cuda.is_available()), 32), + num_workers: int = min( + cpu_core_count, + torch.cuda.device_count() * 8 * int(torch.cuda.is_available()), + 32, + ), + explain: bool = False, + pytorch_forecasting: bool = False, + XAI_metric: bool = False, + random_labels: bool = False, + random_model_dir: str = None, ): """Common wrapper to train all benchmarked models. @@ -77,9 +101,12 @@ def train_common( pl_model: Loading a pytorch lightning model. num_workers: Number of workers to use for data loading. """ - logging.info(f"Training model: {model.__name__}.") - dataset_class = ImputationDataset if mode == RunMode.imputation else PredictionDataset + + # choose dataset_class based on the model + dataset_class = ImputationDataset if mode == RunMode.imputation else (PredictionDatasetpytorch if pytorch_forecasting else PredictionDataset) + + logging.info(f"Logging to directory: {log_dir}.") save_config_file(log_dir) # We save the operative config before and also after training @@ -88,37 +115,34 @@ def train_common( val_dataset = dataset_class(data, split=Split.val, ram_cache=ram_cache, name=dataset_names["val"]) train_dataset, val_dataset = assure_minimum_length(train_dataset), assure_minimum_length(val_dataset) batch_size = min(batch_size, len(train_dataset), len(val_dataset)) + test_dataset = dataset_class(data, split=test_on, name=dataset_names["test"]) + test_dataset = assure_minimum_length(test_dataset) if not eval_only: logging.info( f"Training on {train_dataset.name} with {len(train_dataset)} samples and validating on {val_dataset.name} with" f" {len(val_dataset)} samples." ) + batch_size = int(batch_size) logging.info(f"Using {num_workers} workers for data loading.") - - train_loader = DataLoader( - train_dataset, + train_loader, val_loader, test_loader, model = prepare_data_loaders( + model=model, + train_dataset=train_dataset, + val_dataset=val_dataset, + test_dataset=test_dataset, batch_size=batch_size, - shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True, + pytorch_forecasting=pytorch_forecasting, + load_weights=load_weights, + source_dir=source_dir, + pl_model=pl_model, + optimizer=optimizer, + epochs=epochs, + mode=mode, + random_labels=random_labels, ) - val_loader = DataLoader( - val_dataset, - batch_size=batch_size, - shuffle=False, - num_workers=num_workers, - pin_memory=True, - drop_last=True, - ) - - data_shape = next(iter(train_loader))[0].shape - - if load_weights: - model = load_model(model, source_dir, pl_model=pl_model) - else: - model = model(optimizer=optimizer, input_size=data_shape, epochs=epochs, run_mode=mode) model.set_weight(weight, train_dataset) model.set_trained_columns(train_dataset.get_feature_names()) @@ -126,7 +150,13 @@ def train_common( if use_wandb: loggers.append(WandbLogger(save_dir=log_dir)) callbacks = [ - EarlyStopping(monitor="val/loss", min_delta=min_delta, patience=patience, strict=False, verbose=verbose), + EarlyStopping( + monitor="val/loss", + min_delta=min_delta, + patience=patience, + strict=False, + verbose=verbose, + ), ModelCheckpoint(log_dir, filename="model", save_top_k=1, save_last=True), LearningRateMonitor(logging_interval="step"), ] @@ -147,10 +177,12 @@ def train_common( logger=loggers, num_sanity_val_steps=-1, log_every_n_steps=5, + gradient_clip_val=gradient_clip_val ) if not eval_only: if model.requires_backprop: logging.info("Training DL model.") + trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader) logging.info("Training complete.") else: @@ -162,21 +194,144 @@ def train_common( logging.info("Finished training full model.") save_config_file(log_dir) return 0 - test_dataset = dataset_class(data, split=test_on, name=dataset_names["test"]) - test_dataset = assure_minimum_length(test_dataset) - logging.info(f"Testing on {test_dataset.name} with {len(test_dataset)} samples.") - test_loader = ( - DataLoader( - test_dataset, - batch_size=min(batch_size * 4, len(test_dataset)), - shuffle=False, - num_workers=num_workers, - pin_memory=True, - drop_last=True, + + if explain: + path = Path(random_model_dir) + random_model = load_model( + model, + source_dir=path, + pl_model=pl_model, + train_dataset=train_dataset, + optimizer=optimizer, ) - if model.requires_backprop - else DataLoader([test_dataset.to_tensor()], batch_size=1) - ) + + XAI_dict = {} # dictrionary to log attributions metrics + + # choose which methods to get attributions + methods = { + "G": Saliency, + "L": Lime, + "IG": IntegratedGradients, + "FA": FeatureAblation, + "R": "Random", + "Att": "Attention", + } + for key, item in methods.items(): + # If conditions needed here as different explantations require different inputs + if key == "IG": + ( + all_attrs, + features_attrs, + timestep_attrs, + ts_v_score, + ts_score, + v_score, + r_score, + st_i_score, + st_o_score, + ) = model.explantation( + dataloader=test_loader, + method=item, + log_dir=log_dir, + plot=True, + n_steps=50, + XAI_metric=XAI_metric, + random_model=random_model, + ) + elif key == "L" or key == "FA": + """for Lime and feature ablation we need to define + what is a feature we define each variable + per timestep as a feature""" + shapes = [ + torch.Size([64, 24, 0]), + torch.Size([64, 24, 53]), + torch.Size([64, 24]), + torch.Size([64]), + torch.Size([64, 1, 0]), + torch.Size([64, 1, 53]), + torch.Size([64, 1]), + torch.Size([64]), + torch.Size([64, 1]), + torch.Size([64, 1]), + torch.Size([64, 2]), + ] + + # Create a feature mask for the second tensor that includes both features and timesteps + num_timesteps = shapes[1][1] + num_features = shapes[1][2] + feature_mask_second = torch.arange(num_timesteps * num_features).reshape(num_timesteps, num_features) + feature_mask_second = feature_mask_second.unsqueeze(0).repeat(shapes[1][0], 1, 1) + # Create a tuple of masks + feature_masks = tuple( + [create_default_mask(shape) if i != 1 else feature_mask_second for i, shape in enumerate(shapes)] + ) + ( + all_attrs, + features_attrs, + timestep_attrs, + ts_v_score, + ts_score, + v_score, + r_score, + st_i_score, + st_o_score, + ) = model.explantation( + dataloader=test_loader, + method=item, + log_dir=log_dir, + plot=True, + feature_mask=feature_masks, + return_input_shape=True, + XAI_metric=XAI_metric, + random_model=random_model, + ) + + else: + ( + all_attrs, + features_attrs, + timestep_attrs, + ts_v_score, + ts_score, + v_score, + r_score, + st_i_score, + st_o_score, + ) = model.explantation( + dataloader=test_loader, + method=item, + log_dir=log_dir, + plot=True, + XAI_metric=XAI_metric, + random_model=random_model, + ) + + if XAI_metric: + # logging metric scores + print("{} Attributions Faithfulness Timesteps ".format(key), ts_score) + XAI_dict["{}_Faith Timesteps".format(key)] = ts_score + print("{}_ROS ".format(key), st_o_score) + XAI_dict["{}_ROS".format(key)] = st_o_score + print("{}_RIS ".format(key), st_i_score) + XAI_dict["{}_RIS".format(key)] = st_i_score + + print("{} Attributions faithfulness featrues ".format(key), v_score) + XAI_dict["{}_Faith Features".format(key)] = v_score + + print( + "{}_Attributions Faithfulness Variable Per Timestep ".format(key), + ts_v_score, + ) + XAI_dict["{}_Faith Variable Per Timestep".format(key)] = ts_v_score + print("{}_Data Randomization Distance ".format(key), r_score) + XAI_dict["{}_Data Randomization Distance".format(key)] = r_score + + # Path to the JSON file in log_dir + json_file_path = f"{log_dir}/XAI_metrics.json" + + # Write the dictionary to a JSON file + with open(json_file_path, "w") as json_file: + json.dump(XAI_dict, json_file) model.set_weight("balanced", train_dataset) test_loss = trainer.test(model, dataloaders=test_loader, verbose=verbose)[0]["test/loss"] @@ -184,7 +339,146 @@ def train_common( return test_loss -def load_model(model, source_dir, pl_model=True): +def prepare_data_loaders( + model, + train_dataset, + val_dataset, + test_dataset, + batch_size, + num_workers, + pin_memory, + drop_last=True, + shuffle_train=True, + pytorch_forecasting=False, + load_weights=False, + source_dir=None, + pl_model=None, + optimizer=None, + epochs=None, + mode=None, + random_labels=False, +): + """ + Prepares PyTorch data loaders based on the provided datasets and configuration. + + Args: + train_dataset: Training dataset. + val_dataset: Validation dataset. + test_dataset: Test dataset. + batch_size: Batch size for data loaders. + num_workers: Number of worker processes for data loading. + pin_memory: Whether to use pin_memory for faster data transfer to GPU. + drop_last: Whether to drop the last incomplete batch. + shuffle_train: Whether to shuffle the training data loader. + load_weights: Whether to load weights from a pre-trained model. + source_dir: Directory to load weights from. + pl_model: PyTorch Lightning model (used for loading weights). + optimizer: Optimizer for the model. + epochs: Number of training epochs. + mode: Run mode for the model. + random_labels: Whether to randomize labels for the datasets. + + Returns: + tuple: Tuple containing train_loader, val_loader, and test_loader. + """ + if pytorch_forecasting: + train_loader = train_dataset.to_dataloader( + train=True, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + drop_last=drop_last, + shuffle=shuffle_train, + ) + val_loader = val_dataset.to_dataloader( + train=False, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + drop_last=drop_last, + ) + test_loader = test_dataset.to_dataloader( + train=False, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + drop_last=drop_last, + shuffle=False, + ) + if load_weights: + model = load_model( + model, + source_dir, + pl_model=pl_model, + train_dataset=train_dataset, + optimizer=optimizer, + ) + else: + model = model( + train_dataset, + optimizer=optimizer, + epochs=epochs, + run_mode=mode, + batch_size=batch_size, + ) + if random_labels: + train_dataset.randomize_labels(num_classes=model.num_classes) + val_dataset.randomize_labels(num_classes=model.num_classes) + test_dataset.randomize_labels(num_classes=model.num_classes) + + else: + train_loader = DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=shuffle_train, + num_workers=num_workers, + pin_memory=pin_memory, + drop_last=drop_last, + ) + val_loader = DataLoader( + val_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=pin_memory, + drop_last=drop_last, + ) + + test_loader = ( + DataLoader( + test_dataset, + batch_size=min(batch_size * 4, len(test_dataset)), + shuffle=False, + num_workers=num_workers, + pin_memory=pin_memory, + drop_last=drop_last, + ) + if model.requires_backprop + else DataLoader([test_dataset.to_tensor()], batch_size=1) + ) + + data_shape = next(iter(train_loader))[0].shape + if load_weights: + model = load_model(model, source_dir, pl_model=pl_model) + else: + model = model(optimizer=optimizer, input_size=data_shape, epochs=epochs, run_mode=mode) + + return train_loader, val_loader, test_loader, model + + +def create_default_mask(shape): + if len(shape) == 3: + return torch.zeros(shape[0], shape[1], max(1, shape[2]), dtype=torch.int32) + elif len(shape) == 2: + return torch.zeros(shape[0], max(1, shape[1]), dtype=torch.int32) + else: # len(shape) == 1 + return torch.zeros(shape[0], dtype=torch.int32) + + +def load_model(model, source_dir, pl_model=True, train_dataset=None, optimizer=None): + if source_dir is None: + return None + if source_dir.exists(): if model.requires_backprop: if (source_dir / "model.ckpt").exists(): @@ -196,7 +490,11 @@ def load_model(model, source_dir, pl_model=True): else: return Exception(f"No weights to load at path : {source_dir}") if pl_model: - model = model.load_from_checkpoint(model_path) + if train_dataset is not None: + model = model.load_from_checkpoint(model_path, dataset=train_dataset, optimizer=optimizer) + + else: + model = model.load_from_checkpoint(model_path) else: checkpoint = torch.load(model_path) model.load_from_checkpoint(checkpoint) diff --git a/icu_benchmarks/models/utils.py b/icu_benchmarks/models/utils.py index 6c944ae7..34a5de0b 100644 --- a/icu_benchmarks/models/utils.py +++ b/icu_benchmarks/models/utils.py @@ -8,13 +8,14 @@ import logging import numpy as np import torch - +from quantus.functions.similarity_func import correlation_spearman, cosine from pytorch_lightning.loggers.logger import Logger from pytorch_lightning.utilities import rank_zero_only from torch.nn import Module from torch.optim import Optimizer, Adam, SGD, RAdam from typing import Optional, Union from torch.optim.lr_scheduler import _LRScheduler, CosineAnnealingLR, MultiStepLR, ExponentialLR +import captum def save_config_file(log_dir): @@ -188,3 +189,379 @@ def version(self): @rank_zero_only def log_hyperparams(self, params): pass + + +def Faithfulness_Correlation( + model, + x, + attribution, + similarity_func=None, + nr_runs=100, + pertrub=None, + subset_size=3, + feature=False, + time_step=False, + feature_timestep=False, +): + """ + Calculates faithfulness scores for captum attributions + + Args: + - x:Batch input + -attribution: attribution generated by captum, + - similarity_func:function to determine similarity between sum of attributions and difference in prediction + - nr_runs: How many times to repeat the experiment, + - pertrub: What change to do to the input, + - subset_size: The size of the subset of featrues to alter , + - feature: Determines if to calcualte faithfulness of feature attributions, + - time_step: Determines if to calcualte faithfulness of timesteps attributions, + - feature_timestep: Determines if to calcualte faithfulness of featrues per timesteps attributions, + Returns: + score: similarity score between sum of attributions and difference in prediction averaged over nr_runs + + Implementation of faithfulness correlation by Bhatt et al., 2020. + + The Faithfulness Correlation metric intend to capture an explanation's relative faithfulness + (or 'fidelity') with respect to the model behaviour. + + Faithfulness correlation scores shows to what extent the predicted logits of each modified test point and + the average explanation attribution for only the subset of features are (linearly) correlated, taking the + average over multiple runs and test samples. The metric returns one float per input-attribution pair that + ranges between -1 and 1, where higher scores are better. + + For each test sample, |S| features are randomly selected and replace them with baseline values (zero baseline + or average of set). Thereafter, Pearson’s correlation coefficient between the predicted logits of each modified + test point and the average explanation attribution for only the subset of features is calculated. Results is + average over multiple runs and several test samples. + This code is adapted from the quantus libray to suit our use case + + References: + 1) Umang Bhatt et al.: "Evaluating and aggregating feature-based model + explanations." IJCAI (2020): 3016-3022. + 2)Hedström, Anna, et al. "Quantus: An explainable ai toolkit for + responsible evaluation of neural network explanations and beyond." + Journal of Machine Learning Research 24.34 (2023): 1-11. + """ + + attribution = torch.tensor(attribution).to(model.device) + + # Other initializations + if similarity_func is None: + similarity_func = correlation_spearman + if pertrub is None: + pertrub = "baseline" + similarities = [] + + # Assuming this is a method to prepare your data + + y_pred = model(model.prep_data(x)).detach() # Keep on GPU + pred_deltas = [] + att_sums = [] + + for i_ix in range(nr_runs): + if time_step: + timesteps_idx = np.random.choice(24, subset_size, replace=False) + patient_idx = np.random.choice(64, 1, replace=False) + a_ix = [patient_idx, timesteps_idx] + + elif feature: + feature_idx = np.random.choice(53, subset_size, replace=False) + patient_idx = np.random.choice(64, 1, replace=False) + a_ix = [patient_idx, feature_idx] + elif feature_timestep: + timesteps_idx = np.random.choice(24, subset_size[0], replace=False) + feature_idx = np.random.choice(53, subset_size[1], replace=False) + patient_idx = np.random.choice(64, 1, replace=False) + a_ix = [patient_idx, timesteps_idx, feature_idx] + + # Apply perturbation + if pertrub == "Noise": + x = model.add_noise(x, a_ix, time_step, feature, feature_timestep) + elif pertrub == "baseline": + x = model.apply_baseline(x, a_ix, time_step, feature, feature_timestep) + + # Predict on perturbed input and calculate deltas + y_pred_perturb = (model(model.prep_data(x))).detach() # Keep on GPU + + if time_step: + if attribution.size() == torch.Size([24]): + att_sums.append((attribution[timesteps_idx]).sum()) + else: + att_sums.append((attribution[patient_idx, :, :][:, timesteps_idx, :]).sum()) + elif feature: + if len(attribution) == 53: + att_sums.append((attribution[feature_idx]).sum()) + else: + att_sums.append((attribution[patient_idx, :, :][:, :, feature_idx]).sum()) + elif feature_timestep: + att_sums.append((attribution[patient_idx, :, :][:, timesteps_idx, :][:, :, feature_idx]).sum()) + + pred_deltas.append((y_pred - y_pred_perturb)[patient_idx].item()) + # Convert to CPU for numpy operations + + pred_deltas_cpu = torch.tensor(pred_deltas).cpu().numpy() + att_sums_cpu = torch.tensor(att_sums).cpu().numpy() + + similarities.append(similarity_func(pred_deltas_cpu, att_sums_cpu)) + + score = np.nanmean(similarities) + return score + + +def Data_Randomization( + model, + x, + attribution, + explain_method, + random_model, + similarity_func=cosine, + dataloader=None, + method_name="", + **kwargs, +): + """ + + Args: + - x:Batch input + -attribution: attribution + - explain_method:function to generate explantations + - random_model: Reference to model trained on random labels + - similarity_func: Function to measure similiarity + - dataloader:In case of using Attention as the explain method need to pass the dataloader instead of the batch , + - method_name: Name of the explantation + + Returns: + score: similarity score between attributions of model trained on random data and model trained on real data + + Implementation of the Random Logit Metric by Sixt et al., 2020. + + The Random Logit Metric computes the distance between the original explanation and a reference explanation of + a randomly chosen non-target class. + This code is adapted from the quantus libray to suit our use case + + References: + 1) Leon Sixt et al.: "When Explanations Lie: Why Many Modified BP + Attributions Fail." ICML (2020): 9046-9057. + 2)Hedström, Anna, et al. "Quantus: An explainable ai + toolkit for responsible evaluation of neural network explanations and beyond." + Journal of Machine Learning Research 24.34 (2023): 1-11. + + """ + + if explain_method == "Attention": + Attention_weights = random_model.interpertations(dataloader) + attribution = attribution.cpu().numpy() + min_val = np.min(attribution) + max_val = np.max(attribution) + + attribution = (attribution - min_val) / (max_val - min_val) + random_attr = Attention_weights["attention"].cpu().numpy() + min_val = np.min(random_attr) + max_val = np.max(random_attr) + random_attr = (random_attr - min_val) / (max_val - min_val) + score = similarity_func(random_attr, attribution) + elif explain_method == "Random": + score = similarity_func(np.random.normal(size=[64, 24, 53]).flatten(), attribution.flatten()) + else: + data, baselines = model.prep_data_captum(x) + + explantation = explain_method(random_model.forward_captum) + # Reformat attributions. + if explain_method is not captum.attr.Saliency: + attr = explantation.attribute(data, baselines=baselines, **kwargs) + else: + attr = explantation.attribute(data, **kwargs) + + # Process and store the calculated attributions + random_attr = ( + attr[1].cpu().detach().numpy() + if method_name in ["Lime", "FeatureAblation"] + else torch.stack(attr).cpu().detach().numpy() + ) + + attribution = attribution.flatten() + min_val = np.min(attribution) + max_val = np.max(attribution) + attribution = (attribution - min_val) / (max_val - min_val) + random_attr = random_attr.flatten() + min_val = np.min(random_attr) + max_val = np.max(random_attr) + random_attr = (random_attr - min_val) / (max_val - min_val) + + score = similarity_func(random_attr, attribution) + return score + + +def Relative_Stability( + model, + x, + attribution, + explain_method, + method_name, + dataloader=None, + threshold=0.5, + **kwargs, +): + """ + Args: + - x:Batch input + -attribution: attribution + - explain_method:function to generate explantations + - method_name: Name of the explantation + - dataloader:In case of using Attention as the explain method need to pass the dataloader instead of the batch , + + + Returns: + RIS : relative distance between the explantation and the input + ROS: relative distance between the explantation and the output + + + References: + 1) `https://arxiv.org/pdf/2203.06877.pdf + 2)Hedström, Anna, et al. "Quantus: An explainable ai toolkit for responsible evaluation + of neural network explanations and beyond." Journal of Machine Learning Research 24.34 (2023): 1-11. + + """ + + def relative_stability_objective(x, xs, e_x, e_xs, eps_min=0.0001, input=False, device="cuda") -> torch.Tensor: + """ + Computes relative input and output stabilities maximization objective + as defined here :ref:`https://arxiv.org/pdf/2203.06877.pdf` by the authors. + + Args: + + x: Input tensor + xs: perturbed tensor. + e_x: Explanations for x. + e_xs: Explanations for xs. + eps_min:Value to avoid division by zero if needed + input:Boolean to indicate if this is an input or an output + device: the device to keep the tensors on + + Returns: + + ris_obj: Tensor + RIS maximization objective. + """ + + # Function to convert inputs to tensors if they are numpy arrays + def to_tensor(input_array): + if isinstance(input_array, np.ndarray): + return torch.tensor(input_array).to(device) + return input_array.to(device) + + # Convert all inputs to tensors and move to GPU + x, xs, e_x, e_xs = map(to_tensor, [x, xs, e_x, e_xs]) + + if input: + num_dim = x.ndim + else: + num_dim = e_x.ndim + + if num_dim == 3: + + def norm_function(arr): + return torch.norm(arr, dim=(-1, -2)) + + elif num_dim == 2: + + def norm_function(arr): + return torch.norm(arr, dim=-1) + + else: + + def norm_function(arr): + return torch.norm(arr) + + nominator = (e_x - e_xs) / (e_x + (e_x == 0) * eps_min) + nominator = norm_function(nominator) + + if input: + denominator = x - xs + denominator /= x + (x == 0) * eps_min + denominator = norm_function(denominator) + denominator += (denominator == 0) * eps_min + else: + denominator = torch.squeeze(x) - torch.squeeze(xs) + denominator = torch.norm(denominator, dim=-1) + denominator += (denominator == 0) * eps_min + + return nominator / denominator + + attribution = torch.tensor(attribution).to(model.device) + if explain_method == "Attention": + y_pred = model.model.predict(dataloader) + x_original = dataloader.dataset.data["reals"].clone() + + dataloader.dataset.add_noise() + x_perturb = dataloader.dataset.data["reals"].clone() + y_pred_perturb = model.model.predict(dataloader) + Attention_weights = model.interpertations(dataloader) + att_perturb = Attention_weights["attention"] + # Calculate the absolute difference + difference = torch.abs(y_pred_perturb - y_pred) + + # Find where the difference is less than or equal to a threshold + close_indices = torch.nonzero(difference <= threshold).squeeze() + RIS = relative_stability_objective( + x_original[close_indices, :, :].detach(), + x_perturb[close_indices, :, :].detach(), + attribution, + att_perturb, + input=True, + ) + + ROS = relative_stability_objective( + y_pred[close_indices], + y_pred_perturb[close_indices], + attribution, + att_perturb, + input=False, + ) + + else: + y_pred = model(model.prep_data(x)).detach() + x_original = x["encoder_cont"].detach().clone() + + with torch.no_grad(): + noise = torch.randn_like(x["encoder_cont"]) * 0.01 + x["encoder_cont"] += noise + y_pred_perturb = model(model.prep_data(x)).detach() + if explain_method == "Random": + att_perturb = np.random.normal(size=[64, 24, 53]) + att_perturb = torch.tensor(att_perturb).to(model.device) + else: + data, baselines = model.prep_data_captum(x) + + explantation = explain_method(model.forward_captum) + # Reformat attributions. + if explain_method is not captum.attr.Saliency: + att_perturb = explantation.attribute(data, baselines=baselines, **kwargs) + else: + att_perturb = explantation.attribute(data, **kwargs) + + # Process and store the calculated attributions + att_perturb = ( + att_perturb[1].detach() if method_name in ["Lime", "FeatureAblation"] else torch.stack(att_perturb).detach() + ) + # Calculate the absolute difference + difference = torch.abs(y_pred_perturb - y_pred) + + # Find where the difference is less than or equal to a threshold + close_indices = torch.nonzero(difference <= threshold).squeeze() + RIS = relative_stability_objective( + x_original[close_indices, :, :].detach(), + x["encoder_cont"][close_indices, :, :].detach(), + attribution[close_indices, :, :], + att_perturb[close_indices, :, :], + input=True, + ) + ROS = relative_stability_objective( + y_pred[close_indices], + y_pred_perturb[close_indices], + attribution[close_indices, :, :], + att_perturb[close_indices, :, :], + input=False, + ) + + return np.max(RIS.cpu().numpy()).astype(np.float64), np.max(ROS.cpu().numpy()).astype(np.float64) diff --git a/icu_benchmarks/models/wrappers.py b/icu_benchmarks/models/wrappers.py index e8595a70..3ca1322a 100644 --- a/icu_benchmarks/models/wrappers.py +++ b/icu_benchmarks/models/wrappers.py @@ -1,16 +1,13 @@ import logging from abc import ABC from typing import Dict, Any, List, Optional, Union - import torchmetrics from sklearn.metrics import log_loss, mean_squared_error - import torch from torch.nn import MSELoss, CrossEntropyLoss import torch.nn as nn from torch import Tensor, FloatTensor from torch.optim import Optimizer, Adam - import inspect import gin import numpy as np @@ -19,14 +16,16 @@ from icu_benchmarks.models.utils import create_optimizer, create_scheduler from joblib import dump from pytorch_lightning import LightningModule - from icu_benchmarks.models.constants import MLMetrics, DLMetrics from icu_benchmarks.contants import RunMode +from icu_benchmarks.models.utils import Faithfulness_Correlation, Data_Randomization, Relative_Stability +import captum +from captum._utils.models.linear_model import SkLearnLasso gin.config.external_configurable(nn.functional.nll_loss, module="torch.nn.functional") gin.config.external_configurable(nn.functional.cross_entropy, module="torch.nn.functional") gin.config.external_configurable(nn.functional.mse_loss, module="torch.nn.functional") - +gin.config.external_configurable(nn.functional.l1_loss, module="torch.nn.functional") gin.config.external_configurable(mean_squared_error, module="sklearn.metrics") gin.config.external_configurable(log_loss, module="sklearn.metrics") @@ -97,7 +96,11 @@ def check_supported_runmode(self, runmode: RunMode): class DLWrapper(BaseModule, ABC): requires_backprop = True _metrics_warning_printed = set() - _supported_run_modes = [RunMode.classification, RunMode.regression, RunMode.imputation] + _supported_run_modes = [ + RunMode.classification, + RunMode.regression, + RunMode.imputation, + ] def __init__( self, @@ -155,22 +158,24 @@ def on_train_start(self): def finalize_step(self, step_prefix=""): try: - self.log_dict( - { - f"{step_prefix}/{name}": ( - np.float32(metric.compute()) if isinstance(metric.compute(), np.float64) else metric.compute() - ) - for name, metric in self.metrics[step_prefix].items() - if "_Curve" not in name - }, - sync_dist=True, - ) + for name, metric in self.metrics[step_prefix].items(): + try: + value = np.float32(metric.compute()) if isinstance(metric.compute(), np.float64) else metric.compute() + self.log_dict({f"{step_prefix}/{name}": value}, sync_dist=True) + + except (NotComputableError, ValueError) as e: + if step_prefix not in self._metrics_warning_printed: + self._metrics_warning_printed.add(step_prefix) + logging.warning(f"Metric for {step_prefix}/{name} not computable: {e}") + for metric in self.metrics[step_prefix].values(): metric.reset() - except (NotComputableError, ValueError): + except (NotComputableError, ValueError) as e: if step_prefix not in self._metrics_warning_printed: self._metrics_warning_printed.add(step_prefix) logging.warning(f"Metrics for {step_prefix} not computable") + print(e) + pass def configure_optimizers(self): @@ -187,7 +192,11 @@ def configure_optimizers(self): if self.hparams.lr_scheduler is None or self.hparams.lr_scheduler == "": return optimizer scheduler = create_scheduler( - self.hparams.lr_scheduler, optimizer, self.hparams.lr_factor, self.hparams.lr_steps, self.hparams.epochs + self.hparams.lr_scheduler, + optimizer, + self.hparams.lr_factor, + self.hparams.lr_steps, + self.hparams.epochs, ) optimizers = {"optimizer": optimizer, "lr_scheduler": scheduler} logging.info(f"Using: {optimizers}") @@ -229,6 +238,9 @@ def __init__( epochs: int = 100, input_size: Tensor = None, initialization_method: str = "normal", + pytorch_forecasting: bool = False, + explain: list = [], + XAI_metric: list = [], **kwargs, ): super().__init__( @@ -248,14 +260,17 @@ def __init__( ) self.output_transform = None self.loss_weights = None + self.pytorch_forecasting = pytorch_forecasting + self.explain = explain + self.XAI_metric = XAI_metric def set_weight(self, weight, dataset): """Set the weight for the loss function.""" - if isinstance(weight, list): weight = FloatTensor(weight).to(self.device) elif weight == "balanced": weight = FloatTensor(dataset.get_balance()).to(self.device) + self.loss_weights = weight def set_metrics(self, *args): @@ -276,9 +291,11 @@ def softmax_multi_output_transform(output): # Output transform is not applied for contrib metrics, so we do our own. if self.run_mode == RunMode.classification: # Binary classification + if self.logit.out_features == 2: self.output_transform = softmax_binary_output_transform - metrics = DLMetrics.BINARY_CLASSIFICATION + metrics = DLMetrics.BINARY_CLASSIFICATION_TORCHMETRICS + else: # Multiclass classification self.output_transform = softmax_multi_output_transform @@ -302,9 +319,9 @@ def step_fn(self, element, step_prefix=""): element (object): step_prefix (str): Step type, by default: test, train, val. """ - + if len(element) == 2: - data, labels = element[0], element[1].to(self.device) + data, labels = element[0], (element[1]).to(self.device) if isinstance(data, list): for i in range(len(data)): data[i] = data[i].float().to(self.device) @@ -313,7 +330,11 @@ def step_fn(self, element, step_prefix=""): mask = torch.ones_like(labels).bool() elif len(element) == 3: - data, labels, mask = element[0], element[1].to(self.device), element[2].to(self.device) + data, labels, mask = ( + element[0], + element[1].to(self.device), + element[2].to(self.device), + ) if isinstance(data, list): for i in range(len(data)): data[i] = data[i].float().to(self.device) @@ -321,15 +342,18 @@ def step_fn(self, element, step_prefix=""): data = data.float().to(self.device) else: raise Exception("Loader should return either (data, label) or (data, label, mask)") + out = self(data) - + # If aux_loss is present, it is returned as a tuple if len(out) == 2 and isinstance(out, tuple): out, aux_loss = out else: aux_loss = 0 # Get prediction and target + prediction = torch.masked_select(out, mask.unsqueeze(-1)).reshape(-1, out.shape[-1]).to(self.device) + target = torch.masked_select(labels, mask).to(self.device) if prediction.shape[-1] > 1 and self.run_mode == RunMode.classification: @@ -347,7 +371,12 @@ def step_fn(self, element, step_prefix=""): if isinstance(value, torchmetrics.Metric): if key == "Binary_Fairness": feature_names = key.feature_helper(self.trainer) - value.update(transformed_output[0], transformed_output[1], data, feature_names) + value.update( + transformed_output[0], + transformed_output[1], + data, + feature_names, + ) else: value.update(transformed_output[0], transformed_output[1]) else: @@ -356,6 +385,510 @@ def step_fn(self, element, step_prefix=""): return loss +@gin.configurable("DLPredictionPytorchForecastingWrapper") +class DLPredictionPytorchForecastingWrapper(DLPredictionWrapper): + """Interface for Deep Learning models.""" + + _supported_run_modes = [RunMode.classification, RunMode.regression] + + def __init__( + self, + loss=CrossEntropyLoss(), + optimizer=torch.optim.Adam, + run_mode: RunMode = RunMode.classification, + input_shape=None, + lr: float = 0.002, + momentum: float = 0.9, + lr_scheduler: Optional[str] = None, + lr_factor: float = 0.99, + lr_steps: Optional[List[int]] = None, + epochs: int = 100, + input_size: Tensor = None, + initialization_method: str = "normal", + pytorch_forecasting: bool = False, + **kwargs, + ): + super().__init__( + loss=loss, + optimizer=optimizer, + run_mode=run_mode, + input_shape=input_shape, + lr=lr, + momentum=momentum, + lr_scheduler=lr_scheduler, + lr_factor=lr_factor, + lr_steps=lr_steps, + epochs=epochs, + input_size=input_size, + initialization_method=initialization_method, + kwargs=kwargs, + ) + + def step_fn(self, element, step_prefix=""): + """Perform a step in the DL prediction model training loop. + + Args: + element (object): + step_prefix (str): Step type, by default: test, train, val. + """ + + dic, labels = element[0], element[1][0] + + if isinstance(labels, list): + labels = labels[-1] + + data = self.prep_data(dic) + + out = self(data) + + # If aux_loss is present, it is returned as a tuple + if len(out) == 2 and isinstance(out, tuple): + out, aux_loss = out + else: + aux_loss = 0 + # Get prediction and target + + prediction = out.to(self.device).squeeze(-1) + + target = labels.to(self.device) + + if prediction.shape[-1] > 1 and self.run_mode == RunMode.classification: + # Classification task + loss = self.loss(prediction, target.long(), weight=self.loss_weights.to(self.device)) + aux_loss + # Returns torch.long because negative log likelihood loss + elif self.run_mode == RunMode.regression: + # Regression task + + loss = self.loss(prediction[:, 0], target.float()) + aux_loss + else: + raise ValueError(f"Run mode {self.run_mode} not yet supported. Please implement it.") + transformed_output = self.output_transform((prediction, target)) + + for key, value in self.metrics[step_prefix].items(): + if isinstance(value, torchmetrics.Metric): + if key == "Binary_Fairness": + feature_names = self.metrics[step_prefix][key].feature_helper(self.trainer, step_prefix) + value.update( + transformed_output[0], + transformed_output[1].int(), + data, + feature_names, + ) + + else: + value.update(transformed_output[0], transformed_output[1].int()) + else: + value.update(transformed_output) + self.log(f"{step_prefix}/loss", loss, on_step=False, on_epoch=True, sync_dist=True) + return loss + + def prep_data_captum(self, x): + """ + Prepares data to be fed into captum and generates baseline as well. + + Args: + - x:Batch from dataloader + Returns: + - data:batch data in a tuple after being prepared + - baselines:Basically zero tensors in the input + """ + # captum requires gradient and float values + + data = ( + x["encoder_cat"].float().requires_grad_(), + x["encoder_cont"].requires_grad_(), + x["encoder_target"].float().requires_grad_(), + x["encoder_lengths"].float().requires_grad_(), + x["decoder_cat"].float().requires_grad_(), + x["decoder_cont"].requires_grad_(), + x["decoder_target"].float().requires_grad_(), + x["decoder_lengths"].float().requires_grad_(), + x["decoder_time_idx"].float().requires_grad_(), + x["groups"].float().requires_grad_(), + x["target_scale"].requires_grad_(), + ) + baselines = ( + data[0].to(self.device), # encoder_cat, no cat variables + torch.zeros_like(data[1]).to(self.device), # encoder_cont, set to zero + torch.zeros_like(data[2]).to(self.device), # encoder_target, set to zero + data[3].to(self.device), # encoder_lengths, leave unchanged + data[4].to(self.device), # decoder_cat, no cat variables + torch.zeros_like(data[5]).to(self.device), # decoder_cont, set to zero + torch.zeros_like(data[6]).to(self.device), # decoder_target, set to zero + data[7].to(self.device), # decoder_lengths, leave unchanged + data[8].to(self.device), # decoder_time_idx, unchanged + data[9].to(self.device), # groups, leave unchanged + data[10].to(self.device), # target_scale, leave unchanged + ) + return data, baselines + + def explantation( + self, + dataloader, + method, + log_dir=".", + plot=False, + XAI_metric=False, + random_model=None, + test_dataset=None, + **kwargs, + ): + """ + Generic method to combine pytorchforecasting data loading , interpertations and captum to generate attributions + + Args: + - dataloader: pytorchforecasting data loader + - method: The explantation method chosen + - log_dir= The directory to output the plots + - plot= Determines if plots should be done or not + - XAI_metric=Determines if XAI metrics should be calculated or not + Returns: + - all_attrs : Attribtuons of features per timesteps + - features_attrs : Attribtuons of features averaged over timesteps + - timestep_attrs : Attribtuons of timesteps averaged over features + - f_ts_v_score: Faithfulness score for attribtuons of features per timesteps + - f_ts_score: Faithfulness score for attribtuons of timesteps averaged over features + """ + # Initialize lists to store attribution values for all instances + all_attrs = [] + f_ts_score = [] + f_ts_v_score = [] + f_v_score = [] + r_score = [] + st_i_score = [] + st_o_score = [] + + method_name = method if (method == "Random") or (method == "Attention") else (method.__name__) + if (method_name == "Random") or (method_name == "Attention"): + if method_name == "Attention": + Interpertations = self.interpertations(dataloader=dataloader, log_dir=log_dir, plot=plot) + timestep_attrs = Interpertations["attention"] + features_attrs = Interpertations["static_variables"].tolist() + features_attrs.extend(Interpertations["encoder_variables"].tolist()) + r_score = Data_Randomization( + self, + x=None, + attribution=timestep_attrs, + explain_method=method, + random_model=random_model, + dataloader=dataloader, + method_name=method_name, + ) + st_i_score, st_o_score = Relative_Stability( + self, + x=None, + attribution=timestep_attrs, + explain_method=method, + method_name=method_name, + dataloader=dataloader, + **kwargs, + ) + elif method_name == "Random": + # Generate random attributions for baseline comparison + all_attrs = np.random.normal(size=[64, 24, 53]) + features_attrs = all_attrs.mean(axis=(1)) + timestep_attrs = all_attrs.mean(axis=(2)) + if XAI_metric: + for batch in dataloader: + for key, value in batch[0].items(): + batch[0][key] = batch[0][key].to(self.device) + x = batch[0] + + if method_name == "Random": + f_ts_v_score.append( + Faithfulness_Correlation( + self, + x, + all_attrs, + pertrub="baseline", + feature_timestep=True, + subset_size=[4, 9], + nr_runs=100, + ) + ) + f_ts_score.append( + Faithfulness_Correlation( + self, + x, + all_attrs, + pertrub="baseline", + time_step=True, + subset_size=4, + nr_runs=100, + ) + ) + f_v_score.append( + Faithfulness_Correlation( + self, + x, + all_attrs, + pertrub="baseline", + feature=True, + subset_size=9, + nr_runs=100, + ) + ) + + r_score.append( + Data_Randomization( + self, + x, + attribution=all_attrs, + explain_method=method, + random_model=random_model, + method_name=method_name, + ) + ) + res1, res2 = Relative_Stability( + self, + x, + all_attrs, + explain_method=method, + method_name=method_name, + dataloader=None, + **kwargs, + ) + st_i_score.append(res1) + st_o_score.append(res2) + else: + f_ts_score.append( + Faithfulness_Correlation( + self, + x, + timestep_attrs, + pertrub="baseline", + time_step=True, + subset_size=4, + nr_runs=100, + ) + ) + f_v_score.append( + Faithfulness_Correlation( + self, + x, + features_attrs, + pertrub="baseline", + feature=True, + subset_size=9, + nr_runs=100, + ) + ) + + # Faithfulness score for attribtuons of features per timesteps + f_ts_v_score = np.mean(f_ts_v_score) + # Faithfulness score for attribtuons of timesteps averaged over features + f_ts_score = np.mean(f_ts_score) + f_v_score = np.mean(f_v_score) + + if method_name != "Attention": + # r_score = (r_score - min_val) / (max_val - min_val) + r_score = np.mean(r_score) + st_i_score = np.max(st_i_score) + st_o_score = np.max(st_o_score) + return ( + all_attrs, + features_attrs, + timestep_attrs, + f_ts_v_score, + f_ts_score, + f_v_score, + r_score, + st_i_score, + st_o_score, + ) + + # Loop through the dataloader to compute attributions for all instances + for batch in dataloader: + for key, value in batch[0].items(): + batch[0][key] = batch[0][key].to(self.device) + x = batch[0] + + data, baselines = self.prep_data_captum(x) + + # Initialize the explanation method + explanation = ( + method(self.forward_captum, interpretable_model=SkLearnLasso(alpha=0.4)) + if method_name == "Lime" + else method(self.forward_captum) + ) + + # Calculate attributions using the selected method + if method is not captum.attr.Saliency: + attr = explanation.attribute(data, baselines=baselines, **kwargs) + else: + attr = explanation.attribute(data, **kwargs) + + # Process and store the calculated attributions + stacked_attr = ( + attr[1].cpu().detach().numpy() + if method_name in ["Lime", "FeatureAblation"] + else torch.stack(attr).cpu().detach().numpy() + ) + if XAI_metric: + f_ts_v_score.append( + Faithfulness_Correlation( + self, + x, + stacked_attr, + pertrub="baseline", + feature_timestep=True, + subset_size=[4, 9], + nr_runs=100, + ) + ) + + f_ts_score.append( + Faithfulness_Correlation( + self, + x, + stacked_attr, + pertrub="baseline", + time_step=True, + subset_size=4, + nr_runs=100, + ) + ) + f_v_score.append( + Faithfulness_Correlation( + self, + x, + stacked_attr, + pertrub="baseline", + feature=True, + subset_size=9, + nr_runs=100, + ) + ) + r_score.append( + Data_Randomization( + self, + x, + attribution=stacked_attr, + explain_method=method, + random_model=random_model, + method_name=method_name, + ) + ) + + res1, res2 = Relative_Stability( + self, + x, + stacked_attr, + explain_method=method, + method_name=method_name, + dataloader=None, + **kwargs, + ) + st_i_score.append(res1) + st_o_score.append(res2) + + # aggregate over batch + attr = np.mean(stacked_attr, axis=0) + all_attrs.append(attr) + # aggregate over all batches + all_attrs = np.array(all_attrs).mean(axis=(0)) + # aggregate over all timesteps + features_attrs = all_attrs.mean(axis=(0)) + # aggregate over all features + timestep_attrs = all_attrs.mean(axis=(1)) + # Faithfulness score for attribtuons of features per timesteps + f_ts_v_score = np.mean(f_ts_v_score) + # Faithfulness score for attribtuons of timesteps averaged over features + f_ts_score = np.mean(f_ts_score) + # Faithfulness score for attribtuons of timesteps averaged over timesteps + f_v_score = np.mean(f_v_score) + + # Random data score + r_score = np.mean(r_score) + st_i_score = np.max(st_i_score) + st_o_score = np.max(st_o_score) + + # Return computed attributions and metrics + return ( + all_attrs, + features_attrs, + timestep_attrs, + f_ts_v_score, + f_ts_score, + f_v_score, + r_score, + st_i_score, + st_o_score, + ) + # normalized_means = (means - means.min()) / (means.max() - means.min()) + + def prep_data(self, x): + """ + Prepares data for custom forward method + + Args: + - x:Batch returned from dataloader + Returns: + data:Tuple consisting of the tensors of X in the format the forward method needs + """ + data = ( + x["encoder_cat"], + x["encoder_cont"], + x["encoder_target"], + x["encoder_lengths"], + x["decoder_cat"], + x["decoder_cont"], + x["decoder_target"], + x["decoder_lengths"], + x["decoder_time_idx"], + x["groups"], + x["target_scale"], + ) + return data + + def add_noise(self, x, indices, time_step, feature, feature_timestep): + noise = torch.randn_like(x["encoder_cont"]) + if time_step: + idx0, idx1 = np.meshgrid(indices[0], indices[1], indexing="ij") + + with torch.no_grad(): + x["encoder_cont"][idx0, idx1, :] += noise[idx0, idx1, :] + + elif feature: + idx0, idx1 = np.meshgrid(indices[0], indices[1], indexing="ij") + + with torch.no_grad(): + x["encoder_cont"][idx0, :, idx1] += noise[idx0, :, idx1] + + elif feature_timestep: + idx0, idx1, idx2 = np.meshgrid(indices[0], indices[1], indices[2], indexing="ij") + + with torch.no_grad(): + x["encoder_cont"][idx0, idx1, idx2] += noise[idx0, idx1, idx2] + return x + + def apply_baseline(self, x, indices, time_step, feature, feature_timestep): + mask = torch.ones_like(x["encoder_cont"]) + if time_step: + ( + idx0, + idx1, + ) = np.meshgrid(indices[0], indices[1], indexing="ij") + + mask[idx0, idx1, :] -= mask[idx0, idx1, :] + elif feature: + ( + idx0, + idx1, + ) = np.meshgrid(indices[0], indices[1], indexing="ij") + + mask[idx0, :, idx1] -= mask[idx0, :, idx1] + + elif feature_timestep: + idx0, idx1, idx2 = np.meshgrid(indices[0], indices[1], indices[2], indexing="ij") + + mask[idx0, idx1, idx2] -= mask[idx0, idx1, idx2] + + with torch.no_grad(): + x["encoder_cont"] *= mask + return x + + @gin.configurable("MLWrapper") class MLWrapper(BaseModule, ABC): """Interface for prediction with traditional Scikit-learn-like Machine Learning models.""" @@ -363,7 +896,15 @@ class MLWrapper(BaseModule, ABC): requires_backprop = False _supported_run_modes = [RunMode.classification, RunMode.regression] - def __init__(self, *args, run_mode=RunMode.classification, loss=log_loss, patience=10, mps=False, **kwargs): + def __init__( + self, + *args, + run_mode=RunMode.classification, + loss=log_loss, + patience=10, + mps=False, + **kwargs, + ): super().__init__() self.save_hyperparameters() self.scaler = None @@ -439,17 +980,25 @@ def validation_step(self, val_dataset, _): def test_step(self, dataset, _): test_rep, test_label = dataset - test_rep, test_label = test_rep.squeeze().cpu().numpy(), test_label.squeeze().cpu().numpy() + test_rep, test_label = ( + test_rep.squeeze().cpu().numpy(), + test_label.squeeze().cpu().numpy(), + ) self.set_metrics(test_label) test_pred = self.predict(test_rep) if self.mps: - self.log("test/loss", np.float32(self.loss(test_label, test_pred)), sync_dist=True) + self.log( + "test/loss", + np.float32(self.loss(test_label, test_pred)), + sync_dist=True, + ) self.log_metrics(np.float32(test_label), np.float32(test_pred), "test") else: self.log("test/loss", self.loss(test_label, test_pred), sync_dist=True) self.log_metrics(test_label, test_pred, "test") logging.debug(f"Test loss: {self.loss(test_label, test_pred)}") + self.log_metrics(np.float32(test_label), np.float32(test_pred), "test") def predict(self, features): if self.run_mode == RunMode.regression: @@ -469,7 +1018,10 @@ def log_metrics(self, label, pred, metric_type): # Fore very metric for name, metric in self.metrics.items() # Filter out metrics that return a tuple (e.g. precision_recall_curve) - if not isinstance(metric(self.label_transform(label), self.output_transform(pred)), tuple) + if not isinstance( + metric(self.label_transform(label), self.output_transform(pred)), + tuple, + ) }, sync_dist=True, ) @@ -588,7 +1140,10 @@ def step_fn(self, batch, step_prefix=""): for metric in self.metrics[step_prefix].values(): metric.update( - (torch.flatten(amputated.detach(), start_dim=1).clone(), torch.flatten(target.detach(), start_dim=1).clone()) + ( + torch.flatten(amputated.detach(), start_dim=1).clone(), + torch.flatten(target.detach(), start_dim=1).clone(), + ) ) return loss diff --git a/icu_benchmarks/run.py b/icu_benchmarks/run.py index 3d596ccb..5c2d3bac 100644 --- a/icu_benchmarks/run.py +++ b/icu_benchmarks/run.py @@ -5,7 +5,11 @@ import sys from pathlib import Path import torch.cuda -from icu_benchmarks.wandb_utils import update_wandb_config, apply_wandb_sweep, set_wandb_experiment_name +from icu_benchmarks.wandb_utils import ( + update_wandb_config, + apply_wandb_sweep, + set_wandb_experiment_name, +) from icu_benchmarks.tuning.hyperparameters import choose_and_bind_hyperparameters from scripts.plotting.utils import plot_aggregated_results from icu_benchmarks.cross_validation import execute_repeated_cv @@ -48,6 +52,11 @@ def main(my_args=tuple(sys.argv[1:])): evaluate = args.eval experiment = args.experiment source_dir = args.source_dir + explain = args.explain + pytorch_forecasting = args.pytorch_forecasting + XAI_metric = args.XAI_metric + random_labels = args.random_labels + # Load task config gin.parse_config_file(f"configs/tasks/{task}.gin") mode = get_mode() @@ -73,7 +82,7 @@ def main(my_args=tuple(sys.argv[1:])): else "None" } ) - + random_model_dir = args.random_model log_dir_name = args.log_dir / name log_dir = ( (log_dir_name / experiment) @@ -85,10 +94,14 @@ def main(my_args=tuple(sys.argv[1:])): # Check cuda availability if torch.cuda.is_available(): for name in range(0, torch.cuda.device_count()): - log_full_line(f"Available GPU {name}: {torch.cuda.get_device_name(name)}", level=logging.INFO) + log_full_line( + f"Available GPU {name}: {torch.cuda.get_device_name(name)}", + level=logging.INFO, + ) else: log_full_line( - "No GPUs available: please check your device and Torch,Cuda installation if unintended.", level=logging.WARNING + "No GPUs available: please check your device and Torch,Cuda installation if unintended.", + level=logging.WARNING, ) if args.preprocessor: @@ -118,7 +131,7 @@ def main(my_args=tuple(sys.argv[1:])): name_datasets(args.name, args.name, args.name) hp_checkpoint = log_dir / args.hp_checkpoint if args.hp_checkpoint else None model_path = ( - Path("configs") / ("imputation_models" if mode == RunMode.imputation else "prediction_models") / f"{model}.gin" + Path("configs") / ("imputation_models" if mode == RunMode.imputation else "prediction_models") / f"{model}.gin" ) gin_config_files = ( [Path(f"configs/experiments/{args.experiment}.gin")] @@ -139,6 +152,8 @@ def main(my_args=tuple(sys.argv[1:])): generate_cache=args.generate_cache, load_cache=args.load_cache, verbose=verbose, + pytorch_forecasting=pytorch_forecasting, + random_labels=random_labels, ) log_full_line(f"Logging to {run_dir.resolve()}", level=logging.INFO) @@ -151,6 +166,7 @@ def main(my_args=tuple(sys.argv[1:])): log_full_line(mode_string, level=logging.INFO, char="=", num_newlines=3) start_time = datetime.now() + execute_repeated_cv( data_dir, run_dir, @@ -169,6 +185,11 @@ def main(my_args=tuple(sys.argv[1:])): cpu=args.cpu, wandb=args.wandb_sweep, complete_train=args.complete_train, + explain=explain, + pytorch_forecasting=pytorch_forecasting, + XAI_metric=XAI_metric, + random_model_dir=random_model_dir, + random_labels=random_labels, ) log_full_line("FINISHED TRAINING", level=logging.INFO, char="=", num_newlines=3) diff --git a/icu_benchmarks/run_utils.py b/icu_benchmarks/run_utils.py index 85b676ac..dac9730d 100644 --- a/icu_benchmarks/run_utils.py +++ b/icu_benchmarks/run_utils.py @@ -15,6 +15,7 @@ from statistics import mean, pstdev from icu_benchmarks.models.utils import JsonResultLoggingEncoder from icu_benchmarks.wandb_utils import wandb_log +import numpy as np def build_parser() -> ArgumentParser: @@ -52,6 +53,14 @@ def build_parser() -> ArgumentParser: parser.add_argument("-sn", "--source-name", type=Path, help="Name of the source dataset.") parser.add_argument("--source-dir", type=Path, help="Directory containing gin and model weights.") parser.add_argument("-sa", "--samples", type=int, default=None, help="Number of samples to use for evaluation.") + parser.add_argument("--explain", default=False, action=BOA, help="Provide explaintations for predictions.") + parser.add_argument("--pytorch-forecasting", default=False, action=BOA, help="use pytorch forecasting library ") + parser.add_argument("--XAI_metric", default=False, action=BOA, help="Compare explantations ") + parser.add_argument("--random_labels", default=False, action=BOA, help="randmize target labels") + parser.add_argument( + "--random_model", default=Path("."), type=Path, help="path for model weights that is trained on random labels" + ) + return parser @@ -112,11 +121,16 @@ def aggregate_results(log_dir: Path, execution_time: timedelta = None): with open(fold_iter / "val_metrics.csv", "r") as f: result = json.load(f) aggregated[repetition.name][fold_iter.name].update(result) + # Add durations to metrics if (fold_iter / "durations.json").is_file(): with open(fold_iter / "durations.json", "r") as f: result = json.load(f) aggregated[repetition.name][fold_iter.name].update(result) + if (fold_iter / "XAI_metrics.json").is_file(): + with open(fold_iter / "XAI_metrics.json", "r") as f: + result = json.load(f) + aggregated[repetition.name][fold_iter.name].update(result) # Aggregate results per metric list_scores = {} @@ -132,7 +146,10 @@ def aggregate_results(log_dir: Path, execution_time: timedelta = None): # Calculate the population standard deviation over aggregated results over folds/iterations # Divide by sqrt(n) to get standard deviation. - std_scores = {metric: (pstdev(list) / sqrt(len(list))) for metric, list in list_scores.items()} + + std_scores = { + metric: (pstdev(list) / sqrt(len(list))) for metric, list in list_scores.items() if not (np.isnan(list).all()) + } confidence_interval = { metric: (stats.t.interval(0.95, len(list) - 1, loc=mean(list), scale=stats.sem(list))) @@ -145,6 +162,10 @@ def aggregate_results(log_dir: Path, execution_time: timedelta = None): "CI_0.95": confidence_interval, "execution_time": execution_time.total_seconds() if execution_time is not None else 0.0, } + log_dir_plots = log_dir / "plots" + if not (log_dir_plots.exists()): + log_dir_plots.mkdir(parents=True) + # plot_XAI_Metrics(accumulated_metrics, log_dir_plots=log_dir_plots) with open(log_dir / "aggregated_test_metrics.json", "w") as f: json.dump(aggregated, f, cls=JsonResultLoggingEncoder) @@ -175,7 +196,12 @@ def log_full_line(msg: str, level: int = logging.INFO, char: str = "-", num_newl reserved_chars = len(logging.getLevelName(level)) + 28 logging.log( level, - "{0:{char}^{width}}{1}".format(msg, "\n" * num_newlines, char=char, width=terminal_size.columns - reserved_chars), + "{0:{char}^{width}}{1}".format( + msg, + "\n" * num_newlines, + char=char, + width=terminal_size.columns - reserved_chars, + ), ) diff --git a/icu_benchmarks/tuning/hyperparameters.py b/icu_benchmarks/tuning/hyperparameters.py index 212a7753..1458a7c2 100644 --- a/icu_benchmarks/tuning/hyperparameters.py +++ b/icu_benchmarks/tuning/hyperparameters.py @@ -36,6 +36,8 @@ def choose_and_bind_hyperparameters( debug: bool = False, verbose: bool = False, wandb: bool = False, + pytorch_forecasting: bool = False, + random_labels: bool = False, ): """Choose hyperparameters to tune and bind them to gin. @@ -86,7 +88,10 @@ def choose_and_bind_hyperparameters( n_calls, configuration, evaluation = load_checkpoint(checkpoint_path, n_calls) # Check if we surpassed maximum tuning iterations if n_calls <= 0: - logging.log(TUNE, "No more hyperparameter tuning iterations left, skipping tuning.") + logging.log( + TUNE, + "No more hyperparameter tuning iterations left, skipping tuning.", + ) logging.info("Training with these hyperparameters:") bind_gin_params(hyperparams_names, configuration[np.argmin(evaluation)]) # bind best hyperparameters return @@ -112,6 +117,8 @@ def bind_params_and_train(hyperparams): debug=debug, verbose=verbose, wandb=wandb, + pytorch_forecasting=pytorch_forecasting, + random_labels=random_labels, ) header = ["ITERATION"] + hyperparams_names + ["LOSS AT ITERATION"] diff --git a/scripts/plotting/plotting.py b/scripts/plotting/plotting.py index 779eea26..881d4e1b 100644 --- a/scripts/plotting/plotting.py +++ b/scripts/plotting/plotting.py @@ -1,4 +1,5 @@ import matplotlib.pyplot as plt +import numpy as np class Plotter: @@ -49,3 +50,194 @@ def calibration_curve(self): plt.legend(loc="lower right") plt.savefig(self.save_dir / f"call_curve {self.specifier}.png") plt.clf() + + def plot_XAI_Metrics(accumulated_metrics, log_dir_plots): + groups = {} + for key in accumulated_metrics["avg"]: + if key in ["loss", "MAE"]: + continue + suffix = key.split("_")[-1] + if suffix not in groups: + groups[suffix] = [] + groups[suffix].append(key) + + # Define a dictionary for legend labels + legend_labels = { + "IG": "Integrated Gradient", + "G": "Gradient", + "R": "Random", + "FA": "Feature Ablation", + "Att": "Attention", + "VSN": "Variable Selection Network", + "L": "Lime", + } + colors = [ + "navy", + "skyblue", + "crimson", + "salmon", + "teal", + "orange", + "darkgreen", + "lightgreen", + ] + + # Plotting + num_groups = len(groups) + fig, axs = plt.subplots(num_groups, 1, figsize=(10, num_groups * 5)) + + # Custom handles for the legend + # handles = [plt.Rectangle((0, 0), 1, 1, color="none", + # label=f"{key}: {value}") for key, value in legend_labels.items()] + + for i, (suffix, keys) in enumerate(groups.items()): + ax = axs[i] if num_groups > 1 else axs + # Extract values and errors + avg_values = [accumulated_metrics["avg"][key] for key in keys] + ci_lower = [accumulated_metrics["CI_0.95"][key][0] for key in keys] + ci_upper = [accumulated_metrics["CI_0.95"][key][1] for key in keys] + ci_error = [np.abs([a - b, c - a]) for a, b, c in zip(avg_values, ci_lower, ci_upper)] + + # Sort by absolute values of avg_values + sorted_indices = np.argsort([np.abs(val) for val in avg_values])[::-1] # Indices to sort in descending order + sorted_keys = np.array(keys)[sorted_indices] + sorted_avg_values = np.array(avg_values)[sorted_indices] + sorted_ci_error = np.array(ci_error)[sorted_indices] + + # Plot bars + bars = ax.bar( + sorted_keys, + np.abs(sorted_avg_values), + yerr=np.array(sorted_ci_error).T, + capsize=5, + color=colors, + ) + + # Set titles and labels + title_suffix = sorted_keys[0].split("_")[1] + ax.set_title(f'Metric: "{title_suffix}"') + ax.set_ylabel("Values") + ax.axhline(0, color="grey", linewidth=0.8) + ax.grid(axis="y") + + # Set x-ticks + ax.set_xticks(sorted_keys) + ax.set_xticklabels([key.split("_")[0] for key in sorted_keys], rotation=45, ha="right") + # Create a custom legend for each subplot + custom_labels = [legend_labels[key.split("_")[0]] for key in sorted_keys] + ax.legend(bars, custom_labels, loc="upper right") + + plt.tight_layout() + plt.savefig(log_dir_plots / "metrics_plot.png", bbox_inches="tight") + + def plot_attributions(self, features_attrs, timestep_attrs, method_name, log_dir): + """ + Plots the attribution values for features and timesteps. + + Args: + - features_attrs: Array of feature attribution values. + - timestep_attrs: Array of timestep attribution values. + - method_name: Name of the attribution method. + - log_dir: Directory to save the plots. + Returns: + Nothing + """ + + # Plot for feature attributions + x_values = np.arange(1, len(features_attrs) + 1) + plt.figure(figsize=(8, 6)) + plt.plot( + x_values, + features_attrs, + marker="o", + color="skyblue", + linestyle="-", + linewidth=2, + markersize=8, + ) + plt.xlabel("Feature") + plt.ylabel("{} Attribution".format(method_name)) + plt.title("{} Attribution Values".format(method_name)) + plt.xticks( + x_values, + [ + "height", + "weight", + "age", + "sex", + "time_idx", + "alb", + "alp", + "alt", + "ast", + "be", + "bicar", + "bili", + "bili_dir", + "bnd", + "bun", + "ca", + "cai", + "ck", + "ckmb", + "cl", + "crea", + "crp", + "dbp", + "fgn", + "fio2", + "glu", + "hgb", + "hr", + "inr_pt", + "k", + "lact", + "lymph", + "map", + "mch", + "mchc", + "mcv", + "methb", + "mg", + "na", + "neut", + "o2sat", + "pco2", + "ph", + "phos", + "plt", + "po2", + "ptt", + "resp", + "sbp", + "temp", + "tnt", + "urine", + "wbc", + ], + rotation=90, + ) + plt.tight_layout() + plt.savefig( + log_dir / "{}_attribution_features_plot.png".format(method_name), + bbox_inches="tight", + ) + + # Plot for timestep attributions + x_values = np.arange(1, len(timestep_attrs) + 1) + plt.figure(figsize=(8, 6)) + plt.plot( + x_values, + timestep_attrs, + marker="o", + color="skyblue", + linestyle="-", + linewidth=2, + markersize=8, + ) + plt.xlabel("Time Step") + plt.ylabel("{} Attribution".format(method_name)) + plt.title("{} Attribution Values".format(method_name)) + plt.xticks(x_values) + plt.tight_layout() + plt.savefig(log_dir / "{}_attribution_plot.png".format(method_name), bbox_inches="tight")