From 8e70e8ca1ff1c63237d42368b613e0641b1dd900 Mon Sep 17 00:00:00 2001 From: Shang-Pin Sheng Date: Fri, 6 Jun 2025 09:53:47 -0700 Subject: [PATCH 1/9] fix log probs and add for mtp --- tensorrt_llm/_torch/speculative/mtp.py | 55 +++++++++++++++---- tensorrt_llm/serve/openai_protocol.py | 2 + tensorrt_llm/serve/postprocess_handlers.py | 9 ++- ...test_mtp_sample_and_accept_draft_tokens.py | 2 +- 4 files changed, 53 insertions(+), 15 deletions(-) diff --git a/tensorrt_llm/_torch/speculative/mtp.py b/tensorrt_llm/_torch/speculative/mtp.py index 144633f4c40..25907436d8e 100644 --- a/tensorrt_llm/_torch/speculative/mtp.py +++ b/tensorrt_llm/_torch/speculative/mtp.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from typing import List, Optional +from tensorrt_llm.executor.result import Logprob import torch from torch import nn @@ -9,7 +10,7 @@ from ..attention_backend import AttentionMetadata from ..pyexecutor.llm_request import LlmRequest, LlmRequestState from ..pyexecutor.resource_manager import BaseResourceManager, SlotManager -from ..pyexecutor.sampler import SampleState, SampleStateTensors, TorchSampler +from ..pyexecutor.sampler import SampleState, SampleStateTensors, TorchSampler, greedy_search_sampling_batch from ..pyexecutor.scheduler import ScheduledRequests from .interface import SpecConfig, SpecMetadata, SpeculativeDecodingMode @@ -18,6 +19,7 @@ class SampleStateTensorsMTP(SampleStateTensors): new_tokens_lens: torch.Tensor next_draft_tokens: torch.Tensor + next_tokens_log_probs: Optional[torch.Tensor] = None @dataclass(frozen=True, kw_only=True) @@ -240,20 +242,29 @@ def update_requests(self, state: SampleStateMTP) -> None: new_tokens_list = state.host.new_tokens.tolist() new_tokens_lens_list = state.host.new_tokens_lens.tolist() next_draft_tokens_list = state.host.next_draft_tokens.tolist() + next_tokens_log_probs_list = state.host.next_tokens_log_probs.tolist() + + def handle_logprobs(request, tokens, log_probs): + token_log_probs = [{ + token: Logprob(logprob=logprob, rank=1) + } for token, logprob in zip(tokens, log_probs)] + request.py_result.append_log_probs([token_log_probs]) idx = 0 beam_idx = 0 for request in state.scheduled_requests.context_requests: assert not request.py_return_context_logits, "return_context_logits not implemented for MTPSampler" assert not request.py_return_generation_logits, "return_generation_logits not implemented for MTPSampler" - assert not request.py_return_log_probs, "return_log_probs not implemented for MTPSampler" + # assert not request.py_return_log_probs, "return_log_probs not implemented for MTPSampler" if request.get_context_remaining_length() != 0: idx += 1 continue if request.state != LlmRequestState.GENERATION_COMPLETE: new_token = new_tokens_list[idx][0] + new_token_log_prob = next_tokens_log_probs_list[idx][0] num_tokens = request.add_new_token(new_token, beam_idx) + handle_logprobs(request, [new_token], [new_token_log_prob]) should_stop = self._handle_stop_criteria( request, new_token, num_tokens, beam_idx) if self._draft_meet_max_token_stop_criteria( @@ -267,14 +278,17 @@ def update_requests(self, state: SampleStateMTP) -> None: for request in state.scheduled_requests.generation_requests: assert not request.py_return_context_logits, "return_context_logits not implemented for MTPSampler" assert not request.py_return_generation_logits, "return_generation_logits not implemented for MTPSampler" - assert not request.py_return_log_probs, "return_log_probs not implemented for MTPSampler" + # assert not request.py_return_log_probs, "return_log_probs not implemented for MTPSampler" if request.state != LlmRequestState.GENERATION_COMPLETE: new_tokens = new_tokens_list[idx] + new_tokens_log_probs = next_tokens_log_probs_list[idx] num_new_tokens = new_tokens_lens_list[idx] should_stop = False for i in range(num_new_tokens): new_token = new_tokens[i] + new_token_log_prob = new_tokens_log_probs[i] num_tokens = request.add_new_token(new_token, beam_idx) + handle_logprobs(request, [new_token], [new_token_log_prob]) should_stop = self._handle_stop_criteria( request, new_token, num_tokens, beam_idx) if should_stop: @@ -298,17 +312,21 @@ def sample_async(self, scheduled_requests: ScheduledRequests, new_tokens_lens_device = model_outputs['new_tokens_lens'] next_draft_tokens_device = model_outputs['next_draft_tokens'] next_new_tokens_device = model_outputs['next_new_tokens'] + next_tokens_log_probs_device = model_outputs["log_probs"] device = SampleStateTensorsMTP( new_tokens=next_new_tokens_device, new_tokens_lens=new_tokens_lens_device, next_draft_tokens=next_draft_tokens_device, + next_tokens_log_probs=next_tokens_log_probs_device, ) host = SampleStateTensorsMTP( new_tokens=new_tokens_device.to('cpu', non_blocking=True), new_tokens_lens=new_tokens_lens_device.to('cpu', non_blocking=True), next_draft_tokens=next_draft_tokens_device.to('cpu', non_blocking=True), + next_tokens_log_probs=next_tokens_log_probs_device.to( + 'cpu', non_blocking=True), ) sampler_event = torch.cuda.Event() sampler_event.record() @@ -346,7 +364,7 @@ def forward( # Sample and verify draft tokens raw_logits = logits - accepted_tokens, num_accepted_tokens = self.sample_and_accept_draft_tokens( + accepted_tokens, num_accepted_tokens, log_probs = self.sample_and_accept_draft_tokens( input_ids, logits, spec_metadata, attn_metadata) # Update MTP past hidden states @@ -398,7 +416,8 @@ def forward( 'new_tokens': accepted_tokens, 'new_tokens_lens': num_accepted_tokens, 'next_draft_tokens': next_draft_tokens, - 'next_new_tokens': next_new_tokens + 'next_new_tokens': next_new_tokens, + "log_probs": log_probs } def skip_forward( @@ -418,6 +437,9 @@ def skip_forward( accepted_tokens = torch.empty((batch_size, (mtp_num_modules + 1)), dtype=torch.int, device=logits.device) + log_probs = torch.empty((batch_size, (mtp_num_modules + 1)), + dtype=torch.float32, + device=logits.device) num_accepted_tokens = torch.ones(batch_size, dtype=torch.int, device=logits.device) @@ -432,7 +454,8 @@ def skip_forward( 'new_tokens': accepted_tokens, 'new_tokens_lens': num_accepted_tokens, 'next_draft_tokens': next_draft_tokens, - 'next_new_tokens': next_new_tokens + 'next_new_tokens': next_new_tokens, + 'log_probs': log_probs } def update_mtp_hidden_states( @@ -728,6 +751,9 @@ def sample_and_accept_draft_tokens( accepted_tokens = torch.empty((batch_size, (mtp_num_modules + 1)), dtype=torch.int, device=logits.device) + log_probs = torch.empty((batch_size, (mtp_num_modules + 1)), + dtype=torch.float32, + device=logits.device) num_accepted_tokens = torch.ones(batch_size, dtype=torch.int, device=logits.device) @@ -784,7 +810,7 @@ def sample_and_accept_draft_tokens( # Strict acceptance else: - if self.is_thop: + if False: # Temporary buffer target_tokens_cache = torch.zeros(batch_size * (mtp_num_modules + 1), @@ -795,15 +821,21 @@ def sample_and_accept_draft_tokens( mtp_num_modules, batch_size, num_contexts, logits.shape[-1]) else: # Do greedy sampling for the input logits - target_tokens = torch.argmax(logits, dim=-1) + target_tokens, target_log_probs = greedy_search_sampling_batch(logits) # context accepted_tokens[:num_contexts, 0] = target_tokens[:num_contexts] + log_probs[:num_contexts, 0] = target_log_probs[:num_contexts] # generation gen_target_tokens = target_tokens[num_contexts:].reshape( num_gens, mtp_num_modules + 1) + gen_target_log_probs = target_log_probs[num_contexts:].reshape( + num_gens, mtp_num_modules + 1) + accepted_tokens[num_contexts:, :] = gen_target_tokens + log_probs[num_contexts:, :] = gen_target_log_probs + draft_tokens = spec_metadata.draft_tokens.reshape( num_gens, mtp_num_modules) num_accepted_tokens[num_contexts:] += torch.cumprod( @@ -811,7 +843,7 @@ def sample_and_accept_draft_tokens( ).int(), dim=-1).sum(1) - return accepted_tokens, num_accepted_tokens + return accepted_tokens, num_accepted_tokens, log_probs def change_attn_metadata(self, num_accepted_tokens: torch.Tensor, attn_metadata: AttentionMetadata): @@ -1067,7 +1099,7 @@ def forward( # Sample and verify draft tokens raw_logits = logits - accepted_tokens, num_accepted_tokens = self.sample_and_accept_draft_tokens( + accepted_tokens, num_accepted_tokens, log_probs = self.sample_and_accept_draft_tokens( input_ids, logits, spec_metadata, attn_metadata) # Save the old attn_metadata and spec_metadata @@ -1150,7 +1182,8 @@ def forward( 'new_tokens': accepted_tokens, 'new_tokens_lens': num_accepted_tokens, 'next_draft_tokens': next_draft_tokens, - 'next_new_tokens': next_new_tokens + 'next_new_tokens': next_new_tokens, + 'log_probs': log_probs } def prepare_drafter_inputs( diff --git a/tensorrt_llm/serve/openai_protocol.py b/tensorrt_llm/serve/openai_protocol.py index 98c9d04dae6..41b56ac04a0 100644 --- a/tensorrt_llm/serve/openai_protocol.py +++ b/tensorrt_llm/serve/openai_protocol.py @@ -244,6 +244,7 @@ def to_sampling_params(self) -> SamplingParams: # TODO: migrate to use logprobs and prompt_logprobs _return_log_probs=self.logprobs, + logprobs=self.logprobs, ) return sampling_params @@ -553,6 +554,7 @@ def to_sampling_params(self) -> SamplingParams: # TODO: migrate to use logprobs and prompt_logprobs _return_log_probs=self.logprobs, + logprobs=self.logprobs, ) return sampling_params diff --git a/tensorrt_llm/serve/postprocess_handlers.py b/tensorrt_llm/serve/postprocess_handlers.py index 9846dfe4a75..6487ea25814 100644 --- a/tensorrt_llm/serve/postprocess_handlers.py +++ b/tensorrt_llm/serve/postprocess_handlers.py @@ -1,6 +1,8 @@ from dataclasses import dataclass, field from typing import List, Literal, Optional, Tuple, Union +from tensorrt_llm.executor.result import TokenLogprobs + from .._utils import nvtx_range_debug from ..executor import (DetokenizedGenerationResultBase, GenerationResult, GenerationResultBase) @@ -59,16 +61,17 @@ def from_request(cls, request: ChatCompletionRequest): def create_logprobs(token_ids: List[int], tokenizer: TransformersTokenizer, - logprobs: List[float]) -> ChatCompletionLogProbs: + logprobs: TokenLogprobs) -> ChatCompletionLogProbs: assert len(token_ids) == len(logprobs), \ "token_ids and logprobs have different lengths" content: List[ChatCompletionLogProbsContent] = [] - for token_id, logprob in zip(token_ids, logprobs): + for logprob in logprobs: + token_id, lp = list(logprob.items())[0] token = tokenizer.decode(token_id) # returning multiple logprobs is not supported first_logprob = ChatCompletionLogProbsContent( token=token, - logprob=max(logprob, -9999.0), + logprob=max(lp.logprob, -9999.0), bytes=list(token.encode("utf-8", errors="replace"))) content.append(first_logprob) chat_logprobs = ChatCompletionLogProbs(content=content) diff --git a/tests/unittest/_torch/speculative/test_mtp_sample_and_accept_draft_tokens.py b/tests/unittest/_torch/speculative/test_mtp_sample_and_accept_draft_tokens.py index 954d5476507..4aa5ec49cf6 100644 --- a/tests/unittest/_torch/speculative/test_mtp_sample_and_accept_draft_tokens.py +++ b/tests/unittest/_torch/speculative/test_mtp_sample_and_accept_draft_tokens.py @@ -312,7 +312,7 @@ def test_sample_and_accept_draft_tokens(self, mtp_num_modules, logits, for is_thop in [True, False]: mtpworker.is_thop = is_thop # TODO: add unit tests for relaxed acceptance - accepted_tokens, num_accepted_tokens = mtpworker.sample_and_accept_draft_tokens( + accepted_tokens, num_accepted_tokens, log_probs = mtpworker.sample_and_accept_draft_tokens( None, logits, spec_metadata, attn_metadata) torch.testing.assert_close(num_accepted_tokens, From 3f4f00ba954e628743cfe68a7bdf918a25267e4e Mon Sep 17 00:00:00 2001 From: Shang-Pin Sheng Date: Fri, 6 Jun 2025 16:54:53 -0700 Subject: [PATCH 2/9] mtp logprobs cuda --- .../kernels/speculativeDecoding/mtpKernels.cu | 24 +++++++++++++++---- .../kernels/speculativeDecoding/mtpKernels.h | 2 ++ cpp/tensorrt_llm/thop/mtpOp.cpp | 10 +++++--- tensorrt_llm/_torch/speculative/mtp.py | 8 +++++-- 4 files changed, 34 insertions(+), 10 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/speculativeDecoding/mtpKernels.cu b/cpp/tensorrt_llm/kernels/speculativeDecoding/mtpKernels.cu index 469e64a159c..d8802f8f5e2 100644 --- a/cpp/tensorrt_llm/kernels/speculativeDecoding/mtpKernels.cu +++ b/cpp/tensorrt_llm/kernels/speculativeDecoding/mtpKernels.cu @@ -203,7 +203,7 @@ template void invokeMTPPrepareDrafterInputs<__nv_bfloat16>(MTPPrepareDrafterInpu template __global__ void mtpGreedySampling(int const numMTPModules, int const batchSize, int const numContextRequest, - int const vocabSize, T const* logits, int* targetTokens) + int const vocabSize, T const* logits, int* targetTokens, float* targetTokenLogprobs) { /* In a batch of request: context request (at the beginning) + generation requests @@ -217,6 +217,7 @@ __global__ void mtpGreedySampling(int const numMTPModules, int const batchSize, __shared__ T maxValueCache[BLOCK_SIZE]; __shared__ int maxValueIndexCache[BLOCK_SIZE]; + __shared__ float sumExpCache[BLOCK_SIZE]; int const bid = static_cast(blockIdx.x); int const tid = static_cast(threadIdx.x); @@ -227,19 +228,23 @@ __global__ void mtpGreedySampling(int const numMTPModules, int const batchSize, T tmpMaxValue = curLogitsPtr[0]; int tmpMaxValueIndex = 0; + float tmpNorm = 0.0f; int ii = tid; while (ii < vocabSize) { if (curLogitsPtr[ii] >= tmpMaxValue) { + tmpNorm *= expf(tmpMaxValue - curLogitsPtr[ii]); // Find the first top-1 tmpMaxValueIndex = (curLogitsPtr[ii] == tmpMaxValue) ? min(tmpMaxValueIndex, ii) : ii; tmpMaxValue = curLogitsPtr[ii]; } + tmpNorm += expf(curLogitsPtr[ii] - tmpMaxValue); ii += blockDim.x; } maxValueCache[tid] = tmpMaxValue; maxValueIndexCache[tid] = tmpMaxValueIndex; + sumExpCache[tid] = tmpNorm; __syncthreads(); @@ -251,11 +256,17 @@ __global__ void mtpGreedySampling(int const numMTPModules, int const batchSize, { if (maxValueCache[tid] <= maxValueCache[tid + ii]) { + sumExpCache[tid] *= expf(maxValueCache[tid] - maxValueCache[tid + ii]); maxValueIndexCache[tid] = (maxValueCache[tid] == maxValueCache[tid + ii]) ? min(maxValueIndexCache[tid], maxValueIndexCache[tid + ii]) : maxValueIndexCache[tid + ii]; maxValueCache[tid] = maxValueCache[tid + ii]; } + else + { + sumExpCache[tid + ii] *= expf(maxValueCache[tid + ii] - maxValueCache[tid]); + } + sumExpCache[tid] += sumExpCache[tid + ii]; } __syncthreads(); ii /= 2; @@ -264,11 +275,12 @@ __global__ void mtpGreedySampling(int const numMTPModules, int const batchSize, if (tid == 0) { targetTokens[bid] = maxValueIndexCache[tid]; + targetTokenLogprobs[bid] = logf(1.0f / sumExpCache[tid]); } } __global__ void mtpAcceptDraftToken(int const numMTPModules, int const batchSize, int const numContextRequest, - int const* draftTokens, int* targetTokens, int* acceptedTokens, int* numAcceptedTokens) + int const* draftTokens, int* targetTokens, float* targetTokenLogprobs, int* acceptedTokens, int* numAcceptedTokens, float* logprobs) { /* In a batch of request: context request (at the beginning) + generation requests @@ -324,9 +336,11 @@ __global__ void mtpAcceptDraftToken(int const numMTPModules, int const batchSize // Write back to acceptedTokens auto curAcceptedTokensPtr = acceptedTokens + tid * (numMTPModules + 1); + auto curLogprobsPtr = logprobs + tid * (numMTPModules + 1); for (int jj = 0; jj < curAcceptedLen; jj++) { curAcceptedTokensPtr[jj] = targetTokens[targetTokensStartOffset + jj]; + curLogprobsPtr[jj] = targetTokenLogprobs[targetTokensStartOffset + jj]; } } } @@ -340,12 +354,12 @@ void invokeMTPSampleAndAcceptDraftTokens(MTPSampleAndAcceptDraftTokensParam& par int greedyBlockSize = min(BLOCK_SIZE, params.vocabSize); mtpGreedySampling<<>>(params.numMTPModules, params.batchSize, - params.numContextRequest, params.vocabSize, reinterpret_cast(params.logits), params.targetTokens); + params.numContextRequest, params.vocabSize, reinterpret_cast(params.logits), params.targetTokens, params.targetTokenLogprobs); sync_check_cuda_error(stream); mtpAcceptDraftToken<<>>(params.numMTPModules, - params.batchSize, params.numContextRequest, params.draftTokens, reinterpret_cast(params.targetTokens), - params.acceptedTokens, params.numAcceptedTokens); + params.batchSize, params.numContextRequest, params.draftTokens, reinterpret_cast(params.targetTokens), reinterpret_cast(params.targetTokenLogprobs), + params.acceptedTokens, params.numAcceptedTokens, params.logprobs); sync_check_cuda_error(stream); } diff --git a/cpp/tensorrt_llm/kernels/speculativeDecoding/mtpKernels.h b/cpp/tensorrt_llm/kernels/speculativeDecoding/mtpKernels.h index a930f12356b..279305e5f15 100644 --- a/cpp/tensorrt_llm/kernels/speculativeDecoding/mtpKernels.h +++ b/cpp/tensorrt_llm/kernels/speculativeDecoding/mtpKernels.h @@ -62,7 +62,9 @@ struct MTPSampleAndAcceptDraftTokensParam void* __restrict__ logits; int* draftTokens; int* targetTokens; + float* targetTokenLogprobs; int* acceptedTokens; + float* logprobs; int* numAcceptedTokens; }; diff --git a/cpp/tensorrt_llm/thop/mtpOp.cpp b/cpp/tensorrt_llm/thop/mtpOp.cpp index dd880f2b792..ec5e8b885a1 100644 --- a/cpp/tensorrt_llm/thop/mtpOp.cpp +++ b/cpp/tensorrt_llm/thop/mtpOp.cpp @@ -92,8 +92,8 @@ std::tuple mtp_prepare_drafter_inputs_op(th::Tensor& inp } //////////////////////////////////////////////////////////////////////////////////////////////////////////// -std::tuple mtp_sampling_and_accepted_draft_tokens_op(th::Tensor& logits, - th::Tensor& draftTokens, th::Tensor& targetTokens, int64_t numMTPModules, int64_t batchSize, +std::tuple mtp_sampling_and_accepted_draft_tokens_op(th::Tensor& logits, + th::Tensor& draftTokens, th::Tensor& targetTokens, th::Tensor& targetTokenLogprobs, int64_t numMTPModules, int64_t batchSize, int64_t numContextRequest, int64_t vocabSize) { int const numGenerationRequest = batchSize - numContextRequest; @@ -111,6 +111,8 @@ std::tuple mtp_sampling_and_accepted_draft_tokens_op(th: auto stream = at::cuda::getCurrentCUDAStream(logits.get_device()); auto acceptedTokens = torch::empty( {batchSize, numMTPModules + 1}, at::TensorOptions().dtype(torch::kInt32).device(logits.device())); + auto logprobs = torch::empty( + {batchSize, numMTPModules + 1}, at::TensorOptions().dtype(torch::kFloat32).device(logits.device())); auto numAcceptedTokens = torch::ones({batchSize}, at::TensorOptions().dtype(torch::kInt32).device(logits.device())); // Fill params @@ -122,6 +124,8 @@ std::tuple mtp_sampling_and_accepted_draft_tokens_op(th: params.draftTokens = reinterpret_cast(draftTokens.data_ptr()); params.targetTokens = reinterpret_cast(targetTokens.data_ptr()); params.acceptedTokens = reinterpret_cast(acceptedTokens.data_ptr()); + params.targetTokenLogprobs = reinterpret_cast(targetTokenLogprobs.data_ptr()); + params.logprobs = reinterpret_cast(logprobs.data_ptr()); params.numAcceptedTokens = reinterpret_cast(numAcceptedTokens.data_ptr()); params.logits = logits.data_ptr(); @@ -145,7 +149,7 @@ std::tuple mtp_sampling_and_accepted_draft_tokens_op(th: break; } - return std::make_tuple(acceptedTokens, numAcceptedTokens); + return std::make_tuple(acceptedTokens, numAcceptedTokens, logprobs); } //////////////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tensorrt_llm/_torch/speculative/mtp.py b/tensorrt_llm/_torch/speculative/mtp.py index 25907436d8e..73ea7698009 100644 --- a/tensorrt_llm/_torch/speculative/mtp.py +++ b/tensorrt_llm/_torch/speculative/mtp.py @@ -810,14 +810,18 @@ def sample_and_accept_draft_tokens( # Strict acceptance else: - if False: + if self.is_thop: # Temporary buffer target_tokens_cache = torch.zeros(batch_size * (mtp_num_modules + 1), dtype=torch.int, device=logits.device) + target_token_logprobs_cache = torch.zeros(batch_size * + (mtp_num_modules + 1), + dtype=torch.float32, + device=logits.device) accepted_tokens, num_accepted_tokens = torch.ops.trtllm.mtp_sampling_and_accepted_draft_tokens_op( - logits, spec_metadata.draft_tokens, target_tokens_cache, + logits, spec_metadata.draft_tokens, target_tokens_cache, target_token_logprobs_cache, mtp_num_modules, batch_size, num_contexts, logits.shape[-1]) else: # Do greedy sampling for the input logits From df5230f0279e1ba7382ae41aabd5416b2cff4e2a Mon Sep 17 00:00:00 2001 From: Shang-Pin Sheng Date: Mon, 9 Jun 2025 16:24:36 -0700 Subject: [PATCH 3/9] fix completion --- cpp/tensorrt_llm/thop/mtpOp.cpp | 4 +-- .../_torch/custom_ops/cpp_custom_ops.py | 2 +- tensorrt_llm/_torch/speculative/mtp.py | 4 +-- tensorrt_llm/serve/openai_protocol.py | 7 ----- tensorrt_llm/serve/postprocess_handlers.py | 27 +++++++++++++++++-- 5 files changed, 30 insertions(+), 14 deletions(-) diff --git a/cpp/tensorrt_llm/thop/mtpOp.cpp b/cpp/tensorrt_llm/thop/mtpOp.cpp index ec5e8b885a1..fe170dabb23 100644 --- a/cpp/tensorrt_llm/thop/mtpOp.cpp +++ b/cpp/tensorrt_llm/thop/mtpOp.cpp @@ -295,8 +295,8 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def( "mtp_sampling_and_accepted_draft_tokens_op(Tensor logits, Tensor draftTokens, Tensor " - "targetTokens, int numMTPModules, " - "int batchSize, int numContextRequest, int vocabSize) -> (Tensor, Tensor)"); + "targetTokens, Tensor targetTokenLogprobs, int numMTPModules, " + "int batchSize, int numContextRequest, int vocabSize) -> (Tensor, Tensor, Tensor)"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) diff --git a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py index 39a3788b558..c922635716a 100644 --- a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py @@ -311,7 +311,7 @@ def _( @torch.library.register_fake( "trtllm::mtp_sampling_and_accepted_draft_tokens_op") def _(logits: torch.Tensor, draft_tokens: torch.Tensor, - target_tokens: torch.Tensor, num_mtp_modules: int, batch_size: int, + target_tokens: torch.Tensor, target_token_logprobs: torch.Tensor, num_mtp_modules: int, batch_size: int, num_context_request: int, vocab_size: int): return logits.new_empty((batch_size, num_mtp_modules + 1), dtype=torch.int32), logits.new_empty( diff --git a/tensorrt_llm/_torch/speculative/mtp.py b/tensorrt_llm/_torch/speculative/mtp.py index 73ea7698009..76abe05a931 100644 --- a/tensorrt_llm/_torch/speculative/mtp.py +++ b/tensorrt_llm/_torch/speculative/mtp.py @@ -244,7 +244,7 @@ def update_requests(self, state: SampleStateMTP) -> None: next_draft_tokens_list = state.host.next_draft_tokens.tolist() next_tokens_log_probs_list = state.host.next_tokens_log_probs.tolist() - def handle_logprobs(request, tokens, log_probs): + def handle_logprobs(request: LlmRequest, tokens, log_probs): token_log_probs = [{ token: Logprob(logprob=logprob, rank=1) } for token, logprob in zip(tokens, log_probs)] @@ -820,7 +820,7 @@ def sample_and_accept_draft_tokens( (mtp_num_modules + 1), dtype=torch.float32, device=logits.device) - accepted_tokens, num_accepted_tokens = torch.ops.trtllm.mtp_sampling_and_accepted_draft_tokens_op( + accepted_tokens, num_accepted_tokens, log_probs = torch.ops.trtllm.mtp_sampling_and_accepted_draft_tokens_op( logits, spec_metadata.draft_tokens, target_tokens_cache, target_token_logprobs_cache, mtp_num_modules, batch_size, num_contexts, logits.shape[-1]) else: diff --git a/tensorrt_llm/serve/openai_protocol.py b/tensorrt_llm/serve/openai_protocol.py index 41b56ac04a0..545333094c7 100644 --- a/tensorrt_llm/serve/openai_protocol.py +++ b/tensorrt_llm/serve/openai_protocol.py @@ -259,13 +259,6 @@ def check_beam_search(self): "Only support one response per prompt without beam search") return self - @model_validator(mode="before") - @classmethod - def check_logprobs(cls, data): - if data.get("logprobs"): - raise ValueError("logprobs is not supported") - return data - @model_validator(mode="before") @classmethod def validate_stream_options(cls, data): diff --git a/tensorrt_llm/serve/postprocess_handlers.py b/tensorrt_llm/serve/postprocess_handlers.py index 6487ea25814..a8d18412466 100644 --- a/tensorrt_llm/serve/postprocess_handlers.py +++ b/tensorrt_llm/serve/postprocess_handlers.py @@ -18,7 +18,7 @@ ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, - ChatCompletionToolsParam, ChatMessage, + ChatCompletionToolsParam, ChatMessage, CompletionLogProbs, CompletionRequest, CompletionResponse, CompletionResponseChoice, CompletionResponseStreamChoice, @@ -77,6 +77,22 @@ def create_logprobs(token_ids: List[int], chat_logprobs = ChatCompletionLogProbs(content=content) return chat_logprobs +def create_logprobs_completion(token_ids: List[int], + tokenizer: TransformersTokenizer, + logprobs: TokenLogprobs) -> CompletionLogProbs: + assert len(token_ids) == len(logprobs), \ + "token_ids and logprobs have different lengths" + token_logprobs: List[Optional[float]] = [] + tokens: List[str] = [] + for logprob in logprobs: + token_id, lp = list(logprob.items())[0] + token = tokenizer.decode(token_id) + # returning multiple logprobs is not supported + token_logprobs.append(max(lp.logprob, -9999.0)) + tokens.append(token) + completion_logprobs = CompletionLogProbs(token_logprobs=token_logprobs,tokens=tokens) + return completion_logprobs + def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str, streaming: bool) -> Tuple[bool, str, str]: reasoning_parser = None @@ -261,6 +277,7 @@ class CompletionPostprocArgs(PostprocArgs): prompt_idx: int = 0 prompt: Optional[str] = None stream_options: Optional[StreamOptions] = None + return_logprobs: bool = False @classmethod def from_request(cls, request: CompletionRequest): @@ -269,6 +286,7 @@ def from_request(cls, request: CompletionRequest): model=request.model, num_choices=request.n if request.n else 1, stream_options=request.stream_options, + return_logprobs=request.logprobs, ) @@ -295,6 +313,10 @@ def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase, args: finish_reason = output.finish_reason, stop_reason = output.stop_reason, ) + if args.return_logprobs: + logprobs, args.last_logprobs_len = output.logprobs_diff_safe(args.last_logprobs_len) + token_ids, args.last_token_ids_len = output.token_ids_diff_safe(args.last_token_ids_len) + choice.logprobs = create_logprobs_completion(token_ids, args.tokenizer, logprobs) chunk = CompletionStreamResponse(model=args.model, choices=[choice]) if include_continuous_usage: chunk.usage = UsageInfo(prompt_tokens=prompt_tokens, @@ -336,7 +358,8 @@ def completion_response_post_processor(rsp: GenerationResult, args: CompletionPo stop_reason=output.stop_reason, finish_reason=output.finish_reason, ) - + if args.return_logprobs: + choice.logprobs = create_logprobs_completion(output.token_ids, args.tokenizer, output.logprobs) completion_tokens += output.length choices.append(choice) From 940d632a388ee26b14c35829f9dc832b472b374c Mon Sep 17 00:00:00 2001 From: Shang-Pin Sheng Date: Fri, 13 Jun 2025 17:01:14 -0700 Subject: [PATCH 4/9] logprobs use diff --- tensorrt_llm/_torch/pyexecutor/llm_request.py | 7 ++++++- tensorrt_llm/executor/result.py | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index 799a8867e55..81bb13d2d68 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -1,3 +1,4 @@ +import copy from typing import List, Optional import torch @@ -265,8 +266,12 @@ def create_response( use_fast_logits=False, mpi_world_rank=0) -> tensorrt_llm.bindings.executor.Response | None: response = super().create_response(use_fast_logits, mpi_world_rank) + py_result = None + if response: + py_result = copy.copy(self.py_result) + self.py_result._log_probs = LogProbStorage() return LlmResponse(response, - self.py_result) if response is not None else None + py_result) if response is not None else None def convert_wordlist(word_list) -> List[List[int]]: diff --git a/tensorrt_llm/executor/result.py b/tensorrt_llm/executor/result.py index c94b7af319b..a64fc7f5701 100644 --- a/tensorrt_llm/executor/result.py +++ b/tensorrt_llm/executor/result.py @@ -234,7 +234,7 @@ def _handle_sequence(self, if response_tensors.log_probs is not None: output._last_logprobs_len = len(output.logprobs) - output.logprobs = response_tensors.log_probs[src_idx] + output.logprobs.extend(response_tensors.log_probs[src_idx]) # overcome some WAR in the cpp executor if finish_reasons[src_idx] != tllm.FinishReason.CANCELLED: assert len(output.logprobs) == output.length From 24770e66f6b181b20f147734eddd6f4c8b5dd060 Mon Sep 17 00:00:00 2001 From: Yiqing Yan Date: Thu, 5 Jun 2025 14:03:39 +0800 Subject: [PATCH 5/9] Downgrade NCCL version from 2.26.5 to 2.25.1 (#4931) Signed-off-by: Yiqing Yan --- .devcontainer/docker-compose.yml | 2 +- docker/common/install_tensorrt.sh | 6 +++--- jenkins/L0_MergeRequest.groovy | 8 ++++---- jenkins/controlCCache.groovy | 2 +- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.devcontainer/docker-compose.yml b/.devcontainer/docker-compose.yml index 6d5618a9b6a..523d914a216 100644 --- a/.devcontainer/docker-compose.yml +++ b/.devcontainer/docker-compose.yml @@ -1,7 +1,7 @@ version: "3.9" services: tensorrt_llm-dev: - image: urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202505191345-4400 + image: urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202505292346-4931 network_mode: host ipc: host diff --git a/docker/common/install_tensorrt.sh b/docker/common/install_tensorrt.sh index 8dda2552295..38c3a5bbaea 100644 --- a/docker/common/install_tensorrt.sh +++ b/docker/common/install_tensorrt.sh @@ -9,9 +9,9 @@ CUDA_VER="12.9" # 12.9.0 # Keep the installation for cuDNN if users want to install PyTorch with source codes. # PyTorch 2.x can compile with cuDNN v9. CUDNN_VER="9.9.0.52-1" -# NCCL version 2.26.3 used in the NGC PyTorch 25.04 image but not existing in public. -# Use NCCL version 2.26.5 instead. -NCCL_VER="2.26.5-1+cuda12.9" +# NCCL version 2.26.x used in the NGC PyTorch 25.04 image but has a performance regression issue. +# Use NCCL version 2.25.1 instead. +NCCL_VER="2.25.1-1+cuda12.8" # cuBLAS version 12.9.0.2 used in the NGC PyTorch 25.04 image but not existing in public. # Use cuBLAS version 12.9.0.13 instead. CUBLAS_VER="12.9.0.13-1" diff --git a/jenkins/L0_MergeRequest.groovy b/jenkins/L0_MergeRequest.groovy index b7f4974b1e3..d4c31d2dde7 100644 --- a/jenkins/L0_MergeRequest.groovy +++ b/jenkins/L0_MergeRequest.groovy @@ -21,10 +21,10 @@ UPLOAD_PATH = env.uploadPath ? env.uploadPath : "sw-tensorrt-generic/llm-artifac // Container configuration // available tags can be found in: https://urm.nvidia.com/artifactory/sw-tensorrt-docker/tensorrt-llm/ // [base_image_name]-[arch]-[os](-[python_version])-[trt_version]-[torch_install_type]-[stage]-[date]-[mr_id] -LLM_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202505191345-4400" -LLM_SBSA_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-aarch64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202505191345-4400" -LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py310-trt10.10.0.31-skip-tritondevel-202505191345-4400" -LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py312-trt10.10.0.31-skip-tritondevel-202505191345-4400" +LLM_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202505292346-4931" +LLM_SBSA_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-aarch64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202505292346-4931" +LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py310-trt10.10.0.31-skip-tritondevel-202505292346-4931" +LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py312-trt10.10.0.31-skip-tritondevel-202505292346-4931" // TODO: Move common variables to an unified location BUILD_CORES_REQUEST = "8" diff --git a/jenkins/controlCCache.groovy b/jenkins/controlCCache.groovy index 4f202fc1bf2..379f39dd6c0 100644 --- a/jenkins/controlCCache.groovy +++ b/jenkins/controlCCache.groovy @@ -1,7 +1,7 @@ import java.lang.InterruptedException -DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202505191345-4400" +DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202505292346-4931" def createKubernetesPodConfig(image) { From 20e7bcb0fd7979fa81ce4ba6aebdeadeb9d0f304 Mon Sep 17 00:00:00 2001 From: Patrick Horn Date: Tue, 10 Jun 2025 20:57:24 +0000 Subject: [PATCH 6/9] pin flashinfer due to compat break --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 7f4fa01fb52..a9250d61d29 100644 --- a/requirements.txt +++ b/requirements.txt @@ -47,7 +47,7 @@ setuptools<80 ordered-set peft einops -flashinfer-python~=0.2.3 +flashinfer-python==0.2.5 opencv-python-headless xgrammar==0.1.16 backoff From 90c74b319ec28c94239527e0e50bebbc031e5d10 Mon Sep 17 00:00:00 2001 From: Shang-Pin Sheng Date: Wed, 18 Jun 2025 10:24:13 -0700 Subject: [PATCH 7/9] fix crash --- tensorrt_llm/executor/result.py | 9 ++++++--- tensorrt_llm/serve/postprocess_handlers.py | 8 ++++---- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/tensorrt_llm/executor/result.py b/tensorrt_llm/executor/result.py index a64fc7f5701..edd035041e2 100644 --- a/tensorrt_llm/executor/result.py +++ b/tensorrt_llm/executor/result.py @@ -116,13 +116,16 @@ def length(self) -> int: return len(self.token_ids) def text_diff_safe(self, last_text_len) -> Tuple[str, int]: - return self.text[last_text_len:], len(self.text) + l = len(self.text) + return self.text[last_text_len:l], l def logprobs_diff_safe(self, last_logprobs_len) -> Tuple[List[float], int]: - return self.logprobs[last_logprobs_len:], len(self.logprobs) + l = len(self.logprobs) + return self.logprobs[last_logprobs_len:l], l def token_ids_diff_safe(self, last_token_ids_len) -> Tuple[List[int], int]: - return self.logprobs[last_token_ids_len:], len(self.logprobs) + l = len(self.token_ids) + return self.token_ids[last_token_ids_len:l], l #@property #def text_diff(self) -> str: diff --git a/tensorrt_llm/serve/postprocess_handlers.py b/tensorrt_llm/serve/postprocess_handlers.py index a8d18412466..2f788b8d3f4 100644 --- a/tensorrt_llm/serve/postprocess_handlers.py +++ b/tensorrt_llm/serve/postprocess_handlers.py @@ -62,8 +62,8 @@ def from_request(cls, request: ChatCompletionRequest): def create_logprobs(token_ids: List[int], tokenizer: TransformersTokenizer, logprobs: TokenLogprobs) -> ChatCompletionLogProbs: - assert len(token_ids) == len(logprobs), \ - "token_ids and logprobs have different lengths" + # assert len(token_ids) == len(logprobs), \ + # "token_ids and logprobs have different lengths" content: List[ChatCompletionLogProbsContent] = [] for logprob in logprobs: token_id, lp = list(logprob.items())[0] @@ -80,8 +80,8 @@ def create_logprobs(token_ids: List[int], def create_logprobs_completion(token_ids: List[int], tokenizer: TransformersTokenizer, logprobs: TokenLogprobs) -> CompletionLogProbs: - assert len(token_ids) == len(logprobs), \ - "token_ids and logprobs have different lengths" + # assert len(token_ids) == len(logprobs), \ + # "token_ids and logprobs have different lengths" token_logprobs: List[Optional[float]] = [] tokens: List[str] = [] for logprob in logprobs: From 07d3f0141fa0490457859be9f8ab39ddd246dfe0 Mon Sep 17 00:00:00 2001 From: Shang-Pin Sheng Date: Wed, 18 Jun 2025 13:57:11 -0700 Subject: [PATCH 8/9] fix crash without logprobs --- tensorrt_llm/_torch/pyexecutor/llm_request.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index 81bb13d2d68..a8b268a2307 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -269,7 +269,8 @@ def create_response( py_result = None if response: py_result = copy.copy(self.py_result) - self.py_result._log_probs = LogProbStorage() + if self.py_result._log_probs: + self.py_result._log_probs = LogProbStorage() return LlmResponse(response, py_result) if response is not None else None From 8c5e4c95a88dbb8ef92b2aced77e380f0fc72920 Mon Sep 17 00:00:00 2001 From: Shang-Pin Sheng Date: Tue, 24 Jun 2025 13:30:24 -0700 Subject: [PATCH 9/9] change serve.py model --- tensorrt_llm/commands/serve.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/tensorrt_llm/commands/serve.py b/tensorrt_llm/commands/serve.py index bf36d9a80e4..266e67b4f8f 100644 --- a/tensorrt_llm/commands/serve.py +++ b/tensorrt_llm/commands/serve.py @@ -100,12 +100,7 @@ def launch_server(host: str, port: int, llm_args: dict): @click.command("serve") -@click.argument("model_name", type=str, default=None) -@click.option("--model", - type=str, - default=None, - help="model name or path." - "Model name to use. Defaults to model_path.") +@click.argument("model", type=str) @click.option("--served-model-name", type=str, default=None, @@ -201,7 +196,7 @@ def launch_server(host: str, port: int, llm_args: dict): default=None, help="[Experimental] Specify the parser for reasoning models.", ) -def serve(model_name: Optional[str], model: Optional[str], +def serve(model: str, served_model_name: Optional[str], tokenizer: Optional[str], host: str, port: int, log_level: str, backend: str, max_beam_width: int, @@ -218,7 +213,6 @@ def serve(model_name: Optional[str], model: Optional[str], MODEL: model name | HF checkpoint path | TensorRT engine path """ logger.set_level(log_level) - model = model or model_name llm_args, _ = get_llm_args( model=model,