diff --git a/vllm/attention/ops/common.py b/vllm/attention/ops/common.py index af6766bdd161..bd6bc864d45d 100644 --- a/vllm/attention/ops/common.py +++ b/vllm/attention/ops/common.py @@ -21,6 +21,7 @@ def _correct_attn_cp_out_kernel( lse_idx, HEAD_DIM: tl.constexpr, N_ROUNDED: tl.constexpr, + IS_BASE_E: tl.constexpr, ): """ Apply the all-gathered lses to correct each local rank's attention @@ -55,9 +56,14 @@ def _correct_attn_cp_out_kernel( lse_max = tl.max(lse, axis=0) lse_max = tl.where(lse_max == -float("inf"), 0, lse_max) lse -= lse_max - lse_exp = tl.exp(lse) - lse_acc = tl.sum(lse_exp, axis=0) - lse = tl.log(lse_acc) + if IS_BASE_E: + lse_exp = tl.exp(lse) + lse_acc = tl.sum(lse_exp, axis=0) + lse = tl.log(lse_acc) + else: + lse_exp = tl.exp2(lse) + lse_acc = tl.sum(lse_exp, axis=0) + lse = tl.log2(lse_acc) lse += lse_max lse_offsets = batch_idx * lses_stride_B + head_idx * lses_stride_H @@ -81,7 +87,7 @@ def _correct_attn_cp_out_kernel( -float("inf"), lse_finally, ) - factor = tl.exp(lse_finally) + factor = tl.exp(lse_finally) if IS_BASE_E else tl.exp2(lse_finally) output = tl.load(outputs_ptr + output_offsets) output = output * factor @@ -102,7 +108,11 @@ def call_kernel(self, kernel, grid, *regular_args, **const_args): def correct_attn_out( - out: torch.Tensor, lses: torch.Tensor, cp_rank: int, ctx: CPTritonContext + out: torch.Tensor, + lses: torch.Tensor, + cp_rank: int, + ctx: CPTritonContext, + is_lse_base_on_e: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: """Correct the attention output using the all-gathered lses. @@ -163,8 +173,7 @@ def correct_attn_out( l_sH, cp_rank, ) - const_args = {"HEAD_DIM": D, "N_ROUNDED": N} - + const_args = {"HEAD_DIM": D, "N_ROUNDED": N, "IS_BASE_E": is_lse_base_on_e} ctx.call_kernel(_correct_attn_cp_out_kernel, grid, *regular_args, **const_args) return out, lse @@ -174,6 +183,7 @@ def _cp_lse_common( cp_attn_lse: torch.Tensor, cp_group: GroupCoordinator, ctx: CPTritonContext | None = None, + is_lse_base_on_e=True, ): """ cp_attn_out: [ B, H, D ] @@ -193,7 +203,13 @@ def _cp_lse_common( cp_attn_lse = cp_attn_lse.contiguous() lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses) - out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx) + out, lse = correct_attn_out( + cp_attn_out, + lses, + cp_group.rank_in_group, + ctx, + is_lse_base_on_e=is_lse_base_on_e, + ) return out, lse @@ -203,12 +219,15 @@ def cp_lse_ag_out_rs( cp_group: GroupCoordinator, ctx: CPTritonContext | None = None, return_lse: bool = False, + is_lse_base_on_e=True, ): """ cp_attn_out: [ B, H, D ] cp_attn_lse: [ B, H ] """ - out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx) + out, lse = _cp_lse_common( + cp_attn_out, cp_attn_lse, cp_group, ctx=ctx, is_lse_base_on_e=is_lse_base_on_e + ) out = cp_group.reduce_scatter(out, dim=1) if return_lse: @@ -225,12 +244,15 @@ def cp_lse_ag_out_ar( cp_group: GroupCoordinator, ctx: CPTritonContext | None = None, return_lse: bool = False, + is_lse_base_on_e=True, ): """ cp_attn_out: [ B, H, D ] cp_attn_lse: [ B, H ] """ - out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx) + out, lse = _cp_lse_common( + cp_attn_out, cp_attn_lse, cp_group, ctx=ctx, is_lse_base_on_e=is_lse_base_on_e + ) out = cp_group.all_reduce(out) if return_lse: diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 777398bf8a20..69a6a5e5fae8 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -249,7 +249,11 @@ def run( return_lse=True, ) output_context, lse_context = cp_lse_ag_out_rs( - output_context_tmp, lse_context_tmp, get_dcp_group(), return_lse=True + output_context_tmp, + lse_context_tmp, + get_dcp_group(), + return_lse=True, + is_lse_base_on_e=False, ) lse_context = lse_context.transpose(0, 1).contiguous() @@ -1335,7 +1339,10 @@ def forward( return_lse=True, ) output[:num_decode_tokens] = cp_lse_ag_out_rs( - output_tmp, lse, get_dcp_group() + output_tmp, + lse, + get_dcp_group(), + is_lse_base_on_e=False, ) else: decode_wrapper.run( diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index d94ed9183f63..b09541dbf791 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -2057,7 +2057,12 @@ def forward( # correct dcp attn_out with lse. if self.dcp_world_size > 1: - attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group()) + attn_out = cp_lse_ag_out_rs( + attn_out, + lse, + get_dcp_group(), + is_lse_base_on_e=not self._use_fi_prefill, + ) # v_up projection self._v_up_proj(attn_out, out=output[:num_decode_tokens])