diff --git a/mlx_lm/models/llama4.py b/mlx_lm/models/llama4.py index e4e284d71..6732e7f62 100644 --- a/mlx_lm/models/llama4.py +++ b/mlx_lm/models/llama4.py @@ -155,8 +155,12 @@ def __init__(self, args): self.top_k = args.num_experts_per_tok assert self.top_k == 1, "Only 1 expert per token supported" self.num_experts = args.num_local_experts + fuse = getattr(args, "fuse_gate_up", False) self.experts = SwitchGLU( - args.hidden_size, args.intermediate_size, self.num_experts + args.hidden_size, + args.intermediate_size, + self.num_experts, + fuse_gate_up=fuse, ) self.router = nn.Linear(args.hidden_size, args.num_local_experts, bias=False) self.shared_expert = MLP(args) @@ -296,15 +300,19 @@ def to_remove(k): weights = {k: v for k, v in weights.items() if not to_remove(k)} # Rename expert weights for SwitchGLU + fuse = getattr(self.args.text_config, "fuse_gate_up", False) for l in range(self.args.text_config.num_hidden_layers): prefix = f"language_model.model.layers.{l}.feed_forward.experts" if f"{prefix}.gate_up_proj" in weights: v = weights.pop(f"{prefix}.gate_up_proj") - gate_k = f"{prefix}.gate_proj.weight" - up_k = f"{prefix}.up_proj.weight" - gate_proj, up_proj = mx.split(v, 2, axis=-1) - weights[gate_k] = mx.swapaxes(gate_proj, 1, 2) - weights[up_k] = mx.swapaxes(up_proj, 1, 2) + v = mx.swapaxes(v, 1, 2) + if fuse: + weights[f"{prefix}.gate_up_proj.weight"] = v + else: + # Split fused gate_up into separate gate_proj and up_proj + gate, up = mx.split(v, 2, axis=1) + weights[f"{prefix}.gate_proj.weight"] = gate + weights[f"{prefix}.up_proj.weight"] = up if f"{prefix}.down_proj" in weights: down_proj = weights.pop(f"{prefix}.down_proj") weights[f"{prefix}.down_proj.weight"] = mx.swapaxes(down_proj, 1, 2) diff --git a/mlx_lm/models/minimax.py b/mlx_lm/models/minimax.py index 9bf78d9a4..f1b15a16f 100644 --- a/mlx_lm/models/minimax.py +++ b/mlx_lm/models/minimax.py @@ -165,8 +165,12 @@ def __init__(self, args: ModelArgs): self.num_experts_per_tok = args.num_experts_per_tok self.gate = nn.Linear(args.hidden_size, args.num_local_experts, bias=False) + fuse = getattr(args, "fuse_gate_up", False) self.switch_mlp = SwitchGLU( - args.hidden_size, args.intermediate_size, args.num_local_experts + args.hidden_size, + args.intermediate_size, + args.num_local_experts, + fuse_gate_up=fuse, ) self.e_score_correction_bias = mx.zeros((args.num_local_experts,)) self.sharding_group = None @@ -309,20 +313,55 @@ def dequant(weight, scale_inv): if "model.layers.0.block_sparse_moe.experts.0.w1.weight" not in weights: return weights + fuse = getattr(self.args, "fuse_gate_up", False) for l in range(self.args.num_hidden_layers): prefix = f"model.layers.{l}" - mapping = {"w1": "gate_proj", "w2": "down_proj", "w3": "up_proj"} - for orig_name, new_name in mapping.items(): - if f"{prefix}.block_sparse_moe.experts.0.{orig_name}.weight" in weights: - to_join = [ - weights.pop( - f"{prefix}.block_sparse_moe.experts.{e}.{orig_name}.weight" - ) - for e in range(self.args.num_local_experts) - ] + w1_key = f"{prefix}.block_sparse_moe.experts.0.w1.weight" + w3_key = f"{prefix}.block_sparse_moe.experts.0.w3.weight" + if w1_key in weights and w3_key in weights: + if fuse: + # Stack and fuse gate(w1)+up(w3) into single gate_up_proj + gate = mx.stack( + [ + weights.pop( + f"{prefix}.block_sparse_moe.experts.{e}.w1.weight" + ) + for e in range(self.args.num_local_experts) + ] + ) + up = mx.stack( + [ + weights.pop( + f"{prefix}.block_sparse_moe.experts.{e}.w3.weight" + ) + for e in range(self.args.num_local_experts) + ] + ) weights[ - f"{prefix}.block_sparse_moe.switch_mlp.{new_name}.weight" - ] = mx.stack(to_join) + f"{prefix}.block_sparse_moe.switch_mlp.gate_up_proj.weight" + ] = mx.concatenate([gate, up], axis=1) + else: + # Unfused: separate gate_proj (w1) and up_proj (w3) + for n, wn in [("gate_proj", "w1"), ("up_proj", "w3")]: + to_join = [ + weights.pop( + f"{prefix}.block_sparse_moe.experts.{e}.{wn}.weight" + ) + for e in range(self.args.num_local_experts) + ] + weights[ + f"{prefix}.block_sparse_moe.switch_mlp.{n}.weight" + ] = mx.stack(to_join) + # Stack down(w2) normally + w2_key = f"{prefix}.block_sparse_moe.experts.0.w2.weight" + if w2_key in weights: + to_join = [ + weights.pop(f"{prefix}.block_sparse_moe.experts.{e}.w2.weight") + for e in range(self.args.num_local_experts) + ] + weights[f"{prefix}.block_sparse_moe.switch_mlp.down_proj.weight"] = ( + mx.stack(to_join) + ) return weights diff --git a/mlx_lm/models/mixtral.py b/mlx_lm/models/mixtral.py index 18d16644c..46e1d2dcb 100644 --- a/mlx_lm/models/mixtral.py +++ b/mlx_lm/models/mixtral.py @@ -105,7 +105,10 @@ def __init__(self, args: ModelArgs): # gating self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) - self.switch_mlp = SwitchGLU(self.hidden_dim, self.ffn_dim, self.num_experts) + fuse = getattr(args, "fuse_gate_up", False) + self.switch_mlp = SwitchGLU( + self.hidden_dim, self.ffn_dim, self.num_experts, fuse_gate_up=fuse + ) def __call__(self, x: mx.array) -> mx.array: gates = self.gate(x) @@ -207,20 +210,60 @@ def sanitize(self, weights): weights.pop("lm_head.weight", None) if "model.layers.0.block_sparse_moe.experts.0.w1.weight" not in weights: return weights + fuse = getattr(self.args, "fuse_gate_up", False) for l in range(self.args.num_hidden_layers): prefix = f"model.layers.{l}" - for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]: + if fuse: + # Stack and fuse gate(w1)+up(w3) into single gate_up_proj for k in ["weight", "scales", "biases"]: - if f"{prefix}.block_sparse_moe.experts.0.{n}.{k}" in weights: - to_join = [ - weights.pop( - f"{prefix}.block_sparse_moe.experts.{e}.{n}.{k}" - ) - for e in range(self.args.num_local_experts) - ] - weights[f"{prefix}.block_sparse_moe.switch_mlp.{m}.{k}"] = ( - mx.stack(to_join) + w1_key = f"{prefix}.block_sparse_moe.experts.0.w1.{k}" + w3_key = f"{prefix}.block_sparse_moe.experts.0.w3.{k}" + if w1_key in weights and w3_key in weights: + gate = mx.stack( + [ + weights.pop( + f"{prefix}.block_sparse_moe.experts.{e}.w1.{k}" + ) + for e in range(self.args.num_local_experts) + ] + ) + up = mx.stack( + [ + weights.pop( + f"{prefix}.block_sparse_moe.experts.{e}.w3.{k}" + ) + for e in range(self.args.num_local_experts) + ] ) + weights[ + f"{prefix}.block_sparse_moe.switch_mlp.gate_up_proj.{k}" + ] = mx.concatenate([gate, up], axis=1) + else: + # Unfused: separate gate_proj (w1) and up_proj (w3) + for n, wn in [("gate_proj", "w1"), ("up_proj", "w3")]: + for k in ["weight", "scales", "biases"]: + src = f"{prefix}.block_sparse_moe.experts.0.{wn}.{k}" + if src in weights: + to_join = [ + weights.pop( + f"{prefix}.block_sparse_moe.experts.{e}.{wn}.{k}" + ) + for e in range(self.args.num_local_experts) + ] + weights[ + f"{prefix}.block_sparse_moe.switch_mlp.{n}.{k}" + ] = mx.stack(to_join) + # Stack down(w2) normally + for k in ["weight", "scales", "biases"]: + w2_key = f"{prefix}.block_sparse_moe.experts.0.w2.{k}" + if w2_key in weights: + to_join = [ + weights.pop(f"{prefix}.block_sparse_moe.experts.{e}.w2.{k}") + for e in range(self.args.num_local_experts) + ] + weights[f"{prefix}.block_sparse_moe.switch_mlp.down_proj.{k}"] = ( + mx.stack(to_join) + ) return weights @property diff --git a/mlx_lm/models/olmoe.py b/mlx_lm/models/olmoe.py index c971c2247..02ac21637 100644 --- a/mlx_lm/models/olmoe.py +++ b/mlx_lm/models/olmoe.py @@ -106,6 +106,7 @@ def __init__(self, args: ModelArgs): args.intermediate_size, self.num_experts, bias=args.mlp_bias, + fuse_gate_up=getattr(args, "fuse_gate_up", False), ) def __call__(self, x: mx.array) -> mx.array: @@ -197,16 +198,60 @@ def __call__( def sanitize(self, weights): if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights: return weights + fuse = getattr(self.args, "fuse_gate_up", False) for l in range(self.args.num_hidden_layers): prefix = f"model.layers.{l}" - for n in ["up_proj", "down_proj", "gate_proj"]: + if fuse and f"{prefix}.mlp.experts.0.gate_proj.weight" in weights: + # Stack and fuse gate+up into single gate_up_proj for k in ["weight", "scales", "biases"]: - if f"{prefix}.mlp.experts.0.{n}.{k}" in weights: - to_join = [ - weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}") - for e in range(self.args.num_experts) - ] - weights[f"{prefix}.mlp.switch_mlp.{n}.{k}"] = mx.stack(to_join) + gate_key = f"{prefix}.mlp.experts.0.gate_proj.{k}" + up_key = f"{prefix}.mlp.experts.0.up_proj.{k}" + if gate_key in weights and up_key in weights: + gate = mx.stack( + [ + weights.pop( + f"{prefix}.mlp.experts.{e}.gate_proj.{k}" + ) + for e in range(self.args.num_experts) + ] + ) + up = mx.stack( + [ + weights.pop( + f"{prefix}.mlp.experts.{e}.up_proj.{k}" + ) + for e in range(self.args.num_experts) + ] + ) + weights[f"{prefix}.mlp.switch_mlp.gate_up_proj.{k}"] = ( + mx.concatenate([gate, up], axis=1) + ) + else: + # Unfused: separate gate_proj and up_proj + for n in ["gate_proj", "up_proj"]: + for k in ["weight", "scales", "biases"]: + src = f"{prefix}.mlp.experts.0.{n}.{k}" + if src in weights: + to_join = [ + weights.pop( + f"{prefix}.mlp.experts.{e}.{n}.{k}" + ) + for e in range(self.args.num_experts) + ] + weights[ + f"{prefix}.mlp.switch_mlp.{n}.{k}" + ] = mx.stack(to_join) + # Stack down_proj normally + for k in ["weight", "scales", "biases"]: + down_key = f"{prefix}.mlp.experts.0.down_proj.{k}" + if down_key in weights: + to_join = [ + weights.pop(f"{prefix}.mlp.experts.{e}.down_proj.{k}") + for e in range(self.args.num_experts) + ] + weights[f"{prefix}.mlp.switch_mlp.down_proj.{k}"] = mx.stack( + to_join + ) return weights @property diff --git a/mlx_lm/models/qwen3_5.py b/mlx_lm/models/qwen3_5.py index 43aadba23..7663dd357 100644 --- a/mlx_lm/models/qwen3_5.py +++ b/mlx_lm/models/qwen3_5.py @@ -498,15 +498,26 @@ def _repeat(p): shard_inplace( layer.mlp.shared_expert.up_proj, "all-to-sharded", group=group ) - shard_inplace( - layer.mlp.switch_mlp.gate_proj, "all-to-sharded", group=group - ) + if hasattr(layer.mlp.switch_mlp, "gate_up_proj"): + shard_inplace( + layer.mlp.switch_mlp.gate_up_proj, + "all-to-sharded", + group=group, + ) + else: + shard_inplace( + layer.mlp.switch_mlp.gate_proj, + "all-to-sharded", + group=group, + ) + shard_inplace( + layer.mlp.switch_mlp.up_proj, + "all-to-sharded", + group=group, + ) shard_inplace( layer.mlp.switch_mlp.down_proj, "sharded-to-all", group=group ) - shard_inplace( - layer.mlp.switch_mlp.up_proj, "all-to-sharded", group=group - ) @property def layers(self): diff --git a/mlx_lm/models/qwen3_5_moe.py b/mlx_lm/models/qwen3_5_moe.py index 53ab8530e..c9f5ece7d 100644 --- a/mlx_lm/models/qwen3_5_moe.py +++ b/mlx_lm/models/qwen3_5_moe.py @@ -33,20 +33,24 @@ def sanitize(self, weights): key = "language_model." + key new_weights[key] = value + fuse = getattr(self.language_model.args, "fuse_gate_up", False) for l in range(self.language_model.args.num_hidden_layers): prefix = f"language_model.model.layers.{l}.mlp" gate_up_key = f"{prefix}.experts.gate_up_proj" if gate_up_key in new_weights: gate_up = new_weights.pop(gate_up_key) - mid = gate_up.shape[-2] // 2 - new_weights[f"{prefix}.switch_mlp.gate_proj.weight"] = gate_up[ - ..., :mid, : - ] - new_weights[f"{prefix}.switch_mlp.up_proj.weight"] = gate_up[ - ..., mid:, : - ] - new_weights[f"{prefix}.switch_mlp.down_proj.weight"] = new_weights.pop( - f"{prefix}.experts.down_proj" + if fuse: + new_weights[f"{prefix}.switch_mlp.gate_up_proj.weight"] = gate_up + else: + mid = gate_up.shape[-2] // 2 + new_weights[f"{prefix}.switch_mlp.gate_proj.weight"] = gate_up[ + ..., :mid, : + ] + new_weights[f"{prefix}.switch_mlp.up_proj.weight"] = gate_up[ + ..., mid:, : + ] + new_weights[f"{prefix}.switch_mlp.down_proj.weight"] = ( + new_weights.pop(f"{prefix}.experts.down_proj") ) return self.language_model.sanitize(new_weights) diff --git a/mlx_lm/models/qwen3_moe.py b/mlx_lm/models/qwen3_moe.py index 52dc50f9b..269cc99cc 100644 --- a/mlx_lm/models/qwen3_moe.py +++ b/mlx_lm/models/qwen3_moe.py @@ -118,7 +118,10 @@ def __init__(self, args: ModelArgs): self.norm_topk_prob = args.norm_topk_prob self.gate = nn.Linear(dim, num_experts, bias=False) - self.switch_mlp = SwitchGLU(dim, intermediate_size, num_experts) + fuse = getattr(args, "fuse_gate_up", False) + self.switch_mlp = SwitchGLU( + dim, intermediate_size, num_experts, fuse_gate_up=fuse + ) def __call__( self, @@ -234,15 +237,43 @@ def sanitize(self, weights): weights.pop("lm_head.weight", None) if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights: return weights + fuse = getattr(self.args, "fuse_gate_up", False) for l in range(self.args.num_hidden_layers): prefix = f"model.layers.{l}" - for n in ["up_proj", "down_proj", "gate_proj"]: - if f"{prefix}.mlp.experts.0.{n}.weight" in weights: - to_join = [ - weights.pop(f"{prefix}.mlp.experts.{e}.{n}.weight") + if fuse and f"{prefix}.mlp.experts.0.gate_proj.weight" in weights: + gate = mx.stack( + [ + weights.pop(f"{prefix}.mlp.experts.{e}.gate_proj.weight") + for e in range(self.args.num_experts) + ] + ) + up = mx.stack( + [ + weights.pop(f"{prefix}.mlp.experts.{e}.up_proj.weight") for e in range(self.args.num_experts) ] - weights[f"{prefix}.mlp.switch_mlp.{n}.weight"] = mx.stack(to_join) + ) + weights[f"{prefix}.mlp.switch_mlp.gate_up_proj.weight"] = ( + mx.concatenate([gate, up], axis=1) + ) + if f"{prefix}.mlp.experts.0.down_proj.weight" in weights: + down = mx.stack( + [ + weights.pop(f"{prefix}.mlp.experts.{e}.down_proj.weight") + for e in range(self.args.num_experts) + ] + ) + weights[f"{prefix}.mlp.switch_mlp.down_proj.weight"] = down + else: + for n in ["up_proj", "down_proj", "gate_proj"]: + if f"{prefix}.mlp.experts.0.{n}.weight" in weights: + to_join = [ + weights.pop(f"{prefix}.mlp.experts.{e}.{n}.weight") + for e in range(self.args.num_experts) + ] + weights[f"{prefix}.mlp.switch_mlp.{n}.weight"] = mx.stack( + to_join + ) return weights @property diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index 39128a445..b8ab7334d 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -317,7 +317,10 @@ def __init__(self, args: ModelArgs): self.top_k = args.num_experts_per_tok self.gate = nn.Linear(dim, num_experts, bias=False) - self.switch_mlp = SwitchGLU(dim, intermediate_size, num_experts) + fuse = getattr(args, "fuse_gate_up", False) + self.switch_mlp = SwitchGLU( + dim, intermediate_size, num_experts, fuse_gate_up=fuse + ) self.shared_expert = Qwen3NextMLP(dim, shared_expert_intermediate_size) self.shared_expert_gate = nn.Linear(dim, 1, bias=False) @@ -457,14 +460,39 @@ def sanitize(self, weights): if self.args.tie_word_embeddings: weights.pop("lm_head.weight", None) + fuse = getattr(self.args, "fuse_gate_up", False) for l in range(self.args.num_hidden_layers): prefix = f"model.layers.{l}.mlp" - for n in ["up_proj", "down_proj", "gate_proj"]: - to_join = [ - weights.pop(f"{prefix}.experts.{e}.{n}.weight") - for e in range(self.args.num_experts) - ] - weights[f"{prefix}.switch_mlp.{n}.weight"] = mx.stack(to_join) + if fuse: + gate = mx.stack( + [ + weights.pop(f"{prefix}.experts.{e}.gate_proj.weight") + for e in range(self.args.num_experts) + ] + ) + up = mx.stack( + [ + weights.pop(f"{prefix}.experts.{e}.up_proj.weight") + for e in range(self.args.num_experts) + ] + ) + weights[f"{prefix}.switch_mlp.gate_up_proj.weight"] = ( + mx.concatenate([gate, up], axis=1) + ) + down = mx.stack( + [ + weights.pop(f"{prefix}.experts.{e}.down_proj.weight") + for e in range(self.args.num_experts) + ] + ) + weights[f"{prefix}.switch_mlp.down_proj.weight"] = down + else: + for n in ["up_proj", "down_proj", "gate_proj"]: + to_join = [ + weights.pop(f"{prefix}.experts.{e}.{n}.weight") + for e in range(self.args.num_experts) + ] + weights[f"{prefix}.switch_mlp.{n}.weight"] = mx.stack(to_join) norm_keys = ( ".input_layernorm.weight", diff --git a/mlx_lm/models/qwen3_vl_moe.py b/mlx_lm/models/qwen3_vl_moe.py index 7810f67b5..d864de7a9 100644 --- a/mlx_lm/models/qwen3_vl_moe.py +++ b/mlx_lm/models/qwen3_vl_moe.py @@ -50,18 +50,22 @@ def sanitize(self, weights): ) ) + fuse = getattr(self.language_model.args, "fuse_gate_up", False) for l in range(self.language_model.args.num_hidden_layers): prefix = f"language_model.model.layers.{l}.mlp" gate_up_key = f"{prefix}.experts.gate_up_proj" if gate_up_key in weights: - gate_up = weights.pop(gate_up_key) - mid = gate_up.shape[-1] // 2 - weights[f"{prefix}.switch_mlp.gate_proj.weight"] = gate_up[ - ..., :mid - ].swapaxes(-2, -1) - weights[f"{prefix}.switch_mlp.up_proj.weight"] = gate_up[ - ..., mid: - ].swapaxes(-2, -1) + gate_up = weights.pop(gate_up_key).swapaxes(-2, -1) + if fuse: + weights[f"{prefix}.switch_mlp.gate_up_proj.weight"] = gate_up + else: + mid = gate_up.shape[-2] // 2 + weights[f"{prefix}.switch_mlp.gate_proj.weight"] = gate_up[ + ..., :mid, : + ] + weights[f"{prefix}.switch_mlp.up_proj.weight"] = gate_up[ + ..., mid:, : + ] weights[f"{prefix}.switch_mlp.down_proj.weight"] = weights.pop( f"{prefix}.experts.down_proj" ).swapaxes(-2, -1) diff --git a/mlx_lm/models/switch_layers.py b/mlx_lm/models/switch_layers.py index 1fe5d917e..12e9cda78 100644 --- a/mlx_lm/models/switch_layers.py +++ b/mlx_lm/models/switch_layers.py @@ -165,11 +165,19 @@ def __init__( num_experts: int, activation=SwiGLU(), bias: bool = False, + fuse_gate_up: bool = False, ): super().__init__() - self.gate_proj = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias) - self.up_proj = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias) + if fuse_gate_up: + self.gate_up_proj = SwitchLinear( + input_dims, hidden_dims * 2, num_experts, bias=bias + ) + else: + self.gate_proj = SwitchLinear( + input_dims, hidden_dims, num_experts, bias=bias + ) + self.up_proj = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias) self.down_proj = SwitchLinear(hidden_dims, input_dims, num_experts, bias=bias) self.activation = activation @@ -185,8 +193,14 @@ def __call__(self, x, indices) -> mx.array: x, idx, inv_order = _gather_sort(x, indices) if self.training: idx = mx.stop_gradient(idx) - x_up = self.up_proj(x, idx, sorted_indices=do_sort) - x_gate = self.gate_proj(x, idx, sorted_indices=do_sort) + + if "gate_up_proj" in self: + x_gate_up = self.gate_up_proj(x, idx, sorted_indices=do_sort) + x_gate, x_up = mx.split(x_gate_up, 2, axis=-1) + else: + x_up = self.up_proj(x, idx, sorted_indices=do_sort) + x_gate = self.gate_proj(x, idx, sorted_indices=do_sort) + x = self.down_proj( self.activation(x_up, x_gate), idx, diff --git a/tests/test_nemotron_latentmoe.py b/tests/test_nemotron_latentmoe.py new file mode 100644 index 000000000..7640a8c83 --- /dev/null +++ b/tests/test_nemotron_latentmoe.py @@ -0,0 +1,222 @@ +"""Tests for Nemotron-H LatentMoE support (PR #992). + +Tests the additions to nemotron_h.py: +- ModelArgs: moe_latent_size, layers_block_type normalization, time_step_limit defaults +- NemotronHMoE: latent projection forward pass +- Model.sanitize: MTP weight stripping +""" +import unittest + +import mlx.core as mx +import mlx.nn as nn + +from mlx_lm.models.nemotron_h import Model, ModelArgs, NemotronHMoE + + +class TestModelArgsLatentMoE(unittest.TestCase): + """Test ModelArgs parsing for Nemotron Super config fields.""" + + def _base_args(self, **overrides): + """Minimal valid config for nemotron_h with MoE layers.""" + cfg = { + "model_type": "nemotron_h", + "vocab_size": 1000, + "hidden_size": 128, + "intermediate_size": 64, + "num_hidden_layers": 4, + "max_position_embeddings": 1000, + "num_attention_heads": 4, + "num_key_value_heads": 2, + "attention_bias": False, + "mamba_num_heads": 4, + "mamba_head_dim": 32, + "mamba_proj_bias": False, + "ssm_state_size": 32, + "conv_kernel": 4, + "n_groups": 2, + "time_step_min": 0.001, + "mlp_bias": False, + "layer_norm_epsilon": 1e-5, + "use_bias": False, + "use_conv_bias": True, + "hybrid_override_pattern": ["M", "E", "*", "E"], + "n_routed_experts": 8, + "num_experts_per_tok": 2, + "moe_intermediate_size": 64, + } + cfg.update(overrides) + return ModelArgs(**cfg) + + def test_moe_latent_size_parsed(self): + args = self._base_args(moe_latent_size=32) + self.assertEqual(args.moe_latent_size, 32) + + def test_moe_latent_size_none_by_default(self): + args = self._base_args() + self.assertIsNone(args.moe_latent_size) + + def test_layers_block_type_normalization(self): + """layers_block_type (word list) should normalize to hybrid_override_pattern (char list).""" + args = self._base_args( + hybrid_override_pattern=None, + layers_block_type=["mamba", "moe", "attention", "moe"], + ) + self.assertEqual(args.hybrid_override_pattern, ["M", "E", "*", "E"]) + self.assertEqual(args.num_hidden_layers, 4) + + def test_hybrid_override_pattern_string(self): + """Config from HuggingFace comes as a string, should work with iteration.""" + args = self._base_args(hybrid_override_pattern="ME*E") + # String iterates as chars, len() returns 4 + self.assertEqual(len(args.hybrid_override_pattern), 4) + self.assertEqual(list(args.hybrid_override_pattern), ["M", "E", "*", "E"]) + + def test_time_step_limit_no_upper_bound(self): + """time_step_limit should use inf upper bound when only time_step_min is set.""" + args = self._base_args(time_step_min=0.001) + self.assertEqual(args.time_step_limit[0], 0.001) + self.assertEqual(args.time_step_limit[1], float("inf")) + + def test_time_step_limit_explicit_overrides(self): + """Explicit time_step_limit should not be overwritten.""" + args = self._base_args(time_step_limit=(0.01, 0.5), time_step_min=0.001) + self.assertEqual(args.time_step_limit, (0.01, 0.5)) + + +class TestNemotronHMoELatent(unittest.TestCase): + """Test NemotronHMoE forward pass with latent projection.""" + + def _make_config(self, moe_latent_size=None): + return self._base_args(moe_latent_size=moe_latent_size) + + def _base_args(self, **overrides): + cfg = { + "model_type": "nemotron_h", + "vocab_size": 1000, + "hidden_size": 64, + "intermediate_size": 32, + "num_hidden_layers": 2, + "max_position_embeddings": 512, + "num_attention_heads": 4, + "num_key_value_heads": 2, + "attention_bias": False, + "mamba_num_heads": 4, + "mamba_head_dim": 16, + "mamba_proj_bias": False, + "ssm_state_size": 16, + "conv_kernel": 4, + "n_groups": 2, + "time_step_min": 0.001, + "mlp_bias": False, + "layer_norm_epsilon": 1e-5, + "use_bias": False, + "use_conv_bias": True, + "hybrid_override_pattern": ["E", "E"], + "n_routed_experts": 4, + "num_experts_per_tok": 2, + "moe_intermediate_size": 32, + "n_group": 1, + "topk_group": 1, + "routed_scaling_factor": 1.0, + "norm_topk_prob": True, + } + cfg.update(overrides) + return ModelArgs(**cfg) + + def test_latent_projection_shapes(self): + """With moe_latent_size, experts should operate on latent dim.""" + config = self._make_config(moe_latent_size=16) + moe = NemotronHMoE(config) + mx.eval(moe.parameters()) + + # Input: (batch=1, seq=1, hidden=64) + x = mx.random.normal((1, 1, 64)) + y = moe(x) + mx.eval(y) + + # Output should match hidden_size, not latent size + self.assertEqual(y.shape, (1, 1, 64)) + + def test_no_latent_projection(self): + """Without moe_latent_size, experts operate at full hidden dim.""" + config = self._make_config(moe_latent_size=None) + moe = NemotronHMoE(config) + mx.eval(moe.parameters()) + + x = mx.random.normal((1, 1, 64)) + y = moe(x) + mx.eval(y) + self.assertEqual(y.shape, (1, 1, 64)) + + def test_latent_projection_has_layers(self): + """LatentMoE should have fc1/fc2 latent projection layers.""" + config = self._make_config(moe_latent_size=16) + moe = NemotronHMoE(config) + self.assertTrue(hasattr(moe, "fc1_latent_proj")) + self.assertTrue(hasattr(moe, "fc2_latent_proj")) + # fc1: hidden(64) -> latent(16) + self.assertEqual(moe.fc1_latent_proj.weight.shape, (16, 64)) + # fc2: latent(16) -> hidden(64) + self.assertEqual(moe.fc2_latent_proj.weight.shape, (64, 16)) + + def test_shared_expert_gets_original_input(self): + """Shared expert should receive the original residuals, not latent-projected input.""" + config = self._make_config(moe_latent_size=16) + config.n_shared_experts = 1 + config.moe_shared_expert_intermediate_size = 32 + moe = NemotronHMoE(config) + mx.eval(moe.parameters()) + + x = mx.random.normal((1, 1, 64)) + # Just verify it runs without error — the shared expert + # should accept hidden_size(64) input, not latent_size(16) + y = moe(x) + mx.eval(y) + self.assertEqual(y.shape, (1, 1, 64)) + + +class TestSanitizeMTP(unittest.TestCase): + """Test that sanitize() strips MTP weights.""" + + def test_mtp_weights_stripped(self): + config = ModelArgs( + model_type="nemotron_h", + vocab_size=100, + hidden_size=64, + intermediate_size=32, + num_hidden_layers=2, + max_position_embeddings=256, + num_attention_heads=4, + num_key_value_heads=2, + attention_bias=False, + mamba_num_heads=4, + mamba_head_dim=16, + mamba_proj_bias=False, + ssm_state_size=16, + conv_kernel=4, + n_groups=2, + time_step_min=0.001, + mlp_bias=False, + layer_norm_epsilon=1e-5, + use_bias=False, + use_conv_bias=True, + hybrid_override_pattern=["*", "M"], + ) + model = Model(config) + weights = { + "model.embed_tokens.weight": mx.zeros((100, 64)), + "model.layers.0.norm.weight": mx.zeros((64,)), + "mtp.layers.0.weight": mx.zeros((64, 64)), + "mtp.head.weight": mx.zeros((100, 64)), + } + sanitized = model.sanitize(weights) + # MTP weights should be removed + self.assertNotIn("mtp.layers.0.weight", sanitized) + self.assertNotIn("mtp.head.weight", sanitized) + # Non-MTP weights should be preserved + self.assertIn("model.embed_tokens.weight", sanitized) + self.assertIn("model.layers.0.norm.weight", sanitized) + + +if __name__ == "__main__": + unittest.main()