diff --git a/CLAUDE.md b/CLAUDE.md index 6e129ea..442e5b7 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -32,8 +32,3 @@ make test-all # test all backends use internal tokio runtimes. 4. **Feature flags per backend** โ€” same pattern as wavekat-vad/turn. -## Pending wavekat-core change - -`AudioFrame::from_owned(Vec, u32) -> AudioFrame<'static>` โ€” avoids the -borrow-then-clone path when creating frames from producer-owned data (TTS output). -Currently uses `AudioFrame::new(slice, rate).into_owned()` as a workaround. diff --git a/README.md b/README.md index ccfe448..8541d65 100644 --- a/README.md +++ b/README.md @@ -18,10 +18,10 @@ Same pattern as ## Backends -| Backend | Feature flag | License | -|---------|-------------|---------| -| [Qwen3-TTS](https://huggingface.co/Qwen/Qwen3-TTS) | `qwen3-tts` | Apache 2.0 | -| [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) | `cosyvoice` | Apache 2.0 | +| Backend | Feature flag | Status | License | +|---------|-------------|--------|---------| +| [Qwen3-TTS](https://huggingface.co/Qwen/Qwen3-TTS) | `qwen3-tts` | โœ… Available | Apache 2.0 | +| [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) | `cosyvoice` | ๐Ÿšง Planned | Apache 2.0 | ## Quick start @@ -31,21 +31,33 @@ cargo add wavekat-tts --features qwen3-tts ```rust use wavekat_tts::{TtsBackend, SynthesizeRequest}; -use wavekat_tts::backends::qwen3_tts::Qwen3Tts; +use wavekat_tts::backends::qwen3_tts::{Qwen3Tts, ModelConfig, ModelPrecision, ExecutionProvider}; -// Auto-downloads model files (~3.8 GB) on first run: +// Auto-downloads INT4 model files on first run, runs on CPU (default): let tts = Qwen3Tts::new()?; -// Or load from an explicit directory: -// let tts = Qwen3Tts::from_dir("models/qwen3-tts-0.6b")?; +// Or FP32 on CPU: +// let tts = Qwen3Tts::from_config(ModelConfig::default().with_precision(ModelPrecision::Fp32))?; -let request = SynthesizeRequest::new("Hello, world"); +// Or INT4 from a local directory on CUDA: +// let tts = Qwen3Tts::from_config( +// ModelConfig::default() +// .with_dir("models/qwen3-tts-1.7b") +// .with_execution_provider(ExecutionProvider::Cuda), +// )?; + +let request = SynthesizeRequest::new("Hello, world") + .with_instruction("Speak naturally and clearly."); let audio = tts.synthesize(&request)?; +// Save to WAV (wavekat-core includes WAV I/O via the `wav` feature): +audio.write_wav("output.wav")?; + println!("{}s at {} Hz", audio.duration_secs(), audio.sample_rate()); ``` -Model files are cached at `$WAVEKAT_MODEL_DIR` or `~/.cache/wavekat/qwen3-tts-0.6b/`. +Model files are cached by the HF Hub client at `$HF_HOME/hub/` (default `~/.cache/huggingface/hub/`). +Set `WAVEKAT_MODEL_DIR` to load from a local directory and skip all downloads. All backends produce `AudioFrame<'static>` from [`wavekat-core`](https://github.com/wavekat/wavekat-core) โ€” the same type consumed by `wavekat-vad` and `wavekat-turn`. @@ -72,9 +84,10 @@ Two trait families: Generate a WAV file from text (model files are auto-downloaded on first run): ```sh -cargo run --example synthesize --features qwen3-tts,hound -- "Hello, world\!" -cargo run --example synthesize --features qwen3-tts,hound -- --language zh "ไฝ ๅฅฝไธ–็•Œ" -cargo run --example synthesize --features qwen3-tts,hound -- --model-dir /path/to/model --output hello.wav "Hello" +cargo run --example synthesize --features qwen3-tts -- "Hello, world\!" +cargo run --example synthesize --features qwen3-tts -- --instruction "Speak in a warm, friendly tone." "Give every small business the voice of a big one." +cargo run --example synthesize --features qwen3-tts -- --precision fp32 "Hello" +cargo run --example synthesize --features qwen3-tts -- --model-dir /path/to/model --output hello.wav "Hello" ``` ## Feature flags @@ -82,7 +95,9 @@ cargo run --example synthesize --features qwen3-tts,hound -- --model-dir /path/t | Flag | Default | Description | |------|---------|-------------| | `qwen3-tts` | off | Qwen3-TTS local ONNX inference | -| `cosyvoice` | off | CosyVoice local ONNX inference | +| `cosyvoice` | off | CosyVoice local ONNX inference (planned) | + +WAV I/O (`write_wav` / `from_wav`) is provided by `wavekat-core` via its `wav` feature flag. ## License diff --git a/crates/wavekat-tts/Cargo.toml b/crates/wavekat-tts/Cargo.toml index ce39065..73b7c20 100644 --- a/crates/wavekat-tts/Cargo.toml +++ b/crates/wavekat-tts/Cargo.toml @@ -12,11 +12,11 @@ categories = ["multimedia::audio"] default = [] # Local inference backends (all ONNX-based) -qwen3-tts = ["dep:ort", "dep:ndarray", "dep:tokenizers", "dep:npyz", "dep:rand", "dep:ureq"] +qwen3-tts = ["dep:ort", "dep:ndarray", "dep:tokenizers", "dep:npyz", "dep:rand", "dep:hf-hub"] cosyvoice = ["dep:ort", "dep:ndarray"] [dependencies] -wavekat-core = "0.0.3" +wavekat-core = { version = "0.0.5", features = ["wav"] } thiserror = "2" serde = { version = "1", features = ["derive"] } serde_json = "1" @@ -27,12 +27,8 @@ ndarray = { version = "0.17", optional = true } tokenizers = { version = "0.21", optional = true, default-features = false, features = ["onig"] } npyz = { version = "0.8", optional = true } rand = { version = "0.9", optional = true } -ureq = { version = "2", optional = true } -hound = { version = "3.5", optional = true } - -[dev-dependencies] -hound = "3.5" +hf-hub = { version = "0.5", optional = true, default-features = false, features = ["ureq"] } [[example]] name = "synthesize" -required-features = ["qwen3-tts", "hound"] +required-features = ["qwen3-tts"] diff --git a/crates/wavekat-tts/examples/synthesize.rs b/crates/wavekat-tts/examples/synthesize.rs index 0b32b2a..f0b3ca4 100644 --- a/crates/wavekat-tts/examples/synthesize.rs +++ b/crates/wavekat-tts/examples/synthesize.rs @@ -1,29 +1,46 @@ //! Synthesize text to a WAV file using Qwen3-TTS. //! //! Usage: -//! cargo run --example synthesize --features qwen3-tts,hound -- [OPTIONS] [TEXT] +//! cargo run --example synthesize --features qwen3-tts -- [OPTIONS] [TEXT] //! //! Options: -//! --model-dir Model directory (default: auto-download to cache) -//! --language Language code (default: en) -//! --output Output WAV path (default: output.wav) -//! -i, --interactive Interactive mode: keep model loaded, read text from stdin +//! --model-dir Model directory (default: auto-download to cache) +//! --precision Model precision: int4 (default) or fp32 +//! --language Language code (default: en) +//! --instruction Voice style instruction (VoiceDesign prompt) +//! Default: "Speak naturally and clearly." +//! --output Output WAV path (default: output.wav) +//! -i, --interactive Interactive mode: keep model loaded, read text from stdin +//! +//! Interactive commands (prefix with /): +//! /lang Switch language (e.g. /lang ja) +//! /langs List supported language codes +//! /instruct Change voice instruction (e.g. /instruct Speak slowly.) +//! /instruct Reset instruction to default +//! /status Show current settings +//! /help Show this command list +//! Empty line or Ctrl-D Quit //! //! Example: -//! cargo run --example synthesize --features qwen3-tts,hound -- "Hello, world!" -//! cargo run --example synthesize --features qwen3-tts,hound -- -i +//! cargo run --example synthesize --features qwen3-tts -- "Hello, world!" +//! cargo run --example synthesize --features qwen3-tts -- -i +//! cargo run --example synthesize --features qwen3-tts -- --precision fp32 -i use std::io::{self, BufRead, Write}; use std::path::PathBuf; -use wavekat_tts::backends::qwen3_tts::Qwen3Tts; +use wavekat_tts::backends::qwen3_tts::{ModelConfig, ModelPrecision, Qwen3Tts}; use wavekat_tts::{SynthesizeRequest, TtsBackend}; +const DEFAULT_INSTRUCTION: &str = "Speak naturally and clearly."; + fn main() { let args: Vec = std::env::args().skip(1).collect(); let mut model_dir: Option = None; + let mut precision = ModelPrecision::Int4; let mut language = "en".to_string(); + let mut instruction: Option = None; let mut output = PathBuf::from("output.wav"); let mut interactive = false; let mut text_parts: Vec = Vec::new(); @@ -35,10 +52,25 @@ fn main() { i += 1; model_dir = Some(PathBuf::from(&args[i])); } + "--precision" => { + i += 1; + precision = match args[i].as_str() { + "int4" => ModelPrecision::Int4, + "fp32" => ModelPrecision::Fp32, + other => { + eprintln!("error: unknown precision \"{other}\", expected int4 or fp32"); + std::process::exit(1); + } + }; + } "--language" => { i += 1; language = args[i].clone(); } + "--instruction" => { + i += 1; + instruction = Some(args[i].clone()); + } "--output" => { i += 1; output = PathBuf::from(&args[i]); @@ -52,28 +84,51 @@ fn main() { let text = text_parts.join(" "); if text.is_empty() && !interactive { eprintln!("Usage: synthesize [OPTIONS] [TEXT]"); - eprintln!(" --model-dir Model directory (default: auto-download)"); - eprintln!(" --language Language code (default: en)"); - eprintln!(" --output Output WAV path (default: output.wav)"); - eprintln!(" -i, --interactive Interactive mode (read from stdin)"); + eprintln!(" --model-dir Model directory (default: auto-download)"); + eprintln!(" --precision Model precision: int4 (default) or fp32"); + eprintln!(" --language Language code (default: en)"); + eprintln!(" --instruction Voice style instruction (VoiceDesign prompt)"); + eprintln!(" Default: \"{DEFAULT_INSTRUCTION}\""); + eprintln!(" --output Output WAV path (default: output.wav)"); + eprintln!(" -i, --interactive Interactive mode (read from stdin)"); std::process::exit(1); } + if instruction.is_none() { + eprintln!("note: no --instruction given, using default: \"{DEFAULT_INSTRUCTION}\""); + instruction = Some(DEFAULT_INSTRUCTION.to_string()); + } + eprintln!("Loading model ..."); - let tts = match model_dir { - Some(dir) => Qwen3Tts::from_dir(dir).expect("failed to load model"), - None => Qwen3Tts::new().expect("failed to load model"), - }; + let mut config = ModelConfig::default().with_precision(precision); + if let Some(dir) = model_dir { + config = config.with_dir(dir); + } + let tts = Qwen3Tts::from_config(config).expect("failed to load model"); if interactive { - run_interactive(&tts, &language, &output); + run_interactive(&tts, language, instruction.unwrap(), &output); } else { - synthesize_one(&tts, &text, &language, &output); + synthesize_one(&tts, &text, &language, instruction.as_deref(), &output); } } -fn run_interactive(tts: &Qwen3Tts, language: &str, default_output: &PathBuf) { - eprintln!("Interactive mode. Type text to synthesize, empty line to quit."); +fn run_interactive( + tts: &Qwen3Tts, + mut language: String, + mut instruction: String, + default_output: &PathBuf, +) { + let supported_langs: Vec = tts + .voices() + .unwrap_or_default() + .into_iter() + .flat_map(|v| v.languages) + .collect(); + + eprintln!("Interactive mode. Type text to synthesize, /help for commands, empty line to quit."); + eprintln!(" language={language} instruction=\"{instruction}\""); + let stdin = io::stdin(); let mut count = 0u32; @@ -85,24 +140,80 @@ fn run_interactive(tts: &Qwen3Tts, language: &str, default_output: &PathBuf) { if stdin.lock().read_line(&mut line).unwrap_or(0) == 0 { break; } - let text = line.trim(); - if text.is_empty() { + let input = line.trim(); + if input.is_empty() { break; } + if let Some(rest) = input.strip_prefix('/') { + let (cmd, arg) = rest + .split_once(' ') + .map_or((rest, ""), |(c, a)| (c, a.trim())); + match cmd { + "lang" | "language" => { + if arg.is_empty() { + eprintln!("usage: /lang โ€” type /langs to list supported codes"); + } else if !supported_langs.iter().any(|l| l == arg) { + eprintln!("unsupported language: \"{arg}\""); + eprintln!("supported: {}", supported_langs.join(", ")); + } else { + language = arg.to_string(); + eprintln!("language set to: {language}"); + } + } + "langs" | "languages" => { + eprintln!("supported languages: {}", supported_langs.join(", ")); + } + "instruct" | "instruction" => { + if arg.is_empty() { + instruction = DEFAULT_INSTRUCTION.to_string(); + eprintln!("instruction reset to default: \"{instruction}\""); + } else { + instruction = arg.to_string(); + eprintln!("instruction set to: \"{instruction}\""); + } + } + "status" => { + eprintln!(" language={language}"); + eprintln!(" instruction=\"{instruction}\""); + eprintln!(" supported languages: {}", supported_langs.join(", ")); + } + "help" => { + eprintln!(" /lang Switch language"); + eprintln!(" /langs List supported language codes"); + eprintln!(" /instruct Change voice instruction"); + eprintln!(" /instruct Reset instruction to default"); + eprintln!(" /status Show current settings"); + eprintln!(" /help Show this help"); + eprintln!(" Empty line Quit"); + } + other => eprintln!("unknown command: /{other} (type /help for commands)"), + } + continue; + } + count += 1; - let output = if *default_output == PathBuf::from("output.wav") { + let output = if default_output == std::path::Path::new("output.wav") { PathBuf::from(format!("output_{count:03}.wav")) } else { default_output.clone() }; - synthesize_one(tts, text, language, &output); + synthesize_one(tts, input, &language, Some(&instruction), &output); } } -fn synthesize_one(tts: &Qwen3Tts, text: &str, language: &str, output: &PathBuf) { - let request = SynthesizeRequest::new(text).with_language(language); +fn synthesize_one( + tts: &Qwen3Tts, + text: &str, + language: &str, + instruction: Option<&str>, + output: &PathBuf, +) { + let mut request = SynthesizeRequest::new(text).with_language(language); + if let Some(instr) = instruction { + request = request.with_instruction(instr); + } eprintln!("Synthesizing: \"{text}\" (language={language})"); let start = std::time::Instant::now(); @@ -110,7 +221,7 @@ fn synthesize_one(tts: &Qwen3Tts, text: &str, language: &str, output: &PathBuf) let elapsed = start.elapsed(); let duration = audio.duration_secs(); - let rtf = elapsed.as_secs_f64() / duration as f64; + let rtf = elapsed.as_secs_f64() / duration; eprintln!( "Generated {} samples at {} Hz ({:.2}s) in {:.2}s (RTF: {:.2})", @@ -121,18 +232,7 @@ fn synthesize_one(tts: &Qwen3Tts, text: &str, language: &str, output: &PathBuf) rtf, ); - // Write WAV - let spec = hound::WavSpec { - channels: 1, - sample_rate: audio.sample_rate(), - bits_per_sample: 32, - sample_format: hound::SampleFormat::Float, - }; - let mut writer = hound::WavWriter::create(output, spec).expect("failed to create WAV file"); - for &sample in audio.samples() { - writer.write_sample(sample).expect("failed to write sample"); - } - writer.finalize().expect("failed to finalize WAV"); + audio.write_wav(output).expect("failed to write WAV"); eprintln!("Wrote {}", output.display()); } diff --git a/crates/wavekat-tts/src/backends/qwen3_tts/download.rs b/crates/wavekat-tts/src/backends/qwen3_tts/download.rs index da7b523..480ea01 100644 --- a/crates/wavekat-tts/src/backends/qwen3_tts/download.rs +++ b/crates/wavekat-tts/src/backends/qwen3_tts/download.rs @@ -1,229 +1,146 @@ -//! Auto-download model files to a local cache directory. +//! Download model files from HuggingFace Hub. -use std::fs; -use std::io::{Read, Write}; -use std::path::{Path, PathBuf}; +use std::path::PathBuf; + +use hf_hub::api::sync::ApiBuilder; +use hf_hub::{Repo, RepoType}; use crate::TtsError; -/// Pinned commit โ€” guarantees immutable file URLs. -const REVISION: &str = "6a297d9641354ef0c16e63d329a93a6239bca0a2"; - -const BASE_URL: &str = "https://huggingface.co/elbruno/Qwen3-TTS-12Hz-0.6B-Base-ONNX/resolve"; - -/// (remote_path, local_filename) โ€” remote paths map to the HF repo layout, -/// local filenames are flattened into the cache directory. -const MODEL_FILES: &[(&str, &str)] = &[ - // ONNX sessions + external weight files - ("talker_prefill.onnx", "talker_prefill.onnx"), - ("talker_prefill.onnx.data", "talker_prefill.onnx.data"), - ("talker_decode.onnx", "talker_decode.onnx"), - ("talker_decode.onnx.data", "talker_decode.onnx.data"), - ("code_predictor.onnx", "code_predictor.onnx"), - ("vocoder.onnx", "vocoder.onnx"), - ("vocoder.onnx.data", "vocoder.onnx.data"), - // Embedding tables (in embeddings/ on HF, flattened locally) - ("embeddings/text_embedding.npy", "text_embedding.npy"), - ( - "embeddings/text_projection_fc1_weight.npy", - "text_projection_fc1_weight.npy", - ), - ( - "embeddings/text_projection_fc1_bias.npy", - "text_projection_fc1_bias.npy", - ), - ( - "embeddings/text_projection_fc2_weight.npy", - "text_projection_fc2_weight.npy", - ), - ( - "embeddings/text_projection_fc2_bias.npy", - "text_projection_fc2_bias.npy", - ), - ( - "embeddings/talker_codec_embedding.npy", - "talker_codec_embedding.npy", - ), - ( - "embeddings/cp_codec_embedding_0.npy", - "cp_codec_embedding_0.npy", - ), - ( - "embeddings/cp_codec_embedding_1.npy", - "cp_codec_embedding_1.npy", - ), - ( - "embeddings/cp_codec_embedding_2.npy", - "cp_codec_embedding_2.npy", - ), - ( - "embeddings/cp_codec_embedding_3.npy", - "cp_codec_embedding_3.npy", - ), - ( - "embeddings/cp_codec_embedding_4.npy", - "cp_codec_embedding_4.npy", - ), - ( - "embeddings/cp_codec_embedding_5.npy", - "cp_codec_embedding_5.npy", - ), - ( - "embeddings/cp_codec_embedding_6.npy", - "cp_codec_embedding_6.npy", - ), - ( - "embeddings/cp_codec_embedding_7.npy", - "cp_codec_embedding_7.npy", - ), - ( - "embeddings/cp_codec_embedding_8.npy", - "cp_codec_embedding_8.npy", - ), - ( - "embeddings/cp_codec_embedding_9.npy", - "cp_codec_embedding_9.npy", - ), - ( - "embeddings/cp_codec_embedding_10.npy", - "cp_codec_embedding_10.npy", - ), - ( - "embeddings/cp_codec_embedding_11.npy", - "cp_codec_embedding_11.npy", - ), - ( - "embeddings/cp_codec_embedding_12.npy", - "cp_codec_embedding_12.npy", - ), - ( - "embeddings/cp_codec_embedding_13.npy", - "cp_codec_embedding_13.npy", - ), - ( - "embeddings/cp_codec_embedding_14.npy", - "cp_codec_embedding_14.npy", - ), - // Tokenizer (in tokenizer/ on HF, flattened locally) - ("tokenizer/vocab.json", "vocab.json"), - ("tokenizer/merges.txt", "merges.txt"), +const REPO_ID: &str = "wavekat/Qwen3-TTS-1.7B-VoiceDesign-ONNX"; +const REVISION: &str = "2026-04-06"; + +/// ONNX model files for INT4 precision. +const ONNX_FILES_INT4: &[&str] = &[ + "int4/talker_prefill.onnx", + "int4/talker_prefill.onnx.data", + "int4/talker_decode.onnx", + "int4/talker_decode.onnx.data", + "int4/code_predictor.onnx", + "int4/code_predictor.onnx.data", + "int4/vocoder.onnx", + "int4/vocoder.onnx.data", +]; + +/// ONNX model files for FP32 precision. +const ONNX_FILES_FP32: &[&str] = &[ + "fp32/talker_prefill.onnx", + "fp32/talker_prefill.onnx.data", + "fp32/talker_decode.onnx", + "fp32/talker_decode.onnx.data", + "fp32/code_predictor.onnx", + "fp32/code_predictor.onnx.data", + "fp32/vocoder.onnx", + "fp32/vocoder.onnx.data", +]; + +/// Shared files required for all precision variants (embeddings + tokenizer + config). +/// +/// `embeddings/small_to_mtp_projection_{weight,bias}.npy` are intentionally +/// excluded โ€” that projection is baked into `code_predictor.onnx`. +const SHARED_FILES: &[&str] = &[ + "config.json", + // Embedding tables + "embeddings/text_embedding.npy", + "embeddings/text_projection_fc1_weight.npy", + "embeddings/text_projection_fc1_bias.npy", + "embeddings/text_projection_fc2_weight.npy", + "embeddings/text_projection_fc2_bias.npy", + "embeddings/talker_codec_embedding.npy", + "embeddings/cp_codec_embedding_0.npy", + "embeddings/cp_codec_embedding_1.npy", + "embeddings/cp_codec_embedding_2.npy", + "embeddings/cp_codec_embedding_3.npy", + "embeddings/cp_codec_embedding_4.npy", + "embeddings/cp_codec_embedding_5.npy", + "embeddings/cp_codec_embedding_6.npy", + "embeddings/cp_codec_embedding_7.npy", + "embeddings/cp_codec_embedding_8.npy", + "embeddings/cp_codec_embedding_9.npy", + "embeddings/cp_codec_embedding_10.npy", + "embeddings/cp_codec_embedding_11.npy", + "embeddings/cp_codec_embedding_12.npy", + "embeddings/cp_codec_embedding_13.npy", + "embeddings/cp_codec_embedding_14.npy", + // Tokenizer + "tokenizer/vocab.json", + "tokenizer/merges.txt", ]; -/// Resolve the default model cache directory, downloading any missing files. +/// Resolve the local directory for the Qwen3-TTS model files, downloading +/// any missing files from HF Hub as needed. /// /// Resolution order: -/// 1. `$WAVEKAT_MODEL_DIR` if set -/// 2. `$XDG_CACHE_HOME/wavekat/qwen3-tts-0.6b/` -/// 3. `$HOME/.cache/wavekat/qwen3-tts-0.6b/` -pub fn ensure_model_dir() -> Result { - let dir = default_cache_dir()?; - ensure_files(&dir)?; - Ok(dir) -} +/// 1. `config.model_dir` (if set) +/// 2. `WAVEKAT_MODEL_DIR` environment variable +/// 3. Auto-download from HF Hub +/// +/// Authentication: set `HF_TOKEN` if the repo requires it. hf-hub 0.5 does +/// not read `HF_TOKEN` from the environment natively; this function bridges +/// the gap by passing it to `ApiBuilder::with_token`. +/// +/// Cache location: `$HF_HOME/hub/` (default `~/.cache/huggingface/hub/`). +pub fn resolve_model_dir(config: &super::ModelConfig) -> Result { + if let Some(dir) = &config.model_dir { + return Ok(dir.clone()); + } -fn default_cache_dir() -> Result { if let Ok(dir) = std::env::var("WAVEKAT_MODEL_DIR") { return Ok(PathBuf::from(dir)); } - let base = if let Ok(xdg) = std::env::var("XDG_CACHE_HOME") { - PathBuf::from(xdg) - } else if let Ok(home) = std::env::var("HOME") { - PathBuf::from(home).join(".cache") - } else { - return Err(TtsError::Model( - "cannot determine cache directory: set WAVEKAT_MODEL_DIR or HOME".into(), - )); - }; - Ok(base.join("wavekat").join("qwen3-tts-0.6b")) -} -fn ensure_files(dir: &Path) -> Result<(), TtsError> { - let missing: Vec<(&str, &str)> = MODEL_FILES - .iter() - .filter(|(_, local)| !dir.join(local).exists()) - .copied() - .collect(); + let precision = config.precision; - if missing.is_empty() { - return Ok(()); + // from_env() reads HF_HOME / HF_ENDPOINT. + // Bridge HF_TOKEN which hf-hub doesn't read from the environment natively. + let mut builder = ApiBuilder::from_env(); + if let Ok(token) = std::env::var("HF_TOKEN") { + if !token.is_empty() { + builder = builder.with_token(Some(token)); + } } + let api = builder + .build() + .map_err(|e| TtsError::Model(format!("failed to initialize HF Hub client: {e}")))?; + + let repo = api.repo(Repo::with_revision( + REPO_ID.to_string(), + RepoType::Model, + REVISION.to_string(), + )); + + let onnx_files = match precision { + super::ModelPrecision::Int4 => ONNX_FILES_INT4, + super::ModelPrecision::Fp32 => ONNX_FILES_FP32, + }; + let total = 1 + onnx_files.len() + SHARED_FILES[1..].len(); // config + onnx + shared (excl. config) - fs::create_dir_all(dir).map_err(|e| { - TtsError::Model(format!("failed to create cache dir {}: {e}", dir.display())) - })?; - - let total = missing.len(); eprintln!( - "Downloading Qwen3-TTS model ({total} files) to {} ...", - dir.display() + "Ensuring Qwen3-TTS 1.7B ({}) model ({total} files from {REPO_ID})...", + precision.subdir() ); - for (i, (remote, local)) in missing.iter().enumerate() { - let url = format!("{BASE_URL}/{REVISION}/{remote}"); - let dest = dir.join(local); - eprintln!("[{}/{}] {}", i + 1, total, local); - download_file(&url, &dest)?; - } - - eprintln!("Download complete."); - Ok(()) -} - -fn download_file(url: &str, dest: &Path) -> Result<(), TtsError> { - let response = ureq::get(url) - .call() - .map_err(|e| TtsError::Model(format!("download failed for {url}: {e}")))?; - - let content_length: Option = response - .header("Content-Length") - .and_then(|s| s.parse().ok()); - - let mut reader = response.into_reader(); + // config.json first โ€” its parent is the snapshot root. + eprintln!("[1/{total}] {}", SHARED_FILES[0]); + let config_path = repo + .get(SHARED_FILES[0]) + .map_err(|e| TtsError::Model(format!("failed to download {}: {e}", SHARED_FILES[0])))?; - // Write to a temp file first, then rename to avoid partial files on interrupt. - let tmp = dest.with_extension("tmp"); - let mut file = fs::File::create(&tmp) - .map_err(|e| TtsError::Model(format!("failed to create {}: {e}", tmp.display())))?; + let model_dir = config_path + .parent() + .ok_or_else(|| TtsError::Model("unexpected cache path for config.json".into()))? + .to_path_buf(); - let mut buf = [0u8; 256 * 1024]; - let mut downloaded: u64 = 0; - let mut last_report: u64 = 0; - - loop { - let n = reader - .read(&mut buf) - .map_err(|e| TtsError::Model(format!("download read error: {e}")))?; - if n == 0 { - break; - } - file.write_all(&buf[..n]) - .map_err(|e| TtsError::Model(format!("write error: {e}")))?; - downloaded += n as u64; - - // Report progress every 50 MB for large files. - if downloaded - last_report >= 50_000_000 { - if let Some(total) = content_length { - let mb = downloaded / 1_000_000; - let total_mb = total / 1_000_000; - eprint!("\r {mb}/{total_mb} MB"); - } - last_report = downloaded; - } - } - - // Clear progress line if we printed any. - if last_report > 0 { - eprintln!(); + for (i, filename) in onnx_files + .iter() + .chain(SHARED_FILES[1..].iter()) + .enumerate() + { + eprintln!("[{}/{total}] {filename}", i + 2); + repo.get(filename) + .map_err(|e| TtsError::Model(format!("failed to download {filename}: {e}")))?; } - drop(file); - fs::rename(&tmp, dest).map_err(|e| { - TtsError::Model(format!( - "failed to rename {} โ†’ {}: {e}", - tmp.display(), - dest.display() - )) - })?; - - Ok(()) + eprintln!("Files ready. Loading model ..."); + Ok(model_dir) } diff --git a/crates/wavekat-tts/src/backends/qwen3_tts/mod.rs b/crates/wavekat-tts/src/backends/qwen3_tts/mod.rs index db257be..eed072d 100644 --- a/crates/wavekat-tts/src/backends/qwen3_tts/mod.rs +++ b/crates/wavekat-tts/src/backends/qwen3_tts/mod.rs @@ -1,22 +1,30 @@ -//! Qwen3-TTS backend (ONNX, 12Hz-0.6B). +//! Qwen3-TTS backend (ONNX, 1.7B VoiceDesign). //! -//! Runs the Qwen3-TTS-12Hz-0.6B model via ONNX Runtime. +//! Runs the Qwen3-TTS-12Hz-1.7B-VoiceDesign model via ONNX Runtime. +//! Supports INT4 (weight-only quantized, default) and FP32 precision. //! //! ```ignore //! use wavekat_tts::{TtsBackend, SynthesizeRequest}; -//! use wavekat_tts::backends::qwen3_tts::Qwen3Tts; +//! use wavekat_tts::backends::qwen3_tts::{Qwen3Tts, ModelConfig, ModelPrecision, ExecutionProvider}; //! -//! // Auto-download model files (~3.8 GB, cached for reuse): +//! // Auto-download INT4 model files via HF Hub, run on CPU (default): //! let tts = Qwen3Tts::new()?; //! -//! // Or load from an explicit directory: -//! let tts = Qwen3Tts::from_dir("models/qwen3-tts-0.6b")?; +//! // Auto-download FP32, run on CPU: +//! let tts = Qwen3Tts::from_config(ModelConfig::default().with_precision(ModelPrecision::Fp32))?; +//! +//! // INT4 from a local directory, run on CUDA: +//! let tts = Qwen3Tts::from_config( +//! ModelConfig::default() +//! .with_dir("models/qwen3-tts-1.7b") +//! .with_execution_provider(ExecutionProvider::Cuda), +//! )?; //! //! let request = SynthesizeRequest::new("Hello, world"); //! let audio = tts.synthesize(&request)?; //! ``` -use std::path::Path; +use std::path::PathBuf; use wavekat_core::AudioFrame; @@ -24,11 +32,109 @@ use crate::error::TtsError; use crate::traits::TtsBackend; use crate::types::{SynthesizeRequest, VoiceInfo}; +use std::sync::Once; + +use tokenizer::{IM_END, IM_START, NEWLINE}; + +static WARNED_NO_INSTRUCTION: Once = Once::new(); + mod download; mod model; mod sampler; mod tokenizer; +/// ONNX model precision variant. +/// +/// Selects which quantized model files to download and load. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum ModelPrecision { + /// Weight-only INT4 quantized โ€” smaller download, faster load. Default. + #[default] + Int4, + /// Full FP32 โ€” larger download, no quantization error. + Fp32, +} + +impl ModelPrecision { + pub(crate) fn subdir(self) -> &'static str { + match self { + Self::Int4 => "int4", + Self::Fp32 => "fp32", + } + } +} + +/// ONNX execution provider (inference hardware backend). +/// +/// Selecting a provider that is unavailable at runtime causes an error at load +/// time rather than silently falling back. Use [`ExecutionProvider::Cpu`] (the +/// default) if you need guaranteed availability. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum ExecutionProvider { + /// CPU inference via ONNX Runtime. Always available. Default. + #[default] + Cpu, + /// NVIDIA CUDA GPU inference. Requires an ORT build with CUDA support. + Cuda, + /// Apple CoreML (macOS / iOS). Requires an ORT build with CoreML support. + CoreMl, +} + +/// Model loading configuration for [`Qwen3Tts`]. +/// +/// All fields default to sensible values: INT4 quantization, CPU inference, +/// and auto-download from HF Hub. +/// +/// # Examples +/// +/// ```rust,no_run +/// # use wavekat_tts::backends::qwen3_tts::{ModelConfig, ModelPrecision, ExecutionProvider}; +/// // INT4, CPU, auto-download (equivalent to ModelConfig::default()) +/// let config = ModelConfig::default(); +/// +/// // FP32 from a local directory +/// let config = ModelConfig::default() +/// .with_precision(ModelPrecision::Fp32) +/// .with_dir("models/qwen3-tts-1.7b"); +/// +/// // INT4, CUDA, auto-download +/// let config = ModelConfig::default() +/// .with_execution_provider(ExecutionProvider::Cuda); +/// ``` +#[derive(Debug, Clone, Default)] +pub struct ModelConfig { + /// Weight quantization variant (determines which ONNX files to load). + pub precision: ModelPrecision, + /// Inference hardware backend. + pub execution_provider: ExecutionProvider, + /// Local model directory. `None` = resolve via `WAVEKAT_MODEL_DIR` env var, + /// then auto-download from HF Hub. + pub model_dir: Option, +} + +impl ModelConfig { + /// Set a local model directory, bypassing HF Hub download. + /// + /// The directory must mirror the HF repo layout: + /// `int4/` or `fp32/`, `embeddings/`, `tokenizer/`. + pub fn with_dir(mut self, dir: impl Into) -> Self { + self.model_dir = Some(dir.into()); + self + } + + /// Set the model precision (quantization variant). + pub fn with_precision(mut self, precision: ModelPrecision) -> Self { + self.precision = precision; + self + } + + /// Set the ONNX execution provider. + pub fn with_execution_provider(mut self, ep: ExecutionProvider) -> Self { + self.execution_provider = ep; + self + } +} + /// Qwen3-TTS backend using ONNX Runtime. pub struct Qwen3Tts { model: model::Model, @@ -36,31 +142,25 @@ pub struct Qwen3Tts { } impl Qwen3Tts { - /// Create a new backend, auto-downloading model files if needed. - /// - /// Model files (~3.8 GB) are cached at: - /// - `$WAVEKAT_MODEL_DIR` if set, otherwise - /// - `$XDG_CACHE_HOME/wavekat/qwen3-tts-0.6b/`, otherwise - /// - `$HOME/.cache/wavekat/qwen3-tts-0.6b/` + /// Create a new backend with default config (INT4, CPU, auto-download). /// - /// Use [`from_dir`](Self::from_dir) to skip auto-download and load from - /// a specific directory. + /// Files are cached by the HF Hub client (default `~/.cache/huggingface/hub/`). + /// Set `HF_HOME` to change the cache root, `HF_TOKEN` for authentication, or + /// `WAVEKAT_MODEL_DIR` to load from a local directory and skip all downloads. pub fn new() -> Result { - let model_dir = download::ensure_model_dir()?; - Self::from_dir(model_dir) + Self::from_config(ModelConfig::default()) } - /// Load the model from a directory containing ONNX files and embeddings. + /// Create a new backend with the given [`ModelConfig`]. /// - /// Expected files: - /// - `talker_prefill.onnx`, `talker_decode.onnx`, `code_predictor.onnx`, `vocoder.onnx` - /// - `text_embedding.npy`, `text_projection_fc1_weight.npy`, etc. - /// - `talker_codec_embedding.npy`, `cp_codec_embedding_{0..14}.npy` - /// - `vocab.json`, `merges.txt` - pub fn from_dir(model_dir: impl AsRef) -> Result { - let model_dir = model_dir.as_ref(); - let model = model::Model::load(model_dir)?; - let tokenizer = tokenizer::Tokenizer::new(model_dir)?; + /// Model files are resolved in priority order: + /// 1. `config.model_dir` (if set) + /// 2. `WAVEKAT_MODEL_DIR` environment variable + /// 3. Auto-download from HF Hub + pub fn from_config(config: ModelConfig) -> Result { + let model_dir = download::resolve_model_dir(&config)?; + let model = model::Model::load(model_dir.as_ref(), &config)?; + let tokenizer = tokenizer::Tokenizer::new(&model_dir)?; Ok(Self { model, tokenizer }) } } @@ -69,7 +169,33 @@ impl TtsBackend for Qwen3Tts { fn synthesize(&self, request: &SynthesizeRequest) -> Result, TtsError> { let tokens = self.tokenizer.encode(request.text)?; let language = request.language.unwrap_or("en"); - self.model.synthesize(&tokens, language) + + if request.instruction.is_none() { + WARNED_NO_INSTRUCTION.call_once(|| { + eprintln!( + "wavekat-tts warning: Qwen3-TTS is a VoiceDesign model โ€” \ + synthesize quality may be inconsistent without a style instruction. \ + Set `SynthesizeRequest::with_instruction` to control voice style." + ); + }); + } + + let instruction_tokens = if let Some(instr) = request.instruction { + let mut toks = vec![IM_START]; + toks.extend(self.tokenizer.encode("user")?); + toks.push(NEWLINE); + toks.extend(self.tokenizer.encode("")?); + toks.extend(self.tokenizer.encode(instr)?); + toks.extend(self.tokenizer.encode("")?); + toks.push(IM_END); + toks.push(NEWLINE); + Some(toks) + } else { + None + }; + + self.model + .synthesize(&tokens, language, instruction_tokens.as_deref()) } fn voices(&self) -> Result, TtsError> { diff --git a/crates/wavekat-tts/src/backends/qwen3_tts/model.rs b/crates/wavekat-tts/src/backends/qwen3_tts/model.rs index 4be7990..ca0d99a 100644 --- a/crates/wavekat-tts/src/backends/qwen3_tts/model.rs +++ b/crates/wavekat-tts/src/backends/qwen3_tts/model.rs @@ -11,7 +11,7 @@ use crate::TtsError; use super::sampler::{self, SamplerConfig}; use super::tokenizer::{self, ASSISTANT, IM_START, NEWLINE, TTS_BOS, TTS_EOS, TTS_PAD}; -// Codec control token IDs +// Codec control token IDs (from config.json) const CODEC_PAD: i64 = 2148; const CODEC_BOS: i64 = 2149; const CODEC_THINK: i64 = 2154; @@ -21,27 +21,29 @@ const CODEC_THINK_EOS: i64 = 2157; /// Talker output: (logits, hidden_state, kv_keys, kv_values). type TalkerOutput = (Vec, Array3, Array5, Array5); -// Model dimensions (0.6B-12Hz) -const HIDDEN_DIM: usize = 1024; +// Model dimensions โ€” Qwen3-TTS-12Hz-1.7B-VoiceDesign +const HIDDEN_DIM: usize = 2048; const NUM_LAYERS: usize = 28; const NUM_KV_HEADS: usize = 8; const HEAD_DIM: usize = 128; +const TALKER_VOCAB_SIZE: usize = 3072; const CP_NUM_LAYERS: usize = 5; +const CP_NUM_KV_HEADS: usize = 8; const NUM_CP_GROUPS: usize = 15; // codebook groups 1-15 const SAMPLE_RATE: u32 = 24000; const CODEC_EOS: i64 = 2150; -const MAX_NEW_TOKENS: usize = 2048; +const MAX_NEW_TOKENS: usize = 8192; -/// Sampling defaults matching the reference implementation. +/// Sampling defaults from config.json generate_config. const TALKER_SAMPLER: SamplerConfig = SamplerConfig { - temperature: 0.7, - top_p: 0.8, + temperature: 0.9, + top_k: 50, repetition_penalty: 1.05, }; const CP_SAMPLER: SamplerConfig = SamplerConfig { - temperature: 0.2, - top_p: 0.5, + temperature: 0.9, + top_k: 50, repetition_penalty: 1.0, }; @@ -57,48 +59,70 @@ pub struct Model { // Embedding tables (immutable after construction) text_embedding: Array2, // (vocab, 2048) - text_proj_fc1_weight: Array2, // (1024, 2048) - text_proj_fc1_bias: Array1, // (1024,) - text_proj_fc2_weight: Array2, // (1024, 1024) - text_proj_fc2_bias: Array1, // (1024,) - talker_codec_embedding: Array2, // (3072, 1024) - cp_codec_embeddings: Vec>, // 15 ร— (2048, 1024) + text_proj_fc1_weight: Array2, // (2048, 2048) + text_proj_fc1_bias: Array1, // (2048,) + text_proj_fc2_weight: Array2, // (2048, 2048) + text_proj_fc2_bias: Array1, // (2048,) + talker_codec_embedding: Array2, // (3072, 2048) + cp_codec_embeddings: Vec>, // 15 ร— (2048, 2048) // Precomputed - tts_pad_embed: Array1, // (1024,) projected tts_pad text embedding + tts_pad_embed: Array1, // (2048,) projected tts_pad text embedding } impl Model { /// Load all ONNX sessions and embedding tables from `model_dir`. - pub fn load(model_dir: &Path) -> Result { + /// + /// Expected layout (matches the HF repo): + /// - `model_dir/{int4,fp32}/talker_prefill.onnx` (+ .data), etc. + /// - `model_dir/embeddings/text_embedding.npy`, etc. + pub fn load(model_dir: &Path, config: &super::ModelConfig) -> Result { let load_session = |name: &str| -> Result { - let path = model_dir.join(name); - Session::builder() - .map_err(|e| TtsError::Model(format!("session builder error: {e}")))? + let path = model_dir.join(config.precision.subdir()).join(name); + let builder = Session::builder() + .map_err(|e| TtsError::Model(format!("session builder error: {e}")))?; + apply_execution_provider(builder, config.execution_provider)? .commit_from_file(&path) .map_err(|e| TtsError::Model(format!("failed to load {name}: {e}"))) }; + eprint!("Loading talker prefill ... "); let talker_prefill = load_session("talker_prefill.onnx")?; + eprintln!("done"); + + eprint!("Loading talker decode ... "); let talker_decode = load_session("talker_decode.onnx")?; + eprintln!("done"); + + eprint!("Loading code predictor ... "); let code_predictor = load_session("code_predictor.onnx")?; - let vocoder = load_session("vocoder.onnx")?; + eprintln!("done"); - let text_embedding = load_npy2(model_dir, "text_embedding.npy")?; - let text_proj_fc1_weight = load_npy2(model_dir, "text_projection_fc1_weight.npy")?; - let text_proj_fc1_bias = load_npy1(model_dir, "text_projection_fc1_bias.npy")?; - let text_proj_fc2_weight = load_npy2(model_dir, "text_projection_fc2_weight.npy")?; - let text_proj_fc2_bias = load_npy1(model_dir, "text_projection_fc2_bias.npy")?; - let talker_codec_embedding = load_npy2(model_dir, "talker_codec_embedding.npy")?; + eprint!("Loading vocoder ... "); + let vocoder = load_session("vocoder.onnx")?; + eprintln!("done"); + + eprint!("Loading embeddings ... "); + let text_embedding = load_npy2(model_dir, "embeddings/text_embedding.npy")?; + let text_proj_fc1_weight = + load_npy2(model_dir, "embeddings/text_projection_fc1_weight.npy")?; + let text_proj_fc1_bias = load_npy1(model_dir, "embeddings/text_projection_fc1_bias.npy")?; + let text_proj_fc2_weight = + load_npy2(model_dir, "embeddings/text_projection_fc2_weight.npy")?; + let text_proj_fc2_bias = load_npy1(model_dir, "embeddings/text_projection_fc2_bias.npy")?; + let talker_codec_embedding = load_npy2(model_dir, "embeddings/talker_codec_embedding.npy")?; let mut cp_codec_embeddings = Vec::with_capacity(NUM_CP_GROUPS); for i in 0..NUM_CP_GROUPS { cp_codec_embeddings.push(load_npy2( model_dir, - &format!("cp_codec_embedding_{i}.npy"), + &format!("embeddings/cp_codec_embedding_{i}.npy"), )?); } + eprintln!("done"); + eprintln!("Model ready."); + let tts_pad_raw = text_embedding.row(TTS_PAD as usize).to_owned(); let tts_pad_embed = text_project( &tts_pad_raw, @@ -125,18 +149,23 @@ impl Model { } /// Run the full synthesis pipeline: prefill โ†’ decode โ†’ code predict โ†’ vocoder. + /// + /// `instruction_tokens` โ€” optional user-turn prefix for VoiceDesign control. + /// When `Some`, these tokens are embedded (text_proj only) at the start of + /// the prefill before the role prefix. pub fn synthesize( &self, text_tokens: &[u32], language: &str, + instruction_tokens: Option<&[u32]>, ) -> Result, TtsError> { let lang_id = tokenizer::language_id(language) .ok_or_else(|| TtsError::UnsupportedLanguage(language.to_string()))?; - let (prefill_embeds, trailing) = self.build_prefill_embeds(text_tokens, lang_id)?; + let prefill_embeds = self.build_prefill_embeds(text_tokens, lang_id, instruction_tokens)?; let prefill_len = prefill_embeds.shape()[1]; - // Run talker prefill + // Run talker prefill โ€” returns last-position logits let (logits, hidden_states, mut past_keys, mut past_values) = self.run_talker_prefill(&prefill_embeds, prefill_len)?; @@ -145,7 +174,6 @@ impl Model { let mut talker_past_tokens: Vec = Vec::new(); let mut current_logits = logits; - // Hidden state for code predictor: starts from prefill last position let mut current_hidden = hidden_states .slice(s![0, prefill_len - 1.., ..]) .to_owned() @@ -153,11 +181,12 @@ impl Model { .map_err(|e| TtsError::Synthesis(format!("reshape hidden: {e}")))?; for step in 0..MAX_NEW_TOKENS { + // Suppress CODEC_EOS for the first 2 steps (min_new_tokens=2) let group0 = sampler::sample( ¤t_logits, &TALKER_SAMPLER, &talker_past_tokens, - sampler::talker_mask, + |tok| sampler::talker_mask(tok) || (step < 2 && tok == CODEC_EOS as usize), ) as i64; if group0 == CODEC_EOS { @@ -171,17 +200,13 @@ impl Model { self.run_code_predictor(¤t_hidden, &mut codes)?; all_codes.push(codes); - // Build next talker input: sum of all 16 group embeddings + trailing text + // Build next talker input: sum of all 16 group embeddings + tts_pad (non-streaming) let mut next_embed = self.talker_codec_embedding.row(group0 as usize).to_owned(); for g in 0..NUM_CP_GROUPS { let cp_embed = self.cp_codec_embeddings[g].row(codes[g + 1] as usize); next_embed += &cp_embed; } - if step < trailing.len() { - next_embed += &trailing[step]; - } else { - next_embed += &self.tts_pad_embed; - } + next_embed += &self.tts_pad_embed; let next_embed = next_embed .into_shape_with_order((1, 1, HIDDEN_DIM)) @@ -219,53 +244,60 @@ impl Model { ) } - /// Build prefill embeddings and trailing text hidden states. + /// Build prefill embeddings (non-streaming: all text embedded in prefill). /// - /// Prefill layout (matching the C# reference): /// ```text - /// [im_start, assistant, \n] โ€” role prefix (text proj only, no codec) - /// [think, think_bos, lang, think_eos, speaker(=pad)] โ€” tts_pad_embed + codec_embed - /// [tts_bos_embed + codec_embed(pad)] โ€” transition marker - /// [text_proj(first_text) + codec_embed(bos)] โ€” first text token enters here + /// [instr_tok ร— M] โ€” user turn (text_proj only, optional) + /// [im_start, assistant, \n] โ€” role prefix (text proj only) + /// [think, think_bos, lang_id, think_eos] โ€” codec prefix (tts_pad + codec_embed) + /// [tts_bos + codec_pad] โ€” transition + /// [text_proj(tok) + codec_pad] ร— N โ€” all text tokens + /// [text_proj(TTS_EOS) + codec_pad] โ€” TTS_EOS + /// [tts_pad + codec_bos] โ€” final /// ``` - /// - /// Trailing text hidden: `[text_proj(tok) for tok in text[1:]] + [text_proj(TTS_EOS)]` - /// These are consumed one per decode step; after exhaustion, tts_pad_embed is used. fn build_prefill_embeds( &self, text_tokens: &[u32], lang_id: i64, - ) -> Result<(Array3, Vec>), TtsError> { + instruction_tokens: Option<&[u32]>, + ) -> Result, TtsError> { + let codec_pad_embed = self + .talker_codec_embedding + .row(CODEC_PAD as usize) + .to_owned(); + let codec_bos_embed = self + .talker_codec_embedding + .row(CODEC_BOS as usize) + .to_owned(); let tts_bos_embed = self.text_project_token(TTS_BOS); + let tts_eos_embed = self.text_project_token(TTS_EOS); - let first_text = if text_tokens.is_empty() { - TTS_PAD - } else { - text_tokens[0] - }; + // VoiceDesign codec prefix โ€” no speaker slot + let codec_prefix = [CODEC_THINK, CODEC_THINK_BOS, lang_id, CODEC_THINK_EOS]; - // Codec prefix: [think, think_bos, lang_id, think_eos, speaker_id(=pad)] - let codec_prefix = [ - CODEC_THINK, - CODEC_THINK_BOS, - lang_id, - CODEC_THINK_EOS, - CODEC_PAD, - ]; - - // 3 role + 5 codec prefix + 1 transition + 1 first text = 10 - let seq_len = 3 + codec_prefix.len() + 2; + let instr_len = instruction_tokens.map_or(0, |t| t.len()); + // M instruction + 3 role + 4 codec_prefix + 1 transition + N text + 1 TTS_EOS + 1 final + let seq_len = instr_len + 3 + codec_prefix.len() + 1 + text_tokens.len() + 1 + 1; let mut embeds = Array3::::zeros((1, seq_len, HIDDEN_DIM)); let mut pos = 0; - // 1. Role prefix: text_project only, no codec component + // 0. Instruction / user-turn tokens: text_proj only (VoiceDesign control) + if let Some(instr_toks) = instruction_tokens { + for &tok in instr_toks { + let embed = self.text_project_token(tok); + embeds.slice_mut(s![0, pos, ..]).assign(&embed); + pos += 1; + } + } + + // 1. Role prefix: text_proj only for &tok in &[IM_START, ASSISTANT, NEWLINE] { let embed = self.text_project_token(tok); embeds.slice_mut(s![0, pos, ..]).assign(&embed); pos += 1; } - // 2. Codec prefix: tts_pad_embed + codec_embed(token) + // 2. Codec prefix: tts_pad + codec_embed(token) for &codec_tok in &codec_prefix { let mut embed = self.tts_pad_embed.clone(); embed += &self.talker_codec_embedding.row(codec_tok as usize); @@ -273,36 +305,39 @@ impl Model { pos += 1; } - // 3. Transition: tts_bos_embed + codec_embed(pad) + // 3. Transition: tts_bos + codec_pad { - let mut embed = tts_bos_embed; - embed += &self.talker_codec_embedding.row(CODEC_PAD as usize); + let embed = &tts_bos_embed + &codec_pad_embed; embeds.slice_mut(s![0, pos, ..]).assign(&embed); pos += 1; } - // 4. First text token + codec_embed(bos) + // 4. All text tokens: text_proj(tok) + codec_pad + for &tok in text_tokens { + let embed = self.text_project_token(tok) + &codec_pad_embed; + embeds.slice_mut(s![0, pos, ..]).assign(&embed); + pos += 1; + } + + // 5. TTS_EOS: text_proj(TTS_EOS) + codec_pad { - let mut embed = self.text_project_token(first_text); - embed += &self.talker_codec_embedding.row(CODEC_BOS as usize); + let embed = tts_eos_embed + &codec_pad_embed; embeds.slice_mut(s![0, pos, ..]).assign(&embed); + pos += 1; } - // Trailing text hidden: remaining text tokens + TTS_EOS - let mut trailing = Vec::new(); - if text_tokens.len() > 1 { - for &tok in &text_tokens[1..] { - trailing.push(self.text_project_token(tok)); - } + // 6. Final: tts_pad + codec_bos + { + let embed = &self.tts_pad_embed + &codec_bos_embed; + embeds.slice_mut(s![0, pos, ..]).assign(&embed); } - trailing.push(self.text_project_token(TTS_EOS)); - Ok((embeds, trailing)) + Ok(embeds) } /// Run talker_prefill.onnx. /// - /// Returns (logits, hidden_states, past_keys, past_values). + /// Returns (last-position logits, hidden_states, past_keys, past_values). fn run_talker_prefill( &self, inputs_embeds: &Array3, @@ -334,13 +369,13 @@ impl Model { ]) .map_err(|e| TtsError::Synthesis(format!("talker prefill failed: {e}")))?; - // Logits: last position (1, 1, 3072) โ†’ flat Vec + // Logits: (1, T, 3072) โ€” extract only the last position let (_, logits_data) = outputs[0] .try_extract_tensor::() .map_err(|e| TtsError::Synthesis(format!("extract logits: {e}")))?; - let logits: Vec = logits_data.to_vec(); + let logits: Vec = logits_data[logits_data.len() - TALKER_VOCAB_SIZE..].to_vec(); - // Hidden states: (1, T, 1024) + // Hidden states: (1, T, 2048) let (_, hidden_data) = outputs[1] .try_extract_tensor::() .map_err(|e| TtsError::Synthesis(format!("extract hidden: {e}")))?; @@ -361,7 +396,6 @@ impl Model { .try_extract_tensor::() .map_err(|e| TtsError::Synthesis(format!("extract val layer {layer}: {e}")))?; - // (1, 8, T, 128) โ†’ insert axis 0 โ†’ (1, 1, 8, T, 128) let key_arr = ndarray::ArrayD::from_shape_vec( vec![1, NUM_KV_HEADS, seq_len, HEAD_DIM], key_data.to_vec(), @@ -401,7 +435,7 @@ impl Model { /// Run talker_decode.onnx for a single step. fn run_talker_decode( &self, - inputs_embeds: &Array3, // (1, 1, 1024) + inputs_embeds: &Array3, // (1, 1, 2048) total_seq: usize, position: i64, past_keys: &Array5, // (28, 1, 8, past_seq, 128) @@ -465,9 +499,12 @@ impl Model { } /// Run the code predictor to fill codebook groups 1-15. + /// + /// The code_predictor.onnx includes the small_to_mtp projection internally, + /// so host code passes 2048-dim embeddings directly. fn run_code_predictor( &self, - hidden_state: &Array3, // (1, 1, 1024) + hidden_state: &Array3, // (1, 1, 2048) codes: &mut [i64; 16], ) -> Result<(), TtsError> { let group0_embed = self @@ -477,14 +514,15 @@ impl Model { .into_shape_with_order((1, 1, HIDDEN_DIM)) .map_err(|e| TtsError::Synthesis(format!("reshape group0 embed: {e}")))?; - // First call: concat(hidden_state, group0_embed) โ†’ (1, 2, 1024) + // First call: concat(hidden_state, group0_embed) โ†’ (1, 2, 2048) let first_input = concatenate(Axis(1), &[hidden_state.view(), group0_embed.view()]) .map_err(|e| TtsError::Synthesis(format!("concat cp input: {e}")))?; // Empty KV cache: (5, 1, 8, 0, 128) - let mut cp_past_keys = Array5::::zeros((CP_NUM_LAYERS, 1, NUM_KV_HEADS, 0, HEAD_DIM)); + let mut cp_past_keys = + Array5::::zeros((CP_NUM_LAYERS, 1, CP_NUM_KV_HEADS, 0, HEAD_DIM)); let mut cp_past_values = - Array5::::zeros((CP_NUM_LAYERS, 1, NUM_KV_HEADS, 0, HEAD_DIM)); + Array5::::zeros((CP_NUM_LAYERS, 1, CP_NUM_KV_HEADS, 0, HEAD_DIM)); let mut cp_input = first_input; let mut session = self.code_predictor.lock().unwrap(); @@ -516,8 +554,8 @@ impl Model { let (_, logits_data) = outputs[0] .try_extract_tensor::() .map_err(|e| TtsError::Synthesis(format!("extract cp logits: {e}")))?; - let vocab_size = 2048; - let last_logits = &logits_data[logits_data.len() - vocab_size..]; + let cp_vocab_size = 2048; + let last_logits = &logits_data[logits_data.len() - cp_vocab_size..]; let token = sampler::sample(last_logits, &CP_SAMPLER, &[], sampler::no_mask) as i64; codes[group_idx + 1] = token; @@ -533,12 +571,12 @@ impl Model { .map_err(|e| TtsError::Synthesis(format!("extract cp values: {e}")))?; cp_past_keys = Array5::from_shape_vec( - (CP_NUM_LAYERS, 1, NUM_KV_HEADS, seq_so_far, HEAD_DIM), + (CP_NUM_LAYERS, 1, CP_NUM_KV_HEADS, seq_so_far, HEAD_DIM), keys_data.to_vec(), ) .map_err(|e| TtsError::Synthesis(format!("reshape cp keys: {e}")))?; cp_past_values = Array5::from_shape_vec( - (CP_NUM_LAYERS, 1, NUM_KV_HEADS, seq_so_far, HEAD_DIM), + (CP_NUM_LAYERS, 1, CP_NUM_KV_HEADS, seq_so_far, HEAD_DIM), values_data.to_vec(), ) .map_err(|e| TtsError::Synthesis(format!("reshape cp values: {e}")))?; @@ -583,9 +621,9 @@ impl Model { // Trim leading silence produced by the think phase. let start = waveform.iter().position(|&s| s.abs() > 0.01).unwrap_or(0); - let trimmed = &waveform[start..]; + let trimmed = waveform[start..].to_vec(); - Ok(AudioFrame::new(trimmed, SAMPLE_RATE).into_owned()) + Ok(AudioFrame::from_vec(trimmed, SAMPLE_RATE)) } } @@ -593,7 +631,27 @@ impl Model { // Helpers // --------------------------------------------------------------------------- -/// SiLU-gated MLP text projection: 2048 โ†’ 1024. +/// Register the requested execution provider on a session builder. +/// +/// CPU is the ORT default โ€” no registration needed. CUDA and CoreML require +/// an ORT build that includes those providers; otherwise ORT will return an error. +fn apply_execution_provider( + builder: ort::session::builder::SessionBuilder, + ep: super::ExecutionProvider, +) -> Result { + use ort::execution_providers::{CUDAExecutionProvider, CoreMLExecutionProvider}; + match ep { + super::ExecutionProvider::Cpu => Ok(builder), + super::ExecutionProvider::Cuda => builder + .with_execution_providers([CUDAExecutionProvider::default().build()]) + .map_err(|e| TtsError::Model(format!("CUDA execution provider error: {e}"))), + super::ExecutionProvider::CoreMl => builder + .with_execution_providers([CoreMLExecutionProvider::default().build()]) + .map_err(|e| TtsError::Model(format!("CoreML execution provider error: {e}"))), + } +} + +/// SiLU-gated MLP text projection: 2048 โ†’ 2048. fn text_project( input: &Array1, fc1_weight: &Array2, diff --git a/crates/wavekat-tts/src/backends/qwen3_tts/sampler.rs b/crates/wavekat-tts/src/backends/qwen3_tts/sampler.rs index 6105b62..e1e5c60 100644 --- a/crates/wavekat-tts/src/backends/qwen3_tts/sampler.rs +++ b/crates/wavekat-tts/src/backends/qwen3_tts/sampler.rs @@ -4,14 +4,15 @@ use rand::Rng; #[derive(Debug, Clone)] pub struct SamplerConfig { pub temperature: f32, - pub top_p: f32, + pub top_k: usize, pub repetition_penalty: f32, } -/// Sample a token index from logits using top-p (nucleus) sampling. +/// Sample a token index from logits using top-k sampling with temperature and +/// optional repetition penalty. /// /// `mask` is called with a token index and returns `true` if the token should -/// be suppressed (set to -inf before softmax). +/// be suppressed (forced to -inf before sampling). pub fn sample( logits: &[f32], config: &SamplerConfig, @@ -48,7 +49,16 @@ pub fn sample( } } - // 4. Softmax + // 4. Top-k filtering: zero out all but the k highest scores + if config.top_k > 0 && config.top_k < scores.len() { + let mut indexed: Vec<(usize, f32)> = scores.iter().copied().enumerate().collect(); + indexed.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); + for &(i, _) in &indexed[config.top_k..] { + scores[i] = f32::NEG_INFINITY; + } + } + + // 5. Softmax let max = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max); let mut probs: Vec = scores.iter().map(|&s| (s - max).exp()).collect(); let sum: f32 = probs.iter().sum(); @@ -58,37 +68,24 @@ pub fn sample( } } - // 5. Top-p filtering - let mut indexed: Vec<(usize, f32)> = probs.iter().copied().enumerate().collect(); - indexed.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); - - let mut cumulative = 0.0; - let mut cutoff = indexed.len(); - for (i, &(_, p)) in indexed.iter().enumerate() { - cumulative += p; - if cumulative >= config.top_p { - cutoff = i + 1; - break; - } - } - let candidates = &indexed[..cutoff]; - - // Renormalize - let cand_sum: f32 = candidates.iter().map(|&(_, p)| p).sum(); - // 6. Sample let mut rng = rand::rng(); - let r: f32 = rng.random::() * cand_sum; + let r: f32 = rng.random::(); let mut accum = 0.0; - for &(idx, p) in candidates { + for (i, &p) in probs.iter().enumerate() { accum += p; if accum >= r { - return idx; + return i; } } - // Fallback: return the highest-probability token - candidates[0].0 + // Fallback: highest-probability token + probs + .iter() + .enumerate() + .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal)) + .map(|(i, _)| i) + .unwrap_or(0) } /// Logit mask for the Talker LM (group 0). @@ -127,7 +124,7 @@ mod tests { let logits = vec![0.0; 100]; let config = SamplerConfig { temperature: 1.0, - top_p: 0.9, + top_k: 50, repetition_penalty: 1.0, }; let idx = sample(&logits, &config, &[], no_mask); @@ -136,11 +133,10 @@ mod tests { #[test] fn sample_respects_mask() { - // All logits equal, but mask everything except token 5 let logits = vec![0.0; 10]; let config = SamplerConfig { temperature: 1.0, - top_p: 1.0, + top_k: 0, repetition_penalty: 1.0, }; let idx = sample(&logits, &config, &[], |i| i != 5); diff --git a/crates/wavekat-tts/src/backends/qwen3_tts/tokenizer.rs b/crates/wavekat-tts/src/backends/qwen3_tts/tokenizer.rs index 6f5b703..7f8dc32 100644 --- a/crates/wavekat-tts/src/backends/qwen3_tts/tokenizer.rs +++ b/crates/wavekat-tts/src/backends/qwen3_tts/tokenizer.rs @@ -4,6 +4,7 @@ use crate::TtsError; // Special text token IDs (Qwen3 BPE vocab) pub const IM_START: u32 = 151644; +pub const IM_END: u32 = 151645; pub const ASSISTANT: u32 = 77091; pub const NEWLINE: u32 = 198; pub const TTS_BOS: u32 = 151672; @@ -33,10 +34,10 @@ pub struct Tokenizer { } impl Tokenizer { - /// Build tokenizer from `vocab.json` + `merges.txt` in `model_dir`. + /// Build tokenizer from `tokenizer/vocab.json` + `tokenizer/merges.txt` in `model_dir`. pub fn new(model_dir: &Path) -> Result { - let vocab_path = model_dir.join("vocab.json"); - let merges_path = model_dir.join("merges.txt"); + let vocab_path = model_dir.join("tokenizer").join("vocab.json"); + let merges_path = model_dir.join("tokenizer").join("merges.txt"); let bpe = tokenizers::models::bpe::BPE::from_file( &vocab_path.to_string_lossy(), diff --git a/crates/wavekat-tts/src/lib.rs b/crates/wavekat-tts/src/lib.rs index 3bec35b..eed98de 100644 --- a/crates/wavekat-tts/src/lib.rs +++ b/crates/wavekat-tts/src/lib.rs @@ -22,10 +22,10 @@ //! //! # Feature flags //! -//! | Feature | Backend | Chinese | Requires | -//! |---------|---------|---------|----------| -//! | `qwen3-tts` | Qwen3-TTS (ONNX) | Excellent | ONNX model download | -//! | `cosyvoice` | CosyVoice (ONNX) | Excellent | ONNX model download | +//! | Feature | Backend | Multilingual | Requires | +//! |---------|---------|-------------|----------| +//! | `qwen3-tts` | Qwen3-TTS (ONNX) | 10 languages | ONNX model download | +//! | `cosyvoice` | CosyVoice (ONNX) | Yes | ONNX model download | //! //! # Quick start //! diff --git a/crates/wavekat-tts/src/types.rs b/crates/wavekat-tts/src/types.rs index 7685562..847d16f 100644 --- a/crates/wavekat-tts/src/types.rs +++ b/crates/wavekat-tts/src/types.rs @@ -1,7 +1,8 @@ /// A TTS synthesis request. /// /// Backend-agnostic parameters that describe what to synthesize. -/// Each backend interprets `voice` and `language` according to its own catalog. +/// Each backend interprets `voice`, `instruction`, and `language` according to +/// its own capabilities; unsupported fields are silently ignored. #[derive(Debug, Clone)] pub struct SynthesizeRequest<'a> { /// Text to synthesize. @@ -9,14 +10,30 @@ pub struct SynthesizeRequest<'a> { /// Voice identifier (backend-specific). /// - /// For Edge-TTS: `"zh-CN-XiaoxiaoNeural"`, `"zh-CN-YunxiNeural"`, etc. - /// For Kokoro: `"af_heart"`, `"zf_xiaobei"`, etc. + /// Used by backends with a fixed speaker catalog: + /// - Edge-TTS: `"zh-CN-XiaoxiaoNeural"`, `"zh-CN-YunxiNeural"`, โ€ฆ + /// - Kokoro: `"af_heart"`, `"zf_xiaobei"`, โ€ฆ + /// /// `None` uses the backend's default voice. pub voice: Option<&'a str>, + /// Free-form voice instruction / style prompt. + /// + /// Used by instruction-following backends (e.g. Qwen3-TTS VoiceDesign). + /// The text describes how the model should speak: + /// + /// ```text + /// "Speak in a calm, professional tone." + /// "Narrate with warmth and a gentle pace." + /// "Respond with high energy and enthusiasm!" + /// ``` + /// + /// `None` lets the backend use its default voice character. + pub instruction: Option<&'a str>, + /// Language / locale code. /// - /// E.g. `"zh-CN"`, `"en-US"`, `"ja-JP"`. + /// E.g. `"zh"`, `"en"`, `"ja"`. /// `None` uses the backend's default or auto-detects. pub language: Option<&'a str>, @@ -33,17 +50,24 @@ impl<'a> SynthesizeRequest<'a> { Self { text, voice: None, + instruction: None, language: None, speed: None, } } - /// Set the voice. + /// Set the voice identifier. pub fn with_voice(mut self, voice: &'a str) -> Self { self.voice = Some(voice); self } + /// Set the voice instruction / style prompt. + pub fn with_instruction(mut self, instruction: &'a str) -> Self { + self.instruction = Some(instruction); + self + } + /// Set the language. pub fn with_language(mut self, language: &'a str) -> Self { self.language = Some(language); diff --git a/docs/02-proposed-core-wav-io.md b/docs/02-proposed-core-wav-io.md new file mode 100644 index 0000000..c908f09 --- /dev/null +++ b/docs/02-proposed-core-wav-io.md @@ -0,0 +1,126 @@ +# Proposed addition to wavekat-core: WAV I/O + +## Problem + +Every crate in the WaveKat ecosystem that reads or writes WAV files (wavekat-vad, +wavekat-turn, wavekat-tts) independently reaches for `hound` and repeats the same +boilerplate. Any spec change (e.g. bits-per-sample, channel count) must be updated +in multiple places. + +```rust +// Current: every crate does this manually +let spec = hound::WavSpec { + channels: 1, + sample_rate: audio.sample_rate(), + bits_per_sample: 32, + sample_format: hound::SampleFormat::Float, +}; +let mut writer = hound::WavWriter::create(path, spec)?; +for &sample in audio.samples() { + writer.write_sample(sample)?; +} +writer.finalize()?; +``` + +## Proposed change + +Add a `hound` feature flag to wavekat-core that extends `AudioFrame` with two +methods: + +```toml +# wavekat-core Cargo.toml +[features] +hound = ["dep:hound"] + +[dependencies] +hound = { version = "3.5", optional = true } +``` + +```rust +// In audio.rs, behind #[cfg(feature = "hound")]: + +impl AudioFrame<'_> { + /// Write this frame to a WAV file at `path`. + /// + /// Always writes mono f32 PCM at the frame's native sample rate. + /// + /// # Example + /// + /// ```no_run + /// use wavekat_core::AudioFrame; + /// + /// let frame = AudioFrame::from_vec(vec![0.0f32; 16000], 16000); + /// frame.write_wav("output.wav").unwrap(); + /// ``` + #[cfg(feature = "hound")] + pub fn write_wav(&self, path: impl AsRef) -> Result<(), hound::Error> { + let spec = hound::WavSpec { + channels: 1, + sample_rate: self.sample_rate, + bits_per_sample: 32, + sample_format: hound::SampleFormat::Float, + }; + let mut writer = hound::WavWriter::create(path, spec)?; + for &sample in self.samples() { + writer.write_sample(sample)?; + } + writer.finalize() + } +} + +impl AudioFrame<'static> { + /// Read a mono WAV file and return an owned `AudioFrame`. + /// + /// Accepts both f32 and i16 WAV files. i16 samples are normalised to + /// `[-1.0, 1.0]` (divided by 32768). + /// + /// # Example + /// + /// ```no_run + /// use wavekat_core::AudioFrame; + /// + /// let frame = AudioFrame::from_wav("input.wav").unwrap(); + /// println!("{} Hz, {} samples", frame.sample_rate(), frame.len()); + /// ``` + #[cfg(feature = "hound")] + pub fn from_wav(path: impl AsRef) -> Result { + let mut reader = hound::WavReader::open(path)?; + let spec = reader.spec(); + let sample_rate = spec.sample_rate; + let samples: Vec = match spec.sample_format { + hound::SampleFormat::Float => { + reader.samples::().map(|s| s.unwrap()).collect() + } + hound::SampleFormat::Int => { + reader.samples::().map(|s| s.unwrap() as f32 / 32768.0).collect() + } + }; + Ok(AudioFrame::from_vec(samples, sample_rate)) + } +} +``` + +## Impact + +- **Zero breaking changes** โ€” purely additive, opt-in via feature flag +- Consumers opt in: `wavekat-core = { version = "0.0.5", features = ["hound"] }` +- All examples and tests across wavekat-vad, wavekat-turn, wavekat-tts can drop + their own `hound` boilerplate and use the canonical implementation +- One place to maintain the WAV spec (mono, f32, native sample rate) + +## Tests + +```rust +#[cfg(feature = "hound")] +#[test] +fn wav_round_trip() { + let original = AudioFrame::from_vec(vec![0.5f32, -0.5, 0.0, 1.0], 16000); + let path = std::env::temp_dir().join("wavekat_test.wav"); + original.write_wav(&path).unwrap(); + let loaded = AudioFrame::from_wav(&path).unwrap(); + assert_eq!(loaded.sample_rate(), 16000); + for (a, b) in original.samples().iter().zip(loaded.samples()) { + assert!((a - b).abs() < 1e-6, "sample mismatch: {a} vs {b}"); + } +} +``` diff --git a/docs/04-qwen3-tts-1.7b-migration.md b/docs/04-qwen3-tts-1.7b-migration.md new file mode 100644 index 0000000..2337083 --- /dev/null +++ b/docs/04-qwen3-tts-1.7b-migration.md @@ -0,0 +1,185 @@ +# Qwen3-TTS: 1.7B VoiceDesign Migration + +Documents the changes made in `feat/qwen3-tts-1.7b-int4` to migrate the +Qwen3-TTS backend from the third-party 0.6B ONNX repo to the WaveKat-owned +1.7B VoiceDesign ONNX repo with INT4 quantization by default. + +## What changed and why + +| | Before | After | +|---|---|---| +| **Model** | Qwen3-TTS-12Hz-0.6B-Base | Qwen3-TTS-12Hz-1.7B-VoiceDesign | +| **HF repo** | `elbruno/Qwen3-TTS-12Hz-0.6B-Base-ONNX` | `wavekat/Qwen3-TTS-1.7B-VoiceDesign-ONNX` | +| **ONNX variant** | FP32 | INT4 weight-only (RTN, block=128) | +| **Download** | Custom HTTP + manual cache | `hf-hub` crate (standard HF cache) | +| **Prefill mode** | Streaming (text fed per decode step) | Non-streaming (all text in prefill) | +| **Sampling** | top-p | top-k | + +The 1.7B VoiceDesign model produces significantly higher quality audio, particularly +for longer utterances and non-English languages. The INT4 models are ~4ร— smaller than +FP32 (talker: 5.3 GB โ†’ 1.4 GB) with negligible quality loss. + +## HF repo structure + +`wavekat/Qwen3-TTS-1.7B-VoiceDesign-ONNX` mirrors the export output from +`tools/qwen3-tts-onnx/`. The Rust backend downloads and uses only the INT4 variant: + +``` +{hf_snapshot_root}/ +โ”œโ”€โ”€ config.json # model dimensions, token IDs, sampling config +โ”œโ”€โ”€ int4/ # INT4 weight-only quantized ONNX models +โ”‚ โ”œโ”€โ”€ talker_prefill.onnx + .onnx.data (~1.4 GB) +โ”‚ โ”œโ”€โ”€ talker_decode.onnx + .onnx.data (~1.4 GB) +โ”‚ โ”œโ”€โ”€ code_predictor.onnx + .onnx.data (~322 MB) +โ”‚ โ””โ”€โ”€ vocoder.onnx + .onnx.data (~558 MB) +โ”œโ”€โ”€ embeddings/ # pre-extracted weights as .npy +โ”‚ โ”œโ”€โ”€ text_embedding.npy +โ”‚ โ”œโ”€โ”€ text_projection_fc1_{weight,bias}.npy +โ”‚ โ”œโ”€โ”€ text_projection_fc2_{weight,bias}.npy +โ”‚ โ”œโ”€โ”€ talker_codec_embedding.npy +โ”‚ โ””โ”€โ”€ cp_codec_embedding_{0..14}.npy +โ””โ”€โ”€ tokenizer/ + โ”œโ”€โ”€ vocab.json + โ””โ”€โ”€ merges.txt +``` + +`embeddings/small_to_mtp_projection_{weight,bias}.npy` also exist in the repo +but are not downloaded by the Rust backend โ€” the `small_to_mtp` projection is +baked into `code_predictor.onnx`. + +## Download: hf-hub + +The `ureq` dependency and custom HTTP download logic were replaced with +`hf-hub = "0.5"` (the official Rust HF Hub client). + +`download::ensure_model_dir()` calls `repo.get(filename)` for each file in +the hardcoded `MODEL_FILES` list. HF Hub handles caching, LFS resolution, and +redirect following transparently. hf-hub 0.5 fixes the relative-`Location` +redirect handling that broke 0.3 with HuggingFace's `/api/resolve-cache/` backend. + +**Cache location**: `$HF_HOME/hub/` (default `~/.cache/huggingface/hub/`). +The function returns the snapshot root directory, which has the repo's subdirectory +layout intact. + +**Environment variables**: +| Variable | Purpose | +|---|---| +| `WAVEKAT_MODEL_DIR` | Skip HF Hub; load from this local path directly | +| `HF_TOKEN` | Authentication for private/gated repos | +| `HF_HOME` | Override cache root (also sets the token file location) | +| `HF_ENDPOINT` | Override the HuggingFace endpoint URL | + +> **Note on `HF_TOKEN`**: hf-hub 0.5 does not natively read `HF_TOKEN` +> from the environment (it reads `$HF_HOME/token`, written by +> `huggingface-cli login`). `ensure_model_dir()` bridges this by passing +> `HF_TOKEN` to `ApiBuilder::with_token` when the env var is set. + +## Model dimensions (1.7B vs 0.6B) + +| Parameter | 0.6B | 1.7B | +|---|---|---| +| `HIDDEN_DIM` | 1024 | **2048** | +| `NUM_LAYERS` | 28 | 28 | +| `NUM_KV_HEADS` | 8 | 8 | +| `HEAD_DIM` | 128 | 128 | +| `CP_NUM_LAYERS` | 5 | 5 | +| `CP_NUM_KV_HEADS` | 8 | 8 | +| `MAX_NEW_TOKENS` | 2048 | **8192** | + +Only `HIDDEN_DIM` changes (1024 โ†’ 2048). All KV cache shapes, the code predictor +architecture, and the vocoder are identical between the two sizes. + +## Sampling (top-k replacing top-p) + +`SamplerConfig` now has `top_k: usize` instead of `top_p: f32`. +Values from `config.json`: + +| | Talker (group 0) | Code Predictor (groups 1-15) | +|---|---|---| +| `temperature` | 0.9 | 0.9 | +| `top_k` | 50 | 50 | +| `repetition_penalty` | 1.05 | 1.0 | + +The previous 0.6B values (temp=0.7, top_p=0.8 for talker; temp=0.2, top_p=0.5 +for CP) were derived from an older reference and did not match the official +`config.json`. + +Additionally, `min_new_tokens=2` is now enforced: CODEC_EOS is masked for the +first two decode steps, matching `generate_onnx.py`. + +## Prefill: non-streaming mode + +The 0.6B implementation used a **streaming** prefill: only the first text token +was included in the prefill; the rest were fed one-per-decode-step as a "trailing" +vector added to the codec embedding sum. + +The 1.7B implementation uses **non-streaming** prefill to match `generate_onnx.py`. +All text tokens are embedded in the prefill sequence: + +``` +[im_start, assistant, \n] โ€” role (text_proj only) +[think, think_bos, lang_id, think_eos] โ€” codec prefix (tts_pad + codec_embed) +[tts_bos + codec_pad] โ€” transition +[text_proj(tok) + codec_pad] ร— N โ€” all text tokens +[text_proj(TTS_EOS) + codec_pad] โ€” TTS_EOS +[tts_pad + codec_bos] โ€” final +``` + +The decode loop trailing is always `tts_pad_embed` (a constant) rather than +per-step text projections. + +Non-streaming gives the Talker LM full visibility over the complete text from the +first token, which improves prosody and accuracy on longer inputs. + +**Note**: the 0.6B codec prefix had 5 tokens (included a speaker slot = CODEC_PAD). +VoiceDesign has no predefined speakers, so the codec prefix is 4 tokens. + +## Bug fix: prefill logits extraction + +`run_talker_prefill` previously returned `logits_data.to_vec()` which includes +logits for **all** T prefill positions โ€” a flat vector of T ร— 3072 elements. +The sampler then treated this as a 3072-token vocab, producing wrong samples. + +Fixed by slicing only the last position: + +```rust +let logits: Vec = logits_data[logits_data.len() - TALKER_VOCAB_SIZE..].to_vec(); +``` + +The decode step was unaffected (it always returns shape `(1, 1, 3072)` = 3072 elements). + +## code_predictor.onnx and small_to_mtp + +For the 1.7B model, the Talker hidden size is 2048 but the Code Predictor +transformer runs at 1024. The `small_to_mtp_projection` (Linear 2048 โ†’ 1024) +is **baked into `code_predictor.onnx`** rather than applied in host code. + +As a result: +- Host code passes `(1, 2, 2048)` to the code predictor on the first call + (concat of talker hidden and group-0 codec embedding, both 2048-dim) +- Subsequent calls pass `(1, 1, 2048)` from `cp_codec_embeddings[g]` (shape 2048ร—2048) +- No host-side projection step is needed + +`small_to_mtp_projection_{weight,bias}.npy` are present in the HF repo for +reference but are not downloaded or used by the Rust backend. + +## Usage + +```bash +# Auto-download via HF Hub and synthesize +cargo run --example synthesize --features qwen3-tts,hound -- "Hello, world!" + +# Interactive mode +cargo run --example synthesize --features qwen3-tts,hound -- -i + +# Load from manually downloaded snapshot +WAVEKAT_MODEL_DIR=/path/to/snapshot \ + cargo run --example synthesize --features qwen3-tts,hound -- "Hello" + +# With a VoiceDesign instruction +cargo run --example synthesize --features qwen3-tts,hound -- \ + --instruction "Speak in a calm, professional tone." "Hello, world!" +``` + +The first run downloads ~3.7 GB of INT4 ONNX files and ~0.6 GB of embeddings. +Subsequent runs load directly from the HF Hub cache. diff --git a/tools/qwen3-tts-onnx/README.md b/tools/qwen3-tts-onnx/README.md index f6568c7..de5d9fc 100644 --- a/tools/qwen3-tts-onnx/README.md +++ b/tools/qwen3-tts-onnx/README.md @@ -51,11 +51,11 @@ python generate_onnx.py --variant int4 \ --instruct "Speak in a warm and friendly female voice" \ -o output_int4.wav -# Chinese -python generate_onnx.py --variant int4 --lang chinese \ - --text "่ฎฉๆฏไธ€ๅฎถๅฐไผไธš๏ผŒ้ƒฝๆ‹ฅๆœ‰ๅคงไผไธš็š„ๅฃฐ้Ÿณใ€‚" \ - --instruct "Speak in a warm and professional female voice" \ - -o output_zh.wav +# With a custom style instruction +python generate_onnx.py --variant int4 \ + --text "Give every small business the voice of a big one." \ + --instruct "Speak with confidence and warmth" \ + -o output_styled.wav ``` ## Model Architecture