Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
274 changes: 191 additions & 83 deletions icu_features/feature_engineering.py
Comment thread
manuelburger marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import polars as pl

from icu_features.constants import CAT_MISSING_NAME, HORIZONS, VARIABLE_REFERENCE_PATH
from icu_features.load import features

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -353,17 +354,18 @@ def treatment_continuous_features(
return expressions


def eep_label(events: pl.Expr, horizon: int):
def eep_label(events: pl.Expr, horizon: int, switches_only: bool = True):
"""
From an event series, create a label for the early event prediction (eep) task.

- If there is a positive event at the current time step, the label is missing.
- If the last event in the history was positive, and there is a positive event
within the next `horizon` hours, the label is missing.
- If `switches_only` is `True`, and if the last event in the history was positive,
and there is a positive event within the next `horizon` hours, the label is
missing. That is, we only predict switches from stable to unstable.
- If there was no event in the history or the last event in the history was
negative, and there is a positive event within the next `horizon` hours, the
label is true. This holds even if there is a negative event at the current time
step.
negative or `switches_only` is `False`, and there is a positive event within the
next `horizon` hours, the label is true. This holds even if there is a negative
event at the current time step.
- Else, if there is a negative event within the next `horizon` hours (and no
positive event within the next `horizon` hours or at the current time step), the
label is false.
Expand All @@ -377,7 +379,11 @@ def eep_label(events: pl.Expr, horizon: int):
positive_labels: 1 1 - - - - - - - - - - 1 1 1 1 1 1 1 - 1 1 1 1 1 1 1
negative_labels: - 0 0 0 0 0 - 0 0 0 0 0 0 0 - - 0 0 0 0 0 0 0 0 0 0 -
coalesced_labels: 1 1 0 0 0 0 - 0 0 0 0 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1

if switches_only is True:
label: 1 - 0 0 0 - - 0 0 0 0 0 1 1 1 1 - - - 0 1 1 1 1 - 1 -
else:
label: 1 - 0 0 0 - - 0 0 0 0 0 1 1 1 1 - 1 - 0 1 1 1 1 - 1 -

Note that at the time step of a positive event, the label is always missing. At the
time step of a negative event, the label could be true, false, or missing.
Expand All @@ -390,17 +396,20 @@ def eep_label(events: pl.Expr, horizon: int):
An expression for an event series. Boolean with possibly missing values.
horizon : int
The horizon for the early event prediction task.
switches_only : bool, optional, default = True
Whether to only assign True labels for switches from stable to unstable. I.e.,
no positive labels for events `1 - 1`.
"""
events_ffilled = events.forward_fill().replace(None, False)

positive_labels = events.replace(False, None).backward_fill(horizon)
# shift(-1) and backward_fill(horizon - 1) excludes the last zero.
negative_labels = events.replace(True, None).shift(-1).backward_fill(horizon - 1)

coalesced_label = pl.coalesce(positive_labels, negative_labels)
return pl.when(coalesced_label.eq(False) | events_ffilled.eq(False)).then(
coalesced_label
)

if switches_only:
events = events.forward_fill()
events = events.replace(None, False)

return pl.when(coalesced_label.eq(False) | events.eq(False)).then(coalesced_label)


def polars_nan_or(*args: pl.Expr):
Expand All @@ -412,11 +421,11 @@ def polars_nan_or(*args: pl.Expr):

Examples
--------
>>> import polars as pl
>>> a = pl.Series("a", [1, 2, None])
>>> b = pl.Series("b", [0, None, 2])
>>> polars_nan_or(a < 0, b == 2)
[ False, None, True ]
>>> polars_nan_or(
>>> pl.Series("a", [True, True, True, False, False, False]),
>>> pl.Series("b", [True, None, False, True, None, False]),
>>> )
[True, True, True, True, None, False]
"""
return (
pl.when(pl.max_horizontal(*args)) # This ignores nans
Expand All @@ -432,7 +441,7 @@ def outcomes():

These are:
- mortality_at_24h: A single label at time 24h after entry to the ICU whether the
patient dies in the ICU. THis is a "once per patient" prediction task.
patient dies in the ICU. This is a "once per patient" prediction task.
- decompensation_at_24h: Whether the patient decompensates within the next 24 hours.
This has label is true if the patient dies within the next 24 hours. Else, this is
false. This does not have missing values.
Expand All @@ -451,10 +460,6 @@ def outcomes():
hours. The patient has a kidney failure if they are in stage 3 according to the
KDIGO guidelines:
https://kdigo.org/wp-content/uploads/2016/10/KDIGO-2012-AKI-Guideline-English.pdf
- los_at_24h: The length of stay in the ICU at 24 hours after entry.
- log_creatine_in_1h: The log of the creatinine value in 1 hour.
- log_lactate_in_1h: The log of the lactate value in 1 hour.
- log_po2: The logarithm of the PaO2 value.
"""
# mortality_at_24h
# This is a "once per patient" prediction task. At time step 24h, a label is
Expand All @@ -477,7 +482,8 @@ def outcomes():

# respiratory_failure_at_24h
# If the PaO2/FiO2 ratio is below 200, the patient is considered to have a
# respiratory failure (event). This used pf_ratio from other_variables().
# respiratory failure (event). This used pf_ratio from other_variables(). This uses
# a fio2 which was imputed by 21% if the patient was not ventilated.
RESP_PF_DEF_TSH = 200.0
events = pl.col("pf_ratio") < RESP_PF_DEF_TSH
resp_failure_at_24h = eep_label(events, 24).alias("respiratory_failure_at_24h")
Expand All @@ -487,6 +493,19 @@ def outcomes():
remaining_los = pl.when(remaining_los > 0).then(remaining_los).otherwise(None)
remaining_los = remaining_los.alias("remaining_los")

# Severe respiratory failure label with simple imputation of po2 and fio2, related
# to Hueser et al. https://www.medrxiv.org/content/10.1101/2024.01.23.24301516v1.
SEVERE_RESP_PF_DEF_TSH = 100
pf_ratio = (
100
* pl.col("po2").forward_fill(1).backward_fill(1)
/ pl.col("fio2").forward_fill(1).backward_fill(1)
)
events = pf_ratio < SEVERE_RESP_PF_DEF_TSH
respiratory_failure_at_24h_severe_imputed = eep_label(
events, 24, switches_only=False
).alias("respiratory_failure_at_24h_severe_imputed")

# circulatory_failure_at_8h
# A patient is considered to have a circulatory failure if the mean arterial
# is low (below 65, or being raised by a drug) and the lactate is high (above 2).
Expand Down Expand Up @@ -527,57 +546,150 @@ def outcomes():
)
circulatory_failure_at_8h = eep_label(event, 8).alias("circulatory_failure_at_8h")

