@@ -130,18 +130,19 @@ def forward(self, x):
130
130
131
131
132
132
def make_stem (cfg : TModelCfg ) -> nn .Sequential : # type: ignore
133
+ len_stem = len (cfg .stem_sizes )
133
134
stem : List [tuple [str , nn .Module ]] = [
134
135
(f"conv_{ i } " , cfg .conv_layer (
135
- cfg .stem_sizes [i ] , # type: ignore
136
- cfg .stem_sizes [i + 1 ],
136
+ cfg .stem_sizes [i - 1 ] if i else cfg . in_chans , # type: ignore
137
+ cfg .stem_sizes [i ],
137
138
stride = 2 if i == cfg .stem_stride_on else 1 ,
138
139
bn_layer = (not cfg .stem_bn_end )
139
- if i == (len ( cfg . stem_sizes ) - 2 )
140
+ if i == (len_stem - 1 )
140
141
else True ,
141
142
act_fn = cfg .act_fn ,
142
143
bn_1st = cfg .bn_1st ,
143
144
),)
144
- for i in range (len ( cfg . stem_sizes ) - 1 )
145
+ for i in range (len_stem )
145
146
]
146
147
if cfg .stem_pool :
147
148
stem .append (("stem_pool" , cfg .stem_pool ()))
@@ -262,8 +263,6 @@ class ModelConstructor(ModelCfg):
262
263
263
264
@root_validator
264
265
def post_init (cls , values ): # pylint: disable=E0213
265
- if values ["stem_sizes" ][0 ] != values ["in_chans" ]:
266
- values ["stem_sizes" ] = [values ["in_chans" ]] + values ["stem_sizes" ]
267
266
if values ["se" ] and isinstance (values ["se" ], (bool , int )): # if se=1 or se=True
268
267
values ["se" ] = SEModule
269
268
if values ["sa" ] and isinstance (values ["sa" ], (bool , int )): # if sa=1 or sa=True
0 commit comments