From b4f06f3bd98fc1772dd4b95da7485af86a5bd770 Mon Sep 17 00:00:00 2001 From: weedge Date: Wed, 19 Feb 2025 23:02:42 +0800 Subject: [PATCH 01/37] feat: add flow_cache and hift_cache Signed-off-by: weedge --- cosyvoice/cli/cosyvoice.py | 6 +- cosyvoice/cli/model.py | 101 +++++++++++++++++++++++++- cosyvoice/flow/flow.py | 30 ++++---- cosyvoice/flow/flow_matching.py | 30 +++++--- cosyvoice/hifigan/generator.py | 39 ++++------ tts.py | 123 ++++++++++++++++---------------- 6 files changed, 214 insertions(+), 115 deletions(-) diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index b8c0b4b..4054071 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import threading import uuid import time from tqdm import tqdm @@ -23,7 +24,6 @@ class CosyVoice: - def __init__( self, model_dir, @@ -55,9 +55,7 @@ def token_to_wav_offline( ): tts_mel = self.model.flow.inference( token=speech_token.to(self.model.device), - token_len=torch.tensor([speech_token.size(1)], dtype=torch.int32).to( - self.model.device - ), + token_len=torch.tensor([speech_token.size(1)], dtype=torch.int32).to(self.model.device), prompt_token=prompt_token.to(self.model.device), prompt_token_len=prompt_token_len.to(self.model.device), prompt_feat=speech_feat.to(self.model.device), diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index b284d9e..0e430e9 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -11,11 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import threading import torch +from torch.nn import functional as F +from cosyvoice.utils.common import fade_in_out -class CosyVoiceModel: +class CosyVoiceModel: def __init__( self, flow: torch.nn.Module, @@ -25,8 +28,104 @@ def __init__( self.flow = flow self.hift = hift + # dict used to store session related variable + self.lock = threading.Lock() # dict lock + self.mel_overlap_dict = {} + self.flow_cache_dict = {} + self.hift_cache_dict = {} + def load(self, flow_model, hift_model): self.flow.load_state_dict(torch.load(flow_model, map_location=self.device)) self.flow.to(self.device).eval() self.hift.load_state_dict(torch.load(hift_model, map_location=self.device)) self.hift.to(self.device).eval() + + def token2wav( + self, + token, + prompt_token, + prompt_feat, + embedding, + session_id, + finalize=False, + speed=1.0, + is_flow_cache=False, + is_hift_cache=False, + ): + if is_flow_cache is True and session_id not in self.flow_cache_dict: + with self.lock: + self.mel_overlap_dict[session_id] = torch.zeros(1, 80, 0) + self.flow_cache_dict[session_id] = torch.zeros(1, 80, 0, 2) + + if is_hift_cache is True and session_id not in self.hift_cache_dict: + with self.lock: + self.hift_cache_dict[session_id] = None + + tts_mel, flow_cache = self.flow.inference( + token=token.to(self.device), + token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), + prompt_token=prompt_token.to(self.device), + prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to( + self.device + ), + prompt_feat=prompt_feat.to(self.device), + prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device), + embedding=embedding.to(self.device), + flow_cache=self.flow_cache_dict[session_id] if is_flow_cache else None, + ) + self.flow_cache_dict[session_id] = flow_cache if is_flow_cache else None + + # mel overlap fade in out + if is_flow_cache and self.mel_overlap_dict[session_id].shape[2] != 0: + tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[session_id], self.mel_window) + + hift_cache_source = None + if is_hift_cache is True: + if self.hift_cache_dict[session_id] is not None: + # append hift cache + hift_cache_mel, hift_cache_source = ( + self.hift_cache_dict[session_id]["mel"], + self.hift_cache_dict[session_id]["source"], + ) + tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2) + else: + hift_cache_source = torch.zeros(1, 1, 0) + + # keep overlap mel and hift cache + if finalize is False: + if is_flow_cache is True: + self.mel_overlap_dict[session_id] = tts_mel[:, :, -self.mel_overlap_len :] + + tts_mel = tts_mel[:, :, : -self.mel_overlap_len] + tts_speech, tts_source = self.hift.inference( + speech_feat=tts_mel, cache_source=hift_cache_source + ) + + if is_hift_cache is True: + if self.hift_cache_dict[session_id] is not None: + tts_speech = fade_in_out( + tts_speech, self.hift_cache_dict[session_id]["speech"], self.speech_window + ) + self.hift_cache_dict[session_id] = { + "mel": tts_mel[:, :, -self.mel_cache_len :], + "source": tts_source[:, :, -self.source_cache_len :], + "speech": tts_speech[:, -self.source_cache_len :], + } + + tts_speech = tts_speech[:, : -self.source_cache_len] + else: + if speed != 1.0: + if is_hift_cache is True: + assert ( + self.hift_cache_dict[session_id] is None + ), "speed change only support non-stream inference mode" + tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode="linear") + tts_speech, tts_source = self.hift.inference( + speech_feat=tts_mel, cache_source=hift_cache_source + ) + if is_hift_cache is True and self.hift_cache_dict[session_id] is not None: + tts_speech = fade_in_out( + tts_speech, self.hift_cache_dict[session_id]["speech"], self.speech_window + ) + + return tts_speech diff --git a/cosyvoice/flow/flow.py b/cosyvoice/flow/flow.py index ab9a812..b2fe43b 100644 --- a/cosyvoice/flow/flow.py +++ b/cosyvoice/flow/flow.py @@ -121,9 +121,7 @@ def forward( conds = conds.transpose(1, 2) mask = (~make_pad_mask(feat_len)).to(h) - feat = F.interpolate( - feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest" - ).squeeze(dim=1) + feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1) loss, _ = self.decoder.compute_loss( feat.transpose(1, 2).contiguous(), mask.unsqueeze(1), @@ -143,6 +141,7 @@ def inference( prompt_feat, prompt_feat_len, embedding, + flow_cache, ): assert token.shape[0] == 1 # xvec projection @@ -159,11 +158,14 @@ def inference( token = self.input_embedding(torch.clamp(token, min=0)) h, _ = self.encoder.inference(token, token_len) h = self.encoder_proj(h) - mel_len1, mel_len2 = prompt_feat.shape[1], int( - token_len2 - / self.input_frame_rate - * self.mel_feat_conf["sampling_rate"] - / self.mel_feat_conf["hop_size"] + mel_len1, mel_len2 = ( + prompt_feat.shape[1], + int( + token_len2 + / self.input_frame_rate + * self.mel_feat_conf["sampling_rate"] + / self.mel_feat_conf["hop_size"] + ), ) h, _ = self.length_regulator.inference( @@ -174,23 +176,21 @@ def inference( ) # get conditions - conds = torch.zeros( - [1, mel_len1 + mel_len2, self.output_size], device=token.device - ) + conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device) conds[:, :mel_len1] = prompt_feat conds = conds.transpose(1, 2) # mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h) - mask = torch.ones( - [1, mel_len1 + mel_len2], device=h.device, dtype=torch.bfloat16 - ) + mask = torch.ones([1, mel_len1 + mel_len2], device=h.device, dtype=torch.bfloat16) feat = self.decoder( mu=h.transpose(1, 2).contiguous(), mask=mask.unsqueeze(1), spks=embedding, cond=conds, n_timesteps=10, + prompt_len=mel_len1, + flow_cache=flow_cache, ) feat = feat[:, :, mel_len1:] assert feat.shape[2] == mel_len2 - return feat + return feat, flow_cache diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py index d29673f..7803eab 100644 --- a/cosyvoice/flow/flow_matching.py +++ b/cosyvoice/flow/flow_matching.py @@ -51,6 +51,9 @@ def forward( temperature=1.0, spks=None, cond=None, + prompt_len=0, + # flow_cache=torch.zeros(1, 80, 0, 2), + flow_cache=None, ): """Forward diffusion @@ -69,13 +72,24 @@ def forward( sample: generated mel-spectrogram shape: (batch_size, n_feats, mel_timesteps) """ - z = torch.randn_like(mu) * temperature + z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature + + if flow_cache is not None: + cache_size = flow_cache.shape[2] + # fix prompt and overlap part mu and z + if cache_size != 0: + z[:, :, :cache_size] = flow_cache[:, :, :, 0] + mu[:, :, :cache_size] = flow_cache[:, :, :, 1] + z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2) + mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2) + flow_cache = torch.stack([z_cache, mu_cache], dim=-1) + t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype) if self.t_scheduler == "cosine": t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) return self.solve_euler( z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond - ) + ), flow_cache @torch.inference_mode() def capture_inference(self, seq_len_to_capture=list(range(128, 512, 8))): @@ -96,9 +110,7 @@ def capture_inference(self, seq_len_to_capture=list(range(128, 512, 8))): static_mask = torch.ones( 1, 1, seq_len, device=torch.device("cuda"), dtype=torch.bfloat16 ) - static_spks = torch.randn( - 1, 80, device=torch.device("cuda"), dtype=torch.bfloat16 - ) + static_spks = torch.randn(1, 80, device=torch.device("cuda"), dtype=torch.bfloat16) static_cond = torch.randn( 1, 80, seq_len, device=torch.device("cuda"), dtype=torch.float32 ) @@ -231,9 +243,7 @@ def _solve_euler_impl(self, x, t_span, mu, mask, spks, cond): mu_double = torch.cat([mu, torch.zeros_like(mu)], dim=0) t_double = torch.cat([t, t], dim=0) spks_double = ( - torch.cat([spks, torch.zeros_like(spks)], dim=0) - if spks is not None - else None + torch.cat([spks, torch.zeros_like(spks)], dim=0) if spks is not None else None ) cond_double = torch.cat([cond, torch.zeros_like(cond)], dim=0) @@ -309,7 +319,5 @@ def compute_loss(self, x1, mask, mu, spks=None, cond=None): cond = cond * cfg_mask.view(-1, 1, 1) pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond) - loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / ( - torch.sum(mask) * u.shape[1] - ) + loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1]) return loss, y diff --git a/cosyvoice/hifigan/generator.py b/cosyvoice/hifigan/generator.py index 4d02c03..f27596a 100644 --- a/cosyvoice/hifigan/generator.py +++ b/cosyvoice/hifigan/generator.py @@ -146,17 +146,15 @@ def forward(self, f0): :return: [B, 1, sample_len] """ - F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to( - f0.device - ) + F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device) for i in range(self.harmonic_num + 1): F_mat[:, i : i + 1, :] = f0 * (i + 1) / self.sampling_rate theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1) u_dist = Uniform(low=-np.pi, high=np.pi) - phase_vec = u_dist.sample( - sample_shape=(f0.size(0), self.harmonic_num + 1, 1) - ).to(F_mat.device) + phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to( + F_mat.device + ) phase_vec[:, 0, :] = 0 # generate sine waveforms @@ -322,9 +320,7 @@ def __init__( ): if u == 1: self.source_downs.append( - Conv1d( - istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1 - ) + Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1) ) else: self.source_downs.append( @@ -337,21 +333,15 @@ def __init__( ) ) - self.source_resblocks.append( - ResBlock(base_channels // (2 ** (i + 1)), k, d) - ) + self.source_resblocks.append(ResBlock(base_channels // (2 ** (i + 1)), k, d)) self.resblocks = nn.ModuleList() for i in range(len(self.ups)): ch = base_channels // (2 ** (i + 1)) - for _, (k, d) in enumerate( - zip(resblock_kernel_sizes, resblock_dilation_sizes) - ): + for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): self.resblocks.append(ResBlock(ch, k, d)) - self.conv_post = weight_norm( - Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3) - ) + self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3)) self.ups.apply(init_weights) self.conv_post.apply(init_weights) self.reflection_pad = nn.ReflectionPad1d((1, 0)) @@ -491,6 +481,11 @@ def inference( curr_seq_len = mel.shape[2] f0 = self.f0_predictor(mel) s = self._f02source(f0) + + # use cache_source to avoid glitch + if cache_source is not None and cache_source.shape[2] != 0: + s[:, :, : cache_source.shape[2]] = cache_source + s_stft_real, s_stft_imag = self._stft(s.squeeze(1)) s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1) @@ -533,13 +528,9 @@ def inference( @torch.inference_mode() def capture_inference(self, seq_len_to_capture=[64, 128, 256, 512, 1024]): start_time = time.time() - print( - f"capture inference for HiFTGenerator with seq_len_to_capture: {seq_len_to_capture}" - ) + print(f"capture inference for HiFTGenerator with seq_len_to_capture: {seq_len_to_capture}") for seq_len in seq_len_to_capture: - mel = torch.randn( - 1, 80, seq_len, device=torch.device("cuda"), dtype=torch.float32 - ) + mel = torch.randn(1, 80, seq_len, device=torch.device("cuda"), dtype=torch.float32) f0 = self.f0_predictor(mel) s = self._f02source(f0) s_stft_real, s_stft_imag = self._stft(s.squeeze(1)) diff --git a/tts.py b/tts.py index 8276344..43d60bc 100644 --- a/tts.py +++ b/tts.py @@ -12,9 +12,7 @@ class RepetitionAwareLogitsProcessor(LogitsProcessor): - def __call__( - self, input_ids: torch.LongTensor, scores: torch.FloatTensor - ) -> torch.FloatTensor: + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: window_size = 10 threshold = 0.1 @@ -43,15 +41,9 @@ def __init__( device_map="cuda", trust_remote_code=True, ) - self.tokenizer = AutoTokenizer.from_pretrained( - model_path, trust_remote_code=True - ) - self.common_cosy_model = CosyVoice( - os.path.join(model_path, "CosyVoice-300M-25Hz") - ) - self.music_cosy_model = CosyVoice( - os.path.join(model_path, "CosyVoice-300M-25Hz-Music") - ) + self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + self.common_cosy_model = CosyVoice(os.path.join(model_path, "CosyVoice-300M-25Hz")) + self.music_cosy_model = CosyVoice(os.path.join(model_path, "CosyVoice-300M-25Hz-Music")) self.encoder = encoder self.sys_prompt_dict = { "sys_prompt_for_rap": "请参考对话历史里的音色,用RAP方式将文本内容大声说唱出来。", @@ -62,35 +54,9 @@ def __init__( self.register_speakers() def __call__(self, text: str, prompt_speaker: str, clone_dict: dict | None = None): - if clone_dict: - clone_prompt_code, clone_prompt_token, clone_prompt_token_len, clone_speech_feat, clone_speech_feat_len, clone_speech_embedding = ( - self.preprocess_prompt_wav(clone_dict['wav_path']) - ) - prompt_speaker = clone_dict['speaker'] - self.speakers_info[prompt_speaker] = { - "prompt_text": clone_dict['prompt_text'], - "prompt_code": clone_prompt_code, - "cosy_speech_feat": clone_speech_feat.to(torch.bfloat16), - "cosy_speech_feat_len": clone_speech_feat_len, - "cosy_speech_embedding": clone_speech_embedding.to(torch.bfloat16), - "cosy_prompt_token": clone_prompt_token, - "cosy_prompt_token_len": clone_prompt_token_len, - } - - instruction_name = self.detect_instruction_name(text) - prompt_speaker_info = self.speakers_info[prompt_speaker] - - if instruction_name in ("RAP", "哼唱"): - if not clone_dict: - prompt_speaker_info = self.speakers_info[ - f"{prompt_speaker}{instruction_name}" - ] - cosy_model = self.music_cosy_model - else: - cosy_model = self.common_cosy_model - - if clone_dict: - prompt_speaker = '' + prompt_speaker_info, prompt_speaker, cosy_model = self.prepare_prompt( + text, prompt_speaker, clone_dict=clone_dict + ) token_ids = self.tokenize( text, @@ -126,9 +92,14 @@ def register_speakers(self): for speaker_id, prompt_text in speakers_info.items(): prompt_wav_path = f"speakers/{speaker_id}_prompt.wav" - prompt_code, prompt_token, prompt_token_len, speech_feat, speech_feat_len, speech_embedding = ( - self.preprocess_prompt_wav(prompt_wav_path) - ) + ( + prompt_code, + prompt_token, + prompt_token_len, + speech_feat, + speech_feat_len, + speech_embedding, + ) = self.preprocess_prompt_wav(prompt_wav_path) self.speakers_info[speaker_id] = { "prompt_text": prompt_text, @@ -149,9 +120,7 @@ def detect_instruction_name(self, text): instruction_name = instruction.strip("()()") return instruction_name - def tokenize( - self, text: str, prompt_text: str, prompt_speaker: str, prompt_code: list - ): + def tokenize(self, text: str, prompt_text: str, prompt_speaker: str, prompt_code: list): rap_or_vocal = self.detect_instruction_name(text) in ("RAP", "哼唱") if rap_or_vocal: @@ -197,23 +166,21 @@ def tokenize( ) return history - def preprocess_prompt_wav(self, prompt_wav_path : str): + def preprocess_prompt_wav(self, prompt_wav_path: str): prompt_wav, prompt_wav_sr = torchaudio.load(prompt_wav_path) - prompt_wav_16k = torchaudio.transforms.Resample( - orig_freq=prompt_wav_sr, new_freq=16000 - )(prompt_wav) - prompt_wav_22k = torchaudio.transforms.Resample( - orig_freq=prompt_wav_sr, new_freq=22050 - )(prompt_wav) - - speech_feat, speech_feat_len = ( - self.common_cosy_model.frontend._extract_speech_feat(prompt_wav_22k) + prompt_wav_16k = torchaudio.transforms.Resample(orig_freq=prompt_wav_sr, new_freq=16000)( + prompt_wav ) - speech_embedding = self.common_cosy_model.frontend._extract_spk_embedding( - prompt_wav_16k + prompt_wav_22k = torchaudio.transforms.Resample(orig_freq=prompt_wav_sr, new_freq=22050)( + prompt_wav ) + speech_feat, speech_feat_len = self.common_cosy_model.frontend._extract_speech_feat( + prompt_wav_22k + ) + speech_embedding = self.common_cosy_model.frontend._extract_spk_embedding(prompt_wav_16k) + prompt_code, _, _ = self.encoder.wav2token(prompt_wav, prompt_wav_sr) prompt_token = torch.tensor([prompt_code], dtype=torch.long) - 65536 prompt_token_len = torch.tensor([prompt_token.shape[1]], dtype=torch.long) @@ -225,4 +192,40 @@ def preprocess_prompt_wav(self, prompt_wav_path : str): speech_feat, speech_feat_len, speech_embedding, - ) \ No newline at end of file + ) + + def prepare_prompt(self, text: str, prompt_speaker: str, clone_dict: dict | None = None): + if clone_dict: + ( + clone_prompt_code, + clone_prompt_token, + clone_prompt_token_len, + clone_speech_feat, + clone_speech_feat_len, + clone_speech_embedding, + ) = self.preprocess_prompt_wav(clone_dict["wav_path"]) + prompt_speaker = clone_dict["speaker"] + self.speakers_info[prompt_speaker] = { + "prompt_text": clone_dict["prompt_text"], + "prompt_code": clone_prompt_code, + "cosy_speech_feat": clone_speech_feat.to(torch.bfloat16), + "cosy_speech_feat_len": clone_speech_feat_len, + "cosy_speech_embedding": clone_speech_embedding.to(torch.bfloat16), + "cosy_prompt_token": clone_prompt_token, + "cosy_prompt_token_len": clone_prompt_token_len, + } + + instruction_name = self.detect_instruction_name(text) + prompt_speaker_info = self.speakers_info[prompt_speaker] + + if instruction_name in ("RAP", "哼唱"): + if not clone_dict: + prompt_speaker_info = self.speakers_info[f"{prompt_speaker}{instruction_name}"] + cosy_model = self.music_cosy_model + else: + cosy_model = self.common_cosy_model + + if clone_dict: + prompt_speaker = "" + + return prompt_speaker, prompt_speaker_info, cosy_model From 1b096e86fc7465b04debab186a9a52421148af7c Mon Sep 17 00:00:00 2001 From: weedge Date: Thu, 20 Feb 2025 17:26:54 +0800 Subject: [PATCH 02/37] feat: add tts static batch stream generate waveform Signed-off-by: weedge --- streamer.py | 41 ++++++++++ tts.py | 224 ++++++++++++++++++++++++++++++++++++++-------------- utils.py | 77 +++++++++++++++--- 3 files changed, 271 insertions(+), 71 deletions(-) create mode 100644 streamer.py diff --git a/streamer.py b/streamer.py new file mode 100644 index 0000000..57bce40 --- /dev/null +++ b/streamer.py @@ -0,0 +1,41 @@ +from queue import Queue + +from transformers.generation.streamers import BaseStreamer + + +class TokenStreamer(BaseStreamer): + def __init__(self, skip_prompt: bool = False, timeout=None): + self.skip_prompt = skip_prompt + + # variables used in the streaming process + self.token_queue = Queue() + self.stop_signal = None + self.next_tokens_are_prompt = True + self.timeout = timeout + + def put(self, value): + if len(value.shape) > 1 and value.shape[0] > 1: + raise ValueError("TextStreamer only supports batch size 1") + elif len(value.shape) > 1: + value = value[0] + # print(value) + + if self.skip_prompt and self.next_tokens_are_prompt: + self.next_tokens_are_prompt = False + return + + for token in value.tolist(): + self.token_queue.put(token) + + def end(self): + self.token_queue.put(self.stop_signal) + + def __iter__(self): + return self + + def __next__(self): + value = self.token_queue.get(timeout=self.timeout) + if value == self.stop_signal: + raise StopIteration() + else: + return value diff --git a/tts.py b/tts.py index e0c7036..82a37d8 100644 --- a/tts.py +++ b/tts.py @@ -1,6 +1,10 @@ +import math import os import re import json +from threading import Thread, Lock +import uuid + import torchaudio import torch @@ -9,12 +13,11 @@ from transformers.generation.utils import LogitsProcessorList from cosyvoice.cli.cosyvoice import CosyVoice +from streamer import TokenStreamer class RepetitionAwareLogitsProcessor(LogitsProcessor): - def __call__( - self, input_ids: torch.LongTensor, scores: torch.FloatTensor - ) -> torch.FloatTensor: + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: window_size = 10 threshold = 0.1 @@ -36,22 +39,20 @@ def __init__( self, model_path, encoder, + device_map=None, + stream_factor=2, + **kwargs, ): self.llm = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch.bfloat16, - device_map="cuda", + device_map="cuda" if not device_map else device_map, trust_remote_code=True, + **kwargs, ) - self.tokenizer = AutoTokenizer.from_pretrained( - model_path, trust_remote_code=True - ) - self.common_cosy_model = CosyVoice( - os.path.join(model_path, "CosyVoice-300M-25Hz") - ) - self.music_cosy_model = CosyVoice( - os.path.join(model_path, "CosyVoice-300M-25Hz-Music") - ) + self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + self.common_cosy_model = CosyVoice(os.path.join(model_path, "CosyVoice-300M-25Hz")) + self.music_cosy_model = CosyVoice(os.path.join(model_path, "CosyVoice-300M-25Hz-Music")) self.encoder = encoder self.sys_prompt_dict = { "sys_prompt_for_rap": "请参考对话历史里的音色,用RAP方式将文本内容大声说唱出来。", @@ -61,36 +62,20 @@ def __init__( } self.register_speakers() - def __call__(self, text: str, prompt_speaker: str, clone_dict: dict | None = None): - if clone_dict: - clone_prompt_code, clone_prompt_token, clone_prompt_token_len, clone_speech_feat, clone_speech_feat_len, clone_speech_embedding = ( - self.preprocess_prompt_wav(clone_dict['wav_path']) - ) - prompt_speaker = clone_dict['speaker'] - self.speakers_info[prompt_speaker] = { - "prompt_text": clone_dict['prompt_text'], - "prompt_code": clone_prompt_code, - "cosy_speech_feat": clone_speech_feat.to(torch.bfloat16), - "cosy_speech_feat_len": clone_speech_feat_len, - "cosy_speech_embedding": clone_speech_embedding.to(torch.bfloat16), - "cosy_prompt_token": clone_prompt_token, - "cosy_prompt_token_len": clone_prompt_token_len, - } + self.streamer = TokenStreamer + assert ( + stream_factor >= 2 + ), "stream_factor must >=2 increase for better speech quality, but rft slow (speech quality vs rft)" + self.stream_factor = stream_factor # >=2 increase for better speech quality, but rft slow (speech quality vs rft) - instruction_name = self.detect_instruction_name(text) - prompt_speaker_info = self.speakers_info[prompt_speaker] - - if instruction_name in ("RAP", "哼唱"): - if not clone_dict: - prompt_speaker_info = self.speakers_info[ - f"{prompt_speaker}{instruction_name}" - ] - cosy_model = self.music_cosy_model - else: - cosy_model = self.common_cosy_model + # session ctx dict with lock, maybe need a session class + self.session_lm_generat_lock = Lock() + self.session_lm_generated_ids = {} # session_id: ids(ptr) - if clone_dict: - prompt_speaker = '' + def __call__(self, text: str, prompt_speaker: str, clone_dict: dict | None = None): + prompt_speaker, prompt_speaker_info, cosy_model = self.preprocess_promt( + text, prompt_speaker, clone_dict=clone_dict + ) token_ids = self.tokenize( text, @@ -126,9 +111,14 @@ def register_speakers(self): for speaker_id, prompt_text in speakers_info.items(): prompt_wav_path = f"speakers/{speaker_id}_prompt.wav" - prompt_code, prompt_token, prompt_token_len, speech_feat, speech_feat_len, speech_embedding = ( - self.preprocess_prompt_wav(prompt_wav_path) - ) + ( + prompt_code, + prompt_token, + prompt_token_len, + speech_feat, + speech_feat_len, + speech_embedding, + ) = self.preprocess_prompt_wav(prompt_wav_path) self.speakers_info[speaker_id] = { "prompt_text": prompt_text, @@ -149,9 +139,7 @@ def detect_instruction_name(self, text): instruction_name = instruction.strip("()()") return instruction_name - def tokenize( - self, text: str, prompt_text: str, prompt_speaker: str, prompt_code: list - ): + def tokenize(self, text: str, prompt_text: str, prompt_speaker: str, prompt_code: list): rap_or_vocal = self.detect_instruction_name(text) in ("RAP", "哼唱") if rap_or_vocal: @@ -197,24 +185,22 @@ def tokenize( ) return history - def preprocess_prompt_wav(self, prompt_wav_path : str): + def preprocess_prompt_wav(self, prompt_wav_path: str): prompt_wav, prompt_wav_sr = torchaudio.load(prompt_wav_path) if prompt_wav.shape[0] > 1: prompt_wav = prompt_wav.mean(dim=0, keepdim=True) # 将多通道音频转换为单通道 - prompt_wav_16k = torchaudio.transforms.Resample( - orig_freq=prompt_wav_sr, new_freq=16000 - )(prompt_wav) - prompt_wav_22k = torchaudio.transforms.Resample( - orig_freq=prompt_wav_sr, new_freq=22050 - )(prompt_wav) - - speech_feat, speech_feat_len = ( - self.common_cosy_model.frontend._extract_speech_feat(prompt_wav_22k) + prompt_wav_16k = torchaudio.transforms.Resample(orig_freq=prompt_wav_sr, new_freq=16000)( + prompt_wav ) - speech_embedding = self.common_cosy_model.frontend._extract_spk_embedding( - prompt_wav_16k + prompt_wav_22k = torchaudio.transforms.Resample(orig_freq=prompt_wav_sr, new_freq=22050)( + prompt_wav ) + speech_feat, speech_feat_len = self.common_cosy_model.frontend._extract_speech_feat( + prompt_wav_22k + ) + speech_embedding = self.common_cosy_model.frontend._extract_spk_embedding(prompt_wav_16k) + prompt_code, _, _ = self.encoder.wav2token(prompt_wav, prompt_wav_sr) prompt_token = torch.tensor([prompt_code], dtype=torch.long) - 65536 prompt_token_len = torch.tensor([prompt_token.shape[1]], dtype=torch.long) @@ -226,4 +212,122 @@ def preprocess_prompt_wav(self, prompt_wav_path : str): speech_feat, speech_feat_len, speech_embedding, - ) \ No newline at end of file + ) + + def preprocess_promt(self, text: str, prompt_speaker: str, clone_dict: dict | None = None): + if clone_dict: + ( + clone_prompt_code, + clone_prompt_token, + clone_prompt_token_len, + clone_speech_feat, + clone_speech_feat_len, + clone_speech_embedding, + ) = self.preprocess_prompt_wav(clone_dict["wav_path"]) + prompt_speaker = clone_dict["speaker"] + self.speakers_info[prompt_speaker] = { + "prompt_text": clone_dict["prompt_text"], + "prompt_code": clone_prompt_code, + "cosy_speech_feat": clone_speech_feat.to(torch.bfloat16), + "cosy_speech_feat_len": clone_speech_feat_len, + "cosy_speech_embedding": clone_speech_embedding.to(torch.bfloat16), + "cosy_prompt_token": clone_prompt_token, + "cosy_prompt_token_len": clone_prompt_token_len, + } + + instruction_name = self.detect_instruction_name(text) + prompt_speaker_info = self.speakers_info[prompt_speaker] + + if instruction_name in ("RAP", "哼唱"): + if not clone_dict: + prompt_speaker_info = self.speakers_info[f"{prompt_speaker}{instruction_name}"] + cosy_model = self.music_cosy_model + else: + cosy_model = self.common_cosy_model + + return prompt_speaker, prompt_speaker_info, cosy_model + + def static_batch_stream( + self, + text: str, + prompt_speaker: str, + clone_dict: dict | None = None, + session_id: str = str(uuid.uuid4()), + ): + """ + - step1 lm stream generate token + - static batch size to gen waveform + - flow: audio vq tokens to mel + - hifi: mel to waveform + """ + prompt_speaker, prompt_speaker_info, cosy_model = self.preprocess_promt( + text, prompt_speaker, clone_dict=clone_dict + ) + + token_ids = self.tokenize( + text, + prompt_speaker_info["prompt_text"], + prompt_speaker, + prompt_speaker_info["prompt_code"], + ) + + generation_kwargs = dict( + input_ids=torch.tensor([token_ids]).to(torch.long).to("cuda"), + eos_token_id=3, + streamer=self.streamer, + max_length=8192, + temperature=0.7, + do_sample=True, + logits_processor=LogitsProcessorList([RepetitionAwareLogitsProcessor()]), + ) + # print("generation_kwargs", generation_kwargs) + + thread = Thread(target=self.llm.generate, kwargs=generation_kwargs) + thread.start() + + with self.session_lm_generat_lock: + self.session_lm_generated_ids[session_id] = [] + + batch_size = math.ceil(self.stream_factor * cosy_model.model.flow.input_frame_rate) + for token_id in self.streamer: + # print(token_id, end=",", flush=True) + if token_id == 3: # skip <|EOT|>, break + break + self.session_lm_generated_ids[session_id].append(token_id) + if len(self.session_lm_generated_ids[session_id]) % batch_size == 0: + batch = ( + torch.tensor(self.session_lm_generated_ids[session_id]) + .unsqueeze(0) + .to(self.device) + ) # [T] -> [1,T] + # Process each batch + sub_tts_speech = self.common_cosy_model.token_to_wav_offline( + batch, + prompt_speaker_info["cosy_speech_feat"].to(torch.bfloat16), + prompt_speaker_info["cosy_speech_feat_len"], + prompt_speaker_info["cosy_prompt_token"], + prompt_speaker_info["cosy_prompt_token_len"], + prompt_speaker_info["cosy_speech_embedding"].to(torch.bfloat16), + ) + yield {"tts_speech": sub_tts_speech} + with self.session_lm_generat_lock: + self.session_lm_generated_ids[session_id] = [] + + if len(self.session_lm_generated_ids[session_id]) > 0: + batch = ( + torch.tensor(self.session_lm_generated_ids[session_id]).unsqueeze(0).to(self.device) + ) # [T] -> [1,T] + # Process each batch + sub_tts_speech = self.common_cosy_model.token_to_wav_offline( + batch, + prompt_speaker_info["cosy_speech_feat"].to(torch.bfloat16), + prompt_speaker_info["cosy_speech_feat_len"], + prompt_speaker_info["cosy_prompt_token"], + prompt_speaker_info["cosy_prompt_token_len"], + prompt_speaker_info["cosy_speech_embedding"].to(torch.bfloat16), + ) + yield {"tts_speech": sub_tts_speech} + + with self.lock: + self.session_lm_generated_ids.pop(session_id) + torch.cuda.empty_cache() diff --git a/utils.py b/utils.py index 2c98513..f4cb260 100644 --- a/utils.py +++ b/utils.py @@ -33,18 +33,14 @@ def trim_silence(audio, sr, keep_left_time=0.05, keep_right_time=0.22, hop_size= if start_idx > 0: trim_wav = trim_wav[start_idx:] else: - trim_wav = np.pad( - trim_wav, (abs(start_idx), 0), mode="constant", constant_values=0.0 - ) + trim_wav = np.pad(trim_wav, (abs(start_idx), 0), mode="constant", constant_values=0.0) wav_len = len(trim_wav) out_len = int(num_frames * hop_size + (keep_left_time + keep_right_time) * sr) if out_len < wav_len: trim_wav = trim_wav[:out_len] else: - trim_wav = np.pad( - trim_wav, (0, (out_len - wav_len)), mode="constant", constant_values=0.0 - ) + trim_wav = np.pad(trim_wav, (0, (out_len - wav_len)), mode="constant", constant_values=0.0) return trim_wav @@ -109,9 +105,7 @@ def audio_resample(audio16bit_torch, result_sr, target_sample_rate): def norm_audio(audio16bit_torch): # 直接 归一化处理。 audio16bit_torch = audio16bit_torch.numpy() - audio16bit_torch = ( - audio16bit_torch / np.abs(audio16bit_torch).max() * 32767 - ).astype(np.int16) + audio16bit_torch = (audio16bit_torch / np.abs(audio16bit_torch).max() * 32767).astype(np.int16) audio16bit_torch = torch.from_numpy(audio16bit_torch) return audio16bit_torch @@ -142,8 +136,7 @@ def energy_norm_fn(wav): def get_audio_tokens(audio_tokens: str) -> list[int]: audio_tokens = audio_tokens.split(">", "")) + 65536 - for token in audio_tokens + int(token.replace("", "")) + 65536 for token in audio_tokens ] return audio_tokens @@ -152,3 +145,65 @@ def load_audio(audio_path: str): audio_wav, sr = torchaudio.load(audio_path) audio_wav = audio_wav.mean(dim=0, keepdim=True) return audio_wav, sr + + +def splite_batches(tensor_audio_token_ids: torch.Tensor, batch_size: int): + """ + splite batches of audio token IDs. + # Assuming tensor_audio_token_ids is already defined as in your original code + # and has shape (1, sequence_length) + + Args: + tensor_audio_token_ids: A tensor of audio token IDs. + batch_size: The desired batch size. + + Returns: + A list of tensors, where each tensor is a batch of audio token IDs. + """ + sequence_length = tensor_audio_token_ids.shape[1] + num_batches = (sequence_length + batch_size - 1) // batch_size + batched_ids = [] + + for i in range(num_batches): + start_index = i * batch_size + end_index = min((i + 1) * batch_size, sequence_length) + batch = tensor_audio_token_ids[:, start_index:end_index] + batched_ids.append(batch) + + return batched_ids + + +def merge_tensors(sub_tts_speechs: torch.Tensor): + """ + Merges a list of tensors into a single tensor. + # Assuming sub_tts_speechs is a list of tensors + # and all tensors in the list have the same number of channels + # and has shape (1, sequence_length) + # but possibly different lengths. + + Args: + sub_tts_speechs: A list of tensors. + + Returns: + A single tensor with all the sub tensors concatenated along the time dimension. + Returns None if the input list is empty or contains tensors with inconsistent shapes. + """ + if not sub_tts_speechs: + return None + + num_channels = sub_tts_speechs[0].shape[0] + total_length = sum(tensor.shape[1] for tensor in sub_tts_speechs) + merged_tensor = torch.empty( + num_channels, total_length, dtype=sub_tts_speechs[0].dtype, device=sub_tts_speechs[0].device + ) + current_position = 0 + + for tensor in sub_tts_speechs: + if tensor.shape[0] != num_channels: + print("Error: Tensors have inconsistent number of channels.") + return None + + merged_tensor[:, current_position : current_position + tensor.shape[1]] = tensor + current_position += tensor.shape[1] + + return merged_tensor From a79c5af8b9fef378bd7042ef8a7192821aa681f7 Mon Sep 17 00:00:00 2001 From: weedge Date: Thu, 20 Feb 2025 17:39:30 +0800 Subject: [PATCH 03/37] feat: add tts stream inference Signed-off-by: weedge --- tts_inference_stream.py | 43 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 tts_inference_stream.py diff --git a/tts_inference_stream.py b/tts_inference_stream.py new file mode 100644 index 0000000..fdfeea4 --- /dev/null +++ b/tts_inference_stream.py @@ -0,0 +1,43 @@ +import torchaudio +import argparse +from tts import StepAudioTTS +from tokenizer import StepAudioTokenizer +from utils import load_audio +import os + + +def main(): + parser = argparse.ArgumentParser(description="StepAudio Stream Inference") + parser.add_argument("--model-path", type=str, required=True, help="Base path for model files") + parser.add_argument( + "--synthesis-type", type=str, default="tts", help="Use tts or Clone for Synthesis" + ) + parser.add_argument( + "--output-path", type=str, required=True, help="Output path for synthesis audios" + ) + parser.add_argument( + "--stream", type=str, default="static_batch", help="synthesis audios with streaming" + ) + args = parser.parse_args() + os.makedirs(f"{args.output_path}", exist_ok=True) + + encoder = StepAudioTokenizer(f"{args.model_path}/Step-Audio-Tokenizer") + tts_engine = StepAudioTTS(f"{args.model_path}/Step-Audio-TTS-3B", encoder) + + if args.synthesis_type == "tts": + text = "(RAP)君不见黄河之水天上来,奔流到海不复回。君不见高堂明镜悲白发,朝如青丝暮成雪。人生得意须尽欢,莫使金樽空对月。天生我材必有用,千金散尽还复来。" + output_audio, sr = tts_engine.static_batch_stream(text, "Tingting") + torchaudio.save(f"{args.output_path}/output_tts.wav", output_audio, sr) + else: + clone_speaker = { + "speaker": "test", + "prompt_text": "叫做秋风起蟹脚痒,啊,什么意思呢?就是说这秋风一起啊,螃蟹就该上市了。", + "wav_path": "examples/prompt_wav_yuqian.wav", + } + text_clone = "万物之始,大道至简,衍化至繁。君不见黄河之水天上来,奔流到海不复回。君不见高堂明镜悲白发,朝如青丝暮成雪。人生得意须尽欢,莫使金樽空对月。天生我材必有用,千金散尽还复来。" + output_audio, sr = tts_engine.static_batch_stream(text_clone, "", clone_speaker) + torchaudio.save(f"{args.output_path}/output_clone.wav", output_audio, sr) + + +if __name__ == "__main__": + main() From c8677282927009c820aae69415f64e1d61aae06d Mon Sep 17 00:00:00 2001 From: weedge Date: Thu, 20 Feb 2025 18:00:59 +0800 Subject: [PATCH 04/37] feat: add tts static batch stream to merge Signed-off-by: weedge --- tts.py | 4 ++-- tts_inference_stream.py | 22 +++++++++++++++++----- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/tts.py b/tts.py index 82a37d8..03fc0d6 100644 --- a/tts.py +++ b/tts.py @@ -309,7 +309,7 @@ def static_batch_stream( prompt_speaker_info["cosy_prompt_token_len"], prompt_speaker_info["cosy_speech_embedding"].to(torch.bfloat16), ) - yield {"tts_speech": sub_tts_speech} + yield {"tts_speech": sub_tts_speech, "sample_rate": 22050} with self.session_lm_generat_lock: self.session_lm_generated_ids[session_id] = [] @@ -326,7 +326,7 @@ def static_batch_stream( prompt_speaker_info["cosy_prompt_token_len"], prompt_speaker_info["cosy_speech_embedding"].to(torch.bfloat16), ) - yield {"tts_speech": sub_tts_speech} + yield {"tts_speech": sub_tts_speech, "sample_rate": 22050} with self.lock: self.session_lm_generated_ids.pop(session_id) diff --git a/tts_inference_stream.py b/tts_inference_stream.py index fdfeea4..9f6fe01 100644 --- a/tts_inference_stream.py +++ b/tts_inference_stream.py @@ -2,7 +2,7 @@ import argparse from tts import StepAudioTTS from tokenizer import StepAudioTokenizer -from utils import load_audio +from utils import merge_tensors import os @@ -26,8 +26,14 @@ def main(): if args.synthesis_type == "tts": text = "(RAP)君不见黄河之水天上来,奔流到海不复回。君不见高堂明镜悲白发,朝如青丝暮成雪。人生得意须尽欢,莫使金樽空对月。天生我材必有用,千金散尽还复来。" - output_audio, sr = tts_engine.static_batch_stream(text, "Tingting") - torchaudio.save(f"{args.output_path}/output_tts.wav", output_audio, sr) + batch_stream = tts_engine.static_batch_stream(text, "Tingting") + sub_tts_speechs = [] + sr = 22050 + for item in batch_stream: + sr = item["sample_rate"] + sub_tts_speechs.append(item["tts_speech"]) + output_audio = merge_tensors(sub_tts_speechs) # [1,T] + torchaudio.save(f"{args.output_path}/output_tts_stream.wav", output_audio, sr) else: clone_speaker = { "speaker": "test", @@ -35,8 +41,14 @@ def main(): "wav_path": "examples/prompt_wav_yuqian.wav", } text_clone = "万物之始,大道至简,衍化至繁。君不见黄河之水天上来,奔流到海不复回。君不见高堂明镜悲白发,朝如青丝暮成雪。人生得意须尽欢,莫使金樽空对月。天生我材必有用,千金散尽还复来。" - output_audio, sr = tts_engine.static_batch_stream(text_clone, "", clone_speaker) - torchaudio.save(f"{args.output_path}/output_clone.wav", output_audio, sr) + batch_stream = tts_engine.static_batch_stream(text_clone, "", clone_speaker) + sub_tts_speechs = [] + sr = 22050 + for item in batch_stream: + sr = item["sample_rate"] + sub_tts_speechs.append(item["tts_speech"]) + output_audio = merge_tensors(sub_tts_speechs) # [1,T] + torchaudio.save(f"{args.output_path}/output_clone_stream.wav", output_audio, sr) if __name__ == "__main__": From 46f22e9bada70d06acf0bb0043dfef8dbc76f86e Mon Sep 17 00:00:00 2001 From: weedge Date: Thu, 20 Feb 2025 18:22:12 +0800 Subject: [PATCH 05/37] streamer skip prefix prompt Signed-off-by: weedge --- tts.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tts.py b/tts.py index 03fc0d6..5618acf 100644 --- a/tts.py +++ b/tts.py @@ -62,7 +62,7 @@ def __init__( } self.register_speakers() - self.streamer = TokenStreamer + self.streamer = TokenStreamer(skip_prompt=True) assert ( stream_factor >= 2 ), "stream_factor must >=2 increase for better speech quality, but rft slow (speech quality vs rft)" @@ -290,7 +290,7 @@ def static_batch_stream( batch_size = math.ceil(self.stream_factor * cosy_model.model.flow.input_frame_rate) for token_id in self.streamer: - # print(token_id, end=",", flush=True) + print(token_id, end=",", flush=True) if token_id == 3: # skip <|EOT|>, break break self.session_lm_generated_ids[session_id].append(token_id) @@ -298,10 +298,10 @@ def static_batch_stream( batch = ( torch.tensor(self.session_lm_generated_ids[session_id]) .unsqueeze(0) - .to(self.device) + .to(cosy_model.model.device) ) # [T] -> [1,T] # Process each batch - sub_tts_speech = self.common_cosy_model.token_to_wav_offline( + sub_tts_speech = cosy_model.token_to_wav_offline( batch, prompt_speaker_info["cosy_speech_feat"].to(torch.bfloat16), prompt_speaker_info["cosy_speech_feat_len"], @@ -315,10 +315,12 @@ def static_batch_stream( if len(self.session_lm_generated_ids[session_id]) > 0: batch = ( - torch.tensor(self.session_lm_generated_ids[session_id]).unsqueeze(0).to(self.device) + torch.tensor(self.session_lm_generated_ids[session_id]) + .unsqueeze(0) + .to(cosy_model.model.device) ) # [T] -> [1,T] # Process each batch - sub_tts_speech = self.common_cosy_model.token_to_wav_offline( + sub_tts_speech = cosy_model.token_to_wav_offline( batch, prompt_speaker_info["cosy_speech_feat"].to(torch.bfloat16), prompt_speaker_info["cosy_speech_feat_len"], From 332b5f934e4bca80ea14fdaeaa71ec0bcd393733 Mon Sep 17 00:00:00 2001 From: weedge Date: Thu, 20 Feb 2025 19:09:39 +0800 Subject: [PATCH 06/37] fix: step1lm generated token ids - 65536 to get vq codes Signed-off-by: weedge --- tts.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tts.py b/tts.py index 5618acf..514ef3a 100644 --- a/tts.py +++ b/tts.py @@ -299,6 +299,7 @@ def static_batch_stream( torch.tensor(self.session_lm_generated_ids[session_id]) .unsqueeze(0) .to(cosy_model.model.device) + - 65536 ) # [T] -> [1,T] # Process each batch sub_tts_speech = cosy_model.token_to_wav_offline( @@ -318,6 +319,7 @@ def static_batch_stream( torch.tensor(self.session_lm_generated_ids[session_id]) .unsqueeze(0) .to(cosy_model.model.device) + - 65536 ) # [T] -> [1,T] # Process each batch sub_tts_speech = cosy_model.token_to_wav_offline( From c5314f2dfc8bb5909e99dff0158005fe94a48007 Mon Sep 17 00:00:00 2001 From: weedge Date: Thu, 20 Feb 2025 19:15:27 +0800 Subject: [PATCH 07/37] fix: lock Signed-off-by: weedge --- tts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tts.py b/tts.py index 514ef3a..3c09e21 100644 --- a/tts.py +++ b/tts.py @@ -332,6 +332,6 @@ def static_batch_stream( ) yield {"tts_speech": sub_tts_speech, "sample_rate": 22050} - with self.lock: + with self.session_lm_generat_lock: self.session_lm_generated_ids.pop(session_id) torch.cuda.empty_cache() From 0634c4372b8aa13cbb32a7ca093ddacac36fa2c5 Mon Sep 17 00:00:00 2001 From: weedge Date: Thu, 20 Feb 2025 19:33:43 +0800 Subject: [PATCH 08/37] feat: add stream_factor cmd param and TTS_TEXT env param Signed-off-by: weedge --- tts.py | 9 +++++---- tts_inference_stream.py | 18 ++++++++++++++---- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/tts.py b/tts.py index 3c09e21..d966a60 100644 --- a/tts.py +++ b/tts.py @@ -39,8 +39,8 @@ def __init__( self, model_path, encoder, - device_map=None, - stream_factor=2, + device_map: str | dict | None = None, + stream_factor: int = 2, **kwargs, ): self.llm = AutoModelForCausalLM.from_pretrained( @@ -263,6 +263,7 @@ def static_batch_stream( prompt_speaker, prompt_speaker_info, cosy_model = self.preprocess_promt( text, prompt_speaker, clone_dict=clone_dict ) + output_audio_sample_rate = cosy_model.model.hift.sampling_rate token_ids = self.tokenize( text, @@ -310,7 +311,7 @@ def static_batch_stream( prompt_speaker_info["cosy_prompt_token_len"], prompt_speaker_info["cosy_speech_embedding"].to(torch.bfloat16), ) - yield {"tts_speech": sub_tts_speech, "sample_rate": 22050} + yield {"tts_speech": sub_tts_speech, "sample_rate": output_audio_sample_rate} with self.session_lm_generat_lock: self.session_lm_generated_ids[session_id] = [] @@ -330,7 +331,7 @@ def static_batch_stream( prompt_speaker_info["cosy_prompt_token_len"], prompt_speaker_info["cosy_speech_embedding"].to(torch.bfloat16), ) - yield {"tts_speech": sub_tts_speech, "sample_rate": 22050} + yield {"tts_speech": sub_tts_speech, "sample_rate": output_audio_sample_rate} with self.session_lm_generat_lock: self.session_lm_generated_ids.pop(session_id) diff --git a/tts_inference_stream.py b/tts_inference_stream.py index 9f6fe01..67191a5 100644 --- a/tts_inference_stream.py +++ b/tts_inference_stream.py @@ -1,9 +1,14 @@ -import torchaudio +import os import argparse -from tts import StepAudioTTS +from dotenv import load_dotenv + +import torchaudio + from tokenizer import StepAudioTokenizer from utils import merge_tensors -import os +from tts import StepAudioTTS + +load_dotenv(override=True) def main(): @@ -16,7 +21,10 @@ def main(): "--output-path", type=str, required=True, help="Output path for synthesis audios" ) parser.add_argument( - "--stream", type=str, default="static_batch", help="synthesis audios with streaming" + "--stream", type=str, default="static_batch", help="Synthesis audios with streaming" + ) + parser.add_argument( + "--stream_factor", type=int, default=2, help="Synthesis audios stream factor" ) args = parser.parse_args() os.makedirs(f"{args.output_path}", exist_ok=True) @@ -26,6 +34,7 @@ def main(): if args.synthesis_type == "tts": text = "(RAP)君不见黄河之水天上来,奔流到海不复回。君不见高堂明镜悲白发,朝如青丝暮成雪。人生得意须尽欢,莫使金樽空对月。天生我材必有用,千金散尽还复来。" + text = os.getenv("TTS_TEXT", text) batch_stream = tts_engine.static_batch_stream(text, "Tingting") sub_tts_speechs = [] sr = 22050 @@ -41,6 +50,7 @@ def main(): "wav_path": "examples/prompt_wav_yuqian.wav", } text_clone = "万物之始,大道至简,衍化至繁。君不见黄河之水天上来,奔流到海不复回。君不见高堂明镜悲白发,朝如青丝暮成雪。人生得意须尽欢,莫使金樽空对月。天生我材必有用,千金散尽还复来。" + text_clone = os.getenv("TTS_TEXT", text_clone) batch_stream = tts_engine.static_batch_stream(text_clone, "", clone_speaker) sub_tts_speechs = [] sr = 22050 From de512a48fd05d647bf5a4f1cca897e01602fc036 Mon Sep 17 00:00:00 2001 From: weedge Date: Thu, 20 Feb 2025 19:36:03 +0800 Subject: [PATCH 09/37] feat: add stream-factor cmd param and close debug print Signed-off-by: weedge --- tts.py | 2 +- tts_inference_stream.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tts.py b/tts.py index d966a60..9073d2c 100644 --- a/tts.py +++ b/tts.py @@ -291,7 +291,7 @@ def static_batch_stream( batch_size = math.ceil(self.stream_factor * cosy_model.model.flow.input_frame_rate) for token_id in self.streamer: - print(token_id, end=",", flush=True) + # print(token_id, end=",", flush=True) if token_id == 3: # skip <|EOT|>, break break self.session_lm_generated_ids[session_id].append(token_id) diff --git a/tts_inference_stream.py b/tts_inference_stream.py index 67191a5..67881c5 100644 --- a/tts_inference_stream.py +++ b/tts_inference_stream.py @@ -24,7 +24,7 @@ def main(): "--stream", type=str, default="static_batch", help="Synthesis audios with streaming" ) parser.add_argument( - "--stream_factor", type=int, default=2, help="Synthesis audios stream factor" + "--stream-factor", type=int, default=2, help="Synthesis audios stream factor" ) args = parser.parse_args() os.makedirs(f"{args.output_path}", exist_ok=True) From a62ae841f966d0ed8ef9002f1f7939de879a1412 Mon Sep 17 00:00:00 2001 From: weedge Date: Thu, 20 Feb 2025 19:47:54 +0800 Subject: [PATCH 10/37] fix: fast path to check params Signed-off-by: weedge --- requirements.txt | 1 + tts.py | 21 +++++++++++---------- tts_inference_stream.py | 6 +++++- 3 files changed, 17 insertions(+), 11 deletions(-) diff --git a/requirements.txt b/requirements.txt index c719357..046d5c9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,3 +18,4 @@ pillow sentencepiece funasr>=1.1.3 protobuf==5.29.3 +python-dotenv diff --git a/tts.py b/tts.py index 9073d2c..8cc7e3e 100644 --- a/tts.py +++ b/tts.py @@ -43,6 +43,17 @@ def __init__( stream_factor: int = 2, **kwargs, ): + # fast path to check params + assert ( + stream_factor >= 2 + ), "stream_factor must >=2 increase for better speech quality, but rft slow (speech quality vs rft)" + self.stream_factor = stream_factor # >=2 increase for better speech quality, but rft slow (speech quality vs rft) + self.streamer = TokenStreamer(skip_prompt=True) + + # session ctx dict with lock, maybe need a session class + self.session_lm_generat_lock = Lock() + self.session_lm_generated_ids = {} # session_id: ids(ptr) + self.llm = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch.bfloat16, @@ -62,16 +73,6 @@ def __init__( } self.register_speakers() - self.streamer = TokenStreamer(skip_prompt=True) - assert ( - stream_factor >= 2 - ), "stream_factor must >=2 increase for better speech quality, but rft slow (speech quality vs rft)" - self.stream_factor = stream_factor # >=2 increase for better speech quality, but rft slow (speech quality vs rft) - - # session ctx dict with lock, maybe need a session class - self.session_lm_generat_lock = Lock() - self.session_lm_generated_ids = {} # session_id: ids(ptr) - def __call__(self, text: str, prompt_speaker: str, clone_dict: dict | None = None): prompt_speaker, prompt_speaker_info, cosy_model = self.preprocess_promt( text, prompt_speaker, clone_dict=clone_dict diff --git a/tts_inference_stream.py b/tts_inference_stream.py index 67881c5..a75a263 100644 --- a/tts_inference_stream.py +++ b/tts_inference_stream.py @@ -30,7 +30,11 @@ def main(): os.makedirs(f"{args.output_path}", exist_ok=True) encoder = StepAudioTokenizer(f"{args.model_path}/Step-Audio-Tokenizer") - tts_engine = StepAudioTTS(f"{args.model_path}/Step-Audio-TTS-3B", encoder) + tts_engine = StepAudioTTS( + f"{args.model_path}/Step-Audio-TTS-3B", + encoder, + stream_factor=args.stream_factor, + ) if args.synthesis_type == "tts": text = "(RAP)君不见黄河之水天上来,奔流到海不复回。君不见高堂明镜悲白发,朝如青丝暮成雪。人生得意须尽欢,莫使金樽空对月。天生我材必有用,千金散尽还复来。" From e9d1c7ce5f3ca37d5c0556b42d698d5bde32cecf Mon Sep 17 00:00:00 2001 From: weedge Date: Thu, 20 Feb 2025 23:41:21 +0800 Subject: [PATCH 11/37] fix typo Signed-off-by: weedge --- tts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tts.py b/tts.py index 8cc7e3e..e33a925 100644 --- a/tts.py +++ b/tts.py @@ -46,8 +46,8 @@ def __init__( # fast path to check params assert ( stream_factor >= 2 - ), "stream_factor must >=2 increase for better speech quality, but rft slow (speech quality vs rft)" - self.stream_factor = stream_factor # >=2 increase for better speech quality, but rft slow (speech quality vs rft) + ), "stream_factor must >=2 increase for better speech quality, but rtf slow (speech quality vs rtf)" + self.stream_factor = stream_factor # >=2 increase for better speech quality, but rtf slow (speech quality vs rtf) self.streamer = TokenStreamer(skip_prompt=True) # session ctx dict with lock, maybe need a session class From 075feefc69c5397a96ffdeaa3e83bdd7f13bc1e6 Mon Sep 17 00:00:00 2001 From: weedge Date: Fri, 21 Feb 2025 00:10:46 +0800 Subject: [PATCH 12/37] fix typo Signed-off-by: weedge --- tts.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tts.py b/tts.py index e33a925..3047e1b 100644 --- a/tts.py +++ b/tts.py @@ -74,7 +74,7 @@ def __init__( self.register_speakers() def __call__(self, text: str, prompt_speaker: str, clone_dict: dict | None = None): - prompt_speaker, prompt_speaker_info, cosy_model = self.preprocess_promt( + prompt_speaker, prompt_speaker_info, cosy_model = self.preprocess_prompt( text, prompt_speaker, clone_dict=clone_dict ) @@ -215,7 +215,7 @@ def preprocess_prompt_wav(self, prompt_wav_path: str): speech_embedding, ) - def preprocess_promt(self, text: str, prompt_speaker: str, clone_dict: dict | None = None): + def preprocess_prompt(self, text: str, prompt_speaker: str, clone_dict: dict | None = None): if clone_dict: ( clone_prompt_code, @@ -261,7 +261,7 @@ def static_batch_stream( - flow: audio vq tokens to mel - hifi: mel to waveform """ - prompt_speaker, prompt_speaker_info, cosy_model = self.preprocess_promt( + prompt_speaker, prompt_speaker_info, cosy_model = self.preprocess_prompt( text, prompt_speaker, clone_dict=clone_dict ) output_audio_sample_rate = cosy_model.model.hift.sampling_rate From 7ce0a88996e91fe5b4956ab1f9447869ce60944a Mon Sep 17 00:00:00 2001 From: weedge Date: Fri, 21 Feb 2025 17:35:36 +0800 Subject: [PATCH 13/37] fix: tts instance share streamer -> gen session streamer Signed-off-by: weedge --- tts.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tts.py b/tts.py index 3047e1b..a6dbb9b 100644 --- a/tts.py +++ b/tts.py @@ -48,7 +48,6 @@ def __init__( stream_factor >= 2 ), "stream_factor must >=2 increase for better speech quality, but rtf slow (speech quality vs rtf)" self.stream_factor = stream_factor # >=2 increase for better speech quality, but rtf slow (speech quality vs rtf) - self.streamer = TokenStreamer(skip_prompt=True) # session ctx dict with lock, maybe need a session class self.session_lm_generat_lock = Lock() @@ -248,6 +247,7 @@ def preprocess_prompt(self, text: str, prompt_speaker: str, clone_dict: dict | N return prompt_speaker, prompt_speaker_info, cosy_model + @torch.inference_mode() def static_batch_stream( self, text: str, @@ -273,10 +273,13 @@ def static_batch_stream( prompt_speaker_info["prompt_code"], ) + # session streamer + streamer = TokenStreamer(skip_prompt=True) + generation_kwargs = dict( input_ids=torch.tensor([token_ids]).to(torch.long).to("cuda"), eos_token_id=3, - streamer=self.streamer, + streamer=streamer, max_length=8192, temperature=0.7, do_sample=True, @@ -291,7 +294,7 @@ def static_batch_stream( self.session_lm_generated_ids[session_id] = [] batch_size = math.ceil(self.stream_factor * cosy_model.model.flow.input_frame_rate) - for token_id in self.streamer: + for token_id in streamer: # print(token_id, end=",", flush=True) if token_id == 3: # skip <|EOT|>, break break From 90a23e0e325dc13aa2489b907aac628545c9b4dd Mon Sep 17 00:00:00 2001 From: weedge Date: Sat, 22 Feb 2025 22:28:53 +0800 Subject: [PATCH 14/37] add speaker_file_path Signed-off-by: weedge --- tts.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tts.py b/tts.py index 33f62ca..ce85499 100644 --- a/tts.py +++ b/tts.py @@ -83,7 +83,7 @@ def __init__( "sys_prompt_wo_spk": '作为一名卓越的声优演员,你的任务是根据文本中()或()括号内标注的情感、语种或方言、音乐哼唱、语音调整等标签,以丰富细腻的情感和自然顺畅的语调来朗读文本。\n# 情感标签涵盖了多种情绪状态,包括但不限于:\n- "高兴1"\n- "高兴2"\n- "生气1"\n- "生气2"\n- "悲伤1"\n- "撒娇1"\n\n# 语种或方言标签包含多种语言或方言,包括但不限于:\n- "中文"\n- "英文"\n- "韩语"\n- "日语"\n- "四川话"\n- "粤语"\n- "广东话"\n\n# 音乐哼唱标签包含多种类型歌曲哼唱,包括但不限于:\n- "RAP"\n- "哼唱"\n\n# 语音调整标签,包括但不限于:\n- "慢速1"\n- "慢速2"\n- "快速1"\n- "快速2"\n\n请在朗读时,根据这些情感标签的指示,调整你的情感、语气、语调和哼唱节奏,以确保文本的情感和意义得到准确而生动的传达,如果没有()或()括号,则根据文本语义内容自由演绎。', "sys_prompt_with_spk": '作为一名卓越的声优演员,你的任务是根据文本中()或()括号内标注的情感、语种或方言、音乐哼唱、语音调整等标签,以丰富细腻的情感和自然顺畅的语调来朗读文本。\n# 情感标签涵盖了多种情绪状态,包括但不限于:\n- "高兴1"\n- "高兴2"\n- "生气1"\n- "生气2"\n- "悲伤1"\n- "撒娇1"\n\n# 语种或方言标签包含多种语言或方言,包括但不限于:\n- "中文"\n- "英文"\n- "韩语"\n- "日语"\n- "四川话"\n- "粤语"\n- "广东话"\n\n# 音乐哼唱标签包含多种类型歌曲哼唱,包括但不限于:\n- "RAP"\n- "哼唱"\n\n# 语音调整标签,包括但不限于:\n- "慢速1"\n- "慢速2"\n- "快速1"\n- "快速2"\n\n请在朗读时,使用[{}]的声音,根据这些情感标签的指示,调整你的情感、语气、语调和哼唱节奏,以确保文本的情感和意义得到准确而生动的传达,如果没有()或()括号,则根据文本语义内容自由演绎。', } - self.register_speakers() + self.register_speakers(file_path=kwargs.get("speaker_file_path", "speakers/speakers_info.json")) def __call__(self, text: str, prompt_speaker: str, clone_dict: dict | None = None): prompt_speaker, prompt_speaker_info, cosy_model = self.preprocess_prompt( @@ -116,10 +116,10 @@ def __call__(self, text: str, prompt_speaker: str, clone_dict: dict | None = Non 22050, ) - def register_speakers(self): + def register_speakers(self, file_path:str="speakers/speakers_info.json"): self.speakers_info = {} - with open("speakers/speakers_info.json", "r") as f: + with open(file_path, "r") as f: speakers_info = json.load(f) for speaker_id, prompt_text in speakers_info.items(): From b3376689d416a894255cc97ee241d8e3f16635ee Mon Sep 17 00:00:00 2001 From: weedge Date: Sat, 22 Feb 2025 22:41:12 +0800 Subject: [PATCH 15/37] add speaker_file_path Signed-off-by: weedge --- tts.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tts.py b/tts.py index ce85499..b47e4eb 100644 --- a/tts.py +++ b/tts.py @@ -41,6 +41,7 @@ def __init__( encoder, device_map: str | dict | None = None, stream_factor: int = 2, + speaker_file_path: str = "speakers/speakers_info.json", **kwargs, ): # fast path to check params @@ -83,7 +84,7 @@ def __init__( "sys_prompt_wo_spk": '作为一名卓越的声优演员,你的任务是根据文本中()或()括号内标注的情感、语种或方言、音乐哼唱、语音调整等标签,以丰富细腻的情感和自然顺畅的语调来朗读文本。\n# 情感标签涵盖了多种情绪状态,包括但不限于:\n- "高兴1"\n- "高兴2"\n- "生气1"\n- "生气2"\n- "悲伤1"\n- "撒娇1"\n\n# 语种或方言标签包含多种语言或方言,包括但不限于:\n- "中文"\n- "英文"\n- "韩语"\n- "日语"\n- "四川话"\n- "粤语"\n- "广东话"\n\n# 音乐哼唱标签包含多种类型歌曲哼唱,包括但不限于:\n- "RAP"\n- "哼唱"\n\n# 语音调整标签,包括但不限于:\n- "慢速1"\n- "慢速2"\n- "快速1"\n- "快速2"\n\n请在朗读时,根据这些情感标签的指示,调整你的情感、语气、语调和哼唱节奏,以确保文本的情感和意义得到准确而生动的传达,如果没有()或()括号,则根据文本语义内容自由演绎。', "sys_prompt_with_spk": '作为一名卓越的声优演员,你的任务是根据文本中()或()括号内标注的情感、语种或方言、音乐哼唱、语音调整等标签,以丰富细腻的情感和自然顺畅的语调来朗读文本。\n# 情感标签涵盖了多种情绪状态,包括但不限于:\n- "高兴1"\n- "高兴2"\n- "生气1"\n- "生气2"\n- "悲伤1"\n- "撒娇1"\n\n# 语种或方言标签包含多种语言或方言,包括但不限于:\n- "中文"\n- "英文"\n- "韩语"\n- "日语"\n- "四川话"\n- "粤语"\n- "广东话"\n\n# 音乐哼唱标签包含多种类型歌曲哼唱,包括但不限于:\n- "RAP"\n- "哼唱"\n\n# 语音调整标签,包括但不限于:\n- "慢速1"\n- "慢速2"\n- "快速1"\n- "快速2"\n\n请在朗读时,使用[{}]的声音,根据这些情感标签的指示,调整你的情感、语气、语调和哼唱节奏,以确保文本的情感和意义得到准确而生动的传达,如果没有()或()括号,则根据文本语义内容自由演绎。', } - self.register_speakers(file_path=kwargs.get("speaker_file_path", "speakers/speakers_info.json")) + self.register_speakers(file_path=speaker_file_path) def __call__(self, text: str, prompt_speaker: str, clone_dict: dict | None = None): prompt_speaker, prompt_speaker_info, cosy_model = self.preprocess_prompt( From 1d8cb6e61e238585567e59a2950c63de06dea8c1 Mon Sep 17 00:00:00 2001 From: weedge Date: Sat, 22 Feb 2025 22:49:08 +0800 Subject: [PATCH 16/37] add speaker_file_path Signed-off-by: weedge --- tts.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/tts.py b/tts.py index b47e4eb..fb1c04e 100644 --- a/tts.py +++ b/tts.py @@ -41,9 +41,9 @@ def __init__( encoder, device_map: str | dict | None = None, stream_factor: int = 2, - speaker_file_path: str = "speakers/speakers_info.json", **kwargs, ): + speaker_file_path = kwargs.pop("speaker_file_path", "speakers/speakers_info.json") # fast path to check params assert ( stream_factor >= 2 @@ -58,11 +58,26 @@ def __init__( # if not, please manually set LD_LIBRARY_PATH=xxx/python3.10/site-packages/nvidia/cuda_nvrtc/lib try: if torch.__version__ >= "2.5": - torch.ops.load_library(os.path.join(model_path, 'lib/liboptimus_ths-torch2.5-cu124.cpython-310-x86_64-linux-gnu.so')) + torch.ops.load_library( + os.path.join( + model_path, + "lib/liboptimus_ths-torch2.5-cu124.cpython-310-x86_64-linux-gnu.so", + ) + ) elif torch.__version__ >= "2.3": - torch.ops.load_library(os.path.join(model_path, 'lib/liboptimus_ths-torch2.3-cu121.cpython-310-x86_64-linux-gnu.so')) + torch.ops.load_library( + os.path.join( + model_path, + "lib/liboptimus_ths-torch2.3-cu121.cpython-310-x86_64-linux-gnu.so", + ) + ) elif torch.__version__ >= "2.2": - torch.ops.load_library(os.path.join(model_path, 'lib/liboptimus_ths-torch2.2-cu121.cpython-310-x86_64-linux-gnu.so')) + torch.ops.load_library( + os.path.join( + model_path, + "lib/liboptimus_ths-torch2.2-cu121.cpython-310-x86_64-linux-gnu.so", + ) + ) print("Load optimus_ths successfully and flash attn would be enabled") except Exception as err: print(f"Fail to load optimus_ths and flash attn is disabled: {err}") @@ -117,7 +132,7 @@ def __call__(self, text: str, prompt_speaker: str, clone_dict: dict | None = Non 22050, ) - def register_speakers(self, file_path:str="speakers/speakers_info.json"): + def register_speakers(self, file_path: str = "speakers/speakers_info.json"): self.speakers_info = {} with open(file_path, "r") as f: From 68752ba7591fe9092a59f435ccdf15973bc2552d Mon Sep 17 00:00:00 2001 From: weedge Date: Sat, 22 Feb 2025 23:11:02 +0800 Subject: [PATCH 17/37] add speaker_file_path Signed-off-by: weedge --- tts.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tts.py b/tts.py index fb1c04e..6c3a512 100644 --- a/tts.py +++ b/tts.py @@ -43,7 +43,6 @@ def __init__( stream_factor: int = 2, **kwargs, ): - speaker_file_path = kwargs.pop("speaker_file_path", "speakers/speakers_info.json") # fast path to check params assert ( stream_factor >= 2 @@ -99,7 +98,7 @@ def __init__( "sys_prompt_wo_spk": '作为一名卓越的声优演员,你的任务是根据文本中()或()括号内标注的情感、语种或方言、音乐哼唱、语音调整等标签,以丰富细腻的情感和自然顺畅的语调来朗读文本。\n# 情感标签涵盖了多种情绪状态,包括但不限于:\n- "高兴1"\n- "高兴2"\n- "生气1"\n- "生气2"\n- "悲伤1"\n- "撒娇1"\n\n# 语种或方言标签包含多种语言或方言,包括但不限于:\n- "中文"\n- "英文"\n- "韩语"\n- "日语"\n- "四川话"\n- "粤语"\n- "广东话"\n\n# 音乐哼唱标签包含多种类型歌曲哼唱,包括但不限于:\n- "RAP"\n- "哼唱"\n\n# 语音调整标签,包括但不限于:\n- "慢速1"\n- "慢速2"\n- "快速1"\n- "快速2"\n\n请在朗读时,根据这些情感标签的指示,调整你的情感、语气、语调和哼唱节奏,以确保文本的情感和意义得到准确而生动的传达,如果没有()或()括号,则根据文本语义内容自由演绎。', "sys_prompt_with_spk": '作为一名卓越的声优演员,你的任务是根据文本中()或()括号内标注的情感、语种或方言、音乐哼唱、语音调整等标签,以丰富细腻的情感和自然顺畅的语调来朗读文本。\n# 情感标签涵盖了多种情绪状态,包括但不限于:\n- "高兴1"\n- "高兴2"\n- "生气1"\n- "生气2"\n- "悲伤1"\n- "撒娇1"\n\n# 语种或方言标签包含多种语言或方言,包括但不限于:\n- "中文"\n- "英文"\n- "韩语"\n- "日语"\n- "四川话"\n- "粤语"\n- "广东话"\n\n# 音乐哼唱标签包含多种类型歌曲哼唱,包括但不限于:\n- "RAP"\n- "哼唱"\n\n# 语音调整标签,包括但不限于:\n- "慢速1"\n- "慢速2"\n- "快速1"\n- "快速2"\n\n请在朗读时,使用[{}]的声音,根据这些情感标签的指示,调整你的情感、语气、语调和哼唱节奏,以确保文本的情感和意义得到准确而生动的传达,如果没有()或()括号,则根据文本语义内容自由演绎。', } - self.register_speakers(file_path=speaker_file_path) + self.register_speakers() def __call__(self, text: str, prompt_speaker: str, clone_dict: dict | None = None): prompt_speaker, prompt_speaker_info, cosy_model = self.preprocess_prompt( @@ -132,14 +131,16 @@ def __call__(self, text: str, prompt_speaker: str, clone_dict: dict | None = Non 22050, ) - def register_speakers(self, file_path: str = "speakers/speakers_info.json"): + def register_speakers(self): self.speakers_info = {} + cur_dir= os.path.dirname(os.path.abspath(__file__)) + file_path: str = os.path.join(cur_dir,"speakers/speakers_info.json") with open(file_path, "r") as f: speakers_info = json.load(f) for speaker_id, prompt_text in speakers_info.items(): - prompt_wav_path = f"speakers/{speaker_id}_prompt.wav" + prompt_wav_path = os.path.join(cur_dir,f"speakers/{speaker_id}_prompt.wav") ( prompt_code, prompt_token, From a833d67a27c19fa37664a5b1d48cae99815f36fb Mon Sep 17 00:00:00 2001 From: weedge Date: Sun, 23 Feb 2025 00:33:22 +0800 Subject: [PATCH 18/37] feat: add modal run step tts/voice Signed-off-by: weedge --- tts.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tts.py b/tts.py index 6c3a512..59f3402 100644 --- a/tts.py +++ b/tts.py @@ -135,7 +135,9 @@ def register_speakers(self): self.speakers_info = {} cur_dir= os.path.dirname(os.path.abspath(__file__)) + print(cur_dir) file_path: str = os.path.join(cur_dir,"speakers/speakers_info.json") + print(file_path) with open(file_path, "r") as f: speakers_info = json.load(f) From 3f17784a30e346b9832304c696409ed7873f1c09 Mon Sep 17 00:00:00 2001 From: weedge Date: Sun, 23 Feb 2025 18:20:59 +0800 Subject: [PATCH 19/37] add ThreadSafeDict for tts lm,flow,hift session gen Signed-off-by: weedge --- cosyvoice/cli/model.py | 104 +++++++++++++++++++++----------------- cosyvoice/utils/common.py | 28 +++++++--- tts.py | 26 +++++----- 3 files changed, 93 insertions(+), 65 deletions(-) diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 0e430e9..1e525f7 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -11,11 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import threading +import logging + +import numpy as np import torch from torch.nn import functional as F -from cosyvoice.utils.common import fade_in_out +from cosyvoice.utils.common import fade_in_out, ThreadSafeDict class CosyVoiceModel: @@ -29,10 +31,20 @@ def __init__( self.hift = hift # dict used to store session related variable - self.lock = threading.Lock() # dict lock - self.mel_overlap_dict = {} - self.flow_cache_dict = {} - self.hift_cache_dict = {} + self.mel_overlap_dict = ThreadSafeDict() + self.flow_cache_dict = ThreadSafeDict() + self.hift_cache_dict = ThreadSafeDict() + + # mel fade in out + self.mel_overlap_len = int( + self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256 + ) + self.mel_window = np.hamming(2 * self.mel_overlap_len) + # hift cache + self.mel_cache_len = 20 + self.source_cache_len = int(self.mel_cache_len * 256) + # speech fade in out + self.speech_window = np.hamming(2 * self.source_cache_len) def load(self, flow_model, hift_model): self.flow.load_state_dict(torch.load(flow_model, map_location=self.device)) @@ -49,17 +61,13 @@ def token2wav( session_id, finalize=False, speed=1.0, - is_flow_cache=False, - is_hift_cache=False, ): - if is_flow_cache is True and session_id not in self.flow_cache_dict: - with self.lock: - self.mel_overlap_dict[session_id] = torch.zeros(1, 80, 0) - self.flow_cache_dict[session_id] = torch.zeros(1, 80, 0, 2) + if session_id not in self.flow_cache_dict: + self.mel_overlap_dict.set(session_id, torch.zeros(1, 80, 0)) + self.flow_cache_dict.set(session_id, torch.zeros(1, 80, 0, 2)) - if is_hift_cache is True and session_id not in self.hift_cache_dict: - with self.lock: - self.hift_cache_dict[session_id] = None + if session_id not in self.hift_cache_dict: + self.hift_cache_dict.set(session_id, None) tts_mel, flow_cache = self.flow.inference( token=token.to(self.device), @@ -71,61 +79,67 @@ def token2wav( prompt_feat=prompt_feat.to(self.device), prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device), embedding=embedding.to(self.device), - flow_cache=self.flow_cache_dict[session_id] if is_flow_cache else None, + flow_cache=self.flow_cache_dict.get(session_id), ) - self.flow_cache_dict[session_id] = flow_cache if is_flow_cache else None + self.flow_cache_dict.set(session_id, flow_cache) # mel overlap fade in out - if is_flow_cache and self.mel_overlap_dict[session_id].shape[2] != 0: - tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[session_id], self.mel_window) + if self.mel_overlap_dict.get(session_id).shape[2] != 0: + tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict.get(session_id), self.mel_window) hift_cache_source = None - if is_hift_cache is True: - if self.hift_cache_dict[session_id] is not None: - # append hift cache - hift_cache_mel, hift_cache_source = ( - self.hift_cache_dict[session_id]["mel"], - self.hift_cache_dict[session_id]["source"], - ) - tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2) - else: - hift_cache_source = torch.zeros(1, 1, 0) + if self.hift_cache_dict.get(session_id) is not None: + # append hift cache + hift_cache_mel, hift_cache_source = ( + self.hift_cache_dict.get(session_id)["mel"], + self.hift_cache_dict.get(session_id)["source"], + ) + tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2) + else: + hift_cache_source = torch.zeros(1, 1, 0) # keep overlap mel and hift cache if finalize is False: - if is_flow_cache is True: - self.mel_overlap_dict[session_id] = tts_mel[:, :, -self.mel_overlap_len :] + self.mel_overlap_dict.set(session_id, tts_mel[:, :, -self.mel_overlap_len :]) tts_mel = tts_mel[:, :, : -self.mel_overlap_len] tts_speech, tts_source = self.hift.inference( speech_feat=tts_mel, cache_source=hift_cache_source ) - if is_hift_cache is True: - if self.hift_cache_dict[session_id] is not None: - tts_speech = fade_in_out( - tts_speech, self.hift_cache_dict[session_id]["speech"], self.speech_window - ) - self.hift_cache_dict[session_id] = { + if self.hift_cache_dict.get(session_id) is not None: + tts_speech = fade_in_out( + tts_speech, self.hift_cache_dict.get(session_id)["speech"], self.speech_window + ) + self.hift_cache_dict.set( + session_id, + { "mel": tts_mel[:, :, -self.mel_cache_len :], "source": tts_source[:, :, -self.source_cache_len :], "speech": tts_speech[:, -self.source_cache_len :], - } + }, + ) tts_speech = tts_speech[:, : -self.source_cache_len] - else: + + logging.info("tts_speech: {}".format(tts_speech.shape)) + else: # finalize if speed != 1.0: - if is_hift_cache is True: - assert ( - self.hift_cache_dict[session_id] is None - ), "speed change only support non-stream inference mode" + assert ( + self.hift_cache_dict.get(session_id) is None + ), "speed change only support non-stream inference mode" tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode="linear") tts_speech, tts_source = self.hift.inference( speech_feat=tts_mel, cache_source=hift_cache_source ) - if is_hift_cache is True and self.hift_cache_dict[session_id] is not None: + if self.hift_cache_dict.get(session_id) is not None: tts_speech = fade_in_out( - tts_speech, self.hift_cache_dict[session_id]["speech"], self.speech_window + tts_speech, self.hift_cache_dict.get(session_id)["speech"], self.speech_window ) + self.mel_overlap_dict.pop(session_id) + self.hift_cache_dict.pop(session_id) + self.flow_cache_dict.pop(session_id) + logging.info("finalize tts_speech: {}".format(tts_speech.shape)) + return tts_speech diff --git a/cosyvoice/utils/common.py b/cosyvoice/utils/common.py index e9611b6..c038452 100644 --- a/cosyvoice/utils/common.py +++ b/cosyvoice/utils/common.py @@ -16,6 +16,7 @@ """Unility functions for Transformer.""" import random +import threading from typing import List import numpy as np @@ -88,9 +89,7 @@ def th_accuracy( pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1) ).argmax(2) mask = pad_targets != ignore_label - numerator = torch.sum( - pad_pred.masked_select(mask) == pad_targets.masked_select(mask) - ) + numerator = torch.sum(pad_pred.masked_select(mask) == pad_targets.masked_select(mask)) denominator = torch.sum(mask) return (numerator / denominator).detach() @@ -129,9 +128,7 @@ def ras_sampling( def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25): prob, indices = [], [] cum_prob = 0.0 - sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort( - descending=True, stable=True - ) + sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True) for i in range(len(sorted_idx)): # sampling both top-p and numbers. if cum_prob < top_p and len(prob) < top_k: @@ -167,3 +164,22 @@ def set_all_random_seed(seed): np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) + + +class ThreadSafeDict: + def __init__(self): + self._dict = {} + # 使用 RLock 可重入锁,避免死锁 + self._lock = threading.RLock() + + def get(self, key, default=None): + with self._lock: + return self._dict.get(key, default) + + def set(self, key, value): + with self._lock: + self._dict[key] = value + + def pop(self, key): + with self._lock: + return self._dict.pop(key, None) diff --git a/tts.py b/tts.py index 59f3402..fd45c58 100644 --- a/tts.py +++ b/tts.py @@ -13,6 +13,7 @@ from transformers.generation.utils import LogitsProcessorList from cosyvoice.cli.cosyvoice import CosyVoice +from cosyvoice.utils.common import ThreadSafeDict from streamer import TokenStreamer @@ -50,8 +51,7 @@ def __init__( self.stream_factor = stream_factor # >=2 increase for better speech quality, but rtf slow (speech quality vs rtf) # session ctx dict with lock, maybe need a session class - self.session_lm_generat_lock = Lock() - self.session_lm_generated_ids = {} # session_id: ids(ptr) + self.session_lm_generated_ids = ThreadSafeDict() # session_id: ids(ptr) # load optimus_ths for flash attention, make sure LD_LIBRARY_PATH has `nvidia/cuda_nvrtc/lib` # if not, please manually set LD_LIBRARY_PATH=xxx/python3.10/site-packages/nvidia/cuda_nvrtc/lib @@ -134,15 +134,15 @@ def __call__(self, text: str, prompt_speaker: str, clone_dict: dict | None = Non def register_speakers(self): self.speakers_info = {} - cur_dir= os.path.dirname(os.path.abspath(__file__)) + cur_dir = os.path.dirname(os.path.abspath(__file__)) print(cur_dir) - file_path: str = os.path.join(cur_dir,"speakers/speakers_info.json") + file_path: str = os.path.join(cur_dir, "speakers/speakers_info.json") print(file_path) with open(file_path, "r") as f: speakers_info = json.load(f) for speaker_id, prompt_text in speakers_info.items(): - prompt_wav_path = os.path.join(cur_dir,f"speakers/{speaker_id}_prompt.wav") + prompt_wav_path = os.path.join(cur_dir, f"speakers/{speaker_id}_prompt.wav") ( prompt_code, prompt_token, @@ -322,18 +322,17 @@ def static_batch_stream( thread = Thread(target=self.llm.generate, kwargs=generation_kwargs) thread.start() - with self.session_lm_generat_lock: - self.session_lm_generated_ids[session_id] = [] + self.session_lm_generated_ids.set(session_id, []) batch_size = math.ceil(self.stream_factor * cosy_model.model.flow.input_frame_rate) for token_id in streamer: # print(token_id, end=",", flush=True) if token_id == 3: # skip <|EOT|>, break break - self.session_lm_generated_ids[session_id].append(token_id) - if len(self.session_lm_generated_ids[session_id]) % batch_size == 0: + self.session_lm_generated_ids.get(session_id).append(token_id) + if len(self.session_lm_generated_ids.get(session_id)) % batch_size == 0: batch = ( - torch.tensor(self.session_lm_generated_ids[session_id]) + torch.tensor(self.session_lm_generated_ids.get(session_id)) .unsqueeze(0) .to(cosy_model.model.device) - 65536 @@ -348,12 +347,11 @@ def static_batch_stream( prompt_speaker_info["cosy_speech_embedding"].to(torch.bfloat16), ) yield {"tts_speech": sub_tts_speech, "sample_rate": output_audio_sample_rate} - with self.session_lm_generat_lock: - self.session_lm_generated_ids[session_id] = [] + self.session_lm_generated_ids.set(session_id, []) - if len(self.session_lm_generated_ids[session_id]) > 0: + if len(self.session_lm_generated_ids.get(session_id)) > 0: batch = ( - torch.tensor(self.session_lm_generated_ids[session_id]) + torch.tensor(self.session_lm_generated_ids.get(session_id)) .unsqueeze(0) .to(cosy_model.model.device) - 65536 From 73fb1494f50cb7b96f7cd94f14c6ab120b3bd250 Mon Sep 17 00:00:00 2001 From: weedge Date: Sun, 23 Feb 2025 18:33:25 +0800 Subject: [PATCH 20/37] fix: token_overlap_len Signed-off-by: weedge --- cosyvoice/cli/model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 1e525f7..a9f4d19 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -35,6 +35,7 @@ def __init__( self.flow_cache_dict = ThreadSafeDict() self.hift_cache_dict = ThreadSafeDict() + self.token_overlap_len = 20 # mel fade in out self.mel_overlap_len = int( self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256 From b4a2cf6b95767ead66d08dd2af470fb72e56d3f4 Mon Sep 17 00:00:00 2001 From: weedge Date: Sun, 23 Feb 2025 18:47:25 +0800 Subject: [PATCH 21/37] fix: token_overlap_len Signed-off-by: weedge --- cosyvoice/cli/cosyvoice.py | 2 +- cosyvoice/flow/flow.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index 4054071..a5ed673 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -53,7 +53,7 @@ def token_to_wav_offline( prompt_token_len, embedding, ): - tts_mel = self.model.flow.inference( + tts_mel, _ = self.model.flow.inference( token=speech_token.to(self.model.device), token_len=torch.tensor([speech_token.size(1)], dtype=torch.int32).to(self.model.device), prompt_token=prompt_token.to(self.model.device), diff --git a/cosyvoice/flow/flow.py b/cosyvoice/flow/flow.py index b2fe43b..efea2eb 100644 --- a/cosyvoice/flow/flow.py +++ b/cosyvoice/flow/flow.py @@ -141,7 +141,7 @@ def inference( prompt_feat, prompt_feat_len, embedding, - flow_cache, + flow_cache=None, ): assert token.shape[0] == 1 # xvec projection From ec9d5c6b8cf329d1e70d41ed6f3a88eb5e961a66 Mon Sep 17 00:00:00 2001 From: weedge Date: Sun, 23 Feb 2025 19:01:26 +0800 Subject: [PATCH 22/37] remove print Signed-off-by: weedge --- cosyvoice/flow/flow.py | 2 +- tts.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/cosyvoice/flow/flow.py b/cosyvoice/flow/flow.py index efea2eb..fc62c05 100644 --- a/cosyvoice/flow/flow.py +++ b/cosyvoice/flow/flow.py @@ -182,7 +182,7 @@ def inference( # mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h) mask = torch.ones([1, mel_len1 + mel_len2], device=h.device, dtype=torch.bfloat16) - feat = self.decoder( + feat, flow_cache = self.decoder( mu=h.transpose(1, 2).contiguous(), mask=mask.unsqueeze(1), spks=embedding, diff --git a/tts.py b/tts.py index fd45c58..01d77c0 100644 --- a/tts.py +++ b/tts.py @@ -135,9 +135,7 @@ def register_speakers(self): self.speakers_info = {} cur_dir = os.path.dirname(os.path.abspath(__file__)) - print(cur_dir) file_path: str = os.path.join(cur_dir, "speakers/speakers_info.json") - print(file_path) with open(file_path, "r") as f: speakers_info = json.load(f) From 71a83d74d7ebb1f4dbfb833e3b9b0ba381e648f4 Mon Sep 17 00:00:00 2001 From: weedge Date: Sun, 23 Feb 2025 19:50:26 +0800 Subject: [PATCH 23/37] feat: add token2wav with flow hift session Signed-off-by: weedge --- tts.py | 38 ++++++++++++++++++++------------------ 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/tts.py b/tts.py index 01d77c0..ea74c4a 100644 --- a/tts.py +++ b/tts.py @@ -336,34 +336,36 @@ def static_batch_stream( - 65536 ) # [T] -> [1,T] # Process each batch - sub_tts_speech = cosy_model.token_to_wav_offline( + sub_tts_speech = cosy_model.model.token2wav( batch, prompt_speaker_info["cosy_speech_feat"].to(torch.bfloat16), prompt_speaker_info["cosy_speech_feat_len"], prompt_speaker_info["cosy_prompt_token"], prompt_speaker_info["cosy_prompt_token_len"], prompt_speaker_info["cosy_speech_embedding"].to(torch.bfloat16), + finalize=False, ) yield {"tts_speech": sub_tts_speech, "sample_rate": output_audio_sample_rate} self.session_lm_generated_ids.set(session_id, []) - if len(self.session_lm_generated_ids.get(session_id)) > 0: - batch = ( - torch.tensor(self.session_lm_generated_ids.get(session_id)) - .unsqueeze(0) - .to(cosy_model.model.device) - - 65536 - ) # [T] -> [1,T] - # Process each batch - sub_tts_speech = cosy_model.token_to_wav_offline( - batch, - prompt_speaker_info["cosy_speech_feat"].to(torch.bfloat16), - prompt_speaker_info["cosy_speech_feat_len"], - prompt_speaker_info["cosy_prompt_token"], - prompt_speaker_info["cosy_prompt_token_len"], - prompt_speaker_info["cosy_speech_embedding"].to(torch.bfloat16), - ) - yield {"tts_speech": sub_tts_speech, "sample_rate": output_audio_sample_rate} + self.session_lm_generated_ids.set(session_id, []) + batch = ( + torch.tensor(self.session_lm_generated_ids.get(session_id)) + .unsqueeze(0) + .to(cosy_model.model.device) + - 65536 + ) # [T] -> [1,T] + # Process each batch + sub_tts_speech = cosy_model.model.token2wav( + batch, + prompt_speaker_info["cosy_speech_feat"].to(torch.bfloat16), + prompt_speaker_info["cosy_speech_feat_len"], + prompt_speaker_info["cosy_prompt_token"], + prompt_speaker_info["cosy_prompt_token_len"], + prompt_speaker_info["cosy_speech_embedding"].to(torch.bfloat16), + finalize=True, + ) + yield {"tts_speech": sub_tts_speech, "sample_rate": output_audio_sample_rate} with self.session_lm_generat_lock: self.session_lm_generated_ids.pop(session_id) From a25e16daa91cafa8d423f152b85448ddc93f4641 Mon Sep 17 00:00:00 2001 From: weedge Date: Sun, 23 Feb 2025 20:05:35 +0800 Subject: [PATCH 24/37] remove python-dotenv Signed-off-by: weedge --- requirements.txt | 1 - tts.py | 2 ++ tts_inference_stream.py | 3 --- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/requirements.txt b/requirements.txt index 046d5c9..c719357 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,4 +18,3 @@ pillow sentencepiece funasr>=1.1.3 protobuf==5.29.3 -python-dotenv diff --git a/tts.py b/tts.py index ea74c4a..0e97a7a 100644 --- a/tts.py +++ b/tts.py @@ -343,6 +343,7 @@ def static_batch_stream( prompt_speaker_info["cosy_prompt_token"], prompt_speaker_info["cosy_prompt_token_len"], prompt_speaker_info["cosy_speech_embedding"].to(torch.bfloat16), + session_id, finalize=False, ) yield {"tts_speech": sub_tts_speech, "sample_rate": output_audio_sample_rate} @@ -363,6 +364,7 @@ def static_batch_stream( prompt_speaker_info["cosy_prompt_token"], prompt_speaker_info["cosy_prompt_token_len"], prompt_speaker_info["cosy_speech_embedding"].to(torch.bfloat16), + session_id, finalize=True, ) yield {"tts_speech": sub_tts_speech, "sample_rate": output_audio_sample_rate} diff --git a/tts_inference_stream.py b/tts_inference_stream.py index a75a263..704975a 100644 --- a/tts_inference_stream.py +++ b/tts_inference_stream.py @@ -1,6 +1,5 @@ import os import argparse -from dotenv import load_dotenv import torchaudio @@ -8,8 +7,6 @@ from utils import merge_tensors from tts import StepAudioTTS -load_dotenv(override=True) - def main(): parser = argparse.ArgumentParser(description="StepAudio Stream Inference") From 8ba35a173242767a0c3ba85d893c4c79247d75b3 Mon Sep 17 00:00:00 2001 From: weedge Date: Sun, 23 Feb 2025 20:14:11 +0800 Subject: [PATCH 25/37] fix token2wav Signed-off-by: weedge --- tts.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tts.py b/tts.py index 0e97a7a..00cb352 100644 --- a/tts.py +++ b/tts.py @@ -338,10 +338,8 @@ def static_batch_stream( # Process each batch sub_tts_speech = cosy_model.model.token2wav( batch, - prompt_speaker_info["cosy_speech_feat"].to(torch.bfloat16), - prompt_speaker_info["cosy_speech_feat_len"], prompt_speaker_info["cosy_prompt_token"], - prompt_speaker_info["cosy_prompt_token_len"], + prompt_speaker_info["cosy_speech_feat"].to(torch.bfloat16), prompt_speaker_info["cosy_speech_embedding"].to(torch.bfloat16), session_id, finalize=False, @@ -359,10 +357,8 @@ def static_batch_stream( # Process each batch sub_tts_speech = cosy_model.model.token2wav( batch, - prompt_speaker_info["cosy_speech_feat"].to(torch.bfloat16), - prompt_speaker_info["cosy_speech_feat_len"], prompt_speaker_info["cosy_prompt_token"], - prompt_speaker_info["cosy_prompt_token_len"], + prompt_speaker_info["cosy_speech_feat"].to(torch.bfloat16), prompt_speaker_info["cosy_speech_embedding"].to(torch.bfloat16), session_id, finalize=True, From 733e052383e2f9090c60de693c6b9c74a6269550 Mon Sep 17 00:00:00 2001 From: weedge Date: Sun, 23 Feb 2025 20:21:25 +0800 Subject: [PATCH 26/37] fix token2wav Signed-off-by: weedge --- cosyvoice/cli/model.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index a9f4d19..41e4598 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -63,13 +63,10 @@ def token2wav( finalize=False, speed=1.0, ): - if session_id not in self.flow_cache_dict: + if self.flow_cache_dict.get(session_id) is None: self.mel_overlap_dict.set(session_id, torch.zeros(1, 80, 0)) self.flow_cache_dict.set(session_id, torch.zeros(1, 80, 0, 2)) - if session_id not in self.hift_cache_dict: - self.hift_cache_dict.set(session_id, None) - tts_mel, flow_cache = self.flow.inference( token=token.to(self.device), token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), From 42ff2bfa66c90913826aa4e4915607c3c91bc9ba Mon Sep 17 00:00:00 2001 From: weedge Date: Sun, 23 Feb 2025 20:30:29 +0800 Subject: [PATCH 27/37] fix token2wav Signed-off-by: weedge --- tts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tts.py b/tts.py index 00cb352..7437fe5 100644 --- a/tts.py +++ b/tts.py @@ -347,7 +347,7 @@ def static_batch_stream( yield {"tts_speech": sub_tts_speech, "sample_rate": output_audio_sample_rate} self.session_lm_generated_ids.set(session_id, []) - self.session_lm_generated_ids.set(session_id, []) + self.session_lm_generated_ids.set(session_id, [65536]) batch = ( torch.tensor(self.session_lm_generated_ids.get(session_id)) .unsqueeze(0) From 73897a3bfac7f4b524371f41d65d82e5911af362 Mon Sep 17 00:00:00 2001 From: weedge Date: Sun, 23 Feb 2025 20:59:40 +0800 Subject: [PATCH 28/37] fix token2wav mel Signed-off-by: weedge --- cosyvoice/cli/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 41e4598..b58d651 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -102,7 +102,7 @@ def token2wav( tts_mel = tts_mel[:, :, : -self.mel_overlap_len] tts_speech, tts_source = self.hift.inference( - speech_feat=tts_mel, cache_source=hift_cache_source + mel=tts_mel, cache_source=hift_cache_source ) if self.hift_cache_dict.get(session_id) is not None: @@ -128,7 +128,7 @@ def token2wav( ), "speed change only support non-stream inference mode" tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode="linear") tts_speech, tts_source = self.hift.inference( - speech_feat=tts_mel, cache_source=hift_cache_source + mel=tts_mel, cache_source=hift_cache_source ) if self.hift_cache_dict.get(session_id) is not None: tts_speech = fade_in_out( From 37c6eee17e6b494c4388711d904fafa86541f899 Mon Sep 17 00:00:00 2001 From: weedge Date: Sun, 23 Feb 2025 21:12:05 +0800 Subject: [PATCH 29/37] fix: flow infer return mel float Signed-off-by: weedge --- cosyvoice/flow/flow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cosyvoice/flow/flow.py b/cosyvoice/flow/flow.py index fc62c05..a4fa4e5 100644 --- a/cosyvoice/flow/flow.py +++ b/cosyvoice/flow/flow.py @@ -193,4 +193,4 @@ def inference( ) feat = feat[:, :, mel_len1:] assert feat.shape[2] == mel_len2 - return feat, flow_cache + return feat.float(), flow_cache From 5a426dc4301dab8f79db5db242c6d210725a34c8 Mon Sep 17 00:00:00 2001 From: weedge Date: Sun, 23 Feb 2025 21:15:19 +0800 Subject: [PATCH 30/37] fix: lock Signed-off-by: weedge --- tts.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tts.py b/tts.py index 7437fe5..7b06f2f 100644 --- a/tts.py +++ b/tts.py @@ -365,6 +365,5 @@ def static_batch_stream( ) yield {"tts_speech": sub_tts_speech, "sample_rate": output_audio_sample_rate} - with self.session_lm_generat_lock: - self.session_lm_generated_ids.pop(session_id) + self.session_lm_generated_ids.pop(session_id) torch.cuda.empty_cache() From 8a2b3f3a5c27dadd70c898fbe668ff31fee17ba6 Mon Sep 17 00:00:00 2001 From: weedge Date: Sun, 23 Feb 2025 21:27:49 +0800 Subject: [PATCH 31/37] token2wav load to cpu Signed-off-by: weedge --- cosyvoice/cli/model.py | 2 +- tts.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index b58d651..d35ef74 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -140,4 +140,4 @@ def token2wav( self.flow_cache_dict.pop(session_id) logging.info("finalize tts_speech: {}".format(tts_speech.shape)) - return tts_speech + return tts_speech.cpu() diff --git a/tts.py b/tts.py index 7b06f2f..445e3cb 100644 --- a/tts.py +++ b/tts.py @@ -347,7 +347,9 @@ def static_batch_stream( yield {"tts_speech": sub_tts_speech, "sample_rate": output_audio_sample_rate} self.session_lm_generated_ids.set(session_id, []) - self.session_lm_generated_ids.set(session_id, [65536]) + if len(self.session_lm_generated_ids.get(session_id)) == 0: # end to finalize + self.session_lm_generated_ids.set(session_id, [65536]) + batch = ( torch.tensor(self.session_lm_generated_ids.get(session_id)) .unsqueeze(0) From bb0e08e73eab43ab81b6a758933e4af7ecd033e2 Mon Sep 17 00:00:00 2001 From: weedge Date: Mon, 24 Feb 2025 09:50:41 +0800 Subject: [PATCH 32/37] feat: add token overlap to gen Signed-off-by: weedge --- tts.py | 28 +++++++++++++++++++++------- tts_inference_stream.py | 4 ++-- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/tts.py b/tts.py index 445e3cb..4d77b83 100644 --- a/tts.py +++ b/tts.py @@ -42,13 +42,15 @@ def __init__( encoder, device_map: str | dict | None = None, stream_factor: int = 2, + max_stream_factor: int = 2, **kwargs, ): # fast path to check params assert ( stream_factor >= 2 - ), "stream_factor must >=2 increase for better speech quality, but rtf slow (speech quality vs rtf)" + ), f"stream_factor must >=2 increase for better speech quality, but rtf slow (speech quality vs rtf)" self.stream_factor = stream_factor # >=2 increase for better speech quality, but rtf slow (speech quality vs rtf) + self.token_max_hop_len = max_stream_factor * self.flow.input_frame_rate # session ctx dict with lock, maybe need a session class self.session_lm_generated_ids = ThreadSafeDict() # session_id: ids(ptr) @@ -278,7 +280,7 @@ def preprocess_prompt(self, text: str, prompt_speaker: str, clone_dict: dict | N return prompt_speaker, prompt_speaker_info, cosy_model @torch.inference_mode() - def static_batch_stream( + def batch_stream( self, text: str, prompt_speaker: str, @@ -287,7 +289,9 @@ def static_batch_stream( ): """ - step1 lm stream generate token - - static batch size to gen waveform + - batch size to gen waveform + - when max_stream_factor > stream_factor, dynamic batch size to gen waveform + - when max_stream_factor <= stream_factor, static batch size to gen waveform - flow: audio vq tokens to mel - hifi: mel to waveform """ @@ -328,9 +332,17 @@ def static_batch_stream( if token_id == 3: # skip <|EOT|>, break break self.session_lm_generated_ids.get(session_id).append(token_id) - if len(self.session_lm_generated_ids.get(session_id)) % batch_size == 0: + # if len(self.session_lm_generated_ids.get(session_id)) % batch_size == 0: + if ( + len(self.session_lm_generated_ids.get(session_id)) + >= batch_size + self.token_overlap_len + ): batch = ( - torch.tensor(self.session_lm_generated_ids.get(session_id)) + torch.tensor( + self.session_lm_generated_ids.get(session_id)[ + : batch_size + self.token_overlap_len + ] + ) .unsqueeze(0) .to(cosy_model.model.device) - 65536 @@ -345,9 +357,11 @@ def static_batch_stream( finalize=False, ) yield {"tts_speech": sub_tts_speech, "sample_rate": output_audio_sample_rate} - self.session_lm_generated_ids.set(session_id, []) + self.session_lm_generated_ids.set( + session_id, self.session_lm_generated_ids.get(session_id)[batch_size:] + ) - if len(self.session_lm_generated_ids.get(session_id)) == 0: # end to finalize + if len(self.session_lm_generated_ids.get(session_id)) == 0: # end to finalize self.session_lm_generated_ids.set(session_id, [65536]) batch = ( diff --git a/tts_inference_stream.py b/tts_inference_stream.py index 704975a..96b347f 100644 --- a/tts_inference_stream.py +++ b/tts_inference_stream.py @@ -36,7 +36,7 @@ def main(): if args.synthesis_type == "tts": text = "(RAP)君不见黄河之水天上来,奔流到海不复回。君不见高堂明镜悲白发,朝如青丝暮成雪。人生得意须尽欢,莫使金樽空对月。天生我材必有用,千金散尽还复来。" text = os.getenv("TTS_TEXT", text) - batch_stream = tts_engine.static_batch_stream(text, "Tingting") + batch_stream = tts_engine.batch_stream(text, "Tingting") sub_tts_speechs = [] sr = 22050 for item in batch_stream: @@ -52,7 +52,7 @@ def main(): } text_clone = "万物之始,大道至简,衍化至繁。君不见黄河之水天上来,奔流到海不复回。君不见高堂明镜悲白发,朝如青丝暮成雪。人生得意须尽欢,莫使金樽空对月。天生我材必有用,千金散尽还复来。" text_clone = os.getenv("TTS_TEXT", text_clone) - batch_stream = tts_engine.static_batch_stream(text_clone, "", clone_speaker) + batch_stream = tts_engine.batch_stream(text_clone, "", clone_speaker) sub_tts_speechs = [] sr = 22050 for item in batch_stream: From 9866e8ca9fb37a084eb7c2e92fe0ff72305dd528 Mon Sep 17 00:00:00 2001 From: weedge Date: Mon, 24 Feb 2025 10:16:16 +0800 Subject: [PATCH 33/37] feat: add max_stream_factor for dynamic batch stream Signed-off-by: weedge --- cosyvoice/cli/cosyvoice.py | 7 ++++++- cosyvoice/cli/model.py | 6 ++---- tts.py | 30 ++++++++++++++++++++++++++---- 3 files changed, 34 insertions(+), 9 deletions(-) diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index a5ed673..bbe2136 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -27,6 +27,7 @@ class CosyVoice: def __init__( self, model_dir, + token_overlap_len: int = 20, ): self.model_dir = model_dir with open("{}/cosyvoice.yaml".format(model_dir), "r") as f: @@ -36,7 +37,11 @@ def __init__( "{}/campplus.onnx".format(model_dir), "{}/speech_tokenizer_v1.onnx".format(model_dir), ) - self.model = CosyVoiceModel(configs["flow"], configs["hift"]) + self.model = CosyVoiceModel( + configs["flow"], + configs["hift"], + token_overlap_len=token_overlap_len, + ) self.model.load( "{}/flow.pt".format(model_dir), "{}/hift.pt".format(model_dir), diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index d35ef74..28336af 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -25,6 +25,7 @@ def __init__( self, flow: torch.nn.Module, hift: torch.nn.Module, + token_overlap_len: int = 20, ): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.flow = flow @@ -35,11 +36,8 @@ def __init__( self.flow_cache_dict = ThreadSafeDict() self.hift_cache_dict = ThreadSafeDict() - self.token_overlap_len = 20 # mel fade in out - self.mel_overlap_len = int( - self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256 - ) + self.mel_overlap_len = int(token_overlap_len / self.flow.input_frame_rate * 22050 / 256) self.mel_window = np.hamming(2 * self.mel_overlap_len) # hift cache self.mel_cache_len = 20 diff --git a/tts.py b/tts.py index 4d77b83..1c68381 100644 --- a/tts.py +++ b/tts.py @@ -1,3 +1,4 @@ +import logging import math import os import re @@ -42,15 +43,26 @@ def __init__( encoder, device_map: str | dict | None = None, stream_factor: int = 2, + stream_scale_factor: float = 1.0, max_stream_factor: int = 2, + token_overlap_len: int = 20, **kwargs, ): # fast path to check params + # rtf and decoding related assert ( stream_factor >= 2 ), f"stream_factor must >=2 increase for better speech quality, but rtf slow (speech quality vs rtf)" - self.stream_factor = stream_factor # >=2 increase for better speech quality, but rtf slow (speech quality vs rtf) + self.stream_factor = stream_factor self.token_max_hop_len = max_stream_factor * self.flow.input_frame_rate + assert ( + stream_scale_factor >= 1.0 + ), "stream_scale_factor should be greater than 1, change it according to your actual rtf" + self.stream_scale_factor = stream_scale_factor # scale speed + assert ( + token_overlap_len >= 0 + ), "token_overlap_len should be greater than 0, change it according to your actual rtf" + self.token_overlap_len = token_overlap_len # session ctx dict with lock, maybe need a session class self.session_lm_generated_ids = ThreadSafeDict() # session_id: ids(ptr) @@ -91,8 +103,14 @@ def __init__( **kwargs, ) self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) - self.common_cosy_model = CosyVoice(os.path.join(model_path, "CosyVoice-300M-25Hz")) - self.music_cosy_model = CosyVoice(os.path.join(model_path, "CosyVoice-300M-25Hz-Music")) + self.common_cosy_model = CosyVoice( + os.path.join(model_path, "CosyVoice-300M-25Hz"), + token_overlap_len=token_overlap_len, + ) + self.music_cosy_model = CosyVoice( + os.path.join(model_path, "CosyVoice-300M-25Hz-Music"), + token_overlap_len=token_overlap_len, + ) self.encoder = encoder self.sys_prompt_dict = { "sys_prompt_for_rap": "请参考对话历史里的音色,用RAP方式将文本内容大声说唱出来。", @@ -293,7 +311,7 @@ def batch_stream( - when max_stream_factor > stream_factor, dynamic batch size to gen waveform - when max_stream_factor <= stream_factor, static batch size to gen waveform - flow: audio vq tokens to mel - - hifi: mel to waveform + - hift: mel to waveform """ prompt_speaker, prompt_speaker_info, cosy_model = self.preprocess_prompt( text, prompt_speaker, clone_dict=clone_dict @@ -327,6 +345,7 @@ def batch_stream( self.session_lm_generated_ids.set(session_id, []) batch_size = math.ceil(self.stream_factor * cosy_model.model.flow.input_frame_rate) + logging.info(f"init batch_size: {batch_size}") for token_id in streamer: # print(token_id, end=",", flush=True) if token_id == 3: # skip <|EOT|>, break @@ -360,6 +379,9 @@ def batch_stream( self.session_lm_generated_ids.set( session_id, self.session_lm_generated_ids.get(session_id)[batch_size:] ) + # increase token_hop_len for better speech quality + batch_size = min(self.token_max_hop_len, int(batch_size * self.stream_scale_factor)) + logging.info(f"increase batch_size: {batch_size}") if len(self.session_lm_generated_ids.get(session_id)) == 0: # end to finalize self.session_lm_generated_ids.set(session_id, [65536]) From 6b56e20afb33503de42410b4d05f87358bd975b3 Mon Sep 17 00:00:00 2001 From: weedge Date: Mon, 24 Feb 2025 10:28:04 +0800 Subject: [PATCH 34/37] change device_map default auto Signed-off-by: weedge --- tts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tts.py b/tts.py index 1c68381..1277a6f 100644 --- a/tts.py +++ b/tts.py @@ -98,7 +98,7 @@ def __init__( self.llm = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch.bfloat16, - device_map="cuda" if not device_map else device_map, + device_map="auto" if not device_map else device_map, trust_remote_code=True, **kwargs, ) From 680e93463fae899a29c5747397c8e95579568011 Mon Sep 17 00:00:00 2001 From: weedge Date: Mon, 24 Feb 2025 10:33:35 +0800 Subject: [PATCH 35/37] change max_batch_size Signed-off-by: weedge --- tts.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tts.py b/tts.py index 1277a6f..8533b19 100644 --- a/tts.py +++ b/tts.py @@ -54,7 +54,7 @@ def __init__( stream_factor >= 2 ), f"stream_factor must >=2 increase for better speech quality, but rtf slow (speech quality vs rtf)" self.stream_factor = stream_factor - self.token_max_hop_len = max_stream_factor * self.flow.input_frame_rate + self.max_stream_factor = max_stream_factor assert ( stream_scale_factor >= 1.0 ), "stream_scale_factor should be greater than 1, change it according to your actual rtf" @@ -344,8 +344,9 @@ def batch_stream( self.session_lm_generated_ids.set(session_id, []) + max_batch_size = math.ceil(self.max_stream_factor * cosy_model.model.flow.input_frame_rate) batch_size = math.ceil(self.stream_factor * cosy_model.model.flow.input_frame_rate) - logging.info(f"init batch_size: {batch_size}") + logging.info(f"init batch_size: {batch_size} max_batch_size: {max_batch_size}") for token_id in streamer: # print(token_id, end=",", flush=True) if token_id == 3: # skip <|EOT|>, break @@ -380,7 +381,7 @@ def batch_stream( session_id, self.session_lm_generated_ids.get(session_id)[batch_size:] ) # increase token_hop_len for better speech quality - batch_size = min(self.token_max_hop_len, int(batch_size * self.stream_scale_factor)) + batch_size = min(max_batch_size, int(batch_size * self.stream_scale_factor)) logging.info(f"increase batch_size: {batch_size}") if len(self.session_lm_generated_ids.get(session_id)) == 0: # end to finalize From 376d7628017a58a5bbf6dec144667bac4cf44173 Mon Sep 17 00:00:00 2001 From: weedge Date: Mon, 24 Feb 2025 10:44:28 +0800 Subject: [PATCH 36/37] add: stream_factor stream_scale_factor max_stream_factor token_overlap_len for tts inference stream params Signed-off-by: weedge --- tts_inference_stream.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tts_inference_stream.py b/tts_inference_stream.py index 96b347f..78a3a85 100644 --- a/tts_inference_stream.py +++ b/tts_inference_stream.py @@ -23,6 +23,18 @@ def main(): parser.add_argument( "--stream-factor", type=int, default=2, help="Synthesis audios stream factor" ) + parser.add_argument( + "--max-stream-factor", type=int, default=2, help="Synthesis audios max stream factor" + ) + parser.add_argument( + "--stream-scale-factor", type=float, default=1.0, help="Synthesis audios stream scale factor" + ) + parser.add_argument( + "--max-stream-factor", type=int, default=2, help="Synthesis audios max stream factor" + ) + parser.add_argument( + "--token-overlap-len", type=int, default=20, help="Synthesis audios token overlap len" + ) args = parser.parse_args() os.makedirs(f"{args.output_path}", exist_ok=True) @@ -31,6 +43,9 @@ def main(): f"{args.model_path}/Step-Audio-TTS-3B", encoder, stream_factor=args.stream_factor, + stream_scale_factor=args.stream_scale_factor, + max_stream_factor=args.max_stream_factor, + token_overlap_len=args.token_overlap_len, ) if args.synthesis_type == "tts": From 7f623b089f9238ec25283eac10f373709240acc1 Mon Sep 17 00:00:00 2001 From: weedge Date: Mon, 24 Feb 2025 11:10:05 +0800 Subject: [PATCH 37/37] fix tts_inference_stream params Signed-off-by: weedge --- tts_inference_stream.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tts_inference_stream.py b/tts_inference_stream.py index 78a3a85..08869a7 100644 --- a/tts_inference_stream.py +++ b/tts_inference_stream.py @@ -23,9 +23,6 @@ def main(): parser.add_argument( "--stream-factor", type=int, default=2, help="Synthesis audios stream factor" ) - parser.add_argument( - "--max-stream-factor", type=int, default=2, help="Synthesis audios max stream factor" - ) parser.add_argument( "--stream-scale-factor", type=float, default=1.0, help="Synthesis audios stream scale factor" )