From 853055bd643891ca2c8116483d04145fbb267df6 Mon Sep 17 00:00:00 2001 From: prockenschaub Date: Tue, 24 Mar 2026 11:09:02 +0100 Subject: [PATCH 1/7] use float32 throughout --- icu_benchmarks/data/loader.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/icu_benchmarks/data/loader.py b/icu_benchmarks/data/loader.py index bfa944ad..8dc3952a 100644 --- a/icu_benchmarks/data/loader.py +++ b/icu_benchmarks/data/loader.py @@ -111,7 +111,7 @@ def __getitem__(self, idx: int) -> tuple[Tensor, Tensor, Tensor]: window = ( self.features_df.filter(pl.col(self.vars["GROUP"]) == stay_id).select(pl.exclude(self.vars["GROUP"])).to_numpy() ) - labels = self.outcome_df.filter(pl.col(self.vars["GROUP"]) == stay_id)[self.vars["LABEL"]].to_numpy().astype(float) + labels = self.outcome_df.filter(pl.col(self.vars["GROUP"]) == stay_id)[self.vars["LABEL"]].to_numpy().astype(np.float32) if len(labels) == 1: # only one label per stay, align with window @@ -157,7 +157,7 @@ def get_data_and_labels(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]: Returns: A Tuple containing data points and label for the split. """ - labels = self.outcome_df[self.vars["LABEL"]].to_numpy().astype(float) + labels = self.outcome_df[self.vars["LABEL"]].to_numpy().astype(np.float32) rep = self.features_df if len(labels) == self.num_stays: @@ -166,17 +166,19 @@ def get_data_and_labels(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]: else: # Adding segment count for each stay id and timestep. rep = rep.with_columns(pl.col(self.vars["GROUP"]).cum_count().over(self.vars["GROUP"]).alias("counter")) - rep = rep.to_numpy().astype(float) + rep = rep.to_numpy().astype(np.float32) logging.debug(f"rep shape: {rep.shape}") logging.debug(f"labels shape: {labels.shape}") return rep, labels, self.row_indicators.to_numpy() - def to_tensor(self) -> tuple[Union[Tensor, np.ndarray], Union[Tensor, np.ndarray], Union[Tensor, np.ndarray]]: + def to_tensor(self) -> tuple[Tensor, Tensor, Tensor]: data, labels, row_indicators = self.get_data_and_labels() - if self.mps: - return from_numpy(data).to(float32), from_numpy(labels).to(float32), from_numpy(row_indicators).to(float32) - else: - return from_numpy(data), from_numpy(labels), row_indicators + # Always use float32 for memory efficiency and MPS compatibility + return ( + from_numpy(data), + from_numpy(labels), + from_numpy(row_indicators.astype(np.float32)), + ) @gin.configurable("CommonPandasDataset") @@ -271,7 +273,7 @@ def __getitem__(self, idx: int) -> tuple[Tensor, Tensor, Tensor]: # slice to make sure to always return a DF window = self.features_df.loc[stay_id:stay_id].to_numpy() - labels = self.outcome_df.loc[stay_id:stay_id][self.vars["LABEL"]].to_numpy(dtype=float) + labels = self.outcome_df.loc[stay_id:stay_id][self.vars["LABEL"]].to_numpy(dtype=np.float32) if len(labels) == 1: # only one label per stay, align with window @@ -315,21 +317,19 @@ def get_data_and_labels(self) -> tuple[np.ndarray, np.ndarray]: Returns: A Tuple containing data points and label for the split. """ - labels = self.outcome_df[self.vars["LABEL"]].to_numpy().astype(float) + labels = self.outcome_df[self.vars["LABEL"]].to_numpy().astype(np.float32) rep = self.features_df if len(labels) == self.num_stays: # order of groups could be random, we make sure not to change it rep = rep.groupby(level=self.vars["GROUP"], sort=False).last() - rep = rep.to_numpy().astype(float) + rep = rep.to_numpy().astype(np.float32) return rep, labels - def to_tensor(self) -> tuple[Union[Tensor, np.ndarray], Union[Tensor, np.ndarray]]: + def to_tensor(self) -> tuple[Tensor, Tensor]: data, labels = self.get_data_and_labels() - if self.mps: - return from_numpy(data).to(float32), from_numpy(labels).to(float32) - else: - return from_numpy(data), from_numpy(labels) + # Always use float32 for memory efficiency and MPS compatibility + return from_numpy(data), from_numpy(labels) @gin.configurable("ImputationPandasDataset") From 646882680a7fde03c65b635063d4707cb9bad053 Mon Sep 17 00:00:00 2001 From: Patrick Rockenschaub Date: Thu, 23 Apr 2026 09:22:42 +0200 Subject: [PATCH 2/7] fix precision bug --- icu_benchmarks/models/train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/icu_benchmarks/models/train.py b/icu_benchmarks/models/train.py index d7fe83cc..9b58a2d6 100644 --- a/icu_benchmarks/models/train.py +++ b/icu_benchmarks/models/train.py @@ -42,7 +42,7 @@ def train_common( model: DLModel | MLModelClassifier | MLModelRegression | object = gin.REQUIRED, weight: str = "", optimizer: type = Adam, - precision: Optional[Literal[16] | Literal[32] | Literal[64] | Literal["16-true"]] = 32, + precision: Optional[Literal[16] | Literal[32] | Literal[64] | str] = 32, batch_size: int = 1, epochs: int = 100, patience: int = 20, @@ -154,8 +154,8 @@ def train_common( ] if verbose: callbacks.append(TQDMProgressBar(refresh_rate=min(100, len(train_loader) // 2))) - if precision == 16 or "16-mixed": - torch.set_float32_matmul_precision("medium") + if precision in (16, "16-mixed", "bf16", "bf16-mixed"): + torch.set_float32_matmul_precision("high") trainer = Trainer( max_epochs=epochs if model.requires_backprop else 1, From 4f7099f4de4c0ecab465431819ec50cbdd3a47cd Mon Sep 17 00:00:00 2001 From: Patrick Rockenschaub Date: Thu, 23 Apr 2026 09:23:01 +0200 Subject: [PATCH 3/7] pin memory during training --- icu_benchmarks/models/train.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/icu_benchmarks/models/train.py b/icu_benchmarks/models/train.py index 9b58a2d6..977fb9cc 100644 --- a/icu_benchmarks/models/train.py +++ b/icu_benchmarks/models/train.py @@ -118,6 +118,7 @@ def train_common( num_workers=num_workers, drop_last=True, persistent_workers=persistent_workers, + pin_memory=True, ) val_loader = DataLoader( val_dataset, @@ -126,6 +127,7 @@ def train_common( num_workers=num_workers, drop_last=True, persistent_workers=persistent_workers, + pin_memory=True, ) data_shape = next(iter(train_loader))[0].shape From 816964fe284ac864590625dea7cc8d45abfc2096 Mon Sep 17 00:00:00 2001 From: Patrick Rockenschaub Date: Thu, 23 Apr 2026 21:29:50 +0200 Subject: [PATCH 4/7] optimise cache --- icu_benchmarks/data/loader.py | 50 ++++++++++++++++++++++------------- 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/icu_benchmarks/data/loader.py b/icu_benchmarks/data/loader.py index 8dc3952a..c4ab0ffc 100644 --- a/icu_benchmarks/data/loader.py +++ b/icu_benchmarks/data/loader.py @@ -90,28 +90,30 @@ class PredictionPolarsDataset(CommonPolarsDataset): def __init__(self, *args, ram_cache: bool = True, **kwargs): super().__init__(*args, **kwargs) self.outcome_df = self.grouping_df - self.ram_cache(ram_cache) - def __getitem__(self, idx: int) -> tuple[Tensor, Tensor, Tensor]: - """Function to sample from the data split of choice. Used for deep learning implementations. + # One-pass partition keyed by stay id. Avoids the O(N_stays) full-frame filter + # that previously ran inside every __getitem__. + GROUP = self.vars["GROUP"] + LABEL = self.vars["LABEL"] + label_partitions = self.outcome_df.partition_by(GROUP, maintain_order=True) + feat_partitions = self.features_df.partition_by(GROUP, maintain_order=True) + self._stay_order = [part[GROUP][0] for part in label_partitions] + self._feat_arrays = { + part[GROUP][0]: part.select(pl.exclude(GROUP)).to_numpy().astype(np.float32) + for part in feat_partitions + } + self._label_arrays = { + part[GROUP][0]: part[LABEL].to_numpy().astype(np.float32) for part in label_partitions + } - Args: - idx: A specific row index to sample. - - Returns: - A sample from the data, consisting of data, labels and padding mask. - """ - if self._cached_dataset is not None: - return self._cached_dataset[idx] + self.ram_cache(ram_cache) + def _build_item(self, idx: int) -> tuple[Tensor, Tensor, Tensor]: pad_value = 0.0 - stay_id = self.outcome_df[self.vars["GROUP"]].unique()[idx] # [self.vars["GROUP"]] + stay_id = self._stay_order[idx] - # slice to make sure to always return a DF - window = ( - self.features_df.filter(pl.col(self.vars["GROUP"]) == stay_id).select(pl.exclude(self.vars["GROUP"])).to_numpy() - ) - labels = self.outcome_df.filter(pl.col(self.vars["GROUP"]) == stay_id)[self.vars["LABEL"]].to_numpy().astype(np.float32) + window = self._feat_arrays[stay_id] + labels = self._label_arrays[stay_id] if len(labels) == 1: # only one label per stay, align with window @@ -122,7 +124,6 @@ def __getitem__(self, idx: int) -> tuple[Tensor, Tensor, Tensor]: # Padding the array to fulfill size requirement if length_diff > 0: - # window shorter than the longest window in dataset, pad to same length window = np.concatenate([window, np.ones((length_diff, window.shape[1])) * pad_value], axis=0) labels = np.concatenate([labels, np.ones(length_diff) * pad_value], axis=0) pad_mask = np.concatenate([pad_mask, np.zeros(length_diff)], axis=0) @@ -138,6 +139,19 @@ def __getitem__(self, idx: int) -> tuple[Tensor, Tensor, Tensor]: return from_numpy(data), from_numpy(labels), from_numpy(pad_mask) + def __getitem__(self, idx: int) -> tuple[Tensor, Tensor, Tensor]: + """Function to sample from the data split of choice. Used for deep learning implementations. + + Args: + idx: A specific row index to sample. + + Returns: + A sample from the data, consisting of data, labels and padding mask. + """ + if self._cached_dataset is not None: + return self._cached_dataset[idx] + return self._build_item(idx) + def get_balance(self) -> list: """Return the weight balance for the split of interest. From ed8fb2b2aa649d302e0a829094be2e302b1ffa4a Mon Sep 17 00:00:00 2001 From: Patrick Rockenschaub Date: Fri, 24 Apr 2026 09:51:49 +0200 Subject: [PATCH 5/7] fix in-place mutation of cached _label_arrays[stay_id] --- icu_benchmarks/data/loader.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/icu_benchmarks/data/loader.py b/icu_benchmarks/data/loader.py index c4ab0ffc..a3a79722 100644 --- a/icu_benchmarks/data/loader.py +++ b/icu_benchmarks/data/loader.py @@ -102,9 +102,14 @@ def __init__(self, *args, ram_cache: bool = True, **kwargs): part[GROUP][0]: part.select(pl.exclude(GROUP)).to_numpy().astype(np.float32) for part in feat_partitions } + for arr in self._feat_arrays.values(): + arr.setflags(write=False) + self._label_arrays = { part[GROUP][0]: part[LABEL].to_numpy().astype(np.float32) for part in label_partitions } + for arr in self._label_arrays.values(): + arr.setflags(write=False) self.ram_cache(ram_cache) @@ -113,7 +118,7 @@ def _build_item(self, idx: int) -> tuple[Tensor, Tensor, Tensor]: stay_id = self._stay_order[idx] window = self._feat_arrays[stay_id] - labels = self._label_arrays[stay_id] + labels = self._label_arrays[stay_id].copy() # copy to avoid in-place NaN replacement if len(labels) == 1: # only one label per stay, align with window From 89be3984b69c70b46ca414e40e4d94b99290a856 Mon Sep 17 00:00:00 2001 From: Patrick Rockenschaub Date: Fri, 24 Apr 2026 09:56:58 +0200 Subject: [PATCH 6/7] preserve native dtype of row_indicators in to_tensor --- icu_benchmarks/data/loader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/icu_benchmarks/data/loader.py b/icu_benchmarks/data/loader.py index a3a79722..71b317ae 100644 --- a/icu_benchmarks/data/loader.py +++ b/icu_benchmarks/data/loader.py @@ -192,11 +192,11 @@ def get_data_and_labels(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]: def to_tensor(self) -> tuple[Tensor, Tensor, Tensor]: data, labels, row_indicators = self.get_data_and_labels() - # Always use float32 for memory efficiency and MPS compatibility + return ( from_numpy(data), from_numpy(labels), - from_numpy(row_indicators.astype(np.float32)), + from_numpy(row_indicators), ) From 0ffe22d6163fa2d7a8c67db6d3f4c96eb14b4adc Mon Sep 17 00:00:00 2001 From: Patrick Rockenschaub Date: Fri, 24 Apr 2026 09:58:27 +0200 Subject: [PATCH 7/7] gate pin_memory on accelerator availability --- icu_benchmarks/models/train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/icu_benchmarks/models/train.py b/icu_benchmarks/models/train.py index 977fb9cc..1cbd1e21 100644 --- a/icu_benchmarks/models/train.py +++ b/icu_benchmarks/models/train.py @@ -118,7 +118,7 @@ def train_common( num_workers=num_workers, drop_last=True, persistent_workers=persistent_workers, - pin_memory=True, + pin_memory=not cpu, ) val_loader = DataLoader( val_dataset, @@ -127,7 +127,7 @@ def train_common( num_workers=num_workers, drop_last=True, persistent_workers=persistent_workers, - pin_memory=True, + pin_memory=not cpu, ) data_shape = next(iter(train_loader))[0].shape @@ -196,7 +196,7 @@ def train_common( batch_size=min(batch_size * 4, len(test_dataset)), shuffle=False, num_workers=num_workers, - pin_memory=True, + pin_memory=not cpu, drop_last=True, persistent_workers=persistent_workers, )