From 81270f9cdab4113997565574d3a7cf99a61fcc40 Mon Sep 17 00:00:00 2001 From: Thump604 Date: Fri, 20 Mar 2026 21:52:54 -0500 Subject: [PATCH 1/3] feat: fuse gate/up expert projections in SwitchGLU Add fuse_gate_up option to SwitchGLU that uses a single gather_qmm call with 2x hidden_dims instead of two separate calls for gate_proj and up_proj. Eliminates one kernel dispatch per MoE layer per token. Measured +5% tok/s on Qwen3.5-122B, +8.6% on Qwen3-30B, +5.1% on MiniMax M2.5, +3.8% on OLMoE (ref: #956). Models updated: Qwen3, Qwen3.5 (all variants), Llama 4, Mixtral, MiniMax, OLMoE. --- mlx_lm/models/llama4.py | 13 +++++----- mlx_lm/models/minimax.py | 42 ++++++++++++++++++++++++--------- mlx_lm/models/mixtral.py | 43 +++++++++++++++++++++++++--------- mlx_lm/models/olmoe.py | 36 +++++++++++++++++++++++----- mlx_lm/models/qwen3_5.py | 23 +++++++++++++----- mlx_lm/models/qwen3_5_moe.py | 11 +++------ mlx_lm/models/qwen3_moe.py | 4 +++- mlx_lm/models/qwen3_next.py | 4 +++- mlx_lm/models/qwen3_vl_moe.py | 11 +++------ mlx_lm/models/switch_layers.py | 22 +++++++++++++---- 10 files changed, 146 insertions(+), 63 deletions(-) diff --git a/mlx_lm/models/llama4.py b/mlx_lm/models/llama4.py index e4e284d71..c5ec675a1 100644 --- a/mlx_lm/models/llama4.py +++ b/mlx_lm/models/llama4.py @@ -156,7 +156,10 @@ def __init__(self, args): assert self.top_k == 1, "Only 1 expert per token supported" self.num_experts = args.num_local_experts self.experts = SwitchGLU( - args.hidden_size, args.intermediate_size, self.num_experts + args.hidden_size, + args.intermediate_size, + self.num_experts, + fuse_gate_up=True, ) self.router = nn.Linear(args.hidden_size, args.num_local_experts, bias=False) self.shared_expert = MLP(args) @@ -295,16 +298,12 @@ def to_remove(k): # Remove vision weights weights = {k: v for k, v in weights.items() if not to_remove(k)} - # Rename expert weights for SwitchGLU + # Rename expert weights for SwitchGLU (fused gate_up_proj) for l in range(self.args.text_config.num_hidden_layers): prefix = f"language_model.model.layers.{l}.feed_forward.experts" if f"{prefix}.gate_up_proj" in weights: v = weights.pop(f"{prefix}.gate_up_proj") - gate_k = f"{prefix}.gate_proj.weight" - up_k = f"{prefix}.up_proj.weight" - gate_proj, up_proj = mx.split(v, 2, axis=-1) - weights[gate_k] = mx.swapaxes(gate_proj, 1, 2) - weights[up_k] = mx.swapaxes(up_proj, 1, 2) + weights[f"{prefix}.gate_up_proj.weight"] = mx.swapaxes(v, 1, 2) if f"{prefix}.down_proj" in weights: down_proj = weights.pop(f"{prefix}.down_proj") weights[f"{prefix}.down_proj.weight"] = mx.swapaxes(down_proj, 1, 2) diff --git a/mlx_lm/models/minimax.py b/mlx_lm/models/minimax.py index 9bf78d9a4..f7bba259a 100644 --- a/mlx_lm/models/minimax.py +++ b/mlx_lm/models/minimax.py @@ -166,7 +166,10 @@ def __init__(self, args: ModelArgs): self.gate = nn.Linear(args.hidden_size, args.num_local_experts, bias=False) self.switch_mlp = SwitchGLU( - args.hidden_size, args.intermediate_size, args.num_local_experts + args.hidden_size, + args.intermediate_size, + args.num_local_experts, + fuse_gate_up=True, ) self.e_score_correction_bias = mx.zeros((args.num_local_experts,)) self.sharding_group = None @@ -311,18 +314,35 @@ def dequant(weight, scale_inv): for l in range(self.args.num_hidden_layers): prefix = f"model.layers.{l}" - mapping = {"w1": "gate_proj", "w2": "down_proj", "w3": "up_proj"} - for orig_name, new_name in mapping.items(): - if f"{prefix}.block_sparse_moe.experts.0.{orig_name}.weight" in weights: - to_join = [ - weights.pop( - f"{prefix}.block_sparse_moe.experts.{e}.{orig_name}.weight" - ) + # Stack and fuse gate(w1)+up(w3) into single gate_up_proj + w1_key = f"{prefix}.block_sparse_moe.experts.0.w1.weight" + w3_key = f"{prefix}.block_sparse_moe.experts.0.w3.weight" + if w1_key in weights and w3_key in weights: + gate = mx.stack( + [ + weights.pop(f"{prefix}.block_sparse_moe.experts.{e}.w1.weight") for e in range(self.args.num_local_experts) ] - weights[ - f"{prefix}.block_sparse_moe.switch_mlp.{new_name}.weight" - ] = mx.stack(to_join) + ) + up = mx.stack( + [ + weights.pop(f"{prefix}.block_sparse_moe.experts.{e}.w3.weight") + for e in range(self.args.num_local_experts) + ] + ) + weights[f"{prefix}.block_sparse_moe.switch_mlp.gate_up_proj.weight"] = ( + mx.concatenate([gate, up], axis=1) + ) + # Stack down(w2) normally + w2_key = f"{prefix}.block_sparse_moe.experts.0.w2.weight" + if w2_key in weights: + to_join = [ + weights.pop(f"{prefix}.block_sparse_moe.experts.{e}.w2.weight") + for e in range(self.args.num_local_experts) + ] + weights[f"{prefix}.block_sparse_moe.switch_mlp.down_proj.weight"] = ( + mx.stack(to_join) + ) return weights diff --git a/mlx_lm/models/mixtral.py b/mlx_lm/models/mixtral.py index 18d16644c..ae32bbb40 100644 --- a/mlx_lm/models/mixtral.py +++ b/mlx_lm/models/mixtral.py @@ -105,7 +105,9 @@ def __init__(self, args: ModelArgs): # gating self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) - self.switch_mlp = SwitchGLU(self.hidden_dim, self.ffn_dim, self.num_experts) + self.switch_mlp = SwitchGLU( + self.hidden_dim, self.ffn_dim, self.num_experts, fuse_gate_up=True + ) def __call__(self, x: mx.array) -> mx.array: gates = self.gate(x) @@ -209,18 +211,37 @@ def sanitize(self, weights): return weights for l in range(self.args.num_hidden_layers): prefix = f"model.layers.{l}" - for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]: - for k in ["weight", "scales", "biases"]: - if f"{prefix}.block_sparse_moe.experts.0.{n}.{k}" in weights: - to_join = [ - weights.pop( - f"{prefix}.block_sparse_moe.experts.{e}.{n}.{k}" - ) + # Stack and fuse gate(w1)+up(w3) into single gate_up_proj + for k in ["weight", "scales", "biases"]: + w1_key = f"{prefix}.block_sparse_moe.experts.0.w1.{k}" + w3_key = f"{prefix}.block_sparse_moe.experts.0.w3.{k}" + if w1_key in weights and w3_key in weights: + gate = mx.stack( + [ + weights.pop(f"{prefix}.block_sparse_moe.experts.{e}.w1.{k}") + for e in range(self.args.num_local_experts) + ] + ) + up = mx.stack( + [ + weights.pop(f"{prefix}.block_sparse_moe.experts.{e}.w3.{k}") for e in range(self.args.num_local_experts) ] - weights[f"{prefix}.block_sparse_moe.switch_mlp.{m}.{k}"] = ( - mx.stack(to_join) - ) + ) + weights[ + f"{prefix}.block_sparse_moe.switch_mlp.gate_up_proj.{k}" + ] = mx.concatenate([gate, up], axis=1) + # Stack down(w2) normally + for k in ["weight", "scales", "biases"]: + w2_key = f"{prefix}.block_sparse_moe.experts.0.w2.{k}" + if w2_key in weights: + to_join = [ + weights.pop(f"{prefix}.block_sparse_moe.experts.{e}.w2.{k}") + for e in range(self.args.num_local_experts) + ] + weights[f"{prefix}.block_sparse_moe.switch_mlp.down_proj.{k}"] = ( + mx.stack(to_join) + ) return weights @property diff --git a/mlx_lm/models/olmoe.py b/mlx_lm/models/olmoe.py index c971c2247..80a65043a 100644 --- a/mlx_lm/models/olmoe.py +++ b/mlx_lm/models/olmoe.py @@ -106,6 +106,7 @@ def __init__(self, args: ModelArgs): args.intermediate_size, self.num_experts, bias=args.mlp_bias, + fuse_gate_up=True, ) def __call__(self, x: mx.array) -> mx.array: @@ -199,14 +200,37 @@ def sanitize(self, weights): return weights for l in range(self.args.num_hidden_layers): prefix = f"model.layers.{l}" - for n in ["up_proj", "down_proj", "gate_proj"]: - for k in ["weight", "scales", "biases"]: - if f"{prefix}.mlp.experts.0.{n}.{k}" in weights: - to_join = [ - weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}") + # Stack and fuse gate+up into single gate_up_proj + for k in ["weight", "scales", "biases"]: + gate_key = f"{prefix}.mlp.experts.0.gate_proj.{k}" + up_key = f"{prefix}.mlp.experts.0.up_proj.{k}" + if gate_key in weights and up_key in weights: + gate = mx.stack( + [ + weights.pop(f"{prefix}.mlp.experts.{e}.gate_proj.{k}") for e in range(self.args.num_experts) ] - weights[f"{prefix}.mlp.switch_mlp.{n}.{k}"] = mx.stack(to_join) + ) + up = mx.stack( + [ + weights.pop(f"{prefix}.mlp.experts.{e}.up_proj.{k}") + for e in range(self.args.num_experts) + ] + ) + weights[f"{prefix}.mlp.switch_mlp.gate_up_proj.{k}"] = ( + mx.concatenate([gate, up], axis=1) + ) + # Stack down_proj normally + for k in ["weight", "scales", "biases"]: + down_key = f"{prefix}.mlp.experts.0.down_proj.{k}" + if down_key in weights: + to_join = [ + weights.pop(f"{prefix}.mlp.experts.{e}.down_proj.{k}") + for e in range(self.args.num_experts) + ] + weights[f"{prefix}.mlp.switch_mlp.down_proj.{k}"] = mx.stack( + to_join + ) return weights @property diff --git a/mlx_lm/models/qwen3_5.py b/mlx_lm/models/qwen3_5.py index 43aadba23..7663dd357 100644 --- a/mlx_lm/models/qwen3_5.py +++ b/mlx_lm/models/qwen3_5.py @@ -498,15 +498,26 @@ def _repeat(p): shard_inplace( layer.mlp.shared_expert.up_proj, "all-to-sharded", group=group ) - shard_inplace( - layer.mlp.switch_mlp.gate_proj, "all-to-sharded", group=group - ) + if hasattr(layer.mlp.switch_mlp, "gate_up_proj"): + shard_inplace( + layer.mlp.switch_mlp.gate_up_proj, + "all-to-sharded", + group=group, + ) + else: + shard_inplace( + layer.mlp.switch_mlp.gate_proj, + "all-to-sharded", + group=group, + ) + shard_inplace( + layer.mlp.switch_mlp.up_proj, + "all-to-sharded", + group=group, + ) shard_inplace( layer.mlp.switch_mlp.down_proj, "sharded-to-all", group=group ) - shard_inplace( - layer.mlp.switch_mlp.up_proj, "all-to-sharded", group=group - ) @property def layers(self): diff --git a/mlx_lm/models/qwen3_5_moe.py b/mlx_lm/models/qwen3_5_moe.py index 53ab8530e..c8398e742 100644 --- a/mlx_lm/models/qwen3_5_moe.py +++ b/mlx_lm/models/qwen3_5_moe.py @@ -37,14 +37,9 @@ def sanitize(self, weights): prefix = f"language_model.model.layers.{l}.mlp" gate_up_key = f"{prefix}.experts.gate_up_proj" if gate_up_key in new_weights: - gate_up = new_weights.pop(gate_up_key) - mid = gate_up.shape[-2] // 2 - new_weights[f"{prefix}.switch_mlp.gate_proj.weight"] = gate_up[ - ..., :mid, : - ] - new_weights[f"{prefix}.switch_mlp.up_proj.weight"] = gate_up[ - ..., mid:, : - ] + new_weights[f"{prefix}.switch_mlp.gate_up_proj.weight"] = ( + new_weights.pop(gate_up_key) + ) new_weights[f"{prefix}.switch_mlp.down_proj.weight"] = new_weights.pop( f"{prefix}.experts.down_proj" ) diff --git a/mlx_lm/models/qwen3_moe.py b/mlx_lm/models/qwen3_moe.py index 52dc50f9b..0c265a981 100644 --- a/mlx_lm/models/qwen3_moe.py +++ b/mlx_lm/models/qwen3_moe.py @@ -118,7 +118,9 @@ def __init__(self, args: ModelArgs): self.norm_topk_prob = args.norm_topk_prob self.gate = nn.Linear(dim, num_experts, bias=False) - self.switch_mlp = SwitchGLU(dim, intermediate_size, num_experts) + self.switch_mlp = SwitchGLU( + dim, intermediate_size, num_experts, fuse_gate_up=True + ) def __call__( self, diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index 39128a445..48ccdccb9 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -317,7 +317,9 @@ def __init__(self, args: ModelArgs): self.top_k = args.num_experts_per_tok self.gate = nn.Linear(dim, num_experts, bias=False) - self.switch_mlp = SwitchGLU(dim, intermediate_size, num_experts) + self.switch_mlp = SwitchGLU( + dim, intermediate_size, num_experts, fuse_gate_up=True + ) self.shared_expert = Qwen3NextMLP(dim, shared_expert_intermediate_size) self.shared_expert_gate = nn.Linear(dim, 1, bias=False) diff --git a/mlx_lm/models/qwen3_vl_moe.py b/mlx_lm/models/qwen3_vl_moe.py index 7810f67b5..05be0b8e1 100644 --- a/mlx_lm/models/qwen3_vl_moe.py +++ b/mlx_lm/models/qwen3_vl_moe.py @@ -54,14 +54,9 @@ def sanitize(self, weights): prefix = f"language_model.model.layers.{l}.mlp" gate_up_key = f"{prefix}.experts.gate_up_proj" if gate_up_key in weights: - gate_up = weights.pop(gate_up_key) - mid = gate_up.shape[-1] // 2 - weights[f"{prefix}.switch_mlp.gate_proj.weight"] = gate_up[ - ..., :mid - ].swapaxes(-2, -1) - weights[f"{prefix}.switch_mlp.up_proj.weight"] = gate_up[ - ..., mid: - ].swapaxes(-2, -1) + weights[f"{prefix}.switch_mlp.gate_up_proj.weight"] = weights.pop( + gate_up_key + ).swapaxes(-2, -1) weights[f"{prefix}.switch_mlp.down_proj.weight"] = weights.pop( f"{prefix}.experts.down_proj" ).swapaxes(-2, -1) diff --git a/mlx_lm/models/switch_layers.py b/mlx_lm/models/switch_layers.py index 1fe5d917e..12e9cda78 100644 --- a/mlx_lm/models/switch_layers.py +++ b/mlx_lm/models/switch_layers.py @@ -165,11 +165,19 @@ def __init__( num_experts: int, activation=SwiGLU(), bias: bool = False, + fuse_gate_up: bool = False, ): super().__init__() - self.gate_proj = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias) - self.up_proj = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias) + if fuse_gate_up: + self.gate_up_proj = SwitchLinear( + input_dims, hidden_dims * 2, num_experts, bias=bias + ) + else: + self.gate_proj = SwitchLinear( + input_dims, hidden_dims, num_experts, bias=bias + ) + self.up_proj = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias) self.down_proj = SwitchLinear(hidden_dims, input_dims, num_experts, bias=bias) self.activation = activation @@ -185,8 +193,14 @@ def __call__(self, x, indices) -> mx.array: x, idx, inv_order = _gather_sort(x, indices) if self.training: idx = mx.stop_gradient(idx) - x_up = self.up_proj(x, idx, sorted_indices=do_sort) - x_gate = self.gate_proj(x, idx, sorted_indices=do_sort) + + if "gate_up_proj" in self: + x_gate_up = self.gate_up_proj(x, idx, sorted_indices=do_sort) + x_gate, x_up = mx.split(x_gate_up, 2, axis=-1) + else: + x_up = self.up_proj(x, idx, sorted_indices=do_sort) + x_gate = self.gate_proj(x, idx, sorted_indices=do_sort) + x = self.down_proj( self.activation(x_up, x_gate), idx, From 801ba8cfc99c149e0182976786a4ff9c5afe4101 Mon Sep 17 00:00:00 2001 From: Thump604 Date: Sat, 21 Mar 2026 00:08:58 -0500 Subject: [PATCH 2/3] fix: make fuse_gate_up config-driven to avoid Metal OOM on large models MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Read fuse_gate_up from model config (default True) instead of hardcoding. On memory-constrained systems (e.g. 122B on 128GB), set "fuse_gate_up": false in config.json to use separate gate/up projections that allow Metal to deallocate scratch between kernels. Also fixes sanitize() in qwen3_next and qwen3_moe which produced unfused weight names while the model expected fused — a mismatch that would break loading from HuggingFace format. --- mlx_lm/models/llama4.py | 15 ++++++-- mlx_lm/models/minimax.py | 53 +++++++++++++++++++--------- mlx_lm/models/mixtral.py | 64 +++++++++++++++++++++++----------- mlx_lm/models/olmoe.py | 63 ++++++++++++++++++++++----------- mlx_lm/models/qwen3_5_moe.py | 19 +++++++--- mlx_lm/models/qwen3_moe.py | 41 ++++++++++++++++++---- mlx_lm/models/qwen3_next.py | 40 +++++++++++++++++---- mlx_lm/models/qwen3_vl_moe.py | 15 ++++++-- mlx_lm/models/switch_layers.py | 2 +- 9 files changed, 228 insertions(+), 84 deletions(-) diff --git a/mlx_lm/models/llama4.py b/mlx_lm/models/llama4.py index c5ec675a1..45f90d0d7 100644 --- a/mlx_lm/models/llama4.py +++ b/mlx_lm/models/llama4.py @@ -155,11 +155,12 @@ def __init__(self, args): self.top_k = args.num_experts_per_tok assert self.top_k == 1, "Only 1 expert per token supported" self.num_experts = args.num_local_experts + fuse = getattr(args, "fuse_gate_up", True) self.experts = SwitchGLU( args.hidden_size, args.intermediate_size, self.num_experts, - fuse_gate_up=True, + fuse_gate_up=fuse, ) self.router = nn.Linear(args.hidden_size, args.num_local_experts, bias=False) self.shared_expert = MLP(args) @@ -298,12 +299,20 @@ def to_remove(k): # Remove vision weights weights = {k: v for k, v in weights.items() if not to_remove(k)} - # Rename expert weights for SwitchGLU (fused gate_up_proj) + # Rename expert weights for SwitchGLU + fuse = getattr(self.args.text_config, "fuse_gate_up", True) for l in range(self.args.text_config.num_hidden_layers): prefix = f"language_model.model.layers.{l}.feed_forward.experts" if f"{prefix}.gate_up_proj" in weights: v = weights.pop(f"{prefix}.gate_up_proj") - weights[f"{prefix}.gate_up_proj.weight"] = mx.swapaxes(v, 1, 2) + v = mx.swapaxes(v, 1, 2) + if fuse: + weights[f"{prefix}.gate_up_proj.weight"] = v + else: + # Split fused gate_up into separate gate_proj and up_proj + gate, up = mx.split(v, 2, axis=1) + weights[f"{prefix}.gate_proj.weight"] = gate + weights[f"{prefix}.up_proj.weight"] = up if f"{prefix}.down_proj" in weights: down_proj = weights.pop(f"{prefix}.down_proj") weights[f"{prefix}.down_proj.weight"] = mx.swapaxes(down_proj, 1, 2) diff --git a/mlx_lm/models/minimax.py b/mlx_lm/models/minimax.py index f7bba259a..d0aba46db 100644 --- a/mlx_lm/models/minimax.py +++ b/mlx_lm/models/minimax.py @@ -165,11 +165,12 @@ def __init__(self, args: ModelArgs): self.num_experts_per_tok = args.num_experts_per_tok self.gate = nn.Linear(args.hidden_size, args.num_local_experts, bias=False) + fuse = getattr(args, "fuse_gate_up", True) self.switch_mlp = SwitchGLU( args.hidden_size, args.intermediate_size, args.num_local_experts, - fuse_gate_up=True, + fuse_gate_up=fuse, ) self.e_score_correction_bias = mx.zeros((args.num_local_experts,)) self.sharding_group = None @@ -312,27 +313,45 @@ def dequant(weight, scale_inv): if "model.layers.0.block_sparse_moe.experts.0.w1.weight" not in weights: return weights + fuse = getattr(self.args, "fuse_gate_up", True) for l in range(self.args.num_hidden_layers): prefix = f"model.layers.{l}" - # Stack and fuse gate(w1)+up(w3) into single gate_up_proj w1_key = f"{prefix}.block_sparse_moe.experts.0.w1.weight" w3_key = f"{prefix}.block_sparse_moe.experts.0.w3.weight" if w1_key in weights and w3_key in weights: - gate = mx.stack( - [ - weights.pop(f"{prefix}.block_sparse_moe.experts.{e}.w1.weight") - for e in range(self.args.num_local_experts) - ] - ) - up = mx.stack( - [ - weights.pop(f"{prefix}.block_sparse_moe.experts.{e}.w3.weight") - for e in range(self.args.num_local_experts) - ] - ) - weights[f"{prefix}.block_sparse_moe.switch_mlp.gate_up_proj.weight"] = ( - mx.concatenate([gate, up], axis=1) - ) + if fuse: + # Stack and fuse gate(w1)+up(w3) into single gate_up_proj + gate = mx.stack( + [ + weights.pop( + f"{prefix}.block_sparse_moe.experts.{e}.w1.weight" + ) + for e in range(self.args.num_local_experts) + ] + ) + up = mx.stack( + [ + weights.pop( + f"{prefix}.block_sparse_moe.experts.{e}.w3.weight" + ) + for e in range(self.args.num_local_experts) + ] + ) + weights[ + f"{prefix}.block_sparse_moe.switch_mlp.gate_up_proj.weight" + ] = mx.concatenate([gate, up], axis=1) + else: + # Unfused: separate gate_proj (w1) and up_proj (w3) + for n, wn in [("gate_proj", "w1"), ("up_proj", "w3")]: + to_join = [ + weights.pop( + f"{prefix}.block_sparse_moe.experts.{e}.{wn}.weight" + ) + for e in range(self.args.num_local_experts) + ] + weights[ + f"{prefix}.block_sparse_moe.switch_mlp.{n}.weight" + ] = mx.stack(to_join) # Stack down(w2) normally w2_key = f"{prefix}.block_sparse_moe.experts.0.w2.weight" if w2_key in weights: diff --git a/mlx_lm/models/mixtral.py b/mlx_lm/models/mixtral.py index ae32bbb40..0cdf1373f 100644 --- a/mlx_lm/models/mixtral.py +++ b/mlx_lm/models/mixtral.py @@ -105,8 +105,9 @@ def __init__(self, args: ModelArgs): # gating self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) + fuse = getattr(args, "fuse_gate_up", True) self.switch_mlp = SwitchGLU( - self.hidden_dim, self.ffn_dim, self.num_experts, fuse_gate_up=True + self.hidden_dim, self.ffn_dim, self.num_experts, fuse_gate_up=fuse ) def __call__(self, x: mx.array) -> mx.array: @@ -209,28 +210,49 @@ def sanitize(self, weights): weights.pop("lm_head.weight", None) if "model.layers.0.block_sparse_moe.experts.0.w1.weight" not in weights: return weights + fuse = getattr(self.args, "fuse_gate_up", True) for l in range(self.args.num_hidden_layers): prefix = f"model.layers.{l}" - # Stack and fuse gate(w1)+up(w3) into single gate_up_proj - for k in ["weight", "scales", "biases"]: - w1_key = f"{prefix}.block_sparse_moe.experts.0.w1.{k}" - w3_key = f"{prefix}.block_sparse_moe.experts.0.w3.{k}" - if w1_key in weights and w3_key in weights: - gate = mx.stack( - [ - weights.pop(f"{prefix}.block_sparse_moe.experts.{e}.w1.{k}") - for e in range(self.args.num_local_experts) - ] - ) - up = mx.stack( - [ - weights.pop(f"{prefix}.block_sparse_moe.experts.{e}.w3.{k}") - for e in range(self.args.num_local_experts) - ] - ) - weights[ - f"{prefix}.block_sparse_moe.switch_mlp.gate_up_proj.{k}" - ] = mx.concatenate([gate, up], axis=1) + if fuse: + # Stack and fuse gate(w1)+up(w3) into single gate_up_proj + for k in ["weight", "scales", "biases"]: + w1_key = f"{prefix}.block_sparse_moe.experts.0.w1.{k}" + w3_key = f"{prefix}.block_sparse_moe.experts.0.w3.{k}" + if w1_key in weights and w3_key in weights: + gate = mx.stack( + [ + weights.pop( + f"{prefix}.block_sparse_moe.experts.{e}.w1.{k}" + ) + for e in range(self.args.num_local_experts) + ] + ) + up = mx.stack( + [ + weights.pop( + f"{prefix}.block_sparse_moe.experts.{e}.w3.{k}" + ) + for e in range(self.args.num_local_experts) + ] + ) + weights[ + f"{prefix}.block_sparse_moe.switch_mlp.gate_up_proj.{k}" + ] = mx.concatenate([gate, up], axis=1) + else: + # Unfused: separate gate_proj (w1) and up_proj (w3) + for n, wn in [("gate_proj", "w1"), ("up_proj", "w3")]: + for k in ["weight", "scales", "biases"]: + src = f"{prefix}.block_sparse_moe.experts.0.{wn}.{k}" + if src in weights: + to_join = [ + weights.pop( + f"{prefix}.block_sparse_moe.experts.{e}.{wn}.{k}" + ) + for e in range(self.args.num_local_experts) + ] + weights[ + f"{prefix}.block_sparse_moe.switch_mlp.{n}.{k}" + ] = mx.stack(to_join) # Stack down(w2) normally for k in ["weight", "scales", "biases"]: w2_key = f"{prefix}.block_sparse_moe.experts.0.w2.{k}" diff --git a/mlx_lm/models/olmoe.py b/mlx_lm/models/olmoe.py index 80a65043a..91ea77078 100644 --- a/mlx_lm/models/olmoe.py +++ b/mlx_lm/models/olmoe.py @@ -106,7 +106,7 @@ def __init__(self, args: ModelArgs): args.intermediate_size, self.num_experts, bias=args.mlp_bias, - fuse_gate_up=True, + fuse_gate_up=getattr(args, "fuse_gate_up", True), ) def __call__(self, x: mx.array) -> mx.array: @@ -198,28 +198,49 @@ def __call__( def sanitize(self, weights): if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights: return weights + fuse = getattr(self.args, "fuse_gate_up", True) for l in range(self.args.num_hidden_layers): prefix = f"model.layers.{l}" - # Stack and fuse gate+up into single gate_up_proj - for k in ["weight", "scales", "biases"]: - gate_key = f"{prefix}.mlp.experts.0.gate_proj.{k}" - up_key = f"{prefix}.mlp.experts.0.up_proj.{k}" - if gate_key in weights and up_key in weights: - gate = mx.stack( - [ - weights.pop(f"{prefix}.mlp.experts.{e}.gate_proj.{k}") - for e in range(self.args.num_experts) - ] - ) - up = mx.stack( - [ - weights.pop(f"{prefix}.mlp.experts.{e}.up_proj.{k}") - for e in range(self.args.num_experts) - ] - ) - weights[f"{prefix}.mlp.switch_mlp.gate_up_proj.{k}"] = ( - mx.concatenate([gate, up], axis=1) - ) + if fuse and f"{prefix}.mlp.experts.0.gate_proj.weight" in weights: + # Stack and fuse gate+up into single gate_up_proj + for k in ["weight", "scales", "biases"]: + gate_key = f"{prefix}.mlp.experts.0.gate_proj.{k}" + up_key = f"{prefix}.mlp.experts.0.up_proj.{k}" + if gate_key in weights and up_key in weights: + gate = mx.stack( + [ + weights.pop( + f"{prefix}.mlp.experts.{e}.gate_proj.{k}" + ) + for e in range(self.args.num_experts) + ] + ) + up = mx.stack( + [ + weights.pop( + f"{prefix}.mlp.experts.{e}.up_proj.{k}" + ) + for e in range(self.args.num_experts) + ] + ) + weights[f"{prefix}.mlp.switch_mlp.gate_up_proj.{k}"] = ( + mx.concatenate([gate, up], axis=1) + ) + else: + # Unfused: separate gate_proj and up_proj + for n in ["gate_proj", "up_proj"]: + for k in ["weight", "scales", "biases"]: + src = f"{prefix}.mlp.experts.0.{n}.{k}" + if src in weights: + to_join = [ + weights.pop( + f"{prefix}.mlp.experts.{e}.{n}.{k}" + ) + for e in range(self.args.num_experts) + ] + weights[ + f"{prefix}.mlp.switch_mlp.{n}.{k}" + ] = mx.stack(to_join) # Stack down_proj normally for k in ["weight", "scales", "biases"]: down_key = f"{prefix}.mlp.experts.0.down_proj.{k}" diff --git a/mlx_lm/models/qwen3_5_moe.py b/mlx_lm/models/qwen3_5_moe.py index c8398e742..112c6e05f 100644 --- a/mlx_lm/models/qwen3_5_moe.py +++ b/mlx_lm/models/qwen3_5_moe.py @@ -33,15 +33,24 @@ def sanitize(self, weights): key = "language_model." + key new_weights[key] = value + fuse = getattr(self.language_model.args, "fuse_gate_up", True) for l in range(self.language_model.args.num_hidden_layers): prefix = f"language_model.model.layers.{l}.mlp" gate_up_key = f"{prefix}.experts.gate_up_proj" if gate_up_key in new_weights: - new_weights[f"{prefix}.switch_mlp.gate_up_proj.weight"] = ( - new_weights.pop(gate_up_key) - ) - new_weights[f"{prefix}.switch_mlp.down_proj.weight"] = new_weights.pop( - f"{prefix}.experts.down_proj" + gate_up = new_weights.pop(gate_up_key) + if fuse: + new_weights[f"{prefix}.switch_mlp.gate_up_proj.weight"] = gate_up + else: + mid = gate_up.shape[-2] // 2 + new_weights[f"{prefix}.switch_mlp.gate_proj.weight"] = gate_up[ + ..., :mid, : + ] + new_weights[f"{prefix}.switch_mlp.up_proj.weight"] = gate_up[ + ..., mid:, : + ] + new_weights[f"{prefix}.switch_mlp.down_proj.weight"] = ( + new_weights.pop(f"{prefix}.experts.down_proj") ) return self.language_model.sanitize(new_weights) diff --git a/mlx_lm/models/qwen3_moe.py b/mlx_lm/models/qwen3_moe.py index 0c265a981..4cce2db1e 100644 --- a/mlx_lm/models/qwen3_moe.py +++ b/mlx_lm/models/qwen3_moe.py @@ -118,8 +118,9 @@ def __init__(self, args: ModelArgs): self.norm_topk_prob = args.norm_topk_prob self.gate = nn.Linear(dim, num_experts, bias=False) + fuse = getattr(args, "fuse_gate_up", True) self.switch_mlp = SwitchGLU( - dim, intermediate_size, num_experts, fuse_gate_up=True + dim, intermediate_size, num_experts, fuse_gate_up=fuse ) def __call__( @@ -236,15 +237,43 @@ def sanitize(self, weights): weights.pop("lm_head.weight", None) if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights: return weights + fuse = getattr(self.args, "fuse_gate_up", True) for l in range(self.args.num_hidden_layers): prefix = f"model.layers.{l}" - for n in ["up_proj", "down_proj", "gate_proj"]: - if f"{prefix}.mlp.experts.0.{n}.weight" in weights: - to_join = [ - weights.pop(f"{prefix}.mlp.experts.{e}.{n}.weight") + if fuse and f"{prefix}.mlp.experts.0.gate_proj.weight" in weights: + gate = mx.stack( + [ + weights.pop(f"{prefix}.mlp.experts.{e}.gate_proj.weight") for e in range(self.args.num_experts) ] - weights[f"{prefix}.mlp.switch_mlp.{n}.weight"] = mx.stack(to_join) + ) + up = mx.stack( + [ + weights.pop(f"{prefix}.mlp.experts.{e}.up_proj.weight") + for e in range(self.args.num_experts) + ] + ) + weights[f"{prefix}.mlp.switch_mlp.gate_up_proj.weight"] = ( + mx.concatenate([gate, up], axis=1) + ) + if f"{prefix}.mlp.experts.0.down_proj.weight" in weights: + down = mx.stack( + [ + weights.pop(f"{prefix}.mlp.experts.{e}.down_proj.weight") + for e in range(self.args.num_experts) + ] + ) + weights[f"{prefix}.mlp.switch_mlp.down_proj.weight"] = down + else: + for n in ["up_proj", "down_proj", "gate_proj"]: + if f"{prefix}.mlp.experts.0.{n}.weight" in weights: + to_join = [ + weights.pop(f"{prefix}.mlp.experts.{e}.{n}.weight") + for e in range(self.args.num_experts) + ] + weights[f"{prefix}.mlp.switch_mlp.{n}.weight"] = mx.stack( + to_join + ) return weights @property diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index 48ccdccb9..eb02a5943 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -317,8 +317,9 @@ def __init__(self, args: ModelArgs): self.top_k = args.num_experts_per_tok self.gate = nn.Linear(dim, num_experts, bias=False) + fuse = getattr(args, "fuse_gate_up", True) self.switch_mlp = SwitchGLU( - dim, intermediate_size, num_experts, fuse_gate_up=True + dim, intermediate_size, num_experts, fuse_gate_up=fuse ) self.shared_expert = Qwen3NextMLP(dim, shared_expert_intermediate_size) @@ -459,14 +460,39 @@ def sanitize(self, weights): if self.args.tie_word_embeddings: weights.pop("lm_head.weight", None) + fuse = getattr(self.args, "fuse_gate_up", True) for l in range(self.args.num_hidden_layers): prefix = f"model.layers.{l}.mlp" - for n in ["up_proj", "down_proj", "gate_proj"]: - to_join = [ - weights.pop(f"{prefix}.experts.{e}.{n}.weight") - for e in range(self.args.num_experts) - ] - weights[f"{prefix}.switch_mlp.{n}.weight"] = mx.stack(to_join) + if fuse: + gate = mx.stack( + [ + weights.pop(f"{prefix}.experts.{e}.gate_proj.weight") + for e in range(self.args.num_experts) + ] + ) + up = mx.stack( + [ + weights.pop(f"{prefix}.experts.{e}.up_proj.weight") + for e in range(self.args.num_experts) + ] + ) + weights[f"{prefix}.switch_mlp.gate_up_proj.weight"] = ( + mx.concatenate([gate, up], axis=1) + ) + down = mx.stack( + [ + weights.pop(f"{prefix}.experts.{e}.down_proj.weight") + for e in range(self.args.num_experts) + ] + ) + weights[f"{prefix}.switch_mlp.down_proj.weight"] = down + else: + for n in ["up_proj", "down_proj", "gate_proj"]: + to_join = [ + weights.pop(f"{prefix}.experts.{e}.{n}.weight") + for e in range(self.args.num_experts) + ] + weights[f"{prefix}.switch_mlp.{n}.weight"] = mx.stack(to_join) norm_keys = ( ".input_layernorm.weight", diff --git a/mlx_lm/models/qwen3_vl_moe.py b/mlx_lm/models/qwen3_vl_moe.py index 05be0b8e1..80260dd72 100644 --- a/mlx_lm/models/qwen3_vl_moe.py +++ b/mlx_lm/models/qwen3_vl_moe.py @@ -50,13 +50,22 @@ def sanitize(self, weights): ) ) + fuse = getattr(self.language_model.args, "fuse_gate_up", True) for l in range(self.language_model.args.num_hidden_layers): prefix = f"language_model.model.layers.{l}.mlp" gate_up_key = f"{prefix}.experts.gate_up_proj" if gate_up_key in weights: - weights[f"{prefix}.switch_mlp.gate_up_proj.weight"] = weights.pop( - gate_up_key - ).swapaxes(-2, -1) + gate_up = weights.pop(gate_up_key).swapaxes(-2, -1) + if fuse: + weights[f"{prefix}.switch_mlp.gate_up_proj.weight"] = gate_up + else: + mid = gate_up.shape[-2] // 2 + weights[f"{prefix}.switch_mlp.gate_proj.weight"] = gate_up[ + ..., :mid, : + ] + weights[f"{prefix}.switch_mlp.up_proj.weight"] = gate_up[ + ..., mid:, : + ] weights[f"{prefix}.switch_mlp.down_proj.weight"] = weights.pop( f"{prefix}.experts.down_proj" ).swapaxes(-2, -1) diff --git a/mlx_lm/models/switch_layers.py b/mlx_lm/models/switch_layers.py index 12e9cda78..f18e8e4cf 100644 --- a/mlx_lm/models/switch_layers.py +++ b/mlx_lm/models/switch_layers.py @@ -165,7 +165,7 @@ def __init__( num_experts: int, activation=SwiGLU(), bias: bool = False, - fuse_gate_up: bool = False, + fuse_gate_up: bool = True, ): super().__init__() From e7de34bf5c1427f0a7f59fae2ca25bd4faa98d99 Mon Sep 17 00:00:00 2001 From: Thump604 Date: Sat, 21 Mar 2026 00:48:38 -0500 Subject: [PATCH 3/3] fix: default fuse_gate_up=False for backward compatibility Existing quantized models (e.g. MiniMax DWQ quants) have separate gate_proj/up_proj weights. Defaulting to True broke them because the model init creates gate_up_proj while weights have unfused names, and quantized weights cannot be concatenated (scales/biases mismatch). Default to False so existing models work unchanged. New conversions can opt in by setting "fuse_gate_up": true in config.json. --- mlx_lm/models/llama4.py | 4 +- mlx_lm/models/minimax.py | 4 +- mlx_lm/models/mixtral.py | 4 +- mlx_lm/models/olmoe.py | 4 +- mlx_lm/models/qwen3_5_moe.py | 2 +- mlx_lm/models/qwen3_moe.py | 4 +- mlx_lm/models/qwen3_next.py | 4 +- mlx_lm/models/qwen3_vl_moe.py | 2 +- mlx_lm/models/switch_layers.py | 2 +- tests/test_nemotron_latentmoe.py | 222 +++++++++++++++++++++++++++++++ 10 files changed, 237 insertions(+), 15 deletions(-) create mode 100644 tests/test_nemotron_latentmoe.py diff --git a/mlx_lm/models/llama4.py b/mlx_lm/models/llama4.py index 45f90d0d7..6732e7f62 100644 --- a/mlx_lm/models/llama4.py +++ b/mlx_lm/models/llama4.py @@ -155,7 +155,7 @@ def __init__(self, args): self.top_k = args.num_experts_per_tok assert self.top_k == 1, "Only 1 expert per token supported" self.num_experts = args.num_local_experts - fuse = getattr(args, "fuse_gate_up", True) + fuse = getattr(args, "fuse_gate_up", False) self.experts = SwitchGLU( args.hidden_size, args.intermediate_size, @@ -300,7 +300,7 @@ def to_remove(k): weights = {k: v for k, v in weights.items() if not to_remove(k)} # Rename expert weights for SwitchGLU - fuse = getattr(self.args.text_config, "fuse_gate_up", True) + fuse = getattr(self.args.text_config, "fuse_gate_up", False) for l in range(self.args.text_config.num_hidden_layers): prefix = f"language_model.model.layers.{l}.feed_forward.experts" if f"{prefix}.gate_up_proj" in weights: diff --git a/mlx_lm/models/minimax.py b/mlx_lm/models/minimax.py index d0aba46db..f1b15a16f 100644 --- a/mlx_lm/models/minimax.py +++ b/mlx_lm/models/minimax.py @@ -165,7 +165,7 @@ def __init__(self, args: ModelArgs): self.num_experts_per_tok = args.num_experts_per_tok self.gate = nn.Linear(args.hidden_size, args.num_local_experts, bias=False) - fuse = getattr(args, "fuse_gate_up", True) + fuse = getattr(args, "fuse_gate_up", False) self.switch_mlp = SwitchGLU( args.hidden_size, args.intermediate_size, @@ -313,7 +313,7 @@ def dequant(weight, scale_inv): if "model.layers.0.block_sparse_moe.experts.0.w1.weight" not in weights: return weights - fuse = getattr(self.args, "fuse_gate_up", True) + fuse = getattr(self.args, "fuse_gate_up", False) for l in range(self.args.num_hidden_layers): prefix = f"model.layers.{l}" w1_key = f"{prefix}.block_sparse_moe.experts.0.w1.weight" diff --git a/mlx_lm/models/mixtral.py b/mlx_lm/models/mixtral.py index 0cdf1373f..46e1d2dcb 100644 --- a/mlx_lm/models/mixtral.py +++ b/mlx_lm/models/mixtral.py @@ -105,7 +105,7 @@ def __init__(self, args: ModelArgs): # gating self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) - fuse = getattr(args, "fuse_gate_up", True) + fuse = getattr(args, "fuse_gate_up", False) self.switch_mlp = SwitchGLU( self.hidden_dim, self.ffn_dim, self.num_experts, fuse_gate_up=fuse ) @@ -210,7 +210,7 @@ def sanitize(self, weights): weights.pop("lm_head.weight", None) if "model.layers.0.block_sparse_moe.experts.0.w1.weight" not in weights: return weights - fuse = getattr(self.args, "fuse_gate_up", True) + fuse = getattr(self.args, "fuse_gate_up", False) for l in range(self.args.num_hidden_layers): prefix = f"model.layers.{l}" if fuse: diff --git a/mlx_lm/models/olmoe.py b/mlx_lm/models/olmoe.py index 91ea77078..02ac21637 100644 --- a/mlx_lm/models/olmoe.py +++ b/mlx_lm/models/olmoe.py @@ -106,7 +106,7 @@ def __init__(self, args: ModelArgs): args.intermediate_size, self.num_experts, bias=args.mlp_bias, - fuse_gate_up=getattr(args, "fuse_gate_up", True), + fuse_gate_up=getattr(args, "fuse_gate_up", False), ) def __call__(self, x: mx.array) -> mx.array: @@ -198,7 +198,7 @@ def __call__( def sanitize(self, weights): if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights: return weights - fuse = getattr(self.args, "fuse_gate_up", True) + fuse = getattr(self.args, "fuse_gate_up", False) for l in range(self.args.num_hidden_layers): prefix = f"model.layers.{l}" if fuse and f"{prefix}.mlp.experts.0.gate_proj.weight" in weights: diff --git a/mlx_lm/models/qwen3_5_moe.py b/mlx_lm/models/qwen3_5_moe.py index 112c6e05f..c9f5ece7d 100644 --- a/mlx_lm/models/qwen3_5_moe.py +++ b/mlx_lm/models/qwen3_5_moe.py @@ -33,7 +33,7 @@ def sanitize(self, weights): key = "language_model." + key new_weights[key] = value - fuse = getattr(self.language_model.args, "fuse_gate_up", True) + fuse = getattr(self.language_model.args, "fuse_gate_up", False) for l in range(self.language_model.args.num_hidden_layers): prefix = f"language_model.model.layers.{l}.mlp" gate_up_key = f"{prefix}.experts.gate_up_proj" diff --git a/mlx_lm/models/qwen3_moe.py b/mlx_lm/models/qwen3_moe.py index 4cce2db1e..269cc99cc 100644 --- a/mlx_lm/models/qwen3_moe.py +++ b/mlx_lm/models/qwen3_moe.py @@ -118,7 +118,7 @@ def __init__(self, args: ModelArgs): self.norm_topk_prob = args.norm_topk_prob self.gate = nn.Linear(dim, num_experts, bias=False) - fuse = getattr(args, "fuse_gate_up", True) + fuse = getattr(args, "fuse_gate_up", False) self.switch_mlp = SwitchGLU( dim, intermediate_size, num_experts, fuse_gate_up=fuse ) @@ -237,7 +237,7 @@ def sanitize(self, weights): weights.pop("lm_head.weight", None) if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights: return weights - fuse = getattr(self.args, "fuse_gate_up", True) + fuse = getattr(self.args, "fuse_gate_up", False) for l in range(self.args.num_hidden_layers): prefix = f"model.layers.{l}" if fuse and f"{prefix}.mlp.experts.0.gate_proj.weight" in weights: diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index eb02a5943..b8ab7334d 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -317,7 +317,7 @@ def __init__(self, args: ModelArgs): self.top_k = args.num_experts_per_tok self.gate = nn.Linear(dim, num_experts, bias=False) - fuse = getattr(args, "fuse_gate_up", True) + fuse = getattr(args, "fuse_gate_up", False) self.switch_mlp = SwitchGLU( dim, intermediate_size, num_experts, fuse_gate_up=fuse ) @@ -460,7 +460,7 @@ def sanitize(self, weights): if self.args.tie_word_embeddings: weights.pop("lm_head.weight", None) - fuse = getattr(self.args, "fuse_gate_up", True) + fuse = getattr(self.args, "fuse_gate_up", False) for l in range(self.args.num_hidden_layers): prefix = f"model.layers.{l}.mlp" if fuse: diff --git a/mlx_lm/models/qwen3_vl_moe.py b/mlx_lm/models/qwen3_vl_moe.py index 80260dd72..d864de7a9 100644 --- a/mlx_lm/models/qwen3_vl_moe.py +++ b/mlx_lm/models/qwen3_vl_moe.py @@ -50,7 +50,7 @@ def sanitize(self, weights): ) ) - fuse = getattr(self.language_model.args, "fuse_gate_up", True) + fuse = getattr(self.language_model.args, "fuse_gate_up", False) for l in range(self.language_model.args.num_hidden_layers): prefix = f"language_model.model.layers.{l}.mlp" gate_up_key = f"{prefix}.experts.gate_up_proj" diff --git a/mlx_lm/models/switch_layers.py b/mlx_lm/models/switch_layers.py index f18e8e4cf..12e9cda78 100644 --- a/mlx_lm/models/switch_layers.py +++ b/mlx_lm/models/switch_layers.py @@ -165,7 +165,7 @@ def __init__( num_experts: int, activation=SwiGLU(), bias: bool = False, - fuse_gate_up: bool = True, + fuse_gate_up: bool = False, ): super().__init__() diff --git a/tests/test_nemotron_latentmoe.py b/tests/test_nemotron_latentmoe.py new file mode 100644 index 000000000..7640a8c83 --- /dev/null +++ b/tests/test_nemotron_latentmoe.py @@ -0,0 +1,222 @@ +"""Tests for Nemotron-H LatentMoE support (PR #992). + +Tests the additions to nemotron_h.py: +- ModelArgs: moe_latent_size, layers_block_type normalization, time_step_limit defaults +- NemotronHMoE: latent projection forward pass +- Model.sanitize: MTP weight stripping +""" +import unittest + +import mlx.core as mx +import mlx.nn as nn + +from mlx_lm.models.nemotron_h import Model, ModelArgs, NemotronHMoE + + +class TestModelArgsLatentMoE(unittest.TestCase): + """Test ModelArgs parsing for Nemotron Super config fields.""" + + def _base_args(self, **overrides): + """Minimal valid config for nemotron_h with MoE layers.""" + cfg = { + "model_type": "nemotron_h", + "vocab_size": 1000, + "hidden_size": 128, + "intermediate_size": 64, + "num_hidden_layers": 4, + "max_position_embeddings": 1000, + "num_attention_heads": 4, + "num_key_value_heads": 2, + "attention_bias": False, + "mamba_num_heads": 4, + "mamba_head_dim": 32, + "mamba_proj_bias": False, + "ssm_state_size": 32, + "conv_kernel": 4, + "n_groups": 2, + "time_step_min": 0.001, + "mlp_bias": False, + "layer_norm_epsilon": 1e-5, + "use_bias": False, + "use_conv_bias": True, + "hybrid_override_pattern": ["M", "E", "*", "E"], + "n_routed_experts": 8, + "num_experts_per_tok": 2, + "moe_intermediate_size": 64, + } + cfg.update(overrides) + return ModelArgs(**cfg) + + def test_moe_latent_size_parsed(self): + args = self._base_args(moe_latent_size=32) + self.assertEqual(args.moe_latent_size, 32) + + def test_moe_latent_size_none_by_default(self): + args = self._base_args() + self.assertIsNone(args.moe_latent_size) + + def test_layers_block_type_normalization(self): + """layers_block_type (word list) should normalize to hybrid_override_pattern (char list).""" + args = self._base_args( + hybrid_override_pattern=None, + layers_block_type=["mamba", "moe", "attention", "moe"], + ) + self.assertEqual(args.hybrid_override_pattern, ["M", "E", "*", "E"]) + self.assertEqual(args.num_hidden_layers, 4) + + def test_hybrid_override_pattern_string(self): + """Config from HuggingFace comes as a string, should work with iteration.""" + args = self._base_args(hybrid_override_pattern="ME*E") + # String iterates as chars, len() returns 4 + self.assertEqual(len(args.hybrid_override_pattern), 4) + self.assertEqual(list(args.hybrid_override_pattern), ["M", "E", "*", "E"]) + + def test_time_step_limit_no_upper_bound(self): + """time_step_limit should use inf upper bound when only time_step_min is set.""" + args = self._base_args(time_step_min=0.001) + self.assertEqual(args.time_step_limit[0], 0.001) + self.assertEqual(args.time_step_limit[1], float("inf")) + + def test_time_step_limit_explicit_overrides(self): + """Explicit time_step_limit should not be overwritten.""" + args = self._base_args(time_step_limit=(0.01, 0.5), time_step_min=0.001) + self.assertEqual(args.time_step_limit, (0.01, 0.5)) + + +class TestNemotronHMoELatent(unittest.TestCase): + """Test NemotronHMoE forward pass with latent projection.""" + + def _make_config(self, moe_latent_size=None): + return self._base_args(moe_latent_size=moe_latent_size) + + def _base_args(self, **overrides): + cfg = { + "model_type": "nemotron_h", + "vocab_size": 1000, + "hidden_size": 64, + "intermediate_size": 32, + "num_hidden_layers": 2, + "max_position_embeddings": 512, + "num_attention_heads": 4, + "num_key_value_heads": 2, + "attention_bias": False, + "mamba_num_heads": 4, + "mamba_head_dim": 16, + "mamba_proj_bias": False, + "ssm_state_size": 16, + "conv_kernel": 4, + "n_groups": 2, + "time_step_min": 0.001, + "mlp_bias": False, + "layer_norm_epsilon": 1e-5, + "use_bias": False, + "use_conv_bias": True, + "hybrid_override_pattern": ["E", "E"], + "n_routed_experts": 4, + "num_experts_per_tok": 2, + "moe_intermediate_size": 32, + "n_group": 1, + "topk_group": 1, + "routed_scaling_factor": 1.0, + "norm_topk_prob": True, + } + cfg.update(overrides) + return ModelArgs(**cfg) + + def test_latent_projection_shapes(self): + """With moe_latent_size, experts should operate on latent dim.""" + config = self._make_config(moe_latent_size=16) + moe = NemotronHMoE(config) + mx.eval(moe.parameters()) + + # Input: (batch=1, seq=1, hidden=64) + x = mx.random.normal((1, 1, 64)) + y = moe(x) + mx.eval(y) + + # Output should match hidden_size, not latent size + self.assertEqual(y.shape, (1, 1, 64)) + + def test_no_latent_projection(self): + """Without moe_latent_size, experts operate at full hidden dim.""" + config = self._make_config(moe_latent_size=None) + moe = NemotronHMoE(config) + mx.eval(moe.parameters()) + + x = mx.random.normal((1, 1, 64)) + y = moe(x) + mx.eval(y) + self.assertEqual(y.shape, (1, 1, 64)) + + def test_latent_projection_has_layers(self): + """LatentMoE should have fc1/fc2 latent projection layers.""" + config = self._make_config(moe_latent_size=16) + moe = NemotronHMoE(config) + self.assertTrue(hasattr(moe, "fc1_latent_proj")) + self.assertTrue(hasattr(moe, "fc2_latent_proj")) + # fc1: hidden(64) -> latent(16) + self.assertEqual(moe.fc1_latent_proj.weight.shape, (16, 64)) + # fc2: latent(16) -> hidden(64) + self.assertEqual(moe.fc2_latent_proj.weight.shape, (64, 16)) + + def test_shared_expert_gets_original_input(self): + """Shared expert should receive the original residuals, not latent-projected input.""" + config = self._make_config(moe_latent_size=16) + config.n_shared_experts = 1 + config.moe_shared_expert_intermediate_size = 32 + moe = NemotronHMoE(config) + mx.eval(moe.parameters()) + + x = mx.random.normal((1, 1, 64)) + # Just verify it runs without error — the shared expert + # should accept hidden_size(64) input, not latent_size(16) + y = moe(x) + mx.eval(y) + self.assertEqual(y.shape, (1, 1, 64)) + + +class TestSanitizeMTP(unittest.TestCase): + """Test that sanitize() strips MTP weights.""" + + def test_mtp_weights_stripped(self): + config = ModelArgs( + model_type="nemotron_h", + vocab_size=100, + hidden_size=64, + intermediate_size=32, + num_hidden_layers=2, + max_position_embeddings=256, + num_attention_heads=4, + num_key_value_heads=2, + attention_bias=False, + mamba_num_heads=4, + mamba_head_dim=16, + mamba_proj_bias=False, + ssm_state_size=16, + conv_kernel=4, + n_groups=2, + time_step_min=0.001, + mlp_bias=False, + layer_norm_epsilon=1e-5, + use_bias=False, + use_conv_bias=True, + hybrid_override_pattern=["*", "M"], + ) + model = Model(config) + weights = { + "model.embed_tokens.weight": mx.zeros((100, 64)), + "model.layers.0.norm.weight": mx.zeros((64,)), + "mtp.layers.0.weight": mx.zeros((64, 64)), + "mtp.head.weight": mx.zeros((100, 64)), + } + sanitized = model.sanitize(weights) + # MTP weights should be removed + self.assertNotIn("mtp.layers.0.weight", sanitized) + self.assertNotIn("mtp.head.weight", sanitized) + # Non-MTP weights should be preserved + self.assertIn("model.embed_tokens.weight", sanitized) + self.assertIn("model.layers.0.norm.weight", sanitized) + + +if __name__ == "__main__": + unittest.main()