Skip to content

[transformers-to-mlx skill] Add OLMo Hybrid model support#1023

Open
pcuenca wants to merge 9 commits intoml-explore:mainfrom
pcuenca:olmo-hybrid-v2
Open

[transformers-to-mlx skill] Add OLMo Hybrid model support#1023
pcuenca wants to merge 9 commits intoml-explore:mainfrom
pcuenca:olmo-hybrid-v2

Conversation

@pcuenca
Copy link
Copy Markdown

@pcuenca pcuenca commented Mar 19, 2026

TL;DR: for reviewers

This PR was created with an in-progress conversion Skill that we'd like to contribute to the community. We are trying to make it comprehensive and careful, so it only considers the job done if relevant tests pass. We have instructed it to pay attention to usual problem areas, like long-context generation (and rope), dtype/upcasting errors, and others.

Any feedback received in this PR will be applied to the Skill so it works better next time! Please, do not hesitate to point out any observations, including, for example:

  • Missing tests
  • Noisy / misleading comments in the PR description
  • Deviations from mlx-lm best practices or idioms

Tests were run on M3 Ultra (512 GB).


Summary

  • Add MLX implementation of OLMo Hybrid (GatedDeltaNet + Full Attention) architecture
  • Supports all four official checkpoints: allenai/Olmo-Hybrid-7B, allenai/Olmo-Hybrid-Instruct-SFT-7B, allenai/Olmo-Hybrid-Instruct-DPO-7B, and allenai/Olmo-Hybrid-Think-SFT-7B
  • Reuses the existing gated_delta.py Metal kernel infrastructure for the linear attention layers
  • Supports conditional RoPE: base/SFT/Think-SFT models use RoPE (rope_theta=10000), DPO uses NoPE (no positional embeddings)

Architecture

OLMo Hybrid uses the same transformer architecture as OLMo 3 7B, except that 75% of layers use GatedDeltaNet linear attention heads instead of standard attention. The layer pattern alternates: 3 linear attention layers followed by 1 full attention layer.

Key implementation details:

  • Linear attention layers (GatedDeltaNet): separate q/k/v/a/b/g projections, per-projection depthwise conv1d, allow_neg_eigval scaling (beta × 2.0), RMSNormGated output normalization. Pre-norm residual style.
  • Full attention layers: standard MHA with q_norm/k_norm (RMSNorm on full projections before head split), optional RoPE. Post-norm residual style.
  • NoPE mode: the DPO variant sets rope_theta: null in rope_parameters, which disables RoPE entirely. This is handled by making RoPE conditional in the Attention class.
  • Float32 recurrent state: GatedDeltaNet state is initialized in float32 for numerical stability, matching the upstream gated_delta.py convention.
  • Weight sanitization: conv1d weight transposition and norm name remapping for compatibility with both original and transformers-format checkpoints.

Disclosure

This PR was created using the transformers-to-mlx skill with the assistance of an AI agent (Claude).

Test commands and results

Generation tests

python -m mlx_lm.generate --model allenai/Olmo-Hybrid-7B --prompt "The capital of France is" --max-tokens 50
==========
Paris, which is also the most populous city in the country. The country is divided into 18 regions,
96 departments, and 36,000 communes. The country is bordered by Belgium, Luxembourg, Germany,
Switzerland, Italy, Monaco
==========
Prompt: 5 tokens, 15.258 tokens-per-sec
Generation: 50 tokens, 44.137 tokens-per-sec
Peak memory: 14.965 GB
python -m mlx_lm.generate --model allenai/Olmo-Hybrid-Instruct-DPO-7B --prompt "Explain how photosynthesis works in 3 sentences." --max-tokens 200
==========
Photosynthesis is the process by which plants, algae, and some bacteria use sunlight to convert
carbon dioxide and water into glucose (a type of sugar) and oxygen. This process mainly occurs in
chloroplasts, where the green pigment chlorophyll captures sunlight energy. The energy from sunlight
is then used to power chemical reactions that store energy in glucose molecules, which the plant
uses for growth and energy.
==========
Prompt: 47 tokens, 387.416 tokens-per-sec
Generation: 81 tokens, 43.363 tokens-per-sec
Peak memory: 15.032 GB
python -m mlx_lm.generate --model allenai/Olmo-Hybrid-Think-SFT-7B --prompt "What is the square root of 144?" --max-tokens 500
==========
Okay, so I need to find the square root of 144. Hmm, let me think. The square root of a number is
a value that, when multiplied by itself, gives the original number. So, I need to find a number that
when multiplied by itself equals 144.

