diff --git a/icu_benchmarks/data/loader.py b/icu_benchmarks/data/loader.py index bfa944ad..71b317ae 100644 --- a/icu_benchmarks/data/loader.py +++ b/icu_benchmarks/data/loader.py @@ -90,28 +90,35 @@ class PredictionPolarsDataset(CommonPolarsDataset): def __init__(self, *args, ram_cache: bool = True, **kwargs): super().__init__(*args, **kwargs) self.outcome_df = self.grouping_df - self.ram_cache(ram_cache) - - def __getitem__(self, idx: int) -> tuple[Tensor, Tensor, Tensor]: - """Function to sample from the data split of choice. Used for deep learning implementations. - Args: - idx: A specific row index to sample. + # One-pass partition keyed by stay id. Avoids the O(N_stays) full-frame filter + # that previously ran inside every __getitem__. + GROUP = self.vars["GROUP"] + LABEL = self.vars["LABEL"] + label_partitions = self.outcome_df.partition_by(GROUP, maintain_order=True) + feat_partitions = self.features_df.partition_by(GROUP, maintain_order=True) + self._stay_order = [part[GROUP][0] for part in label_partitions] + self._feat_arrays = { + part[GROUP][0]: part.select(pl.exclude(GROUP)).to_numpy().astype(np.float32) + for part in feat_partitions + } + for arr in self._feat_arrays.values(): + arr.setflags(write=False) + + self._label_arrays = { + part[GROUP][0]: part[LABEL].to_numpy().astype(np.float32) for part in label_partitions + } + for arr in self._label_arrays.values(): + arr.setflags(write=False) - Returns: - A sample from the data, consisting of data, labels and padding mask. - """ - if self._cached_dataset is not None: - return self._cached_dataset[idx] + self.ram_cache(ram_cache) + def _build_item(self, idx: int) -> tuple[Tensor, Tensor, Tensor]: pad_value = 0.0 - stay_id = self.outcome_df[self.vars["GROUP"]].unique()[idx] # [self.vars["GROUP"]] + stay_id = self._stay_order[idx] - # slice to make sure to always return a DF - window = ( - self.features_df.filter(pl.col(self.vars["GROUP"]) == stay_id).select(pl.exclude(self.vars["GROUP"])).to_numpy() - ) - labels = self.outcome_df.filter(pl.col(self.vars["GROUP"]) == stay_id)[self.vars["LABEL"]].to_numpy().astype(float) + window = self._feat_arrays[stay_id] + labels = self._label_arrays[stay_id].copy() # copy to avoid in-place NaN replacement if len(labels) == 1: # only one label per stay, align with window @@ -122,7 +129,6 @@ def __getitem__(self, idx: int) -> tuple[Tensor, Tensor, Tensor]: # Padding the array to fulfill size requirement if length_diff > 0: - # window shorter than the longest window in dataset, pad to same length window = np.concatenate([window, np.ones((length_diff, window.shape[1])) * pad_value], axis=0) labels = np.concatenate([labels, np.ones(length_diff) * pad_value], axis=0) pad_mask = np.concatenate([pad_mask, np.zeros(length_diff)], axis=0) @@ -138,6 +144,19 @@ def __getitem__(self, idx: int) -> tuple[Tensor, Tensor, Tensor]: return from_numpy(data), from_numpy(labels), from_numpy(pad_mask) + def __getitem__(self, idx: int) -> tuple[Tensor, Tensor, Tensor]: + """Function to sample from the data split of choice. Used for deep learning implementations. + + Args: + idx: A specific row index to sample. + + Returns: + A sample from the data, consisting of data, labels and padding mask. + """ + if self._cached_dataset is not None: + return self._cached_dataset[idx] + return self._build_item(idx) + def get_balance(self) -> list: """Return the weight balance for the split of interest. @@ -157,7 +176,7 @@ def get_data_and_labels(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]: Returns: A Tuple containing data points and label for the split. """ - labels = self.outcome_df[self.vars["LABEL"]].to_numpy().astype(float) + labels = self.outcome_df[self.vars["LABEL"]].to_numpy().astype(np.float32) rep = self.features_df if len(labels) == self.num_stays: @@ -166,17 +185,19 @@ def get_data_and_labels(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]: else: # Adding segment count for each stay id and timestep. rep = rep.with_columns(pl.col(self.vars["GROUP"]).cum_count().over(self.vars["GROUP"]).alias("counter")) - rep = rep.to_numpy().astype(float) + rep = rep.to_numpy().astype(np.float32) logging.debug(f"rep shape: {rep.shape}") logging.debug(f"labels shape: {labels.shape}") return rep, labels, self.row_indicators.to_numpy() - def to_tensor(self) -> tuple[Union[Tensor, np.ndarray], Union[Tensor, np.ndarray], Union[Tensor, np.ndarray]]: + def to_tensor(self) -> tuple[Tensor, Tensor, Tensor]: data, labels, row_indicators = self.get_data_and_labels() - if self.mps: - return from_numpy(data).to(float32), from_numpy(labels).to(float32), from_numpy(row_indicators).to(float32) - else: - return from_numpy(data), from_numpy(labels), row_indicators + + return ( + from_numpy(data), + from_numpy(labels), + from_numpy(row_indicators), + ) @gin.configurable("CommonPandasDataset") @@ -271,7 +292,7 @@ def __getitem__(self, idx: int) -> tuple[Tensor, Tensor, Tensor]: # slice to make sure to always return a DF window = self.features_df.loc[stay_id:stay_id].to_numpy() - labels = self.outcome_df.loc[stay_id:stay_id][self.vars["LABEL"]].to_numpy(dtype=float) + labels = self.outcome_df.loc[stay_id:stay_id][self.vars["LABEL"]].to_numpy(dtype=np.float32) if len(labels) == 1: # only one label per stay, align with window @@ -315,21 +336,19 @@ def get_data_and_labels(self) -> tuple[np.ndarray, np.ndarray]: Returns: A Tuple containing data points and label for the split. """ - labels = self.outcome_df[self.vars["LABEL"]].to_numpy().astype(float) + labels = self.outcome_df[self.vars["LABEL"]].to_numpy().astype(np.float32) rep = self.features_df if len(labels) == self.num_stays: # order of groups could be random, we make sure not to change it rep = rep.groupby(level=self.vars["GROUP"], sort=False).last() - rep = rep.to_numpy().astype(float) + rep = rep.to_numpy().astype(np.float32) return rep, labels - def to_tensor(self) -> tuple[Union[Tensor, np.ndarray], Union[Tensor, np.ndarray]]: + def to_tensor(self) -> tuple[Tensor, Tensor]: data, labels = self.get_data_and_labels() - if self.mps: - return from_numpy(data).to(float32), from_numpy(labels).to(float32) - else: - return from_numpy(data), from_numpy(labels) + # Always use float32 for memory efficiency and MPS compatibility + return from_numpy(data), from_numpy(labels) @gin.configurable("ImputationPandasDataset") diff --git a/icu_benchmarks/models/train.py b/icu_benchmarks/models/train.py index d7fe83cc..1cbd1e21 100644 --- a/icu_benchmarks/models/train.py +++ b/icu_benchmarks/models/train.py @@ -42,7 +42,7 @@ def train_common( model: DLModel | MLModelClassifier | MLModelRegression | object = gin.REQUIRED, weight: str = "", optimizer: type = Adam, - precision: Optional[Literal[16] | Literal[32] | Literal[64] | Literal["16-true"]] = 32, + precision: Optional[Literal[16] | Literal[32] | Literal[64] | str] = 32, batch_size: int = 1, epochs: int = 100, patience: int = 20, @@ -118,6 +118,7 @@ def train_common( num_workers=num_workers, drop_last=True, persistent_workers=persistent_workers, + pin_memory=not cpu, ) val_loader = DataLoader( val_dataset, @@ -126,6 +127,7 @@ def train_common( num_workers=num_workers, drop_last=True, persistent_workers=persistent_workers, + pin_memory=not cpu, ) data_shape = next(iter(train_loader))[0].shape @@ -154,8 +156,8 @@ def train_common( ] if verbose: callbacks.append(TQDMProgressBar(refresh_rate=min(100, len(train_loader) // 2))) - if precision == 16 or "16-mixed": - torch.set_float32_matmul_precision("medium") + if precision in (16, "16-mixed", "bf16", "bf16-mixed"): + torch.set_float32_matmul_precision("high") trainer = Trainer( max_epochs=epochs if model.requires_backprop else 1, @@ -194,7 +196,7 @@ def train_common( batch_size=min(batch_size * 4, len(test_dataset)), shuffle=False, num_workers=num_workers, - pin_memory=True, + pin_memory=not cpu, drop_last=True, persistent_workers=persistent_workers, )