# Circulatory failure label using interpolated lactate, inspured by Hyland et al.
# https://www.nature.com/articles/s41591-020-0789-4
# First, we interpolate lactate values. If the time difference between two
# consecutive lactate measurements is less than 6 hours, or if both lactate values
# are either above or below 2, we linearly interpolate the lactate value. Else, we
# fill the value forward and backward for 3 hours.
time = pl.when(pl.col("lact").is_not_null()).then(pl.col("time_hours"))
interp = (time.backward_fill() - time.forward_fill() < 6) | (
bad_lact.forward_fill() == bad_lact.backward_fill()
)
lact_interp = (
pl.when(interp)
.then(pl.col("lact").interpolate(method="linear"))
.otherwise(pl.col("lact").backward_fill(3).forward_fill(3))
)
# On the boundary, if the first value of lactate is below the threshold ("good"), we
# fill backwards indefinitely. If the last value is below the threshold, fill
# forward indefinitely. Else, we fill forward and backward for 3 hours.
lact_interp = pl.coalesce(
lact_interp,
pl.when(pl.col("time").le(time.min()) & bad_lact.backward_fill()).then(
Comment thread
manuelburger marked this conversation as resolved.
lact_interp.backward_fill()
),
pl.when(pl.col("time").ge(time.max()) & bad_lact.forward_fill()).then(
lact_interp.forward_fill()
),
)
bad_lact_interp = lact_interp >= HIGH_LACT_TSH
event = (
pl.when(bad_map & bad_lact_interp)
Comment thread
mlondschien marked this conversation as resolved.
.then(True)
.when(~bad_map & ~bad_lact_interp)
.then(False)
)
circulatory_failure_at_8h_imputed = eep_label(event, 8, switches_only=False).alias(
"circulatory_failure_at_8h_imputed"
)

