Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions .github/workflows/turn-accuracy.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
name: Turn Accuracy

on:
workflow_dispatch:

permissions:
contents: write
pull-requests: write

jobs:
accuracy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- uses: dtolnay/rust-toolchain@stable

- uses: Swatinem/rust-cache@v2

- name: Run accuracy report
run: |
cargo test --features pipecat --test accuracy \
-- --ignored accuracy_report --nocapture 2>&1 | tee accuracy-output.txt

- name: Update README benchmark table
run: |
python3 << 'PYEOF'
import re, sys

with open('accuracy-output.txt') as f:
output = f.read()

version = None
for line in output.split('\n'):
if line.startswith('BENCHMARK_VERSION='):
version = line.split('=', 1)[1].strip()
break

if not version:
print("ERROR: No version found in test output")
sys.exit(1)

table_lines = [l for l in output.split('\n') if l.startswith('|')]
if len(table_lines) < 3:
print("ERROR: No benchmark table found in test output")
sys.exit(1)

table = '\n'.join(table_lines)
block = f'*v{version}*\n\n{table}'

with open('README.md') as f:
readme = f.read()

pattern = r'(<!-- benchmark-table-start -->).*?(<!-- benchmark-table-end -->)'
replacement = rf'\1\n{block}\n\2'
updated = re.sub(pattern, replacement, readme, flags=re.DOTALL)

if updated == readme:
print("No changes to README")
else:
with open('README.md', 'w') as f:
f.write(updated)
print(f"README updated with v{version} benchmarks:")
print(table)
PYEOF

- name: Job summary
if: always()
run: |
echo "## Turn Accuracy Report" >> "$GITHUB_STEP_SUMMARY"
if [ -f accuracy-output.txt ]; then
TABLE=$(grep '^|' accuracy-output.txt || true)
if [ -n "$TABLE" ]; then
echo "$TABLE" >> "$GITHUB_STEP_SUMMARY"
fi
fi

- name: Create PR with updated benchmarks
uses: peter-evans/create-pull-request@v6
with:
commit-message: "docs: update accuracy table"
title: "docs: update accuracy table"
body: |
Auto-generated accuracy update from the turn detection pipeline cross-validation.
branch: docs/update-accuracy
delete-branch: true
10 changes: 10 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,13 @@ Cargo.lock
*.swo
.DS_Store
.cargo/config.toml

