From 9ac9259907137d7cc2edd0d1ac3fd01dbf27801d Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 18 Sep 2023 17:40:53 -0700 Subject: [PATCH] Add alternate method to apply mask to allow XLA to detect MHA pattern --- praxis/layers/attentions.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/praxis/layers/attentions.py b/praxis/layers/attentions.py index a35ce8ba..52886bcf 100644 --- a/praxis/layers/attentions.py +++ b/praxis/layers/attentions.py @@ -1173,6 +1173,7 @@ class DotProductAttention(base_layer.BaseLayer): decode_cache: bool = True attention_mask_summary: bool = False zero_fully_masked: bool = False + mha_mask_addition_pattern: bool = True qk_einsum_tpl: LayerTpl = template_field(base_ops.EinsumOp) pv_einsum_tpl: LayerTpl = template_field(base_ops.EinsumOp) per_dim_scale_tpl: LayerTpl = template_field(PerDimScale) @@ -1524,7 +1525,11 @@ def _dot_atten( # Attention softmax is always carried out in fp32. logits = logits.astype(jnp.float32) # Apply attention masking - padded_logits = py_utils.apply_mask_to_logits(logits, atten_mask) + if self.mha_mask_addition_pattern: + padded_logits = logits + atten_mask.astype(jnp.float32) + else: + padded_logits = py_utils.apply_mask_to_logits(logits, atten_mask) + if self.attention_mask_summary: self.add_summary('attention_mask', atten_mask) if self.attention_extra_logit is None: