diff --git a/tests/detokenizer/test_min_tokens.py b/tests/detokenizer/test_min_tokens.py index 1f8e944695bd..1e2551eefed2 100644 --- a/tests/detokenizer/test_min_tokens.py +++ b/tests/detokenizer/test_min_tokens.py @@ -35,6 +35,7 @@ def test_min_tokens_with_stop(min_tokens: int, stop: str, truth: str): ) request = EngineCoreRequest( request_id="", + external_req_id="", prompt_token_ids=prompt_token_ids, mm_features=None, sampling_params=params, diff --git a/tests/detokenizer/test_stop_string_while_stop_model_terminates.py b/tests/detokenizer/test_stop_string_while_stop_model_terminates.py index 5624332ef71d..c2ed4db4dda5 100644 --- a/tests/detokenizer/test_stop_string_while_stop_model_terminates.py +++ b/tests/detokenizer/test_stop_string_while_stop_model_terminates.py @@ -31,6 +31,7 @@ def _make_request(stop, include_stop_str_in_output: bool, min_tokens: int = 0): # Keep other fields minimal for unit test purposes. req = EngineCoreRequest( request_id="test", + external_req_id="test-ext", prompt_token_ids=[], mm_features=None, sampling_params=params, diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 9ea65f9fa6e7..f0236208eeb8 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -390,7 +390,9 @@ async def _fake_process_inputs( trace_headers, priority, ): - return dict(engine_prompt), {} + mock_request = MagicMock() + mock_request.request_id = request_id + return mock_request, {} serving_chat._process_inputs = AsyncMock(side_effect=_fake_process_inputs) return serving_chat @@ -662,7 +664,11 @@ async def test_serving_chat_data_parallel_rank_extraction(): mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False mock_engine.model_config = MockModelConfig() + + mock_request = MagicMock() + mock_request.request_id = "test-request-internal" mock_engine.input_processor = MagicMock() + mock_engine.input_processor.process_inputs.return_value = mock_request mock_engine.io_processor = MagicMock() # Mock the generate method to return an async generator @@ -689,7 +695,9 @@ async def mock_generate(*args, **kwargs): finished=True, ) - mock_engine.generate = AsyncMock(side_effect=mock_generate) + mock_engine.generate = MagicMock( + side_effect=lambda *args, **kwargs: mock_generate() + ) serving_chat = _build_serving_chat(mock_engine) diff --git a/tests/tokenizers_/test_detokenize.py b/tests/tokenizers_/test_detokenize.py index ae1d6b095672..f2abae652115 100644 --- a/tests/tokenizers_/test_detokenize.py +++ b/tests/tokenizers_/test_detokenize.py @@ -62,6 +62,7 @@ def _run_incremental_decode( ) request = EngineCoreRequest( request_id="", + external_req_id="", prompt_token_ids=prompt_token_ids, mm_features=None, sampling_params=params, diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py index 25af55baa91f..36e4e0bb984f 100644 --- a/tests/v1/engine/test_async_llm.py +++ b/tests/v1/engine/test_async_llm.py @@ -253,7 +253,7 @@ async def test_multi_abort(output_kind: RequestOutputKind): # Use multi-abort to abort multiple requests at once abort_request_ids = [request_ids[i] for i in REQUEST_IDS_TO_ABORT] - await engine.abort(abort_request_ids) + await engine.abort(abort_request_ids, internal=False) # Wait for all tasks to complete results = await asyncio.gather(*tasks, return_exceptions=True) @@ -548,7 +548,7 @@ async def test_abort_final_output(output_kind: RequestOutputKind): await asyncio.sleep(0.5) # Abort the request - await engine.abort(request_id) + await engine.abort(request_id, internal=False) # Wait for generation to complete and return final output final_output = await generated diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index 5fa16897b4e0..4f96ded7ec35 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -40,10 +40,16 @@ PROMPT = "I am Gyoubu Masataka Oniwa" PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids +_REQUEST_COUNTER = 0 + def make_request() -> EngineCoreRequest: + global _REQUEST_COUNTER + _REQUEST_COUNTER += 1 + request_id = f"request-{_REQUEST_COUNTER}" return EngineCoreRequest( - request_id=str(uuid.uuid4()), + request_id=request_id, + external_req_id=f"{request_id}-{uuid.uuid4()}", prompt_token_ids=PROMPT_TOKENS, mm_features=None, sampling_params=SamplingParams(), diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 770560a5e549..7638ece92336 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -39,6 +39,8 @@ PROMPT = "Hello my name is Robert and I love quantization kernels" PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids +_REQUEST_COUNTER = 0 + def make_request( params: SamplingParams, prompt_tokens_ids: list[int] | None = None @@ -46,8 +48,12 @@ def make_request( if not prompt_tokens_ids: prompt_tokens_ids = PROMPT_TOKENS + global _REQUEST_COUNTER + _REQUEST_COUNTER += 1 + request_id = f"request-{_REQUEST_COUNTER}" return EngineCoreRequest( - request_id=str(uuid.uuid4()), + request_id=request_id, + external_req_id=f"{request_id}-{uuid.uuid4()}", prompt_token_ids=prompt_tokens_ids, mm_features=None, sampling_params=params, diff --git a/tests/v1/engine/test_fast_incdec_prefix_err.py b/tests/v1/engine/test_fast_incdec_prefix_err.py index 77e67d54e587..67a3b6b012dc 100644 --- a/tests/v1/engine/test_fast_incdec_prefix_err.py +++ b/tests/v1/engine/test_fast_incdec_prefix_err.py @@ -27,6 +27,7 @@ def test_fast_inc_detok_invalid_utf8_err_case(): params = SamplingParams(skip_special_tokens=True) request = EngineCoreRequest( request_id="test", + external_req_id="test-ext", prompt_token_ids=prompt_token_ids, mm_features=None, sampling_params=params, diff --git a/tests/v1/engine/test_output_processor.py b/tests/v1/engine/test_output_processor.py index 990aa9d92585..f1185222f713 100644 --- a/tests/v1/engine/test_output_processor.py +++ b/tests/v1/engine/test_output_processor.py @@ -58,12 +58,12 @@ def test_incremental_detokenization( output_processor = OutputProcessor( dummy_test_vectors.tokenizer, log_stats=False, stream_interval=stream_interval ) - engine_core = MockEngineCore(tokens_list=dummy_test_vectors.generation_tokens) # Make N requests. requests = [ EngineCoreRequest( - request_id=f"request-{idx}", + request_id=f"request-{idx}-int", + external_req_id=f"request-{idx}", prompt_token_ids=prompt_tokens, mm_features=None, eos_token_id=None, @@ -83,6 +83,11 @@ def test_incremental_detokenization( for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) ] + engine_core = MockEngineCore( + tokens_list=dummy_test_vectors.generation_tokens, + request_ids=[req.request_id for req in requests], + ) + # Add requests to the detokenizer. for request, prompt in zip(requests, dummy_test_vectors.prompt_strings): output_processor.add_request(request, prompt) @@ -438,15 +443,6 @@ def test_logprobs_processor( dummy_test_vectors, ): output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False) - engine_core = MockEngineCore( - tokens_list=dummy_test_vectors.generation_tokens, - generated_logprobs_raw=None - if num_sample_logprobs is None - else dummy_test_vectors.generation_logprobs, - prompt_logprobs_raw=None - if num_prompt_logprobs is None - else dummy_test_vectors.prompt_logprobs, - ) # Make N requests. request_id_list = [ @@ -454,7 +450,8 @@ def test_logprobs_processor( ] requests = [ EngineCoreRequest( - request_id=request_id_list[idx], + request_id=request_id_list[idx] + "-int", + external_req_id=request_id_list[idx], prompt_token_ids=prompt_tokens, mm_features=None, eos_token_id=None, @@ -476,6 +473,17 @@ def test_logprobs_processor( for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) ] + engine_core = MockEngineCore( + tokens_list=dummy_test_vectors.generation_tokens, + generated_logprobs_raw=None + if num_sample_logprobs is None + else dummy_test_vectors.generation_logprobs, + prompt_logprobs_raw=None + if num_prompt_logprobs is None + else dummy_test_vectors.prompt_logprobs, + request_ids=[req.request_id for req in requests], + ) + # Add requests to the detokenizer. for request, prompt in zip(requests, dummy_test_vectors.prompt_strings): output_processor.add_request(request, prompt) @@ -621,19 +629,12 @@ def test_stop_token( ] prompt_string = dummy_test_vectors.prompt_strings[0] prompt_tokens = dummy_test_vectors.prompt_tokens[0] - engine_core = MockEngineCore( - tokens_list=[generation_tokens], - generated_logprobs_raw=[generation_logprobs] if do_logprobs else None, - prompt_logprobs_raw=None, - eos_token_id=eos_token_id, - stop_token_ids=stop_token_ids, - ignore_eos=ignore_eos, - ) # Make request. request_id = "request-0" request = EngineCoreRequest( request_id=request_id, + external_req_id=request_id + "-ext", prompt_token_ids=prompt_tokens, mm_features=None, eos_token_id=eos_token_id, @@ -655,6 +656,16 @@ def test_stop_token( pooling_params=None, ) + engine_core = MockEngineCore( + tokens_list=[generation_tokens], + generated_logprobs_raw=[generation_logprobs] if do_logprobs else None, + prompt_logprobs_raw=None, + eos_token_id=eos_token_id, + stop_token_ids=stop_token_ids, + ignore_eos=ignore_eos, + request_ids=[request.request_id], + ) + # Add request to the detokenizer. output_processor.add_request(request, prompt_string) @@ -720,13 +731,6 @@ def test_stop_string( dummy_test_vectors, ): output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False) - engine_core = MockEngineCore( - tokens_list=dummy_test_vectors.generation_tokens, - generated_logprobs_raw=dummy_test_vectors.generation_logprobs - if num_sample_logprobs - else None, - prompt_logprobs_raw=None, - ) # Make N requests. request_id_list = [ @@ -734,7 +738,8 @@ def test_stop_string( ] requests = [ EngineCoreRequest( - request_id=request_id_list[idx], + request_id=request_id_list[idx] + "-int", + external_req_id=request_id_list[idx], prompt_token_ids=prompt_tokens, mm_features=None, eos_token_id=None, @@ -756,6 +761,15 @@ def test_stop_string( for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) ] + engine_core = MockEngineCore( + tokens_list=dummy_test_vectors.generation_tokens, + generated_logprobs_raw=dummy_test_vectors.generation_logprobs + if num_sample_logprobs + else None, + prompt_logprobs_raw=None, + request_ids=[req.request_id for req in requests], + ) + # Add requests to the detokenizer. for request, prompt in zip(requests, dummy_test_vectors.prompt_strings): output_processor.add_request(request, prompt) @@ -813,9 +827,12 @@ def test_stop_string( for idx, (ref_gen_str, stop_str) in enumerate( zip(dummy_test_vectors.generation_strings, STOP_STRINGS) ): - # Request should be aborted. + # Request should be aborted (check internal ID in abort list). + internal_request_id = f"request-{idx}-int" + assert internal_request_id in aborted + + # Use external ID for collecting outputs request_id = f"request-{idx}" - assert request_id in aborted # Collected values that were generated. gen_str = gen_strings[request_id] @@ -848,13 +865,13 @@ def test_stop_string( def test_iteration_stats(dummy_test_vectors): output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=True) - engine_core = MockEngineCore(dummy_test_vectors.generation_tokens) engine_core_timestamp = time.monotonic() # Make N requests. requests = [ EngineCoreRequest( request_id=f"request-{idx}", + external_req_id=f"request-{idx}-ext", prompt_token_ids=prompt_tokens, mm_features=None, eos_token_id=None, @@ -868,6 +885,11 @@ def test_iteration_stats(dummy_test_vectors): for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) ] + engine_core = MockEngineCore( + dummy_test_vectors.generation_tokens, + request_ids=[req.request_id for req in requests], + ) + # Add all requests except one to the OutputProcessor. num_active = len(dummy_test_vectors.generation_tokens) - 1 for request in requests[:num_active]: @@ -922,7 +944,6 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors): output_processor = OutputProcessor( dummy_test_vectors.tokenizer, log_stats=log_stats ) - engine_core = MockEngineCore(dummy_test_vectors.generation_tokens) engine_core_timestamp = time.monotonic() # Create LoRA requests @@ -936,7 +957,8 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors): lora_assignments = [lora1, lora2, None] requests = [ EngineCoreRequest( - request_id=f"request-{idx}", + request_id=f"request-{idx}-int", + external_req_id=f"request-{idx}", prompt_token_ids=prompt_tokens, mm_features=None, eos_token_id=None, @@ -950,6 +972,11 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors): for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) ] + engine_core = MockEngineCore( + dummy_test_vectors.generation_tokens, + request_ids=[req.request_id for req in requests], + ) + # Add all requests to the OutputProcessor for request in requests: output_processor.add_request(request, None) @@ -1015,9 +1042,9 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors): outputs = EngineCoreOutputs( outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats() ) - # Find and mark request-0 as finished (it uses lora-1) + # Find and mark request-0-int as finished (it uses lora-1) for output in outputs.outputs: - if output.request_id == "request-0": + if output.request_id == "request-0-int": output.finish_reason = FinishReason.LENGTH break @@ -1040,9 +1067,9 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors): outputs = EngineCoreOutputs( outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats() ) - # Find and mark request-1 as finished (it uses lora-2) + # Find and mark request-1-int as finished (it uses lora-2) for output in outputs.outputs: - if output.request_id == "request-1": + if output.request_id == "request-1-int": output.finish_reason = FinishReason.LENGTH break @@ -1064,9 +1091,9 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors): outputs = EngineCoreOutputs( outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats() ) - # Find and mark request-2 as finished (it has no LoRA) + # Find and mark request-2-int as finished (it has no LoRA) for output in outputs.outputs: - if output.request_id == "request-2": + if output.request_id == "request-2-int": output.finish_reason = FinishReason.LENGTH break @@ -1107,7 +1134,9 @@ def make_outputs() -> list[RequestOutput]: for idx in range(NUM_REQS) ] - collector = RequestOutputCollector(RequestOutputKind.DELTA) + collector = RequestOutputCollector( + RequestOutputKind.DELTA, request_id="my-request-id-int" + ) # CASE 1: Put then get. outputs = make_outputs() @@ -1163,7 +1192,9 @@ def make_outputs() -> list[RequestOutput]: @pytest.mark.asyncio async def test_cumulative_output_collector_n(): """Test collector correctly handles multiple outputs by index.""" - collector = RequestOutputCollector(RequestOutputKind.CUMULATIVE) + collector = RequestOutputCollector( + RequestOutputKind.CUMULATIVE, request_id="my-request-id-int" + ) outputs = [ RequestOutput( request_id="my-request-id", @@ -1242,11 +1273,13 @@ async def test_cumulative_output_collector_n(): @pytest.mark.parametrize("runner", ["generate", "pooling"]) -def test_abort_requests(runner: str, dummy_test_vectors): +@pytest.mark.parametrize("abort_by", ["internal", "external"]) +def test_abort_requests(runner: str, abort_by: str, dummy_test_vectors): output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=True) requests = [ EngineCoreRequest( request_id=f"request-{idx}", + external_req_id=f"external-{idx}", prompt_token_ids=prompt_tokens, mm_features=None, eos_token_id=None, @@ -1265,8 +1298,13 @@ def test_abort_requests(runner: str, dummy_test_vectors): output_kind = request.sampling_params.output_kind else: output_kind = request.pooling_params.output_kind - queue = RequestOutputCollector(output_kind=output_kind) + queue = RequestOutputCollector( + output_kind=output_kind, request_id=request.request_id + ) output_processor.add_request(request, None, queue=queue) for request in requests: - output_processor.abort_requests([request.request_id]) + if abort_by == "internal": + output_processor.abort_requests([request.request_id], internal=True) + else: + output_processor.abort_requests([request.external_req_id], internal=False) diff --git a/tests/v1/engine/test_parallel_sampling.py b/tests/v1/engine/test_parallel_sampling.py index 736c0e54837f..dcb94fdc2764 100644 --- a/tests/v1/engine/test_parallel_sampling.py +++ b/tests/v1/engine/test_parallel_sampling.py @@ -8,7 +8,7 @@ def test_parent_request_to_output_stream() -> None: - parent_request = ParentRequest("parent_id", SamplingParams(n=2)) + parent_request = ParentRequest("parent_id", "ext_parent_id", SamplingParams(n=2)) parent_request.child_requests = {"child_id_0", "child_id_1"} output_0 = CompletionOutput( index=0, text="child 0", token_ids=[], cumulative_logprob=None, logprobs=None @@ -17,51 +17,33 @@ def test_parent_request_to_output_stream() -> None: index=1, text="child 1", token_ids=[], cumulative_logprob=None, logprobs=None ) # Request not finished - assert ("parent_id", [output_0], False) == parent_request.get_outputs( - "child_id_0", output_0 - ) - assert ("parent_id", [output_1], False) == parent_request.get_outputs( - "child_id_1", output_1 - ) - assert ("parent_id", [output_0], False) == parent_request.get_outputs( - "child_id_0", output_0 - ) - assert ("parent_id", [output_1], False) == parent_request.get_outputs( - "child_id_1", output_1 - ) + assert ([output_0], False) == parent_request.get_outputs("child_id_0", output_0) + assert ([output_1], False) == parent_request.get_outputs("child_id_1", output_1) + assert ([output_0], False) == parent_request.get_outputs("child_id_0", output_0) + assert ([output_1], False) == parent_request.get_outputs("child_id_1", output_1) # output_1 finished output_1.finish_reason = "ended" - assert ("parent_id", [output_0], False) == parent_request.get_outputs( - "child_id_0", output_0 - ) - assert ("parent_id", [output_1], False) == parent_request.get_outputs( - "child_id_1", output_1 - ) + assert ([output_0], False) == parent_request.get_outputs("child_id_0", output_0) + assert ([output_1], False) == parent_request.get_outputs("child_id_1", output_1) # Finished output_1 had already returned, DO NOT returned again - assert ("parent_id", [output_0], False) == parent_request.get_outputs( - "child_id_0", output_0 - ) - assert parent_request.get_outputs("child_id_1", output_1) == ( - "parent_id", - [], - False, - ) + assert ([output_0], False) == parent_request.get_outputs("child_id_0", output_0) + assert parent_request.get_outputs("child_id_1", output_1) == ([], False) # output_0 finished output_0.finish_reason = "ended" - assert ("parent_id", [output_0], True) == parent_request.get_outputs( - "child_id_0", output_0 - ) - assert parent_request.get_outputs("child_id_1", output_1) == ("parent_id", [], True) + assert ([output_0], True) == parent_request.get_outputs("child_id_0", output_0) + assert parent_request.get_outputs("child_id_1", output_1) == ([], True) # Finished output_0 had already returned, DO NOT returned again - assert parent_request.get_outputs("child_id_0", output_0) == ("parent_id", [], True) - assert parent_request.get_outputs("child_id_1", output_1) == ("parent_id", [], True) + assert parent_request.get_outputs("child_id_0", output_0) == ([], True) + assert parent_request.get_outputs("child_id_1", output_1) == ([], True) def test_parent_request_to_output_final_only() -> None: parent_request = ParentRequest( - "parent_id", SamplingParams(n=2, output_kind=RequestOutputKind.FINAL_ONLY) + "parent_id", + "ext_parent_id", + SamplingParams(n=2, output_kind=RequestOutputKind.FINAL_ONLY), ) parent_request.child_requests = {"child_id_0", "child_id_1"} output_0 = CompletionOutput( @@ -71,33 +53,17 @@ def test_parent_request_to_output_final_only() -> None: index=1, text="child 1", token_ids=[], cumulative_logprob=None, logprobs=None ) # Request not finished, return nothing - assert parent_request.get_outputs("child_id_0", output_0) == ( - "parent_id", - [], - False, - ) - assert parent_request.get_outputs("child_id_1", output_1) == ( - "parent_id", - [], - False, - ) + assert parent_request.get_outputs("child_id_0", output_0) == ([], False) + assert parent_request.get_outputs("child_id_1", output_1) == ([], False) # output_1 finished, but outputs won't be returned until all child requests finished output_1.finish_reason = "ended" - assert parent_request.get_outputs("child_id_0", output_0) == ( - "parent_id", - [], - False, - ) - assert parent_request.get_outputs("child_id_1", output_1) == ( - "parent_id", - [], - False, - ) + assert parent_request.get_outputs("child_id_0", output_0) == ([], False) + assert parent_request.get_outputs("child_id_1", output_1) == ([], False) # output_0 finished, as all child requests finished, the output would be returned output_0.finish_reason = "ended" - assert ("parent_id", [output_0, output_1], True) == parent_request.get_outputs( + assert ([output_0, output_1], True) == parent_request.get_outputs( "child_id_0", output_0 ) - assert ("parent_id", [output_0, output_1], True) == parent_request.get_outputs( + assert ([output_0, output_1], True) == parent_request.get_outputs( "child_id_1", output_1 ) diff --git a/tests/v1/engine/test_process_multi_modal_uuids.py b/tests/v1/engine/test_process_multi_modal_uuids.py index 1b11b8af49d1..196284e9bb0d 100644 --- a/tests/v1/engine/test_process_multi_modal_uuids.py +++ b/tests/v1/engine/test_process_multi_modal_uuids.py @@ -6,6 +6,7 @@ from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset from vllm.config import CacheConfig, DeviceConfig, ModelConfig, VllmConfig +from vllm.multimodal import MultiModalUUIDDict from vllm.sampling_params import SamplingParams from vllm.v1.engine import input_processor as input_processor_mod from vllm.v1.engine.input_processor import InputProcessor @@ -166,7 +167,7 @@ def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch): monkeypatch, mm_cache_gb=0.0, enable_prefix_caching=False ) - captured: dict[str, object] = {} + captured: dict[str, MultiModalUUIDDict] = {} def fake_preprocess( prompt, *, tokenization_kwargs=None, lora_request=None, mm_uuids=None @@ -196,7 +197,16 @@ def fake_preprocess( ) # Expect request-id-based overrides are passed through - assert captured["mm_uuids"] == { - "image": [f"{request_id}-image-0", f"{request_id}-image-1"], - "video": [f"{request_id}-video-0"], - } + mm_uuids = captured["mm_uuids"] + assert set(mm_uuids.keys()) == {"image", "video"} + assert len(mm_uuids["image"]) == 2 + assert len(mm_uuids["video"]) == 1 + assert mm_uuids["image"][0].startswith(f"{request_id}-") and mm_uuids["image"][ + 0 + ].endswith("-image-0") + assert mm_uuids["image"][1].startswith(f"{request_id}-") and mm_uuids["image"][ + 1 + ].endswith("-image-1") + assert mm_uuids["video"][0].startswith(f"{request_id}-") and mm_uuids["video"][ + 0 + ].endswith("-video-0") diff --git a/tests/v1/engine/utils.py b/tests/v1/engine/utils.py index 3541ef89bfc1..d14775668147 100644 --- a/tests/v1/engine/utils.py +++ b/tests/v1/engine/utils.py @@ -343,6 +343,7 @@ def __init__( eos_token_id: int | None = None, stop_token_ids: list[int] | None = None, ignore_eos: bool = False, + request_ids: list[str] | None = None, ) -> None: self.num_requests = len(tokens_list) self.tokens_list = tokens_list @@ -355,6 +356,11 @@ def __init__( self.eos_token_id = eos_token_id self.stop_token_ids = stop_token_ids self.ignore_eos = ignore_eos + self.request_ids = ( + request_ids + if request_ids is not None + else [f"request-{i}" for i in range(self.num_requests)] + ) def get_outputs(self) -> list[EngineCoreOutput]: do_logprobs = self.do_logprobs @@ -386,7 +392,7 @@ def get_outputs(self) -> list[EngineCoreOutput]: prompt_logprobs = None new_token_id = token_ids[token_idx] output = EngineCoreOutput( - request_id=f"request-{req_idx}", + request_id=self.request_ids[req_idx], new_token_ids=[new_token_id], new_logprobs=logprobs, new_prompt_logprobs_tensors=prompt_logprobs, diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 53da09cfbc21..6d24cef7859b 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -41,10 +41,13 @@ has_kv_transfer_group, ) from vllm.forward_context import ForwardContext +from vllm.outputs import RequestOutput from vllm.platforms import current_platform from vllm.platforms.interface import Platform from vllm.sampling_params import SamplingParams from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend +from vllm.v1.engine import EngineCoreRequest +from vllm.v1.engine.output_processor import OutputProcessor from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput from vllm.v1.request import RequestStatus @@ -1069,6 +1072,22 @@ def run_test_and_cleanup(): run_test_and_cleanup() +class RequestIdMapper: + """Helper class to map external request IDs to internal request IDs.""" + + def __init__(self, output_processor: OutputProcessor): + self.req_id_mapping: dict[str, str] = {} + self.original_add_request = output_processor.add_request + output_processor.add_request = self._add_request + + def _add_request(self, request: EngineCoreRequest, *args, **kwargs): + self.req_id_mapping[request.external_req_id] = request.request_id + return self.original_add_request(request, *args, **kwargs) + + def __call__(self, external_req_id: str) -> str: + return self.req_id_mapping[external_req_id] + + def _run_abort_timeout_test(llm: LLM, timeout: int): """Helper function to run the abort timeout test logic.""" remote_prefill_opts = { @@ -1090,24 +1109,34 @@ def _run_abort_timeout_test(llm: LLM, timeout: int): 0 ].req_to_blocks + id_mapper = RequestIdMapper(llm.llm_engine.output_processor) + + def req_id(outputs: list[RequestOutput]) -> str: + assert len(outputs) == 1 + return id_mapper(outputs[0].request_id) + padding = "Just making this request a little longer so that we're sure " "we're not hitting the small-request lower bound beneath which we don't " "actually trigger the whole kv transfer, but rather just recompute the " "blocks on D." - _ = llm.generate([f"What is the capital of Japan? {padding}"], sampling_params) + req0_id = req_id( + llm.generate([f"What is the capital of Japan? {padding}"], sampling_params) + ) # Request finished but not freed - assert "0" in scheduler.finished_req_ids and "0" in req_to_blocks + assert req0_id in scheduler.finished_req_ids and req0_id in req_to_blocks # Some other request, 0 still not freed - _ = llm.generate([f"What is the capital of Italy? {padding}"], sampling_params) - assert "0" in req_to_blocks - assert "1" in scheduler.finished_req_ids and "1" in req_to_blocks + req1_id = req_id( + llm.generate([f"What is the capital of Italy? {padding}"], sampling_params) + ) + assert req0_id in req_to_blocks + assert req1_id in scheduler.finished_req_ids and req1_id in req_to_blocks # Wait for timeout and trigger another scheduler loop time.sleep(timeout) _ = llm.generate([f"What is the capital of France? {padding}"], sampling_params) # Request-0 times out and is cleared! - assert "0" not in req_to_blocks + assert req0_id not in req_to_blocks # Need to shutdown the background thread to release NIXL side channel port llm.llm_engine.engine_core.shutdown() diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 5d5c4a1cdb77..c0c88f73fe92 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1636,7 +1636,7 @@ def _validate_and_add_requests( added_request_ids.append(request_id) except Exception as e: if added_request_ids: - self.llm_engine.abort_request(added_request_ids) + self.llm_engine.abort_request(added_request_ids, internal=True) raise e def _validate_mm_data_and_uuids( @@ -1738,7 +1738,7 @@ def _add_request( ) self.llm_engine.add_request( - request_id, + engine_request.request_id, engine_request, params, lora_request=lora_request, @@ -1746,7 +1746,7 @@ def _add_request( priority=priority, prompt_text=prompt_text, ) - return request_id + return engine_request.request_id def _run_engine( self, *, use_tqdm: bool | Callable[..., tqdm] = True diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index c6333d170c66..acec883509b9 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -341,7 +341,7 @@ async def create_chat_completion( generator = self.engine_client.generate( engine_request, sampling_params, - sub_request_id, + engine_request.request_id, lora_request=lora_request, trace_headers=trace_headers, priority=request.priority, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 3e421e21e3e8..957c098b743d 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -231,7 +231,7 @@ async def create_completion( generator = self.engine_client.generate( engine_request, sampling_params, - request_id_item, + engine_request.request_id, lora_request=lora_request, trace_headers=trace_headers, priority=request.priority, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 99936f588f28..8c38d8442f61 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1297,7 +1297,7 @@ async def _generate_with_builtin_tools( generator = self.engine_client.generate( engine_request, sampling_params, - sub_request_id, + engine_request.request_id, lora_request=lora_request, priority=priority, prompt_text=prompt_text, diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index ce2aae77108d..6277e992c923 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -49,6 +49,7 @@ class EngineCoreRequest( gc=False, ): # type: ignore[call-arg] request_id: str + external_req_id: str prompt_token_ids: list[int] | None mm_features: list[MultiModalFeatureSpec] | None sampling_params: SamplingParams | None diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 931d13be3d9b..7c25e1b276ac 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -289,12 +289,15 @@ async def add_request( is_pooling = isinstance(params, PoolingParams) - # Create a new output collector for the request. - queue = RequestOutputCollector(output_kind=params.output_kind) - # Convert Input --> Request. if isinstance(prompt, EngineCoreRequest): request = prompt + if request_id != request.request_id: + logger.warning_once( + "AsyncLLM.add_request() was passed a request_id parameter that " + "does not match the EngineCoreRequest.request_id attribute. The " + "latter will be used, and the former will be ignored." + ) else: assert prompt_text is None request = self.input_processor.process_inputs( @@ -313,6 +316,9 @@ async def add_request( elif isinstance(prompt, Mapping): prompt_text = cast(str | None, prompt.get("prompt")) + # Create a new output collector for the request. + queue = RequestOutputCollector(params.output_kind, request.request_id) + # Use cloned params that may have been updated in process_inputs() params = request.params @@ -324,7 +330,9 @@ async def add_request( assert isinstance(parent_params, SamplingParams) # Fan out child requests (for n>1). - parent_request = ParentRequest(request_id, parent_params) + parent_request = ParentRequest( + request.request_id, request.external_req_id, parent_params + ) for idx in range(parent_params.n): request_id, child_params = parent_request.get_child_info(idx) child_request = request if idx == parent_params.n - 1 else copy(request) @@ -395,6 +403,7 @@ async def generate( "prompt logprobs" ) + q: RequestOutputCollector | None = None try: # We start the output_handler on the first call to generate() so # we can call __init__ before the event loop, which enables us @@ -445,7 +454,8 @@ async def generate( # is cancelled or the generator is garbage collected. So, # we abort the request if we end up here. except (asyncio.CancelledError, GeneratorExit): - await self.abort(request_id) + if q: + await self.abort(q.request_id, internal=True) if self.log_requests: logger.info("Request %s aborted.", request_id) raise @@ -464,7 +474,8 @@ async def generate( # Unexpected error in the generate() task (possibly recoverable). except Exception as e: - await self.abort(request_id) + if q: + await self.abort(q.request_id, internal=True) if self.log_requests: logger.info("Request %s failed.", request_id) raise EngineGenerateError() from e @@ -540,13 +551,15 @@ async def output_handler(): self.output_handler = asyncio.create_task(output_handler()) - async def abort(self, request_id: str | Iterable[str]) -> None: + async def abort( + self, request_id: str | Iterable[str], internal: bool = False + ) -> None: """Abort RequestId in OutputProcessor and EngineCore.""" request_ids = ( (request_id,) if isinstance(request_id, str) else as_list(request_id) ) - all_request_ids = self.output_processor.abort_requests(request_ids) + all_request_ids = self.output_processor.abort_requests(request_ids, internal) await self.engine_core.abort_requests_async(all_request_ids) if self.log_requests: @@ -580,7 +593,7 @@ async def pause_generation( if not wait_for_inflight_requests: request_ids = list(self.output_processor.request_states.keys()) if request_ids: - await self.abort(request_ids) + await self.abort(request_ids, internal=True) # Wait for running requests to drain before clearing cache. if self.output_processor.has_unfinished_requests(): @@ -629,6 +642,7 @@ async def encode( returning the RequestOutput back to the caller. """ + q: RequestOutputCollector | None = None try: # We start the output_handler on the first call to generate() so # we can call __init__ before the event loop, which enables us @@ -673,7 +687,8 @@ async def encode( # If the request is disconnected by the client, generate() # is cancelled. So, we abort the request if we end up here. except asyncio.CancelledError: - await self.abort(request_id) + if q: + await self.abort(q.request_id, internal=True) if self.log_requests: logger.info("Request %s aborted.", request_id) raise @@ -692,7 +707,8 @@ async def encode( # Unexpected error in the generate() task (possibly recoverable). except Exception as e: - await self.abort(request_id) + if q: + await self.abort(q.request_id, internal=True) if self.log_requests: logger.info("Request %s failed.", request_id) raise EngineGenerateError() from e diff --git a/vllm/v1/engine/input_processor.py b/vllm/v1/engine/input_processor.py index e6a94f4e3de5..1965b0f49363 100644 --- a/vllm/v1/engine/input_processor.py +++ b/vllm/v1/engine/input_processor.py @@ -20,7 +20,7 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.tokenizers import MistralTokenizer, TokenizerLike -from vllm.utils import length_from_prompt_token_ids_or_embeds +from vllm.utils import length_from_prompt_token_ids_or_embeds, random_uuid from vllm.v1.engine import EngineCoreRequest from vllm.v1.metrics.stats import MultiModalCacheStats from vllm.v1.structured_output.backend_guidance import validate_guidance_grammar @@ -382,6 +382,12 @@ def _extract_mm_data(p: PromptType): mm_uuids[modality] = [f"{request_id}-{modality}-{i}" for i in range(n)] return mm_uuids + def _generate_request_id(self, request_id: str): + """Construct an internal request ID by adding 8 random characters + to the supplied request ID in order to ensure uniquness. + """ + return f"{request_id}-{random_uuid():.8}" + def process_inputs( self, request_id: str, @@ -409,6 +415,9 @@ def process_inputs( if arrival_time is None: arrival_time = time.time() + external_req_id = request_id + request_id = self._generate_request_id(request_id) + # Optionally generate multimodal hash overrides to avoid hashing # multimodal data items by their content as their identifiers. @@ -509,6 +518,7 @@ def process_inputs( return EngineCoreRequest( request_id=request_id, + external_req_id=external_req_id, prompt_token_ids=prompt_token_ids, prompt_embeds=prompt_embeds, mm_features=mm_features, diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 4c3129100547..50d0747529d2 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -213,10 +213,10 @@ def validate_outputs(cls, outputs, output_type): def get_supported_tasks(self) -> tuple[SupportedTask, ...]: return self.engine_core.get_supported_tasks() - def abort_request(self, request_ids: list[str]) -> None: + def abort_request(self, request_ids: list[str], internal: bool = False) -> None: """Remove request_ids from EngineCore and Detokenizer.""" - request_ids = self.output_processor.abort_requests(request_ids) + request_ids = self.output_processor.abort_requests(request_ids, internal) self.engine_core.abort_requests(request_ids) def add_request( @@ -238,6 +238,12 @@ def add_request( # Process raw inputs into the request. if isinstance(prompt, EngineCoreRequest): request = prompt + if request_id != request.request_id: + logger.warning_once( + "AsyncLLM.add_request() was passed a request_id parameter that " + "does not match the EngineCoreRequest.request_id attribute. The " + "latter will be used, and the former will be ignored." + ) else: assert prompt_text is None request = self.input_processor.process_inputs( @@ -268,7 +274,7 @@ def add_request( return # Fan out child requests (for n>1). - parent_req = ParentRequest(request_id, params) + parent_req = ParentRequest(request.request_id, request.external_req_id, params) for idx in range(n): request_id, child_params = parent_req.get_child_info(idx) child_request = request if idx == n - 1 else copy(request) diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 9be3f4da7352..ebe2644129e8 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio +from collections import defaultdict from collections.abc import Iterable from dataclasses import dataclass from typing import Any, cast @@ -39,8 +40,9 @@ class RequestOutputCollector: producer gets ahead of the consumer. """ - def __init__(self, output_kind: RequestOutputKind): + def __init__(self, output_kind: RequestOutputKind, request_id: str): self.aggregate = output_kind == RequestOutputKind.DELTA + self.request_id = request_id self.output: RequestOutput | PoolingRequestOutput | Exception | None = None self.ready = asyncio.Event() @@ -91,6 +93,7 @@ class RequestState: def __init__( self, request_id: str, + external_req_id: str, parent_req: ParentRequest | None, request_index: int, lora_name: str | None, @@ -110,6 +113,7 @@ def __init__( temperature: float | None = None, ): self.request_id = request_id + self.external_req_id = external_req_id self.parent_req = parent_req self.request_index = request_index self.lora_name = lora_name @@ -176,6 +180,7 @@ def from_new_request( return cls( request_id=request.request_id, + external_req_id=request.external_req_id, parent_req=parent_req, request_index=request_index, lora_name=( @@ -235,10 +240,13 @@ def make_request_output( ] self.sent_tokens_offset = len(self.detokenizer.output_token_ids) - request_id = self.request_id + external_req_id = self.external_req_id + if pooling_output is not None: return self._new_request_output( - request_id, [self._new_pooling_output(pooling_output)], finished + external_req_id, + [self._new_pooling_output(pooling_output)], + finished, ) output = self._new_completion_output(new_token_ids, finish_reason, stop_reason) @@ -246,19 +254,18 @@ def make_request_output( if self.parent_req is None: outputs = [output] else: - request_id, outputs, finished = self.parent_req.get_outputs( - request_id, output - ) + outputs, finished = self.parent_req.get_outputs(self.request_id, output) if not outputs: return None + external_req_id = self.parent_req.external_req_id return self._new_request_output( - request_id, outputs, finished, kv_transfer_params + external_req_id, outputs, finished, kv_transfer_params ) def _new_request_output( self, - request_id: str, + external_req_id: str, outputs: list[CompletionOutput] | list[PoolingOutput], finished: bool, kv_transfer_params: dict[str, Any] | None = None, @@ -269,7 +276,7 @@ def _new_request_output( # Prompt embeddings are currently not supported by pooling requests. assert self.prompt_token_ids is not None return PoolingRequestOutput( - request_id=request_id, + request_id=external_req_id, outputs=first_output, num_cached_tokens=self.num_cached_tokens, prompt_token_ids=self.prompt_token_ids, @@ -288,7 +295,7 @@ def _new_request_output( prompt_token_ids = [0] * len(self.prompt_embeds) return RequestOutput( - request_id=request_id, + request_id=external_req_id, # request_id is what was provided externally prompt=self.prompt, prompt_token_ids=prompt_token_ids, prompt_logprobs=prompt_logprobs, @@ -351,6 +358,7 @@ def __init__( self.stream_interval = stream_interval self.request_states: dict[str, RequestState] = {} self.parent_requests: dict[str, ParentRequest] = {} + self.external_req_ids: defaultdict[str, list[str]] = defaultdict(list) self.lora_states = LoRARequestStates(log_stats) self.tracer: Tracer | None = None self._requests_drained = asyncio.Event() @@ -374,12 +382,41 @@ def propagate_error(self, e: Exception): assert state.queue is not None state.queue.put(e) - def abort_requests( - self, - request_ids: Iterable[str], - ) -> list[str]: - request_ids_to_abort = [] + def abort_requests(self, request_ids: Iterable[str], internal: bool) -> list[str]: + """Abort a list of requests. + + The request_ids may be either external request IDs (those passed to + InputProcessor.process_inputs()) or internal request IDs (those randomly + generated when creating the EngineCoreRequest). + + If an external request ID is provided, and that external request ID + was used for multiple requests, all requests associated with that external + request ID are aborted. + + In the case of parallel sampling, a request ID may be used to identify + a parent request, in which case the associated child requests are aborted + also. + """ + + internal_req_ids = [] for request_id in request_ids: + if internal: + # Internal ID - this may be a parent request + internal_req_ids.append(request_id) + + # Remove internal ID from the external->internal mapping + if req_state := self.request_states.get(request_id): + external_req_id = req_state.external_req_id + internal_ids = self.external_req_ids[external_req_id] + internal_ids.remove(request_id) + if not internal_ids: + del self.external_req_ids[external_req_id] + elif internal_ids := self.external_req_ids.pop(request_id, []): + # External ID - abort all requests in the external->internal mapping + internal_req_ids.extend(internal_ids) + + request_ids_to_abort = [] + for request_id in internal_req_ids: req_state = self.request_states.pop(request_id, None) if req_state is not None: self.lora_states.request_finished(request_id, req_state.lora_name) @@ -403,7 +440,7 @@ def abort_requests( # Abort children prior to removing the parent. if parent.child_requests: child_reqs = list(parent.child_requests) - child_reqs = self.abort_requests(child_reqs) + child_reqs = self.abort_requests(child_reqs, internal=True) request_ids_to_abort.extend(child_reqs) self.parent_requests.pop(request_id, None) if not self.request_states: @@ -438,6 +475,9 @@ def add_request( if parent_req: self.parent_requests[parent_req.request_id] = parent_req + # Track the external_req_id -> [internal_req_id, ...] mapping + self.external_req_ids[req_state.external_req_id].append(request_id) + def process_outputs( self, engine_core_outputs: list[EngineCoreOutput], @@ -521,6 +561,12 @@ def process_outputs( # Free completed requests. if finish_reason is not None: self.request_states.pop(req_id) + + internal_ids = self.external_req_ids[req_state.external_req_id] + internal_ids.remove(req_id) + if not internal_ids: + del self.external_req_ids[req_state.external_req_id] + # Remove parent request if applicable. parent_req = req_state.parent_req if parent_req and not parent_req.child_requests: @@ -596,7 +642,9 @@ def do_tracing( ) # meta - span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID, req_state.request_id) + span.set_attribute( + SpanAttributes.GEN_AI_REQUEST_ID, req_state.external_req_id + ) if req_state.top_p: span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P, req_state.top_p) if req_state.max_tokens_param: diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index 59aacd196307..28f3109c8a51 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -17,6 +17,7 @@ class ParentRequest: """ request_id: str + external_req_id: str sampling_params: SamplingParams # To track the completion of child requests @@ -31,8 +32,11 @@ class ParentRequest: # To efficiently obtain child sampling params cached_child_sampling_params: SamplingParams | None - def __init__(self, request_id: str, sampling_params: SamplingParams) -> None: + def __init__( + self, request_id: str, external_req_id: str, sampling_params: SamplingParams + ) -> None: self.request_id = request_id + self.external_req_id = external_req_id self.sampling_params = sampling_params self.child_requests = set() @@ -96,7 +100,7 @@ def get_outputs( self, child_request_id: str, completion_output: CompletionOutput, - ) -> tuple[str, list[CompletionOutput], bool]: + ) -> tuple[list[CompletionOutput], bool]: already_finished_and_returned: bool = False if completion_output.finished(): if child_request_id in self.child_requests: @@ -118,7 +122,7 @@ def get_outputs( outputs = [] if self.child_requests else self.output_aggregator finished = not self.child_requests - return self.request_id, outputs, finished + return outputs, finished def observe_num_generation_tokens(self, num_generation_tokens: int): self.max_num_generation_tokens = max(