From 2b8c1987466b8568e7fe966fbfbad30b987a90cf Mon Sep 17 00:00:00 2001 From: Yonghye Kwon Date: Sat, 14 Jun 2025 12:23:23 +0900 Subject: [PATCH] Add common model interface methods to UniRepLKNet --- timm/models/unireplknet.py | 833 ++++++++++++++++++++++++++----------- 1 file changed, 585 insertions(+), 248 deletions(-) diff --git a/timm/models/unireplknet.py b/timm/models/unireplknet.py index 71f2c68c66..aa1b01f0d7 100644 --- a/timm/models/unireplknet.py +++ b/timm/models/unireplknet.py @@ -11,12 +11,16 @@ import torch import torch.nn as nn import torch.nn.functional as F +from typing import Optional, List, Union, Tuple + from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.models.layers import trunc_normal_, DropPath, to_2tuple from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._registry import register_model, generate_default_cfgs from functools import partial import torch.utils.checkpoint as checkpoint + try: from huggingface_hub import hf_hub_download except Exception: @@ -24,11 +28,12 @@ class GRNwithNHWC(nn.Module): - """ GRN (Global Response Normalization) layer + """GRN (Global Response Normalization) layer Originally proposed in ConvNeXt V2 (https://arxiv.org/abs/2301.00808) This implementation is more efficient than the original (https://github.com/facebookresearch/ConvNeXt-V2) We assume the inputs to this layer are (N, H, W, C) """ + def __init__(self, dim, use_bias=True): super().__init__() self.use_bias = use_bias @@ -60,15 +65,33 @@ def __init__(self): def forward(self, x): return x.permute(0, 3, 1, 2) -def get_conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, - attempt_use_lk_impl=True): + +def get_conv2d( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + attempt_use_lk_impl=True, +): kernel_size = to_2tuple(kernel_size) if padding is None: padding = (kernel_size[0] // 2, kernel_size[1] // 2) else: padding = to_2tuple(padding) - return nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, - padding=padding, dilation=dilation, groups=groups, bias=bias) + return nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) def get_bn(dim, use_sync_bn=False): @@ -77,17 +100,29 @@ def get_bn(dim, use_sync_bn=False): else: return nn.BatchNorm2d(dim) + class SEBlock(nn.Module): """ Squeeze-and-Excitation Block proposed in SENet (https://arxiv.org/abs/1709.01507) We assume the inputs to this layer are (N, C, H, W) """ + def __init__(self, input_channels, internal_neurons): super(SEBlock, self).__init__() - self.down = nn.Conv2d(in_channels=input_channels, out_channels=internal_neurons, - kernel_size=1, stride=1, bias=True) - self.up = nn.Conv2d(in_channels=internal_neurons, out_channels=input_channels, - kernel_size=1, stride=1, bias=True) + self.down = nn.Conv2d( + in_channels=input_channels, + out_channels=internal_neurons, + kernel_size=1, + stride=1, + bias=True, + ) + self.up = nn.Conv2d( + in_channels=internal_neurons, + out_channels=input_channels, + kernel_size=1, + stride=1, + bias=True, + ) self.input_channels = input_channels self.nonlinear = nn.ReLU(inplace=True) @@ -99,10 +134,15 @@ def forward(self, inputs): x = F.sigmoid(x) return inputs * x.view(-1, self.input_channels, 1, 1) + def fuse_bn(conv, bn): conv_bias = 0 if conv.bias is None else conv.bias std = (bn.running_var + bn.eps).sqrt() - return conv.weight * (bn.weight / std).reshape(-1, 1, 1, 1), bn.bias + (conv_bias - bn.running_mean) * bn.weight / std + return ( + conv.weight * (bn.weight / std).reshape(-1, 1, 1, 1), + bn.bias + (conv_bias - bn.running_mean) * bn.weight / std, + ) + def convert_dilated_to_nondilated(kernel, dilate_rate): identity_kernel = torch.ones((1, 1, 1, 1)).to(kernel.device) @@ -114,10 +154,13 @@ def convert_dilated_to_nondilated(kernel, dilate_rate): # This is a dense or group-wise (but not DW) kernel slices = [] for i in range(kernel.size(1)): - dilated = F.conv_transpose2d(kernel[:,i:i+1,:,:], identity_kernel, stride=dilate_rate) + dilated = F.conv_transpose2d( + kernel[:, i : i + 1, :, :], identity_kernel, stride=dilate_rate + ) slices.append(dilated) return torch.cat(slices, dim=1) + def merge_dilated_into_large_kernel(large_kernel, dilated_kernel, dilated_r): large_k = large_kernel.size(2) dilated_k = dilated_kernel.size(2) @@ -133,11 +176,22 @@ class DilatedReparamBlock(nn.Module): Dilated Reparam Block proposed in UniRepLKNet (https://github.com/AILab-CVC/UniRepLKNet) We assume the inputs to this block are (N, C, H, W) """ - def __init__(self, channels, kernel_size, deploy, use_sync_bn=False, attempt_use_lk_impl=True): + + def __init__( + self, channels, kernel_size, deploy, use_sync_bn=False, attempt_use_lk_impl=True + ): super().__init__() - self.lk_origin = get_conv2d(channels, channels, kernel_size, stride=1, - padding=kernel_size//2, dilation=1, groups=channels, bias=deploy, - attempt_use_lk_impl=attempt_use_lk_impl) + self.lk_origin = get_conv2d( + channels, + channels, + kernel_size, + stride=1, + padding=kernel_size // 2, + dilation=1, + groups=channels, + bias=deploy, + attempt_use_lk_impl=attempt_use_lk_impl, + ) self.attempt_use_lk_impl = attempt_use_lk_impl # Default settings. We did not tune them carefully. Different settings may work better. @@ -163,74 +217,108 @@ def __init__(self, channels, kernel_size, deploy, use_sync_bn=False, attempt_use self.kernel_sizes = [3, 3] self.dilates = [1, 2] else: - raise ValueError('Dilated Reparam Block requires kernel_size >= 5') + raise ValueError("Dilated Reparam Block requires kernel_size >= 5") if not deploy: self.origin_bn = get_bn(channels, use_sync_bn) for k, r in zip(self.kernel_sizes, self.dilates): - self.__setattr__('dil_conv_k{}_{}'.format(k, r), - nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=k, stride=1, - padding=(r * (k - 1) + 1) // 2, dilation=r, groups=channels, - bias=False)) - self.__setattr__('dil_bn_k{}_{}'.format(k, r), get_bn(channels, use_sync_bn=use_sync_bn)) + self.__setattr__( + "dil_conv_k{}_{}".format(k, r), + nn.Conv2d( + in_channels=channels, + out_channels=channels, + kernel_size=k, + stride=1, + padding=(r * (k - 1) + 1) // 2, + dilation=r, + groups=channels, + bias=False, + ), + ) + self.__setattr__( + "dil_bn_k{}_{}".format(k, r), + get_bn(channels, use_sync_bn=use_sync_bn), + ) def forward(self, x): - if not hasattr(self, 'origin_bn'): # deploy mode + if not hasattr(self, "origin_bn"): # deploy mode return self.lk_origin(x) out = self.origin_bn(self.lk_origin(x)) for k, r in zip(self.kernel_sizes, self.dilates): - conv = self.__getattr__('dil_conv_k{}_{}'.format(k, r)) - bn = self.__getattr__('dil_bn_k{}_{}'.format(k, r)) + conv = self.__getattr__("dil_conv_k{}_{}".format(k, r)) + bn = self.__getattr__("dil_bn_k{}_{}".format(k, r)) out = out + bn(conv(x)) return out def merge_dilated_branches(self): - if hasattr(self, 'origin_bn'): + if hasattr(self, "origin_bn"): origin_k, origin_b = fuse_bn(self.lk_origin, self.origin_bn) for k, r in zip(self.kernel_sizes, self.dilates): - conv = self.__getattr__('dil_conv_k{}_{}'.format(k, r)) - bn = self.__getattr__('dil_bn_k{}_{}'.format(k, r)) + conv = self.__getattr__("dil_conv_k{}_{}".format(k, r)) + bn = self.__getattr__("dil_bn_k{}_{}".format(k, r)) branch_k, branch_b = fuse_bn(conv, bn) origin_k = merge_dilated_into_large_kernel(origin_k, branch_k, r) origin_b += branch_b - merged_conv = get_conv2d(origin_k.size(0), origin_k.size(0), origin_k.size(2), stride=1, - padding=origin_k.size(2)//2, dilation=1, groups=origin_k.size(0), bias=True, - attempt_use_lk_impl=self.attempt_use_lk_impl) + merged_conv = get_conv2d( + origin_k.size(0), + origin_k.size(0), + origin_k.size(2), + stride=1, + padding=origin_k.size(2) // 2, + dilation=1, + groups=origin_k.size(0), + bias=True, + attempt_use_lk_impl=self.attempt_use_lk_impl, + ) merged_conv.weight.data = origin_k merged_conv.bias.data = origin_b self.lk_origin = merged_conv - self.__delattr__('origin_bn') + self.__delattr__("origin_bn") for k, r in zip(self.kernel_sizes, self.dilates): - self.__delattr__('dil_conv_k{}_{}'.format(k, r)) - self.__delattr__('dil_bn_k{}_{}'.format(k, r)) + self.__delattr__("dil_conv_k{}_{}".format(k, r)) + self.__delattr__("dil_bn_k{}_{}".format(k, r)) class UniRepLKNetBlock(nn.Module): - def __init__(self, - dim, - kernel_size, - drop_path=0., - layer_scale_init_value=1e-6, - deploy=False, - attempt_use_lk_impl=True, - with_cp=False, - use_sync_bn=False, - ffn_factor=4): + def __init__( + self, + dim, + kernel_size, + drop_path=0.0, + layer_scale_init_value=1e-6, + deploy=False, + attempt_use_lk_impl=True, + with_cp=False, + use_sync_bn=False, + ffn_factor=4, + ): super().__init__() self.with_cp = with_cp if kernel_size == 0: self.dwconv = nn.Identity() elif kernel_size >= 7: - self.dwconv = DilatedReparamBlock(dim, kernel_size, deploy=deploy, - use_sync_bn=use_sync_bn, - attempt_use_lk_impl=attempt_use_lk_impl) + self.dwconv = DilatedReparamBlock( + dim, + kernel_size, + deploy=deploy, + use_sync_bn=use_sync_bn, + attempt_use_lk_impl=attempt_use_lk_impl, + ) else: assert kernel_size in [3, 5] - self.dwconv = get_conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=kernel_size // 2, - dilation=1, groups=dim, bias=deploy, - attempt_use_lk_impl=attempt_use_lk_impl) + self.dwconv = get_conv2d( + dim, + dim, + kernel_size=kernel_size, + stride=1, + padding=kernel_size // 2, + dilation=1, + groups=dim, + bias=deploy, + attempt_use_lk_impl=attempt_use_lk_impl, + ) if deploy or kernel_size == 0: self.norm = nn.Identity() @@ -240,26 +328,25 @@ def __init__(self, self.se = SEBlock(dim, dim // 4) ffn_dim = int(ffn_factor * dim) - self.pwconv1 = nn.Sequential( - NCHWtoNHWC(), - nn.Linear(dim, ffn_dim)) - self.act = nn.Sequential( - nn.GELU(), - GRNwithNHWC(ffn_dim, use_bias=not deploy)) + self.pwconv1 = nn.Sequential(NCHWtoNHWC(), nn.Linear(dim, ffn_dim)) + self.act = nn.Sequential(nn.GELU(), GRNwithNHWC(ffn_dim, use_bias=not deploy)) if deploy: - self.pwconv2 = nn.Sequential( - nn.Linear(ffn_dim, dim), - NHWCtoNCHW()) + self.pwconv2 = nn.Sequential(nn.Linear(ffn_dim, dim), NHWCtoNCHW()) else: self.pwconv2 = nn.Sequential( nn.Linear(ffn_dim, dim, bias=False), NHWCtoNCHW(), - get_bn(dim, use_sync_bn=use_sync_bn)) - - self.gamma = nn.Parameter(layer_scale_init_value * torch.ones(dim), - requires_grad=True) if (not deploy) and layer_scale_init_value is not None \ - and layer_scale_init_value > 0 else None - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + get_bn(dim, use_sync_bn=use_sync_bn), + ) + + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) + if (not deploy) + and layer_scale_init_value is not None + and layer_scale_init_value > 0 + else None + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() def compute_residual(self, x): y = self.se(self.norm(self.dwconv(x))) @@ -280,19 +367,35 @@ def _f(x): return out def reparameterize(self): - if hasattr(self.dwconv, 'merge_dilated_branches'): + if hasattr(self.dwconv, "merge_dilated_branches"): self.dwconv.merge_dilated_branches() - if hasattr(self.norm, 'running_var'): + if hasattr(self.norm, "running_var"): std = (self.norm.running_var + self.norm.eps).sqrt() - if hasattr(self.dwconv, 'lk_origin'): - self.dwconv.lk_origin.weight.data *= (self.norm.weight / std).view(-1, 1, 1, 1) - self.dwconv.lk_origin.bias.data = self.norm.bias + ( - self.dwconv.lk_origin.bias - self.norm.running_mean) * self.norm.weight / std + if hasattr(self.dwconv, "lk_origin"): + self.dwconv.lk_origin.weight.data *= (self.norm.weight / std).view( + -1, 1, 1, 1 + ) + self.dwconv.lk_origin.bias.data = ( + self.norm.bias + + (self.dwconv.lk_origin.bias - self.norm.running_mean) + * self.norm.weight + / std + ) else: - conv = nn.Conv2d(self.dwconv.in_channels, self.dwconv.out_channels, self.dwconv.kernel_size, - padding=self.dwconv.padding, groups=self.dwconv.groups, bias=True) - conv.weight.data = self.dwconv.weight * (self.norm.weight / std).view(-1, 1, 1, 1) - conv.bias.data = self.norm.bias - self.norm.running_mean * self.norm.weight / std + conv = nn.Conv2d( + self.dwconv.in_channels, + self.dwconv.out_channels, + self.dwconv.kernel_size, + padding=self.dwconv.padding, + groups=self.dwconv.groups, + bias=True, + ) + conv.weight.data = self.dwconv.weight * (self.norm.weight / std).view( + -1, 1, 1, 1 + ) + conv.bias.data = ( + self.norm.bias - self.norm.running_mean * self.norm.weight / std + ) self.dwconv = conv self.norm = nn.Identity() if self.gamma is not None: @@ -302,37 +405,78 @@ def reparameterize(self): final_scale = 1 if self.act[1].use_bias and len(self.pwconv2) == 3: grn_bias = self.act[1].beta.data - self.act[1].__delattr__('beta') + self.act[1].__delattr__("beta") self.act[1].use_bias = False linear = self.pwconv2[0] - grn_bias_projected_bias = (linear.weight.data @ grn_bias.view(-1, 1)).squeeze() + grn_bias_projected_bias = ( + linear.weight.data @ grn_bias.view(-1, 1) + ).squeeze() bn = self.pwconv2[2] std = (bn.running_var + bn.eps).sqrt() new_linear = nn.Linear(linear.in_features, linear.out_features, bias=True) - new_linear.weight.data = linear.weight * (bn.weight / std * final_scale).view(-1, 1) + new_linear.weight.data = linear.weight * ( + bn.weight / std * final_scale + ).view(-1, 1) linear_bias = 0 if linear.bias is None else linear.bias.data linear_bias += grn_bias_projected_bias - new_linear.bias.data = (bn.bias + (linear_bias - bn.running_mean) * bn.weight / std) * final_scale + new_linear.bias.data = ( + bn.bias + (linear_bias - bn.running_mean) * bn.weight / std + ) * final_scale self.pwconv2 = nn.Sequential(new_linear, self.pwconv2[1]) - -default_UniRepLKNet_A_F_P_kernel_sizes = ((3, 3), - (13, 13), - (13, 13, 13, 13, 13, 13), - (13, 13)) -default_UniRepLKNet_N_kernel_sizes = ((3, 3), - (13, 13), - (13, 13, 13, 13, 13, 13, 13, 13), - (13, 13)) -default_UniRepLKNet_T_kernel_sizes = ((3, 3, 3), - (13, 13, 13), - (13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3), - (13, 13, 13)) -default_UniRepLKNet_S_B_L_XL_kernel_sizes = ((3, 3, 3), - (13, 13, 13), - (13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3), - (13, 13, 13)) +default_UniRepLKNet_A_F_P_kernel_sizes = ( + (3, 3), + (13, 13), + (13, 13, 13, 13, 13, 13), + (13, 13), +) +default_UniRepLKNet_N_kernel_sizes = ( + (3, 3), + (13, 13), + (13, 13, 13, 13, 13, 13, 13, 13), + (13, 13), +) +default_UniRepLKNet_T_kernel_sizes = ( + (3, 3, 3), + (13, 13, 13), + (13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3), + (13, 13, 13), +) +default_UniRepLKNet_S_B_L_XL_kernel_sizes = ( + (3, 3, 3), + (13, 13, 13), + ( + 13, + 3, + 3, + 13, + 3, + 3, + 13, + 3, + 3, + 13, + 3, + 3, + 13, + 3, + 3, + 13, + 3, + 3, + 13, + 3, + 3, + 13, + 3, + 3, + 13, + 3, + 3, + ), + (13, 13, 13), +) UniRepLKNet_A_F_P_depths = (2, 2, 6, 2) UniRepLKNet_N_depths = (2, 2, 8, 2) UniRepLKNet_T_depths = (3, 3, 18, 3) @@ -342,11 +486,12 @@ def reparameterize(self): UniRepLKNet_A_F_P_depths: default_UniRepLKNet_A_F_P_kernel_sizes, UniRepLKNet_N_depths: default_UniRepLKNet_N_kernel_sizes, UniRepLKNet_T_depths: default_UniRepLKNet_T_kernel_sizes, - UniRepLKNet_S_B_L_XL_depths: default_UniRepLKNet_S_B_L_XL_kernel_sizes + UniRepLKNet_S_B_L_XL_depths: default_UniRepLKNet_S_B_L_XL_kernel_sizes, } + class UniRepLKNet(nn.Module): - r""" UniRepLKNet + r"""UniRepLKNet A PyTorch impl of UniRepLKNet Args: @@ -364,22 +509,24 @@ class UniRepLKNet(nn.Module): attempt_use_lk_impl (bool): try to load the efficient iGEMM large-kernel impl. Setting it to False disabling the iGEMM impl. Default: True use_sync_bn (bool): use_sync_bn = True means using sync BN. Use it if your batch size is small. Default: False """ - def __init__(self, - in_chans=3, - num_classes=1000, - depths=(3, 3, 27, 3), - dims=(96, 192, 384, 768), - drop_path_rate=0., - layer_scale_init_value=1e-6, - head_init_scale=1., - kernel_sizes=None, - deploy=False, - with_cp=False, - init_cfg=None, - attempt_use_lk_impl=True, - use_sync_bn=False, - **kwargs - ): + + def __init__( + self, + in_chans=3, + num_classes=1000, + depths=(3, 3, 27, 3), + dims=(96, 192, 384, 768), + drop_path_rate=0.0, + layer_scale_init_value=1e-6, + head_init_scale=1.0, + kernel_sizes=None, + deploy=False, + with_cp=False, + init_cfg=None, + attempt_use_lk_impl=True, + use_sync_bn=False, + **kwargs, + ): super().__init__() depths = tuple(depths) @@ -387,45 +534,64 @@ def __init__(self, if depths in default_depths_to_kernel_sizes: kernel_sizes = default_depths_to_kernel_sizes[depths] else: - raise ValueError('no default kernel size settings for the given depths, ' - 'please specify kernel sizes for each block, e.g., ' - '((3, 3), (13, 13), (13, 13, 13, 13, 13, 13), (13, 13))') + raise ValueError( + "no default kernel size settings for the given depths, " + "please specify kernel sizes for each block, e.g., " + "((3, 3), (13, 13), (13, 13, 13, 13, 13, 13), (13, 13))" + ) for i in range(4): - assert len(kernel_sizes[i]) == depths[i], 'kernel sizes do not match the depths' + assert ( + len(kernel_sizes[i]) == depths[i] + ), "kernel sizes do not match the depths" self.with_cp = with_cp dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] - self.downsample_layers = nn.ModuleList() - self.downsample_layers.append(nn.Sequential( - nn.Conv2d(in_chans, dims[0] // 2, kernel_size=3, stride=2, padding=1), - LayerNorm(dims[0] // 2, eps=1e-6, data_format="channels_first"), - nn.GELU(), - nn.Conv2d(dims[0] // 2, dims[0], kernel_size=3, stride=2, padding=1), - LayerNorm(dims[0], eps=1e-6, data_format="channels_first"))) + self.downsample_layers.append( + nn.Sequential( + nn.Conv2d(in_chans, dims[0] // 2, kernel_size=3, stride=2, padding=1), + LayerNorm(dims[0] // 2, eps=1e-6, data_format="channels_first"), + nn.GELU(), + nn.Conv2d(dims[0] // 2, dims[0], kernel_size=3, stride=2, padding=1), + LayerNorm(dims[0], eps=1e-6, data_format="channels_first"), + ) + ) for i in range(3): - self.downsample_layers.append(nn.Sequential( - nn.Conv2d(dims[i], dims[i + 1], kernel_size=3, stride=2, padding=1), - LayerNorm(dims[i + 1], eps=1e-6, data_format="channels_first"))) + self.downsample_layers.append( + nn.Sequential( + nn.Conv2d(dims[i], dims[i + 1], kernel_size=3, stride=2, padding=1), + LayerNorm(dims[i + 1], eps=1e-6, data_format="channels_first"), + ) + ) self.stages = nn.ModuleList() cur = 0 for i in range(4): main_stage = nn.Sequential( - *[UniRepLKNetBlock(dim=dims[i], kernel_size=kernel_sizes[i][j], drop_path=dp_rates[cur + j], - layer_scale_init_value=layer_scale_init_value, deploy=deploy, - attempt_use_lk_impl=attempt_use_lk_impl, - with_cp=with_cp, use_sync_bn=use_sync_bn) for j in - range(depths[i])]) + *[ + UniRepLKNetBlock( + dim=dims[i], + kernel_size=kernel_sizes[i][j], + drop_path=dp_rates[cur + j], + layer_scale_init_value=layer_scale_init_value, + deploy=deploy, + attempt_use_lk_impl=attempt_use_lk_impl, + with_cp=with_cp, + use_sync_bn=use_sync_bn, + ) + for j in range(depths[i]) + ] + ) self.stages.append(main_stage) cur += depths[i] last_channels = dims[-1] + self.num_features = last_channels self.is_classification = num_classes is not None if self.is_classification: @@ -434,64 +600,130 @@ def __init__(self, self.apply(self._init_weights) self.head.weight.data.mul_(head_init_scale) self.head.bias.data.mul_(head_init_scale) - self.output_mode = 'logits' + self.output_mode = "logits" else: - self.output_mode = 'features' + self.output_mode = "features" norm_layer = partial(LayerNorm, eps=1e-6, data_format="channels_first") for i_layer in range(4): layer = norm_layer(dims[i_layer]) - layer_name = f'norm{i_layer}' + layer_name = f"norm{i_layer}" self.add_module(layer_name, layer) if init_cfg is not None: - ckpt = torch.load(init_cfg, map_location='cpu') - if 'state_dict' in ckpt: - ckpt = ckpt['state_dict'] - elif 'model' in ckpt: - ckpt = ckpt['model'] + ckpt = torch.load(init_cfg, map_location="cpu") + if "state_dict" in ckpt: + ckpt = ckpt["state_dict"] + elif "model" in ckpt: + ckpt = ckpt["model"] self.load_state_dict(ckpt, strict=False) - - def _init_weights(self, m): if isinstance(m, (nn.Conv2d, nn.Linear)): - trunc_normal_(m.weight, std=.02) - if hasattr(m, 'bias') and m.bias is not None: + trunc_normal_(m.weight, std=0.02) + if hasattr(m, "bias") and m.bias is not None: nn.init.constant_(m.bias, 0) - def forward(self, x): - if self.output_mode == 'logits': - for stage_idx in range(4): - x = self.downsample_layers[stage_idx](x) - x = self.stages[stage_idx](x) - x = self.norm(x.mean([-2, -1])) - x = self.head(x) + @torch.jit.ignore + def get_classifier(self) -> nn.Module: + return self.head + + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + self.num_classes = num_classes + self.is_classification = num_classes is not None and num_classes > 0 + self.output_mode = "logits" if self.is_classification else "features" + self.head = ( + nn.Linear(self.num_features, num_classes) + if num_classes > 0 + else nn.Identity() + ) + + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = "NCHW", + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + assert output_fmt in ("NCHW",), "Output shape must be NCHW." + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + last_idx = len(self.stages) - 1 + + if torch.jit.is_scripting() or not stop_early: + stage_range = range(len(self.stages)) + else: + stage_range = range(max_index + 1) + for feat_idx in stage_range: + x = self.downsample_layers[feat_idx](x) + x = self.stages[feat_idx](x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ) -> List[int]: + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[: max_index + 1] + self.downsample_layers = self.downsample_layers[: max_index + 1] + if prune_head: + self.reset_classifier(0, "") + return take_indices + + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + for stage_idx in range(len(self.stages)): + x = self.downsample_layers[stage_idx](x) + x = self.stages[stage_idx](x) + return x + + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + x = x.mean(dim=(-2, -1)) + x = self.norm(x) if hasattr(self, "norm") else x + if pre_logits: return x - elif self.output_mode == 'features': + return self.head(x) + + def forward(self, x): + if self.output_mode == "features": outs = [] - for stage_idx in range(4): + for stage_idx in range(len(self.stages)): x = self.downsample_layers[stage_idx](x) x = self.stages[stage_idx](x) - outs.append(self.__getattr__(f'norm{stage_idx}')(x)) + outs.append(self.__getattr__(f"norm{stage_idx}")(x)) return outs - else: - raise ValueError('Defined new output mode?') + x = self.forward_features(x) + x = self.forward_head(x) + return x def reparameterize_unireplknet(self): for m in self.modules(): - if hasattr(m, 'reparameterize'): + if hasattr(m, "reparameterize"): m.reparameterize() - class LayerNorm(nn.Module): - r""" LayerNorm implementation used in ConvNeXt + r"""LayerNorm implementation used in ConvNeXt LayerNorm that supports two data formats: channels_last (default) or channels_first. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). """ - def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last", reshape_last_to_first=False): + def __init__( + self, + normalized_shape, + eps=1e-6, + data_format="channels_last", + reshape_last_to_first=False, + ): super().__init__() self.weight = nn.Parameter(torch.ones(normalized_shape)) self.bias = nn.Parameter(torch.zeros(normalized_shape)) @@ -504,7 +736,9 @@ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last", resh def forward(self, x): if self.data_format == "channels_last": - return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + return F.layer_norm( + x, self.normalized_shape, self.weight, self.bias, self.eps + ) elif self.data_format == "channels_first": u = x.mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True) @@ -513,10 +747,8 @@ def forward(self, x): return x - - model_urls = { - #TODO: it seems that google drive does not support direct downloading with url? so where to upload the checkpoints other than huggingface? any suggestions? + # TODO: it seems that google drive does not support direct downloading with url? so where to upload the checkpoints other than huggingface? any suggestions? } huggingface_file_names = { @@ -533,28 +765,36 @@ def forward(self, x): "unireplknet_l_22k": "unireplknet_l_in22k_pretrain.pth", "unireplknet_l_22k_to_1k": "unireplknet_l_in22k_to_in1k_384_acc87.88.pth", "unireplknet_xl_22k": "unireplknet_xl_in22k_pretrain.pth", - "unireplknet_xl_22k_to_1k": "unireplknet_xl_in22k_to_in1k_384_acc87.96.pth" + "unireplknet_xl_22k_to_1k": "unireplknet_xl_in22k_to_in1k_384_acc87.96.pth", } + def load_with_key(model, key): # if huggingface hub is found, download from our huggingface repo if hf_hub_download is not None: - repo_id = 'DingXiaoH/UniRepLKNet' - cache_file = hf_hub_download(repo_id=repo_id, filename=huggingface_file_names[key]) - checkpoint = torch.load(cache_file, map_location='cpu') + repo_id = "DingXiaoH/UniRepLKNet" + cache_file = hf_hub_download( + repo_id=repo_id, filename=huggingface_file_names[key] + ) + checkpoint = torch.load(cache_file, map_location="cpu") else: - checkpoint = torch.hub.load_state_dict_from_url(url=model_urls[key], map_location="cpu", check_hash=True) - if 'model' in checkpoint: - checkpoint = checkpoint['model'] + checkpoint = torch.hub.load_state_dict_from_url( + url=model_urls[key], map_location="cpu", check_hash=True + ) + if "model" in checkpoint: + checkpoint = checkpoint["model"] model.load_state_dict(checkpoint) -def initialize_with_pretrained(model, model_name, in_1k_pretrained, in_22k_pretrained, in_22k_to_1k): + +def initialize_with_pretrained( + model, model_name, in_1k_pretrained, in_22k_pretrained, in_22k_to_1k +): if in_1k_pretrained: - key = model_name + '_1k' + key = model_name + "_1k" elif in_22k_pretrained: - key = model_name + '_22k' + key = model_name + "_22k" elif in_22k_to_1k: - key = model_name + '_22k_to_1k' + key = model_name + "_22k_to_1k" else: key = None if key: @@ -563,123 +803,220 @@ def initialize_with_pretrained(model, model_name, in_1k_pretrained, in_22k_pretr def _create_unireplknet(variant, pretrained=False, **kwargs): return build_model_with_cfg( - UniRepLKNet, variant, pretrained, + UniRepLKNet, + variant, + pretrained, feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True), **kwargs, ) -def _cfg(url='', hf_hub_id='', hf_hub_filename='', **kwargs): + +def _cfg(url="", hf_hub_id="", hf_hub_filename="", **kwargs): return { - 'url': url, - 'hf_hub_id': hf_hub_id, - 'hf_hub_filename': hf_hub_filename, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), - 'crop_pct': 0.875, 'interpolation': 'bicubic', - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'downsample_layers.0.0', 'classifier': 'head', - 'license': 'apache-2.0', - **kwargs -} + "url": url, + "hf_hub_id": hf_hub_id, + "hf_hub_filename": hf_hub_filename, + "num_classes": 1000, + "input_size": (3, 224, 224), + "pool_size": (7, 7), + "crop_pct": 0.875, + "interpolation": "bicubic", + "mean": IMAGENET_DEFAULT_MEAN, + "std": IMAGENET_DEFAULT_STD, + "first_conv": "downsample_layers.0.0", + "classifier": "head", + "license": "apache-2.0", + **kwargs, + } -def _cfg_22k(url='', hf_hub_id='', hf_hub_filename='', **kwargs): +def _cfg_22k(url="", hf_hub_id="", hf_hub_filename="", **kwargs): return _cfg(url, hf_hub_id, hf_hub_filename, num_classes=21841, **kwargs) -def _cfg_384(url='', hf_hub_id='', hf_hub_filename='', **kwargs): - return _cfg(url, hf_hub_id, hf_hub_filename, input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, **kwargs) - - -default_cfgs = generate_default_cfgs({ - 'unireplknet_a.in1k': _cfg( - url='https://huggingface.co/DingXiaoH/UniRepLKNet/resolve/main/unireplknet_a_in1k_224_acc77.03.pth', - hf_hub_id='DingXiaoH/UniRepLKNet', hf_hub_filename='unireplknet_a_in1k_224_acc77.03.pth'), - 'unireplknet_f.in1k': _cfg( - url='https://huggingface.co/DingXiaoH/UniRepLKNet/resolve/main/unireplknet_f_in1k_224_acc78.58.pth', - hf_hub_id='DingXiaoH/UniRepLKNet', hf_hub_filename='unireplknet_f_in1k_224_acc78.58.pth'), - 'unireplknet_p.in1k': _cfg( - url='https://huggingface.co/DingXiaoH/UniRepLKNet/resolve/main/unireplknet_p_in1k_224_acc80.23.pth', - hf_hub_id='DingXiaoH/UniRepLKNet', hf_hub_filename='unireplknet_p_in1k_224_acc80.23.pth'), - 'unireplknet_n.in1k': _cfg( - url='https://huggingface.co/DingXiaoH/UniRepLKNet/resolve/main/unireplknet_n_in1k_224_acc81.64.pth', - hf_hub_id='DingXiaoH/UniRepLKNet', hf_hub_filename='unireplknet_n_in1k_224_acc81.64.pth'), - 'unireplknet_t.in1k': _cfg( - url='https://huggingface.co/DingXiaoH/UniRepLKNet/resolve/main/unireplknet_t_in1k_224_acc83.21.pth', - hf_hub_id='DingXiaoH/UniRepLKNet', hf_hub_filename='unireplknet_t_in1k_224_acc83.21.pth'), - 'unireplknet_s.in1k': _cfg( - url='https://huggingface.co/DingXiaoH/UniRepLKNet/resolve/main/unireplknet_s_in1k_224_acc83.91.pth', - hf_hub_id='DingXiaoH/UniRepLKNet', hf_hub_filename='unireplknet_s_in1k_224_acc83.91.pth'), - 'unireplknet_s.in22k': _cfg_22k( - url='https://huggingface.co/DingXiaoH/UniRepLKNet/resolve/main/unireplknet_s_in22k_pretrain.pth', - hf_hub_id='DingXiaoH/UniRepLKNet', hf_hub_filename='unireplknet_s_in22k_pretrain.pth'), - 'unireplknet_s.in22k_ft_in1k': _cfg_384( - url='https://huggingface.co/DingXiaoH/UniRepLKNet/resolve/main/unireplknet_s_in22k_to_in1k_384_acc86.44.pth', - hf_hub_id='DingXiaoH/UniRepLKNet', hf_hub_filename='unireplknet_s_in22k_to_in1k_384_acc86.44.pth'), - 'unireplknet_b.in22k': _cfg_22k( - url='https://huggingface.co/DingXiaoH/UniRepLKNet/resolve/main/unireplknet_b_in22k_pretrain.pth', - hf_hub_id='DingXiaoH/UniRepLKNet', hf_hub_filename='unireplknet_b_in22k_pretrain.pth'), - 'unireplknet_b.in22k_ft_in1k': _cfg_384( - url='https://huggingface.co/DingXiaoH/UniRepLKNet/resolve/main/unireplknet_b_in22k_to_in1k_384_acc87.40.pth', - hf_hub_id='DingXiaoH/UniRepLKNet', hf_hub_filename='unireplknet_b_in22k_to_in1k_384_acc87.40.pth'), - 'unireplknet_l.in22k': _cfg_22k( - url='https://huggingface.co/DingXiaoH/UniRepLKNet/resolve/main/unireplknet_l_in22k_pretrain.pth', - hf_hub_id='DingXiaoH/UniRepLKNet', hf_hub_filename='unireplknet_l_in22k_pretrain.pth'), - 'unireplknet_l.in22k_ft_in1k': _cfg_384( - url='https://huggingface.co/DingXiaoH/UniRepLKNet/resolve/main/unireplknet_l_in22k_to_in1k_384_acc87.88.pth', - hf_hub_id='DingXiaoH/UniRepLKNet', hf_hub_filename='unireplknet_l_in22k_to_in1k_384_acc87.88.pth'), - 'unireplknet_xl.in22k': _cfg_22k( - url='https://huggingface.co/DingXiaoH/UniRepLKNet/resolve/main/unireplknet_xl_in22k_pretrain.pth', - hf_hub_id='DingXiaoH/UniRepLKNet', hf_hub_filename='unireplknet_xl_in22k_pretrain.pth'), - 'unireplknet_xl.in22k_ft_in1k': _cfg_384( - url='https://huggingface.co/DingXiaoH/UniRepLKNet/resolve/main/unireplknet_xl_in22k_to_in1k_384_acc87.96.pth', - hf_hub_id='DingXiaoH/UniRepLKNet', hf_hub_filename='unireplknet_xl_in22k_to_in1k_384_acc87.96.pth'), -}) +def _cfg_384(url="", hf_hub_id="", hf_hub_filename="", **kwargs): + return _cfg( + url, + hf_hub_id, + hf_hub_filename, + input_size=(3, 384, 384), + pool_size=(12, 12), + crop_pct=1.0, + **kwargs, + ) + + +default_cfgs = generate_default_cfgs( + { + "unireplknet_a.in1k": _cfg( + url="https://huggingface.co/DingXiaoH/UniRepLKNet/resolve/main/unireplknet_a_in1k_224_acc77.03.pth", + hf_hub_id="DingXiaoH/UniRepLKNet", + hf_hub_filename="unireplknet_a_in1k_224_acc77.03.pth", + ), + "unireplknet_f.in1k": _cfg( + url="https://huggingface.co/DingXiaoH/UniRepLKNet/resolve/main/unireplknet_f_in1k_224_acc78.58.pth", + hf_hub_id="DingXiaoH/UniRepLKNet", + hf_hub_filename="unireplknet_f_in1k_224_acc78.58.pth", + ), + "unireplknet_p.in1k": _cfg( + url="https://huggingface.co/DingXiaoH/UniRepLKNet/resolve/main/unireplknet_p_in1k_224_acc80.23.pth", + hf_hub_id="DingXiaoH/UniRepLKNet", + hf_hub_filename="unireplknet_p_in1k_224_acc80.23.pth", + ), + "unireplknet_n.in1k": _cfg( + url="https://huggingface.co/DingXiaoH/UniRepLKNet/resolve/main/unireplknet_n_in1k_224_acc81.64.pth", + hf_hub_id="DingXiaoH/UniRepLKNet", + hf_hub_filename="unireplknet_n_in1k_224_acc81.64.pth", + ), + "unireplknet_t.in1k": _cfg( + url="https://huggingface.co/DingXiaoH/UniRepLKNet/resolve/main/unireplknet_t_in1k_224_acc83.21.pth", + hf_hub_id="DingXiaoH/UniRepLKNet", + hf_hub_filename="unireplknet_t_in1k_224_acc83.21.pth", + ), + "unireplknet_s.in1k": _cfg( + url="https://huggingface.co/DingXiaoH/UniRepLKNet/resolve/main/unireplknet_s_in1k_224_acc83.91.pth", + hf_hub_id="DingXiaoH/UniRepLKNet", + hf_hub_filename="unireplknet_s_in1k_224_acc83.91.pth", + ), + "unireplknet_s.in22k": _cfg_22k( + url="https://huggingface.co/DingXiaoH/UniRepLKNet/resolve/main/unireplknet_s_in22k_pretrain.pth", + hf_hub_id="DingXiaoH/UniRepLKNet", + hf_hub_filename="unireplknet_s_in22k_pretrain.pth", + ), + "unireplknet_s.in22k_ft_in1k": _cfg_384( + url="https://huggingface.co/DingXiaoH/UniRepLKNet/resolve/main/unireplknet_s_in22k_to_in1k_384_acc86.44.pth", + hf_hub_id="DingXiaoH/UniRepLKNet", + hf_hub_filename="unireplknet_s_in22k_to_in1k_384_acc86.44.pth", + ), + "unireplknet_b.in22k": _cfg_22k( + url="https://huggingface.co/DingXiaoH/UniRepLKNet/resolve/main/unireplknet_b_in22k_pretrain.pth", + hf_hub_id="DingXiaoH/UniRepLKNet", + hf_hub_filename="unireplknet_b_in22k_pretrain.pth", + ), + "unireplknet_b.in22k_ft_in1k": _cfg_384( + url="https://huggingface.co/DingXiaoH/UniRepLKNet/resolve/main/unireplknet_b_in22k_to_in1k_384_acc87.40.pth", + hf_hub_id="DingXiaoH/UniRepLKNet", + hf_hub_filename="unireplknet_b_in22k_to_in1k_384_acc87.40.pth", + ), + "unireplknet_l.in22k": _cfg_22k( + url="https://huggingface.co/DingXiaoH/UniRepLKNet/resolve/main/unireplknet_l_in22k_pretrain.pth", + hf_hub_id="DingXiaoH/UniRepLKNet", + hf_hub_filename="unireplknet_l_in22k_pretrain.pth", + ), + "unireplknet_l.in22k_ft_in1k": _cfg_384( + url="https://huggingface.co/DingXiaoH/UniRepLKNet/resolve/main/unireplknet_l_in22k_to_in1k_384_acc87.88.pth", + hf_hub_id="DingXiaoH/UniRepLKNet", + hf_hub_filename="unireplknet_l_in22k_to_in1k_384_acc87.88.pth", + ), + "unireplknet_xl.in22k": _cfg_22k( + url="https://huggingface.co/DingXiaoH/UniRepLKNet/resolve/main/unireplknet_xl_in22k_pretrain.pth", + hf_hub_id="DingXiaoH/UniRepLKNet", + hf_hub_filename="unireplknet_xl_in22k_pretrain.pth", + ), + "unireplknet_xl.in22k_ft_in1k": _cfg_384( + url="https://huggingface.co/DingXiaoH/UniRepLKNet/resolve/main/unireplknet_xl_in22k_to_in1k_384_acc87.96.pth", + hf_hub_id="DingXiaoH/UniRepLKNet", + hf_hub_filename="unireplknet_xl_in22k_to_in1k_384_acc87.96.pth", + ), + } +) @register_model def unireplknet_a(pretrained: bool = False, **kwargs) -> UniRepLKNet: - model = _create_unireplknet('unireplknet_a', pretrained=pretrained, depths=UniRepLKNet_A_F_P_depths, dims=(40, 80, 160, 320), **kwargs) + model = _create_unireplknet( + "unireplknet_a", + pretrained=pretrained, + depths=UniRepLKNet_A_F_P_depths, + dims=(40, 80, 160, 320), + **kwargs, + ) return model @register_model def unireplknet_f(pretrained: bool = False, **kwargs) -> UniRepLKNet: - return _create_unireplknet('unireplknet_f', pretrained=pretrained, depths=UniRepLKNet_A_F_P_depths, dims=(48, 96, 192, 384), **kwargs) + return _create_unireplknet( + "unireplknet_f", + pretrained=pretrained, + depths=UniRepLKNet_A_F_P_depths, + dims=(48, 96, 192, 384), + **kwargs, + ) @register_model def unireplknet_p(pretrained: bool = False, **kwargs) -> UniRepLKNet: - return _create_unireplknet('unireplknet_p', pretrained=pretrained, depths=UniRepLKNet_A_F_P_depths, dims=(64, 128, 256, 512), **kwargs) + return _create_unireplknet( + "unireplknet_p", + pretrained=pretrained, + depths=UniRepLKNet_A_F_P_depths, + dims=(64, 128, 256, 512), + **kwargs, + ) @register_model def unireplknet_n(pretrained: bool = False, **kwargs) -> UniRepLKNet: - return _create_unireplknet('unireplknet_n', pretrained=pretrained, depths=UniRepLKNet_N_depths, dims=(80, 160, 320, 640), **kwargs) + return _create_unireplknet( + "unireplknet_n", + pretrained=pretrained, + depths=UniRepLKNet_N_depths, + dims=(80, 160, 320, 640), + **kwargs, + ) @register_model def unireplknet_t(pretrained: bool = False, **kwargs) -> UniRepLKNet: - return _create_unireplknet('unireplknet_t', pretrained=pretrained, depths=UniRepLKNet_T_depths, dims=(80, 160, 320, 640), **kwargs) + return _create_unireplknet( + "unireplknet_t", + pretrained=pretrained, + depths=UniRepLKNet_T_depths, + dims=(80, 160, 320, 640), + **kwargs, + ) @register_model def unireplknet_s(pretrained: bool = False, **kwargs) -> UniRepLKNet: - return _create_unireplknet('unireplknet_s', pretrained=pretrained, depths=UniRepLKNet_S_B_L_XL_depths, dims=(96, 192, 384, 768), **kwargs) + return _create_unireplknet( + "unireplknet_s", + pretrained=pretrained, + depths=UniRepLKNet_S_B_L_XL_depths, + dims=(96, 192, 384, 768), + **kwargs, + ) @register_model def unireplknet_b(pretrained: bool = False, **kwargs) -> UniRepLKNet: - return _create_unireplknet('unireplknet_b', pretrained=pretrained, depths=UniRepLKNet_S_B_L_XL_depths, dims=(128, 256, 512, 1024), **kwargs) + return _create_unireplknet( + "unireplknet_b", + pretrained=pretrained, + depths=UniRepLKNet_S_B_L_XL_depths, + dims=(128, 256, 512, 1024), + **kwargs, + ) @register_model def unireplknet_l(pretrained: bool = False, **kwargs) -> UniRepLKNet: - return _create_unireplknet('unireplknet_l', pretrained=pretrained, depths=UniRepLKNet_S_B_L_XL_depths, dims=(192, 384, 768, 1536), **kwargs) + return _create_unireplknet( + "unireplknet_l", + pretrained=pretrained, + depths=UniRepLKNet_S_B_L_XL_depths, + dims=(192, 384, 768, 1536), + **kwargs, + ) @register_model def unireplknet_xl(pretrained: bool = False, **kwargs) -> UniRepLKNet: - return _create_unireplknet('unireplknet_xl', pretrained=pretrained, depths=UniRepLKNet_S_B_L_XL_depths, dims=(256, 512, 1024, 2048), **kwargs) - - - + return _create_unireplknet( + "unireplknet_xl", + pretrained=pretrained, + depths=UniRepLKNet_S_B_L_XL_depths, + dims=(256, 512, 1024, 2048), + **kwargs, + )