-
Notifications
You must be signed in to change notification settings - Fork 530
Update the routing for TRTLLMGEN to support kimi k2 and qwen #1831
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Christina Zhang <83400082+ChristinaZ@users.noreply.github.com>
Summary of ChangesHello @ChristinaZ, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly refactors the Mixture-of-Experts (MoE) routing logic within TRTLLMGEN to enhance its compatibility and performance for new models like Kimi K2 and Qwen. The changes introduce dynamic resource allocation for kernels, standardize data handling for top-K expert selection, and improve the robustness of numerical operations, thereby broadening the framework's support for diverse MoE architectures. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces significant updates to the MoE routing kernels to support new models like Kimi K2 and Qwen, and to align with TRTLLM implementations. The changes involve making the kernels more generic to handle variable numbers of experts, refactoring pointer names for consistency (e.g., mPtrExpertIdx
to mPtrTopKPacked
), and adding new execution paths for pre-computed top-K indices. The use of __launch_bounds__
and replacing cudaMemsetAsync
with a dedicated kernel are good improvements.
My review focuses on a few areas for improvement:
- A typo in a variable name that affects readability.
- A confusing
static_assert
comment in the top-K reduction logic. - A potential bug in the new
reduceTopK
implementation related to an unresolved@todo
and suspicious index initialization, which could lead to incorrect behavior.
Overall, the changes are well-structured and move towards a more flexible and robust implementation. Addressing the identified issues will further improve the code quality.
topKBufferValue[ii] = minValue; | ||
topKBufferIdx[ii] = ii * WarpSize - 1; //@todo: check if this is correct |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This initialization seems suspicious, and the @todo
comment indicates it might not be fully validated.
The initialization topKBufferIdx[ii] = ii * WarpSize - 1;
can result in idx = -1
when ii=0
. This value is then used in RedType::makeCmpVal
which calculates maxIdx - idx
. With idx = -1
, this becomes 65535 - (-1) = 65536
, which overflows the 0xFFFF
mask used for the index part of the packed value. This could lead to incorrect tie-breaking or other subtle bugs.
Please resolve the @todo
and consider a safer way to initialize invalid indices, for example using RedType::maxIdx
or another value that won't cause overflow issues.
float intermidiateScore[NumInterTopKPerThread]; | ||
int32_t intermidiateExpert[NumInterTopKPerThread]; | ||
for (int i = laneIdx; i < NumInterTopKPerThread * WarpSize; i += WarpSize) { | ||
int ii = i / WarpSize; | ||
if (i < NumInterTopK) { | ||
intermidiateScore[ii] = smemInterTopScores[i]; | ||
intermidiateExpert[ii] = smemInterTopExperts[i]; | ||
} else { | ||
intermidiateScore[ii] = invalidScoreFloat; | ||
intermidiateExpert[ii] = KernelParams::MaxNumExperts - 1; | ||
} | ||
} | ||
topk::reduceTopK(warp, topScores, topExperts, intermidiateScore, intermidiateExpert, | ||
/* minValue */ invalidScoreFloat, params.mTopK); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a typo in intermidiateScore
and intermidiateExpert
. It should be intermediateScore
and intermediateExpert
for better readability.
float intermediateScore[NumInterTopKPerThread];
int32_t intermediateExpert[NumInterTopKPerThread];
for (int i = laneIdx; i < NumInterTopKPerThread * WarpSize; i += WarpSize) {
int ii = i / WarpSize;
if (i < NumInterTopK) {
intermediateScore[ii] = smemInterTopScores[i];
intermediateExpert[ii] = smemInterTopExperts[i];
} else {
intermediateScore[ii] = invalidScoreFloat;
intermediateExpert[ii] = KernelParams::MaxNumExperts - 1;
}
}
topk::reduceTopK(warp, topScores, topExperts, intermediateScore, intermediateExpert,
/* minValue */ invalidScoreFloat, params.mTopK);
static_assert(K < WarpSize, "Top K must have K < WarpSize"); | ||
static_assert(N > 0, "Top K must have N > 0"); | ||
static_assert(N <= 16, "Only support candidates number less than or equal to 128"); | ||
static_assert(N < 5, "Only support candidates number less than or equal to 128"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment in this static_assert
is confusing. It says "Only support candidates number less than or equal to 128", but the condition is N < 5
. This seems to be a copy-paste error. The condition N < 5
(i.e., N <= 4
) is what's actually enforced. The comment should be updated to reflect this.
static_assert(N < 5, "Only support up to 4 candidates per thread in this function.");
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
TVM_FFI_ICHECK_LT(top_k, (topk_group * num_experts / n_group)) | ||
<< "top_k must be less than total number of experts in selected groups"; | ||
} | ||
// if (n_group <= 0 || topk_group <= 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ChristinaZ @jiahanc is this PR ready? If so, let's remove these comments.
|
||
static constexpr int NumThreads = 384; | ||
static constexpr int NumWarps = NumThreads / WarpSize; | ||
static constexpr int NumKimiK2Experts = 384; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @GordonGustafson this function is updated.
📌 Description
Update the routing code to align with the implementation in TRTLLM and add support for KIMI K2 and Qwen
Also revised the unit test based on the config of kimi k2 (https://huggingface.co/moonshotai/Kimi-K2-Instruct/blob/main/config.json)
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commit
by runningpip install pre-commit
(or used your preferred method).pre-commit install
.pre-commit run --all-files
and fixed any reported issues.🧪 Tests
unittest
, etc.).Reviewer Notes