Skip to content

Commit 064a690

Browse files
committed
[Core] Add a random suffix to frontend-provided request IDs
Since #9550 and #10968 we support client's supplying a custom request ID. The motivation for this is that it can be very helpful when you need to correlate vLLM logs with logs of a related service. Since the request ID is used ubiquitously across vLLM as a unique key, it obviously is problematic if we ever have multiple in-flight requests using the same client-provided request ID. We saw this happening recently when `vllm serve bench` started including a request ID and the request IDs from multiple concurrent instances caused collisions. See #27723 We try to guard against request ID collisions currently in the frontend in OutputProcessor: ``` def add_request(...): if request_id in self.request_states: raise ValueError(f"Request id {request_id} already running.") ``` however, this is not always effective: 1) We can have abort race conditions where a request is no longer tracked by the frontend, but still not completed in the engine. See #15326 for an attempt to fix this. 2) We can have async scheduling race conditions where a request ID is removed from the output processor and being scheduled while the older request with that ID is still being completed by the model runner. See #29355 3) With P/D, a request will continue to be tracked by the prefill engine long after the prefill request has been completed in the frontend, while we wait for the decode side to fetch the KV blocks. See #20139 Let's instead ensure we use a unique request ID internally, even when a client provides a custom request ID. We can do this simply by appending a short random suffix to any request ID provided by the frontend. We need to ensure we track the external->internal request ID mapping because abort() will be supplied an external request ID. In the case where an external request ID maps to multiple running requests, we assume the caller requires all of those requests to be aborted. The caller can use EngineCoreRequest.request_id as the request ID if they want to be more specific. A full 32 character random UUID would be overkill as a suffix, so how many random characters would be sufficient? 8 characters gives us 32 bits of entropy, or 16^8 possible prefixes. Using the collision probability approximation from https://preshing.com/20110504/hash-collision-probabilities: N = 16^8 and k is the number of generated suffixes, then the probability of collision is (k^2)/(2N), so If a client somehow caused vLLM to hold 10k requests that reuse the same client-provided ID, then there would be a 1.16% chance of collision: ``` >>> (k**2)/(2*N) 0.011641532182693481 ``` That seems (super good enough)[https://hownot2.com/products/hownot2-super-good-enough-t-shirt]. Signed-off-by: Mark McLoughlin <markmc@redhat.com>
1 parent ec7035c commit 064a690

File tree

13 files changed

+111
-26
lines changed

13 files changed

+111
-26
lines changed

tests/tokenizers_/test_detokenize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def _run_incremental_decode(
6262
)
6363
request = EngineCoreRequest(
6464
request_id="",
65+
external_req_id="",
6566
prompt_token_ids=prompt_token_ids,
6667
mm_features=None,
6768
sampling_params=params,

tests/v1/engine/test_process_multi_modal_uuids.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from vllm.assets.image import ImageAsset
77
from vllm.assets.video import VideoAsset
88
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, VllmConfig
9+
from vllm.multimodal import MultiModalUUIDDict
910
from vllm.sampling_params import SamplingParams
1011
from vllm.v1.engine import input_processor as input_processor_mod
1112
from vllm.v1.engine.input_processor import InputProcessor
@@ -166,7 +167,7 @@ def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
166167
monkeypatch, mm_cache_gb=0.0, enable_prefix_caching=False
167168
)
168169

169-
captured: dict[str, object] = {}
170+
captured: dict[str, MultiModalUUIDDict] = {}
170171

