Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions src/stamp/modeling/alibi.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def forward(
v: Float[Tensor, "batch key v_feature"],
coords_q: Float[Tensor, "batch query coord"],
coords_k: Float[Tensor, "batch key coord"],
attn_mask: Bool[Tensor, "batch query key"],
alibi_mask: Bool[Tensor, "batch query key"],
attn_mask: Bool[Tensor, "batch query key"] | None,
alibi_mask: Bool[Tensor, "batch query key"] | None,
) -> Float[Tensor, "batch query v_feature"]:
"""
Args:
Expand All @@ -51,12 +51,18 @@ def forward(
coords_q.unsqueeze(2) - coords_k.unsqueeze(1), dim=-1
)
scaled_distances = self.scale_distance(distances) * self.bias_scale
masked_distances = scaled_distances.where(~alibi_mask, 0.0)

if alibi_mask is not None:
scaled_distances = scaled_distances.where(~alibi_mask, 0.0)

weights = torch.softmax(weight_logits, dim=-1)
masked = (weights - masked_distances).where(~attn_mask, 0.0)

attention = torch.einsum("bqk,bvf->bqf", masked, v)
if attn_mask is not None:
weights = (weights - scaled_distances).where(~attn_mask, 0.0)
else:
weights = weights - scaled_distances

attention = torch.einsum("bqk,bkf->bqf", weights, v)

return attention

Expand Down Expand Up @@ -116,7 +122,7 @@ def forward(
coords_q: Float[Tensor, "batch query coord"],
coords_k: Float[Tensor, "batch key coord"],
attn_mask: Bool[Tensor, "batch query key"] | None,
alibi_mask: Bool[Tensor, "batch query key"],
alibi_mask: Bool[Tensor, "batch query key"] | None,
) -> Float[Tensor, "batch query mh_v_feature"]:
stacked_attentions = torch.stack(
[
Expand Down
100 changes: 55 additions & 45 deletions src/stamp/modeling/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,36 @@


class LitVisionTransformer(lightning.LightningModule):
"""
PyTorch Lightning wrapper for the Vision Transformer (ViT) model used in weakly supervised
learning settings, such as Multiple Instance Learning (MIL) for whole-slide images or patch-based data.

This class encapsulates training, validation, testing, and prediction logic, along with:
- Masking logic that ensures only valid tiles (patches) participate in attention during training.
- AUROC metric tracking during validation for multiclass classification.
- Compatibility checks based on the `stamp` framework version.
- Integration of class imbalance handling through weighted cross-entropy loss.

The attention mask is applied *only* during training to hide paddings
and is skipped during evaluation and inference for reducing memory usage.

