Skip to content

**Optimize __getitem__ method using more Polars operations** #158

@rvandewater

Description

@rvandewater
          _:hammer_and_wrench: Refactor suggestion_

Optimize __getitem__ method using more Polars operations

The current implementation of __getitem__ still uses numpy operations, which may not be as efficient as using Polars operations throughout. Consider refactoring to use more Polars operations:

def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor]:
    if self._cached_dataset is not None:
        return self._cached_dataset[idx]

    pad_value = 0.0
    stay_id = self.outcome_df[self.vars["GROUP"]].unique()[idx]

    window = self.features_df.filter(pl.col(self.vars["GROUP"]) == stay_id)
    labels = self.outcome_df.filter(pl.col(self.vars["GROUP"]) == stay_id)[self.vars["LABEL"]]

    if len(labels) == 1:
        labels = pl.concat([pl.Series([None] * (window.height - 1)), labels])

    length_diff = self.maxlen - window.height
    pad_mask = pl.Series([1] * window.height)

    if length_diff > 0:
        window = window.vstack(pl.DataFrame({col: [pad_value] * length_diff for col in window.columns}))
        labels = labels.extend(pl.Series([pad_value] * length_diff))
        pad_mask = pad_mask.extend(pl.Series([0] * length_diff))

    labels = labels.fill_null(-1)
    pad_mask = pad_mask.where(labels != -1, 0)

    return (
        from_numpy(window.to_numpy().astype(np.float32)),
        from_numpy(labels.to_numpy().astype(np.float32)),
        from_numpy(pad_mask.to_numpy().astype(bool))
    )

This refactored version uses more Polars operations, which should be more efficient, especially for larger datasets.

Originally posted by @coderabbitai[bot] in #155 (comment)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions