From 066452f1dcceac399be9f2f1b7fc1fafdbf8320c Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Sat, 2 Aug 2025 16:01:30 -0700 Subject: [PATCH 1/5] add sliding window support for Gemma3 --- tools/llm/torchtrt_ext/register_sdpa.py | 5 +- tools/llm/torchtrt_ext/sdpa_converter.py | 64 +++++++++++++++++++++++- 2 files changed, 66 insertions(+), 3 deletions(-) diff --git a/tools/llm/torchtrt_ext/register_sdpa.py b/tools/llm/torchtrt_ext/register_sdpa.py index 90a00a5798..b04a7e5b3a 100644 --- a/tools/llm/torchtrt_ext/register_sdpa.py +++ b/tools/llm/torchtrt_ext/register_sdpa.py @@ -1,6 +1,7 @@ import copy import logging import operator +from re import I from typing import Callable, Sequence, Tuple import torch @@ -89,7 +90,9 @@ def replace_variants_of_sdpa( logger.warning( f"This current version of SDPA converter only supports attn_mask = None, dropout_p = 0.0 and is_causal = True configuration. This could cause issues with accuracy for models with different configurations." ) - modified_input_args = (query, key, value, None, dropout_p, True) + # TODO: lan to figure out why is_causal is always False in google/gemma-3-1b-it, as in the config file it should be every 5 sliding window layer followed by a full attention layer + # also to figure out why the attn_mask passed in from transformers is not working + modified_input_args = (query, key, value, None, dropout_p, is_causal) # Create a new node with torch.nn.functional.scaled_dot_product_attention # The input args is (query, key, value, is_causal). kwargs has scale with gm.graph.inserting_after(node): diff --git a/tools/llm/torchtrt_ext/sdpa_converter.py b/tools/llm/torchtrt_ext/sdpa_converter.py index 47083c7b48..9779294b77 100644 --- a/tools/llm/torchtrt_ext/sdpa_converter.py +++ b/tools/llm/torchtrt_ext/sdpa_converter.py @@ -27,7 +27,53 @@ def tril( name: str, row: TRTTensor, col: TRTTensor, + sliding_window_size: Optional[int] = None, ) -> TRTTensor: + + row_arange_tensor = impl.arange.arange( + ctx, target, source_ir, name + "_arange_row", start=0, end=row, step=1 + ) + col_arange_tensor = impl.arange.arange( + ctx, target, source_ir, name + "_arange_col", start=0, end=col, step=1 + ) + row_arange_tensor = impl.unsqueeze.unsqueeze( + ctx, target, source_ir, name + "_unsqueeze_row", row_arange_tensor, -1 + ) + col_arange_tensor = impl.unsqueeze.unsqueeze( + ctx, target, source_ir, name + "_unsqueeze_col", col_arange_tensor, 0 + ) + # sub will return the following mask tensor: + # [[0, -1, -2, -3], + # [1, 0, -1, -2], + # [2, 1, 0, -1], + # [3, 2, 1, 0]] + mask = impl.elementwise.sub( + ctx, target, source_ir, name + "_sub", row_arange_tensor, col_arange_tensor + ) + ge_0_mask = impl.elementwise.ge(ctx, target, source_ir, name + "_ge_0", mask, 0.0) + if sliding_window_size is None: + # return the following lower triangular mask includes the main diagonal: + # 0 ■ ⬚ ⬚ ⬚ ⬚ tensor([[[[ True, False, False, False, False], + # 1 ■ ■ ⬚ ⬚ ⬚ [ True, True, False, False, False], + # 2 ■ ■ ■ ⬚ ⬚ [ True, True, True, False, False], + # 3 ■ ■ ■ ■ ⬚ [ True, True, True, True, False], + # 4 ■ ■ ■ ■ ■ [ True, True, True, True, True]]]]) + return ge_0_mask + + lt_window_mask = impl.elementwise.lt( + ctx, target, source_ir, name + "_lt_window_size", mask, sliding_window_size + ) + mask = impl.elementwise.logical_and( + ctx, target, source_ir, name + "_logical_and", ge_0_mask, lt_window_mask + ) + # return the following mask if sliding_window_size is 3: + # 0 ■ ⬚ ⬚ ⬚ ⬚ tensor([[[[ True, False, False, False, False], + # 1 ■ ■ ⬚ ⬚ ⬚ [ True, True, False, False, False], + # 2 ■ ■ ■ ⬚ ⬚ [ True, True, True, False, False], + # 3 ⬚ ■ ■ ■ ⬚ [False, True, True, True, False], + # 4 ⬚ ⬚ ■ ■ ■ [False, False, True, True,True]]]]) + return mask + row_arange_tensor = impl.arange.arange( ctx, target, source_ir, name + "_arange_row", start=0, end=row, step=1 ) @@ -66,7 +112,7 @@ def scaled_dot_product_attention( # TODO: remove this once we have a better way to handle the causal mask scale = kwargs.get("scale", None) source_ir = SourceIR.ATEN - is_causal = True + # implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html use_fp32_acc = kwargs.get("use_fp32_acc", False) query_dtype = query.dtype @@ -136,7 +182,21 @@ def scaled_dot_product_attention( S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, 2) # generate the mask tensor - tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S) + if is_causal: + tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S) + else: + # hard code the sliding window size to 512 for now + tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S, 512) + # TODO: lan to figure out why attn_mask passed in from transformers is not working + # tried both 2d and 4d, but both are not working, hence the following code is commented out + # assert len(attn_mask.shape) in [2, 4], f"attn_mask must be 2D or 4D, but got {attn_mask.shape=}" + # if len(attn_mask.shape) == 4: + # if attn_mask.shape[0] != 1: + # attn_mask = impl.slice.slice_op(ctx, target, source_ir, name + "_slice", attn_mask, 0, 0, 1, 1) + # if attn_mask.shape[1] != 1: + # attn_mask = impl.slice.slice_op(ctx, target, source_ir, name + "_slice", attn_mask, 1, 0, 1, 1) + # attn_mask = impl.squeeze.squeeze(ctx, target, source_ir, name + "_squeeze", attn_mask, (0, 1)) + # tril_tensor = attn_mask temp_mask = impl.unary.logical_not( ctx, target, source_ir, name + "_logical_not", tril_tensor From a58d17bd1ff743d071c8852a435fbe3a021249cd Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Sat, 2 Aug 2025 16:19:45 -0700 Subject: [PATCH 2/5] test --- tools/llm/torchtrt_ext/register_sdpa.py | 1 - tools/llm/torchtrt_ext/sdpa_converter.py | 19 ------------------- 2 files changed, 20 deletions(-) diff --git a/tools/llm/torchtrt_ext/register_sdpa.py b/tools/llm/torchtrt_ext/register_sdpa.py index b04a7e5b3a..d8472cba55 100644 --- a/tools/llm/torchtrt_ext/register_sdpa.py +++ b/tools/llm/torchtrt_ext/register_sdpa.py @@ -1,7 +1,6 @@ import copy import logging import operator -from re import I from typing import Callable, Sequence, Tuple import torch diff --git a/tools/llm/torchtrt_ext/sdpa_converter.py b/tools/llm/torchtrt_ext/sdpa_converter.py index 9779294b77..8f4ba4e32f 100644 --- a/tools/llm/torchtrt_ext/sdpa_converter.py +++ b/tools/llm/torchtrt_ext/sdpa_converter.py @@ -74,25 +74,6 @@ def tril( # 4 ⬚ ⬚ ■ ■ ■ [False, False, True, True,True]]]]) return mask - row_arange_tensor = impl.arange.arange( - ctx, target, source_ir, name + "_arange_row", start=0, end=row, step=1 - ) - row_reshape_tensor = impl.shuffle.reshape( - ctx, target, source_ir, name + "_reshape_row", row_arange_tensor, [row, 1] - ) - - col_arange_tensor = impl.arange.arange( - ctx, target, source_ir, name + "_arange_col", start=0, end=col, step=1 - ) - col_reshape_tensor = impl.shuffle.reshape( - ctx, target, source_ir, name + "_reshape_col", col_arange_tensor, [1, col] - ) - - mask = impl.elementwise.ge( - ctx, target, source_ir, name + "_ge", row_reshape_tensor, col_reshape_tensor - ) - return mask - @torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter( torch.nn.functional.scaled_dot_product_attention, From a65f0f1ae5792470feb02d71163ffeb8cf2bd91b Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Thu, 14 Aug 2025 20:19:06 -0700 Subject: [PATCH 3/5] add test case --- tools/llm/run_llm.py | 2 +- tools/llm/test_trt_sdpa.py | 68 +++++++++++++++ tools/llm/torchtrt_ext/register_sdpa.py | 6 +- tools/llm/torchtrt_ext/sdpa_converter.py | 102 ++++++++++++++--------- tools/llm/utils.py | 1 - 5 files changed, 136 insertions(+), 43 deletions(-) create mode 100644 tools/llm/test_trt_sdpa.py diff --git a/tools/llm/run_llm.py b/tools/llm/run_llm.py index 7e50b515c2..5647a10a7a 100644 --- a/tools/llm/run_llm.py +++ b/tools/llm/run_llm.py @@ -116,7 +116,7 @@ def compile_torchtrt(model, input_ids, args): use_fp32_acc=use_fp32_acc, device=DEVICE, disable_tf32=True, - use_python_runtime=True, + use_python_runtime=False, debug=args.debug, offload_module_to_cpu=True, min_block_size=args.min_block_size, diff --git a/tools/llm/test_trt_sdpa.py b/tools/llm/test_trt_sdpa.py new file mode 100644 index 0000000000..28827c0184 --- /dev/null +++ b/tools/llm/test_trt_sdpa.py @@ -0,0 +1,68 @@ +import torch +import torch_tensorrt +from torch.export import Dim +from torchtrt_ext import register_sdpa + + +class SimpleNetwork(torch.nn.Module): + def __init__(self): + super(SimpleNetwork, self).__init__() + + def forward(self, query, key, value, attn_mask): + with torch.backends.cuda.sdp_kernel( + enable_flash=False, + enable_math=False, + enable_mem_efficient=True, + ): + return torch.nn.functional.scaled_dot_product_attention( + query, key, value, attn_mask, 0.0, False, scale=0.0625 + ) + + +dtype = torch.float32 + +dyn_dim = Dim("dyn_dim", min=3, max=32) + +query = torch.randn((1, 4, 13, 256), dtype=dtype).cuda() +key = torch.randn((1, 4, 13, 256), dtype=dtype).cuda() +value = torch.randn((1, 4, 13, 256), dtype=dtype).cuda() +attn_mask = torch.ones((13, 13), dtype=torch.bool).tril(diagonal=0).cuda() +inputs = (query, key, value, attn_mask) + +model = SimpleNetwork().eval().cuda() +output_pyt = model(*inputs) +exp_program = torch.export.export( + model, + inputs, + strict=False, + dynamic_shapes={ + "query": {2: dyn_dim}, + "key": {2: dyn_dim}, + "value": {2: dyn_dim}, + "attn_mask": {0: dyn_dim, 1: dyn_dim}, + }, +) +DEBUG_LOGGING_DIR = "./debug_logs" +with torch_tensorrt.dynamo.Debugger( + "graphs", + logging_dir=DEBUG_LOGGING_DIR, + capture_fx_graph_after=["complex_graph_detection"], + save_engine_profile=True, + profile_format="trex", + engine_builder_monitor=True, +): + trt_model = torch_tensorrt.dynamo.compile( + exp_program, + inputs=inputs, + enabled_precisions={dtype}, + min_block_size=1, + cache_built_engines=False, + reuse_cached_engines=False, + truncate_double=True, + use_python_runtime=False, + ) + outputs_trt = trt_model(*inputs) + breakpoint() + assert torch.allclose(output_pyt, outputs_trt, rtol=1e-2, atol=1e-2) + +print("Done") diff --git a/tools/llm/torchtrt_ext/register_sdpa.py b/tools/llm/torchtrt_ext/register_sdpa.py index d8472cba55..c5f1f68665 100644 --- a/tools/llm/torchtrt_ext/register_sdpa.py +++ b/tools/llm/torchtrt_ext/register_sdpa.py @@ -89,9 +89,9 @@ def replace_variants_of_sdpa( logger.warning( f"This current version of SDPA converter only supports attn_mask = None, dropout_p = 0.0 and is_causal = True configuration. This could cause issues with accuracy for models with different configurations." ) - # TODO: lan to figure out why is_causal is always False in google/gemma-3-1b-it, as in the config file it should be every 5 sliding window layer followed by a full attention layer - # also to figure out why the attn_mask passed in from transformers is not working - modified_input_args = (query, key, value, None, dropout_p, is_causal) + # TODO: lan to figure out why the attn_mask passed in from transformers is not working + # modified_input_args = (query, key, value, None, dropout_p, True) + modified_input_args = (query, key, value, attn_mask, dropout_p, is_causal) # Create a new node with torch.nn.functional.scaled_dot_product_attention # The input args is (query, key, value, is_causal). kwargs has scale with gm.graph.inserting_after(node): diff --git a/tools/llm/torchtrt_ext/sdpa_converter.py b/tools/llm/torchtrt_ext/sdpa_converter.py index 8f4ba4e32f..fdb04022da 100644 --- a/tools/llm/torchtrt_ext/sdpa_converter.py +++ b/tools/llm/torchtrt_ext/sdpa_converter.py @@ -161,51 +161,77 @@ def scaled_dot_product_attention( L = impl.shape.shape(ctx, target, source_ir, name + "_shape_0", query, 2) if S < 0: S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, 2) - # generate the mask tensor if is_causal: tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S) else: - # hard code the sliding window size to 512 for now - tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S, 512) # TODO: lan to figure out why attn_mask passed in from transformers is not working - # tried both 2d and 4d, but both are not working, hence the following code is commented out - # assert len(attn_mask.shape) in [2, 4], f"attn_mask must be 2D or 4D, but got {attn_mask.shape=}" - # if len(attn_mask.shape) == 4: - # if attn_mask.shape[0] != 1: - # attn_mask = impl.slice.slice_op(ctx, target, source_ir, name + "_slice", attn_mask, 0, 0, 1, 1) - # if attn_mask.shape[1] != 1: - # attn_mask = impl.slice.slice_op(ctx, target, source_ir, name + "_slice", attn_mask, 1, 0, 1, 1) - # attn_mask = impl.squeeze.squeeze(ctx, target, source_ir, name + "_squeeze", attn_mask, (0, 1)) - # tril_tensor = attn_mask - - temp_mask = impl.unary.logical_not( - ctx, target, source_ir, name + "_logical_not", tril_tensor - ) + # tried both 2d and 4d, but both are not working + assert len(attn_mask.shape) in [ + 2, + 4, + ], f"attn_mask must be 2D or 4D, but got {attn_mask.shape=}" + if len(attn_mask.shape) == 4: + if attn_mask.shape[0] != 1: + attn_mask = impl.slice.slice_op( + ctx, target, source_ir, name + "_slice", attn_mask, 0, 0, 1, 1 + ) + if attn_mask.shape[1] != 1: + attn_mask = impl.slice.slice_op( + ctx, target, source_ir, name + "_slice", attn_mask, 1, 0, 1, 1 + ) + attn_mask = impl.squeeze.squeeze( + ctx, target, source_ir, name + "_squeeze", attn_mask, (0, 1) + ) + tril_tensor = attn_mask - # This need_mask determines if we want to use the causal mask or not - # When KV caching is enabled, L = 1 and != S. In this case, we shouldn't use the causal mask. - # So need_mask will be all False values in this case. - # TODO: Implement more general case where L != 1 and S != L - need_mask = impl.elementwise.eq(ctx, target, source_ir, name + "_eq", L, S) - temp_mask = impl.elementwise.logical_and( - ctx, target, source_ir, name + "_logical_and", need_mask, temp_mask - ) - temp_mask_casted = cast_trt_tensor( - ctx, temp_mask, query_dtype, name + "_casted_bool", target, source_ir - ) + # generate attn_bias via where instead of (logical_and, sub, log) to see whether nan is related to this + attn_bias_via_where = True + if attn_bias_via_where: + attn_bias = impl.condition.where( + ctx, + target, + source_ir, + name + "_where", + torch.tensor(0.0, dtype=torch.float32).cuda(), + torch.tensor(-float("inf"), dtype=torch.float32).cuda(), + tril_tensor, + ) + else: + temp_mask = impl.unary.logical_not( + ctx, target, source_ir, name + "_logical_not", tril_tensor + ) + temp_mask = cast_trt_tensor( + ctx, temp_mask, trt.float32, name + "_casted_bool", target, source_ir + ) + temp_mask = impl.elementwise.mul( + ctx, target, source_ir, name + "_mul_-inf", temp_mask, float("-inf") + ) + attn_bias = temp_mask - one_minus_temp_mask = impl.elementwise.sub( - ctx, - target, - source_ir, - name + "_one_minus_temp_mask", - 1.0, - temp_mask_casted, - ) - attn_bias = impl.unary.log( - ctx, target, source_ir, name + "_log", one_minus_temp_mask - ) + # This need_mask determines if we want to use the causal mask or not + # When KV caching is enabled, L = 1 and != S. In this case, we shouldn't use the causal mask. + # So need_mask will be all False values in this case. + # TODO: Implement more general case where L != 1 and S != L + need_mask = impl.elementwise.eq(ctx, target, source_ir, name + "_eq", L, S) + temp_mask = impl.elementwise.logical_and( + ctx, target, source_ir, name + "_logical_and", need_mask, temp_mask + ) + temp_mask_casted = cast_trt_tensor( + ctx, temp_mask, query_dtype, name + "_casted_bool", target, source_ir + ) + + one_minus_temp_mask = impl.elementwise.sub( + ctx, + target, + source_ir, + name + "_one_minus_temp_mask", + 1.0, + temp_mask_casted, + ) + attn_bias = impl.unary.log( + ctx, target, source_ir, name + "_log", one_minus_temp_mask + ) scaled_add_attn_bias = impl.elementwise.add( ctx, target, source_ir, name + "_attn_bias_add", mm, attn_bias diff --git a/tools/llm/utils.py b/tools/llm/utils.py index 2c3434b0ed..c56aa9b490 100644 --- a/tools/llm/utils.py +++ b/tools/llm/utils.py @@ -179,7 +179,6 @@ def generate_with_dynamic_cache(model, input_seq, max_output_seq_length, eos_tok num_tokens_generated = 0 kv_cache = get_zeroed_dynamic_cache_inputs(model) last_position_id = position_ids[-1, -1].item() - breakpoint() while num_tokens_generated < num_output_tokens: is_generate = False if input_seq.shape[1] > 1 else True position_ids = ( From 47abe2c352f4d5e117ffa23f12a9bfac01934089 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Thu, 14 Aug 2025 20:22:14 -0700 Subject: [PATCH 4/5] test --- tools/llm/torchtrt_ext/sdpa_converter.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tools/llm/torchtrt_ext/sdpa_converter.py b/tools/llm/torchtrt_ext/sdpa_converter.py index fdb04022da..7d467942bd 100644 --- a/tools/llm/torchtrt_ext/sdpa_converter.py +++ b/tools/llm/torchtrt_ext/sdpa_converter.py @@ -201,13 +201,6 @@ def scaled_dot_product_attention( temp_mask = impl.unary.logical_not( ctx, target, source_ir, name + "_logical_not", tril_tensor ) - temp_mask = cast_trt_tensor( - ctx, temp_mask, trt.float32, name + "_casted_bool", target, source_ir - ) - temp_mask = impl.elementwise.mul( - ctx, target, source_ir, name + "_mul_-inf", temp_mask, float("-inf") - ) - attn_bias = temp_mask # This need_mask determines if we want to use the causal mask or not # When KV caching is enabled, L = 1 and != S. In this case, we shouldn't use the causal mask. From 779e17477b7aba20375ac9e14d23f1824b8a774b Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Mon, 18 Aug 2025 12:35:15 -0700 Subject: [PATCH 5/5] resolve the attn_mask nan issue --- tools/llm/test_trt_sdpa.py | 1 + tools/llm/torchtrt_ext/sdpa_converter.py | 92 ++++++++++++------------ 2 files changed, 49 insertions(+), 44 deletions(-) diff --git a/tools/llm/test_trt_sdpa.py b/tools/llm/test_trt_sdpa.py index 28827c0184..5691b206df 100644 --- a/tools/llm/test_trt_sdpa.py +++ b/tools/llm/test_trt_sdpa.py @@ -13,6 +13,7 @@ def forward(self, query, key, value, attn_mask): enable_flash=False, enable_math=False, enable_mem_efficient=True, + enable_cudnn=False, ): return torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask, 0.0, False, scale=0.0625 diff --git a/tools/llm/torchtrt_ext/sdpa_converter.py b/tools/llm/torchtrt_ext/sdpa_converter.py index 7d467942bd..c793851471 100644 --- a/tools/llm/torchtrt_ext/sdpa_converter.py +++ b/tools/llm/torchtrt_ext/sdpa_converter.py @@ -162,11 +162,7 @@ def scaled_dot_product_attention( if S < 0: S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, 2) # generate the mask tensor - if is_causal: - tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S) - else: - # TODO: lan to figure out why attn_mask passed in from transformers is not working - # tried both 2d and 4d, but both are not working + if not is_causal: assert len(attn_mask.shape) in [ 2, 4, @@ -183,48 +179,56 @@ def scaled_dot_product_attention( attn_mask = impl.squeeze.squeeze( ctx, target, source_ir, name + "_squeeze", attn_mask, (0, 1) ) - tril_tensor = attn_mask - - # generate attn_bias via where instead of (logical_and, sub, log) to see whether nan is related to this - attn_bias_via_where = True - if attn_bias_via_where: - attn_bias = impl.condition.where( - ctx, - target, - source_ir, - name + "_where", - torch.tensor(0.0, dtype=torch.float32).cuda(), - torch.tensor(-float("inf"), dtype=torch.float32).cuda(), - tril_tensor, - ) + attn_bias = attn_mask else: - temp_mask = impl.unary.logical_not( - ctx, target, source_ir, name + "_logical_not", tril_tensor - ) + tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S) + # generate attn_bias via where instead of (logical_and, sub, log) to see whether nan is related to this + attn_bias_via_where = True + if attn_bias_via_where: + attn_bias = impl.condition.where( + ctx, + target, + source_ir, + name + "_where", + torch.tensor(0.0, dtype=torch.float32).cuda(), + torch.tensor(-float("inf"), dtype=torch.float32).cuda(), + tril_tensor, + ) + else: + temp_mask = impl.unary.logical_not( + ctx, target, source_ir, name + "_logical_not", tril_tensor + ) - # This need_mask determines if we want to use the causal mask or not - # When KV caching is enabled, L = 1 and != S. In this case, we shouldn't use the causal mask. - # So need_mask will be all False values in this case. - # TODO: Implement more general case where L != 1 and S != L - need_mask = impl.elementwise.eq(ctx, target, source_ir, name + "_eq", L, S) - temp_mask = impl.elementwise.logical_and( - ctx, target, source_ir, name + "_logical_and", need_mask, temp_mask - ) - temp_mask_casted = cast_trt_tensor( - ctx, temp_mask, query_dtype, name + "_casted_bool", target, source_ir - ) + # This need_mask determines if we want to use the causal mask or not + # When KV caching is enabled, L = 1 and != S. In this case, we shouldn't use the causal mask. + # So need_mask will be all False values in this case. + # TODO: Implement more general case where L != 1 and S != L + need_mask = impl.elementwise.eq( + ctx, target, source_ir, name + "_eq", L, S + ) + temp_mask = impl.elementwise.logical_and( + ctx, target, source_ir, name + "_logical_and", need_mask, temp_mask + ) + temp_mask_casted = cast_trt_tensor( + ctx, + temp_mask, + query_dtype, + name + "_casted_bool", + target, + source_ir, + ) - one_minus_temp_mask = impl.elementwise.sub( - ctx, - target, - source_ir, - name + "_one_minus_temp_mask", - 1.0, - temp_mask_casted, - ) - attn_bias = impl.unary.log( - ctx, target, source_ir, name + "_log", one_minus_temp_mask - ) + one_minus_temp_mask = impl.elementwise.sub( + ctx, + target, + source_ir, + name + "_one_minus_temp_mask", + 1.0, + temp_mask_casted, + ) + attn_bias = impl.unary.log( + ctx, target, source_ir, name + "_log", one_minus_temp_mask + ) scaled_add_attn_bias = impl.elementwise.add( ctx, target, source_ir, name + "_attn_bias_add", mm, attn_bias