Skip to content
Open
66 changes: 63 additions & 3 deletions icu_benchmarks/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@ def __init__(
logging.info("Using static dataset")
self.row_indicators = data[split][DataSegment.features][self.vars["GROUP"]]
self.features_df = data[split][DataSegment.features]

# order columns: index, features (alphabetically), indicator (alphabetically)
cols = self.features_df.columns
m_index = [self.vars["GROUP"]]
front = sorted([c for c in cols if not c.startswith("MissingIndicator_") and c not in m_index])
back = sorted([c for c in cols if c.startswith("MissingIndicator_") and c not in m_index])
self.features_df = self.features_df[m_index + front + back]

Comment thread
coderabbitai[bot] marked this conversation as resolved.
# calculate basic info for the data
self.num_stays = self.grouping_df[self.vars["GROUP"]].unique().shape[0]
self.maxlen = self.features_df.group_by([self.vars["GROUP"]]).len().max().item(0, 1)
Expand Down Expand Up @@ -102,7 +110,7 @@ def __getitem__(self, idx: int) -> tuple[Tensor, Tensor, Tensor]:
A sample from the data, consisting of data, labels and padding mask.
"""
if self._cached_dataset is not None:
return self._cached_dataset[idx]
return tuple(element[idx] for element in self._cached_dataset)

pad_value = 0.0
stay_id = self.outcome_df[self.vars["GROUP"]].unique()[idx] # [self.vars["GROUP"]]
Expand All @@ -115,7 +123,7 @@ def __getitem__(self, idx: int) -> tuple[Tensor, Tensor, Tensor]:

if len(labels) == 1:
# only one label per stay, align with window
labels = np.concatenate([np.empty(window.shape[0] - 1) * np.nan, labels], axis=0)
labels = np.concatenate([np.full((window.shape[0] - 1, len(self.vars["LABEL"])), np.nan), labels], axis=0)

length_diff = self.maxlen - window.shape[0]
pad_mask = np.ones(window.shape[0])
Expand All @@ -124,7 +132,7 @@ def __getitem__(self, idx: int) -> tuple[Tensor, Tensor, Tensor]:
if length_diff > 0:
# window shorter than the longest window in dataset, pad to same length
window = np.concatenate([window, np.ones((length_diff, window.shape[1])) * pad_value], axis=0)
labels = np.concatenate([labels, np.ones(length_diff) * pad_value], axis=0)
labels = np.concatenate([labels, np.ones((length_diff, len(self.vars["LABEL"]))) * pad_value], axis=0)
pad_mask = np.concatenate([pad_mask, np.zeros(length_diff)], axis=0)

not_labeled = np.argwhere(np.isnan(labels))
Expand All @@ -138,6 +146,58 @@ def __getitem__(self, idx: int) -> tuple[Tensor, Tensor, Tensor]:

return from_numpy(data), from_numpy(labels), from_numpy(pad_mask)

def ram_cache(self, cache: bool = True) -> None:
"""Prepares a in-memory cache of the data, transforms the DataFrames to padded Tensors.
saves (padded_features, padded_labels, pad_mask) in `self._cached_dataset`
"""
self._cached_dataset = None
if cache:
logging.info(f"Caching {self.split} dataset in ram.")

# Get per-stay lengths in group order
lengths = self.features_df.group_by(self.vars["GROUP"], maintain_order=True).len()["len"].to_numpy().astype(int)
offsets = np.concatenate([[0], lengths.cumsum()[:-1]])

n_stays = self.num_stays
n_features = len(self.features_df.columns) - 1 # exclude GROUP
n_labels = len(self.vars["LABEL"])

# Extract full arrays once
self.features_df = self.features_df.sort(self.vars["GROUP"])
self.outcome_df = self.outcome_df.sort(self.vars["GROUP"])
data_np = self.features_df.select(pl.exclude(self.vars["GROUP"])).to_numpy().astype(np.float32)
labels_np = self.outcome_df[self.vars["LABEL"]].to_numpy().astype(np.float32)

single_label_per_stay = self.outcome_df.shape[0] == n_stays

# Pre-allocate
padded_data = np.zeros((n_stays, self.maxlen, n_features), dtype=np.float32)
padded_labels = np.zeros((n_stays, self.maxlen, n_labels), dtype=np.float32)
pad_mask = np.zeros((n_stays, self.maxlen), dtype=bool)

for i, (offset, length) in enumerate(zip(offsets, lengths, strict=True)):
padded_data[i, :length] = data_np[offset : offset + length]
if single_label_per_stay:
# mirror __getitem__: all NaN except final timestep
stay_labels = np.full((length, n_labels), np.nan, dtype=np.float32)
stay_labels[-1] = labels_np[i]
padded_labels[i, :length] = stay_labels
else:
padded_labels[i, :length] = labels_np[offset : offset + length]
pad_mask[i, :length] = True

# Replace nan labels with -1 and mask them out (mirrors __getitem__)
nan_mask = np.isnan(padded_labels) # (n_stays, maxlen, n_labels)
padded_labels[nan_mask] = -1
# If any label is nan at a timestep, zero the mask
pad_mask[nan_mask.any(axis=-1)] = False

self._cached_dataset = (
from_numpy(padded_data),
from_numpy(padded_labels),
from_numpy(pad_mask),
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

def get_balance(self) -> list:
"""Return the weight balance for the split of interest.

Expand Down
4 changes: 2 additions & 2 deletions icu_benchmarks/data/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def _process_outcome(
outcome_rec.add_step(
StepSklearn(
sklearn_transformer=FunctionTransformer(
func=lambda x: ((x - self.outcome_min) / (self.outcome_max - self.outcome_min))
func=lambda x: (x - self.outcome_min) / (self.outcome_max - self.outcome_min)
),
sel=all_outcomes(),
)
Expand Down Expand Up @@ -528,7 +528,7 @@ def _process_outcome(self, data, vars, split):
outcome_rec.add_step(
StepSklearn(
sklearn_transformer=FunctionTransformer(
func=lambda x: ((x - self.outcome_min) / (self.outcome_max - self.outcome_min))
func=lambda x: (x - self.outcome_min) / (self.outcome_max - self.outcome_min)
),
sel=all_outcomes(),
)
Expand Down
35 changes: 21 additions & 14 deletions icu_benchmarks/data/split_process_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def preprocess_data(
)
else:
# If full train is set, we use all data for training/validation
sanitized_data = make_train_val_polars(data, vars, train_size=None, seed=seed, debug=debug, runmode=runmode)
sanitized_data = make_train_val_polars(data, vars, train_size=train_size, seed=seed, debug=debug, runmode=runmode)

# Apply preprocessing
start = timer()
Expand All @@ -191,7 +191,7 @@ def preprocess_data(
sel = _dict[key].select(pl.all().has_nulls())
logging.debug(sel.select(col.name for col in sel if col.item(0)))
_dict[key] = val.fill_null(strategy="zero")
_dict[key] = val.fill_nan(0)
_dict[key] = _dict[key].fill_nan(0)
logging.debug("Dropping columns with nulls")
sel = _dict[key].select(pl.all().has_nulls())
logging.debug(sel.select(col.name for col in sel if col.item(0)))
Expand Down Expand Up @@ -368,10 +368,14 @@ def make_train_val_polars(
)

if debug:
logging.info("Using only 1% of the data for debugging. Note that this might lead to errors for small datasets.")
data[DataSegment.outcome] = data[DataSegment.outcome].sample(fraction=0.01, seed=seed)
logging.info("Using only 1% of stay_id for debugging. Note that this might lead to errors for small datasets.")
sampled_ids = data[DataSegment.outcome][_id].unique().sample(fraction=0.01, seed=seed)
data[DataSegment.outcome] = data[DataSegment.outcome].filter(pl.col(_id).is_in(sampled_ids))
data[DataSegment.dynamic] = data[DataSegment.dynamic].filter(pl.col(_id).is_in(sampled_ids))
if DataSegment.static in data:
data[DataSegment.static] = data[DataSegment.static].filter(pl.col(_id).is_in(sampled_ids))

stays = pl.Series(name=_id, values=data[DataSegment.outcome][_id].unique())
stays = pl.Series(name=_id, values=data[DataSegment.outcome][_id].unique().sort())

if VarType.label in vars and runmode is RunMode.classification:
labels = data[DataSegment.outcome].group_by(_id).max()[label]
Expand All @@ -391,7 +395,7 @@ def make_train_val_polars(
for fold in split.keys():
data_split[fold] = {
data_type: split[fold]
.join(data[data_type].with_columns(pl.col(_id).cast(pl.datatypes.Int64)), on=_id, how="left")
.join(data[data_type].with_columns(pl.col(_id).cast(pl.datatypes.Int64)), on=_id, how="inner")
.sort(by=_id)
for data_type in data.keys()
}
Expand Down Expand Up @@ -535,18 +539,21 @@ def make_single_split_polars(
For a more detailed documentation refer to make_single_splits(...)
"""
# ID variable
id = vars[VarType.group]
_id = vars[VarType.group]
if debug:
# Only use 1% of the data
logging.info("Using only 1% of the data for debugging. Note that this might lead to errors for small datasets.")
data[DataSegment.outcome] = data[DataSegment.outcome].sample(fraction=0.01, seed=seed)
logging.info("Using only stay_id's of the data for debugging. Note that this might lead to errors for small datasets.")
sampled_ids = data[DataSegment.outcome][_id].unique().sample(fraction=0.01, seed=seed)
data[DataSegment.outcome] = data[DataSegment.outcome].filter(pl.col(_id).is_in(sampled_ids))
data[DataSegment.dynamic] = data[DataSegment.dynamic].filter(pl.col(_id).is_in(sampled_ids))
if DataSegment.static in data:
data[DataSegment.static] = data[DataSegment.static].filter(pl.col(_id).is_in(sampled_ids))

# Get stay IDs from outcome segment
stays = pl.Series(name=id, values=data[DataSegment.outcome][id].unique())
stays = pl.Series(name=_id, values=data[DataSegment.outcome][_id].unique().sort())
# If there are labels, and the task is classification, use stratified k-fold
if VarType.label in vars and runmode is RunMode.classification:
# Get labels from outcome data (takes the highest value (or True) in case seq2seq classification)
labels: pl.Series = data[DataSegment.outcome].group_by(id).max().sort(id)[vars[VarType.label]]
labels: pl.Series = data[DataSegment.outcome].group_by(_id).max().sort(_id)[vars[VarType.label]]
if labels.value_counts().min().item(0, 1) < cv_folds:
raise Exception(
f"The smallest amount of samples in a class is: {labels.value_counts().min()}, "
Expand Down Expand Up @@ -586,8 +593,8 @@ def make_single_split_polars(
# set sort to true to make sure that IDs are reordered after scrambling earlier
data_split[fold] = {
data_type: split[fold]
.join(data[data_type].with_columns(pl.col(id).cast(pl.datatypes.Int64)), on=id, how="left")
.sort(by=id)
.join(data[data_type].with_columns(pl.col(_id).cast(pl.datatypes.Int64)), on=_id, how="inner")
.sort(by=_id)
for data_type in data.keys()
}

Expand Down
Loading