Skip to content

Conversation

@cjluo-nv
Copy link
Collaborator

@cjluo-nv cjluo-nv commented Jan 25, 2026

What does this PR do?

Type of change: ? new feature

Overview: ?

Support loading the MiniMax M2.1 (FP8) checkpoint for PTQ.

Usage

scripts/huggingface_example.sh --model --quant nvfp4 --trust_remote_code

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

Summary by CodeRabbit

  • New Features

    • Added MiniMax M2.1 model quantization support with nvfp4 format.
    • Extended FP8 quantization capabilities with configurable dtype parameter for enhanced precision control.
  • Improvements

    • Enhanced detection of quantized linear module variants.
    • Improved weight unpacking for FP8-based linear modules.
  • Documentation

    • Updated supported models table to include MiniMax M2.1.

✏️ Tip: You can customize this high-level summary in your review settings.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Jan 25, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 25, 2026

📝 Walkthrough

Walkthrough

Adds FP8 quantization support for MiniMax M2.1 model. Introduces _QuantFP8Linear class in HuggingFace plugin, registers MoE modules for MiniMax M2 architecture, adds transformer version gating, updates export utilities to handle FP8 modules, and modifies the weight dequantization kernel signature with a dtype parameter.

Changes

Cohort / File(s) Summary
Documentation & Examples
CHANGELOG.rst, examples/llm_ptq/README.md
Added changelog entry and model support table entry for MiniMax M2.1 quantization with nvfp4 support.
Import Path Updates
examples/deepseek/ptq.py
Updated weight_dequant import from ds_kernel to modelopt.torch.quantization.triton.fp8_kernel.
Export Utilities
modelopt/torch/export/layer_utils.py, modelopt/torch/export/unified_export_hf.py
Enhanced is_quantlinear to detect QuantFP8Linear and exclude lora/ds_kernel modules; extended weight unpacking condition for FP8 modules with element size ≤ 1 byte.
FP8 Kernel Implementation
modelopt/torch/quantization/triton/fp8_kernel.py
Simplified license header, updated documentation, added dtype parameter to weight_dequant function with default value.
HuggingFace Plugin Infrastructure
modelopt/torch/quantization/plugins/huggingface.py
Introduced _QuantFP8Linear class for FP8 weight quantization; added register_minimax_m2_moe_on_the_fly() function; implemented transformer version gating (TRANSFORMERS_VERSION_GE_5_0) for KV attention paths; registered FP8Linear in QuantModuleRegistry.

Sequence Diagram(s)

sequenceDiagram
    participant Model as MiniMax M2.1 Model
    participant HFPlugin as HuggingFace Plugin
    participant Registry as QuantModuleRegistry
    participant ExportUtil as Export Utilities
    participant Kernel as FP8 Kernel

    Model->>HFPlugin: Load model (MiniMaxM2ForCausalLM)
    HFPlugin->>HFPlugin: register_minimax_m2_moe_on_the_fly()
    HFPlugin->>Registry: Register _QuantSparseMoe
    HFPlugin->>Registry: Register _QuantFP8Linear
    
    ExportUtil->>Model: Scan modules
    ExportUtil->>ExportUtil: is_quantlinear() checks for QuantFP8Linear
    alt QuantFP8Linear detected
        ExportUtil->>ExportUtil: Check element_size <= 1
        ExportUtil->>Kernel: Call weight_dequant with dtype
        Kernel->>Kernel: Dequantize weights using scaling factors
        Kernel-->>ExportUtil: Return dequantized tensor
    end
    
    ExportUtil-->>Model: Export with unpacked FP8 weights
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Support MiniMax M2.1 (FP8 checkpoint)' directly and clearly summarizes the main objective of the pull request: adding support for the MiniMax M2.1 model with FP8 checkpoint quantization.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@cjluo-nv cjluo-nv marked this pull request as ready for review January 28, 2026 18:09
@cjluo-nv cjluo-nv requested review from a team as code owners January 28, 2026 18:09
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Fix all issues with AI agents
In `@CHANGELOG.rst`:
- Around line 13-15: The changelog entry duplicates the word "support" on the
MiniMax line; edit the sentence that currently reads "Add support for MiniMax
M2.1 model quantization support for the original FP8 checkpoint." (the MiniMax
M2.1 line) to remove the extra "support", e.g. "Add MiniMax M2.1 model
quantization for the original FP8 checkpoint."

