Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 52 additions & 33 deletions icu_benchmarks/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,28 +90,35 @@ class PredictionPolarsDataset(CommonPolarsDataset):
def __init__(self, *args, ram_cache: bool = True, **kwargs):
super().__init__(*args, **kwargs)
self.outcome_df = self.grouping_df
self.ram_cache(ram_cache)

def __getitem__(self, idx: int) -> tuple[Tensor, Tensor, Tensor]:
"""Function to sample from the data split of choice. Used for deep learning implementations.

Args:
idx: A specific row index to sample.
# One-pass partition keyed by stay id. Avoids the O(N_stays) full-frame filter
# that previously ran inside every __getitem__.
GROUP = self.vars["GROUP"]
LABEL = self.vars["LABEL"]
label_partitions = self.outcome_df.partition_by(GROUP, maintain_order=True)
feat_partitions = self.features_df.partition_by(GROUP, maintain_order=True)
self._stay_order = [part[GROUP][0] for part in label_partitions]
self._feat_arrays = {
part[GROUP][0]: part.select(pl.exclude(GROUP)).to_numpy().astype(np.float32)
for part in feat_partitions
}
for arr in self._feat_arrays.values():
arr.setflags(write=False)

self._label_arrays = {
part[GROUP][0]: part[LABEL].to_numpy().astype(np.float32) for part in label_partitions
}
for arr in self._label_arrays.values():
arr.setflags(write=False)

Returns:
A sample from the data, consisting of data, labels and padding mask.
"""
if self._cached_dataset is not None:
return self._cached_dataset[idx]
self.ram_cache(ram_cache)

def _build_item(self, idx: int) -> tuple[Tensor, Tensor, Tensor]:
pad_value = 0.0
stay_id = self.outcome_df[self.vars["GROUP"]].unique()[idx] # [self.vars["GROUP"]]
stay_id = self._stay_order[idx]

# slice to make sure to always return a DF
window = (
self.features_df.filter(pl.col(self.vars["GROUP"]) == stay_id).select(pl.exclude(self.vars["GROUP"])).to_numpy()
)
labels = self.outcome_df.filter(pl.col(self.vars["GROUP"]) == stay_id)[self.vars["LABEL"]].to_numpy().astype(float)
window = self._feat_arrays[stay_id]
labels = self._label_arrays[stay_id].copy() # copy to avoid in-place NaN replacement

if len(labels) == 1:
# only one label per stay, align with window
Expand All @@ -122,7 +129,6 @@ def __getitem__(self, idx: int) -> tuple[Tensor, Tensor, Tensor]:

# Padding the array to fulfill size requirement
if length_diff > 0:
# window shorter than the longest window in dataset, pad to same length
window = np.concatenate([window, np.ones((length_diff, window.shape[1])) * pad_value], axis=0)
labels = np.concatenate([labels, np.ones(length_diff) * pad_value], axis=0)
pad_mask = np.concatenate([pad_mask, np.zeros(length_diff)], axis=0)
Expand All @@ -138,6 +144,19 @@ def __getitem__(self, idx: int) -> tuple[Tensor, Tensor, Tensor]:

return from_numpy(data), from_numpy(labels), from_numpy(pad_mask)

def __getitem__(self, idx: int) -> tuple[Tensor, Tensor, Tensor]:
"""Function to sample from the data split of choice. Used for deep learning implementations.

Args:
idx: A specific row index to sample.

Returns:
A sample from the data, consisting of data, labels and padding mask.
"""
if self._cached_dataset is not None:
return self._cached_dataset[idx]
return self._build_item(idx)

def get_balance(self) -> list:
"""Return the weight balance for the split of interest.

Expand All @@ -157,7 +176,7 @@ def get_data_and_labels(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
Returns:
A Tuple containing data points and label for the split.
"""
labels = self.outcome_df[self.vars["LABEL"]].to_numpy().astype(float)
labels = self.outcome_df[self.vars["LABEL"]].to_numpy().astype(np.float32)
rep = self.features_df

