Skip to content

Commit 7b561f5

Browse files
committed
stem fix in_chans
1 parent fbd5b00 commit 7b561f5

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

src/model_constructor/model_constructor.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -130,18 +130,19 @@ def forward(self, x):
130130

131131

132132
def make_stem(cfg: TModelCfg) -> nn.Sequential: # type: ignore
133+
len_stem = len(cfg.stem_sizes)
133134
stem: List[tuple[str, nn.Module]] = [
134135
(f"conv_{i}", cfg.conv_layer(
135-
cfg.stem_sizes[i], # type: ignore
136-
cfg.stem_sizes[i + 1],
136+
cfg.stem_sizes[i - 1] if i else cfg.in_chans, # type: ignore
137+
cfg.stem_sizes[i],
137138
stride=2 if i == cfg.stem_stride_on else 1,
138139
bn_layer=(not cfg.stem_bn_end)
139-
if i == (len(cfg.stem_sizes) - 2)
140+
if i == (len_stem - 1)
140141
else True,
141142
act_fn=cfg.act_fn,
142143
bn_1st=cfg.bn_1st,
143144
),)
144-
for i in range(len(cfg.stem_sizes) - 1)
145+
for i in range(len_stem)
145146
]
146147
if cfg.stem_pool:
147148
stem.append(("stem_pool", cfg.stem_pool()))
@@ -262,8 +263,6 @@ class ModelConstructor(ModelCfg):
262263

263264
@root_validator
264265
def post_init(cls, values): # pylint: disable=E0213
265-
if values["stem_sizes"][0] != values["in_chans"]:
266-
values["stem_sizes"] = [values["in_chans"]] + values["stem_sizes"]
267266
if values["se"] and isinstance(values["se"], (bool, int)): # if se=1 or se=True
268267
values["se"] = SEModule
269268
if values["sa"] and isinstance(values["sa"], (bool, int)): # if sa=1 or sa=True

0 commit comments

Comments
 (0)