[transformers-to-mlx skill] Add OLMo Hybrid model support#1023
[transformers-to-mlx skill] Add OLMo Hybrid model support#1023pcuenca wants to merge 9 commits intoml-explore:mainfrom
Conversation
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Full test reportTest 1: GenerationBase model: python -m mlx_lm.generate --model allenai/Olmo-Hybrid-7B \
--prompt "The capital of France is" --max-tokens 50DPO (instruct): python -m mlx_lm.generate --model allenai/Olmo-Hybrid-Instruct-DPO-7B \
--prompt "Explain how photosynthesis works in 3 sentences." --max-tokens 200Think-SFT (chain-of-thought): python -m mlx_lm.generate --model allenai/Olmo-Hybrid-Think-SFT-7B \
--prompt "What is the square root of 144?" --max-tokens 500Test 2: Output dtypeChecks for dtype contamination (a stray float32 weight silently promoting the forward pass to float32). Uses the final norm weight as reference — norm weights are never quantized, so they retain the original dtype even after 4-bit quantization. from pathlib import Path
import mlx.core as mx
from mlx_lm.utils import load_model
for name in ['olmo-hybrid-7b', 'olmo-hybrid-7b-sft', 'olmo-hybrid-7b-dpo',
'olmo-hybrid-7b-think-sft', 'olmo-hybrid-7b-dpo-4bit']:
model_path = Path(f'models/{name}')
model, config = load_model(model_path)
weight_dtype = model.model.norm.weight.dtype
out = model(mx.array([[1]]))
mx.eval(out)
no_contamination = out.dtype != mx.float32
matches_weights = out.dtype == weight_dtype
print(f'{name}: weight_dtype={weight_dtype}, out_dtype={out.dtype}, '
f'no_float32={no_contamination}, matches_weights={matches_weights}')All pass: no float32 contamination. Test 3: Numerical comparison (MLX vs transformers bfloat16)149-token prompt, comparing logits from a single forward pass. import os, gc, warnings
os.environ['TQDM_DISABLE'] = '1'
warnings.filterwarnings('ignore')
import numpy as np
import torch
import mlx.core as mx
from transformers import AutoTokenizer, AutoModelForCausalLM
from mlx_lm import load
prompt = ('The history of artificial intelligence spans decades of research, '
'from early symbolic AI systems in the 1950s through the expert systems '
'era of the 1980s, the AI winter, and the modern deep learning revolution. '
'Today, large language models trained on vast corpora of text have demonstrated '
'remarkable capabilities in understanding and generating human language, '
'reasoning about complex problems, writing code, and engaging in creative tasks. '
'These models, based on the transformer architecture introduced in the landmark '
'Attention Is All You Need paper, use self-attention mechanisms to process '
'sequences of tokens in parallel, enabling efficient training on massive datasets. '
'The key innovation of transformers was replacing recurrent processing with '
'attention, allowing the model to directly attend to any position in the input '
'sequence regardless of distance.')
for name in ['olmo-hybrid-7b', 'olmo-hybrid-7b-sft', 'olmo-hybrid-7b-dpo',
'olmo-hybrid-7b-think-sft']:
model_path = f'models/{name}'
hf_tokenizer = AutoTokenizer.from_pretrained(model_path)
hf_model = AutoModelForCausalLM.from_pretrained(model_path, dtype=torch.bfloat16)
hf_tokens = hf_tokenizer.encode(prompt)
with torch.no_grad():
hf_logits = hf_model(torch.tensor([hf_tokens])).logits.detach().to(torch.float32).numpy()
del hf_model; gc.collect()
mlx_model, mlx_tokenizer = load(model_path)
mlx_tokens = mlx_tokenizer.encode(prompt)
mlx_logits = mlx_model(mx.array([mlx_tokens])).astype(mx.float32)
mx.eval(mlx_logits)
mlx_logits_np = np.array(mlx_logits)
del mlx_model; gc.collect()
diff = np.abs(hf_logits - mlx_logits_np)
hf_last = hf_logits[0, -1, :]
mlx_last = mlx_logits_np[0, -1, :]
hf_top = np.argsort(hf_last)[-10:][::-1]
mlx_top = np.argsort(mlx_last)[-10:][::-1]
top1 = hf_top[0] == mlx_top[0]
top5 = len(set(hf_top[:5]) & set(mlx_top[:5]))
top10 = len(set(hf_top) & set(mlx_top))
# Per-position top-1 agreement
for pos in [0, len(hf_tokens)//4, len(hf_tokens)//2, len(hf_tokens)-1]:
hf_p = np.argsort(hf_logits[0, pos, :])[-1]
mlx_p = np.argsort(mlx_logits_np[0, pos, :])[-1]
Per-position top-1 agreement (base model, sampled positions): All models show excellent agreement, within expected bfloat16 cross-framework tolerance. RoPE embedding verificationDirect comparison of RoPE cos/sin rotation values between transformers and MLX at various sequence positions. This isolates RoPE correctness without needing long generations. import numpy as np
import torch
import mlx.core as mx
from transformers import AutoModelForCausalLM
from mlx_lm import load
# Extract inv_freq from transformers
model = AutoModelForCausalLM.from_pretrained('models/olmo-hybrid-7b', dtype=torch.bfloat16)
tf_inv_freq = model.model.rotary_emb.inv_freq.to(torch.float32).numpy()
del model; import gc; gc.collect()
# Compute analytical cos/sin
head_dim = 128
theta = 10000.0
expected_inv_freq = 1.0 / (theta ** (np.arange(0, head_dim, 2, dtype=np.float32) / head_dim))
positions = [0, 10, 100, 1000, 5000, 10000, 30000]
tf_cos = np.array([np.cos(pos * expected_inv_freq) for pos in positions])
tf_sin = np.array([np.sin(pos * expected_inv_freq) for pos in positions])
# Probe MLX RoPE using unit vectors
mlx_model, _ = load('models/olmo-hybrid-7b')
fa_layer = next(l for l in mlx_model.model.layers if hasattr(l, 'self_attn'))
rope = fa_layer.self_attn.rope
half = head_dim // 2
mlx_cos, mlx_sin = [], []
for pos in positions:
cos_row, sin_row = [], []
for d in range(half):
q = mx.zeros((1, 1, 1, head_dim))
q[..., d] = 1.0
out = rope(q, offset=pos)
mx.eval(out)
out_np = np.array(out.astype(mx.float32)).reshape(-1)
cos_row.append(out_np[d])
sin_row.append(out_np[d + half])
mlx_cos.append(cos_row)
mlx_sin.append(sin_row)
mlx_cos, mlx_sin = np.array(mlx_cos), np.array(mlx_sin)
RoPE embeddings match to high precision across the full position range (max error < 0.002 at position 30,000). Test 5: Long sequence generationBase model — 16K tokens: python -m mlx_lm.generate --model allenai/Olmo-Hybrid-7B \
--prompt "Once upon a time in a land far, far away, ..." --max-tokens 16000Generates to the full 16K limit without crashes. The base model enters repetition after ~550 tokens (typical for greedy decoding on base models without EOS training). Verified same behavior in transformers — model behavior, not an MLX issue. DPO — Space Invaders game (3955 tokens): python -m mlx_lm.generate --model allenai/Olmo-Hybrid-Instruct-DPO-7B \
--prompt "Create a complete Space Invaders game in a single HTML file ..." --max-tokens 16000Generated a complete, coherent Space Invaders game (~530 lines of HTML/CSS/JS) and stopped naturally at EOS. No signs of RoPE-related degradation. Phase 5: 4-bit quantizationpython -m mlx_lm.convert --model models/olmo-hybrid-7b-dpo -q --q-bits 4 \
--mlx-path models/olmo-hybrid-7b-dpo-4bitpython -m mlx_lm.generate --model models/olmo-hybrid-7b-dpo-4bit \
--prompt "Explain the theory of relativity in simple terms." --max-tokens 200Quantization works correctly: coherent output, ~2.5x speed boost (108 vs 43 tok/s), memory drops from 15 GB to 4.4 GB. Dtype check passes (bfloat16 output, no float32 contamination). Phase 6: Unit testsTest config for {
"model_type": "olmo_hybrid",
"hidden_size": 128,
"num_hidden_layers": 4,
"intermediate_size": 128,
"num_attention_heads": 4,
"num_key_value_heads": 4,
"rms_norm_eps": 1e-6,
"vocab_size": 1000,
"max_position_embeddings": 1000,
"rope_theta": 10000.0,
"linear_num_key_heads": 4,
"linear_num_value_heads": 4,
"linear_key_head_dim": 32,
"linear_value_head_dim": 64,
"linear_conv_kernel_dim": 4,
"linear_allow_neg_eigval": True,
},Both RoPE and NoPE modes verified (forward pass, cache creation, step generation): import importlib
import mlx.core as mx
config_rope = { ... } # as above
config_nope = dict(config_rope)
config_nope['rope_parameters'] = {'rope_theta': None, 'rope_type': 'default'}
arch = importlib.import_module('mlx_lm.models.olmo_hybrid')
for name, cfg in [('RoPE', config_rope), ('NoPE', config_nope)]:
args = arch.ModelArgs.from_dict(cfg)
model = arch.Model(args)
x = mx.array([[1, 2, 3]])
out = model(x); mx.eval(out)
assert out.shape == (1, 3, 1000)
cache = model.make_cache()
out_cached = model(x, cache=cache); mx.eval(out_cached)
assert out_cached.shape == (1, 3, 1000)
out_step = model(mx.array([[4]]), cache=cache); mx.eval(out_step)
assert out_step.shape == (1, 1, 1000)
fa_layer = next(l for l in model.model.layers if hasattr(l, 'self_attn'))
assert (fa_layer.self_attn.rope is not None) == (name == 'RoPE') |
|
definitely works! added a PR to your fork. This adds some nits and optimisations: before: after: no need to merge it I think just copying and pasting it into here is enough. |
|
adding this into the def test_olmo_hybrid(self):
from mlx_lm.models import olmo_hybrid
args = olmo_hybrid.ModelArgs(
model_type="olmo_hybrid",
hidden_size=128,
intermediate_size=256,
num_hidden_layers=4,
num_attention_heads=4,
num_key_value_heads=2,
rms_norm_eps=1e-4,
vocab_size=1000,
max_position_embeddings=128,
linear_num_key_heads=1,
linear_num_value_heads=2,
linear_key_head_dim=32,
linear_value_head_dim=32,
linear_conv_kernel_dim=3,
linear_allow_neg_eigval=False,
tie_word_embeddings=False,
attention_bias=False,
rope_theta=1000,
layer_types=[
"linear_attention",
"linear_attention",
"linear_attention",
"full_attention",
],
)
model = olmo_hybrid.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
) |
angeloskath
left a comment
There was a problem hiding this comment.
First, I have to say I do appreciate the disclaimer and the thorough evaluation (even if the latter was done automatically by the LLM).
I did a quick pass and left some comments. Generally speaking the GatedDeltaNet needs some work and the overcommenting but otherwise it looks good.
mlx_lm/models/olmo_hybrid.py
Outdated
| # mx.fast.rms_norm accumulates in float32 internally | ||
| normed = mx.fast.rms_norm(x, self.weight, self.eps) | ||
| # Gate in float32, result kept in original dtype | ||
| return normed * nn.silu(gate.astype(mx.float32)).astype(x.dtype) |
There was a problem hiding this comment.
I would take the _precise_swiglu from Qwen3 next and refactor it out to activations.py and use it here.
I would also remove the comments. They just double the amount of text one has to parse with minimal additional info.
|
|
||
| if cache is not None: | ||
| cache[3] = state | ||
| cache.advance(S) |
There was a problem hiding this comment.
A lot of the above could just defer to gated_delta_update.
There was a problem hiding this comment.
The main problem is that this model allows "negative eigenvalues" (beta = beta * 2.0 in line 174 above), that the current gated_delta_update implementation does not support. This is a trick that appears in a footnote of the Gated Delta paper, but neither Qwen nor Kimi use it.
Would you like me to refactor the gated_delta_update implementation as part of this PR to accept an allow_neg_eigval param? I'd rather do it as part of a follow-up PR, but I'm fine doing it here if preferred.
|
Thanks a lot @Goekdeniz-Guelmez and @angeloskath 🙌. I'll go through your comments and update the PR. I agree with the overcommenting, there's already a directive about that in the skill but it was clearly not enough.
This is a great point. I asked the model to disclose the exact testing code it used so we can reproduce, but perhaps we can go a step further and run it separately via some sort of automation. I'll run tests manually for now. |
Nits on olmo hybrid
The DPO model does not use RoPE. This is controlled here: https://github.com/huggingface/transformers/blob/aad13b87ed59f2afcfaebc985f403301887a35fc/src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py#L956. The DPO model config explicitly sets theta to null: https://huggingface.co/allenai/Olmo-Hybrid-Instruct-DPO-7B/blob/main/config.json#L60
| self.rope = ( | ||
| initialize_rope( | ||
| self.head_dim, | ||
| base=args.rope_theta, | ||
| traditional=False, | ||
| max_position_embeddings=args.max_position_embeddings, | ||
| ) | ||
| if args.rope_theta is not None | ||
| else None | ||
| ) |
There was a problem hiding this comment.
cc @Goekdeniz-Guelmez just for info: I had to revert one of your changes because the DPO model does not use RoPE. See the transformers path here, and how it's controlled in the DPO model config.
There was a problem hiding this comment.
ahh, good catch!!!! No worries!
| @partial(mx.compile, shapeless=True) | ||
| def precise_swiglu(h, gate, x): | ||
| gate = nn.silu(gate.astype(mx.float32)) | ||
| x = x.astype(mx.float32) | ||
| return (gate * x).astype(h.dtype) |
There was a problem hiding this comment.
I used the exact same signature that appeared in the Qwen 3 Next implementation; let me know if it's preferable to rid the h.
They were used to support a pre-release checkpoint.
|
I think I addressed all the feedback (changes were fully done by me this time, not agent-assisted). Let me know how I should proceed regarding the comments I just posted (in particular the reuse / refactor of |
|
Gentle ping @angeloskath, happy to iterate as needed! Regarding your comment about test results, we are working on a separate test harness (not LLM-executed) that can be optionally triggered in our infra for PRs like this one. |
TL;DR: for reviewers
This PR was created with an in-progress conversion Skill that we'd like to contribute to the community. We are trying to make it comprehensive and careful, so it only considers the job done if relevant tests pass. We have instructed it to pay attention to usual problem areas, like long-context generation (and rope), dtype/upcasting errors, and others.
Any feedback received in this PR will be applied to the Skill so it works better next time! Please, do not hesitate to point out any observations, including, for example:
Tests were run on M3 Ultra (512 GB).
Summary
allenai/Olmo-Hybrid-7B,allenai/Olmo-Hybrid-Instruct-SFT-7B,allenai/Olmo-Hybrid-Instruct-DPO-7B, andallenai/Olmo-Hybrid-Think-SFT-7Bgated_delta.pyMetal kernel infrastructure for the linear attention layersrope_theta=10000), DPO uses NoPE (no positional embeddings)Architecture
OLMo Hybrid uses the same transformer architecture as OLMo 3 7B, except that 75% of layers use GatedDeltaNet linear attention heads instead of standard attention. The layer pattern alternates: 3 linear attention layers followed by 1 full attention layer.
Key implementation details:
allow_neg_eigvalscaling (beta × 2.0),RMSNormGatedoutput normalization. Pre-norm residual style.q_norm/k_norm(RMSNorm on full projections before head split), optional RoPE. Post-norm residual style.rope_theta: nullinrope_parameters, which disables RoPE entirely. This is handled by making RoPE conditional in theAttentionclass.gated_delta.pyconvention.Disclosure
This PR was created using the
transformers-to-mlxskill with the assistance of an AI agent (Claude).Test commands and results
Generation tests
Numerical comparison (MLX vs transformers bfloat16, 149-token prompt)
Per-position top-1 agreement (base model):
RoPE embedding verification
Direct comparison of RoPE cos/sin values between transformers and MLX at various positions:
Long generation test (16K tokens, base model)
🤖 Generated with Claude Code