Skip to content

Commit 9d04ff1

Browse files
committed
[NIXL] Add remote_request_id to kv_transfer_params
Include the internal request ID that the prefill instance is expecting the decode instance to send it in the NIXL notification. Right now, we rely on the proxy supplying the ID via X-Request-ID and that prefill and decode will mangle this ID in identical ways. This is obviously quite brittle, and P should be explicit about what ID it expects from D. Relates to #27987 - adding a random prefix to client-provided request IDs. Signed-off-by: Mark McLoughlin <markmc@redhat.com>
1 parent ccbdf51 commit 9d04ff1

File tree

3 files changed

+19
-2
lines changed

3 files changed

+19
-2
lines changed

tests/v1/kv_connector/unit/test_nixl_connector.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,7 @@ def test_multi_xfer_one_engine(
460460
num_xfers + 6,
461461
],
462462
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
463+
"remote_request_id": f"prefill-{request_id}",
463464
"remote_host": "localhost",
464465
"remote_port": 1234,
465466
"remote_tp_size": 1,
@@ -526,6 +527,7 @@ def test_async_load_kv(
526527
kv_transfer_params={
527528
"remote_block_ids": [4, 5, 6],
528529
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
530+
"remote_request_id": "prefill-id",
529531
"remote_host": "localhost",
530532
"remote_port": 1234,
531533
"remote_tp_size": prefill_tp_size,
@@ -581,6 +583,7 @@ def test_concurrent_load_kv(
581583
kv_transfer_params={
582584
"remote_block_ids": [4, 5, 6],
583585
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
586+
"remote_request_id": f"prefill-id-{i}",
584587
"remote_host": "localhost",
585588
"remote_port": 1234,
586589
"remote_tp_size": 1,
@@ -746,6 +749,7 @@ def test_kv_connector_stats(dist_init):
746749
kv_transfer_params={
747750
"remote_block_ids": [4, 5, 6],
748751
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
752+
"remote_request_id": f"prefill-{request_id}",
749753
"remote_host": "localhost",
750754
"remote_port": 1234,
751755
"remote_tp_size": 1,
@@ -1459,6 +1463,7 @@ def test_handshake_failure_returns_finished(dist_init):
14591463
kv_transfer_params={
14601464
"remote_block_ids": [4, 5, 6],
14611465
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
1466+
"remote_request_id": f"prefill-{request_id}",
14621467
"remote_host": "localhost",
14631468
"remote_port": 1234,
14641469
"remote_tp_size": 1,
@@ -1508,6 +1513,7 @@ def test_transfer_setup_failure_returns_finished(dist_init):
15081513
kv_transfer_params={
15091514
"remote_block_ids": [10, 11, 12],
15101515
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
1516+
"remote_request_id": f"prefill-{request_id}",
15111517
"remote_host": "localhost",
15121518
"remote_port": 1234,
15131519
"remote_tp_size": 1,

tests/v1/kv_connector/unit/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ def create_request(
187187
do_remote_prefill=True,
188188
do_remote_decode=False,
189189
remote_engine_id="my-engine-id",
190+
remote_request_id=f"prefill-{request_id}",
190191
remote_block_ids=list(range(num_remote_blocks)),
191192
remote_host="my-host",
192193
remote_port=1234,

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ class ReqMeta:
118118
remote_host: str
119119
remote_port: int
120120
remote_engine_id: str
121+
remote_request_id: str
121122
tp_size: int
122123

123124

@@ -144,6 +145,7 @@ def add_new_req(
144145
local_physical_block_ids=local_block_ids,
145146
remote_block_ids=kv_transfer_params["remote_block_ids"],
146147
remote_engine_id=kv_transfer_params["remote_engine_id"],
148+
remote_request_id=kv_transfer_params["remote_request_id"],
147149
remote_host=kv_transfer_params["remote_host"],
148150
remote_port=kv_transfer_params["remote_port"],
149151
# P workers don't need to receive tp_size from proxy here.
@@ -530,7 +532,12 @@ def update_state_after_alloc(
530532
if params.get("remote_block_ids"):
531533
if all(
532534
p in params
533-
for p in ("remote_engine_id", "remote_host", "remote_port")
535+
for p in (
536+
"remote_engine_id",
537+
"remote_request_id",
538+
"remote_host",
539+
"remote_port",
540+
)
534541
):
535542
# If remote_blocks and num_external_tokens = 0, we have
536543
# a full prefix cache hit on the D worker. We need to call
@@ -659,6 +666,7 @@ def request_finished(
659666
do_remote_decode=False,
660667
remote_block_ids=block_ids,
661668
remote_engine_id=self.engine_id,
669+
remote_request_id=request.request_id,
662670
remote_host=self.side_channel_host,
663671
remote_port=self.side_channel_port,
664672
tp_size=self.vllm_config.parallel_config.tensor_parallel_size,
@@ -1946,6 +1954,7 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
19461954
self._read_blocks(
19471955
request_id=req_id,
19481956
dst_engine_id=meta.remote_engine_id,
1957+
remote_request_id=meta.remote_request_id,
19491958
local_block_ids=meta.local_physical_block_ids,
19501959
remote_block_ids=meta.remote_block_ids,
19511960
)
@@ -1956,6 +1965,7 @@ def _read_blocks(
19561965
remote_block_ids: list[int],
19571966
dst_engine_id: str,
19581967
request_id: str,
1968+
remote_request_id: str,
19591969
):
19601970
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id)
19611971
if block_size_ratio > 1:
@@ -1988,7 +1998,7 @@ def _read_blocks(
19881998
# Number of D TP workers that will read from dst P. Propagate tp_ratio
19891999
# on notification so that dst worker can wait before freeing blocks.
19902000
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(dst_engine_id)
1991-
notif_id = f"{request_id}:{tp_ratio}".encode()
2001+
notif_id = f"{remote_request_id}:{tp_ratio}".encode()
19922002

19932003
# Full prefix cache hit: do not need to read remote blocks,
19942004
# just notify P worker that we have the blocks we need.

0 commit comments

Comments
 (0)