From c83cd1dbea2e122c65c6872b11ac24f22e92d36c Mon Sep 17 00:00:00 2001 From: Juan Pablo Date: Mon, 23 Jun 2025 17:08:59 +0100 Subject: [PATCH 1/3] remove att masking on valid, test and pred - make validation and testing more efficient - improve lightning_model code - make alibi mask optional when there no attention mask is passed as parameter --- src/stamp/modeling/lightning_model.py | 123 +++++++++++------------ src/stamp/modeling/vision_transformer.py | 8 +- 2 files changed, 62 insertions(+), 69 deletions(-) diff --git a/src/stamp/modeling/lightning_model.py b/src/stamp/modeling/lightning_model.py index f6642931..1437de56 100644 --- a/src/stamp/modeling/lightning_model.py +++ b/src/stamp/modeling/lightning_model.py @@ -27,6 +27,35 @@ class LitVisionTransformer(lightning.LightningModule): + """ + PyTorch Lightning wrapper for the Vision Transformer (ViT) model used in weakly supervised + learning settings, such as Multiple Instance Learning (MIL) for whole-slide images or patch-based data. + + This class encapsulates training, validation, testing, and prediction logic, along with: + - Masking logic that ensures only valid tiles (patches) participate in attention during training. + - AUROC metric tracking during validation for multiclass classification. + - Compatibility checks based on the `stamp` framework version. + - Integration of class imbalance handling through weighted cross-entropy loss. + + The attention mask is applied *only* during training to hide paddings + and is skipped during evaluation and inference for reducing memory usage. + + Args: + categories: List of class labels. + category_weights: Class weights for cross-entropy loss to handle imbalance. + dim_input: Input feature dimensionality per tile. + dim_model: Latent dimensionality used inside the transformer. + dim_feedforward: Dimensionality of the transformer MLP block. + n_heads: Number of self-attention heads. + n_layers: Number of transformer layers. + dropout: Dropout rate used throughout the model. + use_alibi: Whether to use ALiBi-style positional bias in attention (optional). + ground_truth_label: Column name for accessing ground-truth labels from metadata. + train_patients: List of patient IDs used for training. + valid_patients: List of patient IDs used for validation. + stamp_version: Version of the `stamp` framework used during training. + **metadata: Additional metadata to store with the model. + """ def __init__( self, *, @@ -40,7 +69,7 @@ def __init__( dropout: float, # Experimental features # TODO remove default values for stamp 3; they're only here for backwards compatibility - use_alibi: bool = False, + use_alibi: bool, # Metadata used by other parts of stamp, but not by the model itself ground_truth_label: PandasLabel, train_patients: Iterable[PatientId], @@ -49,17 +78,11 @@ def __init__( # Other metadata **metadata, ) -> None: - """ - Args: - metadata: - Any additional information to be saved in the models, - but not directly influencing the model. - """ super().__init__() if len(categories) != len(category_weights): raise ValueError( - "the number of category weights has to mathc the number of categories!" + "the number of category weights has to match the number of categories!" ) self.vision_transformer = VisionTransformer( @@ -73,6 +96,15 @@ def __init__( use_alibi=use_alibi, ) self.class_weights = category_weights + self.valid_auroc = MulticlassAUROC(len(categories)) + + # Used during deployment + self.ground_truth_label = ground_truth_label + self.categories = np.array(categories) + self.train_patients = train_patients + self.valid_patients = valid_patients + + _ = metadata # unused, but saved in model # Check if version is compatible. # This should only happen when the model is loaded, @@ -92,16 +124,6 @@ def __init__( "Please upgrade stamp to a compatible version." ) - self.valid_auroc = MulticlassAUROC(len(categories)) - - # Used during deployment - self.ground_truth_label = ground_truth_label - self.categories = np.array(categories) - self.train_patients = train_patients - self.valid_patients = valid_patients - - _ = metadata # unused, but saved in model - self.save_hyperparameters() def forward( @@ -112,21 +134,20 @@ def forward( def _step( self, - *, - step_name: str, batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], - batch_idx: int, + step_name: str, + use_mask: bool, ) -> Loss: - _ = batch_idx # unused - bags, coords, bag_sizes, targets = batch - logits = self.vision_transformer( - bags, coords=coords, mask=_mask_from_bags(bags=bags, bag_sizes=bag_sizes) - ) + mask = _mask_from_bags(bags=bags, bag_sizes=bag_sizes) if use_mask else None + + logits = self.vision_transformer(bags, coords=coords, mask=mask) loss = nn.functional.cross_entropy( - logits, targets.type_as(logits), weight=self.class_weights.type_as(logits) + logits, + targets.type_as(logits), + weight=self.class_weights.type_as(logits), ) self.log( @@ -140,7 +161,7 @@ def _step( if step_name == "validation": # TODO this is a bit ugly, we'd like to have `_step` without special cases - self.valid_auroc.update(logits, targets.long().argmax(-1)) + self.valid_auroc.update(logits, targets.argmax(dim=-1)) self.log( f"{step_name}_auroc", self.valid_auroc, @@ -151,48 +172,18 @@ def _step( return loss - def training_step( - self, - batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], - batch_idx: int, - ) -> Loss: - return self._step( - step_name="training", - batch=batch, - batch_idx=batch_idx, - ) + def training_step(self, batch, batch_idx) -> Loss: + return self._step(batch, step_name="training", use_mask=True) - def validation_step( - self, - batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], - batch_idx: int, - ) -> Loss: - return self._step( - step_name="validation", - batch=batch, - batch_idx=batch_idx, - ) + def validation_step(self, batch, batch_idx) -> Loss: + return self._step(batch, step_name="validation", use_mask=False) - def test_step( - self, - batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], - batch_idx: int, - ) -> Loss: - return self._step( - step_name="test", - batch=batch, - batch_idx=batch_idx, - ) + def test_step(self, batch, batch_idx) -> Loss: + return self._step(batch, step_name="test", use_mask=False) - def predict_step( - self, - batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], - batch_idx: int = -1, - ) -> Float[Tensor, "batch logit"]: + def predict_step(self, batch, batch_idx: int = -1) -> Float[Tensor, "batch logit"]: bags, coords, bag_sizes, _ = batch - return self.vision_transformer( - bags, coords=coords, mask=_mask_from_bags(bags=bags, bag_sizes=bag_sizes) - ) + return self.vision_transformer(bags, coords=coords, mask=None) def configure_optimizers(self) -> optim.Optimizer: optimizer = optim.Adam(self.parameters(), lr=1e-3) diff --git a/src/stamp/modeling/vision_transformer.py b/src/stamp/modeling/vision_transformer.py index 88225f9d..cbc95c56 100755 --- a/src/stamp/modeling/vision_transformer.py +++ b/src/stamp/modeling/vision_transformer.py @@ -58,7 +58,7 @@ def forward( coords: Float[Tensor, "batch sequence xy"], attn_mask: Bool[Tensor, "batch sequence sequence"] | None, # Help, my abstractions are leaking! - alibi_mask: Bool[Tensor, "batch sequence sequence"], + alibi_mask: Bool[Tensor, "batch sequence sequence"] | None, ) -> Float[Tensor, "batch sequence proj_feature"]: """ Args: @@ -144,7 +144,7 @@ def forward( *, coords: Float[Tensor, "batch sequence 2"], attn_mask: Bool[Tensor, "batch sequence sequence"] | None, - alibi_mask: Bool[Tensor, "batch sequence sequence"], + alibi_mask: Bool[Tensor, "batch sequence sequence"] | None, ) -> Float[Tensor, "batch sequence proj_feature"]: for attn, ff in cast(Iterable[tuple[nn.Module, nn.Module]], self.layers): x_attn = attn(x, coords=coords, attn_mask=attn_mask, alibi_mask=alibi_mask) @@ -212,7 +212,9 @@ def forward( match mask: case None: - bags = self.transformer(bags, coords=coords, attn_mask=None) + bags = self.transformer( + bags, coords=coords, attn_mask=None, alibi_mask=None + ) case _: mask_with_class_token = torch.cat( From b35c121aad75e1f3cc5c189ee3946fa49decfbb0 Mon Sep 17 00:00:00 2001 From: Juan Pablo Date: Tue, 24 Jun 2025 10:23:50 +0100 Subject: [PATCH 2/3] fix compatibility and format --- src/stamp/modeling/alibi.py | 18 ++++++++++++------ src/stamp/modeling/lightning_model.py | 5 +++-- tests/test_alibi.py | 2 +- 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/src/stamp/modeling/alibi.py b/src/stamp/modeling/alibi.py index ae0916a6..2714b26b 100644 --- a/src/stamp/modeling/alibi.py +++ b/src/stamp/modeling/alibi.py @@ -38,8 +38,8 @@ def forward( v: Float[Tensor, "batch key v_feature"], coords_q: Float[Tensor, "batch query coord"], coords_k: Float[Tensor, "batch key coord"], - attn_mask: Bool[Tensor, "batch query key"], - alibi_mask: Bool[Tensor, "batch query key"], + attn_mask: Bool[Tensor, "batch query key"] | None, + alibi_mask: Bool[Tensor, "batch query key"] | None, ) -> Float[Tensor, "batch query v_feature"]: """ Args: @@ -51,12 +51,18 @@ def forward( coords_q.unsqueeze(2) - coords_k.unsqueeze(1), dim=-1 ) scaled_distances = self.scale_distance(distances) * self.bias_scale - masked_distances = scaled_distances.where(~alibi_mask, 0.0) + + if alibi_mask is not None: + scaled_distances = scaled_distances.where(~alibi_mask, 0.0) weights = torch.softmax(weight_logits, dim=-1) - masked = (weights - masked_distances).where(~attn_mask, 0.0) - attention = torch.einsum("bqk,bvf->bqf", masked, v) + if attn_mask is not None: + weights = (weights - scaled_distances).where(~attn_mask, 0.0) + else: + weights = weights - scaled_distances + + attention = torch.einsum("bqk,bkf->bqf", weights, v) return attention @@ -116,7 +122,7 @@ def forward( coords_q: Float[Tensor, "batch query coord"], coords_k: Float[Tensor, "batch key coord"], attn_mask: Bool[Tensor, "batch query key"] | None, - alibi_mask: Bool[Tensor, "batch query key"], + alibi_mask: Bool[Tensor, "batch query key"] | None, ) -> Float[Tensor, "batch query mh_v_feature"]: stacked_attentions = torch.stack( [ diff --git a/src/stamp/modeling/lightning_model.py b/src/stamp/modeling/lightning_model.py index 1437de56..69e6aea1 100644 --- a/src/stamp/modeling/lightning_model.py +++ b/src/stamp/modeling/lightning_model.py @@ -56,6 +56,7 @@ class LitVisionTransformer(lightning.LightningModule): stamp_version: Version of the `stamp` framework used during training. **metadata: Additional metadata to store with the model. """ + def __init__( self, *, @@ -69,7 +70,7 @@ def __init__( dropout: float, # Experimental features # TODO remove default values for stamp 3; they're only here for backwards compatibility - use_alibi: bool, + use_alibi: bool = False, # Metadata used by other parts of stamp, but not by the model itself ground_truth_label: PandasLabel, train_patients: Iterable[PatientId], @@ -161,7 +162,7 @@ def _step( if step_name == "validation": # TODO this is a bit ugly, we'd like to have `_step` without special cases - self.valid_auroc.update(logits, targets.argmax(dim=-1)) + self.valid_auroc.update(logits, targets.long().argmax(dim=-1)) self.log( f"{step_name}_auroc", self.valid_auroc, diff --git a/tests/test_alibi.py b/tests/test_alibi.py index 4462f2af..dc0b2378 100644 --- a/tests/test_alibi.py +++ b/tests/test_alibi.py @@ -11,7 +11,7 @@ def test_alibi_shapes(embed_dim: int = 32, num_heads: int = 8) -> None: q = torch.rand(2, 23, embed_dim) k = torch.rand(2, 34, embed_dim) - v = torch.rand(2, 8, embed_dim) + v = torch.rand(2, 34, embed_dim) coords_q = torch.rand(2, 23, 2) coords_k = torch.rand(2, 34, 2) attn_mask = torch.rand(2, 23, 34) > 0.5 From da565bc03c52ff7c0f2639402ee339d4d29db717 Mon Sep 17 00:00:00 2001 From: Juan Pablo Date: Wed, 25 Jun 2025 09:14:27 +0100 Subject: [PATCH 3/3] add formatting --- src/stamp/modeling/lightning_model.py | 32 +++++++++++++++++++++------ 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/src/stamp/modeling/lightning_model.py b/src/stamp/modeling/lightning_model.py index 69e6aea1..ca950bab 100644 --- a/src/stamp/modeling/lightning_model.py +++ b/src/stamp/modeling/lightning_model.py @@ -135,6 +135,7 @@ def forward( def _step( self, + *, batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], step_name: str, use_mask: bool, @@ -173,17 +174,34 @@ def _step( return loss - def training_step(self, batch, batch_idx) -> Loss: - return self._step(batch, step_name="training", use_mask=True) + def training_step( + self, + batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], + batch_idx: int, + ) -> Loss: + return self._step(batch=batch, step_name="training", use_mask=True) - def validation_step(self, batch, batch_idx) -> Loss: - return self._step(batch, step_name="validation", use_mask=False) + def validation_step( + self, + batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], + batch_idx: int, + ) -> Loss: + return self._step(batch=batch, step_name="validation", use_mask=False) - def test_step(self, batch, batch_idx) -> Loss: - return self._step(batch, step_name="test", use_mask=False) + def test_step( + self, + batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], + batch_idx: int, + ) -> Loss: + return self._step(batch=batch, step_name="test", use_mask=False) - def predict_step(self, batch, batch_idx: int = -1) -> Float[Tensor, "batch logit"]: + def predict_step( + self, + batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], + batch_idx: int, + ) -> Float[Tensor, "batch logit"]: bags, coords, bag_sizes, _ = batch + # adding a mask here will *drastically* and *unbearably* increase memory usage return self.vision_transformer(bags, coords=coords, mask=None) def configure_optimizers(self) -> optim.Optimizer: