Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*', 'tnt',
'tiny_vit', 'vovnet', 'tresnet', 'rexnet', 'resnetv2', 'repghost', 'repvit', 'pvt_v2', 'nextvit', 'nest',
'mambaout', 'inception_next', 'inception_v4', 'hgnet', 'gcvit', 'focalnet', 'efficientformer_v2', 'edgenext',
'davit', 'rdnet', 'convnext', 'pit', 'starnet', 'shvit', 'fasternet', 'swiftformer', 'ghostnet', 'naflexvit'
'davit', 'rdnet', 'convnext', 'pit', 'starnet', 'shvit', 'fasternet', 'swiftformer', 'ghostnet', 'naflexvit',
'csatv2'
]

# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
Expand Down
89 changes: 54 additions & 35 deletions timm/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def __init__(
self,
dim: int,
num_heads: int = 8,
attn_head_dim: Optional[int] = None,
dim_out: Optional[int] = None,
qkv_bias: bool = False,
qk_norm: bool = False,
scale_norm: bool = False,
Expand All @@ -37,36 +39,45 @@ def __init__(
proj_drop: float = 0.,
norm_layer: Optional[Type[nn.Module]] = None,
device=None,
dtype=None
dtype=None,
) -> None:
"""Initialize the Attention module.

Args:
dim: Input dimension of the token embeddings
num_heads: Number of attention heads
qkv_bias: Whether to use bias in the query, key, value projections
qk_norm: Whether to apply normalization to query and key vectors
proj_bias: Whether to use bias in the output projection
attn_drop: Dropout rate applied to the attention weights
proj_drop: Dropout rate applied after the output projection
norm_layer: Normalization layer constructor for QK normalization if enabled
dim: Input dimension of the token embeddings.
num_heads: Number of attention heads.
attn_head_dim: Dimension of each attention head. If None, computed as dim // num_heads.
dim_out: Output dimension. If None, same as dim.
qkv_bias: Whether to use bias in the query, key, value projections.
qk_norm: Whether to apply normalization to query and key vectors.
scale_norm: Whether to apply normalization to attention output before projection.
proj_bias: Whether to use bias in the output projection.
attn_drop: Dropout rate applied to the attention weights.
proj_drop: Dropout rate applied after the output projection.
norm_layer: Normalization layer constructor for QK normalization if enabled.
"""
super().__init__()
dd = {'device': device, 'dtype': dtype}
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
dim_out = dim_out or dim
head_dim = attn_head_dim
if head_dim is None:
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
head_dim = dim // num_heads
if qk_norm or scale_norm:
assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True'

self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.head_dim = head_dim
self.attn_dim = num_heads * head_dim
self.scale = head_dim ** -0.5
self.fused_attn = use_fused_attn()

self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd)
self.q_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity()
self.qkv = nn.Linear(dim, self.attn_dim * 3, bias=qkv_bias, **dd)
self.q_norm = norm_layer(head_dim, **dd) if qk_norm else nn.Identity()
self.k_norm = norm_layer(head_dim, **dd) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.norm = norm_layer(dim, **dd) if scale_norm else nn.Identity()
self.proj = nn.Linear(dim, dim, bias=proj_bias, **dd)
self.norm = norm_layer(self.attn_dim, **dd) if scale_norm else nn.Identity()
self.proj = nn.Linear(self.attn_dim, dim_out, bias=proj_bias, **dd)
self.proj_drop = nn.Dropout(proj_drop)