In `@modelopt/torch/quantization/plugins/huggingface.py`:
- Around line 746-754: unpack_weight currently calls weight_dequant without
specifying dtype, which can mismatch forward (which uses dtype=input.dtype); fix
by storing the target dtype during _setup (e.g., save self._orig_dtype or
self.target_dtype) or by passing an explicit dtype parameter into unpack_weight,
then call weight_dequant(weight, scale_inv, self.block_size,
dtype=self._orig_dtype) so unpacked weights match forward; update any callers
and remove weight_scale_inv as before. Reference: unpack_weight, forward,
weight_dequant, and _setup.
- Around line 454-473: In the forward override where TRANSFORMERS_VERSION_GE_5_0
is checked, remove the stray unconditional assignment "self.top_k =
original_top_k" that appears after the version-gated if/else; that assignment is
incorrect for the TRANSFORMERS_VERSION_GE_5_0 branch (original_top_k references
self.gate.topk) and the two branches already restore their respective
attributes, so deleting this line in the forward method (the block manipulating
self.gate.topk and self.top_k) fixes the bug without changing branch-specific
restoration logic.

Comment on lines 13 to +15
- Add standalone type inference option (``--use_standalone_type_inference``) in ONNX AutoCast as an alternative to ONNX's ``infer_shapes``. This experimental feature performs type-only inference without shape inference, useful as a workaround when shape inference fails or to avoid unnecessary shape inference overhead.
- Add support for Kimi K2 Thinking model quantization from the original int4 checkpoint.
- Add support for MiniMax M2.1 model quantization support for the original FP8 checkpoint.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix duplicated wording in the changelog entry.

Line 15 repeats “support”. Consider the tweak below for clarity.

📝 Proposed fix
-- Add support for MiniMax M2.1 model quantization support for the original FP8 checkpoint.
+- Add support for MiniMax M2.1 model quantization for the original FP8 checkpoint.
🤖 Prompt for AI Agents
In `@CHANGELOG.rst` around lines 13 - 15, The changelog entry duplicates the word
"support" on the MiniMax line; edit the sentence that currently reads "Add
support for MiniMax M2.1 model quantization support for the original FP8
checkpoint." (the MiniMax M2.1 line) to remove the extra "support", e.g. "Add
MiniMax M2.1 model quantization for the original FP8 checkpoint."

Comment on lines +454 to 473
if TRANSFORMERS_VERSION_GE_5_0:
assert hasattr(self, "gate")
# Path for transformers >= 5.0
original_top_k = self.gate.topk
self.gate.topk = self.gate.num_experts
super().forward(hidden_states)
self.gate.topk = original_top_k
else:
# Path for transformers < 5.0
original_top_k = self.top_k
if hasattr(self, "num_experts"):
self.top_k = self.num_experts
elif hasattr(self, "experts"):
self.top_k = self.experts.num_experts
else:
raise ValueError(f"Could not find num_experts in module {self}")
super().forward(hidden_states)
self.top_k = original_top_k
self.top_k = original_top_k
return super().forward(hidden_states)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Bug: Stray self.top_k assignment outside version-gated block.

Line 472 (self.top_k = original_top_k) executes unconditionally after both branches. For TRANSFORMERS_VERSION_GE_5_0:

  • original_top_k holds self.gate.topk, not self.top_k
  • Setting self.top_k may fail if the attribute doesn't exist, or corrupt state if it does

This line appears to be a leftover from refactoring and should be removed since each branch already handles its own restoration.

