diff --git a/.devcontainer/docker-compose.yml b/.devcontainer/docker-compose.yml index 6d5618a9b6a..523d914a216 100644 --- a/.devcontainer/docker-compose.yml +++ b/.devcontainer/docker-compose.yml @@ -1,7 +1,7 @@ version: "3.9" services: tensorrt_llm-dev: - image: urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202505191345-4400 + image: urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202505292346-4931 network_mode: host ipc: host diff --git a/cpp/tensorrt_llm/kernels/speculativeDecoding/mtpKernels.cu b/cpp/tensorrt_llm/kernels/speculativeDecoding/mtpKernels.cu index 469e64a159c..d8802f8f5e2 100644 --- a/cpp/tensorrt_llm/kernels/speculativeDecoding/mtpKernels.cu +++ b/cpp/tensorrt_llm/kernels/speculativeDecoding/mtpKernels.cu @@ -203,7 +203,7 @@ template void invokeMTPPrepareDrafterInputs<__nv_bfloat16>(MTPPrepareDrafterInpu template __global__ void mtpGreedySampling(int const numMTPModules, int const batchSize, int const numContextRequest, - int const vocabSize, T const* logits, int* targetTokens) + int const vocabSize, T const* logits, int* targetTokens, float* targetTokenLogprobs) { /* In a batch of request: context request (at the beginning) + generation requests @@ -217,6 +217,7 @@ __global__ void mtpGreedySampling(int const numMTPModules, int const batchSize, __shared__ T maxValueCache[BLOCK_SIZE]; __shared__ int maxValueIndexCache[BLOCK_SIZE]; + __shared__ float sumExpCache[BLOCK_SIZE]; int const bid = static_cast(blockIdx.x); int const tid = static_cast(threadIdx.x); @@ -227,19 +228,23 @@ __global__ void mtpGreedySampling(int const numMTPModules, int const batchSize, T tmpMaxValue = curLogitsPtr[0]; int tmpMaxValueIndex = 0; + float tmpNorm = 0.0f; int ii = tid; while (ii < vocabSize) { if (curLogitsPtr[ii] >= tmpMaxValue) { + tmpNorm *= expf(tmpMaxValue - curLogitsPtr[ii]); // Find the first top-1 tmpMaxValueIndex = (curLogitsPtr[ii] == tmpMaxValue) ? min(tmpMaxValueIndex, ii) : ii; tmpMaxValue = curLogitsPtr[ii]; } + tmpNorm += expf(curLogitsPtr[ii] - tmpMaxValue); ii += blockDim.x; } maxValueCache[tid] = tmpMaxValue; maxValueIndexCache[tid] = tmpMaxValueIndex; + sumExpCache[tid] = tmpNorm; __syncthreads(); @@ -251,11 +256,17 @@ __global__ void mtpGreedySampling(int const numMTPModules, int const batchSize, { if (maxValueCache[tid] <= maxValueCache[tid + ii]) { + sumExpCache[tid] *= expf(maxValueCache[tid] - maxValueCache[tid + ii]); maxValueIndexCache[tid] = (maxValueCache[tid] == maxValueCache[tid + ii]) ? min(maxValueIndexCache[tid], maxValueIndexCache[tid + ii]) : maxValueIndexCache[tid + ii]; maxValueCache[tid] = maxValueCache[tid + ii]; } + else + { + sumExpCache[tid + ii] *= expf(maxValueCache[tid + ii] - maxValueCache[tid]); + } + sumExpCache[tid] += sumExpCache[tid + ii]; } __syncthreads(); ii /= 2; @@ -264,11 +275,12 @@ __global__ void mtpGreedySampling(int const numMTPModules, int const batchSize, if (tid == 0) { targetTokens[bid] = maxValueIndexCache[tid]; + targetTokenLogprobs[bid] = logf(1.0f / sumExpCache[tid]); } } __global__ void mtpAcceptDraftToken(int const numMTPModules, int const batchSize, int const numContextRequest, - int const* draftTokens, int* targetTokens, int* acceptedTokens, int* numAcceptedTokens) + int const* draftTokens, int* targetTokens, float* targetTokenLogprobs, int* acceptedTokens, int* numAcceptedTokens, float* logprobs) { /* In a batch of request: context request (at the beginning) + generation requests @@ -324,9 +336,11 @@ __global__ void mtpAcceptDraftToken(int const numMTPModules, int const batchSize // Write back to acceptedTokens auto curAcceptedTokensPtr = acceptedTokens + tid * (numMTPModules + 1); + auto curLogprobsPtr = logprobs + tid * (numMTPModules + 1); for (int jj = 0; jj < curAcceptedLen; jj++) { curAcceptedTokensPtr[jj] = targetTokens[targetTokensStartOffset + jj]; + curLogprobsPtr[jj] = targetTokenLogprobs[targetTokensStartOffset + jj]; } } } @@ -340,12 +354,12 @@ void invokeMTPSampleAndAcceptDraftTokens(MTPSampleAndAcceptDraftTokensParam& par int greedyBlockSize = min(BLOCK_SIZE, params.vocabSize); mtpGreedySampling<<>>(params.numMTPModules, params.batchSize, - params.numContextRequest, params.vocabSize, reinterpret_cast(params.logits), params.targetTokens); + params.numContextRequest, params.vocabSize, reinterpret_cast(params.logits), params.targetTokens, params.targetTokenLogprobs); sync_check_cuda_error(stream); mtpAcceptDraftToken<<>>(params.numMTPModules, - params.batchSize, params.numContextRequest, params.draftTokens, reinterpret_cast(params.targetTokens), - params.acceptedTokens, params.numAcceptedTokens); + params.batchSize, params.numContextRequest, params.draftTokens, reinterpret_cast(params.targetTokens), reinterpret_cast(params.targetTokenLogprobs), + params.acceptedTokens, params.numAcceptedTokens, params.logprobs); sync_check_cuda_error(stream); } diff --git a/cpp/tensorrt_llm/kernels/speculativeDecoding/mtpKernels.h b/cpp/tensorrt_llm/kernels/speculativeDecoding/mtpKernels.h index a930f12356b..279305e5f15 100644 --- a/cpp/tensorrt_llm/kernels/speculativeDecoding/mtpKernels.h +++ b/cpp/tensorrt_llm/kernels/speculativeDecoding/mtpKernels.h @@ -62,7 +62,9 @@ struct MTPSampleAndAcceptDraftTokensParam void* __restrict__ logits; int* draftTokens; int* targetTokens; + float* targetTokenLogprobs; int* acceptedTokens; + float* logprobs; int* numAcceptedTokens; }; diff --git a/cpp/tensorrt_llm/thop/mtpOp.cpp b/cpp/tensorrt_llm/thop/mtpOp.cpp index dd880f2b792..fe170dabb23 100644 --- a/cpp/tensorrt_llm/thop/mtpOp.cpp +++ b/cpp/tensorrt_llm/thop/mtpOp.cpp @@ -92,8 +92,8 @@ std::tuple mtp_prepare_drafter_inputs_op(th::Tensor& inp } //////////////////////////////////////////////////////////////////////////////////////////////////////////// -std::tuple mtp_sampling_and_accepted_draft_tokens_op(th::Tensor& logits, - th::Tensor& draftTokens, th::Tensor& targetTokens, int64_t numMTPModules, int64_t batchSize, +std::tuple mtp_sampling_and_accepted_draft_tokens_op(th::Tensor& logits, + th::Tensor& draftTokens, th::Tensor& targetTokens, th::Tensor& targetTokenLogprobs, int64_t numMTPModules, int64_t batchSize, int64_t numContextRequest, int64_t vocabSize) { int const numGenerationRequest = batchSize - numContextRequest; @@ -111,6 +111,8 @@ std::tuple mtp_sampling_and_accepted_draft_tokens_op(th: auto stream = at::cuda::getCurrentCUDAStream(logits.get_device()); auto acceptedTokens = torch::empty( {batchSize, numMTPModules + 1}, at::TensorOptions().dtype(torch::kInt32).device(logits.device())); + auto logprobs = torch::empty( + {batchSize, numMTPModules + 1}, at::TensorOptions().dtype(torch::kFloat32).device(logits.device())); auto numAcceptedTokens = torch::ones({batchSize}, at::TensorOptions().dtype(torch::kInt32).device(logits.device())); // Fill params @@ -122,6 +124,8 @@ std::tuple mtp_sampling_and_accepted_draft_tokens_op(th: params.draftTokens = reinterpret_cast(draftTokens.data_ptr()); params.targetTokens = reinterpret_cast(targetTokens.data_ptr()); params.acceptedTokens = reinterpret_cast(acceptedTokens.data_ptr()); + params.targetTokenLogprobs = reinterpret_cast(targetTokenLogprobs.data_ptr()); + params.logprobs = reinterpret_cast(logprobs.data_ptr()); params.numAcceptedTokens = reinterpret_cast(numAcceptedTokens.data_ptr()); params.logits = logits.data_ptr(); @@ -145,7 +149,7 @@ std::tuple mtp_sampling_and_accepted_draft_tokens_op(th: break; } - return std::make_tuple(acceptedTokens, numAcceptedTokens); + return std::make_tuple(acceptedTokens, numAcceptedTokens, logprobs); } //////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -291,8 +295,8 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def( "mtp_sampling_and_accepted_draft_tokens_op(Tensor logits, Tensor draftTokens, Tensor " - "targetTokens, int numMTPModules, " - "int batchSize, int numContextRequest, int vocabSize) -> (Tensor, Tensor)"); + "targetTokens, Tensor targetTokenLogprobs, int numMTPModules, " + "int batchSize, int numContextRequest, int vocabSize) -> (Tensor, Tensor, Tensor)"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) diff --git a/docker/common/install_tensorrt.sh b/docker/common/install_tensorrt.sh index 8dda2552295..38c3a5bbaea 100644 --- a/docker/common/install_tensorrt.sh +++ b/docker/common/install_tensorrt.sh @@ -9,9 +9,9 @@ CUDA_VER="12.9" # 12.9.0 # Keep the installation for cuDNN if users want to install PyTorch with source codes. # PyTorch 2.x can compile with cuDNN v9. CUDNN_VER="9.9.0.52-1" -# NCCL version 2.26.3 used in the NGC PyTorch 25.04 image but not existing in public. -# Use NCCL version 2.26.5 instead. -NCCL_VER="2.26.5-1+cuda12.9" +# NCCL version 2.26.x used in the NGC PyTorch 25.04 image but has a performance regression issue. +# Use NCCL version 2.25.1 instead. +NCCL_VER="2.25.1-1+cuda12.8" # cuBLAS version 12.9.0.2 used in the NGC PyTorch 25.04 image but not existing in public. # Use cuBLAS version 12.9.0.13 instead. CUBLAS_VER="12.9.0.13-1" diff --git a/jenkins/L0_MergeRequest.groovy b/jenkins/L0_MergeRequest.groovy index b7f4974b1e3..d4c31d2dde7 100644 --- a/jenkins/L0_MergeRequest.groovy +++ b/jenkins/L0_MergeRequest.groovy @@ -21,10 +21,10 @@ UPLOAD_PATH = env.uploadPath ? env.uploadPath : "sw-tensorrt-generic/llm-artifac // Container configuration // available tags can be found in: https://urm.nvidia.com/artifactory/sw-tensorrt-docker/tensorrt-llm/ // [base_image_name]-[arch]-[os](-[python_version])-[trt_version]-[torch_install_type]-[stage]-[date]-[mr_id] -LLM_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202505191345-4400" -LLM_SBSA_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-aarch64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202505191345-4400" -LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py310-trt10.10.0.31-skip-tritondevel-202505191345-4400" -LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py312-trt10.10.0.31-skip-tritondevel-202505191345-4400" +LLM_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202505292346-4931" +LLM_SBSA_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-aarch64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202505292346-4931" +LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py310-trt10.10.0.31-skip-tritondevel-202505292346-4931" +LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py312-trt10.10.0.31-skip-tritondevel-202505292346-4931" // TODO: Move common variables to an unified location BUILD_CORES_REQUEST = "8" diff --git a/jenkins/controlCCache.groovy b/jenkins/controlCCache.groovy index 4f202fc1bf2..379f39dd6c0 100644 --- a/jenkins/controlCCache.groovy +++ b/jenkins/controlCCache.groovy @@ -1,7 +1,7 @@ import java.lang.InterruptedException -DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202505191345-4400" +DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202505292346-4931" def createKubernetesPodConfig(image) { diff --git a/requirements.txt b/requirements.txt index 7f4fa01fb52..a9250d61d29 100644 --- a/requirements.txt +++ b/requirements.txt @@ -47,7 +47,7 @@ setuptools<80 ordered-set peft einops -flashinfer-python~=0.2.3 +flashinfer-python==0.2.5 opencv-python-headless xgrammar==0.1.16 backoff diff --git a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py index 39a3788b558..c922635716a 100644 --- a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py @@ -311,7 +311,7 @@ def _( @torch.library.register_fake( "trtllm::mtp_sampling_and_accepted_draft_tokens_op") def _(logits: torch.Tensor, draft_tokens: torch.Tensor, - target_tokens: torch.Tensor, num_mtp_modules: int, batch_size: int, + target_tokens: torch.Tensor, target_token_logprobs: torch.Tensor, num_mtp_modules: int, batch_size: int, num_context_request: int, vocab_size: int): return logits.new_empty((batch_size, num_mtp_modules + 1), dtype=torch.int32), logits.new_empty( diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index 799a8867e55..a8b268a2307 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -1,3 +1,4 @@ +import copy from typing import List, Optional import torch @@ -265,8 +266,13 @@ def create_response( use_fast_logits=False, mpi_world_rank=0) -> tensorrt_llm.bindings.executor.Response | None: response = super().create_response(use_fast_logits, mpi_world_rank) + py_result = None + if response: + py_result = copy.copy(self.py_result) + if self.py_result._log_probs: + self.py_result._log_probs = LogProbStorage() return LlmResponse(response, - self.py_result) if response is not None else None + py_result) if response is not None else None def convert_wordlist(word_list) -> List[List[int]]: diff --git a/tensorrt_llm/_torch/speculative/mtp.py b/tensorrt_llm/_torch/speculative/mtp.py index 144633f4c40..76abe05a931 100644 --- a/tensorrt_llm/_torch/speculative/mtp.py +++ b/tensorrt_llm/_torch/speculative/mtp.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from typing import List, Optional +from tensorrt_llm.executor.result import Logprob import torch from torch import nn @@ -9,7 +10,7 @@ from ..attention_backend import AttentionMetadata from ..pyexecutor.llm_request import LlmRequest, LlmRequestState from ..pyexecutor.resource_manager import BaseResourceManager, SlotManager -from ..pyexecutor.sampler import SampleState, SampleStateTensors, TorchSampler +from ..pyexecutor.sampler import SampleState, SampleStateTensors, TorchSampler, greedy_search_sampling_batch from ..pyexecutor.scheduler import ScheduledRequests from .interface import SpecConfig, SpecMetadata, SpeculativeDecodingMode @@ -18,6 +19,7 @@ class SampleStateTensorsMTP(SampleStateTensors): new_tokens_lens: torch.Tensor next_draft_tokens: torch.Tensor + next_tokens_log_probs: Optional[torch.Tensor] = None @dataclass(frozen=True, kw_only=True) @@ -240,20 +242,29 @@ def update_requests(self, state: SampleStateMTP) -> None: new_tokens_list = state.host.new_tokens.tolist() new_tokens_lens_list = state.host.new_tokens_lens.tolist() next_draft_tokens_list = state.host.next_draft_tokens.tolist() + next_tokens_log_probs_list = state.host.next_tokens_log_probs.tolist() + + def handle_logprobs(request: LlmRequest, tokens, log_probs): + token_log_probs = [{ + token: Logprob(logprob=logprob, rank=1) + } for token, logprob in zip(tokens, log_probs)] + request.py_result.append_log_probs([token_log_probs]) idx = 0 beam_idx = 0 for request in state.scheduled_requests.context_requests: assert not request.py_return_context_logits, "return_context_logits not implemented for MTPSampler" assert not request.py_return_generation_logits, "return_generation_logits not implemented for MTPSampler" - assert not request.py_return_log_probs, "return_log_probs not implemented for MTPSampler" + # assert not request.py_return_log_probs, "return_log_probs not implemented for MTPSampler" if request.get_context_remaining_length() != 0: idx += 1 continue if request.state != LlmRequestState.GENERATION_COMPLETE: new_token = new_tokens_list[idx][0] + new_token_log_prob = next_tokens_log_probs_list[idx][0] num_tokens = request.add_new_token(new_token, beam_idx) + handle_logprobs(request, [new_token], [new_token_log_prob]) should_stop = self._handle_stop_criteria( request, new_token, num_tokens, beam_idx) if self._draft_meet_max_token_stop_criteria( @@ -267,14 +278,17 @@ def update_requests(self, state: SampleStateMTP) -> None: for request in state.scheduled_requests.generation_requests: assert not request.py_return_context_logits, "return_context_logits not implemented for MTPSampler" assert not request.py_return_generation_logits, "return_generation_logits not implemented for MTPSampler" - assert not request.py_return_log_probs, "return_log_probs not implemented for MTPSampler" + # assert not request.py_return_log_probs, "return_log_probs not implemented for MTPSampler" if request.state != LlmRequestState.GENERATION_COMPLETE: new_tokens = new_tokens_list[idx] + new_tokens_log_probs = next_tokens_log_probs_list[idx] num_new_tokens = new_tokens_lens_list[idx] should_stop = False for i in range(num_new_tokens): new_token = new_tokens[i] + new_token_log_prob = new_tokens_log_probs[i] num_tokens = request.add_new_token(new_token, beam_idx) + handle_logprobs(request, [new_token], [new_token_log_prob]) should_stop = self._handle_stop_criteria( request, new_token, num_tokens, beam_idx) if should_stop: @@ -298,17 +312,21 @@ def sample_async(self, scheduled_requests: ScheduledRequests, new_tokens_lens_device = model_outputs['new_tokens_lens'] next_draft_tokens_device = model_outputs['next_draft_tokens'] next_new_tokens_device = model_outputs['next_new_tokens'] + next_tokens_log_probs_device = model_outputs["log_probs"] device = SampleStateTensorsMTP( new_tokens=next_new_tokens_device, new_tokens_lens=new_tokens_lens_device, next_draft_tokens=next_draft_tokens_device, + next_tokens_log_probs=next_tokens_log_probs_device, ) host = SampleStateTensorsMTP( new_tokens=new_tokens_device.to('cpu', non_blocking=True), new_tokens_lens=new_tokens_lens_device.to('cpu', non_blocking=True), next_draft_tokens=next_draft_tokens_device.to('cpu', non_blocking=True), + next_tokens_log_probs=next_tokens_log_probs_device.to( + 'cpu', non_blocking=True), ) sampler_event = torch.cuda.Event() sampler_event.record() @@ -346,7 +364,7 @@ def forward( # Sample and verify draft tokens raw_logits = logits - accepted_tokens, num_accepted_tokens = self.sample_and_accept_draft_tokens( + accepted_tokens, num_accepted_tokens, log_probs = self.sample_and_accept_draft_tokens( input_ids, logits, spec_metadata, attn_metadata) # Update MTP past hidden states @@ -398,7 +416,8 @@ def forward( 'new_tokens': accepted_tokens, 'new_tokens_lens': num_accepted_tokens, 'next_draft_tokens': next_draft_tokens, - 'next_new_tokens': next_new_tokens + 'next_new_tokens': next_new_tokens, + "log_probs": log_probs } def skip_forward( @@ -418,6 +437,9 @@ def skip_forward( accepted_tokens = torch.empty((batch_size, (mtp_num_modules + 1)), dtype=torch.int, device=logits.device) + log_probs = torch.empty((batch_size, (mtp_num_modules + 1)), + dtype=torch.float32, + device=logits.device) num_accepted_tokens = torch.ones(batch_size, dtype=torch.int, device=logits.device) @@ -432,7 +454,8 @@ def skip_forward( 'new_tokens': accepted_tokens, 'new_tokens_lens': num_accepted_tokens, 'next_draft_tokens': next_draft_tokens, - 'next_new_tokens': next_new_tokens + 'next_new_tokens': next_new_tokens, + 'log_probs': log_probs } def update_mtp_hidden_states( @@ -728,6 +751,9 @@ def sample_and_accept_draft_tokens( accepted_tokens = torch.empty((batch_size, (mtp_num_modules + 1)), dtype=torch.int, device=logits.device) + log_probs = torch.empty((batch_size, (mtp_num_modules + 1)), + dtype=torch.float32, + device=logits.device) num_accepted_tokens = torch.ones(batch_size, dtype=torch.int, device=logits.device) @@ -790,20 +816,30 @@ def sample_and_accept_draft_tokens( (mtp_num_modules + 1), dtype=torch.int, device=logits.device) - accepted_tokens, num_accepted_tokens = torch.ops.trtllm.mtp_sampling_and_accepted_draft_tokens_op( - logits, spec_metadata.draft_tokens, target_tokens_cache, + target_token_logprobs_cache = torch.zeros(batch_size * + (mtp_num_modules + 1), + dtype=torch.float32, + device=logits.device) + accepted_tokens, num_accepted_tokens, log_probs = torch.ops.trtllm.mtp_sampling_and_accepted_draft_tokens_op( + logits, spec_metadata.draft_tokens, target_tokens_cache, target_token_logprobs_cache, mtp_num_modules, batch_size, num_contexts, logits.shape[-1]) else: # Do greedy sampling for the input logits - target_tokens = torch.argmax(logits, dim=-1) + target_tokens, target_log_probs = greedy_search_sampling_batch(logits) # context accepted_tokens[:num_contexts, 0] = target_tokens[:num_contexts] + log_probs[:num_contexts, 0] = target_log_probs[:num_contexts] # generation gen_target_tokens = target_tokens[num_contexts:].reshape( num_gens, mtp_num_modules + 1) + gen_target_log_probs = target_log_probs[num_contexts:].reshape( + num_gens, mtp_num_modules + 1) + accepted_tokens[num_contexts:, :] = gen_target_tokens + log_probs[num_contexts:, :] = gen_target_log_probs + draft_tokens = spec_metadata.draft_tokens.reshape( num_gens, mtp_num_modules) num_accepted_tokens[num_contexts:] += torch.cumprod( @@ -811,7 +847,7 @@ def sample_and_accept_draft_tokens( ).int(), dim=-1).sum(1) - return accepted_tokens, num_accepted_tokens + return accepted_tokens, num_accepted_tokens, log_probs def change_attn_metadata(self, num_accepted_tokens: torch.Tensor, attn_metadata: AttentionMetadata): @@ -1067,7 +1103,7 @@ def forward( # Sample and verify draft tokens raw_logits = logits - accepted_tokens, num_accepted_tokens = self.sample_and_accept_draft_tokens( + accepted_tokens, num_accepted_tokens, log_probs = self.sample_and_accept_draft_tokens( input_ids, logits, spec_metadata, attn_metadata) # Save the old attn_metadata and spec_metadata @@ -1150,7 +1186,8 @@ def forward( 'new_tokens': accepted_tokens, 'new_tokens_lens': num_accepted_tokens, 'next_draft_tokens': next_draft_tokens, - 'next_new_tokens': next_new_tokens + 'next_new_tokens': next_new_tokens, + 'log_probs': log_probs } def prepare_drafter_inputs( diff --git a/tensorrt_llm/commands/serve.py b/tensorrt_llm/commands/serve.py index bf36d9a80e4..266e67b4f8f 100644 --- a/tensorrt_llm/commands/serve.py +++ b/tensorrt_llm/commands/serve.py @@ -100,12 +100,7 @@ def launch_server(host: str, port: int, llm_args: dict): @click.command("serve") -@click.argument("model_name", type=str, default=None) -@click.option("--model", - type=str, - default=None, - help="model name or path." - "Model name to use. Defaults to model_path.") +@click.argument("model", type=str) @click.option("--served-model-name", type=str, default=None, @@ -201,7 +196,7 @@ def launch_server(host: str, port: int, llm_args: dict): default=None, help="[Experimental] Specify the parser for reasoning models.", ) -def serve(model_name: Optional[str], model: Optional[str], +def serve(model: str, served_model_name: Optional[str], tokenizer: Optional[str], host: str, port: int, log_level: str, backend: str, max_beam_width: int, @@ -218,7 +213,6 @@ def serve(model_name: Optional[str], model: Optional[str], MODEL: model name | HF checkpoint path | TensorRT engine path """ logger.set_level(log_level) - model = model or model_name llm_args, _ = get_llm_args( model=model, diff --git a/tensorrt_llm/executor/result.py b/tensorrt_llm/executor/result.py index c94b7af319b..edd035041e2 100644 --- a/tensorrt_llm/executor/result.py +++ b/tensorrt_llm/executor/result.py @@ -116,13 +116,16 @@ def length(self) -> int: return len(self.token_ids) def text_diff_safe(self, last_text_len) -> Tuple[str, int]: - return self.text[last_text_len:], len(self.text) + l = len(self.text) + return self.text[last_text_len:l], l def logprobs_diff_safe(self, last_logprobs_len) -> Tuple[List[float], int]: - return self.logprobs[last_logprobs_len:], len(self.logprobs) + l = len(self.logprobs) + return self.logprobs[last_logprobs_len:l], l def token_ids_diff_safe(self, last_token_ids_len) -> Tuple[List[int], int]: - return self.logprobs[last_token_ids_len:], len(self.logprobs) + l = len(self.token_ids) + return self.token_ids[last_token_ids_len:l], l #@property #def text_diff(self) -> str: @@ -234,7 +237,7 @@ def _handle_sequence(self, if response_tensors.log_probs is not None: output._last_logprobs_len = len(output.logprobs) - output.logprobs = response_tensors.log_probs[src_idx] + output.logprobs.extend(response_tensors.log_probs[src_idx]) # overcome some WAR in the cpp executor if finish_reasons[src_idx] != tllm.FinishReason.CANCELLED: assert len(output.logprobs) == output.length diff --git a/tensorrt_llm/serve/openai_protocol.py b/tensorrt_llm/serve/openai_protocol.py index 98c9d04dae6..545333094c7 100644 --- a/tensorrt_llm/serve/openai_protocol.py +++ b/tensorrt_llm/serve/openai_protocol.py @@ -244,6 +244,7 @@ def to_sampling_params(self) -> SamplingParams: # TODO: migrate to use logprobs and prompt_logprobs _return_log_probs=self.logprobs, + logprobs=self.logprobs, ) return sampling_params @@ -258,13 +259,6 @@ def check_beam_search(self): "Only support one response per prompt without beam search") return self - @model_validator(mode="before") - @classmethod - def check_logprobs(cls, data): - if data.get("logprobs"): - raise ValueError("logprobs is not supported") - return data - @model_validator(mode="before") @classmethod def validate_stream_options(cls, data): @@ -553,6 +547,7 @@ def to_sampling_params(self) -> SamplingParams: # TODO: migrate to use logprobs and prompt_logprobs _return_log_probs=self.logprobs, + logprobs=self.logprobs, ) return sampling_params diff --git a/tensorrt_llm/serve/postprocess_handlers.py b/tensorrt_llm/serve/postprocess_handlers.py index 9846dfe4a75..2f788b8d3f4 100644 --- a/tensorrt_llm/serve/postprocess_handlers.py +++ b/tensorrt_llm/serve/postprocess_handlers.py @@ -1,6 +1,8 @@ from dataclasses import dataclass, field from typing import List, Literal, Optional, Tuple, Union +from tensorrt_llm.executor.result import TokenLogprobs + from .._utils import nvtx_range_debug from ..executor import (DetokenizedGenerationResultBase, GenerationResult, GenerationResultBase) @@ -16,7 +18,7 @@ ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, - ChatCompletionToolsParam, ChatMessage, + ChatCompletionToolsParam, ChatMessage, CompletionLogProbs, CompletionRequest, CompletionResponse, CompletionResponseChoice, CompletionResponseStreamChoice, @@ -59,21 +61,38 @@ def from_request(cls, request: ChatCompletionRequest): def create_logprobs(token_ids: List[int], tokenizer: TransformersTokenizer, - logprobs: List[float]) -> ChatCompletionLogProbs: - assert len(token_ids) == len(logprobs), \ - "token_ids and logprobs have different lengths" + logprobs: TokenLogprobs) -> ChatCompletionLogProbs: + # assert len(token_ids) == len(logprobs), \ + # "token_ids and logprobs have different lengths" content: List[ChatCompletionLogProbsContent] = [] - for token_id, logprob in zip(token_ids, logprobs): + for logprob in logprobs: + token_id, lp = list(logprob.items())[0] token = tokenizer.decode(token_id) # returning multiple logprobs is not supported first_logprob = ChatCompletionLogProbsContent( token=token, - logprob=max(logprob, -9999.0), + logprob=max(lp.logprob, -9999.0), bytes=list(token.encode("utf-8", errors="replace"))) content.append(first_logprob) chat_logprobs = ChatCompletionLogProbs(content=content) return chat_logprobs +def create_logprobs_completion(token_ids: List[int], + tokenizer: TransformersTokenizer, + logprobs: TokenLogprobs) -> CompletionLogProbs: + # assert len(token_ids) == len(logprobs), \ + # "token_ids and logprobs have different lengths" + token_logprobs: List[Optional[float]] = [] + tokens: List[str] = [] + for logprob in logprobs: + token_id, lp = list(logprob.items())[0] + token = tokenizer.decode(token_id) + # returning multiple logprobs is not supported + token_logprobs.append(max(lp.logprob, -9999.0)) + tokens.append(token) + completion_logprobs = CompletionLogProbs(token_logprobs=token_logprobs,tokens=tokens) + return completion_logprobs + def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str, streaming: bool) -> Tuple[bool, str, str]: reasoning_parser = None @@ -258,6 +277,7 @@ class CompletionPostprocArgs(PostprocArgs): prompt_idx: int = 0 prompt: Optional[str] = None stream_options: Optional[StreamOptions] = None + return_logprobs: bool = False @classmethod def from_request(cls, request: CompletionRequest): @@ -266,6 +286,7 @@ def from_request(cls, request: CompletionRequest): model=request.model, num_choices=request.n if request.n else 1, stream_options=request.stream_options, + return_logprobs=request.logprobs, ) @@ -292,6 +313,10 @@ def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase, args: finish_reason = output.finish_reason, stop_reason = output.stop_reason, ) + if args.return_logprobs: + logprobs, args.last_logprobs_len = output.logprobs_diff_safe(args.last_logprobs_len) + token_ids, args.last_token_ids_len = output.token_ids_diff_safe(args.last_token_ids_len) + choice.logprobs = create_logprobs_completion(token_ids, args.tokenizer, logprobs) chunk = CompletionStreamResponse(model=args.model, choices=[choice]) if include_continuous_usage: chunk.usage = UsageInfo(prompt_tokens=prompt_tokens, @@ -333,7 +358,8 @@ def completion_response_post_processor(rsp: GenerationResult, args: CompletionPo stop_reason=output.stop_reason, finish_reason=output.finish_reason, ) - + if args.return_logprobs: + choice.logprobs = create_logprobs_completion(output.token_ids, args.tokenizer, output.logprobs) completion_tokens += output.length choices.append(choice) diff --git a/tests/unittest/_torch/speculative/test_mtp_sample_and_accept_draft_tokens.py b/tests/unittest/_torch/speculative/test_mtp_sample_and_accept_draft_tokens.py index 954d5476507..4aa5ec49cf6 100644 --- a/tests/unittest/_torch/speculative/test_mtp_sample_and_accept_draft_tokens.py +++ b/tests/unittest/_torch/speculative/test_mtp_sample_and_accept_draft_tokens.py @@ -312,7 +312,7 @@ def test_sample_and_accept_draft_tokens(self, mtp_num_modules, logits, for is_thop in [True, False]: mtpworker.is_thop = is_thop # TODO: add unit tests for relaxed acceptance - accepted_tokens, num_accepted_tokens = mtpworker.sample_and_accept_draft_tokens( + accepted_tokens, num_accepted_tokens, log_probs = mtpworker.sample_and_accept_draft_tokens( None, logits, spec_metadata, attn_metadata) torch.testing.assert_close(num_accepted_tokens,