Skip to content

add sliding window support for Gemma3 #3742

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tools/llm/run_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def compile_torchtrt(model, input_ids, args):
use_fp32_acc=use_fp32_acc,
device=DEVICE,
disable_tf32=True,
use_python_runtime=True,
use_python_runtime=False,
debug=args.debug,
offload_module_to_cpu=True,
min_block_size=args.min_block_size,
Expand Down
68 changes: 68 additions & 0 deletions tools/llm/test_trt_sdpa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import torch
import torch_tensorrt
from torch.export import Dim
from torchtrt_ext import register_sdpa


class SimpleNetwork(torch.nn.Module):
def __init__(self):
super(SimpleNetwork, self).__init__()

def forward(self, query, key, value, attn_mask):
with torch.backends.cuda.sdp_kernel(
enable_flash=False,
enable_math=False,
enable_mem_efficient=True,
):
return torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask, 0.0, False, scale=0.0625
)


dtype = torch.float32

dyn_dim = Dim("dyn_dim", min=3, max=32)

query = torch.randn((1, 4, 13, 256), dtype=dtype).cuda()
key = torch.randn((1, 4, 13, 256), dtype=dtype).cuda()
value = torch.randn((1, 4, 13, 256), dtype=dtype).cuda()
attn_mask = torch.ones((13, 13), dtype=torch.bool).tril(diagonal=0).cuda()
inputs = (query, key, value, attn_mask)

model = SimpleNetwork().eval().cuda()
output_pyt = model(*inputs)
exp_program = torch.export.export(
model,
inputs,
strict=False,
dynamic_shapes={
"query": {2: dyn_dim},
"key": {2: dyn_dim},
"value": {2: dyn_dim},
"attn_mask": {0: dyn_dim, 1: dyn_dim},
},
)
DEBUG_LOGGING_DIR = "./debug_logs"
with torch_tensorrt.dynamo.Debugger(
"graphs",
logging_dir=DEBUG_LOGGING_DIR,
capture_fx_graph_after=["complex_graph_detection"],
save_engine_profile=True,
profile_format="trex",
engine_builder_monitor=True,
):
trt_model = torch_tensorrt.dynamo.compile(
exp_program,
inputs=inputs,
enabled_precisions={dtype},
min_block_size=1,
cache_built_engines=False,
reuse_cached_engines=False,
truncate_double=True,
use_python_runtime=False,
)
outputs_trt = trt_model(*inputs)
breakpoint()
assert torch.allclose(output_pyt, outputs_trt, rtol=1e-2, atol=1e-2)

