From 7a39031f4e61233284a989b7ffb2a818dacc958f Mon Sep 17 00:00:00 2001 From: Biraj Date: Sat, 21 Mar 2026 11:08:01 +0530 Subject: [PATCH] fix(chunking): keep Dr./Mr./a.m./p.m. inline instead of splitting sentences Before: - `chunk_text()` split on every `.` / `!` / `?` with no abbreviation handling. - `Dr. Sharma...` was treated like `Dr` + new sentence. - `9:30 a.m.` and `5:15 p.m.` were split at the periods inside the abbreviation. - That made TTS pause or slow down after `Dr`, `a`, and `m`, as if they were full sentence boundaries. After: - Protect common inline abbreviations before sentence splitting and restore them afterward. - `Dr.`, `Mr.`, `Mrs.`, `Ms.`, `Prof.`, `a.m.`, and `p.m.` now stay inside the same sentence/chunk. - TTS no longer inserts sentence-break pauses around those abbreviations. - Real sentence-ending punctuation still splits normally --- kittentts/onnx_model.py | 129 +++++++++++++++++++++++++++++----------- 1 file changed, 93 insertions(+), 36 deletions(-) diff --git a/kittentts/onnx_model.py b/kittentts/onnx_model.py index 88009fc..83bfc44 100644 --- a/kittentts/onnx_model.py +++ b/kittentts/onnx_model.py @@ -1,38 +1,73 @@ +import re + from misaki import en, espeak import numpy as np import phonemizer import soundfile as sf import onnxruntime as ort + from .preprocess import TextPreprocessor + def basic_english_tokenize(text): """Basic English tokenizer that splits on whitespace and punctuation.""" - import re tokens = re.findall(r"\w+|[^\w\s]", text) return tokens + def ensure_punctuation(text): """Ensure text ends with punctuation. If not, add a comma.""" text = text.strip() if not text: return text - if text[-1] not in '.!?,;:': - text = text + ',' + if text[-1] not in ".!?,;:": + text = text + "," + return text + + +def _protect_inline_abbreviations(text): + """Protect period-bearing abbreviations from sentence splitting.""" + + placeholders = {} + patterns = [ + (r"\ba\.m\.(?=\s|$|[,;:!?])", "a.m."), + (r"\bp\.m\.(?=\s|$|[,;:!?])", "p.m."), + (r"\bDr\.(?=\s|$)", "Dr."), + (r"\bMr\.(?=\s|$)", "Mr."), + (r"\bMrs\.(?=\s|$)", "Mrs."), + (r"\bMs\.(?=\s|$)", "Ms."), + (r"\bProf\.(?=\s|$)", "Prof."), + ] + + for index, (pattern, original) in enumerate(patterns): + placeholder = f"__ABBR_{index}__" + text, count = re.subn(pattern, placeholder, text, flags=re.IGNORECASE) + if count: + placeholders[placeholder] = original + + return text, placeholders + + +def _restore_inline_abbreviations(text, placeholders): + """Restore protected abbreviations after sentence splitting.""" + for placeholder, original in placeholders.items(): + text = text.replace(placeholder, original) return text def chunk_text(text, max_len=400): """Split text into chunks for processing long texts.""" - import re - - sentences = re.split(r'[.!?]+', text) + + protected_text, placeholders = _protect_inline_abbreviations(text) + + sentences = re.split(r"[.!?]+", protected_text) chunks = [] - + for sentence in sentences: - sentence = sentence.strip() + sentence = _restore_inline_abbreviations(sentence, placeholders).strip() if not sentence: continue - + if len(sentence) <= max_len: chunks.append(ensure_punctuation(sentence)) else: @@ -48,7 +83,7 @@ def chunk_text(text, max_len=400): temp_chunk = word if temp_chunk: chunks.append(ensure_punctuation(temp_chunk.strip())) - + return chunks @@ -56,11 +91,11 @@ class TextCleaner: def __init__(self, dummy=None): _pad = "$" _punctuation = ';:,.!?¡¿—…"«»"" ' - _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' + _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) - + dicts = {} for i in range(len(symbols)): dicts[symbols[i]] = i @@ -78,7 +113,13 @@ def __call__(self, text): class KittenTTS_1_Onnx: - def __init__(self, model_path="kitten_tts_nano_preview.onnx", voices_path="voices.npz", speed_priors={}, voice_aliases={}): + def __init__( + self, + model_path="kitten_tts_nano_preview.onnx", + voices_path="voices.npz", + speed_priors={}, + voice_aliases={}, + ): """Initialize KittenTTS with model and voice data. Args: @@ -104,7 +145,7 @@ def __init__(self, model_path="kitten_tts_nano_preview.onnx", voices_path="voice self.voice_aliases = voice_aliases self.preprocessor = TextPreprocessor(remove_punctuation=False) - + def _prepare_inputs(self, text: str, voice: str, speed: float = 1.0) -> dict: """Prepare ONNX model inputs from text and voice parameters.""" if voice in self.voice_aliases: @@ -112,65 +153,82 @@ def _prepare_inputs(self, text: str, voice: str, speed: float = 1.0) -> dict: if voice not in self.available_voices: raise ValueError(f"Voice '{voice}' not available. Choose from: {self.available_voices}") - + if voice in self.speed_priors: speed = speed * self.speed_priors[voice] - + # Phonemize the input text phonemes_list = self.phonemizer.phonemize([text]) - + # Process phonemes to get token IDs phonemes = basic_english_tokenize(phonemes_list[0]) - phonemes = ' '.join(phonemes) + phonemes = " ".join(phonemes) tokens = self.text_cleaner(phonemes) - + # Add start and end tokens tokens.insert(0, 0) tokens.append(10) tokens.append(0) - + input_ids = np.array([tokens], dtype=np.int64) - ref_id = min(len(text), self.voices[voice].shape[0] - 1) - ref_s = self.voices[voice][ref_id:ref_id+1] - + ref_id = min(len(text), self.voices[voice].shape[0] - 1) + ref_s = self.voices[voice][ref_id : ref_id + 1] + return { "input_ids": input_ids, "style": ref_s, "speed": np.array([speed], dtype=np.float32), } - - def generate(self, text: str, voice: str = "expr-voice-5-m", speed: float = 1.0, clean_text: bool=True) -> np.ndarray: + + def generate( + self, + text: str, + voice: str = "expr-voice-5-m", + speed: float = 1.0, + clean_text: bool = True, + ) -> np.ndarray: + out_chunks = [] if clean_text: text = self.preprocessor(text) + for text_chunk in chunk_text(text): - out_chunks.append(self.generate_single_chunk(text_chunk, voice, speed)) + chunk = self.generate_single_chunk(text_chunk, voice, speed) + out_chunks.append(chunk) return np.concatenate(out_chunks, axis=-1) - def generate_single_chunk(self, text: str, voice: str = "expr-voice-5-m", speed: float = 1.0) -> np.ndarray: + def generate_single_chunk( + self, text: str, voice: str = "expr-voice-5-m", speed: float = 1.0 + ) -> np.ndarray: """Synthesize speech from text. - + Args: text: Input text to synthesize voice: Voice to use for synthesis speed: Speech speed (1.0 = normal) - + Returns: Audio data as numpy array """ onnx_inputs = self._prepare_inputs(text, voice, speed) - outputs = self.session.run(None, onnx_inputs) - + # Trim audio audio = outputs[0][..., :-5000] return audio - - def generate_to_file(self, text: str, output_path: str, voice: str = "expr-voice-5-m", - speed: float = 1.0, sample_rate: int = 24000, clean_text: bool=True) -> None: + + def generate_to_file( + self, + text: str, + output_path: str, + voice: str = "expr-voice-5-m", + speed: float = 1.0, + sample_rate: int = 24000, + clean_text: bool = True, + ) -> None: """Synthesize speech and save to file. - + Args: text: Input text to synthesize output_path: Path to save the audio file @@ -182,4 +240,3 @@ def generate_to_file(self, text: str, output_path: str, voice: str = "expr-voice audio = self.generate(text, voice, speed, clean_text=clean_text) sf.write(output_path, audio, sample_rate) print(f"Audio saved to {output_path}") -