Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions ultralytics/models/sam/modules/memory_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,24 @@ def _forward_ca(
assert isinstance(self.cross_attn_image, RoPEAttention)
kwds = {"num_k_exclude_rope": num_k_exclude_rope}

# Cross-Attention
tgt2 = self.norm2(tgt)
# Avoid repeated computation of norm2(tgt). Do only once.
# The additions for positional encoding are the only variable branches.
tgt2_normed = self.norm2(tgt)
if self.pos_enc_at_cross_attn_queries:
q = tgt2_normed if query_pos is None else tgt2_normed + query_pos
else:
q = tgt2_normed

# When using pos enc in keys, perform the addition conditionally and only if pos is not None.
if self.pos_enc_at_cross_attn_keys:
k = memory if pos is None else memory + pos
else:
k = memory

# Call cross_attn_image only once and minimize intermediate variables
tgt2 = self.cross_attn_image(
q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
q=q,
k=k,
v=memory,
**kwds,
)
Expand Down