diff --git a/model_search/blocks.py b/model_search/blocks.py index b38f64d8..7dc347e4 100644 --- a/model_search/blocks.py +++ b/model_search/blocks.py @@ -633,15 +633,15 @@ def __init__(self, projection_size=32, skip_connect=False): def build(self, input_tensors, is_training, lengths=None, hparams=None): input_tensor = input_tensors[-1] - net = tf.keras.layers.Flatten(name='flatten')(input_tensor) - net = tf.keras.layers.Dense(self._projection_size, name='lower_dim')(net) + flat_input_tensor = tf.keras.layers.Flatten(name='flatten')(input_tensor) + net = tf.keras.layers.Dense(self._projection_size, name='lower_dim')(flat_input_tensor) net = tf.nn.leaky_relu(net) net = tf.keras.layers.Dense( - input_tensor.get_shape()[1], name='expand_dim')( + flat_input_tensor.get_shape()[1], name='expand_dim')( net) net = tf.nn.leaky_relu(net) if self._skip_connect: - net += input_tensor + net += flat_input_tensor return input_tensors + [net] @property