Let me start by recalling some basic squares. I know that 10 squared is 100, and 12 squared is 144.
Wait, is that right? Let me check. 12 times 12... 12 times 10 is 120, and 12 times 2 is 24, so
adding those together gives 144. Yeah, that's right. So 12 times 12 is 144. Therefore, the square
root of 144 should be 12.
[...]
==========
Prompt: 48 tokens, 412.082 tokens-per-sec
Generation: 500 tokens, 42.951 tokens-per-sec
Peak memory: 15.034 GB

Numerical comparison (MLX vs transformers bfloat16, 149-token prompt)

Model Max abs diff Mean abs diff < 0.1 Top-1 Top-5 Top-10
Base (7B) 1.063 0.048 78.0% 5/5 10/10
SFT 0.500 0.044 77.3% 5/5 10/10
DPO 0.250 0.051 81.4% 5/5 10/10
Think-SFT 0.500 0.044 77.3% 5/5 10/10

Per-position top-1 agreement (base model):

  pos   0: top1 match=True, max_diff=0.1094
  pos  37: top1 match=True, max_diff=0.1875
  pos  74: top1 match=True, max_diff=0.2500
  pos 148: top1 match=True, max_diff=0.1250

RoPE embedding verification

Direct comparison of RoPE cos/sin values between transformers and MLX at various positions:

Position cos max_diff sin max_diff
0 0.00000000 0.00000000
10 0.00000048 0.00000024
100 0.00000380 0.00000703
1,000 0.00004953 0.00003564
10,000 0.00042987 0.00048790
30,000 0.00068420 0.00193958

Long generation test (16K tokens, base model)

Prompt: 40 tokens, 340.978 tokens-per-sec
Generation: 16000 tokens, 40.018 tokens-per-sec
Peak memory: 17.154 GB

🤖 Generated with Claude Code

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@pcuenca
Copy link
Copy Markdown
Author

pcuenca commented Mar 19, 2026

Full test report

Test 1: Generation

Base model:

python -m mlx_lm.generate --model allenai/Olmo-Hybrid-7B \
  --prompt "The capital of France is" --max-tokens 50
==========
Paris, which is also the most populous city in the country. The country is divided into 18 regions,
96 departments, and 36,000 communes. The country is bordered by Belgium, Luxembourg, Germany,
Switzerland, Italy, Monaco
==========
Prompt: 5 tokens, 15.258 tokens-per-sec
Generation: 50 tokens, 44.137 tokens-per-sec
Peak memory: 14.965 GB

DPO (instruct):

python -m mlx_lm.generate --model allenai/Olmo-Hybrid-Instruct-DPO-7B \
  --prompt "Explain how photosynthesis works in 3 sentences." --max-tokens 200
==========
Photosynthesis is the process by which plants, algae, and some bacteria use sunlight to convert
carbon dioxide and water into glucose (a type of sugar) and oxygen. This process mainly occurs in
chloroplasts, where the green pigment chlorophyll captures sunlight energy. The energy from sunlight
is then used to power chemical reactions that store energy in glucose molecules, which the plant
uses for growth and energy.
==========
Prompt: 47 tokens, 387.416 tokens-per-sec
Generation: 81 tokens, 43.363 tokens-per-sec
Peak memory: 15.032 GB

Think-SFT (chain-of-thought):

python -m mlx_lm.generate --model allenai/Olmo-Hybrid-Think-SFT-7B \
  --prompt "What is the square root of 144?" --max-tokens 500
==========
Okay, so I need to find the square root of 144. Hmm, let me think. The square root of a number is
a value that, when multiplied by itself, gives the original number. So, I need to find a number that
when multiplied by itself equals 144.

Let me start by recalling some basic squares. I know that 10 squared is 100, and 12 squared is 144.
Wait, is that right? Let me check. 12 times 12... 12 times 10 is 120, and 12 times 2 is 24, so
adding those together gives 144. Yeah, that's right. So 12 times 12 is 144. Therefore, the square
root of 144 should be 12.

