From 2f09be6e9cc768c3d8a3d61c92aaf35a5474810c Mon Sep 17 00:00:00 2001 From: Ishant Mrinal Haloi Date: Sun, 28 Mar 2021 16:09:37 +0530 Subject: [PATCH] fix BottleNeckBlock skip_connect bug --- model_search/blocks.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/model_search/blocks.py b/model_search/blocks.py index b38f64d8..7dc347e4 100644 --- a/model_search/blocks.py +++ b/model_search/blocks.py @@ -633,15 +633,15 @@ def __init__(self, projection_size=32, skip_connect=False): def build(self, input_tensors, is_training, lengths=None, hparams=None): input_tensor = input_tensors[-1] - net = tf.keras.layers.Flatten(name='flatten')(input_tensor) - net = tf.keras.layers.Dense(self._projection_size, name='lower_dim')(net) + flat_input_tensor = tf.keras.layers.Flatten(name='flatten')(input_tensor) + net = tf.keras.layers.Dense(self._projection_size, name='lower_dim')(flat_input_tensor) net = tf.nn.leaky_relu(net) net = tf.keras.layers.Dense( - input_tensor.get_shape()[1], name='expand_dim')( + flat_input_tensor.get_shape()[1], name='expand_dim')( net) net = tf.nn.leaky_relu(net) if self._skip_connect: - net += input_tensor + net += flat_input_tensor return input_tensors + [net] @property