Skip to content

Commit 7b18ea8

Browse files
committed
black
1 parent 6cc3050 commit 7b18ea8

File tree

4 files changed

+168
-118
lines changed

4 files changed

+168
-118
lines changed

src/model_constructor/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
from model_constructor.convmixer import ConvMixer # noqa F401
2-
from model_constructor.model_constructor import ModelConstructor, ResBlock, ModelCfg # noqa F401
3-
2+
from model_constructor.model_constructor import (
3+
ModelConstructor,
4+
ResBlock,
5+
ModelCfg,
6+
) # noqa F401
7+
48
from model_constructor.version import __version__ # noqa F401

src/model_constructor/model_constructor.py

Lines changed: 96 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -56,51 +56,66 @@ def __init__(
5656
groups = int(mid_channels / div_groups)
5757
if expansion == 1:
5858
layers = [
59-
("conv_0", conv_layer(
60-
in_channels,
61-
mid_channels,
62-
3,
63-
stride=stride, # type: ignore
64-
act_fn=act_fn,
65-
bn_1st=bn_1st,
66-
groups=in_channels if dw else groups,
67-
),),
68-
("conv_1", conv_layer(
69-
mid_channels,
70-
out_channels,
71-
3,
72-
zero_bn=zero_bn,
73-
act_fn=False,
74-
bn_1st=bn_1st,
75-
groups=mid_channels if dw else groups,
76-
),),
59+
(
60+
"conv_0",
61+
conv_layer(
62+
in_channels,
63+
mid_channels,
64+
3,
65+
stride=stride, # type: ignore
66+
act_fn=act_fn,
67+
bn_1st=bn_1st,
68+
groups=in_channels if dw else groups,
69+
),
70+
),
71+
(
72+
"conv_1",
73+
conv_layer(
74+
mid_channels,
75+
out_channels,
76+
3,
77+
zero_bn=zero_bn,
78+
act_fn=False,
79+
bn_1st=bn_1st,
80+
groups=mid_channels if dw else groups,
81+
),
82+
),
7783
]
7884
else:
7985
layers = [
80-
("conv_0", conv_layer(
81-
in_channels,
82-
mid_channels,
83-
1,
84-
act_fn=act_fn,
85-
bn_1st=bn_1st,
86-
),),
87-
("conv_1", conv_layer(
88-
mid_channels,
89-
mid_channels,
90-
3,
91-
stride=stride,
92-
act_fn=act_fn,
93-
bn_1st=bn_1st,
94-
groups=mid_channels if dw else groups,
95-
),),
96-
("conv_2", conv_layer(
97-
mid_channels,
98-
out_channels,
99-
1,
100-
zero_bn=zero_bn,
101-
act_fn=False,
102-
bn_1st=bn_1st,
103-
),), # noqa E501
86+
(
87+
"conv_0",
88+
conv_layer(
89+
in_channels,
90+
mid_channels,
91+
1,
92+
act_fn=act_fn,
93+
bn_1st=bn_1st,
94+
),
95+
),
96+
(
97+
"conv_1",
98+
conv_layer(
99+
mid_channels,
100+
mid_channels,
101+
3,
102+
stride=stride,
103+
act_fn=act_fn,
104+
bn_1st=bn_1st,
105+
groups=mid_channels if dw else groups,
106+
),
107+
),
108+
(
109+
"conv_2",
110+
conv_layer(
111+
mid_channels,
112+
out_channels,
113+
1,
114+
zero_bn=zero_bn,
115+
act_fn=False,
116+
bn_1st=bn_1st,
117+
),
118+
), # noqa E501
104119
]
105120
if se:
106121
layers.append(("se", se(out_channels)))
@@ -109,16 +124,23 @@ def __init__(
109124
self.convs = nn.Sequential(OrderedDict(layers))
110125
if stride != 1 or in_channels != out_channels:
111126
id_layers = []
112-
if stride != 1 and pool is not None: # if pool - reduce by pool else stride 2 art id_conv
127+
if (
128+
stride != 1 and pool is not None
129+
): # if pool - reduce by pool else stride 2 art id_conv
113130
id_layers.append(("pool", pool()))
114131
if in_channels != out_channels or (stride != 1 and pool is None):
115-
id_layers += [("id_conv", conv_layer(
116-
in_channels,
117-
out_channels,
118-
1,
119-
stride=1 if pool else stride,
120-
act_fn=False,
121-
),)]
132+
id_layers += [
133+
(
134+
"id_conv",
135+
conv_layer(
136+
in_channels,
137+
out_channels,
138+
1,
139+
stride=1 if pool else stride,
140+
act_fn=False,
141+
),
142+
)
143+
]
122144
self.id_conv = nn.Sequential(OrderedDict(id_layers))
123145
else:
124146
self.id_conv = None
@@ -132,16 +154,17 @@ def forward(self, x):
132154
def make_stem(cfg: TModelCfg) -> nn.Sequential: # type: ignore
133155
len_stem = len(cfg.stem_sizes)
134156
stem: List[tuple[str, nn.Module]] = [
135-
(f"conv_{i}", cfg.conv_layer(
136-
cfg.stem_sizes[i - 1] if i else cfg.in_chans, # type: ignore
137-
cfg.stem_sizes[i],
138-
stride=2 if i == cfg.stem_stride_on else 1,
139-
bn_layer=(not cfg.stem_bn_end)
140-
if i == (len_stem - 1)
141-
else True,
142-
act_fn=cfg.act_fn,
143-
bn_1st=cfg.bn_1st,
144-
),)
157+
(
158+
f"conv_{i}",
159+
cfg.conv_layer(
160+
cfg.stem_sizes[i - 1] if i else cfg.in_chans, # type: ignore
161+
cfg.stem_sizes[i],
162+
stride=2 if i == cfg.stem_stride_on else 1,
163+
bn_layer=(not cfg.stem_bn_end) if i == (len_stem - 1) else True,
164+
act_fn=cfg.act_fn,
165+
bn_1st=cfg.bn_1st,
166+
),
167+
)
145168
for i in range(len_stem)
146169
]
147170
if cfg.stem_pool:
@@ -164,7 +187,9 @@ def make_layer(cfg: TModelCfg, layer_num: int) -> nn.Sequential: # type: ignore
164187
f"bl_{block_num}",
165188
cfg.block(
166189
cfg.expansion, # type: ignore
167-
block_chs[layer_num] if block_num == 0 else block_chs[layer_num + 1],
190+
block_chs[layer_num]
191+
if block_num == 0
192+
else block_chs[layer_num + 1],
168193
block_chs[layer_num + 1],
169194
stride if block_num == 0 else 1,
170195
sa=cfg.sa
@@ -191,10 +216,7 @@ def make_body(cfg: TModelCfg) -> nn.Sequential: # type: ignore
191216
return nn.Sequential(
192217
OrderedDict(
193218
[
194-
(
195-
f"l_{layer_num}",
196-
cfg.make_layer(cfg, layer_num) # type: ignore
197-
)
219+
(f"l_{layer_num}", cfg.make_layer(cfg, layer_num)) # type: ignore
198220
for layer_num in range(len(cfg.layers))
199221
]
200222
)
@@ -222,7 +244,9 @@ class ModelCfg(BaseModel):
222244
layers: List[int] = [2, 2, 2, 2]
223245
norm: Type[nn.Module] = nn.BatchNorm2d
224246
act_fn: Type[nn.Module] = nn.ReLU
225-
pool: Callable[[Any], nn.Module] = partial(nn.AvgPool2d, kernel_size=2, ceil_mode=True)
247+
pool: Callable[[Any], nn.Module] = partial(
248+
nn.AvgPool2d, kernel_size=2, ceil_mode=True
249+
)
226250
expansion: int = 1
227251
groups: int = 1
228252
dw: bool = False
@@ -235,7 +259,9 @@ class ModelCfg(BaseModel):
235259
zero_bn: bool = True
236260
stem_stride_on: int = 0
237261
stem_sizes: List[int] = [32, 32, 64]
238-
stem_pool: Union[Callable[[], nn.Module], None] = partial(nn.MaxPool2d, kernel_size=3, stride=2, padding=1)
262+
stem_pool: Union[Callable[[], nn.Module], None] = partial(
263+
nn.MaxPool2d, kernel_size=3, stride=2, padding=1
264+
)
239265
stem_bn_end: bool = False
240266
init_cnn: Callable[[nn.Module], None] = init_cnn
241267
make_stem: Callable[[TModelCfg], Union[nn.Module, nn.Sequential]] = make_stem # type: ignore
@@ -301,7 +327,7 @@ def from_cfg(cls, cfg: ModelCfg):
301327

302328
def __call__(self):
303329
model_name = self.name or self.__class__.__name__
304-
named_sequential = type(model_name, (nn.Sequential, ), {})
330+
named_sequential = type(model_name, (nn.Sequential,), {})
305331
model = named_sequential(
306332
OrderedDict([("stem", self.stem), ("body", self.body), ("head", self.head)])
307333
)
@@ -314,7 +340,8 @@ def __call__(self):
314340
def _get_extra_repr(self) -> str:
315341
return " ".join(
316342
f"{field}: {self._get_str_value(field)},"
317-
for field in self.__fields_set__ if field != "name"
343+
for field in self.__fields_set__
344+
if field != "name"
318345
)[:-1]
319346

320347
def __repr__(self):

src/model_constructor/mxresnet.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,13 @@
44
from .net import Net
55

66

7-
__all__ = ['mxresnet_parameters', 'mxresnet34', 'mxresnet50']
7+
__all__ = ["mxresnet_parameters", "mxresnet34", "mxresnet50"]
88

99

10-
mxresnet_parameters = {'stem_sizes': [3, 32, 64, 64], 'act_fn': Mish()}
11-
mxresnet34 = partial(Net, name='MXResnet32', expansion=1, layers=[3, 4, 6, 3], **mxresnet_parameters)
12-
mxresnet50 = partial(Net, name='MXResnet50', expansion=4, layers=[3, 4, 6, 3], **mxresnet_parameters)
10+
mxresnet_parameters = {"stem_sizes": [3, 32, 64, 64], "act_fn": Mish()}
11+
mxresnet34 = partial(
12+
Net, name="MXResnet32", expansion=1, layers=[3, 4, 6, 3], **mxresnet_parameters
13+
)
14+
mxresnet50 = partial(
15+
Net, name="MXResnet50", expansion=4, layers=[3, 4, 6, 3], **mxresnet_parameters
16+
)

src/model_constructor/yaresnet.py

Lines changed: 58 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919

2020
class YaResBlock(nn.Module):
21-
'''YaResBlock. Reduce by pool instead of stride 2'''
21+
"""YaResBlock. Reduce by pool instead of stride 2"""
2222

2323
def __init__(
2424
self,
@@ -53,51 +53,66 @@ def __init__(
5353
self.reduce = None
5454
if expansion == 1:
5555
layers = [
56-
("conv_0", conv_layer(
57-
in_channels,
58-
mid_channels,
59-
3,
60-
stride=1,
61-
act_fn=act_fn,
62-
bn_1st=bn_1st,
63-
groups=in_channels if dw else groups,
64-
),),
65-
("conv_1", conv_layer(
66-
mid_channels,
67-
out_channels,
68-
3,
69-
zero_bn=zero_bn,
70-
act_fn=False,
71-
bn_1st=bn_1st,
72-
groups=mid_channels if dw else groups,
73-
),),
56+
(
57+
"conv_0",
58+
conv_layer(
59+
in_channels,
60+
mid_channels,
61+
3,
62+
stride=1,
63+
act_fn=act_fn,
64+
bn_1st=bn_1st,
65+
groups=in_channels if dw else groups,
66+
),
67+
),
68+
(
69+
"conv_1",
70+
conv_layer(
71+
mid_channels,
72+
out_channels,
73+
3,
74+
zero_bn=zero_bn,
75+
act_fn=False,
76+
bn_1st=bn_1st,
77+
groups=mid_channels if dw else groups,
78+
),
79+
),
7480
]
7581
else:
7682
layers = [
77-
("conv_0", conv_layer(
78-
in_channels,
79-
mid_channels,
80-
1,
81-
act_fn=act_fn,
82-
bn_1st=bn_1st,
83-
),),
84-
("conv_1", conv_layer(
85-
mid_channels,
86-
mid_channels,
87-
3,
88-
stride=1,
89-
act_fn=act_fn,
90-
bn_1st=bn_1st,
91-
groups=mid_channels if dw else groups,
92-
),),
93-
("conv_2", conv_layer(
94-
mid_channels,
95-
out_channels,
96-
1,
97-
zero_bn=zero_bn,
98-
act_fn=False,
99-
bn_1st=bn_1st,
100-
),), # noqa E501
83+
(
84+
"conv_0",
85+
conv_layer(
86+
in_channels,
87+
mid_channels,
88+
1,
89+
act_fn=act_fn,
90+
bn_1st=bn_1st,
91+
),
92+
),
93+
(
94+
"conv_1",
95+
conv_layer(
96+
mid_channels,
97+
mid_channels,
98+
3,
99+
stride=1,
100+
act_fn=act_fn,
101+
bn_1st=bn_1st,
102+
groups=mid_channels if dw else groups,
103+
),
104+
),
105+
(
106+
"conv_2",
107+
conv_layer(
108+
mid_channels,
109+
out_channels,
110+
1,
111+
zero_bn=zero_bn,
112+
act_fn=False,
113+
bn_1st=bn_1st,
114+
),
115+
), # noqa E501
101116
]
102117
if se:
103118
layers.append(("se", se(out_channels)))

0 commit comments

Comments
 (0)