From c73ba5c3fc8233fa9f38744b996f84c5fb284cf8 Mon Sep 17 00:00:00 2001 From: Avery Yingyi Huang Date: Wed, 3 Sep 2025 00:15:49 -0400 Subject: [PATCH 01/13] init --- csrc/trtllm_fmha_kernel_launcher.cu | 83 ++++++++++++++++++++--------- flashinfer/decode.py | 33 ++++++++++++ flashinfer/prefill.py | 38 ++++++++++++- tests/test_trtllm_gen_attention.py | 2 + 4 files changed, 129 insertions(+), 27 deletions(-) diff --git a/csrc/trtllm_fmha_kernel_launcher.cu b/csrc/trtllm_fmha_kernel_launcher.cu index 37ae25d6df..69c78c528b 100644 --- a/csrc/trtllm_fmha_kernel_launcher.cu +++ b/csrc/trtllm_fmha_kernel_launcher.cu @@ -79,9 +79,10 @@ void trtllm_paged_attention_launcher( int64_t max_kv_len, int64_t num_pages_in_mem_pool, int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim_qk, int64_t head_dim_vo, int64_t page_size, int64_t kv_stride_keys_values, int64_t kv_stride_heads, int64_t kv_stride_batch, int64_t max_num_blocks_per_seq, - double bmm1_scale, double bmm2_scale, double o_sf_scale, int64_t o_sf_vec_size, - int64_t o_sf_start_index, int64_t window_left, int64_t sum_seq_q, int64_t sm_count, - bool enable_pdl, int64_t workspace_size, cudaStream_t stream) { + double bmm1_scale, double bmm2_scale, float* bmm1_scale_log2_ptr, float* bmm2_scale_ptr, + double o_sf_scale, int64_t o_sf_vec_size, int64_t o_sf_start_index, int64_t window_left, + int64_t sum_seq_q, int64_t sm_count, bool enable_pdl, int64_t workspace_size, + cudaStream_t stream) { if (num_qo_heads % num_kv_heads != 0) { std::ostringstream err_msg; err_msg << "num_qo_heads must be a multiple of num_kv_heads, got num_kv_heads: " << num_kv_heads @@ -120,6 +121,8 @@ void trtllm_paged_attention_launcher( runner_params.stream = stream; runner_params.outputScale = bmm2_scale; runner_params.scaleSoftmaxLog2 = bmm1_scale * M_LOG2E; + runner_params.outputScalePtr = bmm2_scale_ptr; + runner_params.scaleSoftmaxLog2Ptr = bmm1_scale_log2_ptr; runner_params.oSfPtr = out_scale_factor; runner_params.mSfStartTokenIdx = o_sf_start_index; runner_params.mScaleSfO = o_sf_scale; @@ -202,7 +205,9 @@ void trtllm_paged_attention_decode(at::Tensor out, std::optional out double bmm2_scale, double o_sf_scale, int64_t o_sf_vec_size, int64_t o_sf_start_index, int64_t window_left, int64_t sm_count, bool enable_pdl, int64_t workspace_size, - std::optional attention_sinks) { + std::optional attention_sinks, + std::optional bmm1_scale_log2_tensor, + std::optional bmm2_scale_tensor) { auto q_data_type = torch_dtype_to_tllm_data_type(query.scalar_type()); auto kv_data_type = torch_dtype_to_tllm_data_type(key_cache.scalar_type()); TORCH_CHECK_EQ(key_cache.dim(), value_cache.dim()); @@ -249,6 +254,14 @@ void trtllm_paged_attention_decode(at::Tensor out, std::optional out "attention_sinks must be a float tensor"); attention_sinks_ptr = attention_sinks->data_ptr(); } + float* bmm1_scale_log2_ptr = nullptr; + float* bmm2_scale_ptr = nullptr; + if (bmm1_scale_log2_tensor.has_value()) { + bmm1_scale_log2_ptr = static_cast(bmm1_scale_log2_tensor.value().data_ptr()); + } + if (bmm2_scale_tensor.has_value()) { + bmm2_scale_ptr = static_cast(bmm2_scale_tensor.value().data_ptr()); + } trtllm_paged_attention_launcher( out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(), @@ -259,20 +272,19 @@ void trtllm_paged_attention_decode(at::Tensor out, std::optional out TllmPagedAttentionMode::ForGen, batch_size, /*max_q_len=*/q_len_per_request, max_kv_len, num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size, kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, bmm1_scale, - bmm2_scale, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sm_count, - enable_pdl, workspace_size, stream); + bmm2_scale, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, + window_left, sum_seq_q, sm_count, enable_pdl, workspace_size, stream); } -void trtllm_paged_attention_context(at::Tensor out, std::optional out_scale_factor, - at::Tensor query, at::Tensor key_cache, at::Tensor value_cache, - at::Tensor workspace_buffer, at::Tensor block_tables, - at::Tensor seq_lens, int64_t max_q_len, int64_t max_kv_len, - double bmm1_scale, double bmm2_scale, double o_sf_scale, - int64_t o_sf_vec_size, int64_t o_sf_start_index, - int64_t batch_size, int64_t window_left, - at::Tensor cum_seq_lens_q, at::Tensor cum_seq_lens_kv, - int64_t sm_count, bool enable_pdl, int64_t workspace_size, - std::optional attention_sinks) { +void trtllm_paged_attention_context( + at::Tensor out, std::optional out_scale_factor, at::Tensor query, + at::Tensor key_cache, at::Tensor value_cache, at::Tensor workspace_buffer, + at::Tensor block_tables, at::Tensor seq_lens, int64_t max_q_len, int64_t max_kv_len, + double bmm1_scale, double bmm2_scale, double o_sf_scale, int64_t o_sf_vec_size, + int64_t o_sf_start_index, int64_t batch_size, int64_t window_left, at::Tensor cum_seq_lens_q, + at::Tensor cum_seq_lens_kv, int64_t sm_count, bool enable_pdl, int64_t workspace_size, + std::optional attention_sinks, std::optional bmm1_scale_log2_tensor, + std::optional bmm2_scale_tensor) { auto q_data_type = torch_dtype_to_tllm_data_type(query.scalar_type()); auto kv_data_type = torch_dtype_to_tllm_data_type(key_cache.scalar_type()); auto o_data_type = torch_dtype_to_tllm_data_type(out.scalar_type()); @@ -309,7 +321,14 @@ void trtllm_paged_attention_context(at::Tensor out, std::optional ou "attention_sinks must be a float tensor"); attention_sinks_ptr = attention_sinks->data_ptr(); } - + float* bmm1_scale_log2_ptr = nullptr; + float* bmm2_scale_ptr = nullptr; + if (bmm1_scale_log2_tensor.has_value()) { + bmm1_scale_log2_ptr = static_cast(bmm1_scale_log2_tensor.value().data_ptr()); + } + if (bmm2_scale_tensor.has_value()) { + bmm2_scale_ptr = static_cast(bmm2_scale_tensor.value().data_ptr()); + } trtllm_paged_attention_launcher( out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(), workspace_buffer.data_ptr(), static_cast(block_tables.data_ptr()), @@ -319,8 +338,9 @@ void trtllm_paged_attention_context(at::Tensor out, std::optional ou q_data_type, kv_data_type, o_data_type, TllmPagedAttentionMode::Context, batch_size, max_q_len, max_kv_len, num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size, kv_stride_keys_values, kv_stride_heads, kv_stride_batch, - max_num_blocks_per_seq, bmm1_scale, bmm2_scale, o_sf_scale, o_sf_vec_size, o_sf_start_index, - window_left, sum_seq_q, sm_count, enable_pdl, workspace_size, stream); + max_num_blocks_per_seq, bmm1_scale, bmm2_scale, bmm1_scale_log2_ptr, bmm2_scale_ptr, + o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sm_count, enable_pdl, + workspace_size, stream); } void trtllm_ragged_attention_launcher( @@ -329,8 +349,9 @@ void trtllm_ragged_attention_launcher( Data_type q_data_type, Data_type kv_data_type, Data_type o_data_type, int64_t max_q_len, int64_t max_kv_len, int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim_qk, int64_t head_dim_v, int64_t sum_seq_q, int64_t sum_seq_kv, double bmm1_scale, double bmm2_scale, - double o_sf_scale, int64_t batch_size, int64_t window_left, int64_t sm_count, bool enable_pdl, - bool is_causal, int64_t k_stride_keys_values, int64_t k_stride_heads, int64_t k_stride_batch, + float* bmm1_scale_log2_ptr, float* bmm2_scale_ptr, double o_sf_scale, int64_t batch_size, + int64_t window_left, int64_t sm_count, bool enable_pdl, bool is_causal, + int64_t k_stride_keys_values, int64_t k_stride_heads, int64_t k_stride_batch, int64_t v_stride_keys_values, int64_t v_stride_heads, int64_t v_stride_batch, int64_t workspace_size, cudaStream_t stream) { if (num_qo_heads % num_kv_heads != 0) { @@ -360,6 +381,8 @@ void trtllm_ragged_attention_launcher( runner_params.stream = stream; runner_params.outputScale = bmm2_scale; runner_params.scaleSoftmaxLog2 = bmm1_scale * M_LOG2E; + runner_params.outputScalePtr = bmm2_scale_ptr; + runner_params.scaleSoftmaxLog2Ptr = bmm1_scale_log2_ptr; runner_params.mScaleSfO = o_sf_scale; runner_params.mChunkedAttentionSize = INT_MAX; // disable chunked attention by INT_MAX runner_params.mAttentionWindowSize = @@ -417,7 +440,9 @@ void trtllm_ragged_attention(at::Tensor out, at::Tensor query, at::Tensor key, a at::Tensor cum_seq_lens_q, at::Tensor cum_seq_lens_kv, int64_t sm_count, bool enable_pdl, bool is_causal, int64_t workspace_size, std::optional attention_sinks, - std::optional lse) { + std::optional lse, + std::optional bmm1_scale_log2_tensor, + std::optional bmm2_scale_tensor) { float* attention_sinks_ptr = nullptr; if (attention_sinks) { TORCH_CHECK(attention_sinks->scalar_type() == at::ScalarType::Float, @@ -429,6 +454,14 @@ void trtllm_ragged_attention(at::Tensor out, at::Tensor query, at::Tensor key, a TORCH_CHECK(lse->scalar_type() == at::ScalarType::Float, "lse must be a float tensor"); lse_ptr = lse->data_ptr(); } + float* bmm1_scale_log2_ptr = nullptr; + float* bmm2_scale_ptr = nullptr; + if (bmm1_scale_log2_tensor.has_value()) { + bmm1_scale_log2_ptr = static_cast(bmm1_scale_log2_tensor.value().data_ptr()); + } + if (bmm2_scale_tensor.has_value()) { + bmm2_scale_ptr = static_cast(bmm2_scale_tensor.value().data_ptr()); + } TORCH_CHECK(out.dim() == 3, "out must be a 3D tensor"); TORCH_CHECK(query.dim() == 3, "query must be a 3D tensor"); TORCH_CHECK(key.dim() == 3, "key must be a 3D tensor"); @@ -458,9 +491,9 @@ void trtllm_ragged_attention(at::Tensor out, at::Tensor query, at::Tensor key, a static_cast(cum_seq_lens_q.data_ptr()), static_cast(cum_seq_lens_kv.data_ptr()), attention_sinks_ptr, lse_ptr, q_data_type, kv_data_type, o_data_type, max_q_len, max_kv_len, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_v, sum_seq_q, sum_seq_kv, bmm1_scale, - bmm2_scale, o_sf_scale, batch_size, window_left, sm_count, enable_pdl, is_causal, - k_stride_keys_values, k_stride_heads, k_stride_batch, v_stride_keys_values, v_stride_heads, - v_stride_batch, workspace_size, stream); + bmm1_scale_log2_ptr, bmm2_scale_ptr, bmm2_scale, o_sf_scale, batch_size, window_left, + sm_count, enable_pdl, is_causal, k_stride_keys_values, k_stride_heads, k_stride_batch, + v_stride_keys_values, v_stride_heads, v_stride_batch, workspace_size, stream); } namespace trtllm_cubin_loader { diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 4d1e36e361..60ebb7a23c 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -1147,6 +1147,8 @@ def run( window_left: Optional[int] = None, sinks: Optional[torch.Tensor] = None, q_len_per_req: Optional[int] = 1, + bmm1_scale_log2_tensor: Optional[torch.Tensor] = None, + bmm2_scale_tensor: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: r"""Compute batch decode attention between query and paged kv cache. @@ -1186,6 +1188,10 @@ def run( Only supported for >= sm90, and currently only for FA2 and CUDA core decode. q_len_per_req : int The number of query tokens per request, if not provided, will be set to ``1``. + bmm1_scale_log2_tensor : Optional[torch.Tensor] + The on-device fused scale tensor for bmm1 input. Must be fused with * M_LOG2E before passing in. + bmm2_scale_tensor : Optional[torch.Tensor] + The on-device fused scale tensor for bmm2 input. Returns ------- Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] @@ -1296,6 +1302,8 @@ def run( page_size, self._max_kv_len, sinks, + bmm1_scale_log2_tensor, + bmm2_scale_tensor, ] self._cached_module.paged_run(*run_args) @@ -1322,6 +1330,8 @@ def run( TensorLayout[self._kv_layout].value, window_left, enable_pdl, + bmm1_scale_log2_tensor, + bmm2_scale_tensor, ] if self._jit_module is not None: @@ -1832,6 +1842,8 @@ def _paged_run( enable_pdl: bool = None, out: Optional[torch.Tensor] = None, sinks: Optional[torch.Tensor] = None, + bmm1_scale_log2_tensor: Optional[torch.Tensor] = None, + bmm2_scale_tensor: Optional[torch.Tensor] = None, ) -> torch.Tensor: if out is None: out = torch.empty_like(query) @@ -1858,6 +1870,8 @@ def _paged_run( enable_pdl, workspace_size, sinks, + bmm1_scale_log2_tensor, + bmm2_scale_tensor, ) return out @@ -1919,6 +1933,8 @@ def paged_run( page_size: Optional[int] = None, max_kv_len: Optional[int] = None, sinks: Optional[torch.Tensor] = None, + bmm1_scale_log2_tensor: Optional[torch.Tensor] = None, + bmm2_scale_tensor: Optional[torch.Tensor] = None, ) -> None: assert maybe_lse is None assert paged_kv_cache is not None @@ -1944,6 +1960,8 @@ def paged_run( enable_pdl, out=o, sinks=sinks, + bmm1_scale_log2_tensor=bmm1_scale_log2_tensor, + bmm2_scale_tensor=bmm2_scale_tensor, ) @register_fake_op(f"flashinfer::{uri}_paged_run") @@ -1983,6 +2001,8 @@ def _fake_paged_run( page_size: Optional[int] = None, max_kv_len: Optional[int] = None, sinks: Optional[torch.Tensor] = None, + bmm1_scale_log2_tensor: Optional[torch.Tensor] = None, + bmm2_scale_tensor: Optional[torch.Tensor] = None, ) -> None: pass @@ -2013,6 +2033,8 @@ def trtllm_batch_decode_with_kv_cache( sinks: Optional[List[torch.Tensor]] = None, enable_pdl: bool = None, q_len_per_req: Optional[int] = 1, + bmm1_scale_log2_tensor: Optional[torch.Tensor] = None, + bmm2_scale_tensor: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, FP4Tensor]: """ Parameters @@ -2065,6 +2087,12 @@ def trtllm_batch_decode_with_kv_cache( Whether to enable Programmatic Dependent Launch (PDL). See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#programmatic-dependent-launch-and-synchronization Only supported for >= sm90, and currently only for FA2, CUDA core, and trtllm-gen decode. + bmm1_scale_log2_tensor : Optional[torch.Tensor] + The on-device fused scale tensor for bmm1 input. Must be fused with * M_LOG2E before passing in. + + bmm2_scale_tensor : Optional[torch.Tensor] + The on-device fused scale tensor for bmm2 input. + Returns ------- out : Union[torch.Tensor, FP4Tensor] @@ -2182,6 +2210,8 @@ def trtllm_batch_decode_with_kv_cache( enable_pdl, workspace_buffer.numel() * workspace_buffer.element_size(), sinks, + bmm1_scale_log2_tensor, + bmm2_scale_tensor, ) return ( @@ -2320,6 +2350,7 @@ def trtllm_batch_decode_with_kv_cache_mla( "out", ) + # todo(Yingyi): check support for dynamic scale factors if bmm1_scale_log2_tensor is not None and bmm2_scale_tensor is not None: # dynamic scale factors if query.dtype != torch.float8_e4m3fn or kv_cache.dtype != torch.float8_e4m3fn: @@ -2347,5 +2378,7 @@ def trtllm_batch_decode_with_kv_cache_mla( enable_pdl, workspace_buffer.numel() * workspace_buffer.element_size(), sinks, + bmm1_scale_log2_tensor, + bmm2_scale_tensor, ) return out diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index cd5ba70d24..a0ff13299e 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -193,6 +193,8 @@ def _paged_run( window_left: int = -1, out: Optional[torch.Tensor] = None, sinks: Optional[torch.Tensor] = None, + bmm1_scale_log2_tensor: Optional[torch.Tensor] = None, + bmm2_scale_tensor: Optional[torch.Tensor] = None, ) -> torch.Tensor: sm_count = get_device_sm_count(query.device) if out is None: @@ -221,6 +223,8 @@ def _paged_run( enable_pdl, workspace_size, sinks, + bmm1_scale_log2_tensor, + bmm2_scale_tensor, ) return out @@ -541,6 +545,8 @@ def paged_run( cum_seq_lens_q: Optional[torch.Tensor] = None, cum_seq_lens_kv: Optional[torch.Tensor] = None, sinks: Optional[torch.Tensor] = None, + bmm1_scale_log2_tensor: Optional[torch.Tensor] = None, + bmm2_scale_tensor: Optional[torch.Tensor] = None, ) -> None: if backend == "trtllm-gen": assert maybe_lse is None @@ -573,9 +579,13 @@ def paged_run( window_left, out=o, sinks=sinks, + bmm1_scale_log2_tensor=bmm1_scale_log2_tensor, + bmm2_scale_tensor=bmm2_scale_tensor, ) elif backend == "fa2": assert not is_float8(q) + assert bmm1_scale_log2_tensor is None + assert bmm2_scale_tensor is None paged_run_func( float_workspace_buffer, int_workspace_buffer, @@ -606,6 +616,8 @@ def paged_run( token_pos_in_items_len, ) else: + assert bmm1_scale_log2_tensor is None + assert bmm2_scale_tensor is None if not is_float8(q): paged_run_func( float_workspace_buffer, @@ -1957,6 +1969,8 @@ def run( enable_pdl: Optional[bool] = None, window_left: Optional[int] = None, sinks: Optional[torch.Tensor] = None, + bmm1_scale_log2_tensor: Optional[torch.Tensor] = None, + bmm2_scale_tensor: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: r"""Compute batch prefill/append attention between query and paged kv-cache. @@ -1993,6 +2007,10 @@ def run( enable_pdl : bool Whether to enable Programmatic Dependent Launch (PDL). See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#programmatic-dependent-launch-and-synchronization Only supported for >= sm90, and currently only for FA2 and CUDA core decode. + bmm1_scale_log2_tensor : Optional[torch.Tensor] + The on-device fused scale tensor for bmm1 input. Must be fused with * M_LOG2E before passing in. + bmm2_scale_tensor : Optional[torch.Tensor] + The on-device fused scale tensor for bmm2 input. Returns ------- Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] @@ -2160,6 +2178,8 @@ def run( self._qo_indptr_buf, self._vector_sparse_indptr_buffer, sinks, + bmm1_scale_log2_tensor, + bmm2_scale_tensor, ] assert self._cached_module is not None, "cached module is not initialized" @@ -3156,6 +3176,8 @@ def trtllm_ragged_attention_deepseek( attention_sinks: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, + bmm1_scale_log2_tensor: Optional[torch.Tensor] = None, + bmm2_scale_tensor: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Parameters @@ -3198,7 +3220,10 @@ def trtllm_ragged_attention_deepseek( output tensor, if not provided, will be allocated with shape [query.shape[0], query.shape[1], value.shape[2]] lse : Optional[torch.Tensor] lse tensor, if not provided, will be allocated with shape [query.shape[0], query.shape[1]] - + bmm1_scale_log2_tensor: Optional[torch.Tensor] = None + The on-device fused scale tensor for bmm1 input. Must be fused with * M_LOG2E before passing in. + bmm2_scale_tensor: Optional[torch.Tensor] = None + The on-device fused scale tensor for bmm2 input. Returns ------- out: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] @@ -3254,6 +3279,8 @@ def trtllm_ragged_attention_deepseek( workspace_size, attention_sinks, lse, + bmm1_scale_log2_tensor, + bmm2_scale_tensor, ) if return_lse: return out, lse @@ -3281,6 +3308,8 @@ def trtllm_batch_context_with_kv_cache( o_sf_vec_size: Optional[int] = None, enable_pdl: Optional[bool] = None, sinks: Optional[List[torch.Tensor]] = None, + bmm1_scale_log2_tensor: Optional[torch.Tensor] = None, + bmm2_scale_tensor: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, FP4Tensor]: """ Parameters @@ -3323,7 +3352,10 @@ def trtllm_batch_context_with_kv_cache( vector size for nvfp4 output tensor scale factor. sinks : Optional[List[torch.Tensor]] = None additional value per head in the denominator of the softmax. - + bmm1_scale_log2_tensor: Optional[torch.Tensor] = None + The on-device fused scale tensor for bmm1 input. Must be fused with * M_LOG2E before passing in. + bmm2_scale_tensor: Optional[torch.Tensor] = None + The on-device fused scale tensor for bmm2 input. Returns ------- out: Union[torch.Tensor, FP4Tensor] @@ -3446,6 +3478,8 @@ def trtllm_batch_context_with_kv_cache( enable_pdl, workspace_size, sinks, + bmm1_scale_log2_tensor, + bmm2_scale_tensor, ) return ( out diff --git a/tests/test_trtllm_gen_attention.py b/tests/test_trtllm_gen_attention.py index f42d418d04..970a70f8ee 100755 --- a/tests/test_trtllm_gen_attention.py +++ b/tests/test_trtllm_gen_attention.py @@ -575,6 +575,7 @@ def test_trtllm_batch_decode( o_sf_vec_size=o_sf_vec_size, enable_pdl=enable_pdl, q_len_per_req=q_len_per_req, + # todo(Yingyi): add bmm1_scale_log2_tensor and bmm2_scale_tensor later ) if o_dtype == "nvfp4": @@ -618,6 +619,7 @@ def test_trtllm_batch_decode( v_scale=v_scale / o_scale, enable_pdl=enable_pdl, q_len_per_req=q_len_per_req, + # todo(Yingyi): add bmm1_scale_log2_tensor and bmm2_scale_tensor later ) # v_scale, o_scale in wrapper is emulated by multiplying output by v_scale instead of fused into kernel. if v_scale == o_scale == 1.0: From a2f5be1fe50a90961273b7c571672deb700c1473 Mon Sep 17 00:00:00 2001 From: Avery Yingyi Huang Date: Wed, 3 Sep 2025 01:06:22 -0400 Subject: [PATCH 02/13] upd decode test --- tests/test_trtllm_gen_attention.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/test_trtllm_gen_attention.py b/tests/test_trtllm_gen_attention.py index 970a70f8ee..8400edfade 100755 --- a/tests/test_trtllm_gen_attention.py +++ b/tests/test_trtllm_gen_attention.py @@ -558,6 +558,10 @@ def test_trtllm_batch_decode( # Run trtllm-gen function call sm_scale = float(1.0 / (head_dim**0.5)) + bmm1_scale_log2_tensor = torch.tensor( + [sm_scale * math.log2(math.e)], device=GPU_DEVICE + ) + bmm2_scale_tensor = torch.tensor([1.0], device=GPU_DEVICE) output = flashinfer.decode.trtllm_batch_decode_with_kv_cache( q.contiguous(), @@ -575,7 +579,8 @@ def test_trtllm_batch_decode( o_sf_vec_size=o_sf_vec_size, enable_pdl=enable_pdl, q_len_per_req=q_len_per_req, - # todo(Yingyi): add bmm1_scale_log2_tensor and bmm2_scale_tensor later + bmm1_scale_log2_tensor=bmm1_scale_log2_tensor, + bmm2_scale_tensor=bmm2_scale_tensor, ) if o_dtype == "nvfp4": @@ -619,7 +624,8 @@ def test_trtllm_batch_decode( v_scale=v_scale / o_scale, enable_pdl=enable_pdl, q_len_per_req=q_len_per_req, - # todo(Yingyi): add bmm1_scale_log2_tensor and bmm2_scale_tensor later + bmm1_scale_log2_tensor=bmm1_scale_log2_tensor, + bmm2_scale_tensor=bmm2_scale_tensor, ) # v_scale, o_scale in wrapper is emulated by multiplying output by v_scale instead of fused into kernel. if v_scale == o_scale == 1.0: From 1e3ac69ac296f95d11423938a5eb90c08bde9e2e Mon Sep 17 00:00:00 2001 From: Avery Yingyi Huang Date: Wed, 3 Sep 2025 01:12:45 -0400 Subject: [PATCH 03/13] upd prefill test --- tests/test_trtllm_gen_attention.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/test_trtllm_gen_attention.py b/tests/test_trtllm_gen_attention.py index 8400edfade..83fddf654b 100755 --- a/tests/test_trtllm_gen_attention.py +++ b/tests/test_trtllm_gen_attention.py @@ -352,6 +352,10 @@ def test_trtllm_batch_prefill( # Run trtllm-gen function call sm_scale = float(1.0 / (head_dim**0.5)) + bmm1_scale_log2_tensor = torch.tensor( + [sm_scale * math.log2(math.e)], device=GPU_DEVICE + ) + bmm2_scale_tensor = torch.tensor([1.0], device=GPU_DEVICE) output = flashinfer.prefill.trtllm_batch_context_with_kv_cache( q.contiguous(), kv_cache, @@ -371,6 +375,8 @@ def test_trtllm_batch_prefill( o_sf_scale=o_sf_scale, o_sf_vec_size=o_sf_vec_size, enable_pdl=enable_pdl, + bmm1_scale_log2_tensor=bmm1_scale_log2_tensor, + bmm2_scale_tensor=bmm2_scale_tensor, ) if o_dtype == "nvfp4": @@ -406,6 +412,8 @@ def test_trtllm_batch_prefill( k_scale=k_scale, v_scale=v_scale / o_scale, enable_pdl=enable_pdl, + bmm1_scale_log2_tensor=bmm1_scale_log2_tensor, + bmm2_scale_tensor=bmm2_scale_tensor, ) # v_scale, o_scale in wrapper is emulated by multiplying output by v_scale instead of fused into kernel. if v_scale == o_scale == 1.0: @@ -751,6 +759,8 @@ def test_trtllm_gen_prefill_deepseek( bmm1_scale = scale bmm2_scale = 1.0 + bmm1_scale_log2_tensor = torch.tensor([scale * math.log2(math.e)], device=device) + bmm2_scale_tensor = torch.tensor([1.0], device=device) output_trtllm, lse_trtllm = flashinfer.prefill.trtllm_ragged_attention_deepseek( q, k_cache, @@ -770,6 +780,8 @@ def test_trtllm_gen_prefill_deepseek( causal, True, out=output, + bmm1_scale_log2_tensor=bmm1_scale_log2_tensor, + bmm2_scale_tensor=bmm2_scale_tensor, ) torch.testing.assert_close( output_trtllm, From 04bc848abf917ebfbc18498d1a71da9cca3848a4 Mon Sep 17 00:00:00 2001 From: Avery Yingyi Huang Date: Wed, 3 Sep 2025 01:16:24 -0400 Subject: [PATCH 04/13] upd mla test --- tests/test_trtllm_gen_attention.py | 131 ----------------------------- tests/test_trtllm_gen_mla.py | 131 +++++++++++++++++++++++++++++ 2 files changed, 131 insertions(+), 131 deletions(-) diff --git a/tests/test_trtllm_gen_attention.py b/tests/test_trtllm_gen_attention.py index 83fddf654b..9b7f5eaf0a 100755 --- a/tests/test_trtllm_gen_attention.py +++ b/tests/test_trtllm_gen_attention.py @@ -666,137 +666,6 @@ def test_trtllm_batch_decode( ) -@pytest.mark.parametrize("batch_size", [4, 128, 256]) -@pytest.mark.parametrize("s_qo", [32, 64, 87]) -@pytest.mark.parametrize("s_kv", [32, 64, 87]) -@pytest.mark.parametrize("num_kv_heads", [16, 32]) -@pytest.mark.parametrize("head_grp_size", [1, 5, 8]) -@pytest.mark.parametrize("causal", [True, False]) -def test_trtllm_gen_prefill_deepseek( - batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal -): - if s_qo > s_kv: - pytest.skip("s_qo > s_kv, skipping test as causal") - - num_qo_heads = num_kv_heads * head_grp_size - head_dim_qk = 192 - head_dim_vo = 128 - - seed = 0 - torch.manual_seed(seed) - device = "cuda:0" - - actual_seq_lens_q = torch.randint( - 1, s_qo + 1, (batch_size, 1, 1, 1), dtype=torch.int32, device=device - ) - - actual_seq_lens_kv = torch.randint( - s_qo, s_kv + 1, (batch_size, 1, 1, 1), dtype=torch.int32, device=device - ) - - cumsum_s_qo = torch.sum(actual_seq_lens_q) - cumsum_s_kv = torch.sum(actual_seq_lens_kv) - - q = torch.randn( - cumsum_s_qo, num_qo_heads, head_dim_qk, device=device, dtype=torch.bfloat16 - ) - - k_cache = torch.randn( - (cumsum_s_kv, num_kv_heads, head_dim_qk), - device=device, - dtype=torch.bfloat16, - ) - v_cache = torch.randn( - (cumsum_s_kv, num_kv_heads, head_dim_vo), - device=device, - dtype=torch.bfloat16, - ) - - # Initialize scale - scale = float(1.0 / (head_dim_qk**0.5)) - - workspace_buffer = torch.empty(workspace_size, dtype=torch.int8, device=device) - - qo_indptr = torch.cat( - [ - torch.tensor([0], device=device), - torch.cumsum(actual_seq_lens_q.view(-1), dim=0), - ] - ).int() - - # kv_indptr = torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * s_kv - - # Create kv_indptr as cumulative sum of actual_seq_lens_kv - kv_indptr = torch.cat( - [ - torch.tensor( - [0], - device=device, - ), - torch.cumsum(actual_seq_lens_kv.view(-1), dim=0), - ] - ).int() - - wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( - torch.zeros(workspace_size, device="cuda", dtype=torch.uint8), - kv_layout="NHD", - backend="cutlass", - ) - wrapper.plan( - qo_indptr, - kv_indptr, - num_qo_heads, - num_kv_heads, - head_dim_qk, - head_dim_vo=head_dim_vo, - causal=causal, - sm_scale=scale, - q_data_type=torch.bfloat16, - kv_data_type=torch.bfloat16, - ) - output_ref, lse_ref = wrapper.run(q, k_cache, v_cache, return_lse=True) - output = torch.empty_like(output_ref) - - bmm1_scale = scale - bmm2_scale = 1.0 - bmm1_scale_log2_tensor = torch.tensor([scale * math.log2(math.e)], device=device) - bmm2_scale_tensor = torch.tensor([1.0], device=device) - output_trtllm, lse_trtllm = flashinfer.prefill.trtllm_ragged_attention_deepseek( - q, - k_cache, - v_cache, - workspace_buffer, - actual_seq_lens_kv, - s_qo, - s_kv, - bmm1_scale, - bmm2_scale, - -1, - batch_size, - -1, - qo_indptr, - kv_indptr, - False, - causal, - True, - out=output, - bmm1_scale_log2_tensor=bmm1_scale_log2_tensor, - bmm2_scale_tensor=bmm2_scale_tensor, - ) - torch.testing.assert_close( - output_trtllm, - output_ref, - atol=1e-2, - rtol=1e-2, - ) - torch.testing.assert_close( - lse_trtllm, - lse_ref, - atol=1e-3, - rtol=1e-3, - ) - - if __name__ == "__main__": test_trtllm_batch_prefill("HND", 128, 32, 2, 5, -1, "fp16", "fp16", "fp16", False) test_trtllm_batch_decode("HND", 256, 3, 64, 4, 5, -1, "fp8", "fp8", "fp8", True) diff --git a/tests/test_trtllm_gen_mla.py b/tests/test_trtllm_gen_mla.py index e73da75337..0dc2349e11 100644 --- a/tests/test_trtllm_gen_mla.py +++ b/tests/test_trtllm_gen_mla.py @@ -9,6 +9,137 @@ workspace_size = 128 * 1024 * 1024 +@pytest.mark.parametrize("batch_size", [4, 128, 256]) +@pytest.mark.parametrize("s_qo", [32, 64, 87]) +@pytest.mark.parametrize("s_kv", [32, 64, 87]) +@pytest.mark.parametrize("num_kv_heads", [16, 32]) +@pytest.mark.parametrize("head_grp_size", [1, 5, 8]) +@pytest.mark.parametrize("causal", [True, False]) +def test_trtllm_gen_prefill_deepseek( + batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal +): + if s_qo > s_kv: + pytest.skip("s_qo > s_kv, skipping test as causal") + + num_qo_heads = num_kv_heads * head_grp_size + head_dim_qk = 192 + head_dim_vo = 128 + + seed = 0 + torch.manual_seed(seed) + device = "cuda:0" + + actual_seq_lens_q = torch.randint( + 1, s_qo + 1, (batch_size, 1, 1, 1), dtype=torch.int32, device=device + ) + + actual_seq_lens_kv = torch.randint( + s_qo, s_kv + 1, (batch_size, 1, 1, 1), dtype=torch.int32, device=device + ) + + cumsum_s_qo = torch.sum(actual_seq_lens_q) + cumsum_s_kv = torch.sum(actual_seq_lens_kv) + + q = torch.randn( + cumsum_s_qo, num_qo_heads, head_dim_qk, device=device, dtype=torch.bfloat16 + ) + + k_cache = torch.randn( + (cumsum_s_kv, num_kv_heads, head_dim_qk), + device=device, + dtype=torch.bfloat16, + ) + v_cache = torch.randn( + (cumsum_s_kv, num_kv_heads, head_dim_vo), + device=device, + dtype=torch.bfloat16, + ) + + # Initialize scale + scale = float(1.0 / (head_dim_qk**0.5)) + + workspace_buffer = torch.empty(workspace_size, dtype=torch.int8, device=device) + + qo_indptr = torch.cat( + [ + torch.tensor([0], device=device), + torch.cumsum(actual_seq_lens_q.view(-1), dim=0), + ] + ).int() + + # kv_indptr = torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * s_kv + + # Create kv_indptr as cumulative sum of actual_seq_lens_kv + kv_indptr = torch.cat( + [ + torch.tensor( + [0], + device=device, + ), + torch.cumsum(actual_seq_lens_kv.view(-1), dim=0), + ] + ).int() + + wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( + torch.zeros(workspace_size, device="cuda", dtype=torch.uint8), + kv_layout="NHD", + backend="cutlass", + ) + wrapper.plan( + qo_indptr, + kv_indptr, + num_qo_heads, + num_kv_heads, + head_dim_qk, + head_dim_vo=head_dim_vo, + causal=causal, + sm_scale=scale, + q_data_type=torch.bfloat16, + kv_data_type=torch.bfloat16, + ) + output_ref, lse_ref = wrapper.run(q, k_cache, v_cache, return_lse=True) + output = torch.empty_like(output_ref) + + bmm1_scale = scale + bmm2_scale = 1.0 + bmm1_scale_log2_tensor = torch.tensor([scale * math.log2(math.e)], device=device) + bmm2_scale_tensor = torch.tensor([1.0], device=device) + output_trtllm, lse_trtllm = flashinfer.prefill.trtllm_ragged_attention_deepseek( + q, + k_cache, + v_cache, + workspace_buffer, + actual_seq_lens_kv, + s_qo, + s_kv, + bmm1_scale, + bmm2_scale, + -1, + batch_size, + -1, + qo_indptr, + kv_indptr, + False, + causal, + True, + out=output, + bmm1_scale_log2_tensor=bmm1_scale_log2_tensor, + bmm2_scale_tensor=bmm2_scale_tensor, + ) + torch.testing.assert_close( + output_trtllm, + output_ref, + atol=1e-2, + rtol=1e-2, + ) + torch.testing.assert_close( + lse_trtllm, + lse_ref, + atol=1e-3, + rtol=1e-3, + ) + + @pytest.mark.parametrize( "batch_size", [1, 2, 4, 16, 32, 64, 128, 256, 512, 768, 1024], From c0c07d151e876e202bd0d1211ff9f76bfc85c722 Mon Sep 17 00:00:00 2001 From: Avery Yingyi Huang Date: Wed, 3 Sep 2025 01:49:07 -0400 Subject: [PATCH 05/13] fix --- csrc/trtllm_fmha_kernel_launcher.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/trtllm_fmha_kernel_launcher.cu b/csrc/trtllm_fmha_kernel_launcher.cu index 69c78c528b..784b443dca 100644 --- a/csrc/trtllm_fmha_kernel_launcher.cu +++ b/csrc/trtllm_fmha_kernel_launcher.cu @@ -491,7 +491,7 @@ void trtllm_ragged_attention(at::Tensor out, at::Tensor query, at::Tensor key, a static_cast(cum_seq_lens_q.data_ptr()), static_cast(cum_seq_lens_kv.data_ptr()), attention_sinks_ptr, lse_ptr, q_data_type, kv_data_type, o_data_type, max_q_len, max_kv_len, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_v, sum_seq_q, sum_seq_kv, bmm1_scale, - bmm1_scale_log2_ptr, bmm2_scale_ptr, bmm2_scale, o_sf_scale, batch_size, window_left, + bmm2_scale, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, batch_size, window_left, sm_count, enable_pdl, is_causal, k_stride_keys_values, k_stride_heads, k_stride_batch, v_stride_keys_values, v_stride_heads, v_stride_batch, workspace_size, stream); } From d347bb6b69bebe8f631d2a99bec570dc7c408bc4 Mon Sep 17 00:00:00 2001 From: Avery Yingyi Huang Date: Wed, 3 Sep 2025 15:01:10 -0400 Subject: [PATCH 06/13] fix --- flashinfer/decode.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 60ebb7a23c..c981db9a36 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -1330,8 +1330,6 @@ def run( TensorLayout[self._kv_layout].value, window_left, enable_pdl, - bmm1_scale_log2_tensor, - bmm2_scale_tensor, ] if self._jit_module is not None: From ec0e22c431d5c444fa8c36fd5cfd7aba4794237f Mon Sep 17 00:00:00 2001 From: Avery Yingyi Huang Date: Wed, 3 Sep 2025 17:22:43 -0400 Subject: [PATCH 07/13] comment --- flashinfer/decode.py | 20 ++++++++++---------- flashinfer/prefill.py | 16 ++++++++-------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/flashinfer/decode.py b/flashinfer/decode.py index c981db9a36..f8178a3017 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -1189,9 +1189,9 @@ def run( q_len_per_req : int The number of query tokens per request, if not provided, will be set to ``1``. bmm1_scale_log2_tensor : Optional[torch.Tensor] - The on-device fused scale tensor for bmm1 input. Must be fused with * M_LOG2E before passing in. + The on-device fused scale tensor for bmm1 input. Must be fused with * M_LOG2E before passing in. If provided, the scalar scale factor will be ignored. bmm2_scale_tensor : Optional[torch.Tensor] - The on-device fused scale tensor for bmm2 input. + The on-device fused scale tensor for bmm2 input. If provided, the scalar scale factor will be ignored. Returns ------- Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] @@ -2057,10 +2057,10 @@ def trtllm_batch_decode_with_kv_cache( max sequence length for kv_cache bmm1_scale : float - fused scale for bmm1 input. + fused scale for bmm1 input. We recommend to use bmm1_scale_log2_tensor for better cuda graph support. bmm2_scale : float - fused scale for bmm2 input. + fused scale for bmm2 input. We recommend to use bmm2_scale_tensor for better cuda graph support. window_left : int = -1 The left (inclusive) window size for the attention window, when set to ``-1``, the window @@ -2086,10 +2086,10 @@ def trtllm_batch_decode_with_kv_cache( Only supported for >= sm90, and currently only for FA2, CUDA core, and trtllm-gen decode. bmm1_scale_log2_tensor : Optional[torch.Tensor] - The on-device fused scale tensor for bmm1 input. Must be fused with * M_LOG2E before passing in. + The on-device fused scale tensor for bmm1 input. Must be fused with * M_LOG2E before passing in. If provided, the scalar scale factor will be ignored. bmm2_scale_tensor : Optional[torch.Tensor] - The on-device fused scale tensor for bmm2 input. + The on-device fused scale tensor for bmm2 input. If provided, the scalar scale factor will be ignored. Returns ------- @@ -2291,10 +2291,10 @@ def trtllm_batch_decode_with_kv_cache_mla( seq_lens: query_len max_seq_len: max sequence length for kv_cache out: output tensor, if not provided, will be allocated internally - bmm1_scale: fused scale for mla bmm1 input. - bmm2_scale: fused scale for mla bmm2 input. - bmm1_scale_log2_tensor: On-device fused scale tensor for mla bmm1 input. Must be fused with * M_LOG2E before passing in. - bmm2_scale_tensor: On-device fused scale tensor for mla bmm2 input. + bmm1_scale: fused scale for mla bmm1 input. We recommend to use bmm1_scale_log2_tensor for better cuda graph support. + bmm2_scale: fused scale for mla bmm2 input. We recommend to use bmm2_scale_tensor for better cuda graph support. + bmm1_scale_log2_tensor: On-device fused scale tensor for mla bmm1 input. Must be fused with * M_LOG2E before passing in. If provided, the scalar scale factor will be ignored. + bmm2_scale_tensor: On-device fused scale tensor for mla bmm2 input. If provided, the scalar scale factor will be ignored. sinks: additional value per head in the denominator of the softmax. Note: diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index a0ff13299e..b48a8c1172 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -3197,9 +3197,9 @@ def trtllm_ragged_attention_deepseek( max_kv_len : int max key/value length bmm1_scale : float - scale for bmm1, scale_q * scale_k * 1.0 / (head_dim_qk ** 0.5) + scale for bmm1, scale_q * scale_k * 1.0 / (head_dim_qk ** 0.5), but we recommend to use bmm1_scale_log2_tensor for better cuda graph support. bmm2_scale : float - scale for bmm2, scale_v + scale for bmm2, scale_v, but we recommend to use bmm2_scale_tensor for better cuda graph support. o_sf_scale : float scale for output batch_size : int @@ -3221,9 +3221,9 @@ def trtllm_ragged_attention_deepseek( lse : Optional[torch.Tensor] lse tensor, if not provided, will be allocated with shape [query.shape[0], query.shape[1]] bmm1_scale_log2_tensor: Optional[torch.Tensor] = None - The on-device fused scale tensor for bmm1 input. Must be fused with * M_LOG2E before passing in. + The on-device fused scale tensor for bmm1 input. Must be fused with * M_LOG2E before passing in. If provided, the bmm1_scale will be ignored. bmm2_scale_tensor: Optional[torch.Tensor] = None - The on-device fused scale tensor for bmm2 input. + The on-device fused scale tensor for bmm2 input. If provided, the bmm2_scale will be ignored. Returns ------- out: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] @@ -3330,9 +3330,9 @@ def trtllm_batch_context_with_kv_cache( max_kv_len : int max sequence length for kv_cache bmm1_scale : float - fused scale for bmm1 input. + fused scale for bmm1 input. But we recommend to use bmm1_scale_log2_tensor for better cuda graph support. bmm2_scale : float - fused scale for bmm2 input. + fused scale for bmm2 input. But we recommend to use bmm2_scale_tensor for better cuda graph support. batch_size : int batch size cum_seq_lens_q : torch.Tensor @@ -3353,9 +3353,9 @@ def trtllm_batch_context_with_kv_cache( sinks : Optional[List[torch.Tensor]] = None additional value per head in the denominator of the softmax. bmm1_scale_log2_tensor: Optional[torch.Tensor] = None - The on-device fused scale tensor for bmm1 input. Must be fused with * M_LOG2E before passing in. + The on-device fused scale tensor for bmm1 input. Must be fused with * M_LOG2E before passing in. If provided, the bmm1_scale will be ignored. bmm2_scale_tensor: Optional[torch.Tensor] = None - The on-device fused scale tensor for bmm2 input. + The on-device fused scale tensor for bmm2 input. If provided, the bmm2_scale will be ignored. Returns ------- out: Union[torch.Tensor, FP4Tensor] From 45f8ad0ee55bdd7f80c5872e3ee4f875670c0acd Mon Sep 17 00:00:00 2001 From: Avery Yingyi Huang Date: Wed, 3 Sep 2025 17:24:35 -0400 Subject: [PATCH 08/13] upd mla test --- tests/test_trtllm_gen_mla.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/test_trtllm_gen_mla.py b/tests/test_trtllm_gen_mla.py index 0dc2349e11..361a69f633 100644 --- a/tests/test_trtllm_gen_mla.py +++ b/tests/test_trtllm_gen_mla.py @@ -102,8 +102,10 @@ def test_trtllm_gen_prefill_deepseek( bmm1_scale = scale bmm2_scale = 1.0 - bmm1_scale_log2_tensor = torch.tensor([scale * math.log2(math.e)], device=device) - bmm2_scale_tensor = torch.tensor([1.0], device=device) + bmm1_scale_log2_tensor = torch.tensor( + [bmm1_scale * math.log2(math.e)], device=device + ) + bmm2_scale_tensor = torch.tensor([bmm2_scale], device=device) output_trtllm, lse_trtllm = flashinfer.prefill.trtllm_ragged_attention_deepseek( q, k_cache, @@ -112,8 +114,8 @@ def test_trtllm_gen_prefill_deepseek( actual_seq_lens_kv, s_qo, s_kv, - bmm1_scale, - bmm2_scale, + 1.0, # should be bmm1_scale, just for testing dynamic in-memory scale factors + 0.0, # should be bmm2_scale, just for testing dynamic in-memory scale factors -1, batch_size, -1, @@ -260,8 +262,8 @@ def test_trtllm_batch_decode_mla( block_tables=block_tables, seq_lens=seq_lens_tensor, max_seq_len=max_seq_len, - bmm1_scale=scale / ((128 + 64) ** 0.5), - bmm2_scale=1.0, + bmm1_scale=1.0, # should be scale / ((128 + 64) ** 0.5), just for testing dynamic in-memory scale factors + bmm2_scale=0.0, # should be 1.0, just for testing dynamic in-memory scale factors bmm1_scale_log2_tensor=bmm1_log2_scale_tensor, bmm2_scale_tensor=bmm2_scale_tensor, enable_pdl=enable_pdl, From 85f185419d4667ad5aada010853bd5970962c811 Mon Sep 17 00:00:00 2001 From: Avery Yingyi Huang Date: Wed, 3 Sep 2025 17:47:04 -0400 Subject: [PATCH 09/13] fix --- tests/test_trtllm_gen_attention.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/tests/test_trtllm_gen_attention.py b/tests/test_trtllm_gen_attention.py index 9b7f5eaf0a..a4a6b3cbb9 100755 --- a/tests/test_trtllm_gen_attention.py +++ b/tests/test_trtllm_gen_attention.py @@ -352,10 +352,12 @@ def test_trtllm_batch_prefill( # Run trtllm-gen function call sm_scale = float(1.0 / (head_dim**0.5)) + bmm1_scale = q_scale * k_scale * sm_scale bmm1_scale_log2_tensor = torch.tensor( - [sm_scale * math.log2(math.e)], device=GPU_DEVICE + [bmm1_scale * math.log2(math.e)], device=GPU_DEVICE ) - bmm2_scale_tensor = torch.tensor([1.0], device=GPU_DEVICE) + bmm2_scale = v_scale / o_scale + bmm2_scale_tensor = torch.tensor([bmm2_scale], device=GPU_DEVICE) output = flashinfer.prefill.trtllm_batch_context_with_kv_cache( q.contiguous(), kv_cache, @@ -364,8 +366,8 @@ def test_trtllm_batch_prefill( seq_lens.to(GPU_DEVICE), torch.max(q_lens).item(), torch.max(seq_lens).item(), - q_scale * k_scale * sm_scale, # bmm1_scale - v_scale / o_scale, # bmm2_scale + bmm1_scale, + bmm2_scale, batch_size, q_indptr, kv_indptr, @@ -566,10 +568,12 @@ def test_trtllm_batch_decode( # Run trtllm-gen function call sm_scale = float(1.0 / (head_dim**0.5)) + bmm1_scale = q_scale * k_scale * sm_scale bmm1_scale_log2_tensor = torch.tensor( - [sm_scale * math.log2(math.e)], device=GPU_DEVICE + [bmm1_scale * math.log2(math.e)], device=GPU_DEVICE ) - bmm2_scale_tensor = torch.tensor([1.0], device=GPU_DEVICE) + bmm2_scale = v_scale / o_scale + bmm2_scale_tensor = torch.tensor([bmm2_scale], device=GPU_DEVICE) output = flashinfer.decode.trtllm_batch_decode_with_kv_cache( q.contiguous(), @@ -578,8 +582,8 @@ def test_trtllm_batch_decode( page_table, seq_lens.to(GPU_DEVICE), torch.max(seq_lens).item(), - q_scale * k_scale * sm_scale, # bmm1_scale - v_scale / o_scale, # bmm2_scale + bmm1_scale, # bmm1_scale + bmm2_scale, # bmm2_scale window_left, # window_left out=out, out_dtype=out_dtype, From f9518d0252e6bee2ef5cfc9596ad6806d7c761d7 Mon Sep 17 00:00:00 2001 From: Avery Yingyi Huang Date: Wed, 3 Sep 2025 18:42:46 -0400 Subject: [PATCH 10/13] upd mla test --- tests/test_trtllm_gen_mla.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/test_trtllm_gen_mla.py b/tests/test_trtllm_gen_mla.py index 361a69f633..640d5d3f32 100644 --- a/tests/test_trtllm_gen_mla.py +++ b/tests/test_trtllm_gen_mla.py @@ -236,9 +236,11 @@ def test_trtllm_batch_decode_mla( ) workspace_buffer = global_workspace_buffer + bmm1_scale = scale / ((128 + 64) ** 0.5) + bmm2_scale = 1.0 bmm1_log2_scale_tensor = ( torch.tensor( - [scale / ((128 + 64) ** 0.5 * math.log2(math.e))], + [bmm1_scale * math.log2(math.e)], dtype=torch.float32, device=device, ) @@ -246,7 +248,7 @@ def test_trtllm_batch_decode_mla( else None ) bmm2_scale_tensor = ( - torch.tensor([1.0], dtype=torch.float32, device=device) + torch.tensor([bmm2_scale], dtype=torch.float32, device=device) if dynamic_scale else None ) @@ -262,8 +264,8 @@ def test_trtllm_batch_decode_mla( block_tables=block_tables, seq_lens=seq_lens_tensor, max_seq_len=max_seq_len, - bmm1_scale=1.0, # should be scale / ((128 + 64) ** 0.5), just for testing dynamic in-memory scale factors - bmm2_scale=0.0, # should be 1.0, just for testing dynamic in-memory scale factors + bmm1_scale=1.0, # should be bmm1_scale, just for testing dynamic in-memory scale factors + bmm2_scale=0.0, # should be bmm2_scale, just for testing dynamic in-memory scale factors bmm1_scale_log2_tensor=bmm1_log2_scale_tensor, bmm2_scale_tensor=bmm2_scale_tensor, enable_pdl=enable_pdl, From 2e98f38e93365cd1ed7301ff192f14afeb4cbfcf Mon Sep 17 00:00:00 2001 From: Avery Yingyi Huang Date: Wed, 3 Sep 2025 22:31:06 -0400 Subject: [PATCH 11/13] fix mha test --- tests/test_trtllm_gen_attention.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_trtllm_gen_attention.py b/tests/test_trtllm_gen_attention.py index a4a6b3cbb9..6d0f11ddfc 100755 --- a/tests/test_trtllm_gen_attention.py +++ b/tests/test_trtllm_gen_attention.py @@ -407,6 +407,9 @@ def test_trtllm_batch_prefill( plan_params["q_data_type"] = q.dtype plan_params["kv_data_type"] = kv_cache.dtype wrapper_trtllm_gen.plan(**plan_params) + bmm2_scale_tensor = torch.tensor( + [1.0], device=GPU_DEVICE + ) # todo(Yingyi): wrapper accept fixed bmm2_scale as 1.0 output_wrapper = wrapper_trtllm_gen.run( q.contiguous(), kv_cache, @@ -628,6 +631,9 @@ def test_trtllm_batch_decode( plan_params["q_data_type"] = q.dtype plan_params["kv_data_type"] = kv_cache.dtype wrapper_trtllm_gen.plan(**plan_params) + bmm2_scale_tensor = torch.tensor( + [1.0], device=GPU_DEVICE + ) # todo(Yingyi): wrapper accept fixed bmm2_scale as 1.0 output_wrapper = wrapper_trtllm_gen.run( q.contiguous(), kv_cache, From ed6aff776b3fa2166f9fbd655e637226c0275672 Mon Sep 17 00:00:00 2001 From: Avery Yingyi Huang Date: Wed, 3 Sep 2025 22:40:27 -0400 Subject: [PATCH 12/13] undo mla test --- tests/test_trtllm_gen_mla.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/test_trtllm_gen_mla.py b/tests/test_trtllm_gen_mla.py index 640d5d3f32..25e9f9b9bb 100644 --- a/tests/test_trtllm_gen_mla.py +++ b/tests/test_trtllm_gen_mla.py @@ -114,8 +114,8 @@ def test_trtllm_gen_prefill_deepseek( actual_seq_lens_kv, s_qo, s_kv, - 1.0, # should be bmm1_scale, just for testing dynamic in-memory scale factors - 0.0, # should be bmm2_scale, just for testing dynamic in-memory scale factors + bmm1_scale, # todo(Yingyi): disable this scale for testing dynamic in-memory scale factors + bmm2_scale, # todo(Yingyi): disable this scale for testing dynamic in-memory scale factors -1, batch_size, -1, @@ -125,6 +125,7 @@ def test_trtllm_gen_prefill_deepseek( causal, True, out=output, + # todo(Yingyi): enable this scale for testing dynamic in-memory scale factors bmm1_scale_log2_tensor=bmm1_scale_log2_tensor, bmm2_scale_tensor=bmm2_scale_tensor, ) @@ -264,8 +265,9 @@ def test_trtllm_batch_decode_mla( block_tables=block_tables, seq_lens=seq_lens_tensor, max_seq_len=max_seq_len, - bmm1_scale=1.0, # should be bmm1_scale, just for testing dynamic in-memory scale factors - bmm2_scale=0.0, # should be bmm2_scale, just for testing dynamic in-memory scale factors + bmm1_scale=bmm1_scale, # todo(Yingyi): disable this scale for testing dynamic in-memory scale factors + bmm2_scale=bmm2_scale, # todo(Yingyi): disable this scale for testing dynamic in-memory scale factors + # todo(Yingyi): enable this scale for testing dynamic in-memory scale factors bmm1_scale_log2_tensor=bmm1_log2_scale_tensor, bmm2_scale_tensor=bmm2_scale_tensor, enable_pdl=enable_pdl, From fe284e90ced12b153e7194601afd58789e60b6a8 Mon Sep 17 00:00:00 2001 From: Avery Yingyi Huang Date: Wed, 3 Sep 2025 23:11:09 -0400 Subject: [PATCH 13/13] disable scalar scale for mha test --- tests/test_trtllm_gen_attention.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_trtllm_gen_attention.py b/tests/test_trtllm_gen_attention.py index 6d0f11ddfc..612b76efd8 100755 --- a/tests/test_trtllm_gen_attention.py +++ b/tests/test_trtllm_gen_attention.py @@ -366,8 +366,8 @@ def test_trtllm_batch_prefill( seq_lens.to(GPU_DEVICE), torch.max(q_lens).item(), torch.max(seq_lens).item(), - bmm1_scale, - bmm2_scale, + 0.0, # should be bmm1_scale, disabled to use dynamic in-memory scale factors, but fp16 and bf16 failed to run + 0.0, # should be bmm2_scale, disabled to use dynamic in-memory scale factors, but fp16 and bf16 failed to run batch_size, q_indptr, kv_indptr, @@ -585,8 +585,8 @@ def test_trtllm_batch_decode( page_table, seq_lens.to(GPU_DEVICE), torch.max(seq_lens).item(), - bmm1_scale, # bmm1_scale - bmm2_scale, # bmm2_scale + 0.0, # should be bmm1_scale, disabled to use dynamic in-memory scale factors, but fp16 and bf16 failed to run + 0.0, # should be bmm2_scale, disabled to use dynamic in-memory scale factors, but fp16 and bf16 failed to run window_left, # window_left out=out, out_dtype=out_dtype,