171172
def fake_preprocess(
172173
prompt, *, tokenization_kwargs=None, lora_request=None, mm_uuids=None
@@ -196,7 +197,16 @@ def fake_preprocess(
196197
)
197198

198199
# Expect request-id-based overrides are passed through
199-
assert captured["mm_uuids"] == {
200-
"image": [f"{request_id}-image-0", f"{request_id}-image-1"],
201-
"video": [f"{request_id}-video-0"],
202-
}
200+
mm_uuids = captured["mm_uuids"]
201+
assert set(mm_uuids.keys()) == {"image", "video"}
202+
assert len(mm_uuids["image"]) == 2
203+
assert len(mm_uuids["video"]) == 1
204+
assert mm_uuids["image"][0].startswith(f"{request_id}-") and mm_uuids["image"][
205+
0
206+
].endswith("-image-0")
207+
assert mm_uuids["image"][1].startswith(f"{request_id}-") and mm_uuids["image"][
208+
1
209+
].endswith("-image-1")
210+
assert mm_uuids["video"][0].startswith(f"{request_id}-") and mm_uuids["video"][
211+
0
212+
].endswith("-video-0")

tests/v1/kv_connector/unit/test_nixl_connector.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
has_kv_transfer_group,
3737
)
3838
from vllm.forward_context import ForwardContext
39+
from vllm.outputs import RequestOutput
3940
from vllm.platforms.interface import Platform
4041
from vllm.sampling_params import SamplingParams
4142
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
@@ -1077,24 +1078,32 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
10771078
0
10781079
].req_to_blocks
10791080

1081+
def req_id(outputs: list[RequestOutput]):
1082+
assert len(outputs) == 1
1083+
return outputs[0].request_id
1084+
10801085
padding = "Just making this request a little longer so that we're sure "
10811086
"we're not hitting the small-request lower bound beneath which we don't "
10821087
"actually trigger the whole kv transfer, but rather just recompute the "
10831088
"blocks on D."
1084-
_ = llm.generate([f"What is the capital of Japan? {padding}"], sampling_params)
1089+
req0_id = req_id(
1090+
llm.generate([f"What is the capital of Japan? {padding}"], sampling_params)
1091+
)
10851092

10861093
# Request finished but not freed
1087-
assert "0" in scheduler.finished_req_ids and "0" in req_to_blocks
1094+
assert req0_id in scheduler.finished_req_ids and req0_id in req_to_blocks
10881095
# Some other request, 0 still not freed
1089-
_ = llm.generate([f"What is the capital of Italy? {padding}"], sampling_params)
1090-
assert "0" in req_to_blocks
1091-
assert "1" in scheduler.finished_req_ids and "1" in req_to_blocks
1096+
req1_id = req_id(
1097+
llm.generate([f"What is the capital of Italy? {padding}"], sampling_params)
1098+
)
1099+
assert req0_id in req_to_blocks
1100+
assert req1_id in scheduler.finished_req_ids and req1_id in req_to_blocks
10921101

10931102
# Wait for timeout and trigger another scheduler loop
10941103
time.sleep(timeout)
10951104
_ = llm.generate([f"What is the capital of France? {padding}"], sampling_params)
10961105
# Request-0 times out and is cleared!
1097-
assert "0" not in req_to_blocks
1106+
assert req0_id not in req_to_blocks
10981107
# Need to shutdown the background thread to release NIXL side channel port
10991108
llm.llm_engine.engine_core.shutdown()
11001109

vllm/entrypoints/llm.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1700,15 +1700,30 @@ def _add_request(
17001700
)
17011701

17021702
self.llm_engine.add_request(
1703-
request_id,
1703+
engine_request.request_id,
17041704
engine_request,
17051705
params,
17061706
lora_request=lora_request,
17071707
tokenization_kwargs=tokenization_kwargs,
17081708
priority=priority,
17091709
prompt_text=prompt_text,
17101710
)
1711-
return request_id
1711+
return engine_request.request_id
1712+
1713+
@staticmethod
1714+
def _sort_outputs(
1715+
outputs: list[RequestOutput | PoolingRequestOutput],
1716+
) -> list[RequestOutput | PoolingRequestOutput]:
1717+
# Sort the outputs by request ID.
1718+
# This is necessary because some requests may be finished earlier than
1719+
# its previous requests.
1720+
1721+
# Extract the original request ID prefix for sorting.
1722+
# See how InputProcessor._generate_request_id() adds a random suffix
1723+
def extract_request_id_prefix(request_id: str) -> int:
1724+
return int(request_id.rsplit("-", 1)[0])
1725+
1726+
return sorted(outputs, key=lambda x: extract_request_id_prefix(x.request_id))
17121727

