diff --git a/.gitignore b/.gitignore index e116b6cc..d57e68b8 100644 --- a/.gitignore +++ b/.gitignore @@ -38,3 +38,5 @@ test-results/ # PM agent persistent memory .pm/ /companion/target/ +# ONNX models (too large for git, download via script) +public/models/*.onnx diff --git a/.llm/research/bpm-chord-detection-research.md b/.llm/research/bpm-chord-detection-research.md new file mode 100644 index 00000000..b38b4ab2 --- /dev/null +++ b/.llm/research/bpm-chord-detection-research.md @@ -0,0 +1,218 @@ +# BPM Detection & Chord Recognition — Research Report + +> Date: 2026-03-27 | For: ACE-Step-DAW web-local inference + +--- + +## Executive Summary + +**Best deployment strategy: ONNX Runtime Web (WASM + WebGPU) in a Web Worker** + +- BPM: **Beat This! (small, 8MB ONNX)** — current SOTA, C++ port exists with ONNX export +- Chord: **consonance-ACE (ISMIR 2025)** — decomposed conformer, SOTA chord estimation +- Feature extraction: **Essentia.js** (WASM) or **Rust (rustfft + mel-spec → WASM)** +- Fallback/lightweight: **Essentia.js** has built-in BPM + chord detection (DSP-based, lower accuracy) + +NOT recommended: C++ or Rust standalone — the models are Python/PyTorch, so native compilation means reimplementing inference. ONNX is the universal bridge. + +--- + +## Part 1: BPM Detection + +### Model Comparison + +| Model | Accuracy | Size | Real-time | Web-ready | License | +|-------|----------|------|-----------|-----------|---------| +| **Beat This! (CPJKU, 2024)** | SOTA (best F1 across 16 datasets) | 8MB (small) / 97MB (full) | Offline | ONNX export exists via beat_this_cpp | MIT | +| Madmom (CPJKU) | Very high | Large (multiple RNNs) | No | No WASM path | BSD | +| BeatNet+ | High (best online method) | Medium | Yes (<50ms) | Needs ONNX export | MIT | +| Essentia RhythmExtractor | Good | ~2.5MB WASM | Yes | **essentia.js ready** | AGPL-3.0 | +| Aubio | Moderate | <1MB | Yes | **aubiojs ready** | GPL-3.0 | + +### Recommendation: Beat This! (small variant) + +- **Why**: SOTA accuracy without needing DBN post-processing; 8MB small model is web-friendly +- **How**: C++ port already exists at [mosynthkey/beat_this_cpp](https://github.com/mosynthkey/beat_this_cpp) + - Uses ONNX Runtime for inference + - 97MB ONNX model (full) — small variant is ~8MB + - Pipeline: audio → mel spectrogram → transformer → beat/downbeat positions +- **Web path**: Export small model to ONNX → run via `onnxruntime-web` (WASM or WebGPU) + +### Fallback: Essentia.js + +- Already works in browser, zero additional work +- `RhythmExtractor2013` gives BPM + beat positions +- Lower accuracy but production-ready today + +--- + +## Part 2: Chord Recognition + +### Model Comparison + +| Model | Accuracy | Vocabulary | Architecture | Web Path | License | +|-------|----------|------------|--------------|----------|---------| +| **consonance-ACE (ISMIR 2025)** | SOTA | 170 classes | Conformer (decomposed) | PyTorch → ONNX → ort-web | MIT | +| BTC | ~80-86% majmin | 25 classes | Transformer | PyTorch → ONNX → ort-web | — | +| CREMA | ~75-80% | 602 classes | CNN+RNN | TF → ONNX → ort-web | — | +| Chordino | ~70-75% | maj/min/7th | NNLS+HMM | C++ → WASM | GPL | +| Essentia ChordsDetection | ~65-70% | maj/min | HPCP+template | **essentia.js ready** | AGPL | + +### Recommendation: consonance-ACE + +- **Why**: ISMIR 2025 SOTA, decomposed output (root + bass + pitch activations), 170 chord vocabulary, MIT license +- **Repo**: [andreamust/consonance-ACE](https://github.com/andreamust/consonance-ACE) +- **Paper**: [arxiv.org/abs/2509.01588](https://arxiv.org/abs/2509.01588) +- **Architecture**: Conformer with decomposed heads — separately estimates root, bass, and note activations, then reconstructs chord labels +- **Key innovation**: Consonance-based label smoothing handles annotator subjectivity and class imbalance +- **Input**: Audio (WAV) → 20s chunks +- **Output**: `.lab` format (start_time, end_time, chord_label e.g. `E:maj`) +- **Pipeline**: audio → conformer → decomposed heads (root/bass/notes) → chord label +- **Web path**: PyTorch checkpoint → `torch.onnx.export()` → ONNX → `onnxruntime-web` +- **Training data**: Isophonics, McGill Billboard (via ChoCo corpus) + +### Beat-synchronous chord detection (DAW integration) + +1. Run BPM/beat detection first (Beat This!) +2. Segment audio at beat boundaries +3. Run consonance-ACE per segment (or on full audio, then snap to beats) +4. Post-processing: merge short segments, snap chord changes to nearest beat/bar +5. Output: chord track aligned to DAW grid + +### Fallback: Essentia.js ChordsDetectionBeats + +- Already works in browser +- Beat-synchronous chord detection built-in +- Lower accuracy (~65-70%) but zero integration work + +--- + +## Part 3: Web Deployment Architecture + +### Recommended Stack + +``` +┌─────────────────────────────────────────────────┐ +│ Main Thread (React) │ +│ - UI rendering │ +│ - Receives results via postMessage │ +└─────────────┬───────────────────────────────────┘ + │ postMessage(audioBuffer) + ▼ +┌─────────────────────────────────────────────────┐ +│ Web Worker │ +│ │ +│ ┌─────────────────────────────────────────────┐ │ +│ │ Feature Extraction (WASM) │ │ +│ │ Option A: essentia.js (C++ → WASM) │ │ +│ │ Option B: Rust (rustfft + mel-spec → WASM) │ │ +│ │ - Mel spectrogram for BPM model │ │ +│ │ - CQT / chromagram for chord model │ │ +│ └──────────────┬──────────────────────────────┘ │ +│ │ Float32Array │ +│ ▼ │ +│ ┌─────────────────────────────────────────────┐ │ +│ │ Model Inference │ │ +│ │ onnxruntime-web (WASM CPU or WebGPU) │ │ +│ │ - Beat This! small (8MB ONNX) → beats/BPM │ │ +│ │ - consonance-ACE (ONNX) → chord labels │ │ +│ └──────────────┬──────────────────────────────┘ │ +│ │ results │ +│ ▼ │ +│ Post-processing: Viterbi smoothing, beat snap │ +└─────────────────────────────────────────────────┘ +``` + +### Runtime Options Comparison + +| Runtime | Pros | Cons | Best For | +|---------|------|------|----------| +| **onnxruntime-web** | Best operator coverage, INT8 quant, WebGPU support | Larger WASM binary (~5MB) | Production deployment | +| Tract (Rust → WASM) | Pure Rust, single binary, lightweight | Less operator coverage | Simple models | +| Candle (HF Rust) | Self-contained WASM, proven with Whisper | Need to reimplement model in Candle | Custom models | +| TensorFlow.js WASM | Mature ecosystem | Heavier, ecosystem moving to ONNX | Legacy TF models | + +### Performance Expectations + +| Operation | Latency (WASM, M2 MacBook) | +|-----------|---------------------------| +| Mel spectrogram (5s clip) | 5-15ms | +| ONNX model inference (small CNN) | 8-12ms | +| ONNX model inference (transformer, 8MB) | 50-200ms | +| WebGPU inference (same transformer) | 5-20ms | +| Total pipeline (5s clip → BPM + chords) | ~100-500ms (WASM) / ~30-100ms (WebGPU) | + +--- + +## Part 4: Implementation Roadmap + +### Phase 1: Quick Win (essentia.js) +- Install `essentia.js` npm package +- Use `RhythmExtractor2013` for BPM + beats +- Use `ChordsDetectionBeats` for beat-synced chords +- Run in Web Worker +- Accuracy: ~70% for both — usable but not great +- **Effort: 1-2 days** + +### Phase 2: High-Accuracy BPM (Beat This! ONNX) +- Clone [beat_this_cpp](https://github.com/mosynthkey/beat_this_cpp), get ONNX model +- Use small variant (8MB) or quantize full model to INT8 +- Implement mel spectrogram in WASM (essentia.js or Rust `mel-spec`) +- Run ONNX inference via `onnxruntime-web` +- Beat-synced output → snap to DAW grid +- **Effort: 3-5 days** + +### Phase 3: High-Accuracy Chords (consonance-ACE ONNX) +- Export consonance-ACE conformer_decomposed model to ONNX from PyTorch +- Model outputs decomposed root/bass/note activations → reconstruct chord labels +- Run ONNX inference via `onnxruntime-web` +- Use beat positions from Phase 2 for beat-synchronous snapping +- 170 chord vocabulary — rich enough for DAW display + +### Phase 4: Optimization +- INT8 quantization of both models (2-3x faster WASM) +- WebGPU acceleration for devices that support it +- Streaming/chunked analysis for long files +- Cache results in IndexedDB + +--- + +## Part 5: Key Repos & Links + +### BPM +- Beat This!: https://github.com/CPJKU/beat_this (Python) | https://github.com/mosynthkey/beat_this_cpp (C++ ONNX) +- BeatNet: https://github.com/mjhydri/BeatNet +- Essentia.js: https://github.com/mtg/essentia.js/ + +### Chords +- consonance-ACE: https://github.com/andreamust/consonance-ACE (ISMIR 2025, MIT) +- BTC: https://github.com/jayg996/BTC-ISMIR19 +- CREMA: https://github.com/bmcfee/crema + +### Inference Runtimes +- onnxruntime-web: https://www.npmjs.com/package/onnxruntime-web +- Tract (Rust ONNX): https://github.com/sonos/tract +- Candle (Rust ML): https://github.com/huggingface/candle + +### Audio Preprocessing +- essentia.js: https://github.com/mtg/essentia.js/ +- rust-melspec-wasm: https://github.com/nicolvisser/rust-melspec-wasm +- mel_spec crate: https://crates.io/crates/mel_spec +- spectrograms crate: https://docs.rs/spectrograms/latest/spectrograms/ + +### Reference Implementations +- basicpitch.cpp (ONNX + WASM): https://github.com/sevagh/basicpitch.cpp +- Candle Whisper WASM: https://huggingface.co/spaces/lmz/candle-whisper + +--- + +## Decision: Why NOT Pure C++ or Rust? + +| Approach | Problem | +|----------|---------| +| Rewrite model in C++ | Models are defined in PyTorch; reimplementing transformer/RNN in C++ is months of work | +| Rewrite model in Rust (candle/burn) | Same problem — must port architecture + load weights | +| Compile Python + PyTorch to WASM | Not feasible | +| **Export to ONNX + run via ort-web** | **Universal bridge: any PyTorch model → ONNX → browser. This is the answer.** | + +C++ and Rust are excellent for the **preprocessing** pipeline (FFT, mel spectrogram, CQT), but for **model inference**, ONNX is the standard interchange format and ort-web is the best runtime for browsers. diff --git a/package-lock.json b/package-lock.json index 34e976df..3c709b52 100644 --- a/package-lock.json +++ b/package-lock.json @@ -32,6 +32,7 @@ "idb-keyval": "^6.2.0", "mp4-muxer": "^5.2.2", "node-pty": "^1.2.0-beta.12", + "onnxruntime-web": "^1.24.3", "react": "^19.0.0", "react-dom": "^19.0.0", "tone": "^15.1.22", @@ -1680,6 +1681,70 @@ "node": ">=18" } }, + "node_modules/@protobufjs/aspromise": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/@protobufjs/aspromise/-/aspromise-1.1.2.tgz", + "integrity": "sha512-j+gKExEuLmKwvz3OgROXtrJ2UG2x8Ch2YZUxahh+s1F2HZ+wAceUNLkvy6zKCPVRkU++ZWQrdxsUeQXmcg4uoQ==", + "license": "BSD-3-Clause" + }, + "node_modules/@protobufjs/base64": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/@protobufjs/base64/-/base64-1.1.2.tgz", + "integrity": "sha512-AZkcAA5vnN/v4PDqKyMR5lx7hZttPDgClv83E//FMNhR2TMcLUhfRUBHCmSl0oi9zMgDDqRUJkSxO3wm85+XLg==", + "license": "BSD-3-Clause" + }, + "node_modules/@protobufjs/codegen": { + "version": "2.0.4", + "resolved": "https://registry.npmjs.org/@protobufjs/codegen/-/codegen-2.0.4.tgz", + "integrity": "sha512-YyFaikqM5sH0ziFZCN3xDC7zeGaB/d0IUb9CATugHWbd1FRFwWwt4ld4OYMPWu5a3Xe01mGAULCdqhMlPl29Jg==", + "license": "BSD-3-Clause" + }, + "node_modules/@protobufjs/eventemitter": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/@protobufjs/eventemitter/-/eventemitter-1.1.0.tgz", + "integrity": "sha512-j9ednRT81vYJ9OfVuXG6ERSTdEL1xVsNgqpkxMsbIabzSo3goCjDIveeGv5d03om39ML71RdmrGNjG5SReBP/Q==", + "license": "BSD-3-Clause" + }, + "node_modules/@protobufjs/fetch": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/@protobufjs/fetch/-/fetch-1.1.0.tgz", + "integrity": "sha512-lljVXpqXebpsijW71PZaCYeIcE5on1w5DlQy5WH6GLbFryLUrBD4932W/E2BSpfRJWseIL4v/KPgBFxDOIdKpQ==", + "license": "BSD-3-Clause", + "dependencies": { + "@protobufjs/aspromise": "^1.1.1", + "@protobufjs/inquire": "^1.1.0" + } + }, + "node_modules/@protobufjs/float": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/@protobufjs/float/-/float-1.0.2.tgz", + "integrity": "sha512-Ddb+kVXlXst9d+R9PfTIxh1EdNkgoRe5tOX6t01f1lYWOvJnSPDBlG241QLzcyPdoNTsblLUdujGSE4RzrTZGQ==", + "license": "BSD-3-Clause" + }, + "node_modules/@protobufjs/inquire": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/@protobufjs/inquire/-/inquire-1.1.0.tgz", + "integrity": "sha512-kdSefcPdruJiFMVSbn801t4vFK7KB/5gd2fYvrxhuJYg8ILrmn9SKSX2tZdV6V+ksulWqS7aXjBcRXl3wHoD9Q==", + "license": "BSD-3-Clause" + }, + "node_modules/@protobufjs/path": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/@protobufjs/path/-/path-1.1.2.tgz", + "integrity": "sha512-6JOcJ5Tm08dOHAbdR3GrvP+yUUfkjG5ePsHYczMFLq3ZmMkAD98cDgcT2iA1lJ9NVwFd4tH/iSSoe44YWkltEA==", + "license": "BSD-3-Clause" + }, + "node_modules/@protobufjs/pool": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/@protobufjs/pool/-/pool-1.1.0.tgz", + "integrity": "sha512-0kELaGSIDBKvcgS4zkjz1PeddatrjYcmMWOlAuAPwAeccUrPHdUqo/J6LiymHHEiJT5NrF1UVwxY14f+fy4WQw==", + "license": "BSD-3-Clause" + }, + "node_modules/@protobufjs/utf8": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/@protobufjs/utf8/-/utf8-1.1.0.tgz", + "integrity": "sha512-Vvn3zZrhQZkkBE8LSuW3em98c0FwgO4nxzv6OdSxPKJIEKY2bGbHn+mhGIPerzI4twdxaP8/0+06HBpwf345Lw==", + "license": "BSD-3-Clause" + }, "node_modules/@replit/codemirror-emacs": { "version": "6.1.0", "resolved": "https://registry.npmjs.org/@replit/codemirror-emacs/-/codemirror-emacs-6.1.0.tgz", @@ -3350,7 +3415,6 @@ "version": "25.5.0", "resolved": "https://registry.npmjs.org/@types/node/-/node-25.5.0.tgz", "integrity": "sha512-jp2P3tQMSxWugkCUKLRPVUpGaL5MVFwF8RDuSRztfwgN1wmqJeMSbKlnEtQqU8UrhTmzEmZdu2I6v2dpp7XIxw==", - "dev": true, "license": "MIT", "dependencies": { "undici-types": "~7.18.0" @@ -4546,6 +4610,12 @@ "babel-plugin-add-module-exports": "^0.2.1" } }, + "node_modules/flatbuffers": { + "version": "25.9.23", + "resolved": "https://registry.npmjs.org/flatbuffers/-/flatbuffers-25.9.23.tgz", + "integrity": "sha512-MI1qs7Lo4Syw0EOzUl0xjs2lsoeqFku44KpngfIduHBYvzm8h2+7K8YMQh1JtVVVrUvhLpNwqVi4DERegUJhPQ==", + "license": "Apache-2.0" + }, "node_modules/focus-trap": { "version": "7.8.0", "resolved": "https://registry.npmjs.org/focus-trap/-/focus-trap-7.8.0.tgz", @@ -4610,6 +4680,12 @@ "dev": true, "license": "ISC" }, + "node_modules/guid-typescript": { + "version": "1.0.9", + "resolved": "https://registry.npmjs.org/guid-typescript/-/guid-typescript-1.0.9.tgz", + "integrity": "sha512-Y8T4vYhEfwJOTbouREvG+3XDsjr8E3kIr7uf+JZ0BYloFsttiHU0WfvANVsR7TxNUJa/WpCnw/Ino/p+DeBhBQ==", + "license": "ISC" + }, "node_modules/hasown": { "version": "2.0.2", "resolved": "https://registry.npmjs.org/hasown/-/hasown-2.0.2.tgz", @@ -5146,6 +5222,12 @@ "url": "https://opencollective.com/parcel" } }, + "node_modules/long": { + "version": "5.3.2", + "resolved": "https://registry.npmjs.org/long/-/long-5.3.2.tgz", + "integrity": "sha512-mNAgZ1GmyNhD7AuqnTG3/VQ26o760+ZYBPKjPvugO8+nLbYfX6TVpJPseBvopbdY+qpZ/lKUnmEc1LeZYS3QAA==", + "license": "Apache-2.0" + }, "node_modules/loose-envify": { "version": "1.4.0", "resolved": "https://registry.npmjs.org/loose-envify/-/loose-envify-1.4.0.tgz", @@ -5472,6 +5554,26 @@ "regex-recursion": "^6.0.2" } }, + "node_modules/onnxruntime-common": { + "version": "1.24.3", + "resolved": "https://registry.npmjs.org/onnxruntime-common/-/onnxruntime-common-1.24.3.tgz", + "integrity": "sha512-GeuPZO6U/LBJXvwdaqHbuUmoXiEdeCjWi/EG7Y1HNnDwJYuk6WUbNXpF6luSUY8yASul3cmUlLGrCCL1ZgVXqA==", + "license": "MIT" + }, + "node_modules/onnxruntime-web": { + "version": "1.24.3", + "resolved": "https://registry.npmjs.org/onnxruntime-web/-/onnxruntime-web-1.24.3.tgz", + "integrity": "sha512-41dDq7fxtTm0XzGE7N0d6m8FcOY8EWtUA65GkOixJPB/G7DGzBmiDAnVVXHznRw9bgUZpb+4/1lQK/PNxGpbrQ==", + "license": "MIT", + "dependencies": { + "flatbuffers": "^25.1.24", + "guid-typescript": "^1.0.9", + "long": "^5.2.3", + "onnxruntime-common": "1.24.3", + "platform": "^1.3.6", + "protobufjs": "^7.2.4" + } + }, "node_modules/parse5": { "version": "8.0.0", "resolved": "https://registry.npmjs.org/parse5/-/parse5-8.0.0.tgz", @@ -5543,6 +5645,12 @@ "url": "https://github.com/sponsors/jonschlinkert" } }, + "node_modules/platform": { + "version": "1.3.6", + "resolved": "https://registry.npmjs.org/platform/-/platform-1.3.6.tgz", + "integrity": "sha512-fnWVljUchTro6RiCFvCXBbNhJc2NijN7oIQxbwsyL0buWJPG85v81ehlHI9fXrJsMNgTofEoWIQeClKpgxFLrg==", + "license": "MIT" + }, "node_modules/playwright": { "version": "1.58.2", "resolved": "https://registry.npmjs.org/playwright/-/playwright-1.58.2.tgz", @@ -5657,6 +5765,30 @@ "url": "https://github.com/sponsors/wooorm" } }, + "node_modules/protobufjs": { + "version": "7.5.4", + "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-7.5.4.tgz", + "integrity": "sha512-CvexbZtbov6jW2eXAvLukXjXUW1TzFaivC46BpWc/3BpcCysb5Vffu+B3XHMm8lVEuy2Mm4XGex8hBSg1yapPg==", + "hasInstallScript": true, + "license": "BSD-3-Clause", + "dependencies": { + "@protobufjs/aspromise": "^1.1.2", + "@protobufjs/base64": "^1.1.2", + "@protobufjs/codegen": "^2.0.4", + "@protobufjs/eventemitter": "^1.1.0", + "@protobufjs/fetch": "^1.1.0", + "@protobufjs/float": "^1.0.2", + "@protobufjs/inquire": "^1.1.0", + "@protobufjs/path": "^1.1.2", + "@protobufjs/pool": "^1.1.0", + "@protobufjs/utf8": "^1.1.0", + "@types/node": ">=13.7.0", + "long": "^5.0.0" + }, + "engines": { + "node": ">=12.0.0" + } + }, "node_modules/punycode": { "version": "2.3.1", "resolved": "https://registry.npmjs.org/punycode/-/punycode-2.3.1.tgz", @@ -6310,7 +6442,6 @@ "version": "7.18.2", "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-7.18.2.tgz", "integrity": "sha512-AsuCzffGHJybSaRrmr5eHr81mwJU3kjw6M+uprWvCXiNeN9SOGwQ3Jn8jb8m3Z6izVgknn1R0FTCEAP2QrLY/w==", - "dev": true, "license": "MIT" }, "node_modules/unist-util-is": { diff --git a/package.json b/package.json index 4c075d85..72347620 100644 --- a/package.json +++ b/package.json @@ -42,6 +42,7 @@ "idb-keyval": "^6.2.0", "mp4-muxer": "^5.2.2", "node-pty": "^1.2.0-beta.12", + "onnxruntime-web": "^1.24.3", "react": "^19.0.0", "react-dom": "^19.0.0", "tone": "^15.1.22", diff --git a/scripts/download-models.sh b/scripts/download-models.sh new file mode 100755 index 00000000..38cbb48e --- /dev/null +++ b/scripts/download-models.sh @@ -0,0 +1,40 @@ +#!/usr/bin/env bash +# Download ONNX models for BPM detection and chord recognition. +# Models are too large for git (~100MB total), so they're downloaded on demand. +# +# Usage: ./scripts/download-models.sh + +set -euo pipefail + +MODELS_DIR="$(cd "$(dirname "$0")/.." && pwd)/public/models" +mkdir -p "$MODELS_DIR" + +echo "Downloading ONNX models to $MODELS_DIR..." + +# Beat This! (79MB) — CPJKU ISMIR 2024 SOTA beat/BPM detection +# Source: https://github.com/mosynthkey/beat_this_cpp +BEAT_THIS_URL="https://github.com/mosynthkey/beat_this_cpp/raw/main/onnx/beat_this.onnx" +if [ ! -f "$MODELS_DIR/beat-this.onnx" ]; then + echo "Downloading Beat This! model (79MB)..." + curl -L -o "$MODELS_DIR/beat-this.onnx" "$BEAT_THIS_URL" + echo " Done: beat-this.onnx ($(du -h "$MODELS_DIR/beat-this.onnx" | cut -f1))" +else + echo " beat-this.onnx already exists, skipping" +fi + +# consonance-ACE — must be exported from PyTorch checkpoint +# The ONNX file should already exist if you ran the export script. +if [ ! -f "$MODELS_DIR/consonance-ace.onnx" ]; then + echo "" + echo "consonance-ace.onnx not found." + echo "To export it, run:" + echo " python scripts/export-consonance-ace.py" + echo "" + echo "Or download from the project's release assets (if available)." +else + echo " consonance-ace.onnx already exists, skipping" +fi + +echo "" +echo "Model files:" +ls -lh "$MODELS_DIR"/*.onnx 2>/dev/null || echo " (none found)" diff --git a/scripts/export-consonance-ace.py b/scripts/export-consonance-ace.py new file mode 100644 index 00000000..7df26b7e --- /dev/null +++ b/scripts/export-consonance-ace.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +""" +Export consonance-ACE conformer_decomposed model to ONNX. + +Prerequisites: + pip install torch torchaudio librosa lightning gin-config torchmetrics + + git clone https://github.com/andreamust/consonance-ACE /tmp/consonance-ACE + cd /tmp/consonance-ACE && pip install -r requirements.txt + +Usage: + python scripts/export-consonance-ace.py + +Output: + public/models/consonance-ace.onnx (~20MB) +""" + +import sys +import os + +# Add consonance-ACE to path +ACE_REPO = "/tmp/consonance-ACE" +if not os.path.isdir(ACE_REPO): + print(f"ERROR: Clone consonance-ACE first:") + print(f" git clone https://github.com/andreamust/consonance-ACE {ACE_REPO}") + sys.exit(1) + +sys.path.insert(0, ACE_REPO) + +import torch +import numpy as np + +from ACE.models.conformer_decomposed import ConformerDecomposedModel + + +class ACEWrapper(torch.nn.Module): + """Wraps ConformerDecomposedModel to return tuple (ONNX-compatible).""" + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, x): + out = self.model(x) + return out["root"], out["bass"], out["onehot"] + + +def main(): + ckpt = os.path.join(ACE_REPO, "ACE/checkpoints/conformer_decomposed_smooth.ckpt") + if not os.path.exists(ckpt): + print(f"ERROR: Checkpoint not found: {ckpt}") + sys.exit(1) + + print(f"Loading model from {ckpt}...") + model = ConformerDecomposedModel.load_from_checkpoint( + ckpt, + vocabularies={"root": 13, "bass": 13, "onehot": 12}, + map_location="cpu", + loss="consonance_decomposed", + vocab_path=os.path.join(ACE_REPO, "ACE/chords_vocab.joblib"), + strict=False, + ) + model.eval() + print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}") + + wrapper = ACEWrapper(model) + wrapper.eval() + + # Dummy input: [batch=1, channels=1, freq=144, time=862] + # 862 frames = 20s at sr=22050, hop=512 + dummy = torch.randn(1, 1, 144, 862) + + output_dir = os.path.join(os.path.dirname(__file__), "..", "public", "models") + os.makedirs(output_dir, exist_ok=True) + output_path = os.path.join(output_dir, "consonance-ace.onnx") + + print(f"Exporting to {output_path}...") + torch.onnx.export( + wrapper, dummy, output_path, + input_names=["cqt_features"], + output_names=["root_logits", "bass_logits", "chord_logits"], + dynamic_axes={ + "cqt_features": {3: "n_frames"}, + "root_logits": {1: "n_frames"}, + "bass_logits": {1: "n_frames"}, + "chord_logits": {1: "n_frames"}, + }, + opset_version=17, + do_constant_folding=True, + ) + + size_mb = os.path.getsize(output_path) / 1024 / 1024 + print(f"Exported: {output_path} ({size_mb:.1f} MB)") + + # Verify + import onnxruntime as ort + sess = ort.InferenceSession(output_path) + result = sess.run(None, {"cqt_features": dummy.numpy()}) + + with torch.no_grad(): + pt_r, pt_b, pt_c = wrapper(dummy) + + diff = max( + np.abs(result[0] - pt_r.numpy()).max(), + np.abs(result[1] - pt_b.numpy()).max(), + np.abs(result[2] - pt_c.numpy()).max(), + ) + print(f"Max PyTorch vs ONNX diff: {diff:.6f}") + print("PASS" if diff < 0.001 else "WARN: large diff") + + +if __name__ == "__main__": + main() diff --git a/scripts/verify-onnx-models.py b/scripts/verify-onnx-models.py new file mode 100644 index 00000000..9aee967e --- /dev/null +++ b/scripts/verify-onnx-models.py @@ -0,0 +1,285 @@ +#!/usr/bin/env python3 +""" +Verify ONNX model inference for Beat This! (BPM) and consonance-ACE (chords). + +Usage: + python scripts/verify-onnx-models.py [path/to/audio.wav] + +Validates: + 1. Models load correctly in ONNX Runtime + 2. Input/output shapes match expectations + 3. Beat detection produces reasonable BPM (30-300 range) + 4. Chord detection produces valid chord labels + 5. Processing time is within acceptable range for web deployment +""" + +import sys +import time +import os + +import numpy as np +import onnxruntime as ort +import librosa + +MODELS_DIR = os.path.join(os.path.dirname(__file__), "..", "public", "models") +BEAT_THIS_PATH = os.path.join(MODELS_DIR, "beat-this.onnx") +CONSONANCE_ACE_PATH = os.path.join(MODELS_DIR, "consonance-ace.onnx") + +# Constants matching the model training configs +BEAT_THIS_SR = 22050 +BEAT_THIS_N_FFT = 2048 +BEAT_THIS_HOP = 441 # 20ms hop @ 22050 +BEAT_THIS_N_MELS = 128 + +ACE_SR = 22050 +ACE_HOP = 512 +ACE_N_BINS = 144 # CQT bins + +# Chord label maps +ROOT_LABELS = ["N", "C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"] +PITCH_CLASSES = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"] + + +def compute_mel_spectrogram(audio: np.ndarray, sr: int) -> np.ndarray: + """Compute log-mel spectrogram using Beat This! official preprocessing. + + Uses torchaudio MelSpectrogram with exact config: + n_fft=1024, hop=441, f_min=30, f_max=11000, n_mels=128, + mel_scale=slaney, normalized=frame_length, power=1, + output = log1p(1000 * mel).T -> [time, freq] + """ + try: + # Use official Beat This! preprocessing if available + import torch + from beat_this.preprocessing import LogMelSpect + audio_t = torch.from_numpy(audio).float() + spect = LogMelSpect(sample_rate=sr, device="cpu") + with torch.no_grad(): + mel = spect(audio_t) # [T, 128] + return mel.numpy()[np.newaxis, :, :] # [1, T, 128] + except ImportError: + # Fallback: approximate with librosa + import warnings + warnings.warn("beat_this not installed, using librosa approximation for mel spectrogram") + mel = librosa.feature.melspectrogram( + y=audio, sr=sr, + n_fft=1024, hop_length=441, n_mels=128, + fmin=30, fmax=11000, power=1, + norm="slaney", htk=False, + ) + log_mel = np.log1p(1000.0 * mel) + return log_mel.T[np.newaxis, :, :] # [1, T, 128] + + +def compute_cqt(audio: np.ndarray, sr: int) -> np.ndarray: + """Compute CQT features matching consonance-ACE CQTransform exactly. + + Config from ACE/preprocess/transforms.py: + sr=22050, hop=512, bins_per_octave=24, num_octaves=6, start_note=C1 + Output = abs(cqt) (raw magnitude, NOT dB) + Audio is normalized to [-1, 1] before CQT. + """ + # Normalize audio to [-1, 1] (matching AudioProcessor._normalize) + max_val = np.abs(audio).max() + if max_val > 0: + audio = audio / max_val + + fmin = librosa.note_to_hz("C1") + cqt = librosa.cqt( + y=audio, sr=sr, hop_length=ACE_HOP, + n_bins=ACE_N_BINS, bins_per_octave=24, + fmin=fmin, + ) + cqt_mag = np.abs(cqt) + # Model expects [batch, 1, freq, time] — raw magnitude + return cqt_mag[np.newaxis, np.newaxis, :, :].astype(np.float32) # [1, 1, 144, T] + + +def verify_beat_this(audio: np.ndarray, sr: int): + """Verify Beat This! ONNX model.""" + print("\n" + "=" * 60) + print("BEAT THIS! — BPM Detection Verification") + print("=" * 60) + + if not os.path.exists(BEAT_THIS_PATH): + print(f" SKIP: {BEAT_THIS_PATH} not found") + return False + + sess = ort.InferenceSession(BEAT_THIS_PATH) + inp = sess.get_inputs()[0] + print(f" Model input: {inp.name}, shape={inp.shape}, type={inp.type}") + for o in sess.get_outputs(): + print(f" Model output: {o.name}, shape={o.shape}") + + # Compute mel spectrogram + mel = compute_mel_spectrogram(audio, sr) + print(f" Mel spectrogram: shape={mel.shape}, range=[{mel.min():.2f}, {mel.max():.2f}]") + + # Run inference + t0 = time.time() + beat_logits, downbeat_logits = sess.run(None, {inp.name: mel.astype(np.float32)}) + elapsed = time.time() - t0 + print(f" Inference time: {elapsed * 1000:.0f}ms") + + print(f" Beat logits: shape={beat_logits.shape}, range=[{beat_logits.min():.3f}, {beat_logits.max():.3f}]") + print(f" Downbeat logits: shape={downbeat_logits.shape}") + + # Post-processing: local max-pool peak picking (matching Beat This! minimal postprocessor) + # 1. max_pool1d with kernel=7 (±70ms at 50fps) to find local maxima + # 2. Keep peaks where logit > 0 (probability > 0.5) + def peak_pick(logits_1d: np.ndarray, kernel: int = 7) -> np.ndarray: + """Pick local maxima from logits, matching Beat This! postprocessor.""" + from scipy.ndimage import maximum_filter1d + maxpool = maximum_filter1d(logits_1d, size=kernel, mode='constant', cval=-1000) + peaks = (logits_1d == maxpool) & (logits_1d > 0) + return np.where(peaks)[0] + + beat_frames = peak_pick(beat_logits[0]) + downbeat_frames = peak_pick(downbeat_logits[0]) + + # Convert frames to time (hop=441 @ 22050Hz = 20ms per frame) + frame_duration = 441.0 / BEAT_THIS_SR # 0.02s per frame + beat_times = beat_frames * frame_duration + downbeat_times = downbeat_frames * frame_duration + + print(f" Detected {len(beat_times)} beats, {len(downbeat_times)} downbeats") + + if len(beat_times) >= 2: + # Compute BPM from inter-beat intervals + ibis = np.diff(beat_times) + median_ibi = np.median(ibis) + bpm = 60.0 / median_ibi if median_ibi > 0 else 0 + print(f" Estimated BPM: {bpm:.1f}") + print(f" First 10 beat times (s): {beat_times[:10].round(2).tolist()}") + print(f" First 5 downbeat times (s): {downbeat_times[:5].round(2).tolist()}") + + # Sanity checks + ok = True + if bpm < 30 or bpm > 300: + print(f" WARN: BPM {bpm:.1f} outside expected range [30, 300]") + ok = False + if len(beat_times) < 4: + print(f" WARN: Too few beats detected ({len(beat_times)})") + ok = False + + if ok: + print(" PASS: Beat detection looks correct") + return ok + else: + print(" FAIL: Fewer than 2 beats detected") + return False + + +def verify_consonance_ace(audio: np.ndarray, sr: int): + """Verify consonance-ACE ONNX model.""" + print("\n" + "=" * 60) + print("CONSONANCE-ACE — Chord Recognition Verification") + print("=" * 60) + + if not os.path.exists(CONSONANCE_ACE_PATH): + print(f" SKIP: {CONSONANCE_ACE_PATH} not found") + return False + + sess = ort.InferenceSession(CONSONANCE_ACE_PATH) + inp = sess.get_inputs()[0] + print(f" Model input: {inp.name}, shape={inp.shape}, type={inp.type}") + for o in sess.get_outputs(): + print(f" Model output: {o.name}, shape={o.shape}") + + # Compute CQT — process in 20s chunks like the original + chunk_dur = 20.0 + n_samples = int(chunk_dur * sr) + audio_chunk = audio[:n_samples] + if len(audio_chunk) < n_samples: + audio_chunk = np.pad(audio_chunk, (0, n_samples - len(audio_chunk))) + + cqt = compute_cqt(audio_chunk, sr) + print(f" CQT features: shape={cqt.shape}, range=[{cqt.min():.2f}, {cqt.max():.2f}]") + + # Run inference + t0 = time.time() + root_logits, bass_logits, chord_logits = sess.run(None, {inp.name: cqt.astype(np.float32)}) + elapsed = time.time() - t0 + print(f" Inference time: {elapsed * 1000:.0f}ms") + + print(f" Root logits: shape={root_logits.shape}") + print(f" Bass logits: shape={bass_logits.shape}") + print(f" Chord logits: shape={chord_logits.shape}") + + # Decode predictions + root_preds = np.argmax(root_logits[0], axis=-1) # [T] + bass_preds = np.argmax(bass_logits[0], axis=-1) # [T] + chord_probs = 1 / (1 + np.exp(-chord_logits[0])) # sigmoid -> [T, 12] + + n_frames = root_preds.shape[0] + frame_dur = chunk_dur / n_frames + print(f" {n_frames} frames, {frame_dur * 1000:.1f}ms per frame") + + # Sample chord labels at 1-second intervals + print("\n Chord timeline (every 1s):") + for sec in range(min(int(chunk_dur), 20)): + frame_idx = int(sec / frame_dur) + if frame_idx >= n_frames: + break + root = ROOT_LABELS[root_preds[frame_idx]] + bass = ROOT_LABELS[bass_preds[frame_idx]] + active_notes = np.where(chord_probs[frame_idx] > 0.5)[0] + notes_str = ",".join([PITCH_CLASSES[n] for n in active_notes]) if len(active_notes) > 0 else "none" + chord_label = f"{root}" if root != "N" else "N" + print(f" {sec:2d}s: root={root:>2s} bass={bass:>2s} notes=[{notes_str}] -> {chord_label}") + + # Sanity checks + ok = True + unique_roots = len(set(root_preds.tolist())) + if unique_roots < 2: + print(f"\n WARN: Only {unique_roots} unique root predictions (model may not be discriminating)") + + # Check that not all predictions are "N" (no chord) + n_ratio = np.mean(root_preds == 0) + if n_ratio > 0.95: + print(f" WARN: {n_ratio * 100:.0f}% of frames predicted as 'N' (no chord)") + ok = False + + if ok: + print("\n PASS: Chord detection looks correct") + return ok + + +def main(): + if len(sys.argv) > 1: + audio_path = sys.argv[1] + else: + # Use a default test file + audio_path = "/Users/gongjunmin/timedomain/nanoclaw/groups/main/funk_rock_groove.mp3" + + if not os.path.exists(audio_path): + print(f"Audio file not found: {audio_path}") + print("Usage: python scripts/verify-onnx-models.py [path/to/audio.wav]") + sys.exit(1) + + print(f"Loading audio: {audio_path}") + audio, sr = librosa.load(audio_path, sr=22050, mono=True) + duration = len(audio) / sr + print(f" Duration: {duration:.1f}s, SR: {sr}, Samples: {len(audio)}") + + results = {} + results["beat_this"] = verify_beat_this(audio, sr) + results["consonance_ace"] = verify_consonance_ace(audio, sr) + + print("\n" + "=" * 60) + print("SUMMARY") + print("=" * 60) + for name, ok in results.items(): + status = "PASS" if ok else "FAIL" + print(f" {name}: {status}") + + if all(results.values()): + print("\nAll models verified successfully!") + sys.exit(0) + else: + print("\nSome models failed verification.") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/src/components/generation/AudioAnalysisPanel.tsx b/src/components/generation/AudioAnalysisPanel.tsx index 60a89ad5..c0b333c9 100644 --- a/src/components/generation/AudioAnalysisPanel.tsx +++ b/src/components/generation/AudioAnalysisPanel.tsx @@ -2,12 +2,17 @@ import { useState, useEffect, useCallback } from 'react'; import { useProjectStore } from '../../store/projectStore'; import { useUIStore } from '../../store/uiStore'; import { useGenerationStore } from '../../store/generationStore'; +import { useAnalysisStore } from '../../store/analysisStore'; import * as api from '../../services/aceStepApi'; import { loadAudioBlobByKey } from '../../services/audioFileManager'; +import { analyzeClipLocally } from '../../services/localAnalysisService'; import type { TaskResultItem } from '../../types/api'; +import type { LocalAnalysisResult, ChordEvent } from '../../types/analysis'; import { POLL_INTERVAL_MS, MAX_POLL_DURATION_MS } from '../../constants/defaults'; -interface AnalysisResult { +type AnalysisMode = 'local' | 'server'; + +interface ServerAnalysisResult { bpm: number | undefined; keyScale: string | undefined; timeSignature: string | undefined; @@ -25,14 +30,22 @@ export function AudioAnalysisPanel() { const clip = analysisClipId ? getClipById(analysisClipId) : null; const track = project?.tracks.find((t) => t.clips.some((c) => c.id === analysisClipId)) ?? null; + const [mode, setMode] = useState('local'); const [analyzing, setAnalyzing] = useState(false); - const [result, setResult] = useState(null); + const [serverResult, setServerResult] = useState(null); + const [localResult, setLocalResult] = useState(null); const [error, setError] = useState(''); const [applied, setApplied] = useState(false); + // Local analysis progress from store + const analysisJob = useAnalysisStore((s) => + analysisClipId ? s.getJobForClip(analysisClipId) : undefined, + ); + // Reset when clip changes useEffect(() => { - setResult(null); + setServerResult(null); + setLocalResult(null); setError(''); setApplied(false); setAnalyzing(false); @@ -48,14 +61,31 @@ export function AudioAnalysisPanel() { return () => window.removeEventListener('keydown', handleEsc); }, [onClose]); - const handleAnalyze = useCallback(async () => { + // ---------- Local analysis ---------- + const handleLocalAnalyze = useCallback(async () => { + if (!clip || !analysisClipId || analyzing) return; + setAnalyzing(true); + setError(''); + setLocalResult(null); + + try { + const result = await analyzeClipLocally(analysisClipId); + setLocalResult(result); + } catch (e) { + setError(e instanceof Error ? e.message : 'Local analysis failed'); + } finally { + setAnalyzing(false); + } + }, [clip, analysisClipId, analyzing]); + + // ---------- Server analysis ---------- + const handleServerAnalyze = useCallback(async () => { if (!clip || analyzing || isGenerating) return; setAnalyzing(true); setError(''); - setResult(null); + setServerResult(null); try { - // Load clip audio let audioBlob: Blob | null = null; if (clip.isolatedAudioKey) { audioBlob = (await loadAudioBlobByKey(clip.isolatedAudioKey)) ?? null; @@ -68,15 +98,13 @@ export function AudioAnalysisPanel() { return; } - // Send as a cover task with minimal transformation — we just want the metas back. - // The cover task returns inferred BPM, key, etc. in the result metas. const coverParams = { task_type: 'cover' as const, caption: 'analyze audio properties', lyrics: '', - audio_cover_strength: 0.0, // No transformation — just analyze + audio_cover_strength: 0.0, audio_duration: clip.duration, - inference_steps: 10, // Minimal steps for fast analysis + inference_steps: 10, guidance_scale: 1.0, shift: 1.0, batch_size: 1, @@ -99,7 +127,7 @@ export function AudioAnalysisPanel() { const items: TaskResultItem[] = JSON.parse(entry.result); const first = items?.[0]; if (first) { - setResult({ + setServerResult({ bpm: first.metas?.bpm, keyScale: first.metas?.keyscale, timeSignature: first.metas?.timesignature, @@ -123,30 +151,39 @@ export function AudioAnalysisPanel() { } }, [clip, analyzing, isGenerating, project]); + const handleAnalyze = mode === 'local' ? handleLocalAnalyze : handleServerAnalyze; + const handleApplyToProject = useCallback(() => { - if (!result || !project) return; + if (!project) return; const updates: Record = {}; - if (result.bpm) updates.bpm = Math.round(result.bpm); - if (result.keyScale) updates.keyScale = result.keyScale; + if (mode === 'local' && localResult) { + if (localResult.bpm) updates.bpm = Math.round(localResult.bpm); + if (localResult.keyScale) updates.keyScale = localResult.keyScale; + } else if (mode === 'server' && serverResult) { + if (serverResult.bpm) updates.bpm = Math.round(serverResult.bpm); + if (serverResult.keyScale) updates.keyScale = serverResult.keyScale; + } if (Object.keys(updates).length > 0) { useProjectStore.getState().updateProject(updates as { bpm?: number; keyScale?: string }); setApplied(true); } - }, [result, project]); + }, [mode, localResult, serverResult, project]); if (!analysisClipId || !clip || !track) return null; const hasAudio = !!(clip.isolatedAudioKey || clip.cumulativeMixKey); - - // If clip already has inferred metas, show them immediately const existingMetas = clip.inferredMetas; + const hasResult = mode === 'local' ? !!localResult : !!serverResult; + const hasBpmOrKey = mode === 'local' + ? !!(localResult?.bpm || localResult?.keyScale) + : !!(serverResult?.bpm || serverResult?.keyScale); return (
{ if (e.target === e.currentTarget) onClose(); }} > -
+
{/* Header */}
@@ -165,6 +202,30 @@ export function AudioAnalysisPanel() { {/* Body */}
+ {/* Mode selector */} +
+ + +
+ {/* Source clip info */}

Clip

@@ -175,10 +236,33 @@ export function AudioAnalysisPanel() {

{clip.duration.toFixed(1)}s

+ {/* Local analysis progress */} + {mode === 'local' && analyzing && analysisJob && ( +
+

+ Analyzing... +

+
+
+
+

{analysisJob.message}

+
+ )} + {/* Existing inferred metas */} {existingMetas && (
-

Previously Inferred

+

+ Previously Inferred + {existingMetas.analysisSource && ( + + ({existingMetas.analysisSource}) + + )} +

{existingMetas.bpm && (
@@ -205,42 +289,72 @@ export function AudioAnalysisPanel() {
)}
+ + {/* Chord display for local analysis results */} + {existingMetas.chords && existingMetas.chords.length > 0 && ( +
+ Chords +
+ {existingMetas.chords.slice(0, 16).map((chord, i) => ( + + {chord.label} + + ))} + {existingMetas.chords.length > 16 && ( + + +{existingMetas.chords.length - 16} more + + )} +
+
+ )}
)} - {/* Analysis results */} - {result && ( + {/* Local analysis results */} + {mode === 'local' && localResult && ( + + )} + + {/* Server analysis results */} + {mode === 'server' && serverResult && (
-

Analysis Results

+

+ Server Results +

- {result.bpm && ( + {serverResult.bpm && (
BPM -

{Math.round(result.bpm)}

+

{Math.round(serverResult.bpm)}

)} - {result.keyScale && ( + {serverResult.keyScale && (
Key -

{result.keyScale}

+

{serverResult.keyScale}

)} - {result.timeSignature && ( + {serverResult.timeSignature && (
Time Sig -

{result.timeSignature}

+

{serverResult.timeSignature}

)} - {result.genres && ( + {serverResult.genres && (
Genre -

{result.genres}

+

{serverResult.genres}

)} - {result.caption && ( + {serverResult.caption && (
Description -

{result.caption}

+

{serverResult.caption}

)}
@@ -258,6 +372,13 @@ export function AudioAnalysisPanel() { No audio available — generate the clip first before analyzing.

)} + + {mode === 'local' && !analyzing && !localResult && hasAudio && ( +

+ Local analysis uses Beat This! for BPM detection and Consonance-ACE for chord recognition. + Models are downloaded on first use (~23MB total) and cached locally. +

+ )}
{/* Footer */} @@ -269,7 +390,7 @@ export function AudioAnalysisPanel() { Close
- {result && (result.bpm || result.keyScale) && ( + {hasResult && hasBpmOrKey && (