From c0b6fd341b340e84459029591e5bf0d735460728 Mon Sep 17 00:00:00 2001 From: Abhinav Goel Date: Wed, 7 Jun 2023 16:54:09 -0700 Subject: [PATCH] adding alternate method to apply mask to allow XLA to detect MHA pattern more easily --- praxis/layers/attentions.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/praxis/layers/attentions.py b/praxis/layers/attentions.py index d670febf..7148d8ac 100644 --- a/praxis/layers/attentions.py +++ b/praxis/layers/attentions.py @@ -979,6 +979,7 @@ class DotProductAttention(base_layer.BaseLayer): zero_fully_masked: bool = False qk_einsum_tpl: LayerTpl = template_field(base_ops.Einsum) pv_einsum_tpl: LayerTpl = template_field(base_ops.Einsum) + mha_mask_addition_pattern: bool = True # SPMD partition related params. # @@ -1307,8 +1308,14 @@ def _dot_atten( logits = self._cap_logits(logits) # Attention softmax is always carried out in fp32. logits = logits.astype(jnp.float32) + # Apply attention masking - padded_logits = py_utils.apply_mask_to_logits(logits, atten_mask) + if self.mha_mask_addition_pattern: + padded_logits = logits + atten_mask.astype(jnp.float32) + else: + padded_logits = py_utils.apply_mask_to_logits(logits, atten_mask) + + if self.attention_mask_summary: self.add_summary('attention_mask', atten_mask) if self.attention_extra_logit is None: