Skip to content

Commit dad1ca1

Browse files
committed
Some rework of csatv2 to better fit timm norms, re-use existing layers where possible
1 parent 8d5e51e commit dad1ca1

File tree

3 files changed

+548
-430
lines changed

3 files changed

+548
-430
lines changed

tests/test_models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@
5656
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*', 'tnt',
5757
'tiny_vit', 'vovnet', 'tresnet', 'rexnet', 'resnetv2', 'repghost', 'repvit', 'pvt_v2', 'nextvit', 'nest',
5858
'mambaout', 'inception_next', 'inception_v4', 'hgnet', 'gcvit', 'focalnet', 'efficientformer_v2', 'edgenext',
59-
'davit', 'rdnet', 'convnext', 'pit', 'starnet', 'shvit', 'fasternet', 'swiftformer', 'ghostnet', 'naflexvit'
59+
'davit', 'rdnet', 'convnext', 'pit', 'starnet', 'shvit', 'fasternet', 'swiftformer', 'ghostnet', 'naflexvit',
60+
'csatv2'
6061
]
6162

6263
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.

timm/layers/attention.py

Lines changed: 54 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ def __init__(
2929
self,
3030
dim: int,
3131
num_heads: int = 8,
32+
attn_head_dim: Optional[int] = None,
33+
dim_out: Optional[int] = None,
3234
qkv_bias: bool = False,
3335
qk_norm: bool = False,
3436
scale_norm: bool = False,
@@ -37,36 +39,45 @@ def __init__(
3739
proj_drop: float = 0.,
3840
norm_layer: Optional[Type[nn.Module]] = None,
3941
device=None,
40-
dtype=None
42+
dtype=None,
4143
) -> None:
4244
"""Initialize the Attention module.
4345
4446
Args:
45-
dim: Input dimension of the token embeddings
46-
num_heads: Number of attention heads
47-
qkv_bias: Whether to use bias in the query, key, value projections
48-
qk_norm: Whether to apply normalization to query and key vectors
49-
proj_bias: Whether to use bias in the output projection
50-
attn_drop: Dropout rate applied to the attention weights
51-
proj_drop: Dropout rate applied after the output projection
52-
norm_layer: Normalization layer constructor for QK normalization if enabled
47+
dim: Input dimension of the token embeddings.
48+
num_heads: Number of attention heads.
49+
attn_head_dim: Dimension of each attention head. If None, computed as dim // num_heads.
50+
dim_out: Output dimension. If None, same as dim.
51+
qkv_bias: Whether to use bias in the query, key, value projections.
52+
qk_norm: Whether to apply normalization to query and key vectors.
53+
scale_norm: Whether to apply normalization to attention output before projection.
54+
proj_bias: Whether to use bias in the output projection.
55+
attn_drop: Dropout rate applied to the attention weights.
56+
proj_drop: Dropout rate applied after the output projection.
57+
norm_layer: Normalization layer constructor for QK normalization if enabled.
5358
"""
5459
super().__init__()
5560
dd = {'device': device, 'dtype': dtype}
56-
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
61+
dim_out = dim_out or dim
62+
head_dim = attn_head_dim
63+
if head_dim is None:
64+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
65+
head_dim = dim // num_heads
5766
if qk_norm or scale_norm:
5867
assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True'
68+
5969
self.num_heads = num_heads
60-
self.head_dim = dim // num_heads
61-
self.scale = self.head_dim ** -0.5
70+
self.head_dim = head_dim
71+
self.attn_dim = num_heads * head_dim
72+
self.scale = head_dim ** -0.5
6273
self.fused_attn = use_fused_attn()
6374

64-
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd)
65-
self.q_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity()
66-
self.k_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity()
75+
self.qkv = nn.Linear(dim, self.attn_dim * 3, bias=qkv_bias, **dd)
76+
self.q_norm = norm_layer(head_dim, **dd) if qk_norm else nn.Identity()
77+
self.k_norm = norm_layer(head_dim, **dd) if qk_norm else nn.Identity()
6778
self.attn_drop = nn.Dropout(attn_drop)
68-
self.norm = norm_layer(dim, **dd) if scale_norm else nn.Identity()
69-
self.proj = nn.Linear(dim, dim, bias=proj_bias, **dd)
79+
self.norm = norm_layer(self.attn_dim, **dd) if scale_norm else nn.Identity()
80+
self.proj = nn.Linear(self.attn_dim, dim_out, bias=proj_bias, **dd)
7081
self.proj_drop = nn.Dropout(proj_drop)
7182