# Python tooling (scripts/)
scripts/.venv/
scripts/__pycache__/
scripts/*.onnx
__pycache__/
*.pyc

# Generated mel reference tensors (regenerate with scripts/gen_reference.py)
*.mel.npy
24 changes: 17 additions & 7 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
.PHONY: help check test fmt lint doc ci
.PHONY: help check test fmt lint doc ci accuracy mel

help:
@echo "Available targets:"
@echo " check Check workspace compiles"
@echo " test Run all tests"
@echo " fmt Format code"
@echo " lint Run clippy with warnings as errors"
@echo " doc Build and open docs in browser"
@echo " ci Run all CI checks locally (fmt, clippy, test, doc, features)"
@echo " check Check workspace compiles"
@echo " test Run all tests"
@echo " accuracy Cross-validate Rust pipeline against Python reference"
@echo " mel Compare Rust vs Python mel spectrograms element-wise"
@echo " fmt Format code"
@echo " lint Run clippy with warnings as errors"
@echo " doc Build and open docs in browser"
@echo " ci Run all CI checks locally (fmt, clippy, test, doc, features)"

# Check workspace compiles
check:
Expand All @@ -17,6 +19,14 @@ check:
test:
cargo test --workspace

# Cross-validate Rust mel+ONNX pipeline against Python reference probabilities
accuracy:
cargo test --features pipecat --test accuracy -- --ignored accuracy_report --nocapture

# Compare Rust vs Python mel spectrograms element-wise (requires .npy fixtures)
mel:
cargo test --features pipecat -- mel_report --ignored --nocapture

# Format code
fmt:
cargo fmt --all
Expand Down
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,16 @@ wavekat-voice --> orchestrates VAD + turn + ASR + LLM + TTS
- Text-based detectors depend on ASR transcript quality. Pair with a
streaming ASR provider for best results.

## Accuracy

Cross-validated against the original Python (Pipecat) pipeline on three fixture clips.
Tolerance: ±0.02 probability.

<!-- benchmark-table-start -->
<!-- benchmark-table-end -->

Run locally with `make accuracy`. See [`scripts/README.md`](scripts/README.md) for how to regenerate the Python reference.

## License

Licensed under [Apache 2.0](LICENSE).
Expand Down
3 changes: 3 additions & 0 deletions crates/wavekat-turn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ ureq = { version = "3", optional = true }

[dev-dependencies]
hound = "3.5"
ndarray-npy = "0.10"
serde = { version = "1", features = ["derive"] }
serde_json = "1"

[package.metadata.docs.rs]
all-features = true
Expand Down
162 changes: 156 additions & 6 deletions crates/wavekat-turn/src/audio/pipecat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,22 @@ impl MelExtractor {
fn extract(&mut self, audio: &[f32], shift_frames: usize) -> Array2<f32> {
debug_assert_eq!(audio.len(), RING_CAPACITY);

// ---- Center-pad: N_FFT/2 zeros on each side → 128 400 samples ----
// This replicates librosa/PyTorch `center=True` STFT behaviour, which
// gives exactly N_FRAMES + 1 = 801 frames; we discard the last one.
let pad = N_FFT / 2;
let mut padded = vec![0.0f32; pad + audio.len() + pad];
padded[pad..pad + audio.len()].copy_from_slice(audio);
// ---- Center-pad: N_FFT/2 reflect samples on each side → 128 400 samples ----
// Matches WhisperFeatureExtractor: np.pad(waveform, n_fft//2, mode="reflect").
// Reflect (not zero) padding ensures the boundary frames match Python exactly.
// Gives exactly N_FRAMES + 1 = 801 frames; we discard the last one.
let pad = N_FFT / 2; // 200
let n = audio.len(); // 128 000
let mut padded = vec![0.0f32; pad + n + pad];
padded[pad..pad + n].copy_from_slice(audio);
// Left reflect: padded[0..pad] = audio[pad..1] reversed (exclude edge)
for i in 0..pad {
padded[i] = audio[pad - i];
}
// Right reflect: padded[pad+n..pad+n+pad] = audio[n-2..n-2-pad] reversed
for i in 0..pad {
padded[pad + n + i] = audio[n - 2 - i];
}

// n_total = (128 400 − 400) / 160 + 1 = 801
let n_total_frames = (padded.len() - N_FFT) / HOP_LENGTH + 1;
Expand Down Expand Up @@ -527,3 +537,143 @@ impl AudioTurnDetector for PipecatSmartTurn {
self.mel.invalidate_cache();
}
}

// ---------------------------------------------------------------------------
// Mel comparison tests (unit tests — need access to private MelExtractor)
// ---------------------------------------------------------------------------

#[cfg(test)]
mod mel_tests {
use std::path::{Path, PathBuf};

use ndarray::Array2;
use ndarray_npy::ReadNpyExt;

use super::{prepare_audio, MelExtractor, RING_CAPACITY, SAMPLE_RATE};

/// Max allowed element-wise absolute difference between Rust and Python mel.
const MEL_TOLERANCE: f32 = 0.05;

fn fixtures_dir() -> PathBuf {
Path::new(env!("CARGO_MANIFEST_DIR"))
.parent()
.unwrap() // crates/
.parent()
.unwrap() // repo root
.join("tests/fixtures")
}

/// Load 16 kHz mono WAV as f32 in [-1, 1], normalised the same way as
/// Python's soundfile (divide by 32768, not i16::MAX).
fn load_wav_f32(path: &Path) -> Vec<f32> {
let mut reader = hound::WavReader::open(path)
.unwrap_or_else(|e| panic!("failed to open {}: {}", path.display(), e));
let spec = reader.spec();
assert_eq!(spec.sample_rate, SAMPLE_RATE, "expected 16 kHz");
assert_eq!(spec.channels, 1, "expected mono");
match spec.sample_format {
hound::SampleFormat::Int => reader
.samples::<i16>()
.map(|s| s.unwrap() as f32 / 32768.0)
.collect(),
hound::SampleFormat::Float => reader.samples::<f32>().map(|s| s.unwrap()).collect(),
}
}

fn load_python_mel(clip: &str) -> Array2<f32> {
let path = fixtures_dir().join(format!("{clip}.mel.npy"));
let file = std::fs::File::open(&path).unwrap_or_else(|_| {
panic!(
"missing {}: run `python scripts/gen_reference.py` first",
path.display()
)
});
Array2::<f32>::read_npy(file).expect("failed to parse .npy")
}

struct MelDiff {
max_diff: f32,
mean_diff: f32,
/// (mel_bin, frame) of the single largest diff
max_at: (usize, usize),
/// fraction of elements with diff > 0.01
outlier_frac: f32,
}

fn compare_mel(clip: &str) -> MelDiff {
let samples = load_wav_f32(&fixtures_dir().join(clip));
let audio = prepare_audio(&samples);
assert_eq!(audio.len(), RING_CAPACITY);

let mut extractor = MelExtractor::new();
let rust_mel = extractor.extract(&audio, 0);
let python_mel = load_python_mel(clip);

assert_eq!(
rust_mel.shape(),
python_mel.shape(),
"{clip}: mel shape mismatch"
);

let shape = rust_mel.shape();
let (n_mels, n_frames) = (shape[0], shape[1]);

let mut max_diff = 0.0f32;
let mut max_at = (0, 0);
let mut sum_diff = 0.0f32;
let mut outliers = 0usize;

for m in 0..n_mels {
for t in 0..n_frames {
let d = (rust_mel[[m, t]] - python_mel[[m, t]]).abs();
sum_diff += d;
if d > max_diff {
max_diff = d;
max_at = (m, t);
}
if d > 0.01 {
outliers += 1;
}
}
}

let total = (n_mels * n_frames) as f32;
MelDiff {
max_diff,
mean_diff: sum_diff / total,
max_at,
outlier_frac: outliers as f32 / total,
}
}

/// Print a markdown table of mel-level diffs between Rust and Python.
/// Run with: `make mel`
#[test]
#[ignore]
fn mel_report() {
let clips = ["silence_2s.wav", "speech_finished.wav", "speech_mid.wav"];

println!();
println!("MEL_TOLERANCE={MEL_TOLERANCE}");
println!();
println!("| Clip | Max Diff | Mean Diff | Max at (mel,frame) | Outliers >0.01 | Status |");
println!("|------|----------|-----------|---------------------|----------------|--------|");
for clip in clips {
let d = compare_mel(clip);
let status = if d.max_diff <= MEL_TOLERANCE {
"PASS"
} else {
"FAIL"
};
println!(
"| `{clip}` | {:.6} | {:.6} | ({},{}) | {:.2}% | {status} |",
d.max_diff,
d.mean_diff,
d.max_at.0,
d.max_at.1,
d.outlier_frac * 100.0,
);
}
println!();
}
}
Loading