# kidney_failure_at_48h
# The patient has a kidney failure if they are in stage 3 according to
# https://kdigo.org/wp-content/uploads/2016/10/KDIGO-2012-AKI-Guideline-English.pdf
relative_creatine = pl.col("crea") / pl.col("crea").shift(1).rolling_min(
window_size=7 * 24, min_samples=1
def kdigo_3(crea, rel_urine_rate, crrt):
"""Compute the KDIGO stage 3 kidney failure label."""
crea_baseline = crea.shift(1).rolling_min(7 * 24, min_samples=1)

# AKI 1 is
# - max absolute creatinine increase of 0.3 within 48h or
# - a relative creatinine increase of 1.5.
creatine_min_48 = crea.rolling_min(window_size=48, min_samples=1)
creatine_max_48 = crea.rolling_max(window_size=48, min_samples=1)
creatine_change_48 = creatine_max_48 - creatine_min_48
aki_1 = polars_nan_or(creatine_change_48 >= 0.3, crea / crea_baseline >= 1.5)
Comment thread
mlondschien marked this conversation as resolved.

# AKI 3 is any of the following:
# - a relative creatine increase of 3.0 x baseline
# - AKI 1 and creatinine >= 4.0
# - not more than 0.3ml/kg/h urine rate for 24h
# - no urine for 12h
# - initiation of renal replacement therapy (crrt)
good_urine_rate = (rel_urine_rate >= 0.3).cast(pl.Int32)
low_urine_rate_24 = ~(good_urine_rate.rolling_sum(24, min_samples=1).gt(0))

# high_creatine is True if aki_1 is True and creatine >= 4 (neither missing).
# False if either aki_1 is False or creatine < 4. Else, missing.
high_creatine = ~polars_nan_or(~aki_1, ~crea.gt(4))

# anuria True if the urine_rate is consistently equal to 0 for 12 hours. False
# if it is ever above 0. If all values are missing, the result is missing.
not_anuria = rel_urine_rate.gt(0).cast(pl.Int32)
anuria = ~(not_anuria.rolling_sum(window_size=12, min_samples=1).gt(0))

return polars_nan_or(
crrt.cast(pl.Boolean).replace(False, None).forward_fill().fill_null(False),
(crea / crea_baseline) >= 3.0,
high_creatine,
low_urine_rate_24,
anuria,
)

aki_3 = kdigo_3(
pl.col("crea"),
pl.col("rel_urine_rate"),
pl.col("ufilt_ind"),
)
kidney_failure_at_48h = eep_label(aki_3, 48).alias("kidney_failure_at_48h")

# AKI 1 is
# - max absolute creatinine increase of 0.3 within 48h or
# - a relative creatinine increase of 1.5.
creatine_min_48 = pl.col("crea").rolling_min(window_size=48, min_samples=1)
creatine_max_48 = pl.col("crea").rolling_max(window_size=48, min_samples=1)
creatine_change_48 = creatine_max_48 - creatine_min_48
aki_1 = polars_nan_or(creatine_change_48 >= 0.3, relative_creatine >= 1.5)

# AKI 3 is any of
# - a relative creatine increase of 3.0 x baseline
# - AKI 1 and creatinine >= 4.0
# - not more than 0.3ml/kg/h urine rate for 24h
# - no urine for 12h
low_urine_rate = ((pl.col("urine_rate") / pl.col("weight")) < 0.3).cast(pl.Int32)
low_urine_rate_24 = low_urine_rate.rolling_sum(window_size=24, min_samples=1).eq(24)

# True if aki_1 is True and creatine >= 4 (neither missing). False if either aki_1
# is False or creatine < 4. Else, missing.
high_creatine = ~polars_nan_or(~aki_1, ~pl.col("crea").gt(4))
# True if the urine_rate is consistently equal to 0 for 12 hours. False if it is
# ever above 0. If all values are missing, the result is missing.
anuria = (
pl.col("urine_rate")
.eq(0)
.cast(pl.Int32)
.rolling_sum(window_size=12, min_samples=1)
.eq(12)
# Kidney failure with creatine imputation motivated by Lyu et al. 2024:
# https://www.medrxiv.org/content/10.1101/2024.02.01.24302063v1
# Similarly to circulatory_failure_8h_imputed, we linearly interpolate creatine
# values. We only interpolate if the time difference between two consecutive
# creatine measurements is less than 48 hours. We ffill and bfill the first and last
# measurements.
time = pl.when(pl.col("crea").is_not_null()).then(pl.col("time_hours"))
interpolate = time.backward_fill() - time.forward_fill() < 48
crea = pl.coalesce(
pl.when(interpolate).then(pl.col("crea").interpolate(method="linear")),
pl.when(pl.col("time").le(time.min())).then(pl.col("crea").backward_fill(48)),
pl.when(pl.col("time").ge(time.max())).then(pl.col("crea").forward_fill(48)),
)
urine_rate = pl.col("urine_rate").backward_fill(48)

aki_3 = polars_nan_or(
relative_creatine >= 3.0,
high_creatine,
low_urine_rate_24,
anuria,
aki_3 = kdigo_3(
crea,
urine_rate / pl.col("weight"),
pl.col("ufilt_ind"),
)
kidney_failure_at_48h_imputed = eep_label(aki_3, 48, switches_only=False).alias(
"kidney_failure_at_48h_imputed"
)
# If the weight is missing, the patient could only ever have a positive label, as
# urine related conditions are always missing. We thus set the label to missing.
aki_3 = pl.when(pl.col("weight").is_null()).then(None).otherwise(aki_3)
kidney_failure_at_48h = eep_label(aki_3, 48).alias("kidney_failure_at_48h")

# total length of stay, predicted at 24h after entry to the ICU.
los_at_24h = pl.when(pl.col("time_hours").eq(24)).then(pl.col("los_icu"))
los_at_24h = pl.when(los_at_24h > 0.1).then(los_at_24h).alias("los_at_24h")
log_los_at_24h = los_at_24h.log().alias("log_los_at_24h")
# hyperglycemia_at_8h and hypoglycemia_at_8h according to Mehdizavareha et al.,
# https://arxiv.org/pdf/2411.01418
hyperglycemia_event = pl.col("glu") > 180 # mg/dl
hypoglycemia_event = pl.col("glu") < 70 # mg/dl
hyperglycemia_at_8h = eep_label(hyperglycemia_event, 8, switches_only=False).alias(
"hyperglycemia_at_8h"
)
hypoglycemia_at_8h = eep_label(hypoglycemia_event, 8, switches_only=False).alias(
"hypoglycemia_at_8h"
)

# Liver failure according to
# - MELD score > 30: https://en.wikipedia.org/wiki/Model_for_End-Stage_Liver_Disease
# - SOFA score >= 3: https://en.wikipedia.org/wiki/SOFA_score
crea = (
pl.when(
pl.col("ufilt_ind")
.cast(pl.Boolean)
.replace(False, None)
.forward_fill(7 * 24)
)
.then(4.0)
.otherwise(pl.col("crea").forward_fill(1).backward_fill(1))
)
bili = pl.col("bili").forward_fill(1).backward_fill(1).clip(1, None)
inr = pl.col("inr_pt").forward_fill(1).backward_fill(1).clip(1, None)
meld_score = 3.78 * crea.log() + 11.2 * inr.log() + 9.57 * bili.log() + 6.43
meld_event = meld_score > 30
severe_meld_at_48h = eep_label(meld_event, 48, switches_only=False).alias(
"severe_meld_at_48h"
)

liver_sofa3 = pl.col("bili").forward_fill(1).backward_fill(1) > 6.0
liver_sofa3_at_48h = eep_label(liver_sofa3, 48).alias("liver_sofa3_at_48h")

# log(lactate) in 4 hours. This is 1/2 the forecast horizon of circ. failure eep.
log_lactate_in_4h = (
Expand All @@ -588,32 +700,21 @@ def outcomes():
pl.col("pf_ratio").log().shift(-12).alias("log_pf_ratio_in_12h")
)

# The "raw" ICU data contains urine measurements from a bag at specific times (ml).
# In `ricu`, we divide these measurement values by the time distance to the last
# measurement. This gives us a urine rate (ml/h). This divided by the patient's
# weight is the relative urine rate (ml/h/kg). These "relative rate" measurements
# are only non-missing at the timepoint of the measurement.
# We assign a label if there is a (positive) measurement in 2 hours.
log_rel_urine_rate_in_2h = (
pl.when(
pl.col("rel_urine_rate").is_not_null() & pl.col("rel_urine_rate").ge(0.01)
)
.then(pl.col("rel_urine_rate").log())
.shift(-2)
.alias("log_rel_urine_rate_in_2h")
)

return [
mortality_at_24h,
decompensation_at_24h,
resp_failure_at_24h,
respiratory_failure_at_24h_severe_imputed,
remaining_los,
circulatory_failure_at_8h,
circulatory_failure_at_8h_imputed,
kidney_failure_at_48h,
los_at_24h,
log_los_at_24h,
kidney_failure_at_48h_imputed,
hyperglycemia_at_8h,
hypoglycemia_at_8h,
severe_meld_at_48h,
liver_sofa3_at_48h,
log_lactate_in_4h,
log_rel_urine_rate_in_2h,
log_pf_ratio_in_12h,
]

Expand Down Expand Up @@ -729,6 +830,13 @@ def main(dataset: str, data_dir: str | Path): # noqa D
pl.col("time_hours").log1p().alias("log_time_hours"),
)

feature_names = set(features())
schema_names = set(q.collect_schema().keys())
missing_features = feature_names - schema_names

if missing_features:
raise ValueError(f"Missing features: {missing_features}")

tic = perf_counter()
out = q.collect()
toc = perf_counter()
Expand Down
Loading