Skip to content

Commit b728070

Browse files
committed
bugfix: correct attn output with base 2 or e
flashinfer attention use 2 as base of lse instead of e, see https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/attention/mla.cuh#L400 Signed-off-by: augusto.yjh <augusto.yjh@antgroup.com>
1 parent 3380ed5 commit b728070

File tree

2 files changed

+91
-6
lines changed

2 files changed

+91
-6
lines changed

vllm/attention/ops/common.py

Lines changed: 89 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,88 @@ def _correct_attn_cp_out_kernel(
8888
tl.store(new_output_ptr + output_offsets, output)
8989

9090

91+
@triton.jit
92+
def _correct_attn_cp_out_kernel_for_flashinfer(
93+
outputs_ptr,
94+
new_output_ptr,
95+
lses_ptr,
96+
vlse_ptr,
97+
outputs_stride_B,
98+
outputs_stride_H,
99+
outputs_stride_D,
100+
lses_stride_N,
101+
lses_stride_B,
102+
lses_stride_H,
103+
lse_idx,
104+
HEAD_DIM: tl.constexpr,
105+
N_ROUNDED: tl.constexpr,
106+
):
107+
"""
108+
Apply the all-gathered lses to correct each local rank's attention
109+
output. we still need perform a cross-rank reduction to obtain the
110+
final attention output.
111+
112+
Args:
113+
outputs_ptr (triton.PointerType):
114+
Pointer to input tensor of shape [ B, H, D ]
115+
lses_ptr (triton.PointerType):
116+
Pointer to input tensor of shape [ N, B, H ]
117+
new_output_ptr (triton.PointerType):
118+
Pointer to output tensor of shape [ B, H, D ]
119+
vlse_ptr (triton.PointerType):
120+
Pointer to output tensor of shape [ B, H ]
121+
"""
122+
batch_idx = tl.program_id(axis=0).to(tl.int64)
123+
head_idx = tl.program_id(axis=1).to(tl.int64)
124+
d_offsets = tl.arange(0, HEAD_DIM)
125+
num_n_offsets = tl.arange(0, N_ROUNDED)
126+
127+
# shape = [N]
128+
lse_offsets = (
129+
num_n_offsets * lses_stride_N
130+
+ batch_idx * lses_stride_B
131+
+ head_idx * lses_stride_H
132+
)
133+
134+
# calc final lse
135+
lse = tl.load(lses_ptr + lse_offsets)
136+
lse = tl.where((lse != lse) | (lse == float("inf")), -float("inf"), lse)
137+
lse_max = tl.max(lse, axis=0)
138+
lse_max = tl.where(lse_max == -float("inf"), 0, lse_max)
139+
lse -= lse_max
140+
lse_exp = tl.exp2(lse)
141+
lse_acc = tl.sum(lse_exp, axis=0)
142+
lse = tl.log2(lse_acc)
143+
lse += lse_max
144+
145+
lse_offsets = batch_idx * lses_stride_B + head_idx * lses_stride_H
146+
tl.store(vlse_ptr + lse_offsets, lse)
147+
148+
# shape = [D]
149+
output_offsets = (
150+
batch_idx * outputs_stride_B
151+
+ head_idx * outputs_stride_H
152+
+ d_offsets * outputs_stride_D
153+
)
154+
155+
# correct output
156+
lse_offset = (
157+
lse_idx * lses_stride_N + batch_idx * lses_stride_B + head_idx * lses_stride_H
158+
)
159+
lse_tmp = tl.load(lses_ptr + lse_offset)
160+
lse_finally = lse_tmp - lse
161+
lse_finally = tl.where(
162+
(lse_finally != lse_finally) | (lse_finally == float("inf")),
163+
-float("inf"),
164+
lse_finally,
165+
)
166+
factor = tl.exp2(lse_finally)
167+
output = tl.load(outputs_ptr + output_offsets)
168+
output = output * factor
169+
170+
tl.store(new_output_ptr + output_offsets, output)
171+
172+
91173
class CPTritonContext:
92174
"""The CPTritonContext is used to avoid recompilation of the Triton JIT."""
93175

@@ -102,7 +184,7 @@ def call_kernel(self, kernel, grid, *regular_args, **const_args):
102184

103185

104186
def correct_attn_out(
105-
out: torch.Tensor, lses: torch.Tensor, cp_rank: int, ctx: CPTritonContext
187+
out: torch.Tensor, lses: torch.Tensor, cp_rank: int, ctx: CPTritonContext, is_lse_base_on_e: bool=True,
106188
) -> tuple[torch.Tensor, torch.Tensor]:
107189
"""Correct the attention output using the all-gathered lses.
108190
@@ -164,8 +246,10 @@ def correct_attn_out(
164246
cp_rank,
165247
)
166248
const_args = {"HEAD_DIM": D, "N_ROUNDED": N}
167-
168-
ctx.call_kernel(_correct_attn_cp_out_kernel, grid, *regular_args, **const_args)
249+
correct_attn_kernel = _correct_attn_cp_out_kernel
250+
if not is_lse_base_on_e:
251+
correct_attn_kernel = _correct_attn_cp_out_kernel_for_flashinfer
252+
ctx.call_kernel(correct_attn_kernel, grid, *regular_args, **const_args)
169253
return out, lse
170254

171255

@@ -175,6 +259,7 @@ def cp_lse_ag_out_rs(
175259
cp_group: GroupCoordinator,
176260
ctx: CPTritonContext = None,
177261
return_lse=False,
262+
is_lse_base_on_e=True,
178263
):
179264
"""
180265
cp_attn_out: [ B, H, D ]
@@ -194,7 +279,7 @@ def cp_lse_ag_out_rs(
194279

195280
cp_attn_lse = cp_attn_lse.contiguous()
196281
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
197-
out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
282+
out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx, is_lse_base_on_e=is_lse_base_on_e)
198283
out = cp_group.reduce_scatter(out, dim=1)
199284

200285
if return_lse:

vllm/v1/attention/backends/flashinfer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def run(
250250
return_lse=True,
251251
)
252252
output_context, lse_context = cp_lse_ag_out_rs(
253-
output_context_tmp, lse_context_tmp, get_dcp_group(), return_lse=True
253+
output_context_tmp, lse_context_tmp, get_dcp_group(), return_lse=True, is_lse_base_on_e=False,
254254
)
255255
lse_context = lse_context.transpose(0, 1).contiguous()
256256

@@ -1343,7 +1343,7 @@ def forward(
13431343
return_lse=True,
13441344
)
13451345
output[:num_decode_tokens] = cp_lse_ag_out_rs(
1346-
output_tmp, lse, get_dcp_group()
1346+
output_tmp, lse, get_dcp_group(), is_lse_base_on_e=False,
13471347
)
13481348
else:
13491349
decode_wrapper.run(

0 commit comments

Comments
 (0)