Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions modelopt/torch/export/plugins/mcore_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,18 @@ def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any]
)


class GroupedMLPMerging(CustomModuleMapping):
"""A custom module mapping that merges up_proj and down_proj for Grouped MLP."""

def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any] = {}):
"""Create a custom module mapping that merges up_proj and down_proj for Grouped MLP."""
super().__init__(
func_name="grouped_mlp_merging",
target_name_or_prefix=target_name_or_prefix,
func_kwargs=func_kwargs,
)


class GatedMLPMerging(CustomModuleMapping):
"""A custom module mapping that merges gate_proj and up_proj."""

Expand All @@ -127,6 +139,18 @@ def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any]
)


class SelfAttentionScaling(CustomModuleMapping):
"""A custom module mapping that scales self attention."""

def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any] = {}):
"""Create a custom module mapping that scales self attention."""
super().__init__(
func_name="self_attention_scaling",
target_name_or_prefix=target_name_or_prefix,
func_kwargs=func_kwargs,
)


class GatedMLPSlicing(CustomModuleMapping):
"""A custom module mapping that slices gate_proj and up_proj."""

Expand Down
3 changes: 3 additions & 0 deletions modelopt/torch/export/plugins/mcore_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
PackNameRemapping,
QKVMerging,
QKVSlicing,
SelfAttentionScaling,
UnpackNameRemapping,
)

Expand All @@ -38,6 +39,8 @@
"input_layernorm": NameRemapping("model.layers.{}.input_layernorm."),
"linear_qkv": QKVSlicing("model.layers.{}.self_attn."),
"linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj."),
# KV cache quant export
"core_attention": SelfAttentionScaling("model.layers.{}.self_attn."),
"pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm."),
"linear_fc1": GatedMLPSlicing("model.layers.{}.mlp."),
"linear_fc2": NameRemapping("model.layers.{}.mlp.down_proj."),
Expand Down
28 changes: 27 additions & 1 deletion modelopt/torch/export/plugins/mcore_nemotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@
ROW_ETP,
ROW_TP,
CustomModuleMapping,
GroupedMLPMerging,
NameRemapping,
QKVMerging,
QKVSlicing,
SelfAttentionScaling,
)

# Example on adding a new CausalLM.
Expand All @@ -35,6 +37,7 @@
"input_layernorm": NameRemapping("model.layers.{}.input_layernorm."),
"linear_qkv": QKVSlicing("model.layers.{}.self_attn."),
"linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj."),
"core_attention": SelfAttentionScaling("backbone.layers.{}.mixer."),
"pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm."),
# NemotronForCausalLM is using square-relu where no gated handle is needed.
"linear_fc1": NameRemapping("model.layers.{}.mlp.up_proj."),
Expand Down Expand Up @@ -81,9 +84,23 @@
"shared_experts.linear_fc2": NameRemapping(
"backbone.layers.{}.mixer.shared_experts.down_proj.", ROW_TP
),
# Latent MoE
"fc1_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc1_latent_proj.", REPLICATE),
"fc2_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc2_latent_proj.", REPLICATE),
# Repeated MTP module
"mtp.enorm": NameRemapping("mtp.layers.{}.enorm.", {"is_mtp": True}),
"mtp.hnorm": NameRemapping("mtp.layers.{}.hnorm.", {"is_mtp": True}),
"mtp.eh_proj": NameRemapping("mtp.layers.{}.eh_proj.", {"is_mtp": True}),
"mtp.final_layernorm": NameRemapping("mtp.layers.{}.final_layernorm.", {"is_mtp": True}),
# Grouped local experts in MTP
"experts.linear_fc1": GroupedMLPMerging(
"mtp.layers.{}.mixer.experts.{{}}.up_proj", COL_ETP | {"is_mtp": True}
),
"experts.linear_fc2": GroupedMLPMerging(
"mtp.layers.{}.mixer.experts.{{}}.down_proj", ROW_ETP | {"is_mtp": True}
),
}


nemotron_h_causal_lm_export: dict[str, CustomModuleMapping] = {
"word_embeddings": NameRemapping("backbone.embeddings."),
"final_norm": NameRemapping("backbone.norm_f."),
Expand All @@ -101,6 +118,7 @@
"input_layernorm": NameRemapping("backbone.layers.{}.norm."),
"linear_qkv": QKVSlicing("backbone.layers.{}.mixer."),
"linear_proj": NameRemapping("backbone.layers.{}.mixer.o_proj."),
"core_attention": SelfAttentionScaling("backbone.layers.{}.mixer."),
# MLP
"pre_mlp_layernorm": NameRemapping("backbone.layers.{}.norm."),
"linear_fc1": NameRemapping("backbone.layers.{}.mixer.up_proj."),
Expand All @@ -115,4 +133,12 @@
"shared_experts.linear_fc2": NameRemapping(
"backbone.layers.{}.mixer.shared_experts.down_proj."
),
# Latent MoE
"fc1_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc1_latent_proj."),
"fc2_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc2_latent_proj."),
# MTP
"mtp.enorm": NameRemapping("mtp.layers.{}.enorm."),
"mtp.hnorm": NameRemapping("mtp.layers.{}.hnorm."),
"mtp.eh_proj": NameRemapping("mtp.layers.{}.eh_proj."),
"mtp.final_layernorm": NameRemapping("mtp.layers.{}.final_layernorm."),
}
Loading
Loading