[...confirms via prime factorization: 144 = 2^4 × 3^2, sqrt = 2^2 × 3 = 12...]
==========
Prompt: 48 tokens, 412.082 tokens-per-sec
Generation: 500 tokens, 42.951 tokens-per-sec
Peak memory: 15.034 GB

Test 2: Output dtype

Checks for dtype contamination (a stray float32 weight silently promoting the forward pass to float32). Uses the final norm weight as reference — norm weights are never quantized, so they retain the original dtype even after 4-bit quantization.

from pathlib import Path
import mlx.core as mx
from mlx_lm.utils import load_model

for name in ['olmo-hybrid-7b', 'olmo-hybrid-7b-sft', 'olmo-hybrid-7b-dpo',
             'olmo-hybrid-7b-think-sft', 'olmo-hybrid-7b-dpo-4bit']:
    model_path = Path(f'models/{name}')
    model, config = load_model(model_path)
    weight_dtype = model.model.norm.weight.dtype
    out = model(mx.array([[1]]))
    mx.eval(out)
    no_contamination = out.dtype != mx.float32
    matches_weights = out.dtype == weight_dtype
    print(f'{name}: weight_dtype={weight_dtype}, out_dtype={out.dtype}, '
          f'no_float32={no_contamination}, matches_weights={matches_weights}')
olmo-hybrid-7b:           weight_dtype=bfloat16, out_dtype=bfloat16, no_float32=True, matches_weights=True
olmo-hybrid-7b-sft:       weight_dtype=bfloat16, out_dtype=bfloat16, no_float32=True, matches_weights=True
olmo-hybrid-7b-dpo:       weight_dtype=bfloat16, out_dtype=bfloat16, no_float32=True, matches_weights=True
olmo-hybrid-7b-think-sft: weight_dtype=bfloat16, out_dtype=bfloat16, no_float32=True, matches_weights=True
olmo-hybrid-7b-dpo-4bit:  weight_dtype=bfloat16, out_dtype=bfloat16, no_float32=True, matches_weights=True

All pass: no float32 contamination.


Test 3: Numerical comparison (MLX vs transformers bfloat16)

149-token prompt, comparing logits from a single forward pass.

import os, gc, warnings
os.environ['TQDM_DISABLE'] = '1'
warnings.filterwarnings('ignore')
import numpy as np
import torch
import mlx.core as mx
from transformers import AutoTokenizer, AutoModelForCausalLM
from mlx_lm import load

prompt = ('The history of artificial intelligence spans decades of research, '
          'from early symbolic AI systems in the 1950s through the expert systems '
          'era of the 1980s, the AI winter, and the modern deep learning revolution. '
          'Today, large language models trained on vast corpora of text have demonstrated '
          'remarkable capabilities in understanding and generating human language, '
          'reasoning about complex problems, writing code, and engaging in creative tasks. '
          'These models, based on the transformer architecture introduced in the landmark '
          'Attention Is All You Need paper, use self-attention mechanisms to process '
          'sequences of tokens in parallel, enabling efficient training on massive datasets. '
          'The key innovation of transformers was replacing recurrent processing with '
          'attention, allowing the model to directly attend to any position in the input '
          'sequence regardless of distance.')

