From 71052dfb10faebdf578673685cb1a718c51cd646 Mon Sep 17 00:00:00 2001 From: Alex Yang Date: Fri, 26 Sep 2025 00:08:13 +0000 Subject: [PATCH 01/12] bf16 --- csrc/trtllm_fused_moe_kernel_launcher.cu | 481 +++++++++++++++++++++++ flashinfer/artifacts.py | 2 +- flashinfer/fused_moe/__init__.py | 2 + flashinfer/fused_moe/core.py | 145 +++++++ tests/test_trtllm_gen_fused_moe.py | 195 ++++++++- 5 files changed, 815 insertions(+), 10 deletions(-) diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index da9b6d630f..5cd2e227fb 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -39,9 +39,468 @@ namespace flashinfer { namespace btg = batchedGemm::trtllm::gen; +using batchedGemm::gemm::MatrixLayout; using tensorrt_llm::kernels::trtllmgen_moe::MoE::GatedActType; using tensorrt_llm::kernels::trtllmgen_moe::Routing::RoutingMethodType; +/* + +Abstraction layers: + +1. TORCH_LIBRARY_FRAGMENT bindings +These are currently the same signature as the public python APIs. +We strive to make the python interface relatively stable +and the naming of parameters meaningful to the users. + +2. FusedMoeLauncher +This performs checks and preparations for the execution, +organized in several stages, see FusedMoeLauncher::run(). + +3. MoE::Runner +Orchestrate and dispatch all the kernels executions to fulfill the requested operation. +This includes PermuteGemm1, Gemm2, activation (if not fused), and finalize. + +4. TrtllmGenBatchedGemmRunner +This provides tactic selection if not determined yet at the public API (or auto-tuning) + +5. BatchedGemm Runner +The low-level gemm kernel executor which is updated together with the kernels. + +6. BatchedGemmInterface +Driver calls take place to carry out the gemm operations. +*/ + +class FusedMoeLauncher { + protected: + at::Tensor const* routing_logits{}; + at::Tensor const* routing_bias{}; + at::Tensor const* hidden_states{}; + at::Tensor const* gemm1_weights{}; + at::Tensor const* output1_scales_scalar{}; + at::Tensor const* output1_scales_gate_scalar{}; + at::Tensor const* gemm2_weights{}; + at::Tensor const* output2_scales_scalar{}; + + int64_t tile_tokens_dim{}; + int64_t routing_method_type{}; + bool use_shuffled_weight{}; + MatrixLayout weight_layout{MatrixLayout::MajorK}; + + std::tuple device_version; + std::unique_ptr args; + tensorrt_llm::kernels::trtllmgen_moe::MoE::MoEWorkspace workspace; + + btg::Dtype mDtypeAct{btg::Dtype::Bfloat16}; + btg::Dtype mDtypeWeights{btg::Dtype::Bfloat16}; + GatedActType gated_act_type{GatedActType::SwiGlu}; + + // Initialize common data necessary for later. + // May throw exception from TORCH_CHECK. + void init_common(at::Tensor const* routing_logits, at::Tensor const* routing_bias, + at::Tensor const* hidden_states, at::Tensor const* gemm1_weights, + at::Tensor const* gemm2_weights, + std::unique_ptr&& args, + int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, + int64_t weight_layout, int64_t gated_act_type); + + // Routing logits [num_tokens, num_experts] + void check_routing_logits_shape() const { + TORCH_CHECK(routing_logits->dim() == 2, "routing_logits must be 2D."); + TORCH_CHECK(routing_logits->sizes()[0] == hidden_states->sizes()[0], + "routing_logits and hidden_states must have the same number of tokens."); + TORCH_CHECK(routing_logits->sizes()[1] == args->num_experts, + "routing_logits dim1 must match num_experts."); + } + + // Routing bias [num_experts] + void check_routing_bias_shape() const { + if (routing_bias != nullptr) { + TORCH_CHECK(routing_bias->dim() == 1, "routing_bias must be 1D."); + TORCH_CHECK(routing_bias->sizes()[0] == args->num_experts, + "routing_bias has incorrect shape."); + } + } + + // Hidden states [num_tokens, hidden_size] + void check_hidden_states_shape() const { + TORCH_CHECK(hidden_states->dim() == 2, "hidden_states must be 2D."); + TORCH_CHECK(hidden_states->sizes()[1] == args->intermediate_size, + "hidden_states has incorrect shape."); + } + + // GEMM1 or GEMM2 weights [num_experts, M, K] or [num_experts, K/block_k, M, block_k] + void check_weights_shape(std::string which_weights) const { + at::Tensor const* weights{}; + if (which_weights == "gemm1") { + weights = gemm1_weights; + } else if (which_weights == "gemm2") { + weights = gemm2_weights; + } else { + TORCH_CHECK(false, "Internal error: which_weights = ", which_weights); + } + + int64_t Mn = 0, K = 0; + if (weight_layout == MatrixLayout::MajorK) { + // MajorK [num_experts, M, K] + Mn = weights->sizes()[1]; + K = weights->sizes()[2]; + } else if (weight_layout == MatrixLayout::BlockMajorK) { + // BlockMajorK [num_experts, K/block_k, M, block_k] + Mn = weights->sizes()[2]; + int64_t block_k = weights->sizes()[3]; + K = weights->sizes()[1] * block_k; + } else { + TORCH_CHECK(false, "Unsupported weight_layout: ", weight_layout); + } + TORCH_CHECK(weights->sizes()[0] == args->num_experts, + which_weights + " weights expert dimension must match num_experts"); + if (which_weights == "gemm1") { + TORCH_CHECK(Mn % 2 == 0, which_weights + " weights Mn dimension must be even."); + TORCH_CHECK(args->intermediate_size == Mn / 2, "intermediate_size has incorrect shape."); + TORCH_CHECK(K == hidden_states->sizes()[1], + which_weights + " weights K dimension must be equal to hidden_size."); + } else if (which_weights == "gemm2") { + TORCH_CHECK(K == args->intermediate_size, + which_weights + " weights K dimension must be equal to intermediate_size."); + } + } + + void check_routing_common() const { + TORCH_CHECK(args->top_k > 0 && args->top_k <= args->num_experts, + "top_k must be between 1 and num_experts"); + TORCH_CHECK(args->local_num_experts > 0 && args->local_num_experts <= args->num_experts, + "local_num_experts must be between 1 and num_experts"); + TORCH_CHECK(args->local_expert_offset >= 0 && + args->local_expert_offset + args->local_num_experts <= args->num_experts, + "expert offset and count must be within valid range"); + + check_routing_logits_shape(); + + if (routing_bias) { + check_routing_bias_shape(); + } + } + + // Routing phase workspace tensors (allocated in prepare_routing() or prepare_routing_common()) + at::Tensor num_tokens_per_expert; + at::Tensor total_num_padded_tokens; + at::Tensor expanded_idx_to_permuted_idx; + at::Tensor permuted_idx_to_token_idx; + at::Tensor expert_weights; + at::Tensor expert_indexes; + at::Tensor expert_count_histogram; + at::Tensor cta_idx_xy_to_batch_idx; + at::Tensor cta_idx_xy_to_mn_limit; + at::Tensor num_non_exiting_ctas; + + void prepare_routing_common() { + // Allocate routing phase workspace tensors + int32_t max_num_padded_tokens = + tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxPermutedPaddedCount( + args->num_tokens, args->top_k, args->num_experts, tile_tokens_dim); + + // Common routing workspace tensors allocation + num_tokens_per_expert = at::detail::empty_cuda({args->num_experts}, at::ScalarType::Int, + routing_logits->device(), std::nullopt); + + total_num_padded_tokens = at::empty( + {}, at::TensorOptions().device(routing_logits->device()).dtype(at::ScalarType::Int)); + + expanded_idx_to_permuted_idx = + at::detail::empty_cuda({args->num_tokens * args->top_k}, at::ScalarType::Int, + routing_logits->device(), std::nullopt); + + permuted_idx_to_token_idx = at::detail::empty_cuda({max_num_padded_tokens}, at::ScalarType::Int, + routing_logits->device(), std::nullopt); + + expert_indexes = at::detail::empty_cuda({args->num_tokens, args->top_k}, at::ScalarType::Int, + routing_logits->device(), std::nullopt); + + // expert_weights allocation should be done by derived class since data type could vary + + int64_t const size_of_expert_count_histogram = std::max(args->num_experts * 2, 256 * 2); + expert_count_histogram = + at::detail::empty_cuda({size_of_expert_count_histogram}, + at::ScalarType::Int, // 256 is the max number of threads per block + // and max number of experts + routing_logits->device(), std::nullopt); + + int32_t max_num_ctas = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxNumCtasInBatchDim( + args->num_tokens, args->top_k, args->num_experts, tile_tokens_dim); + + cta_idx_xy_to_batch_idx = at::detail::empty_cuda({max_num_ctas}, at::ScalarType::Int, + routing_logits->device(), std::nullopt); + + cta_idx_xy_to_mn_limit = at::detail::empty_cuda({max_num_ctas}, at::ScalarType::Int, + routing_logits->device(), std::nullopt); + + num_non_exiting_ctas = at::empty( + {}, at::TensorOptions().device(routing_logits->device()).dtype(at::ScalarType::Int)); + + workspace.total_num_padded_tokens = total_num_padded_tokens.data_ptr(); + workspace.total_max_padded_tokens = max_num_padded_tokens; + workspace.ProjUpTileN = tile_tokens_dim; + workspace.routing_expert_indexes = expert_indexes.data_ptr(); + workspace.permuted_idx_size = total_num_padded_tokens.data_ptr(); + workspace.expanded_idx_to_permuted_idx = expanded_idx_to_permuted_idx.data_ptr(); + workspace.permuted_idx_to_token_idx = permuted_idx_to_token_idx.data_ptr(); + // workspace.expert_weights will be set by derived class after expert_weights allocation + workspace.cta_idx_xy_to_batch_idx = cta_idx_xy_to_batch_idx.data_ptr(); + workspace.cta_idx_xy_to_mn_limit = cta_idx_xy_to_mn_limit.data_ptr(); + workspace.num_non_exiting_ctas = num_non_exiting_ctas.data_ptr(); + } + + void check_moe_common() const { + // Hidden states [num_tokens, hidden_size] + TORCH_CHECK(hidden_states->dim() == 2, "hidden_states must be 2D."); + } + + // MoE computation phase workspace tensors (allocated in prepare_moe() or prepare_moe_common()) + at::Tensor gemm1_output; + at::Tensor activation_output; + at::Tensor gemm2_output; + at::Tensor workspace_fc1; + at::Tensor workspace_fc2; + at::Tensor output; + int64_t moe_tactic{-1}; + std::unique_ptr moe_runner; + + void prepare_moe_common(int64_t& moe_tactic) { + using RunnerType = tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner; + moe_runner = std::make_unique( + this->mDtypeAct, this->mDtypeWeights, args->mUseDeepSeekFp8, (int32_t)tile_tokens_dim, + static_cast(this->gated_act_type), this->use_shuffled_weight); + + if (moe_tactic == -1) { + moe_tactic = moe_runner->getDefaultValidConfigIndex( + args->top_k, args->hidden_size, args->intermediate_size, args->local_num_experts, + args->num_tokens); + } + this->moe_tactic = moe_tactic; + + auto workspace_sizes = moe_runner->getWorkspaceSizeInBytes(*args, moe_tactic); + workspace_fc1 = at::detail::empty_cuda({std::get<0>(workspace_sizes)}, at::ScalarType::Char, + hidden_states->device(), std::nullopt); + workspace_fc2 = at::detail::empty_cuda({std::get<1>(workspace_sizes)}, at::ScalarType::Char, + hidden_states->device(), std::nullopt); + workspace.bmm1_workspace = workspace_fc1.data_ptr(); + workspace.bmm2_workspace = workspace_fc2.data_ptr(); + } + + public: + virtual void check_routing() const = 0; + virtual void prepare_routing() = 0; + virtual void check_moe() const = 0; + virtual void prepare_moe(int64_t& moe_tactic) = 0; + + // Main entry point for all the executions. + // Do initializations prior to calling this as the initializations are different for bf16, fp8 and + // fp4. The executions are non-blocking by default. + std::vector run(int64_t moe_tactic, bool enable_pdl = true) { + check_routing(); + prepare_routing(); + + // Execute routing + tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim); + int routing_device = routing_logits->get_device(); + auto const& routing_stream = at::cuda::getCurrentCUDAStream(routing_device); + routing_runner.run( + routing_logits->data_ptr(), args->routing_bias, args->num_tokens, args->num_experts, + args->top_k, args->n_group, args->topk_group, args->local_expert_offset, + args->local_num_experts, args->routed_scaling_factor, expert_indexes.data_ptr(), + expert_count_histogram.data_ptr(), total_num_padded_tokens.data_ptr(), + expanded_idx_to_permuted_idx.data_ptr(), + nullptr /*permuted_idx_to_expanded_idx.data_ptr()*/, + permuted_idx_to_token_idx.data_ptr(), expert_weights.data_ptr(), + num_tokens_per_expert.data_ptr(), cta_idx_xy_to_batch_idx.data_ptr(), + cta_idx_xy_to_mn_limit.data_ptr(), num_non_exiting_ctas.data_ptr(), + args->mDtypeElt, false, true, static_cast(routing_method_type), + routing_stream); + + check_moe(); + // if moe_tactic is -1, it will be set to the default valid config index + prepare_moe(moe_tactic); + + // Execute MoE + int moe_device = hidden_states->get_device(); + auto const& moe_stream = at::cuda::getCurrentCUDAStream(moe_device); + moe_runner->run(*args, workspace, moe_device, moe_stream, moe_tactic, enable_pdl); + + if (args->do_finalize) { + return {output}; + } + return {gemm2_output, expert_weights, expanded_idx_to_permuted_idx}; + } +}; + +void FusedMoeLauncher::init_common( + at::Tensor const* routing_logits, at::Tensor const* routing_bias, + at::Tensor const* hidden_states, at::Tensor const* gemm1_weights, + at::Tensor const* gemm2_weights, + std::unique_ptr&& args, + int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, + int64_t weight_layout, int64_t gated_act_type) { + // Check devicearchitecture: Blackwell (SM 10.x) required + TORCH_CHECK(hidden_states != nullptr, "hidden_states is required"); + auto device = hidden_states->device().index(); + int major = 0, minor = 0; + cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device); + cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device); + TORCH_CHECK(major == 10, "BF16 MoE requires 10.x architecture. Current device has SM ", major, + minor); + this->device_version = std::make_tuple(major, minor); + + this->routing_logits = routing_logits; + this->routing_bias = routing_bias; + this->hidden_states = hidden_states; + this->gemm1_weights = gemm1_weights; + this->gemm2_weights = gemm2_weights; + + args->routing_logits = routing_logits->data_ptr(); + args->routing_bias = routing_bias ? routing_bias->data_ptr() : nullptr; + args->hidden_states = hidden_states->data_ptr(); + args->gemm1_weights = gemm1_weights->data_ptr(); + args->gemm2_weights = gemm2_weights->data_ptr(); + + this->args = std::move(args); + this->tile_tokens_dim = tile_tokens_dim; + this->routing_method_type = routing_method_type; + this->use_shuffled_weight = use_shuffled_weight; + TORCH_CHECK(0 <= weight_layout && weight_layout <= 2, + "the value of weight_layout is not recognized"); + this->weight_layout = static_cast(weight_layout); + TORCH_CHECK(0 <= gated_act_type && gated_act_type <= 1, + "the value of gated_act_type is not recognized"); + this->gated_act_type = static_cast(gated_act_type); +} + +class Bf16MoeLauncher : public FusedMoeLauncher { + public: + Bf16MoeLauncher() = default; + + void init(at::Tensor const& routing_logits, std::optional const& routing_bias, + at::Tensor const& hidden_states, at::Tensor const& gemm1_weights, + at::Tensor const& gemm2_weights, + std::unique_ptr&& args, + int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, + int64_t weight_layout) { + constexpr int64_t gated_act_type = + static_cast(GatedActType::SwiGlu); // not exposed in api for now + + // Do base class init and perform common checks + FusedMoeLauncher::init_common( + &routing_logits, routing_bias.has_value() ? &routing_bias.value() : nullptr, &hidden_states, + &gemm1_weights, &gemm2_weights, std::move(args), tile_tokens_dim, routing_method_type, + use_shuffled_weight, weight_layout, gated_act_type); + } + + void check_routing() const override { + FusedMoeLauncher::check_routing_common(); + + // TODO n_group, topk_group validation? + } + + void prepare_routing() override { + FusedMoeLauncher::prepare_routing_common(); + + args->mDtypeElt = btg::Dtype::Bfloat16; + args->mDtypeExpW = btg::Dtype::Bfloat16; + args->mUseDeepSeekFp8 = false; + + auto const routing_bias_dtype = at::ScalarType::BFloat16; + expert_weights = at::detail::empty_cuda({args->num_tokens, args->top_k}, routing_bias_dtype, + routing_logits->device(), std::nullopt); + + workspace.expert_weights = expert_weights.data_ptr(); + } + + void check_moe() const override { + FusedMoeLauncher::check_moe_common(); + + TORCH_CHECK(weight_layout == MatrixLayout::BlockMajorK, + "BF16 Moe: weight_layout must be BlockMajorK"); + check_weights_shape("gemm1"); + check_weights_shape("gemm2"); + + TORCH_CHECK(args->intermediate_size % 128 == 0, + "the second dimension of weights must be a multiple of 128."); + } + + void prepare_moe(int64_t& moe_tactic) override { + // in the next line moe_tactic is passed by reference so modification will be propagated back + // here + FusedMoeLauncher::prepare_moe_common(moe_tactic); + + int32_t max_num_padded_tokens = workspace.total_max_padded_tokens; + gemm1_output = + at::detail::empty_cuda({max_num_padded_tokens, 2 * args->intermediate_size}, + at::ScalarType::BFloat16, hidden_states->device(), std::nullopt); + activation_output = + at::detail::empty_cuda({max_num_padded_tokens, args->intermediate_size}, + at::ScalarType::BFloat16, hidden_states->device(), std::nullopt); + gemm2_output = + at::detail::empty_cuda({max_num_padded_tokens, args->hidden_size}, at::ScalarType::BFloat16, + hidden_states->device(), std::nullopt); + + workspace.hidden_states_scale_linear = nullptr; + workspace.gemm1_output = gemm1_output.data_ptr(); + workspace.gemm1_output_scale = nullptr; // BF16 doesn't use scale tensors + workspace.activation_output = activation_output.data_ptr(); + workspace.activation_output_scale = nullptr; // BF16 doesn't use scale tensors + workspace.gemm2_output = gemm2_output.data_ptr(); + workspace.gemm2_output_scale = nullptr; + + output = at::detail::empty_cuda({args->num_tokens, args->hidden_size}, at::ScalarType::BFloat16, + hidden_states->device(), std::nullopt); + args->output = output.data_ptr(); + args->output_scale = nullptr; + } +}; + +at::Tensor trtllm_bf16_moe(at::Tensor const& routing_logits, + std::optional const& routing_bias, + at::Tensor const& hidden_states, at::Tensor const& gemm1_weights, + at::Tensor const& gemm2_weights, int64_t num_experts, int64_t top_k, + int64_t n_group, int64_t topk_group, int64_t intermediate_size, + int64_t local_expert_offset, int64_t local_num_experts, + int64_t tile_tokens_dim, int64_t routing_method_type, + bool use_shuffled_weight, int64_t weight_layout, int64_t moe_tactic, + bool enable_pdl) { + // Just some basic type validation first and leave more checks to the launcher + TORCH_CHECK(routing_logits.scalar_type() == at::ScalarType::Float || + routing_logits.scalar_type() == at::ScalarType::BFloat16, + "BF16 MoE: routing_logits must be bfoat16 or float."); + if (routing_bias.has_value()) { + TORCH_CHECK(routing_bias.value().scalar_type() == at::ScalarType::BFloat16, + "BF16 MoE: routing_bias must be bfloat16."); + } + TORCH_CHECK(hidden_states.scalar_type() == at::ScalarType::BFloat16, + "BF16 MoE: hidden_states must be bfloat16."); + TORCH_CHECK(gemm1_weights.scalar_type() == at::ScalarType::BFloat16, + "BF16 MoE: gemm1_weights must be bfloat16."); + TORCH_CHECK(gemm2_weights.scalar_type() == at::ScalarType::BFloat16, + "BF16 MoE: gemm2_weights must be bfloat16."); + + // Save params to MoE arguments + auto args = std::make_unique(); + args->num_tokens = hidden_states.sizes()[0]; + args->num_experts = num_experts; + args->hidden_size = hidden_states.sizes()[1]; + args->hidden_size_output = args->hidden_size; + args->top_k = top_k; + args->n_group = n_group; + args->topk_group = topk_group; + args->local_expert_offset = local_expert_offset; + args->local_num_experts = local_num_experts; + args->intermediate_size = intermediate_size; + + Bf16MoeLauncher launcher; + launcher.init(routing_logits, routing_bias, hidden_states, gemm1_weights, gemm2_weights, + std::move(args), tile_tokens_dim, routing_method_type, use_shuffled_weight, + weight_layout); + return launcher.run(moe_tactic, enable_pdl)[0]; +} + at::Tensor trtllm_fp8_per_tensor_scale_moe_launcher( at::Tensor const& routing_logits, std::optional routing_bias, at::Tensor const& hidden_states, at::Tensor const& gemm1_weights, @@ -301,6 +760,27 @@ at::Tensor trtllm_fp8_per_tensor_scale_moe( auto dtype = hidden_states.dtype(); if (dtype == at::ScalarType::Half || dtype == at::ScalarType::BFloat16 || dtype == at::ScalarType::Float8_e4m3fn) { + // // Create unified runner for FP8 per-tensor mode + // using RunnerType = tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner; + // auto mRunner = std::make_unique( + // btg::Dtype::E4m3, false, tile_tokens_dim, /*useShuffledMatrixA*/ true); + + // auto const moeConfigIndex = mRunner->getDefaultValidConfigIndex( + // top_k, hidden_states.sizes()[1], intermediate_size, local_num_experts, + // hidden_states.sizes()[0]); + + // // Call unified launcher with nullopt for expert_indices, expert_weights, and output (will be + // created internally) auto results = trtllm_fp4_block_scale_moe_launcher( + // routing_logits, std::nullopt, std::nullopt, routing_bias, hidden_states, std::nullopt, + // gemm1_weights, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, + // gemm2_weights, std::nullopt, std::nullopt, + // output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar, + // num_experts, top_k, n_group, topk_group, intermediate_size, local_expert_offset, + // local_num_experts, routed_scaling_factor, tile_tokens_dim, routing_method_type, true, // + // do_finalize = true *mRunner, btg::Dtype::E4m3, btg::Dtype::E4m3, moeConfigIndex, + // enable_pdl); + + // return results[0]; // Return the first tensor from the vector return trtllm_fp8_per_tensor_scale_moe_launcher( routing_logits, routing_bias, hidden_states, gemm1_weights, output1_scales_scalar, output1_scales_gate_scalar, gemm2_weights, output2_scales_scalar, num_experts, top_k, @@ -1161,6 +1641,7 @@ namespace trtllm_cubin_loader { } TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { + m.def("trtllm_bf16_moe", trtllm_bf16_moe); m.def("trtllm_fp8_per_tensor_scale_moe", trtllm_fp8_per_tensor_scale_moe); m.def("trtllm_fp8_block_scale_moe", trtllm_fp8_block_scale_moe); m.def("trtllm_fp4_block_scale_moe", trtllm_fp4_block_scale_moe); diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index 0f0d18d680..d8c49ff2fd 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -111,7 +111,7 @@ def get_available_cubin_files(source, retries=3, delay=5, timeout=10): class ArtifactPath: TRTLLM_GEN_FMHA: str = "037e528e719ec3456a7d7d654f26b805e44c63b1/fmha/trtllm-gen/" TRTLLM_GEN_BMM: str = ( - "e6f22dcc3fdeb29ff87af2f4a2cb3d30b8d273e0/batched_gemm-45beda1-ee6a802/" + "b5c82312e606632b7571c7370f4335cfae46e206/batched_gemm-145d1b1-9e1d49a/" ) TRTLLM_GEN_GEMM: str = ( "037e528e719ec3456a7d7d654f26b805e44c63b1/gemm-8704aa4-f91dc9e/" diff --git a/flashinfer/fused_moe/__init__.py b/flashinfer/fused_moe/__init__.py index 8c26c73b01..937f791226 100644 --- a/flashinfer/fused_moe/__init__.py +++ b/flashinfer/fused_moe/__init__.py @@ -28,6 +28,7 @@ trtllm_fp4_block_scale_routed_moe, trtllm_fp8_block_scale_moe, trtllm_fp8_per_tensor_scale_moe, + trtllm_bf16_moe, ) __all__ = [ @@ -42,4 +43,5 @@ "trtllm_fp4_block_scale_moe", "trtllm_fp8_block_scale_moe", "trtllm_fp8_per_tensor_scale_moe", + "trtllm_bf16_moe", ] diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 34c34a6a0b..f3baee3ad1 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -1224,6 +1224,106 @@ def refine_tuning_config(cls, tune_max_num_tokens: int): ), ) + @register_custom_op( + "flashinfer::trtllm_bf16_moe", + mutates_args=(""), + ) + def trtllm_bf16_moe_op( + routing_logits: torch.Tensor, + routing_bias: Optional[torch.Tensor], + hidden_states: torch.Tensor, + gemm1_weights: torch.Tensor, + gemm2_weights: torch.Tensor, + num_experts: int, + top_k: int, + n_group: int, + topk_group: int, + intermediate_size: int, + local_expert_offset: int, + local_num_experts: int, + tile_tokens_dim: int, + routing_method_type: int, + use_shuffled_weight: bool, + weight_layout: int, + moe_tactic: int, + enable_pdl: Optional[bool] = None, + ) -> torch.Tensor: + if enable_pdl is None: + enable_pdl = device_support_pdl(hidden_states.device) + # Call the C++ function for block scale MoE + # FIXME: remove these prints once done + print( + "@@@@", + routing_logits.dtype, + "routing_bias.dtype", + hidden_states.dtype, + gemm1_weights.dtype, + gemm2_weights.dtype, + ) + print("@@@@", gemm1_weights.shape, gemm2_weights.shape) + print( + num_experts, + top_k, + n_group, + topk_group, + intermediate_size, + local_expert_offset, + local_num_experts, + tile_tokens_dim, + use_shuffled_weight, + weight_layout, + moe_tactic, + routing_method_type, + enable_pdl, + ) + output = moe_op.trtllm_bf16_moe( + routing_logits.to(torch.float), # FIXME what's the supported type? + routing_bias, + hidden_states, + gemm1_weights, + gemm2_weights, + num_experts, + top_k, + n_group, + topk_group, + intermediate_size, + local_expert_offset, + local_num_experts, + tile_tokens_dim, + routing_method_type, + use_shuffled_weight, + weight_layout, + moe_tactic, + enable_pdl, + ) + return output + + @register_fake_op("flashinfer::trtllm_bf16_moe") + def _fake_trtllm_bf16_moe( + routing_logits: torch.Tensor, + routing_bias: Optional[torch.Tensor], + hidden_states: torch.Tensor, + gemm1_weights: torch.Tensor, + gemm2_weights: torch.Tensor, + num_experts: int, + top_k: int, + n_group: int, + topk_group: int, + intermediate_size: int, + local_expert_offset: int, + local_num_experts: int, + tile_tokens_dim: int, + routing_method_type: int, + use_shuffled_weight: bool, + weight_layout: int, + moe_tactic: int, + enable_pdl: Optional[bool] = None, + ): + seq_len = hidden_states.shape[0] + hidden_size = hidden_states.shape[1] + + return [hidden_states.new_empty([seq_len, hidden_size], dtype=torch.bfloat16)] + @register_custom_op( "flashinfer::trtllm_fp8_per_tensor_scale_moe", mutates_args=(""), @@ -1607,12 +1707,57 @@ def _fake_trtllm_fp4_block_scale_moe( return [hidden_states.new_empty([seq_len, hidden_size], dtype=torch.bfloat16)] return SimpleNamespace( + trtllm_bf16_moe=trtllm_bf16_moe_op, trtllm_fp8_per_tensor_scale_moe=trtllm_fp8_per_tensor_scale_moe_op, trtllm_fp8_block_scale_moe=trtllm_fp8_block_scale_moe_op, trtllm_fp4_block_scale_moe=trtllm_fp4_block_scale_moe_op, ) +def trtllm_bf16_moe( + routing_logits: torch.Tensor, + routing_bias: Optional[torch.Tensor], + hidden_states: torch.Tensor, + gemm1_weights: torch.Tensor, + gemm2_weights: torch.Tensor, + num_experts: int, + top_k: int, + n_group: int, + topk_group: int, + intermediate_size: int, + local_expert_offset: int, + local_num_experts: int, + *, + tile_tokens_dim: int = 8, + routing_method_type: int = 0, + use_shuffled_weight: bool = True, + weight_layout: int = WeightLayout.BlockMajorK, + moe_tactic: int = -1, + enable_pdl: bool = True, +) -> torch.Tensor: + """BF16 block scale MoE operation.""" + return get_trtllm_moe_sm100_module().trtllm_bf16_moe( + routing_logits, + routing_bias, + hidden_states, + gemm1_weights, + gemm2_weights, + num_experts, + top_k, + n_group or 0, # may receive None from test configs, convert to 0 + topk_group or 0, + intermediate_size, + local_expert_offset, + local_num_experts, + tile_tokens_dim, + routing_method_type, + use_shuffled_weight, + weight_layout, + moe_tactic, + enable_pdl, + ) + + def trtllm_fp8_per_tensor_scale_moe( routing_logits: torch.Tensor, routing_bias: Optional[torch.Tensor], diff --git a/tests/test_trtllm_gen_fused_moe.py b/tests/test_trtllm_gen_fused_moe.py index 02e41a21e2..f811348628 100644 --- a/tests/test_trtllm_gen_fused_moe.py +++ b/tests/test_trtllm_gen_fused_moe.py @@ -38,6 +38,7 @@ from flashinfer.fused_moe import ( WeightLayout, convert_to_block_layout, + trtllm_bf16_moe, trtllm_fp4_block_scale_moe, trtllm_fp8_block_scale_moe, trtllm_fp8_per_tensor_scale_moe, @@ -983,6 +984,142 @@ def get_tolerances(self): return {"atol": 0.1, "rtol": 0.85, "percent": 0.925} +# ==================================================================================== +# BF16 Implementation +# ==================================================================================== + + +class BF16Moe(Moe): + """BF16 MoE implementation.""" + + def quantize_weights(self, gemm1_weights, gemm2_weights, hidden_states_sample): + """No scaling for weights.""" + return { + "hidden_states_scale_global": None, + "gemm1_weights": gemm1_weights.to(torch.bfloat16), + "gemm1_scales": None, + "gemm1_scales_global": None, + "gemm2_weights": gemm2_weights.to(torch.bfloat16), + "gemm2_scales": None, + "gemm2_scales_global": None, + } + + def quantize_inputs(self, hidden_states, *unused_args): + """No scaling for hidden states.""" + return { + "hidden_states": hidden_states.to(torch.bfloat16), + "hidden_states_scale": None, + } + + def prepare_static_weights_for_kernel( + self, + args_dequant, + args, + gemm1_weights_orig, + gemm2_weights_orig, + hidden_size, + intermediate_size, + num_experts, + weight_processing, + ): + """Prepare quantized weights for kernel (done offline with weights).""" + + # Use shuffled weights with BlockMajorK layout for better performance + use_shuffled_weight = weight_processing["use_shuffled_weight"] + weight_layout = weight_processing["layout"] + + if use_shuffled_weight: + # FIXME: this depends on the kernel internals + epilogue_tile_m = 128 + + # Reorder rows of W1 for fused gated activation + gemm1_weights_bf16_interleaved = [] + for i in range(num_experts): + gemm1_weights_bf16_interleaved.append( + reorder_rows_for_gated_act_gemm(args.gemm1_weights[i].clone()) + ) + + # Stack weights and scales for all experts + gemm1_weights_bf16_interleaved = torch.stack( + gemm1_weights_bf16_interleaved + ).reshape(num_experts, 2 * intermediate_size, hidden_size) + + # Shuffle weights and scaling factors for transposed mma output + gemm1_weights_bf16_shuffled = [] + gemm2_weights_bf16_shuffled = [] + for i in range(num_experts): + tmp_weights1 = shuffle_matrix_a( + args.gemm1_weights[i].view(torch.uint8), epilogue_tile_m + ) + tmp_weights2 = shuffle_matrix_a( + args.gemm2_weights[i].view(torch.uint8), epilogue_tile_m + ) + + if weight_layout == WeightLayout.BlockMajorK: + block_k = 128 + tmp_weights1 = convert_to_block_layout(tmp_weights1, block_k) + tmp_weights2 = convert_to_block_layout(tmp_weights2, block_k) + + gemm1_weights_bf16_shuffled.append(tmp_weights1) + gemm2_weights_bf16_shuffled.append(tmp_weights2) + + # Stack weights for all experts + gemm1_weights_bf16_shuffled = torch.stack(gemm1_weights_bf16_shuffled).view( + torch.bfloat16 + ) + gemm2_weights_bf16_shuffled = torch.stack(gemm2_weights_bf16_shuffled).view( + torch.bfloat16 + ) + + return { + "gemm1_weights": gemm1_weights_bf16_shuffled, + "gemm2_weights": gemm2_weights_bf16_shuffled, + } + + def call_moe( + self, static_data, hidden_states_orig, hidden_states_scale_global, **kwargs + ): + """Call MoE with runtime input quantization + kernel execution (done at runtime).""" + expert_logits = kwargs["expert_logits"] + routing_bias = kwargs["routing_bias"] + num_experts = kwargs["num_experts"] + top_k = kwargs["top_k"] + n_groups = kwargs["n_groups"] + top_k_groups = kwargs["top_k_groups"] + intermediate_size = kwargs["intermediate_size"] + routing_method_type = kwargs["routing_method_type"] + tile_tokens_dim = kwargs["tile_tokens_dim"] + + output = trtllm_bf16_moe( + expert_logits, # float + routing_bias, + hidden_states_orig, + static_data["gemm1_weights"], + static_data["gemm2_weights"], + num_experts, + top_k, + n_groups, + top_k_groups, + intermediate_size, + 0, + num_experts, + # the rest are enforced by the api to be passed in the keyword form + # as opposed to the positional form + tile_tokens_dim=tile_tokens_dim, + routing_method_type=routing_method_type, + ) + + return output.to(torch.float) + + def compute_reference(self, args): + """BF16 reference implementation.""" + return run_moe_reference_bf16(args) + + def get_tolerances(self): + """Get BF16 accuracy tolerances.""" + return {"atol": 0.1, "rtol": 0.85, "percent": 0.925} + + # ==================================================================================== # Quantizer Factory # ==================================================================================== @@ -1789,6 +1926,37 @@ def run_moe_reference_per_tensor_scale_fp8(args): return run_moe_dequant(args_dequant, QuantMode.FP8_PER_TENSOR), args_dequant +def run_moe_reference_bf16(args): + """BF16 reference implementation.""" + + # no scaling for hidden states and weights + hidden_states_dequant = args.hidden_states.to(torch.float) + gemm1_weights_dequant = {} + for i in range(args.num_experts): + gemm1_weights_dequant[i] = args.gemm1_weights[i].to(torch.float) + gemm2_weights_dequant = {} + for i in range(args.num_experts): + gemm2_weights_dequant[i] = args.gemm2_weights[i].to(torch.float) + + args_dequant = moe_args_dequant( + args.num_tokens, + args.num_experts, + args.hidden_size, + args.intermediate_size, + args.top_k, + args.padding, + hidden_states_dequant, + args.expert_logits, + gemm1_weights_dequant, + gemm2_weights_dequant, + args.permute_info, + args.use_routing_scales_on_input, + GatedActType.SwiGlu.value, # gated_act_type + ) + + return run_moe_dequant(args_dequant, QuantMode.FP8_PER_TENSOR), args_dequant + + def _compute_moe_actual_unified(moe_impl, args_dequant, args, **kwargs): """Unified actual computation that delegates to implementation-specific methods.""" # 1. Prepare static weights for the kernel (offline processing) @@ -1848,6 +2016,7 @@ def cache_permute_indices(): pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_Bf16), id="MxFP4xBf16"), pytest.param(FP8BlockScaleMoe(), id="FP8_Block"), pytest.param(FP8PerTensorMoe(), id="FP8_Tensor"), + pytest.param(BF16Moe(), id="All_BF16"), ], ) @pytest.mark.parametrize( @@ -1897,7 +2066,12 @@ def cache_permute_indices(): "routed_scaling": None, "has_routing_bias": False, "routing_method_type": RoutingMethodType.Renormalize, - "compatible_moe_impls": [FP4Moe, FP8PerTensorMoe, FP8BlockScaleMoe], + "compatible_moe_impls": [ + FP4Moe, + FP8PerTensorMoe, + FP8BlockScaleMoe, + BF16Moe, + ], }, id="Renorm", marks=pytest.mark.skip( @@ -1914,7 +2088,7 @@ def cache_permute_indices(): "routed_scaling": None, "has_routing_bias": False, "routing_method_type": RoutingMethodType.RenormalizeNaive, - "compatible_moe_impls": [FP4Moe], + "compatible_moe_impls": [FP4Moe, BF16Moe], }, id="RenormNaive", ), @@ -1971,7 +2145,7 @@ def cache_permute_indices(): { "use_shuffled_weight": True, "layout": WeightLayout.BlockMajorK, - "compatible_moe_impls": [FP8BlockScaleMoe], + "compatible_moe_impls": [FP8BlockScaleMoe, BF16Moe], }, id="Shuffled_BlockMajorK", ), @@ -2004,31 +2178,34 @@ def test_moe_quantization_classes( Each quantization class clearly shows which precision is being used. """ # Skip incompatible combinations + # NOTE Don't worry about the short-lived variables between if & skip - formatting simplifications if gated_act_type == GatedActType.GeGlu and ( type(moe_impl) is not FP4Moe or moe_impl.quant_mode != QuantMode.FP4_NVFP4_NVFP4 or routing_config["routing_method_type"] != RoutingMethodType.TopK or num_tokens > 128 ): + routing = routing_config["routing_method_type"].name # GeGlu is only supported for FP4Moe FP4_NVFP4_NVFP4 and TopK routing pytest.skip( - f"Incompatible: {moe_impl.name} + {gated_act_type} + {routing_config['routing_method_type']} + {num_tokens}" + f"Incompatible: {moe_impl.name} {gated_act_type.name} {routing=} {num_tokens=}" ) elif gated_act_type == GatedActType.SwiGlu and ( hidden_size > 1024 or intermediate_size > 1024 ): # Skip some tests for SwiGlu for testing speed pytest.skip( - f"Skip for testing speed: {gated_act_type} + {hidden_size} + {intermediate_size}" + f"Skip for testing speed: {gated_act_type.name} {hidden_size=} {intermediate_size=}" ) if type(moe_impl) not in routing_config["compatible_moe_impls"]: - pytest.skip( - f"Incompatible: {moe_impl.name} + {routing_config['routing_method_type'].name}" - ) + routing = routing_config["routing_method_type"].name + pytest.skip(f"Incompatible routing: {moe_impl.name} {routing=}") if type(moe_impl) not in weight_processing["compatible_moe_impls"]: + layout = weight_processing["layout"].name + use_shuffled_weight = weight_processing["use_shuffled_weight"] pytest.skip( - f"Incompatible: {moe_impl.name} + {weight_processing['use_shuffled_weight']} + {weight_processing['layout']}" + f"Incompatible weight format: {moe_impl.name} {use_shuffled_weight=} {layout=}" ) moe_impl._cache_permute_indices = cache_permute_indices From 9d35e9e5741ae4e621a464366a92ccf709fc9ed2 Mon Sep 17 00:00:00 2001 From: Nikita Korobov <14355239+nekorobov@users.noreply.github.com> Date: Fri, 26 Sep 2025 01:57:47 -0700 Subject: [PATCH 02/12] fix compilable kernel but test fails --- csrc/trtllm_fused_moe_kernel_launcher.cu | 2 +- .../trtllmGen_bmm_export/BatchedGemmEnums.h | 48 +- .../BatchedGemmInterface.h | 109 ++- .../trtllmGen_bmm_export/BatchedGemmOptions.h | 336 +++++--- .../batched_gemm/trtllmGen_bmm_export/Enums.h | 63 +- .../GemmGatedActOptions.h | 99 ++- .../trtllmGen_bmm_export/GemmOptions.h | 753 +++++++++++------- .../trtllmGen_bmm_export/KernelParams.h | 530 +++++++----- .../trtllmGen_bmm_export/KernelParamsDecl.h | 81 +- .../trtllmGen_bmm_export/KernelTraits.h | 186 +++-- .../trtllmGen_bmm_export/TmaDescriptor.h | 88 +- .../trtllm/gen/CommonUtils.h | 57 +- .../trtllm/gen/CudaKernelLauncher.h | 62 +- .../trtllm/gen/DtypeDecl.h | 184 ++--- .../trtllmGen_bmm_export/trtllm/gen/MmaDecl.h | 74 +- .../trtllm/gen/SfLayoutDecl.h | 58 +- tests/test_trtllm_gen_fused_moe.py | 6 +- 17 files changed, 1638 insertions(+), 1098 deletions(-) diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index 5cd2e227fb..4dea0c7736 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -269,7 +269,7 @@ class FusedMoeLauncher { using RunnerType = tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner; moe_runner = std::make_unique( this->mDtypeAct, this->mDtypeWeights, args->mUseDeepSeekFp8, (int32_t)tile_tokens_dim, - static_cast(this->gated_act_type), this->use_shuffled_weight); + static_cast(this->gated_act_type), this->use_shuffled_weight, this->weight_layout); if (moe_tactic == -1) { moe_tactic = moe_runner->getDefaultValidConfigIndex( diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmEnums.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmEnums.h index 27955d2bdc..9052618d1f 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmEnums.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmEnums.h @@ -1,23 +1,23 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & - * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +* SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & +* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ #pragma once -#include #include +#include namespace batchedGemm { @@ -36,20 +36,26 @@ enum class RouteImpl { //////////////////////////////////////////////////////////////////////////////////////////////////// -inline bool doesRouteImplUseNoRoute(RouteImpl mode) { return (mode == RouteImpl::NoRoute); } +inline bool doesRouteImplUseNoRoute(RouteImpl mode) { + return (mode == RouteImpl::NoRoute); +} //////////////////////////////////////////////////////////////////////////////////////////////////// -inline bool doesRouteImplUseLdgsts(RouteImpl mode) { return (mode == RouteImpl::Ldgsts); } +inline bool doesRouteImplUseLdgsts(RouteImpl mode) { + return (mode == RouteImpl::Ldgsts); +} //////////////////////////////////////////////////////////////////////////////////////////////////// -inline bool doesRouteImplUseTma(RouteImpl mode) { return (mode == RouteImpl::Tma); } +inline bool doesRouteImplUseTma(RouteImpl mode) { + return (mode == RouteImpl::Tma); +} //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace batchedGemm +} // namespace batchedGemm //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace batchedGemm +} // namespace batchedGemm diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h index 21fd39f59c..fbf6d56ce6 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h @@ -454,20 +454,25 @@ struct BatchedGemmData { //////////////////////////////////////////////////////////////////////////////////////////////////// class BatchedGemmInterface { - public: +public: using ModuleCache = std::unordered_map>; BatchedGemmInterface() {} // Launch the cubin from the provided config. It calls all necessary memsets for internal buffers. // Provided config must be validated with isValidConfig before the call. - int32_t run(BatchedGemmConfig const& config, void* workspace, BatchedGemmData const& options, - void* cudaStream, int32_t multiProcessorCount, bool usePdl = true, + int32_t run(BatchedGemmConfig const& config, + void* workspace, + BatchedGemmData const& options, + void* cudaStream, + int32_t multiProcessorCount, + bool usePdl = true, std::optional> moduleCache = std::nullopt); // Initializes the buffers before the world sync. Must be called before run. int32_t runInitBeforeWorldSync(BatchedGemmConfig const& /* config */, - BatchedGemmData const& /* data */, void* /* cudaStream */) const { + BatchedGemmData const& /* data */, + void* /* cudaStream */) const { return 0; }; @@ -482,8 +487,8 @@ class BatchedGemmInterface { // Returns the grid dimensions of the current kernel. std::tuple getGridDim( - BatchedGemmOptions const& options, - std::optional maxNumCtasInBatchDim = std::nullopt) const { + BatchedGemmOptions const& options, + std::optional maxNumCtasInBatchDim = std::nullopt) const { bool const batchM = options.mBatchMode == BatchedGemmOptions::BatchMode::BatchM; int32_t numCtasBatch{0}; @@ -506,12 +511,27 @@ class BatchedGemmInterface { throw std::invalid_argument("Invalid combination of options"); } - int32_t const numCtasTile = - batchM ? gemm::divUp(options.mN, options.mTileN) : gemm::divUp(options.mM, options.mTileM); + if (batchM) { + numCtasBatch = gemm::divUpMul(numCtasBatch, options.mClusterDimX); + } else { + numCtasBatch = gemm::divUpMul(numCtasBatch, options.mClusterDimY); + } + + int32_t numCtasTile = + batchM ? gemm::divUp(options.mN, options.mTileN) : gemm::divUp(options.mM, options.mTileM); + if (batchM) { + numCtasTile = gemm::divUpMul(numCtasTile, options.mClusterDimY); + } else { + numCtasTile = gemm::divUpMul(numCtasTile, options.mClusterDimX); + } int32_t const numCtasInner = options.mNumSlicesForSplitK; return std::make_tuple(numCtasBatch, numCtasTile, numCtasInner); } + // Creates GemmOptions from kernel and data. + BatchedGemmOptions getOptionsFromConfigAndData(BatchedGemmConfig const& config, + BatchedGemmData const& data) const; + // Returns the number of CTAs of the current kernel. int32_t getNumCtas(BatchedGemmOptions const& options, std::optional maxNumCtasInBatchDim = std::nullopt) const { @@ -522,14 +542,9 @@ class BatchedGemmInterface { // Returns true if the configuration of the cubin can be executed for the given params. bool isValidConfig(BatchedGemmConfig const& config, BatchedGemmData const& data) const; - // Creates GemmOptions from kernel and data. - BatchedGemmOptions getOptionsFromConfigAndData(BatchedGemmConfig const& config, - BatchedGemmData const& data) const; - - private: +private: // Aligns the pointer to the alignment - template - inline Dtype* alignPtr(Dtype* ptr, int64_t alignment) const; + template inline Dtype* alignPtr(Dtype* ptr, int64_t alignment) const; // Returns the size of the workspace buffers in bytes std::vector getWorkspaceSizesInBytes(BatchedGemmConfig const& config, @@ -572,7 +587,8 @@ size_t BatchedGemmInterface::getNumBatchedGemmConfigs() const { //////////////////////////////////////////////////////////////////////////////////////////////////// BatchedGemmOptions BatchedGemmInterface::getOptionsFromConfigAndData( - BatchedGemmConfig const& config, BatchedGemmData const& data) const { + BatchedGemmConfig const& config, + BatchedGemmData const& data) const { // Create options from config and data. BatchedGemmOptions options; options = config.mOptions; @@ -599,7 +615,8 @@ bool BatchedGemmInterface::isValidConfig(BatchedGemmConfig const& config, bool isBlackwell = gemm::isSmVersionBlackwell(config.mSm); // Check options without modifications. - return checkAndUpdateBatchedGemmOptions(options, isBlackwell, + return checkAndUpdateBatchedGemmOptions(options, + isBlackwell, /* updateOptions */ false); } @@ -623,7 +640,9 @@ size_t BatchedGemmInterface::getWorkspaceSizeInBytes(BatchedGemmConfig const& co //////////////////////////////////////////////////////////////////////////////////////////////////// std::vector BatchedGemmInterface::getWorkspaceSizesInBytes( - BatchedGemmConfig const& config, BatchedGemmData const& data) const { + BatchedGemmConfig const& config, + BatchedGemmData const& data) const { + std::vector workspaceSizes; // Get options from config and data. @@ -667,9 +686,12 @@ std::vector BatchedGemmInterface::getWorkspaceSizesInBytes( } //////////////////////////////////////////////////////////////////////////////////////////////////// -int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, void* workspace, - BatchedGemmData const& batchedGemmData, void* cudaStream, - int32_t /* multiProcessorCount */, bool usePdl, +int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, + void* workspace, + BatchedGemmData const& batchedGemmData, + void* cudaStream, + int32_t /* multiProcessorCount */, + bool usePdl, std::optional> moduleCache) { // Might be used. (void)usePdl; @@ -698,20 +720,32 @@ int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, void* workspa } auto [numCtaBatch, numCtaTile, numCtaInner] = - getGridDim(options, batchedGemmData.mProblemDimensions.mMaxNumCtasInTokenDim); - auto kernelParams = KernelParamsSetup::setKernelParams( - options, batchM, batchedGemmData.mInputBuffers.mPtrA, batchedGemmData.mInputBuffers.mPtrB, - batchedGemmData.mOutputBuffers.mPtrC, batchedGemmData.mInputBuffers.mPtrSfA, - batchedGemmData.mInputBuffers.mPtrSfB, batchedGemmData.mInputBuffers.mPtrPerTokenSfA, - batchedGemmData.mInputBuffers.mPtrPerTokenSfB, batchedGemmData.mInputBuffers.mPtrBias, - batchedGemmData.mOutputBuffers.mPtrSfC, batchedGemmData.mInputBuffers.mPtrScaleC, - batchedGemmData.mInputBuffers.mPtrScaleGate, batchedGemmData.mInputBuffers.mPtrClampLimit, - batchedGemmData.mInputBuffers.mPtrGatedActAlpha, - batchedGemmData.mInputBuffers.mPtrGatedActBeta, batchedGemmData.mInputBuffers.mPtrRouteMap, - dPtrRowMax, dPtrRowMaxBars, batchedGemmData.mInputBuffers.mPtrNumNonExitingCtas, - batchedGemmData.mInputBuffers.mPtrTotalNumPaddedTokens, - batchedGemmData.mInputBuffers.mPtrCtaIdxXyToBatchIdx, - batchedGemmData.mInputBuffers.mPtrCtaIdxXyToMnLimit, numCtaBatch); + getGridDim(options, batchedGemmData.mProblemDimensions.mMaxNumCtasInTokenDim); + auto kernelParams = + KernelParamsSetup::setKernelParams(options, + batchM, + batchedGemmData.mInputBuffers.mPtrA, + batchedGemmData.mInputBuffers.mPtrB, + batchedGemmData.mOutputBuffers.mPtrC, + batchedGemmData.mInputBuffers.mPtrSfA, + batchedGemmData.mInputBuffers.mPtrSfB, + batchedGemmData.mInputBuffers.mPtrPerTokenSfA, + batchedGemmData.mInputBuffers.mPtrPerTokenSfB, + batchedGemmData.mInputBuffers.mPtrBias, + batchedGemmData.mOutputBuffers.mPtrSfC, + batchedGemmData.mInputBuffers.mPtrScaleC, + batchedGemmData.mInputBuffers.mPtrScaleGate, + batchedGemmData.mInputBuffers.mPtrClampLimit, + batchedGemmData.mInputBuffers.mPtrGatedActAlpha, + batchedGemmData.mInputBuffers.mPtrGatedActBeta, + batchedGemmData.mInputBuffers.mPtrRouteMap, + dPtrRowMax, + dPtrRowMaxBars, + batchedGemmData.mInputBuffers.mPtrNumNonExitingCtas, + batchedGemmData.mInputBuffers.mPtrTotalNumPaddedTokens, + batchedGemmData.mInputBuffers.mPtrCtaIdxXyToBatchIdx, + batchedGemmData.mInputBuffers.mPtrCtaIdxXyToMnLimit, + numCtaBatch); // The size of the grid. std::vector grid = batchM ? std::vector{numCtaBatch, numCtaTile, numCtaInner} @@ -781,6 +815,7 @@ int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, void* workspa if (result != CUDA_SUCCESS) { return -1; } + // If a module cache has not been given, unload the module to avoid leaking if (!moduleCache.has_value()) { cuModuleUnload(cuModule); @@ -794,8 +829,8 @@ int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, void* workspa //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace batchedGemm +} // namespace batchedGemm //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace batchedGemm +} // namespace batchedGemm diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h index c29fb24b0a..0178148ae4 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h @@ -1,57 +1,57 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & - * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +* SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & +* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ #pragma once +#include "GemmOptions.h" +#include "GemmGatedActOptions.h" +#include "BatchedGemmEnums.h" + #include #include -#include "BatchedGemmEnums.h" -#include "GemmGatedActOptions.h" -#include "GemmOptions.h" - #ifndef TLLM_GEN_EXPORT_INTERFACE -#include "trtllm/gen/CudaRunner.h" #include "trtllm/gen/GenCtx.h" +#include "trtllm/gen/CudaRunner.h" #else #include -#define TLLM_CHECK_ERROR(cond, ...) \ - if (!(cond)) { \ - printArgs(__VA_ARGS__); \ - printArgs("\n"); \ - return false; \ +#define TLLM_CHECK_ERROR(cond, ...) \ + if (!(cond)) { \ + printArgs(__VA_ARGS__); \ + printArgs("\n"); \ + return false; \ } #define TLLM_LOG_ERROR(...) TLLM_CHECK_ERROR(false, __VA_ARGS__) #define TLLM_CHECK_ERROR_FMT(cond, ...) TLLM_CHECK_ERROR(cond, __VA_ARGS__) -#define TLLM_CHECK_WARNING(cond, ...) \ - if (!(cond)) { \ - printArgs(__VA_ARGS__); \ - printArgs("\n"); \ - return false; \ +#define TLLM_CHECK_WARNING(cond, ...) \ + if (!(cond)) { \ + printArgs(__VA_ARGS__); \ + printArgs("\n"); \ + return false; \ } #define TLLM_LOG_WARNING(...) TLLM_CHECK_WARNING(false, __VA_ARGS__) #define TLLM_LOG_INFO(...) TLLM_CHECK_WARNING(false, __VA_ARGS__) -#endif // TLLM_GEN_EXPORT_INTERFACE +#endif // TLLM_GEN_EXPORT_INTERFACE namespace batchedGemm { @@ -67,6 +67,7 @@ namespace tg = trtllm::gen; // We inherit from GemmGatedActOptions, which is inherited from // GemmOptions to get GemmOptions and GemmGatedActOptions at the same time. struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions { + // Dtor. Allow down-casting. virtual ~BatchedGemmOptions() = default; @@ -74,59 +75,178 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions { BatchedGemmOptions() = default; // FIXME We create explicit constructor with all options to WAR stubgen issue in TRT-LLM. - BatchedGemmOptions( - gemm::AllReduceAlgo allReduceAlgo, gemm::BiasType biasType, int blockK, int clusterDimX, - int clusterDimY, int clusterDimZ, tg::Dtype dtypeAcc, tg::Dtype dtypeA, tg::Dtype dtypeB, - tg::Dtype dtypeC, tg::Dtype dtypeMmaA, tg::Dtype dtypeMmaB, bool enablesEarlyExit, - bool enablesDelayedEarlyExit, bool enablesGlobalPtxKnobs, int epilogueLdtmDps, - int epilogueLdtmBits, int epilogueTileM, int epilogueTileN, bool gridTriggerSecondaryA, - bool gridTriggerSecondaryB, bool gridWaitForPrimaryEarlyExit, bool gridWaitForPrimaryA, - bool gridWaitForPrimaryB, bool hoistLoadTaskInit, bool hoistMmaTaskTryWaits, int k, - gemm::KernelTraits kernelTraits, gemm::MatrixLayout layoutA, gemm::MatrixLayout layoutB, - int m, int mmaK, tg::MmaKind mmaKind, int mmaM, int mmaN, bool mockAllReduce, int n, - int numSlicesForSplitK, int numSlicesForSliceK, int numStages, int numStagesMma, - int numStagesMmaWithinWorkTile, int numStagesMmaAcrossWorkTile, int numStagesWorkId, - bool outputDebugTensors, bool patchF2fp, std::optional sfBlockSizeA, - tg::SfLayout sfLayoutA, tg::SfLayout sfLayoutB, tg::SfLayout sfLayoutC, - int32_t sfReshapeFactor, bool sliceK, gemm::SplitK splitK, int tileK, int tileM, int tileN, - gemm::TileScheduler tileScheduler, bool transposeMmaOutput, bool useCustomMmaSchedule, - bool useDeepSeekFp8, bool useHoistTryWaitForCustomMmaSchedule, bool usePerTokenSfA, - bool usePerTokenSfB, bool useShuffledMatrixA, bool useTmaStore, bool useTwoTmaLoadWarps, - bool useTwoMmaWarps, bool useUnrollLoop2xForMma, int worldSize, gemmGatedAct::ActType actType, - bool clampBeforeAct, std::vector batchedM, std::vector batchedN, - BatchMode batchMode, int numBatches, bool isStaticBatch, int numTokens, RouteImpl routeImpl, - bool gridWaitForPrimaryRouting, bool fusedAct, int numRegsPerThreadNonEpilogueWarp, - int numRegsPerThreadEpilogueWarp, int numRegsCastAWarps, bool useTmaOobOpt) - : gemmGatedAct::GemmGatedActOptions( - gemm::GemmOptions( - allReduceAlgo, biasType, blockK, clusterDimX, clusterDimY, clusterDimZ, dtypeAcc, - dtypeA, dtypeB, dtypeC, dtypeMmaA, dtypeMmaB, enablesEarlyExit, - enablesDelayedEarlyExit, enablesGlobalPtxKnobs, epilogueLdtmDps, epilogueLdtmBits, - epilogueTileM, epilogueTileN, gridTriggerSecondaryA, gridTriggerSecondaryB, - gridWaitForPrimaryEarlyExit, gridWaitForPrimaryA, gridWaitForPrimaryB, - hoistLoadTaskInit, hoistMmaTaskTryWaits, k, kernelTraits, layoutA, layoutB, m, mmaK, - mmaKind, mmaM, mmaN, mockAllReduce, n, numSlicesForSplitK, numSlicesForSliceK, - numStages, numStagesMma, numStagesMmaWithinWorkTile, numStagesMmaAcrossWorkTile, - numStagesWorkId, outputDebugTensors, patchF2fp, sfBlockSizeA, sfLayoutA, sfLayoutB, - sfLayoutC, sfReshapeFactor, sliceK, splitK, tileK, tileM, tileN, tileScheduler, - transposeMmaOutput, useCustomMmaSchedule, useDeepSeekFp8, - useHoistTryWaitForCustomMmaSchedule, usePerTokenSfA, usePerTokenSfB, - useShuffledMatrixA, useTmaStore, useTwoTmaLoadWarps, useTwoMmaWarps, - useUnrollLoop2xForMma, worldSize), - actType, clampBeforeAct), - mBatchedM(batchedM), - mBatchedN(batchedN), - mBatchMode(BatchMode(batchMode)), - mFusedAct(fusedAct), - mGridWaitForPrimaryRouting(gridWaitForPrimaryRouting), - mIsStaticBatch(isStaticBatch), - mNumBatches(numBatches), - mNumRegsPerThreadNonEpilogueWarp(numRegsPerThreadNonEpilogueWarp), - mNumRegsPerThreadEpilogueWarp(numRegsPerThreadEpilogueWarp), - mNumRegsCastAWarps(numRegsCastAWarps), - mNumTokens(numTokens), - mRouteImpl(routeImpl), - mUseTmaOobOpt(useTmaOobOpt) {} + BatchedGemmOptions(gemm::AllReduceAlgo allReduceAlgo, + gemm::BiasType biasType, + int blockK, + int clusterDimX, + int clusterDimY, + int clusterDimZ, + gemm::CtaSwizzleType ctaSwizzleType, + tg::Dtype dtypeAcc, + tg::Dtype dtypeA, + tg::Dtype dtypeB, + tg::Dtype dtypeC, + tg::Dtype dtypeMmaA, + tg::Dtype dtypeMmaB, + bool enablesEarlyExit, + bool enablesDelayedEarlyExit, + bool enablesGlobalPtxKnobs, + int epilogueLdtmDps, + int epilogueLdtmBits, + int epilogueTileM, + int epilogueTileN, + bool gridTriggerSecondaryA, + bool gridTriggerSecondaryB, + bool gridWaitForPrimaryEarlyExit, + bool gridWaitForPrimaryA, + bool gridWaitForPrimaryB, + bool hoistLoadTaskInit, + bool hoistMmaTaskTryWaits, + int k, + gemm::KernelTraits kernelTraits, + gemm::MatrixLayout layoutA, + gemm::MatrixLayout layoutB, + int m, + int mmaK, + tg::MmaKind mmaKind, + int mmaM, + int mmaN, + bool mockAllReduce, + int n, + int numRegsCastAWarps, + int numRegsCopySfLdsSttm, + int numRegsPerThreadEpilogueWarp, + int numRegsPerThreadNonEpilogueWarp, + int numSlicesForSplitK, + int numSlicesForSliceK, + int numStages, + int numStagesMma, + int numStagesMmaWithinWorkTile, + int numStagesMmaAcrossWorkTile, + int numStagesWorkId, + bool outputDebugTensors, + bool patchF2fp, + std::optional sfBlockSizeA, + tg::SfLayout sfLayoutA, + tg::SfLayout sfLayoutB, + tg::SfLayout sfLayoutC, + int32_t sfReshapeFactor, + bool sliceK, + gemm::SplitK splitK, + int tileK, + int tileM, + int tileN, + gemm::TileScheduler tileScheduler, + bool transposeMmaOutput, + bool useCustomMmaSchedule, + bool useDeepSeekFp8, + bool useHoistTryWaitForCustomMmaSchedule, + bool usePerTokenSfA, + bool usePerTokenSfB, + bool useShuffledMatrixA, + bool useTmaStore, + bool useTwoTmaLoadWarps, + bool useTwoMmaWarps, + bool useUnrollLoop2xForMma, + int worldSize, + gemmGatedAct::ActType actType, + bool clampBeforeAct, + std::vector batchedM, + std::vector batchedN, + BatchMode batchMode, + int numBatches, + bool isStaticBatch, + int numTokens, + RouteImpl routeImpl, + bool gridWaitForPrimaryRouting, + bool fusedAct, + bool useTmaOobOpt) + : gemmGatedAct::GemmGatedActOptions(gemm::GemmOptions(allReduceAlgo, + biasType, + blockK, + clusterDimX, + clusterDimY, + clusterDimZ, + ctaSwizzleType, + dtypeAcc, + dtypeA, + dtypeB, + dtypeC, + dtypeMmaA, + dtypeMmaB, + enablesEarlyExit, + enablesDelayedEarlyExit, + enablesGlobalPtxKnobs, + epilogueLdtmDps, + epilogueLdtmBits, + epilogueTileM, + epilogueTileN, + gridTriggerSecondaryA, + gridTriggerSecondaryB, + gridWaitForPrimaryEarlyExit, + gridWaitForPrimaryA, + gridWaitForPrimaryB, + hoistLoadTaskInit, + hoistMmaTaskTryWaits, + k, + kernelTraits, + layoutA, + layoutB, + m, + mmaK, + mmaKind, + mmaM, + mmaN, + mockAllReduce, + n, + numRegsCastAWarps, + numRegsCopySfLdsSttm, + numRegsPerThreadEpilogueWarp, + numRegsPerThreadNonEpilogueWarp, + numSlicesForSplitK, + numSlicesForSliceK, + numStages, + numStagesMma, + numStagesMmaWithinWorkTile, + numStagesMmaAcrossWorkTile, + numStagesWorkId, + outputDebugTensors, + patchF2fp, + sfBlockSizeA, + sfLayoutA, + sfLayoutB, + sfLayoutC, + sfReshapeFactor, + sliceK, + splitK, + tileK, + tileM, + tileN, + tileScheduler, + transposeMmaOutput, + useCustomMmaSchedule, + useDeepSeekFp8, + useHoistTryWaitForCustomMmaSchedule, + usePerTokenSfA, + usePerTokenSfB, + useShuffledMatrixA, + useTmaStore, + useTwoTmaLoadWarps, + useTwoMmaWarps, + useUnrollLoop2xForMma, + worldSize), + actType, + clampBeforeAct) + , mBatchedM(batchedM) + , mBatchedN(batchedN) + , mBatchMode(BatchMode(batchMode)) + , mFusedAct(fusedAct) + , mGridWaitForPrimaryRouting(gridWaitForPrimaryRouting) + , mIsStaticBatch(isStaticBatch) + , mNumBatches(numBatches) + , mNumTokens(numTokens) + , mRouteImpl(routeImpl) + , mUseTmaOobOpt(useTmaOobOpt) {} // Batched M-dimensions of GEMM. std::vector mBatchedM; @@ -143,12 +263,6 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions { bool mIsStaticBatch{true}; // Number of Gemm batches. int mNumBatches; - // Number of registers per thread for non-epilogue warps - int mNumRegsPerThreadNonEpilogueWarp{0}; - // Number of registers per thread for epilogue warps - int mNumRegsPerThreadEpilogueWarp{0}; - // Number of registers for the cast A warps. - int mNumRegsCastAWarps{0}; // Total number of tokens. int mNumTokens{32}; // Whether load the input tokens and do routing. @@ -161,8 +275,10 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions { //////////////////////////////////////////////////////////////////////////////////////////////////// // Check if the options are valid or not. -bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackwell, +bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, + bool isBlackwell, bool updateOptions = true) { + bool isValid = true; if (options.mUseTmaOobOpt && !options.mUseTwoTmaLoadWarps) { if (updateOptions) { @@ -179,7 +295,7 @@ bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackw isValid = gemmGatedAct::checkAndUpdateGemmGatedActOptions(options, isBlackwell, updateOptions); } else { isValid = - gemm::checkAndUpdateGemmOptions(options, isBlackwell, 1 /* tpGrpSize */, updateOptions); + gemm::checkAndUpdateGemmOptions(options, isBlackwell, 1 /* tpGrpSize */, updateOptions); } bool batchM = options.mBatchMode == BatchedGemmOptions::BatchMode::BatchM; @@ -231,7 +347,7 @@ bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackw options.mK); TLLM_CHECK_ERROR(options.mDtypeC != tg::Dtype::E2m1 && options.mDtypeA == tg::Dtype::E4m3 && - options.mDtypeB == tg::Dtype::E4m3, + options.mDtypeB == tg::Dtype::E4m3, "E2m1 is not supported with DeepSeek FP8"); } @@ -263,22 +379,14 @@ bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackw "Tokens need use SF linear layout when being routed"); } else { // Note: if B is cast from a non-block format to a block format, there are no SFs to load. - TLLM_CHECK_ERROR( - options.mSfLayoutB == tg::SfLayout::Linear || !tg::dtypeIsBlockFmt(options.mDtypeB), - "Tokens need use SF linear layout when being routed"); + TLLM_CHECK_ERROR(options.mSfLayoutB == tg::SfLayout::Linear || + !tg::dtypeIsBlockFmt(options.mDtypeB), + "Tokens need use SF linear layout when being routed"); } } - if (doesRouteImplUseTma(options.mRouteImpl)) { - TLLM_CHECK_ERROR(!batchM, "UTMALDG.GATHER4 only supported for batch N."); - - if (tg::mmaKindIsBlockFmt(options.mMmaKind)) { - auto dtypeRoute = batchM ? options.mDtypeA : options.mDtypeB; - TLLM_CHECK_ERROR(options.mTileK % tg::dtypeNumEltsPerSf(dtypeRoute) == 0, - "tileK needs to be a multiple of 16 * tg::dtypeNumEltsPerSf(dtypeA)."); - TLLM_CHECK_ERROR(options.mTileK % (tg::dtypeNumEltsPerSf(dtypeRoute) * 16) == 0, - "tileK needs to be a multiple of 16 * tg::dtypeNumEltsPerSf(dtypeA)."); - } + if (options.mClusterDimX > 1) { + TLLM_CHECK_ERROR(!batchM, "2CTA Gemm currently only supports batch N."); } if (!batchM || doesRouteImplUseNoRoute(options.mRouteImpl)) { @@ -290,8 +398,8 @@ bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackw if (!gemm::isBiasTypeNone(options.mBiasType)) { TLLM_CHECK_ERROR((gemm::isBiasTypeN(options.mBiasType) && options.mBatchMode == BatchedGemmOptions::BatchMode::BatchM) || - (gemm::isBiasTypeM(options.mBiasType) && - options.mBatchMode == BatchedGemmOptions::BatchMode::BatchN), + (gemm::isBiasTypeM(options.mBiasType) && + options.mBatchMode == BatchedGemmOptions::BatchMode::BatchN), "BatchedGemm supports only per channel bias."); } @@ -323,6 +431,7 @@ struct BatchedGemmConfig { char const* mHash{nullptr}; #else trtllm::gen::CudaRunner* mCudaRunner{nullptr}; + int32_t mInstanceIdx{0}; #endif BatchedGemmOptions mOptions; @@ -345,18 +454,13 @@ inline std::string dumpOptions(BatchedGemmOptions const& options) { << std::endl; ss << "mGridWaitForPrimaryRouting=" << options.mGridWaitForPrimaryRouting << "," << std::endl; ss << "mFusedAct=" << options.mFusedAct << "," << std::endl; - ss << "mNumRegsPerThreadNonEpilogueWarp=" << options.mNumRegsPerThreadNonEpilogueWarp << "," - << std::endl; - ss << "mNumRegsPerThreadEpilogueWarp=" << options.mNumRegsPerThreadEpilogueWarp << "," - << std::endl; - ss << "mNumRegsCastAWarps=" << options.mNumRegsCastAWarps << "," << std::endl; ss << "mUseTmaOobOpt=" << options.mUseTmaOobOpt << std::endl; return ss.str(); } //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace batchedGemm +} // namespace batchedGemm //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -369,6 +473,6 @@ inline std::string dumpOptions(BatchedGemmOptions const& options) { #undef TLLM_LOG_INFO #undef TLLM_LOG_ERROR -#endif // TLLM_GEN_EXPORT_INTERFACE +#endif // TLLM_GEN_EXPORT_INTERFACE -} // namespace batchedGemm +} // namespace batchedGemm diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h index 6f2b1c270d..79c6c027d0 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h @@ -1,19 +1,19 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & - * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +* SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & +* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ #pragma once #include @@ -97,10 +97,29 @@ enum class TileScheduler { //////////////////////////////////////////////////////////////////////////////////////////////////// +enum class CtaSwizzleType : uint32_t { + // Rasterize CTAs along the M dimension. + RasterizeAlongM = 0, + // Rasterize CTAs along the N dimension. + RasterizeAlongN, + // Swizzle CTAs in zig-zag pattern along M dimension, Zig-zag width is 2. + ZigZagAlongM2, + // Swizzle CTAs in zig-zag pattern along N dimension, Zig-zag width is 2. + ZigZagAlongN2, + // Swizzle CTAs in zig-zag pattern along M dimension, Zig-zag width is 4. + ZigZagAlongM4, + // Swizzle CTAs in zig-zag pattern along N dimension, Zig-zag width is 4. + ZigZagAlongN4, +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + // Helper functions to check the SplitK type. -#define SPLIT_K_FUNCTION(Mode) \ - inline bool doesSplitKUse##Mode(SplitK mode) { return (mode == SplitK::Mode); } +#define SPLIT_K_FUNCTION(Mode) \ + inline bool doesSplitKUse##Mode(SplitK mode) { \ + return (mode == SplitK::Mode); \ + } SPLIT_K_FUNCTION(Gmem) SPLIT_K_FUNCTION(Dsmem) @@ -111,8 +130,10 @@ SPLIT_K_FUNCTION(Dsmem) // Helper functions to check the Bias type. -#define BIAS_TYPE_FUNCTION(Mode) \ - inline bool isBiasType##Mode(BiasType type) { return (type == BiasType::Mode); } +#define BIAS_TYPE_FUNCTION(Mode) \ + inline bool isBiasType##Mode(BiasType type) { \ + return (type == BiasType::Mode); \ + } BIAS_TYPE_FUNCTION(None) BIAS_TYPE_FUNCTION(N) @@ -123,6 +144,6 @@ BIAS_TYPE_FUNCTION(Mn) //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace gemm +} // namespace gemm -} // namespace batchedGemm +} // namespace batchedGemm diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h index 1086cd4fd5..c9a9a4663f 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h @@ -1,19 +1,19 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & - * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +* SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & +* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ #pragma once #include "GemmOptions.h" @@ -21,20 +21,20 @@ #ifdef TLLM_GEN_EXPORT_INTERFACE #include -#define TLLM_CHECK_ERROR(cond, ...) \ - if (!(cond)) { \ - printArgs(__VA_ARGS__); \ - return false; \ +#define TLLM_CHECK_ERROR(cond, ...) \ + if (!(cond)) { \ + printArgs(__VA_ARGS__); \ + return false; \ } #define TLLM_LOG_ERROR(...) TLLM_CHECK_ERROR(false, __VA_ARGS__) #define TLLM_CHECK_ERROR_FMT(...) TLLM_CHECK_ERROR(false, __VA_ARGS__) -#define TLLM_CHECK_WARNING(cond, ...) \ - if (!(cond)) { \ - printArgs(__VA_ARGS__); \ - return false; \ +#define TLLM_CHECK_WARNING(cond, ...) \ + if (!(cond)) { \ + printArgs(__VA_ARGS__); \ + return false; \ } #define TLLM_LOG_WARNING(...) TLLM_CHECK_WARNING(false, __VA_ARGS__) @@ -74,8 +74,10 @@ enum class ActType { // Helper functions to check the ActType type. -#define TLLM_ACT_TYPE_FUNCTION(actType) \ - inline bool is##actType(ActType type) { return (type == ActType::actType); } +#define TLLM_ACT_TYPE_FUNCTION(actType) \ + inline bool is##actType(ActType type) { \ + return (type == ActType::actType); \ + } TLLM_ACT_TYPE_FUNCTION(SwiGlu) TLLM_ACT_TYPE_FUNCTION(GeGlu) @@ -86,12 +88,12 @@ TLLM_ACT_TYPE_FUNCTION(GeGlu) inline std::string getActTypeName(ActType type) { switch (type) { - case ActType::SwiGlu: - return "SwiGlu"; - case ActType::GeGlu: - return "GeGlu"; - default: - return "Unknown type"; + case ActType::SwiGlu: + return "SwiGlu"; + case ActType::GeGlu: + return "GeGlu"; + default: + return "Unknown type"; } } @@ -100,7 +102,9 @@ inline std::string getActTypeName(ActType type) { struct GemmGatedActOptions : public gemm::GemmOptions { GemmGatedActOptions() = default; GemmGatedActOptions(gemm::GemmOptions options, ActType actType, bool clampBeforeAct) - : gemm::GemmOptions(options), mActType(actType), mClampBeforeAct(clampBeforeAct) {} + : gemm::GemmOptions(options) + , mActType(actType) + , mClampBeforeAct(clampBeforeAct) {} // Type of the gated activation. ActType mActType{ActType::SwiGlu}; @@ -112,12 +116,14 @@ struct GemmGatedActOptions : public gemm::GemmOptions { // Check if the options are valid or not. inline bool checkAndUpdateGemmGatedActOptions(gemmGatedAct::GemmGatedActOptions& options, - bool isBlackwell, bool updateOptions = true) { + bool isBlackwell, + bool updateOptions = true) { + // tmpOut is already transposed at this stage auto const hiddenSizeStr = options.mTransposeMmaOutput ? "M" : "N"; auto const hiddenSize = options.mTransposeMmaOutput ? options.mM : options.mN; auto const hiddenEpilogueTileSize = - options.mTransposeMmaOutput ? options.mEpilogueTileM : options.mEpilogueTileN; + options.mTransposeMmaOutput ? options.mEpilogueTileM : options.mEpilogueTileN; TLLM_CHECK_ERROR(hiddenSize % 2 == 0, hiddenSizeStr, " must be a multiple of 2."); @@ -126,8 +132,8 @@ inline bool checkAndUpdateGemmGatedActOptions(gemmGatedAct::GemmGatedActOptions& if (options.mUseTmaStore) { TLLM_CHECK_ERROR( - hiddenEpilogueTileSize * tg::dtypeGetNumBits(options.mDtypeC) / /* bits */ 8 % 32 == 0, - "Unsupported output hidden tile size"); + hiddenEpilogueTileSize * tg::dtypeGetNumBits(options.mDtypeC) / /* bits */ 8 % 32 == 0, + "Unsupported output hidden tile size"); } if (options.mUseDeepSeekFp8) { @@ -137,12 +143,18 @@ inline bool checkAndUpdateGemmGatedActOptions(gemmGatedAct::GemmGatedActOptions& if (options.mDtypeC == tg::Dtype::E2m1 || options.mDtypeC == tg::Dtype::MxE4m3) { int const outHiddenSize = (options.mTransposeMmaOutput ? options.mM : options.mN) / 2; int const hiddenGranularity = 4 * tg::dtypeNumEltsPerSf(options.mDtypeC); - TLLM_CHECK_ERROR(outHiddenSize % hiddenGranularity == 0, "Output hidden size (", outHiddenSize, - ") must be a multiple of ", hiddenGranularity, " for block-scaled outputs."); + TLLM_CHECK_ERROR(outHiddenSize % hiddenGranularity == 0, + "Output hidden size (", + outHiddenSize, + ") must be a multiple of ", + hiddenGranularity, + " for block-scaled outputs."); } - auto isValid = gemm::checkAndUpdateGemmOptions(options, isBlackwell, - /* tpGrpSize */ 1, updateOptions); + auto isValid = gemm::checkAndUpdateGemmOptions(options, + isBlackwell, + /* tpGrpSize */ 1, + updateOptions); if (!isValid) { return false; @@ -191,6 +203,7 @@ struct GemmGatedActConfig { char const* mHash{nullptr}; #else trtllm::gen::CudaRunner* mCudaRunner{nullptr}; + int32_t mInstanceIdx{0}; #endif GemmGatedActOptions mOptions{}; @@ -199,7 +212,7 @@ struct GemmGatedActConfig { //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace gemmGatedAct +} // namespace gemmGatedAct #ifdef TLLM_GEN_EXPORT_INTERFACE @@ -209,6 +222,6 @@ struct GemmGatedActConfig { #undef TLLM_LOG_WARNING #undef TLLM_LOG_INFO #undef TLLM_LOG_ERROR -#endif // TLLM_GEN_EXPORT_INTERFACE +#endif // TLLM_GEN_EXPORT_INTERFACE -} // namespace batchedGemm +} // namespace batchedGemm diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h index f9c7044700..18386f3a2b 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h @@ -1,19 +1,19 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & - * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +* SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & +* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ #pragma once #include @@ -24,23 +24,21 @@ #include "KernelParams.h" #include "KernelTraits.h" #include "trtllm/gen/DtypeDecl.h" -#include "trtllm/gen/MmaDecl.h" #include "trtllm/gen/SfLayoutDecl.h" +#include "trtllm/gen/MmaDecl.h" #ifndef TLLM_GEN_EXPORT_INTERFACE -#include "trtllm/gen/CudaRunner.h" #include "trtllm/gen/GenCtx.h" +#include "trtllm/gen/CudaRunner.h" #else #include -template -void printArgs(T arg) { +template void printArgs(T arg) { #ifdef TLLM_GEN_DEBUG std::cout << arg; #endif } -template -void printArgs(T first, Args... args) { +template void printArgs(T first, Args... args) { printArgs(first); if constexpr (sizeof...(args) > 0) { printArgs(", "); @@ -48,29 +46,29 @@ void printArgs(T first, Args... args) { } } -#define TLLM_CHECK_ERROR(cond, ...) \ - if (!(cond)) { \ - printArgs(__VA_ARGS__); \ - printArgs("\n"); \ - return false; \ +#define TLLM_CHECK_ERROR(cond, ...) \ + if (!(cond)) { \ + printArgs(__VA_ARGS__); \ + printArgs("\n"); \ + return false; \ } #define TLLM_LOG_ERROR(...) TLLM_CHECK_ERROR(false, __VA_ARGS__) #define TLLM_CHECK_ERROR_FMT(cond, ...) TLLM_CHECK_ERROR(cond, __VA_ARGS__) -#define TLLM_CHECK_WARNING(cond, ...) \ - if (!(cond)) { \ - printArgs(__VA_ARGS__); \ - printArgs("\n"); \ - return false; \ +#define TLLM_CHECK_WARNING(cond, ...) \ + if (!(cond)) { \ + printArgs(__VA_ARGS__); \ + printArgs("\n"); \ + return false; \ } #define TLLM_LOG_WARNING(...) TLLM_CHECK_WARNING(false, __VA_ARGS__) #define TLLM_LOG_INFO(...) TLLM_CHECK_WARNING(false, __VA_ARGS__) -#endif // TLLM_GEN_EXPORT_INTERFACE +#endif // TLLM_GEN_EXPORT_INTERFACE namespace batchedGemm { @@ -91,95 +89,154 @@ struct GemmOptions { #endif GemmOptions() = default; - GemmOptions(AllReduceAlgo allReduceAlgo, BiasType biasType, int blockK, int clusterDimX, - int clusterDimY, int clusterDimZ, tg::Dtype dtypeAcc, tg::Dtype dtypeA, - tg::Dtype dtypeB, tg::Dtype dtypeC, tg::Dtype dtypeMmaA, tg::Dtype dtypeMmaB, - bool enablesEarlyExit, bool enablesDelayedEarlyExit, bool enablesGlobalPtxKnobs, - int epilogueLdtmDps, int epilogueLdtmBits, int epilogueTileM, int epilogueTileN, - bool gridTriggerSecondaryA, bool gridTriggerSecondaryB, - bool gridWaitForPrimaryEarlyExit, bool gridWaitForPrimaryA, bool gridWaitForPrimaryB, - bool hoistLoadTaskInit, bool hoistMmaTaskTryWaits, int k, KernelTraits kernelTraits, - MatrixLayout layoutA, MatrixLayout layoutB, int m, int mmaK, tg::MmaKind mmaKind, - int mmaM, int mmaN, bool mockAllReduce, int n, int numSlicesForSplitK, - int numSlicesForSliceK, int numStages, int numStagesMma, - int numStagesMmaWithinWorkTile, int numStagesMmaAcrossWorkTile, int numStagesWorkId, - bool outputDebugTensors, bool patchF2fp, std::optional sfBlockSizeA, - tg::SfLayout sfLayoutA, tg::SfLayout sfLayoutB, tg::SfLayout sfLayoutC, - int sfReshapeFactor, bool sliceK, SplitK splitK, int tileK, int tileM, int tileN, - TileScheduler tileScheduler, bool transposeMmaOutput, bool useCustomMmaSchedule, - bool useDeepSeekFp8, bool useHoistTryWaitForCustomMmaSchedule, bool usePerTokenSfA, - bool usePerTokenSfB, bool useShuffledMatrixA, bool useTmaStore, - bool useTwoTmaLoadWarps, bool useTwoMmaWarps, bool useUnrollLoop2xForMma, + GemmOptions(AllReduceAlgo allReduceAlgo, + BiasType biasType, + int blockK, + int clusterDimX, + int clusterDimY, + int clusterDimZ, + CtaSwizzleType ctaSwizzleType, + tg::Dtype dtypeAcc, + tg::Dtype dtypeA, + tg::Dtype dtypeB, + tg::Dtype dtypeC, + tg::Dtype dtypeMmaA, + tg::Dtype dtypeMmaB, + bool enablesEarlyExit, + bool enablesDelayedEarlyExit, + bool enablesGlobalPtxKnobs, + int epilogueLdtmDps, + int epilogueLdtmBits, + int epilogueTileM, + int epilogueTileN, + bool gridTriggerSecondaryA, + bool gridTriggerSecondaryB, + bool gridWaitForPrimaryEarlyExit, + bool gridWaitForPrimaryA, + bool gridWaitForPrimaryB, + bool hoistLoadTaskInit, + bool hoistMmaTaskTryWaits, + int k, + KernelTraits kernelTraits, + MatrixLayout layoutA, + MatrixLayout layoutB, + int m, + int mmaK, + tg::MmaKind mmaKind, + int mmaM, + int mmaN, + bool mockAllReduce, + int n, + int numRegsCastAWarps, + int numRegsCopySfLdsSttm, + int numRegsPerThreadEpilogueWarp, + int numRegsPerThreadNonEpilogueWarp, + int numSlicesForSplitK, + int numSlicesForSliceK, + int numStages, + int numStagesMma, + int numStagesMmaWithinWorkTile, + int numStagesMmaAcrossWorkTile, + int numStagesWorkId, + bool outputDebugTensors, + bool patchF2fp, + std::optional sfBlockSizeA, + tg::SfLayout sfLayoutA, + tg::SfLayout sfLayoutB, + tg::SfLayout sfLayoutC, + int sfReshapeFactor, + bool sliceK, + SplitK splitK, + int tileK, + int tileM, + int tileN, + TileScheduler tileScheduler, + bool transposeMmaOutput, + bool useCustomMmaSchedule, + bool useDeepSeekFp8, + bool useHoistTryWaitForCustomMmaSchedule, + bool usePerTokenSfA, + bool usePerTokenSfB, + bool useShuffledMatrixA, + bool useTmaStore, + bool useTwoTmaLoadWarps, + bool useTwoMmaWarps, + bool useUnrollLoop2xForMma, int worldSize) - : mAllReduceAlgo{allReduceAlgo}, - mBiasType{biasType}, - mBlockK(blockK), - mClusterDimX{clusterDimX}, - mClusterDimY{clusterDimY}, - mClusterDimZ{clusterDimZ}, - mDtypeAcc{dtypeAcc}, - mDtypeA{dtypeA}, - mDtypeB{dtypeB}, - mDtypeC{dtypeC}, - mDtypeMmaA{dtypeMmaA}, - mDtypeMmaB{dtypeMmaB}, - mEnablesEarlyExit{enablesEarlyExit}, - mEnablesDelayedEarlyExit{enablesDelayedEarlyExit}, - mEnablesGlobalPtxKnobs{enablesGlobalPtxKnobs}, - mEpilogueLdtmDps{epilogueLdtmDps}, - mEpilogueLdtmBits{epilogueLdtmBits}, - mEpilogueTileM{epilogueTileM}, - mEpilogueTileN{epilogueTileN}, - mGridTriggerSecondaryA{gridTriggerSecondaryA}, - mGridTriggerSecondaryB{gridTriggerSecondaryB}, - mGridWaitForPrimaryEarlyExit{gridWaitForPrimaryEarlyExit}, - mGridWaitForPrimaryA{gridWaitForPrimaryA}, - mGridWaitForPrimaryB{gridWaitForPrimaryB}, - mHoistLoadTaskInit{hoistLoadTaskInit}, - mHoistMmaTaskTryWaits{hoistMmaTaskTryWaits}, - mK{k}, - mKernelTraits{kernelTraits}, - mLayoutA{layoutA}, - mLayoutB{layoutB}, - mM{m}, - mMmaK{mmaK}, - mMmaKind{mmaKind}, - mMmaM{mmaM}, - mMmaN{mmaN}, - mMockAllReduce{mockAllReduce}, - mN{n}, - mNumSlicesForSplitK{numSlicesForSplitK}, - mNumSlicesForSliceK{numSlicesForSliceK}, - mNumStages{numStages}, - mNumStagesMma{numStagesMma}, - mNumStagesMmaWithinWorkTile{numStagesMmaWithinWorkTile}, - mNumStagesMmaAcrossWorkTile{numStagesMmaAcrossWorkTile}, - mNumStagesWorkId{numStagesWorkId}, - mOutputDebugTensors{outputDebugTensors}, - mPatchF2fp{patchF2fp}, - mSfBlockSizeA{sfBlockSizeA}, - mSfLayoutA{sfLayoutA}, - mSfLayoutB{sfLayoutB}, - mSfLayoutC{sfLayoutC}, - mSfReshapeFactor{sfReshapeFactor}, - mSliceK{sliceK}, - mSplitK{splitK}, - mTileK{tileK}, - mTileM{tileM}, - mTileN{tileN}, - mTileScheduler{tileScheduler}, - mTransposeMmaOutput{transposeMmaOutput}, - mUseCustomMmaSchedule{useCustomMmaSchedule}, - mUseDeepSeekFp8{useDeepSeekFp8}, - mUseHoistTryWaitForCustomMmaSchedule{useHoistTryWaitForCustomMmaSchedule}, - mUsePerTokenSfA{usePerTokenSfA}, - mUsePerTokenSfB{usePerTokenSfB}, - mUseShuffledMatrixA{useShuffledMatrixA}, - mUseTmaStore{useTmaStore}, - mUseTwoTmaLoadWarps{useTwoTmaLoadWarps}, - mUseTwoMmaWarps{useTwoMmaWarps}, - mUseUnrollLoop2xForMma{useUnrollLoop2xForMma}, - mWorldSize{worldSize} {} + : mAllReduceAlgo{allReduceAlgo} + , mBiasType{biasType} + , mBlockK(blockK) + , mClusterDimX{clusterDimX} + , mClusterDimY{clusterDimY} + , mClusterDimZ{clusterDimZ} + , mCtaSwizzleType{ctaSwizzleType} + , mDtypeAcc{dtypeAcc} + , mDtypeA{dtypeA} + , mDtypeB{dtypeB} + , mDtypeC{dtypeC} + , mDtypeMmaA{dtypeMmaA} + , mDtypeMmaB{dtypeMmaB} + , mEnablesEarlyExit{enablesEarlyExit} + , mEnablesDelayedEarlyExit{enablesDelayedEarlyExit} + , mEnablesGlobalPtxKnobs{enablesGlobalPtxKnobs} + , mEpilogueLdtmDps{epilogueLdtmDps} + , mEpilogueLdtmBits{epilogueLdtmBits} + , mEpilogueTileM{epilogueTileM} + , mEpilogueTileN{epilogueTileN} + , mGridTriggerSecondaryA{gridTriggerSecondaryA} + , mGridTriggerSecondaryB{gridTriggerSecondaryB} + , mGridWaitForPrimaryEarlyExit{gridWaitForPrimaryEarlyExit} + , mGridWaitForPrimaryA{gridWaitForPrimaryA} + , mGridWaitForPrimaryB{gridWaitForPrimaryB} + , mHoistLoadTaskInit{hoistLoadTaskInit} + , mHoistMmaTaskTryWaits{hoistMmaTaskTryWaits} + , mK{k} + , mKernelTraits{kernelTraits} + , mLayoutA{layoutA} + , mLayoutB{layoutB} + , mM{m} + , mMmaK{mmaK} + , mMmaKind{mmaKind} + , mMmaM{mmaM} + , mMmaN{mmaN} + , mMockAllReduce{mockAllReduce} + , mN{n} + , mNumRegsCastAWarps(numRegsCastAWarps) + , mNumRegsCopySfLdsSttm(numRegsCopySfLdsSttm) + , mNumRegsPerThreadEpilogueWarp(numRegsPerThreadEpilogueWarp) + , mNumRegsPerThreadNonEpilogueWarp(numRegsPerThreadNonEpilogueWarp) + , mNumSlicesForSplitK{numSlicesForSplitK} + , mNumSlicesForSliceK{numSlicesForSliceK} + , mNumStages{numStages} + , mNumStagesMma{numStagesMma} + , mNumStagesMmaWithinWorkTile{numStagesMmaWithinWorkTile} + , mNumStagesMmaAcrossWorkTile{numStagesMmaAcrossWorkTile} + , mNumStagesWorkId{numStagesWorkId} + , mOutputDebugTensors{outputDebugTensors} + , mPatchF2fp{patchF2fp} + , mSfBlockSizeA{sfBlockSizeA} + , mSfLayoutA{sfLayoutA} + , mSfLayoutB{sfLayoutB} + , mSfLayoutC{sfLayoutC} + , mSfReshapeFactor{sfReshapeFactor} + , mSliceK{sliceK} + , mSplitK{splitK} + , mTileK{tileK} + , mTileM{tileM} + , mTileN{tileN} + , mTileScheduler{tileScheduler} + , mTransposeMmaOutput{transposeMmaOutput} + , mUseCustomMmaSchedule{useCustomMmaSchedule} + , mUseDeepSeekFp8{useDeepSeekFp8} + , mUseHoistTryWaitForCustomMmaSchedule{useHoistTryWaitForCustomMmaSchedule} + , mUsePerTokenSfA{usePerTokenSfA} + , mUsePerTokenSfB{usePerTokenSfB} + , mUseShuffledMatrixA{useShuffledMatrixA} + , mUseTmaStore{useTmaStore} + , mUseTwoTmaLoadWarps{useTwoTmaLoadWarps} + , mUseTwoMmaWarps{useTwoMmaWarps} + , mUseUnrollLoop2xForMma{useUnrollLoop2xForMma} + , mWorldSize{worldSize} {} // The all-reduce algorithm. AllReduceAlgo mAllReduceAlgo{AllReduceAlgo::None}; @@ -193,6 +250,8 @@ struct GemmOptions { int mClusterDimY{1}; // Cluster size in Z dim. int mClusterDimZ{1}; + // The type of CTA swizzle. + CtaSwizzleType mCtaSwizzleType{CtaSwizzleType::RasterizeAlongM}; // Data type of the accumulators. tg::Dtype mDtypeAcc{tg::Dtype::Fp32}; // Data type of the A matrix. @@ -263,6 +322,14 @@ struct GemmOptions { bool mMockAllReduce{false}; // The N dimension of GEMM. int mN{64 * 4}; + // Number of registers for the cast A warps. + int mNumRegsCastAWarps{0}; + // Number of registers for the LDS+STTM warps. + int mNumRegsCopySfLdsSttm{0}; + // Number of registers per thread for epilogue warps + int mNumRegsPerThreadEpilogueWarp{0}; + // Number of registers per thread for non-epilogue warps + int mNumRegsPerThreadNonEpilogueWarp{0}; // Number of partitions along the K dimension. When mNumSlicesForSplitK > 1, // the problem is distributed across several SMs, where each CTA works on its local K slice. // Partial results are accumulated afterwards using either GMEM or DSMEM (in CGA) @@ -369,6 +436,7 @@ struct GemmConfig { char const* mHash{nullptr}; #else trtllm::gen::CudaRunner* mCudaRunner{nullptr}; + int32_t mInstanceIdx{0}; #endif GemmOptions mOptions{}; @@ -378,22 +446,19 @@ struct GemmConfig { //////////////////////////////////////////////////////////////////////////////////////////////////// // Serialization helpers. -template -inline std::string toString(T e) { +template inline std::string toString(T e) { return std::to_string(e); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template <> -inline std::string toString(trtllm::gen::Dtype e) { +template <> inline std::string toString(trtllm::gen::Dtype e) { return trtllm::gen::dtypeToString(e); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template <> -inline std::string toString(trtllm::gen::MmaKind e) { +template <> inline std::string toString(trtllm::gen::MmaKind e) { return trtllm::gen::mmaKindToString(e); } @@ -409,6 +474,8 @@ inline std::string dumpOptions(GemmOptions const& options) { ss << "mClusterDimX=" << options.mClusterDimX << "," << std::endl; ss << "mClusterDimY=" << options.mClusterDimY << "," << std::endl; ss << "mClusterDimZ=" << options.mClusterDimZ << "," << std::endl; + ss << "mCtaSwizzleType=" << "gemm::CtaSwizzleType(" + << static_cast(options.mCtaSwizzleType) << ")" << "," << std::endl; ss << "mDtypeAcc=" << "trtllm::gen::Dtype(" << static_cast(options.mDtypeAcc) << ")" << "," << std::endl; ss << "mDtypeA=" << "trtllm::gen::Dtype(" << static_cast(options.mDtypeA) << ")" << "," @@ -449,6 +516,12 @@ inline std::string dumpOptions(GemmOptions const& options) { ss << "mMmaN=" << options.mMmaN << "," << std::endl; ss << "mMockAllReduce=" << options.mMockAllReduce << "," << std::endl; ss << "mN=" << options.mN << "," << std::endl; + ss << "mNumRegsCastAWarps=" << options.mNumRegsCastAWarps << "," << std::endl; + ss << "mNumRegsCopySfLdsSttm=" << options.mNumRegsCopySfLdsSttm << "," << std::endl; + ss << "mNumRegsPerThreadEpilogueWarp=" << options.mNumRegsPerThreadEpilogueWarp << "," + << std::endl; + ss << "mNumRegsPerThreadNonEpilogueWarp=" << options.mNumRegsPerThreadNonEpilogueWarp << "," + << std::endl; ss << "mNumSlicesForSplitK=" << options.mNumSlicesForSplitK << "," << std::endl; ss << "mNumSlicesForSliceK=" << options.mNumSlicesForSliceK << "," << std::endl; ss << "mNumStages=" << options.mNumStages << "," << std::endl; @@ -496,15 +569,13 @@ inline std::string dumpOptions(GemmOptions const& options) { //////////////////////////////////////////////////////////////////////////////////////////////////// -template -inline T divUp(T a, T b) { +template inline T divUp(T a, T b) { return (a + b - 1) / b; } //////////////////////////////////////////////////////////////////////////////////////////////////// -template -inline T divUpMul(T a, T b) { +template inline T divUpMul(T a, T b) { return gemm::divUp(a, b) * b; } @@ -521,7 +592,9 @@ inline int32_t getShuffleBlockSize(int epilogueTileM) { //////////////////////////////////////////////////////////////////////////////////////////////////// // Check if the options are valid or not. -inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, int tpGrpSize, +inline bool checkAndUpdateGemmOptions(GemmOptions& options, + bool isBlackwell, + int tpGrpSize, bool updateOptions = true) { options.mWorldSize = tpGrpSize; @@ -552,21 +625,25 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in // Check that the A cast is supported. // Currently, we only support {MxFp4, NvFp4} -> Bf16. TLLM_CHECK_ERROR( - (options.mDtypeA == options.mDtypeMmaA) || - ((options.mDtypeA == tg::Dtype::MxE2m1 || options.mDtypeA == tg::Dtype::E2m1) && - options.mDtypeMmaA == tg::Dtype::Bfloat16) || - (options.mDtypeA == tg::Dtype::E2m1 && options.mDtypeMmaA == tg::Dtype::E4m3), - "Unsupported cast for A: ", tg::dtypeToString(options.mDtypeA), " -> ", - tg::dtypeToString(options.mDtypeMmaA)); + (options.mDtypeA == options.mDtypeMmaA) || + ((options.mDtypeA == tg::Dtype::MxE2m1 || options.mDtypeA == tg::Dtype::E2m1) && + options.mDtypeMmaA == tg::Dtype::Bfloat16) || + (options.mDtypeA == tg::Dtype::E2m1 && options.mDtypeMmaA == tg::Dtype::E4m3), + "Unsupported cast for A: ", + tg::dtypeToString(options.mDtypeA), + " -> ", + tg::dtypeToString(options.mDtypeMmaA)); // Check that the B cast is supported. // Currently, we only support Fp8 -> MxFp8. // TODO: add same support for A (no transpose) TLLM_CHECK_ERROR( - (options.mDtypeB == options.mDtypeMmaB) || - (options.mDtypeB == tg::Dtype::E4m3 && options.mDtypeMmaB == tg::Dtype::MxE4m3), - "Unsupported cast for B: ", tg::dtypeToString(options.mDtypeB), " -> ", - tg::dtypeToString(options.mDtypeMmaB)); + (options.mDtypeB == options.mDtypeMmaB) || + (options.mDtypeB == tg::Dtype::E4m3 && options.mDtypeMmaB == tg::Dtype::MxE4m3), + "Unsupported cast for B: ", + tg::dtypeToString(options.mDtypeB), + " -> ", + tg::dtypeToString(options.mDtypeMmaB)); if (options.mDtypeA != options.mDtypeMmaA) { TLLM_CHECK_ERROR(options.mTileM == 128, @@ -574,9 +651,9 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in } if (options.mPatchF2fp) { - TLLM_CHECK_ERROR( - options.mDtypeA == tg::Dtype::MxE2m1 && options.mDtypeMmaA == tg::Dtype::Bfloat16, - "PatchF2fp is only supported for MxFp4 to Bf16 casts."); + TLLM_CHECK_ERROR(options.mDtypeA == tg::Dtype::MxE2m1 && + options.mDtypeMmaA == tg::Dtype::Bfloat16, + "PatchF2fp is only supported for MxFp4 to Bf16 casts."); } // FIXME: We do not support different dtypes for A and B when not on Blackwell. @@ -594,14 +671,14 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in // kind::mxf8f6f4 if (options.mDtypeMmaA == tg::Dtype::MxE4m3 || options.mDtypeMmaA == tg::Dtype::MxE2m1) { - TLLM_CHECK_ERROR( - options.mDtypeMmaB == tg::Dtype::MxE4m3 || options.mDtypeMmaB == tg::Dtype::MxE2m1, - "For dtypeMmaA = MxE4m3 or MxE2m1, dtypeMmaB must also be MxE4m3 or MxE2m1."); + TLLM_CHECK_ERROR(options.mDtypeMmaB == tg::Dtype::MxE4m3 || + options.mDtypeMmaB == tg::Dtype::MxE2m1, + "For dtypeMmaA = MxE4m3 or MxE2m1, dtypeMmaB must also be MxE4m3 or MxE2m1."); } if (options.mDtypeMmaB == tg::Dtype::MxE4m3 || options.mDtypeMmaB == tg::Dtype::MxE2m1) { - TLLM_CHECK_ERROR( - options.mDtypeMmaA == tg::Dtype::MxE4m3 || options.mDtypeMmaA == tg::Dtype::MxE2m1, - "For dtypeMmaB = MxE4m3 or MxE2m1, dtypeMmaA must also be MxE4m3 or MxE2m1."); + TLLM_CHECK_ERROR(options.mDtypeMmaA == tg::Dtype::MxE4m3 || + options.mDtypeMmaA == tg::Dtype::MxE2m1, + "For dtypeMmaB = MxE4m3 or MxE2m1, dtypeMmaA must also be MxE4m3 or MxE2m1."); } // kind::f16 @@ -635,8 +712,11 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in if ((options.mMmaKind == tg::MmaKind::Fp8Fp6Fp4 || options.mMmaKind == tg::MmaKind::MxFp8Fp6Fp4) && options.mMmaK != 32) { - TLLM_LOG_WARNING("Unsupported MmaK (", options.mMmaK, - ") for MmaKind=", gemm::toString(options.mMmaKind), ". Setting MmaK to 32"); + TLLM_LOG_WARNING("Unsupported MmaK (", + options.mMmaK, + ") for MmaKind=", + gemm::toString(options.mMmaKind), + ". Setting MmaK to 32"); if (updateOptions) { options.mMmaK = 32; options.mTileK = std::max(options.mMmaK, options.mTileK); @@ -648,13 +728,19 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in // Check LDTM shape. if (isBlackwell) { TLLM_CHECK_ERROR((options.mEpilogueLdtmDps == 16 && options.mEpilogueLdtmBits == 256) || - (options.mEpilogueLdtmDps == 32 && options.mEpilogueLdtmBits == 32), - "Unsupported LDTM shape: ", options.mEpilogueLdtmDps, "dp", - options.mEpilogueLdtmBits, "bit."); + (options.mEpilogueLdtmDps == 32 && options.mEpilogueLdtmBits == 32), + "Unsupported LDTM shape: ", + options.mEpilogueLdtmDps, + "dp", + options.mEpilogueLdtmBits, + "bit."); if (options.mEpilogueTileM == 64) { TLLM_CHECK_ERROR(options.mEpilogueLdtmDps == 16, - "Unsupported LDTM shape for epilogueTileM=64: ", options.mEpilogueLdtmDps, - "dp", options.mEpilogueLdtmBits, "bit."); + "Unsupported LDTM shape for epilogueTileM=64: ", + options.mEpilogueLdtmDps, + "dp", + options.mEpilogueLdtmBits, + "bit."); } if (options.mTransposeMmaOutput) { // We can't use 32dp32bit LDTM for transposed outputs because we need each thread to own @@ -664,27 +750,43 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in } } else { TLLM_CHECK_ERROR( - options.mEpilogueLdtmDps == 16 && options.mEpilogueLdtmBits == 256, - "Hopper does not use TMEM. The register layout corresponds to 16dp256bit. Got ", - options.mEpilogueLdtmDps, "dp", options.mEpilogueLdtmBits, "bit."); + options.mEpilogueLdtmDps == 16 && options.mEpilogueLdtmBits == 256, + "Hopper does not use TMEM. The register layout corresponds to 16dp256bit. Got ", + options.mEpilogueLdtmDps, + "dp", + options.mEpilogueLdtmBits, + "bit."); } // Constraints for NvFp4 and MxFp8. if ((options.mMmaKind == tg::MmaKind::MxFp4NvFp4 || options.mMmaKind == tg::MmaKind::MxFp8Fp6Fp4 || options.mDtypeC == tg::Dtype::MxE4m3) && options.mMmaM != 128) { - // MMA M must be 128 when the input uses block scaling, or when the output is an Mx format. - int newTileM = 128 * divUp(options.mTileM, 128); - TLLM_LOG_WARNING("Unsupported MmaM (", options.mMmaM, - ") for MmaKind=", gemm::toString(options.mMmaKind), - ". Setting MmaM to 128 and TileM to ", newTileM); - if (updateOptions) { - options.mMmaM = 128; - options.mTileM = newTileM; + + if (options.mClusterDimX == 1) { + // MMA M must be 128 when the input uses block scaling, or when the output is an Mx format. + int newTileM = 128 * divUp(options.mTileM, 128); + TLLM_LOG_WARNING("Unsupported MmaM (", + options.mMmaM, + ") for MmaKind=", + gemm::toString(options.mMmaKind), + ". Setting MmaM to 128 and TileM to ", + newTileM); + if (updateOptions) { + options.mMmaM = 128; + options.mTileM = newTileM; + } else { + return false; + } } else { - return false; + TLLM_CHECK_ERROR(options.mMmaM == 256 && options.mTileM == 128, + "2CTA UTCxMMA only supports mmaM = 256 and tileM = 128."); } } + if (options.mClusterDimX > 1) { + TLLM_CHECK_ERROR(options.mLayoutB != MatrixLayout::BlockMajorK, + "layoutB == MatrixLayout::BlockMajorK is not supported for now"); + } if (options.mMmaKind == tg::MmaKind::MxFp4NvFp4 || options.mMmaKind == tg::MmaKind::MxFp8Fp6Fp4) { TLLM_CHECK_ERROR(isBlackwell, "Block scaling is only supported on Blackwell"); @@ -700,9 +802,14 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in } if (options.mMmaK != mmaK) { int newTileK = mmaK * divUp(options.mTileK, mmaK); - TLLM_LOG_WARNING("Unsupported MmaK (", options.mMmaK, - ") for MmaKind=", gemm::toString(options.mMmaKind), ". Setting MmaK to ", - mmaK, " and TileK to ", newTileK); + TLLM_LOG_WARNING("Unsupported MmaK (", + options.mMmaK, + ") for MmaKind=", + gemm::toString(options.mMmaKind), + ". Setting MmaK to ", + mmaK, + " and TileK to ", + newTileK); if (updateOptions) { options.mMmaK = mmaK; options.mTileK = newTileK; @@ -712,8 +819,12 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in } // The MMA N may only be smaller than 64 if it is equal to the tile N. - TLLM_CHECK_ERROR(options.mMmaN >= 64 || options.mMmaN == options.mTileN, "MmaN (", - options.mMmaN, ") must be >= 64 or equal to TileN (", options.mTileN, ")"); + TLLM_CHECK_ERROR(options.mMmaN >= 64 || options.mMmaN == options.mTileN, + "MmaN (", + options.mMmaN, + ") must be >= 64 or equal to TileN (", + options.mTileN, + ")"); } if (options.mSfBlockSizeA.has_value()) { @@ -721,7 +832,8 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in TLLM_CHECK_ERROR(options.mDtypeA == tg::Dtype::E2m1 && options.mDtypeB == tg::Dtype::E4m3, "sfBlockSizeA is only supported for E2m1 and E4m3 types. Found dtypeA=", tg::dtypeToString(options.mDtypeA), - " dtypeB=", tg::dtypeToString(options.mDtypeB)); + " dtypeB=", + tg::dtypeToString(options.mDtypeB)); // sfBlockSizeA must be 16 or 32. // SfBlockSizeA can also support 64 and 128, although they are not officially supported Nvida @@ -730,38 +842,56 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in // If we want to support sfBlockSizeA=8, we can write another version of convertE2m1ToSfE4m3, // which only packs 8 e2m1 elements. TLLM_CHECK_ERROR(options.mSfBlockSizeA.value() == 16 || options.mSfBlockSizeA.value() == 32, - "SfBlockSizeA (", options.mSfBlockSizeA.value(), ") must be 16 or 32."); + "SfBlockSizeA (", + options.mSfBlockSizeA.value(), + ") must be 16 or 32."); } if (tg::dtypeIsBlockFmt(options.mDtypeA)) { int numEltsPerSfA = options.mSfBlockSizeA.value_or(tg::dtypeNumEltsPerSf(options.mDtypeA)); - TLLM_CHECK_ERROR(options.mTileK % (4 * numEltsPerSfA) == 0, "TileK (", options.mTileK, - ") must be a multiple of ", (4 * numEltsPerSfA), " for typeA ", + TLLM_CHECK_ERROR(options.mTileK % (4 * numEltsPerSfA) == 0, + "TileK (", + options.mTileK, + ") must be a multiple of ", + (4 * numEltsPerSfA), + " for typeA ", gemm::toString(options.mDtypeA)); auto const numEltsPerSfAInK = options.mK / numEltsPerSfA; - TLLM_CHECK_ERROR(numEltsPerSfAInK % 4 == 0, "K dimension of scaling factors for A (", - numEltsPerSfAInK, ") must be a multiple of 4"); + TLLM_CHECK_ERROR(numEltsPerSfAInK % 4 == 0, + "K dimension of scaling factors for A (", + numEltsPerSfAInK, + ") must be a multiple of 4"); } if (tg::dtypeIsBlockFmt(options.mDtypeB)) { TLLM_CHECK_ERROR(options.mSfLayoutB == tg::SfLayout::R128c4 || - options.mSfLayoutB == tg::SfLayout::R8c4 || - options.mSfLayoutB == tg::SfLayout::Linear, + options.mSfLayoutB == tg::SfLayout::R8c4 || + options.mSfLayoutB == tg::SfLayout::Linear, "Only the 128x4 and 8x4 SF layouts are supported for B, got ", tg::sfLayoutToString(options.mSfLayoutB)); // TileN must be a multiple of the number of rows per SF tile. int const numSfTileRowsB = options.mSfLayoutB == tg::SfLayout::R128c4 ? 128 : 8; - TLLM_CHECK_ERROR(options.mTileN % numSfTileRowsB == 0, "TileN (", options.mTileN, - ") must be a multiple of ", numSfTileRowsB, " for B SF layout ", + TLLM_CHECK_ERROR(options.mTileN % numSfTileRowsB == 0, + "TileN (", + options.mTileN, + ") must be a multiple of ", + numSfTileRowsB, + " for B SF layout ", tg::sfLayoutToString(options.mSfLayoutB)); int numEltsPerSfB = tg::dtypeNumEltsPerSf(options.mDtypeB); - TLLM_CHECK_ERROR(options.mTileK % (4 * numEltsPerSfB) == 0, "TileK (", options.mTileK, - ") must be a multiple of ", (4 * numEltsPerSfB), " for typeB ", + TLLM_CHECK_ERROR(options.mTileK % (4 * numEltsPerSfB) == 0, + "TileK (", + options.mTileK, + ") must be a multiple of ", + (4 * numEltsPerSfB), + " for typeB ", gemm::toString(options.mDtypeB)); auto const numEltsPerSfBInK = options.mK / numEltsPerSfB; - TLLM_CHECK_ERROR(numEltsPerSfBInK % 4 == 0, "K dimension of scaling factors for B (", - numEltsPerSfBInK, ") must be a multiple of 4"); + TLLM_CHECK_ERROR(numEltsPerSfBInK % 4 == 0, + "K dimension of scaling factors for B (", + numEltsPerSfBInK, + ") must be a multiple of 4"); } int32_t padMultiplierA = 1; @@ -774,30 +904,36 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in padMultiplierB = 2; } } - TLLM_CHECK_ERROR( - (padMultiplierA * tg::dtypeGetNumBits(options.mDtypeA) * options.mK / 8) % 16 == 0, - "K dimension of A must be aligned to 16 bytes."); - TLLM_CHECK_ERROR( - (padMultiplierB * tg::dtypeGetNumBits(options.mDtypeB) * options.mK / 8) % 16 == 0, - "K dimension of B must be aligned to 16 bytes."); + TLLM_CHECK_ERROR((padMultiplierA * tg::dtypeGetNumBits(options.mDtypeA) * options.mK / 8) % 16 == + 0, + "K dimension of A must be aligned to 16 bytes."); + TLLM_CHECK_ERROR((padMultiplierB * tg::dtypeGetNumBits(options.mDtypeB) * options.mK / 8) % 16 == + 0, + "K dimension of B must be aligned to 16 bytes."); if (options.mDtypeC == tg::Dtype::E2m1 || options.mDtypeC == tg::Dtype::MxE4m3) { TLLM_CHECK_ERROR(isBlackwell, "Block scaling is only supported on Blackwell"); - TLLM_CHECK_ERROR( - options.mSfLayoutC == tg::SfLayout::R128c4 || options.mSfLayoutC == tg::SfLayout::R8c4, - "Only the 128x4 and 8x4 SF layouts are supported for C."); + TLLM_CHECK_ERROR(options.mSfLayoutC == tg::SfLayout::R128c4 || + options.mSfLayoutC == tg::SfLayout::R8c4, + "Only the 128x4 and 8x4 SF layouts are supported for C."); int const numSfTileRowsC = options.mSfLayoutC == tg::SfLayout::R128c4 ? 128 : 8; int const tileTokenDim = options.mTransposeMmaOutput ? options.mTileN : options.mTileM; TLLM_CHECK_ERROR_FMT(tileTokenDim % numSfTileRowsC == 0, "Tile%s (%d) must be a multiple of %d for C SF layout %s", - options.mTransposeMmaOutput ? "N" : "M", tileTokenDim, numSfTileRowsC, + options.mTransposeMmaOutput ? "N" : "M", + tileTokenDim, + numSfTileRowsC, tg::sfLayoutToString(options.mSfLayoutC).c_str()); int const hiddenDim = options.mTransposeMmaOutput ? options.mM : options.mN; int const hiddenGranularity = 4 * tg::dtypeNumEltsPerSf(options.mDtypeC); - TLLM_CHECK_ERROR(hiddenDim % hiddenGranularity == 0, "Hidden dim (", hiddenDim, - ") must be a multiple of ", hiddenGranularity, " for block-scaled outputs."); + TLLM_CHECK_ERROR(hiddenDim % hiddenGranularity == 0, + "Hidden dim (", + hiddenDim, + ") must be a multiple of ", + hiddenGranularity, + " for block-scaled outputs."); TLLM_CHECK_ERROR(!options.mTransposeMmaOutput || options.mUseShuffledMatrixA, "Transposing block-scaled outputs requires shuffled A."); } @@ -814,8 +950,11 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in // Set epilogue tile sizes to the output tile sizes, when epilogue tile sizes are incorrect. if (options.mTileM % options.mEpilogueTileM != 0) { - TLLM_LOG_WARNING("TileM (", options.mTileM, ") must be divisible by EpilogueTileM (", - options.mEpilogueTileM, "). Setting EpilogueTileM to TileM"); + TLLM_LOG_WARNING("TileM (", + options.mTileM, + ") must be divisible by EpilogueTileM (", + options.mEpilogueTileM, + "). Setting EpilogueTileM to TileM"); if (updateOptions) { options.mEpilogueTileM = options.mTileM; } else { @@ -824,8 +963,11 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in } if (options.mTileN % options.mEpilogueTileN != 0) { - TLLM_LOG_WARNING("TileN (", options.mTileN, ") must be divisible by EpilogueTileN (", - options.mEpilogueTileN, "). Setting EpilogueTileN to TileN"); + TLLM_LOG_WARNING("TileN (", + options.mTileN, + ") must be divisible by EpilogueTileN (", + options.mEpilogueTileN, + "). Setting EpilogueTileN to TileN"); if (updateOptions) { options.mEpilogueTileN = options.mTileN; } else { @@ -837,7 +979,7 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in if (!isBlackwell && (options.mEpilogueTileM != options.mTileM || options.mEpilogueTileN != options.mTileN)) { TLLM_LOG_WARNING( - "Overwriting epilogueTileM and epilogueTileN to match tileM and tileN respectively"); + "Overwriting epilogueTileM and epilogueTileN to match tileM and tileN respectively"); if (updateOptions) { options.mEpilogueTileM = options.mTileM; options.mEpilogueTileN = options.mTileN; @@ -849,7 +991,7 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in // Unsupported epilogue tile size. if (options.mMmaM == 128 && options.mEpilogueTileM != options.mTileM) { TLLM_LOG_WARNING( - "When MmaM = 128, EpilogueTileM must be equal to TileM. Setting EpilogueTileM to TileM"); + "When MmaM = 128, EpilogueTileM must be equal to TileM. Setting EpilogueTileM to TileM"); if (updateOptions) { options.mEpilogueTileM = options.mTileM; } else { @@ -864,18 +1006,31 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in if (options.mUseShuffledMatrixA) { auto const shuffleBlockSize = getShuffleBlockSize(options.mEpilogueTileM); TLLM_CHECK_ERROR(options.mM % shuffleBlockSize == 0, - "M must be a multiple of shuffle block size (", shuffleBlockSize, + "M must be a multiple of shuffle block size (", + shuffleBlockSize, ") when useShuffledMatrixA"); } if (!options.mSliceK) { - TLLM_CHECK_ERROR(options.mMmaM <= options.mEpilogueTileM, + TLLM_CHECK_ERROR(options.mMmaM / options.mClusterDimX <= options.mEpilogueTileM, "EpilogueTileM must be larger or equal than mmaM."); + } else { + // FIXME: this is not necessary limitation. Simply fixing num repeats in TmemSliceKA should be + // enough. + TLLM_CHECK_ERROR((options.mTileN & (options.mTileN - 1)) == 0, + "For Slice-K TileN is required to be a power of 2"); + } + + if (options.mClusterDimX == 2) { + TLLM_CHECK_ERROR(options.mMmaM == 256, "Only mmaM = 256 is supported for 2CTA UTCMMA."); + TLLM_CHECK_ERROR(options.mMmaN % 16 == 0, "mmaN needs to be multiple of 16 for 2CTA UTCMMA."); } + TLLM_CHECK_ERROR( - options.mTileM % options.mEpilogueTileM == 0 && options.mTileN % options.mEpilogueTileN == 0, - "TileM and TileN must be divisible by EpilogueTileM and EpilogueTileN respectively."); - TLLM_CHECK_ERROR(options.mClusterDimX == 1 && options.mClusterDimY == 1, + options.mTileM % options.mEpilogueTileM == 0 && options.mTileN % options.mEpilogueTileN == 0, + "TileM and TileN must be divisible by EpilogueTileM and EpilogueTileN respectively."); + TLLM_CHECK_ERROR((options.mClusterDimX == 1 || options.mClusterDimX == 2) && + options.mClusterDimY == 1, "GEMM does not support cluster in X and Y dimensions."); TLLM_CHECK_ERROR(options.mClusterDimZ == 1 || options.mNumSlicesForSplitK > 1, "Cluster DimZ is only allowed for split-k."); @@ -891,8 +1046,8 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in if (options.mUseShuffledMatrixA) { // TODO add matrix shuffle for N-major epilogue. TLLM_CHECK_ERROR( - options.mTransposeMmaOutput, - "Shuffled matrix A is only supported with M-major epilogue. Set -transposeMmaOutput"); + options.mTransposeMmaOutput, + "Shuffled matrix A is only supported with M-major epilogue. Set -transposeMmaOutput"); } // Check all-reduce options. @@ -902,18 +1057,24 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in // // See: https://docs.nvidia.com/cuda/parallel-thread-execution/ // #data-movement-and-conversion-instructions-cp-reduce-async-bulk-tensor - std::set dtypeSupported{tg::Dtype::UInt32, tg::Dtype::Int32, tg::Dtype::UInt64, - tg::Dtype::Fp32, tg::Dtype::Fp16, tg::Dtype::Bfloat16}; + std::set dtypeSupported{tg::Dtype::UInt32, + tg::Dtype::Int32, + tg::Dtype::UInt64, + tg::Dtype::Fp32, + tg::Dtype::Fp16, + tg::Dtype::Bfloat16}; TLLM_CHECK_ERROR(dtypeSupported.find(options.mDtypeC) != dtypeSupported.end(), - "Unsupported output dtype ", tg::dtypeToString(options.mDtypeC)); + "Unsupported output dtype ", + tg::dtypeToString(options.mDtypeC)); } else if (options.mAllReduceAlgo == AllReduceAlgo::TwoShot) { // TODO(anchengc): // Input dtype == output dtype -> can perform all-reduce in-place. // Input dtype != output dtype -> must perform all-reduce out of place. TLLM_CHECK_ERROR_FMT( - options.mDtypeC == options.mDtypeAcc, - "Not implemented - mixed dtype (dtypeC (%s) != dtypeAcc (%s)) requires out of place update", - tg::dtypeToString(options.mDtypeC).c_str(), tg::dtypeToString(options.mDtypeAcc).c_str()); + options.mDtypeC == options.mDtypeAcc, + "Not implemented - mixed dtype (dtypeC (%s) != dtypeAcc (%s)) requires out of place update", + tg::dtypeToString(options.mDtypeC).c_str(), + tg::dtypeToString(options.mDtypeAcc).c_str()); } if (options.mAllReduceAlgo != AllReduceAlgo::None) { TLLM_CHECK_ERROR(options.mUseTmaStore, "Non-TMA store with all-reduce is not implemented"); @@ -941,7 +1102,7 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in if ((options.mEpilogueTileM != options.mTileM || options.mEpilogueTileN != options.mTileN) && !options.mUseDeepSeekFp8) { TLLM_LOG_WARNING( - "Overwriting epilogueTileM and epilogueTileN to match tileM and tileN respectively"); + "Overwriting epilogueTileM and epilogueTileN to match tileM and tileN respectively"); if (updateOptions) { options.mEpilogueTileM = options.mTileM; options.mEpilogueTileN = options.mTileN; @@ -963,37 +1124,40 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in if (updateOptions) { options.mNumStagesMmaAcrossWorkTile = std::min(2, options.mNumStagesMma); options.mNumStagesMmaWithinWorkTile = - options.mNumStagesMma / options.mNumStagesMmaAcrossWorkTile; + options.mNumStagesMma / options.mNumStagesMmaAcrossWorkTile; } else { return false; } } else if (options.mNumStagesMmaWithinWorkTile == -1) { if (updateOptions) { options.mNumStagesMmaWithinWorkTile = - options.mNumStagesMma / options.mNumStagesMmaAcrossWorkTile; + options.mNumStagesMma / options.mNumStagesMmaAcrossWorkTile; } else { return false; } } else if (options.mNumStagesMmaAcrossWorkTile == -1) { if (updateOptions) { options.mNumStagesMmaAcrossWorkTile = - options.mNumStagesMma / options.mNumStagesMmaWithinWorkTile; + options.mNumStagesMma / options.mNumStagesMmaWithinWorkTile; } else { return false; } } // Check mma stages. TLLM_CHECK_ERROR_FMT(options.mNumStagesMmaWithinWorkTile * options.mNumStagesMmaAcrossWorkTile == - options.mNumStagesMma && - options.mNumStagesMmaAcrossWorkTile <= 2, + options.mNumStagesMma && + options.mNumStagesMmaAcrossWorkTile <= 2, "Condition numStagesMmaWithinWorkTile (%d) * numStagesMmaAcrossWorkTile " "(%d) == numStagesMma (%d) && numStagesMmaAcrossWorkTile (%d) <= 2 must be " "satisfied. Check arguments.", - options.mNumStagesMmaWithinWorkTile, options.mNumStagesMmaAcrossWorkTile, - options.mNumStagesMma, options.mNumStagesMmaAcrossWorkTile); + options.mNumStagesMmaWithinWorkTile, + options.mNumStagesMmaAcrossWorkTile, + options.mNumStagesMma, + options.mNumStagesMmaAcrossWorkTile); // Mma stage must be 1 for pre-Hopper. TLLM_CHECK_ERROR(isBlackwell || options.mNumStagesMma == 1, - "Mma stage must be 1 for pre-Hopper. Found ", options.mNumStagesMma); + "Mma stage must be 1 for pre-Hopper. Found ", + options.mNumStagesMma); // DeepSeek Fp8 if (!options.mUseDeepSeekFp8) { TLLM_CHECK_ERROR(options.mNumStagesMmaWithinWorkTile == 1, @@ -1003,11 +1167,15 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in "Non-DeepSeekFp8 requires persistent scheduler when using numStagesMma >1"); } } + if (options.mUseDeepSeekFp8) { + TLLM_CHECK_ERROR(options.mClusterDimX == 1, "2CTA Gemm is not supported for DeepSeekFp8"); + } if (options.mUseDeepSeekFp8) { TLLM_CHECK_ERROR(options.mDtypeA == tg::Dtype::E4m3 && options.mDtypeB == tg::Dtype::E4m3, "A and B dtype must be E4m3 for DeepSeek Fp8. Found dtypeA=", tg::dtypeToString(options.mDtypeA), - " dtypeB=", tg::dtypeToString(options.mDtypeB)); + " dtypeB=", + tg::dtypeToString(options.mDtypeB)); TLLM_CHECK_ERROR(isBlackwell, "DeepSeek Fp8 is not supported for Hopper"); TLLM_CHECK_ERROR(options.mAllReduceAlgo == AllReduceAlgo::None, @@ -1019,7 +1187,7 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in // Tile sizes of the output hidden dimension. auto hiddenDimPerOutputTile = options.mTransposeMmaOutput ? options.mTileM : options.mTileN; auto hiddenDimPerEpilogueTile = - options.mTransposeMmaOutput ? options.mEpilogueTileM : options.mEpilogueTileN; + options.mTransposeMmaOutput ? options.mEpilogueTileM : options.mEpilogueTileN; auto hiddenDimPerMma = options.mTransposeMmaOutput ? options.mMmaM : options.mMmaN; auto hiddenDimName = options.mTransposeMmaOutput ? "M" : "N"; TLLM_CHECK_WARNING(options.mNumStagesMmaWithinWorkTile > 1, @@ -1037,14 +1205,26 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in // Check that the output tile N can be processed with the epilogue tile granularity. TLLM_CHECK_ERROR((hiddenDimPerOutputTile / 2) % hiddenDimPerEpilogueTile == 0, - "DeepSeek Fp8 requires Tile", hiddenDimName, " / 2 (", - hiddenDimPerOutputTile / 2, ") being a multiple of EpilogueTile", - hiddenDimName, " (", hiddenDimPerEpilogueTile, ")"); + "DeepSeek Fp8 requires Tile", + hiddenDimName, + " / 2 (", + hiddenDimPerOutputTile / 2, + ") being a multiple of EpilogueTile", + hiddenDimName, + " (", + hiddenDimPerEpilogueTile, + ")"); // Check that the output tile N can be processed with the epilogue tile granularity. TLLM_CHECK_ERROR((hiddenDimPerOutputTile / 2) % hiddenDimPerMma == 0, - "DeepSeek Fp8 requires Tile", hiddenDimName, " / 2 (", - hiddenDimPerOutputTile / 2, ") being a multiple of mma", hiddenDimName, " (", - hiddenDimPerMma, ")"); + "DeepSeek Fp8 requires Tile", + hiddenDimName, + " / 2 (", + hiddenDimPerOutputTile / 2, + ") being a multiple of mma", + hiddenDimName, + " (", + hiddenDimPerMma, + ")"); } if (options.mSliceK) { @@ -1079,9 +1259,13 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in return false; } } - TLLM_CHECK_ERROR((options.mTileK / options.mMmaK) % options.mNumSlicesForSliceK == 0, "TileK (", - options.mTileK, ") / MmaK (", options.mMmaK, - ") must be a multiple of mNumSlicesForSliceK (", options.mNumSlicesForSliceK, + TLLM_CHECK_ERROR((options.mTileK / options.mMmaK) % options.mNumSlicesForSliceK == 0, + "TileK (", + options.mTileK, + ") / MmaK (", + options.mMmaK, + ") must be a multiple of mNumSlicesForSliceK (", + options.mNumSlicesForSliceK, ")"); } @@ -1111,9 +1295,15 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in (clampedAndPaddedPerCtaK % (options.mTileK * 2) != 0); if (notSupported) { TLLM_LOG_WARNING("Size K / splitK must be a multiple of TileK * 2. Found TileK=", - options.mTileK, " and K=", options.mK, " (paddedK=", paddedK, - " clampedAndPaddedPerCtaK=", clampedAndPaddedPerCtaK, - ") and numSlicesForSplitK=", options.mNumSlicesForSplitK, + options.mTileK, + " and K=", + options.mK, + " (paddedK=", + paddedK, + " clampedAndPaddedPerCtaK=", + clampedAndPaddedPerCtaK, + ") and numSlicesForSplitK=", + options.mNumSlicesForSplitK, ". Disabling unrollLoop2xForMma."); if (updateOptions) { options.mUseUnrollLoop2xForMma = false; @@ -1124,8 +1314,8 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in } if (options.mNumSlicesForSplitK > 1) { TLLM_CHECK_ERROR( - perCtaK * (options.mNumSlicesForSplitK - 1) < options.mK, - "K must be greater than perCtaK * (numSlicesForSplitK - 1) to ensure each CTA has work"); + perCtaK * (options.mNumSlicesForSplitK - 1) < options.mK, + "K must be greater than perCtaK * (numSlicesForSplitK - 1) to ensure each CTA has work"); } if (!isBlackwell && options.mTileScheduler == TileScheduler::Persistent) { @@ -1139,9 +1329,8 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in } if (options.mEnablesDelayedEarlyExit && options.mEnablesEarlyExit) { - TLLM_LOG_WARNING( - "Only one of early exit and delayed early exit should be enabled. Disabling " - "delayed early exit"); + TLLM_LOG_WARNING("Only one of early exit and delayed early exit should be enabled. Disabling " + "delayed early exit"); if (updateOptions) { options.mEnablesDelayedEarlyExit = false; } else { @@ -1168,11 +1357,11 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in // visible- Warp 2: -------------------ACQBULK-- Kernel 1,2 output visible // ---------- TLLM_CHECK_ERROR( - (options.mGridWaitForPrimaryA || !options.mGridTriggerSecondaryA), - "A: If a task triggers a secondary kernel, it must also wait for primary kernel."); + (options.mGridWaitForPrimaryA || !options.mGridTriggerSecondaryA), + "A: If a task triggers a secondary kernel, it must also wait for primary kernel."); TLLM_CHECK_ERROR( - (options.mGridWaitForPrimaryB || !options.mGridTriggerSecondaryB), - "B: If a task triggers a secondary kernel, it must also wait for primary kernel."); + (options.mGridWaitForPrimaryB || !options.mGridTriggerSecondaryB), + "B: If a task triggers a secondary kernel, it must also wait for primary kernel."); if (options.mUsePerTokenSfA || options.mUsePerTokenSfB) { // Checks applicable to both MetaFP8 and RoutingScalesOnInput @@ -1184,20 +1373,21 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in TLLM_CHECK_ERROR(options.mDtypeA == tg::Dtype::E4m3 && options.mDtypeB == tg::Dtype::E4m3, "A and B dtype must be E4m3 for Meta Fp8. Found dtypeA=", tg::dtypeToString(options.mDtypeA), - " dtypeB=", tg::dtypeToString(options.mDtypeB)); + " dtypeB=", + tg::dtypeToString(options.mDtypeB)); } else { // RoutingScalesOnInput case TLLM_CHECK_ERROR((options.mUsePerTokenSfA && !options.mTransposeMmaOutput) || - (options.mUsePerTokenSfB && options.mTransposeMmaOutput), + (options.mUsePerTokenSfB && options.mTransposeMmaOutput), "In RoutingScalesOnInput mode, perToken scales must be used on activations"); } } // The generation should support non K-major layouts for both A and B; however, it is unclear if // there is a use-case - TLLM_CHECK_ERROR( - (options.mLayoutA == MatrixLayout::MajorK) || (options.mLayoutB == MatrixLayout::MajorK), - "At least one matrix must be in k-major layout"); + TLLM_CHECK_ERROR((options.mLayoutA == MatrixLayout::MajorK) || + (options.mLayoutB == MatrixLayout::MajorK), + "At least one matrix must be in k-major layout"); // Some features are currently only support when both matrices are in K-major format if (options.mLayoutB != MatrixLayout::MajorK || options.mLayoutB != MatrixLayout::MajorK) { @@ -1221,7 +1411,7 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in // TODO Leaving this as an option for now in case we want to expertiment with other block sizes // As the user is not expected to set this, do not fail if updateOptions is false int32_t const elemSizeInBits = - (isBlockA) ? tg::dtypeGetNumBits(options.mDtypeA) : tg::dtypeGetNumBits(options.mDtypeB); + (isBlockA) ? tg::dtypeGetNumBits(options.mDtypeA) : tg::dtypeGetNumBits(options.mDtypeB); int32_t const elemsIn128B = 128 * 8 /* Bits in byte */ / elemSizeInBits; if (options.mBlockK != elemsIn128B) { @@ -1234,12 +1424,12 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in if (options.mBlockK > options.mTileK) { TLLM_CHECK_ERROR( - options.mBlockK % options.mTileK == 0, - "If block size is greater than tile size, block size must be a multiple of tile size"); + options.mBlockK % options.mTileK == 0, + "If block size is greater than tile size, block size must be a multiple of tile size"); } else if (options.mBlockK < options.mTileK) { TLLM_CHECK_ERROR( - options.mTileK % options.mBlockK == 0, - "If tile size is greater than block size, tile size must be a multiple of block size"); + options.mTileK % options.mBlockK == 0, + "If tile size is greater than block size, tile size must be a multiple of block size"); } } @@ -1252,14 +1442,33 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in if (updateOptions) { // Init kernel traits. - options.mKernelTraits = KernelTraits( - options.mDtypeA, options.mDtypeB, options.mDtypeC, options.mDtypeAcc, options.mDtypeMmaA, - options.mDtypeMmaB, options.mMmaKind, options.mMmaK, options.mTileM, options.mTileN, - options.mTileK, options.mEpilogueTileM, options.mEpilogueTileN, options.mNumStages, - options.mNumStagesMma, options.mNumSlicesForSplitK, options.mNumSlicesForSliceK, - options.mSplitK, options.mUseTmaStore, options.mTransposeMmaOutput, options.mAllReduceAlgo, - options.mTileScheduler == TileScheduler::Persistent, options.mUseDeepSeekFp8, - options.mUsePerTokenSfA, options.mUsePerTokenSfB, options.mBiasType); + options.mKernelTraits = KernelTraits(options.mDtypeA, + options.mDtypeB, + options.mDtypeC, + options.mDtypeAcc, + options.mDtypeMmaA, + options.mDtypeMmaB, + options.mMmaKind, + options.mMmaK, + options.mTileM, + options.mTileN, + options.mTileK, + options.mEpilogueTileM, + options.mEpilogueTileN, + options.mNumStages, + options.mNumStagesMma, + options.mNumSlicesForSplitK, + options.mNumSlicesForSliceK, + options.mSplitK, + options.mUseTmaStore, + options.mTransposeMmaOutput, + options.mAllReduceAlgo, + options.mTileScheduler == TileScheduler::Persistent, + options.mUseDeepSeekFp8, + options.mUsePerTokenSfA, + options.mUsePerTokenSfB, + /* useTwoCtas*/ options.mClusterDimX == 2, + options.mBiasType); } return true; @@ -1267,7 +1476,7 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace gemm +} // namespace gemm #ifdef TLLM_GEN_EXPORT_INTERFACE @@ -1278,6 +1487,6 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in #undef TLLM_LOG_INFO #undef TLLM_LOG_ERROR -#endif // TLLM_GEN_EXPORT_INTERFACE +#endif // TLLM_GEN_EXPORT_INTERFACE -} // namespace batchedGemm +} // namespace batchedGemm diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h index eba3f54737..03edfb149e 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h @@ -1,28 +1,28 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & - * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +* SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & +* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ #pragma once #include +#include "trtllm/gen/SfLayoutDecl.h" +#include "trtllm/gen/CommonUtils.h" -#include "BatchedGemmEnums.h" -#include "Enums.h" #include "TmaDescriptor.h" -#include "trtllm/gen/CommonUtils.h" -#include "trtllm/gen/SfLayoutDecl.h" +#include "Enums.h" +#include "BatchedGemmEnums.h" // NOTE: keep this code dependency free. It has to be included by the device code and has to be // compilable with NVRTC. @@ -31,12 +31,10 @@ namespace batchedGemm { namespace batchedGemm { - //////////////////////////////////////////////////////////////////////////////////////////////////// // TODO: Find a better header to put this in, that we can include from here. -template -inline T ceilDiv(T m, T n) { +template inline T ceilDiv(T m, T n) { return (m + n - T(1)) / n; } @@ -57,24 +55,21 @@ enum class MatrixType { MatrixA = 0, MatrixB, MatrixC }; // ////////////////////////////////////////////////////////////////////////////////////////////////// -template -bool useTmaOobOptA(BatchedGemmOptions const& options) { +template bool useTmaOobOptA(BatchedGemmOptions const& options) { return options.mBatchMode == BatchedGemmOptions::BatchMode::BatchM && doesRouteImplUseNoRoute(options.mRouteImpl) && options.mUseTmaOobOpt; } ////////////////////////////////////////////////////////////////////////////////////////////////// -template -bool useTmaOobOptB(BatchedGemmOptions const& options) { +template bool useTmaOobOptB(BatchedGemmOptions const& options) { return options.mBatchMode == BatchedGemmOptions::BatchMode::BatchN && doesRouteImplUseNoRoute(options.mRouteImpl) && options.mUseTmaOobOpt; } ////////////////////////////////////////////////////////////////////////////////////////////////// -template -bool useTmaOobOptC(BatchedGemmOptions const& options) { +template bool useTmaOobOptC(BatchedGemmOptions const& options) { return options.mUseTmaStore && options.mUseTmaOobOpt; } @@ -82,8 +77,14 @@ bool useTmaOobOptC(BatchedGemmOptions const& options) { // Create the TMA shape/stride for A/B/C. template -static auto makeTmaShapeStrideAbc(GemmOptions const& options, int mM, int mN, int mK, int tileM, - int tileN, int tileK, MatrixType matrixType) { +static auto makeTmaShapeStrideAbc(GemmOptions const& options, + int mM, + int mN, + int mK, + int tileM, + int tileN, + int tileK, + MatrixType matrixType) { // Weights matrix is A if we transpose the output of MMA (to have it M-major). // Otherwise, it is B, when the output of MMA is K-major. bool const isWeights = (matrixType == MatrixType::MatrixA && options.mTransposeMmaOutput) || @@ -98,13 +99,13 @@ static auto makeTmaShapeStrideAbc(GemmOptions const& options, int mM, int mN, in // The outer dimension. auto numTokens = - (matrixType == MatrixType::MatrixA || matrixType == MatrixType::MatrixC) ? mM : mN; + (matrixType == MatrixType::MatrixA || matrixType == MatrixType::MatrixC) ? mM : mN; // The outer dimension tile size. auto ctaTileNumTokens = - (matrixType == MatrixType::MatrixA || matrixType == MatrixType::MatrixC) ? tileM : tileN; + (matrixType == MatrixType::MatrixA || matrixType == MatrixType::MatrixC) ? tileM : tileN; // The outer dimension of TMA box shape. auto tileNumTokens = - (matrixType == MatrixType::MatrixC) ? options.mEpilogueTileM : ctaTileNumTokens; + (matrixType == MatrixType::MatrixC) ? options.mEpilogueTileM : ctaTileNumTokens; // The inner dimension. auto hiddenSize = (matrixType == MatrixType::MatrixC) ? mN : mK; @@ -112,7 +113,7 @@ static auto makeTmaShapeStrideAbc(GemmOptions const& options, int mM, int mN, in auto ctaTileHiddenSize = (matrixType == MatrixType::MatrixC) ? tileN : tileK; // The inner dimension of TMA box shape. auto tileHiddenSize = - (matrixType == MatrixType::MatrixC) ? options.mEpilogueTileN : ctaTileHiddenSize; + (matrixType == MatrixType::MatrixC) ? options.mEpilogueTileN : ctaTileHiddenSize; // Swap matrix C sizes if output is transposed. if (matrixType == MatrixType::MatrixC && options.mTransposeMmaOutput) { @@ -137,15 +138,17 @@ static auto makeTmaShapeStrideAbc(GemmOptions const& options, int mM, int mN, in std::vector shape = {static_cast(hiddenSize), static_cast(numTokens)}; if (useTmaOobOpt /* also implies input/output activation */) { - // If TMA OOB optimization is used, we use 3D logical shape (M, tileM, K) or (N, tileN, K). - // The outer dimension is extended to make room for the possible counterbalance positive - // offset from the middle "bound" dimension. The counterbalance should be no more than - // ctaTileNumTokens. - shape = {static_cast(hiddenSize), static_cast(ctaTileNumTokens), - static_cast(numTokens + ctaTileNumTokens)}; + // If TMA OOB optimization is used: + // Shape [hidden, tokens] Stride [1, hidden] becomes + // Shape [hidden, tileN, TmaDimMax, TmaDimMax] Stride [1, hidden, XLargeN - hidden, hidden] + shape = {static_cast(hiddenSize), + static_cast(ctaTileNumTokens), + static_cast(tg::TmaDimMax), + static_cast(tg::TmaDimMax)}; } else if (isWeights) { // If the matrix is a weights matrix, we use 3D logical shape (B, M, K) or (B, N, K). - shape = {static_cast(hiddenSize), static_cast(numTokens), + shape = {static_cast(hiddenSize), + static_cast(numTokens), static_cast(options.mNumBatches)}; } @@ -153,9 +156,13 @@ static auto makeTmaShapeStrideAbc(GemmOptions const& options, int mM, int mN, in // Swap the first two dimension as mentioned before. std::vector stride = {1, static_cast(hiddenSize)}; if (useTmaOobOpt) { - stride = {1, static_cast(hiddenSize), static_cast(hiddenSize)}; + stride = {1, + static_cast(hiddenSize), + static_cast(tg::XLargeN - hiddenSize), + static_cast(hiddenSize)}; } else if (isWeights) { - stride = {1, static_cast(hiddenSize), + stride = {1, + static_cast(hiddenSize), static_cast(hiddenSize) * static_cast(numTokens)}; } @@ -165,7 +172,7 @@ static auto makeTmaShapeStrideAbc(GemmOptions const& options, int mM, int mN, in // Alternate layouts (MajorMn and BlockMajorK) do not apply to matrixC if (matrixType != MatrixType::MatrixC) { gemm::MatrixLayout layout = - (matrixType == MatrixType::MatrixA) ? options.mLayoutA : options.mLayoutB; + (matrixType == MatrixType::MatrixA) ? options.mLayoutA : options.mLayoutB; // Note, only the weights support non MajorK layouts if (layout == gemm::MatrixLayout::MajorMn) { // Apply transpose if necessary @@ -174,10 +181,12 @@ static auto makeTmaShapeStrideAbc(GemmOptions const& options, int mM, int mN, in std::swap(tileShape[0], tileShape[1]); } else if (layout == gemm::MatrixLayout::BlockMajorK) { // Set shapes based on blocking layout - shape = {static_cast(options.mBlockK), static_cast(numTokens), + shape = {static_cast(options.mBlockK), + static_cast(numTokens), static_cast(mK / options.mBlockK), static_cast(options.mNumBatches)}; - stride = {1, static_cast(options.mBlockK), + stride = {1, + static_cast(options.mBlockK), static_cast(numTokens * options.mBlockK), static_cast(hiddenSize * numTokens)}; @@ -191,9 +200,17 @@ static auto makeTmaShapeStrideAbc(GemmOptions const& options, int mM, int mN, in } // Create the TMA shape/stride for A/B block scaling factors. -static auto makeTmaShapeStrideSfAb(int mM, int mN, int mK, MatrixType matrixType, int tileM, - int tileN, int tileK, tg::SfLayout layout, int sfReshapeFactor, +static auto makeTmaShapeStrideSfAb(int mM, + int mN, + int mK, + MatrixType matrixType, + int tileM, + int tileN, + int tileK, + tg::SfLayout layout, + int sfReshapeFactor, const int32_t numEltsPerSf) { + // The outer dimension. auto numTokens = matrixType == MatrixType::MatrixA ? mM : mN; // The inner dimension. @@ -204,102 +221,121 @@ static auto makeTmaShapeStrideSfAb(int mM, int mN, int mK, MatrixType matrixType auto hiddenSizePerTile = tileK; switch (layout) { - case tg::SfLayout::R128c4: { - // The scaling factor tensor packs 128x4 tiles into contiguous 512B blocks. - // The 512B block maps to a 32x16B (32x128b) block in TMEM. - // See https://nvbugspro.nvidia.com/bug/4165523 - // - // Additionally, we have to meet constraints of TMA that the box dimensions are less - // than 256 and boxDim[0] is a multiple of 16B. - // - // The "logical" tensor is: [outer, inner / numEltsPerSf] - // The aforementioned format is: [outer / 128, inner / numEltsPerSf / 4, 512] - // The shape we use for TMA is: [outer / 128, inner / numEltsPerSf / 4, 2, 256] - - auto shape = std::vector{ - 256, 2, static_cast(ceilDiv(hiddenSize, numEltsPerSf * 4)), - static_cast(ceilDiv(numTokens, 128))}; - - std::vector stride(shape.size()); - stride[0] = 1; - for (size_t i = 1; i < shape.size(); i++) { - stride[i] = shape[i - 1] * stride[i - 1]; - } - - auto tileShapes = std::vector{ - 256, 2, static_cast(ceilDiv(hiddenSizePerTile, numEltsPerSf * 4)), - static_cast(ceilDiv(numTokensPerTile, 128))}; - - return std::make_tuple(shape, stride, tileShapes); + case tg::SfLayout::R128c4: { + // The scaling factor tensor packs 128x4 tiles into contiguous 512B blocks. + // The 512B block maps to a 32x16B (32x128b) block in TMEM. + // See https://nvbugspro.nvidia.com/bug/4165523 + // + // Additionally, we have to meet constraints of TMA that the box dimensions are less + // than 256 and boxDim[0] is a multiple of 16B. + // + // The "logical" tensor is: [outer, inner / numEltsPerSf] + // The aforementioned format is: [outer / 128, inner / numEltsPerSf / 4, 512] + // The shape we use for TMA is: [outer / 128, inner / numEltsPerSf / 4, 2, 256] + + auto shape = std::vector{256, + 2, + static_cast(ceilDiv(hiddenSize, numEltsPerSf * 4)), + static_cast(ceilDiv(numTokens, 128))}; + + std::vector stride(shape.size()); + stride[0] = 1; + for (size_t i = 1; i < shape.size(); i++) { + stride[i] = shape[i - 1] * stride[i - 1]; } - case tg::SfLayout::R8c4: { - // The scaling factor tensor packs 8x4 tiles into contiguous 32B blocks. - // - // As the inner dimension (k) is often a multiple of the tile size, we can reshape to use - // fewer read requests, if the tile dimensions allow. It does not reduce the number of - // instructions. - // - // I.e., let's define r = min(⌈hiddenSizePerTile / (numEltsPerSf * 4)⌉, 8) - // - // The "logical" tensor is: [outer, inner / numEltsPerSf] - // The 8x4 SF layout is: [⌈outer / 8⌉, inner / (4 * numEltsPerSf), 32] - // The TMA tensor shape is: [⌈outer / 8⌉, inner / (4 * numEltsPerSf * r), r * 32] - // - // The caveat of NumRepeats>1 is we must pad the hidden dimension of SF to multiples of - // NumRepeats * numEltsPerSf * 4. - - // Detect if the supplied factor is power of 2. E.g., 0b0100 and (0b0100 - 1) == 0b0000. - int const r = sfReshapeFactor; - if (r > 0 && (r & (r - 1)) != 0) { - throw std::runtime_error("mSfReshapeFactor must be positive and a power of 2. Found " + - std::to_string(r)); - } + auto tileShapes = + std::vector{256, + 2, + static_cast(ceilDiv(hiddenSizePerTile, numEltsPerSf * 4)), + static_cast(ceilDiv(numTokensPerTile, 128))}; - // Sanitize number of repeats so it doesn't exceed the dimension. - int const repeats = std::min(ceilDiv(hiddenSizePerTile, numEltsPerSf * 4), r); + return std::make_tuple(shape, stride, tileShapes); + } - // Detect if the input hidden size K is a multiple of the repeats. - if (ceilDiv(hiddenSize, numEltsPerSf * 4) % repeats != 0) { - throw std::runtime_error( - "SF hiddenSize K (" + std::to_string(ceilDiv(hiddenSize, numEltsPerSf * 4)) + - ") must be a multiple of repeats (" + std::to_string(repeats) + ")"); - } + case tg::SfLayout::R8c4: { + // The scaling factor tensor packs 8x4 tiles into contiguous 32B blocks. + // + // As the inner dimension (k) is often a multiple of the tile size, we can reshape to use + // fewer read requests, if the tile dimensions allow. It does not reduce the number of + // instructions. + // + // I.e., let's define r = min(⌈hiddenSizePerTile / (numEltsPerSf * 4)⌉, 8) + // + // The "logical" tensor is: [outer, inner / numEltsPerSf] + // The 8x4 SF layout is: [⌈outer / 8⌉, inner / (4 * numEltsPerSf), 32] + // The TMA tensor shape is: [⌈outer / 8⌉, inner / (4 * numEltsPerSf * r), r * 32] + // + // The caveat of NumRepeats>1 is we must pad the hidden dimension of SF to multiples of + // NumRepeats * numEltsPerSf * 4. + + // Detect if the supplied factor is power of 2. E.g., 0b0100 and (0b0100 - 1) == 0b0000. + int const r = sfReshapeFactor; + if (r > 0 && (r & (r - 1)) != 0) { + throw std::runtime_error("mSfReshapeFactor must be positive and a power of 2. Found " + + std::to_string(r)); + } - auto shape = std::vector{ - static_cast(repeats * 32), - static_cast(ceilDiv(hiddenSize, numEltsPerSf * 4 * repeats)), - static_cast(ceilDiv(numTokens, 8))}; + // Sanitize number of repeats so it doesn't exceed the dimension. + int const repeats = std::min(ceilDiv(hiddenSizePerTile, numEltsPerSf * 4), r); - std::vector stride(shape.size()); - stride[0] = 1; - for (size_t i = 1; i < shape.size(); i++) { - stride[i] = shape[i - 1] * stride[i - 1]; - } + // Detect if the input hidden size K is a multiple of the repeats. + if (ceilDiv(hiddenSize, numEltsPerSf * 4) % repeats != 0) { + throw std::runtime_error("SF hiddenSize K (" + + std::to_string(ceilDiv(hiddenSize, numEltsPerSf * 4)) + + ") must be a multiple of repeats (" + std::to_string(repeats) + ")"); + } - auto tileShapes = std::vector{ - static_cast(repeats * 32), - static_cast(ceilDiv(hiddenSizePerTile, numEltsPerSf * 4 * repeats)), - static_cast(ceilDiv(numTokensPerTile, 8))}; + auto shape = + std::vector{static_cast(repeats * 32), + static_cast(ceilDiv(hiddenSize, numEltsPerSf * 4 * repeats)), + static_cast(ceilDiv(numTokens, 8))}; - return std::make_tuple(shape, stride, tileShapes); + std::vector stride(shape.size()); + stride[0] = 1; + for (size_t i = 1; i < shape.size(); i++) { + stride[i] = shape[i - 1] * stride[i - 1]; } - default: - throw std::runtime_error("Unsupported SF layout"); + auto tileShapes = std::vector{ + static_cast(repeats * 32), + static_cast(ceilDiv(hiddenSizePerTile, numEltsPerSf * 4 * repeats)), + static_cast(ceilDiv(numTokensPerTile, 8))}; + + return std::make_tuple(shape, stride, tileShapes); + } + + default: + throw std::runtime_error("Unsupported SF layout"); } return std::make_tuple(std::vector{}, std::vector{}, std::vector{}); } template -static KernelParams setKernelParams( - GemmOptions_ const& options, bool const batchM, void const* ptrA, void const* ptrB, void* ptrC, - void const* dSfA, void const* dSfB, void const* ptrPerTokenSfA, void const* ptrPerTokenSfB, - void const* ptrBias, void* dSfC, float const* ptrScaleC, float const* ptrScaleGate, - float const* ptrClampLimit, float const* ptrGatedActAlpha, float const* ptrGatedActBeta, - int32_t const* routeMap, float* rowMax, uint32_t* rowMaxBars, - int32_t const* ptrNumNonExitingCtas = nullptr, int32_t const* ptrTotalNumPaddedTokens = nullptr, - int32_t const* ptrCtaIdxXyToBatchIdx = nullptr, int32_t const* ptrCtaIdxXyToMnLimit = nullptr, - int32_t const maxNumCtas = KernelParams::MaxNumCtas) { +static KernelParams setKernelParams(GemmOptions_ const& options, + bool const batchM, + void const* ptrA, + void const* ptrB, + void* ptrC, + void const* dSfA, + void const* dSfB, + void const* ptrPerTokenSfA, + void const* ptrPerTokenSfB, + void const* ptrBias, + void* dSfC, + float const* ptrScaleC, + float const* ptrScaleGate, + float const* ptrClampLimit, + float const* ptrGatedActAlpha, + float const* ptrGatedActBeta, + int32_t const* routeMap, + float* rowMax, + uint32_t* rowMaxBars, + int32_t const* ptrNumNonExitingCtas = nullptr, + int32_t const* ptrTotalNumPaddedTokens = nullptr, + int32_t const* ptrCtaIdxXyToBatchIdx = nullptr, + int32_t const* ptrCtaIdxXyToMnLimit = nullptr, + int32_t const maxNumCtas = KernelParams::MaxNumCtas) { + static_assert(sizeof(KernelParams) <= 32 * 1024, "sizeof(KernelParams) has to be less or equal than 32KB"); @@ -324,6 +360,7 @@ static KernelParams setKernelParams( if (options.mIsStaticBatch) { params.totalNumPaddedTokens = 0; for (int b = 0; b < options.mNumBatches; b++) { + int mM = batchM ? options.mBatchedM[b] : options.mM; int mN = batchM ? options.mN : options.mBatchedN[b]; @@ -351,7 +388,7 @@ static KernelParams setKernelParams( // This is now an identity map and it is no longer needed. // params.ctaIdxXyToTileIdxMn[ctaOffset + cta] = ctaOffset + cta; params.ctaIdxXyToMnLimit[ctaOffset + cta] = - std::min((ctaOffset + cta + 1) * tile, ctaOffset * tile + tokensPerTile); + std::min((ctaOffset + cta + 1) * tile, ctaOffset * tile + tokensPerTile); } ctaOffset += numCtas; @@ -385,12 +422,21 @@ static KernelParams setKernelParams( params.tileStridePerBatch = options.mM / options.mTileM; params.nm = options.mM; // Shape/stride for gmem tensor A. - auto [shapeA, strideA, tileShapeA] = - makeTmaShapeStrideAbc(options, options.mM, options.mN, options.mK, options.mTileM, - options.mTileN, options.mTileK, MatrixType::MatrixA); + auto [shapeA, strideA, tileShapeA] = makeTmaShapeStrideAbc(options, + options.mM, + options.mN, + options.mK, + options.mTileM, + options.mTileN, + options.mTileK, + MatrixType::MatrixA); // Build tma descriptor for A. - params.tmaA[0] = gemm::buildNdTmaDescriptor(options.mDtypeA, options.mMmaKind, shapeA, strideA, - tileShapeA, const_cast(ptrA)); + params.tmaA[0] = gemm::buildNdTmaDescriptor(options.mDtypeA, + options.mMmaKind, + shapeA, + strideA, + tileShapeA, + const_cast(ptrA)); // The input is padded: // [act0, padding, padding, ... TileN size .., act1, padding, padding, ...] @@ -400,35 +446,55 @@ static KernelParams setKernelParams( bool useRouteAct = batchedGemm::doesRouteImplUseTma(options.mRouteImpl); // B is the activation // Shape/stride for gmem tensor B. - auto [shapeB, strideB, tileShapeB] = makeTmaShapeStrideAbc( - options, options.mM, useRouteAct ? options.mNumTokens : inputNumTokens, options.mK, - options.mTileM, (useRouteAct ? 1 : options.mTileN), options.mTileK, MatrixType::MatrixB); + auto [shapeB, strideB, tileShapeB] = + makeTmaShapeStrideAbc(options, + options.mM, + useRouteAct ? options.mNumTokens : inputNumTokens, + options.mK, + options.mTileM, + (useRouteAct ? 1 : options.mTileN), + options.mTileK, + MatrixType::MatrixB); // Build tma descriptor for B. - params.tmaB[0] = gemm::buildNdTmaDescriptor(options.mDtypeB, options.mMmaKind, shapeB, - strideB, tileShapeB, const_cast(ptrB)); + params.tmaB[0] = gemm::buildNdTmaDescriptor(options.mDtypeB, + options.mMmaKind, + shapeB, + strideB, + tileShapeB, + const_cast(ptrB)); } if (options.mDtypeA == tg::Dtype::E2m1 || options.mDtypeA == tg::Dtype::MxE4m3 || options.mDtypeA == tg::Dtype::MxE2m1) { tg::Dtype const dTypeSf = - (options.mDtypeA == tg::Dtype::E2m1) ? tg::Dtype::E4m3 : tg::Dtype::UE8m0; + (options.mDtypeA == tg::Dtype::E2m1) ? tg::Dtype::E4m3 : tg::Dtype::UE8m0; // Build TMA descriptor for gmem A block scaling factors. auto [shapeSfA, strideSfA, tileShapesSfA] = makeTmaShapeStrideSfAb( - options.mM * options.mNumBatches, options.mN, options.mK, MatrixType::MatrixA, - options.mTileM, options.mTileN, options.mTileK, tg::SfLayout::R128c4, - options.mSfReshapeFactor, - options.mSfBlockSizeA.value_or(tg::dtypeNumEltsPerSf(options.mDtypeA))); - params.tmaSfA[0] = gemm::buildSfTmaDescriptor(dTypeSf, shapeSfA, strideSfA, tileShapesSfA, + options.mM * options.mNumBatches, + options.mN, + options.mK, + MatrixType::MatrixA, + options.mTileM, + options.mTileN, + options.mTileK, + tg::SfLayout::R128c4, + options.mSfReshapeFactor, + options.mSfBlockSizeA.value_or(tg::dtypeNumEltsPerSf(options.mDtypeA))); + params.tmaSfA[0] = gemm::buildSfTmaDescriptor(dTypeSf, + shapeSfA, + strideSfA, + tileShapesSfA, const_cast(dSfA)); } if (options.mDtypeB == tg::Dtype::E2m1 || options.mDtypeB == tg::Dtype::MxE4m3 || options.mDtypeB == tg::Dtype::MxE2m1) { tg::Dtype const dTypeSf = - (options.mDtypeB == tg::Dtype::E2m1) ? tg::Dtype::E4m3 : tg::Dtype::UE8m0; + (options.mDtypeB == tg::Dtype::E2m1) ? tg::Dtype::E4m3 : tg::Dtype::UE8m0; if (batchedGemm::doesRouteImplUseTma(options.mRouteImpl)) { + // The input is NOT padded: // [act0, act1, act2, ...] @@ -439,24 +505,45 @@ static KernelParams setKernelParams( auto numSfsInK = options.mK / numEltsPerSf; numSfsInK = ceilDiv(numSfsInK, 16) * 16; - auto [shapeSfB, strideSfB, tileShapesSfB] = makeTmaShapeStrideAbc( - options, options.mM, options.mNumTokens, numSfsInK, options.mTileM, 1 /* tileN */, - options.mTileK / numEltsPerSf, MatrixType::MatrixB); - params.tmaSfB[0] = gemm::buildNdTmaDescriptor( - dTypeSf, options.mMmaKind, shapeSfB, strideSfB, tileShapesSfB, const_cast(dSfB), - /*doSwizzle*/ true); + auto [shapeSfB, strideSfB, tileShapesSfB] = + makeTmaShapeStrideAbc(options, + options.mM, + options.mNumTokens, + numSfsInK, + options.mTileM, + 1 /* tileN */, + options.mTileK / numEltsPerSf, + MatrixType::MatrixB); + params.tmaSfB[0] = gemm::buildNdTmaDescriptor(dTypeSf, + options.mMmaKind, + shapeSfB, + strideSfB, + tileShapesSfB, + const_cast(dSfB), + /*doSwizzle*/ true); } else if (batchedGemm::doesRouteImplUseNoRoute(options.mRouteImpl)) { + // The input is padded: // [act0, padding, padding, ... TileN size .., act1, padding, padding, ...] auto const inputNumTokensSfB = ctaOffset * options.mTileN; // Build TMA descriptor for gmem B block scaling factors. - auto [shapeSfB, strideSfB, tileShapesSfB] = makeTmaShapeStrideSfAb( - options.mM, inputNumTokensSfB, options.mK, MatrixType::MatrixB, options.mTileM, - options.mTileN, options.mTileK, options.mSfLayoutB, options.mSfReshapeFactor, - tg::dtypeNumEltsPerSf(options.mDtypeB)); - params.tmaSfB[0] = gemm::buildSfTmaDescriptor(dTypeSf, shapeSfB, strideSfB, tileShapesSfB, + auto [shapeSfB, strideSfB, tileShapesSfB] = + makeTmaShapeStrideSfAb(options.mM, + inputNumTokensSfB, + options.mK, + MatrixType::MatrixB, + options.mTileM, + options.mTileN, + options.mTileK, + options.mSfLayoutB, + options.mSfReshapeFactor, + tg::dtypeNumEltsPerSf(options.mDtypeB)); + params.tmaSfB[0] = gemm::buildSfTmaDescriptor(dTypeSf, + shapeSfB, + strideSfB, + tileShapesSfB, const_cast(dSfB)); } } @@ -464,12 +551,21 @@ static KernelParams setKernelParams( // C is the output activation if (options.mUseTmaStore) { // Shape/stride for gmem tensor C. - auto [shapeC, strideC, tileShapeC] = makeTmaShapeStrideAbc( - options, options.mM, ctaOffset * options.mTileN, options.mK, options.mTileM, - options.mTileN, options.mTileK, MatrixType::MatrixC); + auto [shapeC, strideC, tileShapeC] = makeTmaShapeStrideAbc(options, + options.mM, + ctaOffset * options.mTileN, + options.mK, + options.mTileM, + options.mTileN, + options.mTileK, + MatrixType::MatrixC); // Build tma descriptor for C. - params.tmaC[0] = gemm::buildNdTmaDescriptor(options.mDtypeC, tg::MmaKind::Auto, shapeC, - strideC, tileShapeC, ptrC); + params.tmaC[0] = gemm::buildNdTmaDescriptor(options.mDtypeC, + tg::MmaKind::Auto, + shapeC, + strideC, + tileShapeC, + ptrC); } else { params.ptrC = ptrC; } @@ -482,12 +578,21 @@ static KernelParams setKernelParams( params.tileStridePerBatch = options.mN / options.mTileN; params.nm = options.mN; // Shape/stride for gmem tensor B. - auto [shapeB, strideB, tileShapeB] = - makeTmaShapeStrideAbc(options, options.mM, options.mN, options.mK, options.mTileM, - options.mTileN, options.mTileK, MatrixType::MatrixB); + auto [shapeB, strideB, tileShapeB] = makeTmaShapeStrideAbc(options, + options.mM, + options.mN, + options.mK, + options.mTileM, + options.mTileN, + options.mTileK, + MatrixType::MatrixB); // Build tma descriptor for B. - params.tmaB[0] = gemm::buildNdTmaDescriptor(options.mDtypeB, options.mMmaKind, shapeB, strideB, - tileShapeB, const_cast(ptrB)); + params.tmaB[0] = gemm::buildNdTmaDescriptor(options.mDtypeB, + options.mMmaKind, + shapeB, + strideB, + tileShapeB, + const_cast(ptrB)); if (options.mRouteImpl == batchedGemm::RouteImpl::NoRoute) { // A is the activation @@ -495,30 +600,50 @@ static KernelParams setKernelParams( // The input is padded: // [act0, padding, padding, ... tileM size .., act1, padding, padding, ...] auto const inputNumTokens = ctaOffset * options.mTileM; - auto [shapeA, strideA, tileShapeA] = - makeTmaShapeStrideAbc(options, inputNumTokens, options.mN, options.mK, options.mTileM, - options.mTileN, options.mTileK, MatrixType::MatrixA); + auto [shapeA, strideA, tileShapeA] = makeTmaShapeStrideAbc(options, + inputNumTokens, + options.mN, + options.mK, + options.mTileM, + options.mTileN, + options.mTileK, + MatrixType::MatrixA); // Build tma descriptor for A. - params.tmaA[0] = gemm::buildNdTmaDescriptor(options.mDtypeA, options.mMmaKind, shapeA, - strideA, tileShapeA, const_cast(ptrA)); + params.tmaA[0] = gemm::buildNdTmaDescriptor(options.mDtypeA, + options.mMmaKind, + shapeA, + strideA, + tileShapeA, + const_cast(ptrA)); } if (options.mDtypeA == tg::Dtype::E2m1 || options.mDtypeA == tg::Dtype::MxE4m3 || options.mDtypeA == tg::Dtype::MxE2m1) { tg::Dtype const dTypeSf = - (options.mDtypeA == tg::Dtype::E2m1) ? tg::Dtype::E4m3 : tg::Dtype::UE8m0; + (options.mDtypeA == tg::Dtype::E2m1) ? tg::Dtype::E4m3 : tg::Dtype::UE8m0; if (options.mRouteImpl == batchedGemm::RouteImpl::NoRoute) { + // The input is padded: // [act0, padding, padding, ... tileM size .., act1, padding, padding, ...] auto const inputNumTokensSfA = ctaOffset * options.mTileM; // Build TMA descriptor for gmem A block scaling factors. auto [shapeSfA, strideSfA, tileShapesSfA] = makeTmaShapeStrideSfAb( - inputNumTokensSfA, options.mN, options.mK, MatrixType::MatrixA, options.mTileM, - options.mTileN, options.mTileK, tg::SfLayout::R128c4, options.mSfReshapeFactor, - options.mSfBlockSizeA.value_or(tg::dtypeNumEltsPerSf(options.mDtypeA))); - params.tmaSfA[0] = gemm::buildSfTmaDescriptor(dTypeSf, shapeSfA, strideSfA, tileShapesSfA, + inputNumTokensSfA, + options.mN, + options.mK, + MatrixType::MatrixA, + options.mTileM, + options.mTileN, + options.mTileK, + tg::SfLayout::R128c4, + options.mSfReshapeFactor, + options.mSfBlockSizeA.value_or(tg::dtypeNumEltsPerSf(options.mDtypeA))); + params.tmaSfA[0] = gemm::buildSfTmaDescriptor(dTypeSf, + shapeSfA, + strideSfA, + tileShapesSfA, const_cast(dSfA)); } } @@ -526,26 +651,45 @@ static KernelParams setKernelParams( if (options.mDtypeB == tg::Dtype::E2m1 || options.mDtypeB == tg::Dtype::MxE4m3 || options.mDtypeB == tg::Dtype::MxE2m1) { tg::Dtype const dTypeSf = - (options.mDtypeB == tg::Dtype::E2m1) ? tg::Dtype::E4m3 : tg::Dtype::UE8m0; + (options.mDtypeB == tg::Dtype::E2m1) ? tg::Dtype::E4m3 : tg::Dtype::UE8m0; // Build TMA descriptor for gmem B block scaling factors. - auto [shapeSfB, strideSfB, tileShapesSfB] = makeTmaShapeStrideSfAb( - options.mM, options.mN * options.mNumBatches, options.mK, MatrixType::MatrixB, - options.mTileM, options.mTileN, options.mTileK, options.mSfLayoutB, - options.mSfReshapeFactor, tg::dtypeNumEltsPerSf(options.mDtypeB)); - params.tmaSfB[0] = gemm::buildSfTmaDescriptor(dTypeSf, shapeSfB, strideSfB, tileShapesSfB, + auto [shapeSfB, strideSfB, tileShapesSfB] = + makeTmaShapeStrideSfAb(options.mM, + options.mN * options.mNumBatches, + options.mK, + MatrixType::MatrixB, + options.mTileM, + options.mTileN, + options.mTileK, + options.mSfLayoutB, + options.mSfReshapeFactor, + tg::dtypeNumEltsPerSf(options.mDtypeB)); + params.tmaSfB[0] = gemm::buildSfTmaDescriptor(dTypeSf, + shapeSfB, + strideSfB, + tileShapesSfB, const_cast(dSfB)); } // C is the output activation if (options.mUseTmaStore) { // Shape/stride for gmem tensor C. - auto [shapeC, strideC, tileShapeC] = makeTmaShapeStrideAbc( - options, ctaOffset * options.mTileM, options.mN, options.mK, options.mTileM, - options.mTileN, options.mTileK, MatrixType::MatrixC); + auto [shapeC, strideC, tileShapeC] = makeTmaShapeStrideAbc(options, + ctaOffset * options.mTileM, + options.mN, + options.mK, + options.mTileM, + options.mTileN, + options.mTileK, + MatrixType::MatrixC); // Build tma descriptor for C. - params.tmaC[0] = gemm::buildNdTmaDescriptor(options.mDtypeC, tg::MmaKind::Auto, shapeC, - strideC, tileShapeC, ptrC); + params.tmaC[0] = gemm::buildNdTmaDescriptor(options.mDtypeC, + tg::MmaKind::Auto, + shapeC, + strideC, + tileShapeC, + ptrC); } else { params.ptrC = ptrC; } @@ -570,10 +714,10 @@ static KernelParams setKernelParams( return params; } #endif -}; // namespace KernelParamsSetup +}; // namespace KernelParamsSetup //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace batchedGemm +} // namespace batchedGemm -} // namespace batchedGemm +} // namespace batchedGemm diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h index 16b4af3149..a6056009de 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h @@ -1,22 +1,23 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & - * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +* SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & +* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ #pragma once namespace batchedGemm { + // This is device code struct KernelParams { @@ -29,54 +30,6 @@ struct KernelParams { // Maximum number of CTAs in the batch-token dimension. static constexpr int MaxNumCtas = 2048; - // NOTE: TMA out-of-bounds optimization for MoE padded tokens: - // - // Originally the padded tokens is a 2D tensor [hiddenDim, ctaGridDimY * tileN] with stride [1, - // hiddenDim] and box size [tileM, tileN] at pointer p. We waste bandwidth bytes since we only - // want to load [0, batchEnd) out of the [0, tileN) box size: batchEnd is a runtime variable while - // box size needs to be fixed at compile time. - // - // To deal with this, we reshape the tensor to 3D: [hiddenDim, tileN, ctaGridDimY * tileN] with - // stride [1, hiddenDim, hiddenDim] and box size [tileM, tileN, 1]. For the original 2D - // tensor, - // - // Offset Coords [ : , ctaIdxY * tileN ], - // Box Sizes [ : , tileN ], - // Coords Range [ : , ctaIdxY * tileN : ctaIdxY * tileN + tileN], - // - // while we only want load the range [ctaIdxY * tileN, ctaIdxY * tileN + batchEnd), 1 <= batchEnd - // <= tileN - // - // For the reshaped 3D tensor, - // - // Offset Coords [ : , tileN - batchEnd , - // ctaIdxY * tileN + batchEnd ], - // Box Sizes [ : , tileN , - // 1 ], - // Coords Range [ : , tileN - batchEnd : min(tileN, 2 * tileN - batchEnd), - // ctaIdxY * tileN + batchEnd : ctaIdx * tileN + batchEnd + 1], - // - // while min(tileN, 2 * tileN - batchEnd) always evaluates to tileN. The unwanted tokens are - // essentially filtered out by utilizing the OOB feature of TMA. Since the 2nd and 3rd dimension - // has the same stride, we end up loading the following (adding the left and right end of the 2nd - // and 3rd dimension ranges): - // - // Effective 2D Coords Range - // [ : , tileN + ctaIdxY * tileN : tileN + ctaIdxY * tileN + batchEnd], - // - // This is exactly the same as the original range except for the offset tileN, thus we also need - // to offset the pointer in the opposite direction: - // - // Ptr (p) -> Ptr (p - tileN * hiddenDim) - // - // Due to the restrictions of TMA unit, the above operations requires the TMA descriptor and the - // underlying buffer be constructed differently: - // - Requires valid buffer at (p - tileN * hidden) - needs prepending `tileN` tokens. - // - TMA outermost dimension must be extended by `tileN` or loads will OOB in the rightmost side. - // The latter is because when batchEnd == tileN, the offset coords in the 3rd dimension becomes - // ctaIdxY * tileN + tileN. When ctaIdxY = ctaGridDimY - 1, it becomes ((ctaGridDimY - 1) * tileN - // + tileN = ctaGridDimY * tileN which is equal to the 3rd dimension size and will be filtered - // out. That's why we need to extend the tensor size by tileN. // // TMA descriptor for A. // Must be setup using gemm::buildNdTmaDescriptor with shapes and strides from @@ -541,4 +494,4 @@ struct KernelParams { /////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace batchedGemm +} // namespace batchedGemm diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h index 640b3a69f0..6699530009 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h @@ -1,28 +1,27 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & - * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +* SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & +* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ #pragma once #include #include - -#include "Enums.h" -#include "trtllm/gen/CommonUtils.h" #include "trtllm/gen/DtypeDecl.h" +#include "trtllm/gen/CommonUtils.h" #include "trtllm/gen/MmaDecl.h" +#include "Enums.h" namespace batchedGemm { @@ -36,14 +35,17 @@ namespace tg = trtllm::gen; // Structure to manage memory allocation with configurable reuse class MemAllocatorHelper { - public: +public: // The default constructor. MemAllocatorHelper() {} // Constructor to initialize chunk sizes, alignments, and reuse flags MemAllocatorHelper(std::vector> const& sizes, - std::vector const& reuse, std::vector const& names) - : mNumBytesAndAlignmentPerSmemChunk(sizes), mFirstChunkReuse(reuse), mSmemChunkNames(names) {} + std::vector const& reuse, + std::vector const& names) + : mNumBytesAndAlignmentPerSmemChunk(sizes) + , mFirstChunkReuse(reuse) + , mSmemChunkNames(names) {} // Function to calculate the size of the array from 0 to jj chunks int32_t getOffsetBeforeChunk(int jj) const { @@ -95,14 +97,17 @@ class MemAllocatorHelper { // Print the contents of this object. void print() const { for (size_t ii = 0; ii < mNumBytesAndAlignmentPerSmemChunk.size(); ++ii) { - printf("Chunk %zd %s: %d bytes, %d alignment, reuse %s, offset %d\n", ii, - mSmemChunkNames[ii].c_str(), mNumBytesAndAlignmentPerSmemChunk[ii].first, - mNumBytesAndAlignmentPerSmemChunk[ii].second, mFirstChunkReuse[ii] ? "true" : "false", + printf("Chunk %zd %s: %d bytes, %d alignment, reuse %s, offset %d\n", + ii, + mSmemChunkNames[ii].c_str(), + mNumBytesAndAlignmentPerSmemChunk[ii].first, + mNumBytesAndAlignmentPerSmemChunk[ii].second, + mFirstChunkReuse[ii] ? "true" : "false", getChunkOffset(ii)); } } - private: +private: int32_t getChunkOffset(int32_t ii) const { if (mFirstChunkReuse[ii]) { // Reuse the offset of the 0th chunk. @@ -124,7 +129,7 @@ class MemAllocatorHelper { return (size + alignment - 1) & ~(alignment - 1); } - private: +private: // Sizes and alignment requirements of each chunk // NOTE: be careful and make sure that the memory dependency is clear and // chunks in the beginning of the SMEM can be overwritten. @@ -151,20 +156,39 @@ inline int getNumSmemBitsPerElt(tg::Dtype dtype, tg::MmaKind mmaKind) { //////////////////////////////////////////////////////////////////////////////////////////////////// class KernelTraits { - public: +public: // The default constructor. KernelTraits() {} // The constructor. - KernelTraits(tg::Dtype dtypeA, tg::Dtype dtypeB, tg::Dtype dtypeC, tg::Dtype dtypeAcc, - tg::Dtype dtypeMmaA, tg::Dtype dtypeMmaB, tg::MmaKind mmaKind, int32_t mmaK, - int32_t tileM, int32_t tileN, int32_t tileK, int32_t epilogueTileM, - int32_t epilogueTileN, int32_t numStages, int32_t numStagesMma, - int32_t numSlicesForSplitK, int32_t numSlicesForSliceK, SplitK splitK, - bool useTmaStore, bool transposeMmaOutput, AllReduceAlgo allReduceAlgo, - bool usePersistentScheduler, bool useDeepSeekFp8, bool usePerTokenSfA, - bool usePerTokenSfB, BiasType biasType) - : mMmaKind{mmaKind} { + KernelTraits(tg::Dtype dtypeA, + tg::Dtype dtypeB, + tg::Dtype dtypeC, + tg::Dtype dtypeAcc, + tg::Dtype dtypeMmaA, + tg::Dtype dtypeMmaB, + tg::MmaKind mmaKind, + int32_t mmaK, + int32_t tileM, + int32_t tileN, + int32_t tileK, + int32_t epilogueTileM, + int32_t epilogueTileN, + int32_t numStages, + int32_t numStagesMma, + int32_t numSlicesForSplitK, + int32_t numSlicesForSliceK, + SplitK splitK, + bool useTmaStore, + bool transposeMmaOutput, + AllReduceAlgo allReduceAlgo, + bool usePersistentScheduler, + bool useDeepSeekFp8, + bool usePerTokenSfA, + bool usePerTokenSfB, + bool useTwoCtas, + BiasType biasType) + : mMmaKind{mmaKind} { // // SMEM // @@ -198,7 +222,7 @@ class KernelTraits { { // Number of bytes in load A shared memory. auto const numSmemBytesLoadA = - numStages * tileM * tileK * getNumSmemBitsPerElt(dtypeA, mMmaKind) / 8 /* bits */; + numStages * tileM * tileK * getNumSmemBitsPerElt(dtypeA, mMmaKind) / 8 /* bits */; // Number of bytes for load A alignment for TMA load. auto const numBytesAlignmentLoadA = 1024; // loadA is already at first chunk. No need to reuse it. @@ -206,15 +230,15 @@ class KernelTraits { // Add info. smemChunkNames.emplace_back("smemLoadA"); numBytesAndAlignmentPerSmemChunk.emplace_back( - std::make_pair(numSmemBytesLoadA, numBytesAlignmentLoadA)); + std::make_pair(numSmemBytesLoadA, numBytesAlignmentLoadA)); firstChunkReuseSmem.emplace_back(reuseChunksSmemLoadA); } // LoadB { // Number of bytes in load B shared memory. - auto const numSmemBytesLoadB = - numStages * tileN * tileK * getNumSmemBitsPerElt(dtypeB, mMmaKind) / 8 /* bits */; + auto const numSmemBytesLoadB = numStages * (useTwoCtas ? tileN / 2 : tileN) * tileK * + getNumSmemBitsPerElt(dtypeB, mMmaKind) / 8 /* bits */; // Number of bytes for load B alignment for TMA load. auto const numBytesAlignmentLoadB = 1024; // No need to reuse the first chunk. @@ -222,7 +246,7 @@ class KernelTraits { // Add info. smemChunkNames.emplace_back("smemLoadB"); numBytesAndAlignmentPerSmemChunk.emplace_back( - std::make_pair(numSmemBytesLoadB, numBytesAlignmentLoadB)); + std::make_pair(numSmemBytesLoadB, numBytesAlignmentLoadB)); firstChunkReuseSmem.emplace_back(reuseChunksSmemLoadB); } @@ -234,9 +258,9 @@ class KernelTraits { { // Number of bytes in save shuffled B in shared memory. auto const numSmemBytesLoadB = - numSlicesForSliceK > 1 - ? numStages * tileN * tileK * getNumSmemBitsPerElt(dtypeB, mMmaKind) / 8 /* bits */ - : 0; + numSlicesForSliceK > 1 + ? numStages * tileN * tileK * getNumSmemBitsPerElt(dtypeB, mMmaKind) / 8 /* bits */ + : 0; // Number of bytes for load B alignment for TMA load. auto const numBytesAlignmentLoadB = 1024; // No need to reuse the first chunk. @@ -245,7 +269,7 @@ class KernelTraits { // Add info. smemChunkNames.emplace_back("smemBShuffle"); numBytesAndAlignmentPerSmemChunk.emplace_back( - std::make_pair(numSmemBytesLoadB, numBytesAlignmentLoadB)); + std::make_pair(numSmemBytesLoadB, numBytesAlignmentLoadB)); firstChunkReuseSmem.emplace_back(reuseChunksSmemLoadB); } @@ -273,21 +297,21 @@ class KernelTraits { // Number of bytes to store the output in smem. auto const numBytesSmemStoreC = usesSmemForGmemC - ? extraGmemCMultiplier * epilogueTileM * epilogueTileN * - tg::dtypeGetNumBits(dtypeSmemC) / 8 /* bits */ - : 0; + ? extraGmemCMultiplier * epilogueTileM * epilogueTileN * + tg::dtypeGetNumBits(dtypeSmemC) / 8 /* bits */ + : 0; // Number of bytes for store C alignment for TMA store. auto const numBytesAlignmentStoreC = 1024; // gmemC reuses loadAb memory for split-K in DSMEM. // Epilogue1 does not reuse and continues after the memory allocated Epilogue0 // NOTE: we can always reuse loadAb SMEM as long as we don't have persistent scheduler. auto const reuseFirstChunksSmemStoreC = - doesSplitKUseDsmem(splitK) && resIdx == 0 && !usePersistentScheduler; + doesSplitKUseDsmem(splitK) && resIdx == 0 && !usePersistentScheduler; // Add info. smemChunkNames.emplace_back("smemGmemC" + std::to_string(resIdx)); numBytesAndAlignmentPerSmemChunk.emplace_back( - std::make_pair(numBytesSmemStoreC, numBytesAlignmentStoreC)); + std::make_pair(numBytesSmemStoreC, numBytesAlignmentStoreC)); firstChunkReuseSmem.emplace_back(reuseFirstChunksSmemStoreC); } @@ -304,7 +328,7 @@ class KernelTraits { // Add info. smemChunkNames.emplace_back("smemRowMax"); numBytesAndAlignmentPerSmemChunk.emplace_back( - std::make_pair(numBytesSmemRowMax, numBytesAlignmentRowMax)); + std::make_pair(numBytesSmemRowMax, numBytesAlignmentRowMax)); firstChunkReuseSmem.emplace_back(false); } @@ -312,7 +336,7 @@ class KernelTraits { { // Real tile size before slice-K reduction. auto const tileSize = - numSlicesForSliceK > 1 ? numSlicesForSliceK * tileM * numSlicesForSliceK * tileN : 0; + numSlicesForSliceK > 1 ? numSlicesForSliceK * tileM * numSlicesForSliceK * tileN : 0; // Number of bytes for tile in SMEM. auto const numBytesSmemTile = tileSize * tg::dtypeGetNumBits(dtypeAcc) / 8 /* bits */; // Number of bytes alignment for rowMax in SMEM. @@ -321,7 +345,7 @@ class KernelTraits { // Add info. smemChunkNames.emplace_back("smemSliceK"); numBytesAndAlignmentPerSmemChunk.emplace_back( - std::make_pair(numBytesSmemTile, numBytesAlignmentTile)); + std::make_pair(numBytesSmemTile, numBytesAlignmentTile)); firstChunkReuseSmem.emplace_back(false); } @@ -335,7 +359,7 @@ class KernelTraits { // Add info. smemChunkNames.emplace_back("smemPerTokenSf"); numBytesAndAlignmentPerSmemChunk.emplace_back( - std::make_pair(numBytesSmemPerTokenSf, numBytesAlignmentPerTokenSf)); + std::make_pair(numBytesSmemPerTokenSf, numBytesAlignmentPerTokenSf)); firstChunkReuseSmem.emplace_back(false); } @@ -354,7 +378,7 @@ class KernelTraits { // Add info. smemChunkNames.emplace_back("smemBias"); numBytesAndAlignmentPerSmemChunk.emplace_back( - std::make_pair(numBytesSmemBias, numBytesAlignmentBias)); + std::make_pair(numBytesSmemBias, numBytesAlignmentBias)); firstChunkReuseSmem.emplace_back(false); } @@ -368,7 +392,7 @@ class KernelTraits { // Add info. smemChunkNames.emplace_back("smemBlockAmax"); numBytesAndAlignmentPerSmemChunk.emplace_back( - std::make_pair(numBytesSmemBlockAmax, numBytesAlignmentBlockAmax)); + std::make_pair(numBytesSmemBlockAmax, numBytesAlignmentBlockAmax)); firstChunkReuseSmem.emplace_back(false); } @@ -387,13 +411,13 @@ class KernelTraits { // Add info. smemChunkNames.emplace_back("smemConstSfBuf"); numBytesAndAlignmentPerSmemChunk.emplace_back( - std::make_pair(numSmemBytesConstSfBuf, numBytesAlignmentConstSfBuf)); + std::make_pair(numSmemBytesConstSfBuf, numBytesAlignmentConstSfBuf)); firstChunkReuseSmem.emplace_back(reuseChunksSmemConstSfBuf); } // Create SMEM helper object. mSmemAllocatorHelper = - MemAllocatorHelper(numBytesAndAlignmentPerSmemChunk, firstChunkReuseSmem, smemChunkNames); + MemAllocatorHelper(numBytesAndAlignmentPerSmemChunk, firstChunkReuseSmem, smemChunkNames); #if 0 // E.g., // Chunk 0 smemLoadA: 32768 bytes, 1024 alignment, false, offset 0 @@ -430,7 +454,7 @@ class KernelTraits { // Add info. tmemChunkNames.emplace_back("tmemD"); numBytesAndAlignmentPerTmemChunk.emplace_back( - std::make_pair(numTmemColsD, numColsAlignmentD)); + std::make_pair(numTmemColsD, numColsAlignmentD)); firstChunkReuseTmem.emplace_back(reuseChunksTmemD); } @@ -440,10 +464,10 @@ class KernelTraits { bool const useTmemA = (numSlicesForSliceK > 1) || (dtypeMmaA != dtypeA); // Number of columns for A. auto const numTmemColsA = - useTmemA ? numStages * tileK / - (numSlicesForSliceK * tg::dtypeGetNumBits(tg::Dtype::UInt32) / - tg::dtypeGetNumBits(dtypeMmaA)) - : 0; + useTmemA ? numStages * tileK / + (numSlicesForSliceK * tg::dtypeGetNumBits(tg::Dtype::UInt32) / + tg::dtypeGetNumBits(dtypeMmaA)) + : 0; // Number of columns for A alignment. auto const numColsAlignmentA = 4; // No need to reuse TMEM. @@ -452,7 +476,7 @@ class KernelTraits { // Add info. tmemChunkNames.emplace_back("tmemA"); numBytesAndAlignmentPerTmemChunk.emplace_back( - std::make_pair(numTmemColsA, numColsAlignmentA)); + std::make_pair(numTmemColsA, numColsAlignmentA)); firstChunkReuseTmem.emplace_back(reuseChunksTmemA); } @@ -464,11 +488,10 @@ class KernelTraits { bool const useConstSfA = useBlockScalingA && !tg::dtypeIsBlockFmt(dtypeA); // Number of columns for scaling factors of A. auto const numTmemColsSfA = - useConstSfA - ? tg::roundUp((tileK / 64) * tg::getTmemColStridePerGroup(tileM, mmaK), 4) - : (useBlockScalingA - ? ((tileK / 64) * tg::getTmemColStridePerGroup(tileM, mmaK)) * numStages - : 0); + useConstSfA ? tg::roundUp((tileK / 64) * tg::getTmemColStridePerGroup(tileM, mmaK), 4) + : (useBlockScalingA + ? ((tileK / 64) * tg::getTmemColStridePerGroup(tileM, mmaK)) * numStages + : 0); // Number of columns for Sf alignment. auto const numColsAlignmentSfA = 4; // No need to reuse TMEM. @@ -477,7 +500,7 @@ class KernelTraits { // Add info. tmemChunkNames.emplace_back("tmemSfA"); numBytesAndAlignmentPerTmemChunk.emplace_back( - std::make_pair(numTmemColsSfA, numColsAlignmentSfA)); + std::make_pair(numTmemColsSfA, numColsAlignmentSfA)); firstChunkReuseTmem.emplace_back(reuseChunksTmemSfA); } @@ -489,11 +512,10 @@ class KernelTraits { bool const useConstSfB = useBlockScalingB && !tg::dtypeIsBlockFmt(dtypeB); // Number of columns for scaling factors of B. auto const numTmemColsSfB = - useConstSfB - ? tg::roundUp((tileK / 64) * tg::getTmemColStridePerGroup(tileN, mmaK), 4) - : (useBlockScalingB - ? ((tileK / 64) * tg::getTmemColStridePerGroup(tileN, mmaK)) * numStages - : 0); + useConstSfB ? tg::roundUp((tileK / 64) * tg::getTmemColStridePerGroup(tileN, mmaK), 4) + : (useBlockScalingB + ? ((tileK / 64) * tg::getTmemColStridePerGroup(tileN, mmaK)) * numStages + : 0); // Number of columns for Sf alignment. auto const numColsAlignmentSfB = 4; // No need to reuse TMEM. @@ -502,17 +524,17 @@ class KernelTraits { // Add info. tmemChunkNames.emplace_back("tmemSfB"); numBytesAndAlignmentPerTmemChunk.emplace_back( - std::make_pair(numTmemColsSfB, numColsAlignmentSfB)); + std::make_pair(numTmemColsSfB, numColsAlignmentSfB)); firstChunkReuseTmem.emplace_back(reuseChunksTmemSfB); } // Create TMEM helper object. mTmemAllocatorHelper = - MemAllocatorHelper(numBytesAndAlignmentPerTmemChunk, firstChunkReuseTmem, tmemChunkNames); + MemAllocatorHelper(numBytesAndAlignmentPerTmemChunk, firstChunkReuseTmem, tmemChunkNames); } } - public: +public: // The MMA kind. tg::MmaKind mMmaKind; // Helper for SMEM allocation. @@ -551,7 +573,9 @@ inline int32_t getSmemOffsetLoadB(KernelTraits traits) { //////////////////////////////////////////////////////////////////////////////////////////////////// -inline int32_t getSmemOffsetLoadAb(KernelTraits traits) { return getSmemOffsetLoadA(traits); } +inline int32_t getSmemOffsetLoadAb(KernelTraits traits) { + return getSmemOffsetLoadA(traits); +} //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -638,6 +662,6 @@ inline int32_t getTmemOffsetSfB(KernelTraits traits) { //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace gemm +} // namespace gemm -} // namespace batchedGemm +} // namespace batchedGemm diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h index a1412444ae..c20ce9c00f 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h @@ -1,25 +1,24 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & - * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +* SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & +* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ #pragma once -#include - #include "trtllm/gen/DtypeDecl.h" #include "trtllm/gen/MmaDecl.h" +#include #ifdef TLLM_ENABLE_CUDA #include @@ -39,10 +38,12 @@ namespace tg = trtllm::gen; #ifdef TLLM_ENABLE_CUDA -inline CUtensorMap buildNdTmaDescriptor(tg::Dtype dtype, tg::MmaKind mmaKind, +inline CUtensorMap buildNdTmaDescriptor(tg::Dtype dtype, + tg::MmaKind mmaKind, std::vector const& shapes, std::vector const& strides, - std::vector const& tileShapes, void* gmemAddr, + std::vector const& tileShapes, + void* gmemAddr, bool doSwizzle = true) { // The multiplication factor of the data padding in SMEM. int32_t padMultiplier = 1; @@ -76,7 +77,7 @@ inline CUtensorMap buildNdTmaDescriptor(tg::Dtype dtype, tg::MmaKind mmaKind, // The swizzle type. CUtensorMapSwizzle swizzleType{CU_TENSOR_MAP_SWIZZLE_NONE}; int32_t fastestDimTileSizeBytes = - (tileShapes[0] * tg::dtypeGetNumBits(dtype) * padMultiplier) / /* bits */ 8; + (tileShapes[0] * tg::dtypeGetNumBits(dtype) * padMultiplier) / /* bits */ 8; if (doSwizzle) { if ((fastestDimTileSizeBytes % 128) == 0) { swizzleType = CU_TENSOR_MAP_SWIZZLE_128B; @@ -96,7 +97,7 @@ inline CUtensorMap buildNdTmaDescriptor(tg::Dtype dtype, tg::MmaKind mmaKind, } // Check gmem address must be 16B-aligned - assert((reinterpret_cast(gmemAddr) & 0b1111) == 0); // + assert((reinterpret_cast(gmemAddr) & 0b1111) == 0); // // Check shape must be in range [1, 2^32] int32_t dim = shapes.size(); @@ -105,8 +106,8 @@ inline CUtensorMap buildNdTmaDescriptor(tg::Dtype dtype, tg::MmaKind mmaKind, assert(dim == 2 || dim == 3 || dim == 4); // Check shape range. for (int32_t ii = 0; ii < dim; ++ii) { - assert(shapes[ii] >= (uint64_t(1))); // Size must be min 1 - assert(shapes[ii] <= (uint64_t(1) << 32)); // Size must be max 2^32 + assert(shapes[ii] >= (uint64_t(1))); // Size must be min 1 + assert(shapes[ii] <= (uint64_t(1) << 32)); // Size must be max 2^32 } // TMA descriptor does not store the zeroth stride and assumes it is 1. @@ -144,19 +145,24 @@ inline CUtensorMap buildNdTmaDescriptor(tg::Dtype dtype, tg::MmaKind mmaKind, std::vector tileStrides(dim, 1); // Build the descriptor. - CUresult result = - cuTensorMapEncodeTiled(&desc, tmaDataFormat, - /*tensorRank=*/dim, gmemAddr, shapes.data(), stridesInBytes.data(), - boxDim.data(), tileStrides.data(), - /*interleave=*/CU_TENSOR_MAP_INTERLEAVE_NONE, swizzleType, - /*l2Promotion=*/CU_TENSOR_MAP_L2_PROMOTION_L2_128B, - /*oobFill=*/CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + CUresult result = cuTensorMapEncodeTiled(&desc, + tmaDataFormat, + /*tensorRank=*/dim, + gmemAddr, + shapes.data(), + stridesInBytes.data(), + boxDim.data(), + tileStrides.data(), + /*interleave=*/CU_TENSOR_MAP_INTERLEAVE_NONE, + swizzleType, + /*l2Promotion=*/CU_TENSOR_MAP_L2_PROMOTION_L2_128B, + /*oobFill=*/CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); if (result != CUDA_SUCCESS) { char const* errorString; cuGetErrorString(result, &errorString); std::stringstream ss; - ss << "Error: Failed to initialize the TMA descriptor " << result << std::endl; + ss << "Error: Failed to initialize the TMA descriptor. " << errorString << std::endl; ss << "tmaFormat: " << static_cast(tmaDataFormat) << " dim: " << dim << " gmem: " << gmemAddr << std::endl; @@ -193,9 +199,11 @@ inline CUtensorMap buildNdTmaDescriptor(tg::Dtype dtype, tg::MmaKind mmaKind, } // TODO: make it work with the above descriptor? -inline CUtensorMap buildSfTmaDescriptor(tg::Dtype dtype, std::vector const& shapes, +inline CUtensorMap buildSfTmaDescriptor(tg::Dtype dtype, + std::vector const& shapes, std::vector const& strides, - const std::vector& tileShapes, void* gmemAddr) { + const std::vector& tileShapes, + void* gmemAddr) { CUtensorMap desc{}; CUtensorMapDataType tmaDataFormat; if (dtype == tg::Dtype::E4m3 || dtype == tg::Dtype::UE8m0) { @@ -209,14 +217,14 @@ inline CUtensorMap buildSfTmaDescriptor(tg::Dtype dtype, std::vector c CUtensorMapSwizzle swizzleType = CU_TENSOR_MAP_SWIZZLE_NONE; // Check gmem address must be 16B-aligned - assert((reinterpret_cast(gmemAddr) & 0b1111) == 0); // + assert((reinterpret_cast(gmemAddr) & 0b1111) == 0); // // Check shape must be in range [1, 2^32] int32_t dim = shapes.size(); // Check shape range. for (int32_t ii = 0; ii < dim; ++ii) { - assert(shapes[ii] >= (uint64_t(1))); // Size must be min 1 - assert(shapes[ii] <= (uint64_t(1) << 32)); // Size must be max 2^32 + assert(shapes[ii] >= (uint64_t(1))); // Size must be min 1 + assert(shapes[ii] <= (uint64_t(1) << 32)); // Size must be max 2^32 } // TMA descriptor does not store the zeroth stride and assumes it is 1. @@ -251,7 +259,7 @@ inline CUtensorMap buildSfTmaDescriptor(tg::Dtype dtype, std::vector c char const* errorString; cuGetErrorString(result, &errorString); std::stringstream ss; - ss << "Error: Failed to initialize the TMA descriptor for SF " << errorString << std::endl; + ss << "Error: Failed to initialize the TMA descriptor for SF. " << errorString << std::endl; ss << "tmaFormat: " << static_cast(tmaDataFormat) << " dim: " << dim << " gmem: " << gmemAddr << std::endl; @@ -288,10 +296,10 @@ inline CUtensorMap buildSfTmaDescriptor(tg::Dtype dtype, std::vector c return desc; } -#endif // defined TLLM_ENABLE_CUDA +#endif // defined TLLM_ENABLE_CUDA //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace gemm +} // namespace gemm -} // namespace batchedGemm +} // namespace batchedGemm diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CommonUtils.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CommonUtils.h index 393949a516..798da2a23a 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CommonUtils.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CommonUtils.h @@ -1,19 +1,19 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & - * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +* SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & +* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ #pragma once namespace batchedGemm { @@ -23,21 +23,34 @@ namespace gen { //////////////////////////////////////////////////////////////////////////////////////////////////// -template -inline T ceilDiv(T m, T n) { +// +// TMA OOB optimization constants. +// +// CUDA Programming Guide states that "globalDim must be non-zero and less than or equal to 2^32". +// In practice, the kernel acts funny with TMA shape of 2^32 so we use 2^31. +constexpr unsigned long TmaDimMax = 1UL << 31; +// Chosen so that LargeN * XLargeN * sizeof(dtype) >= 2^64 which causes overflow and effectively +// becomes 0. As sizeof(dtype) can be as small as 0.5B, we choose LargeN = 2^30 and XLargeN = 2^35 +// so overflow can happen. +constexpr unsigned long LargeN = 1UL << 30; +// Used in TMA stride. Should be less than 2^40. +constexpr unsigned long XLargeN = 1UL << 35; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template inline T ceilDiv(T m, T n) { return (m + n - T(1)) / n; } //////////////////////////////////////////////////////////////////////////////////////////////////// -template -inline T roundUp(T m, T n) { +template inline T roundUp(T m, T n) { return ceilDiv(m, n) * n; } //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace gen -} // namespace trtllm +} // namespace gen +} // namespace trtllm -} // namespace batchedGemm +} // namespace batchedGemm diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CudaKernelLauncher.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CudaKernelLauncher.h index 42bc884f92..7b819ba753 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CudaKernelLauncher.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CudaKernelLauncher.h @@ -1,27 +1,26 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & - * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +* SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & +* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ #pragma once #ifdef TLLM_ENABLE_CUDA -#include -#include - #include #include +#include +#include #endif namespace batchedGemm { @@ -31,13 +30,18 @@ namespace gen { //////////////////////////////////////////////////////////////////////////////////////////////////// #ifdef TLLM_ENABLE_CUDA -inline CUresult launchKernel(void* kernelParams, void* cudaStream, int32_t smemSize, - CUfunction kernel, dim3 block3, dim3 grid3, dim3 cluster3, +inline CUresult launchKernel(void* kernelParams, + void* cudaStream, + int32_t smemSize, + CUfunction kernel, + dim3 block3, + dim3 grid3, + dim3 cluster3, bool enablesPdl) { // Make sure we can launch with that much shared memory. if (smemSize > 48 * 1024) { CUresult result = - cuFuncSetAttribute(kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smemSize); + cuFuncSetAttribute(kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smemSize); if (result != CUDA_SUCCESS) { return result; } @@ -62,7 +66,7 @@ inline CUresult launchKernel(void* kernelParams, void* cudaStream, int32_t smemS launchAttrs[0].value.clusterDim.z = cluster3.z; launchAttrs[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE; launchAttrs[1].value.clusterSchedulingPolicyPreference = - (clusterDim > 1) ? CU_CLUSTER_SCHEDULING_POLICY_SPREAD : CU_CLUSTER_SCHEDULING_POLICY_DEFAULT; + (clusterDim > 1) ? CU_CLUSTER_SCHEDULING_POLICY_SPREAD : CU_CLUSTER_SCHEDULING_POLICY_DEFAULT; launchAttrs[2].id = CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION; launchAttrs[2].value.programmaticStreamSerializationAllowed = enablesPdl; launchConfig.attrs = launchAttrs; @@ -70,10 +74,10 @@ inline CUresult launchKernel(void* kernelParams, void* cudaStream, int32_t smemS // Add setting for non-portable cluster size. if (clusterDim > 8) { - CUresult result = - cuFuncSetAttribute(kernel, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, - 1 // Enable non-portable cluster sizes - ); + CUresult result = cuFuncSetAttribute(kernel, + CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, + 1 // Enable non-portable cluster sizes + ); if (result != CUDA_SUCCESS) { return result; } @@ -86,7 +90,7 @@ inline CUresult launchKernel(void* kernelParams, void* cudaStream, int32_t smemS //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace gen -} // namespace trtllm +} // namespace gen +} // namespace trtllm -} // namespace batchedGemm +} // namespace batchedGemm diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h index 0866256492..2b26230f33 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h @@ -1,25 +1,25 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & - * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +* SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & +* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ #pragma once -#include #include -#include +#include #include +#include #ifndef TLLM_GEN_EXPORT_INTERFACE #include "trtllm/gen/MmaDecl.h" #else @@ -50,9 +50,9 @@ enum class Dtype : uint32_t { // Bit 4: is it signed? 0x1 if true, 0x0 otherwise. // Byte 3: Is it a block format? 0x1 if true, 0x0 otherwise. -#define TLLM_ENCODE_DTYPE(BlockFormatBit, SignedBit, IntegerBit, NumBits, Uid) \ - uint32_t { \ - (BlockFormatBit << 24) | (SignedBit << 20) | (IntegerBit << 16) | (NumBits << 8) | (Uid) \ +#define TLLM_ENCODE_DTYPE(BlockFormatBit, SignedBit, IntegerBit, NumBits, Uid) \ + uint32_t { \ + (BlockFormatBit << 24) | (SignedBit << 20) | (IntegerBit << 16) | (NumBits << 8) | (Uid) \ } // clang-format off @@ -109,7 +109,9 @@ inline bool dtypeIsFloat(Dtype dtype) { //////////////////////////////////////////////////////////////////////////////////////////////////// // Is a given data type an 8-bit floating-point type? -inline bool dtypeIsFp8(Dtype dtype) { return dtype == Dtype::E4m3 || dtype == Dtype::E5m2; } +inline bool dtypeIsFp8(Dtype dtype) { + return dtype == Dtype::E4m3 || dtype == Dtype::E5m2; +} //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -132,51 +134,51 @@ inline bool dtypeIsSigned(Dtype dtype) { // For logging and error reporting inline std::string dtypeToString(Dtype dtype) { switch (dtype) { - case Dtype::Bfloat16: - return "Bfloat16"; - case Dtype::Bool: - return "Bool"; - case Dtype::E2m1: - return "E2m1"; - case Dtype::E2m3: - return "E2m3"; - case Dtype::E3m2: - return "E3m2"; - case Dtype::E4m3: - return "E4m3"; - case Dtype::E5m2: - return "E5m2"; - case Dtype::Fp16: - return "Fp16"; - case Dtype::Fp32: - return "Fp32"; - case Dtype::Int8: - return "Int8"; - case Dtype::Int32: - return "Int32"; - case Dtype::Int64: - return "Int64"; - case Dtype::MxE4m3: - return "MxE4m3"; - case Dtype::MxE2m1: - return "MxE2m1"; - case Dtype::UE8m0: - return "UE8m0"; - case Dtype::UInt8: - return "UInt8"; - case Dtype::UInt16: - return "UInt16"; - case Dtype::UInt32: - return "UInt32"; - case Dtype::UInt64: - return "UInt64"; - case Dtype::UInt128: - return "UInt128"; - case Dtype::Void: - return "Void"; - default: - assert(false); - return "Unsupported type"; + case Dtype::Bfloat16: + return "Bfloat16"; + case Dtype::Bool: + return "Bool"; + case Dtype::E2m1: + return "E2m1"; + case Dtype::E2m3: + return "E2m3"; + case Dtype::E3m2: + return "E3m2"; + case Dtype::E4m3: + return "E4m3"; + case Dtype::E5m2: + return "E5m2"; + case Dtype::Fp16: + return "Fp16"; + case Dtype::Fp32: + return "Fp32"; + case Dtype::Int8: + return "Int8"; + case Dtype::Int32: + return "Int32"; + case Dtype::Int64: + return "Int64"; + case Dtype::MxE4m3: + return "MxE4m3"; + case Dtype::MxE2m1: + return "MxE2m1"; + case Dtype::UE8m0: + return "UE8m0"; + case Dtype::UInt8: + return "UInt8"; + case Dtype::UInt16: + return "UInt16"; + case Dtype::UInt32: + return "UInt32"; + case Dtype::UInt64: + return "UInt64"; + case Dtype::UInt128: + return "UInt128"; + case Dtype::Void: + return "Void"; + default: + assert(false); + return "Unsupported type"; } } @@ -184,12 +186,12 @@ inline std::string dtypeToString(Dtype dtype) { inline Dtype dtypeEltType(Dtype dtype) { switch (dtype) { - case Dtype::MxE2m1: - return Dtype::E2m1; - case Dtype::MxE4m3: - return Dtype::E4m3; - default: - return dtype; + case Dtype::MxE2m1: + return Dtype::E2m1; + case Dtype::MxE4m3: + return Dtype::E4m3; + default: + return dtype; } } @@ -197,14 +199,14 @@ inline Dtype dtypeEltType(Dtype dtype) { inline int dtypeNumEltsPerSf(Dtype dtype) { switch (dtype) { - case Dtype::E2m1: - return 16; - case Dtype::MxE2m1: - case Dtype::MxE4m3: - return 32; - default: - assert(false); - return -1; + case Dtype::E2m1: + return 16; + case Dtype::MxE2m1: + case Dtype::MxE4m3: + return 32; + default: + assert(false); + return -1; } } @@ -213,14 +215,14 @@ inline int dtypeNumEltsPerSf(Dtype dtype) { // Returns the dtype of scaling factors, if applicable. inline Dtype dtypeGetBlockSfType(Dtype dtype) { switch (dtype) { - case Dtype::E2m1: - return Dtype::E4m3; - case Dtype::MxE2m1: - case Dtype::MxE4m3: - return Dtype::UE8m0; - default: - assert(false); - return Dtype::Void; + case Dtype::E2m1: + return Dtype::E4m3; + case Dtype::MxE2m1: + case Dtype::MxE4m3: + return Dtype::UE8m0; + default: + assert(false); + return Dtype::Void; } } @@ -265,7 +267,7 @@ inline MmaKind dtypeGetMmaKind(Dtype dtypeA, Dtype dtypeB) { //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace gen -} // namespace trtllm +} // namespace gen +} // namespace trtllm -} // namespace batchedGemm +} // namespace batchedGemm diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/MmaDecl.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/MmaDecl.h index c8de154396..7633090dc8 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/MmaDecl.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/MmaDecl.h @@ -1,19 +1,19 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & - * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +* SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & +* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ #pragma once #include @@ -21,9 +21,9 @@ #include #ifndef TLLM_GEN_EXPORT_INTERFACE #include "trtllm/gen/CommonUtils.h" -#else // TLLM_GEN_EXPORT_INTERFACE +#else // TLLM_GEN_EXPORT_INTERFACE #include "CommonUtils.h" -#endif // TLLM_GEN_EXPORT_INTERFACE +#endif // TLLM_GEN_EXPORT_INTERFACE namespace batchedGemm { @@ -73,23 +73,23 @@ inline bool mmaKindIsBlockFmt(MmaKind mmaKind) { // For logging and error reporting inline std::string mmaKindToString(MmaKind mmaKind) { switch (mmaKind) { - case MmaKind::Auto: - return "Auto"; - case MmaKind::Fp16: - return "Fp16"; - case MmaKind::Fp8Fp6Fp4: - return "Fp8Fp6Fp4"; - case MmaKind::Int8: - return "Int8"; - case MmaKind::MxFp4NvFp4: - return "MxFp4NvFp4"; - case MmaKind::MxFp8Fp6Fp4: - return "MxFp8Fp6Fp4"; - case MmaKind::Tf32: - return "Tf32"; - default: - assert(false); - return "Unsupported type"; + case MmaKind::Auto: + return "Auto"; + case MmaKind::Fp16: + return "Fp16"; + case MmaKind::Fp8Fp6Fp4: + return "Fp8Fp6Fp4"; + case MmaKind::Int8: + return "Int8"; + case MmaKind::MxFp4NvFp4: + return "MxFp4NvFp4"; + case MmaKind::MxFp8Fp6Fp4: + return "MxFp8Fp6Fp4"; + case MmaKind::Tf32: + return "Tf32"; + default: + assert(false); + return "Unsupported type"; } } @@ -104,7 +104,7 @@ inline int32_t getTmemColStridePerGroup(int32_t tileMn, int32_t mmaK) { //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace gen -} // namespace trtllm +} // namespace gen +} // namespace trtllm -} // namespace batchedGemm +} // namespace batchedGemm diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/SfLayoutDecl.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/SfLayoutDecl.h index 965bb1b7b8..c64105696e 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/SfLayoutDecl.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/SfLayoutDecl.h @@ -1,19 +1,19 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & - * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +* SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & +* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ #pragma once #include @@ -74,23 +74,23 @@ enum class SfLayout { inline std::string sfLayoutToString(SfLayout layout) { switch (layout) { - case SfLayout::Linear: - return "linear"; - case SfLayout::R8c4: - return "8x4"; - case SfLayout::R8c16: - return "8x16"; - case SfLayout::R128c4: - return "128x4"; - default: - assert(false); - return "Unsupported layout"; + case SfLayout::Linear: + return "linear"; + case SfLayout::R8c4: + return "8x4"; + case SfLayout::R8c16: + return "8x16"; + case SfLayout::R128c4: + return "128x4"; + default: + assert(false); + return "Unsupported layout"; } } //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace gen -} // namespace trtllm +} // namespace gen +} // namespace trtllm -} // namespace batchedGemm +} // namespace batchedGemm diff --git a/tests/test_trtllm_gen_fused_moe.py b/tests/test_trtllm_gen_fused_moe.py index f811348628..2eb739fe7e 100644 --- a/tests/test_trtllm_gen_fused_moe.py +++ b/tests/test_trtllm_gen_fused_moe.py @@ -1056,7 +1056,7 @@ def prepare_static_weights_for_kernel( ) if weight_layout == WeightLayout.BlockMajorK: - block_k = 128 + block_k = 64 tmp_weights1 = convert_to_block_layout(tmp_weights1, block_k) tmp_weights2 = convert_to_block_layout(tmp_weights2, block_k) @@ -1074,6 +1074,8 @@ def prepare_static_weights_for_kernel( return { "gemm1_weights": gemm1_weights_bf16_shuffled, "gemm2_weights": gemm2_weights_bf16_shuffled, + "use_shuffled_weight": use_shuffled_weight, + "weight_layout": weight_layout, } def call_moe( @@ -1105,6 +1107,8 @@ def call_moe( num_experts, # the rest are enforced by the api to be passed in the keyword form # as opposed to the positional form + use_shuffled_weight=static_data["use_shuffled_weight"], + weight_layout=static_data["weight_layout"], tile_tokens_dim=tile_tokens_dim, routing_method_type=routing_method_type, ) From dd1d53d077d6da3d6db4fcb42e5af09bfa7fd148 Mon Sep 17 00:00:00 2001 From: Nikita Korobov <14355239+nekorobov@users.noreply.github.com> Date: Fri, 26 Sep 2025 05:16:28 -0700 Subject: [PATCH 03/12] some other fixes --- tests/test_trtllm_gen_fused_moe.py | 45 ++++++++++++++++-------------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/tests/test_trtllm_gen_fused_moe.py b/tests/test_trtllm_gen_fused_moe.py index 2eb739fe7e..8065e1a186 100644 --- a/tests/test_trtllm_gen_fused_moe.py +++ b/tests/test_trtllm_gen_fused_moe.py @@ -220,6 +220,7 @@ class QuantMode(IntEnum): FP4_MXFP4_Bf16 = 3 FP8_BLOCK_SCALE = 4 FP8_PER_TENSOR = 5 + BF16 = 6 # ==================================================================================== @@ -1027,38 +1028,28 @@ def prepare_static_weights_for_kernel( # Use shuffled weights with BlockMajorK layout for better performance use_shuffled_weight = weight_processing["use_shuffled_weight"] weight_layout = weight_processing["layout"] - + if use_shuffled_weight: # FIXME: this depends on the kernel internals epilogue_tile_m = 128 # Reorder rows of W1 for fused gated activation - gemm1_weights_bf16_interleaved = [] - for i in range(num_experts): - gemm1_weights_bf16_interleaved.append( - reorder_rows_for_gated_act_gemm(args.gemm1_weights[i].clone()) - ) - - # Stack weights and scales for all experts - gemm1_weights_bf16_interleaved = torch.stack( - gemm1_weights_bf16_interleaved - ).reshape(num_experts, 2 * intermediate_size, hidden_size) - - # Shuffle weights and scaling factors for transposed mma output gemm1_weights_bf16_shuffled = [] gemm2_weights_bf16_shuffled = [] for i in range(num_experts): + tmp_weights1 = reorder_rows_for_gated_act_gemm(args.gemm1_weights[i].clone()) + tmp_weights1 = shuffle_matrix_a( - args.gemm1_weights[i].view(torch.uint8), epilogue_tile_m + tmp_weights1.view(torch.uint8), epilogue_tile_m ) tmp_weights2 = shuffle_matrix_a( - args.gemm2_weights[i].view(torch.uint8), epilogue_tile_m + args.gemm2_weights[i].clone().view(torch.uint8), epilogue_tile_m ) if weight_layout == WeightLayout.BlockMajorK: block_k = 64 - tmp_weights1 = convert_to_block_layout(tmp_weights1, block_k) - tmp_weights2 = convert_to_block_layout(tmp_weights2, block_k) + tmp_weights1 = convert_to_block_layout(tmp_weights1.view(torch.bfloat16), block_k) + tmp_weights2 = convert_to_block_layout(tmp_weights2.view(torch.bfloat16), block_k) gemm1_weights_bf16_shuffled.append(tmp_weights1) gemm2_weights_bf16_shuffled.append(tmp_weights2) @@ -1066,10 +1057,11 @@ def prepare_static_weights_for_kernel( # Stack weights for all experts gemm1_weights_bf16_shuffled = torch.stack(gemm1_weights_bf16_shuffled).view( torch.bfloat16 - ) + ).contiguous() gemm2_weights_bf16_shuffled = torch.stack(gemm2_weights_bf16_shuffled).view( torch.bfloat16 - ) + ).contiguous() + print(gemm1_weights_bf16_shuffled.shape, gemm2_weights_bf16_shuffled.shape) return { "gemm1_weights": gemm1_weights_bf16_shuffled, @@ -1412,6 +1404,7 @@ def check_accuracy(a, b, atol, rtol, percent): raise Exception("Inf in actual output") assert a.shape == b.shape, f"Shape mismatch: {a.shape} vs {b.shape}" + print(a, b) left = torch.abs(a - b) right = atol + rtol * torch.abs(b) count = torch.sum(left > right) @@ -1725,6 +1718,10 @@ def run_moe_dequant(args, quant_mode: QuantMode): .to(torch.float) ) args.c_global_sf = 1.0 + elif quant_mode == QuantMode.BF16: + activation_output = activation_output.to(torch.bfloat16).to(torch.float) + args.c_global_sf = 1.0 + print(activation_output) else: # mxfp4Bf16 activation_output = activation_output.to(torch.bfloat16).to(torch.float) args.c_global_sf = 1.0 @@ -1958,7 +1955,7 @@ def run_moe_reference_bf16(args): GatedActType.SwiGlu.value, # gated_act_type ) - return run_moe_dequant(args_dequant, QuantMode.FP8_PER_TENSOR), args_dequant + return run_moe_dequant(args_dequant, QuantMode.BF16), args_dequant def _compute_moe_actual_unified(moe_impl, args_dequant, args, **kwargs): @@ -2255,7 +2252,7 @@ def test_moe_quantization_classes( else: routing_bias = None - hidden_states = 2 * torch.randn( + hidden_states = torch.ones( (num_tokens, hidden_size), device="cuda", dtype=torch.bfloat16 ) gemm1_weights = torch.randn( @@ -2263,6 +2260,12 @@ def test_moe_quantization_classes( device="cuda", dtype=torch.bfloat16, ) + # for n in range(intermediate_size): + # for k in range(hidden_size): + # gemm1_weights[:, 1, k] = k * 0.01 + # gemm1_weights[:, 1 + intermediate_size, k] = (k + k%2 + 30) * 0.00001 + # gemm1_weights[:, 1 + intermediate_size, k] = k * 0.01 + gemm2_weights = torch.randn( (num_experts, hidden_size, intermediate_size), device="cuda", From 2b570cf85d7f314eaae51d516859d0a812265aa4 Mon Sep 17 00:00:00 2001 From: Nikita Korobov <14355239+nekorobov@users.noreply.github.com> Date: Mon, 29 Sep 2025 14:02:45 -0700 Subject: [PATCH 04/12] fix: test is passing --- csrc/trtllm_fused_moe_kernel_launcher.cu | 15 +++++++------ csrc/trtllm_fused_moe_runner.cu | 1 - flashinfer/fused_moe/core.py | 27 +---------------------- tests/test_trtllm_gen_fused_moe.py | 28 +++++++----------------- 4 files changed, 17 insertions(+), 54 deletions(-) diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index 4dea0c7736..ebf5b07ad1 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -71,7 +71,7 @@ Driver calls take place to carry out the gemm operations. */ class FusedMoeLauncher { - protected: +protected: at::Tensor const* routing_logits{}; at::Tensor const* routing_bias{}; at::Tensor const* hidden_states{}; @@ -305,7 +305,7 @@ class FusedMoeLauncher { int routing_device = routing_logits->get_device(); auto const& routing_stream = at::cuda::getCurrentCUDAStream(routing_device); routing_runner.run( - routing_logits->data_ptr(), args->routing_bias, args->num_tokens, args->num_experts, + routing_logits->data_ptr(), args->routing_bias, args->num_tokens, args->num_experts, args->top_k, args->n_group, args->topk_group, args->local_expert_offset, args->local_num_experts, args->routed_scaling_factor, expert_indexes.data_ptr(), expert_count_histogram.data_ptr(), total_num_padded_tokens.data_ptr(), @@ -356,7 +356,7 @@ void FusedMoeLauncher::init_common( this->gemm1_weights = gemm1_weights; this->gemm2_weights = gemm2_weights; - args->routing_logits = routing_logits->data_ptr(); + args->routing_logits = routing_logits->data_ptr(); args->routing_bias = routing_bias ? routing_bias->data_ptr() : nullptr; args->hidden_states = hidden_states->data_ptr(); args->gemm1_weights = gemm1_weights->data_ptr(); @@ -433,7 +433,7 @@ class Bf16MoeLauncher : public FusedMoeLauncher { int32_t max_num_padded_tokens = workspace.total_max_padded_tokens; gemm1_output = - at::detail::empty_cuda({max_num_padded_tokens, 2 * args->intermediate_size}, + at::detail::empty_cuda({max_num_padded_tokens, args->intermediate_size}, at::ScalarType::BFloat16, hidden_states->device(), std::nullopt); activation_output = at::detail::empty_cuda({max_num_padded_tokens, args->intermediate_size}, @@ -496,9 +496,10 @@ at::Tensor trtllm_bf16_moe(at::Tensor const& routing_logits, Bf16MoeLauncher launcher; launcher.init(routing_logits, routing_bias, hidden_states, gemm1_weights, gemm2_weights, - std::move(args), tile_tokens_dim, routing_method_type, use_shuffled_weight, - weight_layout); - return launcher.run(moe_tactic, enable_pdl)[0]; + std::move(args), tile_tokens_dim, routing_method_type, use_shuffled_weight, + weight_layout); + auto data = launcher.run(moe_tactic, enable_pdl)[0]; + return data; } at::Tensor trtllm_fp8_per_tensor_scale_moe_launcher( diff --git a/csrc/trtllm_fused_moe_runner.cu b/csrc/trtllm_fused_moe_runner.cu index d17222a361..2b03625938 100644 --- a/csrc/trtllm_fused_moe_runner.cu +++ b/csrc/trtllm_fused_moe_runner.cu @@ -137,7 +137,6 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 routingData.mDoSoftmaxBeforeTopK = routingMethodType == RoutingMethodType::RenormalizeNaive; routingData.mNormTopkProb = routingMethodType == RoutingMethodType::RenormalizeNaive; routingData.mApplySoftmaxAfterTopK = routingMethodType == RoutingMethodType::Renormalize; - routingData.mPtrScores = routingLogits; // diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index f3baee3ad1..60d9734932 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -1251,33 +1251,8 @@ def trtllm_bf16_moe_op( if enable_pdl is None: enable_pdl = device_support_pdl(hidden_states.device) # Call the C++ function for block scale MoE - # FIXME: remove these prints once done - print( - "@@@@", - routing_logits.dtype, - "routing_bias.dtype", - hidden_states.dtype, - gemm1_weights.dtype, - gemm2_weights.dtype, - ) - print("@@@@", gemm1_weights.shape, gemm2_weights.shape) - print( - num_experts, - top_k, - n_group, - topk_group, - intermediate_size, - local_expert_offset, - local_num_experts, - tile_tokens_dim, - use_shuffled_weight, - weight_layout, - moe_tactic, - routing_method_type, - enable_pdl, - ) output = moe_op.trtllm_bf16_moe( - routing_logits.to(torch.float), # FIXME what's the supported type? + routing_logits, routing_bias, hidden_states, gemm1_weights, diff --git a/tests/test_trtllm_gen_fused_moe.py b/tests/test_trtllm_gen_fused_moe.py index 8065e1a186..eaf7942c63 100644 --- a/tests/test_trtllm_gen_fused_moe.py +++ b/tests/test_trtllm_gen_fused_moe.py @@ -1037,22 +1037,21 @@ def prepare_static_weights_for_kernel( gemm1_weights_bf16_shuffled = [] gemm2_weights_bf16_shuffled = [] for i in range(num_experts): - tmp_weights1 = reorder_rows_for_gated_act_gemm(args.gemm1_weights[i].clone()) - + tmp_weights1 = reorder_rows_for_gated_act_gemm(args.gemm1_weights[i].clone().view(torch.uint8)) tmp_weights1 = shuffle_matrix_a( - tmp_weights1.view(torch.uint8), epilogue_tile_m + tmp_weights1, epilogue_tile_m ) tmp_weights2 = shuffle_matrix_a( args.gemm2_weights[i].clone().view(torch.uint8), epilogue_tile_m ) if weight_layout == WeightLayout.BlockMajorK: - block_k = 64 - tmp_weights1 = convert_to_block_layout(tmp_weights1.view(torch.bfloat16), block_k) - tmp_weights2 = convert_to_block_layout(tmp_weights2.view(torch.bfloat16), block_k) + block_k = 128 + tmp_weights1 = convert_to_block_layout(tmp_weights1.view(torch.uint8), block_k) + tmp_weights2 = convert_to_block_layout(tmp_weights2.view(torch.uint8), block_k) - gemm1_weights_bf16_shuffled.append(tmp_weights1) - gemm2_weights_bf16_shuffled.append(tmp_weights2) + gemm1_weights_bf16_shuffled.append(tmp_weights1.view(torch.bfloat16)) + gemm2_weights_bf16_shuffled.append(tmp_weights2.view(torch.bfloat16)) # Stack weights for all experts gemm1_weights_bf16_shuffled = torch.stack(gemm1_weights_bf16_shuffled).view( @@ -1061,7 +1060,6 @@ def prepare_static_weights_for_kernel( gemm2_weights_bf16_shuffled = torch.stack(gemm2_weights_bf16_shuffled).view( torch.bfloat16 ).contiguous() - print(gemm1_weights_bf16_shuffled.shape, gemm2_weights_bf16_shuffled.shape) return { "gemm1_weights": gemm1_weights_bf16_shuffled, @@ -1404,14 +1402,11 @@ def check_accuracy(a, b, atol, rtol, percent): raise Exception("Inf in actual output") assert a.shape == b.shape, f"Shape mismatch: {a.shape} vs {b.shape}" - print(a, b) left = torch.abs(a - b) right = atol + rtol * torch.abs(b) count = torch.sum(left > right) mismatch_percent = count / a.numel() if mismatch_percent > 1 - percent: - print(a) - print(b) raise Exception( f"Mismatch percentage is {mismatch_percent:.4f} for rtol {rtol} " f"(threshold: {1 - percent:.4f})" @@ -1721,7 +1716,6 @@ def run_moe_dequant(args, quant_mode: QuantMode): elif quant_mode == QuantMode.BF16: activation_output = activation_output.to(torch.bfloat16).to(torch.float) args.c_global_sf = 1.0 - print(activation_output) else: # mxfp4Bf16 activation_output = activation_output.to(torch.bfloat16).to(torch.float) args.c_global_sf = 1.0 @@ -2252,7 +2246,7 @@ def test_moe_quantization_classes( else: routing_bias = None - hidden_states = torch.ones( + hidden_states = 2 * torch.randn( (num_tokens, hidden_size), device="cuda", dtype=torch.bfloat16 ) gemm1_weights = torch.randn( @@ -2260,12 +2254,6 @@ def test_moe_quantization_classes( device="cuda", dtype=torch.bfloat16, ) - # for n in range(intermediate_size): - # for k in range(hidden_size): - # gemm1_weights[:, 1, k] = k * 0.01 - # gemm1_weights[:, 1 + intermediate_size, k] = (k + k%2 + 30) * 0.00001 - # gemm1_weights[:, 1 + intermediate_size, k] = k * 0.01 - gemm2_weights = torch.randn( (num_experts, hidden_size, intermediate_size), device="cuda", From 50bfd472e74267dae81160c2593d5277b5315471 Mon Sep 17 00:00:00 2001 From: Alex Yang Date: Wed, 1 Oct 2025 21:45:30 +0000 Subject: [PATCH 05/12] patch hash for public cubins (tested) --- flashinfer/artifacts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index d8c49ff2fd..d41df0f80c 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -111,7 +111,7 @@ def get_available_cubin_files(source, retries=3, delay=5, timeout=10): class ArtifactPath: TRTLLM_GEN_FMHA: str = "037e528e719ec3456a7d7d654f26b805e44c63b1/fmha/trtllm-gen/" TRTLLM_GEN_BMM: str = ( - "b5c82312e606632b7571c7370f4335cfae46e206/batched_gemm-145d1b1-9e1d49a/" + "24fce2b18de57a4dba2740cb8a5fdbc0ea451ecb/batched_gemm-145d1b1-9e1d49a/" ) TRTLLM_GEN_GEMM: str = ( "037e528e719ec3456a7d7d654f26b805e44c63b1/gemm-8704aa4-f91dc9e/" @@ -125,7 +125,7 @@ class MetaInfoHash: "0ff77215b86997665cf75973e13cd2932f551d46b4e008f851d32d47e1d9560f" ) TRTLLM_GEN_BMM: str = ( - "c98b4ce69a39fd41556d67033c30ea814ef76b0a2fe16e798e55baf0104acc34" + "9880c131d1ae4959ed38ff59c9b4ec2fb5aca74316ceba789c890fa7abebacdb" ) DEEPGEMM: str = "b4374f857c3066089c4ec6b5e79e785559fa2c05ce2623710b0b04bf86414a48" TRTLLM_GEN_GEMM: str = ( From b1cf91fe0e1c4646dbfca8b03a7e7610e05c75ad Mon Sep 17 00:00:00 2001 From: Alex Yang Date: Wed, 1 Oct 2025 21:47:31 +0000 Subject: [PATCH 06/12] precommit + clean --- csrc/trtllm_fused_moe_kernel_launcher.cu | 34 +- .../trtllmGen_bmm_export/BatchedGemmEnums.h | 48 +- .../BatchedGemmInterface.h | 89 +-- .../trtllmGen_bmm_export/BatchedGemmOptions.h | 313 +++----- .../batched_gemm/trtllmGen_bmm_export/Enums.h | 46 +- .../GemmGatedActOptions.h | 98 ++- .../trtllmGen_bmm_export/GemmOptions.h | 708 +++++++----------- .../trtllmGen_bmm_export/KernelParams.h | 521 +++++-------- .../trtllmGen_bmm_export/KernelParamsDecl.h | 33 +- .../trtllmGen_bmm_export/KernelTraits.h | 182 ++--- .../trtllmGen_bmm_export/TmaDescriptor.h | 84 +-- .../trtllm/gen/CommonUtils.h | 42 +- .../trtllm/gen/CudaKernelLauncher.h | 62 +- .../trtllm/gen/DtypeDecl.h | 184 +++-- .../trtllmGen_bmm_export/trtllm/gen/MmaDecl.h | 74 +- .../trtllm/gen/SfLayoutDecl.h | 58 +- tests/test_trtllm_gen_fused_moe.py | 32 +- 17 files changed, 1045 insertions(+), 1563 deletions(-) diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index ebf5b07ad1..27f9ce6f60 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -71,7 +71,7 @@ Driver calls take place to carry out the gemm operations. */ class FusedMoeLauncher { -protected: + protected: at::Tensor const* routing_logits{}; at::Tensor const* routing_bias{}; at::Tensor const* hidden_states{}; @@ -267,9 +267,10 @@ protected: void prepare_moe_common(int64_t& moe_tactic) { using RunnerType = tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner; - moe_runner = std::make_unique( - this->mDtypeAct, this->mDtypeWeights, args->mUseDeepSeekFp8, (int32_t)tile_tokens_dim, - static_cast(this->gated_act_type), this->use_shuffled_weight, this->weight_layout); + moe_runner = std::make_unique(this->mDtypeAct, this->mDtypeWeights, + args->mUseDeepSeekFp8, (int32_t)tile_tokens_dim, + static_cast(this->gated_act_type), + this->use_shuffled_weight, this->weight_layout); if (moe_tactic == -1) { moe_tactic = moe_runner->getDefaultValidConfigIndex( @@ -496,8 +497,8 @@ at::Tensor trtllm_bf16_moe(at::Tensor const& routing_logits, Bf16MoeLauncher launcher; launcher.init(routing_logits, routing_bias, hidden_states, gemm1_weights, gemm2_weights, - std::move(args), tile_tokens_dim, routing_method_type, use_shuffled_weight, - weight_layout); + std::move(args), tile_tokens_dim, routing_method_type, use_shuffled_weight, + weight_layout); auto data = launcher.run(moe_tactic, enable_pdl)[0]; return data; } @@ -761,27 +762,6 @@ at::Tensor trtllm_fp8_per_tensor_scale_moe( auto dtype = hidden_states.dtype(); if (dtype == at::ScalarType::Half || dtype == at::ScalarType::BFloat16 || dtype == at::ScalarType::Float8_e4m3fn) { - // // Create unified runner for FP8 per-tensor mode - // using RunnerType = tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner; - // auto mRunner = std::make_unique( - // btg::Dtype::E4m3, false, tile_tokens_dim, /*useShuffledMatrixA*/ true); - - // auto const moeConfigIndex = mRunner->getDefaultValidConfigIndex( - // top_k, hidden_states.sizes()[1], intermediate_size, local_num_experts, - // hidden_states.sizes()[0]); - - // // Call unified launcher with nullopt for expert_indices, expert_weights, and output (will be - // created internally) auto results = trtllm_fp4_block_scale_moe_launcher( - // routing_logits, std::nullopt, std::nullopt, routing_bias, hidden_states, std::nullopt, - // gemm1_weights, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, - // gemm2_weights, std::nullopt, std::nullopt, - // output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar, - // num_experts, top_k, n_group, topk_group, intermediate_size, local_expert_offset, - // local_num_experts, routed_scaling_factor, tile_tokens_dim, routing_method_type, true, // - // do_finalize = true *mRunner, btg::Dtype::E4m3, btg::Dtype::E4m3, moeConfigIndex, - // enable_pdl); - - // return results[0]; // Return the first tensor from the vector return trtllm_fp8_per_tensor_scale_moe_launcher( routing_logits, routing_bias, hidden_states, gemm1_weights, output1_scales_scalar, output1_scales_gate_scalar, gemm2_weights, output2_scales_scalar, num_experts, top_k, diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmEnums.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmEnums.h index 9052618d1f..27955d2bdc 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmEnums.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmEnums.h @@ -1,23 +1,23 @@ /* -* SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & -* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once -#include #include +#include namespace batchedGemm { @@ -36,26 +36,20 @@ enum class RouteImpl { //////////////////////////////////////////////////////////////////////////////////////////////////// -inline bool doesRouteImplUseNoRoute(RouteImpl mode) { - return (mode == RouteImpl::NoRoute); -} +inline bool doesRouteImplUseNoRoute(RouteImpl mode) { return (mode == RouteImpl::NoRoute); } //////////////////////////////////////////////////////////////////////////////////////////////////// -inline bool doesRouteImplUseLdgsts(RouteImpl mode) { - return (mode == RouteImpl::Ldgsts); -} +inline bool doesRouteImplUseLdgsts(RouteImpl mode) { return (mode == RouteImpl::Ldgsts); } //////////////////////////////////////////////////////////////////////////////////////////////////// -inline bool doesRouteImplUseTma(RouteImpl mode) { - return (mode == RouteImpl::Tma); -} +inline bool doesRouteImplUseTma(RouteImpl mode) { return (mode == RouteImpl::Tma); } //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace batchedGemm +} // namespace batchedGemm //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace batchedGemm +} // namespace batchedGemm diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h index fbf6d56ce6..3a72aab90c 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h @@ -454,25 +454,20 @@ struct BatchedGemmData { //////////////////////////////////////////////////////////////////////////////////////////////////// class BatchedGemmInterface { -public: + public: using ModuleCache = std::unordered_map>; BatchedGemmInterface() {} // Launch the cubin from the provided config. It calls all necessary memsets for internal buffers. // Provided config must be validated with isValidConfig before the call. - int32_t run(BatchedGemmConfig const& config, - void* workspace, - BatchedGemmData const& options, - void* cudaStream, - int32_t multiProcessorCount, - bool usePdl = true, + int32_t run(BatchedGemmConfig const& config, void* workspace, BatchedGemmData const& options, + void* cudaStream, int32_t multiProcessorCount, bool usePdl = true, std::optional> moduleCache = std::nullopt); // Initializes the buffers before the world sync. Must be called before run. int32_t runInitBeforeWorldSync(BatchedGemmConfig const& /* config */, - BatchedGemmData const& /* data */, - void* /* cudaStream */) const { + BatchedGemmData const& /* data */, void* /* cudaStream */) const { return 0; }; @@ -487,8 +482,8 @@ class BatchedGemmInterface { // Returns the grid dimensions of the current kernel. std::tuple getGridDim( - BatchedGemmOptions const& options, - std::optional maxNumCtasInBatchDim = std::nullopt) const { + BatchedGemmOptions const& options, + std::optional maxNumCtasInBatchDim = std::nullopt) const { bool const batchM = options.mBatchMode == BatchedGemmOptions::BatchMode::BatchM; int32_t numCtasBatch{0}; @@ -518,7 +513,7 @@ class BatchedGemmInterface { } int32_t numCtasTile = - batchM ? gemm::divUp(options.mN, options.mTileN) : gemm::divUp(options.mM, options.mTileM); + batchM ? gemm::divUp(options.mN, options.mTileN) : gemm::divUp(options.mM, options.mTileM); if (batchM) { numCtasTile = gemm::divUpMul(numCtasTile, options.mClusterDimY); } else { @@ -530,7 +525,7 @@ class BatchedGemmInterface { // Creates GemmOptions from kernel and data. BatchedGemmOptions getOptionsFromConfigAndData(BatchedGemmConfig const& config, - BatchedGemmData const& data) const; + BatchedGemmData const& data) const; // Returns the number of CTAs of the current kernel. int32_t getNumCtas(BatchedGemmOptions const& options, @@ -542,9 +537,10 @@ class BatchedGemmInterface { // Returns true if the configuration of the cubin can be executed for the given params. bool isValidConfig(BatchedGemmConfig const& config, BatchedGemmData const& data) const; -private: + private: // Aligns the pointer to the alignment - template inline Dtype* alignPtr(Dtype* ptr, int64_t alignment) const; + template + inline Dtype* alignPtr(Dtype* ptr, int64_t alignment) const; // Returns the size of the workspace buffers in bytes std::vector getWorkspaceSizesInBytes(BatchedGemmConfig const& config, @@ -587,8 +583,7 @@ size_t BatchedGemmInterface::getNumBatchedGemmConfigs() const { //////////////////////////////////////////////////////////////////////////////////////////////////// BatchedGemmOptions BatchedGemmInterface::getOptionsFromConfigAndData( - BatchedGemmConfig const& config, - BatchedGemmData const& data) const { + BatchedGemmConfig const& config, BatchedGemmData const& data) const { // Create options from config and data. BatchedGemmOptions options; options = config.mOptions; @@ -615,8 +610,7 @@ bool BatchedGemmInterface::isValidConfig(BatchedGemmConfig const& config, bool isBlackwell = gemm::isSmVersionBlackwell(config.mSm); // Check options without modifications. - return checkAndUpdateBatchedGemmOptions(options, - isBlackwell, + return checkAndUpdateBatchedGemmOptions(options, isBlackwell, /* updateOptions */ false); } @@ -640,9 +634,7 @@ size_t BatchedGemmInterface::getWorkspaceSizeInBytes(BatchedGemmConfig const& co //////////////////////////////////////////////////////////////////////////////////////////////////// std::vector BatchedGemmInterface::getWorkspaceSizesInBytes( - BatchedGemmConfig const& config, - BatchedGemmData const& data) const { - + BatchedGemmConfig const& config, BatchedGemmData const& data) const { std::vector workspaceSizes; // Get options from config and data. @@ -686,12 +678,9 @@ std::vector BatchedGemmInterface::getWorkspaceSizesInBytes( } //////////////////////////////////////////////////////////////////////////////////////////////////// -int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, - void* workspace, - BatchedGemmData const& batchedGemmData, - void* cudaStream, - int32_t /* multiProcessorCount */, - bool usePdl, +int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, void* workspace, + BatchedGemmData const& batchedGemmData, void* cudaStream, + int32_t /* multiProcessorCount */, bool usePdl, std::optional> moduleCache) { // Might be used. (void)usePdl; @@ -720,32 +709,20 @@ int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, } auto [numCtaBatch, numCtaTile, numCtaInner] = - getGridDim(options, batchedGemmData.mProblemDimensions.mMaxNumCtasInTokenDim); - auto kernelParams = - KernelParamsSetup::setKernelParams(options, - batchM, - batchedGemmData.mInputBuffers.mPtrA, - batchedGemmData.mInputBuffers.mPtrB, - batchedGemmData.mOutputBuffers.mPtrC, - batchedGemmData.mInputBuffers.mPtrSfA, - batchedGemmData.mInputBuffers.mPtrSfB, - batchedGemmData.mInputBuffers.mPtrPerTokenSfA, - batchedGemmData.mInputBuffers.mPtrPerTokenSfB, - batchedGemmData.mInputBuffers.mPtrBias, - batchedGemmData.mOutputBuffers.mPtrSfC, - batchedGemmData.mInputBuffers.mPtrScaleC, - batchedGemmData.mInputBuffers.mPtrScaleGate, - batchedGemmData.mInputBuffers.mPtrClampLimit, - batchedGemmData.mInputBuffers.mPtrGatedActAlpha, - batchedGemmData.mInputBuffers.mPtrGatedActBeta, - batchedGemmData.mInputBuffers.mPtrRouteMap, - dPtrRowMax, - dPtrRowMaxBars, - batchedGemmData.mInputBuffers.mPtrNumNonExitingCtas, - batchedGemmData.mInputBuffers.mPtrTotalNumPaddedTokens, - batchedGemmData.mInputBuffers.mPtrCtaIdxXyToBatchIdx, - batchedGemmData.mInputBuffers.mPtrCtaIdxXyToMnLimit, - numCtaBatch); + getGridDim(options, batchedGemmData.mProblemDimensions.mMaxNumCtasInTokenDim); + auto kernelParams = KernelParamsSetup::setKernelParams( + options, batchM, batchedGemmData.mInputBuffers.mPtrA, batchedGemmData.mInputBuffers.mPtrB, + batchedGemmData.mOutputBuffers.mPtrC, batchedGemmData.mInputBuffers.mPtrSfA, + batchedGemmData.mInputBuffers.mPtrSfB, batchedGemmData.mInputBuffers.mPtrPerTokenSfA, + batchedGemmData.mInputBuffers.mPtrPerTokenSfB, batchedGemmData.mInputBuffers.mPtrBias, + batchedGemmData.mOutputBuffers.mPtrSfC, batchedGemmData.mInputBuffers.mPtrScaleC, + batchedGemmData.mInputBuffers.mPtrScaleGate, batchedGemmData.mInputBuffers.mPtrClampLimit, + batchedGemmData.mInputBuffers.mPtrGatedActAlpha, + batchedGemmData.mInputBuffers.mPtrGatedActBeta, batchedGemmData.mInputBuffers.mPtrRouteMap, + dPtrRowMax, dPtrRowMaxBars, batchedGemmData.mInputBuffers.mPtrNumNonExitingCtas, + batchedGemmData.mInputBuffers.mPtrTotalNumPaddedTokens, + batchedGemmData.mInputBuffers.mPtrCtaIdxXyToBatchIdx, + batchedGemmData.mInputBuffers.mPtrCtaIdxXyToMnLimit, numCtaBatch); // The size of the grid. std::vector grid = batchM ? std::vector{numCtaBatch, numCtaTile, numCtaInner} @@ -829,8 +806,8 @@ int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace batchedGemm +} // namespace batchedGemm //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace batchedGemm +} // namespace batchedGemm diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h index 0178148ae4..f2a7d5dafd 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h @@ -1,57 +1,57 @@ /* -* SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & -* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once -#include "GemmOptions.h" -#include "GemmGatedActOptions.h" -#include "BatchedGemmEnums.h" - #include #include +#include "BatchedGemmEnums.h" +#include "GemmGatedActOptions.h" +#include "GemmOptions.h" + #ifndef TLLM_GEN_EXPORT_INTERFACE -#include "trtllm/gen/GenCtx.h" #include "trtllm/gen/CudaRunner.h" +#include "trtllm/gen/GenCtx.h" #else #include -#define TLLM_CHECK_ERROR(cond, ...) \ - if (!(cond)) { \ - printArgs(__VA_ARGS__); \ - printArgs("\n"); \ - return false; \ +#define TLLM_CHECK_ERROR(cond, ...) \ + if (!(cond)) { \ + printArgs(__VA_ARGS__); \ + printArgs("\n"); \ + return false; \ } #define TLLM_LOG_ERROR(...) TLLM_CHECK_ERROR(false, __VA_ARGS__) #define TLLM_CHECK_ERROR_FMT(cond, ...) TLLM_CHECK_ERROR(cond, __VA_ARGS__) -#define TLLM_CHECK_WARNING(cond, ...) \ - if (!(cond)) { \ - printArgs(__VA_ARGS__); \ - printArgs("\n"); \ - return false; \ +#define TLLM_CHECK_WARNING(cond, ...) \ + if (!(cond)) { \ + printArgs(__VA_ARGS__); \ + printArgs("\n"); \ + return false; \ } #define TLLM_LOG_WARNING(...) TLLM_CHECK_WARNING(false, __VA_ARGS__) #define TLLM_LOG_INFO(...) TLLM_CHECK_WARNING(false, __VA_ARGS__) -#endif // TLLM_GEN_EXPORT_INTERFACE +#endif // TLLM_GEN_EXPORT_INTERFACE namespace batchedGemm { @@ -67,7 +67,6 @@ namespace tg = trtllm::gen; // We inherit from GemmGatedActOptions, which is inherited from // GemmOptions to get GemmOptions and GemmGatedActOptions at the same time. struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions { - // Dtor. Allow down-casting. virtual ~BatchedGemmOptions() = default; @@ -75,178 +74,60 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions { BatchedGemmOptions() = default; // FIXME We create explicit constructor with all options to WAR stubgen issue in TRT-LLM. - BatchedGemmOptions(gemm::AllReduceAlgo allReduceAlgo, - gemm::BiasType biasType, - int blockK, - int clusterDimX, - int clusterDimY, - int clusterDimZ, - gemm::CtaSwizzleType ctaSwizzleType, - tg::Dtype dtypeAcc, - tg::Dtype dtypeA, - tg::Dtype dtypeB, - tg::Dtype dtypeC, - tg::Dtype dtypeMmaA, - tg::Dtype dtypeMmaB, - bool enablesEarlyExit, - bool enablesDelayedEarlyExit, - bool enablesGlobalPtxKnobs, - int epilogueLdtmDps, - int epilogueLdtmBits, - int epilogueTileM, - int epilogueTileN, - bool gridTriggerSecondaryA, - bool gridTriggerSecondaryB, - bool gridWaitForPrimaryEarlyExit, - bool gridWaitForPrimaryA, - bool gridWaitForPrimaryB, - bool hoistLoadTaskInit, - bool hoistMmaTaskTryWaits, - int k, - gemm::KernelTraits kernelTraits, - gemm::MatrixLayout layoutA, - gemm::MatrixLayout layoutB, - int m, - int mmaK, - tg::MmaKind mmaKind, - int mmaM, - int mmaN, - bool mockAllReduce, - int n, - int numRegsCastAWarps, - int numRegsCopySfLdsSttm, - int numRegsPerThreadEpilogueWarp, - int numRegsPerThreadNonEpilogueWarp, - int numSlicesForSplitK, - int numSlicesForSliceK, - int numStages, - int numStagesMma, - int numStagesMmaWithinWorkTile, - int numStagesMmaAcrossWorkTile, - int numStagesWorkId, - bool outputDebugTensors, - bool patchF2fp, - std::optional sfBlockSizeA, - tg::SfLayout sfLayoutA, - tg::SfLayout sfLayoutB, - tg::SfLayout sfLayoutC, - int32_t sfReshapeFactor, - bool sliceK, - gemm::SplitK splitK, - int tileK, - int tileM, - int tileN, - gemm::TileScheduler tileScheduler, - bool transposeMmaOutput, - bool useCustomMmaSchedule, - bool useDeepSeekFp8, - bool useHoistTryWaitForCustomMmaSchedule, - bool usePerTokenSfA, - bool usePerTokenSfB, - bool useShuffledMatrixA, - bool useTmaStore, - bool useTwoTmaLoadWarps, - bool useTwoMmaWarps, - bool useUnrollLoop2xForMma, - int worldSize, - gemmGatedAct::ActType actType, - bool clampBeforeAct, - std::vector batchedM, - std::vector batchedN, - BatchMode batchMode, - int numBatches, - bool isStaticBatch, - int numTokens, - RouteImpl routeImpl, - bool gridWaitForPrimaryRouting, - bool fusedAct, - bool useTmaOobOpt) - : gemmGatedAct::GemmGatedActOptions(gemm::GemmOptions(allReduceAlgo, - biasType, - blockK, - clusterDimX, - clusterDimY, - clusterDimZ, - ctaSwizzleType, - dtypeAcc, - dtypeA, - dtypeB, - dtypeC, - dtypeMmaA, - dtypeMmaB, - enablesEarlyExit, - enablesDelayedEarlyExit, - enablesGlobalPtxKnobs, - epilogueLdtmDps, - epilogueLdtmBits, - epilogueTileM, - epilogueTileN, - gridTriggerSecondaryA, - gridTriggerSecondaryB, - gridWaitForPrimaryEarlyExit, - gridWaitForPrimaryA, - gridWaitForPrimaryB, - hoistLoadTaskInit, - hoistMmaTaskTryWaits, - k, - kernelTraits, - layoutA, - layoutB, - m, - mmaK, - mmaKind, - mmaM, - mmaN, - mockAllReduce, - n, - numRegsCastAWarps, - numRegsCopySfLdsSttm, - numRegsPerThreadEpilogueWarp, - numRegsPerThreadNonEpilogueWarp, - numSlicesForSplitK, - numSlicesForSliceK, - numStages, - numStagesMma, - numStagesMmaWithinWorkTile, - numStagesMmaAcrossWorkTile, - numStagesWorkId, - outputDebugTensors, - patchF2fp, - sfBlockSizeA, - sfLayoutA, - sfLayoutB, - sfLayoutC, - sfReshapeFactor, - sliceK, - splitK, - tileK, - tileM, - tileN, - tileScheduler, - transposeMmaOutput, - useCustomMmaSchedule, - useDeepSeekFp8, - useHoistTryWaitForCustomMmaSchedule, - usePerTokenSfA, - usePerTokenSfB, - useShuffledMatrixA, - useTmaStore, - useTwoTmaLoadWarps, - useTwoMmaWarps, - useUnrollLoop2xForMma, - worldSize), - actType, - clampBeforeAct) - , mBatchedM(batchedM) - , mBatchedN(batchedN) - , mBatchMode(BatchMode(batchMode)) - , mFusedAct(fusedAct) - , mGridWaitForPrimaryRouting(gridWaitForPrimaryRouting) - , mIsStaticBatch(isStaticBatch) - , mNumBatches(numBatches) - , mNumTokens(numTokens) - , mRouteImpl(routeImpl) - , mUseTmaOobOpt(useTmaOobOpt) {} + BatchedGemmOptions( + gemm::AllReduceAlgo allReduceAlgo, gemm::BiasType biasType, int blockK, int clusterDimX, + int clusterDimY, int clusterDimZ, gemm::CtaSwizzleType ctaSwizzleType, tg::Dtype dtypeAcc, + tg::Dtype dtypeA, tg::Dtype dtypeB, tg::Dtype dtypeC, tg::Dtype dtypeMmaA, + tg::Dtype dtypeMmaB, bool enablesEarlyExit, bool enablesDelayedEarlyExit, + bool enablesGlobalPtxKnobs, int epilogueLdtmDps, int epilogueLdtmBits, int epilogueTileM, + int epilogueTileN, bool gridTriggerSecondaryA, bool gridTriggerSecondaryB, + bool gridWaitForPrimaryEarlyExit, bool gridWaitForPrimaryA, bool gridWaitForPrimaryB, + bool hoistLoadTaskInit, bool hoistMmaTaskTryWaits, int k, gemm::KernelTraits kernelTraits, + gemm::MatrixLayout layoutA, gemm::MatrixLayout layoutB, int m, int mmaK, tg::MmaKind mmaKind, + int mmaM, int mmaN, bool mockAllReduce, int n, int numRegsCastAWarps, + int numRegsCopySfLdsSttm, int numRegsPerThreadEpilogueWarp, + int numRegsPerThreadNonEpilogueWarp, int numSlicesForSplitK, int numSlicesForSliceK, + int numStages, int numStagesMma, int numStagesMmaWithinWorkTile, + int numStagesMmaAcrossWorkTile, int numStagesWorkId, bool outputDebugTensors, bool patchF2fp, + std::optional sfBlockSizeA, tg::SfLayout sfLayoutA, tg::SfLayout sfLayoutB, + tg::SfLayout sfLayoutC, int32_t sfReshapeFactor, bool sliceK, gemm::SplitK splitK, int tileK, + int tileM, int tileN, gemm::TileScheduler tileScheduler, bool transposeMmaOutput, + bool useCustomMmaSchedule, bool useDeepSeekFp8, bool useHoistTryWaitForCustomMmaSchedule, + bool usePerTokenSfA, bool usePerTokenSfB, bool useShuffledMatrixA, bool useTmaStore, + bool useTwoTmaLoadWarps, bool useTwoMmaWarps, bool useUnrollLoop2xForMma, int worldSize, + gemmGatedAct::ActType actType, bool clampBeforeAct, std::vector batchedM, + std::vector batchedN, BatchMode batchMode, int numBatches, bool isStaticBatch, + int numTokens, RouteImpl routeImpl, bool gridWaitForPrimaryRouting, bool fusedAct, + bool useTmaOobOpt) + : gemmGatedAct::GemmGatedActOptions( + gemm::GemmOptions( + allReduceAlgo, biasType, blockK, clusterDimX, clusterDimY, clusterDimZ, + ctaSwizzleType, dtypeAcc, dtypeA, dtypeB, dtypeC, dtypeMmaA, dtypeMmaB, + enablesEarlyExit, enablesDelayedEarlyExit, enablesGlobalPtxKnobs, epilogueLdtmDps, + epilogueLdtmBits, epilogueTileM, epilogueTileN, gridTriggerSecondaryA, + gridTriggerSecondaryB, gridWaitForPrimaryEarlyExit, gridWaitForPrimaryA, + gridWaitForPrimaryB, hoistLoadTaskInit, hoistMmaTaskTryWaits, k, kernelTraits, + layoutA, layoutB, m, mmaK, mmaKind, mmaM, mmaN, mockAllReduce, n, numRegsCastAWarps, + numRegsCopySfLdsSttm, numRegsPerThreadEpilogueWarp, numRegsPerThreadNonEpilogueWarp, + numSlicesForSplitK, numSlicesForSliceK, numStages, numStagesMma, + numStagesMmaWithinWorkTile, numStagesMmaAcrossWorkTile, numStagesWorkId, + outputDebugTensors, patchF2fp, sfBlockSizeA, sfLayoutA, sfLayoutB, sfLayoutC, + sfReshapeFactor, sliceK, splitK, tileK, tileM, tileN, tileScheduler, + transposeMmaOutput, useCustomMmaSchedule, useDeepSeekFp8, + useHoistTryWaitForCustomMmaSchedule, usePerTokenSfA, usePerTokenSfB, + useShuffledMatrixA, useTmaStore, useTwoTmaLoadWarps, useTwoMmaWarps, + useUnrollLoop2xForMma, worldSize), + actType, clampBeforeAct), + mBatchedM(batchedM), + mBatchedN(batchedN), + mBatchMode(BatchMode(batchMode)), + mFusedAct(fusedAct), + mGridWaitForPrimaryRouting(gridWaitForPrimaryRouting), + mIsStaticBatch(isStaticBatch), + mNumBatches(numBatches), + mNumTokens(numTokens), + mRouteImpl(routeImpl), + mUseTmaOobOpt(useTmaOobOpt) {} // Batched M-dimensions of GEMM. std::vector mBatchedM; @@ -275,10 +156,8 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions { //////////////////////////////////////////////////////////////////////////////////////////////////// // Check if the options are valid or not. -bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, - bool isBlackwell, +bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackwell, bool updateOptions = true) { - bool isValid = true; if (options.mUseTmaOobOpt && !options.mUseTwoTmaLoadWarps) { if (updateOptions) { @@ -295,7 +174,7 @@ bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, isValid = gemmGatedAct::checkAndUpdateGemmGatedActOptions(options, isBlackwell, updateOptions); } else { isValid = - gemm::checkAndUpdateGemmOptions(options, isBlackwell, 1 /* tpGrpSize */, updateOptions); + gemm::checkAndUpdateGemmOptions(options, isBlackwell, 1 /* tpGrpSize */, updateOptions); } bool batchM = options.mBatchMode == BatchedGemmOptions::BatchMode::BatchM; @@ -347,7 +226,7 @@ bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, options.mK); TLLM_CHECK_ERROR(options.mDtypeC != tg::Dtype::E2m1 && options.mDtypeA == tg::Dtype::E4m3 && - options.mDtypeB == tg::Dtype::E4m3, + options.mDtypeB == tg::Dtype::E4m3, "E2m1 is not supported with DeepSeek FP8"); } @@ -379,9 +258,9 @@ bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, "Tokens need use SF linear layout when being routed"); } else { // Note: if B is cast from a non-block format to a block format, there are no SFs to load. - TLLM_CHECK_ERROR(options.mSfLayoutB == tg::SfLayout::Linear || - !tg::dtypeIsBlockFmt(options.mDtypeB), - "Tokens need use SF linear layout when being routed"); + TLLM_CHECK_ERROR( + options.mSfLayoutB == tg::SfLayout::Linear || !tg::dtypeIsBlockFmt(options.mDtypeB), + "Tokens need use SF linear layout when being routed"); } } @@ -398,8 +277,8 @@ bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, if (!gemm::isBiasTypeNone(options.mBiasType)) { TLLM_CHECK_ERROR((gemm::isBiasTypeN(options.mBiasType) && options.mBatchMode == BatchedGemmOptions::BatchMode::BatchM) || - (gemm::isBiasTypeM(options.mBiasType) && - options.mBatchMode == BatchedGemmOptions::BatchMode::BatchN), + (gemm::isBiasTypeM(options.mBiasType) && + options.mBatchMode == BatchedGemmOptions::BatchMode::BatchN), "BatchedGemm supports only per channel bias."); } @@ -460,7 +339,7 @@ inline std::string dumpOptions(BatchedGemmOptions const& options) { //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace batchedGemm +} // namespace batchedGemm //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -473,6 +352,6 @@ inline std::string dumpOptions(BatchedGemmOptions const& options) { #undef TLLM_LOG_INFO #undef TLLM_LOG_ERROR -#endif // TLLM_GEN_EXPORT_INTERFACE +#endif // TLLM_GEN_EXPORT_INTERFACE -} // namespace batchedGemm +} // namespace batchedGemm diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h index 79c6c027d0..e9d5a23a65 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h @@ -1,19 +1,19 @@ /* -* SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & -* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once #include @@ -116,10 +116,8 @@ enum class CtaSwizzleType : uint32_t { // Helper functions to check the SplitK type. -#define SPLIT_K_FUNCTION(Mode) \ - inline bool doesSplitKUse##Mode(SplitK mode) { \ - return (mode == SplitK::Mode); \ - } +#define SPLIT_K_FUNCTION(Mode) \ + inline bool doesSplitKUse##Mode(SplitK mode) { return (mode == SplitK::Mode); } SPLIT_K_FUNCTION(Gmem) SPLIT_K_FUNCTION(Dsmem) @@ -130,10 +128,8 @@ SPLIT_K_FUNCTION(Dsmem) // Helper functions to check the Bias type. -#define BIAS_TYPE_FUNCTION(Mode) \ - inline bool isBiasType##Mode(BiasType type) { \ - return (type == BiasType::Mode); \ - } +#define BIAS_TYPE_FUNCTION(Mode) \ + inline bool isBiasType##Mode(BiasType type) { return (type == BiasType::Mode); } BIAS_TYPE_FUNCTION(None) BIAS_TYPE_FUNCTION(N) @@ -144,6 +140,6 @@ BIAS_TYPE_FUNCTION(Mn) //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace gemm +} // namespace gemm -} // namespace batchedGemm +} // namespace batchedGemm diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h index c9a9a4663f..e796bcc23c 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h @@ -1,19 +1,19 @@ /* -* SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & -* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once #include "GemmOptions.h" @@ -21,20 +21,20 @@ #ifdef TLLM_GEN_EXPORT_INTERFACE #include -#define TLLM_CHECK_ERROR(cond, ...) \ - if (!(cond)) { \ - printArgs(__VA_ARGS__); \ - return false; \ +#define TLLM_CHECK_ERROR(cond, ...) \ + if (!(cond)) { \ + printArgs(__VA_ARGS__); \ + return false; \ } #define TLLM_LOG_ERROR(...) TLLM_CHECK_ERROR(false, __VA_ARGS__) #define TLLM_CHECK_ERROR_FMT(...) TLLM_CHECK_ERROR(false, __VA_ARGS__) -#define TLLM_CHECK_WARNING(cond, ...) \ - if (!(cond)) { \ - printArgs(__VA_ARGS__); \ - return false; \ +#define TLLM_CHECK_WARNING(cond, ...) \ + if (!(cond)) { \ + printArgs(__VA_ARGS__); \ + return false; \ } #define TLLM_LOG_WARNING(...) TLLM_CHECK_WARNING(false, __VA_ARGS__) @@ -74,10 +74,8 @@ enum class ActType { // Helper functions to check the ActType type. -#define TLLM_ACT_TYPE_FUNCTION(actType) \ - inline bool is##actType(ActType type) { \ - return (type == ActType::actType); \ - } +#define TLLM_ACT_TYPE_FUNCTION(actType) \ + inline bool is##actType(ActType type) { return (type == ActType::actType); } TLLM_ACT_TYPE_FUNCTION(SwiGlu) TLLM_ACT_TYPE_FUNCTION(GeGlu) @@ -88,12 +86,12 @@ TLLM_ACT_TYPE_FUNCTION(GeGlu) inline std::string getActTypeName(ActType type) { switch (type) { - case ActType::SwiGlu: - return "SwiGlu"; - case ActType::GeGlu: - return "GeGlu"; - default: - return "Unknown type"; + case ActType::SwiGlu: + return "SwiGlu"; + case ActType::GeGlu: + return "GeGlu"; + default: + return "Unknown type"; } } @@ -102,9 +100,7 @@ inline std::string getActTypeName(ActType type) { struct GemmGatedActOptions : public gemm::GemmOptions { GemmGatedActOptions() = default; GemmGatedActOptions(gemm::GemmOptions options, ActType actType, bool clampBeforeAct) - : gemm::GemmOptions(options) - , mActType(actType) - , mClampBeforeAct(clampBeforeAct) {} + : gemm::GemmOptions(options), mActType(actType), mClampBeforeAct(clampBeforeAct) {} // Type of the gated activation. ActType mActType{ActType::SwiGlu}; @@ -116,14 +112,12 @@ struct GemmGatedActOptions : public gemm::GemmOptions { // Check if the options are valid or not. inline bool checkAndUpdateGemmGatedActOptions(gemmGatedAct::GemmGatedActOptions& options, - bool isBlackwell, - bool updateOptions = true) { - + bool isBlackwell, bool updateOptions = true) { // tmpOut is already transposed at this stage auto const hiddenSizeStr = options.mTransposeMmaOutput ? "M" : "N"; auto const hiddenSize = options.mTransposeMmaOutput ? options.mM : options.mN; auto const hiddenEpilogueTileSize = - options.mTransposeMmaOutput ? options.mEpilogueTileM : options.mEpilogueTileN; + options.mTransposeMmaOutput ? options.mEpilogueTileM : options.mEpilogueTileN; TLLM_CHECK_ERROR(hiddenSize % 2 == 0, hiddenSizeStr, " must be a multiple of 2."); @@ -132,8 +126,8 @@ inline bool checkAndUpdateGemmGatedActOptions(gemmGatedAct::GemmGatedActOptions& if (options.mUseTmaStore) { TLLM_CHECK_ERROR( - hiddenEpilogueTileSize * tg::dtypeGetNumBits(options.mDtypeC) / /* bits */ 8 % 32 == 0, - "Unsupported output hidden tile size"); + hiddenEpilogueTileSize * tg::dtypeGetNumBits(options.mDtypeC) / /* bits */ 8 % 32 == 0, + "Unsupported output hidden tile size"); } if (options.mUseDeepSeekFp8) { @@ -143,18 +137,12 @@ inline bool checkAndUpdateGemmGatedActOptions(gemmGatedAct::GemmGatedActOptions& if (options.mDtypeC == tg::Dtype::E2m1 || options.mDtypeC == tg::Dtype::MxE4m3) { int const outHiddenSize = (options.mTransposeMmaOutput ? options.mM : options.mN) / 2; int const hiddenGranularity = 4 * tg::dtypeNumEltsPerSf(options.mDtypeC); - TLLM_CHECK_ERROR(outHiddenSize % hiddenGranularity == 0, - "Output hidden size (", - outHiddenSize, - ") must be a multiple of ", - hiddenGranularity, - " for block-scaled outputs."); + TLLM_CHECK_ERROR(outHiddenSize % hiddenGranularity == 0, "Output hidden size (", outHiddenSize, + ") must be a multiple of ", hiddenGranularity, " for block-scaled outputs."); } - auto isValid = gemm::checkAndUpdateGemmOptions(options, - isBlackwell, - /* tpGrpSize */ 1, - updateOptions); + auto isValid = gemm::checkAndUpdateGemmOptions(options, isBlackwell, + /* tpGrpSize */ 1, updateOptions); if (!isValid) { return false; @@ -212,7 +200,7 @@ struct GemmGatedActConfig { //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace gemmGatedAct +} // namespace gemmGatedAct #ifdef TLLM_GEN_EXPORT_INTERFACE @@ -222,6 +210,6 @@ struct GemmGatedActConfig { #undef TLLM_LOG_WARNING #undef TLLM_LOG_INFO #undef TLLM_LOG_ERROR -#endif // TLLM_GEN_EXPORT_INTERFACE +#endif // TLLM_GEN_EXPORT_INTERFACE -} // namespace batchedGemm +} // namespace batchedGemm diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h index 18386f3a2b..8710da2a85 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h @@ -1,19 +1,19 @@ /* -* SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & -* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once #include @@ -24,21 +24,23 @@ #include "KernelParams.h" #include "KernelTraits.h" #include "trtllm/gen/DtypeDecl.h" -#include "trtllm/gen/SfLayoutDecl.h" #include "trtllm/gen/MmaDecl.h" +#include "trtllm/gen/SfLayoutDecl.h" #ifndef TLLM_GEN_EXPORT_INTERFACE -#include "trtllm/gen/GenCtx.h" #include "trtllm/gen/CudaRunner.h" +#include "trtllm/gen/GenCtx.h" #else #include -template void printArgs(T arg) { +template +void printArgs(T arg) { #ifdef TLLM_GEN_DEBUG std::cout << arg; #endif } -template void printArgs(T first, Args... args) { +template +void printArgs(T first, Args... args) { printArgs(first); if constexpr (sizeof...(args) > 0) { printArgs(", "); @@ -46,29 +48,29 @@ template void printArgs(T first, Args... args) { } } -#define TLLM_CHECK_ERROR(cond, ...) \ - if (!(cond)) { \ - printArgs(__VA_ARGS__); \ - printArgs("\n"); \ - return false; \ +#define TLLM_CHECK_ERROR(cond, ...) \ + if (!(cond)) { \ + printArgs(__VA_ARGS__); \ + printArgs("\n"); \ + return false; \ } #define TLLM_LOG_ERROR(...) TLLM_CHECK_ERROR(false, __VA_ARGS__) #define TLLM_CHECK_ERROR_FMT(cond, ...) TLLM_CHECK_ERROR(cond, __VA_ARGS__) -#define TLLM_CHECK_WARNING(cond, ...) \ - if (!(cond)) { \ - printArgs(__VA_ARGS__); \ - printArgs("\n"); \ - return false; \ +#define TLLM_CHECK_WARNING(cond, ...) \ + if (!(cond)) { \ + printArgs(__VA_ARGS__); \ + printArgs("\n"); \ + return false; \ } #define TLLM_LOG_WARNING(...) TLLM_CHECK_WARNING(false, __VA_ARGS__) #define TLLM_LOG_INFO(...) TLLM_CHECK_WARNING(false, __VA_ARGS__) -#endif // TLLM_GEN_EXPORT_INTERFACE +#endif // TLLM_GEN_EXPORT_INTERFACE namespace batchedGemm { @@ -89,154 +91,102 @@ struct GemmOptions { #endif GemmOptions() = default; - GemmOptions(AllReduceAlgo allReduceAlgo, - BiasType biasType, - int blockK, - int clusterDimX, - int clusterDimY, - int clusterDimZ, - CtaSwizzleType ctaSwizzleType, - tg::Dtype dtypeAcc, - tg::Dtype dtypeA, - tg::Dtype dtypeB, - tg::Dtype dtypeC, - tg::Dtype dtypeMmaA, - tg::Dtype dtypeMmaB, - bool enablesEarlyExit, - bool enablesDelayedEarlyExit, - bool enablesGlobalPtxKnobs, - int epilogueLdtmDps, - int epilogueLdtmBits, - int epilogueTileM, - int epilogueTileN, - bool gridTriggerSecondaryA, - bool gridTriggerSecondaryB, - bool gridWaitForPrimaryEarlyExit, - bool gridWaitForPrimaryA, - bool gridWaitForPrimaryB, - bool hoistLoadTaskInit, - bool hoistMmaTaskTryWaits, - int k, - KernelTraits kernelTraits, - MatrixLayout layoutA, - MatrixLayout layoutB, - int m, - int mmaK, - tg::MmaKind mmaKind, - int mmaM, - int mmaN, - bool mockAllReduce, - int n, - int numRegsCastAWarps, - int numRegsCopySfLdsSttm, - int numRegsPerThreadEpilogueWarp, - int numRegsPerThreadNonEpilogueWarp, - int numSlicesForSplitK, - int numSlicesForSliceK, - int numStages, - int numStagesMma, - int numStagesMmaWithinWorkTile, - int numStagesMmaAcrossWorkTile, - int numStagesWorkId, - bool outputDebugTensors, - bool patchF2fp, - std::optional sfBlockSizeA, - tg::SfLayout sfLayoutA, - tg::SfLayout sfLayoutB, - tg::SfLayout sfLayoutC, - int sfReshapeFactor, - bool sliceK, - SplitK splitK, - int tileK, - int tileM, - int tileN, - TileScheduler tileScheduler, - bool transposeMmaOutput, - bool useCustomMmaSchedule, - bool useDeepSeekFp8, - bool useHoistTryWaitForCustomMmaSchedule, - bool usePerTokenSfA, - bool usePerTokenSfB, - bool useShuffledMatrixA, - bool useTmaStore, - bool useTwoTmaLoadWarps, - bool useTwoMmaWarps, - bool useUnrollLoop2xForMma, + GemmOptions(AllReduceAlgo allReduceAlgo, BiasType biasType, int blockK, int clusterDimX, + int clusterDimY, int clusterDimZ, CtaSwizzleType ctaSwizzleType, tg::Dtype dtypeAcc, + tg::Dtype dtypeA, tg::Dtype dtypeB, tg::Dtype dtypeC, tg::Dtype dtypeMmaA, + tg::Dtype dtypeMmaB, bool enablesEarlyExit, bool enablesDelayedEarlyExit, + bool enablesGlobalPtxKnobs, int epilogueLdtmDps, int epilogueLdtmBits, + int epilogueTileM, int epilogueTileN, bool gridTriggerSecondaryA, + bool gridTriggerSecondaryB, bool gridWaitForPrimaryEarlyExit, + bool gridWaitForPrimaryA, bool gridWaitForPrimaryB, bool hoistLoadTaskInit, + bool hoistMmaTaskTryWaits, int k, KernelTraits kernelTraits, MatrixLayout layoutA, + MatrixLayout layoutB, int m, int mmaK, tg::MmaKind mmaKind, int mmaM, int mmaN, + bool mockAllReduce, int n, int numRegsCastAWarps, int numRegsCopySfLdsSttm, + int numRegsPerThreadEpilogueWarp, int numRegsPerThreadNonEpilogueWarp, + int numSlicesForSplitK, int numSlicesForSliceK, int numStages, int numStagesMma, + int numStagesMmaWithinWorkTile, int numStagesMmaAcrossWorkTile, int numStagesWorkId, + bool outputDebugTensors, bool patchF2fp, std::optional sfBlockSizeA, + tg::SfLayout sfLayoutA, tg::SfLayout sfLayoutB, tg::SfLayout sfLayoutC, + int sfReshapeFactor, bool sliceK, SplitK splitK, int tileK, int tileM, int tileN, + TileScheduler tileScheduler, bool transposeMmaOutput, bool useCustomMmaSchedule, + bool useDeepSeekFp8, bool useHoistTryWaitForCustomMmaSchedule, bool usePerTokenSfA, + bool usePerTokenSfB, bool useShuffledMatrixA, bool useTmaStore, + bool useTwoTmaLoadWarps, bool useTwoMmaWarps, bool useUnrollLoop2xForMma, int worldSize) - : mAllReduceAlgo{allReduceAlgo} - , mBiasType{biasType} - , mBlockK(blockK) - , mClusterDimX{clusterDimX} - , mClusterDimY{clusterDimY} - , mClusterDimZ{clusterDimZ} - , mCtaSwizzleType{ctaSwizzleType} - , mDtypeAcc{dtypeAcc} - , mDtypeA{dtypeA} - , mDtypeB{dtypeB} - , mDtypeC{dtypeC} - , mDtypeMmaA{dtypeMmaA} - , mDtypeMmaB{dtypeMmaB} - , mEnablesEarlyExit{enablesEarlyExit} - , mEnablesDelayedEarlyExit{enablesDelayedEarlyExit} - , mEnablesGlobalPtxKnobs{enablesGlobalPtxKnobs} - , mEpilogueLdtmDps{epilogueLdtmDps} - , mEpilogueLdtmBits{epilogueLdtmBits} - , mEpilogueTileM{epilogueTileM} - , mEpilogueTileN{epilogueTileN} - , mGridTriggerSecondaryA{gridTriggerSecondaryA} - , mGridTriggerSecondaryB{gridTriggerSecondaryB} - , mGridWaitForPrimaryEarlyExit{gridWaitForPrimaryEarlyExit} - , mGridWaitForPrimaryA{gridWaitForPrimaryA} - , mGridWaitForPrimaryB{gridWaitForPrimaryB} - , mHoistLoadTaskInit{hoistLoadTaskInit} - , mHoistMmaTaskTryWaits{hoistMmaTaskTryWaits} - , mK{k} - , mKernelTraits{kernelTraits} - , mLayoutA{layoutA} - , mLayoutB{layoutB} - , mM{m} - , mMmaK{mmaK} - , mMmaKind{mmaKind} - , mMmaM{mmaM} - , mMmaN{mmaN} - , mMockAllReduce{mockAllReduce} - , mN{n} - , mNumRegsCastAWarps(numRegsCastAWarps) - , mNumRegsCopySfLdsSttm(numRegsCopySfLdsSttm) - , mNumRegsPerThreadEpilogueWarp(numRegsPerThreadEpilogueWarp) - , mNumRegsPerThreadNonEpilogueWarp(numRegsPerThreadNonEpilogueWarp) - , mNumSlicesForSplitK{numSlicesForSplitK} - , mNumSlicesForSliceK{numSlicesForSliceK} - , mNumStages{numStages} - , mNumStagesMma{numStagesMma} - , mNumStagesMmaWithinWorkTile{numStagesMmaWithinWorkTile} - , mNumStagesMmaAcrossWorkTile{numStagesMmaAcrossWorkTile} - , mNumStagesWorkId{numStagesWorkId} - , mOutputDebugTensors{outputDebugTensors} - , mPatchF2fp{patchF2fp} - , mSfBlockSizeA{sfBlockSizeA} - , mSfLayoutA{sfLayoutA} - , mSfLayoutB{sfLayoutB} - , mSfLayoutC{sfLayoutC} - , mSfReshapeFactor{sfReshapeFactor} - , mSliceK{sliceK} - , mSplitK{splitK} - , mTileK{tileK} - , mTileM{tileM} - , mTileN{tileN} - , mTileScheduler{tileScheduler} - , mTransposeMmaOutput{transposeMmaOutput} - , mUseCustomMmaSchedule{useCustomMmaSchedule} - , mUseDeepSeekFp8{useDeepSeekFp8} - , mUseHoistTryWaitForCustomMmaSchedule{useHoistTryWaitForCustomMmaSchedule} - , mUsePerTokenSfA{usePerTokenSfA} - , mUsePerTokenSfB{usePerTokenSfB} - , mUseShuffledMatrixA{useShuffledMatrixA} - , mUseTmaStore{useTmaStore} - , mUseTwoTmaLoadWarps{useTwoTmaLoadWarps} - , mUseTwoMmaWarps{useTwoMmaWarps} - , mUseUnrollLoop2xForMma{useUnrollLoop2xForMma} - , mWorldSize{worldSize} {} + : mAllReduceAlgo{allReduceAlgo}, + mBiasType{biasType}, + mBlockK(blockK), + mClusterDimX{clusterDimX}, + mClusterDimY{clusterDimY}, + mClusterDimZ{clusterDimZ}, + mCtaSwizzleType{ctaSwizzleType}, + mDtypeAcc{dtypeAcc}, + mDtypeA{dtypeA}, + mDtypeB{dtypeB}, + mDtypeC{dtypeC}, + mDtypeMmaA{dtypeMmaA}, + mDtypeMmaB{dtypeMmaB}, + mEnablesEarlyExit{enablesEarlyExit}, + mEnablesDelayedEarlyExit{enablesDelayedEarlyExit}, + mEnablesGlobalPtxKnobs{enablesGlobalPtxKnobs}, + mEpilogueLdtmDps{epilogueLdtmDps}, + mEpilogueLdtmBits{epilogueLdtmBits}, + mEpilogueTileM{epilogueTileM}, + mEpilogueTileN{epilogueTileN}, + mGridTriggerSecondaryA{gridTriggerSecondaryA}, + mGridTriggerSecondaryB{gridTriggerSecondaryB}, + mGridWaitForPrimaryEarlyExit{gridWaitForPrimaryEarlyExit}, + mGridWaitForPrimaryA{gridWaitForPrimaryA}, + mGridWaitForPrimaryB{gridWaitForPrimaryB}, + mHoistLoadTaskInit{hoistLoadTaskInit}, + mHoistMmaTaskTryWaits{hoistMmaTaskTryWaits}, + mK{k}, + mKernelTraits{kernelTraits}, + mLayoutA{layoutA}, + mLayoutB{layoutB}, + mM{m}, + mMmaK{mmaK}, + mMmaKind{mmaKind}, + mMmaM{mmaM}, + mMmaN{mmaN}, + mMockAllReduce{mockAllReduce}, + mN{n}, + mNumRegsCastAWarps(numRegsCastAWarps), + mNumRegsCopySfLdsSttm(numRegsCopySfLdsSttm), + mNumRegsPerThreadEpilogueWarp(numRegsPerThreadEpilogueWarp), + mNumRegsPerThreadNonEpilogueWarp(numRegsPerThreadNonEpilogueWarp), + mNumSlicesForSplitK{numSlicesForSplitK}, + mNumSlicesForSliceK{numSlicesForSliceK}, + mNumStages{numStages}, + mNumStagesMma{numStagesMma}, + mNumStagesMmaWithinWorkTile{numStagesMmaWithinWorkTile}, + mNumStagesMmaAcrossWorkTile{numStagesMmaAcrossWorkTile}, + mNumStagesWorkId{numStagesWorkId}, + mOutputDebugTensors{outputDebugTensors}, + mPatchF2fp{patchF2fp}, + mSfBlockSizeA{sfBlockSizeA}, + mSfLayoutA{sfLayoutA}, + mSfLayoutB{sfLayoutB}, + mSfLayoutC{sfLayoutC}, + mSfReshapeFactor{sfReshapeFactor}, + mSliceK{sliceK}, + mSplitK{splitK}, + mTileK{tileK}, + mTileM{tileM}, + mTileN{tileN}, + mTileScheduler{tileScheduler}, + mTransposeMmaOutput{transposeMmaOutput}, + mUseCustomMmaSchedule{useCustomMmaSchedule}, + mUseDeepSeekFp8{useDeepSeekFp8}, + mUseHoistTryWaitForCustomMmaSchedule{useHoistTryWaitForCustomMmaSchedule}, + mUsePerTokenSfA{usePerTokenSfA}, + mUsePerTokenSfB{usePerTokenSfB}, + mUseShuffledMatrixA{useShuffledMatrixA}, + mUseTmaStore{useTmaStore}, + mUseTwoTmaLoadWarps{useTwoTmaLoadWarps}, + mUseTwoMmaWarps{useTwoMmaWarps}, + mUseUnrollLoop2xForMma{useUnrollLoop2xForMma}, + mWorldSize{worldSize} {} // The all-reduce algorithm. AllReduceAlgo mAllReduceAlgo{AllReduceAlgo::None}; @@ -446,19 +396,22 @@ struct GemmConfig { //////////////////////////////////////////////////////////////////////////////////////////////////// // Serialization helpers. -template inline std::string toString(T e) { +template +inline std::string toString(T e) { return std::to_string(e); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template <> inline std::string toString(trtllm::gen::Dtype e) { +template <> +inline std::string toString(trtllm::gen::Dtype e) { return trtllm::gen::dtypeToString(e); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template <> inline std::string toString(trtllm::gen::MmaKind e) { +template <> +inline std::string toString(trtllm::gen::MmaKind e) { return trtllm::gen::mmaKindToString(e); } @@ -569,13 +522,15 @@ inline std::string dumpOptions(GemmOptions const& options) { //////////////////////////////////////////////////////////////////////////////////////////////////// -template inline T divUp(T a, T b) { +template +inline T divUp(T a, T b) { return (a + b - 1) / b; } //////////////////////////////////////////////////////////////////////////////////////////////////// -template inline T divUpMul(T a, T b) { +template +inline T divUpMul(T a, T b) { return gemm::divUp(a, b) * b; } @@ -592,9 +547,7 @@ inline int32_t getShuffleBlockSize(int epilogueTileM) { //////////////////////////////////////////////////////////////////////////////////////////////////// // Check if the options are valid or not. -inline bool checkAndUpdateGemmOptions(GemmOptions& options, - bool isBlackwell, - int tpGrpSize, +inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, int tpGrpSize, bool updateOptions = true) { options.mWorldSize = tpGrpSize; @@ -625,25 +578,21 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, // Check that the A cast is supported. // Currently, we only support {MxFp4, NvFp4} -> Bf16. TLLM_CHECK_ERROR( - (options.mDtypeA == options.mDtypeMmaA) || - ((options.mDtypeA == tg::Dtype::MxE2m1 || options.mDtypeA == tg::Dtype::E2m1) && - options.mDtypeMmaA == tg::Dtype::Bfloat16) || - (options.mDtypeA == tg::Dtype::E2m1 && options.mDtypeMmaA == tg::Dtype::E4m3), - "Unsupported cast for A: ", - tg::dtypeToString(options.mDtypeA), - " -> ", - tg::dtypeToString(options.mDtypeMmaA)); + (options.mDtypeA == options.mDtypeMmaA) || + ((options.mDtypeA == tg::Dtype::MxE2m1 || options.mDtypeA == tg::Dtype::E2m1) && + options.mDtypeMmaA == tg::Dtype::Bfloat16) || + (options.mDtypeA == tg::Dtype::E2m1 && options.mDtypeMmaA == tg::Dtype::E4m3), + "Unsupported cast for A: ", tg::dtypeToString(options.mDtypeA), " -> ", + tg::dtypeToString(options.mDtypeMmaA)); // Check that the B cast is supported. // Currently, we only support Fp8 -> MxFp8. // TODO: add same support for A (no transpose) TLLM_CHECK_ERROR( - (options.mDtypeB == options.mDtypeMmaB) || - (options.mDtypeB == tg::Dtype::E4m3 && options.mDtypeMmaB == tg::Dtype::MxE4m3), - "Unsupported cast for B: ", - tg::dtypeToString(options.mDtypeB), - " -> ", - tg::dtypeToString(options.mDtypeMmaB)); + (options.mDtypeB == options.mDtypeMmaB) || + (options.mDtypeB == tg::Dtype::E4m3 && options.mDtypeMmaB == tg::Dtype::MxE4m3), + "Unsupported cast for B: ", tg::dtypeToString(options.mDtypeB), " -> ", + tg::dtypeToString(options.mDtypeMmaB)); if (options.mDtypeA != options.mDtypeMmaA) { TLLM_CHECK_ERROR(options.mTileM == 128, @@ -651,9 +600,9 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, } if (options.mPatchF2fp) { - TLLM_CHECK_ERROR(options.mDtypeA == tg::Dtype::MxE2m1 && - options.mDtypeMmaA == tg::Dtype::Bfloat16, - "PatchF2fp is only supported for MxFp4 to Bf16 casts."); + TLLM_CHECK_ERROR( + options.mDtypeA == tg::Dtype::MxE2m1 && options.mDtypeMmaA == tg::Dtype::Bfloat16, + "PatchF2fp is only supported for MxFp4 to Bf16 casts."); } // FIXME: We do not support different dtypes for A and B when not on Blackwell. @@ -671,14 +620,14 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, // kind::mxf8f6f4 if (options.mDtypeMmaA == tg::Dtype::MxE4m3 || options.mDtypeMmaA == tg::Dtype::MxE2m1) { - TLLM_CHECK_ERROR(options.mDtypeMmaB == tg::Dtype::MxE4m3 || - options.mDtypeMmaB == tg::Dtype::MxE2m1, - "For dtypeMmaA = MxE4m3 or MxE2m1, dtypeMmaB must also be MxE4m3 or MxE2m1."); + TLLM_CHECK_ERROR( + options.mDtypeMmaB == tg::Dtype::MxE4m3 || options.mDtypeMmaB == tg::Dtype::MxE2m1, + "For dtypeMmaA = MxE4m3 or MxE2m1, dtypeMmaB must also be MxE4m3 or MxE2m1."); } if (options.mDtypeMmaB == tg::Dtype::MxE4m3 || options.mDtypeMmaB == tg::Dtype::MxE2m1) { - TLLM_CHECK_ERROR(options.mDtypeMmaA == tg::Dtype::MxE4m3 || - options.mDtypeMmaA == tg::Dtype::MxE2m1, - "For dtypeMmaB = MxE4m3 or MxE2m1, dtypeMmaA must also be MxE4m3 or MxE2m1."); + TLLM_CHECK_ERROR( + options.mDtypeMmaA == tg::Dtype::MxE4m3 || options.mDtypeMmaA == tg::Dtype::MxE2m1, + "For dtypeMmaB = MxE4m3 or MxE2m1, dtypeMmaA must also be MxE4m3 or MxE2m1."); } // kind::f16 @@ -712,11 +661,8 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, if ((options.mMmaKind == tg::MmaKind::Fp8Fp6Fp4 || options.mMmaKind == tg::MmaKind::MxFp8Fp6Fp4) && options.mMmaK != 32) { - TLLM_LOG_WARNING("Unsupported MmaK (", - options.mMmaK, - ") for MmaKind=", - gemm::toString(options.mMmaKind), - ". Setting MmaK to 32"); + TLLM_LOG_WARNING("Unsupported MmaK (", options.mMmaK, + ") for MmaKind=", gemm::toString(options.mMmaKind), ". Setting MmaK to 32"); if (updateOptions) { options.mMmaK = 32; options.mTileK = std::max(options.mMmaK, options.mTileK); @@ -728,19 +674,13 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, // Check LDTM shape. if (isBlackwell) { TLLM_CHECK_ERROR((options.mEpilogueLdtmDps == 16 && options.mEpilogueLdtmBits == 256) || - (options.mEpilogueLdtmDps == 32 && options.mEpilogueLdtmBits == 32), - "Unsupported LDTM shape: ", - options.mEpilogueLdtmDps, - "dp", - options.mEpilogueLdtmBits, - "bit."); + (options.mEpilogueLdtmDps == 32 && options.mEpilogueLdtmBits == 32), + "Unsupported LDTM shape: ", options.mEpilogueLdtmDps, "dp", + options.mEpilogueLdtmBits, "bit."); if (options.mEpilogueTileM == 64) { TLLM_CHECK_ERROR(options.mEpilogueLdtmDps == 16, - "Unsupported LDTM shape for epilogueTileM=64: ", - options.mEpilogueLdtmDps, - "dp", - options.mEpilogueLdtmBits, - "bit."); + "Unsupported LDTM shape for epilogueTileM=64: ", options.mEpilogueLdtmDps, + "dp", options.mEpilogueLdtmBits, "bit."); } if (options.mTransposeMmaOutput) { // We can't use 32dp32bit LDTM for transposed outputs because we need each thread to own @@ -750,28 +690,21 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, } } else { TLLM_CHECK_ERROR( - options.mEpilogueLdtmDps == 16 && options.mEpilogueLdtmBits == 256, - "Hopper does not use TMEM. The register layout corresponds to 16dp256bit. Got ", - options.mEpilogueLdtmDps, - "dp", - options.mEpilogueLdtmBits, - "bit."); + options.mEpilogueLdtmDps == 16 && options.mEpilogueLdtmBits == 256, + "Hopper does not use TMEM. The register layout corresponds to 16dp256bit. Got ", + options.mEpilogueLdtmDps, "dp", options.mEpilogueLdtmBits, "bit."); } // Constraints for NvFp4 and MxFp8. if ((options.mMmaKind == tg::MmaKind::MxFp4NvFp4 || options.mMmaKind == tg::MmaKind::MxFp8Fp6Fp4 || options.mDtypeC == tg::Dtype::MxE4m3) && options.mMmaM != 128) { - if (options.mClusterDimX == 1) { // MMA M must be 128 when the input uses block scaling, or when the output is an Mx format. int newTileM = 128 * divUp(options.mTileM, 128); - TLLM_LOG_WARNING("Unsupported MmaM (", - options.mMmaM, - ") for MmaKind=", - gemm::toString(options.mMmaKind), - ". Setting MmaM to 128 and TileM to ", - newTileM); + TLLM_LOG_WARNING("Unsupported MmaM (", options.mMmaM, + ") for MmaKind=", gemm::toString(options.mMmaKind), + ". Setting MmaM to 128 and TileM to ", newTileM); if (updateOptions) { options.mMmaM = 128; options.mTileM = newTileM; @@ -802,14 +735,9 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, } if (options.mMmaK != mmaK) { int newTileK = mmaK * divUp(options.mTileK, mmaK); - TLLM_LOG_WARNING("Unsupported MmaK (", - options.mMmaK, - ") for MmaKind=", - gemm::toString(options.mMmaKind), - ". Setting MmaK to ", - mmaK, - " and TileK to ", - newTileK); + TLLM_LOG_WARNING("Unsupported MmaK (", options.mMmaK, + ") for MmaKind=", gemm::toString(options.mMmaKind), ". Setting MmaK to ", + mmaK, " and TileK to ", newTileK); if (updateOptions) { options.mMmaK = mmaK; options.mTileK = newTileK; @@ -819,12 +747,8 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, } // The MMA N may only be smaller than 64 if it is equal to the tile N. - TLLM_CHECK_ERROR(options.mMmaN >= 64 || options.mMmaN == options.mTileN, - "MmaN (", - options.mMmaN, - ") must be >= 64 or equal to TileN (", - options.mTileN, - ")"); + TLLM_CHECK_ERROR(options.mMmaN >= 64 || options.mMmaN == options.mTileN, "MmaN (", + options.mMmaN, ") must be >= 64 or equal to TileN (", options.mTileN, ")"); } if (options.mSfBlockSizeA.has_value()) { @@ -832,8 +756,7 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, TLLM_CHECK_ERROR(options.mDtypeA == tg::Dtype::E2m1 && options.mDtypeB == tg::Dtype::E4m3, "sfBlockSizeA is only supported for E2m1 and E4m3 types. Found dtypeA=", tg::dtypeToString(options.mDtypeA), - " dtypeB=", - tg::dtypeToString(options.mDtypeB)); + " dtypeB=", tg::dtypeToString(options.mDtypeB)); // sfBlockSizeA must be 16 or 32. // SfBlockSizeA can also support 64 and 128, although they are not officially supported Nvida @@ -842,56 +765,38 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, // If we want to support sfBlockSizeA=8, we can write another version of convertE2m1ToSfE4m3, // which only packs 8 e2m1 elements. TLLM_CHECK_ERROR(options.mSfBlockSizeA.value() == 16 || options.mSfBlockSizeA.value() == 32, - "SfBlockSizeA (", - options.mSfBlockSizeA.value(), - ") must be 16 or 32."); + "SfBlockSizeA (", options.mSfBlockSizeA.value(), ") must be 16 or 32."); } if (tg::dtypeIsBlockFmt(options.mDtypeA)) { int numEltsPerSfA = options.mSfBlockSizeA.value_or(tg::dtypeNumEltsPerSf(options.mDtypeA)); - TLLM_CHECK_ERROR(options.mTileK % (4 * numEltsPerSfA) == 0, - "TileK (", - options.mTileK, - ") must be a multiple of ", - (4 * numEltsPerSfA), - " for typeA ", + TLLM_CHECK_ERROR(options.mTileK % (4 * numEltsPerSfA) == 0, "TileK (", options.mTileK, + ") must be a multiple of ", (4 * numEltsPerSfA), " for typeA ", gemm::toString(options.mDtypeA)); auto const numEltsPerSfAInK = options.mK / numEltsPerSfA; - TLLM_CHECK_ERROR(numEltsPerSfAInK % 4 == 0, - "K dimension of scaling factors for A (", - numEltsPerSfAInK, - ") must be a multiple of 4"); + TLLM_CHECK_ERROR(numEltsPerSfAInK % 4 == 0, "K dimension of scaling factors for A (", + numEltsPerSfAInK, ") must be a multiple of 4"); } if (tg::dtypeIsBlockFmt(options.mDtypeB)) { TLLM_CHECK_ERROR(options.mSfLayoutB == tg::SfLayout::R128c4 || - options.mSfLayoutB == tg::SfLayout::R8c4 || - options.mSfLayoutB == tg::SfLayout::Linear, + options.mSfLayoutB == tg::SfLayout::R8c4 || + options.mSfLayoutB == tg::SfLayout::Linear, "Only the 128x4 and 8x4 SF layouts are supported for B, got ", tg::sfLayoutToString(options.mSfLayoutB)); // TileN must be a multiple of the number of rows per SF tile. int const numSfTileRowsB = options.mSfLayoutB == tg::SfLayout::R128c4 ? 128 : 8; - TLLM_CHECK_ERROR(options.mTileN % numSfTileRowsB == 0, - "TileN (", - options.mTileN, - ") must be a multiple of ", - numSfTileRowsB, - " for B SF layout ", + TLLM_CHECK_ERROR(options.mTileN % numSfTileRowsB == 0, "TileN (", options.mTileN, + ") must be a multiple of ", numSfTileRowsB, " for B SF layout ", tg::sfLayoutToString(options.mSfLayoutB)); int numEltsPerSfB = tg::dtypeNumEltsPerSf(options.mDtypeB); - TLLM_CHECK_ERROR(options.mTileK % (4 * numEltsPerSfB) == 0, - "TileK (", - options.mTileK, - ") must be a multiple of ", - (4 * numEltsPerSfB), - " for typeB ", + TLLM_CHECK_ERROR(options.mTileK % (4 * numEltsPerSfB) == 0, "TileK (", options.mTileK, + ") must be a multiple of ", (4 * numEltsPerSfB), " for typeB ", gemm::toString(options.mDtypeB)); auto const numEltsPerSfBInK = options.mK / numEltsPerSfB; - TLLM_CHECK_ERROR(numEltsPerSfBInK % 4 == 0, - "K dimension of scaling factors for B (", - numEltsPerSfBInK, - ") must be a multiple of 4"); + TLLM_CHECK_ERROR(numEltsPerSfBInK % 4 == 0, "K dimension of scaling factors for B (", + numEltsPerSfBInK, ") must be a multiple of 4"); } int32_t padMultiplierA = 1; @@ -904,36 +809,30 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, padMultiplierB = 2; } } - TLLM_CHECK_ERROR((padMultiplierA * tg::dtypeGetNumBits(options.mDtypeA) * options.mK / 8) % 16 == - 0, - "K dimension of A must be aligned to 16 bytes."); - TLLM_CHECK_ERROR((padMultiplierB * tg::dtypeGetNumBits(options.mDtypeB) * options.mK / 8) % 16 == - 0, - "K dimension of B must be aligned to 16 bytes."); + TLLM_CHECK_ERROR( + (padMultiplierA * tg::dtypeGetNumBits(options.mDtypeA) * options.mK / 8) % 16 == 0, + "K dimension of A must be aligned to 16 bytes."); + TLLM_CHECK_ERROR( + (padMultiplierB * tg::dtypeGetNumBits(options.mDtypeB) * options.mK / 8) % 16 == 0, + "K dimension of B must be aligned to 16 bytes."); if (options.mDtypeC == tg::Dtype::E2m1 || options.mDtypeC == tg::Dtype::MxE4m3) { TLLM_CHECK_ERROR(isBlackwell, "Block scaling is only supported on Blackwell"); - TLLM_CHECK_ERROR(options.mSfLayoutC == tg::SfLayout::R128c4 || - options.mSfLayoutC == tg::SfLayout::R8c4, - "Only the 128x4 and 8x4 SF layouts are supported for C."); + TLLM_CHECK_ERROR( + options.mSfLayoutC == tg::SfLayout::R128c4 || options.mSfLayoutC == tg::SfLayout::R8c4, + "Only the 128x4 and 8x4 SF layouts are supported for C."); int const numSfTileRowsC = options.mSfLayoutC == tg::SfLayout::R128c4 ? 128 : 8; int const tileTokenDim = options.mTransposeMmaOutput ? options.mTileN : options.mTileM; TLLM_CHECK_ERROR_FMT(tileTokenDim % numSfTileRowsC == 0, "Tile%s (%d) must be a multiple of %d for C SF layout %s", - options.mTransposeMmaOutput ? "N" : "M", - tileTokenDim, - numSfTileRowsC, + options.mTransposeMmaOutput ? "N" : "M", tileTokenDim, numSfTileRowsC, tg::sfLayoutToString(options.mSfLayoutC).c_str()); int const hiddenDim = options.mTransposeMmaOutput ? options.mM : options.mN; int const hiddenGranularity = 4 * tg::dtypeNumEltsPerSf(options.mDtypeC); - TLLM_CHECK_ERROR(hiddenDim % hiddenGranularity == 0, - "Hidden dim (", - hiddenDim, - ") must be a multiple of ", - hiddenGranularity, - " for block-scaled outputs."); + TLLM_CHECK_ERROR(hiddenDim % hiddenGranularity == 0, "Hidden dim (", hiddenDim, + ") must be a multiple of ", hiddenGranularity, " for block-scaled outputs."); TLLM_CHECK_ERROR(!options.mTransposeMmaOutput || options.mUseShuffledMatrixA, "Transposing block-scaled outputs requires shuffled A."); } @@ -950,11 +849,8 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, // Set epilogue tile sizes to the output tile sizes, when epilogue tile sizes are incorrect. if (options.mTileM % options.mEpilogueTileM != 0) { - TLLM_LOG_WARNING("TileM (", - options.mTileM, - ") must be divisible by EpilogueTileM (", - options.mEpilogueTileM, - "). Setting EpilogueTileM to TileM"); + TLLM_LOG_WARNING("TileM (", options.mTileM, ") must be divisible by EpilogueTileM (", + options.mEpilogueTileM, "). Setting EpilogueTileM to TileM"); if (updateOptions) { options.mEpilogueTileM = options.mTileM; } else { @@ -963,11 +859,8 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, } if (options.mTileN % options.mEpilogueTileN != 0) { - TLLM_LOG_WARNING("TileN (", - options.mTileN, - ") must be divisible by EpilogueTileN (", - options.mEpilogueTileN, - "). Setting EpilogueTileN to TileN"); + TLLM_LOG_WARNING("TileN (", options.mTileN, ") must be divisible by EpilogueTileN (", + options.mEpilogueTileN, "). Setting EpilogueTileN to TileN"); if (updateOptions) { options.mEpilogueTileN = options.mTileN; } else { @@ -979,7 +872,7 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, if (!isBlackwell && (options.mEpilogueTileM != options.mTileM || options.mEpilogueTileN != options.mTileN)) { TLLM_LOG_WARNING( - "Overwriting epilogueTileM and epilogueTileN to match tileM and tileN respectively"); + "Overwriting epilogueTileM and epilogueTileN to match tileM and tileN respectively"); if (updateOptions) { options.mEpilogueTileM = options.mTileM; options.mEpilogueTileN = options.mTileN; @@ -991,7 +884,7 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, // Unsupported epilogue tile size. if (options.mMmaM == 128 && options.mEpilogueTileM != options.mTileM) { TLLM_LOG_WARNING( - "When MmaM = 128, EpilogueTileM must be equal to TileM. Setting EpilogueTileM to TileM"); + "When MmaM = 128, EpilogueTileM must be equal to TileM. Setting EpilogueTileM to TileM"); if (updateOptions) { options.mEpilogueTileM = options.mTileM; } else { @@ -1006,8 +899,7 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, if (options.mUseShuffledMatrixA) { auto const shuffleBlockSize = getShuffleBlockSize(options.mEpilogueTileM); TLLM_CHECK_ERROR(options.mM % shuffleBlockSize == 0, - "M must be a multiple of shuffle block size (", - shuffleBlockSize, + "M must be a multiple of shuffle block size (", shuffleBlockSize, ") when useShuffledMatrixA"); } @@ -1027,11 +919,11 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, } TLLM_CHECK_ERROR( - options.mTileM % options.mEpilogueTileM == 0 && options.mTileN % options.mEpilogueTileN == 0, - "TileM and TileN must be divisible by EpilogueTileM and EpilogueTileN respectively."); - TLLM_CHECK_ERROR((options.mClusterDimX == 1 || options.mClusterDimX == 2) && - options.mClusterDimY == 1, - "GEMM does not support cluster in X and Y dimensions."); + options.mTileM % options.mEpilogueTileM == 0 && options.mTileN % options.mEpilogueTileN == 0, + "TileM and TileN must be divisible by EpilogueTileM and EpilogueTileN respectively."); + TLLM_CHECK_ERROR( + (options.mClusterDimX == 1 || options.mClusterDimX == 2) && options.mClusterDimY == 1, + "GEMM does not support cluster in X and Y dimensions."); TLLM_CHECK_ERROR(options.mClusterDimZ == 1 || options.mNumSlicesForSplitK > 1, "Cluster DimZ is only allowed for split-k."); TLLM_CHECK_ERROR(options.mTileM <= 128, "GEMM does not support TileM > 128."); @@ -1046,8 +938,8 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, if (options.mUseShuffledMatrixA) { // TODO add matrix shuffle for N-major epilogue. TLLM_CHECK_ERROR( - options.mTransposeMmaOutput, - "Shuffled matrix A is only supported with M-major epilogue. Set -transposeMmaOutput"); + options.mTransposeMmaOutput, + "Shuffled matrix A is only supported with M-major epilogue. Set -transposeMmaOutput"); } // Check all-reduce options. @@ -1057,24 +949,18 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, // // See: https://docs.nvidia.com/cuda/parallel-thread-execution/ // #data-movement-and-conversion-instructions-cp-reduce-async-bulk-tensor - std::set dtypeSupported{tg::Dtype::UInt32, - tg::Dtype::Int32, - tg::Dtype::UInt64, - tg::Dtype::Fp32, - tg::Dtype::Fp16, - tg::Dtype::Bfloat16}; + std::set dtypeSupported{tg::Dtype::UInt32, tg::Dtype::Int32, tg::Dtype::UInt64, + tg::Dtype::Fp32, tg::Dtype::Fp16, tg::Dtype::Bfloat16}; TLLM_CHECK_ERROR(dtypeSupported.find(options.mDtypeC) != dtypeSupported.end(), - "Unsupported output dtype ", - tg::dtypeToString(options.mDtypeC)); + "Unsupported output dtype ", tg::dtypeToString(options.mDtypeC)); } else if (options.mAllReduceAlgo == AllReduceAlgo::TwoShot) { // TODO(anchengc): // Input dtype == output dtype -> can perform all-reduce in-place. // Input dtype != output dtype -> must perform all-reduce out of place. TLLM_CHECK_ERROR_FMT( - options.mDtypeC == options.mDtypeAcc, - "Not implemented - mixed dtype (dtypeC (%s) != dtypeAcc (%s)) requires out of place update", - tg::dtypeToString(options.mDtypeC).c_str(), - tg::dtypeToString(options.mDtypeAcc).c_str()); + options.mDtypeC == options.mDtypeAcc, + "Not implemented - mixed dtype (dtypeC (%s) != dtypeAcc (%s)) requires out of place update", + tg::dtypeToString(options.mDtypeC).c_str(), tg::dtypeToString(options.mDtypeAcc).c_str()); } if (options.mAllReduceAlgo != AllReduceAlgo::None) { TLLM_CHECK_ERROR(options.mUseTmaStore, "Non-TMA store with all-reduce is not implemented"); @@ -1102,7 +988,7 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, if ((options.mEpilogueTileM != options.mTileM || options.mEpilogueTileN != options.mTileN) && !options.mUseDeepSeekFp8) { TLLM_LOG_WARNING( - "Overwriting epilogueTileM and epilogueTileN to match tileM and tileN respectively"); + "Overwriting epilogueTileM and epilogueTileN to match tileM and tileN respectively"); if (updateOptions) { options.mEpilogueTileM = options.mTileM; options.mEpilogueTileN = options.mTileN; @@ -1124,40 +1010,37 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, if (updateOptions) { options.mNumStagesMmaAcrossWorkTile = std::min(2, options.mNumStagesMma); options.mNumStagesMmaWithinWorkTile = - options.mNumStagesMma / options.mNumStagesMmaAcrossWorkTile; + options.mNumStagesMma / options.mNumStagesMmaAcrossWorkTile; } else { return false; } } else if (options.mNumStagesMmaWithinWorkTile == -1) { if (updateOptions) { options.mNumStagesMmaWithinWorkTile = - options.mNumStagesMma / options.mNumStagesMmaAcrossWorkTile; + options.mNumStagesMma / options.mNumStagesMmaAcrossWorkTile; } else { return false; } } else if (options.mNumStagesMmaAcrossWorkTile == -1) { if (updateOptions) { options.mNumStagesMmaAcrossWorkTile = - options.mNumStagesMma / options.mNumStagesMmaWithinWorkTile; + options.mNumStagesMma / options.mNumStagesMmaWithinWorkTile; } else { return false; } } // Check mma stages. TLLM_CHECK_ERROR_FMT(options.mNumStagesMmaWithinWorkTile * options.mNumStagesMmaAcrossWorkTile == - options.mNumStagesMma && - options.mNumStagesMmaAcrossWorkTile <= 2, + options.mNumStagesMma && + options.mNumStagesMmaAcrossWorkTile <= 2, "Condition numStagesMmaWithinWorkTile (%d) * numStagesMmaAcrossWorkTile " "(%d) == numStagesMma (%d) && numStagesMmaAcrossWorkTile (%d) <= 2 must be " "satisfied. Check arguments.", - options.mNumStagesMmaWithinWorkTile, - options.mNumStagesMmaAcrossWorkTile, - options.mNumStagesMma, - options.mNumStagesMmaAcrossWorkTile); + options.mNumStagesMmaWithinWorkTile, options.mNumStagesMmaAcrossWorkTile, + options.mNumStagesMma, options.mNumStagesMmaAcrossWorkTile); // Mma stage must be 1 for pre-Hopper. TLLM_CHECK_ERROR(isBlackwell || options.mNumStagesMma == 1, - "Mma stage must be 1 for pre-Hopper. Found ", - options.mNumStagesMma); + "Mma stage must be 1 for pre-Hopper. Found ", options.mNumStagesMma); // DeepSeek Fp8 if (!options.mUseDeepSeekFp8) { TLLM_CHECK_ERROR(options.mNumStagesMmaWithinWorkTile == 1, @@ -1174,8 +1057,7 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, TLLM_CHECK_ERROR(options.mDtypeA == tg::Dtype::E4m3 && options.mDtypeB == tg::Dtype::E4m3, "A and B dtype must be E4m3 for DeepSeek Fp8. Found dtypeA=", tg::dtypeToString(options.mDtypeA), - " dtypeB=", - tg::dtypeToString(options.mDtypeB)); + " dtypeB=", tg::dtypeToString(options.mDtypeB)); TLLM_CHECK_ERROR(isBlackwell, "DeepSeek Fp8 is not supported for Hopper"); TLLM_CHECK_ERROR(options.mAllReduceAlgo == AllReduceAlgo::None, @@ -1187,7 +1069,7 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, // Tile sizes of the output hidden dimension. auto hiddenDimPerOutputTile = options.mTransposeMmaOutput ? options.mTileM : options.mTileN; auto hiddenDimPerEpilogueTile = - options.mTransposeMmaOutput ? options.mEpilogueTileM : options.mEpilogueTileN; + options.mTransposeMmaOutput ? options.mEpilogueTileM : options.mEpilogueTileN; auto hiddenDimPerMma = options.mTransposeMmaOutput ? options.mMmaM : options.mMmaN; auto hiddenDimName = options.mTransposeMmaOutput ? "M" : "N"; TLLM_CHECK_WARNING(options.mNumStagesMmaWithinWorkTile > 1, @@ -1205,26 +1087,14 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, // Check that the output tile N can be processed with the epilogue tile granularity. TLLM_CHECK_ERROR((hiddenDimPerOutputTile / 2) % hiddenDimPerEpilogueTile == 0, - "DeepSeek Fp8 requires Tile", - hiddenDimName, - " / 2 (", - hiddenDimPerOutputTile / 2, - ") being a multiple of EpilogueTile", - hiddenDimName, - " (", - hiddenDimPerEpilogueTile, - ")"); + "DeepSeek Fp8 requires Tile", hiddenDimName, " / 2 (", + hiddenDimPerOutputTile / 2, ") being a multiple of EpilogueTile", + hiddenDimName, " (", hiddenDimPerEpilogueTile, ")"); // Check that the output tile N can be processed with the epilogue tile granularity. TLLM_CHECK_ERROR((hiddenDimPerOutputTile / 2) % hiddenDimPerMma == 0, - "DeepSeek Fp8 requires Tile", - hiddenDimName, - " / 2 (", - hiddenDimPerOutputTile / 2, - ") being a multiple of mma", - hiddenDimName, - " (", - hiddenDimPerMma, - ")"); + "DeepSeek Fp8 requires Tile", hiddenDimName, " / 2 (", + hiddenDimPerOutputTile / 2, ") being a multiple of mma", hiddenDimName, " (", + hiddenDimPerMma, ")"); } if (options.mSliceK) { @@ -1259,13 +1129,9 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, return false; } } - TLLM_CHECK_ERROR((options.mTileK / options.mMmaK) % options.mNumSlicesForSliceK == 0, - "TileK (", - options.mTileK, - ") / MmaK (", - options.mMmaK, - ") must be a multiple of mNumSlicesForSliceK (", - options.mNumSlicesForSliceK, + TLLM_CHECK_ERROR((options.mTileK / options.mMmaK) % options.mNumSlicesForSliceK == 0, "TileK (", + options.mTileK, ") / MmaK (", options.mMmaK, + ") must be a multiple of mNumSlicesForSliceK (", options.mNumSlicesForSliceK, ")"); } @@ -1295,15 +1161,9 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, (clampedAndPaddedPerCtaK % (options.mTileK * 2) != 0); if (notSupported) { TLLM_LOG_WARNING("Size K / splitK must be a multiple of TileK * 2. Found TileK=", - options.mTileK, - " and K=", - options.mK, - " (paddedK=", - paddedK, - " clampedAndPaddedPerCtaK=", - clampedAndPaddedPerCtaK, - ") and numSlicesForSplitK=", - options.mNumSlicesForSplitK, + options.mTileK, " and K=", options.mK, " (paddedK=", paddedK, + " clampedAndPaddedPerCtaK=", clampedAndPaddedPerCtaK, + ") and numSlicesForSplitK=", options.mNumSlicesForSplitK, ". Disabling unrollLoop2xForMma."); if (updateOptions) { options.mUseUnrollLoop2xForMma = false; @@ -1314,8 +1174,8 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, } if (options.mNumSlicesForSplitK > 1) { TLLM_CHECK_ERROR( - perCtaK * (options.mNumSlicesForSplitK - 1) < options.mK, - "K must be greater than perCtaK * (numSlicesForSplitK - 1) to ensure each CTA has work"); + perCtaK * (options.mNumSlicesForSplitK - 1) < options.mK, + "K must be greater than perCtaK * (numSlicesForSplitK - 1) to ensure each CTA has work"); } if (!isBlackwell && options.mTileScheduler == TileScheduler::Persistent) { @@ -1329,8 +1189,9 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, } if (options.mEnablesDelayedEarlyExit && options.mEnablesEarlyExit) { - TLLM_LOG_WARNING("Only one of early exit and delayed early exit should be enabled. Disabling " - "delayed early exit"); + TLLM_LOG_WARNING( + "Only one of early exit and delayed early exit should be enabled. Disabling " + "delayed early exit"); if (updateOptions) { options.mEnablesDelayedEarlyExit = false; } else { @@ -1357,11 +1218,11 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, // visible- Warp 2: -------------------ACQBULK-- Kernel 1,2 output visible // ---------- TLLM_CHECK_ERROR( - (options.mGridWaitForPrimaryA || !options.mGridTriggerSecondaryA), - "A: If a task triggers a secondary kernel, it must also wait for primary kernel."); + (options.mGridWaitForPrimaryA || !options.mGridTriggerSecondaryA), + "A: If a task triggers a secondary kernel, it must also wait for primary kernel."); TLLM_CHECK_ERROR( - (options.mGridWaitForPrimaryB || !options.mGridTriggerSecondaryB), - "B: If a task triggers a secondary kernel, it must also wait for primary kernel."); + (options.mGridWaitForPrimaryB || !options.mGridTriggerSecondaryB), + "B: If a task triggers a secondary kernel, it must also wait for primary kernel."); if (options.mUsePerTokenSfA || options.mUsePerTokenSfB) { // Checks applicable to both MetaFP8 and RoutingScalesOnInput @@ -1373,21 +1234,20 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, TLLM_CHECK_ERROR(options.mDtypeA == tg::Dtype::E4m3 && options.mDtypeB == tg::Dtype::E4m3, "A and B dtype must be E4m3 for Meta Fp8. Found dtypeA=", tg::dtypeToString(options.mDtypeA), - " dtypeB=", - tg::dtypeToString(options.mDtypeB)); + " dtypeB=", tg::dtypeToString(options.mDtypeB)); } else { // RoutingScalesOnInput case TLLM_CHECK_ERROR((options.mUsePerTokenSfA && !options.mTransposeMmaOutput) || - (options.mUsePerTokenSfB && options.mTransposeMmaOutput), + (options.mUsePerTokenSfB && options.mTransposeMmaOutput), "In RoutingScalesOnInput mode, perToken scales must be used on activations"); } } // The generation should support non K-major layouts for both A and B; however, it is unclear if // there is a use-case - TLLM_CHECK_ERROR((options.mLayoutA == MatrixLayout::MajorK) || - (options.mLayoutB == MatrixLayout::MajorK), - "At least one matrix must be in k-major layout"); + TLLM_CHECK_ERROR( + (options.mLayoutA == MatrixLayout::MajorK) || (options.mLayoutB == MatrixLayout::MajorK), + "At least one matrix must be in k-major layout"); // Some features are currently only support when both matrices are in K-major format if (options.mLayoutB != MatrixLayout::MajorK || options.mLayoutB != MatrixLayout::MajorK) { @@ -1411,7 +1271,7 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, // TODO Leaving this as an option for now in case we want to expertiment with other block sizes // As the user is not expected to set this, do not fail if updateOptions is false int32_t const elemSizeInBits = - (isBlockA) ? tg::dtypeGetNumBits(options.mDtypeA) : tg::dtypeGetNumBits(options.mDtypeB); + (isBlockA) ? tg::dtypeGetNumBits(options.mDtypeA) : tg::dtypeGetNumBits(options.mDtypeB); int32_t const elemsIn128B = 128 * 8 /* Bits in byte */ / elemSizeInBits; if (options.mBlockK != elemsIn128B) { @@ -1424,12 +1284,12 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, if (options.mBlockK > options.mTileK) { TLLM_CHECK_ERROR( - options.mBlockK % options.mTileK == 0, - "If block size is greater than tile size, block size must be a multiple of tile size"); + options.mBlockK % options.mTileK == 0, + "If block size is greater than tile size, block size must be a multiple of tile size"); } else if (options.mBlockK < options.mTileK) { TLLM_CHECK_ERROR( - options.mTileK % options.mBlockK == 0, - "If tile size is greater than block size, tile size must be a multiple of block size"); + options.mTileK % options.mBlockK == 0, + "If tile size is greater than block size, tile size must be a multiple of block size"); } } @@ -1442,33 +1302,15 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, if (updateOptions) { // Init kernel traits. - options.mKernelTraits = KernelTraits(options.mDtypeA, - options.mDtypeB, - options.mDtypeC, - options.mDtypeAcc, - options.mDtypeMmaA, - options.mDtypeMmaB, - options.mMmaKind, - options.mMmaK, - options.mTileM, - options.mTileN, - options.mTileK, - options.mEpilogueTileM, - options.mEpilogueTileN, - options.mNumStages, - options.mNumStagesMma, - options.mNumSlicesForSplitK, - options.mNumSlicesForSliceK, - options.mSplitK, - options.mUseTmaStore, - options.mTransposeMmaOutput, - options.mAllReduceAlgo, - options.mTileScheduler == TileScheduler::Persistent, - options.mUseDeepSeekFp8, - options.mUsePerTokenSfA, - options.mUsePerTokenSfB, - /* useTwoCtas*/ options.mClusterDimX == 2, - options.mBiasType); + options.mKernelTraits = KernelTraits( + options.mDtypeA, options.mDtypeB, options.mDtypeC, options.mDtypeAcc, options.mDtypeMmaA, + options.mDtypeMmaB, options.mMmaKind, options.mMmaK, options.mTileM, options.mTileN, + options.mTileK, options.mEpilogueTileM, options.mEpilogueTileN, options.mNumStages, + options.mNumStagesMma, options.mNumSlicesForSplitK, options.mNumSlicesForSliceK, + options.mSplitK, options.mUseTmaStore, options.mTransposeMmaOutput, options.mAllReduceAlgo, + options.mTileScheduler == TileScheduler::Persistent, options.mUseDeepSeekFp8, + options.mUsePerTokenSfA, options.mUsePerTokenSfB, + /* useTwoCtas*/ options.mClusterDimX == 2, options.mBiasType); } return true; @@ -1476,7 +1318,7 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace gemm +} // namespace gemm #ifdef TLLM_GEN_EXPORT_INTERFACE @@ -1487,6 +1329,6 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, #undef TLLM_LOG_INFO #undef TLLM_LOG_ERROR -#endif // TLLM_GEN_EXPORT_INTERFACE +#endif // TLLM_GEN_EXPORT_INTERFACE -} // namespace batchedGemm +} // namespace batchedGemm diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h index 03edfb149e..9cb615c750 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h @@ -1,28 +1,28 @@ /* -* SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & -* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once #include -#include "trtllm/gen/SfLayoutDecl.h" -#include "trtllm/gen/CommonUtils.h" -#include "TmaDescriptor.h" -#include "Enums.h" #include "BatchedGemmEnums.h" +#include "Enums.h" +#include "TmaDescriptor.h" +#include "trtllm/gen/CommonUtils.h" +#include "trtllm/gen/SfLayoutDecl.h" // NOTE: keep this code dependency free. It has to be included by the device code and has to be // compilable with NVRTC. @@ -34,7 +34,8 @@ namespace batchedGemm { //////////////////////////////////////////////////////////////////////////////////////////////////// // TODO: Find a better header to put this in, that we can include from here. -template inline T ceilDiv(T m, T n) { +template +inline T ceilDiv(T m, T n) { return (m + n - T(1)) / n; } @@ -55,21 +56,24 @@ enum class MatrixType { MatrixA = 0, MatrixB, MatrixC }; // ////////////////////////////////////////////////////////////////////////////////////////////////// -template bool useTmaOobOptA(BatchedGemmOptions const& options) { +template +bool useTmaOobOptA(BatchedGemmOptions const& options) { return options.mBatchMode == BatchedGemmOptions::BatchMode::BatchM && doesRouteImplUseNoRoute(options.mRouteImpl) && options.mUseTmaOobOpt; } ////////////////////////////////////////////////////////////////////////////////////////////////// -template bool useTmaOobOptB(BatchedGemmOptions const& options) { +template +bool useTmaOobOptB(BatchedGemmOptions const& options) { return options.mBatchMode == BatchedGemmOptions::BatchMode::BatchN && doesRouteImplUseNoRoute(options.mRouteImpl) && options.mUseTmaOobOpt; } ////////////////////////////////////////////////////////////////////////////////////////////////// -template bool useTmaOobOptC(BatchedGemmOptions const& options) { +template +bool useTmaOobOptC(BatchedGemmOptions const& options) { return options.mUseTmaStore && options.mUseTmaOobOpt; } @@ -77,14 +81,8 @@ template bool useTmaOobOptC(BatchedGemmOptions con // Create the TMA shape/stride for A/B/C. template -static auto makeTmaShapeStrideAbc(GemmOptions const& options, - int mM, - int mN, - int mK, - int tileM, - int tileN, - int tileK, - MatrixType matrixType) { +static auto makeTmaShapeStrideAbc(GemmOptions const& options, int mM, int mN, int mK, int tileM, + int tileN, int tileK, MatrixType matrixType) { // Weights matrix is A if we transpose the output of MMA (to have it M-major). // Otherwise, it is B, when the output of MMA is K-major. bool const isWeights = (matrixType == MatrixType::MatrixA && options.mTransposeMmaOutput) || @@ -99,13 +97,13 @@ static auto makeTmaShapeStrideAbc(GemmOptions const& options, // The outer dimension. auto numTokens = - (matrixType == MatrixType::MatrixA || matrixType == MatrixType::MatrixC) ? mM : mN; + (matrixType == MatrixType::MatrixA || matrixType == MatrixType::MatrixC) ? mM : mN; // The outer dimension tile size. auto ctaTileNumTokens = - (matrixType == MatrixType::MatrixA || matrixType == MatrixType::MatrixC) ? tileM : tileN; + (matrixType == MatrixType::MatrixA || matrixType == MatrixType::MatrixC) ? tileM : tileN; // The outer dimension of TMA box shape. auto tileNumTokens = - (matrixType == MatrixType::MatrixC) ? options.mEpilogueTileM : ctaTileNumTokens; + (matrixType == MatrixType::MatrixC) ? options.mEpilogueTileM : ctaTileNumTokens; // The inner dimension. auto hiddenSize = (matrixType == MatrixType::MatrixC) ? mN : mK; @@ -113,7 +111,7 @@ static auto makeTmaShapeStrideAbc(GemmOptions const& options, auto ctaTileHiddenSize = (matrixType == MatrixType::MatrixC) ? tileN : tileK; // The inner dimension of TMA box shape. auto tileHiddenSize = - (matrixType == MatrixType::MatrixC) ? options.mEpilogueTileN : ctaTileHiddenSize; + (matrixType == MatrixType::MatrixC) ? options.mEpilogueTileN : ctaTileHiddenSize; // Swap matrix C sizes if output is transposed. if (matrixType == MatrixType::MatrixC && options.mTransposeMmaOutput) { @@ -141,14 +139,11 @@ static auto makeTmaShapeStrideAbc(GemmOptions const& options, // If TMA OOB optimization is used: // Shape [hidden, tokens] Stride [1, hidden] becomes // Shape [hidden, tileN, TmaDimMax, TmaDimMax] Stride [1, hidden, XLargeN - hidden, hidden] - shape = {static_cast(hiddenSize), - static_cast(ctaTileNumTokens), - static_cast(tg::TmaDimMax), - static_cast(tg::TmaDimMax)}; + shape = {static_cast(hiddenSize), static_cast(ctaTileNumTokens), + static_cast(tg::TmaDimMax), static_cast(tg::TmaDimMax)}; } else if (isWeights) { // If the matrix is a weights matrix, we use 3D logical shape (B, M, K) or (B, N, K). - shape = {static_cast(hiddenSize), - static_cast(numTokens), + shape = {static_cast(hiddenSize), static_cast(numTokens), static_cast(options.mNumBatches)}; } @@ -156,13 +151,10 @@ static auto makeTmaShapeStrideAbc(GemmOptions const& options, // Swap the first two dimension as mentioned before. std::vector stride = {1, static_cast(hiddenSize)}; if (useTmaOobOpt) { - stride = {1, - static_cast(hiddenSize), - static_cast(tg::XLargeN - hiddenSize), + stride = {1, static_cast(hiddenSize), static_cast(tg::XLargeN - hiddenSize), static_cast(hiddenSize)}; } else if (isWeights) { - stride = {1, - static_cast(hiddenSize), + stride = {1, static_cast(hiddenSize), static_cast(hiddenSize) * static_cast(numTokens)}; } @@ -172,7 +164,7 @@ static auto makeTmaShapeStrideAbc(GemmOptions const& options, // Alternate layouts (MajorMn and BlockMajorK) do not apply to matrixC if (matrixType != MatrixType::MatrixC) { gemm::MatrixLayout layout = - (matrixType == MatrixType::MatrixA) ? options.mLayoutA : options.mLayoutB; + (matrixType == MatrixType::MatrixA) ? options.mLayoutA : options.mLayoutB; // Note, only the weights support non MajorK layouts if (layout == gemm::MatrixLayout::MajorMn) { // Apply transpose if necessary @@ -181,12 +173,10 @@ static auto makeTmaShapeStrideAbc(GemmOptions const& options, std::swap(tileShape[0], tileShape[1]); } else if (layout == gemm::MatrixLayout::BlockMajorK) { // Set shapes based on blocking layout - shape = {static_cast(options.mBlockK), - static_cast(numTokens), + shape = {static_cast(options.mBlockK), static_cast(numTokens), static_cast(mK / options.mBlockK), static_cast(options.mNumBatches)}; - stride = {1, - static_cast(options.mBlockK), + stride = {1, static_cast(options.mBlockK), static_cast(numTokens * options.mBlockK), static_cast(hiddenSize * numTokens)}; @@ -200,17 +190,9 @@ static auto makeTmaShapeStrideAbc(GemmOptions const& options, } // Create the TMA shape/stride for A/B block scaling factors. -static auto makeTmaShapeStrideSfAb(int mM, - int mN, - int mK, - MatrixType matrixType, - int tileM, - int tileN, - int tileK, - tg::SfLayout layout, - int sfReshapeFactor, +static auto makeTmaShapeStrideSfAb(int mM, int mN, int mK, MatrixType matrixType, int tileM, + int tileN, int tileK, tg::SfLayout layout, int sfReshapeFactor, const int32_t numEltsPerSf) { - // The outer dimension. auto numTokens = matrixType == MatrixType::MatrixA ? mM : mN; // The inner dimension. @@ -221,121 +203,102 @@ static auto makeTmaShapeStrideSfAb(int mM, auto hiddenSizePerTile = tileK; switch (layout) { - case tg::SfLayout::R128c4: { - // The scaling factor tensor packs 128x4 tiles into contiguous 512B blocks. - // The 512B block maps to a 32x16B (32x128b) block in TMEM. - // See https://nvbugspro.nvidia.com/bug/4165523 - // - // Additionally, we have to meet constraints of TMA that the box dimensions are less - // than 256 and boxDim[0] is a multiple of 16B. - // - // The "logical" tensor is: [outer, inner / numEltsPerSf] - // The aforementioned format is: [outer / 128, inner / numEltsPerSf / 4, 512] - // The shape we use for TMA is: [outer / 128, inner / numEltsPerSf / 4, 2, 256] - - auto shape = std::vector{256, - 2, - static_cast(ceilDiv(hiddenSize, numEltsPerSf * 4)), - static_cast(ceilDiv(numTokens, 128))}; - - std::vector stride(shape.size()); - stride[0] = 1; - for (size_t i = 1; i < shape.size(); i++) { - stride[i] = shape[i - 1] * stride[i - 1]; - } + case tg::SfLayout::R128c4: { + // The scaling factor tensor packs 128x4 tiles into contiguous 512B blocks. + // The 512B block maps to a 32x16B (32x128b) block in TMEM. + // See https://nvbugspro.nvidia.com/bug/4165523 + // + // Additionally, we have to meet constraints of TMA that the box dimensions are less + // than 256 and boxDim[0] is a multiple of 16B. + // + // The "logical" tensor is: [outer, inner / numEltsPerSf] + // The aforementioned format is: [outer / 128, inner / numEltsPerSf / 4, 512] + // The shape we use for TMA is: [outer / 128, inner / numEltsPerSf / 4, 2, 256] + + auto shape = std::vector{ + 256, 2, static_cast(ceilDiv(hiddenSize, numEltsPerSf * 4)), + static_cast(ceilDiv(numTokens, 128))}; + + std::vector stride(shape.size()); + stride[0] = 1; + for (size_t i = 1; i < shape.size(); i++) { + stride[i] = shape[i - 1] * stride[i - 1]; + } - auto tileShapes = - std::vector{256, - 2, - static_cast(ceilDiv(hiddenSizePerTile, numEltsPerSf * 4)), - static_cast(ceilDiv(numTokensPerTile, 128))}; + auto tileShapes = std::vector{ + 256, 2, static_cast(ceilDiv(hiddenSizePerTile, numEltsPerSf * 4)), + static_cast(ceilDiv(numTokensPerTile, 128))}; - return std::make_tuple(shape, stride, tileShapes); - } - - case tg::SfLayout::R8c4: { - // The scaling factor tensor packs 8x4 tiles into contiguous 32B blocks. - // - // As the inner dimension (k) is often a multiple of the tile size, we can reshape to use - // fewer read requests, if the tile dimensions allow. It does not reduce the number of - // instructions. - // - // I.e., let's define r = min(⌈hiddenSizePerTile / (numEltsPerSf * 4)⌉, 8) - // - // The "logical" tensor is: [outer, inner / numEltsPerSf] - // The 8x4 SF layout is: [⌈outer / 8⌉, inner / (4 * numEltsPerSf), 32] - // The TMA tensor shape is: [⌈outer / 8⌉, inner / (4 * numEltsPerSf * r), r * 32] - // - // The caveat of NumRepeats>1 is we must pad the hidden dimension of SF to multiples of - // NumRepeats * numEltsPerSf * 4. - - // Detect if the supplied factor is power of 2. E.g., 0b0100 and (0b0100 - 1) == 0b0000. - int const r = sfReshapeFactor; - if (r > 0 && (r & (r - 1)) != 0) { - throw std::runtime_error("mSfReshapeFactor must be positive and a power of 2. Found " + - std::to_string(r)); + return std::make_tuple(shape, stride, tileShapes); } - // Sanitize number of repeats so it doesn't exceed the dimension. - int const repeats = std::min(ceilDiv(hiddenSizePerTile, numEltsPerSf * 4), r); + case tg::SfLayout::R8c4: { + // The scaling factor tensor packs 8x4 tiles into contiguous 32B blocks. + // + // As the inner dimension (k) is often a multiple of the tile size, we can reshape to use + // fewer read requests, if the tile dimensions allow. It does not reduce the number of + // instructions. + // + // I.e., let's define r = min(⌈hiddenSizePerTile / (numEltsPerSf * 4)⌉, 8) + // + // The "logical" tensor is: [outer, inner / numEltsPerSf] + // The 8x4 SF layout is: [⌈outer / 8⌉, inner / (4 * numEltsPerSf), 32] + // The TMA tensor shape is: [⌈outer / 8⌉, inner / (4 * numEltsPerSf * r), r * 32] + // + // The caveat of NumRepeats>1 is we must pad the hidden dimension of SF to multiples of + // NumRepeats * numEltsPerSf * 4. + + // Detect if the supplied factor is power of 2. E.g., 0b0100 and (0b0100 - 1) == 0b0000. + int const r = sfReshapeFactor; + if (r > 0 && (r & (r - 1)) != 0) { + throw std::runtime_error("mSfReshapeFactor must be positive and a power of 2. Found " + + std::to_string(r)); + } - // Detect if the input hidden size K is a multiple of the repeats. - if (ceilDiv(hiddenSize, numEltsPerSf * 4) % repeats != 0) { - throw std::runtime_error("SF hiddenSize K (" + - std::to_string(ceilDiv(hiddenSize, numEltsPerSf * 4)) + - ") must be a multiple of repeats (" + std::to_string(repeats) + ")"); - } + // Sanitize number of repeats so it doesn't exceed the dimension. + int const repeats = std::min(ceilDiv(hiddenSizePerTile, numEltsPerSf * 4), r); - auto shape = - std::vector{static_cast(repeats * 32), - static_cast(ceilDiv(hiddenSize, numEltsPerSf * 4 * repeats)), - static_cast(ceilDiv(numTokens, 8))}; + // Detect if the input hidden size K is a multiple of the repeats. + if (ceilDiv(hiddenSize, numEltsPerSf * 4) % repeats != 0) { + throw std::runtime_error( + "SF hiddenSize K (" + std::to_string(ceilDiv(hiddenSize, numEltsPerSf * 4)) + + ") must be a multiple of repeats (" + std::to_string(repeats) + ")"); + } - std::vector stride(shape.size()); - stride[0] = 1; - for (size_t i = 1; i < shape.size(); i++) { - stride[i] = shape[i - 1] * stride[i - 1]; - } + auto shape = std::vector{ + static_cast(repeats * 32), + static_cast(ceilDiv(hiddenSize, numEltsPerSf * 4 * repeats)), + static_cast(ceilDiv(numTokens, 8))}; - auto tileShapes = std::vector{ - static_cast(repeats * 32), - static_cast(ceilDiv(hiddenSizePerTile, numEltsPerSf * 4 * repeats)), - static_cast(ceilDiv(numTokensPerTile, 8))}; + std::vector stride(shape.size()); + stride[0] = 1; + for (size_t i = 1; i < shape.size(); i++) { + stride[i] = shape[i - 1] * stride[i - 1]; + } - return std::make_tuple(shape, stride, tileShapes); - } + auto tileShapes = std::vector{ + static_cast(repeats * 32), + static_cast(ceilDiv(hiddenSizePerTile, numEltsPerSf * 4 * repeats)), + static_cast(ceilDiv(numTokensPerTile, 8))}; - default: - throw std::runtime_error("Unsupported SF layout"); + return std::make_tuple(shape, stride, tileShapes); + } + + default: + throw std::runtime_error("Unsupported SF layout"); } return std::make_tuple(std::vector{}, std::vector{}, std::vector{}); } template -static KernelParams setKernelParams(GemmOptions_ const& options, - bool const batchM, - void const* ptrA, - void const* ptrB, - void* ptrC, - void const* dSfA, - void const* dSfB, - void const* ptrPerTokenSfA, - void const* ptrPerTokenSfB, - void const* ptrBias, - void* dSfC, - float const* ptrScaleC, - float const* ptrScaleGate, - float const* ptrClampLimit, - float const* ptrGatedActAlpha, - float const* ptrGatedActBeta, - int32_t const* routeMap, - float* rowMax, - uint32_t* rowMaxBars, - int32_t const* ptrNumNonExitingCtas = nullptr, - int32_t const* ptrTotalNumPaddedTokens = nullptr, - int32_t const* ptrCtaIdxXyToBatchIdx = nullptr, - int32_t const* ptrCtaIdxXyToMnLimit = nullptr, - int32_t const maxNumCtas = KernelParams::MaxNumCtas) { - +static KernelParams setKernelParams( + GemmOptions_ const& options, bool const batchM, void const* ptrA, void const* ptrB, void* ptrC, + void const* dSfA, void const* dSfB, void const* ptrPerTokenSfA, void const* ptrPerTokenSfB, + void const* ptrBias, void* dSfC, float const* ptrScaleC, float const* ptrScaleGate, + float const* ptrClampLimit, float const* ptrGatedActAlpha, float const* ptrGatedActBeta, + int32_t const* routeMap, float* rowMax, uint32_t* rowMaxBars, + int32_t const* ptrNumNonExitingCtas = nullptr, int32_t const* ptrTotalNumPaddedTokens = nullptr, + int32_t const* ptrCtaIdxXyToBatchIdx = nullptr, int32_t const* ptrCtaIdxXyToMnLimit = nullptr, + int32_t const maxNumCtas = KernelParams::MaxNumCtas) { static_assert(sizeof(KernelParams) <= 32 * 1024, "sizeof(KernelParams) has to be less or equal than 32KB"); @@ -360,7 +323,6 @@ static KernelParams setKernelParams(GemmOptions_ const& options, if (options.mIsStaticBatch) { params.totalNumPaddedTokens = 0; for (int b = 0; b < options.mNumBatches; b++) { - int mM = batchM ? options.mBatchedM[b] : options.mM; int mN = batchM ? options.mN : options.mBatchedN[b]; @@ -388,7 +350,7 @@ static KernelParams setKernelParams(GemmOptions_ const& options, // This is now an identity map and it is no longer needed. // params.ctaIdxXyToTileIdxMn[ctaOffset + cta] = ctaOffset + cta; params.ctaIdxXyToMnLimit[ctaOffset + cta] = - std::min((ctaOffset + cta + 1) * tile, ctaOffset * tile + tokensPerTile); + std::min((ctaOffset + cta + 1) * tile, ctaOffset * tile + tokensPerTile); } ctaOffset += numCtas; @@ -422,21 +384,12 @@ static KernelParams setKernelParams(GemmOptions_ const& options, params.tileStridePerBatch = options.mM / options.mTileM; params.nm = options.mM; // Shape/stride for gmem tensor A. - auto [shapeA, strideA, tileShapeA] = makeTmaShapeStrideAbc(options, - options.mM, - options.mN, - options.mK, - options.mTileM, - options.mTileN, - options.mTileK, - MatrixType::MatrixA); + auto [shapeA, strideA, tileShapeA] = + makeTmaShapeStrideAbc(options, options.mM, options.mN, options.mK, options.mTileM, + options.mTileN, options.mTileK, MatrixType::MatrixA); // Build tma descriptor for A. - params.tmaA[0] = gemm::buildNdTmaDescriptor(options.mDtypeA, - options.mMmaKind, - shapeA, - strideA, - tileShapeA, - const_cast(ptrA)); + params.tmaA[0] = gemm::buildNdTmaDescriptor(options.mDtypeA, options.mMmaKind, shapeA, strideA, + tileShapeA, const_cast(ptrA)); // The input is padded: // [act0, padding, padding, ... TileN size .., act1, padding, padding, ...] @@ -446,55 +399,35 @@ static KernelParams setKernelParams(GemmOptions_ const& options, bool useRouteAct = batchedGemm::doesRouteImplUseTma(options.mRouteImpl); // B is the activation // Shape/stride for gmem tensor B. - auto [shapeB, strideB, tileShapeB] = - makeTmaShapeStrideAbc(options, - options.mM, - useRouteAct ? options.mNumTokens : inputNumTokens, - options.mK, - options.mTileM, - (useRouteAct ? 1 : options.mTileN), - options.mTileK, - MatrixType::MatrixB); + auto [shapeB, strideB, tileShapeB] = makeTmaShapeStrideAbc( + options, options.mM, useRouteAct ? options.mNumTokens : inputNumTokens, options.mK, + options.mTileM, (useRouteAct ? 1 : options.mTileN), options.mTileK, MatrixType::MatrixB); // Build tma descriptor for B. - params.tmaB[0] = gemm::buildNdTmaDescriptor(options.mDtypeB, - options.mMmaKind, - shapeB, - strideB, - tileShapeB, - const_cast(ptrB)); + params.tmaB[0] = gemm::buildNdTmaDescriptor(options.mDtypeB, options.mMmaKind, shapeB, + strideB, tileShapeB, const_cast(ptrB)); } if (options.mDtypeA == tg::Dtype::E2m1 || options.mDtypeA == tg::Dtype::MxE4m3 || options.mDtypeA == tg::Dtype::MxE2m1) { tg::Dtype const dTypeSf = - (options.mDtypeA == tg::Dtype::E2m1) ? tg::Dtype::E4m3 : tg::Dtype::UE8m0; + (options.mDtypeA == tg::Dtype::E2m1) ? tg::Dtype::E4m3 : tg::Dtype::UE8m0; // Build TMA descriptor for gmem A block scaling factors. auto [shapeSfA, strideSfA, tileShapesSfA] = makeTmaShapeStrideSfAb( - options.mM * options.mNumBatches, - options.mN, - options.mK, - MatrixType::MatrixA, - options.mTileM, - options.mTileN, - options.mTileK, - tg::SfLayout::R128c4, - options.mSfReshapeFactor, - options.mSfBlockSizeA.value_or(tg::dtypeNumEltsPerSf(options.mDtypeA))); - params.tmaSfA[0] = gemm::buildSfTmaDescriptor(dTypeSf, - shapeSfA, - strideSfA, - tileShapesSfA, + options.mM * options.mNumBatches, options.mN, options.mK, MatrixType::MatrixA, + options.mTileM, options.mTileN, options.mTileK, tg::SfLayout::R128c4, + options.mSfReshapeFactor, + options.mSfBlockSizeA.value_or(tg::dtypeNumEltsPerSf(options.mDtypeA))); + params.tmaSfA[0] = gemm::buildSfTmaDescriptor(dTypeSf, shapeSfA, strideSfA, tileShapesSfA, const_cast(dSfA)); } if (options.mDtypeB == tg::Dtype::E2m1 || options.mDtypeB == tg::Dtype::MxE4m3 || options.mDtypeB == tg::Dtype::MxE2m1) { tg::Dtype const dTypeSf = - (options.mDtypeB == tg::Dtype::E2m1) ? tg::Dtype::E4m3 : tg::Dtype::UE8m0; + (options.mDtypeB == tg::Dtype::E2m1) ? tg::Dtype::E4m3 : tg::Dtype::UE8m0; if (batchedGemm::doesRouteImplUseTma(options.mRouteImpl)) { - // The input is NOT padded: // [act0, act1, act2, ...] @@ -505,45 +438,24 @@ static KernelParams setKernelParams(GemmOptions_ const& options, auto numSfsInK = options.mK / numEltsPerSf; numSfsInK = ceilDiv(numSfsInK, 16) * 16; - auto [shapeSfB, strideSfB, tileShapesSfB] = - makeTmaShapeStrideAbc(options, - options.mM, - options.mNumTokens, - numSfsInK, - options.mTileM, - 1 /* tileN */, - options.mTileK / numEltsPerSf, - MatrixType::MatrixB); - params.tmaSfB[0] = gemm::buildNdTmaDescriptor(dTypeSf, - options.mMmaKind, - shapeSfB, - strideSfB, - tileShapesSfB, - const_cast(dSfB), - /*doSwizzle*/ true); + auto [shapeSfB, strideSfB, tileShapesSfB] = makeTmaShapeStrideAbc( + options, options.mM, options.mNumTokens, numSfsInK, options.mTileM, 1 /* tileN */, + options.mTileK / numEltsPerSf, MatrixType::MatrixB); + params.tmaSfB[0] = gemm::buildNdTmaDescriptor( + dTypeSf, options.mMmaKind, shapeSfB, strideSfB, tileShapesSfB, const_cast(dSfB), + /*doSwizzle*/ true); } else if (batchedGemm::doesRouteImplUseNoRoute(options.mRouteImpl)) { - // The input is padded: // [act0, padding, padding, ... TileN size .., act1, padding, padding, ...] auto const inputNumTokensSfB = ctaOffset * options.mTileN; // Build TMA descriptor for gmem B block scaling factors. - auto [shapeSfB, strideSfB, tileShapesSfB] = - makeTmaShapeStrideSfAb(options.mM, - inputNumTokensSfB, - options.mK, - MatrixType::MatrixB, - options.mTileM, - options.mTileN, - options.mTileK, - options.mSfLayoutB, - options.mSfReshapeFactor, - tg::dtypeNumEltsPerSf(options.mDtypeB)); - params.tmaSfB[0] = gemm::buildSfTmaDescriptor(dTypeSf, - shapeSfB, - strideSfB, - tileShapesSfB, + auto [shapeSfB, strideSfB, tileShapesSfB] = makeTmaShapeStrideSfAb( + options.mM, inputNumTokensSfB, options.mK, MatrixType::MatrixB, options.mTileM, + options.mTileN, options.mTileK, options.mSfLayoutB, options.mSfReshapeFactor, + tg::dtypeNumEltsPerSf(options.mDtypeB)); + params.tmaSfB[0] = gemm::buildSfTmaDescriptor(dTypeSf, shapeSfB, strideSfB, tileShapesSfB, const_cast(dSfB)); } } @@ -551,21 +463,12 @@ static KernelParams setKernelParams(GemmOptions_ const& options, // C is the output activation if (options.mUseTmaStore) { // Shape/stride for gmem tensor C. - auto [shapeC, strideC, tileShapeC] = makeTmaShapeStrideAbc(options, - options.mM, - ctaOffset * options.mTileN, - options.mK, - options.mTileM, - options.mTileN, - options.mTileK, - MatrixType::MatrixC); + auto [shapeC, strideC, tileShapeC] = makeTmaShapeStrideAbc( + options, options.mM, ctaOffset * options.mTileN, options.mK, options.mTileM, + options.mTileN, options.mTileK, MatrixType::MatrixC); // Build tma descriptor for C. - params.tmaC[0] = gemm::buildNdTmaDescriptor(options.mDtypeC, - tg::MmaKind::Auto, - shapeC, - strideC, - tileShapeC, - ptrC); + params.tmaC[0] = gemm::buildNdTmaDescriptor(options.mDtypeC, tg::MmaKind::Auto, shapeC, + strideC, tileShapeC, ptrC); } else { params.ptrC = ptrC; } @@ -578,21 +481,12 @@ static KernelParams setKernelParams(GemmOptions_ const& options, params.tileStridePerBatch = options.mN / options.mTileN; params.nm = options.mN; // Shape/stride for gmem tensor B. - auto [shapeB, strideB, tileShapeB] = makeTmaShapeStrideAbc(options, - options.mM, - options.mN, - options.mK, - options.mTileM, - options.mTileN, - options.mTileK, - MatrixType::MatrixB); + auto [shapeB, strideB, tileShapeB] = + makeTmaShapeStrideAbc(options, options.mM, options.mN, options.mK, options.mTileM, + options.mTileN, options.mTileK, MatrixType::MatrixB); // Build tma descriptor for B. - params.tmaB[0] = gemm::buildNdTmaDescriptor(options.mDtypeB, - options.mMmaKind, - shapeB, - strideB, - tileShapeB, - const_cast(ptrB)); + params.tmaB[0] = gemm::buildNdTmaDescriptor(options.mDtypeB, options.mMmaKind, shapeB, strideB, + tileShapeB, const_cast(ptrB)); if (options.mRouteImpl == batchedGemm::RouteImpl::NoRoute) { // A is the activation @@ -600,50 +494,30 @@ static KernelParams setKernelParams(GemmOptions_ const& options, // The input is padded: // [act0, padding, padding, ... tileM size .., act1, padding, padding, ...] auto const inputNumTokens = ctaOffset * options.mTileM; - auto [shapeA, strideA, tileShapeA] = makeTmaShapeStrideAbc(options, - inputNumTokens, - options.mN, - options.mK, - options.mTileM, - options.mTileN, - options.mTileK, - MatrixType::MatrixA); + auto [shapeA, strideA, tileShapeA] = + makeTmaShapeStrideAbc(options, inputNumTokens, options.mN, options.mK, options.mTileM, + options.mTileN, options.mTileK, MatrixType::MatrixA); // Build tma descriptor for A. - params.tmaA[0] = gemm::buildNdTmaDescriptor(options.mDtypeA, - options.mMmaKind, - shapeA, - strideA, - tileShapeA, - const_cast(ptrA)); + params.tmaA[0] = gemm::buildNdTmaDescriptor(options.mDtypeA, options.mMmaKind, shapeA, + strideA, tileShapeA, const_cast(ptrA)); } if (options.mDtypeA == tg::Dtype::E2m1 || options.mDtypeA == tg::Dtype::MxE4m3 || options.mDtypeA == tg::Dtype::MxE2m1) { tg::Dtype const dTypeSf = - (options.mDtypeA == tg::Dtype::E2m1) ? tg::Dtype::E4m3 : tg::Dtype::UE8m0; + (options.mDtypeA == tg::Dtype::E2m1) ? tg::Dtype::E4m3 : tg::Dtype::UE8m0; if (options.mRouteImpl == batchedGemm::RouteImpl::NoRoute) { - // The input is padded: // [act0, padding, padding, ... tileM size .., act1, padding, padding, ...] auto const inputNumTokensSfA = ctaOffset * options.mTileM; // Build TMA descriptor for gmem A block scaling factors. auto [shapeSfA, strideSfA, tileShapesSfA] = makeTmaShapeStrideSfAb( - inputNumTokensSfA, - options.mN, - options.mK, - MatrixType::MatrixA, - options.mTileM, - options.mTileN, - options.mTileK, - tg::SfLayout::R128c4, - options.mSfReshapeFactor, - options.mSfBlockSizeA.value_or(tg::dtypeNumEltsPerSf(options.mDtypeA))); - params.tmaSfA[0] = gemm::buildSfTmaDescriptor(dTypeSf, - shapeSfA, - strideSfA, - tileShapesSfA, + inputNumTokensSfA, options.mN, options.mK, MatrixType::MatrixA, options.mTileM, + options.mTileN, options.mTileK, tg::SfLayout::R128c4, options.mSfReshapeFactor, + options.mSfBlockSizeA.value_or(tg::dtypeNumEltsPerSf(options.mDtypeA))); + params.tmaSfA[0] = gemm::buildSfTmaDescriptor(dTypeSf, shapeSfA, strideSfA, tileShapesSfA, const_cast(dSfA)); } } @@ -651,45 +525,26 @@ static KernelParams setKernelParams(GemmOptions_ const& options, if (options.mDtypeB == tg::Dtype::E2m1 || options.mDtypeB == tg::Dtype::MxE4m3 || options.mDtypeB == tg::Dtype::MxE2m1) { tg::Dtype const dTypeSf = - (options.mDtypeB == tg::Dtype::E2m1) ? tg::Dtype::E4m3 : tg::Dtype::UE8m0; + (options.mDtypeB == tg::Dtype::E2m1) ? tg::Dtype::E4m3 : tg::Dtype::UE8m0; // Build TMA descriptor for gmem B block scaling factors. - auto [shapeSfB, strideSfB, tileShapesSfB] = - makeTmaShapeStrideSfAb(options.mM, - options.mN * options.mNumBatches, - options.mK, - MatrixType::MatrixB, - options.mTileM, - options.mTileN, - options.mTileK, - options.mSfLayoutB, - options.mSfReshapeFactor, - tg::dtypeNumEltsPerSf(options.mDtypeB)); - params.tmaSfB[0] = gemm::buildSfTmaDescriptor(dTypeSf, - shapeSfB, - strideSfB, - tileShapesSfB, + auto [shapeSfB, strideSfB, tileShapesSfB] = makeTmaShapeStrideSfAb( + options.mM, options.mN * options.mNumBatches, options.mK, MatrixType::MatrixB, + options.mTileM, options.mTileN, options.mTileK, options.mSfLayoutB, + options.mSfReshapeFactor, tg::dtypeNumEltsPerSf(options.mDtypeB)); + params.tmaSfB[0] = gemm::buildSfTmaDescriptor(dTypeSf, shapeSfB, strideSfB, tileShapesSfB, const_cast(dSfB)); } // C is the output activation if (options.mUseTmaStore) { // Shape/stride for gmem tensor C. - auto [shapeC, strideC, tileShapeC] = makeTmaShapeStrideAbc(options, - ctaOffset * options.mTileM, - options.mN, - options.mK, - options.mTileM, - options.mTileN, - options.mTileK, - MatrixType::MatrixC); + auto [shapeC, strideC, tileShapeC] = makeTmaShapeStrideAbc( + options, ctaOffset * options.mTileM, options.mN, options.mK, options.mTileM, + options.mTileN, options.mTileK, MatrixType::MatrixC); // Build tma descriptor for C. - params.tmaC[0] = gemm::buildNdTmaDescriptor(options.mDtypeC, - tg::MmaKind::Auto, - shapeC, - strideC, - tileShapeC, - ptrC); + params.tmaC[0] = gemm::buildNdTmaDescriptor(options.mDtypeC, tg::MmaKind::Auto, shapeC, + strideC, tileShapeC, ptrC); } else { params.ptrC = ptrC; } @@ -714,10 +569,10 @@ static KernelParams setKernelParams(GemmOptions_ const& options, return params; } #endif -}; // namespace KernelParamsSetup +}; // namespace KernelParamsSetup //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace batchedGemm +} // namespace batchedGemm -} // namespace batchedGemm +} // namespace batchedGemm diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h index a6056009de..e11374739f 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h @@ -1,23 +1,22 @@ /* -* SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & -* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once namespace batchedGemm { - // This is device code struct KernelParams { @@ -494,4 +493,4 @@ struct KernelParams { /////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace batchedGemm +} // namespace batchedGemm diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h index 6699530009..4d79f83c23 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h @@ -1,27 +1,28 @@ /* -* SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & -* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once #include #include -#include "trtllm/gen/DtypeDecl.h" + +#include "Enums.h" #include "trtllm/gen/CommonUtils.h" +#include "trtllm/gen/DtypeDecl.h" #include "trtllm/gen/MmaDecl.h" -#include "Enums.h" namespace batchedGemm { @@ -35,17 +36,14 @@ namespace tg = trtllm::gen; // Structure to manage memory allocation with configurable reuse class MemAllocatorHelper { -public: + public: // The default constructor. MemAllocatorHelper() {} // Constructor to initialize chunk sizes, alignments, and reuse flags MemAllocatorHelper(std::vector> const& sizes, - std::vector const& reuse, - std::vector const& names) - : mNumBytesAndAlignmentPerSmemChunk(sizes) - , mFirstChunkReuse(reuse) - , mSmemChunkNames(names) {} + std::vector const& reuse, std::vector const& names) + : mNumBytesAndAlignmentPerSmemChunk(sizes), mFirstChunkReuse(reuse), mSmemChunkNames(names) {} // Function to calculate the size of the array from 0 to jj chunks int32_t getOffsetBeforeChunk(int jj) const { @@ -97,17 +95,14 @@ class MemAllocatorHelper { // Print the contents of this object. void print() const { for (size_t ii = 0; ii < mNumBytesAndAlignmentPerSmemChunk.size(); ++ii) { - printf("Chunk %zd %s: %d bytes, %d alignment, reuse %s, offset %d\n", - ii, - mSmemChunkNames[ii].c_str(), - mNumBytesAndAlignmentPerSmemChunk[ii].first, - mNumBytesAndAlignmentPerSmemChunk[ii].second, - mFirstChunkReuse[ii] ? "true" : "false", + printf("Chunk %zd %s: %d bytes, %d alignment, reuse %s, offset %d\n", ii, + mSmemChunkNames[ii].c_str(), mNumBytesAndAlignmentPerSmemChunk[ii].first, + mNumBytesAndAlignmentPerSmemChunk[ii].second, mFirstChunkReuse[ii] ? "true" : "false", getChunkOffset(ii)); } } -private: + private: int32_t getChunkOffset(int32_t ii) const { if (mFirstChunkReuse[ii]) { // Reuse the offset of the 0th chunk. @@ -129,7 +124,7 @@ class MemAllocatorHelper { return (size + alignment - 1) & ~(alignment - 1); } -private: + private: // Sizes and alignment requirements of each chunk // NOTE: be careful and make sure that the memory dependency is clear and // chunks in the beginning of the SMEM can be overwritten. @@ -156,39 +151,20 @@ inline int getNumSmemBitsPerElt(tg::Dtype dtype, tg::MmaKind mmaKind) { //////////////////////////////////////////////////////////////////////////////////////////////////// class KernelTraits { -public: + public: // The default constructor. KernelTraits() {} // The constructor. - KernelTraits(tg::Dtype dtypeA, - tg::Dtype dtypeB, - tg::Dtype dtypeC, - tg::Dtype dtypeAcc, - tg::Dtype dtypeMmaA, - tg::Dtype dtypeMmaB, - tg::MmaKind mmaKind, - int32_t mmaK, - int32_t tileM, - int32_t tileN, - int32_t tileK, - int32_t epilogueTileM, - int32_t epilogueTileN, - int32_t numStages, - int32_t numStagesMma, - int32_t numSlicesForSplitK, - int32_t numSlicesForSliceK, - SplitK splitK, - bool useTmaStore, - bool transposeMmaOutput, - AllReduceAlgo allReduceAlgo, - bool usePersistentScheduler, - bool useDeepSeekFp8, - bool usePerTokenSfA, - bool usePerTokenSfB, - bool useTwoCtas, - BiasType biasType) - : mMmaKind{mmaKind} { + KernelTraits(tg::Dtype dtypeA, tg::Dtype dtypeB, tg::Dtype dtypeC, tg::Dtype dtypeAcc, + tg::Dtype dtypeMmaA, tg::Dtype dtypeMmaB, tg::MmaKind mmaKind, int32_t mmaK, + int32_t tileM, int32_t tileN, int32_t tileK, int32_t epilogueTileM, + int32_t epilogueTileN, int32_t numStages, int32_t numStagesMma, + int32_t numSlicesForSplitK, int32_t numSlicesForSliceK, SplitK splitK, + bool useTmaStore, bool transposeMmaOutput, AllReduceAlgo allReduceAlgo, + bool usePersistentScheduler, bool useDeepSeekFp8, bool usePerTokenSfA, + bool usePerTokenSfB, bool useTwoCtas, BiasType biasType) + : mMmaKind{mmaKind} { // // SMEM // @@ -222,7 +198,7 @@ class KernelTraits { { // Number of bytes in load A shared memory. auto const numSmemBytesLoadA = - numStages * tileM * tileK * getNumSmemBitsPerElt(dtypeA, mMmaKind) / 8 /* bits */; + numStages * tileM * tileK * getNumSmemBitsPerElt(dtypeA, mMmaKind) / 8 /* bits */; // Number of bytes for load A alignment for TMA load. auto const numBytesAlignmentLoadA = 1024; // loadA is already at first chunk. No need to reuse it. @@ -230,7 +206,7 @@ class KernelTraits { // Add info. smemChunkNames.emplace_back("smemLoadA"); numBytesAndAlignmentPerSmemChunk.emplace_back( - std::make_pair(numSmemBytesLoadA, numBytesAlignmentLoadA)); + std::make_pair(numSmemBytesLoadA, numBytesAlignmentLoadA)); firstChunkReuseSmem.emplace_back(reuseChunksSmemLoadA); } @@ -246,7 +222,7 @@ class KernelTraits { // Add info. smemChunkNames.emplace_back("smemLoadB"); numBytesAndAlignmentPerSmemChunk.emplace_back( - std::make_pair(numSmemBytesLoadB, numBytesAlignmentLoadB)); + std::make_pair(numSmemBytesLoadB, numBytesAlignmentLoadB)); firstChunkReuseSmem.emplace_back(reuseChunksSmemLoadB); } @@ -258,9 +234,9 @@ class KernelTraits { { // Number of bytes in save shuffled B in shared memory. auto const numSmemBytesLoadB = - numSlicesForSliceK > 1 - ? numStages * tileN * tileK * getNumSmemBitsPerElt(dtypeB, mMmaKind) / 8 /* bits */ - : 0; + numSlicesForSliceK > 1 + ? numStages * tileN * tileK * getNumSmemBitsPerElt(dtypeB, mMmaKind) / 8 /* bits */ + : 0; // Number of bytes for load B alignment for TMA load. auto const numBytesAlignmentLoadB = 1024; // No need to reuse the first chunk. @@ -269,7 +245,7 @@ class KernelTraits { // Add info. smemChunkNames.emplace_back("smemBShuffle"); numBytesAndAlignmentPerSmemChunk.emplace_back( - std::make_pair(numSmemBytesLoadB, numBytesAlignmentLoadB)); + std::make_pair(numSmemBytesLoadB, numBytesAlignmentLoadB)); firstChunkReuseSmem.emplace_back(reuseChunksSmemLoadB); } @@ -297,21 +273,21 @@ class KernelTraits { // Number of bytes to store the output in smem. auto const numBytesSmemStoreC = usesSmemForGmemC - ? extraGmemCMultiplier * epilogueTileM * epilogueTileN * - tg::dtypeGetNumBits(dtypeSmemC) / 8 /* bits */ - : 0; + ? extraGmemCMultiplier * epilogueTileM * epilogueTileN * + tg::dtypeGetNumBits(dtypeSmemC) / 8 /* bits */ + : 0; // Number of bytes for store C alignment for TMA store. auto const numBytesAlignmentStoreC = 1024; // gmemC reuses loadAb memory for split-K in DSMEM. // Epilogue1 does not reuse and continues after the memory allocated Epilogue0 // NOTE: we can always reuse loadAb SMEM as long as we don't have persistent scheduler. auto const reuseFirstChunksSmemStoreC = - doesSplitKUseDsmem(splitK) && resIdx == 0 && !usePersistentScheduler; + doesSplitKUseDsmem(splitK) && resIdx == 0 && !usePersistentScheduler; // Add info. smemChunkNames.emplace_back("smemGmemC" + std::to_string(resIdx)); numBytesAndAlignmentPerSmemChunk.emplace_back( - std::make_pair(numBytesSmemStoreC, numBytesAlignmentStoreC)); + std::make_pair(numBytesSmemStoreC, numBytesAlignmentStoreC)); firstChunkReuseSmem.emplace_back(reuseFirstChunksSmemStoreC); } @@ -328,7 +304,7 @@ class KernelTraits { // Add info. smemChunkNames.emplace_back("smemRowMax"); numBytesAndAlignmentPerSmemChunk.emplace_back( - std::make_pair(numBytesSmemRowMax, numBytesAlignmentRowMax)); + std::make_pair(numBytesSmemRowMax, numBytesAlignmentRowMax)); firstChunkReuseSmem.emplace_back(false); } @@ -336,7 +312,7 @@ class KernelTraits { { // Real tile size before slice-K reduction. auto const tileSize = - numSlicesForSliceK > 1 ? numSlicesForSliceK * tileM * numSlicesForSliceK * tileN : 0; + numSlicesForSliceK > 1 ? numSlicesForSliceK * tileM * numSlicesForSliceK * tileN : 0; // Number of bytes for tile in SMEM. auto const numBytesSmemTile = tileSize * tg::dtypeGetNumBits(dtypeAcc) / 8 /* bits */; // Number of bytes alignment for rowMax in SMEM. @@ -345,7 +321,7 @@ class KernelTraits { // Add info. smemChunkNames.emplace_back("smemSliceK"); numBytesAndAlignmentPerSmemChunk.emplace_back( - std::make_pair(numBytesSmemTile, numBytesAlignmentTile)); + std::make_pair(numBytesSmemTile, numBytesAlignmentTile)); firstChunkReuseSmem.emplace_back(false); } @@ -359,7 +335,7 @@ class KernelTraits { // Add info. smemChunkNames.emplace_back("smemPerTokenSf"); numBytesAndAlignmentPerSmemChunk.emplace_back( - std::make_pair(numBytesSmemPerTokenSf, numBytesAlignmentPerTokenSf)); + std::make_pair(numBytesSmemPerTokenSf, numBytesAlignmentPerTokenSf)); firstChunkReuseSmem.emplace_back(false); } @@ -378,7 +354,7 @@ class KernelTraits { // Add info. smemChunkNames.emplace_back("smemBias"); numBytesAndAlignmentPerSmemChunk.emplace_back( - std::make_pair(numBytesSmemBias, numBytesAlignmentBias)); + std::make_pair(numBytesSmemBias, numBytesAlignmentBias)); firstChunkReuseSmem.emplace_back(false); } @@ -392,7 +368,7 @@ class KernelTraits { // Add info. smemChunkNames.emplace_back("smemBlockAmax"); numBytesAndAlignmentPerSmemChunk.emplace_back( - std::make_pair(numBytesSmemBlockAmax, numBytesAlignmentBlockAmax)); + std::make_pair(numBytesSmemBlockAmax, numBytesAlignmentBlockAmax)); firstChunkReuseSmem.emplace_back(false); } @@ -411,13 +387,13 @@ class KernelTraits { // Add info. smemChunkNames.emplace_back("smemConstSfBuf"); numBytesAndAlignmentPerSmemChunk.emplace_back( - std::make_pair(numSmemBytesConstSfBuf, numBytesAlignmentConstSfBuf)); + std::make_pair(numSmemBytesConstSfBuf, numBytesAlignmentConstSfBuf)); firstChunkReuseSmem.emplace_back(reuseChunksSmemConstSfBuf); } // Create SMEM helper object. mSmemAllocatorHelper = - MemAllocatorHelper(numBytesAndAlignmentPerSmemChunk, firstChunkReuseSmem, smemChunkNames); + MemAllocatorHelper(numBytesAndAlignmentPerSmemChunk, firstChunkReuseSmem, smemChunkNames); #if 0 // E.g., // Chunk 0 smemLoadA: 32768 bytes, 1024 alignment, false, offset 0 @@ -454,7 +430,7 @@ class KernelTraits { // Add info. tmemChunkNames.emplace_back("tmemD"); numBytesAndAlignmentPerTmemChunk.emplace_back( - std::make_pair(numTmemColsD, numColsAlignmentD)); + std::make_pair(numTmemColsD, numColsAlignmentD)); firstChunkReuseTmem.emplace_back(reuseChunksTmemD); } @@ -464,10 +440,10 @@ class KernelTraits { bool const useTmemA = (numSlicesForSliceK > 1) || (dtypeMmaA != dtypeA); // Number of columns for A. auto const numTmemColsA = - useTmemA ? numStages * tileK / - (numSlicesForSliceK * tg::dtypeGetNumBits(tg::Dtype::UInt32) / - tg::dtypeGetNumBits(dtypeMmaA)) - : 0; + useTmemA ? numStages * tileK / + (numSlicesForSliceK * tg::dtypeGetNumBits(tg::Dtype::UInt32) / + tg::dtypeGetNumBits(dtypeMmaA)) + : 0; // Number of columns for A alignment. auto const numColsAlignmentA = 4; // No need to reuse TMEM. @@ -476,7 +452,7 @@ class KernelTraits { // Add info. tmemChunkNames.emplace_back("tmemA"); numBytesAndAlignmentPerTmemChunk.emplace_back( - std::make_pair(numTmemColsA, numColsAlignmentA)); + std::make_pair(numTmemColsA, numColsAlignmentA)); firstChunkReuseTmem.emplace_back(reuseChunksTmemA); } @@ -488,10 +464,11 @@ class KernelTraits { bool const useConstSfA = useBlockScalingA && !tg::dtypeIsBlockFmt(dtypeA); // Number of columns for scaling factors of A. auto const numTmemColsSfA = - useConstSfA ? tg::roundUp((tileK / 64) * tg::getTmemColStridePerGroup(tileM, mmaK), 4) - : (useBlockScalingA - ? ((tileK / 64) * tg::getTmemColStridePerGroup(tileM, mmaK)) * numStages - : 0); + useConstSfA + ? tg::roundUp((tileK / 64) * tg::getTmemColStridePerGroup(tileM, mmaK), 4) + : (useBlockScalingA + ? ((tileK / 64) * tg::getTmemColStridePerGroup(tileM, mmaK)) * numStages + : 0); // Number of columns for Sf alignment. auto const numColsAlignmentSfA = 4; // No need to reuse TMEM. @@ -500,7 +477,7 @@ class KernelTraits { // Add info. tmemChunkNames.emplace_back("tmemSfA"); numBytesAndAlignmentPerTmemChunk.emplace_back( - std::make_pair(numTmemColsSfA, numColsAlignmentSfA)); + std::make_pair(numTmemColsSfA, numColsAlignmentSfA)); firstChunkReuseTmem.emplace_back(reuseChunksTmemSfA); } @@ -512,10 +489,11 @@ class KernelTraits { bool const useConstSfB = useBlockScalingB && !tg::dtypeIsBlockFmt(dtypeB); // Number of columns for scaling factors of B. auto const numTmemColsSfB = - useConstSfB ? tg::roundUp((tileK / 64) * tg::getTmemColStridePerGroup(tileN, mmaK), 4) - : (useBlockScalingB - ? ((tileK / 64) * tg::getTmemColStridePerGroup(tileN, mmaK)) * numStages - : 0); + useConstSfB + ? tg::roundUp((tileK / 64) * tg::getTmemColStridePerGroup(tileN, mmaK), 4) + : (useBlockScalingB + ? ((tileK / 64) * tg::getTmemColStridePerGroup(tileN, mmaK)) * numStages + : 0); // Number of columns for Sf alignment. auto const numColsAlignmentSfB = 4; // No need to reuse TMEM. @@ -524,17 +502,17 @@ class KernelTraits { // Add info. tmemChunkNames.emplace_back("tmemSfB"); numBytesAndAlignmentPerTmemChunk.emplace_back( - std::make_pair(numTmemColsSfB, numColsAlignmentSfB)); + std::make_pair(numTmemColsSfB, numColsAlignmentSfB)); firstChunkReuseTmem.emplace_back(reuseChunksTmemSfB); } // Create TMEM helper object. mTmemAllocatorHelper = - MemAllocatorHelper(numBytesAndAlignmentPerTmemChunk, firstChunkReuseTmem, tmemChunkNames); + MemAllocatorHelper(numBytesAndAlignmentPerTmemChunk, firstChunkReuseTmem, tmemChunkNames); } } -public: + public: // The MMA kind. tg::MmaKind mMmaKind; // Helper for SMEM allocation. @@ -573,9 +551,7 @@ inline int32_t getSmemOffsetLoadB(KernelTraits traits) { //////////////////////////////////////////////////////////////////////////////////////////////////// -inline int32_t getSmemOffsetLoadAb(KernelTraits traits) { - return getSmemOffsetLoadA(traits); -} +inline int32_t getSmemOffsetLoadAb(KernelTraits traits) { return getSmemOffsetLoadA(traits); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -662,6 +638,6 @@ inline int32_t getTmemOffsetSfB(KernelTraits traits) { //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace gemm +} // namespace gemm -} // namespace batchedGemm +} // namespace batchedGemm diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h index c20ce9c00f..c7b18af138 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h @@ -1,24 +1,25 @@ /* -* SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & -* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once +#include + #include "trtllm/gen/DtypeDecl.h" #include "trtllm/gen/MmaDecl.h" -#include #ifdef TLLM_ENABLE_CUDA #include @@ -38,12 +39,10 @@ namespace tg = trtllm::gen; #ifdef TLLM_ENABLE_CUDA -inline CUtensorMap buildNdTmaDescriptor(tg::Dtype dtype, - tg::MmaKind mmaKind, +inline CUtensorMap buildNdTmaDescriptor(tg::Dtype dtype, tg::MmaKind mmaKind, std::vector const& shapes, std::vector const& strides, - std::vector const& tileShapes, - void* gmemAddr, + std::vector const& tileShapes, void* gmemAddr, bool doSwizzle = true) { // The multiplication factor of the data padding in SMEM. int32_t padMultiplier = 1; @@ -77,7 +76,7 @@ inline CUtensorMap buildNdTmaDescriptor(tg::Dtype dtype, // The swizzle type. CUtensorMapSwizzle swizzleType{CU_TENSOR_MAP_SWIZZLE_NONE}; int32_t fastestDimTileSizeBytes = - (tileShapes[0] * tg::dtypeGetNumBits(dtype) * padMultiplier) / /* bits */ 8; + (tileShapes[0] * tg::dtypeGetNumBits(dtype) * padMultiplier) / /* bits */ 8; if (doSwizzle) { if ((fastestDimTileSizeBytes % 128) == 0) { swizzleType = CU_TENSOR_MAP_SWIZZLE_128B; @@ -97,7 +96,7 @@ inline CUtensorMap buildNdTmaDescriptor(tg::Dtype dtype, } // Check gmem address must be 16B-aligned - assert((reinterpret_cast(gmemAddr) & 0b1111) == 0); // + assert((reinterpret_cast(gmemAddr) & 0b1111) == 0); // // Check shape must be in range [1, 2^32] int32_t dim = shapes.size(); @@ -106,8 +105,8 @@ inline CUtensorMap buildNdTmaDescriptor(tg::Dtype dtype, assert(dim == 2 || dim == 3 || dim == 4); // Check shape range. for (int32_t ii = 0; ii < dim; ++ii) { - assert(shapes[ii] >= (uint64_t(1))); // Size must be min 1 - assert(shapes[ii] <= (uint64_t(1) << 32)); // Size must be max 2^32 + assert(shapes[ii] >= (uint64_t(1))); // Size must be min 1 + assert(shapes[ii] <= (uint64_t(1) << 32)); // Size must be max 2^32 } // TMA descriptor does not store the zeroth stride and assumes it is 1. @@ -145,18 +144,13 @@ inline CUtensorMap buildNdTmaDescriptor(tg::Dtype dtype, std::vector tileStrides(dim, 1); // Build the descriptor. - CUresult result = cuTensorMapEncodeTiled(&desc, - tmaDataFormat, - /*tensorRank=*/dim, - gmemAddr, - shapes.data(), - stridesInBytes.data(), - boxDim.data(), - tileStrides.data(), - /*interleave=*/CU_TENSOR_MAP_INTERLEAVE_NONE, - swizzleType, - /*l2Promotion=*/CU_TENSOR_MAP_L2_PROMOTION_L2_128B, - /*oobFill=*/CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + CUresult result = + cuTensorMapEncodeTiled(&desc, tmaDataFormat, + /*tensorRank=*/dim, gmemAddr, shapes.data(), stridesInBytes.data(), + boxDim.data(), tileStrides.data(), + /*interleave=*/CU_TENSOR_MAP_INTERLEAVE_NONE, swizzleType, + /*l2Promotion=*/CU_TENSOR_MAP_L2_PROMOTION_L2_128B, + /*oobFill=*/CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); if (result != CUDA_SUCCESS) { char const* errorString; @@ -199,11 +193,9 @@ inline CUtensorMap buildNdTmaDescriptor(tg::Dtype dtype, } // TODO: make it work with the above descriptor? -inline CUtensorMap buildSfTmaDescriptor(tg::Dtype dtype, - std::vector const& shapes, +inline CUtensorMap buildSfTmaDescriptor(tg::Dtype dtype, std::vector const& shapes, std::vector const& strides, - const std::vector& tileShapes, - void* gmemAddr) { + const std::vector& tileShapes, void* gmemAddr) { CUtensorMap desc{}; CUtensorMapDataType tmaDataFormat; if (dtype == tg::Dtype::E4m3 || dtype == tg::Dtype::UE8m0) { @@ -217,14 +209,14 @@ inline CUtensorMap buildSfTmaDescriptor(tg::Dtype dtype, CUtensorMapSwizzle swizzleType = CU_TENSOR_MAP_SWIZZLE_NONE; // Check gmem address must be 16B-aligned - assert((reinterpret_cast(gmemAddr) & 0b1111) == 0); // + assert((reinterpret_cast(gmemAddr) & 0b1111) == 0); // // Check shape must be in range [1, 2^32] int32_t dim = shapes.size(); // Check shape range. for (int32_t ii = 0; ii < dim; ++ii) { - assert(shapes[ii] >= (uint64_t(1))); // Size must be min 1 - assert(shapes[ii] <= (uint64_t(1) << 32)); // Size must be max 2^32 + assert(shapes[ii] >= (uint64_t(1))); // Size must be min 1 + assert(shapes[ii] <= (uint64_t(1) << 32)); // Size must be max 2^32 } // TMA descriptor does not store the zeroth stride and assumes it is 1. @@ -296,10 +288,10 @@ inline CUtensorMap buildSfTmaDescriptor(tg::Dtype dtype, return desc; } -#endif // defined TLLM_ENABLE_CUDA +#endif // defined TLLM_ENABLE_CUDA //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace gemm +} // namespace gemm -} // namespace batchedGemm +} // namespace batchedGemm diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CommonUtils.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CommonUtils.h index 798da2a23a..53155c8ffb 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CommonUtils.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CommonUtils.h @@ -1,19 +1,19 @@ /* -* SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & -* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once namespace batchedGemm { @@ -38,19 +38,21 @@ constexpr unsigned long XLargeN = 1UL << 35; //////////////////////////////////////////////////////////////////////////////////////////////////// -template inline T ceilDiv(T m, T n) { +template +inline T ceilDiv(T m, T n) { return (m + n - T(1)) / n; } //////////////////////////////////////////////////////////////////////////////////////////////////// -template inline T roundUp(T m, T n) { +template +inline T roundUp(T m, T n) { return ceilDiv(m, n) * n; } //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace gen -} // namespace trtllm +} // namespace gen +} // namespace trtllm -} // namespace batchedGemm +} // namespace batchedGemm diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CudaKernelLauncher.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CudaKernelLauncher.h index 7b819ba753..42bc884f92 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CudaKernelLauncher.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CudaKernelLauncher.h @@ -1,26 +1,27 @@ /* -* SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & -* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once #ifdef TLLM_ENABLE_CUDA +#include +#include + #include #include -#include -#include #endif namespace batchedGemm { @@ -30,18 +31,13 @@ namespace gen { //////////////////////////////////////////////////////////////////////////////////////////////////// #ifdef TLLM_ENABLE_CUDA -inline CUresult launchKernel(void* kernelParams, - void* cudaStream, - int32_t smemSize, - CUfunction kernel, - dim3 block3, - dim3 grid3, - dim3 cluster3, +inline CUresult launchKernel(void* kernelParams, void* cudaStream, int32_t smemSize, + CUfunction kernel, dim3 block3, dim3 grid3, dim3 cluster3, bool enablesPdl) { // Make sure we can launch with that much shared memory. if (smemSize > 48 * 1024) { CUresult result = - cuFuncSetAttribute(kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smemSize); + cuFuncSetAttribute(kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smemSize); if (result != CUDA_SUCCESS) { return result; } @@ -66,7 +62,7 @@ inline CUresult launchKernel(void* kernelParams, launchAttrs[0].value.clusterDim.z = cluster3.z; launchAttrs[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE; launchAttrs[1].value.clusterSchedulingPolicyPreference = - (clusterDim > 1) ? CU_CLUSTER_SCHEDULING_POLICY_SPREAD : CU_CLUSTER_SCHEDULING_POLICY_DEFAULT; + (clusterDim > 1) ? CU_CLUSTER_SCHEDULING_POLICY_SPREAD : CU_CLUSTER_SCHEDULING_POLICY_DEFAULT; launchAttrs[2].id = CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION; launchAttrs[2].value.programmaticStreamSerializationAllowed = enablesPdl; launchConfig.attrs = launchAttrs; @@ -74,10 +70,10 @@ inline CUresult launchKernel(void* kernelParams, // Add setting for non-portable cluster size. if (clusterDim > 8) { - CUresult result = cuFuncSetAttribute(kernel, - CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, - 1 // Enable non-portable cluster sizes - ); + CUresult result = + cuFuncSetAttribute(kernel, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, + 1 // Enable non-portable cluster sizes + ); if (result != CUDA_SUCCESS) { return result; } @@ -90,7 +86,7 @@ inline CUresult launchKernel(void* kernelParams, //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace gen -} // namespace trtllm +} // namespace gen +} // namespace trtllm -} // namespace batchedGemm +} // namespace batchedGemm diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h index 2b26230f33..0866256492 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h @@ -1,25 +1,25 @@ /* -* SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & -* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once -#include #include -#include +#include #include +#include #ifndef TLLM_GEN_EXPORT_INTERFACE #include "trtllm/gen/MmaDecl.h" #else @@ -50,9 +50,9 @@ enum class Dtype : uint32_t { // Bit 4: is it signed? 0x1 if true, 0x0 otherwise. // Byte 3: Is it a block format? 0x1 if true, 0x0 otherwise. -#define TLLM_ENCODE_DTYPE(BlockFormatBit, SignedBit, IntegerBit, NumBits, Uid) \ - uint32_t { \ - (BlockFormatBit << 24) | (SignedBit << 20) | (IntegerBit << 16) | (NumBits << 8) | (Uid) \ +#define TLLM_ENCODE_DTYPE(BlockFormatBit, SignedBit, IntegerBit, NumBits, Uid) \ + uint32_t { \ + (BlockFormatBit << 24) | (SignedBit << 20) | (IntegerBit << 16) | (NumBits << 8) | (Uid) \ } // clang-format off @@ -109,9 +109,7 @@ inline bool dtypeIsFloat(Dtype dtype) { //////////////////////////////////////////////////////////////////////////////////////////////////// // Is a given data type an 8-bit floating-point type? -inline bool dtypeIsFp8(Dtype dtype) { - return dtype == Dtype::E4m3 || dtype == Dtype::E5m2; -} +inline bool dtypeIsFp8(Dtype dtype) { return dtype == Dtype::E4m3 || dtype == Dtype::E5m2; } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -134,51 +132,51 @@ inline bool dtypeIsSigned(Dtype dtype) { // For logging and error reporting inline std::string dtypeToString(Dtype dtype) { switch (dtype) { - case Dtype::Bfloat16: - return "Bfloat16"; - case Dtype::Bool: - return "Bool"; - case Dtype::E2m1: - return "E2m1"; - case Dtype::E2m3: - return "E2m3"; - case Dtype::E3m2: - return "E3m2"; - case Dtype::E4m3: - return "E4m3"; - case Dtype::E5m2: - return "E5m2"; - case Dtype::Fp16: - return "Fp16"; - case Dtype::Fp32: - return "Fp32"; - case Dtype::Int8: - return "Int8"; - case Dtype::Int32: - return "Int32"; - case Dtype::Int64: - return "Int64"; - case Dtype::MxE4m3: - return "MxE4m3"; - case Dtype::MxE2m1: - return "MxE2m1"; - case Dtype::UE8m0: - return "UE8m0"; - case Dtype::UInt8: - return "UInt8"; - case Dtype::UInt16: - return "UInt16"; - case Dtype::UInt32: - return "UInt32"; - case Dtype::UInt64: - return "UInt64"; - case Dtype::UInt128: - return "UInt128"; - case Dtype::Void: - return "Void"; - default: - assert(false); - return "Unsupported type"; + case Dtype::Bfloat16: + return "Bfloat16"; + case Dtype::Bool: + return "Bool"; + case Dtype::E2m1: + return "E2m1"; + case Dtype::E2m3: + return "E2m3"; + case Dtype::E3m2: + return "E3m2"; + case Dtype::E4m3: + return "E4m3"; + case Dtype::E5m2: + return "E5m2"; + case Dtype::Fp16: + return "Fp16"; + case Dtype::Fp32: + return "Fp32"; + case Dtype::Int8: + return "Int8"; + case Dtype::Int32: + return "Int32"; + case Dtype::Int64: + return "Int64"; + case Dtype::MxE4m3: + return "MxE4m3"; + case Dtype::MxE2m1: + return "MxE2m1"; + case Dtype::UE8m0: + return "UE8m0"; + case Dtype::UInt8: + return "UInt8"; + case Dtype::UInt16: + return "UInt16"; + case Dtype::UInt32: + return "UInt32"; + case Dtype::UInt64: + return "UInt64"; + case Dtype::UInt128: + return "UInt128"; + case Dtype::Void: + return "Void"; + default: + assert(false); + return "Unsupported type"; } } @@ -186,12 +184,12 @@ inline std::string dtypeToString(Dtype dtype) { inline Dtype dtypeEltType(Dtype dtype) { switch (dtype) { - case Dtype::MxE2m1: - return Dtype::E2m1; - case Dtype::MxE4m3: - return Dtype::E4m3; - default: - return dtype; + case Dtype::MxE2m1: + return Dtype::E2m1; + case Dtype::MxE4m3: + return Dtype::E4m3; + default: + return dtype; } } @@ -199,14 +197,14 @@ inline Dtype dtypeEltType(Dtype dtype) { inline int dtypeNumEltsPerSf(Dtype dtype) { switch (dtype) { - case Dtype::E2m1: - return 16; - case Dtype::MxE2m1: - case Dtype::MxE4m3: - return 32; - default: - assert(false); - return -1; + case Dtype::E2m1: + return 16; + case Dtype::MxE2m1: + case Dtype::MxE4m3: + return 32; + default: + assert(false); + return -1; } } @@ -215,14 +213,14 @@ inline int dtypeNumEltsPerSf(Dtype dtype) { // Returns the dtype of scaling factors, if applicable. inline Dtype dtypeGetBlockSfType(Dtype dtype) { switch (dtype) { - case Dtype::E2m1: - return Dtype::E4m3; - case Dtype::MxE2m1: - case Dtype::MxE4m3: - return Dtype::UE8m0; - default: - assert(false); - return Dtype::Void; + case Dtype::E2m1: + return Dtype::E4m3; + case Dtype::MxE2m1: + case Dtype::MxE4m3: + return Dtype::UE8m0; + default: + assert(false); + return Dtype::Void; } } @@ -267,7 +265,7 @@ inline MmaKind dtypeGetMmaKind(Dtype dtypeA, Dtype dtypeB) { //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace gen -} // namespace trtllm +} // namespace gen +} // namespace trtllm -} // namespace batchedGemm +} // namespace batchedGemm diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/MmaDecl.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/MmaDecl.h index 7633090dc8..c8de154396 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/MmaDecl.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/MmaDecl.h @@ -1,19 +1,19 @@ /* -* SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & -* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once #include @@ -21,9 +21,9 @@ #include #ifndef TLLM_GEN_EXPORT_INTERFACE #include "trtllm/gen/CommonUtils.h" -#else // TLLM_GEN_EXPORT_INTERFACE +#else // TLLM_GEN_EXPORT_INTERFACE #include "CommonUtils.h" -#endif // TLLM_GEN_EXPORT_INTERFACE +#endif // TLLM_GEN_EXPORT_INTERFACE namespace batchedGemm { @@ -73,23 +73,23 @@ inline bool mmaKindIsBlockFmt(MmaKind mmaKind) { // For logging and error reporting inline std::string mmaKindToString(MmaKind mmaKind) { switch (mmaKind) { - case MmaKind::Auto: - return "Auto"; - case MmaKind::Fp16: - return "Fp16"; - case MmaKind::Fp8Fp6Fp4: - return "Fp8Fp6Fp4"; - case MmaKind::Int8: - return "Int8"; - case MmaKind::MxFp4NvFp4: - return "MxFp4NvFp4"; - case MmaKind::MxFp8Fp6Fp4: - return "MxFp8Fp6Fp4"; - case MmaKind::Tf32: - return "Tf32"; - default: - assert(false); - return "Unsupported type"; + case MmaKind::Auto: + return "Auto"; + case MmaKind::Fp16: + return "Fp16"; + case MmaKind::Fp8Fp6Fp4: + return "Fp8Fp6Fp4"; + case MmaKind::Int8: + return "Int8"; + case MmaKind::MxFp4NvFp4: + return "MxFp4NvFp4"; + case MmaKind::MxFp8Fp6Fp4: + return "MxFp8Fp6Fp4"; + case MmaKind::Tf32: + return "Tf32"; + default: + assert(false); + return "Unsupported type"; } } @@ -104,7 +104,7 @@ inline int32_t getTmemColStridePerGroup(int32_t tileMn, int32_t mmaK) { //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace gen -} // namespace trtllm +} // namespace gen +} // namespace trtllm -} // namespace batchedGemm +} // namespace batchedGemm diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/SfLayoutDecl.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/SfLayoutDecl.h index c64105696e..965bb1b7b8 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/SfLayoutDecl.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/SfLayoutDecl.h @@ -1,19 +1,19 @@ /* -* SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & -* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once #include @@ -74,23 +74,23 @@ enum class SfLayout { inline std::string sfLayoutToString(SfLayout layout) { switch (layout) { - case SfLayout::Linear: - return "linear"; - case SfLayout::R8c4: - return "8x4"; - case SfLayout::R8c16: - return "8x16"; - case SfLayout::R128c4: - return "128x4"; - default: - assert(false); - return "Unsupported layout"; + case SfLayout::Linear: + return "linear"; + case SfLayout::R8c4: + return "8x4"; + case SfLayout::R8c16: + return "8x16"; + case SfLayout::R128c4: + return "128x4"; + default: + assert(false); + return "Unsupported layout"; } } //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace gen -} // namespace trtllm +} // namespace gen +} // namespace trtllm -} // namespace batchedGemm +} // namespace batchedGemm diff --git a/tests/test_trtllm_gen_fused_moe.py b/tests/test_trtllm_gen_fused_moe.py index eaf7942c63..c1a317f49d 100644 --- a/tests/test_trtllm_gen_fused_moe.py +++ b/tests/test_trtllm_gen_fused_moe.py @@ -1028,7 +1028,7 @@ def prepare_static_weights_for_kernel( # Use shuffled weights with BlockMajorK layout for better performance use_shuffled_weight = weight_processing["use_shuffled_weight"] weight_layout = weight_processing["layout"] - + if use_shuffled_weight: # FIXME: this depends on the kernel internals epilogue_tile_m = 128 @@ -1037,29 +1037,37 @@ def prepare_static_weights_for_kernel( gemm1_weights_bf16_shuffled = [] gemm2_weights_bf16_shuffled = [] for i in range(num_experts): - tmp_weights1 = reorder_rows_for_gated_act_gemm(args.gemm1_weights[i].clone().view(torch.uint8)) - tmp_weights1 = shuffle_matrix_a( - tmp_weights1, epilogue_tile_m + tmp_weights1 = reorder_rows_for_gated_act_gemm( + args.gemm1_weights[i].clone().view(torch.uint8) ) + tmp_weights1 = shuffle_matrix_a(tmp_weights1, epilogue_tile_m) tmp_weights2 = shuffle_matrix_a( args.gemm2_weights[i].clone().view(torch.uint8), epilogue_tile_m ) if weight_layout == WeightLayout.BlockMajorK: block_k = 128 - tmp_weights1 = convert_to_block_layout(tmp_weights1.view(torch.uint8), block_k) - tmp_weights2 = convert_to_block_layout(tmp_weights2.view(torch.uint8), block_k) + tmp_weights1 = convert_to_block_layout( + tmp_weights1.view(torch.uint8), block_k + ) + tmp_weights2 = convert_to_block_layout( + tmp_weights2.view(torch.uint8), block_k + ) gemm1_weights_bf16_shuffled.append(tmp_weights1.view(torch.bfloat16)) gemm2_weights_bf16_shuffled.append(tmp_weights2.view(torch.bfloat16)) # Stack weights for all experts - gemm1_weights_bf16_shuffled = torch.stack(gemm1_weights_bf16_shuffled).view( - torch.bfloat16 - ).contiguous() - gemm2_weights_bf16_shuffled = torch.stack(gemm2_weights_bf16_shuffled).view( - torch.bfloat16 - ).contiguous() + gemm1_weights_bf16_shuffled = ( + torch.stack(gemm1_weights_bf16_shuffled) + .view(torch.bfloat16) + .contiguous() + ) + gemm2_weights_bf16_shuffled = ( + torch.stack(gemm2_weights_bf16_shuffled) + .view(torch.bfloat16) + .contiguous() + ) return { "gemm1_weights": gemm1_weights_bf16_shuffled, From baab3ac6f3c90fd9dd4a4f174ba4eeb548bd89b9 Mon Sep 17 00:00:00 2001 From: Alex Yang Date: Sat, 4 Oct 2025 02:01:42 +0000 Subject: [PATCH 07/12] upgrade to tvm ffi --- csrc/trtllm_fused_moe_kernel_launcher.cu | 344 +++++++++++------------ 1 file changed, 165 insertions(+), 179 deletions(-) diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index 0bd23584de..76dc9ce294 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -65,14 +65,14 @@ Driver calls take place to carry out the gemm operations. class FusedMoeLauncher { protected: - at::Tensor const* routing_logits{}; - at::Tensor const* routing_bias{}; - at::Tensor const* hidden_states{}; - at::Tensor const* gemm1_weights{}; - at::Tensor const* output1_scales_scalar{}; - at::Tensor const* output1_scales_gate_scalar{}; - at::Tensor const* gemm2_weights{}; - at::Tensor const* output2_scales_scalar{}; + Tensor const* routing_logits{}; + Tensor const* routing_bias{}; + Tensor const* hidden_states{}; + Tensor const* gemm1_weights{}; + Tensor const* output1_scales_scalar{}; + Tensor const* output1_scales_gate_scalar{}; + Tensor const* gemm2_weights{}; + Tensor const* output2_scales_scalar{}; int64_t tile_tokens_dim{}; int64_t routing_method_type{}; @@ -88,84 +88,85 @@ class FusedMoeLauncher { GatedActType gated_act_type{GatedActType::SwiGlu}; // Initialize common data necessary for later. - // May throw exception from TORCH_CHECK. - void init_common(at::Tensor const* routing_logits, at::Tensor const* routing_bias, - at::Tensor const* hidden_states, at::Tensor const* gemm1_weights, - at::Tensor const* gemm2_weights, + // May throw exception from TVM_FFI_ICHECK. + void init_common(Tensor const* routing_logits, Tensor const* routing_bias, + Tensor const* hidden_states, Tensor const* gemm1_weights, + Tensor const* gemm2_weights, std::unique_ptr&& args, int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout, int64_t gated_act_type); // Routing logits [num_tokens, num_experts] void check_routing_logits_shape() const { - TORCH_CHECK(routing_logits->dim() == 2, "routing_logits must be 2D."); - TORCH_CHECK(routing_logits->sizes()[0] == hidden_states->sizes()[0], - "routing_logits and hidden_states must have the same number of tokens."); - TORCH_CHECK(routing_logits->sizes()[1] == args->num_experts, - "routing_logits dim1 must match num_experts."); + TVM_FFI_ICHECK_EQ(routing_logits->ndim, 2) << "routing_logits must be 2D."; + TVM_FFI_ICHECK_EQ(routing_logits->shape[0], hidden_states->shape[0]) + << "routing_logits and hidden_states must have the same number of tokens."; + TVM_FFI_ICHECK_EQ(routing_logits->shape[1], args->num_experts) + << "routing_logits dim1 must match num_experts."; } // Routing bias [num_experts] void check_routing_bias_shape() const { if (routing_bias != nullptr) { - TORCH_CHECK(routing_bias->dim() == 1, "routing_bias must be 1D."); - TORCH_CHECK(routing_bias->sizes()[0] == args->num_experts, - "routing_bias has incorrect shape."); + TVM_FFI_ICHECK_EQ(routing_bias->ndim, 1) << "routing_bias must be 1D."; + TVM_FFI_ICHECK_EQ(routing_bias->shape[0], args->num_experts) + << "routing_bias has incorrect shape."; } } // Hidden states [num_tokens, hidden_size] void check_hidden_states_shape() const { - TORCH_CHECK(hidden_states->dim() == 2, "hidden_states must be 2D."); - TORCH_CHECK(hidden_states->sizes()[1] == args->intermediate_size, - "hidden_states has incorrect shape."); + TVM_FFI_ICHECK_EQ(hidden_states->ndim, 2) << "hidden_states must be 2D."; + TVM_FFI_ICHECK_EQ(hidden_states->shape[1], args->intermediate_size) + << "hidden_states has incorrect shape."; } // GEMM1 or GEMM2 weights [num_experts, M, K] or [num_experts, K/block_k, M, block_k] void check_weights_shape(std::string which_weights) const { - at::Tensor const* weights{}; + Tensor const* weights{}; if (which_weights == "gemm1") { weights = gemm1_weights; } else if (which_weights == "gemm2") { weights = gemm2_weights; } else { - TORCH_CHECK(false, "Internal error: which_weights = ", which_weights); + TVM_FFI_LOG_AND_THROW(InternalError) << "Internal error: which_weights = " << which_weights; } int64_t Mn = 0, K = 0; if (weight_layout == MatrixLayout::MajorK) { // MajorK [num_experts, M, K] - Mn = weights->sizes()[1]; - K = weights->sizes()[2]; + Mn = weights->shape[1]; + K = weights->shape[2]; } else if (weight_layout == MatrixLayout::BlockMajorK) { // BlockMajorK [num_experts, K/block_k, M, block_k] - Mn = weights->sizes()[2]; - int64_t block_k = weights->sizes()[3]; - K = weights->sizes()[1] * block_k; + Mn = weights->shape[2]; + int64_t block_k = weights->shape[3]; + K = weights->shape[1] * block_k; } else { - TORCH_CHECK(false, "Unsupported weight_layout: ", weight_layout); + TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported weight_layout: " << weight_layout; } - TORCH_CHECK(weights->sizes()[0] == args->num_experts, - which_weights + " weights expert dimension must match num_experts"); + TVM_FFI_ICHECK_EQ(weights->shape[0], args->num_experts) + << which_weights << " weights expert dimension must match num_experts"; if (which_weights == "gemm1") { - TORCH_CHECK(Mn % 2 == 0, which_weights + " weights Mn dimension must be even."); - TORCH_CHECK(args->intermediate_size == Mn / 2, "intermediate_size has incorrect shape."); - TORCH_CHECK(K == hidden_states->sizes()[1], - which_weights + " weights K dimension must be equal to hidden_size."); + TVM_FFI_ICHECK_EQ(Mn % 2, 0) << which_weights << " weights Mn dimension must be even."; + TVM_FFI_ICHECK_EQ(args->intermediate_size, Mn / 2) + << "intermediate_size has incorrect shape."; + TVM_FFI_ICHECK_EQ(K, hidden_states->shape[1]) + << which_weights << " weights K dimension must be equal to hidden_size."; } else if (which_weights == "gemm2") { - TORCH_CHECK(K == args->intermediate_size, - which_weights + " weights K dimension must be equal to intermediate_size."); + TVM_FFI_ICHECK_EQ(K, args->intermediate_size) + << which_weights << " weights K dimension must be equal to intermediate_size."; } } void check_routing_common() const { - TORCH_CHECK(args->top_k > 0 && args->top_k <= args->num_experts, - "top_k must be between 1 and num_experts"); - TORCH_CHECK(args->local_num_experts > 0 && args->local_num_experts <= args->num_experts, - "local_num_experts must be between 1 and num_experts"); - TORCH_CHECK(args->local_expert_offset >= 0 && - args->local_expert_offset + args->local_num_experts <= args->num_experts, - "expert offset and count must be within valid range"); + TVM_FFI_ICHECK(args->top_k > 0 && args->top_k <= args->num_experts) + << "top_k must be between 1 and num_experts"; + TVM_FFI_ICHECK(args->local_num_experts > 0 && args->local_num_experts <= args->num_experts) + << "local_num_experts must be between 1 and num_experts"; + TVM_FFI_ICHECK(args->local_expert_offset >= 0 && + args->local_expert_offset + args->local_num_experts <= args->num_experts) + << "expert offset and count must be within valid range"; check_routing_logits_shape(); @@ -175,16 +176,16 @@ class FusedMoeLauncher { } // Routing phase workspace tensors (allocated in prepare_routing() or prepare_routing_common()) - at::Tensor num_tokens_per_expert; - at::Tensor total_num_padded_tokens; - at::Tensor expanded_idx_to_permuted_idx; - at::Tensor permuted_idx_to_token_idx; - at::Tensor expert_weights; - at::Tensor expert_indexes; - at::Tensor expert_count_histogram; - at::Tensor cta_idx_xy_to_batch_idx; - at::Tensor cta_idx_xy_to_mn_limit; - at::Tensor num_non_exiting_ctas; + Tensor num_tokens_per_expert; + Tensor total_num_padded_tokens; + Tensor expanded_idx_to_permuted_idx; + Tensor permuted_idx_to_token_idx; + Tensor expert_weights; + Tensor expert_indexes; + Tensor expert_count_histogram; + Tensor cta_idx_xy_to_batch_idx; + Tensor cta_idx_xy_to_mn_limit; + Tensor num_non_exiting_ctas; void prepare_routing_common() { // Allocate routing phase workspace tensors @@ -193,68 +194,61 @@ class FusedMoeLauncher { args->num_tokens, args->top_k, args->num_experts, tile_tokens_dim); // Common routing workspace tensors allocation - num_tokens_per_expert = at::detail::empty_cuda({args->num_experts}, at::ScalarType::Int, - routing_logits->device(), std::nullopt); + num_tokens_per_expert = alloc_tensor({args->num_experts}, dl_int32, routing_logits->device); - total_num_padded_tokens = at::empty( - {}, at::TensorOptions().device(routing_logits->device()).dtype(at::ScalarType::Int)); + total_num_padded_tokens = alloc_tensor({1}, dl_int32, routing_logits->device); expanded_idx_to_permuted_idx = - at::detail::empty_cuda({args->num_tokens * args->top_k}, at::ScalarType::Int, - routing_logits->device(), std::nullopt); + alloc_tensor({args->num_tokens * args->top_k}, dl_int32, routing_logits->device); - permuted_idx_to_token_idx = at::detail::empty_cuda({max_num_padded_tokens}, at::ScalarType::Int, - routing_logits->device(), std::nullopt); + permuted_idx_to_token_idx = + alloc_tensor({max_num_padded_tokens}, dl_int32, routing_logits->device); - expert_indexes = at::detail::empty_cuda({args->num_tokens, args->top_k}, at::ScalarType::Int, - routing_logits->device(), std::nullopt); + expert_indexes = + alloc_tensor({args->num_tokens, args->top_k}, dl_int32, routing_logits->device); // expert_weights allocation should be done by derived class since data type could vary int64_t const size_of_expert_count_histogram = std::max(args->num_experts * 2, 256 * 2); - expert_count_histogram = - at::detail::empty_cuda({size_of_expert_count_histogram}, - at::ScalarType::Int, // 256 is the max number of threads per block + expert_count_histogram = alloc_tensor({size_of_expert_count_histogram}, + dl_int32, // 256 is the max number of threads per block // and max number of experts - routing_logits->device(), std::nullopt); + routing_logits->device); int32_t max_num_ctas = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxNumCtasInBatchDim( args->num_tokens, args->top_k, args->num_experts, tile_tokens_dim); - cta_idx_xy_to_batch_idx = at::detail::empty_cuda({max_num_ctas}, at::ScalarType::Int, - routing_logits->device(), std::nullopt); + cta_idx_xy_to_batch_idx = alloc_tensor({max_num_ctas}, dl_int32, routing_logits->device); - cta_idx_xy_to_mn_limit = at::detail::empty_cuda({max_num_ctas}, at::ScalarType::Int, - routing_logits->device(), std::nullopt); + cta_idx_xy_to_mn_limit = alloc_tensor({max_num_ctas}, dl_int32, routing_logits->device); - num_non_exiting_ctas = at::empty( - {}, at::TensorOptions().device(routing_logits->device()).dtype(at::ScalarType::Int)); + num_non_exiting_ctas = alloc_tensor({1}, dl_int32, routing_logits->device); - workspace.total_num_padded_tokens = total_num_padded_tokens.data_ptr(); + workspace.total_num_padded_tokens = static_cast(total_num_padded_tokens->data); workspace.total_max_padded_tokens = max_num_padded_tokens; workspace.ProjUpTileN = tile_tokens_dim; - workspace.routing_expert_indexes = expert_indexes.data_ptr(); - workspace.permuted_idx_size = total_num_padded_tokens.data_ptr(); - workspace.expanded_idx_to_permuted_idx = expanded_idx_to_permuted_idx.data_ptr(); - workspace.permuted_idx_to_token_idx = permuted_idx_to_token_idx.data_ptr(); + workspace.routing_expert_indexes = static_cast(expert_indexes->data); + workspace.permuted_idx_size = static_cast(total_num_padded_tokens->data); + workspace.expanded_idx_to_permuted_idx = static_cast(expanded_idx_to_permuted_idx->data); + workspace.permuted_idx_to_token_idx = static_cast(permuted_idx_to_token_idx->data); // workspace.expert_weights will be set by derived class after expert_weights allocation - workspace.cta_idx_xy_to_batch_idx = cta_idx_xy_to_batch_idx.data_ptr(); - workspace.cta_idx_xy_to_mn_limit = cta_idx_xy_to_mn_limit.data_ptr(); - workspace.num_non_exiting_ctas = num_non_exiting_ctas.data_ptr(); + workspace.cta_idx_xy_to_batch_idx = static_cast(cta_idx_xy_to_batch_idx->data); + workspace.cta_idx_xy_to_mn_limit = static_cast(cta_idx_xy_to_mn_limit->data); + workspace.num_non_exiting_ctas = static_cast(num_non_exiting_ctas->data); } void check_moe_common() const { // Hidden states [num_tokens, hidden_size] - TORCH_CHECK(hidden_states->dim() == 2, "hidden_states must be 2D."); + TVM_FFI_ICHECK_EQ(hidden_states->ndim, 2) << "hidden_states must be 2D."; } // MoE computation phase workspace tensors (allocated in prepare_moe() or prepare_moe_common()) - at::Tensor gemm1_output; - at::Tensor activation_output; - at::Tensor gemm2_output; - at::Tensor workspace_fc1; - at::Tensor workspace_fc2; - at::Tensor output; + Tensor gemm1_output; + Tensor activation_output; + Tensor gemm2_output; + Tensor workspace_fc1; + Tensor workspace_fc2; + Tensor output; int64_t moe_tactic{-1}; std::unique_ptr moe_runner; @@ -273,12 +267,10 @@ class FusedMoeLauncher { this->moe_tactic = moe_tactic; auto workspace_sizes = moe_runner->getWorkspaceSizeInBytes(*args, moe_tactic); - workspace_fc1 = at::detail::empty_cuda({std::get<0>(workspace_sizes)}, at::ScalarType::Char, - hidden_states->device(), std::nullopt); - workspace_fc2 = at::detail::empty_cuda({std::get<1>(workspace_sizes)}, at::ScalarType::Char, - hidden_states->device(), std::nullopt); - workspace.bmm1_workspace = workspace_fc1.data_ptr(); - workspace.bmm2_workspace = workspace_fc2.data_ptr(); + workspace_fc1 = alloc_tensor({std::get<0>(workspace_sizes)}, dl_int8, hidden_states->device); + workspace_fc2 = alloc_tensor({std::get<1>(workspace_sizes)}, dl_int8, hidden_states->device); + workspace.bmm1_workspace = workspace_fc1->data; + workspace.bmm2_workspace = workspace_fc2->data; } public: @@ -290,35 +282,36 @@ class FusedMoeLauncher { // Main entry point for all the executions. // Do initializations prior to calling this as the initializations are different for bf16, fp8 and // fp4. The executions are non-blocking by default. - std::vector run(int64_t moe_tactic, bool enable_pdl = true) { + Array run(int64_t moe_tactic, bool enable_pdl = true) { check_routing(); prepare_routing(); // Execute routing tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim); - int routing_device = routing_logits->get_device(); - auto const& routing_stream = at::cuda::getCurrentCUDAStream(routing_device); - routing_runner.run( - routing_logits->data_ptr(), args->routing_bias, args->num_tokens, args->num_experts, - args->top_k, args->n_group, args->topk_group, args->local_expert_offset, - args->local_num_experts, args->routed_scaling_factor, expert_indexes.data_ptr(), - expert_count_histogram.data_ptr(), total_num_padded_tokens.data_ptr(), - expanded_idx_to_permuted_idx.data_ptr(), - nullptr /*permuted_idx_to_expanded_idx.data_ptr()*/, - permuted_idx_to_token_idx.data_ptr(), expert_weights.data_ptr(), - num_tokens_per_expert.data_ptr(), cta_idx_xy_to_batch_idx.data_ptr(), - cta_idx_xy_to_mn_limit.data_ptr(), num_non_exiting_ctas.data_ptr(), - args->mDtypeElt, false, true, static_cast(routing_method_type), - routing_stream); + cudaStream_t routing_stream = get_stream(routing_logits->device); + routing_runner.run(routing_logits->data, args->routing_bias, args->num_tokens, + args->num_experts, args->top_k, args->n_group, args->topk_group, + args->local_expert_offset, args->local_num_experts, + args->routed_scaling_factor, static_cast(expert_indexes->data), + static_cast(expert_count_histogram->data), + static_cast(total_num_padded_tokens->data), + static_cast(expanded_idx_to_permuted_idx->data), + nullptr /*permuted_idx_to_expanded_idx->data*/, + static_cast(permuted_idx_to_token_idx->data), expert_weights->data, + static_cast(num_tokens_per_expert->data), + static_cast(cta_idx_xy_to_batch_idx->data), + static_cast(cta_idx_xy_to_mn_limit->data), + static_cast(num_non_exiting_ctas->data), args->mDtypeElt, false, true, + static_cast(routing_method_type), routing_stream); check_moe(); // if moe_tactic is -1, it will be set to the default valid config index prepare_moe(moe_tactic); // Execute MoE - int moe_device = hidden_states->get_device(); - auto const& moe_stream = at::cuda::getCurrentCUDAStream(moe_device); - moe_runner->run(*args, workspace, moe_device, moe_stream, moe_tactic, enable_pdl); + cudaStream_t moe_stream = get_stream(hidden_states->device); + moe_runner->run(*args, workspace, hidden_states->device.device_id, moe_stream, moe_tactic, + enable_pdl); if (args->do_finalize) { return {output}; @@ -328,20 +321,19 @@ class FusedMoeLauncher { }; void FusedMoeLauncher::init_common( - at::Tensor const* routing_logits, at::Tensor const* routing_bias, - at::Tensor const* hidden_states, at::Tensor const* gemm1_weights, - at::Tensor const* gemm2_weights, + Tensor const* routing_logits, Tensor const* routing_bias, Tensor const* hidden_states, + Tensor const* gemm1_weights, Tensor const* gemm2_weights, std::unique_ptr&& args, int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout, int64_t gated_act_type) { // Check devicearchitecture: Blackwell (SM 10.x) required - TORCH_CHECK(hidden_states != nullptr, "hidden_states is required"); - auto device = hidden_states->device().index(); + TVM_FFI_ICHECK(hidden_states != nullptr) << "hidden_states is required"; + auto device = hidden_states->device.device_id; int major = 0, minor = 0; cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device); cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device); - TORCH_CHECK(major == 10, "BF16 MoE requires 10.x architecture. Current device has SM ", major, - minor); + TVM_FFI_ICHECK_EQ(major, 10) << "BF16 MoE requires 10.x architecture. Current device has SM " + << major << minor; this->device_version = std::make_tuple(major, minor); this->routing_logits = routing_logits; @@ -350,21 +342,21 @@ void FusedMoeLauncher::init_common( this->gemm1_weights = gemm1_weights; this->gemm2_weights = gemm2_weights; - args->routing_logits = routing_logits->data_ptr(); - args->routing_bias = routing_bias ? routing_bias->data_ptr() : nullptr; - args->hidden_states = hidden_states->data_ptr(); - args->gemm1_weights = gemm1_weights->data_ptr(); - args->gemm2_weights = gemm2_weights->data_ptr(); + args->routing_logits = routing_logits->data; + args->routing_bias = routing_bias ? routing_bias->data : nullptr; + args->hidden_states = hidden_states->data; + args->gemm1_weights = gemm1_weights->data; + args->gemm2_weights = gemm2_weights->data; this->args = std::move(args); this->tile_tokens_dim = tile_tokens_dim; this->routing_method_type = routing_method_type; this->use_shuffled_weight = use_shuffled_weight; - TORCH_CHECK(0 <= weight_layout && weight_layout <= 2, - "the value of weight_layout is not recognized"); + TVM_FFI_ICHECK(0 <= weight_layout && weight_layout <= 2) + << "the value of weight_layout is not recognized"; this->weight_layout = static_cast(weight_layout); - TORCH_CHECK(0 <= gated_act_type && gated_act_type <= 1, - "the value of gated_act_type is not recognized"); + TVM_FFI_ICHECK(0 <= gated_act_type && gated_act_type <= 1) + << "the value of gated_act_type is not recognized"; this->gated_act_type = static_cast(gated_act_type); } @@ -372,9 +364,8 @@ class Bf16MoeLauncher : public FusedMoeLauncher { public: Bf16MoeLauncher() = default; - void init(at::Tensor const& routing_logits, std::optional const& routing_bias, - at::Tensor const& hidden_states, at::Tensor const& gemm1_weights, - at::Tensor const& gemm2_weights, + void init(Tensor const& routing_logits, Optional const& routing_bias, + Tensor const& hidden_states, Tensor const& gemm1_weights, Tensor const& gemm2_weights, std::unique_ptr&& args, int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout) { @@ -401,23 +392,23 @@ class Bf16MoeLauncher : public FusedMoeLauncher { args->mDtypeExpW = btg::Dtype::Bfloat16; args->mUseDeepSeekFp8 = false; - auto const routing_bias_dtype = at::ScalarType::BFloat16; - expert_weights = at::detail::empty_cuda({args->num_tokens, args->top_k}, routing_bias_dtype, - routing_logits->device(), std::nullopt); + auto const routing_bias_dtype = dl_bfloat16; + expert_weights = + alloc_tensor({args->num_tokens, args->top_k}, routing_bias_dtype, routing_logits->device); - workspace.expert_weights = expert_weights.data_ptr(); + workspace.expert_weights = expert_weights->data; } void check_moe() const override { FusedMoeLauncher::check_moe_common(); - TORCH_CHECK(weight_layout == MatrixLayout::BlockMajorK, - "BF16 Moe: weight_layout must be BlockMajorK"); + TVM_FFI_ICHECK_EQ(weight_layout, MatrixLayout::BlockMajorK) + << "BF16 Moe: weight_layout must be BlockMajorK"; check_weights_shape("gemm1"); check_weights_shape("gemm2"); - TORCH_CHECK(args->intermediate_size % 128 == 0, - "the second dimension of weights must be a multiple of 128."); + TVM_FFI_ICHECK_EQ(args->intermediate_size % 128, 0) + << "the second dimension of weights must be a multiple of 128."; } void prepare_moe(int64_t& moe_tactic) override { @@ -426,60 +417,55 @@ class Bf16MoeLauncher : public FusedMoeLauncher { FusedMoeLauncher::prepare_moe_common(moe_tactic); int32_t max_num_padded_tokens = workspace.total_max_padded_tokens; - gemm1_output = - at::detail::empty_cuda({max_num_padded_tokens, args->intermediate_size}, - at::ScalarType::BFloat16, hidden_states->device(), std::nullopt); - activation_output = - at::detail::empty_cuda({max_num_padded_tokens, args->intermediate_size}, - at::ScalarType::BFloat16, hidden_states->device(), std::nullopt); - gemm2_output = - at::detail::empty_cuda({max_num_padded_tokens, args->hidden_size}, at::ScalarType::BFloat16, - hidden_states->device(), std::nullopt); + gemm1_output = alloc_tensor({max_num_padded_tokens, args->intermediate_size}, dl_bfloat16, + hidden_states->device); + activation_output = alloc_tensor({max_num_padded_tokens, args->intermediate_size}, dl_bfloat16, + hidden_states->device); + gemm2_output = alloc_tensor({max_num_padded_tokens, args->hidden_size}, dl_bfloat16, + hidden_states->device); workspace.hidden_states_scale_linear = nullptr; - workspace.gemm1_output = gemm1_output.data_ptr(); + workspace.gemm1_output = gemm1_output->data; workspace.gemm1_output_scale = nullptr; // BF16 doesn't use scale tensors - workspace.activation_output = activation_output.data_ptr(); + workspace.activation_output = activation_output->data; workspace.activation_output_scale = nullptr; // BF16 doesn't use scale tensors - workspace.gemm2_output = gemm2_output.data_ptr(); + workspace.gemm2_output = gemm2_output->data; workspace.gemm2_output_scale = nullptr; - output = at::detail::empty_cuda({args->num_tokens, args->hidden_size}, at::ScalarType::BFloat16, - hidden_states->device(), std::nullopt); - args->output = output.data_ptr(); + output = + alloc_tensor({args->num_tokens, args->hidden_size}, dl_bfloat16, hidden_states->device); + args->output = output->data; args->output_scale = nullptr; } }; -at::Tensor trtllm_bf16_moe(at::Tensor const& routing_logits, - std::optional const& routing_bias, - at::Tensor const& hidden_states, at::Tensor const& gemm1_weights, - at::Tensor const& gemm2_weights, int64_t num_experts, int64_t top_k, - int64_t n_group, int64_t topk_group, int64_t intermediate_size, - int64_t local_expert_offset, int64_t local_num_experts, - int64_t tile_tokens_dim, int64_t routing_method_type, - bool use_shuffled_weight, int64_t weight_layout, int64_t moe_tactic, - bool enable_pdl) { +Tensor trtllm_bf16_moe(Tensor const& routing_logits, Optional const& routing_bias, + Tensor const& hidden_states, Tensor const& gemm1_weights, + Tensor const& gemm2_weights, int64_t num_experts, int64_t top_k, + int64_t n_group, int64_t topk_group, int64_t intermediate_size, + int64_t local_expert_offset, int64_t local_num_experts, + int64_t tile_tokens_dim, int64_t routing_method_type, + bool use_shuffled_weight, int64_t weight_layout, int64_t moe_tactic, + bool enable_pdl) { // Just some basic type validation first and leave more checks to the launcher - TORCH_CHECK(routing_logits.scalar_type() == at::ScalarType::Float || - routing_logits.scalar_type() == at::ScalarType::BFloat16, - "BF16 MoE: routing_logits must be bfoat16 or float."); + TVM_FFI_ICHECK(routing_logits->dtype == dl_float32 || routing_logits->dtype == dl_bfloat16) + << "BF16 MoE: routing_logits must be bfloat16 or float."; if (routing_bias.has_value()) { - TORCH_CHECK(routing_bias.value().scalar_type() == at::ScalarType::BFloat16, - "BF16 MoE: routing_bias must be bfloat16."); + TVM_FFI_ICHECK_EQ(routing_bias.value()->dtype, dl_bfloat16) + << "BF16 MoE: routing_bias must be bfloat16."; } - TORCH_CHECK(hidden_states.scalar_type() == at::ScalarType::BFloat16, - "BF16 MoE: hidden_states must be bfloat16."); - TORCH_CHECK(gemm1_weights.scalar_type() == at::ScalarType::BFloat16, - "BF16 MoE: gemm1_weights must be bfloat16."); - TORCH_CHECK(gemm2_weights.scalar_type() == at::ScalarType::BFloat16, - "BF16 MoE: gemm2_weights must be bfloat16."); + TVM_FFI_ICHECK_EQ(hidden_states->dtype, dl_bfloat16) + << "BF16 MoE: hidden_states must be bfloat16."; + TVM_FFI_ICHECK_EQ(gemm1_weights->dtype, dl_bfloat16) + << "BF16 MoE: gemm1_weights must be bfloat16."; + TVM_FFI_ICHECK_EQ(gemm2_weights->dtype, dl_bfloat16) + << "BF16 MoE: gemm2_weights must be bfloat16."; // Save params to MoE arguments auto args = std::make_unique(); - args->num_tokens = hidden_states.sizes()[0]; + args->num_tokens = hidden_states->shape[0]; args->num_experts = num_experts; - args->hidden_size = hidden_states.sizes()[1]; + args->hidden_size = hidden_states->shape[1]; args->hidden_size_output = args->hidden_size; args->top_k = top_k; args->n_group = n_group; From 69e64b0e560702afdb92f920817ded8e6729c7f2 Mon Sep 17 00:00:00 2001 From: Alex Yang Date: Sat, 4 Oct 2025 02:07:27 +0000 Subject: [PATCH 08/12] .. --- csrc/trtllm_fused_moe_kernel_launcher.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index 76dc9ce294..4e94f4d2d7 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -35,6 +35,8 @@ namespace btg = batchedGemm::trtllm::gen; using batchedGemm::gemm::MatrixLayout; using tensorrt_llm::kernels::trtllmgen_moe::MoE::GatedActType; using tensorrt_llm::kernels::trtllmgen_moe::Routing::RoutingMethodType; +using tvm::ffi::Array; +using tvm::ffi::Optional; /* @@ -481,8 +483,6 @@ Tensor trtllm_bf16_moe(Tensor const& routing_logits, Optional const& rou auto data = launcher.run(moe_tactic, enable_pdl)[0]; return data; } -using tvm::ffi::Array; -using tvm::ffi::Optional; Tensor trtllm_fp8_per_tensor_scale_moe_launcher( Tensor routing_logits, Optional routing_bias, Tensor hidden_states, From 49a1a734183b91a42ca921388b884fa3dfbd5c12 Mon Sep 17 00:00:00 2001 From: Alex Yang Date: Sat, 4 Oct 2025 02:27:22 +0000 Subject: [PATCH 09/12] .. --- csrc/trtllm_fused_moe_kernel_launcher.cu | 38 +++++++++++------------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index 4e94f4d2d7..bc46a35eb8 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -67,14 +67,14 @@ Driver calls take place to carry out the gemm operations. class FusedMoeLauncher { protected: - Tensor const* routing_logits{}; - Tensor const* routing_bias{}; - Tensor const* hidden_states{}; - Tensor const* gemm1_weights{}; - Tensor const* output1_scales_scalar{}; - Tensor const* output1_scales_gate_scalar{}; - Tensor const* gemm2_weights{}; - Tensor const* output2_scales_scalar{}; + Tensor routing_logits{}; + Tensor routing_bias{}; + Tensor hidden_states{}; + Tensor gemm1_weights{}; + Tensor output1_scales_scalar{}; + Tensor output1_scales_gate_scalar{}; + Tensor gemm2_weights{}; + Tensor output2_scales_scalar{}; int64_t tile_tokens_dim{}; int64_t routing_method_type{}; @@ -91,9 +91,9 @@ class FusedMoeLauncher { // Initialize common data necessary for later. // May throw exception from TVM_FFI_ICHECK. - void init_common(Tensor const* routing_logits, Tensor const* routing_bias, - Tensor const* hidden_states, Tensor const* gemm1_weights, - Tensor const* gemm2_weights, + void init_common(Tensor const& routing_logits, Optional const& routing_bias, + Tensor const& hidden_states, Tensor const& gemm1_weights, + Tensor const& gemm2_weights, std::unique_ptr&& args, int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout, int64_t gated_act_type); @@ -109,7 +109,7 @@ class FusedMoeLauncher { // Routing bias [num_experts] void check_routing_bias_shape() const { - if (routing_bias != nullptr) { + if (routing_bias.defined()) { TVM_FFI_ICHECK_EQ(routing_bias->ndim, 1) << "routing_bias must be 1D."; TVM_FFI_ICHECK_EQ(routing_bias->shape[0], args->num_experts) << "routing_bias has incorrect shape."; @@ -125,7 +125,7 @@ class FusedMoeLauncher { // GEMM1 or GEMM2 weights [num_experts, M, K] or [num_experts, K/block_k, M, block_k] void check_weights_shape(std::string which_weights) const { - Tensor const* weights{}; + Tensor weights{}; if (which_weights == "gemm1") { weights = gemm1_weights; } else if (which_weights == "gemm2") { @@ -323,13 +323,12 @@ class FusedMoeLauncher { }; void FusedMoeLauncher::init_common( - Tensor const* routing_logits, Tensor const* routing_bias, Tensor const* hidden_states, - Tensor const* gemm1_weights, Tensor const* gemm2_weights, + Tensor const& routing_logits, Optional const& routing_bias, Tensor const& hidden_states, + Tensor const& gemm1_weights, Tensor const& gemm2_weights, std::unique_ptr&& args, int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout, int64_t gated_act_type) { // Check devicearchitecture: Blackwell (SM 10.x) required - TVM_FFI_ICHECK(hidden_states != nullptr) << "hidden_states is required"; auto device = hidden_states->device.device_id; int major = 0, minor = 0; cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device); @@ -345,7 +344,7 @@ void FusedMoeLauncher::init_common( this->gemm2_weights = gemm2_weights; args->routing_logits = routing_logits->data; - args->routing_bias = routing_bias ? routing_bias->data : nullptr; + args->routing_bias = routing_bias.has_value() ? routing_bias.value()->data : nullptr; args->hidden_states = hidden_states->data; args->gemm1_weights = gemm1_weights->data; args->gemm2_weights = gemm2_weights->data; @@ -376,9 +375,8 @@ class Bf16MoeLauncher : public FusedMoeLauncher { // Do base class init and perform common checks FusedMoeLauncher::init_common( - &routing_logits, routing_bias.has_value() ? &routing_bias.value() : nullptr, &hidden_states, - &gemm1_weights, &gemm2_weights, std::move(args), tile_tokens_dim, routing_method_type, - use_shuffled_weight, weight_layout, gated_act_type); + routing_logits, routing_bias, hidden_states, gemm1_weights, gemm2_weights, std::move(args), + tile_tokens_dim, routing_method_type, use_shuffled_weight, weight_layout, gated_act_type); } void check_routing() const override { From 0da156625f67ff7fe38a07fd06427bab7d78d2d2 Mon Sep 17 00:00:00 2001 From: Alex Yang Date: Sat, 4 Oct 2025 02:35:15 +0000 Subject: [PATCH 10/12] .. --- csrc/trtllm_fused_moe_kernel_launcher.cu | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index bc46a35eb8..fff8dcc2b6 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -145,7 +145,8 @@ class FusedMoeLauncher { int64_t block_k = weights->shape[3]; K = weights->shape[1] * block_k; } else { - TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported weight_layout: " << weight_layout; + TVM_FFI_LOG_AND_THROW(NotImplementedError) + << "Unsupported weight_layout: " << (int)weight_layout; } TVM_FFI_ICHECK_EQ(weights->shape[0], args->num_experts) << which_weights << " weights expert dimension must match num_experts"; @@ -172,7 +173,7 @@ class FusedMoeLauncher { check_routing_logits_shape(); - if (routing_bias) { + if (routing_bias.has_value()) { check_routing_bias_shape(); } } @@ -338,7 +339,9 @@ void FusedMoeLauncher::init_common( this->device_version = std::make_tuple(major, minor); this->routing_logits = routing_logits; - this->routing_bias = routing_bias; + if (routing_bias.has_value()) { + this->routing_bias = routing_bias.value(); + } this->hidden_states = hidden_states; this->gemm1_weights = gemm1_weights; this->gemm2_weights = gemm2_weights; @@ -735,9 +738,9 @@ Tensor trtllm_fp8_per_tensor_scale_moe( n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, use_routing_scales_on_input, tile_tokens_dim, routing_method_type, enable_pdl); - } else { - TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported input dtype."; } + + TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported input dtype."; } void trtllm_fp8_block_scale_moe_launcher( From a60c01e55f5836779596d24a2b09475322fa5b80 Mon Sep 17 00:00:00 2001 From: Alex Yang Date: Sat, 4 Oct 2025 02:37:03 +0000 Subject: [PATCH 11/12] .. --- csrc/trtllm_fused_moe_kernel_launcher.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index fff8dcc2b6..0c1b876f07 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -173,7 +173,7 @@ class FusedMoeLauncher { check_routing_logits_shape(); - if (routing_bias.has_value()) { + if (routing_bias.defined()) { check_routing_bias_shape(); } } From ef08cdd592ca9fc836a753eadb6aa04732841f26 Mon Sep 17 00:00:00 2001 From: Alex Yang Date: Sat, 4 Oct 2025 02:43:25 +0000 Subject: [PATCH 12/12] .. --- csrc/trtllm_fused_moe_kernel_launcher.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index 0c1b876f07..25b938b11b 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -405,7 +405,7 @@ class Bf16MoeLauncher : public FusedMoeLauncher { void check_moe() const override { FusedMoeLauncher::check_moe_common(); - TVM_FFI_ICHECK_EQ(weight_layout, MatrixLayout::BlockMajorK) + TVM_FFI_ICHECK(weight_layout == MatrixLayout::BlockMajorK) << "BF16 Moe: weight_layout must be BlockMajorK"; check_weights_shape("gemm1"); check_weights_shape("gemm2");