Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
453 changes: 451 additions & 2 deletions csrc/trtllm_fused_moe_kernel_launcher.cu

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion csrc/trtllm_fused_moe_runner.cu
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
routingData.mDoSoftmaxBeforeTopK = routingMethodType == RoutingMethodType::RenormalizeNaive;
routingData.mNormTopkProb = routingMethodType == RoutingMethodType::RenormalizeNaive;
routingData.mApplySoftmaxAfterTopK = routingMethodType == RoutingMethodType::Renormalize;

routingData.mPtrScores = routingLogits;

//
Expand Down
4 changes: 2 additions & 2 deletions flashinfer/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def get_available_cubin_files(
class ArtifactPath:
TRTLLM_GEN_FMHA: str = "7206d64e67f4c8949286246d6e2e07706af5d223/fmha/trtllm-gen"
TRTLLM_GEN_BMM: str = (
"e6f22dcc3fdeb29ff87af2f4a2cb3d30b8d273e0/batched_gemm-45beda1-ee6a802"
"a72d85b019dc125b9f711300cb989430f762f5a6/batched_gemm-145d1b1-9e1d49a"
)
TRTLLM_GEN_GEMM: str = (
"037e528e719ec3456a7d7d654f26b805e44c63b1/gemm-8704aa4-f91dc9e"
Expand All @@ -91,7 +91,7 @@ class MetaInfoHash:
"2f605255e71d673768f5bece66dde9e2e9f4c873347bfe8fefcffbf86a3c847d"
)
TRTLLM_GEN_BMM: str = (
"c98b4ce69a39fd41556d67033c30ea814ef76b0a2fe16e798e55baf0104acc34"
"8c5f97d582df0e4fd9f69ddeb3b72cc3a37915c6c20b4d0905fec69702310b63"
)
DEEPGEMM: str = "b4374f857c3066089c4ec6b5e79e785559fa2c05ce2623710b0b04bf86414a48"
TRTLLM_GEN_GEMM: str = (
Expand Down
2 changes: 2 additions & 0 deletions flashinfer/fused_moe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
trtllm_fp4_block_scale_routed_moe,
trtllm_fp8_block_scale_moe,
trtllm_fp8_per_tensor_scale_moe,
trtllm_bf16_moe,
)

__all__ = [
Expand All @@ -44,4 +45,5 @@
"trtllm_fp4_block_scale_moe",
"trtllm_fp8_block_scale_moe",
"trtllm_fp8_per_tensor_scale_moe",
"trtllm_bf16_moe",
]
120 changes: 120 additions & 0 deletions flashinfer/fused_moe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1268,6 +1268,81 @@ def refine_tuning_config(cls, tune_max_num_tokens: int):
),
)

@register_custom_op(
"flashinfer::trtllm_bf16_moe",
mutates_args=(""),
)
def trtllm_bf16_moe_op(
routing_logits: torch.Tensor,
routing_bias: Optional[torch.Tensor],
hidden_states: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm2_weights: torch.Tensor,
num_experts: int,
top_k: int,
n_group: int,
topk_group: int,
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
tile_tokens_dim: int,
routing_method_type: int,
use_shuffled_weight: bool,
weight_layout: int,
moe_tactic: int,
enable_pdl: Optional[bool] = None,
) -> torch.Tensor:
if enable_pdl is None:
enable_pdl = device_support_pdl(hidden_states.device)
# Call the C++ function for block scale MoE
output = moe_op.trtllm_bf16_moe(
routing_logits,
routing_bias,
hidden_states,
gemm1_weights,
gemm2_weights,
num_experts,
top_k,
n_group,
topk_group,
intermediate_size,
local_expert_offset,
local_num_experts,
tile_tokens_dim,
routing_method_type,
use_shuffled_weight,
weight_layout,
moe_tactic,
enable_pdl,
)
return output

@register_fake_op("flashinfer::trtllm_bf16_moe")
def _fake_trtllm_bf16_moe(
routing_logits: torch.Tensor,
routing_bias: Optional[torch.Tensor],
hidden_states: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm2_weights: torch.Tensor,
num_experts: int,
top_k: int,
n_group: int,
topk_group: int,
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
tile_tokens_dim: int,
routing_method_type: int,
use_shuffled_weight: bool,
weight_layout: int,
moe_tactic: int,
enable_pdl: Optional[bool] = None,
):
seq_len = hidden_states.shape[0]
hidden_size = hidden_states.shape[1]

