From 2b9aeea6d1166218424559be8e5bec5a79bace01 Mon Sep 17 00:00:00 2001 From: jenchen13 Date: Mon, 12 Jan 2026 18:51:33 -0800 Subject: [PATCH 01/20] support latent moe import and fix local experts sync Signed-off-by: jenchen13 --- .../torch/export/plugins/mcore_nemotron.py | 8 +++++++- modelopt/torch/quantization/model_calib.py | 18 ++++++++++++------ 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/modelopt/torch/export/plugins/mcore_nemotron.py b/modelopt/torch/export/plugins/mcore_nemotron.py index 5fdb8ba1b..53fd0d232 100644 --- a/modelopt/torch/export/plugins/mcore_nemotron.py +++ b/modelopt/torch/export/plugins/mcore_nemotron.py @@ -81,8 +81,11 @@ "shared_experts.linear_fc2": NameRemapping( "backbone.layers.{}.mixer.shared_experts.down_proj.", ROW_TP ), -} + # Latent MoE + "fc1_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc1_latent_proj.", REPLICATE), + "fc2_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc2_latent_proj.", REPLICATE), +} nemotron_h_causal_lm_export: dict[str, CustomModuleMapping] = { "word_embeddings": NameRemapping("backbone.embeddings."), @@ -115,4 +118,7 @@ "shared_experts.linear_fc2": NameRemapping( "backbone.layers.{}.mixer.shared_experts.down_proj." ), + # Latent MoE + "fc1_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc1_latent_proj."), + "fc2_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc2_latent_proj."), } diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 0077d8666..5ec9224f3 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -95,13 +95,22 @@ def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state): quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group) # TODO: create sync_bias_across_distributed_group - # Step 1:Sync amax across data parallelism + # Step 1: Sync amax across local experts in a SequentialMLP + for name, module in model.named_modules(): + if hasattr(module, "sync_moe_local_experts_amax"): + module.sync_moe_local_experts_amax() + + # TODO just for testing + if "experts" in name and "weight_quantizer" in name: + assert child.amax is not None + + # Step 2:Sync amax across data parallelism for name, module in model.named_modules(): if isinstance(module, QuantModule): for child in module.children(): if isinstance(child, (TensorQuantizer, SequentialQuantizer)): sync_quantizer_amax_across_dp_ep(child, module.parallel_state) - # TP sync: + # Step 3: TP sync # Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same # ColumnParallel: X @ [A_1, A_2] (weights split along Cout) @@ -182,10 +191,7 @@ def sync_quantizer_amax_across_tp( parallel_state=module.parallel_state, ) - # MOE Quantization - if hasattr(module, "sync_moe_local_experts_amax"): - module.sync_moe_local_experts_amax() - + # KV Cache Quantization if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"): # We only support KVCache quantization with scalar per-tensor states for now (NVFP4 & FP8 KV cache) From 989eaab3f4d7e79c0e83c59c41303d0e28f39820 Mon Sep 17 00:00:00 2001 From: jenchen13 Date: Mon, 12 Jan 2026 21:06:15 -0800 Subject: [PATCH 02/20] patch TransformerLayer forward Signed-off-by: jenchen13 --- modelopt/torch/export/plugins/mcore_nemotron.py | 2 ++ modelopt/torch/quantization/plugins/megatron.py | 8 ++++++++ 2 files changed, 10 insertions(+) diff --git a/modelopt/torch/export/plugins/mcore_nemotron.py b/modelopt/torch/export/plugins/mcore_nemotron.py index 53fd0d232..a61fc367e 100644 --- a/modelopt/torch/export/plugins/mcore_nemotron.py +++ b/modelopt/torch/export/plugins/mcore_nemotron.py @@ -87,6 +87,8 @@ } +# TODO later support MTP import/export + nemotron_h_causal_lm_export: dict[str, CustomModuleMapping] = { "word_embeddings": NameRemapping("backbone.embeddings."), "final_norm": NameRemapping("backbone.norm_f."), diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 6a2bdc15f..9e0e9c818 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -23,6 +23,7 @@ import megatron.core.parallel_state as mcore_parallel import megatron.core.tensor_parallel.layers as megatron_parallel import megatron.core.transformer.mlp as megatron_mlp +import megatron.core.transformer.transformer_layer as megatron_transformer_layer import megatron.core.transformer.moe.experts as megatron_moe import megatron.core.transformer.moe.moe_layer as megatron_moe_layer import torch @@ -40,6 +41,7 @@ register_modelopt_extra_state_callbacks, ) from modelopt.torch.utils.distributed import ParallelState +import torch.distributed as dist from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer from ..nn.modules.quant_linear import RealQuantLinear @@ -593,12 +595,18 @@ def sync_moe_local_experts_amax(self): if stored_amax is None else torch.maximum(stored_amax, amax_tensor) ) + #if isinstance(module, TensorQuantizer) and module.amax is None: + # print(f"MISSING AMAX BEFORE SYNC in expert rank {dist.get_rank()}: {name}", flush=True) + + # Apply synchronized amax values back to all local experts for expert in self.local_experts: for name, module in expert.named_modules(): if isinstance(module, TensorQuantizer) and module.amax is not None: module.amax = amax_dict[name].detach().clone().to(module.amax.device) + #if isinstance(module, TensorQuantizer) and module.amax is None: + # print(f"MISSING AMAX AFTER SYNC in expert rank {dist.get_rank()}: {name}", flush=True) def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): """Override the default to enable singleton_local_shards. From 2c6dc7183692cd68a255714bfd14c27f71cf3e97 Mon Sep 17 00:00:00 2001 From: jenchen13 Date: Tue, 20 Jan 2026 12:18:13 -0800 Subject: [PATCH 03/20] fix kv bmm export Signed-off-by: jenchen13 --- modelopt/torch/export/plugins/mcore_custom.py | 9 +++ .../torch/export/plugins/mcore_nemotron.py | 12 ++- modelopt/torch/export/quant_utils.py | 31 +++----- .../torch/export/unified_export_megatron.py | 73 ++++++++----------- modelopt/torch/quantization/model_calib.py | 5 -- .../torch/quantization/plugins/megatron.py | 2 - 6 files changed, 61 insertions(+), 71 deletions(-) diff --git a/modelopt/torch/export/plugins/mcore_custom.py b/modelopt/torch/export/plugins/mcore_custom.py index 23804b322..25a2cd0cb 100644 --- a/modelopt/torch/export/plugins/mcore_custom.py +++ b/modelopt/torch/export/plugins/mcore_custom.py @@ -126,6 +126,15 @@ def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any] func_kwargs=func_kwargs, ) +class SelfAttentionScaling(CustomModuleMapping): + """A custom module mapping that scales self attention.""" + def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any] = {}): + """Create a custom module mapping that scales self attention.""" + super().__init__( + func_name="self_attention_scaling", + target_name_or_prefix=target_name_or_prefix, + func_kwargs=func_kwargs, + ) class GatedMLPSlicing(CustomModuleMapping): """A custom module mapping that slices gate_proj and up_proj.""" diff --git a/modelopt/torch/export/plugins/mcore_nemotron.py b/modelopt/torch/export/plugins/mcore_nemotron.py index a61fc367e..f857230ae 100644 --- a/modelopt/torch/export/plugins/mcore_nemotron.py +++ b/modelopt/torch/export/plugins/mcore_nemotron.py @@ -26,6 +26,7 @@ NameRemapping, QKVMerging, QKVSlicing, + SelfAttentionScaling, ) # Example on adding a new CausalLM. @@ -84,10 +85,18 @@ # Latent MoE "fc1_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc1_latent_proj.", REPLICATE), "fc2_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc2_latent_proj.", REPLICATE), + # MTP + #"enorm": NameRemapping("mtp.layers.{}.enorm.", REPLICATE), + #"hnorm": NameRemapping("mtp.layers.{}.hnorm.", REPLICATE), + #"eh_proj": NameRemapping("mtp.layers.{}.eh_proj.", REPLICATE), + #"layer_norm": NameRemapping("mtp.layers.{}.final_layernorm.", REPLICATE), + #"norm": NameRemapping("mtp.layers.{}.norm", REPLICATE) + # "transformer_layer": NameRemapping("mtp.layers.{}.mixer", REPLICATE), + } -# TODO later support MTP import/export +# TODO ADD MTP export nemotron_h_causal_lm_export: dict[str, CustomModuleMapping] = { "word_embeddings": NameRemapping("backbone.embeddings."), @@ -106,6 +115,7 @@ "input_layernorm": NameRemapping("backbone.layers.{}.norm."), "linear_qkv": QKVSlicing("backbone.layers.{}.mixer."), "linear_proj": NameRemapping("backbone.layers.{}.mixer.o_proj."), + "core_attention": SelfAttentionScaling("backbone.layers.{}.mixer."), # MLP "pre_mlp_layernorm": NameRemapping("backbone.layers.{}.norm."), "linear_fc1": NameRemapping("backbone.layers.{}.mixer.up_proj."), diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index e7587d222..00f03b9aa 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -387,7 +387,7 @@ def get_prequant_scaling_factor(module: nn.Module) -> torch.Tensor: if prequant_scaling_factor is not None: assert torch.all(prequant_scaling_factor > 0), ( f"prequant scaling factor {prequant_scaling_factor} not positive." - ) + ) return prequant_scaling_factor @@ -399,32 +399,22 @@ def get_kv_cache_bias(kv_module: nn.Module) -> list[torch.Tensor]: kv_bias.append(getattr(quantizer_module, "_bias_value", None)) return kv_bias - -def get_kv_cache_scaling_factor(kv_module: nn.Module) -> list[torch.Tensor]: - """Returns the kv_cache scaling factor if output quantizer is set. Else returns None by default.""" - if not hasattr(kv_module, "k_bmm_quantizer") or not hasattr(kv_module, "v_bmm_quantizer"): +def get_kv_cache_scaling_factor(self_attention_module: nn.Module) -> torch.Tensor: + """ + Returns the k and v BMM scaling factors if BMM quantizers are set in the self attention module. + Else returns None by default. + """ + if not hasattr(self_attention_module, "k_bmm_quantizer") or not hasattr(self_attention_module, "v_bmm_quantizer"): return [None, None] scaling_factors = [ - get_scaling_factor(getattr(kv_module, quantizer)) + get_scaling_factor(getattr(self_attention_module, quantizer)) for quantizer in ("k_bmm_quantizer", "v_bmm_quantizer") ] - - # For FP8, we recommend default kv cache scaling factor to be 1. - if get_kv_cache_dtype(kv_module) == KV_CACHE_FP8: - for i, factor in enumerate(scaling_factors): - if factor.item() > 0.5: - warn( - f"Warning: Large KV activation detected: {factor.item()}, " - "Quantized KV cache may lead to higher accuracy drop." - ) - scaling_factors[i] = torch.max( - factor, torch.tensor([1.0], dtype=torch.float, device=factor.device) - ) - return scaling_factors + def get_kv_cache_dtype(modules: list[nn.Module] | nn.Module) -> str | None: """Returns the kv_cache dtype. @@ -445,8 +435,7 @@ def get_kv_cache_dtype(modules: list[nn.Module] | nn.Module) -> str | None: for module in modules: # Case where the module has both k_bmm_quantizer and v_bmm_quantizer - # Still check for output quantizer for the unified_megatron_export path - for quantizer in ("k_bmm_quantizer", "v_bmm_quantizer", "output_quantizer"): + for quantizer in ("k_bmm_quantizer", "v_bmm_quantizer"): quantizer_attr = getattr(module, quantizer, None) if quantizer_attr and quantizer_attr.is_enabled: num_bits_list.append(quantizer_attr.num_bits) diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index 7b3883509..412d4b513 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -52,6 +52,8 @@ from .quant_utils import ( get_activation_scaling_factor, get_kv_cache_dtype, + get_kv_cache_scaling_factor, + get_quant_config, get_quantization_format, get_scaling_factor, get_weight_block_size, @@ -86,33 +88,6 @@ ] -# This path uses output_quantizer for KV cache quantization. -# The function below is the old version of get_kv_cache_scaling_factor which is now refactored to handle bmm_quantizer. -def get_kv_cache_scaling_factor(kv_module: nn.Module) -> torch.Tensor: - """Returns the kv_cache scaling factor if output quantizer is set. Else returns None by default.""" - scaling_factor = ( - get_scaling_factor(kv_module.output_quantizer) - if hasattr(kv_module, "output_quantizer") - else None - ) - - if not scaling_factor: - return None - - # For FP8, we recommend default kv cache scaling factor to be 1. - if get_kv_cache_dtype(kv_module) == KV_CACHE_FP8: - if scaling_factor.item() > 0.5: - warn( - f"!!!!Large KV activations detected: {scaling_factor.item()}, " - "Quantized KV cache may lead to higher accuracy drop.\n!!!!" - ) - scaling_factor = torch.max( - scaling_factor, - torch.tensor([1.0], dtype=torch.float, device=scaling_factor.device), - ) - return scaling_factor - - class GPTModelExporter: """Megatron Core GPTModel Exporter. @@ -283,6 +258,7 @@ def save_pretrained( kv_cache_quantization = None kv_cache_dtype = get_kv_cache_dtype(self.model) + print("kv_cache_dtype: ", kv_cache_dtype) if kv_cache_dtype in (KV_CACHE_FP8, KV_CACHE_NVFP4): # FP8 KV Cache is supported in VLLM; NVFP4 supported in TRTLLM kv_cache_quantization = kv_cache_dtype @@ -320,7 +296,9 @@ def save_pretrained( pass if is_last_stage_main_rank and quantization is not None: - hf_quant_config = { + # TODO refactor to use mte.quant_utils.get_quant_config + # except layer names are different in MCore and HF + hf_quant_config = { "producer": { "name": "modelopt", "version": __version__, @@ -328,9 +306,11 @@ def save_pretrained( "quantization": { "quant_algo": quantization, "kv_cache_quant_algo": kv_cache_quantization, - "exclude_modules": ["lm_head"], + "exclude_modules": ["lm_head"], # TODO update this dynamically }, } + if quantization == "NVFP4": + hf_quant_config["quantization"]["group_size"] = 16 with open(save_directory + "/hf_quant_config.json", "w") as f: json.dump(hf_quant_config, f, indent=4) @@ -473,6 +453,7 @@ def _custom_mapping_to_lambda(mapping): method_map = { "name_remapping": self._name_remapping, "qkv_slicing": self._qkv_slicing, + "self_attention_scaling": self._self_attention_scaling, "gated_mlp_slicing": self._gated_mlp_slicing, "pack_name_remapping": self._pack_name_remapping, "pack_name_remapping_gpt_oss": self._pack_name_remapping_gpt_oss, @@ -541,12 +522,8 @@ def _get_quantized_state( # TODO (chenhany): support AWQ with pre_quant_scale if hasattr(module.input_quantizer, "_pre_quant_scale"): raise ValueError("Detect pre_quant_scale! SmoothQuant/AWQ are not yet supported!") - - if hasattr(module, "output_quantizer"): - output_scale = get_kv_cache_scaling_factor(module) - if output_scale is not None: - name_to_value["output_scale"] = output_scale - + + return name_to_value, qformat, block_size def _get_quantization_format(self, module: torch.nn.Module): @@ -674,9 +651,7 @@ def _qkv_slicing( q_proj_name="q_proj", k_proj_name="k_proj", v_proj_name="v_proj", - k_scale_name="k_scale", - v_scale_name="v_scale", - ): + ): name_to_value, qformat, block_size = self._get_quantized_state(module, self.dtype) q_proj_prefix = prefix + q_proj_name + "." @@ -774,10 +749,7 @@ def _qkv_slicing( q_proj_key = q_proj_prefix + key k_proj_key = k_proj_prefix + key v_proj_key = v_proj_prefix + key - if key == "output_scale": - self._state_dict[prefix + k_scale_name] = val.detach().clone() - self._state_dict[prefix + v_scale_name] = val.detach().clone() - elif key == "bias": + if key == "bias": # Slice bias similar to weight bias = val.detach().clone() bias = bias.reshape([qkv_total_dim, head_size]) @@ -790,6 +762,17 @@ def _qkv_slicing( self._state_dict[k_proj_key] = val.detach().clone() self._state_dict[v_proj_key] = val.detach().clone() + def _self_attention_scaling(self, module, prefix, k_scale_name="k_scale", v_scale_name="v_scale"): + """KV cache scaling for self attention module.""" + k_scale_key = prefix + k_scale_name + v_scale_key = prefix + v_scale_name + if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"): + kv_scales = get_kv_cache_scaling_factor(module) + if all(s is not None for s in kv_scales): + self._state_dict[k_scale_key] = kv_scales[0] + self._state_dict[v_scale_key] = kv_scales[1] + + def _pack_name_remapping(self, module, prefix, layer_type=None): """Pack name remapping into one tensor.""" weight_list = [] @@ -1149,6 +1132,8 @@ def _get_state_dict(self): self.rules["q_layernorm"](layer.self_attention.q_layernorm, layer_id) self.rules["k_layernorm"](layer.self_attention.k_layernorm, layer_id) self.rules["linear_qkv"](layer.self_attention.linear_qkv, layer_id) + if hasattr(layer.self_attention, "core_attention"): + self.rules["core_attention"](layer.self_attention.core_attention, layer_id) self.rules["linear_proj"](layer.self_attention.linear_proj, layer_id) if ( getattr(layer.self_attention.core_attention, "softmax_offset", None) @@ -1166,6 +1151,10 @@ def _get_state_dict(self): self.rules["router"]( layer.mlp.router, layer_id, dtype=self.moe_router_dtype ) + if hasattr(layer.mlp, "fc1_latent_proj") and layer.mlp.fc1_latent_proj is not None: + self.rules["fc1_latent_proj"](layer.mlp.fc1_latent_proj, layer_id) + if hasattr(layer.mlp, "fc2_latent_proj") and layer.mlp.fc2_latent_proj is not None: + self.rules["fc2_latent_proj"](layer.mlp.fc2_latent_proj, layer_id) if ( hasattr(layer.mlp, "shared_experts") and layer.mlp.shared_experts is not None diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 5ec9224f3..8a45d24fe 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -100,10 +100,6 @@ def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state): if hasattr(module, "sync_moe_local_experts_amax"): module.sync_moe_local_experts_amax() - # TODO just for testing - if "experts" in name and "weight_quantizer" in name: - assert child.amax is not None - # Step 2:Sync amax across data parallelism for name, module in model.named_modules(): if isinstance(module, QuantModule): @@ -165,7 +161,6 @@ def sync_quantizer_amax_across_tp( axes_for_sync=[None, -1], parallel_state=module.parallel_state, ) - sync_quantizer_amax_across_tp( module.weight_quantizer, name, diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 9e0e9c818..60c63ac27 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -595,8 +595,6 @@ def sync_moe_local_experts_amax(self): if stored_amax is None else torch.maximum(stored_amax, amax_tensor) ) - #if isinstance(module, TensorQuantizer) and module.amax is None: - # print(f"MISSING AMAX BEFORE SYNC in expert rank {dist.get_rank()}: {name}", flush=True) From aa752fcdff6aa45bdaa2d065467899b2324f08cf Mon Sep 17 00:00:00 2001 From: jenchen13 Date: Tue, 20 Jan 2026 12:26:17 -0800 Subject: [PATCH 04/20] small fixes Signed-off-by: jenchen13 --- modelopt/torch/export/unified_export_megatron.py | 6 +++--- modelopt/torch/quantization/model_calib.py | 11 ++++++----- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index 412d4b513..5fd41e91f 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -298,7 +298,7 @@ def save_pretrained( if is_last_stage_main_rank and quantization is not None: # TODO refactor to use mte.quant_utils.get_quant_config # except layer names are different in MCore and HF - hf_quant_config = { + hf_quant_config = { "producer": { "name": "modelopt", "version": __version__, @@ -309,7 +309,7 @@ def save_pretrained( "exclude_modules": ["lm_head"], # TODO update this dynamically }, } - if quantization == "NVFP4": + if quantization == "NVFP4": # update block size hf_quant_config["quantization"]["group_size"] = 16 with open(save_directory + "/hf_quant_config.json", "w") as f: json.dump(hf_quant_config, f, indent=4) @@ -763,7 +763,7 @@ def _qkv_slicing( self._state_dict[v_proj_key] = val.detach().clone() def _self_attention_scaling(self, module, prefix, k_scale_name="k_scale", v_scale_name="v_scale"): - """KV cache scaling for self attention module.""" + """KV cache scaling for CoreAttention module.""" k_scale_key = prefix + k_scale_name v_scale_key = prefix + v_scale_name if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"): diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 8a45d24fe..233469277 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -81,6 +81,11 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis forward_loop(model) finish_stats_collection(model) + # Step 1: Sync amax across local experts in a SequentialMLP + for name, module in model.named_modules(): + if hasattr(module, "sync_moe_local_experts_amax"): + module.sync_moe_local_experts_amax() + if not distributed_sync: return @@ -95,11 +100,7 @@ def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state): quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group) # TODO: create sync_bias_across_distributed_group - # Step 1: Sync amax across local experts in a SequentialMLP - for name, module in model.named_modules(): - if hasattr(module, "sync_moe_local_experts_amax"): - module.sync_moe_local_experts_amax() - + # Step 2:Sync amax across data parallelism for name, module in model.named_modules(): if isinstance(module, QuantModule): From 256578aac26928d31d798f60f5722f50a4400c08 Mon Sep 17 00:00:00 2001 From: jenchen13 Date: Thu, 22 Jan 2026 09:39:40 -0800 Subject: [PATCH 05/20] mtp import fixes Signed-off-by: jenchen13 --- .../torch/export/plugins/mcore_nemotron.py | 10 +- .../torch/export/plugins/megatron_importer.py | 333 ++++++++++-------- modelopt/torch/export/quant_utils.py | 31 +- .../torch/export/unified_export_megatron.py | 17 +- 4 files changed, 217 insertions(+), 174 deletions(-) diff --git a/modelopt/torch/export/plugins/mcore_nemotron.py b/modelopt/torch/export/plugins/mcore_nemotron.py index f857230ae..07f4656bf 100644 --- a/modelopt/torch/export/plugins/mcore_nemotron.py +++ b/modelopt/torch/export/plugins/mcore_nemotron.py @@ -86,12 +86,10 @@ "fc1_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc1_latent_proj.", REPLICATE), "fc2_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc2_latent_proj.", REPLICATE), # MTP - #"enorm": NameRemapping("mtp.layers.{}.enorm.", REPLICATE), - #"hnorm": NameRemapping("mtp.layers.{}.hnorm.", REPLICATE), - #"eh_proj": NameRemapping("mtp.layers.{}.eh_proj.", REPLICATE), - #"layer_norm": NameRemapping("mtp.layers.{}.final_layernorm.", REPLICATE), - #"norm": NameRemapping("mtp.layers.{}.norm", REPLICATE) - # "transformer_layer": NameRemapping("mtp.layers.{}.mixer", REPLICATE), + "mtp.enorm": NameRemapping("mtp.layers.{}.enorm.", REPLICATE), + "mtp.hnorm": NameRemapping("mtp.layers.{}.hnorm.", REPLICATE), + "mtp.eh_proj": NameRemapping("mtp.layers.{}.eh_proj.", REPLICATE), + "mtp.final_layernorm": NameRemapping("mtp.layers.{}.final_layernorm.", REPLICATE), } diff --git a/modelopt/torch/export/plugins/megatron_importer.py b/modelopt/torch/export/plugins/megatron_importer.py index f663e1921..42a483b99 100644 --- a/modelopt/torch/export/plugins/megatron_importer.py +++ b/modelopt/torch/export/plugins/megatron_importer.py @@ -150,7 +150,14 @@ def _name_remapping( mapping={}, parallel_config: ParallelConfig | None = None, dtype: torch.dtype | None = None, + is_mtp: bool = False, ): + if is_mtp: + if "backbone" in prefix: + prefix = prefix.replace("backbone", "mtp") + else: + prefix = prefix.replace("model", "mtp") + print(f"name_remapping: {prefix}, mapping: {mapping}") if dtype is None: dtype = self.dtype if isinstance(module, torch.Tensor): @@ -262,7 +269,13 @@ def _qkv_merging( k_proj_name="k_proj", v_proj_name="v_proj", parallel_config: ParallelConfig | None = None, + is_mtp: bool = False, ): + if is_mtp: + if "backbone" in prefix: + prefix = prefix.replace("backbone", "mtp") + else: + prefix = prefix.replace("model", "mtp") config = module.config hidden_size = config.hidden_size num_query_groups = config.num_query_groups @@ -469,9 +482,111 @@ def _unpack_name_remapping_gpt_oss( linear_module.load_state_dict(state_dict) + def _import_mamba_layer(self, layer, layer_id, layer_pbar): + layer_pbar.set_description("Importing Mamba layer") + if not isinstance(layer.norm, IdentityOp): + self.rules["norm"](layer.norm, layer_id) + + self.rules["mixer_norm"](layer.mixer.norm, layer_id) + self.rules["A_log"](layer.mixer.A_log, layer_id) + self.rules["D"](layer.mixer.D, layer_id) + self.rules["dt_bias"](layer.mixer.dt_bias, layer_id) + self.rules["conv1d"](layer.mixer.conv1d, layer_id) + self.rules["in_proj"](layer.mixer.in_proj, layer_id) + self.rules["out_proj"](layer.mixer.out_proj, layer_id) + + def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp=False): + if not isinstance(layer.input_layernorm, IdentityOp): + self.rules["input_layernorm"](layer.input_layernorm, layer_id, is_mtp=is_mtp) + + attention = layer.self_attention + if not isinstance(attention, IdentityOp): + if "MLASelfAttention" in str(type(attention)): + if hasattr(attention, "linear_q_proj"): + layer_pbar.set_description("Importing MLA (without q LoRA)") + self.rules["linear_q_proj"](attention.linear_q_proj, layer_id, is_mtp=is_mtp) + else: + layer_pbar.set_description("Importing MLA (with q LoRA)") + self.rules["linear_q_down_proj"](attention.linear_q_down_proj, layer_id, is_mtp=is_mtp) + self.rules["linear_q_layernorm"](attention.q_layernorm, layer_id, is_mtp=is_mtp) + self.rules["linear_q_up_proj"](attention.linear_q_up_proj, layer_id, is_mtp=is_mtp) + self.rules["linear_kv_down_proj"](attention.linear_kv_down_proj, layer_id, is_mtp=is_mtp) + self.rules["linear_kv_layernorm"](attention.kv_layernorm, layer_id, is_mtp=is_mtp) + self.rules["linear_kv_up_proj"](attention.linear_kv_up_proj, layer_id, is_mtp=is_mtp) + self.rules["linear_proj"](attention.linear_proj, layer_id, is_mtp=is_mtp) + else: + layer_pbar.set_description("Importing GQA/MHA") + if attention.q_layernorm is not None and not isinstance( + attention.q_layernorm, (IdentityOp, L2Norm) + ): + self.rules["q_layernorm"](attention.q_layernorm, layer_id, is_mtp=is_mtp) + self.rules["k_layernorm"](attention.k_layernorm, layer_id, is_mtp=is_mtp) + self.rules["linear_qkv"](attention.linear_qkv, layer_id, is_mtp=is_mtp) + self.rules["linear_proj"](attention.linear_proj, layer_id, is_mtp=is_mtp) + if getattr(attention.core_attention, "softmax_offset", None) is not None: + self.rules["softmax_offset"]( + attention.core_attention.softmax_offset, layer_id, is_mtp=is_mtp + ) + + if not isinstance(layer.pre_mlp_layernorm, IdentityOp): + self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id, is_mtp=is_mtp) + + if not isinstance(layer.mlp, IdentityOp): + if "MoE" in str(type(layer.mlp)): + layer_pbar.set_description("Importing MoE") + self.rules["router"]( + layer.mlp.router, layer_id, dtype=self.moe_router_dtype, is_mtp=is_mtp + ) + if ( + hasattr(layer.mlp, "shared_experts") + and layer.mlp.shared_experts is not None + ): + layer_pbar.set_description("Importing MoE shared experts") + fc1 = layer.mlp.shared_experts.linear_fc1 + fc2 = layer.mlp.shared_experts.linear_fc2 + self.rules["shared_experts.linear_fc1"](fc1, layer_id, is_mtp=is_mtp) + self.rules["shared_experts.linear_fc2"](fc2, layer_id, is_mtp=is_mtp) + if not self.rules.get("use_packed_local_experts", False): + for local_expert_id, expert in tqdm( + enumerate(layer.mlp.experts.local_experts), + desc="Importing MoE local experts", + leave=False, + disable=self.disable_tqdm, + ): + expert_id = layer.mlp.local_expert_indices[local_expert_id] + fc1 = expert.linear_fc1 + fc2 = expert.linear_fc2 + self.rules["local_experts.linear_fc1"](fc1, layer_id, expert_id, is_mtp=is_mtp) + self.rules["local_experts.linear_fc2"](fc2, layer_id, expert_id, is_mtp=is_mtp) + # We only support either EP or ETP for now + elif get_expert_tensor_parallel_world_size() > 1: + # ETP supports for packed MoE + # ETP is not supported for gpt-oss model + if self.arch in ["GptOssForCausalLM"]: + raise ValueError("ETP is not supported for gpt-oss model") + self.rules["local_experts.linear_fc1_etp"]( + layer.mlp.experts.local_experts, layer_id, is_mtp=is_mtp + ) + self.rules["local_experts.linear_fc2_etp"]( + layer.mlp.experts.local_experts, layer_id, is_mtp=is_mtp + ) + else: + # EP supports for packed MoE + self.rules["local_experts.linear_fc1_ep"]( + layer.mlp.experts.local_experts, layer_id, is_mtp=is_mtp + ) + self.rules["local_experts.linear_fc2_ep"]( + layer.mlp.experts.local_experts, layer_id, is_mtp=is_mtp + ) + else: + layer_pbar.set_description("Importing MLP") + self.rules["linear_fc1"](layer.mlp.linear_fc1, layer_id, is_mtp=is_mtp) + self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id, is_mtp=is_mtp) + + def _import_state_dict(self): model = self.model - + print(model, flush=True) layer_pbar = tqdm(model.decoder.layers, disable=self.disable_tqdm) # Embedding @@ -481,108 +596,13 @@ def _import_state_dict(self): # Decoder layers for layer in layer_pbar: + print(f"Importing layer {layer.layer_number}", flush=True) layer_id = layer.layer_number - 1 if isinstance(layer, MambaLayer): - if not isinstance(layer.norm, IdentityOp): - self.rules["norm"](layer.norm, layer_id) - - self.rules["mixer_norm"](layer.mixer.norm, layer_id) - self.rules["A_log"](layer.mixer.A_log, layer_id) - self.rules["D"](layer.mixer.D, layer_id) - self.rules["dt_bias"](layer.mixer.dt_bias, layer_id) - - self.rules["conv1d"](layer.mixer.conv1d, layer_id) - self.rules["in_proj"](layer.mixer.in_proj, layer_id) - self.rules["out_proj"](layer.mixer.out_proj, layer_id) - + self._import_mamba_layer(layer, layer_id, layer_pbar) elif isinstance(layer, TransformerLayer): - if not isinstance(layer.input_layernorm, IdentityOp): - self.rules["input_layernorm"](layer.input_layernorm, layer_id) - - attention = layer.self_attention - if not isinstance(attention, IdentityOp): - if "MLASelfAttention" in str(type(attention)): - if hasattr(attention, "linear_q_proj"): - layer_pbar.set_description("Importing MLA (without q LoRA)") - self.rules["linear_q_proj"](attention.linear_q_proj, layer_id) - else: - layer_pbar.set_description("Importing MLA (with q LoRA)") - self.rules["linear_q_down_proj"](attention.linear_q_down_proj, layer_id) - self.rules["linear_q_layernorm"](attention.q_layernorm, layer_id) - self.rules["linear_q_up_proj"](attention.linear_q_up_proj, layer_id) - self.rules["linear_kv_down_proj"](attention.linear_kv_down_proj, layer_id) - self.rules["linear_kv_layernorm"](attention.kv_layernorm, layer_id) - self.rules["linear_kv_up_proj"](attention.linear_kv_up_proj, layer_id) - self.rules["linear_proj"](attention.linear_proj, layer_id) - else: - layer_pbar.set_description("Importing GQA/MHA") - if attention.q_layernorm is not None and not isinstance( - attention.q_layernorm, (IdentityOp, L2Norm) - ): - self.rules["q_layernorm"](attention.q_layernorm, layer_id) - self.rules["k_layernorm"](attention.k_layernorm, layer_id) - self.rules["linear_qkv"](attention.linear_qkv, layer_id) - self.rules["linear_proj"](attention.linear_proj, layer_id) - if getattr(attention.core_attention, "softmax_offset", None) is not None: - self.rules["softmax_offset"]( - attention.core_attention.softmax_offset, layer_id - ) - - if not isinstance(layer.pre_mlp_layernorm, IdentityOp): - self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id) - - if not isinstance(layer.mlp, IdentityOp): - if "MoE" in str(type(layer.mlp)): - layer_pbar.set_description("Importing MoE") - self.rules["router"]( - layer.mlp.router, layer_id, dtype=self.moe_router_dtype - ) - if ( - hasattr(layer.mlp, "shared_experts") - and layer.mlp.shared_experts is not None - ): - layer_pbar.set_description("Importing MoE shared experts") - fc1 = layer.mlp.shared_experts.linear_fc1 - fc2 = layer.mlp.shared_experts.linear_fc2 - self.rules["shared_experts.linear_fc1"](fc1, layer_id) - self.rules["shared_experts.linear_fc2"](fc2, layer_id) - if not self.rules.get("use_packed_local_experts", False): - for local_expert_id, expert in tqdm( - enumerate(layer.mlp.experts.local_experts), - desc="Importing MoE local experts", - leave=False, - disable=self.disable_tqdm, - ): - expert_id = layer.mlp.local_expert_indices[local_expert_id] - fc1 = expert.linear_fc1 - fc2 = expert.linear_fc2 - self.rules["local_experts.linear_fc1"](fc1, layer_id, expert_id) - self.rules["local_experts.linear_fc2"](fc2, layer_id, expert_id) - # We only support either EP or ETP for now - elif get_expert_tensor_parallel_world_size() > 1: - # ETP supports for packed MoE - # ETP is not supported for gpt-oss model - if self.arch in ["GptOssForCausalLM"]: - raise ValueError("ETP is not supported for gpt-oss model") - self.rules["local_experts.linear_fc1_etp"]( - layer.mlp.experts.local_experts, layer_id - ) - self.rules["local_experts.linear_fc2_etp"]( - layer.mlp.experts.local_experts, layer_id - ) - else: - # EP supports for packed MoE - self.rules["local_experts.linear_fc1_ep"]( - layer.mlp.experts.local_experts, layer_id - ) - self.rules["local_experts.linear_fc2_ep"]( - layer.mlp.experts.local_experts, layer_id - ) - else: - layer_pbar.set_description("Importing MLP") - self.rules["linear_fc1"](layer.mlp.linear_fc1, layer_id) - self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id) + self._import_transformer_layer(layer, layer_id, layer_pbar) if self.verbose: print( @@ -591,71 +611,92 @@ def _import_state_dict(self): ), flush=True, ) + break # TODO: remove this # Final layernorm if hasattr(model.decoder, "final_layernorm") and model.decoder.final_layernorm: self.rules["final_layernorm"](model.decoder.final_layernorm) - if hasattr(model.decoder, "final_norm") and model.decoder.final_norm: self.rules["final_norm"](model.decoder.final_norm) # Output layer if hasattr(model, "output_layer") and not model.share_embeddings_and_output_weights: self.rules["output_layer"](model.output_layer) + # MTP if hasattr(model, "mtp"): + print("Importing MTP", flush=True) # MTP is the last layer in DeepSeek V3/R1 - layer_id += 1 - for mtp in model.mtp: - self.rules["mtp.fc"](mtp.fc, layer_id) + if len(model.mtp.layers) == 1: # Repeated MTP + layer_id = 0 # reset layer_id for repeated MTP + mtp = model.mtp.layers[0] + + self.rules["mtp.eh_proj"](mtp.eh_proj, layer_id) self.rules["mtp.enorm"](mtp.enorm, layer_id) self.rules["mtp.hnorm"](mtp.hnorm, layer_id) - self.rules["mtp.input_layernorm"](mtp.decoder.layers[0].input_layernorm, layer_id) - if hasattr(mtp.decoder.layers[0].self_attention, "linear_q_proj"): - self.rules["mtp.linear_q_proj"]( - mtp.decoder.layers[0].self_attention.linear_q_proj, layer_id + + mtp_model_layers = mtp.mtp_model_layer.layers + for mtp_model_layer in mtp_model_layers: + if isinstance(mtp_model_layer, MambaLayer): + self._import_mamba_layer(mtp_model_layer, layer_id, layer_pbar) + elif isinstance(mtp_model_layer, TransformerLayer): + self._import_transformer_layer(mtp_model_layer, layer_id, layer_pbar, is_mtp=True) + else: + raise ValueError(f"Unsupported layer type during MTP import: {type(mtp_model_layer)}") + + layer_id += 1 + else: # non-repeated MTP + + for mtp in model.mtp.layers: + self.rules["mtp.eh_proj"](mtp.eh_proj, layer_id) + self.rules["mtp.enorm"](mtp.enorm, layer_id) + self.rules["mtp.hnorm"](mtp.hnorm, layer_id) + self.rules["mtp.input_layernorm"](mtp.decoder.layers[0].input_layernorm, layer_id) + if hasattr(mtp.decoder.layers[0].self_attention, "linear_q_proj"): + self.rules["mtp.linear_q_proj"]( + mtp.decoder.layers[0].self_attention.linear_q_proj, layer_id + ) + else: + self.rules["mtp.linear_q_down_proj"]( + mtp.decoder.layers[0].self_attention.linear_q_down_proj, layer_id + ) + self.rules["mtp.linear_q_layernorm"]( + mtp.decoder.layers[0].self_attention.q_layernorm, layer_id + ) + self.rules["mtp.linear_q_up_proj"]( + mtp.decoder.layers[0].self_attention.linear_q_up_proj, layer_id + ) + self.rules["mtp.linear_kv_down_proj"]( + mtp.decoder.layers[0].self_attention.linear_kv_down_proj, layer_id ) - else: - self.rules["mtp.linear_q_down_proj"]( - mtp.decoder.layers[0].self_attention.linear_q_down_proj, layer_id + self.rules["mtp.linear_kv_layernorm"]( + mtp.decoder.layers[0].self_attention.kv_layernorm, layer_id ) - self.rules["mtp.linear_q_layernorm"]( - mtp.decoder.layers[0].self_attention.q_layernorm, layer_id + self.rules["mtp.linear_kv_up_proj"]( + mtp.decoder.layers[0].self_attention.linear_kv_up_proj, layer_id ) - self.rules["mtp.linear_q_up_proj"]( - mtp.decoder.layers[0].self_attention.linear_q_up_proj, layer_id + self.rules["mtp.linear_proj"]( + mtp.decoder.layers[0].self_attention.linear_proj, layer_id ) - self.rules["mtp.linear_kv_down_proj"]( - mtp.decoder.layers[0].self_attention.linear_kv_down_proj, layer_id - ) - self.rules["mtp.linear_kv_layernorm"]( - mtp.decoder.layers[0].self_attention.kv_layernorm, layer_id - ) - self.rules["mtp.linear_kv_up_proj"]( - mtp.decoder.layers[0].self_attention.linear_kv_up_proj, layer_id - ) - self.rules["mtp.linear_proj"]( - mtp.decoder.layers[0].self_attention.linear_proj, layer_id - ) - self.rules["mtp.pre_mlp_layernorm"]( - mtp.decoder.layers[0].pre_mlp_layernorm, layer_id - ) - self.rules["mtp.router"](mtp.decoder.layers[0].mlp.router, layer_id) - self.rules["mtp.shared_experts.linear_fc1"]( - mtp.decoder.layers[0].mlp.shared_experts.linear_fc1, layer_id - ) - self.rules["mtp.shared_experts.linear_fc2"]( - mtp.decoder.layers[0].mlp.shared_experts.linear_fc2, layer_id - ) - for expert_id, expert in tqdm( - enumerate(mtp.decoder.layers[0].mlp.experts.local_experts), - desc="Importing MoE local experts", - leave=False, - disable=self.disable_tqdm, - ): - self.rules["mtp.local_experts.linear_fc1"]( - expert.linear_fc1, layer_id, expert_id + self.rules["mtp.pre_mlp_layernorm"]( + mtp.decoder.layers[0].pre_mlp_layernorm, layer_id ) - self.rules["mtp.local_experts.linear_fc2"]( - expert.linear_fc2, layer_id, expert_id + self.rules["mtp.router"](mtp.decoder.layers[0].mlp.router, layer_id) + self.rules["mtp.shared_experts.linear_fc1"]( + mtp.decoder.layers[0].mlp.shared_experts.linear_fc1, layer_id ) + self.rules["mtp.shared_experts.linear_fc2"]( + mtp.decoder.layers[0].mlp.shared_experts.linear_fc2, layer_id + ) + for expert_id, expert in tqdm( + enumerate(mtp.decoder.layers[0].mlp.experts.local_experts), + desc="Importing MoE local experts", + leave=False, + disable=self.disable_tqdm, + ): + self.rules["mtp.local_experts.linear_fc1"]( + expert.linear_fc1, layer_id, expert_id + ) + self.rules["mtp.local_experts.linear_fc2"]( + expert.linear_fc2, layer_id, expert_id + ) diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index 00f03b9aa..48dec6ae1 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -435,12 +435,29 @@ def get_kv_cache_dtype(modules: list[nn.Module] | nn.Module) -> str | None: for module in modules: # Case where the module has both k_bmm_quantizer and v_bmm_quantizer - for quantizer in ("k_bmm_quantizer", "v_bmm_quantizer"): + # Still check for output quantizer for the unified_megatron_export path + for quantizer in ("k_bmm_quantizer", "v_bmm_quantizer", "output_quantizer"): quantizer_attr = getattr(module, quantizer, None) if quantizer_attr and quantizer_attr.is_enabled: num_bits_list.append(quantizer_attr.num_bits) is_affine &= hasattr(quantizer_attr, "_bias_value") + return _compute_kv_cache_dtype(num_bits_list) + +def _compute_kv_cache_dtype(num_bits_list: list[int]) -> str | None: + """Returns the kv_cache dtype. + + If num_bits of output_quantizer is (4, 3) then returns FP8; if it is 8, returns int8, + otherwise returns None. + + Args: + modules: The module or list of modules to inspect. + + Returns: + The kv_cache dtype. + """ + is_affine = True + if (4, 3) in num_bits_list: return KV_CACHE_FP8 elif 8 in num_bits_list: @@ -996,18 +1013,6 @@ def postprocess_state_dict( value = value.float() / maxbound - # Warn if scale exceeds threshold - if quantization == KV_CACHE_FP8 and value.item() > 0.5: - logger.warning( - "Large KV activations detected. Quantized KV cache may lead to higher accuracy drop. " - "Setting KV cache scaling factor to at least 1." - ) - - # Ensure scale is at least 1 for KV_CACHE_FP8 - # We export real value for KV_CACHE_NVFP4 - if quantization == KV_CACHE_FP8: - value.clamp_(min=1.0) - post_state_dict[prefix + new_suffix] = value break diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index 5fd41e91f..e55183aed 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -51,8 +51,8 @@ from .plugins.megatron_importer import GPTModelImporter from .quant_utils import ( get_activation_scaling_factor, - get_kv_cache_dtype, get_kv_cache_scaling_factor, + get_kv_cache_dtype, get_quant_config, get_quantization_format, get_scaling_factor, @@ -256,12 +256,6 @@ def save_pretrained( elif quantization_format == QUANTIZATION_NVFP4: quantization = "NVFP4" - kv_cache_quantization = None - kv_cache_dtype = get_kv_cache_dtype(self.model) - print("kv_cache_dtype: ", kv_cache_dtype) - if kv_cache_dtype in (KV_CACHE_FP8, KV_CACHE_NVFP4): - # FP8 KV Cache is supported in VLLM; NVFP4 supported in TRTLLM - kv_cache_quantization = kv_cache_dtype # We use the last PP rank and the 1st EP rank to write the config because # medusa_heads and eagle_module only exist in the last stage. if is_last_stage_main_rank: @@ -305,12 +299,13 @@ def save_pretrained( }, "quantization": { "quant_algo": quantization, - "kv_cache_quant_algo": kv_cache_quantization, "exclude_modules": ["lm_head"], # TODO update this dynamically }, } if quantization == "NVFP4": # update block size hf_quant_config["quantization"]["group_size"] = 16 + if hasattr(self, "kv_cache_dtype"): + hf_quant_config["quantization"]["kv_cache_quant_algo"] = self.kv_cache_dtype with open(save_directory + "/hf_quant_config.json", "w") as f: json.dump(hf_quant_config, f, indent=4) @@ -731,7 +726,7 @@ def _qkv_slicing( quantized_weight = to_quantized_weight( weight, scale, - qformat, + qformat, weight_scale_2, block_size, ) @@ -772,6 +767,10 @@ def _self_attention_scaling(self, module, prefix, k_scale_name="k_scale", v_scal self._state_dict[k_scale_key] = kv_scales[0] self._state_dict[v_scale_key] = kv_scales[1] + kv_cache_dtype = get_kv_cache_dtype(module) + if kv_cache_dtype in (KV_CACHE_FP8, KV_CACHE_NVFP4): + # FP8 KV Cache is supported in VLLM; NVFP4 supported in TRTLLM + self.kv_cache_dtype = kv_cache_dtype def _pack_name_remapping(self, module, prefix, layer_type=None): """Pack name remapping into one tensor.""" From 75dd530dd60e69e73622b91eca39001bf5c1862a Mon Sep 17 00:00:00 2001 From: jenchen13 Date: Thu, 22 Jan 2026 13:29:08 -0800 Subject: [PATCH 06/20] enable TELinear quant Signed-off-by: jenchen13 --- modelopt/torch/export/plugins/mcore_nemotron.py | 6 ++++++ modelopt/torch/quantization/plugins/megatron.py | 7 +++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/modelopt/torch/export/plugins/mcore_nemotron.py b/modelopt/torch/export/plugins/mcore_nemotron.py index 07f4656bf..d2cb2858a 100644 --- a/modelopt/torch/export/plugins/mcore_nemotron.py +++ b/modelopt/torch/export/plugins/mcore_nemotron.py @@ -131,4 +131,10 @@ # Latent MoE "fc1_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc1_latent_proj."), "fc2_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc2_latent_proj."), + # MTP + "mtp.enorm": NameRemapping("mtp.layers.{}.enorm."), + "mtp.hnorm": NameRemapping("mtp.layers.{}.hnorm."), + "mtp.eh_proj": NameRemapping("mtp.layers.{}.eh_proj."), + "mtp.final_layernorm": NameRemapping("mtp.layers.{}.final_layernorm."), + } diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 60c63ac27..03b3888f3 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -50,6 +50,7 @@ try: from megatron.core.extensions.transformer_engine import ( + TELinear, TEColumnParallelGroupedLinear, TEColumnParallelLinear, TEDotProductAttention, @@ -603,8 +604,6 @@ def sync_moe_local_experts_amax(self): for name, module in expert.named_modules(): if isinstance(module, TensorQuantizer) and module.amax is not None: module.amax = amax_dict[name].detach().clone().to(module.amax.device) - #if isinstance(module, TensorQuantizer) and module.amax is None: - # print(f"MISSING AMAX AFTER SYNC in expert rank {dist.get_rank()}: {name}", flush=True) def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): """Override the default to enable singleton_local_shards. @@ -624,6 +623,10 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): if HAS_TE: + @QuantModuleRegistry.register({TELinear: "te_mcore_Linear"}) + class _QuantTEMCoreLinear(_QuantTELinear): + pass + @QuantModuleRegistry.register({TERowParallelLinear: "te_mcore_RowParallelLinear"}) class _QuantTEMCoreRowParallelLinear(_QuantTELinear, _MegatronRowParallelLinear): pass From d99b6a5a39c9810b854c861f6f9511742f194fd5 Mon Sep 17 00:00:00 2001 From: jenchen13 Date: Fri, 23 Jan 2026 15:09:05 -0800 Subject: [PATCH 07/20] import grouped mlp in mtp Signed-off-by: jenchen13 --- modelopt/torch/export/plugins/mcore_custom.py | 9 ++ .../torch/export/plugins/mcore_nemotron.py | 15 +- .../torch/export/plugins/megatron_importer.py | 149 ++++++++++++------ .../torch/export/unified_export_megatron.py | 2 +- modelopt/torch/quantization/model_calib.py | 1 - 5 files changed, 123 insertions(+), 53 deletions(-) diff --git a/modelopt/torch/export/plugins/mcore_custom.py b/modelopt/torch/export/plugins/mcore_custom.py index 25a2cd0cb..77a3208bb 100644 --- a/modelopt/torch/export/plugins/mcore_custom.py +++ b/modelopt/torch/export/plugins/mcore_custom.py @@ -102,7 +102,16 @@ def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any] func_kwargs=func_kwargs, ) +class GroupedMLPMerging(CustomModuleMapping): + """A custom module mapping that merges up_proj and down_proj for Grouped MLP.""" + def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any] = {}): + """Create a custom module mapping that merges up_proj and down_proj for Grouped MLP.""" + super().__init__( + func_name="grouped_mlp_merging", + target_name_or_prefix=target_name_or_prefix, + func_kwargs=func_kwargs, + ) class GatedMLPMerging(CustomModuleMapping): """A custom module mapping that merges gate_proj and up_proj.""" diff --git a/modelopt/torch/export/plugins/mcore_nemotron.py b/modelopt/torch/export/plugins/mcore_nemotron.py index d2cb2858a..385418a37 100644 --- a/modelopt/torch/export/plugins/mcore_nemotron.py +++ b/modelopt/torch/export/plugins/mcore_nemotron.py @@ -26,6 +26,7 @@ NameRemapping, QKVMerging, QKVSlicing, + GroupedMLPMerging, SelfAttentionScaling, ) @@ -85,12 +86,14 @@ # Latent MoE "fc1_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc1_latent_proj.", REPLICATE), "fc2_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc2_latent_proj.", REPLICATE), - # MTP - "mtp.enorm": NameRemapping("mtp.layers.{}.enorm.", REPLICATE), - "mtp.hnorm": NameRemapping("mtp.layers.{}.hnorm.", REPLICATE), - "mtp.eh_proj": NameRemapping("mtp.layers.{}.eh_proj.", REPLICATE), - "mtp.final_layernorm": NameRemapping("mtp.layers.{}.final_layernorm.", REPLICATE), - + # Repeated MTP module + "mtp.enorm": NameRemapping("mtp.layers.{}.enorm.", {"is_mtp": True}), + "mtp.hnorm": NameRemapping("mtp.layers.{}.hnorm.", {"is_mtp": True}), + "mtp.eh_proj": NameRemapping("mtp.layers.{}.eh_proj.", {"is_mtp": True}), + "mtp.final_layernorm": NameRemapping("mtp.layers.{}.final_layernorm.", {"is_mtp": True}), + # Grouped local experts in MTP + "experts.linear_fc1": GroupedMLPMerging("mtp.layers.{}.experts.{{}}.up_proj", COL_ETP | {"is_mtp": True}), + "experts.linear_fc2": GroupedMLPMerging("mtp.layers.{}.experts.{{}}.down_proj", ROW_ETP | {"is_mtp": True}), } diff --git a/modelopt/torch/export/plugins/megatron_importer.py b/modelopt/torch/export/plugins/megatron_importer.py index 42a483b99..79ce62573 100644 --- a/modelopt/torch/export/plugins/megatron_importer.py +++ b/modelopt/torch/export/plugins/megatron_importer.py @@ -19,7 +19,7 @@ from pathlib import Path import torch -import torch.distributed +import torch.distributed as dist from huggingface_hub import snapshot_download from tqdm import tqdm @@ -40,6 +40,7 @@ with import_plugin("megatron"): from megatron.core.parallel_state import ( get_expert_tensor_parallel_world_size, + get_expert_tensor_parallel_rank, get_tensor_model_parallel_world_size, ) from megatron.core.ssm.mamba_layer import MambaLayer @@ -94,12 +95,12 @@ def __init__( if workspace_dir is None: workspace_dir = tempfile.gettempdir() pretrained_model_path = workspace_dir + "/" + pretrained_model_name_or_path - if torch.distributed.get_rank() == 0: + if dist.get_rank() == 0: snapshot_download( repo_id=pretrained_model_name_or_path, local_dir=pretrained_model_path, ) - torch.distributed.barrier() + dist.barrier() self.arch = self._hf_config.architectures[0] self.all_rules = self._populate_rule_book() self.rules = self.all_rules[self.arch] @@ -108,7 +109,7 @@ def __init__( self.dtype = dtype self.dequantize = dequantize self.verbose = verbose - self.disable_tqdm = torch.distributed.get_rank() > 0 or verbose + self.disable_tqdm = dist.get_rank() > 0 or verbose def _populate_rule_book(self): """The rule book maps each state_dict key to a Callable.""" @@ -119,6 +120,7 @@ def _custom_mapping_to_lambda(mapping): "name_remapping": self._name_remapping, "qkv_merging": self._qkv_merging, "gated_mlp_merging": self._gated_mlp_merging, + "grouped_mlp_merging": self._grouped_mlp_merging, "unpack_name_remapping": self._unpack_name_remapping, "unpack_name_remapping_gpt_oss": self._unpack_name_remapping_gpt_oss, } @@ -157,7 +159,6 @@ def _name_remapping( prefix = prefix.replace("backbone", "mtp") else: prefix = prefix.replace("model", "mtp") - print(f"name_remapping: {prefix}, mapping: {mapping}") if dtype is None: dtype = self.dtype if isinstance(module, torch.Tensor): @@ -261,6 +262,37 @@ def _gated_mlp_merging( module.load_state_dict(state_dict) + def _grouped_mlp_merging( + self, + module, + prefix, + parallel_config: ParallelConfig | None = None, + is_mtp: bool = False, + init_expert_id: int = 0, + num_local_experts: int = 1, + ): + if is_mtp: + if "backbone" in prefix: + prefix = prefix.replace("backbone", "mtp") + else: + prefix = prefix.replace("model", "mtp") + + state_dict = module.state_dict() + weight = state_dict.get("weight", None) + print(f"mcore weight.shape: {weight.shape}") + weight_scale = state_dict.get("weight_quantizer._scale", None) + + all_experts = [] + for expert_id in range(init_expert_id, init_expert_id + num_local_experts): + tensor = self._get_safetensor(prefix.format(expert_id) + ".weight") + print(f"HF weight.shape: {tensor.shape}") + all_experts.append(tensor) + all_experts = torch.cat(all_experts, dim=0) + print(f"all_experts.shape: {all_experts.shape}") + state_dict["weight"] = all_experts + + module.load_state_dict(state_dict) + def _qkv_merging( self, module, @@ -302,8 +334,9 @@ def _qkv_merging( state_dict = {} - weight = module.state_dict().get("weight", None) - weight_scale = module.state_dict().get("weight_quantizer._scale", None) + module_state_dict = module.state_dict() + weight = module_state_dict.get("weight", None) + weight_scale = module_state_dict.get("weight_quantizer._scale", None) if weight is None: raise ValueError(f"{module!s} does not contain weight!") @@ -357,7 +390,7 @@ def _qkv_merging( state_dict["weight"] = tensor.reshape(-1, hidden_size) # Handle bias merging - bias = module.state_dict().get("bias", None) + bias = module_state_dict.get("bias", None) if bias is not None: q_bias = self._get_safetensor( prefix + q_proj_name + ".bias", parallel_config=parallel_config @@ -384,6 +417,11 @@ def _qkv_merging( state_dict["bias"] = bias_tensor.reshape(-1) + layer_norm_weight = module_state_dict.get("layer_norm_weight", None) + if layer_norm_weight is not None: + state_dict["layer_norm_weight"] = layer_norm_weight + state_dict["_extra_state"] = None # for TE modules require _extra_state key + module.load_state_dict(state_dict) def _unpack_name_remapping( @@ -495,47 +533,47 @@ def _import_mamba_layer(self, layer, layer_id, layer_pbar): self.rules["in_proj"](layer.mixer.in_proj, layer_id) self.rules["out_proj"](layer.mixer.out_proj, layer_id) - def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp=False): + def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool = False): if not isinstance(layer.input_layernorm, IdentityOp): - self.rules["input_layernorm"](layer.input_layernorm, layer_id, is_mtp=is_mtp) + self.rules["input_layernorm"](layer.input_layernorm, layer_id) attention = layer.self_attention if not isinstance(attention, IdentityOp): if "MLASelfAttention" in str(type(attention)): if hasattr(attention, "linear_q_proj"): layer_pbar.set_description("Importing MLA (without q LoRA)") - self.rules["linear_q_proj"](attention.linear_q_proj, layer_id, is_mtp=is_mtp) + self.rules["linear_q_proj"](attention.linear_q_proj, layer_id) else: layer_pbar.set_description("Importing MLA (with q LoRA)") - self.rules["linear_q_down_proj"](attention.linear_q_down_proj, layer_id, is_mtp=is_mtp) - self.rules["linear_q_layernorm"](attention.q_layernorm, layer_id, is_mtp=is_mtp) - self.rules["linear_q_up_proj"](attention.linear_q_up_proj, layer_id, is_mtp=is_mtp) - self.rules["linear_kv_down_proj"](attention.linear_kv_down_proj, layer_id, is_mtp=is_mtp) - self.rules["linear_kv_layernorm"](attention.kv_layernorm, layer_id, is_mtp=is_mtp) - self.rules["linear_kv_up_proj"](attention.linear_kv_up_proj, layer_id, is_mtp=is_mtp) - self.rules["linear_proj"](attention.linear_proj, layer_id, is_mtp=is_mtp) + self.rules["linear_q_down_proj"](attention.linear_q_down_proj, layer_id) + self.rules["linear_q_layernorm"](attention.q_layernorm, layer_id) + self.rules["linear_q_up_proj"](attention.linear_q_up_proj, layer_id) + self.rules["linear_kv_down_proj"](attention.linear_kv_down_proj, layer_id) + self.rules["linear_kv_layernorm"](attention.kv_layernorm, layer_id) + self.rules["linear_kv_up_proj"](attention.linear_kv_up_proj, layer_id) + self.rules["linear_proj"](attention.linear_proj, layer_id) else: layer_pbar.set_description("Importing GQA/MHA") if attention.q_layernorm is not None and not isinstance( attention.q_layernorm, (IdentityOp, L2Norm) ): - self.rules["q_layernorm"](attention.q_layernorm, layer_id, is_mtp=is_mtp) - self.rules["k_layernorm"](attention.k_layernorm, layer_id, is_mtp=is_mtp) + self.rules["q_layernorm"](attention.q_layernorm, layer_id) + self.rules["k_layernorm"](attention.k_layernorm, layer_id) self.rules["linear_qkv"](attention.linear_qkv, layer_id, is_mtp=is_mtp) self.rules["linear_proj"](attention.linear_proj, layer_id, is_mtp=is_mtp) if getattr(attention.core_attention, "softmax_offset", None) is not None: self.rules["softmax_offset"]( - attention.core_attention.softmax_offset, layer_id, is_mtp=is_mtp + attention.core_attention.softmax_offset, layer_id ) if not isinstance(layer.pre_mlp_layernorm, IdentityOp): - self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id, is_mtp=is_mtp) + self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id) if not isinstance(layer.mlp, IdentityOp): if "MoE" in str(type(layer.mlp)): layer_pbar.set_description("Importing MoE") self.rules["router"]( - layer.mlp.router, layer_id, dtype=self.moe_router_dtype, is_mtp=is_mtp + layer.mlp.router, layer_id, dtype=self.moe_router_dtype ) if ( hasattr(layer.mlp, "shared_experts") @@ -544,20 +582,41 @@ def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp=False): layer_pbar.set_description("Importing MoE shared experts") fc1 = layer.mlp.shared_experts.linear_fc1 fc2 = layer.mlp.shared_experts.linear_fc2 - self.rules["shared_experts.linear_fc1"](fc1, layer_id, is_mtp=is_mtp) - self.rules["shared_experts.linear_fc2"](fc2, layer_id, is_mtp=is_mtp) - if not self.rules.get("use_packed_local_experts", False): - for local_expert_id, expert in tqdm( - enumerate(layer.mlp.experts.local_experts), - desc="Importing MoE local experts", - leave=False, - disable=self.disable_tqdm, - ): - expert_id = layer.mlp.local_expert_indices[local_expert_id] - fc1 = expert.linear_fc1 - fc2 = expert.linear_fc2 - self.rules["local_experts.linear_fc1"](fc1, layer_id, expert_id, is_mtp=is_mtp) - self.rules["local_experts.linear_fc2"](fc2, layer_id, expert_id, is_mtp=is_mtp) + self.rules["shared_experts.linear_fc1"](fc1, layer_id) + self.rules["shared_experts.linear_fc2"](fc2, layer_id) + if not self.rules.get("use_packed_local_experts", False): # Import local experts + experts = layer.mlp.experts + if hasattr(experts, "local_experts"): + for local_expert_id, expert in tqdm( + enumerate(layer.mlp.experts.local_experts), + desc="Importing MoE local experts", + leave=False, + disable=self.disable_tqdm, + ): + expert_id = layer.mlp.local_expert_indices[local_expert_id] + fc1 = expert.linear_fc1 + fc2 = expert.linear_fc2 + self.rules["local_experts.linear_fc1"](fc1, layer_id, expert_id) + self.rules["local_experts.linear_fc2"](fc2, layer_id, expert_id) + else: # Slice TEGroupedMLP + layer_pbar.set_description("Importing MoE grouped local experts") + num_local_experts = experts.num_local_experts + num_global_experts = experts.config.num_moe_experts + print(f"num_local_experts: {num_local_experts}") + print(f"num_global_experts: {num_global_experts}") + + if parallel_config is not None: + etp_size = get_expert_tensor_parallel_world_size() + # etp_rank = get_expert_tensor_parallel_rank() # this gives group rank + etp_rank = dist.get_rank() + print(f"etp_size: {etp_size}") + print(f"etp_rank: {etp_rank}") + assert num_local_experts * etp_size == num_global_experts + init_index = etp_rank * num_local_experts + + self.rules["experts.linear_fc1"](experts.linear_fc1, layer_id, init_expert_id=init_index, num_local_experts=num_local_experts) + self.rules["experts.linear_fc2"](experts.linear_fc2, layer_id, init_expert_id=init_index, num_local_experts=num_local_experts ) + # We only support either EP or ETP for now elif get_expert_tensor_parallel_world_size() > 1: # ETP supports for packed MoE @@ -565,28 +624,28 @@ def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp=False): if self.arch in ["GptOssForCausalLM"]: raise ValueError("ETP is not supported for gpt-oss model") self.rules["local_experts.linear_fc1_etp"]( - layer.mlp.experts.local_experts, layer_id, is_mtp=is_mtp + layer.mlp.experts.local_experts, layer_id ) self.rules["local_experts.linear_fc2_etp"]( - layer.mlp.experts.local_experts, layer_id, is_mtp=is_mtp + layer.mlp.experts.local_experts, layer_id ) else: # EP supports for packed MoE self.rules["local_experts.linear_fc1_ep"]( - layer.mlp.experts.local_experts, layer_id, is_mtp=is_mtp + layer.mlp.experts.local_experts, layer_id ) self.rules["local_experts.linear_fc2_ep"]( - layer.mlp.experts.local_experts, layer_id, is_mtp=is_mtp + layer.mlp.experts.local_experts, layer_id ) else: layer_pbar.set_description("Importing MLP") - self.rules["linear_fc1"](layer.mlp.linear_fc1, layer_id, is_mtp=is_mtp) - self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id, is_mtp=is_mtp) + self.rules["linear_fc1"](layer.mlp.linear_fc1, layer_id) + self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id) def _import_state_dict(self): model = self.model - print(model, flush=True) + # print(model, flush=True) layer_pbar = tqdm(model.decoder.layers, disable=self.disable_tqdm) # Embedding @@ -607,7 +666,7 @@ def _import_state_dict(self): if self.verbose: print( "{:3}/{:3} completes importing layer {:3}.".format( - torch.distributed.get_rank(), torch.distributed.get_world_size(), layer_id + dist.get_rank(), dist.get_world_size(), layer_id ), flush=True, ) diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index e55183aed..a2d70a20f 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -726,7 +726,7 @@ def _qkv_slicing( quantized_weight = to_quantized_weight( weight, scale, - qformat, + qformat, weight_scale_2, block_size, ) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 233469277..5360e9415 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -187,7 +187,6 @@ def sync_quantizer_amax_across_tp( parallel_state=module.parallel_state, ) - # KV Cache Quantization if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"): # We only support KVCache quantization with scalar per-tensor states for now (NVFP4 & FP8 KV cache) From d6e62e840ba277a67a427fce736132c83bed9b16 Mon Sep 17 00:00:00 2001 From: jenchen13 Date: Mon, 26 Jan 2026 13:17:54 -0800 Subject: [PATCH 08/20] fix grouped mlp import Signed-off-by: jenchen13 --- .../torch/export/plugins/mcore_nemotron.py | 4 ++-- .../torch/export/plugins/megatron_importer.py | 24 ++++++++----------- .../torch/export/unified_export_megatron.py | 6 ++--- 3 files changed, 15 insertions(+), 19 deletions(-) diff --git a/modelopt/torch/export/plugins/mcore_nemotron.py b/modelopt/torch/export/plugins/mcore_nemotron.py index 385418a37..12408480b 100644 --- a/modelopt/torch/export/plugins/mcore_nemotron.py +++ b/modelopt/torch/export/plugins/mcore_nemotron.py @@ -92,8 +92,8 @@ "mtp.eh_proj": NameRemapping("mtp.layers.{}.eh_proj.", {"is_mtp": True}), "mtp.final_layernorm": NameRemapping("mtp.layers.{}.final_layernorm.", {"is_mtp": True}), # Grouped local experts in MTP - "experts.linear_fc1": GroupedMLPMerging("mtp.layers.{}.experts.{{}}.up_proj", COL_ETP | {"is_mtp": True}), - "experts.linear_fc2": GroupedMLPMerging("mtp.layers.{}.experts.{{}}.down_proj", ROW_ETP | {"is_mtp": True}), + "experts.linear_fc1": GroupedMLPMerging("mtp.layers.{}.mixer.experts.{{}}.up_proj", COL_ETP | {"is_mtp": True}), + "experts.linear_fc2": GroupedMLPMerging("mtp.layers.{}.mixer.experts.{{}}.down_proj", ROW_ETP | {"is_mtp": True}), } diff --git a/modelopt/torch/export/plugins/megatron_importer.py b/modelopt/torch/export/plugins/megatron_importer.py index 79ce62573..be14934cd 100644 --- a/modelopt/torch/export/plugins/megatron_importer.py +++ b/modelopt/torch/export/plugins/megatron_importer.py @@ -277,19 +277,14 @@ def _grouped_mlp_merging( else: prefix = prefix.replace("model", "mtp") - state_dict = module.state_dict() - weight = state_dict.get("weight", None) - print(f"mcore weight.shape: {weight.shape}") - weight_scale = state_dict.get("weight_quantizer._scale", None) + state_dict = module.state_dict() + # TODO handle weight_scale + #weight_scale = state_dict.get("weight_quantizer._scale", None) - all_experts = [] + assert module.num_gemms == num_local_experts, "num_gemms must be equal to num_local_experts in TEGroupedMLP" for expert_id in range(init_expert_id, init_expert_id + num_local_experts): tensor = self._get_safetensor(prefix.format(expert_id) + ".weight") - print(f"HF weight.shape: {tensor.shape}") - all_experts.append(tensor) - all_experts = torch.cat(all_experts, dim=0) - print(f"all_experts.shape: {all_experts.shape}") - state_dict["weight"] = all_experts + state_dict[f"weight{expert_id}"] = tensor module.load_state_dict(state_dict) @@ -602,9 +597,9 @@ def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool = layer_pbar.set_description("Importing MoE grouped local experts") num_local_experts = experts.num_local_experts num_global_experts = experts.config.num_moe_experts - print(f"num_local_experts: {num_local_experts}") - print(f"num_global_experts: {num_global_experts}") + assert num_local_experts == num_global_experts, "num_local_experts must be equal to num_global_experts during MoE import" + ''' if parallel_config is not None: etp_size = get_expert_tensor_parallel_world_size() # etp_rank = get_expert_tensor_parallel_rank() # this gives group rank @@ -613,9 +608,11 @@ def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool = print(f"etp_rank: {etp_rank}") assert num_local_experts * etp_size == num_global_experts init_index = etp_rank * num_local_experts + ''' + init_index = 0 self.rules["experts.linear_fc1"](experts.linear_fc1, layer_id, init_expert_id=init_index, num_local_experts=num_local_experts) - self.rules["experts.linear_fc2"](experts.linear_fc2, layer_id, init_expert_id=init_index, num_local_experts=num_local_experts ) + self.rules["experts.linear_fc2"](experts.linear_fc2, layer_id, init_expert_id=init_index, num_local_experts=num_local_experts) # We only support either EP or ETP for now elif get_expert_tensor_parallel_world_size() > 1: @@ -670,7 +667,6 @@ def _import_state_dict(self): ), flush=True, ) - break # TODO: remove this # Final layernorm if hasattr(model.decoder, "final_layernorm") and model.decoder.final_layernorm: diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index a2d70a20f..3410ddb5d 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -487,16 +487,16 @@ def _get_quantized_state( qformat: str = self._get_quantization_format(module) block_size = get_weight_block_size(module) - if hasattr(module, "weight") and module.weight is not None: + if hasattr(module, "weight") and module.weight is not None and module.weight.numel() > 0: weight = module.weight.to(dtype).cpu() name_to_value["weight"] = weight else: return name_to_value, qformat, block_size - if hasattr(module, "bias") and module.bias is not None: + if hasattr(module, "bias") and module.bias is not None and module.bias.numel() > 0: name_to_value["bias"] = module.bias.to(dtype).cpu() - if hasattr(module, "expert_bias") and module.expert_bias is not None: + if hasattr(module, "expert_bias") and module.expert_bias is not None and module.expert_bias.numel() > 0: name_to_value["expert_bias"] = module.expert_bias.to(dtype).cpu() if qformat == QUANTIZATION_NONE: From e6e10da50d4edca21680b2ca262561578d9bfe10 Mon Sep 17 00:00:00 2001 From: jenchen13 Date: Mon, 26 Jan 2026 19:16:43 -0800 Subject: [PATCH 09/20] fix config.json Signed-off-by: jenchen13 --- .../torch/export/plugins/megatron_importer.py | 6 ++++ .../torch/export/unified_export_megatron.py | 33 ++++++++++++++++++- 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/export/plugins/megatron_importer.py b/modelopt/torch/export/plugins/megatron_importer.py index be14934cd..ea47869ee 100644 --- a/modelopt/torch/export/plugins/megatron_importer.py +++ b/modelopt/torch/export/plugins/megatron_importer.py @@ -567,9 +567,15 @@ def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool = if not isinstance(layer.mlp, IdentityOp): if "MoE" in str(type(layer.mlp)): layer_pbar.set_description("Importing MoE") + print(f"moe_router_dtype: {self.moe_router_dtype}") self.rules["router"]( layer.mlp.router, layer_id, dtype=self.moe_router_dtype ) + if hasattr(layer.mlp, "fc1_latent_proj") and layer.mlp.fc1_latent_proj is not None: + self.rules["fc1_latent_proj"](layer.mlp.fc1_latent_proj, layer_id) + if hasattr(layer.mlp, "fc2_latent_proj") and layer.mlp.fc2_latent_proj is not None: + self.rules["fc2_latent_proj"](layer.mlp.fc2_latent_proj, layer_id) + if ( hasattr(layer.mlp, "shared_experts") and layer.mlp.shared_experts is not None diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index 3410ddb5d..4f4844f6a 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -47,7 +47,7 @@ QUANTIZATION_NVFP4, ) from .plugins.mcore_common import all_mcore_hf_export_mapping -from .plugins.mcore_custom import CustomModuleMapping, save_safetensors +from .plugins.mcore_custom import CustomModuleMapping, save_safetensors, get_safetensor from .plugins.megatron_importer import GPTModelImporter from .quant_utils import ( get_activation_scaling_factor, @@ -129,6 +129,7 @@ def __init__( self.moe_router_dtype = torch.float32 elif moe_router_dtype == "fp64": self.moe_router_dtype = torch.float64 + print(f"moe_router_dtype: {self.moe_router_dtype}") # If multimodal, extra the text_config self._hf_text_config = getattr(self._hf_config, "text_config", self._hf_config) @@ -309,6 +310,15 @@ def save_pretrained( with open(save_directory + "/hf_quant_config.json", "w") as f: json.dump(hf_quant_config, f, indent=4) + # Newer versions of VLLM expect config.json with hf_quant_config + config_file = save_directory + "/config.json" + if os.path.exists(config_file): + with open(config_file, "r") as f: + config = json.load(f) + config["quantization"] = hf_quant_config["quantization"] + with open(config_file, "w") as f: + json.dump(config, f, indent=4) + if ( is_first_stage_main_rank and self.is_multimodal @@ -1185,6 +1195,27 @@ def _get_state_dict(self): self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id) else: raise ValueError("Only TransformerLayer or MambaLayer are supported.") + + # MTP module + # Hacky version for now: copy MTP weights from pretrained model + if os.path.isdir(self._hf_pretrained_model_name): + safetensors_index_file = Path(self._hf_pretrained_model_name) / "model.safetensors.index.json" + else: + safetensors_index_file = hf_hub_download( + repo_id=self._hf_pretrained_model_name, + filename="model.safetensors.index.json") + + print(f"safetensors_index_file: {safetensors_index_file}") + if safetensors_index_file and os.path.exists(safetensors_index_file): + with open(safetensors_index_file, "r") as f: + safetensors_index = json.load(f) + model_dir = Path(safetensors_index_file).parent + for key in safetensors_index["weight_map"]: + if "mtp" in key: + self._state_dict[key] = get_safetensor(model_dir, key) + + # TODO implement actual MTP export + def export_mcore_gpt_to_hf( From 7d49569bfc5da3dff057d2107d0dcaab96e9cc9c Mon Sep 17 00:00:00 2001 From: jenchen13 Date: Tue, 27 Jan 2026 07:43:39 -0800 Subject: [PATCH 10/20] fix import router dtype bug Signed-off-by: jenchen13 --- modelopt/torch/export/plugins/megatron_importer.py | 5 ++--- modelopt/torch/export/unified_export_megatron.py | 1 - 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/modelopt/torch/export/plugins/megatron_importer.py b/modelopt/torch/export/plugins/megatron_importer.py index ea47869ee..81d9a5b39 100644 --- a/modelopt/torch/export/plugins/megatron_importer.py +++ b/modelopt/torch/export/plugins/megatron_importer.py @@ -191,7 +191,7 @@ def _name_remapping( tensor = expanded_tensor state_dict["weight"] = tensor.view(dtype=weight.dtype).to(device=weight.device) else: - state_dict["weight"] = tensor.to(dtype=self.dtype).to(device=weight.device) + state_dict["weight"] = tensor.to(dtype=dtype).to(device=weight.device) # Handle the rest of the state_dict. for key, val in module.state_dict().items(): @@ -566,8 +566,7 @@ def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool = if not isinstance(layer.mlp, IdentityOp): if "MoE" in str(type(layer.mlp)): - layer_pbar.set_description("Importing MoE") - print(f"moe_router_dtype: {self.moe_router_dtype}") + layer_pbar.set_description(f"Importing MoE with moe_router_dtype: {self.moe_router_dtype}") self.rules["router"]( layer.mlp.router, layer_id, dtype=self.moe_router_dtype ) diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index 4f4844f6a..279a9b698 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -129,7 +129,6 @@ def __init__( self.moe_router_dtype = torch.float32 elif moe_router_dtype == "fp64": self.moe_router_dtype = torch.float64 - print(f"moe_router_dtype: {self.moe_router_dtype}") # If multimodal, extra the text_config self._hf_text_config = getattr(self._hf_config, "text_config", self._hf_config) From 53f09b8929d81e704fa7e4c3bd41200da7ea2016 Mon Sep 17 00:00:00 2001 From: jenchen13 Date: Tue, 27 Jan 2026 12:46:19 -0800 Subject: [PATCH 11/20] make export exclude modules dynamic Signed-off-by: jenchen13 --- .../torch/export/unified_export_megatron.py | 568 +++++++++--------- 1 file changed, 296 insertions(+), 272 deletions(-) diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index 279a9b698..95f138cd5 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -129,6 +129,7 @@ def __init__( self.moe_router_dtype = torch.float32 elif moe_router_dtype == "fp64": self.moe_router_dtype = torch.float64 + print(f"Exporting model with moe_router_dtype: {self.moe_router_dtype}") # If multimodal, extra the text_config self._hf_text_config = getattr(self._hf_config, "text_config", self._hf_config) @@ -156,6 +157,7 @@ def __init__( del self._hf_config.quantization_config self.all_rules = self._populate_rule_book() self.rules = self.all_rules[self.arch] + self.exclude_modules = [] if not hasattr(model, "_modelopt_state"): return @@ -292,32 +294,25 @@ def save_pretrained( if is_last_stage_main_rank and quantization is not None: # TODO refactor to use mte.quant_utils.get_quant_config # except layer names are different in MCore and HF - hf_quant_config = { + self.exclude_modules.append("lm_head") + self._hf_quant_config = { "producer": { "name": "modelopt", "version": __version__, }, "quantization": { "quant_algo": quantization, - "exclude_modules": ["lm_head"], # TODO update this dynamically + "exclude_modules": self.exclude_modules, }, } if quantization == "NVFP4": # update block size - hf_quant_config["quantization"]["group_size"] = 16 + self._hf_quant_config["quantization"]["group_size"] = 16 if hasattr(self, "kv_cache_dtype"): - hf_quant_config["quantization"]["kv_cache_quant_algo"] = self.kv_cache_dtype + self._hf_quant_config["quantization"]["kv_cache_quant_algo"] = self.kv_cache_dtype with open(save_directory + "/hf_quant_config.json", "w") as f: - json.dump(hf_quant_config, f, indent=4) - - # Newer versions of VLLM expect config.json with hf_quant_config - config_file = save_directory + "/config.json" - if os.path.exists(config_file): - with open(config_file, "r") as f: - config = json.load(f) - config["quantization"] = hf_quant_config["quantization"] - with open(config_file, "w") as f: - json.dump(config, f, indent=4) + json.dump(self._hf_quant_config, f, indent=4) + if ( is_first_stage_main_rank and self.is_multimodal @@ -434,6 +429,15 @@ def save_pretrained( if modeling_file and os.path.exists(modeling_file): shutil.copy(modeling_file, f"{save_directory}/modeling_{model_type}.py") + # Newer versions of VLLM expect config.json with hf_quant_config + config_json_file = save_directory + "/config.json" + if self._hf_quant_config and os.path.exists(config_json_file): + with open(config_json_file, "r") as f: + config_dict = json.load(f) + config_dict["quantization"] = self._hf_quant_config["quantization"] + with open(config_json_file, "w") as f: + json.dump(config_dict, f, indent=4) + save_safetensors(state_dict, save_directory) @property @@ -450,6 +454,268 @@ def extra_state_dict(self): self._get_eagle_module_state_dict() return self._state_dict + def _get_state_dict(self): + model = self.model + + # Embedding + if hasattr(model, "embedding"): + self.rules["word_embeddings"](model.embedding.word_embeddings) + + # Final layernorm + if hasattr(model.decoder, "final_layernorm") and model.decoder.final_layernorm: + self.rules["final_layernorm"](model.decoder.final_layernorm) + + if hasattr(model.decoder, "final_norm") and model.decoder.final_norm: + self.rules["final_norm"](model.decoder.final_norm) + + # Output layer + if hasattr(model, "output_layer") and not model.share_embeddings_and_output_weights: + self.rules["output_layer"](model.output_layer) + + # Decoder layers + for layer in model.decoder.layers: + layer_id = layer.layer_number - 1 + if isinstance(layer, MambaLayer): + self._get_mamba_layer_state_dict(layer, layer_id) + elif isinstance(layer, TransformerLayer): + self._get_transformer_layer_state_dict(layer, layer_id) + else: + raise ValueError("Only TransformerLayer or MambaLayer are supported.") + + # MTP layer + self._get_mtp_state_dict() + + def _get_transformer_layer_state_dict(self, layer, layer_id): + if not isinstance(layer.input_layernorm, IdentityOp): + self.rules["input_layernorm"](layer.input_layernorm, layer_id) + + if not isinstance(layer.self_attention, IdentityOp): + if "MLASelfAttention" in str(type(layer.self_attention)): + if hasattr(layer.self_attention, "linear_q_proj"): + self.rules["linear_q_proj"]( + layer.self_attention.linear_q_proj, layer_id + ) + else: + self.rules["linear_q_down_proj"]( + layer.self_attention.linear_q_down_proj, layer_id + ) + self.rules["linear_q_layernorm"]( + layer.self_attention.q_layernorm, layer_id + ) + self.rules["linear_q_up_proj"]( + layer.self_attention.linear_q_up_proj, layer_id + ) + + self.rules["linear_kv_down_proj"]( + layer.self_attention.linear_kv_down_proj, layer_id + ) + self.rules["linear_kv_layernorm"]( + layer.self_attention.kv_layernorm, layer_id + ) + self.rules["linear_kv_up_proj"]( + layer.self_attention.linear_kv_up_proj, layer_id + ) + self.rules["linear_proj"](layer.self_attention.linear_proj, layer_id) + else: + if layer.self_attention.q_layernorm is not None and not isinstance( + layer.self_attention.q_layernorm, (IdentityOp, L2Norm) + ): + self.rules["q_layernorm"](layer.self_attention.q_layernorm, layer_id) + self.rules["k_layernorm"](layer.self_attention.k_layernorm, layer_id) + self.rules["linear_qkv"](layer.self_attention.linear_qkv, layer_id) + if hasattr(layer.self_attention, "core_attention"): + self.rules["core_attention"](layer.self_attention.core_attention, layer_id) + self.rules["linear_proj"](layer.self_attention.linear_proj, layer_id) + if ( + getattr(layer.self_attention.core_attention, "softmax_offset", None) + is not None + ): + self.rules["softmax_offset"]( + layer.self_attention.core_attention.softmax_offset, layer_id + ) + + if not isinstance(layer.pre_mlp_layernorm, IdentityOp): + self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id) + + if not isinstance(layer.mlp, IdentityOp): + if "MoE" in str(type(layer.mlp)): + self.rules["router"]( + layer.mlp.router, layer_id, dtype=self.moe_router_dtype + ) + if hasattr(layer.mlp, "fc1_latent_proj") and layer.mlp.fc1_latent_proj is not None: + self.rules["fc1_latent_proj"](layer.mlp.fc1_latent_proj, layer_id) + if hasattr(layer.mlp, "fc2_latent_proj") and layer.mlp.fc2_latent_proj is not None: + self.rules["fc2_latent_proj"](layer.mlp.fc2_latent_proj, layer_id) + if ( + hasattr(layer.mlp, "shared_experts") + and layer.mlp.shared_experts is not None + ): + self.rules["shared_experts.linear_fc1"]( + layer.mlp.shared_experts.linear_fc1, layer_id + ) + self.rules["shared_experts.linear_fc2"]( + layer.mlp.shared_experts.linear_fc2, layer_id + ) + if not self.rules.get("use_packed_local_experts", False): + for expert_id, expert in enumerate(layer.mlp.experts.local_experts): + self.rules["local_experts.linear_fc1"]( + expert.linear_fc1, layer_id, expert_id + ) + self.rules["local_experts.linear_fc2"]( + expert.linear_fc2, layer_id, expert_id + ) + else: + # For llama 4, in hf unified checkpoint, all local experts share one scale + self.rules["local_experts.linear_fc1"]( + layer.mlp.experts.local_experts, layer_id + ) + self.rules["local_experts.linear_fc2"]( + layer.mlp.experts.local_experts, layer_id + ) + else: + self.rules["linear_fc1"](layer.mlp.linear_fc1, layer_id) + self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id) + + def _get_mtp_state_dict(self): + # TODO implement actual MTP export + # Hacky version for now: copy MTP weights from pretrained model + if os.path.isdir(self._hf_pretrained_model_name): + safetensors_index_file = Path(self._hf_pretrained_model_name) / "model.safetensors.index.json" + else: + safetensors_index_file = hf_hub_download( + repo_id=self._hf_pretrained_model_name, + filename="model.safetensors.index.json") + + print(f"Exporting MTP: using safetensors_index_file: {safetensors_index_file}") + if safetensors_index_file and os.path.exists(safetensors_index_file): + with open(safetensors_index_file, "r") as f: + safetensors_index = json.load(f) + model_dir = Path(safetensors_index_file).parent + for key in safetensors_index["weight_map"]: + if key.startswith("mtp.") and key not in self._state_dict: + self._state_dict[key] = get_safetensor(model_dir, key) + + + def _get_mamba_layer_state_dict(self, layer, layer_id): + if not isinstance(layer.norm, IdentityOp): + self.rules["norm"](layer.norm, layer_id) + + self.rules["mixer_norm"](layer.mixer.norm, layer_id) + self.rules["A_log"](layer.mixer.A_log, layer_id) + self.rules["D"](layer.mixer.D, layer_id) + self.rules["dt_bias"](layer.mixer.dt_bias, layer_id) + + self.rules["conv1d"](layer.mixer.conv1d, layer_id) + self.rules["in_proj"](layer.mixer.in_proj, layer_id) + self.rules["out_proj"](layer.mixer.out_proj, layer_id) + + def _get_medusa_heads_state_dict(self): + medusa_heads = getattr(self.model, "medusa_heads", None) + if medusa_heads is None: + return + + for head_id, head in enumerate(medusa_heads): + self.rules["medusa_heads.lm_head"](head.lm_head, head_id) + for layer_id, layer in enumerate(head.medusa_layers): + self.rules["medusa_heads.medusa_layers.linear"](layer.linear, head_id, layer_id) + + def _get_eagle_module_state_dict(self): + eagle_module = getattr(self.model, "eagle_module", None) + + if eagle_module is None: + return + + # if hasattr(self.model, "embedding"): + # self.rules["word_embeddings"](self.model.embedding.word_embeddings) + + self.rules["fc"](eagle_module.fc) + if self.model.eagle_config.use_aux_hidden_state: + self.rules["enorm"](eagle_module.enorm) + elif self.model.eagle_config.use_mtp_layernorm: + self.rules["enorm"](eagle_module.enorm) + self.rules["hnorm"](eagle_module.hnorm) + + if self.model.eagle_config.use_last_layernorm: + self.rules["final_layernorm"](eagle_module.decoder.final_layernorm) + + if hasattr(self.model.eagle_module, "eagle_output_layer"): + self.rules["output_layer"](eagle_module.eagle_output_layer) + if hasattr(self.model.eagle_module, "dt2"): + self.rules["d2t"](eagle_module.d2t) + + for layer in eagle_module.decoder.layers: + layer_id = layer.layer_number - 1 + + # The first layernorm needs special handling here. We have a dedicated mapping + # for the first layernorm since in EAGLE3 it will be mapped to hidden_norm + # instead of input_layernorm (due to the specialized transformer layer). + # The remaining EAGLE3 layers (if more than 1) are normal transformer layers + # where input_layernorm is mapped to input_layernorm. + if layer_id == 0 and self.model.eagle_config.use_input_layernorm_in_first_layer: + self.rules["first_input_layernorm"](layer.input_layernorm, layer_id) + elif layer_id > 0: + self.rules["input_layernorm"](layer.input_layernorm, layer_id) + + if "MLASelfAttention" in str(type(layer.self_attention)): + if hasattr(layer.self_attention, "linear_q_proj"): + self.rules["eagle_module.linear_q_proj"]( + layer.self_attention.linear_q_proj, layer_id + ) + else: + self.rules["eagle_module.linear_q_down_proj"]( + layer.self_attention.linear_q_down_proj, layer_id + ) + self.rules["eagle_module.linear_q_layernorm"]( + layer.self_attention.q_layernorm, layer_id + ) + self.rules["eagle_module.linear_q_up_proj"]( + layer.self_attention.linear_q_up_proj, layer_id + ) + + self.rules["eagle_module.linear_kv_down_proj"]( + layer.self_attention.linear_kv_down_proj, layer_id + ) + self.rules["eagle_module.linear_kv_layernorm"]( + layer.self_attention.kv_layernorm, layer_id + ) + self.rules["eagle_module.linear_kv_up_proj"]( + layer.self_attention.linear_kv_up_proj, layer_id + ) + else: + self.rules["linear_qkv"](layer.self_attention.linear_qkv, layer_id) + + self.rules["linear_proj"](layer.self_attention.linear_proj, layer_id) + self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id) + + if "MoE" in str(type(layer.mlp)): + self.rules["eagle_module.router"](layer.mlp.router, layer_id) + if hasattr(layer.mlp, "shared_experts") and layer.mlp.shared_experts is not None: + self.rules["eagle_module.shared_experts.linear_fc1"]( + layer.mlp.shared_experts.linear_fc1, layer_id + ) + self.rules["eagle_module.shared_experts.linear_fc2"]( + layer.mlp.shared_experts.linear_fc2, layer_id + ) + for expert_id, expert in enumerate(layer.mlp.experts.local_experts): + self.rules["eagle_module.local_experts.linear_fc1"]( + expert.linear_fc1, layer_id, expert_id + ) + self.rules["eagle_module.local_experts.linear_fc2"]( + expert.linear_fc2, layer_id, expert_id + ) + else: + self.rules["linear_fc1"](layer.mlp.linear_fc1, layer_id) + self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id) + + parallel_draft_heads = getattr(eagle_module, "parallel_draft_heads", None) + if parallel_draft_heads is not None: + for head_id, head in enumerate(parallel_draft_heads.medusa_heads): + for layer_id, layer in enumerate(head): + self.rules["parallel_draft_heads.medusa_layers"]( + layer.linear, head_id, layer_id + ) + self.rules["parallel_draft_heads.lm_head"](parallel_draft_heads.lm_head) + def _populate_rule_book(self): all_rules = {} @@ -482,18 +748,22 @@ def _get_quantized_state( self, module: torch.nn.Module, dtype: torch.dtype = torch.float16, + prefix: str = "", ) -> tuple[dict[str, torch.Tensor], str, int]: """Return a state_dict, quantization format, and block_size of the module. Args: module: The target module to perform real quantization. dtype: The default data type. + prefix: The prefix of the layer. Returns: Tuple: state_dict, quantization format, and block_size of the module. """ name_to_value = {} qformat: str = self._get_quantization_format(module) + if qformat is None: + self.exclude_modules.append(prefix) block_size = get_weight_block_size(module) if hasattr(module, "weight") and module.weight is not None and module.weight.numel() > 0: @@ -561,12 +831,16 @@ def _name_remapping( self._state_dict[prefix] = module return - name_to_value, qformat, block_size = self._get_quantized_state(module, dtype) + name_to_value, qformat, block_size = self._get_quantized_state(module, dtype, prefix=prefix) weight = name_to_value.pop("weight") + if "mixer.gate" in prefix: + print(f"{prefix}: weight dtype: {weight.dtype}") weight_scale, weight_scale_2 = self._get_weight_scales(name_to_value, qformat) if weight_scale is None: + if "mixer.gate" in prefix: + print(f"{prefix}: weight_scale is None") self._state_dict[prefix + "weight"] = weight else: self._state_dict[prefix + "weight"] = to_quantized_weight( @@ -589,11 +863,14 @@ def _name_remapping( else: source_key = mapping.get(key, key) self._state_dict[prefix + source_key] = val + print(f"{prefix + source_key}: {self._state_dict[prefix + source_key].dtype}") + if "mixer.gate" in prefix: + print(f"{prefix}weight: {self._state_dict[prefix + 'weight'].dtype}") def _gated_mlp_slicing( self, module, prefix, gate_proj_name="gate_proj", up_proj_name="up_proj" ): - name_to_value, qformat, block_size = self._get_quantized_state(module, self.dtype) + name_to_value, qformat, block_size = self._get_quantized_state(module, self.dtype, prefix=prefix) weight = name_to_value.pop("weight") weight_scale, weight_scale_2 = self._get_weight_scales(name_to_value, qformat) @@ -656,7 +933,7 @@ def _qkv_slicing( k_proj_name="k_proj", v_proj_name="v_proj", ): - name_to_value, qformat, block_size = self._get_quantized_state(module, self.dtype) + name_to_value, qformat, block_size = self._get_quantized_state(module, self.dtype, prefix=prefix) q_proj_prefix = prefix + q_proj_name + "." k_proj_prefix = prefix + k_proj_name + "." @@ -791,7 +1068,7 @@ def _pack_name_remapping(self, module, prefix, layer_type=None): for expert in module: assert layer_type is not None, "layer_type is required for pack_name_remapping" name_to_value, qformat, block_size = self._get_quantized_state( - getattr(expert, layer_type), self.dtype + getattr(expert, layer_type), self.dtype, prefix=prefix ) weight = name_to_value.pop("weight") weight_scale, weight_scale_2 = self._get_weight_scales(name_to_value, qformat) @@ -857,7 +1134,7 @@ def _pack_name_remapping_gpt_oss(self, module, prefix, layer_type=None): for expert in module: assert layer_type is not None, "layer_type is required for pack_name_remapping" name_to_value, qformat, block_size = self._get_quantized_state( - getattr(expert, layer_type), self.dtype + getattr(expert, layer_type), self.dtype, prefix=prefix ) weight = name_to_value.pop("weight") bias = name_to_value.pop("bias", None) @@ -960,260 +1237,7 @@ def _pack_name_remapping_gpt_oss(self, module, prefix, layer_type=None): # TODO: May need to modify the key name later. self._state_dict[prefix + "_bias"] = merged_bias - def _get_medusa_heads_state_dict(self): - medusa_heads = getattr(self.model, "medusa_heads", None) - if medusa_heads is None: - return - - for head_id, head in enumerate(medusa_heads): - self.rules["medusa_heads.lm_head"](head.lm_head, head_id) - for layer_id, layer in enumerate(head.medusa_layers): - self.rules["medusa_heads.medusa_layers.linear"](layer.linear, head_id, layer_id) - - def _get_eagle_module_state_dict(self): - eagle_module = getattr(self.model, "eagle_module", None) - - if eagle_module is None: - return - - # if hasattr(self.model, "embedding"): - # self.rules["word_embeddings"](self.model.embedding.word_embeddings) - - self.rules["fc"](eagle_module.fc) - if self.model.eagle_config.use_aux_hidden_state: - self.rules["enorm"](eagle_module.enorm) - elif self.model.eagle_config.use_mtp_layernorm: - self.rules["enorm"](eagle_module.enorm) - self.rules["hnorm"](eagle_module.hnorm) - - if self.model.eagle_config.use_last_layernorm: - self.rules["final_layernorm"](eagle_module.decoder.final_layernorm) - - if hasattr(self.model.eagle_module, "eagle_output_layer"): - self.rules["output_layer"](eagle_module.eagle_output_layer) - if hasattr(self.model.eagle_module, "dt2"): - self.rules["d2t"](eagle_module.d2t) - - for layer in eagle_module.decoder.layers: - layer_id = layer.layer_number - 1 - - # The first layernorm needs special handling here. We have a dedicated mapping - # for the first layernorm since in EAGLE3 it will be mapped to hidden_norm - # instead of input_layernorm (due to the specialized transformer layer). - # The remaining EAGLE3 layers (if more than 1) are normal transformer layers - # where input_layernorm is mapped to input_layernorm. - if layer_id == 0 and self.model.eagle_config.use_input_layernorm_in_first_layer: - self.rules["first_input_layernorm"](layer.input_layernorm, layer_id) - elif layer_id > 0: - self.rules["input_layernorm"](layer.input_layernorm, layer_id) - - if "MLASelfAttention" in str(type(layer.self_attention)): - if hasattr(layer.self_attention, "linear_q_proj"): - self.rules["eagle_module.linear_q_proj"]( - layer.self_attention.linear_q_proj, layer_id - ) - else: - self.rules["eagle_module.linear_q_down_proj"]( - layer.self_attention.linear_q_down_proj, layer_id - ) - self.rules["eagle_module.linear_q_layernorm"]( - layer.self_attention.q_layernorm, layer_id - ) - self.rules["eagle_module.linear_q_up_proj"]( - layer.self_attention.linear_q_up_proj, layer_id - ) - - self.rules["eagle_module.linear_kv_down_proj"]( - layer.self_attention.linear_kv_down_proj, layer_id - ) - self.rules["eagle_module.linear_kv_layernorm"]( - layer.self_attention.kv_layernorm, layer_id - ) - self.rules["eagle_module.linear_kv_up_proj"]( - layer.self_attention.linear_kv_up_proj, layer_id - ) - else: - self.rules["linear_qkv"](layer.self_attention.linear_qkv, layer_id) - - self.rules["linear_proj"](layer.self_attention.linear_proj, layer_id) - self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id) - - if "MoE" in str(type(layer.mlp)): - self.rules["eagle_module.router"](layer.mlp.router, layer_id) - if hasattr(layer.mlp, "shared_experts") and layer.mlp.shared_experts is not None: - self.rules["eagle_module.shared_experts.linear_fc1"]( - layer.mlp.shared_experts.linear_fc1, layer_id - ) - self.rules["eagle_module.shared_experts.linear_fc2"]( - layer.mlp.shared_experts.linear_fc2, layer_id - ) - for expert_id, expert in enumerate(layer.mlp.experts.local_experts): - self.rules["eagle_module.local_experts.linear_fc1"]( - expert.linear_fc1, layer_id, expert_id - ) - self.rules["eagle_module.local_experts.linear_fc2"]( - expert.linear_fc2, layer_id, expert_id - ) - else: - self.rules["linear_fc1"](layer.mlp.linear_fc1, layer_id) - self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id) - - parallel_draft_heads = getattr(eagle_module, "parallel_draft_heads", None) - if parallel_draft_heads is not None: - for head_id, head in enumerate(parallel_draft_heads.medusa_heads): - for layer_id, layer in enumerate(head): - self.rules["parallel_draft_heads.medusa_layers"]( - layer.linear, head_id, layer_id - ) - self.rules["parallel_draft_heads.lm_head"](parallel_draft_heads.lm_head) - - def _get_state_dict(self): - model = self.model - - # Embedding - if hasattr(model, "embedding"): - self.rules["word_embeddings"](model.embedding.word_embeddings) - - # Final layernorm - if hasattr(model.decoder, "final_layernorm") and model.decoder.final_layernorm: - self.rules["final_layernorm"](model.decoder.final_layernorm) - - if hasattr(model.decoder, "final_norm") and model.decoder.final_norm: - self.rules["final_norm"](model.decoder.final_norm) - - # Output layer - if hasattr(model, "output_layer") and not model.share_embeddings_and_output_weights: - self.rules["output_layer"](model.output_layer) - - # Decoder layers - for layer in model.decoder.layers: - layer_id = layer.layer_number - 1 - - if isinstance(layer, MambaLayer): - if not isinstance(layer.norm, IdentityOp): - self.rules["norm"](layer.norm, layer_id) - - self.rules["mixer_norm"](layer.mixer.norm, layer_id) - self.rules["A_log"](layer.mixer.A_log, layer_id) - self.rules["D"](layer.mixer.D, layer_id) - self.rules["dt_bias"](layer.mixer.dt_bias, layer_id) - - self.rules["conv1d"](layer.mixer.conv1d, layer_id) - self.rules["in_proj"](layer.mixer.in_proj, layer_id) - self.rules["out_proj"](layer.mixer.out_proj, layer_id) - - elif isinstance(layer, TransformerLayer): - if not isinstance(layer.input_layernorm, IdentityOp): - self.rules["input_layernorm"](layer.input_layernorm, layer_id) - - if not isinstance(layer.self_attention, IdentityOp): - if "MLASelfAttention" in str(type(layer.self_attention)): - if hasattr(layer.self_attention, "linear_q_proj"): - self.rules["linear_q_proj"]( - layer.self_attention.linear_q_proj, layer_id - ) - else: - self.rules["linear_q_down_proj"]( - layer.self_attention.linear_q_down_proj, layer_id - ) - self.rules["linear_q_layernorm"]( - layer.self_attention.q_layernorm, layer_id - ) - self.rules["linear_q_up_proj"]( - layer.self_attention.linear_q_up_proj, layer_id - ) - - self.rules["linear_kv_down_proj"]( - layer.self_attention.linear_kv_down_proj, layer_id - ) - self.rules["linear_kv_layernorm"]( - layer.self_attention.kv_layernorm, layer_id - ) - self.rules["linear_kv_up_proj"]( - layer.self_attention.linear_kv_up_proj, layer_id - ) - self.rules["linear_proj"](layer.self_attention.linear_proj, layer_id) - else: - if layer.self_attention.q_layernorm is not None and not isinstance( - layer.self_attention.q_layernorm, (IdentityOp, L2Norm) - ): - self.rules["q_layernorm"](layer.self_attention.q_layernorm, layer_id) - self.rules["k_layernorm"](layer.self_attention.k_layernorm, layer_id) - self.rules["linear_qkv"](layer.self_attention.linear_qkv, layer_id) - if hasattr(layer.self_attention, "core_attention"): - self.rules["core_attention"](layer.self_attention.core_attention, layer_id) - self.rules["linear_proj"](layer.self_attention.linear_proj, layer_id) - if ( - getattr(layer.self_attention.core_attention, "softmax_offset", None) - is not None - ): - self.rules["softmax_offset"]( - layer.self_attention.core_attention.softmax_offset, layer_id - ) - - if not isinstance(layer.pre_mlp_layernorm, IdentityOp): - self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id) - if not isinstance(layer.mlp, IdentityOp): - if "MoE" in str(type(layer.mlp)): - self.rules["router"]( - layer.mlp.router, layer_id, dtype=self.moe_router_dtype - ) - if hasattr(layer.mlp, "fc1_latent_proj") and layer.mlp.fc1_latent_proj is not None: - self.rules["fc1_latent_proj"](layer.mlp.fc1_latent_proj, layer_id) - if hasattr(layer.mlp, "fc2_latent_proj") and layer.mlp.fc2_latent_proj is not None: - self.rules["fc2_latent_proj"](layer.mlp.fc2_latent_proj, layer_id) - if ( - hasattr(layer.mlp, "shared_experts") - and layer.mlp.shared_experts is not None - ): - self.rules["shared_experts.linear_fc1"]( - layer.mlp.shared_experts.linear_fc1, layer_id - ) - self.rules["shared_experts.linear_fc2"]( - layer.mlp.shared_experts.linear_fc2, layer_id - ) - if not self.rules.get("use_packed_local_experts", False): - for expert_id, expert in enumerate(layer.mlp.experts.local_experts): - self.rules["local_experts.linear_fc1"]( - expert.linear_fc1, layer_id, expert_id - ) - self.rules["local_experts.linear_fc2"]( - expert.linear_fc2, layer_id, expert_id - ) - else: - # For llama 4, in hf unified checkpoint, all local experts share one scale - self.rules["local_experts.linear_fc1"]( - layer.mlp.experts.local_experts, layer_id - ) - self.rules["local_experts.linear_fc2"]( - layer.mlp.experts.local_experts, layer_id - ) - else: - self.rules["linear_fc1"](layer.mlp.linear_fc1, layer_id) - self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id) - else: - raise ValueError("Only TransformerLayer or MambaLayer are supported.") - - # MTP module - # Hacky version for now: copy MTP weights from pretrained model - if os.path.isdir(self._hf_pretrained_model_name): - safetensors_index_file = Path(self._hf_pretrained_model_name) / "model.safetensors.index.json" - else: - safetensors_index_file = hf_hub_download( - repo_id=self._hf_pretrained_model_name, - filename="model.safetensors.index.json") - - print(f"safetensors_index_file: {safetensors_index_file}") - if safetensors_index_file and os.path.exists(safetensors_index_file): - with open(safetensors_index_file, "r") as f: - safetensors_index = json.load(f) - model_dir = Path(safetensors_index_file).parent - for key in safetensors_index["weight_map"]: - if "mtp" in key: - self._state_dict[key] = get_safetensor(model_dir, key) - - # TODO implement actual MTP export From 31fcdc451ac032c1596b994c1ec3745dccebc9d1 Mon Sep 17 00:00:00 2001 From: jenchen13 Date: Thu, 29 Jan 2026 08:50:54 -0800 Subject: [PATCH 12/20] dynamically export exclude modules Signed-off-by: jenchen13 --- .../torch/export/unified_export_megatron.py | 30 ++++++++++++++++--- .../torch/quantization/plugins/megatron.py | 2 -- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index 95f138cd5..ca03e2d00 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -61,6 +61,7 @@ get_weight_scaling_factor_2, to_quantized_weight, ) +from .convert_hf_config import convert_hf_quant_config_format with import_plugin("transformers", verbose=False): import transformers @@ -244,6 +245,8 @@ def save_pretrained( # Main export process state_dict = self.extra_state_dict if self.export_extra_modules else self.state_dict + + quantization_format = self._get_quantization_format(self.model) quantization = None @@ -291,10 +294,18 @@ def save_pretrained( except (OSError, ValueError, ImportError): pass + # Gather exclude_modules from all ranks to ensure hf_quant_config is complete + all_exclude_modules = [None] * torch.distributed.get_world_size() + torch.distributed.all_gather_object(all_exclude_modules, self.exclude_modules) + combined_exclude_modules = set() + for modules in all_exclude_modules: + if modules: + combined_exclude_modules.update(modules) + self.exclude_modules = sorted(list(combined_exclude_modules)) + if is_last_stage_main_rank and quantization is not None: # TODO refactor to use mte.quant_utils.get_quant_config # except layer names are different in MCore and HF - self.exclude_modules.append("lm_head") self._hf_quant_config = { "producer": { "name": "modelopt", @@ -432,9 +443,10 @@ def save_pretrained( # Newer versions of VLLM expect config.json with hf_quant_config config_json_file = save_directory + "/config.json" if self._hf_quant_config and os.path.exists(config_json_file): + converted_quant_config = convert_hf_quant_config_format(self._hf_quant_config) with open(config_json_file, "r") as f: config_dict = json.load(f) - config_dict["quantization"] = self._hf_quant_config["quantization"] + config_dict["quantization_config"] = converted_quant_config with open(config_json_file, "w") as f: json.dump(config_dict, f, indent=4) @@ -577,7 +589,12 @@ def _get_transformer_layer_state_dict(self, layer, layer_id): self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id) def _get_mtp_state_dict(self): - # TODO implement actual MTP export + """ + Export the MTP module. + + Currently, we copy the BF16 MTP weights from the pretrained model. + """ + # TODO Implement MTP export for quantized MTP # Hacky version for now: copy MTP weights from pretrained model if os.path.isdir(self._hf_pretrained_model_name): safetensors_index_file = Path(self._hf_pretrained_model_name) / "model.safetensors.index.json" @@ -587,6 +604,7 @@ def _get_mtp_state_dict(self): filename="model.safetensors.index.json") print(f"Exporting MTP: using safetensors_index_file: {safetensors_index_file}") + mtp_exists = False if safetensors_index_file and os.path.exists(safetensors_index_file): with open(safetensors_index_file, "r") as f: safetensors_index = json.load(f) @@ -594,6 +612,10 @@ def _get_mtp_state_dict(self): for key in safetensors_index["weight_map"]: if key.startswith("mtp.") and key not in self._state_dict: self._state_dict[key] = get_safetensor(model_dir, key) + mtp_exists = True + + if mtp_exists: + self.exclude_modules.append("mtp*") def _get_mamba_layer_state_dict(self, layer, layer_id): @@ -762,7 +784,7 @@ def _get_quantized_state( """ name_to_value = {} qformat: str = self._get_quantization_format(module) - if qformat is None: + if qformat is None and "norm" not in prefix: # Add exclude layers for hf_quant_config self.exclude_modules.append(prefix) block_size = get_weight_block_size(module) diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 03b3888f3..c2a2eca0c 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -597,8 +597,6 @@ def sync_moe_local_experts_amax(self): else torch.maximum(stored_amax, amax_tensor) ) - - # Apply synchronized amax values back to all local experts for expert in self.local_experts: for name, module in expert.named_modules(): From 53103e9a65b0a4f037fe1f10096549a86d328168 Mon Sep 17 00:00:00 2001 From: jenchen13 Date: Thu, 29 Jan 2026 09:46:26 -0800 Subject: [PATCH 13/20] linting Signed-off-by: jenchen13 --- modelopt/torch/export/plugins/mcore_custom.py | 6 ++ .../torch/export/plugins/mcore_nemotron.py | 22 +++-- .../torch/export/plugins/megatron_importer.py | 74 ++++++++------ modelopt/torch/export/quant_utils.py | 19 ++-- .../torch/export/unified_export_megatron.py | 98 ++++++++----------- modelopt/torch/quantization/model_calib.py | 1 - .../torch/quantization/plugins/megatron.py | 4 +- 7 files changed, 115 insertions(+), 109 deletions(-) diff --git a/modelopt/torch/export/plugins/mcore_custom.py b/modelopt/torch/export/plugins/mcore_custom.py index 77a3208bb..c269cef1d 100644 --- a/modelopt/torch/export/plugins/mcore_custom.py +++ b/modelopt/torch/export/plugins/mcore_custom.py @@ -102,6 +102,7 @@ def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any] func_kwargs=func_kwargs, ) + class GroupedMLPMerging(CustomModuleMapping): """A custom module mapping that merges up_proj and down_proj for Grouped MLP.""" @@ -112,6 +113,8 @@ def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any] target_name_or_prefix=target_name_or_prefix, func_kwargs=func_kwargs, ) + + class GatedMLPMerging(CustomModuleMapping): """A custom module mapping that merges gate_proj and up_proj.""" @@ -135,8 +138,10 @@ def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any] func_kwargs=func_kwargs, ) + class SelfAttentionScaling(CustomModuleMapping): """A custom module mapping that scales self attention.""" + def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any] = {}): """Create a custom module mapping that scales self attention.""" super().__init__( @@ -145,6 +150,7 @@ def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any] func_kwargs=func_kwargs, ) + class GatedMLPSlicing(CustomModuleMapping): """A custom module mapping that slices gate_proj and up_proj.""" diff --git a/modelopt/torch/export/plugins/mcore_nemotron.py b/modelopt/torch/export/plugins/mcore_nemotron.py index 12408480b..70bd00fb2 100644 --- a/modelopt/torch/export/plugins/mcore_nemotron.py +++ b/modelopt/torch/export/plugins/mcore_nemotron.py @@ -23,10 +23,10 @@ ROW_ETP, ROW_TP, CustomModuleMapping, + GroupedMLPMerging, NameRemapping, QKVMerging, QKVSlicing, - GroupedMLPMerging, SelfAttentionScaling, ) @@ -87,17 +87,20 @@ "fc1_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc1_latent_proj.", REPLICATE), "fc2_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc2_latent_proj.", REPLICATE), # Repeated MTP module - "mtp.enorm": NameRemapping("mtp.layers.{}.enorm.", {"is_mtp": True}), - "mtp.hnorm": NameRemapping("mtp.layers.{}.hnorm.", {"is_mtp": True}), - "mtp.eh_proj": NameRemapping("mtp.layers.{}.eh_proj.", {"is_mtp": True}), - "mtp.final_layernorm": NameRemapping("mtp.layers.{}.final_layernorm.", {"is_mtp": True}), + "mtp.enorm": NameRemapping("mtp.layers.{}.enorm.", {"is_mtp": True}), + "mtp.hnorm": NameRemapping("mtp.layers.{}.hnorm.", {"is_mtp": True}), + "mtp.eh_proj": NameRemapping("mtp.layers.{}.eh_proj.", {"is_mtp": True}), + "mtp.final_layernorm": NameRemapping("mtp.layers.{}.final_layernorm.", {"is_mtp": True}), # Grouped local experts in MTP - "experts.linear_fc1": GroupedMLPMerging("mtp.layers.{}.mixer.experts.{{}}.up_proj", COL_ETP | {"is_mtp": True}), - "experts.linear_fc2": GroupedMLPMerging("mtp.layers.{}.mixer.experts.{{}}.down_proj", ROW_ETP | {"is_mtp": True}), - + "experts.linear_fc1": GroupedMLPMerging( + "mtp.layers.{}.mixer.experts.{{}}.up_proj", COL_ETP | {"is_mtp": True} + ), + "experts.linear_fc2": GroupedMLPMerging( + "mtp.layers.{}.mixer.experts.{{}}.down_proj", ROW_ETP | {"is_mtp": True} + ), } -# TODO ADD MTP export +# TODO ADD MTP export nemotron_h_causal_lm_export: dict[str, CustomModuleMapping] = { "word_embeddings": NameRemapping("backbone.embeddings."), @@ -139,5 +142,4 @@ "mtp.hnorm": NameRemapping("mtp.layers.{}.hnorm."), "mtp.eh_proj": NameRemapping("mtp.layers.{}.eh_proj."), "mtp.final_layernorm": NameRemapping("mtp.layers.{}.final_layernorm."), - } diff --git a/modelopt/torch/export/plugins/megatron_importer.py b/modelopt/torch/export/plugins/megatron_importer.py index 81d9a5b39..b1c2a4546 100644 --- a/modelopt/torch/export/plugins/megatron_importer.py +++ b/modelopt/torch/export/plugins/megatron_importer.py @@ -40,7 +40,6 @@ with import_plugin("megatron"): from megatron.core.parallel_state import ( get_expert_tensor_parallel_world_size, - get_expert_tensor_parallel_rank, get_tensor_model_parallel_world_size, ) from megatron.core.ssm.mamba_layer import MambaLayer @@ -272,16 +271,18 @@ def _grouped_mlp_merging( num_local_experts: int = 1, ): if is_mtp: - if "backbone" in prefix: + if "backbone" in prefix: prefix = prefix.replace("backbone", "mtp") else: prefix = prefix.replace("model", "mtp") state_dict = module.state_dict() # TODO handle weight_scale - #weight_scale = state_dict.get("weight_quantizer._scale", None) + # weight_scale = state_dict.get("weight_quantizer._scale", None) - assert module.num_gemms == num_local_experts, "num_gemms must be equal to num_local_experts in TEGroupedMLP" + assert module.num_gemms == num_local_experts, ( + "num_gemms must be equal to num_local_experts in TEGroupedMLP" + ) for expert_id in range(init_expert_id, init_expert_id + num_local_experts): tensor = self._get_safetensor(prefix.format(expert_id) + ".weight") state_dict[f"weight{expert_id}"] = tensor @@ -299,7 +300,7 @@ def _qkv_merging( is_mtp: bool = False, ): if is_mtp: - if "backbone" in prefix: + if "backbone" in prefix: prefix = prefix.replace("backbone", "mtp") else: prefix = prefix.replace("model", "mtp") @@ -527,7 +528,7 @@ def _import_mamba_layer(self, layer, layer_id, layer_pbar): self.rules["conv1d"](layer.mixer.conv1d, layer_id) self.rules["in_proj"](layer.mixer.in_proj, layer_id) self.rules["out_proj"](layer.mixer.out_proj, layer_id) - + def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool = False): if not isinstance(layer.input_layernorm, IdentityOp): self.rules["input_layernorm"](layer.input_layernorm, layer_id) @@ -557,34 +558,29 @@ def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool = self.rules["linear_qkv"](attention.linear_qkv, layer_id, is_mtp=is_mtp) self.rules["linear_proj"](attention.linear_proj, layer_id, is_mtp=is_mtp) if getattr(attention.core_attention, "softmax_offset", None) is not None: - self.rules["softmax_offset"]( - attention.core_attention.softmax_offset, layer_id - ) + self.rules["softmax_offset"](attention.core_attention.softmax_offset, layer_id) if not isinstance(layer.pre_mlp_layernorm, IdentityOp): self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id) if not isinstance(layer.mlp, IdentityOp): if "MoE" in str(type(layer.mlp)): - layer_pbar.set_description(f"Importing MoE with moe_router_dtype: {self.moe_router_dtype}") - self.rules["router"]( - layer.mlp.router, layer_id, dtype=self.moe_router_dtype + layer_pbar.set_description( + f"Importing MoE with moe_router_dtype: {self.moe_router_dtype}" ) + self.rules["router"](layer.mlp.router, layer_id, dtype=self.moe_router_dtype) if hasattr(layer.mlp, "fc1_latent_proj") and layer.mlp.fc1_latent_proj is not None: self.rules["fc1_latent_proj"](layer.mlp.fc1_latent_proj, layer_id) if hasattr(layer.mlp, "fc2_latent_proj") and layer.mlp.fc2_latent_proj is not None: self.rules["fc2_latent_proj"](layer.mlp.fc2_latent_proj, layer_id) - if ( - hasattr(layer.mlp, "shared_experts") - and layer.mlp.shared_experts is not None - ): + if hasattr(layer.mlp, "shared_experts") and layer.mlp.shared_experts is not None: layer_pbar.set_description("Importing MoE shared experts") fc1 = layer.mlp.shared_experts.linear_fc1 fc2 = layer.mlp.shared_experts.linear_fc2 self.rules["shared_experts.linear_fc1"](fc1, layer_id) self.rules["shared_experts.linear_fc2"](fc2, layer_id) - if not self.rules.get("use_packed_local_experts", False): # Import local experts + if not self.rules.get("use_packed_local_experts", False): # Import local experts experts = layer.mlp.experts if hasattr(experts, "local_experts"): for local_expert_id, expert in tqdm( @@ -598,13 +594,15 @@ def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool = fc2 = expert.linear_fc2 self.rules["local_experts.linear_fc1"](fc1, layer_id, expert_id) self.rules["local_experts.linear_fc2"](fc2, layer_id, expert_id) - else: # Slice TEGroupedMLP + else: # Slice TEGroupedMLP layer_pbar.set_description("Importing MoE grouped local experts") num_local_experts = experts.num_local_experts num_global_experts = experts.config.num_moe_experts - assert num_local_experts == num_global_experts, "num_local_experts must be equal to num_global_experts during MoE import" + assert num_local_experts == num_global_experts, ( + "num_local_experts must be equal to num_global_experts during MoE import" + ) - ''' + """ if parallel_config is not None: etp_size = get_expert_tensor_parallel_world_size() # etp_rank = get_expert_tensor_parallel_rank() # this gives group rank @@ -613,11 +611,21 @@ def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool = print(f"etp_rank: {etp_rank}") assert num_local_experts * etp_size == num_global_experts init_index = etp_rank * num_local_experts - ''' + """ init_index = 0 - self.rules["experts.linear_fc1"](experts.linear_fc1, layer_id, init_expert_id=init_index, num_local_experts=num_local_experts) - self.rules["experts.linear_fc2"](experts.linear_fc2, layer_id, init_expert_id=init_index, num_local_experts=num_local_experts) + self.rules["experts.linear_fc1"]( + experts.linear_fc1, + layer_id, + init_expert_id=init_index, + num_local_experts=num_local_experts, + ) + self.rules["experts.linear_fc2"]( + experts.linear_fc2, + layer_id, + init_expert_id=init_index, + num_local_experts=num_local_experts, + ) # We only support either EP or ETP for now elif get_expert_tensor_parallel_world_size() > 1: @@ -644,7 +652,6 @@ def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool = self.rules["linear_fc1"](layer.mlp.linear_fc1, layer_id) self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id) - def _import_state_dict(self): model = self.model # print(model, flush=True) @@ -687,8 +694,8 @@ def _import_state_dict(self): if hasattr(model, "mtp"): print("Importing MTP", flush=True) # MTP is the last layer in DeepSeek V3/R1 - if len(model.mtp.layers) == 1: # Repeated MTP - layer_id = 0 # reset layer_id for repeated MTP + if len(model.mtp.layers) == 1: # Repeated MTP + layer_id = 0 # reset layer_id for repeated MTP mtp = model.mtp.layers[0] self.rules["mtp.eh_proj"](mtp.eh_proj, layer_id) @@ -700,18 +707,23 @@ def _import_state_dict(self): if isinstance(mtp_model_layer, MambaLayer): self._import_mamba_layer(mtp_model_layer, layer_id, layer_pbar) elif isinstance(mtp_model_layer, TransformerLayer): - self._import_transformer_layer(mtp_model_layer, layer_id, layer_pbar, is_mtp=True) + self._import_transformer_layer( + mtp_model_layer, layer_id, layer_pbar, is_mtp=True + ) else: - raise ValueError(f"Unsupported layer type during MTP import: {type(mtp_model_layer)}") + raise ValueError( + f"Unsupported layer type during MTP import: {type(mtp_model_layer)}" + ) layer_id += 1 - else: # non-repeated MTP - + else: # non-repeated MTP for mtp in model.mtp.layers: self.rules["mtp.eh_proj"](mtp.eh_proj, layer_id) self.rules["mtp.enorm"](mtp.enorm, layer_id) self.rules["mtp.hnorm"](mtp.hnorm, layer_id) - self.rules["mtp.input_layernorm"](mtp.decoder.layers[0].input_layernorm, layer_id) + self.rules["mtp.input_layernorm"]( + mtp.decoder.layers[0].input_layernorm, layer_id + ) if hasattr(mtp.decoder.layers[0].self_attention, "linear_q_proj"): self.rules["mtp.linear_q_proj"]( mtp.decoder.layers[0].self_attention.linear_q_proj, layer_id diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index 48dec6ae1..370a3e753 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -387,7 +387,7 @@ def get_prequant_scaling_factor(module: nn.Module) -> torch.Tensor: if prequant_scaling_factor is not None: assert torch.all(prequant_scaling_factor > 0), ( f"prequant scaling factor {prequant_scaling_factor} not positive." - ) + ) return prequant_scaling_factor @@ -399,12 +399,19 @@ def get_kv_cache_bias(kv_module: nn.Module) -> list[torch.Tensor]: kv_bias.append(getattr(quantizer_module, "_bias_value", None)) return kv_bias + def get_kv_cache_scaling_factor(self_attention_module: nn.Module) -> torch.Tensor: + """Get the K and V BMM scaling factors for the self attention module. + + Args: + self_attention_module: The self attention module to get the K and V BMM scaling factors from. + + Returns: + The K and V BMM scaling factors. """ - Returns the k and v BMM scaling factors if BMM quantizers are set in the self attention module. - Else returns None by default. - """ - if not hasattr(self_attention_module, "k_bmm_quantizer") or not hasattr(self_attention_module, "v_bmm_quantizer"): + if not hasattr(self_attention_module, "k_bmm_quantizer") or not hasattr( + self_attention_module, "v_bmm_quantizer" + ): return [None, None] scaling_factors = [ @@ -414,7 +421,6 @@ def get_kv_cache_scaling_factor(self_attention_module: nn.Module) -> torch.Tenso return scaling_factors - def get_kv_cache_dtype(modules: list[nn.Module] | nn.Module) -> str | None: """Returns the kv_cache dtype. @@ -444,6 +450,7 @@ def get_kv_cache_dtype(modules: list[nn.Module] | nn.Module) -> str | None: return _compute_kv_cache_dtype(num_bits_list) + def _compute_kv_cache_dtype(num_bits_list: list[int]) -> str | None: """Returns the kv_cache dtype. diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index ca03e2d00..16fe0f6ad 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -25,11 +25,9 @@ from collections import OrderedDict from pathlib import Path from typing import Any -from warnings import warn import torch import torch.distributed -import torch.nn as nn from huggingface_hub import hf_hub_download, snapshot_download from safetensors.torch import safe_open, save_file from tqdm import tqdm @@ -37,6 +35,7 @@ from modelopt import __version__ from modelopt.torch.utils import import_plugin +from .convert_hf_config import convert_hf_quant_config_format from .model_config import ( KV_CACHE_FP8, KV_CACHE_NVFP4, @@ -47,21 +46,18 @@ QUANTIZATION_NVFP4, ) from .plugins.mcore_common import all_mcore_hf_export_mapping -from .plugins.mcore_custom import CustomModuleMapping, save_safetensors, get_safetensor +from .plugins.mcore_custom import CustomModuleMapping, get_safetensor, save_safetensors from .plugins.megatron_importer import GPTModelImporter from .quant_utils import ( get_activation_scaling_factor, - get_kv_cache_scaling_factor, get_kv_cache_dtype, - get_quant_config, + get_kv_cache_scaling_factor, get_quantization_format, - get_scaling_factor, get_weight_block_size, get_weight_scaling_factor, get_weight_scaling_factor_2, to_quantized_weight, ) -from .convert_hf_config import convert_hf_quant_config_format with import_plugin("transformers", verbose=False): import transformers @@ -246,7 +242,6 @@ def save_pretrained( # Main export process state_dict = self.extra_state_dict if self.export_extra_modules else self.state_dict - quantization_format = self._get_quantization_format(self.model) quantization = None @@ -301,7 +296,7 @@ def save_pretrained( for modules in all_exclude_modules: if modules: combined_exclude_modules.update(modules) - self.exclude_modules = sorted(list(combined_exclude_modules)) + self.exclude_modules = sorted(combined_exclude_modules) if is_last_stage_main_rank and quantization is not None: # TODO refactor to use mte.quant_utils.get_quant_config @@ -316,14 +311,13 @@ def save_pretrained( "exclude_modules": self.exclude_modules, }, } - if quantization == "NVFP4": # update block size + if quantization == "NVFP4": # update block size self._hf_quant_config["quantization"]["group_size"] = 16 if hasattr(self, "kv_cache_dtype"): self._hf_quant_config["quantization"]["kv_cache_quant_algo"] = self.kv_cache_dtype with open(save_directory + "/hf_quant_config.json", "w") as f: json.dump(self._hf_quant_config, f, indent=4) - if ( is_first_stage_main_rank and self.is_multimodal @@ -444,7 +438,7 @@ def save_pretrained( config_json_file = save_directory + "/config.json" if self._hf_quant_config and os.path.exists(config_json_file): converted_quant_config = convert_hf_quant_config_format(self._hf_quant_config) - with open(config_json_file, "r") as f: + with open(config_json_file) as f: config_dict = json.load(f) config_dict["quantization_config"] = converted_quant_config with open(config_json_file, "w") as f: @@ -493,7 +487,7 @@ def _get_state_dict(self): self._get_transformer_layer_state_dict(layer, layer_id) else: raise ValueError("Only TransformerLayer or MambaLayer are supported.") - + # MTP layer self._get_mtp_state_dict() @@ -504,29 +498,19 @@ def _get_transformer_layer_state_dict(self, layer, layer_id): if not isinstance(layer.self_attention, IdentityOp): if "MLASelfAttention" in str(type(layer.self_attention)): if hasattr(layer.self_attention, "linear_q_proj"): - self.rules["linear_q_proj"]( - layer.self_attention.linear_q_proj, layer_id - ) + self.rules["linear_q_proj"](layer.self_attention.linear_q_proj, layer_id) else: self.rules["linear_q_down_proj"]( layer.self_attention.linear_q_down_proj, layer_id ) - self.rules["linear_q_layernorm"]( - layer.self_attention.q_layernorm, layer_id - ) - self.rules["linear_q_up_proj"]( - layer.self_attention.linear_q_up_proj, layer_id - ) + self.rules["linear_q_layernorm"](layer.self_attention.q_layernorm, layer_id) + self.rules["linear_q_up_proj"](layer.self_attention.linear_q_up_proj, layer_id) self.rules["linear_kv_down_proj"]( layer.self_attention.linear_kv_down_proj, layer_id ) - self.rules["linear_kv_layernorm"]( - layer.self_attention.kv_layernorm, layer_id - ) - self.rules["linear_kv_up_proj"]( - layer.self_attention.linear_kv_up_proj, layer_id - ) + self.rules["linear_kv_layernorm"](layer.self_attention.kv_layernorm, layer_id) + self.rules["linear_kv_up_proj"](layer.self_attention.linear_kv_up_proj, layer_id) self.rules["linear_proj"](layer.self_attention.linear_proj, layer_id) else: if layer.self_attention.q_layernorm is not None and not isinstance( @@ -538,10 +522,7 @@ def _get_transformer_layer_state_dict(self, layer, layer_id): if hasattr(layer.self_attention, "core_attention"): self.rules["core_attention"](layer.self_attention.core_attention, layer_id) self.rules["linear_proj"](layer.self_attention.linear_proj, layer_id) - if ( - getattr(layer.self_attention.core_attention, "softmax_offset", None) - is not None - ): + if getattr(layer.self_attention.core_attention, "softmax_offset", None) is not None: self.rules["softmax_offset"]( layer.self_attention.core_attention.softmax_offset, layer_id ) @@ -551,17 +532,12 @@ def _get_transformer_layer_state_dict(self, layer, layer_id): if not isinstance(layer.mlp, IdentityOp): if "MoE" in str(type(layer.mlp)): - self.rules["router"]( - layer.mlp.router, layer_id, dtype=self.moe_router_dtype - ) + self.rules["router"](layer.mlp.router, layer_id, dtype=self.moe_router_dtype) if hasattr(layer.mlp, "fc1_latent_proj") and layer.mlp.fc1_latent_proj is not None: self.rules["fc1_latent_proj"](layer.mlp.fc1_latent_proj, layer_id) if hasattr(layer.mlp, "fc2_latent_proj") and layer.mlp.fc2_latent_proj is not None: self.rules["fc2_latent_proj"](layer.mlp.fc2_latent_proj, layer_id) - if ( - hasattr(layer.mlp, "shared_experts") - and layer.mlp.shared_experts is not None - ): + if hasattr(layer.mlp, "shared_experts") and layer.mlp.shared_experts is not None: self.rules["shared_experts.linear_fc1"]( layer.mlp.shared_experts.linear_fc1, layer_id ) @@ -589,35 +565,35 @@ def _get_transformer_layer_state_dict(self, layer, layer_id): self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id) def _get_mtp_state_dict(self): - """ - Export the MTP module. - + """Export the MTP module. + Currently, we copy the BF16 MTP weights from the pretrained model. """ # TODO Implement MTP export for quantized MTP # Hacky version for now: copy MTP weights from pretrained model if os.path.isdir(self._hf_pretrained_model_name): - safetensors_index_file = Path(self._hf_pretrained_model_name) / "model.safetensors.index.json" + safetensors_index_file = ( + Path(self._hf_pretrained_model_name) / "model.safetensors.index.json" + ) else: safetensors_index_file = hf_hub_download( - repo_id=self._hf_pretrained_model_name, - filename="model.safetensors.index.json") + repo_id=self._hf_pretrained_model_name, filename="model.safetensors.index.json" + ) print(f"Exporting MTP: using safetensors_index_file: {safetensors_index_file}") mtp_exists = False if safetensors_index_file and os.path.exists(safetensors_index_file): - with open(safetensors_index_file, "r") as f: + with open(safetensors_index_file) as f: safetensors_index = json.load(f) model_dir = Path(safetensors_index_file).parent for key in safetensors_index["weight_map"]: if key.startswith("mtp.") and key not in self._state_dict: self._state_dict[key] = get_safetensor(model_dir, key) mtp_exists = True - + if mtp_exists: self.exclude_modules.append("mtp*") - def _get_mamba_layer_state_dict(self, layer, layer_id): if not isinstance(layer.norm, IdentityOp): self.rules["norm"](layer.norm, layer_id) @@ -797,7 +773,11 @@ def _get_quantized_state( if hasattr(module, "bias") and module.bias is not None and module.bias.numel() > 0: name_to_value["bias"] = module.bias.to(dtype).cpu() - if hasattr(module, "expert_bias") and module.expert_bias is not None and module.expert_bias.numel() > 0: + if ( + hasattr(module, "expert_bias") + and module.expert_bias is not None + and module.expert_bias.numel() > 0 + ): name_to_value["expert_bias"] = module.expert_bias.to(dtype).cpu() if qformat == QUANTIZATION_NONE: @@ -818,8 +798,7 @@ def _get_quantized_state( # TODO (chenhany): support AWQ with pre_quant_scale if hasattr(module.input_quantizer, "_pre_quant_scale"): raise ValueError("Detect pre_quant_scale! SmoothQuant/AWQ are not yet supported!") - - + return name_to_value, qformat, block_size def _get_quantization_format(self, module: torch.nn.Module): @@ -892,7 +871,9 @@ def _name_remapping( def _gated_mlp_slicing( self, module, prefix, gate_proj_name="gate_proj", up_proj_name="up_proj" ): - name_to_value, qformat, block_size = self._get_quantized_state(module, self.dtype, prefix=prefix) + name_to_value, qformat, block_size = self._get_quantized_state( + module, self.dtype, prefix=prefix + ) weight = name_to_value.pop("weight") weight_scale, weight_scale_2 = self._get_weight_scales(name_to_value, qformat) @@ -954,8 +935,10 @@ def _qkv_slicing( q_proj_name="q_proj", k_proj_name="k_proj", v_proj_name="v_proj", - ): - name_to_value, qformat, block_size = self._get_quantized_state(module, self.dtype, prefix=prefix) + ): + name_to_value, qformat, block_size = self._get_quantized_state( + module, self.dtype, prefix=prefix + ) q_proj_prefix = prefix + q_proj_name + "." k_proj_prefix = prefix + k_proj_name + "." @@ -1065,7 +1048,9 @@ def _qkv_slicing( self._state_dict[k_proj_key] = val.detach().clone() self._state_dict[v_proj_key] = val.detach().clone() - def _self_attention_scaling(self, module, prefix, k_scale_name="k_scale", v_scale_name="v_scale"): + def _self_attention_scaling( + self, module, prefix, k_scale_name="k_scale", v_scale_name="v_scale" + ): """KV cache scaling for CoreAttention module.""" k_scale_key = prefix + k_scale_name v_scale_key = prefix + v_scale_name @@ -1260,9 +1245,6 @@ def _pack_name_remapping_gpt_oss(self, module, prefix, layer_type=None): self._state_dict[prefix + "_bias"] = merged_bias - - - def export_mcore_gpt_to_hf( model: torch.nn.Module, pretrained_model_name_or_path: str | os.PathLike | None = None, diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 5360e9415..386fe8c2c 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -100,7 +100,6 @@ def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state): quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group) # TODO: create sync_bias_across_distributed_group - # Step 2:Sync amax across data parallelism for name, module in model.named_modules(): if isinstance(module, QuantModule): diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index c2a2eca0c..6fa99e819 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -23,7 +23,6 @@ import megatron.core.parallel_state as mcore_parallel import megatron.core.tensor_parallel.layers as megatron_parallel import megatron.core.transformer.mlp as megatron_mlp -import megatron.core.transformer.transformer_layer as megatron_transformer_layer import megatron.core.transformer.moe.experts as megatron_moe import megatron.core.transformer.moe.moe_layer as megatron_moe_layer import torch @@ -41,7 +40,6 @@ register_modelopt_extra_state_callbacks, ) from modelopt.torch.utils.distributed import ParallelState -import torch.distributed as dist from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer from ..nn.modules.quant_linear import RealQuantLinear @@ -50,11 +48,11 @@ try: from megatron.core.extensions.transformer_engine import ( - TELinear, TEColumnParallelGroupedLinear, TEColumnParallelLinear, TEDotProductAttention, TELayerNormColumnParallelLinear, + TELinear, TERowParallelGroupedLinear, TERowParallelLinear, ) From e41783a4eb1dce31e7a7796e1f184c2fb9d1d9d7 Mon Sep 17 00:00:00 2001 From: jenchen13 Date: Thu, 29 Jan 2026 10:22:54 -0800 Subject: [PATCH 14/20] remove comments Signed-off-by: jenchen13 --- .../torch/export/plugins/megatron_importer.py | 101 ++++++++++-------- modelopt/torch/export/quant_utils.py | 2 +- .../torch/export/unified_export_megatron.py | 35 +++--- 3 files changed, 69 insertions(+), 69 deletions(-) diff --git a/modelopt/torch/export/plugins/megatron_importer.py b/modelopt/torch/export/plugins/megatron_importer.py index b1c2a4546..254b0504d 100644 --- a/modelopt/torch/export/plugins/megatron_importer.py +++ b/modelopt/torch/export/plugins/megatron_importer.py @@ -277,8 +277,6 @@ def _grouped_mlp_merging( prefix = prefix.replace("model", "mtp") state_dict = module.state_dict() - # TODO handle weight_scale - # weight_scale = state_dict.get("weight_quantizer._scale", None) assert module.num_gemms == num_local_experts, ( "num_gemms must be equal to num_local_experts in TEGroupedMLP" @@ -286,6 +284,7 @@ def _grouped_mlp_merging( for expert_id in range(init_expert_id, init_expert_id + num_local_experts): tensor = self._get_safetensor(prefix.format(expert_id) + ".weight") state_dict[f"weight{expert_id}"] = tensor + # TODO handle weight_scale module.load_state_dict(state_dict) @@ -531,55 +530,71 @@ def _import_mamba_layer(self, layer, layer_id, layer_pbar): def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool = False): if not isinstance(layer.input_layernorm, IdentityOp): - self.rules["input_layernorm"](layer.input_layernorm, layer_id) + self.rules["input_layernorm"](layer.input_layernorm, layer_id, is_mtp=is_mtp) attention = layer.self_attention if not isinstance(attention, IdentityOp): if "MLASelfAttention" in str(type(attention)): if hasattr(attention, "linear_q_proj"): layer_pbar.set_description("Importing MLA (without q LoRA)") - self.rules["linear_q_proj"](attention.linear_q_proj, layer_id) + self.rules["linear_q_proj"](attention.linear_q_proj, layer_id, is_mtp=is_mtp) else: layer_pbar.set_description("Importing MLA (with q LoRA)") - self.rules["linear_q_down_proj"](attention.linear_q_down_proj, layer_id) - self.rules["linear_q_layernorm"](attention.q_layernorm, layer_id) - self.rules["linear_q_up_proj"](attention.linear_q_up_proj, layer_id) - self.rules["linear_kv_down_proj"](attention.linear_kv_down_proj, layer_id) - self.rules["linear_kv_layernorm"](attention.kv_layernorm, layer_id) - self.rules["linear_kv_up_proj"](attention.linear_kv_up_proj, layer_id) - self.rules["linear_proj"](attention.linear_proj, layer_id) + self.rules["linear_q_down_proj"]( + attention.linear_q_down_proj, layer_id, is_mtp=is_mtp + ) + self.rules["linear_q_layernorm"](attention.q_layernorm, layer_id, is_mtp=is_mtp) + self.rules["linear_q_up_proj"]( + attention.linear_q_up_proj, layer_id, is_mtp=is_mtp + ) + self.rules["linear_kv_down_proj"]( + attention.linear_kv_down_proj, layer_id, is_mtp=is_mtp + ) + self.rules["linear_kv_layernorm"](attention.kv_layernorm, layer_id, is_mtp=is_mtp) + self.rules["linear_kv_up_proj"]( + attention.linear_kv_up_proj, layer_id, is_mtp=is_mtp + ) + self.rules["linear_proj"](attention.linear_proj, layer_id, is_mtp=is_mtp) else: layer_pbar.set_description("Importing GQA/MHA") if attention.q_layernorm is not None and not isinstance( attention.q_layernorm, (IdentityOp, L2Norm) ): - self.rules["q_layernorm"](attention.q_layernorm, layer_id) - self.rules["k_layernorm"](attention.k_layernorm, layer_id) + self.rules["q_layernorm"](attention.q_layernorm, layer_id, is_mtp=is_mtp) + self.rules["k_layernorm"](attention.k_layernorm, layer_id, is_mtp=is_mtp) self.rules["linear_qkv"](attention.linear_qkv, layer_id, is_mtp=is_mtp) self.rules["linear_proj"](attention.linear_proj, layer_id, is_mtp=is_mtp) if getattr(attention.core_attention, "softmax_offset", None) is not None: - self.rules["softmax_offset"](attention.core_attention.softmax_offset, layer_id) + self.rules["softmax_offset"]( + attention.core_attention.softmax_offset, layer_id, is_mtp=is_mtp + ) if not isinstance(layer.pre_mlp_layernorm, IdentityOp): - self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id) + self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id, is_mtp=is_mtp) if not isinstance(layer.mlp, IdentityOp): if "MoE" in str(type(layer.mlp)): layer_pbar.set_description( f"Importing MoE with moe_router_dtype: {self.moe_router_dtype}" ) - self.rules["router"](layer.mlp.router, layer_id, dtype=self.moe_router_dtype) + self.rules["router"]( + layer.mlp.router, layer_id, dtype=self.moe_router_dtype, is_mtp=is_mtp + ) if hasattr(layer.mlp, "fc1_latent_proj") and layer.mlp.fc1_latent_proj is not None: - self.rules["fc1_latent_proj"](layer.mlp.fc1_latent_proj, layer_id) + self.rules["fc1_latent_proj"]( + layer.mlp.fc1_latent_proj, layer_id, is_mtp=is_mtp + ) if hasattr(layer.mlp, "fc2_latent_proj") and layer.mlp.fc2_latent_proj is not None: - self.rules["fc2_latent_proj"](layer.mlp.fc2_latent_proj, layer_id) + self.rules["fc2_latent_proj"]( + layer.mlp.fc2_latent_proj, layer_id, is_mtp=is_mtp + ) if hasattr(layer.mlp, "shared_experts") and layer.mlp.shared_experts is not None: layer_pbar.set_description("Importing MoE shared experts") fc1 = layer.mlp.shared_experts.linear_fc1 fc2 = layer.mlp.shared_experts.linear_fc2 - self.rules["shared_experts.linear_fc1"](fc1, layer_id) - self.rules["shared_experts.linear_fc2"](fc2, layer_id) + self.rules["shared_experts.linear_fc1"](fc1, layer_id, is_mtp=is_mtp) + self.rules["shared_experts.linear_fc2"](fc2, layer_id, is_mtp=is_mtp) if not self.rules.get("use_packed_local_experts", False): # Import local experts experts = layer.mlp.experts if hasattr(experts, "local_experts"): @@ -592,8 +607,12 @@ def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool = expert_id = layer.mlp.local_expert_indices[local_expert_id] fc1 = expert.linear_fc1 fc2 = expert.linear_fc2 - self.rules["local_experts.linear_fc1"](fc1, layer_id, expert_id) - self.rules["local_experts.linear_fc2"](fc2, layer_id, expert_id) + self.rules["local_experts.linear_fc1"]( + fc1, layer_id, expert_id, is_mtp=is_mtp + ) + self.rules["local_experts.linear_fc2"]( + fc2, layer_id, expert_id, is_mtp=is_mtp + ) else: # Slice TEGroupedMLP layer_pbar.set_description("Importing MoE grouped local experts") num_local_experts = experts.num_local_experts @@ -601,17 +620,6 @@ def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool = assert num_local_experts == num_global_experts, ( "num_local_experts must be equal to num_global_experts during MoE import" ) - - """ - if parallel_config is not None: - etp_size = get_expert_tensor_parallel_world_size() - # etp_rank = get_expert_tensor_parallel_rank() # this gives group rank - etp_rank = dist.get_rank() - print(f"etp_size: {etp_size}") - print(f"etp_rank: {etp_rank}") - assert num_local_experts * etp_size == num_global_experts - init_index = etp_rank * num_local_experts - """ init_index = 0 self.rules["experts.linear_fc1"]( @@ -619,12 +627,14 @@ def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool = layer_id, init_expert_id=init_index, num_local_experts=num_local_experts, + is_mtp=is_mtp, ) self.rules["experts.linear_fc2"]( experts.linear_fc2, layer_id, init_expert_id=init_index, num_local_experts=num_local_experts, + is_mtp=is_mtp, ) # We only support either EP or ETP for now @@ -634,27 +644,26 @@ def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool = if self.arch in ["GptOssForCausalLM"]: raise ValueError("ETP is not supported for gpt-oss model") self.rules["local_experts.linear_fc1_etp"]( - layer.mlp.experts.local_experts, layer_id + layer.mlp.experts.local_experts, layer_id, is_mtp=is_mtp ) self.rules["local_experts.linear_fc2_etp"]( - layer.mlp.experts.local_experts, layer_id + layer.mlp.experts.local_experts, layer_id, is_mtp=is_mtp ) else: # EP supports for packed MoE self.rules["local_experts.linear_fc1_ep"]( - layer.mlp.experts.local_experts, layer_id + layer.mlp.experts.local_experts, layer_id, is_mtp=is_mtp ) self.rules["local_experts.linear_fc2_ep"]( - layer.mlp.experts.local_experts, layer_id + layer.mlp.experts.local_experts, layer_id, is_mtp=is_mtp ) else: layer_pbar.set_description("Importing MLP") - self.rules["linear_fc1"](layer.mlp.linear_fc1, layer_id) - self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id) + self.rules["linear_fc1"](layer.mlp.linear_fc1, layer_id, is_mtp=is_mtp) + self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id, is_mtp=is_mtp) def _import_state_dict(self): model = self.model - # print(model, flush=True) layer_pbar = tqdm(model.decoder.layers, disable=self.disable_tqdm) # Embedding @@ -664,7 +673,7 @@ def _import_state_dict(self): # Decoder layers for layer in layer_pbar: - print(f"Importing layer {layer.layer_number}", flush=True) + layer_pbar.set_description(f"Importing Decoder layer {layer.layer_number}") layer_id = layer.layer_number - 1 if isinstance(layer, MambaLayer): @@ -692,8 +701,7 @@ def _import_state_dict(self): # MTP if hasattr(model, "mtp"): - print("Importing MTP", flush=True) - # MTP is the last layer in DeepSeek V3/R1 + layer_pbar.set_description("Importing MTP") if len(model.mtp.layers) == 1: # Repeated MTP layer_id = 0 # reset layer_id for repeated MTP mtp = model.mtp.layers[0] @@ -704,19 +712,18 @@ def _import_state_dict(self): mtp_model_layers = mtp.mtp_model_layer.layers for mtp_model_layer in mtp_model_layers: - if isinstance(mtp_model_layer, MambaLayer): - self._import_mamba_layer(mtp_model_layer, layer_id, layer_pbar) - elif isinstance(mtp_model_layer, TransformerLayer): + if isinstance(mtp_model_layer, TransformerLayer): self._import_transformer_layer( mtp_model_layer, layer_id, layer_pbar, is_mtp=True ) else: raise ValueError( - f"Unsupported layer type during MTP import: {type(mtp_model_layer)}" + f"Unsupported layer type during MTP import: {type(mtp_model_layer)}. Only TransformerLayer is supported." ) layer_id += 1 else: # non-repeated MTP + # MTP is the last layer in DeepSeek V3/R1 for mtp in model.mtp.layers: self.rules["mtp.eh_proj"](mtp.eh_proj, layer_id) self.rules["mtp.enorm"](mtp.enorm, layer_id) diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index 370a3e753..ac168ed2a 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -400,7 +400,7 @@ def get_kv_cache_bias(kv_module: nn.Module) -> list[torch.Tensor]: return kv_bias -def get_kv_cache_scaling_factor(self_attention_module: nn.Module) -> torch.Tensor: +def get_kv_cache_scaling_factor(self_attention_module: nn.Module) -> list[torch.Tensor | None]: """Get the K and V BMM scaling factors for the self attention module. Args: diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index 16fe0f6ad..13d31237b 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -243,9 +243,7 @@ def save_pretrained( state_dict = self.extra_state_dict if self.export_extra_modules else self.state_dict quantization_format = self._get_quantization_format(self.model) - quantization = None - if quantization_format in ( QUANTIZATION_FP8_PB_REAL, QUANTIZATION_FP8_PB_WO, @@ -289,18 +287,8 @@ def save_pretrained( except (OSError, ValueError, ImportError): pass - # Gather exclude_modules from all ranks to ensure hf_quant_config is complete - all_exclude_modules = [None] * torch.distributed.get_world_size() - torch.distributed.all_gather_object(all_exclude_modules, self.exclude_modules) - combined_exclude_modules = set() - for modules in all_exclude_modules: - if modules: - combined_exclude_modules.update(modules) - self.exclude_modules = sorted(combined_exclude_modules) - if is_last_stage_main_rank and quantization is not None: - # TODO refactor to use mte.quant_utils.get_quant_config - # except layer names are different in MCore and HF + self._gather_exclude_modules() # gather exclude_modules from all ranks self._hf_quant_config = { "producer": { "name": "modelopt", @@ -489,7 +477,8 @@ def _get_state_dict(self): raise ValueError("Only TransformerLayer or MambaLayer are supported.") # MTP layer - self._get_mtp_state_dict() + if self._hf_pretrained_model_name is not None: + self._get_mtp_state_dict() def _get_transformer_layer_state_dict(self, layer, layer_id): if not isinstance(layer.input_layernorm, IdentityOp): @@ -567,7 +556,7 @@ def _get_transformer_layer_state_dict(self, layer, layer_id): def _get_mtp_state_dict(self): """Export the MTP module. - Currently, we copy the BF16 MTP weights from the pretrained model. + Currently, we copy the BF16 MTP weights from the pretrained model if the pretrained model has MTP layers. """ # TODO Implement MTP export for quantized MTP # Hacky version for now: copy MTP weights from pretrained model @@ -835,13 +824,9 @@ def _name_remapping( name_to_value, qformat, block_size = self._get_quantized_state(module, dtype, prefix=prefix) weight = name_to_value.pop("weight") - if "mixer.gate" in prefix: - print(f"{prefix}: weight dtype: {weight.dtype}") weight_scale, weight_scale_2 = self._get_weight_scales(name_to_value, qformat) if weight_scale is None: - if "mixer.gate" in prefix: - print(f"{prefix}: weight_scale is None") self._state_dict[prefix + "weight"] = weight else: self._state_dict[prefix + "weight"] = to_quantized_weight( @@ -865,8 +850,6 @@ def _name_remapping( source_key = mapping.get(key, key) self._state_dict[prefix + source_key] = val print(f"{prefix + source_key}: {self._state_dict[prefix + source_key].dtype}") - if "mixer.gate" in prefix: - print(f"{prefix}weight: {self._state_dict[prefix + 'weight'].dtype}") def _gated_mlp_slicing( self, module, prefix, gate_proj_name="gate_proj", up_proj_name="up_proj" @@ -1244,6 +1227,16 @@ def _pack_name_remapping_gpt_oss(self, module, prefix, layer_type=None): # TODO: May need to modify the key name later. self._state_dict[prefix + "_bias"] = merged_bias + def _gather_exclude_modules(self): + """Get exclude_modules from all ranks to ensure hf_quant_config is complete.""" + all_exclude_modules = [None] * torch.distributed.get_world_size() + torch.distributed.all_gather_object(all_exclude_modules, self.exclude_modules) + combined_exclude_modules = set() + for modules in all_exclude_modules: + if modules: + combined_exclude_modules.update(modules) + return sorted(combined_exclude_modules) + def export_mcore_gpt_to_hf( model: torch.nn.Module, From 45de68f0d7e4128b9f1820d1204c0e2d37424fd5 Mon Sep 17 00:00:00 2001 From: jenchen13 Date: Thu, 29 Jan 2026 15:16:42 -0800 Subject: [PATCH 15/20] fix bugs Signed-off-by: jenchen13 --- modelopt/torch/export/plugins/mcore_llama.py | 2 + .../torch/export/plugins/mcore_nemotron.py | 3 +- .../torch/export/plugins/megatron_importer.py | 13 +++- modelopt/torch/export/quant_utils.py | 2 +- .../torch/export/unified_export_megatron.py | 57 +++++++------- modelopt/torch/quantization/model_calib.py | 2 +- .../torch/quantization/plugins/megatron.py | 11 ++- tests/gpu/torch/export/test_export.py | 2 +- .../export/test_unified_export_megatron.py | 75 +++++++++++++++---- .../quantization/plugins/test_megatron.py | 3 - 10 files changed, 117 insertions(+), 53 deletions(-) diff --git a/modelopt/torch/export/plugins/mcore_llama.py b/modelopt/torch/export/plugins/mcore_llama.py index 03a2c5fe7..c4a1ff82b 100644 --- a/modelopt/torch/export/plugins/mcore_llama.py +++ b/modelopt/torch/export/plugins/mcore_llama.py @@ -38,6 +38,8 @@ "input_layernorm": NameRemapping("model.layers.{}.input_layernorm."), "linear_qkv": QKVSlicing("model.layers.{}.self_attn."), "linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj."), + # KV cache quant export + "core_attention": NameRemapping("model.layers.{}.self_attn."), "pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm."), "linear_fc1": GatedMLPSlicing("model.layers.{}.mlp."), "linear_fc2": NameRemapping("model.layers.{}.mlp.down_proj."), diff --git a/modelopt/torch/export/plugins/mcore_nemotron.py b/modelopt/torch/export/plugins/mcore_nemotron.py index 70bd00fb2..92611f54b 100644 --- a/modelopt/torch/export/plugins/mcore_nemotron.py +++ b/modelopt/torch/export/plugins/mcore_nemotron.py @@ -37,6 +37,7 @@ "input_layernorm": NameRemapping("model.layers.{}.input_layernorm."), "linear_qkv": QKVSlicing("model.layers.{}.self_attn."), "linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj."), + "core_attention": SelfAttentionScaling("backbone.layers.{}.mixer."), "pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm."), # NemotronForCausalLM is using square-relu where no gated handle is needed. "linear_fc1": NameRemapping("model.layers.{}.mlp.up_proj."), @@ -100,8 +101,6 @@ ), } -# TODO ADD MTP export - nemotron_h_causal_lm_export: dict[str, CustomModuleMapping] = { "word_embeddings": NameRemapping("backbone.embeddings."), "final_norm": NameRemapping("backbone.norm_f."), diff --git a/modelopt/torch/export/plugins/megatron_importer.py b/modelopt/torch/export/plugins/megatron_importer.py index 254b0504d..b4c1ec694 100644 --- a/modelopt/torch/export/plugins/megatron_importer.py +++ b/modelopt/torch/export/plugins/megatron_importer.py @@ -223,7 +223,14 @@ def _gated_mlp_merging( gate_proj_name="gate_proj", up_proj_name="up_proj", parallel_config: ParallelConfig | None = None, + is_mtp: bool = False, ): + if is_mtp: + if "backbone" in prefix: + prefix = prefix.replace("backbone", "mtp") + else: + prefix = prefix.replace("model", "mtp") + weight = module.state_dict().get("weight", None) weight_scale = module.state_dict().get("weight_quantizer._scale", None) @@ -425,6 +432,7 @@ def _unpack_name_remapping( prefix, layer_type: str, parallel_config: ParallelConfig | None = None, + is_mtp: bool = False, # no-op: necessary for _import_transformer_layer ): tensor = self._get_safetensor(prefix, parallel_config=parallel_config) @@ -455,6 +463,7 @@ def _unpack_name_remapping_gpt_oss( prefix, layer_type: str, parallel_config: ParallelConfig | None = None, + is_mtp: bool = False, # no-op: necessary for _import_transformer_layer ): tensor_blocks = self._get_safetensor(prefix + "_blocks", parallel_config=parallel_config) tensor_bias = self._get_safetensor(prefix + "_bias", parallel_config=parallel_config) @@ -718,12 +727,14 @@ def _import_state_dict(self): ) else: raise ValueError( - f"Unsupported layer type during MTP import: {type(mtp_model_layer)}. Only TransformerLayer is supported." + f"Unsupported layer type during MTP import: {type(mtp_model_layer)}.\n" + "Only TransformerLayer is supported." ) layer_id += 1 else: # non-repeated MTP # MTP is the last layer in DeepSeek V3/R1 + layer_id += 1 for mtp in model.mtp.layers: self.rules["mtp.eh_proj"](mtp.eh_proj, layer_id) self.rules["mtp.enorm"](mtp.enorm, layer_id) diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index ac168ed2a..146fb6e19 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -451,7 +451,7 @@ def get_kv_cache_dtype(modules: list[nn.Module] | nn.Module) -> str | None: return _compute_kv_cache_dtype(num_bits_list) -def _compute_kv_cache_dtype(num_bits_list: list[int]) -> str | None: +def _compute_kv_cache_dtype(num_bits_list: list[int | tuple[int, int]]) -> str | None: """Returns the kv_cache dtype. If num_bits of output_quantizer is (4, 3) then returns FP8; if it is 8, returns int8, diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index 13d31237b..dc5264fbf 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -101,6 +101,8 @@ class GPTModelExporter: export_extra_modules: If True, export extra modules like medusa_heads or eagle_module. Otherwise, only export the base model. dtype: The weights data type to export the unquantized layers. + trust_remote_code: Whether to trust remote code in the HuggingFace pretrained model. + moe_router_dtype: The data type of the MoE router. Can be "fp32", "fp64", or None (default to the model dtype). """ def __init__( @@ -110,7 +112,7 @@ def __init__( export_extra_modules: bool = False, dtype=torch.bfloat16, trust_remote_code: bool = False, - moe_router_dtype: torch.dtype | None = None, + moe_router_dtype: str | None = None, ): """Create a GPTModel exporter instance.""" if not isinstance(model, (GPTModel, MambaModel, LLaVAModel)): @@ -476,9 +478,8 @@ def _get_state_dict(self): else: raise ValueError("Only TransformerLayer or MambaLayer are supported.") - # MTP layer - if self._hf_pretrained_model_name is not None: - self._get_mtp_state_dict() + # Get MTP layer if exists + self._get_mtp_state_dict() def _get_transformer_layer_state_dict(self, layer, layer_id): if not isinstance(layer.input_layernorm, IdentityOp): @@ -508,7 +509,10 @@ def _get_transformer_layer_state_dict(self, layer, layer_id): self.rules["q_layernorm"](layer.self_attention.q_layernorm, layer_id) self.rules["k_layernorm"](layer.self_attention.k_layernorm, layer_id) self.rules["linear_qkv"](layer.self_attention.linear_qkv, layer_id) - if hasattr(layer.self_attention, "core_attention"): + if ( + hasattr(layer.self_attention, "core_attention") + and "core_attention" in self.rules + ): # KV cache quant export self.rules["core_attention"](layer.self_attention.core_attention, layer_id) self.rules["linear_proj"](layer.self_attention.linear_proj, layer_id) if getattr(layer.self_attention.core_attention, "softmax_offset", None) is not None: @@ -560,28 +564,29 @@ def _get_mtp_state_dict(self): """ # TODO Implement MTP export for quantized MTP # Hacky version for now: copy MTP weights from pretrained model - if os.path.isdir(self._hf_pretrained_model_name): - safetensors_index_file = ( - Path(self._hf_pretrained_model_name) / "model.safetensors.index.json" - ) - else: - safetensors_index_file = hf_hub_download( - repo_id=self._hf_pretrained_model_name, filename="model.safetensors.index.json" - ) + if self._hf_pretrained_model_name: + if os.path.isdir(self._hf_pretrained_model_name): + safetensors_index_file = ( + Path(self._hf_pretrained_model_name) / "model.safetensors.index.json" + ) + else: + safetensors_index_file = hf_hub_download( + repo_id=self._hf_pretrained_model_name, filename="model.safetensors.index.json" + ) - print(f"Exporting MTP: using safetensors_index_file: {safetensors_index_file}") - mtp_exists = False - if safetensors_index_file and os.path.exists(safetensors_index_file): - with open(safetensors_index_file) as f: - safetensors_index = json.load(f) - model_dir = Path(safetensors_index_file).parent - for key in safetensors_index["weight_map"]: - if key.startswith("mtp.") and key not in self._state_dict: - self._state_dict[key] = get_safetensor(model_dir, key) - mtp_exists = True - - if mtp_exists: - self.exclude_modules.append("mtp*") + print(f"Exporting MTP: using safetensors_index_file: {safetensors_index_file}") + mtp_exists = False + if safetensors_index_file and os.path.exists(safetensors_index_file): + with open(safetensors_index_file) as f: + safetensors_index = json.load(f) + model_dir = Path(safetensors_index_file).parent + for key in safetensors_index["weight_map"]: + if key.startswith("mtp.") and key not in self._state_dict: + self._state_dict[key] = get_safetensor(model_dir, key) + mtp_exists = True + + if mtp_exists: + self.exclude_modules.append("mtp*") def _get_mamba_layer_state_dict(self, layer, layer_id): if not isinstance(layer.norm, IdentityOp): diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 386fe8c2c..88ebaa906 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -84,7 +84,7 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis # Step 1: Sync amax across local experts in a SequentialMLP for name, module in model.named_modules(): if hasattr(module, "sync_moe_local_experts_amax"): - module.sync_moe_local_experts_amax() + module.layer_sync_moe_local_experts_amax() if not distributed_sync: return diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 6fa99e819..44784409b 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -575,12 +575,15 @@ def _setup(self): expert.linear_fc1.parallel_state = self.parallel_state expert.linear_fc2.parallel_state = self.parallel_state - def sync_moe_local_experts_amax(self): + def layer_sync_moe_local_experts_amax(self): """Sync amax across local experts in a SequentialMLP. - amax across EP and ETP (for RowParallel) are synchronized as part of model_calib.max_calibrate(). - This function is called to synchronize the amax values across local experts s.t. all localexperts will - share the same amax. + Synchronize the amax values across local experts in a lyaer such that all local experts will + share the same amax. This function operates on a single rank and does not require distributed sync. + + Distributed amax sync across EP and ETP (for RowParallel) happens in model_calib.max_calibrate(). + This function should be called before the distributed sync to ensure the amax values + are synchronized across the layer first. """ # Collect amax from all local experts amax_dict = {} diff --git a/tests/gpu/torch/export/test_export.py b/tests/gpu/torch/export/test_export.py index 19aa63505..9d2afba41 100644 --- a/tests/gpu/torch/export/test_export.py +++ b/tests/gpu/torch/export/test_export.py @@ -222,7 +222,7 @@ def test_get_scaling_factor_from_weight(weight, group_size, expected): KV_CACHE_FP8, 128.0, { - "layer1.k_proj.k_scale": torch.tensor([1.0]), + "layer1.k_proj.k_scale": torch.tensor([0.001]), "layer1.v_proj.v_scale": torch.tensor([2.0]), }, ), diff --git a/tests/gpu/torch/export/test_unified_export_megatron.py b/tests/gpu/torch/export/test_unified_export_megatron.py index c07c2b565..700a23124 100644 --- a/tests/gpu/torch/export/test_unified_export_megatron.py +++ b/tests/gpu/torch/export/test_unified_export_megatron.py @@ -16,6 +16,7 @@ import json from copy import deepcopy from functools import partial +from pathlib import Path import pytest import torch @@ -27,14 +28,39 @@ skip_if_no_megatron(apex_or_te_required=True) +import modelopt.torch.quantization as mtq import modelopt.torch.speculative as mtsp -from modelopt.torch.export import export_mcore_gpt_to_hf, import_mcore_gpt_from_hf +from modelopt.torch.export import KV_CACHE_FP8, export_mcore_gpt_to_hf, import_mcore_gpt_from_hf from modelopt.torch.speculative.eagle.default_config import default_eagle_config from modelopt.torch.speculative.plugins.megatron_eagle import _DynamicEagleGPTModel from modelopt.torch.speculative.plugins.megatron_medusa import _DynamicMedusaGPTModel -def _test_unified_export_megatron(tmp_path, model_type, arch, algo, rank, size): +def _verify_model_configs( + export_dir: Path, quant_config: str | None = None, kv_cache_quant_cfg: str | None = None +): + """Verify config.json and hf_quant_config.json""" + config_dict = json.load(open(export_dir / "config.json")) + hf_quant_config_dict = json.load(open(export_dir / "hf_quant_config.json")) + assert config_dict["quantization_config"] == hf_quant_config_dict + assert hf_quant_config_dict["producer"]["name"] == "modelopt" + if quant_config: + quant_config_dict = hf_quant_config_dict["quantization"] + quant_type = quant_config_dict["quant_algo"] + assert ( + quant_type in quant_config + ) # quant config str is subset of quant config e.g. NVFP4 -> NVFP4_DEFAULT_CFG + assert len(quant_config_dict["exclude_modules"]) > 1 # Dynamically added exclude modules + if quant_type == "NVFP4": + assert quant_config_dict["group_size"] == 16 + + if kv_cache_quant_cfg: + assert quant_config["kv_cache_quant_algo"] == KV_CACHE_FP8 + + +def _test_unified_export_megatron( + tmp_path, model_type, arch, extra_module, quant_config, kv_cache_quant_cfg, rank, size +): num_layers = 2 hidden_size = 64 num_attention_heads = 8 @@ -63,14 +89,23 @@ def _test_unified_export_megatron(tmp_path, model_type, arch, algo, rank, size): transformer_impl="modelopt", ).cuda() - if algo == "medusa": + if quant_config is not None: + quant_config_dict = getattr(mtq, quant_config) + if kv_cache_quant_cfg: + kv_quant_cfg = getattr(mtq, kv_cache_quant_cfg)["quant_cfg"] + quant_config_dict = mtq.utils.update_quant_cfg_with_kv_cache_quant( + quant_config_dict, kv_quant_cfg + ) + model = mtq.quantize(model, quant_config_dict) + + if extra_module == "medusa": config = { "medusa_num_heads": 1, "medusa_num_layers": 1, } model = mtsp.convert(model, [("medusa", config)]) assert isinstance(model, _DynamicMedusaGPTModel) - elif algo == "eagle": + elif extra_module == "eagle": config = {"eagle_architecture_config": deepcopy(default_eagle_config)} model = mtsp.convert(model, [("eagle", config)]) assert isinstance(model, _DynamicEagleGPTModel) @@ -91,25 +126,35 @@ def _test_unified_export_megatron(tmp_path, model_type, arch, algo, rank, size): with open(tmp_path / "config.json", "w") as f: json.dump(pretrained_config, f) + tmp_export_dir = tmp_path / "export" export_mcore_gpt_to_hf( model, - tmp_path if arch is not None else None, + tmp_path / "pretrained" if arch is not None else None, dtype=torch.bfloat16, + export_dir=tmp_export_dir, ) + _verify_model_configs(tmp_export_dir, quant_config, kv_cache_quant_cfg) + @pytest.mark.parametrize( - ("model_type", "arch", "algo"), + ("model_type", "arch", "extra_module", "quant_config", "kv_cache_quant_cfg"), [ - ("nemotron", "NemotronForCausalLM", None), - ("nemotron", "NemotronForCausalLM", "eagle"), - ("nemotron", "NemotronForCausalLM", "medusa"), - ("llama", "LlamaForCausalLM", None), - ("llama", "LlamaForCausalLM", "eagle"), - ("llama", "LlamaForCausalLM", "medusa"), + ("nemotron", "NemotronForCausalLM", None, None, None), + ("nemotron", "NemotronForCausalLM", None, "NVFP4_DEFAULT_CFG", None), + ("nemotron", "NemotronForCausalLM", None, "NVFP4_DEFAULT_CFG", "FP8_KV_CFG"), + ("nemotron", "NemotronForCausalLM", "eagle", None, None), + ("nemotron", "NemotronForCausalLM", "medusa", None, None), + ("llama", "LlamaForCausalLM", None, None, None), + ("llama", "LlamaForCausalLM", None, "FP8_DEFAULT_CFG", None), + ("llama", "LlamaForCausalLM", None, "FP8_DEFAULT_CFG", "FP8_KV_CFG"), + ("llama", "LlamaForCausalLM", "eagle", None, None), + ("llama", "LlamaForCausalLM", "medusa", None, None), ], ) -def test_unified_export_megatron(tmp_path, model_type, arch, algo): +def test_unified_export_megatron( + tmp_path, model_type, arch, extra_module, quant_config, kv_cache_quant_cfg +): # TODO: Fix TP>1 failures spawn_multiprocess_job( size=1, # torch.cuda.device_count(), @@ -118,7 +163,9 @@ def test_unified_export_megatron(tmp_path, model_type, arch, algo): tmp_path, model_type, arch, - algo, + extra_module, + quant_config, + kv_cache_quant_cfg, ), backend="nccl", ) diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index d02b02c18..c07fd9b53 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -841,9 +841,6 @@ def _test_expert_model_parallel_amax_sync( ) # calibrate the model with distributed sync and test synchronization mtq.model_calib.max_calibrate(model, forward, distributed_sync=True) - for module in model.modules(): - if hasattr(module, "sync_moe_local_experts_amax"): - module.sync_moe_local_experts_amax() final_sync, quantizer_type, rank_values = compare_amax_sync_across_expert_parallel(model) assert final_sync, f"Inconsistent amax for expert {quantizer_type} across ranks: {rank_values}" From 14c3d8457c370dd815f5cfc7e67f30b524da7528 Mon Sep 17 00:00:00 2001 From: jenchen13 Date: Fri, 30 Jan 2026 10:44:08 -0800 Subject: [PATCH 16/20] restore kv cache clamp Signed-off-by: jenchen13 --- modelopt/torch/export/quant_utils.py | 23 +++++++++++++++++++ tests/gpu/torch/export/test_export.py | 2 +- .../export/test_unified_export_megatron.py | 2 +- 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index 146fb6e19..ad9a19d7f 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -418,6 +418,18 @@ def get_kv_cache_scaling_factor(self_attention_module: nn.Module) -> list[torch. get_scaling_factor(getattr(self_attention_module, quantizer)) for quantizer in ("k_bmm_quantizer", "v_bmm_quantizer") ] + + # For FP8, we recommend default kv cache scaling factor to be 1. + if get_kv_cache_dtype(kv_module) == KV_CACHE_FP8: + for i, factor in enumerate(scaling_factors): + if factor.item() > 0.5: + warn( + f"Warning: Large KV activation detected: {factor.item()}, " + "Quantized KV cache may lead to higher accuracy drop." + ) + scaling_factors[i] = torch.max( + factor, torch.tensor([1.0], dtype=torch.float, device=factor.device) + ) return scaling_factors @@ -1020,6 +1032,17 @@ def postprocess_state_dict( value = value.float() / maxbound + # Warn if scale exceeds threshold + if quantization == KV_CACHE_FP8 and value.item() > 0.5: + logger.warning( + "Large KV activations detected. Quantized KV cache may lead to higher accuracy drop. " + "Setting KV cache scaling factor to at least 1." + ) + + # Ensure scale is at least 1 for KV_CACHE_FP8 + # We export real value for KV_CACHE_NVFP4 + if quantization == KV_CACHE_FP8: + value.clamp_(min=1.0) post_state_dict[prefix + new_suffix] = value break diff --git a/tests/gpu/torch/export/test_export.py b/tests/gpu/torch/export/test_export.py index 9d2afba41..19aa63505 100644 --- a/tests/gpu/torch/export/test_export.py +++ b/tests/gpu/torch/export/test_export.py @@ -222,7 +222,7 @@ def test_get_scaling_factor_from_weight(weight, group_size, expected): KV_CACHE_FP8, 128.0, { - "layer1.k_proj.k_scale": torch.tensor([0.001]), + "layer1.k_proj.k_scale": torch.tensor([1.0]), "layer1.v_proj.v_scale": torch.tensor([2.0]), }, ), diff --git a/tests/gpu/torch/export/test_unified_export_megatron.py b/tests/gpu/torch/export/test_unified_export_megatron.py index 700a23124..4ed587c4b 100644 --- a/tests/gpu/torch/export/test_unified_export_megatron.py +++ b/tests/gpu/torch/export/test_unified_export_megatron.py @@ -129,7 +129,7 @@ def _test_unified_export_megatron( tmp_export_dir = tmp_path / "export" export_mcore_gpt_to_hf( model, - tmp_path / "pretrained" if arch is not None else None, + tmp_path if arch is not None else None, dtype=torch.bfloat16, export_dir=tmp_export_dir, ) From 5d1c704abacacba37c8a995879fe55c193ce930e Mon Sep 17 00:00:00 2001 From: jenchen13 Date: Fri, 30 Jan 2026 11:19:49 -0800 Subject: [PATCH 17/20] fix bug Signed-off-by: jenchen13 --- modelopt/torch/export/quant_utils.py | 2 +- tests/gpu/torch/export/test_unified_export_megatron.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index ad9a19d7f..0d99d44f0 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -420,7 +420,7 @@ def get_kv_cache_scaling_factor(self_attention_module: nn.Module) -> list[torch. ] # For FP8, we recommend default kv cache scaling factor to be 1. - if get_kv_cache_dtype(kv_module) == KV_CACHE_FP8: + if get_kv_cache_dtype(self_attention_module) == KV_CACHE_FP8: for i, factor in enumerate(scaling_factors): if factor.item() > 0.5: warn( diff --git a/tests/gpu/torch/export/test_unified_export_megatron.py b/tests/gpu/torch/export/test_unified_export_megatron.py index 4ed587c4b..c0bef3bb6 100644 --- a/tests/gpu/torch/export/test_unified_export_megatron.py +++ b/tests/gpu/torch/export/test_unified_export_megatron.py @@ -131,7 +131,7 @@ def _test_unified_export_megatron( model, tmp_path if arch is not None else None, dtype=torch.bfloat16, - export_dir=tmp_export_dir, + export_dir=str(tmp_export_dir), ) _verify_model_configs(tmp_export_dir, quant_config, kv_cache_quant_cfg) From e7c32a993ea80ef8f5345bba64e6395e0c44a549 Mon Sep 17 00:00:00 2001 From: jenchen13 Date: Fri, 30 Jan 2026 12:10:42 -0800 Subject: [PATCH 18/20] fxi bug Signed-off-by: jenchen13 --- modelopt/torch/export/plugins/mcore_llama.py | 3 ++- tests/gpu/torch/export/test_unified_export_megatron.py | 7 ++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/modelopt/torch/export/plugins/mcore_llama.py b/modelopt/torch/export/plugins/mcore_llama.py index c4a1ff82b..7fb8ec76a 100644 --- a/modelopt/torch/export/plugins/mcore_llama.py +++ b/modelopt/torch/export/plugins/mcore_llama.py @@ -30,6 +30,7 @@ PackNameRemapping, QKVMerging, QKVSlicing, + SelfAttentionScaling, UnpackNameRemapping, ) @@ -39,7 +40,7 @@ "linear_qkv": QKVSlicing("model.layers.{}.self_attn."), "linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj."), # KV cache quant export - "core_attention": NameRemapping("model.layers.{}.self_attn."), + "core_attention": SelfAttentionScaling("model.layers.{}.self_attn."), "pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm."), "linear_fc1": GatedMLPSlicing("model.layers.{}.mlp."), "linear_fc2": NameRemapping("model.layers.{}.mlp.down_proj."), diff --git a/tests/gpu/torch/export/test_unified_export_megatron.py b/tests/gpu/torch/export/test_unified_export_megatron.py index c0bef3bb6..0004fb781 100644 --- a/tests/gpu/torch/export/test_unified_export_megatron.py +++ b/tests/gpu/torch/export/test_unified_export_megatron.py @@ -36,7 +36,7 @@ from modelopt.torch.speculative.plugins.megatron_medusa import _DynamicMedusaGPTModel -def _verify_model_configs( +def _verify_model_quant_config( export_dir: Path, quant_config: str | None = None, kv_cache_quant_cfg: str | None = None ): """Verify config.json and hf_quant_config.json""" @@ -89,7 +89,7 @@ def _test_unified_export_megatron( transformer_impl="modelopt", ).cuda() - if quant_config is not None: + if quant_config: quant_config_dict = getattr(mtq, quant_config) if kv_cache_quant_cfg: kv_quant_cfg = getattr(mtq, kv_cache_quant_cfg)["quant_cfg"] @@ -134,7 +134,8 @@ def _test_unified_export_megatron( export_dir=str(tmp_export_dir), ) - _verify_model_configs(tmp_export_dir, quant_config, kv_cache_quant_cfg) + if quant_config: + _verify_model_quant_config(tmp_export_dir, quant_config, kv_cache_quant_cfg) @pytest.mark.parametrize( From d54846272b53ed4796cf69a51d0688dfa42e7ce0 Mon Sep 17 00:00:00 2001 From: jenchen13 Date: Fri, 30 Jan 2026 12:33:50 -0800 Subject: [PATCH 19/20] fix bug Signed-off-by: jenchen13 --- .../gpu/torch/export/test_unified_export_megatron.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/gpu/torch/export/test_unified_export_megatron.py b/tests/gpu/torch/export/test_unified_export_megatron.py index 0004fb781..723ab9d67 100644 --- a/tests/gpu/torch/export/test_unified_export_megatron.py +++ b/tests/gpu/torch/export/test_unified_export_megatron.py @@ -42,8 +42,15 @@ def _verify_model_quant_config( """Verify config.json and hf_quant_config.json""" config_dict = json.load(open(export_dir / "config.json")) hf_quant_config_dict = json.load(open(export_dir / "hf_quant_config.json")) - assert config_dict["quantization_config"] == hf_quant_config_dict - assert hf_quant_config_dict["producer"]["name"] == "modelopt" + # Make sure config.json and hf_quant_config.json are consistent + assert config_dict["quantization_config"]["quant_algo"] == hf_quant_config_dict["quantization"]["quant_algo"] + assert config_dict["quantization_config"]["ignore"] == hf_quant_config_dict["quantization"]["exclude_modules"] + + # Verify config.json + if kv_cache_quant_cfg: + assert config_dict["quantization_config"]["kv_cache_scheme"]["num_bits"] == 8 + + # Verify hf_quant_config.json if quant_config: quant_config_dict = hf_quant_config_dict["quantization"] quant_type = quant_config_dict["quant_algo"] From 9ad2d9f7efe48c8d77bc6601424d6e6999206779 Mon Sep 17 00:00:00 2001 From: jenchen13 Date: Fri, 30 Jan 2026 13:50:44 -0800 Subject: [PATCH 20/20] fix tests Signed-off-by: jenchen13 --- tests/_test_utils/torch/megatron/utils.py | 41 ++++++++++++++++++ .../export/test_unified_export_megatron.py | 14 ++++-- .../quantization/plugins/test_megatron.py | 43 +------------------ 3 files changed, 53 insertions(+), 45 deletions(-) diff --git a/tests/_test_utils/torch/megatron/utils.py b/tests/_test_utils/torch/megatron/utils.py index bb91f83cd..63904ba44 100644 --- a/tests/_test_utils/torch/megatron/utils.py +++ b/tests/_test_utils/torch/megatron/utils.py @@ -129,6 +129,47 @@ def run_mcore_inference_with_dummy_input( return run_mcore_inference(model, prompt_tokens, hidden_size) +def get_batch(model, batch_size=2): + seq_length = model.max_sequence_length + vocab_size = model.vocab_size + + input_ids = torch.randint(0, vocab_size, (batch_size, seq_length)).cuda() + labels = torch.randint(0, vocab_size, (batch_size, seq_length)).cuda() + position_ids = ( + torch.arange(seq_length, dtype=torch.int64).unsqueeze(0).repeat(batch_size, 1).cuda() + ) + attention_mask = torch.tril( + torch.ones((batch_size, 1, seq_length, seq_length), dtype=torch.bool) + ).cuda() + loss_mask = torch.ones((batch_size, seq_length), dtype=torch.float32).cuda() + + return input_ids, labels, position_ids, attention_mask, loss_mask + + +def get_forward(model, batch_size=2): + """Return a forward function with cached batch inputs.""" + input_ids, labels, position_ids, attention_mask, loss_mask = get_batch(model, batch_size) + + def forward(model): + # MambaModel doesn't accept loss_mask argument + if isinstance(model, MambaModel): + return model.forward( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + labels=labels, + ) + return model.forward( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + labels=labels, + loss_mask=loss_mask, + ) + + return forward + + def initialize_for_megatron( tensor_model_parallel_size=1, pipeline_model_parallel_size=1, diff --git a/tests/gpu/torch/export/test_unified_export_megatron.py b/tests/gpu/torch/export/test_unified_export_megatron.py index 723ab9d67..54d6c16f2 100644 --- a/tests/gpu/torch/export/test_unified_export_megatron.py +++ b/tests/gpu/torch/export/test_unified_export_megatron.py @@ -24,6 +24,7 @@ from _test_utils.import_helper import skip_if_no_megatron from _test_utils.torch.distributed.utils import spawn_multiprocess_job from _test_utils.torch.megatron.models import get_mcore_gpt_model +from _test_utils.torch.megatron.utils import get_forward from _test_utils.torch.transformers_models import create_tiny_llama_dir skip_if_no_megatron(apex_or_te_required=True) @@ -43,8 +44,14 @@ def _verify_model_quant_config( config_dict = json.load(open(export_dir / "config.json")) hf_quant_config_dict = json.load(open(export_dir / "hf_quant_config.json")) # Make sure config.json and hf_quant_config.json are consistent - assert config_dict["quantization_config"]["quant_algo"] == hf_quant_config_dict["quantization"]["quant_algo"] - assert config_dict["quantization_config"]["ignore"] == hf_quant_config_dict["quantization"]["exclude_modules"] + assert ( + config_dict["quantization_config"]["quant_algo"] + == hf_quant_config_dict["quantization"]["quant_algo"] + ) + assert ( + config_dict["quantization_config"]["ignore"] + == hf_quant_config_dict["quantization"]["exclude_modules"] + ) # Verify config.json if kv_cache_quant_cfg: @@ -103,7 +110,8 @@ def _test_unified_export_megatron( quant_config_dict = mtq.utils.update_quant_cfg_with_kv_cache_quant( quant_config_dict, kv_quant_cfg ) - model = mtq.quantize(model, quant_config_dict) + forward = get_forward(model) + model = mtq.quantize(model, quant_config_dict, forward) if extra_module == "medusa": config = { diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index c07fd9b53..2ac5915e4 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -21,7 +21,6 @@ from _test_utils.import_helper import skip_if_no_megatron from _test_utils.torch.distributed.utils import spawn_multiprocess_job from _test_utils.torch.megatron.models import ( - MambaModel, MegatronModel, get_mcore_gpt_model, get_mcore_mamba_hybrid_model, @@ -29,6 +28,7 @@ from _test_utils.torch.megatron.utils import ( compare_amax_sync_across_expert_parallel, copy_weights_from_grouped_to_non_grouped, + get_forward, initialize_for_megatron, run_mcore_inference, sharded_state_dict_test_helper, @@ -69,47 +69,6 @@ SEED = 1234 -def get_batch(model, batch_size=2): - seq_length = model.max_sequence_length - vocab_size = model.vocab_size - - input_ids = torch.randint(0, vocab_size, (batch_size, seq_length)).cuda() - labels = torch.randint(0, vocab_size, (batch_size, seq_length)).cuda() - position_ids = ( - torch.arange(seq_length, dtype=torch.int64).unsqueeze(0).repeat(batch_size, 1).cuda() - ) - attention_mask = torch.tril( - torch.ones((batch_size, 1, seq_length, seq_length), dtype=torch.bool) - ).cuda() - loss_mask = torch.ones((batch_size, seq_length), dtype=torch.float32).cuda() - - return input_ids, labels, position_ids, attention_mask, loss_mask - - -def get_forward(model, batch_size=2): - """Return a forward function with cached batch inputs.""" - input_ids, labels, position_ids, attention_mask, loss_mask = get_batch(model, batch_size) - - def forward(model): - # MambaModel doesn't accept loss_mask argument - if isinstance(model, MambaModel): - return model.forward( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - labels=labels, - ) - return model.forward( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - labels=labels, - loss_mask=loss_mask, - ) - - return forward - - def test_convert_megatron_parallel_linear(distributed_setup_size_1): initialize_for_megatron(seed=SEED) set_seed(SEED)