Args:
categories: List of class labels.
category_weights: Class weights for cross-entropy loss to handle imbalance.
dim_input: Input feature dimensionality per tile.
dim_model: Latent dimensionality used inside the transformer.
dim_feedforward: Dimensionality of the transformer MLP block.
n_heads: Number of self-attention heads.
n_layers: Number of transformer layers.
dropout: Dropout rate used throughout the model.
use_alibi: Whether to use ALiBi-style positional bias in attention (optional).
ground_truth_label: Column name for accessing ground-truth labels from metadata.
train_patients: List of patient IDs used for training.
valid_patients: List of patient IDs used for validation.
stamp_version: Version of the `stamp` framework used during training.
**metadata: Additional metadata to store with the model.
"""

def __init__(
self,
*,
Expand All @@ -49,17 +79,11 @@ def __init__(
# Other metadata
**metadata,
) -> None:
"""
Args:
metadata:
Any additional information to be saved in the models,
but not directly influencing the model.
"""
super().__init__()

if len(categories) != len(category_weights):
raise ValueError(
"the number of category weights has to mathc the number of categories!"
"the number of category weights has to match the number of categories!"
)

self.vision_transformer = VisionTransformer(
Expand All @@ -73,6 +97,15 @@ def __init__(
use_alibi=use_alibi,
)
self.class_weights = category_weights
self.valid_auroc = MulticlassAUROC(len(categories))

# Used during deployment
self.ground_truth_label = ground_truth_label
self.categories = np.array(categories)
self.train_patients = train_patients
self.valid_patients = valid_patients

_ = metadata # unused, but saved in model

# Check if version is compatible.
# This should only happen when the model is loaded,
Expand All @@ -92,16 +125,6 @@ def __init__(
"Please upgrade stamp to a compatible version."
)

self.valid_auroc = MulticlassAUROC(len(categories))

# Used during deployment
self.ground_truth_label = ground_truth_label
self.categories = np.array(categories)
self.train_patients = train_patients
self.valid_patients = valid_patients

_ = metadata # unused, but saved in model

self.save_hyperparameters()

def forward(
Expand All @@ -113,20 +136,20 @@ def forward(
def _step(
self,
*,
step_name: str,
batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets],
batch_idx: int,
step_name: str,
use_mask: bool,
) -> Loss:
_ = batch_idx # unused

bags, coords, bag_sizes, targets = batch

logits = self.vision_transformer(
bags, coords=coords, mask=_mask_from_bags(bags=bags, bag_sizes=bag_sizes)
)
mask = _mask_from_bags(bags=bags, bag_sizes=bag_sizes) if use_mask else None

logits = self.vision_transformer(bags, coords=coords, mask=mask)

loss = nn.functional.cross_entropy(
logits, targets.type_as(logits), weight=self.class_weights.type_as(logits)
logits,
targets.type_as(logits),
weight=self.class_weights.type_as(logits),
)

self.log(
Expand All @@ -140,7 +163,7 @@ def _step(

if step_name == "validation":
# TODO this is a bit ugly, we'd like to have `_step` without special cases
self.valid_auroc.update(logits, targets.long().argmax(-1))
self.valid_auroc.update(logits, targets.long().argmax(dim=-1))
self.log(
f"{step_name}_auroc",
self.valid_auroc,
Expand All @@ -156,43 +179,30 @@ def training_step(
batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets],
batch_idx: int,
) -> Loss:
return self._step(
step_name="training",
batch=batch,
batch_idx=batch_idx,
)
return self._step(batch=batch, step_name="training", use_mask=True)

def validation_step(
self,
batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets],
batch_idx: int,
) -> Loss:
return self._step(
step_name="validation",
batch=batch,
batch_idx=batch_idx,
)
return self._step(batch=batch, step_name="validation", use_mask=False)

def test_step(
self,
batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets],
batch_idx: int,
) -> Loss:
return self._step(
step_name="test",
batch=batch,
batch_idx=batch_idx,
)
return self._step(batch=batch, step_name="test", use_mask=False)

def predict_step(
self,
batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets],
batch_idx: int = -1,
batch_idx: int,
) -> Float[Tensor, "batch logit"]:
bags, coords, bag_sizes, _ = batch
return self.vision_transformer(
bags, coords=coords, mask=_mask_from_bags(bags=bags, bag_sizes=bag_sizes)
)
# adding a mask here will *drastically* and *unbearably* increase memory usage
return self.vision_transformer(bags, coords=coords, mask=None)

def configure_optimizers(self) -> optim.Optimizer:
optimizer = optim.Adam(self.parameters(), lr=1e-3)
Expand Down
8 changes: 5 additions & 3 deletions src/stamp/modeling/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def forward(
coords: Float[Tensor, "batch sequence xy"],
attn_mask: Bool[Tensor, "batch sequence sequence"] | None,
# Help, my abstractions are leaking!
alibi_mask: Bool[Tensor, "batch sequence sequence"],
alibi_mask: Bool[Tensor, "batch sequence sequence"] | None,
) -> Float[Tensor, "batch sequence proj_feature"]:
"""
Args:
Expand Down Expand Up @@ -144,7 +144,7 @@ def forward(
*,
coords: Float[Tensor, "batch sequence 2"],
attn_mask: Bool[Tensor, "batch sequence sequence"] | None,
alibi_mask: Bool[Tensor, "batch sequence sequence"],
alibi_mask: Bool[Tensor, "batch sequence sequence"] | None,
) -> Float[Tensor, "batch sequence proj_feature"]:
for attn, ff in cast(Iterable[tuple[nn.Module, nn.Module]], self.layers):
x_attn = attn(x, coords=coords, attn_mask=attn_mask, alibi_mask=alibi_mask)
Expand Down Expand Up @@ -212,7 +212,9 @@ def forward(

match mask:
case None:
bags = self.transformer(bags, coords=coords, attn_mask=None)
bags = self.transformer(
bags, coords=coords, attn_mask=None, alibi_mask=None
)

case _:
mask_with_class_token = torch.cat(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_alibi.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def test_alibi_shapes(embed_dim: int = 32, num_heads: int = 8) -> None:

q = torch.rand(2, 23, embed_dim)
k = torch.rand(2, 34, embed_dim)
v = torch.rand(2, 8, embed_dim)
v = torch.rand(2, 34, embed_dim)
coords_q = torch.rand(2, 23, 2)
coords_k = torch.rand(2, 34, 2)
attn_mask = torch.rand(2, 23, 34) > 0.5
Expand Down