7283
def forward(
@@ -93,7 +104,7 @@ def forward(
93104
attn = self.attn_drop(attn)
94105
x = attn @ v
95106

96-
x = x.transpose(1, 2).reshape(B, N, C)
107+
x = x.transpose(1, 2).reshape(B, N, self.attn_dim)
97108
x = self.norm(x)
98109
x = self.proj(x)
99110
x = self.proj_drop(x)
@@ -114,6 +125,7 @@ def __init__(
114125
self,
115126
dim: int,
116127
num_heads: int = 8,
128+
dim_out: Optional[int] = None,
117129
qkv_bias: bool = True,
118130
qkv_fused: bool = True,
119131
num_prefix_tokens: int = 1,
@@ -133,45 +145,52 @@ def __init__(
133145
Args:
134146
dim: Input dimension of the token embeddings
135147
num_heads: Number of attention heads
148+
dim_out: Output dimension. If None, same as dim.
136149
qkv_bias: Whether to add a bias term to the query, key, and value projections
150+
qkv_fused: Whether to use fused QKV projection (single linear) or separate projections
137151
num_prefix_tokens: Number of reg/cls tokens at the beginning of the sequence that
138152
should not have position embeddings applied
139153
attn_drop: Dropout rate for attention weights
140154
proj_drop: Dropout rate for the output projection
141-
attn_head_dim: Dimension of each attention head (if None, computed as dim // num_heads)
155+
attn_head_dim: Dimension of each attention head. If None, computed as dim // num_heads.
142156
norm_layer: Normalization layer constructor to use for QK and scale normalization
143157
qk_norm: Enable normalization of query (Q) and key (K) vectors with norm_layer
144158
scale_norm: Enable normalization (scaling) of attention output with norm_layer
159+
proj_bias: Whether to use bias in the output projection
145160
rotate_half: Use 'half' ROPE layout instead of default 'interleaved'
146161
"""
147162
super().__init__()
148163
dd = {'device': device, 'dtype': dtype}
164+
dim_out = dim_out or dim
165+
head_dim = attn_head_dim
166+
if head_dim is None:
167+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
168+
head_dim = dim // num_heads
149169
if scale_norm or qk_norm:
150170
assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True'
171+
151172
self.num_heads = num_heads
152-
head_dim = dim // num_heads
153-
if attn_head_dim is not None:
154-
head_dim = attn_head_dim
155-
attn_dim = head_dim * self.num_heads
173+
self.head_dim = head_dim
174+
self.attn_dim = head_dim * num_heads
156175
self.scale = head_dim ** -0.5
157176
self.num_prefix_tokens = num_prefix_tokens
158177
self.fused_attn = use_fused_attn()
159178
self.rotate_half = rotate_half
160179

161180
if qkv_fused:
162-
self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias, **dd)
181+
self.qkv = nn.Linear(dim, self.attn_dim * 3, bias=qkv_bias, **dd)
163182
self.q_proj = self.k_proj = self.v_proj = None
164183
else:
165184
self.qkv = None
166-
self.q_proj = nn.Linear(dim, attn_dim, bias=qkv_bias, **dd)
167-
self.k_proj = nn.Linear(dim, attn_dim, bias=qkv_bias, **dd)
168-
self.v_proj = nn.Linear(dim, attn_dim, bias=qkv_bias, **dd)
185+
self.q_proj = nn.Linear(dim, self.attn_dim, bias=qkv_bias, **dd)
186+
self.k_proj = nn.Linear(dim, self.attn_dim, bias=qkv_bias, **dd)
187+
self.v_proj = nn.Linear(dim, self.attn_dim, bias=qkv_bias, **dd)
169188

170189
self.q_norm = norm_layer(head_dim, **dd) if qk_norm else nn.Identity()
171190
self.k_norm = norm_layer(head_dim, **dd) if qk_norm else nn.Identity()
172191
self.attn_drop = nn.Dropout(attn_drop)
173-
self.norm = norm_layer(attn_dim, **dd) if scale_norm else nn.Identity()
174-
self.proj = nn.Linear(attn_dim, dim, bias=proj_bias, **dd)
192+
self.norm = norm_layer(self.attn_dim, **dd) if scale_norm else nn.Identity()
193+
self.proj = nn.Linear(self.attn_dim, dim_out, bias=proj_bias, **dd)
175194
self.proj_drop = nn.Dropout(proj_drop)
176195

177196
def forward(
@@ -188,18 +207,18 @@ def forward(
188207
attn_mask: Optional attention mask to apply during attention computation
189208
190209
Returns:
191-
Tensor of shape (batch_size, sequence_length, embedding_dim)
210+
Tensor of shape (batch_size, sequence_length, dim_out)
192211
"""
193212
B, N, C = x.shape
194213

195214
if self.qkv is not None:
196215
qkv = self.qkv(x)
197-
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
216+
qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
198217
q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim
199218
else:
200-
q = self.q_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2) # B, num_heads, N, C
201-
k = self.k_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2)
202-
v = self.v_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2)
219+
q = self.q_proj(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
220+
k = self.k_proj(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
221+
v = self.v_proj(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
203222

204223
q, k = self.q_norm(q), self.k_norm(k)
205224

@@ -224,7 +243,7 @@ def forward(
224243
attn = self.attn_drop(attn)
225244
x = attn @ v
226245

227-
x = x.transpose(1, 2).reshape(B, N, C)
246+
x = x.transpose(1, 2).reshape(B, N, self.attn_dim)
228247
x = self.norm(x)
229248
x = self.proj(x)
230249
x = self.proj_drop(x)

0 commit comments

Comments
 (0)