Skip to content

Commit be7721f

Browse files
committed
format code with pre-commit
Signed-off-by: augusto.yjh <augusto.yjh@antgroup.com>
1 parent 629af91 commit be7721f

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

vllm/attention/ops/common.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,13 @@ def _cp_lse_common(
203203

204204
cp_attn_lse = cp_attn_lse.contiguous()
205205
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
206-
out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx, is_lse_base_on_e=is_lse_base_on_e)
206+
out, lse = correct_attn_out(
207+
cp_attn_out,
208+
lses,
209+
cp_group.rank_in_group,
210+
ctx,
211+
is_lse_base_on_e=is_lse_base_on_e,
212+
)
207213
return out, lse
208214

209215

@@ -219,7 +225,9 @@ def cp_lse_ag_out_rs(
219225
cp_attn_out: [ B, H, D ]
220226
cp_attn_lse: [ B, H ]
221227
"""
222-
out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx, is_lse_base_on_e=is_lse_base_on_e)
228+
out, lse = _cp_lse_common(
229+
cp_attn_out, cp_attn_lse, cp_group, ctx=ctx, is_lse_base_on_e=is_lse_base_on_e
230+
)
223231
out = cp_group.reduce_scatter(out, dim=1)
224232

225233
if return_lse:
@@ -242,7 +250,9 @@ def cp_lse_ag_out_ar(
242250
cp_attn_out: [ B, H, D ]
243251
cp_attn_lse: [ B, H ]
244252
"""
245-
out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx, is_lse_base_on_e=is_lse_base_on_e)
253+
out, lse = _cp_lse_common(
254+
cp_attn_out, cp_attn_lse, cp_group, ctx=ctx, is_lse_base_on_e=is_lse_base_on_e
255+
)
246256
out = cp_group.all_reduce(out)
247257

248258
if return_lse:

0 commit comments

Comments
 (0)