if len(labels) == self.num_stays:
Expand All @@ -166,17 +185,19 @@ def get_data_and_labels(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
else:
# Adding segment count for each stay id and timestep.
rep = rep.with_columns(pl.col(self.vars["GROUP"]).cum_count().over(self.vars["GROUP"]).alias("counter"))
rep = rep.to_numpy().astype(float)
rep = rep.to_numpy().astype(np.float32)
logging.debug(f"rep shape: {rep.shape}")
logging.debug(f"labels shape: {labels.shape}")
return rep, labels, self.row_indicators.to_numpy()

def to_tensor(self) -> tuple[Union[Tensor, np.ndarray], Union[Tensor, np.ndarray], Union[Tensor, np.ndarray]]:
def to_tensor(self) -> tuple[Tensor, Tensor, Tensor]:
data, labels, row_indicators = self.get_data_and_labels()
if self.mps:
return from_numpy(data).to(float32), from_numpy(labels).to(float32), from_numpy(row_indicators).to(float32)
else:
return from_numpy(data), from_numpy(labels), row_indicators

return (
from_numpy(data),
from_numpy(labels),
from_numpy(row_indicators),
)


@gin.configurable("CommonPandasDataset")
Expand Down Expand Up @@ -271,7 +292,7 @@ def __getitem__(self, idx: int) -> tuple[Tensor, Tensor, Tensor]:

# slice to make sure to always return a DF
window = self.features_df.loc[stay_id:stay_id].to_numpy()
labels = self.outcome_df.loc[stay_id:stay_id][self.vars["LABEL"]].to_numpy(dtype=float)
labels = self.outcome_df.loc[stay_id:stay_id][self.vars["LABEL"]].to_numpy(dtype=np.float32)

if len(labels) == 1:
# only one label per stay, align with window
Expand Down Expand Up @@ -315,21 +336,19 @@ def get_data_and_labels(self) -> tuple[np.ndarray, np.ndarray]:
Returns:
A Tuple containing data points and label for the split.
"""
labels = self.outcome_df[self.vars["LABEL"]].to_numpy().astype(float)
labels = self.outcome_df[self.vars["LABEL"]].to_numpy().astype(np.float32)
rep = self.features_df
if len(labels) == self.num_stays:
# order of groups could be random, we make sure not to change it
rep = rep.groupby(level=self.vars["GROUP"], sort=False).last()
rep = rep.to_numpy().astype(float)
rep = rep.to_numpy().astype(np.float32)

return rep, labels

def to_tensor(self) -> tuple[Union[Tensor, np.ndarray], Union[Tensor, np.ndarray]]:
def to_tensor(self) -> tuple[Tensor, Tensor]:
data, labels = self.get_data_and_labels()
if self.mps:
return from_numpy(data).to(float32), from_numpy(labels).to(float32)
else:
return from_numpy(data), from_numpy(labels)
# Always use float32 for memory efficiency and MPS compatibility
return from_numpy(data), from_numpy(labels)


@gin.configurable("ImputationPandasDataset")
Expand Down
10 changes: 6 additions & 4 deletions icu_benchmarks/models/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def train_common(
model: DLModel | MLModelClassifier | MLModelRegression | object = gin.REQUIRED,
weight: str = "",
optimizer: type = Adam,
precision: Optional[Literal[16] | Literal[32] | Literal[64] | Literal["16-true"]] = 32,
precision: Optional[Literal[16] | Literal[32] | Literal[64] | str] = 32,
batch_size: int = 1,
epochs: int = 100,
patience: int = 20,
Expand Down Expand Up @@ -118,6 +118,7 @@ def train_common(
num_workers=num_workers,
drop_last=True,
persistent_workers=persistent_workers,
pin_memory=not cpu,
)
val_loader = DataLoader(
val_dataset,
Expand All @@ -126,6 +127,7 @@ def train_common(
num_workers=num_workers,
drop_last=True,
persistent_workers=persistent_workers,
pin_memory=not cpu,
)

data_shape = next(iter(train_loader))[0].shape
Expand Down Expand Up @@ -154,8 +156,8 @@ def train_common(
]
if verbose:
callbacks.append(TQDMProgressBar(refresh_rate=min(100, len(train_loader) // 2)))
if precision == 16 or "16-mixed":
torch.set_float32_matmul_precision("medium")
if precision in (16, "16-mixed", "bf16", "bf16-mixed"):
torch.set_float32_matmul_precision("high")

trainer = Trainer(
max_epochs=epochs if model.requires_backprop else 1,
Expand Down Expand Up @@ -194,7 +196,7 @@ def train_common(
batch_size=min(batch_size * 4, len(test_dataset)),
shuffle=False,
num_workers=num_workers,
pin_memory=True,
pin_memory=not cpu,
drop_last=True,
persistent_workers=persistent_workers,
)
Expand Down
Loading