From 78d6b196ab1f4f93a7a1bf44e5da3895e975b47c Mon Sep 17 00:00:00 2001 From: gusdlf93 Date: Sat, 6 Dec 2025 23:44:34 +0900 Subject: [PATCH 01/14] Upload CSATv2 --- timm/models/__init__.py | 1 + timm/models/csatv2.py | 612 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 613 insertions(+) create mode 100644 timm/models/csatv2.py diff --git a/timm/models/__init__.py b/timm/models/__init__.py index dc8cfa0359..e11c609bee 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -7,6 +7,7 @@ from .convmixer import * from .convnext import * from .crossvit import * +from .csatv2 import * from .cspnet import * from .davit import * from .deit import * diff --git a/timm/models/csatv2.py b/timm/models/csatv2.py new file mode 100644 index 0000000000..27881fe71c --- /dev/null +++ b/timm/models/csatv2.py @@ -0,0 +1,612 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import warnings +from timm.models._builder import build_model_with_cfg +from ._registry import register_model, generate_default_cfgs +from typing import List +import numpy as np +import math + +__all__ = ['DCT2D'] + +# Helper Functions +mean=[[932.42657,-0.00260,0.33415,-0.02840,0.00003,-0.02792,-0.00183,0.00006,0.00032,0.03402,-0.00571,0.00020,0.00006,-0.00038,-0.00558,-0.00116,-0.00000,-0.00047,-0.00008,-0.00030,0.00942,0.00161,-0.00009,-0.00006,-0.00014,-0.00035,0.00001,-0.00220,0.00033,-0.00002,-0.00003,-0.00020,0.00007,-0.00000,0.00005,0.00293,-0.00004,0.00006,0.00019,0.00004,0.00006,-0.00015,-0.00002,0.00007,0.00010,-0.00004,0.00008,0.00000,0.00008,-0.00001,0.00015,0.00002,0.00007,0.00003,0.00004,-0.00001,0.00004,-0.00000,0.00002,-0.00000,-0.00008,-0.00000,-0.00003,0.00003], +[962.34735,-0.00428,0.09835,0.00152,-0.00009,0.00312,-0.00141,-0.00001,-0.00013,0.01050,0.00065,0.00006,-0.00000,0.00003,0.00264,0.00000,0.00001,0.00007,-0.00006,0.00003,0.00341,0.00163,0.00004,0.00003,-0.00001,0.00008,-0.00000,0.00090,0.00018,-0.00006,-0.00001,0.00007,-0.00003,-0.00001,0.00006,0.00084,-0.00000,-0.00001,0.00000,0.00004,-0.00001,-0.00002,0.00000,0.00001,0.00002,0.00001,0.00004,0.00011,0.00000,-0.00003,0.00011,-0.00002,0.00001,0.00001,0.00001,0.00001,-0.00007,-0.00003,0.00001,0.00000,0.00001,0.00002,0.00001,0.00000], +[1053.16101,-0.00213,-0.09207,0.00186,0.00013,0.00034,-0.00119,0.00002,0.00011,-0.00984,0.00046,-0.00007,-0.00001,-0.00005,0.00180,0.00042,0.00002,-0.00010,0.00004,0.00003,-0.00301,0.00125,-0.00002,-0.00003,-0.00001,-0.00001,-0.00001,0.00056,0.00021,0.00001,-0.00001,0.00002,-0.00001,-0.00001,0.00005,-0.00070,-0.00002,-0.00002,0.00005,-0.00004,-0.00000,0.00002,-0.00002,0.00001,0.00000,-0.00003,0.00004,0.00007,0.00001,0.00000,0.00013,-0.00000,0.00000,0.00002,-0.00000,-0.00001,-0.00004,-0.00003,0.00000,0.00001,-0.00001,0.00001,-0.00000,0.00000]] + +var=[[270372.37500,6287.10645,5974.94043,1653.10889,1463.91748,1832.58997,755.92468,692.41528,648.57184,641.46881,285.79288,301.62100,380.43405,349.84027,374.15891,190.30960,190.76746,221.64578,200.82646,145.87979,126.92046,62.14622,67.75562,102.42001,129.74922,130.04631,103.12189,97.76417,53.17402,54.81048,73.48712,81.04342,69.35100,49.06024,33.96053,37.03279,20.48858,24.94830,33.90822,44.54912,47.56363,40.03160,30.43313,22.63899,26.53739,26.57114,21.84404,17.41557,15.18253,10.69678,11.24111,12.97229,15.08971,15.31646,8.90409,7.44213,6.66096,6.97719,4.17834,3.83882,4.51073,2.36646,2.41363,1.48266], +[18839.21094,321.70932,300.15259,77.47830,76.02293,89.04748,33.99642,34.74807,32.12333,28.19588,12.04675,14.26871,18.45779,16.59588,15.67892,7.37718,8.56312,10.28946,9.41013,6.69090,5.16453,2.55186,3.03073,4.66765,5.85418,5.74644,4.33702,3.66948,1.95107,2.26034,3.06380,3.50705,3.06359,2.19284,1.54454,1.57860,0.97078,1.13941,1.48653,1.89996,1.95544,1.64950,1.24754,0.93677,1.09267,1.09516,0.94163,0.78966,0.72489,0.50841,0.50909,0.55664,0.63111,0.64125,0.38847,0.33378,0.30918,0.33463,0.20875,0.19298,0.21903,0.13380,0.13444,0.09554], +[17127.39844,292.81421,271.45209,66.64056,63.60253,76.35437,28.06587,27.84831,25.96656,23.60370,9.99173,11.34992,14.46955,12.92553,12.69353,5.91537,6.60187,7.90891,7.32825,5.32785,4.29660,2.13459,2.44135,3.66021,4.50335,4.38959,3.34888,2.97181,1.60633,1.77010,2.35118,2.69018,2.38189,1.74596,1.26014,1.31684,0.79327,0.92046,1.17670,1.47609,1.50914,1.28725,0.99898,0.74832,0.85736,0.85800,0.74663,0.63508,0.58748,0.41098,0.41121,0.44663,0.50277,0.51519,0.31729,0.27336,0.25399,0.27241,0.17353,0.16255,0.18440,0.11602,0.11511,0.08450]] + + +def _zigzag_permutation(rows: int, cols: int) -> List[int]: + idx_matrix = np.arange(0, rows * cols, 1).reshape(rows, cols).tolist() + dia = [[] for _ in range(rows + cols - 1)] + zigzag = [] + for i in range(rows): + for j in range(cols): + s = i + j + if s % 2 == 0: + dia[s].insert(0, idx_matrix[i][j]) + else: + dia[s].append(idx_matrix[i][j]) + for d in dia: + zigzag.extend(d) + return zigzag + + +def _dct_kernel_type_2(kernel_size: int, orthonormal: bool, device=None, dtype=None) -> torch.Tensor: + factory_kwargs = dict(device=device, dtype=dtype) + x = torch.eye(kernel_size, **factory_kwargs) + v = x.clone().contiguous().view(-1, kernel_size) + v = torch.cat([v, v.flip([1])], dim=-1) + v = torch.fft.fft(v, dim=-1)[:, :kernel_size] + try: + k = torch.tensor(-1j, **factory_kwargs) * torch.pi * torch.arange(kernel_size, **factory_kwargs)[None, :] + except: + k = torch.tensor(-1j, **factory_kwargs) * math.pi * torch.arange(kernel_size, **factory_kwargs)[None, :] + k = torch.exp(k / (kernel_size * 2)) + v = v * k + v = v.real + if orthonormal: + v[:, 0] = v[:, 0] * torch.sqrt(torch.tensor(1 / (kernel_size * 4), **factory_kwargs)) + v[:, 1:] = v[:, 1:] * torch.sqrt(torch.tensor(1 / (kernel_size * 2), **factory_kwargs)) + v = v.contiguous().view(*x.shape) + return v + + +def _dct_kernel_type_3(kernel_size: int, orthonormal: bool, device=None, dtype=None) -> torch.Tensor: + return torch.linalg.inv(_dct_kernel_type_2(kernel_size, orthonormal, device, dtype)) + + +class _DCT1D(nn.Module): + def __init__(self, kernel_size: int, kernel_type: int = 2, orthonormal: bool = True, + device=None, dtype=None) -> None: + factory_kwargs = dict(device=device, dtype=dtype) + super(_DCT1D, self).__init__() + kernel = {'2': _dct_kernel_type_2, '3': _dct_kernel_type_3} + self.weights = nn.Parameter(kernel[f'{kernel_type}'](kernel_size, orthonormal, **factory_kwargs).T, False) + self.register_parameter('bias', None) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return nn.functional.linear(x, self.weights, self.bias) + + +class _DCT2D(nn.Module): + def __init__(self, kernel_size: int, kernel_type: int = 2, orthonormal: bool = True, + device=None, dtype=None) -> None: + factory_kwargs = dict(device=device, dtype=dtype) + super(_DCT2D, self).__init__() + self.transform = _DCT1D(kernel_size, kernel_type, orthonormal, **factory_kwargs) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.transform(self.transform(x).transpose(-1, -2)).transpose(-1, -2) + + +class Learnable_DCT2D(nn.Module): + def __init__(self, kernel_size: int, kernel_type: int = 2, orthonormal: bool = True, + device=None, dtype=None) -> None: + factory_kwargs = dict(device=device, dtype=dtype) + super(Learnable_DCT2D, self).__init__() + self.k = kernel_size + self.unfold = nn.Unfold(kernel_size=(kernel_size, kernel_size), stride=(kernel_size, kernel_size)) + self.transform = _DCT2D(kernel_size, kernel_type, orthonormal, **factory_kwargs) + self.permutation = _zigzag_permutation(kernel_size, kernel_size) + self.Y_Conv = nn.Conv2d(kernel_size ** 2, 24, kernel_size=1, padding=0) + self.Cb_Conv = nn.Conv2d(kernel_size ** 2, 4, kernel_size=1, padding=0) + self.Cr_Conv = nn.Conv2d(kernel_size ** 2, 4, kernel_size=1, padding=0) + self.mean = torch.tensor(mean, requires_grad=False) + self.var = torch.tensor(var, requires_grad=False) + self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406], requires_grad=False) + self.imagenet_var = torch.tensor([0.229, 0.224, 0.225], requires_grad=False) + + def denormalize(self, x): + x = x.multiply(self.imagenet_var.to(x.device)).add_(self.imagenet_mean.to(x.device)) * 255 + return x + + def rgb2ycbcr(self, x): + y = (x[:, :, :, 0] * 0.299) + (x[:, :, :, 1] * 0.587) + (x[:, :, :, 2] * 0.114) + cb = 0.564 * (x[:, :, :, 2] - y) + 128 + cr = 0.713 * (x[:, :, :, 0] - y) + 128 + x = torch.stack([y, cb, cr], dim=-1) + return x + + def frequncy_normalize(self, x): + x[:, 0, ].sub_(self.mean.to(x.device)[0]).div_((self.var.to(x.device)[0] ** 0.5 + 1e-8)) + x[:, 1, ].sub_(self.mean.to(x.device)[1]).div_((self.var.to(x.device)[1] ** 0.5 + 1e-8)) + x[:, 2, ].sub_(self.mean.to(x.device)[2]).div_((self.var.to(x.device)[2] ** 0.5 + 1e-8)) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, c, h, w = x.shape + x = x.permute(0, 2, 3, 1) + x = self.denormalize(x) + x = self.rgb2ycbcr(x) + x = x.permute(0, 3, 1, 2) + x = self.unfold(x).transpose(-1, -2) + x = x.reshape(b, h // self.k, w // self.k, c, self.k, self.k) + x = self.transform(x) + x = x.reshape(-1, c, self.k * self.k) + x = x[:, :, self.permutation] + x = self.frequncy_normalize(x) + x = x.reshape(b, h // self.k, w // self.k, c, -1) + x = x.permute(0, 3, 4, 1, 2).contiguous() + x_Y = self.Y_Conv(x[:, 0, ]) + x_Cb = self.Cb_Conv(x[:, 1, ]) + x_Cr = self.Cr_Conv(x[:, 2, ]) + x = torch.cat([x_Y, x_Cb, x_Cr], axis=1) + return x + + +class DCT2D(nn.Module): + def __init__(self, kernel_size: int, kernel_type: int = 2, orthonormal: bool = True, + device=None, dtype=None) -> None: + factory_kwargs = dict(device=device, dtype=dtype) + super(DCT2D, self).__init__() + self.k = kernel_size + self.unfold = nn.Unfold(kernel_size=(kernel_size, kernel_size), stride=(kernel_size, kernel_size)) + self.transform = _DCT2D(kernel_size, kernel_type, orthonormal, **factory_kwargs) + self.permutation = _zigzag_permutation(kernel_size, kernel_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, c, h, w = x.shape + x = self.unfold(x).transpose(-1, -2) + x = x.reshape(b, h // self.k, w // self.k, c, self.k, self.k) + x = self.transform(x) + x = x.reshape(-1, c, self.k * self.k) + x = x[:, :, self.permutation] + x = x.reshape(b * (h // self.k) * (w // self.k), c, -1) + + mean_list = torch.zeros([3, 64]) + var_list = torch.zeros([3, 64]) + mean_list[0] = torch.mean(x[:, 0, ], axis=0) + mean_list[1] = torch.mean(x[:, 1, ], axis=0) + mean_list[2] = torch.mean(x[:, 2, ], axis=0) + var_list[0] = torch.var(x[:, 0, ], axis=0) + var_list[1] = torch.var(x[:, 1, ], axis=0) + var_list[2] = torch.var(x[:, 2, ], axis=0) + return mean_list, var_list + + +class Block(nn.Module): + def __init__(self, dim, drop_path=0.): + super().__init__() + self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) + self.norm = LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, 4 * dim) + self.act = nn.GELU() + self.grn = GRN(4 * dim) + self.pwconv2 = nn.Linear(4 * dim, dim) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.attention = Spatial_Attention() + + def forward(self, x): + input = x + x = self.dwconv(x) + x = x.permute(0, 2, 3, 1) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.grn(x) + x = self.pwconv2(x) + x = x.permute(0, 3, 1, 2) + + # Spatial Attention logic + attention = self.attention(x) + # Upsampling to match x spatial size + x = x * nn.UpsamplingBilinear2d(x.shape[2:])(attention) + + x = input + self.drop_path(x) + return x + + +class Spatial_Attention(nn.Module): + def __init__(self): + super().__init__() + self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) + self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3) + # downsample=False, dropout=0. by default + self.attention = TransformerBlock(1, 1, heads=1, dim_head=1, img_size=[7, 7], downsample=False) + + def forward(self, x): + x_avg = x.mean([1]).unsqueeze(1) + x_max = x.max(dim=1).values.unsqueeze(1) + x = torch.cat([x_avg, x_max], dim=1) + x = self.avgpool(x) + x = self.conv(x) + x = self.attention(x) + return x + + +class TransformerBlock(nn.Module): + """ + Refactored TransformerBlock without einops and PreNorm class wrapper. + Manual reshaping is performed in forward(). + """ + + def __init__(self, inp, oup, heads=8, dim_head=32, img_size=None, downsample=False, dropout=0.): + super().__init__() + hidden_dim = int(inp * 4) + + self.downsample = downsample + self.ih, self.iw = img_size + + if self.downsample: + self.pool1 = nn.MaxPool2d(3, 2, 1) + self.pool2 = nn.MaxPool2d(3, 2, 1) + self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False) + + # Attention block components + # Note: In old code, PreNorm wrapped Attention. Here we split them. + self.attn_norm = nn.LayerNorm(inp) + self.attn = Attention(inp, oup, heads, dim_head, dropout) + + # FeedForward block components + self.ff_norm = nn.LayerNorm(oup) + self.ff = FeedForward(oup, hidden_dim, dropout) + + def forward(self, x): + # x shape: (B, C, H, W) + if self.downsample: + # Identity path with projection + shortcut = self.proj(self.pool1(x)) + + # Attention path + x_t = self.pool2(x) + B, C, H, W = x_t.shape + + # Flatten spatial: (B, C, H, W) -> (B, H*W, C) + x_t = x_t.flatten(2).transpose(1, 2) + + # PreNorm -> Attention + x_t = self.attn_norm(x_t) + x_t = self.attn(x_t) + + # Unflatten: (B, H*W, C) -> (B, C, H, W) + x_t = x_t.transpose(1, 2).reshape(B, C, H, W) + + x = shortcut + x_t + else: + # Standard PreNorm Residual Attention + B, C, H, W = x.shape + shortcut = x + + # Flatten + x_t = x.flatten(2).transpose(1, 2) + + # PreNorm -> Attention + x_t = self.attn_norm(x_t) + x_t = self.attn(x_t) + + # Unflatten + x_t = x_t.transpose(1, 2).reshape(B, C, H, W) + + x = shortcut + x_t + + # FeedForward Block + B, C, H, W = x.shape + shortcut = x + + # Flatten + x_t = x.flatten(2).transpose(1, 2) + + # PreNorm -> FeedForward + x_t = self.ff_norm(x_t) + x_t = self.ff(x_t) + + # Unflatten + x_t = x_t.transpose(1, 2).reshape(B, C, H, W) + + x = shortcut + x_t + + return x + + +class Attention(nn.Module): + """ + Refactored Attention without einops.rearrange. + """ + + def __init__(self, inp, oup, heads=8, dim_head=32, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == inp) + + self.heads = heads + self.dim_head = dim_head + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim=-1) + self.to_qkv = nn.Linear(inp, inner_dim * 3, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, oup), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + self.pos_embed = PosCNN(in_chans=inp) + + def forward(self, x): + # x shape: (B, N, C) + # Positional Embedding (expects B, N, C -> internally converts to spatial) + x = self.pos_embed(x) + + B, N, C = x.shape + + # Generate Q, K, V + # qkv: (B, N, 3 * heads * dim_head) + qkv = self.to_qkv(x) + + # Reshape to (B, N, 3, heads, dim_head) and permute to (3, B, heads, N, dim_head) + qkv = qkv.reshape(B, N, 3, self.heads, self.dim_head).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0], qkv[1], qkv[2] # Each is (B, heads, N, dim_head) + + # Attention Score + # (B, heads, N, dim_head) @ (B, heads, dim_head, N) -> (B, heads, N, N) + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + attn = self.attend(dots) + + # Weighted Sum + # (B, heads, N, N) @ (B, heads, N, dim_head) -> (B, heads, N, dim_head) + out = torch.matmul(attn, v) + + # Rearrange output: (B, heads, N, dim_head) -> (B, N, heads*dim_head) + out = out.transpose(1, 2).reshape(B, N, -1) + + out = self.to_out(out) + return out + + +class CSATv2(nn.Module): + def __init__( + self, + img_size=512, + num_classes=1000, + in_chans=3, + drop_path_rate=0.0, + global_pool='avg', + **kwargs, + ): + super().__init__() + + self.num_classes = num_classes + self.global_pool = global_pool + + if isinstance(img_size, (tuple, list)): + img_size = img_size[0] + + self.img_size = img_size + + dims = [32, 72, 168, 386] + channel_order = "channels_first" + depths = [2, 2, 6, 4] + dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + + self.dct = Learnable_DCT2D(8) + + self.stages1 = nn.Sequential( + Block(dim=dims[0], drop_path=dp_rates[0]), + Block(dim=dims[0], drop_path=dp_rates[1]), + LayerNorm(dims[0], eps=1e-6, data_format=channel_order), + nn.Conv2d(dims[0], dims[0 + 1], kernel_size=2, stride=2), + ) + + self.stages2 = nn.Sequential( + Block(dim=dims[1], drop_path=dp_rates[0]), + Block(dim=dims[1], drop_path=dp_rates[1]), + LayerNorm(dims[1], eps=1e-6, data_format=channel_order), + nn.Conv2d(dims[1], dims[1 + 1], kernel_size=2, stride=2), + ) + + self.stages3 = nn.Sequential( + Block(dim=dims[2], drop_path=dp_rates[0]), + Block(dim=dims[2], drop_path=dp_rates[1]), + Block(dim=dims[2], drop_path=dp_rates[2]), + Block(dim=dims[2], drop_path=dp_rates[3]), + Block(dim=dims[2], drop_path=dp_rates[4]), + Block(dim=dims[2], drop_path=dp_rates[5]), + TransformerBlock( + inp=dims[2], + oup=dims[2], + img_size=[int(self.img_size / 32), int(self.img_size / 32)], + ), + TransformerBlock( + inp=dims[2], + oup=dims[2], + img_size=[int(self.img_size / 32), int(self.img_size / 32)], + ), + LayerNorm(dims[2], eps=1e-6, data_format=channel_order), + nn.Conv2d(dims[2], dims[2 + 1], kernel_size=2, stride=2), + ) + + self.stages4 = nn.Sequential( + Block(dim=dims[3], drop_path=dp_rates[0]), + Block(dim=dims[3], drop_path=dp_rates[1]), + Block(dim=dims[3], drop_path=dp_rates[2]), + Block(dim=dims[3], drop_path=dp_rates[3]), + TransformerBlock( + inp=dims[3], + oup=dims[3], + img_size=[int(self.img_size / 64), int(self.img_size / 64)], + ), + TransformerBlock( + inp=dims[3], + oup=dims[3], + img_size=[int(self.img_size / 64), int(self.img_size / 64)], + ), + ) + + self.norm = nn.LayerNorm(dims[-1], eps=1e-6) + + self.head = nn.Linear(dims[-1], num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward_features(self, x): + x = self.dct(x) + x = self.stages1(x) + x = self.stages2(x) + x = self.stages3(x) + x = self.stages4(x) + x = x.mean(dim=(-2, -1)) + x = self.norm(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + if pre_logits: + return x + x = self.head(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +# --- Components like LayerNorm, GRN, DropPath, FeedForward, PosCNN, trunc_normal_ --- +# (이 부분은 einops와 무관하므로 위 코드와 동일하게 유지합니다. 여기서는 공간 절약을 위해 생략) +# 기존 코드의 LayerNorm, GRN, DropPath, FeedForward, PosCNN, trunc_normal_ 함수를 그대로 사용하세요. + +class LayerNorm(nn.Module): + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError + self.normalized_shape = (normalized_shape,) + + def forward(self, x): + if self.data_format == "channels_last": + return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + elif self.data_format == "channels_first": + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +class GRN(nn.Module): + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) + Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * Nx) + self.beta + x + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout=0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + + def forward(self, x): + return self.net(x) + + +class PosCNN(nn.Module): + def __init__(self, in_chans): + super(PosCNN, self).__init__() + self.proj = nn.Conv2d(in_chans, in_chans, kernel_size=3, stride=1, padding=1, bias=True, groups=in_chans) + + def forward(self, x): + B, N, C = x.shape + feat_token = x + H, W = int(N ** 0.5), int(N ** 0.5) + cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W) + x = self.proj(cnn_feat) + cnn_feat + x = x.flatten(2).transpose(1, 2) + return x + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + def norm_cdf(x): + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_.", stacklevel=2) + with torch.no_grad(): + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + tensor.uniform_(2 * l - 1, 2 * u - 1) + tensor.erfinv_() + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + tensor.clamp_(min=a, max=b) + return tensor + + +default_cfgs = generate_default_cfgs({ + 'csatv2': { + 'url': 'https://huggingface.co/Hyunil/CSATv2/resolve/main/CSATv2_ImageNet_timm.pth', + 'num_classes': 1000, + 'input_size': (3, 512, 512), + 'mean': (0.485, 0.456, 0.406), + 'std': (0.229, 0.224, 0.225), + 'interpolation': 'bilinear', + 'crop_pct': 1.0, + }, +}) + + +def _create_csatv2(variant: str, pretrained: bool = False, **kwargs) -> CSATv2: + return build_model_with_cfg( + CSATv2, + variant, + pretrained, + default_cfg=default_cfgs[variant], + **kwargs, + ) + + +@register_model +def csatv2(pretrained: bool = False, **kwargs) -> CSATv2: + model_args = dict( + img_size=kwargs.pop('img_size', 512), + num_classes=kwargs.pop('num_classes', 1000), + ) + return _create_csatv2('csatv2', pretrained, **dict(model_args, **kwargs)) \ No newline at end of file From c5843040f1c6a819b3237ac641dc8f3322c25054 Mon Sep 17 00:00:00 2001 From: gusdlf93 Date: Sun, 7 Dec 2025 00:00:58 +0900 Subject: [PATCH 02/14] Update csatv2.py --- timm/models/csatv2.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/timm/models/csatv2.py b/timm/models/csatv2.py index 27881fe71c..75a01680ba 100644 --- a/timm/models/csatv2.py +++ b/timm/models/csatv2.py @@ -470,11 +470,6 @@ def forward(self, x): x = self.forward_head(x) return x - -# --- Components like LayerNorm, GRN, DropPath, FeedForward, PosCNN, trunc_normal_ --- -# (이 부분은 einops와 무관하므로 위 코드와 동일하게 유지합니다. 여기서는 공간 절약을 위해 생략) -# 기존 코드의 LayerNorm, GRN, DropPath, FeedForward, PosCNN, trunc_normal_ 함수를 그대로 사용하세요. - class LayerNorm(nn.Module): def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): super().__init__() @@ -609,4 +604,5 @@ def csatv2(pretrained: bool = False, **kwargs) -> CSATv2: img_size=kwargs.pop('img_size', 512), num_classes=kwargs.pop('num_classes', 1000), ) - return _create_csatv2('csatv2', pretrained, **dict(model_args, **kwargs)) \ No newline at end of file + + return _create_csatv2('csatv2', pretrained, **dict(model_args, **kwargs)) From 5dc2e6d32efc2df740e509dc2d99d62f2c0bd0d2 Mon Sep 17 00:00:00 2001 From: gusdlf93 Date: Sun, 7 Dec 2025 16:22:29 +0900 Subject: [PATCH 03/14] Update csatv2.py --- timm/models/csatv2.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/timm/models/csatv2.py b/timm/models/csatv2.py index 75a01680ba..bd386ea20d 100644 --- a/timm/models/csatv2.py +++ b/timm/models/csatv2.py @@ -135,7 +135,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x_Y = self.Y_Conv(x[:, 0, ]) x_Cb = self.Cb_Conv(x[:, 1, ]) x_Cr = self.Cr_Conv(x[:, 2, ]) - x = torch.cat([x_Y, x_Cb, x_Cr], axis=1) + x = torch.cat([x_Y, x_Cb, x_Cr], dim=1) return x @@ -160,12 +160,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: mean_list = torch.zeros([3, 64]) var_list = torch.zeros([3, 64]) - mean_list[0] = torch.mean(x[:, 0, ], axis=0) - mean_list[1] = torch.mean(x[:, 1, ], axis=0) - mean_list[2] = torch.mean(x[:, 2, ], axis=0) - var_list[0] = torch.var(x[:, 0, ], axis=0) - var_list[1] = torch.var(x[:, 1, ], axis=0) - var_list[2] = torch.var(x[:, 2, ], axis=0) + mean_list[0] = torch.mean(x[:, 0, ], dim=0) + mean_list[1] = torch.mean(x[:, 1, ], dim=0) + mean_list[2] = torch.mean(x[:, 2, ], dim=0) + var_list[0] = torch.var(x[:, 0, ], dim=0) + var_list[1] = torch.var(x[:, 1, ], dim=0) + var_list[2] = torch.var(x[:, 2, ], dim=0) return mean_list, var_list @@ -470,6 +470,11 @@ def forward(self, x): x = self.forward_head(x) return x + +# --- Components like LayerNorm, GRN, DropPath, FeedForward, PosCNN, trunc_normal_ --- +# (이 부분은 einops와 무관하므로 위 코드와 동일하게 유지합니다. 여기서는 공간 절약을 위해 생략) +# 기존 코드의 LayerNorm, GRN, DropPath, FeedForward, PosCNN, trunc_normal_ 함수를 그대로 사용하세요. + class LayerNorm(nn.Module): def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): super().__init__() @@ -604,5 +609,4 @@ def csatv2(pretrained: bool = False, **kwargs) -> CSATv2: img_size=kwargs.pop('img_size', 512), num_classes=kwargs.pop('num_classes', 1000), ) - return _create_csatv2('csatv2', pretrained, **dict(model_args, **kwargs)) From f1de3274fe18e28221dc2a4b93347cac6273e5cc Mon Sep 17 00:00:00 2001 From: gusdlf93 Date: Mon, 8 Dec 2025 18:16:49 +0900 Subject: [PATCH 04/14] Add files via upload --- timm/models/csatv2.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/timm/models/csatv2.py b/timm/models/csatv2.py index bd386ea20d..ea74e91f77 100644 --- a/timm/models/csatv2.py +++ b/timm/models/csatv2.py @@ -194,8 +194,12 @@ def forward(self, x): # Spatial Attention logic attention = self.attention(x) - # Upsampling to match x spatial size - x = x * nn.UpsamplingBilinear2d(x.shape[2:])(attention) + + # [Fix] nn.UpsamplingBilinear2d 클래스 생성 -> F.interpolate 함수 사용 + # align_corners=False가 최신 기본값에 가깝습니다. (성능 차이는 미미함) + attention = F.interpolate(attention, size=x.shape[2:], mode='bilinear', align_corners=False) + + x = x * attention x = input + self.drop_path(x) return x @@ -236,6 +240,11 @@ def __init__(self, inp, oup, heads=8, dim_head=32, img_size=None, downsample=Fal self.pool1 = nn.MaxPool2d(3, 2, 1) self.pool2 = nn.MaxPool2d(3, 2, 1) self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False) + else: + # [Fix] JIT 컴파일 에러 방지: 사용하지 않더라도 속성을 정의해야 함 + self.pool1 = nn.Identity() + self.pool2 = nn.Identity() + self.proj = nn.Identity() # Attention block components # Note: In old code, PreNorm wrapped Attention. Here we split them. @@ -476,6 +485,12 @@ def forward(self, x): # 기존 코드의 LayerNorm, GRN, DropPath, FeedForward, PosCNN, trunc_normal_ 함수를 그대로 사용하세요. class LayerNorm(nn.Module): + """ LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with + shape (batch_size, height, width, channels) while channels_first corresponds to inputs + with shape (batch_size, channels, height, width). + """ + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): super().__init__() self.weight = nn.Parameter(torch.ones(normalized_shape)) @@ -489,7 +504,9 @@ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): def forward(self, x): if self.data_format == "channels_last": return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) - elif self.data_format == "channels_first": + else: + # [Fix] elif -> else로 변경 + # JIT이 "모든 경로에서 Tensor가 반환됨"을 알 수 있게 함 u = x.mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True) x = (x - u) / torch.sqrt(s + self.eps) From 24a884b3148a14608f0de088fc1c4d447a3457f9 Mon Sep 17 00:00:00 2001 From: gusdlf93 Date: Mon, 8 Dec 2025 18:19:41 +0900 Subject: [PATCH 05/14] Add files via upload --- csatv2.py | 628 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 628 insertions(+) create mode 100644 csatv2.py diff --git a/csatv2.py b/csatv2.py new file mode 100644 index 0000000000..f70c3237f5 --- /dev/null +++ b/csatv2.py @@ -0,0 +1,628 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import warnings +from timm.models._builder import build_model_with_cfg +from ._registry import register_model, generate_default_cfgs +from typing import List +import numpy as np +import math + +__all__ = ['DCT2D'] + +# Helper Functions +mean=[[932.42657,-0.00260,0.33415,-0.02840,0.00003,-0.02792,-0.00183,0.00006,0.00032,0.03402,-0.00571,0.00020,0.00006,-0.00038,-0.00558,-0.00116,-0.00000,-0.00047,-0.00008,-0.00030,0.00942,0.00161,-0.00009,-0.00006,-0.00014,-0.00035,0.00001,-0.00220,0.00033,-0.00002,-0.00003,-0.00020,0.00007,-0.00000,0.00005,0.00293,-0.00004,0.00006,0.00019,0.00004,0.00006,-0.00015,-0.00002,0.00007,0.00010,-0.00004,0.00008,0.00000,0.00008,-0.00001,0.00015,0.00002,0.00007,0.00003,0.00004,-0.00001,0.00004,-0.00000,0.00002,-0.00000,-0.00008,-0.00000,-0.00003,0.00003], +[962.34735,-0.00428,0.09835,0.00152,-0.00009,0.00312,-0.00141,-0.00001,-0.00013,0.01050,0.00065,0.00006,-0.00000,0.00003,0.00264,0.00000,0.00001,0.00007,-0.00006,0.00003,0.00341,0.00163,0.00004,0.00003,-0.00001,0.00008,-0.00000,0.00090,0.00018,-0.00006,-0.00001,0.00007,-0.00003,-0.00001,0.00006,0.00084,-0.00000,-0.00001,0.00000,0.00004,-0.00001,-0.00002,0.00000,0.00001,0.00002,0.00001,0.00004,0.00011,0.00000,-0.00003,0.00011,-0.00002,0.00001,0.00001,0.00001,0.00001,-0.00007,-0.00003,0.00001,0.00000,0.00001,0.00002,0.00001,0.00000], +[1053.16101,-0.00213,-0.09207,0.00186,0.00013,0.00034,-0.00119,0.00002,0.00011,-0.00984,0.00046,-0.00007,-0.00001,-0.00005,0.00180,0.00042,0.00002,-0.00010,0.00004,0.00003,-0.00301,0.00125,-0.00002,-0.00003,-0.00001,-0.00001,-0.00001,0.00056,0.00021,0.00001,-0.00001,0.00002,-0.00001,-0.00001,0.00005,-0.00070,-0.00002,-0.00002,0.00005,-0.00004,-0.00000,0.00002,-0.00002,0.00001,0.00000,-0.00003,0.00004,0.00007,0.00001,0.00000,0.00013,-0.00000,0.00000,0.00002,-0.00000,-0.00001,-0.00004,-0.00003,0.00000,0.00001,-0.00001,0.00001,-0.00000,0.00000]] + +var=[[270372.37500,6287.10645,5974.94043,1653.10889,1463.91748,1832.58997,755.92468,692.41528,648.57184,641.46881,285.79288,301.62100,380.43405,349.84027,374.15891,190.30960,190.76746,221.64578,200.82646,145.87979,126.92046,62.14622,67.75562,102.42001,129.74922,130.04631,103.12189,97.76417,53.17402,54.81048,73.48712,81.04342,69.35100,49.06024,33.96053,37.03279,20.48858,24.94830,33.90822,44.54912,47.56363,40.03160,30.43313,22.63899,26.53739,26.57114,21.84404,17.41557,15.18253,10.69678,11.24111,12.97229,15.08971,15.31646,8.90409,7.44213,6.66096,6.97719,4.17834,3.83882,4.51073,2.36646,2.41363,1.48266], +[18839.21094,321.70932,300.15259,77.47830,76.02293,89.04748,33.99642,34.74807,32.12333,28.19588,12.04675,14.26871,18.45779,16.59588,15.67892,7.37718,8.56312,10.28946,9.41013,6.69090,5.16453,2.55186,3.03073,4.66765,5.85418,5.74644,4.33702,3.66948,1.95107,2.26034,3.06380,3.50705,3.06359,2.19284,1.54454,1.57860,0.97078,1.13941,1.48653,1.89996,1.95544,1.64950,1.24754,0.93677,1.09267,1.09516,0.94163,0.78966,0.72489,0.50841,0.50909,0.55664,0.63111,0.64125,0.38847,0.33378,0.30918,0.33463,0.20875,0.19298,0.21903,0.13380,0.13444,0.09554], +[17127.39844,292.81421,271.45209,66.64056,63.60253,76.35437,28.06587,27.84831,25.96656,23.60370,9.99173,11.34992,14.46955,12.92553,12.69353,5.91537,6.60187,7.90891,7.32825,5.32785,4.29660,2.13459,2.44135,3.66021,4.50335,4.38959,3.34888,2.97181,1.60633,1.77010,2.35118,2.69018,2.38189,1.74596,1.26014,1.31684,0.79327,0.92046,1.17670,1.47609,1.50914,1.28725,0.99898,0.74832,0.85736,0.85800,0.74663,0.63508,0.58748,0.41098,0.41121,0.44663,0.50277,0.51519,0.31729,0.27336,0.25399,0.27241,0.17353,0.16255,0.18440,0.11602,0.11511,0.08450]] + + +def _zigzag_permutation(rows: int, cols: int) -> List[int]: + idx_matrix = np.arange(0, rows * cols, 1).reshape(rows, cols).tolist() + dia = [[] for _ in range(rows + cols - 1)] + zigzag = [] + for i in range(rows): + for j in range(cols): + s = i + j + if s % 2 == 0: + dia[s].insert(0, idx_matrix[i][j]) + else: + dia[s].append(idx_matrix[i][j]) + for d in dia: + zigzag.extend(d) + return zigzag + + +def _dct_kernel_type_2(kernel_size: int, orthonormal: bool, device=None, dtype=None) -> torch.Tensor: + factory_kwargs = dict(device=device, dtype=dtype) + x = torch.eye(kernel_size, **factory_kwargs) + v = x.clone().contiguous().view(-1, kernel_size) + v = torch.cat([v, v.flip([1])], dim=-1) + v = torch.fft.fft(v, dim=-1)[:, :kernel_size] + try: + k = torch.tensor(-1j, **factory_kwargs) * torch.pi * torch.arange(kernel_size, **factory_kwargs)[None, :] + except: + k = torch.tensor(-1j, **factory_kwargs) * math.pi * torch.arange(kernel_size, **factory_kwargs)[None, :] + k = torch.exp(k / (kernel_size * 2)) + v = v * k + v = v.real + if orthonormal: + v[:, 0] = v[:, 0] * torch.sqrt(torch.tensor(1 / (kernel_size * 4), **factory_kwargs)) + v[:, 1:] = v[:, 1:] * torch.sqrt(torch.tensor(1 / (kernel_size * 2), **factory_kwargs)) + v = v.contiguous().view(*x.shape) + return v + + +def _dct_kernel_type_3(kernel_size: int, orthonormal: bool, device=None, dtype=None) -> torch.Tensor: + return torch.linalg.inv(_dct_kernel_type_2(kernel_size, orthonormal, device, dtype)) + + +class _DCT1D(nn.Module): + def __init__(self, kernel_size: int, kernel_type: int = 2, orthonormal: bool = True, + device=None, dtype=None) -> None: + factory_kwargs = dict(device=device, dtype=dtype) + super(_DCT1D, self).__init__() + kernel = {'2': _dct_kernel_type_2, '3': _dct_kernel_type_3} + self.weights = nn.Parameter(kernel[f'{kernel_type}'](kernel_size, orthonormal, **factory_kwargs).T, False) + self.register_parameter('bias', None) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return nn.functional.linear(x, self.weights, self.bias) + + +class _DCT2D(nn.Module): + def __init__(self, kernel_size: int, kernel_type: int = 2, orthonormal: bool = True, + device=None, dtype=None) -> None: + factory_kwargs = dict(device=device, dtype=dtype) + super(_DCT2D, self).__init__() + self.transform = _DCT1D(kernel_size, kernel_type, orthonormal, **factory_kwargs) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.transform(self.transform(x).transpose(-1, -2)).transpose(-1, -2) + + +class Learnable_DCT2D(nn.Module): + def __init__(self, kernel_size: int, kernel_type: int = 2, orthonormal: bool = True, + device=None, dtype=None) -> None: + factory_kwargs = dict(device=device, dtype=dtype) + super(Learnable_DCT2D, self).__init__() + self.k = kernel_size + self.unfold = nn.Unfold(kernel_size=(kernel_size, kernel_size), stride=(kernel_size, kernel_size)) + self.transform = _DCT2D(kernel_size, kernel_type, orthonormal, **factory_kwargs) + self.permutation = _zigzag_permutation(kernel_size, kernel_size) + self.Y_Conv = nn.Conv2d(kernel_size ** 2, 24, kernel_size=1, padding=0) + self.Cb_Conv = nn.Conv2d(kernel_size ** 2, 4, kernel_size=1, padding=0) + self.Cr_Conv = nn.Conv2d(kernel_size ** 2, 4, kernel_size=1, padding=0) + self.mean = torch.tensor(mean, requires_grad=False) + self.var = torch.tensor(var, requires_grad=False) + self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406], requires_grad=False) + self.imagenet_var = torch.tensor([0.229, 0.224, 0.225], requires_grad=False) + + def denormalize(self, x): + x = x.multiply(self.imagenet_var.to(x.device)).add_(self.imagenet_mean.to(x.device)) * 255 + return x + + def rgb2ycbcr(self, x): + y = (x[:, :, :, 0] * 0.299) + (x[:, :, :, 1] * 0.587) + (x[:, :, :, 2] * 0.114) + cb = 0.564 * (x[:, :, :, 2] - y) + 128 + cr = 0.713 * (x[:, :, :, 0] - y) + 128 + x = torch.stack([y, cb, cr], dim=-1) + return x + + def frequncy_normalize(self, x): + x[:, 0, ].sub_(self.mean.to(x.device)[0]).div_((self.var.to(x.device)[0] ** 0.5 + 1e-8)) + x[:, 1, ].sub_(self.mean.to(x.device)[1]).div_((self.var.to(x.device)[1] ** 0.5 + 1e-8)) + x[:, 2, ].sub_(self.mean.to(x.device)[2]).div_((self.var.to(x.device)[2] ** 0.5 + 1e-8)) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, c, h, w = x.shape + x = x.permute(0, 2, 3, 1) + x = self.denormalize(x) + x = self.rgb2ycbcr(x) + x = x.permute(0, 3, 1, 2) + x = self.unfold(x).transpose(-1, -2) + x = x.reshape(b, h // self.k, w // self.k, c, self.k, self.k) + x = self.transform(x) + x = x.reshape(-1, c, self.k * self.k) + x = x[:, :, self.permutation] + x = self.frequncy_normalize(x) + x = x.reshape(b, h // self.k, w // self.k, c, -1) + x = x.permute(0, 3, 4, 1, 2).contiguous() + x_Y = self.Y_Conv(x[:, 0, ]) + x_Cb = self.Cb_Conv(x[:, 1, ]) + x_Cr = self.Cr_Conv(x[:, 2, ]) + x = torch.cat([x_Y, x_Cb, x_Cr], dim=1) + return x + + +class DCT2D(nn.Module): + def __init__(self, kernel_size: int, kernel_type: int = 2, orthonormal: bool = True, + device=None, dtype=None) -> None: + factory_kwargs = dict(device=device, dtype=dtype) + super(DCT2D, self).__init__() + self.k = kernel_size + self.unfold = nn.Unfold(kernel_size=(kernel_size, kernel_size), stride=(kernel_size, kernel_size)) + self.transform = _DCT2D(kernel_size, kernel_type, orthonormal, **factory_kwargs) + self.permutation = _zigzag_permutation(kernel_size, kernel_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, c, h, w = x.shape + x = self.unfold(x).transpose(-1, -2) + x = x.reshape(b, h // self.k, w // self.k, c, self.k, self.k) + x = self.transform(x) + x = x.reshape(-1, c, self.k * self.k) + x = x[:, :, self.permutation] + x = x.reshape(b * (h // self.k) * (w // self.k), c, -1) + + mean_list = torch.zeros([3, 64]) + var_list = torch.zeros([3, 64]) + mean_list[0] = torch.mean(x[:, 0, ], dim=0) + mean_list[1] = torch.mean(x[:, 1, ], dim=0) + mean_list[2] = torch.mean(x[:, 2, ], dim=0) + var_list[0] = torch.var(x[:, 0, ], dim=0) + var_list[1] = torch.var(x[:, 1, ], dim=0) + var_list[2] = torch.var(x[:, 2, ], dim=0) + return mean_list, var_list + + +class Block(nn.Module): + def __init__(self, dim, drop_path=0.): + super().__init__() + self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) + self.norm = LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, 4 * dim) + self.act = nn.GELU() + self.grn = GRN(4 * dim) + self.pwconv2 = nn.Linear(4 * dim, dim) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.attention = Spatial_Attention() + + def forward(self, x): + input = x + x = self.dwconv(x) + x = x.permute(0, 2, 3, 1) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.grn(x) + x = self.pwconv2(x) + x = x.permute(0, 3, 1, 2) + + # Spatial Attention logic + attention = self.attention(x) + + # [Fix] create nn.UpsamplingBilinear2d class -> use F.interpolate function + attention = F.interpolate(attention, size=x.shape[2:], mode='bilinear', align_corners=False) + + x = x * attention + + x = input + self.drop_path(x) + return x + + +class Spatial_Attention(nn.Module): + def __init__(self): + super().__init__() + self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) + self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3) + # downsample=False, dropout=0. by default + self.attention = TransformerBlock(1, 1, heads=1, dim_head=1, img_size=[7, 7], downsample=False) + + def forward(self, x): + x_avg = x.mean([1]).unsqueeze(1) + x_max = x.max(dim=1).values.unsqueeze(1) + x = torch.cat([x_avg, x_max], dim=1) + x = self.avgpool(x) + x = self.conv(x) + x = self.attention(x) + return x + + +class TransformerBlock(nn.Module): + """ + Refactored TransformerBlock without einops and PreNorm class wrapper. + Manual reshaping is performed in forward(). + """ + + def __init__(self, inp, oup, heads=8, dim_head=32, img_size=None, downsample=False, dropout=0.): + super().__init__() + hidden_dim = int(inp * 4) + + self.downsample = downsample + self.ih, self.iw = img_size + + if self.downsample: + self.pool1 = nn.MaxPool2d(3, 2, 1) + self.pool2 = nn.MaxPool2d(3, 2, 1) + self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False) + else: + # [Fix] Prevent JIT compile error + self.pool1 = nn.Identity() + self.pool2 = nn.Identity() + self.proj = nn.Identity() + + # Attention block components + # Note: In old code, PreNorm wrapped Attention. Here we split them. + self.attn_norm = nn.LayerNorm(inp) + self.attn = Attention(inp, oup, heads, dim_head, dropout) + + # FeedForward block components + self.ff_norm = nn.LayerNorm(oup) + self.ff = FeedForward(oup, hidden_dim, dropout) + + def forward(self, x): + # x shape: (B, C, H, W) + if self.downsample: + # Identity path with projection + shortcut = self.proj(self.pool1(x)) + + # Attention path + x_t = self.pool2(x) + B, C, H, W = x_t.shape + + # Flatten spatial: (B, C, H, W) -> (B, H*W, C) + x_t = x_t.flatten(2).transpose(1, 2) + + # PreNorm -> Attention + x_t = self.attn_norm(x_t) + x_t = self.attn(x_t) + + # Unflatten: (B, H*W, C) -> (B, C, H, W) + x_t = x_t.transpose(1, 2).reshape(B, C, H, W) + + x = shortcut + x_t + else: + # Standard PreNorm Residual Attention + B, C, H, W = x.shape + shortcut = x + + # Flatten + x_t = x.flatten(2).transpose(1, 2) + + # PreNorm -> Attention + x_t = self.attn_norm(x_t) + x_t = self.attn(x_t) + + # Unflatten + x_t = x_t.transpose(1, 2).reshape(B, C, H, W) + + x = shortcut + x_t + + # FeedForward Block + B, C, H, W = x.shape + shortcut = x + + # Flatten + x_t = x.flatten(2).transpose(1, 2) + + # PreNorm -> FeedForward + x_t = self.ff_norm(x_t) + x_t = self.ff(x_t) + + # Unflatten + x_t = x_t.transpose(1, 2).reshape(B, C, H, W) + + x = shortcut + x_t + + return x + + +class Attention(nn.Module): + """ + Refactored Attention without einops.rearrange. + """ + + def __init__(self, inp, oup, heads=8, dim_head=32, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == inp) + + self.heads = heads + self.dim_head = dim_head + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim=-1) + self.to_qkv = nn.Linear(inp, inner_dim * 3, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, oup), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + self.pos_embed = PosCNN(in_chans=inp) + + def forward(self, x): + # x shape: (B, N, C) + # Positional Embedding (expects B, N, C -> internally converts to spatial) + x = self.pos_embed(x) + + B, N, C = x.shape + + # Generate Q, K, V + # qkv: (B, N, 3 * heads * dim_head) + qkv = self.to_qkv(x) + + # Reshape to (B, N, 3, heads, dim_head) and permute to (3, B, heads, N, dim_head) + qkv = qkv.reshape(B, N, 3, self.heads, self.dim_head).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0], qkv[1], qkv[2] # Each is (B, heads, N, dim_head) + + # Attention Score + # (B, heads, N, dim_head) @ (B, heads, dim_head, N) -> (B, heads, N, N) + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + attn = self.attend(dots) + + # Weighted Sum + # (B, heads, N, N) @ (B, heads, N, dim_head) -> (B, heads, N, dim_head) + out = torch.matmul(attn, v) + + # Rearrange output: (B, heads, N, dim_head) -> (B, N, heads*dim_head) + out = out.transpose(1, 2).reshape(B, N, -1) + + out = self.to_out(out) + return out + + +class CSATv2(nn.Module): + def __init__( + self, + img_size=512, + num_classes=1000, + in_chans=3, + drop_path_rate=0.0, + global_pool='avg', + **kwargs, + ): + super().__init__() + + self.num_classes = num_classes + self.global_pool = global_pool + + if isinstance(img_size, (tuple, list)): + img_size = img_size[0] + + self.img_size = img_size + + dims = [32, 72, 168, 386] + channel_order = "channels_first" + depths = [2, 2, 6, 4] + dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + + self.dct = Learnable_DCT2D(8) + + self.stages1 = nn.Sequential( + Block(dim=dims[0], drop_path=dp_rates[0]), + Block(dim=dims[0], drop_path=dp_rates[1]), + LayerNorm(dims[0], eps=1e-6, data_format=channel_order), + nn.Conv2d(dims[0], dims[0 + 1], kernel_size=2, stride=2), + ) + + self.stages2 = nn.Sequential( + Block(dim=dims[1], drop_path=dp_rates[0]), + Block(dim=dims[1], drop_path=dp_rates[1]), + LayerNorm(dims[1], eps=1e-6, data_format=channel_order), + nn.Conv2d(dims[1], dims[1 + 1], kernel_size=2, stride=2), + ) + + self.stages3 = nn.Sequential( + Block(dim=dims[2], drop_path=dp_rates[0]), + Block(dim=dims[2], drop_path=dp_rates[1]), + Block(dim=dims[2], drop_path=dp_rates[2]), + Block(dim=dims[2], drop_path=dp_rates[3]), + Block(dim=dims[2], drop_path=dp_rates[4]), + Block(dim=dims[2], drop_path=dp_rates[5]), + TransformerBlock( + inp=dims[2], + oup=dims[2], + img_size=[int(self.img_size / 32), int(self.img_size / 32)], + ), + TransformerBlock( + inp=dims[2], + oup=dims[2], + img_size=[int(self.img_size / 32), int(self.img_size / 32)], + ), + LayerNorm(dims[2], eps=1e-6, data_format=channel_order), + nn.Conv2d(dims[2], dims[2 + 1], kernel_size=2, stride=2), + ) + + self.stages4 = nn.Sequential( + Block(dim=dims[3], drop_path=dp_rates[0]), + Block(dim=dims[3], drop_path=dp_rates[1]), + Block(dim=dims[3], drop_path=dp_rates[2]), + Block(dim=dims[3], drop_path=dp_rates[3]), + TransformerBlock( + inp=dims[3], + oup=dims[3], + img_size=[int(self.img_size / 64), int(self.img_size / 64)], + ), + TransformerBlock( + inp=dims[3], + oup=dims[3], + img_size=[int(self.img_size / 64), int(self.img_size / 64)], + ), + ) + + self.norm = nn.LayerNorm(dims[-1], eps=1e-6) + + self.head = nn.Linear(dims[-1], num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward_features(self, x): + x = self.dct(x) + x = self.stages1(x) + x = self.stages2(x) + x = self.stages3(x) + x = self.stages4(x) + x = x.mean(dim=(-2, -1)) + x = self.norm(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + if pre_logits: + return x + x = self.head(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +# --- Components like LayerNorm, GRN, DropPath, FeedForward, PosCNN, trunc_normal_ --- +# (이 부분은 einops와 무관하므로 위 코드와 동일하게 유지합니다. 여기서는 공간 절약을 위해 생략) +# 기존 코드의 LayerNorm, GRN, DropPath, FeedForward, PosCNN, trunc_normal_ 함수를 그대로 사용하세요. + +class LayerNorm(nn.Module): + """ LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with + shape (batch_size, height, width, channels) while channels_first corresponds to inputs + with shape (batch_size, channels, height, width). + """ + + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError + self.normalized_shape = (normalized_shape,) + + def forward(self, x): + if self.data_format == "channels_last": + return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + else: + # [Fix] change elif -> else + # JIT should know Tensor return path of all cases. + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +class GRN(nn.Module): + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) + Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * Nx) + self.beta + x + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout=0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + + def forward(self, x): + return self.net(x) + + +class PosCNN(nn.Module): + def __init__(self, in_chans): + super(PosCNN, self).__init__() + self.proj = nn.Conv2d(in_chans, in_chans, kernel_size=3, stride=1, padding=1, bias=True, groups=in_chans) + + def forward(self, x): + B, N, C = x.shape + feat_token = x + H, W = int(N ** 0.5), int(N ** 0.5) + cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W) + x = self.proj(cnn_feat) + cnn_feat + x = x.flatten(2).transpose(1, 2) + return x + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + def norm_cdf(x): + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_.", stacklevel=2) + with torch.no_grad(): + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + tensor.uniform_(2 * l - 1, 2 * u - 1) + tensor.erfinv_() + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + tensor.clamp_(min=a, max=b) + return tensor + + +default_cfgs = generate_default_cfgs({ + 'csatv2': { + 'url': 'https://huggingface.co/Hyunil/CSATv2/resolve/main/CSATv2_ImageNet_timm.pth', + 'num_classes': 1000, + 'input_size': (3, 512, 512), + 'mean': (0.485, 0.456, 0.406), + 'std': (0.229, 0.224, 0.225), + 'interpolation': 'bilinear', + 'crop_pct': 1.0, + }, +}) + + +def _create_csatv2(variant: str, pretrained: bool = False, **kwargs) -> CSATv2: + return build_model_with_cfg( + CSATv2, + variant, + pretrained, + default_cfg=default_cfgs[variant], + **kwargs, + ) + + +@register_model +def csatv2(pretrained: bool = False, **kwargs) -> CSATv2: + model_args = dict( + img_size=kwargs.pop('img_size', 512), + num_classes=kwargs.pop('num_classes', 1000), + ) + return _create_csatv2('csatv2', pretrained, **dict(model_args, **kwargs)) From 6462c73fc73dc179a2c0aec84bfe36b44c904721 Mon Sep 17 00:00:00 2001 From: gusdlf93 Date: Mon, 8 Dec 2025 18:20:42 +0900 Subject: [PATCH 06/14] Delete csatv2.py --- csatv2.py | 628 ------------------------------------------------------ 1 file changed, 628 deletions(-) delete mode 100644 csatv2.py diff --git a/csatv2.py b/csatv2.py deleted file mode 100644 index f70c3237f5..0000000000 --- a/csatv2.py +++ /dev/null @@ -1,628 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import warnings -from timm.models._builder import build_model_with_cfg -from ._registry import register_model, generate_default_cfgs -from typing import List -import numpy as np -import math - -__all__ = ['DCT2D'] - -# Helper Functions -mean=[[932.42657,-0.00260,0.33415,-0.02840,0.00003,-0.02792,-0.00183,0.00006,0.00032,0.03402,-0.00571,0.00020,0.00006,-0.00038,-0.00558,-0.00116,-0.00000,-0.00047,-0.00008,-0.00030,0.00942,0.00161,-0.00009,-0.00006,-0.00014,-0.00035,0.00001,-0.00220,0.00033,-0.00002,-0.00003,-0.00020,0.00007,-0.00000,0.00005,0.00293,-0.00004,0.00006,0.00019,0.00004,0.00006,-0.00015,-0.00002,0.00007,0.00010,-0.00004,0.00008,0.00000,0.00008,-0.00001,0.00015,0.00002,0.00007,0.00003,0.00004,-0.00001,0.00004,-0.00000,0.00002,-0.00000,-0.00008,-0.00000,-0.00003,0.00003], -[962.34735,-0.00428,0.09835,0.00152,-0.00009,0.00312,-0.00141,-0.00001,-0.00013,0.01050,0.00065,0.00006,-0.00000,0.00003,0.00264,0.00000,0.00001,0.00007,-0.00006,0.00003,0.00341,0.00163,0.00004,0.00003,-0.00001,0.00008,-0.00000,0.00090,0.00018,-0.00006,-0.00001,0.00007,-0.00003,-0.00001,0.00006,0.00084,-0.00000,-0.00001,0.00000,0.00004,-0.00001,-0.00002,0.00000,0.00001,0.00002,0.00001,0.00004,0.00011,0.00000,-0.00003,0.00011,-0.00002,0.00001,0.00001,0.00001,0.00001,-0.00007,-0.00003,0.00001,0.00000,0.00001,0.00002,0.00001,0.00000], -[1053.16101,-0.00213,-0.09207,0.00186,0.00013,0.00034,-0.00119,0.00002,0.00011,-0.00984,0.00046,-0.00007,-0.00001,-0.00005,0.00180,0.00042,0.00002,-0.00010,0.00004,0.00003,-0.00301,0.00125,-0.00002,-0.00003,-0.00001,-0.00001,-0.00001,0.00056,0.00021,0.00001,-0.00001,0.00002,-0.00001,-0.00001,0.00005,-0.00070,-0.00002,-0.00002,0.00005,-0.00004,-0.00000,0.00002,-0.00002,0.00001,0.00000,-0.00003,0.00004,0.00007,0.00001,0.00000,0.00013,-0.00000,0.00000,0.00002,-0.00000,-0.00001,-0.00004,-0.00003,0.00000,0.00001,-0.00001,0.00001,-0.00000,0.00000]] - -var=[[270372.37500,6287.10645,5974.94043,1653.10889,1463.91748,1832.58997,755.92468,692.41528,648.57184,641.46881,285.79288,301.62100,380.43405,349.84027,374.15891,190.30960,190.76746,221.64578,200.82646,145.87979,126.92046,62.14622,67.75562,102.42001,129.74922,130.04631,103.12189,97.76417,53.17402,54.81048,73.48712,81.04342,69.35100,49.06024,33.96053,37.03279,20.48858,24.94830,33.90822,44.54912,47.56363,40.03160,30.43313,22.63899,26.53739,26.57114,21.84404,17.41557,15.18253,10.69678,11.24111,12.97229,15.08971,15.31646,8.90409,7.44213,6.66096,6.97719,4.17834,3.83882,4.51073,2.36646,2.41363,1.48266], -[18839.21094,321.70932,300.15259,77.47830,76.02293,89.04748,33.99642,34.74807,32.12333,28.19588,12.04675,14.26871,18.45779,16.59588,15.67892,7.37718,8.56312,10.28946,9.41013,6.69090,5.16453,2.55186,3.03073,4.66765,5.85418,5.74644,4.33702,3.66948,1.95107,2.26034,3.06380,3.50705,3.06359,2.19284,1.54454,1.57860,0.97078,1.13941,1.48653,1.89996,1.95544,1.64950,1.24754,0.93677,1.09267,1.09516,0.94163,0.78966,0.72489,0.50841,0.50909,0.55664,0.63111,0.64125,0.38847,0.33378,0.30918,0.33463,0.20875,0.19298,0.21903,0.13380,0.13444,0.09554], -[17127.39844,292.81421,271.45209,66.64056,63.60253,76.35437,28.06587,27.84831,25.96656,23.60370,9.99173,11.34992,14.46955,12.92553,12.69353,5.91537,6.60187,7.90891,7.32825,5.32785,4.29660,2.13459,2.44135,3.66021,4.50335,4.38959,3.34888,2.97181,1.60633,1.77010,2.35118,2.69018,2.38189,1.74596,1.26014,1.31684,0.79327,0.92046,1.17670,1.47609,1.50914,1.28725,0.99898,0.74832,0.85736,0.85800,0.74663,0.63508,0.58748,0.41098,0.41121,0.44663,0.50277,0.51519,0.31729,0.27336,0.25399,0.27241,0.17353,0.16255,0.18440,0.11602,0.11511,0.08450]] - - -def _zigzag_permutation(rows: int, cols: int) -> List[int]: - idx_matrix = np.arange(0, rows * cols, 1).reshape(rows, cols).tolist() - dia = [[] for _ in range(rows + cols - 1)] - zigzag = [] - for i in range(rows): - for j in range(cols): - s = i + j - if s % 2 == 0: - dia[s].insert(0, idx_matrix[i][j]) - else: - dia[s].append(idx_matrix[i][j]) - for d in dia: - zigzag.extend(d) - return zigzag - - -def _dct_kernel_type_2(kernel_size: int, orthonormal: bool, device=None, dtype=None) -> torch.Tensor: - factory_kwargs = dict(device=device, dtype=dtype) - x = torch.eye(kernel_size, **factory_kwargs) - v = x.clone().contiguous().view(-1, kernel_size) - v = torch.cat([v, v.flip([1])], dim=-1) - v = torch.fft.fft(v, dim=-1)[:, :kernel_size] - try: - k = torch.tensor(-1j, **factory_kwargs) * torch.pi * torch.arange(kernel_size, **factory_kwargs)[None, :] - except: - k = torch.tensor(-1j, **factory_kwargs) * math.pi * torch.arange(kernel_size, **factory_kwargs)[None, :] - k = torch.exp(k / (kernel_size * 2)) - v = v * k - v = v.real - if orthonormal: - v[:, 0] = v[:, 0] * torch.sqrt(torch.tensor(1 / (kernel_size * 4), **factory_kwargs)) - v[:, 1:] = v[:, 1:] * torch.sqrt(torch.tensor(1 / (kernel_size * 2), **factory_kwargs)) - v = v.contiguous().view(*x.shape) - return v - - -def _dct_kernel_type_3(kernel_size: int, orthonormal: bool, device=None, dtype=None) -> torch.Tensor: - return torch.linalg.inv(_dct_kernel_type_2(kernel_size, orthonormal, device, dtype)) - - -class _DCT1D(nn.Module): - def __init__(self, kernel_size: int, kernel_type: int = 2, orthonormal: bool = True, - device=None, dtype=None) -> None: - factory_kwargs = dict(device=device, dtype=dtype) - super(_DCT1D, self).__init__() - kernel = {'2': _dct_kernel_type_2, '3': _dct_kernel_type_3} - self.weights = nn.Parameter(kernel[f'{kernel_type}'](kernel_size, orthonormal, **factory_kwargs).T, False) - self.register_parameter('bias', None) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return nn.functional.linear(x, self.weights, self.bias) - - -class _DCT2D(nn.Module): - def __init__(self, kernel_size: int, kernel_type: int = 2, orthonormal: bool = True, - device=None, dtype=None) -> None: - factory_kwargs = dict(device=device, dtype=dtype) - super(_DCT2D, self).__init__() - self.transform = _DCT1D(kernel_size, kernel_type, orthonormal, **factory_kwargs) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.transform(self.transform(x).transpose(-1, -2)).transpose(-1, -2) - - -class Learnable_DCT2D(nn.Module): - def __init__(self, kernel_size: int, kernel_type: int = 2, orthonormal: bool = True, - device=None, dtype=None) -> None: - factory_kwargs = dict(device=device, dtype=dtype) - super(Learnable_DCT2D, self).__init__() - self.k = kernel_size - self.unfold = nn.Unfold(kernel_size=(kernel_size, kernel_size), stride=(kernel_size, kernel_size)) - self.transform = _DCT2D(kernel_size, kernel_type, orthonormal, **factory_kwargs) - self.permutation = _zigzag_permutation(kernel_size, kernel_size) - self.Y_Conv = nn.Conv2d(kernel_size ** 2, 24, kernel_size=1, padding=0) - self.Cb_Conv = nn.Conv2d(kernel_size ** 2, 4, kernel_size=1, padding=0) - self.Cr_Conv = nn.Conv2d(kernel_size ** 2, 4, kernel_size=1, padding=0) - self.mean = torch.tensor(mean, requires_grad=False) - self.var = torch.tensor(var, requires_grad=False) - self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406], requires_grad=False) - self.imagenet_var = torch.tensor([0.229, 0.224, 0.225], requires_grad=False) - - def denormalize(self, x): - x = x.multiply(self.imagenet_var.to(x.device)).add_(self.imagenet_mean.to(x.device)) * 255 - return x - - def rgb2ycbcr(self, x): - y = (x[:, :, :, 0] * 0.299) + (x[:, :, :, 1] * 0.587) + (x[:, :, :, 2] * 0.114) - cb = 0.564 * (x[:, :, :, 2] - y) + 128 - cr = 0.713 * (x[:, :, :, 0] - y) + 128 - x = torch.stack([y, cb, cr], dim=-1) - return x - - def frequncy_normalize(self, x): - x[:, 0, ].sub_(self.mean.to(x.device)[0]).div_((self.var.to(x.device)[0] ** 0.5 + 1e-8)) - x[:, 1, ].sub_(self.mean.to(x.device)[1]).div_((self.var.to(x.device)[1] ** 0.5 + 1e-8)) - x[:, 2, ].sub_(self.mean.to(x.device)[2]).div_((self.var.to(x.device)[2] ** 0.5 + 1e-8)) - return x - - def forward(self, x: torch.Tensor) -> torch.Tensor: - b, c, h, w = x.shape - x = x.permute(0, 2, 3, 1) - x = self.denormalize(x) - x = self.rgb2ycbcr(x) - x = x.permute(0, 3, 1, 2) - x = self.unfold(x).transpose(-1, -2) - x = x.reshape(b, h // self.k, w // self.k, c, self.k, self.k) - x = self.transform(x) - x = x.reshape(-1, c, self.k * self.k) - x = x[:, :, self.permutation] - x = self.frequncy_normalize(x) - x = x.reshape(b, h // self.k, w // self.k, c, -1) - x = x.permute(0, 3, 4, 1, 2).contiguous() - x_Y = self.Y_Conv(x[:, 0, ]) - x_Cb = self.Cb_Conv(x[:, 1, ]) - x_Cr = self.Cr_Conv(x[:, 2, ]) - x = torch.cat([x_Y, x_Cb, x_Cr], dim=1) - return x - - -class DCT2D(nn.Module): - def __init__(self, kernel_size: int, kernel_type: int = 2, orthonormal: bool = True, - device=None, dtype=None) -> None: - factory_kwargs = dict(device=device, dtype=dtype) - super(DCT2D, self).__init__() - self.k = kernel_size - self.unfold = nn.Unfold(kernel_size=(kernel_size, kernel_size), stride=(kernel_size, kernel_size)) - self.transform = _DCT2D(kernel_size, kernel_type, orthonormal, **factory_kwargs) - self.permutation = _zigzag_permutation(kernel_size, kernel_size) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - b, c, h, w = x.shape - x = self.unfold(x).transpose(-1, -2) - x = x.reshape(b, h // self.k, w // self.k, c, self.k, self.k) - x = self.transform(x) - x = x.reshape(-1, c, self.k * self.k) - x = x[:, :, self.permutation] - x = x.reshape(b * (h // self.k) * (w // self.k), c, -1) - - mean_list = torch.zeros([3, 64]) - var_list = torch.zeros([3, 64]) - mean_list[0] = torch.mean(x[:, 0, ], dim=0) - mean_list[1] = torch.mean(x[:, 1, ], dim=0) - mean_list[2] = torch.mean(x[:, 2, ], dim=0) - var_list[0] = torch.var(x[:, 0, ], dim=0) - var_list[1] = torch.var(x[:, 1, ], dim=0) - var_list[2] = torch.var(x[:, 2, ], dim=0) - return mean_list, var_list - - -class Block(nn.Module): - def __init__(self, dim, drop_path=0.): - super().__init__() - self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) - self.norm = LayerNorm(dim, eps=1e-6) - self.pwconv1 = nn.Linear(dim, 4 * dim) - self.act = nn.GELU() - self.grn = GRN(4 * dim) - self.pwconv2 = nn.Linear(4 * dim, dim) - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.attention = Spatial_Attention() - - def forward(self, x): - input = x - x = self.dwconv(x) - x = x.permute(0, 2, 3, 1) - x = self.norm(x) - x = self.pwconv1(x) - x = self.act(x) - x = self.grn(x) - x = self.pwconv2(x) - x = x.permute(0, 3, 1, 2) - - # Spatial Attention logic - attention = self.attention(x) - - # [Fix] create nn.UpsamplingBilinear2d class -> use F.interpolate function - attention = F.interpolate(attention, size=x.shape[2:], mode='bilinear', align_corners=False) - - x = x * attention - - x = input + self.drop_path(x) - return x - - -class Spatial_Attention(nn.Module): - def __init__(self): - super().__init__() - self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) - self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3) - # downsample=False, dropout=0. by default - self.attention = TransformerBlock(1, 1, heads=1, dim_head=1, img_size=[7, 7], downsample=False) - - def forward(self, x): - x_avg = x.mean([1]).unsqueeze(1) - x_max = x.max(dim=1).values.unsqueeze(1) - x = torch.cat([x_avg, x_max], dim=1) - x = self.avgpool(x) - x = self.conv(x) - x = self.attention(x) - return x - - -class TransformerBlock(nn.Module): - """ - Refactored TransformerBlock without einops and PreNorm class wrapper. - Manual reshaping is performed in forward(). - """ - - def __init__(self, inp, oup, heads=8, dim_head=32, img_size=None, downsample=False, dropout=0.): - super().__init__() - hidden_dim = int(inp * 4) - - self.downsample = downsample - self.ih, self.iw = img_size - - if self.downsample: - self.pool1 = nn.MaxPool2d(3, 2, 1) - self.pool2 = nn.MaxPool2d(3, 2, 1) - self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False) - else: - # [Fix] Prevent JIT compile error - self.pool1 = nn.Identity() - self.pool2 = nn.Identity() - self.proj = nn.Identity() - - # Attention block components - # Note: In old code, PreNorm wrapped Attention. Here we split them. - self.attn_norm = nn.LayerNorm(inp) - self.attn = Attention(inp, oup, heads, dim_head, dropout) - - # FeedForward block components - self.ff_norm = nn.LayerNorm(oup) - self.ff = FeedForward(oup, hidden_dim, dropout) - - def forward(self, x): - # x shape: (B, C, H, W) - if self.downsample: - # Identity path with projection - shortcut = self.proj(self.pool1(x)) - - # Attention path - x_t = self.pool2(x) - B, C, H, W = x_t.shape - - # Flatten spatial: (B, C, H, W) -> (B, H*W, C) - x_t = x_t.flatten(2).transpose(1, 2) - - # PreNorm -> Attention - x_t = self.attn_norm(x_t) - x_t = self.attn(x_t) - - # Unflatten: (B, H*W, C) -> (B, C, H, W) - x_t = x_t.transpose(1, 2).reshape(B, C, H, W) - - x = shortcut + x_t - else: - # Standard PreNorm Residual Attention - B, C, H, W = x.shape - shortcut = x - - # Flatten - x_t = x.flatten(2).transpose(1, 2) - - # PreNorm -> Attention - x_t = self.attn_norm(x_t) - x_t = self.attn(x_t) - - # Unflatten - x_t = x_t.transpose(1, 2).reshape(B, C, H, W) - - x = shortcut + x_t - - # FeedForward Block - B, C, H, W = x.shape - shortcut = x - - # Flatten - x_t = x.flatten(2).transpose(1, 2) - - # PreNorm -> FeedForward - x_t = self.ff_norm(x_t) - x_t = self.ff(x_t) - - # Unflatten - x_t = x_t.transpose(1, 2).reshape(B, C, H, W) - - x = shortcut + x_t - - return x - - -class Attention(nn.Module): - """ - Refactored Attention without einops.rearrange. - """ - - def __init__(self, inp, oup, heads=8, dim_head=32, dropout=0.): - super().__init__() - inner_dim = dim_head * heads - project_out = not (heads == 1 and dim_head == inp) - - self.heads = heads - self.dim_head = dim_head - self.scale = dim_head ** -0.5 - - self.attend = nn.Softmax(dim=-1) - self.to_qkv = nn.Linear(inp, inner_dim * 3, bias=False) - - self.to_out = nn.Sequential( - nn.Linear(inner_dim, oup), - nn.Dropout(dropout) - ) if project_out else nn.Identity() - self.pos_embed = PosCNN(in_chans=inp) - - def forward(self, x): - # x shape: (B, N, C) - # Positional Embedding (expects B, N, C -> internally converts to spatial) - x = self.pos_embed(x) - - B, N, C = x.shape - - # Generate Q, K, V - # qkv: (B, N, 3 * heads * dim_head) - qkv = self.to_qkv(x) - - # Reshape to (B, N, 3, heads, dim_head) and permute to (3, B, heads, N, dim_head) - qkv = qkv.reshape(B, N, 3, self.heads, self.dim_head).permute(2, 0, 3, 1, 4) - - q, k, v = qkv[0], qkv[1], qkv[2] # Each is (B, heads, N, dim_head) - - # Attention Score - # (B, heads, N, dim_head) @ (B, heads, dim_head, N) -> (B, heads, N, N) - dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale - attn = self.attend(dots) - - # Weighted Sum - # (B, heads, N, N) @ (B, heads, N, dim_head) -> (B, heads, N, dim_head) - out = torch.matmul(attn, v) - - # Rearrange output: (B, heads, N, dim_head) -> (B, N, heads*dim_head) - out = out.transpose(1, 2).reshape(B, N, -1) - - out = self.to_out(out) - return out - - -class CSATv2(nn.Module): - def __init__( - self, - img_size=512, - num_classes=1000, - in_chans=3, - drop_path_rate=0.0, - global_pool='avg', - **kwargs, - ): - super().__init__() - - self.num_classes = num_classes - self.global_pool = global_pool - - if isinstance(img_size, (tuple, list)): - img_size = img_size[0] - - self.img_size = img_size - - dims = [32, 72, 168, 386] - channel_order = "channels_first" - depths = [2, 2, 6, 4] - dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] - - self.dct = Learnable_DCT2D(8) - - self.stages1 = nn.Sequential( - Block(dim=dims[0], drop_path=dp_rates[0]), - Block(dim=dims[0], drop_path=dp_rates[1]), - LayerNorm(dims[0], eps=1e-6, data_format=channel_order), - nn.Conv2d(dims[0], dims[0 + 1], kernel_size=2, stride=2), - ) - - self.stages2 = nn.Sequential( - Block(dim=dims[1], drop_path=dp_rates[0]), - Block(dim=dims[1], drop_path=dp_rates[1]), - LayerNorm(dims[1], eps=1e-6, data_format=channel_order), - nn.Conv2d(dims[1], dims[1 + 1], kernel_size=2, stride=2), - ) - - self.stages3 = nn.Sequential( - Block(dim=dims[2], drop_path=dp_rates[0]), - Block(dim=dims[2], drop_path=dp_rates[1]), - Block(dim=dims[2], drop_path=dp_rates[2]), - Block(dim=dims[2], drop_path=dp_rates[3]), - Block(dim=dims[2], drop_path=dp_rates[4]), - Block(dim=dims[2], drop_path=dp_rates[5]), - TransformerBlock( - inp=dims[2], - oup=dims[2], - img_size=[int(self.img_size / 32), int(self.img_size / 32)], - ), - TransformerBlock( - inp=dims[2], - oup=dims[2], - img_size=[int(self.img_size / 32), int(self.img_size / 32)], - ), - LayerNorm(dims[2], eps=1e-6, data_format=channel_order), - nn.Conv2d(dims[2], dims[2 + 1], kernel_size=2, stride=2), - ) - - self.stages4 = nn.Sequential( - Block(dim=dims[3], drop_path=dp_rates[0]), - Block(dim=dims[3], drop_path=dp_rates[1]), - Block(dim=dims[3], drop_path=dp_rates[2]), - Block(dim=dims[3], drop_path=dp_rates[3]), - TransformerBlock( - inp=dims[3], - oup=dims[3], - img_size=[int(self.img_size / 64), int(self.img_size / 64)], - ), - TransformerBlock( - inp=dims[3], - oup=dims[3], - img_size=[int(self.img_size / 64), int(self.img_size / 64)], - ), - ) - - self.norm = nn.LayerNorm(dims[-1], eps=1e-6) - - self.head = nn.Linear(dims[-1], num_classes) if num_classes > 0 else nn.Identity() - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, (nn.Conv2d, nn.Linear)): - trunc_normal_(m.weight, std=.02) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - - def forward_features(self, x): - x = self.dct(x) - x = self.stages1(x) - x = self.stages2(x) - x = self.stages3(x) - x = self.stages4(x) - x = x.mean(dim=(-2, -1)) - x = self.norm(x) - return x - - def forward_head(self, x, pre_logits: bool = False): - if pre_logits: - return x - x = self.head(x) - return x - - def forward(self, x): - x = self.forward_features(x) - x = self.forward_head(x) - return x - - -# --- Components like LayerNorm, GRN, DropPath, FeedForward, PosCNN, trunc_normal_ --- -# (이 부분은 einops와 무관하므로 위 코드와 동일하게 유지합니다. 여기서는 공간 절약을 위해 생략) -# 기존 코드의 LayerNorm, GRN, DropPath, FeedForward, PosCNN, trunc_normal_ 함수를 그대로 사용하세요. - -class LayerNorm(nn.Module): - """ LayerNorm that supports two data formats: channels_last (default) or channels_first. - The ordering of the dimensions in the inputs. channels_last corresponds to inputs with - shape (batch_size, height, width, channels) while channels_first corresponds to inputs - with shape (batch_size, channels, height, width). - """ - - def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): - super().__init__() - self.weight = nn.Parameter(torch.ones(normalized_shape)) - self.bias = nn.Parameter(torch.zeros(normalized_shape)) - self.eps = eps - self.data_format = data_format - if self.data_format not in ["channels_last", "channels_first"]: - raise NotImplementedError - self.normalized_shape = (normalized_shape,) - - def forward(self, x): - if self.data_format == "channels_last": - return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) - else: - # [Fix] change elif -> else - # JIT should know Tensor return path of all cases. - u = x.mean(1, keepdim=True) - s = (x - u).pow(2).mean(1, keepdim=True) - x = (x - u) / torch.sqrt(s + self.eps) - x = self.weight[:, None, None] * x + self.bias[:, None, None] - return x - - -class GRN(nn.Module): - def __init__(self, dim): - super().__init__() - self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) - self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) - - def forward(self, x): - Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) - Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) - return self.gamma * (x * Nx) + self.beta + x - - -def drop_path(x, drop_prob: float = 0., training: bool = False): - if drop_prob == 0. or not training: - return x - keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * (x.ndim - 1) - random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) - random_tensor.floor_() - output = x.div(keep_prob) * random_tensor - return output - - -class DropPath(nn.Module): - def __init__(self, drop_prob=None): - super(DropPath, self).__init__() - self.drop_prob = drop_prob - - def forward(self, x): - return drop_path(x, self.drop_prob, self.training) - - -class FeedForward(nn.Module): - def __init__(self, dim, hidden_dim, dropout=0.): - super().__init__() - self.net = nn.Sequential( - nn.Linear(dim, hidden_dim), - nn.GELU(), - nn.Dropout(dropout), - nn.Linear(hidden_dim, dim), - nn.Dropout(dropout) - ) - - def forward(self, x): - return self.net(x) - - -class PosCNN(nn.Module): - def __init__(self, in_chans): - super(PosCNN, self).__init__() - self.proj = nn.Conv2d(in_chans, in_chans, kernel_size=3, stride=1, padding=1, bias=True, groups=in_chans) - - def forward(self, x): - B, N, C = x.shape - feat_token = x - H, W = int(N ** 0.5), int(N ** 0.5) - cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W) - x = self.proj(cnn_feat) + cnn_feat - x = x.flatten(2).transpose(1, 2) - return x - - -def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): - return _no_grad_trunc_normal_(tensor, mean, std, a, b) - - -def _no_grad_trunc_normal_(tensor, mean, std, a, b): - def norm_cdf(x): - return (1. + math.erf(x / math.sqrt(2.))) / 2. - - if (mean < a - 2 * std) or (mean > b + 2 * std): - warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_.", stacklevel=2) - with torch.no_grad(): - l = norm_cdf((a - mean) / std) - u = norm_cdf((b - mean) / std) - tensor.uniform_(2 * l - 1, 2 * u - 1) - tensor.erfinv_() - tensor.mul_(std * math.sqrt(2.)) - tensor.add_(mean) - tensor.clamp_(min=a, max=b) - return tensor - - -default_cfgs = generate_default_cfgs({ - 'csatv2': { - 'url': 'https://huggingface.co/Hyunil/CSATv2/resolve/main/CSATv2_ImageNet_timm.pth', - 'num_classes': 1000, - 'input_size': (3, 512, 512), - 'mean': (0.485, 0.456, 0.406), - 'std': (0.229, 0.224, 0.225), - 'interpolation': 'bilinear', - 'crop_pct': 1.0, - }, -}) - - -def _create_csatv2(variant: str, pretrained: bool = False, **kwargs) -> CSATv2: - return build_model_with_cfg( - CSATv2, - variant, - pretrained, - default_cfg=default_cfgs[variant], - **kwargs, - ) - - -@register_model -def csatv2(pretrained: bool = False, **kwargs) -> CSATv2: - model_args = dict( - img_size=kwargs.pop('img_size', 512), - num_classes=kwargs.pop('num_classes', 1000), - ) - return _create_csatv2('csatv2', pretrained, **dict(model_args, **kwargs)) From 9e40d5ac8077e0c7771b07be6b14a277cd1fe570 Mon Sep 17 00:00:00 2001 From: gusdlf93 Date: Mon, 8 Dec 2025 18:20:54 +0900 Subject: [PATCH 07/14] Add files via upload --- timm/models/csatv2.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/timm/models/csatv2.py b/timm/models/csatv2.py index ea74e91f77..f70c3237f5 100644 --- a/timm/models/csatv2.py +++ b/timm/models/csatv2.py @@ -195,8 +195,7 @@ def forward(self, x): # Spatial Attention logic attention = self.attention(x) - # [Fix] nn.UpsamplingBilinear2d 클래스 생성 -> F.interpolate 함수 사용 - # align_corners=False가 최신 기본값에 가깝습니다. (성능 차이는 미미함) + # [Fix] create nn.UpsamplingBilinear2d class -> use F.interpolate function attention = F.interpolate(attention, size=x.shape[2:], mode='bilinear', align_corners=False) x = x * attention @@ -241,7 +240,7 @@ def __init__(self, inp, oup, heads=8, dim_head=32, img_size=None, downsample=Fal self.pool2 = nn.MaxPool2d(3, 2, 1) self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False) else: - # [Fix] JIT 컴파일 에러 방지: 사용하지 않더라도 속성을 정의해야 함 + # [Fix] Prevent JIT compile error self.pool1 = nn.Identity() self.pool2 = nn.Identity() self.proj = nn.Identity() @@ -505,8 +504,8 @@ def forward(self, x): if self.data_format == "channels_last": return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) else: - # [Fix] elif -> else로 변경 - # JIT이 "모든 경로에서 Tensor가 반환됨"을 알 수 있게 함 + # [Fix] change elif -> else + # JIT should know Tensor return path of all cases. u = x.mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True) x = (x - u) / torch.sqrt(s + self.eps) From 8d5e51edef0afbd74079c458d844852d967911f9 Mon Sep 17 00:00:00 2001 From: gusdlf93 Date: Wed, 10 Dec 2025 00:32:45 +0900 Subject: [PATCH 08/14] Pass the Test --- timm/models/csatv2.py | 102 ++++++++++++++++++++++++++---------------- 1 file changed, 64 insertions(+), 38 deletions(-) diff --git a/timm/models/csatv2.py b/timm/models/csatv2.py index f70c3237f5..ba2dcc2f3c 100644 --- a/timm/models/csatv2.py +++ b/timm/models/csatv2.py @@ -66,13 +66,16 @@ def __init__(self, kernel_size: int, kernel_type: int = 2, orthonormal: bool = T factory_kwargs = dict(device=device, dtype=dtype) super(_DCT1D, self).__init__() kernel = {'2': _dct_kernel_type_2, '3': _dct_kernel_type_3} - self.weights = nn.Parameter(kernel[f'{kernel_type}'](kernel_size, orthonormal, **factory_kwargs).T, False) + dct_weights = kernel[f'{kernel_type}'](kernel_size, orthonormal, **factory_kwargs).T + self.register_buffer('weights', dct_weights) + self.register_parameter('bias', None) def forward(self, x: torch.Tensor) -> torch.Tensor: return nn.functional.linear(x, self.weights, self.bias) + class _DCT2D(nn.Module): def __init__(self, kernel_size: int, kernel_type: int = 2, orthonormal: bool = True, device=None, dtype=None) -> None: @@ -96,6 +99,7 @@ def __init__(self, kernel_size: int, kernel_type: int = 2, orthonormal: bool = T self.Y_Conv = nn.Conv2d(kernel_size ** 2, 24, kernel_size=1, padding=0) self.Cb_Conv = nn.Conv2d(kernel_size ** 2, 4, kernel_size=1, padding=0) self.Cr_Conv = nn.Conv2d(kernel_size ** 2, 4, kernel_size=1, padding=0) + self.mean = torch.tensor(mean, requires_grad=False) self.var = torch.tensor(var, requires_grad=False) self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406], requires_grad=False) @@ -113,9 +117,14 @@ def rgb2ycbcr(self, x): return x def frequncy_normalize(self, x): - x[:, 0, ].sub_(self.mean.to(x.device)[0]).div_((self.var.to(x.device)[0] ** 0.5 + 1e-8)) - x[:, 1, ].sub_(self.mean.to(x.device)[1]).div_((self.var.to(x.device)[1] ** 0.5 + 1e-8)) - x[:, 2, ].sub_(self.mean.to(x.device)[2]).div_((self.var.to(x.device)[2] ** 0.5 + 1e-8)) + + mean_tensor = self.mean.to(x.device) + var_tensor = self.var.to(x.device) + + std = var_tensor ** 0.5 + 1e-8 + + x = (x - mean_tensor) / std + return x def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -181,11 +190,13 @@ def __init__(self, dim, drop_path=0.): self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.attention = Spatial_Attention() - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: input = x x = self.dwconv(x) x = x.permute(0, 2, 3, 1) + x = self.norm(x) + x = self.pwconv1(x) x = self.act(x) x = self.grn(x) @@ -194,11 +205,9 @@ def forward(self, x): # Spatial Attention logic attention = self.attention(x) - - # [Fix] create nn.UpsamplingBilinear2d class -> use F.interpolate function - attention = F.interpolate(attention, size=x.shape[2:], mode='bilinear', align_corners=False) - - x = x * attention + # x = x * nn.UpsamplingBilinear2d(size=x.shape[2:])(attention) + up_attn = F.interpolate(attention, size=x.shape[2:], mode='bilinear', align_corners=True) + x = x * up_attn x = input + self.drop_path(x) return x @@ -223,11 +232,6 @@ def forward(self, x): class TransformerBlock(nn.Module): - """ - Refactored TransformerBlock without einops and PreNorm class wrapper. - Manual reshaping is performed in forward(). - """ - def __init__(self, inp, oup, heads=8, dim_head=32, img_size=None, downsample=False, dropout=0.): super().__init__() hidden_dim = int(inp * 4) @@ -240,15 +244,16 @@ def __init__(self, inp, oup, heads=8, dim_head=32, img_size=None, downsample=Fal self.pool2 = nn.MaxPool2d(3, 2, 1) self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False) else: - # [Fix] Prevent JIT compile error + # [change] for use TorchScript all variable need to declare for Identity self.pool1 = nn.Identity() self.pool2 = nn.Identity() self.proj = nn.Identity() # Attention block components - # Note: In old code, PreNorm wrapped Attention. Here we split them. self.attn_norm = nn.LayerNorm(inp) - self.attn = Attention(inp, oup, heads, dim_head, dropout) + # self.attn = Attention(inp, oup, heads, dim_head, dropout) + self.attn = Attention(inp, oup, heads, dim_head, dropout, img_size=img_size) + # FeedForward block components self.ff_norm = nn.LayerNorm(oup) @@ -316,7 +321,7 @@ class Attention(nn.Module): Refactored Attention without einops.rearrange. """ - def __init__(self, inp, oup, heads=8, dim_head=32, dropout=0.): + def __init__(self, inp, oup, heads=8, dim_head=32, dropout=0., img_size=None): super().__init__() inner_dim = dim_head * heads project_out = not (heads == 1 and dim_head == inp) @@ -332,7 +337,8 @@ def __init__(self, inp, oup, heads=8, dim_head=32, dropout=0.): nn.Linear(inner_dim, oup), nn.Dropout(dropout) ) if project_out else nn.Identity() - self.pos_embed = PosCNN(in_chans=inp) + # self.pos_embed = PosCNN(in_chans=inp) + self.pos_embed = PosCNN(in_chans=inp, img_size=img_size) def forward(self, x): # x shape: (B, N, C) @@ -387,6 +393,16 @@ def __init__( self.img_size = img_size dims = [32, 72, 168, 386] + self.num_features = dims[-1] + self.head_hidden_size = self.num_features + + self.feature_info = [ + dict(num_chs=32, reduction=8, module='dct'), + dict(num_chs=dims[1], reduction=16, module='stages1'), + dict(num_chs=dims[2], reduction=32, module='stages2'), + dict(num_chs=dims[3], reduction=64, module='stages3'), + dict(num_chs=dims[3], reduction=64, module='stages4'), + ] channel_order = "channels_first" depths = [2, 2, 6, 4] dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] @@ -479,17 +495,7 @@ def forward(self, x): return x -# --- Components like LayerNorm, GRN, DropPath, FeedForward, PosCNN, trunc_normal_ --- -# (이 부분은 einops와 무관하므로 위 코드와 동일하게 유지합니다. 여기서는 공간 절약을 위해 생략) -# 기존 코드의 LayerNorm, GRN, DropPath, FeedForward, PosCNN, trunc_normal_ 함수를 그대로 사용하세요. - class LayerNorm(nn.Module): - """ LayerNorm that supports two data formats: channels_last (default) or channels_first. - The ordering of the dimensions in the inputs. channels_last corresponds to inputs with - shape (batch_size, height, width, channels) while channels_first corresponds to inputs - with shape (batch_size, channels, height, width). - """ - def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): super().__init__() self.weight = nn.Parameter(torch.ones(normalized_shape)) @@ -500,12 +506,10 @@ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): raise NotImplementedError self.normalized_shape = (normalized_shape,) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: if self.data_format == "channels_last": return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) - else: - # [Fix] change elif -> else - # JIT should know Tensor return path of all cases. + else: #elif self.data_format == "channels_first": u = x.mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True) x = (x - u) / torch.sqrt(s + self.eps) @@ -561,20 +565,41 @@ def forward(self, x): class PosCNN(nn.Module): - def __init__(self, in_chans): + def __init__(self, in_chans, img_size=None): super(PosCNN, self).__init__() self.proj = nn.Conv2d(in_chans, in_chans, kernel_size=3, stride=1, padding=1, bias=True, groups=in_chans) + self.img_size = img_size + + # ignore JIT + safety variable type change + @torch.jit.ignore + def _get_dynamic_size(self, N): + # 1. Eager Mode (normal execution) + if isinstance(N, int): + s = int(N ** 0.5) + return s, s + + # 2. FX Tracing & Runtime + # if n is Proxy or Tensor, change to float Tensor + N_float = N * torch.tensor(1.0) # float promotion (Proxy 호환) + s = (N_float ** 0.5).to(torch.int) + return s, s def forward(self, x): B, N, C = x.shape feat_token = x - H, W = int(N ** 0.5), int(N ** 0.5) + + # JIT mode vs others + if torch.jit.is_scripting(): + H = int(N ** 0.5) + W = H + else: + H, W = self._get_dynamic_size(N) + cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W) x = self.proj(cnn_feat) + cnn_feat x = x.flatten(2).transpose(1, 2) return x - def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): return _no_grad_trunc_normal_(tensor, mean, std, a, b) @@ -605,10 +630,11 @@ def norm_cdf(x): 'std': (0.229, 0.224, 0.225), 'interpolation': 'bilinear', 'crop_pct': 1.0, + 'classifier': 'head', + 'first_conv': [], }, }) - def _create_csatv2(variant: str, pretrained: bool = False, **kwargs) -> CSATv2: return build_model_with_cfg( CSATv2, From dad1ca1ee04e92baa2d5721d923b67b412abf69e Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 9 Dec 2025 14:43:09 -0800 Subject: [PATCH 09/14] Some rework of csatv2 to better fit timm norms, re-use existing layers where possible --- tests/test_models.py | 3 +- timm/layers/attention.py | 89 ++-- timm/models/csatv2.py | 886 ++++++++++++++++++++++----------------- 3 files changed, 548 insertions(+), 430 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index d73d4512fc..f7c647979c 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -56,7 +56,8 @@ 'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*', 'tnt', 'tiny_vit', 'vovnet', 'tresnet', 'rexnet', 'resnetv2', 'repghost', 'repvit', 'pvt_v2', 'nextvit', 'nest', 'mambaout', 'inception_next', 'inception_v4', 'hgnet', 'gcvit', 'focalnet', 'efficientformer_v2', 'edgenext', - 'davit', 'rdnet', 'convnext', 'pit', 'starnet', 'shvit', 'fasternet', 'swiftformer', 'ghostnet', 'naflexvit' + 'davit', 'rdnet', 'convnext', 'pit', 'starnet', 'shvit', 'fasternet', 'swiftformer', 'ghostnet', 'naflexvit', + 'csatv2' ] # transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output. diff --git a/timm/layers/attention.py b/timm/layers/attention.py index 21329107fb..37bb6b6e9a 100644 --- a/timm/layers/attention.py +++ b/timm/layers/attention.py @@ -29,6 +29,8 @@ def __init__( self, dim: int, num_heads: int = 8, + attn_head_dim: Optional[int] = None, + dim_out: Optional[int] = None, qkv_bias: bool = False, qk_norm: bool = False, scale_norm: bool = False, @@ -37,36 +39,45 @@ def __init__( proj_drop: float = 0., norm_layer: Optional[Type[nn.Module]] = None, device=None, - dtype=None + dtype=None, ) -> None: """Initialize the Attention module. Args: - dim: Input dimension of the token embeddings - num_heads: Number of attention heads - qkv_bias: Whether to use bias in the query, key, value projections - qk_norm: Whether to apply normalization to query and key vectors - proj_bias: Whether to use bias in the output projection - attn_drop: Dropout rate applied to the attention weights - proj_drop: Dropout rate applied after the output projection - norm_layer: Normalization layer constructor for QK normalization if enabled + dim: Input dimension of the token embeddings. + num_heads: Number of attention heads. + attn_head_dim: Dimension of each attention head. If None, computed as dim // num_heads. + dim_out: Output dimension. If None, same as dim. + qkv_bias: Whether to use bias in the query, key, value projections. + qk_norm: Whether to apply normalization to query and key vectors. + scale_norm: Whether to apply normalization to attention output before projection. + proj_bias: Whether to use bias in the output projection. + attn_drop: Dropout rate applied to the attention weights. + proj_drop: Dropout rate applied after the output projection. + norm_layer: Normalization layer constructor for QK normalization if enabled. """ super().__init__() dd = {'device': device, 'dtype': dtype} - assert dim % num_heads == 0, 'dim should be divisible by num_heads' + dim_out = dim_out or dim + head_dim = attn_head_dim + if head_dim is None: + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + head_dim = dim // num_heads if qk_norm or scale_norm: assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True' + self.num_heads = num_heads - self.head_dim = dim // num_heads - self.scale = self.head_dim ** -0.5 + self.head_dim = head_dim + self.attn_dim = num_heads * head_dim + self.scale = head_dim ** -0.5 self.fused_attn = use_fused_attn() - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd) - self.q_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity() - self.k_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity() + self.qkv = nn.Linear(dim, self.attn_dim * 3, bias=qkv_bias, **dd) + self.q_norm = norm_layer(head_dim, **dd) if qk_norm else nn.Identity() + self.k_norm = norm_layer(head_dim, **dd) if qk_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) - self.norm = norm_layer(dim, **dd) if scale_norm else nn.Identity() - self.proj = nn.Linear(dim, dim, bias=proj_bias, **dd) + self.norm = norm_layer(self.attn_dim, **dd) if scale_norm else nn.Identity() + self.proj = nn.Linear(self.attn_dim, dim_out, bias=proj_bias, **dd) self.proj_drop = nn.Dropout(proj_drop) def forward( @@ -93,7 +104,7 @@ def forward( attn = self.attn_drop(attn) x = attn @ v - x = x.transpose(1, 2).reshape(B, N, C) + x = x.transpose(1, 2).reshape(B, N, self.attn_dim) x = self.norm(x) x = self.proj(x) x = self.proj_drop(x) @@ -114,6 +125,7 @@ def __init__( self, dim: int, num_heads: int = 8, + dim_out: Optional[int] = None, qkv_bias: bool = True, qkv_fused: bool = True, num_prefix_tokens: int = 1, @@ -133,45 +145,52 @@ def __init__( Args: dim: Input dimension of the token embeddings num_heads: Number of attention heads + dim_out: Output dimension. If None, same as dim. qkv_bias: Whether to add a bias term to the query, key, and value projections + qkv_fused: Whether to use fused QKV projection (single linear) or separate projections num_prefix_tokens: Number of reg/cls tokens at the beginning of the sequence that should not have position embeddings applied attn_drop: Dropout rate for attention weights proj_drop: Dropout rate for the output projection - attn_head_dim: Dimension of each attention head (if None, computed as dim // num_heads) + attn_head_dim: Dimension of each attention head. If None, computed as dim // num_heads. norm_layer: Normalization layer constructor to use for QK and scale normalization qk_norm: Enable normalization of query (Q) and key (K) vectors with norm_layer scale_norm: Enable normalization (scaling) of attention output with norm_layer + proj_bias: Whether to use bias in the output projection rotate_half: Use 'half' ROPE layout instead of default 'interleaved' """ super().__init__() dd = {'device': device, 'dtype': dtype} + dim_out = dim_out or dim + head_dim = attn_head_dim + if head_dim is None: + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + head_dim = dim // num_heads if scale_norm or qk_norm: assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True' + self.num_heads = num_heads - head_dim = dim // num_heads - if attn_head_dim is not None: - head_dim = attn_head_dim - attn_dim = head_dim * self.num_heads + self.head_dim = head_dim + self.attn_dim = head_dim * num_heads self.scale = head_dim ** -0.5 self.num_prefix_tokens = num_prefix_tokens self.fused_attn = use_fused_attn() self.rotate_half = rotate_half if qkv_fused: - self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias, **dd) + self.qkv = nn.Linear(dim, self.attn_dim * 3, bias=qkv_bias, **dd) self.q_proj = self.k_proj = self.v_proj = None else: self.qkv = None - self.q_proj = nn.Linear(dim, attn_dim, bias=qkv_bias, **dd) - self.k_proj = nn.Linear(dim, attn_dim, bias=qkv_bias, **dd) - self.v_proj = nn.Linear(dim, attn_dim, bias=qkv_bias, **dd) + self.q_proj = nn.Linear(dim, self.attn_dim, bias=qkv_bias, **dd) + self.k_proj = nn.Linear(dim, self.attn_dim, bias=qkv_bias, **dd) + self.v_proj = nn.Linear(dim, self.attn_dim, bias=qkv_bias, **dd) self.q_norm = norm_layer(head_dim, **dd) if qk_norm else nn.Identity() self.k_norm = norm_layer(head_dim, **dd) if qk_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) - self.norm = norm_layer(attn_dim, **dd) if scale_norm else nn.Identity() - self.proj = nn.Linear(attn_dim, dim, bias=proj_bias, **dd) + self.norm = norm_layer(self.attn_dim, **dd) if scale_norm else nn.Identity() + self.proj = nn.Linear(self.attn_dim, dim_out, bias=proj_bias, **dd) self.proj_drop = nn.Dropout(proj_drop) def forward( @@ -188,18 +207,18 @@ def forward( attn_mask: Optional attention mask to apply during attention computation Returns: - Tensor of shape (batch_size, sequence_length, embedding_dim) + Tensor of shape (batch_size, sequence_length, dim_out) """ B, N, C = x.shape if self.qkv is not None: qkv = self.qkv(x) - qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim else: - q = self.q_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2) # B, num_heads, N, C - k = self.k_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2) - v = self.v_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2) + q = self.q_proj(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) + k = self.k_proj(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) + v = self.v_proj(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) q, k = self.q_norm(q), self.k_norm(k) @@ -224,7 +243,7 @@ def forward( attn = self.attn_drop(attn) x = attn @ v - x = x.transpose(1, 2).reshape(B, N, C) + x = x.transpose(1, 2).reshape(B, N, self.attn_dim) x = self.norm(x) x = self.proj(x) x = self.proj_drop(x) diff --git a/timm/models/csatv2.py b/timm/models/csatv2.py index ba2dcc2f3c..c9768969f7 100644 --- a/timm/models/csatv2.py +++ b/timm/models/csatv2.py @@ -1,26 +1,84 @@ +"""CSATv2 + +A frequency-domain vision model using DCT transforms with spatial attention. + +Paper: TBD +""" +import math +import warnings +from typing import List, Optional, Tuple, Union + +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -import warnings + +from timm.layers import trunc_normal_, DropPath, Mlp, LayerNorm2d, Attention +from timm.layers.grn import GlobalResponseNorm from timm.models._builder import build_model_with_cfg +from timm.models._features import feature_take_indices from ._registry import register_model, generate_default_cfgs -from typing import List -import numpy as np -import math - -__all__ = ['DCT2D'] -# Helper Functions -mean=[[932.42657,-0.00260,0.33415,-0.02840,0.00003,-0.02792,-0.00183,0.00006,0.00032,0.03402,-0.00571,0.00020,0.00006,-0.00038,-0.00558,-0.00116,-0.00000,-0.00047,-0.00008,-0.00030,0.00942,0.00161,-0.00009,-0.00006,-0.00014,-0.00035,0.00001,-0.00220,0.00033,-0.00002,-0.00003,-0.00020,0.00007,-0.00000,0.00005,0.00293,-0.00004,0.00006,0.00019,0.00004,0.00006,-0.00015,-0.00002,0.00007,0.00010,-0.00004,0.00008,0.00000,0.00008,-0.00001,0.00015,0.00002,0.00007,0.00003,0.00004,-0.00001,0.00004,-0.00000,0.00002,-0.00000,-0.00008,-0.00000,-0.00003,0.00003], -[962.34735,-0.00428,0.09835,0.00152,-0.00009,0.00312,-0.00141,-0.00001,-0.00013,0.01050,0.00065,0.00006,-0.00000,0.00003,0.00264,0.00000,0.00001,0.00007,-0.00006,0.00003,0.00341,0.00163,0.00004,0.00003,-0.00001,0.00008,-0.00000,0.00090,0.00018,-0.00006,-0.00001,0.00007,-0.00003,-0.00001,0.00006,0.00084,-0.00000,-0.00001,0.00000,0.00004,-0.00001,-0.00002,0.00000,0.00001,0.00002,0.00001,0.00004,0.00011,0.00000,-0.00003,0.00011,-0.00002,0.00001,0.00001,0.00001,0.00001,-0.00007,-0.00003,0.00001,0.00000,0.00001,0.00002,0.00001,0.00000], -[1053.16101,-0.00213,-0.09207,0.00186,0.00013,0.00034,-0.00119,0.00002,0.00011,-0.00984,0.00046,-0.00007,-0.00001,-0.00005,0.00180,0.00042,0.00002,-0.00010,0.00004,0.00003,-0.00301,0.00125,-0.00002,-0.00003,-0.00001,-0.00001,-0.00001,0.00056,0.00021,0.00001,-0.00001,0.00002,-0.00001,-0.00001,0.00005,-0.00070,-0.00002,-0.00002,0.00005,-0.00004,-0.00000,0.00002,-0.00002,0.00001,0.00000,-0.00003,0.00004,0.00007,0.00001,0.00000,0.00013,-0.00000,0.00000,0.00002,-0.00000,-0.00001,-0.00004,-0.00003,0.00000,0.00001,-0.00001,0.00001,-0.00000,0.00000]] - -var=[[270372.37500,6287.10645,5974.94043,1653.10889,1463.91748,1832.58997,755.92468,692.41528,648.57184,641.46881,285.79288,301.62100,380.43405,349.84027,374.15891,190.30960,190.76746,221.64578,200.82646,145.87979,126.92046,62.14622,67.75562,102.42001,129.74922,130.04631,103.12189,97.76417,53.17402,54.81048,73.48712,81.04342,69.35100,49.06024,33.96053,37.03279,20.48858,24.94830,33.90822,44.54912,47.56363,40.03160,30.43313,22.63899,26.53739,26.57114,21.84404,17.41557,15.18253,10.69678,11.24111,12.97229,15.08971,15.31646,8.90409,7.44213,6.66096,6.97719,4.17834,3.83882,4.51073,2.36646,2.41363,1.48266], -[18839.21094,321.70932,300.15259,77.47830,76.02293,89.04748,33.99642,34.74807,32.12333,28.19588,12.04675,14.26871,18.45779,16.59588,15.67892,7.37718,8.56312,10.28946,9.41013,6.69090,5.16453,2.55186,3.03073,4.66765,5.85418,5.74644,4.33702,3.66948,1.95107,2.26034,3.06380,3.50705,3.06359,2.19284,1.54454,1.57860,0.97078,1.13941,1.48653,1.89996,1.95544,1.64950,1.24754,0.93677,1.09267,1.09516,0.94163,0.78966,0.72489,0.50841,0.50909,0.55664,0.63111,0.64125,0.38847,0.33378,0.30918,0.33463,0.20875,0.19298,0.21903,0.13380,0.13444,0.09554], -[17127.39844,292.81421,271.45209,66.64056,63.60253,76.35437,28.06587,27.84831,25.96656,23.60370,9.99173,11.34992,14.46955,12.92553,12.69353,5.91537,6.60187,7.90891,7.32825,5.32785,4.29660,2.13459,2.44135,3.66021,4.50335,4.38959,3.34888,2.97181,1.60633,1.77010,2.35118,2.69018,2.38189,1.74596,1.26014,1.31684,0.79327,0.92046,1.17670,1.47609,1.50914,1.28725,0.99898,0.74832,0.85736,0.85800,0.74663,0.63508,0.58748,0.41098,0.41121,0.44663,0.50277,0.51519,0.31729,0.27336,0.25399,0.27241,0.17353,0.16255,0.18440,0.11602,0.11511,0.08450]] +__all__ = ['CSATv2', 'csatv2'] + +# DCT frequency normalization statistics (Y, Cb, Cr channels x 64 coefficients) +_DCT_MEAN = [ + [932.42657, -0.00260, 0.33415, -0.02840, 0.00003, -0.02792, -0.00183, 0.00006, + 0.00032, 0.03402, -0.00571, 0.00020, 0.00006, -0.00038, -0.00558, -0.00116, + -0.00000, -0.00047, -0.00008, -0.00030, 0.00942, 0.00161, -0.00009, -0.00006, + -0.00014, -0.00035, 0.00001, -0.00220, 0.00033, -0.00002, -0.00003, -0.00020, + 0.00007, -0.00000, 0.00005, 0.00293, -0.00004, 0.00006, 0.00019, 0.00004, + 0.00006, -0.00015, -0.00002, 0.00007, 0.00010, -0.00004, 0.00008, 0.00000, + 0.00008, -0.00001, 0.00015, 0.00002, 0.00007, 0.00003, 0.00004, -0.00001, + 0.00004, -0.00000, 0.00002, -0.00000, -0.00008, -0.00000, -0.00003, 0.00003], + [962.34735, -0.00428, 0.09835, 0.00152, -0.00009, 0.00312, -0.00141, -0.00001, + -0.00013, 0.01050, 0.00065, 0.00006, -0.00000, 0.00003, 0.00264, 0.00000, + 0.00001, 0.00007, -0.00006, 0.00003, 0.00341, 0.00163, 0.00004, 0.00003, + -0.00001, 0.00008, -0.00000, 0.00090, 0.00018, -0.00006, -0.00001, 0.00007, + -0.00003, -0.00001, 0.00006, 0.00084, -0.00000, -0.00001, 0.00000, 0.00004, + -0.00001, -0.00002, 0.00000, 0.00001, 0.00002, 0.00001, 0.00004, 0.00011, + 0.00000, -0.00003, 0.00011, -0.00002, 0.00001, 0.00001, 0.00001, 0.00001, + -0.00007, -0.00003, 0.00001, 0.00000, 0.00001, 0.00002, 0.00001, 0.00000], + [1053.16101, -0.00213, -0.09207, 0.00186, 0.00013, 0.00034, -0.00119, 0.00002, + 0.00011, -0.00984, 0.00046, -0.00007, -0.00001, -0.00005, 0.00180, 0.00042, + 0.00002, -0.00010, 0.00004, 0.00003, -0.00301, 0.00125, -0.00002, -0.00003, + -0.00001, -0.00001, -0.00001, 0.00056, 0.00021, 0.00001, -0.00001, 0.00002, + -0.00001, -0.00001, 0.00005, -0.00070, -0.00002, -0.00002, 0.00005, -0.00004, + -0.00000, 0.00002, -0.00002, 0.00001, 0.00000, -0.00003, 0.00004, 0.00007, + 0.00001, 0.00000, 0.00013, -0.00000, 0.00000, 0.00002, -0.00000, -0.00001, + -0.00004, -0.00003, 0.00000, 0.00001, -0.00001, 0.00001, -0.00000, 0.00000], +] + +_DCT_VAR = [ + [270372.37500, 6287.10645, 5974.94043, 1653.10889, 1463.91748, 1832.58997, 755.92468, 692.41528, + 648.57184, 641.46881, 285.79288, 301.62100, 380.43405, 349.84027, 374.15891, 190.30960, + 190.76746, 221.64578, 200.82646, 145.87979, 126.92046, 62.14622, 67.75562, 102.42001, + 129.74922, 130.04631, 103.12189, 97.76417, 53.17402, 54.81048, 73.48712, 81.04342, + 69.35100, 49.06024, 33.96053, 37.03279, 20.48858, 24.94830, 33.90822, 44.54912, + 47.56363, 40.03160, 30.43313, 22.63899, 26.53739, 26.57114, 21.84404, 17.41557, + 15.18253, 10.69678, 11.24111, 12.97229, 15.08971, 15.31646, 8.90409, 7.44213, + 6.66096, 6.97719, 4.17834, 3.83882, 4.51073, 2.36646, 2.41363, 1.48266], + [18839.21094, 321.70932, 300.15259, 77.47830, 76.02293, 89.04748, 33.99642, 34.74807, + 32.12333, 28.19588, 12.04675, 14.26871, 18.45779, 16.59588, 15.67892, 7.37718, + 8.56312, 10.28946, 9.41013, 6.69090, 5.16453, 2.55186, 3.03073, 4.66765, + 5.85418, 5.74644, 4.33702, 3.66948, 1.95107, 2.26034, 3.06380, 3.50705, + 3.06359, 2.19284, 1.54454, 1.57860, 0.97078, 1.13941, 1.48653, 1.89996, + 1.95544, 1.64950, 1.24754, 0.93677, 1.09267, 1.09516, 0.94163, 0.78966, + 0.72489, 0.50841, 0.50909, 0.55664, 0.63111, 0.64125, 0.38847, 0.33378, + 0.30918, 0.33463, 0.20875, 0.19298, 0.21903, 0.13380, 0.13444, 0.09554], + [17127.39844, 292.81421, 271.45209, 66.64056, 63.60253, 76.35437, 28.06587, 27.84831, + 25.96656, 23.60370, 9.99173, 11.34992, 14.46955, 12.92553, 12.69353, 5.91537, + 6.60187, 7.90891, 7.32825, 5.32785, 4.29660, 2.13459, 2.44135, 3.66021, + 4.50335, 4.38959, 3.34888, 2.97181, 1.60633, 1.77010, 2.35118, 2.69018, + 2.38189, 1.74596, 1.26014, 1.31684, 0.79327, 0.92046, 1.17670, 1.47609, + 1.50914, 1.28725, 0.99898, 0.74832, 0.85736, 0.85800, 0.74663, 0.63508, + 0.58748, 0.41098, 0.41121, 0.44663, 0.50277, 0.51519, 0.31729, 0.27336, + 0.25399, 0.27241, 0.17353, 0.16255, 0.18440, 0.11602, 0.11511, 0.08450], +] def _zigzag_permutation(rows: int, cols: int) -> List[int]: + """Generate zigzag scan order for DCT coefficients.""" idx_matrix = np.arange(0, rows * cols, 1).reshape(rows, cols).tolist() dia = [[] for _ in range(rows + cols - 1)] zigzag = [] @@ -36,7 +94,13 @@ def _zigzag_permutation(rows: int, cols: int) -> List[int]: return zigzag -def _dct_kernel_type_2(kernel_size: int, orthonormal: bool, device=None, dtype=None) -> torch.Tensor: +def _dct_kernel_type_2( + kernel_size: int, + orthonormal: bool, + device=None, + dtype=None, +) -> torch.Tensor: + """Generate Type-II DCT kernel matrix.""" factory_kwargs = dict(device=device, dtype=dtype) x = torch.eye(kernel_size, **factory_kwargs) v = x.clone().contiguous().view(-1, kernel_size) @@ -44,7 +108,7 @@ def _dct_kernel_type_2(kernel_size: int, orthonormal: bool, device=None, dtype=N v = torch.fft.fft(v, dim=-1)[:, :kernel_size] try: k = torch.tensor(-1j, **factory_kwargs) * torch.pi * torch.arange(kernel_size, **factory_kwargs)[None, :] - except: + except AttributeError: k = torch.tensor(-1j, **factory_kwargs) * math.pi * torch.arange(kernel_size, **factory_kwargs)[None, :] k = torch.exp(k / (kernel_size * 2)) v = v * k @@ -56,109 +120,138 @@ def _dct_kernel_type_2(kernel_size: int, orthonormal: bool, device=None, dtype=N return v -def _dct_kernel_type_3(kernel_size: int, orthonormal: bool, device=None, dtype=None) -> torch.Tensor: +def _dct_kernel_type_3( + kernel_size: int, + orthonormal: bool, + device=None, + dtype=None, +) -> torch.Tensor: + """Generate Type-III DCT kernel matrix (inverse of Type-II).""" return torch.linalg.inv(_dct_kernel_type_2(kernel_size, orthonormal, device, dtype)) -class _DCT1D(nn.Module): - def __init__(self, kernel_size: int, kernel_type: int = 2, orthonormal: bool = True, - device=None, dtype=None) -> None: +class Dct1d(nn.Module): + """1D Discrete Cosine Transform layer.""" + + def __init__( + self, + kernel_size: int, + kernel_type: int = 2, + orthonormal: bool = True, + device=None, + dtype=None, + ) -> None: factory_kwargs = dict(device=device, dtype=dtype) - super(_DCT1D, self).__init__() + super().__init__() kernel = {'2': _dct_kernel_type_2, '3': _dct_kernel_type_3} dct_weights = kernel[f'{kernel_type}'](kernel_size, orthonormal, **factory_kwargs).T self.register_buffer('weights', dct_weights) - self.register_parameter('bias', None) def forward(self, x: torch.Tensor) -> torch.Tensor: - return nn.functional.linear(x, self.weights, self.bias) + return F.linear(x, self.weights, self.bias) +class Dct2d(nn.Module): + """2D Discrete Cosine Transform layer.""" -class _DCT2D(nn.Module): - def __init__(self, kernel_size: int, kernel_type: int = 2, orthonormal: bool = True, - device=None, dtype=None) -> None: + def __init__( + self, + kernel_size: int, + kernel_type: int = 2, + orthonormal: bool = True, + device=None, + dtype=None, + ) -> None: factory_kwargs = dict(device=device, dtype=dtype) - super(_DCT2D, self).__init__() - self.transform = _DCT1D(kernel_size, kernel_type, orthonormal, **factory_kwargs) + super().__init__() + self.transform = Dct1d(kernel_size, kernel_type, orthonormal, **factory_kwargs) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.transform(self.transform(x).transpose(-1, -2)).transpose(-1, -2) -class Learnable_DCT2D(nn.Module): - def __init__(self, kernel_size: int, kernel_type: int = 2, orthonormal: bool = True, - device=None, dtype=None) -> None: +class LearnableDct2d(nn.Module): + """Learnable 2D DCT stem with RGB to YCbCr conversion and frequency selection.""" + + def __init__( + self, + kernel_size: int, + kernel_type: int = 2, + orthonormal: bool = True, + device=None, + dtype=None, + ) -> None: factory_kwargs = dict(device=device, dtype=dtype) - super(Learnable_DCT2D, self).__init__() + super().__init__() self.k = kernel_size self.unfold = nn.Unfold(kernel_size=(kernel_size, kernel_size), stride=(kernel_size, kernel_size)) - self.transform = _DCT2D(kernel_size, kernel_type, orthonormal, **factory_kwargs) + self.transform = Dct2d(kernel_size, kernel_type, orthonormal, **factory_kwargs) self.permutation = _zigzag_permutation(kernel_size, kernel_size) self.Y_Conv = nn.Conv2d(kernel_size ** 2, 24, kernel_size=1, padding=0) self.Cb_Conv = nn.Conv2d(kernel_size ** 2, 4, kernel_size=1, padding=0) self.Cr_Conv = nn.Conv2d(kernel_size ** 2, 4, kernel_size=1, padding=0) - self.mean = torch.tensor(mean, requires_grad=False) - self.var = torch.tensor(var, requires_grad=False) - self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406], requires_grad=False) - self.imagenet_var = torch.tensor([0.229, 0.224, 0.225], requires_grad=False) + self.register_buffer('mean', torch.tensor(_DCT_MEAN), persistent=False) + self.register_buffer('var', torch.tensor(_DCT_VAR), persistent=False) + self.register_buffer('imagenet_mean', torch.tensor([0.485, 0.456, 0.406]), persistent=False) + self.register_buffer('imagenet_std', torch.tensor([0.229, 0.224, 0.225]), persistent=False) - def denormalize(self, x): - x = x.multiply(self.imagenet_var.to(x.device)).add_(self.imagenet_mean.to(x.device)) * 255 - return x + def _denormalize(self, x: torch.Tensor) -> torch.Tensor: + """Convert from ImageNet normalized to [0, 255] range.""" + return x.mul(self.imagenet_std).add_(self.imagenet_mean) * 255 - def rgb2ycbcr(self, x): + def _rgb_to_ycbcr(self, x: torch.Tensor) -> torch.Tensor: + """Convert RGB to YCbCr color space.""" y = (x[:, :, :, 0] * 0.299) + (x[:, :, :, 1] * 0.587) + (x[:, :, :, 2] * 0.114) cb = 0.564 * (x[:, :, :, 2] - y) + 128 cr = 0.713 * (x[:, :, :, 0] - y) + 128 - x = torch.stack([y, cb, cr], dim=-1) - return x - - def frequncy_normalize(self, x): - - mean_tensor = self.mean.to(x.device) - var_tensor = self.var.to(x.device) + return torch.stack([y, cb, cr], dim=-1) - std = var_tensor ** 0.5 + 1e-8 - - x = (x - mean_tensor) / std - - return x + def _frequency_normalize(self, x: torch.Tensor) -> torch.Tensor: + """Normalize DCT coefficients using precomputed statistics.""" + std = self.var ** 0.5 + 1e-8 + return (x - self.mean) / std def forward(self, x: torch.Tensor) -> torch.Tensor: b, c, h, w = x.shape x = x.permute(0, 2, 3, 1) - x = self.denormalize(x) - x = self.rgb2ycbcr(x) + x = self._denormalize(x) + x = self._rgb_to_ycbcr(x) x = x.permute(0, 3, 1, 2) x = self.unfold(x).transpose(-1, -2) x = x.reshape(b, h // self.k, w // self.k, c, self.k, self.k) x = self.transform(x) x = x.reshape(-1, c, self.k * self.k) x = x[:, :, self.permutation] - x = self.frequncy_normalize(x) + x = self._frequency_normalize(x) x = x.reshape(b, h // self.k, w // self.k, c, -1) x = x.permute(0, 3, 4, 1, 2).contiguous() - x_Y = self.Y_Conv(x[:, 0, ]) - x_Cb = self.Cb_Conv(x[:, 1, ]) - x_Cr = self.Cr_Conv(x[:, 2, ]) - x = torch.cat([x_Y, x_Cb, x_Cr], dim=1) - return x + x_y = self.Y_Conv(x[:, 0]) + x_cb = self.Cb_Conv(x[:, 1]) + x_cr = self.Cr_Conv(x[:, 2]) + return torch.cat([x_y, x_cb, x_cr], dim=1) + +class Dct2dStats(nn.Module): + """Utility module to compute DCT coefficient statistics.""" -class DCT2D(nn.Module): - def __init__(self, kernel_size: int, kernel_type: int = 2, orthonormal: bool = True, - device=None, dtype=None) -> None: + def __init__( + self, + kernel_size: int, + kernel_type: int = 2, + orthonormal: bool = True, + device=None, + dtype=None, + ) -> None: factory_kwargs = dict(device=device, dtype=dtype) - super(DCT2D, self).__init__() + super().__init__() self.k = kernel_size self.unfold = nn.Unfold(kernel_size=(kernel_size, kernel_size), stride=(kernel_size, kernel_size)) - self.transform = _DCT2D(kernel_size, kernel_type, orthonormal, **factory_kwargs) + self.transform = Dct2d(kernel_size, kernel_type, orthonormal, **factory_kwargs) self.permutation = _zigzag_permutation(kernel_size, kernel_size) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: b, c, h, w = x.shape x = self.unfold(x).transpose(-1, -2) x = x.reshape(b, h // self.k, w // self.k, c, self.k, self.k) @@ -169,61 +262,106 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: mean_list = torch.zeros([3, 64]) var_list = torch.zeros([3, 64]) - mean_list[0] = torch.mean(x[:, 0, ], dim=0) - mean_list[1] = torch.mean(x[:, 1, ], dim=0) - mean_list[2] = torch.mean(x[:, 2, ], dim=0) - var_list[0] = torch.var(x[:, 0, ], dim=0) - var_list[1] = torch.var(x[:, 1, ], dim=0) - var_list[2] = torch.var(x[:, 2, ], dim=0) + for i in range(3): + mean_list[i] = torch.mean(x[:, i], dim=0) + var_list[i] = torch.var(x[:, i], dim=0) return mean_list, var_list class Block(nn.Module): - def __init__(self, dim, drop_path=0.): + """ConvNeXt-style block with spatial attention.""" + + def __init__( + self, + dim: int, + drop_path: float = 0., + ) -> None: super().__init__() self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) - self.norm = LayerNorm(dim, eps=1e-6) + self.norm = nn.LayerNorm(dim, eps=1e-6) self.pwconv1 = nn.Linear(dim, 4 * dim) self.act = nn.GELU() - self.grn = GRN(4 * dim) + self.grn = GlobalResponseNorm(4 * dim, channels_last=True) self.pwconv2 = nn.Linear(4 * dim, dim) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.attention = Spatial_Attention() + self.attention = SpatialAttention() def forward(self, x: torch.Tensor) -> torch.Tensor: - input = x + shortcut = x x = self.dwconv(x) x = x.permute(0, 2, 3, 1) - x = self.norm(x) - x = self.pwconv1(x) x = self.act(x) x = self.grn(x) x = self.pwconv2(x) x = x.permute(0, 3, 1, 2) - # Spatial Attention logic attention = self.attention(x) - # x = x * nn.UpsamplingBilinear2d(size=x.shape[2:])(attention) up_attn = F.interpolate(attention, size=x.shape[2:], mode='bilinear', align_corners=True) x = x * up_attn - x = input + self.drop_path(x) + return shortcut + self.drop_path(x) + + +class SpatialTransformerBlock(nn.Module): + """Lightweight transformer block for spatial attention (1-channel, 7x7 grid). + + This is a simplified transformer with single-head, 1-dim attention over spatial + positions. Used inside SpatialAttention where input is 1 channel at 7x7 resolution. + """ + + def __init__(self) -> None: + super().__init__() + # Single-head attention with 1-dim q/k/v (no output projection needed) + self.pos_embed = PosConv(in_chans=1) + self.norm1 = nn.LayerNorm(1) + self.qkv = nn.Linear(1, 3, bias=False) + + # Feedforward: 1 -> 4 -> 1 + self.norm2 = nn.LayerNorm(1) + self.mlp = Mlp(1, 4, 1, act_layer=nn.GELU) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, C, H, W = x.shape + + # Attention block + shortcut = x + x_t = x.flatten(2).transpose(1, 2) # (B, N, 1) + x_t = self.norm1(x_t) + x_t = self.pos_embed(x_t, (H, W)) + + # Simple single-head attention with scalar q/k/v + qkv = self.qkv(x_t) # (B, N, 3) + q, k, v = qkv.unbind(-1) # each (B, N) + attn = (q @ k.transpose(-1, -2)).softmax(dim=-1) # (B, N, N) + x_t = (attn @ v).unsqueeze(-1) # (B, N, 1) + + x_t = x_t.transpose(1, 2).reshape(B, C, H, W) + x = shortcut + x_t + + # Feedforward block + shortcut = x + x_t = x.flatten(2).transpose(1, 2) + x_t = self.mlp(self.norm2(x_t)) + x_t = x_t.transpose(1, 2).reshape(B, C, H, W) + x = shortcut + x_t + return x -class Spatial_Attention(nn.Module): - def __init__(self): +class SpatialAttention(nn.Module): + """Spatial attention module using channel statistics and transformer.""" + + def __init__(self) -> None: super().__init__() self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3) - # downsample=False, dropout=0. by default - self.attention = TransformerBlock(1, 1, heads=1, dim_head=1, img_size=[7, 7], downsample=False) + self.attention = SpatialTransformerBlock() - def forward(self, x): - x_avg = x.mean([1]).unsqueeze(1) - x_max = x.max(dim=1).values.unsqueeze(1) + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_avg = x.mean(dim=1, keepdim=True) + x_max = x.amax(dim=1, keepdim=True) x = torch.cat([x_avg, x_max], dim=1) x = self.avgpool(x) x = self.conv(x) @@ -232,393 +370,276 @@ def forward(self, x): class TransformerBlock(nn.Module): - def __init__(self, inp, oup, heads=8, dim_head=32, img_size=None, downsample=False, dropout=0.): + """Transformer block with optional downsampling and convolutional position encoding.""" + + def __init__( + self, + inp: int, + oup: int, + num_heads: int = 8, + attn_head_dim: int = 32, + downsample: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., + ) -> None: super().__init__() hidden_dim = int(inp * 4) - self.downsample = downsample - self.ih, self.iw = img_size if self.downsample: self.pool1 = nn.MaxPool2d(3, 2, 1) self.pool2 = nn.MaxPool2d(3, 2, 1) self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False) else: - # [change] for use TorchScript all variable need to declare for Identity self.pool1 = nn.Identity() self.pool2 = nn.Identity() self.proj = nn.Identity() - # Attention block components - self.attn_norm = nn.LayerNorm(inp) - # self.attn = Attention(inp, oup, heads, dim_head, dropout) - self.attn = Attention(inp, oup, heads, dim_head, dropout, img_size=img_size) - + self.pos_embed = PosConv(in_chans=inp) + self.norm1 = nn.LayerNorm(inp) + self.attn = Attention( + dim=inp, + num_heads=num_heads, + attn_head_dim=attn_head_dim, + dim_out=oup, + attn_drop=attn_drop, + proj_drop=proj_drop, + ) - # FeedForward block components - self.ff_norm = nn.LayerNorm(oup) - self.ff = FeedForward(oup, hidden_dim, dropout) + self.norm2 = nn.LayerNorm(oup) + self.mlp = Mlp(oup, hidden_dim, oup, act_layer=nn.GELU, drop=proj_drop) - def forward(self, x): - # x shape: (B, C, H, W) + def forward(self, x: torch.Tensor) -> torch.Tensor: if self.downsample: - # Identity path with projection shortcut = self.proj(self.pool1(x)) - - # Attention path x_t = self.pool2(x) B, C, H, W = x_t.shape - - # Flatten spatial: (B, C, H, W) -> (B, H*W, C) x_t = x_t.flatten(2).transpose(1, 2) - - # PreNorm -> Attention - x_t = self.attn_norm(x_t) + x_t = self.norm1(x_t) + x_t = self.pos_embed(x_t, (H, W)) x_t = self.attn(x_t) - - # Unflatten: (B, H*W, C) -> (B, C, H, W) - x_t = x_t.transpose(1, 2).reshape(B, C, H, W) - + x_t = x_t.transpose(1, 2).reshape(B, -1, H, W) x = shortcut + x_t else: - # Standard PreNorm Residual Attention B, C, H, W = x.shape shortcut = x - - # Flatten x_t = x.flatten(2).transpose(1, 2) - - # PreNorm -> Attention - x_t = self.attn_norm(x_t) + x_t = self.norm1(x_t) + x_t = self.pos_embed(x_t, (H, W)) x_t = self.attn(x_t) - - # Unflatten - x_t = x_t.transpose(1, 2).reshape(B, C, H, W) - + x_t = x_t.transpose(1, 2).reshape(B, -1, H, W) x = shortcut + x_t - # FeedForward Block + # MLP block B, C, H, W = x.shape shortcut = x - - # Flatten x_t = x.flatten(2).transpose(1, 2) - - # PreNorm -> FeedForward - x_t = self.ff_norm(x_t) - x_t = self.ff(x_t) - - # Unflatten + x_t = self.mlp(self.norm2(x_t)) x_t = x_t.transpose(1, 2).reshape(B, C, H, W) - x = shortcut + x_t return x -class Attention(nn.Module): - """ - Refactored Attention without einops.rearrange. - """ +class PosConv(nn.Module): + """Convolutional position encoding.""" - def __init__(self, inp, oup, heads=8, dim_head=32, dropout=0., img_size=None): + def __init__( + self, + in_chans: int, + ) -> None: super().__init__() - inner_dim = dim_head * heads - project_out = not (heads == 1 and dim_head == inp) - - self.heads = heads - self.dim_head = dim_head - self.scale = dim_head ** -0.5 - - self.attend = nn.Softmax(dim=-1) - self.to_qkv = nn.Linear(inp, inner_dim * 3, bias=False) - - self.to_out = nn.Sequential( - nn.Linear(inner_dim, oup), - nn.Dropout(dropout) - ) if project_out else nn.Identity() - # self.pos_embed = PosCNN(in_chans=inp) - self.pos_embed = PosCNN(in_chans=inp, img_size=img_size) - - def forward(self, x): - # x shape: (B, N, C) - # Positional Embedding (expects B, N, C -> internally converts to spatial) - x = self.pos_embed(x) + self.proj = nn.Conv2d(in_chans, in_chans, kernel_size=3, stride=1, padding=1, bias=True, groups=in_chans) + def forward(self, x: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor: B, N, C = x.shape + H, W = size + cnn_feat = x.transpose(1, 2).view(B, C, H, W) + x = self.proj(cnn_feat) + cnn_feat + return x.flatten(2).transpose(1, 2) - # Generate Q, K, V - # qkv: (B, N, 3 * heads * dim_head) - qkv = self.to_qkv(x) - - # Reshape to (B, N, 3, heads, dim_head) and permute to (3, B, heads, N, dim_head) - qkv = qkv.reshape(B, N, 3, self.heads, self.dim_head).permute(2, 0, 3, 1, 4) - - q, k, v = qkv[0], qkv[1], qkv[2] # Each is (B, heads, N, dim_head) - - # Attention Score - # (B, heads, N, dim_head) @ (B, heads, dim_head, N) -> (B, heads, N, N) - dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale - attn = self.attend(dots) - - # Weighted Sum - # (B, heads, N, N) @ (B, heads, N, dim_head) -> (B, heads, N, dim_head) - out = torch.matmul(attn, v) - - # Rearrange output: (B, heads, N, dim_head) -> (B, N, heads*dim_head) - out = out.transpose(1, 2).reshape(B, N, -1) - out = self.to_out(out) - return out +class CSATv2(nn.Module): + """CSATv2: Frequency-domain vision model with spatial attention. + A hybrid architecture that processes images in the DCT frequency domain + with ConvNeXt-style blocks and transformer attention. + """ -class CSATv2(nn.Module): def __init__( self, - img_size=512, - num_classes=1000, - in_chans=3, - drop_path_rate=0.0, - global_pool='avg', + num_classes: int = 1000, + in_chans: int = 3, + drop_path_rate: float = 0.0, + global_pool: str = 'avg', **kwargs, - ): + ) -> None: super().__init__() - self.num_classes = num_classes self.global_pool = global_pool - if isinstance(img_size, (tuple, list)): - img_size = img_size[0] - - self.img_size = img_size - dims = [32, 72, 168, 386] self.num_features = dims[-1] self.head_hidden_size = self.num_features self.feature_info = [ - dict(num_chs=32, reduction=8, module='dct'), - dict(num_chs=dims[1], reduction=16, module='stages1'), - dict(num_chs=dims[2], reduction=32, module='stages2'), - dict(num_chs=dims[3], reduction=64, module='stages3'), - dict(num_chs=dims[3], reduction=64, module='stages4'), + dict(num_chs=dims[0], reduction=8, module='stem_dct'), + dict(num_chs=dims[0], reduction=8, module='stages.0'), + dict(num_chs=dims[1], reduction=16, module='stages.1'), + dict(num_chs=dims[2], reduction=32, module='stages.2'), + dict(num_chs=dims[3], reduction=64, module='stages.3'), ] - channel_order = "channels_first" + depths = [2, 2, 6, 4] dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] - self.dct = Learnable_DCT2D(8) - - self.stages1 = nn.Sequential( - Block(dim=dims[0], drop_path=dp_rates[0]), - Block(dim=dims[0], drop_path=dp_rates[1]), - LayerNorm(dims[0], eps=1e-6, data_format=channel_order), - nn.Conv2d(dims[0], dims[0 + 1], kernel_size=2, stride=2), - ) - - self.stages2 = nn.Sequential( - Block(dim=dims[1], drop_path=dp_rates[0]), - Block(dim=dims[1], drop_path=dp_rates[1]), - LayerNorm(dims[1], eps=1e-6, data_format=channel_order), - nn.Conv2d(dims[1], dims[1 + 1], kernel_size=2, stride=2), - ) + self.stem_dct = LearnableDct2d(8) - self.stages3 = nn.Sequential( - Block(dim=dims[2], drop_path=dp_rates[0]), - Block(dim=dims[2], drop_path=dp_rates[1]), - Block(dim=dims[2], drop_path=dp_rates[2]), - Block(dim=dims[2], drop_path=dp_rates[3]), - Block(dim=dims[2], drop_path=dp_rates[4]), - Block(dim=dims[2], drop_path=dp_rates[5]), - TransformerBlock( - inp=dims[2], - oup=dims[2], - img_size=[int(self.img_size / 32), int(self.img_size / 32)], + self.stages = nn.Sequential( + nn.Sequential( + Block(dim=dims[0], drop_path=dp_rates[0]), + Block(dim=dims[0], drop_path=dp_rates[1]), + LayerNorm2d(dims[0], eps=1e-6), ), - TransformerBlock( - inp=dims[2], - oup=dims[2], - img_size=[int(self.img_size / 32), int(self.img_size / 32)], + nn.Sequential( + nn.Conv2d(dims[0], dims[1], kernel_size=2, stride=2), + Block(dim=dims[1], drop_path=dp_rates[2]), + Block(dim=dims[1], drop_path=dp_rates[3]), + LayerNorm2d(dims[1], eps=1e-6), ), - LayerNorm(dims[2], eps=1e-6, data_format=channel_order), - nn.Conv2d(dims[2], dims[2 + 1], kernel_size=2, stride=2), - ) - - self.stages4 = nn.Sequential( - Block(dim=dims[3], drop_path=dp_rates[0]), - Block(dim=dims[3], drop_path=dp_rates[1]), - Block(dim=dims[3], drop_path=dp_rates[2]), - Block(dim=dims[3], drop_path=dp_rates[3]), - TransformerBlock( - inp=dims[3], - oup=dims[3], - img_size=[int(self.img_size / 64), int(self.img_size / 64)], + nn.Sequential( + nn.Conv2d(dims[1], dims[2], kernel_size=2, stride=2), + Block(dim=dims[2], drop_path=dp_rates[4]), + Block(dim=dims[2], drop_path=dp_rates[5]), + Block(dim=dims[2], drop_path=dp_rates[6]), + Block(dim=dims[2], drop_path=dp_rates[7]), + Block(dim=dims[2], drop_path=dp_rates[8]), + Block(dim=dims[2], drop_path=dp_rates[9]), + TransformerBlock(inp=dims[2], oup=dims[2]), + TransformerBlock(inp=dims[2], oup=dims[2]), + LayerNorm2d(dims[2], eps=1e-6), ), - TransformerBlock( - inp=dims[3], - oup=dims[3], - img_size=[int(self.img_size / 64), int(self.img_size / 64)], + nn.Sequential( + nn.Conv2d(dims[2], dims[3], kernel_size=2, stride=2), + Block(dim=dims[3], drop_path=dp_rates[10]), + Block(dim=dims[3], drop_path=dp_rates[11]), + Block(dim=dims[3], drop_path=dp_rates[12]), + Block(dim=dims[3], drop_path=dp_rates[13]), + TransformerBlock(inp=dims[3], oup=dims[3]), + TransformerBlock(inp=dims[3], oup=dims[3]), ), ) self.norm = nn.LayerNorm(dims[-1], eps=1e-6) - self.head = nn.Linear(dims[-1], num_classes) if num_classes > 0 else nn.Identity() self.apply(self._init_weights) - def _init_weights(self, m): + def _init_weights(self, m: nn.Module) -> None: if isinstance(m, (nn.Conv2d, nn.Linear)): - trunc_normal_(m.weight, std=.02) + trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.constant_(m.bias, 0) - def forward_features(self, x): - x = self.dct(x) - x = self.stages1(x) - x = self.stages2(x) - x = self.stages3(x) - x = self.stages4(x) - x = x.mean(dim=(-2, -1)) - x = self.norm(x) - return x - - def forward_head(self, x, pre_logits: bool = False): - if pre_logits: - return x - x = self.head(x) - return x - - def forward(self, x): - x = self.forward_features(x) - x = self.forward_head(x) - return x - - -class LayerNorm(nn.Module): - def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): - super().__init__() - self.weight = nn.Parameter(torch.ones(normalized_shape)) - self.bias = nn.Parameter(torch.zeros(normalized_shape)) - self.eps = eps - self.data_format = data_format - if self.data_format not in ["channels_last", "channels_first"]: - raise NotImplementedError - self.normalized_shape = (normalized_shape,) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.data_format == "channels_last": - return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) - else: #elif self.data_format == "channels_first": - u = x.mean(1, keepdim=True) - s = (x - u).pow(2).mean(1, keepdim=True) - x = (x - u) / torch.sqrt(s + self.eps) - x = self.weight[:, None, None] * x + self.bias[:, None, None] - return x - - -class GRN(nn.Module): - def __init__(self, dim): - super().__init__() - self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) - self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) - - def forward(self, x): - Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) - Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) - return self.gamma * (x * Nx) + self.beta + x + @torch.jit.ignore + def get_classifier(self) -> nn.Module: + return self.head + def reset_classifier(self, num_classes: int, global_pool: str = 'avg') -> None: + self.num_classes = num_classes + self.global_pool = global_pool + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() -def drop_path(x, drop_prob: float = 0., training: bool = False): - if drop_prob == 0. or not training: + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + x = self.stem_dct(x) + x = self.stages(x) return x - keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * (x.ndim - 1) - random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) - random_tensor.floor_() - output = x.div(keep_prob) * random_tensor - return output - - -class DropPath(nn.Module): - def __init__(self, drop_prob=None): - super(DropPath, self).__init__() - self.drop_prob = drop_prob - - def forward(self, x): - return drop_path(x, self.drop_prob, self.training) - - -class FeedForward(nn.Module): - def __init__(self, dim, hidden_dim, dropout=0.): - super().__init__() - self.net = nn.Sequential( - nn.Linear(dim, hidden_dim), - nn.GELU(), - nn.Dropout(dropout), - nn.Linear(hidden_dim, dim), - nn.Dropout(dropout) - ) - - def forward(self, x): - return self.net(x) - - -class PosCNN(nn.Module): - def __init__(self, in_chans, img_size=None): - super(PosCNN, self).__init__() - self.proj = nn.Conv2d(in_chans, in_chans, kernel_size=3, stride=1, padding=1, bias=True, groups=in_chans) - self.img_size = img_size - - # ignore JIT + safety variable type change - @torch.jit.ignore - def _get_dynamic_size(self, N): - # 1. Eager Mode (normal execution) - if isinstance(N, int): - s = int(N ** 0.5) - return s, s - - # 2. FX Tracing & Runtime - # if n is Proxy or Tensor, change to float Tensor - N_float = N * torch.tensor(1.0) # float promotion (Proxy 호환) - s = (N_float ** 0.5).to(torch.int) - return s, s - - def forward(self, x): - B, N, C = x.shape - feat_token = x - # JIT mode vs others - if torch.jit.is_scripting(): - H = int(N ** 0.5) - W = H + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """Forward pass returning intermediate features. + + Args: + x: Input image tensor. + indices: Indices of features to return (0=stem_dct, 1-4=stages). None returns all. + norm: Apply norm layer to final intermediate (unused, for API compat). + stop_early: Stop iterating when last desired intermediate is reached. + output_fmt: Output format, must be 'NCHW'. + intermediates_only: Only return intermediate features. + + Returns: + List of intermediate features or tuple of (final features, intermediates). + """ + assert output_fmt == 'NCHW', 'Output format must be NCHW.' + intermediates = [] + # 5 feature levels: stem_dct (0) + stages 0-3 (1-4) + take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices) + + x = self.stem_dct(x) + if 0 in take_indices: + intermediates.append(x) + + if torch.jit.is_scripting() or not stop_early: + stages = self.stages else: - H, W = self._get_dynamic_size(N) + # max_index is 0-4, stages are 1-4, so we need max_index stages + stages = self.stages[:max_index] if max_index > 0 else [] - cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W) - x = self.proj(cnn_feat) + cnn_feat - x = x.flatten(2).transpose(1, 2) - return x + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx + 1 in take_indices: # +1 because stem is index 0 + intermediates.append(x) -def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): - return _no_grad_trunc_normal_(tensor, mean, std, a, b) + if intermediates_only: + return intermediates + return x, intermediates -def _no_grad_trunc_normal_(tensor, mean, std, a, b): - def norm_cdf(x): - return (1. + math.erf(x / math.sqrt(2.))) / 2. + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ) -> List[int]: + """Prune layers not required for specified intermediates. + + Args: + indices: Indices of intermediate layers to keep (0=stem_dct, 1-4=stages). + prune_norm: Whether to prune the final norm layer. + prune_head: Whether to prune the classifier head. + + Returns: + List of indices that were kept. + """ + # 5 feature levels: stem_dct (0) + stages 0-3 (1-4) + take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices) + # max_index is 0-4, stages are 1-4, so we keep max_index stages + self.stages = self.stages[:max_index] if max_index > 0 else nn.Sequential() + + if prune_norm: + self.norm = nn.Identity() + if prune_head: + self.reset_classifier(0, '') + + return take_indices + + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + x = x.mean(dim=(-2, -1)) + x = self.norm(x) + if pre_logits: + return x + return self.head(x) - if (mean < a - 2 * std) or (mean > b + 2 * std): - warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_.", stacklevel=2) - with torch.no_grad(): - l = norm_cdf((a - mean) / std) - u = norm_cdf((b - mean) / std) - tensor.uniform_(2 * l - 1, 2 * u - 1) - tensor.erfinv_() - tensor.mul_(std * math.sqrt(2.)) - tensor.add_(mean) - tensor.clamp_(min=a, max=b) - return tensor + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.forward_features(x) + return self.forward_head(x) default_cfgs = generate_default_cfgs({ @@ -635,11 +656,92 @@ def norm_cdf(x): }, }) + +def checkpoint_filter_fn(state_dict: dict, model: nn.Module) -> dict: + """Remap original CSATv2 checkpoint to timm format. + + Handles two key structural changes: + 1. Stage naming: stages1/2/3/4 -> stages.0/1/2/3 + 2. Downsample position: moved from end of stage N to start of stage N+1 + """ + if 'grn.weight' in state_dict or 'stages.0.0.grn.weight' in state_dict: + return state_dict # Already in timm format + + import re + + # Downsample indices in original checkpoint (Conv2d at end of each stage) + # These move to index 0 of the next stage + downsample_idx = {1: 3, 2: 3, 3: 9} # stage -> downsample index + + def remap_stage(m): + stage = int(m.group(1)) + idx = int(m.group(2)) + rest = m.group(3) + if stage in downsample_idx and idx == downsample_idx[stage]: + # Downsample moves to start of next stage + return f'stages.{stage}.0.{rest}' + elif stage == 1: + # Stage 1 -> stages.0, indices unchanged + return f'stages.0.{idx}.{rest}' + else: + # Stages 2-4 -> stages.1-3, indices shift +1 (after downsample) + return f'stages.{stage - 1}.{idx + 1}.{rest}' + + out_dict = {} + for k, v in state_dict.items(): + # Remap dct -> stem_dct + k = re.sub(r'^dct\.', 'stem_dct.', k) + + # Remap stage names with index adjustments for downsample relocation + k = re.sub(r'^stages([1-4])\.(\d+)\.(.*)$', remap_stage, k) + + # Remap GRN: gamma/beta -> weight/bias with reshape + if 'grn.gamma' in k: + k = k.replace('grn.gamma', 'grn.weight') + v = v.reshape(-1) + elif 'grn.beta' in k: + k = k.replace('grn.beta', 'grn.bias') + v = v.reshape(-1) + + # Remap FeedForward (nn.Sequential) to Mlp: net.0 -> fc1, net.3 -> fc2 + # Also rename ff -> mlp, ff_norm -> norm2, attn_norm -> norm1 + if '.ff.net.0.' in k: + k = k.replace('.ff.net.0.', '.mlp.fc1.') + elif '.ff.net.3.' in k: + k = k.replace('.ff.net.3.', '.mlp.fc2.') + elif '.ff_norm.' in k: + k = k.replace('.ff_norm.', '.norm2.') + elif '.attn_norm.' in k: + k = k.replace('.attn_norm.', '.norm1.') + + # SpatialTransformerBlock: flatten .attention.attention.attn. -> .attention.attention. + # and remap to_qkv -> qkv + if '.attention.attention.attn.' in k: + k = k.replace('.attention.attention.attn.to_qkv.', '.attention.attention.qkv.') + k = k.replace('.attention.attention.attn.', '.attention.attention.') + + # TransformerBlock: remap attention layer names + # to_qkv -> qkv, to_out.0 -> proj, attn.pos_embed -> pos_embed + if '.attn.to_qkv.' in k: + k = k.replace('.attn.to_qkv.', '.attn.qkv.') + elif '.attn.to_out.0.' in k: + k = k.replace('.attn.to_out.0.', '.attn.proj.') + + if '.attn.pos_embed.' in k: + k = k.replace('.attn.pos_embed.', '.pos_embed.') + + out_dict[k] = v + + return out_dict + + def _create_csatv2(variant: str, pretrained: bool = False, **kwargs) -> CSATv2: return build_model_with_cfg( CSATv2, variant, pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(out_indices=(0, 1, 2, 3, 4), flatten_sequential=True), default_cfg=default_cfgs[variant], **kwargs, ) @@ -647,8 +749,4 @@ def _create_csatv2(variant: str, pretrained: bool = False, **kwargs) -> CSATv2: @register_model def csatv2(pretrained: bool = False, **kwargs) -> CSATv2: - model_args = dict( - img_size=kwargs.pop('img_size', 512), - num_classes=kwargs.pop('num_classes', 1000), - ) - return _create_csatv2('csatv2', pretrained, **dict(model_args, **kwargs)) + return _create_csatv2('csatv2', pretrained, **kwargs) From b0cb7448d324d34d28f6d9e64a52826be8e54841 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 9 Dec 2025 15:34:53 -0800 Subject: [PATCH 10/14] Add grad_checkpointing. Remap head to use NormMlpClassifierHead and final norm is 2d so we can disable pooling if desired. Still inconsistent line endings --- timm/models/csatv2.py | 1516 +++++++++++++++++++++-------------------- 1 file changed, 764 insertions(+), 752 deletions(-) diff --git a/timm/models/csatv2.py b/timm/models/csatv2.py index c9768969f7..1d29c8b21c 100644 --- a/timm/models/csatv2.py +++ b/timm/models/csatv2.py @@ -1,752 +1,764 @@ -"""CSATv2 - -A frequency-domain vision model using DCT transforms with spatial attention. - -Paper: TBD -""" -import math -import warnings -from typing import List, Optional, Tuple, Union - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F - -from timm.layers import trunc_normal_, DropPath, Mlp, LayerNorm2d, Attention -from timm.layers.grn import GlobalResponseNorm -from timm.models._builder import build_model_with_cfg -from timm.models._features import feature_take_indices -from ._registry import register_model, generate_default_cfgs - -__all__ = ['CSATv2', 'csatv2'] - -# DCT frequency normalization statistics (Y, Cb, Cr channels x 64 coefficients) -_DCT_MEAN = [ - [932.42657, -0.00260, 0.33415, -0.02840, 0.00003, -0.02792, -0.00183, 0.00006, - 0.00032, 0.03402, -0.00571, 0.00020, 0.00006, -0.00038, -0.00558, -0.00116, - -0.00000, -0.00047, -0.00008, -0.00030, 0.00942, 0.00161, -0.00009, -0.00006, - -0.00014, -0.00035, 0.00001, -0.00220, 0.00033, -0.00002, -0.00003, -0.00020, - 0.00007, -0.00000, 0.00005, 0.00293, -0.00004, 0.00006, 0.00019, 0.00004, - 0.00006, -0.00015, -0.00002, 0.00007, 0.00010, -0.00004, 0.00008, 0.00000, - 0.00008, -0.00001, 0.00015, 0.00002, 0.00007, 0.00003, 0.00004, -0.00001, - 0.00004, -0.00000, 0.00002, -0.00000, -0.00008, -0.00000, -0.00003, 0.00003], - [962.34735, -0.00428, 0.09835, 0.00152, -0.00009, 0.00312, -0.00141, -0.00001, - -0.00013, 0.01050, 0.00065, 0.00006, -0.00000, 0.00003, 0.00264, 0.00000, - 0.00001, 0.00007, -0.00006, 0.00003, 0.00341, 0.00163, 0.00004, 0.00003, - -0.00001, 0.00008, -0.00000, 0.00090, 0.00018, -0.00006, -0.00001, 0.00007, - -0.00003, -0.00001, 0.00006, 0.00084, -0.00000, -0.00001, 0.00000, 0.00004, - -0.00001, -0.00002, 0.00000, 0.00001, 0.00002, 0.00001, 0.00004, 0.00011, - 0.00000, -0.00003, 0.00011, -0.00002, 0.00001, 0.00001, 0.00001, 0.00001, - -0.00007, -0.00003, 0.00001, 0.00000, 0.00001, 0.00002, 0.00001, 0.00000], - [1053.16101, -0.00213, -0.09207, 0.00186, 0.00013, 0.00034, -0.00119, 0.00002, - 0.00011, -0.00984, 0.00046, -0.00007, -0.00001, -0.00005, 0.00180, 0.00042, - 0.00002, -0.00010, 0.00004, 0.00003, -0.00301, 0.00125, -0.00002, -0.00003, - -0.00001, -0.00001, -0.00001, 0.00056, 0.00021, 0.00001, -0.00001, 0.00002, - -0.00001, -0.00001, 0.00005, -0.00070, -0.00002, -0.00002, 0.00005, -0.00004, - -0.00000, 0.00002, -0.00002, 0.00001, 0.00000, -0.00003, 0.00004, 0.00007, - 0.00001, 0.00000, 0.00013, -0.00000, 0.00000, 0.00002, -0.00000, -0.00001, - -0.00004, -0.00003, 0.00000, 0.00001, -0.00001, 0.00001, -0.00000, 0.00000], -] - -_DCT_VAR = [ - [270372.37500, 6287.10645, 5974.94043, 1653.10889, 1463.91748, 1832.58997, 755.92468, 692.41528, - 648.57184, 641.46881, 285.79288, 301.62100, 380.43405, 349.84027, 374.15891, 190.30960, - 190.76746, 221.64578, 200.82646, 145.87979, 126.92046, 62.14622, 67.75562, 102.42001, - 129.74922, 130.04631, 103.12189, 97.76417, 53.17402, 54.81048, 73.48712, 81.04342, - 69.35100, 49.06024, 33.96053, 37.03279, 20.48858, 24.94830, 33.90822, 44.54912, - 47.56363, 40.03160, 30.43313, 22.63899, 26.53739, 26.57114, 21.84404, 17.41557, - 15.18253, 10.69678, 11.24111, 12.97229, 15.08971, 15.31646, 8.90409, 7.44213, - 6.66096, 6.97719, 4.17834, 3.83882, 4.51073, 2.36646, 2.41363, 1.48266], - [18839.21094, 321.70932, 300.15259, 77.47830, 76.02293, 89.04748, 33.99642, 34.74807, - 32.12333, 28.19588, 12.04675, 14.26871, 18.45779, 16.59588, 15.67892, 7.37718, - 8.56312, 10.28946, 9.41013, 6.69090, 5.16453, 2.55186, 3.03073, 4.66765, - 5.85418, 5.74644, 4.33702, 3.66948, 1.95107, 2.26034, 3.06380, 3.50705, - 3.06359, 2.19284, 1.54454, 1.57860, 0.97078, 1.13941, 1.48653, 1.89996, - 1.95544, 1.64950, 1.24754, 0.93677, 1.09267, 1.09516, 0.94163, 0.78966, - 0.72489, 0.50841, 0.50909, 0.55664, 0.63111, 0.64125, 0.38847, 0.33378, - 0.30918, 0.33463, 0.20875, 0.19298, 0.21903, 0.13380, 0.13444, 0.09554], - [17127.39844, 292.81421, 271.45209, 66.64056, 63.60253, 76.35437, 28.06587, 27.84831, - 25.96656, 23.60370, 9.99173, 11.34992, 14.46955, 12.92553, 12.69353, 5.91537, - 6.60187, 7.90891, 7.32825, 5.32785, 4.29660, 2.13459, 2.44135, 3.66021, - 4.50335, 4.38959, 3.34888, 2.97181, 1.60633, 1.77010, 2.35118, 2.69018, - 2.38189, 1.74596, 1.26014, 1.31684, 0.79327, 0.92046, 1.17670, 1.47609, - 1.50914, 1.28725, 0.99898, 0.74832, 0.85736, 0.85800, 0.74663, 0.63508, - 0.58748, 0.41098, 0.41121, 0.44663, 0.50277, 0.51519, 0.31729, 0.27336, - 0.25399, 0.27241, 0.17353, 0.16255, 0.18440, 0.11602, 0.11511, 0.08450], -] - - -def _zigzag_permutation(rows: int, cols: int) -> List[int]: - """Generate zigzag scan order for DCT coefficients.""" - idx_matrix = np.arange(0, rows * cols, 1).reshape(rows, cols).tolist() - dia = [[] for _ in range(rows + cols - 1)] - zigzag = [] - for i in range(rows): - for j in range(cols): - s = i + j - if s % 2 == 0: - dia[s].insert(0, idx_matrix[i][j]) - else: - dia[s].append(idx_matrix[i][j]) - for d in dia: - zigzag.extend(d) - return zigzag - - -def _dct_kernel_type_2( - kernel_size: int, - orthonormal: bool, - device=None, - dtype=None, -) -> torch.Tensor: - """Generate Type-II DCT kernel matrix.""" - factory_kwargs = dict(device=device, dtype=dtype) - x = torch.eye(kernel_size, **factory_kwargs) - v = x.clone().contiguous().view(-1, kernel_size) - v = torch.cat([v, v.flip([1])], dim=-1) - v = torch.fft.fft(v, dim=-1)[:, :kernel_size] - try: - k = torch.tensor(-1j, **factory_kwargs) * torch.pi * torch.arange(kernel_size, **factory_kwargs)[None, :] - except AttributeError: - k = torch.tensor(-1j, **factory_kwargs) * math.pi * torch.arange(kernel_size, **factory_kwargs)[None, :] - k = torch.exp(k / (kernel_size * 2)) - v = v * k - v = v.real - if orthonormal: - v[:, 0] = v[:, 0] * torch.sqrt(torch.tensor(1 / (kernel_size * 4), **factory_kwargs)) - v[:, 1:] = v[:, 1:] * torch.sqrt(torch.tensor(1 / (kernel_size * 2), **factory_kwargs)) - v = v.contiguous().view(*x.shape) - return v - - -def _dct_kernel_type_3( - kernel_size: int, - orthonormal: bool, - device=None, - dtype=None, -) -> torch.Tensor: - """Generate Type-III DCT kernel matrix (inverse of Type-II).""" - return torch.linalg.inv(_dct_kernel_type_2(kernel_size, orthonormal, device, dtype)) - - -class Dct1d(nn.Module): - """1D Discrete Cosine Transform layer.""" - - def __init__( - self, - kernel_size: int, - kernel_type: int = 2, - orthonormal: bool = True, - device=None, - dtype=None, - ) -> None: - factory_kwargs = dict(device=device, dtype=dtype) - super().__init__() - kernel = {'2': _dct_kernel_type_2, '3': _dct_kernel_type_3} - dct_weights = kernel[f'{kernel_type}'](kernel_size, orthonormal, **factory_kwargs).T - self.register_buffer('weights', dct_weights) - self.register_parameter('bias', None) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return F.linear(x, self.weights, self.bias) - - -class Dct2d(nn.Module): - """2D Discrete Cosine Transform layer.""" - - def __init__( - self, - kernel_size: int, - kernel_type: int = 2, - orthonormal: bool = True, - device=None, - dtype=None, - ) -> None: - factory_kwargs = dict(device=device, dtype=dtype) - super().__init__() - self.transform = Dct1d(kernel_size, kernel_type, orthonormal, **factory_kwargs) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.transform(self.transform(x).transpose(-1, -2)).transpose(-1, -2) - - -class LearnableDct2d(nn.Module): - """Learnable 2D DCT stem with RGB to YCbCr conversion and frequency selection.""" - - def __init__( - self, - kernel_size: int, - kernel_type: int = 2, - orthonormal: bool = True, - device=None, - dtype=None, - ) -> None: - factory_kwargs = dict(device=device, dtype=dtype) - super().__init__() - self.k = kernel_size - self.unfold = nn.Unfold(kernel_size=(kernel_size, kernel_size), stride=(kernel_size, kernel_size)) - self.transform = Dct2d(kernel_size, kernel_type, orthonormal, **factory_kwargs) - self.permutation = _zigzag_permutation(kernel_size, kernel_size) - self.Y_Conv = nn.Conv2d(kernel_size ** 2, 24, kernel_size=1, padding=0) - self.Cb_Conv = nn.Conv2d(kernel_size ** 2, 4, kernel_size=1, padding=0) - self.Cr_Conv = nn.Conv2d(kernel_size ** 2, 4, kernel_size=1, padding=0) - - self.register_buffer('mean', torch.tensor(_DCT_MEAN), persistent=False) - self.register_buffer('var', torch.tensor(_DCT_VAR), persistent=False) - self.register_buffer('imagenet_mean', torch.tensor([0.485, 0.456, 0.406]), persistent=False) - self.register_buffer('imagenet_std', torch.tensor([0.229, 0.224, 0.225]), persistent=False) - - def _denormalize(self, x: torch.Tensor) -> torch.Tensor: - """Convert from ImageNet normalized to [0, 255] range.""" - return x.mul(self.imagenet_std).add_(self.imagenet_mean) * 255 - - def _rgb_to_ycbcr(self, x: torch.Tensor) -> torch.Tensor: - """Convert RGB to YCbCr color space.""" - y = (x[:, :, :, 0] * 0.299) + (x[:, :, :, 1] * 0.587) + (x[:, :, :, 2] * 0.114) - cb = 0.564 * (x[:, :, :, 2] - y) + 128 - cr = 0.713 * (x[:, :, :, 0] - y) + 128 - return torch.stack([y, cb, cr], dim=-1) - - def _frequency_normalize(self, x: torch.Tensor) -> torch.Tensor: - """Normalize DCT coefficients using precomputed statistics.""" - std = self.var ** 0.5 + 1e-8 - return (x - self.mean) / std - - def forward(self, x: torch.Tensor) -> torch.Tensor: - b, c, h, w = x.shape - x = x.permute(0, 2, 3, 1) - x = self._denormalize(x) - x = self._rgb_to_ycbcr(x) - x = x.permute(0, 3, 1, 2) - x = self.unfold(x).transpose(-1, -2) - x = x.reshape(b, h // self.k, w // self.k, c, self.k, self.k) - x = self.transform(x) - x = x.reshape(-1, c, self.k * self.k) - x = x[:, :, self.permutation] - x = self._frequency_normalize(x) - x = x.reshape(b, h // self.k, w // self.k, c, -1) - x = x.permute(0, 3, 4, 1, 2).contiguous() - x_y = self.Y_Conv(x[:, 0]) - x_cb = self.Cb_Conv(x[:, 1]) - x_cr = self.Cr_Conv(x[:, 2]) - return torch.cat([x_y, x_cb, x_cr], dim=1) - - -class Dct2dStats(nn.Module): - """Utility module to compute DCT coefficient statistics.""" - - def __init__( - self, - kernel_size: int, - kernel_type: int = 2, - orthonormal: bool = True, - device=None, - dtype=None, - ) -> None: - factory_kwargs = dict(device=device, dtype=dtype) - super().__init__() - self.k = kernel_size - self.unfold = nn.Unfold(kernel_size=(kernel_size, kernel_size), stride=(kernel_size, kernel_size)) - self.transform = Dct2d(kernel_size, kernel_type, orthonormal, **factory_kwargs) - self.permutation = _zigzag_permutation(kernel_size, kernel_size) - - def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - b, c, h, w = x.shape - x = self.unfold(x).transpose(-1, -2) - x = x.reshape(b, h // self.k, w // self.k, c, self.k, self.k) - x = self.transform(x) - x = x.reshape(-1, c, self.k * self.k) - x = x[:, :, self.permutation] - x = x.reshape(b * (h // self.k) * (w // self.k), c, -1) - - mean_list = torch.zeros([3, 64]) - var_list = torch.zeros([3, 64]) - for i in range(3): - mean_list[i] = torch.mean(x[:, i], dim=0) - var_list[i] = torch.var(x[:, i], dim=0) - return mean_list, var_list - - -class Block(nn.Module): - """ConvNeXt-style block with spatial attention.""" - - def __init__( - self, - dim: int, - drop_path: float = 0., - ) -> None: - super().__init__() - self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) - self.norm = nn.LayerNorm(dim, eps=1e-6) - self.pwconv1 = nn.Linear(dim, 4 * dim) - self.act = nn.GELU() - self.grn = GlobalResponseNorm(4 * dim, channels_last=True) - self.pwconv2 = nn.Linear(4 * dim, dim) - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.attention = SpatialAttention() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - shortcut = x - x = self.dwconv(x) - x = x.permute(0, 2, 3, 1) - x = self.norm(x) - x = self.pwconv1(x) - x = self.act(x) - x = self.grn(x) - x = self.pwconv2(x) - x = x.permute(0, 3, 1, 2) - - attention = self.attention(x) - up_attn = F.interpolate(attention, size=x.shape[2:], mode='bilinear', align_corners=True) - x = x * up_attn - - return shortcut + self.drop_path(x) - - -class SpatialTransformerBlock(nn.Module): - """Lightweight transformer block for spatial attention (1-channel, 7x7 grid). - - This is a simplified transformer with single-head, 1-dim attention over spatial - positions. Used inside SpatialAttention where input is 1 channel at 7x7 resolution. - """ - - def __init__(self) -> None: - super().__init__() - # Single-head attention with 1-dim q/k/v (no output projection needed) - self.pos_embed = PosConv(in_chans=1) - self.norm1 = nn.LayerNorm(1) - self.qkv = nn.Linear(1, 3, bias=False) - - # Feedforward: 1 -> 4 -> 1 - self.norm2 = nn.LayerNorm(1) - self.mlp = Mlp(1, 4, 1, act_layer=nn.GELU) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - B, C, H, W = x.shape - - # Attention block - shortcut = x - x_t = x.flatten(2).transpose(1, 2) # (B, N, 1) - x_t = self.norm1(x_t) - x_t = self.pos_embed(x_t, (H, W)) - - # Simple single-head attention with scalar q/k/v - qkv = self.qkv(x_t) # (B, N, 3) - q, k, v = qkv.unbind(-1) # each (B, N) - attn = (q @ k.transpose(-1, -2)).softmax(dim=-1) # (B, N, N) - x_t = (attn @ v).unsqueeze(-1) # (B, N, 1) - - x_t = x_t.transpose(1, 2).reshape(B, C, H, W) - x = shortcut + x_t - - # Feedforward block - shortcut = x - x_t = x.flatten(2).transpose(1, 2) - x_t = self.mlp(self.norm2(x_t)) - x_t = x_t.transpose(1, 2).reshape(B, C, H, W) - x = shortcut + x_t - - return x - - -class SpatialAttention(nn.Module): - """Spatial attention module using channel statistics and transformer.""" - - def __init__(self) -> None: - super().__init__() - self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) - self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3) - self.attention = SpatialTransformerBlock() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x_avg = x.mean(dim=1, keepdim=True) - x_max = x.amax(dim=1, keepdim=True) - x = torch.cat([x_avg, x_max], dim=1) - x = self.avgpool(x) - x = self.conv(x) - x = self.attention(x) - return x - - -class TransformerBlock(nn.Module): - """Transformer block with optional downsampling and convolutional position encoding.""" - - def __init__( - self, - inp: int, - oup: int, - num_heads: int = 8, - attn_head_dim: int = 32, - downsample: bool = False, - attn_drop: float = 0., - proj_drop: float = 0., - ) -> None: - super().__init__() - hidden_dim = int(inp * 4) - self.downsample = downsample - - if self.downsample: - self.pool1 = nn.MaxPool2d(3, 2, 1) - self.pool2 = nn.MaxPool2d(3, 2, 1) - self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False) - else: - self.pool1 = nn.Identity() - self.pool2 = nn.Identity() - self.proj = nn.Identity() - - self.pos_embed = PosConv(in_chans=inp) - self.norm1 = nn.LayerNorm(inp) - self.attn = Attention( - dim=inp, - num_heads=num_heads, - attn_head_dim=attn_head_dim, - dim_out=oup, - attn_drop=attn_drop, - proj_drop=proj_drop, - ) - - self.norm2 = nn.LayerNorm(oup) - self.mlp = Mlp(oup, hidden_dim, oup, act_layer=nn.GELU, drop=proj_drop) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.downsample: - shortcut = self.proj(self.pool1(x)) - x_t = self.pool2(x) - B, C, H, W = x_t.shape - x_t = x_t.flatten(2).transpose(1, 2) - x_t = self.norm1(x_t) - x_t = self.pos_embed(x_t, (H, W)) - x_t = self.attn(x_t) - x_t = x_t.transpose(1, 2).reshape(B, -1, H, W) - x = shortcut + x_t - else: - B, C, H, W = x.shape - shortcut = x - x_t = x.flatten(2).transpose(1, 2) - x_t = self.norm1(x_t) - x_t = self.pos_embed(x_t, (H, W)) - x_t = self.attn(x_t) - x_t = x_t.transpose(1, 2).reshape(B, -1, H, W) - x = shortcut + x_t - - # MLP block - B, C, H, W = x.shape - shortcut = x - x_t = x.flatten(2).transpose(1, 2) - x_t = self.mlp(self.norm2(x_t)) - x_t = x_t.transpose(1, 2).reshape(B, C, H, W) - x = shortcut + x_t - - return x - - -class PosConv(nn.Module): - """Convolutional position encoding.""" - - def __init__( - self, - in_chans: int, - ) -> None: - super().__init__() - self.proj = nn.Conv2d(in_chans, in_chans, kernel_size=3, stride=1, padding=1, bias=True, groups=in_chans) - - def forward(self, x: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor: - B, N, C = x.shape - H, W = size - cnn_feat = x.transpose(1, 2).view(B, C, H, W) - x = self.proj(cnn_feat) + cnn_feat - return x.flatten(2).transpose(1, 2) - - -class CSATv2(nn.Module): - """CSATv2: Frequency-domain vision model with spatial attention. - - A hybrid architecture that processes images in the DCT frequency domain - with ConvNeXt-style blocks and transformer attention. - """ - - def __init__( - self, - num_classes: int = 1000, - in_chans: int = 3, - drop_path_rate: float = 0.0, - global_pool: str = 'avg', - **kwargs, - ) -> None: - super().__init__() - self.num_classes = num_classes - self.global_pool = global_pool - - dims = [32, 72, 168, 386] - self.num_features = dims[-1] - self.head_hidden_size = self.num_features - - self.feature_info = [ - dict(num_chs=dims[0], reduction=8, module='stem_dct'), - dict(num_chs=dims[0], reduction=8, module='stages.0'), - dict(num_chs=dims[1], reduction=16, module='stages.1'), - dict(num_chs=dims[2], reduction=32, module='stages.2'), - dict(num_chs=dims[3], reduction=64, module='stages.3'), - ] - - depths = [2, 2, 6, 4] - dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] - - self.stem_dct = LearnableDct2d(8) - - self.stages = nn.Sequential( - nn.Sequential( - Block(dim=dims[0], drop_path=dp_rates[0]), - Block(dim=dims[0], drop_path=dp_rates[1]), - LayerNorm2d(dims[0], eps=1e-6), - ), - nn.Sequential( - nn.Conv2d(dims[0], dims[1], kernel_size=2, stride=2), - Block(dim=dims[1], drop_path=dp_rates[2]), - Block(dim=dims[1], drop_path=dp_rates[3]), - LayerNorm2d(dims[1], eps=1e-6), - ), - nn.Sequential( - nn.Conv2d(dims[1], dims[2], kernel_size=2, stride=2), - Block(dim=dims[2], drop_path=dp_rates[4]), - Block(dim=dims[2], drop_path=dp_rates[5]), - Block(dim=dims[2], drop_path=dp_rates[6]), - Block(dim=dims[2], drop_path=dp_rates[7]), - Block(dim=dims[2], drop_path=dp_rates[8]), - Block(dim=dims[2], drop_path=dp_rates[9]), - TransformerBlock(inp=dims[2], oup=dims[2]), - TransformerBlock(inp=dims[2], oup=dims[2]), - LayerNorm2d(dims[2], eps=1e-6), - ), - nn.Sequential( - nn.Conv2d(dims[2], dims[3], kernel_size=2, stride=2), - Block(dim=dims[3], drop_path=dp_rates[10]), - Block(dim=dims[3], drop_path=dp_rates[11]), - Block(dim=dims[3], drop_path=dp_rates[12]), - Block(dim=dims[3], drop_path=dp_rates[13]), - TransformerBlock(inp=dims[3], oup=dims[3]), - TransformerBlock(inp=dims[3], oup=dims[3]), - ), - ) - - self.norm = nn.LayerNorm(dims[-1], eps=1e-6) - self.head = nn.Linear(dims[-1], num_classes) if num_classes > 0 else nn.Identity() - - self.apply(self._init_weights) - - def _init_weights(self, m: nn.Module) -> None: - if isinstance(m, (nn.Conv2d, nn.Linear)): - trunc_normal_(m.weight, std=0.02) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - - @torch.jit.ignore - def get_classifier(self) -> nn.Module: - return self.head - - def reset_classifier(self, num_classes: int, global_pool: str = 'avg') -> None: - self.num_classes = num_classes - self.global_pool = global_pool - self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() - - def forward_features(self, x: torch.Tensor) -> torch.Tensor: - x = self.stem_dct(x) - x = self.stages(x) - return x - - def forward_intermediates( - self, - x: torch.Tensor, - indices: Optional[Union[int, List[int]]] = None, - norm: bool = False, - stop_early: bool = False, - output_fmt: str = 'NCHW', - intermediates_only: bool = False, - ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: - """Forward pass returning intermediate features. - - Args: - x: Input image tensor. - indices: Indices of features to return (0=stem_dct, 1-4=stages). None returns all. - norm: Apply norm layer to final intermediate (unused, for API compat). - stop_early: Stop iterating when last desired intermediate is reached. - output_fmt: Output format, must be 'NCHW'. - intermediates_only: Only return intermediate features. - - Returns: - List of intermediate features or tuple of (final features, intermediates). - """ - assert output_fmt == 'NCHW', 'Output format must be NCHW.' - intermediates = [] - # 5 feature levels: stem_dct (0) + stages 0-3 (1-4) - take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices) - - x = self.stem_dct(x) - if 0 in take_indices: - intermediates.append(x) - - if torch.jit.is_scripting() or not stop_early: - stages = self.stages - else: - # max_index is 0-4, stages are 1-4, so we need max_index stages - stages = self.stages[:max_index] if max_index > 0 else [] - - for feat_idx, stage in enumerate(stages): - x = stage(x) - if feat_idx + 1 in take_indices: # +1 because stem is index 0 - intermediates.append(x) - - if intermediates_only: - return intermediates - - return x, intermediates - - def prune_intermediate_layers( - self, - indices: Union[int, List[int]] = 1, - prune_norm: bool = False, - prune_head: bool = True, - ) -> List[int]: - """Prune layers not required for specified intermediates. - - Args: - indices: Indices of intermediate layers to keep (0=stem_dct, 1-4=stages). - prune_norm: Whether to prune the final norm layer. - prune_head: Whether to prune the classifier head. - - Returns: - List of indices that were kept. - """ - # 5 feature levels: stem_dct (0) + stages 0-3 (1-4) - take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices) - # max_index is 0-4, stages are 1-4, so we keep max_index stages - self.stages = self.stages[:max_index] if max_index > 0 else nn.Sequential() - - if prune_norm: - self.norm = nn.Identity() - if prune_head: - self.reset_classifier(0, '') - - return take_indices - - def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: - x = x.mean(dim=(-2, -1)) - x = self.norm(x) - if pre_logits: - return x - return self.head(x) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.forward_features(x) - return self.forward_head(x) - - -default_cfgs = generate_default_cfgs({ - 'csatv2': { - 'url': 'https://huggingface.co/Hyunil/CSATv2/resolve/main/CSATv2_ImageNet_timm.pth', - 'num_classes': 1000, - 'input_size': (3, 512, 512), - 'mean': (0.485, 0.456, 0.406), - 'std': (0.229, 0.224, 0.225), - 'interpolation': 'bilinear', - 'crop_pct': 1.0, - 'classifier': 'head', - 'first_conv': [], - }, -}) - - -def checkpoint_filter_fn(state_dict: dict, model: nn.Module) -> dict: - """Remap original CSATv2 checkpoint to timm format. - - Handles two key structural changes: - 1. Stage naming: stages1/2/3/4 -> stages.0/1/2/3 - 2. Downsample position: moved from end of stage N to start of stage N+1 - """ - if 'grn.weight' in state_dict or 'stages.0.0.grn.weight' in state_dict: - return state_dict # Already in timm format - - import re - - # Downsample indices in original checkpoint (Conv2d at end of each stage) - # These move to index 0 of the next stage - downsample_idx = {1: 3, 2: 3, 3: 9} # stage -> downsample index - - def remap_stage(m): - stage = int(m.group(1)) - idx = int(m.group(2)) - rest = m.group(3) - if stage in downsample_idx and idx == downsample_idx[stage]: - # Downsample moves to start of next stage - return f'stages.{stage}.0.{rest}' - elif stage == 1: - # Stage 1 -> stages.0, indices unchanged - return f'stages.0.{idx}.{rest}' - else: - # Stages 2-4 -> stages.1-3, indices shift +1 (after downsample) - return f'stages.{stage - 1}.{idx + 1}.{rest}' - - out_dict = {} - for k, v in state_dict.items(): - # Remap dct -> stem_dct - k = re.sub(r'^dct\.', 'stem_dct.', k) - - # Remap stage names with index adjustments for downsample relocation - k = re.sub(r'^stages([1-4])\.(\d+)\.(.*)$', remap_stage, k) - - # Remap GRN: gamma/beta -> weight/bias with reshape - if 'grn.gamma' in k: - k = k.replace('grn.gamma', 'grn.weight') - v = v.reshape(-1) - elif 'grn.beta' in k: - k = k.replace('grn.beta', 'grn.bias') - v = v.reshape(-1) - - # Remap FeedForward (nn.Sequential) to Mlp: net.0 -> fc1, net.3 -> fc2 - # Also rename ff -> mlp, ff_norm -> norm2, attn_norm -> norm1 - if '.ff.net.0.' in k: - k = k.replace('.ff.net.0.', '.mlp.fc1.') - elif '.ff.net.3.' in k: - k = k.replace('.ff.net.3.', '.mlp.fc2.') - elif '.ff_norm.' in k: - k = k.replace('.ff_norm.', '.norm2.') - elif '.attn_norm.' in k: - k = k.replace('.attn_norm.', '.norm1.') - - # SpatialTransformerBlock: flatten .attention.attention.attn. -> .attention.attention. - # and remap to_qkv -> qkv - if '.attention.attention.attn.' in k: - k = k.replace('.attention.attention.attn.to_qkv.', '.attention.attention.qkv.') - k = k.replace('.attention.attention.attn.', '.attention.attention.') - - # TransformerBlock: remap attention layer names - # to_qkv -> qkv, to_out.0 -> proj, attn.pos_embed -> pos_embed - if '.attn.to_qkv.' in k: - k = k.replace('.attn.to_qkv.', '.attn.qkv.') - elif '.attn.to_out.0.' in k: - k = k.replace('.attn.to_out.0.', '.attn.proj.') - - if '.attn.pos_embed.' in k: - k = k.replace('.attn.pos_embed.', '.pos_embed.') - - out_dict[k] = v - - return out_dict - - -def _create_csatv2(variant: str, pretrained: bool = False, **kwargs) -> CSATv2: - return build_model_with_cfg( - CSATv2, - variant, - pretrained, - pretrained_filter_fn=checkpoint_filter_fn, - feature_cfg=dict(out_indices=(0, 1, 2, 3, 4), flatten_sequential=True), - default_cfg=default_cfgs[variant], - **kwargs, - ) - - -@register_model -def csatv2(pretrained: bool = False, **kwargs) -> CSATv2: - return _create_csatv2('csatv2', pretrained, **kwargs) +"""CSATv2 + +A frequency-domain vision model using DCT transforms with spatial attention. + +Paper: TBD +""" +import math +import warnings +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.layers import trunc_normal_, DropPath, Mlp, LayerNorm2d, Attention, NormMlpClassifierHead +from timm.layers.grn import GlobalResponseNorm +from timm.models._builder import build_model_with_cfg +from timm.models._features import feature_take_indices +from timm.models._manipulate import checkpoint, checkpoint_seq +from ._registry import register_model, generate_default_cfgs + +__all__ = ['CSATv2', 'csatv2'] + +# DCT frequency normalization statistics (Y, Cb, Cr channels x 64 coefficients) +_DCT_MEAN = [ + [932.42657, -0.00260, 0.33415, -0.02840, 0.00003, -0.02792, -0.00183, 0.00006, + 0.00032, 0.03402, -0.00571, 0.00020, 0.00006, -0.00038, -0.00558, -0.00116, + -0.00000, -0.00047, -0.00008, -0.00030, 0.00942, 0.00161, -0.00009, -0.00006, + -0.00014, -0.00035, 0.00001, -0.00220, 0.00033, -0.00002, -0.00003, -0.00020, + 0.00007, -0.00000, 0.00005, 0.00293, -0.00004, 0.00006, 0.00019, 0.00004, + 0.00006, -0.00015, -0.00002, 0.00007, 0.00010, -0.00004, 0.00008, 0.00000, + 0.00008, -0.00001, 0.00015, 0.00002, 0.00007, 0.00003, 0.00004, -0.00001, + 0.00004, -0.00000, 0.00002, -0.00000, -0.00008, -0.00000, -0.00003, 0.00003], + [962.34735, -0.00428, 0.09835, 0.00152, -0.00009, 0.00312, -0.00141, -0.00001, + -0.00013, 0.01050, 0.00065, 0.00006, -0.00000, 0.00003, 0.00264, 0.00000, + 0.00001, 0.00007, -0.00006, 0.00003, 0.00341, 0.00163, 0.00004, 0.00003, + -0.00001, 0.00008, -0.00000, 0.00090, 0.00018, -0.00006, -0.00001, 0.00007, + -0.00003, -0.00001, 0.00006, 0.00084, -0.00000, -0.00001, 0.00000, 0.00004, + -0.00001, -0.00002, 0.00000, 0.00001, 0.00002, 0.00001, 0.00004, 0.00011, + 0.00000, -0.00003, 0.00011, -0.00002, 0.00001, 0.00001, 0.00001, 0.00001, + -0.00007, -0.00003, 0.00001, 0.00000, 0.00001, 0.00002, 0.00001, 0.00000], + [1053.16101, -0.00213, -0.09207, 0.00186, 0.00013, 0.00034, -0.00119, 0.00002, + 0.00011, -0.00984, 0.00046, -0.00007, -0.00001, -0.00005, 0.00180, 0.00042, + 0.00002, -0.00010, 0.00004, 0.00003, -0.00301, 0.00125, -0.00002, -0.00003, + -0.00001, -0.00001, -0.00001, 0.00056, 0.00021, 0.00001, -0.00001, 0.00002, + -0.00001, -0.00001, 0.00005, -0.00070, -0.00002, -0.00002, 0.00005, -0.00004, + -0.00000, 0.00002, -0.00002, 0.00001, 0.00000, -0.00003, 0.00004, 0.00007, + 0.00001, 0.00000, 0.00013, -0.00000, 0.00000, 0.00002, -0.00000, -0.00001, + -0.00004, -0.00003, 0.00000, 0.00001, -0.00001, 0.00001, -0.00000, 0.00000], +] + +_DCT_VAR = [ + [270372.37500, 6287.10645, 5974.94043, 1653.10889, 1463.91748, 1832.58997, 755.92468, 692.41528, + 648.57184, 641.46881, 285.79288, 301.62100, 380.43405, 349.84027, 374.15891, 190.30960, + 190.76746, 221.64578, 200.82646, 145.87979, 126.92046, 62.14622, 67.75562, 102.42001, + 129.74922, 130.04631, 103.12189, 97.76417, 53.17402, 54.81048, 73.48712, 81.04342, + 69.35100, 49.06024, 33.96053, 37.03279, 20.48858, 24.94830, 33.90822, 44.54912, + 47.56363, 40.03160, 30.43313, 22.63899, 26.53739, 26.57114, 21.84404, 17.41557, + 15.18253, 10.69678, 11.24111, 12.97229, 15.08971, 15.31646, 8.90409, 7.44213, + 6.66096, 6.97719, 4.17834, 3.83882, 4.51073, 2.36646, 2.41363, 1.48266], + [18839.21094, 321.70932, 300.15259, 77.47830, 76.02293, 89.04748, 33.99642, 34.74807, + 32.12333, 28.19588, 12.04675, 14.26871, 18.45779, 16.59588, 15.67892, 7.37718, + 8.56312, 10.28946, 9.41013, 6.69090, 5.16453, 2.55186, 3.03073, 4.66765, + 5.85418, 5.74644, 4.33702, 3.66948, 1.95107, 2.26034, 3.06380, 3.50705, + 3.06359, 2.19284, 1.54454, 1.57860, 0.97078, 1.13941, 1.48653, 1.89996, + 1.95544, 1.64950, 1.24754, 0.93677, 1.09267, 1.09516, 0.94163, 0.78966, + 0.72489, 0.50841, 0.50909, 0.55664, 0.63111, 0.64125, 0.38847, 0.33378, + 0.30918, 0.33463, 0.20875, 0.19298, 0.21903, 0.13380, 0.13444, 0.09554], + [17127.39844, 292.81421, 271.45209, 66.64056, 63.60253, 76.35437, 28.06587, 27.84831, + 25.96656, 23.60370, 9.99173, 11.34992, 14.46955, 12.92553, 12.69353, 5.91537, + 6.60187, 7.90891, 7.32825, 5.32785, 4.29660, 2.13459, 2.44135, 3.66021, + 4.50335, 4.38959, 3.34888, 2.97181, 1.60633, 1.77010, 2.35118, 2.69018, + 2.38189, 1.74596, 1.26014, 1.31684, 0.79327, 0.92046, 1.17670, 1.47609, + 1.50914, 1.28725, 0.99898, 0.74832, 0.85736, 0.85800, 0.74663, 0.63508, + 0.58748, 0.41098, 0.41121, 0.44663, 0.50277, 0.51519, 0.31729, 0.27336, + 0.25399, 0.27241, 0.17353, 0.16255, 0.18440, 0.11602, 0.11511, 0.08450], +] + + +def _zigzag_permutation(rows: int, cols: int) -> List[int]: + """Generate zigzag scan order for DCT coefficients.""" + idx_matrix = np.arange(0, rows * cols, 1).reshape(rows, cols).tolist() + dia = [[] for _ in range(rows + cols - 1)] + zigzag = [] + for i in range(rows): + for j in range(cols): + s = i + j + if s % 2 == 0: + dia[s].insert(0, idx_matrix[i][j]) + else: + dia[s].append(idx_matrix[i][j]) + for d in dia: + zigzag.extend(d) + return zigzag + + +def _dct_kernel_type_2( + kernel_size: int, + orthonormal: bool, + device=None, + dtype=None, +) -> torch.Tensor: + """Generate Type-II DCT kernel matrix.""" + factory_kwargs = dict(device=device, dtype=dtype) + x = torch.eye(kernel_size, **factory_kwargs) + v = x.clone().contiguous().view(-1, kernel_size) + v = torch.cat([v, v.flip([1])], dim=-1) + v = torch.fft.fft(v, dim=-1)[:, :kernel_size] + try: + k = torch.tensor(-1j, **factory_kwargs) * torch.pi * torch.arange(kernel_size, **factory_kwargs)[None, :] + except AttributeError: + k = torch.tensor(-1j, **factory_kwargs) * math.pi * torch.arange(kernel_size, **factory_kwargs)[None, :] + k = torch.exp(k / (kernel_size * 2)) + v = v * k + v = v.real + if orthonormal: + v[:, 0] = v[:, 0] * torch.sqrt(torch.tensor(1 / (kernel_size * 4), **factory_kwargs)) + v[:, 1:] = v[:, 1:] * torch.sqrt(torch.tensor(1 / (kernel_size * 2), **factory_kwargs)) + v = v.contiguous().view(*x.shape) + return v + + +def _dct_kernel_type_3( + kernel_size: int, + orthonormal: bool, + device=None, + dtype=None, +) -> torch.Tensor: + """Generate Type-III DCT kernel matrix (inverse of Type-II).""" + return torch.linalg.inv(_dct_kernel_type_2(kernel_size, orthonormal, device, dtype)) + + +class Dct1d(nn.Module): + """1D Discrete Cosine Transform layer.""" + + def __init__( + self, + kernel_size: int, + kernel_type: int = 2, + orthonormal: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = dict(device=device, dtype=dtype) + super().__init__() + kernel = {'2': _dct_kernel_type_2, '3': _dct_kernel_type_3} + dct_weights = kernel[f'{kernel_type}'](kernel_size, orthonormal, **factory_kwargs).T + self.register_buffer('weights', dct_weights) + self.register_parameter('bias', None) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return F.linear(x, self.weights, self.bias) + + +class Dct2d(nn.Module): + """2D Discrete Cosine Transform layer.""" + + def __init__( + self, + kernel_size: int, + kernel_type: int = 2, + orthonormal: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = dict(device=device, dtype=dtype) + super().__init__() + self.transform = Dct1d(kernel_size, kernel_type, orthonormal, **factory_kwargs) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.transform(self.transform(x).transpose(-1, -2)).transpose(-1, -2) + + +class LearnableDct2d(nn.Module): + """Learnable 2D DCT stem with RGB to YCbCr conversion and frequency selection.""" + + def __init__( + self, + kernel_size: int, + kernel_type: int = 2, + orthonormal: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = dict(device=device, dtype=dtype) + super().__init__() + self.k = kernel_size + self.unfold = nn.Unfold(kernel_size=(kernel_size, kernel_size), stride=(kernel_size, kernel_size)) + self.transform = Dct2d(kernel_size, kernel_type, orthonormal, **factory_kwargs) + self.permutation = _zigzag_permutation(kernel_size, kernel_size) + self.Y_Conv = nn.Conv2d(kernel_size ** 2, 24, kernel_size=1, padding=0) + self.Cb_Conv = nn.Conv2d(kernel_size ** 2, 4, kernel_size=1, padding=0) + self.Cr_Conv = nn.Conv2d(kernel_size ** 2, 4, kernel_size=1, padding=0) + + self.register_buffer('mean', torch.tensor(_DCT_MEAN), persistent=False) + self.register_buffer('var', torch.tensor(_DCT_VAR), persistent=False) + self.register_buffer('imagenet_mean', torch.tensor([0.485, 0.456, 0.406]), persistent=False) + self.register_buffer('imagenet_std', torch.tensor([0.229, 0.224, 0.225]), persistent=False) + + def _denormalize(self, x: torch.Tensor) -> torch.Tensor: + """Convert from ImageNet normalized to [0, 255] range.""" + return x.mul(self.imagenet_std).add_(self.imagenet_mean) * 255 + + def _rgb_to_ycbcr(self, x: torch.Tensor) -> torch.Tensor: + """Convert RGB to YCbCr color space.""" + y = (x[:, :, :, 0] * 0.299) + (x[:, :, :, 1] * 0.587) + (x[:, :, :, 2] * 0.114) + cb = 0.564 * (x[:, :, :, 2] - y) + 128 + cr = 0.713 * (x[:, :, :, 0] - y) + 128 + return torch.stack([y, cb, cr], dim=-1) + + def _frequency_normalize(self, x: torch.Tensor) -> torch.Tensor: + """Normalize DCT coefficients using precomputed statistics.""" + std = self.var ** 0.5 + 1e-8 + return (x - self.mean) / std + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, c, h, w = x.shape + x = x.permute(0, 2, 3, 1) + x = self._denormalize(x) + x = self._rgb_to_ycbcr(x) + x = x.permute(0, 3, 1, 2) + x = self.unfold(x).transpose(-1, -2) + x = x.reshape(b, h // self.k, w // self.k, c, self.k, self.k) + x = self.transform(x) + x = x.reshape(-1, c, self.k * self.k) + x = x[:, :, self.permutation] + x = self._frequency_normalize(x) + x = x.reshape(b, h // self.k, w // self.k, c, -1) + x = x.permute(0, 3, 4, 1, 2).contiguous() + x_y = self.Y_Conv(x[:, 0]) + x_cb = self.Cb_Conv(x[:, 1]) + x_cr = self.Cr_Conv(x[:, 2]) + return torch.cat([x_y, x_cb, x_cr], dim=1) + + +class Dct2dStats(nn.Module): + """Utility module to compute DCT coefficient statistics.""" + + def __init__( + self, + kernel_size: int, + kernel_type: int = 2, + orthonormal: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = dict(device=device, dtype=dtype) + super().__init__() + self.k = kernel_size + self.unfold = nn.Unfold(kernel_size=(kernel_size, kernel_size), stride=(kernel_size, kernel_size)) + self.transform = Dct2d(kernel_size, kernel_type, orthonormal, **factory_kwargs) + self.permutation = _zigzag_permutation(kernel_size, kernel_size) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + b, c, h, w = x.shape + x = self.unfold(x).transpose(-1, -2) + x = x.reshape(b, h // self.k, w // self.k, c, self.k, self.k) + x = self.transform(x) + x = x.reshape(-1, c, self.k * self.k) + x = x[:, :, self.permutation] + x = x.reshape(b * (h // self.k) * (w // self.k), c, -1) + + mean_list = torch.zeros([3, 64]) + var_list = torch.zeros([3, 64]) + for i in range(3): + mean_list[i] = torch.mean(x[:, i], dim=0) + var_list[i] = torch.var(x[:, i], dim=0) + return mean_list, var_list + + +class Block(nn.Module): + """ConvNeXt-style block with spatial attention.""" + + def __init__( + self, + dim: int, + drop_path: float = 0., + ) -> None: + super().__init__() + self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) + self.norm = nn.LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, 4 * dim) + self.act = nn.GELU() + self.grn = GlobalResponseNorm(4 * dim, channels_last=True) + self.pwconv2 = nn.Linear(4 * dim, dim) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.attention = SpatialAttention() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + x = self.dwconv(x) + x = x.permute(0, 2, 3, 1) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.grn(x) + x = self.pwconv2(x) + x = x.permute(0, 3, 1, 2) + + attention = self.attention(x) + up_attn = F.interpolate(attention, size=x.shape[2:], mode='bilinear', align_corners=True) + x = x * up_attn + + return shortcut + self.drop_path(x) + + +class SpatialTransformerBlock(nn.Module): + """Lightweight transformer block for spatial attention (1-channel, 7x7 grid). + + This is a simplified transformer with single-head, 1-dim attention over spatial + positions. Used inside SpatialAttention where input is 1 channel at 7x7 resolution. + """ + + def __init__(self) -> None: + super().__init__() + # Single-head attention with 1-dim q/k/v (no output projection needed) + self.pos_embed = PosConv(in_chans=1) + self.norm1 = nn.LayerNorm(1) + self.qkv = nn.Linear(1, 3, bias=False) + + # Feedforward: 1 -> 4 -> 1 + self.norm2 = nn.LayerNorm(1) + self.mlp = Mlp(1, 4, 1, act_layer=nn.GELU) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, C, H, W = x.shape + + # Attention block + shortcut = x + x_t = x.flatten(2).transpose(1, 2) # (B, N, 1) + x_t = self.norm1(x_t) + x_t = self.pos_embed(x_t, (H, W)) + + # Simple single-head attention with scalar q/k/v + qkv = self.qkv(x_t) # (B, N, 3) + q, k, v = qkv.unbind(-1) # each (B, N) + attn = (q @ k.transpose(-1, -2)).softmax(dim=-1) # (B, N, N) + x_t = (attn @ v).unsqueeze(-1) # (B, N, 1) + + x_t = x_t.transpose(1, 2).reshape(B, C, H, W) + x = shortcut + x_t + + # Feedforward block + shortcut = x + x_t = x.flatten(2).transpose(1, 2) + x_t = self.mlp(self.norm2(x_t)) + x_t = x_t.transpose(1, 2).reshape(B, C, H, W) + x = shortcut + x_t + + return x + + +class SpatialAttention(nn.Module): + """Spatial attention module using channel statistics and transformer.""" + + def __init__(self) -> None: + super().__init__() + self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) + self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3) + self.attention = SpatialTransformerBlock() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_avg = x.mean(dim=1, keepdim=True) + x_max = x.amax(dim=1, keepdim=True) + x = torch.cat([x_avg, x_max], dim=1) + x = self.avgpool(x) + x = self.conv(x) + x = self.attention(x) + return x + + +class TransformerBlock(nn.Module): + """Transformer block with optional downsampling and convolutional position encoding.""" + + def __init__( + self, + inp: int, + oup: int, + num_heads: int = 8, + attn_head_dim: int = 32, + downsample: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., + ) -> None: + super().__init__() + hidden_dim = int(inp * 4) + self.downsample = downsample + + if self.downsample: + self.pool1 = nn.MaxPool2d(3, 2, 1) + self.pool2 = nn.MaxPool2d(3, 2, 1) + self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False) + else: + self.pool1 = nn.Identity() + self.pool2 = nn.Identity() + self.proj = nn.Identity() + + self.pos_embed = PosConv(in_chans=inp) + self.norm1 = nn.LayerNorm(inp) + self.attn = Attention( + dim=inp, + num_heads=num_heads, + attn_head_dim=attn_head_dim, + dim_out=oup, + attn_drop=attn_drop, + proj_drop=proj_drop, + ) + + self.norm2 = nn.LayerNorm(oup) + self.mlp = Mlp(oup, hidden_dim, oup, act_layer=nn.GELU, drop=proj_drop) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.downsample: + shortcut = self.proj(self.pool1(x)) + x_t = self.pool2(x) + B, C, H, W = x_t.shape + x_t = x_t.flatten(2).transpose(1, 2) + x_t = self.norm1(x_t) + x_t = self.pos_embed(x_t, (H, W)) + x_t = self.attn(x_t) + x_t = x_t.transpose(1, 2).reshape(B, -1, H, W) + x = shortcut + x_t + else: + B, C, H, W = x.shape + shortcut = x + x_t = x.flatten(2).transpose(1, 2) + x_t = self.norm1(x_t) + x_t = self.pos_embed(x_t, (H, W)) + x_t = self.attn(x_t) + x_t = x_t.transpose(1, 2).reshape(B, -1, H, W) + x = shortcut + x_t + + # MLP block + B, C, H, W = x.shape + shortcut = x + x_t = x.flatten(2).transpose(1, 2) + x_t = self.mlp(self.norm2(x_t)) + x_t = x_t.transpose(1, 2).reshape(B, C, H, W) + x = shortcut + x_t + + return x + + +class PosConv(nn.Module): + """Convolutional position encoding.""" + + def __init__( + self, + in_chans: int, + ) -> None: + super().__init__() + self.proj = nn.Conv2d(in_chans, in_chans, kernel_size=3, stride=1, padding=1, bias=True, groups=in_chans) + + def forward(self, x: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor: + B, N, C = x.shape + H, W = size + cnn_feat = x.transpose(1, 2).view(B, C, H, W) + x = self.proj(cnn_feat) + cnn_feat + return x.flatten(2).transpose(1, 2) + + +class CSATv2(nn.Module): + """CSATv2: Frequency-domain vision model with spatial attention. + + A hybrid architecture that processes images in the DCT frequency domain + with ConvNeXt-style blocks and transformer attention. + """ + + def __init__( + self, + num_classes: int = 1000, + in_chans: int = 3, + drop_path_rate: float = 0.0, + global_pool: str = 'avg', + **kwargs, + ) -> None: + super().__init__() + self.num_classes = num_classes + self.global_pool = global_pool + self.grad_checkpointing = False + + dims = [32, 72, 168, 386] + self.num_features = dims[-1] + self.head_hidden_size = self.num_features + + self.feature_info = [ + dict(num_chs=dims[0], reduction=8, module='stem_dct'), + dict(num_chs=dims[0], reduction=8, module='stages.0'), + dict(num_chs=dims[1], reduction=16, module='stages.1'), + dict(num_chs=dims[2], reduction=32, module='stages.2'), + dict(num_chs=dims[3], reduction=64, module='stages.3'), + ] + + depths = [2, 2, 6, 4] + dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + + self.stem_dct = LearnableDct2d(8) + + self.stages = nn.Sequential( + nn.Sequential( + Block(dim=dims[0], drop_path=dp_rates[0]), + Block(dim=dims[0], drop_path=dp_rates[1]), + LayerNorm2d(dims[0], eps=1e-6), + ), + nn.Sequential( + nn.Conv2d(dims[0], dims[1], kernel_size=2, stride=2), + Block(dim=dims[1], drop_path=dp_rates[2]), + Block(dim=dims[1], drop_path=dp_rates[3]), + LayerNorm2d(dims[1], eps=1e-6), + ), + nn.Sequential( + nn.Conv2d(dims[1], dims[2], kernel_size=2, stride=2), + Block(dim=dims[2], drop_path=dp_rates[4]), + Block(dim=dims[2], drop_path=dp_rates[5]), + Block(dim=dims[2], drop_path=dp_rates[6]), + Block(dim=dims[2], drop_path=dp_rates[7]), + Block(dim=dims[2], drop_path=dp_rates[8]), + Block(dim=dims[2], drop_path=dp_rates[9]), + TransformerBlock(inp=dims[2], oup=dims[2]), + TransformerBlock(inp=dims[2], oup=dims[2]), + LayerNorm2d(dims[2], eps=1e-6), + ), + nn.Sequential( + nn.Conv2d(dims[2], dims[3], kernel_size=2, stride=2), + Block(dim=dims[3], drop_path=dp_rates[10]), + Block(dim=dims[3], drop_path=dp_rates[11]), + Block(dim=dims[3], drop_path=dp_rates[12]), + Block(dim=dims[3], drop_path=dp_rates[13]), + TransformerBlock(inp=dims[3], oup=dims[3]), + TransformerBlock(inp=dims[3], oup=dims[3]), + ), + ) + + self.head = NormMlpClassifierHead(dims[-1], num_classes, pool_type=global_pool) + + self.apply(self._init_weights) + + def _init_weights(self, m: nn.Module) -> None: + if isinstance(m, (nn.Conv2d, nn.Linear)): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + @torch.jit.ignore + def get_classifier(self) -> nn.Module: + return self.head.fc + + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None: + self.num_classes = num_classes + if global_pool is not None: + self.global_pool = global_pool + self.head.reset(num_classes, pool_type=global_pool) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable: bool = True) -> None: + self.grad_checkpointing = enable + + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + x = self.stem_dct(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.stages, x) + else: + x = self.stages(x) + return x + + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """Forward pass returning intermediate features. + + Args: + x: Input image tensor. + indices: Indices of features to return (0=stem_dct, 1-4=stages). None returns all. + norm: Apply norm layer to final intermediate (unused, for API compat). + stop_early: Stop iterating when last desired intermediate is reached. + output_fmt: Output format, must be 'NCHW'. + intermediates_only: Only return intermediate features. + + Returns: + List of intermediate features or tuple of (final features, intermediates). + """ + assert output_fmt == 'NCHW', 'Output format must be NCHW.' + intermediates = [] + # 5 feature levels: stem_dct (0) + stages 0-3 (1-4) + take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices) + + x = self.stem_dct(x) + if 0 in take_indices: + intermediates.append(x) + + if torch.jit.is_scripting() or not stop_early: + stages = self.stages + else: + # max_index is 0-4, stages are 1-4, so we need max_index stages + stages = self.stages[:max_index] if max_index > 0 else [] + + for feat_idx, stage in enumerate(stages): + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(stage, x) + else: + x = stage(x) + if feat_idx + 1 in take_indices: # +1 because stem is index 0 + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ) -> List[int]: + """Prune layers not required for specified intermediates. + + Args: + indices: Indices of intermediate layers to keep (0=stem_dct, 1-4=stages). + prune_norm: Whether to prune the final norm layer. + prune_head: Whether to prune the classifier head. + + Returns: + List of indices that were kept. + """ + # 5 feature levels: stem_dct (0) + stages 0-3 (1-4) + take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices) + # max_index is 0-4, stages are 1-4, so we keep max_index stages + self.stages = self.stages[:max_index] if max_index > 0 else nn.Sequential() + + if prune_norm: + self.head.norm = nn.Identity() + if prune_head: + self.reset_classifier(0, '') + + return take_indices + + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + return self.head(x, pre_logits=pre_logits) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.forward_features(x) + return self.forward_head(x) + + +default_cfgs = generate_default_cfgs({ + 'csatv2': { + 'url': 'https://huggingface.co/Hyunil/CSATv2/resolve/main/CSATv2_ImageNet_timm.pth', + 'num_classes': 1000, + 'input_size': (3, 512, 512), + 'mean': (0.485, 0.456, 0.406), + 'std': (0.229, 0.224, 0.225), + 'interpolation': 'bilinear', + 'crop_pct': 1.0, + 'classifier': 'head.fc', + 'first_conv': [], + }, +}) + + +def checkpoint_filter_fn(state_dict: dict, model: nn.Module) -> dict: + """Remap original CSATv2 checkpoint to timm format. + + Handles two key structural changes: + 1. Stage naming: stages1/2/3/4 -> stages.0/1/2/3 + 2. Downsample position: moved from end of stage N to start of stage N+1 + """ + if 'stages.0.0.grn.weight' in state_dict: + return state_dict # Already in timm format + + import re + + # Downsample indices in original checkpoint (Conv2d at end of each stage) + # These move to index 0 of the next stage + downsample_idx = {1: 3, 2: 3, 3: 9} # stage -> downsample index + + def remap_stage(m): + stage = int(m.group(1)) + idx = int(m.group(2)) + rest = m.group(3) + if stage in downsample_idx and idx == downsample_idx[stage]: + # Downsample moves to start of next stage + return f'stages.{stage}.0.{rest}' + elif stage == 1: + # Stage 1 -> stages.0, indices unchanged + return f'stages.0.{idx}.{rest}' + else: + # Stages 2-4 -> stages.1-3, indices shift +1 (after downsample) + return f'stages.{stage - 1}.{idx + 1}.{rest}' + + out_dict = {} + for k, v in state_dict.items(): + # Remap dct -> stem_dct + k = re.sub(r'^dct\.', 'stem_dct.', k) + + # Remap stage names with index adjustments for downsample relocation + k = re.sub(r'^stages([1-4])\.(\d+)\.(.*)$', remap_stage, k) + + # Remap GRN: gamma/beta -> weight/bias with reshape + if 'grn.gamma' in k: + k = k.replace('grn.gamma', 'grn.weight') + v = v.reshape(-1) + elif 'grn.beta' in k: + k = k.replace('grn.beta', 'grn.bias') + v = v.reshape(-1) + + # Remap FeedForward (nn.Sequential) to Mlp: net.0 -> fc1, net.3 -> fc2 + # Also rename ff -> mlp, ff_norm -> norm2, attn_norm -> norm1 + if '.ff.net.0.' in k: + k = k.replace('.ff.net.0.', '.mlp.fc1.') + elif '.ff.net.3.' in k: + k = k.replace('.ff.net.3.', '.mlp.fc2.') + elif '.ff_norm.' in k: + k = k.replace('.ff_norm.', '.norm2.') + elif '.attn_norm.' in k: + k = k.replace('.attn_norm.', '.norm1.') + + # SpatialTransformerBlock: flatten .attention.attention.attn. -> .attention.attention. + # and remap to_qkv -> qkv + if '.attention.attention.attn.' in k: + k = k.replace('.attention.attention.attn.to_qkv.', '.attention.attention.qkv.') + k = k.replace('.attention.attention.attn.', '.attention.attention.') + + # TransformerBlock: remap attention layer names + # to_qkv -> qkv, to_out.0 -> proj, attn.pos_embed -> pos_embed + if '.attn.to_qkv.' in k: + k = k.replace('.attn.to_qkv.', '.attn.qkv.') + elif '.attn.to_out.0.' in k: + k = k.replace('.attn.to_out.0.', '.attn.proj.') + + if '.attn.pos_embed.' in k: + k = k.replace('.attn.pos_embed.', '.pos_embed.') + + # Remap head -> head.fc, norm -> head.norm (order matters) + k = re.sub(r'^head\.', 'head.fc.', k) + k = re.sub(r'^norm\.', 'head.norm.', k) + + out_dict[k] = v + + return out_dict + + +def _create_csatv2(variant: str, pretrained: bool = False, **kwargs) -> CSATv2: + return build_model_with_cfg( + CSATv2, + variant, + pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(out_indices=(0, 1, 2, 3, 4), flatten_sequential=True), + default_cfg=default_cfgs[variant], + **kwargs, + ) + + +@register_model +def csatv2(pretrained: bool = False, **kwargs) -> CSATv2: + return _create_csatv2('csatv2', pretrained, **kwargs) From 1c6b0f5e73aed8304b170da12fcbd2094e5a9f5d Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 9 Dec 2025 15:41:48 -0800 Subject: [PATCH 11/14] Weird mix of camel/snake conv naming -> snake case --- timm/models/csatv2.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/timm/models/csatv2.py b/timm/models/csatv2.py index 1d29c8b21c..eb3e60b6c5 100644 --- a/timm/models/csatv2.py +++ b/timm/models/csatv2.py @@ -189,9 +189,9 @@ def __init__( self.unfold = nn.Unfold(kernel_size=(kernel_size, kernel_size), stride=(kernel_size, kernel_size)) self.transform = Dct2d(kernel_size, kernel_type, orthonormal, **factory_kwargs) self.permutation = _zigzag_permutation(kernel_size, kernel_size) - self.Y_Conv = nn.Conv2d(kernel_size ** 2, 24, kernel_size=1, padding=0) - self.Cb_Conv = nn.Conv2d(kernel_size ** 2, 4, kernel_size=1, padding=0) - self.Cr_Conv = nn.Conv2d(kernel_size ** 2, 4, kernel_size=1, padding=0) + self.conv_y = nn.Conv2d(kernel_size ** 2, 24, kernel_size=1, padding=0) + self.conv_cb = nn.Conv2d(kernel_size ** 2, 4, kernel_size=1, padding=0) + self.conv_cr = nn.Conv2d(kernel_size ** 2, 4, kernel_size=1, padding=0) self.register_buffer('mean', torch.tensor(_DCT_MEAN), persistent=False) self.register_buffer('var', torch.tensor(_DCT_VAR), persistent=False) @@ -228,9 +228,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self._frequency_normalize(x) x = x.reshape(b, h // self.k, w // self.k, c, -1) x = x.permute(0, 3, 4, 1, 2).contiguous() - x_y = self.Y_Conv(x[:, 0]) - x_cb = self.Cb_Conv(x[:, 1]) - x_cr = self.Cr_Conv(x[:, 2]) + x_y = self.conv_y(x[:, 0]) + x_cb = self.conv_cb(x[:, 1]) + x_cr = self.conv_cr(x[:, 2]) return torch.cat([x_y, x_cb, x_cr], dim=1) @@ -697,8 +697,11 @@ def remap_stage(m): out_dict = {} for k, v in state_dict.items(): - # Remap dct -> stem_dct + # Remap dct -> stem_dct, and Y_Conv/Cb_Conv/Cr_Conv -> conv_y/conv_cb/conv_cr k = re.sub(r'^dct\.', 'stem_dct.', k) + k = k.replace('.Y_Conv.', '.conv_y.') + k = k.replace('.Cb_Conv.', '.conv_cb.') + k = k.replace('.Cr_Conv.', '.conv_cr.') # Remap stage names with index adjustments for downsample relocation k = re.sub(r'^stages([1-4])\.(\d+)\.(.*)$', remap_stage, k) From 4706c8ecca82806b1085d55d2c1a8deb459f7b06 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 9 Dec 2025 16:21:33 -0800 Subject: [PATCH 12/14] Consistent use of device/dtype factory kwargs through csatv2 and submodules --- timm/models/csatv2.py | 159 ++++++++++++++++++++++++------------------ 1 file changed, 91 insertions(+), 68 deletions(-) diff --git a/timm/models/csatv2.py b/timm/models/csatv2.py index eb3e60b6c5..25820829e0 100644 --- a/timm/models/csatv2.py +++ b/timm/models/csatv2.py @@ -102,21 +102,21 @@ def _dct_kernel_type_2( dtype=None, ) -> torch.Tensor: """Generate Type-II DCT kernel matrix.""" - factory_kwargs = dict(device=device, dtype=dtype) - x = torch.eye(kernel_size, **factory_kwargs) + dd = dict(device=device, dtype=dtype) + x = torch.eye(kernel_size, **dd) v = x.clone().contiguous().view(-1, kernel_size) v = torch.cat([v, v.flip([1])], dim=-1) v = torch.fft.fft(v, dim=-1)[:, :kernel_size] try: - k = torch.tensor(-1j, **factory_kwargs) * torch.pi * torch.arange(kernel_size, **factory_kwargs)[None, :] + k = torch.tensor(-1j, **dd) * torch.pi * torch.arange(kernel_size, **dd)[None, :] except AttributeError: - k = torch.tensor(-1j, **factory_kwargs) * math.pi * torch.arange(kernel_size, **factory_kwargs)[None, :] + k = torch.tensor(-1j, **dd) * math.pi * torch.arange(kernel_size, **dd)[None, :] k = torch.exp(k / (kernel_size * 2)) v = v * k v = v.real if orthonormal: - v[:, 0] = v[:, 0] * torch.sqrt(torch.tensor(1 / (kernel_size * 4), **factory_kwargs)) - v[:, 1:] = v[:, 1:] * torch.sqrt(torch.tensor(1 / (kernel_size * 2), **factory_kwargs)) + v[:, 0] = v[:, 0] * torch.sqrt(torch.tensor(1 / (kernel_size * 4), **dd)) + v[:, 1:] = v[:, 1:] * torch.sqrt(torch.tensor(1 / (kernel_size * 2), **dd)) v = v.contiguous().view(*x.shape) return v @@ -142,10 +142,10 @@ def __init__( device=None, dtype=None, ) -> None: - factory_kwargs = dict(device=device, dtype=dtype) + dd = dict(device=device, dtype=dtype) super().__init__() kernel = {'2': _dct_kernel_type_2, '3': _dct_kernel_type_3} - dct_weights = kernel[f'{kernel_type}'](kernel_size, orthonormal, **factory_kwargs).T + dct_weights = kernel[f'{kernel_type}'](kernel_size, orthonormal, **dd).T self.register_buffer('weights', dct_weights) self.register_parameter('bias', None) @@ -164,9 +164,9 @@ def __init__( device=None, dtype=None, ) -> None: - factory_kwargs = dict(device=device, dtype=dtype) + dd = dict(device=device, dtype=dtype) super().__init__() - self.transform = Dct1d(kernel_size, kernel_type, orthonormal, **factory_kwargs) + self.transform = Dct1d(kernel_size, kernel_type, orthonormal, **dd) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.transform(self.transform(x).transpose(-1, -2)).transpose(-1, -2) @@ -183,20 +183,20 @@ def __init__( device=None, dtype=None, ) -> None: - factory_kwargs = dict(device=device, dtype=dtype) + dd = dict(device=device, dtype=dtype) super().__init__() self.k = kernel_size self.unfold = nn.Unfold(kernel_size=(kernel_size, kernel_size), stride=(kernel_size, kernel_size)) - self.transform = Dct2d(kernel_size, kernel_type, orthonormal, **factory_kwargs) + self.transform = Dct2d(kernel_size, kernel_type, orthonormal, **dd) self.permutation = _zigzag_permutation(kernel_size, kernel_size) - self.conv_y = nn.Conv2d(kernel_size ** 2, 24, kernel_size=1, padding=0) - self.conv_cb = nn.Conv2d(kernel_size ** 2, 4, kernel_size=1, padding=0) - self.conv_cr = nn.Conv2d(kernel_size ** 2, 4, kernel_size=1, padding=0) + self.conv_y = nn.Conv2d(kernel_size ** 2, 24, kernel_size=1, padding=0, **dd) + self.conv_cb = nn.Conv2d(kernel_size ** 2, 4, kernel_size=1, padding=0, **dd) + self.conv_cr = nn.Conv2d(kernel_size ** 2, 4, kernel_size=1, padding=0, **dd) - self.register_buffer('mean', torch.tensor(_DCT_MEAN), persistent=False) - self.register_buffer('var', torch.tensor(_DCT_VAR), persistent=False) - self.register_buffer('imagenet_mean', torch.tensor([0.485, 0.456, 0.406]), persistent=False) - self.register_buffer('imagenet_std', torch.tensor([0.229, 0.224, 0.225]), persistent=False) + self.register_buffer('mean', torch.tensor(_DCT_MEAN, device=device), persistent=False) + self.register_buffer('var', torch.tensor(_DCT_VAR, device=device), persistent=False) + self.register_buffer('imagenet_mean', torch.tensor([0.485, 0.456, 0.406], device=device), persistent=False) + self.register_buffer('imagenet_std', torch.tensor([0.229, 0.224, 0.225], device=device), persistent=False) def _denormalize(self, x: torch.Tensor) -> torch.Tensor: """Convert from ImageNet normalized to [0, 255] range.""" @@ -245,11 +245,11 @@ def __init__( device=None, dtype=None, ) -> None: - factory_kwargs = dict(device=device, dtype=dtype) + dd = dict(device=device, dtype=dtype) super().__init__() self.k = kernel_size self.unfold = nn.Unfold(kernel_size=(kernel_size, kernel_size), stride=(kernel_size, kernel_size)) - self.transform = Dct2d(kernel_size, kernel_type, orthonormal, **factory_kwargs) + self.transform = Dct2d(kernel_size, kernel_type, orthonormal, **dd) self.permutation = _zigzag_permutation(kernel_size, kernel_size) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: @@ -276,16 +276,19 @@ def __init__( self, dim: int, drop_path: float = 0., + device=None, + dtype=None, ) -> None: + dd = dict(device=device, dtype=dtype) super().__init__() - self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) - self.norm = nn.LayerNorm(dim, eps=1e-6) - self.pwconv1 = nn.Linear(dim, 4 * dim) + self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, **dd) + self.norm = nn.LayerNorm(dim, eps=1e-6, **dd) + self.pwconv1 = nn.Linear(dim, 4 * dim, **dd) self.act = nn.GELU() - self.grn = GlobalResponseNorm(4 * dim, channels_last=True) - self.pwconv2 = nn.Linear(4 * dim, dim) + self.grn = GlobalResponseNorm(4 * dim, channels_last=True, **dd) + self.pwconv2 = nn.Linear(4 * dim, dim, **dd) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.attention = SpatialAttention() + self.attention = SpatialAttention(**dd) def forward(self, x: torch.Tensor) -> torch.Tensor: shortcut = x @@ -312,16 +315,21 @@ class SpatialTransformerBlock(nn.Module): positions. Used inside SpatialAttention where input is 1 channel at 7x7 resolution. """ - def __init__(self) -> None: + def __init__( + self, + device=None, + dtype=None, + ) -> None: + dd = dict(device=device, dtype=dtype) super().__init__() # Single-head attention with 1-dim q/k/v (no output projection needed) - self.pos_embed = PosConv(in_chans=1) - self.norm1 = nn.LayerNorm(1) - self.qkv = nn.Linear(1, 3, bias=False) + self.pos_embed = PosConv(in_chans=1, **dd) + self.norm1 = nn.LayerNorm(1, **dd) + self.qkv = nn.Linear(1, 3, bias=False, **dd) # Feedforward: 1 -> 4 -> 1 - self.norm2 = nn.LayerNorm(1) - self.mlp = Mlp(1, 4, 1, act_layer=nn.GELU) + self.norm2 = nn.LayerNorm(1, **dd) + self.mlp = Mlp(1, 4, 1, act_layer=nn.GELU, **dd) def forward(self, x: torch.Tensor) -> torch.Tensor: B, C, H, W = x.shape @@ -354,11 +362,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class SpatialAttention(nn.Module): """Spatial attention module using channel statistics and transformer.""" - def __init__(self) -> None: + def __init__( + self, + device=None, + dtype=None, + ) -> None: + dd = dict(device=device, dtype=dtype) super().__init__() self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) - self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3) - self.attention = SpatialTransformerBlock() + self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3, **dd) + self.attention = SpatialTransformerBlock(**dd) def forward(self, x: torch.Tensor) -> torch.Tensor: x_avg = x.mean(dim=1, keepdim=True) @@ -382,7 +395,10 @@ def __init__( downsample: bool = False, attn_drop: float = 0., proj_drop: float = 0., + device=None, + dtype=None, ) -> None: + dd = dict(device=device, dtype=dtype) super().__init__() hidden_dim = int(inp * 4) self.downsample = downsample @@ -390,14 +406,14 @@ def __init__( if self.downsample: self.pool1 = nn.MaxPool2d(3, 2, 1) self.pool2 = nn.MaxPool2d(3, 2, 1) - self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False) + self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False, **dd) else: self.pool1 = nn.Identity() self.pool2 = nn.Identity() self.proj = nn.Identity() - self.pos_embed = PosConv(in_chans=inp) - self.norm1 = nn.LayerNorm(inp) + self.pos_embed = PosConv(in_chans=inp, **dd) + self.norm1 = nn.LayerNorm(inp, **dd) self.attn = Attention( dim=inp, num_heads=num_heads, @@ -405,10 +421,11 @@ def __init__( dim_out=oup, attn_drop=attn_drop, proj_drop=proj_drop, + **dd, ) - self.norm2 = nn.LayerNorm(oup) - self.mlp = Mlp(oup, hidden_dim, oup, act_layer=nn.GELU, drop=proj_drop) + self.norm2 = nn.LayerNorm(oup, **dd) + self.mlp = Mlp(oup, hidden_dim, oup, act_layer=nn.GELU, drop=proj_drop, **dd) def forward(self, x: torch.Tensor) -> torch.Tensor: if self.downsample: @@ -448,9 +465,12 @@ class PosConv(nn.Module): def __init__( self, in_chans: int, + device=None, + dtype=None, ) -> None: + dd = dict(device=device, dtype=dtype) super().__init__() - self.proj = nn.Conv2d(in_chans, in_chans, kernel_size=3, stride=1, padding=1, bias=True, groups=in_chans) + self.proj = nn.Conv2d(in_chans, in_chans, kernel_size=3, stride=1, padding=1, bias=True, groups=in_chans, **dd) def forward(self, x: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor: B, N, C = x.shape @@ -473,8 +493,11 @@ def __init__( in_chans: int = 3, drop_path_rate: float = 0.0, global_pool: str = 'avg', + device=None, + dtype=None, **kwargs, ) -> None: + dd = dict(device=device, dtype=dtype) super().__init__() self.num_classes = num_classes self.global_pool = global_pool @@ -495,44 +518,44 @@ def __init__( depths = [2, 2, 6, 4] dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] - self.stem_dct = LearnableDct2d(8) + self.stem_dct = LearnableDct2d(8, **dd) self.stages = nn.Sequential( nn.Sequential( - Block(dim=dims[0], drop_path=dp_rates[0]), - Block(dim=dims[0], drop_path=dp_rates[1]), - LayerNorm2d(dims[0], eps=1e-6), + Block(dim=dims[0], drop_path=dp_rates[0], **dd), + Block(dim=dims[0], drop_path=dp_rates[1], **dd), + LayerNorm2d(dims[0], eps=1e-6, **dd), ), nn.Sequential( - nn.Conv2d(dims[0], dims[1], kernel_size=2, stride=2), - Block(dim=dims[1], drop_path=dp_rates[2]), - Block(dim=dims[1], drop_path=dp_rates[3]), - LayerNorm2d(dims[1], eps=1e-6), + nn.Conv2d(dims[0], dims[1], kernel_size=2, stride=2, **dd), + Block(dim=dims[1], drop_path=dp_rates[2], **dd), + Block(dim=dims[1], drop_path=dp_rates[3], **dd), + LayerNorm2d(dims[1], eps=1e-6, **dd), ), nn.Sequential( - nn.Conv2d(dims[1], dims[2], kernel_size=2, stride=2), - Block(dim=dims[2], drop_path=dp_rates[4]), - Block(dim=dims[2], drop_path=dp_rates[5]), - Block(dim=dims[2], drop_path=dp_rates[6]), - Block(dim=dims[2], drop_path=dp_rates[7]), - Block(dim=dims[2], drop_path=dp_rates[8]), - Block(dim=dims[2], drop_path=dp_rates[9]), - TransformerBlock(inp=dims[2], oup=dims[2]), - TransformerBlock(inp=dims[2], oup=dims[2]), - LayerNorm2d(dims[2], eps=1e-6), + nn.Conv2d(dims[1], dims[2], kernel_size=2, stride=2, **dd), + Block(dim=dims[2], drop_path=dp_rates[4], **dd), + Block(dim=dims[2], drop_path=dp_rates[5], **dd), + Block(dim=dims[2], drop_path=dp_rates[6], **dd), + Block(dim=dims[2], drop_path=dp_rates[7], **dd), + Block(dim=dims[2], drop_path=dp_rates[8], **dd), + Block(dim=dims[2], drop_path=dp_rates[9], **dd), + TransformerBlock(inp=dims[2], oup=dims[2], **dd), + TransformerBlock(inp=dims[2], oup=dims[2], **dd), + LayerNorm2d(dims[2], eps=1e-6, **dd), ), nn.Sequential( - nn.Conv2d(dims[2], dims[3], kernel_size=2, stride=2), - Block(dim=dims[3], drop_path=dp_rates[10]), - Block(dim=dims[3], drop_path=dp_rates[11]), - Block(dim=dims[3], drop_path=dp_rates[12]), - Block(dim=dims[3], drop_path=dp_rates[13]), - TransformerBlock(inp=dims[3], oup=dims[3]), - TransformerBlock(inp=dims[3], oup=dims[3]), + nn.Conv2d(dims[2], dims[3], kernel_size=2, stride=2, **dd), + Block(dim=dims[3], drop_path=dp_rates[10], **dd), + Block(dim=dims[3], drop_path=dp_rates[11], **dd), + Block(dim=dims[3], drop_path=dp_rates[12], **dd), + Block(dim=dims[3], drop_path=dp_rates[13], **dd), + TransformerBlock(inp=dims[3], oup=dims[3], **dd), + TransformerBlock(inp=dims[3], oup=dims[3], **dd), ), ) - self.head = NormMlpClassifierHead(dims[-1], num_classes, pool_type=global_pool) + self.head = NormMlpClassifierHead(dims[-1], num_classes, pool_type=global_pool, **dd) self.apply(self._init_weights) From 6cec2f0c14fb958156be4e88f621af67d9126e8b Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 9 Dec 2025 16:46:45 -0800 Subject: [PATCH 13/14] Make default out_indices use stage outputs and skip the dct out --- timm/models/csatv2.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/timm/models/csatv2.py b/timm/models/csatv2.py index 25820829e0..78edc0bad2 100644 --- a/timm/models/csatv2.py +++ b/timm/models/csatv2.py @@ -774,12 +774,13 @@ def remap_stage(m): def _create_csatv2(variant: str, pretrained: bool = False, **kwargs) -> CSATv2: + out_indices = kwargs.pop('out_indices', (1, 2, 3, 4)) return build_model_with_cfg( CSATv2, variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn, - feature_cfg=dict(out_indices=(0, 1, 2, 3, 4), flatten_sequential=True), + feature_cfg=dict(out_indices=out_indices, flatten_sequential=True), default_cfg=default_cfgs[variant], **kwargs, ) From b6eb61aba5dc185b4f9dd1b2ffe8b612e25af46a Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 11 Dec 2025 15:13:01 -0800 Subject: [PATCH 14/14] Another round of consistency changes for csatv2, make stage building dynamic for other network shapes, allow drop path option for transformer blocks. --- timm/models/csatv2.py | 126 ++++++++++++++++++++++-------------------- 1 file changed, 66 insertions(+), 60 deletions(-) diff --git a/timm/models/csatv2.py b/timm/models/csatv2.py index 78edc0bad2..186498e805 100644 --- a/timm/models/csatv2.py +++ b/timm/models/csatv2.py @@ -288,7 +288,7 @@ def __init__( self.grn = GlobalResponseNorm(4 * dim, channels_last=True, **dd) self.pwconv2 = nn.Linear(4 * dim, dim, **dd) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.attention = SpatialAttention(**dd) + self.attn = SpatialAttention(**dd) def forward(self, x: torch.Tensor) -> torch.Tensor: shortcut = x @@ -301,9 +301,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.pwconv2(x) x = x.permute(0, 3, 1, 2) - attention = self.attention(x) - up_attn = F.interpolate(attention, size=x.shape[2:], mode='bilinear', align_corners=True) - x = x * up_attn + attn = self.attn(x) + attn = F.interpolate(attn, size=x.shape[2:], mode='bilinear', align_corners=True) + x = x * attn return shortcut + self.drop_path(x) @@ -371,7 +371,7 @@ def __init__( super().__init__() self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3, **dd) - self.attention = SpatialTransformerBlock(**dd) + self.attn = SpatialTransformerBlock(**dd) def forward(self, x: torch.Tensor) -> torch.Tensor: x_avg = x.mean(dim=1, keepdim=True) @@ -379,7 +379,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = torch.cat([x_avg, x_max], dim=1) x = self.avgpool(x) x = self.conv(x) - x = self.attention(x) + x = self.attn(x) return x @@ -395,6 +395,7 @@ def __init__( downsample: bool = False, attn_drop: float = 0., proj_drop: float = 0., + drop_path: float = 0., device=None, dtype=None, ) -> None: @@ -423,9 +424,11 @@ def __init__( proj_drop=proj_drop, **dd, ) + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = nn.LayerNorm(oup, **dd) self.mlp = Mlp(oup, hidden_dim, oup, act_layer=nn.GELU, drop=proj_drop, **dd) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: if self.downsample: @@ -437,7 +440,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x_t = self.pos_embed(x_t, (H, W)) x_t = self.attn(x_t) x_t = x_t.transpose(1, 2).reshape(B, -1, H, W) - x = shortcut + x_t + x = shortcut + self.drop_path1(x_t) else: B, C, H, W = x.shape shortcut = x @@ -446,7 +449,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x_t = self.pos_embed(x_t, (H, W)) x_t = self.attn(x_t) x_t = x_t.transpose(1, 2).reshape(B, -1, H, W) - x = shortcut + x_t + x = shortcut + self.drop_path1(x_t) # MLP block B, C, H, W = x.shape @@ -454,7 +457,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x_t = x.flatten(2).transpose(1, 2) x_t = self.mlp(self.norm2(x_t)) x_t = x_t.transpose(1, 2).reshape(B, C, H, W) - x = shortcut + x_t + x = shortcut + self.drop_path2(x_t) return x @@ -491,7 +494,11 @@ def __init__( self, num_classes: int = 1000, in_chans: int = 3, + dims: Tuple[int, ...] = (32, 72, 168, 386), + depths: Tuple[int, ...] = (2, 2, 8, 6), + transformer_depths: Tuple[int, ...] = (0, 0, 2, 2), drop_path_rate: float = 0.0, + transformer_drop_path: bool = False, global_pool: str = 'avg', device=None, dtype=None, @@ -499,61 +506,52 @@ def __init__( ) -> None: dd = dict(device=device, dtype=dtype) super().__init__() + if in_chans != 3: + warnings.warn( + f'CSATv2 is designed for 3-channel RGB input. ' + f'in_chans={in_chans} may not work correctly with the DCT stem.' + ) self.num_classes = num_classes self.global_pool = global_pool self.grad_checkpointing = False - dims = [32, 72, 168, 386] self.num_features = dims[-1] self.head_hidden_size = self.num_features - self.feature_info = [ - dict(num_chs=dims[0], reduction=8, module='stem_dct'), - dict(num_chs=dims[0], reduction=8, module='stages.0'), - dict(num_chs=dims[1], reduction=16, module='stages.1'), - dict(num_chs=dims[2], reduction=32, module='stages.2'), - dict(num_chs=dims[3], reduction=64, module='stages.3'), - ] - - depths = [2, 2, 6, 4] - dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + # Build feature_info dynamically + self.feature_info = [dict(num_chs=dims[0], reduction=8, module='stem_dct')] + reduction = 8 + for i, dim in enumerate(dims): + if i > 0: + reduction *= 2 + self.feature_info.append(dict(num_chs=dim, reduction=reduction, module=f'stages.{i}')) + + # Build drop path rates for all blocks (0 for transformer blocks when transformer_drop_path=False) + total_blocks = sum(depths) if transformer_drop_path else sum(d - t for d, t in zip(depths, transformer_depths)) + dp_iter = iter(torch.linspace(0, drop_path_rate, total_blocks).tolist()) + dp_rates = [] + for depth, t_depth in zip(depths, transformer_depths): + dp_rates += [next(dp_iter) for _ in range(depth - t_depth)] + dp_rates += [next(dp_iter) if transformer_drop_path else 0. for _ in range(t_depth)] self.stem_dct = LearnableDct2d(8, **dd) - self.stages = nn.Sequential( - nn.Sequential( - Block(dim=dims[0], drop_path=dp_rates[0], **dd), - Block(dim=dims[0], drop_path=dp_rates[1], **dd), - LayerNorm2d(dims[0], eps=1e-6, **dd), - ), - nn.Sequential( - nn.Conv2d(dims[0], dims[1], kernel_size=2, stride=2, **dd), - Block(dim=dims[1], drop_path=dp_rates[2], **dd), - Block(dim=dims[1], drop_path=dp_rates[3], **dd), - LayerNorm2d(dims[1], eps=1e-6, **dd), - ), - nn.Sequential( - nn.Conv2d(dims[1], dims[2], kernel_size=2, stride=2, **dd), - Block(dim=dims[2], drop_path=dp_rates[4], **dd), - Block(dim=dims[2], drop_path=dp_rates[5], **dd), - Block(dim=dims[2], drop_path=dp_rates[6], **dd), - Block(dim=dims[2], drop_path=dp_rates[7], **dd), - Block(dim=dims[2], drop_path=dp_rates[8], **dd), - Block(dim=dims[2], drop_path=dp_rates[9], **dd), - TransformerBlock(inp=dims[2], oup=dims[2], **dd), - TransformerBlock(inp=dims[2], oup=dims[2], **dd), - LayerNorm2d(dims[2], eps=1e-6, **dd), - ), - nn.Sequential( - nn.Conv2d(dims[2], dims[3], kernel_size=2, stride=2, **dd), - Block(dim=dims[3], drop_path=dp_rates[10], **dd), - Block(dim=dims[3], drop_path=dp_rates[11], **dd), - Block(dim=dims[3], drop_path=dp_rates[12], **dd), - Block(dim=dims[3], drop_path=dp_rates[13], **dd), - TransformerBlock(inp=dims[3], oup=dims[3], **dd), - TransformerBlock(inp=dims[3], oup=dims[3], **dd), - ), - ) + # Build stages dynamically + dp_iter = iter(dp_rates) + stages = [] + for i, (dim, depth, t_depth) in enumerate(zip(dims, depths, transformer_depths)): + layers = ( + # Downsample at start of stage (except first stage) + ([nn.Conv2d(dims[i - 1], dim, kernel_size=2, stride=2, **dd)] if i > 0 else []) + + # Conv blocks + [Block(dim=dim, drop_path=next(dp_iter), **dd) for _ in range(depth - t_depth)] + + # Transformer blocks at end of stage + [TransformerBlock(inp=dim, oup=dim, drop_path=next(dp_iter), **dd) for _ in range(t_depth)] + + # Trailing LayerNorm (except last stage) + ([LayerNorm2d(dim, eps=1e-6, **dd)] if i < len(depths) - 1 else []) + ) + stages.append(nn.Sequential(*layers)) + self.stages = nn.Sequential(*stages) self.head = NormMlpClassifierHead(dims[-1], num_classes, pool_type=global_pool, **dd) @@ -748,20 +746,28 @@ def remap_stage(m): elif '.attn_norm.' in k: k = k.replace('.attn_norm.', '.norm1.') - # SpatialTransformerBlock: flatten .attention.attention.attn. -> .attention.attention. - # and remap to_qkv -> qkv - if '.attention.attention.attn.' in k: - k = k.replace('.attention.attention.attn.to_qkv.', '.attention.attention.qkv.') - k = k.replace('.attention.attention.attn.', '.attention.attention.') + # Block.attention -> Block.attn (SpatialAttention) + # SpatialAttention.attention -> SpatialAttention.attn (SpatialTransformerBlock) + # Handle nested .attention.attention. first, then remaining .attention. + if '.attention.attention.' in k: + # SpatialTransformerBlock inner attn: remap to_qkv -> qkv + k = k.replace('.attention.attention.attn.to_qkv.', '.attn.attn.qkv.') + k = k.replace('.attention.attention.attn.', '.attn.attn.') + k = k.replace('.attention.attention.', '.attn.attn.') + elif '.attention.' in k: + # Block.attention -> Block.attn (catches SpatialAttention.conv etc) + k = k.replace('.attention.', '.attn.') # TransformerBlock: remap attention layer names # to_qkv -> qkv, to_out.0 -> proj, attn.pos_embed -> pos_embed + # Note: only for TransformerBlock, not SpatialTransformerBlock (which has .attn.attn.) if '.attn.to_qkv.' in k: k = k.replace('.attn.to_qkv.', '.attn.qkv.') elif '.attn.to_out.0.' in k: k = k.replace('.attn.to_out.0.', '.attn.proj.') - if '.attn.pos_embed.' in k: + # TransformerBlock: .attn.pos_embed -> .pos_embed (but not .attn.attn.pos_embed) + if '.attn.pos_embed.' in k and '.attn.attn.' not in k: k = k.replace('.attn.pos_embed.', '.pos_embed.') # Remap head -> head.fc, norm -> head.norm (order matters)