Skip to content

Commit 1a94cb7

Browse files
gjc0824pisceskkk
andcommitted
[Refactor] all gather the accurate context lengths
Co-authored-by: gaojc <1055866782@qq.com> Co-authored-by: QiuChunshuo <qiuchunshuo@huawei.com> Signed-off-by: gaojc <1055866782@qq.com> Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
1 parent 1f8ffde commit 1a94cb7

File tree

1 file changed

+25
-16
lines changed

1 file changed

+25
-16
lines changed

vllm/v1/attention/backends/mla/common.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@
195195

196196
import torch
197197
from tqdm import tqdm
198+
import numpy as np
198199

199200
import vllm.envs as envs
200201
from vllm import _custom_ops as ops
@@ -845,15 +846,7 @@ def build(
845846
None,
846847
self.dcp_local_block_size,
847848
)
848-
# Note(qcs): The max local context lengths
849-
# padded to `dcp_local_block_size`.
850-
local_context_lens_cpu = (
851-
cdiv(
852-
context_lens_cpu,
853-
self.dcp_virtual_block_size,
854-
)
855-
* self.dcp_local_block_size
856-
)
849+
local_context_lens_cpu = local_context_lens_allrank[:, self.dcp_rank]
857850
# Note(hc): The above max_context_chunk already enforces
858851
# block_size alignment, DCP just need the block_size can
859852
# be divisible by dcp_world_size, because DCP use
@@ -989,7 +982,7 @@ def reorg_kvcache(
989982
local_context_lens_allrank: list[list[int]],
990983
sum_seq_len: int,
991984
max_seq_len: int,
992-
toks: int,
985+
local_context_lens_sum: list[int],
993986
) -> tuple[torch.Tensor, torch.Tensor]:
994987
"""
995988
reorg kvcache after cp local gather to tp layout for attn kernel.
@@ -1001,30 +994,35 @@ def reorg_kvcache(
1001994
sum_seq_len: the sum of cp_chunk_seq_lens_lst.
1002995
max_seq_len: the max value of cp_chunk_seq_lens_lst.
1003996
toks: the number of tokens for local gather cache.
997+
local_context_lens_sum: the total context tokens of all request
998+
on each CP rank.
1004999
"""
10051000
kv_c_segments = []
10061001
k_pe_segments = []
10071002
src_token_idx = 0
10081003
max_seq_len_check = 0
1004+
10091005
for local_chunk_seq_len, local_context_lens in zip(
10101006
local_chunk_seq_lens_lst, local_context_lens_allrank
10111007
):
10121008
cur_seq_len = 0
1009+
context_len_across_rank = 0
10131010
for rank, local_context_len in enumerate(local_context_lens):
10141011
if local_context_len != 0:
10151012
kv_c_segment = allgatered_kv_c_normed[
1016-
rank * toks + src_token_idx : rank * toks
1013+
context_len_across_rank + src_token_idx : context_len_across_rank
10171014
+ src_token_idx
10181015
+ local_context_len
10191016
]
10201017
k_pe_segment = allgatered_k_pe[
1021-
rank * toks + src_token_idx : rank * toks
1018+
context_len_across_rank + src_token_idx : context_len_across_rank
10221019
+ src_token_idx
10231020
+ local_context_len
10241021
]
10251022
kv_c_segments.append(kv_c_segment)
10261023
k_pe_segments.append(k_pe_segment)
10271024
cur_seq_len += local_context_len
1025+
context_len_across_rank += local_context_lens_sum[rank]
10281026
max_seq_len_check = max(max_seq_len_check, cur_seq_len)
10291027
src_token_idx += local_chunk_seq_len
10301028
reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0)
@@ -1613,11 +1611,21 @@ def _context_parallel_compute_prefill_context(
16131611
cur_allgather_workspace = workspace[
16141612
allgather_offset : allgather_offset * (1 + dcp_world_size)
16151613
]
1614+
local_context_lens_allrank = (
1615+
prefill_metadata.chunked_context.local_context_lens_allrank
1616+
)
1617+
local_context_lens_sum = np.sum(local_context_lens_allrank, axis=0).tolist()
16161618
assert toks * dcp_world_size <= cur_allgather_workspace.shape[0]
1617-
cur_allgather_kvcache = cur_allgather_workspace[: toks * dcp_world_size]
1619+
cur_allgather_kvcache = cur_allgather_workspace[: sum(local_context_lens_sum)]
1620+
16181621
cur_allgather_kvcache.copy_(
1619-
get_dcp_group().all_gather(local_gathered_kvcache, dim=0)
1622+
get_dcp_group().all_gatherv(
1623+
local_gathered_kvcache,
1624+
dim=0,
1625+
sizes=local_context_lens_sum
1626+
)
16201627
)
1628+
16211629
assert (
16221630
cur_allgather_kvcache.shape[-1]
16231631
== self.kv_lora_rank + self.qk_rope_head_dim
@@ -1632,10 +1640,11 @@ def _context_parallel_compute_prefill_context(
16321640
local_chunk_seq_lens_lst=prefill_metadata.chunked_context.local_chunk_seq_lens[
16331641
i
16341642
],
1635-
local_context_lens_allrank=prefill_metadata.chunked_context.local_context_lens_allrank,
1643+
local_context_lens_allrank=
1644+
prefill_metadata.chunked_context.local_context_lens_allrank,
16361645
sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i][-1],
16371646
max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i],
1638-
toks=toks,
1647+
local_context_lens_sum=local_context_lens_sum,
16391648
)
16401649

16411650
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(

0 commit comments

Comments
 (0)