EAGLE-3 speculative decoding prototype: implementation + Apple Silicon analysis #890
kmsalah
started this conversation in
Show and tell
Replies: 1 comment
-
|
Vibe coded while trying to optimize m3 ultra silicon, including the post above. Post is intended to save the time and research of another researcher/agent in case they were interested in EAGLE-3 as an optimization approach. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I implemented EAGLE-3 speculative decoding1 on MLX to see how it performs on Apple Silicon. Sharing the implementation and findings here since there's no existing EAGLE port for MLX, and the results have some implications for speculative decoding on this platform in general.
What is EAGLE-3
EAGLE-3 uses the target model's own hidden states to predict future tokens, instead of a separate draft model1. A single extra decoder layer (the "draft head")2 fuses hidden states from 3 target model layers — selected as
{2, N//2, N-3}where N is the layer count3 — and predicts the next token in a reduced vocabulary of 32,000 tokens4. On CUDA, EAGLE-3 reports 3.3-6.5x speedup via tree-based verification5.Implementation
1,148 lines across 4 files6, tested with
Meta-Llama-3.1-8B-Instruct-4bit7 + the pretrained EAGLE-3 head8:nn.Module(modified decoder layer with 2×hidden_size Q/K/V input9, reduced vocabulary lm_head, draft↔target vocab mapping)forward_with_hidden_states()that taps pre-layer activations10 from layers 2, 16, 293. Greedy single-path verification (no tree).Source: branch | gist
Key implementation details:
nn.quantize(draft_model, bits=4)at runtime — cuts draft forward pass from ~0.9ms to ~0.4ms with zero accuracy lossmx.evalcall (no intermediate syncs)Benchmark results
M3 Ultra 96GB, LLaMA-3.1-8B-Instruct 4-bit, 128 tokens, 5 prompts averaged, 2 warmup runs.
The 1.05x result comes from three stacked optimizations: 4-bit draft quantization, eliminating the "cache completeness" draft call (compensated with RoPE position offsets), and collapsing multiple
mx.evalcalls into one.Why the speedup is limited on Apple Silicon
On CUDA, EAGLE-3 gets 3.3-6.5x5 because tree verification is nearly free — verifying 64 candidates costs about the same as 1 token (memory-bound matmuls, batch dimension fills the GPU). On Apple Silicon with a 4-bit model, the dynamics are different:
Per-token activation cost is ~0.5-1ms. The 4-bit 8B model is small enough that each extra verification token costs real time, unlike CUDA where batch parallelism amortizes this.
The overhead budget is razor-thin. At ~8.8ms per baseline token11: the draft head costs ~0.4ms, the extra verification token costs ~1ms, sync overhead ~0.5ms. With 34% acceptance rate12, each EAGLE step produces 1.34 tokens on average. The 0.34 bonus tokens save 0.34 × 8.8ms = 3.0ms, but we spend ~1.9ms on overhead. Net saving: ~1.1ms/step, which matches the measured 1.05x13.
Tree attention is blocked by KVCache. The current
KVCache14 uses a singleoffsetinteger for RoPE positions15. Tree verification requires different position IDs for candidates at different depths, which isn't possible without modifying the cache or model forward pass.Observations that might be useful
Speculative decoding has a model-size sweet spot on Apple Silicon. For small quantized models (4-bit 8B at ~114 tok/s), per-token time is already fast enough that draft overhead struggles to pay for itself. Larger or less-quantized models (where baseline is slower) should benefit more.
position_idssupport in the model forward pass would unlock tree attention for both EAGLE-style and standard speculative decoding. Currently all models derive RoPE positions fromcache.offset15. Adding optionalposition_idsto__call__(similar to HuggingFace transformers) would enable tree masks, which is where the real speedup potential lives. Related open issues: #846, #250.Draft model quantization at runtime is free.
nn.quantize(model, bits=4)on the EAGLE head produced identical acceptance rates (0.34)12 while halving the forward pass cost. Might be worth defaulting to for the existing speculative decoding pipeline too.RoPE position drift is surprisingly tolerable in EAGLE. We skipped a draft model call per step (saving ~0.4ms) by letting the draft KV cache fall behind the target cache, compensating only with a RoPE offset. The accumulated drift reached 48+ positions at 256 tokens with no measurable accuracy degradation — the fused target hidden states dominate the draft model's predictions.
Running the code
Happy to answer questions or clean this up into a PR if there's interest.
Footnotes
Li et al., "EAGLE-3: Scaling up Inference Acceleration of Large Language Models via Training-Free Token-Level Blending," NeurIPS 2025. arXiv:2503.01840, NeurIPS acceptance. ↩ ↩2
EAGLE-3 checkpoint config specifies
num_hidden_layers: 1(single decoder layer).config.json. ↩Layer indices computed as
{2, N//2, N-3}inmodeling_llama_kv.py. For 32-layer LLaMA-3.1-8B: layers 2, 16, 29. ↩ ↩2draft_vocab_size: 32000from checkpointconfig.json, vs full LLaMA-3.1 vocab of 128,256. ↩Paper Table 1, temperature=0 greedy decoding. Min 3.27x (LLaMA-3.3-70B on CNN/DM), max 6.47x (Vicuna-13B on HumanEval). arXiv:2503.01840. ↩ ↩2
eagle_draft.py(256) +eagle_generate.py(604) +eagle_convert.py(164) +eagle_bench.py(124) = 1,148 lines. ↩mlx-community/Meta-Llama-3.1-8B-Instruct-4bit— 4-bit quantized MLX conversion ofmeta-llama/Meta-Llama-3.1-8B-Instruct. ↩yuhuili/EAGLE3-LLaMA3.1-Instruct-8B— pretrained EAGLE-3 draft head, 850MB PyTorch checkpoint. ↩eagle_draft.pyline 55:input_size = 2 * config.hidden_size. Q/K/V projections take concatenated[token_embedding, fused_features]as input. Gist source. ↩EAGLE-3 captures pre-layer activations (hidden state before the layer executes), matching the training convention in
modeling_llama_kv.py. ↩ ↩21000ms / 113.7 tok/s = 8.80ms per token. ↩
Measured across 5 prompts, 128 tokens each, consistent at 0.34 with and without 4-bit draft quantization. ↩ ↩2
118.9 / 113.7 = 1.046x, rounded to 1.05x. ↩
KVCacheclass in mlx-lm. ↩self.offset = 0atcache.pyline 275. Plain integer, incremented bykeys.shape[2]each step. ↩ ↩2Beta Was this translation helpful? Give feedback.
All reactions