diff --git a/docqa/nn/attention.py b/docqa/nn/attention.py index 652c903..8c500a5 100644 --- a/docqa/nn/attention.py +++ b/docqa/nn/attention.py @@ -132,7 +132,7 @@ def apply(self, is_train, x, keys, memories, x_mask=None, mem_mask=None): return tf.concat([x, select_query], axis=2) # select query-to-context - context_dist = tf.reduce_max(dist_matrix, axis=2) # (batch, x_word``s) + context_dist = tf.reduce_max(dist_matrix, axis=2) # (batch, x_words) context_probs = tf.nn.softmax(context_dist) # (batch, x_words) select_context = tf.einsum("ai,aik->ak", context_probs, x) # (batch, x_dim) select_context = tf.expand_dims(select_context, 1)