Skip to content
7 changes: 7 additions & 0 deletions mlx_lm/models/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@ def swiglu(gate, x):
return nn.silu(gate) * x


@partial(mx.compile, shapeless=True)
def precise_swiglu(h, gate, x):
gate = nn.silu(gate.astype(mx.float32))
x = x.astype(mx.float32)
return (gate * x).astype(h.dtype)
Comment on lines +14 to +18
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used the exact same signature that appeared in the Qwen 3 Next implementation; let me know if it's preferable to rid the h.



@partial(mx.compile, shapeless=True)
def xielu(x, alpha_p, alpha_n, beta, eps):
alpha_p = nn.softplus(alpha_p)
Expand Down
Loading