17131728
def _run_engine(
17141729
self, *, use_tqdm: bool | Callable[..., tqdm] = True
@@ -1756,7 +1771,5 @@ def _run_engine(
17561771

17571772
if use_tqdm:
17581773
pbar.close()
1759-
# Sort the outputs by request ID.
1760-
# This is necessary because some requests may be finished earlier than
1761-
# its previous requests.
1762-
return sorted(outputs, key=lambda x: int(x.request_id))
1774+
1775+
return self._sort_outputs(outputs)

vllm/entrypoints/openai/serving_chat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ async def create_chat_completion(
341341
generator = self.engine_client.generate(
342342
engine_request,
343343
sampling_params,
344-
sub_request_id,
344+
engine_request.request_id,
345345
lora_request=lora_request,
346346
trace_headers=trace_headers,
347347
priority=request.priority,

vllm/entrypoints/openai/serving_completion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ async def create_completion(
231231
generator = self.engine_client.generate(
232232
engine_request,
233233
sampling_params,
234-
request_id_item,
234+
engine_request.request_id,
235235
lora_request=lora_request,
236236
trace_headers=trace_headers,
237237
priority=request.priority,

vllm/entrypoints/openai/serving_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1260,7 +1260,7 @@ async def _generate_with_builtin_tools(
12601260
generator = self.engine_client.generate(
12611261
engine_request,
12621262
sampling_params,
1263-
sub_request_id,
1263+
engine_request.request_id,
12641264
lora_request=lora_request,
12651265
priority=priority,
12661266
prompt_text=prompt_text,

vllm/entrypoints/pooling/embed/serving.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -536,8 +536,8 @@ async def _collect_batch(
536536
# Non-chunked result - extract prompt_idx from request_id
537537
parts = result.request_id.split("-")
538538
try:
539-
# Last part should be prompt index
540-
prompt_idx = int(parts[-1])
539+
# Second-to-last part should be prompt index
540+
prompt_idx = int(parts[-2])
541541
except (ValueError, IndexError):
542542
prompt_idx = result_idx # Fallback to result_idx
543543

vllm/v1/engine/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class EngineCoreRequest(
4949
gc=False,
5050
): # type: ignore[call-arg]
5151
request_id: str
52+
external_req_id: str
5253
prompt_token_ids: list[int] | None
5354
mm_features: list[MultiModalFeatureSpec] | None
5455
sampling_params: SamplingParams | None

vllm/v1/engine/async_llm.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,12 @@ async def add_request(
304304
# Convert Input --> Request.
305305
if isinstance(prompt, EngineCoreRequest):
306306
request = prompt
307+
if request_id != request.request_id:
308+
logger.warning_once(
309+
"AsyncLLM.add_request() was passed a request_id parameter that "
310+
"does not match the EngineCoreRequest.request_id attribute. The "
311+
"latter will be used, and the former will be ignored."
312+
)
307313
else:
308314
assert prompt_text is None
309315
request = self.input_processor.process_inputs(
@@ -333,7 +339,7 @@ async def add_request(
333339
assert isinstance(parent_params, SamplingParams)
334340

335341
# Fan out child requests (for n>1).
336-
parent_request = ParentRequest(request_id, parent_params)
342+
parent_request = ParentRequest(request.request_id, parent_params)
337343
for idx in range(parent_params.n):
338344
request_id, child_params = parent_request.get_child_info(idx)
339345
child_request = request if idx == parent_params.n - 1 else copy(request)

0 commit comments

Comments
 (0)