print("Done")
4 changes: 3 additions & 1 deletion tools/llm/torchtrt_ext/register_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def replace_variants_of_sdpa(
logger.warning(
f"This current version of SDPA converter only supports attn_mask = None, dropout_p = 0.0 and is_causal = True configuration. This could cause issues with accuracy for models with different configurations."
)
modified_input_args = (query, key, value, None, dropout_p, True)
# TODO: lan to figure out why the attn_mask passed in from transformers is not working
# modified_input_args = (query, key, value, None, dropout_p, True)
modified_input_args = (query, key, value, attn_mask, dropout_p, is_causal)
# Create a new node with torch.nn.functional.scaled_dot_product_attention
# The input args is (query, key, value, is_causal). kwargs has scale
with gm.graph.inserting_after(node):
Expand Down
132 changes: 96 additions & 36 deletions tools/llm/torchtrt_ext/sdpa_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,51 @@ def tril(
name: str,
row: TRTTensor,
col: TRTTensor,
sliding_window_size: Optional[int] = None,
) -> TRTTensor:

row_arange_tensor = impl.arange.arange(
ctx, target, source_ir, name + "_arange_row", start=0, end=row, step=1
)
row_reshape_tensor = impl.shuffle.reshape(
ctx, target, source_ir, name + "_reshape_row", row_arange_tensor, [row, 1]
)

col_arange_tensor = impl.arange.arange(
ctx, target, source_ir, name + "_arange_col", start=0, end=col, step=1
)
col_reshape_tensor = impl.shuffle.reshape(
ctx, target, source_ir, name + "_reshape_col", col_arange_tensor, [1, col]
row_arange_tensor = impl.unsqueeze.unsqueeze(
ctx, target, source_ir, name + "_unsqueeze_row", row_arange_tensor, -1
)
col_arange_tensor = impl.unsqueeze.unsqueeze(
ctx, target, source_ir, name + "_unsqueeze_col", col_arange_tensor, 0
)
# sub will return the following mask tensor:
# [[0, -1, -2, -3],
# [1, 0, -1, -2],
# [2, 1, 0, -1],
# [3, 2, 1, 0]]
mask = impl.elementwise.sub(
ctx, target, source_ir, name + "_sub", row_arange_tensor, col_arange_tensor
)
ge_0_mask = impl.elementwise.ge(ctx, target, source_ir, name + "_ge_0", mask, 0.0)
if sliding_window_size is None:
# return the following lower triangular mask includes the main diagonal:
# 0 ■ ⬚ ⬚ ⬚ ⬚ tensor([[[[ True, False, False, False, False],
# 1 ■ ■ ⬚ ⬚ ⬚ [ True, True, False, False, False],
# 2 ■ ■ ■ ⬚ ⬚ [ True, True, True, False, False],
# 3 ■ ■ ■ ■ ⬚ [ True, True, True, True, False],
# 4 ■ ■ ■ ■ ■ [ True, True, True, True, True]]]])
return ge_0_mask

mask = impl.elementwise.ge(
ctx, target, source_ir, name + "_ge", row_reshape_tensor, col_reshape_tensor
lt_window_mask = impl.elementwise.lt(
ctx, target, source_ir, name + "_lt_window_size", mask, sliding_window_size
)
mask = impl.elementwise.logical_and(
ctx, target, source_ir, name + "_logical_and", ge_0_mask, lt_window_mask
)
# return the following mask if sliding_window_size is 3:
# 0 ■ ⬚ ⬚ ⬚ ⬚ tensor([[[[ True, False, False, False, False],
# 1 ■ ■ ⬚ ⬚ ⬚ [ True, True, False, False, False],
# 2 ■ ■ ■ ⬚ ⬚ [ True, True, True, False, False],
# 3 ⬚ ■ ■ ■ ⬚ [False, True, True, True, False],
# 4 ⬚ ⬚ ■ ■ ■ [False, False, True, True,True]]]])
return mask


Expand All @@ -66,7 +93,7 @@ def scaled_dot_product_attention(
# TODO: remove this once we have a better way to handle the causal mask
scale = kwargs.get("scale", None)
source_ir = SourceIR.ATEN
is_causal = True

# implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
use_fp32_acc = kwargs.get("use_fp32_acc", False)
query_dtype = query.dtype
Expand Down Expand Up @@ -134,37 +161,70 @@ def scaled_dot_product_attention(
L = impl.shape.shape(ctx, target, source_ir, name + "_shape_0", query, 2)
if S < 0:
S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, 2)

# generate the mask tensor
tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S)
if is_causal:
tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S)
else:
# TODO: lan to figure out why attn_mask passed in from transformers is not working
# tried both 2d and 4d, but both are not working
assert len(attn_mask.shape) in [
2,
4,
], f"attn_mask must be 2D or 4D, but got {attn_mask.shape=}"
if len(attn_mask.shape) == 4:
if attn_mask.shape[0] != 1:
attn_mask = impl.slice.slice_op(
ctx, target, source_ir, name + "_slice", attn_mask, 0, 0, 1, 1
)
if attn_mask.shape[1] != 1:
attn_mask = impl.slice.slice_op(
ctx, target, source_ir, name + "_slice", attn_mask, 1, 0, 1, 1
)
attn_mask = impl.squeeze.squeeze(
ctx, target, source_ir, name + "_squeeze", attn_mask, (0, 1)
)
tril_tensor = attn_mask

temp_mask = impl.unary.logical_not(
ctx, target, source_ir, name + "_logical_not", tril_tensor
)
# generate attn_bias via where instead of (logical_and, sub, log) to see whether nan is related to this
attn_bias_via_where = True
if attn_bias_via_where:
attn_bias = impl.condition.where(
ctx,
target,
source_ir,
name + "_where",
torch.tensor(0.0, dtype=torch.float32).cuda(),
torch.tensor(-float("inf"), dtype=torch.float32).cuda(),
tril_tensor,
)
else:
temp_mask = impl.unary.logical_not(
ctx, target, source_ir, name + "_logical_not", tril_tensor
)

# This need_mask determines if we want to use the causal mask or not
# When KV caching is enabled, L = 1 and != S. In this case, we shouldn't use the causal mask.
# So need_mask will be all False values in this case.
# TODO: Implement more general case where L != 1 and S != L
need_mask = impl.elementwise.eq(ctx, target, source_ir, name + "_eq", L, S)
temp_mask = impl.elementwise.logical_and(
ctx, target, source_ir, name + "_logical_and", need_mask, temp_mask
)
temp_mask_casted = cast_trt_tensor(
ctx, temp_mask, query_dtype, name + "_casted_bool", target, source_ir
)
# This need_mask determines if we want to use the causal mask or not
# When KV caching is enabled, L = 1 and != S. In this case, we shouldn't use the causal mask.
# So need_mask will be all False values in this case.
# TODO: Implement more general case where L != 1 and S != L
need_mask = impl.elementwise.eq(ctx, target, source_ir, name + "_eq", L, S)
temp_mask = impl.elementwise.logical_and(
ctx, target, source_ir, name + "_logical_and", need_mask, temp_mask
)
temp_mask_casted = cast_trt_tensor(
ctx, temp_mask, query_dtype, name + "_casted_bool", target, source_ir
)

one_minus_temp_mask = impl.elementwise.sub(
ctx,
target,
source_ir,
name + "_one_minus_temp_mask",
1.0,
temp_mask_casted,
)
attn_bias = impl.unary.log(
ctx, target, source_ir, name + "_log", one_minus_temp_mask
)
one_minus_temp_mask = impl.elementwise.sub(
ctx,
target,
source_ir,
name + "_one_minus_temp_mask",
1.0,
temp_mask_casted,
)
attn_bias = impl.unary.log(
ctx, target, source_ir, name + "_log", one_minus_temp_mask
)

scaled_add_attn_bias = impl.elementwise.add(
ctx, target, source_ir, name + "_attn_bias_add", mm, attn_bias
Expand Down
1 change: 0 additions & 1 deletion tools/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@ def generate_with_dynamic_cache(model, input_seq, max_output_seq_length, eos_tok
num_tokens_generated = 0
kv_cache = get_zeroed_dynamic_cache_inputs(model)
last_position_id = position_ids[-1, -1].item()
breakpoint()
while num_tokens_generated < num_output_tokens:
is_generate = False if input_seq.shape[1] > 1 else True
position_ids = (
Expand Down
Loading