diff --git a/tests/test_models.py b/tests/test_models.py index d73d4512fc..f7c647979c 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -56,7 +56,8 @@ 'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*', 'tnt', 'tiny_vit', 'vovnet', 'tresnet', 'rexnet', 'resnetv2', 'repghost', 'repvit', 'pvt_v2', 'nextvit', 'nest', 'mambaout', 'inception_next', 'inception_v4', 'hgnet', 'gcvit', 'focalnet', 'efficientformer_v2', 'edgenext', - 'davit', 'rdnet', 'convnext', 'pit', 'starnet', 'shvit', 'fasternet', 'swiftformer', 'ghostnet', 'naflexvit' + 'davit', 'rdnet', 'convnext', 'pit', 'starnet', 'shvit', 'fasternet', 'swiftformer', 'ghostnet', 'naflexvit', + 'csatv2' ] # transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output. diff --git a/timm/layers/attention.py b/timm/layers/attention.py index 21329107fb..37bb6b6e9a 100644 --- a/timm/layers/attention.py +++ b/timm/layers/attention.py @@ -29,6 +29,8 @@ def __init__( self, dim: int, num_heads: int = 8, + attn_head_dim: Optional[int] = None, + dim_out: Optional[int] = None, qkv_bias: bool = False, qk_norm: bool = False, scale_norm: bool = False, @@ -37,36 +39,45 @@ def __init__( proj_drop: float = 0., norm_layer: Optional[Type[nn.Module]] = None, device=None, - dtype=None + dtype=None, ) -> None: """Initialize the Attention module. Args: - dim: Input dimension of the token embeddings - num_heads: Number of attention heads - qkv_bias: Whether to use bias in the query, key, value projections - qk_norm: Whether to apply normalization to query and key vectors - proj_bias: Whether to use bias in the output projection - attn_drop: Dropout rate applied to the attention weights - proj_drop: Dropout rate applied after the output projection - norm_layer: Normalization layer constructor for QK normalization if enabled + dim: Input dimension of the token embeddings. + num_heads: Number of attention heads. + attn_head_dim: Dimension of each attention head. If None, computed as dim // num_heads. + dim_out: Output dimension. If None, same as dim. + qkv_bias: Whether to use bias in the query, key, value projections. + qk_norm: Whether to apply normalization to query and key vectors. + scale_norm: Whether to apply normalization to attention output before projection. + proj_bias: Whether to use bias in the output projection. + attn_drop: Dropout rate applied to the attention weights. + proj_drop: Dropout rate applied after the output projection. + norm_layer: Normalization layer constructor for QK normalization if enabled. """ super().__init__() dd = {'device': device, 'dtype': dtype} - assert dim % num_heads == 0, 'dim should be divisible by num_heads' + dim_out = dim_out or dim + head_dim = attn_head_dim + if head_dim is None: + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + head_dim = dim // num_heads if qk_norm or scale_norm: assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True' + self.num_heads = num_heads - self.head_dim = dim // num_heads - self.scale = self.head_dim ** -0.5 + self.head_dim = head_dim + self.attn_dim = num_heads * head_dim + self.scale = head_dim ** -0.5 self.fused_attn = use_fused_attn() - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd) - self.q_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity() - self.k_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity() + self.qkv = nn.Linear(dim, self.attn_dim * 3, bias=qkv_bias, **dd) + self.q_norm = norm_layer(head_dim, **dd) if qk_norm else nn.Identity() + self.k_norm = norm_layer(head_dim, **dd) if qk_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) - self.norm = norm_layer(dim, **dd) if scale_norm else nn.Identity() - self.proj = nn.Linear(dim, dim, bias=proj_bias, **dd) + self.norm = norm_layer(self.attn_dim, **dd) if scale_norm else nn.Identity() + self.proj = nn.Linear(self.attn_dim, dim_out, bias=proj_bias, **dd) self.proj_drop = nn.Dropout(proj_drop) def forward( @@ -93,7 +104,7 @@ def forward( attn = self.attn_drop(attn) x = attn @ v - x = x.transpose(1, 2).reshape(B, N, C) + x = x.transpose(1, 2).reshape(B, N, self.attn_dim) x = self.norm(x) x = self.proj(x) x = self.proj_drop(x) @@ -114,6 +125,7 @@ def __init__( self, dim: int, num_heads: int = 8, + dim_out: Optional[int] = None, qkv_bias: bool = True, qkv_fused: bool = True, num_prefix_tokens: int = 1, @@ -133,45 +145,52 @@ def __init__( Args: dim: Input dimension of the token embeddings num_heads: Number of attention heads + dim_out: Output dimension. If None, same as dim. qkv_bias: Whether to add a bias term to the query, key, and value projections + qkv_fused: Whether to use fused QKV projection (single linear) or separate projections num_prefix_tokens: Number of reg/cls tokens at the beginning of the sequence that should not have position embeddings applied attn_drop: Dropout rate for attention weights proj_drop: Dropout rate for the output projection - attn_head_dim: Dimension of each attention head (if None, computed as dim // num_heads) + attn_head_dim: Dimension of each attention head. If None, computed as dim // num_heads. norm_layer: Normalization layer constructor to use for QK and scale normalization qk_norm: Enable normalization of query (Q) and key (K) vectors with norm_layer scale_norm: Enable normalization (scaling) of attention output with norm_layer + proj_bias: Whether to use bias in the output projection rotate_half: Use 'half' ROPE layout instead of default 'interleaved' """ super().__init__() dd = {'device': device, 'dtype': dtype} + dim_out = dim_out or dim + head_dim = attn_head_dim + if head_dim is None: + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + head_dim = dim // num_heads if scale_norm or qk_norm: assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True' + self.num_heads = num_heads - head_dim = dim // num_heads - if attn_head_dim is not None: - head_dim = attn_head_dim - attn_dim = head_dim * self.num_heads + self.head_dim = head_dim + self.attn_dim = head_dim * num_heads self.scale = head_dim ** -0.5 self.num_prefix_tokens = num_prefix_tokens self.fused_attn = use_fused_attn() self.rotate_half = rotate_half if qkv_fused: - self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias, **dd) + self.qkv = nn.Linear(dim, self.attn_dim * 3, bias=qkv_bias, **dd) self.q_proj = self.k_proj = self.v_proj = None else: self.qkv = None - self.q_proj = nn.Linear(dim, attn_dim, bias=qkv_bias, **dd) - self.k_proj = nn.Linear(dim, attn_dim, bias=qkv_bias, **dd) - self.v_proj = nn.Linear(dim, attn_dim, bias=qkv_bias, **dd) + self.q_proj = nn.Linear(dim, self.attn_dim, bias=qkv_bias, **dd) + self.k_proj = nn.Linear(dim, self.attn_dim, bias=qkv_bias, **dd) + self.v_proj = nn.Linear(dim, self.attn_dim, bias=qkv_bias, **dd) self.q_norm = norm_layer(head_dim, **dd) if qk_norm else nn.Identity() self.k_norm = norm_layer(head_dim, **dd) if qk_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) - self.norm = norm_layer(attn_dim, **dd) if scale_norm else nn.Identity() - self.proj = nn.Linear(attn_dim, dim, bias=proj_bias, **dd) + self.norm = norm_layer(self.attn_dim, **dd) if scale_norm else nn.Identity() + self.proj = nn.Linear(self.attn_dim, dim_out, bias=proj_bias, **dd) self.proj_drop = nn.Dropout(proj_drop) def forward( @@ -188,18 +207,18 @@ def forward( attn_mask: Optional attention mask to apply during attention computation Returns: - Tensor of shape (batch_size, sequence_length, embedding_dim) + Tensor of shape (batch_size, sequence_length, dim_out) """ B, N, C = x.shape if self.qkv is not None: qkv = self.qkv(x) - qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim else: - q = self.q_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2) # B, num_heads, N, C - k = self.k_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2) - v = self.v_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2) + q = self.q_proj(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) + k = self.k_proj(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) + v = self.v_proj(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) q, k = self.q_norm(q), self.k_norm(k) @@ -224,7 +243,7 @@ def forward( attn = self.attn_drop(attn) x = attn @ v - x = x.transpose(1, 2).reshape(B, N, C) + x = x.transpose(1, 2).reshape(B, N, self.attn_dim) x = self.norm(x) x = self.proj(x) x = self.proj_drop(x) diff --git a/timm/models/__init__.py b/timm/models/__init__.py index dc8cfa0359..e11c609bee 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -7,6 +7,7 @@ from .convmixer import * from .convnext import * from .crossvit import * +from .csatv2 import * from .cspnet import * from .davit import * from .deit import * diff --git a/timm/models/csatv2.py b/timm/models/csatv2.py new file mode 100644 index 0000000000..186498e805 --- /dev/null +++ b/timm/models/csatv2.py @@ -0,0 +1,797 @@ +"""CSATv2 + +A frequency-domain vision model using DCT transforms with spatial attention. + +Paper: TBD +""" +import math +import warnings +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.layers import trunc_normal_, DropPath, Mlp, LayerNorm2d, Attention, NormMlpClassifierHead +from timm.layers.grn import GlobalResponseNorm +from timm.models._builder import build_model_with_cfg +from timm.models._features import feature_take_indices +from timm.models._manipulate import checkpoint, checkpoint_seq +from ._registry import register_model, generate_default_cfgs + +__all__ = ['CSATv2', 'csatv2'] + +# DCT frequency normalization statistics (Y, Cb, Cr channels x 64 coefficients) +_DCT_MEAN = [ + [932.42657, -0.00260, 0.33415, -0.02840, 0.00003, -0.02792, -0.00183, 0.00006, + 0.00032, 0.03402, -0.00571, 0.00020, 0.00006, -0.00038, -0.00558, -0.00116, + -0.00000, -0.00047, -0.00008, -0.00030, 0.00942, 0.00161, -0.00009, -0.00006, + -0.00014, -0.00035, 0.00001, -0.00220, 0.00033, -0.00002, -0.00003, -0.00020, + 0.00007, -0.00000, 0.00005, 0.00293, -0.00004, 0.00006, 0.00019, 0.00004, + 0.00006, -0.00015, -0.00002, 0.00007, 0.00010, -0.00004, 0.00008, 0.00000, + 0.00008, -0.00001, 0.00015, 0.00002, 0.00007, 0.00003, 0.00004, -0.00001, + 0.00004, -0.00000, 0.00002, -0.00000, -0.00008, -0.00000, -0.00003, 0.00003], + [962.34735, -0.00428, 0.09835, 0.00152, -0.00009, 0.00312, -0.00141, -0.00001, + -0.00013, 0.01050, 0.00065, 0.00006, -0.00000, 0.00003, 0.00264, 0.00000, + 0.00001, 0.00007, -0.00006, 0.00003, 0.00341, 0.00163, 0.00004, 0.00003, + -0.00001, 0.00008, -0.00000, 0.00090, 0.00018, -0.00006, -0.00001, 0.00007, + -0.00003, -0.00001, 0.00006, 0.00084, -0.00000, -0.00001, 0.00000, 0.00004, + -0.00001, -0.00002, 0.00000, 0.00001, 0.00002, 0.00001, 0.00004, 0.00011, + 0.00000, -0.00003, 0.00011, -0.00002, 0.00001, 0.00001, 0.00001, 0.00001, + -0.00007, -0.00003, 0.00001, 0.00000, 0.00001, 0.00002, 0.00001, 0.00000], + [1053.16101, -0.00213, -0.09207, 0.00186, 0.00013, 0.00034, -0.00119, 0.00002, + 0.00011, -0.00984, 0.00046, -0.00007, -0.00001, -0.00005, 0.00180, 0.00042, + 0.00002, -0.00010, 0.00004, 0.00003, -0.00301, 0.00125, -0.00002, -0.00003, + -0.00001, -0.00001, -0.00001, 0.00056, 0.00021, 0.00001, -0.00001, 0.00002, + -0.00001, -0.00001, 0.00005, -0.00070, -0.00002, -0.00002, 0.00005, -0.00004, + -0.00000, 0.00002, -0.00002, 0.00001, 0.00000, -0.00003, 0.00004, 0.00007, + 0.00001, 0.00000, 0.00013, -0.00000, 0.00000, 0.00002, -0.00000, -0.00001, + -0.00004, -0.00003, 0.00000, 0.00001, -0.00001, 0.00001, -0.00000, 0.00000], +] + +_DCT_VAR = [ + [270372.37500, 6287.10645, 5974.94043, 1653.10889, 1463.91748, 1832.58997, 755.92468, 692.41528, + 648.57184, 641.46881, 285.79288, 301.62100, 380.43405, 349.84027, 374.15891, 190.30960, + 190.76746, 221.64578, 200.82646, 145.87979, 126.92046, 62.14622, 67.75562, 102.42001, + 129.74922, 130.04631, 103.12189, 97.76417, 53.17402, 54.81048, 73.48712, 81.04342, + 69.35100, 49.06024, 33.96053, 37.03279, 20.48858, 24.94830, 33.90822, 44.54912, + 47.56363, 40.03160, 30.43313, 22.63899, 26.53739, 26.57114, 21.84404, 17.41557, + 15.18253, 10.69678, 11.24111, 12.97229, 15.08971, 15.31646, 8.90409, 7.44213, + 6.66096, 6.97719, 4.17834, 3.83882, 4.51073, 2.36646, 2.41363, 1.48266], + [18839.21094, 321.70932, 300.15259, 77.47830, 76.02293, 89.04748, 33.99642, 34.74807, + 32.12333, 28.19588, 12.04675, 14.26871, 18.45779, 16.59588, 15.67892, 7.37718, + 8.56312, 10.28946, 9.41013, 6.69090, 5.16453, 2.55186, 3.03073, 4.66765, + 5.85418, 5.74644, 4.33702, 3.66948, 1.95107, 2.26034, 3.06380, 3.50705, + 3.06359, 2.19284, 1.54454, 1.57860, 0.97078, 1.13941, 1.48653, 1.89996, + 1.95544, 1.64950, 1.24754, 0.93677, 1.09267, 1.09516, 0.94163, 0.78966, + 0.72489, 0.50841, 0.50909, 0.55664, 0.63111, 0.64125, 0.38847, 0.33378, + 0.30918, 0.33463, 0.20875, 0.19298, 0.21903, 0.13380, 0.13444, 0.09554], + [17127.39844, 292.81421, 271.45209, 66.64056, 63.60253, 76.35437, 28.06587, 27.84831, + 25.96656, 23.60370, 9.99173, 11.34992, 14.46955, 12.92553, 12.69353, 5.91537, + 6.60187, 7.90891, 7.32825, 5.32785, 4.29660, 2.13459, 2.44135, 3.66021, + 4.50335, 4.38959, 3.34888, 2.97181, 1.60633, 1.77010, 2.35118, 2.69018, + 2.38189, 1.74596, 1.26014, 1.31684, 0.79327, 0.92046, 1.17670, 1.47609, + 1.50914, 1.28725, 0.99898, 0.74832, 0.85736, 0.85800, 0.74663, 0.63508, + 0.58748, 0.41098, 0.41121, 0.44663, 0.50277, 0.51519, 0.31729, 0.27336, + 0.25399, 0.27241, 0.17353, 0.16255, 0.18440, 0.11602, 0.11511, 0.08450], +] + + +def _zigzag_permutation(rows: int, cols: int) -> List[int]: + """Generate zigzag scan order for DCT coefficients.""" + idx_matrix = np.arange(0, rows * cols, 1).reshape(rows, cols).tolist() + dia = [[] for _ in range(rows + cols - 1)] + zigzag = [] + for i in range(rows): + for j in range(cols): + s = i + j + if s % 2 == 0: + dia[s].insert(0, idx_matrix[i][j]) + else: + dia[s].append(idx_matrix[i][j]) + for d in dia: + zigzag.extend(d) + return zigzag + + +def _dct_kernel_type_2( + kernel_size: int, + orthonormal: bool, + device=None, + dtype=None, +) -> torch.Tensor: + """Generate Type-II DCT kernel matrix.""" + dd = dict(device=device, dtype=dtype) + x = torch.eye(kernel_size, **dd) + v = x.clone().contiguous().view(-1, kernel_size) + v = torch.cat([v, v.flip([1])], dim=-1) + v = torch.fft.fft(v, dim=-1)[:, :kernel_size] + try: + k = torch.tensor(-1j, **dd) * torch.pi * torch.arange(kernel_size, **dd)[None, :] + except AttributeError: + k = torch.tensor(-1j, **dd) * math.pi * torch.arange(kernel_size, **dd)[None, :] + k = torch.exp(k / (kernel_size * 2)) + v = v * k + v = v.real + if orthonormal: + v[:, 0] = v[:, 0] * torch.sqrt(torch.tensor(1 / (kernel_size * 4), **dd)) + v[:, 1:] = v[:, 1:] * torch.sqrt(torch.tensor(1 / (kernel_size * 2), **dd)) + v = v.contiguous().view(*x.shape) + return v + + +def _dct_kernel_type_3( + kernel_size: int, + orthonormal: bool, + device=None, + dtype=None, +) -> torch.Tensor: + """Generate Type-III DCT kernel matrix (inverse of Type-II).""" + return torch.linalg.inv(_dct_kernel_type_2(kernel_size, orthonormal, device, dtype)) + + +class Dct1d(nn.Module): + """1D Discrete Cosine Transform layer.""" + + def __init__( + self, + kernel_size: int, + kernel_type: int = 2, + orthonormal: bool = True, + device=None, + dtype=None, + ) -> None: + dd = dict(device=device, dtype=dtype) + super().__init__() + kernel = {'2': _dct_kernel_type_2, '3': _dct_kernel_type_3} + dct_weights = kernel[f'{kernel_type}'](kernel_size, orthonormal, **dd).T + self.register_buffer('weights', dct_weights) + self.register_parameter('bias', None) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return F.linear(x, self.weights, self.bias) + + +class Dct2d(nn.Module): + """2D Discrete Cosine Transform layer.""" + + def __init__( + self, + kernel_size: int, + kernel_type: int = 2, + orthonormal: bool = True, + device=None, + dtype=None, + ) -> None: + dd = dict(device=device, dtype=dtype) + super().__init__() + self.transform = Dct1d(kernel_size, kernel_type, orthonormal, **dd) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.transform(self.transform(x).transpose(-1, -2)).transpose(-1, -2) + + +class LearnableDct2d(nn.Module): + """Learnable 2D DCT stem with RGB to YCbCr conversion and frequency selection.""" + + def __init__( + self, + kernel_size: int, + kernel_type: int = 2, + orthonormal: bool = True, + device=None, + dtype=None, + ) -> None: + dd = dict(device=device, dtype=dtype) + super().__init__() + self.k = kernel_size + self.unfold = nn.Unfold(kernel_size=(kernel_size, kernel_size), stride=(kernel_size, kernel_size)) + self.transform = Dct2d(kernel_size, kernel_type, orthonormal, **dd) + self.permutation = _zigzag_permutation(kernel_size, kernel_size) + self.conv_y = nn.Conv2d(kernel_size ** 2, 24, kernel_size=1, padding=0, **dd) + self.conv_cb = nn.Conv2d(kernel_size ** 2, 4, kernel_size=1, padding=0, **dd) + self.conv_cr = nn.Conv2d(kernel_size ** 2, 4, kernel_size=1, padding=0, **dd) + + self.register_buffer('mean', torch.tensor(_DCT_MEAN, device=device), persistent=False) + self.register_buffer('var', torch.tensor(_DCT_VAR, device=device), persistent=False) + self.register_buffer('imagenet_mean', torch.tensor([0.485, 0.456, 0.406], device=device), persistent=False) + self.register_buffer('imagenet_std', torch.tensor([0.229, 0.224, 0.225], device=device), persistent=False) + + def _denormalize(self, x: torch.Tensor) -> torch.Tensor: + """Convert from ImageNet normalized to [0, 255] range.""" + return x.mul(self.imagenet_std).add_(self.imagenet_mean) * 255 + + def _rgb_to_ycbcr(self, x: torch.Tensor) -> torch.Tensor: + """Convert RGB to YCbCr color space.""" + y = (x[:, :, :, 0] * 0.299) + (x[:, :, :, 1] * 0.587) + (x[:, :, :, 2] * 0.114) + cb = 0.564 * (x[:, :, :, 2] - y) + 128 + cr = 0.713 * (x[:, :, :, 0] - y) + 128 + return torch.stack([y, cb, cr], dim=-1) + + def _frequency_normalize(self, x: torch.Tensor) -> torch.Tensor: + """Normalize DCT coefficients using precomputed statistics.""" + std = self.var ** 0.5 + 1e-8 + return (x - self.mean) / std + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, c, h, w = x.shape + x = x.permute(0, 2, 3, 1) + x = self._denormalize(x) + x = self._rgb_to_ycbcr(x) + x = x.permute(0, 3, 1, 2) + x = self.unfold(x).transpose(-1, -2) + x = x.reshape(b, h // self.k, w // self.k, c, self.k, self.k) + x = self.transform(x) + x = x.reshape(-1, c, self.k * self.k) + x = x[:, :, self.permutation] + x = self._frequency_normalize(x) + x = x.reshape(b, h // self.k, w // self.k, c, -1) + x = x.permute(0, 3, 4, 1, 2).contiguous() + x_y = self.conv_y(x[:, 0]) + x_cb = self.conv_cb(x[:, 1]) + x_cr = self.conv_cr(x[:, 2]) + return torch.cat([x_y, x_cb, x_cr], dim=1) + + +class Dct2dStats(nn.Module): + """Utility module to compute DCT coefficient statistics.""" + + def __init__( + self, + kernel_size: int, + kernel_type: int = 2, + orthonormal: bool = True, + device=None, + dtype=None, + ) -> None: + dd = dict(device=device, dtype=dtype) + super().__init__() + self.k = kernel_size + self.unfold = nn.Unfold(kernel_size=(kernel_size, kernel_size), stride=(kernel_size, kernel_size)) + self.transform = Dct2d(kernel_size, kernel_type, orthonormal, **dd) + self.permutation = _zigzag_permutation(kernel_size, kernel_size) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + b, c, h, w = x.shape + x = self.unfold(x).transpose(-1, -2) + x = x.reshape(b, h // self.k, w // self.k, c, self.k, self.k) + x = self.transform(x) + x = x.reshape(-1, c, self.k * self.k) + x = x[:, :, self.permutation] + x = x.reshape(b * (h // self.k) * (w // self.k), c, -1) + + mean_list = torch.zeros([3, 64]) + var_list = torch.zeros([3, 64]) + for i in range(3): + mean_list[i] = torch.mean(x[:, i], dim=0) + var_list[i] = torch.var(x[:, i], dim=0) + return mean_list, var_list + + +class Block(nn.Module): + """ConvNeXt-style block with spatial attention.""" + + def __init__( + self, + dim: int, + drop_path: float = 0., + device=None, + dtype=None, + ) -> None: + dd = dict(device=device, dtype=dtype) + super().__init__() + self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, **dd) + self.norm = nn.LayerNorm(dim, eps=1e-6, **dd) + self.pwconv1 = nn.Linear(dim, 4 * dim, **dd) + self.act = nn.GELU() + self.grn = GlobalResponseNorm(4 * dim, channels_last=True, **dd) + self.pwconv2 = nn.Linear(4 * dim, dim, **dd) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.attn = SpatialAttention(**dd) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + x = self.dwconv(x) + x = x.permute(0, 2, 3, 1) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.grn(x) + x = self.pwconv2(x) + x = x.permute(0, 3, 1, 2) + + attn = self.attn(x) + attn = F.interpolate(attn, size=x.shape[2:], mode='bilinear', align_corners=True) + x = x * attn + + return shortcut + self.drop_path(x) + + +class SpatialTransformerBlock(nn.Module): + """Lightweight transformer block for spatial attention (1-channel, 7x7 grid). + + This is a simplified transformer with single-head, 1-dim attention over spatial + positions. Used inside SpatialAttention where input is 1 channel at 7x7 resolution. + """ + + def __init__( + self, + device=None, + dtype=None, + ) -> None: + dd = dict(device=device, dtype=dtype) + super().__init__() + # Single-head attention with 1-dim q/k/v (no output projection needed) + self.pos_embed = PosConv(in_chans=1, **dd) + self.norm1 = nn.LayerNorm(1, **dd) + self.qkv = nn.Linear(1, 3, bias=False, **dd) + + # Feedforward: 1 -> 4 -> 1 + self.norm2 = nn.LayerNorm(1, **dd) + self.mlp = Mlp(1, 4, 1, act_layer=nn.GELU, **dd) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, C, H, W = x.shape + + # Attention block + shortcut = x + x_t = x.flatten(2).transpose(1, 2) # (B, N, 1) + x_t = self.norm1(x_t) + x_t = self.pos_embed(x_t, (H, W)) + + # Simple single-head attention with scalar q/k/v + qkv = self.qkv(x_t) # (B, N, 3) + q, k, v = qkv.unbind(-1) # each (B, N) + attn = (q @ k.transpose(-1, -2)).softmax(dim=-1) # (B, N, N) + x_t = (attn @ v).unsqueeze(-1) # (B, N, 1) + + x_t = x_t.transpose(1, 2).reshape(B, C, H, W) + x = shortcut + x_t + + # Feedforward block + shortcut = x + x_t = x.flatten(2).transpose(1, 2) + x_t = self.mlp(self.norm2(x_t)) + x_t = x_t.transpose(1, 2).reshape(B, C, H, W) + x = shortcut + x_t + + return x + + +class SpatialAttention(nn.Module): + """Spatial attention module using channel statistics and transformer.""" + + def __init__( + self, + device=None, + dtype=None, + ) -> None: + dd = dict(device=device, dtype=dtype) + super().__init__() + self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) + self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3, **dd) + self.attn = SpatialTransformerBlock(**dd) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_avg = x.mean(dim=1, keepdim=True) + x_max = x.amax(dim=1, keepdim=True) + x = torch.cat([x_avg, x_max], dim=1) + x = self.avgpool(x) + x = self.conv(x) + x = self.attn(x) + return x + + +class TransformerBlock(nn.Module): + """Transformer block with optional downsampling and convolutional position encoding.""" + + def __init__( + self, + inp: int, + oup: int, + num_heads: int = 8, + attn_head_dim: int = 32, + downsample: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., + drop_path: float = 0., + device=None, + dtype=None, + ) -> None: + dd = dict(device=device, dtype=dtype) + super().__init__() + hidden_dim = int(inp * 4) + self.downsample = downsample + + if self.downsample: + self.pool1 = nn.MaxPool2d(3, 2, 1) + self.pool2 = nn.MaxPool2d(3, 2, 1) + self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False, **dd) + else: + self.pool1 = nn.Identity() + self.pool2 = nn.Identity() + self.proj = nn.Identity() + + self.pos_embed = PosConv(in_chans=inp, **dd) + self.norm1 = nn.LayerNorm(inp, **dd) + self.attn = Attention( + dim=inp, + num_heads=num_heads, + attn_head_dim=attn_head_dim, + dim_out=oup, + attn_drop=attn_drop, + proj_drop=proj_drop, + **dd, + ) + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = nn.LayerNorm(oup, **dd) + self.mlp = Mlp(oup, hidden_dim, oup, act_layer=nn.GELU, drop=proj_drop, **dd) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.downsample: + shortcut = self.proj(self.pool1(x)) + x_t = self.pool2(x) + B, C, H, W = x_t.shape + x_t = x_t.flatten(2).transpose(1, 2) + x_t = self.norm1(x_t) + x_t = self.pos_embed(x_t, (H, W)) + x_t = self.attn(x_t) + x_t = x_t.transpose(1, 2).reshape(B, -1, H, W) + x = shortcut + self.drop_path1(x_t) + else: + B, C, H, W = x.shape + shortcut = x + x_t = x.flatten(2).transpose(1, 2) + x_t = self.norm1(x_t) + x_t = self.pos_embed(x_t, (H, W)) + x_t = self.attn(x_t) + x_t = x_t.transpose(1, 2).reshape(B, -1, H, W) + x = shortcut + self.drop_path1(x_t) + + # MLP block + B, C, H, W = x.shape + shortcut = x + x_t = x.flatten(2).transpose(1, 2) + x_t = self.mlp(self.norm2(x_t)) + x_t = x_t.transpose(1, 2).reshape(B, C, H, W) + x = shortcut + self.drop_path2(x_t) + + return x + + +class PosConv(nn.Module): + """Convolutional position encoding.""" + + def __init__( + self, + in_chans: int, + device=None, + dtype=None, + ) -> None: + dd = dict(device=device, dtype=dtype) + super().__init__() + self.proj = nn.Conv2d(in_chans, in_chans, kernel_size=3, stride=1, padding=1, bias=True, groups=in_chans, **dd) + + def forward(self, x: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor: + B, N, C = x.shape + H, W = size + cnn_feat = x.transpose(1, 2).view(B, C, H, W) + x = self.proj(cnn_feat) + cnn_feat + return x.flatten(2).transpose(1, 2) + + +class CSATv2(nn.Module): + """CSATv2: Frequency-domain vision model with spatial attention. + + A hybrid architecture that processes images in the DCT frequency domain + with ConvNeXt-style blocks and transformer attention. + """ + + def __init__( + self, + num_classes: int = 1000, + in_chans: int = 3, + dims: Tuple[int, ...] = (32, 72, 168, 386), + depths: Tuple[int, ...] = (2, 2, 8, 6), + transformer_depths: Tuple[int, ...] = (0, 0, 2, 2), + drop_path_rate: float = 0.0, + transformer_drop_path: bool = False, + global_pool: str = 'avg', + device=None, + dtype=None, + **kwargs, + ) -> None: + dd = dict(device=device, dtype=dtype) + super().__init__() + if in_chans != 3: + warnings.warn( + f'CSATv2 is designed for 3-channel RGB input. ' + f'in_chans={in_chans} may not work correctly with the DCT stem.' + ) + self.num_classes = num_classes + self.global_pool = global_pool + self.grad_checkpointing = False + + self.num_features = dims[-1] + self.head_hidden_size = self.num_features + + # Build feature_info dynamically + self.feature_info = [dict(num_chs=dims[0], reduction=8, module='stem_dct')] + reduction = 8 + for i, dim in enumerate(dims): + if i > 0: + reduction *= 2 + self.feature_info.append(dict(num_chs=dim, reduction=reduction, module=f'stages.{i}')) + + # Build drop path rates for all blocks (0 for transformer blocks when transformer_drop_path=False) + total_blocks = sum(depths) if transformer_drop_path else sum(d - t for d, t in zip(depths, transformer_depths)) + dp_iter = iter(torch.linspace(0, drop_path_rate, total_blocks).tolist()) + dp_rates = [] + for depth, t_depth in zip(depths, transformer_depths): + dp_rates += [next(dp_iter) for _ in range(depth - t_depth)] + dp_rates += [next(dp_iter) if transformer_drop_path else 0. for _ in range(t_depth)] + + self.stem_dct = LearnableDct2d(8, **dd) + + # Build stages dynamically + dp_iter = iter(dp_rates) + stages = [] + for i, (dim, depth, t_depth) in enumerate(zip(dims, depths, transformer_depths)): + layers = ( + # Downsample at start of stage (except first stage) + ([nn.Conv2d(dims[i - 1], dim, kernel_size=2, stride=2, **dd)] if i > 0 else []) + + # Conv blocks + [Block(dim=dim, drop_path=next(dp_iter), **dd) for _ in range(depth - t_depth)] + + # Transformer blocks at end of stage + [TransformerBlock(inp=dim, oup=dim, drop_path=next(dp_iter), **dd) for _ in range(t_depth)] + + # Trailing LayerNorm (except last stage) + ([LayerNorm2d(dim, eps=1e-6, **dd)] if i < len(depths) - 1 else []) + ) + stages.append(nn.Sequential(*layers)) + self.stages = nn.Sequential(*stages) + + self.head = NormMlpClassifierHead(dims[-1], num_classes, pool_type=global_pool, **dd) + + self.apply(self._init_weights) + + def _init_weights(self, m: nn.Module) -> None: + if isinstance(m, (nn.Conv2d, nn.Linear)): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + @torch.jit.ignore + def get_classifier(self) -> nn.Module: + return self.head.fc + + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None: + self.num_classes = num_classes + if global_pool is not None: + self.global_pool = global_pool + self.head.reset(num_classes, pool_type=global_pool) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable: bool = True) -> None: + self.grad_checkpointing = enable + + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + x = self.stem_dct(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.stages, x) + else: + x = self.stages(x) + return x + + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """Forward pass returning intermediate features. + + Args: + x: Input image tensor. + indices: Indices of features to return (0=stem_dct, 1-4=stages). None returns all. + norm: Apply norm layer to final intermediate (unused, for API compat). + stop_early: Stop iterating when last desired intermediate is reached. + output_fmt: Output format, must be 'NCHW'. + intermediates_only: Only return intermediate features. + + Returns: + List of intermediate features or tuple of (final features, intermediates). + """ + assert output_fmt == 'NCHW', 'Output format must be NCHW.' + intermediates = [] + # 5 feature levels: stem_dct (0) + stages 0-3 (1-4) + take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices) + + x = self.stem_dct(x) + if 0 in take_indices: + intermediates.append(x) + + if torch.jit.is_scripting() or not stop_early: + stages = self.stages + else: + # max_index is 0-4, stages are 1-4, so we need max_index stages + stages = self.stages[:max_index] if max_index > 0 else [] + + for feat_idx, stage in enumerate(stages): + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(stage, x) + else: + x = stage(x) + if feat_idx + 1 in take_indices: # +1 because stem is index 0 + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ) -> List[int]: + """Prune layers not required for specified intermediates. + + Args: + indices: Indices of intermediate layers to keep (0=stem_dct, 1-4=stages). + prune_norm: Whether to prune the final norm layer. + prune_head: Whether to prune the classifier head. + + Returns: + List of indices that were kept. + """ + # 5 feature levels: stem_dct (0) + stages 0-3 (1-4) + take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices) + # max_index is 0-4, stages are 1-4, so we keep max_index stages + self.stages = self.stages[:max_index] if max_index > 0 else nn.Sequential() + + if prune_norm: + self.head.norm = nn.Identity() + if prune_head: + self.reset_classifier(0, '') + + return take_indices + + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + return self.head(x, pre_logits=pre_logits) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.forward_features(x) + return self.forward_head(x) + + +default_cfgs = generate_default_cfgs({ + 'csatv2': { + 'url': 'https://huggingface.co/Hyunil/CSATv2/resolve/main/CSATv2_ImageNet_timm.pth', + 'num_classes': 1000, + 'input_size': (3, 512, 512), + 'mean': (0.485, 0.456, 0.406), + 'std': (0.229, 0.224, 0.225), + 'interpolation': 'bilinear', + 'crop_pct': 1.0, + 'classifier': 'head.fc', + 'first_conv': [], + }, +}) + + +def checkpoint_filter_fn(state_dict: dict, model: nn.Module) -> dict: + """Remap original CSATv2 checkpoint to timm format. + + Handles two key structural changes: + 1. Stage naming: stages1/2/3/4 -> stages.0/1/2/3 + 2. Downsample position: moved from end of stage N to start of stage N+1 + """ + if 'stages.0.0.grn.weight' in state_dict: + return state_dict # Already in timm format + + import re + + # Downsample indices in original checkpoint (Conv2d at end of each stage) + # These move to index 0 of the next stage + downsample_idx = {1: 3, 2: 3, 3: 9} # stage -> downsample index + + def remap_stage(m): + stage = int(m.group(1)) + idx = int(m.group(2)) + rest = m.group(3) + if stage in downsample_idx and idx == downsample_idx[stage]: + # Downsample moves to start of next stage + return f'stages.{stage}.0.{rest}' + elif stage == 1: + # Stage 1 -> stages.0, indices unchanged + return f'stages.0.{idx}.{rest}' + else: + # Stages 2-4 -> stages.1-3, indices shift +1 (after downsample) + return f'stages.{stage - 1}.{idx + 1}.{rest}' + + out_dict = {} + for k, v in state_dict.items(): + # Remap dct -> stem_dct, and Y_Conv/Cb_Conv/Cr_Conv -> conv_y/conv_cb/conv_cr + k = re.sub(r'^dct\.', 'stem_dct.', k) + k = k.replace('.Y_Conv.', '.conv_y.') + k = k.replace('.Cb_Conv.', '.conv_cb.') + k = k.replace('.Cr_Conv.', '.conv_cr.') + + # Remap stage names with index adjustments for downsample relocation + k = re.sub(r'^stages([1-4])\.(\d+)\.(.*)$', remap_stage, k) + + # Remap GRN: gamma/beta -> weight/bias with reshape + if 'grn.gamma' in k: + k = k.replace('grn.gamma', 'grn.weight') + v = v.reshape(-1) + elif 'grn.beta' in k: + k = k.replace('grn.beta', 'grn.bias') + v = v.reshape(-1) + + # Remap FeedForward (nn.Sequential) to Mlp: net.0 -> fc1, net.3 -> fc2 + # Also rename ff -> mlp, ff_norm -> norm2, attn_norm -> norm1 + if '.ff.net.0.' in k: + k = k.replace('.ff.net.0.', '.mlp.fc1.') + elif '.ff.net.3.' in k: + k = k.replace('.ff.net.3.', '.mlp.fc2.') + elif '.ff_norm.' in k: + k = k.replace('.ff_norm.', '.norm2.') + elif '.attn_norm.' in k: + k = k.replace('.attn_norm.', '.norm1.') + + # Block.attention -> Block.attn (SpatialAttention) + # SpatialAttention.attention -> SpatialAttention.attn (SpatialTransformerBlock) + # Handle nested .attention.attention. first, then remaining .attention. + if '.attention.attention.' in k: + # SpatialTransformerBlock inner attn: remap to_qkv -> qkv + k = k.replace('.attention.attention.attn.to_qkv.', '.attn.attn.qkv.') + k = k.replace('.attention.attention.attn.', '.attn.attn.') + k = k.replace('.attention.attention.', '.attn.attn.') + elif '.attention.' in k: + # Block.attention -> Block.attn (catches SpatialAttention.conv etc) + k = k.replace('.attention.', '.attn.') + + # TransformerBlock: remap attention layer names + # to_qkv -> qkv, to_out.0 -> proj, attn.pos_embed -> pos_embed + # Note: only for TransformerBlock, not SpatialTransformerBlock (which has .attn.attn.) + if '.attn.to_qkv.' in k: + k = k.replace('.attn.to_qkv.', '.attn.qkv.') + elif '.attn.to_out.0.' in k: + k = k.replace('.attn.to_out.0.', '.attn.proj.') + + # TransformerBlock: .attn.pos_embed -> .pos_embed (but not .attn.attn.pos_embed) + if '.attn.pos_embed.' in k and '.attn.attn.' not in k: + k = k.replace('.attn.pos_embed.', '.pos_embed.') + + # Remap head -> head.fc, norm -> head.norm (order matters) + k = re.sub(r'^head\.', 'head.fc.', k) + k = re.sub(r'^norm\.', 'head.norm.', k) + + out_dict[k] = v + + return out_dict + + +def _create_csatv2(variant: str, pretrained: bool = False, **kwargs) -> CSATv2: + out_indices = kwargs.pop('out_indices', (1, 2, 3, 4)) + return build_model_with_cfg( + CSATv2, + variant, + pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(out_indices=out_indices, flatten_sequential=True), + default_cfg=default_cfgs[variant], + **kwargs, + ) + + +@register_model +def csatv2(pretrained: bool = False, **kwargs) -> CSATv2: + return _create_csatv2('csatv2', pretrained, **kwargs)