diff --git a/src/stamp/modeling/vision_transformer.py b/src/stamp/modeling/vision_transformer.py index cbc95c56..fa15cd20 100755 --- a/src/stamp/modeling/vision_transformer.py +++ b/src/stamp/modeling/vision_transformer.py @@ -41,6 +41,7 @@ def __init__( super().__init__() self.heads = num_heads self.norm = nn.LayerNorm(dim) + self.attn_weights = None # NEW if use_alibi: self.mhsa = MultiHeadALiBi( @@ -75,17 +76,18 @@ def forward( x = self.norm(x) match self.mhsa: case nn.MultiheadAttention(): - attn_output, _ = self.mhsa( + attn_output, attn_weights = self.mhsa( x, x, x, - need_weights=False, + need_weights=True, # NEW: enable attention weights attn_mask=( attn_mask.repeat(self.mhsa.num_heads, 1, 1) if attn_mask is not None else None ), ) + self.attn_weights = attn_weights.detach() # NEW case MultiHeadALiBi(): attn_output = self.mhsa( q=x, @@ -96,6 +98,7 @@ def forward( attn_mask=attn_mask, alibi_mask=alibi_mask, ) + self.attn_weights = None # NEW: no weights for ALiBi case _ as unreachable: assert_never(unreachable) @@ -242,3 +245,8 @@ def forward( bags = bags[:, 0] return self.mlp_head(bags) + + @property + def attention_weights(self): + # NEW: expose attention weights to user + return self.transformer.attn_weights_per_layer