[transformers-to-mlx skill] Add OLMo Hybrid (GatedDeltaNet + Full Attention)#942
[transformers-to-mlx skill] Add OLMo Hybrid (GatedDeltaNet + Full Attention)#942pcuenca wants to merge 1 commit intoml-explore:mainfrom
Conversation
Implements mlx_lm/models/olmo_hybrid.py for the OLMo Hybrid architecture recently merged into transformers. Also adds olmo_hybrid to test_all_models. Architecture: 7.4B parameter hybrid model interleaving GatedDeltaNet (linear_attention) and standard MHA (full_attention) in a 3:1 ratio. Key implementation details: - Separate q/k/v depthwise conv1d per head (vs fused in qwen3_next) - allow_neg_eigval: beta scaled by 2.0 via direct gated_delta_ops call - Weight key names match checkpoint (attention_layer_norm/feedforward_layer_norm) - Conv1d weights transposed in sanitize(): (out, 1, kernel) → (out, kernel, 1) - rope_parameters extracted from nested dict in ModelArgs.__post_init__ - ArraysCache(size=4) for linear attention layers Tested against hf-internal-testing/olmo-hybrid: - 10/10 top-10 token overlap vs transformers (max logit diff: 0.1875) - Greedy generation matches transformers output - Output dtype: bfloat16 throughout (no float32 contamination) - Long sequence (500 tokens): coherent, no degradation - 4-bit quantized: 110 tok/s, 4.28 GB Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Update: fix norm key naming for compatibility with fine-tuned modelsThe original checkpoint (
The transformers model uses different attribute names:
Transformers handles this transparently via a The initial version of this PR used the checkpoint names as MLX attribute names, which worked for the original checkpoint but would break for models fine-tuned and saved with transformers (which would use the transformers attribute names). This update switches the MLX attribute names to match transformers ( Numerical results are unchanged (same logits, same generation output). |
|
Closing (experimental, will reopen) |
|
Hey @pcuenca, can this be reopened, since the model is public? |
|
Yes, good idea @Goekdeniz-Guelmez, I'll test on the published checkpoints first. |
|
I made some fixes and ran some tests in my fork here, does it look reasonable to you @Goekdeniz-Guelmez? If so, I think we could go ahead and open the PR here :) And sorry for the delay, I was traveling for a few days! |
Thanks for testing, at first glance it look good to go, I’ll let you know when I ran it too, thanks and no worries! |
|
Superseded by #1023 |
Summary
mlx_lm/models/olmo_hybrid.py: MLX implementation of the OLMo Hybrid architecture (recently merged into transformers, no official checkpoints yet)olmo_hybridentry totest_all_modelsintests/test_models.pyhf-internal-testing/olmo-hybrid(7.4B params, bfloat16)Architecture
OLMo Hybrid interleaves two layer types in a 3:1 ratio (
[linear, linear, linear, full] × 8, 32 layers total):linear_attention): linear recurrent attention with gated delta rule, separate depthwise conv1d per q/k/v projection, learnable decay, pre-norm stylefull_attention): standard MHA with q/k normalization (RMSNorm on full projections) and RoPE, post-norm styleNon-trivial implementation details
1. Weight key name mismatch between checkpoint and transformers
The
hf-internal-testing/olmo-hybridcheckpoint uses OLMo-core naming (attention_layer_norm,feedforward_layer_norm) for linear attention layer norms, while the merged transformers code usesinput_layernorm/post_attention_layernorm. The MLX model uses the checkpoint names directly so nosanitize()remapping is needed. (Transformers loads these with ones-initialized weights unless manually patched — this was the cause of an early investigation detour.)2. Conv1d weight transpose in
sanitize()Checkpoint stores depthwise conv1d weights as
(out_channels, 1, kernel_size), but MLXConv1dexpects(out_channels, kernel_size, 1). Fixed withv.moveaxis(2, 1).3.
allow_neg_eigvalbeta scalinglinear_allow_neg_eigval=Truerequires beta ∈ [0, 2] instead of [0, 1]. This requires callinggated_delta_ops/gated_delta_kerneldirectly rather thangated_delta_update, which computesbeta = sigmoid(b)internally without a scaling hook.4.
rope_parametersas nested dictRoPE config is stored as
{"rope_theta": 500000, "rope_type": "default"}rather than top-level fields. Extracted inModelArgs.__post_init__.5. L2Norm via RMSNorm + scale factor
GatedDeltaNet applies l2norm to q and k; this is equivalent to
rms_norm(x) * inv_scale(for k) andrms_norm(x) * inv_scale²(for q after the scale-compensating attention dot product), matching the pattern fromqwen3_next.py.6. Cache structure
ArraysCache(size=4)— [conv_q_state, conv_k_state, conv_v_state, recurrent_state]KVCacheTest results
Numerical comparison (MLX bfloat16 Metal vs transformers bfloat16 CPU)
Prompt:
"The history of artificial intelligence began in"(7 tokens)Greedy generation comparison
MLX output:
Transformers output (patched weights):
Identical except for the leading space (mlx-lm's streaming detokenizer trims the leading space from the first generated token).
Output dtype
Verified for both fp16 and 4-bit quantized model.
Long sequence (500 tokens)
500 tokens of coherent structured text, no degradation.
4-bit quantization
Output dtype bfloat16 ✓ after quantization.
Unit tests
python -m pytest tests/test_models.py::TestModels::test_all_models -v # 1 passed, 51 subtests passed