Inference optimization for transformer models combining two independent techniques that can be used separately or stacked together.
Tested on Qwen2.5-14B running on Apple Silicon via MLX.
forge-edge-repo/
├── README.md
├── requirements.txt
├── infer_forge_edge_stacked.py # Run Stage 1 + Stage 2A together (benchmark)
├── chat_forge_edge_v5.py # Interactive chat with both stages active
│
├── word_router/ # Stage 1 — Surprisal Router
│ ├── train_router_8b.py # Train the router MLP on cached hidden states
│ ├── forge_edge_poc_v2.py # Proof-of-concept / earlier training variant
│ ├── infer_forge_edge_qwen14b.py # Inference engine with router gating
│ └── benchmark_4.py # Per-token benchmark with router decision logging
│
└── ablation/ # Stage 2A — MLP Lobotomy
├── profile_ffn_sparsity.py # Analyse FFN activation sparsity per layer
├── benchmark_mlp_ablation_mlx.py # Benchmark ablated model vs baseline
├── lobotomy_qwen2.5_14b_instruct_bf16.json # Layer sensitivity ranking (14B bf16)
├── ablation_benchmark_mlx.json # Benchmark results
└── ablation_mlx_v2.json # v2 benchmark results
Transformer vocabulary follows a Zipfian distribution: roughly 20% of tokens account for ~95% of all generated text. The router exploits this.
At each generation step, a small MLP (1.6M params) reads the hidden state and predicts whether the next token will be rare. If not, the model skips the expensive full-vocabulary matmul (hidden × W_vocab, 32K tokens) and uses a pre-sliced common head covering only ~6,400 tokens instead — a ~5× cheaper operation.
hidden state
│
├──► Router MLP ──► P(rare) > threshold?
│ │
│ Yes │ No
│ ┌──────────────┴──────────────┐
│ ▼ ▼
│ Full head (32K) Common head (~6.4K)
│ expensive ~5× cheaper
└────────────────────────────────────────
Result: ~20% latency reduction with <0.5% quality loss.
Not all transformer layers contribute equally. A sensitivity analysis (profile_ffn_sparsity.py) identifies which FFN/MLP blocks can be zeroed out with minimal impact on output quality. These are then replaced at load time with a ZeroMLP that returns zeros, keeping the residual stream intact but skipping the computation entirely.
The cut list for Qwen2.5-14B bf16 is pre-computed in lobotomy_qwen2.5_14b_instruct_bf16.json (11 layers removed).
Result: ~15–20% additional latency reduction with <1% quality loss.
Both stages compose cleanly. infer_forge_edge_stacked.py and chat_forge_edge_v5.py apply Stage 2A at load time and Stage 1 at each generation step.
Combined result: ~35–40% throughput improvement, ~30–40% energy reduction, <2% quality loss.
pip install -r requirements.txtRequires Python 3.10+ and Apple Silicon (MLX). For NVIDIA GPU usage, the router training and inference files also support PyTorch checkpoints.
Phase 1 caches backbone hidden states (~15 min, one-time). Phase 2 trains the router MLP (~2 min).
python word_router/train_router_8b.py
# Skip Phase 1 if cache already exists
python word_router/train_router_8b.py --skip-cache
# More training steps
python word_router/train_router_8b.py --max-steps 20000Checkpoint saved to ./checkpoints_8b/step_XXXXXXX/ containing:
router.safetensors— router weightscommon_ids.npy— the common token setmeta.txt— training metadata
Profiles which FFN layers are safest to cut based on activation sparsity (Gini coefficient / Zipfian structure).
python ablation/profile_ffn_sparsity.py --model mlx-community/Qwen2.5-14B-Instruct-bf16Output: per-layer fire rates, Gini coefficients, estimated bandwidth savings. Use the results to build a lobotomy JSON.
# Router vs baseline comparison
python word_router/benchmark_4.py --mode both \
--checkpoint ./qwen14b \
--max-tokens 200
# Router only
python word_router/benchmark_4.py --mode fe \
--checkpoint ./qwen14b \
--threshold 0.75# Router only
python infer_forge_edge_stacked.py \
--checkpoint ./qwen14b \
--max-tokens 500
# Router + Lobotomy (Stage 1 + 2A)
python infer_forge_edge_stacked.py \
--checkpoint ./qwen14b \
--lobotomy \
--max-tokens 500python chat_forge_edge_v5.py \
--model mlx-community/Qwen2.5-14B-Instruct-bf16 \
--checkpoint ./qwen14b/step_0015000 \
--lobotomy ablation/lobotomy_qwen2.5_14b_instruct_bf16.json \
--threshold 0.8Type exit or quit to end the session. Throughput (tok/s) is printed after each response.
| Configuration | Throughput | Latency | Power | Quality loss |
|---|---|---|---|---|
| Baseline | ~25 tok/s | ~40 ms/t | ~110 W | — |
| + Stage 1 (router) | ~30 tok/s | ~33 ms/t | ~100 W | <0.5% |
| + Stage 2A (lobotomy) | ~34 tok/s | ~29 ms/t | ~85 W | <1% |
| Stacked (1 + 2A) | ~40 tok/s | ~25 ms/t | ~70 W | <2% |
Pre-trained router checkpoints for Qwen2.5-14B are in ./qwen14b/. Place your trained checkpoints in the same structure:
qwen14b/
└── step_0015000/
├── router.safetensors
├── common_ids.npy
└── meta.txt