From d1d94ec29edf462f51281545488af86db2d242e5 Mon Sep 17 00:00:00 2001 From: Rahul D Shetty <35rahuldshetty@gmail.com> Date: Sat, 28 Feb 2026 21:55:36 +0530 Subject: [PATCH 1/2] feat: Rust implementation of kittentts-rs Signed-off-by: Rahul D Shetty <35rahuldshetty@gmail.com> --- .gitignore | 5 +- .gitmodules | 3 + kittentts-rs/.gitignore | 7 + kittentts-rs/Cargo.toml | 20 + kittentts-rs/README.md | 43 + kittentts-rs/crates/espeak-ng-sys/Cargo.toml | 21 + kittentts-rs/crates/espeak-ng-sys/build.rs | 68 ++ kittentts-rs/crates/espeak-ng-sys/espeak-ng | 1 + kittentts-rs/crates/espeak-ng-sys/src/lib.rs | 5 + kittentts-rs/crates/espeak-ng-sys/wrapper.h | 2 + kittentts-rs/src/config.rs | 31 + kittentts-rs/src/espeak.rs | 65 ++ kittentts-rs/src/lib.rs | 15 + kittentts-rs/src/main.rs | 75 ++ kittentts-rs/src/model.rs | 362 +++++++ kittentts-rs/src/preprocess.rs | 947 +++++++++++++++++++ kittentts-rs/src/text_cleaner.rs | 53 ++ 17 files changed, 1722 insertions(+), 1 deletion(-) create mode 100644 .gitmodules create mode 100644 kittentts-rs/.gitignore create mode 100644 kittentts-rs/Cargo.toml create mode 100644 kittentts-rs/README.md create mode 100644 kittentts-rs/crates/espeak-ng-sys/Cargo.toml create mode 100644 kittentts-rs/crates/espeak-ng-sys/build.rs create mode 160000 kittentts-rs/crates/espeak-ng-sys/espeak-ng create mode 100644 kittentts-rs/crates/espeak-ng-sys/src/lib.rs create mode 100644 kittentts-rs/crates/espeak-ng-sys/wrapper.h create mode 100644 kittentts-rs/src/config.rs create mode 100644 kittentts-rs/src/espeak.rs create mode 100644 kittentts-rs/src/lib.rs create mode 100644 kittentts-rs/src/main.rs create mode 100644 kittentts-rs/src/model.rs create mode 100644 kittentts-rs/src/preprocess.rs create mode 100644 kittentts-rs/src/text_cleaner.rs diff --git a/.gitignore b/.gitignore index 15631bd..8eed2e6 100644 --- a/.gitignore +++ b/.gitignore @@ -241,4 +241,7 @@ marimo/_lsp/ __marimo__/ # Streamlit -.streamlit/secrets.toml \ No newline at end of file +.streamlit/secrets.toml + + +kitten-tts-mini-0.8/ \ No newline at end of file diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..8a3807b --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "kittentts-rs/crates/espeak-ng-sys/espeak-ng"] + path = kittentts-rs/crates/espeak-ng-sys/espeak-ng + url = https://github.com/espeak-ng/espeak-ng diff --git a/kittentts-rs/.gitignore b/kittentts-rs/.gitignore new file mode 100644 index 0000000..43dcd0e --- /dev/null +++ b/kittentts-rs/.gitignore @@ -0,0 +1,7 @@ +/target/ +Cargo.lock +espeak-ng.dll +*.wav +*.onnx +*.npz +.DS_Store diff --git a/kittentts-rs/Cargo.toml b/kittentts-rs/Cargo.toml new file mode 100644 index 0000000..45740d4 --- /dev/null +++ b/kittentts-rs/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "kittentts-rs" +version = "0.1.0" +edition = "2021" +description = "Ultra-lightweight text-to-speech inference in Rust — port of KittenTTS" +license = "MIT" + +[dependencies] +ort = { version = "2.0.0-rc.11", features = ["ndarray"] } +ndarray = "0.16" +ndarray-npy = { version = "0.9", features = ["npz"] } +hound = "3.5" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +regex = "1.10" +clap = { version = "4.4", features = ["derive"] } +anyhow = "1.0" +once_cell = "1.19" +fancy-regex = "0.13" +espeak-ng-sys = { path = "crates/espeak-ng-sys" } diff --git a/kittentts-rs/README.md b/kittentts-rs/README.md new file mode 100644 index 0000000..baa1039 --- /dev/null +++ b/kittentts-rs/README.md @@ -0,0 +1,43 @@ +# KittenTTS-RS + +Ultra-lightweight text-to-speech inference in Rust using ONNX Runtime and eSpeak-NG. + +This is a Rust port of the KittenTTS project. + +## Features + +- **Fast & Efficient**: Low-latency synthesis via ONNX Runtime (`ort`). +- **High Quality**: Accurate pronunciation using IPA phonemes via eSpeak-NG. +- **Self-Contained**: eSpeak-NG is built-in as a submodule—no system installation required. +- **Configurable**: Multiple voices, speed control, and text preprocessing. + +## Prerequisites + +- **Rust**: [Install Rust](https://www.rust-lang.org/tools/install) +- **ONNX Model**: A directory containing `config.json`, `.onnx` model, and `voices.npz`. + +## Running Locally + +```bash +# download model weights (onnx) +git clone https://huggingface.co/KittenML/kitten-tts-mini-0.8 + +cargo run -- \ + --text "Hello, I am KittenTTS-RS!" \ + --model-dir "../kitten-tts-mini-0.8" \ + --voice "Luna" \ + --output "output.wav" +``` + +### Options + +- `--text`: Text to synthesize. +- `--model-dir`: Directory containing ONNX model files. +- `--voice`: Voice name (e.g., "Luna", "Leo"). +- `--speed`: Speech speed (default: 1.0). +- `--output`: Output file path (default: `output.wav`). +- `--no-clean`: Skip text preprocessing. + +--- + +*Coded with ❤️ by Antigravity.* diff --git a/kittentts-rs/crates/espeak-ng-sys/Cargo.toml b/kittentts-rs/crates/espeak-ng-sys/Cargo.toml new file mode 100644 index 0000000..cbb97f9 --- /dev/null +++ b/kittentts-rs/crates/espeak-ng-sys/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "espeak-ng-sys" +version = "0.1.0" +edition = "2021" +links = "espeak-ng" +include = [ + "src/**/*", + "espeak-ng/espeak-ng-data/**/*", + "espeak-ng/dictsource/**/*", + "build.rs", + "wrapper.h", +] + +[dependencies] +# Piper uses these for path resolution +once_cell = "1.19" + +[build-dependencies] +cmake = "0.1" +bindgen = "0.70" +glob = "0.3" diff --git a/kittentts-rs/crates/espeak-ng-sys/build.rs b/kittentts-rs/crates/espeak-ng-sys/build.rs new file mode 100644 index 0000000..cf65838 --- /dev/null +++ b/kittentts-rs/crates/espeak-ng-sys/build.rs @@ -0,0 +1,68 @@ +use std::env; +use std::path::PathBuf; + +fn main() { + let dst = cmake::Config::new("espeak-ng") + .define("BUILD_SHARED_LIBS", "OFF") + .define("USE_LIBPCAUDIO", "OFF") + .define("USE_KLATT", "OFF") + .define("USE_MBROLA", "OFF") + .define("USE_ASYNC", "OFF") + .define("ENABLE_TESTS", "OFF") + .build(); + + let profile = env::var("PROFILE").unwrap_or_else(|_| "debug".to_string()); + let cmake_config = if profile == "release" { + "Release" + } else { + "Debug" + }; + + // espeak-ng builds sub-libraries in specific subdirectories + let mut search_paths = vec![ + dst.join("lib"), + dst.join("build/src/ucd-tools"), + dst.join("build/src/speechPlayer"), + ]; + + if cfg!(target_os = "windows") { + // MSVC adds a config-specific subdirectory + search_paths.push(dst.join("build/src/ucd-tools").join(cmake_config)); + search_paths.push(dst.join("build/src/speechPlayer").join(cmake_config)); + // Some versions might put espeak-ng.lib in lib/Debug too + search_paths.push(dst.join("lib").join(cmake_config)); + } + + for path in &search_paths { + if path.exists() { + println!("cargo:rustc-link-search=native={}", path.display()); + } + } + + println!("cargo:rustc-link-lib=static=espeak-ng"); + println!("cargo:rustc-link-lib=static=ucd"); + // Piper links speechPlayer too + println!("cargo:rustc-link-lib=static=speechPlayer"); + + if cfg!(target_os = "windows") { + println!("cargo:rustc-link-lib=dylib=user32"); + println!("cargo:rustc-link-lib=dylib=shell32"); + } else if cfg!(target_os = "linux") { + println!("cargo:rustc-link-lib=dylib=stdc++"); + } else if cfg!(target_os = "macos") { + println!("cargo:rustc-link-lib=framework=Foundation"); + println!("cargo:rustc-link-lib=dylib=c++"); + } + + let bindings = bindgen::Builder::default() + .header("wrapper.h") + .clang_arg("-Iespeak-ng/src/include") + .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) + .generate() + .expect("Unable to generate bindings"); + + let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); + bindings + .write_to_file(out_path.join("bindings.rs")) + .expect("Couldn't write bindings!"); +} diff --git a/kittentts-rs/crates/espeak-ng-sys/espeak-ng b/kittentts-rs/crates/espeak-ng-sys/espeak-ng new file mode 160000 index 0000000..c204e61 --- /dev/null +++ b/kittentts-rs/crates/espeak-ng-sys/espeak-ng @@ -0,0 +1 @@ +Subproject commit c204e6183497f4b6a7f30c6aac69deb86d6f341c diff --git a/kittentts-rs/crates/espeak-ng-sys/src/lib.rs b/kittentts-rs/crates/espeak-ng-sys/src/lib.rs new file mode 100644 index 0000000..024d8f4 --- /dev/null +++ b/kittentts-rs/crates/espeak-ng-sys/src/lib.rs @@ -0,0 +1,5 @@ +#![allow(non_upper_case_globals)] +#![allow(non_camel_case_types)] +#![allow(non_snake_case)] + +include!(concat!(env!("OUT_DIR"), "/bindings.rs")); diff --git a/kittentts-rs/crates/espeak-ng-sys/wrapper.h b/kittentts-rs/crates/espeak-ng-sys/wrapper.h new file mode 100644 index 0000000..dc9a7b8 --- /dev/null +++ b/kittentts-rs/crates/espeak-ng-sys/wrapper.h @@ -0,0 +1,2 @@ +#include "espeak-ng/speak_lib.h" +#include "espeak-ng/espeak_ng.h" diff --git a/kittentts-rs/src/config.rs b/kittentts-rs/src/config.rs new file mode 100644 index 0000000..6d6dd89 --- /dev/null +++ b/kittentts-rs/src/config.rs @@ -0,0 +1,31 @@ +use serde::Deserialize; +use std::collections::HashMap; +use std::path::Path; + +/// Model configuration matching the JSON config.json format. +#[derive(Debug, Deserialize)] +pub struct ModelConfig { + pub name: String, + pub version: String, + #[serde(rename = "type")] + pub model_type: String, + pub model: String, + pub voices: String, + pub model_file: String, + #[serde(default)] + pub speed_priors: HashMap, + #[serde(default)] + pub voice_aliases: HashMap, +} + +impl ModelConfig { + /// Load configuration from a config.json file. + pub fn load(path: &Path) -> anyhow::Result { + let data = std::fs::read_to_string(path)?; + let config: ModelConfig = serde_json::from_str(&data)?; + if config.model_type != "ONNX1" && config.model_type != "ONNX2" { + anyhow::bail!("Unsupported model type: {}", config.model_type); + } + Ok(config) + } +} diff --git a/kittentts-rs/src/espeak.rs b/kittentts-rs/src/espeak.rs new file mode 100644 index 0000000..4ad0e0d --- /dev/null +++ b/kittentts-rs/src/espeak.rs @@ -0,0 +1,65 @@ +use espeak_ng_sys as espeak_ng; +use std::ffi::{CStr, CString}; +use std::os::raw::c_void; +use std::ptr; + +pub struct Espeak { + _private: (), +} + +impl Espeak { + pub fn new(data_path: Option<&str>) -> anyhow::Result { + let c_path = data_path.map(|s| CString::new(s).unwrap()); + + let path_ptr = c_path.as_ref().map(|s| s.as_ptr()).unwrap_or(ptr::null()); + + // output = Retrieval, buflength = 0, options = 0 + let sample_rate = unsafe { + espeak_ng::espeak_Initialize( + espeak_ng::espeak_AUDIO_OUTPUT_AUDIO_OUTPUT_RETRIEVAL, + 0, + path_ptr, + 0, + ) + }; + + if sample_rate <= 0 { + anyhow::bail!("Failed to initialize eSpeak-NG (returned {})", sample_rate); + } + + Ok(Espeak { _private: () }) + } + + pub fn set_voice(&self, name: &str) -> anyhow::Result<()> { + let c_name = CString::new(name)?; + let result = unsafe { espeak_ng::espeak_SetVoiceByName(c_name.as_ptr()) }; + if result == espeak_ng::espeak_ERROR_EE_OK { + Ok(()) + } else { + anyhow::bail!("Failed to set eSpeak voice to {}", name) + } + } + + pub fn text_to_phonemes(&self, text: &str) -> anyhow::Result { + let c_text = CString::new(text)?; + let mut text_ptr = c_text.as_ptr() as *const c_void; + + // textmode 1 = UTF8, phonememode 2 = IPA phonemes + let result_ptr = unsafe { espeak_ng::espeak_TextToPhonemes(&mut text_ptr, 1, 2) }; + + if result_ptr.is_null() { + anyhow::bail!("eSpeak failed to generate phonemes"); + } + + let c_str = unsafe { CStr::from_ptr(result_ptr) }; + Ok(c_str.to_string_lossy().into_owned()) + } +} + +impl Drop for Espeak { + fn drop(&mut self) { + unsafe { + espeak_ng::espeak_Terminate(); + } + } +} diff --git a/kittentts-rs/src/lib.rs b/kittentts-rs/src/lib.rs new file mode 100644 index 0000000..efbeb1e --- /dev/null +++ b/kittentts-rs/src/lib.rs @@ -0,0 +1,15 @@ +//! # kittentts-rs +//! +//! Ultra-lightweight text-to-speech inference in Rust. +//! +//! A Rust port of the [KittenTTS](https://github.com/kittenml/kittentts) Python +//! package, using `ort` for ONNX Runtime inference and `espeakng` for phonemisation. + +pub mod config; +pub mod espeak; +pub mod model; +pub mod preprocess; +pub mod text_cleaner; + +// Re-export the main model type for convenience. +pub use model::KittenTTSModel; diff --git a/kittentts-rs/src/main.rs b/kittentts-rs/src/main.rs new file mode 100644 index 0000000..9089caa --- /dev/null +++ b/kittentts-rs/src/main.rs @@ -0,0 +1,75 @@ +use anyhow::Result; +use clap::Parser; +use std::path::PathBuf; + +use kittentts_rs::KittenTTSModel; + +/// kittentts-rs — Ultra-lightweight text-to-speech inference in Rust. +#[derive(Parser, Debug)] +#[command(name = "kittentts-rs", version, about)] +struct Cli { + /// Path to the model directory (containing config.json, ONNX model, and voices.npz) + #[arg(long)] + model_dir: PathBuf, + + /// Path to eSpeak-NG data directory (optional) + #[arg(long)] + espeak_data: Option, + + /// Text to synthesize + #[arg(long)] + text: String, + + /// Voice to use (e.g. "Leo", "Bella", "Bruno", or internal names like "expr-voice-5-m") + #[arg(long, default_value = "Leo")] + voice: String, + + /// Speech speed (1.0 = normal) + #[arg(long, default_value_t = 1.0)] + speed: f32, + + /// Output WAV file path + #[arg(long, default_value = "output.wav")] + output: PathBuf, + + /// Audio sample rate in Hz + #[arg(long, default_value_t = 24000)] + sample_rate: u32, + + /// Disable text preprocessing / cleaning + #[arg(long, default_value_t = false)] + no_clean: bool, +} + +fn main() -> Result<()> { + let cli = Cli::parse(); + + let espeak_path = cli + .espeak_data + .map(|p| p.to_string_lossy().into_owned()) + .or_else(|| { + let local_path = std::env::current_dir().ok()?.join("espeak-ng"); + if local_path.exists() { + Some(local_path.to_string_lossy().into_owned()) + } else { + None + } + }); + + println!("Loading model from {} ...", cli.model_dir.display()); + let mut model = KittenTTSModel::from_dir(&cli.model_dir, espeak_path.as_deref())?; + + println!("Generating audio for: \"{}\"", cli.text); + println!("Voice: {}, Speed: {}", cli.voice, cli.speed); + + model.generate_to_file( + &cli.text, + &cli.output, + &cli.voice, + cli.speed, + cli.sample_rate, + !cli.no_clean, + )?; + + Ok(()) +} diff --git a/kittentts-rs/src/model.rs b/kittentts-rs/src/model.rs new file mode 100644 index 0000000..85f6ac4 --- /dev/null +++ b/kittentts-rs/src/model.rs @@ -0,0 +1,362 @@ +//! ONNX model inference and voice embedding loading. +//! +//! Port of `onnx_model.py` — loads the ONNX session, reads voice embeddings +//! from an NPZ archive, phonemises input text via eSpeak, and runs inference. + +use std::collections::HashMap; +use std::path::Path; + +use anyhow::{Context, Result}; +use ndarray::{s, Array2}; +use ndarray_npy::NpzReader; +use ort::session::Session; +use ort::value::Value; +use regex::Regex; + +use crate::config::ModelConfig; +use crate::espeak::Espeak; +use crate::preprocess::TextPreprocessor; +use crate::text_cleaner::TextCleaner; + +// ───────────────────────────────────────────── +// Helper functions +// ───────────────────────────────────────────── + +/// Basic English tokenizer: split on word / non-word boundaries. +fn basic_english_tokenize(text: &str) -> Vec { + let re = Regex::new(r"\w+|[^\w\s]").unwrap(); + re.find_iter(text).map(|m| m.as_str().to_string()).collect() +} + +/// Ensure text ends with punctuation; append a comma if needed. +fn ensure_punctuation(text: &str) -> String { + let text = text.trim(); + if text.is_empty() { + return String::new(); + } + let last = text.chars().last().unwrap(); + if ".!?,;:".contains(last) { + text.to_string() + } else { + format!("{},", text) + } +} + +/// Split text into chunks for processing long texts. +fn chunk_text(text: &str, max_len: usize) -> Vec { + let re = Regex::new(r"[.!?]+").unwrap(); + let sentences: Vec<&str> = re.split(text).collect(); + let mut chunks = Vec::new(); + + for sentence in sentences { + let sentence = sentence.trim(); + if sentence.is_empty() { + continue; + } + if sentence.len() <= max_len { + chunks.push(ensure_punctuation(sentence)); + } else { + let words: Vec<&str> = sentence.split_whitespace().collect(); + let mut temp_chunk = String::new(); + for word in words { + if temp_chunk.len() + word.len() + 1 <= max_len { + if !temp_chunk.is_empty() { + temp_chunk.push(' '); + } + temp_chunk.push_str(word); + } else { + if !temp_chunk.is_empty() { + chunks.push(ensure_punctuation(temp_chunk.trim())); + } + temp_chunk = word.to_string(); + } + } + if !temp_chunk.is_empty() { + chunks.push(ensure_punctuation(temp_chunk.trim())); + } + } + } + chunks +} + +// ───────────────────────────────────────────── +// KittenTTS Model +// ───────────────────────────────────────────── + +/// Core KittenTTS ONNX model for text-to-speech synthesis. +pub struct KittenTTSModel { + session: Session, + voices: HashMap>, + text_cleaner: TextCleaner, + preprocessor: TextPreprocessor, + espeak: Espeak, + + // Voice metadata + pub available_voices: Vec, + pub all_voice_names: Vec, + pub voice_aliases: HashMap, + pub speed_priors: HashMap, +} + +impl KittenTTSModel { + /// Load a model from a directory containing config.json, the ONNX model, and voices.npz. + pub fn from_dir(model_dir: &Path, espeak_data_path: Option<&str>) -> Result { + let config_path = model_dir.join("config.json"); + let config = ModelConfig::load(&config_path).context("Failed to load config.json")?; + + let model_path = model_dir.join(&config.model_file); + let voices_path = model_dir.join(&config.voices); + + Self::new( + &model_path, + &voices_path, + config.speed_priors, + config.voice_aliases, + espeak_data_path, + ) + } + + /// Create a new model from explicit file paths. + pub fn new( + model_path: &Path, + voices_path: &Path, + speed_priors: HashMap, + voice_aliases: HashMap, + espeak_data_path: Option<&str>, + ) -> Result { + // Load ONNX session + let session = Session::builder()? + .commit_from_file(model_path) + .context("Failed to load ONNX model")?; + + // Load voice embeddings from NPZ + let voices = Self::load_voices(voices_path)?; + + let available_voices = vec![ + "expr-voice-2-m", + "expr-voice-2-f", + "expr-voice-3-m", + "expr-voice-3-f", + "expr-voice-4-m", + "expr-voice-4-f", + "expr-voice-5-m", + "expr-voice-5-f", + ] + .into_iter() + .map(String::from) + .collect(); + + let all_voice_names = vec![ + "Bella", "Jasper", "Luna", "Bruno", "Rosie", "Hugo", "Kiki", "Leo", + ] + .into_iter() + .map(String::from) + .collect(); + + let text_cleaner = TextCleaner::new(); + let preprocessor = TextPreprocessor::default(); + let espeak = Espeak::new(espeak_data_path).context("Failed to initialize eSpeak-NG")?; + + Ok(KittenTTSModel { + session, + voices, + text_cleaner, + preprocessor, + espeak, + available_voices, + all_voice_names, + voice_aliases, + speed_priors, + }) + } + + /// Load all voice embeddings from a .npz file. + fn load_voices(path: &Path) -> Result>> { + let file = std::fs::File::open(path).context("Failed to open voices.npz")?; + let mut npz = NpzReader::new(file).context("Failed to parse voices.npz")?; + + let mut voices = HashMap::new(); + let names = npz.names().context("Failed to read NPZ entry names")?; + + for name in &names { + // Strip the .npy extension that ndarray-npy may include + let clean_name = name.trim_end_matches(".npy").to_string(); + let arr: Array2 = npz + .by_name(&clean_name) + .context(format!("Failed to read voice array '{}'", clean_name))?; + voices.insert(clean_name, arr); + } + + Ok(voices) + } + + /// Resolve a voice name/alias to the internal voice key. + fn resolve_voice<'a>(&'a self, voice: &'a str) -> Result<&'a str> { + let resolved = self + .voice_aliases + .get(voice) + .map(|s| s.as_str()) + .unwrap_or(voice); + if !self.available_voices.contains(&resolved.to_string()) { + anyhow::bail!( + "Voice '{}' not available. Choose from: {:?}", + voice, + self.all_voice_names + ); + } + Ok(resolved) + } + + /// Phonemise text using eSpeak-NG. + fn phonemise(&self, text: &str) -> Result { + self.espeak + .set_voice("en-us") + .context("Failed to set eSpeak voice to en-us")?; + + self.espeak + .text_to_phonemes(text) + .context("eSpeak phonemisation failed") + } + + /// Prepare ONNX model inputs from text and voice parameters. + fn prepare_inputs( + &self, + text: &str, + voice: &str, + speed: f32, + ) -> Result<(Vec, Array2, f32)> { + let voice = self.resolve_voice(voice)?; + + let mut speed = speed; + if let Some(&prior) = self.speed_priors.get(voice) { + speed *= prior; + } + + // Phonemise + let phonemes_raw = self.phonemise(text)?; + let tokens_str = basic_english_tokenize(&phonemes_raw); + let phonemes_joined = tokens_str.join(" "); + + #[cfg(debug_assertions)] + println!("Phonemes: {}", phonemes_joined); + + let mut tokens = self.text_cleaner.encode(&phonemes_joined); + + #[cfg(debug_assertions)] + println!("Tokens: {:?}", tokens); + + // Add start and end tokens + tokens.insert(0, 0); // start token + tokens.push(10); // end marker + tokens.push(0); // pad + + // Get voice reference embedding + let voice_arr = self + .voices + .get(voice) + .ok_or_else(|| anyhow::anyhow!("Voice embedding '{}' not found in NPZ", voice))?; + + let ref_id = text.len().min(voice_arr.nrows() - 1); + let ref_s = voice_arr.slice(s![ref_id..ref_id + 1, ..]).to_owned(); + + Ok((tokens, ref_s, speed)) + } + + /// Generate speech for a single text chunk. + fn generate_single_chunk(&mut self, text: &str, voice: &str, speed: f32) -> Result> { + let (tokens, ref_s, speed) = self.prepare_inputs(text, voice, speed)?; + + let input_ids_shape = vec![1, tokens.len()]; + let input_ids_vec = tokens; + + let speed_shape = vec![1]; + let speed_vec = vec![speed]; + + let ref_s_shape = vec![ref_s.shape()[0], ref_s.shape()[1]]; + let ref_s_vec: Vec = ref_s.iter().copied().collect(); + + let inputs = ort::inputs![ + "input_ids" => Value::from_array((input_ids_shape, input_ids_vec))?, + "style" => Value::from_array((ref_s_shape, ref_s_vec))?, + "speed" => Value::from_array((speed_shape, speed_vec))?, + ]; + + let outputs = self.session.run(inputs)?; + + // Extract audio output (first output tensor) + let (_, audio_data) = outputs[0] + .try_extract_tensor::() + .context("Failed to extract audio tensor")?; + + let total_len = audio_data.len(); + + // Trim last 5000 samples (same as Python) + let trim_len = 5000.min(total_len); + let trimmed_len = total_len - trim_len; + + let audio: Vec = audio_data.iter().take(trimmed_len).copied().collect(); + + Ok(audio) + } + + /// Generate speech from text. + /// + /// Automatically chunks long texts, generates each chunk, and concatenates. + pub fn generate( + &mut self, + text: &str, + voice: &str, + speed: f32, + clean_text: bool, + ) -> Result> { + let text = if clean_text { + self.preprocessor.process(text) + } else { + text.to_string() + }; + + let chunks = chunk_text(&text, 400); + let mut all_audio = Vec::new(); + + for chunk in &chunks { + let audio = self.generate_single_chunk(chunk, voice, speed)?; + all_audio.extend(audio); + } + + Ok(all_audio) + } + + /// Generate speech and save to a WAV file. + pub fn generate_to_file( + &mut self, + text: &str, + output_path: &Path, + voice: &str, + speed: f32, + sample_rate: u32, + clean_text: bool, + ) -> Result<()> { + let audio = self.generate(text, voice, speed, clean_text)?; + + let spec = hound::WavSpec { + channels: 1, + sample_rate, + bits_per_sample: 32, + sample_format: hound::SampleFormat::Float, + }; + + let mut writer = + hound::WavWriter::create(output_path, spec).context("Failed to create WAV file")?; + + for &sample in &audio { + writer + .write_sample(sample) + .context("Failed to write audio sample")?; + } + + writer.finalize().context("Failed to finalise WAV file")?; + + println!("Audio saved to {}", output_path.display()); + Ok(()) + } +} diff --git a/kittentts-rs/src/preprocess.rs b/kittentts-rs/src/preprocess.rs new file mode 100644 index 0000000..d19cf79 --- /dev/null +++ b/kittentts-rs/src/preprocess.rs @@ -0,0 +1,947 @@ +//! Text preprocessing pipeline for TTS input normalisation. +//! +//! Port of the Python `preprocess.py` — converts numbers, currencies, times, +//! ordinals, contractions, etc. into their spoken-word equivalents so the +//! phonemiser receives clean English text. + +use fancy_regex::Regex; +use once_cell::sync::Lazy; + +// ───────────────────────────────────────────── +// Number → Words conversion +// ───────────────────────────────────────────── + +const ONES: &[&str] = &[ + "", + "one", + "two", + "three", + "four", + "five", + "six", + "seven", + "eight", + "nine", + "ten", + "eleven", + "twelve", + "thirteen", + "fourteen", + "fifteen", + "sixteen", + "seventeen", + "eighteen", + "nineteen", +]; + +const TENS: &[&str] = &[ + "", "", "twenty", "thirty", "forty", "fifty", "sixty", "seventy", "eighty", "ninety", +]; + +const SCALE: &[&str] = &["", "thousand", "million", "billion", "trillion"]; + +fn three_digits_to_words(n: u64) -> String { + if n == 0 { + return String::new(); + } + let mut parts = Vec::new(); + let hundreds = n / 100; + let remainder = n % 100; + if hundreds > 0 { + parts.push(format!("{} hundred", ONES[hundreds as usize])); + } + if remainder < 20 { + if remainder > 0 { + parts.push(ONES[remainder as usize].to_string()); + } + } else { + let tens_word = TENS[(remainder / 10) as usize]; + let ones_word = ONES[(remainder % 10) as usize]; + if ones_word.is_empty() { + parts.push(tens_word.to_string()); + } else { + parts.push(format!("{}-{}", tens_word, ones_word)); + } + } + parts.join(" ") +} + +/// Convert an integer to its English word representation. +pub fn number_to_words(n: i64) -> String { + if n == 0 { + return "zero".to_string(); + } + if n < 0 { + return format!("negative {}", number_to_words(-n)); + } + let n = n as u64; + + // X00–X999 read as "X hundred" (e.g. 1200 → "twelve hundred") + if (100..=9999).contains(&n) && n % 100 == 0 && n % 1000 != 0 { + let hundreds = n / 100; + if hundreds < 20 { + return format!("{} hundred", ONES[hundreds as usize]); + } + } + + let mut parts = Vec::new(); + let mut remaining = n; + for scale in SCALE { + let chunk = remaining % 1000; + if chunk > 0 { + let chunk_words = three_digits_to_words(chunk); + if scale.is_empty() { + parts.push(chunk_words); + } else { + parts.push(format!("{} {}", chunk_words, scale)); + } + } + remaining /= 1000; + if remaining == 0 { + break; + } + } + parts.reverse(); + parts.join(" ") +} + +/// Convert a float to words, reading decimal digits individually. +pub fn float_to_words(value: &str, decimal_sep: &str) -> String { + let mut text = value.to_string(); + let negative = text.starts_with('-'); + if negative { + text = text[1..].to_string(); + } + + let result = if text.contains('.') { + let mut split = text.splitn(2, '.'); + let int_part = split.next().unwrap_or("0"); + let dec_part = split.next().unwrap_or("0"); + let int_words = if int_part.is_empty() { + "zero".to_string() + } else { + number_to_words(int_part.parse::().unwrap_or(0)) + }; + let digit_names = [ + "zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine", + ]; + let dec_words: Vec<&str> = dec_part + .chars() + .map(|c| { + let d = c.to_digit(10).unwrap_or(0) as usize; + digit_names[d] + }) + .collect(); + format!("{} {} {}", int_words, decimal_sep, dec_words.join(" ")) + } else { + number_to_words(text.parse::().unwrap_or(0)) + }; + + if negative { + format!("negative {}", result) + } else { + result + } +} + +fn roman_to_int(s: &str) -> i64 { + let val = |ch: char| -> i64 { + match ch.to_ascii_uppercase() { + 'I' => 1, + 'V' => 5, + 'X' => 10, + 'L' => 50, + 'C' => 100, + 'D' => 500, + 'M' => 1000, + _ => 0, + } + }; + let mut result: i64 = 0; + let mut prev: i64 = 0; + for ch in s.chars().rev() { + let curr = val(ch); + if curr >= prev { + result += curr; + } else { + result -= curr; + } + prev = curr; + } + result +} + +// ───────────────────────────────────────────── +// Ordinal helpers +// ───────────────────────────────────────────── + +fn ordinal_suffix(n: i64) -> String { + let word = number_to_words(n); + + let (prefix, last, joiner) = if word.contains('-') { + let idx = word.rfind('-').unwrap(); + (word[..idx].to_string(), word[idx + 1..].to_string(), "-") + } else if word.contains(' ') { + let idx = word.rfind(' ').unwrap(); + (word[..idx].to_string(), word[idx + 1..].to_string(), " ") + } else { + (String::new(), word.clone(), "") + }; + + let ordinal_exceptions: &[(&str, &str)] = &[ + ("one", "first"), + ("two", "second"), + ("three", "third"), + ("four", "fourth"), + ("five", "fifth"), + ("six", "sixth"), + ("seven", "seventh"), + ("eight", "eighth"), + ("nine", "ninth"), + ("twelve", "twelfth"), + ]; + + let last_ord = ordinal_exceptions + .iter() + .find(|(base, _)| *base == last) + .map(|(_, ord)| ord.to_string()) + .unwrap_or_else(|| { + if last.ends_with('t') { + format!("{}h", last) + } else if last.ends_with('e') { + format!("{}th", &last[..last.len() - 1]) + } else { + format!("{}th", last) + } + }); + + if prefix.is_empty() { + last_ord + } else { + format!("{}{}{}", prefix, joiner, last_ord) + } +} + +// ───────────────────────────────────────────── +// Compiled regex patterns (lazy statics) +// ───────────────────────────────────────────── + +static RE_URL: Lazy = Lazy::new(|| Regex::new(r"https?://\S+|www\.\S+").unwrap()); +static RE_EMAIL: Lazy = + Lazy::new(|| Regex::new(r"(?i)\b[\w.+\-]+@[\w\-]+\.[a-z]{2,}\b").unwrap()); +static RE_HTML: Lazy = Lazy::new(|| Regex::new(r"<[^>]+>").unwrap()); +static RE_HASHTAG: Lazy = Lazy::new(|| Regex::new(r"#\w+").unwrap()); +static RE_MENTION: Lazy = Lazy::new(|| Regex::new(r"@\w+").unwrap()); +static RE_PUNCT: Lazy = + Lazy::new(|| Regex::new(r"[^\w\s.,?!;:\-\u{2014}\u{2013}\u{2026}]").unwrap()); +static RE_SPACES: Lazy = Lazy::new(|| Regex::new(r"\s+").unwrap()); +static RE_NUMBER: Lazy = + Lazy::new(|| Regex::new(r"(? = Lazy::new(|| Regex::new(r"(?i)\b(\d+)(st|nd|rd|th)\b").unwrap()); +static RE_PERCENT: Lazy = Lazy::new(|| Regex::new(r"(-?[\d,]+(?:\.\d+)?)\s*%").unwrap()); +static RE_CURRENCY: Lazy = Lazy::new(|| { + Regex::new(r"([$€£¥₹₩₿])\s*([\d,]+(?:\.\d+)?)\s*([KMBT])?(?![a-zA-Z\d])").unwrap() +}); +static RE_TIME: Lazy = + Lazy::new(|| Regex::new(r"(?i)\b(\d{1,2}):(\d{2})(?::(\d{2}))?\s*(am|pm)?\b").unwrap()); +static RE_RANGE: Lazy = Lazy::new(|| Regex::new(r"(? = + Lazy::new(|| Regex::new(r"\b([a-zA-Z][a-zA-Z0-9]*)-(\d[\d.]*)(?=[^\d.]|$)").unwrap()); +static RE_UNIT: Lazy = Lazy::new(|| { + Regex::new(r"(?i)(\d+(?:\.\d+)?)\s*(km|kg|mg|ml|gb|mb|kb|tb|hz|khz|mhz|ghz|mph|kph|°[cCfF]|[cCfF]°|ms|ns|µs)\b").unwrap() +}); +static RE_SCALE: Lazy = + Lazy::new(|| Regex::new(r"(? = Lazy::new(|| { + Regex::new(r"(? = Lazy::new(|| Regex::new(r"\b(\d+)\s*/\s*(\d+)\b").unwrap()); +static RE_DECADE: Lazy = Lazy::new(|| Regex::new(r"\b(\d{1,3})0s\b").unwrap()); +static RE_LEAD_DEC: Lazy = Lazy::new(|| Regex::new(r"(? = Lazy::new(|| { + Regex::new(r"\b(M{0,4})(CM|CD|D?C{0,3})(XC|XL|L?X{0,3})(IX|IV|V?I{0,3})\b").unwrap() +}); +static RE_IP: Lazy = + Lazy::new(|| Regex::new(r"\b(\d{1,3})\.(\d{1,3})\.(\d{1,3})\.(\d{1,3})\b").unwrap()); +// Phone patterns +static RE_PHONE_11: Lazy = Lazy::new(|| { + Regex::new(r"(? = + Lazy::new(|| Regex::new(r"(? = + Lazy::new(|| Regex::new(r"(? = Lazy::new(|| Regex::new(r"(?i)\bcan't\b").unwrap()); +static RE_WONT: Lazy = Lazy::new(|| Regex::new(r"(?i)\bwon't\b").unwrap()); +static RE_SHANT: Lazy = Lazy::new(|| Regex::new(r"(?i)\bshan't\b").unwrap()); +static RE_AINT: Lazy = Lazy::new(|| Regex::new(r"(?i)\bain't\b").unwrap()); +static RE_LETS: Lazy = Lazy::new(|| Regex::new(r"(?i)\blet's\b").unwrap()); +static RE_ITS: Lazy = Lazy::new(|| Regex::new(r"(?i)\bit's\b").unwrap()); +static RE_NT: Lazy = Lazy::new(|| Regex::new(r"(?i)\b(\w+)n't\b").unwrap()); +static RE_RE: Lazy = Lazy::new(|| Regex::new(r"(?i)\b(\w+)'re\b").unwrap()); +static RE_VE: Lazy = Lazy::new(|| Regex::new(r"(?i)\b(\w+)'ve\b").unwrap()); +static RE_LL: Lazy = Lazy::new(|| Regex::new(r"(?i)\b(\w+)'ll\b").unwrap()); +static RE_D: Lazy = Lazy::new(|| Regex::new(r"(?i)\b(\w+)'d\b").unwrap()); +static RE_M: Lazy = Lazy::new(|| Regex::new(r"(?i)\b(\w+)'m\b").unwrap()); + +static RE_TITLE_WORDS: Lazy = Lazy::new(|| { + Regex::new(r"(?i)\b(war|chapter|part|volume|act|scene|book|section|article|king|queen|pope|louis|henry|edward|george|william|james|phase|round|level|stage|class|type|version|episode|season)\b").unwrap() +}); + +// ───────────────────────────────────────────── +// Expansion functions +// ───────────────────────────────────────────── + +fn expand_ordinals(text: &str) -> String { + RE_ORDINAL + .replace_all(text, |caps: &fancy_regex::Captures| { + let n: i64 = caps[1].parse().unwrap_or(0); + ordinal_suffix(n) + }) + .into_owned() +} + +fn expand_percentages(text: &str) -> String { + RE_PERCENT + .replace_all(text, |caps: &fancy_regex::Captures| { + let raw = caps[1].replace(',', ""); + if raw.contains('.') { + format!("{} percent", float_to_words(&raw, "point")) + } else { + format!( + "{} percent", + number_to_words(raw.parse::().unwrap_or(0)) + ) + } + }) + .to_string() +} + +fn expand_currency(text: &str) -> String { + let currency_name = |sym: &str| -> &str { + match sym { + "$" => "dollar", + "€" => "euro", + "£" => "pound", + "¥" => "yen", + "₹" => "rupee", + "₩" => "won", + "₿" => "bitcoin", + _ => "", + } + }; + let scale_map = |s: &str| -> &str { + match s { + "K" => "thousand", + "M" => "million", + "B" => "billion", + "T" => "trillion", + _ => "", + } + }; + + RE_CURRENCY + .replace_all(text, |caps: &fancy_regex::Captures| { + let symbol = &caps[1]; + let raw = caps[2].replace(',', ""); + let scale_suffix = caps.get(3).map(|m| m.as_str()); + let unit = currency_name(symbol); + + if let Some(suffix) = scale_suffix { + let scale_word = scale_map(suffix); + let num = if raw.contains('.') { + float_to_words(&raw, "point") + } else { + number_to_words(raw.parse::().unwrap_or(0)) + }; + let plural = if !unit.is_empty() { + format!("{}s", unit) + } else { + String::new() + }; + format!("{} {} {}", num, scale_word, plural) + .trim() + .to_string() + } else if raw.contains('.') { + let mut split = raw.splitn(2, '.'); + let int_part = split.next().unwrap_or("0"); + let dec_part = split.next().unwrap_or("0"); + let dec_str = format!("{:0<2}", &dec_part[..dec_part.len().min(2)]); + let dec_val: i64 = dec_str.parse().unwrap_or(0); + let int_val: i64 = int_part.parse().unwrap_or(0); + let int_words = number_to_words(int_val); + let mut result = if !unit.is_empty() { + format!("{} {}s", int_words, unit) + } else { + int_words + }; + if dec_val > 0 { + let cents = number_to_words(dec_val); + let cent_suffix = if dec_val != 1 { "cents" } else { "cent" }; + result = format!("{} and {} {}", result, cents, cent_suffix); + } + result + } else { + let val: i64 = raw.parse().unwrap_or(0); + let words = number_to_words(val); + if !unit.is_empty() { + let plural = if val != 1 { "s" } else { "" }; + format!("{} {}{}", words, unit, plural) + } else { + words + } + } + }) + .to_string() +} + +fn expand_time(text: &str) -> String { + RE_TIME + .replace_all(text, |caps: &fancy_regex::Captures| { + let h: i64 = caps[1].parse().unwrap_or(0); + let mins: i64 = caps[2].parse().unwrap_or(0); + let suffix = caps + .get(4) + .map(|m| format!(" {}", m.as_str().to_lowercase())) + .unwrap_or_default(); + let h_words = number_to_words(h); + if mins == 0 { + if caps.get(4).is_some() { + format!("{}{}", h_words, suffix) + } else { + format!("{} hundred{}", h_words, suffix) + } + } else if mins < 10 { + format!("{} oh {}{}", h_words, number_to_words(mins), suffix) + } else { + format!("{} {}{}", h_words, number_to_words(mins), suffix) + } + }) + .to_string() +} + +fn expand_ranges(text: &str) -> String { + RE_RANGE + .replace_all(text, |caps: &fancy_regex::Captures| { + let lo = number_to_words(caps[1].parse::().unwrap_or(0)); + let hi = number_to_words(caps[2].parse::().unwrap_or(0)); + format!("{} to {}", lo, hi) + }) + .to_string() +} + +fn expand_model_names(text: &str) -> String { + RE_MODEL_VER + .replace_all(text, |caps: &fancy_regex::Captures| { + format!("{} {}", &caps[1], &caps[2]) + }) + .to_string() +} + +fn expand_units_map(u: &str) -> &str { + match u { + "km" => "kilometers", + "kg" => "kilograms", + "mg" => "milligrams", + "ml" => "milliliters", + "gb" => "gigabytes", + "mb" => "megabytes", + "kb" => "kilobytes", + "tb" => "terabytes", + "hz" => "hertz", + "khz" => "kilohertz", + "mhz" => "megahertz", + "ghz" => "gigahertz", + "mph" => "miles per hour", + "kph" => "kilometers per hour", + "ms" => "milliseconds", + "ns" => "nanoseconds", + "µs" => "microseconds", + "°c" | "c°" => "degrees Celsius", + "°f" | "f°" => "degrees Fahrenheit", + _ => u, + } +} + +fn expand_units(text: &str) -> String { + RE_UNIT + .replace_all(text, |caps: &fancy_regex::Captures| { + let raw = &caps[1]; + let unit_lower = caps[2].to_lowercase(); + let expanded = expand_units_map(&unit_lower); + let num = if raw.contains('.') { + float_to_words(raw, "point") + } else { + number_to_words(raw.parse::().unwrap_or(0)) + }; + format!("{} {}", num, expanded) + }) + .to_string() +} + +fn expand_scale_map(s: &str) -> &str { + match s { + "K" => "thousand", + "M" => "million", + "B" => "billion", + "T" => "trillion", + _ => s, + } +} + +fn expand_scale_suffixes(text: &str) -> String { + RE_SCALE + .replace_all(text, |caps: &fancy_regex::Captures| { + let raw = &caps[1]; + let suffix = &caps[2]; + let scale_word = expand_scale_map(suffix); + let num = if raw.contains('.') { + float_to_words(raw, "point") + } else { + number_to_words(raw.parse::().unwrap_or(0)) + }; + format!("{} {}", num, scale_word) + }) + .to_string() +} + +fn expand_scientific_notation(text: &str) -> String { + RE_SCI + .replace_all(text, |caps: &fancy_regex::Captures| { + let coeff_raw = &caps[1]; + let exp: i64 = caps[2].parse().unwrap_or(0); + let coeff_words = if coeff_raw.contains('.') { + float_to_words(coeff_raw, "point") + } else { + number_to_words(coeff_raw.parse::().unwrap_or(0)) + }; + let exp_words = number_to_words(exp.abs()); + let sign = if exp < 0 { "negative " } else { "" }; + format!("{} times ten to the {}{}", coeff_words, sign, exp_words) + }) + .to_string() +} + +fn expand_fractions(text: &str) -> String { + RE_FRACTION + .replace_all(text, |caps: &fancy_regex::Captures| { + let num: i64 = caps[1].parse().unwrap_or(0); + let den: i64 = caps[2].parse().unwrap_or(1); + if den == 0 { + return caps[0].to_string(); + } + let num_words = number_to_words(num); + let denom_word = if den == 2 { + if num == 1 { + "half".to_string() + } else { + "halves".to_string() + } + } else if den == 4 { + if num == 1 { + "quarter".to_string() + } else { + "quarters".to_string() + } + } else { + let base = ordinal_suffix(den); + if num != 1 { + format!("{}s", base) + } else { + base + } + }; + format!("{} {}", num_words, denom_word) + }) + .to_string() +} + +fn expand_decades(text: &str) -> String { + let decade_map = |d: u64| -> &'static str { + match d { + 0 => "hundreds", + 1 => "tens", + 2 => "twenties", + 3 => "thirties", + 4 => "forties", + 5 => "fifties", + 6 => "sixties", + 7 => "seventies", + 8 => "eighties", + 9 => "nineties", + _ => "", + } + }; + RE_DECADE + .replace_all(text, |caps: &fancy_regex::Captures| { + let base: u64 = caps[1].parse().unwrap_or(0); + let decade_digit = base % 10; + let decade_word = decade_map(decade_digit); + if base < 10 { + decade_word.to_string() + } else { + let century_part = base / 10; + format!("{} {}", number_to_words(century_part as i64), decade_word) + } + }) + .to_string() +} + +fn expand_ip_addresses(text: &str) -> String { + let digit_name = |c: char| -> &'static str { + match c { + '0' => "zero", + '1' => "one", + '2' => "two", + '3' => "three", + '4' => "four", + '5' => "five", + '6' => "six", + '7' => "seven", + '8' => "eight", + '9' => "nine", + _ => "", + } + }; + let octet_to_words = + |s: &str| -> String { s.chars().map(digit_name).collect::>().join(" ") }; + + RE_IP + .replace_all(text, |caps: &fancy_regex::Captures| { + let parts: Vec = (1..=4).map(|i| octet_to_words(&caps[i])).collect(); + parts.join(" dot ") + }) + .to_string() +} + +fn expand_phone_numbers(text: &str) -> String { + let digit_name = |c: char| -> &'static str { + match c { + '0' => "zero", + '1' => "one", + '2' => "two", + '3' => "three", + '4' => "four", + '5' => "five", + '6' => "six", + '7' => "seven", + '8' => "eight", + '9' => "nine", + _ => "", + } + }; + let digits_to_words = + |s: &str| -> String { s.chars().map(digit_name).collect::>().join(" ") }; + + // 11-digit + let text = RE_PHONE_11 + .replace_all(&text, |caps: &fancy_regex::Captures| { + format!( + "{} {} {} {}", + digits_to_words(&caps[1]), + digits_to_words(&caps[2]), + digits_to_words(&caps[3]), + digits_to_words(&caps[4]) + ) + }) + .to_string(); + // 10-digit + let text = RE_PHONE_10 + .replace_all(&text, |caps: &fancy_regex::Captures| { + format!( + "{} {} {}", + digits_to_words(&caps[1]), + digits_to_words(&caps[2]), + digits_to_words(&caps[3]) + ) + }) + .to_string(); + // 7-digit + RE_PHONE_7 + .replace_all(&text, |caps: &fancy_regex::Captures| { + format!( + "{} {}", + digits_to_words(&caps[1]), + digits_to_words(&caps[2]) + ) + }) + .to_string() +} + +fn expand_contractions(text: &str) -> String { + let mut t = RE_CANT.replace_all(text, "cannot").to_string(); + t = RE_WONT.replace_all(&t, "will not").to_string(); + t = RE_SHANT.replace_all(&t, "shall not").to_string(); + t = RE_AINT.replace_all(&t, "is not").to_string(); + t = RE_LETS.replace_all(&t, "let us").to_string(); + t = RE_ITS.replace_all(&t, "it is").to_string(); + t = RE_NT.replace_all(&t, "$1 not").to_string(); + t = RE_RE.replace_all(&t, "$1 are").to_string(); + t = RE_VE.replace_all(&t, "$1 have").to_string(); + t = RE_LL.replace_all(&t, "$1 will").to_string(); + t = RE_D.replace_all(&t, "$1 would").to_string(); + t = RE_M.replace_all(&t, "$1 am").to_string(); + t +} + +fn normalize_leading_decimals(text: &str) -> String { + // -.5 → -0.5 + let re_neg = Regex::new(r"(? String { + RE_ROMAN + .replace_all(text, |caps: &fancy_regex::Captures| { + let roman = caps[0].to_string(); + if roman.trim().is_empty() { + return roman; + } + // Skip single ambiguous letters + if roman.len() == 1 && "IVX".contains(&roman) { + let start = caps.get(0).unwrap().start(); + let preceding = &text[start.saturating_sub(30)..start]; + if !RE_TITLE_WORDS.is_match(preceding).unwrap_or(false) { + return roman; + } + } + let val = roman_to_int(&roman); + if val == 0 { + return roman; + } + number_to_words(val) + }) + .to_string() +} + +fn replace_numbers(text: &str, replace_floats: bool) -> String { + RE_NUMBER + .replace_all(text, |caps: &fancy_regex::Captures| { + let raw = caps[0].replace(',', ""); + if raw.contains('.') && replace_floats { + float_to_words(&raw, "point") + } else { + match raw.parse::() { + Ok(v) => number_to_words(v as i64), + Err(_) => caps[0].to_string(), + } + } + }) + .to_string() +} + +fn remove_urls(text: &str) -> String { + RE_URL.replace_all(text, "").trim().to_string() +} + +fn remove_emails(text: &str) -> String { + RE_EMAIL.replace_all(text, "").trim().to_string() +} + +fn remove_html_tags(text: &str) -> String { + RE_HTML.replace_all(text, " ").to_string() +} + +fn remove_hashtags(text: &str) -> String { + RE_HASHTAG.replace_all(text, "").to_string() +} + +fn remove_mentions(text: &str) -> String { + RE_MENTION.replace_all(text, "").to_string() +} + +fn remove_punctuation(text: &str) -> String { + RE_PUNCT.replace_all(text, " ").to_string() +} + +fn remove_extra_whitespace(text: &str) -> String { + RE_SPACES.replace_all(text, " ").trim().to_string() +} + +fn normalize_unicode(text: &str) -> String { + // Rust strings are always valid UTF-8; we do NFC normalisation + use std::iter::FromIterator; + // Simple approach: we just return the string as-is since full NFC + // requires the `unicode-normalization` crate. For TTS purposes this is acceptable. + String::from_iter(text.chars()) +} + +fn remove_accents(text: &str) -> String { + // Simplified: strip combining marks after NFD-ish decomposition. + // For a full solution we'd use unicode-normalization crate. + text.chars() + .filter(|c| !('\u{0300}'..='\u{036f}').contains(c)) + .collect() +} + +fn to_lowercase(text: &str) -> String { + text.to_lowercase() +} + +// ───────────────────────────────────────────── +// Configurable pipeline +// ───────────────────────────────────────────── + +/// Configurable text preprocessing pipeline for TTS input normalisation. +pub struct TextPreprocessor { + pub lowercase: bool, + pub do_replace_numbers: bool, + pub replace_floats: bool, + pub do_expand_contractions: bool, + pub do_expand_model_names: bool, + pub do_expand_ordinals: bool, + pub do_expand_percentages: bool, + pub do_expand_currency: bool, + pub do_expand_time: bool, + pub do_expand_ranges: bool, + pub do_expand_units: bool, + pub do_expand_scale_suffixes: bool, + pub do_expand_scientific_notation: bool, + pub do_expand_fractions: bool, + pub do_expand_decades: bool, + pub do_expand_phone_numbers: bool, + pub do_expand_ip_addresses: bool, + pub do_normalize_leading_decimals: bool, + pub do_expand_roman_numerals: bool, + pub do_remove_urls: bool, + pub do_remove_emails: bool, + pub do_remove_html: bool, + pub do_remove_hashtags: bool, + pub do_remove_mentions: bool, + pub do_remove_punctuation: bool, + pub do_normalize_unicode: bool, + pub do_remove_accents: bool, + pub do_remove_extra_whitespace: bool, +} + +impl Default for TextPreprocessor { + /// Default matches the Python `TextPreprocessor(remove_punctuation=False)` used in the model. + fn default() -> Self { + TextPreprocessor { + lowercase: true, + do_replace_numbers: true, + replace_floats: true, + do_expand_contractions: true, + do_expand_model_names: true, + do_expand_ordinals: true, + do_expand_percentages: true, + do_expand_currency: true, + do_expand_time: true, + do_expand_ranges: true, + do_expand_units: true, + do_expand_scale_suffixes: true, + do_expand_scientific_notation: true, + do_expand_fractions: true, + do_expand_decades: true, + do_expand_phone_numbers: true, + do_expand_ip_addresses: true, + do_normalize_leading_decimals: true, + do_expand_roman_numerals: false, + do_remove_urls: true, + do_remove_emails: true, + do_remove_html: true, + do_remove_hashtags: false, + do_remove_mentions: false, + do_remove_punctuation: false, // model default + do_normalize_unicode: true, + do_remove_accents: false, + do_remove_extra_whitespace: true, + } + } +} + +impl TextPreprocessor { + /// Process text through the configured pipeline, exactly matching the Python + /// `TextPreprocessor.process()` ordering. + pub fn process(&self, text: &str) -> String { + let mut text = text.to_string(); + + if self.do_normalize_unicode { + text = normalize_unicode(&text); + } + if self.do_remove_html { + text = remove_html_tags(&text); + } + if self.do_remove_urls { + text = remove_urls(&text); + } + if self.do_remove_emails { + text = remove_emails(&text); + } + if self.do_remove_hashtags { + text = remove_hashtags(&text); + } + if self.do_remove_mentions { + text = remove_mentions(&text); + } + if self.do_expand_contractions { + text = expand_contractions(&text); + } + if self.do_expand_ip_addresses { + text = expand_ip_addresses(&text); + } + if self.do_normalize_leading_decimals { + text = normalize_leading_decimals(&text); + } + if self.do_expand_currency { + text = expand_currency(&text); + } + if self.do_expand_percentages { + text = expand_percentages(&text); + } + if self.do_expand_scientific_notation { + text = expand_scientific_notation(&text); + } + if self.do_expand_time { + text = expand_time(&text); + } + if self.do_expand_ordinals { + text = expand_ordinals(&text); + } + if self.do_expand_units { + text = expand_units(&text); + } + if self.do_expand_scale_suffixes { + text = expand_scale_suffixes(&text); + } + if self.do_expand_fractions { + text = expand_fractions(&text); + } + if self.do_expand_decades { + text = expand_decades(&text); + } + if self.do_expand_phone_numbers { + text = expand_phone_numbers(&text); + } + if self.do_expand_ranges { + text = expand_ranges(&text); + } + if self.do_expand_model_names { + text = expand_model_names(&text); + } + if self.do_expand_roman_numerals { + text = expand_roman_numerals(&text); + } + if self.do_replace_numbers { + text = replace_numbers(&text, self.replace_floats); + } + if self.do_remove_accents { + text = remove_accents(&text); + } + if self.do_remove_punctuation { + text = remove_punctuation(&text); + } + if self.lowercase { + text = to_lowercase(&text); + } + if self.do_remove_extra_whitespace { + text = remove_extra_whitespace(&text); + } + + text + } +} diff --git a/kittentts-rs/src/text_cleaner.rs b/kittentts-rs/src/text_cleaner.rs new file mode 100644 index 0000000..78ed8da --- /dev/null +++ b/kittentts-rs/src/text_cleaner.rs @@ -0,0 +1,53 @@ +use std::collections::HashMap; + +/// Maps phoneme/symbol characters to integer token IDs. +/// +/// Replicates the Python `TextCleaner` class: builds a lookup from the same +/// ordered symbol table (pad + punctuation + ASCII letters + IPA letters). +pub struct TextCleaner { + word_index: HashMap, +} + +impl TextCleaner { + pub fn new() -> Self { + let pad = "$"; + let punctuation = ";:,.!?¡¿—…\"«»\"\" "; + let letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; + let letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"; + + let mut symbols: Vec = Vec::new(); + for c in pad.chars() { + symbols.push(c); + } + for c in punctuation.chars() { + symbols.push(c); + } + for c in letters.chars() { + symbols.push(c); + } + for c in letters_ipa.chars() { + symbols.push(c); + } + + let mut word_index = HashMap::new(); + for (i, &ch) in symbols.iter().enumerate() { + word_index.insert(ch, i as i64); + } + + TextCleaner { word_index } + } + + /// Convert a phoneme string to a sequence of token IDs. + /// Characters not in the symbol table are silently skipped. + pub fn encode(&self, text: &str) -> Vec { + text.chars() + .filter_map(|c| self.word_index.get(&c).copied()) + .collect() + } +} + +impl Default for TextCleaner { + fn default() -> Self { + Self::new() + } +} From 88870e85bd589679544c21e8ff09b2e313cc4700 Mon Sep 17 00:00:00 2001 From: Rahul D Shetty <35rahuldshetty@gmail.com> Date: Sun, 1 Mar 2026 14:19:49 +0530 Subject: [PATCH 2/2] handle punctuations Signed-off-by: Rahul D Shetty <35rahuldshetty@gmail.com> --- kittentts-rs/src/espeak.rs | 36 +++++++++++++++++++++++++------- kittentts-rs/src/model.rs | 13 +++--------- kittentts-rs/src/text_cleaner.rs | 10 +++++---- 3 files changed, 38 insertions(+), 21 deletions(-) diff --git a/kittentts-rs/src/espeak.rs b/kittentts-rs/src/espeak.rs index 4ad0e0d..9441292 100644 --- a/kittentts-rs/src/espeak.rs +++ b/kittentts-rs/src/espeak.rs @@ -1,6 +1,5 @@ use espeak_ng_sys as espeak_ng; use std::ffi::{CStr, CString}; -use std::os::raw::c_void; use std::ptr; pub struct Espeak { @@ -42,17 +41,40 @@ impl Espeak { pub fn text_to_phonemes(&self, text: &str) -> anyhow::Result { let c_text = CString::new(text)?; - let mut text_ptr = c_text.as_ptr() as *const c_void; + let mut text_ptr = c_text.as_ptr() as *const std::os::raw::c_char; + let mut all_phonemes = String::new(); - // textmode 1 = UTF8, phonememode 2 = IPA phonemes - let result_ptr = unsafe { espeak_ng::espeak_TextToPhonemes(&mut text_ptr, 1, 2) }; + while !text_ptr.is_null() && unsafe { *text_ptr != 0 } { + let initial_ptr = text_ptr; + let mut current_ptr = text_ptr as *const std::os::raw::c_void; - if result_ptr.is_null() { + // textmode 1 = UTF8, phonememode 66 = IPA + Punctuation + let result_ptr = unsafe { espeak_ng::espeak_TextToPhonemes(&mut current_ptr, 1, 66) }; + + if !result_ptr.is_null() { + let c_str = unsafe { CStr::from_ptr(result_ptr) }; + let phonemes = c_str.to_string_lossy(); + if !all_phonemes.is_empty() + && !all_phonemes.ends_with(' ') + && !phonemes.starts_with(' ') + { + all_phonemes.push(' '); + } + all_phonemes.push_str(&phonemes); + } + + text_ptr = current_ptr as *const std::os::raw::c_char; + + if text_ptr.is_null() || text_ptr == initial_ptr { + break; + } + } + + if all_phonemes.is_empty() { anyhow::bail!("eSpeak failed to generate phonemes"); } - let c_str = unsafe { CStr::from_ptr(result_ptr) }; - Ok(c_str.to_string_lossy().into_owned()) + Ok(all_phonemes) } } diff --git a/kittentts-rs/src/model.rs b/kittentts-rs/src/model.rs index 85f6ac4..9782500 100644 --- a/kittentts-rs/src/model.rs +++ b/kittentts-rs/src/model.rs @@ -44,12 +44,11 @@ fn ensure_punctuation(text: &str) -> String { /// Split text into chunks for processing long texts. fn chunk_text(text: &str, max_len: usize) -> Vec { - let re = Regex::new(r"[.!?]+").unwrap(); - let sentences: Vec<&str> = re.split(text).collect(); + let re = Regex::new(r"([^.!?]+[.!?]+)").unwrap(); let mut chunks = Vec::new(); - for sentence in sentences { - let sentence = sentence.trim(); + for caps in re.captures_iter(text) { + let sentence = caps[1].trim(); if sentence.is_empty() { continue; } @@ -237,14 +236,8 @@ impl KittenTTSModel { let tokens_str = basic_english_tokenize(&phonemes_raw); let phonemes_joined = tokens_str.join(" "); - #[cfg(debug_assertions)] - println!("Phonemes: {}", phonemes_joined); - let mut tokens = self.text_cleaner.encode(&phonemes_joined); - #[cfg(debug_assertions)] - println!("Tokens: {:?}", tokens); - // Add start and end tokens tokens.insert(0, 0); // start token tokens.push(10); // end marker diff --git a/kittentts-rs/src/text_cleaner.rs b/kittentts-rs/src/text_cleaner.rs index 78ed8da..89423c0 100644 --- a/kittentts-rs/src/text_cleaner.rs +++ b/kittentts-rs/src/text_cleaner.rs @@ -14,11 +14,10 @@ impl TextCleaner { let punctuation = ";:,.!?¡¿—…\"«»\"\" "; let letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; let letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"; + let extra = "|‖"; - let mut symbols: Vec = Vec::new(); - for c in pad.chars() { - symbols.push(c); - } + let mut symbols = Vec::new(); + symbols.push(pad.chars().next().unwrap()); for c in punctuation.chars() { symbols.push(c); } @@ -28,6 +27,9 @@ impl TextCleaner { for c in letters_ipa.chars() { symbols.push(c); } + for c in extra.chars() { + symbols.push(c); + } let mut word_index = HashMap::new(); for (i, &ch) in symbols.iter().enumerate() {