diff --git a/icu_features/feature_engineering.py b/icu_features/feature_engineering.py index f05add1..5255aca 100644 --- a/icu_features/feature_engineering.py +++ b/icu_features/feature_engineering.py @@ -7,6 +7,7 @@ import polars as pl from icu_features.constants import CAT_MISSING_NAME, HORIZONS, VARIABLE_REFERENCE_PATH +from icu_features.load import features logger = logging.getLogger(__name__) @@ -353,17 +354,18 @@ def treatment_continuous_features( return expressions -def eep_label(events: pl.Expr, horizon: int): +def eep_label(events: pl.Expr, horizon: int, switches_only: bool = True): """ From an event series, create a label for the early event prediction (eep) task. - If there is a positive event at the current time step, the label is missing. - - If the last event in the history was positive, and there is a positive event - within the next `horizon` hours, the label is missing. + - If `switches_only` is `True`, and if the last event in the history was positive, + and there is a positive event within the next `horizon` hours, the label is + missing. That is, we only predict switches from stable to unstable. - If there was no event in the history or the last event in the history was - negative, and there is a positive event within the next `horizon` hours, the - label is true. This holds even if there is a negative event at the current time - step. + negative or `switches_only` is `False`, and there is a positive event within the + next `horizon` hours, the label is true. This holds even if there is a negative + event at the current time step. - Else, if there is a negative event within the next `horizon` hours (and no positive event within the next `horizon` hours or at the current time step), the label is false. @@ -377,7 +379,11 @@ def eep_label(events: pl.Expr, horizon: int): positive_labels: 1 1 - - - - - - - - - - 1 1 1 1 1 1 1 - 1 1 1 1 1 1 1 negative_labels: - 0 0 0 0 0 - 0 0 0 0 0 0 0 - - 0 0 0 0 0 0 0 0 0 0 - coalesced_labels: 1 1 0 0 0 0 - 0 0 0 0 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 + + if switches_only is True: label: 1 - 0 0 0 - - 0 0 0 0 0 1 1 1 1 - - - 0 1 1 1 1 - 1 - + else: + label: 1 - 0 0 0 - - 0 0 0 0 0 1 1 1 1 - 1 - 0 1 1 1 1 - 1 - Note that at the time step of a positive event, the label is always missing. At the time step of a negative event, the label could be true, false, or missing. @@ -390,17 +396,20 @@ def eep_label(events: pl.Expr, horizon: int): An expression for an event series. Boolean with possibly missing values. horizon : int The horizon for the early event prediction task. + switches_only : bool, optional, default = True + Whether to only assign True labels for switches from stable to unstable. I.e., + no positive labels for events `1 - 1`. """ - events_ffilled = events.forward_fill().replace(None, False) - positive_labels = events.replace(False, None).backward_fill(horizon) # shift(-1) and backward_fill(horizon - 1) excludes the last zero. negative_labels = events.replace(True, None).shift(-1).backward_fill(horizon - 1) - coalesced_label = pl.coalesce(positive_labels, negative_labels) - return pl.when(coalesced_label.eq(False) | events_ffilled.eq(False)).then( - coalesced_label - ) + + if switches_only: + events = events.forward_fill() + events = events.replace(None, False) + + return pl.when(coalesced_label.eq(False) | events.eq(False)).then(coalesced_label) def polars_nan_or(*args: pl.Expr): @@ -412,11 +421,11 @@ def polars_nan_or(*args: pl.Expr): Examples -------- - >>> import polars as pl - >>> a = pl.Series("a", [1, 2, None]) - >>> b = pl.Series("b", [0, None, 2]) - >>> polars_nan_or(a < 0, b == 2) - [ False, None, True ] + >>> polars_nan_or( + >>> pl.Series("a", [True, True, True, False, False, False]), + >>> pl.Series("b", [True, None, False, True, None, False]), + >>> ) + [True, True, True, True, None, False] """ return ( pl.when(pl.max_horizontal(*args)) # This ignores nans @@ -432,7 +441,7 @@ def outcomes(): These are: - mortality_at_24h: A single label at time 24h after entry to the ICU whether the - patient dies in the ICU. THis is a "once per patient" prediction task. + patient dies in the ICU. This is a "once per patient" prediction task. - decompensation_at_24h: Whether the patient decompensates within the next 24 hours. This has label is true if the patient dies within the next 24 hours. Else, this is false. This does not have missing values. @@ -451,10 +460,6 @@ def outcomes(): hours. The patient has a kidney failure if they are in stage 3 according to the KDIGO guidelines: https://kdigo.org/wp-content/uploads/2016/10/KDIGO-2012-AKI-Guideline-English.pdf - - los_at_24h: The length of stay in the ICU at 24 hours after entry. - - log_creatine_in_1h: The log of the creatinine value in 1 hour. - - log_lactate_in_1h: The log of the lactate value in 1 hour. - - log_po2: The logarithm of the PaO2 value. """ # mortality_at_24h # This is a "once per patient" prediction task. At time step 24h, a label is @@ -477,7 +482,8 @@ def outcomes(): # respiratory_failure_at_24h # If the PaO2/FiO2 ratio is below 200, the patient is considered to have a - # respiratory failure (event). This used pf_ratio from other_variables(). + # respiratory failure (event). This used pf_ratio from other_variables(). This uses + # a fio2 which was imputed by 21% if the patient was not ventilated. RESP_PF_DEF_TSH = 200.0 events = pl.col("pf_ratio") < RESP_PF_DEF_TSH resp_failure_at_24h = eep_label(events, 24).alias("respiratory_failure_at_24h") @@ -487,6 +493,19 @@ def outcomes(): remaining_los = pl.when(remaining_los > 0).then(remaining_los).otherwise(None) remaining_los = remaining_los.alias("remaining_los") + # Severe respiratory failure label with simple imputation of po2 and fio2, related + # to Hueser et al. https://www.medrxiv.org/content/10.1101/2024.01.23.24301516v1. + SEVERE_RESP_PF_DEF_TSH = 100 + pf_ratio = ( + 100 + * pl.col("po2").forward_fill(1).backward_fill(1) + / pl.col("fio2").forward_fill(1).backward_fill(1) + ) + events = pf_ratio < SEVERE_RESP_PF_DEF_TSH + respiratory_failure_at_24h_severe_imputed = eep_label( + events, 24, switches_only=False + ).alias("respiratory_failure_at_24h_severe_imputed") + # circulatory_failure_at_8h # A patient is considered to have a circulatory failure if the mean arterial # is low (below 65, or being raised by a drug) and the lactate is high (above 2). @@ -527,57 +546,150 @@ def outcomes(): ) circulatory_failure_at_8h = eep_label(event, 8).alias("circulatory_failure_at_8h") + # Circulatory failure label using interpolated lactate, inspured by Hyland et al. + # https://www.nature.com/articles/s41591-020-0789-4 + # First, we interpolate lactate values. If the time difference between two + # consecutive lactate measurements is less than 6 hours, or if both lactate values + # are either above or below 2, we linearly interpolate the lactate value. Else, we + # fill the value forward and backward for 3 hours. + time = pl.when(pl.col("lact").is_not_null()).then(pl.col("time_hours")) + interp = (time.backward_fill() - time.forward_fill() < 6) | ( + bad_lact.forward_fill() == bad_lact.backward_fill() + ) + lact_interp = ( + pl.when(interp) + .then(pl.col("lact").interpolate(method="linear")) + .otherwise(pl.col("lact").backward_fill(3).forward_fill(3)) + ) + # On the boundary, if the first value of lactate is below the threshold ("good"), we + # fill backwards indefinitely. If the last value is below the threshold, fill + # forward indefinitely. Else, we fill forward and backward for 3 hours. + lact_interp = pl.coalesce( + lact_interp, + pl.when(pl.col("time").le(time.min()) & bad_lact.backward_fill()).then( + lact_interp.backward_fill() + ), + pl.when(pl.col("time").ge(time.max()) & bad_lact.forward_fill()).then( + lact_interp.forward_fill() + ), + ) + bad_lact_interp = lact_interp >= HIGH_LACT_TSH + event = ( + pl.when(bad_map & bad_lact_interp) + .then(True) + .when(~bad_map & ~bad_lact_interp) + .then(False) + ) + circulatory_failure_at_8h_imputed = eep_label(event, 8, switches_only=False).alias( + "circulatory_failure_at_8h_imputed" + ) + # kidney_failure_at_48h # The patient has a kidney failure if they are in stage 3 according to # https://kdigo.org/wp-content/uploads/2016/10/KDIGO-2012-AKI-Guideline-English.pdf - relative_creatine = pl.col("crea") / pl.col("crea").shift(1).rolling_min( - window_size=7 * 24, min_samples=1 + def kdigo_3(crea, rel_urine_rate, crrt): + """Compute the KDIGO stage 3 kidney failure label.""" + crea_baseline = crea.shift(1).rolling_min(7 * 24, min_samples=1) + + # AKI 1 is + # - max absolute creatinine increase of 0.3 within 48h or + # - a relative creatinine increase of 1.5. + creatine_min_48 = crea.rolling_min(window_size=48, min_samples=1) + creatine_max_48 = crea.rolling_max(window_size=48, min_samples=1) + creatine_change_48 = creatine_max_48 - creatine_min_48 + aki_1 = polars_nan_or(creatine_change_48 >= 0.3, crea / crea_baseline >= 1.5) + + # AKI 3 is any of the following: + # - a relative creatine increase of 3.0 x baseline + # - AKI 1 and creatinine >= 4.0 + # - not more than 0.3ml/kg/h urine rate for 24h + # - no urine for 12h + # - initiation of renal replacement therapy (crrt) + good_urine_rate = (rel_urine_rate >= 0.3).cast(pl.Int32) + low_urine_rate_24 = ~(good_urine_rate.rolling_sum(24, min_samples=1).gt(0)) + + # high_creatine is True if aki_1 is True and creatine >= 4 (neither missing). + # False if either aki_1 is False or creatine < 4. Else, missing. + high_creatine = ~polars_nan_or(~aki_1, ~crea.gt(4)) + + # anuria True if the urine_rate is consistently equal to 0 for 12 hours. False + # if it is ever above 0. If all values are missing, the result is missing. + not_anuria = rel_urine_rate.gt(0).cast(pl.Int32) + anuria = ~(not_anuria.rolling_sum(window_size=12, min_samples=1).gt(0)) + + return polars_nan_or( + crrt.cast(pl.Boolean).replace(False, None).forward_fill().fill_null(False), + (crea / crea_baseline) >= 3.0, + high_creatine, + low_urine_rate_24, + anuria, + ) + + aki_3 = kdigo_3( + pl.col("crea"), + pl.col("rel_urine_rate"), + pl.col("ufilt_ind"), ) + kidney_failure_at_48h = eep_label(aki_3, 48).alias("kidney_failure_at_48h") - # AKI 1 is - # - max absolute creatinine increase of 0.3 within 48h or - # - a relative creatinine increase of 1.5. - creatine_min_48 = pl.col("crea").rolling_min(window_size=48, min_samples=1) - creatine_max_48 = pl.col("crea").rolling_max(window_size=48, min_samples=1) - creatine_change_48 = creatine_max_48 - creatine_min_48 - aki_1 = polars_nan_or(creatine_change_48 >= 0.3, relative_creatine >= 1.5) - - # AKI 3 is any of - # - a relative creatine increase of 3.0 x baseline - # - AKI 1 and creatinine >= 4.0 - # - not more than 0.3ml/kg/h urine rate for 24h - # - no urine for 12h - low_urine_rate = ((pl.col("urine_rate") / pl.col("weight")) < 0.3).cast(pl.Int32) - low_urine_rate_24 = low_urine_rate.rolling_sum(window_size=24, min_samples=1).eq(24) - - # True if aki_1 is True and creatine >= 4 (neither missing). False if either aki_1 - # is False or creatine < 4. Else, missing. - high_creatine = ~polars_nan_or(~aki_1, ~pl.col("crea").gt(4)) - # True if the urine_rate is consistently equal to 0 for 12 hours. False if it is - # ever above 0. If all values are missing, the result is missing. - anuria = ( - pl.col("urine_rate") - .eq(0) - .cast(pl.Int32) - .rolling_sum(window_size=12, min_samples=1) - .eq(12) + # Kidney failure with creatine imputation motivated by Lyu et al. 2024: + # https://www.medrxiv.org/content/10.1101/2024.02.01.24302063v1 + # Similarly to circulatory_failure_8h_imputed, we linearly interpolate creatine + # values. We only interpolate if the time difference between two consecutive + # creatine measurements is less than 48 hours. We ffill and bfill the first and last + # measurements. + time = pl.when(pl.col("crea").is_not_null()).then(pl.col("time_hours")) + interpolate = time.backward_fill() - time.forward_fill() < 48 + crea = pl.coalesce( + pl.when(interpolate).then(pl.col("crea").interpolate(method="linear")), + pl.when(pl.col("time").le(time.min())).then(pl.col("crea").backward_fill(48)), + pl.when(pl.col("time").ge(time.max())).then(pl.col("crea").forward_fill(48)), ) + urine_rate = pl.col("urine_rate").backward_fill(48) - aki_3 = polars_nan_or( - relative_creatine >= 3.0, - high_creatine, - low_urine_rate_24, - anuria, + aki_3 = kdigo_3( + crea, + urine_rate / pl.col("weight"), + pl.col("ufilt_ind"), + ) + kidney_failure_at_48h_imputed = eep_label(aki_3, 48, switches_only=False).alias( + "kidney_failure_at_48h_imputed" ) - # If the weight is missing, the patient could only ever have a positive label, as - # urine related conditions are always missing. We thus set the label to missing. - aki_3 = pl.when(pl.col("weight").is_null()).then(None).otherwise(aki_3) - kidney_failure_at_48h = eep_label(aki_3, 48).alias("kidney_failure_at_48h") - # total length of stay, predicted at 24h after entry to the ICU. - los_at_24h = pl.when(pl.col("time_hours").eq(24)).then(pl.col("los_icu")) - los_at_24h = pl.when(los_at_24h > 0.1).then(los_at_24h).alias("los_at_24h") - log_los_at_24h = los_at_24h.log().alias("log_los_at_24h") + # hyperglycemia_at_8h and hypoglycemia_at_8h according to Mehdizavareha et al., + # https://arxiv.org/pdf/2411.01418 + hyperglycemia_event = pl.col("glu") > 180 # mg/dl + hypoglycemia_event = pl.col("glu") < 70 # mg/dl + hyperglycemia_at_8h = eep_label(hyperglycemia_event, 8, switches_only=False).alias( + "hyperglycemia_at_8h" + ) + hypoglycemia_at_8h = eep_label(hypoglycemia_event, 8, switches_only=False).alias( + "hypoglycemia_at_8h" + ) + + # Liver failure according to + # - MELD score > 30: https://en.wikipedia.org/wiki/Model_for_End-Stage_Liver_Disease + # - SOFA score >= 3: https://en.wikipedia.org/wiki/SOFA_score + crea = ( + pl.when( + pl.col("ufilt_ind") + .cast(pl.Boolean) + .replace(False, None) + .forward_fill(7 * 24) + ) + .then(4.0) + .otherwise(pl.col("crea").forward_fill(1).backward_fill(1)) + ) + bili = pl.col("bili").forward_fill(1).backward_fill(1).clip(1, None) + inr = pl.col("inr_pt").forward_fill(1).backward_fill(1).clip(1, None) + meld_score = 3.78 * crea.log() + 11.2 * inr.log() + 9.57 * bili.log() + 6.43 + meld_event = meld_score > 30 + severe_meld_at_48h = eep_label(meld_event, 48, switches_only=False).alias( + "severe_meld_at_48h" + ) + + liver_sofa3 = pl.col("bili").forward_fill(1).backward_fill(1) > 6.0 + liver_sofa3_at_48h = eep_label(liver_sofa3, 48).alias("liver_sofa3_at_48h") # log(lactate) in 4 hours. This is 1/2 the forecast horizon of circ. failure eep. log_lactate_in_4h = ( @@ -588,32 +700,21 @@ def outcomes(): pl.col("pf_ratio").log().shift(-12).alias("log_pf_ratio_in_12h") ) - # The "raw" ICU data contains urine measurements from a bag at specific times (ml). - # In `ricu`, we divide these measurement values by the time distance to the last - # measurement. This gives us a urine rate (ml/h). This divided by the patient's - # weight is the relative urine rate (ml/h/kg). These "relative rate" measurements - # are only non-missing at the timepoint of the measurement. - # We assign a label if there is a (positive) measurement in 2 hours. - log_rel_urine_rate_in_2h = ( - pl.when( - pl.col("rel_urine_rate").is_not_null() & pl.col("rel_urine_rate").ge(0.01) - ) - .then(pl.col("rel_urine_rate").log()) - .shift(-2) - .alias("log_rel_urine_rate_in_2h") - ) - return [ mortality_at_24h, decompensation_at_24h, resp_failure_at_24h, + respiratory_failure_at_24h_severe_imputed, remaining_los, circulatory_failure_at_8h, + circulatory_failure_at_8h_imputed, kidney_failure_at_48h, - los_at_24h, - log_los_at_24h, + kidney_failure_at_48h_imputed, + hyperglycemia_at_8h, + hypoglycemia_at_8h, + severe_meld_at_48h, + liver_sofa3_at_48h, log_lactate_in_4h, - log_rel_urine_rate_in_2h, log_pf_ratio_in_12h, ] @@ -729,6 +830,13 @@ def main(dataset: str, data_dir: str | Path): # noqa D pl.col("time_hours").log1p().alias("log_time_hours"), ) + feature_names = set(features()) + schema_names = set(q.collect_schema().keys()) + missing_features = feature_names - schema_names + + if missing_features: + raise ValueError(f"Missing features: {missing_features}") + tic = perf_counter() out = q.collect() toc = perf_counter() diff --git a/tests/test_outcomes.py b/tests/test_outcomes.py index 8fe4971..1f81ec6 100644 --- a/tests/test_outcomes.py +++ b/tests/test_outcomes.py @@ -13,22 +13,29 @@ def to_bool(x): @pytest.mark.parametrize( - "events, expected, horizon", + "events, expected, horizon, switches_only", [ ( "- 1 - - 0 0 - - - - - 0 0 0 - - 1 - 1 - 0 0 0 - 1 0 1", "1 - 0 0 0 - - 0 0 0 0 0 1 1 1 1 - - - 0 1 1 1 1 - 1 -", 4, + True, + ), + ( + "- 1 - - 0 0 - - - - - 0 0 0 - - 1 - 1 - 0 0 0 - 1 0 1", + "1 - 0 0 0 - - 0 0 0 0 0 1 1 1 1 - 1 - 0 1 1 1 1 - 1 -", + 4, + False, ), ], ) -def test_eep_labels(events, expected, horizon): +def test_eep_labels(events, expected, horizon, switches_only): df = pl.DataFrame( { "events": to_bool(events), "expected": to_bool(expected), } - ).with_columns(eep_label(pl.col("events"), horizon).alias("labels")) + ).with_columns(eep_label(pl.col("events"), horizon, switches_only).alias("labels")) assert_series_equal(df["labels"], df["expected"], check_names=False) @@ -95,16 +102,6 @@ def test_polars_nan_or(args, expected): ), pl.Series(np.arange(4 * 24, 0, -1) / 24), ), - ( - "los_at_24h", - pl.DataFrame( - { - "los_icu": 4.0, - "time_hours": np.arange(0, 4 * 24), - } - ), - pl.Series(24 * [None] + [4.0] + 71 * [None]), - ), ( "log_lactate_in_4h", pl.DataFrame( @@ -192,8 +189,8 @@ def test_outcomes(outcome_name, input, expected): pl.DataFrame( { "crea": [1.0] * 24 + [None] * 72 + [2.0] * 72 + [4.0] * 48, - "urine_rate": 70, - "weight": 70, + "rel_urine_rate": 1, + "ufilt_ind": [False] * 216, } ), pl.Series(