Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .devcontainer/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
version: "3.9"
services:
tensorrt_llm-dev:
image: urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202505191345-4400
image: urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202505292346-4931
network_mode: host
ipc: host

Expand Down
24 changes: 19 additions & 5 deletions cpp/tensorrt_llm/kernels/speculativeDecoding/mtpKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ template void invokeMTPPrepareDrafterInputs<__nv_bfloat16>(MTPPrepareDrafterInpu

template <typename T, int BLOCK_SIZE>
__global__ void mtpGreedySampling(int const numMTPModules, int const batchSize, int const numContextRequest,
int const vocabSize, T const* logits, int* targetTokens)
int const vocabSize, T const* logits, int* targetTokens, float* targetTokenLogprobs)
{
/*
In a batch of request: context request (at the beginning) + generation requests
Expand All @@ -217,6 +217,7 @@ __global__ void mtpGreedySampling(int const numMTPModules, int const batchSize,

__shared__ T maxValueCache[BLOCK_SIZE];
__shared__ int maxValueIndexCache[BLOCK_SIZE];
__shared__ float sumExpCache[BLOCK_SIZE];

int const bid = static_cast<int>(blockIdx.x);
int const tid = static_cast<int>(threadIdx.x);
Expand All @@ -227,19 +228,23 @@ __global__ void mtpGreedySampling(int const numMTPModules, int const batchSize,

T tmpMaxValue = curLogitsPtr[0];
int tmpMaxValueIndex = 0;
float tmpNorm = 0.0f;
int ii = tid;
while (ii < vocabSize)
{
if (curLogitsPtr[ii] >= tmpMaxValue)
{
tmpNorm *= expf(tmpMaxValue - curLogitsPtr[ii]);
// Find the first top-1
tmpMaxValueIndex = (curLogitsPtr[ii] == tmpMaxValue) ? min(tmpMaxValueIndex, ii) : ii;
tmpMaxValue = curLogitsPtr[ii];
}
tmpNorm += expf(curLogitsPtr[ii] - tmpMaxValue);
ii += blockDim.x;
}
maxValueCache[tid] = tmpMaxValue;
maxValueIndexCache[tid] = tmpMaxValueIndex;
sumExpCache[tid] = tmpNorm;

__syncthreads();

Expand All @@ -251,11 +256,17 @@ __global__ void mtpGreedySampling(int const numMTPModules, int const batchSize,
{
if (maxValueCache[tid] <= maxValueCache[tid + ii])
{
sumExpCache[tid] *= expf(maxValueCache[tid] - maxValueCache[tid + ii]);
maxValueIndexCache[tid] = (maxValueCache[tid] == maxValueCache[tid + ii])
? min(maxValueIndexCache[tid], maxValueIndexCache[tid + ii])
: maxValueIndexCache[tid + ii];
maxValueCache[tid] = maxValueCache[tid + ii];
}
else
{
sumExpCache[tid + ii] *= expf(maxValueCache[tid + ii] - maxValueCache[tid]);
}
sumExpCache[tid] += sumExpCache[tid + ii];
}
__syncthreads();
ii /= 2;
Expand All @@ -264,11 +275,12 @@ __global__ void mtpGreedySampling(int const numMTPModules, int const batchSize,
if (tid == 0)
{
targetTokens[bid] = maxValueIndexCache[tid];
targetTokenLogprobs[bid] = logf(1.0f / sumExpCache[tid]);
}
}

__global__ void mtpAcceptDraftToken(int const numMTPModules, int const batchSize, int const numContextRequest,
int const* draftTokens, int* targetTokens, int* acceptedTokens, int* numAcceptedTokens)
int const* draftTokens, int* targetTokens, float* targetTokenLogprobs, int* acceptedTokens, int* numAcceptedTokens, float* logprobs)
{
/*
In a batch of request: context request (at the beginning) + generation requests
Expand Down Expand Up @@ -324,9 +336,11 @@ __global__ void mtpAcceptDraftToken(int const numMTPModules, int const batchSize

// Write back to acceptedTokens
auto curAcceptedTokensPtr = acceptedTokens + tid * (numMTPModules + 1);
auto curLogprobsPtr = logprobs + tid * (numMTPModules + 1);
for (int jj = 0; jj < curAcceptedLen; jj++)
{
curAcceptedTokensPtr[jj] = targetTokens[targetTokensStartOffset + jj];
curLogprobsPtr[jj] = targetTokenLogprobs[targetTokensStartOffset + jj];
}
}
}
Expand All @@ -340,12 +354,12 @@ void invokeMTPSampleAndAcceptDraftTokens(MTPSampleAndAcceptDraftTokensParam& par
int greedyBlockSize = min(BLOCK_SIZE, params.vocabSize);

mtpGreedySampling<T, BLOCK_SIZE><<<numLogits, greedyBlockSize, 0, stream>>>(params.numMTPModules, params.batchSize,
params.numContextRequest, params.vocabSize, reinterpret_cast<T*>(params.logits), params.targetTokens);
params.numContextRequest, params.vocabSize, reinterpret_cast<T*>(params.logits), params.targetTokens, params.targetTokenLogprobs);
sync_check_cuda_error(stream);

mtpAcceptDraftToken<<<divUp(params.batchSize, BLOCK_SIZE), BLOCK_SIZE, 0, stream>>>(params.numMTPModules,
params.batchSize, params.numContextRequest, params.draftTokens, reinterpret_cast<int*>(params.targetTokens),
params.acceptedTokens, params.numAcceptedTokens);
params.batchSize, params.numContextRequest, params.draftTokens, reinterpret_cast<int*>(params.targetTokens), reinterpret_cast<float*>(params.targetTokenLogprobs),
params.acceptedTokens, params.numAcceptedTokens, params.logprobs);
sync_check_cuda_error(stream);
}

Expand Down
2 changes: 2 additions & 0 deletions cpp/tensorrt_llm/kernels/speculativeDecoding/mtpKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ struct MTPSampleAndAcceptDraftTokensParam
void* __restrict__ logits;
int* draftTokens;
int* targetTokens;
float* targetTokenLogprobs;
int* acceptedTokens;
float* logprobs;
int* numAcceptedTokens;
};

Expand Down
14 changes: 9 additions & 5 deletions cpp/tensorrt_llm/thop/mtpOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ std::tuple<th::Tensor, th::Tensor> mtp_prepare_drafter_inputs_op(th::Tensor& inp
}

////////////////////////////////////////////////////////////////////////////////////////////////////////////
std::tuple<th::Tensor, th::Tensor> mtp_sampling_and_accepted_draft_tokens_op(th::Tensor& logits,
th::Tensor& draftTokens, th::Tensor& targetTokens, int64_t numMTPModules, int64_t batchSize,
std::tuple<th::Tensor, th::Tensor, th::Tensor> mtp_sampling_and_accepted_draft_tokens_op(th::Tensor& logits,
th::Tensor& draftTokens, th::Tensor& targetTokens, th::Tensor& targetTokenLogprobs, int64_t numMTPModules, int64_t batchSize,
int64_t numContextRequest, int64_t vocabSize)
{
int const numGenerationRequest = batchSize - numContextRequest;
Expand All @@ -111,6 +111,8 @@ std::tuple<th::Tensor, th::Tensor> mtp_sampling_and_accepted_draft_tokens_op(th:
auto stream = at::cuda::getCurrentCUDAStream(logits.get_device());
auto acceptedTokens = torch::empty(
{batchSize, numMTPModules + 1}, at::TensorOptions().dtype(torch::kInt32).device(logits.device()));
auto logprobs = torch::empty(
{batchSize, numMTPModules + 1}, at::TensorOptions().dtype(torch::kFloat32).device(logits.device()));
auto numAcceptedTokens = torch::ones({batchSize}, at::TensorOptions().dtype(torch::kInt32).device(logits.device()));

// Fill params
Expand All @@ -122,6 +124,8 @@ std::tuple<th::Tensor, th::Tensor> mtp_sampling_and_accepted_draft_tokens_op(th:
params.draftTokens = reinterpret_cast<int*>(draftTokens.data_ptr());
params.targetTokens = reinterpret_cast<int*>(targetTokens.data_ptr());
params.acceptedTokens = reinterpret_cast<int*>(acceptedTokens.data_ptr());
params.targetTokenLogprobs = reinterpret_cast<float*>(targetTokenLogprobs.data_ptr());
params.logprobs = reinterpret_cast<float*>(logprobs.data_ptr());
params.numAcceptedTokens = reinterpret_cast<int*>(numAcceptedTokens.data_ptr());
params.logits = logits.data_ptr();

Expand All @@ -145,7 +149,7 @@ std::tuple<th::Tensor, th::Tensor> mtp_sampling_and_accepted_draft_tokens_op(th:
break;
}

return std::make_tuple(acceptedTokens, numAcceptedTokens);
return std::make_tuple(acceptedTokens, numAcceptedTokens, logprobs);
}

////////////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -291,8 +295,8 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def(
"mtp_sampling_and_accepted_draft_tokens_op(Tensor logits, Tensor draftTokens, Tensor "
"targetTokens, int numMTPModules, "
"int batchSize, int numContextRequest, int vocabSize) -> (Tensor, Tensor)");
"targetTokens, Tensor targetTokenLogprobs, int numMTPModules, "
"int batchSize, int numContextRequest, int vocabSize) -> (Tensor, Tensor, Tensor)");
}

TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
Expand Down
6 changes: 3 additions & 3 deletions docker/common/install_tensorrt.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ CUDA_VER="12.9" # 12.9.0
# Keep the installation for cuDNN if users want to install PyTorch with source codes.
# PyTorch 2.x can compile with cuDNN v9.
CUDNN_VER="9.9.0.52-1"
# NCCL version 2.26.3 used in the NGC PyTorch 25.04 image but not existing in public.
# Use NCCL version 2.26.5 instead.
NCCL_VER="2.26.5-1+cuda12.9"
# NCCL version 2.26.x used in the NGC PyTorch 25.04 image but has a performance regression issue.
# Use NCCL version 2.25.1 instead.
NCCL_VER="2.25.1-1+cuda12.8"
# cuBLAS version 12.9.0.2 used in the NGC PyTorch 25.04 image but not existing in public.
# Use cuBLAS version 12.9.0.13 instead.
CUBLAS_VER="12.9.0.13-1"
Expand Down
8 changes: 4 additions & 4 deletions jenkins/L0_MergeRequest.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ UPLOAD_PATH = env.uploadPath ? env.uploadPath : "sw-tensorrt-generic/llm-artifac
// Container configuration
// available tags can be found in: https://urm.nvidia.com/artifactory/sw-tensorrt-docker/tensorrt-llm/
// [base_image_name]-[arch]-[os](-[python_version])-[trt_version]-[torch_install_type]-[stage]-[date]-[mr_id]
LLM_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202505191345-4400"
LLM_SBSA_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-aarch64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202505191345-4400"
LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py310-trt10.10.0.31-skip-tritondevel-202505191345-4400"
LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py312-trt10.10.0.31-skip-tritondevel-202505191345-4400"
LLM_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202505292346-4931"
LLM_SBSA_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-aarch64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202505292346-4931"
LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py310-trt10.10.0.31-skip-tritondevel-202505292346-4931"
LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py312-trt10.10.0.31-skip-tritondevel-202505292346-4931"

// TODO: Move common variables to an unified location
BUILD_CORES_REQUEST = "8"
Expand Down
2 changes: 1 addition & 1 deletion jenkins/controlCCache.groovy
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

import java.lang.InterruptedException

DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202505191345-4400"
DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202505292346-4931"

def createKubernetesPodConfig(image)
{
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ setuptools<80
ordered-set
peft
einops
flashinfer-python~=0.2.3
flashinfer-python==0.2.5
opencv-python-headless
xgrammar==0.1.16
backoff
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def _(
@torch.library.register_fake(
"trtllm::mtp_sampling_and_accepted_draft_tokens_op")
def _(logits: torch.Tensor, draft_tokens: torch.Tensor,
target_tokens: torch.Tensor, num_mtp_modules: int, batch_size: int,
target_tokens: torch.Tensor, target_token_logprobs: torch.Tensor, num_mtp_modules: int, batch_size: int,
num_context_request: int, vocab_size: int):
return logits.new_empty((batch_size, num_mtp_modules + 1),
dtype=torch.int32), logits.new_empty(
Expand Down
8 changes: 7 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/llm_request.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from typing import List, Optional

import torch
Expand Down Expand Up @@ -265,8 +266,13 @@ def create_response(
use_fast_logits=False,
mpi_world_rank=0) -> tensorrt_llm.bindings.executor.Response | None:
response = super().create_response(use_fast_logits, mpi_world_rank)
py_result = None
if response:
py_result = copy.copy(self.py_result)
if self.py_result._log_probs:
self.py_result._log_probs = LogProbStorage()
return LlmResponse(response,
self.py_result) if response is not None else None
py_result) if response is not None else None


def convert_wordlist(word_list) -> List[List[int]]:
Expand Down
Loading
Loading