Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions mlx_lm/models/llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,12 @@ def __init__(self, args):
self.top_k = args.num_experts_per_tok
assert self.top_k == 1, "Only 1 expert per token supported"
self.num_experts = args.num_local_experts
fuse = getattr(args, "fuse_gate_up", False)
self.experts = SwitchGLU(
args.hidden_size, args.intermediate_size, self.num_experts
args.hidden_size,
args.intermediate_size,
self.num_experts,
fuse_gate_up=fuse,
)
self.router = nn.Linear(args.hidden_size, args.num_local_experts, bias=False)
self.shared_expert = MLP(args)
Expand Down Expand Up @@ -296,15 +300,19 @@ def to_remove(k):
weights = {k: v for k, v in weights.items() if not to_remove(k)}

# Rename expert weights for SwitchGLU
fuse = getattr(self.args.text_config, "fuse_gate_up", False)
for l in range(self.args.text_config.num_hidden_layers):
prefix = f"language_model.model.layers.{l}.feed_forward.experts"
if f"{prefix}.gate_up_proj" in weights:
v = weights.pop(f"{prefix}.gate_up_proj")
gate_k = f"{prefix}.gate_proj.weight"
up_k = f"{prefix}.up_proj.weight"
gate_proj, up_proj = mx.split(v, 2, axis=-1)
weights[gate_k] = mx.swapaxes(gate_proj, 1, 2)
weights[up_k] = mx.swapaxes(up_proj, 1, 2)
v = mx.swapaxes(v, 1, 2)
if fuse:
weights[f"{prefix}.gate_up_proj.weight"] = v
else:
# Split fused gate_up into separate gate_proj and up_proj
gate, up = mx.split(v, 2, axis=1)
weights[f"{prefix}.gate_proj.weight"] = gate
weights[f"{prefix}.up_proj.weight"] = up
if f"{prefix}.down_proj" in weights:
down_proj = weights.pop(f"{prefix}.down_proj")
weights[f"{prefix}.down_proj.weight"] = mx.swapaxes(down_proj, 1, 2)
Expand Down
63 changes: 51 additions & 12 deletions mlx_lm/models/minimax.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,12 @@ def __init__(self, args: ModelArgs):
self.num_experts_per_tok = args.num_experts_per_tok

self.gate = nn.Linear(args.hidden_size, args.num_local_experts, bias=False)
fuse = getattr(args, "fuse_gate_up", False)
self.switch_mlp = SwitchGLU(
args.hidden_size, args.intermediate_size, args.num_local_experts
args.hidden_size,
args.intermediate_size,
args.num_local_experts,
fuse_gate_up=fuse,
)
self.e_score_correction_bias = mx.zeros((args.num_local_experts,))
self.sharding_group = None
Expand Down Expand Up @@ -309,20 +313,55 @@ def dequant(weight, scale_inv):
if "model.layers.0.block_sparse_moe.experts.0.w1.weight" not in weights:
return weights

fuse = getattr(self.args, "fuse_gate_up", False)
for l in range(self.args.num_hidden_layers):
prefix = f"model.layers.{l}"
mapping = {"w1": "gate_proj", "w2": "down_proj", "w3": "up_proj"}
for orig_name, new_name in mapping.items():
if f"{prefix}.block_sparse_moe.experts.0.{orig_name}.weight" in weights:
to_join = [
weights.pop(
f"{prefix}.block_sparse_moe.experts.{e}.{orig_name}.weight"
)
for e in range(self.args.num_local_experts)
]
w1_key = f"{prefix}.block_sparse_moe.experts.0.w1.weight"
w3_key = f"{prefix}.block_sparse_moe.experts.0.w3.weight"
if w1_key in weights and w3_key in weights:
if fuse:
# Stack and fuse gate(w1)+up(w3) into single gate_up_proj
gate = mx.stack(
[
weights.pop(
f"{prefix}.block_sparse_moe.experts.{e}.w1.weight"
)
for e in range(self.args.num_local_experts)
]
)
up = mx.stack(
[
weights.pop(
f"{prefix}.block_sparse_moe.experts.{e}.w3.weight"
)
for e in range(self.args.num_local_experts)
]
)
weights[
f"{prefix}.block_sparse_moe.switch_mlp.{new_name}.weight"
] = mx.stack(to_join)
f"{prefix}.block_sparse_moe.switch_mlp.gate_up_proj.weight"
] = mx.concatenate([gate, up], axis=1)
else:
# Unfused: separate gate_proj (w1) and up_proj (w3)
for n, wn in [("gate_proj", "w1"), ("up_proj", "w3")]:
to_join = [
weights.pop(
f"{prefix}.block_sparse_moe.experts.{e}.{wn}.weight"
)
for e in range(self.args.num_local_experts)
]
weights[
f"{prefix}.block_sparse_moe.switch_mlp.{n}.weight"
] = mx.stack(to_join)
# Stack down(w2) normally
w2_key = f"{prefix}.block_sparse_moe.experts.0.w2.weight"
if w2_key in weights:
to_join = [
weights.pop(f"{prefix}.block_sparse_moe.experts.{e}.w2.weight")
for e in range(self.args.num_local_experts)
]
weights[f"{prefix}.block_sparse_moe.switch_mlp.down_proj.weight"] = (
mx.stack(to_join)
)

