Skip to content

Conversation

ChristinaZ
Copy link

@ChristinaZ ChristinaZ commented Oct 1, 2025

📌 Description

Update the routing code to align with the implementation in TRTLLM and add support for KIMI K2 and Qwen

Also revised the unit test based on the config of kimi k2 (https://huggingface.co/moonshotai/Kimi-K2-Instruct/blob/main/config.json)

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Signed-off-by: Christina Zhang <83400082+ChristinaZ@users.noreply.github.com>
Copy link
Contributor

Summary of Changes

Hello @ChristinaZ, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly refactors the Mixture-of-Experts (MoE) routing logic within TRTLLMGEN to enhance its compatibility and performance for new models like Kimi K2 and Qwen. The changes introduce dynamic resource allocation for kernels, standardize data handling for top-K expert selection, and improve the robustness of numerical operations, thereby broadening the framework's support for diverse MoE architectures.

Highlights

  • New Model Support: The routing code has been updated to support Kimi K2 and Qwen models, aligning with the existing TRTLLM implementation.
  • Dynamic Expert Count Handling: The routing kernels now dynamically determine thread and shared memory configurations based on the number of experts (MaxNumExperts), replacing previously fixed constants. This allows for more flexible and efficient handling of varying MoE configurations.
  • Unified Top-K Data Pointers: The data structures and kernel parameters have been refactored to use standardized pointer names like mPtrTopKPacked, mPtrTopKWeights, and mPtrTopKIds for managing top-K expert scores and indices, improving consistency across different routing methods.
  • Enhanced Top-K Reduction Logic: The reduceTopK mechanism has been improved to handle a broader range of candidate numbers more efficiently, including a new overload that processes larger inputs in chunks.
  • Softmax Calculation Improvements: Softmax calculations within the kernels now explicitly use float for intermediate computations, ensuring better numerical stability when dealing with half or bfloat16 input types.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant updates to the MoE routing kernels to support new models like Kimi K2 and Qwen, and to align with TRTLLM implementations. The changes involve making the kernels more generic to handle variable numbers of experts, refactoring pointer names for consistency (e.g., mPtrExpertIdx to mPtrTopKPacked), and adding new execution paths for pre-computed top-K indices. The use of __launch_bounds__ and replacing cudaMemsetAsync with a dedicated kernel are good improvements.

My review focuses on a few areas for improvement:

  • A typo in a variable name that affects readability.
  • A confusing static_assert comment in the top-K reduction logic.
  • A potential bug in the new reduceTopK implementation related to an unresolved @todo and suspicious index initialization, which could lead to incorrect behavior.

Overall, the changes are well-structured and move towards a more flexible and robust implementation. Addressing the identified issues will further improve the code quality.

Comment on lines +371 to +372
topKBufferValue[ii] = minValue;
topKBufferIdx[ii] = ii * WarpSize - 1; //@todo: check if this is correct
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This initialization seems suspicious, and the @todo comment indicates it might not be fully validated.

The initialization topKBufferIdx[ii] = ii * WarpSize - 1; can result in idx = -1 when ii=0. This value is then used in RedType::makeCmpVal which calculates maxIdx - idx. With idx = -1, this becomes 65535 - (-1) = 65536, which overflows the 0xFFFF mask used for the index part of the packed value. This could lead to incorrect tie-breaking or other subtle bugs.

Please resolve the @todo and consider a safer way to initialize invalid indices, for example using RedType::maxIdx or another value that won't cause overflow issues.

Comment on lines +192 to +205
float intermidiateScore[NumInterTopKPerThread];
int32_t intermidiateExpert[NumInterTopKPerThread];
for (int i = laneIdx; i < NumInterTopKPerThread * WarpSize; i += WarpSize) {
int ii = i / WarpSize;
if (i < NumInterTopK) {
intermidiateScore[ii] = smemInterTopScores[i];
intermidiateExpert[ii] = smemInterTopExperts[i];
} else {
intermidiateScore[ii] = invalidScoreFloat;
intermidiateExpert[ii] = KernelParams::MaxNumExperts - 1;
}
}
topk::reduceTopK(warp, topScores, topExperts, intermidiateScore, intermidiateExpert,
/* minValue */ invalidScoreFloat, params.mTopK);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's a typo in intermidiateScore and intermidiateExpert. It should be intermediateScore and intermediateExpert for better readability.

        float intermediateScore[NumInterTopKPerThread];
        int32_t intermediateExpert[NumInterTopKPerThread];
        for (int i = laneIdx; i < NumInterTopKPerThread * WarpSize; i += WarpSize) {
          int ii = i / WarpSize;
          if (i < NumInterTopK) {
            intermediateScore[ii] = smemInterTopScores[i];
            intermediateExpert[ii] = smemInterTopExperts[i];
          } else {
            intermediateScore[ii] = invalidScoreFloat;
            intermediateExpert[ii] = KernelParams::MaxNumExperts - 1;
          }
        }
        topk::reduceTopK(warp, topScores, topExperts, intermediateScore, intermediateExpert,
                         /* minValue */ invalidScoreFloat, params.mTopK);

static_assert(K < WarpSize, "Top K must have K < WarpSize");
static_assert(N > 0, "Top K must have N > 0");
static_assert(N <= 16, "Only support candidates number less than or equal to 128");
static_assert(N < 5, "Only support candidates number less than or equal to 128");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment in this static_assert is confusing. It says "Only support candidates number less than or equal to 128", but the condition is N < 5. This seems to be a copy-paste error. The condition N < 5 (i.e., N <= 4) is what's actually enforced. The comment should be updated to reflect this.

  static_assert(N < 5, "Only support up to 4 candidates per thread in this function.");

Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
TVM_FFI_ICHECK_LT(top_k, (topk_group * num_experts / n_group))
<< "top_k must be less than total number of experts in selected groups";
}
// if (n_group <= 0 || topk_group <= 0) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ChristinaZ @jiahanc is this PR ready? If so, let's remove these comments.


static constexpr int NumThreads = 384;
static constexpr int NumWarps = NumThreads / WarpSize;
static constexpr int NumKimiK2Experts = 384;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @GordonGustafson this function is updated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants