Skip to content
Draft
83 changes: 58 additions & 25 deletions csrc/trtllm_fmha_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,10 @@ void trtllm_paged_attention_launcher(
int64_t max_kv_len, int64_t num_pages_in_mem_pool, int64_t num_qo_heads, int64_t num_kv_heads,
int64_t head_dim_qk, int64_t head_dim_vo, int64_t page_size, int64_t kv_stride_keys_values,
int64_t kv_stride_heads, int64_t kv_stride_batch, int64_t max_num_blocks_per_seq,
double bmm1_scale, double bmm2_scale, double o_sf_scale, int64_t o_sf_vec_size,
int64_t o_sf_start_index, int64_t window_left, int64_t sum_seq_q, int64_t sm_count,
bool enable_pdl, int64_t workspace_size, cudaStream_t stream) {
double bmm1_scale, double bmm2_scale, float* bmm1_scale_log2_ptr, float* bmm2_scale_ptr,
double o_sf_scale, int64_t o_sf_vec_size, int64_t o_sf_start_index, int64_t window_left,
int64_t sum_seq_q, int64_t sm_count, bool enable_pdl, int64_t workspace_size,
cudaStream_t stream) {
if (num_qo_heads % num_kv_heads != 0) {
std::ostringstream err_msg;
err_msg << "num_qo_heads must be a multiple of num_kv_heads, got num_kv_heads: " << num_kv_heads
Expand Down Expand Up @@ -120,6 +121,8 @@ void trtllm_paged_attention_launcher(
runner_params.stream = stream;
runner_params.outputScale = bmm2_scale;
runner_params.scaleSoftmaxLog2 = bmm1_scale * M_LOG2E;
runner_params.outputScalePtr = bmm2_scale_ptr;
runner_params.scaleSoftmaxLog2Ptr = bmm1_scale_log2_ptr;
runner_params.oSfPtr = out_scale_factor;
runner_params.mSfStartTokenIdx = o_sf_start_index;
runner_params.mScaleSfO = o_sf_scale;
Expand Down Expand Up @@ -202,7 +205,9 @@ void trtllm_paged_attention_decode(at::Tensor out, std::optional<at::Tensor> out
double bmm2_scale, double o_sf_scale, int64_t o_sf_vec_size,
int64_t o_sf_start_index, int64_t window_left, int64_t sm_count,
bool enable_pdl, int64_t workspace_size,
std::optional<at::Tensor> attention_sinks) {
std::optional<at::Tensor> attention_sinks,
std::optional<at::Tensor> bmm1_scale_log2_tensor,
std::optional<at::Tensor> bmm2_scale_tensor) {
auto q_data_type = torch_dtype_to_tllm_data_type(query.scalar_type());
auto kv_data_type = torch_dtype_to_tllm_data_type(key_cache.scalar_type());
TORCH_CHECK_EQ(key_cache.dim(), value_cache.dim());
Expand Down Expand Up @@ -249,6 +254,14 @@ void trtllm_paged_attention_decode(at::Tensor out, std::optional<at::Tensor> out
"attention_sinks must be a float tensor");
attention_sinks_ptr = attention_sinks->data_ptr<float>();
}
float* bmm1_scale_log2_ptr = nullptr;
float* bmm2_scale_ptr = nullptr;
if (bmm1_scale_log2_tensor.has_value()) {
bmm1_scale_log2_ptr = static_cast<float*>(bmm1_scale_log2_tensor.value().data_ptr());
}
if (bmm2_scale_tensor.has_value()) {
bmm2_scale_ptr = static_cast<float*>(bmm2_scale_tensor.value().data_ptr());
}

trtllm_paged_attention_launcher(
out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(),
Expand All @@ -259,20 +272,19 @@ void trtllm_paged_attention_decode(at::Tensor out, std::optional<at::Tensor> out
TllmPagedAttentionMode::ForGen, batch_size, /*max_q_len=*/q_len_per_request, max_kv_len,
num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size,
kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, bmm1_scale,
bmm2_scale, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sm_count,
enable_pdl, workspace_size, stream);
bmm2_scale, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index,
window_left, sum_seq_q, sm_count, enable_pdl, workspace_size, stream);
}

void trtllm_paged_attention_context(at::Tensor out, std::optional<at::Tensor> out_scale_factor,
at::Tensor query, at::Tensor key_cache, at::Tensor value_cache,
at::Tensor workspace_buffer, at::Tensor block_tables,
at::Tensor seq_lens, int64_t max_q_len, int64_t max_kv_len,
double bmm1_scale, double bmm2_scale, double o_sf_scale,
int64_t o_sf_vec_size, int64_t o_sf_start_index,
int64_t batch_size, int64_t window_left,
at::Tensor cum_seq_lens_q, at::Tensor cum_seq_lens_kv,
int64_t sm_count, bool enable_pdl, int64_t workspace_size,
std::optional<at::Tensor> attention_sinks) {
void trtllm_paged_attention_context(
at::Tensor out, std::optional<at::Tensor> out_scale_factor, at::Tensor query,
at::Tensor key_cache, at::Tensor value_cache, at::Tensor workspace_buffer,
at::Tensor block_tables, at::Tensor seq_lens, int64_t max_q_len, int64_t max_kv_len,
double bmm1_scale, double bmm2_scale, double o_sf_scale, int64_t o_sf_vec_size,
int64_t o_sf_start_index, int64_t batch_size, int64_t window_left, at::Tensor cum_seq_lens_q,
at::Tensor cum_seq_lens_kv, int64_t sm_count, bool enable_pdl, int64_t workspace_size,
std::optional<at::Tensor> attention_sinks, std::optional<at::Tensor> bmm1_scale_log2_tensor,
std::optional<at::Tensor> bmm2_scale_tensor) {
auto q_data_type = torch_dtype_to_tllm_data_type(query.scalar_type());
auto kv_data_type = torch_dtype_to_tllm_data_type(key_cache.scalar_type());
auto o_data_type = torch_dtype_to_tllm_data_type(out.scalar_type());
Expand Down Expand Up @@ -309,7 +321,14 @@ void trtllm_paged_attention_context(at::Tensor out, std::optional<at::Tensor> ou
"attention_sinks must be a float tensor");
attention_sinks_ptr = attention_sinks->data_ptr<float>();
}

float* bmm1_scale_log2_ptr = nullptr;
float* bmm2_scale_ptr = nullptr;
if (bmm1_scale_log2_tensor.has_value()) {
bmm1_scale_log2_ptr = static_cast<float*>(bmm1_scale_log2_tensor.value().data_ptr());
}
if (bmm2_scale_tensor.has_value()) {
bmm2_scale_ptr = static_cast<float*>(bmm2_scale_tensor.value().data_ptr());
}
trtllm_paged_attention_launcher(
out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(),
workspace_buffer.data_ptr(), static_cast<int*>(block_tables.data_ptr()),
Expand All @@ -319,8 +338,9 @@ void trtllm_paged_attention_context(at::Tensor out, std::optional<at::Tensor> ou
q_data_type, kv_data_type, o_data_type, TllmPagedAttentionMode::Context, batch_size,
max_q_len, max_kv_len, num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q,
head_dim_o, page_size, kv_stride_keys_values, kv_stride_heads, kv_stride_batch,
max_num_blocks_per_seq, bmm1_scale, bmm2_scale, o_sf_scale, o_sf_vec_size, o_sf_start_index,
window_left, sum_seq_q, sm_count, enable_pdl, workspace_size, stream);
max_num_blocks_per_seq, bmm1_scale, bmm2_scale, bmm1_scale_log2_ptr, bmm2_scale_ptr,
o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sm_count, enable_pdl,
workspace_size, stream);
}

void trtllm_ragged_attention_launcher(
Expand All @@ -329,8 +349,9 @@ void trtllm_ragged_attention_launcher(
Data_type q_data_type, Data_type kv_data_type, Data_type o_data_type, int64_t max_q_len,
int64_t max_kv_len, int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim_qk,
int64_t head_dim_v, int64_t sum_seq_q, int64_t sum_seq_kv, double bmm1_scale, double bmm2_scale,
double o_sf_scale, int64_t batch_size, int64_t window_left, int64_t sm_count, bool enable_pdl,
bool is_causal, int64_t k_stride_keys_values, int64_t k_stride_heads, int64_t k_stride_batch,
float* bmm1_scale_log2_ptr, float* bmm2_scale_ptr, double o_sf_scale, int64_t batch_size,
int64_t window_left, int64_t sm_count, bool enable_pdl, bool is_causal,
int64_t k_stride_keys_values, int64_t k_stride_heads, int64_t k_stride_batch,
int64_t v_stride_keys_values, int64_t v_stride_heads, int64_t v_stride_batch,
int64_t workspace_size, cudaStream_t stream) {
if (num_qo_heads % num_kv_heads != 0) {
Expand Down Expand Up @@ -360,6 +381,8 @@ void trtllm_ragged_attention_launcher(
runner_params.stream = stream;
runner_params.outputScale = bmm2_scale;
runner_params.scaleSoftmaxLog2 = bmm1_scale * M_LOG2E;
runner_params.outputScalePtr = bmm2_scale_ptr;
runner_params.scaleSoftmaxLog2Ptr = bmm1_scale_log2_ptr;
runner_params.mScaleSfO = o_sf_scale;
runner_params.mChunkedAttentionSize = INT_MAX; // disable chunked attention by INT_MAX
runner_params.mAttentionWindowSize =
Expand Down Expand Up @@ -417,7 +440,9 @@ void trtllm_ragged_attention(at::Tensor out, at::Tensor query, at::Tensor key, a
at::Tensor cum_seq_lens_q, at::Tensor cum_seq_lens_kv,
int64_t sm_count, bool enable_pdl, bool is_causal,
int64_t workspace_size, std::optional<at::Tensor> attention_sinks,
std::optional<at::Tensor> lse) {
std::optional<at::Tensor> lse,
std::optional<at::Tensor> bmm1_scale_log2_tensor,
std::optional<at::Tensor> bmm2_scale_tensor) {
float* attention_sinks_ptr = nullptr;
if (attention_sinks) {
TORCH_CHECK(attention_sinks->scalar_type() == at::ScalarType::Float,
Expand All @@ -429,6 +454,14 @@ void trtllm_ragged_attention(at::Tensor out, at::Tensor query, at::Tensor key, a
TORCH_CHECK(lse->scalar_type() == at::ScalarType::Float, "lse must be a float tensor");
lse_ptr = lse->data_ptr<float>();
}
float* bmm1_scale_log2_ptr = nullptr;
float* bmm2_scale_ptr = nullptr;
if (bmm1_scale_log2_tensor.has_value()) {
bmm1_scale_log2_ptr = static_cast<float*>(bmm1_scale_log2_tensor.value().data_ptr());
}
if (bmm2_scale_tensor.has_value()) {
bmm2_scale_ptr = static_cast<float*>(bmm2_scale_tensor.value().data_ptr());
}
TORCH_CHECK(out.dim() == 3, "out must be a 3D tensor");
TORCH_CHECK(query.dim() == 3, "query must be a 3D tensor");
TORCH_CHECK(key.dim() == 3, "key must be a 3D tensor");
Expand Down Expand Up @@ -458,9 +491,9 @@ void trtllm_ragged_attention(at::Tensor out, at::Tensor query, at::Tensor key, a
static_cast<int*>(cum_seq_lens_q.data_ptr()), static_cast<int*>(cum_seq_lens_kv.data_ptr()),
attention_sinks_ptr, lse_ptr, q_data_type, kv_data_type, o_data_type, max_q_len, max_kv_len,
num_qo_heads, num_kv_heads, head_dim_qk, head_dim_v, sum_seq_q, sum_seq_kv, bmm1_scale,
bmm2_scale, o_sf_scale, batch_size, window_left, sm_count, enable_pdl, is_causal,
k_stride_keys_values, k_stride_heads, k_stride_batch, v_stride_keys_values, v_stride_heads,
v_stride_batch, workspace_size, stream);
bmm2_scale, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, batch_size, window_left,
sm_count, enable_pdl, is_causal, k_stride_keys_values, k_stride_heads, k_stride_batch,
v_stride_keys_values, v_stride_heads, v_stride_batch, workspace_size, stream);
}

namespace trtllm_cubin_loader {
Expand Down
43 changes: 37 additions & 6 deletions flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -1147,6 +1147,8 @@ def run(
window_left: Optional[int] = None,
sinks: Optional[torch.Tensor] = None,
q_len_per_req: Optional[int] = 1,
bmm1_scale_log2_tensor: Optional[torch.Tensor] = None,
bmm2_scale_tensor: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
r"""Compute batch decode attention between query and paged kv cache.

Expand Down Expand Up @@ -1186,6 +1188,10 @@ def run(
Only supported for >= sm90, and currently only for FA2 and CUDA core decode.
q_len_per_req : int
The number of query tokens per request, if not provided, will be set to ``1``.
bmm1_scale_log2_tensor : Optional[torch.Tensor]
The on-device fused scale tensor for bmm1 input. Must be fused with * M_LOG2E before passing in. If provided, the scalar scale factor will be ignored.
bmm2_scale_tensor : Optional[torch.Tensor]
The on-device fused scale tensor for bmm2 input. If provided, the scalar scale factor will be ignored.
Returns
-------
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
Expand Down Expand Up @@ -1296,6 +1302,8 @@ def run(
page_size,
self._max_kv_len,
sinks,
bmm1_scale_log2_tensor,
bmm2_scale_tensor,
]

self._cached_module.paged_run(*run_args)
Expand Down Expand Up @@ -1832,6 +1840,8 @@ def _paged_run(
enable_pdl: bool = None,
out: Optional[torch.Tensor] = None,
sinks: Optional[torch.Tensor] = None,
bmm1_scale_log2_tensor: Optional[torch.Tensor] = None,
bmm2_scale_tensor: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if out is None:
out = torch.empty_like(query)
Expand All @@ -1858,6 +1868,8 @@ def _paged_run(
enable_pdl,
workspace_size,
sinks,
bmm1_scale_log2_tensor,
bmm2_scale_tensor,
)
return out

Expand Down Expand Up @@ -1919,6 +1931,8 @@ def paged_run(
page_size: Optional[int] = None,
max_kv_len: Optional[int] = None,
sinks: Optional[torch.Tensor] = None,
bmm1_scale_log2_tensor: Optional[torch.Tensor] = None,
bmm2_scale_tensor: Optional[torch.Tensor] = None,
) -> None:
assert maybe_lse is None
assert paged_kv_cache is not None
Expand All @@ -1944,6 +1958,8 @@ def paged_run(
enable_pdl,
out=o,
sinks=sinks,
bmm1_scale_log2_tensor=bmm1_scale_log2_tensor,
bmm2_scale_tensor=bmm2_scale_tensor,
)

@register_fake_op(f"flashinfer::{uri}_paged_run")
Expand Down Expand Up @@ -1983,6 +1999,8 @@ def _fake_paged_run(
page_size: Optional[int] = None,
max_kv_len: Optional[int] = None,
sinks: Optional[torch.Tensor] = None,
bmm1_scale_log2_tensor: Optional[torch.Tensor] = None,
bmm2_scale_tensor: Optional[torch.Tensor] = None,
) -> None:
pass

Expand Down Expand Up @@ -2013,6 +2031,8 @@ def trtllm_batch_decode_with_kv_cache(
sinks: Optional[List[torch.Tensor]] = None,
enable_pdl: bool = None,
q_len_per_req: Optional[int] = 1,
bmm1_scale_log2_tensor: Optional[torch.Tensor] = None,
bmm2_scale_tensor: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, FP4Tensor]:
"""
Parameters
Expand All @@ -2037,10 +2057,10 @@ def trtllm_batch_decode_with_kv_cache(
max sequence length for kv_cache

bmm1_scale : float
fused scale for bmm1 input.
fused scale for bmm1 input. We recommend to use bmm1_scale_log2_tensor for better cuda graph support.

bmm2_scale : float
fused scale for bmm2 input.
fused scale for bmm2 input. We recommend to use bmm2_scale_tensor for better cuda graph support.

window_left : int = -1
The left (inclusive) window size for the attention window, when set to ``-1``, the window
Expand All @@ -2065,6 +2085,12 @@ def trtllm_batch_decode_with_kv_cache(
Whether to enable Programmatic Dependent Launch (PDL). See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#programmatic-dependent-launch-and-synchronization
Only supported for >= sm90, and currently only for FA2, CUDA core, and trtllm-gen decode.

bmm1_scale_log2_tensor : Optional[torch.Tensor]
The on-device fused scale tensor for bmm1 input. Must be fused with * M_LOG2E before passing in. If provided, the scalar scale factor will be ignored.

bmm2_scale_tensor : Optional[torch.Tensor]
The on-device fused scale tensor for bmm2 input. If provided, the scalar scale factor will be ignored.

Returns
-------
out : Union[torch.Tensor, FP4Tensor]
Expand Down Expand Up @@ -2182,6 +2208,8 @@ def trtllm_batch_decode_with_kv_cache(
enable_pdl,
workspace_buffer.numel() * workspace_buffer.element_size(),
sinks,
bmm1_scale_log2_tensor,
bmm2_scale_tensor,
)

return (
Expand Down Expand Up @@ -2263,10 +2291,10 @@ def trtllm_batch_decode_with_kv_cache_mla(
seq_lens: query_len
max_seq_len: max sequence length for kv_cache
out: output tensor, if not provided, will be allocated internally
bmm1_scale: fused scale for mla bmm1 input.
bmm2_scale: fused scale for mla bmm2 input.
bmm1_scale_log2_tensor: On-device fused scale tensor for mla bmm1 input. Must be fused with * M_LOG2E before passing in.
bmm2_scale_tensor: On-device fused scale tensor for mla bmm2 input.
bmm1_scale: fused scale for mla bmm1 input. We recommend to use bmm1_scale_log2_tensor for better cuda graph support.
bmm2_scale: fused scale for mla bmm2 input. We recommend to use bmm2_scale_tensor for better cuda graph support.
bmm1_scale_log2_tensor: On-device fused scale tensor for mla bmm1 input. Must be fused with * M_LOG2E before passing in. If provided, the scalar scale factor will be ignored.
bmm2_scale_tensor: On-device fused scale tensor for mla bmm2 input. If provided, the scalar scale factor will be ignored.
sinks: additional value per head in the denominator of the softmax.

Note:
Expand Down Expand Up @@ -2320,6 +2348,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
"out",
)

# todo(Yingyi): check support for dynamic scale factors
if bmm1_scale_log2_tensor is not None and bmm2_scale_tensor is not None:
# dynamic scale factors
if query.dtype != torch.float8_e4m3fn or kv_cache.dtype != torch.float8_e4m3fn:
Expand Down Expand Up @@ -2347,5 +2376,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
enable_pdl,
workspace_buffer.numel() * workspace_buffer.element_size(),
sinks,
bmm1_scale_log2_tensor,
bmm2_scale_tensor,
)
return out
Loading