return weights

Expand Down
65 changes: 54 additions & 11 deletions mlx_lm/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,10 @@ def __init__(self, args: ModelArgs):
# gating
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)

self.switch_mlp = SwitchGLU(self.hidden_dim, self.ffn_dim, self.num_experts)
fuse = getattr(args, "fuse_gate_up", False)
self.switch_mlp = SwitchGLU(
self.hidden_dim, self.ffn_dim, self.num_experts, fuse_gate_up=fuse
)

def __call__(self, x: mx.array) -> mx.array:
gates = self.gate(x)
Expand Down Expand Up @@ -207,20 +210,60 @@ def sanitize(self, weights):
weights.pop("lm_head.weight", None)
if "model.layers.0.block_sparse_moe.experts.0.w1.weight" not in weights:
return weights
fuse = getattr(self.args, "fuse_gate_up", False)
for l in range(self.args.num_hidden_layers):
prefix = f"model.layers.{l}"
for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
if fuse:
# Stack and fuse gate(w1)+up(w3) into single gate_up_proj
for k in ["weight", "scales", "biases"]:
if f"{prefix}.block_sparse_moe.experts.0.{n}.{k}" in weights:
to_join = [
weights.pop(
f"{prefix}.block_sparse_moe.experts.{e}.{n}.{k}"
)
for e in range(self.args.num_local_experts)
]
weights[f"{prefix}.block_sparse_moe.switch_mlp.{m}.{k}"] = (
mx.stack(to_join)
w1_key = f"{prefix}.block_sparse_moe.experts.0.w1.{k}"
w3_key = f"{prefix}.block_sparse_moe.experts.0.w3.{k}"
if w1_key in weights and w3_key in weights:
gate = mx.stack(
[
weights.pop(
f"{prefix}.block_sparse_moe.experts.{e}.w1.{k}"
)
for e in range(self.args.num_local_experts)
]
)
up = mx.stack(
[
weights.pop(
f"{prefix}.block_sparse_moe.experts.{e}.w3.{k}"
)
for e in range(self.args.num_local_experts)
]
)
weights[
f"{prefix}.block_sparse_moe.switch_mlp.gate_up_proj.{k}"
] = mx.concatenate([gate, up], axis=1)
else:
# Unfused: separate gate_proj (w1) and up_proj (w3)
for n, wn in [("gate_proj", "w1"), ("up_proj", "w3")]:
for k in ["weight", "scales", "biases"]:
src = f"{prefix}.block_sparse_moe.experts.0.{wn}.{k}"
if src in weights:
to_join = [
weights.pop(
f"{prefix}.block_sparse_moe.experts.{e}.{wn}.{k}"
)
for e in range(self.args.num_local_experts)
]
weights[
f"{prefix}.block_sparse_moe.switch_mlp.{n}.{k}"
] = mx.stack(to_join)
# Stack down(w2) normally
for k in ["weight", "scales", "biases"]:
w2_key = f"{prefix}.block_sparse_moe.experts.0.w2.{k}"
if w2_key in weights:
to_join = [
weights.pop(f"{prefix}.block_sparse_moe.experts.{e}.w2.{k}")
for e in range(self.args.num_local_experts)
]
weights[f"{prefix}.block_sparse_moe.switch_mlp.down_proj.{k}"] = (
mx.stack(to_join)
)
return weights

@property
Expand Down
59 changes: 52 additions & 7 deletions mlx_lm/models/olmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def __init__(self, args: ModelArgs):
args.intermediate_size,
self.num_experts,
bias=args.mlp_bias,
fuse_gate_up=getattr(args, "fuse_gate_up", False),
)

