-
Notifications
You must be signed in to change notification settings - Fork 247
Support MiniMax M2.1 (FP8 checkpoint) #817
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
📝 WalkthroughWalkthroughAdds FP8 quantization support for MiniMax M2.1 model. Introduces Changes
Sequence Diagram(s)sequenceDiagram
participant Model as MiniMax M2.1 Model
participant HFPlugin as HuggingFace Plugin
participant Registry as QuantModuleRegistry
participant ExportUtil as Export Utilities
participant Kernel as FP8 Kernel
Model->>HFPlugin: Load model (MiniMaxM2ForCausalLM)
HFPlugin->>HFPlugin: register_minimax_m2_moe_on_the_fly()
HFPlugin->>Registry: Register _QuantSparseMoe
HFPlugin->>Registry: Register _QuantFP8Linear
ExportUtil->>Model: Scan modules
ExportUtil->>ExportUtil: is_quantlinear() checks for QuantFP8Linear
alt QuantFP8Linear detected
ExportUtil->>ExportUtil: Check element_size <= 1
ExportUtil->>Kernel: Call weight_dequant with dtype
Kernel->>Kernel: Dequantize weights using scaling factors
Kernel-->>ExportUtil: Return dequantized tensor
end
ExportUtil-->>Model: Export with unpacked FP8 weights
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🤖 Fix all issues with AI agents
In `@CHANGELOG.rst`:
- Around line 13-15: The changelog entry duplicates the word "support" on the
MiniMax line; edit the sentence that currently reads "Add support for MiniMax
M2.1 model quantization support for the original FP8 checkpoint." (the MiniMax
M2.1 line) to remove the extra "support", e.g. "Add MiniMax M2.1 model
quantization for the original FP8 checkpoint."
In `@modelopt/torch/quantization/plugins/huggingface.py`:
- Around line 746-754: unpack_weight currently calls weight_dequant without
specifying dtype, which can mismatch forward (which uses dtype=input.dtype); fix
by storing the target dtype during _setup (e.g., save self._orig_dtype or
self.target_dtype) or by passing an explicit dtype parameter into unpack_weight,
then call weight_dequant(weight, scale_inv, self.block_size,
dtype=self._orig_dtype) so unpacked weights match forward; update any callers
and remove weight_scale_inv as before. Reference: unpack_weight, forward,
weight_dequant, and _setup.
- Around line 454-473: In the forward override where TRANSFORMERS_VERSION_GE_5_0
is checked, remove the stray unconditional assignment "self.top_k =
original_top_k" that appears after the version-gated if/else; that assignment is
incorrect for the TRANSFORMERS_VERSION_GE_5_0 branch (original_top_k references
self.gate.topk) and the two branches already restore their respective
attributes, so deleting this line in the forward method (the block manipulating
self.gate.topk and self.top_k) fixes the bug without changing branch-specific
restoration logic.
| - Add standalone type inference option (``--use_standalone_type_inference``) in ONNX AutoCast as an alternative to ONNX's ``infer_shapes``. This experimental feature performs type-only inference without shape inference, useful as a workaround when shape inference fails or to avoid unnecessary shape inference overhead. | ||
| - Add support for Kimi K2 Thinking model quantization from the original int4 checkpoint. | ||
| - Add support for MiniMax M2.1 model quantization support for the original FP8 checkpoint. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix duplicated wording in the changelog entry.
Line 15 repeats “support”. Consider the tweak below for clarity.
📝 Proposed fix
-- Add support for MiniMax M2.1 model quantization support for the original FP8 checkpoint.
+- Add support for MiniMax M2.1 model quantization for the original FP8 checkpoint.🤖 Prompt for AI Agents
In `@CHANGELOG.rst` around lines 13 - 15, The changelog entry duplicates the word
"support" on the MiniMax line; edit the sentence that currently reads "Add
support for MiniMax M2.1 model quantization support for the original FP8
checkpoint." (the MiniMax M2.1 line) to remove the extra "support", e.g. "Add
MiniMax M2.1 model quantization for the original FP8 checkpoint."
| if TRANSFORMERS_VERSION_GE_5_0: | ||
| assert hasattr(self, "gate") | ||
| # Path for transformers >= 5.0 | ||
| original_top_k = self.gate.topk | ||
| self.gate.topk = self.gate.num_experts | ||
| super().forward(hidden_states) | ||
| self.gate.topk = original_top_k | ||
| else: | ||
| # Path for transformers < 5.0 | ||
| original_top_k = self.top_k | ||
| if hasattr(self, "num_experts"): | ||
| self.top_k = self.num_experts | ||
| elif hasattr(self, "experts"): | ||
| self.top_k = self.experts.num_experts | ||
| else: | ||
| raise ValueError(f"Could not find num_experts in module {self}") | ||
| super().forward(hidden_states) | ||
| self.top_k = original_top_k | ||
| self.top_k = original_top_k | ||
| return super().forward(hidden_states) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Stray self.top_k assignment outside version-gated block.
Line 472 (self.top_k = original_top_k) executes unconditionally after both branches. For TRANSFORMERS_VERSION_GE_5_0:
original_top_kholdsself.gate.topk, notself.top_k- Setting
self.top_kmay fail if the attribute doesn't exist, or corrupt state if it does
This line appears to be a leftover from refactoring and should be removed since each branch already handles its own restoration.
🐛 Proposed fix
super().forward(hidden_states)
self.top_k = original_top_k
- self.top_k = original_top_k
return super().forward(hidden_states)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if TRANSFORMERS_VERSION_GE_5_0: | |
| assert hasattr(self, "gate") | |
| # Path for transformers >= 5.0 | |
| original_top_k = self.gate.topk | |
| self.gate.topk = self.gate.num_experts | |
| super().forward(hidden_states) | |
| self.gate.topk = original_top_k | |
| else: | |
| # Path for transformers < 5.0 | |
| original_top_k = self.top_k | |
| if hasattr(self, "num_experts"): | |
| self.top_k = self.num_experts | |
| elif hasattr(self, "experts"): | |
| self.top_k = self.experts.num_experts | |
| else: | |
| raise ValueError(f"Could not find num_experts in module {self}") | |
| super().forward(hidden_states) | |
| self.top_k = original_top_k | |
| self.top_k = original_top_k | |
| return super().forward(hidden_states) | |
| if TRANSFORMERS_VERSION_GE_5_0: | |
| assert hasattr(self, "gate") | |
| # Path for transformers >= 5.0 | |
| original_top_k = self.gate.topk | |
| self.gate.topk = self.gate.num_experts | |
| super().forward(hidden_states) | |
| self.gate.topk = original_top_k | |
| else: | |
| # Path for transformers < 5.0 | |
| original_top_k = self.top_k | |
| if hasattr(self, "num_experts"): | |
| self.top_k = self.num_experts | |
| elif hasattr(self, "experts"): | |
| self.top_k = self.experts.num_experts | |
| else: | |
| raise ValueError(f"Could not find num_experts in module {self}") | |
| super().forward(hidden_states) | |
| self.top_k = original_top_k | |
| return super().forward(hidden_states) |
🤖 Prompt for AI Agents
In `@modelopt/torch/quantization/plugins/huggingface.py` around lines 454 - 473,
In the forward override where TRANSFORMERS_VERSION_GE_5_0 is checked, remove the
stray unconditional assignment "self.top_k = original_top_k" that appears after
the version-gated if/else; that assignment is incorrect for the
TRANSFORMERS_VERSION_GE_5_0 branch (original_top_k references self.gate.topk)
and the two branches already restore their respective attributes, so deleting
this line in the forward method (the block manipulating self.gate.topk and
self.top_k) fixes the bug without changing branch-specific restoration logic.
| def unpack_weight(self): | ||
| with torch.cuda.device(self.weight.device): | ||
| weight, scale_inv = self._get_weight_and_scale_inv() | ||
| self.weight = nn.Parameter( | ||
| weight_dequant(weight, scale_inv, self.block_size), | ||
| requires_grad=False, | ||
| ) | ||
| if hasattr(self, "weight_scale_inv"): | ||
| del self.weight_scale_inv |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider specifying dtype in unpack_weight for consistency.
In forward(), weight_dequant is called with dtype=input.dtype, preserving the input's precision. However, unpack_weight() omits dtype, defaulting to torch.get_default_dtype() (typically float32).
If this is intentional for export, a comment would clarify the design. Otherwise, consider accepting/storing a target dtype to ensure consistency.
💡 Suggested improvement
def unpack_weight(self):
with torch.cuda.device(self.weight.device):
weight, scale_inv = self._get_weight_and_scale_inv()
self.weight = nn.Parameter(
- weight_dequant(weight, scale_inv, self.block_size),
+ weight_dequant(weight, scale_inv, self.block_size, dtype=torch.bfloat16),
requires_grad=False,
)Or store the original dtype during _setup and use it here.
🤖 Prompt for AI Agents
In `@modelopt/torch/quantization/plugins/huggingface.py` around lines 746 - 754,
unpack_weight currently calls weight_dequant without specifying dtype, which can
mismatch forward (which uses dtype=input.dtype); fix by storing the target dtype
during _setup (e.g., save self._orig_dtype or self.target_dtype) or by passing
an explicit dtype parameter into unpack_weight, then call weight_dequant(weight,
scale_inv, self.block_size, dtype=self._orig_dtype) so unpacked weights match
forward; update any callers and remove weight_scale_inv as before. Reference:
unpack_weight, forward, weight_dequant, and _setup.
What does this PR do?
Type of change: ? new feature
Overview: ?
Support loading the MiniMax M2.1 (FP8) checkpoint for PTQ.
Usage
scripts/huggingface_example.sh --model --quant nvfp4 --trust_remote_code
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Improvements
Documentation
✏️ Tip: You can customize this high-level summary in your review settings.