Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
99fc347
move weights loading related logic to ModelLoader
QiJune Sep 6, 2025
018b022
fix
QiJune Sep 6, 2025
0242a6e
rebase
QiJune Sep 8, 2025
e4a542c
Merge branch 'main' into model_loader
QiJune Sep 8, 2025
cf84a58
clean
QiJune Sep 8, 2025
746e486
fix
QiJune Sep 8, 2025
2c92b6e
Merge branch 'main' into model_loader
QiJune Sep 9, 2025
38505f5
Merge branch 'main' into model_loader
QiJune Sep 9, 2025
c64d320
fix ci
QiJune Sep 9, 2025
e0e5bf8
rebase
QiJune Sep 9, 2025
18e79fe
Merge branch 'main' into model_loader
QiJune Sep 16, 2025
96e700e
Merge branch 'main' into model_loader
QiJune Sep 16, 2025
1a9b420
Merge branch 'main' into model_loader
QiJune Sep 17, 2025
91d79d6
Merge branch 'main' into model_loader
QiJune Sep 17, 2025
9656901
Merge branch 'main' into model_loader
QiJune Sep 18, 2025
133a9eb
Merge branch 'main' into model_loader
QiJune Sep 18, 2025
9556a94
Merge branch 'main' into model_loader
QiJune Sep 18, 2025
1ae8cfd
rebase
QiJune Sep 19, 2025
fd2a93e
Merge branch 'main' into model_loader
QiJune Sep 19, 2025
41291db
Merge branch 'main' into model_loader
QiJune Sep 20, 2025
bd0a47a
Merge branch 'main' into model_loader
QiJune Sep 22, 2025
199face
clean
QiJune Sep 22, 2025
c3a01d6
fix ci
QiJune Sep 22, 2025
7169c51
fix
QiJune Sep 22, 2025
cf13d0f
Merge branch 'main' into model_loader
QiJune Sep 22, 2025
4a05fb8
Merge branch 'main' into model_loader
QiJune Sep 22, 2025
21ff657
fix
QiJune Sep 23, 2025
810e37d
Merge branch 'main' into model_loader
QiJune Sep 23, 2025
8904e9c
clean
QiJune Sep 24, 2025
1ff33a0
Merge branch 'main' into model_loader
QiJune Sep 24, 2025
797f5e1
fix
QiJune Sep 24, 2025
fa85991
Merge branch 'main' into model_loader
QiJune Sep 24, 2025
382f419
rebase
QiJune Sep 25, 2025
336c2ef
[None][feat] DeepEP LL fp8 dispatch/combine (#7927)
yilin-void Sep 25, 2025
98726a3
[None][chore] Update trtllm-bench documentation on setting FP8 KV cac…
achartier Sep 25, 2025
bb60671
[None][chroe] Update the cuda and tensorrt version in homepage icons.…
nv-guomingz Sep 25, 2025
0945403
[TRTLLM-6541][test] Add NIM perf test cases (#7924)
fredricz-20070104 Sep 25, 2025
14a98ba
Merge branch 'main' into model_loader
QiJune Sep 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ TensorRT-LLM
[![Documentation](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](https://nvidia.github.io/TensorRT-LLM/)
[![python](https://img.shields.io/badge/python-3.12-green)](https://www.python.org/downloads/release/python-3123/)
[![python](https://img.shields.io/badge/python-3.10-green)](https://www.python.org/downloads/release/python-31012/)
[![cuda](https://img.shields.io/badge/cuda-12.9.1-green)](https://developer.nvidia.com/cuda-downloads)
[![trt](https://img.shields.io/badge/TRT-10.11.0-green)](https://developer.nvidia.com/tensorrt)
[![cuda](https://img.shields.io/badge/cuda-13.0.0-green)](https://developer.nvidia.com/cuda-downloads)
[![trt](https://img.shields.io/badge/TRT-10.13.2-green)](https://developer.nvidia.com/tensorrt)
[![version](https://img.shields.io/badge/release-1.1.0rc6-green)](./tensorrt_llm/version.py)
[![license](https://img.shields.io/badge/license-Apache%202-blue)](./LICENSE)

Expand Down
2 changes: 1 addition & 1 deletion cpp/tensorrt_llm/deep_ep/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
set(DEEP_EP_COMMIT 515a311f290eb6d9592fcccfcc80c40f5123ca72)
set(DEEP_EP_COMMIT be2582ffe69b5e7d61c3bc9bf7a5316bc48261f9)
set(NVSHMEM_URL_HASH
SHA256=eb2c8fb3b7084c2db86bd9fd905387909f1dfd483e7b45f7b3c3d5fcf5374b5a)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
|| std::is_same_v<T, __nv_fp8_e5m2>) &&!std::is_same_v<WeightType, cutlass::uint4b_t>;
static constexpr bool use_w4afp8
= std::is_same_v<WeightType, cutlass::uint4b_t> && std::is_same_v<T, __nv_fp8_e4m3>;
static constexpr bool use_fp8_input = std::is_same_v<InputType, __nv_fp8_e4m3>;
static_assert(!std::is_same_v<BackBoneType, __nv_fp8_e4m3>, "Current logic requires backbone type to be >=16-bits");
static_assert(!std::is_same_v<OutputType, __nv_fp8_e4m3>, "Current logic requires output type to be >=16-bits");
#else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1625,7 +1625,7 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input,
else if constexpr (std::is_same_v<ExpandedActivationsType, __nv_fp8_e4m3>
&& std::is_same_v<InputActivationsType, __nv_fp8_e4m3>)
{
TLLM_CHECK_WITH_INFO(!prequant_scales, "NVFP4 is not supported for AWQ");
TLLM_CHECK_WITH_INFO(!prequant_scales, "FP8 is not supported for AWQ");
return quant_params.mxfp8_mxfp4.fc1.weight_block_scale
? &expandInputRowsKernel<InputActivationsType, ExpandedActivationsType,
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX, false>
Expand Down Expand Up @@ -3689,7 +3689,7 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
permuted_token_final_scales_, permuted_row_to_unpermuted_row_, num_rows, hidden_size, experts_per_token,
num_experts_per_node, quant_params, use_per_expert_act_scale, expert_first_token_offset_,
fc1_fp4_act_scale_, input_sf, swizzled_input_sf,
use_w4afp8 ? quant_params.groupwise.fc1.act_scales : nullptr, stream);
(use_w4afp8 && !use_fp8_input) ? quant_params.groupwise.fc1.act_scales : nullptr, stream);
auto const* gemm1_input = gemm1_input_expand;

sync_check_cuda_error(stream);
Expand Down Expand Up @@ -4755,6 +4755,7 @@ template class CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, half, half>;
template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16>;
template class CutlassMoeFCRunner<__nv_bfloat16, __nv_fp8_e4m3, __nv_bfloat16>;
template class CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16>;
template class CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, __nv_bfloat16, __nv_fp8_e4m3>;
#endif
#endif
#ifdef ENABLE_FP4
Expand Down
15 changes: 15 additions & 0 deletions cpp/tensorrt_llm/thop/moeOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,21 @@ class FusedMoeRunner : public torch::CustomClassHolder
}
switch (mActivationDtype)
{
#ifdef ENABLE_FP8
case c10::ScalarType::Float8_e4m3fn:
{
if (isInt4Quant() and mUseW4GroupScaling)
{
mKernelRunner = std::make_unique<
kernels::CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, __nv_bfloat16, __nv_fp8_e4m3>>();
}
else
{
C10_THROW_ERROR_FORMATTED(Error, "FP8 activation type is not supported for non-W4A8 quantization");
}
break;
}
#endif
case c10::ScalarType::Half: mKernelRunner = create_weight_quant_runner<half>(); break;
case c10::ScalarType::BFloat16: mKernelRunner = create_weight_quant_runner<__nv_bfloat16>(); break;
default: C10_THROW_ERROR_FORMATTED(Error, "Unsupported activation type for int-type weight");
Expand Down
5 changes: 3 additions & 2 deletions docs/source/developer-guide/perf-benchmarking.md
Original file line number Diff line number Diff line change
Expand Up @@ -460,9 +460,10 @@ If you would like to force the KV cache quantization, you can specify the follow
when the checkpoint precision is `null`:

```yaml
kv_cache_dtype: "fp8"
kv_cache_config:
dtype: fp8
```

```{tip}
The two valid values for `kv_cache_dtype` are `auto` and `fp8`.
The two valid values for `kv_cache_config.dtype` are `auto` and `fp8`.
```
18 changes: 12 additions & 6 deletions tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# https://github.com/deepseek-ai/DeepEP/blob/aae9fa9a6dd0fec2a723fbb85ec4b22460fab670/README.md
import os
import weakref
from typing import List, Tuple, Union
from typing import List, Optional, Tuple, Union

import torch

Expand Down Expand Up @@ -179,12 +179,18 @@ def low_latency_dispatch_fp4(self, hidden_states: torch.Tensor,

return recv_hidden_states, recv_scales, recv_expert_count, handle

def low_latency_combine_fp4(self, hidden_states: torch.Tensor,
global_scales: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor, handle: Tuple):
def low_latency_combine_low_precision(self, precision: str,
hidden_states: torch.Tensor,
global_scales: Optional[torch.Tensor],
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
handle: Tuple):
"""
Arguments:
precision: the precision of the low-precision kernel, "fp8" for FP8, "nvfp4" for NVFP4.
"""
combined_hidden_states, event, hook = \
self.buffer.low_latency_combine_fp4(hidden_states, global_scales, topk_idx, topk_weights, handle)
self.buffer.low_latency_combine_low_precision(precision, hidden_states, global_scales, topk_idx, topk_weights, handle)
assert event.event is None
assert hook is None

Expand Down
154 changes: 86 additions & 68 deletions tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from .ops import MoEOp, MoEOpSelector
from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod,
DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm,
FP8QDQFusedMoEMethod, MoEWeightLoadingMode,
NVFP4CutlassFusedMoEMethod,
FP8QDQFusedMoEMethod, FusedMoEQuantScalesW4A8,
MoEWeightLoadingMode, NVFP4CutlassFusedMoEMethod,
UnquantizedFusedMoEMethod, WInt4AFP8FusedMoEMethod)
from .routing import BaseMoeRoutingMethod

Expand Down Expand Up @@ -191,13 +191,10 @@ def __init__(
self.use_postquant_alltoall = False
self.use_low_precision_combine = False
if self.enable_alltoall:
qm = self.quant_config.quant_mode
self.use_postquant_alltoall = (os.environ.get(
"TRTLLM_MOE_POST_QUANT_ALLTOALLV", "1")
== "1") and qm.has_nvfp4()
"TRTLLM_MOE_POST_QUANT_ALLTOALLV", "1") == "1")
self.use_low_precision_combine = (os.environ.get(
"TRTLLM_MOE_USE_LOW_PRECISION_COMBINE", "0")
== "1") and qm.has_nvfp4()
"TRTLLM_MOE_USE_LOW_PRECISION_COMBINE", "0") == "1")

if self.alltoall_method_type == AlltoallMethodType.MNNVL:
MnnvlMemory.initialize()
Expand Down Expand Up @@ -319,6 +316,35 @@ def can_use_alltoall(self, all_rank_num_tokens, all_rank_max_num_tokens):

return self.enable_alltoall

def deep_ep_low_latency_dispatch_modify_output_to_adapt_fused_moe(
self, x: torch.Tensor, x_sf: Optional[torch.Tensor],
recv_expert_count: torch.Tensor, final_scales_dtype: torch.dtype
) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor,
torch.Tensor]:
# x shape: [#local experts, EP size * all_rank_max_num_tokens, hidden_size]
# recv_expert_count shape: [#local experts]

# Adapter between `torch.ops.trtllm.fused_moe` and DeepEP
# TODO: remove the adapter by changing `torch.ops.trtllm.fused_moe` API
mask = torch.arange(x.shape[1],
dtype=torch.int32, device=x.device).expand(
x.shape[0],
x.shape[1]) < recv_expert_count.unsqueeze(1)
token_selected_slots = torch.where(
mask,
torch.arange(x.shape[0] * self.mapping.moe_ep_rank,
x.shape[0] * (self.mapping.moe_ep_rank + 1),
dtype=torch.int32,
device=x.device).unsqueeze(1), self.num_slots)
x = x.reshape(x.shape[0] * x.shape[1], x.shape[2])
if x_sf is not None:
x_sf = x_sf.reshape(x_sf.shape[0] * x_sf.shape[1], x_sf.shape[2])
# Cheat the fused_moe API with fake top_k=1
token_selected_slots = token_selected_slots.view(x.shape[0], 1)
token_final_scales = torch.ones_like(token_selected_slots,
dtype=final_scales_dtype)
return x, x_sf, token_selected_slots, token_final_scales

def _get_quant_method(self):
if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant(
exclude_kv_cache=True):
Expand Down Expand Up @@ -468,7 +494,7 @@ def forward_chunk(
use_allgather = not use_all_to_all

# If alltoall is disabled, we need also disable use_postquant_alltoall
use_postquant_alltoall = self.use_postquant_alltoall and use_all_to_all
use_postquant_alltoall = self.use_postquant_alltoall and use_all_to_all and self.has_any_quant

# Prepare additional information for profiling in case padding is applied when using alltoall.
# Only the non-alltoall case is considered for profiling in the warmup phase.
Expand Down Expand Up @@ -518,28 +544,8 @@ def forward_chunk(
assert all_rank_max_num_tokens <= self.deep_ep_max_num_tokens
x, recv_expert_count, deep_ep_handle = \
self.deep_ep_buffer.low_latency_dispatch(x, deep_ep_topk_idx, all_rank_max_num_tokens, self.num_slots)
# x shape: [#local experts, EP size * all_rank_max_num_tokens, hidden_size]
# recv_expert_count shape: [#local experts]

# Adapter between `torch.ops.trtllm.fused_moe` and DeepEP
# TODO: remove the adapter by changing `torch.ops.trtllm.fused_moe` API
mask = torch.arange(
x.shape[1], dtype=torch.int32, device=x.device).expand(
x.shape[0],
x.shape[1]) < recv_expert_count.unsqueeze(1)
token_selected_slots = torch.where(
mask,
torch.arange(
x.shape[0] * self.mapping.moe_ep_rank,
x.shape[0] * (self.mapping.moe_ep_rank + 1),
dtype=torch.int32,
device=x.device).unsqueeze(1), self.num_slots)
x = x.reshape(x.shape[0] * x.shape[1], x.shape[2])
# Cheat the fused_moe API with fake top_k=1
token_selected_slots = token_selected_slots.view(
x.shape[0], 1)
token_final_scales = torch.ones_like(
token_selected_slots, dtype=token_final_scales.dtype)
x, _, token_selected_slots, token_final_scales = self.deep_ep_low_latency_dispatch_modify_output_to_adapt_fused_moe(
x, None, recv_expert_count, token_final_scales.dtype)

x_sf = None
x_row = x.shape[0]
Expand Down Expand Up @@ -621,41 +627,48 @@ def forward_chunk(
if x_sf is not None:
x_sf = x_sf.view(x_sf_dtype)
elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
token_num = x_row
hidden_size = x_col
assert x_sf is not None and self.has_nvfp4
assert hidden_size % 32 == 0
assert x.dtype == torch.uint8 and x_sf.dtype == torch.uint8
assert x_sf.shape[0] == token_num and x_sf.shape[
1] == hidden_size // 16
assert x.shape[0] == token_num and x.shape[1] == hidden_size // 2

assert self.has_any_quant, "DeepEPLowLatency postquant alltoall should have quantization"
assert all_rank_max_num_tokens <= self.deep_ep_max_num_tokens
deep_ep_topk_idx = token_selected_slots
deep_ep_topk_weights = token_final_scales

assert all_rank_max_num_tokens <= self.deep_ep_max_num_tokens
x, x_sf, recv_expert_count, deep_ep_handle = \
self.deep_ep_buffer.low_latency_dispatch_fp4(x, x_sf, deep_ep_topk_idx, all_rank_max_num_tokens, self.num_slots)
assert x.dtype == torch.uint8 and x_sf.dtype == torch.uint8
assert x.dim() == 3 and x_sf.dim() == 3
assert x.shape[2] == hidden_size // 2 and x_sf.shape[
2] == hidden_size // 16

mask = torch.arange(
x.shape[1], dtype=torch.int32, device=x.device).expand(
x.shape[0], x.shape[1]) < recv_expert_count.unsqueeze(1)
token_selected_slots = torch.where(
mask,
torch.arange(x.shape[0] * self.mapping.moe_ep_rank,
x.shape[0] * (self.mapping.moe_ep_rank + 1),
dtype=torch.int32,
device=x.device).unsqueeze(1), self.num_slots)
x = x.reshape(x.shape[0] * x.shape[1], x.shape[2])
x_sf = x_sf.reshape(x_sf.shape[0] * x_sf.shape[1],
x_sf.shape[2])
token_selected_slots = token_selected_slots.view(x.shape[0], 1)
token_final_scales = torch.ones_like(
token_selected_slots, dtype=token_final_scales.dtype)
if self.has_fp8_qdq:
assert x.dtype == torch.float8_e4m3fn and x_sf is None, "x should be torch.float8_e4m3fn and x_sf should be None in fp8 postquant alltoall"
x = x.view(torch.bfloat16)
x, recv_expert_count, deep_ep_handle = \
self.deep_ep_buffer.low_latency_dispatch(x, deep_ep_topk_idx, all_rank_max_num_tokens, self.num_slots)
x = x.view(torch.float8_e4m3fn)
elif self.has_nvfp4:
token_num = x_row
hidden_size = x_col
assert x.dtype == torch.uint8 and x_sf is not None and x_sf.dtype == torch.uint8
assert hidden_size % 32 == 0, "HiddenSize should be divisible by 32 in nvfp4 postquant alltoall"
assert x_sf.shape[0] == token_num and x_sf.shape[
1] == hidden_size // 16
assert x.shape[0] == token_num and x.shape[
1] == hidden_size // 2

x, x_sf, recv_expert_count, deep_ep_handle = \
self.deep_ep_buffer.low_latency_dispatch_fp4(x, x_sf, deep_ep_topk_idx, all_rank_max_num_tokens, self.num_slots)
assert x.dtype == torch.uint8 and x_sf.dtype == torch.uint8
assert x.dim() == 3 and x_sf.dim() == 3
assert x.shape[2] == hidden_size // 2 and x_sf.shape[
2] == hidden_size // 16
elif self.has_w4afp8:
assert isinstance(quant_scales, FusedMoEQuantScalesW4A8)
pre_quant_scales = quant_scales.pre_quant_scale_1
assert pre_quant_scales.shape == (
1, x.shape[1]) and pre_quant_scales.dtype == x.dtype
x = (x * pre_quant_scales).to(torch.float8_e4m3fn).view(
torch.bfloat16)
x, recv_expert_count, deep_ep_handle = \
self.deep_ep_buffer.low_latency_dispatch(x, deep_ep_topk_idx, all_rank_max_num_tokens, self.num_slots)
x = x.view(torch.float8_e4m3fn)
else:
raise ValueError(
f"unsupported quantization mode in postquant alltoall: {self.quant_config.quant_mode}"
)
x, x_sf, token_selected_slots, token_final_scales = self.deep_ep_low_latency_dispatch_modify_output_to_adapt_fused_moe(
x, x_sf, recv_expert_count, token_final_scales.dtype)
else:
raise NotImplementedError(
f"Not available alltoall method type: {self.alltoall_method_type!r}"
Expand Down Expand Up @@ -704,11 +717,16 @@ def forward_chunk(
self.expert_size_per_partition,
num_tokens_per_expert_for_fused_moe, self.hidden_size)
if self.use_low_precision_combine:
global_scales = torch.ops.trtllm.calculate_nvfp4_global_scale(
final_hidden_states, recv_expert_count)
final_hidden_states = self.deep_ep_buffer.low_latency_combine_fp4(
final_hidden_states, global_scales, deep_ep_topk_idx,
deep_ep_topk_weights, deep_ep_handle)
assert self.has_nvfp4 or self.has_w4afp8 or self.has_fp8_qdq, "Low precision combine only supports nvfp4, w4afp8 and fp8 qdq"
precision = "fp8"
global_scales = None
if self.has_nvfp4:
precision = "nvfp4"
global_scales = torch.ops.trtllm.calculate_nvfp4_global_scale(
final_hidden_states, recv_expert_count)
final_hidden_states = self.deep_ep_buffer.low_latency_combine_low_precision(
precision, final_hidden_states, global_scales,
deep_ep_topk_idx, deep_ep_topk_weights, deep_ep_handle)
else:
final_hidden_states = self.deep_ep_buffer.low_latency_combine(
final_hidden_states, deep_ep_topk_idx,
Expand Down
Loading
Loading