Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/stamp/modeling/vision_transformer.py
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am confused. Where are the attention weights used in the gradcam function?
In general, there are two different approaches for Transformer heatmaps:

  1. gradient back to input (our standard nowadays)

  2. attention activation (attention weights) visualization

  3. has the disadvantage that parts of the model are kinda neglected (MLP parts). And there are multiple heads and layers that have individual activations).

So I am unsure if we are not mixing these things with this approach. Also returning the weights will affect the performance of the ViT.

Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
super().__init__()
self.heads = num_heads
self.norm = nn.LayerNorm(dim)
self.attn_weights = None # NEW

if use_alibi:
self.mhsa = MultiHeadALiBi(
Expand Down Expand Up @@ -75,17 +76,18 @@ def forward(
x = self.norm(x)
match self.mhsa:
case nn.MultiheadAttention():
attn_output, _ = self.mhsa(
attn_output, attn_weights = self.mhsa(
x,
x,
x,
need_weights=False,
need_weights=True, # NEW: enable attention weights
attn_mask=(
attn_mask.repeat(self.mhsa.num_heads, 1, 1)
if attn_mask is not None
else None
),
)
self.attn_weights = attn_weights.detach() # NEW
case MultiHeadALiBi():
attn_output = self.mhsa(
q=x,
Expand All @@ -96,6 +98,7 @@ def forward(
attn_mask=attn_mask,
alibi_mask=alibi_mask,
)
self.attn_weights = None # NEW: no weights for ALiBi
case _ as unreachable:
assert_never(unreachable)

Expand Down Expand Up @@ -242,3 +245,8 @@ def forward(
bags = bags[:, 0]

return self.mlp_head(bags)

@property
def attention_weights(self):
# NEW: expose attention weights to user
return self.transformer.attn_weights_per_layer
Loading