-
Notifications
You must be signed in to change notification settings - Fork 0
[Vulkan] Lower common decode primitive sequences into fused execution regions #17
Description
Problem
The decode path still executes too many stable primitive sequences independently. This creates extra recording boundaries, descriptor writes, and synchronization points even when the sequence shape is highly repetitive across tokens.
This is related to #4 but is not a duplicate of it. Issue #4 is about descriptor-set churn. This issue is about lowering stable decode subsequences into fewer Vulkan execution regions so there is less descriptor and scheduling work to do in the first place.
Why This Matters
The reference runtimes reduce overhead not just with low-level reuse, but by making decode execution structurally larger and more fusion-friendly:
- ggml reorders/fuses graph patterns to preserve efficient execution regions
- Zinc hand-writes large token-local execution blocks instead of dispatching everything as isolated primitives
MLX should identify repeated decode subsequences and lower them into fused or region-based Vulkan execution paths.
Tasks
- Identify common decode subsequences such as norm -> projection, residual/add chains, and KV update + immediate attention consumption
- Lower those subsequences into fewer Vulkan execution regions or fused paths
- Reuse bindings/scratch/layout state across the region instead of rebinding per primitive when possible
- Measure reduction in host enqueue time and descriptor traffic per token
Acceptance Criteria
- Lower host-side enqueue time per token on Qwen3 decode
- Fewer recording boundaries and descriptor updates in the hot path
- No generation correctness regressions
References
mlx-vulkan-reference-conclusions.mdreferences/ggml-vulkan-findings.mdreferences/zinc-findings.mdreferences/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp(graph reordering / fusion-friendly execution)references/zinc/src/compute/forward.zig(imperative token-local decode schedule)