def __call__(self, x: mx.array) -> mx.array:
Expand Down Expand Up @@ -197,16 +198,60 @@ def __call__(
def sanitize(self, weights):
if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights:
return weights
fuse = getattr(self.args, "fuse_gate_up", False)
for l in range(self.args.num_hidden_layers):
prefix = f"model.layers.{l}"
for n in ["up_proj", "down_proj", "gate_proj"]:
if fuse and f"{prefix}.mlp.experts.0.gate_proj.weight" in weights:
# Stack and fuse gate+up into single gate_up_proj
for k in ["weight", "scales", "biases"]:
if f"{prefix}.mlp.experts.0.{n}.{k}" in weights:
to_join = [
weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}")
for e in range(self.args.num_experts)
]
weights[f"{prefix}.mlp.switch_mlp.{n}.{k}"] = mx.stack(to_join)
gate_key = f"{prefix}.mlp.experts.0.gate_proj.{k}"
up_key = f"{prefix}.mlp.experts.0.up_proj.{k}"
if gate_key in weights and up_key in weights:
gate = mx.stack(
[
weights.pop(
f"{prefix}.mlp.experts.{e}.gate_proj.{k}"
)
for e in range(self.args.num_experts)
]
)
up = mx.stack(
[
weights.pop(
f"{prefix}.mlp.experts.{e}.up_proj.{k}"
)
for e in range(self.args.num_experts)
]
)
weights[f"{prefix}.mlp.switch_mlp.gate_up_proj.{k}"] = (
mx.concatenate([gate, up], axis=1)
)
else:
# Unfused: separate gate_proj and up_proj
for n in ["gate_proj", "up_proj"]:
for k in ["weight", "scales", "biases"]:
src = f"{prefix}.mlp.experts.0.{n}.{k}"
if src in weights:
to_join = [
weights.pop(
f"{prefix}.mlp.experts.{e}.{n}.{k}"
)
for e in range(self.args.num_experts)
]
weights[
f"{prefix}.mlp.switch_mlp.{n}.{k}"
] = mx.stack(to_join)
# Stack down_proj normally
for k in ["weight", "scales", "biases"]:
down_key = f"{prefix}.mlp.experts.0.down_proj.{k}"
if down_key in weights:
to_join = [
weights.pop(f"{prefix}.mlp.experts.{e}.down_proj.{k}")
for e in range(self.args.num_experts)
]
weights[f"{prefix}.mlp.switch_mlp.down_proj.{k}"] = mx.stack(
to_join
)
return weights

@property
Expand Down
23 changes: 17 additions & 6 deletions mlx_lm/models/qwen3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,15 +498,26 @@ def _repeat(p):
shard_inplace(
layer.mlp.shared_expert.up_proj, "all-to-sharded", group=group
)
shard_inplace(
layer.mlp.switch_mlp.gate_proj, "all-to-sharded", group=group
)
if hasattr(layer.mlp.switch_mlp, "gate_up_proj"):
shard_inplace(
layer.mlp.switch_mlp.gate_up_proj,
"all-to-sharded",
group=group,
)
else:
shard_inplace(
layer.mlp.switch_mlp.gate_proj,
"all-to-sharded",
group=group,
)
shard_inplace(
layer.mlp.switch_mlp.up_proj,
"all-to-sharded",
group=group,
)
shard_inplace(
layer.mlp.switch_mlp.down_proj, "sharded-to-all", group=group
)
shard_inplace(
layer.mlp.switch_mlp.up_proj, "all-to-sharded", group=group
)

@property
def layers(self):
Expand Down
22 changes: 13 additions & 9 deletions mlx_lm/models/qwen3_5_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,24 @@ def sanitize(self, weights):
key = "language_model." + key
new_weights[key] = value

fuse = getattr(self.language_model.args, "fuse_gate_up", False)
for l in range(self.language_model.args.num_hidden_layers):
prefix = f"language_model.model.layers.{l}.mlp"
gate_up_key = f"{prefix}.experts.gate_up_proj"
if gate_up_key in new_weights:
gate_up = new_weights.pop(gate_up_key)
mid = gate_up.shape[-2] // 2
new_weights[f"{prefix}.switch_mlp.gate_proj.weight"] = gate_up[
..., :mid, :
]
new_weights[f"{prefix}.switch_mlp.up_proj.weight"] = gate_up[
..., mid:, :
]
new_weights[f"{prefix}.switch_mlp.down_proj.weight"] = new_weights.pop(
f"{prefix}.experts.down_proj"
if fuse:
new_weights[f"{prefix}.switch_mlp.gate_up_proj.weight"] = gate_up
else:
mid = gate_up.shape[-2] // 2
new_weights[f"{prefix}.switch_mlp.gate_proj.weight"] = gate_up[
..., :mid, :
]
new_weights[f"{prefix}.switch_mlp.up_proj.weight"] = gate_up[
..., mid:, :
]
new_weights[f"{prefix}.switch_mlp.down_proj.weight"] = (
new_weights.pop(f"{prefix}.experts.down_proj")
)

return self.language_model.sanitize(new_weights)
Loading