Skip to content

Commit 5a3f957

Browse files
authored
Fix swish fn (#52)
1 parent 3ff1824 commit 5a3f957

File tree

2 files changed

+26
-23
lines changed

2 files changed

+26
-23
lines changed

efficientnet/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def init_keras_custom_objects():
6363
from . import model
6464

6565
custom_objects = {
66-
'swish': model.swish,
66+
'swish': inject_keras_modules(model.get_swish)(),
6767
'FixedDropout': inject_keras_modules(model.get_dropout)()
6868
}
6969

@@ -75,7 +75,7 @@ def init_tfkeras_custom_objects():
7575
from . import model
7676

7777
custom_objects = {
78-
'swish': model.swish,
78+
'swish': inject_tfkeras_modules(model.get_swish)(),
7979
'FixedDropout': inject_tfkeras_modules(model.get_dropout)()
8080
}
8181

efficientnet/model.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -128,20 +128,23 @@ def preprocess_input(x, **kwargs):
128128
return _preprocess_input(x, mode='torch', **kwargs)
129129

130130

131-
def swish(x):
132-
"""Swish activation function: x * sigmoid(x).
133-
Reference: [Searching for Activation Functions](https://arxiv.org/abs/1710.05941)
134-
"""
131+
def get_swish(**kwargs):
132+
backend, layers, models, keras_utils = get_submodules_from_kwargs(kwargs)
133+
def swish(x):
134+
"""Swish activation function: x * sigmoid(x).
135+
Reference: [Searching for Activation Functions](https://arxiv.org/abs/1710.05941)
136+
"""
135137

136-
if backend.backend() == 'tensorflow':
137-
try:
138-
# The native TF implementation has a more
139-
# memory-efficient gradient implementation
140-
return backend.tf.nn.swish(x)
141-
except AttributeError:
142-
pass
138+
if backend.backend() == 'tensorflow':
139+
try:
140+
# The native TF implementation has a more
141+
# memory-efficient gradient implementation
142+
return backend.tf.nn.swish(x)
143+
except AttributeError:
144+
pass
143145

144-
return x * backend.sigmoid(x)
146+
return x * backend.sigmoid(x)
147+
return swish
145148

146149

147150
def get_dropout(**kwargs):
@@ -185,7 +188,7 @@ def round_repeats(repeats, depth_coefficient):
185188
return int(math.ceil(depth_coefficient * repeats))
186189

187190

188-
def mb_conv_block(inputs, block_args, drop_rate=None, relu_fn=swish, prefix='', ):
191+
def mb_conv_block(inputs, block_args, activation, drop_rate=None, prefix='', ):
189192
"""Mobile Inverted Residual Bottleneck."""
190193

191194
has_se = (block_args.se_ratio is not None) and (0 < block_args.se_ratio <= 1)
@@ -208,7 +211,7 @@ def mb_conv_block(inputs, block_args, drop_rate=None, relu_fn=swish, prefix='',
208211
kernel_initializer=CONV_KERNEL_INITIALIZER,
209212
name=prefix + 'expand_conv')(inputs)
210213
x = layers.BatchNormalization(axis=bn_axis, name=prefix + 'expand_bn')(x)
211-
x = layers.Activation(relu_fn, name=prefix + 'expand_activation')(x)
214+
x = layers.Activation(activation, name=prefix + 'expand_activation')(x)
212215
else:
213216
x = inputs
214217

@@ -220,7 +223,7 @@ def mb_conv_block(inputs, block_args, drop_rate=None, relu_fn=swish, prefix='',
220223
depthwise_initializer=CONV_KERNEL_INITIALIZER,
221224
name=prefix + 'dwconv')(x)
222225
x = layers.BatchNormalization(axis=bn_axis, name=prefix + 'bn')(x)
223-
x = layers.Activation(relu_fn, name=prefix + 'activation')(x)
226+
x = layers.Activation(activation, name=prefix + 'activation')(x)
224227

225228
# Squeeze and Excitation phase
226229
if has_se:
@@ -232,7 +235,7 @@ def mb_conv_block(inputs, block_args, drop_rate=None, relu_fn=swish, prefix='',
232235
target_shape = (1, 1, filters) if backend.image_data_format() == 'channels_last' else (filters, 1, 1)
233236
se_tensor = layers.Reshape(target_shape, name=prefix + 'se_reshape')(se_tensor)
234237
se_tensor = layers.Conv2D(num_reduced_filters, 1,
235-
activation=relu_fn,
238+
activation=activation,
236239
padding='same',
237240
use_bias=True,
238241
kernel_initializer=CONV_KERNEL_INITIALIZER,
@@ -278,7 +281,6 @@ def EfficientNet(width_coefficient,
278281
dropout_rate=0.2,
279282
drop_connect_rate=0.2,
280283
depth_divisor=8,
281-
relu_fn=swish,
282284
blocks_args=DEFAULT_BLOCKS_ARGS,
283285
model_name='efficientnet',
284286
include_top=True,
@@ -367,6 +369,7 @@ def EfficientNet(width_coefficient,
367369
img_input = input_tensor
368370

369371
bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1
372+
activation = get_swish(**kwargs)
370373

371374
# Build stem
372375
x = img_input
@@ -377,7 +380,7 @@ def EfficientNet(width_coefficient,
377380
kernel_initializer=CONV_KERNEL_INITIALIZER,
378381
name='stem_conv')(x)
379382
x = layers.BatchNormalization(axis=bn_axis, name='stem_bn')(x)
380-
x = layers.Activation(relu_fn, name='stem_activation')(x)
383+
x = layers.Activation(activation, name='stem_activation')(x)
381384

382385
# Build blocks
383386
num_blocks_total = sum(block_args.num_repeat for block_args in blocks_args)
@@ -395,7 +398,7 @@ def EfficientNet(width_coefficient,
395398
# The first block needs to take care of stride and filter size increase.
396399
drop_rate = drop_connect_rate * float(block_num) / num_blocks_total
397400
x = mb_conv_block(x, block_args,
398-
relu_fn=relu_fn,
401+
activation=activation,
399402
drop_rate=drop_rate,
400403
prefix='block{}a_'.format(idx + 1))
401404
block_num += 1
@@ -411,7 +414,7 @@ def EfficientNet(width_coefficient,
411414
string.ascii_lowercase[bidx + 1]
412415
)
413416
x = mb_conv_block(x, block_args,
414-
relu_fn=relu_fn,
417+
activation=activation,
415418
drop_rate=drop_rate,
416419
prefix=block_prefix)
417420
block_num += 1
@@ -423,7 +426,7 @@ def EfficientNet(width_coefficient,
423426
kernel_initializer=CONV_KERNEL_INITIALIZER,
424427
name='top_conv')(x)
425428
x = layers.BatchNormalization(axis=bn_axis, name='top_bn')(x)
426-
x = layers.Activation(relu_fn, name='top_activation')(x)
429+
x = layers.Activation(activation, name='top_activation')(x)
427430
if include_top:
428431
x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
429432
if dropout_rate and dropout_rate > 0:

0 commit comments

Comments
 (0)