Skip to content

[transformers-to-mlx skill] Add OLMo Hybrid (GatedDeltaNet + Full Attention)#942

Closed
pcuenca wants to merge 1 commit intoml-explore:mainfrom
pcuenca:conversions
Closed

[transformers-to-mlx skill] Add OLMo Hybrid (GatedDeltaNet + Full Attention)#942
pcuenca wants to merge 1 commit intoml-explore:mainfrom
pcuenca:conversions

Conversation

@pcuenca
Copy link
Copy Markdown

@pcuenca pcuenca commented Mar 2, 2026

Disclosure: This implementation and all tests were produced using the transformers-to-mlx skill for Claude Code.

Summary

  • Adds mlx_lm/models/olmo_hybrid.py: MLX implementation of the OLMo Hybrid architecture (recently merged into transformers, no official checkpoints yet)
  • Adds olmo_hybrid entry to test_all_models in tests/test_models.py
  • Tested against hf-internal-testing/olmo-hybrid (7.4B params, bfloat16)

Architecture

OLMo Hybrid interleaves two layer types in a 3:1 ratio ([linear, linear, linear, full] × 8, 32 layers total):

  • GatedDeltaNet (linear_attention): linear recurrent attention with gated delta rule, separate depthwise conv1d per q/k/v projection, learnable decay, pre-norm style
  • Full Attention (full_attention): standard MHA with q/k normalization (RMSNorm on full projections) and RoPE, post-norm style

Non-trivial implementation details

1. Weight key name mismatch between checkpoint and transformers
The hf-internal-testing/olmo-hybrid checkpoint uses OLMo-core naming (attention_layer_norm, feedforward_layer_norm) for linear attention layer norms, while the merged transformers code uses input_layernorm / post_attention_layernorm. The MLX model uses the checkpoint names directly so no sanitize() remapping is needed. (Transformers loads these with ones-initialized weights unless manually patched — this was the cause of an early investigation detour.)

2. Conv1d weight transpose in sanitize()
Checkpoint stores depthwise conv1d weights as (out_channels, 1, kernel_size), but MLX Conv1d expects (out_channels, kernel_size, 1). Fixed with v.moveaxis(2, 1).

3. allow_neg_eigval beta scaling
linear_allow_neg_eigval=True requires beta ∈ [0, 2] instead of [0, 1]. This requires calling gated_delta_ops/gated_delta_kernel directly rather than gated_delta_update, which computes beta = sigmoid(b) internally without a scaling hook.

4. rope_parameters as nested dict
RoPE config is stored as {"rope_theta": 500000, "rope_type": "default"} rather than top-level fields. Extracted in ModelArgs.__post_init__.

5. L2Norm via RMSNorm + scale factor
GatedDeltaNet applies l2norm to q and k; this is equivalent to rms_norm(x) * inv_scale (for k) and rms_norm(x) * inv_scale² (for q after the scale-compensating attention dot product), matching the pattern from qwen3_next.py.

6. Cache structure

  • Linear attention: ArraysCache(size=4) — [conv_q_state, conv_k_state, conv_v_state, recurrent_state]
  • Full attention: standard KVCache

Test results

Numerical comparison (MLX bfloat16 Metal vs transformers bfloat16 CPU)

Prompt: "The history of artificial intelligence began in" (7 tokens)

Max diff:   0.1875
Mean diff:  0.031
Top-5 overlap:  5/5
Top-10 overlap: 10/10

Top-10 tokens:
     279 ' the'       : MLX=9.7500, HF=9.7500, diff=0.0000
     220 ' '          : MLX=9.4375, HF=9.4375, diff=0.0000
   14154 ' ancient'   : MLX=8.6250, HF=8.6250, diff=0.0000
   61386 ' antiqu'    : MLX=7.8750, HF=7.9375, diff=0.0625
   38050 ' Ancient'   : MLX=6.0938, HF=6.1875, diff=0.0938
   55349 ' earnest'   : MLX=5.1562, HF=5.1562, diff=0.0000
     264 ' a'         : MLX=4.5312, HF=4.5312, diff=0.0000
    6287 ' August'    : MLX=4.0625, HF=4.0625, diff=0.0000
   29924 ' classical' : MLX=4.0312, HF=4.0625, diff=0.0312
    4435 ' World'     : MLX=3.9062, HF=3.9375, diff=0.0312

Greedy generation comparison

MLX output:

the 1950s with the Dartmouth Conference, where the term "artificial intelligence"
was coined. Since then, AI has evolved significantly, with advancements in machine
learning, neural networks, and robotics driving its growth. Today, AI is integrated
into various aspects of our lives, from virtual assistants and self-driving cars to
healthcare and finance.

The Future of AI

The future of AI holds immense

Transformers output (patched weights):

 the 1950s with the Dartmouth Conference, where the term "artificial intelligence"
was coined. Since then, AI has evolved significantly, with advancements in machine
learning, neural networks, and robotics driving its growth. Today, AI is integrated
into various aspects of our lives, from virtual assistants and self-driving cars to
healthcare and finance.

