@@ -128,20 +128,23 @@ def preprocess_input(x, **kwargs):
128128 return _preprocess_input (x , mode = 'torch' , ** kwargs )
129129
130130
131- def swish (x ):
132- """Swish activation function: x * sigmoid(x).
133- Reference: [Searching for Activation Functions](https://arxiv.org/abs/1710.05941)
134- """
131+ def get_swish (** kwargs ):
132+ backend , layers , models , keras_utils = get_submodules_from_kwargs (kwargs )
133+ def swish (x ):
134+ """Swish activation function: x * sigmoid(x).
135+ Reference: [Searching for Activation Functions](https://arxiv.org/abs/1710.05941)
136+ """
135137
136- if backend .backend () == 'tensorflow' :
137- try :
138- # The native TF implementation has a more
139- # memory-efficient gradient implementation
140- return backend .tf .nn .swish (x )
141- except AttributeError :
142- pass
138+ if backend .backend () == 'tensorflow' :
139+ try :
140+ # The native TF implementation has a more
141+ # memory-efficient gradient implementation
142+ return backend .tf .nn .swish (x )
143+ except AttributeError :
144+ pass
143145
144- return x * backend .sigmoid (x )
146+ return x * backend .sigmoid (x )
147+ return swish
145148
146149
147150def get_dropout (** kwargs ):
@@ -185,7 +188,7 @@ def round_repeats(repeats, depth_coefficient):
185188 return int (math .ceil (depth_coefficient * repeats ))
186189
187190
188- def mb_conv_block (inputs , block_args , drop_rate = None , relu_fn = swish , prefix = '' , ):
191+ def mb_conv_block (inputs , block_args , activation , drop_rate = None , prefix = '' , ):
189192 """Mobile Inverted Residual Bottleneck."""
190193
191194 has_se = (block_args .se_ratio is not None ) and (0 < block_args .se_ratio <= 1 )
@@ -208,7 +211,7 @@ def mb_conv_block(inputs, block_args, drop_rate=None, relu_fn=swish, prefix='',
208211 kernel_initializer = CONV_KERNEL_INITIALIZER ,
209212 name = prefix + 'expand_conv' )(inputs )
210213 x = layers .BatchNormalization (axis = bn_axis , name = prefix + 'expand_bn' )(x )
211- x = layers .Activation (relu_fn , name = prefix + 'expand_activation' )(x )
214+ x = layers .Activation (activation , name = prefix + 'expand_activation' )(x )
212215 else :
213216 x = inputs
214217
@@ -220,7 +223,7 @@ def mb_conv_block(inputs, block_args, drop_rate=None, relu_fn=swish, prefix='',
220223 depthwise_initializer = CONV_KERNEL_INITIALIZER ,
221224 name = prefix + 'dwconv' )(x )
222225 x = layers .BatchNormalization (axis = bn_axis , name = prefix + 'bn' )(x )
223- x = layers .Activation (relu_fn , name = prefix + 'activation' )(x )
226+ x = layers .Activation (activation , name = prefix + 'activation' )(x )
224227
225228 # Squeeze and Excitation phase
226229 if has_se :
@@ -232,7 +235,7 @@ def mb_conv_block(inputs, block_args, drop_rate=None, relu_fn=swish, prefix='',
232235 target_shape = (1 , 1 , filters ) if backend .image_data_format () == 'channels_last' else (filters , 1 , 1 )
233236 se_tensor = layers .Reshape (target_shape , name = prefix + 'se_reshape' )(se_tensor )
234237 se_tensor = layers .Conv2D (num_reduced_filters , 1 ,
235- activation = relu_fn ,
238+ activation = activation ,
236239 padding = 'same' ,
237240 use_bias = True ,
238241 kernel_initializer = CONV_KERNEL_INITIALIZER ,
@@ -278,7 +281,6 @@ def EfficientNet(width_coefficient,
278281 dropout_rate = 0.2 ,
279282 drop_connect_rate = 0.2 ,
280283 depth_divisor = 8 ,
281- relu_fn = swish ,
282284 blocks_args = DEFAULT_BLOCKS_ARGS ,
283285 model_name = 'efficientnet' ,
284286 include_top = True ,
@@ -367,6 +369,7 @@ def EfficientNet(width_coefficient,
367369 img_input = input_tensor
368370
369371 bn_axis = 3 if backend .image_data_format () == 'channels_last' else 1
372+ activation = get_swish (** kwargs )
370373
371374 # Build stem
372375 x = img_input
@@ -377,7 +380,7 @@ def EfficientNet(width_coefficient,
377380 kernel_initializer = CONV_KERNEL_INITIALIZER ,
378381 name = 'stem_conv' )(x )
379382 x = layers .BatchNormalization (axis = bn_axis , name = 'stem_bn' )(x )
380- x = layers .Activation (relu_fn , name = 'stem_activation' )(x )
383+ x = layers .Activation (activation , name = 'stem_activation' )(x )
381384
382385 # Build blocks
383386 num_blocks_total = sum (block_args .num_repeat for block_args in blocks_args )
@@ -395,7 +398,7 @@ def EfficientNet(width_coefficient,
395398 # The first block needs to take care of stride and filter size increase.
396399 drop_rate = drop_connect_rate * float (block_num ) / num_blocks_total
397400 x = mb_conv_block (x , block_args ,
398- relu_fn = relu_fn ,
401+ activation = activation ,
399402 drop_rate = drop_rate ,
400403 prefix = 'block{}a_' .format (idx + 1 ))
401404 block_num += 1
@@ -411,7 +414,7 @@ def EfficientNet(width_coefficient,
411414 string .ascii_lowercase [bidx + 1 ]
412415 )
413416 x = mb_conv_block (x , block_args ,
414- relu_fn = relu_fn ,
417+ activation = activation ,
415418 drop_rate = drop_rate ,
416419 prefix = block_prefix )
417420 block_num += 1
@@ -423,7 +426,7 @@ def EfficientNet(width_coefficient,
423426 kernel_initializer = CONV_KERNEL_INITIALIZER ,
424427 name = 'top_conv' )(x )
425428 x = layers .BatchNormalization (axis = bn_axis , name = 'top_bn' )(x )
426- x = layers .Activation (relu_fn , name = 'top_activation' )(x )
429+ x = layers .Activation (activation , name = 'top_activation' )(x )
427430 if include_top :
428431 x = layers .GlobalAveragePooling2D (name = 'avg_pool' )(x )
429432 if dropout_rate and dropout_rate > 0 :
0 commit comments