Skip to content

Commit 1c6b0f5

Browse files
committed
Weird mix of camel/snake conv naming -> snake case
1 parent b0cb744 commit 1c6b0f5

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

timm/models/csatv2.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -189,9 +189,9 @@ def __init__(
189189
self.unfold = nn.Unfold(kernel_size=(kernel_size, kernel_size), stride=(kernel_size, kernel_size))
190190
self.transform = Dct2d(kernel_size, kernel_type, orthonormal, **factory_kwargs)
191191
self.permutation = _zigzag_permutation(kernel_size, kernel_size)
192-
self.Y_Conv = nn.Conv2d(kernel_size ** 2, 24, kernel_size=1, padding=0)
193-
self.Cb_Conv = nn.Conv2d(kernel_size ** 2, 4, kernel_size=1, padding=0)
194-
self.Cr_Conv = nn.Conv2d(kernel_size ** 2, 4, kernel_size=1, padding=0)
192+
self.conv_y = nn.Conv2d(kernel_size ** 2, 24, kernel_size=1, padding=0)
193+
self.conv_cb = nn.Conv2d(kernel_size ** 2, 4, kernel_size=1, padding=0)
194+
self.conv_cr = nn.Conv2d(kernel_size ** 2, 4, kernel_size=1, padding=0)
195195

196196
self.register_buffer('mean', torch.tensor(_DCT_MEAN), persistent=False)
197197
self.register_buffer('var', torch.tensor(_DCT_VAR), persistent=False)
@@ -228,9 +228,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
228228
x = self._frequency_normalize(x)
229229
x = x.reshape(b, h // self.k, w // self.k, c, -1)
230230
x = x.permute(0, 3, 4, 1, 2).contiguous()
231-
x_y = self.Y_Conv(x[:, 0])
232-
x_cb = self.Cb_Conv(x[:, 1])
233-
x_cr = self.Cr_Conv(x[:, 2])
231+
x_y = self.conv_y(x[:, 0])
232+
x_cb = self.conv_cb(x[:, 1])
233+
x_cr = self.conv_cr(x[:, 2])
234234
return torch.cat([x_y, x_cb, x_cr], dim=1)
235235

236236

@@ -697,8 +697,11 @@ def remap_stage(m):
697697

698698
out_dict = {}
699699
for k, v in state_dict.items():
700-
# Remap dct -> stem_dct
700+
# Remap dct -> stem_dct, and Y_Conv/Cb_Conv/Cr_Conv -> conv_y/conv_cb/conv_cr
701701
k = re.sub(r'^dct\.', 'stem_dct.', k)
702+
k = k.replace('.Y_Conv.', '.conv_y.')
703+
k = k.replace('.Cb_Conv.', '.conv_cb.')
704+
k = k.replace('.Cr_Conv.', '.conv_cr.')
702705

703706
# Remap stage names with index adjustments for downsample relocation
704707
k = re.sub(r'^stages([1-4])\.(\d+)\.(.*)$', remap_stage, k)

0 commit comments

Comments
 (0)