🐛 Proposed fix
             super().forward(hidden_states)
             self.top_k = original_top_k
-        self.top_k = original_top_k
         return super().forward(hidden_states)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if TRANSFORMERS_VERSION_GE_5_0:
assert hasattr(self, "gate")
# Path for transformers >= 5.0
original_top_k = self.gate.topk
self.gate.topk = self.gate.num_experts
super().forward(hidden_states)
self.gate.topk = original_top_k
else:
# Path for transformers < 5.0
original_top_k = self.top_k
if hasattr(self, "num_experts"):
self.top_k = self.num_experts
elif hasattr(self, "experts"):
self.top_k = self.experts.num_experts
else:
raise ValueError(f"Could not find num_experts in module {self}")
super().forward(hidden_states)
self.top_k = original_top_k
self.top_k = original_top_k
return super().forward(hidden_states)
if TRANSFORMERS_VERSION_GE_5_0:
assert hasattr(self, "gate")
# Path for transformers >= 5.0
original_top_k = self.gate.topk
self.gate.topk = self.gate.num_experts
super().forward(hidden_states)
self.gate.topk = original_top_k
else:
# Path for transformers < 5.0
original_top_k = self.top_k
if hasattr(self, "num_experts"):
self.top_k = self.num_experts
elif hasattr(self, "experts"):
self.top_k = self.experts.num_experts
else:
raise ValueError(f"Could not find num_experts in module {self}")
super().forward(hidden_states)
self.top_k = original_top_k
return super().forward(hidden_states)
🤖 Prompt for AI Agents
In `@modelopt/torch/quantization/plugins/huggingface.py` around lines 454 - 473,
In the forward override where TRANSFORMERS_VERSION_GE_5_0 is checked, remove the
stray unconditional assignment "self.top_k = original_top_k" that appears after
the version-gated if/else; that assignment is incorrect for the
TRANSFORMERS_VERSION_GE_5_0 branch (original_top_k references self.gate.topk)
and the two branches already restore their respective attributes, so deleting
this line in the forward method (the block manipulating self.gate.topk and
self.top_k) fixes the bug without changing branch-specific restoration logic.

Comment on lines +746 to +754
def unpack_weight(self):
with torch.cuda.device(self.weight.device):
weight, scale_inv = self._get_weight_and_scale_inv()
self.weight = nn.Parameter(
weight_dequant(weight, scale_inv, self.block_size),
requires_grad=False,
)
if hasattr(self, "weight_scale_inv"):
del self.weight_scale_inv
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Consider specifying dtype in unpack_weight for consistency.

In forward(), weight_dequant is called with dtype=input.dtype, preserving the input's precision. However, unpack_weight() omits dtype, defaulting to torch.get_default_dtype() (typically float32).

If this is intentional for export, a comment would clarify the design. Otherwise, consider accepting/storing a target dtype to ensure consistency.

💡 Suggested improvement
     def unpack_weight(self):
         with torch.cuda.device(self.weight.device):
             weight, scale_inv = self._get_weight_and_scale_inv()
             self.weight = nn.Parameter(
-                weight_dequant(weight, scale_inv, self.block_size),
+                weight_dequant(weight, scale_inv, self.block_size, dtype=torch.bfloat16),
                 requires_grad=False,
             )

Or store the original dtype during _setup and use it here.

🤖 Prompt for AI Agents
In `@modelopt/torch/quantization/plugins/huggingface.py` around lines 746 - 754,
unpack_weight currently calls weight_dequant without specifying dtype, which can
mismatch forward (which uses dtype=input.dtype); fix by storing the target dtype
during _setup (e.g., save self._orig_dtype or self.target_dtype) or by passing
an explicit dtype parameter into unpack_weight, then call weight_dequant(weight,
scale_inv, self.block_size, dtype=self._orig_dtype) so unpacked weights match
forward; update any callers and remove weight_scale_inv as before. Reference:
unpack_weight, forward, weight_dequant, and _setup.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant