diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index 086dc2bf4a5..d71b6e89f6a 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -2328,6 +2328,11 @@ class LlmRequest : public GenericLlmRequest /// @return An optional Response std::optional createResponse(bool useFastLogits = false, int32_t mpiWorldRank = 0); + std::optional createResult(bool useFastLogits = false, int32_t mpiWorldRank = 0); + + void createSerializedResult( + std::vector& serializedResult, bool& isFinal, bool useFastLogits = false, int32_t mpiWorldRank = 0); + void validate(SizeType32 maxInputLen, SizeType32 maxSequenceLen, SizeType32 maxDraftLen, SizeType32 vocabSizePadded, std::optional maxEncoderInputLen = std::nullopt, bool enableKVCacheReuse = false); diff --git a/cpp/tensorrt_llm/batch_manager/llmRequest.cpp b/cpp/tensorrt_llm/batch_manager/llmRequest.cpp index 6fc7051ad7e..433f349b07d 100644 --- a/cpp/tensorrt_llm/batch_manager/llmRequest.cpp +++ b/cpp/tensorrt_llm/batch_manager/llmRequest.cpp @@ -16,6 +16,7 @@ */ #include "tensorrt_llm/batch_manager/llmRequest.h" +#include "tensorrt_llm/executor/serializeUtils.h" #include "tensorrt_llm/kernels/beamSearchKernels.h" namespace tensorrt_llm::batch_manager @@ -39,8 +40,34 @@ runtime::SizeType32 GenericLlmRequest::getBeamWidthByIter(bool template class GenericLlmRequest; -/// Note that there is some dependency on the order of operations in this method. Modify with care! std::optional LlmRequest::createResponse(bool useFastLogits, int32_t mpiWorldRank) +{ + auto requestId = isChild() ? mParentRequestId : mRequestId; + auto result = createResult(useFastLogits, mpiWorldRank); + if (result.has_value()) + { + return executor::Response(requestId, result.value(), mClientId); + } + return std::nullopt; +} + +void LlmRequest::createSerializedResult( + std::vector& serializedResult, bool& isFinal, bool useFastLogits, int32_t mpiWorldRank) +{ + auto result = createResult(useFastLogits, mpiWorldRank); + if (result.has_value()) + { + std::ostringstream oStream; + executor::serialize_utils::serialize(result.value(), oStream); + auto str = oStream.str(); + serializedResult.resize(str.size()); + std::copy(str.begin(), str.end(), serializedResult.begin()); + isFinal = result.value().isFinal; + } +} + +/// Note that there is some dependency on the order of operations in this method. Modify with care! +std::optional LlmRequest::createResult(bool useFastLogits, int32_t mpiWorldRank) { TLLM_CHECK(!isDisaggContextCompleteState()); if (!(isFinished() || (mIsStreaming && mState == LlmRequestState::kGENERATION_IN_PROGRESS))) @@ -192,11 +219,7 @@ std::optional LlmRequest::createResponse(bool useFastLogits, // Update position of last sent response setMaxSentTokenLen(maxNbTokens); - - auto requestId = isChild() ? mParentRequestId : mRequestId; - auto response = executor::Response(requestId, std::move(result), mClientId); - - return response; + return result; } void LlmRequest::validate(SizeType32 maxInputLen, SizeType32 maxSequenceLen, SizeType32 maxDraftLen, diff --git a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp index 35f32a3b128..9f8f95e8ed7 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp @@ -36,6 +36,7 @@ #include #include #include +#include namespace py = pybind11; namespace tb = tensorrt_llm::batch_manager; @@ -360,6 +361,16 @@ void initBindings(pybind11::module_& m) py::arg("enable_kv_cache_reuse") = false) .def("create_response", &tb::LlmRequest::createResponse, py::arg("use_fast_logits") = false, py::arg("mpi_world_rank") = 0) + .def("create_result", &tb::LlmRequest::createResult, py::arg("use_fast_logits") = false, + py::arg("mpi_world_rank") = 0) + .def("create_serialized_result", + [](tb::LlmRequest& self, bool use_fast_logits = false, int mpi_world_rank = 0) + { + std::vector serialized_result; + bool is_final = false; + self.createSerializedResult(serialized_result, is_final, use_fast_logits, mpi_world_rank); + return std::make_tuple(serialized_result, is_final); + }) .def("move_prompt_embedding_table_to_gpu", &tb::LlmRequest::movePromptEmbeddingTableToGpu, py::arg("manager")) .def("move_lora_weights_to_gpu", &tb::LlmRequest::moveLoraWeightsToGpu, py::arg("manager")) .def("finish_by_reason", &tb::LlmRequest::finishByReason, py::arg("finish_reason")); diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index 680206c6191..3ea0eb7ec46 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -213,31 +213,25 @@ def __init__(self, result: tensorrt_llm.bindings.executor.Result, def __getattr__(self, item): if item in self.py_result_properties: return getattr(self._py_result, item) - return getattr(self._result, item) + result = object.__getattribute__(self, '_result') + return getattr(result, item) class LlmResponse: """LlmResponse wraps `bindings.executor.Response` but detour some features to Python implementation""" - def __init__(self, response: tensorrt_llm.bindings.executor.Response, - py_result: PyResult): - self._response = response - self._py_result = py_result - - def __getstate__(self): - return self._response, self._py_result - - def __setstate__(self, state): - self._response, self._py_result = state + def __init__(self, + request_id: int, + error_msg: str = None, + result: LlmResult = None, + client_id: int = None): + self.request_id = request_id + self.error_msg = error_msg + self.result = result + self.client_id = client_id - @property - def result(self) -> tensorrt_llm.bindings.executor.Result: - return LlmResult( - self._response.result, - self._py_result) # LlmResult masquerades bindings.executor.Result - - def __getattr__(self, item): - return getattr(self._response, item) + def has_error(self): + return self.error_msg is not None class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest): @@ -269,6 +263,7 @@ def __init__( **kwargs) self.py_client_id = client_id self.py_request_id = self.request_id + self.py_llm_request_type = self.llm_request_type self.py_end_id = self.end_id self.py_prompt_len = self.prompt_len self.py_orig_prompt_len = self.orig_prompt_len @@ -282,6 +277,8 @@ def __init__( self.is_cuda_graph_dummy = False self.py_lora_task_layer_module_configs = None + self.py_tokens = super().get_tokens() + self.py_return_log_probs = return_log_probs self.py_return_context_logits = return_context_logits self.py_return_generation_logits = return_generation_logits @@ -297,13 +294,29 @@ def __init__( return_generation_logits, exclude_last_generation_logits) + def is_generation_only_request(self): + return self.py_llm_request_type == LlmRequestType.LLMREQUEST_TYPE_GENERATION_ONLY + + def get_tokens(self, beam: int) -> int: + return self.py_tokens[beam] + + def get_last_tokens(self, beam: int) -> int: + return self.py_tokens[beam][-1] + + def add_new_token(self, token: int, beam: int) -> int: + self.py_tokens[beam].append(token) + # sync to C++ side + return super().add_new_token(token, beam) + def create_response( self, use_fast_logits=False, mpi_world_rank=0) -> tensorrt_llm.bindings.executor.Response | None: - response = super().create_response(use_fast_logits, mpi_world_rank) - return LlmResponse(response, - self.py_result) if response is not None else None + result = super().create_result(use_fast_logits, mpi_world_rank) + return LlmResponse( + request_id=self.py_request_id, + result=LlmResult(result, self.py_result), + client_id=self.py_client_id) if result is not None else None @property def is_dummy(self): diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 692c4f40398..a58811da323 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -1184,7 +1184,7 @@ def _prepare_tp_inputs( gather_ids.append(len(input_ids) - 1) sequence_lengths.append(len(prompt_tokens)) prompt_lengths.append(len(prompt_tokens)) - past_seen_token_num = request.context_current_position + past_seen_token_num = begin_compute num_cached_tokens_per_seq.append(past_seen_token_num) multimodal_embedding = request.multimodal_embedding if multimodal_embedding is not None: diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 54ccc556504..5b84cd7e373 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -30,8 +30,8 @@ from ..distributed import Distributed from .kv_cache_transceiver import KvCacheTransceiver -from .llm_request import (ExecutorRequest, ExecutorResponse, LlmRequest, - LlmRequestState, executor_request_to_llm_request) +from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState, + LlmResponse, executor_request_to_llm_request) from .model_engine import ModelEngine from .sampler import Sampler, SampleState, SampleStateTensors, TorchSampler from .scheduler import ScheduledRequests @@ -323,14 +323,14 @@ def await_responses( self, id: Optional[Union[List[int], int]] = None, timeout: Optional[datetime.timedelta] = None, - ) -> Union[List[List[ExecutorResponse]], List[ExecutorResponse]]: + ) -> Union[List[List[LlmResponse]], List[LlmResponse]]: """ Await for ready responses Args: id (Optional[Union[List[int], int]]): Request id timeout (Optional[datetime.timedelta]): The maximum time to wait for new responses Returns: - Union[List[tensorrt_llm.bindings.executor.Response], List[List[tensorrt_llm.bindings.executor.Response]]]: Responses + Union[List[LlmResponse], List[List[LlmResponse]]]: Responses """ timeout = timeout.total_seconds() if timeout is not None else None if id is None: @@ -1934,8 +1934,10 @@ def _handle_errors(self, error_msg: Optional[str] = None): req_id = request.py_request_id request.state = LlmRequestState.GENERATION_COMPLETE self._terminate_request(request) - error_responses[req_id] = ExecutorResponse( - req_id, error_msg, client_id=request.py_client_id) + error_responses[req_id] = LlmResponse( + request_id=req_id, + error_msg=error_msg, + client_id=request.py_client_id) self.active_requests.clear() self._enqueue_responses(error_responses) @@ -1979,7 +1981,7 @@ def _handle_cancelled_requests(self): self._enqueue_responses(cancelled_responses) @nvtx_range("_enqueue_responses") - def _enqueue_responses(self, responses: Dict[int, ExecutorResponse]): + def _enqueue_responses(self, responses: Dict[int, LlmResponse]): if 0 not in self.dist.mapping.tp_group and not self.gather_all_responses: return @@ -2036,7 +2038,7 @@ def _handle_responses(self): requests_to_terminate.append(request) continue - if request.is_generation_only_request: + if request.is_generation_only_request(): # If request is in transmission, so we don't need to emit a response # Also, for the first iteration with overlap, we should skip since first # token has already been emitted previously @@ -2048,7 +2050,7 @@ def _handle_responses(self): request.draft_tokens = request.py_draft_tokens request.decoding_iter = request.py_decoding_iter - response: Response = request.create_response(False, self.dist.rank) + response = request.create_response(False, self.dist.rank) request_done = False if response: request_done = response.result.is_final @@ -2075,7 +2077,7 @@ def _terminate_ctx_finished_requests(self): def _await_any_response(self, timeout: Optional[float] = None - ) -> List[ExecutorResponse]: + ) -> List[LlmResponse]: def any_responses_ready(): return len(self.responses) > 0 or self.is_shutdown @@ -2092,7 +2094,7 @@ def any_responses_ready(): def _await_single_response( self, id: int, - timeout: Optional[float] = None) -> List[ExecutorResponse]: + timeout: Optional[float] = None) -> List[LlmResponse]: with self.response_cv: def key_has_response(): diff --git a/tensorrt_llm/executor/utils.py b/tensorrt_llm/executor/utils.py index bb6466373f1..fd4cd8444ec 100644 --- a/tensorrt_llm/executor/utils.py +++ b/tensorrt_llm/executor/utils.py @@ -8,7 +8,6 @@ from strenum import StrEnum from tensorrt_llm._utils import mpi_rank -from tensorrt_llm.bindings.executor import Response from tensorrt_llm.llmapi.utils import print_colored_debug from ..llmapi.mpi_session import (MpiCommSession, MpiPoolSession, MpiSession, @@ -144,8 +143,4 @@ class WorkerCommIpcAddrs(NamedTuple): def is_llm_response(instance): - from tensorrt_llm._torch.pyexecutor.llm_request import \ - LlmResponse as PyLlmResponse - - from .result import ResponseWrapper - return isinstance(instance, (Response, PyLlmResponse, ResponseWrapper)) + return hasattr(instance, "result")