return [hidden_states.new_empty([seq_len, hidden_size], dtype=torch.bfloat16)]

@register_custom_op(
"flashinfer::trtllm_fp8_per_tensor_scale_moe",
mutates_args=(""),
Expand Down Expand Up @@ -1658,12 +1733,57 @@ def _fake_trtllm_fp4_block_scale_moe(
return [hidden_states.new_empty([seq_len, hidden_size], dtype=torch.bfloat16)]

return SimpleNamespace(
trtllm_bf16_moe=trtllm_bf16_moe_op,
trtllm_fp8_per_tensor_scale_moe=trtllm_fp8_per_tensor_scale_moe_op,
trtllm_fp8_block_scale_moe=trtllm_fp8_block_scale_moe_op,
trtllm_fp4_block_scale_moe=trtllm_fp4_block_scale_moe_op,
)


def trtllm_bf16_moe(
routing_logits: torch.Tensor,
routing_bias: Optional[torch.Tensor],
hidden_states: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm2_weights: torch.Tensor,
num_experts: int,
top_k: int,
n_group: int,
topk_group: int,
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
*,
tile_tokens_dim: int = 8,
routing_method_type: int = 0,
use_shuffled_weight: bool = True,
weight_layout: int = WeightLayout.BlockMajorK,
moe_tactic: int = -1,
enable_pdl: bool = True,
) -> torch.Tensor:
"""BF16 block scale MoE operation."""
return get_trtllm_moe_sm100_module().trtllm_bf16_moe(
routing_logits,
routing_bias,
hidden_states,
gemm1_weights,
gemm2_weights,
num_experts,
top_k,
n_group or 0, # may receive None from test configs, convert to 0
topk_group or 0,
intermediate_size,
local_expert_offset,
local_num_experts,
tile_tokens_dim,
routing_method_type,
use_shuffled_weight,
weight_layout,
moe_tactic,
enable_pdl,
)


def trtllm_fp8_per_tensor_scale_moe(
routing_logits: torch.Tensor,
routing_bias: Optional[torch.Tensor],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -506,12 +506,27 @@ class BatchedGemmInterface {
throw std::invalid_argument("Invalid combination of options");
}

int32_t const numCtasTile =
if (batchM) {
numCtasBatch = gemm::divUpMul(numCtasBatch, options.mClusterDimX);
} else {
numCtasBatch = gemm::divUpMul(numCtasBatch, options.mClusterDimY);
}

int32_t numCtasTile =
batchM ? gemm::divUp(options.mN, options.mTileN) : gemm::divUp(options.mM, options.mTileM);
if (batchM) {
numCtasTile = gemm::divUpMul(numCtasTile, options.mClusterDimY);
} else {
numCtasTile = gemm::divUpMul(numCtasTile, options.mClusterDimX);
}
int32_t const numCtasInner = options.mNumSlicesForSplitK;
return std::make_tuple(numCtasBatch, numCtasTile, numCtasInner);
}

// Creates GemmOptions from kernel and data.
BatchedGemmOptions getOptionsFromConfigAndData(BatchedGemmConfig const& config,
BatchedGemmData const& data) const;

// Returns the number of CTAs of the current kernel.
int32_t getNumCtas(BatchedGemmOptions const& options,
std::optional<int32_t> maxNumCtasInBatchDim = std::nullopt) const {
Expand All @@ -522,10 +537,6 @@ class BatchedGemmInterface {
// Returns true if the configuration of the cubin can be executed for the given params.
bool isValidConfig(BatchedGemmConfig const& config, BatchedGemmData const& data) const;

// Creates GemmOptions from kernel and data.
BatchedGemmOptions getOptionsFromConfigAndData(BatchedGemmConfig const& config,
BatchedGemmData const& data) const;

private:
// Aligns the pointer to the alignment
template <typename Dtype>
Expand Down Expand Up @@ -781,6 +792,7 @@ int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, void* workspa
if (result != CUDA_SUCCESS) {
return -1;
}

// If a module cache has not been given, unload the module to avoid leaking
if (!moduleCache.has_value()) {
cuModuleUnload(cuModule);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,39 +76,43 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions {
// FIXME We create explicit constructor with all options to WAR stubgen issue in TRT-LLM.
BatchedGemmOptions(
gemm::AllReduceAlgo allReduceAlgo, gemm::BiasType biasType, int blockK, int clusterDimX,
int clusterDimY, int clusterDimZ, tg::Dtype dtypeAcc, tg::Dtype dtypeA, tg::Dtype dtypeB,
tg::Dtype dtypeC, tg::Dtype dtypeMmaA, tg::Dtype dtypeMmaB, bool enablesEarlyExit,
bool enablesDelayedEarlyExit, bool enablesGlobalPtxKnobs, int epilogueLdtmDps,
int epilogueLdtmBits, int epilogueTileM, int epilogueTileN, bool gridTriggerSecondaryA,
bool gridTriggerSecondaryB, bool gridWaitForPrimaryEarlyExit, bool gridWaitForPrimaryA,
bool gridWaitForPrimaryB, bool hoistLoadTaskInit, bool hoistMmaTaskTryWaits, int k,
gemm::KernelTraits kernelTraits, gemm::MatrixLayout layoutA, gemm::MatrixLayout layoutB,
int m, int mmaK, tg::MmaKind mmaKind, int mmaM, int mmaN, bool mockAllReduce, int n,
int numSlicesForSplitK, int numSlicesForSliceK, int numStages, int numStagesMma,
int numStagesMmaWithinWorkTile, int numStagesMmaAcrossWorkTile, int numStagesWorkId,
bool outputDebugTensors, bool patchF2fp, std::optional<int32_t> sfBlockSizeA,
tg::SfLayout sfLayoutA, tg::SfLayout sfLayoutB, tg::SfLayout sfLayoutC,
int32_t sfReshapeFactor, bool sliceK, gemm::SplitK splitK, int tileK, int tileM, int tileN,
gemm::TileScheduler tileScheduler, bool transposeMmaOutput, bool useCustomMmaSchedule,
bool useDeepSeekFp8, bool useHoistTryWaitForCustomMmaSchedule, bool usePerTokenSfA,
bool usePerTokenSfB, bool useShuffledMatrixA, bool useTmaStore, bool useTwoTmaLoadWarps,
bool useTwoMmaWarps, bool useUnrollLoop2xForMma, int worldSize, gemmGatedAct::ActType actType,
bool clampBeforeAct, std::vector<int> batchedM, std::vector<int> batchedN,
BatchMode batchMode, int numBatches, bool isStaticBatch, int numTokens, RouteImpl routeImpl,
bool gridWaitForPrimaryRouting, bool fusedAct, int numRegsPerThreadNonEpilogueWarp,
int numRegsPerThreadEpilogueWarp, int numRegsCastAWarps, bool useTmaOobOpt)
int clusterDimY, int clusterDimZ, gemm::CtaSwizzleType ctaSwizzleType, tg::Dtype dtypeAcc,
tg::Dtype dtypeA, tg::Dtype dtypeB, tg::Dtype dtypeC, tg::Dtype dtypeMmaA,
tg::Dtype dtypeMmaB, bool enablesEarlyExit, bool enablesDelayedEarlyExit,
bool enablesGlobalPtxKnobs, int epilogueLdtmDps, int epilogueLdtmBits, int epilogueTileM,
int epilogueTileN, bool gridTriggerSecondaryA, bool gridTriggerSecondaryB,
bool gridWaitForPrimaryEarlyExit, bool gridWaitForPrimaryA, bool gridWaitForPrimaryB,
bool hoistLoadTaskInit, bool hoistMmaTaskTryWaits, int k, gemm::KernelTraits kernelTraits,
gemm::MatrixLayout layoutA, gemm::MatrixLayout layoutB, int m, int mmaK, tg::MmaKind mmaKind,
int mmaM, int mmaN, bool mockAllReduce, int n, int numRegsCastAWarps,
int numRegsCopySfLdsSttm, int numRegsPerThreadEpilogueWarp,
int numRegsPerThreadNonEpilogueWarp, int numSlicesForSplitK, int numSlicesForSliceK,
int numStages, int numStagesMma, int numStagesMmaWithinWorkTile,
int numStagesMmaAcrossWorkTile, int numStagesWorkId, bool outputDebugTensors, bool patchF2fp,
std::optional<int32_t> sfBlockSizeA, tg::SfLayout sfLayoutA, tg::SfLayout sfLayoutB,
tg::SfLayout sfLayoutC, int32_t sfReshapeFactor, bool sliceK, gemm::SplitK splitK, int tileK,
int tileM, int tileN, gemm::TileScheduler tileScheduler, bool transposeMmaOutput,
bool useCustomMmaSchedule, bool useDeepSeekFp8, bool useHoistTryWaitForCustomMmaSchedule,
bool usePerTokenSfA, bool usePerTokenSfB, bool useShuffledMatrixA, bool useTmaStore,
bool useTwoTmaLoadWarps, bool useTwoMmaWarps, bool useUnrollLoop2xForMma, int worldSize,
gemmGatedAct::ActType actType, bool clampBeforeAct, std::vector<int> batchedM,
std::vector<int> batchedN, BatchMode batchMode, int numBatches, bool isStaticBatch,
int numTokens, RouteImpl routeImpl, bool gridWaitForPrimaryRouting, bool fusedAct,
bool useTmaOobOpt)
: gemmGatedAct::GemmGatedActOptions(
gemm::GemmOptions(
allReduceAlgo, biasType, blockK, clusterDimX, clusterDimY, clusterDimZ, dtypeAcc,
dtypeA, dtypeB, dtypeC, dtypeMmaA, dtypeMmaB, enablesEarlyExit,
enablesDelayedEarlyExit, enablesGlobalPtxKnobs, epilogueLdtmDps, epilogueLdtmBits,
epilogueTileM, epilogueTileN, gridTriggerSecondaryA, gridTriggerSecondaryB,
gridWaitForPrimaryEarlyExit, gridWaitForPrimaryA, gridWaitForPrimaryB,
hoistLoadTaskInit, hoistMmaTaskTryWaits, k, kernelTraits, layoutA, layoutB, m, mmaK,
mmaKind, mmaM, mmaN, mockAllReduce, n, numSlicesForSplitK, numSlicesForSliceK,
numStages, numStagesMma, numStagesMmaWithinWorkTile, numStagesMmaAcrossWorkTile,
numStagesWorkId, outputDebugTensors, patchF2fp, sfBlockSizeA, sfLayoutA, sfLayoutB,
sfLayoutC, sfReshapeFactor, sliceK, splitK, tileK, tileM, tileN, tileScheduler,
allReduceAlgo, biasType, blockK, clusterDimX, clusterDimY, clusterDimZ,
ctaSwizzleType, dtypeAcc, dtypeA, dtypeB, dtypeC, dtypeMmaA, dtypeMmaB,
enablesEarlyExit, enablesDelayedEarlyExit, enablesGlobalPtxKnobs, epilogueLdtmDps,
epilogueLdtmBits, epilogueTileM, epilogueTileN, gridTriggerSecondaryA,
gridTriggerSecondaryB, gridWaitForPrimaryEarlyExit, gridWaitForPrimaryA,
gridWaitForPrimaryB, hoistLoadTaskInit, hoistMmaTaskTryWaits, k, kernelTraits,
layoutA, layoutB, m, mmaK, mmaKind, mmaM, mmaN, mockAllReduce, n, numRegsCastAWarps,
numRegsCopySfLdsSttm, numRegsPerThreadEpilogueWarp, numRegsPerThreadNonEpilogueWarp,
numSlicesForSplitK, numSlicesForSliceK, numStages, numStagesMma,
numStagesMmaWithinWorkTile, numStagesMmaAcrossWorkTile, numStagesWorkId,
outputDebugTensors, patchF2fp, sfBlockSizeA, sfLayoutA, sfLayoutB, sfLayoutC,
sfReshapeFactor, sliceK, splitK, tileK, tileM, tileN, tileScheduler,
transposeMmaOutput, useCustomMmaSchedule, useDeepSeekFp8,
useHoistTryWaitForCustomMmaSchedule, usePerTokenSfA, usePerTokenSfB,
useShuffledMatrixA, useTmaStore, useTwoTmaLoadWarps, useTwoMmaWarps,
Expand All @@ -121,9 +125,6 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions {
mGridWaitForPrimaryRouting(gridWaitForPrimaryRouting),
mIsStaticBatch(isStaticBatch),
mNumBatches(numBatches),
mNumRegsPerThreadNonEpilogueWarp(numRegsPerThreadNonEpilogueWarp),
mNumRegsPerThreadEpilogueWarp(numRegsPerThreadEpilogueWarp),
mNumRegsCastAWarps(numRegsCastAWarps),
mNumTokens(numTokens),
mRouteImpl(routeImpl),
mUseTmaOobOpt(useTmaOobOpt) {}
Expand All @@ -143,12 +144,6 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions {
bool mIsStaticBatch{true};
// Number of Gemm batches.
int mNumBatches;
// Number of registers per thread for non-epilogue warps
int mNumRegsPerThreadNonEpilogueWarp{0};
// Number of registers per thread for epilogue warps
int mNumRegsPerThreadEpilogueWarp{0};
// Number of registers for the cast A warps.
int mNumRegsCastAWarps{0};
// Total number of tokens.
int mNumTokens{32};
// Whether load the input tokens and do routing.
Expand Down Expand Up @@ -269,16 +264,8 @@ bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackw
}
}

if (doesRouteImplUseTma(options.mRouteImpl)) {
TLLM_CHECK_ERROR(!batchM, "UTMALDG.GATHER4 only supported for batch N.");

if (tg::mmaKindIsBlockFmt(options.mMmaKind)) {
auto dtypeRoute = batchM ? options.mDtypeA : options.mDtypeB;
TLLM_CHECK_ERROR(options.mTileK % tg::dtypeNumEltsPerSf(dtypeRoute) == 0,
"tileK needs to be a multiple of 16 * tg::dtypeNumEltsPerSf(dtypeA).");
TLLM_CHECK_ERROR(options.mTileK % (tg::dtypeNumEltsPerSf(dtypeRoute) * 16) == 0,
"tileK needs to be a multiple of 16 * tg::dtypeNumEltsPerSf(dtypeA).");
}
if (options.mClusterDimX > 1) {
TLLM_CHECK_ERROR(!batchM, "2CTA Gemm currently only supports batch N.");
}

if (!batchM || doesRouteImplUseNoRoute(options.mRouteImpl)) {
Expand Down Expand Up @@ -323,6 +310,7 @@ struct BatchedGemmConfig {
char const* mHash{nullptr};
#else
trtllm::gen::CudaRunner* mCudaRunner{nullptr};
int32_t mInstanceIdx{0};
#endif

BatchedGemmOptions mOptions;
Expand All @@ -345,11 +333,6 @@ inline std::string dumpOptions(BatchedGemmOptions const& options) {
<< std::endl;
ss << "mGridWaitForPrimaryRouting=" << options.mGridWaitForPrimaryRouting << "," << std::endl;
ss << "mFusedAct=" << options.mFusedAct << "," << std::endl;
ss << "mNumRegsPerThreadNonEpilogueWarp=" << options.mNumRegsPerThreadNonEpilogueWarp << ","
<< std::endl;
ss << "mNumRegsPerThreadEpilogueWarp=" << options.mNumRegsPerThreadEpilogueWarp << ","
<< std::endl;
ss << "mNumRegsCastAWarps=" << options.mNumRegsCastAWarps << "," << std::endl;
ss << "mUseTmaOobOpt=" << options.mUseTmaOobOpt << std::endl;
return ss.str();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,23 @@ enum class TileScheduler {

////////////////////////////////////////////////////////////////////////////////////////////////////

enum class CtaSwizzleType : uint32_t {
// Rasterize CTAs along the M dimension.
RasterizeAlongM = 0,
// Rasterize CTAs along the N dimension.
RasterizeAlongN,
// Swizzle CTAs in zig-zag pattern along M dimension, Zig-zag width is 2.
ZigZagAlongM2,
// Swizzle CTAs in zig-zag pattern along N dimension, Zig-zag width is 2.
ZigZagAlongN2,
// Swizzle CTAs in zig-zag pattern along M dimension, Zig-zag width is 4.
ZigZagAlongM4,
// Swizzle CTAs in zig-zag pattern along N dimension, Zig-zag width is 4.
ZigZagAlongN4,
};

////////////////////////////////////////////////////////////////////////////////////////////////////

// Helper functions to check the SplitK type.

#define SPLIT_K_FUNCTION(Mode) \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ struct GemmGatedActConfig {
char const* mHash{nullptr};
#else
trtllm::gen::CudaRunner* mCudaRunner{nullptr};
int32_t mInstanceIdx{0};
#endif

GemmGatedActOptions mOptions{};
Expand Down
Loading