Skip to content

Commit c519e25

Browse files
committed
[NIXL] Add remote_request_id to kv_transfer_params
Include the internal request ID that the prefill instance is expecting the decode instance to send it in the NIXL notification. Right now, we rely on the proxy supplying the ID via X-Request-ID and that prefill and decode will mangle this ID in identical ways. This is obviously quite brittle, and P should be explicit about what ID it expects from D. Relates to #27987 - adding a random prefix to client-provided request IDs. Signed-off-by: Mark McLoughlin <markmc@redhat.com>
1 parent ccbdf51 commit c519e25

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ class ReqMeta:
118118
remote_host: str
119119
remote_port: int
120120
remote_engine_id: str
121+
remote_request_id: str
121122
tp_size: int
122123

123124

@@ -144,6 +145,7 @@ def add_new_req(
144145
local_physical_block_ids=local_block_ids,
145146
remote_block_ids=kv_transfer_params["remote_block_ids"],
146147
remote_engine_id=kv_transfer_params["remote_engine_id"],
148+
remote_request_id=kv_transfer_params["remote_request_id"],
147149
remote_host=kv_transfer_params["remote_host"],
148150
remote_port=kv_transfer_params["remote_port"],
149151
# P workers don't need to receive tp_size from proxy here.
@@ -530,7 +532,12 @@ def update_state_after_alloc(
530532
if params.get("remote_block_ids"):
531533
if all(
532534
p in params
533-
for p in ("remote_engine_id", "remote_host", "remote_port")
535+
for p in (
536+
"remote_engine_id",
537+
"remote_request_id",
538+
"remote_host",
539+
"remote_port",
540+
)
534541
):
535542
# If remote_blocks and num_external_tokens = 0, we have
536543
# a full prefix cache hit on the D worker. We need to call
@@ -659,6 +666,7 @@ def request_finished(
659666
do_remote_decode=False,
660667
remote_block_ids=block_ids,
661668
remote_engine_id=self.engine_id,
669+
remote_request_id=request.request_id,
662670
remote_host=self.side_channel_host,
663671
remote_port=self.side_channel_port,
664672
tp_size=self.vllm_config.parallel_config.tensor_parallel_size,
@@ -1946,6 +1954,7 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
19461954
self._read_blocks(
19471955
request_id=req_id,
19481956
dst_engine_id=meta.remote_engine_id,
1957+
remote_request_id=meta.remote_request_id,
19491958
local_block_ids=meta.local_physical_block_ids,
19501959
remote_block_ids=meta.remote_block_ids,
19511960
)
@@ -1956,6 +1965,7 @@ def _read_blocks(
19561965
remote_block_ids: list[int],
19571966
dst_engine_id: str,
19581967
request_id: str,
1968+
remote_request_id: str,
19591969
):
19601970
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id)
19611971
if block_size_ratio > 1:
@@ -1988,7 +1998,7 @@ def _read_blocks(
19881998
# Number of D TP workers that will read from dst P. Propagate tp_ratio
19891999
# on notification so that dst worker can wait before freeing blocks.
19902000
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(dst_engine_id)
1991-
notif_id = f"{request_id}:{tp_ratio}".encode()
2001+
notif_id = f"{remote_request_id}:{tp_ratio}".encode()
19922002

19932003
# Full prefix cache hit: do not need to read remote blocks,
19942004
# just notify P worker that we have the blocks we need.

0 commit comments

Comments
 (0)