diff --git a/.gitignore b/.gitignore index 72c21a8..fd90917 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ build/ # Virtual environments .venv/ +.venv-mlx-dev/ venv/ ENV/ @@ -33,28 +34,36 @@ htmlcov/ zig/.zig-cache/ zig/zig-out/ -# Project-specific exclusions (sensitive/local-only) -firebase-debug.log -SETUP_TRAINING.md -NEXT_SESSION.md -CLAUDE.md +# Session-specific prompt files (local handoff docs, not general documentation) +CAMPAIGN_PLAN.md HANDOFF_PROMPT.md REVIEW_PROMPT.md IMPLEMENTATION_SUMMARY.md -AGENTS.md -UPSTREAM_PLAN.md PARALLEL_MOE_PROMPT.md +NEXT_SESSION_PROMPT.md +NEXT_SESSION.md +NEXT_AI_KERNEL_DISCOVERY_TAKEOVER_PROMPT.md +SETUP_TRAINING.md +firebase-debug.log + +# Ephemeral output directories (generated locally, not shipped) +sessions/ +runs/ +training_data/ +discover_sessions/ benchmarks/results/ -# Local-only directories (not shipped) +# External projects (cloned locally, not part of ZMLX) mlx_local/ -_worktrees/ -stable-diffcoder-mlx/ -experiments/ exo/ vllm-metal/ +stable-diffcoder-mlx/ zmlx_kvtc_integration/ +_worktrees/ + +# Local experiment scripts (not shipped) +experiments/ + +# Tool state .aleph/ -NEXT_SESSION_PROMPT.md .hf_cache/ -src/zmlx/foundry/ diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..bb61a5a --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,72 @@ +# AGENTS.md + +Guidance for AI coding agents working in this repository. + +## Rules + +- Do **not** include machine-specific absolute paths (e.g., `/Volumes/VIXinSSD/...`) in README, docs, or user-facing text. Use placeholders like ``, `$HF_HOME`, or repository-relative paths. +- Do **not** fabricate benchmark numbers. All performance claims must come from actual measurements with repro capsules in `benchmarks/repro_capsules/`. +- Do **not** modify `mlx_local/` or `exo/` — these are external projects cloned locally and gitignored. +- Always activate the venv (`source .venv/bin/activate`) before running any Python commands. +- Run `ruff check .` and `pytest -q` before considering any code change complete. + +## Project Overview + +ZMLX is a Metal kernel toolkit for MLX on Apple Silicon. It provides: + +1. **Kernel authoring** — `elementwise("x * tanh(log(1 + exp(x)))")` compiles to Metal +2. **Model patching** — `patch(model)` fuses MoE expert dispatch for faster decode +3. **70+ kernel catalog** — activations, attention, norms, MoE, quant, loss, etc. +4. **Custom C++ primitive** — `gather_qmm_swiglu` for GLM/Qwen3 (optional, ~800 lines Metal/C++) + +### What Actually Works (Proven Results) + +| Model | Speedup | Requires | +|:--|--:|:--| +| LFM2-8B-A1B-4bit | +11.6% decode | stock MLX | +| GLM-4.7-Flash-4bit | +8.5% decode | custom `gather_qmm_swiglu` | +| Qwen3-30B-A3B-4bit | +5.5% decode | custom `gather_qmm_swiglu` | + +All token-identical under greedy decoding. + +## Key Commands + +```bash +pytest -q # ~670 tests +ruff check . # lint +python -m zmlx.validate --runs 3 # fidelity + throughput +python -m zmlx.matrix catalog # 58 models with metadata +python -m zmlx.matrix report # test matrix heatmap +``` + +## File Layout + +``` +src/zmlx/ + patch/ # model patching (the main win) + patterns/moe_mlp.py # fused MoE expert dispatch + patterns/swiglu_mlp.py # dense SwiGLU fusion + __init__.py # patch(), safety excludes + kernels/ # 70+ Metal kernels (19 modules) + matrix/ # test matrix: catalog, runner, reports + foundry/ # kernel template evaluation + SFT dataset export + discover/ # LLM-guided PUCT kernel optimization search + train/ # LoRA training CLI + validate.py # fidelity + throughput validation + api.py # kernel authoring API + metal.py # Metal kernel wrapper +tests/ # ~670 tests +benchmarks/ # benchmark scripts + repro capsules +configs/ # training/foundry config YAML files +docs/ # user-facing documentation +integrations/ # custom MLX primitive patch +``` + +## Model-Aware Safety + +`patch()` auto-detects model family and skips patterns with known issues: +- **Qwen**: `swiglu_mlp` and `residual_norm` break fidelity +- **GLM/Qwen on stock MLX**: `moe_mlp` regresses (needs custom primitive) +- **Mixtral**: `moe_mlp` breaks fidelity + +See `_FIDELITY_EXCLUDES` and `_PERF_EXCLUDES` in `src/zmlx/patch/__init__.py`. diff --git a/CHANGELOG.md b/CHANGELOG.md index 871d6a8..ebeddb2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,7 +14,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - GLM combine-mode default now resolves to `fp32_no_fma` when `ZMLX_GLM_COMBINE_MODE` is unset/invalid, aligning runtime behavior with the documented default path. - ## [0.8.5] - 2026-02-11 ### Added diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..36f54df --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,267 @@ +# CLAUDE.md + +Context for Claude Code when working in this repository. + +## What ZMLX Is + +ZMLX is an open-source toolkit that extends [MLX](https://github.com/ml-explore/mlx) with custom Metal kernels for Apple silicon. It provides a kernel authoring API, automatic gradient support, a catalog of over 70 ready-to-use kernels, and a model patching system that replaces MLX model layers with fused kernel equivalents. Think "Triton for MLX": author kernels in Python, compile to Metal, and wire them into real models. + +**Historical best-case wins:** ZMLX has achieved measurable, token-identical +decode speedups on real models shipping today: + +| Model | Hardware | Speedup | MLX Required | +|:--|:--|--:|:--| +| LFM2-8B-A1B-4bit | M4 Max 36GB | **+11.6%** | stock (pip install) | +| LFM2-8B-A1B-4bit | M1 Pro 16GB | **+9.3%** | stock (pip install) | +| GLM-4.7-Flash-4bit | M4 Max 36GB | **+8.5%** | custom `gather_qmm_swiglu` | +| LFM2-24B-A2B-4bit | M4 Max 36GB | **+7.2%** | stock (pip install) | +| Qwen3-30B-A3B-4bit | M4 Max 36GB | **+5.5%** | custom `gather_qmm_swiglu` | + +All results were verified token-identical under greedy decoding in their +recorded run windows. Repro capsules live in `benchmarks/repro_capsules/`. + +**Current benchmark truth set (2026-02-24 update):** +- LFM2-24B-A2B-MLX-4bit: D-SIMD gate kernel + native combine, fused SwiGLU + disabled. **+7.2% decode** on M4 Max, 500/500 fidelity PASS (stock MLX). + Auto-enabled for K>=3. +- GLM-4.7-Flash: keep `glm_combine_fp32_no_fma` as active default + (`promote_candidate`; not fully promoted). +- Qwen3-30B-A3B: keep `control_patterns_moe_mlp` as benchmark anchor; no + custom-kernel variant is currently promoted. + +### How It Works + +The kernel authoring API compiles Metal Shading Language directly from Python expressions. A call like `elementwise("x * tanh(log(1 + exp(x)))")` generates a Metal kernel, compiles it, caches it, and returns a callable that behaves like any other MLX op. Gradient support is built on `mx.custom_function`, where the backward passes are themselves Metal kernels rather than autodiff tape operations. + +The primary use case is faster MoE decode inference. In MoE models, each token is routed to a subset of expert networks, and the standard path dispatches multiple Metal kernels per expert per layer. ZMLX fuses several of these into single dispatches, reducing command buffer overhead. The fused SwiGLU expert dispatch activates only when M <= 32 (decode); during prefill the standard MLX code path runs unchanged. + +### Custom `gather_qmm_swiglu` Primitive + +For GLM and Qwen3, we wrote a custom C++ Metal primitive (~800 lines) that fuses gate projection + up projection + SwiGLU activation for quantized MoE experts into a single GPU dispatch. This primitive is **not part of released MLX** — it lives in `mlx_local/` and is applied via the patch in `integrations/mlx_local_integration/`. On stock MLX, ZMLX auto-detects its absence and safely skips these models (0 modules patched, no regressions). + +## Dev Environment + +The project uses a venv at `.venv/`. Always use the venv Python — never the system Python. + +```bash +source .venv/bin/activate # activate before any work +pip install -e ".[dev]" # install editable (already done) +``` + +All commands below assume the venv is active. If you see `ModuleNotFoundError: No module named 'zmlx'`, you forgot to activate. + +## Build and Test + +```bash +pytest -q # full test suite (~670 tests) +pytest tests/test_kernels_catalog.py -q # kernel correctness only +pytest -k "test_softmax" -q # single test by name + +ruff check . # lint +mypy src # type check +``` + +End-to-end model validation compares patched and unpatched models under greedy decoding: + +```bash +python -m zmlx.validate --max-tokens 500 --runs 5 +``` + +Repro capsules capture raw per-run data with full environment metadata: + +```bash +python -m zmlx.bench.report benchmarks/repro_capsules/lfm2_m4max_20260131.json +``` + +Tests auto-skip on platforms without Apple silicon or Metal GPU support (see `tests/conftest.py`). + +## Benchmark Guardrails + +For GLM/Qwen3-sized benchmarks: +- Do not run multi-variant sweeps in one Python process. +- Use `benchmarks/bench_iso_variant_sweep.py` so each variant runs in an + isolated subprocess. +- For GLM consistency checks, run AB/BA blocks with cooldown windows. +- For matrix controls, include `--patterns none` anchors. + +Campaign entrypoint (phase-based): + +```bash +bash benchmarks/run_3hr_benchmark_campaign.sh +``` + +Recommended phase flow: +1. `quick` +2. `glm_abba_200` +3. `glm_abba_1024` +4. `qwen_isolation_200` +5. `qwen_isolation_1024` +6. `glm_stress` +7. `matrix_controls` + +## Test Matrix + +The matrix module (`src/zmlx/matrix/`) provides structured model catalog, test runner, and reporting across all Exo-supported models: + +```bash +python -m zmlx.matrix catalog # 58 models with metadata +python -m zmlx.matrix run --runs 3 # validate one model +python -m zmlx.matrix run --all --runs 3 # all models that fit in RAM +python -m zmlx.matrix report # terminal heatmap table +python -m zmlx.matrix csv # CSV export +python -m zmlx.matrix html > matrix.html # self-contained HTML +``` + +The catalog ingests Exo's TOML model cards plus manual LFM2 entries. Each model has inferred architecture (MoE/dense), expected ZMLX patterns, and exclusion reasons. + +## Custom MLX Primitive (optional) + +Some MoE fusions require a local MLX fork that exposes `mx.gather_qmm_swiglu`. ZMLX auto-detects this at runtime and only enables those paths when present. + +```bash +python -c "import mlx.core as mx; print(hasattr(mx, 'gather_qmm_swiglu'))" +``` + +Build instructions: `docs/EXPERIMENTAL_MLX.md`. The patch: `integrations/mlx_local_integration/gather_qmm_swiglu.patch`. + +## Architecture + +The codebase is organized under `src/zmlx/`: + +### Metal kernel infrastructure +`metal.py` wraps `mx.fast.metal_kernel` with in-process caching keyed on source hash and configuration. `cache.py` manages the global compilation cache. `_compat.py` detects the platform and guards imports on non-Apple-silicon systems. + +### Code generation +`codegen.py` provides template generators for elementwise, rowwise reduction, and two-pass map-reduce patterns. `elementwise.py` exposes `unary()`, `binary()`, and `map()` builders. `autograd.py` creates differentiable ops via `mx.custom_function`. `autotune.py` searches threadgroup size candidates. + +### Kernel catalog (`kernels/`) +70+ kernels across 19 modules: activations, attention, bits, fused, fused_moe, image, indexing, linear, loss, moe, norms, optimizers, quant, reductions, rope, scan, softmax, transformer, vlsp. + +### Patch system (`patch/`) +Walks a model tree, detects the architecture, replaces matching layers with fused kernel equivalents. Each pattern is a single file in `patch/patterns/`: + +- `moe_mlp.py` — fused MoE expert dispatch (the main performance win) +- `swiglu_mlp.py` — dense SwiGLU layer fusion +- `geglu_mlp.py` — GeGLU activation fusion +- `rmsnorm.py`, `layernorm.py` — norm kernel replacements +- `softmax.py` — softmax kernel replacement +- `residual_norm.py` — fused residual + norm + +Model-aware safety: `_FIDELITY_EXCLUDES` and `_PERF_EXCLUDES` in `patch/__init__.py` auto-skip patterns with known issues per model family (Qwen: swiglu/residual fidelity; GLM/Qwen: moe_mlp perf risk on stock MLX). + +### Test matrix (`matrix/`) +Model catalog from Exo TOML cards, JSONL ledger for results, terminal/CSV/HTML reports. See `python -m zmlx.matrix --help`. + +### Foundry module (`foundry/`) +Kernel template evaluation and dataset generation. Generates Metal kernel variants across 16 ops (9 kernel classes), evaluates correctness and performance, and exports training data for kernel-writing SFT. + +CLI: `python -m zmlx.foundry {run, report, export, export-sft, list}`. + +Key submodules: `ops/` (op registry), `templates/` (Metal template discovery), `harness/` (compile + benchmark orchestrator), `sampling/` (candidate generation), `scheduler.py` (curriculum staging), `export/` (JSONL + chat SFT export), `workers.py` (multi-process execution). + +### Discover module (`discover/`) +LLM-guided PUCT tree search for Metal kernel optimization. Uses LLM backends (Claude, OpenAI, mock) to propose kernel variants, evaluates them against baselines, and tracks the search tree. + +CLI: `python -m zmlx.discover {search, autorun, compare, report, list, export}`. + +### Training module (`train/`) +LoRA fine-tuning CLI: `zmlx train`. Config, runner, callbacks, export. + +### Integrations +- `exo.py` + `_exo_bootstrap/` — exo distributed inference integration +- `mlx_lm_compat.py` — compatibility layer for mlx-lm API differences +- `kv_cache.py` — quantized KV cache support + +## Model Support + +| Family | ZMLX Family Key | Architecture | Stock MLX | + Custom Primitive | +|:--|:--|:--|:--|:--| +| LFM2 | `lfm` | MoE | **+6-12% decode** (8B: fused SwiGLU; 24B: D-SIMD gate) | same | +| GLM-4.7 | `glm` | MoE | 0% (auto-skipped) | custom primitive path; current default `fp32_no_fma` (`promote_candidate`) | +| Qwen3-MoE | `qwen` | MoE | 0% (auto-skipped) | no promoted custom-kernel variant in current truth set | +| GPT-OSS | `gpt_oss` | MoE | ~+1% | same | +| DeepSeek-V3 | `deepseek` | MoE | patterns apply | untested (needs >300GB) | +| Kimi-K2.5 | `deepseek` | MoE | patterns apply | untested (needs >300GB) | +| Llama | `llama` | Dense | swiglu_mlp | same | +| Other | `unknown` | varies | safe no-op | same | + +`patch()` is always safe to call. It returns unchanged if no patterns match or if the model family has known issues. + +## CLI Entry Points + +| Command | Purpose | +|:--|:--| +| `python -m zmlx.validate ` | Fidelity + throughput validation | +| `python -m zmlx.matrix {catalog,run,report,csv,html}` | Test matrix across models | +| `python -m zmlx.bench.report ` | Print benchmark report from capsule | +| `python -m zmlx.foundry {run,report,export,export-sft,list}` | Kernel template evaluation + dataset generation | +| `python -m zmlx.discover {search,autorun,compare,report,list,export}` | LLM-guided kernel optimization search | +| `zmlx train` | LoRA fine-tuning CLI | + +## Design Decisions + +MLX caches compiled Metal programs by source string. ZMLX generates MSL deterministically, so the same kernel configuration always reuses the compiled binary. No separate compilation cache needed. + +In MoE decode, sequence length M is typically 1. At this scale, kernel dispatch overhead dominates over compute. The fused SwiGLU path activates only when M <= 32. At larger M (prefill), the standard MLX code path runs unchanged — prefill throughput is neutral. + +Binary elementwise ops require matching shapes. Deliberate choice to avoid silent shape-dependent bugs from broadcasting. + +## Conventions + +Linting: `ruff check .` — line-length 100, target py310, rules `[E, F, I, UP, B]`, E501 ignored. +Type checking: `mypy src` — `ignore_missing_imports = true`, `warn_return_any = true`. + +Every kernel needs a correctness test against the MLX reference op (see `tests/test_kernels_catalog.py`). +Performance claims need a repro capsule in `benchmarks/repro_capsules/` with hardware, OS, MLX version, raw per-run data. +New patch patterns go in `src/zmlx/patch/patterns/` as a single file, registered in `patterns/__init__.py`. + +## Release Policy + +- **Stable (default)**: stock MLX, only token-identical patches enabled. +- **Fast (opt-in)**: custom MLX allowed; experimental kernels opt-in, must be validated. +- **Edge (opt-in)**: nightly/dev MLX + experimental kernels for local testing only. + +## Model Storage + +HuggingFace models resolve through `HF_HOME`. Set it to a directory with enough space: + +``` +export HF_HOME=/path/to/your/hf_cache +``` + +## Key Files + +| File | Purpose | +|:--|:--| +| `src/zmlx/patch/__init__.py` | Main `patch()` function, safety excludes, model family detection | +| `src/zmlx/patch/patterns/moe_mlp.py` | Fused MoE pattern (the main performance win) | +| `src/zmlx/patch/patterns/swiglu_mlp.py` | Dense SwiGLU fusion | +| `src/zmlx/kernels/fused_moe.py` | `gather_qmm_swiglu` detection and wrapper | +| `src/zmlx/kernels/moe.py` | MoE combine kernels (`moe_combine_no_fma`, etc.) | +| `src/zmlx/validate.py` | Fidelity + throughput validation CLI | +| `src/zmlx/matrix/` | Test matrix: catalog, runner, reports | +| `src/zmlx/api.py` | Public kernel authoring API | +| `src/zmlx/mlx_lm_compat.py` | mlx-lm version compatibility | +| `src/zmlx/foundry/__main__.py` | Foundry CLI entry point | +| `src/zmlx/foundry/ops/` | 16 registered ops across 9 kernel classes | +| `src/zmlx/foundry/export/sft.py` | Chat SFT dataset export | +| `src/zmlx/discover/__main__.py` | Discover CLI entry point | +| `integrations/mlx_local_integration/` | Custom MLX primitive patch | + +## Documentation + +| File | Contents | +|:--|:--| +| `README.md` | Benchmarks, usage, API overview, model support | +| `docs/TOUR.md` | Quick walkthrough for new contributors | +| `docs/ARCHITECTURE.md` | Design philosophy and multi-frontend vision | +| `docs/KERNELS.md` | Complete kernel catalog reference | +| `docs/QUICKSTART.md` | 5-minute kernel authoring tutorial | +| `docs/COOKBOOK.md` | Recipes for common patterns | +| `docs/BENCHMARKS.md` | Benchmark methodology and raw data | +| `docs/EXO.md` | exo integration guide (GLM/Qwen3) | +| `docs/EXPERIMENTAL_MLX.md` | Custom `gather_qmm_swiglu` primitive build instructions | +| `docs/ROADMAP.md` | Feature roadmap with priority matrix | +| `docs/FOUNDRY.md` | Foundry module: kernel template evaluation and SFT export | +| `docs/DEVELOPMENT.md` | Development areas and backlog | +| `UPSTREAM_PLAN.md` | What belongs upstream in MLX vs stays in ZMLX | diff --git a/README.md b/README.md index 4a9fe32..af0712a 100644 --- a/README.md +++ b/README.md @@ -11,8 +11,8 @@ ZMLX extends [MLX](https://github.com/ml-explore/mlx) with a Python-first Metal - **Metal kernels from Python:** write `elementwise("x * tanh(log(1 + exp(x)))")` and get a compiled Metal kernel with caching, autograd support, and the 70+ kernel catalog. - **Model patching:** `patch(model)` replaces MoE gating/combine/activation sequences with fused Metal kernels, reducing dispatch overhead during decode. Token-identical output; verify with `python -m zmlx.validate`. +- **Works with stock MLX:** LFM2-8B (+12%) and LFM2-24B (+7%) show consistent decode gains with `pip install mlx` — no custom builds required. - **Optional custom primitive (GLM/Qwen3):** build the custom `gather_qmm_swiglu` primitive to fuse quantized expert projections for GLM-4.7-Flash and Qwen3-30B-A3B. See the GLM-4.7-Flash stress benchmark results below + [`docs/EXPERIMENTAL_MLX.md`](docs/EXPERIMENTAL_MLX.md). On stock MLX these models auto-skip safely. -- **Proven on current MLX:** LFM2-8B-A1B-4bit shows consistent decode gains in current matrix runs with token-identical output. ## Benchmark Snapshot (2026-02-08) @@ -55,7 +55,7 @@ If you are using GLM with custom MLX, this is already the default behavior: | GLM-4.7-Flash-4bit-mxfp4 | `patch(model)` default (`glm_combine_fp32_no_fma`) | `+6.2%` (200), `+6.7%` (1024), `~+6.4%` average | `+2.3%` average (`+0.3%..+6.7%`) | PASS | `benchmarks/repro_capsules/glm47_combo_v8_fp32nofmaonly_t200_r2_summary.json`, `benchmarks/repro_capsules/glm47_combo_v8_fp32nofmaonly_t1024_r2_summary.json`, `benchmarks/repro_capsules/benchmark_vs_baseline_followup_20260211.json` | | Qwen3-30B-A3B-4bit | keep control baseline | no promoted overall gain claim | no reliable decode gain yet | PASS | `benchmarks/repro_capsules/benchmark_vs_baseline_followup_20260211.json` | -GLM long-context confirmation (`runs=5`, `max_tokens=1024`): decode `+0.93%` vs control (PASS fidelity). +GLM long-context confirmation (`runs=5`, `max_tokens=1024`): decode `+0.93%` vs control (PASS fidelity). Capsule: `benchmarks/repro_capsules/glm47_final_longconfirm_t1024_r5_20260211_summary.json`. How to actually get the extra GLM speedup: @@ -66,6 +66,28 @@ How to actually get the extra GLM speedup: For full protocol and per-variant detail, see `benchmarks/LAB_NOTEBOOK.md`. +### Benchmark Execution Protocol (2026-02-13) + +Use the isolation-first benchmark flow for GLM/Qwen3: +- Run one variant per process via `benchmarks/bench_iso_variant_sweep.py`. +- Use AB/BA replicate blocks (with cooldown) for GLM consistency checks. +- Treat Qwen custom-kernel variants as experimental only until a decode-positive + signal is reproduced against `control_patterns_moe_mlp`. + +Phase runner (explicit phase selection, no hidden background fanout): + +```bash +source .venv/bin/activate + +# Run a single phase (recommended) +bash benchmarks/run_3hr_benchmark_campaign.sh quick +bash benchmarks/run_3hr_benchmark_campaign.sh glm_abba_200 +bash benchmarks/run_3hr_benchmark_campaign.sh glm_abba_1024 + +# Optional full sequence +bash benchmarks/run_3hr_benchmark_campaign.sh all +``` + ## GLM-4.7-Flash Stress Benchmark (Historical Reference) Historical stress result (M4 Max, MLX `0.30.4.dev20260204+2f324cc`, 5 prompts x 3 lengths x 5 runs): @@ -130,8 +152,9 @@ pip install "zmlx[lm]" # includes mlx-lm for model patching import mlx_lm from zmlx.patch import patch -model, tokenizer = mlx_lm.load("mlx-community/LFM2-8B-A1B-4bit") -patch(model) # safe inference defaults for supported model families +# Works with any supported model — just change the model ID +model, tokenizer = mlx_lm.load("LiquidAI/LFM2-24B-A2B-MLX-4bit") +patch(model) # auto-detects model family, applies safe optimizations print( mlx_lm.generate( @@ -143,12 +166,31 @@ print( ) ``` +That's it. `patch(model)` handles everything automatically — model detection, kernel selection, and safety checks. No env vars or configuration needed. + 3. Verify token fidelity + throughput on your hardware: ```bash +# LFM2-24B (+7% on M4 Max) +python -m zmlx.validate LiquidAI/LFM2-24B-A2B-MLX-4bit --max-tokens 200 --runs 3 + +# LFM2-8B (+12% on M4 Max) python -m zmlx.validate mlx-community/LFM2-8B-A1B-4bit --max-tokens 200 --runs 3 ``` +One-command smoke inference (loads model, applies `zmlx.patch.patch(model)`, then generates): + +```bash +source .venv/bin/activate && python examples/inference_smoke.py --model-id --prompt "" --max-tokens 64 +``` + +Expected output shape: +- `[load] model=` +- `[patch] Applying zmlx.patch.patch(model) with safe defaults` +- `[patch] Patched ...` +- `[generate] prompt='...' max_tokens=64` +- `[output]` followed by generated text + Tip: large model downloads use the Hugging Face cache; set `HF_HOME` to control its location. ## What's Inside @@ -190,6 +232,7 @@ ZMLX hooks into exo's model loading at runtime — when GLM/Qwen3 load with the | [`docs/COOKBOOK.md`](docs/COOKBOOK.md) | Recipes for common patterns | | [`docs/KERNELS.md`](docs/KERNELS.md) | Kernel catalog (by module/domain) | | [`docs/KNOWLEDGE_BASE.md`](docs/KNOWLEDGE_BASE.md) | Canonical KB schema, rebuild, and validation | +| [`docs/FOUNDRY.md`](docs/FOUNDRY.md) | Kernel template evaluation, dataset generation, SFT export | | [`docs/kernel_discovery.md`](docs/kernel_discovery.md) | Hamiltonian-guided fused-boundary kernel discovery (`zmlx.kd`) | | [`docs/BENCHMARKS.md`](docs/BENCHMARKS.md) | Benchmark methodology + raw data | | [`docs/ARCHITECTURE.md`](docs/ARCHITECTURE.md) | Design philosophy | @@ -221,6 +264,7 @@ Full methodology and raw data: [`docs/BENCHMARKS.md`](docs/BENCHMARKS.md). |:--|:--|--:|--:|:--|:--| | LFM2-8B-A1B-4bit | M4 Max 36 GB | 197.8 tok/s -> 223.2 tok/s | **+12.8%** | token-identical | [`benchmarks/repro_capsules/lfm2_m4max_20260205_rerun_mlx0304dev2f324cc.json`](benchmarks/repro_capsules/lfm2_m4max_20260205_rerun_mlx0304dev2f324cc.json) | | LFM2-8B-A1B-4bit | M1 Pro 16 GB | 105.5 tok/s -> 115.3 tok/s | +9.3% | token-identical | [`benchmarks/repro_capsules/lfm2_m1pro_20260131.json`](benchmarks/repro_capsules/lfm2_m1pro_20260131.json) | +| LFM2-24B-A2B-4bit | M4 Max 36 GB | 152.0 tok/s -> 161.1 tok/s | **+6.0%** | token-identical (500 tok) | [`benchmarks/repro_capsules/lfm2_24b_dsimd_gate_m4max_20260224.json`](benchmarks/repro_capsules/lfm2_24b_dsimd_gate_m4max_20260224.json) | | GPT-OSS-20B-4bit | M4 Max 36 GB | 121.8 tok/s -> 122.9 tok/s | +1.0% | token-identical | — | To print a report from a capsule: @@ -242,7 +286,7 @@ ZMLX provides the model-side integration: auto-detecting MoE architectures, rewi | Model | Recommended config | Overall decode gain vs unpatched baseline | Fidelity | Evidence | |:--|:--|--:|:--|:--| -| GLM-4.7-Flash-4bit-mxfp4 | `glm_combine_fp32_no_fma` | `+6.2%` (200), `+6.7%` (1024), `~+6.4%` average | PASS | `benchmarks/repro_capsules/glm47_combo_v8_fp32nofmaonly_t200_r2_summary.json`, `benchmarks/repro_capsules/glm47_combo_v8_fp32nofmaonly_t1024_r2_summary.json` | +| GLM-4.7-Flash-4bit-mxfp4 | `glm_combine_fp32_no_fma` | `+6.2%` (200), `+6.7%` (1024), `~+6.4%` average | PASS | `benchmarks/repro_capsules/glm47_combo_v8_fp32nofmaonly_t200_r2_summary.json`, `benchmarks/repro_capsules/glm47_combo_v8_fp32nofmaonly_t1024_r2_summary.json`, `benchmarks/repro_capsules/benchmark_vs_baseline_followup_20260211.json` | Qwen note: no candidate is promoted yet; keep control baseline until a clear decode-positive variant is reproduced. @@ -261,7 +305,8 @@ See [`docs/EXPERIMENTAL_MLX.md`](docs/EXPERIMENTAL_MLX.md) for build instruction | Model | Stock MLX | + Custom primitive | What ZMLX does | |:--|:--|:--|:--| -| LFM2-8B-A1B | speedup (see stock MLX table) | same | ZMLX Metal kernels: fused MoE gating + combine + SwiGLU | +| LFM2-8B-A1B | **+12% decode** | same | Fused MoE gating + combine + SwiGLU activation | +| LFM2-24B-A2B | **+6-7% decode** | same | D-SIMD fused gating kernel (64 experts, K=4) | | GLM-4.7-Flash | 0% (auto-skipped) | speedup (see custom primitive table) | ZMLX patching + custom `gather_qmm_swiglu` primitive | | Qwen3-30B-A3B | 0% (auto-skipped) | speedup (see custom primitive table) | ZMLX patching + custom `gather_qmm_swiglu` primitive | | GPT-OSS-20B | fused SwiGLU activation | same | ZMLX Metal kernel: fused SwiGLU activation | diff --git a/UPSTREAM_PLAN.md b/UPSTREAM_PLAN.md new file mode 100644 index 0000000..f8ae250 --- /dev/null +++ b/UPSTREAM_PLAN.md @@ -0,0 +1,66 @@ +# Upstream Plan + +What belongs in MLX, what stays in ZMLX, and how to get there. + +## Scope + +ZMLX sits on top of MLX's `mx.fast.metal_kernel` and `mx.custom_function` APIs. Most of ZMLX (kernel authoring, autograd wrappers, the catalog) is a toolkit that doesn't need to be upstream. The fused C++ Metal primitives in `mlx_local/` are the upstream candidates — they need access to MLX internals that `metal_kernel` can't reach. + +## Principles + +- Prefer general-purpose primitives that benefit many model families. +- Keep MoE‑specific fusions in ZMLX unless MLX maintainers request them. +- Default user path must remain token‑identical (performance without silent drift). + +## Upstream candidates (general‑purpose first) + +| Candidate | Status | Why it belongs in MLX | +|:--|:--|:--| +| `mx.fast.swiglu` | Proposed | Common transformer activation (Llama/Qwen/Mistral/Gemma/Phi). Small, clean primitive that mirrors `mx.fast.*` patterns. | +| `add_rms_norm` | Benchmark‑gated | Only useful if MLX doesn’t already fuse `x + residual` into `rms_norm`. | +| `gather_qmm_swiglu` | Local fork (RFC first) | MoE‑specific, uses MLX quantized matmul internals. Needs maintainers’ buy‑in before a PR. | + +### Minimal PR: `mx.fast.swiglu` + +**What it does:** Computes `silu(gate) * up` in a single pass with a small elementwise Metal kernel. + +**Why it helps:** Every modern transformer uses SwiGLU. A fused op reduces dispatch overhead and intermediate buffers and fits MLX’s existing `mx.fast` API pattern. + +**PR structure:** +1. Declare `swiglu()` in `mlx/fast.h` +2. Add primitive + VJP (patterned on RMSNorm) +3. Implement a tiny Metal kernel (single‑pass elementwise) +4. Add Python binding +5. Tests: forward correctness + VJP + CPU/Metal parity + +### Benchmark gate: `add_rms_norm` + +Before proposing a primitive, benchmark whether MLX already fuses: + +``` +mx.fast.rms_norm(x + residual, w, eps) +``` + +If MLX already fuses this into a single dispatch, an `add_rms_norm` primitive adds complexity without benefit. If it does not, a fused add+norm can halve memory bandwidth at every layer. + +### MoE‑specific fusions + +- **`gather_qmm_swiglu`**: keep in `mlx_local/` until an MLX Discussion confirms interest. Share LFM2 and Qwen3 repro capsules to justify it. +- **`gather_qmm_combine`**: stays in ZMLX (achievable via `metal_kernel` and too specialized for core MLX). + +## Release policy (ZMLX) + +- **Stable (default)**: stock MLX, only token‑identical patches enabled. +- **Fast (opt‑in)**: dev MLX allowed; experimental kernels are opt‑in and must be validated. +- **Edge (opt‑in)**: nightly/dev MLX + experimental kernels for local testing only. + +ZMLX should auto‑detect dev MLX features (e.g. `mx.gather_qmm_swiglu`) and only enable them when present. + +## Roadmap (sequenced) + +1. **Dev‑MLX delivery**: publish a reproducible dev‑MLX build path (wheel or source) and a runtime capability check. +2. **Qwen3‑30B‑A3B (base + instruct)**: run token‑parity + perf on stock MLX vs dev MLX and record results in capsules. +3. **Capsules + docs**: update README/CLAUDE with release policy and dev‑MLX behavior. +4. **Upstream `mx.fast.swiglu`**: submit the small, general PR once tests are ready. +5. **`add_rms_norm`**: submit only if the benchmark shows no existing fusion. +6. **MoE RFC**: open an MLX Discussion for `gather_qmm_swiglu` before any PR. diff --git a/benchmarks/LAB_NOTEBOOK.md b/benchmarks/LAB_NOTEBOOK.md index 50cfc53..13868d1 100644 --- a/benchmarks/LAB_NOTEBOOK.md +++ b/benchmarks/LAB_NOTEBOOK.md @@ -769,3 +769,312 @@ Decision update: - GLM long-context confirmation (`runs=5`, `max_tokens=1024`) remained fidelity-safe and decode-positive (`+0.93 pp` vs control), but with prefill regression in that block (`+2.35 pp` vs control). - GLM: keep `glm_combine_fp32_no_fma` as active default, with status `promote_candidate` (not fully promoted) pending tighter prefill variance. - Qwen: no promotion; keep baseline control behavior as the benchmark truth anchor. + +## EXP-20260211T164421Z-GLM-T2-ROW-TILE-TRANSFER + +Question: +- Does foundry-inspired `t2_row_tile` transfer into production GLM combine and beat current `glm_combine_fp32_no_fma` with fidelity preserved? + +Code changes: +- `src/zmlx/kernels/moe.py` + - Added `moe_combine_fp32_no_fma_row_tile()` plus row-tile kernel builder. +- `src/zmlx/patch/patterns/moe_mlp.py` + - Added GLM combine mode `fp32_no_fma_row_tile`. +- `benchmarks/bench_glm47_flash_experiments.py` + - Added benchmark variant `glm_combine_fp32_no_fma_row_tile`. + +Method: +- `python benchmarks/bench_iso_variant_sweep.py --suite glm47 --variants control_swiglu_moe glm_combine_fp32_no_fma glm_combine_fp32_no_fma_row_tile --runs 3 --max-tokens 200 --prefix glm47_t2_row_tile_transfer_t200_r3_20260211` +- `python benchmarks/bench_iso_variant_sweep.py --suite glm47 --variants control_swiglu_moe glm_combine_fp32_no_fma glm_combine_fp32_no_fma_row_tile --runs 3 --max-tokens 1024 --prefix glm47_t2_row_tile_transfer_t1024_r3_20260211` + +Artifacts: +- `benchmarks/repro_capsules/glm47_t2_row_tile_transfer_t200_r3_20260211_summary.json` +- `benchmarks/repro_capsules/glm47_t2_row_tile_transfer_t200_r3_20260211_control_swiglu_moe.json` +- `benchmarks/repro_capsules/glm47_t2_row_tile_transfer_t200_r3_20260211_glm_combine_fp32_no_fma.json` +- `benchmarks/repro_capsules/glm47_t2_row_tile_transfer_t200_r3_20260211_glm_combine_fp32_no_fma_row_tile.json` +- `benchmarks/repro_capsules/glm47_t2_row_tile_transfer_t1024_r3_20260211_summary.json` +- `benchmarks/repro_capsules/glm47_t2_row_tile_transfer_t1024_r3_20260211_control_swiglu_moe.json` +- `benchmarks/repro_capsules/glm47_t2_row_tile_transfer_t1024_r3_20260211_glm_combine_fp32_no_fma.json` +- `benchmarks/repro_capsules/glm47_t2_row_tile_transfer_t1024_r3_20260211_glm_combine_fp32_no_fma_row_tile.json` + +Results: +- `t=200`: + - `control_swiglu_moe`: fidelity `PASS` (`200/200`), decode speedup `0.9594x`, prefill change `-0.46%` + - `glm_combine_fp32_no_fma`: fidelity `PASS` (`200/200`), decode speedup `0.9582x`, prefill change `-1.44%` + - `glm_combine_fp32_no_fma_row_tile`: fidelity `PASS` (`200/200`), decode speedup `0.9283x`, prefill change `-7.88%` +- `t=1024`: + - `control_swiglu_moe`: fidelity `PASS` (`1024/1024`), decode speedup `0.9308x`, prefill change `+1.12%` + - `glm_combine_fp32_no_fma`: fidelity `PASS` (`1024/1024`), decode speedup `0.9688x`, prefill change `-0.79%` + - `glm_combine_fp32_no_fma_row_tile`: fidelity `PASS` (`1024/1024`), decode speedup `0.9537x`, prefill change `-2.91%` + +Interpretation: +- Fidelity gate passes for row-tile at both lengths. +- Transfer does not beat current best: + - decode is materially worse at `t=200` and still worse at `t=1024` vs `glm_combine_fp32_no_fma`. + +Decision: +- Do not promote `fp32_no_fma_row_tile` into GLM default path. +- Keep `glm_combine_fp32_no_fma` as current best production combine mode. + +## EXP-20260211T165600Z-GLM-PREFILL-CONSISTENCY-REPC-REPD + +Question: +- Do two additional long-context ABBA blocks (`t=1024`, `runs=5`) stabilize GLM prefill deltas within `+/-1 pp` and allow promotion from `promote_candidate` to `promoted`? + +Method: +- Block C (AB order): + - `python benchmarks/bench_iso_variant_sweep.py --suite glm47 --variants control_swiglu_moe glm_combine_fp32_no_fma --runs 5 --max-tokens 1024 --prefix glm47_prefill_consistency_abba_t1024_r5_repC_AB_20260211` +- Explicit cooldown between blocks: + - `sleep 45` +- Block D (BA order): + - `python benchmarks/bench_iso_variant_sweep.py --suite glm47 --variants glm_combine_fp32_no_fma control_swiglu_moe --runs 5 --max-tokens 1024 --prefix glm47_prefill_consistency_abba_t1024_r5_repD_BA_20260211` + +Artifacts: +- `benchmarks/repro_capsules/glm47_prefill_consistency_abba_t1024_r5_repC_AB_20260211_summary.json` +- `benchmarks/repro_capsules/glm47_prefill_consistency_abba_t1024_r5_repC_AB_20260211_control_swiglu_moe.json` +- `benchmarks/repro_capsules/glm47_prefill_consistency_abba_t1024_r5_repC_AB_20260211_glm_combine_fp32_no_fma.json` +- `benchmarks/repro_capsules/glm47_prefill_consistency_abba_t1024_r5_repD_BA_20260211_summary.json` +- `benchmarks/repro_capsules/glm47_prefill_consistency_abba_t1024_r5_repD_BA_20260211_control_swiglu_moe.json` +- `benchmarks/repro_capsules/glm47_prefill_consistency_abba_t1024_r5_repD_BA_20260211_glm_combine_fp32_no_fma.json` + +Results: +- Block C: + - `control_swiglu_moe`: fidelity `PASS`, decode speedup `0.9458x`, prefill change `-1.45%` + - `glm_combine_fp32_no_fma`: fidelity `PASS`, decode speedup `0.9532x`, prefill change `-0.22%` +- Block D: + - `glm_combine_fp32_no_fma`: fidelity `PASS`, decode speedup `0.9508x`, prefill change `-3.03%` + - `control_swiglu_moe`: fidelity `PASS`, decode speedup `1.0041x`, prefill change `+1.98%` +- Across all four `t=1024,r=5` AB/BA blocks (`repA`,`repB`,`repC`,`repD`), prefill delta (candidate vs control) ranges from `-1.26 pp` to `+0.79 pp` (spread `2.05 pp`), exceeding the `+/-1 pp` stabilization target. + +Interpretation: +- Fidelity stays stable (`PASS`) across all additional blocks. +- Prefill variance remains slightly outside the promotion envelope due the `repD` outlier (`-1.26 pp`). +- Decode signal is mostly positive but not monotonic (`repD` decode delta vs control is negative). + +Promotion decision: +- Keep GLM `glm_combine_fp32_no_fma` status as `promote_candidate`. +- Do not flip to `promoted` yet. + +## EXP-20260211T171500Z-WEIGHTED-SUM-FP32-NOFMA-K8-SPECIALIZATION + +Question: +- Can a new production `moe_combine_weighted_sum_fp32_no_fma` (`k=8` specialization) beat current `moe_combine_fp32_no_fma` in isolation and then at model level? + +Code changes: +- `src/zmlx/kernels/moe.py` + - Added `moe_combine_weighted_sum_fp32_no_fma()` and dedicated `K=8` row-tile/no-FMA kernel. +- `src/zmlx/patch/patterns/moe_mlp.py` + - Added GLM combine mode `fp32_no_fma_weighted_sum`. +- `benchmarks/bench_glm47_flash_experiments.py` + - Added benchmark variant `glm_combine_fp32_no_fma_weighted_sum`. +- `benchmarks/bench_moe_combine_weighted_sum_fp32_no_fma.py` (new) + - Added reproducible microbench entrypoint for this specialization. + +Isolation microbench method: +- `python benchmarks/bench_moe_combine_weighted_sum_fp32_no_fma.py --json-out benchmarks/repro_capsules/moe_combine_weighted_sum_fp32_no_fma_microbench_20260211.json --warmup 6 --repeats 50` + +Isolation microbench artifact: +- `benchmarks/repro_capsules/moe_combine_weighted_sum_fp32_no_fma_microbench_20260211.json` + +Isolation microbench result: +- median `specialized/base` ratio: `1.012044` (slightly slower) +- wins: `2/5` cases +- max abs diff vs baseline kernel: `0.0` + +Model-level method: +- `python benchmarks/bench_iso_variant_sweep.py --suite glm47 --variants control_swiglu_moe glm_combine_fp32_no_fma glm_combine_fp32_no_fma_weighted_sum --runs 3 --max-tokens 200 --prefix glm47_weighted_sum_fp32_no_fma_t200_r3_20260211` +- `python benchmarks/bench_iso_variant_sweep.py --suite glm47 --variants control_swiglu_moe glm_combine_fp32_no_fma glm_combine_fp32_no_fma_weighted_sum --runs 3 --max-tokens 1024 --prefix glm47_weighted_sum_fp32_no_fma_t1024_r3_20260211` + +Model-level artifacts: +- `benchmarks/repro_capsules/glm47_weighted_sum_fp32_no_fma_t200_r3_20260211_summary.json` +- `benchmarks/repro_capsules/glm47_weighted_sum_fp32_no_fma_t200_r3_20260211_control_swiglu_moe.json` +- `benchmarks/repro_capsules/glm47_weighted_sum_fp32_no_fma_t200_r3_20260211_glm_combine_fp32_no_fma.json` +- `benchmarks/repro_capsules/glm47_weighted_sum_fp32_no_fma_t200_r3_20260211_glm_combine_fp32_no_fma_weighted_sum.json` +- `benchmarks/repro_capsules/glm47_weighted_sum_fp32_no_fma_t1024_r3_20260211_summary.json` +- `benchmarks/repro_capsules/glm47_weighted_sum_fp32_no_fma_t1024_r3_20260211_control_swiglu_moe.json` +- `benchmarks/repro_capsules/glm47_weighted_sum_fp32_no_fma_t1024_r3_20260211_glm_combine_fp32_no_fma.json` +- `benchmarks/repro_capsules/glm47_weighted_sum_fp32_no_fma_t1024_r3_20260211_glm_combine_fp32_no_fma_weighted_sum.json` + +Model-level results: +- `t=200`: + - `glm_combine_fp32_no_fma`: fidelity `PASS`, decode speedup `0.9518x` + - `glm_combine_fp32_no_fma_weighted_sum`: fidelity `PASS`, decode speedup `0.9389x` +- `t=1024`: + - `glm_combine_fp32_no_fma`: fidelity `PASS`, decode speedup `0.9563x` + - `glm_combine_fp32_no_fma_weighted_sum`: fidelity `PASS`, decode speedup `0.9641x` (better than control, but still below current mode's patched decode tok/s in this run set) + +Decision: +- Keep weighted-sum specialization as experimental only. +- No promotion over `glm_combine_fp32_no_fma` at this stage. + +## EXP-20260212T210500Z-JIT-FUSION-16GB-GUARDED-MODEL-SWEEP + +Question: +- Under a hard 16GB RSS guard, which currently cached models show real decode/prefill/memory improvement from fusion-related patch paths? + +Method: +- Activated project venv before all Python runs: `source .venv/bin/activate`. +- Used a guarded runner that sampled process-tree RSS (`ps -o rss` + child aggregation) and killed at `16*1024*1024` KB. +- Ran `python -m zmlx.matrix run ... --max-tokens 64` with an initial sweep (`runs=1`) and targeted confirmations (`runs=3`). +- Used `--patterns none` controls per model to correct baseline-first ordering bias in `zmlx.matrix run`. + +Artifacts: +- `runs/fusion_guard_20260212/matrix.jsonl` +- `runs/fusion_guard_20260212/guard_summary.tsv` +- `runs/fusion_guard_20260212_confirm/matrix.jsonl` +- `runs/fusion_guard_20260212_confirm/guard_summary.tsv` +- `runs/fusion_guard_20260212_extra/matrix.jsonl` +- `runs/fusion_guard_20260212_extra/guard_summary.tsv` + +Key results: +- `mlx-community/LFM2-8B-A1B-4bit`: + - Control (`--patterns none`, `runs=3`): patched decode median `226.66 tok/s`, prefill `746.10 tok/s`, mem `5.30 GB`, guard peak `5.30 GB`, fidelity `PASS`. + - Candidate (`--patterns moe_mlp`, `runs=3`): patched decode median `236.52 tok/s`, prefill `733.77 tok/s`, mem `5.30 GB`, guard peak `5.33 GB`, fidelity `PASS`. + - Candidate delta vs control (patched throughput): decode `+4.35%`, prefill `-1.65%`, memory `+0.00 GB`. +- `mlx-community/LFM2-8B-A1B-8bit-MLX`: + - Control (`--patterns none`, `runs=3`): patched decode median `156.8 tok/s`, mem `9.45 GB`. + - Candidate (`--patterns moe_mlp`, `runs=3`): patched decode median `150.7 tok/s`, mem `9.45 GB`. + - Candidate delta vs control: decode `-3.89%`, prefill `-3.46%`, memory `+0.00 GB` (reject). +- `mlx-community/Qwen3-8B-4bit`: + - Control (`--patterns none`, `runs=3`): patched decode median `76.93 tok/s`, mem `4.71 GB`. + - Candidate (`--patterns swiglu_mlp`, `runs=3`): patched decode median `75.82 tok/s`, mem `4.72 GB`. + - Candidate delta vs control: decode `-1.44%`, prefill not improved, memory `+0.01 GB` (reject). +- `mlx-community/GLM-4.7-Flash-4bit`: + - Initial sweep (`runs=1`, default patch) showed patched decode `90.7 tok/s` and model peak memory `16.91 GB`. + - Guarded confirm control (`--patterns none`, `runs=3`) was killed at `16,978,416 KB` RSS: + - `[FAIL] glm47_control_none_r3: exceeded 16GB RSS (16978416 KB). Killing.` + - Under hard 16GB cap, GLM path is not admissible in this run set. + +Decision: +- Keep `moe_mlp` on `LFM2-8B-A1B-4bit` as the only currently validated decode-positive fusion path under the 16GB guard. +- Do not enable tested fusion candidates for `LFM2-8B-A1B-8bit-MLX` or `Qwen3-8B-4bit`. +- Treat GLM-4.7-Flash as out-of-budget for strict 16GB-cap validation unless a lower-memory configuration is introduced. + +## EXP-20260213T223500Z-CAMPAIGN-PROTOCOL-CORRECTION + +Question: +- Which benchmark instructions became stale or operationally unsafe, and what is + the corrected execution protocol? + +Observed issues: +- The previous `benchmarks/run_3hr_benchmark_campaign.sh` used multi-variant + single-process sweeps for GLM/Qwen3-sized runs. This conflicts with the + earlier pivot to one-variant-per-process isolation after OOM incidents. +- The previous campaign flow treated `qwen_combine_exact` as a promoted + benchmark target, while the latest follow-up snapshot kept Qwen control as + the benchmark anchor and marked tested Qwen custom variants as non-promoted. +- The previous matrix phase did not enforce `--patterns none` control anchors + that were adopted to reduce ordering bias in matrix runs. + +Code + docs updates: +- Reworked `benchmarks/run_3hr_benchmark_campaign.sh` into a phase-based, + isolation-first runner: + - explicit phase selector (`quick`, `glm_abba_200`, `glm_abba_1024`, + `qwen_isolation_200`, `qwen_isolation_1024`, `glm_stress`, + `matrix_controls`, `checks`, `all`) + - subprocess-isolated variant sweeps via + `benchmarks/bench_iso_variant_sweep.py` + - AB/BA GLM blocks with cooldown windows + - matrix control anchors with `--patterns none` +- Updated benchmark workflow documentation: + - `docs/BENCHMARKS.md` + - `README.md` + - `CLAUDE.md` + - `benchmarks/NEXT_AI_PIVOT_PROMPT.md` + +Decision: +- Keep isolation-first execution as the default protocol for GLM/Qwen3 campaign + work. +- Keep Qwen control (`control_patterns_moe_mlp`) as benchmark truth anchor + until a custom variant is reproduced decode-positive under isolation runs. + +## EXP-20260214T013000Z-ISOLATION-FIRST-CAMPAIGN-CONTINUATION + +Question: +- Continue the corrected isolation-first GLM/Qwen campaign phase-by-phase, collect capsule-backed decode/prefill/fidelity/memory evidence, and decide variant status. + +Method: +- Exact phase commands executed: + - `source .venv/bin/activate && bash benchmarks/run_3hr_benchmark_campaign.sh quick` + - `source .venv/bin/activate && bash benchmarks/run_3hr_benchmark_campaign.sh glm_abba_200` + - `source .venv/bin/activate && bash benchmarks/run_3hr_benchmark_campaign.sh glm_abba_1024` + - `source .venv/bin/activate && bash benchmarks/run_3hr_benchmark_campaign.sh qwen_isolation_200` + - `source .venv/bin/activate && bash benchmarks/run_3hr_benchmark_campaign.sh qwen_isolation_1024` + - `source .venv/bin/activate && bash benchmarks/run_3hr_benchmark_campaign.sh glm_stress` + - `source .venv/bin/activate && bash benchmarks/run_3hr_benchmark_campaign.sh matrix_controls` +- Pause handling: + - `glm_stress` was manually interrupted with `Ctrl-C` and rerun with the same command to completion. +- Aggregation approach for GLM/Qwen comparison tables: + - GLM: averaged AB/BA summaries per token length. + - Qwen: averaged repA/repB summaries per token length. + - Control anchor: + - GLM: `control_swiglu_moe` + - Qwen: `control_patterns_moe_mlp` + +Artifacts: +- Phase summaries: + - `benchmarks/results/campaign_20260213_quick_summary.txt` + - `benchmarks/results/campaign_20260213_glm_abba_200_summary.txt` + - `benchmarks/results/campaign_20260213_glm_abba_1024_summary.txt` + - `benchmarks/results/campaign_20260213_qwen_isolation_200_summary.txt` + - `benchmarks/results/campaign_20260213_qwen_isolation_1024_summary.txt` + - `benchmarks/results/campaign_20260213_glm_stress_summary.txt` + - `benchmarks/results/campaign_20260213_matrix_controls_summary.txt` +- GLM summaries: + - `benchmarks/repro_capsules/glm47_consistency_ab_t200_r5_20260213_summary.json` + - `benchmarks/repro_capsules/glm47_consistency_ba_t200_r5_20260213_summary.json` + - `benchmarks/repro_capsules/glm47_consistency_ab_t1024_r5_20260213_summary.json` + - `benchmarks/repro_capsules/glm47_consistency_ba_t1024_r5_20260213_summary.json` +- Qwen summaries: + - `benchmarks/repro_capsules/qwen3_isolation_ordered_t200_r5_repA_20260213_summary.json` + - `benchmarks/repro_capsules/qwen3_isolation_ordered_t200_r5_repB_20260213_summary.json` + - `benchmarks/repro_capsules/qwen3_isolation_ordered_t1024_r5_repA_20260213_summary.json` + - `benchmarks/repro_capsules/qwen3_isolation_ordered_t1024_r5_repB_20260213_summary.json` +- Stress artifacts: + - `benchmarks/repro_capsules/glm47_stress_full_20260213.json` + - `benchmarks/results/glm_stress/glm_stress_full_20260213_185549.log` +- Matrix control anchors (`--patterns none`) recorded in: + - `benchmarks/matrix.jsonl` + +Results: +- GLM-4.7-Flash-4bit-mxfp4 (control vs candidate medians, AB/BA averaged): + +| Tokens | Candidate | Control decode (tok/s) | Candidate decode (tok/s) | Decode delta vs control | Control prefill (tok/s) | Candidate prefill (tok/s) | Prefill delta vs control | Fidelity | Peak mem delta (GB) | Label | Reason | +|---|---|---:|---:|---:|---:|---:|---:|---|---:|---|---| +| 200 | `glm_combine_fp32_no_fma` | 73.95 | 76.65 | +3.65% | 268.35 | 278.50 | +3.78% | PASS | +0.00 | hold | Beats control anchor, but both control and candidate remain decode-negative vs own baselines in this environment snapshot. | +| 1024 | `glm_combine_fp32_no_fma` | 70.45 | 70.50 | +0.07% | 167.45 | 169.25 | +1.07% | PASS | +0.00 | hold | Near-tie vs control; still decode-negative vs own baseline and not strong enough for promotion. | + +- Qwen3-30B-A3B-4bit (control vs candidate medians, repA/repB averaged): + +| Tokens | Candidate | Control decode (tok/s) | Candidate decode (tok/s) | Decode delta vs control | Control prefill (tok/s) | Candidate prefill (tok/s) | Prefill delta vs control | Fidelity | Peak mem delta (GB) | Label | Reason | +|---|---|---:|---:|---:|---:|---:|---:|---|---:|---|---| +| 200 | `qwen_combine_exact` | 107.10 | 106.10 | -0.93% | 331.70 | 329.15 | -0.77% | PASS | +0.00 | reject | Decode regresses vs control at short context; not promotable. | +| 200 | `qwen_router_argpartition_logits` | 107.10 | 105.50 | -1.49% | 331.70 | 329.65 | -0.62% | PASS | +0.00 | reject | Largest short-context decode regression vs control. | +| 200 | `qwen_router_argpartition_logits_topk_combine_exact` | 107.10 | 106.65 | -0.42% | 331.70 | 329.00 | -0.81% | PASS | +0.00 | reject | Still decode-negative at short context; inconsistent. | +| 1024 | `qwen_combine_exact` | 101.75 | 101.65 | -0.10% | 324.80 | 321.20 | -1.11% | PASS | -0.01 | reject | No decode win vs control and prefill regression. | +| 1024 | `qwen_router_argpartition_logits` | 101.75 | 102.80 | +1.03% | 324.80 | 324.40 | -0.12% | PASS | +0.00 | reject | Long-context decode improvement is not enough to offset short-context regression. | +| 1024 | `qwen_router_argpartition_logits_topk_combine_exact` | 101.75 | 102.45 | +0.69% | 324.80 | 322.95 | -0.57% | PASS | +0.00 | reject | Same pattern as router-only: mixed by length, not robust enough. | + +- GLM stress (`glm47_stress_full_20260213.json`): + - Fidelity: all PASS. + - Average decode speedup by length: `256=1.006x`, `1024=1.002x`, `2048=1.008x`. + - Average decode speedup by prompt family: `english_technical=0.990x`, `chinese=0.993x`, `code=0.993x`, `math_reasoning=1.033x`, `creative=1.018x`. + - Peak memory remained stable (`baseline max=16.07 GB`, `patched max=16.07 GB`). + +- Matrix control anchors (`--patterns none`, `patterns_applied=[]`, PASS): + - `mlx-community/GLM-4.7-Flash-4bit-mxfp4`: decode `84.12 tok/s`, prefill `288.5 tok/s`, mem `15.98 GB`. + - `mlx-community/Qwen3-30B-A3B-4bit`: decode `113.88 tok/s`, prefill `336.57 tok/s`, mem `17.24 GB`. + - `mlx-community/LFM2-8B-A1B-4bit`: decode `223.67 tok/s`, prefill `737.92 tok/s`, mem `5.30 GB`. + +Variant decisions: +- `control_swiglu_moe`: hold (control anchor; no promotion semantics). +- `glm_combine_fp32_no_fma`: hold. +- `control_patterns_moe_mlp`: hold (control anchor; Qwen benchmark truth anchor). +- `qwen_combine_exact`: reject. +- `qwen_router_argpartition_logits`: reject. +- `qwen_router_argpartition_logits_topk_combine_exact`: reject. + +Decision: +- Keep isolation-first, subprocess-separated variant sweeps as mandatory for GLM/Qwen. +- Keep Qwen benchmark anchor at `control_patterns_moe_mlp`; do not promote `qwen_combine_exact`. +- Keep GLM `glm_combine_fp32_no_fma` at `hold` pending a cleaner environment where control/candidate are decode-positive vs own baselines. diff --git a/benchmarks/NEXT_AI_PIVOT_PROMPT.md b/benchmarks/NEXT_AI_PIVOT_PROMPT.md index c48c659..0b38d77 100644 --- a/benchmarks/NEXT_AI_PIVOT_PROMPT.md +++ b/benchmarks/NEXT_AI_PIVOT_PROMPT.md @@ -1,83 +1,86 @@ -# Next-AI Pivot Prompt (Custom MLX First) +# Next-AI Prompt (Isolation-First + Custom-MLX Pivot) Use this prompt verbatim with the next coding agent. --- -You are working in the ZMLX repo. We need to pivot from patch-heavy ZMLX experiments to custom MLX kernel authoring (C++/Metal) first, then thin ZMLX wiring. +You are working in the ZMLX repo. Continue the custom-MLX-first pivot, but use +strict isolation-first benchmarking for GLM/Qwen3 to avoid stale conclusions and +OOM-prone methodology. ## Mission -Deliver the highest-confidence decode throughput gains for: -- Qwen3-30B-A3B-4bit -- GLM-4.7-Flash-4bit +Produce high-confidence, reproducible benchmark evidence for: +- `mlx-community/GLM-4.7-Flash-4bit-mxfp4` +- `mlx-community/Qwen3-30B-A3B-4bit` without fidelity regressions. ## Hard constraints - Do not fabricate benchmark numbers. -- Every claim must be backed by repro capsules under `benchmarks/repro_capsules/`. -- Use in-repo notebook/matrix tracking: +- Every claim must be backed by repro capsules in `benchmarks/repro_capsules/`. +- Use in-repo tracking: - notebook: `benchmarks/LAB_NOTEBOOK.md` - matrix ledger: `benchmarks/matrix.jsonl` -- Do not edit external clones under `mlx_local/` or `exo/` directly unless required by the integration workflow. +- Do not edit external clones under `mlx_local/` or `exo/`. - Run with venv active: - `source .venv/bin/activate` - Before finalizing any change set, run: - `ruff check .` - `pytest -q` -## Existing findings to preserve -- Qwen best validated combo currently includes combine-exact path (see latest qwen v6 confirm capsules). -- New GLM override work shows: - - `ZMLX_GLM_COMBINE_MODE=fp32_no_fma` is fidelity-safe and beneficial. - - `exact` and `fp32` combine modes fail fidelity for GLM. - -## Where custom MLX lives -- Patch source: `integrations/mlx_local_integration/gather_qmm_swiglu.patch` -- Setup/build workflow: `integrations/mlx_local_integration/setup_mlx_local.sh` -- Workflow docs: `docs/EXPERIMENTAL_MLX.md` - -## Priority kernel targets (C++/Metal-first) +## Current benchmark truth to preserve +- GLM: + - active default path: `glm_combine_fp32_no_fma` + - status: `promote_candidate` (not fully promoted) +- Qwen: + - benchmark anchor: `control_patterns_moe_mlp` + - no custom-kernel variant is currently promoted +- Rejected variants to keep rejected unless new evidence reverses them: + - `qwen_combine_exact` + - `qwen_router_argpartition_logits_topk_combine_exact` + - `glm_combine_fp32_no_fma_shared_overlap_streams2` + - `glm_combine_fp32_no_fma_glm47_rope` + +## Benchmark methodology (mandatory) +- Use `benchmarks/bench_iso_variant_sweep.py` for GLM/Qwen3 variant sweeps. +- One variant per subprocess only. +- Use AB/BA replicate blocks with cooldown windows for GLM consistency checks. +- For matrix controls, run `python -m zmlx.matrix run ... --patterns none`. +- Do not reintroduce multi-variant single-process sweeps for GLM/Qwen3. + +## Campaign command pattern +Run one phase at a time (recommended): + +```bash +source .venv/bin/activate +bash benchmarks/run_3hr_benchmark_campaign.sh quick +bash benchmarks/run_3hr_benchmark_campaign.sh glm_abba_200 +bash benchmarks/run_3hr_benchmark_campaign.sh glm_abba_1024 +bash benchmarks/run_3hr_benchmark_campaign.sh qwen_isolation_200 +bash benchmarks/run_3hr_benchmark_campaign.sh qwen_isolation_1024 +bash benchmarks/run_3hr_benchmark_campaign.sh glm_stress +bash benchmarks/run_3hr_benchmark_campaign.sh matrix_controls +``` + +## Custom-MLX kernel priorities (C++/Metal-first) 1. `moe_router_argpartition_logits_topk` -- Goal: fuse top-k selection on logits + normalization over selected logits for Qwen-style routing. -- Why: router overhead is still a bottleneck at decode batch sizes 1-4. -- Fidelity guardrail: preserve argpartition/top-k semantics exactly. +- Goal: fuse top-k selection + normalization for Qwen-style routing. +- Guardrail: exact semantic parity for selected indices and normalized weights. 2. `moe_gather_qmm_swiglu_downproj_combine` -- Goal: one primitive for gather -> qmm swiglu -> downproj -> weighted combine. -- Why: remove intermediate allocations and launch overhead. -- Fidelity guardrail: deterministic accumulation mode (`fp32_no_fma`) option. +- Goal: fuse gather -> qmm swiglu -> downproj -> weighted combine. +- Guardrail: deterministic accumulation mode option (`fp32_no_fma`). 3. `moe_combine_weighted_sum_fp32_no_fma` specialization -- Goal: optimize small-k combine for decode (k=8 common case). -- Why: this appears portable across Qwen and GLM. -- Fidelity guardrail: strict token-level comparison vs baseline. - -## Execution plan -1. Implement one kernel at a time in custom MLX (C++/Metal), wire capability detection. -2. Keep ZMLX-side change minimal and opt-in with env flags for A/B safety. -3. Benchmark each variant isolated, then in controlled combos. -4. For each run, write: -- repro capsule JSON in `benchmarks/repro_capsules/` -- matrix entry in `benchmarks/matrix.jsonl` -- short experiment note in `benchmarks/LAB_NOTEBOOK.md` -5. Promote only variants that satisfy both: -- token fidelity pass at configured token budget -- repeatable decode speedup across short and long context settings - -## Benchmark minimums -For every candidate kernel and combo: -- short decode check: `tokens=200`, `runs>=3` -- long decode check: `tokens=1024`, `runs>=2` -- capture and compare `peak_mem_gb` -- include baseline and patched runs on same hardware session +- Goal: optimize decode-typical small-k combine. +- Guardrail: strict token-level fidelity checks against control. ## Deliverables -- Ranked table of tested variants by decode speedup and fidelity status. -- Explicit "best combo" recommendation for Qwen3 and GLM (may differ). -- Clear reject list with failure reason (fidelity, perf regression, or memory regression). -- Updated notebook + matrix + repro capsules. +- Ranked variant table with decode delta, prefill delta, fidelity verdict, memory delta. +- Explicit recommendation for GLM and Qwen (promote/hold/reject). +- Reject list with concrete failure reasons and capsule paths. +- Updated `benchmarks/LAB_NOTEBOOK.md` entry with exact commands and artifacts. -Proceed immediately; do not stop at planning. +Proceed immediately and execute, not just plan. --- diff --git a/benchmarks/bench_glm47_flash_experiments.py b/benchmarks/bench_glm47_flash_experiments.py index e31d5c6..ed4ac38 100644 --- a/benchmarks/bench_glm47_flash_experiments.py +++ b/benchmarks/bench_glm47_flash_experiments.py @@ -24,7 +24,7 @@ import mlx.core as mx -DEFAULT_MODEL = "mlx-community/GLM-4.7-Flash-4bit" +DEFAULT_MODEL = "mlx-community/GLM-4.7-Flash-4bit-mxfp4" DEFAULT_PROMPT = ( "Explain the key differences between TCP and UDP protocols, " "including their use cases, reliability guarantees, and " @@ -156,7 +156,11 @@ def main() -> None: "Variant names to run (default: all). " "Choices: control_swiglu_moe, control_swiglu_moe_residual_norm, " "control_swiglu_moe_downproj_combine, glm_combine_exact, " - "glm_combine_fp32, glm_combine_fp32_no_fma, glm47_rope, " + "glm_combine_fp32, glm_combine_fp32_no_fma, " + "glm_combine_fp32_no_fma_row_tile, " + "glm_combine_fp32_no_fma_weighted_sum, " + "glm_combine_fp32_no_fma_shared_overlap_streams2, " + "glm_combine_fp32_no_fma_glm47_rope, glm47_rope, " "shared_experts_overlap_streams2" ), ) @@ -225,13 +229,13 @@ def main() -> None: { "name": "control_swiglu_moe", "patterns": ["swiglu_mlp", "moe_mlp"], - "env": {}, - "notes": "Control: current best patch set (no RoPE fusion).", + "env": {"ZMLX_GLM_COMBINE_MODE": "off"}, + "notes": "Control: current best patch set (no RoPE fusion, combine off).", }, { "name": "control_swiglu_moe_residual_norm", "patterns": ["swiglu_mlp", "moe_mlp", "residual_norm"], - "env": {}, + "env": {"ZMLX_GLM_COMBINE_MODE": "off"}, "notes": ( "Experimental: residual_norm fusion (breaks greedy token fidelity on " "GLM-4.7-Flash in current testing)." @@ -240,7 +244,10 @@ def main() -> None: { "name": "control_swiglu_moe_downproj_combine", "patterns": ["swiglu_mlp", "moe_mlp"], - "env": {"ZMLX_GLM_FUSED_DOWNPROJ_COMBINE": "1"}, + "env": { + "ZMLX_GLM_FUSED_DOWNPROJ_COMBINE": "1", + "ZMLX_GLM_COMBINE_MODE": "off", + }, "unsafe": True, "notes": ( "Experimental: fuse down-proj + combine for decode-like shapes " @@ -274,6 +281,25 @@ def main() -> None: "(ZMLX_GLM_COMBINE_MODE=fp32_no_fma)." ), }, + { + "name": "glm_combine_fp32_no_fma_row_tile", + "patterns": ["swiglu_mlp", "moe_mlp"], + "env": {"ZMLX_GLM_COMBINE_MODE": "fp32_no_fma_row_tile"}, + "notes": ( + "Experimental: force GLM combine via t2-row-tile-inspired fp32 " + "no-FMA kernel (ZMLX_GLM_COMBINE_MODE=fp32_no_fma_row_tile)." + ), + }, + { + "name": "glm_combine_fp32_no_fma_weighted_sum", + "patterns": ["swiglu_mlp", "moe_mlp"], + "env": {"ZMLX_GLM_COMBINE_MODE": "fp32_no_fma_weighted_sum"}, + "notes": ( + "Experimental: force GLM combine via K=8-specialized fp32 no-FMA " + "weighted-sum kernel " + "(ZMLX_GLM_COMBINE_MODE=fp32_no_fma_weighted_sum)." + ), + }, { "name": "glm47_rope", "patterns": ["swiglu_mlp", "moe_mlp", "glm47_rope"], @@ -284,11 +310,34 @@ def main() -> None: "name": "shared_experts_overlap_streams2", "patterns": ["swiglu_mlp", "moe_mlp"], "env": { + "ZMLX_GLM_COMBINE_MODE": "off", "ZMLX_MOE_STREAMS": "2", "ZMLX_MOE_SHARED_EXPERTS_OVERLAP": "1", }, "notes": "Experimental: overlap shared_experts(x) on a separate stream.", }, + { + "name": "glm_combine_fp32_no_fma_shared_overlap_streams2", + "patterns": ["swiglu_mlp", "moe_mlp"], + "env": { + "ZMLX_GLM_COMBINE_MODE": "fp32_no_fma", + "ZMLX_MOE_STREAMS": "2", + "ZMLX_MOE_SHARED_EXPERTS_OVERLAP": "1", + }, + "notes": ( + "Experimental additive combo: fp32-no-fma combine + " + "shared_experts overlap streams=2." + ), + }, + { + "name": "glm_combine_fp32_no_fma_glm47_rope", + "patterns": ["swiglu_mlp", "moe_mlp", "glm47_rope"], + "env": {"ZMLX_GLM_COMBINE_MODE": "fp32_no_fma"}, + "notes": ( + "Experimental additive combo: fp32-no-fma combine + " + "glm47_rope decode path." + ), + }, ] if args.variants is not None: wanted = {v.strip() for v in args.variants if v.strip()} diff --git a/benchmarks/bench_glm_stress.py b/benchmarks/bench_glm_stress.py index b70afee..b3b5d3e 100644 --- a/benchmarks/bench_glm_stress.py +++ b/benchmarks/bench_glm_stress.py @@ -38,7 +38,7 @@ import mlx.core as mx -DEFAULT_MODEL = "mlx-community/GLM-4.7-Flash-4bit" +DEFAULT_MODEL = "mlx-community/GLM-4.7-Flash-4bit-mxfp4" DEFAULT_PROMPTS = "english_technical,chinese,code,math_reasoning,creative" DEFAULT_LENGTHS = "256,1024,2048" diff --git a/benchmarks/bench_iso_variant_sweep.py b/benchmarks/bench_iso_variant_sweep.py index ef2170b..d70fa62 100644 --- a/benchmarks/bench_iso_variant_sweep.py +++ b/benchmarks/bench_iso_variant_sweep.py @@ -82,6 +82,31 @@ def _load_capsule(path: Path) -> dict[str, Any]: } +def _fmt_pct(value: float | None) -> str: + if value is None: + return "n/a" + return f"{value * 100:+.2f}%" + + +def _fmt_ratio(value: float | None) -> str: + if value is None: + return "n/a" + return f"{value:.4f}x" + + +def _fmt_num(value: float | None, *, digits: int = 1) -> str: + if value is None: + return "n/a" + return f"{value:.{digits}f}" + + +def _fidelity_str(fidelity: dict[str, Any]) -> str: + verdict = str(fidelity.get("verdict", "UNKNOWN")).upper() + matched = fidelity.get("matched", fidelity.get("matched_tokens", 0)) + total = fidelity.get("total", fidelity.get("total_tokens", 0)) + return f"{verdict} ({matched}/{total})" + + def _run_variant( *, repo_root: Path, @@ -241,6 +266,33 @@ def main() -> None: if rc == 0 and capsule_path.exists(): try: result["metrics"] = _load_capsule(capsule_path) + metrics = result["metrics"] + fidelity = _fidelity_str(metrics.get("fidelity", {})) + decode = metrics.get("decode", {}) + prefill = metrics.get("prefill", {}) + memory = metrics.get("memory_gb", {}) + decode_baseline = decode.get("baseline") + decode_patched = decode.get("patched") + prefill_baseline = prefill.get("baseline") + prefill_patched = prefill.get("patched") + mem_delta = None + if ( + memory.get("baseline") is not None + and memory.get("patched") is not None + ): + mem_delta = float(memory["patched"]) - float(memory["baseline"]) + print( + " summary: " + f"fidelity={fidelity}, " + "decode=" + f"{_fmt_num(decode_baseline)}->{_fmt_num(decode_patched)} tok/s " + f"({_fmt_ratio(decode.get('speedup'))}), " + "prefill=" + f"{_fmt_num(prefill_baseline)}->{_fmt_num(prefill_patched)} tok/s " + f"({_fmt_pct(prefill.get('change'))}), " + "mem_delta_gb=" + f"{_fmt_num(mem_delta, digits=2) if mem_delta is not None else 'n/a'}" + ) except Exception as exc: # pragma: no cover - defensive fallback result["metrics_error"] = str(exc) else: diff --git a/benchmarks/bench_lfm2_24b_sweep.py b/benchmarks/bench_lfm2_24b_sweep.py new file mode 100644 index 0000000..1640296 --- /dev/null +++ b/benchmarks/bench_lfm2_24b_sweep.py @@ -0,0 +1,225 @@ +#!/usr/bin/env python3 +"""LFM2-24B-A2B variant sweep: gate mode × combine mode matrix. + +Runs each variant in an isolated subprocess to avoid cross-contamination. +Produces repro capsule JSON for each variant. + +Usage: + # Full matrix (9 variants, ~20 min) + python benchmarks/bench_lfm2_24b_sweep.py + + # Quick fidelity gate only (200 tokens, 1 run) + python benchmarks/bench_lfm2_24b_sweep.py --quick + + # Single variant + python benchmarks/bench_lfm2_24b_sweep.py --gate native --combine native + + # Baseline only (no patching) + python benchmarks/bench_lfm2_24b_sweep.py --baseline-only +""" +from __future__ import annotations + +import argparse +import json +import os +import subprocess +import sys +import time +from datetime import datetime +from pathlib import Path + +MODEL = "LiquidAI/LFM2-24B-A2B-MLX-4bit" +CAPSULE_DIR = Path("benchmarks/repro_capsules") + +GATE_MODES = ["kernel", "native", "fast64"] +COMBINE_MODES = ["kernel", "native", "fp32_no_fma"] + + +def run_variant( + *, + gate: str | None, + combine: str | None, + max_tokens: int, + runs: int, + label: str, + patterns: list[str] | None = None, +) -> dict: + """Run a single variant via zmlx.validate in a subprocess.""" + env = os.environ.copy() + if gate is not None: + env["ZMLX_LFM2_GATE_MODE"] = gate + if combine is not None: + env["ZMLX_LFM2_COMBINE_MODE"] = combine + # Force moe_mlp pattern for LFM2 (normally auto-excluded by perf risk) + env["ZMLX_PATCH_ALLOW_PERF_RISK"] = "1" + + cmd = [ + sys.executable, "-m", "zmlx.validate", MODEL, + "--max-tokens", str(max_tokens), + "--runs", str(runs), + "--verbose", + ] + if patterns is not None: + cmd.extend(["--patterns"] + patterns) + + print(f"\n{'='*70}") + print(f" Variant: {label}") + print(f" Gate: {gate or 'N/A'}, Combine: {combine or 'N/A'}") + print(f" Tokens: {max_tokens}, Runs: {runs}") + print(f" Patterns: {patterns or 'default'}") + print(f"{'='*70}") + + t0 = time.time() + result = subprocess.run( + cmd, + env=env, + capture_output=True, + text=True, + timeout=600, + ) + elapsed = time.time() - t0 + + # Parse output + out = result.stdout + result.stderr + print(out[-3000:] if len(out) > 3000 else out) + + return { + "label": label, + "gate_mode": gate, + "combine_mode": combine, + "patterns": patterns, + "max_tokens": max_tokens, + "runs": runs, + "returncode": result.returncode, + "elapsed_s": round(elapsed, 1), + "stdout": result.stdout, + "stderr": result.stderr, + } + + +def parse_validate_output(text: str) -> dict: + """Extract key metrics from zmlx.validate output.""" + metrics = {} + for line in text.split("\n"): + line = line.strip() + if "fidelity" in line.lower(): + metrics["fidelity_line"] = line + if "matched" in line.lower() and "/" in line: + # e.g. "Fidelity: 200/200 matched" + parts = line.split() + for p in parts: + if "/" in p: + try: + a, b = p.split("/") + metrics["matched"] = int(a) + metrics["total"] = int(b) + except ValueError: + pass + if "decode" in line.lower() and "tok/s" in line.lower(): + metrics["decode_line"] = line + if "speedup" in line.lower() or "ratio" in line.lower(): + metrics["speedup_line"] = line + if "median" in line.lower(): + metrics["median_line"] = line + return metrics + + +def main(): + parser = argparse.ArgumentParser(description="LFM2-24B variant sweep") + parser.add_argument("--gate", choices=GATE_MODES, help="Single gate mode") + parser.add_argument("--combine", choices=COMBINE_MODES, help="Single combine mode") + parser.add_argument("--max-tokens", type=int, default=200) + parser.add_argument("--runs", type=int, default=3) + parser.add_argument("--quick", action="store_true", help="Quick fidelity check (200 tok, 1 run)") + parser.add_argument("--baseline-only", action="store_true") + parser.add_argument("--skip-baseline", action="store_true") + args = parser.parse_args() + + if args.quick: + args.max_tokens = 200 + args.runs = 1 + + CAPSULE_DIR.mkdir(parents=True, exist_ok=True) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + + results = [] + + # 1. Baseline (no patching) — always run first + if not args.skip_baseline: + r = run_variant( + gate=None, + combine=None, + max_tokens=args.max_tokens, + runs=args.runs, + label="baseline_unpatched", + patterns=["none"], # no patterns = baseline + ) + r["metrics"] = parse_validate_output(r["stdout"] + r["stderr"]) + results.append(r) + + if args.baseline_only: + _print_summary(results) + return + + # 2. Variant matrix + if args.gate and args.combine: + # Single variant + variants = [(args.gate, args.combine)] + elif args.gate: + variants = [(args.gate, c) for c in COMBINE_MODES] + elif args.combine: + variants = [(g, args.combine) for g in GATE_MODES] + else: + # Full matrix + variants = [(g, c) for g in GATE_MODES for c in COMBINE_MODES] + + for gate, combine in variants: + label = f"lfm2_24b_g{gate}_c{combine}" + r = run_variant( + gate=gate, + combine=combine, + max_tokens=args.max_tokens, + runs=args.runs, + label=label, + patterns=["moe_mlp"], + ) + r["metrics"] = parse_validate_output(r["stdout"] + r["stderr"]) + results.append(r) + + # Save combined results + capsule_path = CAPSULE_DIR / f"lfm2_24b_sweep_{timestamp}.json" + with open(capsule_path, "w") as f: + json.dump( + { + "model": MODEL, + "timestamp": timestamp, + "variants": [ + {k: v for k, v in r.items() if k not in ("stdout", "stderr")} + for r in results + ], + }, + f, + indent=2, + ) + print(f"\nCapsule saved: {capsule_path}") + + _print_summary(results) + + +def _print_summary(results: list[dict]) -> None: + print(f"\n{'='*70}") + print(" SUMMARY") + print(f"{'='*70}") + print(f"{'Label':<35} {'Fidelity':<12} {'Return':<8}") + print(f"{'-'*35} {'-'*12} {'-'*8}") + for r in results: + m = r.get("metrics", {}) + matched = m.get("matched", "?") + total = m.get("total", "?") + fid = f"{matched}/{total}" + rc = r["returncode"] + print(f"{r['label']:<35} {fid:<12} {rc:<8}") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/inference_benchmark.py b/benchmarks/inference_benchmark.py index e41268d..038efea 100644 --- a/benchmarks/inference_benchmark.py +++ b/benchmarks/inference_benchmark.py @@ -37,7 +37,7 @@ "qwen3-32b": "mlx-community/Qwen3-32B-4bit", "lfm2-8b": "mlx-community/LFM2-8B-A1B-8bit-MLX", "lfm2-8b-4bit": "mlx-community/LFM2-8B-A1B-4bit", - "glm-4.7-flash": "mlx-community/GLM-4.7-Flash-4bit", + "glm-4.7-flash": "mlx-community/GLM-4.7-Flash-4bit-mxfp4", "gemma3-27b": "mlx-community/gemma-3-27b-it-qat-4bit", "deepseek-r1-32b": "mlx-community/DeepSeek-R1-Distill-Qwen-32B-4bit", } diff --git a/benchmarks/matrix.jsonl b/benchmarks/matrix.jsonl index 51a0c99..b506d32 100644 --- a/benchmarks/matrix.jsonl +++ b/benchmarks/matrix.jsonl @@ -60,3 +60,225 @@ {"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: patch(model, patterns=['moe_mlp']) capsule=benchmarks/repro_capsules/qwen3_combo_v9_routertopk_t1024_r3_confirm_control_patterns_moe_mlp.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.3","zmlx_commit":"094296f","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T18:50:25.757149+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":111.3,"decode_tps_patched":115.1,"decode_speedup":1.0341,"decode_runs_baseline":[110.3,111.3,113.1],"decode_runs_patched":[114.8,115.2,115.1],"prefill_tps_baseline":334.1,"prefill_tps_patched":337.1,"prefill_change":0.009,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.34,"max_tokens":1024,"gen_tokens":1024,"runs":3} {"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen combine via ZMLX exact-dtype kernel (ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_combo_v9_routertopk_t1024_r3_confirm_qwen_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.3","zmlx_commit":"094296f","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T18:52:03.319579+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":112.4,"decode_tps_patched":119.3,"decode_speedup":1.0614,"decode_runs_baseline":[113.0,112.2,112.4],"decode_runs_patched":[119.1,119.3,119.6],"prefill_tps_baseline":321.5,"prefill_tps_patched":335.2,"prefill_change":0.0426,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.34,"max_tokens":1024,"gen_tokens":1024,"runs":3} {"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen argpartition(logits) fused top-k softmax plus exact combine kernel (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1 ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS_TOPK=1 ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_combo_v9_routertopk_t1024_r3_confirm_qwen_router_argpartition_logits_topk_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.3","zmlx_commit":"094296f","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T18:53:41.428156+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":113.1,"decode_tps_patched":118.7,"decode_speedup":1.0495,"decode_runs_baseline":[113.3,113.1,109.3],"decode_runs_patched":[118.7,118.1,119.1],"prefill_tps_baseline":334.2,"prefill_tps_patched":336.0,"prefill_change":0.0054,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.34,"max_tokens":1024,"gen_tokens":1024,"runs":3} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: patch(model, patterns=['moe_mlp']) capsule=benchmarks/repro_capsules/qwen3_takeover_combo_t200_r3_control_patterns_moe_mlp.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T20:07:44.530545+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":117.3,"decode_tps_patched":114.0,"decode_speedup":0.9719,"decode_runs_baseline":[113.9,117.3,118.0],"decode_runs_patched":[117.1,113.4,114.0],"prefill_tps_baseline":337.0,"prefill_tps_patched":337.7,"prefill_change":0.0021,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen combine via ZMLX exact-dtype kernel (ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_takeover_combo_t200_r3_qwen_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T20:08:44.179304+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":117.7,"decode_tps_patched":117.6,"decode_speedup":0.9992,"decode_runs_baseline":[117.6,117.7,117.7],"decode_runs_patched":[113.4,117.7,117.6],"prefill_tps_baseline":338.6,"prefill_tps_patched":335.5,"prefill_change":-0.0092,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen argpartition(logits) fused top-k softmax plus exact combine kernel (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1 ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS_TOPK=1 ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_takeover_combo_t200_r3_qwen_router_argpartition_logits_topk_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T20:09:42.475870+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":116.9,"decode_tps_patched":117.5,"decode_speedup":1.0051,"decode_runs_baseline":[113.9,116.9,116.9],"decode_runs_patched":[117.5,117.8,116.5],"prefill_tps_baseline":339.0,"prefill_tps_patched":333.2,"prefill_change":-0.0171,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: patch(model, patterns=['moe_mlp']) capsule=benchmarks/repro_capsules/qwen3_takeover_combo_t1024_r2_control_patterns_moe_mlp.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T20:11:11.893718+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":112.5,"decode_tps_patched":111.2,"decode_speedup":0.9884,"decode_runs_baseline":[112.4,112.6],"decode_runs_patched":[109.6,112.7],"prefill_tps_baseline":327.8,"prefill_tps_patched":335.9,"prefill_change":0.0247,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.33,"max_tokens":1024,"gen_tokens":1024,"runs":2} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen combine via ZMLX exact-dtype kernel (ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_takeover_combo_t1024_r2_qwen_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T20:12:35.297776+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":113.1,"decode_tps_patched":113.9,"decode_speedup":1.0071,"decode_runs_baseline":[113.1,113.1],"decode_runs_patched":[113.9,113.9],"prefill_tps_baseline":328.4,"prefill_tps_patched":334.4,"prefill_change":0.0183,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.33,"max_tokens":1024,"gen_tokens":1024,"runs":2} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen argpartition(logits) fused top-k softmax plus exact combine kernel (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1 ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS_TOPK=1 ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_takeover_combo_t1024_r2_qwen_router_argpartition_logits_topk_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T20:13:58.090367+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":112.8,"decode_tps_patched":113.6,"decode_speedup":1.0071,"decode_runs_baseline":[112.5,113.0],"decode_runs_patched":[113.3,113.9],"prefill_tps_baseline":330.0,"prefill_tps_patched":332.6,"prefill_change":0.0079,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.33,"max_tokens":1024,"gen_tokens":1024,"runs":2} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: current best patch set (no RoPE fusion). capsule=benchmarks/repro_capsules/glm47_takeover_combo_t200_r3_control_swiglu_moe.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T20:15:09.728714+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":85.3,"decode_tps_patched":82.4,"decode_speedup":0.966,"decode_runs_baseline":[85.1,85.5,85.3],"decode_runs_patched":[83.7,79.9,82.4],"prefill_tps_baseline":276.7,"prefill_tps_patched":273.0,"prefill_change":-0.0134,"peak_mem_baseline_gb":16.91,"peak_mem_patched_gb":16.91,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via ZMLX fp32 no-FMA kernel (ZMLX_GLM_COMBINE_MODE=fp32_no_fma). capsule=benchmarks/repro_capsules/glm47_takeover_combo_t200_r3_glm_combine_fp32_no_fma.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T20:16:11.797192+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":82.6,"decode_tps_patched":84.3,"decode_speedup":1.0206,"decode_runs_baseline":[82.6,81.3,82.9],"decode_runs_patched":[84.3,87.9,83.3],"prefill_tps_baseline":273.1,"prefill_tps_patched":274.3,"prefill_change":0.0044,"peak_mem_baseline_gb":16.91,"peak_mem_patched_gb":16.91,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: overlap shared_experts(x) on a separate stream. capsule=benchmarks/repro_capsules/glm47_takeover_combo_t200_r3_shared_experts_overlap_streams2.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T20:17:17.776144+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":85.4,"decode_tps_patched":58.3,"decode_speedup":0.6827,"decode_runs_baseline":[85.4,85.6,84.0],"decode_runs_patched":[57.5,58.3,58.7],"prefill_tps_baseline":277.3,"prefill_tps_patched":237.4,"prefill_change":-0.1439,"peak_mem_baseline_gb":16.91,"peak_mem_patched_gb":16.91,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit","model_family":"glm","architecture":"moe","patterns_applied":["glm47_rope","moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Decode-only fused RoPE+concat for glm4_moe_lite attention. capsule=benchmarks/repro_capsules/glm47_takeover_combo_t200_r3_glm47_rope.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T20:18:20.287813+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":140,"decode_tps_baseline":84.9,"decode_tps_patched":83.0,"decode_speedup":0.9776,"decode_runs_baseline":[84.9,84.8,85.3],"decode_runs_patched":[81.7,83.1,83.0],"prefill_tps_baseline":277.2,"prefill_tps_patched":272.3,"prefill_change":-0.0177,"peak_mem_baseline_gb":16.91,"peak_mem_patched_gb":16.91,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: current best patch set (no RoPE fusion). capsule=benchmarks/repro_capsules/glm47_takeover_combo_t1024_r2_control_swiglu_moe.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T20:20:03.008706+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":78.9,"decode_tps_patched":81.8,"decode_speedup":1.0368,"decode_runs_baseline":[78.7,79.1],"decode_runs_patched":[81.5,82.1],"prefill_tps_baseline":237.6,"prefill_tps_patched":243.5,"prefill_change":0.0248,"peak_mem_baseline_gb":16.94,"peak_mem_patched_gb":16.95,"max_tokens":1024,"gen_tokens":1024,"runs":2} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via ZMLX fp32 no-FMA kernel (ZMLX_GLM_COMBINE_MODE=fp32_no_fma). capsule=benchmarks/repro_capsules/glm47_takeover_combo_t1024_r2_glm_combine_fp32_no_fma.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T20:21:39.991240+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":78.6,"decode_tps_patched":81.5,"decode_speedup":1.0369,"decode_runs_baseline":[79.1,78.1],"decode_runs_patched":[80.3,82.8],"prefill_tps_baseline":237.4,"prefill_tps_patched":238.4,"prefill_change":0.0042,"peak_mem_baseline_gb":16.94,"peak_mem_patched_gb":16.95,"max_tokens":1024,"gen_tokens":1024,"runs":2} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: overlap shared_experts(x) on a separate stream. capsule=benchmarks/repro_capsules/glm47_takeover_combo_t1024_r2_shared_experts_overlap_streams2.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T20:23:31.312719+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":78.6,"decode_tps_patched":54.6,"decode_speedup":0.6947,"decode_runs_baseline":[79.3,78.0],"decode_runs_patched":[54.5,54.7],"prefill_tps_baseline":235.2,"prefill_tps_patched":203.7,"prefill_change":-0.1339,"peak_mem_baseline_gb":16.94,"peak_mem_patched_gb":16.95,"max_tokens":1024,"gen_tokens":1024,"runs":2} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit","model_family":"glm","architecture":"moe","patterns_applied":["glm47_rope","moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Decode-only fused RoPE+concat for glm4_moe_lite attention. capsule=benchmarks/repro_capsules/glm47_takeover_combo_t1024_r2_glm47_rope.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T20:25:11.879858+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":140,"decode_tps_baseline":78.4,"decode_tps_patched":77.1,"decode_speedup":0.9834,"decode_runs_baseline":[78.4,78.3],"decode_runs_patched":[76.9,77.4],"prefill_tps_baseline":238.4,"prefill_tps_patched":239.3,"prefill_change":0.0038,"peak_mem_baseline_gb":16.94,"peak_mem_patched_gb":16.95,"max_tokens":1024,"gen_tokens":1024,"runs":2} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: current best patch set (no RoPE fusion). capsule=benchmarks/repro_capsules/glm47_takeover_additive_combo_t200_r3_control_swiglu_moe.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T20:27:26.609189+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":84.8,"decode_tps_patched":86.3,"decode_speedup":1.0177,"decode_runs_baseline":[84.8,85.2,84.5],"decode_runs_patched":[85.1,86.5,86.3],"prefill_tps_baseline":276.3,"prefill_tps_patched":275.5,"prefill_change":-0.0029,"peak_mem_baseline_gb":16.91,"peak_mem_patched_gb":16.91,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental additive combo: fp32-no-fma combine + shared_experts overlap streams=2. capsule=benchmarks/repro_capsules/glm47_takeover_additive_combo_t200_r3_glm_combine_fp32_no_fma_shared_overlap_streams2.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T20:28:30.845632+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":84.5,"decode_tps_patched":58.8,"decode_speedup":0.6959,"decode_runs_baseline":[83.7,84.5,84.9],"decode_runs_patched":[58.8,58.7,58.8],"prefill_tps_baseline":275.8,"prefill_tps_patched":242.3,"prefill_change":-0.1215,"peak_mem_baseline_gb":16.91,"peak_mem_patched_gb":16.91,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit","model_family":"glm","architecture":"moe","patterns_applied":["glm47_rope","moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental additive combo: fp32-no-fma combine + glm47_rope decode path. capsule=benchmarks/repro_capsules/glm47_takeover_additive_combo_t200_r3_glm_combine_fp32_no_fma_glm47_rope.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T20:29:30.121857+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":140,"decode_tps_baseline":84.3,"decode_tps_patched":84.7,"decode_speedup":1.0047,"decode_runs_baseline":[84.3,84.7,83.9],"decode_runs_patched":[84.7,84.5,85.0],"prefill_tps_baseline":276.2,"prefill_tps_patched":275.2,"prefill_change":-0.0036,"peak_mem_baseline_gb":16.91,"peak_mem_patched_gb":16.91,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: current best patch set (no RoPE fusion). capsule=benchmarks/repro_capsules/glm47_takeover_additive_combo_t1024_r2_control_swiglu_moe.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T20:31:13.833579+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":78.0,"decode_tps_patched":80.2,"decode_speedup":1.0282,"decode_runs_baseline":[77.5,78.4],"decode_runs_patched":[80.3,80.1],"prefill_tps_baseline":233.9,"prefill_tps_patched":241.1,"prefill_change":0.0308,"peak_mem_baseline_gb":16.94,"peak_mem_patched_gb":16.95,"max_tokens":1024,"gen_tokens":1024,"runs":2} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental additive combo: fp32-no-fma combine + shared_experts overlap streams=2. capsule=benchmarks/repro_capsules/glm47_takeover_additive_combo_t1024_r2_glm_combine_fp32_no_fma_shared_overlap_streams2.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T20:33:03.967635+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":78.5,"decode_tps_patched":55.5,"decode_speedup":0.707,"decode_runs_baseline":[78.2,78.8],"decode_runs_patched":[55.5,55.5],"prefill_tps_baseline":237.4,"prefill_tps_patched":212.0,"prefill_change":-0.107,"peak_mem_baseline_gb":16.94,"peak_mem_patched_gb":16.95,"max_tokens":1024,"gen_tokens":1024,"runs":2} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit","model_family":"glm","architecture":"moe","patterns_applied":["glm47_rope","moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental additive combo: fp32-no-fma combine + glm47_rope decode path. capsule=benchmarks/repro_capsules/glm47_takeover_additive_combo_t1024_r2_glm_combine_fp32_no_fma_glm47_rope.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T20:34:46.759105+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":140,"decode_tps_baseline":78.8,"decode_tps_patched":77.1,"decode_speedup":0.9784,"decode_runs_baseline":[78.9,78.8],"decode_runs_patched":[77.9,76.4],"prefill_tps_baseline":236.4,"prefill_tps_patched":236.7,"prefill_change":0.0013,"peak_mem_baseline_gb":16.94,"peak_mem_patched_gb":16.95,"max_tokens":1024,"gen_tokens":1024,"runs":2} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: patch(model, patterns=['moe_mlp']) capsule=benchmarks/repro_capsules/qwen3_recal_t200_r3_control_patterns_moe_mlp.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T20:54:31.612786+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":114.9,"decode_tps_patched":111.5,"decode_speedup":0.9704,"decode_runs_baseline":[115.3,114.9,114.4],"decode_runs_patched":[115.6,111.5,110.7],"prefill_tps_baseline":328.6,"prefill_tps_patched":339.7,"prefill_change":0.0338,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen combine via ZMLX exact-dtype kernel (ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_recal_t200_r3_qwen_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T20:55:31.025471+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":112.2,"decode_tps_patched":116.7,"decode_speedup":1.0401,"decode_runs_baseline":[111.1,115.5,112.2],"decode_runs_patched":[116.7,116.8,115.4],"prefill_tps_baseline":332.1,"prefill_tps_patched":334.5,"prefill_change":0.0072,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: patch(model, patterns=['moe_mlp']) capsule=benchmarks/repro_capsules/qwen3_recal_t1024_r2_control_patterns_moe_mlp.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T20:57:05.316607+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":111.6,"decode_tps_patched":111.2,"decode_speedup":0.9964,"decode_runs_baseline":[111.4,111.8],"decode_runs_patched":[111.1,111.3],"prefill_tps_baseline":330.3,"prefill_tps_patched":334.4,"prefill_change":0.0124,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.33,"max_tokens":1024,"gen_tokens":1024,"runs":2} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen combine via ZMLX exact-dtype kernel (ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_recal_t1024_r2_qwen_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T20:58:29.689425+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":111.9,"decode_tps_patched":110.6,"decode_speedup":0.9884,"decode_runs_baseline":[111.8,112.0],"decode_runs_patched":[110.5,110.7],"prefill_tps_baseline":328.9,"prefill_tps_patched":331.6,"prefill_change":0.0082,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.33,"max_tokens":1024,"gen_tokens":1024,"runs":2} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: current best patch set (no RoPE fusion). capsule=benchmarks/repro_capsules/glm47_recal_t200_r3_control_swiglu_moe.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T20:59:40.460187+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":82.5,"decode_tps_patched":85.8,"decode_speedup":1.04,"decode_runs_baseline":[82.8,82.5,82.4],"decode_runs_patched":[85.9,83.7,85.8],"prefill_tps_baseline":277.8,"prefill_tps_patched":275.7,"prefill_change":-0.0076,"peak_mem_baseline_gb":16.91,"peak_mem_patched_gb":16.91,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via ZMLX fp32 no-FMA kernel (ZMLX_GLM_COMBINE_MODE=fp32_no_fma). capsule=benchmarks/repro_capsules/glm47_recal_t200_r3_glm_combine_fp32_no_fma.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T21:00:40.892195+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":85.0,"decode_tps_patched":88.2,"decode_speedup":1.0376,"decode_runs_baseline":[85.0,85.5,84.7],"decode_runs_patched":[87.8,88.2,88.9],"prefill_tps_baseline":277.0,"prefill_tps_patched":277.9,"prefill_change":0.0032,"peak_mem_baseline_gb":16.91,"peak_mem_patched_gb":16.91,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: current best patch set (no RoPE fusion). capsule=benchmarks/repro_capsules/glm47_recal_t1024_r2_control_swiglu_moe.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T21:02:22.132027+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":79.8,"decode_tps_patched":81.7,"decode_speedup":1.0238,"decode_runs_baseline":[79.7,79.8],"decode_runs_patched":[82.0,81.4],"prefill_tps_baseline":239.2,"prefill_tps_patched":241.7,"prefill_change":0.0105,"peak_mem_baseline_gb":16.94,"peak_mem_patched_gb":16.95,"max_tokens":1024,"gen_tokens":1024,"runs":2} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via ZMLX fp32 no-FMA kernel (ZMLX_GLM_COMBINE_MODE=fp32_no_fma). capsule=benchmarks/repro_capsules/glm47_recal_t1024_r2_glm_combine_fp32_no_fma.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T21:03:59.934186+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":77.9,"decode_tps_patched":83.5,"decode_speedup":1.0719,"decode_runs_baseline":[78.1,77.7],"decode_runs_patched":[83.3,83.6],"prefill_tps_baseline":235.1,"prefill_tps_patched":243.4,"prefill_change":0.0353,"peak_mem_baseline_gb":16.94,"peak_mem_patched_gb":16.95,"max_tokens":1024,"gen_tokens":1024,"runs":2} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: patch(model, patterns=['moe_mlp']) capsule=benchmarks/repro_capsules/qwen3_fp32_vs_nofma_t200_r2_summary.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T21:47:14.760722+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":116.4,"decode_tps_patched":116.3,"decode_speedup":0.9991,"decode_runs_baseline":[116.6,116.3],"decode_runs_patched":[116.4,116.3],"prefill_tps_baseline":333.6,"prefill_tps_patched":336.4,"prefill_change":0.0084,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":2} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen combine via ZMLX fp32 kernel (ZMLX_QWEN_COMBINE_MODE=fp32) capsule=benchmarks/repro_capsules/qwen3_fp32_probe_t128_r1_summary.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T21:49:39.367308+00:00","fidelity":"FAIL","fidelity_detail":"93/128 tokens identical","modules_patched":48,"decode_tps_baseline":117.5,"decode_tps_patched":42.3,"decode_speedup":0.36,"decode_runs_baseline":[117.5],"decode_runs_patched":[42.3],"prefill_tps_baseline":323.0,"prefill_tps_patched":205.6,"prefill_change":-0.3635,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.71,"max_tokens":128,"gen_tokens":128,"runs":1} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen combine via ZMLX fp32 no-FMA kernel (ZMLX_QWEN_COMBINE_MODE=fp32_no_fma) capsule=benchmarks/repro_capsules/qwen3_fp32_nofma_probe_t128_r1_summary.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T21:50:30.890645+00:00","fidelity":"FAIL","fidelity_detail":"93/128 tokens identical","modules_patched":48,"decode_tps_baseline":116.2,"decode_tps_patched":42.0,"decode_speedup":0.3614,"decode_runs_baseline":[116.2],"decode_runs_patched":[42.0],"prefill_tps_baseline":323.1,"prefill_tps_patched":203.8,"prefill_change":-0.3692,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.71,"max_tokens":128,"gen_tokens":128,"runs":1} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via ZMLX fp32 kernel (ZMLX_GLM_COMBINE_MODE=fp32). capsule=benchmarks/repro_capsules/glm47_fp32_probe_t128_r1_summary.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T21:51:29.932732+00:00","fidelity":"FAIL","fidelity_detail":"2/128 tokens identical","modules_patched":93,"decode_tps_baseline":83.5,"decode_tps_patched":82.6,"decode_speedup":0.9892,"decode_runs_baseline":[83.5],"decode_runs_patched":[82.6],"prefill_tps_baseline":292.1,"prefill_tps_patched":301.6,"prefill_change":0.0325,"peak_mem_baseline_gb":16.91,"peak_mem_patched_gb":16.91,"max_tokens":128,"gen_tokens":128,"runs":1} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via ZMLX fp32 no-FMA kernel (ZMLX_GLM_COMBINE_MODE=fp32_no_fma). capsule=benchmarks/repro_capsules/glm47_fp32_nofma_probe_t128_r1_summary.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T21:52:23.386621+00:00","fidelity":"PASS","fidelity_detail":"128/128 tokens identical","modules_patched":93,"decode_tps_baseline":80.8,"decode_tps_patched":82.1,"decode_speedup":1.0161,"decode_runs_baseline":[80.8],"decode_runs_patched":[82.1],"prefill_tps_baseline":292.1,"prefill_tps_patched":301.6,"prefill_change":0.0325,"peak_mem_baseline_gb":16.91,"peak_mem_patched_gb":16.91,"max_tokens":128,"gen_tokens":128,"runs":1} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen combine via ZMLX fp32 kernel (ZMLX_QWEN_COMBINE_MODE=fp32) capsule=benchmarks/repro_capsules/moe_combine_transfer_probe_t128_r1_20260210_qwen3_qwen_combine_fp32.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T22:14:57.511858+00:00","fidelity":"FAIL","fidelity_detail":"93/128 tokens identical","modules_patched":48,"decode_tps_baseline":119.1,"decode_tps_patched":42.7,"decode_speedup":0.3585,"decode_runs_baseline":[119.1],"decode_runs_patched":[42.7],"prefill_tps_baseline":326.0,"prefill_tps_patched":206.0,"prefill_change":-0.3681,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.71,"max_tokens":128,"gen_tokens":128,"runs":1} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen combine via ZMLX fp32 no-FMA kernel (ZMLX_QWEN_COMBINE_MODE=fp32_no_fma) capsule=benchmarks/repro_capsules/moe_combine_transfer_probe_t128_r1_20260210_qwen3_qwen_combine_fp32_no_fma.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T22:15:45.270656+00:00","fidelity":"FAIL","fidelity_detail":"9/128 tokens identical","modules_patched":48,"decode_tps_baseline":118.8,"decode_tps_patched":42.6,"decode_speedup":0.3586,"decode_runs_baseline":[118.8],"decode_runs_patched":[42.6],"prefill_tps_baseline":325.2,"prefill_tps_patched":205.4,"prefill_change":-0.3684,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.71,"max_tokens":128,"gen_tokens":128,"runs":1} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via ZMLX fp32 kernel (ZMLX_GLM_COMBINE_MODE=fp32). capsule=benchmarks/repro_capsules/moe_combine_transfer_probe_t128_r1_20260210_glm47_glm_combine_fp32.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T22:16:36.493825+00:00","fidelity":"FAIL","fidelity_detail":"2/128 tokens identical","modules_patched":93,"decode_tps_baseline":85.8,"decode_tps_patched":86.8,"decode_speedup":1.0117,"decode_runs_baseline":[85.8],"decode_runs_patched":[86.8],"prefill_tps_baseline":297.0,"prefill_tps_patched":304.8,"prefill_change":0.0263,"peak_mem_baseline_gb":16.91,"peak_mem_patched_gb":16.91,"max_tokens":128,"gen_tokens":128,"runs":1} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via ZMLX fp32 no-FMA kernel (ZMLX_GLM_COMBINE_MODE=fp32_no_fma). capsule=benchmarks/repro_capsules/moe_combine_transfer_probe_t128_r1_20260210_glm47_glm_combine_fp32_no_fma.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T22:17:22.325792+00:00","fidelity":"PASS","fidelity_detail":"128/128 tokens identical","modules_patched":93,"decode_tps_baseline":84.7,"decode_tps_patched":89.2,"decode_speedup":1.0531,"decode_runs_baseline":[84.7],"decode_runs_patched":[89.2],"prefill_tps_baseline":290.1,"prefill_tps_patched":306.2,"prefill_change":0.0555,"peak_mem_baseline_gb":16.91,"peak_mem_patched_gb":16.91,"max_tokens":128,"gen_tokens":128,"runs":1} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: current best patch set (no RoPE fusion). capsule=benchmarks/repro_capsules/glm47_fp32_no_fma_longconfirm_t1024_r2_control_swiglu_moe.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T22:22:42.137097+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":76.9,"decode_tps_patched":80.3,"decode_speedup":1.0442,"decode_runs_baseline":[76.8,77.1],"decode_runs_patched":[80.3,80.4],"prefill_tps_baseline":238.6,"prefill_tps_patched":238.1,"prefill_change":-0.0021,"peak_mem_baseline_gb":16.94,"peak_mem_patched_gb":16.95,"max_tokens":1024,"gen_tokens":1024,"runs":2} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via ZMLX fp32 no-FMA kernel (ZMLX_GLM_COMBINE_MODE=fp32_no_fma). capsule=benchmarks/repro_capsules/glm47_fp32_no_fma_longconfirm_t1024_r2_glm_combine_fp32_no_fma.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T22:24:18.526417+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":78.2,"decode_tps_patched":82.7,"decode_speedup":1.0575,"decode_runs_baseline":[77.9,78.4],"decode_runs_patched":[83.0,82.4],"prefill_tps_baseline":236.3,"prefill_tps_patched":242.5,"prefill_change":0.0262,"peak_mem_baseline_gb":16.94,"peak_mem_patched_gb":16.95,"max_tokens":1024,"gen_tokens":1024,"runs":2} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: patch(model, patterns=['moe_mlp']) capsule=benchmarks/repro_capsules/qwen3_return_t200_r3_20260210_control_patterns_moe_mlp.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T22:30:12.175240+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":116.3,"decode_tps_patched":114.9,"decode_speedup":0.988,"decode_runs_baseline":[116.8,116.0,116.3],"decode_runs_patched":[114.9,115.3,102.0],"prefill_tps_baseline":330.7,"prefill_tps_patched":334.1,"prefill_change":0.0103,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen combine via ZMLX exact-dtype kernel (ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_return_t200_r3_20260210_qwen_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T22:31:10.748371+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":114.2,"decode_tps_patched":112.2,"decode_speedup":0.9825,"decode_runs_baseline":[114.2,112.2,114.9],"decode_runs_patched":[111.7,112.2,116.4],"prefill_tps_baseline":313.1,"prefill_tps_patched":330.6,"prefill_change":0.0559,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen argpartition(logits) fused top-k softmax plus exact combine kernel (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1 ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS_TOPK=1 ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_return_t200_r3_20260210_qwen_router_argpartition_logits_topk_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T22:32:07.226081+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":113.9,"decode_tps_patched":111.3,"decode_speedup":0.9772,"decode_runs_baseline":[115.3,111.2,113.9],"decode_runs_patched":[112.1,111.3,110.2],"prefill_tps_baseline":324.4,"prefill_tps_patched":329.5,"prefill_change":0.0157,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: patch(model, patterns=['moe_mlp']) capsule=benchmarks/repro_capsules/qwen3_return_t1024_r2_20260210_control_patterns_moe_mlp.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T22:33:48.599740+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":111.9,"decode_tps_patched":111.3,"decode_speedup":0.9946,"decode_runs_baseline":[111.9,111.8],"decode_runs_patched":[111.2,111.4],"prefill_tps_baseline":325.7,"prefill_tps_patched":331.7,"prefill_change":0.0184,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.33,"max_tokens":1024,"gen_tokens":1024,"runs":2} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: patch(model, patterns=['moe_mlp']) capsule=benchmarks/repro_capsules/qwen3_benchmark_vs_baseline_t200_r3_repA_20260210_control_patterns_moe_mlp.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T22:43:21.513571+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":116.0,"decode_tps_patched":116.2,"decode_speedup":1.0017,"decode_runs_baseline":[116.0,117.4,115.3],"decode_runs_patched":[116.2,115.8,117.0],"prefill_tps_baseline":323.5,"prefill_tps_patched":336.0,"prefill_change":0.0386,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen combine via ZMLX exact-dtype kernel (ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_benchmark_vs_baseline_t200_r3_repA_20260210_qwen_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T22:44:17.153926+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":117.2,"decode_tps_patched":116.2,"decode_speedup":0.9915,"decode_runs_baseline":[117.1,117.2,117.3],"decode_runs_patched":[116.8,116.1,116.2],"prefill_tps_baseline":337.5,"prefill_tps_patched":332.0,"prefill_change":-0.0163,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen router fast path using argpartition on logits with top-k softmax (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1) capsule=benchmarks/repro_capsules/qwen3_benchmark_vs_baseline_t200_r3_repA_20260210_qwen_router_argpartition_logits.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T22:45:12.745196+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":111.7,"decode_tps_patched":114.1,"decode_speedup":1.0215,"decode_runs_baseline":[110.0,111.7,112.2],"decode_runs_patched":[108.9,115.8,114.1],"prefill_tps_baseline":331.5,"prefill_tps_patched":333.3,"prefill_change":0.0054,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen argpartition(logits) fused top-k softmax plus exact combine kernel (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1 ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS_TOPK=1 ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_benchmark_vs_baseline_t200_r3_repA_20260210_qwen_router_argpartition_logits_topk_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T22:46:08.440869+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":113.7,"decode_tps_patched":113.5,"decode_speedup":0.9982,"decode_runs_baseline":[113.6,113.7,114.0],"decode_runs_patched":[113.5,113.7,113.2],"prefill_tps_baseline":334.4,"prefill_tps_patched":331.3,"prefill_change":-0.0093,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: patch(model, patterns=['moe_mlp']) capsule=benchmarks/repro_capsules/qwen3_benchmark_vs_baseline_t1024_r3_repA_20260210_control_patterns_moe_mlp.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T22:48:47.007038+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":112.7,"decode_tps_patched":112.3,"decode_speedup":0.9965,"decode_runs_baseline":[112.6,112.7,113.1],"decode_runs_patched":[112.9,112.0,112.3],"prefill_tps_baseline":326.2,"prefill_tps_patched":332.2,"prefill_change":0.0184,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.33,"max_tokens":1024,"gen_tokens":1024,"runs":3} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen combine via ZMLX exact-dtype kernel (ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_benchmark_vs_baseline_t1024_r3_repA_20260210_qwen_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T22:50:28.810952+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":111.7,"decode_tps_patched":111.9,"decode_speedup":1.0018,"decode_runs_baseline":[112.4,111.7,111.6],"decode_runs_patched":[111.9,112.6,110.9],"prefill_tps_baseline":322.5,"prefill_tps_patched":331.2,"prefill_change":0.027,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.33,"max_tokens":1024,"gen_tokens":1024,"runs":3} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen router fast path using argpartition on logits with top-k softmax (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1) capsule=benchmarks/repro_capsules/qwen3_benchmark_vs_baseline_t1024_r3_repA_20260210_qwen_router_argpartition_logits.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T22:52:11.571875+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":112.9,"decode_tps_patched":113.1,"decode_speedup":1.0018,"decode_runs_baseline":[112.9,112.9,112.7],"decode_runs_patched":[113.2,113.1,113.1],"prefill_tps_baseline":323.9,"prefill_tps_patched":332.4,"prefill_change":0.0262,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.33,"max_tokens":1024,"gen_tokens":1024,"runs":3} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen argpartition(logits) fused top-k softmax plus exact combine kernel (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1 ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS_TOPK=1 ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_benchmark_vs_baseline_t1024_r3_repA_20260210_qwen_router_argpartition_logits_topk_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T22:53:52.293589+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":112.9,"decode_tps_patched":112.9,"decode_speedup":1.0,"decode_runs_baseline":[112.9,113.0,112.2],"decode_runs_patched":[113.4,112.9,112.5],"prefill_tps_baseline":333.2,"prefill_tps_patched":330.5,"prefill_change":-0.0081,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.33,"max_tokens":1024,"gen_tokens":1024,"runs":3} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: patch(model, patterns=['moe_mlp']) capsule=benchmarks/repro_capsules/qwen3_benchmark_vs_baseline_t200_r3_repB_20260210_control_patterns_moe_mlp.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T22:55:32.201656+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":117.0,"decode_tps_patched":111.9,"decode_speedup":0.9564,"decode_runs_baseline":[117.2,117.0,115.0],"decode_runs_patched":[111.6,111.9,112.5],"prefill_tps_baseline":337.1,"prefill_tps_patched":332.3,"prefill_change":-0.0142,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen combine via ZMLX exact-dtype kernel (ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_benchmark_vs_baseline_t200_r3_repB_20260210_qwen_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T22:56:29.781386+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":116.9,"decode_tps_patched":114.7,"decode_speedup":0.9812,"decode_runs_baseline":[116.9,117.5,116.1],"decode_runs_patched":[114.7,113.9,117.3],"prefill_tps_baseline":338.1,"prefill_tps_patched":333.9,"prefill_change":-0.0124,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen router fast path using argpartition on logits with top-k softmax (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1) capsule=benchmarks/repro_capsules/qwen3_benchmark_vs_baseline_t200_r3_repB_20260210_qwen_router_argpartition_logits.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T22:57:27.353945+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":117.2,"decode_tps_patched":111.6,"decode_speedup":0.9522,"decode_runs_baseline":[117.3,117.2,117.2],"decode_runs_patched":[115.9,110.1,111.6],"prefill_tps_baseline":338.8,"prefill_tps_patched":333.9,"prefill_change":-0.0145,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen argpartition(logits) fused top-k softmax plus exact combine kernel (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1 ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS_TOPK=1 ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_benchmark_vs_baseline_t200_r3_repB_20260210_qwen_router_argpartition_logits_topk_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T22:58:27.881257+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":113.7,"decode_tps_patched":116.6,"decode_speedup":1.0255,"decode_runs_baseline":[114.7,113.7,111.2],"decode_runs_patched":[113.7,116.6,118.0],"prefill_tps_baseline":330.7,"prefill_tps_patched":331.2,"prefill_change":0.0015,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: patch(model, patterns=['moe_mlp']) capsule=benchmarks/repro_capsules/qwen3_benchmark_vs_baseline_t1024_r3_repB_20260210_control_patterns_moe_mlp.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T23:00:49.664251+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":112.9,"decode_tps_patched":113.5,"decode_speedup":1.0053,"decode_runs_baseline":[113.1,112.5,112.9],"decode_runs_patched":[113.6,113.5,113.2],"prefill_tps_baseline":333.7,"prefill_tps_patched":332.5,"prefill_change":-0.0036,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.33,"max_tokens":1024,"gen_tokens":1024,"runs":3} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen combine via ZMLX exact-dtype kernel (ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_benchmark_vs_baseline_t1024_r3_repB_20260210_qwen_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T23:02:30.712424+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":113.0,"decode_tps_patched":109.5,"decode_speedup":0.969,"decode_runs_baseline":[113.0,113.0,113.1],"decode_runs_patched":[110.6,109.4,109.5],"prefill_tps_baseline":333.6,"prefill_tps_patched":330.5,"prefill_change":-0.0093,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.33,"max_tokens":1024,"gen_tokens":1024,"runs":3} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen router fast path using argpartition on logits with top-k softmax (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1) capsule=benchmarks/repro_capsules/qwen3_benchmark_vs_baseline_t1024_r3_repB_20260210_qwen_router_argpartition_logits.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T23:04:14.206876+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":112.0,"decode_tps_patched":108.8,"decode_speedup":0.9714,"decode_runs_baseline":[112.0,112.1,111.2],"decode_runs_patched":[108.8,108.4,109.2],"prefill_tps_baseline":327.7,"prefill_tps_patched":332.3,"prefill_change":0.014,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.33,"max_tokens":1024,"gen_tokens":1024,"runs":3} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen argpartition(logits) fused top-k softmax plus exact combine kernel (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1 ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS_TOPK=1 ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_benchmark_vs_baseline_t1024_r3_repB_20260210_qwen_router_argpartition_logits_topk_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T23:05:59.305829+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":111.1,"decode_tps_patched":108.5,"decode_speedup":0.9766,"decode_runs_baseline":[111.1,111.3,108.9],"decode_runs_patched":[108.5,108.9,107.3],"prefill_tps_baseline":334.9,"prefill_tps_patched":332.0,"prefill_change":-0.0087,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.33,"max_tokens":1024,"gen_tokens":1024,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: current best patch set (no RoPE fusion, combine off). capsule=benchmarks/repro_capsules/glm47_benchmark_vs_baseline_t200_r3_repA_20260210_control_swiglu_moe.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T23:11:59.113287+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":83.6,"decode_tps_patched":80.2,"decode_speedup":0.9593,"decode_runs_baseline":[83.1,83.7,83.6],"decode_runs_patched":[80.2,80.1,80.2],"prefill_tps_baseline":284.4,"prefill_tps_patched":282.7,"prefill_change":-0.006,"peak_mem_baseline_gb":15.98,"peak_mem_patched_gb":15.98,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via ZMLX fp32 no-FMA kernel (ZMLX_GLM_COMBINE_MODE=fp32_no_fma). capsule=benchmarks/repro_capsules/glm47_benchmark_vs_baseline_t200_r3_repA_20260210_glm_combine_fp32_no_fma.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T23:12:52.604844+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":85.3,"decode_tps_patched":81.3,"decode_speedup":0.9531,"decode_runs_baseline":[85.0,85.3,85.5],"decode_runs_patched":[80.7,81.3,81.4],"prefill_tps_baseline":287.7,"prefill_tps_patched":285.4,"prefill_change":-0.008,"peak_mem_baseline_gb":15.98,"peak_mem_patched_gb":15.98,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: overlap shared_experts(x) on a separate stream. capsule=benchmarks/repro_capsules/glm47_benchmark_vs_baseline_t200_r3_repA_20260210_shared_experts_overlap_streams2.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T23:13:47.641610+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":84.9,"decode_tps_patched":59.5,"decode_speedup":0.7008,"decode_runs_baseline":[84.9,84.8,85.1],"decode_runs_patched":[59.5,59.7,59.4],"prefill_tps_baseline":286.7,"prefill_tps_patched":252.1,"prefill_change":-0.1207,"peak_mem_baseline_gb":15.98,"peak_mem_patched_gb":15.98,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["glm47_rope","moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Decode-only fused RoPE+concat for glm4_moe_lite attention. capsule=benchmarks/repro_capsules/glm47_benchmark_vs_baseline_t200_r3_repA_20260210_glm47_rope.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T23:14:41.784611+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":140,"decode_tps_baseline":85.2,"decode_tps_patched":77.4,"decode_speedup":0.9085,"decode_runs_baseline":[82.7,85.2,85.3],"decode_runs_patched":[77.4,77.6,75.3],"prefill_tps_baseline":286.3,"prefill_tps_patched":279.6,"prefill_change":-0.0234,"peak_mem_baseline_gb":15.98,"peak_mem_patched_gb":15.98,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental additive combo: fp32-no-fma combine + shared_experts overlap streams=2. capsule=benchmarks/repro_capsules/glm47_benchmark_vs_baseline_t200_r3_repA_20260210_glm_combine_fp32_no_fma_shared_overlap_streams2.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T23:15:43.467630+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":82.0,"decode_tps_patched":58.4,"decode_speedup":0.7122,"decode_runs_baseline":[82.0,80.7,82.0],"decode_runs_patched":[58.3,58.4,58.4],"prefill_tps_baseline":277.8,"prefill_tps_patched":246.3,"prefill_change":-0.1134,"peak_mem_baseline_gb":15.98,"peak_mem_patched_gb":15.98,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["glm47_rope","moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental additive combo: fp32-no-fma combine + glm47_rope decode path. capsule=benchmarks/repro_capsules/glm47_benchmark_vs_baseline_t200_r3_repA_20260210_glm_combine_fp32_no_fma_glm47_rope.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T23:16:39.087161+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":140,"decode_tps_baseline":81.7,"decode_tps_patched":74.2,"decode_speedup":0.9082,"decode_runs_baseline":[81.5,81.7,82.2],"decode_runs_patched":[73.7,74.2,74.4],"prefill_tps_baseline":279.9,"prefill_tps_patched":273.5,"prefill_change":-0.0229,"peak_mem_baseline_gb":15.98,"peak_mem_patched_gb":15.98,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: current best patch set (no RoPE fusion, combine off). capsule=benchmarks/repro_capsules/glm47_benchmark_vs_baseline_t1024_r3_repA_20260210_control_swiglu_moe.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T23:19:17.214848+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":79.3,"decode_tps_patched":75.1,"decode_speedup":0.947,"decode_runs_baseline":[79.3,79.6,79.1],"decode_runs_patched":[75.2,74.9,75.1],"prefill_tps_baseline":180.9,"prefill_tps_patched":175.3,"prefill_change":-0.031,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via ZMLX fp32 no-FMA kernel (ZMLX_GLM_COMBINE_MODE=fp32_no_fma). capsule=benchmarks/repro_capsules/glm47_benchmark_vs_baseline_t1024_r3_repA_20260210_glm_combine_fp32_no_fma.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T23:21:16.687842+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":79.5,"decode_tps_patched":75.7,"decode_speedup":0.9522,"decode_runs_baseline":[79.5,79.5,78.5],"decode_runs_patched":[75.6,75.7,75.7],"prefill_tps_baseline":179.7,"prefill_tps_patched":170.1,"prefill_change":-0.0534,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: overlap shared_experts(x) on a separate stream. capsule=benchmarks/repro_capsules/glm47_benchmark_vs_baseline_t1024_r3_repA_20260210_shared_experts_overlap_streams2.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T23:23:30.497749+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":78.8,"decode_tps_patched":55.5,"decode_speedup":0.7043,"decode_runs_baseline":[78.5,79.1,78.8],"decode_runs_patched":[55.5,55.5,55.5],"prefill_tps_baseline":176.9,"prefill_tps_patched":157.4,"prefill_change":-0.1102,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["glm47_rope","moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Decode-only fused RoPE+concat for glm4_moe_lite attention. capsule=benchmarks/repro_capsules/glm47_benchmark_vs_baseline_t1024_r3_repA_20260210_glm47_rope.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T23:25:29.500368+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":140,"decode_tps_baseline":79.0,"decode_tps_patched":74.7,"decode_speedup":0.9456,"decode_runs_baseline":[78.8,79.0,79.3],"decode_runs_patched":[74.5,74.9,74.7],"prefill_tps_baseline":176.4,"prefill_tps_patched":179.7,"prefill_change":0.0187,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental additive combo: fp32-no-fma combine + shared_experts overlap streams=2. capsule=benchmarks/repro_capsules/glm47_benchmark_vs_baseline_t1024_r3_repA_20260210_glm_combine_fp32_no_fma_shared_overlap_streams2.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T23:27:38.667436+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":80.6,"decode_tps_patched":57.1,"decode_speedup":0.7084,"decode_runs_baseline":[80.2,80.6,80.9],"decode_runs_patched":[57.1,57.1,57.1],"prefill_tps_baseline":182.6,"prefill_tps_patched":164.3,"prefill_change":-0.1002,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["glm47_rope","moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental additive combo: fp32-no-fma combine + glm47_rope decode path. capsule=benchmarks/repro_capsules/glm47_benchmark_vs_baseline_t1024_r3_repA_20260210_glm_combine_fp32_no_fma_glm47_rope.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T23:29:35.606197+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":140,"decode_tps_baseline":80.2,"decode_tps_patched":74.2,"decode_speedup":0.9252,"decode_runs_baseline":[79.9,80.2,80.4],"decode_runs_patched":[74.3,74.2,72.5],"prefill_tps_baseline":181.7,"prefill_tps_patched":180.2,"prefill_change":-0.0083,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: current best patch set (no RoPE fusion, combine off). capsule=benchmarks/repro_capsules/glm47_benchmark_vs_baseline_t200_r3_repB_20260210_control_swiglu_moe.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T23:31:07.836067+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":84.3,"decode_tps_patched":79.0,"decode_speedup":0.9371,"decode_runs_baseline":[84.2,84.3,84.6],"decode_runs_patched":[78.8,79.1,79.0],"prefill_tps_baseline":283.6,"prefill_tps_patched":276.4,"prefill_change":-0.0254,"peak_mem_baseline_gb":15.98,"peak_mem_patched_gb":15.98,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via ZMLX fp32 no-FMA kernel (ZMLX_GLM_COMBINE_MODE=fp32_no_fma). capsule=benchmarks/repro_capsules/glm47_benchmark_vs_baseline_t200_r3_repB_20260210_glm_combine_fp32_no_fma.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T23:31:59.948563+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":84.1,"decode_tps_patched":80.5,"decode_speedup":0.9572,"decode_runs_baseline":[83.9,84.2,84.1],"decode_runs_patched":[80.5,80.7,80.3],"prefill_tps_baseline":283.8,"prefill_tps_patched":281.1,"prefill_change":-0.0095,"peak_mem_baseline_gb":15.98,"peak_mem_patched_gb":15.98,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: overlap shared_experts(x) on a separate stream. capsule=benchmarks/repro_capsules/glm47_benchmark_vs_baseline_t200_r3_repB_20260210_shared_experts_overlap_streams2.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T23:32:54.673060+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":84.2,"decode_tps_patched":58.2,"decode_speedup":0.6912,"decode_runs_baseline":[84.2,84.2,84.4],"decode_runs_patched":[58.2,58.1,58.2],"prefill_tps_baseline":283.0,"prefill_tps_patched":251.0,"prefill_change":-0.1131,"peak_mem_baseline_gb":15.98,"peak_mem_patched_gb":15.98,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["glm47_rope","moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Decode-only fused RoPE+concat for glm4_moe_lite attention. capsule=benchmarks/repro_capsules/glm47_benchmark_vs_baseline_t200_r3_repB_20260210_glm47_rope.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T23:33:47.139107+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":140,"decode_tps_baseline":84.2,"decode_tps_patched":77.9,"decode_speedup":0.9252,"decode_runs_baseline":[84.4,84.0,84.2],"decode_runs_patched":[77.4,77.9,80.7],"prefill_tps_baseline":282.0,"prefill_tps_patched":282.9,"prefill_change":0.0032,"peak_mem_baseline_gb":15.98,"peak_mem_patched_gb":15.98,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental additive combo: fp32-no-fma combine + shared_experts overlap streams=2. capsule=benchmarks/repro_capsules/glm47_benchmark_vs_baseline_t200_r3_repB_20260210_glm_combine_fp32_no_fma_shared_overlap_streams2.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T23:34:40.387983+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":86.1,"decode_tps_patched":60.2,"decode_speedup":0.6992,"decode_runs_baseline":[86.0,86.2,86.1],"decode_runs_patched":[60.3,60.2,60.1],"prefill_tps_baseline":288.2,"prefill_tps_patched":257.1,"prefill_change":-0.1079,"peak_mem_baseline_gb":15.98,"peak_mem_patched_gb":15.98,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["glm47_rope","moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental additive combo: fp32-no-fma combine + glm47_rope decode path. capsule=benchmarks/repro_capsules/glm47_benchmark_vs_baseline_t200_r3_repB_20260210_glm_combine_fp32_no_fma_glm47_rope.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T23:35:30.950846+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":140,"decode_tps_baseline":86.0,"decode_tps_patched":79.6,"decode_speedup":0.9256,"decode_runs_baseline":[85.9,86.0,86.0],"decode_runs_patched":[80.4,79.6,79.6],"prefill_tps_baseline":289.1,"prefill_tps_patched":284.2,"prefill_change":-0.0169,"peak_mem_baseline_gb":15.98,"peak_mem_patched_gb":15.98,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: current best patch set (no RoPE fusion, combine off). capsule=benchmarks/repro_capsules/glm47_benchmark_vs_baseline_t1024_r3_repB_20260210_control_swiglu_moe.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T23:38:04.709791+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":80.8,"decode_tps_patched":76.6,"decode_speedup":0.948,"decode_runs_baseline":[80.8,80.8,80.6],"decode_runs_patched":[76.8,76.3,76.6],"prefill_tps_baseline":181.3,"prefill_tps_patched":179.8,"prefill_change":-0.0083,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via ZMLX fp32 no-FMA kernel (ZMLX_GLM_COMBINE_MODE=fp32_no_fma). capsule=benchmarks/repro_capsules/glm47_benchmark_vs_baseline_t1024_r3_repB_20260210_glm_combine_fp32_no_fma.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T23:40:01.685817+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":78.8,"decode_tps_patched":75.6,"decode_speedup":0.9594,"decode_runs_baseline":[78.6,78.8,79.1],"decode_runs_patched":[75.4,75.6,76.5],"prefill_tps_baseline":173.0,"prefill_tps_patched":170.9,"prefill_change":-0.0121,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: overlap shared_experts(x) on a separate stream. capsule=benchmarks/repro_capsules/glm47_benchmark_vs_baseline_t1024_r3_repB_20260210_shared_experts_overlap_streams2.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T23:42:15.072964+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":78.9,"decode_tps_patched":55.6,"decode_speedup":0.7047,"decode_runs_baseline":[78.8,78.9,79.1],"decode_runs_patched":[55.6,55.6,55.6],"prefill_tps_baseline":176.9,"prefill_tps_patched":156.2,"prefill_change":-0.117,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["glm47_rope","moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Decode-only fused RoPE+concat for glm4_moe_lite attention. capsule=benchmarks/repro_capsules/glm47_benchmark_vs_baseline_t1024_r3_repB_20260210_glm47_rope.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T23:44:13.295338+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":140,"decode_tps_baseline":78.7,"decode_tps_patched":75.9,"decode_speedup":0.9644,"decode_runs_baseline":[78.7,79.7,77.3],"decode_runs_patched":[75.9,75.9,76.1],"prefill_tps_baseline":184.1,"prefill_tps_patched":180.9,"prefill_change":-0.0174,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental additive combo: fp32-no-fma combine + shared_experts overlap streams=2. capsule=benchmarks/repro_capsules/glm47_benchmark_vs_baseline_t1024_r3_repB_20260210_glm_combine_fp32_no_fma_shared_overlap_streams2.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T23:46:25.279607+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":81.4,"decode_tps_patched":57.7,"decode_speedup":0.7088,"decode_runs_baseline":[81.2,81.9,81.4],"decode_runs_patched":[57.8,57.7,57.7],"prefill_tps_baseline":184.3,"prefill_tps_patched":163.4,"prefill_change":-0.1134,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["glm47_rope","moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental additive combo: fp32-no-fma combine + glm47_rope decode path. capsule=benchmarks/repro_capsules/glm47_benchmark_vs_baseline_t1024_r3_repB_20260210_glm_combine_fp32_no_fma_glm47_rope.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-10T23:48:23.902988+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":140,"decode_tps_baseline":81.5,"decode_tps_patched":76.0,"decode_speedup":0.9325,"decode_runs_baseline":[81.4,81.5,81.6],"decode_runs_patched":[76.0,75.9,76.1],"prefill_tps_baseline":183.2,"prefill_tps_patched":180.3,"prefill_change":-0.0158,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: current best patch set (no RoPE fusion, combine off). capsule=benchmarks/repro_capsules/glm47_consistency_abba_t200_r5_repA_20260211_control_swiglu_moe.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-11T03:24:41.484771+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":86.9,"decode_tps_patched":77.0,"decode_speedup":0.8861,"decode_runs_baseline":[86.5,86.9,86.9,86.9,86.9],"decode_runs_patched":[80.0,78.3,76.9,77.0,76.6],"prefill_tps_baseline":287.1,"prefill_tps_patched":276.3,"prefill_change":-0.0376,"peak_mem_baseline_gb":15.98,"peak_mem_patched_gb":15.98,"max_tokens":200,"gen_tokens":200,"runs":5} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via ZMLX fp32 no-FMA kernel (ZMLX_GLM_COMBINE_MODE=fp32_no_fma). capsule=benchmarks/repro_capsules/glm47_consistency_abba_t200_r5_repA_20260211_glm_combine_fp32_no_fma.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-11T03:25:49.331320+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":81.6,"decode_tps_patched":77.8,"decode_speedup":0.9534,"decode_runs_baseline":[81.5,81.4,81.9,81.7,81.6],"decode_runs_patched":[77.5,77.9,77.8,77.8,77.8],"prefill_tps_baseline":278.5,"prefill_tps_patched":276.2,"prefill_change":-0.0083,"peak_mem_baseline_gb":15.98,"peak_mem_patched_gb":15.98,"max_tokens":200,"gen_tokens":200,"runs":5} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via ZMLX fp32 no-FMA kernel (ZMLX_GLM_COMBINE_MODE=fp32_no_fma). capsule=benchmarks/repro_capsules/glm47_consistency_abba_t200_r5_repB_20260211_glm_combine_fp32_no_fma.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-11T03:27:51.957528+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":81.7,"decode_tps_patched":77.7,"decode_speedup":0.951,"decode_runs_baseline":[81.4,81.7,81.0,81.7,81.7],"decode_runs_patched":[77.3,77.4,77.8,77.8,77.7],"prefill_tps_baseline":276.9,"prefill_tps_patched":274.6,"prefill_change":-0.0083,"peak_mem_baseline_gb":15.98,"peak_mem_patched_gb":15.98,"max_tokens":200,"gen_tokens":200,"runs":5} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: current best patch set (no RoPE fusion, combine off). capsule=benchmarks/repro_capsules/glm47_consistency_abba_t200_r5_repB_20260211_control_swiglu_moe.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-11T03:29:01.319574+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":81.4,"decode_tps_patched":76.5,"decode_speedup":0.9398,"decode_runs_baseline":[81.4,81.2,81.4,81.4,81.3],"decode_runs_patched":[76.2,76.4,76.5,76.5,76.6],"prefill_tps_baseline":278.8,"prefill_tps_patched":275.0,"prefill_change":-0.0136,"peak_mem_baseline_gb":15.98,"peak_mem_patched_gb":15.98,"max_tokens":200,"gen_tokens":200,"runs":5} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: current best patch set (no RoPE fusion, combine off). capsule=benchmarks/repro_capsules/glm47_consistency_abba_t1024_r5_repA_20260211_control_swiglu_moe.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-11T03:32:58.829229+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":76.4,"decode_tps_patched":72.1,"decode_speedup":0.9437,"decode_runs_baseline":[76.1,76.3,76.4,76.6,76.5],"decode_runs_patched":[71.9,72.1,72.1,71.9,72.1],"prefill_tps_baseline":176.7,"prefill_tps_patched":172.0,"prefill_change":-0.0266,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":5} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via ZMLX fp32 no-FMA kernel (ZMLX_GLM_COMBINE_MODE=fp32_no_fma). capsule=benchmarks/repro_capsules/glm47_consistency_abba_t1024_r5_repA_20260211_glm_combine_fp32_no_fma.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-11T03:35:59.364846+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":76.7,"decode_tps_patched":73.2,"decode_speedup":0.9544,"decode_runs_baseline":[76.5,76.7,76.7,76.6,76.7],"decode_runs_patched":[72.9,73.1,73.3,73.2,73.4],"prefill_tps_baseline":174.3,"prefill_tps_patched":173.0,"prefill_change":-0.0075,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":5} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via ZMLX fp32 no-FMA kernel (ZMLX_GLM_COMBINE_MODE=fp32_no_fma). capsule=benchmarks/repro_capsules/glm47_consistency_abba_t1024_r5_repB_20260211_glm_combine_fp32_no_fma.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-11T03:39:52.505219+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":78.1,"decode_tps_patched":73.9,"decode_speedup":0.9462,"decode_runs_baseline":[78.1,78.2,78.3,78.1,77.7],"decode_runs_patched":[73.9,73.9,73.7,73.9,74.0],"prefill_tps_baseline":177.4,"prefill_tps_patched":173.0,"prefill_change":-0.0248,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":5} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: current best patch set (no RoPE fusion, combine off). capsule=benchmarks/repro_capsules/glm47_consistency_abba_t1024_r5_repB_20260211_control_swiglu_moe.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-11T03:42:51.165663+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":77.3,"decode_tps_patched":72.9,"decode_speedup":0.9431,"decode_runs_baseline":[77.3,77.4,77.4,77.3,77.0],"decode_runs_patched":[72.9,73.1,72.8,72.9,72.8],"prefill_tps_baseline":175.2,"prefill_tps_patched":174.1,"prefill_change":-0.0063,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: patch(model, patterns=['moe_mlp']) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t200_r5_repA_20260211_control_patterns_moe_mlp.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-11T03:45:19.738152+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":113.2,"decode_tps_patched":110.7,"decode_speedup":0.9779,"decode_runs_baseline":[113.2,113.3,113.0,112.8,113.4],"decode_runs_patched":[110.3,111.1,111.4,110.6,110.7],"prefill_tps_baseline":334.1,"prefill_tps_patched":329.8,"prefill_change":-0.0129,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen combine via ZMLX exact-dtype kernel (ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t200_r5_repA_20260211_qwen_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-11T03:46:23.914857+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":113.3,"decode_tps_patched":104.0,"decode_speedup":0.9179,"decode_runs_baseline":[112.9,113.3,113.3,113.3,113.1],"decode_runs_patched":[106.0,103.3,104.0,100.6,107.4],"prefill_tps_baseline":330.7,"prefill_tps_patched":318.0,"prefill_change":-0.0384,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen router fast path using argpartition on logits with top-k softmax (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t200_r5_repA_20260211_qwen_router_argpartition_logits.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-11T03:47:32.445829+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":112.0,"decode_tps_patched":109.1,"decode_speedup":0.9741,"decode_runs_baseline":[111.1,112.0,112.2,111.9,112.0],"decode_runs_patched":[109.2,109.1,109.1,109.2,108.9],"prefill_tps_baseline":329.5,"prefill_tps_patched":327.6,"prefill_change":-0.0058,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen argpartition(logits) fused top-k softmax plus exact combine kernel (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1 ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS_TOPK=1 ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t200_r5_repA_20260211_qwen_router_argpartition_logits_topk_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-11T03:48:40.666879+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":109.9,"decode_tps_patched":101.2,"decode_speedup":0.9208,"decode_runs_baseline":[110.9,110.8,109.8,109.9,105.9],"decode_runs_patched":[101.2,100.6,101.2,101.5,101.0],"prefill_tps_baseline":326.3,"prefill_tps_patched":303.5,"prefill_change":-0.0699,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen argpartition(logits) fused top-k softmax plus exact combine kernel (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1 ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS_TOPK=1 ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t200_r5_repB_20260211_qwen_router_argpartition_logits_topk_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-11T03:50:52.914343+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":110.8,"decode_tps_patched":109.1,"decode_speedup":0.9847,"decode_runs_baseline":[106.4,110.8,111.1,111.6,110.8],"decode_runs_patched":[108.7,109.2,108.9,109.4,109.1],"prefill_tps_baseline":331.3,"prefill_tps_patched":327.7,"prefill_change":-0.0109,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen router fast path using argpartition on logits with top-k softmax (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t200_r5_repB_20260211_qwen_router_argpartition_logits.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-11T03:52:00.062924+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":111.2,"decode_tps_patched":108.4,"decode_speedup":0.9748,"decode_runs_baseline":[110.5,111.0,111.7,111.2,111.3],"decode_runs_patched":[108.3,108.4,107.9,108.5,108.5],"prefill_tps_baseline":328.6,"prefill_tps_patched":326.9,"prefill_change":-0.0052,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen combine via ZMLX exact-dtype kernel (ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t200_r5_repB_20260211_qwen_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-11T03:53:08.499994+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":111.4,"decode_tps_patched":108.5,"decode_speedup":0.974,"decode_runs_baseline":[111.4,111.4,111.3,111.3,111.6],"decode_runs_patched":[108.7,108.7,108.5,108.5,108.4],"prefill_tps_baseline":328.6,"prefill_tps_patched":327.0,"prefill_change":-0.0049,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: patch(model, patterns=['moe_mlp']) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t200_r5_repB_20260211_control_patterns_moe_mlp.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-11T03:54:14.230646+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":111.7,"decode_tps_patched":109.4,"decode_speedup":0.9794,"decode_runs_baseline":[111.7,111.8,111.8,111.3,111.6],"decode_runs_patched":[109.4,109.3,109.6,109.4,109.4],"prefill_tps_baseline":328.6,"prefill_tps_patched":327.9,"prefill_change":-0.0021,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: patch(model, patterns=['moe_mlp']) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t1024_r5_repA_20260211_control_patterns_moe_mlp.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-11T03:57:45.452563+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":107.6,"decode_tps_patched":105.3,"decode_speedup":0.9786,"decode_runs_baseline":[107.9,108.1,106.9,107.6,107.5],"decode_runs_patched":[105.3,105.5,105.3,104.9,105.1],"prefill_tps_baseline":326.3,"prefill_tps_patched":324.0,"prefill_change":-0.007,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.33,"max_tokens":1024,"gen_tokens":1024,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen combine via ZMLX exact-dtype kernel (ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t1024_r5_repA_20260211_qwen_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-11T04:00:12.378390+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":108.2,"decode_tps_patched":103.3,"decode_speedup":0.9547,"decode_runs_baseline":[108.0,108.3,108.3,108.1,108.2],"decode_runs_patched":[102.9,103.1,103.4,103.5,103.3],"prefill_tps_baseline":324.7,"prefill_tps_patched":316.9,"prefill_change":-0.024,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.33,"max_tokens":1024,"gen_tokens":1024,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen router fast path using argpartition on logits with top-k softmax (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t1024_r5_repA_20260211_qwen_router_argpartition_logits.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-11T04:02:39.973427+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":107.8,"decode_tps_patched":106.2,"decode_speedup":0.9852,"decode_runs_baseline":[106.5,107.3,107.8,107.8,107.9],"decode_runs_patched":[106.4,106.3,106.2,106.2,105.8],"prefill_tps_baseline":325.6,"prefill_tps_patched":324.1,"prefill_change":-0.0046,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.33,"max_tokens":1024,"gen_tokens":1024,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen argpartition(logits) fused top-k softmax plus exact combine kernel (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1 ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS_TOPK=1 ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t1024_r5_repA_20260211_qwen_router_argpartition_logits_topk_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-11T04:05:04.369888+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":108.4,"decode_tps_patched":105.1,"decode_speedup":0.9696,"decode_runs_baseline":[108.1,108.4,108.7,108.5,108.3],"decode_runs_patched":[105.2,105.0,105.1,105.1,105.1],"prefill_tps_baseline":325.7,"prefill_tps_patched":322.2,"prefill_change":-0.0107,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.33,"max_tokens":1024,"gen_tokens":1024,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen argpartition(logits) fused top-k softmax plus exact combine kernel (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1 ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS_TOPK=1 ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t1024_r5_repB_20260211_qwen_router_argpartition_logits_topk_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-11T04:08:31.526829+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":108.1,"decode_tps_patched":105.5,"decode_speedup":0.9759,"decode_runs_baseline":[108.2,108.0,108.1,107.9,108.4],"decode_runs_patched":[105.5,105.5,105.4,105.5,105.5],"prefill_tps_baseline":327.1,"prefill_tps_patched":321.1,"prefill_change":-0.0183,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.33,"max_tokens":1024,"gen_tokens":1024,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen router fast path using argpartition on logits with top-k softmax (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t1024_r5_repB_20260211_qwen_router_argpartition_logits.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-11T04:10:56.368693+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":108.0,"decode_tps_patched":105.3,"decode_speedup":0.975,"decode_runs_baseline":[107.8,107.4,108.0,108.1,108.0],"decode_runs_patched":[105.3,105.5,105.2,105.1,106.0],"prefill_tps_baseline":324.4,"prefill_tps_patched":326.6,"prefill_change":0.0068,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.33,"max_tokens":1024,"gen_tokens":1024,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen combine via ZMLX exact-dtype kernel (ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t1024_r5_repB_20260211_qwen_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-11T04:13:22.048772+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":107.8,"decode_tps_patched":105.1,"decode_speedup":0.975,"decode_runs_baseline":[108.0,107.7,107.7,107.8,108.1],"decode_runs_patched":[105.5,105.1,105.1,104.9,104.2],"prefill_tps_baseline":324.4,"prefill_tps_patched":323.0,"prefill_change":-0.0043,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.33,"max_tokens":1024,"gen_tokens":1024,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: patch(model, patterns=['moe_mlp']) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t1024_r5_repB_20260211_control_patterns_moe_mlp.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-11T04:15:49.223898+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":107.9,"decode_tps_patched":105.9,"decode_speedup":0.9815,"decode_runs_baseline":[107.8,107.9,107.7,107.9,107.9],"decode_runs_patched":[106.0,106.0,105.8,104.8,105.9],"prefill_tps_baseline":326.9,"prefill_tps_patched":322.4,"prefill_change":-0.0138,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.33,"max_tokens":1024,"gen_tokens":1024,"runs":5} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: current best patch set (no RoPE fusion, combine off). capsule=benchmarks/repro_capsules/glm47_final_longconfirm_t1024_r5_20260211_control_swiglu_moe.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-11T04:41:13.090255+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":75.0,"decode_tps_patched":70.8,"decode_speedup":0.944,"decode_runs_baseline":[74.8,74.9,75.0,75.0,75.1],"decode_runs_patched":[70.6,70.8,70.8,70.8,70.8],"prefill_tps_baseline":170.4,"prefill_tps_patched":167.8,"prefill_change":-0.0153,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":5} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via ZMLX fp32 no-FMA kernel (ZMLX_GLM_COMBINE_MODE=fp32_no_fma). capsule=benchmarks/repro_capsules/glm47_final_longconfirm_t1024_r5_20260211_glm_combine_fp32_no_fma.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.4","zmlx_commit":"eebfa6e","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-11T04:44:20.439726+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":75.0,"decode_tps_patched":71.5,"decode_speedup":0.9533,"decode_runs_baseline":[74.8,75.0,75.0,74.9,75.0],"decode_runs_patched":[71.0,71.5,71.3,72.2,72.4],"prefill_tps_baseline":170.3,"prefill_tps_patched":171.7,"prefill_change":0.0082,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":5} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: current best patch set (no RoPE fusion, combine off). capsule=benchmarks/repro_capsules/glm47_t2_row_tile_transfer_t200_r3_20260211_control_swiglu_moe.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-11T16:09:48.261493+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":83.8,"decode_tps_patched":80.4,"decode_speedup":0.9594,"decode_runs_baseline":[83.6,83.8,84.5],"decode_runs_patched":[79.5,80.5,80.4],"prefill_tps_baseline":283.8,"prefill_tps_patched":282.5,"prefill_change":-0.0046,"peak_mem_baseline_gb":15.98,"peak_mem_patched_gb":15.98,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via ZMLX fp32 no-FMA kernel (ZMLX_GLM_COMBINE_MODE=fp32_no_fma). capsule=benchmarks/repro_capsules/glm47_t2_row_tile_transfer_t200_r3_20260211_glm_combine_fp32_no_fma.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-11T16:10:51.479126+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":83.7,"decode_tps_patched":80.2,"decode_speedup":0.9582,"decode_runs_baseline":[83.6,83.7,84.4],"decode_runs_patched":[79.5,80.3,80.2],"prefill_tps_baseline":285.7,"prefill_tps_patched":281.6,"prefill_change":-0.0144,"peak_mem_baseline_gb":15.98,"peak_mem_patched_gb":15.98,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via t2-row-tile-inspired fp32 no-FMA kernel (ZMLX_GLM_COMBINE_MODE=fp32_no_fma_row_tile). capsule=benchmarks/repro_capsules/glm47_t2_row_tile_transfer_t200_r3_20260211_glm_combine_fp32_no_fma_row_tile.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-11T16:11:56.118027+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":80.9,"decode_tps_patched":75.1,"decode_speedup":0.9283,"decode_runs_baseline":[82.7,80.9,80.5],"decode_runs_patched":[73.5,75.1,77.8],"prefill_tps_baseline":285.5,"prefill_tps_patched":263.0,"prefill_change":-0.0788,"peak_mem_baseline_gb":15.98,"peak_mem_patched_gb":15.98,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: current best patch set (no RoPE fusion, combine off). capsule=benchmarks/repro_capsules/glm47_t2_row_tile_transfer_t1024_r3_20260211_control_swiglu_moe.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-11T16:14:18.018794+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":76.6,"decode_tps_patched":71.3,"decode_speedup":0.9308,"decode_runs_baseline":[73.4,77.5,76.6],"decode_runs_patched":[69.7,71.3,73.5],"prefill_tps_baseline":178.8,"prefill_tps_patched":180.8,"prefill_change":0.0112,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via ZMLX fp32 no-FMA kernel (ZMLX_GLM_COMBINE_MODE=fp32_no_fma). capsule=benchmarks/repro_capsules/glm47_t2_row_tile_transfer_t1024_r3_20260211_glm_combine_fp32_no_fma.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-11T16:16:28.475791+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":76.8,"decode_tps_patched":74.4,"decode_speedup":0.9688,"decode_runs_baseline":[76.4,77.7,76.8],"decode_runs_patched":[74.9,73.5,74.4],"prefill_tps_baseline":178.2,"prefill_tps_patched":176.8,"prefill_change":-0.0079,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via t2-row-tile-inspired fp32 no-FMA kernel (ZMLX_GLM_COMBINE_MODE=fp32_no_fma_row_tile). capsule=benchmarks/repro_capsules/glm47_t2_row_tile_transfer_t1024_r3_20260211_glm_combine_fp32_no_fma_row_tile.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-11T16:18:38.706586+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":77.8,"decode_tps_patched":74.2,"decode_speedup":0.9537,"decode_runs_baseline":[77.9,77.8,77.7],"decode_runs_patched":[74.0,74.3,74.2],"prefill_tps_baseline":179.0,"prefill_tps_patched":173.8,"prefill_change":-0.0291,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: current best patch set (no RoPE fusion, combine off). capsule=benchmarks/repro_capsules/glm47_prefill_consistency_abba_t1024_r5_repC_AB_20260211_control_swiglu_moe.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-11T16:21:53.177941+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":77.5,"decode_tps_patched":73.3,"decode_speedup":0.9458,"decode_runs_baseline":[75.8,77.5,77.8,77.7,75.9],"decode_runs_patched":[71.4,73.3,73.7,73.1,73.9],"prefill_tps_baseline":179.6,"prefill_tps_patched":177.0,"prefill_change":-0.0145,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":5} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via ZMLX fp32 no-FMA kernel (ZMLX_GLM_COMBINE_MODE=fp32_no_fma). capsule=benchmarks/repro_capsules/glm47_prefill_consistency_abba_t1024_r5_repC_AB_20260211_glm_combine_fp32_no_fma.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-11T16:25:00.088282+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":77.0,"decode_tps_patched":73.4,"decode_speedup":0.9532,"decode_runs_baseline":[77.3,76.7,77.2,76.3,77.0],"decode_runs_patched":[73.4,73.3,73.6,73.8,72.4],"prefill_tps_baseline":178.8,"prefill_tps_patched":178.4,"prefill_change":-0.0022,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":5} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via ZMLX fp32 no-FMA kernel (ZMLX_GLM_COMBINE_MODE=fp32_no_fma). capsule=benchmarks/repro_capsules/glm47_prefill_consistency_abba_t1024_r5_repD_BA_20260211_glm_combine_fp32_no_fma.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-11T16:29:02.784830+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":77.2,"decode_tps_patched":73.4,"decode_speedup":0.9508,"decode_runs_baseline":[75.6,77.2,77.2,78.0,77.9],"decode_runs_patched":[73.4,74.5,73.3,73.1,74.2],"prefill_tps_baseline":178.3,"prefill_tps_patched":172.9,"prefill_change":-0.0303,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":5} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: current best patch set (no RoPE fusion, combine off). capsule=benchmarks/repro_capsules/glm47_prefill_consistency_abba_t1024_r5_repD_BA_20260211_control_swiglu_moe.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-11T16:32:14.445351+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":73.7,"decode_tps_patched":74.0,"decode_speedup":1.0041,"decode_runs_baseline":[75.2,73.7,73.0,76.1,67.6],"decode_runs_patched":[74.5,74.0,74.0,71.6,68.9],"prefill_tps_baseline":171.7,"prefill_tps_patched":175.1,"prefill_change":0.0198,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":5} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: current best patch set (no RoPE fusion, combine off). capsule=benchmarks/repro_capsules/glm47_weighted_sum_fp32_no_fma_t200_r3_20260211_control_swiglu_moe.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-11T16:34:20.041116+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":77.9,"decode_tps_patched":72.4,"decode_speedup":0.9294,"decode_runs_baseline":[78.4,77.9,77.2],"decode_runs_patched":[71.9,72.7,72.4],"prefill_tps_baseline":274.8,"prefill_tps_patched":262.6,"prefill_change":-0.0444,"peak_mem_baseline_gb":15.98,"peak_mem_patched_gb":15.98,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via ZMLX fp32 no-FMA kernel (ZMLX_GLM_COMBINE_MODE=fp32_no_fma). capsule=benchmarks/repro_capsules/glm47_weighted_sum_fp32_no_fma_t200_r3_20260211_glm_combine_fp32_no_fma.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-11T16:35:25.491393+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":78.8,"decode_tps_patched":75.0,"decode_speedup":0.9518,"decode_runs_baseline":[79.0,78.4,78.8],"decode_runs_patched":[75.0,75.0,75.1],"prefill_tps_baseline":274.3,"prefill_tps_patched":270.3,"prefill_change":-0.0146,"peak_mem_baseline_gb":15.98,"peak_mem_patched_gb":15.98,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via K=8-specialized fp32 no-FMA weighted-sum kernel (ZMLX_GLM_COMBINE_MODE=fp32_no_fma_weighted_sum). capsule=benchmarks/repro_capsules/glm47_weighted_sum_fp32_no_fma_t200_r3_20260211_glm_combine_fp32_no_fma_weighted_sum.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-11T16:36:30.450164+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":78.5,"decode_tps_patched":73.7,"decode_speedup":0.9389,"decode_runs_baseline":[78.4,78.5,78.9],"decode_runs_patched":[72.7,73.7,75.5],"prefill_tps_baseline":274.2,"prefill_tps_patched":269.1,"prefill_change":-0.0186,"peak_mem_baseline_gb":15.98,"peak_mem_patched_gb":15.98,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: current best patch set (no RoPE fusion, combine off). capsule=benchmarks/repro_capsules/glm47_weighted_sum_fp32_no_fma_t1024_r3_20260211_control_swiglu_moe.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-11T16:38:52.211691+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":73.6,"decode_tps_patched":69.4,"decode_speedup":0.9429,"decode_runs_baseline":[73.2,73.6,73.9],"decode_runs_patched":[69.3,69.4,69.6],"prefill_tps_baseline":170.3,"prefill_tps_patched":162.7,"prefill_change":-0.0446,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via ZMLX fp32 no-FMA kernel (ZMLX_GLM_COMBINE_MODE=fp32_no_fma). capsule=benchmarks/repro_capsules/glm47_weighted_sum_fp32_no_fma_t1024_r3_20260211_glm_combine_fp32_no_fma.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-11T16:41:07.459206+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":73.3,"decode_tps_patched":70.1,"decode_speedup":0.9563,"decode_runs_baseline":[73.2,73.3,73.6],"decode_runs_patched":[70.1,70.5,70.0],"prefill_tps_baseline":166.2,"prefill_tps_patched":165.1,"prefill_change":-0.0066,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via K=8-specialized fp32 no-FMA weighted-sum kernel (ZMLX_GLM_COMBINE_MODE=fp32_no_fma_weighted_sum). capsule=benchmarks/repro_capsules/glm47_weighted_sum_fp32_no_fma_t1024_r3_20260211_glm_combine_fp32_no_fma_weighted_sum.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-11T16:43:23.578459+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":72.5,"decode_tps_patched":69.9,"decode_speedup":0.9641,"decode_runs_baseline":[72.2,72.5,73.6],"decode_runs_patched":[70.1,69.9,69.1],"prefill_tps_baseline":162.3,"prefill_tps_patched":162.1,"prefill_change":-0.0012,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: current best patch set (no RoPE fusion, combine off). capsule=benchmarks/repro_capsules/glm47_quickval_t200_r3_20260213.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-13T22:06:33.229033+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":82.7,"decode_tps_patched":77.6,"decode_speedup":0.9383,"decode_runs_baseline":[82.8,82.7,82.0],"decode_runs_patched":[77.0,77.9,77.6],"prefill_tps_baseline":285.3,"prefill_tps_patched":281.4,"prefill_change":-0.0137,"peak_mem_baseline_gb":15.98,"peak_mem_patched_gb":15.98,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: current best patch set (no RoPE fusion, combine off). capsule=benchmarks/repro_capsules/glm47_quickval_t200_r3_20260213.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-13T22:10:43.917682+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":81.6,"decode_tps_patched":76.3,"decode_speedup":0.935,"decode_runs_baseline":[81.4,81.6,82.1],"decode_runs_patched":[74.8,76.3,76.9],"prefill_tps_baseline":283.7,"prefill_tps_patched":277.6,"prefill_change":-0.0215,"peak_mem_baseline_gb":15.98,"peak_mem_patched_gb":15.98,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: current best patch set (no RoPE fusion, combine off). capsule=benchmarks/repro_capsules/glm47_quickval_t200_r3_20260213.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-13T22:14:46.594452+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":80.1,"decode_tps_patched":56.1,"decode_speedup":0.7004,"decode_runs_baseline":[80.3,80.1,79.6],"decode_runs_patched":[56.4,56.1,54.2],"prefill_tps_baseline":276.9,"prefill_tps_patched":289.4,"prefill_change":0.0451,"peak_mem_baseline_gb":15.98,"peak_mem_patched_gb":15.98,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: current best patch set (no RoPE fusion, combine off). capsule=benchmarks/repro_capsules/glm47_quickval_t200_r3_20260213_control_swiglu_moe.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-13T22:18:14.639935+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":81.0,"decode_tps_patched":75.3,"decode_speedup":0.9296,"decode_runs_baseline":[81.0,81.0,80.7],"decode_runs_patched":[74.3,75.5,75.3],"prefill_tps_baseline":280.7,"prefill_tps_patched":274.9,"prefill_change":-0.0207,"peak_mem_baseline_gb":15.98,"peak_mem_patched_gb":15.98,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via ZMLX fp32 no-FMA kernel (ZMLX_GLM_COMBINE_MODE=fp32_no_fma). capsule=benchmarks/repro_capsules/glm47_quickval_t200_r3_20260213_glm_combine_fp32_no_fma.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-13T22:19:26.749535+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":81.5,"decode_tps_patched":77.4,"decode_speedup":0.9497,"decode_runs_baseline":[81.2,81.7,81.5],"decode_runs_patched":[77.6,77.1,77.4],"prefill_tps_baseline":280.3,"prefill_tps_patched":276.7,"prefill_change":-0.0128,"peak_mem_baseline_gb":15.98,"peak_mem_patched_gb":15.98,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: current best patch set (no RoPE fusion, combine off). capsule=benchmarks/repro_capsules/glm47_quickval_iso_t200_r3_20260213_control_swiglu_moe.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-13T22:32:15.835373+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":79.8,"decode_tps_patched":73.2,"decode_speedup":0.9173,"decode_runs_baseline":[78.2,80.0,79.8],"decode_runs_patched":[72.8,73.2,73.5],"prefill_tps_baseline":281.7,"prefill_tps_patched":270.0,"prefill_change":-0.0415,"peak_mem_baseline_gb":15.98,"peak_mem_patched_gb":15.98,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via ZMLX fp32 no-FMA kernel (ZMLX_GLM_COMBINE_MODE=fp32_no_fma). capsule=benchmarks/repro_capsules/glm47_quickval_iso_t200_r3_20260213_glm_combine_fp32_no_fma.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-13T22:33:21.770397+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":79.5,"decode_tps_patched":76.0,"decode_speedup":0.956,"decode_runs_baseline":[80.0,79.5,79.1],"decode_runs_patched":[77.0,76.0,75.6],"prefill_tps_baseline":275.0,"prefill_tps_patched":275.1,"prefill_change":0.0004,"peak_mem_baseline_gb":15.98,"peak_mem_patched_gb":15.98,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: patch(model, patterns=['moe_mlp']) capsule=benchmarks/repro_capsules/qwen3_quickval_iso_t200_r3_20260213_control_patterns_moe_mlp.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-13T22:34:37.846419+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":41.3,"decode_tps_patched":106.7,"decode_speedup":2.5835,"decode_runs_baseline":[37.9,41.3,65.5],"decode_runs_patched":[104.3,106.7,107.1],"prefill_tps_baseline":242.0,"prefill_tps_patched":326.6,"prefill_change":0.3496,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen combine via ZMLX exact-dtype kernel (ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_quickval_iso_t200_r3_20260213_qwen_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-13T22:35:46.440067+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":109.8,"decode_tps_patched":99.4,"decode_speedup":0.9053,"decode_runs_baseline":[110.0,109.6,109.8],"decode_runs_patched":[65.9,99.4,101.2],"prefill_tps_baseline":335.0,"prefill_tps_patched":318.6,"prefill_change":-0.049,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen router fast path using argpartition on logits with top-k softmax (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1) capsule=benchmarks/repro_capsules/qwen3_quickval_iso_t200_r3_20260213_qwen_router_argpartition_logits.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-13T22:36:51.865982+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":108.4,"decode_tps_patched":105.9,"decode_speedup":0.9769,"decode_runs_baseline":[108.4,108.2,108.6],"decode_runs_patched":[105.6,106.4,105.9],"prefill_tps_baseline":331.8,"prefill_tps_patched":331.0,"prefill_change":-0.0024,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen argpartition(logits) fused top-k softmax plus exact combine kernel (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1 ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS_TOPK=1 ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_quickval_iso_t200_r3_20260213_qwen_router_argpartition_logits_topk_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-13T22:37:56.533826+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":109.4,"decode_tps_patched":106.8,"decode_speedup":0.9762,"decode_runs_baseline":[109.4,110.0,109.4],"decode_runs_patched":[107.4,105.7,106.8],"prefill_tps_baseline":335.1,"prefill_tps_patched":329.8,"prefill_change":-0.0158,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: current best patch set (no RoPE fusion, combine off). capsule=benchmarks/repro_capsules/glm47_consistency_ab_t200_r5_20260213_control_swiglu_moe.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-13T22:39:23.744820+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":79.7,"decode_tps_patched":74.6,"decode_speedup":0.936,"decode_runs_baseline":[79.3,80.3,79.8,79.7,79.7],"decode_runs_patched":[75.0,74.2,74.7,73.9,74.6],"prefill_tps_baseline":277.9,"prefill_tps_patched":271.2,"prefill_change":-0.0241,"peak_mem_baseline_gb":15.98,"peak_mem_patched_gb":15.98,"max_tokens":200,"gen_tokens":200,"runs":5} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via ZMLX fp32 no-FMA kernel (ZMLX_GLM_COMBINE_MODE=fp32_no_fma). capsule=benchmarks/repro_capsules/glm47_consistency_ab_t200_r5_20260213_glm_combine_fp32_no_fma.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-13T22:40:41.313278+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":78.4,"decode_tps_patched":76.1,"decode_speedup":0.9707,"decode_runs_baseline":[78.4,78.4,78.0,79.2,78.9],"decode_runs_patched":[74.6,75.3,77.3,77.0,76.1],"prefill_tps_baseline":276.4,"prefill_tps_patched":278.0,"prefill_change":0.0058,"peak_mem_baseline_gb":15.98,"peak_mem_patched_gb":15.98,"max_tokens":200,"gen_tokens":200,"runs":5} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via ZMLX fp32 no-FMA kernel (ZMLX_GLM_COMBINE_MODE=fp32_no_fma). capsule=benchmarks/repro_capsules/glm47_consistency_ba_t200_r5_20260213_glm_combine_fp32_no_fma.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-13T22:42:43.224012+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":79.2,"decode_tps_patched":77.2,"decode_speedup":0.9747,"decode_runs_baseline":[77.3,79.2,80.1,70.7,80.5],"decode_runs_patched":[76.8,77.2,77.9,78.1,75.1],"prefill_tps_baseline":277.4,"prefill_tps_patched":279.0,"prefill_change":0.0058,"peak_mem_baseline_gb":15.98,"peak_mem_patched_gb":15.98,"max_tokens":200,"gen_tokens":200,"runs":5} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: current best patch set (no RoPE fusion, combine off). capsule=benchmarks/repro_capsules/glm47_consistency_ba_t200_r5_20260213_control_swiglu_moe.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-13T22:44:04.832601+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":73.2,"decode_tps_patched":73.3,"decode_speedup":1.0014,"decode_runs_baseline":[79.8,81.1,61.6,43.7,73.2],"decode_runs_patched":[74.6,73.6,69.8,71.6,73.3],"prefill_tps_baseline":278.5,"prefill_tps_patched":265.5,"prefill_change":-0.0467,"peak_mem_baseline_gb":15.98,"peak_mem_patched_gb":15.98,"max_tokens":200,"gen_tokens":200,"runs":5} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: current best patch set (no RoPE fusion, combine off). capsule=benchmarks/repro_capsules/glm47_consistency_ab_t1024_r5_20260213_control_swiglu_moe.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-13T22:47:31.704304+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":75.7,"decode_tps_patched":70.4,"decode_speedup":0.93,"decode_runs_baseline":[63.7,75.7,76.2,75.3,75.9],"decode_runs_patched":[69.6,71.2,70.2,70.4,70.7],"prefill_tps_baseline":174.1,"prefill_tps_patched":168.8,"prefill_change":-0.0304,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":5} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via ZMLX fp32 no-FMA kernel (ZMLX_GLM_COMBINE_MODE=fp32_no_fma). capsule=benchmarks/repro_capsules/glm47_consistency_ab_t1024_r5_20260213_glm_combine_fp32_no_fma.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-13T22:50:46.371054+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":75.1,"decode_tps_patched":69.8,"decode_speedup":0.9294,"decode_runs_baseline":[75.1,75.4,75.6,74.5,73.6],"decode_runs_patched":[70.7,69.8,68.4,68.9,71.8],"prefill_tps_baseline":172.9,"prefill_tps_patched":170.2,"prefill_change":-0.0156,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":5} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via ZMLX fp32 no-FMA kernel (ZMLX_GLM_COMBINE_MODE=fp32_no_fma). capsule=benchmarks/repro_capsules/glm47_consistency_ba_t1024_r5_20260213_glm_combine_fp32_no_fma.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-13T22:54:45.825225+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":74.3,"decode_tps_patched":71.2,"decode_speedup":0.9583,"decode_runs_baseline":[74.8,71.7,74.4,74.1,74.3],"decode_runs_patched":[70.3,70.5,71.5,71.5,71.2],"prefill_tps_baseline":171.2,"prefill_tps_patched":168.3,"prefill_change":-0.0169,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":5} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: current best patch set (no RoPE fusion, combine off). capsule=benchmarks/repro_capsules/glm47_consistency_ba_t1024_r5_20260213_control_swiglu_moe.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-13T22:57:59.568893+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":74.7,"decode_tps_patched":70.5,"decode_speedup":0.9438,"decode_runs_baseline":[74.8,74.6,74.0,74.7,75.3],"decode_runs_patched":[70.3,71.3,71.1,70.0,70.5],"prefill_tps_baseline":170.0,"prefill_tps_patched":166.1,"prefill_change":-0.0229,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: patch(model, patterns=['moe_mlp']) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t200_r5_repA_20260213_control_patterns_moe_mlp.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-13T22:59:20.410445+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":112.7,"decode_tps_patched":109.9,"decode_speedup":0.9752,"decode_runs_baseline":[112.7,112.6,113.4,112.9,112.2],"decode_runs_patched":[109.6,110.2,110.9,109.9,109.6],"prefill_tps_baseline":336.0,"prefill_tps_patched":335.2,"prefill_change":-0.0024,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen combine via ZMLX exact-dtype kernel (ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t200_r5_repA_20260213_qwen_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-13T23:00:31.653180+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":111.9,"decode_tps_patched":108.2,"decode_speedup":0.9669,"decode_runs_baseline":[112.0,112.2,111.9,111.4,111.5],"decode_runs_patched":[108.8,108.2,107.4,107.3,108.2],"prefill_tps_baseline":335.1,"prefill_tps_patched":329.9,"prefill_change":-0.0155,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen router fast path using argpartition on logits with top-k softmax (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t200_r5_repA_20260213_qwen_router_argpartition_logits.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-13T23:01:42.753009+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":110.4,"decode_tps_patched":108.3,"decode_speedup":0.981,"decode_runs_baseline":[111.1,110.2,107.5,111.1,110.4],"decode_runs_patched":[108.3,107.4,107.8,108.3,108.9],"prefill_tps_baseline":333.7,"prefill_tps_patched":331.7,"prefill_change":-0.006,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen argpartition(logits) fused top-k softmax plus exact combine kernel (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1 ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS_TOPK=1 ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t200_r5_repA_20260213_qwen_router_argpartition_logits_topk_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-13T23:02:54.221162+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":111.1,"decode_tps_patched":107.5,"decode_speedup":0.9676,"decode_runs_baseline":[112.4,111.1,110.1,111.7,110.7],"decode_runs_patched":[107.5,104.9,105.2,108.3,108.9],"prefill_tps_baseline":335.0,"prefill_tps_patched":329.4,"prefill_change":-0.0167,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen argpartition(logits) fused top-k softmax plus exact combine kernel (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1 ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS_TOPK=1 ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t200_r5_repB_20260213_qwen_router_argpartition_logits_topk_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-13T23:04:51.610588+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":108.0,"decode_tps_patched":105.8,"decode_speedup":0.9796,"decode_runs_baseline":[107.9,107.1,108.0,108.5,108.8],"decode_runs_patched":[105.8,105.9,106.2,104.7,105.8],"prefill_tps_baseline":332.0,"prefill_tps_patched":328.6,"prefill_change":-0.0102,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen router fast path using argpartition on logits with top-k softmax (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t200_r5_repB_20260213_qwen_router_argpartition_logits.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-13T23:06:05.097824+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":105.9,"decode_tps_patched":102.7,"decode_speedup":0.9698,"decode_runs_baseline":[105.7,105.9,106.2,105.2,106.6],"decode_runs_patched":[103.7,102.7,102.0,103.5,101.4],"prefill_tps_baseline":330.5,"prefill_tps_patched":327.6,"prefill_change":-0.0088,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen combine via ZMLX exact-dtype kernel (ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t200_r5_repB_20260213_qwen_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-13T23:07:19.117305+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":105.9,"decode_tps_patched":104.0,"decode_speedup":0.9821,"decode_runs_baseline":[105.9,107.0,106.5,104.3,105.0],"decode_runs_patched":[104.4,103.9,104.0,102.8,104.8],"prefill_tps_baseline":329.6,"prefill_tps_patched":328.4,"prefill_change":-0.0036,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: patch(model, patterns=['moe_mlp']) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t200_r5_repB_20260213_control_patterns_moe_mlp.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-13T23:08:32.403671+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":107.3,"decode_tps_patched":104.3,"decode_speedup":0.972,"decode_runs_baseline":[107.6,107.3,107.5,106.9,106.1],"decode_runs_patched":[104.5,104.6,103.6,104.3,104.1],"prefill_tps_baseline":332.9,"prefill_tps_patched":328.2,"prefill_change":-0.0141,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: patch(model, patterns=['moe_mlp']) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t1024_r5_repA_20260213_control_patterns_moe_mlp.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-13T23:11:14.333494+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":105.0,"decode_tps_patched":103.2,"decode_speedup":0.9829,"decode_runs_baseline":[103.7,100.8,105.0,105.7,105.5],"decode_runs_patched":[102.2,103.6,103.2,103.3,102.9],"prefill_tps_baseline":326.9,"prefill_tps_patched":327.9,"prefill_change":0.0031,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.33,"max_tokens":1024,"gen_tokens":1024,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen combine via ZMLX exact-dtype kernel (ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t1024_r5_repA_20260213_qwen_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-13T23:13:44.976561+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":108.6,"decode_tps_patched":104.2,"decode_speedup":0.9595,"decode_runs_baseline":[105.7,105.7,108.6,109.3,108.6],"decode_runs_patched":[104.2,103.6,104.0,104.2,104.9],"prefill_tps_baseline":329.5,"prefill_tps_patched":323.8,"prefill_change":-0.0173,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.33,"max_tokens":1024,"gen_tokens":1024,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen router fast path using argpartition on logits with top-k softmax (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t1024_r5_repA_20260213_qwen_router_argpartition_logits.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-13T23:16:14.563591+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":108.4,"decode_tps_patched":105.2,"decode_speedup":0.9705,"decode_runs_baseline":[108.2,108.5,108.6,108.1,108.4],"decode_runs_patched":[105.7,105.2,105.5,104.9,104.6],"prefill_tps_baseline":331.2,"prefill_tps_patched":327.3,"prefill_change":-0.0118,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.33,"max_tokens":1024,"gen_tokens":1024,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen argpartition(logits) fused top-k softmax plus exact combine kernel (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1 ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS_TOPK=1 ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t1024_r5_repA_20260213_qwen_router_argpartition_logits_topk_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-13T23:18:44.408429+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":108.8,"decode_tps_patched":104.0,"decode_speedup":0.9559,"decode_runs_baseline":[108.8,109.1,109.0,108.6,108.7],"decode_runs_patched":[104.8,104.0,103.6,102.9,104.3],"prefill_tps_baseline":331.1,"prefill_tps_patched":323.6,"prefill_change":-0.0227,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.33,"max_tokens":1024,"gen_tokens":1024,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen argpartition(logits) fused top-k softmax plus exact combine kernel (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1 ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS_TOPK=1 ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t1024_r5_repB_20260213_qwen_router_argpartition_logits_topk_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-13T23:22:03.278313+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":104.2,"decode_tps_patched":100.9,"decode_speedup":0.9683,"decode_runs_baseline":[104.2,104.4,103.5,104.2,104.2],"decode_runs_patched":[100.6,100.9,100.6,101.0,101.0],"prefill_tps_baseline":325.8,"prefill_tps_patched":322.3,"prefill_change":-0.0107,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.33,"max_tokens":1024,"gen_tokens":1024,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen router fast path using argpartition on logits with top-k softmax (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t1024_r5_repB_20260213_qwen_router_argpartition_logits.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-13T23:24:38.644053+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":103.4,"decode_tps_patched":100.4,"decode_speedup":0.971,"decode_runs_baseline":[102.3,102.6,103.9,104.0,103.4],"decode_runs_patched":[100.4,100.4,100.3,100.4,100.2],"prefill_tps_baseline":326.4,"prefill_tps_patched":321.5,"prefill_change":-0.015,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.33,"max_tokens":1024,"gen_tokens":1024,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen combine via ZMLX exact-dtype kernel (ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t1024_r5_repB_20260213_qwen_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.2","custom_mlx":true,"timestamp":"2026-02-13T23:27:18.124048+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":102.8,"decode_tps_patched":99.1,"decode_speedup":0.964,"decode_runs_baseline":[103.1,102.8,103.4,102.4,98.8],"decode_runs_patched":[100.2,96.6,99.5,98.9,99.1],"prefill_tps_baseline":322.5,"prefill_tps_patched":318.6,"prefill_change":-0.0121,"peak_mem_baseline_gb":17.31,"peak_mem_patched_gb":17.31,"max_tokens":1024,"gen_tokens":1024,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: patch(model, patterns=['moe_mlp']) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t1024_r5_repB_20260213_control_patterns_moe_mlp.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-13T23:30:09.544024+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":103.4,"decode_tps_patched":100.3,"decode_speedup":0.97,"decode_runs_baseline":[102.8,103.4,103.2,103.7,103.6],"decode_runs_patched":[100.3,100.2,100.6,100.5,90.9],"prefill_tps_baseline":327.6,"prefill_tps_patched":321.7,"prefill_change":-0.018,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.33,"max_tokens":1024,"gen_tokens":1024,"runs":5} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":[],"hardware":"Apple M4 Max 36GB","notes":"Isolation campaign control anchor (patterns none)","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-14T01:19:58.029114+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":0,"decode_tps_baseline":84.16,"decode_tps_patched":84.12,"decode_speedup":0.9995,"decode_runs_baseline":[84.16,84.29,84.07],"decode_runs_patched":[83.68,84.72,84.12],"prefill_tps_baseline":287.5,"prefill_tps_patched":288.5,"prefill_change":0.0035,"peak_mem_baseline_gb":15.98,"peak_mem_patched_gb":15.98,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":[],"hardware":"Apple M4 Max 36GB","notes":"Isolation campaign control anchor (patterns none)","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-14T01:21:01.076203+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":0,"decode_tps_baseline":112.63,"decode_tps_patched":113.88,"decode_speedup":1.0111,"decode_runs_baseline":[110.02,112.63,114.79],"decode_runs_patched":[113.88,114.7,113.02],"prefill_tps_baseline":327.15,"prefill_tps_patched":336.57,"prefill_change":0.0288,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/LFM2-8B-A1B-4bit","model_family":"lfm","architecture":"moe","patterns_applied":[],"hardware":"Apple M4 Max 36GB","notes":"Isolation campaign control anchor (patterns none)","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-14T01:21:17.434584+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":0,"decode_tps_baseline":224.03,"decode_tps_patched":223.67,"decode_speedup":0.9984,"decode_runs_baseline":[224.03,224.11,223.82],"decode_runs_patched":[222.56,223.67,224.29],"prefill_tps_baseline":728.08,"prefill_tps_patched":737.92,"prefill_change":0.0135,"peak_mem_baseline_gb":5.3,"peak_mem_patched_gb":5.3,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: patch(model, patterns=['moe_mlp']) capsule=benchmarks/repro_capsules/qwen3_frontier_t200_r2_20260214_control_patterns_moe_mlp.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-14T02:08:39.592983+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":118.1,"decode_tps_patched":117.0,"decode_speedup":0.9907,"decode_runs_baseline":[118.1,118.0],"decode_runs_patched":[116.8,117.1],"prefill_tps_baseline":334.9,"prefill_tps_patched":338.4,"prefill_change":0.0105,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":2} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen fused router top-k softmax (ZMLX_QWEN_FUSED_ROUTER_TOPK=1) capsule=benchmarks/repro_capsules/qwen3_frontier_t200_r2_20260214_qwen_fused_router_topk.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-14T02:09:44.288521+00:00","fidelity":"FAIL","fidelity_detail":"163/200 tokens identical","modules_patched":48,"decode_tps_baseline":115.8,"decode_tps_patched":42.8,"decode_speedup":0.3696,"decode_runs_baseline":[114.6,117.0],"decode_runs_patched":[42.5,43.0],"prefill_tps_baseline":332.6,"prefill_tps_patched":207.4,"prefill_change":-0.3764,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.72,"max_tokens":200,"gen_tokens":200,"runs":2} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen fused down-proj+combine (ZMLX_QWEN_FUSED_DOWNPROJ_COMBINE=1) capsule=benchmarks/repro_capsules/qwen3_frontier_t200_r2_20260214_qwen_fused_downproj_combine.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-14T02:10:45.443195+00:00","fidelity":"FAIL","fidelity_detail":"11/200 tokens identical","modules_patched":48,"decode_tps_baseline":112.8,"decode_tps_patched":65.9,"decode_speedup":0.5842,"decode_runs_baseline":[112.2,113.3],"decode_runs_patched":[66.7,65.2],"prefill_tps_baseline":329.4,"prefill_tps_patched":293.6,"prefill_change":-0.1087,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":2} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen fused down-proj+combine with K-vectorized gather combine (ZMLX_QWEN_FUSED_DOWNPROJ_COMBINE=1 ZMLX_QWEN_FUSED_DOWNPROJ_COMBINE_KVEC=1) capsule=benchmarks/repro_capsules/qwen3_frontier_t200_r2_20260214_qwen_fused_downproj_combine_kvec.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-14T02:11:43.656968+00:00","fidelity":"FAIL","fidelity_detail":"9/200 tokens identical","modules_patched":48,"decode_tps_baseline":117.3,"decode_tps_patched":107.0,"decode_speedup":0.9122,"decode_runs_baseline":[116.3,118.3],"decode_runs_patched":[106.8,107.2],"prefill_tps_baseline":333.2,"prefill_tps_patched":325.3,"prefill_change":-0.0237,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":2} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen argpartition(logits) with fused top-k softmax kernel (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1 ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS_TOPK=1) capsule=benchmarks/repro_capsules/qwen3_frontier_t200_r2_20260214_qwen_router_argpartition_logits_topk.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-14T02:12:47.331901+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":116.8,"decode_tps_patched":56.3,"decode_speedup":0.482,"decode_runs_baseline":[118.3,115.3],"decode_runs_patched":[56.6,56.1],"prefill_tps_baseline":335.1,"prefill_tps_patched":188.3,"prefill_change":-0.4381,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":2} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen router argpartition(logits)+topk softmax plus exact-dtype combine kernel (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1 ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_frontier_t200_r2_20260214_qwen_router_argpartition_logits_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-14T02:14:16.008898+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":56.5,"decode_tps_patched":54.8,"decode_speedup":0.9699,"decode_runs_baseline":[56.4,56.6],"decode_runs_patched":[56.1,53.6],"prefill_tps_baseline":186.7,"prefill_tps_patched":196.4,"prefill_change":0.052,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":2} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen router argpartition(logits)+topk softmax plus fp32 no-FMA combine kernel (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1 ZMLX_QWEN_COMBINE_MODE=fp32_no_fma) capsule=benchmarks/repro_capsules/qwen3_frontier_t200_r2_20260214_qwen_router_argpartition_logits_combine_fp32_no_fma.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-14T02:15:56.598451+00:00","fidelity":"FAIL","fidelity_detail":"2/200 tokens identical","modules_patched":48,"decode_tps_baseline":56.6,"decode_tps_patched":26.3,"decode_speedup":0.4647,"decode_runs_baseline":[56.3,56.9],"decode_runs_patched":[26.4,26.3],"prefill_tps_baseline":177.8,"prefill_tps_patched":122.4,"prefill_change":-0.3116,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.72,"max_tokens":200,"gen_tokens":200,"runs":2} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen router argpartition(logits)+topk softmax plus fused down-proj+combine K-vectorized path (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1 ZMLX_QWEN_FUSED_DOWNPROJ_COMBINE=1 ZMLX_QWEN_FUSED_DOWNPROJ_COMBINE_KVEC=1) capsule=benchmarks/repro_capsules/qwen3_frontier_t200_r2_20260214_qwen_router_argpartition_logits_downproj_kvec.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-14T02:17:25.333349+00:00","fidelity":"FAIL","fidelity_detail":"9/200 tokens identical","modules_patched":48,"decode_tps_baseline":55.7,"decode_tps_patched":57.2,"decode_speedup":1.0269,"decode_runs_baseline":[54.4,57.0],"decode_runs_patched":[57.4,56.9],"prefill_tps_baseline":170.7,"prefill_tps_patched":175.1,"prefill_change":0.0258,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":2} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: patch(model, patterns=['moe_mlp']) capsule=benchmarks/repro_capsules/qwen3_frontier_t1024_r2_20260214_control_patterns_moe_mlp.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-14T02:20:03.301432+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":57.3,"decode_tps_patched":55.0,"decode_speedup":0.9599,"decode_runs_baseline":[59.5,55.2],"decode_runs_patched":[54.8,55.2],"prefill_tps_baseline":167.2,"prefill_tps_patched":170.1,"prefill_change":0.0173,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.33,"max_tokens":1024,"gen_tokens":1024,"runs":2} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen fused router top-k softmax (ZMLX_QWEN_FUSED_ROUTER_TOPK=1) capsule=benchmarks/repro_capsules/qwen3_frontier_t1024_r2_20260214_qwen_fused_router_topk.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-14T02:23:07.982211+00:00","fidelity":"FAIL","fidelity_detail":"175/1024 tokens identical","modules_patched":48,"decode_tps_baseline":55.2,"decode_tps_patched":25.9,"decode_speedup":0.4692,"decode_runs_baseline":[55.2,55.1],"decode_runs_patched":[25.8,25.9],"prefill_tps_baseline":168.8,"prefill_tps_patched":119.7,"prefill_change":-0.2909,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.92,"max_tokens":1024,"gen_tokens":1024,"runs":2} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen fused down-proj+combine (ZMLX_QWEN_FUSED_DOWNPROJ_COMBINE=1) capsule=benchmarks/repro_capsules/qwen3_frontier_t1024_r2_20260214_qwen_fused_downproj_combine.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-14T02:25:44.897208+00:00","fidelity":"FAIL","fidelity_detail":"15/1024 tokens identical","modules_patched":48,"decode_tps_baseline":55.4,"decode_tps_patched":36.1,"decode_speedup":0.6516,"decode_runs_baseline":[55.5,55.4],"decode_runs_patched":[36.3,36.0],"prefill_tps_baseline":182.8,"prefill_tps_patched":169.0,"prefill_change":-0.0755,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.33,"max_tokens":1024,"gen_tokens":1024,"runs":2} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen fused down-proj+combine with K-vectorized gather combine (ZMLX_QWEN_FUSED_DOWNPROJ_COMBINE=1 ZMLX_QWEN_FUSED_DOWNPROJ_COMBINE_KVEC=1) capsule=benchmarks/repro_capsules/qwen3_frontier_t1024_r2_20260214_qwen_fused_downproj_combine_kvec.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-14T02:28:05.924399+00:00","fidelity":"FAIL","fidelity_detail":"15/1024 tokens identical","modules_patched":48,"decode_tps_baseline":55.9,"decode_tps_patched":57.8,"decode_speedup":1.034,"decode_runs_baseline":[55.9,55.9],"decode_runs_patched":[56.0,59.5],"prefill_tps_baseline":200.5,"prefill_tps_patched":198.3,"prefill_change":-0.011,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.33,"max_tokens":1024,"gen_tokens":1024,"runs":2} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen argpartition(logits) with fused top-k softmax kernel (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1 ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS_TOPK=1) capsule=benchmarks/repro_capsules/qwen3_frontier_t1024_r2_20260214_qwen_router_argpartition_logits_topk.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-14T02:30:29.935030+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":55.6,"decode_tps_patched":52.5,"decode_speedup":0.9442,"decode_runs_baseline":[55.7,55.6],"decode_runs_patched":[55.1,50.0],"prefill_tps_baseline":169.7,"prefill_tps_patched":196.3,"prefill_change":0.1567,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.33,"max_tokens":1024,"gen_tokens":1024,"runs":2} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen router argpartition(logits)+topk softmax plus exact-dtype combine kernel (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1 ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_frontier_t1024_r2_20260214_qwen_router_argpartition_logits_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-14T02:32:43.519083+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":55.6,"decode_tps_patched":55.8,"decode_speedup":1.0036,"decode_runs_baseline":[55.7,55.5],"decode_runs_patched":[55.7,56.0],"prefill_tps_baseline":180.8,"prefill_tps_patched":187.0,"prefill_change":0.0343,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.33,"max_tokens":1024,"gen_tokens":1024,"runs":2} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen router argpartition(logits)+topk softmax plus fp32 no-FMA combine kernel (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1 ZMLX_QWEN_COMBINE_MODE=fp32_no_fma) capsule=benchmarks/repro_capsules/qwen3_frontier_t1024_r2_20260214_qwen_router_argpartition_logits_combine_fp32_no_fma.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-14T02:35:58.165345+00:00","fidelity":"FAIL","fidelity_detail":"7/1024 tokens identical","modules_patched":48,"decode_tps_baseline":57.3,"decode_tps_patched":25.8,"decode_speedup":0.4503,"decode_runs_baseline":[55.1,59.5],"decode_runs_patched":[25.7,25.8],"prefill_tps_baseline":181.3,"prefill_tps_patched":121.3,"prefill_change":-0.3309,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.89,"max_tokens":1024,"gen_tokens":1024,"runs":2} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen router argpartition(logits)+topk softmax plus fused down-proj+combine K-vectorized path (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1 ZMLX_QWEN_FUSED_DOWNPROJ_COMBINE=1 ZMLX_QWEN_FUSED_DOWNPROJ_COMBINE_KVEC=1) capsule=benchmarks/repro_capsules/qwen3_frontier_t1024_r2_20260214_qwen_router_argpartition_logits_downproj_kvec.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-14T02:38:20.160192+00:00","fidelity":"FAIL","fidelity_detail":"15/1024 tokens identical","modules_patched":48,"decode_tps_baseline":50.8,"decode_tps_patched":55.7,"decode_speedup":1.0965,"decode_runs_baseline":[55.3,46.2],"decode_runs_patched":[56.0,55.4],"prefill_tps_baseline":160.4,"prefill_tps_patched":180.0,"prefill_change":0.1222,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.33,"max_tokens":1024,"gen_tokens":1024,"runs":2} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: current best patch set (no RoPE fusion, combine off). capsule=benchmarks/repro_capsules/glm47_frontier_t1024_r3_20260214_control_swiglu_moe.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-14T02:41:59.843699+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":39.4,"decode_tps_patched":39.0,"decode_speedup":0.9898,"decode_runs_baseline":[39.4,39.4,39.2],"decode_runs_patched":[38.9,39.0,39.1],"prefill_tps_baseline":104.8,"prefill_tps_patched":124.1,"prefill_change":0.1842,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via ZMLX fp32 no-FMA kernel (ZMLX_GLM_COMBINE_MODE=fp32_no_fma). capsule=benchmarks/repro_capsules/glm47_frontier_t1024_r3_20260214_glm_combine_fp32_no_fma.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-14T02:45:35.549258+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":37.8,"decode_tps_patched":39.8,"decode_speedup":1.0529,"decode_runs_baseline":[35.8,37.8,39.3],"decode_runs_patched":[39.8,39.5,39.8],"prefill_tps_baseline":109.6,"prefill_tps_patched":135.3,"prefill_change":0.2345,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via ZMLX exact-dtype kernel (ZMLX_GLM_COMBINE_MODE=exact). capsule=benchmarks/repro_capsules/glm47_frontier_t1024_r3_20260214_glm_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-14T02:49:37.999212+00:00","fidelity":"FAIL","fidelity_detail":"343/1024 tokens identical","modules_patched":93,"decode_tps_baseline":39.4,"decode_tps_patched":39.0,"decode_speedup":0.9898,"decode_runs_baseline":[23.3,39.4,39.7],"decode_runs_patched":[39.0,38.8,39.2],"prefill_tps_baseline":118.3,"prefill_tps_patched":116.3,"prefill_change":-0.0169,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via t2-row-tile-inspired fp32 no-FMA kernel (ZMLX_GLM_COMBINE_MODE=fp32_no_fma_row_tile). capsule=benchmarks/repro_capsules/glm47_frontier_t1024_r3_20260214_glm_combine_fp32_no_fma_row_tile.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-14T02:53:24.141218+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":38.1,"decode_tps_patched":38.3,"decode_speedup":1.0052,"decode_runs_baseline":[38.1,38.5,32.9],"decode_runs_patched":[38.3,34.0,38.5],"prefill_tps_baseline":121.4,"prefill_tps_patched":119.1,"prefill_change":-0.0189,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: current best patch set (no RoPE fusion, combine off). capsule=benchmarks/repro_capsules/glm47_combo_fp32nofma_t1024_r3_20260214_control_swiglu_moe.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-14T05:34:15.209851+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":38.9,"decode_tps_patched":38.8,"decode_speedup":0.9974,"decode_runs_baseline":[38.8,38.9,40.2],"decode_runs_patched":[38.7,38.8,38.8],"prefill_tps_baseline":133.6,"prefill_tps_patched":113.1,"prefill_change":-0.1534,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via ZMLX fp32 no-FMA kernel (ZMLX_GLM_COMBINE_MODE=fp32_no_fma). capsule=benchmarks/repro_capsules/glm47_combo_fp32nofma_t1024_r3_20260214_glm_combine_fp32_no_fma.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-14T05:37:51.155041+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":38.9,"decode_tps_patched":39.0,"decode_speedup":1.0026,"decode_runs_baseline":[37.8,38.9,39.0],"decode_runs_patched":[39.0,38.8,39.0],"prefill_tps_baseline":126.7,"prefill_tps_patched":104.1,"prefill_change":-0.1784,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["glm47_rope","moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental additive combo: fp32-no-fma combine + glm47_rope decode path. capsule=benchmarks/repro_capsules/glm47_combo_fp32nofma_t1024_r3_20260214_glm_combine_fp32_no_fma_glm47_rope.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-14T05:41:33.282139+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":140,"decode_tps_baseline":39.0,"decode_tps_patched":37.4,"decode_speedup":0.959,"decode_runs_baseline":[39.0,39.0,38.7],"decode_runs_patched":[38.6,34.3,37.4],"prefill_tps_baseline":103.3,"prefill_tps_patched":112.3,"prefill_change":0.0871,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental additive combo: fp32-no-fma combine + shared_experts overlap streams=2. capsule=benchmarks/repro_capsules/glm47_combo_fp32nofma_t1024_r3_20260214_glm_combine_fp32_no_fma_shared_overlap_streams2.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-14T05:45:36.754992+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":38.8,"decode_tps_patched":30.7,"decode_speedup":0.7912,"decode_runs_baseline":[38.7,38.8,39.1],"decode_runs_patched":[31.6,30.7,30.2],"prefill_tps_baseline":127.3,"prefill_tps_patched":106.5,"prefill_change":-0.1634,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: current best patch set (no RoPE fusion, combine off). capsule=benchmarks/repro_capsules/glm47_quickval_iso_t200_r3_20260217_control_swiglu_moe.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-17T22:29:33.974098+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":86.0,"decode_tps_patched":76.7,"decode_speedup":0.8919,"decode_runs_baseline":[86.0,85.9,86.3],"decode_runs_patched":[80.9,75.2,76.7],"prefill_tps_baseline":290.6,"prefill_tps_patched":273.8,"prefill_change":-0.0578,"peak_mem_baseline_gb":15.98,"peak_mem_patched_gb":15.98,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via ZMLX fp32 no-FMA kernel (ZMLX_GLM_COMBINE_MODE=fp32_no_fma). capsule=benchmarks/repro_capsules/glm47_quickval_iso_t200_r3_20260217_glm_combine_fp32_no_fma.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-17T22:30:38.106185+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":83.9,"decode_tps_patched":79.8,"decode_speedup":0.9511,"decode_runs_baseline":[85.2,83.9,82.2],"decode_runs_patched":[81.1,79.8,79.4],"prefill_tps_baseline":290.5,"prefill_tps_patched":287.9,"prefill_change":-0.009,"peak_mem_baseline_gb":15.98,"peak_mem_patched_gb":15.98,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: current best patch set (no RoPE fusion, combine off). capsule=benchmarks/repro_capsules/glm47_quickval_iso_t200_r3_20260217_control_swiglu_moe.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-17T22:36:22.521384+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":81.3,"decode_tps_patched":76.0,"decode_speedup":0.9348,"decode_runs_baseline":[79.4,81.3,81.7],"decode_runs_patched":[82.5,76.0,75.4],"prefill_tps_baseline":290.3,"prefill_tps_patched":281.2,"prefill_change":-0.0313,"peak_mem_baseline_gb":15.98,"peak_mem_patched_gb":15.98,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via ZMLX fp32 no-FMA kernel (ZMLX_GLM_COMBINE_MODE=fp32_no_fma). capsule=benchmarks/repro_capsules/glm47_quickval_iso_t200_r3_20260217_glm_combine_fp32_no_fma.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-17T22:36:56.206838+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":79.7,"decode_tps_patched":76.4,"decode_speedup":0.9586,"decode_runs_baseline":[80.3,79.1,79.7],"decode_runs_patched":[76.5,76.3,76.4],"prefill_tps_baseline":283.2,"prefill_tps_patched":270.2,"prefill_change":-0.0459,"peak_mem_baseline_gb":15.98,"peak_mem_patched_gb":15.98,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: patch(model, patterns=['moe_mlp']) capsule=benchmarks/repro_capsules/qwen3_quickval_iso_t200_r3_20260217_control_patterns_moe_mlp.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-17T22:40:52.031970+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":112.9,"decode_tps_patched":101.5,"decode_speedup":0.899,"decode_runs_baseline":[110.4,112.9,115.4],"decode_runs_patched":[105.7,99.6,101.5],"prefill_tps_baseline":335.1,"prefill_tps_patched":322.8,"prefill_change":-0.0367,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen combine via ZMLX exact-dtype kernel (ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_quickval_iso_t200_r3_20260217_qwen_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-17T22:41:22.206373+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":112.2,"decode_tps_patched":115.0,"decode_speedup":1.025,"decode_runs_baseline":[113.8,112.2,110.6],"decode_runs_patched":[110.0,115.0,117.1],"prefill_tps_baseline":332.2,"prefill_tps_patched":336.5,"prefill_change":0.0129,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen router fast path using argpartition on logits with top-k softmax (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1) capsule=benchmarks/repro_capsules/qwen3_quickval_iso_t200_r3_20260217_qwen_router_argpartition_logits.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-17T22:41:53.315724+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":107.5,"decode_tps_patched":106.1,"decode_speedup":0.987,"decode_runs_baseline":[96.0,107.5,111.1],"decode_runs_patched":[95.3,106.1,108.5],"prefill_tps_baseline":313.7,"prefill_tps_patched":326.7,"prefill_change":0.0414,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen argpartition(logits) fused top-k softmax plus exact combine kernel (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1 ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS_TOPK=1 ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_quickval_iso_t200_r3_20260217_qwen_router_argpartition_logits_topk_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-17T22:42:23.579953+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":112.0,"decode_tps_patched":108.7,"decode_speedup":0.9705,"decode_runs_baseline":[112.0,113.1,111.6],"decode_runs_patched":[112.3,106.0,108.7],"prefill_tps_baseline":334.8,"prefill_tps_patched":336.7,"prefill_change":0.0057,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: current best patch set (no RoPE fusion, combine off). capsule=benchmarks/repro_capsules/glm47_consistency_ab_t200_r5_20260217_control_swiglu_moe.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-17T22:43:07.868019+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":82.0,"decode_tps_patched":77.1,"decode_speedup":0.9402,"decode_runs_baseline":[81.3,82.0,81.5,83.6,83.1],"decode_runs_patched":[80.4,79.2,77.1,75.6,75.3],"prefill_tps_baseline":284.1,"prefill_tps_patched":280.6,"prefill_change":-0.0123,"peak_mem_baseline_gb":15.98,"peak_mem_patched_gb":15.98,"max_tokens":200,"gen_tokens":200,"runs":5} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via ZMLX fp32 no-FMA kernel (ZMLX_GLM_COMBINE_MODE=fp32_no_fma). capsule=benchmarks/repro_capsules/glm47_consistency_ab_t200_r5_20260217_glm_combine_fp32_no_fma.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-17T22:43:53.120019+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":79.7,"decode_tps_patched":75.5,"decode_speedup":0.9473,"decode_runs_baseline":[81.6,79.2,79.7,81.3,79.6],"decode_runs_patched":[76.4,75.1,75.5,76.3,75.2],"prefill_tps_baseline":278.3,"prefill_tps_patched":279.8,"prefill_change":0.0054,"peak_mem_baseline_gb":15.98,"peak_mem_patched_gb":15.98,"max_tokens":200,"gen_tokens":200,"runs":5} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via ZMLX fp32 no-FMA kernel (ZMLX_GLM_COMBINE_MODE=fp32_no_fma). capsule=benchmarks/repro_capsules/glm47_consistency_ba_t200_r5_20260217_glm_combine_fp32_no_fma.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-17T22:45:27.487099+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":78.5,"decode_tps_patched":74.8,"decode_speedup":0.9529,"decode_runs_baseline":[75.6,77.5,78.5,79.6,79.3],"decode_runs_patched":[70.7,74.7,74.8,75.7,76.6],"prefill_tps_baseline":276.3,"prefill_tps_patched":272.8,"prefill_change":-0.0127,"peak_mem_baseline_gb":15.98,"peak_mem_patched_gb":15.98,"max_tokens":200,"gen_tokens":200,"runs":5} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: current best patch set (no RoPE fusion, combine off). capsule=benchmarks/repro_capsules/glm47_consistency_ba_t200_r5_20260217_control_swiglu_moe.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-17T22:46:14.515042+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":93,"decode_tps_baseline":78.9,"decode_tps_patched":75.5,"decode_speedup":0.9569,"decode_runs_baseline":[78.5,81.3,80.5,78.9,78.6],"decode_runs_patched":[74.2,75.3,76.8,75.5,78.0],"prefill_tps_baseline":280.9,"prefill_tps_patched":280.3,"prefill_change":-0.0021,"peak_mem_baseline_gb":15.98,"peak_mem_patched_gb":15.98,"max_tokens":200,"gen_tokens":200,"runs":5} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: current best patch set (no RoPE fusion, combine off). capsule=benchmarks/repro_capsules/glm47_consistency_ab_t1024_r5_20260217_control_swiglu_moe.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-17T22:48:53.338737+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":76.8,"decode_tps_patched":73.9,"decode_speedup":0.9622,"decode_runs_baseline":[71.9,74.1,77.3,77.4,76.8],"decode_runs_patched":[72.4,73.9,73.2,74.3,77.5],"prefill_tps_baseline":182.3,"prefill_tps_patched":186.5,"prefill_change":0.023,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":5} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via ZMLX fp32 no-FMA kernel (ZMLX_GLM_COMBINE_MODE=fp32_no_fma). capsule=benchmarks/repro_capsules/glm47_consistency_ab_t1024_r5_20260217_glm_combine_fp32_no_fma.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-17T22:51:30.770762+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":77.4,"decode_tps_patched":75.8,"decode_speedup":0.9793,"decode_runs_baseline":[74.5,77.2,77.9,78.0,77.4],"decode_runs_patched":[75.8,77.6,77.5,75.3,70.8],"prefill_tps_baseline":185.8,"prefill_tps_patched":186.4,"prefill_change":0.0032,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":5} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: force GLM combine via ZMLX fp32 no-FMA kernel (ZMLX_GLM_COMBINE_MODE=fp32_no_fma). capsule=benchmarks/repro_capsules/glm47_consistency_ba_t1024_r5_20260217_glm_combine_fp32_no_fma.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-17T22:54:45.714643+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":80.4,"decode_tps_patched":77.5,"decode_speedup":0.9639,"decode_runs_baseline":[80.3,80.4,79.5,80.7,80.5],"decode_runs_patched":[77.5,77.4,78.2,78.0,77.5],"prefill_tps_baseline":187.0,"prefill_tps_patched":183.2,"prefill_change":-0.0203,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":5} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":["moe_mlp","swiglu_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: current best patch set (no RoPE fusion, combine off). capsule=benchmarks/repro_capsules/glm47_consistency_ba_t1024_r5_20260217_control_swiglu_moe.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-17T22:57:16.449347+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":93,"decode_tps_baseline":80.1,"decode_tps_patched":76.4,"decode_speedup":0.9538,"decode_runs_baseline":[77.8,80.9,80.1,80.1,81.0],"decode_runs_patched":[76.1,76.6,76.7,76.1,76.4],"prefill_tps_baseline":184.9,"prefill_tps_patched":180.4,"prefill_change":-0.0243,"peak_mem_baseline_gb":16.0,"peak_mem_patched_gb":16.0,"max_tokens":1024,"gen_tokens":1024,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: patch(model, patterns=['moe_mlp']) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t200_r5_repA_20260217_control_patterns_moe_mlp.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-17T22:57:52.499401+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":117.6,"decode_tps_patched":114.3,"decode_speedup":0.9719,"decode_runs_baseline":[117.7,117.8,111.4,115.3,117.6],"decode_runs_patched":[114.3,114.4,111.7,116.1,110.3],"prefill_tps_baseline":334.2,"prefill_tps_patched":335.5,"prefill_change":0.0039,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen combine via ZMLX exact-dtype kernel (ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t200_r5_repA_20260217_qwen_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-17T22:58:28.654345+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":116.6,"decode_tps_patched":113.9,"decode_speedup":0.9768,"decode_runs_baseline":[114.4,116.8,116.6,117.1,116.6],"decode_runs_patched":[113.9,113.7,110.2,114.5,115.7],"prefill_tps_baseline":339.5,"prefill_tps_patched":335.0,"prefill_change":-0.0133,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen router fast path using argpartition on logits with top-k softmax (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t200_r5_repA_20260217_qwen_router_argpartition_logits.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-17T22:59:05.403670+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":114.7,"decode_tps_patched":109.4,"decode_speedup":0.9538,"decode_runs_baseline":[114.1,114.5,114.7,115.1,117.1],"decode_runs_patched":[108.0,109.4,112.0,109.5,107.6],"prefill_tps_baseline":335.2,"prefill_tps_patched":332.4,"prefill_change":-0.0084,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen argpartition(logits) fused top-k softmax plus exact combine kernel (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1 ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS_TOPK=1 ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t200_r5_repA_20260217_qwen_router_argpartition_logits_topk_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-17T22:59:45.455588+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":106.9,"decode_tps_patched":111.1,"decode_speedup":1.0393,"decode_runs_baseline":[107.6,108.0,104.1,106.9,104.3],"decode_runs_patched":[106.7,111.2,108.9,111.1,113.8],"prefill_tps_baseline":324.9,"prefill_tps_patched":336.3,"prefill_change":0.0351,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen argpartition(logits) fused top-k softmax plus exact combine kernel (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1 ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS_TOPK=1 ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t200_r5_repB_20260217_qwen_router_argpartition_logits_topk_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-17T23:01:17.904696+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":105.6,"decode_tps_patched":105.2,"decode_speedup":0.9962,"decode_runs_baseline":[111.4,108.1,105.6,52.9,39.1],"decode_runs_patched":[107.4,105.2,105.1,105.6,101.2],"prefill_tps_baseline":329.8,"prefill_tps_patched":328.4,"prefill_change":-0.0042,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen router fast path using argpartition on logits with top-k softmax (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t200_r5_repB_20260217_qwen_router_argpartition_logits.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-17T23:01:59.228668+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":110.8,"decode_tps_patched":112.1,"decode_speedup":1.0117,"decode_runs_baseline":[108.7,109.2,110.8,115.8,115.6],"decode_runs_patched":[115.6,111.9,105.4,116.6,112.1],"prefill_tps_baseline":332.0,"prefill_tps_patched":338.1,"prefill_change":0.0184,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen combine via ZMLX exact-dtype kernel (ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t200_r5_repB_20260217_qwen_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-17T23:02:34.948911+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":118.1,"decode_tps_patched":114.1,"decode_speedup":0.9661,"decode_runs_baseline":[113.7,118.1,116.9,118.2,118.2],"decode_runs_patched":[116.8,117.1,114.1,110.1,110.8],"prefill_tps_baseline":340.4,"prefill_tps_patched":334.1,"prefill_change":-0.0185,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: patch(model, patterns=['moe_mlp']) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t200_r5_repB_20260217_control_patterns_moe_mlp.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-17T23:03:10.249347+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":48,"decode_tps_baseline":116.7,"decode_tps_patched":112.8,"decode_speedup":0.9666,"decode_runs_baseline":[117.5,114.1,116.8,115.3,116.7],"decode_runs_patched":[112.7,108.7,112.8,116.1,115.6],"prefill_tps_baseline":336.1,"prefill_tps_patched":334.2,"prefill_change":-0.0057,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: patch(model, patterns=['moe_mlp']) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t1024_r5_repA_20260217_control_patterns_moe_mlp.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-17T23:05:03.485411+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":109.8,"decode_tps_patched":104.4,"decode_speedup":0.9508,"decode_runs_baseline":[111.6,109.9,109.8,109.3,108.6],"decode_runs_patched":[103.5,104.4,103.8,108.3,110.6],"prefill_tps_baseline":323.9,"prefill_tps_patched":329.3,"prefill_change":0.0167,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.33,"max_tokens":1024,"gen_tokens":1024,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen combine via ZMLX exact-dtype kernel (ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t1024_r5_repA_20260217_qwen_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-17T23:07:05.132366+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":108.5,"decode_tps_patched":108.9,"decode_speedup":1.0037,"decode_runs_baseline":[108.5,109.0,111.6,103.0,107.5],"decode_runs_patched":[108.9,91.9,84.7,111.0,109.9],"prefill_tps_baseline":329.0,"prefill_tps_patched":331.0,"prefill_change":0.0061,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.33,"max_tokens":1024,"gen_tokens":1024,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen router fast path using argpartition on logits with top-k softmax (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t1024_r5_repA_20260217_qwen_router_argpartition_logits.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-17T23:08:57.638988+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":112.4,"decode_tps_patched":109.2,"decode_speedup":0.9715,"decode_runs_baseline":[110.1,112.4,109.5,112.6,113.3],"decode_runs_patched":[109.2,109.2,111.2,110.9,106.9],"prefill_tps_baseline":333.1,"prefill_tps_patched":334.1,"prefill_change":0.003,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.33,"max_tokens":1024,"gen_tokens":1024,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen argpartition(logits) fused top-k softmax plus exact combine kernel (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1 ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS_TOPK=1 ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t1024_r5_repA_20260217_qwen_router_argpartition_logits_topk_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-17T23:10:53.361944+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":112.1,"decode_tps_patched":104.9,"decode_speedup":0.9358,"decode_runs_baseline":[110.0,112.3,112.1,112.4,105.4],"decode_runs_patched":[104.9,104.2,104.8,109.1,105.5],"prefill_tps_baseline":334.4,"prefill_tps_patched":326.1,"prefill_change":-0.0248,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.33,"max_tokens":1024,"gen_tokens":1024,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen argpartition(logits) fused top-k softmax plus exact combine kernel (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1 ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS_TOPK=1 ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t1024_r5_repB_20260217_qwen_router_argpartition_logits_topk_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-17T23:13:38.847427+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":110.9,"decode_tps_patched":106.2,"decode_speedup":0.9576,"decode_runs_baseline":[112.3,111.0,110.9,109.7,109.0],"decode_runs_patched":[105.7,107.9,107.7,106.2,105.8],"prefill_tps_baseline":326.6,"prefill_tps_patched":327.6,"prefill_change":0.0031,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.33,"max_tokens":1024,"gen_tokens":1024,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen router fast path using argpartition on logits with top-k softmax (ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t1024_r5_repB_20260217_qwen_router_argpartition_logits.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-17T23:15:34.487345+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":106.8,"decode_tps_patched":107.6,"decode_speedup":1.0075,"decode_runs_baseline":[110.2,106.8,106.7,109.8,106.4],"decode_runs_patched":[107.6,106.1,106.2,108.6,108.1],"prefill_tps_baseline":330.0,"prefill_tps_patched":329.8,"prefill_change":-0.0006,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.33,"max_tokens":1024,"gen_tokens":1024,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Experimental: Qwen combine via ZMLX exact-dtype kernel (ZMLX_QWEN_COMBINE_MODE=exact) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t1024_r5_repB_20260217_qwen_combine_exact.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-17T23:17:31.259383+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":110.1,"decode_tps_patched":107.4,"decode_speedup":0.9755,"decode_runs_baseline":[104.8,110.1,110.7,110.8,109.1],"decode_runs_patched":[107.4,106.6,100.1,108.0,107.4],"prefill_tps_baseline":332.8,"prefill_tps_patched":328.0,"prefill_change":-0.0144,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.33,"max_tokens":1024,"gen_tokens":1024,"runs":5} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":["moe_mlp"],"hardware":"Apple M4 Max 36GB","notes":"Control: patch(model, patterns=['moe_mlp']) capsule=benchmarks/repro_capsules/qwen3_isolation_ordered_t1024_r5_repB_20260217_control_patterns_moe_mlp.json","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-17T23:19:25.346700+00:00","fidelity":"PASS","fidelity_detail":"1024/1024 tokens identical","modules_patched":48,"decode_tps_baseline":111.7,"decode_tps_patched":109.8,"decode_speedup":0.983,"decode_runs_baseline":[111.7,111.6,111.4,113.3,113.0],"decode_runs_patched":[109.8,110.3,110.6,109.5,106.6],"prefill_tps_baseline":329.2,"prefill_tps_patched":333.0,"prefill_change":0.0115,"peak_mem_baseline_gb":17.33,"peak_mem_patched_gb":17.33,"max_tokens":1024,"gen_tokens":1024,"runs":5} +{"model_id":"mlx-community/GLM-4.7-Flash-4bit-mxfp4","model_family":"glm","architecture":"moe","patterns_applied":[],"hardware":"Apple M4 Max 36GB","notes":"Isolation campaign control anchor (patterns none)","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-17T23:41:50.470511+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":0,"decode_tps_baseline":86.75,"decode_tps_patched":86.13,"decode_speedup":0.9928,"decode_runs_baseline":[86.75,87.02,84.02],"decode_runs_patched":[86.13,83.92,86.88],"prefill_tps_baseline":287.14,"prefill_tps_patched":287.86,"prefill_change":0.0025,"peak_mem_baseline_gb":15.98,"peak_mem_patched_gb":15.98,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/Qwen3-30B-A3B-4bit","model_family":"qwen","architecture":"moe","patterns_applied":[],"hardware":"Apple M4 Max 36GB","notes":"Isolation campaign control anchor (patterns none)","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-17T23:42:18.034882+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":0,"decode_tps_baseline":118.32,"decode_tps_patched":117.94,"decode_speedup":0.9968,"decode_runs_baseline":[118.22,118.32,118.6],"decode_runs_patched":[117.94,118.35,113.08],"prefill_tps_baseline":341.09,"prefill_tps_patched":340.07,"prefill_change":-0.003,"peak_mem_baseline_gb":17.24,"peak_mem_patched_gb":17.24,"max_tokens":200,"gen_tokens":200,"runs":3} +{"model_id":"mlx-community/LFM2-8B-A1B-4bit","model_family":"lfm","architecture":"moe","patterns_applied":[],"hardware":"Apple M4 Max 36GB","notes":"Isolation campaign control anchor (patterns none)","macos_version":"26.1","mlx_version":"0.30.7.dev20260207+8fe1d092","zmlx_version":"0.8.5","zmlx_commit":"3824ede","python_version":"3.14.3","custom_mlx":true,"timestamp":"2026-02-17T23:43:20.631107+00:00","fidelity":"PASS","fidelity_detail":"200/200 tokens identical","modules_patched":0,"decode_tps_baseline":228.4,"decode_tps_patched":228.37,"decode_speedup":0.9999,"decode_runs_baseline":[227.74,228.75,228.4],"decode_runs_patched":[226.4,228.37,228.76],"prefill_tps_baseline":755.24,"prefill_tps_patched":754.08,"prefill_change":-0.0015,"peak_mem_baseline_gb":5.3,"peak_mem_patched_gb":5.3,"max_tokens":200,"gen_tokens":200,"runs":3} diff --git a/benchmarks/repro_capsules/lfm2_24b_dsimd_gate_m4max_20260224.json b/benchmarks/repro_capsules/lfm2_24b_dsimd_gate_m4max_20260224.json new file mode 100644 index 0000000..f0440b3 --- /dev/null +++ b/benchmarks/repro_capsules/lfm2_24b_dsimd_gate_m4max_20260224.json @@ -0,0 +1,58 @@ +{ + "experiment": "lfm2_24b_dsimd_gate_kernel", + "model": "LiquidAI/LFM2-24B-A2B-MLX-4bit", + "timestamp": "2026-02-24T10:29:17.816582", + "hardware": { + "chip": "Apple M4 Max", + "memory_gb": 36, + "os": "macOS-26.1-arm64-arm-64bit-Mach-O", + "python": "3.14.3 (main, Feb 3 2026, 15:32:20) [Clang 17.0.0 (clang-1700.6.3.2)]" + }, + "configuration": { + "gate_mode": "kernel (D-SIMD, 2 SIMD groups, ascending order)", + "combine_mode": "native", + "fused_swiglu": false, + "pattern": "moe_mlp", + "patched_modules": 38, + "kernel": "topk_softmax_bias_dsimd D=64 K=4 TG=64" + }, + "quick_gate_200tok_3runs": { + "baseline_decode_median": 152.6, + "patched_decode_median": 162.8, + "speedup": 1.067, + "fidelity": "200/200 PASS", + "baseline_prompt_median": 491.9, + "patched_prompt_median": 499.6, + "peak_mem_gb": 13.45 + }, + "confirmation_500tok_5runs": { + "baseline_decode_median": 152.0, + "patched_decode_median": 161.1, + "speedup": 1.06, + "fidelity": "500/500 PASS", + "baseline_prompt_median": 490.7, + "patched_prompt_median": 501.8, + "peak_mem_gb": 13.45, + "baseline_runs": [ + 151.0, + 152.4, + 150.1, + 152.0, + 152.7 + ], + "patched_runs": [ + 163.7, + 161.1, + 160.6, + 162.8, + 160.6 + ] + }, + "key_findings": [ + "D-SIMD kernel fuses softmax+bias+topK into 1 Metal dispatch (replaces ~10 dispatches)", + "Uses 2 SIMD groups (64 threads) with cross-group reductions (only 4K barriers vs 24+ in TG=256)", + "Ascending value order output matches MLX argpartition ordering (critical for fidelity)", + "gather_qmm_swiglu HURTS LFM2-24B (K=4): ~0.77x regression + fidelity failure", + "Fused SwiGLU designed for K=2 (LFM2-8B); K=4 experts overwhelm the fused dispatch" + ] +} \ No newline at end of file diff --git a/configs/qwen3_1p7b_kernel_sft_lora.yaml b/configs/qwen3_1p7b_kernel_sft_lora.yaml new file mode 100644 index 0000000..37daa83 --- /dev/null +++ b/configs/qwen3_1p7b_kernel_sft_lora.yaml @@ -0,0 +1,24 @@ +model: mlx-community/Qwen3-1.7B-4bit +train: true +fine_tune_type: lora +optimizer: adamw +data: training_data/kernel_sft_qwen3_1p7b +mask_prompt: true +num_layers: 12 +batch_size: 1 +iters: 6000 +val_batches: 64 +learning_rate: 5.0e-5 +steps_per_report: 20 +steps_per_eval: 200 +grad_accumulation_steps: 8 +save_every: 200 +test: true +test_batches: 64 +max_seq_length: 1024 +grad_checkpoint: true +seed: 42 +lora_parameters: + rank: 16 + dropout: 0.05 + scale: 32.0 diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md index 590d649..7289c5f 100644 --- a/docs/ARCHITECTURE.md +++ b/docs/ARCHITECTURE.md @@ -101,4 +101,4 @@ Good upstream candidates: - a tiny threadgroup heuristic helper - minimal correctness tests around `metal_kernel` -See `docs/UPSTREAMING.md`. +See [`UPSTREAM_PLAN.md`](../UPSTREAM_PLAN.md). diff --git a/docs/BENCHMARKS.md b/docs/BENCHMARKS.md index 23b3557..085389e 100644 --- a/docs/BENCHMARKS.md +++ b/docs/BENCHMARKS.md @@ -9,6 +9,35 @@ All benchmarks use `python -m zmlx.validate` with the following defaults: - **Patched:** `patch(model)` with default settings (stable mode) - **Fidelity:** every generated token ID compared between patched and unpatched +## Current large-model protocol (2026-02-13+) + +For GLM-4.7-Flash and Qwen3-30B-A3B, do not run multi-variant sweeps in a +single Python process. Those runs can accumulate allocator pressure and OOM. + +Use isolation-first execution: +- `benchmarks/bench_iso_variant_sweep.py` (one variant per subprocess). +- AB/BA replicate blocks with cooldown windows for GLM consistency. +- Qwen control anchor: `control_patterns_moe_mlp`. +- Matrix control anchors with `python -m zmlx.matrix run ... --patterns none`. + +Phase-based entrypoint: + +```bash +source .venv/bin/activate +bash benchmarks/run_3hr_benchmark_campaign.sh +``` + +Supported phases: +- `quick` +- `glm_abba_200` +- `glm_abba_1024` +- `qwen_isolation_200` +- `qwen_isolation_1024` +- `glm_stress` +- `matrix_controls` +- `checks` +- `all` + ### Repro capsules Each benchmark result is backed by a repro capsule in `benchmarks/repro_capsules/` containing: diff --git a/docs/AGENTS.md b/docs/DEVELOPMENT.md similarity index 100% rename from docs/AGENTS.md rename to docs/DEVELOPMENT.md diff --git a/docs/FOUNDRY.md b/docs/FOUNDRY.md new file mode 100644 index 0000000..73243d5 --- /dev/null +++ b/docs/FOUNDRY.md @@ -0,0 +1,71 @@ +# Foundry Module + +The foundry generates Metal kernel variant datasets for evaluation and training data export. It replaces four earlier prototype repos (Discover, Lab, Foundry, DataFoundry) with a single module under `src/zmlx/foundry/`. + +## What it does + +1. **Op registry** -- 16 registered ops across 9 kernel classes (rmsnorm, layernorm, softmax, swiglu, geglu, moe_combine, moe_gating, rope, dropout). +2. **Metal templates** -- discovers `.metal` template files with knob placeholders (`{{BLOCK_SIZE}}`, `{{USE_FMA}}`, etc.) and renders concrete variants. +3. **Harness** -- compiles each variant, runs correctness checks against the MLX reference op, and benchmarks throughput. +4. **Sampling** -- random, coverage-guided, mutation, and CEM-based candidate generation strategies. +5. **Curriculum scheduler** -- staged op unlocking across difficulty tiers. +6. **Export** -- NDJSON training data and chat-style SFT JSONL for kernel-writing fine-tuning. + +## CLI + +```bash +python -m zmlx.foundry run # generate kernel dataset +python -m zmlx.foundry run --ops rmsnorm swiglu --n 500 --workers 4 +python -m zmlx.foundry report sessions/my_run # coverage + pareto reports +python -m zmlx.foundry export sessions/my_run --out training_data/ +python -m zmlx.foundry export-sft --out training_data/kernel_sft +python -m zmlx.foundry list # list registered ops +``` + +## SFT export workflow + +The `export-sft` command produces chat-format JSONL suitable for `mlx_lm.lora` fine-tuning. It ingests from four sources: foundry sessions, KD runs, discover sessions, and raw templates. Output includes train/valid/test splits with deduplication and a manifest JSON. + +Example end-to-end training: + +```bash +# 1. Export SFT dataset +python -m zmlx.foundry export-sft --out training_data/kernel_sft + +# 2. Run LoRA fine-tuning +python -m mlx_lm.lora \ + --model mlx-community/Qwen3-1.7B-4bit \ + --data training_data/kernel_sft \ + --train \ + --config configs/qwen3_1p7b_kernel_sft_lora.yaml +``` + +See `configs/qwen3_1p7b_kernel_sft_lora.yaml` for the training configuration. + +## Module layout + +``` +src/zmlx/foundry/ + __init__.py # module docstring, __all__ + __main__.py # CLI (run, report, export, export-sft, list) + taxonomy.py # foundation types (KernelClass, OpSpec, KernelCandidate) + ids.py # stable identifiers for attempts and cache keys + ndjson.py # append-only NDJSON logging + scheduler.py # curriculum-based staged op unlocking + session.py # session directory management + workers.py # multi-process worker orchestration + ops/ # op registry (16 ops) + templates/ # Metal template discovery and rendering + harness/ # compile, correctness, benchmark orchestrator + sampling/ # candidate generation strategies + plugins/ # protocol-based extensibility + reports/ # coverage analysis and Pareto extraction + export/ # training JSONL + chat SFT export +``` + +## Adding a new op + +1. Create a template `.metal` file in `templates/`. +2. Register the op in `ops/` with its `OpSpec` (kernel class, reference function, shapes, dtypes). +3. Run `python -m zmlx.foundry run --ops your_op --n 100` to verify the harness evaluates it correctly. +4. Run `python -m zmlx.foundry report` to check coverage. diff --git a/docs/ROADMAP.md b/docs/ROADMAP.md index cff7599..d125520 100644 --- a/docs/ROADMAP.md +++ b/docs/ROADMAP.md @@ -372,7 +372,7 @@ Items 1, 2, 3, 5, and 6 can be developed in parallel. Item 4 feeds into Item 7. - [MLX Documentation](https://ml-explore.github.io/mlx/build/html/index.html) - [MLX Custom Metal Kernels](https://ml-explore.github.io/mlx/build/html/dev/custom_metal_kernels.html) - [MLX GitHub Issues](https://github.com/ml-explore/mlx/issues) -- [ZMLX Benchmarks](../benchmarks/results/TEST_SUMMARY.md) +- [ZMLX Benchmarks](BENCHMARKS.md) --- diff --git a/pyproject.toml b/pyproject.toml index a469552..fe45c90 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "hatchling.build" [project] name = "zmlx" version = "0.8.5" -description = "ZMLX: Metal-kernel toolkit and optimization lab for MLX on Apple Silicon. Fused MoE decode (+5-12% on LFM2-8B-A1B), custom GPU kernels in one line, 70+ kernel catalog." +description = "ZMLX: Metal-kernel toolkit and optimization lab for MLX on Apple Silicon. Fused MoE decode (+6-12% on LFM2 8B/24B), custom GPU kernels in one line, 70+ kernel catalog." readme = "README.md" requires-python = ">=3.10" license = { text = "MIT" } @@ -78,3 +78,7 @@ disallow_untyped_defs = false no_implicit_optional = true check_untyped_defs = true ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = ["zmlx.foundry.*", "zmlx.fusion.*"] +ignore_errors = true diff --git a/src/zmlx/__init__.py b/src/zmlx/__init__.py index 38fd643..f3bcf2b 100644 --- a/src/zmlx/__init__.py +++ b/src/zmlx/__init__.py @@ -53,7 +53,7 @@ def __getattr__(name: str): # Submodules (includes elementwise as a module for backward compat) if name in { "autograd", "elementwise", "kernels", "metal", "registry", "rowwise", "msl", - "patch", "codegen", "autotune", "optimizers", "nn", "kv_cache", + "patch", "codegen", "autotune", "optimizers", "nn", "kv_cache", "fusion", "testing", "bench", "profile", "device_profile", "kd", "kb", }: import importlib @@ -95,6 +95,7 @@ def __dir__() -> list[str]: "nn", "device_profile", "kv_cache", + "fusion", # Backend compatibility "_compat", # Kernel authoring (from zmlx.api) diff --git a/src/zmlx/foundry/__init__.py b/src/zmlx/foundry/__init__.py new file mode 100644 index 0000000..0141c60 --- /dev/null +++ b/src/zmlx/foundry/__init__.py @@ -0,0 +1,53 @@ +"""ZMLX Foundry -- kernel template system and evaluation harness. + +The foundry unifies four worktree repos (Discover, Lab, Foundry, DataFoundry) +into a single module for Metal kernel dataset generation, evaluation, and +training data export. + +Submodules: + taxonomy -- Foundation types (KernelClass, OpSpec, KernelCandidate, etc.) + ids -- Stable identifiers for attempts and cache keys + ndjson -- Append-only NDJSON logging with crash tolerance + ops -- Op registry (16 ops across 9 kernel classes) + templates -- Metal template discovery, loading, and rendering + harness -- Compile, correctness-check, and benchmark orchestrator + sampling -- Random/coverage/mutation/CEM candidate sampling + session -- Session directory management and NDJSON logging + scheduler -- Curriculum-based staged op unlocking + plugins -- Protocol-based extensibility (entry-point and local) + reports -- Coverage analysis and Pareto extraction + export -- Training JSONL export + workers -- Multi-process worker orchestration + +CLI: + python -m zmlx.foundry run # generate kernel dataset + python -m zmlx.foundry report # coverage + pareto reports + python -m zmlx.foundry export # export training JSONL + python -m zmlx.foundry list # list registered ops +""" +from __future__ import annotations + +__all__ = [ + # Foundation + "taxonomy", + "ids", + "ndjson", + # Ops + "ops", + # Templates + "templates", + # Harness + "harness", + # Sampling + "sampling", + # Session & scheduling + "session", + "scheduler", + # Plugins + "plugins", + # Reports & export + "reports", + "export", + # Workers + "workers", +] diff --git a/src/zmlx/foundry/__main__.py b/src/zmlx/foundry/__main__.py new file mode 100644 index 0000000..930585a --- /dev/null +++ b/src/zmlx/foundry/__main__.py @@ -0,0 +1,302 @@ +"""CLI entry point for the ZMLX foundry. + +Usage:: + + python -m zmlx.foundry run [options] # generate kernel dataset + python -m zmlx.foundry report [options] # coverage + pareto reports + python -m zmlx.foundry export [options] # export training JSONL + python -m zmlx.foundry export-sft [options] # export chat SFT JSONL + python -m zmlx.foundry list # list registered ops + +Example:: + + python -m zmlx.foundry run --ops rmsnorm swiglu --n 500 --workers 4 + python -m zmlx.foundry report sessions/my_run + python -m zmlx.foundry export sessions/my_run --out training_data/ + python -m zmlx.foundry export-sft --out training_data/kernel_sft +""" +from __future__ import annotations + +import argparse +import sys +import time +from pathlib import Path + + +def _cmd_run(args: argparse.Namespace) -> None: + """Run a foundry session.""" + from .ops import get_registry + from .scheduler import flatten_curriculum + from .workers import spawn_workers + + registry = get_registry() + + if args.ops: + ops = args.ops + unknown = [o for o in ops if o not in registry] + if unknown: + print(f"Unknown ops: {', '.join(unknown)}", file=sys.stderr) + print(f"Available: {', '.join(sorted(registry))}", file=sys.stderr) + sys.exit(1) + else: + ops = flatten_curriculum() + + # Session directory + if args.session_dir: + session_dir = args.session_dir + else: + ts = time.strftime("%Y%m%d_%H%M%S") + session_dir = str(Path("sessions") / f"foundry_{ts}") + + print(f"Session: {session_dir}") + print(f"Ops: {', '.join(ops)}") + print(f"Attempts: {args.n}, Workers: {args.workers}, Mode: {args.mode}") + print(f"Backend: {args.backend}, Stage: {args.stage}") + print() + + result = spawn_workers( + session_dir=session_dir, + ops=ops, + n_attempts=args.n, + num_workers=args.workers, + mode=args.mode, + seed=args.seed, + backend=args.backend, + correctness_tests=args.correctness_tests, + warmup=args.warmup, + repeats=args.repeats, + bench_timeout_s=args.bench_timeout, + stage=args.stage, + ) + + print(f"\nDone in {result['total_elapsed_s']:.1f}s") + for w in result["workers"]: + print( + f" worker {w['worker_id']}: " + f"{w['n_written']} written, " + f"{w['n_skipped']} skipped, " + f"{w['elapsed_s']:.1f}s" + ) + print(f"Merged log: {result['merged_path']}") + + +def _cmd_report(args: argparse.Namespace) -> None: + """Generate coverage and pareto reports.""" + from .reports import build_coverage, extract_pareto_front, write_coverage_reports + from .reports.pareto import load_attempts + + session_dir = Path(args.session_dir) + if not session_dir.exists(): + print(f"Session directory not found: {session_dir}", file=sys.stderr) + sys.exit(1) + + md_path, json_path = write_coverage_reports(session_dir) + print(f"Coverage report: {md_path}") + print(f"Coverage data: {json_path}") + + # Print summary + cov = build_coverage(session_dir) + n_ops = len(cov["ops"]) + total = sum(d["attempts"] for d in cov["ops"].values()) + total_ok = sum(d["ok"] for d in cov["ops"].values()) + rate = (total_ok / total * 100) if total > 0 else 0 + print(f"\n{n_ops} ops, {total} attempts, {total_ok} passed ({rate:.1f}%)") + + for op, d in sorted(cov["ops"].items()): + best = d["best_template_latency_ms"] + best_s = f"{best:.4f} ms" if isinstance(best, (int, float)) else "-" + print(f" {op:30s} {d['ok']:4d}/{d['attempts']:<4d} best: {best_s}") + + # Pareto front + attempts = load_attempts(str(session_dir)) + pareto = extract_pareto_front(attempts) + if pareto: + print(f"\nPareto front: {len(pareto)} kernels") + for p in pareto[:5]: + print(f" {p['op']}/{p.get('template_id', '?')}: {p.get('p50_ms', '?')} ms") + else: + print("\nNo pareto front (no successful benchmarked attempts)") + + +def _cmd_export(args: argparse.Namespace) -> None: + """Export training JSONL.""" + from .export.training import export_training_jsonl + + result = export_training_jsonl( + session_dir=args.session_dir, + out_dir=args.out, + ops=args.ops or None, + min_p50_ms=args.min_p50, + max_p50_ms=args.max_p50, + ) + + print(f"Exported {result['n_records']} records ({result['n_skipped']} skipped)") + print(f"Output: {result['out_path']}") + + +def _cmd_export_sft(args: argparse.Namespace) -> None: + """Export chat-style SFT JSONL from foundry/discovery artifacts.""" + from .export.sft import export_kernel_sft_jsonl + + result = export_kernel_sft_jsonl( + out_dir=args.out, + sessions_root=args.sessions_root, + runs_root=args.runs_root, + discover_root=args.discover_root, + train_fraction=args.train_fraction, + valid_fraction=args.valid_fraction, + seed=args.seed, + max_examples=args.max_examples, + include_failed_kd=args.include_failed_kd, + ) + + counts = result["counts"] + print( + "Exported SFT dataset " + f"(total={counts['total']}, train={counts['train']}, " + f"valid={counts['valid']}, test={counts['test']})" + ) + print(f"Manifest: {args.out}/manifest.json") + + +def _cmd_list(_args: argparse.Namespace) -> None: + """List registered ops.""" + from .ops import get_registry + from .scheduler import DEFAULT_CURRICULUM + from .templates import list_templates + + registry = get_registry() + + # Build stage map + stage_map = {} + for stage_idx, stage_ops in enumerate(DEFAULT_CURRICULUM): + for op in stage_ops: + stage_map[op] = stage_idx + + print(f"{'Op':30s} {'Class':12s} {'Stage':6s} {'Templates':20s}") + print("-" * 75) + for name, op in sorted(registry.items()): + try: + kc = op.spec().kernel_class.value + except Exception: + kc = "?" + stage = stage_map.get(name, "?") + templates = list_templates(name) + tpl_str = ", ".join(templates) if templates else "(ref only)" + print(f"{name:30s} {str(kc):12s} {str(stage):6s} {tpl_str}") + + print(f"\n{len(registry)} ops registered") + + +def main() -> None: + parser = argparse.ArgumentParser( + prog="zmlx-foundry", + description="ZMLX Foundry -- Metal kernel template evaluation and dataset generation", + ) + sub = parser.add_subparsers(dest="command") + + # -- run --------------------------------------------------------------- + p_run = sub.add_parser("run", help="Run a foundry evaluation session") + p_run.add_argument("--ops", nargs="+", help="Ops to evaluate (default: all curriculum)") + p_run.add_argument("-n", type=int, default=100, help="Number of attempts per op (default: 100)") + p_run.add_argument("--workers", type=int, default=1, help="Number of parallel workers") + p_run.add_argument("--mode", default="mix", choices=["random", "coverage", "mutation", "mix"], + help="Sampling mode (default: mix)") + p_run.add_argument("--seed", type=int, default=42, help="Base random seed") + p_run.add_argument("--backend", default="mlx", choices=["mlx", "mock"], + help="Evaluation backend (default: mlx)") + p_run.add_argument("--stage", type=int, default=4, + help="Curriculum stage (0-4, default: 4 = all ops)") + p_run.add_argument("--session-dir", help="Session directory (default: auto-generated)") + p_run.add_argument("--correctness-tests", type=int, default=3) + p_run.add_argument("--warmup", type=int, default=10) + p_run.add_argument("--repeats", type=int, default=50) + p_run.add_argument("--bench-timeout", type=float, default=10.0) + + # -- report ------------------------------------------------------------ + p_report = sub.add_parser("report", help="Generate coverage/pareto reports") + p_report.add_argument("session_dir", help="Session directory to analyze") + + # -- export ------------------------------------------------------------ + p_export = sub.add_parser("export", help="Export training JSONL from a session") + p_export.add_argument("session_dir", help="Session directory") + p_export.add_argument("--out", default="training_data", help="Output directory") + p_export.add_argument("--ops", nargs="+", help="Filter to specific ops") + p_export.add_argument("--min-p50", type=float, help="Min p50 latency filter (ms)") + p_export.add_argument("--max-p50", type=float, help="Max p50 latency filter (ms)") + + # -- export-sft -------------------------------------------------------- + p_export_sft = sub.add_parser( + "export-sft", + help="Export chat SFT JSONL from sessions/runs/discover artifacts", + ) + p_export_sft.add_argument( + "--out", + default="training_data/kernel_sft", + help="Output directory for train/valid/test JSONL", + ) + p_export_sft.add_argument( + "--sessions-root", + default="sessions", + help="Root directory containing foundry session attempts", + ) + p_export_sft.add_argument( + "--runs-root", + default="runs", + help="Root directory containing KD run logs", + ) + p_export_sft.add_argument( + "--discover-root", + default="discover_sessions", + help="Root directory containing discover session JSON files", + ) + p_export_sft.add_argument( + "--train-fraction", + type=float, + default=0.96, + help="Fraction of examples assigned to train split", + ) + p_export_sft.add_argument( + "--valid-fraction", + type=float, + default=0.03, + help="Fraction of examples assigned to valid split", + ) + p_export_sft.add_argument( + "--seed", + type=int, + default=42, + help="Shuffle seed", + ) + p_export_sft.add_argument( + "--max-examples", + type=int, + help="Optional cap on total examples after shuffle", + ) + p_export_sft.add_argument( + "--include-failed-kd", + action="store_true", + help="Include KD failed candidates as negative examples", + ) + + # -- list -------------------------------------------------------------- + sub.add_parser("list", help="List registered ops and their metadata") + + args = parser.parse_args() + + if args.command is None: + parser.print_help() + sys.exit(0) + + dispatch = { + "run": _cmd_run, + "report": _cmd_report, + "export": _cmd_export, + "export-sft": _cmd_export_sft, + "list": _cmd_list, + } + dispatch[args.command](args) + + +if __name__ == "__main__": + main() diff --git a/src/zmlx/foundry/export/__init__.py b/src/zmlx/foundry/export/__init__.py new file mode 100644 index 0000000..cba5957 --- /dev/null +++ b/src/zmlx/foundry/export/__init__.py @@ -0,0 +1,7 @@ +"""Export modules for the ZMLX foundry.""" +from __future__ import annotations + +from .sft import export_kernel_sft_jsonl +from .training import export_training_jsonl + +__all__ = ["export_training_jsonl", "export_kernel_sft_jsonl"] diff --git a/src/zmlx/foundry/export/sft.py b/src/zmlx/foundry/export/sft.py new file mode 100644 index 0000000..ca278d8 --- /dev/null +++ b/src/zmlx/foundry/export/sft.py @@ -0,0 +1,455 @@ +"""Kernel-generation SFT export. + +Builds chat-style JSONL datasets for ``mlx_lm.lora`` from: +- Foundry session logs under ``sessions/`` (attempt records + rendered kernels) +- KD discovery runs under ``runs/*/run.ndjson`` +- Discover session trees under ``discover_sessions/*session.json`` +- In-repo Metal templates (foundry + kd) +""" +from __future__ import annotations + +import hashlib +import json +import random +from pathlib import Path +from typing import Any + + +def export_kernel_sft_jsonl( + out_dir: str, + *, + sessions_root: str = "sessions", + runs_root: str = "runs", + discover_root: str = "discover_sessions", + train_fraction: float = 0.96, + valid_fraction: float = 0.03, + seed: int = 42, + max_examples: int | None = None, + include_failed_kd: bool = False, +) -> dict[str, Any]: + """Export a chat-format SFT dataset for kernel generation. + + Output files: + - ``train.jsonl`` + - ``valid.jsonl`` + - ``test.jsonl`` + - ``manifest.json`` + """ + if not (0.0 < train_fraction < 1.0): + raise ValueError("train_fraction must be in (0, 1)") + if not (0.0 <= valid_fraction < 1.0): + raise ValueError("valid_fraction must be in [0, 1)") + if train_fraction + valid_fraction >= 1.0: + raise ValueError("train_fraction + valid_fraction must be < 1") + + examples: list[dict[str, Any]] = [] + source_counts = { + "foundry_attempts": 0, + "kd_runs": 0, + "discover_sessions": 0, + "templates": 0, + } + dedup: set[str] = set() + + out_path = Path(out_dir) + out_path.mkdir(parents=True, exist_ok=True) + + def _append(example: dict[str, Any], source_key: str) -> None: + key = _example_key(example) + if key in dedup: + return + dedup.add(key) + examples.append(example) + source_counts[source_key] += 1 + + for ex in _iter_foundry_examples(Path(sessions_root)): + _append(ex, "foundry_attempts") + + for ex in _iter_kd_examples(Path(runs_root), include_failed=include_failed_kd): + _append(ex, "kd_runs") + + for ex in _iter_discover_examples(Path(discover_root)): + _append(ex, "discover_sessions") + + for ex in _iter_template_examples(): + _append(ex, "templates") + + rng = random.Random(seed) + rng.shuffle(examples) + if max_examples is not None: + examples = examples[: max(0, int(max_examples))] + + train_set, valid_set, test_set = _split_examples( + examples, + train_fraction=train_fraction, + valid_fraction=valid_fraction, + ) + + train_path = out_path / "train.jsonl" + valid_path = out_path / "valid.jsonl" + test_path = out_path / "test.jsonl" + + _write_jsonl(train_path, train_set) + _write_jsonl(valid_path, valid_set) + _write_jsonl(test_path, test_set) + + manifest = { + "schema_version": "1.0", + "seed": seed, + "counts": { + "total": len(examples), + "train": len(train_set), + "valid": len(valid_set), + "test": len(test_set), + }, + "source_counts": source_counts, + "paths": { + "train": str(train_path), + "valid": str(valid_path), + "test": str(test_path), + }, + } + (out_path / "manifest.json").write_text( + json.dumps(manifest, indent=2, sort_keys=True), + encoding="utf-8", + ) + return manifest + + +def _iter_foundry_examples(sessions_root: Path) -> list[dict[str, Any]]: + if not sessions_root.exists(): + return [] + + examples: list[dict[str, Any]] = [] + for log_path in sorted(sessions_root.rglob("*.ndjson")): + if not _is_foundry_attempt_log(log_path): + continue + session_dir = _attempt_session_dir(log_path) + for rec in _iter_jsonl(log_path): + if not _foundry_success(rec): + continue + source_rel = str(rec.get("kernel", {}).get("source_metal", "")).strip() + if not source_rel: + continue + source = _read_source_from_attempt(session_dir, log_path.parent, source_rel) + if not source: + continue + + prompt = _build_foundry_prompt(rec) + meta = { + "source_type": "foundry_attempt", + "session": session_dir.name, + "op": rec.get("op"), + "dtype": rec.get("dtype"), + "shape": rec.get("shape", {}), + "template_id": rec.get("kernel", {}).get("template_id"), + "knobs": rec.get("kernel", {}).get("knobs", {}), + "latency_ms_p50": rec.get("bench", {}).get("latency_ms", {}).get("p50"), + } + examples.append(_chat_example(prompt, source, meta)) + return examples + + +def _iter_kd_examples(runs_root: Path, *, include_failed: bool) -> list[dict[str, Any]]: + if not runs_root.exists(): + return [] + + examples: list[dict[str, Any]] = [] + for run_path in sorted(runs_root.rglob("run.ndjson")): + for rec in _iter_jsonl(run_path): + source = str(rec.get("metal_source", "")).strip() + if not source: + continue + + status = rec.get("status") + if not include_failed and status != "benchmarked": + continue + if not include_failed and rec.get("metrics", {}).get("failure"): + continue + + prompt = _build_kd_prompt(rec) + meta = { + "source_type": "kd_run", + "run_dir": run_path.parent.name, + "op": rec.get("op_name"), + "status": status, + "candidate_id": rec.get("candidate_id"), + "speedup_vs_ref": rec.get("metrics", {}).get("speedup_vs_ref"), + } + examples.append(_chat_example(prompt, source, meta)) + return examples + + +def _iter_discover_examples(discover_root: Path) -> list[dict[str, Any]]: + if not discover_root.exists(): + return [] + + examples: list[dict[str, Any]] = [] + for session_path in sorted(discover_root.rglob("*session.json")): + try: + payload = json.loads(session_path.read_text(encoding="utf-8")) + except json.JSONDecodeError: + continue + + metadata = payload.get("metadata", {}) + target_name = str(metadata.get("target_name", "unknown")) + + added = 0 + root = payload.get("tree_data", {}).get("root") + if isinstance(root, dict): + for node in _iter_tree_nodes(root): + eval_result = node.get("eval_result", {}) + if not (eval_result.get("compiled") and eval_result.get("correct")): + continue + source = str(node.get("candidate", {}).get("spec", {}).get("source", "")).strip() + if not source: + continue + + prompt = _build_discover_prompt(target_name, node, metadata) + meta = { + "source_type": "discover_session", + "session_file": str(session_path), + "target": target_name, + "node_id": node.get("node_id"), + "reward": eval_result.get("reward"), + "speedup": eval_result.get("speedup"), + } + examples.append(_chat_example(prompt, source, meta)) + added += 1 + + if added == 0: + best_source = str(metadata.get("best_source", "")).strip() + if best_source: + prompt = ( + "Write an optimized Metal kernel body for target " + f"`{target_name}`. Return only source code." + ) + meta = { + "source_type": "discover_session", + "session_file": str(session_path), + "target": target_name, + "best_speedup": metadata.get("best_speedup"), + } + examples.append(_chat_example(prompt, best_source, meta)) + + return examples + + +def _iter_template_examples() -> list[dict[str, Any]]: + examples: list[dict[str, Any]] = [] + here = Path(__file__).resolve() + zmlx_root = here.parents[2] + foundry_templates = here.parents[1] / "templates" + kd_templates = zmlx_root / "kd" / "templates" + + for tpl_path in sorted(foundry_templates.rglob("*.metal")): + source = tpl_path.read_text(encoding="utf-8") + op = tpl_path.parent.name + prompt = ( + "Write a parameterized Metal template for op " + f"`{op}` using ``{{PLACEHOLDER}}`` markers where needed." + ) + meta = { + "source_type": "template", + "family": "foundry", + "op": op, + "template_file": str(tpl_path), + } + examples.append(_chat_example(prompt, source, meta)) + + for tpl_path in sorted(kd_templates.rglob("*.tmpl")): + source = tpl_path.read_text(encoding="utf-8") + op = tpl_path.stem.split(".")[0] + prompt = ( + "Write a Metal kernel body template for discovery op " + f"`{op}` using explicit compile-time constants." + ) + meta = { + "source_type": "template", + "family": "kd", + "op": op, + "template_file": str(tpl_path), + } + examples.append(_chat_example(prompt, source, meta)) + + return examples + + +def _iter_tree_nodes(root: dict[str, Any]) -> list[dict[str, Any]]: + out: list[dict[str, Any]] = [] + stack = [root] + while stack: + node = stack.pop() + out.append(node) + for child in node.get("children", []): + if isinstance(child, dict): + stack.append(child) + return out + + +def _chat_example(prompt: str, completion: str, metadata: dict[str, Any]) -> dict[str, Any]: + return { + "messages": [ + {"role": "user", "content": prompt.strip()}, + {"role": "assistant", "content": completion.rstrip()}, + ], + "metadata": metadata, + } + + +def _split_examples( + examples: list[dict[str, Any]], + *, + train_fraction: float, + valid_fraction: float, +) -> tuple[list[dict[str, Any]], list[dict[str, Any]], list[dict[str, Any]]]: + n = len(examples) + if n == 0: + return [], [], [] + + n_train = int(n * train_fraction) + n_valid = int(n * valid_fraction) + + if n >= 3: + n_train = max(1, min(n_train, n - 2)) + n_valid = max(1, min(n_valid, n - n_train - 1)) + else: + n_train = max(1, n - 1) + n_valid = 0 + + n_test = n - n_train - n_valid + if n_test < 0: + n_test = 0 + n_valid = max(0, n - n_train) + + train = examples[:n_train] + valid = examples[n_train : n_train + n_valid] + test = examples[n_train + n_valid :] + return train, valid, test + + +def _example_key(example: dict[str, Any]) -> str: + payload = json.dumps(example, sort_keys=True, ensure_ascii=False) + return hashlib.sha256(payload.encode("utf-8")).hexdigest() + + +def _write_jsonl(path: Path, rows: list[dict[str, Any]]) -> None: + with path.open("w", encoding="utf-8") as f: + for row in rows: + f.write( + json.dumps( + row, + sort_keys=True, + separators=(",", ":"), + ensure_ascii=False, + ) + ) + f.write("\n") + + +def _iter_jsonl(path: Path) -> list[dict[str, Any]]: + out: list[dict[str, Any]] = [] + for line in path.read_text(encoding="utf-8").splitlines(): + line = line.strip() + if not line: + continue + try: + out.append(json.loads(line)) + except json.JSONDecodeError: + continue + return out + + +def _is_foundry_attempt_log(path: Path) -> bool: + if path.name == "attempts.ndjson": + return True + if path.name.startswith("attempts.worker"): + return True + return path.parent.name == "attempts" and path.suffix == ".ndjson" + + +def _attempt_session_dir(log_path: Path) -> Path: + if log_path.parent.name == "attempts": + return log_path.parent.parent + return log_path.parent + + +def _foundry_success(rec: dict[str, Any]) -> bool: + return bool( + rec.get("build", {}).get("ok") + and rec.get("correctness", {}).get("ok") + and rec.get("bench", {}).get("ok") + ) + + +def _read_source_from_attempt(session_dir: Path, log_parent: Path, source_rel: str) -> str: + candidates = [ + session_dir / source_rel, + log_parent / source_rel, + ] + for source_path in candidates: + if source_path.exists(): + return source_path.read_text(encoding="utf-8") + return "" + + +def _build_foundry_prompt(rec: dict[str, Any]) -> str: + op = rec.get("op") + dtype = rec.get("dtype") + shape = rec.get("shape", {}) + spec = rec.get("spec", {}) + kernel = rec.get("kernel", {}) + + return ( + "Write a Metal kernel body for this MLX operation.\n" + f"op: {op}\n" + f"dtype: {dtype}\n" + f"shape: {_compact(shape)}\n" + f"math: {spec.get('math', '')}\n" + f"constraints: {_compact(spec.get('constraints', {}))}\n" + f"template_id: {kernel.get('template_id')}\n" + f"knobs: {_compact(kernel.get('knobs', {}))}\n" + "Return only Metal source code." + ) + + +def _build_kd_prompt(rec: dict[str, Any]) -> str: + op = rec.get("op_name") + inputs_spec = rec.get("inputs_spec", []) + outputs_spec = rec.get("outputs_spec", []) + launch = rec.get("launch_params", {}) + params = rec.get("template_params", {}) + notes = rec.get("notes", {}) + + return ( + "Write a Metal kernel body for this discovery target.\n" + f"op: {op}\n" + f"inputs: {_compact(inputs_spec)}\n" + f"outputs: {_compact(outputs_spec)}\n" + f"template_params: {_compact(params)}\n" + f"launch: {_compact(launch)}\n" + f"shape_signature: {_compact(notes.get('shape_signature', {}))}\n" + "Return only Metal source code." + ) + + +def _build_discover_prompt( + target_name: str, + node: dict[str, Any], + metadata: dict[str, Any], +) -> str: + strategy = node.get("candidate", {}).get("llm_reasoning", "") + eval_result = node.get("eval_result", {}) + return ( + "Write an optimized Metal kernel body for this discover target.\n" + f"target: {target_name}\n" + f"device_chip: {metadata.get('device_chip', 'unknown')}\n" + f"strategy_hint: {strategy}\n" + f"observed_speedup: {eval_result.get('speedup')}\n" + "Return only Metal source code." + ) + + +def _compact(payload: Any) -> str: + return json.dumps(payload, sort_keys=True, separators=(",", ":"), ensure_ascii=False) + diff --git a/src/zmlx/foundry/export/training.py b/src/zmlx/foundry/export/training.py new file mode 100644 index 0000000..7a9d9b8 --- /dev/null +++ b/src/zmlx/foundry/export/training.py @@ -0,0 +1,184 @@ +"""JSONL training data export. + +Filters a session's attempts down to successful, correct, benchmarked +records and exports them as JSONL with the kernel source inlined. +This is the primary format consumed by downstream training pipelines. + +Adapted from mlx-kernel-lab's ``export/training_export.py``. +""" +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +from ..ndjson import iter_records + + +def export_training_jsonl( + session_dir: str, + out_dir: str, + *, + ops: list[str] | None = None, + min_p50_ms: float | None = None, + max_p50_ms: float | None = None, +) -> dict[str, Any]: + """Export successful attempts as JSONL for training. + + Each output record contains: + - ``id``: attempt identifier + - ``op``, ``dtype``, ``shape``, ``spec``: problem specification + - ``template_id``, ``knobs``: kernel configuration + - ``source``: Metal kernel source (from ``kernels/.metal``) + - ``bench``: benchmark metrics + + Parameters + ---------- + session_dir : str + Path to the session directory. + out_dir : str + Output directory for the JSONL file. + ops : list of str, optional + If given, only export attempts for these ops. + min_p50_ms : float, optional + Exclude attempts faster than this (filter out measurement noise). + max_p50_ms : float, optional + Exclude attempts slower than this (filter out regressions). + + Returns + ------- + dict + ``{"out_path": str, "n_records": int, "n_skipped": int}``. + """ + session_dir_p = Path(session_dir) + out_dir_p = Path(out_dir) + out_dir_p.mkdir(parents=True, exist_ok=True) + + out_path = out_dir_p / f"{session_dir_p.name}_training.jsonl" + kernels_dir = session_dir_p / "kernels" + + # Collect all attempt records from all layouts + all_records = _collect_records(session_dir_p) + + n_written = 0 + n_skipped = 0 + + with open(out_path, "w", encoding="utf-8") as f: + for rec in all_records: + # Op filter + rec_op = rec.get("op") + if ops is not None and rec_op not in ops: + n_skipped += 1 + continue + + # Success filter (support both layouts) + if not _is_successful(rec): + n_skipped += 1 + continue + + # Latency filter + p50 = _get_p50(rec) + if p50 is not None: + if min_p50_ms is not None and p50 < min_p50_ms: + n_skipped += 1 + continue + if max_p50_ms is not None and p50 > max_p50_ms: + n_skipped += 1 + continue + + # Resolve attempt ID (different field names across layouts) + aid = rec.get("id") or rec.get("attempt_id", "") + + # Load kernel source if available + src = "" + if aid: + src_path = kernels_dir / f"{aid}.metal" + if src_path.exists(): + src = src_path.read_text(encoding="utf-8") + + out_rec: dict[str, Any] = { + "id": aid, + "op": rec_op, + "dtype": rec.get("dtype"), + "shape": rec.get("shape"), + "spec": rec.get("spec") or rec.get("op_params"), + "template_id": ( + rec.get("template_id") + or rec.get("kernel", {}).get("template_id") + ), + "knobs": ( + rec.get("knobs") + or rec.get("kernel", {}).get("knobs") + ), + "source": src, + "bench": rec.get("bench") or rec.get("result", {}), + } + + f.write(json.dumps( + out_rec, + sort_keys=True, + separators=(",", ":"), + ensure_ascii=False, + )) + f.write("\n") + n_written += 1 + + return { + "out_path": str(out_path), + "n_records": n_written, + "n_skipped": n_skipped, + } + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _collect_records(session_dir: Path) -> list[dict[str, Any]]: + """Collect all attempt records from a session (all layout conventions).""" + records: list[dict[str, Any]] = [] + + # Single-file layout + single = session_dir / "attempts.ndjson" + if single.exists(): + records.extend(iter_records(single)) + + # Worker shards + for wp in sorted(session_dir.glob("attempts.worker*.ndjson")): + records.extend(iter_records(wp)) + + # Sub-directory layout + attempts_dir = session_dir / "attempts" + if attempts_dir.is_dir(): + for fp in sorted(attempts_dir.glob("*.ndjson")): + records.extend(iter_records(fp)) + + return records + + +def _is_successful(rec: dict[str, Any]) -> bool: + """Check whether a record represents a successful attempt.""" + # DataFoundry layout + if ( + rec.get("build", {}).get("ok") + and rec.get("correctness", {}).get("ok") + and rec.get("bench", {}).get("ok") + ): + return True + # Foundry layout + if rec.get("result", {}).get("status") == "ok": + return True + return False + + +def _get_p50(rec: dict[str, Any]) -> float | None: + """Extract p50 latency from a record.""" + lat = rec.get("bench", {}).get("latency_ms") + if isinstance(lat, dict): + p50 = lat.get("p50") + if p50 is not None: + return float(p50) + rlat = rec.get("result", {}).get("template_latency_ms") + if isinstance(rlat, (int, float)): + return float(rlat) + return None diff --git a/src/zmlx/foundry/harness/__init__.py b/src/zmlx/foundry/harness/__init__.py new file mode 100644 index 0000000..e1dd239 --- /dev/null +++ b/src/zmlx/foundry/harness/__init__.py @@ -0,0 +1,20 @@ +"""Evaluation harness for ZMLX foundry kernel candidates. + +Compile, correctness-check, and benchmark Metal kernels against a +reference implementation using a generic ``KernelOp`` protocol. + +Public API: + evaluate_attempt -- main orchestrator (compile + correctness + bench) + MLXBackend -- real Metal backend (requires Apple Silicon + mlx) + MockBackend -- in-process mock for CI / non-GPU environments +""" +from __future__ import annotations + +from .backend import MLXBackend, MockBackend +from .evaluate import evaluate_attempt + +__all__ = [ + "evaluate_attempt", + "MLXBackend", + "MockBackend", +] diff --git a/src/zmlx/foundry/harness/backend.py b/src/zmlx/foundry/harness/backend.py new file mode 100644 index 0000000..dce706a --- /dev/null +++ b/src/zmlx/foundry/harness/backend.py @@ -0,0 +1,219 @@ +"""MLX and Mock backends for the evaluation harness. + +Adapted from DataFoundry's harness/backend.py. The ``maybe_quantize`` +helper is inlined here (instead of importing from datafoundry.ops.common) +to keep the foundry self-contained. + +Both backends satisfy the ``Backend`` protocol in ``taxonomy.py``. +""" +from __future__ import annotations + +import platform +from dataclasses import dataclass +from typing import Any + +import numpy as np + +# --------------------------------------------------------------------------- +# Inlined maybe_quantize (was: from datafoundry.ops.common import maybe_quantize) +# --------------------------------------------------------------------------- + + +def _quantize_to_bfloat16(x: np.ndarray) -> np.ndarray: + """Simulate bfloat16 by truncating the lower 16 bits of float32 mantissa.""" + x32 = x.astype(np.float32, copy=False) + u = x32.view(np.uint32) + u_trunc = u & np.uint32(0xFFFF0000) + return u_trunc.view(np.float32) + + +def maybe_quantize(x: np.ndarray, dtype: str) -> np.ndarray: + """Cast *x* to the precision level of *dtype*. + + * ``float16`` / ``float32`` -- straightforward ``astype``. + * ``bfloat16`` -- numpy has no native bf16 so we truncate float32 mantissa + bits to simulate the reduced precision. + """ + if dtype == "float16": + return x.astype(np.float16) + if dtype == "float32": + return x.astype(np.float32) + if dtype == "bfloat16": + return _quantize_to_bfloat16(x.astype(np.float32)) + raise ValueError(f"unsupported dtype: {dtype}") + + +# --------------------------------------------------------------------------- +# MLXBackend -- real Metal GPU execution +# --------------------------------------------------------------------------- + + +class MLXBackend: + """Backend that delegates to ``mlx.core`` for real GPU execution.""" + + name = "mlx" + + def __init__(self) -> None: + try: + import mlx.core as mx # noqa: F811 + self._mx = mx + except Exception as exc: + self._mx = None + self._import_error = exc + + def is_available(self) -> bool: + return self._mx is not None + + def mlx_version(self) -> str | None: + if self._mx is None: + return None + try: + import mlx + return getattr(mlx, "__version__", None) + except Exception: + return None + + def device_info(self) -> dict[str, Any]: + info: dict[str, Any] = {} + if self._mx is None: + return info + try: + from mlx.core import metal + di = metal.device_info() + info.update(di) + except Exception: + pass + return info + + def array(self, np_array: Any, dtype: str) -> Any: + assert self._mx is not None, "MLX is not available" + mx = self._mx + if dtype == "float16": + return mx.array(np_array.astype(np.float16)) + if dtype == "float32": + return mx.array(np_array.astype(np.float32)) + if dtype == "bfloat16": + return mx.array(np_array.astype(np.float32)).astype(mx.bfloat16) + raise ValueError(f"unsupported dtype: {dtype}") + + def to_numpy(self, arr: Any) -> Any: + assert self._mx is not None, "MLX is not available" + return np.array(arr) + + def eval(self, arr: Any) -> None: + assert self._mx is not None, "MLX is not available" + self._mx.eval(arr) + + def synchronize(self) -> None: + assert self._mx is not None, "MLX is not available" + try: + self._mx.synchronize() + except Exception: + # Older MLX versions may lack synchronize(); eval acts as sync. + pass + + def metal_kernel( + self, + name: str, + input_names: list[str], + output_names: list[str], + source: str, + header: str = "", + ensure_row_contiguous: bool = True, + atomic_outputs: bool = False, + ) -> Any: + assert self._mx is not None, "MLX is not available" + mx = self._mx + return mx.fast.metal_kernel( + name=name, + input_names=input_names, + output_names=output_names, + source=source, + header=header, + ensure_row_contiguous=ensure_row_contiguous, + atomic_outputs=atomic_outputs, + ) + + +# --------------------------------------------------------------------------- +# MockBackend -- no GPU required; for CI and testing +# --------------------------------------------------------------------------- + + +@dataclass +class MockKernel: + """Placeholder kernel object returned by ``MockBackend.metal_kernel``. + + Calling it raises ``RuntimeError`` -- the harness is expected to use + the reference implementation instead of executing mock kernels. The + one exception is compile-error simulation: if the source contains + ``THIS_WILL_NOT_COMPILE``, ``metal_kernel`` raises at creation time. + """ + + name: str + source: str + header: str + + def __call__( + self, + *, + inputs: list[Any], + template: list[Any], + grid: tuple[int, int, int], + threadgroup: tuple[int, int, int], + output_shapes: list[tuple[int, ...]], + output_dtypes: list[Any], + verbose: bool = False, + ) -> Any: + if "THIS_WILL_NOT_COMPILE" in self.source or "THIS_WILL_NOT_COMPILE" in self.header: + raise RuntimeError( + "Mock compile error: injected syntax error (THIS_WILL_NOT_COMPILE)." + ) + raise RuntimeError("MockKernel should not be executed directly.") + + +class MockBackend: + """In-process mock backend for environments without Apple Silicon / MLX.""" + + name = "mock" + + def is_available(self) -> bool: + return True + + def mlx_version(self) -> str | None: + return None + + def device_info(self) -> dict[str, Any]: + return { + "chip": platform.machine(), + "gpu": "mock", + "mem_gb": None, + } + + def array(self, np_array: Any, dtype: str) -> Any: + return maybe_quantize(np_array, dtype) + + def to_numpy(self, arr: Any) -> Any: + return np.array(arr) + + def eval(self, arr: Any) -> None: + return + + def synchronize(self) -> None: + return + + def metal_kernel( + self, + name: str, + input_names: list[str], + output_names: list[str], + source: str, + header: str = "", + ensure_row_contiguous: bool = True, + atomic_outputs: bool = False, + ) -> Any: + if "THIS_WILL_NOT_COMPILE" in source or "THIS_WILL_NOT_COMPILE" in header: + raise RuntimeError( + "Mock compile error: injected syntax error (THIS_WILL_NOT_COMPILE)." + ) + return MockKernel(name=name, source=source, header=header) diff --git a/src/zmlx/foundry/harness/bench.py b/src/zmlx/foundry/harness/bench.py new file mode 100644 index 0000000..10aa983 --- /dev/null +++ b/src/zmlx/foundry/harness/bench.py @@ -0,0 +1,89 @@ +"""Adaptive benchmarking with timeout, warmup, and percentile stats. + +Adapted from DataFoundry's harness/bench.py. Runs a kernel repeatedly, +adaptively increasing the repeat count until timing noise stabilises or +the timeout is hit. +""" +from __future__ import annotations + +import statistics +import time +from collections.abc import Callable + + +def _percentile(xs: list[float], q: float) -> float: + """Simple percentile on a pre-sorted list (nearest-rank).""" + xs_sorted = sorted(xs) + idx = int(round((len(xs_sorted) - 1) * q)) + return float(xs_sorted[max(0, min(len(xs_sorted) - 1, idx))]) + + +def bench( + *, + run_once: Callable[[], None], + sync: Callable[[], None], + warmup: int, + repeats: int, + timeout_s: float, + adaptive: bool = True, +) -> tuple[bool, dict[str, float | None], bool]: + """Benchmark a kernel and return ``(ok, latency_ms_dict, timed_out)``. + + Parameters: + run_once: Callable that dispatches the kernel once. + sync: Callable that forces GPU completion (e.g. ``mx.synchronize``). + warmup: Number of warmup iterations (not timed). + repeats: Target number of timed iterations. + timeout_s: Wall-clock budget in seconds. + adaptive: If ``True``, adaptively increase repeats until noise < 10%. + + Returns: + A 3-tuple: + * ``ok`` -- ``True`` if at least one timed sample was collected. + * ``latency_ms`` -- dict with keys ``p50``, ``p90``, ``min``, ``mean`` + (values are ``None`` on failure). + * ``timed_out`` -- ``True`` if the timeout was hit before completing. + """ + # Warmup + for _ in range(max(0, warmup)): + run_once() + sync() + + times: list[float] = [] + t_start = time.perf_counter() + timed_out = False + + r = max(1, repeats) + while True: + times.clear() + for _ in range(r): + t0 = time.perf_counter() + run_once() + sync() + dt = (time.perf_counter() - t0) * 1000.0 + times.append(dt) + if (time.perf_counter() - t_start) > timeout_s: + timed_out = True + break + if timed_out: + break + if not adaptive or r >= repeats: + break + # Crude noise estimate: if CV < 10% we have enough samples. + if len(times) >= 5: + mean = statistics.mean(times) + stdev = statistics.pstdev(times) + if mean > 0 and (stdev / mean) < 0.10: + break + r = min(r * 2, 256) + if r >= repeats: + break + + if timed_out or not times: + return False, {"p50": None, "p90": None, "min": None, "mean": None}, True + + p50 = _percentile(times, 0.50) + p90 = _percentile(times, 0.90) + mn = min(times) + mean = statistics.mean(times) + return True, {"p50": p50, "p90": p90, "min": mn, "mean": mean}, False diff --git a/src/zmlx/foundry/harness/cache.py b/src/zmlx/foundry/harness/cache.py new file mode 100644 index 0000000..68cd9aa --- /dev/null +++ b/src/zmlx/foundry/harness/cache.py @@ -0,0 +1,42 @@ +"""In-process compile cache for Metal kernels. + +A simple dict-based cache that stores compiled kernel objects (or the +error that prevented compilation). The cache is keyed on a stable +hash of (op, dtype, template_id, knobs, normalized_source). +""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + + +@dataclass +class CacheEntry: + ok: bool + kernel: Any = None + error: Exception | None = None + + +class CompileCache: + """Thread-local (single-process) compile cache. + + Not designed for cross-process sharing -- each worker gets its own. + """ + + def __init__(self) -> None: + self._cache: dict[str, CacheEntry] = {} + + def get(self, key: str) -> CacheEntry | None: + return self._cache.get(key) + + def put_ok(self, key: str, kernel: Any) -> None: + self._cache[key] = CacheEntry(ok=True, kernel=kernel) + + def put_err(self, key: str, error: Exception) -> None: + self._cache[key] = CacheEntry(ok=False, error=error) + + def __len__(self) -> int: + return len(self._cache) + + def clear(self) -> None: + self._cache.clear() diff --git a/src/zmlx/foundry/harness/compile.py b/src/zmlx/foundry/harness/compile.py new file mode 100644 index 0000000..75dd950 --- /dev/null +++ b/src/zmlx/foundry/harness/compile.py @@ -0,0 +1,71 @@ +"""Compile a Metal kernel via the backend, with caching. + +Adapted from DataFoundry's harness/compile.py. Wraps the backend's +``metal_kernel()`` call with ``CompileCache`` lookup/insert and timing. +""" +from __future__ import annotations + +import time +from typing import Any + +from ..taxonomy import BuildResult +from .cache import CompileCache + + +def compile_kernel( + *, + backend: Any, + cache: CompileCache, + cache_key: str, + name: str, + input_names: list[str], + output_names: list[str], + source: str, + header: str, + ensure_row_contiguous: bool, + atomic_outputs: bool = False, +) -> tuple[Any, BuildResult]: + """Compile (or cache-hit) a Metal kernel and return ``(kernel_obj, BuildResult)``. + + On compile failure ``kernel_obj`` is ``None`` and ``BuildResult.ok`` is + ``False`` with error details populated. + """ + # Cache hit? + ent = cache.get(cache_key) + if ent is not None: + if ent.ok: + return ent.kernel, BuildResult(ok=True, ms=0.0, cache_hit=True) + return None, BuildResult( + ok=False, + ms=0.0, + error_type="compile_error", + error_summary=str(ent.error)[:200], + error_log=str(ent.error)[:2000], + cache_hit=True, + ) + + t0 = time.perf_counter() + try: + kernel = backend.metal_kernel( + name=name, + input_names=input_names, + output_names=output_names, + source=source, + header=header, + ensure_row_contiguous=ensure_row_contiguous, + atomic_outputs=atomic_outputs, + ) + ms = (time.perf_counter() - t0) * 1000.0 + cache.put_ok(cache_key, kernel) + return kernel, BuildResult(ok=True, ms=ms, cache_hit=False) + except Exception as exc: + ms = (time.perf_counter() - t0) * 1000.0 + cache.put_err(cache_key, exc) + return None, BuildResult( + ok=False, + ms=ms, + error_type="compile_error", + error_summary=str(exc).split("\n", 1)[0][:200], + error_log=str(exc)[:2000], + cache_hit=False, + ) diff --git a/src/zmlx/foundry/harness/correctness.py b/src/zmlx/foundry/harness/correctness.py new file mode 100644 index 0000000..83a13a8 --- /dev/null +++ b/src/zmlx/foundry/harness/correctness.py @@ -0,0 +1,81 @@ +"""Correctness checking: error metrics and tolerance gating. + +Adapted from DataFoundry's harness/correctness.py. Provides both +numpy-only and MLX-accelerated metric computation, plus a simple +tolerance gate (``check_pass``). +""" +from __future__ import annotations + +from typing import Any + +import numpy as np + +from ..taxonomy import tolerances + +# --------------------------------------------------------------------------- +# Numpy metric helpers +# --------------------------------------------------------------------------- + + +def _max_rel_err_np(y: np.ndarray, y_ref: np.ndarray) -> float: + denom = np.maximum(np.abs(y_ref), 1e-8) + return float(np.max(np.abs(y - y_ref) / denom)) + + +def _max_abs_err_np(y: np.ndarray, y_ref: np.ndarray) -> float: + return float(np.max(np.abs(y - y_ref))) + + +def _ulp32_np(y: np.ndarray, y_ref: np.ndarray) -> float: + """ULP (unit in the last place) distance for float32.""" + yi = y.astype(np.float32, copy=False).view(np.int32) + ri = y_ref.astype(np.float32, copy=False).view(np.int32) + return float(np.max(np.abs(yi - ri))) + + +# --------------------------------------------------------------------------- +# Public metric computation +# --------------------------------------------------------------------------- + + +def compute_metrics_numpy( + y: np.ndarray, y_ref: np.ndarray, dtype: str +) -> tuple[float, float, float | None]: + """Compute (max_abs_err, max_rel_err, ulp) from numpy arrays.""" + y32 = y.astype(np.float32, copy=False) + r32 = y_ref.astype(np.float32, copy=False) + max_abs = _max_abs_err_np(y32, r32) + max_rel = _max_rel_err_np(y32, r32) + ulp = _ulp32_np(y32, r32) if dtype == "float32" else None + return max_abs, max_rel, ulp + + +def compute_metrics_mlx( + mx: Any, y: Any, y_ref: Any, dtype: str +) -> tuple[float, float, float | None]: + """Compute error metrics on-device via MLX and return python scalars.""" + diff = mx.abs(y.astype(mx.float32) - y_ref.astype(mx.float32)) + max_abs = float(mx.max(diff).item()) + denom = mx.maximum(mx.abs(y_ref.astype(mx.float32)), 1e-8) + max_rel = float(mx.max(diff / denom).item()) + ulp: float | None = None + return max_abs, max_rel, ulp + + +# --------------------------------------------------------------------------- +# Tolerance gate +# --------------------------------------------------------------------------- + + +def check_pass( + dtype: str, max_abs: float, max_rel: float +) -> tuple[bool, str]: + """Return ``(passed, reason)`` where *reason* is empty on pass.""" + tol = tolerances(dtype) + if max_abs <= tol.max_abs and max_rel <= tol.max_rel: + return True, "" + return False, ( + f"exceeds_tolerance(" + f"abs={max_abs:.3g}, rel={max_rel:.3g}, " + f"tol_abs={tol.max_abs:.3g}, tol_rel={tol.max_rel:.3g})" + ) diff --git a/src/zmlx/foundry/harness/evaluate.py b/src/zmlx/foundry/harness/evaluate.py new file mode 100644 index 0000000..78db065 --- /dev/null +++ b/src/zmlx/foundry/harness/evaluate.py @@ -0,0 +1,542 @@ +"""Generic evaluation orchestrator for foundry kernel candidates. + +This is the KEY refactor from DataFoundry: instead of hard-coded +``if op == "rmsnorm" ... elif op == "swiglu"`` branches, the +orchestrator accepts a ``KernelOp`` instance that encapsulates all +op-specific logic (input generation, reference computation, output +shape inference, threadgroup memory estimation). + +The flow is: + 1. Render the Metal template with knob values. + 2. Compile (via ``compile_kernel`` with caching). + 3. For each correctness seed: generate inputs, run kernel (or + mock with reference), compute error metrics, gate on tolerances. + 4. If correctness passes: benchmark with adaptive repeats. + 5. Return a structured attempt record dict (ready for NDJSON). +""" +from __future__ import annotations + +import datetime as _dt +import platform +import sys +from pathlib import Path +from typing import Any + +import numpy as np + +from ..ids import attempt_id, cache_key, shape_class +from ..ndjson import append_record +from ..taxonomy import KernelCandidate, KernelOp +from ..templates.render import render_template +from .backend import MLXBackend, MockBackend +from .bench import bench as bench_fn +from .cache import CompileCache +from .compile import compile_kernel +from .correctness import check_pass, compute_metrics_mlx, compute_metrics_numpy + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _iso_now() -> str: + return _dt.datetime.utcnow().replace(microsecond=0).isoformat() + "Z" + + +def _env(backend: Any) -> dict[str, Any]: + return { + "os": platform.platform(), + "mlx_version": backend.mlx_version() if hasattr(backend, "mlx_version") else None, + "python": sys.version.split()[0], + "device": backend.device_info() if hasattr(backend, "device_info") else {}, + } + + +def _trim(s: Any, n: int) -> str: + if s is None: + return "" + s = str(s) + return s if len(s) <= n else (s[:n] + "...") + + +def _classify_error(e: Exception) -> tuple[str, str, str]: + """Classify an exception into (error_type, summary, log).""" + msg = str(e) + summary = msg.split("\n", 1)[0] + etype = "api_error" + lowered = msg.lower() + if any( + tok in lowered + for tok in ( + "unable to build metal library", + "program_source", + "error:", + "metal::device", + ) + ): + etype = "compile_error" + return etype, _trim(summary, 200), _trim(msg, 4000) + + +def _metal_out_type(dtype: str) -> str: + """Map a dtype string to the Metal type name for output casts.""" + if dtype == "float16": + return "float16_t" + if dtype == "bfloat16": + return "bfloat16_t" + if dtype == "float32": + return "float" + raise ValueError(f"unsupported dtype for Metal output: {dtype}") + + +def _mlx_dtype(mx: Any, dtype: str) -> Any: + """Map a dtype string to an ``mx.*`` dtype object.""" + if dtype == "float16": + return mx.float16 + if dtype == "float32": + return mx.float32 + if dtype == "bfloat16": + return mx.bfloat16 + raise ValueError(f"unsupported mlx dtype: {dtype}") + + +def _np_strides_elems(x: np.ndarray) -> list[int]: + """Convert numpy byte-strides to element-strides.""" + item = x.dtype.itemsize + return [int(s // item) for s in x.strides] + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def evaluate_attempt( + *, + session_dir: Path, + backend_name: str, + candidate: KernelCandidate, + op: KernelOp, + cache: CompileCache, + correctness_tests: int = 3, + correctness_seed: int = 0, + warmup: int = 10, + repeats: int = 50, + bench_timeout_s: float = 10.0, + compile_timeout_s: float = 30.0, +) -> dict[str, Any]: + """Compile, correctness-check, and benchmark one kernel attempt. + + This is the main orchestrator. It is **generic** over ops: all + op-specific logic is delegated to the ``op`` parameter (a ``KernelOp`` + instance). + + Parameters: + session_dir: Directory for writing kernel source files and logs. + backend_name: ``"mlx"`` for real GPU or ``"mock"`` for testing. + candidate: The ``KernelCandidate`` describing template + knobs. + op: A ``KernelOp`` instance for input generation, reference + computation, and output shape inference. + cache: ``CompileCache`` instance (shared across attempts). + correctness_tests: Number of random seeds to test. + correctness_seed: Base seed for reproducibility. + warmup: Benchmark warmup iterations. + repeats: Benchmark timed iterations. + bench_timeout_s: Wall-clock timeout for benchmarking. + compile_timeout_s: Wall-clock timeout for compilation (advisory). + + Returns: + A dict suitable for NDJSON serialisation with keys: + ``id``, ``ts``, ``op``, ``dtype``, ``shape``, ``layout``, + ``spec``, ``kernel``, ``build``, ``correctness``, ``bench``, ``env``. + """ + backend = MLXBackend() if backend_name == "mlx" else MockBackend() + if backend_name == "mlx" and not backend.is_available(): + raise RuntimeError( + "MLX backend requested but MLX is not available. " + "Install mlx and run on Apple Silicon." + ) + + op_name = candidate.op + dtype = candidate.dtype + shape = dict(candidate.shape) + layout = dict(candidate.layout) + + # -- Op-specific metadata via the KernelOp protocol ---------------- + # Support both taxonomy.KernelOp Protocol and ops.base.KernelOp ABC + has_bridge = hasattr(op, "generate_inputs_numpy_bridge") + + math = op.canonical_math() + constraints = op.correctness_constraints(dtype) + bytes_est, flops_est = op.bytes_and_flops(shape, dtype) + input_names = list(op.input_names) + output_names = list(op.output_names) + + # -- Render template ----------------------------------------------- + template_path = ( + Path(__file__).resolve().parent.parent + / "templates" + / op_name + / f"{candidate.template_id}.metal" + ) + if not template_path.exists(): + raise FileNotFoundError( + f"Template not found: {template_path}" + ) + + tg_size = int(candidate.knobs.get("tg_size", 256)) + values = { + "TG_SIZE": tg_size, + "VEC": int(candidate.knobs.get("vec", 1)), + "UNROLL": int(candidate.knobs.get("unroll", 1)), + "FAST_MATH": 1 if candidate.knobs.get("fast_math", False) else 0, + "INJECT_INCORRECT": 1 if candidate.knobs.get("inject_incorrect", False) else 0, + "OUT_TYPE": _metal_out_type(dtype), + "COMPILE_ERROR_SNIPPET": ( + "THIS_WILL_NOT_COMPILE;" + if candidate.knobs.get("inject_compile_error", False) + else "" + ), + } + rendered = render_template(template_path, values) + + # -- Stable IDs ---------------------------------------------------- + sc = shape_class(op_name, dtype, shape, layout) + rid = attempt_id( + op=op_name, + dtype=dtype, + shape_class=sc, + template_id=candidate.template_id, + knobs=candidate.knobs, + normalized_source=rendered.normalized_full_text, + ) + ck = cache_key( + op=op_name, + dtype=dtype, + template_id=candidate.template_id, + knobs=candidate.knobs, + normalized_source=rendered.normalized_full_text, + ) + + # -- Write kernel source to disk ----------------------------------- + kernels_dir = session_dir / "kernels" + kernels_dir.mkdir(parents=True, exist_ok=True) + kernel_file = kernels_dir / f"{rid}.metal" + if not kernel_file.exists(): + kernel_file.write_text(rendered.full_text, encoding="utf-8") + + # -- Launch config (one threadgroup per row) ----------------------- + rows = int(shape.get("batch", 1)) * int(shape.get("seq", 1)) + grid = (tg_size, rows, 1) + threadgroup = (tg_size, 1, 1) + tg_mem = op.tg_mem_bytes(candidate.template_id, tg_size) + + # -- Compile ------------------------------------------------------- + ensure_row_contig = bool(layout.get("contiguous", True)) + kernel_obj, build = compile_kernel( + backend=backend, + cache=cache, + cache_key=ck, + name=f"{op_name}_{candidate.template_id}", + input_names=input_names, + output_names=output_names, + source=rendered.source, + header=rendered.header, + ensure_row_contiguous=ensure_row_contig, + atomic_outputs=False, + ) + + # -- Base record --------------------------------------------------- + record: dict[str, Any] = { + "id": rid, + "ts": _iso_now(), + "op": op_name, + "dtype": dtype, + "shape": shape, + "layout": { + "contiguous": bool(layout.get("contiguous", True)), + "strides": layout.get("strides", []), + }, + "spec": { + "math": math, + "constraints": constraints, + "reference_impl": f"{op_name}_reference_numpy_v1", + }, + "kernel": { + "template_id": candidate.template_id, + "knobs": candidate.knobs, + "source_metal": str(Path("kernels") / f"{rid}.metal"), + "launch": { + "grid": list(grid), + "tg": list(threadgroup), + "tg_mem_bytes": tg_mem, + }, + "compile_flags": ( + ["fast_math"] if candidate.knobs.get("fast_math", False) else [] + ), + "cache_key": ck, + }, + "build": { + "ok": bool(build.ok), + "ms": float(build.ms), + "error_type": build.error_type, + "error_summary": build.error_summary, + "error_log": build.error_log, + }, + "correctness": { + "ok": False, + "max_abs_err": None, + "max_rel_err": None, + "ulp": None, + "n_tests": int(correctness_tests), + "seed": int(correctness_seed), + "fail_reason": "build_failed" if not build.ok else None, + }, + "bench": { + "ok": False, + "warmup": int(warmup), + "repeats": int(repeats), + "latency_ms": {"p50": None, "p90": None, "min": None, "mean": None}, + "throughput_gbs": None, + "throughput_gflops": None, + "timeout": False, + }, + "env": _env(backend), + } + + if not build.ok: + return record + + # ================================================================ + # CORRECTNESS + # ================================================================ + max_abs_all = 0.0 + max_rel_all = 0.0 + ulp_all: float | None = None + fail_reason: str | None = None + + use_mlx = backend_name == "mlx" + mx = None + if use_mlx: + import mlx.core as mx # type: ignore[no-redef] + + def _to_mlx_input(np_arr: np.ndarray) -> Any: + """Preserve integer tensors (e.g. index buffers), cast float tensors to candidate dtype.""" + assert mx is not None + if np.issubdtype(np_arr.dtype, np.integer): + return mx.array(np_arr.astype(np.int32, copy=False)) + if np_arr.ndim == 0 or np_arr.size == 1: + return backend.array(np_arr, "float32") + return backend.array(np_arr, dtype) + + for i in range(correctness_tests): + seed_i = correctness_seed + i + rng = np.random.default_rng(seed_i) + + # Generic input generation -- supports both protocol variants. + if has_bridge: + in_np = op.generate_inputs_numpy_bridge(rng, shape, dtype, layout) + else: + in_np = op.generate_inputs_numpy(rng, shape, dtype, layout) + + # Record strides from the first test case. + if i == 0: + # Use the first array that has meaningful strides. + for arr_name in input_names: + if arr_name in in_np and hasattr(in_np[arr_name], "strides"): + record["layout"]["strides"] = _np_strides_elems(in_np[arr_name]) + break + + if use_mlx: + # Convert numpy inputs to MLX arrays. + mlx_inputs: list[Any] = [] + for arr_name in input_names: + np_arr = in_np[arr_name] + mlx_inputs.append(_to_mlx_input(np_arr)) + + try: + outs = kernel_obj( + inputs=mlx_inputs, + template=[], + grid=grid, + threadgroup=threadgroup, + output_shapes=op.output_shapes(in_np, shape), + output_dtypes=[_mlx_dtype(mx, dtype) for _ in output_names], + ) + y = outs[0] + + # Build mlx-typed input dict for the reference. + mlx_input_dict: dict[str, Any] = {} + for idx, arr_name in enumerate(input_names): + mlx_input_dict[arr_name] = mlx_inputs[idx] + + try: + y_ref = op.reference_mlx(mlx_input_dict, dtype) + backend.eval(y) + backend.eval(y_ref) + backend.synchronize() + max_abs, max_rel, ulp = compute_metrics_mlx(mx, y, y_ref, dtype) + except NotImplementedError: + # Most ops only provide numpy references; compare in numpy. + backend.eval(y) + backend.synchronize() + y_np = np.asarray(backend.to_numpy(y.astype(mx.float32))) + if has_bridge: + y_ref_np = op.reference_numpy_bridge(in_np, dtype) + else: + y_ref_np = op.reference_numpy(in_np, dtype) + y_ref_np = np.asarray(y_ref_np) + max_abs, max_rel, ulp = compute_metrics_numpy(y_np, y_ref_np, dtype) + except Exception as exc: + et, summ, log = _classify_error(exc) + record["build"]["ok"] = False + record["build"]["error_type"] = et + record["build"]["error_summary"] = summ + record["build"]["error_log"] = log + record["correctness"]["ok"] = False + record["correctness"]["fail_reason"] = "kernel_call_error" + return record + else: + # Mock path: use numpy reference, optionally corrupt output. + if has_bridge: + y_ref = op.reference_numpy_bridge(in_np, dtype) + else: + y_ref = op.reference_numpy(in_np, dtype) + y = y_ref.copy() + if candidate.knobs.get("inject_incorrect", False): + # Inject a small offset so correctness fails. + y = y + 0.1 + max_abs, max_rel, ulp = compute_metrics_numpy(y, y_ref, dtype) + + max_abs_all = max(max_abs_all, float(max_abs)) + max_rel_all = max(max_rel_all, float(max_rel)) + if ulp is not None: + ulp_all = float(ulp) if ulp_all is None else max(ulp_all, float(ulp)) + + ok, reason = check_pass(dtype, max_abs, max_rel) + if not ok: + fail_reason = reason + break + + corr_ok = fail_reason is None + record["correctness"] = { + "ok": bool(corr_ok), + "max_abs_err": float(max_abs_all) if corr_ok or max_abs_all > 0 else None, + "max_rel_err": float(max_rel_all) if corr_ok or max_rel_all > 0 else None, + "ulp": ulp_all, + "n_tests": int(correctness_tests), + "seed": int(correctness_seed), + "fail_reason": fail_reason, + } + + if not corr_ok: + return record + + # ================================================================ + # BENCHMARK + # ================================================================ + rngb = np.random.default_rng(correctness_seed + 10_000) + if has_bridge: + in_np_bench = op.generate_inputs_numpy_bridge(rngb, shape, dtype, layout) + else: + in_np_bench = op.generate_inputs_numpy(rngb, shape, dtype, layout) + + if use_mlx: + mlx_bench_inputs: list[Any] = [] + for arr_name in input_names: + np_arr = in_np_bench[arr_name] + mlx_bench_inputs.append(_to_mlx_input(np_arr)) + + bench_out_shapes = op.output_shapes(in_np_bench, shape) + bench_out_dtypes = [_mlx_dtype(mx, dtype) for _ in output_names] + + def run_once() -> None: + outs = kernel_obj( + inputs=mlx_bench_inputs, + template=[], + grid=grid, + threadgroup=threadgroup, + output_shapes=bench_out_shapes, + output_dtypes=bench_out_dtypes, + ) + backend.eval(outs[0]) + + def sync() -> None: + backend.synchronize() + else: + def run_once() -> None: + if has_bridge: + _ = op.reference_numpy_bridge(in_np_bench, dtype) + else: + _ = op.reference_numpy(in_np_bench, dtype) + + def sync() -> None: + return + + try: + ok_bench, lat, timed_out = bench_fn( + run_once=run_once, + sync=sync, + warmup=warmup, + repeats=repeats, + timeout_s=bench_timeout_s, + adaptive=True, + ) + except Exception as exc: + et, summ, log = _classify_error(exc) + record["bench"] = { + "ok": False, + "warmup": int(warmup), + "repeats": int(repeats), + "latency_ms": {"p50": None, "p90": None, "min": None, "mean": None}, + "throughput_gbs": None, + "throughput_gflops": None, + "timeout": False, + "error_type": et, + "error_summary": summ, + "error_log": log, + } + return record + + throughput_gbs: float | None = None + throughput_gflops: float | None = None + if ok_bench and lat.get("p50") is not None and float(lat["p50"]) > 0: + sec = float(lat["p50"]) / 1000.0 + throughput_gbs = (bytes_est / sec) / 1e9 + throughput_gflops = (flops_est / sec) / 1e9 + + record["bench"] = { + "ok": bool(ok_bench), + "warmup": int(warmup), + "repeats": int(repeats), + "latency_ms": { + k: (float(v) if v is not None else None) for k, v in lat.items() + }, + "throughput_gbs": ( + float(throughput_gbs) if throughput_gbs is not None else None + ), + "throughput_gflops": ( + float(throughput_gflops) if throughput_gflops is not None else None + ), + "timeout": bool(timed_out), + } + return record + + +# --------------------------------------------------------------------------- +# Record writer (convenience wrapper around ndjson.append_record) +# --------------------------------------------------------------------------- + + +def write_record( + *, + session_dir: Path, + record: dict[str, Any], + worker_id: int | None = None, +) -> None: + """Append an attempt record to the session's NDJSON log.""" + if worker_id is None: + path = session_dir / "attempts.ndjson" + else: + path = session_dir / f"attempts.worker{worker_id}.ndjson" + append_record(path, record) diff --git a/src/zmlx/foundry/ids.py b/src/zmlx/foundry/ids.py new file mode 100644 index 0000000..7497101 --- /dev/null +++ b/src/zmlx/foundry/ids.py @@ -0,0 +1,116 @@ +"""Stable identifiers for attempts, cache keys, and shape classes. + +Adapted from DataFoundry's logging/ids.py. Every function is pure +(deterministic, no side effects) and depends only on the stdlib. +""" +from __future__ import annotations + +import hashlib +import json +import re +from typing import Any + +# --------------------------------------------------------------------------- +# Source normalisation +# --------------------------------------------------------------------------- + +_WHITESPACE_RE = re.compile(r"[ \t]+") + + +def normalize_source(src: str) -> str: + """Normalize Metal source for stable hashing. + + * Converts line endings to ``\\n``. + * Collapses runs of horizontal whitespace in non-indented lines. + * Strips trailing blank lines. + """ + src = src.replace("\r\n", "\n").replace("\r", "\n") + lines: list[str] = [] + for line in src.split("\n"): + line = line.rstrip() + # Non-indented lines: collapse internal runs of spaces/tabs. + if line.lstrip() == line: + line = _WHITESPACE_RE.sub(" ", line) + lines.append(line) + # Remove trailing blank lines. + while lines and lines[-1] == "": + lines.pop() + return "\n".join(lines) + "\n" + + +# --------------------------------------------------------------------------- +# Canonical JSON (deterministic serialisation for hashing) +# --------------------------------------------------------------------------- + + +def canonical_json(obj: Any) -> str: + return json.dumps( + obj, sort_keys=True, separators=(",", ":"), ensure_ascii=False + ) + + +# --------------------------------------------------------------------------- +# Shape class -- human-readable dimension fingerprint +# --------------------------------------------------------------------------- + + +def shape_class( + op: str, dtype: str, shape: dict[str, int], layout: dict[str, Any] +) -> str: + """Replay-stable shape class string. + + Later summarisation can bucket these; we keep exact dims here. + """ + parts = [f"b{shape.get('batch', 1)}", f"s{shape.get('seq', 1)}"] + if op == "swiglu": + parts.append(f"h{shape.get('hidden', 0)}") + parts.append(f"hin{shape.get('hidden_in', shape.get('hidden', 0) * 2)}") + else: + parts.append(f"h{shape.get('hidden', 0)}") + parts.append( + "strided" if not bool(layout.get("contiguous", True)) else "contig" + ) + return "_".join(parts) + + +# --------------------------------------------------------------------------- +# Attempt ID and compile cache key +# --------------------------------------------------------------------------- + + +def attempt_id( + *, + op: str, + dtype: str, + shape_class: str, + template_id: str, + knobs: dict[str, Any], + normalized_source: str, +) -> str: + payload = { + "op": op, + "dtype": dtype, + "shape_class": shape_class, + "template_id": template_id, + "knobs": knobs, + "source": normalized_source, + } + return hashlib.sha256(canonical_json(payload).encode("utf-8")).hexdigest() + + +def cache_key( + *, + op: str, + dtype: str, + template_id: str, + knobs: dict[str, Any], + normalized_source: str, +) -> str: + payload = { + "op": op, + "dtype": dtype, + "template_id": template_id, + "knobs": knobs, + "source": normalized_source, + } + return hashlib.sha256(canonical_json(payload).encode("utf-8")).hexdigest() diff --git a/src/zmlx/foundry/ndjson.py b/src/zmlx/foundry/ndjson.py new file mode 100644 index 0000000..263c6c8 --- /dev/null +++ b/src/zmlx/foundry/ndjson.py @@ -0,0 +1,94 @@ +"""Append-only NDJSON log for attempt records. + +Adapted from DataFoundry's logging/ndjson.py. Each record is a single +JSON line, fsynced after write so partial records from crashes are +confined to the last line (which ``iter_records`` tolerates). +""" +from __future__ import annotations + +import json +import os +from collections.abc import Iterator +from pathlib import Path +from typing import Any + + +def dumps(record: dict[str, Any]) -> str: + """Compact, deterministic-order JSON serialisation.""" + return json.dumps( + record, + sort_keys=False, + ensure_ascii=False, + separators=(",", ":"), + allow_nan=False, + ) + + +def append_record(path: Path, record: dict[str, Any]) -> None: + """Append one JSON line and fsync.""" + path.parent.mkdir(parents=True, exist_ok=True) + line = dumps(record) + "\n" + with open(path, "a", encoding="utf-8") as f: + f.write(line) + f.flush() + os.fsync(f.fileno()) + + +def iter_records(path: Path) -> Iterator[dict[str, Any]]: + """Iterate NDJSON records, tolerating partial/corrupt lines.""" + if not path.exists(): + return + with open(path, encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + try: + yield json.loads(line) + except json.JSONDecodeError: + continue + + +def load_existing_ids(session_dir: Path) -> set[str]: + """Collect all attempt IDs already recorded in a session directory.""" + ids: set[str] = set() + if not session_dir.exists(): + return ids + candidates = [session_dir / "attempts.ndjson"] + candidates += sorted(session_dir.glob("attempts.worker*.ndjson")) + for p in candidates: + for rec in iter_records(p): + rid = rec.get("id") + if isinstance(rid, str): + ids.add(rid) + return ids + + +def merge_worker_logs( + session_dir: Path, *, out_name: str = "attempts.ndjson" +) -> Path: + """Merge per-worker shards into a single deduplicated log.""" + session_dir.mkdir(parents=True, exist_ok=True) + out_path = session_dir / out_name + + seen: set[str] = set() + for rec in iter_records(out_path): + rid = rec.get("id") + if isinstance(rid, str): + seen.add(rid) + + worker_paths = sorted(session_dir.glob("attempts.worker*.ndjson")) + if not worker_paths: + return out_path + + with open(out_path, "a", encoding="utf-8") as out: + for wp in worker_paths: + for rec in iter_records(wp): + rid = rec.get("id") + if not isinstance(rid, str) or rid in seen: + continue + out.write(dumps(rec) + "\n") + seen.add(rid) + out.flush() + os.fsync(out.fileno()) + return out_path diff --git a/src/zmlx/foundry/ops/__init__.py b/src/zmlx/foundry/ops/__init__.py new file mode 100644 index 0000000..d011ab8 --- /dev/null +++ b/src/zmlx/foundry/ops/__init__.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from .base import KernelOp +from .dequantize import DequantizeOp +from .gather import GatherOp +from .grouped_gemm import GroupedGemmOp +from .kv_append import KVAppendOp +from .layernorm import LayerNormOp +from .moe_combine import MoECombineOp +from .moe_dispatch import MoEDispatchOp +from .moe_pack import MoEPackOp +from .moe_topk import MoETopKOp +from .quantize import QuantizeOp +from .rmsnorm import RMSNormOp +from .rope import RoPEOp +from .scatter import ScatterOp +from .softmax import SoftmaxOp +from .swiglu import SwiGLUOp +from .topk import TopKOp + +_ALL_OPS = [ + RMSNormOp, + SwiGLUOp, + RoPEOp, + SoftmaxOp, + LayerNormOp, + QuantizeOp, + DequantizeOp, + KVAppendOp, + GatherOp, + ScatterOp, + TopKOp, + MoETopKOp, + MoEPackOp, + MoEDispatchOp, + MoECombineOp, + GroupedGemmOp, +] + + +def get_registry() -> dict[str, KernelOp]: + """Return a name -> KernelOp instance mapping for all registered ops.""" + return {cls.name: cls() for cls in _ALL_OPS} diff --git a/src/zmlx/foundry/ops/base.py b/src/zmlx/foundry/ops/base.py new file mode 100644 index 0000000..8ab8b05 --- /dev/null +++ b/src/zmlx/foundry/ops/base.py @@ -0,0 +1,250 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any + +import numpy as np + +from ..taxonomy import OpSpec, Tolerances + +# --------------------------------------------------------------------------- +# Shape ladders (canonical sizes for LLM inference workloads) +# --------------------------------------------------------------------------- + +TOKENS_LADDER: list[int] = [16, 32, 64, 128, 256, 512, 1024, 2048] +HIDDEN_LADDER: list[int] = [512, 768, 1024, 1536, 2048, 3072, 4096, 6144, 8192] +N_EXPERTS_LADDER: list[int] = [8, 16, 32, 64] +MOE_K_LADDER: list[int] = [2, 4] + + +# --------------------------------------------------------------------------- +# Quantization helpers (from datafoundry common.py) +# --------------------------------------------------------------------------- + +DEFAULT_TOLERANCES: dict[str, Tolerances] = { + "float32": Tolerances(max_abs=5e-5, max_rel=5e-5), + "float16": Tolerances(max_abs=5e-3, max_rel=5e-3), + "bfloat16": Tolerances(max_abs=1e-2, max_rel=1e-2), +} + + +def np_dtype(dtype: str): + if dtype == "float16": + return np.float16 + if dtype == "float32": + return np.float32 + if dtype == "bfloat16": + # numpy lacks native bfloat16; store as float32 and truncate when needed + return np.float32 + raise ValueError(f"unknown dtype {dtype}") + + +def quantize_to_bfloat16(x: np.ndarray) -> np.ndarray: + """Simulate bf16 by truncating lower 16 bits of float32 mantissa.""" + x32 = x.astype(np.float32, copy=False) + u = x32.view(np.uint32) + u_trunc = (u & np.uint32(0xFFFF0000)) + return u_trunc.view(np.float32) + + +def maybe_quantize(x: np.ndarray, dtype: str) -> np.ndarray: + if dtype == "float16": + return x.astype(np.float16) + if dtype == "float32": + return x.astype(np.float32) + if dtype == "bfloat16": + return quantize_to_bfloat16(x.astype(np.float32)) + raise ValueError(dtype) + + +# --------------------------------------------------------------------------- +# Numpy I/O helpers +# --------------------------------------------------------------------------- + +def randn_np(shape: tuple[int, ...], *, dtype: str, seed: int, scale: float = 1.0) -> np.ndarray: + rng = np.random.default_rng(seed & 0xFFFFFFFF) + x = rng.standard_normal(size=shape).astype(np.float32) * scale + return maybe_quantize(x, dtype) + + +def randint_np(low: int, high: int, shape: tuple[int, ...], *, dtype: str, seed: int) -> np.ndarray: + rng = np.random.default_rng(seed & 0xFFFFFFFF) + x = rng.integers(low, high, size=shape) + if dtype == "int8": + return x.astype(np.int8) + if dtype == "int32": + return x.astype(np.int32) + return x + + +def maybe_to_float32(x: np.ndarray) -> np.ndarray: + return np.asarray(x, dtype=np.float32) + + +def max_abs_err(a: np.ndarray, b: np.ndarray) -> float: + a = np.asarray(a) + b = np.asarray(b) + if a.shape != b.shape: + raise ValueError(f"shape mismatch: {a.shape} vs {b.shape}") + return float(np.max(np.abs(a - b))) if a.size else 0.0 + + +# --------------------------------------------------------------------------- +# KernelOp ABC +# --------------------------------------------------------------------------- + +class KernelOp(ABC): + """Standard op interface for the unified foundry. + + Every op provides: + - spec(): static metadata + - supported_dtypes(): what dtypes this op handles + - templates(): list of Metal template IDs (or ["ref"] for reference-only) + - knob_space(): dict describing tunable knob ranges per template + - sample_shape(): random shape sampling for data generation + - generate_inputs_numpy(): produce NumPy input tensors + - reference_numpy(): golden reference implementation in NumPy + - validate_knobs(): check if a knob configuration is valid + - bytes_and_flops(): throughput estimation + """ + + name: str = "" + + @abstractmethod + def spec(self) -> OpSpec: + ... + + def supported_dtypes(self) -> list[str]: + return ["float16", "float32"] + + def templates(self) -> list[str]: + return ["ref"] + + def knob_space(self, template_id: str) -> dict[str, Any]: + return {} + + @abstractmethod + def sample_shape(self, rng: np.random.Generator) -> dict[str, int]: + ... + + def sample_op_params(self, shape: dict[str, int], rng: np.random.Generator) -> dict[str, Any]: + return {} + + @abstractmethod + def generate_inputs_numpy( + self, + shape: dict[str, int], + dtype: str, + op_params: dict[str, Any], + seed: int, + ) -> dict[str, np.ndarray]: + ... + + @abstractmethod + def reference_numpy( + self, + inputs: dict[str, np.ndarray], + op_params: dict[str, Any], + ) -> dict[str, np.ndarray]: + ... + + def validate_knobs( + self, template_id: str, knobs: dict[str, Any], shape: dict[str, int], dtype: str + ) -> tuple[bool, str]: + return True, "" + + def bytes_and_flops(self, shape: dict[str, int], dtype: str) -> tuple[int, int]: + return 0, 0 + + # ------------------------------------------------------------------ + # Bridge methods: satisfy the taxonomy.KernelOp Protocol interface + # so that harness/evaluate.py can use ops directly. + # ------------------------------------------------------------------ + + @property + def input_names(self) -> list[str]: + """Buffer names from spec().inputs, e.g. 'x[tokens,hidden]' -> 'x'.""" + return [s.split("[")[0] for s in self.spec().inputs] + + @property + def output_names(self) -> list[str]: + """Buffer names from spec().outputs, e.g. 'y[tokens,hidden]' -> 'y'.""" + return [s.split("[")[0] for s in self.spec().outputs] + + def canonical_math(self) -> str: + return self.spec().summary + + def correctness_constraints(self, dtype: str) -> dict[str, Any]: + tol = DEFAULT_TOLERANCES.get(dtype, Tolerances(max_abs=1e-2, max_rel=1e-2)) + return {"max_abs": tol.max_abs, "max_rel": tol.max_rel} + + def generate_inputs_numpy_bridge( + self, + rng: Any, + shape: dict[str, int], + dtype: str, + layout: dict[str, Any], + ) -> dict[str, Any]: + """Adapter for taxonomy Protocol's generate_inputs_numpy signature. + + Translates sampler shape format (batch/seq/hidden) into op-specific + format (e.g. tokens/hidden for norms), then delegates to the + concrete generate_inputs_numpy. + """ + # Translate batch+seq -> tokens if the op expects tokens + adapted_shape = dict(shape) + if "tokens" not in adapted_shape and "batch" in adapted_shape: + adapted_shape["tokens"] = ( + adapted_shape.get("batch", 1) * adapted_shape.get("seq", 1) + ) + + seed = int(rng.integers(0, 2**31)) if hasattr(rng, "integers") else 42 + gen_rng = np.random.default_rng(seed) + op_params = self.sample_op_params(adapted_shape, gen_rng) + return self.generate_inputs_numpy(adapted_shape, dtype, op_params, seed) + + def reference_numpy_bridge( + self, + inputs: dict[str, Any], + dtype: str, + ) -> Any: + """Adapter for taxonomy Protocol's reference_numpy signature.""" + result = self.reference_numpy(inputs, {}) + # If result is a dict, return the first output array + if isinstance(result, dict): + out_names = self.output_names + if out_names and out_names[0] in result: + return result[out_names[0]] + # Fallback: return first value + return next(iter(result.values())) + return result + + def reference_mlx( + self, + inputs: dict[str, Any], + dtype: str, + ) -> Any: + """Default: no MLX reference. The harness falls back to numpy.""" + raise NotImplementedError(f"{self.name} has no MLX reference implementation") + + def output_shapes( + self, + inputs: dict[str, Any], + shape: dict[str, int], + ) -> list[tuple[int, ...]]: + """Derive output shapes from the reference numpy computation.""" + try: + result = self.reference_numpy(inputs, {}) + if isinstance(result, dict): + return [np.asarray(v).shape for v in result.values()] + return [np.asarray(result).shape] + except Exception: + # Fallback: assume output matches first input shape + for v in inputs.values(): + if hasattr(v, "shape"): + return [v.shape] + return [(1,)] + + def tg_mem_bytes(self, template_id: str, tg_size: int) -> int: + """Threadgroup memory estimate. Default: 0 (no shared memory).""" + return 0 diff --git a/src/zmlx/foundry/ops/dequantize.py b/src/zmlx/foundry/ops/dequantize.py new file mode 100644 index 0000000..742427a --- /dev/null +++ b/src/zmlx/foundry/ops/dequantize.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +from typing import Any + +import numpy as np + +from ..taxonomy import KernelClass, OpSpec +from .base import HIDDEN_LADDER, TOKENS_LADDER, KernelOp, randint_np + + +def _unpack_int4(packed: np.ndarray) -> np.ndarray: + """Unpack uint8 packed int4 (two's complement) into int8 in [-8,7].""" + p = packed.astype(np.uint8) + low = np.bitwise_and(p, 0xF).astype(np.int8) + high = np.bitwise_and(p >> 4, 0xF).astype(np.int8) + # two's complement sign extend 4-bit + low = low - (low >= 8).astype(np.int8) * 16 + high = high - (high >= 8).astype(np.int8) * 16 + out = np.empty(p.shape[:-1] + (p.shape[-1] * 2,), dtype=np.int8) + out[..., 0::2] = low + out[..., 1::2] = high + return out + + +class DequantizeOp(KernelOp): + name = "dequantize" + + def spec(self) -> OpSpec: + return OpSpec( + name=self.name, + kernel_class=KernelClass.QUANT, + summary="Dequantize int8 or packed int4 -> fp16/bf16", + inputs=["q[tokens,hidden] (int8) OR q_packed[tokens,hidden/2] (int4 packed)"], + outputs=["y[tokens,hidden]"], + op_params_schema={ + "q_dtype": {"type": "str", "enum": ["int8", "int4"], "default": "int8"}, + "scale": {"type": "float", "default": 0.02}, + "out_dtype": {"type": "str", "enum": ["float16", "bfloat16", "float32"], "default": "float16"}, + }, + shape_hints={"tokens": TOKENS_LADDER, "hidden": HIDDEN_LADDER}, + dtype_hints=["int8", "int4"], + templates=["ref"], + ) + + def supported_dtypes(self) -> list[str]: + return ["float16", "bfloat16", "float32"] + + def sample_shape(self, rng: np.random.Generator) -> dict[str, int]: + tokens = int(rng.choice(TOKENS_LADDER)) + hidden = int(rng.choice(HIDDEN_LADDER)) + if hidden % 2 != 0: + hidden += 1 + return {"tokens": tokens, "hidden": hidden} + + def sample_op_params(self, shape: dict[str, int], rng: np.random.Generator) -> dict[str, Any]: + q_dtype = str(rng.choice(["int8", "int4"])) + scale = float(rng.choice([0.01, 0.02, 0.05])) + out_dtype = str(rng.choice(["float16", "bfloat16", "float32"])) + return {"q_dtype": q_dtype, "scale": scale, "out_dtype": out_dtype} + + def generate_inputs_numpy( + self, shape: dict[str, int], dtype: str, op_params: dict[str, Any], seed: int + ) -> dict[str, np.ndarray]: + q_dtype = str(op_params.get("q_dtype", "int8")) + if q_dtype == "int8": + q = randint_np(-128, 128, (shape["tokens"], shape["hidden"]), dtype="int8", seed=seed) + return {"q": q} + if q_dtype == "int4": + packed = randint_np( + 0, 256, (shape["tokens"], shape["hidden"] // 2), dtype="int32", seed=seed + ).astype(np.uint8) + return {"q_packed": packed} + raise ValueError(f"Unknown q_dtype: {q_dtype}") + + def reference_numpy( + self, inputs: dict[str, np.ndarray], op_params: dict[str, Any] + ) -> dict[str, np.ndarray]: + q_dtype = str(op_params.get("q_dtype", "int8")) + scale = float(op_params.get("scale", 0.02)) + if scale <= 0: + scale = 0.02 + + if q_dtype == "int8": + q = np.asarray(inputs["q"]).astype(np.int8) + y = q.astype(np.float32) * scale + return {"y": y} + if q_dtype == "int4": + packed = np.asarray(inputs["q_packed"]).astype(np.uint8) + q = _unpack_int4(packed).astype(np.int8) + y = q.astype(np.float32) * scale + return {"y": y} + raise ValueError(f"Unknown q_dtype: {q_dtype}") diff --git a/src/zmlx/foundry/ops/gather.py b/src/zmlx/foundry/ops/gather.py new file mode 100644 index 0000000..b26519a --- /dev/null +++ b/src/zmlx/foundry/ops/gather.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from typing import Any + +import numpy as np + +from ..taxonomy import KernelClass, OpSpec +from .base import HIDDEN_LADDER, TOKENS_LADDER, KernelOp, randint_np, randn_np + + +class GatherOp(KernelOp): + name = "gather" + + def spec(self) -> OpSpec: + return OpSpec( + name=self.name, + kernel_class=KernelClass.DATA_MOVEMENT, + summary="Gather (index select) along axis=0", + inputs=["x[tokens,hidden]", "idx[m]"], + outputs=["y[m,hidden]"], + op_params_schema={"axis": {"type": "int", "default": 0}}, + shape_hints={ + "tokens": TOKENS_LADDER, + "hidden": HIDDEN_LADDER, + "m": [16, 32, 64, 128, 256, 512], + }, + dtype_hints=["float16", "bfloat16", "float32"], + templates=["ref"], + ) + + def supported_dtypes(self) -> list[str]: + return ["float16", "bfloat16", "float32"] + + def sample_shape(self, rng: np.random.Generator) -> dict[str, int]: + tokens = int(rng.choice(TOKENS_LADDER)) + hidden = int(rng.choice(HIDDEN_LADDER)) + m = int(rng.choice([16, 32, 64, 128, 256])) + m = min(m, tokens) if tokens > 0 else m + return {"tokens": tokens, "hidden": hidden, "m": m} + + def generate_inputs_numpy( + self, shape: dict[str, int], dtype: str, op_params: dict[str, Any], seed: int + ) -> dict[str, np.ndarray]: + x = randn_np((shape["tokens"], shape["hidden"]), dtype=dtype, seed=seed, scale=0.5) + idx = randint_np(0, max(1, shape["tokens"]), (shape["m"],), dtype="int32", seed=seed + 1) + return {"x": x, "idx": idx} + + def reference_numpy( + self, inputs: dict[str, np.ndarray], op_params: dict[str, Any] + ) -> dict[str, np.ndarray]: + x = np.asarray(inputs["x"]) + idx = np.asarray(inputs["idx"]).astype(np.int64) + y = x[idx, :] + return {"y": y} diff --git a/src/zmlx/foundry/ops/grouped_gemm.py b/src/zmlx/foundry/ops/grouped_gemm.py new file mode 100644 index 0000000..e75b746 --- /dev/null +++ b/src/zmlx/foundry/ops/grouped_gemm.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +from typing import Any + +import numpy as np + +from ..taxonomy import KernelClass, OpSpec +from .base import ( + HIDDEN_LADDER, + MOE_K_LADDER, + N_EXPERTS_LADDER, + TOKENS_LADDER, + KernelOp, + randn_np, +) +from .moe_pack import pack_assignments +from .moe_topk import softmax_1d + + +class GroupedGemmOp(KernelOp): + """Grouped GEMM hook for MoE. + + Not a Metal kernel yet. Provides a stable interface that can be swapped + for a future specialized grouped/batched matmul kernel. + """ + + name = "grouped_gemm" + + def spec(self) -> OpSpec: + return OpSpec( + name=self.name, + kernel_class=KernelClass.GEMM, + summary="Grouped GEMM hook for MoE (reference uses NumPy matmul per expert)", + inputs=[ + "expert_x[tokens*k,hidden]", + "expert_offsets[n_experts+1]", + "(implicit) W_expert[n_experts,hidden,out_hidden] (generated from seed)", + ], + outputs=["expert_y[tokens*k,out_hidden]"], + op_params_schema={ + "n_experts": {"type": "int", "enum": N_EXPERTS_LADDER, "default": 16}, + "k": {"type": "int", "enum": MOE_K_LADDER, "default": 2}, + "out_hidden": {"type": "int", "default": 512}, + }, + shape_hints={ + "tokens": TOKENS_LADDER, + "hidden": HIDDEN_LADDER, + "out_hidden": [512, 1024, 2048], + "n_experts": N_EXPERTS_LADDER, + "k": MOE_K_LADDER, + }, + dtype_hints=["float16", "bfloat16", "float32"], + templates=["ref"], + ) + + def supported_dtypes(self) -> list[str]: + return ["float16", "bfloat16", "float32"] + + def sample_shape(self, rng: np.random.Generator) -> dict[str, int]: + tokens = int(rng.choice([16, 32, 64, 128, 256])) + hidden = int(rng.choice([512, 768, 1024])) + n_experts = int(rng.choice(N_EXPERTS_LADDER)) + k = int(rng.choice(MOE_K_LADDER)) + out_hidden = int(rng.choice([512, 1024])) + return { + "tokens": tokens, + "hidden": hidden, + "n_experts": n_experts, + "k": min(k, n_experts), + "out_hidden": out_hidden, + } + + def sample_op_params(self, shape: dict[str, int], rng: np.random.Generator) -> dict[str, Any]: + return { + "n_experts": int(shape["n_experts"]), + "k": int(shape["k"]), + "out_hidden": int(shape["out_hidden"]), + } + + def generate_inputs_numpy( + self, shape: dict[str, int], dtype: str, op_params: dict[str, Any], seed: int + ) -> dict[str, np.ndarray]: + # Build synthetic expert_x + expert_offsets from router assignments + x = randn_np((shape["tokens"], shape["hidden"]), dtype=dtype, seed=seed, scale=0.5) + logits = randn_np( + (shape["tokens"], shape["n_experts"]), dtype=dtype, seed=seed + 1, scale=1.0 + ) + topk_idx = np.argsort(-logits, axis=-1)[:, : shape["k"]].astype(np.int32) + top = np.take_along_axis(logits, topk_idx, axis=-1) + topk_w = softmax_1d(top, axis=-1).astype(np.float32) + _, expert_offsets, packed_token_ids, _ = pack_assignments( + topk_idx, topk_w, shape["n_experts"] + ) + expert_x = x[packed_token_ids.astype(np.int64), :] + return {"expert_x": expert_x, "expert_offsets": expert_offsets} + + def reference_numpy( + self, inputs: dict[str, np.ndarray], op_params: dict[str, Any] + ) -> dict[str, np.ndarray]: + expert_x = np.asarray(inputs["expert_x"]).astype(np.float32) + offsets = np.asarray(inputs["expert_offsets"]).astype(np.int32) + n_experts = int(op_params.get("n_experts", len(offsets) - 1)) + out_hidden = int(op_params.get("out_hidden", expert_x.shape[-1])) + + # Per-expert weight matrices from a fixed seed (hook; real system provides W_expert) + W = [] + base_rng = np.random.default_rng(12345) + for _ in range(n_experts): + W.append( + base_rng.standard_normal(size=(expert_x.shape[-1], out_hidden)).astype(np.float32) + * 0.02 + ) + + y = np.zeros((expert_x.shape[0], out_hidden), dtype=np.float32) + for e in range(n_experts): + s = int(offsets[e]) + t = int(offsets[e + 1]) + if t <= s: + continue + y[s:t, :] = expert_x[s:t, :] @ W[e] + return {"expert_y": y} diff --git a/src/zmlx/foundry/ops/kv_append.py b/src/zmlx/foundry/ops/kv_append.py new file mode 100644 index 0000000..6aa46a4 --- /dev/null +++ b/src/zmlx/foundry/ops/kv_append.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +from typing import Any + +import numpy as np + +from ..taxonomy import KernelClass, OpSpec +from .base import TOKENS_LADDER, KernelOp, randn_np + + +class KVAppendOp(KernelOp): + name = "kv_append" + + def spec(self) -> OpSpec: + return OpSpec( + name=self.name, + kernel_class=KernelClass.KV_CACHE, + summary="KV cache append: write k_new/v_new into cache at [start_pos:start_pos+tokens]", + inputs=[ + "k_cache[max_seq,n_kv_heads,head_dim]", + "v_cache[max_seq,n_kv_heads,head_dim]", + "k_new[tokens,n_kv_heads,head_dim]", + "v_new[tokens,n_kv_heads,head_dim]", + ], + outputs=["k_cache_out", "v_cache_out"], + op_params_schema={ + "max_seq": {"type": "int", "default": 2048}, + "start_pos": {"type": "int", "default": 0}, + }, + shape_hints={"tokens": TOKENS_LADDER, "n_kv_heads": [4, 8, 16], "head_dim": [64, 128]}, + dtype_hints=["float16", "bfloat16"], + templates=["ref"], + ) + + def supported_dtypes(self) -> list[str]: + return ["float16", "bfloat16", "float32"] + + def sample_shape(self, rng: np.random.Generator) -> dict[str, int]: + tokens = int(rng.choice(TOKENS_LADDER)) + n_kv_heads = int(rng.choice([4, 8, 16])) + head_dim = int(rng.choice([64, 128])) + max_seq = int(rng.choice([512, 1024, 2048, 4096])) + if tokens > max_seq: + max_seq = tokens + return {"tokens": tokens, "n_kv_heads": n_kv_heads, "head_dim": head_dim, "max_seq": max_seq} + + def sample_op_params(self, shape: dict[str, int], rng: np.random.Generator) -> dict[str, Any]: + max_seq = int(shape["max_seq"]) + tokens = int(shape["tokens"]) + start_pos = int(rng.integers(0, max(1, max_seq - tokens + 1))) + return {"max_seq": max_seq, "start_pos": start_pos} + + def generate_inputs_numpy( + self, shape: dict[str, int], dtype: str, op_params: dict[str, Any], seed: int + ) -> dict[str, np.ndarray]: + max_seq = int(op_params.get("max_seq", shape["max_seq"])) + dims = (shape["n_kv_heads"], shape["head_dim"]) + k_cache = randn_np((max_seq, *dims), dtype=dtype, seed=seed, scale=0.1) + v_cache = randn_np((max_seq, *dims), dtype=dtype, seed=seed + 1, scale=0.1) + k_new = randn_np((shape["tokens"], *dims), dtype=dtype, seed=seed + 2, scale=0.5) + v_new = randn_np((shape["tokens"], *dims), dtype=dtype, seed=seed + 3, scale=0.5) + return {"k_cache": k_cache, "v_cache": v_cache, "k_new": k_new, "v_new": v_new} + + def reference_numpy( + self, inputs: dict[str, np.ndarray], op_params: dict[str, Any] + ) -> dict[str, np.ndarray]: + k_cache = np.asarray(inputs["k_cache"]).copy() + v_cache = np.asarray(inputs["v_cache"]).copy() + k_new = np.asarray(inputs["k_new"]) + v_new = np.asarray(inputs["v_new"]) + start_pos = int(op_params.get("start_pos", 0)) + end = start_pos + k_new.shape[0] + k_cache[start_pos:end, :, :] = k_new + v_cache[start_pos:end, :, :] = v_new + return {"k_cache_out": k_cache, "v_cache_out": v_cache} diff --git a/src/zmlx/foundry/ops/layernorm.py b/src/zmlx/foundry/ops/layernorm.py new file mode 100644 index 0000000..a7e7dd9 --- /dev/null +++ b/src/zmlx/foundry/ops/layernorm.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +from typing import Any + +import numpy as np + +from ..taxonomy import KernelClass, OpSpec +from .base import ( + HIDDEN_LADDER, + TOKENS_LADDER, + KernelOp, + maybe_to_float32, + randn_np, +) + + +class LayerNormOp(KernelOp): + name = "layernorm" + + def spec(self) -> OpSpec: + return OpSpec( + name=self.name, + kernel_class=KernelClass.REDUCTION, + summary="LayerNorm: (x-mean)/sqrt(var+eps)*gamma+beta", + inputs=["x[tokens,hidden]", "gamma[hidden]", "beta[hidden]"], + outputs=["y[tokens,hidden]"], + op_params_schema={"eps": {"type": "float", "default": 1e-5}}, + shape_hints={"tokens": TOKENS_LADDER, "hidden": HIDDEN_LADDER}, + dtype_hints=["float16", "bfloat16", "float32"], + templates=["ref"], + ) + + def supported_dtypes(self) -> list[str]: + return ["float16", "bfloat16", "float32"] + + def sample_shape(self, rng: np.random.Generator) -> dict[str, int]: + tokens = int(rng.choice(TOKENS_LADDER)) + hidden = int(rng.choice(HIDDEN_LADDER)) + return {"tokens": tokens, "hidden": hidden} + + def sample_op_params(self, shape: dict[str, int], rng: np.random.Generator) -> dict[str, Any]: + eps = float(rng.choice([1e-5, 1e-6, 1e-4])) + return {"eps": eps} + + def generate_inputs_numpy( + self, shape: dict[str, int], dtype: str, op_params: dict[str, Any], seed: int + ) -> dict[str, np.ndarray]: + x = randn_np((shape["tokens"], shape["hidden"]), dtype=dtype, seed=seed, scale=0.5) + gamma = randn_np((shape["hidden"],), dtype=dtype, seed=seed + 1, scale=0.1) + gamma = gamma + 1.0 # center around 1 + beta = randn_np((shape["hidden"],), dtype=dtype, seed=seed + 2, scale=0.1) + return {"x": x, "gamma": gamma, "beta": beta} + + def reference_numpy( + self, inputs: dict[str, np.ndarray], op_params: dict[str, Any] + ) -> dict[str, np.ndarray]: + x = maybe_to_float32(inputs["x"]) + gamma = maybe_to_float32(inputs["gamma"]) + beta = maybe_to_float32(inputs["beta"]) + eps = float(op_params.get("eps", 1e-5)) + mean = np.mean(x, axis=-1, keepdims=True) + var = np.mean((x - mean) ** 2, axis=-1, keepdims=True) + y = (x - mean) / np.sqrt(var + eps) + y = y * gamma + beta + return {"y": y.astype(np.float32)} diff --git a/src/zmlx/foundry/ops/moe_combine.py b/src/zmlx/foundry/ops/moe_combine.py new file mode 100644 index 0000000..50a40d0 --- /dev/null +++ b/src/zmlx/foundry/ops/moe_combine.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +from typing import Any + +import numpy as np + +from ..taxonomy import KernelClass, OpSpec +from .base import ( + HIDDEN_LADDER, + MOE_K_LADDER, + N_EXPERTS_LADDER, + TOKENS_LADDER, + KernelOp, + randn_np, +) +from .moe_pack import pack_assignments +from .moe_topk import softmax_1d + +MOE_COMBINE_TEMPLATES = ["t0_basic", "t1_k8_unrolled", "t2_row_tile"] +TG_SIZES = [32, 64, 128, 256] +UNROLLS = [1, 2, 4, 8] + + +class MoECombineOp(KernelOp): + name = "moe_combine" + extra_shape_dims = {"n_experts": N_EXPERTS_LADDER, "k": MOE_K_LADDER} + + def spec(self) -> OpSpec: + return OpSpec( + name=self.name, + kernel_class=KernelClass.MOE, + summary="MoE combine scatter: accumulate expert_y back to tokens (weighted)", + inputs=[ + "expert_y[tokens*k,hidden]", + "packed_token_ids[tokens*k]", + "packed_weights[tokens*k]", + ], + outputs=["y[tokens,hidden]"], + op_params_schema={ + "n_experts": {"type": "int", "enum": N_EXPERTS_LADDER, "default": 16}, + "k": {"type": "int", "enum": MOE_K_LADDER, "default": 2}, + "distribution": {"type": "str", "enum": ["uniform", "peaked", "ties"], "default": "uniform"}, + }, + shape_hints={ + "tokens": TOKENS_LADDER, + "hidden": HIDDEN_LADDER, + "n_experts": N_EXPERTS_LADDER, + "k": MOE_K_LADDER, + }, + dtype_hints=["float16", "bfloat16", "float32"], + templates=MOE_COMBINE_TEMPLATES, + ) + + def supported_dtypes(self) -> list[str]: + return ["float16", "bfloat16", "float32"] + + def templates(self) -> list[str]: + return MOE_COMBINE_TEMPLATES + + def knob_space(self, template_id: str) -> dict[str, Any]: + _ = template_id + return { + "tg_size": {"type": "int", "values": TG_SIZES}, + "unroll": {"type": "int", "values": UNROLLS}, + "fast_math": {"type": "bool"}, + } + + def sample_shape(self, rng: np.random.Generator) -> dict[str, int]: + tokens = int(rng.choice(TOKENS_LADDER)) + hidden = int(rng.choice(HIDDEN_LADDER)) + n_experts = int(rng.choice(N_EXPERTS_LADDER)) + k = int(rng.choice(MOE_K_LADDER)) + k = min(k, n_experts) + return {"tokens": tokens, "hidden": hidden, "n_experts": n_experts, "k": k} + + def sample_op_params(self, shape: dict[str, int], rng: np.random.Generator) -> dict[str, Any]: + return { + "n_experts": int(shape["n_experts"]), + "k": int(shape["k"]), + "distribution": str(rng.choice(["uniform", "peaked", "ties"])), + } + + def generate_inputs_numpy( + self, shape: dict[str, int], dtype: str, op_params: dict[str, Any], seed: int + ) -> dict[str, np.ndarray]: + x = randn_np((shape["tokens"], shape["hidden"]), dtype=dtype, seed=seed, scale=0.5) + + # Router -> topk -> pack + dist = str(op_params.get("distribution", "uniform")) + logits = randn_np( + (shape["tokens"], shape["n_experts"]), dtype=dtype, seed=seed + 1, scale=1.0 + ) + rng = np.random.default_rng(int(seed + 1) & 0xFFFFFFFF) + if dist == "peaked": + peak_idx = rng.integers(0, shape["n_experts"], size=(shape["tokens"],)) + for t in range(shape["tokens"]): + logits[t, peak_idx[t]] += 6.0 + elif dist == "ties": + logits = np.round(logits * 2.0) / 2.0 + if shape["n_experts"] >= 4: + logits[0, :4] = 1.0 + + topk_idx = np.argsort(-logits, axis=-1)[:, : shape["k"]].astype(np.int32) + top = np.take_along_axis(logits, topk_idx, axis=-1) + topk_w = softmax_1d(top, axis=-1).astype(np.float32) + + _, _, packed_token_ids, packed_weights = pack_assignments( + topk_idx, topk_w, shape["n_experts"] + ) + + # expert_y = x[packed_token_ids] for a strong roundtrip test + expert_y = x[packed_token_ids.astype(np.int64), :] + + return { + "expert_y": expert_y, + "packed_token_ids": packed_token_ids, + "packed_weights": packed_weights, + } + + def reference_numpy( + self, inputs: dict[str, np.ndarray], op_params: dict[str, Any] + ) -> dict[str, np.ndarray]: + expert_y = np.asarray(inputs["expert_y"]).astype(np.float32) + ids = np.asarray(inputs["packed_token_ids"]).astype(np.int64) + w = np.asarray(inputs["packed_weights"]).astype(np.float32) + tokens = int(ids.max()) + 1 if ids.size else 0 + hidden = expert_y.shape[-1] + y = np.zeros((tokens, hidden), dtype=np.float32) + np.add.at(y, ids, expert_y * w[:, None]) + return {"y": y} + + def validate_knobs( + self, template_id: str, knobs: dict[str, Any], shape: dict[str, int], dtype: str + ) -> tuple[bool, str]: + _ = template_id, shape, dtype + tg = int(knobs.get("tg_size", 0)) + if tg not in TG_SIZES: + return False, "invalid_tg_size" + if tg & (tg - 1) != 0: + return False, "tg_size_not_pow2" + unroll = int(knobs.get("unroll", 1)) + if unroll not in UNROLLS: + return False, "invalid_unroll" + return True, "" + + def bytes_and_flops(self, shape: dict[str, int], dtype: str) -> tuple[int, int]: + tokens = int(shape.get("tokens", shape.get("batch", 1) * shape.get("seq", 1))) + hidden = int(shape["hidden"]) + k = int(shape.get("k", 2)) + pairs = tokens * k + bytes_per_elem = 4 if dtype == "float32" else 2 + + bytes_expert_y = pairs * hidden * bytes_per_elem + bytes_ids = pairs * 4 + bytes_weights = pairs * 4 + bytes_y = tokens * hidden * bytes_per_elem + + # One multiply and one add per (assignment, hidden) contribution. + flops = pairs * hidden * 2 + return int(bytes_expert_y + bytes_ids + bytes_weights + bytes_y), int(flops) diff --git a/src/zmlx/foundry/ops/moe_dispatch.py b/src/zmlx/foundry/ops/moe_dispatch.py new file mode 100644 index 0000000..366c8e8 --- /dev/null +++ b/src/zmlx/foundry/ops/moe_dispatch.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +from typing import Any + +import numpy as np + +from ..taxonomy import KernelClass, OpSpec +from .base import ( + HIDDEN_LADDER, + MOE_K_LADDER, + N_EXPERTS_LADDER, + TOKENS_LADDER, + KernelOp, + randn_np, +) +from .moe_pack import pack_assignments +from .moe_topk import softmax_1d + + +class MoEDispatchOp(KernelOp): + name = "moe_dispatch" + + def spec(self) -> OpSpec: + return OpSpec( + name=self.name, + kernel_class=KernelClass.MOE, + summary="MoE dispatch gather: build expert_x by gathering x using packed_token_ids", + inputs=["x[tokens,hidden]", "packed_token_ids[tokens*k]", "expert_offsets[n_experts+1]"], + outputs=["expert_x[tokens*k,hidden]"], + op_params_schema={ + "n_experts": {"type": "int", "enum": N_EXPERTS_LADDER, "default": 16}, + "k": {"type": "int", "enum": MOE_K_LADDER, "default": 2}, + "distribution": {"type": "str", "enum": ["uniform", "peaked", "ties"], "default": "uniform"}, + }, + shape_hints={ + "tokens": TOKENS_LADDER, + "hidden": HIDDEN_LADDER, + "n_experts": N_EXPERTS_LADDER, + "k": MOE_K_LADDER, + }, + dtype_hints=["float16", "bfloat16", "float32"], + templates=["ref"], + ) + + def supported_dtypes(self) -> list[str]: + return ["float16", "bfloat16", "float32"] + + def sample_shape(self, rng: np.random.Generator) -> dict[str, int]: + tokens = int(rng.choice(TOKENS_LADDER)) + hidden = int(rng.choice(HIDDEN_LADDER)) + n_experts = int(rng.choice(N_EXPERTS_LADDER)) + k = int(rng.choice(MOE_K_LADDER)) + k = min(k, n_experts) + return {"tokens": tokens, "hidden": hidden, "n_experts": n_experts, "k": k} + + def sample_op_params(self, shape: dict[str, int], rng: np.random.Generator) -> dict[str, Any]: + return { + "n_experts": int(shape["n_experts"]), + "k": int(shape["k"]), + "distribution": str(rng.choice(["uniform", "peaked", "ties"])), + } + + def generate_inputs_numpy( + self, shape: dict[str, int], dtype: str, op_params: dict[str, Any], seed: int + ) -> dict[str, np.ndarray]: + x = randn_np((shape["tokens"], shape["hidden"]), dtype=dtype, seed=seed, scale=0.5) + + # Build packed assignments from synthetic router logits + dist = str(op_params.get("distribution", "uniform")) + logits = randn_np( + (shape["tokens"], shape["n_experts"]), dtype=dtype, seed=seed + 1, scale=1.0 + ) + rng = np.random.default_rng(int(seed + 1) & 0xFFFFFFFF) + if dist == "peaked": + peak_idx = rng.integers(0, shape["n_experts"], size=(shape["tokens"],)) + for t in range(shape["tokens"]): + logits[t, peak_idx[t]] += 6.0 + elif dist == "ties": + logits = np.round(logits * 2.0) / 2.0 + if shape["n_experts"] >= 4: + logits[0, :4] = 1.0 + + topk_idx = np.argsort(-logits, axis=-1)[:, : shape["k"]].astype(np.int32) + top = np.take_along_axis(logits, topk_idx, axis=-1) + topk_w = softmax_1d(top, axis=-1).astype(np.float32) + + _, expert_offsets, packed_token_ids, _ = pack_assignments( + topk_idx, topk_w, shape["n_experts"] + ) + return { + "x": x, + "packed_token_ids": packed_token_ids, + "expert_offsets": expert_offsets, + } + + def reference_numpy( + self, inputs: dict[str, np.ndarray], op_params: dict[str, Any] + ) -> dict[str, np.ndarray]: + x = np.asarray(inputs["x"]) + ids = np.asarray(inputs["packed_token_ids"]).astype(np.int64) + expert_x = x[ids, :] + return {"expert_x": expert_x} diff --git a/src/zmlx/foundry/ops/moe_pack.py b/src/zmlx/foundry/ops/moe_pack.py new file mode 100644 index 0000000..b33f107 --- /dev/null +++ b/src/zmlx/foundry/ops/moe_pack.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +from typing import Any + +import numpy as np + +from ..taxonomy import KernelClass, OpSpec +from .base import ( + MOE_K_LADDER, + N_EXPERTS_LADDER, + TOKENS_LADDER, + KernelOp, + randn_np, +) +from .moe_topk import softmax_1d + + +def pack_assignments( + topk_idx: np.ndarray, topk_w: np.ndarray, n_experts: int +) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Pack token->expert assignments into expert-grouped layout. + + Returns (expert_counts, expert_offsets, packed_token_ids, packed_weights). + """ + tokens, k = topk_idx.shape + pairs = [] + for t in range(tokens): + for j in range(k): + e = int(topk_idx[t, j]) + pairs.append((e, t, float(topk_w[t, j]))) + pairs.sort(key=lambda x: x[0]) + + expert_counts = np.zeros((n_experts,), dtype=np.int32) + for e, _, _ in pairs: + expert_counts[e] += 1 + + expert_offsets = np.zeros((n_experts + 1,), dtype=np.int32) + expert_offsets[1:] = np.cumsum(expert_counts) + + packed_token_ids = np.zeros((tokens * k,), dtype=np.int32) + packed_weights = np.zeros((tokens * k,), dtype=np.float32) + for i, (_, t, w) in enumerate(pairs): + packed_token_ids[i] = t + packed_weights[i] = w + + return expert_counts, expert_offsets, packed_token_ids, packed_weights + + +class MoEPackOp(KernelOp): + name = "moe_pack" + + def spec(self) -> OpSpec: + return OpSpec( + name=self.name, + kernel_class=KernelClass.MOE, + summary="MoE: build expert_counts/offsets + packed token ids/weights grouped by expert", + inputs=["topk_idx[tokens,k]", "topk_w[tokens,k]"], + outputs=[ + "expert_counts[n_experts]", + "expert_offsets[n_experts+1]", + "packed_token_ids[tokens*k]", + "packed_weights[tokens*k]", + ], + op_params_schema={ + "n_experts": {"type": "int", "enum": N_EXPERTS_LADDER, "default": 16}, + "k": {"type": "int", "enum": MOE_K_LADDER, "default": 2}, + "distribution": {"type": "str", "enum": ["uniform", "peaked", "ties"], "default": "uniform"}, + }, + shape_hints={"tokens": TOKENS_LADDER, "n_experts": N_EXPERTS_LADDER, "k": MOE_K_LADDER}, + dtype_hints=["float16", "float32"], + templates=["ref"], + ) + + def supported_dtypes(self) -> list[str]: + return ["float16", "bfloat16", "float32"] + + def sample_shape(self, rng: np.random.Generator) -> dict[str, int]: + tokens = int(rng.choice(TOKENS_LADDER)) + n_experts = int(rng.choice(N_EXPERTS_LADDER)) + k = int(rng.choice(MOE_K_LADDER)) + k = min(k, n_experts) + return {"tokens": tokens, "n_experts": n_experts, "k": k} + + def sample_op_params(self, shape: dict[str, int], rng: np.random.Generator) -> dict[str, Any]: + return { + "n_experts": int(shape["n_experts"]), + "k": int(shape["k"]), + "distribution": str(rng.choice(["uniform", "peaked", "ties"])), + } + + def generate_inputs_numpy( + self, shape: dict[str, int], dtype: str, op_params: dict[str, Any], seed: int + ) -> dict[str, np.ndarray]: + dist = str(op_params.get("distribution", "uniform")) + logits = randn_np((shape["tokens"], shape["n_experts"]), dtype=dtype, seed=seed, scale=1.0) + rng = np.random.default_rng(int(seed) & 0xFFFFFFFF) + if dist == "peaked": + peak_idx = rng.integers(0, shape["n_experts"], size=(shape["tokens"],)) + for t in range(shape["tokens"]): + logits[t, peak_idx[t]] += 6.0 + elif dist == "ties": + logits = np.round(logits * 2.0) / 2.0 + if shape["n_experts"] >= 4: + logits[0, :4] = 1.0 + idx = np.argsort(-logits, axis=-1)[:, : shape["k"]] + top = np.take_along_axis(logits, idx, axis=-1) + w = softmax_1d(top, axis=-1).astype(np.float32) + return {"topk_idx": idx.astype(np.int32), "topk_w": w} + + def reference_numpy( + self, inputs: dict[str, np.ndarray], op_params: dict[str, Any] + ) -> dict[str, np.ndarray]: + topk_idx = np.asarray(inputs["topk_idx"]).astype(np.int32) + topk_w = np.asarray(inputs["topk_w"]).astype(np.float32) + n_experts = int(op_params.get("n_experts", int(topk_idx.max()) + 1)) + expert_counts, expert_offsets, packed_token_ids, packed_weights = pack_assignments( + topk_idx, topk_w, n_experts + ) + return { + "expert_counts": expert_counts, + "expert_offsets": expert_offsets, + "packed_token_ids": packed_token_ids, + "packed_weights": packed_weights, + } diff --git a/src/zmlx/foundry/ops/moe_topk.py b/src/zmlx/foundry/ops/moe_topk.py new file mode 100644 index 0000000..7cc8185 --- /dev/null +++ b/src/zmlx/foundry/ops/moe_topk.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +from typing import Any + +import numpy as np + +from ..taxonomy import KernelClass, OpSpec +from .base import ( + MOE_K_LADDER, + N_EXPERTS_LADDER, + TOKENS_LADDER, + KernelOp, + maybe_to_float32, + randn_np, +) + + +def softmax_1d(x: np.ndarray, axis: int = -1) -> np.ndarray: + x = x - np.max(x, axis=axis, keepdims=True) + e = np.exp(x) + return e / np.sum(e, axis=axis, keepdims=True) + + +class MoETopKOp(KernelOp): + name = "moe_topk" + + def spec(self) -> OpSpec: + return OpSpec( + name=self.name, + kernel_class=KernelClass.MOE, + summary="MoE routing: router logits -> top-k indices + weights", + inputs=["logits[tokens,n_experts]"], + outputs=["topk_idx[tokens,k]", "topk_w[tokens,k]"], + op_params_schema={ + "k": {"type": "int", "enum": MOE_K_LADDER, "default": 2}, + "distribution": {"type": "str", "enum": ["uniform", "peaked", "ties"], "default": "uniform"}, + "temperature": {"type": "float", "default": 1.0}, + }, + shape_hints={"tokens": TOKENS_LADDER, "n_experts": N_EXPERTS_LADDER}, + dtype_hints=["float16", "bfloat16", "float32"], + templates=["ref"], + ) + + def supported_dtypes(self) -> list[str]: + return ["float16", "bfloat16", "float32"] + + def sample_shape(self, rng: np.random.Generator) -> dict[str, int]: + tokens = int(rng.choice(TOKENS_LADDER)) + n_experts = int(rng.choice(N_EXPERTS_LADDER)) + return {"tokens": tokens, "n_experts": n_experts} + + def sample_op_params(self, shape: dict[str, int], rng: np.random.Generator) -> dict[str, Any]: + k = int(rng.choice(MOE_K_LADDER)) + k = min(k, shape["n_experts"]) + distribution = str(rng.choice(["uniform", "peaked", "ties"])) + temperature = float(rng.choice([0.7, 1.0, 1.5])) + return {"k": k, "distribution": distribution, "temperature": temperature} + + def generate_inputs_numpy( + self, shape: dict[str, int], dtype: str, op_params: dict[str, Any], seed: int + ) -> dict[str, np.ndarray]: + dist = str(op_params.get("distribution", "uniform")) + logits = randn_np((shape["tokens"], shape["n_experts"]), dtype=dtype, seed=seed, scale=1.0) + rng = np.random.default_rng(int(seed) & 0xFFFFFFFF) + if dist == "peaked": + peak_idx = rng.integers(0, shape["n_experts"], size=(shape["tokens"],)) + for t in range(shape["tokens"]): + logits[t, peak_idx[t]] += 6.0 + elif dist == "ties": + logits = np.round(logits * 2.0) / 2.0 + if shape["n_experts"] >= 4: + logits[0, :4] = 1.0 + return {"logits": logits} + + def reference_numpy( + self, inputs: dict[str, np.ndarray], op_params: dict[str, Any] + ) -> dict[str, np.ndarray]: + logits = maybe_to_float32(inputs["logits"]) + k = int(op_params.get("k", 2)) + temperature = float(op_params.get("temperature", 1.0)) + if temperature <= 0: + temperature = 1.0 + x = logits / temperature + idx = np.argsort(-x, axis=-1)[:, :k] + top = np.take_along_axis(x, idx, axis=-1) + w = softmax_1d(top, axis=-1) + return {"topk_idx": idx.astype(np.int32), "topk_w": w.astype(np.float32)} diff --git a/src/zmlx/foundry/ops/quantize.py b/src/zmlx/foundry/ops/quantize.py new file mode 100644 index 0000000..af7dfbe --- /dev/null +++ b/src/zmlx/foundry/ops/quantize.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from typing import Any + +import numpy as np + +from ..taxonomy import KernelClass, OpSpec +from .base import ( + HIDDEN_LADDER, + TOKENS_LADDER, + KernelOp, + maybe_to_float32, + randn_np, +) + + +def _pack_int4(vals: np.ndarray) -> np.ndarray: + """Pack signed int4 values (two's complement nibbles) into uint8.""" + v = vals.astype(np.int8) + if v.shape[-1] % 2 != 0: + raise ValueError("int4 pack requires even last dimension") + low = np.bitwise_and(v[..., 0::2], 0xF).astype(np.uint8) + high = np.bitwise_and(v[..., 1::2], 0xF).astype(np.uint8) + return low | (high << 4) + + +class QuantizeOp(KernelOp): + name = "quantize" + + def spec(self) -> OpSpec: + return OpSpec( + name=self.name, + kernel_class=KernelClass.QUANT, + summary="Quantize fp16/bf16 -> int8 or packed int4 (symmetric)", + inputs=["x[tokens,hidden]"], + outputs=["q[tokens,hidden] (int8) OR q_packed[tokens,hidden/2] (int4 packed)"], + op_params_schema={ + "q_dtype": {"type": "str", "enum": ["int8", "int4"], "default": "int8"}, + "scale": {"type": "float", "default": 0.02}, + }, + shape_hints={"tokens": TOKENS_LADDER, "hidden": HIDDEN_LADDER}, + dtype_hints=["float16", "bfloat16"], + templates=["ref"], + ) + + def supported_dtypes(self) -> list[str]: + return ["float16", "bfloat16", "float32"] + + def sample_shape(self, rng: np.random.Generator) -> dict[str, int]: + tokens = int(rng.choice(TOKENS_LADDER)) + hidden = int(rng.choice(HIDDEN_LADDER)) + if hidden % 2 != 0: + hidden += 1 + return {"tokens": tokens, "hidden": hidden} + + def sample_op_params(self, shape: dict[str, int], rng: np.random.Generator) -> dict[str, Any]: + q_dtype = str(rng.choice(["int8", "int4"])) + scale = float(rng.choice([0.01, 0.02, 0.05])) + return {"q_dtype": q_dtype, "scale": scale} + + def generate_inputs_numpy( + self, shape: dict[str, int], dtype: str, op_params: dict[str, Any], seed: int + ) -> dict[str, np.ndarray]: + x = randn_np((shape["tokens"], shape["hidden"]), dtype=dtype, seed=seed, scale=0.5) + return {"x": x} + + def reference_numpy( + self, inputs: dict[str, np.ndarray], op_params: dict[str, Any] + ) -> dict[str, np.ndarray]: + x = maybe_to_float32(inputs["x"]) + q_dtype = str(op_params.get("q_dtype", "int8")) + scale = float(op_params.get("scale", 0.02)) + if scale <= 0: + scale = 0.02 + + if q_dtype == "int8": + q = np.clip(np.round(x / scale), -128, 127).astype(np.int8) + return {"q": q} + if q_dtype == "int4": + q4 = np.clip(np.round(x / scale), -8, 7).astype(np.int8) + q_packed = _pack_int4(q4) + return {"q_packed": q_packed} + raise ValueError(f"Unknown q_dtype: {q_dtype}") diff --git a/src/zmlx/foundry/ops/rmsnorm.py b/src/zmlx/foundry/ops/rmsnorm.py new file mode 100644 index 0000000..68eef5b --- /dev/null +++ b/src/zmlx/foundry/ops/rmsnorm.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +from typing import Any + +import numpy as np + +from ..taxonomy import KernelClass, OpSpec +from .base import ( + HIDDEN_LADDER, + TOKENS_LADDER, + KernelOp, + maybe_to_float32, + randn_np, +) + +# --------------------------------------------------------------------------- +# Knob space constants (from datafoundry rmsnorm_knobs) +# --------------------------------------------------------------------------- + +RMSNORM_TEMPLATES = ["t0_basic", "t1_tgmem"] +TG_SIZES = [32, 64, 128, 256] +VECS = [1, 2, 4] +UNROLLS = [1, 2, 4] + + +class RMSNormOp(KernelOp): + name = "rmsnorm" + + def spec(self) -> OpSpec: + return OpSpec( + name=self.name, + kernel_class=KernelClass.REDUCTION, + summary="RMSNorm: y = x * w / sqrt(mean(x^2)+eps)", + inputs=["x[tokens,hidden]", "w[hidden]"], + outputs=["y[tokens,hidden]"], + op_params_schema={"eps": {"type": "float", "default": 1e-5}}, + shape_hints={"tokens": TOKENS_LADDER, "hidden": HIDDEN_LADDER}, + dtype_hints=["float16", "bfloat16", "float32"], + templates=RMSNORM_TEMPLATES, + ) + + def supported_dtypes(self) -> list[str]: + return ["float16", "bfloat16", "float32"] + + def templates(self) -> list[str]: + return RMSNORM_TEMPLATES + + def knob_space(self, template_id: str) -> dict[str, Any]: + space: dict[str, Any] = { + "tg_size": {"type": "int", "values": TG_SIZES}, + "unroll": {"type": "int", "values": UNROLLS}, + "fast_math": {"type": "bool"}, + } + if template_id == "t0_basic": + space["vec"] = {"type": "int", "values": VECS} + return space + + def sample_shape(self, rng: np.random.Generator) -> dict[str, int]: + tokens = int(rng.choice(TOKENS_LADDER)) + hidden = int(rng.choice(HIDDEN_LADDER)) + return {"tokens": tokens, "hidden": hidden} + + def sample_op_params(self, shape: dict[str, int], rng: np.random.Generator) -> dict[str, Any]: + eps = float(rng.choice([1e-5, 1e-6, 1e-4])) + return {"eps": eps} + + def generate_inputs_numpy( + self, + shape: dict[str, int], + dtype: str, + op_params: dict[str, Any], + seed: int, + ) -> dict[str, np.ndarray]: + x = randn_np((shape["tokens"], shape["hidden"]), dtype=dtype, seed=seed, scale=0.5) + w = randn_np((shape["hidden"],), dtype=dtype, seed=seed + 1, scale=0.1) + return {"x": x, "w": w} + + def reference_numpy( + self, + inputs: dict[str, np.ndarray], + op_params: dict[str, Any], + ) -> dict[str, np.ndarray]: + x = maybe_to_float32(inputs["x"]) + w = maybe_to_float32(inputs["w"]) + eps = float(op_params.get("eps", 1e-5)) + ms = np.mean(x * x, axis=-1, keepdims=True) + inv = 1.0 / np.sqrt(ms + eps) + y = x * inv * w + return {"y": y.astype(np.float32)} + + def validate_knobs( + self, template_id: str, knobs: dict[str, Any], shape: dict[str, int], dtype: str + ) -> tuple[bool, str]: + tg = int(knobs.get("tg_size", 0)) + if tg not in TG_SIZES: + return False, "invalid_tg_size" + if tg & (tg - 1) != 0: + return False, "tg_size_not_pow2" + unroll = int(knobs.get("unroll", 1)) + if unroll not in UNROLLS: + return False, "invalid_unroll" + vec = int(knobs.get("vec", 1)) + if template_id == "t0_basic" and vec not in VECS: + return False, "invalid_vec" + if template_id != "t0_basic" and vec != 1: + return False, "vec_not_supported" + return True, "" + + def bytes_and_flops(self, shape: dict[str, int], dtype: str) -> tuple[int, int]: + tokens = int(shape.get("tokens", shape.get("batch", 1) * shape.get("seq", 1))) + h = int(shape["hidden"]) + bytes_per_elem = 4 if dtype == "float32" else 2 + # x read + y write + bytes_rw = tokens * h * bytes_per_elem * 2 + # weight read (amortized across tokens) + bytes_w = h * bytes_per_elem + # flops: square + sum (2), rsqrt (~1), mul (2) per element + flops = tokens * h * 4 + return int(bytes_rw + bytes_w), int(flops) diff --git a/src/zmlx/foundry/ops/rope.py b/src/zmlx/foundry/ops/rope.py new file mode 100644 index 0000000..17e5736 --- /dev/null +++ b/src/zmlx/foundry/ops/rope.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +from typing import Any + +import numpy as np + +from ..taxonomy import KernelClass, OpSpec +from .base import TOKENS_LADDER, KernelOp, maybe_to_float32, randn_np + + +def _rope_apply(x: np.ndarray, theta: float, offset: int = 0) -> np.ndarray: + tokens, n_heads, head_dim = x.shape + if head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + half = head_dim // 2 + + pos = np.arange(tokens, dtype=np.float32) + float(offset) + inv_freq = theta ** (-np.arange(0, half, dtype=np.float32) / float(half)) + angles = np.outer(pos, inv_freq) + cos = np.cos(angles)[:, None, :] + sin = np.sin(angles)[:, None, :] + + x1 = x[..., :half] + x2 = x[..., half:] + y1 = x1 * cos - x2 * sin + y2 = x1 * sin + x2 * cos + return np.concatenate([y1, y2], axis=-1) + + +class RoPEOp(KernelOp): + name = "rope" + + def spec(self) -> OpSpec: + return OpSpec( + name=self.name, + kernel_class=KernelClass.ATTENTION, + summary="Apply RoPE to Q/K tensors (rotary embedding)", + inputs=["x[tokens,n_heads,head_dim]"], + outputs=["y[tokens,n_heads,head_dim]"], + op_params_schema={ + "theta": {"type": "float", "default": 10000.0}, + "offset": {"type": "int", "default": 0}, + }, + shape_hints={"tokens": TOKENS_LADDER, "head_dim": [64, 96, 128], "n_heads": [8, 16, 32]}, + dtype_hints=["float16", "bfloat16", "float32"], + templates=["ref"], + ) + + def supported_dtypes(self) -> list[str]: + return ["float16", "bfloat16", "float32"] + + def sample_shape(self, rng: np.random.Generator) -> dict[str, int]: + tokens = int(rng.choice(TOKENS_LADDER)) + n_heads = int(rng.choice([8, 16, 32])) + head_dim = int(rng.choice([64, 96, 128])) + head_dim = head_dim + (head_dim % 2) + return {"tokens": tokens, "n_heads": n_heads, "head_dim": head_dim} + + def sample_op_params(self, shape: dict[str, int], rng: np.random.Generator) -> dict[str, Any]: + theta = float(rng.choice([10000.0, 50000.0, 1000.0])) + offset = int(rng.integers(0, 2048)) + return {"theta": theta, "offset": offset} + + def generate_inputs_numpy( + self, shape: dict[str, int], dtype: str, op_params: dict[str, Any], seed: int + ) -> dict[str, np.ndarray]: + x = randn_np( + (shape["tokens"], shape["n_heads"], shape["head_dim"]), + dtype=dtype, seed=seed, scale=0.5, + ) + return {"x": x} + + def reference_numpy( + self, inputs: dict[str, np.ndarray], op_params: dict[str, Any] + ) -> dict[str, np.ndarray]: + x = maybe_to_float32(inputs["x"]) + theta = float(op_params.get("theta", 10000.0)) + offset = int(op_params.get("offset", 0)) + y = _rope_apply(x, theta=theta, offset=offset) + return {"y": y.astype(np.float32)} diff --git a/src/zmlx/foundry/ops/scatter.py b/src/zmlx/foundry/ops/scatter.py new file mode 100644 index 0000000..51697bd --- /dev/null +++ b/src/zmlx/foundry/ops/scatter.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from typing import Any + +import numpy as np + +from ..taxonomy import KernelClass, OpSpec +from .base import HIDDEN_LADDER, TOKENS_LADDER, KernelOp, randint_np, randn_np + + +class ScatterOp(KernelOp): + name = "scatter" + + def spec(self) -> OpSpec: + return OpSpec( + name=self.name, + kernel_class=KernelClass.DATA_MOVEMENT, + summary="Scatter-add along axis=0: y[idx] += updates", + inputs=["updates[m,hidden]", "idx[m]", "base[tokens,hidden]"], + outputs=["y[tokens,hidden]"], + op_params_schema={"mode": {"type": "str", "enum": ["add"], "default": "add"}}, + shape_hints={ + "tokens": TOKENS_LADDER, + "hidden": HIDDEN_LADDER, + "m": [16, 32, 64, 128, 256, 512], + }, + dtype_hints=["float16", "bfloat16", "float32"], + templates=["ref"], + ) + + def supported_dtypes(self) -> list[str]: + return ["float16", "bfloat16", "float32"] + + def sample_shape(self, rng: np.random.Generator) -> dict[str, int]: + tokens = int(rng.choice(TOKENS_LADDER)) + hidden = int(rng.choice(HIDDEN_LADDER)) + m = int(rng.choice([16, 32, 64, 128, 256])) + return {"tokens": tokens, "hidden": hidden, "m": m} + + def generate_inputs_numpy( + self, shape: dict[str, int], dtype: str, op_params: dict[str, Any], seed: int + ) -> dict[str, np.ndarray]: + updates = randn_np((shape["m"], shape["hidden"]), dtype=dtype, seed=seed, scale=0.5) + idx = randint_np(0, max(1, shape["tokens"]), (shape["m"],), dtype="int32", seed=seed + 1) + base = randn_np((shape["tokens"], shape["hidden"]), dtype=dtype, seed=seed + 2, scale=0.1) + return {"updates": updates, "idx": idx, "base": base} + + def reference_numpy( + self, inputs: dict[str, np.ndarray], op_params: dict[str, Any] + ) -> dict[str, np.ndarray]: + updates = np.asarray(inputs["updates"]) + idx = np.asarray(inputs["idx"]).astype(np.int64) + base = np.asarray(inputs["base"]).copy() + np.add.at(base, idx, updates) + return {"y": base} diff --git a/src/zmlx/foundry/ops/softmax.py b/src/zmlx/foundry/ops/softmax.py new file mode 100644 index 0000000..150e6d7 --- /dev/null +++ b/src/zmlx/foundry/ops/softmax.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from typing import Any + +import numpy as np + +from ..taxonomy import KernelClass, OpSpec +from .base import TOKENS_LADDER, KernelOp, maybe_to_float32, randn_np + + +def _softmax(x: np.ndarray, axis: int = -1) -> np.ndarray: + x = x - np.max(x, axis=axis, keepdims=True) + e = np.exp(x) + return e / np.sum(e, axis=axis, keepdims=True) + + +class SoftmaxOp(KernelOp): + name = "softmax" + + def spec(self) -> OpSpec: + return OpSpec( + name=self.name, + kernel_class=KernelClass.ATTENTION, + summary="Softmax (attention-style) along last axis", + inputs=["x[tokens,dim]"], + outputs=["y[tokens,dim]"], + op_params_schema={"axis": {"type": "int", "default": -1}}, + shape_hints={"tokens": TOKENS_LADDER, "dim": [32, 64, 128, 256, 512, 1024, 2048]}, + dtype_hints=["float16", "bfloat16", "float32"], + templates=["ref"], + ) + + def supported_dtypes(self) -> list[str]: + return ["float16", "bfloat16", "float32"] + + def sample_shape(self, rng: np.random.Generator) -> dict[str, int]: + tokens = int(rng.choice(TOKENS_LADDER)) + dim = int(rng.choice([32, 64, 128, 256, 512, 1024])) + return {"tokens": tokens, "dim": dim} + + def sample_op_params(self, shape: dict[str, int], rng: np.random.Generator) -> dict[str, Any]: + return {"axis": -1} + + def generate_inputs_numpy( + self, shape: dict[str, int], dtype: str, op_params: dict[str, Any], seed: int + ) -> dict[str, np.ndarray]: + x = randn_np((shape["tokens"], shape["dim"]), dtype=dtype, seed=seed, scale=1.0) + return {"x": x} + + def reference_numpy( + self, inputs: dict[str, np.ndarray], op_params: dict[str, Any] + ) -> dict[str, np.ndarray]: + x = maybe_to_float32(inputs["x"]) + axis = int(op_params.get("axis", -1)) + y = _softmax(x, axis=axis) + return {"y": y.astype(np.float32)} diff --git a/src/zmlx/foundry/ops/swiglu.py b/src/zmlx/foundry/ops/swiglu.py new file mode 100644 index 0000000..bb2f54c --- /dev/null +++ b/src/zmlx/foundry/ops/swiglu.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +from typing import Any + +import numpy as np + +from ..taxonomy import KernelClass, OpSpec +from .base import ( + HIDDEN_LADDER, + TOKENS_LADDER, + KernelOp, + maybe_to_float32, + randn_np, +) + +# --------------------------------------------------------------------------- +# Knob space constants (from datafoundry swiglu_knobs) +# --------------------------------------------------------------------------- + +SWIGLU_TEMPLATES = ["t0_basic", "t1_unrolled"] +TG_SIZES = [32, 64, 128, 256] +VECS = [1, 2, 4] +UNROLLS = [1, 2, 4] + + +def _silu(x: np.ndarray) -> np.ndarray: + return x / (1.0 + np.exp(-x)) + + +class SwiGLUOp(KernelOp): + name = "swiglu" + + def spec(self) -> OpSpec: + return OpSpec( + name=self.name, + kernel_class=KernelClass.FUSED_POINTWISE, + summary="SwiGLU gate: y = silu(a) * b where x=[a|b]", + inputs=["x[tokens,2*hidden]"], + outputs=["y[tokens,hidden]"], + op_params_schema={}, + shape_hints={"tokens": TOKENS_LADDER, "hidden": HIDDEN_LADDER}, + dtype_hints=["float16", "bfloat16", "float32"], + templates=SWIGLU_TEMPLATES, + ) + + def supported_dtypes(self) -> list[str]: + return ["float16", "bfloat16", "float32"] + + def templates(self) -> list[str]: + return SWIGLU_TEMPLATES + + def knob_space(self, template_id: str) -> dict[str, Any]: + space: dict[str, Any] = { + "tg_size": {"type": "int", "values": TG_SIZES}, + "unroll": {"type": "int", "values": UNROLLS}, + "fast_math": {"type": "bool"}, + } + if template_id == "t1_unrolled": + space["vec"] = {"type": "int", "values": VECS} + return space + + def sample_shape(self, rng: np.random.Generator) -> dict[str, int]: + tokens = int(rng.choice(TOKENS_LADDER)) + hidden = int(rng.choice(HIDDEN_LADDER)) + return {"tokens": tokens, "hidden": hidden} + + def generate_inputs_numpy( + self, + shape: dict[str, int], + dtype: str, + op_params: dict[str, Any], + seed: int, + ) -> dict[str, np.ndarray]: + x = randn_np((shape["tokens"], shape["hidden"] * 2), dtype=dtype, seed=seed, scale=0.5) + return {"x": x} + + def reference_numpy( + self, + inputs: dict[str, np.ndarray], + op_params: dict[str, Any], + ) -> dict[str, np.ndarray]: + x = maybe_to_float32(inputs["x"]) + h2 = x.shape[-1] + h = h2 // 2 + x0 = x[..., :h] + x1 = x[..., h:2 * h] + sig = 1.0 / (1.0 + np.exp(-x0)) + y = (x0 * sig) * x1 + return {"y": y.astype(np.float32)} + + def validate_knobs( + self, template_id: str, knobs: dict[str, Any], shape: dict[str, int], dtype: str + ) -> tuple[bool, str]: + tg = int(knobs.get("tg_size", 0)) + if tg not in TG_SIZES: + return False, "invalid_tg_size" + if tg & (tg - 1) != 0: + return False, "tg_size_not_pow2" + unroll = int(knobs.get("unroll", 1)) + if unroll not in UNROLLS: + return False, "invalid_unroll" + vec = int(knobs.get("vec", 1)) + if template_id == "t1_unrolled": + if vec not in VECS: + return False, "invalid_vec" + else: + if vec != 1: + return False, "vec_not_supported" + return True, "" + + def bytes_and_flops(self, shape: dict[str, int], dtype: str) -> tuple[int, int]: + tokens = int(shape.get("tokens", shape.get("batch", 1) * shape.get("seq", 1))) + h = int(shape["hidden"]) + bytes_per_elem = 4 if dtype == "float32" else 2 + # reads 2H inputs, writes H output + bytes_rw = tokens * (2 * h * bytes_per_elem + h * bytes_per_elem) + # flops per output elem: sigmoid approx + muls + flops = tokens * h * 12 + return int(bytes_rw), int(flops) diff --git a/src/zmlx/foundry/ops/topk.py b/src/zmlx/foundry/ops/topk.py new file mode 100644 index 0000000..f0559f9 --- /dev/null +++ b/src/zmlx/foundry/ops/topk.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from typing import Any + +import numpy as np + +from ..taxonomy import KernelClass, OpSpec +from .base import TOKENS_LADDER, KernelOp, maybe_to_float32, randn_np + + +class TopKOp(KernelOp): + name = "topk" + + def spec(self) -> OpSpec: + return OpSpec( + name=self.name, + kernel_class=KernelClass.REDUCTION, + summary="TopK along last axis", + inputs=["x[tokens,dim]"], + outputs=["values[tokens,k]", "indices[tokens,k]"], + op_params_schema={"k": {"type": "int", "enum": [2, 4, 8], "default": 4}}, + shape_hints={"tokens": TOKENS_LADDER, "dim": [8, 16, 32, 64, 128, 256]}, + dtype_hints=["float16", "bfloat16", "float32"], + templates=["ref"], + ) + + def supported_dtypes(self) -> list[str]: + return ["float16", "bfloat16", "float32"] + + def sample_shape(self, rng: np.random.Generator) -> dict[str, int]: + tokens = int(rng.choice(TOKENS_LADDER)) + dim = int(rng.choice([8, 16, 32, 64, 128])) + return {"tokens": tokens, "dim": dim} + + def sample_op_params(self, shape: dict[str, int], rng: np.random.Generator) -> dict[str, Any]: + k = int(rng.choice([2, 4, 8])) + k = min(k, shape["dim"]) + return {"k": k} + + def generate_inputs_numpy( + self, shape: dict[str, int], dtype: str, op_params: dict[str, Any], seed: int + ) -> dict[str, np.ndarray]: + x = randn_np((shape["tokens"], shape["dim"]), dtype=dtype, seed=seed, scale=1.0) + return {"x": x} + + def reference_numpy( + self, inputs: dict[str, np.ndarray], op_params: dict[str, Any] + ) -> dict[str, np.ndarray]: + x = maybe_to_float32(inputs["x"]) + k = int(op_params.get("k", 4)) + idx = np.argsort(-x, axis=-1)[:, :k] + vals = np.take_along_axis(x, idx, axis=-1) + return {"values": vals.astype(np.float32), "indices": idx.astype(np.int32)} diff --git a/src/zmlx/foundry/plugins/__init__.py b/src/zmlx/foundry/plugins/__init__.py new file mode 100644 index 0000000..9970f6e --- /dev/null +++ b/src/zmlx/foundry/plugins/__init__.py @@ -0,0 +1,30 @@ +"""Plugin system for the ZMLX foundry. + +Provides Protocol classes defining the plugin contracts, plus entry-point +and local-module loaders for extensibility. +""" +from __future__ import annotations + +from .protocols import ( + DiscoveryContext, + DiscoveryPlugin, + ExportArtifacts, + ExportPlugin, + FoundryContext, + FoundryPlugin, + MoEPlugin, +) +from .registry import LoadedPlugin, load_entrypoint_plugins, load_local_plugin + +__all__ = [ + "DiscoveryPlugin", + "FoundryPlugin", + "MoEPlugin", + "ExportPlugin", + "DiscoveryContext", + "FoundryContext", + "ExportArtifacts", + "LoadedPlugin", + "load_entrypoint_plugins", + "load_local_plugin", +] diff --git a/src/zmlx/foundry/plugins/protocols.py b/src/zmlx/foundry/plugins/protocols.py new file mode 100644 index 0000000..de4f464 --- /dev/null +++ b/src/zmlx/foundry/plugins/protocols.py @@ -0,0 +1,117 @@ +"""Plugin protocol definitions for the ZMLX foundry. + +Each protocol defines a minimal contract that plugins must satisfy. +Using ``Protocol`` (structural subtyping) rather than ABC so that +external plugin authors do not need to inherit from anything. +""" +from __future__ import annotations + +from collections.abc import Callable, Iterator +from dataclasses import dataclass, field +from typing import Any, Protocol + +# --------------------------------------------------------------------------- +# Context objects passed to plugins +# --------------------------------------------------------------------------- + +@dataclass +class DiscoveryContext: + """Context passed to discovery plugins.""" + + op: str + dtype: str + shape_space: dict[str, Any] + constraints: dict[str, Any] = field(default_factory=dict) + seed: int = 0 + + +@dataclass +class FoundryContext: + """Context passed to foundry plugins.""" + + op: str + dtype: str + curriculum: dict[str, Any] = field(default_factory=dict) + seed: int = 0 + + +# --------------------------------------------------------------------------- +# Candidate / artifact types +# --------------------------------------------------------------------------- + +@dataclass(frozen=True) +class Candidate: + """A proposed kernel candidate from a discovery plugin.""" + + template_id: str + knobs: dict[str, Any] + notes: str | None = None + + +@dataclass(frozen=True) +class AttemptSpec: + """A fully specified attempt plan entry consumed by the harness. + + Carries everything needed to compile, run, and validate a single + kernel attempt. + """ + + op: str + dtype: str + shape: dict[str, Any] + spec: dict[str, Any] + template_id: str + knobs: dict[str, Any] + compile_flags: list[str] = field(default_factory=list) + launch: dict[str, Any] = field(default_factory=dict) + + +@dataclass(frozen=True) +class ExportArtifacts: + """Describes the output of an export plugin.""" + + out_dir: str + files: list[str] + n_records: int + + +# --------------------------------------------------------------------------- +# Plugin protocols +# --------------------------------------------------------------------------- + +class DiscoveryPlugin(Protocol): + """Proposes kernel candidates and learns from evaluation results.""" + + def propose_candidates(self, context: DiscoveryContext) -> list[Candidate]: + ... + + def update(self, context: DiscoveryContext, results: list[dict[str, Any]]) -> None: + ... + + +class FoundryPlugin(Protocol): + """Generates a full attempt plan for dataset construction.""" + + def generate_attempt_plan( + self, + target_attempts: int, + curriculum: dict[str, Any], + ) -> Iterator[AttemptSpec]: + ... + + +class MoEPlugin(Protocol): + """Registers MoE-specific ops and provides property tests.""" + + def register_ops(self, registry: Any) -> None: + ... + + def property_tests(self) -> list[Callable[[], None]]: + ... + + +class ExportPlugin(Protocol): + """Exports session data to a target format.""" + + def export(self, session_dir: str, out_dir: str) -> ExportArtifacts: + ... diff --git a/src/zmlx/foundry/plugins/registry.py b/src/zmlx/foundry/plugins/registry.py new file mode 100644 index 0000000..dc9eb31 --- /dev/null +++ b/src/zmlx/foundry/plugins/registry.py @@ -0,0 +1,80 @@ +"""Plugin discovery and loading for the ZMLX foundry. + +Supports two loading mechanisms: + +1. **Entry points** (``zmlx.foundry.plugins`` group): external packages + declare plugins via ``[project.entry-points."zmlx.foundry.plugins"]`` + in their pyproject.toml. + +2. **Local module paths**: for development or in-tree plugins, load by + dotted module path plus optional attribute name. +""" +from __future__ import annotations + +import importlib +import importlib.metadata +from dataclasses import dataclass +from typing import Any + + +@dataclass(frozen=True) +class LoadedPlugin: + """Container for a loaded plugin reference.""" + name: str + obj: Any + + +# Entry-point group name. External packages register under this group +# to be auto-discovered by the foundry. +_EP_GROUP = "zmlx.foundry.plugins" + + +def load_entrypoint_plugins(group: str = _EP_GROUP) -> list[LoadedPlugin]: + """Load all plugins registered under the given entry-point group. + + Silently skips plugins that fail to load (broken install, missing + dependency, etc.) to avoid crashing the entire foundry when one + plugin is broken. + """ + plugins: list[LoadedPlugin] = [] + try: + eps = importlib.metadata.entry_points() + # Python 3.12+ has eps.select(); 3.9/3.10 compat via dict get + if hasattr(eps, "select"): + group_eps = eps.select(group=group) + else: + group_eps = eps.get(group, []) # type: ignore[union-attr] + for ep in group_eps: + try: + obj = ep.load() + plugins.append(LoadedPlugin(name=ep.name, obj=obj)) + except Exception: + # Log at debug level in production; here we silently skip. + continue + except Exception: + pass + return plugins + + +def load_local_plugin( + module_path: str, + attr: str | None = None, +) -> LoadedPlugin: + """Load a plugin from a local dotted module path. + + Parameters + ---------- + module_path : str + Fully qualified module path, e.g. ``"zmlx.foundry.myplugin"``. + attr : str, optional + Attribute name within the module. If omitted, the module itself + is returned as the plugin object. + + Returns + ------- + LoadedPlugin + """ + mod = importlib.import_module(module_path) + obj = getattr(mod, attr) if attr else mod + name = f"{module_path}:{attr}" if attr else module_path + return LoadedPlugin(name=name, obj=obj) diff --git a/src/zmlx/foundry/reports/__init__.py b/src/zmlx/foundry/reports/__init__.py new file mode 100644 index 0000000..6340d61 --- /dev/null +++ b/src/zmlx/foundry/reports/__init__.py @@ -0,0 +1,12 @@ +"""Foundry reporting modules: coverage analysis and Pareto extraction.""" +from __future__ import annotations + +from .coverage import build_coverage, write_coverage_reports +from .pareto import best_kernel_by_p50, extract_pareto_front + +__all__ = [ + "build_coverage", + "write_coverage_reports", + "extract_pareto_front", + "best_kernel_by_p50", +] diff --git a/src/zmlx/foundry/reports/coverage.py b/src/zmlx/foundry/reports/coverage.py new file mode 100644 index 0000000..4873b16 --- /dev/null +++ b/src/zmlx/foundry/reports/coverage.py @@ -0,0 +1,207 @@ +"""Coverage report generation. + +Scans a session's NDJSON attempt logs and produces per-op coverage +summaries (attempts, ok rate, dtypes seen, shape classes seen, best +latency) plus detailed per-combo breakdowns. + +Adapted from zmlx-kernel-foundry's ``reports/coverage.py``. The main +generalization is that this version does not depend on a live op registry; +instead it derives op metadata purely from the attempt records. +""" +from __future__ import annotations + +import json +from collections import defaultdict +from pathlib import Path +from typing import Any + +from ..ndjson import iter_records + +# --------------------------------------------------------------------------- +# Attempt iteration +# --------------------------------------------------------------------------- + +def _iter_attempts(session_dir: Path): + """Yield attempt records from all NDJSON files in the session.""" + # Single-file layout + single = session_dir / "attempts.ndjson" + if single.exists(): + yield from iter_records(single) + + # Per-worker shards + for p in sorted(session_dir.glob("attempts.worker*.ndjson")): + yield from iter_records(p) + + # Sub-directory layout (zmlx-kernel-foundry style) + attempts_dir = session_dir / "attempts" + if attempts_dir.is_dir(): + for p in sorted(attempts_dir.glob("*.ndjson")): + yield from iter_records(p) + + +def _infer_shape_class(op: str, shape: dict[str, int]) -> str: + """Lightweight shape class string when the op registry is unavailable.""" + b = shape.get("batch", 1) + s = shape.get("seq", shape.get("tokens", 1)) + h = shape.get("hidden", 0) + parts = [f"b{b}", f"s{s}", f"h{h}"] + if not bool(shape.get("contiguous", True)): + parts.append("strided") + else: + parts.append("contig") + return "_".join(parts) + + +# --------------------------------------------------------------------------- +# Build coverage dict +# --------------------------------------------------------------------------- + +def build_coverage(session_dir: Path) -> dict[str, Any]: + """Compute coverage statistics for a session. + + Returns a dict with keys: + + - ``session_dir``: str + - ``ops``: per-op summary dicts + - ``combos``: sorted list of per-combo dicts + """ + per_op: dict[str, dict[str, Any]] = {} + by_key: dict[tuple, dict[str, Any]] = defaultdict( + lambda: {"attempts": 0, "ok": 0, "best_latency_ms": None}, + ) + + for att in _iter_attempts(session_dir): + op = att.get("op") + if not op: + continue + + # Lazily initialize per_op entry + if op not in per_op: + kc = att.get("kernel_class", "?") + per_op[op] = { + "kernel_class": kc, + "attempts": 0, + "ok": 0, + "best_template_latency_ms": None, + "dtypes_seen": set(), + "shape_classes_seen": set(), + "templates_seen": set(), + } + + dtype = att.get("dtype", "?") + t_id = att.get("template_id", att.get("kernel", {}).get("template_id", "ref")) + shp = att.get("shape", {}) + sc = att.get("shape_class") or _infer_shape_class(op, shp) + + per_op[op]["attempts"] += 1 + per_op[op]["dtypes_seen"].add(dtype) + per_op[op]["shape_classes_seen"].add(sc) + per_op[op]["templates_seen"].add(t_id) + + # Determine success -- support both Foundry and DataFoundry record layouts + res = att.get("result", {}) + status = res.get("status") + correctness_ok = att.get("correctness", {}).get("ok", False) + bench_ok = att.get("bench", {}).get("ok", False) + is_ok = (status == "ok") or (correctness_ok and bench_ok) + + key = (op, dtype, sc, t_id) + by_key[key]["attempts"] += 1 + + if is_ok: + per_op[op]["ok"] += 1 + by_key[key]["ok"] += 1 + + # Extract latency (multiple possible record shapes) + lat = ( + res.get("template_latency_ms") + or att.get("bench", {}).get("latency_ms", {}).get("p50") + ) + if isinstance(lat, (int, float)): + cur = by_key[key]["best_latency_ms"] + by_key[key]["best_latency_ms"] = lat if cur is None else min(cur, lat) + cur2 = per_op[op]["best_template_latency_ms"] + per_op[op]["best_template_latency_ms"] = lat if cur2 is None else min(cur2, lat) + + # Finalize sets -> sorted lists for JSON serialization + for d in per_op.values(): + d["dtypes_seen"] = sorted(list(d["dtypes_seen"])) + d["shape_classes_seen"] = sorted(list(d["shape_classes_seen"])) + d["templates_seen"] = sorted(list(d["templates_seen"])) + + combos: list[dict[str, Any]] = [] + for (op, dtype, sc, t_id), d in by_key.items(): + combos.append({ + "op": op, + "dtype": dtype, + "shape_class": sc, + "template_id": t_id, + "attempts": d["attempts"], + "ok": d["ok"], + "best_latency_ms": d["best_latency_ms"], + }) + + combos.sort(key=lambda x: (x["op"], x["dtype"], x["shape_class"], x["template_id"])) + + return { + "session_dir": str(session_dir), + "ops": per_op, + "combos": combos, + } + + +# --------------------------------------------------------------------------- +# Write coverage reports (JSON + Markdown) +# --------------------------------------------------------------------------- + +def write_coverage_reports(session_dir: Path) -> tuple[Path, Path]: + """Generate and write coverage reports to the session directory. + + Returns (markdown_path, json_path). + """ + cov = build_coverage(session_dir) + + out_json = session_dir / "coverage.json" + out_md = session_dir / "coverage.md" + + with open(out_json, "w", encoding="utf-8") as f: + json.dump(cov, f, indent=2, ensure_ascii=False) + + # -- Markdown report --------------------------------------------------- + lines: list[str] = [] + lines.append(f"# Coverage Report -- {session_dir.name}\n") + lines.append("## Ops\n") + lines.append("| op | class | attempts | ok | best_latency_ms | dtypes | shape_classes | templates |") + lines.append("|---|---|---:|---:|---:|---|---|---|") + + for op, d in sorted(cov["ops"].items()): + best = d["best_template_latency_ms"] + best_s = f"{best:.4f}" if isinstance(best, (int, float)) else "-" + sc_list = d["shape_classes_seen"] + sc_str = ", ".join(sc_list[:3]) + if len(sc_list) > 3: + sc_str += "..." + lines.append( + f"| {op} | {d['kernel_class']} | {d['attempts']} | {d['ok']}" + f" | {best_s}" + f" | {', '.join(d['dtypes_seen'])}" + f" | {sc_str}" + f" | {', '.join(d['templates_seen'])} |" + ) + + # Group ops by kernel class + lines.append("\n## Kernel classes\n") + by_class: dict[str, list[str]] = defaultdict(list) + for op, d in cov["ops"].items(): + by_class[d["kernel_class"]].append(op) + + for kc in sorted(by_class.keys()): + ops = sorted(by_class[kc]) + lines.append(f"### Class {kc}\n") + lines.append(", ".join(ops)) + lines.append("") + + with open(out_md, "w", encoding="utf-8") as f: + f.write("\n".join(lines)) + + return out_md, out_json diff --git a/src/zmlx/foundry/reports/pareto.py b/src/zmlx/foundry/reports/pareto.py new file mode 100644 index 0000000..c042ff1 --- /dev/null +++ b/src/zmlx/foundry/reports/pareto.py @@ -0,0 +1,223 @@ +"""Pareto front extraction from session attempt data. + +Identifies the Pareto-optimal set of kernel variants (non-dominated on +latency and correctness error) and provides helper utilities for finding +the single best kernel by p50 latency. + +Adapted from mlx-kernel-lab's ``reports/pareto.py``. +""" +from __future__ import annotations + +from collections.abc import Iterable +from pathlib import Path +from typing import Any + +from ..ndjson import iter_records + +# --------------------------------------------------------------------------- +# Attempt loading +# --------------------------------------------------------------------------- + +def load_attempts(session_dir: str) -> list[dict[str, Any]]: + """Load all attempt records from a session directory.""" + p = Path(session_dir) + out: list[dict[str, Any]] = [] + + # Single-file layout + single = p / "attempts.ndjson" + if single.exists(): + out.extend(iter_records(single)) + + # Worker shards + for wp in sorted(p.glob("attempts.worker*.ndjson")): + out.extend(iter_records(wp)) + + # Sub-directory layout + attempts_dir = p / "attempts" + if attempts_dir.is_dir(): + for fp in sorted(attempts_dir.glob("*.ndjson")): + out.extend(iter_records(fp)) + + return out + + +# --------------------------------------------------------------------------- +# Filter helpers +# --------------------------------------------------------------------------- + +def _is_successful(rec: dict[str, Any]) -> bool: + """Check whether a record represents a successful attempt. + + Supports both DataFoundry layout (build/correctness/bench nested dicts) + and Foundry layout (result.status == "ok"). + """ + # DataFoundry layout + if rec.get("build", {}).get("ok") and rec.get("correctness", {}).get("ok") and rec.get("bench", {}).get("ok"): + return True + # Foundry layout + if rec.get("result", {}).get("status") == "ok": + return True + return False + + +def _get_p50(rec: dict[str, Any]) -> float | None: + """Extract p50 latency from a record (multiple layout conventions).""" + # DataFoundry: bench.latency_ms.p50 + lat = rec.get("bench", {}).get("latency_ms") + if isinstance(lat, dict): + p50 = lat.get("p50") + if p50 is not None: + return float(p50) + # Foundry: result.template_latency_ms + rlat = rec.get("result", {}).get("template_latency_ms") + if isinstance(rlat, (int, float)): + return float(rlat) + return None + + +# --------------------------------------------------------------------------- +# Best kernel by p50 +# --------------------------------------------------------------------------- + +def best_kernel_by_p50( + attempts: Iterable[dict[str, Any]], + *, + op: str | None = None, +) -> dict[str, Any] | None: + """Return the fastest successful attempt by p50 latency. + + Parameters + ---------- + attempts : iterable of dicts + Attempt records. + op : str, optional + If given, filter to attempts for this op only. + """ + best: dict[str, Any] | None = None + best_p50: float | None = None + + for rec in attempts: + if op is not None and rec.get("op") != op: + continue + if not _is_successful(rec): + continue + p50 = _get_p50(rec) + if p50 is None: + continue + if best is None or p50 < best_p50: # type: ignore[operator] + best = rec + best_p50 = p50 + + return best + + +# --------------------------------------------------------------------------- +# Pareto front extraction +# --------------------------------------------------------------------------- + +def extract_pareto_front( + attempts: Iterable[dict[str, Any]], + *, + op: str | None = None, + objectives: list[str] | None = None, +) -> list[dict[str, Any]]: + """Extract the Pareto-optimal subset of successful attempts. + + By default, minimizes p50 latency and maximizes correctness quality + (lower max_abs_err is better). Custom objectives can be specified. + + Parameters + ---------- + attempts : iterable of dicts + Attempt records. + op : str, optional + Filter to this op. + objectives : list of str, optional + Objective keys to extract from records. Each objective is + minimized. Defaults to ``["p50_latency_ms", "max_abs_err"]``. + + Returns + ------- + list of dict + The Pareto-non-dominated records. + """ + if objectives is None: + objectives = ["p50_latency_ms", "max_abs_err"] + + # Collect (objective_vector, record) pairs for successful attempts + candidates: list[tuple[list[float], dict[str, Any]]] = [] + + for rec in attempts: + if op is not None and rec.get("op") != op: + continue + if not _is_successful(rec): + continue + + obj_vec: list[float] = [] + valid = True + for obj_key in objectives: + val = _extract_objective(rec, obj_key) + if val is None: + valid = False + break + obj_vec.append(val) + + if valid: + candidates.append((obj_vec, rec)) + + if not candidates: + return [] + + # Classic Pareto filter: a point is dominated if another point is + # <= on all objectives and < on at least one. + n_obj = len(objectives) + pareto: list[dict[str, Any]] = [] + + for i, (vec_i, rec_i) in enumerate(candidates): + dominated = False + for j, (vec_j, _) in enumerate(candidates): + if i == j: + continue + # Check if j dominates i (all <= and at least one <) + all_leq = all(vec_j[k] <= vec_i[k] for k in range(n_obj)) + any_lt = any(vec_j[k] < vec_i[k] for k in range(n_obj)) + if all_leq and any_lt: + dominated = True + break + if not dominated: + pareto.append(rec_i) + + return pareto + + +def _extract_objective(rec: dict[str, Any], key: str) -> float | None: + """Extract a named objective value from a record.""" + if key == "p50_latency_ms": + return _get_p50(rec) + if key == "max_abs_err": + corr = rec.get("correctness", {}) + mae = corr.get("max_abs_err") + if mae is not None: + return float(mae) + # Foundry layout: result.max_abs_err + mae2 = rec.get("result", {}).get("max_abs_err") + if mae2 is not None: + return float(mae2) + # If correctness is ok but no error metric, use 0.0 + if corr.get("ok") or rec.get("result", {}).get("status") == "ok": + return 0.0 + return None + if key == "max_rel_err": + corr = rec.get("correctness", {}) + mre = corr.get("max_rel_err") + if mre is not None: + return float(mre) + return None + # Generic: try to find key in top-level, result, or bench + for loc in [rec, rec.get("result", {}), rec.get("bench", {}), rec.get("correctness", {})]: + if key in loc: + try: + return float(loc[key]) + except (TypeError, ValueError): + pass + return None diff --git a/src/zmlx/foundry/sampling/__init__.py b/src/zmlx/foundry/sampling/__init__.py new file mode 100644 index 0000000..61a5f75 --- /dev/null +++ b/src/zmlx/foundry/sampling/__init__.py @@ -0,0 +1,11 @@ +"""Sampling and search policies for the ZMLX foundry. + +Provides random/coverage/mutation/mix sampling of kernel candidates, +plus CEM and random search policies for knob space exploration. +""" +from __future__ import annotations + +from .policies import CEMPolicy, RandomPolicy +from .sampler import Sampler + +__all__ = ["Sampler", "CEMPolicy", "RandomPolicy"] diff --git a/src/zmlx/foundry/sampling/coverage.py b/src/zmlx/foundry/sampling/coverage.py new file mode 100644 index 0000000..5b7bf7d --- /dev/null +++ b/src/zmlx/foundry/sampling/coverage.py @@ -0,0 +1,75 @@ +"""Shape and layout coverage generators for systematic exploration. + +Generates deterministic sequences of shapes and layouts to ensure coverage +across the (batch, seq, hidden) space, with occasional odd/adversarial +shapes to test edge cases. +""" +from __future__ import annotations + +import random +from typing import Any + +# -- Standard shape ladders for LLM inference workloads ----------------------- + +HIDDEN_SIZES: list[int] = [768, 1024, 1536, 2048, 2560, 3072, 4096, 5120, 6144, 8192] +SEQ_LENS: list[int] = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] +BATCHES: list[int] = [1, 2, 4, 8] + +# Edge-case sizes to probe non-power-of-two handling, off-by-one, etc. +ODD_HIDDEN: list[int] = [769, 1025, 1537, 2053, 2561, 3073, 4097] +ODD_SEQ: list[int] = [3, 5, 7, 9, 17, 33, 65, 127, 257, 513, 1025] + + +def shape_for_index( + i: int, + rng: random.Random, + op: str, + *, + extra_dims: dict[str, list[int]] | None = None, +) -> dict[str, int]: + """Deterministic shape for index *i*, cycling through standard ladders. + + Parameters + ---------- + i : int + Iteration index used for deterministic cycling. + rng : random.Random + Seeded RNG for occasional odd-shape injection. + op : str + Operation name; used to add op-specific shape keys. + extra_dims : dict, optional + Additional dimension ladders keyed by name. Each value list is + cycled through independently. E.g. ``{"n_experts": [8, 16, 32]}``. + """ + use_odd = (i % 11 == 0) # ~9% of samples get adversarial shapes + batch = BATCHES[i % len(BATCHES)] + + if use_odd: + hidden = rng.choice(ODD_HIDDEN) + seq = rng.choice(ODD_SEQ) + else: + hidden = HIDDEN_SIZES[(i // len(SEQ_LENS)) % len(HIDDEN_SIZES)] + seq = SEQ_LENS[i % len(SEQ_LENS)] + + shape: dict[str, int] = {"batch": int(batch), "seq": int(seq), "hidden": int(hidden)} + + # Op-specific shape augmentations + if op == "swiglu": + shape["hidden_in"] = int(2 * hidden) + + # Plug in any caller-supplied extra dimension ladders + if extra_dims: + for dim_name, ladder in extra_dims.items(): + shape[dim_name] = ladder[i % len(ladder)] + + return shape + + +def layout_for_index(i: int, rng: random.Random) -> dict[str, Any]: + """Deterministic layout for index *i*. + + ~15% of samples are non-contiguous (strided) to exercise Metal + stride-handling code paths. + """ + strided = rng.random() < 0.15 + return {"contiguous": (not strided), "strides": []} diff --git a/src/zmlx/foundry/sampling/mutate.py b/src/zmlx/foundry/sampling/mutate.py new file mode 100644 index 0000000..78be875 --- /dev/null +++ b/src/zmlx/foundry/sampling/mutate.py @@ -0,0 +1,90 @@ +"""Knob mutation for elite-evolution sampling. + +Operates on a generic knob space (Dict[str, List[Any]]) rather than +hardcoded per-op types. A single knob field is perturbed per mutation +by stepping +/- 1 in its ordered value list. Boolean knobs are toggled. +""" +from __future__ import annotations + +import copy +import random +from typing import Any + + +def _mutate_choice(rng: random.Random, val: Any, choices: list[Any]) -> Any: + """Step the current value one position in *choices* (wrapping).""" + if val not in choices: + return rng.choice(choices) + idx = choices.index(val) + step = rng.choice([-1, 1]) + return choices[(idx + step) % len(choices)] + + +def mutate( + knobs: dict[str, Any], + knob_space: dict[str, list[Any]], + rng: random.Random, + *, + inject_error_prob: float = 0.02, + inject_incorrect_prob: float = 0.04, +) -> dict[str, Any]: + """Mutate one random knob field from *knobs*, guided by *knob_space*. + + Parameters + ---------- + knobs : dict + Current knob values (will be deep-copied, not mutated in place). + knob_space : dict + Mapping of knob name to its ordered domain (list of valid values). + Boolean knobs are detected automatically and toggled. + rng : random.Random + Seeded RNG for reproducibility. + inject_error_prob : float + Probability of injecting a deliberate compile error (for dataset + diversity -- training on failures is valuable). + inject_incorrect_prob : float + Probability of injecting a deliberate correctness error. + + Returns + ------- + dict + New knobs dict with exactly one field mutated. + """ + k = copy.deepcopy(knobs) + + # Only mutate fields that exist in both the knob space and the current knobs + mutable_fields = [f for f in knob_space if f in k] + if not mutable_fields: + # Nothing to mutate -- return a copy with possible fault injection + return _maybe_inject(k, rng, inject_error_prob, inject_incorrect_prob) + + field = rng.choice(mutable_fields) + domain = list(knob_space[field]) + + # Boolean toggle shortcut + if set(domain) == {True, False} or set(domain) == {0, 1}: + k[field] = not bool(k.get(field, False)) + else: + k[field] = _mutate_choice(rng, k.get(field), domain) + + return _maybe_inject(k, rng, inject_error_prob, inject_incorrect_prob) + + +def _maybe_inject( + k: dict[str, Any], + rng: random.Random, + error_prob: float, + incorrect_prob: float, +) -> dict[str, Any]: + """Optionally inject compile-error or correctness-error flags.""" + r = rng.random() + if r < error_prob: + k["inject_compile_error"] = True + k["inject_incorrect"] = False + elif r < error_prob + incorrect_prob: + k["inject_compile_error"] = False + k["inject_incorrect"] = True + else: + k["inject_compile_error"] = False + k["inject_incorrect"] = False + return k diff --git a/src/zmlx/foundry/sampling/policies.py b/src/zmlx/foundry/sampling/policies.py new file mode 100644 index 0000000..2e3d5fb --- /dev/null +++ b/src/zmlx/foundry/sampling/policies.py @@ -0,0 +1,250 @@ +"""Search policies for knob-space exploration. + +Generalized from Discover's RmsnormKnobSpace-specific policies to accept +any ``Dict[str, List[Any]]`` knob space. The CEM math (smoothing, elite +fraction, minimum probability) is preserved identically. +""" +from __future__ import annotations + +import random +from collections.abc import Mapping, Sequence +from dataclasses import dataclass +from typing import Any + +# --------------------------------------------------------------------------- +# Shared types +# --------------------------------------------------------------------------- + +@dataclass(frozen=True, slots=True) +class SampledVariant: + """A single sampled point in the knob space.""" + knobs: dict[str, Any] + + +# --------------------------------------------------------------------------- +# Utility +# --------------------------------------------------------------------------- + +def _sample_categorical( + rng: random.Random, + items: Sequence[Any], + probs: Sequence[float], +) -> Any: + """Sample one item from *items* according to *probs*.""" + r = rng.random() + acc = 0.0 + for item, p in zip(items, probs, strict=False): + acc += p + if r <= acc: + return item + return items[-1] + + +# --------------------------------------------------------------------------- +# RandomPolicy +# --------------------------------------------------------------------------- + +class RandomPolicy: + """Uniform-random sampling over a discrete knob space. + + Parameters + ---------- + knob_space : dict + Mapping of knob name to its domain (list of allowed values). + Example: ``{"tg_size": [64, 128, 256], "vec": [1, 2, 4]}`` + """ + + def __init__(self, knob_space: dict[str, list[Any]]) -> None: + self.knob_space = knob_space + + def state(self) -> dict[str, Any]: + return {"type": "random"} + + def sample(self, rng: random.Random, n: int) -> list[SampledVariant]: + """Draw *n* independent uniform samples.""" + out: list[SampledVariant] = [] + for _ in range(int(n)): + knobs: dict[str, Any] = {} + for key, domain in self.knob_space.items(): + knobs[key] = rng.choice(list(domain)) + out.append(SampledVariant(knobs=knobs)) + return out + + def update(self, variants: Sequence[tuple[SampledVariant, float]]) -> None: + """No-op: random policy does not learn.""" + return + + +# --------------------------------------------------------------------------- +# CEMPolicy +# --------------------------------------------------------------------------- + +class CEMPolicy: + """Cross-Entropy Method over independent categorical knob dimensions. + + Maintains a factored probability distribution (one categorical per knob) + and updates toward elite samples with exponential smoothing. + + Parameters + ---------- + knob_space : dict + Mapping of knob name to its domain (list of allowed values). + elite_frac : float + Fraction of samples to treat as elite in each update. + smoothing : float + Weight on the *old* distribution during update (higher = more + conservative, slower adaptation). + min_prob : float + Floor probability for any value in any dimension, to prevent + premature convergence and maintain exploration. + """ + + def __init__( + self, + knob_space: dict[str, list[Any]], + *, + elite_frac: float = 0.25, + smoothing: float = 0.7, + min_prob: float = 0.02, + ) -> None: + self.knob_space = knob_space + self.elite_frac = float(elite_frac) + self.smoothing = float(smoothing) + self.min_prob = float(min_prob) + + # Initialize uniform distribution over each knob dimension + self._dist: dict[str, dict[Any, float]] = {} + for key, domain in knob_space.items(): + n = len(domain) + self._dist[key] = {v: 1.0 / n for v in domain} + + # -- Serialization / deserialization ------------------------------------ + + @staticmethod + def from_state( + knob_space: dict[str, list[Any]], + state: Mapping[str, Any], + ) -> CEMPolicy: + """Reconstruct a CEMPolicy from a previously saved state dict.""" + pol = CEMPolicy( + knob_space, + elite_frac=float(state.get("elite_frac", 0.25)), + smoothing=float(state.get("smoothing", 0.7)), + min_prob=float(state.get("min_prob", 0.02)), + ) + dist = state.get("dist", {}) + for key, entries in dist.items(): + if key not in pol._dist or not isinstance(entries, list): + continue + new: dict[Any, float] = {} + for e in entries: + if not isinstance(e, Mapping): + continue + if "value" not in e or "prob" not in e: + continue + new[e["value"]] = float(e["prob"]) + if new: + pol._dist[key] = new + pol._renormalize(key) + return pol + + def state(self) -> dict[str, Any]: + """Serialize the policy state for persistence.""" + return { + "type": "cem", + "elite_frac": self.elite_frac, + "smoothing": self.smoothing, + "min_prob": self.min_prob, + "dist": { + k: [{"value": kk, "prob": float(vv)} for kk, vv in d.items()] + for k, d in self._dist.items() + }, + } + + # -- Sampling ----------------------------------------------------------- + + def sample(self, rng: random.Random, n: int) -> list[SampledVariant]: + """Draw *n* samples from the current distribution.""" + out: list[SampledVariant] = [] + for _ in range(int(n)): + knobs: dict[str, Any] = {} + for key, domain in self.knob_space.items(): + knobs[key] = self._sample_key(rng, key, list(domain)) + out.append(SampledVariant(knobs=knobs)) + return out + + # -- Update (the CEM step) --------------------------------------------- + + def update(self, variants: Sequence[tuple[SampledVariant, float]]) -> None: + """Update the distribution toward elite samples. + + Parameters + ---------- + variants : sequence of (SampledVariant, reward) tuples + Higher reward is better. + """ + if not variants: + return + + # Sort by reward descending + sorted_pairs = sorted(variants, key=lambda x: x[1], reverse=True) + elite_n = max(1, int(len(sorted_pairs) * self.elite_frac)) + elite = sorted_pairs[:elite_n] + + # Count elite frequencies per knob dimension + counts: dict[str, dict[Any, int]] = { + key: {} for key in self._dist + } + + for v, _reward in elite: + for key in self._dist: + val = v.knobs.get(key) + if val is not None: + counts[key][val] = counts[key].get(val, 0) + 1 + + # Update each categorical with exponential smoothing + for key, old in self._dist.items(): + total = float(sum(counts[key].values())) + if total <= 0.0: + continue + + new: dict[Any, float] = {} + for k in old.keys(): + target = float(counts[key].get(k, 0)) / total + new[k] = self.smoothing * old[k] + (1.0 - self.smoothing) * target + + # Enforce minimum probability to maintain exploration + for k in new: + new[k] = max(new[k], self.min_prob) + + self._dist[key] = new + self._renormalize(key) + + # -- Internal helpers --------------------------------------------------- + + def _renormalize(self, key: str) -> None: + """Re-normalize the distribution for *key* to sum to 1.""" + d = self._dist[key] + s = float(sum(d.values())) + if s <= 0: + n = float(len(d)) + for k in d: + d[k] = 1.0 / n + return + for k in d: + d[k] = d[k] / s + + def _sample_key( + self, + rng: random.Random, + key: str, + domain: Sequence[Any], + ) -> Any: + """Sample one value for knob *key* from the current distribution.""" + probs = [float(self._dist[key].get(k, 0.0)) for k in domain] + s = float(sum(probs)) + if s <= 0.0: + probs = [1.0 / len(domain)] * len(domain) + else: + probs = [p / s for p in probs] + return _sample_categorical(rng, domain, probs) diff --git a/src/zmlx/foundry/sampling/sampler.py b/src/zmlx/foundry/sampling/sampler.py new file mode 100644 index 0000000..5e99aa8 --- /dev/null +++ b/src/zmlx/foundry/sampling/sampler.py @@ -0,0 +1,208 @@ +"""Generalized kernel candidate sampler. + +Adapted from DataFoundry's hardcoded op-to-knobs dispatch to accept a +generic knob_space dict from the op registry. Supports four sampling +modes: random, coverage, mutation, and mix (cyclic blend of all three). +""" +from __future__ import annotations + +import random +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from ..ndjson import iter_records +from ..taxonomy import KernelCandidate +from .coverage import layout_for_index, shape_for_index +from .mutate import mutate + +# --------------------------------------------------------------------------- +# Dtype sampling +# --------------------------------------------------------------------------- + +DTYPES: list[str] = ["float16", "bfloat16", "float32"] + + +def dtype_for_index(i: int, rng: random.Random) -> str: + """Biased dtype sampling: ~45% fp16, ~35% bf16, ~20% fp32.""" + r = rng.random() + if r < 0.45: + return "float16" + if r < 0.80: + return "bfloat16" + return "float32" + + +# --------------------------------------------------------------------------- +# Elite loading (for mutation mode) +# --------------------------------------------------------------------------- + +def load_elite( + session_dir: Path, + op: str, + *, + top_k: int = 32, +) -> list[dict[str, Any]]: + """Load the top-k fastest correct attempts from the session log. + + Returns a list of dicts with ``template_id`` and ``knobs`` keys. + """ + path = session_dir / "attempts.ndjson" + if not path.exists(): + return [] + + good: list[tuple[float, dict[str, Any]]] = [] + for rec in iter_records(path): + if rec.get("op") != op: + continue + if not rec.get("correctness", {}).get("ok", False): + continue + bench = rec.get("bench", {}) + if not bench.get("ok", False): + continue + p50 = bench.get("latency_ms", {}).get("p50") + if p50 is None: + continue + try: + p50f = float(p50) + except (TypeError, ValueError): + continue + good.append((p50f, rec)) + + good.sort(key=lambda t: t[0]) + elites: list[dict[str, Any]] = [] + for _, rec in good[:top_k]: + elites.append({ + "template_id": rec.get("kernel", {}).get("template_id"), + "knobs": rec.get("kernel", {}).get("knobs", {}), + }) + return elites + + +# --------------------------------------------------------------------------- +# Sampler +# --------------------------------------------------------------------------- + +@dataclass +class Sampler: + """Generates KernelCandidate instances using configurable strategies. + + Parameters + ---------- + op : str + Operation name (e.g. "rmsnorm", "swiglu"). + knob_space : dict + Mapping of knob name to list of valid values for that knob. + Obtained from the op registry rather than hardcoded here. + templates : list of str + Available template IDs for this op. + mode : str + Sampling mode: "random", "coverage", "mutation", or "mix". + seed : int + Base seed for deterministic sampling. + session_dir : Path + Session directory for loading elite candidates (mutation mode). + extra_shape_dims : dict, optional + Extra dimension ladders passed through to shape_for_index. + """ + + op: str + knob_space: dict[str, list[Any]] + templates: list[str] + mode: str # random | coverage | mutation | mix + seed: int + session_dir: Path + extra_shape_dims: dict[str, list[int]] | None = None + + # Internal state (set in __post_init__) + rng: random.Random = field(init=False, repr=False) + _mix_cycle: int = field(init=False, default=0, repr=False) + + def __post_init__(self) -> None: + self.rng = random.Random(self.seed) + + def sample_knobs_random(self, template_id: str) -> dict[str, Any]: + """Uniform random sample from the knob space.""" + knobs: dict[str, Any] = {} + for key, domain in self.knob_space.items(): + knobs[key] = self.rng.choice(list(domain)) + return knobs + + def sample_knobs_coverage(self, template_id: str, i: int) -> dict[str, Any]: + """Stratified sweep through knob values for systematic coverage. + + Cycles through each dimension's domain at different rates using + prime-offset strides, ensuring broad coverage over many iterations. + """ + knobs: dict[str, Any] = {} + # Use different stride primes per dimension to avoid alignment + stride_primes = [1, 3, 7, 11, 13, 17, 19, 23, 29, 31] + for idx, (key, domain) in enumerate(self.knob_space.items()): + stride = stride_primes[idx % len(stride_primes)] + knobs[key] = domain[(i * stride) % len(domain)] + + # Low-probability fault injection for dataset diversity + knobs["inject_compile_error"] = (i % 97 == 0) + knobs["inject_incorrect"] = (i % 113 == 0) and not knobs["inject_compile_error"] + return knobs + + def next_candidate(self, i: int) -> KernelCandidate: + """Generate the *i*-th candidate. + + The mode determines how shapes, layouts, and knobs are selected: + - **random**: all dimensions sampled uniformly + - **coverage**: deterministic cycling for systematic exploration + - **mutation**: perturb elite candidates + - **mix**: cycles coverage -> random -> random -> mutation + """ + mode = self.mode + if mode == "mix": + cyc = self._mix_cycle % 4 + self._mix_cycle += 1 + mode = ["coverage", "random", "random", "mutation"][cyc] + + dtype = dtype_for_index(i, self.rng) + template_id = self.templates[i % len(self.templates)] + + if mode == "coverage": + shape = shape_for_index( + i, self.rng, self.op, extra_dims=self.extra_shape_dims, + ) + layout = layout_for_index(i, self.rng) + knobs = self.sample_knobs_coverage(template_id, i) + + elif mode == "mutation": + elites = load_elite(self.session_dir, self.op, top_k=32) + if elites: + base = self.rng.choice(elites) + template_id = base.get("template_id") or template_id + knobs = mutate( + base.get("knobs", {}), + self.knob_space, + self.rng, + ) + else: + # No elites yet -- fall back to random knobs + knobs = self.sample_knobs_random(template_id) + # Mutation also samples new shapes to maintain shape coverage + shape = shape_for_index( + i, self.rng, self.op, extra_dims=self.extra_shape_dims, + ) + layout = layout_for_index(i, self.rng) + + else: + # random + shape = shape_for_index( + i, self.rng, self.op, extra_dims=self.extra_shape_dims, + ) + layout = layout_for_index(i, self.rng) + knobs = self.sample_knobs_random(template_id) + + return KernelCandidate( + op=self.op, + dtype=dtype, + shape=shape, + layout=layout, + template_id=template_id, + knobs=knobs, + ) diff --git a/src/zmlx/foundry/scheduler.py b/src/zmlx/foundry/scheduler.py new file mode 100644 index 0000000..4d68264 --- /dev/null +++ b/src/zmlx/foundry/scheduler.py @@ -0,0 +1,108 @@ +"""Curriculum scheduler for staged dataset generation. + +Organizes kernel ops into stages ordered from cheap/simple to +expensive/complex, allowing generation runs to progressively unlock +harder ops. This mirrors LLM training curricula: start with easy +examples, add complexity as the model improves. +""" +from __future__ import annotations + +from dataclasses import dataclass + +# --------------------------------------------------------------------------- +# Default 5-stage curriculum +# --------------------------------------------------------------------------- + +DEFAULT_CURRICULUM: list[list[str]] = [ + # Stage 0: cheapest elementwise / fused norms + ["rmsnorm", "swiglu", "rope"], + # Stage 1: more expensive attention-adjacent ops + ["softmax", "layernorm"], + # Stage 2: quantization primitives + ["dequantize", "quantize"], + # Stage 3: data-movement / KV cache + ["gather", "scatter", "kv_append"], + # Stage 4: MoE routing + dispatch (most expensive, largest kernel families) + ["moe_topk", "moe_pack_assignments", "moe_dispatch_gather", "moe_combine_scatter"], +] + + +def flatten_curriculum(curriculum: list[list[str]] | None = None) -> list[str]: + """Flatten a nested curriculum into a single ordered list of op names.""" + if curriculum is None: + curriculum = DEFAULT_CURRICULUM + out: list[str] = [] + for stage in curriculum: + out.extend(stage) + return out + + +# --------------------------------------------------------------------------- +# Scheduler +# --------------------------------------------------------------------------- + +@dataclass +class CurriculumScheduler: + """Controls which ops are available at a given curriculum stage. + + Parameters + ---------- + ops : list of str + Requested ops (from config / CLI). The scheduler intersects these + with the curriculum to determine what is actually available. + curriculum : list of list of str, optional + Override the default 5-stage curriculum. + stage : int + Current stage (0-indexed). Stages 0..stage are unlocked. + """ + + ops: list[str] + curriculum: list[list[str]] = None # type: ignore[assignment] + stage: int = 0 + + def __post_init__(self) -> None: + if self.curriculum is None: + self.curriculum = DEFAULT_CURRICULUM + + @property + def n_stages(self) -> int: + return len(self.curriculum) + + def set_stage(self, stage: int) -> None: + """Set the current curriculum stage (clamped to valid range).""" + self.stage = max(0, min(int(stage), self.n_stages - 1)) + + def available_ops(self) -> list[str]: + """Return ops unlocked at the current stage. + + The result is the intersection of requested ``ops`` and all ops + in curriculum stages 0 through ``stage`` (inclusive), preserving + curriculum order. Falls back to all requested ops if the + intersection would be empty. + """ + allowed: list[str] = [] + for i in range(min(self.stage + 1, self.n_stages)): + for op in self.curriculum[i]: + if op in self.ops and op not in allowed: + allowed.append(op) + if not allowed: + # Fallback: no curriculum overlap, return all requested ops + return list(self.ops) + return allowed + + def stage_for_op(self, op: str) -> int | None: + """Return the curriculum stage that introduces *op*, or None.""" + for i, stage_ops in enumerate(self.curriculum): + if op in stage_ops: + return i + return None + + def summary(self) -> dict[str, object]: + """Diagnostic summary for logging.""" + return { + "n_stages": self.n_stages, + "current_stage": self.stage, + "requested_ops": self.ops, + "available_ops": self.available_ops(), + "curriculum": self.curriculum, + } diff --git a/src/zmlx/foundry/session.py b/src/zmlx/foundry/session.py new file mode 100644 index 0000000..afcac47 --- /dev/null +++ b/src/zmlx/foundry/session.py @@ -0,0 +1,161 @@ +"""Session management for foundry runs. + +A Session owns a session directory, manages NDJSON logging, compile-cache +tracking, and existing-ID dedup. It combines the session-setup patterns +from datafoundry's ``run.py``, the kernel-foundry's ``SessionWriter``, and +the general NDJSON utilities already in this package. +""" +from __future__ import annotations + +import datetime as _dt +import hashlib +import json +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from .ndjson import append_record, load_existing_ids + + +def _utc_now_iso() -> str: + """ISO-8601 UTC timestamp string.""" + return _dt.datetime.now(_dt.timezone.utc).isoformat() + + +@dataclass +class Session: + """Manages a single foundry session on disk. + + Directory layout:: + + / + attempts.ndjson # primary log + attempts.worker*.ndjson # per-worker shards (multi-worker) + kernels/ # saved Metal source files + cache/ # compile-cache metadata + reports/ # generated reports + + Parameters + ---------- + session_dir : Path + Root directory for this session. + worker_id : int + Worker ID for multi-worker runs (0 for single-worker). + num_workers : int + Total number of workers. + """ + + session_dir: Path + worker_id: int = 0 + num_workers: int = 1 + + # Internal state + _existing_ids: set[str] = field(init=False, repr=False, default_factory=set) + _compile_cache: dict[str, bool] = field(init=False, repr=False, default_factory=dict) + _written: int = field(init=False, repr=False, default=0) + _created_at: str = field(init=False, repr=False, default="") + + def __post_init__(self) -> None: + self.session_dir = Path(self.session_dir) + self.session_dir.mkdir(parents=True, exist_ok=True) + (self.session_dir / "kernels").mkdir(exist_ok=True) + (self.session_dir / "cache").mkdir(exist_ok=True) + (self.session_dir / "reports").mkdir(exist_ok=True) + + self._existing_ids = load_existing_ids(self.session_dir) + self._created_at = _utc_now_iso() + + # -- Properties --------------------------------------------------------- + + @property + def attempts_path(self) -> Path: + """Path to this worker's NDJSON log.""" + if self.num_workers > 1: + return self.session_dir / f"attempts.worker{self.worker_id}.ndjson" + return self.session_dir / "attempts.ndjson" + + @property + def kernels_dir(self) -> Path: + return self.session_dir / "kernels" + + @property + def n_existing(self) -> int: + """Number of previously seen attempt IDs.""" + return len(self._existing_ids) + + @property + def n_written(self) -> int: + """Number of records written in this session instance.""" + return self._written + + # -- Dedup / claim ------------------------------------------------------ + + def is_known(self, attempt_id: str) -> bool: + """Check whether an attempt ID has already been recorded.""" + return attempt_id in self._existing_ids + + def try_claim(self, attempt_id: str) -> bool: + """Atomically claim an attempt ID. Returns True if newly claimed.""" + if attempt_id in self._existing_ids: + return False + self._existing_ids.add(attempt_id) + return True + + # -- Writing records ---------------------------------------------------- + + def write_record(self, record: dict[str, Any]) -> None: + """Append *record* to the session's NDJSON log. + + Automatically stamps ``session``, ``worker_id``, and ``created_at`` + if not already present. + """ + record.setdefault("session", self.session_dir.name) + record.setdefault("worker_id", self.worker_id) + record.setdefault("created_at", _utc_now_iso()) + append_record(self.attempts_path, record) + self._written += 1 + + def save_kernel_source(self, attempt_id: str, source: str) -> Path: + """Persist Metal source to ``kernels/.metal``.""" + path = self.kernels_dir / f"{attempt_id}.metal" + path.write_text(source, encoding="utf-8") + return path + + # -- Compile cache helpers ---------------------------------------------- + + def cache_key(self, **kwargs: Any) -> str: + """Compute a SHA-256 compile-cache key from arbitrary keyword args.""" + payload = json.dumps(kwargs, sort_keys=True, separators=(",", ":"), ensure_ascii=False) + return hashlib.sha256(payload.encode("utf-8")).hexdigest() + + def is_cached(self, key: str) -> bool: + """Check if a compile result is cached (in-memory fast path).""" + if key in self._compile_cache: + return True + # Check on-disk marker + marker = self.session_dir / "cache" / f"{key}.ok" + if marker.exists(): + self._compile_cache[key] = True + return True + return False + + def mark_cached(self, key: str) -> None: + """Record that a compile succeeded for *key*.""" + self._compile_cache[key] = True + marker = self.session_dir / "cache" / f"{key}.ok" + marker.parent.mkdir(parents=True, exist_ok=True) + marker.touch() + + # -- Session metadata --------------------------------------------------- + + def metadata(self) -> dict[str, Any]: + """Summary metadata for reporting.""" + return { + "session_dir": str(self.session_dir), + "session_name": self.session_dir.name, + "worker_id": self.worker_id, + "num_workers": self.num_workers, + "created_at": self._created_at, + "n_existing": self.n_existing, + "n_written": self.n_written, + } diff --git a/src/zmlx/foundry/taxonomy.py b/src/zmlx/foundry/taxonomy.py new file mode 100644 index 0000000..0e2bfe1 --- /dev/null +++ b/src/zmlx/foundry/taxonomy.py @@ -0,0 +1,257 @@ +"""Foundation types for the ZMLX foundry module. + +Adapted from DataFoundry's types.py. All types are plain dataclasses +(no framework dependency) so they serialize cleanly to NDJSON. +""" +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Protocol + +# --------------------------------------------------------------------------- +# Kernel class taxonomy (9 classes for LLM inference on-device) +# --------------------------------------------------------------------------- + + +class KernelClass(str, Enum): + ELEMENTWISE = "A" # Elementwise / broadcast + REDUCTION = "B" # Reductions (sum/max/norm) + FUSED_POINTWISE = "C" # Fused pointwise (bias+act, gated MLP) + GEMM = "D" # Matmul / GEMM-like (incl. grouped/batched) + ATTENTION = "E" # Attention primitives (softmax, RoPE) + QUANT = "F" # Quantization primitives (dequant/quant) + DATA_MOVEMENT = "G" # Data movement / re-layout (gather/scatter) + MOE = "H" # MoE routing + dispatch + KV_CACHE = "I" # KV cache management + + +# --------------------------------------------------------------------------- +# Op specification +# --------------------------------------------------------------------------- + + +@dataclass +class OpSpec: + """Compact, serializable spec for an op (used in datasets & reports).""" + + name: str + kernel_class: KernelClass + summary: str + inputs: list[str] = field(default_factory=list) + outputs: list[str] = field(default_factory=list) + op_params_schema: dict[str, Any] = field(default_factory=dict) + shape_hints: dict[str, Any] = field(default_factory=dict) + dtype_hints: list[str] = field(default_factory=list) + has_reference: bool = True + templates: list[str] = field(default_factory=list) + + +# --------------------------------------------------------------------------- +# Type aliases +# --------------------------------------------------------------------------- + +Shape = dict[str, int] +Layout = dict[str, Any] + +# --------------------------------------------------------------------------- +# Template metadata +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class KernelTemplate: + """Pointer to a .metal template on disk.""" + + op: str + template_id: str + path: str # relative path under templates// + + +# --------------------------------------------------------------------------- +# Candidate (the unit of work for the harness) +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class KernelCandidate: + """Fully-specified kernel attempt: op + dtype + shape + template + knobs.""" + + op: str + dtype: str + shape: Shape + layout: Layout + template_id: str + knobs: dict[str, Any] + + +# --------------------------------------------------------------------------- +# Result dataclasses +# --------------------------------------------------------------------------- + + +@dataclass +class BuildResult: + ok: bool + ms: float + error_type: str | None = None # 'compile_error' | 'api_error' | None + error_summary: str = "" + error_log: str = "" + cache_hit: bool = False + + +@dataclass +class CorrectnessResult: + ok: bool + max_abs_err: float | None + max_rel_err: float | None + ulp: float | None + n_tests: int + seed: int + fail_reason: str | None = None + + +@dataclass +class BenchResult: + ok: bool + warmup: int + repeats: int + latency_ms: dict[str, float | None] # p50/p90/min/mean + throughput_gbs: float | None + throughput_gflops: float | None + timeout: bool = False + + +# --------------------------------------------------------------------------- +# Tolerances +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class Tolerances: + max_abs: float + max_rel: float + + +def tolerances(dtype: str) -> Tolerances: + """Conservative defaults for correctness gating.""" + if dtype == "float32": + return Tolerances(max_abs=5e-5, max_rel=5e-5) + if dtype == "float16": + return Tolerances(max_abs=5e-3, max_rel=5e-3) + if dtype == "bfloat16": + return Tolerances(max_abs=1e-2, max_rel=1e-2) + raise ValueError(f"unknown dtype: {dtype}") + + +# --------------------------------------------------------------------------- +# Backend protocol -- structural typing so MLXBackend / MockBackend both match +# --------------------------------------------------------------------------- + + +class Backend(Protocol): + name: str + + def is_available(self) -> bool: ... + + def device_info(self) -> dict[str, Any]: ... + + def mlx_version(self) -> str | None: ... + + def array(self, np_array: Any, dtype: str) -> Any: ... + + def to_numpy(self, arr: Any) -> Any: ... + + def eval(self, arr: Any) -> None: ... + + def synchronize(self) -> None: ... + + def metal_kernel( + self, + name: str, + input_names: list[str], + output_names: list[str], + source: str, + header: str = "", + ensure_row_contiguous: bool = True, + atomic_outputs: bool = False, + ) -> Any: ... + + +# --------------------------------------------------------------------------- +# KernelOp protocol -- the GENERIC abstraction for op-specific logic +# --------------------------------------------------------------------------- + + +class KernelOp(Protocol): + """Protocol that every op module must satisfy. + + Instead of DataFoundry's ``if op == "rmsnorm" ... elif op == "swiglu"`` + switch, the harness accepts a KernelOp instance and delegates all + op-specific behaviour to it. + """ + + @property + def name(self) -> str: + """Canonical op name, e.g. ``"rmsnorm"``, ``"swiglu"``.""" + ... + + @property + def input_names(self) -> list[str]: + """Metal kernel input buffer names, e.g. ``["x", "w", "eps"]``.""" + ... + + @property + def output_names(self) -> list[str]: + """Metal kernel output buffer names, e.g. ``["y"]``.""" + ... + + def canonical_math(self) -> str: + """Human-readable math specification string.""" + ... + + def correctness_constraints(self, dtype: str) -> dict[str, Any]: + """Tolerance metadata dict for this dtype.""" + ... + + def bytes_and_flops(self, shape: Shape, dtype: str) -> tuple[int, int]: + """Estimated total bytes transferred and FLOPs for throughput calc.""" + ... + + def generate_inputs_numpy( + self, + rng: Any, + shape: Shape, + dtype: str, + layout: Layout, + ) -> dict[str, Any]: + """Create numpy input arrays for a single test case.""" + ... + + def reference_numpy( + self, + inputs: dict[str, Any], + dtype: str, + ) -> Any: + """Gold-standard numpy reference implementation.""" + ... + + def reference_mlx( + self, + inputs: dict[str, Any], + dtype: str, + ) -> Any: + """MLX reference implementation (for on-device comparison).""" + ... + + def output_shapes( + self, + inputs: dict[str, Any], + shape: Shape, + ) -> list[tuple[int, ...]]: + """Expected output shapes for the kernel call.""" + ... + + def tg_mem_bytes(self, template_id: str, tg_size: int) -> int: + """Threadgroup memory in bytes for a given template + tg_size.""" + ... diff --git a/src/zmlx/foundry/templates/__init__.py b/src/zmlx/foundry/templates/__init__.py new file mode 100644 index 0000000..93300f6 --- /dev/null +++ b/src/zmlx/foundry/templates/__init__.py @@ -0,0 +1,43 @@ +"""Template discovery and loading for foundry Metal kernel templates. + +Each op has a subdirectory (e.g. ``templates/rmsnorm/``) containing one +or more ``.metal`` template files. Template IDs are derived from the +filename stem (e.g. ``t0_basic.metal`` -> ``"t0_basic"``). +""" +from __future__ import annotations + +from pathlib import Path + +# The templates directory is this file's parent. +_TEMPLATES_DIR = Path(__file__).resolve().parent + + +def get_template_path(op: str, template_id: str) -> Path: + """Return the absolute Path to a ``.metal`` template file. + + Raises ``FileNotFoundError`` if the template does not exist on disk. + """ + path = _TEMPLATES_DIR / op / f"{template_id}.metal" + if not path.exists(): + raise FileNotFoundError( + f"Template not found: {path} " + f"(op={op!r}, template_id={template_id!r})" + ) + return path + + +def list_templates(op: str) -> list[str]: + """Return sorted template IDs available for *op*. + + An empty list is returned if the op subdirectory does not exist. + """ + op_dir = _TEMPLATES_DIR / op + if not op_dir.is_dir(): + return [] + return sorted(p.stem for p in op_dir.glob("*.metal")) + + +def load_template(op: str, template_id: str) -> str: + """Read and return the raw text of a template file.""" + path = get_template_path(op, template_id) + return path.read_text(encoding="utf-8") diff --git a/src/zmlx/foundry/templates/moe_combine/t0_basic.metal b/src/zmlx/foundry/templates/moe_combine/t0_basic.metal new file mode 100644 index 0000000..cb802e8 --- /dev/null +++ b/src/zmlx/foundry/templates/moe_combine/t0_basic.metal @@ -0,0 +1,47 @@ +// TEMPLATE_ID: t0_basic +// OP: moe_combine +//---HEADER--- +#include +using namespace metal; +//---BODY--- +constexpr uint TG_SIZE = {{TG_SIZE}}; +constexpr uint UNROLL = {{UNROLL}}; + +uint tid = thread_position_in_grid.x; +uint row = thread_position_in_grid.y; + +uint H = (uint)expert_y_shape[expert_y_ndim - 1]; +uint P = (uint)packed_token_ids_shape[packed_token_ids_ndim - 1]; + +for (uint col0 = tid; col0 < H; col0 += TG_SIZE * UNROLL) { + #pragma unroll + for (uint u = 0; u < UNROLL; ++u) { + uint col = col0 + u * TG_SIZE; + if (col < H) { + float acc = 0.0f; + for (uint p = 0; p < P; ++p) { + uint id_loc = elem_to_loc( + p, packed_token_ids_shape, packed_token_ids_strides, packed_token_ids_ndim + ); + uint tok = (uint)packed_token_ids[id_loc]; + if (tok == row) { + uint w_loc = elem_to_loc( + p, packed_weights_shape, packed_weights_strides, packed_weights_ndim + ); + float w = (float)packed_weights[w_loc]; + uint ey_elem = p * H + col; + uint ey_loc = elem_to_loc( + ey_elem, expert_y_shape, expert_y_strides, expert_y_ndim + ); + acc += ((float)expert_y[ey_loc]) * w; + } + } + + {{COMPILE_ERROR_SNIPPET}} + #if {{INJECT_INCORRECT}} + acc += 0.05f; + #endif + y[row * H + col] = ({{OUT_TYPE}})acc; + } + } +} diff --git a/src/zmlx/foundry/templates/moe_combine/t1_k8_unrolled.metal b/src/zmlx/foundry/templates/moe_combine/t1_k8_unrolled.metal new file mode 100644 index 0000000..1240c8b --- /dev/null +++ b/src/zmlx/foundry/templates/moe_combine/t1_k8_unrolled.metal @@ -0,0 +1,59 @@ +// TEMPLATE_ID: t1_k8_unrolled +// OP: moe_combine +//---HEADER--- +#include +using namespace metal; +//---BODY--- +constexpr uint TG_SIZE = {{TG_SIZE}}; +constexpr uint UNROLL = {{UNROLL}}; +constexpr uint K_UNROLL = 8; + +uint tid = thread_position_in_grid.x; +uint row = thread_position_in_grid.y; + +uint H = (uint)expert_y_shape[expert_y_ndim - 1]; +uint P = (uint)packed_token_ids_shape[packed_token_ids_ndim - 1]; + +for (uint col0 = tid; col0 < H; col0 += TG_SIZE * UNROLL) { + #pragma unroll + for (uint u = 0; u < UNROLL; ++u) { + uint col = col0 + u * TG_SIZE; + if (col < H) { + float acc = 0.0f; + for (uint p0 = 0; p0 < P; p0 += K_UNROLL) { + #pragma unroll + for (uint kk = 0; kk < K_UNROLL; ++kk) { + uint p = p0 + kk; + if (p < P) { + uint id_loc = elem_to_loc( + p, packed_token_ids_shape, packed_token_ids_strides, packed_token_ids_ndim + ); + uint tok = (uint)packed_token_ids[id_loc]; + if (tok == row) { + uint w_loc = elem_to_loc( + p, packed_weights_shape, packed_weights_strides, packed_weights_ndim + ); + float w = (float)packed_weights[w_loc]; + uint ey_elem = p * H + col; + uint ey_loc = elem_to_loc( + ey_elem, expert_y_shape, expert_y_strides, expert_y_ndim + ); + float ey = (float)expert_y[ey_loc]; + #if {{FAST_MATH}} + acc = fma(ey, w, acc); + #else + acc += ey * w; + #endif + } + } + } + } + + {{COMPILE_ERROR_SNIPPET}} + #if {{INJECT_INCORRECT}} + acc *= 0.995f; + #endif + y[row * H + col] = ({{OUT_TYPE}})acc; + } + } +} diff --git a/src/zmlx/foundry/templates/moe_combine/t2_row_tile.metal b/src/zmlx/foundry/templates/moe_combine/t2_row_tile.metal new file mode 100644 index 0000000..aaa1a5e --- /dev/null +++ b/src/zmlx/foundry/templates/moe_combine/t2_row_tile.metal @@ -0,0 +1,71 @@ +// TEMPLATE_ID: t2_row_tile +// OP: moe_combine +//---HEADER--- +#include +using namespace metal; +//---BODY--- +constexpr uint TG_SIZE = {{TG_SIZE}}; +constexpr uint UNROLL = {{UNROLL}}; + +uint tid = thread_position_in_grid.x; +uint row = thread_position_in_grid.y; + +uint H = (uint)expert_y_shape[expert_y_ndim - 1]; +uint P = (uint)packed_token_ids_shape[packed_token_ids_ndim - 1]; + +// Cache one p-tile of (token_id, weight) in threadgroup memory. +threadgroup int tg_ids[TG_SIZE]; +threadgroup float tg_w[TG_SIZE]; + +for (uint col0 = tid; col0 < H; col0 += TG_SIZE * UNROLL) { + #pragma unroll + for (uint u = 0; u < UNROLL; ++u) { + uint col = col0 + u * TG_SIZE; + if (col < H) { + float acc = 0.0f; + + for (uint p0 = 0; p0 < P; p0 += TG_SIZE) { + uint p_load = p0 + tid; + if (p_load < P) { + uint id_loc = elem_to_loc( + p_load, packed_token_ids_shape, packed_token_ids_strides, packed_token_ids_ndim + ); + uint w_loc = elem_to_loc( + p_load, packed_weights_shape, packed_weights_strides, packed_weights_ndim + ); + tg_ids[tid] = (int)packed_token_ids[id_loc]; + tg_w[tid] = (float)packed_weights[w_loc]; + } else { + tg_ids[tid] = -1; + tg_w[tid] = 0.0f; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + #pragma unroll + for (uint j = 0; j < TG_SIZE; ++j) { + uint p = p0 + j; + if (p < P && tg_ids[j] == (int)row) { + uint ey_elem = p * H + col; + uint ey_loc = elem_to_loc( + ey_elem, expert_y_shape, expert_y_strides, expert_y_ndim + ); + float ey = (float)expert_y[ey_loc]; + float w = tg_w[j]; + #if {{FAST_MATH}} + acc = fma(ey, w, acc); + #else + acc += ey * w; + #endif + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + {{COMPILE_ERROR_SNIPPET}} + #if {{INJECT_INCORRECT}} + acc -= 0.03f; + #endif + y[row * H + col] = ({{OUT_TYPE}})acc; + } + } +} diff --git a/src/zmlx/foundry/templates/render.py b/src/zmlx/foundry/templates/render.py new file mode 100644 index 0000000..ef31033 --- /dev/null +++ b/src/zmlx/foundry/templates/render.py @@ -0,0 +1,95 @@ +"""Mustache-style ``{{VAR}}`` template renderer for Metal kernel sources. + +Adapted from DataFoundry's harness/render.py. Templates may optionally +contain ``//---HEADER---`` and ``//---BODY---`` markers to split the +source into a header (``#include`` directives) and body (kernel logic). +The ``mx.fast.metal_kernel`` API accepts these separately. +""" +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from ..ids import normalize_source + +# --------------------------------------------------------------------------- +# Result dataclass +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class RenderedTemplate: + """The output of template rendering. + + Attributes: + header: Metal source for ``#include`` / ``using namespace`` preamble. + source: Metal kernel body (the code that runs per-thread). + full_text: Combined ``header + //---BODY--- + source`` for logging. + normalized_full_text: Whitespace-normalized version for stable hashing. + """ + + header: str + source: str + full_text: str + normalized_full_text: str + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _split_template(text: str) -> tuple[str, str]: + """Split raw template text at ``//---HEADER---`` / ``//---BODY---`` markers. + + If either marker is missing the entire text is treated as body and + the header is empty. This gracefully handles templates that omit + the markers. + """ + if "//---HEADER---" not in text or "//---BODY---" not in text: + return "", text + header = text.split("//---HEADER---", 1)[1].split("//---BODY---", 1)[0] + body = text.split("//---BODY---", 1)[1] + return header.strip() + "\n", body.strip() + "\n" + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def render_template( + path: Path, values: dict[str, Any] +) -> RenderedTemplate: + """Render a ``.metal`` template with ``{{KEY}}`` substitution. + + Parameters: + path: Absolute path to the ``.metal`` template file. + values: Mapping of placeholder names to replacement values. + Each value is converted to ``str`` before substitution. + + Returns: + A ``RenderedTemplate`` with the rendered header, body, combined + full text, and a normalised copy of the full text suitable for + cache-key hashing. + """ + raw = path.read_text(encoding="utf-8") + header, body = _split_template(raw) + + def _substitute(s: str) -> str: + out = s + for k, v in values.items(): + out = out.replace("{{" + k + "}}", str(v)) + return out + + header_rendered = _substitute(header) + body_rendered = _substitute(body) + full = f"{header_rendered}\n//---BODY---\n{body_rendered}" + normalised = normalize_source(full) + return RenderedTemplate( + header=header_rendered, + source=body_rendered, + full_text=full, + normalized_full_text=normalised, + ) diff --git a/src/zmlx/foundry/templates/rmsnorm/t0_basic.metal b/src/zmlx/foundry/templates/rmsnorm/t0_basic.metal new file mode 100644 index 0000000..7ca075b --- /dev/null +++ b/src/zmlx/foundry/templates/rmsnorm/t0_basic.metal @@ -0,0 +1,80 @@ +// TEMPLATE_ID: t0_basic +// OP: rmsnorm +//---HEADER--- +#include +using namespace metal; +//---BODY--- +constexpr uint TG_SIZE = {{TG_SIZE}}; +constexpr uint VEC = {{VEC}}; +constexpr uint UNROLL = {{UNROLL}}; + +#if {{FAST_MATH}} +inline float inv_sqrt(float x) { return rsqrt(x); } +#else +inline float inv_sqrt(float x) { return 1.0f / sqrt(x); } +#endif + +uint tid = thread_position_in_grid.x; +uint row = thread_position_in_grid.y; + +uint H = (uint)x_shape[2]; +uint base = row * H; + +threadgroup float partial[TG_SIZE]; + +float sumsq = 0.0f; + +// Accumulate sumsq +for (uint col0 = tid; col0 < H; col0 += TG_SIZE * UNROLL) { + #pragma unroll + for (uint u = 0; u < UNROLL; ++u) { + uint col = col0 + u * TG_SIZE; + if (col < H) { + uint elem = base + col; + uint loc = elem_to_loc(elem, x_shape, x_strides, x_ndim); + float xv = (float)x[loc]; + sumsq += xv * xv; + } + } +} + +partial[tid] = sumsq; +threadgroup_barrier(mem_flags::mem_threadgroup); + +// Reduce within threadgroup (assumes TG_SIZE is power of 2) +for (uint stride = TG_SIZE / 2; stride > 0; stride >>= 1) { + if (tid < stride) { + partial[tid] += partial[tid + stride]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); +} + +float inv = 0.0f; +if (tid == 0) { + float mean = partial[0] / max(1.0f, (float)H); + float e = eps[0]; + inv = inv_sqrt(mean + e); + partial[0] = inv; +} +threadgroup_barrier(mem_flags::mem_threadgroup); +inv = partial[0]; + +// Write outputs +for (uint col0 = tid; col0 < H; col0 += TG_SIZE * UNROLL) { + #pragma unroll + for (uint u = 0; u < UNROLL; ++u) { + uint col = col0 + u * TG_SIZE; + if (col < H) { + uint elem = base + col; + uint loc = elem_to_loc(elem, x_shape, x_strides, x_ndim); + float xv = (float)x[loc]; + float wv = (float)w[col]; + float outv = xv * inv * wv; + {{COMPILE_ERROR_SNIPPET}} + #if {{INJECT_INCORRECT}} + outv += 0.1f; + #endif + y[elem] = ({{OUT_TYPE}})outv; + } + } +} diff --git a/src/zmlx/foundry/templates/rmsnorm/t1_tgmem.metal b/src/zmlx/foundry/templates/rmsnorm/t1_tgmem.metal new file mode 100644 index 0000000..7062b25 --- /dev/null +++ b/src/zmlx/foundry/templates/rmsnorm/t1_tgmem.metal @@ -0,0 +1,77 @@ +// TEMPLATE_ID: t1_tgmem +// OP: rmsnorm +//---HEADER--- +#include +using namespace metal; +//---BODY--- +constexpr uint TG_SIZE = {{TG_SIZE}}; +constexpr uint UNROLL = {{UNROLL}}; + +#if {{FAST_MATH}} +inline float inv_sqrt(float x) { return rsqrt(x); } +#else +inline float inv_sqrt(float x) { return 1.0f / sqrt(x); } +#endif + +uint tid = thread_position_in_grid.x; +uint row = thread_position_in_grid.y; + +uint H = (uint)x_shape[2]; +uint base = row * H; + +// Cache a small tile of weights in tgmem (for demonstration) +// NOTE: This is intentionally simple and may not always help. +threadgroup float w_tile[TG_SIZE]; +threadgroup float partial[TG_SIZE]; + +float sumsq = 0.0f; +for (uint col0 = tid; col0 < H; col0 += TG_SIZE) { + uint elem = base + col0; + uint loc = elem_to_loc(elem, x_shape, x_strides, x_ndim); + float xv = (float)x[loc]; + sumsq += xv * xv; +} +partial[tid] = sumsq; +threadgroup_barrier(mem_flags::mem_threadgroup); + +// Reduce +for (uint stride = TG_SIZE / 2; stride > 0; stride >>= 1) { + if (tid < stride) partial[tid] += partial[tid + stride]; + threadgroup_barrier(mem_flags::mem_threadgroup); +} + +float inv = 0.0f; +if (tid == 0) { + float mean = partial[0] / max(1.0f, (float)H); + inv = inv_sqrt(mean + eps[0]); + partial[0] = inv; +} +threadgroup_barrier(mem_flags::mem_threadgroup); +inv = partial[0]; + +// Write output, staging weight per tile +for (uint colBase = 0; colBase < H; colBase += TG_SIZE) { + uint col = colBase + tid; + if (col < H) { + w_tile[tid] = (float)w[col]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + #pragma unroll + for (uint u = 0; u < UNROLL; ++u) { + uint colu = colBase + tid + u * TG_SIZE; + if (colu < H) { + uint elem = base + colu; + uint loc = elem_to_loc(elem, x_shape, x_strides, x_ndim); + float xv = (float)x[loc]; + float wv = w_tile[tid]; // simplistic; same tid + float outv = xv * inv * wv; + {{COMPILE_ERROR_SNIPPET}} + #if {{INJECT_INCORRECT}} + outv -= 0.05f; + #endif + y[elem] = ({{OUT_TYPE}})outv; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); +} diff --git a/src/zmlx/foundry/templates/swiglu/t0_basic.metal b/src/zmlx/foundry/templates/swiglu/t0_basic.metal new file mode 100644 index 0000000..8e59bf1 --- /dev/null +++ b/src/zmlx/foundry/templates/swiglu/t0_basic.metal @@ -0,0 +1,48 @@ +// TEMPLATE_ID: t0_basic +// OP: swiglu +//---HEADER--- +#include +using namespace metal; +//---BODY--- +constexpr uint TG_SIZE = {{TG_SIZE}}; +constexpr uint UNROLL = {{UNROLL}}; + +inline float sigmoid_approx(float x) { +#if {{FAST_MATH}} + float ax = fabs(x); + float y = x / (1.0f + ax); + return 0.5f + 0.5f * y; +#else + return 1.0f / (1.0f + exp(-x)); +#endif +} + +uint tid = thread_position_in_grid.x; +uint row = thread_position_in_grid.y; + +uint H2 = (uint)x_shape[2]; +uint H = H2 / 2; +uint base_in = row * H2; +uint base_out = row * H; + +for (uint col0 = tid; col0 < H; col0 += TG_SIZE * UNROLL) { + #pragma unroll + for (uint u = 0; u < UNROLL; ++u) { + uint col = col0 + u * TG_SIZE; + if (col < H) { + uint elem0 = base_in + col; + uint elem1 = base_in + (H + col); + uint loc0 = elem_to_loc(elem0, x_shape, x_strides, x_ndim); + uint loc1 = elem_to_loc(elem1, x_shape, x_strides, x_ndim); + float a = (float)x[loc0]; + float b = (float)x[loc1]; + float s = sigmoid_approx(a); + float outv = (a * s) * b; + {{COMPILE_ERROR_SNIPPET}} + #if {{INJECT_INCORRECT}} + outv *= 0.98f; + #endif + y[base_out + col] = ({{OUT_TYPE}})outv; + } + } +} diff --git a/src/zmlx/foundry/templates/swiglu/t1_unrolled.metal b/src/zmlx/foundry/templates/swiglu/t1_unrolled.metal new file mode 100644 index 0000000..392869e --- /dev/null +++ b/src/zmlx/foundry/templates/swiglu/t1_unrolled.metal @@ -0,0 +1,46 @@ +// TEMPLATE_ID: t1_unrolled +// OP: swiglu +//---HEADER--- +#include +using namespace metal; +//---BODY--- +constexpr uint TG_SIZE = {{TG_SIZE}}; +constexpr uint UNROLL = {{UNROLL}}; +constexpr uint VEC = {{VEC}}; + +inline float sigmoid_exact(float x) { return 1.0f / (1.0f + exp(-x)); } + +uint tid = thread_position_in_grid.x; +uint row = thread_position_in_grid.y; + +uint H2 = (uint)x_shape[2]; +uint H = H2 / 2; +uint base_in = row * H2; +uint base_out = row * H; + +// Vector-ish inner loop: process VEC contiguous cols per iteration when possible +for (uint col0 = tid * VEC; col0 < H; col0 += TG_SIZE * VEC * UNROLL) { + #pragma unroll + for (uint u = 0; u < UNROLL; ++u) { + uint colu = col0 + u * TG_SIZE * VEC; + #pragma unroll + for (uint v = 0; v < VEC; ++v) { + uint col = colu + v; + if (col < H) { + uint elem0 = base_in + col; + uint elem1 = base_in + (H + col); + uint loc0 = elem_to_loc(elem0, x_shape, x_strides, x_ndim); + uint loc1 = elem_to_loc(elem1, x_shape, x_strides, x_ndim); + float a = (float)x[loc0]; + float b = (float)x[loc1]; + float s = sigmoid_exact(a); + float outv = (a * s) * b; + {{COMPILE_ERROR_SNIPPET}} + #if {{INJECT_INCORRECT}} + outv += 0.02f; + #endif + y[base_out + col] = ({{OUT_TYPE}})outv; + } + } + } +} diff --git a/src/zmlx/foundry/workers.py b/src/zmlx/foundry/workers.py new file mode 100644 index 0000000..9d425a4 --- /dev/null +++ b/src/zmlx/foundry/workers.py @@ -0,0 +1,263 @@ +"""Multi-process worker orchestration for foundry runs. + +Spawns N workers that each run an independent evaluate loop over +candidates from a shared op/scheduler configuration. Each worker +writes to its own NDJSON shard (``attempts.worker.ndjson``) to +avoid contention; after all workers complete, logs are merged into +a single deduplicated ``attempts.ndjson``. + +Adapted from DataFoundry's ``run.py`` multi-worker path. +""" +from __future__ import annotations + +import multiprocessing as mp +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from .harness.cache import CompileCache +from .harness.evaluate import evaluate_attempt, write_record +from .ndjson import merge_worker_logs +from .ops import get_registry +from .sampling.sampler import Sampler +from .scheduler import CurriculumScheduler +from .session import Session +from .templates import list_templates + +# --------------------------------------------------------------------------- +# Worker config (picklable across processes) +# --------------------------------------------------------------------------- + + +@dataclass +class WorkerConfig: + """Serializable configuration for a single worker process.""" + + session_dir: str + worker_id: int + num_workers: int + ops: list[str] + n_attempts: int + mode: str # random | coverage | mutation | mix + seed: int + backend: str # mlx | mock + correctness_tests: int + warmup: int + repeats: int + bench_timeout_s: float + stage: int + + +# --------------------------------------------------------------------------- +# Single-worker loop +# --------------------------------------------------------------------------- + + +def run_worker(cfg: WorkerConfig) -> dict[str, Any]: + """Execute the evaluation loop for one worker. + + This function is designed to be the target of ``mp.Process`` or + called directly for single-worker runs. + + Returns a summary dict with ``worker_id``, ``n_attempted``, + ``n_written``, ``n_skipped``, and ``elapsed_s``. + """ + session = Session( + session_dir=Path(cfg.session_dir), + worker_id=cfg.worker_id, + num_workers=cfg.num_workers, + ) + + registry = get_registry() + scheduler = CurriculumScheduler(ops=cfg.ops) + scheduler.set_stage(cfg.stage) + available = scheduler.available_ops() + + cache = CompileCache() + t0 = time.monotonic() + n_attempted = 0 + n_written = 0 + n_skipped = 0 + + for op_name in available: + op = registry.get(op_name) + if op is None: + continue + + templates = list_templates(op_name) + if not templates: + templates = ["ref"] + + # Build a merged, flattened knob space from all templates. + # Op.knob_space(template_id) returns {"knob": {"type":..., "values":[...]}} + # Sampler expects {"knob": [val1, val2, ...]} + merged_knob_space: dict[str, list[Any]] = {} + if hasattr(op, "knob_space"): + for tid in templates: + try: + raw = op.knob_space(tid) + except Exception: + continue + for k, v in raw.items(): + if isinstance(v, dict) and "values" in v: + vals = list(v["values"]) + elif isinstance(v, dict) and v.get("type") == "bool": + vals = [True, False] + elif isinstance(v, list): + vals = list(v) + else: + continue + if k not in merged_knob_space: + merged_knob_space[k] = vals + else: + # Union of values, preserving order + existing = set(merged_knob_space[k]) + for val in vals: + if val not in existing: + merged_knob_space[k].append(val) + existing.add(val) + + sampler = Sampler( + op=op_name, + knob_space=merged_knob_space, + templates=templates, + mode=cfg.mode, + seed=cfg.seed + cfg.worker_id * 1_000_000, + session_dir=session.session_dir, + extra_shape_dims=getattr(op, "extra_shape_dims", None), + ) + + per_op_budget = max(1, cfg.n_attempts // max(len(available), 1)) + + for i in range(per_op_budget): + candidate = sampler.next_candidate(i) + + # Shard: only this worker handles candidates where + # hash(attempt_index) % num_workers == worker_id + global_idx = n_attempted + if cfg.num_workers > 1 and (global_idx % cfg.num_workers) != cfg.worker_id: + n_attempted += 1 + n_skipped += 1 + continue + + try: + record = evaluate_attempt( + session_dir=session.session_dir, + backend_name=cfg.backend, + candidate=candidate, + op=op, + cache=cache, + correctness_tests=cfg.correctness_tests, + warmup=cfg.warmup, + repeats=cfg.repeats, + bench_timeout_s=cfg.bench_timeout_s, + ) + except Exception as exc: + record = { + "id": f"error_{cfg.worker_id}_{n_attempted}", + "op": op_name, + "error": str(exc), + "build": {"ok": False}, + "correctness": {"ok": False}, + "bench": {"ok": False}, + } + + rid = record.get("id", "") + if session.try_claim(rid): + write_record( + session_dir=session.session_dir, + record=record, + worker_id=cfg.worker_id if cfg.num_workers > 1 else None, + ) + n_written += 1 + + n_attempted += 1 + + elapsed = time.monotonic() - t0 + return { + "worker_id": cfg.worker_id, + "n_attempted": n_attempted, + "n_written": n_written, + "n_skipped": n_skipped, + "elapsed_s": round(elapsed, 2), + } + + +def _worker_entry(cfg: WorkerConfig) -> None: + """Entry point for spawned worker processes (writes result to stdout).""" + import json + result = run_worker(cfg) + print(json.dumps(result), flush=True) + + +# --------------------------------------------------------------------------- +# Multi-worker orchestrator +# --------------------------------------------------------------------------- + + +def spawn_workers( + *, + session_dir: str, + ops: list[str], + n_attempts: int, + num_workers: int = 1, + mode: str = "mix", + seed: int = 42, + backend: str = "mlx", + correctness_tests: int = 3, + warmup: int = 10, + repeats: int = 50, + bench_timeout_s: float = 10.0, + stage: int = 4, +) -> dict[str, Any]: + """Run a foundry session with *num_workers* parallel workers. + + For ``num_workers == 1``, runs in the current process (no subprocess + overhead). For ``num_workers > 1``, spawns child processes and merges + the per-worker NDJSON shards on completion. + + Returns a summary dict with ``session_dir``, ``workers`` (list of + per-worker results), ``merged_path``, and ``total_elapsed_s``. + """ + session_path = Path(session_dir) + session_path.mkdir(parents=True, exist_ok=True) + + configs = [ + WorkerConfig( + session_dir=session_dir, + worker_id=wid, + num_workers=num_workers, + ops=ops, + n_attempts=n_attempts, + mode=mode, + seed=seed, + backend=backend, + correctness_tests=correctness_tests, + warmup=warmup, + repeats=repeats, + bench_timeout_s=bench_timeout_s, + stage=stage, + ) + for wid in range(num_workers) + ] + + t0 = time.monotonic() + + if num_workers == 1: + results = [run_worker(configs[0])] + else: + with mp.Pool(processes=num_workers) as pool: + results = pool.map(run_worker, configs) + + # Merge worker shards + merged = merge_worker_logs(session_path) + total_elapsed = time.monotonic() - t0 + + return { + "session_dir": str(session_path), + "num_workers": num_workers, + "workers": results, + "merged_path": str(merged), + "total_elapsed_s": round(total_elapsed, 2), + } diff --git a/src/zmlx/fusion/__init__.py b/src/zmlx/fusion/__init__.py new file mode 100644 index 0000000..fbaf5a3 --- /dev/null +++ b/src/zmlx/fusion/__init__.py @@ -0,0 +1,10 @@ +"""Public fusion APIs.""" + +from .aot import build_aot_artifact, write_aot_artifact +from .jit import jit + +__all__ = [ + "jit", + "build_aot_artifact", + "write_aot_artifact", +] diff --git a/src/zmlx/fusion/aot.py b/src/zmlx/fusion/aot.py new file mode 100644 index 0000000..f14d830 --- /dev/null +++ b/src/zmlx/fusion/aot.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from .runtime import ( + CompiledFusedElementwise, + CompiledFusedReductionTwoPass, + CompiledQuantizedSharedInputFusion, + export_compiled_aot_payload, +) + + +@dataclass(frozen=True) +class FusedAOTArtifact: + """Minimal AOT artifact scaffold for fused kernels.""" + + descriptor: dict[str, Any] + kernels: dict[str, str] + sources: dict[str, str] + + +def build_aot_artifact( + *, + descriptor: dict[str, Any], + compiled: CompiledFusedElementwise + | CompiledFusedReductionTwoPass + | CompiledQuantizedSharedInputFusion, +) -> FusedAOTArtifact: + payload = export_compiled_aot_payload(descriptor=descriptor, compiled=compiled) + return FusedAOTArtifact( + descriptor=dict(payload["descriptor"]), + kernels=dict(payload["kernels"]), + sources=dict(payload["sources"]), + ) + + +def write_aot_artifact( + artifact: FusedAOTArtifact, + *, + output_dir: str | Path, + stem: str = "fused_kernel", +) -> Path: + """Write descriptor JSON + MSL source files for offline packaging.""" + + out_dir = Path(output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + descriptor_path = out_dir / f"{stem}.descriptor.json" + descriptor_path.write_text( + json.dumps( + { + "descriptor": artifact.descriptor, + "kernels": artifact.kernels, + "sources": sorted(artifact.sources.keys()), + }, + indent=2, + sort_keys=True, + ), + encoding="utf-8", + ) + for source_name, source in artifact.sources.items(): + source_path = out_dir / f"{stem}.{source_name}.metal" + source_path.write_text(source, encoding="utf-8") + return descriptor_path + + +__all__ = [ + "FusedAOTArtifact", + "build_aot_artifact", + "write_aot_artifact", +] diff --git a/src/zmlx/fusion/cache.py b/src/zmlx/fusion/cache.py new file mode 100644 index 0000000..b1b0e1d --- /dev/null +++ b/src/zmlx/fusion/cache.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +from collections import OrderedDict +from collections.abc import Sequence +from dataclasses import dataclass +from threading import RLock +from typing import Any + + +def _shape_key(shape: Sequence[int]) -> tuple[int, ...]: + return tuple(int(d) for d in shape) + + +def _dtype_key(dtype: Any) -> str: + # MLX dtype objects stringify deterministically (e.g. "float16"). + return str(dtype) + + +@dataclass(frozen=True) +class FusionCompileKey: + """Key for fused elementwise compile cache entries.""" + + op_signature: str + input_shapes: tuple[tuple[int, ...], ...] + input_dtypes: tuple[str, ...] + + @classmethod + def from_parts( + cls, + *, + op_signature: str, + input_shapes: Sequence[Sequence[int]], + input_dtypes: Sequence[Any], + ) -> FusionCompileKey: + return cls( + op_signature=str(op_signature), + input_shapes=tuple(_shape_key(shape) for shape in input_shapes), + input_dtypes=tuple(_dtype_key(dtype) for dtype in input_dtypes), + ) + + +class FusionCompileCache: + """Thread-safe in-process LRU cache for compiled fused kernels.""" + + def __init__(self, max_entries: int = 1000) -> None: + if max_entries <= 0: + raise ValueError("max_entries must be > 0") + self._max_entries = int(max_entries) + self._entries: OrderedDict[FusionCompileKey, Any] = OrderedDict() + self._lock = RLock() + + @property + def max_entries(self) -> int: + return self._max_entries + + def set_max_entries(self, max_entries: int) -> None: + if max_entries <= 0: + raise ValueError("max_entries must be > 0") + with self._lock: + self._max_entries = int(max_entries) + self._evict_if_needed() + + def get(self, key: FusionCompileKey) -> Any | None: + with self._lock: + value = self._entries.get(key) + if value is None: + return None + self._entries.move_to_end(key) + return value + + def put(self, key: FusionCompileKey, value: Any) -> Any: + with self._lock: + self._entries[key] = value + self._entries.move_to_end(key) + self._evict_if_needed() + return value + + def clear(self) -> None: + with self._lock: + self._entries.clear() + + def size(self) -> int: + with self._lock: + return len(self._entries) + + def keys(self) -> list[FusionCompileKey]: + with self._lock: + return list(self._entries.keys()) + + def __len__(self) -> int: + return self.size() + + def _evict_if_needed(self) -> None: + while len(self._entries) > self._max_entries: + self._entries.popitem(last=False) + + +GLOBAL_FUSION_COMPILE_CACHE = FusionCompileCache(max_entries=1000) + + +__all__ = [ + "FusionCompileKey", + "FusionCompileCache", + "GLOBAL_FUSION_COMPILE_CACHE", +] diff --git a/src/zmlx/fusion/codegen.py b/src/zmlx/fusion/codegen.py new file mode 100644 index 0000000..afa5bc4 --- /dev/null +++ b/src/zmlx/fusion/codegen.py @@ -0,0 +1,741 @@ +from __future__ import annotations + +import hashlib +import json +from dataclasses import dataclass +from typing import Any + + +@dataclass(frozen=True) +class FusedNode: + """Single node in a fused elementwise DAG.""" + + id: str + op: str + args: tuple[str, ...] + + +@dataclass(frozen=True) +class FusedGraph: + """Normalized fused elementwise graph descriptor.""" + + input_names: tuple[str, ...] + nodes: tuple[FusedNode, ...] + output: str + name: str = "" + + def to_payload(self) -> dict[str, Any]: + return { + "name": self.name, + "inputs": list(self.input_names), + "nodes": [ + { + "id": node.id, + "op": node.op, + "args": list(node.args), + } + for node in self.nodes + ], + "output": self.output, + } + + +@dataclass(frozen=True) +class GeneratedFusedSource: + """Generated source payload used by the runtime compiler.""" + + kernel_name: str + input_names: tuple[str, ...] + output_name: str + source: str + op_signature: str + graph: FusedGraph + + +@dataclass(frozen=True) +class GeneratedReductionTwoPassSource: + """Generated source payload for two-pass reduction fusion.""" + + pass1_kernel_name: str + pass2_kernel_name: str + source_pass1: str + source_pass2: str + op_signature: str + reduce_op: str + no_fma: bool + rows: int + reduce_extent: int + partial_cols: int + rhs_mode: str + + +@dataclass(frozen=True) +class GeneratedQuantizedSharedInputSource: + """Generated source payload for decode-oriented quantized QMM fusion.""" + + kernel_name: str + source: str + op_signature: str + bits: int + group_size: int + n: int + k: int + + +_OP_ALIASES: dict[str, str] = { + "plus": "add", + "minus": "sub", + "times": "mul", + "multiply": "mul", + "swish": "silu", + "minimum": "min", + "maximum": "max", +} + + +_SUPPORTED_OPS: dict[str, tuple[int, Any]] = { + "add": (2, lambda a: f"({a[0]} + {a[1]})"), + "sub": (2, lambda a: f"({a[0]} - {a[1]})"), + "mul": (2, lambda a: f"({a[0]} * {a[1]})"), + "div": (2, lambda a: f"({a[0]} / {a[1]})"), + "neg": (1, lambda a: f"(-({a[0]}))"), + "exp": (1, lambda a: f"metal::exp({a[0]})"), + "tanh": (1, lambda a: f"metal::tanh({a[0]})"), + "sigmoid": (1, lambda a: f"kk_sigmoid({a[0]})"), + "relu": (1, lambda a: f"metal::max({a[0]}, 0.0f)"), + "silu": (1, lambda a: f"kk_silu({a[0]})"), + "log": (1, lambda a: f"metal::log({a[0]})"), + "sqrt": (1, lambda a: f"metal::sqrt({a[0]})"), + "rsqrt": (1, lambda a: f"metal::rsqrt({a[0]})"), + "min": (2, lambda a: f"metal::min({a[0]}, {a[1]})"), + "max": (2, lambda a: f"metal::max({a[0]}, {a[1]})"), +} + + +_RHS_MODE_TO_INT: dict[str, int] = { + "full": 0, + "row": 1, + "scalar": 2, +} + + +def _is_seq(x: Any) -> bool: + return isinstance(x, (list, tuple)) + + +def _canonical_op_name(op: Any) -> str: + if not isinstance(op, str): + raise TypeError(f"op must be str, got {type(op).__name__}") + key = op.strip().lower() + key = _OP_ALIASES.get(key, key) + if key not in _SUPPORTED_OPS: + supported = ", ".join(sorted(_SUPPORTED_OPS)) + raise ValueError(f"Unsupported fusion op {op!r}. Supported ops: {supported}") + return key + + +def _extract_input_names(descriptor: dict[str, Any]) -> tuple[str, ...]: + raw = descriptor.get("input_names") + if raw is None: + raw = descriptor.get("inputs") + if raw is None: + raw = descriptor.get("args") + + if isinstance(raw, int): + if raw <= 0: + raise ValueError("descriptor input count must be > 0") + return tuple(f"in{i}" for i in range(int(raw))) + + if not _is_seq(raw): + n = descriptor.get("num_inputs") + if isinstance(n, int) and n > 0: + return tuple(f"in{i}" for i in range(int(n))) + raise ValueError("descriptor must provide 'inputs' or 'input_names'") + + out: list[str] = [] + for i, item in enumerate(raw): + if isinstance(item, dict): + name = item.get("name") or item.get("id") or item.get("input") + if name is None: + name = f"in{i}" + out.append(str(name)) + else: + out.append(str(item)) + + if not out: + raise ValueError("descriptor must define at least one input") + + if len(set(out)) != len(out): + raise ValueError(f"input names must be unique, got {out}") + + return tuple(out) + + +def _extract_raw_nodes(descriptor: dict[str, Any]) -> list[dict[str, Any]]: + raw_nodes = descriptor.get("nodes") + if raw_nodes is None: + raw_nodes = descriptor.get("ops") + if raw_nodes is None: + raw_nodes = descriptor.get("instructions") + + if raw_nodes is None: + if "op" in descriptor: + return [descriptor] + raise ValueError("descriptor must provide 'nodes'/'ops' or a top-level 'op'") + + if not _is_seq(raw_nodes): + raise TypeError("descriptor nodes must be a list/tuple") + + out: list[dict[str, Any]] = [] + for raw in raw_nodes: + if not isinstance(raw, dict): + raise TypeError(f"node entries must be dicts, got {type(raw).__name__}") + out.append(raw) + return out + + +def _extract_node_args(node: dict[str, Any]) -> list[Any]: + if "args" in node: + args = node["args"] + elif "inputs" in node: + args = node["inputs"] + elif "lhs" in node and "rhs" in node: + args = [node["lhs"], node["rhs"]] + elif "a" in node and "b" in node: + args = [node["a"], node["b"]] + elif "x" in node: + args = [node["x"]] + elif "input" in node: + args = [node["input"]] + else: + args = [] + + if _is_seq(args): + return list(args) + return [args] + + +def _normalize_ref(raw: Any, input_names: tuple[str, ...], node_ids: tuple[str, ...]) -> str: + if isinstance(raw, str): + return raw + + if isinstance(raw, int): + i = int(raw) + if 0 <= i < len(input_names): + return input_names[i] + if 0 <= i < len(node_ids): + return node_ids[i] + raise ValueError(f"integer ref {raw} is out of range") + + if isinstance(raw, dict): + if "input" in raw: + return _normalize_ref(raw["input"], input_names, node_ids) + if "node" in raw: + return _normalize_ref(raw["node"], input_names, node_ids) + if len(raw) == 1 and "id" in raw: + return str(raw["id"]) + if len(raw) == 1 and "name" in raw: + return str(raw["name"]) + + raise TypeError(f"unsupported ref type {type(raw).__name__}: {raw!r}") + + +def _extract_output_ref( + descriptor: dict[str, Any], + input_names: tuple[str, ...], + node_ids: tuple[str, ...], +) -> str: + raw = descriptor.get("output") + if raw is None: + raw = descriptor.get("out") + if raw is None: + raw = descriptor.get("result") + if raw is None: + if node_ids: + return node_ids[-1] + raise ValueError("descriptor must define an output") + + if _is_seq(raw): + if len(raw) != 1: + raise ValueError("phase-1 fused elementwise only supports one output") + raw = raw[0] + + return _normalize_ref(raw, input_names, node_ids) + + +def _toposort_nodes( + *, + nodes_by_id: dict[str, FusedNode], + output_ref: str, + input_names: tuple[str, ...], +) -> tuple[FusedNode, ...]: + inputs = set(input_names) + resolved: set[str] = set() + visiting: set[str] = set() + ordered: list[FusedNode] = [] + + def visit(ref: str) -> None: + if ref in inputs: + return + if ref in resolved: + return + if ref in visiting: + raise ValueError(f"cycle detected at {ref!r}") + node = nodes_by_id.get(ref) + if node is None: + raise ValueError(f"unknown ref {ref!r}") + visiting.add(ref) + for arg in node.args: + visit(arg) + visiting.remove(ref) + resolved.add(ref) + ordered.append(node) + + visit(output_ref) + return tuple(ordered) + + +def normalize_descriptor(descriptor: FusedGraph | dict[str, Any]) -> FusedGraph: + """Normalize a user/fusion-rule descriptor into a deterministic graph.""" + + if isinstance(descriptor, FusedGraph): + return descriptor + if not isinstance(descriptor, dict): + raise TypeError(f"descriptor must be dict-like, got {type(descriptor).__name__}") + + input_names = _extract_input_names(descriptor) + raw_nodes = _extract_raw_nodes(descriptor) + + node_ids: list[str] = [] + for i, raw in enumerate(raw_nodes): + node_id = raw.get("id") or raw.get("name") or raw.get("out") + if node_id is None: + node_id = f"n{i}" + node_ids.append(str(node_id)) + + if len(set(node_ids)) != len(node_ids): + raise ValueError(f"node ids must be unique, got {node_ids}") + + for node_id in node_ids: + if node_id in set(input_names): + raise ValueError(f"node id {node_id!r} collides with an input name") + + nodes: list[FusedNode] = [] + node_ids_t = tuple(node_ids) + for i, raw in enumerate(raw_nodes): + node_id = node_ids[i] + op = _canonical_op_name(raw.get("op") or raw.get("kind") or raw.get("type")) + raw_args = _extract_node_args(raw) + expected_arity, _ = _SUPPORTED_OPS[op] + if len(raw_args) != expected_arity: + raise ValueError(f"node {node_id!r} op {op!r} expects {expected_arity} args") + args = tuple(_normalize_ref(arg, input_names, node_ids_t) for arg in raw_args) + nodes.append(FusedNode(id=node_id, op=op, args=args)) + + output_ref = _extract_output_ref(descriptor, input_names, node_ids_t) + nodes_by_id = {node.id: node for node in nodes} + topo_nodes = _toposort_nodes( + nodes_by_id=nodes_by_id, + output_ref=output_ref, + input_names=input_names, + ) + + if output_ref not in set(input_names) and output_ref not in nodes_by_id: + raise ValueError(f"output ref {output_ref!r} must reference an input or node") + + name = str(descriptor.get("name") or descriptor.get("kernel_name") or "") + return FusedGraph( + input_names=input_names, + nodes=topo_nodes, + output=output_ref, + name=name, + ) + + +def _sanitize_identifier(name: str, fallback: str) -> str: + cleaned = "".join(ch if (ch.isascii() and (ch.isalnum() or ch == "_")) else "_" for ch in name) + if not cleaned: + cleaned = fallback + if not (cleaned[0].isalpha() or cleaned[0] == "_"): + cleaned = f"_{cleaned}" + return cleaned + + +def _sanitize_input_names(input_names: tuple[str, ...], output_name: str) -> tuple[str, ...]: + out: list[str] = [] + used: set[str] = {output_name} + for i, name in enumerate(input_names): + base = _sanitize_identifier(name, f"in{i}") + candidate = base + suffix = 1 + while candidate in used: + candidate = f"{base}_{suffix}" + suffix += 1 + used.add(candidate) + out.append(candidate) + return tuple(out) + + +def _emit_op(op: str, args: tuple[str, ...]) -> str: + _, emitter = _SUPPORTED_OPS[op] + return str(emitter(args)) + + +def _hash_payload(payload: dict[str, Any]) -> str: + encoded = json.dumps(payload, sort_keys=True, separators=(",", ":")) + return hashlib.sha256(encoded.encode("utf-8")).hexdigest() + + +def fused_op_signature(descriptor: FusedGraph | dict[str, Any]) -> str: + """Return a stable signature for a fused graph's operation structure.""" + + graph = normalize_descriptor(descriptor) + return _hash_payload(graph.to_payload()) + + +def generate_fused_elementwise_source( + descriptor: FusedGraph | dict[str, Any], + *, + n_elements: int, + kernel_name: str | None = None, + output_name: str = "out", +) -> GeneratedFusedSource: + """Generate MSL body for one fused elementwise kernel launch shape.""" + + graph = normalize_descriptor(descriptor) + signature = fused_op_signature(graph) + + safe_output_name = _sanitize_identifier(output_name, "out") + safe_inputs = _sanitize_input_names(graph.input_names, safe_output_name) + + value_by_ref: dict[str, str] = {} + lines: list[str] = [ + f"constexpr uint N_ELEMS = {int(n_elements)}u;", + "uint idx = thread_position_in_grid.x;", + "if (idx >= N_ELEMS) {", + " return;", + "}", + ] + + for i, (raw_name, safe_name) in enumerate(zip(graph.input_names, safe_inputs, strict=True)): + var = f"in_{i}" + lines.append(f"float {var} = (float){safe_name}[idx];") + value_by_ref[raw_name] = var + + for i, node in enumerate(graph.nodes): + args = tuple(value_by_ref[arg] for arg in node.args) + expr = _emit_op(node.op, args) + var = f"tmp_{i}" + lines.append(f"float {var} = {expr};") + value_by_ref[node.id] = var + + if graph.output not in value_by_ref: + raise ValueError(f"output ref {graph.output!r} was not resolved") + + lines.append(f"{safe_output_name}[idx] = (O)({value_by_ref[graph.output]});") + source = "\n".join(lines) + + if kernel_name is None: + base_name = graph.name or f"kk_fused_elem_{signature[:12]}" + kernel_name = f"{_sanitize_identifier(base_name, 'kk_fused_elem')}_{int(n_elements)}" + else: + kernel_name = _sanitize_identifier(kernel_name, "kk_fused_elem") + + return GeneratedFusedSource( + kernel_name=kernel_name, + input_names=safe_inputs, + output_name=safe_output_name, + source=source, + op_signature=signature, + graph=graph, + ) + + +def generate_broadcast_mul_reduce_two_pass_source( + *, + rows: int, + reduce_extent: int, + partial_cols: int, + reduce_op: str, + rhs_mode: str, + no_fma: bool, + kernel_base: str | None = None, +) -> GeneratedReductionTwoPassSource: + """Generate two-pass MSL for ``mul + reduce`` with float32 accumulation.""" + + if reduce_op not in {"sum", "mean", "max"}: + raise ValueError(f"Unsupported reduction op {reduce_op!r}.") + if rhs_mode not in _RHS_MODE_TO_INT: + raise ValueError(f"Unsupported rhs_mode {rhs_mode!r}.") + if rows <= 0 or reduce_extent <= 0 or partial_cols <= 0: + raise ValueError("rows/reduce_extent/partial_cols must be > 0.") + + payload = { + "kind": "broadcast_mul_reduce", + "rows": int(rows), + "reduce_extent": int(reduce_extent), + "partial_cols": int(partial_cols), + "reduce_op": reduce_op, + "rhs_mode": rhs_mode, + "no_fma": bool(no_fma), + } + signature = _hash_payload(payload) + base = _sanitize_identifier( + kernel_base or f"kk_fused_reduce_{signature[:10]}", + "kk_fused_reduce", + ) + pass1_name = f"{base}_pass1" + pass2_name = f"{base}_pass2" + rhs_mode_int = _RHS_MODE_TO_INT[rhs_mode] + + if reduce_op == "max": + init_value = "-INFINITY" + combine_expr = "metal::max(scratch[tid], scratch[tid + stride])" + row_reduce_combine = "acc = metal::max(acc, partials[base + g]);" + else: + init_value = "0.0f" + combine_expr = "scratch[tid] + scratch[tid + stride]" + row_reduce_combine = "acc += partials[base + g];" + + if no_fma: + mul_acc_stmt = "float prod = a * b; acc += prod;" + pragma_stmt = "#pragma clang fp contract(off)\n" + else: + mul_acc_stmt = "acc = metal::fma(a, b, acc);" + pragma_stmt = "" + + if reduce_op == "max": + mul_acc_stmt = "float prod = a * b; acc = metal::max(acc, prod);" + + mean_finalize = "float outv = acc / (float)REDUCE_EXTENT;" if reduce_op == "mean" else "float outv = acc;" + + source_pass1 = f""" + {pragma_stmt}constexpr uint ROWS = {int(rows)}u; + constexpr uint REDUCE_EXTENT = {int(reduce_extent)}u; + constexpr uint PARTIAL_COLS = {int(partial_cols)}u; + constexpr uint RHS_MODE = {rhs_mode_int}u; + + uint tid = thread_position_in_threadgroup.x; + uint packed = thread_position_in_grid.y; + uint row = packed / PARTIAL_COLS; + uint part = packed % PARTIAL_COLS; + if (row >= ROWS || part >= PARTIAL_COLS) {{ + return; + }} + + uint tgx = threads_per_threadgroup.x; + uint lhs_base = row * REDUCE_EXTENT; + uint start = part * tgx + tid; + + float acc = {init_value}; + uint stride = PARTIAL_COLS * tgx; + for (uint j = start; j < REDUCE_EXTENT; j += stride) {{ + uint lhs_idx = lhs_base + j; + float a = (float)lhs[lhs_idx]; + float b = 0.0f; + if (RHS_MODE == 0u) {{ + b = (float)rhs[lhs_idx]; + }} else if (RHS_MODE == 1u) {{ + b = (float)rhs[row]; + }} else {{ + b = (float)rhs[0]; + }} + {mul_acc_stmt} + }} + + threadgroup float scratch[512]; + scratch[tid] = acc; + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (uint stride = tgx >> 1; stride > 0; stride >>= 1) {{ + if (tid < stride) {{ + scratch[tid] = {combine_expr}; + }} + threadgroup_barrier(mem_flags::mem_threadgroup); + }} + + if (tid == 0u) {{ + partials[row * PARTIAL_COLS + part] = scratch[0]; + }} + """ + + source_pass2 = f""" + constexpr uint ROWS = {int(rows)}u; + constexpr uint PARTIAL_COLS = {int(partial_cols)}u; + constexpr uint REDUCE_EXTENT = {int(reduce_extent)}u; + + uint tid = thread_position_in_threadgroup.x; + uint row = thread_position_in_grid.y; + if (row >= ROWS) {{ + return; + }} + + uint tgx = threads_per_threadgroup.x; + uint base = row * PARTIAL_COLS; + float acc = {init_value}; + for (uint g = tid; g < PARTIAL_COLS; g += tgx) {{ + {row_reduce_combine} + }} + + threadgroup float scratch[512]; + scratch[tid] = acc; + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (uint stride = tgx >> 1; stride > 0; stride >>= 1) {{ + if (tid < stride) {{ + scratch[tid] = {combine_expr}; + }} + threadgroup_barrier(mem_flags::mem_threadgroup); + }} + + if (tid == 0u) {{ + {mean_finalize} + out[row] = (O)(outv); + }} + """ + + return GeneratedReductionTwoPassSource( + pass1_kernel_name=pass1_name, + pass2_kernel_name=pass2_name, + source_pass1=source_pass1, + source_pass2=source_pass2, + op_signature=signature, + reduce_op=reduce_op, + no_fma=bool(no_fma), + rows=int(rows), + reduce_extent=int(reduce_extent), + partial_cols=int(partial_cols), + rhs_mode=rhs_mode, + ) + + +def generate_quantized_shared_input_source( + *, + n: int, + k: int, + bits: int, + group_size: int, + no_fma: bool = False, + kernel_name: str | None = None, +) -> GeneratedQuantizedSharedInputSource: + """Generate decode-oriented QMM pair + SwiGLU source (M=1).""" + + if bits not in (4, 8): + raise ValueError("bits must be 4 or 8") + if group_size <= 0: + raise ValueError("group_size must be > 0") + if n <= 0 or k <= 0: + raise ValueError("n and k must be > 0") + if k % group_size != 0: + raise ValueError("group_size must divide k") + + elements_per_word = 32 // bits + if k % elements_per_word != 0: + raise ValueError("k must be divisible by packed elements per word") + + k_packed = k // elements_per_word + groups_per_row = k // group_size + mask = (1 << bits) - 1 + payload = { + "kind": "quantized_shared_input_swiglu", + "n": int(n), + "k": int(k), + "bits": int(bits), + "group_size": int(group_size), + "no_fma": bool(no_fma), + } + signature = _hash_payload(payload) + + pragma_stmt = "#pragma clang fp contract(off)\n" if no_fma else "" + if no_fma: + gate_acc = "float g_prod = gate_val * xv; acc_gate += g_prod;" + up_acc = "float u_prod = up_val * xv; acc_up += u_prod;" + else: + gate_acc = "acc_gate = metal::fma(gate_val, xv, acc_gate);" + up_acc = "acc_up = metal::fma(up_val, xv, acc_up);" + + source = f""" + {pragma_stmt}constexpr uint N = {int(n)}u; + constexpr uint K = {int(k)}u; + constexpr uint K_PACKED = {int(k_packed)}u; + constexpr uint GROUP = {int(group_size)}u; + constexpr uint GROUPS = {int(groups_per_row)}u; + constexpr uint ELEMENTS = {int(elements_per_word)}u; + constexpr uint MASK = {int(mask)}u; + + uint col = thread_position_in_grid.x; + if (col >= N) {{ + return; + }} + + uint w_row = col * K_PACKED; + uint s_row = col * GROUPS; + float acc_gate = 0.0f; + float acc_up = 0.0f; + + for (uint i = 0; i < K; ++i) {{ + uint pack_idx = i / ELEMENTS; + uint shift = (i % ELEMENTS) * {int(bits)}; + uint q_gate = (gate_w[w_row + pack_idx] >> shift) & MASK; + uint q_up = (up_w[w_row + pack_idx] >> shift) & MASK; + uint g = i / GROUP; + + float gate_val = (float)q_gate * (float)gate_scales[s_row + g] + + (float)gate_biases[s_row + g]; + float up_val = (float)q_up * (float)up_scales[s_row + g] + + (float)up_biases[s_row + g]; + float xv = (float)x[i]; + + {gate_acc} + {up_acc} + }} + + float outv = (acc_gate * kk_sigmoid(acc_gate)) * acc_up; + out[col] = (O)outv; + """ + + safe_name = _sanitize_identifier( + kernel_name or f"kk_qshared_swiglu_{signature[:12]}", + "kk_qshared_swiglu", + ) + return GeneratedQuantizedSharedInputSource( + kernel_name=safe_name, + source=source, + op_signature=signature, + bits=int(bits), + group_size=int(group_size), + n=int(n), + k=int(k), + ) + + +def generate_msl( + descriptor: FusedGraph | dict[str, Any], + *, + n_elements: int, + kernel_name: str | None = None, + output_name: str = "out", +) -> GeneratedFusedSource: + """Alias for generate_fused_elementwise_source().""" + + return generate_fused_elementwise_source( + descriptor, + n_elements=n_elements, + kernel_name=kernel_name, + output_name=output_name, + ) + + +__all__ = [ + "FusedNode", + "FusedGraph", + "GeneratedFusedSource", + "GeneratedReductionTwoPassSource", + "GeneratedQuantizedSharedInputSource", + "normalize_descriptor", + "fused_op_signature", + "generate_fused_elementwise_source", + "generate_broadcast_mul_reduce_two_pass_source", + "generate_quantized_shared_input_source", + "generate_msl", +] diff --git a/src/zmlx/fusion/graph.py b/src/zmlx/fusion/graph.py new file mode 100644 index 0000000..f4f1fa0 --- /dev/null +++ b/src/zmlx/fusion/graph.py @@ -0,0 +1,154 @@ +"""Graph IR for phase-1/2/3 JIT fusion.""" + +from __future__ import annotations + +from collections.abc import Iterable, Iterator, Mapping +from dataclasses import dataclass, field +from types import MappingProxyType +from typing import Any + +ELEMENTWISE_UNARY_OPS: frozenset[str] = frozenset( + {"neg", "exp", "tanh", "sigmoid", "relu", "silu", "log", "sqrt", "rsqrt"} +) +ELEMENTWISE_BINARY_OPS: frozenset[str] = frozenset( + {"add", "sub", "mul", "div", "maximum", "minimum"} +) +ELEMENTWISE_OPS: frozenset[str] = ELEMENTWISE_UNARY_OPS | ELEMENTWISE_BINARY_OPS +REDUCTION_OPS: frozenset[str] = frozenset({"sum", "mean", "max"}) +QUANTIZED_LINEAR_OPS: frozenset[str] = frozenset({"quantized_matmul"}) +SUPPORTED_OPS: frozenset[str] = ELEMENTWISE_OPS | REDUCTION_OPS | QUANTIZED_LINEAR_OPS + + +@dataclass(frozen=True, slots=True) +class TensorMeta: + """Tensor metadata captured during symbolic tracing.""" + + shape: tuple[int, ...] + dtype: str + + @property + def is_scalar(self) -> bool: + return self.shape == () + + +@dataclass(frozen=True, slots=True) +class Node: + """A graph node representing an input, scalar literal, or tensor op.""" + + id: int + op: str + inputs: tuple[int, ...] = () + meta: TensorMeta | None = None + attrs: Mapping[str, Any] = field(default_factory=dict) + + +@dataclass(frozen=True, slots=True) +class FusedOp: + """Descriptor emitted by a fusion rule for downstream JIT lowering.""" + + rule: str + kind: str + node_ids: tuple[int, ...] + input_ids: tuple[int, ...] + output_id: int + output_meta: TensorMeta + ops: tuple[str, ...] + attrs: Mapping[str, Any] = field(default_factory=dict) + + def __post_init__(self) -> None: + object.__setattr__(self, "attrs", MappingProxyType(dict(self.attrs))) + + +class Graph: + """Small DAG used by phase-1 symbolic tracing and fusion matching.""" + + def __init__(self) -> None: + self._nodes: dict[int, Node] = {} + self._order: list[int] = [] + self._users: dict[int, list[int]] = {} + self._next_id = 0 + + def __len__(self) -> int: + return len(self._order) + + def __iter__(self) -> Iterator[Node]: + for node_id in self._order: + yield self._nodes[node_id] + + def topological_order(self) -> tuple[int, ...]: + """Return node ids in insertion order (the graph is append-only DAG).""" + return tuple(self._order) + + def node(self, node_id: int) -> Node: + """Return node by id.""" + return self._nodes[node_id] + + def users(self, node_id: int) -> tuple[int, ...]: + """Return all direct consumers of ``node_id``.""" + return tuple(self._users.get(node_id, ())) + + def add_input(self, *, name: str, meta: TensorMeta) -> int: + """Add an input placeholder.""" + return self._add_node(op="input", inputs=(), meta=meta, attrs={"name": name}) + + def add_constant(self, *, value: int | float | bool, meta: TensorMeta) -> int: + """Add a scalar literal node.""" + return self._add_node(op="constant", inputs=(), meta=meta, attrs={"value": value}) + + def add_op( + self, + op: str, + *, + inputs: Iterable[int], + meta: TensorMeta, + attrs: Mapping[str, Any] | None = None, + ) -> int: + """Add a tensor op node with shape/dtype metadata.""" + if op not in SUPPORTED_OPS: + raise ValueError(f"Unsupported fusion op: {op}") + input_ids = tuple(inputs) + if op in ELEMENTWISE_UNARY_OPS | REDUCTION_OPS: + expected_arity = 1 + if len(input_ids) != expected_arity: + raise ValueError( + f"Op {op!r} expects {expected_arity} inputs, got {len(input_ids)}" + ) + elif op in ELEMENTWISE_BINARY_OPS: + expected_arity = 2 + if len(input_ids) != expected_arity: + raise ValueError( + f"Op {op!r} expects {expected_arity} inputs, got {len(input_ids)}" + ) + elif op == "quantized_matmul": + if len(input_ids) < 2 or len(input_ids) > 4: + raise ValueError( + "Op 'quantized_matmul' expects between 2 and 4 inputs " + f"(x, w, [scales], [biases]), got {len(input_ids)}" + ) + else: + raise ValueError(f"Unsupported fusion op: {op}") + return self._add_node(op=op, inputs=input_ids, meta=meta, attrs=dict(attrs or {})) + + def _add_node( + self, + *, + op: str, + inputs: tuple[int, ...], + meta: TensorMeta | None, + attrs: dict[str, Any] | None, + ) -> int: + for input_id in inputs: + if input_id not in self._nodes: + raise KeyError(f"Input node {input_id} does not exist.") + + node_id = self._next_id + self._next_id += 1 + + frozen_attrs: Mapping[str, Any] = MappingProxyType(dict(attrs or {})) + node = Node(id=node_id, op=op, inputs=inputs, meta=meta, attrs=frozen_attrs) + self._nodes[node_id] = node + self._order.append(node_id) + self._users.setdefault(node_id, []) + for input_id in inputs: + self._users.setdefault(input_id, []).append(node_id) + return node_id diff --git a/src/zmlx/fusion/integration.py b/src/zmlx/fusion/integration.py new file mode 100644 index 0000000..5db365d --- /dev/null +++ b/src/zmlx/fusion/integration.py @@ -0,0 +1,207 @@ +from __future__ import annotations + +from collections.abc import Callable +from functools import wraps +from typing import Any + +from .._compat import import_mx +from ..patch._types import ( + FusionCacheKey, + FusionConfig, + FusionLegacyCacheKey, + FusionRuntimeState, + FusionScopedCacheKey, + FusionSignature, + FusionSignatureItem, +) +from .jit import jit + + +def _is_tensor_like(value: Any) -> bool: + return hasattr(value, "shape") and hasattr(value, "dtype") + + +def _is_scalar(value: Any) -> bool: + return isinstance(value, (bool, int, float)) + + +def _signature_for(args: tuple[Any, ...], kwargs: dict[str, Any]) -> FusionSignature | None: + if kwargs: + return None + signature: list[FusionSignatureItem] = [] + for arg in args: + if _is_tensor_like(arg): + signature.append(("tensor", tuple(int(dim) for dim in arg.shape), str(arg.dtype))) + continue + if _is_scalar(arg): + signature.append(("scalar", type(arg).__name__, repr(arg))) + continue + return None + return tuple(signature) + + +def _outputs_close(expected: Any, actual: Any, *, rtol: float, atol: float) -> bool: + mx = import_mx() + if _is_tensor_like(expected) and _is_tensor_like(actual): + mx.eval(expected, actual) + return bool(mx.allclose(expected, actual, rtol=rtol, atol=atol).item()) + if isinstance(expected, (tuple, list)) and isinstance(actual, (tuple, list)): + if len(expected) != len(actual): + return False + return all( + _outputs_close(exp_item, act_item, rtol=rtol, atol=atol) + for exp_item, act_item in zip(expected, actual, strict=True) + ) + return expected == actual + + +def _fusion_stats(fn: Callable[..., Any]) -> dict[str, Any] | None: + stats_fn = getattr(fn, "fusion_stats", None) + if not callable(stats_fn): + return None + try: + stats = stats_fn() + except Exception: + return None + if not isinstance(stats, dict): + return None + return stats + + +def _decision_key_for( + pattern_name: str, + signature: FusionSignature, + module_instance_key: str | None, +) -> FusionCacheKey: + if module_instance_key is None: + return (pattern_name, signature) + return (pattern_name, module_instance_key, signature) + + +def _legacy_decision_key_for( + pattern_name: str, + signature: FusionSignature, +) -> FusionLegacyCacheKey: + return (pattern_name, signature) + + +def _normalize_legacy_key( + legacy_key: FusionLegacyCacheKey, + module_instance_key: str, +) -> FusionScopedCacheKey: + """Convert a legacy (non-scoped) key to scoped format.""" + pattern_name, signature = legacy_key + return (pattern_name, module_instance_key, signature) + + +def _has_key_with_migration( + entries: set[FusionCacheKey], + key: FusionCacheKey, + legacy_key: FusionLegacyCacheKey, + module_instance_key: str | None, + normalize: bool, +) -> bool: + """ + Check if key exists in entries, with optional legacy key migration. + + When normalize=True and a legacy key is found, it is removed from entries + and the scoped key is added instead (migrating to the new format). + + Returns True if either key format is found. + """ + if key in entries: + return True + if legacy_key in entries: + if ( + normalize + and module_instance_key is not None + and isinstance(key, tuple) + and len(key) == 3 + ): + entries.remove(legacy_key) + entries.add(key) + return True + return False + + +def wrap_callable_with_fusion_validation( + fn: Callable[..., Any], + *, + pattern_name: str, + fusion: FusionConfig, + state: FusionRuntimeState, + module_instance_key: str | None = None, +) -> Callable[..., Any]: + """Wrap callable with opt-in fusion and per-signature validation/blacklist.""" + + fused_candidate = jit(fn) + + @wraps(fn) + def _wrapped(*args: Any, **kwargs: Any) -> Any: + signature = _signature_for(args, kwargs) + if signature is None: + return fn(*args, **kwargs) + + key = _decision_key_for(pattern_name, signature, module_instance_key) + legacy_key = _legacy_decision_key_for(pattern_name, signature) + if _has_key_with_migration( + state.blacklist, key, legacy_key, module_instance_key, fusion.normalize_legacy_keys + ): + state.blacklist_hits += 1 + return fn(*args, **kwargs) + + if fusion.validate and not _has_key_with_migration( + state.validated, key, legacy_key, module_instance_key, fusion.normalize_legacy_keys + ): + stats_before = _fusion_stats(fused_candidate) + try: + expected = fn(*args, **kwargs) + actual = fused_candidate(*args, **kwargs) + except Exception: + state.compile_failures += 1 + state.blacklist.add(key) + return fn(*args, **kwargs) + stats_after = _fusion_stats(fused_candidate) + if ( + stats_before is not None + and stats_after is not None + and int(stats_after.get("fallbacks", 0)) > int(stats_before.get("fallbacks", 0)) + and not bool(stats_after.get("fusable", False)) + ): + state.compile_failures += 1 + state.blacklist.add(key) + return expected + + if _outputs_close(expected, actual, rtol=fusion.rtol, atol=fusion.atol): + state.validated.add(key) + return actual + + state.validation_mismatches += 1 + state.blacklist.add(key) + return expected + + stats_before = _fusion_stats(fused_candidate) + try: + out = fused_candidate(*args, **kwargs) + except Exception: + state.runtime_failures += 1 + state.blacklist.add(key) + return fn(*args, **kwargs) + stats_after = _fusion_stats(fused_candidate) + if ( + stats_before is not None + and stats_after is not None + and int(stats_after.get("fallbacks", 0)) > int(stats_before.get("fallbacks", 0)) + ): + state.runtime_failures += 1 + state.blacklist.add(key) + return fn(*args, **kwargs) + return out + + _wrapped._zmlx_fusion_wrapped = True # type: ignore[attr-defined] + if hasattr(fused_candidate, "fusion_stats"): + _wrapped.fusion_stats = fused_candidate.fusion_stats # type: ignore[attr-defined] + return _wrapped + + +__all__ = ["wrap_callable_with_fusion_validation"] diff --git a/src/zmlx/fusion/jit.py b/src/zmlx/fusion/jit.py new file mode 100644 index 0000000..38150e0 --- /dev/null +++ b/src/zmlx/fusion/jit.py @@ -0,0 +1,611 @@ +from __future__ import annotations + +import inspect +from collections.abc import Callable +from contextlib import contextmanager +from dataclasses import dataclass, field +from functools import wraps +from typing import Any, Literal + +from .._compat import import_mx +from .graph import FusedOp, Graph, TensorMeta +from .passes import run_fusion +from .runtime import ( + CompiledFusedElementwise, + CompiledFusedReductionTwoPass, + CompiledQuantizedSharedInputFusion, + compile_broadcast_mul_reduce_two_pass, + compile_fused_elementwise, + compile_quantized_shared_input_swiglu, + launch_broadcast_mul_reduce_two_pass, + launch_quantized_shared_input_swiglu, +) +from .trace import ( + ProxyTensor, + exp, + log, + maximum, + minimum, + quantized_matmul, + reduce_max, + reduce_mean, + reduce_sum, + relu, + rsqrt, + sigmoid, + silu, + sqrt, + symbolic_trace, + tanh, +) + +Scalar = bool | int | float +SignatureItem = tuple[str, Any, Any] + + +class _UnsupportedFusion(ValueError): + pass + + +@dataclass(frozen=True, slots=True) +class _InputBinding: + name: str + kind: Literal["arg", "const"] + arg_index: int | None = None + const_value: Scalar | None = None + const_dtype: str | None = None + + +@dataclass(frozen=True, slots=True) +class _FusionPlan: + kind: str + descriptor: dict[str, Any] + bindings: tuple[_InputBinding, ...] + output_meta: TensorMeta + ops_repr: str + + +@dataclass(frozen=True, slots=True) +class _CompiledEntry: + plan: _FusionPlan + kernel: ( + CompiledFusedElementwise + | CompiledFusedReductionTwoPass + | CompiledQuantizedSharedInputFusion + ) + + +@dataclass +class _FusionState: + traced: bool = False + fusable: bool = False + unsupported_reason: str | None = None + expression: str | None = None + traces: int = 0 + compiles: int = 0 + cache_hits: int = 0 + fallbacks: int = 0 + cache: dict[tuple[SignatureItem, ...], _CompiledEntry] = field(default_factory=dict) + failed_signatures: set[tuple[SignatureItem, ...]] = field(default_factory=set) + + +def _is_tensor_like(value: Any) -> bool: + return hasattr(value, "shape") and hasattr(value, "dtype") + + +def _is_scalar(value: Any) -> bool: + return isinstance(value, (bool, int, float)) + + +def _arg_signature(args: tuple[Any, ...]) -> tuple[SignatureItem, ...]: + signature: list[SignatureItem] = [] + tensor_count = 0 + for arg in args: + if _is_tensor_like(arg): + tensor_count += 1 + signature.append(("tensor", tuple(int(dim) for dim in arg.shape), str(arg.dtype))) + continue + if _is_scalar(arg): + signature.append(("scalar", type(arg).__name__, repr(arg))) + continue + raise _UnsupportedFusion( + f"unsupported non-tensor argument type: {type(arg).__name__}" + ) + if tensor_count == 0: + raise _UnsupportedFusion("at least one tensor argument is required") + return tuple(signature) + + +def _first_tensor_arg(args: tuple[Any, ...]) -> Any: + for arg in args: + if _is_tensor_like(arg): + return arg + raise _UnsupportedFusion("no tensor argument found") + + +def _prod(shape: tuple[int, ...]) -> int: + n = 1 + for dim in shape: + n *= int(dim) + return int(n) + + +def _dtype_from_name(mx: Any, name: str | None, fallback: Any) -> Any: + if not name: + return fallback + return getattr(mx, name, fallback) + + +def _arg_names(fn: Callable[..., Any]) -> tuple[str, ...]: + sig = inspect.signature(fn) + names: list[str] = [] + for param in sig.parameters.values(): + if param.kind not in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): + raise _UnsupportedFusion("only positional parameters are supported") + names.append(param.name) + return tuple(names) + + +@contextmanager +def _patched_mx_for_tracing() -> Any: + mx = import_mx() + proxy_ops: dict[str, Callable[..., Any]] = { + "sigmoid": sigmoid, + "silu": silu, + "tanh": tanh, + "exp": exp, + "log": log, + "sqrt": sqrt, + "rsqrt": rsqrt, + "relu": relu, + "maximum": maximum, + "minimum": minimum, + } + originals: dict[str, Any] = {} + + def _contains_proxy(values: tuple[Any, ...], kwargs: dict[str, Any]) -> bool: + if any(isinstance(v, ProxyTensor) for v in values): + return True + return any(isinstance(v, ProxyTensor) for v in kwargs.values()) + + def _wrap(original: Any, proxy_fn: Callable[..., Any]) -> Callable[..., Any]: + @wraps(original) + def _inner(*args: Any, **kwargs: Any) -> Any: + if _contains_proxy(args, kwargs): + if kwargs: + raise _UnsupportedFusion("keyword arguments in traced mx ops are unsupported") + return proxy_fn(*args) + return original(*args, **kwargs) + + return _inner + + def _wrap_reduce(original: Any, reduce_name: str) -> Callable[..., Any]: + @wraps(original) + def _inner(*args: Any, **kwargs: Any) -> Any: + if not _contains_proxy(args, kwargs): + return original(*args, **kwargs) + if len(args) != 1: + raise _UnsupportedFusion( + f"traced mx.{reduce_name} expects exactly one positional tensor argument" + ) + axis = kwargs.get("axis", -1) + keepdims = bool(kwargs.get("keepdims", False)) + x = args[0] + if reduce_name == "sum": + return reduce_sum(x, axis=axis, keepdims=keepdims) + if reduce_name == "mean": + return reduce_mean(x, axis=axis, keepdims=keepdims) + return reduce_max(x, axis=axis, keepdims=keepdims) + + return _inner + + def _wrap_quantized_matmul(original: Any) -> Callable[..., Any]: + @wraps(original) + def _inner(*args: Any, **kwargs: Any) -> Any: + if not _contains_proxy(args, kwargs): + return original(*args, **kwargs) + if len(args) < 2: + raise _UnsupportedFusion( + "traced mx.quantized_matmul requires at least lhs and rhs arguments" + ) + lhs = args[0] + rhs = args[1] + scales = kwargs.get("scales") + biases = kwargs.get("biases") + group_size = kwargs.get("group_size") + bits = kwargs.get("bits") + mode = kwargs.get("mode") + transpose = bool(kwargs.get("transpose", False)) + return quantized_matmul( + lhs, + rhs, + scales=scales, + biases=biases, + group_size=group_size, + bits=bits, + mode=mode, + transpose=transpose, + ) + + return _inner + + try: + for name, proxy_fn in proxy_ops.items(): + if not hasattr(mx, name): + continue + original = getattr(mx, name) + originals[name] = original + setattr(mx, name, _wrap(original, proxy_fn)) + for reduce_name in ("sum", "mean", "max"): + if not hasattr(mx, reduce_name): + continue + original = getattr(mx, reduce_name) + originals[reduce_name] = original + setattr(mx, reduce_name, _wrap_reduce(original, reduce_name)) + if hasattr(mx, "quantized_matmul"): + original = mx.quantized_matmul + originals["quantized_matmul"] = original + mx.quantized_matmul = _wrap_quantized_matmul(original) + yield + finally: + for name, original in originals.items(): + setattr(mx, name, original) + + +def _select_fused_op(graph: Graph, output_id: int) -> FusedOp: + result = run_fusion(graph) + for fused in result.fused_ops: + if fused.output_id == output_id: + return fused + raise _UnsupportedFusion("no fusion pattern ended at function output") + + +def _build_plan( + fn: Callable[..., Any], + args: tuple[Any, ...], + arg_names: tuple[str, ...], +) -> _FusionPlan: + tensor_positions: list[int] = [] + tensor_names: list[str] = [] + tensor_metas: list[TensorMeta] = [] + + for idx, arg in enumerate(args): + if not _is_tensor_like(arg): + continue + tensor_positions.append(idx) + tensor_names.append(arg_names[idx]) + tensor_metas.append( + TensorMeta( + shape=tuple(int(dim) for dim in arg.shape), + dtype=str(arg.dtype), + ) + ) + + if not tensor_positions: + raise _UnsupportedFusion("fused path requires at least one tensor input") + + def _invoke_with_proxy_tensors(*proxy_tensors: ProxyTensor) -> Any: + call_args = list(args) + for pos, proxy in zip(tensor_positions, proxy_tensors, strict=True): + call_args[pos] = proxy + with _patched_mx_for_tracing(): + return fn(*call_args) + + try: + trace = symbolic_trace( + _invoke_with_proxy_tensors, + *tensor_metas, + input_names=tensor_names, + ) + except Exception as exc: + raise _UnsupportedFusion(f"symbolic tracing failed: {exc}") from exc + + if len(trace.output_ids) != 1: + raise _UnsupportedFusion("phase-1 fusion currently supports a single output") + + fused = _select_fused_op(trace.graph, trace.output_ids[0]) + descriptor, bindings = _build_descriptor(trace.graph, fused, arg_names) + return _FusionPlan( + kind=fused.kind, + descriptor=descriptor, + bindings=bindings, + output_meta=fused.output_meta, + ops_repr=f"{fused.kind}: " + " -> ".join(fused.ops), + ) + + +def _build_descriptor( + graph: Graph, + fused: FusedOp, + arg_names: tuple[str, ...], +) -> tuple[dict[str, Any], tuple[_InputBinding, ...]]: + tensor_arg_by_name = {name: idx for idx, name in enumerate(arg_names)} + external_name_by_node_id: dict[int, str] = {} + bindings: list[_InputBinding] = [] + + for node_id in fused.input_ids: + node = graph.node(node_id) + if node.op == "input": + name_raw = node.attrs.get("name") + name = str(name_raw) if name_raw is not None else f"arg_{node_id}" + if name not in tensor_arg_by_name: + raise _UnsupportedFusion(f"trace input {name!r} was not found in function args") + binding = _InputBinding(name=name, kind="arg", arg_index=tensor_arg_by_name[name]) + external_name_by_node_id[node_id] = name + bindings.append(binding) + continue + + if node.op == "constant": + value = node.attrs.get("value") + if not _is_scalar(value): + raise _UnsupportedFusion("only scalar constants are supported") + name = f"const_{node_id}" + dtype = node.meta.dtype if node.meta is not None else "float32" + binding = _InputBinding( + name=name, + kind="const", + const_value=value, + const_dtype=dtype, + ) + external_name_by_node_id[node_id] = name + bindings.append(binding) + continue + + raise _UnsupportedFusion(f"unsupported fused external input op: {node.op}") + + if fused.kind == "elementwise_chain": + fused_set = set(fused.node_ids) + nodes: list[dict[str, Any]] = [] + for node_id in fused.node_ids: + node = graph.node(node_id) + args: list[str] = [] + for input_id in node.inputs: + if input_id in fused_set: + args.append(f"n{input_id}") + continue + ext = external_name_by_node_id.get(input_id) + if ext is None: + raise _UnsupportedFusion( + f"fused node {node_id} depends on unsupported external node {input_id}" + ) + args.append(ext) + + nodes.append({ + "id": f"n{node_id}", + "op": node.op, + "args": args, + }) + + descriptor = { + "name": "zmlx_phase1_elementwise", + "kind": "elementwise_chain", + "inputs": [binding.name for binding in bindings], + "nodes": nodes, + "output": f"n{fused.output_id}", + } + return descriptor, tuple(bindings) + + if fused.kind == "broadcast_mul_reduce": + if len(bindings) != 2: + raise _UnsupportedFusion( + f"broadcast_mul_reduce expects exactly 2 external inputs, got {len(bindings)}" + ) + descriptor = { + "name": "zmlx_phase2_broadcast_mul_reduce", + "kind": fused.kind, + "inputs": [binding.name for binding in bindings], + "reduce_op": str(fused.attrs.get("reduce_op", "sum")), + "axis": int(fused.attrs.get("axis", -1)), + "keepdims": bool(fused.attrs.get("keepdims", False)), + "no_fma": bool(fused.attrs.get("no_fma", False)), + } + return descriptor, tuple(bindings) + + if fused.kind == "quantized_shared_input_swiglu": + if len(bindings) != 7: + raise _UnsupportedFusion( + "quantized_shared_input_swiglu expects 7 external inputs " + f"(got {len(bindings)})" + ) + descriptor = { + "name": "zmlx_phase3_quantized_shared_input_swiglu", + "kind": fused.kind, + "inputs": [binding.name for binding in bindings], + "bits": fused.attrs.get("bits"), + "group_size": fused.attrs.get("group_size"), + "mode": fused.attrs.get("mode"), + "transpose": bool(fused.attrs.get("transpose", False)), + "no_fma": bool(fused.attrs.get("no_fma", False)), + } + return descriptor, tuple(bindings) + + raise _UnsupportedFusion(f"Unsupported fused kind {fused.kind!r}.") + + +def _materialize_runtime_inputs( + args: tuple[Any, ...], + bindings: tuple[_InputBinding, ...], + materialize_shape: tuple[int, ...], + materialize_dtype: Any, +) -> list[Any]: + mx = import_mx() + runtime_inputs: list[Any] = [] + for binding in bindings: + if binding.kind == "arg": + assert binding.arg_index is not None + runtime_inputs.append(args[binding.arg_index]) + continue + + assert binding.const_value is not None + dtype = _dtype_from_name(mx, binding.const_dtype, materialize_dtype) + runtime_inputs.append(mx.full(materialize_shape, binding.const_value, dtype=dtype)) + + return runtime_inputs + + +def _launch_with_compiled( + compiled: CompiledFusedElementwise, + inputs: list[Any], + out_shape: tuple[int, ...], + out_dtype: Any, +) -> Any: + mx = import_mx() + n = _prod(out_shape) + tgx = 1 if n <= 0 else min(256, n) + return compiled.kernel( + *inputs, + template=[("T", mx.float32), ("O", out_dtype)], + grid=(n, 1, 1), + threadgroup=(tgx, 1, 1), + output_shapes=[out_shape], + output_dtypes=[out_dtype], + )[0] + + +def _compile_for_plan( + plan: _FusionPlan, + runtime_inputs: list[Any], +) -> ( + CompiledFusedElementwise + | CompiledFusedReductionTwoPass + | CompiledQuantizedSharedInputFusion +): + if plan.kind == "elementwise_chain": + return compile_fused_elementwise(plan.descriptor, runtime_inputs) + if plan.kind == "broadcast_mul_reduce": + return compile_broadcast_mul_reduce_two_pass(plan.descriptor, runtime_inputs) + if plan.kind == "quantized_shared_input_swiglu": + return compile_quantized_shared_input_swiglu(plan.descriptor, runtime_inputs) + raise _UnsupportedFusion(f"Unsupported fusion kind {plan.kind!r}") + + +def _launch_plan_entry( + entry: _CompiledEntry, + runtime_inputs: list[Any], + output_meta: TensorMeta, + output_dtype: Any, +) -> Any: + if isinstance(entry.kernel, CompiledFusedElementwise): + return _launch_with_compiled( + entry.kernel, + runtime_inputs, + output_meta.shape, + output_dtype, + ) + if isinstance(entry.kernel, CompiledFusedReductionTwoPass): + return launch_broadcast_mul_reduce_two_pass( + entry.plan.descriptor, + runtime_inputs, + output_dtype=output_dtype, + ) + if isinstance(entry.kernel, CompiledQuantizedSharedInputFusion): + return launch_quantized_shared_input_swiglu( + entry.plan.descriptor, + runtime_inputs, + output_dtype=output_dtype, + ) + raise _UnsupportedFusion(f"Unsupported compiled kernel type {type(entry.kernel).__name__}") + + +def jit(fn: Callable[..., Any]) -> Callable[..., Any]: + """Phase-1/2/3 JIT fusion via symbolic tracing + Metal codegen/runtime.""" + + state = _FusionState() + try: + arg_names = _arg_names(fn) + except _UnsupportedFusion as exc: + arg_names = () + state.unsupported_reason = str(exc) + + def fusion_stats() -> dict[str, Any]: + return { + "traced": state.traced, + "fusable": state.fusable, + "unsupported_reason": state.unsupported_reason, + "expression": state.expression, + "traces": state.traces, + "compiles": state.compiles, + "cache_hits": state.cache_hits, + "fallbacks": state.fallbacks, + "cache_entries": len(state.cache), + "failed_signatures": len(state.failed_signatures), + } + + @wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + if kwargs or not arg_names: + state.fallbacks += 1 + return fn(*args, **kwargs) + + try: + signature = _arg_signature(args) + except _UnsupportedFusion as exc: + state.traced = True + state.traces += 1 + state.fusable = False + state.unsupported_reason = str(exc) + state.fallbacks += 1 + return fn(*args, **kwargs) + + if signature in state.failed_signatures: + state.fallbacks += 1 + return fn(*args, **kwargs) + + entry = state.cache.get(signature) + if entry is None: + state.traced = True + state.traces += 1 + try: + plan = _build_plan(fn, args, arg_names) + first_tensor = _first_tensor_arg(args) + materialize_shape = tuple(int(dim) for dim in first_tensor.shape) + runtime_inputs = _materialize_runtime_inputs( + args, + plan.bindings, + materialize_shape, + first_tensor.dtype, + ) + compiled = _compile_for_plan(plan, runtime_inputs) + except Exception as exc: + state.fusable = False + state.unsupported_reason = str(exc) + state.failed_signatures.add(signature) + state.fallbacks += 1 + return fn(*args, **kwargs) + + entry = _CompiledEntry(plan=plan, kernel=compiled) + state.cache[signature] = entry + state.expression = plan.ops_repr + state.fusable = True + state.unsupported_reason = None + state.compiles += 1 + else: + state.cache_hits += 1 + + try: + first_tensor = _first_tensor_arg(args) + materialize_shape = tuple(int(dim) for dim in first_tensor.shape) + runtime_inputs = _materialize_runtime_inputs( + args, + entry.plan.bindings, + materialize_shape, + first_tensor.dtype, + ) + mx = import_mx() + output_dtype = _dtype_from_name(mx, entry.plan.output_meta.dtype, first_tensor.dtype) + return _launch_plan_entry( + entry, + runtime_inputs, + entry.plan.output_meta, + output_dtype, + ) + except Exception: + state.failed_signatures.add(signature) + state.fallbacks += 1 + return fn(*args, **kwargs) + + wrapper.fusion_stats = fusion_stats # type: ignore[attr-defined] + return wrapper diff --git a/src/zmlx/fusion/passes.py b/src/zmlx/fusion/passes.py new file mode 100644 index 0000000..1b9a33b --- /dev/null +++ b/src/zmlx/fusion/passes.py @@ -0,0 +1,77 @@ +"""Fusion pass orchestration for phase-1/2/3 JIT fusion.""" + +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import dataclass + +from .graph import FusedOp, Graph +from .rules import ( + BroadcastMulReduceRule, + ElementwiseChainRule, + FusionRule, + QuantizedSharedInputSwiGLURule, +) + + +@dataclass(frozen=True, slots=True) +class FusionPassResult: + """Result of running fusion rules over a graph.""" + + fused_ops: tuple[FusedOp, ...] + consumed_nodes: frozenset[int] + + +def run_phase1_fusion( + graph: Graph, + rules: Sequence[FusionRule] | None = None, +) -> FusionPassResult: + """Apply only phase-1 elementwise rules in priority order.""" + active_rules = _sort_rules(rules or (ElementwiseChainRule(),)) + return _run_fusion_with_rules(graph, active_rules) + + +def run_fusion( + graph: Graph, + rules: Sequence[FusionRule] | None = None, +) -> FusionPassResult: + """Apply phase-1/2/3 rules in priority order and return fused descriptors.""" + active_rules = _sort_rules( + rules + or ( + BroadcastMulReduceRule(), + QuantizedSharedInputSwiGLURule(), + ElementwiseChainRule(), + ) + ) + return _run_fusion_with_rules(graph, active_rules) + + +def _run_fusion_with_rules( + graph: Graph, + active_rules: Sequence[FusionRule], +) -> FusionPassResult: + consumed: set[int] = set() + fused_ops: list[FusedOp] = [] + + for node_id in graph.topological_order(): + if node_id in consumed: + continue + + for rule in active_rules: + match = rule.match(graph, start=node_id, consumed=consumed) + if match is None: + continue + fused_ops.append(rule.fuse(graph, match)) + consumed.update(match.node_ids) + break + + return FusionPassResult( + fused_ops=tuple(fused_ops), + consumed_nodes=frozenset(consumed), + ) + + +def _sort_rules(rules: Sequence[FusionRule]) -> list[FusionRule]: + """Sort rules by descending priority and then by rule name.""" + return sorted(rules, key=lambda rule: (-rule.priority, rule.name)) diff --git a/src/zmlx/fusion/rules.py b/src/zmlx/fusion/rules.py new file mode 100644 index 0000000..63fe562 --- /dev/null +++ b/src/zmlx/fusion/rules.py @@ -0,0 +1,434 @@ +"""Fusion rule protocol and concrete phase-1/2/3 matchers.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import Set +from dataclasses import dataclass +from typing import Protocol, runtime_checkable + +from .graph import ELEMENTWISE_OPS, FusedOp, Graph, TensorMeta + + +@dataclass(frozen=True, slots=True) +class FusionMatch: + """A rule match over one or more graph nodes.""" + + node_ids: tuple[int, ...] + + +@runtime_checkable +class FusionRule(Protocol): + """Protocol for phase-1 fusion rules.""" + + name: str + priority: int + + def match( + self, + graph: Graph, + *, + start: int, + consumed: Set[int], + ) -> FusionMatch | None: + """Return a match rooted at ``start`` or ``None`` if not applicable.""" + + def fuse(self, graph: Graph, match: FusionMatch) -> FusedOp: + """Convert a match into a typed ``FusedOp`` descriptor.""" + + +class BaseFusionRule(ABC): + """Convenience base class for rules with ``priority``.""" + + name: str = "base_rule" + + def __init__(self, *, priority: int = 0) -> None: + self.priority = priority + + @abstractmethod + def match( + self, + graph: Graph, + *, + start: int, + consumed: Set[int], + ) -> FusionMatch | None: ... + + @abstractmethod + def fuse(self, graph: Graph, match: FusionMatch) -> FusedOp: ... + + +class ElementwiseChainRule(BaseFusionRule): + """Fuse linear chains of elementwise ops into one descriptor.""" + + name = "elementwise_chain" + + def __init__(self, *, priority: int = 100, min_length: int = 2) -> None: + super().__init__(priority=priority) + self.min_length = min_length + + def match( + self, + graph: Graph, + *, + start: int, + consumed: Set[int], + ) -> FusionMatch | None: + if start in consumed: + return None + if not self._can_start_chain(graph, start=start, consumed=consumed): + return None + + chain: list[int] = [start] + current = start + + while True: + users = graph.users(current) + if len(users) != 1: + break + + nxt = users[0] + if nxt in consumed: + break + + nxt_node = graph.node(nxt) + if nxt_node.op not in ELEMENTWISE_OPS: + break + + elementwise_inputs = [ + input_id + for input_id in nxt_node.inputs + if input_id not in consumed and graph.node(input_id).op in ELEMENTWISE_OPS + ] + if elementwise_inputs != [current]: + break + + chain.append(nxt) + current = nxt + + if len(chain) < self.min_length: + return None + return FusionMatch(node_ids=tuple(chain)) + + def fuse(self, graph: Graph, match: FusionMatch) -> FusedOp: + chain_set = set(match.node_ids) + input_ids: list[int] = [] + seen_inputs: set[int] = set() + + for node_id in match.node_ids: + for input_id in graph.node(node_id).inputs: + if input_id in chain_set or input_id in seen_inputs: + continue + seen_inputs.add(input_id) + input_ids.append(input_id) + + output_id = match.node_ids[-1] + output_meta = graph.node(output_id).meta + if output_meta is None: + raise ValueError(f"Matched output node {output_id} has no metadata.") + + ops = tuple(graph.node(node_id).op for node_id in match.node_ids) + return FusedOp( + rule=self.name, + kind="elementwise_chain", + node_ids=match.node_ids, + input_ids=tuple(input_ids), + output_id=output_id, + output_meta=output_meta, + ops=ops, + ) + + def _can_start_chain( + self, + graph: Graph, + *, + start: int, + consumed: Set[int], + ) -> bool: + node = graph.node(start) + if node.op not in ELEMENTWISE_OPS: + return False + + preds = [ + input_id + for input_id in node.inputs + if input_id not in consumed and graph.node(input_id).op in ELEMENTWISE_OPS + ] + if len(preds) > 1: + return False + if len(preds) == 1 and len(graph.users(preds[0])) == 1: + return False + return True + + +def _is_last_axis(axis: int | None, *, rank: int) -> bool: + if axis is None: + return False + axis_norm = int(axis) + if axis_norm < 0: + axis_norm += rank + return axis_norm == rank - 1 + + +def _is_broadcast_mul( + lhs_meta: TensorMeta | None, + rhs_meta: TensorMeta | None, + out_meta: TensorMeta | None, +) -> bool: + if lhs_meta is None or rhs_meta is None or out_meta is None: + return False + if lhs_meta.shape == out_meta.shape and rhs_meta.shape == out_meta.shape: + return False + if lhs_meta.is_scalar or rhs_meta.is_scalar: + return True + return lhs_meta.shape != rhs_meta.shape + + +def _unique_external_inputs(graph: Graph, node_ids: tuple[int, ...]) -> tuple[int, ...]: + fused_set = set(node_ids) + out: list[int] = [] + seen: set[int] = set() + for node_id in node_ids: + node = graph.node(node_id) + for input_id in node.inputs: + if input_id in fused_set or input_id in seen: + continue + seen.add(input_id) + out.append(input_id) + return tuple(out) + + +class BroadcastMulReduceRule(BaseFusionRule): + """Fuse ``mul`` followed by a last-axis reduction (phase-2 path).""" + + name = "broadcast_mul_reduce" + + def __init__(self, *, priority: int = 250) -> None: + super().__init__(priority=priority) + + def match( + self, + graph: Graph, + *, + start: int, + consumed: Set[int], + ) -> FusionMatch | None: + if start in consumed: + return None + mul_node = graph.node(start) + if mul_node.op != "mul": + return None + if any(input_id in consumed for input_id in mul_node.inputs): + return None + + users = tuple(user for user in graph.users(start) if user not in consumed) + if len(users) != 1: + return None + + reduce_id = users[0] + reduce_node = graph.node(reduce_id) + if reduce_node.op not in {"sum", "mean", "max"}: + return None + if reduce_node.inputs != (start,): + return None + if reduce_node.meta is None or mul_node.meta is None: + return None + if not _is_last_axis(reduce_node.attrs.get("axis"), rank=len(mul_node.meta.shape)): + return None + + lhs_meta = graph.node(mul_node.inputs[0]).meta + rhs_meta = graph.node(mul_node.inputs[1]).meta + if not _is_broadcast_mul(lhs_meta, rhs_meta, mul_node.meta): + return None + return FusionMatch(node_ids=(start, reduce_id)) + + def fuse(self, graph: Graph, match: FusionMatch) -> FusedOp: + mul_id, reduce_id = match.node_ids + reduce_node = graph.node(reduce_id) + output_meta = reduce_node.meta + if output_meta is None: + raise ValueError(f"Matched reduction node {reduce_id} has no metadata.") + + attrs = { + "reduce_op": reduce_node.op, + "axis": int(reduce_node.attrs.get("axis", -1)), + "keepdims": bool(reduce_node.attrs.get("keepdims", False)), + "accumulate_dtype": "float32", + } + return FusedOp( + rule=self.name, + kind="broadcast_mul_reduce", + node_ids=match.node_ids, + input_ids=_unique_external_inputs(graph, match.node_ids), + output_id=reduce_id, + output_meta=output_meta, + ops=("mul", reduce_node.op), + attrs=attrs, + ) + + +def _decode_friendly_lhs(meta: TensorMeta | None) -> bool: + if meta is None: + return False + if len(meta.shape) == 1: + return True + if len(meta.shape) == 2 and int(meta.shape[0]) == 1: + return True + return False + + +def _is_affine_qmm(node_attrs: dict[str, object] | object) -> bool: + if not isinstance(node_attrs, dict): + return False + mode = str(node_attrs.get("mode", "")) + bits = node_attrs.get("bits") + transpose = bool(node_attrs.get("transpose", False)) + return mode == "affine" and transpose and bits in (4, 8) + + +class QuantizedSharedInputSwiGLURule(BaseFusionRule): + """Fuse shared-input QMM pair with ``silu(gate) * up`` epilogue (phase-3).""" + + name = "quantized_shared_input_swiglu" + + def __init__(self, *, priority: int = 220) -> None: + super().__init__(priority=priority) + + def match( + self, + graph: Graph, + *, + start: int, + consumed: Set[int], + ) -> FusionMatch | None: + if start in consumed: + return None + gate_qmm = graph.node(start) + if gate_qmm.op != "quantized_matmul": + return None + if not _is_affine_qmm(dict(gate_qmm.attrs)): + return None + if not gate_qmm.inputs: + return None + lhs_id = gate_qmm.inputs[0] + if not _decode_friendly_lhs(graph.node(lhs_id).meta): + return None + + epilogue = self._match_epilogue(graph, start=start, consumed=consumed) + if epilogue is None: + return None + epilogue_node_ids, up_qmm_id, output_mul_id = epilogue + + up_qmm = graph.node(up_qmm_id) + if up_qmm.op != "quantized_matmul": + return None + if not _is_affine_qmm(dict(up_qmm.attrs)): + return None + if not up_qmm.inputs or up_qmm.inputs[0] != lhs_id: + return None + if up_qmm.meta != gate_qmm.meta: + return None + + gate_bits = gate_qmm.attrs.get("bits") + up_bits = up_qmm.attrs.get("bits") + gate_group = gate_qmm.attrs.get("group_size") + up_group = up_qmm.attrs.get("group_size") + if gate_bits != up_bits or gate_group != up_group: + return None + + return FusionMatch(node_ids=epilogue_node_ids + (up_qmm_id, output_mul_id)) + + def _match_epilogue( + self, + graph: Graph, + *, + start: int, + consumed: Set[int], + ) -> tuple[tuple[int, ...], int, int] | None: + gate_users = tuple(user for user in graph.users(start) if user not in consumed) + # Canonical form: silu(gate_qmm) * up_qmm + if len(gate_users) == 1: + silu_id = gate_users[0] + silu_node = graph.node(silu_id) + if silu_node.op == "silu" and silu_node.inputs == (start,): + silu_users = tuple(user for user in graph.users(silu_id) if user not in consumed) + if len(silu_users) != 1: + return None + output_mul_id = silu_users[0] + output_mul = graph.node(output_mul_id) + if output_mul.op != "mul": + return None + other_inputs = [node_id for node_id in output_mul.inputs if node_id != silu_id] + if len(other_inputs) != 1: + return None + up_qmm_id = other_inputs[0] + if up_qmm_id in consumed: + return None + return (start, silu_id), up_qmm_id, output_mul_id + + # Expanded form: (gate_qmm * sigmoid(gate_qmm)) * up_qmm + sigmoid_ids = [node_id for node_id in gate_users if graph.node(node_id).op == "sigmoid"] + if len(sigmoid_ids) != 1: + return None + sigmoid_id = sigmoid_ids[0] + gate_mul_ids = [ + node_id + for node_id in gate_users + if graph.node(node_id).op == "mul" + and set(graph.node(node_id).inputs) == {start, sigmoid_id} + ] + if len(gate_mul_ids) != 1: + return None + gate_mul_id = gate_mul_ids[0] + gate_mul_users = tuple(user for user in graph.users(gate_mul_id) if user not in consumed) + if len(gate_mul_users) != 1: + return None + output_mul_id = gate_mul_users[0] + output_mul = graph.node(output_mul_id) + if output_mul.op != "mul": + return None + other_inputs = [node_id for node_id in output_mul.inputs if node_id != gate_mul_id] + if len(other_inputs) != 1: + return None + up_qmm_id = other_inputs[0] + if up_qmm_id in consumed: + return None + return (start, sigmoid_id, gate_mul_id), up_qmm_id, output_mul_id + + def fuse(self, graph: Graph, match: FusionMatch) -> FusedOp: + gate_qmm_id = match.node_ids[0] + mul_id = match.node_ids[-1] + up_candidates = [ + node_id + for node_id in match.node_ids[1:-1] + if graph.node(node_id).op == "quantized_matmul" + ] + if len(up_candidates) != 1: + raise ValueError( + "quantized_shared_input_swiglu expected exactly one up_qmm node in match." + ) + up_qmm_id = up_candidates[0] + mul_node = graph.node(mul_id) + output_meta = mul_node.meta + if output_meta is None: + raise ValueError(f"Matched output node {mul_id} has no metadata.") + + gate_qmm = graph.node(gate_qmm_id) + attrs = { + "bits": gate_qmm.attrs.get("bits"), + "group_size": gate_qmm.attrs.get("group_size"), + "mode": gate_qmm.attrs.get("mode"), + "transpose": bool(gate_qmm.attrs.get("transpose", False)), + "gate_qmm_id": gate_qmm_id, + "up_qmm_id": up_qmm_id, + } + return FusedOp( + rule=self.name, + kind="quantized_shared_input_swiglu", + node_ids=match.node_ids, + input_ids=_unique_external_inputs(graph, match.node_ids), + output_id=mul_id, + output_meta=output_meta, + ops=tuple(graph.node(node_id).op for node_id in match.node_ids), + attrs=attrs, + ) diff --git a/src/zmlx/fusion/runtime.py b/src/zmlx/fusion/runtime.py new file mode 100644 index 0000000..d399253 --- /dev/null +++ b/src/zmlx/fusion/runtime.py @@ -0,0 +1,725 @@ +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Any + +from .._compat import import_mx +from ..metal import kernel as metal_kernel +from ..msl import DEFAULT_HEADER +from .cache import GLOBAL_FUSION_COMPILE_CACHE, FusionCompileCache, FusionCompileKey +from .codegen import ( + FusedGraph, + GeneratedQuantizedSharedInputSource, + generate_broadcast_mul_reduce_two_pass_source, + generate_fused_elementwise_source, + generate_quantized_shared_input_source, + normalize_descriptor, +) + + +@dataclass(frozen=True) +class CompiledFusedElementwise: + """Compiled fused kernel object cached by graph + concrete input metadata.""" + + cache_key: FusionCompileKey + op_signature: str + kernel_name: str + input_names: tuple[str, ...] + output_name: str + source: str + n_elements: int + kernel: Any + + +def _prod(shape: Sequence[int]) -> int: + n = 1 + for d in shape: + n *= int(d) + return int(n) + + +def _shape_of(x: Any) -> tuple[int, ...]: + return tuple(int(d) for d in x.shape) + + +def _dtype_of(x: Any) -> Any: + return x.dtype + + +def _as_input_list(inputs: Sequence[Any] | Any) -> list[Any]: + if isinstance(inputs, (list, tuple)): + return list(inputs) + return [inputs] + + +def _validate_inputs(inputs: list[Any], graph: FusedGraph) -> tuple[int, ...]: + if not inputs: + raise ValueError("fused runtime requires at least one input") + if len(inputs) != len(graph.input_names): + raise ValueError( + f"expected {len(graph.input_names)} inputs for graph, got {len(inputs)}" + ) + + shape0 = _shape_of(inputs[0]) + for i, x in enumerate(inputs[1:], start=1): + shape_i = _shape_of(x) + if shape_i != shape0: + raise ValueError( + "phase-1 fused elementwise launch requires identical shapes; " + f"input0={shape0} input{i}={shape_i}" + ) + return shape0 + + +def _cache_key_for( + *, + op_signature: str, + inputs: Sequence[Any], +) -> FusionCompileKey: + return FusionCompileKey.from_parts( + op_signature=op_signature, + input_shapes=[_shape_of(x) for x in inputs], + input_dtypes=[_dtype_of(x) for x in inputs], + ) + + +def compile_fused_elementwise( + descriptor: FusedGraph | dict[str, Any], + inputs: Sequence[Any], + *, + cache: FusionCompileCache | None = None, + header: str = DEFAULT_HEADER, + ensure_row_contiguous: bool = True, + kernel_name: str | None = None, +) -> CompiledFusedElementwise: + """Compile (or cache-hit) a fused elementwise Metal kernel for concrete inputs.""" + + graph = normalize_descriptor(descriptor) + inputs_list = _as_input_list(inputs) + shape0 = _validate_inputs(inputs_list, graph) + n_elements = _prod(shape0) + + generated = generate_fused_elementwise_source( + graph, + n_elements=n_elements, + kernel_name=kernel_name, + output_name="out", + ) + key = _cache_key_for(op_signature=generated.op_signature, inputs=inputs_list) + + cache_obj = cache or GLOBAL_FUSION_COMPILE_CACHE + cached = cache_obj.get(key) + if cached is not None: + return cached # type: ignore[no-any-return] + + compiled_kernel = metal_kernel( + name=generated.kernel_name, + input_names=list(generated.input_names), + output_names=[generated.output_name], + source=generated.source, + header=header or DEFAULT_HEADER, + ensure_row_contiguous=ensure_row_contiguous, + cache=True, + ) + + compiled = CompiledFusedElementwise( + cache_key=key, + op_signature=generated.op_signature, + kernel_name=generated.kernel_name, + input_names=generated.input_names, + output_name=generated.output_name, + source=generated.source, + n_elements=n_elements, + kernel=compiled_kernel, + ) + return cache_obj.put(key, compiled) # type: ignore[no-any-return] + + +def launch_fused_elementwise( + descriptor: FusedGraph | dict[str, Any], + *inputs: Any, + output_dtype: Any | None = None, + compute_dtype: Any | None = None, + threadgroup_x: int = 256, + cache: FusionCompileCache | None = None, + header: str = DEFAULT_HEADER, + ensure_row_contiguous: bool = True, + kernel_name: str | None = None, +) -> Any: + """Compile+launch a fused elementwise kernel with concrete MLX inputs.""" + + if len(inputs) == 1 and isinstance(inputs[0], (list, tuple)): + input_list = list(inputs[0]) + else: + input_list = list(inputs) + + graph = normalize_descriptor(descriptor) + shape0 = _validate_inputs(input_list, graph) + n_elements = _prod(shape0) + + compiled = compile_fused_elementwise( + graph, + input_list, + cache=cache, + header=header, + ensure_row_contiguous=ensure_row_contiguous, + kernel_name=kernel_name, + ) + + mx = import_mx() + effective_compute_dtype = compute_dtype if compute_dtype is not None else mx.float32 + effective_output_dtype = output_dtype if output_dtype is not None else input_list[0].dtype + + tgx = int(threadgroup_x) + if tgx <= 0: + raise ValueError("threadgroup_x must be > 0") + if n_elements > 0: + tgx = min(tgx, n_elements) + else: + tgx = 1 + + out = compiled.kernel( + *input_list, + template=[("T", effective_compute_dtype), ("O", effective_output_dtype)], + grid=(n_elements, 1, 1), + threadgroup=(tgx, 1, 1), + output_shapes=[shape0], + output_dtypes=[effective_output_dtype], + )[0] + return out + + +def run_fused_elementwise( + descriptor: FusedGraph | dict[str, Any], + inputs: Sequence[Any], + *, + output_dtype: Any | None = None, + compute_dtype: Any | None = None, + threadgroup_x: int = 256, + cache: FusionCompileCache | None = None, + header: str = DEFAULT_HEADER, + ensure_row_contiguous: bool = True, + kernel_name: str | None = None, +) -> Any: + """Sequence-based wrapper over launch_fused_elementwise().""" + + return launch_fused_elementwise( + descriptor, + list(inputs), + output_dtype=output_dtype, + compute_dtype=compute_dtype, + threadgroup_x=threadgroup_x, + cache=cache, + header=header, + ensure_row_contiguous=ensure_row_contiguous, + kernel_name=kernel_name, + ) + + +def compile_and_launch_fused_elementwise( + descriptor: FusedGraph | dict[str, Any], + inputs: Sequence[Any], + *, + output_dtype: Any | None = None, + compute_dtype: Any | None = None, + threadgroup_x: int = 256, + cache: FusionCompileCache | None = None, + header: str = DEFAULT_HEADER, + ensure_row_contiguous: bool = True, + kernel_name: str | None = None, +) -> Any: + """Compatibility alias for one-shot fused elementwise execution.""" + + return run_fused_elementwise( + descriptor, + inputs, + output_dtype=output_dtype, + compute_dtype=compute_dtype, + threadgroup_x=threadgroup_x, + cache=cache, + header=header, + ensure_row_contiguous=ensure_row_contiguous, + kernel_name=kernel_name, + ) + + +@dataclass(frozen=True) +class CompiledFusedReductionTwoPass: + """Compiled two-pass reduction fusion kernels.""" + + cache_key: FusionCompileKey + op_signature: str + reduce_op: str + axis: int + keepdims: bool + no_fma: bool + rhs_mode: str + source_pass1: str + source_pass2: str + pass1_kernel_name: str + pass2_kernel_name: str + rows: int + reduce_extent: int + partial_cols: int + output_shape: tuple[int, ...] + pass1_kernel: Any + pass2_kernel: Any + + +@dataclass(frozen=True) +class CompiledQuantizedSharedInputFusion: + """Compiled decode-friendly quantized shared-input fusion kernel.""" + + cache_key: FusionCompileKey + op_signature: str + kernel_name: str + source: str + bits: int + group_size: int + n: int + k: int + output_shape: tuple[int, ...] + kernel: Any + + +def _normalize_axis(shape: tuple[int, ...], axis: int) -> int: + rank = len(shape) + if rank == 0: + raise ValueError("Cannot reduce scalar input.") + axis_norm = int(axis) + if axis_norm < 0: + axis_norm += rank + if axis_norm < 0 or axis_norm >= rank: + raise ValueError(f"axis {axis} out of range for shape {shape}") + return axis_norm + + +def _output_shape_for_reduction(shape: tuple[int, ...], *, axis: int, keepdims: bool) -> tuple[int, ...]: + dims = list(shape) + if keepdims: + dims[axis] = 1 + else: + dims.pop(axis) + return tuple(dims) + + +def _rhs_mode_for(lhs_shape: tuple[int, ...], rhs_shape: tuple[int, ...]) -> str: + if rhs_shape == lhs_shape: + return "full" + if rhs_shape == (): + return "scalar" + if rhs_shape == lhs_shape[:-1]: + return "row" + if rhs_shape == lhs_shape[:-1] + (1,): + return "row" + if rhs_shape == (1,): + return "scalar" + raise ValueError( + "Unsupported broadcast shape for fused mul+reduce " + f"(lhs={lhs_shape}, rhs={rhs_shape})." + ) + + +def compile_broadcast_mul_reduce_two_pass( + descriptor: dict[str, Any], + inputs: Sequence[Any], + *, + cache: FusionCompileCache | None = None, + header: str = DEFAULT_HEADER, + ensure_row_contiguous: bool = True, + threadgroup_x: int = 256, +) -> CompiledFusedReductionTwoPass: + """Compile a two-pass ``mul + reduce`` kernel pair.""" + + input_list = _as_input_list(inputs) + if len(input_list) != 2: + raise ValueError( + f"broadcast_mul_reduce expects 2 inputs (lhs, rhs), got {len(input_list)}" + ) + + lhs, rhs = input_list + lhs_shape = _shape_of(lhs) + rhs_shape = _shape_of(rhs) + reduce_op = str(descriptor.get("reduce_op", "sum")) + keepdims = bool(descriptor.get("keepdims", False)) + no_fma = bool(descriptor.get("no_fma", False)) + axis = _normalize_axis(lhs_shape, int(descriptor.get("axis", -1))) + if axis != len(lhs_shape) - 1: + raise ValueError( + f"Only last-axis reductions are supported (axis={axis}, shape={lhs_shape})." + ) + if reduce_op not in {"sum", "mean", "max"}: + raise ValueError(f"Unsupported reduce_op {reduce_op!r}.") + + rhs_mode = _rhs_mode_for(lhs_shape, rhs_shape) + rows = _prod(lhs_shape[:-1]) if lhs_shape[:-1] else 1 + reduce_extent = int(lhs_shape[-1]) + if reduce_extent <= 0: + raise ValueError("Reduction extent must be > 0 for fusion.") + + tgx = int(threadgroup_x) + if tgx <= 0: + raise ValueError("threadgroup_x must be > 0") + tgx = max(1, min(tgx, 512, reduce_extent)) + partial_cols = max(1, (reduce_extent + tgx - 1) // tgx) + + generated = generate_broadcast_mul_reduce_two_pass_source( + rows=rows, + reduce_extent=reduce_extent, + partial_cols=partial_cols, + reduce_op=reduce_op, + rhs_mode=rhs_mode, + no_fma=no_fma, + ) + key = _cache_key_for(op_signature=generated.op_signature, inputs=input_list) + + cache_obj = cache or GLOBAL_FUSION_COMPILE_CACHE + cached = cache_obj.get(key) + if cached is not None: + return cached # type: ignore[no-any-return] + + pass1_kernel = metal_kernel( + name=generated.pass1_kernel_name, + input_names=["lhs", "rhs"], + output_names=["partials"], + source=generated.source_pass1, + header=header or DEFAULT_HEADER, + ensure_row_contiguous=ensure_row_contiguous, + cache=True, + ) + pass2_kernel = metal_kernel( + name=generated.pass2_kernel_name, + input_names=["partials"], + output_names=["out"], + source=generated.source_pass2, + header=header or DEFAULT_HEADER, + ensure_row_contiguous=ensure_row_contiguous, + cache=True, + ) + + compiled = CompiledFusedReductionTwoPass( + cache_key=key, + op_signature=generated.op_signature, + reduce_op=reduce_op, + axis=axis, + keepdims=keepdims, + no_fma=no_fma, + rhs_mode=rhs_mode, + source_pass1=generated.source_pass1, + source_pass2=generated.source_pass2, + pass1_kernel_name=generated.pass1_kernel_name, + pass2_kernel_name=generated.pass2_kernel_name, + rows=rows, + reduce_extent=reduce_extent, + partial_cols=partial_cols, + output_shape=_output_shape_for_reduction(lhs_shape, axis=axis, keepdims=keepdims), + pass1_kernel=pass1_kernel, + pass2_kernel=pass2_kernel, + ) + return cache_obj.put(key, compiled) # type: ignore[no-any-return] + + +def launch_broadcast_mul_reduce_two_pass( + descriptor: dict[str, Any], + *inputs: Any, + output_dtype: Any | None = None, + cache: FusionCompileCache | None = None, + header: str = DEFAULT_HEADER, + ensure_row_contiguous: bool = True, + threadgroup_x: int = 256, +) -> Any: + """Compile and launch two-pass ``mul + reduce`` fusion.""" + + if len(inputs) == 1 and isinstance(inputs[0], (list, tuple)): + input_list = list(inputs[0]) + else: + input_list = list(inputs) + if len(input_list) != 2: + raise ValueError("broadcast_mul_reduce runtime expects exactly 2 inputs.") + + compiled = compile_broadcast_mul_reduce_two_pass( + descriptor, + input_list, + cache=cache, + header=header, + ensure_row_contiguous=ensure_row_contiguous, + threadgroup_x=threadgroup_x, + ) + + mx = import_mx() + lhs = input_list[0] + out_dtype = output_dtype if output_dtype is not None else lhs.dtype + tgx = int(threadgroup_x) + tgx = max(1, min(tgx, 512, compiled.reduce_extent)) + + partials = compiled.pass1_kernel( + lhs, + input_list[1], + grid=(tgx, compiled.rows * compiled.partial_cols, 1), + threadgroup=(tgx, 1, 1), + output_shapes=[(compiled.rows, compiled.partial_cols)], + output_dtypes=[mx.float32], + )[0] + out_flat = compiled.pass2_kernel( + partials, + template=[("O", out_dtype)], + grid=(tgx, compiled.rows, 1), + threadgroup=(tgx, 1, 1), + output_shapes=[(compiled.rows,)], + output_dtypes=[out_dtype], + )[0] + return out_flat.reshape(compiled.output_shape) + + +def _validate_quantized_shared_input_inputs( + descriptor: dict[str, Any], + inputs: list[Any], +) -> tuple[int, int, int, int, Any, Any, Any, Any, Any, Any, Any]: + if len(inputs) != 7: + raise ValueError( + "quantized_shared_input_swiglu expects 7 inputs " + "(x, gate_w, gate_scales, gate_biases, up_w, up_scales, up_biases)." + ) + x, gate_w, gate_scales, gate_biases, up_w, up_scales, up_biases = inputs + + mode = str(descriptor.get("mode", "")) + transpose = bool(descriptor.get("transpose", False)) + bits = int(descriptor.get("bits", 0)) + group_size = int(descriptor.get("group_size", 0)) + if mode != "affine" or not transpose: + raise ValueError("Only affine + transpose=True quantized fusion is supported.") + if bits not in (4, 8): + raise ValueError("Only 4-bit/8-bit quantized fusion is supported.") + if group_size <= 0: + raise ValueError("group_size must be > 0.") + + x_shape = _shape_of(x) + if len(x_shape) == 1: + m = 1 + k = int(x_shape[0]) + x_vec = x + elif len(x_shape) == 2 and int(x_shape[0]) == 1: + m = 1 + k = int(x_shape[1]) + x_vec = x.reshape((k,)) + else: + raise ValueError( + f"Decode-only M=1 shape is required for quantized fusion, got x.shape={x_shape}." + ) + + gate_w_shape = _shape_of(gate_w) + up_w_shape = _shape_of(up_w) + if gate_w_shape != up_w_shape or len(gate_w_shape) != 2: + raise ValueError("gate_w/up_w must be matching rank-2 packed tensors.") + n = int(gate_w_shape[0]) + k_packed = int(gate_w_shape[1]) + elements_per_word = 32 // bits + if k_packed * elements_per_word != k: + raise ValueError("Packed weight shape does not match x K dimension.") + + if _shape_of(gate_scales) != _shape_of(up_scales): + raise ValueError("gate_scales and up_scales must match.") + if _shape_of(gate_biases) != _shape_of(up_biases): + raise ValueError("gate_biases and up_biases must match.") + expected_groups = k // group_size + if _shape_of(gate_scales) != (n, expected_groups): + raise ValueError( + "Unsupported scales shape for quantized fusion " + f"(got {_shape_of(gate_scales)}, expected {(n, expected_groups)})." + ) + if _shape_of(gate_biases) != (n, expected_groups): + raise ValueError( + "Unsupported biases shape for quantized fusion " + f"(got {_shape_of(gate_biases)}, expected {(n, expected_groups)})." + ) + if m != 1: + raise ValueError("Only M=1 is supported in this phase.") + + return ( + n, + k, + bits, + group_size, + x_vec, + gate_w, + gate_scales, + gate_biases, + up_w, + up_scales, + up_biases, + ) + + +def compile_quantized_shared_input_swiglu( + descriptor: dict[str, Any], + inputs: Sequence[Any], + *, + cache: FusionCompileCache | None = None, + header: str = DEFAULT_HEADER, + ensure_row_contiguous: bool = True, +) -> CompiledQuantizedSharedInputFusion: + """Compile decode-oriented quantized shared-input fusion.""" + + input_list = _as_input_list(inputs) + ( + n, + k, + bits, + group_size, + _x_vec, + _gate_w, + _gate_scales, + _gate_biases, + _up_w, + _up_scales, + _up_biases, + ) = _validate_quantized_shared_input_inputs(descriptor, input_list) + no_fma = bool(descriptor.get("no_fma", False)) + + generated: GeneratedQuantizedSharedInputSource = generate_quantized_shared_input_source( + n=n, + k=k, + bits=bits, + group_size=group_size, + no_fma=no_fma, + ) + key = _cache_key_for(op_signature=generated.op_signature, inputs=input_list) + cache_obj = cache or GLOBAL_FUSION_COMPILE_CACHE + cached = cache_obj.get(key) + if cached is not None: + return cached # type: ignore[no-any-return] + + kernel = metal_kernel( + name=generated.kernel_name, + input_names=[ + "x", + "gate_w", + "gate_scales", + "gate_biases", + "up_w", + "up_scales", + "up_biases", + ], + output_names=["out"], + source=generated.source, + header=header or DEFAULT_HEADER, + ensure_row_contiguous=ensure_row_contiguous, + cache=True, + ) + + compiled = CompiledQuantizedSharedInputFusion( + cache_key=key, + op_signature=generated.op_signature, + kernel_name=generated.kernel_name, + source=generated.source, + bits=bits, + group_size=group_size, + n=n, + k=k, + output_shape=(1, n), + kernel=kernel, + ) + return cache_obj.put(key, compiled) # type: ignore[no-any-return] + + +def launch_quantized_shared_input_swiglu( + descriptor: dict[str, Any], + *inputs: Any, + output_dtype: Any | None = None, + cache: FusionCompileCache | None = None, + header: str = DEFAULT_HEADER, + ensure_row_contiguous: bool = True, +) -> Any: + """Compile and launch decode-oriented quantized shared-input fusion.""" + + if len(inputs) == 1 and isinstance(inputs[0], (list, tuple)): + input_list = list(inputs[0]) + else: + input_list = list(inputs) + ( + _n, + _k, + _bits, + _group_size, + x_vec, + gate_w, + gate_scales, + gate_biases, + up_w, + up_scales, + up_biases, + ) = _validate_quantized_shared_input_inputs(descriptor, input_list) + + compiled = compile_quantized_shared_input_swiglu( + descriptor, + input_list, + cache=cache, + header=header, + ensure_row_contiguous=ensure_row_contiguous, + ) + + mx = import_mx() + out_dtype = output_dtype if output_dtype is not None else mx.float32 + out = compiled.kernel( + x_vec, + gate_w, + gate_scales, + gate_biases, + up_w, + up_scales, + up_biases, + template=[("O", out_dtype)], + grid=(compiled.n, 1, 1), + threadgroup=(min(compiled.n, 256), 1, 1), + output_shapes=[(compiled.n,)], + output_dtypes=[out_dtype], + )[0] + return out.reshape(compiled.output_shape) + + +def export_compiled_aot_payload( + *, + descriptor: dict[str, Any], + compiled: CompiledFusedElementwise + | CompiledFusedReductionTwoPass + | CompiledQuantizedSharedInputFusion, +) -> dict[str, Any]: + """Return a minimal AOT scaffold payload (descriptor + source).""" + + if isinstance(compiled, CompiledFusedElementwise): + sources = {"kernel": compiled.source} + kernels = {"kernel_name": compiled.kernel_name} + elif isinstance(compiled, CompiledFusedReductionTwoPass): + sources = { + "pass1": compiled.source_pass1, + "pass2": compiled.source_pass2, + } + kernels = { + "pass1_kernel_name": compiled.pass1_kernel_name, + "pass2_kernel_name": compiled.pass2_kernel_name, + } + else: + sources = {"kernel": compiled.source} + kernels = {"kernel_name": compiled.kernel_name} + return { + "descriptor": descriptor, + "kernels": kernels, + "sources": sources, + } + + +__all__ = [ + "CompiledFusedElementwise", + "CompiledFusedReductionTwoPass", + "CompiledQuantizedSharedInputFusion", + "compile_fused_elementwise", + "launch_fused_elementwise", + "run_fused_elementwise", + "compile_and_launch_fused_elementwise", + "compile_broadcast_mul_reduce_two_pass", + "launch_broadcast_mul_reduce_two_pass", + "compile_quantized_shared_input_swiglu", + "launch_quantized_shared_input_swiglu", + "export_compiled_aot_payload", +] diff --git a/src/zmlx/fusion/trace.py b/src/zmlx/fusion/trace.py new file mode 100644 index 0000000..5fb1051 --- /dev/null +++ b/src/zmlx/fusion/trace.py @@ -0,0 +1,542 @@ +"""Symbolic tracing with proxy tensors for phase-1/2/3 JIT fusion.""" + +from __future__ import annotations + +from collections.abc import Callable, Sequence +from dataclasses import dataclass +from typing import Any, Protocol, runtime_checkable + +from .graph import Graph, TensorMeta + +Scalar = bool | int | float + + +@runtime_checkable +class TensorLike(Protocol): + """Minimal tensor protocol required by the symbolic tracer.""" + + shape: Any + dtype: Any + + +@dataclass(frozen=True, slots=True) +class TraceResult: + """Output of symbolic tracing.""" + + graph: Graph + input_ids: tuple[int, ...] + output_ids: tuple[int, ...] + + +@dataclass(frozen=True, slots=True) +class ProxyTensor: + """Tensor proxy that records operations to a symbolic graph.""" + + _tracer: SymbolicTracer + _node_id: int + _meta: TensorMeta + + @property + def node_id(self) -> int: + return self._node_id + + @property + def meta(self) -> TensorMeta: + return self._meta + + @property + def shape(self) -> tuple[int, ...]: + return self._meta.shape + + @property + def dtype(self) -> str: + return self._meta.dtype + + def __add__(self, other: ProxyTensor | Scalar) -> ProxyTensor: + return self._tracer._binary("add", self, other) + + def __radd__(self, other: Scalar) -> ProxyTensor: + return self._tracer._binary("add", other, self) + + def __sub__(self, other: ProxyTensor | Scalar) -> ProxyTensor: + return self._tracer._binary("sub", self, other) + + def __rsub__(self, other: Scalar) -> ProxyTensor: + return self._tracer._binary("sub", other, self) + + def __mul__(self, other: ProxyTensor | Scalar) -> ProxyTensor: + return self._tracer._binary("mul", self, other) + + def __rmul__(self, other: Scalar) -> ProxyTensor: + return self._tracer._binary("mul", other, self) + + def __truediv__(self, other: ProxyTensor | Scalar) -> ProxyTensor: + return self._tracer._binary("div", self, other) + + def __rtruediv__(self, other: Scalar) -> ProxyTensor: + return self._tracer._binary("div", other, self) + + def __neg__(self) -> ProxyTensor: + return self._tracer._unary("neg", self) + + def exp(self) -> ProxyTensor: + return self._tracer._unary("exp", self) + + def tanh(self) -> ProxyTensor: + return self._tracer._unary("tanh", self) + + def sigmoid(self) -> ProxyTensor: + return self._tracer._unary("sigmoid", self) + + def relu(self) -> ProxyTensor: + return self._tracer._unary("relu", self) + + def silu(self) -> ProxyTensor: + return self._tracer._unary("silu", self) + + def log(self) -> ProxyTensor: + return self._tracer._unary("log", self) + + def sqrt(self) -> ProxyTensor: + return self._tracer._unary("sqrt", self) + + def rsqrt(self) -> ProxyTensor: + return self._tracer._unary("rsqrt", self) + + def maximum(self, other: ProxyTensor | Scalar) -> ProxyTensor: + return self._tracer._binary("maximum", self, other) + + def minimum(self, other: ProxyTensor | Scalar) -> ProxyTensor: + return self._tracer._binary("minimum", self, other) + + def sum(self, *, axis: int = -1, keepdims: bool = False) -> ProxyTensor: + return self._tracer._reduce("sum", self, axis=axis, keepdims=keepdims) + + def mean(self, *, axis: int = -1, keepdims: bool = False) -> ProxyTensor: + return self._tracer._reduce("mean", self, axis=axis, keepdims=keepdims) + + def max(self, *, axis: int = -1, keepdims: bool = False) -> ProxyTensor: + return self._tracer._reduce("max", self, axis=axis, keepdims=keepdims) + + +class SymbolicTracer: + """Constructs a symbolic graph by executing a function on proxies.""" + + def __init__(self, graph: Graph | None = None) -> None: + self.graph = graph or Graph() + + def input(self, *, name: str, meta: TensorMeta) -> ProxyTensor: + """Create an input proxy tensor.""" + node_id = self.graph.add_input(name=name, meta=meta) + return ProxyTensor(self, node_id, meta) + + def constant(self, value: Scalar, *, dtype: str) -> ProxyTensor: + """Create a scalar-literal proxy node.""" + meta = TensorMeta(shape=(), dtype=dtype) + node_id = self.graph.add_constant(value=value, meta=meta) + return ProxyTensor(self, node_id, meta) + + def _unary(self, op: str, value: ProxyTensor) -> ProxyTensor: + proxy = self._ensure_proxy(value) + node_id = self.graph.add_op(op, inputs=(proxy.node_id,), meta=proxy.meta) + return ProxyTensor(self, node_id, proxy.meta) + + def _binary( + self, + op: str, + lhs: ProxyTensor | Scalar, + rhs: ProxyTensor | Scalar, + ) -> ProxyTensor: + left, right = self._coerce_binary_operands(lhs, rhs) + out_meta = _binary_output_meta(left.meta, right.meta) + node_id = self.graph.add_op(op, inputs=(left.node_id, right.node_id), meta=out_meta) + return ProxyTensor(self, node_id, out_meta) + + def _reduce( + self, + op: str, + value: ProxyTensor, + *, + axis: int = -1, + keepdims: bool = False, + ) -> ProxyTensor: + proxy = self._ensure_proxy(value) + axis_norm = _normalize_reduction_axis(proxy.meta.shape, axis) + out_meta = _reduction_output_meta(proxy.meta, axis=axis_norm, keepdims=keepdims) + node_id = self.graph.add_op( + op, + inputs=(proxy.node_id,), + meta=out_meta, + attrs={"axis": axis_norm, "keepdims": bool(keepdims)}, + ) + return ProxyTensor(self, node_id, out_meta) + + def quantized_matmul( + self, + lhs: ProxyTensor, + rhs: ProxyTensor, + *, + scales: ProxyTensor | None = None, + biases: ProxyTensor | None = None, + group_size: int | None = None, + bits: int | None = None, + mode: str | None = None, + transpose: bool = False, + ) -> ProxyTensor: + left = self._ensure_proxy(lhs) + right = self._ensure_proxy(rhs) + out_meta = _quantized_matmul_output_meta( + left.meta, + right.meta, + transpose=transpose, + bits=bits, + ) + + input_ids: list[int] = [left.node_id, right.node_id] + if scales is not None: + input_ids.append(self._ensure_proxy(scales).node_id) + if biases is not None: + input_ids.append(self._ensure_proxy(biases).node_id) + + node_id = self.graph.add_op( + "quantized_matmul", + inputs=tuple(input_ids), + meta=out_meta, + attrs={ + "group_size": group_size, + "bits": bits, + "mode": mode, + "transpose": bool(transpose), + }, + ) + return ProxyTensor(self, node_id, out_meta) + + def _coerce_binary_operands( + self, + lhs: ProxyTensor | Scalar, + rhs: ProxyTensor | Scalar, + ) -> tuple[ProxyTensor, ProxyTensor]: + if isinstance(lhs, ProxyTensor) and isinstance(rhs, ProxyTensor): + return self._ensure_proxy(lhs), self._ensure_proxy(rhs) + if isinstance(lhs, ProxyTensor): + left = self._ensure_proxy(lhs) + right = self.constant(_ensure_scalar(rhs), dtype=left.dtype) + return left, right + if isinstance(rhs, ProxyTensor): + right = self._ensure_proxy(rhs) + left = self.constant(_ensure_scalar(lhs), dtype=right.dtype) + return left, right + raise TypeError("At least one operand must be a ProxyTensor.") + + def _ensure_proxy(self, value: ProxyTensor) -> ProxyTensor: + if value._tracer is not self: + raise ValueError("Cannot mix ProxyTensor objects from different traces.") + return value + + +def symbolic_trace( + fn: Callable[..., ProxyTensor | Sequence[ProxyTensor]], + *inputs: TensorMeta | TensorLike, + input_names: Sequence[str] | None = None, +) -> TraceResult: + """Trace ``fn`` with proxy tensors and return a symbolic graph.""" + tracer = SymbolicTracer() + names = tuple(input_names or ()) + if names and len(names) != len(inputs): + raise ValueError("input_names length must match number of inputs.") + + proxy_inputs: list[ProxyTensor] = [] + for idx, raw_input in enumerate(inputs): + meta = _to_meta(raw_input) + name = names[idx] if names else f"arg{idx}" + proxy_inputs.append(tracer.input(name=name, meta=meta)) + + raw_outputs = fn(*proxy_inputs) + outputs = _normalize_outputs(raw_outputs) + for out in outputs: + if out._tracer is not tracer: + raise ValueError("Trace output contains ProxyTensor from a different tracer.") + + return TraceResult( + graph=tracer.graph, + input_ids=tuple(proxy.node_id for proxy in proxy_inputs), + output_ids=tuple(proxy.node_id for proxy in outputs), + ) + + +def add(lhs: ProxyTensor | Scalar, rhs: ProxyTensor | Scalar) -> ProxyTensor: + return _binary_op("add", lhs, rhs) + + +def sub(lhs: ProxyTensor | Scalar, rhs: ProxyTensor | Scalar) -> ProxyTensor: + return _binary_op("sub", lhs, rhs) + + +def mul(lhs: ProxyTensor | Scalar, rhs: ProxyTensor | Scalar) -> ProxyTensor: + return _binary_op("mul", lhs, rhs) + + +def div(lhs: ProxyTensor | Scalar, rhs: ProxyTensor | Scalar) -> ProxyTensor: + return _binary_op("div", lhs, rhs) + + +def neg(x: ProxyTensor) -> ProxyTensor: + return x._tracer._unary("neg", x) + + +def exp(x: ProxyTensor) -> ProxyTensor: + return x._tracer._unary("exp", x) + + +def tanh(x: ProxyTensor) -> ProxyTensor: + return x._tracer._unary("tanh", x) + + +def sigmoid(x: ProxyTensor) -> ProxyTensor: + return x._tracer._unary("sigmoid", x) + + +def relu(x: ProxyTensor) -> ProxyTensor: + return x._tracer._unary("relu", x) + + +def silu(x: ProxyTensor) -> ProxyTensor: + return x._tracer._unary("silu", x) + + +def log(x: ProxyTensor) -> ProxyTensor: + return x._tracer._unary("log", x) + + +def sqrt(x: ProxyTensor) -> ProxyTensor: + return x._tracer._unary("sqrt", x) + + +def rsqrt(x: ProxyTensor) -> ProxyTensor: + return x._tracer._unary("rsqrt", x) + + +def maximum(lhs: ProxyTensor | Scalar, rhs: ProxyTensor | Scalar) -> ProxyTensor: + return _binary_op("maximum", lhs, rhs) + + +def minimum(lhs: ProxyTensor | Scalar, rhs: ProxyTensor | Scalar) -> ProxyTensor: + return _binary_op("minimum", lhs, rhs) + + +def reduce_sum(x: ProxyTensor, *, axis: int = -1, keepdims: bool = False) -> ProxyTensor: + return x._tracer._reduce("sum", x, axis=axis, keepdims=keepdims) + + +def reduce_mean(x: ProxyTensor, *, axis: int = -1, keepdims: bool = False) -> ProxyTensor: + return x._tracer._reduce("mean", x, axis=axis, keepdims=keepdims) + + +def reduce_max(x: ProxyTensor, *, axis: int = -1, keepdims: bool = False) -> ProxyTensor: + return x._tracer._reduce("max", x, axis=axis, keepdims=keepdims) + + +def quantized_matmul( + lhs: ProxyTensor, + rhs: ProxyTensor, + *, + scales: ProxyTensor | None = None, + biases: ProxyTensor | None = None, + group_size: int | None = None, + bits: int | None = None, + mode: str | None = None, + transpose: bool = False, +) -> ProxyTensor: + return lhs._tracer.quantized_matmul( + lhs, + rhs, + scales=scales, + biases=biases, + group_size=group_size, + bits=bits, + mode=mode, + transpose=transpose, + ) + + +def _binary_op(op: str, lhs: ProxyTensor | Scalar, rhs: ProxyTensor | Scalar) -> ProxyTensor: + tracer = _resolve_tracer(lhs, rhs) + return tracer._binary(op, lhs, rhs) + + +def _resolve_tracer(*values: ProxyTensor | Scalar) -> SymbolicTracer: + tracer: SymbolicTracer | None = None + for value in values: + if not isinstance(value, ProxyTensor): + continue + if tracer is None: + tracer = value._tracer + continue + if value._tracer is not tracer: + raise ValueError("Cannot mix ProxyTensor objects from different traces.") + if tracer is None: + raise TypeError("At least one operand must be a ProxyTensor.") + return tracer + + +def _normalize_outputs(raw_outputs: ProxyTensor | Sequence[ProxyTensor]) -> tuple[ProxyTensor, ...]: + if isinstance(raw_outputs, ProxyTensor): + return (raw_outputs,) + if not isinstance(raw_outputs, Sequence): + raise TypeError("Trace function must return ProxyTensor or sequence of ProxyTensor.") + outputs = tuple(raw_outputs) + if not outputs: + raise ValueError("Trace function must return at least one output.") + for out in outputs: + if not isinstance(out, ProxyTensor): + raise TypeError("All trace outputs must be ProxyTensor instances.") + return outputs + + +def _to_meta(value: TensorMeta | TensorLike) -> TensorMeta: + if isinstance(value, TensorMeta): + return value + if isinstance(value, TensorLike): + return TensorMeta(shape=_to_shape(value.shape), dtype=_to_dtype(value.dtype)) + raise TypeError("Trace inputs must be TensorMeta or TensorLike.") + + +def _to_shape(shape: Any) -> tuple[int, ...]: + if isinstance(shape, int): + return (int(shape),) + return tuple(int(dim) for dim in shape) + + +def _to_dtype(dtype: Any) -> str: + if isinstance(dtype, str): + return dtype + if hasattr(dtype, "name"): + return str(dtype.name) + return str(dtype) + + +def _ensure_scalar(value: ProxyTensor | Scalar) -> Scalar: + if isinstance(value, ProxyTensor): + raise TypeError("Expected scalar literal, got ProxyTensor.") + if isinstance(value, bool): + return value + if isinstance(value, (int, float)): + return value + raise TypeError(f"Unsupported scalar literal type: {type(value).__name__}") + + +def _binary_output_meta(lhs: TensorMeta, rhs: TensorMeta) -> TensorMeta: + if not lhs.is_scalar and not rhs.is_scalar and lhs.dtype != rhs.dtype: + raise ValueError( + "Elementwise tensor operands must share dtype " + f"(got {lhs.dtype!r} and {rhs.dtype!r})." + ) + out_shape = _broadcast_shape(lhs.shape, rhs.shape) + out_dtype = lhs.dtype + if lhs.is_scalar and not rhs.is_scalar: + out_dtype = rhs.dtype + return TensorMeta(shape=out_shape, dtype=out_dtype) + + +def _broadcast_shape(lhs: tuple[int, ...], rhs: tuple[int, ...]) -> tuple[int, ...]: + if lhs == rhs: + return lhs + if not lhs: + return rhs + if not rhs: + return lhs + + lhs_rank = len(lhs) + rhs_rank = len(rhs) + max_rank = lhs_rank if lhs_rank >= rhs_rank else rhs_rank + lhs_pad = (1,) * (max_rank - lhs_rank) + lhs + rhs_pad = (1,) * (max_rank - rhs_rank) + rhs + + out: list[int] = [] + for idx, (lhs_dim, rhs_dim) in enumerate(zip(lhs_pad, rhs_pad, strict=True)): + if lhs_dim == rhs_dim: + out.append(lhs_dim) + continue + if lhs_dim == 1: + out.append(rhs_dim) + continue + if rhs_dim == 1: + out.append(lhs_dim) + continue + raise ValueError( + "Elementwise operands are not broadcast-compatible " + f"(lhs={lhs}, rhs={rhs}, mismatch at dim {idx})." + ) + return tuple(out) + + +def _normalize_reduction_axis(shape: tuple[int, ...], axis: int) -> int: + if not shape: + raise ValueError("Reduction over scalar tensors is unsupported.") + if not isinstance(axis, int): + raise TypeError(f"Reduction axis must be int, got {type(axis).__name__}.") + + rank = len(shape) + axis_norm = axis + rank if axis < 0 else axis + if axis_norm < 0 or axis_norm >= rank: + raise ValueError(f"Reduction axis {axis} is out of range for shape {shape}.") + if axis_norm != rank - 1: + raise ValueError( + "Only last-axis reductions are supported in phase-2 tracing " + f"(axis={axis}, shape={shape})." + ) + return axis_norm + + +def _reduction_output_meta( + value: TensorMeta, + *, + axis: int, + keepdims: bool, +) -> TensorMeta: + shape = list(value.shape) + if keepdims: + shape[axis] = 1 + else: + shape.pop(axis) + return TensorMeta(shape=tuple(shape), dtype=value.dtype) + + +def _quantized_matmul_output_meta( + lhs: TensorMeta, + rhs: TensorMeta, + *, + transpose: bool, + bits: int | None, +) -> TensorMeta: + if len(rhs.shape) != 2: + raise ValueError( + "quantized_matmul rhs must be rank-2 for tracing metadata propagation " + f"(got shape {rhs.shape})." + ) + if len(lhs.shape) < 1: + raise ValueError( + "quantized_matmul lhs must be rank-1 or higher " + f"(got shape {lhs.shape})." + ) + + k_lhs = lhs.shape[-1] + pack_factor = 1 + if bits in (4, 8): + pack_factor = 32 // int(bits) + if transpose: + n = rhs.shape[0] + k_rhs = rhs.shape[1] * pack_factor + else: + k_rhs = rhs.shape[0] * pack_factor + n = rhs.shape[1] + + if k_lhs != k_rhs: + raise ValueError( + "quantized_matmul inner dimensions must agree " + f"(lhs_k={k_lhs}, rhs_k={k_rhs}, transpose={transpose})." + ) + + if len(lhs.shape) == 1: + return TensorMeta(shape=(n,), dtype=lhs.dtype) + return TensorMeta(shape=lhs.shape[:-1] + (n,), dtype=lhs.dtype) diff --git a/src/zmlx/kernels/moe.py b/src/zmlx/kernels/moe.py index 8343e67..453bbe3 100644 --- a/src/zmlx/kernels/moe.py +++ b/src/zmlx/kernels/moe.py @@ -290,6 +290,224 @@ def _topk_softmax_bias_simd_kernel(d: int, k: int, renorm: bool) -> Any: cache=True, ) +@cache +def _topk_softmax_bias_dsimd_kernel(d: int, k: int, renorm: bool) -> Any: + """Optimized kernel for D in (32, 64] using exactly D threads. + + Uses 2 SIMD groups (32 threads each) with minimal cross-group barriers. + Much faster than the generic TG=256 path for D=64 (4 barriers vs 24+). + """ + D = int(d) + K = int(k) + if D <= 32 or D > 64: + raise ValueError("_topk_softmax_bias_dsimd_kernel: requires 32 < D <= 64") + if K <= 0 or K > _MAX_FUSED_TOPK: + raise ValueError("_topk_softmax_bias_dsimd_kernel: invalid K") + + renorm_literal = str(bool(renorm)).lower() + + # TG = D rounded up to next multiple of 32 (always 64 for D in (32,64]) + TG = ((D + 31) // 32) * 32 + + source = f""" + constexpr uint D = {D}; + constexpr uint K = {K}; + constexpr uint TG = {TG}; + constexpr uint SG = 32; + constexpr bool RENORM = {renorm_literal}; + + uint gid = thread_position_in_grid.x; + uint tid = thread_position_in_threadgroup.x; + uint row = gid / TG; + uint base = row * D; + uint sg_id = tid / SG; // SIMD group index (0 or 1) + uint lane = tid % SG; // lane within SIMD group + + // --- Load one element per thread --- + float v = -INFINITY; + if (tid < D) {{ + v = (float)inp[base + tid]; + }} + + // --- Softmax: cross-SIMD max --- + float sg_max = simd_max(v); + threadgroup float cross_f[2]; + if (lane == 0) cross_f[sg_id] = sg_max; + threadgroup_barrier(mem_flags::mem_threadgroup); + float row_max = metal::max(cross_f[0], cross_f[1]); + + // --- Softmax: cross-SIMD exp-sum --- + float exp_v = (tid < D) ? metal::exp(v - row_max) : 0.0f; + float sg_sum = simd_sum(exp_v); + if (lane == 0) cross_f[sg_id] = sg_sum; + threadgroup_barrier(mem_flags::mem_threadgroup); + float row_sum = cross_f[0] + cross_f[1]; + + // --- Probability + bias --- + float p = (tid < D) ? (exp_v / row_sum + (float)bias[tid]) : -INFINITY; + + // --- Top-K via iterative cross-SIMD max extraction --- + // Each iteration: 1 simd_max + 1 simd_min + 2 barriers = 4 ops + threadgroup uint cross_u[2]; + + thread float topk_vals[K]; + thread uint topk_idx[K]; + + float cur = p; + for (uint i = 0; i < K; ++i) {{ + // Find global max across both SIMD groups + float sg_cur_max = simd_max(cur); + if (lane == 0) cross_f[sg_id] = sg_cur_max; + threadgroup_barrier(mem_flags::mem_threadgroup); + float global_max = metal::max(cross_f[0], cross_f[1]); + + // Find winner: lowest-index thread holding global_max + uint candidate = (cur == global_max && tid < D) ? tid : 0xFFFFu; + uint sg_winner = simd_min(candidate); + if (lane == 0) cross_u[sg_id] = sg_winner; + threadgroup_barrier(mem_flags::mem_threadgroup); + uint winner = metal::min(cross_u[0], cross_u[1]); + + if (tid == 0) {{ + topk_vals[i] = global_max; + topk_idx[i] = winner; + }} + // Mark winner as consumed + if (tid == winner) cur = -INFINITY; + }} + + // --- Renormalize and write output (ascending value order) --- + // MLX argpartition returns top-K in ascending value order; + // we write in reverse so slot 0 = smallest top-K value. + if (tid == 0) {{ + float denom = 1.0f; + if (RENORM) {{ + float sum_k = 0.0f; + for (uint i = 0; i < K; ++i) {{ + sum_k += topk_vals[i]; + }} + denom = sum_k + 1e-20f; + }} + uint out_base = row * K; + for (uint i = 0; i < K; ++i) {{ + uint ri = K - 1 - i; // reverse: descending -> ascending + float v_out = topk_vals[ri]; + weights[out_base + i] = (T)(RENORM ? (v_out / denom) : v_out); + indices[out_base + i] = topk_idx[ri]; + }} + }} + """ + return metal_kernel( + name=f"kk_topk_softmax_bias_dsimd_D{D}_K{K}_R{int(bool(renorm))}", + input_names=["inp", "bias"], + output_names=["weights", "indices"], + source=source, + header=DEFAULT_HEADER, + ensure_row_contiguous=True, + cache=True, + ) + + +@cache +def _topk_softmax_dsimd_kernel(d: int, k: int, renorm: bool) -> Any: + """Optimized kernel for D in (32, 64] WITHOUT bias.""" + D = int(d) + K = int(k) + if D <= 32 or D > 64: + raise ValueError("_topk_softmax_dsimd_kernel: requires 32 < D <= 64") + if K <= 0 or K > _MAX_FUSED_TOPK: + raise ValueError("_topk_softmax_dsimd_kernel: invalid K") + + renorm_literal = str(bool(renorm)).lower() + TG = ((D + 31) // 32) * 32 + + source = f""" + constexpr uint D = {D}; + constexpr uint K = {K}; + constexpr uint TG = {TG}; + constexpr uint SG = 32; + constexpr bool RENORM = {renorm_literal}; + + uint gid = thread_position_in_grid.x; + uint tid = thread_position_in_threadgroup.x; + uint row = gid / TG; + uint base = row * D; + uint sg_id = tid / SG; + uint lane = tid % SG; + + float v = -INFINITY; + if (tid < D) {{ + v = (float)inp[base + tid]; + }} + + float sg_max = simd_max(v); + threadgroup float cross_f[2]; + if (lane == 0) cross_f[sg_id] = sg_max; + threadgroup_barrier(mem_flags::mem_threadgroup); + float row_max = metal::max(cross_f[0], cross_f[1]); + + float exp_v = (tid < D) ? metal::exp(v - row_max) : 0.0f; + float sg_sum = simd_sum(exp_v); + if (lane == 0) cross_f[sg_id] = sg_sum; + threadgroup_barrier(mem_flags::mem_threadgroup); + float row_sum = cross_f[0] + cross_f[1]; + + float p = (tid < D) ? (exp_v / row_sum) : -INFINITY; + + threadgroup uint cross_u[2]; + thread float topk_vals[K]; + thread uint topk_idx[K]; + + float cur = p; + for (uint i = 0; i < K; ++i) {{ + float sg_cur_max = simd_max(cur); + if (lane == 0) cross_f[sg_id] = sg_cur_max; + threadgroup_barrier(mem_flags::mem_threadgroup); + float global_max = metal::max(cross_f[0], cross_f[1]); + + uint candidate = (cur == global_max && tid < D) ? tid : 0xFFFFu; + uint sg_winner = simd_min(candidate); + if (lane == 0) cross_u[sg_id] = sg_winner; + threadgroup_barrier(mem_flags::mem_threadgroup); + uint winner = metal::min(cross_u[0], cross_u[1]); + + if (tid == 0) {{ + topk_vals[i] = global_max; + topk_idx[i] = winner; + }} + if (tid == winner) cur = -INFINITY; + }} + + // Write output in ascending value order (matching argpartition) + if (tid == 0) {{ + float denom = 1.0f; + if (RENORM) {{ + float sum_k = 0.0f; + for (uint i = 0; i < K; ++i) {{ + sum_k += topk_vals[i]; + }} + denom = sum_k + 1e-20f; + }} + uint out_base = row * K; + for (uint i = 0; i < K; ++i) {{ + uint ri = K - 1 - i; // reverse: descending -> ascending + float v_out = topk_vals[ri]; + weights[out_base + i] = (T)(RENORM ? (v_out / denom) : v_out); + indices[out_base + i] = topk_idx[ri]; + }} + }} + """ + return metal_kernel( + name=f"kk_topk_softmax_dsimd_D{D}_K{K}_R{int(bool(renorm))}", + input_names=["inp"], + output_names=["weights", "indices"], + source=source, + header=DEFAULT_HEADER, + ensure_row_contiguous=True, + cache=True, + ) + + @cache def _top2_gating_kernel(d: int, tg: int) -> Any: D = int(d) @@ -1105,6 +1323,138 @@ def _moe_combine_kernel_fp32_no_fma(d: int, k: int) -> Any: ) +def _row_tile_launch_params(d: int) -> tuple[int, int]: + """Return (threadgroup_size, unroll) for row-tiled combine kernels.""" + D = int(d) + if D >= 256: + tg = 256 + elif D >= 128: + tg = 128 + elif D >= 64: + tg = 64 + elif D >= 32: + tg = 32 + else: + tg = max(1, D) + + if D >= tg * 4: + unroll = 4 + elif D >= tg * 2: + unroll = 2 + else: + unroll = 1 + return tg, unroll + + +@cache +def _moe_combine_kernel_fp32_no_fma_row_tile( + d: int, + k: int, + tg_size: int, + unroll: int, +) -> Any: + """FP32 no-FMA combine with row tiling adapted from foundry t2_row_tile.""" + D = int(d) + K = int(k) + TG_SIZE = int(tg_size) + UNROLL = int(unroll) + source = f""" + #pragma clang fp contract(off) + constexpr uint D = {D}; + constexpr uint K = {K}; + constexpr uint TG_SIZE = {TG_SIZE}; + constexpr uint UNROLL = {UNROLL}; + uint token_idx = thread_position_in_grid.y; + uint tid = thread_position_in_threadgroup.x; + + threadgroup float wbuf[K]; + for (uint i = tid; i < K; i += threads_per_threadgroup.x) {{ + wbuf[i] = weights[token_idx * K + i]; + }} + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (uint col0 = tid; col0 < D; col0 += TG_SIZE * UNROLL) {{ + #pragma unroll + for (uint u = 0; u < UNROLL; ++u) {{ + uint d_idx = col0 + u * TG_SIZE; + if (d_idx < D) {{ + float acc = 0.0f; + for (uint i = 0; i < K; ++i) {{ + float w = wbuf[i]; + float v = (float)expert_outputs[(token_idx * K + i) * D + d_idx]; + float prod = w * v; + acc = acc + prod; + }} + out[token_idx * D + d_idx] = acc; + }} + }} + }} + """ + return metal_kernel( + name=f"kk_moe_combine_fp32_no_fma_row_tile_D{D}_K{K}_TG{TG_SIZE}_U{UNROLL}", + input_names=["expert_outputs", "weights"], + output_names=["out"], + source=source, + header=DEFAULT_HEADER, + ensure_row_contiguous=True, + cache=True, + ) + + +@cache +def _moe_combine_weighted_sum_kernel_fp32_no_fma_k8_row_tile( + d: int, + tg_size: int, + unroll: int, +) -> Any: + """Specialized fp32/no-FMA weighted-sum combine for K=8.""" + D = int(d) + TG_SIZE = int(tg_size) + UNROLL = int(unroll) + source = f""" + #pragma clang fp contract(off) + constexpr uint D = {D}; + constexpr uint K = 8; + constexpr uint TG_SIZE = {TG_SIZE}; + constexpr uint UNROLL = {UNROLL}; + uint token_idx = thread_position_in_grid.y; + uint tid = thread_position_in_threadgroup.x; + + threadgroup float wbuf[K]; + for (uint i = tid; i < K; i += threads_per_threadgroup.x) {{ + wbuf[i] = weights[token_idx * K + i]; + }} + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (uint col0 = tid; col0 < D; col0 += TG_SIZE * UNROLL) {{ + #pragma unroll + for (uint u = 0; u < UNROLL; ++u) {{ + uint d_idx = col0 + u * TG_SIZE; + if (d_idx < D) {{ + float acc = 0.0f; + #pragma unroll + for (uint i = 0; i < K; ++i) {{ + float w = wbuf[i]; + float v = (float)expert_outputs[(token_idx * K + i) * D + d_idx]; + float prod = w * v; + acc = acc + prod; + }} + out[token_idx * D + d_idx] = acc; + }} + }} + }} + """ + return metal_kernel( + name=f"kk_moe_combine_weighted_sum_fp32_no_fma_k8_D{D}_TG{TG_SIZE}_U{UNROLL}", + input_names=["expert_outputs", "weights"], + output_names=["out"], + source=source, + header=DEFAULT_HEADER, + ensure_row_contiguous=True, + cache=True, + ) + + def moe_combine_fp32_no_fma(expert_outputs: Any, weights: Any) -> Any: """Like ``moe_combine_fp32`` but disables FMA contraction in accumulation.""" original_shape = weights.shape[:-1] @@ -1127,6 +1477,57 @@ def moe_combine_fp32_no_fma(expert_outputs: Any, weights: Any) -> Any: return out.reshape((*original_shape, D)) +def moe_combine_fp32_no_fma_row_tile(expert_outputs: Any, weights: Any) -> Any: + """FP32 no-FMA combine with t2-row-tile style column traversal.""" + original_shape = weights.shape[:-1] + K = int(weights.shape[-1]) + D = int(expert_outputs.shape[-1]) + + expert_outputs_flat = expert_outputs.reshape(-1, K, D) + weights_flat = weights.reshape(-1, K).astype(mx.float32) + B = int(weights_flat.shape[0]) + + tg_size, unroll = _row_tile_launch_params(D) + k = _moe_combine_kernel_fp32_no_fma_row_tile(D, K, tg_size, unroll) + out = k( + expert_outputs_flat, + weights_flat, + grid=(tg_size, B, 1), + threadgroup=(tg_size, 1, 1), + output_shapes=[(B, D)], + output_dtypes=[mx.float32], + )[0] + return out.reshape((*original_shape, D)) + + +def moe_combine_weighted_sum_fp32_no_fma(expert_outputs: Any, weights: Any) -> Any: + """Specialized fp32/no-FMA weighted sum for K=8 decode-like MoE combine. + + Falls back to ``moe_combine_fp32_no_fma`` when K != 8. + """ + K = int(weights.shape[-1]) + if K != 8: + return moe_combine_fp32_no_fma(expert_outputs, weights) + + original_shape = weights.shape[:-1] + D = int(expert_outputs.shape[-1]) + expert_outputs_flat = expert_outputs.reshape(-1, K, D) + weights_flat = weights.reshape(-1, K).astype(mx.float32) + B = int(weights_flat.shape[0]) + + tg_size, unroll = _row_tile_launch_params(D) + k = _moe_combine_weighted_sum_kernel_fp32_no_fma_k8_row_tile(D, tg_size, unroll) + out = k( + expert_outputs_flat, + weights_flat, + grid=(tg_size, B, 1), + threadgroup=(tg_size, 1, 1), + output_shapes=[(B, D)], + output_dtypes=[mx.float32], + )[0] + return out.reshape((*original_shape, D)) + + def _combine_projected_experts( projected: Any, gate: Any, @@ -1351,6 +1752,35 @@ def topk_gating_softmax( ) return weights, indices + # D-SIMD fast path for 32 < D <= 64: uses exactly D threads (2 SIMD groups) + # instead of TG=256 with 75%+ idle threads. 4K barriers vs 24+. + use_dsimd = 32 < D <= 64 and K <= _MAX_FUSED_TOPK + if use_dsimd: + TG = ((D + 31) // 32) * 32 # 64 for D in (32, 64] + if bias is None: + kernel = _topk_softmax_dsimd_kernel(D, K, norm) + weights, indices = kernel( + x_cast, + template=[("T", cd)], + grid=(rows * TG, 1, 1), + threadgroup=(TG, 1, 1), + output_shapes=[x_cast.shape[:-1] + (K,), x_cast.shape[:-1] + (K,)], + output_dtypes=[cd, mx.uint32], + ) + else: + bias_cast = bias.astype(cd) if bias.dtype != cd else bias + kernel = _topk_softmax_bias_dsimd_kernel(D, K, norm) + weights, indices = kernel( + x_cast, + bias_cast, + template=[("T", cd)], + grid=(rows * TG, 1, 1), + threadgroup=(TG, 1, 1), + output_shapes=[x_cast.shape[:-1] + (K,), x_cast.shape[:-1] + (K,)], + output_dtypes=[cd, mx.uint32], + ) + return weights, indices + TG = _validate_tg(threadgroup) if bias is None: kernel = _topk_softmax_kernel(D, K, TG, norm) @@ -2185,6 +2615,8 @@ def gather_qmm_combine_quantized( "moe_combine_exact", "moe_combine_fp32", "moe_combine_fp32_no_fma", + "moe_combine_fp32_no_fma_row_tile", + "moe_combine_weighted_sum_fp32_no_fma", "gather_mmk_combine", "gather_qmmk_combine_quantized", "gather_qmm_combine", diff --git a/src/zmlx/patch/__init__.py b/src/zmlx/patch/__init__.py index f607f69..69350f3 100644 --- a/src/zmlx/patch/__init__.py +++ b/src/zmlx/patch/__init__.py @@ -37,7 +37,7 @@ from ._registry import _ensure_loaded, get_all_patterns, get_pattern, list_patterns from ._traversal import apply_patterns -from ._types import PatchConfig, PatchResult +from ._types import FusionConfig, PatchConfig, PatchResult # --------------------------------------------------------------------------- # Model-aware safety: auto-exclude patterns with known fidelity issues @@ -136,6 +136,11 @@ def patch( compute_dtype: str = "float32", threadgroup: int | str = 256, moe_fused_swiglu_max_tokens: int | None = None, + fusion: bool = False, + fusion_validate: bool = True, + fusion_patterns: list[str] | None = None, + fusion_rtol: float = 1e-4, + fusion_atol: float = 1e-5, verbose: bool = False, ) -> nn.Module: """Patch an MLX model to use fused ZMLX Metal kernels. @@ -159,6 +164,14 @@ def patch( to autotune on first invocation. moe_fused_swiglu_max_tokens: Override the max token count for the fused SwiGLU MoE path. ``None`` uses the built-in default threshold. + fusion: Opt-in JIT fusion integration for patched modules. Disabled by + default for safety. + fusion_validate: When fusion is enabled, validate fused vs unfused + outputs before enabling each signature/pattern. + fusion_patterns: Optional allowlist of pattern names eligible for fusion + integration. ``None`` means all patched patterns. + fusion_rtol: Relative tolerance for validation gate. + fusion_atol: Absolute tolerance for validation gate. verbose: Print each replacement as it happens. Returns: @@ -191,6 +204,13 @@ def patch( compute_dtype=compute_dtype, threadgroup=threadgroup, moe_fused_swiglu_max_tokens=moe_fused_swiglu_max_tokens, + fusion=FusionConfig( + enabled=bool(fusion), + validate=bool(fusion_validate), + rtol=float(fusion_rtol), + atol=float(fusion_atol), + patterns=tuple(fusion_patterns or ()), + ), verbose=verbose, ) @@ -317,6 +337,14 @@ def unpatch(model: nn.Module) -> nn.Module: children = dict(model.children()) for _name, child in children.items(): + if hasattr(child, "_zmlx_fusion_original_class"): + try: + child.__class__ = child._zmlx_fusion_original_class + except Exception: + pass + del child._zmlx_fusion_original_class + if hasattr(child, "_zmlx_fusion_wrapped"): + del child._zmlx_fusion_wrapped if hasattr(child, "_zmlx_original_call") and child._zmlx_original_call is not None: child.__call__ = child._zmlx_original_call del child._zmlx_original_call diff --git a/src/zmlx/patch/_traversal.py b/src/zmlx/patch/_traversal.py index ef3d3d8..8c1e457 100644 --- a/src/zmlx/patch/_traversal.py +++ b/src/zmlx/patch/_traversal.py @@ -4,9 +4,96 @@ import mlx.nn as nn +from ..fusion.integration import wrap_callable_with_fusion_validation from ._types import PatchConfig, PatchPattern, PatchResult +def _module_fingerprint(module: Any) -> str: + """Generate a deterministic fingerprint for a module instance. + + Uses structural properties (class name, parameter count, child count) + rather than id() for reproducibility across runs while maintaining + reasonable cross-instance isolation. + """ + parts = [type(module).__name__] + params_fn = getattr(module, "parameters", None) + if callable(params_fn): + try: + parts.append(f"p{sum(1 for _ in params_fn())}") + except Exception: + parts.append("p?") + children_fn = getattr(module, "children", None) + if callable(children_fn): + try: + children_dict = children_fn() + if isinstance(children_dict, dict): + parts.append(f"c{len(children_dict)}") + else: + parts.append("c?") + except Exception: + parts.append("c?") + return ":".join(parts) + + +def _wrap_module_call_path( + module: Any, + wrapped_call: Any, +) -> None: + """Install a per-instance class-level __call__ wrapper for module(...).""" + + original_class = module.__class__ + + def _patched_call(_self: Any, *args: Any, **kwargs: Any) -> Any: + return wrapped_call(*args, **kwargs) + + _patched_call._zmlx_fusion_wrapped = True # type: ignore[attr-defined] + if hasattr(wrapped_call, "fusion_stats"): + _patched_call.fusion_stats = wrapped_call.fusion_stats # type: ignore[attr-defined] + + module.__class__ = type( + original_class.__name__, + (original_class,), + {"__call__": _patched_call}, + ) + module._zmlx_fusion_original_class = original_class # type: ignore[attr-defined] + module._zmlx_fusion_wrapped = True # type: ignore[attr-defined] + + +def _maybe_wrap_fusion( + module: Any, + *, + pattern_name: str, + path: str, + config: PatchConfig, + result: PatchResult, +) -> None: + fusion = config.fusion + if not fusion.allows_pattern(pattern_name): + return + if not callable(module): + result.fusion_skipped.append(f"{path}: missing __call__") + return + if getattr(module, "_zmlx_fusion_wrapped", False): + return + wrapped = module.__call__ + if getattr(wrapped, "_zmlx_fusion_wrapped", False): + module._zmlx_fusion_wrapped = True # type: ignore[attr-defined] + return + try: + wrapped_call = wrap_callable_with_fusion_validation( + wrapped, + pattern_name=pattern_name, + fusion=fusion, + state=config.fusion_state, + module_instance_key=f"{path}#{_module_fingerprint(module)}", + ) + _wrap_module_call_path(module, wrapped_call) + except Exception as exc: + result.fusion_skipped.append(f"{path}: {exc}") + return + result.fusion_wrapped_count += 1 + + def _walk_modules( module: Any, patterns: list[PatchPattern], @@ -48,6 +135,13 @@ def _walk_modules( result.pattern_counts[pattern.name] = ( result.pattern_counts.get(pattern.name, 0) + 1 ) + _maybe_wrap_fusion( + replacement, + pattern_name=pattern.name, + path=item_path, + config=config, + result=result, + ) if config.verbose: print(f" [zmlx.patch] {item_path}: {pattern.name}") except Exception as e: @@ -63,6 +157,13 @@ def _walk_modules( result.pattern_counts[pattern.name] = ( result.pattern_counts.get(pattern.name, 0) + 1 ) + _maybe_wrap_fusion( + replacement, + pattern_name=pattern.name, + path=child_path, + config=config, + result=result, + ) if config.verbose: print(f" [zmlx.patch] {child_path}: {pattern.name}") except Exception as e: diff --git a/src/zmlx/patch/_types.py b/src/zmlx/patch/_types.py index 6135744..741c89d 100644 --- a/src/zmlx/patch/_types.py +++ b/src/zmlx/patch/_types.py @@ -3,6 +3,45 @@ from dataclasses import dataclass, field from typing import Any, Protocol, runtime_checkable +FusionSignatureItem = tuple[str, Any, Any] +FusionSignature = tuple[FusionSignatureItem, ...] +FusionLegacyCacheKey = tuple[str, FusionSignature] +FusionScopedCacheKey = tuple[str, str, FusionSignature] +FusionCacheKey = FusionLegacyCacheKey | FusionScopedCacheKey + + +@dataclass(frozen=True) +class FusionConfig: + """Opt-in fusion runtime integration settings for patched modules.""" + + enabled: bool = False + validate: bool = True + rtol: float = 1e-4 + atol: float = 1e-5 + # Empty means "all patched patterns". + patterns: tuple[str, ...] = () + # When True, legacy (non-scoped) cache keys are migrated to scoped format on access. + normalize_legacy_keys: bool = False + + def allows_pattern(self, pattern_name: str) -> bool: + if not self.enabled: + return False + if not self.patterns: + return True + return pattern_name in self.patterns + + +@dataclass +class FusionRuntimeState: + """Mutable cache/state for fusion integration safety decisions.""" + + validated: set[FusionCacheKey] = field(default_factory=set) + blacklist: set[FusionCacheKey] = field(default_factory=set) + compile_failures: int = 0 + runtime_failures: int = 0 + validation_mismatches: int = 0 + blacklist_hits: int = 0 + @dataclass(frozen=True) class PatchConfig: @@ -12,6 +51,8 @@ class PatchConfig: threadgroup: int | str = 256 verbose: bool = False moe_fused_swiglu_max_tokens: int | None = None + fusion: FusionConfig = field(default_factory=FusionConfig) + fusion_state: FusionRuntimeState = field(default_factory=FusionRuntimeState) @dataclass @@ -23,6 +64,8 @@ class PatchResult: skipped: list[str] = field(default_factory=list) # Filled by smart_patch: pattern_name -> speedup ratio benchmarks: dict[str, float] = field(default_factory=dict) + fusion_wrapped_count: int = 0 + fusion_skipped: list[str] = field(default_factory=list) def summary(self) -> str: lines = [f"Patched {self.patched_count} modules:"] @@ -30,6 +73,10 @@ def summary(self) -> str: bench = self.benchmarks.get(name) suffix = f" ({bench:.3f}x)" if bench is not None else "" lines.append(f" {name}: {count}{suffix}") + if self.fusion_wrapped_count: + lines.append(f" Fusion wrappers: {self.fusion_wrapped_count}") + if self.fusion_skipped: + lines.append(f" Fusion skipped: {len(self.fusion_skipped)}") if self.skipped: lines.append(f" Skipped: {len(self.skipped)}") return "\n".join(lines) diff --git a/src/zmlx/patch/patterns/moe_mlp.py b/src/zmlx/patch/patterns/moe_mlp.py index 35a0693..59c09bc 100644 --- a/src/zmlx/patch/patterns/moe_mlp.py +++ b/src/zmlx/patch/patterns/moe_mlp.py @@ -199,6 +199,63 @@ def _qwen_combine_mode() -> str: return "off" +def _lfm2_gate_mode() -> str: + """Return LFM2 gate mode for experimental optimization. + + Modes: + - ``kernel`` (default): use fused ``topk_gating_softmax`` Metal kernel. + - ``native``: use exact original MLX ops (softmax + argpartition). + Guaranteed fidelity but may have higher dispatch count. + - ``fast64``: use new D-SIMD optimized kernel for D=64 + (2 SIMD groups, 4K barriers vs 24+). Experimental. + + Set via ``ZMLX_LFM2_GATE_MODE``. + """ + raw = os.environ.get(_LFM2_GATE_MODE_ENV, "kernel").strip().lower() + if raw in {"kernel", "native", "fast64"}: + return raw + return "kernel" + + +def _lfm2_combine_mode(num_experts_per_tok: int = 2) -> str: + """Return LFM2 combine mode for experimental optimization. + + Modes: + - ``kernel``: use ``moe_combine`` Metal kernel. + - ``native``: use MLX native ``(y * w[..., None]).sum(axis=-2)``. + - ``fp32_no_fma``: use ``moe_combine_fp32_no_fma``. + + Default: ``native`` for K>=3 (LFM2-24B), ``kernel`` for K<=2 (LFM2-8B). + Set via ``ZMLX_LFM2_COMBINE_MODE``. + """ + raw = os.environ.get(_LFM2_COMBINE_MODE_ENV) + if raw is not None: + raw = raw.strip().lower() + if raw in {"kernel", "native", "fp32_no_fma"}: + return raw + # LFM2-24B (K=4): native combine is fidelity-safe and fastest. + # LFM2-8B (K=2): kernel combine works fine. + return "native" if num_experts_per_tok >= 3 else "kernel" + + +def _lfm2_allow_fused_swiglu(num_experts_per_tok: int = 2) -> bool: + """Return True if LFM2 should use gather_qmm_swiglu fused SwiGLU. + + Default: enabled for K<=2 (LFM2-8B where it gives +11.6%), + disabled for K>=3 (LFM2-24B where it regresses to ~0.77x). + Set ``ZMLX_LFM2_FUSED_SWIGLU=0`` or ``=1`` to override. + """ + raw = os.environ.get("ZMLX_LFM2_FUSED_SWIGLU") + if raw is not None: + try: + return int(raw) != 0 + except ValueError: + pass + # K=4 (LFM2-24B): gather_qmm_swiglu causes both fidelity failure and + # performance regression. K=2 (LFM2-8B): proven +11.6% win. + return num_experts_per_tok <= 2 + + def _gating( self_mod: Any, x: Any, @@ -307,6 +364,8 @@ def _gating( _QWEN_ROUTER_ARGPARTITION_LOGITS_ENV = "ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS" _QWEN_ROUTER_ARGPARTITION_LOGITS_TOPK_ENV = "ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS_TOPK" _QWEN_COMBINE_MODE_ENV = "ZMLX_QWEN_COMBINE_MODE" +_LFM2_GATE_MODE_ENV = "ZMLX_LFM2_GATE_MODE" +_LFM2_COMBINE_MODE_ENV = "ZMLX_LFM2_COMBINE_MODE" def _glm_combine_mode() -> str: @@ -314,6 +373,8 @@ def _glm_combine_mode() -> str: Modes: - ``fp32_no_fma`` (default): use ``moe_combine_fp32_no_fma``. + - ``fp32_no_fma_row_tile``: use ``moe_combine_fp32_no_fma_row_tile``. + - ``fp32_no_fma_weighted_sum``: use ``moe_combine_weighted_sum_fp32_no_fma``. - ``off``: keep current GLM path (discovered combine if available, otherwise native ``(y * w[..., None]).sum(axis=-2)``). - ``fp32``: use ``moe_combine_fp32``. @@ -321,10 +382,21 @@ def _glm_combine_mode() -> str: Set via ``ZMLX_GLM_COMBINE_MODE``. """ - raw = os.environ.get(_GLM_COMBINE_MODE_ENV, "fp32_no_fma").strip().lower() - if raw in {"", "0", "off", "false", "no"}: + raw = os.environ.get(_GLM_COMBINE_MODE_ENV) + if raw is None: return "fp32_no_fma" - if raw in {"fp32", "fp32_no_fma", "exact"}: + raw = raw.strip().lower() + if raw in {"", "default", "auto"}: + return "fp32_no_fma" + if raw in {"0", "off", "false", "no"}: + return "off" + if raw in { + "fp32", + "fp32_no_fma", + "fp32_no_fma_row_tile", + "fp32_no_fma_weighted_sum", + "exact", + }: return raw return "fp32_no_fma" @@ -846,6 +918,13 @@ def apply(self, module: Any, config: PatchConfig) -> Any: "ZMLX_GLM_FUSED_SWIGLU=0" ) _use_fused_swiglu = False + if is_lfm2 and _use_fused_swiglu and not _lfm2_allow_fused_swiglu(num_experts_per_tok): + if config.verbose: + print( + f" [moe_mlp] LFM2 detected: fused SwiGLU disabled " + f"(K={num_experts_per_tok})" + ) + _use_fused_swiglu = False moe_stream_pool = _get_moe_stream_pool() moe_stream_reduce = None if moe_stream_pool is not None: @@ -863,9 +942,15 @@ def apply(self, module: Any, config: PatchConfig) -> Any: qwen_fused_downproj_combine_kvec = _qwen_allow_fused_downproj_combine_kvec() qwen_combine_mode = _qwen_combine_mode() if is_glm and glm_combine_mode != "off" and config.verbose: - print(f" [moe_mlp] GLM combine override enabled (mode={glm_combine_mode})") + print(f" [moe_mlp] GLM combine mode active (mode={glm_combine_mode})") if is_qwen3 and qwen_combine_mode != "off" and config.verbose: print(f" [moe_mlp] Qwen combine override enabled (mode={qwen_combine_mode})") + lfm2_gate = _lfm2_gate_mode() if is_lfm2 else None + lfm2_combine = _lfm2_combine_mode(num_experts_per_tok) if is_lfm2 else None + if is_lfm2 and config.verbose: + print( + f" [moe_mlp] LFM2 mode: gate={lfm2_gate}, combine={lfm2_combine}" + ) # Optional: overlap GLM/DeepSeek-style `shared_experts(x)` with routed # expert execution. This is experimental and may regress on some MLX @@ -888,14 +973,32 @@ def patched_call(self_mod: Any, x: Any) -> Any: shared_out = shared(x) # 1. Gating — preserve the model's original logic exactly. - indices, gate_weights = _gating( - self_mod, - x, - _gate_attr, - num_experts_per_tok, - is_qwen3=is_qwen3, - is_gpt_oss=is_gpt_oss, - ) + if is_lfm2 and lfm2_gate == "native": + # LFM2 native gate: reproduce exact original MLX ops + # to guarantee fidelity (no custom kernel). + gate_fn = getattr(self_mod, _gate_attr) + gate_out = gate_fn(x).astype(mx.float32) + gate_out = mx.softmax(gate_out, axis=-1) + _eb = getattr(self_mod, "expert_bias", None) + if _eb is not None: + gate_out = gate_out + _eb + _k = num_experts_per_tok + inds = mx.argpartition(gate_out, kth=-_k, axis=-1)[..., -_k:] + scores = mx.take_along_axis(gate_out, inds, axis=-1) + _ntp = getattr(self_mod, "norm_topk_prob", False) + if _ntp: + scores = scores / (mx.sum(scores, axis=-1, keepdims=True) + 1e-20) + scores = scores.astype(x.dtype) + indices, gate_weights = inds, scores + else: + indices, gate_weights = _gating( + self_mod, + x, + _gate_attr, + num_experts_per_tok, + is_qwen3=is_qwen3, + is_gpt_oss=is_gpt_oss, + ) # 2. Expert Execution expert_outputs = None @@ -921,7 +1024,7 @@ def patched_call(self_mod: Any, x: Any) -> Any: require_fp32=use_exact_combine, vectorized_k=qwen_fused_downproj_combine_kvec, ) - elif is_lfm2: + elif is_lfm2 and _use_fused_swiglu: fused_out = _try_fused_downproj_combine( switch_mlp, x, @@ -958,7 +1061,7 @@ def patched_call(self_mod: Any, x: Any) -> Any: for i, expert in enumerate(experts): for k in range(K): mask = indices[:, k] == i - if mask.any(): + if mask.any(): # type: ignore[union-attr] expert_outputs[mask, k] = expert(x[mask]) else: stream_count = len(moe_stream_pool) @@ -1000,6 +1103,22 @@ def patched_call(self_mod: Any, x: Any) -> Any: ) y = moe.moe_combine_fp32_no_fma(expert_outputs, gw) y = y.astype(expert_outputs.dtype) + elif glm_combine_mode == "fp32_no_fma_row_tile": + gw = ( + gate_weights + if gate_weights.dtype == mx.float32 + else gate_weights.astype(mx.float32) + ) + y = moe.moe_combine_fp32_no_fma_row_tile(expert_outputs, gw) + y = y.astype(expert_outputs.dtype) + elif glm_combine_mode == "fp32_no_fma_weighted_sum": + gw = ( + gate_weights + if gate_weights.dtype == mx.float32 + else gate_weights.astype(mx.float32) + ) + y = moe.moe_combine_weighted_sum_fp32_no_fma(expert_outputs, gw) + y = y.astype(expert_outputs.dtype) elif glm_combine_mode == "exact": gw = ( gate_weights @@ -1041,6 +1160,20 @@ def patched_call(self_mod: Any, x: Any) -> Any: y = moe.moe_combine_exact(expert_outputs, gw) else: y = (expert_outputs * gate_weights[..., None]).sum(axis=-2) + elif is_lfm2: + # LFM2-specific combine mode selection. + if lfm2_combine == "native": + y = (expert_outputs * gate_weights[..., None]).sum(axis=-2) + elif lfm2_combine == "fp32_no_fma": + gw = ( + gate_weights + if gate_weights.dtype == mx.float32 + else gate_weights.astype(mx.float32) + ) + y = moe.moe_combine_fp32_no_fma(expert_outputs, gw) + y = y.astype(expert_outputs.dtype) + else: + y = moe.moe_combine(expert_outputs, gate_weights) elif is_gpt_oss: # GPT-OSS: match MLX's dtype promotion and sum order. if gate_weights.dtype == mx.float32: diff --git a/src/zmlx/validate.py b/src/zmlx/validate.py index bc4502d..a262527 100644 --- a/src/zmlx/validate.py +++ b/src/zmlx/validate.py @@ -58,11 +58,11 @@ def peak_mem_gb(self) -> float: # Helpers # --------------------------------------------------------------------------- def _clear_gpu() -> None: + mx.synchronize() + gc.collect() + gc.collect() + mx.clear_cache() gc.collect() - if hasattr(mx, "clear_cache"): - mx.clear_cache() - elif hasattr(mx.metal, "clear_cache"): - mx.metal.clear_cache() def _generate_greedy( diff --git a/tests/test_fused_ops.py b/tests/test_fused_ops.py index c9de3d3..1807a7e 100644 --- a/tests/test_fused_ops.py +++ b/tests/test_fused_ops.py @@ -361,6 +361,57 @@ def test_moe_combine_no_fma_exact(self): mx.eval(ref, out) assert mx.array_equal(ref, out) + def test_moe_combine_fp32_no_fma_row_tile_matches_base(self): + """Row-tiled fp32/no-FMA variant should match current fp32/no-FMA combine.""" + from zmlx.kernels.moe import ( + moe_combine_fp32_no_fma, + moe_combine_fp32_no_fma_row_tile, + ) + + mx.random.seed(3) + B, K, D = 4, 4, 3072 + expert_outputs = mx.random.normal((B, K, D)).astype(mx.float16) + weights = mx.softmax(mx.random.normal((B, K)), axis=-1).astype(mx.float32) + + ref = moe_combine_fp32_no_fma(expert_outputs, weights) + got = moe_combine_fp32_no_fma_row_tile(expert_outputs, weights) + mx.eval(ref, got) + assert_allclose(got, ref, rtol=1e-6, atol=1e-6) + + def test_moe_combine_weighted_sum_fp32_no_fma_matches_base_for_k8(self): + """K=8 weighted-sum specialization should match fp32/no-FMA reference.""" + from zmlx.kernels.moe import ( + moe_combine_fp32_no_fma, + moe_combine_weighted_sum_fp32_no_fma, + ) + + mx.random.seed(7) + B, K, D = 2, 8, 4096 + expert_outputs = mx.random.normal((B, K, D)).astype(mx.float16) + weights = mx.softmax(mx.random.normal((B, K)), axis=-1).astype(mx.float32) + + ref = moe_combine_fp32_no_fma(expert_outputs, weights) + got = moe_combine_weighted_sum_fp32_no_fma(expert_outputs, weights) + mx.eval(ref, got) + assert_allclose(got, ref, rtol=1e-6, atol=1e-6) + + def test_moe_combine_weighted_sum_fp32_no_fma_falls_back_when_not_k8(self): + """Weighted-sum specialization should fall back safely when K != 8.""" + from zmlx.kernels.moe import ( + moe_combine_fp32_no_fma, + moe_combine_weighted_sum_fp32_no_fma, + ) + + mx.random.seed(11) + B, K, D = 3, 4, 1536 + expert_outputs = mx.random.normal((B, K, D)).astype(mx.float16) + weights = mx.softmax(mx.random.normal((B, K)), axis=-1).astype(mx.float32) + + ref = moe_combine_fp32_no_fma(expert_outputs, weights) + got = moe_combine_weighted_sum_fp32_no_fma(expert_outputs, weights) + mx.eval(ref, got) + assert_allclose(got, ref, rtol=1e-6, atol=1e-6) + @pytest.mark.metal class TestMoEDispatch: diff --git a/tests/test_moe_fused_swiglu_gate.py b/tests/test_moe_fused_swiglu_gate.py index e6c9b3f..abcd909 100644 --- a/tests/test_moe_fused_swiglu_gate.py +++ b/tests/test_moe_fused_swiglu_gate.py @@ -112,6 +112,24 @@ def test_glm_combine_mode_default_fp32_no_fma(monkeypatch: pytest.MonkeyPatch): assert moe_mlp._glm_combine_mode() == "fp32_no_fma" +@pytest.mark.cpu +def test_glm_combine_mode_off(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("ZMLX_GLM_COMBINE_MODE", "off") + assert moe_mlp._glm_combine_mode() == "off" + + +@pytest.mark.cpu +def test_glm_combine_mode_row_tile(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("ZMLX_GLM_COMBINE_MODE", "fp32_no_fma_row_tile") + assert moe_mlp._glm_combine_mode() == "fp32_no_fma_row_tile" + + +@pytest.mark.cpu +def test_glm_combine_mode_weighted_sum(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("ZMLX_GLM_COMBINE_MODE", "fp32_no_fma_weighted_sum") + assert moe_mlp._glm_combine_mode() == "fp32_no_fma_weighted_sum" + + @pytest.mark.cpu def test_glm_combine_mode_valid(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("ZMLX_GLM_COMBINE_MODE", "exact")