195195
196196import torch
197197from tqdm import tqdm
198+ import numpy as np
198199
199200import vllm .envs as envs
200201from vllm import _custom_ops as ops
@@ -845,15 +846,7 @@ def build(
845846 None ,
846847 self .dcp_local_block_size ,
847848 )
848- # Note(qcs): The max local context lengths
849- # padded to `dcp_local_block_size`.
850- local_context_lens_cpu = (
851- cdiv (
852- context_lens_cpu ,
853- self .dcp_virtual_block_size ,
854- )
855- * self .dcp_local_block_size
856- )
849+ local_context_lens_cpu = local_context_lens_allrank [:, self .dcp_rank ]
857850 # Note(hc): The above max_context_chunk already enforces
858851 # block_size alignment, DCP just need the block_size can
859852 # be divisible by dcp_world_size, because DCP use
@@ -989,7 +982,7 @@ def reorg_kvcache(
989982 local_context_lens_allrank : list [list [int ]],
990983 sum_seq_len : int ,
991984 max_seq_len : int ,
992- toks : int ,
985+ local_context_lens_sum : list [ int ] ,
993986) -> tuple [torch .Tensor , torch .Tensor ]:
994987 """
995988 reorg kvcache after cp local gather to tp layout for attn kernel.
@@ -1001,30 +994,35 @@ def reorg_kvcache(
1001994 sum_seq_len: the sum of cp_chunk_seq_lens_lst.
1002995 max_seq_len: the max value of cp_chunk_seq_lens_lst.
1003996 toks: the number of tokens for local gather cache.
997+ local_context_lens_sum: the total context tokens of all request
998+ on each CP rank.
1004999 """
10051000 kv_c_segments = []
10061001 k_pe_segments = []
10071002 src_token_idx = 0
10081003 max_seq_len_check = 0
1004+
10091005 for local_chunk_seq_len , local_context_lens in zip (
10101006 local_chunk_seq_lens_lst , local_context_lens_allrank
10111007 ):
10121008 cur_seq_len = 0
1009+ context_len_across_rank = 0
10131010 for rank , local_context_len in enumerate (local_context_lens ):
10141011 if local_context_len != 0 :
10151012 kv_c_segment = allgatered_kv_c_normed [
1016- rank * toks + src_token_idx : rank * toks
1013+ context_len_across_rank + src_token_idx : context_len_across_rank
10171014 + src_token_idx
10181015 + local_context_len
10191016 ]
10201017 k_pe_segment = allgatered_k_pe [
1021- rank * toks + src_token_idx : rank * toks
1018+ context_len_across_rank + src_token_idx : context_len_across_rank
10221019 + src_token_idx
10231020 + local_context_len
10241021 ]
10251022 kv_c_segments .append (kv_c_segment )
10261023 k_pe_segments .append (k_pe_segment )
10271024 cur_seq_len += local_context_len
1025+ context_len_across_rank += local_context_lens_sum [rank ]
10281026 max_seq_len_check = max (max_seq_len_check , cur_seq_len )
10291027 src_token_idx += local_chunk_seq_len
10301028 reorganized_kv_c_normed = torch .cat (kv_c_segments , dim = 0 )
@@ -1613,11 +1611,21 @@ def _context_parallel_compute_prefill_context(
16131611 cur_allgather_workspace = workspace [
16141612 allgather_offset : allgather_offset * (1 + dcp_world_size )
16151613 ]
1614+ local_context_lens_allrank = (
1615+ prefill_metadata .chunked_context .local_context_lens_allrank
1616+ )
1617+ local_context_lens_sum = np .sum (local_context_lens_allrank , axis = 0 ).tolist ()
16161618 assert toks * dcp_world_size <= cur_allgather_workspace .shape [0 ]
1617- cur_allgather_kvcache = cur_allgather_workspace [: toks * dcp_world_size ]
1619+ cur_allgather_kvcache = cur_allgather_workspace [: sum (local_context_lens_sum )]
1620+
16181621 cur_allgather_kvcache .copy_ (
1619- get_dcp_group ().all_gather (local_gathered_kvcache , dim = 0 )
1622+ get_dcp_group ().all_gatherv (
1623+ local_gathered_kvcache ,
1624+ dim = 0 ,
1625+ sizes = local_context_lens_sum
1626+ )
16201627 )
1628+
16211629 assert (
16221630 cur_allgather_kvcache .shape [- 1 ]
16231631 == self .kv_lora_rank + self .qk_rope_head_dim
@@ -1632,10 +1640,11 @@ def _context_parallel_compute_prefill_context(
16321640 local_chunk_seq_lens_lst = prefill_metadata .chunked_context .local_chunk_seq_lens [
16331641 i
16341642 ],
1635- local_context_lens_allrank = prefill_metadata .chunked_context .local_context_lens_allrank ,
1643+ local_context_lens_allrank =
1644+ prefill_metadata .chunked_context .local_context_lens_allrank ,
16361645 sum_seq_len = prefill_metadata .chunked_context .cu_seq_lens_lst [i ][- 1 ],
16371646 max_seq_len = prefill_metadata .chunked_context .max_seq_lens [i ],
1638- toks = toks ,
1647+ local_context_lens_sum = local_context_lens_sum ,
16391648 )
16401649
16411650 kv_nope = self .kv_b_proj (kv_c_normed )[0 ].view (
0 commit comments