def forward(
Expand All @@ -93,7 +104,7 @@ def forward(
attn = self.attn_drop(attn)
x = attn @ v

x = x.transpose(1, 2).reshape(B, N, C)
x = x.transpose(1, 2).reshape(B, N, self.attn_dim)
x = self.norm(x)
x = self.proj(x)
x = self.proj_drop(x)
Expand All @@ -114,6 +125,7 @@ def __init__(
self,
dim: int,
num_heads: int = 8,
dim_out: Optional[int] = None,
qkv_bias: bool = True,
qkv_fused: bool = True,
num_prefix_tokens: int = 1,
Expand All @@ -133,45 +145,52 @@ def __init__(
Args:
dim: Input dimension of the token embeddings
num_heads: Number of attention heads
dim_out: Output dimension. If None, same as dim.
qkv_bias: Whether to add a bias term to the query, key, and value projections
qkv_fused: Whether to use fused QKV projection (single linear) or separate projections
num_prefix_tokens: Number of reg/cls tokens at the beginning of the sequence that
should not have position embeddings applied
attn_drop: Dropout rate for attention weights
proj_drop: Dropout rate for the output projection
attn_head_dim: Dimension of each attention head (if None, computed as dim // num_heads)
attn_head_dim: Dimension of each attention head. If None, computed as dim // num_heads.
norm_layer: Normalization layer constructor to use for QK and scale normalization
qk_norm: Enable normalization of query (Q) and key (K) vectors with norm_layer
scale_norm: Enable normalization (scaling) of attention output with norm_layer
proj_bias: Whether to use bias in the output projection
rotate_half: Use 'half' ROPE layout instead of default 'interleaved'
"""
super().__init__()
dd = {'device': device, 'dtype': dtype}
dim_out = dim_out or dim
head_dim = attn_head_dim
if head_dim is None:
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
head_dim = dim // num_heads
if scale_norm or qk_norm:
assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True'

self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
attn_dim = head_dim * self.num_heads
self.head_dim = head_dim
self.attn_dim = head_dim * num_heads
self.scale = head_dim ** -0.5
self.num_prefix_tokens = num_prefix_tokens
self.fused_attn = use_fused_attn()
self.rotate_half = rotate_half

if qkv_fused:
self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias, **dd)
self.qkv = nn.Linear(dim, self.attn_dim * 3, bias=qkv_bias, **dd)
self.q_proj = self.k_proj = self.v_proj = None
else:
self.qkv = None
self.q_proj = nn.Linear(dim, attn_dim, bias=qkv_bias, **dd)
self.k_proj = nn.Linear(dim, attn_dim, bias=qkv_bias, **dd)
self.v_proj = nn.Linear(dim, attn_dim, bias=qkv_bias, **dd)
self.q_proj = nn.Linear(dim, self.attn_dim, bias=qkv_bias, **dd)
self.k_proj = nn.Linear(dim, self.attn_dim, bias=qkv_bias, **dd)
self.v_proj = nn.Linear(dim, self.attn_dim, bias=qkv_bias, **dd)

self.q_norm = norm_layer(head_dim, **dd) if qk_norm else nn.Identity()
self.k_norm = norm_layer(head_dim, **dd) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.norm = norm_layer(attn_dim, **dd) if scale_norm else nn.Identity()
self.proj = nn.Linear(attn_dim, dim, bias=proj_bias, **dd)
self.norm = norm_layer(self.attn_dim, **dd) if scale_norm else nn.Identity()
self.proj = nn.Linear(self.attn_dim, dim_out, bias=proj_bias, **dd)
self.proj_drop = nn.Dropout(proj_drop)

def forward(
Expand All @@ -188,18 +207,18 @@ def forward(
attn_mask: Optional attention mask to apply during attention computation

Returns:
Tensor of shape (batch_size, sequence_length, embedding_dim)
Tensor of shape (batch_size, sequence_length, dim_out)
"""
B, N, C = x.shape

if self.qkv is not None:
qkv = self.qkv(x)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim
else:
q = self.q_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2) # B, num_heads, N, C
k = self.k_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2)
v = self.v_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2)
q = self.q_proj(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)

q, k = self.q_norm(q), self.k_norm(k)

Expand All @@ -224,7 +243,7 @@ def forward(
attn = self.attn_drop(attn)
x = attn @ v

x = x.transpose(1, 2).reshape(B, N, C)
x = x.transpose(1, 2).reshape(B, N, self.attn_dim)
x = self.norm(x)
x = self.proj(x)
x = self.proj_drop(x)
Expand Down
1 change: 1 addition & 0 deletions timm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .convmixer import *
from .convnext import *
from .crossvit import *
from .csatv2 import *
from .cspnet import *
from .davit import *
from .deit import *
Expand Down
Loading