@@ -97,17 +97,17 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim,
9797 )
9898 else :
9999 self .encoder = Encoder (norm_layer = nn .GroupNorm if layer_norm else nn .Identity ,
100- kernel_sizes = [4 , 4 , 4 ],
101- channel_step = 48 * (self .n_dim // 192 ),
102- double_conv = True ,
100+ kernel_sizes = [4 , 4 ],
101+ channel_step = 48 * (self .n_dim // 192 ) * 2 ,
102+ post_conv_num = 3 ,
103103 flatten_output = False )
104104
105105 self .slot_attention = SlotAttention (slots_num , self .n_dim , slots_iter_num , use_prev_slots )
106106 self .register_buffer ('pos_enc' , torch .from_numpy (get_position_encoding (self .slots_num , self .state_size // slots_num )).to (dtype = torch .float32 ))
107107 if self .encode_vit :
108108 self .positional_augmenter_inp = PositionalEmbedding (self .n_dim , (4 , 4 ))
109109 else :
110- self .positional_augmenter_inp = PositionalEmbedding (self .n_dim , (6 , 6 ))
110+ self .positional_augmenter_inp = PositionalEmbedding (self .n_dim , (14 , 14 ))
111111
112112 self .slot_mlp = nn .Sequential (nn .Linear (self .n_dim , self .n_dim ),
113113 nn .ReLU (inplace = True ),
@@ -116,8 +116,8 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim,
116116 if decode_vit :
117117 self .dino_predictor = Decoder (rssm_dim + latent_dim * latent_classes ,
118118 norm_layer = nn .GroupNorm if layer_norm else nn .Identity ,
119- conv_kernel_sizes = [],
120- channel_step = self .vit_feat_dim ,
119+ conv_kernel_sizes = [3 , 3 ],
120+ channel_step = 2 * self .vit_feat_dim ,
121121 kernel_sizes = self .decoder_kernels ,
122122 output_channels = self .vit_feat_dim + 1 ,
123123 return_dist = False )
0 commit comments