From accdc20de136026c3e67c169dabd91a667113968 Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Wed, 26 Nov 2025 20:18:14 +0000 Subject: [PATCH 1/3] update and pass existing tests --- csrc/trtllm_fmha_kernel_launcher.cu | 55 +++++++++++++++----- flashinfer/decode.py | 23 ++++++-- tests/attention/test_trtllm_gen_attention.py | 4 ++ 3 files changed, 65 insertions(+), 17 deletions(-) diff --git a/csrc/trtllm_fmha_kernel_launcher.cu b/csrc/trtllm_fmha_kernel_launcher.cu index 89fe53b874..6928b2196d 100644 --- a/csrc/trtllm_fmha_kernel_launcher.cu +++ b/csrc/trtllm_fmha_kernel_launcher.cu @@ -26,6 +26,7 @@ #include #include +#include "tvm/ffi/error.h" #include "tvm_ffi_utils.h" using tvm::ffi::Optional; @@ -163,6 +164,9 @@ void trtllm_paged_attention_launcher( use_multi_block ? TileScheduler::Static : TileScheduler::Persistent; runner_params.mMultiCtasKvMode = use_multi_block; + runner_params.cumSeqLensQPtr = cum_seq_lens_q; + runner_params.cumSeqLensKvPtr = cum_seq_lens_kv; + size_t max_batch_size = 8192; // todo(Yingyi): get from dlfw size_t max_num_qo_heads = 256; // todo(Yingyi): get from dlfw, in total 8MB size_t num_semaphores = @@ -213,7 +217,11 @@ void trtllm_paged_attention_decode( TensorView seq_lens, int64_t max_kv_len, Variant bmm1_scale, Variant bmm2_scale, double o_sf_scale, int64_t o_sf_vec_size, int64_t o_sf_start_index, int64_t window_left, int64_t sparse_mla_top_k, int64_t sm_count, - bool enable_pdl, int64_t workspace_size, Optional attention_sinks) { + bool enable_pdl, int64_t workspace_size, Optional attention_sinks, + Optional optional_max_q_len, + Optional cum_seq_lens_q, + Optional cum_seq_lens_kv + ) { auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype()); auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype()); TVM_FFI_ICHECK_EQ(key_cache.ndim(), value_cache.ndim()); @@ -221,14 +229,37 @@ void trtllm_paged_attention_decode( TVM_FFI_ICHECK_EQ(key_cache.size(i), value_cache.size(i)); } auto o_data_type = dl_dtype_to_tllm_data_type(out.dtype()); - // NOTE(Zihao): query is [B, Q, H, D] - // where Q is the number of query tokens per request, used in MTP - // based on profiled results, always use decode mode for MTP (q_len is small) - // example: when kv_len = 10000, q < 200, decode mode is faster - int batch_size = query.size(0); - int q_len_per_request = query.size(1); - int sum_seq_q = batch_size * q_len_per_request; - int num_qo_heads = query.size(2); + int batch_size; + int max_q_len; + int sum_seq_q; + int num_qo_heads; + int* cum_seq_lens_q_ptr = nullptr; + int* cum_seq_lens_kv_ptr = nullptr; + if (!optional_max_q_len.has_value()) { + // each request has the same length + + // NOTE(Zihao): query is [B, Q, H, D] + // where Q is the number of query tokens per request, used in MTP + // based on profiled results, always use decode mode for MTP (q_len is small) + // example: when kv_len = 10000, q < 200, decode mode is faster + int q_len_per_request = query.size(1); + batch_size = query.size(0); + sum_seq_q = batch_size * q_len_per_request; + num_qo_heads = query.size(2); + max_q_len = q_len_per_request; + } else { + // each request has different length + TVM_FFI_CHECK(cum_seq_lens_q.has_value(), "cum_seq_lens_q must be provided when max_q_len is provided"); + TVM_FFI_CHECK(cum_seq_lens_kv.has_value(), "cum_seq_lens_kv must be provided when max_q_len is provided"); + // the shape of query: [sum_seq_q, num_qo_heads, head_dim_q] + // the shape of cum_seq_lens_q: [batch_size + 1] + batch_size = cum_seq_lens_q.value().size(0) - 1; + sum_seq_q = query.size(0); + num_qo_heads = query.size(1); + max_q_len = optional_max_q_len.value(); + cum_seq_lens_q_ptr = static_cast(cum_seq_lens_q.value().data_ptr()); + cum_seq_lens_kv_ptr = static_cast(cum_seq_lens_kv.value().data_ptr()); + } // Multiply by two for FP4 tensor as it is stored as UINT8 dtype. Assume the dim is even. int head_dim_k = is_4bit(kv_data_type) ? key_cache.size(-1) * 2 : key_cache.size(-1); int head_dim_q = is_4bit(q_data_type) ? query.size(-1) * 2 : query.size(-1); @@ -285,9 +316,9 @@ void trtllm_paged_attention_decode( out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(), workspace_buffer.data_ptr(), static_cast(block_tables.data_ptr()), static_cast(seq_lens.data_ptr()), - /*cum_seq_lens_q=*/nullptr, - /*cum_seq_lens_kv=*/nullptr, attention_sinks_ptr, q_data_type, kv_data_type, o_data_type, - TllmPagedAttentionMode::ForGen, batch_size, /*max_q_len=*/q_len_per_request, max_kv_len, + cum_seq_lens_q_ptr, + cum_seq_lens_kv_ptr, attention_sinks_ptr, q_data_type, kv_data_type, o_data_type, + TllmPagedAttentionMode::ForGen, batch_size, max_q_len, max_kv_len, num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size, kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 3f9f03ebb7..b878d51a17 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -1927,6 +1927,9 @@ def _paged_run( enable_pdl, workspace_size, sinks, + None, # max_q_len + None, # cum_seq_lens_q + None # cum_seq_lens_kv ) return out @@ -2073,7 +2076,7 @@ def trtllm_batch_decode_with_kv_cache( workspace_buffer: torch.Tensor, block_tables: torch.Tensor, seq_lens: torch.Tensor, - max_seq_len: int, + max_kv_len: int, bmm1_scale: Union[float, torch.Tensor] = 1.0, bmm2_scale: Union[float, torch.Tensor] = 1.0, window_left: int = -1, @@ -2088,12 +2091,15 @@ def trtllm_batch_decode_with_kv_cache( q_len_per_req: Optional[int] = 1, o_scale: Optional[float] = 1.0, mask: Optional[torch.Tensor] = None, + max_q_len: Optional[int] = None, + cum_seq_lens_q: Optional[torch.Tensor] = None, + cum_seq_lens_kv: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, FP4Tensor]: """ Parameters ---------- query : torch.Tensor - query tensor with shape [num_tokens, num_heads, head_dim], num_tokens = batch_size * q_len_per_request + query tensor with shape [num_tokens, num_heads, head_dim], num_tokens = total query tokens in the batch. kv_cache : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] If kv_cache is a single tensor, it should be a tensor with shape [num_pages, 1 or 2, num_kv_heads, page_size, head_dim] if :attr:`kv_layout` is ``HND``, @@ -2192,6 +2198,10 @@ def trtllm_batch_decode_with_kv_cache( raise ValueError("xqa backend does not support nvfp4 output") if o_sf_scale is not None or o_sf_vec_size is not None: raise ValueError("xqa backend does not support o_sf_scale or o_sf_vec_size") + if max_q_len is not None or cum_seq_lens_q is not None or cum_seq_lens_kv is not None: + raise ValueError( + "xqa backend does not support cum_seq_lens_q or cum_seq_lens_kv" + ) # Handle out and out_dtype if out_dtype is None: @@ -2206,7 +2216,7 @@ def trtllm_batch_decode_with_kv_cache( workspace_buffer=workspace_buffer, block_tables=block_tables, seq_lens=seq_lens, - max_seq_len=max_seq_len, + max_seq_len=max_kv_len, bmm1_scale=bmm1_scale, bmm2_scale=bmm2_scale, window_left=window_left, @@ -2316,13 +2326,13 @@ def trtllm_batch_decode_with_kv_cache( q_len_per_req, query.size(1), query.size(2), - ), + ) if q_len_per_req is not None else query, k_cache, v_cache, workspace_buffer, block_tables, seq_lens, - max_seq_len, + max_kv_len, bmm1_scale, bmm2_scale, o_sf_scale or -1.0, @@ -2334,6 +2344,9 @@ def trtllm_batch_decode_with_kv_cache( enable_pdl, workspace_buffer.numel() * workspace_buffer.element_size(), sinks, + max_q_len, + cum_seq_lens_q, + cum_seq_lens_kv, ) return ( diff --git a/tests/attention/test_trtllm_gen_attention.py b/tests/attention/test_trtllm_gen_attention.py index dd0002ff06..56828f065c 100755 --- a/tests/attention/test_trtllm_gen_attention.py +++ b/tests/attention/test_trtllm_gen_attention.py @@ -1434,3 +1434,7 @@ def test_trtllm_gen_prefill_deepseek_bs1( test_trtllm_gen_prefill_deepseek( batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal ) + + +if __name__ == "__main__": + pytest.main([__file__]) From b321b273826d27ab985e656ec5bd38a5099050c7 Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Wed, 26 Nov 2025 20:43:08 +0000 Subject: [PATCH 2/3] wip --- tests/attention/test_trtllm_gen_attention.py | 111 ++++++++++++++++++- 1 file changed, 105 insertions(+), 6 deletions(-) diff --git a/tests/attention/test_trtllm_gen_attention.py b/tests/attention/test_trtllm_gen_attention.py index 56828f065c..1302564413 100755 --- a/tests/attention/test_trtllm_gen_attention.py +++ b/tests/attention/test_trtllm_gen_attention.py @@ -1,5 +1,8 @@ import math +import sys +sys.path.append("./") + import pytest import torch from tests.test_helpers.utils_fp4 import ( @@ -54,8 +57,13 @@ def generate_seq_lens_prefill(batch_size, max_q_len, max_in_kv_len): return q_lens, in_kv_lens, seq_lens -def generate_seq_lens_decode(batch_size, q_len_per_req, max_in_kv_len): - q_lens = torch.full((batch_size,), q_len_per_req, dtype=torch.int32) +def generate_seq_lens_decode(batch_size, q_len_per_req, max_in_kv_len, max_q_len): + if q_len_per_req is not None: + assert max_q_len is None, "Can not specify both q_len_per_req and max_q_len." + q_lens = torch.full((batch_size,), q_len_per_req, dtype=torch.int32) + else: + assert max_q_len is not None, "Must specify either q_len_per_req or max_q_len." + q_lens = torch.randint(1, max_q_len + 1, (batch_size,), dtype=torch.int32) in_kv_lens = torch.randint(0, max_in_kv_len + 1, (batch_size,), dtype=torch.int) in_kv_lens[-1] = max_in_kv_len seq_lens = q_lens + in_kv_lens @@ -746,6 +754,7 @@ def _test_trtllm_batch_decode( max_in_kv_len, head_dim, device_scale=False, + max_q_len=None, ): """ Common function for testing trtllm-gen decode. @@ -780,7 +789,7 @@ def _test_trtllm_batch_decode( # Generate random sequence lengths num_qo_heads = num_kv_heads * head_grp_size q_lens, in_kv_lens, seq_lens = generate_seq_lens_decode( - batch_size, q_len_per_req, max_in_kv_len + batch_size, q_len_per_req, max_in_kv_len, max_q_len ) # Create query tensor and related data @@ -835,7 +844,7 @@ def _test_trtllm_batch_decode( "window_left": window_left, } if not enable_sink: - if q_len_per_req == 1: + if q_len_per_req is not None and q_len_per_req == 1: wrapper_ref = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( workspace_buffer_ref, kv_layout, use_tensor_cores=True ) @@ -923,6 +932,9 @@ def _test_trtllm_batch_decode( q_len_per_req=q_len_per_req, o_scale=o_scale, mask=mask, + max_q_len=max_q_len if max_q_len is not None else None, + cum_seq_lens_q=q_indptr if max_q_len is not None else None, + cum_seq_lens_kv=kv_indptr if max_q_len is not None else None, ) if backend == "trtllm-gen": # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero @@ -948,7 +960,7 @@ def _test_trtllm_batch_decode( # convert to float32 for fp8 is not supported by assert_close # relax rtol and atol for speculative decoding test - if q_len_per_req > 1: + if (q_len_per_req and q_len_per_req > 1) or (max_q_len and max_q_len > 1): rtol, atol = rtol * 2, atol * 2 # Arbitary small mismatch rate @@ -1436,5 +1448,92 @@ def test_trtllm_gen_prefill_deepseek_bs1( ) +def test_trtllm_batch_decode_spec( + kv_layout, + batch_size, + max_q_len, + page_size, + num_kv_heads, + head_grp_size, + window_left, + q_dtype, + o_dtype, + kv_dtype, + enable_pdl, + enable_sink, + max_in_kv_len, + head_dim, +): + _test_trtllm_batch_decode( + "trtllm-gen", + kv_layout, + batch_size, + None, # q_len_per_req + page_size, + num_kv_heads, + head_grp_size, + window_left, + q_dtype, + o_dtype, + kv_dtype, + enable_pdl, + enable_sink, + max_in_kv_len, + head_dim, + max_q_len=max_q_len, + ) + + if __name__ == "__main__": - pytest.main([__file__]) + # pytest.main([__file__]) + test_trtllm_batch_decode_spec( + kv_layout="HND", + batch_size=4, + max_q_len=12, + page_size=64, + num_kv_heads=4, + head_grp_size=1, + window_left=-1, + q_dtype="bf16", + kv_dtype="bf16", + o_dtype="bf16", + enable_pdl=None, + enable_sink=False, + max_in_kv_len=110, + head_dim=128, + ) + # _test_trtllm_batch_decode( + # backend='trtllm-gen', + # kv_layout="HND", + # batch_size=4, + # q_len_per_req=3, + # page_size=64, + # num_kv_heads=4, + # head_grp_size=1, + # window_left=-1, + # q_dtype="bf16", + # kv_dtype="bf16", + # o_dtype="bf16", + # enable_pdl=None, + # enable_sink=False, + # max_in_kv_len=110, + # head_dim=128, + # ) + + # _test_trtllm_batch_decode( + # backend='trtllm-gen', + # kv_layout="HND", + # batch_size=4, + # q_len_per_req=1, + # page_size=64, + # num_kv_heads=4, + # head_grp_size=1, + # window_left=-1, + # q_dtype="fp8", + # kv_dtype="fp8", + # o_dtype="nvfp4", + # enable_pdl=None, + # enable_sink=False, + # max_in_kv_len=110, + # head_dim=128, + # ) From a0d6c5dbca012cfd75dce254ed2a25d86d09305a Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Wed, 26 Nov 2025 21:24:03 +0000 Subject: [PATCH 3/3] wip --- tests/attention/test_trtllm_gen_attention.py | 71 ++++++++++---------- 1 file changed, 35 insertions(+), 36 deletions(-) diff --git a/tests/attention/test_trtllm_gen_attention.py b/tests/attention/test_trtllm_gen_attention.py index 1302564413..ab7f59e05e 100755 --- a/tests/attention/test_trtllm_gen_attention.py +++ b/tests/attention/test_trtllm_gen_attention.py @@ -895,7 +895,7 @@ def _test_trtllm_batch_decode( kv_indptr=kv_indptr_tokens, ) - if q_len_per_req > 1: + if (q_len_per_req and q_len_per_req > 1): mask = generate_causal_mask(batch_size, q_len_per_req, GPU_DEVICE) else: mask = None @@ -1502,38 +1502,37 @@ def test_trtllm_batch_decode_spec( max_in_kv_len=110, head_dim=128, ) - # _test_trtllm_batch_decode( - # backend='trtllm-gen', - # kv_layout="HND", - # batch_size=4, - # q_len_per_req=3, - # page_size=64, - # num_kv_heads=4, - # head_grp_size=1, - # window_left=-1, - # q_dtype="bf16", - # kv_dtype="bf16", - # o_dtype="bf16", - # enable_pdl=None, - # enable_sink=False, - # max_in_kv_len=110, - # head_dim=128, - # ) - - # _test_trtllm_batch_decode( - # backend='trtllm-gen', - # kv_layout="HND", - # batch_size=4, - # q_len_per_req=1, - # page_size=64, - # num_kv_heads=4, - # head_grp_size=1, - # window_left=-1, - # q_dtype="fp8", - # kv_dtype="fp8", - # o_dtype="nvfp4", - # enable_pdl=None, - # enable_sink=False, - # max_in_kv_len=110, - # head_dim=128, - # ) + _test_trtllm_batch_decode( + backend='trtllm-gen', + kv_layout="HND", + batch_size=4, + q_len_per_req=3, + page_size=64, + num_kv_heads=4, + head_grp_size=1, + window_left=-1, + q_dtype="bf16", + kv_dtype="bf16", + o_dtype="bf16", + enable_pdl=None, + enable_sink=False, + max_in_kv_len=110, + head_dim=128, + ) + _test_trtllm_batch_decode( + backend='trtllm-gen', + kv_layout="HND", + batch_size=4, + q_len_per_req=1, + page_size=64, + num_kv_heads=4, + head_grp_size=1, + window_left=-1, + q_dtype="fp8", + kv_dtype="fp8", + o_dtype="nvfp4", + enable_pdl=None, + enable_sink=False, + max_in_kv_len=110, + head_dim=128, + )