From f6232403fdf3e7ed97be33de7e14db82fd3e709d Mon Sep 17 00:00:00 2001 From: Eason WaveKat Date: Mon, 6 Apr 2026 20:19:37 +1200 Subject: [PATCH 01/15] feat: Qwen3-TTS 1.7B VoiceDesign INT4 via HF Hub MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Switch qwen3-tts backend from 0.6B/elbruno to the WaveKat 1.7B VoiceDesign ONNX repo with INT4 by default. - download: replace ureq HTTP with hf-hub client; downloads int4/ ONNX + embeddings/ + tokenizer/ from wavekat/Qwen3-TTS-1.7B-VoiceDesign-ONNX (pinned revision) - model: HIDDEN_DIM 1024→2048, MAX_NEW_TOKENS 2048→8192, sampling updated to config.json values (top_k=50, temp=0.9), non-streaming prefill (all text in prefill, matches generate_onnx.py), fix last-position logits extraction bug, min_new_tokens=2, VoiceDesign codec prefix (4 tokens, no speaker slot), int4/ and embeddings/ subdirectory paths - sampler: replace top_p with top_k - tokenizer: load from tokenizer/ subdirectory Co-Authored-By: Claude Sonnet 4.6 --- crates/wavekat-tts/Cargo.toml | 4 +- .../src/backends/qwen3_tts/download.rs | 295 +++++------------- .../wavekat-tts/src/backends/qwen3_tts/mod.rs | 34 +- .../src/backends/qwen3_tts/model.rs | 185 +++++------ .../src/backends/qwen3_tts/sampler.rs | 56 ++-- .../src/backends/qwen3_tts/tokenizer.rs | 6 +- 6 files changed, 222 insertions(+), 358 deletions(-) diff --git a/crates/wavekat-tts/Cargo.toml b/crates/wavekat-tts/Cargo.toml index ce39065..3158dd2 100644 --- a/crates/wavekat-tts/Cargo.toml +++ b/crates/wavekat-tts/Cargo.toml @@ -12,7 +12,7 @@ categories = ["multimedia::audio"] default = [] # Local inference backends (all ONNX-based) -qwen3-tts = ["dep:ort", "dep:ndarray", "dep:tokenizers", "dep:npyz", "dep:rand", "dep:ureq"] +qwen3-tts = ["dep:ort", "dep:ndarray", "dep:tokenizers", "dep:npyz", "dep:rand", "dep:hf-hub"] cosyvoice = ["dep:ort", "dep:ndarray"] [dependencies] @@ -27,7 +27,7 @@ ndarray = { version = "0.17", optional = true } tokenizers = { version = "0.21", optional = true, default-features = false, features = ["onig"] } npyz = { version = "0.8", optional = true } rand = { version = "0.9", optional = true } -ureq = { version = "2", optional = true } +hf-hub = { version = "0.3", optional = true } hound = { version = "3.5", optional = true } [dev-dependencies] diff --git a/crates/wavekat-tts/src/backends/qwen3_tts/download.rs b/crates/wavekat-tts/src/backends/qwen3_tts/download.rs index da7b523..1b746c5 100644 --- a/crates/wavekat-tts/src/backends/qwen3_tts/download.rs +++ b/crates/wavekat-tts/src/backends/qwen3_tts/download.rs @@ -1,229 +1,96 @@ -//! Auto-download model files to a local cache directory. +//! Download model files from HuggingFace Hub. -use std::fs; -use std::io::{Read, Write}; -use std::path::{Path, PathBuf}; +use std::path::PathBuf; -use crate::TtsError; - -/// Pinned commit — guarantees immutable file URLs. -const REVISION: &str = "6a297d9641354ef0c16e63d329a93a6239bca0a2"; +use hf_hub::api::sync::Api; +use hf_hub::{Repo, RepoType}; -const BASE_URL: &str = "https://huggingface.co/elbruno/Qwen3-TTS-12Hz-0.6B-Base-ONNX/resolve"; +use crate::TtsError; -/// (remote_path, local_filename) — remote paths map to the HF repo layout, -/// local filenames are flattened into the cache directory. -const MODEL_FILES: &[(&str, &str)] = &[ - // ONNX sessions + external weight files - ("talker_prefill.onnx", "talker_prefill.onnx"), - ("talker_prefill.onnx.data", "talker_prefill.onnx.data"), - ("talker_decode.onnx", "talker_decode.onnx"), - ("talker_decode.onnx.data", "talker_decode.onnx.data"), - ("code_predictor.onnx", "code_predictor.onnx"), - ("vocoder.onnx", "vocoder.onnx"), - ("vocoder.onnx.data", "vocoder.onnx.data"), - // Embedding tables (in embeddings/ on HF, flattened locally) - ("embeddings/text_embedding.npy", "text_embedding.npy"), - ( - "embeddings/text_projection_fc1_weight.npy", - "text_projection_fc1_weight.npy", - ), - ( - "embeddings/text_projection_fc1_bias.npy", - "text_projection_fc1_bias.npy", - ), - ( - "embeddings/text_projection_fc2_weight.npy", - "text_projection_fc2_weight.npy", - ), - ( - "embeddings/text_projection_fc2_bias.npy", - "text_projection_fc2_bias.npy", - ), - ( - "embeddings/talker_codec_embedding.npy", - "talker_codec_embedding.npy", - ), - ( - "embeddings/cp_codec_embedding_0.npy", - "cp_codec_embedding_0.npy", - ), - ( - "embeddings/cp_codec_embedding_1.npy", - "cp_codec_embedding_1.npy", - ), - ( - "embeddings/cp_codec_embedding_2.npy", - "cp_codec_embedding_2.npy", - ), - ( - "embeddings/cp_codec_embedding_3.npy", - "cp_codec_embedding_3.npy", - ), - ( - "embeddings/cp_codec_embedding_4.npy", - "cp_codec_embedding_4.npy", - ), - ( - "embeddings/cp_codec_embedding_5.npy", - "cp_codec_embedding_5.npy", - ), - ( - "embeddings/cp_codec_embedding_6.npy", - "cp_codec_embedding_6.npy", - ), - ( - "embeddings/cp_codec_embedding_7.npy", - "cp_codec_embedding_7.npy", - ), - ( - "embeddings/cp_codec_embedding_8.npy", - "cp_codec_embedding_8.npy", - ), - ( - "embeddings/cp_codec_embedding_9.npy", - "cp_codec_embedding_9.npy", - ), - ( - "embeddings/cp_codec_embedding_10.npy", - "cp_codec_embedding_10.npy", - ), - ( - "embeddings/cp_codec_embedding_11.npy", - "cp_codec_embedding_11.npy", - ), - ( - "embeddings/cp_codec_embedding_12.npy", - "cp_codec_embedding_12.npy", - ), - ( - "embeddings/cp_codec_embedding_13.npy", - "cp_codec_embedding_13.npy", - ), - ( - "embeddings/cp_codec_embedding_14.npy", - "cp_codec_embedding_14.npy", - ), - // Tokenizer (in tokenizer/ on HF, flattened locally) - ("tokenizer/vocab.json", "vocab.json"), - ("tokenizer/merges.txt", "merges.txt"), +const REPO_ID: &str = "wavekat/Qwen3-TTS-1.7B-VoiceDesign-ONNX"; +const REVISION: &str = "62c7863a68800d72bcee4f2931148c441571ace7"; + +/// Files required for INT4 inference (ONNX models + embeddings + tokenizer). +const MODEL_FILES: &[&str] = &[ + "config.json", + // INT4 ONNX models + "int4/talker_prefill.onnx", + "int4/talker_prefill.onnx.data", + "int4/talker_decode.onnx", + "int4/talker_decode.onnx.data", + "int4/code_predictor.onnx", + "int4/code_predictor.onnx.data", + "int4/vocoder.onnx", + "int4/vocoder.onnx.data", + // Embedding tables + "embeddings/text_embedding.npy", + "embeddings/text_projection_fc1_weight.npy", + "embeddings/text_projection_fc1_bias.npy", + "embeddings/text_projection_fc2_weight.npy", + "embeddings/text_projection_fc2_bias.npy", + "embeddings/talker_codec_embedding.npy", + "embeddings/cp_codec_embedding_0.npy", + "embeddings/cp_codec_embedding_1.npy", + "embeddings/cp_codec_embedding_2.npy", + "embeddings/cp_codec_embedding_3.npy", + "embeddings/cp_codec_embedding_4.npy", + "embeddings/cp_codec_embedding_5.npy", + "embeddings/cp_codec_embedding_6.npy", + "embeddings/cp_codec_embedding_7.npy", + "embeddings/cp_codec_embedding_8.npy", + "embeddings/cp_codec_embedding_9.npy", + "embeddings/cp_codec_embedding_10.npy", + "embeddings/cp_codec_embedding_11.npy", + "embeddings/cp_codec_embedding_12.npy", + "embeddings/cp_codec_embedding_13.npy", + "embeddings/cp_codec_embedding_14.npy", + // Tokenizer + "tokenizer/vocab.json", + "tokenizer/merges.txt", ]; -/// Resolve the default model cache directory, downloading any missing files. +/// Resolve the local HF Hub snapshot directory for the Qwen3-TTS model, +/// downloading any missing files as needed. +/// +/// Set `WAVEKAT_MODEL_DIR` to skip HF Hub and load from a local directory +/// that mirrors the repo layout (`int4/`, `embeddings/`, `tokenizer/`). /// -/// Resolution order: -/// 1. `$WAVEKAT_MODEL_DIR` if set -/// 2. `$XDG_CACHE_HOME/wavekat/qwen3-tts-0.6b/` -/// 3. `$HOME/.cache/wavekat/qwen3-tts-0.6b/` +/// Authentication: set `HF_TOKEN` if the repo requires it. +/// Cache location: `$HF_HOME/hub/` (default `~/.cache/huggingface/hub/`). pub fn ensure_model_dir() -> Result { - let dir = default_cache_dir()?; - ensure_files(&dir)?; - Ok(dir) -} - -fn default_cache_dir() -> Result { if let Ok(dir) = std::env::var("WAVEKAT_MODEL_DIR") { return Ok(PathBuf::from(dir)); } - let base = if let Ok(xdg) = std::env::var("XDG_CACHE_HOME") { - PathBuf::from(xdg) - } else if let Ok(home) = std::env::var("HOME") { - PathBuf::from(home).join(".cache") - } else { - return Err(TtsError::Model( - "cannot determine cache directory: set WAVEKAT_MODEL_DIR or HOME".into(), - )); - }; - Ok(base.join("wavekat").join("qwen3-tts-0.6b")) -} -fn ensure_files(dir: &Path) -> Result<(), TtsError> { - let missing: Vec<(&str, &str)> = MODEL_FILES - .iter() - .filter(|(_, local)| !dir.join(local).exists()) - .copied() - .collect(); - - if missing.is_empty() { - return Ok(()); - } - - fs::create_dir_all(dir).map_err(|e| { - TtsError::Model(format!("failed to create cache dir {}: {e}", dir.display())) - })?; - - let total = missing.len(); - eprintln!( - "Downloading Qwen3-TTS model ({total} files) to {} ...", - dir.display() - ); - - for (i, (remote, local)) in missing.iter().enumerate() { - let url = format!("{BASE_URL}/{REVISION}/{remote}"); - let dest = dir.join(local); - eprintln!("[{}/{}] {}", i + 1, total, local); - download_file(&url, &dest)?; + let api = Api::new() + .map_err(|e| TtsError::Model(format!("failed to initialize HF Hub client: {e}")))?; + + let repo = api.repo(Repo::with_revision( + REPO_ID.to_string(), + RepoType::Model, + REVISION.to_string(), + )); + + let total = MODEL_FILES.len(); + eprintln!("Ensuring Qwen3-TTS 1.7B model ({total} files from {REPO_ID})..."); + + // config.json is always first — its parent is the snapshot root. + eprintln!("[1/{total}] {}", MODEL_FILES[0]); + let config_path = repo + .get(MODEL_FILES[0]) + .map_err(|e| TtsError::Model(format!("failed to download {}: {e}", MODEL_FILES[0])))?; + + let model_dir = config_path + .parent() + .ok_or_else(|| TtsError::Model("unexpected cache path for config.json".into()))? + .to_path_buf(); + + for (i, filename) in MODEL_FILES[1..].iter().enumerate() { + eprintln!("[{}/{total}] {filename}", i + 2); + repo.get(filename) + .map_err(|e| TtsError::Model(format!("failed to download {filename}: {e}")))?; } - eprintln!("Download complete."); - Ok(()) -} - -fn download_file(url: &str, dest: &Path) -> Result<(), TtsError> { - let response = ureq::get(url) - .call() - .map_err(|e| TtsError::Model(format!("download failed for {url}: {e}")))?; - - let content_length: Option = response - .header("Content-Length") - .and_then(|s| s.parse().ok()); - - let mut reader = response.into_reader(); - - // Write to a temp file first, then rename to avoid partial files on interrupt. - let tmp = dest.with_extension("tmp"); - let mut file = fs::File::create(&tmp) - .map_err(|e| TtsError::Model(format!("failed to create {}: {e}", tmp.display())))?; - - let mut buf = [0u8; 256 * 1024]; - let mut downloaded: u64 = 0; - let mut last_report: u64 = 0; - - loop { - let n = reader - .read(&mut buf) - .map_err(|e| TtsError::Model(format!("download read error: {e}")))?; - if n == 0 { - break; - } - file.write_all(&buf[..n]) - .map_err(|e| TtsError::Model(format!("write error: {e}")))?; - downloaded += n as u64; - - // Report progress every 50 MB for large files. - if downloaded - last_report >= 50_000_000 { - if let Some(total) = content_length { - let mb = downloaded / 1_000_000; - let total_mb = total / 1_000_000; - eprint!("\r {mb}/{total_mb} MB"); - } - last_report = downloaded; - } - } - - // Clear progress line if we printed any. - if last_report > 0 { - eprintln!(); - } - - drop(file); - fs::rename(&tmp, dest).map_err(|e| { - TtsError::Model(format!( - "failed to rename {} → {}: {e}", - tmp.display(), - dest.display() - )) - })?; - - Ok(()) + eprintln!("Model ready."); + Ok(model_dir) } diff --git a/crates/wavekat-tts/src/backends/qwen3_tts/mod.rs b/crates/wavekat-tts/src/backends/qwen3_tts/mod.rs index db257be..9cf24d9 100644 --- a/crates/wavekat-tts/src/backends/qwen3_tts/mod.rs +++ b/crates/wavekat-tts/src/backends/qwen3_tts/mod.rs @@ -1,16 +1,17 @@ -//! Qwen3-TTS backend (ONNX, 12Hz-0.6B). +//! Qwen3-TTS backend (ONNX INT4, 1.7B VoiceDesign). //! -//! Runs the Qwen3-TTS-12Hz-0.6B model via ONNX Runtime. +//! Runs the Qwen3-TTS-12Hz-1.7B-VoiceDesign model via ONNX Runtime using the +//! INT4 weight-only quantized models from `wavekat/Qwen3-TTS-1.7B-VoiceDesign-ONNX`. //! //! ```ignore //! use wavekat_tts::{TtsBackend, SynthesizeRequest}; //! use wavekat_tts::backends::qwen3_tts::Qwen3Tts; //! -//! // Auto-download model files (~3.8 GB, cached for reuse): +//! // Auto-download model files via HF Hub (cached at ~/.cache/huggingface/hub/): //! let tts = Qwen3Tts::new()?; //! -//! // Or load from an explicit directory: -//! let tts = Qwen3Tts::from_dir("models/qwen3-tts-0.6b")?; +//! // Or load from an explicit directory (must mirror the HF repo layout): +//! let tts = Qwen3Tts::from_dir("models/qwen3-tts-1.7b")?; //! //! let request = SynthesizeRequest::new("Hello, world"); //! let audio = tts.synthesize(&request)?; @@ -36,27 +37,24 @@ pub struct Qwen3Tts { } impl Qwen3Tts { - /// Create a new backend, auto-downloading model files if needed. + /// Create a new backend, downloading model files from HF Hub if needed. /// - /// Model files (~3.8 GB) are cached at: - /// - `$WAVEKAT_MODEL_DIR` if set, otherwise - /// - `$XDG_CACHE_HOME/wavekat/qwen3-tts-0.6b/`, otherwise - /// - `$HOME/.cache/wavekat/qwen3-tts-0.6b/` + /// Files are cached by the HF Hub client (default `~/.cache/huggingface/hub/`). + /// Set `HF_HOME` to change the cache root, or `HF_TOKEN` for authentication. + /// Set `WAVEKAT_MODEL_DIR` to load from a local directory and skip all downloads. /// - /// Use [`from_dir`](Self::from_dir) to skip auto-download and load from - /// a specific directory. + /// Use [`from_dir`](Self::from_dir) to load from an explicit path. pub fn new() -> Result { let model_dir = download::ensure_model_dir()?; Self::from_dir(model_dir) } - /// Load the model from a directory containing ONNX files and embeddings. + /// Load the model from a directory that mirrors the HF repo layout. /// - /// Expected files: - /// - `talker_prefill.onnx`, `talker_decode.onnx`, `code_predictor.onnx`, `vocoder.onnx` - /// - `text_embedding.npy`, `text_projection_fc1_weight.npy`, etc. - /// - `talker_codec_embedding.npy`, `cp_codec_embedding_{0..14}.npy` - /// - `vocab.json`, `merges.txt` + /// Expected subdirectories: + /// - `int4/` — ONNX models (`talker_prefill.onnx`, `talker_decode.onnx`, etc.) + /// - `embeddings/` — `.npy` embedding tables + /// - `tokenizer/` — `vocab.json`, `merges.txt` pub fn from_dir(model_dir: impl AsRef) -> Result { let model_dir = model_dir.as_ref(); let model = model::Model::load(model_dir)?; diff --git a/crates/wavekat-tts/src/backends/qwen3_tts/model.rs b/crates/wavekat-tts/src/backends/qwen3_tts/model.rs index 4be7990..5008f34 100644 --- a/crates/wavekat-tts/src/backends/qwen3_tts/model.rs +++ b/crates/wavekat-tts/src/backends/qwen3_tts/model.rs @@ -11,7 +11,7 @@ use crate::TtsError; use super::sampler::{self, SamplerConfig}; use super::tokenizer::{self, ASSISTANT, IM_START, NEWLINE, TTS_BOS, TTS_EOS, TTS_PAD}; -// Codec control token IDs +// Codec control token IDs (from config.json) const CODEC_PAD: i64 = 2148; const CODEC_BOS: i64 = 2149; const CODEC_THINK: i64 = 2154; @@ -21,27 +21,29 @@ const CODEC_THINK_EOS: i64 = 2157; /// Talker output: (logits, hidden_state, kv_keys, kv_values). type TalkerOutput = (Vec, Array3, Array5, Array5); -// Model dimensions (0.6B-12Hz) -const HIDDEN_DIM: usize = 1024; +// Model dimensions — Qwen3-TTS-12Hz-1.7B-VoiceDesign +const HIDDEN_DIM: usize = 2048; const NUM_LAYERS: usize = 28; const NUM_KV_HEADS: usize = 8; const HEAD_DIM: usize = 128; +const TALKER_VOCAB_SIZE: usize = 3072; const CP_NUM_LAYERS: usize = 5; +const CP_NUM_KV_HEADS: usize = 8; const NUM_CP_GROUPS: usize = 15; // codebook groups 1-15 const SAMPLE_RATE: u32 = 24000; const CODEC_EOS: i64 = 2150; -const MAX_NEW_TOKENS: usize = 2048; +const MAX_NEW_TOKENS: usize = 8192; -/// Sampling defaults matching the reference implementation. +/// Sampling defaults from config.json generate_config. const TALKER_SAMPLER: SamplerConfig = SamplerConfig { - temperature: 0.7, - top_p: 0.8, + temperature: 0.9, + top_k: 50, repetition_penalty: 1.05, }; const CP_SAMPLER: SamplerConfig = SamplerConfig { - temperature: 0.2, - top_p: 0.5, + temperature: 0.9, + top_k: 50, repetition_penalty: 1.0, }; @@ -57,22 +59,26 @@ pub struct Model { // Embedding tables (immutable after construction) text_embedding: Array2, // (vocab, 2048) - text_proj_fc1_weight: Array2, // (1024, 2048) - text_proj_fc1_bias: Array1, // (1024,) - text_proj_fc2_weight: Array2, // (1024, 1024) - text_proj_fc2_bias: Array1, // (1024,) - talker_codec_embedding: Array2, // (3072, 1024) - cp_codec_embeddings: Vec>, // 15 × (2048, 1024) + text_proj_fc1_weight: Array2, // (2048, 2048) + text_proj_fc1_bias: Array1, // (2048,) + text_proj_fc2_weight: Array2, // (2048, 2048) + text_proj_fc2_bias: Array1, // (2048,) + talker_codec_embedding: Array2, // (3072, 2048) + cp_codec_embeddings: Vec>, // 15 × (2048, 2048) // Precomputed - tts_pad_embed: Array1, // (1024,) projected tts_pad text embedding + tts_pad_embed: Array1, // (2048,) projected tts_pad text embedding } impl Model { /// Load all ONNX sessions and embedding tables from `model_dir`. + /// + /// Expected layout (matches the HF repo): + /// - `model_dir/int4/talker_prefill.onnx` (+ .data), etc. + /// - `model_dir/embeddings/text_embedding.npy`, etc. pub fn load(model_dir: &Path) -> Result { let load_session = |name: &str| -> Result { - let path = model_dir.join(name); + let path = model_dir.join("int4").join(name); Session::builder() .map_err(|e| TtsError::Model(format!("session builder error: {e}")))? .commit_from_file(&path) @@ -84,18 +90,24 @@ impl Model { let code_predictor = load_session("code_predictor.onnx")?; let vocoder = load_session("vocoder.onnx")?; - let text_embedding = load_npy2(model_dir, "text_embedding.npy")?; - let text_proj_fc1_weight = load_npy2(model_dir, "text_projection_fc1_weight.npy")?; - let text_proj_fc1_bias = load_npy1(model_dir, "text_projection_fc1_bias.npy")?; - let text_proj_fc2_weight = load_npy2(model_dir, "text_projection_fc2_weight.npy")?; - let text_proj_fc2_bias = load_npy1(model_dir, "text_projection_fc2_bias.npy")?; - let talker_codec_embedding = load_npy2(model_dir, "talker_codec_embedding.npy")?; + let text_embedding = + load_npy2(model_dir, "embeddings/text_embedding.npy")?; + let text_proj_fc1_weight = + load_npy2(model_dir, "embeddings/text_projection_fc1_weight.npy")?; + let text_proj_fc1_bias = + load_npy1(model_dir, "embeddings/text_projection_fc1_bias.npy")?; + let text_proj_fc2_weight = + load_npy2(model_dir, "embeddings/text_projection_fc2_weight.npy")?; + let text_proj_fc2_bias = + load_npy1(model_dir, "embeddings/text_projection_fc2_bias.npy")?; + let talker_codec_embedding = + load_npy2(model_dir, "embeddings/talker_codec_embedding.npy")?; let mut cp_codec_embeddings = Vec::with_capacity(NUM_CP_GROUPS); for i in 0..NUM_CP_GROUPS { cp_codec_embeddings.push(load_npy2( model_dir, - &format!("cp_codec_embedding_{i}.npy"), + &format!("embeddings/cp_codec_embedding_{i}.npy"), )?); } @@ -133,10 +145,10 @@ impl Model { let lang_id = tokenizer::language_id(language) .ok_or_else(|| TtsError::UnsupportedLanguage(language.to_string()))?; - let (prefill_embeds, trailing) = self.build_prefill_embeds(text_tokens, lang_id)?; + let prefill_embeds = self.build_prefill_embeds(text_tokens, lang_id)?; let prefill_len = prefill_embeds.shape()[1]; - // Run talker prefill + // Run talker prefill — returns last-position logits let (logits, hidden_states, mut past_keys, mut past_values) = self.run_talker_prefill(&prefill_embeds, prefill_len)?; @@ -145,7 +157,6 @@ impl Model { let mut talker_past_tokens: Vec = Vec::new(); let mut current_logits = logits; - // Hidden state for code predictor: starts from prefill last position let mut current_hidden = hidden_states .slice(s![0, prefill_len - 1.., ..]) .to_owned() @@ -153,11 +164,12 @@ impl Model { .map_err(|e| TtsError::Synthesis(format!("reshape hidden: {e}")))?; for step in 0..MAX_NEW_TOKENS { + // Suppress CODEC_EOS for the first 2 steps (min_new_tokens=2) let group0 = sampler::sample( ¤t_logits, &TALKER_SAMPLER, &talker_past_tokens, - sampler::talker_mask, + |tok| sampler::talker_mask(tok) || (step < 2 && tok == CODEC_EOS as usize), ) as i64; if group0 == CODEC_EOS { @@ -171,17 +183,13 @@ impl Model { self.run_code_predictor(¤t_hidden, &mut codes)?; all_codes.push(codes); - // Build next talker input: sum of all 16 group embeddings + trailing text + // Build next talker input: sum of all 16 group embeddings + tts_pad (non-streaming) let mut next_embed = self.talker_codec_embedding.row(group0 as usize).to_owned(); for g in 0..NUM_CP_GROUPS { let cp_embed = self.cp_codec_embeddings[g].row(codes[g + 1] as usize); next_embed += &cp_embed; } - if step < trailing.len() { - next_embed += &trailing[step]; - } else { - next_embed += &self.tts_pad_embed; - } + next_embed += &self.tts_pad_embed; let next_embed = next_embed .into_shape_with_order((1, 1, HIDDEN_DIM)) @@ -219,53 +227,42 @@ impl Model { ) } - /// Build prefill embeddings and trailing text hidden states. + /// Build prefill embeddings (non-streaming: all text embedded in prefill). /// - /// Prefill layout (matching the C# reference): /// ```text - /// [im_start, assistant, \n] — role prefix (text proj only, no codec) - /// [think, think_bos, lang, think_eos, speaker(=pad)] — tts_pad_embed + codec_embed - /// [tts_bos_embed + codec_embed(pad)] — transition marker - /// [text_proj(first_text) + codec_embed(bos)] — first text token enters here + /// [im_start, assistant, \n] — role prefix (text proj only) + /// [think, think_bos, lang_id, think_eos] — codec prefix (tts_pad + codec_embed) + /// [tts_bos + codec_pad] — transition + /// [text_proj(tok) + codec_pad] × N — all text tokens + /// [text_proj(TTS_EOS) + codec_pad] — TTS_EOS + /// [tts_pad + codec_bos] — final /// ``` - /// - /// Trailing text hidden: `[text_proj(tok) for tok in text[1:]] + [text_proj(TTS_EOS)]` - /// These are consumed one per decode step; after exhaustion, tts_pad_embed is used. fn build_prefill_embeds( &self, text_tokens: &[u32], lang_id: i64, - ) -> Result<(Array3, Vec>), TtsError> { + ) -> Result, TtsError> { + let codec_pad_embed = self.talker_codec_embedding.row(CODEC_PAD as usize).to_owned(); + let codec_bos_embed = self.talker_codec_embedding.row(CODEC_BOS as usize).to_owned(); let tts_bos_embed = self.text_project_token(TTS_BOS); + let tts_eos_embed = self.text_project_token(TTS_EOS); - let first_text = if text_tokens.is_empty() { - TTS_PAD - } else { - text_tokens[0] - }; + // VoiceDesign codec prefix — no speaker slot + let codec_prefix = [CODEC_THINK, CODEC_THINK_BOS, lang_id, CODEC_THINK_EOS]; - // Codec prefix: [think, think_bos, lang_id, think_eos, speaker_id(=pad)] - let codec_prefix = [ - CODEC_THINK, - CODEC_THINK_BOS, - lang_id, - CODEC_THINK_EOS, - CODEC_PAD, - ]; - - // 3 role + 5 codec prefix + 1 transition + 1 first text = 10 - let seq_len = 3 + codec_prefix.len() + 2; + // 3 role + 4 codec_prefix + 1 transition + N text + 1 TTS_EOS + 1 final + let seq_len = 3 + codec_prefix.len() + 1 + text_tokens.len() + 1 + 1; let mut embeds = Array3::::zeros((1, seq_len, HIDDEN_DIM)); let mut pos = 0; - // 1. Role prefix: text_project only, no codec component + // 1. Role prefix: text_proj only for &tok in &[IM_START, ASSISTANT, NEWLINE] { let embed = self.text_project_token(tok); embeds.slice_mut(s![0, pos, ..]).assign(&embed); pos += 1; } - // 2. Codec prefix: tts_pad_embed + codec_embed(token) + // 2. Codec prefix: tts_pad + codec_embed(token) for &codec_tok in &codec_prefix { let mut embed = self.tts_pad_embed.clone(); embed += &self.talker_codec_embedding.row(codec_tok as usize); @@ -273,36 +270,39 @@ impl Model { pos += 1; } - // 3. Transition: tts_bos_embed + codec_embed(pad) + // 3. Transition: tts_bos + codec_pad { - let mut embed = tts_bos_embed; - embed += &self.talker_codec_embedding.row(CODEC_PAD as usize); + let embed = &tts_bos_embed + &codec_pad_embed; embeds.slice_mut(s![0, pos, ..]).assign(&embed); pos += 1; } - // 4. First text token + codec_embed(bos) + // 4. All text tokens: text_proj(tok) + codec_pad + for &tok in text_tokens { + let embed = self.text_project_token(tok) + &codec_pad_embed; + embeds.slice_mut(s![0, pos, ..]).assign(&embed); + pos += 1; + } + + // 5. TTS_EOS: text_proj(TTS_EOS) + codec_pad { - let mut embed = self.text_project_token(first_text); - embed += &self.talker_codec_embedding.row(CODEC_BOS as usize); + let embed = tts_eos_embed + &codec_pad_embed; embeds.slice_mut(s![0, pos, ..]).assign(&embed); + pos += 1; } - // Trailing text hidden: remaining text tokens + TTS_EOS - let mut trailing = Vec::new(); - if text_tokens.len() > 1 { - for &tok in &text_tokens[1..] { - trailing.push(self.text_project_token(tok)); - } + // 6. Final: tts_pad + codec_bos + { + let embed = &self.tts_pad_embed + &codec_bos_embed; + embeds.slice_mut(s![0, pos, ..]).assign(&embed); } - trailing.push(self.text_project_token(TTS_EOS)); - Ok((embeds, trailing)) + Ok(embeds) } /// Run talker_prefill.onnx. /// - /// Returns (logits, hidden_states, past_keys, past_values). + /// Returns (last-position logits, hidden_states, past_keys, past_values). fn run_talker_prefill( &self, inputs_embeds: &Array3, @@ -334,13 +334,13 @@ impl Model { ]) .map_err(|e| TtsError::Synthesis(format!("talker prefill failed: {e}")))?; - // Logits: last position (1, 1, 3072) → flat Vec + // Logits: (1, T, 3072) — extract only the last position let (_, logits_data) = outputs[0] .try_extract_tensor::() .map_err(|e| TtsError::Synthesis(format!("extract logits: {e}")))?; - let logits: Vec = logits_data.to_vec(); + let logits: Vec = logits_data[logits_data.len() - TALKER_VOCAB_SIZE..].to_vec(); - // Hidden states: (1, T, 1024) + // Hidden states: (1, T, 2048) let (_, hidden_data) = outputs[1] .try_extract_tensor::() .map_err(|e| TtsError::Synthesis(format!("extract hidden: {e}")))?; @@ -361,7 +361,6 @@ impl Model { .try_extract_tensor::() .map_err(|e| TtsError::Synthesis(format!("extract val layer {layer}: {e}")))?; - // (1, 8, T, 128) → insert axis 0 → (1, 1, 8, T, 128) let key_arr = ndarray::ArrayD::from_shape_vec( vec![1, NUM_KV_HEADS, seq_len, HEAD_DIM], key_data.to_vec(), @@ -401,7 +400,7 @@ impl Model { /// Run talker_decode.onnx for a single step. fn run_talker_decode( &self, - inputs_embeds: &Array3, // (1, 1, 1024) + inputs_embeds: &Array3, // (1, 1, 2048) total_seq: usize, position: i64, past_keys: &Array5, // (28, 1, 8, past_seq, 128) @@ -465,9 +464,12 @@ impl Model { } /// Run the code predictor to fill codebook groups 1-15. + /// + /// The code_predictor.onnx includes the small_to_mtp projection internally, + /// so host code passes 2048-dim embeddings directly. fn run_code_predictor( &self, - hidden_state: &Array3, // (1, 1, 1024) + hidden_state: &Array3, // (1, 1, 2048) codes: &mut [i64; 16], ) -> Result<(), TtsError> { let group0_embed = self @@ -477,14 +479,15 @@ impl Model { .into_shape_with_order((1, 1, HIDDEN_DIM)) .map_err(|e| TtsError::Synthesis(format!("reshape group0 embed: {e}")))?; - // First call: concat(hidden_state, group0_embed) → (1, 2, 1024) + // First call: concat(hidden_state, group0_embed) → (1, 2, 2048) let first_input = concatenate(Axis(1), &[hidden_state.view(), group0_embed.view()]) .map_err(|e| TtsError::Synthesis(format!("concat cp input: {e}")))?; // Empty KV cache: (5, 1, 8, 0, 128) - let mut cp_past_keys = Array5::::zeros((CP_NUM_LAYERS, 1, NUM_KV_HEADS, 0, HEAD_DIM)); + let mut cp_past_keys = + Array5::::zeros((CP_NUM_LAYERS, 1, CP_NUM_KV_HEADS, 0, HEAD_DIM)); let mut cp_past_values = - Array5::::zeros((CP_NUM_LAYERS, 1, NUM_KV_HEADS, 0, HEAD_DIM)); + Array5::::zeros((CP_NUM_LAYERS, 1, CP_NUM_KV_HEADS, 0, HEAD_DIM)); let mut cp_input = first_input; let mut session = self.code_predictor.lock().unwrap(); @@ -516,8 +519,8 @@ impl Model { let (_, logits_data) = outputs[0] .try_extract_tensor::() .map_err(|e| TtsError::Synthesis(format!("extract cp logits: {e}")))?; - let vocab_size = 2048; - let last_logits = &logits_data[logits_data.len() - vocab_size..]; + let cp_vocab_size = 2048; + let last_logits = &logits_data[logits_data.len() - cp_vocab_size..]; let token = sampler::sample(last_logits, &CP_SAMPLER, &[], sampler::no_mask) as i64; codes[group_idx + 1] = token; @@ -533,12 +536,12 @@ impl Model { .map_err(|e| TtsError::Synthesis(format!("extract cp values: {e}")))?; cp_past_keys = Array5::from_shape_vec( - (CP_NUM_LAYERS, 1, NUM_KV_HEADS, seq_so_far, HEAD_DIM), + (CP_NUM_LAYERS, 1, CP_NUM_KV_HEADS, seq_so_far, HEAD_DIM), keys_data.to_vec(), ) .map_err(|e| TtsError::Synthesis(format!("reshape cp keys: {e}")))?; cp_past_values = Array5::from_shape_vec( - (CP_NUM_LAYERS, 1, NUM_KV_HEADS, seq_so_far, HEAD_DIM), + (CP_NUM_LAYERS, 1, CP_NUM_KV_HEADS, seq_so_far, HEAD_DIM), values_data.to_vec(), ) .map_err(|e| TtsError::Synthesis(format!("reshape cp values: {e}")))?; @@ -593,7 +596,7 @@ impl Model { // Helpers // --------------------------------------------------------------------------- -/// SiLU-gated MLP text projection: 2048 → 1024. +/// SiLU-gated MLP text projection: 2048 → 2048. fn text_project( input: &Array1, fc1_weight: &Array2, diff --git a/crates/wavekat-tts/src/backends/qwen3_tts/sampler.rs b/crates/wavekat-tts/src/backends/qwen3_tts/sampler.rs index 6105b62..e1e5c60 100644 --- a/crates/wavekat-tts/src/backends/qwen3_tts/sampler.rs +++ b/crates/wavekat-tts/src/backends/qwen3_tts/sampler.rs @@ -4,14 +4,15 @@ use rand::Rng; #[derive(Debug, Clone)] pub struct SamplerConfig { pub temperature: f32, - pub top_p: f32, + pub top_k: usize, pub repetition_penalty: f32, } -/// Sample a token index from logits using top-p (nucleus) sampling. +/// Sample a token index from logits using top-k sampling with temperature and +/// optional repetition penalty. /// /// `mask` is called with a token index and returns `true` if the token should -/// be suppressed (set to -inf before softmax). +/// be suppressed (forced to -inf before sampling). pub fn sample( logits: &[f32], config: &SamplerConfig, @@ -48,7 +49,16 @@ pub fn sample( } } - // 4. Softmax + // 4. Top-k filtering: zero out all but the k highest scores + if config.top_k > 0 && config.top_k < scores.len() { + let mut indexed: Vec<(usize, f32)> = scores.iter().copied().enumerate().collect(); + indexed.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); + for &(i, _) in &indexed[config.top_k..] { + scores[i] = f32::NEG_INFINITY; + } + } + + // 5. Softmax let max = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max); let mut probs: Vec = scores.iter().map(|&s| (s - max).exp()).collect(); let sum: f32 = probs.iter().sum(); @@ -58,37 +68,24 @@ pub fn sample( } } - // 5. Top-p filtering - let mut indexed: Vec<(usize, f32)> = probs.iter().copied().enumerate().collect(); - indexed.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); - - let mut cumulative = 0.0; - let mut cutoff = indexed.len(); - for (i, &(_, p)) in indexed.iter().enumerate() { - cumulative += p; - if cumulative >= config.top_p { - cutoff = i + 1; - break; - } - } - let candidates = &indexed[..cutoff]; - - // Renormalize - let cand_sum: f32 = candidates.iter().map(|&(_, p)| p).sum(); - // 6. Sample let mut rng = rand::rng(); - let r: f32 = rng.random::() * cand_sum; + let r: f32 = rng.random::(); let mut accum = 0.0; - for &(idx, p) in candidates { + for (i, &p) in probs.iter().enumerate() { accum += p; if accum >= r { - return idx; + return i; } } - // Fallback: return the highest-probability token - candidates[0].0 + // Fallback: highest-probability token + probs + .iter() + .enumerate() + .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal)) + .map(|(i, _)| i) + .unwrap_or(0) } /// Logit mask for the Talker LM (group 0). @@ -127,7 +124,7 @@ mod tests { let logits = vec![0.0; 100]; let config = SamplerConfig { temperature: 1.0, - top_p: 0.9, + top_k: 50, repetition_penalty: 1.0, }; let idx = sample(&logits, &config, &[], no_mask); @@ -136,11 +133,10 @@ mod tests { #[test] fn sample_respects_mask() { - // All logits equal, but mask everything except token 5 let logits = vec![0.0; 10]; let config = SamplerConfig { temperature: 1.0, - top_p: 1.0, + top_k: 0, repetition_penalty: 1.0, }; let idx = sample(&logits, &config, &[], |i| i != 5); diff --git a/crates/wavekat-tts/src/backends/qwen3_tts/tokenizer.rs b/crates/wavekat-tts/src/backends/qwen3_tts/tokenizer.rs index 6f5b703..3d70ba7 100644 --- a/crates/wavekat-tts/src/backends/qwen3_tts/tokenizer.rs +++ b/crates/wavekat-tts/src/backends/qwen3_tts/tokenizer.rs @@ -33,10 +33,10 @@ pub struct Tokenizer { } impl Tokenizer { - /// Build tokenizer from `vocab.json` + `merges.txt` in `model_dir`. + /// Build tokenizer from `tokenizer/vocab.json` + `tokenizer/merges.txt` in `model_dir`. pub fn new(model_dir: &Path) -> Result { - let vocab_path = model_dir.join("vocab.json"); - let merges_path = model_dir.join("merges.txt"); + let vocab_path = model_dir.join("tokenizer").join("vocab.json"); + let merges_path = model_dir.join("tokenizer").join("merges.txt"); let bpe = tokenizers::models::bpe::BPE::from_file( &vocab_path.to_string_lossy(), From e031f0b390bb265107a55b27ebf3be83616f9ded Mon Sep 17 00:00:00 2001 From: Eason WaveKat Date: Mon, 6 Apr 2026 20:37:23 +1200 Subject: [PATCH 02/15] docs: add 1.7B VoiceDesign migration notes Co-Authored-By: Claude Sonnet 4.6 --- docs/04-qwen3-tts-1.7b-migration.md | 177 ++++++++++++++++++++++++++++ 1 file changed, 177 insertions(+) create mode 100644 docs/04-qwen3-tts-1.7b-migration.md diff --git a/docs/04-qwen3-tts-1.7b-migration.md b/docs/04-qwen3-tts-1.7b-migration.md new file mode 100644 index 0000000..925ea9a --- /dev/null +++ b/docs/04-qwen3-tts-1.7b-migration.md @@ -0,0 +1,177 @@ +# Qwen3-TTS: 1.7B VoiceDesign Migration + +Documents the changes made in `feat/qwen3-tts-1.7b-int4` to migrate the +Qwen3-TTS backend from the third-party 0.6B ONNX repo to the WaveKat-owned +1.7B VoiceDesign ONNX repo with INT4 quantization by default. + +## What changed and why + +| | Before | After | +|---|---|---| +| **Model** | Qwen3-TTS-12Hz-0.6B-Base | Qwen3-TTS-12Hz-1.7B-VoiceDesign | +| **HF repo** | `elbruno/Qwen3-TTS-12Hz-0.6B-Base-ONNX` | `wavekat/Qwen3-TTS-1.7B-VoiceDesign-ONNX` | +| **ONNX variant** | FP32 | INT4 weight-only (RTN, block=128) | +| **Download** | Custom HTTP + manual cache | `hf-hub` crate (standard HF cache) | +| **Prefill mode** | Streaming (text fed per decode step) | Non-streaming (all text in prefill) | +| **Sampling** | top-p | top-k | + +The 1.7B VoiceDesign model produces significantly higher quality audio, particularly +for longer utterances and non-English languages. The INT4 models are ~4× smaller than +FP32 (talker: 5.3 GB → 1.4 GB) with negligible quality loss. + +## HF repo structure + +`wavekat/Qwen3-TTS-1.7B-VoiceDesign-ONNX` mirrors the export output from +`tools/qwen3-tts-onnx/`. The Rust backend downloads and uses only the INT4 variant: + +``` +{hf_snapshot_root}/ +├── config.json # model dimensions, token IDs, sampling config +├── int4/ # INT4 weight-only quantized ONNX models +│ ├── talker_prefill.onnx + .onnx.data (~1.4 GB) +│ ├── talker_decode.onnx + .onnx.data (~1.4 GB) +│ ├── code_predictor.onnx + .onnx.data (~322 MB) +│ └── vocoder.onnx + .onnx.data (~558 MB) +├── embeddings/ # pre-extracted weights as .npy +│ ├── text_embedding.npy +│ ├── text_projection_fc1_{weight,bias}.npy +│ ├── text_projection_fc2_{weight,bias}.npy +│ ├── talker_codec_embedding.npy +│ └── cp_codec_embedding_{0..14}.npy +└── tokenizer/ + ├── vocab.json + └── merges.txt +``` + +`embeddings/small_to_mtp_projection_{weight,bias}.npy` also exist in the repo +but are not downloaded by the Rust backend — the `small_to_mtp` projection is +baked into `code_predictor.onnx`. + +## Download: hf-hub + +The `ureq` dependency and custom HTTP download logic were replaced with +`hf-hub = "0.3"` (the official Rust HF Hub client). + +`download::ensure_model_dir()` calls `repo.get(filename)` for each required file. +HF Hub handles caching, LFS resolution, and authentication transparently. + +**Cache location**: `$HF_HOME/hub/` (default `~/.cache/huggingface/hub/`). +The function returns the snapshot root directory, which has the repo's subdirectory +layout intact. + +**Environment variables**: +| Variable | Purpose | +|---|---| +| `WAVEKAT_MODEL_DIR` | Skip HF Hub; load from this local path directly | +| `HF_TOKEN` | Authentication for private repos | +| `HF_HOME` | Override cache root | + +## Model dimensions (1.7B vs 0.6B) + +| Parameter | 0.6B | 1.7B | +|---|---|---| +| `HIDDEN_DIM` | 1024 | **2048** | +| `NUM_LAYERS` | 28 | 28 | +| `NUM_KV_HEADS` | 8 | 8 | +| `HEAD_DIM` | 128 | 128 | +| `CP_NUM_LAYERS` | 5 | 5 | +| `CP_NUM_KV_HEADS` | 8 | 8 | +| `MAX_NEW_TOKENS` | 2048 | **8192** | + +Only `HIDDEN_DIM` changes (1024 → 2048). All KV cache shapes, the code predictor +architecture, and the vocoder are identical between the two sizes. + +## Sampling (top-k replacing top-p) + +`SamplerConfig` now has `top_k: usize` instead of `top_p: f32`. +Values from `config.json`: + +| | Talker (group 0) | Code Predictor (groups 1-15) | +|---|---|---| +| `temperature` | 0.9 | 0.9 | +| `top_k` | 50 | 50 | +| `repetition_penalty` | 1.05 | 1.0 | + +The previous 0.6B values (temp=0.7, top_p=0.8 for talker; temp=0.2, top_p=0.5 +for CP) were derived from an older reference and did not match the official +`config.json`. + +Additionally, `min_new_tokens=2` is now enforced: CODEC_EOS is masked for the +first two decode steps, matching `generate_onnx.py`. + +## Prefill: non-streaming mode + +The 0.6B implementation used a **streaming** prefill: only the first text token +was included in the prefill; the rest were fed one-per-decode-step as a "trailing" +vector added to the codec embedding sum. + +The 1.7B implementation uses **non-streaming** prefill to match `generate_onnx.py`. +All text tokens are embedded in the prefill sequence: + +``` +[im_start, assistant, \n] — role (text_proj only) +[think, think_bos, lang_id, think_eos] — codec prefix (tts_pad + codec_embed) +[tts_bos + codec_pad] — transition +[text_proj(tok) + codec_pad] × N — all text tokens +[text_proj(TTS_EOS) + codec_pad] — TTS_EOS +[tts_pad + codec_bos] — final +``` + +The decode loop trailing is always `tts_pad_embed` (a constant) rather than +per-step text projections. + +Non-streaming gives the Talker LM full visibility over the complete text from the +first token, which improves prosody and accuracy on longer inputs. + +**Note**: the 0.6B codec prefix had 5 tokens (included a speaker slot = CODEC_PAD). +VoiceDesign has no predefined speakers, so the codec prefix is 4 tokens. + +## Bug fix: prefill logits extraction + +`run_talker_prefill` previously returned `logits_data.to_vec()` which includes +logits for **all** T prefill positions — a flat vector of T × 3072 elements. +The sampler then treated this as a 3072-token vocab, producing wrong samples. + +Fixed by slicing only the last position: + +```rust +let logits: Vec = logits_data[logits_data.len() - TALKER_VOCAB_SIZE..].to_vec(); +``` + +The decode step was unaffected (it always returns shape `(1, 1, 3072)` = 3072 elements). + +## code_predictor.onnx and small_to_mtp + +For the 1.7B model, the Talker hidden size is 2048 but the Code Predictor +transformer runs at 1024. The `small_to_mtp_projection` (Linear 2048 → 1024) +is **baked into `code_predictor.onnx`** rather than applied in host code. + +As a result: +- Host code passes `(1, 2, 2048)` to the code predictor on the first call + (concat of talker hidden and group-0 codec embedding, both 2048-dim) +- Subsequent calls pass `(1, 1, 2048)` from `cp_codec_embeddings[g]` (shape 2048×2048) +- No host-side projection step is needed + +`small_to_mtp_projection_{weight,bias}.npy` are present in the HF repo for +reference but are not downloaded or used by the Rust backend. + +## Usage + +```bash +# Auto-download via HF Hub and synthesize +cargo run --example synthesize --features qwen3-tts,hound -- "Hello, world!" + +# Interactive mode +cargo run --example synthesize --features qwen3-tts,hound -- -i + +# Load from manually downloaded snapshot +WAVEKAT_MODEL_DIR=/path/to/snapshot \ + cargo run --example synthesize --features qwen3-tts,hound -- "Hello" + +# Non-English +cargo run --example synthesize --features qwen3-tts,hound -- \ + --language chinese "让每一家小企业,都拥有大企业的声音。" +``` + +The first run downloads ~3.7 GB of INT4 ONNX files and ~0.6 GB of embeddings. +Subsequent runs load directly from the HF Hub cache. From efabea4e9cf3669bf9de2a4d0a114d9bf91079a1 Mon Sep 17 00:00:00 2001 From: Eason WaveKat Date: Mon, 6 Apr 2026 21:58:08 +1200 Subject: [PATCH 03/15] feat: VoiceDesign instruction API + interactive controls MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add `instruction` field to `SynthesizeRequest` with `with_instruction()` - Qwen3-TTS backend builds user-turn prefix from instruction tokens; warns on stderr when no instruction is provided - Upgrade hf-hub 0.3 → 0.5 (fixes relative-Location redirect bug) - synthesize example: default instruction, /lang, /langs, /instruct, /status, /help commands for live session control Co-Authored-By: Claude Sonnet 4.6 --- README.md | 2 +- crates/wavekat-tts/Cargo.toml | 2 +- crates/wavekat-tts/examples/synthesize.rs | 130 +++++++++++++++--- .../src/backends/qwen3_tts/download.rs | 23 +++- .../wavekat-tts/src/backends/qwen3_tts/mod.rs | 28 +++- .../src/backends/qwen3_tts/model.rs | 57 +++++--- .../src/backends/qwen3_tts/tokenizer.rs | 1 + crates/wavekat-tts/src/lib.rs | 8 +- crates/wavekat-tts/src/types.rs | 34 ++++- docs/04-qwen3-tts-1.7b-migration.md | 22 ++- tools/qwen3-tts-onnx/README.md | 10 +- 11 files changed, 252 insertions(+), 65 deletions(-) diff --git a/README.md b/README.md index ccfe448..6eeba86 100644 --- a/README.md +++ b/README.md @@ -73,7 +73,7 @@ Generate a WAV file from text (model files are auto-downloaded on first run): ```sh cargo run --example synthesize --features qwen3-tts,hound -- "Hello, world\!" -cargo run --example synthesize --features qwen3-tts,hound -- --language zh "你好世界" +cargo run --example synthesize --features qwen3-tts,hound -- --instruction "Speak in a warm, friendly tone." "Give every small business the voice of a big one." cargo run --example synthesize --features qwen3-tts,hound -- --model-dir /path/to/model --output hello.wav "Hello" ``` diff --git a/crates/wavekat-tts/Cargo.toml b/crates/wavekat-tts/Cargo.toml index 3158dd2..af3c4f0 100644 --- a/crates/wavekat-tts/Cargo.toml +++ b/crates/wavekat-tts/Cargo.toml @@ -27,7 +27,7 @@ ndarray = { version = "0.17", optional = true } tokenizers = { version = "0.21", optional = true, default-features = false, features = ["onig"] } npyz = { version = "0.8", optional = true } rand = { version = "0.9", optional = true } -hf-hub = { version = "0.3", optional = true } +hf-hub = { version = "0.5", optional = true, default-features = false, features = ["ureq"] } hound = { version = "3.5", optional = true } [dev-dependencies] diff --git a/crates/wavekat-tts/examples/synthesize.rs b/crates/wavekat-tts/examples/synthesize.rs index 0b32b2a..263ccb9 100644 --- a/crates/wavekat-tts/examples/synthesize.rs +++ b/crates/wavekat-tts/examples/synthesize.rs @@ -4,10 +4,21 @@ //! cargo run --example synthesize --features qwen3-tts,hound -- [OPTIONS] [TEXT] //! //! Options: -//! --model-dir Model directory (default: auto-download to cache) -//! --language Language code (default: en) -//! --output Output WAV path (default: output.wav) -//! -i, --interactive Interactive mode: keep model loaded, read text from stdin +//! --model-dir Model directory (default: auto-download to cache) +//! --language Language code (default: en) +//! --instruction Voice style instruction (VoiceDesign prompt) +//! Default: "Speak naturally and clearly." +//! --output Output WAV path (default: output.wav) +//! -i, --interactive Interactive mode: keep model loaded, read text from stdin +//! +//! Interactive commands (prefix with /): +//! /lang Switch language (e.g. /lang ja) +//! /langs List supported language codes +//! /instruct Change voice instruction (e.g. /instruct Speak slowly.) +//! /instruct Reset instruction to default +//! /status Show current settings +//! /help Show this command list +//! Empty line or Ctrl-D Quit //! //! Example: //! cargo run --example synthesize --features qwen3-tts,hound -- "Hello, world!" @@ -19,11 +30,14 @@ use std::path::PathBuf; use wavekat_tts::backends::qwen3_tts::Qwen3Tts; use wavekat_tts::{SynthesizeRequest, TtsBackend}; +const DEFAULT_INSTRUCTION: &str = "Speak naturally and clearly."; + fn main() { let args: Vec = std::env::args().skip(1).collect(); let mut model_dir: Option = None; let mut language = "en".to_string(); + let mut instruction: Option = None; let mut output = PathBuf::from("output.wav"); let mut interactive = false; let mut text_parts: Vec = Vec::new(); @@ -39,6 +53,10 @@ fn main() { i += 1; language = args[i].clone(); } + "--instruction" => { + i += 1; + instruction = Some(args[i].clone()); + } "--output" => { i += 1; output = PathBuf::from(&args[i]); @@ -52,13 +70,20 @@ fn main() { let text = text_parts.join(" "); if text.is_empty() && !interactive { eprintln!("Usage: synthesize [OPTIONS] [TEXT]"); - eprintln!(" --model-dir Model directory (default: auto-download)"); - eprintln!(" --language Language code (default: en)"); - eprintln!(" --output Output WAV path (default: output.wav)"); - eprintln!(" -i, --interactive Interactive mode (read from stdin)"); + eprintln!(" --model-dir Model directory (default: auto-download)"); + eprintln!(" --language Language code (default: en)"); + eprintln!(" --instruction Voice style instruction (VoiceDesign prompt)"); + eprintln!(" Default: \"{DEFAULT_INSTRUCTION}\""); + eprintln!(" --output Output WAV path (default: output.wav)"); + eprintln!(" -i, --interactive Interactive mode (read from stdin)"); std::process::exit(1); } + if instruction.is_none() { + eprintln!("note: no --instruction given, using default: \"{DEFAULT_INSTRUCTION}\""); + instruction = Some(DEFAULT_INSTRUCTION.to_string()); + } + eprintln!("Loading model ..."); let tts = match model_dir { Some(dir) => Qwen3Tts::from_dir(dir).expect("failed to load model"), @@ -66,14 +91,28 @@ fn main() { }; if interactive { - run_interactive(&tts, &language, &output); + run_interactive(&tts, language, instruction.unwrap(), &output); } else { - synthesize_one(&tts, &text, &language, &output); + synthesize_one(&tts, &text, &language, instruction.as_deref(), &output); } } -fn run_interactive(tts: &Qwen3Tts, language: &str, default_output: &PathBuf) { - eprintln!("Interactive mode. Type text to synthesize, empty line to quit."); +fn run_interactive( + tts: &Qwen3Tts, + mut language: String, + mut instruction: String, + default_output: &PathBuf, +) { + let supported_langs: Vec = tts + .voices() + .unwrap_or_default() + .into_iter() + .flat_map(|v| v.languages) + .collect(); + + eprintln!("Interactive mode. Type text to synthesize, /help for commands, empty line to quit."); + eprintln!(" language={language} instruction=\"{instruction}\""); + let stdin = io::stdin(); let mut count = 0u32; @@ -85,11 +124,58 @@ fn run_interactive(tts: &Qwen3Tts, language: &str, default_output: &PathBuf) { if stdin.lock().read_line(&mut line).unwrap_or(0) == 0 { break; } - let text = line.trim(); - if text.is_empty() { + let input = line.trim(); + if input.is_empty() { break; } + if let Some(rest) = input.strip_prefix('/') { + let (cmd, arg) = rest + .split_once(' ') + .map_or((rest, ""), |(c, a)| (c, a.trim())); + match cmd { + "lang" | "language" => { + if arg.is_empty() { + eprintln!("usage: /lang — type /langs to list supported codes"); + } else if !supported_langs.iter().any(|l| l == arg) { + eprintln!("unsupported language: \"{arg}\""); + eprintln!("supported: {}", supported_langs.join(", ")); + } else { + language = arg.to_string(); + eprintln!("language set to: {language}"); + } + } + "langs" | "languages" => { + eprintln!("supported languages: {}", supported_langs.join(", ")); + } + "instruct" | "instruction" => { + if arg.is_empty() { + instruction = DEFAULT_INSTRUCTION.to_string(); + eprintln!("instruction reset to default: \"{instruction}\""); + } else { + instruction = arg.to_string(); + eprintln!("instruction set to: \"{instruction}\""); + } + } + "status" => { + eprintln!(" language={language}"); + eprintln!(" instruction=\"{instruction}\""); + eprintln!(" supported languages: {}", supported_langs.join(", ")); + } + "help" => { + eprintln!(" /lang Switch language"); + eprintln!(" /langs List supported language codes"); + eprintln!(" /instruct Change voice instruction"); + eprintln!(" /instruct Reset instruction to default"); + eprintln!(" /status Show current settings"); + eprintln!(" /help Show this help"); + eprintln!(" Empty line Quit"); + } + other => eprintln!("unknown command: /{other} (type /help for commands)"), + } + continue; + } + count += 1; let output = if *default_output == PathBuf::from("output.wav") { PathBuf::from(format!("output_{count:03}.wav")) @@ -97,12 +183,21 @@ fn run_interactive(tts: &Qwen3Tts, language: &str, default_output: &PathBuf) { default_output.clone() }; - synthesize_one(tts, text, language, &output); + synthesize_one(tts, input, &language, Some(&instruction), &output); } } -fn synthesize_one(tts: &Qwen3Tts, text: &str, language: &str, output: &PathBuf) { - let request = SynthesizeRequest::new(text).with_language(language); +fn synthesize_one( + tts: &Qwen3Tts, + text: &str, + language: &str, + instruction: Option<&str>, + output: &PathBuf, +) { + let mut request = SynthesizeRequest::new(text).with_language(language); + if let Some(instr) = instruction { + request = request.with_instruction(instr); + } eprintln!("Synthesizing: \"{text}\" (language={language})"); let start = std::time::Instant::now(); @@ -121,7 +216,6 @@ fn synthesize_one(tts: &Qwen3Tts, text: &str, language: &str, output: &PathBuf) rtf, ); - // Write WAV let spec = hound::WavSpec { channels: 1, sample_rate: audio.sample_rate(), diff --git a/crates/wavekat-tts/src/backends/qwen3_tts/download.rs b/crates/wavekat-tts/src/backends/qwen3_tts/download.rs index 1b746c5..fea7417 100644 --- a/crates/wavekat-tts/src/backends/qwen3_tts/download.rs +++ b/crates/wavekat-tts/src/backends/qwen3_tts/download.rs @@ -2,15 +2,18 @@ use std::path::PathBuf; -use hf_hub::api::sync::Api; +use hf_hub::api::sync::ApiBuilder; use hf_hub::{Repo, RepoType}; use crate::TtsError; const REPO_ID: &str = "wavekat/Qwen3-TTS-1.7B-VoiceDesign-ONNX"; -const REVISION: &str = "62c7863a68800d72bcee4f2931148c441571ace7"; +const REVISION: &str = "2026-04-06"; /// Files required for INT4 inference (ONNX models + embeddings + tokenizer). +/// +/// `embeddings/small_to_mtp_projection_{weight,bias}.npy` are intentionally +/// excluded — that projection is baked into `code_predictor.onnx`. const MODEL_FILES: &[&str] = &[ "config.json", // INT4 ONNX models @@ -55,14 +58,26 @@ const MODEL_FILES: &[&str] = &[ /// Set `WAVEKAT_MODEL_DIR` to skip HF Hub and load from a local directory /// that mirrors the repo layout (`int4/`, `embeddings/`, `tokenizer/`). /// -/// Authentication: set `HF_TOKEN` if the repo requires it. +/// Authentication: set `HF_TOKEN` if the repo requires it. hf-hub 0.5 does +/// not read `HF_TOKEN` from the environment natively; this function bridges +/// the gap by passing it to `ApiBuilder::with_token`. +/// /// Cache location: `$HF_HOME/hub/` (default `~/.cache/huggingface/hub/`). pub fn ensure_model_dir() -> Result { if let Ok(dir) = std::env::var("WAVEKAT_MODEL_DIR") { return Ok(PathBuf::from(dir)); } - let api = Api::new() + // from_env() reads HF_HOME / HF_ENDPOINT. + // Bridge HF_TOKEN which hf-hub doesn't read from the environment natively. + let mut builder = ApiBuilder::from_env(); + if let Ok(token) = std::env::var("HF_TOKEN") { + if !token.is_empty() { + builder = builder.with_token(Some(token)); + } + } + let api = builder + .build() .map_err(|e| TtsError::Model(format!("failed to initialize HF Hub client: {e}")))?; let repo = api.repo(Repo::with_revision( diff --git a/crates/wavekat-tts/src/backends/qwen3_tts/mod.rs b/crates/wavekat-tts/src/backends/qwen3_tts/mod.rs index 9cf24d9..e14052c 100644 --- a/crates/wavekat-tts/src/backends/qwen3_tts/mod.rs +++ b/crates/wavekat-tts/src/backends/qwen3_tts/mod.rs @@ -25,6 +25,8 @@ use crate::error::TtsError; use crate::traits::TtsBackend; use crate::types::{SynthesizeRequest, VoiceInfo}; +use tokenizer::{IM_END, IM_START, NEWLINE}; + mod download; mod model; mod sampler; @@ -67,7 +69,31 @@ impl TtsBackend for Qwen3Tts { fn synthesize(&self, request: &SynthesizeRequest) -> Result, TtsError> { let tokens = self.tokenizer.encode(request.text)?; let language = request.language.unwrap_or("en"); - self.model.synthesize(&tokens, language) + + if request.instruction.is_none() { + eprintln!( + "wavekat-tts warning: Qwen3-TTS is a VoiceDesign model — \ + synthesize quality may be inconsistent without a style instruction. \ + Set `SynthesizeRequest::with_instruction` to control voice style." + ); + } + + let instruction_tokens = if let Some(instr) = request.instruction { + let mut toks = vec![IM_START]; + toks.extend(self.tokenizer.encode("user")?); + toks.push(NEWLINE); + toks.extend(self.tokenizer.encode("")?); + toks.extend(self.tokenizer.encode(instr)?); + toks.extend(self.tokenizer.encode("")?); + toks.push(IM_END); + toks.push(NEWLINE); + Some(toks) + } else { + None + }; + + self.model + .synthesize(&tokens, language, instruction_tokens.as_deref()) } fn voices(&self) -> Result, TtsError> { diff --git a/crates/wavekat-tts/src/backends/qwen3_tts/model.rs b/crates/wavekat-tts/src/backends/qwen3_tts/model.rs index 5008f34..e597c12 100644 --- a/crates/wavekat-tts/src/backends/qwen3_tts/model.rs +++ b/crates/wavekat-tts/src/backends/qwen3_tts/model.rs @@ -90,18 +90,14 @@ impl Model { let code_predictor = load_session("code_predictor.onnx")?; let vocoder = load_session("vocoder.onnx")?; - let text_embedding = - load_npy2(model_dir, "embeddings/text_embedding.npy")?; + let text_embedding = load_npy2(model_dir, "embeddings/text_embedding.npy")?; let text_proj_fc1_weight = load_npy2(model_dir, "embeddings/text_projection_fc1_weight.npy")?; - let text_proj_fc1_bias = - load_npy1(model_dir, "embeddings/text_projection_fc1_bias.npy")?; + let text_proj_fc1_bias = load_npy1(model_dir, "embeddings/text_projection_fc1_bias.npy")?; let text_proj_fc2_weight = load_npy2(model_dir, "embeddings/text_projection_fc2_weight.npy")?; - let text_proj_fc2_bias = - load_npy1(model_dir, "embeddings/text_projection_fc2_bias.npy")?; - let talker_codec_embedding = - load_npy2(model_dir, "embeddings/talker_codec_embedding.npy")?; + let text_proj_fc2_bias = load_npy1(model_dir, "embeddings/text_projection_fc2_bias.npy")?; + let talker_codec_embedding = load_npy2(model_dir, "embeddings/talker_codec_embedding.npy")?; let mut cp_codec_embeddings = Vec::with_capacity(NUM_CP_GROUPS); for i in 0..NUM_CP_GROUPS { @@ -137,15 +133,20 @@ impl Model { } /// Run the full synthesis pipeline: prefill → decode → code predict → vocoder. + /// + /// `instruction_tokens` — optional user-turn prefix for VoiceDesign control. + /// When `Some`, these tokens are embedded (text_proj only) at the start of + /// the prefill before the role prefix. pub fn synthesize( &self, text_tokens: &[u32], language: &str, + instruction_tokens: Option<&[u32]>, ) -> Result, TtsError> { let lang_id = tokenizer::language_id(language) .ok_or_else(|| TtsError::UnsupportedLanguage(language.to_string()))?; - let prefill_embeds = self.build_prefill_embeds(text_tokens, lang_id)?; + let prefill_embeds = self.build_prefill_embeds(text_tokens, lang_id, instruction_tokens)?; let prefill_len = prefill_embeds.shape()[1]; // Run talker prefill — returns last-position logits @@ -230,31 +231,49 @@ impl Model { /// Build prefill embeddings (non-streaming: all text embedded in prefill). /// /// ```text - /// [im_start, assistant, \n] — role prefix (text proj only) - /// [think, think_bos, lang_id, think_eos] — codec prefix (tts_pad + codec_embed) - /// [tts_bos + codec_pad] — transition - /// [text_proj(tok) + codec_pad] × N — all text tokens - /// [text_proj(TTS_EOS) + codec_pad] — TTS_EOS - /// [tts_pad + codec_bos] — final + /// [instr_tok × M] — user turn (text_proj only, optional) + /// [im_start, assistant, \n] — role prefix (text proj only) + /// [think, think_bos, lang_id, think_eos] — codec prefix (tts_pad + codec_embed) + /// [tts_bos + codec_pad] — transition + /// [text_proj(tok) + codec_pad] × N — all text tokens + /// [text_proj(TTS_EOS) + codec_pad] — TTS_EOS + /// [tts_pad + codec_bos] — final /// ``` fn build_prefill_embeds( &self, text_tokens: &[u32], lang_id: i64, + instruction_tokens: Option<&[u32]>, ) -> Result, TtsError> { - let codec_pad_embed = self.talker_codec_embedding.row(CODEC_PAD as usize).to_owned(); - let codec_bos_embed = self.talker_codec_embedding.row(CODEC_BOS as usize).to_owned(); + let codec_pad_embed = self + .talker_codec_embedding + .row(CODEC_PAD as usize) + .to_owned(); + let codec_bos_embed = self + .talker_codec_embedding + .row(CODEC_BOS as usize) + .to_owned(); let tts_bos_embed = self.text_project_token(TTS_BOS); let tts_eos_embed = self.text_project_token(TTS_EOS); // VoiceDesign codec prefix — no speaker slot let codec_prefix = [CODEC_THINK, CODEC_THINK_BOS, lang_id, CODEC_THINK_EOS]; - // 3 role + 4 codec_prefix + 1 transition + N text + 1 TTS_EOS + 1 final - let seq_len = 3 + codec_prefix.len() + 1 + text_tokens.len() + 1 + 1; + let instr_len = instruction_tokens.map_or(0, |t| t.len()); + // M instruction + 3 role + 4 codec_prefix + 1 transition + N text + 1 TTS_EOS + 1 final + let seq_len = instr_len + 3 + codec_prefix.len() + 1 + text_tokens.len() + 1 + 1; let mut embeds = Array3::::zeros((1, seq_len, HIDDEN_DIM)); let mut pos = 0; + // 0. Instruction / user-turn tokens: text_proj only (VoiceDesign control) + if let Some(instr_toks) = instruction_tokens { + for &tok in instr_toks { + let embed = self.text_project_token(tok); + embeds.slice_mut(s![0, pos, ..]).assign(&embed); + pos += 1; + } + } + // 1. Role prefix: text_proj only for &tok in &[IM_START, ASSISTANT, NEWLINE] { let embed = self.text_project_token(tok); diff --git a/crates/wavekat-tts/src/backends/qwen3_tts/tokenizer.rs b/crates/wavekat-tts/src/backends/qwen3_tts/tokenizer.rs index 3d70ba7..7f8dc32 100644 --- a/crates/wavekat-tts/src/backends/qwen3_tts/tokenizer.rs +++ b/crates/wavekat-tts/src/backends/qwen3_tts/tokenizer.rs @@ -4,6 +4,7 @@ use crate::TtsError; // Special text token IDs (Qwen3 BPE vocab) pub const IM_START: u32 = 151644; +pub const IM_END: u32 = 151645; pub const ASSISTANT: u32 = 77091; pub const NEWLINE: u32 = 198; pub const TTS_BOS: u32 = 151672; diff --git a/crates/wavekat-tts/src/lib.rs b/crates/wavekat-tts/src/lib.rs index 3bec35b..eed98de 100644 --- a/crates/wavekat-tts/src/lib.rs +++ b/crates/wavekat-tts/src/lib.rs @@ -22,10 +22,10 @@ //! //! # Feature flags //! -//! | Feature | Backend | Chinese | Requires | -//! |---------|---------|---------|----------| -//! | `qwen3-tts` | Qwen3-TTS (ONNX) | Excellent | ONNX model download | -//! | `cosyvoice` | CosyVoice (ONNX) | Excellent | ONNX model download | +//! | Feature | Backend | Multilingual | Requires | +//! |---------|---------|-------------|----------| +//! | `qwen3-tts` | Qwen3-TTS (ONNX) | 10 languages | ONNX model download | +//! | `cosyvoice` | CosyVoice (ONNX) | Yes | ONNX model download | //! //! # Quick start //! diff --git a/crates/wavekat-tts/src/types.rs b/crates/wavekat-tts/src/types.rs index 7685562..847d16f 100644 --- a/crates/wavekat-tts/src/types.rs +++ b/crates/wavekat-tts/src/types.rs @@ -1,7 +1,8 @@ /// A TTS synthesis request. /// /// Backend-agnostic parameters that describe what to synthesize. -/// Each backend interprets `voice` and `language` according to its own catalog. +/// Each backend interprets `voice`, `instruction`, and `language` according to +/// its own capabilities; unsupported fields are silently ignored. #[derive(Debug, Clone)] pub struct SynthesizeRequest<'a> { /// Text to synthesize. @@ -9,14 +10,30 @@ pub struct SynthesizeRequest<'a> { /// Voice identifier (backend-specific). /// - /// For Edge-TTS: `"zh-CN-XiaoxiaoNeural"`, `"zh-CN-YunxiNeural"`, etc. - /// For Kokoro: `"af_heart"`, `"zf_xiaobei"`, etc. + /// Used by backends with a fixed speaker catalog: + /// - Edge-TTS: `"zh-CN-XiaoxiaoNeural"`, `"zh-CN-YunxiNeural"`, … + /// - Kokoro: `"af_heart"`, `"zf_xiaobei"`, … + /// /// `None` uses the backend's default voice. pub voice: Option<&'a str>, + /// Free-form voice instruction / style prompt. + /// + /// Used by instruction-following backends (e.g. Qwen3-TTS VoiceDesign). + /// The text describes how the model should speak: + /// + /// ```text + /// "Speak in a calm, professional tone." + /// "Narrate with warmth and a gentle pace." + /// "Respond with high energy and enthusiasm!" + /// ``` + /// + /// `None` lets the backend use its default voice character. + pub instruction: Option<&'a str>, + /// Language / locale code. /// - /// E.g. `"zh-CN"`, `"en-US"`, `"ja-JP"`. + /// E.g. `"zh"`, `"en"`, `"ja"`. /// `None` uses the backend's default or auto-detects. pub language: Option<&'a str>, @@ -33,17 +50,24 @@ impl<'a> SynthesizeRequest<'a> { Self { text, voice: None, + instruction: None, language: None, speed: None, } } - /// Set the voice. + /// Set the voice identifier. pub fn with_voice(mut self, voice: &'a str) -> Self { self.voice = Some(voice); self } + /// Set the voice instruction / style prompt. + pub fn with_instruction(mut self, instruction: &'a str) -> Self { + self.instruction = Some(instruction); + self + } + /// Set the language. pub fn with_language(mut self, language: &'a str) -> Self { self.language = Some(language); diff --git a/docs/04-qwen3-tts-1.7b-migration.md b/docs/04-qwen3-tts-1.7b-migration.md index 925ea9a..2337083 100644 --- a/docs/04-qwen3-tts-1.7b-migration.md +++ b/docs/04-qwen3-tts-1.7b-migration.md @@ -50,10 +50,12 @@ baked into `code_predictor.onnx`. ## Download: hf-hub The `ureq` dependency and custom HTTP download logic were replaced with -`hf-hub = "0.3"` (the official Rust HF Hub client). +`hf-hub = "0.5"` (the official Rust HF Hub client). -`download::ensure_model_dir()` calls `repo.get(filename)` for each required file. -HF Hub handles caching, LFS resolution, and authentication transparently. +`download::ensure_model_dir()` calls `repo.get(filename)` for each file in +the hardcoded `MODEL_FILES` list. HF Hub handles caching, LFS resolution, and +redirect following transparently. hf-hub 0.5 fixes the relative-`Location` +redirect handling that broke 0.3 with HuggingFace's `/api/resolve-cache/` backend. **Cache location**: `$HF_HOME/hub/` (default `~/.cache/huggingface/hub/`). The function returns the snapshot root directory, which has the repo's subdirectory @@ -63,8 +65,14 @@ layout intact. | Variable | Purpose | |---|---| | `WAVEKAT_MODEL_DIR` | Skip HF Hub; load from this local path directly | -| `HF_TOKEN` | Authentication for private repos | -| `HF_HOME` | Override cache root | +| `HF_TOKEN` | Authentication for private/gated repos | +| `HF_HOME` | Override cache root (also sets the token file location) | +| `HF_ENDPOINT` | Override the HuggingFace endpoint URL | + +> **Note on `HF_TOKEN`**: hf-hub 0.5 does not natively read `HF_TOKEN` +> from the environment (it reads `$HF_HOME/token`, written by +> `huggingface-cli login`). `ensure_model_dir()` bridges this by passing +> `HF_TOKEN` to `ApiBuilder::with_token` when the env var is set. ## Model dimensions (1.7B vs 0.6B) @@ -168,9 +176,9 @@ cargo run --example synthesize --features qwen3-tts,hound -- -i WAVEKAT_MODEL_DIR=/path/to/snapshot \ cargo run --example synthesize --features qwen3-tts,hound -- "Hello" -# Non-English +# With a VoiceDesign instruction cargo run --example synthesize --features qwen3-tts,hound -- \ - --language chinese "让每一家小企业,都拥有大企业的声音。" + --instruction "Speak in a calm, professional tone." "Hello, world!" ``` The first run downloads ~3.7 GB of INT4 ONNX files and ~0.6 GB of embeddings. diff --git a/tools/qwen3-tts-onnx/README.md b/tools/qwen3-tts-onnx/README.md index f6568c7..de5d9fc 100644 --- a/tools/qwen3-tts-onnx/README.md +++ b/tools/qwen3-tts-onnx/README.md @@ -51,11 +51,11 @@ python generate_onnx.py --variant int4 \ --instruct "Speak in a warm and friendly female voice" \ -o output_int4.wav -# Chinese -python generate_onnx.py --variant int4 --lang chinese \ - --text "让每一家小企业,都拥有大企业的声音。" \ - --instruct "Speak in a warm and professional female voice" \ - -o output_zh.wav +# With a custom style instruction +python generate_onnx.py --variant int4 \ + --text "Give every small business the voice of a big one." \ + --instruct "Speak with confidence and warmth" \ + -o output_styled.wav ``` ## Model Architecture From 3db024fe02c3d70c637774172731a59c9a16ce94 Mon Sep 17 00:00:00 2001 From: Eason WaveKat Date: Mon, 6 Apr 2026 22:04:56 +1200 Subject: [PATCH 04/15] feat: show model loading progress after download Co-Authored-By: Claude Sonnet 4.6 --- .../src/backends/qwen3_tts/download.rs | 2 +- .../wavekat-tts/src/backends/qwen3_tts/model.rs | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/crates/wavekat-tts/src/backends/qwen3_tts/download.rs b/crates/wavekat-tts/src/backends/qwen3_tts/download.rs index fea7417..59f9a92 100644 --- a/crates/wavekat-tts/src/backends/qwen3_tts/download.rs +++ b/crates/wavekat-tts/src/backends/qwen3_tts/download.rs @@ -106,6 +106,6 @@ pub fn ensure_model_dir() -> Result { .map_err(|e| TtsError::Model(format!("failed to download {filename}: {e}")))?; } - eprintln!("Model ready."); + eprintln!("Files ready. Loading model ..."); Ok(model_dir) } diff --git a/crates/wavekat-tts/src/backends/qwen3_tts/model.rs b/crates/wavekat-tts/src/backends/qwen3_tts/model.rs index e597c12..89aef55 100644 --- a/crates/wavekat-tts/src/backends/qwen3_tts/model.rs +++ b/crates/wavekat-tts/src/backends/qwen3_tts/model.rs @@ -85,11 +85,23 @@ impl Model { .map_err(|e| TtsError::Model(format!("failed to load {name}: {e}"))) }; + eprint!("Loading talker prefill ... "); let talker_prefill = load_session("talker_prefill.onnx")?; + eprintln!("done"); + + eprint!("Loading talker decode ... "); let talker_decode = load_session("talker_decode.onnx")?; + eprintln!("done"); + + eprint!("Loading code predictor ... "); let code_predictor = load_session("code_predictor.onnx")?; + eprintln!("done"); + + eprint!("Loading vocoder ... "); let vocoder = load_session("vocoder.onnx")?; + eprintln!("done"); + eprint!("Loading embeddings ... "); let text_embedding = load_npy2(model_dir, "embeddings/text_embedding.npy")?; let text_proj_fc1_weight = load_npy2(model_dir, "embeddings/text_projection_fc1_weight.npy")?; @@ -107,6 +119,9 @@ impl Model { )?); } + eprintln!("done"); + eprintln!("Model ready."); + let tts_pad_raw = text_embedding.row(TTS_PAD as usize).to_owned(); let tts_pad_embed = text_project( &tts_pad_raw, From dfa4423879cc545883baf4574e76ba2574a2174a Mon Sep 17 00:00:00 2001 From: Eason WaveKat Date: Mon, 6 Apr 2026 22:42:25 +1200 Subject: [PATCH 05/15] feat: add FP32 precision option to Qwen3-TTS backend Co-Authored-By: Claude Sonnet 4.6 --- crates/wavekat-tts/examples/synthesize.rs | 21 ++++++-- .../src/backends/qwen3_tts/download.rs | 52 ++++++++++++------ .../wavekat-tts/src/backends/qwen3_tts/mod.rs | 53 ++++++++++++++----- .../src/backends/qwen3_tts/model.rs | 6 +-- 4 files changed, 97 insertions(+), 35 deletions(-) diff --git a/crates/wavekat-tts/examples/synthesize.rs b/crates/wavekat-tts/examples/synthesize.rs index 263ccb9..f1f2ed2 100644 --- a/crates/wavekat-tts/examples/synthesize.rs +++ b/crates/wavekat-tts/examples/synthesize.rs @@ -5,6 +5,7 @@ //! //! Options: //! --model-dir Model directory (default: auto-download to cache) +//! --precision Model precision: int4 (default) or fp32 //! --language Language code (default: en) //! --instruction Voice style instruction (VoiceDesign prompt) //! Default: "Speak naturally and clearly." @@ -23,11 +24,12 @@ //! Example: //! cargo run --example synthesize --features qwen3-tts,hound -- "Hello, world!" //! cargo run --example synthesize --features qwen3-tts,hound -- -i +//! cargo run --example synthesize --features qwen3-tts,hound -- --precision fp32 -i use std::io::{self, BufRead, Write}; use std::path::PathBuf; -use wavekat_tts::backends::qwen3_tts::Qwen3Tts; +use wavekat_tts::backends::qwen3_tts::{ModelPrecision, Qwen3Tts}; use wavekat_tts::{SynthesizeRequest, TtsBackend}; const DEFAULT_INSTRUCTION: &str = "Speak naturally and clearly."; @@ -36,6 +38,7 @@ fn main() { let args: Vec = std::env::args().skip(1).collect(); let mut model_dir: Option = None; + let mut precision = ModelPrecision::Int4; let mut language = "en".to_string(); let mut instruction: Option = None; let mut output = PathBuf::from("output.wav"); @@ -49,6 +52,17 @@ fn main() { i += 1; model_dir = Some(PathBuf::from(&args[i])); } + "--precision" => { + i += 1; + precision = match args[i].as_str() { + "int4" => ModelPrecision::Int4, + "fp32" => ModelPrecision::Fp32, + other => { + eprintln!("error: unknown precision \"{other}\", expected int4 or fp32"); + std::process::exit(1); + } + }; + } "--language" => { i += 1; language = args[i].clone(); @@ -71,6 +85,7 @@ fn main() { if text.is_empty() && !interactive { eprintln!("Usage: synthesize [OPTIONS] [TEXT]"); eprintln!(" --model-dir Model directory (default: auto-download)"); + eprintln!(" --precision Model precision: int4 (default) or fp32"); eprintln!(" --language Language code (default: en)"); eprintln!(" --instruction Voice style instruction (VoiceDesign prompt)"); eprintln!(" Default: \"{DEFAULT_INSTRUCTION}\""); @@ -86,8 +101,8 @@ fn main() { eprintln!("Loading model ..."); let tts = match model_dir { - Some(dir) => Qwen3Tts::from_dir(dir).expect("failed to load model"), - None => Qwen3Tts::new().expect("failed to load model"), + Some(dir) => Qwen3Tts::from_dir(dir, precision).expect("failed to load model"), + None => Qwen3Tts::new_with_precision(precision).expect("failed to load model"), }; if interactive { diff --git a/crates/wavekat-tts/src/backends/qwen3_tts/download.rs b/crates/wavekat-tts/src/backends/qwen3_tts/download.rs index 59f9a92..d49a724 100644 --- a/crates/wavekat-tts/src/backends/qwen3_tts/download.rs +++ b/crates/wavekat-tts/src/backends/qwen3_tts/download.rs @@ -10,13 +10,8 @@ use crate::TtsError; const REPO_ID: &str = "wavekat/Qwen3-TTS-1.7B-VoiceDesign-ONNX"; const REVISION: &str = "2026-04-06"; -/// Files required for INT4 inference (ONNX models + embeddings + tokenizer). -/// -/// `embeddings/small_to_mtp_projection_{weight,bias}.npy` are intentionally -/// excluded — that projection is baked into `code_predictor.onnx`. -const MODEL_FILES: &[&str] = &[ - "config.json", - // INT4 ONNX models +/// ONNX model files for INT4 precision. +const ONNX_FILES_INT4: &[&str] = &[ "int4/talker_prefill.onnx", "int4/talker_prefill.onnx.data", "int4/talker_decode.onnx", @@ -25,6 +20,26 @@ const MODEL_FILES: &[&str] = &[ "int4/code_predictor.onnx.data", "int4/vocoder.onnx", "int4/vocoder.onnx.data", +]; + +/// ONNX model files for FP32 precision. +const ONNX_FILES_FP32: &[&str] = &[ + "fp32/talker_prefill.onnx", + "fp32/talker_prefill.onnx.data", + "fp32/talker_decode.onnx", + "fp32/talker_decode.onnx.data", + "fp32/code_predictor.onnx", + "fp32/code_predictor.onnx.data", + "fp32/vocoder.onnx", + "fp32/vocoder.onnx.data", +]; + +/// Shared files required for all precision variants (embeddings + tokenizer + config). +/// +/// `embeddings/small_to_mtp_projection_{weight,bias}.npy` are intentionally +/// excluded — that projection is baked into `code_predictor.onnx`. +const SHARED_FILES: &[&str] = &[ + "config.json", // Embedding tables "embeddings/text_embedding.npy", "embeddings/text_projection_fc1_weight.npy", @@ -56,14 +71,14 @@ const MODEL_FILES: &[&str] = &[ /// downloading any missing files as needed. /// /// Set `WAVEKAT_MODEL_DIR` to skip HF Hub and load from a local directory -/// that mirrors the repo layout (`int4/`, `embeddings/`, `tokenizer/`). +/// that mirrors the repo layout (`int4/` or `fp32/`, `embeddings/`, `tokenizer/`). /// /// Authentication: set `HF_TOKEN` if the repo requires it. hf-hub 0.5 does /// not read `HF_TOKEN` from the environment natively; this function bridges /// the gap by passing it to `ApiBuilder::with_token`. /// /// Cache location: `$HF_HOME/hub/` (default `~/.cache/huggingface/hub/`). -pub fn ensure_model_dir() -> Result { +pub fn ensure_model_dir(precision: super::ModelPrecision) -> Result { if let Ok(dir) = std::env::var("WAVEKAT_MODEL_DIR") { return Ok(PathBuf::from(dir)); } @@ -86,21 +101,26 @@ pub fn ensure_model_dir() -> Result { REVISION.to_string(), )); - let total = MODEL_FILES.len(); - eprintln!("Ensuring Qwen3-TTS 1.7B model ({total} files from {REPO_ID})..."); + let onnx_files = match precision { + super::ModelPrecision::Int4 => ONNX_FILES_INT4, + super::ModelPrecision::Fp32 => ONNX_FILES_FP32, + }; + let total = 1 + onnx_files.len() + SHARED_FILES[1..].len(); // config + onnx + shared (excl. config) + + eprintln!("Ensuring Qwen3-TTS 1.7B ({}) model ({total} files from {REPO_ID})...", precision.subdir()); - // config.json is always first — its parent is the snapshot root. - eprintln!("[1/{total}] {}", MODEL_FILES[0]); + // config.json first — its parent is the snapshot root. + eprintln!("[1/{total}] {}", SHARED_FILES[0]); let config_path = repo - .get(MODEL_FILES[0]) - .map_err(|e| TtsError::Model(format!("failed to download {}: {e}", MODEL_FILES[0])))?; + .get(SHARED_FILES[0]) + .map_err(|e| TtsError::Model(format!("failed to download {}: {e}", SHARED_FILES[0])))?; let model_dir = config_path .parent() .ok_or_else(|| TtsError::Model("unexpected cache path for config.json".into()))? .to_path_buf(); - for (i, filename) in MODEL_FILES[1..].iter().enumerate() { + for (i, filename) in onnx_files.iter().chain(SHARED_FILES[1..].iter()).enumerate() { eprintln!("[{}/{total}] {filename}", i + 2); repo.get(filename) .map_err(|e| TtsError::Model(format!("failed to download {filename}: {e}")))?; diff --git a/crates/wavekat-tts/src/backends/qwen3_tts/mod.rs b/crates/wavekat-tts/src/backends/qwen3_tts/mod.rs index e14052c..7aacd69 100644 --- a/crates/wavekat-tts/src/backends/qwen3_tts/mod.rs +++ b/crates/wavekat-tts/src/backends/qwen3_tts/mod.rs @@ -1,17 +1,20 @@ -//! Qwen3-TTS backend (ONNX INT4, 1.7B VoiceDesign). +//! Qwen3-TTS backend (ONNX, 1.7B VoiceDesign). //! -//! Runs the Qwen3-TTS-12Hz-1.7B-VoiceDesign model via ONNX Runtime using the -//! INT4 weight-only quantized models from `wavekat/Qwen3-TTS-1.7B-VoiceDesign-ONNX`. +//! Runs the Qwen3-TTS-12Hz-1.7B-VoiceDesign model via ONNX Runtime. +//! Supports INT4 (weight-only quantized, default) and FP32 precision. //! //! ```ignore //! use wavekat_tts::{TtsBackend, SynthesizeRequest}; -//! use wavekat_tts::backends::qwen3_tts::Qwen3Tts; +//! use wavekat_tts::backends::qwen3_tts::{Qwen3Tts, ModelPrecision}; //! -//! // Auto-download model files via HF Hub (cached at ~/.cache/huggingface/hub/): +//! // Auto-download INT4 model files via HF Hub (default): //! let tts = Qwen3Tts::new()?; //! +//! // Auto-download FP32 model files: +//! let tts = Qwen3Tts::new_with_precision(ModelPrecision::Fp32)?; +//! //! // Or load from an explicit directory (must mirror the HF repo layout): -//! let tts = Qwen3Tts::from_dir("models/qwen3-tts-1.7b")?; +//! let tts = Qwen3Tts::from_dir("models/qwen3-tts-1.7b", ModelPrecision::Int4)?; //! //! let request = SynthesizeRequest::new("Hello, world"); //! let audio = tts.synthesize(&request)?; @@ -32,6 +35,25 @@ mod model; mod sampler; mod tokenizer; +/// ONNX model precision variant. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum ModelPrecision { + /// Weight-only INT4 quantized — smaller download, faster load. Default. + #[default] + Int4, + /// Full FP32 — larger download, no quantization error. + Fp32, +} + +impl ModelPrecision { + pub(crate) fn subdir(self) -> &'static str { + match self { + Self::Int4 => "int4", + Self::Fp32 => "fp32", + } + } +} + /// Qwen3-TTS backend using ONNX Runtime. pub struct Qwen3Tts { model: model::Model, @@ -39,27 +61,32 @@ pub struct Qwen3Tts { } impl Qwen3Tts { - /// Create a new backend, downloading model files from HF Hub if needed. + /// Create a new INT4 backend, downloading model files from HF Hub if needed. /// /// Files are cached by the HF Hub client (default `~/.cache/huggingface/hub/`). /// Set `HF_HOME` to change the cache root, or `HF_TOKEN` for authentication. /// Set `WAVEKAT_MODEL_DIR` to load from a local directory and skip all downloads. /// - /// Use [`from_dir`](Self::from_dir) to load from an explicit path. + /// Use [`new_with_precision`](Self::new_with_precision) to select FP32. pub fn new() -> Result { - let model_dir = download::ensure_model_dir()?; - Self::from_dir(model_dir) + Self::new_with_precision(ModelPrecision::Int4) + } + + /// Create a new backend with the given precision, downloading files if needed. + pub fn new_with_precision(precision: ModelPrecision) -> Result { + let model_dir = download::ensure_model_dir(precision)?; + Self::from_dir(model_dir, precision) } /// Load the model from a directory that mirrors the HF repo layout. /// /// Expected subdirectories: - /// - `int4/` — ONNX models (`talker_prefill.onnx`, `talker_decode.onnx`, etc.) + /// - `int4/` or `fp32/` — ONNX models (`talker_prefill.onnx`, etc.) /// - `embeddings/` — `.npy` embedding tables /// - `tokenizer/` — `vocab.json`, `merges.txt` - pub fn from_dir(model_dir: impl AsRef) -> Result { + pub fn from_dir(model_dir: impl AsRef, precision: ModelPrecision) -> Result { let model_dir = model_dir.as_ref(); - let model = model::Model::load(model_dir)?; + let model = model::Model::load(model_dir, precision)?; let tokenizer = tokenizer::Tokenizer::new(model_dir)?; Ok(Self { model, tokenizer }) } diff --git a/crates/wavekat-tts/src/backends/qwen3_tts/model.rs b/crates/wavekat-tts/src/backends/qwen3_tts/model.rs index 89aef55..f0af9ad 100644 --- a/crates/wavekat-tts/src/backends/qwen3_tts/model.rs +++ b/crates/wavekat-tts/src/backends/qwen3_tts/model.rs @@ -74,11 +74,11 @@ impl Model { /// Load all ONNX sessions and embedding tables from `model_dir`. /// /// Expected layout (matches the HF repo): - /// - `model_dir/int4/talker_prefill.onnx` (+ .data), etc. + /// - `model_dir/{int4,fp32}/talker_prefill.onnx` (+ .data), etc. /// - `model_dir/embeddings/text_embedding.npy`, etc. - pub fn load(model_dir: &Path) -> Result { + pub fn load(model_dir: &Path, precision: super::ModelPrecision) -> Result { let load_session = |name: &str| -> Result { - let path = model_dir.join("int4").join(name); + let path = model_dir.join(precision.subdir()).join(name); Session::builder() .map_err(|e| TtsError::Model(format!("session builder error: {e}")))? .commit_from_file(&path) From 9fb5ba7775e611534f1d761be3d6895063b18ce6 Mon Sep 17 00:00:00 2001 From: Eason WaveKat Date: Tue, 7 Apr 2026 08:21:18 +1200 Subject: [PATCH 06/15] docs: update README for 1.7B VoiceDesign + FP32 precision Co-Authored-By: Claude Sonnet 4.6 --- README.md | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 6eeba86..eaac9d0 100644 --- a/README.md +++ b/README.md @@ -31,21 +31,26 @@ cargo add wavekat-tts --features qwen3-tts ```rust use wavekat_tts::{TtsBackend, SynthesizeRequest}; -use wavekat_tts::backends::qwen3_tts::Qwen3Tts; +use wavekat_tts::backends::qwen3_tts::{Qwen3Tts, ModelPrecision}; -// Auto-downloads model files (~3.8 GB) on first run: +// Auto-downloads INT4 model files on first run (default): let tts = Qwen3Tts::new()?; +// Or load FP32: +// let tts = Qwen3Tts::new_with_precision(ModelPrecision::Fp32)?; + // Or load from an explicit directory: -// let tts = Qwen3Tts::from_dir("models/qwen3-tts-0.6b")?; +// let tts = Qwen3Tts::from_dir("models/qwen3-tts-1.7b", ModelPrecision::Int4)?; -let request = SynthesizeRequest::new("Hello, world"); +let request = SynthesizeRequest::new("Hello, world") + .with_instruction("Speak naturally and clearly."); let audio = tts.synthesize(&request)?; println!("{}s at {} Hz", audio.duration_secs(), audio.sample_rate()); ``` -Model files are cached at `$WAVEKAT_MODEL_DIR` or `~/.cache/wavekat/qwen3-tts-0.6b/`. +Model files are cached by the HF Hub client at `$HF_HOME/hub/` (default `~/.cache/huggingface/hub/`). +Set `WAVEKAT_MODEL_DIR` to load from a local directory and skip all downloads. All backends produce `AudioFrame<'static>` from [`wavekat-core`](https://github.com/wavekat/wavekat-core) — the same type consumed by `wavekat-vad` and `wavekat-turn`. @@ -74,6 +79,7 @@ Generate a WAV file from text (model files are auto-downloaded on first run): ```sh cargo run --example synthesize --features qwen3-tts,hound -- "Hello, world\!" cargo run --example synthesize --features qwen3-tts,hound -- --instruction "Speak in a warm, friendly tone." "Give every small business the voice of a big one." +cargo run --example synthesize --features qwen3-tts,hound -- --precision fp32 "Hello" cargo run --example synthesize --features qwen3-tts,hound -- --model-dir /path/to/model --output hello.wav "Hello" ``` From 3ab1b924dc0192a73aa7bb2a4343c8c83db625d7 Mon Sep 17 00:00:00 2001 From: Eason WaveKat Date: Tue, 7 Apr 2026 08:22:35 +1200 Subject: [PATCH 07/15] docs: mark CosyVoice backend as planned Co-Authored-By: Claude Sonnet 4.6 --- README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index eaac9d0..f099b45 100644 --- a/README.md +++ b/README.md @@ -18,10 +18,10 @@ Same pattern as ## Backends -| Backend | Feature flag | License | -|---------|-------------|---------| -| [Qwen3-TTS](https://huggingface.co/Qwen/Qwen3-TTS) | `qwen3-tts` | Apache 2.0 | -| [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) | `cosyvoice` | Apache 2.0 | +| Backend | Feature flag | Status | License | +|---------|-------------|--------|---------| +| [Qwen3-TTS](https://huggingface.co/Qwen/Qwen3-TTS) | `qwen3-tts` | ✅ Available | Apache 2.0 | +| [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) | `cosyvoice` | 🚧 Planned | Apache 2.0 | ## Quick start @@ -88,7 +88,7 @@ cargo run --example synthesize --features qwen3-tts,hound -- --model-dir /path/t | Flag | Default | Description | |------|---------|-------------| | `qwen3-tts` | off | Qwen3-TTS local ONNX inference | -| `cosyvoice` | off | CosyVoice local ONNX inference | +| `cosyvoice` | off | CosyVoice local ONNX inference (planned) | ## License From f5da9a30ce542459bdaabf72fbec55fc655143a7 Mon Sep 17 00:00:00 2001 From: Eason WaveKat Date: Tue, 7 Apr 2026 09:43:11 +1200 Subject: [PATCH 08/15] chore: bump wavekat-core to 0.0.4, use AudioFrame::from_vec Co-Authored-By: Claude Sonnet 4.6 --- CLAUDE.md | 5 ----- crates/wavekat-tts/Cargo.toml | 2 +- crates/wavekat-tts/src/backends/qwen3_tts/model.rs | 4 ++-- 3 files changed, 3 insertions(+), 8 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 6e129ea..442e5b7 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -32,8 +32,3 @@ make test-all # test all backends use internal tokio runtimes. 4. **Feature flags per backend** — same pattern as wavekat-vad/turn. -## Pending wavekat-core change - -`AudioFrame::from_owned(Vec, u32) -> AudioFrame<'static>` — avoids the -borrow-then-clone path when creating frames from producer-owned data (TTS output). -Currently uses `AudioFrame::new(slice, rate).into_owned()` as a workaround. diff --git a/crates/wavekat-tts/Cargo.toml b/crates/wavekat-tts/Cargo.toml index af3c4f0..855b4a5 100644 --- a/crates/wavekat-tts/Cargo.toml +++ b/crates/wavekat-tts/Cargo.toml @@ -16,7 +16,7 @@ qwen3-tts = ["dep:ort", "dep:ndarray", "dep:tokenizers", "dep:npyz", "dep:rand", cosyvoice = ["dep:ort", "dep:ndarray"] [dependencies] -wavekat-core = "0.0.3" +wavekat-core = "0.0.4" thiserror = "2" serde = { version = "1", features = ["derive"] } serde_json = "1" diff --git a/crates/wavekat-tts/src/backends/qwen3_tts/model.rs b/crates/wavekat-tts/src/backends/qwen3_tts/model.rs index f0af9ad..5e5ef8e 100644 --- a/crates/wavekat-tts/src/backends/qwen3_tts/model.rs +++ b/crates/wavekat-tts/src/backends/qwen3_tts/model.rs @@ -620,9 +620,9 @@ impl Model { // Trim leading silence produced by the think phase. let start = waveform.iter().position(|&s| s.abs() > 0.01).unwrap_or(0); - let trimmed = &waveform[start..]; + let trimmed = waveform[start..].to_vec(); - Ok(AudioFrame::new(trimmed, SAMPLE_RATE).into_owned()) + Ok(AudioFrame::from_vec(trimmed, SAMPLE_RATE)) } } From f3c4fd603182028ecf59ccd9d6b8fe21b36a261b Mon Sep 17 00:00:00 2001 From: Eason WaveKat Date: Tue, 7 Apr 2026 09:46:41 +1200 Subject: [PATCH 09/15] docs: propose AudioFrame WAV I/O feature for wavekat-core Co-Authored-By: Claude Sonnet 4.6 --- docs/02-proposed-core-wav-io.md | 126 ++++++++++++++++++++++++++++++++ 1 file changed, 126 insertions(+) create mode 100644 docs/02-proposed-core-wav-io.md diff --git a/docs/02-proposed-core-wav-io.md b/docs/02-proposed-core-wav-io.md new file mode 100644 index 0000000..c908f09 --- /dev/null +++ b/docs/02-proposed-core-wav-io.md @@ -0,0 +1,126 @@ +# Proposed addition to wavekat-core: WAV I/O + +## Problem + +Every crate in the WaveKat ecosystem that reads or writes WAV files (wavekat-vad, +wavekat-turn, wavekat-tts) independently reaches for `hound` and repeats the same +boilerplate. Any spec change (e.g. bits-per-sample, channel count) must be updated +in multiple places. + +```rust +// Current: every crate does this manually +let spec = hound::WavSpec { + channels: 1, + sample_rate: audio.sample_rate(), + bits_per_sample: 32, + sample_format: hound::SampleFormat::Float, +}; +let mut writer = hound::WavWriter::create(path, spec)?; +for &sample in audio.samples() { + writer.write_sample(sample)?; +} +writer.finalize()?; +``` + +## Proposed change + +Add a `hound` feature flag to wavekat-core that extends `AudioFrame` with two +methods: + +```toml +# wavekat-core Cargo.toml +[features] +hound = ["dep:hound"] + +[dependencies] +hound = { version = "3.5", optional = true } +``` + +```rust +// In audio.rs, behind #[cfg(feature = "hound")]: + +impl AudioFrame<'_> { + /// Write this frame to a WAV file at `path`. + /// + /// Always writes mono f32 PCM at the frame's native sample rate. + /// + /// # Example + /// + /// ```no_run + /// use wavekat_core::AudioFrame; + /// + /// let frame = AudioFrame::from_vec(vec![0.0f32; 16000], 16000); + /// frame.write_wav("output.wav").unwrap(); + /// ``` + #[cfg(feature = "hound")] + pub fn write_wav(&self, path: impl AsRef) -> Result<(), hound::Error> { + let spec = hound::WavSpec { + channels: 1, + sample_rate: self.sample_rate, + bits_per_sample: 32, + sample_format: hound::SampleFormat::Float, + }; + let mut writer = hound::WavWriter::create(path, spec)?; + for &sample in self.samples() { + writer.write_sample(sample)?; + } + writer.finalize() + } +} + +impl AudioFrame<'static> { + /// Read a mono WAV file and return an owned `AudioFrame`. + /// + /// Accepts both f32 and i16 WAV files. i16 samples are normalised to + /// `[-1.0, 1.0]` (divided by 32768). + /// + /// # Example + /// + /// ```no_run + /// use wavekat_core::AudioFrame; + /// + /// let frame = AudioFrame::from_wav("input.wav").unwrap(); + /// println!("{} Hz, {} samples", frame.sample_rate(), frame.len()); + /// ``` + #[cfg(feature = "hound")] + pub fn from_wav(path: impl AsRef) -> Result { + let mut reader = hound::WavReader::open(path)?; + let spec = reader.spec(); + let sample_rate = spec.sample_rate; + let samples: Vec = match spec.sample_format { + hound::SampleFormat::Float => { + reader.samples::().map(|s| s.unwrap()).collect() + } + hound::SampleFormat::Int => { + reader.samples::().map(|s| s.unwrap() as f32 / 32768.0).collect() + } + }; + Ok(AudioFrame::from_vec(samples, sample_rate)) + } +} +``` + +## Impact + +- **Zero breaking changes** — purely additive, opt-in via feature flag +- Consumers opt in: `wavekat-core = { version = "0.0.5", features = ["hound"] }` +- All examples and tests across wavekat-vad, wavekat-turn, wavekat-tts can drop + their own `hound` boilerplate and use the canonical implementation +- One place to maintain the WAV spec (mono, f32, native sample rate) + +## Tests + +```rust +#[cfg(feature = "hound")] +#[test] +fn wav_round_trip() { + let original = AudioFrame::from_vec(vec![0.5f32, -0.5, 0.0, 1.0], 16000); + let path = std::env::temp_dir().join("wavekat_test.wav"); + original.write_wav(&path).unwrap(); + let loaded = AudioFrame::from_wav(&path).unwrap(); + assert_eq!(loaded.sample_rate(), 16000); + for (a, b) in original.samples().iter().zip(loaded.samples()) { + assert!((a - b).abs() < 1e-6, "sample mismatch: {a} vs {b}"); + } +} +``` From 22033807c7c99209005dd94ef7a3c27cbd0c4c96 Mon Sep 17 00:00:00 2001 From: Eason WaveKat Date: Tue, 7 Apr 2026 11:33:10 +1200 Subject: [PATCH 10/15] chore: use AudioFrame::write_wav from wavekat-core 0.0.5 Co-Authored-By: Claude Sonnet 4.6 --- crates/wavekat-tts/Cargo.toml | 7 +++---- crates/wavekat-tts/examples/synthesize.rs | 20 +++++--------------- 2 files changed, 8 insertions(+), 19 deletions(-) diff --git a/crates/wavekat-tts/Cargo.toml b/crates/wavekat-tts/Cargo.toml index 855b4a5..5a0b7fa 100644 --- a/crates/wavekat-tts/Cargo.toml +++ b/crates/wavekat-tts/Cargo.toml @@ -15,8 +15,9 @@ default = [] qwen3-tts = ["dep:ort", "dep:ndarray", "dep:tokenizers", "dep:npyz", "dep:rand", "dep:hf-hub"] cosyvoice = ["dep:ort", "dep:ndarray"] + [dependencies] -wavekat-core = "0.0.4" +wavekat-core = { version = "0.0.5", features = ["wav"] } thiserror = "2" serde = { version = "1", features = ["derive"] } serde_json = "1" @@ -28,11 +29,9 @@ tokenizers = { version = "0.21", optional = true, default-features = false, feat npyz = { version = "0.8", optional = true } rand = { version = "0.9", optional = true } hf-hub = { version = "0.5", optional = true, default-features = false, features = ["ureq"] } -hound = { version = "3.5", optional = true } [dev-dependencies] -hound = "3.5" [[example]] name = "synthesize" -required-features = ["qwen3-tts", "hound"] +required-features = ["qwen3-tts"] diff --git a/crates/wavekat-tts/examples/synthesize.rs b/crates/wavekat-tts/examples/synthesize.rs index f1f2ed2..df36a95 100644 --- a/crates/wavekat-tts/examples/synthesize.rs +++ b/crates/wavekat-tts/examples/synthesize.rs @@ -1,7 +1,7 @@ //! Synthesize text to a WAV file using Qwen3-TTS. //! //! Usage: -//! cargo run --example synthesize --features qwen3-tts,hound -- [OPTIONS] [TEXT] +//! cargo run --example synthesize --features qwen3-tts -- [OPTIONS] [TEXT] //! //! Options: //! --model-dir Model directory (default: auto-download to cache) @@ -22,9 +22,9 @@ //! Empty line or Ctrl-D Quit //! //! Example: -//! cargo run --example synthesize --features qwen3-tts,hound -- "Hello, world!" -//! cargo run --example synthesize --features qwen3-tts,hound -- -i -//! cargo run --example synthesize --features qwen3-tts,hound -- --precision fp32 -i +//! cargo run --example synthesize --features qwen3-tts -- "Hello, world!" +//! cargo run --example synthesize --features qwen3-tts -- -i +//! cargo run --example synthesize --features qwen3-tts -- --precision fp32 -i use std::io::{self, BufRead, Write}; use std::path::PathBuf; @@ -231,17 +231,7 @@ fn synthesize_one( rtf, ); - let spec = hound::WavSpec { - channels: 1, - sample_rate: audio.sample_rate(), - bits_per_sample: 32, - sample_format: hound::SampleFormat::Float, - }; - let mut writer = hound::WavWriter::create(output, spec).expect("failed to create WAV file"); - for &sample in audio.samples() { - writer.write_sample(sample).expect("failed to write sample"); - } - writer.finalize().expect("failed to finalize WAV"); + audio.write_wav(output).expect("failed to write WAV"); eprintln!("Wrote {}", output.display()); } From 3b39e0d8e88a4e109ee5b109e9642236c92d13a9 Mon Sep 17 00:00:00 2001 From: Eason WaveKat Date: Tue, 7 Apr 2026 11:34:17 +1200 Subject: [PATCH 11/15] docs: show write_wav in quick start and update example commands Co-Authored-By: Claude Sonnet 4.6 --- README.md | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index f099b45..20e97d9 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,7 @@ Same pattern as ```sh cargo add wavekat-tts --features qwen3-tts +cargo add wavekat-core --features wav ``` ```rust @@ -46,6 +47,9 @@ let request = SynthesizeRequest::new("Hello, world") .with_instruction("Speak naturally and clearly."); let audio = tts.synthesize(&request)?; +// Save to WAV (wavekat-core includes WAV I/O via the `wav` feature): +audio.write_wav("output.wav")?; + println!("{}s at {} Hz", audio.duration_secs(), audio.sample_rate()); ``` @@ -77,10 +81,10 @@ Two trait families: Generate a WAV file from text (model files are auto-downloaded on first run): ```sh -cargo run --example synthesize --features qwen3-tts,hound -- "Hello, world\!" -cargo run --example synthesize --features qwen3-tts,hound -- --instruction "Speak in a warm, friendly tone." "Give every small business the voice of a big one." -cargo run --example synthesize --features qwen3-tts,hound -- --precision fp32 "Hello" -cargo run --example synthesize --features qwen3-tts,hound -- --model-dir /path/to/model --output hello.wav "Hello" +cargo run --example synthesize --features qwen3-tts -- "Hello, world\!" +cargo run --example synthesize --features qwen3-tts -- --instruction "Speak in a warm, friendly tone." "Give every small business the voice of a big one." +cargo run --example synthesize --features qwen3-tts -- --precision fp32 "Hello" +cargo run --example synthesize --features qwen3-tts -- --model-dir /path/to/model --output hello.wav "Hello" ``` ## Feature flags @@ -90,6 +94,8 @@ cargo run --example synthesize --features qwen3-tts,hound -- --model-dir /path/t | `qwen3-tts` | off | Qwen3-TTS local ONNX inference | | `cosyvoice` | off | CosyVoice local ONNX inference (planned) | +WAV I/O (`write_wav` / `from_wav`) is provided by `wavekat-core` via its `wav` feature flag. + ## License Licensed under [Apache 2.0](LICENSE). From 5ce743e6cbd4969154312b0bc20971387809d0c8 Mon Sep 17 00:00:00 2001 From: Eason WaveKat Date: Tue, 7 Apr 2026 11:35:40 +1200 Subject: [PATCH 12/15] docs: remove redundant cargo add wavekat-core from quick start Co-Authored-By: Claude Sonnet 4.6 --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index 20e97d9..f0218f9 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,6 @@ Same pattern as ```sh cargo add wavekat-tts --features qwen3-tts -cargo add wavekat-core --features wav ``` ```rust From f286d3a12685127348a599a41e9d96db114d9f60 Mon Sep 17 00:00:00 2001 From: Eason WaveKat Date: Tue, 7 Apr 2026 11:38:05 +1200 Subject: [PATCH 13/15] style: fix rustfmt formatting in qwen3_tts backend Co-Authored-By: Claude Sonnet 4.6 --- crates/wavekat-tts/src/backends/qwen3_tts/download.rs | 11 +++++++++-- crates/wavekat-tts/src/backends/qwen3_tts/mod.rs | 5 ++++- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/crates/wavekat-tts/src/backends/qwen3_tts/download.rs b/crates/wavekat-tts/src/backends/qwen3_tts/download.rs index d49a724..f58ae0a 100644 --- a/crates/wavekat-tts/src/backends/qwen3_tts/download.rs +++ b/crates/wavekat-tts/src/backends/qwen3_tts/download.rs @@ -107,7 +107,10 @@ pub fn ensure_model_dir(precision: super::ModelPrecision) -> Result Result, precision: ModelPrecision) -> Result { + pub fn from_dir( + model_dir: impl AsRef, + precision: ModelPrecision, + ) -> Result { let model_dir = model_dir.as_ref(); let model = model::Model::load(model_dir, precision)?; let tokenizer = tokenizer::Tokenizer::new(model_dir)?; From 219e66c99e79932f27e673220e64f0240a5676df Mon Sep 17 00:00:00 2001 From: Eason WaveKat Date: Tue, 7 Apr 2026 11:51:05 +1200 Subject: [PATCH 14/15] refactor: replace Qwen3Tts constructors with ModelConfig MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduce ModelConfig (precision + execution_provider + model_dir) and replace new_with_precision/from_dir with a single from_config constructor. Model dir resolution order: config field → WAVEKAT_MODEL_DIR → HF Hub. Co-Authored-By: Claude Sonnet 4.6 --- README.md | 16 ++- crates/wavekat-tts/examples/synthesize.rs | 15 +- .../src/backends/qwen3_tts/download.rs | 18 ++- .../wavekat-tts/src/backends/qwen3_tts/mod.rs | 128 +++++++++++++----- .../src/backends/qwen3_tts/model.rs | 29 +++- 5 files changed, 153 insertions(+), 53 deletions(-) diff --git a/README.md b/README.md index f0218f9..8541d65 100644 --- a/README.md +++ b/README.md @@ -31,16 +31,20 @@ cargo add wavekat-tts --features qwen3-tts ```rust use wavekat_tts::{TtsBackend, SynthesizeRequest}; -use wavekat_tts::backends::qwen3_tts::{Qwen3Tts, ModelPrecision}; +use wavekat_tts::backends::qwen3_tts::{Qwen3Tts, ModelConfig, ModelPrecision, ExecutionProvider}; -// Auto-downloads INT4 model files on first run (default): +// Auto-downloads INT4 model files on first run, runs on CPU (default): let tts = Qwen3Tts::new()?; -// Or load FP32: -// let tts = Qwen3Tts::new_with_precision(ModelPrecision::Fp32)?; +// Or FP32 on CPU: +// let tts = Qwen3Tts::from_config(ModelConfig::default().with_precision(ModelPrecision::Fp32))?; -// Or load from an explicit directory: -// let tts = Qwen3Tts::from_dir("models/qwen3-tts-1.7b", ModelPrecision::Int4)?; +// Or INT4 from a local directory on CUDA: +// let tts = Qwen3Tts::from_config( +// ModelConfig::default() +// .with_dir("models/qwen3-tts-1.7b") +// .with_execution_provider(ExecutionProvider::Cuda), +// )?; let request = SynthesizeRequest::new("Hello, world") .with_instruction("Speak naturally and clearly."); diff --git a/crates/wavekat-tts/examples/synthesize.rs b/crates/wavekat-tts/examples/synthesize.rs index df36a95..f0b3ca4 100644 --- a/crates/wavekat-tts/examples/synthesize.rs +++ b/crates/wavekat-tts/examples/synthesize.rs @@ -29,7 +29,7 @@ use std::io::{self, BufRead, Write}; use std::path::PathBuf; -use wavekat_tts::backends::qwen3_tts::{ModelPrecision, Qwen3Tts}; +use wavekat_tts::backends::qwen3_tts::{ModelConfig, ModelPrecision, Qwen3Tts}; use wavekat_tts::{SynthesizeRequest, TtsBackend}; const DEFAULT_INSTRUCTION: &str = "Speak naturally and clearly."; @@ -100,10 +100,11 @@ fn main() { } eprintln!("Loading model ..."); - let tts = match model_dir { - Some(dir) => Qwen3Tts::from_dir(dir, precision).expect("failed to load model"), - None => Qwen3Tts::new_with_precision(precision).expect("failed to load model"), - }; + let mut config = ModelConfig::default().with_precision(precision); + if let Some(dir) = model_dir { + config = config.with_dir(dir); + } + let tts = Qwen3Tts::from_config(config).expect("failed to load model"); if interactive { run_interactive(&tts, language, instruction.unwrap(), &output); @@ -192,7 +193,7 @@ fn run_interactive( } count += 1; - let output = if *default_output == PathBuf::from("output.wav") { + let output = if default_output == std::path::Path::new("output.wav") { PathBuf::from(format!("output_{count:03}.wav")) } else { default_output.clone() @@ -220,7 +221,7 @@ fn synthesize_one( let elapsed = start.elapsed(); let duration = audio.duration_secs(); - let rtf = elapsed.as_secs_f64() / duration as f64; + let rtf = elapsed.as_secs_f64() / duration; eprintln!( "Generated {} samples at {} Hz ({:.2}s) in {:.2}s (RTF: {:.2})", diff --git a/crates/wavekat-tts/src/backends/qwen3_tts/download.rs b/crates/wavekat-tts/src/backends/qwen3_tts/download.rs index f58ae0a..480ea01 100644 --- a/crates/wavekat-tts/src/backends/qwen3_tts/download.rs +++ b/crates/wavekat-tts/src/backends/qwen3_tts/download.rs @@ -67,22 +67,30 @@ const SHARED_FILES: &[&str] = &[ "tokenizer/merges.txt", ]; -/// Resolve the local HF Hub snapshot directory for the Qwen3-TTS model, -/// downloading any missing files as needed. +/// Resolve the local directory for the Qwen3-TTS model files, downloading +/// any missing files from HF Hub as needed. /// -/// Set `WAVEKAT_MODEL_DIR` to skip HF Hub and load from a local directory -/// that mirrors the repo layout (`int4/` or `fp32/`, `embeddings/`, `tokenizer/`). +/// Resolution order: +/// 1. `config.model_dir` (if set) +/// 2. `WAVEKAT_MODEL_DIR` environment variable +/// 3. Auto-download from HF Hub /// /// Authentication: set `HF_TOKEN` if the repo requires it. hf-hub 0.5 does /// not read `HF_TOKEN` from the environment natively; this function bridges /// the gap by passing it to `ApiBuilder::with_token`. /// /// Cache location: `$HF_HOME/hub/` (default `~/.cache/huggingface/hub/`). -pub fn ensure_model_dir(precision: super::ModelPrecision) -> Result { +pub fn resolve_model_dir(config: &super::ModelConfig) -> Result { + if let Some(dir) = &config.model_dir { + return Ok(dir.clone()); + } + if let Ok(dir) = std::env::var("WAVEKAT_MODEL_DIR") { return Ok(PathBuf::from(dir)); } + let precision = config.precision; + // from_env() reads HF_HOME / HF_ENDPOINT. // Bridge HF_TOKEN which hf-hub doesn't read from the environment natively. let mut builder = ApiBuilder::from_env(); diff --git a/crates/wavekat-tts/src/backends/qwen3_tts/mod.rs b/crates/wavekat-tts/src/backends/qwen3_tts/mod.rs index 4e8d4e7..8acb92e 100644 --- a/crates/wavekat-tts/src/backends/qwen3_tts/mod.rs +++ b/crates/wavekat-tts/src/backends/qwen3_tts/mod.rs @@ -5,22 +5,26 @@ //! //! ```ignore //! use wavekat_tts::{TtsBackend, SynthesizeRequest}; -//! use wavekat_tts::backends::qwen3_tts::{Qwen3Tts, ModelPrecision}; +//! use wavekat_tts::backends::qwen3_tts::{Qwen3Tts, ModelConfig, ModelPrecision, ExecutionProvider}; //! -//! // Auto-download INT4 model files via HF Hub (default): +//! // Auto-download INT4 model files via HF Hub, run on CPU (default): //! let tts = Qwen3Tts::new()?; //! -//! // Auto-download FP32 model files: -//! let tts = Qwen3Tts::new_with_precision(ModelPrecision::Fp32)?; +//! // Auto-download FP32, run on CPU: +//! let tts = Qwen3Tts::from_config(ModelConfig::default().with_precision(ModelPrecision::Fp32))?; //! -//! // Or load from an explicit directory (must mirror the HF repo layout): -//! let tts = Qwen3Tts::from_dir("models/qwen3-tts-1.7b", ModelPrecision::Int4)?; +//! // INT4 from a local directory, run on CUDA: +//! let tts = Qwen3Tts::from_config( +//! ModelConfig::default() +//! .with_dir("models/qwen3-tts-1.7b") +//! .with_execution_provider(ExecutionProvider::Cuda), +//! )?; //! //! let request = SynthesizeRequest::new("Hello, world"); //! let audio = tts.synthesize(&request)?; //! ``` -use std::path::Path; +use std::path::PathBuf; use wavekat_core::AudioFrame; @@ -36,6 +40,8 @@ mod sampler; mod tokenizer; /// ONNX model precision variant. +/// +/// Selects which quantized model files to download and load. #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] pub enum ModelPrecision { /// Weight-only INT4 quantized — smaller download, faster load. Default. @@ -54,6 +60,77 @@ impl ModelPrecision { } } +/// ONNX execution provider (inference hardware backend). +/// +/// Selecting a provider that is unavailable at runtime causes an error at load +/// time rather than silently falling back. Use [`ExecutionProvider::Cpu`] (the +/// default) if you need guaranteed availability. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum ExecutionProvider { + /// CPU inference via ONNX Runtime. Always available. Default. + #[default] + Cpu, + /// NVIDIA CUDA GPU inference. Requires an ORT build with CUDA support. + Cuda, + /// Apple CoreML (macOS / iOS). Requires an ORT build with CoreML support. + CoreMl, +} + +/// Model loading configuration for [`Qwen3Tts`]. +/// +/// All fields default to sensible values: INT4 quantization, CPU inference, +/// and auto-download from HF Hub. +/// +/// # Examples +/// +/// ```rust,no_run +/// # use wavekat_tts::backends::qwen3_tts::{ModelConfig, ModelPrecision, ExecutionProvider}; +/// // INT4, CPU, auto-download (equivalent to ModelConfig::default()) +/// let config = ModelConfig::default(); +/// +/// // FP32 from a local directory +/// let config = ModelConfig::default() +/// .with_precision(ModelPrecision::Fp32) +/// .with_dir("models/qwen3-tts-1.7b"); +/// +/// // INT4, CUDA, auto-download +/// let config = ModelConfig::default() +/// .with_execution_provider(ExecutionProvider::Cuda); +/// ``` +#[derive(Debug, Clone, Default)] +pub struct ModelConfig { + /// Weight quantization variant (determines which ONNX files to load). + pub precision: ModelPrecision, + /// Inference hardware backend. + pub execution_provider: ExecutionProvider, + /// Local model directory. `None` = resolve via `WAVEKAT_MODEL_DIR` env var, + /// then auto-download from HF Hub. + pub model_dir: Option, +} + +impl ModelConfig { + /// Set a local model directory, bypassing HF Hub download. + /// + /// The directory must mirror the HF repo layout: + /// `int4/` or `fp32/`, `embeddings/`, `tokenizer/`. + pub fn with_dir(mut self, dir: impl Into) -> Self { + self.model_dir = Some(dir.into()); + self + } + + /// Set the model precision (quantization variant). + pub fn with_precision(mut self, precision: ModelPrecision) -> Self { + self.precision = precision; + self + } + + /// Set the ONNX execution provider. + pub fn with_execution_provider(mut self, ep: ExecutionProvider) -> Self { + self.execution_provider = ep; + self + } +} + /// Qwen3-TTS backend using ONNX Runtime. pub struct Qwen3Tts { model: model::Model, @@ -61,36 +138,25 @@ pub struct Qwen3Tts { } impl Qwen3Tts { - /// Create a new INT4 backend, downloading model files from HF Hub if needed. + /// Create a new backend with default config (INT4, CPU, auto-download). /// /// Files are cached by the HF Hub client (default `~/.cache/huggingface/hub/`). - /// Set `HF_HOME` to change the cache root, or `HF_TOKEN` for authentication. - /// Set `WAVEKAT_MODEL_DIR` to load from a local directory and skip all downloads. - /// - /// Use [`new_with_precision`](Self::new_with_precision) to select FP32. + /// Set `HF_HOME` to change the cache root, `HF_TOKEN` for authentication, or + /// `WAVEKAT_MODEL_DIR` to load from a local directory and skip all downloads. pub fn new() -> Result { - Self::new_with_precision(ModelPrecision::Int4) - } - - /// Create a new backend with the given precision, downloading files if needed. - pub fn new_with_precision(precision: ModelPrecision) -> Result { - let model_dir = download::ensure_model_dir(precision)?; - Self::from_dir(model_dir, precision) + Self::from_config(ModelConfig::default()) } - /// Load the model from a directory that mirrors the HF repo layout. + /// Create a new backend with the given [`ModelConfig`]. /// - /// Expected subdirectories: - /// - `int4/` or `fp32/` — ONNX models (`talker_prefill.onnx`, etc.) - /// - `embeddings/` — `.npy` embedding tables - /// - `tokenizer/` — `vocab.json`, `merges.txt` - pub fn from_dir( - model_dir: impl AsRef, - precision: ModelPrecision, - ) -> Result { - let model_dir = model_dir.as_ref(); - let model = model::Model::load(model_dir, precision)?; - let tokenizer = tokenizer::Tokenizer::new(model_dir)?; + /// Model files are resolved in priority order: + /// 1. `config.model_dir` (if set) + /// 2. `WAVEKAT_MODEL_DIR` environment variable + /// 3. Auto-download from HF Hub + pub fn from_config(config: ModelConfig) -> Result { + let model_dir = download::resolve_model_dir(&config)?; + let model = model::Model::load(model_dir.as_ref(), &config)?; + let tokenizer = tokenizer::Tokenizer::new(&model_dir)?; Ok(Self { model, tokenizer }) } } diff --git a/crates/wavekat-tts/src/backends/qwen3_tts/model.rs b/crates/wavekat-tts/src/backends/qwen3_tts/model.rs index 5e5ef8e..ca0d99a 100644 --- a/crates/wavekat-tts/src/backends/qwen3_tts/model.rs +++ b/crates/wavekat-tts/src/backends/qwen3_tts/model.rs @@ -76,11 +76,12 @@ impl Model { /// Expected layout (matches the HF repo): /// - `model_dir/{int4,fp32}/talker_prefill.onnx` (+ .data), etc. /// - `model_dir/embeddings/text_embedding.npy`, etc. - pub fn load(model_dir: &Path, precision: super::ModelPrecision) -> Result { + pub fn load(model_dir: &Path, config: &super::ModelConfig) -> Result { let load_session = |name: &str| -> Result { - let path = model_dir.join(precision.subdir()).join(name); - Session::builder() - .map_err(|e| TtsError::Model(format!("session builder error: {e}")))? + let path = model_dir.join(config.precision.subdir()).join(name); + let builder = Session::builder() + .map_err(|e| TtsError::Model(format!("session builder error: {e}")))?; + apply_execution_provider(builder, config.execution_provider)? .commit_from_file(&path) .map_err(|e| TtsError::Model(format!("failed to load {name}: {e}"))) }; @@ -630,6 +631,26 @@ impl Model { // Helpers // --------------------------------------------------------------------------- +/// Register the requested execution provider on a session builder. +/// +/// CPU is the ORT default — no registration needed. CUDA and CoreML require +/// an ORT build that includes those providers; otherwise ORT will return an error. +fn apply_execution_provider( + builder: ort::session::builder::SessionBuilder, + ep: super::ExecutionProvider, +) -> Result { + use ort::execution_providers::{CUDAExecutionProvider, CoreMLExecutionProvider}; + match ep { + super::ExecutionProvider::Cpu => Ok(builder), + super::ExecutionProvider::Cuda => builder + .with_execution_providers([CUDAExecutionProvider::default().build()]) + .map_err(|e| TtsError::Model(format!("CUDA execution provider error: {e}"))), + super::ExecutionProvider::CoreMl => builder + .with_execution_providers([CoreMLExecutionProvider::default().build()]) + .map_err(|e| TtsError::Model(format!("CoreML execution provider error: {e}"))), + } +} + /// SiLU-gated MLP text projection: 2048 → 2048. fn text_project( input: &Array1, From e1c137a0ab165b2557b395d2fd6b363c6d2724a2 Mon Sep 17 00:00:00 2001 From: Eason WaveKat Date: Tue, 7 Apr 2026 14:04:32 +1200 Subject: [PATCH 15/15] fix: one-time warning and Cargo.toml cleanup - Wrap no-instruction eprintln! in Once so it fires at most once per process rather than on every synthesize() call - Remove double blank line and empty [dev-dependencies] in Cargo.toml Co-Authored-By: Claude Sonnet 4.6 --- crates/wavekat-tts/Cargo.toml | 3 --- crates/wavekat-tts/src/backends/qwen3_tts/mod.rs | 16 +++++++++++----- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/crates/wavekat-tts/Cargo.toml b/crates/wavekat-tts/Cargo.toml index 5a0b7fa..73b7c20 100644 --- a/crates/wavekat-tts/Cargo.toml +++ b/crates/wavekat-tts/Cargo.toml @@ -15,7 +15,6 @@ default = [] qwen3-tts = ["dep:ort", "dep:ndarray", "dep:tokenizers", "dep:npyz", "dep:rand", "dep:hf-hub"] cosyvoice = ["dep:ort", "dep:ndarray"] - [dependencies] wavekat-core = { version = "0.0.5", features = ["wav"] } thiserror = "2" @@ -30,8 +29,6 @@ npyz = { version = "0.8", optional = true } rand = { version = "0.9", optional = true } hf-hub = { version = "0.5", optional = true, default-features = false, features = ["ureq"] } -[dev-dependencies] - [[example]] name = "synthesize" required-features = ["qwen3-tts"] diff --git a/crates/wavekat-tts/src/backends/qwen3_tts/mod.rs b/crates/wavekat-tts/src/backends/qwen3_tts/mod.rs index 8acb92e..eed072d 100644 --- a/crates/wavekat-tts/src/backends/qwen3_tts/mod.rs +++ b/crates/wavekat-tts/src/backends/qwen3_tts/mod.rs @@ -32,8 +32,12 @@ use crate::error::TtsError; use crate::traits::TtsBackend; use crate::types::{SynthesizeRequest, VoiceInfo}; +use std::sync::Once; + use tokenizer::{IM_END, IM_START, NEWLINE}; +static WARNED_NO_INSTRUCTION: Once = Once::new(); + mod download; mod model; mod sampler; @@ -167,11 +171,13 @@ impl TtsBackend for Qwen3Tts { let language = request.language.unwrap_or("en"); if request.instruction.is_none() { - eprintln!( - "wavekat-tts warning: Qwen3-TTS is a VoiceDesign model — \ - synthesize quality may be inconsistent without a style instruction. \ - Set `SynthesizeRequest::with_instruction` to control voice style." - ); + WARNED_NO_INSTRUCTION.call_once(|| { + eprintln!( + "wavekat-tts warning: Qwen3-TTS is a VoiceDesign model — \ + synthesize quality may be inconsistent without a style instruction. \ + Set `SynthesizeRequest::with_instruction` to control voice style." + ); + }); } let instruction_tokens = if let Some(instr) = request.instruction {