_:hammer_and_wrench: Refactor suggestion_
Optimize __getitem__ method using more Polars operations
The current implementation of __getitem__ still uses numpy operations, which may not be as efficient as using Polars operations throughout. Consider refactoring to use more Polars operations:
def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor]:
if self._cached_dataset is not None:
return self._cached_dataset[idx]
pad_value = 0.0
stay_id = self.outcome_df[self.vars["GROUP"]].unique()[idx]
window = self.features_df.filter(pl.col(self.vars["GROUP"]) == stay_id)
labels = self.outcome_df.filter(pl.col(self.vars["GROUP"]) == stay_id)[self.vars["LABEL"]]
if len(labels) == 1:
labels = pl.concat([pl.Series([None] * (window.height - 1)), labels])
length_diff = self.maxlen - window.height
pad_mask = pl.Series([1] * window.height)
if length_diff > 0:
window = window.vstack(pl.DataFrame({col: [pad_value] * length_diff for col in window.columns}))
labels = labels.extend(pl.Series([pad_value] * length_diff))
pad_mask = pad_mask.extend(pl.Series([0] * length_diff))
labels = labels.fill_null(-1)
pad_mask = pad_mask.where(labels != -1, 0)
return (
from_numpy(window.to_numpy().astype(np.float32)),
from_numpy(labels.to_numpy().astype(np.float32)),
from_numpy(pad_mask.to_numpy().astype(bool))
)
This refactored version uses more Polars operations, which should be more efficient, especially for larger datasets.
Originally posted by @coderabbitai[bot] in #155 (comment)
Optimize
__getitem__method using more Polars operationsThe current implementation of
__getitem__still uses numpy operations, which may not be as efficient as using Polars operations throughout. Consider refactoring to use more Polars operations:This refactored version uses more Polars operations, which should be more efficient, especially for larger datasets.
Originally posted by @coderabbitai[bot] in #155 (comment)