Skip to content

Commit abf6be5

Browse files
committed
format code and pass lse base for flashinfer chunked prefill
Signed-off-by: augusto.yjh <augusto.yjh@antgroup.com>
1 parent e8adf8d commit abf6be5

File tree

3 files changed

+28
-4
lines changed

3 files changed

+28
-4
lines changed

vllm/attention/ops/common.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,11 @@ def call_kernel(self, kernel, grid, *regular_args, **const_args):
111111

112112

113113
def correct_attn_out(
114-
out: torch.Tensor, lses: torch.Tensor, cp_rank: int, ctx: CPTritonContext, is_lse_base_on_e: bool=True,
114+
out: torch.Tensor,
115+
lses: torch.Tensor,
116+
cp_rank: int,
117+
ctx: CPTritonContext,
118+
is_lse_base_on_e: bool = True,
115119
) -> tuple[torch.Tensor, torch.Tensor]:
116120
"""Correct the attention output using the all-gathered lses.
117121
@@ -202,6 +206,7 @@ def _cp_lse_common(
202206

203207
cp_attn_lse = cp_attn_lse.contiguous()
204208
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
209+
<<<<<<< HEAD
205210
out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx, is_lse_base_on_e=is_lse_base_on_e)
206211
return out, lse
207212

@@ -219,6 +224,13 @@ def cp_lse_ag_out_rs(
219224
cp_attn_lse: [ B, H ]
220225
"""
221226
out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx, is_lse_base_on_e=is_lse_base_on_e)
227+
out, lse = correct_attn_out(
228+
cp_attn_out,
229+
lses,
230+
cp_group.rank_in_group,
231+
ctx,
232+
is_lse_base_on_e=is_lse_base_on_e,
233+
)
222234
out = cp_group.reduce_scatter(out, dim=1)
223235

224236
if return_lse:

vllm/v1/attention/backends/flashinfer.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,11 @@ def run(
249249
return_lse=True,
250250
)
251251
output_context, lse_context = cp_lse_ag_out_rs(
252-
output_context_tmp, lse_context_tmp, get_dcp_group(), return_lse=True, is_lse_base_on_e=False,
252+
output_context_tmp,
253+
lse_context_tmp,
254+
get_dcp_group(),
255+
return_lse=True,
256+
is_lse_base_on_e=False,
253257
)
254258
lse_context = lse_context.transpose(0, 1).contiguous()
255259

@@ -1335,7 +1339,10 @@ def forward(
13351339
return_lse=True,
13361340
)
13371341
output[:num_decode_tokens] = cp_lse_ag_out_rs(
1338-
output_tmp, lse, get_dcp_group(), is_lse_base_on_e=False,
1342+
output_tmp,
1343+
lse,
1344+
get_dcp_group(),
1345+
is_lse_base_on_e=False,
13391346
)
13401347
else:
13411348
decode_wrapper.run(

vllm/v1/attention/backends/mla/common.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2057,7 +2057,12 @@ def forward(
20572057

20582058
# correct dcp attn_out with lse.
20592059
if self.dcp_world_size > 1:
2060-
attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group())
2060+
attn_out = cp_lse_ag_out_rs(
2061+
attn_out,
2062+
lse,
2063+
get_dcp_group(),
2064+
is_lse_base_on_e: not self._use_fi_prefill,
2065+
)
20612066

20622067
# v_up projection
20632068
self._v_up_proj(attn_out, out=output[:num_decode_tokens])

0 commit comments

Comments
 (0)