From 34b39e837e63498f5b9f133411659d176a0f07f7 Mon Sep 17 00:00:00 2001 From: "augusto.yjh" Date: Fri, 28 Nov 2025 10:55:17 +0800 Subject: [PATCH 1/8] fix merge conflicts Signed-off-by: augusto.yjh --- vllm/attention/ops/common.py | 99 ++++++++++++++++++++++-- vllm/v1/attention/backends/flashinfer.py | 4 +- 2 files changed, 95 insertions(+), 8 deletions(-) diff --git a/vllm/attention/ops/common.py b/vllm/attention/ops/common.py index af6766bdd161..d61ec7c76d9d 100644 --- a/vllm/attention/ops/common.py +++ b/vllm/attention/ops/common.py @@ -88,6 +88,88 @@ def _correct_attn_cp_out_kernel( tl.store(new_output_ptr + output_offsets, output) +@triton.jit +def _correct_attn_cp_out_kernel_for_flashinfer( + outputs_ptr, + new_output_ptr, + lses_ptr, + vlse_ptr, + outputs_stride_B, + outputs_stride_H, + outputs_stride_D, + lses_stride_N, + lses_stride_B, + lses_stride_H, + lse_idx, + HEAD_DIM: tl.constexpr, + N_ROUNDED: tl.constexpr, +): + """ + Apply the all-gathered lses to correct each local rank's attention + output. we still need perform a cross-rank reduction to obtain the + final attention output. + + Args: + outputs_ptr (triton.PointerType): + Pointer to input tensor of shape [ B, H, D ] + lses_ptr (triton.PointerType): + Pointer to input tensor of shape [ N, B, H ] + new_output_ptr (triton.PointerType): + Pointer to output tensor of shape [ B, H, D ] + vlse_ptr (triton.PointerType): + Pointer to output tensor of shape [ B, H ] + """ + batch_idx = tl.program_id(axis=0).to(tl.int64) + head_idx = tl.program_id(axis=1).to(tl.int64) + d_offsets = tl.arange(0, HEAD_DIM) + num_n_offsets = tl.arange(0, N_ROUNDED) + + # shape = [N] + lse_offsets = ( + num_n_offsets * lses_stride_N + + batch_idx * lses_stride_B + + head_idx * lses_stride_H + ) + + # calc final lse + lse = tl.load(lses_ptr + lse_offsets) + lse = tl.where((lse != lse) | (lse == float("inf")), -float("inf"), lse) + lse_max = tl.max(lse, axis=0) + lse_max = tl.where(lse_max == -float("inf"), 0, lse_max) + lse -= lse_max + lse_exp = tl.exp2(lse) + lse_acc = tl.sum(lse_exp, axis=0) + lse = tl.log2(lse_acc) + lse += lse_max + + lse_offsets = batch_idx * lses_stride_B + head_idx * lses_stride_H + tl.store(vlse_ptr + lse_offsets, lse) + + # shape = [D] + output_offsets = ( + batch_idx * outputs_stride_B + + head_idx * outputs_stride_H + + d_offsets * outputs_stride_D + ) + + # correct output + lse_offset = ( + lse_idx * lses_stride_N + batch_idx * lses_stride_B + head_idx * lses_stride_H + ) + lse_tmp = tl.load(lses_ptr + lse_offset) + lse_finally = lse_tmp - lse + lse_finally = tl.where( + (lse_finally != lse_finally) | (lse_finally == float("inf")), + -float("inf"), + lse_finally, + ) + factor = tl.exp2(lse_finally) + output = tl.load(outputs_ptr + output_offsets) + output = output * factor + + tl.store(new_output_ptr + output_offsets, output) + + class CPTritonContext: """The CPTritonContext is used to avoid recompilation of the Triton JIT.""" @@ -102,7 +184,7 @@ def call_kernel(self, kernel, grid, *regular_args, **const_args): def correct_attn_out( - out: torch.Tensor, lses: torch.Tensor, cp_rank: int, ctx: CPTritonContext + out: torch.Tensor, lses: torch.Tensor, cp_rank: int, ctx: CPTritonContext, is_lse_base_on_e: bool=True, ) -> tuple[torch.Tensor, torch.Tensor]: """Correct the attention output using the all-gathered lses. @@ -164,8 +246,10 @@ def correct_attn_out( cp_rank, ) const_args = {"HEAD_DIM": D, "N_ROUNDED": N} - - ctx.call_kernel(_correct_attn_cp_out_kernel, grid, *regular_args, **const_args) + correct_attn_kernel = _correct_attn_cp_out_kernel + if not is_lse_base_on_e: + correct_attn_kernel = _correct_attn_cp_out_kernel_for_flashinfer + ctx.call_kernel(correct_attn_kernel, grid, *regular_args, **const_args) return out, lse @@ -174,6 +258,7 @@ def _cp_lse_common( cp_attn_lse: torch.Tensor, cp_group: GroupCoordinator, ctx: CPTritonContext | None = None, + is_lse_base_on_e=True, ): """ cp_attn_out: [ B, H, D ] @@ -193,7 +278,7 @@ def _cp_lse_common( cp_attn_lse = cp_attn_lse.contiguous() lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses) - out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx) + out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx, is_lse_base_on_e=is_lse_base_on_e) return out, lse @@ -203,12 +288,13 @@ def cp_lse_ag_out_rs( cp_group: GroupCoordinator, ctx: CPTritonContext | None = None, return_lse: bool = False, + is_lse_base_on_e=True, ): """ cp_attn_out: [ B, H, D ] cp_attn_lse: [ B, H ] """ - out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx) + out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx, is_lse_base_on_e=is_lse_base_on_e) out = cp_group.reduce_scatter(out, dim=1) if return_lse: @@ -225,12 +311,13 @@ def cp_lse_ag_out_ar( cp_group: GroupCoordinator, ctx: CPTritonContext | None = None, return_lse: bool = False, + is_lse_base_on_e=True, ): """ cp_attn_out: [ B, H, D ] cp_attn_lse: [ B, H ] """ - out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx) + out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx, is_lse_base_on_e=is_lse_base_on_e) out = cp_group.all_reduce(out) if return_lse: diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 777398bf8a20..86c1e8d1f0f2 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -249,7 +249,7 @@ def run( return_lse=True, ) output_context, lse_context = cp_lse_ag_out_rs( - output_context_tmp, lse_context_tmp, get_dcp_group(), return_lse=True + output_context_tmp, lse_context_tmp, get_dcp_group(), return_lse=True, is_lse_base_on_e=False, ) lse_context = lse_context.transpose(0, 1).contiguous() @@ -1335,7 +1335,7 @@ def forward( return_lse=True, ) output[:num_decode_tokens] = cp_lse_ag_out_rs( - output_tmp, lse, get_dcp_group() + output_tmp, lse, get_dcp_group(), is_lse_base_on_e=False, ) else: decode_wrapper.run( From e8adf8d13c05af9f935bf1d4564fdd68c96e6726 Mon Sep 17 00:00:00 2001 From: "augusto.yjh" Date: Mon, 17 Nov 2025 15:21:19 +0800 Subject: [PATCH 2/8] add constexpr for triton kernel instead of write a new kernel Signed-off-by: augusto.yjh --- vllm/attention/ops/common.py | 104 +++++------------------------------ 1 file changed, 14 insertions(+), 90 deletions(-) diff --git a/vllm/attention/ops/common.py b/vllm/attention/ops/common.py index d61ec7c76d9d..789c2614c86a 100644 --- a/vllm/attention/ops/common.py +++ b/vllm/attention/ops/common.py @@ -21,6 +21,7 @@ def _correct_attn_cp_out_kernel( lse_idx, HEAD_DIM: tl.constexpr, N_ROUNDED: tl.constexpr, + IS_BASE_E: tl.constexpr, ): """ Apply the all-gathered lses to correct each local rank's attention @@ -55,91 +56,14 @@ def _correct_attn_cp_out_kernel( lse_max = tl.max(lse, axis=0) lse_max = tl.where(lse_max == -float("inf"), 0, lse_max) lse -= lse_max - lse_exp = tl.exp(lse) - lse_acc = tl.sum(lse_exp, axis=0) - lse = tl.log(lse_acc) - lse += lse_max - - lse_offsets = batch_idx * lses_stride_B + head_idx * lses_stride_H - tl.store(vlse_ptr + lse_offsets, lse) - - # shape = [D] - output_offsets = ( - batch_idx * outputs_stride_B - + head_idx * outputs_stride_H - + d_offsets * outputs_stride_D - ) - - # correct output - lse_offset = ( - lse_idx * lses_stride_N + batch_idx * lses_stride_B + head_idx * lses_stride_H - ) - lse_tmp = tl.load(lses_ptr + lse_offset) - lse_finally = lse_tmp - lse - lse_finally = tl.where( - (lse_finally != lse_finally) | (lse_finally == float("inf")), - -float("inf"), - lse_finally, - ) - factor = tl.exp(lse_finally) - output = tl.load(outputs_ptr + output_offsets) - output = output * factor - - tl.store(new_output_ptr + output_offsets, output) - - -@triton.jit -def _correct_attn_cp_out_kernel_for_flashinfer( - outputs_ptr, - new_output_ptr, - lses_ptr, - vlse_ptr, - outputs_stride_B, - outputs_stride_H, - outputs_stride_D, - lses_stride_N, - lses_stride_B, - lses_stride_H, - lse_idx, - HEAD_DIM: tl.constexpr, - N_ROUNDED: tl.constexpr, -): - """ - Apply the all-gathered lses to correct each local rank's attention - output. we still need perform a cross-rank reduction to obtain the - final attention output. - - Args: - outputs_ptr (triton.PointerType): - Pointer to input tensor of shape [ B, H, D ] - lses_ptr (triton.PointerType): - Pointer to input tensor of shape [ N, B, H ] - new_output_ptr (triton.PointerType): - Pointer to output tensor of shape [ B, H, D ] - vlse_ptr (triton.PointerType): - Pointer to output tensor of shape [ B, H ] - """ - batch_idx = tl.program_id(axis=0).to(tl.int64) - head_idx = tl.program_id(axis=1).to(tl.int64) - d_offsets = tl.arange(0, HEAD_DIM) - num_n_offsets = tl.arange(0, N_ROUNDED) - - # shape = [N] - lse_offsets = ( - num_n_offsets * lses_stride_N - + batch_idx * lses_stride_B - + head_idx * lses_stride_H - ) - - # calc final lse - lse = tl.load(lses_ptr + lse_offsets) - lse = tl.where((lse != lse) | (lse == float("inf")), -float("inf"), lse) - lse_max = tl.max(lse, axis=0) - lse_max = tl.where(lse_max == -float("inf"), 0, lse_max) - lse -= lse_max - lse_exp = tl.exp2(lse) - lse_acc = tl.sum(lse_exp, axis=0) - lse = tl.log2(lse_acc) + if IS_BASE_E: + lse_exp = tl.exp(lse) + lse_acc = tl.sum(lse_exp, axis=0) + lse = tl.log(lse_acc) + else: + lse_exp = tl.exp2(lse) + lse_acc = tl.sum(lse_exp, axis=0) + lse = tl.log2(lse_acc) lse += lse_max lse_offsets = batch_idx * lses_stride_B + head_idx * lses_stride_H @@ -163,7 +87,10 @@ def _correct_attn_cp_out_kernel_for_flashinfer( -float("inf"), lse_finally, ) - factor = tl.exp2(lse_finally) + if IS_BASE_E: + factor = tl.exp(lse_finally) + else: + factor = tl.exp2(lse_finally) output = tl.load(outputs_ptr + output_offsets) output = output * factor @@ -245,10 +172,7 @@ def correct_attn_out( l_sH, cp_rank, ) - const_args = {"HEAD_DIM": D, "N_ROUNDED": N} - correct_attn_kernel = _correct_attn_cp_out_kernel - if not is_lse_base_on_e: - correct_attn_kernel = _correct_attn_cp_out_kernel_for_flashinfer + const_args = {"HEAD_DIM": D, "N_ROUNDED": N, "IS_BASE_E": is_lse_base_on_e} ctx.call_kernel(correct_attn_kernel, grid, *regular_args, **const_args) return out, lse From abf6be57eabd7015e97a1e49185b7f2927e97b53 Mon Sep 17 00:00:00 2001 From: "augusto.yjh" Date: Mon, 17 Nov 2025 15:58:54 +0800 Subject: [PATCH 3/8] format code and pass lse base for flashinfer chunked prefill Signed-off-by: augusto.yjh --- vllm/attention/ops/common.py | 14 +++++++++++++- vllm/v1/attention/backends/flashinfer.py | 11 +++++++++-- vllm/v1/attention/backends/mla/common.py | 7 ++++++- 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/vllm/attention/ops/common.py b/vllm/attention/ops/common.py index 789c2614c86a..ae059a2df435 100644 --- a/vllm/attention/ops/common.py +++ b/vllm/attention/ops/common.py @@ -111,7 +111,11 @@ def call_kernel(self, kernel, grid, *regular_args, **const_args): def correct_attn_out( - out: torch.Tensor, lses: torch.Tensor, cp_rank: int, ctx: CPTritonContext, is_lse_base_on_e: bool=True, + out: torch.Tensor, + lses: torch.Tensor, + cp_rank: int, + ctx: CPTritonContext, + is_lse_base_on_e: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: """Correct the attention output using the all-gathered lses. @@ -202,6 +206,7 @@ def _cp_lse_common( cp_attn_lse = cp_attn_lse.contiguous() lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses) +<<<<<<< HEAD out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx, is_lse_base_on_e=is_lse_base_on_e) return out, lse @@ -219,6 +224,13 @@ def cp_lse_ag_out_rs( cp_attn_lse: [ B, H ] """ out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx, is_lse_base_on_e=is_lse_base_on_e) + out, lse = correct_attn_out( + cp_attn_out, + lses, + cp_group.rank_in_group, + ctx, + is_lse_base_on_e=is_lse_base_on_e, + ) out = cp_group.reduce_scatter(out, dim=1) if return_lse: diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 86c1e8d1f0f2..69a6a5e5fae8 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -249,7 +249,11 @@ def run( return_lse=True, ) output_context, lse_context = cp_lse_ag_out_rs( - output_context_tmp, lse_context_tmp, get_dcp_group(), return_lse=True, is_lse_base_on_e=False, + output_context_tmp, + lse_context_tmp, + get_dcp_group(), + return_lse=True, + is_lse_base_on_e=False, ) lse_context = lse_context.transpose(0, 1).contiguous() @@ -1335,7 +1339,10 @@ def forward( return_lse=True, ) output[:num_decode_tokens] = cp_lse_ag_out_rs( - output_tmp, lse, get_dcp_group(), is_lse_base_on_e=False, + output_tmp, + lse, + get_dcp_group(), + is_lse_base_on_e=False, ) else: decode_wrapper.run( diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index d94ed9183f63..7944021c3fbb 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -2057,7 +2057,12 @@ def forward( # correct dcp attn_out with lse. if self.dcp_world_size > 1: - attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group()) + attn_out = cp_lse_ag_out_rs( + attn_out, + lse, + get_dcp_group(), + is_lse_base_on_e: not self._use_fi_prefill, + ) # v_up projection self._v_up_proj(attn_out, out=output[:num_decode_tokens]) From 015e6344b08dce5db970c46b1f782118271a3176 Mon Sep 17 00:00:00 2001 From: "augusto.yjh" Date: Mon, 17 Nov 2025 16:04:13 +0800 Subject: [PATCH 4/8] fix syntax error Signed-off-by: augusto.yjh --- vllm/v1/attention/backends/mla/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 7944021c3fbb..b09541dbf791 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -2061,7 +2061,7 @@ def forward( attn_out, lse, get_dcp_group(), - is_lse_base_on_e: not self._use_fi_prefill, + is_lse_base_on_e=not self._use_fi_prefill, ) # v_up projection From 578e9c526dc7a97f63bfb6d26c8d57fc91ce7f5f Mon Sep 17 00:00:00 2001 From: "augusto.yjh" Date: Mon, 17 Nov 2025 16:38:27 +0800 Subject: [PATCH 5/8] recover triton kernel name in ctx.call_kernel Signed-off-by: augusto.yjh --- vllm/attention/ops/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/ops/common.py b/vllm/attention/ops/common.py index ae059a2df435..100a67cba0aa 100644 --- a/vllm/attention/ops/common.py +++ b/vllm/attention/ops/common.py @@ -177,7 +177,7 @@ def correct_attn_out( cp_rank, ) const_args = {"HEAD_DIM": D, "N_ROUNDED": N, "IS_BASE_E": is_lse_base_on_e} - ctx.call_kernel(correct_attn_kernel, grid, *regular_args, **const_args) + ctx.call_kernel(_correct_attn_cp_out_kernel, grid, *regular_args, **const_args) return out, lse From c2d13738a2faa4afef67c517c967d3afecbe5e5f Mon Sep 17 00:00:00 2001 From: "augusto.yjh" Date: Mon, 17 Nov 2025 16:39:34 +0800 Subject: [PATCH 6/8] simplify code in triton kernel _correct_attn_cp_out_kernel Signed-off-by: augusto.yjh --- vllm/attention/ops/common.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/vllm/attention/ops/common.py b/vllm/attention/ops/common.py index 100a67cba0aa..be34f8138fc1 100644 --- a/vllm/attention/ops/common.py +++ b/vllm/attention/ops/common.py @@ -87,10 +87,7 @@ def _correct_attn_cp_out_kernel( -float("inf"), lse_finally, ) - if IS_BASE_E: - factor = tl.exp(lse_finally) - else: - factor = tl.exp2(lse_finally) + factor = tl.exp(lse_finally) if IS_BASE_E else tl.exp2(lse_finally) output = tl.load(outputs_ptr + output_offsets) output = output * factor From 629af91e985e7095e80ed0d36611e12dfb06cf09 Mon Sep 17 00:00:00 2001 From: "augusto.yjh" Date: Fri, 28 Nov 2025 10:59:53 +0800 Subject: [PATCH 7/8] fix merge conflicts Signed-off-by: augusto.yjh --- vllm/attention/ops/common.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/vllm/attention/ops/common.py b/vllm/attention/ops/common.py index be34f8138fc1..25cb2d8e7f96 100644 --- a/vllm/attention/ops/common.py +++ b/vllm/attention/ops/common.py @@ -203,7 +203,6 @@ def _cp_lse_common( cp_attn_lse = cp_attn_lse.contiguous() lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses) -<<<<<<< HEAD out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx, is_lse_base_on_e=is_lse_base_on_e) return out, lse @@ -221,13 +220,6 @@ def cp_lse_ag_out_rs( cp_attn_lse: [ B, H ] """ out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx, is_lse_base_on_e=is_lse_base_on_e) - out, lse = correct_attn_out( - cp_attn_out, - lses, - cp_group.rank_in_group, - ctx, - is_lse_base_on_e=is_lse_base_on_e, - ) out = cp_group.reduce_scatter(out, dim=1) if return_lse: From be7721fc85555094023da72f80045b7e9f5d38ca Mon Sep 17 00:00:00 2001 From: "augusto.yjh" Date: Fri, 28 Nov 2025 11:21:11 +0800 Subject: [PATCH 8/8] format code with pre-commit Signed-off-by: augusto.yjh --- vllm/attention/ops/common.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/vllm/attention/ops/common.py b/vllm/attention/ops/common.py index 25cb2d8e7f96..bd6bc864d45d 100644 --- a/vllm/attention/ops/common.py +++ b/vllm/attention/ops/common.py @@ -203,7 +203,13 @@ def _cp_lse_common( cp_attn_lse = cp_attn_lse.contiguous() lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses) - out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx, is_lse_base_on_e=is_lse_base_on_e) + out, lse = correct_attn_out( + cp_attn_out, + lses, + cp_group.rank_in_group, + ctx, + is_lse_base_on_e=is_lse_base_on_e, + ) return out, lse @@ -219,7 +225,9 @@ def cp_lse_ag_out_rs( cp_attn_out: [ B, H, D ] cp_attn_lse: [ B, H ] """ - out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx, is_lse_base_on_e=is_lse_base_on_e) + out, lse = _cp_lse_common( + cp_attn_out, cp_attn_lse, cp_group, ctx=ctx, is_lse_base_on_e=is_lse_base_on_e + ) out = cp_group.reduce_scatter(out, dim=1) if return_lse: @@ -242,7 +250,9 @@ def cp_lse_ag_out_ar( cp_attn_out: [ B, H, D ] cp_attn_lse: [ B, H ] """ - out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx, is_lse_base_on_e=is_lse_base_on_e) + out, lse = _cp_lse_common( + cp_attn_out, cp_attn_lse, cp_group, ctx=ctx, is_lse_base_on_e=is_lse_base_on_e + ) out = cp_group.all_reduce(out) if return_lse: