@@ -189,9 +189,9 @@ def __init__(
189189 self .unfold = nn .Unfold (kernel_size = (kernel_size , kernel_size ), stride = (kernel_size , kernel_size ))
190190 self .transform = Dct2d (kernel_size , kernel_type , orthonormal , ** factory_kwargs )
191191 self .permutation = _zigzag_permutation (kernel_size , kernel_size )
192- self .Y_Conv = nn .Conv2d (kernel_size ** 2 , 24 , kernel_size = 1 , padding = 0 )
193- self .Cb_Conv = nn .Conv2d (kernel_size ** 2 , 4 , kernel_size = 1 , padding = 0 )
194- self .Cr_Conv = nn .Conv2d (kernel_size ** 2 , 4 , kernel_size = 1 , padding = 0 )
192+ self .conv_y = nn .Conv2d (kernel_size ** 2 , 24 , kernel_size = 1 , padding = 0 )
193+ self .conv_cb = nn .Conv2d (kernel_size ** 2 , 4 , kernel_size = 1 , padding = 0 )
194+ self .conv_cr = nn .Conv2d (kernel_size ** 2 , 4 , kernel_size = 1 , padding = 0 )
195195
196196 self .register_buffer ('mean' , torch .tensor (_DCT_MEAN ), persistent = False )
197197 self .register_buffer ('var' , torch .tensor (_DCT_VAR ), persistent = False )
@@ -228,9 +228,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
228228 x = self ._frequency_normalize (x )
229229 x = x .reshape (b , h // self .k , w // self .k , c , - 1 )
230230 x = x .permute (0 , 3 , 4 , 1 , 2 ).contiguous ()
231- x_y = self .Y_Conv (x [:, 0 ])
232- x_cb = self .Cb_Conv (x [:, 1 ])
233- x_cr = self .Cr_Conv (x [:, 2 ])
231+ x_y = self .conv_y (x [:, 0 ])
232+ x_cb = self .conv_cb (x [:, 1 ])
233+ x_cr = self .conv_cr (x [:, 2 ])
234234 return torch .cat ([x_y , x_cb , x_cr ], dim = 1 )
235235
236236
@@ -697,8 +697,11 @@ def remap_stage(m):
697697
698698 out_dict = {}
699699 for k , v in state_dict .items ():
700- # Remap dct -> stem_dct
700+ # Remap dct -> stem_dct, and Y_Conv/Cb_Conv/Cr_Conv -> conv_y/conv_cb/conv_cr
701701 k = re .sub (r'^dct\.' , 'stem_dct.' , k )
702+ k = k .replace ('.Y_Conv.' , '.conv_y.' )
703+ k = k .replace ('.Cb_Conv.' , '.conv_cb.' )
704+ k = k .replace ('.Cr_Conv.' , '.conv_cr.' )
702705
703706 # Remap stage names with index adjustments for downsample relocation
704707 k = re .sub (r'^stages([1-4])\.(\d+)\.(.*)$' , remap_stage , k )
0 commit comments