Skip to content
42 changes: 32 additions & 10 deletions vllm/attention/ops/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def _correct_attn_cp_out_kernel(
lse_idx,
HEAD_DIM: tl.constexpr,
N_ROUNDED: tl.constexpr,
IS_BASE_E: tl.constexpr,
):
"""
Apply the all-gathered lses to correct each local rank's attention
Expand Down Expand Up @@ -55,9 +56,14 @@ def _correct_attn_cp_out_kernel(
lse_max = tl.max(lse, axis=0)
lse_max = tl.where(lse_max == -float("inf"), 0, lse_max)
lse -= lse_max
lse_exp = tl.exp(lse)
lse_acc = tl.sum(lse_exp, axis=0)
lse = tl.log(lse_acc)
if IS_BASE_E:
lse_exp = tl.exp(lse)
lse_acc = tl.sum(lse_exp, axis=0)
lse = tl.log(lse_acc)
else:
lse_exp = tl.exp2(lse)
lse_acc = tl.sum(lse_exp, axis=0)
lse = tl.log2(lse_acc)
lse += lse_max

lse_offsets = batch_idx * lses_stride_B + head_idx * lses_stride_H
Expand All @@ -81,7 +87,7 @@ def _correct_attn_cp_out_kernel(
-float("inf"),
lse_finally,
)
factor = tl.exp(lse_finally)
factor = tl.exp(lse_finally) if IS_BASE_E else tl.exp2(lse_finally)
output = tl.load(outputs_ptr + output_offsets)
output = output * factor

Expand All @@ -102,7 +108,11 @@ def call_kernel(self, kernel, grid, *regular_args, **const_args):


def correct_attn_out(
out: torch.Tensor, lses: torch.Tensor, cp_rank: int, ctx: CPTritonContext
out: torch.Tensor,
lses: torch.Tensor,
cp_rank: int,
ctx: CPTritonContext,
is_lse_base_on_e: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Correct the attention output using the all-gathered lses.

Expand Down Expand Up @@ -163,8 +173,7 @@ def correct_attn_out(
l_sH,
cp_rank,
)
const_args = {"HEAD_DIM": D, "N_ROUNDED": N}

const_args = {"HEAD_DIM": D, "N_ROUNDED": N, "IS_BASE_E": is_lse_base_on_e}
ctx.call_kernel(_correct_attn_cp_out_kernel, grid, *regular_args, **const_args)
return out, lse

Expand All @@ -174,6 +183,7 @@ def _cp_lse_common(
cp_attn_lse: torch.Tensor,
cp_group: GroupCoordinator,
ctx: CPTritonContext | None = None,
is_lse_base_on_e=True,
):
"""
cp_attn_out: [ B, H, D ]
Expand All @@ -193,7 +203,13 @@ def _cp_lse_common(

cp_attn_lse = cp_attn_lse.contiguous()
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
out, lse = correct_attn_out(
cp_attn_out,
lses,
cp_group.rank_in_group,
ctx,
is_lse_base_on_e=is_lse_base_on_e,
)
return out, lse


Expand All @@ -203,12 +219,15 @@ def cp_lse_ag_out_rs(
cp_group: GroupCoordinator,
ctx: CPTritonContext | None = None,
return_lse: bool = False,
is_lse_base_on_e=True,
):
"""
cp_attn_out: [ B, H, D ]
cp_attn_lse: [ B, H ]
"""
out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx)
out, lse = _cp_lse_common(
cp_attn_out, cp_attn_lse, cp_group, ctx=ctx, is_lse_base_on_e=is_lse_base_on_e
)
out = cp_group.reduce_scatter(out, dim=1)

if return_lse:
Expand All @@ -225,12 +244,15 @@ def cp_lse_ag_out_ar(
cp_group: GroupCoordinator,
ctx: CPTritonContext | None = None,
return_lse: bool = False,
is_lse_base_on_e=True,
):
"""
cp_attn_out: [ B, H, D ]
cp_attn_lse: [ B, H ]
"""
out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx)
out, lse = _cp_lse_common(
cp_attn_out, cp_attn_lse, cp_group, ctx=ctx, is_lse_base_on_e=is_lse_base_on_e
)
out = cp_group.all_reduce(out)

if return_lse:
Expand Down
11 changes: 9 additions & 2 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,11 @@ def run(
return_lse=True,
)
output_context, lse_context = cp_lse_ag_out_rs(
output_context_tmp, lse_context_tmp, get_dcp_group(), return_lse=True
output_context_tmp,
lse_context_tmp,
get_dcp_group(),
return_lse=True,
is_lse_base_on_e=False,
)
lse_context = lse_context.transpose(0, 1).contiguous()

Expand Down Expand Up @@ -1335,7 +1339,10 @@ def forward(
return_lse=True,
)
output[:num_decode_tokens] = cp_lse_ag_out_rs(
output_tmp, lse, get_dcp_group()
output_tmp,
lse,
get_dcp_group(),
is_lse_base_on_e=False,
)
else:
decode_wrapper.run(
Expand Down
7 changes: 6 additions & 1 deletion vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2057,7 +2057,12 @@ def forward(

# correct dcp attn_out with lse.
if self.dcp_world_size > 1:
attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group())
attn_out = cp_lse_ag_out_rs(
attn_out,
lse,
get_dcp_group(),
is_lse_base_on_e=not self._use_fi_prefill,
)

# v_up projection
self._v_up_proj(attn_out, out=output[:num_decode_tokens])
Expand Down