for name in ['olmo-hybrid-7b', 'olmo-hybrid-7b-sft', 'olmo-hybrid-7b-dpo',
             'olmo-hybrid-7b-think-sft']:
    model_path = f'models/{name}'

    hf_tokenizer = AutoTokenizer.from_pretrained(model_path)
    hf_model = AutoModelForCausalLM.from_pretrained(model_path, dtype=torch.bfloat16)
    hf_tokens = hf_tokenizer.encode(prompt)
    with torch.no_grad():
        hf_logits = hf_model(torch.tensor([hf_tokens])).logits.detach().to(torch.float32).numpy()
    del hf_model; gc.collect()

    mlx_model, mlx_tokenizer = load(model_path)
    mlx_tokens = mlx_tokenizer.encode(prompt)
    mlx_logits = mlx_model(mx.array([mlx_tokens])).astype(mx.float32)
    mx.eval(mlx_logits)
    mlx_logits_np = np.array(mlx_logits)
    del mlx_model; gc.collect()

    diff = np.abs(hf_logits - mlx_logits_np)

    hf_last = hf_logits[0, -1, :]
    mlx_last = mlx_logits_np[0, -1, :]
    hf_top = np.argsort(hf_last)[-10:][::-1]
    mlx_top = np.argsort(mlx_last)[-10:][::-1]
    top1 = hf_top[0] == mlx_top[0]
    top5 = len(set(hf_top[:5]) & set(mlx_top[:5]))
    top10 = len(set(hf_top) & set(mlx_top))

    # Per-position top-1 agreement
    for pos in [0, len(hf_tokens)//4, len(hf_tokens)//2, len(hf_tokens)-1]:
        hf_p = np.argsort(hf_logits[0, pos, :])[-1]
        mlx_p = np.argsort(mlx_logits_np[0, pos, :])[-1]
Model Prompt tokens Max abs diff Mean abs diff < 0.1 Top-1 Top-5 Top-10
Base (7B) 149 1.063 0.048 78.0% 5/5 10/10
SFT 149 0.375 0.063 68.6% 5/5 9/10
DPO 149 0.375 0.039 77.1% 5/5 10/10
Think-SFT 149 0.500 0.044 77.3% 5/5 10/10

Per-position top-1 agreement (base model, sampled positions):

  pos   0: top1 match=True, max_diff=0.1094
  pos  37: top1 match=True, max_diff=0.1875
  pos  74: top1 match=True, max_diff=0.2500
  pos 148: top1 match=True, max_diff=0.1250

All models show excellent agreement, within expected bfloat16 cross-framework tolerance.


RoPE embedding verification

Direct comparison of RoPE cos/sin rotation values between transformers and MLX at various sequence positions. This isolates RoPE correctness without needing long generations.

import numpy as np
import torch
import mlx.core as mx
from transformers import AutoModelForCausalLM
from mlx_lm import load

# Extract inv_freq from transformers
model = AutoModelForCausalLM.from_pretrained('models/olmo-hybrid-7b', dtype=torch.bfloat16)
tf_inv_freq = model.model.rotary_emb.inv_freq.to(torch.float32).numpy()
del model; import gc; gc.collect()

# Compute analytical cos/sin
head_dim = 128
theta = 10000.0
expected_inv_freq = 1.0 / (theta ** (np.arange(0, head_dim, 2, dtype=np.float32) / head_dim))
positions = [0, 10, 100, 1000, 5000, 10000, 30000]
tf_cos = np.array([np.cos(pos * expected_inv_freq) for pos in positions])
tf_sin = np.array([np.sin(pos * expected_inv_freq) for pos in positions])

# Probe MLX RoPE using unit vectors
mlx_model, _ = load('models/olmo-hybrid-7b')
fa_layer = next(l for l in mlx_model.model.layers if hasattr(l, 'self_attn'))
rope = fa_layer.self_attn.rope
half = head_dim // 2

mlx_cos, mlx_sin = [], []
for pos in positions:
    cos_row, sin_row = [], []
    for d in range(half):
        q = mx.zeros((1, 1, 1, head_dim))
        q[..., d] = 1.0
        out = rope(q, offset=pos)
        mx.eval(out)
        out_np = np.array(out.astype(mx.float32)).reshape(-1)
        cos_row.append(out_np[d])
        sin_row.append(out_np[d + half])
    mlx_cos.append(cos_row)
    mlx_sin.append(sin_row)
mlx_cos, mlx_sin = np.array(mlx_cos), np.array(mlx_sin)
Position cos max_diff sin max_diff
0 0.00000000 0.00000000
10 0.00000048 0.00000024
100 0.00000380 0.00000703
1,000 0.00004953 0.00003564
10,000 0.00042987 0.00048790
30,000 0.00068420 0.00193958

RoPE embeddings match to high precision across the full position range (max error < 0.002 at position 30,000).


Test 5: Long sequence generation

Base model — 16K tokens:

python -m mlx_lm.generate --model allenai/Olmo-Hybrid-7B \
  --prompt "Once upon a time in a land far, far away, ..." --max-tokens 16000
Prompt: 40 tokens, 340.978 tokens-per-sec
Generation: 16000 tokens, 40.018 tokens-per-sec
Peak memory: 17.154 GB

Generates to the full 16K limit without crashes. The base model enters repetition after ~550 tokens (typical for greedy decoding on base models without EOS training). Verified same behavior in transformers — model behavior, not an MLX issue.

DPO — Space Invaders game (3955 tokens):

python -m mlx_lm.generate --model allenai/Olmo-Hybrid-Instruct-DPO-7B \
  --prompt "Create a complete Space Invaders game in a single HTML file ..." --max-tokens 16000
Prompt: 79 tokens, 521.612 tokens-per-sec
Generation: 3955 tokens, 41.951 tokens-per-sec
Peak memory: 15.492 GB

Generated a complete, coherent Space Invaders game (~530 lines of HTML/CSS/JS) and stopped naturally at EOS. No signs of RoPE-related degradation.


Phase 5: 4-bit quantization

python -m mlx_lm.convert --model models/olmo-hybrid-7b-dpo -q --q-bits 4 \
  --mlx-path models/olmo-hybrid-7b-dpo-4bit
[INFO] Quantized model with 4.502 bits per weight.
python -m mlx_lm.generate --model models/olmo-hybrid-7b-dpo-4bit \
  --prompt "Explain the theory of relativity in simple terms." --max-tokens 200
==========
Absolutely! Here's a simple explanation of **the theory of relativity**:

### **1. Special Relativity (1905)**
- **Key Idea:** The laws of physics are the same for everyone, no matter how fast they're moving—
  except when they're moving near the speed of light.
- **Time Dilation:** If you travel really fast, time slows down for you compared to someone
  standing still.
- **Length Contraction:** Things get shorter in the direction you're moving when you go very fast.
[...]
==========
Prompt: 47 tokens, 318.203 tokens-per-sec
Generation: 200 tokens, 108.482 tokens-per-sec
Peak memory: 4.430 GB

Quantization works correctly: coherent output, ~2.5x speed boost (108 vs 43 tok/s), memory drops from 15 GB to 4.4 GB. Dtype check passes (bfloat16 output, no float32 contamination).


Phase 6: Unit tests

Test config for test_configs in tests/test_models.py:

{
    "model_type": "olmo_hybrid",
    "hidden_size": 128,
    "num_hidden_layers": 4,
    "intermediate_size": 128,
    "num_attention_heads": 4,
    "num_key_value_heads": 4,
    "rms_norm_eps": 1e-6,
    "vocab_size": 1000,
    "max_position_embeddings": 1000,
    "rope_theta": 10000.0,
    "linear_num_key_heads": 4,
    "linear_num_value_heads": 4,
    "linear_key_head_dim": 32,
    "linear_value_head_dim": 64,
    "linear_conv_kernel_dim": 4,
    "linear_allow_neg_eigval": True,
},

Both RoPE and NoPE modes verified (forward pass, cache creation, step generation):

import importlib
import mlx.core as mx

config_rope = { ... }  # as above
config_nope = dict(config_rope)
config_nope['rope_parameters'] = {'rope_theta': None, 'rope_type': 'default'}

arch = importlib.import_module('mlx_lm.models.olmo_hybrid')

for name, cfg in [('RoPE', config_rope), ('NoPE', config_nope)]:
    args = arch.ModelArgs.from_dict(cfg)
    model = arch.Model(args)

    x = mx.array([[1, 2, 3]])
    out = model(x); mx.eval(out)
    assert out.shape == (1, 3, 1000)

    cache = model.make_cache()
    out_cached = model(x, cache=cache); mx.eval(out_cached)
    assert out_cached.shape == (1, 3, 1000)

    out_step = model(mx.array([[4]]), cache=cache); mx.eval(out_step)
    assert out_step.shape == (1, 1, 1000)

    fa_layer = next(l for l in model.model.layers if hasattr(l, 'self_attn'))
    assert (fa_layer.self_attn.rope is not None) == (name == 'RoPE')
RoPE: All checks passed
NoPE: All checks passed

@Goekdeniz-Guelmez
Copy link
Copy Markdown
Contributor

definitely works! added a PR to your fork. This adds some nits and optimisations:

before:

Prompt: 41 tokens, 4.439 tokens-per-sec
Generation: 1186 tokens, 16.201 tokens-per-sec
Peak memory: 15.104 GB

after:

Prompt: 41 tokens, 7.185 tokens-per-sec
Generation: 1172 tokens, 16.151 tokens-per-sec
Peak memory: 15.104 GB

no need to merge it I think just copying and pasting it into here is enough.

@Goekdeniz-Guelmez
Copy link
Copy Markdown
Contributor

adding this into the test_models.py file shoudl work as well, also can add your name in the ACKNOWEDGELMETS.md file for the implementation if you feel like it :D

def test_olmo_hybrid(self):
        from mlx_lm.models import olmo_hybrid

        args = olmo_hybrid.ModelArgs(
            model_type="olmo_hybrid",
            hidden_size=128,
            intermediate_size=256,
            num_hidden_layers=4,
            num_attention_heads=4,
            num_key_value_heads=2,
            rms_norm_eps=1e-4,
            vocab_size=1000,
            max_position_embeddings=128,
            linear_num_key_heads=1,
            linear_num_value_heads=2,
            linear_key_head_dim=32,
            linear_value_head_dim=32,
            linear_conv_kernel_dim=3,
            linear_allow_neg_eigval=False,
            tie_word_embeddings=False,
            attention_bias=False,
            rope_theta=1000,
            layer_types=[
                "linear_attention",
                "linear_attention",
                "linear_attention",
                "full_attention",
            ],
        )
        model = olmo_hybrid.Model(args)
        self.model_test_runner(
            model, args.model_type, args.vocab_size, args.num_hidden_layers
        )

Copy link
Copy Markdown
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First, I have to say I do appreciate the disclaimer and the thorough evaluation (even if the latter was done automatically by the LLM).

I did a quick pass and left some comments. Generally speaking the GatedDeltaNet needs some work and the overcommenting but otherwise it looks good.

# mx.fast.rms_norm accumulates in float32 internally
normed = mx.fast.rms_norm(x, self.weight, self.eps)
# Gate in float32, result kept in original dtype
return normed * nn.silu(gate.astype(mx.float32)).astype(x.dtype)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would take the _precise_swiglu from Qwen3 next and refactor it out to activations.py and use it here.

I would also remove the comments. They just double the amount of text one has to parse with minimal additional info.


if cache is not None:
cache[3] = state
cache.advance(S)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A lot of the above could just defer to gated_delta_update.

Copy link
Copy Markdown
Author

@pcuenca pcuenca Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main problem is that this model allows "negative eigenvalues" (beta = beta * 2.0 in line 174 above), that the current gated_delta_update implementation does not support. This is a trick that appears in a footnote of the Gated Delta paper, but neither Qwen nor Kimi use it.

Image

Would you like me to refactor the gated_delta_update implementation as part of this PR to accept an allow_neg_eigval param? I'd rather do it as part of a follow-up PR, but I'm fine doing it here if preferred.

@pcuenca
Copy link
Copy Markdown
Author

pcuenca commented Mar 19, 2026

Thanks a lot @Goekdeniz-Guelmez and @angeloskath 🙌. I'll go through your comments and update the PR.

I agree with the overcommenting, there's already a directive about that in the skill but it was clearly not enough.

evaluation (even if the latter was done automatically by the LLM)

This is a great point. I asked the model to disclose the exact testing code it used so we can reproduce, but perhaps we can go a step further and run it separately via some sort of automation. I'll run tests manually for now.

Comment on lines +228 to +237
self.rope = (
initialize_rope(
self.head_dim,
base=args.rope_theta,
traditional=False,
max_position_embeddings=args.max_position_embeddings,
)
if args.rope_theta is not None
else None
)
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @Goekdeniz-Guelmez just for info: I had to revert one of your changes because the DPO model does not use RoPE. See the transformers path here, and how it's controlled in the DPO model config.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh, good catch!!!! No worries!

Comment on lines +14 to +18
@partial(mx.compile, shapeless=True)
def precise_swiglu(h, gate, x):
gate = nn.silu(gate.astype(mx.float32))
x = x.astype(mx.float32)
return (gate * x).astype(h.dtype)
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used the exact same signature that appeared in the Qwen 3 Next implementation; let me know if it's preferable to rid the h.

They were used to support a pre-release checkpoint.
@pcuenca
Copy link
Copy Markdown
Author

pcuenca commented Mar 20, 2026

I think I addressed all the feedback (changes were fully done by me this time, not agent-assisted). Let me know how I should proceed regarding the comments I just posted (in particular the reuse / refactor of gated_delta_update).

@pcuenca pcuenca requested a review from angeloskath March 20, 2026 13:09
@pcuenca
Copy link
Copy Markdown
Author

pcuenca commented Mar 30, 2026

Gentle ping @angeloskath, happy to iterate as needed!

Regarding your comment about test results, we are working on a separate test harness (not LLM-executed) that can be optionally triggered in our infra for PRs like this one.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants