@@ -56,51 +56,66 @@ def __init__(
56
56
groups = int (mid_channels / div_groups )
57
57
if expansion == 1 :
58
58
layers = [
59
- ("conv_0" , conv_layer (
60
- in_channels ,
61
- mid_channels ,
62
- 3 ,
63
- stride = stride , # type: ignore
64
- act_fn = act_fn ,
65
- bn_1st = bn_1st ,
66
- groups = in_channels if dw else groups ,
67
- ),),
68
- ("conv_1" , conv_layer (
69
- mid_channels ,
70
- out_channels ,
71
- 3 ,
72
- zero_bn = zero_bn ,
73
- act_fn = False ,
74
- bn_1st = bn_1st ,
75
- groups = mid_channels if dw else groups ,
76
- ),),
59
+ (
60
+ "conv_0" ,
61
+ conv_layer (
62
+ in_channels ,
63
+ mid_channels ,
64
+ 3 ,
65
+ stride = stride , # type: ignore
66
+ act_fn = act_fn ,
67
+ bn_1st = bn_1st ,
68
+ groups = in_channels if dw else groups ,
69
+ ),
70
+ ),
71
+ (
72
+ "conv_1" ,
73
+ conv_layer (
74
+ mid_channels ,
75
+ out_channels ,
76
+ 3 ,
77
+ zero_bn = zero_bn ,
78
+ act_fn = False ,
79
+ bn_1st = bn_1st ,
80
+ groups = mid_channels if dw else groups ,
81
+ ),
82
+ ),
77
83
]
78
84
else :
79
85
layers = [
80
- ("conv_0" , conv_layer (
81
- in_channels ,
82
- mid_channels ,
83
- 1 ,
84
- act_fn = act_fn ,
85
- bn_1st = bn_1st ,
86
- ),),
87
- ("conv_1" , conv_layer (
88
- mid_channels ,
89
- mid_channels ,
90
- 3 ,
91
- stride = stride ,
92
- act_fn = act_fn ,
93
- bn_1st = bn_1st ,
94
- groups = mid_channels if dw else groups ,
95
- ),),
96
- ("conv_2" , conv_layer (
97
- mid_channels ,
98
- out_channels ,
99
- 1 ,
100
- zero_bn = zero_bn ,
101
- act_fn = False ,
102
- bn_1st = bn_1st ,
103
- ),), # noqa E501
86
+ (
87
+ "conv_0" ,
88
+ conv_layer (
89
+ in_channels ,
90
+ mid_channels ,
91
+ 1 ,
92
+ act_fn = act_fn ,
93
+ bn_1st = bn_1st ,
94
+ ),
95
+ ),
96
+ (
97
+ "conv_1" ,
98
+ conv_layer (
99
+ mid_channels ,
100
+ mid_channels ,
101
+ 3 ,
102
+ stride = stride ,
103
+ act_fn = act_fn ,
104
+ bn_1st = bn_1st ,
105
+ groups = mid_channels if dw else groups ,
106
+ ),
107
+ ),
108
+ (
109
+ "conv_2" ,
110
+ conv_layer (
111
+ mid_channels ,
112
+ out_channels ,
113
+ 1 ,
114
+ zero_bn = zero_bn ,
115
+ act_fn = False ,
116
+ bn_1st = bn_1st ,
117
+ ),
118
+ ), # noqa E501
104
119
]
105
120
if se :
106
121
layers .append (("se" , se (out_channels )))
@@ -109,16 +124,23 @@ def __init__(
109
124
self .convs = nn .Sequential (OrderedDict (layers ))
110
125
if stride != 1 or in_channels != out_channels :
111
126
id_layers = []
112
- if stride != 1 and pool is not None : # if pool - reduce by pool else stride 2 art id_conv
127
+ if (
128
+ stride != 1 and pool is not None
129
+ ): # if pool - reduce by pool else stride 2 art id_conv
113
130
id_layers .append (("pool" , pool ()))
114
131
if in_channels != out_channels or (stride != 1 and pool is None ):
115
- id_layers += [("id_conv" , conv_layer (
116
- in_channels ,
117
- out_channels ,
118
- 1 ,
119
- stride = 1 if pool else stride ,
120
- act_fn = False ,
121
- ),)]
132
+ id_layers += [
133
+ (
134
+ "id_conv" ,
135
+ conv_layer (
136
+ in_channels ,
137
+ out_channels ,
138
+ 1 ,
139
+ stride = 1 if pool else stride ,
140
+ act_fn = False ,
141
+ ),
142
+ )
143
+ ]
122
144
self .id_conv = nn .Sequential (OrderedDict (id_layers ))
123
145
else :
124
146
self .id_conv = None
@@ -132,16 +154,17 @@ def forward(self, x):
132
154
def make_stem (cfg : TModelCfg ) -> nn .Sequential : # type: ignore
133
155
len_stem = len (cfg .stem_sizes )
134
156
stem : List [tuple [str , nn .Module ]] = [
135
- (f"conv_{ i } " , cfg .conv_layer (
136
- cfg .stem_sizes [i - 1 ] if i else cfg .in_chans , # type: ignore
137
- cfg .stem_sizes [i ],
138
- stride = 2 if i == cfg .stem_stride_on else 1 ,
139
- bn_layer = (not cfg .stem_bn_end )
140
- if i == (len_stem - 1 )
141
- else True ,
142
- act_fn = cfg .act_fn ,
143
- bn_1st = cfg .bn_1st ,
144
- ),)
157
+ (
158
+ f"conv_{ i } " ,
159
+ cfg .conv_layer (
160
+ cfg .stem_sizes [i - 1 ] if i else cfg .in_chans , # type: ignore
161
+ cfg .stem_sizes [i ],
162
+ stride = 2 if i == cfg .stem_stride_on else 1 ,
163
+ bn_layer = (not cfg .stem_bn_end ) if i == (len_stem - 1 ) else True ,
164
+ act_fn = cfg .act_fn ,
165
+ bn_1st = cfg .bn_1st ,
166
+ ),
167
+ )
145
168
for i in range (len_stem )
146
169
]
147
170
if cfg .stem_pool :
@@ -164,7 +187,9 @@ def make_layer(cfg: TModelCfg, layer_num: int) -> nn.Sequential: # type: ignore
164
187
f"bl_{ block_num } " ,
165
188
cfg .block (
166
189
cfg .expansion , # type: ignore
167
- block_chs [layer_num ] if block_num == 0 else block_chs [layer_num + 1 ],
190
+ block_chs [layer_num ]
191
+ if block_num == 0
192
+ else block_chs [layer_num + 1 ],
168
193
block_chs [layer_num + 1 ],
169
194
stride if block_num == 0 else 1 ,
170
195
sa = cfg .sa
@@ -191,10 +216,7 @@ def make_body(cfg: TModelCfg) -> nn.Sequential: # type: ignore
191
216
return nn .Sequential (
192
217
OrderedDict (
193
218
[
194
- (
195
- f"l_{ layer_num } " ,
196
- cfg .make_layer (cfg , layer_num ) # type: ignore
197
- )
219
+ (f"l_{ layer_num } " , cfg .make_layer (cfg , layer_num )) # type: ignore
198
220
for layer_num in range (len (cfg .layers ))
199
221
]
200
222
)
@@ -222,7 +244,9 @@ class ModelCfg(BaseModel):
222
244
layers : List [int ] = [2 , 2 , 2 , 2 ]
223
245
norm : Type [nn .Module ] = nn .BatchNorm2d
224
246
act_fn : Type [nn .Module ] = nn .ReLU
225
- pool : Callable [[Any ], nn .Module ] = partial (nn .AvgPool2d , kernel_size = 2 , ceil_mode = True )
247
+ pool : Callable [[Any ], nn .Module ] = partial (
248
+ nn .AvgPool2d , kernel_size = 2 , ceil_mode = True
249
+ )
226
250
expansion : int = 1
227
251
groups : int = 1
228
252
dw : bool = False
@@ -235,7 +259,9 @@ class ModelCfg(BaseModel):
235
259
zero_bn : bool = True
236
260
stem_stride_on : int = 0
237
261
stem_sizes : List [int ] = [32 , 32 , 64 ]
238
- stem_pool : Union [Callable [[], nn .Module ], None ] = partial (nn .MaxPool2d , kernel_size = 3 , stride = 2 , padding = 1 )
262
+ stem_pool : Union [Callable [[], nn .Module ], None ] = partial (
263
+ nn .MaxPool2d , kernel_size = 3 , stride = 2 , padding = 1
264
+ )
239
265
stem_bn_end : bool = False
240
266
init_cnn : Callable [[nn .Module ], None ] = init_cnn
241
267
make_stem : Callable [[TModelCfg ], Union [nn .Module , nn .Sequential ]] = make_stem # type: ignore
@@ -301,7 +327,7 @@ def from_cfg(cls, cfg: ModelCfg):
301
327
302
328
def __call__ (self ):
303
329
model_name = self .name or self .__class__ .__name__
304
- named_sequential = type (model_name , (nn .Sequential , ), {})
330
+ named_sequential = type (model_name , (nn .Sequential ,), {})
305
331
model = named_sequential (
306
332
OrderedDict ([("stem" , self .stem ), ("body" , self .body ), ("head" , self .head )])
307
333
)
@@ -314,7 +340,8 @@ def __call__(self):
314
340
def _get_extra_repr (self ) -> str :
315
341
return " " .join (
316
342
f"{ field } : { self ._get_str_value (field )} ,"
317
- for field in self .__fields_set__ if field != "name"
343
+ for field in self .__fields_set__
344
+ if field != "name"
318
345
)[:- 1 ]
319
346
320
347
def __repr__ (self ):
0 commit comments