diff --git a/.gitattributes b/.gitattributes index c536b92..191ace7 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,2 +1,5 @@ tests/testdata/eicu_demo/dyn.parquet filter=lfs diff=lfs merge=lfs -text tests/testdata/eicu_demo/sta.parquet filter=lfs diff=lfs merge=lfs -text +tests/testdata/eicu_demo/hospital.parquet filter=lfs diff=lfs merge=lfs -text +tests/testdata/mimic_demo/dyn.parquet filter=lfs diff=lfs merge=lfs -text +tests/testdata/mimic_demo/sta.parquet filter=lfs diff=lfs merge=lfs -text \ No newline at end of file diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index ec4fce9..f166cf0 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -5,7 +5,7 @@ on: pull_request: jobs: - unittest: + integration-tests: name: Unit tests - ${{ matrix.PYTHON_VERSION }} runs-on: ubuntu-latest strategy: @@ -28,12 +28,20 @@ jobs: environment-file: environment.yml - name: Install repository run: python -m pip install --no-build-isolation --no-deps --disable-pip-version-check -e . - - name: Run feature_engineering.py - run: python icu_features/feature_engineering.py --dataset "eicu_demo" --data_dir "tests/testdata" + - name: install icd-mappings + run: pip install icd-mappings + - name: feature engineering for eicu_demo + run: | + python icu_features/icd_codes.py --data_dir "tests/testdata" --dataset "eicu_demo" + python icu_features/feature_engineering.py --dataset "eicu_demo" --data_dir "tests/testdata" + - name: feature engineering for eicu_demo + run: | + python icu_features/split_datasets.py --data_dir "tests/testdata" + python icu_features/icd_codes.py --data_dir "tests/testdata" --dataset "mimic_demo-carevue" + python icu_features/feature_engineering.py --dataset "mimic_demo-carevue" --data_dir "tests/testdata" - name: Pytest run: pytest tests - pre-commit-checks: name: "Linux - pre-commit checks - Python 3.12" timeout-minutes: 30 diff --git a/environment.yml b/environment.yml index 9f0a8cb..5ed7372 100644 --- a/environment.yml +++ b/environment.yml @@ -8,6 +8,6 @@ dependencies: - pip - click - numpy>=2.1 - - polars + - polars==1.26 - pytest - pyarrow \ No newline at end of file diff --git a/icu_features/__init__.py b/icu_features/__init__.py index e69de29..8129cfc 100644 --- a/icu_features/__init__.py +++ b/icu_features/__init__.py @@ -0,0 +1,3 @@ +from .load import load + +__all__ = ["load"] diff --git a/icu_features/constants.py b/icu_features/constants.py index 7002ebb..dadb565 100644 --- a/icu_features/constants.py +++ b/icu_features/constants.py @@ -2,7 +2,7 @@ VARIABLE_REFERENCE_PATH = Path(__file__).parents[1] / "resources" / "variables.tsv" -HORIZONS = [8, 24, 72] +HORIZONS = [8, 24] CAT_MISSING_NAME = "(MISSING)" @@ -124,7 +124,6 @@ CONTINUOUS_FEATURES = [ "mean", - "sq_mean", "std", "slope", "fraction_nonnull", diff --git a/icu_features/feature_engineering.py b/icu_features/feature_engineering.py index e5ae339..4d72bb4 100644 --- a/icu_features/feature_engineering.py +++ b/icu_features/feature_engineering.py @@ -382,7 +382,7 @@ def eep_label(events: pl.Expr, horizon: int, switches_only: bool = True): if switches_only is True: label: 1 - 0 0 0 - - 0 0 0 0 0 1 1 1 1 - - - 0 1 1 1 1 - 1 - - else: + if switches_only is False: label: 1 - 0 0 0 - - 0 0 0 0 0 1 1 1 1 - 1 - 0 1 1 1 1 - 1 - Note that at the time step of a positive event, the label is always missing. At the @@ -435,7 +435,7 @@ def polars_nan_or(*args: pl.Expr): ) -def outcomes(): +def outcomes(dataset): """ Compute outcomes. @@ -447,15 +447,25 @@ def outcomes(): false. This does not have missing values. - respiratory_failure_at_24h: Whether the patient has a respiratory failure within the next 24 hours. If the PaO2/FiO2 ratio is below 200, the patient is considered - to have a respiratory failure (event). + to have a severe respiratory failure (event). For the datasets nwicu, picdb, and + zigong, where the PaO2/FiO2 ratio is not available, we use the condition + SaO2 <= 90% instead. + - respiratory_failure_at_24h: Whether the patient has a respiratory failure within + the next 24 hours. If the PaO2/FiO2 ratio is below 100, the patient is considered + to have a respiratory failure (event). For the datasets nwicu, picdb, and + zigong, where the PaO2/FiO2 ratio is not available, we use the condition + SaO2 <= 90% instead. - remaining_los: The remaining length of stay in the ICU. Missing for the last time- step (instead of zero). - circulatory_failure_at_8h: Whether the patient has a circulatory failure within the next 8 hours. Circulatory failure is defined via blood pressure and lactate. - Blood pressure is considered low if the mean arterial pressure is below 65 mmHg or if - the patient is on any blood pressure increasing drug. Lactate is high if it is - above 2 mmol/l. If the two criteria (map and lactate) don't agree, or if one of them is - missing, the event label is missing. + Blood pressure is considered low if the mean arterial pressure is below 65 mmHg or + if the patient is on any blood pressure increasing drug. Lactate is high if it is + above 2 mmol/l. If the two criteria (map and lactate) don't agree, or if one of + them is missing, the event label is missing. For picdb and zigong, very few + measurements of map are available. For them, we relax the condition for a "False" + event: A patient has a negative event if the lactate is good and the map is good + or not measured. - kidney_failure_at_48h: Whether the patient has a kidney failure within the next 48 hours. The patient has a kidney failure if they are in stage 3 according to the KDIGO guidelines: @@ -480,31 +490,40 @@ def outcomes(): .otherwise(False) ).alias("decompensation_at_24h") - # respiratory_failure_at_24h - # If the PaO2/FiO2 ratio is below 200, the patient is considered to have a - # respiratory failure (event). This used pf_ratio from other_variables(). This uses - # a fio2 which was imputed by 21% if the patient was not ventilated. - RESP_PF_DEF_TSH = 200.0 - events = pl.col("pf_ratio") < RESP_PF_DEF_TSH - resp_failure_at_24h = eep_label(events, 24).alias("respiratory_failure_at_24h") - # remaining_los + # remaining length of stay in days remaining_los = pl.col("los_icu") - pl.col("time_hours") / 24.0 remaining_los = pl.when(remaining_los > 0).then(remaining_los).otherwise(None) remaining_los = remaining_los.alias("remaining_los") - # Severe respiratory failure label with simple imputation of po2 and fio2, related - # to Hueser et al. https://www.medrxiv.org/content/10.1101/2024.01.23.24301516v1. - SEVERE_RESP_PF_DEF_TSH = 100 - pf_ratio = ( - 100 - * pl.col("po2").forward_fill(1).backward_fill(1) - / pl.col("fio2").forward_fill(1).backward_fill(1) + # (severe) respiratory_failure_at_24h + # If the PaO2/FiO2 ratio is below 200, the patient is considered to have a + # respiratory failure event. If it's below 100, the patient is considered to have a + # severe respiratory event. We use a fio2 which was imputed by 21% if the patient + # was not ventilated. + vent_ind = pl.col("vent_ind") | pl.col("vent_ind").shift(-1, fill_value=False) + fio2 = (pl.col("fio2") / 100.0).fill_null(pl.when(~vent_ind).then(0.21)) + + RESP_PF_DEF_TSH = 200.0 + SEVERE_RESP_PF_DEF_TSH = 100.0 + RESP_SAO2_DEF_TSH = 90 + SEVERE_RESP_SAO2_DEF_TSH = 90 + + # There are very few pf values in nwicu, picdb, and zigong. We use sao2 instead. + pf_ratio = pl.col("po2") / fio2 + if dataset in ["nwicu", "picdb", "zigong"]: + events = pl.col("sao2") <= RESP_SAO2_DEF_TSH + severe_events = pl.col("sao2") <= SEVERE_RESP_SAO2_DEF_TSH + else: + events = pf_ratio < RESP_PF_DEF_TSH + severe_events = pf_ratio < SEVERE_RESP_PF_DEF_TSH + + respiratory_failure_at_24h = eep_label(events, 24).alias( + "respiratory_failure_at_24h" + ) + severe_respiratory_failure_at_24h = eep_label(severe_events, 24).alias( + "severe_respiratory_failure_at_24h" ) - events = pf_ratio < SEVERE_RESP_PF_DEF_TSH - respiratory_failure_at_24h_severe_imputed = eep_label( - events, 24, switches_only=False - ).alias("respiratory_failure_at_24h_severe_imputed") # circulatory_failure_at_8h # A patient is considered to have a circulatory failure if the mean arterial @@ -541,49 +560,25 @@ def outcomes(): # If the map is "good" (not low and not on drugs) and the lactate is not high, the # event label is negative. If the map and lact don't agree, or if one of them is # missing, the event label is missing. - event = ( - pl.when(bad_map & bad_lact).then(True).when(~bad_map & ~bad_lact).then(False) - ) + if dataset in ["picdb", "zigong"]: + # For PICdb and Zigong, map is not consistently available. We relax the + # condition for a "False": A patient is not in an event if the lactate is good + # and the map is good or not available. + event = ( + pl.when(bad_map & bad_lact) + .then(True) + .when(~bad_map.fill_null(False) & ~bad_lact) + .then(False) + ) + else: + event = ( + pl.when(bad_map & bad_lact) + .then(True) + .when(~bad_map & ~bad_lact) + .then(False) + ) circulatory_failure_at_8h = eep_label(event, 8).alias("circulatory_failure_at_8h") - # Circulatory failure label using interpolated lactate, inspured by Hyland et al. - # https://www.nature.com/articles/s41591-020-0789-4 - # First, we interpolate lactate values. If the time difference between two - # consecutive lactate measurements is less than 6 hours, or if both lactate values - # are either above or below 2, we linearly interpolate the lactate value. Else, we - # fill the value forward and backward for 3 hours. - time = pl.when(pl.col("lact").is_not_null()).then(pl.col("time_hours")) - interp = (time.backward_fill() - time.forward_fill() < 6) | ( - bad_lact.forward_fill() == bad_lact.backward_fill() - ) - lact_interp = ( - pl.when(interp) - .then(pl.col("lact").interpolate(method="linear")) - .otherwise(pl.col("lact").backward_fill(3).forward_fill(3)) - ) - # On the boundary, if the first value of lactate is below the threshold ("good"), we - # fill backwards indefinitely. If the last value is below the threshold, fill - # forward indefinitely. Else, we fill forward and backward for 3 hours. - lact_interp = pl.coalesce( - lact_interp, - pl.when(pl.col("time").le(time.min()) & bad_lact.backward_fill()).then( - lact_interp.backward_fill() - ), - pl.when(pl.col("time").ge(time.max()) & bad_lact.forward_fill()).then( - lact_interp.forward_fill() - ), - ) - bad_lact_interp = lact_interp >= HIGH_LACT_TSH - event = ( - pl.when(bad_map & bad_lact_interp) - .then(True) - .when(~bad_map & ~bad_lact_interp) - .then(False) - ) - circulatory_failure_at_8h_imputed = eep_label(event, 8, switches_only=False).alias( - "circulatory_failure_at_8h_imputed" - ) - # kidney_failure_at_48h # The patient has a kidney failure if they are in stage 3 according to # https://kdigo.org/wp-content/uploads/2016/10/KDIGO-2012-AKI-Guideline-English.pdf @@ -621,8 +616,8 @@ def kdigo_3(crea, rel_urine_rate, crrt): crrt.cast(pl.Boolean).replace(False, None).forward_fill().fill_null(False), (crea / crea_baseline) >= 3.0, high_creatine, - low_urine_rate_24, - anuria, + low_urine_rate_24.fill_null(False), + anuria.fill_null(False), ) aki_3 = kdigo_3( @@ -632,41 +627,6 @@ def kdigo_3(crea, rel_urine_rate, crrt): ) kidney_failure_at_48h = eep_label(aki_3, 48).alias("kidney_failure_at_48h") - # Kidney failure with creatine imputation motivated by Lyu et al. 2024: - # https://www.medrxiv.org/content/10.1101/2024.02.01.24302063v1 - # Similarly to circulatory_failure_8h_imputed, we linearly interpolate creatine - # values. We only interpolate if the time difference between two consecutive - # creatine measurements is less than 48 hours. We ffill and bfill the first and last - # measurements. - time = pl.when(pl.col("crea").is_not_null()).then(pl.col("time_hours")) - interpolate = time.backward_fill() - time.forward_fill() < 48 - crea = pl.coalesce( - pl.when(interpolate).then(pl.col("crea").interpolate(method="linear")), - pl.when(pl.col("time").le(time.min())).then(pl.col("crea").backward_fill(48)), - pl.when(pl.col("time").ge(time.max())).then(pl.col("crea").forward_fill(48)), - ) - urine_rate = pl.col("urine_rate").backward_fill(48) - - aki_3 = kdigo_3( - crea, - urine_rate / pl.col("weight"), - pl.col("ufilt_ind"), - ) - kidney_failure_at_48h_imputed = eep_label(aki_3, 48, switches_only=False).alias( - "kidney_failure_at_48h_imputed" - ) - - # hyperglycemia_at_8h and hypoglycemia_at_8h according to Mehdizavareha et al., - # https://arxiv.org/pdf/2411.01418 - hyperglycemia_event = pl.col("glu") > 180 # mg/dl - hypoglycemia_event = pl.col("glu") < 70 # mg/dl - hyperglycemia_at_8h = eep_label(hyperglycemia_event, 8, switches_only=False).alias( - "hyperglycemia_at_8h" - ) - hypoglycemia_at_8h = eep_label(hypoglycemia_event, 8, switches_only=False).alias( - "hypoglycemia_at_8h" - ) - # Liver failure according to # - MELD score > 30: https://en.wikipedia.org/wiki/Model_for_End-Stage_Liver_Disease # - SOFA score >= 3: https://en.wikipedia.org/wiki/SOFA_score @@ -678,44 +638,62 @@ def kdigo_3(crea, rel_urine_rate, crrt): .forward_fill(7 * 24) ) .then(4.0) - .otherwise(pl.col("crea").forward_fill(1).backward_fill(1)) + .otherwise(pl.col("crea").backward_fill(1)) + ) + bili = pl.col("bili").backward_fill(1) + + if dataset in ["sic", "nwicu", "picdb"]: + # These datasets don't report inr_pt, but just the non-standardized value pt. + # We impute, assuming an ISI=1.0 (appears reasonable for modern equipment) and + # mean normal PT = 12s. + inr_pt = (pl.col("pt") / 12).backward_fill(1) + else: + inr_pt = pl.col("inr_pt").backward_fill(1) + + SEVERE_MELD_DEF_TSH = 30.0 + meld_event = ( + 3.78 * bili.clip(1.0, None).log() + + 11.2 * inr_pt.clip(1.0, None).log() + + 9.57 * crea.clip(1.0, None).log() + + 6.43 + >= SEVERE_MELD_DEF_TSH ) - bili = pl.col("bili").forward_fill(1).backward_fill(1).clip(1, None) - inr = pl.col("inr_pt").forward_fill(1).backward_fill(1).clip(1, None) - meld_score = 3.78 * bili.log() + 11.2 * inr.log() + 9.57 * crea.log() + 6.43 - meld_event = meld_score > 30 severe_meld_at_48h = eep_label(meld_event, 48, switches_only=False).alias( "severe_meld_at_48h" ) - liver_sofa3 = pl.col("bili").forward_fill(1).backward_fill(1) > 6.0 - liver_sofa3_at_48h = eep_label(liver_sofa3, 48).alias("liver_sofa3_at_48h") + # Different to the meld above, we clip from below by 0.1 instead of 1.0. + meld_score = ( + 3.78 * bili.clip(0.1, None).log() + + 11.2 * inr_pt.clip(0.1, None).log() + + 9.57 * crea.clip(0.1, None).log() + + 6.43 + ) + meld_score_in_24h = meld_score.shift(-24).alias("meld_score_in_24h") + # log_lactate_in_4h # log(lactate) in 4 hours. This is 1/2 the forecast horizon of circ. failure eep. log_lactate_in_4h = ( (pl.col("lact") + 0.1).log().shift(-4).alias("log_lactate_in_4h") ) - log_pf_ratio_in_12h = ( - pl.col("pf_ratio").log().shift(-12).alias("log_pf_ratio_in_12h") + # log_creatinine_in_24h + log_creatinine_in_24h = ( + (pl.col("crea") + 0.1).log().shift(-24).alias("log_creatinine_in_24h") ) return [ mortality_at_24h, decompensation_at_24h, - resp_failure_at_24h, - respiratory_failure_at_24h_severe_imputed, + respiratory_failure_at_24h, + severe_respiratory_failure_at_24h, remaining_los, circulatory_failure_at_8h, - circulatory_failure_at_8h_imputed, kidney_failure_at_48h, - kidney_failure_at_48h_imputed, - hyperglycemia_at_8h, - hypoglycemia_at_8h, severe_meld_at_48h, - liver_sofa3_at_48h, + meld_score_in_24h, log_lactate_in_4h, - log_pf_ratio_in_12h, + log_creatinine_in_24h, ] @@ -731,8 +709,10 @@ def main(dataset: str, data_dir: str | Path): # noqa D if "patient_id" not in sta.collect_schema().names(): sta = sta.with_columns(pl.col("stay_id").alias("patient_id")) + if "hospital_id" not in sta.collect_schema().names(): + sta = sta.with_columns(pl.lit(0).alias("hospital_id")) - dyn = dyn.join(sta, on="stay_id", how="full", coalesce=True, validate="m:1") + dyn = dyn.join(sta, on="stay_id", how="left", coalesce=True, validate="m:1") dyn = dyn.with_columns( (pl.col("time").dt.total_hours()).cast(pl.Int32).alias("time_hours") ) @@ -769,7 +749,7 @@ def main(dataset: str, data_dir: str | Path): # noqa D dyn = dyn.join(time_ranges, on=["stay_id", "time_hours"], how="full", coalesce=True) dyn = dyn.sort(["stay_id", "time_hours"]) - expressions = ["time_hours", "anchoryear", "carevue", "metavision", "patient_id"] + expressions = ["time_hours"] for row in variables.rows(named=True): tag = row["VariableTag"] @@ -809,15 +789,36 @@ def main(dataset: str, data_dir: str | Path): # noqa D dyn = dyn.with_columns(col) - expressions += additional_variables() + outcomes() - - q = dyn.group_by("stay_id").agg(expressions).explode(pl.exclude("stay_id")) + expressions += additional_variables() + outcomes(dataset=dataset) + + dyn = dyn.group_by("stay_id").agg(expressions).explode(pl.exclude("stay_id")) + dyn = dyn.join( + sta.select( + [ + "stay_id", + "year", + "carevue", + "metavision", + "patient_id", + "adm", + "insurance", + "ward", + "hospital_id", + "icd10_blocks", + ] + ), + on="stay_id", + how="left", + coalesce=True, + validate="m:1", + ) - # These hashes are useful for subsetting. .hash() returns a u64-int. By normalizing - # with 2**64, we get a pseudo-random float between 0 and 1. - q = q.with_columns( + dyn = dyn.with_columns( + # These hashes are useful for subsetting. .hash() returns a u64-int. By + # normalizing with 2**64, we get a pseudo-random float between 0 and 1. (pl.col("stay_id").hash() / 2.0**64).alias("stay_id_hash"), (pl.col("patient_id").hash() / 2.0**64).alias("patient_id_hash"), + pl.col("time_hours").log1p().alias("log_time_hours"), ).with_columns( pl.when(pl.col("patient_id_hash") < 0.7) .then(pl.lit("train")) @@ -831,22 +832,16 @@ def main(dataset: str, data_dir: str | Path): # noqa D ) feature_names = set(features()) - schema_names = set(q.collect_schema().keys()) + schema_names = set(dyn.collect_schema().keys()) missing_features = feature_names - schema_names if missing_features: raise ValueError(f"Missing features: {missing_features}") tic = perf_counter() - out = q.collect() + dyn.sink_parquet(data_dir / dataset / "features.parquet", engine="streaming") toc = perf_counter() logger.info(f"Time to compute features: {toc - tic:.2f}s") - logger.info(f"out.shape: {out.shape}") - - tic = perf_counter() - out.write_parquet(data_dir / dataset / "features.parquet") - toc = perf_counter() - logger.info(f"Time to write features: {toc - tic:.2f}s") if __name__ == "__main__": diff --git a/icu_features/icd_codes.py b/icu_features/icd_codes.py new file mode 100644 index 0000000..f1f880d --- /dev/null +++ b/icu_features/icd_codes.py @@ -0,0 +1,84 @@ +from pathlib import Path + +import click +import polars as pl +from icdmappings import Mapper + +datasets = [ + "miiv", + "aumc", + "mimic", + "eicu", + "zigong", + "picdb", + "hirid", + "sic", + "nwicu", + "mimic-carevue", +] + +mapper = Mapper() + + +def icd10_blocks(x): + """Map from icd 10 strings to ~130 Blocks of ICD-10 CM codes.""" + return mapper.map(x, source="icd10", target="block") or "" + + +def icd9_blocks(x): + """Map from icd 9 strings to ~130 Blocks of ICD-10 CM codes.""" + icd10code = mapper.map(x, source="icd9", target="icd10") + if x is not None and icd10code is None: + icd10code = mapper.map(x.replace(".", ""), source="icd9", target="icd10") + if icd10code is None: + return "" + return mapper.map(icd10code, source="icd10", target="block") or "" + + +@click.command() +@click.option("--data_dir", type=click.Path(exists=True)) +@click.option("--dataset", type=str, default="eicu_demo") +def main(data_dir: str, dataset: str): # noqa D + sta = pl.scan_parquet(Path(data_dir) / dataset / "sta.parquet") + sta = sta.collect() + + if dataset in ["eicu", "eicu_demo"]: + # For eicu, we don't extract as list, but as a string of comma-separated values. + sta = sta.with_columns( + pl.col("icd9_diagnosis").map_elements( + lambda s: s.map_elements( + lambda ls: ls.split(", ") if ls is not None else [], + skip_nulls=False, + return_dtype=pl.List(pl.String), + ), + return_dtype=pl.List(pl.List(pl.String)), + ) + ) + + sta = sta.with_columns( + pl.col("icd10_diagnosis").cast(pl.List(pl.String)).fill_null([]), + pl.col("icd9_diagnosis").cast(pl.List(pl.String)).fill_null([]), + ) + sta = sta.with_columns( + pl.col("icd10_diagnosis") + .map_elements( + lambda s: sorted(z for z in {icd10_blocks(x) for x in s} if z != ""), + return_dtype=pl.List(pl.String), + ) + .alias("icd10_blocks"), + pl.col("icd9_diagnosis") + .map_elements( + lambda s: sorted(z for z in {icd9_blocks(x) for x in s} if z != ""), + return_dtype=pl.List(pl.String), + ) + .alias("icd9_blocks"), + ) + + sta = sta.with_columns( + pl.concat_list("icd10_blocks", "icd9_blocks").alias("icd10_blocks"), + ) + sta.write_parquet(Path(data_dir) / dataset / "sta.parquet") + + +if __name__ == "__main__": + main() diff --git a/icu_features/load.py b/icu_features/load.py index b3a794e..d27bf54 100644 --- a/icu_features/load.py +++ b/icu_features/load.py @@ -64,7 +64,7 @@ def load( `icu_features.constants.TREATMENT_CONTINUOUS_FEATURES` are loaded. horizons : list of int, optional, default = None The horizons for which to load features. If `None`, all horizons - `icu_benchmarks.constants.HORIZONS` are loaded. + `icu_features.constants.HORIZONS` are loaded. other_columns : list of str, optional, default = None Other columns to load. E.g., `["stay_id_hash"]`. @@ -106,7 +106,7 @@ def load( if other_columns is None: other_columns = [] - columns_to_load = columns + [outcome, "dataset", "stay_id_hash"] + columns_to_load = sorted(set(columns + [outcome, "dataset", "stay_id_hash"])) if "time_hours" not in columns: columns_to_load += ["time_hours"] @@ -122,9 +122,7 @@ def load( y = df[outcome].to_numpy() assert np.isnan(y).sum() == 0 - return (df.select(columns), y) + tuple( - df.select(c).to_series() for c in other_columns - ) + return df.select(columns), y, df.select(other_columns) def features( @@ -159,7 +157,7 @@ def features( `icu_features.constants.TREATMENT_CONTINUOUS_FEATURES` are loaded. horizons : list of int, optional, default = None The horizons for which to load features. If `None`, all horizons - `icu_benchmarks.constants.HORIZONS` are loaded. + `icu_features.constants.HORIZONS` are loaded. Returns ------- diff --git a/icu_features/split_datasets.py b/icu_features/split_datasets.py new file mode 100644 index 0000000..05839f1 --- /dev/null +++ b/icu_features/split_datasets.py @@ -0,0 +1,78 @@ +from pathlib import Path + +import click +import polars as pl + + +@click.command() +@click.option("--data_dir", type=click.Path(exists=True)) +def main(data_dir): + """ + Split mimic, miiv, and aumc datasets into subsets. + + These are: + - mimic -> mimic-carevue and mimic-metavision. + - miiv -> miiv-late. + - aumc -> aumc-early and aumc-late. + + The mimic hospital used the Philips Carevue EHR system until 2008 and the iMDsoft + Metavision ICU system from 2008 onwards. The data format differs between these two + EHR systems. We split the mimic dataset into mimic-carevue and mimic-metavision + subsets to allow for separate analysis of these two EHR systems. The miiv (mimic-iv) + dataset contains data from the mimic hospital after the switch to the Metavision EHR + system. There is an overlap between miiv and mimic-metavision. This cannot be + identified by stay_id due to anonymization. We thus filter out all "early" stays + from the miiv dataset to create the miiv-late dataset with no overlap with + mimic-metavision. + + The aumc dataset contains data from the Amsterdam UMC hospital from 2003 - 2016. + The year has been anonymized and grouped into two periods: around 2006 and around + 2013. We split the aumc dataset into aumc-early and aumc-late subsets to allow for + separate analysis of these two periods. + + Parameters + ---------- + data_dir : str + Path to the data directory. We expect input data at, e.g., + `data_dir/mimic/sta.parquet` and `data_dir/mimic/dyn.parquet`. We write the + output data to `data_dir/mimic-carevue/sta.parquet`, etc. + """ + for source, target, filter_ in [ + ("mimic", "mimic-carevue", pl.col("carevue") & pl.col("metavision").is_null()), + ( + "mimic", + "mimic-metavision", + pl.col("metavision") & pl.col("carevue").is_null(), + ), + ( + "mimic_demo", + "mimic_demo-carevue", + pl.col("carevue") & pl.col("metavision").is_null(), + ), + ( + "mimic", + "mimic_demo-metavision", + pl.col("metavision") & pl.col("carevue").is_null(), + ), + ("miiv", "miiv-late", pl.col("year") > 2012), + ("aumc", "aumc-early", pl.col("year") == 2006), + ("aumc", "aumc-late", pl.col("year") == 2013), + ]: + input_path = Path(data_dir) / source + if not input_path.exists(): + print(f"Skipping {source} as it does not exist in {data_dir}.") + continue + + sta = pl.scan_parquet(Path(data_dir) / source / "sta.parquet").filter(filter_) + sta = sta.collect() + + dyn = pl.scan_parquet(Path(data_dir) / source / "dyn.parquet") + dyn = dyn.filter(pl.col("stay_id").is_in(sta["stay_id"])).collect() + + (Path(data_dir) / target).mkdir(parents=True, exist_ok=True) + sta.write_parquet(Path(data_dir) / target / "sta.parquet") + dyn.write_parquet(Path(data_dir) / target / "dyn.parquet") + + +if __name__ == "__main__": + main() diff --git a/resources/variables.tsv b/resources/variables.tsv index 92f14a5..ea63065 100644 --- a/resources/variables.tsv +++ b/resources/variables.tsv @@ -31,7 +31,7 @@ VariableID VariableTag VariableName VariableType OrganSystem TreatmentGroups Uni 30 crp C-reactive protein observation infection None mg/L continuous 0.0 2000.0 Inflammation marker true 0.01 31 dbp Diastolic blood pressure observation circulatory None mmHg continuous 0.0 200.0 Diastolic blood pressure, blood pressure when heart rests false None 32 fgn Fibrinogen observation circulatory None mg/dL continuous 0.0 1500.0 None false None -33 glu Glucose observation metabolic_renal None mg/dL continuous 0.0 1000.0 None false None +33 glu Glucose observation metabolic_renal None mg/dL continuous 0.0 1000.0 None true 5.0 34 hgb Hemoglobin observation circulatory None g/dL continuous 4.0 18.0 Oxygen carrying protein false None 35 inr_pt Prothrombin observation circulatory None None continuous 0.0 50.0 Prothrombin time/international normalized ratio true 0.01 36 k Potassium observation metabolic_renal None mmol/l continuous 0.0 10.0 None false None @@ -61,7 +61,6 @@ VariableID VariableTag VariableName VariableType OrganSystem TreatmentGroups Uni 63 rass Richmond Agitation Sedation Scale observation neuro None None continuous -5.0 4.0 Scale used to measure the agitation or sedation level of a person false None 64 hbco Carboxyhemoglobin observation circulatory None None continuous 0.0 100.0 Hemoglobin in which the sites usually bound to oxygen are bound to carbon monoxide. true 0.1 65 esr Erythrocyte sedimentation rate observation infection None None continuous 0.0 200.0 Erythrocyte sedimentation rate true 0.1 -68 adm Patient admission type static demographic None None categorical None None Patient admission type false None 69 hba1c Hemoglobin A1C observation metabolic_renal None % continuous 0.0 100.0 A1C is a blood test for type 2 diabetes and prediabetes. true 0.01 71 samp Body fluid sampling, detected bacterial growth observation infection None None categorical None None Body fluid sampling, testing for bacterial infection, True/False/Missing false None 72 spo2 Pulse Oxymetry Oxygen Saturation observation respiratory None % continuous 50.0 100.0 Pulse oxymetry measured blood oxygen saturation false None diff --git a/tests/test_load.py b/tests/test_load.py index c0ec0dd..d6ffd01 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -13,7 +13,9 @@ ["mortality_at_24h", "circulatory_failure_at_8h", "respiratory_failure_at_24h"], ) def test_eicu_demo(outcome): - df, y = load(["eicu_demo"], outcome=outcome, data_dir=DATA_DIR) + df, y, _ = load( + ["eicu_demo", "mimic_demo-carevue"], outcome=outcome, data_dir=DATA_DIR + ) assert np.isnan(y).sum() == 0 assert len(df) == len(y) > 0 assert sorted(df.columns) == sorted(features()) diff --git a/tests/test_outcomes.py b/tests/test_outcomes.py index 1f81ec6..ec19973 100644 --- a/tests/test_outcomes.py +++ b/tests/test_outcomes.py @@ -124,7 +124,7 @@ def test_polars_nan_or(args, expected): ], ) def test_outcomes(outcome_name, input, expected): - expr = [e for e in outcomes() if e.meta.output_name() == outcome_name][0] + expr = [e for e in outcomes("eicu") if e.meta.output_name() == outcome_name][0] assert_series_equal( input.with_columns(expr).select(outcome_name).to_series(), expected.rename(outcome_name), @@ -134,16 +134,6 @@ def test_outcomes(outcome_name, input, expected): @pytest.mark.parametrize( "outcome_name, input, expected_events, expected_labels", [ - ( - "respiratory_failure_at_24h", - pl.DataFrame( - { - "pf_ratio": [None] * 36 + [100] * 12 + [None] * 24 + [500] * 12, - } - ), - pl.Series([None] * 36 + [True] * 12 + [None] * 24 + [False] * 12), - pl.Series([None] * 12 + [True] * 24 + [None] * 12 + [False] * 35 + [None]), - ), ( "circulatory_failure_at_8h", pl.DataFrame( @@ -209,7 +199,7 @@ def test_outcomes(outcome_name, input, expected): ], ) def test_eep_outcomes(outcome_name, input, expected_events, expected_labels): - expr = [e for e in outcomes() if e.meta.output_name() == outcome_name][0] + expr = [e for e in outcomes("eicu") if e.meta.output_name() == outcome_name][0] assert_series_equal( input.with_columns(expr).select(outcome_name).to_series(), diff --git a/tests/testdata/eicu_demo/dyn.parquet b/tests/testdata/eicu_demo/dyn.parquet index 6f6e620..0e1beab 100644 --- a/tests/testdata/eicu_demo/dyn.parquet +++ b/tests/testdata/eicu_demo/dyn.parquet @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:b0f721728c4836c15f8c7b638e22d44e56c6b945a05d573410e43f4ba00f627b -size 2256898 +oid sha256:1d4cbf01a2c17dc32814fc8df3b119c676d736f4158c47c2ee88b81a2f0e8f77 +size 2252914 diff --git a/tests/testdata/eicu_demo/hospital.parquet b/tests/testdata/eicu_demo/hospital.parquet new file mode 100644 index 0000000..a96c583 --- /dev/null +++ b/tests/testdata/eicu_demo/hospital.parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:203c21ace7b2d8ca0bf9882e1b530eea49b035845dbc86650f61398e76563ce3 +size 2905 diff --git a/tests/testdata/eicu_demo/sta.parquet b/tests/testdata/eicu_demo/sta.parquet index 78dcfd9..93511c9 100644 --- a/tests/testdata/eicu_demo/sta.parquet +++ b/tests/testdata/eicu_demo/sta.parquet @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:cc400e3bab39ca0a1a6fd4ec73040eb13024260915c171bb67bc7cfef861d629 -size 68736 +oid sha256:5d86e836864335842b38e24030268bf311625d3dc9e7c8d88b14c7bf5c798cb9 +size 68334 diff --git a/tests/testdata/mimic_demo/dyn.parquet b/tests/testdata/mimic_demo/dyn.parquet new file mode 100644 index 0000000..a97f594 Binary files /dev/null and b/tests/testdata/mimic_demo/dyn.parquet differ diff --git a/tests/testdata/mimic_demo/sta.parquet b/tests/testdata/mimic_demo/sta.parquet new file mode 100644 index 0000000..81997e1 Binary files /dev/null and b/tests/testdata/mimic_demo/sta.parquet differ