diff --git a/modelopt/torch/export/plugins/mcore_custom.py b/modelopt/torch/export/plugins/mcore_custom.py index 23804b322..c269cef1d 100644 --- a/modelopt/torch/export/plugins/mcore_custom.py +++ b/modelopt/torch/export/plugins/mcore_custom.py @@ -103,6 +103,18 @@ def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any] ) +class GroupedMLPMerging(CustomModuleMapping): + """A custom module mapping that merges up_proj and down_proj for Grouped MLP.""" + + def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any] = {}): + """Create a custom module mapping that merges up_proj and down_proj for Grouped MLP.""" + super().__init__( + func_name="grouped_mlp_merging", + target_name_or_prefix=target_name_or_prefix, + func_kwargs=func_kwargs, + ) + + class GatedMLPMerging(CustomModuleMapping): """A custom module mapping that merges gate_proj and up_proj.""" @@ -127,6 +139,18 @@ def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any] ) +class SelfAttentionScaling(CustomModuleMapping): + """A custom module mapping that scales self attention.""" + + def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any] = {}): + """Create a custom module mapping that scales self attention.""" + super().__init__( + func_name="self_attention_scaling", + target_name_or_prefix=target_name_or_prefix, + func_kwargs=func_kwargs, + ) + + class GatedMLPSlicing(CustomModuleMapping): """A custom module mapping that slices gate_proj and up_proj.""" diff --git a/modelopt/torch/export/plugins/mcore_llama.py b/modelopt/torch/export/plugins/mcore_llama.py index 03a2c5fe7..7fb8ec76a 100644 --- a/modelopt/torch/export/plugins/mcore_llama.py +++ b/modelopt/torch/export/plugins/mcore_llama.py @@ -30,6 +30,7 @@ PackNameRemapping, QKVMerging, QKVSlicing, + SelfAttentionScaling, UnpackNameRemapping, ) @@ -38,6 +39,8 @@ "input_layernorm": NameRemapping("model.layers.{}.input_layernorm."), "linear_qkv": QKVSlicing("model.layers.{}.self_attn."), "linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj."), + # KV cache quant export + "core_attention": SelfAttentionScaling("model.layers.{}.self_attn."), "pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm."), "linear_fc1": GatedMLPSlicing("model.layers.{}.mlp."), "linear_fc2": NameRemapping("model.layers.{}.mlp.down_proj."), diff --git a/modelopt/torch/export/plugins/mcore_nemotron.py b/modelopt/torch/export/plugins/mcore_nemotron.py index 5fdb8ba1b..92611f54b 100644 --- a/modelopt/torch/export/plugins/mcore_nemotron.py +++ b/modelopt/torch/export/plugins/mcore_nemotron.py @@ -23,9 +23,11 @@ ROW_ETP, ROW_TP, CustomModuleMapping, + GroupedMLPMerging, NameRemapping, QKVMerging, QKVSlicing, + SelfAttentionScaling, ) # Example on adding a new CausalLM. @@ -35,6 +37,7 @@ "input_layernorm": NameRemapping("model.layers.{}.input_layernorm."), "linear_qkv": QKVSlicing("model.layers.{}.self_attn."), "linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj."), + "core_attention": SelfAttentionScaling("backbone.layers.{}.mixer."), "pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm."), # NemotronForCausalLM is using square-relu where no gated handle is needed. "linear_fc1": NameRemapping("model.layers.{}.mlp.up_proj."), @@ -81,9 +84,23 @@ "shared_experts.linear_fc2": NameRemapping( "backbone.layers.{}.mixer.shared_experts.down_proj.", ROW_TP ), + # Latent MoE + "fc1_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc1_latent_proj.", REPLICATE), + "fc2_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc2_latent_proj.", REPLICATE), + # Repeated MTP module + "mtp.enorm": NameRemapping("mtp.layers.{}.enorm.", {"is_mtp": True}), + "mtp.hnorm": NameRemapping("mtp.layers.{}.hnorm.", {"is_mtp": True}), + "mtp.eh_proj": NameRemapping("mtp.layers.{}.eh_proj.", {"is_mtp": True}), + "mtp.final_layernorm": NameRemapping("mtp.layers.{}.final_layernorm.", {"is_mtp": True}), + # Grouped local experts in MTP + "experts.linear_fc1": GroupedMLPMerging( + "mtp.layers.{}.mixer.experts.{{}}.up_proj", COL_ETP | {"is_mtp": True} + ), + "experts.linear_fc2": GroupedMLPMerging( + "mtp.layers.{}.mixer.experts.{{}}.down_proj", ROW_ETP | {"is_mtp": True} + ), } - nemotron_h_causal_lm_export: dict[str, CustomModuleMapping] = { "word_embeddings": NameRemapping("backbone.embeddings."), "final_norm": NameRemapping("backbone.norm_f."), @@ -101,6 +118,7 @@ "input_layernorm": NameRemapping("backbone.layers.{}.norm."), "linear_qkv": QKVSlicing("backbone.layers.{}.mixer."), "linear_proj": NameRemapping("backbone.layers.{}.mixer.o_proj."), + "core_attention": SelfAttentionScaling("backbone.layers.{}.mixer."), # MLP "pre_mlp_layernorm": NameRemapping("backbone.layers.{}.norm."), "linear_fc1": NameRemapping("backbone.layers.{}.mixer.up_proj."), @@ -115,4 +133,12 @@ "shared_experts.linear_fc2": NameRemapping( "backbone.layers.{}.mixer.shared_experts.down_proj." ), + # Latent MoE + "fc1_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc1_latent_proj."), + "fc2_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc2_latent_proj."), + # MTP + "mtp.enorm": NameRemapping("mtp.layers.{}.enorm."), + "mtp.hnorm": NameRemapping("mtp.layers.{}.hnorm."), + "mtp.eh_proj": NameRemapping("mtp.layers.{}.eh_proj."), + "mtp.final_layernorm": NameRemapping("mtp.layers.{}.final_layernorm."), } diff --git a/modelopt/torch/export/plugins/megatron_importer.py b/modelopt/torch/export/plugins/megatron_importer.py index f663e1921..b4c1ec694 100644 --- a/modelopt/torch/export/plugins/megatron_importer.py +++ b/modelopt/torch/export/plugins/megatron_importer.py @@ -19,7 +19,7 @@ from pathlib import Path import torch -import torch.distributed +import torch.distributed as dist from huggingface_hub import snapshot_download from tqdm import tqdm @@ -94,12 +94,12 @@ def __init__( if workspace_dir is None: workspace_dir = tempfile.gettempdir() pretrained_model_path = workspace_dir + "/" + pretrained_model_name_or_path - if torch.distributed.get_rank() == 0: + if dist.get_rank() == 0: snapshot_download( repo_id=pretrained_model_name_or_path, local_dir=pretrained_model_path, ) - torch.distributed.barrier() + dist.barrier() self.arch = self._hf_config.architectures[0] self.all_rules = self._populate_rule_book() self.rules = self.all_rules[self.arch] @@ -108,7 +108,7 @@ def __init__( self.dtype = dtype self.dequantize = dequantize self.verbose = verbose - self.disable_tqdm = torch.distributed.get_rank() > 0 or verbose + self.disable_tqdm = dist.get_rank() > 0 or verbose def _populate_rule_book(self): """The rule book maps each state_dict key to a Callable.""" @@ -119,6 +119,7 @@ def _custom_mapping_to_lambda(mapping): "name_remapping": self._name_remapping, "qkv_merging": self._qkv_merging, "gated_mlp_merging": self._gated_mlp_merging, + "grouped_mlp_merging": self._grouped_mlp_merging, "unpack_name_remapping": self._unpack_name_remapping, "unpack_name_remapping_gpt_oss": self._unpack_name_remapping_gpt_oss, } @@ -150,7 +151,13 @@ def _name_remapping( mapping={}, parallel_config: ParallelConfig | None = None, dtype: torch.dtype | None = None, + is_mtp: bool = False, ): + if is_mtp: + if "backbone" in prefix: + prefix = prefix.replace("backbone", "mtp") + else: + prefix = prefix.replace("model", "mtp") if dtype is None: dtype = self.dtype if isinstance(module, torch.Tensor): @@ -183,7 +190,7 @@ def _name_remapping( tensor = expanded_tensor state_dict["weight"] = tensor.view(dtype=weight.dtype).to(device=weight.device) else: - state_dict["weight"] = tensor.to(dtype=self.dtype).to(device=weight.device) + state_dict["weight"] = tensor.to(dtype=dtype).to(device=weight.device) # Handle the rest of the state_dict. for key, val in module.state_dict().items(): @@ -216,7 +223,14 @@ def _gated_mlp_merging( gate_proj_name="gate_proj", up_proj_name="up_proj", parallel_config: ParallelConfig | None = None, + is_mtp: bool = False, ): + if is_mtp: + if "backbone" in prefix: + prefix = prefix.replace("backbone", "mtp") + else: + prefix = prefix.replace("model", "mtp") + weight = module.state_dict().get("weight", None) weight_scale = module.state_dict().get("weight_quantizer._scale", None) @@ -254,6 +268,33 @@ def _gated_mlp_merging( module.load_state_dict(state_dict) + def _grouped_mlp_merging( + self, + module, + prefix, + parallel_config: ParallelConfig | None = None, + is_mtp: bool = False, + init_expert_id: int = 0, + num_local_experts: int = 1, + ): + if is_mtp: + if "backbone" in prefix: + prefix = prefix.replace("backbone", "mtp") + else: + prefix = prefix.replace("model", "mtp") + + state_dict = module.state_dict() + + assert module.num_gemms == num_local_experts, ( + "num_gemms must be equal to num_local_experts in TEGroupedMLP" + ) + for expert_id in range(init_expert_id, init_expert_id + num_local_experts): + tensor = self._get_safetensor(prefix.format(expert_id) + ".weight") + state_dict[f"weight{expert_id}"] = tensor + # TODO handle weight_scale + + module.load_state_dict(state_dict) + def _qkv_merging( self, module, @@ -262,7 +303,13 @@ def _qkv_merging( k_proj_name="k_proj", v_proj_name="v_proj", parallel_config: ParallelConfig | None = None, + is_mtp: bool = False, ): + if is_mtp: + if "backbone" in prefix: + prefix = prefix.replace("backbone", "mtp") + else: + prefix = prefix.replace("model", "mtp") config = module.config hidden_size = config.hidden_size num_query_groups = config.num_query_groups @@ -289,8 +336,9 @@ def _qkv_merging( state_dict = {} - weight = module.state_dict().get("weight", None) - weight_scale = module.state_dict().get("weight_quantizer._scale", None) + module_state_dict = module.state_dict() + weight = module_state_dict.get("weight", None) + weight_scale = module_state_dict.get("weight_quantizer._scale", None) if weight is None: raise ValueError(f"{module!s} does not contain weight!") @@ -344,7 +392,7 @@ def _qkv_merging( state_dict["weight"] = tensor.reshape(-1, hidden_size) # Handle bias merging - bias = module.state_dict().get("bias", None) + bias = module_state_dict.get("bias", None) if bias is not None: q_bias = self._get_safetensor( prefix + q_proj_name + ".bias", parallel_config=parallel_config @@ -371,6 +419,11 @@ def _qkv_merging( state_dict["bias"] = bias_tensor.reshape(-1) + layer_norm_weight = module_state_dict.get("layer_norm_weight", None) + if layer_norm_weight is not None: + state_dict["layer_norm_weight"] = layer_norm_weight + state_dict["_extra_state"] = None # for TE modules require _extra_state key + module.load_state_dict(state_dict) def _unpack_name_remapping( @@ -379,6 +432,7 @@ def _unpack_name_remapping( prefix, layer_type: str, parallel_config: ParallelConfig | None = None, + is_mtp: bool = False, # no-op: necessary for _import_transformer_layer ): tensor = self._get_safetensor(prefix, parallel_config=parallel_config) @@ -409,6 +463,7 @@ def _unpack_name_remapping_gpt_oss( prefix, layer_type: str, parallel_config: ParallelConfig | None = None, + is_mtp: bool = False, # no-op: necessary for _import_transformer_layer ): tensor_blocks = self._get_safetensor(prefix + "_blocks", parallel_config=parallel_config) tensor_bias = self._get_safetensor(prefix + "_bias", parallel_config=parallel_config) @@ -469,9 +524,155 @@ def _unpack_name_remapping_gpt_oss( linear_module.load_state_dict(state_dict) + def _import_mamba_layer(self, layer, layer_id, layer_pbar): + layer_pbar.set_description("Importing Mamba layer") + if not isinstance(layer.norm, IdentityOp): + self.rules["norm"](layer.norm, layer_id) + + self.rules["mixer_norm"](layer.mixer.norm, layer_id) + self.rules["A_log"](layer.mixer.A_log, layer_id) + self.rules["D"](layer.mixer.D, layer_id) + self.rules["dt_bias"](layer.mixer.dt_bias, layer_id) + self.rules["conv1d"](layer.mixer.conv1d, layer_id) + self.rules["in_proj"](layer.mixer.in_proj, layer_id) + self.rules["out_proj"](layer.mixer.out_proj, layer_id) + + def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool = False): + if not isinstance(layer.input_layernorm, IdentityOp): + self.rules["input_layernorm"](layer.input_layernorm, layer_id, is_mtp=is_mtp) + + attention = layer.self_attention + if not isinstance(attention, IdentityOp): + if "MLASelfAttention" in str(type(attention)): + if hasattr(attention, "linear_q_proj"): + layer_pbar.set_description("Importing MLA (without q LoRA)") + self.rules["linear_q_proj"](attention.linear_q_proj, layer_id, is_mtp=is_mtp) + else: + layer_pbar.set_description("Importing MLA (with q LoRA)") + self.rules["linear_q_down_proj"]( + attention.linear_q_down_proj, layer_id, is_mtp=is_mtp + ) + self.rules["linear_q_layernorm"](attention.q_layernorm, layer_id, is_mtp=is_mtp) + self.rules["linear_q_up_proj"]( + attention.linear_q_up_proj, layer_id, is_mtp=is_mtp + ) + self.rules["linear_kv_down_proj"]( + attention.linear_kv_down_proj, layer_id, is_mtp=is_mtp + ) + self.rules["linear_kv_layernorm"](attention.kv_layernorm, layer_id, is_mtp=is_mtp) + self.rules["linear_kv_up_proj"]( + attention.linear_kv_up_proj, layer_id, is_mtp=is_mtp + ) + self.rules["linear_proj"](attention.linear_proj, layer_id, is_mtp=is_mtp) + else: + layer_pbar.set_description("Importing GQA/MHA") + if attention.q_layernorm is not None and not isinstance( + attention.q_layernorm, (IdentityOp, L2Norm) + ): + self.rules["q_layernorm"](attention.q_layernorm, layer_id, is_mtp=is_mtp) + self.rules["k_layernorm"](attention.k_layernorm, layer_id, is_mtp=is_mtp) + self.rules["linear_qkv"](attention.linear_qkv, layer_id, is_mtp=is_mtp) + self.rules["linear_proj"](attention.linear_proj, layer_id, is_mtp=is_mtp) + if getattr(attention.core_attention, "softmax_offset", None) is not None: + self.rules["softmax_offset"]( + attention.core_attention.softmax_offset, layer_id, is_mtp=is_mtp + ) + + if not isinstance(layer.pre_mlp_layernorm, IdentityOp): + self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id, is_mtp=is_mtp) + + if not isinstance(layer.mlp, IdentityOp): + if "MoE" in str(type(layer.mlp)): + layer_pbar.set_description( + f"Importing MoE with moe_router_dtype: {self.moe_router_dtype}" + ) + self.rules["router"]( + layer.mlp.router, layer_id, dtype=self.moe_router_dtype, is_mtp=is_mtp + ) + if hasattr(layer.mlp, "fc1_latent_proj") and layer.mlp.fc1_latent_proj is not None: + self.rules["fc1_latent_proj"]( + layer.mlp.fc1_latent_proj, layer_id, is_mtp=is_mtp + ) + if hasattr(layer.mlp, "fc2_latent_proj") and layer.mlp.fc2_latent_proj is not None: + self.rules["fc2_latent_proj"]( + layer.mlp.fc2_latent_proj, layer_id, is_mtp=is_mtp + ) + + if hasattr(layer.mlp, "shared_experts") and layer.mlp.shared_experts is not None: + layer_pbar.set_description("Importing MoE shared experts") + fc1 = layer.mlp.shared_experts.linear_fc1 + fc2 = layer.mlp.shared_experts.linear_fc2 + self.rules["shared_experts.linear_fc1"](fc1, layer_id, is_mtp=is_mtp) + self.rules["shared_experts.linear_fc2"](fc2, layer_id, is_mtp=is_mtp) + if not self.rules.get("use_packed_local_experts", False): # Import local experts + experts = layer.mlp.experts + if hasattr(experts, "local_experts"): + for local_expert_id, expert in tqdm( + enumerate(layer.mlp.experts.local_experts), + desc="Importing MoE local experts", + leave=False, + disable=self.disable_tqdm, + ): + expert_id = layer.mlp.local_expert_indices[local_expert_id] + fc1 = expert.linear_fc1 + fc2 = expert.linear_fc2 + self.rules["local_experts.linear_fc1"]( + fc1, layer_id, expert_id, is_mtp=is_mtp + ) + self.rules["local_experts.linear_fc2"]( + fc2, layer_id, expert_id, is_mtp=is_mtp + ) + else: # Slice TEGroupedMLP + layer_pbar.set_description("Importing MoE grouped local experts") + num_local_experts = experts.num_local_experts + num_global_experts = experts.config.num_moe_experts + assert num_local_experts == num_global_experts, ( + "num_local_experts must be equal to num_global_experts during MoE import" + ) + init_index = 0 + + self.rules["experts.linear_fc1"]( + experts.linear_fc1, + layer_id, + init_expert_id=init_index, + num_local_experts=num_local_experts, + is_mtp=is_mtp, + ) + self.rules["experts.linear_fc2"]( + experts.linear_fc2, + layer_id, + init_expert_id=init_index, + num_local_experts=num_local_experts, + is_mtp=is_mtp, + ) + + # We only support either EP or ETP for now + elif get_expert_tensor_parallel_world_size() > 1: + # ETP supports for packed MoE + # ETP is not supported for gpt-oss model + if self.arch in ["GptOssForCausalLM"]: + raise ValueError("ETP is not supported for gpt-oss model") + self.rules["local_experts.linear_fc1_etp"]( + layer.mlp.experts.local_experts, layer_id, is_mtp=is_mtp + ) + self.rules["local_experts.linear_fc2_etp"]( + layer.mlp.experts.local_experts, layer_id, is_mtp=is_mtp + ) + else: + # EP supports for packed MoE + self.rules["local_experts.linear_fc1_ep"]( + layer.mlp.experts.local_experts, layer_id, is_mtp=is_mtp + ) + self.rules["local_experts.linear_fc2_ep"]( + layer.mlp.experts.local_experts, layer_id, is_mtp=is_mtp + ) + else: + layer_pbar.set_description("Importing MLP") + self.rules["linear_fc1"](layer.mlp.linear_fc1, layer_id, is_mtp=is_mtp) + self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id, is_mtp=is_mtp) + def _import_state_dict(self): model = self.model - layer_pbar = tqdm(model.decoder.layers, disable=self.disable_tqdm) # Embedding @@ -481,113 +682,18 @@ def _import_state_dict(self): # Decoder layers for layer in layer_pbar: + layer_pbar.set_description(f"Importing Decoder layer {layer.layer_number}") layer_id = layer.layer_number - 1 if isinstance(layer, MambaLayer): - if not isinstance(layer.norm, IdentityOp): - self.rules["norm"](layer.norm, layer_id) - - self.rules["mixer_norm"](layer.mixer.norm, layer_id) - self.rules["A_log"](layer.mixer.A_log, layer_id) - self.rules["D"](layer.mixer.D, layer_id) - self.rules["dt_bias"](layer.mixer.dt_bias, layer_id) - - self.rules["conv1d"](layer.mixer.conv1d, layer_id) - self.rules["in_proj"](layer.mixer.in_proj, layer_id) - self.rules["out_proj"](layer.mixer.out_proj, layer_id) - + self._import_mamba_layer(layer, layer_id, layer_pbar) elif isinstance(layer, TransformerLayer): - if not isinstance(layer.input_layernorm, IdentityOp): - self.rules["input_layernorm"](layer.input_layernorm, layer_id) - - attention = layer.self_attention - if not isinstance(attention, IdentityOp): - if "MLASelfAttention" in str(type(attention)): - if hasattr(attention, "linear_q_proj"): - layer_pbar.set_description("Importing MLA (without q LoRA)") - self.rules["linear_q_proj"](attention.linear_q_proj, layer_id) - else: - layer_pbar.set_description("Importing MLA (with q LoRA)") - self.rules["linear_q_down_proj"](attention.linear_q_down_proj, layer_id) - self.rules["linear_q_layernorm"](attention.q_layernorm, layer_id) - self.rules["linear_q_up_proj"](attention.linear_q_up_proj, layer_id) - self.rules["linear_kv_down_proj"](attention.linear_kv_down_proj, layer_id) - self.rules["linear_kv_layernorm"](attention.kv_layernorm, layer_id) - self.rules["linear_kv_up_proj"](attention.linear_kv_up_proj, layer_id) - self.rules["linear_proj"](attention.linear_proj, layer_id) - else: - layer_pbar.set_description("Importing GQA/MHA") - if attention.q_layernorm is not None and not isinstance( - attention.q_layernorm, (IdentityOp, L2Norm) - ): - self.rules["q_layernorm"](attention.q_layernorm, layer_id) - self.rules["k_layernorm"](attention.k_layernorm, layer_id) - self.rules["linear_qkv"](attention.linear_qkv, layer_id) - self.rules["linear_proj"](attention.linear_proj, layer_id) - if getattr(attention.core_attention, "softmax_offset", None) is not None: - self.rules["softmax_offset"]( - attention.core_attention.softmax_offset, layer_id - ) - - if not isinstance(layer.pre_mlp_layernorm, IdentityOp): - self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id) - - if not isinstance(layer.mlp, IdentityOp): - if "MoE" in str(type(layer.mlp)): - layer_pbar.set_description("Importing MoE") - self.rules["router"]( - layer.mlp.router, layer_id, dtype=self.moe_router_dtype - ) - if ( - hasattr(layer.mlp, "shared_experts") - and layer.mlp.shared_experts is not None - ): - layer_pbar.set_description("Importing MoE shared experts") - fc1 = layer.mlp.shared_experts.linear_fc1 - fc2 = layer.mlp.shared_experts.linear_fc2 - self.rules["shared_experts.linear_fc1"](fc1, layer_id) - self.rules["shared_experts.linear_fc2"](fc2, layer_id) - if not self.rules.get("use_packed_local_experts", False): - for local_expert_id, expert in tqdm( - enumerate(layer.mlp.experts.local_experts), - desc="Importing MoE local experts", - leave=False, - disable=self.disable_tqdm, - ): - expert_id = layer.mlp.local_expert_indices[local_expert_id] - fc1 = expert.linear_fc1 - fc2 = expert.linear_fc2 - self.rules["local_experts.linear_fc1"](fc1, layer_id, expert_id) - self.rules["local_experts.linear_fc2"](fc2, layer_id, expert_id) - # We only support either EP or ETP for now - elif get_expert_tensor_parallel_world_size() > 1: - # ETP supports for packed MoE - # ETP is not supported for gpt-oss model - if self.arch in ["GptOssForCausalLM"]: - raise ValueError("ETP is not supported for gpt-oss model") - self.rules["local_experts.linear_fc1_etp"]( - layer.mlp.experts.local_experts, layer_id - ) - self.rules["local_experts.linear_fc2_etp"]( - layer.mlp.experts.local_experts, layer_id - ) - else: - # EP supports for packed MoE - self.rules["local_experts.linear_fc1_ep"]( - layer.mlp.experts.local_experts, layer_id - ) - self.rules["local_experts.linear_fc2_ep"]( - layer.mlp.experts.local_experts, layer_id - ) - else: - layer_pbar.set_description("Importing MLP") - self.rules["linear_fc1"](layer.mlp.linear_fc1, layer_id) - self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id) + self._import_transformer_layer(layer, layer_id, layer_pbar) if self.verbose: print( "{:3}/{:3} completes importing layer {:3}.".format( - torch.distributed.get_rank(), torch.distributed.get_world_size(), layer_id + dist.get_rank(), dist.get_world_size(), layer_id ), flush=True, ) @@ -595,67 +701,92 @@ def _import_state_dict(self): # Final layernorm if hasattr(model.decoder, "final_layernorm") and model.decoder.final_layernorm: self.rules["final_layernorm"](model.decoder.final_layernorm) - if hasattr(model.decoder, "final_norm") and model.decoder.final_norm: self.rules["final_norm"](model.decoder.final_norm) # Output layer if hasattr(model, "output_layer") and not model.share_embeddings_and_output_weights: self.rules["output_layer"](model.output_layer) + # MTP if hasattr(model, "mtp"): - # MTP is the last layer in DeepSeek V3/R1 - layer_id += 1 - for mtp in model.mtp: - self.rules["mtp.fc"](mtp.fc, layer_id) + layer_pbar.set_description("Importing MTP") + if len(model.mtp.layers) == 1: # Repeated MTP + layer_id = 0 # reset layer_id for repeated MTP + mtp = model.mtp.layers[0] + + self.rules["mtp.eh_proj"](mtp.eh_proj, layer_id) self.rules["mtp.enorm"](mtp.enorm, layer_id) self.rules["mtp.hnorm"](mtp.hnorm, layer_id) - self.rules["mtp.input_layernorm"](mtp.decoder.layers[0].input_layernorm, layer_id) - if hasattr(mtp.decoder.layers[0].self_attention, "linear_q_proj"): - self.rules["mtp.linear_q_proj"]( - mtp.decoder.layers[0].self_attention.linear_q_proj, layer_id + + mtp_model_layers = mtp.mtp_model_layer.layers + for mtp_model_layer in mtp_model_layers: + if isinstance(mtp_model_layer, TransformerLayer): + self._import_transformer_layer( + mtp_model_layer, layer_id, layer_pbar, is_mtp=True + ) + else: + raise ValueError( + f"Unsupported layer type during MTP import: {type(mtp_model_layer)}.\n" + "Only TransformerLayer is supported." + ) + + layer_id += 1 + else: # non-repeated MTP + # MTP is the last layer in DeepSeek V3/R1 + layer_id += 1 + for mtp in model.mtp.layers: + self.rules["mtp.eh_proj"](mtp.eh_proj, layer_id) + self.rules["mtp.enorm"](mtp.enorm, layer_id) + self.rules["mtp.hnorm"](mtp.hnorm, layer_id) + self.rules["mtp.input_layernorm"]( + mtp.decoder.layers[0].input_layernorm, layer_id ) - else: - self.rules["mtp.linear_q_down_proj"]( - mtp.decoder.layers[0].self_attention.linear_q_down_proj, layer_id + if hasattr(mtp.decoder.layers[0].self_attention, "linear_q_proj"): + self.rules["mtp.linear_q_proj"]( + mtp.decoder.layers[0].self_attention.linear_q_proj, layer_id + ) + else: + self.rules["mtp.linear_q_down_proj"]( + mtp.decoder.layers[0].self_attention.linear_q_down_proj, layer_id + ) + self.rules["mtp.linear_q_layernorm"]( + mtp.decoder.layers[0].self_attention.q_layernorm, layer_id + ) + self.rules["mtp.linear_q_up_proj"]( + mtp.decoder.layers[0].self_attention.linear_q_up_proj, layer_id + ) + self.rules["mtp.linear_kv_down_proj"]( + mtp.decoder.layers[0].self_attention.linear_kv_down_proj, layer_id ) - self.rules["mtp.linear_q_layernorm"]( - mtp.decoder.layers[0].self_attention.q_layernorm, layer_id + self.rules["mtp.linear_kv_layernorm"]( + mtp.decoder.layers[0].self_attention.kv_layernorm, layer_id ) - self.rules["mtp.linear_q_up_proj"]( - mtp.decoder.layers[0].self_attention.linear_q_up_proj, layer_id + self.rules["mtp.linear_kv_up_proj"]( + mtp.decoder.layers[0].self_attention.linear_kv_up_proj, layer_id ) - self.rules["mtp.linear_kv_down_proj"]( - mtp.decoder.layers[0].self_attention.linear_kv_down_proj, layer_id - ) - self.rules["mtp.linear_kv_layernorm"]( - mtp.decoder.layers[0].self_attention.kv_layernorm, layer_id - ) - self.rules["mtp.linear_kv_up_proj"]( - mtp.decoder.layers[0].self_attention.linear_kv_up_proj, layer_id - ) - self.rules["mtp.linear_proj"]( - mtp.decoder.layers[0].self_attention.linear_proj, layer_id - ) - self.rules["mtp.pre_mlp_layernorm"]( - mtp.decoder.layers[0].pre_mlp_layernorm, layer_id - ) - self.rules["mtp.router"](mtp.decoder.layers[0].mlp.router, layer_id) - self.rules["mtp.shared_experts.linear_fc1"]( - mtp.decoder.layers[0].mlp.shared_experts.linear_fc1, layer_id - ) - self.rules["mtp.shared_experts.linear_fc2"]( - mtp.decoder.layers[0].mlp.shared_experts.linear_fc2, layer_id - ) - for expert_id, expert in tqdm( - enumerate(mtp.decoder.layers[0].mlp.experts.local_experts), - desc="Importing MoE local experts", - leave=False, - disable=self.disable_tqdm, - ): - self.rules["mtp.local_experts.linear_fc1"]( - expert.linear_fc1, layer_id, expert_id + self.rules["mtp.linear_proj"]( + mtp.decoder.layers[0].self_attention.linear_proj, layer_id + ) + self.rules["mtp.pre_mlp_layernorm"]( + mtp.decoder.layers[0].pre_mlp_layernorm, layer_id + ) + self.rules["mtp.router"](mtp.decoder.layers[0].mlp.router, layer_id) + self.rules["mtp.shared_experts.linear_fc1"]( + mtp.decoder.layers[0].mlp.shared_experts.linear_fc1, layer_id ) - self.rules["mtp.local_experts.linear_fc2"]( - expert.linear_fc2, layer_id, expert_id + self.rules["mtp.shared_experts.linear_fc2"]( + mtp.decoder.layers[0].mlp.shared_experts.linear_fc2, layer_id ) + for expert_id, expert in tqdm( + enumerate(mtp.decoder.layers[0].mlp.experts.local_experts), + desc="Importing MoE local experts", + leave=False, + disable=self.disable_tqdm, + ): + self.rules["mtp.local_experts.linear_fc1"]( + expert.linear_fc1, layer_id, expert_id + ) + self.rules["mtp.local_experts.linear_fc2"]( + expert.linear_fc2, layer_id, expert_id + ) diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index e7587d222..0d99d44f0 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -400,18 +400,27 @@ def get_kv_cache_bias(kv_module: nn.Module) -> list[torch.Tensor]: return kv_bias -def get_kv_cache_scaling_factor(kv_module: nn.Module) -> list[torch.Tensor]: - """Returns the kv_cache scaling factor if output quantizer is set. Else returns None by default.""" - if not hasattr(kv_module, "k_bmm_quantizer") or not hasattr(kv_module, "v_bmm_quantizer"): +def get_kv_cache_scaling_factor(self_attention_module: nn.Module) -> list[torch.Tensor | None]: + """Get the K and V BMM scaling factors for the self attention module. + + Args: + self_attention_module: The self attention module to get the K and V BMM scaling factors from. + + Returns: + The K and V BMM scaling factors. + """ + if not hasattr(self_attention_module, "k_bmm_quantizer") or not hasattr( + self_attention_module, "v_bmm_quantizer" + ): return [None, None] scaling_factors = [ - get_scaling_factor(getattr(kv_module, quantizer)) + get_scaling_factor(getattr(self_attention_module, quantizer)) for quantizer in ("k_bmm_quantizer", "v_bmm_quantizer") ] # For FP8, we recommend default kv cache scaling factor to be 1. - if get_kv_cache_dtype(kv_module) == KV_CACHE_FP8: + if get_kv_cache_dtype(self_attention_module) == KV_CACHE_FP8: for i, factor in enumerate(scaling_factors): if factor.item() > 0.5: warn( @@ -421,7 +430,6 @@ def get_kv_cache_scaling_factor(kv_module: nn.Module) -> list[torch.Tensor]: scaling_factors[i] = torch.max( factor, torch.tensor([1.0], dtype=torch.float, device=factor.device) ) - return scaling_factors @@ -452,6 +460,23 @@ def get_kv_cache_dtype(modules: list[nn.Module] | nn.Module) -> str | None: num_bits_list.append(quantizer_attr.num_bits) is_affine &= hasattr(quantizer_attr, "_bias_value") + return _compute_kv_cache_dtype(num_bits_list) + + +def _compute_kv_cache_dtype(num_bits_list: list[int | tuple[int, int]]) -> str | None: + """Returns the kv_cache dtype. + + If num_bits of output_quantizer is (4, 3) then returns FP8; if it is 8, returns int8, + otherwise returns None. + + Args: + modules: The module or list of modules to inspect. + + Returns: + The kv_cache dtype. + """ + is_affine = True + if (4, 3) in num_bits_list: return KV_CACHE_FP8 elif 8 in num_bits_list: @@ -1018,7 +1043,6 @@ def postprocess_state_dict( # We export real value for KV_CACHE_NVFP4 if quantization == KV_CACHE_FP8: value.clamp_(min=1.0) - post_state_dict[prefix + new_suffix] = value break diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index 7b3883509..dc5264fbf 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -25,11 +25,9 @@ from collections import OrderedDict from pathlib import Path from typing import Any -from warnings import warn import torch import torch.distributed -import torch.nn as nn from huggingface_hub import hf_hub_download, snapshot_download from safetensors.torch import safe_open, save_file from tqdm import tqdm @@ -37,6 +35,7 @@ from modelopt import __version__ from modelopt.torch.utils import import_plugin +from .convert_hf_config import convert_hf_quant_config_format from .model_config import ( KV_CACHE_FP8, KV_CACHE_NVFP4, @@ -47,13 +46,13 @@ QUANTIZATION_NVFP4, ) from .plugins.mcore_common import all_mcore_hf_export_mapping -from .plugins.mcore_custom import CustomModuleMapping, save_safetensors +from .plugins.mcore_custom import CustomModuleMapping, get_safetensor, save_safetensors from .plugins.megatron_importer import GPTModelImporter from .quant_utils import ( get_activation_scaling_factor, get_kv_cache_dtype, + get_kv_cache_scaling_factor, get_quantization_format, - get_scaling_factor, get_weight_block_size, get_weight_scaling_factor, get_weight_scaling_factor_2, @@ -86,33 +85,6 @@ ] -# This path uses output_quantizer for KV cache quantization. -# The function below is the old version of get_kv_cache_scaling_factor which is now refactored to handle bmm_quantizer. -def get_kv_cache_scaling_factor(kv_module: nn.Module) -> torch.Tensor: - """Returns the kv_cache scaling factor if output quantizer is set. Else returns None by default.""" - scaling_factor = ( - get_scaling_factor(kv_module.output_quantizer) - if hasattr(kv_module, "output_quantizer") - else None - ) - - if not scaling_factor: - return None - - # For FP8, we recommend default kv cache scaling factor to be 1. - if get_kv_cache_dtype(kv_module) == KV_CACHE_FP8: - if scaling_factor.item() > 0.5: - warn( - f"!!!!Large KV activations detected: {scaling_factor.item()}, " - "Quantized KV cache may lead to higher accuracy drop.\n!!!!" - ) - scaling_factor = torch.max( - scaling_factor, - torch.tensor([1.0], dtype=torch.float, device=scaling_factor.device), - ) - return scaling_factor - - class GPTModelExporter: """Megatron Core GPTModel Exporter. @@ -129,6 +101,8 @@ class GPTModelExporter: export_extra_modules: If True, export extra modules like medusa_heads or eagle_module. Otherwise, only export the base model. dtype: The weights data type to export the unquantized layers. + trust_remote_code: Whether to trust remote code in the HuggingFace pretrained model. + moe_router_dtype: The data type of the MoE router. Can be "fp32", "fp64", or None (default to the model dtype). """ def __init__( @@ -138,7 +112,7 @@ def __init__( export_extra_modules: bool = False, dtype=torch.bfloat16, trust_remote_code: bool = False, - moe_router_dtype: torch.dtype | None = None, + moe_router_dtype: str | None = None, ): """Create a GPTModel exporter instance.""" if not isinstance(model, (GPTModel, MambaModel, LLaVAModel)): @@ -154,6 +128,7 @@ def __init__( self.moe_router_dtype = torch.float32 elif moe_router_dtype == "fp64": self.moe_router_dtype = torch.float64 + print(f"Exporting model with moe_router_dtype: {self.moe_router_dtype}") # If multimodal, extra the text_config self._hf_text_config = getattr(self._hf_config, "text_config", self._hf_config) @@ -181,6 +156,7 @@ def __init__( del self._hf_config.quantization_config self.all_rules = self._populate_rule_book() self.rules = self.all_rules[self.arch] + self.exclude_modules = [] if not hasattr(model, "_modelopt_state"): return @@ -267,10 +243,9 @@ def save_pretrained( # Main export process state_dict = self.extra_state_dict if self.export_extra_modules else self.state_dict - quantization_format = self._get_quantization_format(self.model) + quantization_format = self._get_quantization_format(self.model) quantization = None - if quantization_format in ( QUANTIZATION_FP8_PB_REAL, QUANTIZATION_FP8_PB_WO, @@ -281,11 +256,6 @@ def save_pretrained( elif quantization_format == QUANTIZATION_NVFP4: quantization = "NVFP4" - kv_cache_quantization = None - kv_cache_dtype = get_kv_cache_dtype(self.model) - if kv_cache_dtype in (KV_CACHE_FP8, KV_CACHE_NVFP4): - # FP8 KV Cache is supported in VLLM; NVFP4 supported in TRTLLM - kv_cache_quantization = kv_cache_dtype # We use the last PP rank and the 1st EP rank to write the config because # medusa_heads and eagle_module only exist in the last stage. if is_last_stage_main_rank: @@ -320,19 +290,23 @@ def save_pretrained( pass if is_last_stage_main_rank and quantization is not None: - hf_quant_config = { + self._gather_exclude_modules() # gather exclude_modules from all ranks + self._hf_quant_config = { "producer": { "name": "modelopt", "version": __version__, }, "quantization": { "quant_algo": quantization, - "kv_cache_quant_algo": kv_cache_quantization, - "exclude_modules": ["lm_head"], + "exclude_modules": self.exclude_modules, }, } + if quantization == "NVFP4": # update block size + self._hf_quant_config["quantization"]["group_size"] = 16 + if hasattr(self, "kv_cache_dtype"): + self._hf_quant_config["quantization"]["kv_cache_quant_algo"] = self.kv_cache_dtype with open(save_directory + "/hf_quant_config.json", "w") as f: - json.dump(hf_quant_config, f, indent=4) + json.dump(self._hf_quant_config, f, indent=4) if ( is_first_stage_main_rank @@ -450,6 +424,16 @@ def save_pretrained( if modeling_file and os.path.exists(modeling_file): shutil.copy(modeling_file, f"{save_directory}/modeling_{model_type}.py") + # Newer versions of VLLM expect config.json with hf_quant_config + config_json_file = save_directory + "/config.json" + if self._hf_quant_config and os.path.exists(config_json_file): + converted_quant_config = convert_hf_quant_config_format(self._hf_quant_config) + with open(config_json_file) as f: + config_dict = json.load(f) + config_dict["quantization_config"] = converted_quant_config + with open(config_json_file, "w") as f: + json.dump(config_dict, f, indent=4) + save_safetensors(state_dict, save_directory) @property @@ -466,6 +450,264 @@ def extra_state_dict(self): self._get_eagle_module_state_dict() return self._state_dict + def _get_state_dict(self): + model = self.model + + # Embedding + if hasattr(model, "embedding"): + self.rules["word_embeddings"](model.embedding.word_embeddings) + + # Final layernorm + if hasattr(model.decoder, "final_layernorm") and model.decoder.final_layernorm: + self.rules["final_layernorm"](model.decoder.final_layernorm) + + if hasattr(model.decoder, "final_norm") and model.decoder.final_norm: + self.rules["final_norm"](model.decoder.final_norm) + + # Output layer + if hasattr(model, "output_layer") and not model.share_embeddings_and_output_weights: + self.rules["output_layer"](model.output_layer) + + # Decoder layers + for layer in model.decoder.layers: + layer_id = layer.layer_number - 1 + if isinstance(layer, MambaLayer): + self._get_mamba_layer_state_dict(layer, layer_id) + elif isinstance(layer, TransformerLayer): + self._get_transformer_layer_state_dict(layer, layer_id) + else: + raise ValueError("Only TransformerLayer or MambaLayer are supported.") + + # Get MTP layer if exists + self._get_mtp_state_dict() + + def _get_transformer_layer_state_dict(self, layer, layer_id): + if not isinstance(layer.input_layernorm, IdentityOp): + self.rules["input_layernorm"](layer.input_layernorm, layer_id) + + if not isinstance(layer.self_attention, IdentityOp): + if "MLASelfAttention" in str(type(layer.self_attention)): + if hasattr(layer.self_attention, "linear_q_proj"): + self.rules["linear_q_proj"](layer.self_attention.linear_q_proj, layer_id) + else: + self.rules["linear_q_down_proj"]( + layer.self_attention.linear_q_down_proj, layer_id + ) + self.rules["linear_q_layernorm"](layer.self_attention.q_layernorm, layer_id) + self.rules["linear_q_up_proj"](layer.self_attention.linear_q_up_proj, layer_id) + + self.rules["linear_kv_down_proj"]( + layer.self_attention.linear_kv_down_proj, layer_id + ) + self.rules["linear_kv_layernorm"](layer.self_attention.kv_layernorm, layer_id) + self.rules["linear_kv_up_proj"](layer.self_attention.linear_kv_up_proj, layer_id) + self.rules["linear_proj"](layer.self_attention.linear_proj, layer_id) + else: + if layer.self_attention.q_layernorm is not None and not isinstance( + layer.self_attention.q_layernorm, (IdentityOp, L2Norm) + ): + self.rules["q_layernorm"](layer.self_attention.q_layernorm, layer_id) + self.rules["k_layernorm"](layer.self_attention.k_layernorm, layer_id) + self.rules["linear_qkv"](layer.self_attention.linear_qkv, layer_id) + if ( + hasattr(layer.self_attention, "core_attention") + and "core_attention" in self.rules + ): # KV cache quant export + self.rules["core_attention"](layer.self_attention.core_attention, layer_id) + self.rules["linear_proj"](layer.self_attention.linear_proj, layer_id) + if getattr(layer.self_attention.core_attention, "softmax_offset", None) is not None: + self.rules["softmax_offset"]( + layer.self_attention.core_attention.softmax_offset, layer_id + ) + + if not isinstance(layer.pre_mlp_layernorm, IdentityOp): + self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id) + + if not isinstance(layer.mlp, IdentityOp): + if "MoE" in str(type(layer.mlp)): + self.rules["router"](layer.mlp.router, layer_id, dtype=self.moe_router_dtype) + if hasattr(layer.mlp, "fc1_latent_proj") and layer.mlp.fc1_latent_proj is not None: + self.rules["fc1_latent_proj"](layer.mlp.fc1_latent_proj, layer_id) + if hasattr(layer.mlp, "fc2_latent_proj") and layer.mlp.fc2_latent_proj is not None: + self.rules["fc2_latent_proj"](layer.mlp.fc2_latent_proj, layer_id) + if hasattr(layer.mlp, "shared_experts") and layer.mlp.shared_experts is not None: + self.rules["shared_experts.linear_fc1"]( + layer.mlp.shared_experts.linear_fc1, layer_id + ) + self.rules["shared_experts.linear_fc2"]( + layer.mlp.shared_experts.linear_fc2, layer_id + ) + if not self.rules.get("use_packed_local_experts", False): + for expert_id, expert in enumerate(layer.mlp.experts.local_experts): + self.rules["local_experts.linear_fc1"]( + expert.linear_fc1, layer_id, expert_id + ) + self.rules["local_experts.linear_fc2"]( + expert.linear_fc2, layer_id, expert_id + ) + else: + # For llama 4, in hf unified checkpoint, all local experts share one scale + self.rules["local_experts.linear_fc1"]( + layer.mlp.experts.local_experts, layer_id + ) + self.rules["local_experts.linear_fc2"]( + layer.mlp.experts.local_experts, layer_id + ) + else: + self.rules["linear_fc1"](layer.mlp.linear_fc1, layer_id) + self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id) + + def _get_mtp_state_dict(self): + """Export the MTP module. + + Currently, we copy the BF16 MTP weights from the pretrained model if the pretrained model has MTP layers. + """ + # TODO Implement MTP export for quantized MTP + # Hacky version for now: copy MTP weights from pretrained model + if self._hf_pretrained_model_name: + if os.path.isdir(self._hf_pretrained_model_name): + safetensors_index_file = ( + Path(self._hf_pretrained_model_name) / "model.safetensors.index.json" + ) + else: + safetensors_index_file = hf_hub_download( + repo_id=self._hf_pretrained_model_name, filename="model.safetensors.index.json" + ) + + print(f"Exporting MTP: using safetensors_index_file: {safetensors_index_file}") + mtp_exists = False + if safetensors_index_file and os.path.exists(safetensors_index_file): + with open(safetensors_index_file) as f: + safetensors_index = json.load(f) + model_dir = Path(safetensors_index_file).parent + for key in safetensors_index["weight_map"]: + if key.startswith("mtp.") and key not in self._state_dict: + self._state_dict[key] = get_safetensor(model_dir, key) + mtp_exists = True + + if mtp_exists: + self.exclude_modules.append("mtp*") + + def _get_mamba_layer_state_dict(self, layer, layer_id): + if not isinstance(layer.norm, IdentityOp): + self.rules["norm"](layer.norm, layer_id) + + self.rules["mixer_norm"](layer.mixer.norm, layer_id) + self.rules["A_log"](layer.mixer.A_log, layer_id) + self.rules["D"](layer.mixer.D, layer_id) + self.rules["dt_bias"](layer.mixer.dt_bias, layer_id) + + self.rules["conv1d"](layer.mixer.conv1d, layer_id) + self.rules["in_proj"](layer.mixer.in_proj, layer_id) + self.rules["out_proj"](layer.mixer.out_proj, layer_id) + + def _get_medusa_heads_state_dict(self): + medusa_heads = getattr(self.model, "medusa_heads", None) + if medusa_heads is None: + return + + for head_id, head in enumerate(medusa_heads): + self.rules["medusa_heads.lm_head"](head.lm_head, head_id) + for layer_id, layer in enumerate(head.medusa_layers): + self.rules["medusa_heads.medusa_layers.linear"](layer.linear, head_id, layer_id) + + def _get_eagle_module_state_dict(self): + eagle_module = getattr(self.model, "eagle_module", None) + + if eagle_module is None: + return + + # if hasattr(self.model, "embedding"): + # self.rules["word_embeddings"](self.model.embedding.word_embeddings) + + self.rules["fc"](eagle_module.fc) + if self.model.eagle_config.use_aux_hidden_state: + self.rules["enorm"](eagle_module.enorm) + elif self.model.eagle_config.use_mtp_layernorm: + self.rules["enorm"](eagle_module.enorm) + self.rules["hnorm"](eagle_module.hnorm) + + if self.model.eagle_config.use_last_layernorm: + self.rules["final_layernorm"](eagle_module.decoder.final_layernorm) + + if hasattr(self.model.eagle_module, "eagle_output_layer"): + self.rules["output_layer"](eagle_module.eagle_output_layer) + if hasattr(self.model.eagle_module, "dt2"): + self.rules["d2t"](eagle_module.d2t) + + for layer in eagle_module.decoder.layers: + layer_id = layer.layer_number - 1 + + # The first layernorm needs special handling here. We have a dedicated mapping + # for the first layernorm since in EAGLE3 it will be mapped to hidden_norm + # instead of input_layernorm (due to the specialized transformer layer). + # The remaining EAGLE3 layers (if more than 1) are normal transformer layers + # where input_layernorm is mapped to input_layernorm. + if layer_id == 0 and self.model.eagle_config.use_input_layernorm_in_first_layer: + self.rules["first_input_layernorm"](layer.input_layernorm, layer_id) + elif layer_id > 0: + self.rules["input_layernorm"](layer.input_layernorm, layer_id) + + if "MLASelfAttention" in str(type(layer.self_attention)): + if hasattr(layer.self_attention, "linear_q_proj"): + self.rules["eagle_module.linear_q_proj"]( + layer.self_attention.linear_q_proj, layer_id + ) + else: + self.rules["eagle_module.linear_q_down_proj"]( + layer.self_attention.linear_q_down_proj, layer_id + ) + self.rules["eagle_module.linear_q_layernorm"]( + layer.self_attention.q_layernorm, layer_id + ) + self.rules["eagle_module.linear_q_up_proj"]( + layer.self_attention.linear_q_up_proj, layer_id + ) + + self.rules["eagle_module.linear_kv_down_proj"]( + layer.self_attention.linear_kv_down_proj, layer_id + ) + self.rules["eagle_module.linear_kv_layernorm"]( + layer.self_attention.kv_layernorm, layer_id + ) + self.rules["eagle_module.linear_kv_up_proj"]( + layer.self_attention.linear_kv_up_proj, layer_id + ) + else: + self.rules["linear_qkv"](layer.self_attention.linear_qkv, layer_id) + + self.rules["linear_proj"](layer.self_attention.linear_proj, layer_id) + self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id) + + if "MoE" in str(type(layer.mlp)): + self.rules["eagle_module.router"](layer.mlp.router, layer_id) + if hasattr(layer.mlp, "shared_experts") and layer.mlp.shared_experts is not None: + self.rules["eagle_module.shared_experts.linear_fc1"]( + layer.mlp.shared_experts.linear_fc1, layer_id + ) + self.rules["eagle_module.shared_experts.linear_fc2"]( + layer.mlp.shared_experts.linear_fc2, layer_id + ) + for expert_id, expert in enumerate(layer.mlp.experts.local_experts): + self.rules["eagle_module.local_experts.linear_fc1"]( + expert.linear_fc1, layer_id, expert_id + ) + self.rules["eagle_module.local_experts.linear_fc2"]( + expert.linear_fc2, layer_id, expert_id + ) + else: + self.rules["linear_fc1"](layer.mlp.linear_fc1, layer_id) + self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id) + + parallel_draft_heads = getattr(eagle_module, "parallel_draft_heads", None) + if parallel_draft_heads is not None: + for head_id, head in enumerate(parallel_draft_heads.medusa_heads): + for layer_id, layer in enumerate(head): + self.rules["parallel_draft_heads.medusa_layers"]( + layer.linear, head_id, layer_id + ) + self.rules["parallel_draft_heads.lm_head"](parallel_draft_heads.lm_head) + def _populate_rule_book(self): all_rules = {} @@ -473,6 +715,7 @@ def _custom_mapping_to_lambda(mapping): method_map = { "name_remapping": self._name_remapping, "qkv_slicing": self._qkv_slicing, + "self_attention_scaling": self._self_attention_scaling, "gated_mlp_slicing": self._gated_mlp_slicing, "pack_name_remapping": self._pack_name_remapping, "pack_name_remapping_gpt_oss": self._pack_name_remapping_gpt_oss, @@ -497,30 +740,38 @@ def _get_quantized_state( self, module: torch.nn.Module, dtype: torch.dtype = torch.float16, + prefix: str = "", ) -> tuple[dict[str, torch.Tensor], str, int]: """Return a state_dict, quantization format, and block_size of the module. Args: module: The target module to perform real quantization. dtype: The default data type. + prefix: The prefix of the layer. Returns: Tuple: state_dict, quantization format, and block_size of the module. """ name_to_value = {} qformat: str = self._get_quantization_format(module) + if qformat is None and "norm" not in prefix: # Add exclude layers for hf_quant_config + self.exclude_modules.append(prefix) block_size = get_weight_block_size(module) - if hasattr(module, "weight") and module.weight is not None: + if hasattr(module, "weight") and module.weight is not None and module.weight.numel() > 0: weight = module.weight.to(dtype).cpu() name_to_value["weight"] = weight else: return name_to_value, qformat, block_size - if hasattr(module, "bias") and module.bias is not None: + if hasattr(module, "bias") and module.bias is not None and module.bias.numel() > 0: name_to_value["bias"] = module.bias.to(dtype).cpu() - if hasattr(module, "expert_bias") and module.expert_bias is not None: + if ( + hasattr(module, "expert_bias") + and module.expert_bias is not None + and module.expert_bias.numel() > 0 + ): name_to_value["expert_bias"] = module.expert_bias.to(dtype).cpu() if qformat == QUANTIZATION_NONE: @@ -542,11 +793,6 @@ def _get_quantized_state( if hasattr(module.input_quantizer, "_pre_quant_scale"): raise ValueError("Detect pre_quant_scale! SmoothQuant/AWQ are not yet supported!") - if hasattr(module, "output_quantizer"): - output_scale = get_kv_cache_scaling_factor(module) - if output_scale is not None: - name_to_value["output_scale"] = output_scale - return name_to_value, qformat, block_size def _get_quantization_format(self, module: torch.nn.Module): @@ -580,7 +826,7 @@ def _name_remapping( self._state_dict[prefix] = module return - name_to_value, qformat, block_size = self._get_quantized_state(module, dtype) + name_to_value, qformat, block_size = self._get_quantized_state(module, dtype, prefix=prefix) weight = name_to_value.pop("weight") weight_scale, weight_scale_2 = self._get_weight_scales(name_to_value, qformat) @@ -608,11 +854,14 @@ def _name_remapping( else: source_key = mapping.get(key, key) self._state_dict[prefix + source_key] = val + print(f"{prefix + source_key}: {self._state_dict[prefix + source_key].dtype}") def _gated_mlp_slicing( self, module, prefix, gate_proj_name="gate_proj", up_proj_name="up_proj" ): - name_to_value, qformat, block_size = self._get_quantized_state(module, self.dtype) + name_to_value, qformat, block_size = self._get_quantized_state( + module, self.dtype, prefix=prefix + ) weight = name_to_value.pop("weight") weight_scale, weight_scale_2 = self._get_weight_scales(name_to_value, qformat) @@ -674,10 +923,10 @@ def _qkv_slicing( q_proj_name="q_proj", k_proj_name="k_proj", v_proj_name="v_proj", - k_scale_name="k_scale", - v_scale_name="v_scale", ): - name_to_value, qformat, block_size = self._get_quantized_state(module, self.dtype) + name_to_value, qformat, block_size = self._get_quantized_state( + module, self.dtype, prefix=prefix + ) q_proj_prefix = prefix + q_proj_name + "." k_proj_prefix = prefix + k_proj_name + "." @@ -774,10 +1023,7 @@ def _qkv_slicing( q_proj_key = q_proj_prefix + key k_proj_key = k_proj_prefix + key v_proj_key = v_proj_prefix + key - if key == "output_scale": - self._state_dict[prefix + k_scale_name] = val.detach().clone() - self._state_dict[prefix + v_scale_name] = val.detach().clone() - elif key == "bias": + if key == "bias": # Slice bias similar to weight bias = val.detach().clone() bias = bias.reshape([qkv_total_dim, head_size]) @@ -790,6 +1036,23 @@ def _qkv_slicing( self._state_dict[k_proj_key] = val.detach().clone() self._state_dict[v_proj_key] = val.detach().clone() + def _self_attention_scaling( + self, module, prefix, k_scale_name="k_scale", v_scale_name="v_scale" + ): + """KV cache scaling for CoreAttention module.""" + k_scale_key = prefix + k_scale_name + v_scale_key = prefix + v_scale_name + if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"): + kv_scales = get_kv_cache_scaling_factor(module) + if all(s is not None for s in kv_scales): + self._state_dict[k_scale_key] = kv_scales[0] + self._state_dict[v_scale_key] = kv_scales[1] + + kv_cache_dtype = get_kv_cache_dtype(module) + if kv_cache_dtype in (KV_CACHE_FP8, KV_CACHE_NVFP4): + # FP8 KV Cache is supported in VLLM; NVFP4 supported in TRTLLM + self.kv_cache_dtype = kv_cache_dtype + def _pack_name_remapping(self, module, prefix, layer_type=None): """Pack name remapping into one tensor.""" weight_list = [] @@ -800,7 +1063,7 @@ def _pack_name_remapping(self, module, prefix, layer_type=None): for expert in module: assert layer_type is not None, "layer_type is required for pack_name_remapping" name_to_value, qformat, block_size = self._get_quantized_state( - getattr(expert, layer_type), self.dtype + getattr(expert, layer_type), self.dtype, prefix=prefix ) weight = name_to_value.pop("weight") weight_scale, weight_scale_2 = self._get_weight_scales(name_to_value, qformat) @@ -866,7 +1129,7 @@ def _pack_name_remapping_gpt_oss(self, module, prefix, layer_type=None): for expert in module: assert layer_type is not None, "layer_type is required for pack_name_remapping" name_to_value, qformat, block_size = self._get_quantized_state( - getattr(expert, layer_type), self.dtype + getattr(expert, layer_type), self.dtype, prefix=prefix ) weight = name_to_value.pop("weight") bias = name_to_value.pop("bias", None) @@ -969,234 +1232,15 @@ def _pack_name_remapping_gpt_oss(self, module, prefix, layer_type=None): # TODO: May need to modify the key name later. self._state_dict[prefix + "_bias"] = merged_bias - def _get_medusa_heads_state_dict(self): - medusa_heads = getattr(self.model, "medusa_heads", None) - if medusa_heads is None: - return - - for head_id, head in enumerate(medusa_heads): - self.rules["medusa_heads.lm_head"](head.lm_head, head_id) - for layer_id, layer in enumerate(head.medusa_layers): - self.rules["medusa_heads.medusa_layers.linear"](layer.linear, head_id, layer_id) - - def _get_eagle_module_state_dict(self): - eagle_module = getattr(self.model, "eagle_module", None) - - if eagle_module is None: - return - - # if hasattr(self.model, "embedding"): - # self.rules["word_embeddings"](self.model.embedding.word_embeddings) - - self.rules["fc"](eagle_module.fc) - if self.model.eagle_config.use_aux_hidden_state: - self.rules["enorm"](eagle_module.enorm) - elif self.model.eagle_config.use_mtp_layernorm: - self.rules["enorm"](eagle_module.enorm) - self.rules["hnorm"](eagle_module.hnorm) - - if self.model.eagle_config.use_last_layernorm: - self.rules["final_layernorm"](eagle_module.decoder.final_layernorm) - - if hasattr(self.model.eagle_module, "eagle_output_layer"): - self.rules["output_layer"](eagle_module.eagle_output_layer) - if hasattr(self.model.eagle_module, "dt2"): - self.rules["d2t"](eagle_module.d2t) - - for layer in eagle_module.decoder.layers: - layer_id = layer.layer_number - 1 - - # The first layernorm needs special handling here. We have a dedicated mapping - # for the first layernorm since in EAGLE3 it will be mapped to hidden_norm - # instead of input_layernorm (due to the specialized transformer layer). - # The remaining EAGLE3 layers (if more than 1) are normal transformer layers - # where input_layernorm is mapped to input_layernorm. - if layer_id == 0 and self.model.eagle_config.use_input_layernorm_in_first_layer: - self.rules["first_input_layernorm"](layer.input_layernorm, layer_id) - elif layer_id > 0: - self.rules["input_layernorm"](layer.input_layernorm, layer_id) - - if "MLASelfAttention" in str(type(layer.self_attention)): - if hasattr(layer.self_attention, "linear_q_proj"): - self.rules["eagle_module.linear_q_proj"]( - layer.self_attention.linear_q_proj, layer_id - ) - else: - self.rules["eagle_module.linear_q_down_proj"]( - layer.self_attention.linear_q_down_proj, layer_id - ) - self.rules["eagle_module.linear_q_layernorm"]( - layer.self_attention.q_layernorm, layer_id - ) - self.rules["eagle_module.linear_q_up_proj"]( - layer.self_attention.linear_q_up_proj, layer_id - ) - - self.rules["eagle_module.linear_kv_down_proj"]( - layer.self_attention.linear_kv_down_proj, layer_id - ) - self.rules["eagle_module.linear_kv_layernorm"]( - layer.self_attention.kv_layernorm, layer_id - ) - self.rules["eagle_module.linear_kv_up_proj"]( - layer.self_attention.linear_kv_up_proj, layer_id - ) - else: - self.rules["linear_qkv"](layer.self_attention.linear_qkv, layer_id) - - self.rules["linear_proj"](layer.self_attention.linear_proj, layer_id) - self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id) - - if "MoE" in str(type(layer.mlp)): - self.rules["eagle_module.router"](layer.mlp.router, layer_id) - if hasattr(layer.mlp, "shared_experts") and layer.mlp.shared_experts is not None: - self.rules["eagle_module.shared_experts.linear_fc1"]( - layer.mlp.shared_experts.linear_fc1, layer_id - ) - self.rules["eagle_module.shared_experts.linear_fc2"]( - layer.mlp.shared_experts.linear_fc2, layer_id - ) - for expert_id, expert in enumerate(layer.mlp.experts.local_experts): - self.rules["eagle_module.local_experts.linear_fc1"]( - expert.linear_fc1, layer_id, expert_id - ) - self.rules["eagle_module.local_experts.linear_fc2"]( - expert.linear_fc2, layer_id, expert_id - ) - else: - self.rules["linear_fc1"](layer.mlp.linear_fc1, layer_id) - self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id) - - parallel_draft_heads = getattr(eagle_module, "parallel_draft_heads", None) - if parallel_draft_heads is not None: - for head_id, head in enumerate(parallel_draft_heads.medusa_heads): - for layer_id, layer in enumerate(head): - self.rules["parallel_draft_heads.medusa_layers"]( - layer.linear, head_id, layer_id - ) - self.rules["parallel_draft_heads.lm_head"](parallel_draft_heads.lm_head) - - def _get_state_dict(self): - model = self.model - - # Embedding - if hasattr(model, "embedding"): - self.rules["word_embeddings"](model.embedding.word_embeddings) - - # Final layernorm - if hasattr(model.decoder, "final_layernorm") and model.decoder.final_layernorm: - self.rules["final_layernorm"](model.decoder.final_layernorm) - - if hasattr(model.decoder, "final_norm") and model.decoder.final_norm: - self.rules["final_norm"](model.decoder.final_norm) - - # Output layer - if hasattr(model, "output_layer") and not model.share_embeddings_and_output_weights: - self.rules["output_layer"](model.output_layer) - - # Decoder layers - for layer in model.decoder.layers: - layer_id = layer.layer_number - 1 - - if isinstance(layer, MambaLayer): - if not isinstance(layer.norm, IdentityOp): - self.rules["norm"](layer.norm, layer_id) - - self.rules["mixer_norm"](layer.mixer.norm, layer_id) - self.rules["A_log"](layer.mixer.A_log, layer_id) - self.rules["D"](layer.mixer.D, layer_id) - self.rules["dt_bias"](layer.mixer.dt_bias, layer_id) - - self.rules["conv1d"](layer.mixer.conv1d, layer_id) - self.rules["in_proj"](layer.mixer.in_proj, layer_id) - self.rules["out_proj"](layer.mixer.out_proj, layer_id) - - elif isinstance(layer, TransformerLayer): - if not isinstance(layer.input_layernorm, IdentityOp): - self.rules["input_layernorm"](layer.input_layernorm, layer_id) - - if not isinstance(layer.self_attention, IdentityOp): - if "MLASelfAttention" in str(type(layer.self_attention)): - if hasattr(layer.self_attention, "linear_q_proj"): - self.rules["linear_q_proj"]( - layer.self_attention.linear_q_proj, layer_id - ) - else: - self.rules["linear_q_down_proj"]( - layer.self_attention.linear_q_down_proj, layer_id - ) - self.rules["linear_q_layernorm"]( - layer.self_attention.q_layernorm, layer_id - ) - self.rules["linear_q_up_proj"]( - layer.self_attention.linear_q_up_proj, layer_id - ) - - self.rules["linear_kv_down_proj"]( - layer.self_attention.linear_kv_down_proj, layer_id - ) - self.rules["linear_kv_layernorm"]( - layer.self_attention.kv_layernorm, layer_id - ) - self.rules["linear_kv_up_proj"]( - layer.self_attention.linear_kv_up_proj, layer_id - ) - self.rules["linear_proj"](layer.self_attention.linear_proj, layer_id) - else: - if layer.self_attention.q_layernorm is not None and not isinstance( - layer.self_attention.q_layernorm, (IdentityOp, L2Norm) - ): - self.rules["q_layernorm"](layer.self_attention.q_layernorm, layer_id) - self.rules["k_layernorm"](layer.self_attention.k_layernorm, layer_id) - self.rules["linear_qkv"](layer.self_attention.linear_qkv, layer_id) - self.rules["linear_proj"](layer.self_attention.linear_proj, layer_id) - if ( - getattr(layer.self_attention.core_attention, "softmax_offset", None) - is not None - ): - self.rules["softmax_offset"]( - layer.self_attention.core_attention.softmax_offset, layer_id - ) - - if not isinstance(layer.pre_mlp_layernorm, IdentityOp): - self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id) - - if not isinstance(layer.mlp, IdentityOp): - if "MoE" in str(type(layer.mlp)): - self.rules["router"]( - layer.mlp.router, layer_id, dtype=self.moe_router_dtype - ) - if ( - hasattr(layer.mlp, "shared_experts") - and layer.mlp.shared_experts is not None - ): - self.rules["shared_experts.linear_fc1"]( - layer.mlp.shared_experts.linear_fc1, layer_id - ) - self.rules["shared_experts.linear_fc2"]( - layer.mlp.shared_experts.linear_fc2, layer_id - ) - if not self.rules.get("use_packed_local_experts", False): - for expert_id, expert in enumerate(layer.mlp.experts.local_experts): - self.rules["local_experts.linear_fc1"]( - expert.linear_fc1, layer_id, expert_id - ) - self.rules["local_experts.linear_fc2"]( - expert.linear_fc2, layer_id, expert_id - ) - else: - # For llama 4, in hf unified checkpoint, all local experts share one scale - self.rules["local_experts.linear_fc1"]( - layer.mlp.experts.local_experts, layer_id - ) - self.rules["local_experts.linear_fc2"]( - layer.mlp.experts.local_experts, layer_id - ) - else: - self.rules["linear_fc1"](layer.mlp.linear_fc1, layer_id) - self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id) - else: - raise ValueError("Only TransformerLayer or MambaLayer are supported.") + def _gather_exclude_modules(self): + """Get exclude_modules from all ranks to ensure hf_quant_config is complete.""" + all_exclude_modules = [None] * torch.distributed.get_world_size() + torch.distributed.all_gather_object(all_exclude_modules, self.exclude_modules) + combined_exclude_modules = set() + for modules in all_exclude_modules: + if modules: + combined_exclude_modules.update(modules) + return sorted(combined_exclude_modules) def export_mcore_gpt_to_hf( diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 0077d8666..88ebaa906 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -81,6 +81,11 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis forward_loop(model) finish_stats_collection(model) + # Step 1: Sync amax across local experts in a SequentialMLP + for name, module in model.named_modules(): + if hasattr(module, "sync_moe_local_experts_amax"): + module.layer_sync_moe_local_experts_amax() + if not distributed_sync: return @@ -95,13 +100,13 @@ def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state): quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group) # TODO: create sync_bias_across_distributed_group - # Step 1:Sync amax across data parallelism + # Step 2:Sync amax across data parallelism for name, module in model.named_modules(): if isinstance(module, QuantModule): for child in module.children(): if isinstance(child, (TensorQuantizer, SequentialQuantizer)): sync_quantizer_amax_across_dp_ep(child, module.parallel_state) - # TP sync: + # Step 3: TP sync # Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same # ColumnParallel: X @ [A_1, A_2] (weights split along Cout) @@ -156,7 +161,6 @@ def sync_quantizer_amax_across_tp( axes_for_sync=[None, -1], parallel_state=module.parallel_state, ) - sync_quantizer_amax_across_tp( module.weight_quantizer, name, @@ -182,10 +186,6 @@ def sync_quantizer_amax_across_tp( parallel_state=module.parallel_state, ) - # MOE Quantization - if hasattr(module, "sync_moe_local_experts_amax"): - module.sync_moe_local_experts_amax() - # KV Cache Quantization if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"): # We only support KVCache quantization with scalar per-tensor states for now (NVFP4 & FP8 KV cache) diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 6a2bdc15f..44784409b 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -52,6 +52,7 @@ TEColumnParallelLinear, TEDotProductAttention, TELayerNormColumnParallelLinear, + TELinear, TERowParallelGroupedLinear, TERowParallelLinear, ) @@ -574,12 +575,15 @@ def _setup(self): expert.linear_fc1.parallel_state = self.parallel_state expert.linear_fc2.parallel_state = self.parallel_state - def sync_moe_local_experts_amax(self): + def layer_sync_moe_local_experts_amax(self): """Sync amax across local experts in a SequentialMLP. - amax across EP and ETP (for RowParallel) are synchronized as part of model_calib.max_calibrate(). - This function is called to synchronize the amax values across local experts s.t. all localexperts will - share the same amax. + Synchronize the amax values across local experts in a lyaer such that all local experts will + share the same amax. This function operates on a single rank and does not require distributed sync. + + Distributed amax sync across EP and ETP (for RowParallel) happens in model_calib.max_calibrate(). + This function should be called before the distributed sync to ensure the amax values + are synchronized across the layer first. """ # Collect amax from all local experts amax_dict = {} @@ -618,6 +622,10 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): if HAS_TE: + @QuantModuleRegistry.register({TELinear: "te_mcore_Linear"}) + class _QuantTEMCoreLinear(_QuantTELinear): + pass + @QuantModuleRegistry.register({TERowParallelLinear: "te_mcore_RowParallelLinear"}) class _QuantTEMCoreRowParallelLinear(_QuantTELinear, _MegatronRowParallelLinear): pass diff --git a/tests/_test_utils/torch/megatron/utils.py b/tests/_test_utils/torch/megatron/utils.py index bb91f83cd..63904ba44 100644 --- a/tests/_test_utils/torch/megatron/utils.py +++ b/tests/_test_utils/torch/megatron/utils.py @@ -129,6 +129,47 @@ def run_mcore_inference_with_dummy_input( return run_mcore_inference(model, prompt_tokens, hidden_size) +def get_batch(model, batch_size=2): + seq_length = model.max_sequence_length + vocab_size = model.vocab_size + + input_ids = torch.randint(0, vocab_size, (batch_size, seq_length)).cuda() + labels = torch.randint(0, vocab_size, (batch_size, seq_length)).cuda() + position_ids = ( + torch.arange(seq_length, dtype=torch.int64).unsqueeze(0).repeat(batch_size, 1).cuda() + ) + attention_mask = torch.tril( + torch.ones((batch_size, 1, seq_length, seq_length), dtype=torch.bool) + ).cuda() + loss_mask = torch.ones((batch_size, seq_length), dtype=torch.float32).cuda() + + return input_ids, labels, position_ids, attention_mask, loss_mask + + +def get_forward(model, batch_size=2): + """Return a forward function with cached batch inputs.""" + input_ids, labels, position_ids, attention_mask, loss_mask = get_batch(model, batch_size) + + def forward(model): + # MambaModel doesn't accept loss_mask argument + if isinstance(model, MambaModel): + return model.forward( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + labels=labels, + ) + return model.forward( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + labels=labels, + loss_mask=loss_mask, + ) + + return forward + + def initialize_for_megatron( tensor_model_parallel_size=1, pipeline_model_parallel_size=1, diff --git a/tests/gpu/torch/export/test_unified_export_megatron.py b/tests/gpu/torch/export/test_unified_export_megatron.py index c07c2b565..54d6c16f2 100644 --- a/tests/gpu/torch/export/test_unified_export_megatron.py +++ b/tests/gpu/torch/export/test_unified_export_megatron.py @@ -16,6 +16,7 @@ import json from copy import deepcopy from functools import partial +from pathlib import Path import pytest import torch @@ -23,18 +24,57 @@ from _test_utils.import_helper import skip_if_no_megatron from _test_utils.torch.distributed.utils import spawn_multiprocess_job from _test_utils.torch.megatron.models import get_mcore_gpt_model +from _test_utils.torch.megatron.utils import get_forward from _test_utils.torch.transformers_models import create_tiny_llama_dir skip_if_no_megatron(apex_or_te_required=True) +import modelopt.torch.quantization as mtq import modelopt.torch.speculative as mtsp -from modelopt.torch.export import export_mcore_gpt_to_hf, import_mcore_gpt_from_hf +from modelopt.torch.export import KV_CACHE_FP8, export_mcore_gpt_to_hf, import_mcore_gpt_from_hf from modelopt.torch.speculative.eagle.default_config import default_eagle_config from modelopt.torch.speculative.plugins.megatron_eagle import _DynamicEagleGPTModel from modelopt.torch.speculative.plugins.megatron_medusa import _DynamicMedusaGPTModel -def _test_unified_export_megatron(tmp_path, model_type, arch, algo, rank, size): +def _verify_model_quant_config( + export_dir: Path, quant_config: str | None = None, kv_cache_quant_cfg: str | None = None +): + """Verify config.json and hf_quant_config.json""" + config_dict = json.load(open(export_dir / "config.json")) + hf_quant_config_dict = json.load(open(export_dir / "hf_quant_config.json")) + # Make sure config.json and hf_quant_config.json are consistent + assert ( + config_dict["quantization_config"]["quant_algo"] + == hf_quant_config_dict["quantization"]["quant_algo"] + ) + assert ( + config_dict["quantization_config"]["ignore"] + == hf_quant_config_dict["quantization"]["exclude_modules"] + ) + + # Verify config.json + if kv_cache_quant_cfg: + assert config_dict["quantization_config"]["kv_cache_scheme"]["num_bits"] == 8 + + # Verify hf_quant_config.json + if quant_config: + quant_config_dict = hf_quant_config_dict["quantization"] + quant_type = quant_config_dict["quant_algo"] + assert ( + quant_type in quant_config + ) # quant config str is subset of quant config e.g. NVFP4 -> NVFP4_DEFAULT_CFG + assert len(quant_config_dict["exclude_modules"]) > 1 # Dynamically added exclude modules + if quant_type == "NVFP4": + assert quant_config_dict["group_size"] == 16 + + if kv_cache_quant_cfg: + assert quant_config["kv_cache_quant_algo"] == KV_CACHE_FP8 + + +def _test_unified_export_megatron( + tmp_path, model_type, arch, extra_module, quant_config, kv_cache_quant_cfg, rank, size +): num_layers = 2 hidden_size = 64 num_attention_heads = 8 @@ -63,14 +103,24 @@ def _test_unified_export_megatron(tmp_path, model_type, arch, algo, rank, size): transformer_impl="modelopt", ).cuda() - if algo == "medusa": + if quant_config: + quant_config_dict = getattr(mtq, quant_config) + if kv_cache_quant_cfg: + kv_quant_cfg = getattr(mtq, kv_cache_quant_cfg)["quant_cfg"] + quant_config_dict = mtq.utils.update_quant_cfg_with_kv_cache_quant( + quant_config_dict, kv_quant_cfg + ) + forward = get_forward(model) + model = mtq.quantize(model, quant_config_dict, forward) + + if extra_module == "medusa": config = { "medusa_num_heads": 1, "medusa_num_layers": 1, } model = mtsp.convert(model, [("medusa", config)]) assert isinstance(model, _DynamicMedusaGPTModel) - elif algo == "eagle": + elif extra_module == "eagle": config = {"eagle_architecture_config": deepcopy(default_eagle_config)} model = mtsp.convert(model, [("eagle", config)]) assert isinstance(model, _DynamicEagleGPTModel) @@ -91,25 +141,36 @@ def _test_unified_export_megatron(tmp_path, model_type, arch, algo, rank, size): with open(tmp_path / "config.json", "w") as f: json.dump(pretrained_config, f) + tmp_export_dir = tmp_path / "export" export_mcore_gpt_to_hf( model, tmp_path if arch is not None else None, dtype=torch.bfloat16, + export_dir=str(tmp_export_dir), ) + if quant_config: + _verify_model_quant_config(tmp_export_dir, quant_config, kv_cache_quant_cfg) + @pytest.mark.parametrize( - ("model_type", "arch", "algo"), + ("model_type", "arch", "extra_module", "quant_config", "kv_cache_quant_cfg"), [ - ("nemotron", "NemotronForCausalLM", None), - ("nemotron", "NemotronForCausalLM", "eagle"), - ("nemotron", "NemotronForCausalLM", "medusa"), - ("llama", "LlamaForCausalLM", None), - ("llama", "LlamaForCausalLM", "eagle"), - ("llama", "LlamaForCausalLM", "medusa"), + ("nemotron", "NemotronForCausalLM", None, None, None), + ("nemotron", "NemotronForCausalLM", None, "NVFP4_DEFAULT_CFG", None), + ("nemotron", "NemotronForCausalLM", None, "NVFP4_DEFAULT_CFG", "FP8_KV_CFG"), + ("nemotron", "NemotronForCausalLM", "eagle", None, None), + ("nemotron", "NemotronForCausalLM", "medusa", None, None), + ("llama", "LlamaForCausalLM", None, None, None), + ("llama", "LlamaForCausalLM", None, "FP8_DEFAULT_CFG", None), + ("llama", "LlamaForCausalLM", None, "FP8_DEFAULT_CFG", "FP8_KV_CFG"), + ("llama", "LlamaForCausalLM", "eagle", None, None), + ("llama", "LlamaForCausalLM", "medusa", None, None), ], ) -def test_unified_export_megatron(tmp_path, model_type, arch, algo): +def test_unified_export_megatron( + tmp_path, model_type, arch, extra_module, quant_config, kv_cache_quant_cfg +): # TODO: Fix TP>1 failures spawn_multiprocess_job( size=1, # torch.cuda.device_count(), @@ -118,7 +179,9 @@ def test_unified_export_megatron(tmp_path, model_type, arch, algo): tmp_path, model_type, arch, - algo, + extra_module, + quant_config, + kv_cache_quant_cfg, ), backend="nccl", ) diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index d02b02c18..2ac5915e4 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -21,7 +21,6 @@ from _test_utils.import_helper import skip_if_no_megatron from _test_utils.torch.distributed.utils import spawn_multiprocess_job from _test_utils.torch.megatron.models import ( - MambaModel, MegatronModel, get_mcore_gpt_model, get_mcore_mamba_hybrid_model, @@ -29,6 +28,7 @@ from _test_utils.torch.megatron.utils import ( compare_amax_sync_across_expert_parallel, copy_weights_from_grouped_to_non_grouped, + get_forward, initialize_for_megatron, run_mcore_inference, sharded_state_dict_test_helper, @@ -69,47 +69,6 @@ SEED = 1234 -def get_batch(model, batch_size=2): - seq_length = model.max_sequence_length - vocab_size = model.vocab_size - - input_ids = torch.randint(0, vocab_size, (batch_size, seq_length)).cuda() - labels = torch.randint(0, vocab_size, (batch_size, seq_length)).cuda() - position_ids = ( - torch.arange(seq_length, dtype=torch.int64).unsqueeze(0).repeat(batch_size, 1).cuda() - ) - attention_mask = torch.tril( - torch.ones((batch_size, 1, seq_length, seq_length), dtype=torch.bool) - ).cuda() - loss_mask = torch.ones((batch_size, seq_length), dtype=torch.float32).cuda() - - return input_ids, labels, position_ids, attention_mask, loss_mask - - -def get_forward(model, batch_size=2): - """Return a forward function with cached batch inputs.""" - input_ids, labels, position_ids, attention_mask, loss_mask = get_batch(model, batch_size) - - def forward(model): - # MambaModel doesn't accept loss_mask argument - if isinstance(model, MambaModel): - return model.forward( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - labels=labels, - ) - return model.forward( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - labels=labels, - loss_mask=loss_mask, - ) - - return forward - - def test_convert_megatron_parallel_linear(distributed_setup_size_1): initialize_for_megatron(seed=SEED) set_seed(SEED) @@ -841,9 +800,6 @@ def _test_expert_model_parallel_amax_sync( ) # calibrate the model with distributed sync and test synchronization mtq.model_calib.max_calibrate(model, forward, distributed_sync=True) - for module in model.modules(): - if hasattr(module, "sync_moe_local_experts_amax"): - module.sync_moe_local_experts_amax() final_sync, quantizer_type, rank_values = compare_amax_sync_across_expert_parallel(model) assert final_sync, f"Inconsistent amax for expert {quantizer_type} across ranks: {rank_values}"