The Future of AI

The future of AI holds immense

Identical except for the leading space (mlx-lm's streaming detokenizer trims the leading space from the first generated token).

Output dtype

Config dtype: 'bfloat16' -> expected: mlx.core.bfloat16
Output dtype: mlx.core.bfloat16
PASS: Output dtype matches config

Verified for both fp16 and 4-bit quantized model.

Long sequence (500 tokens)

mlx_lm.generate --model hf-internal-testing/olmo-hybrid \
  --prompt "Write a detailed explanation of how the Python programming language works..." \
  --max-tokens 500
# Prompt: 28 tokens, 195.842 tokens-per-sec
# Generation: 500 tokens, 43.106 tokens-per-sec
# Peak memory: 15.002 GB

500 tokens of coherent structured text, no degradation.

4-bit quantization

mlx_lm.convert --hf-path hf-internal-testing/olmo-hybrid \
  --mlx-path olmo-hybrid-mlx-4bit -q
# [INFO] Quantized model with 4.502 bits per weight.

mlx_lm.generate --model olmo-hybrid-mlx-4bit \
  --prompt "The history of artificial intelligence began in" --max-tokens 80
# Prompt: 7 tokens, 123.534 tokens-per-sec
# Generation: 80 tokens, 109.696 tokens-per-sec
# Peak memory: 4.280 GB

Output dtype bfloat16 ✓ after quantization.

Unit tests

python -m pytest tests/test_models.py::TestModels::test_all_models -v
# 1 passed, 51 subtests passed

Implements mlx_lm/models/olmo_hybrid.py for the OLMo Hybrid architecture
recently merged into transformers. Also adds olmo_hybrid to test_all_models.

Architecture: 7.4B parameter hybrid model interleaving GatedDeltaNet
(linear_attention) and standard MHA (full_attention) in a 3:1 ratio.

Key implementation details:
- Separate q/k/v depthwise conv1d per head (vs fused in qwen3_next)
- allow_neg_eigval: beta scaled by 2.0 via direct gated_delta_ops call
- Weight key names match checkpoint (attention_layer_norm/feedforward_layer_norm)
- Conv1d weights transposed in sanitize(): (out, 1, kernel) → (out, kernel, 1)
- rope_parameters extracted from nested dict in ModelArgs.__post_init__
- ArraysCache(size=4) for linear attention layers

Tested against hf-internal-testing/olmo-hybrid:
- 10/10 top-10 token overlap vs transformers (max logit diff: 0.1875)
- Greedy generation matches transformers output
- Output dtype: bfloat16 throughout (no float32 contamination)
- Long sequence (500 tokens): coherent, no degradation
- 4-bit quantized: 110 tok/s, 4.28 GB

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@pcuenca
Copy link
Copy Markdown
Author

pcuenca commented Mar 3, 2026

Update: fix norm key naming for compatibility with fine-tuned models

The original checkpoint (hf-internal-testing/olmo-hybrid) uses OLMo-core naming for linear attention layer norms:

  • attention_layer_norm (pre-norm before linear attention)
  • feedforward_layer_norm (pre-norm before MLP)

The transformers model uses different attribute names:

  • input_layernorm
  • post_attention_layernorm

Transformers handles this transparently via a WeightRenaming registered in conversion_mapping.py.

The initial version of this PR used the checkpoint names as MLX attribute names, which worked for the original checkpoint but would break for models fine-tuned and saved with transformers (which would use the transformers attribute names).

This update switches the MLX attribute names to match transformers (input_layernorm, post_attention_layernorm) and remaps the original checkpoint names in sanitize(). This way both the original checkpoint and any future transformers-saved checkpoints load correctly.

Numerical results are unchanged (same logits, same generation output).

@pcuenca
Copy link
Copy Markdown
Author

pcuenca commented Mar 3, 2026

Closing (experimental, will reopen)

@pcuenca pcuenca closed this Mar 3, 2026
@Goekdeniz-Guelmez
Copy link
Copy Markdown
Contributor

Hey @pcuenca, can this be reopened, since the model is public?

@pcuenca
Copy link
Copy Markdown
Author

pcuenca commented Mar 17, 2026

Yes, good idea @Goekdeniz-Guelmez, I'll test on the published checkpoints first.

@pcuenca
Copy link
Copy Markdown
Author

pcuenca commented Mar 18, 2026

I made some fixes and ran some tests in my fork here, does it look reasonable to you @Goekdeniz-Guelmez? If so, I think we could go ahead and open the PR here :)

And sorry for the delay, I was traveling for a few days!

@Goekdeniz-Guelmez
Copy link
Copy Markdown
Contributor

I made some fixes and ran some tests in my fork here, does it look reasonable to you @Goekdeniz-Guelmez? If so, I think we could go ahead and open the PR here :)

And sorry for the delay, I was traveling for a few days!

Thanks for testing, at first glance it look good to go, I’ll let you know when I ran it too, thanks and no worries!

@pcuenca
Copy link
Copy Markdown
Author

pcuenca commented Mar 19, 2026

Superseded by #1023

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants