diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index b8c0b4b..bbe2136 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import threading import uuid import time from tqdm import tqdm @@ -23,10 +24,10 @@ class CosyVoice: - def __init__( self, model_dir, + token_overlap_len: int = 20, ): self.model_dir = model_dir with open("{}/cosyvoice.yaml".format(model_dir), "r") as f: @@ -36,7 +37,11 @@ def __init__( "{}/campplus.onnx".format(model_dir), "{}/speech_tokenizer_v1.onnx".format(model_dir), ) - self.model = CosyVoiceModel(configs["flow"], configs["hift"]) + self.model = CosyVoiceModel( + configs["flow"], + configs["hift"], + token_overlap_len=token_overlap_len, + ) self.model.load( "{}/flow.pt".format(model_dir), "{}/hift.pt".format(model_dir), @@ -53,11 +58,9 @@ def token_to_wav_offline( prompt_token_len, embedding, ): - tts_mel = self.model.flow.inference( + tts_mel, _ = self.model.flow.inference( token=speech_token.to(self.model.device), - token_len=torch.tensor([speech_token.size(1)], dtype=torch.int32).to( - self.model.device - ), + token_len=torch.tensor([speech_token.size(1)], dtype=torch.int32).to(self.model.device), prompt_token=prompt_token.to(self.model.device), prompt_token_len=prompt_token_len.to(self.model.device), prompt_feat=speech_feat.to(self.model.device), diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index b284d9e..28336af 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -11,22 +11,131 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging + +import numpy as np import torch +from torch.nn import functional as F +from cosyvoice.utils.common import fade_in_out, ThreadSafeDict -class CosyVoiceModel: +class CosyVoiceModel: def __init__( self, flow: torch.nn.Module, hift: torch.nn.Module, + token_overlap_len: int = 20, ): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.flow = flow self.hift = hift + # dict used to store session related variable + self.mel_overlap_dict = ThreadSafeDict() + self.flow_cache_dict = ThreadSafeDict() + self.hift_cache_dict = ThreadSafeDict() + + # mel fade in out + self.mel_overlap_len = int(token_overlap_len / self.flow.input_frame_rate * 22050 / 256) + self.mel_window = np.hamming(2 * self.mel_overlap_len) + # hift cache + self.mel_cache_len = 20 + self.source_cache_len = int(self.mel_cache_len * 256) + # speech fade in out + self.speech_window = np.hamming(2 * self.source_cache_len) + def load(self, flow_model, hift_model): self.flow.load_state_dict(torch.load(flow_model, map_location=self.device)) self.flow.to(self.device).eval() self.hift.load_state_dict(torch.load(hift_model, map_location=self.device)) self.hift.to(self.device).eval() + + def token2wav( + self, + token, + prompt_token, + prompt_feat, + embedding, + session_id, + finalize=False, + speed=1.0, + ): + if self.flow_cache_dict.get(session_id) is None: + self.mel_overlap_dict.set(session_id, torch.zeros(1, 80, 0)) + self.flow_cache_dict.set(session_id, torch.zeros(1, 80, 0, 2)) + + tts_mel, flow_cache = self.flow.inference( + token=token.to(self.device), + token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), + prompt_token=prompt_token.to(self.device), + prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to( + self.device + ), + prompt_feat=prompt_feat.to(self.device), + prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device), + embedding=embedding.to(self.device), + flow_cache=self.flow_cache_dict.get(session_id), + ) + self.flow_cache_dict.set(session_id, flow_cache) + + # mel overlap fade in out + if self.mel_overlap_dict.get(session_id).shape[2] != 0: + tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict.get(session_id), self.mel_window) + + hift_cache_source = None + if self.hift_cache_dict.get(session_id) is not None: + # append hift cache + hift_cache_mel, hift_cache_source = ( + self.hift_cache_dict.get(session_id)["mel"], + self.hift_cache_dict.get(session_id)["source"], + ) + tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2) + else: + hift_cache_source = torch.zeros(1, 1, 0) + + # keep overlap mel and hift cache + if finalize is False: + self.mel_overlap_dict.set(session_id, tts_mel[:, :, -self.mel_overlap_len :]) + + tts_mel = tts_mel[:, :, : -self.mel_overlap_len] + tts_speech, tts_source = self.hift.inference( + mel=tts_mel, cache_source=hift_cache_source + ) + + if self.hift_cache_dict.get(session_id) is not None: + tts_speech = fade_in_out( + tts_speech, self.hift_cache_dict.get(session_id)["speech"], self.speech_window + ) + self.hift_cache_dict.set( + session_id, + { + "mel": tts_mel[:, :, -self.mel_cache_len :], + "source": tts_source[:, :, -self.source_cache_len :], + "speech": tts_speech[:, -self.source_cache_len :], + }, + ) + + tts_speech = tts_speech[:, : -self.source_cache_len] + + logging.info("tts_speech: {}".format(tts_speech.shape)) + else: # finalize + if speed != 1.0: + assert ( + self.hift_cache_dict.get(session_id) is None + ), "speed change only support non-stream inference mode" + tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode="linear") + tts_speech, tts_source = self.hift.inference( + mel=tts_mel, cache_source=hift_cache_source + ) + if self.hift_cache_dict.get(session_id) is not None: + tts_speech = fade_in_out( + tts_speech, self.hift_cache_dict.get(session_id)["speech"], self.speech_window + ) + + self.mel_overlap_dict.pop(session_id) + self.hift_cache_dict.pop(session_id) + self.flow_cache_dict.pop(session_id) + logging.info("finalize tts_speech: {}".format(tts_speech.shape)) + + return tts_speech.cpu() diff --git a/cosyvoice/flow/flow.py b/cosyvoice/flow/flow.py index ab9a812..a4fa4e5 100644 --- a/cosyvoice/flow/flow.py +++ b/cosyvoice/flow/flow.py @@ -121,9 +121,7 @@ def forward( conds = conds.transpose(1, 2) mask = (~make_pad_mask(feat_len)).to(h) - feat = F.interpolate( - feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest" - ).squeeze(dim=1) + feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1) loss, _ = self.decoder.compute_loss( feat.transpose(1, 2).contiguous(), mask.unsqueeze(1), @@ -143,6 +141,7 @@ def inference( prompt_feat, prompt_feat_len, embedding, + flow_cache=None, ): assert token.shape[0] == 1 # xvec projection @@ -159,11 +158,14 @@ def inference( token = self.input_embedding(torch.clamp(token, min=0)) h, _ = self.encoder.inference(token, token_len) h = self.encoder_proj(h) - mel_len1, mel_len2 = prompt_feat.shape[1], int( - token_len2 - / self.input_frame_rate - * self.mel_feat_conf["sampling_rate"] - / self.mel_feat_conf["hop_size"] + mel_len1, mel_len2 = ( + prompt_feat.shape[1], + int( + token_len2 + / self.input_frame_rate + * self.mel_feat_conf["sampling_rate"] + / self.mel_feat_conf["hop_size"] + ), ) h, _ = self.length_regulator.inference( @@ -174,23 +176,21 @@ def inference( ) # get conditions - conds = torch.zeros( - [1, mel_len1 + mel_len2, self.output_size], device=token.device - ) + conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device) conds[:, :mel_len1] = prompt_feat conds = conds.transpose(1, 2) # mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h) - mask = torch.ones( - [1, mel_len1 + mel_len2], device=h.device, dtype=torch.bfloat16 - ) - feat = self.decoder( + mask = torch.ones([1, mel_len1 + mel_len2], device=h.device, dtype=torch.bfloat16) + feat, flow_cache = self.decoder( mu=h.transpose(1, 2).contiguous(), mask=mask.unsqueeze(1), spks=embedding, cond=conds, n_timesteps=10, + prompt_len=mel_len1, + flow_cache=flow_cache, ) feat = feat[:, :, mel_len1:] assert feat.shape[2] == mel_len2 - return feat + return feat.float(), flow_cache diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py index d29673f..7803eab 100644 --- a/cosyvoice/flow/flow_matching.py +++ b/cosyvoice/flow/flow_matching.py @@ -51,6 +51,9 @@ def forward( temperature=1.0, spks=None, cond=None, + prompt_len=0, + # flow_cache=torch.zeros(1, 80, 0, 2), + flow_cache=None, ): """Forward diffusion @@ -69,13 +72,24 @@ def forward( sample: generated mel-spectrogram shape: (batch_size, n_feats, mel_timesteps) """ - z = torch.randn_like(mu) * temperature + z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature + + if flow_cache is not None: + cache_size = flow_cache.shape[2] + # fix prompt and overlap part mu and z + if cache_size != 0: + z[:, :, :cache_size] = flow_cache[:, :, :, 0] + mu[:, :, :cache_size] = flow_cache[:, :, :, 1] + z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2) + mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2) + flow_cache = torch.stack([z_cache, mu_cache], dim=-1) + t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype) if self.t_scheduler == "cosine": t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) return self.solve_euler( z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond - ) + ), flow_cache @torch.inference_mode() def capture_inference(self, seq_len_to_capture=list(range(128, 512, 8))): @@ -96,9 +110,7 @@ def capture_inference(self, seq_len_to_capture=list(range(128, 512, 8))): static_mask = torch.ones( 1, 1, seq_len, device=torch.device("cuda"), dtype=torch.bfloat16 ) - static_spks = torch.randn( - 1, 80, device=torch.device("cuda"), dtype=torch.bfloat16 - ) + static_spks = torch.randn(1, 80, device=torch.device("cuda"), dtype=torch.bfloat16) static_cond = torch.randn( 1, 80, seq_len, device=torch.device("cuda"), dtype=torch.float32 ) @@ -231,9 +243,7 @@ def _solve_euler_impl(self, x, t_span, mu, mask, spks, cond): mu_double = torch.cat([mu, torch.zeros_like(mu)], dim=0) t_double = torch.cat([t, t], dim=0) spks_double = ( - torch.cat([spks, torch.zeros_like(spks)], dim=0) - if spks is not None - else None + torch.cat([spks, torch.zeros_like(spks)], dim=0) if spks is not None else None ) cond_double = torch.cat([cond, torch.zeros_like(cond)], dim=0) @@ -309,7 +319,5 @@ def compute_loss(self, x1, mask, mu, spks=None, cond=None): cond = cond * cfg_mask.view(-1, 1, 1) pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond) - loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / ( - torch.sum(mask) * u.shape[1] - ) + loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1]) return loss, y diff --git a/cosyvoice/hifigan/generator.py b/cosyvoice/hifigan/generator.py index 4d02c03..f27596a 100644 --- a/cosyvoice/hifigan/generator.py +++ b/cosyvoice/hifigan/generator.py @@ -146,17 +146,15 @@ def forward(self, f0): :return: [B, 1, sample_len] """ - F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to( - f0.device - ) + F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device) for i in range(self.harmonic_num + 1): F_mat[:, i : i + 1, :] = f0 * (i + 1) / self.sampling_rate theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1) u_dist = Uniform(low=-np.pi, high=np.pi) - phase_vec = u_dist.sample( - sample_shape=(f0.size(0), self.harmonic_num + 1, 1) - ).to(F_mat.device) + phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to( + F_mat.device + ) phase_vec[:, 0, :] = 0 # generate sine waveforms @@ -322,9 +320,7 @@ def __init__( ): if u == 1: self.source_downs.append( - Conv1d( - istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1 - ) + Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1) ) else: self.source_downs.append( @@ -337,21 +333,15 @@ def __init__( ) ) - self.source_resblocks.append( - ResBlock(base_channels // (2 ** (i + 1)), k, d) - ) + self.source_resblocks.append(ResBlock(base_channels // (2 ** (i + 1)), k, d)) self.resblocks = nn.ModuleList() for i in range(len(self.ups)): ch = base_channels // (2 ** (i + 1)) - for _, (k, d) in enumerate( - zip(resblock_kernel_sizes, resblock_dilation_sizes) - ): + for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): self.resblocks.append(ResBlock(ch, k, d)) - self.conv_post = weight_norm( - Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3) - ) + self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3)) self.ups.apply(init_weights) self.conv_post.apply(init_weights) self.reflection_pad = nn.ReflectionPad1d((1, 0)) @@ -491,6 +481,11 @@ def inference( curr_seq_len = mel.shape[2] f0 = self.f0_predictor(mel) s = self._f02source(f0) + + # use cache_source to avoid glitch + if cache_source is not None and cache_source.shape[2] != 0: + s[:, :, : cache_source.shape[2]] = cache_source + s_stft_real, s_stft_imag = self._stft(s.squeeze(1)) s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1) @@ -533,13 +528,9 @@ def inference( @torch.inference_mode() def capture_inference(self, seq_len_to_capture=[64, 128, 256, 512, 1024]): start_time = time.time() - print( - f"capture inference for HiFTGenerator with seq_len_to_capture: {seq_len_to_capture}" - ) + print(f"capture inference for HiFTGenerator with seq_len_to_capture: {seq_len_to_capture}") for seq_len in seq_len_to_capture: - mel = torch.randn( - 1, 80, seq_len, device=torch.device("cuda"), dtype=torch.float32 - ) + mel = torch.randn(1, 80, seq_len, device=torch.device("cuda"), dtype=torch.float32) f0 = self.f0_predictor(mel) s = self._f02source(f0) s_stft_real, s_stft_imag = self._stft(s.squeeze(1)) diff --git a/cosyvoice/utils/common.py b/cosyvoice/utils/common.py index e9611b6..c038452 100644 --- a/cosyvoice/utils/common.py +++ b/cosyvoice/utils/common.py @@ -16,6 +16,7 @@ """Unility functions for Transformer.""" import random +import threading from typing import List import numpy as np @@ -88,9 +89,7 @@ def th_accuracy( pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1) ).argmax(2) mask = pad_targets != ignore_label - numerator = torch.sum( - pad_pred.masked_select(mask) == pad_targets.masked_select(mask) - ) + numerator = torch.sum(pad_pred.masked_select(mask) == pad_targets.masked_select(mask)) denominator = torch.sum(mask) return (numerator / denominator).detach() @@ -129,9 +128,7 @@ def ras_sampling( def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25): prob, indices = [], [] cum_prob = 0.0 - sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort( - descending=True, stable=True - ) + sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True) for i in range(len(sorted_idx)): # sampling both top-p and numbers. if cum_prob < top_p and len(prob) < top_k: @@ -167,3 +164,22 @@ def set_all_random_seed(seed): np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) + + +class ThreadSafeDict: + def __init__(self): + self._dict = {} + # 使用 RLock 可重入锁,避免死锁 + self._lock = threading.RLock() + + def get(self, key, default=None): + with self._lock: + return self._dict.get(key, default) + + def set(self, key, value): + with self._lock: + self._dict[key] = value + + def pop(self, key): + with self._lock: + return self._dict.pop(key, None) diff --git a/streamer.py b/streamer.py new file mode 100644 index 0000000..57bce40 --- /dev/null +++ b/streamer.py @@ -0,0 +1,41 @@ +from queue import Queue + +from transformers.generation.streamers import BaseStreamer + + +class TokenStreamer(BaseStreamer): + def __init__(self, skip_prompt: bool = False, timeout=None): + self.skip_prompt = skip_prompt + + # variables used in the streaming process + self.token_queue = Queue() + self.stop_signal = None + self.next_tokens_are_prompt = True + self.timeout = timeout + + def put(self, value): + if len(value.shape) > 1 and value.shape[0] > 1: + raise ValueError("TextStreamer only supports batch size 1") + elif len(value.shape) > 1: + value = value[0] + # print(value) + + if self.skip_prompt and self.next_tokens_are_prompt: + self.next_tokens_are_prompt = False + return + + for token in value.tolist(): + self.token_queue.put(token) + + def end(self): + self.token_queue.put(self.stop_signal) + + def __iter__(self): + return self + + def __next__(self): + value = self.token_queue.get(timeout=self.timeout) + if value == self.stop_signal: + raise StopIteration() + else: + return value diff --git a/tts.py b/tts.py index 0122f22..8533b19 100644 --- a/tts.py +++ b/tts.py @@ -1,6 +1,11 @@ +import logging +import math import os import re import json +from threading import Thread, Lock +import uuid + import torchaudio import torch @@ -9,12 +14,12 @@ from transformers.generation.utils import LogitsProcessorList from cosyvoice.cli.cosyvoice import CosyVoice +from cosyvoice.utils.common import ThreadSafeDict +from streamer import TokenStreamer class RepetitionAwareLogitsProcessor(LogitsProcessor): - def __call__( - self, input_ids: torch.LongTensor, scores: torch.FloatTensor - ) -> torch.FloatTensor: + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: window_size = 10 threshold = 0.1 @@ -36,16 +41,56 @@ def __init__( self, model_path, encoder, + device_map: str | dict | None = None, + stream_factor: int = 2, + stream_scale_factor: float = 1.0, + max_stream_factor: int = 2, + token_overlap_len: int = 20, + **kwargs, ): + # fast path to check params + # rtf and decoding related + assert ( + stream_factor >= 2 + ), f"stream_factor must >=2 increase for better speech quality, but rtf slow (speech quality vs rtf)" + self.stream_factor = stream_factor + self.max_stream_factor = max_stream_factor + assert ( + stream_scale_factor >= 1.0 + ), "stream_scale_factor should be greater than 1, change it according to your actual rtf" + self.stream_scale_factor = stream_scale_factor # scale speed + assert ( + token_overlap_len >= 0 + ), "token_overlap_len should be greater than 0, change it according to your actual rtf" + self.token_overlap_len = token_overlap_len + + # session ctx dict with lock, maybe need a session class + self.session_lm_generated_ids = ThreadSafeDict() # session_id: ids(ptr) + # load optimus_ths for flash attention, make sure LD_LIBRARY_PATH has `nvidia/cuda_nvrtc/lib` # if not, please manually set LD_LIBRARY_PATH=xxx/python3.10/site-packages/nvidia/cuda_nvrtc/lib try: if torch.__version__ >= "2.5": - torch.ops.load_library(os.path.join(model_path, 'lib/liboptimus_ths-torch2.5-cu124.cpython-310-x86_64-linux-gnu.so')) + torch.ops.load_library( + os.path.join( + model_path, + "lib/liboptimus_ths-torch2.5-cu124.cpython-310-x86_64-linux-gnu.so", + ) + ) elif torch.__version__ >= "2.3": - torch.ops.load_library(os.path.join(model_path, 'lib/liboptimus_ths-torch2.3-cu121.cpython-310-x86_64-linux-gnu.so')) + torch.ops.load_library( + os.path.join( + model_path, + "lib/liboptimus_ths-torch2.3-cu121.cpython-310-x86_64-linux-gnu.so", + ) + ) elif torch.__version__ >= "2.2": - torch.ops.load_library(os.path.join(model_path, 'lib/liboptimus_ths-torch2.2-cu121.cpython-310-x86_64-linux-gnu.so')) + torch.ops.load_library( + os.path.join( + model_path, + "lib/liboptimus_ths-torch2.2-cu121.cpython-310-x86_64-linux-gnu.so", + ) + ) print("Load optimus_ths successfully and flash attn would be enabled") except Exception as err: print(f"Fail to load optimus_ths and flash attn is disabled: {err}") @@ -53,17 +98,18 @@ def __init__( self.llm = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch.bfloat16, - device_map="cuda", + device_map="auto" if not device_map else device_map, trust_remote_code=True, + **kwargs, ) - self.tokenizer = AutoTokenizer.from_pretrained( - model_path, trust_remote_code=True - ) + self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) self.common_cosy_model = CosyVoice( - os.path.join(model_path, "CosyVoice-300M-25Hz") + os.path.join(model_path, "CosyVoice-300M-25Hz"), + token_overlap_len=token_overlap_len, ) self.music_cosy_model = CosyVoice( - os.path.join(model_path, "CosyVoice-300M-25Hz-Music") + os.path.join(model_path, "CosyVoice-300M-25Hz-Music"), + token_overlap_len=token_overlap_len, ) self.encoder = encoder self.sys_prompt_dict = { @@ -75,35 +121,9 @@ def __init__( self.register_speakers() def __call__(self, text: str, prompt_speaker: str, clone_dict: dict | None = None): - if clone_dict: - clone_prompt_code, clone_prompt_token, clone_prompt_token_len, clone_speech_feat, clone_speech_feat_len, clone_speech_embedding = ( - self.preprocess_prompt_wav(clone_dict['wav_path']) - ) - prompt_speaker = clone_dict['speaker'] - self.speakers_info[prompt_speaker] = { - "prompt_text": clone_dict['prompt_text'], - "prompt_code": clone_prompt_code, - "cosy_speech_feat": clone_speech_feat.to(torch.bfloat16), - "cosy_speech_feat_len": clone_speech_feat_len, - "cosy_speech_embedding": clone_speech_embedding.to(torch.bfloat16), - "cosy_prompt_token": clone_prompt_token, - "cosy_prompt_token_len": clone_prompt_token_len, - } - - instruction_name = self.detect_instruction_name(text) - prompt_speaker_info = self.speakers_info[prompt_speaker] - - if instruction_name in ("RAP", "哼唱"): - if not clone_dict: - prompt_speaker_info = self.speakers_info[ - f"{prompt_speaker}{instruction_name}" - ] - cosy_model = self.music_cosy_model - else: - cosy_model = self.common_cosy_model - - if clone_dict: - prompt_speaker = '' + prompt_speaker, prompt_speaker_info, cosy_model = self.preprocess_prompt( + text, prompt_speaker, clone_dict=clone_dict + ) token_ids = self.tokenize( text, @@ -134,14 +154,21 @@ def __call__(self, text: str, prompt_speaker: str, clone_dict: dict | None = Non def register_speakers(self): self.speakers_info = {} - with open("speakers/speakers_info.json", "r") as f: + cur_dir = os.path.dirname(os.path.abspath(__file__)) + file_path: str = os.path.join(cur_dir, "speakers/speakers_info.json") + with open(file_path, "r") as f: speakers_info = json.load(f) for speaker_id, prompt_text in speakers_info.items(): - prompt_wav_path = f"speakers/{speaker_id}_prompt.wav" - prompt_code, prompt_token, prompt_token_len, speech_feat, speech_feat_len, speech_embedding = ( - self.preprocess_prompt_wav(prompt_wav_path) - ) + prompt_wav_path = os.path.join(cur_dir, f"speakers/{speaker_id}_prompt.wav") + ( + prompt_code, + prompt_token, + prompt_token_len, + speech_feat, + speech_feat_len, + speech_embedding, + ) = self.preprocess_prompt_wav(prompt_wav_path) self.speakers_info[speaker_id] = { "prompt_text": prompt_text, @@ -162,9 +189,7 @@ def detect_instruction_name(self, text): instruction_name = instruction.strip("()()") return instruction_name - def tokenize( - self, text: str, prompt_text: str, prompt_speaker: str, prompt_code: list - ): + def tokenize(self, text: str, prompt_text: str, prompt_speaker: str, prompt_code: list): rap_or_vocal = self.detect_instruction_name(text) in ("RAP", "哼唱") if rap_or_vocal: @@ -210,23 +235,21 @@ def tokenize( ) return history - def preprocess_prompt_wav(self, prompt_wav_path : str): + def preprocess_prompt_wav(self, prompt_wav_path: str): prompt_wav, prompt_wav_sr = torchaudio.load(prompt_wav_path) if prompt_wav.shape[0] > 1: prompt_wav = prompt_wav.mean(dim=0, keepdim=True) # 将多通道音频转换为单通道 - prompt_wav_16k = torchaudio.transforms.Resample( - orig_freq=prompt_wav_sr, new_freq=16000 - )(prompt_wav) - prompt_wav_22k = torchaudio.transforms.Resample( - orig_freq=prompt_wav_sr, new_freq=22050 - )(prompt_wav) - - speech_feat, speech_feat_len = ( - self.common_cosy_model.frontend._extract_speech_feat(prompt_wav_22k) + prompt_wav_16k = torchaudio.transforms.Resample(orig_freq=prompt_wav_sr, new_freq=16000)( + prompt_wav ) - speech_embedding = self.common_cosy_model.frontend._extract_spk_embedding( - prompt_wav_16k + prompt_wav_22k = torchaudio.transforms.Resample(orig_freq=prompt_wav_sr, new_freq=22050)( + prompt_wav + ) + + speech_feat, speech_feat_len = self.common_cosy_model.frontend._extract_speech_feat( + prompt_wav_22k ) + speech_embedding = self.common_cosy_model.frontend._extract_spk_embedding(prompt_wav_16k) prompt_code, _, _ = self.encoder.wav2token(prompt_wav, prompt_wav_sr) prompt_token = torch.tensor([prompt_code], dtype=torch.long) - 65536 @@ -239,4 +262,147 @@ def preprocess_prompt_wav(self, prompt_wav_path : str): speech_feat, speech_feat_len, speech_embedding, - ) \ No newline at end of file + ) + + def preprocess_prompt(self, text: str, prompt_speaker: str, clone_dict: dict | None = None): + if clone_dict: + ( + clone_prompt_code, + clone_prompt_token, + clone_prompt_token_len, + clone_speech_feat, + clone_speech_feat_len, + clone_speech_embedding, + ) = self.preprocess_prompt_wav(clone_dict["wav_path"]) + prompt_speaker = clone_dict["speaker"] + self.speakers_info[prompt_speaker] = { + "prompt_text": clone_dict["prompt_text"], + "prompt_code": clone_prompt_code, + "cosy_speech_feat": clone_speech_feat.to(torch.bfloat16), + "cosy_speech_feat_len": clone_speech_feat_len, + "cosy_speech_embedding": clone_speech_embedding.to(torch.bfloat16), + "cosy_prompt_token": clone_prompt_token, + "cosy_prompt_token_len": clone_prompt_token_len, + } + + instruction_name = self.detect_instruction_name(text) + prompt_speaker_info = self.speakers_info[prompt_speaker] + + if instruction_name in ("RAP", "哼唱"): + if not clone_dict: + prompt_speaker_info = self.speakers_info[f"{prompt_speaker}{instruction_name}"] + cosy_model = self.music_cosy_model + else: + cosy_model = self.common_cosy_model + + return prompt_speaker, prompt_speaker_info, cosy_model + + @torch.inference_mode() + def batch_stream( + self, + text: str, + prompt_speaker: str, + clone_dict: dict | None = None, + session_id: str = str(uuid.uuid4()), + ): + """ + - step1 lm stream generate token + - batch size to gen waveform + - when max_stream_factor > stream_factor, dynamic batch size to gen waveform + - when max_stream_factor <= stream_factor, static batch size to gen waveform + - flow: audio vq tokens to mel + - hift: mel to waveform + """ + prompt_speaker, prompt_speaker_info, cosy_model = self.preprocess_prompt( + text, prompt_speaker, clone_dict=clone_dict + ) + output_audio_sample_rate = cosy_model.model.hift.sampling_rate + + token_ids = self.tokenize( + text, + prompt_speaker_info["prompt_text"], + prompt_speaker, + prompt_speaker_info["prompt_code"], + ) + + # session streamer + streamer = TokenStreamer(skip_prompt=True) + + generation_kwargs = dict( + input_ids=torch.tensor([token_ids]).to(torch.long).to("cuda"), + eos_token_id=3, + streamer=streamer, + max_length=8192, + temperature=0.7, + do_sample=True, + logits_processor=LogitsProcessorList([RepetitionAwareLogitsProcessor()]), + ) + # print("generation_kwargs", generation_kwargs) + + thread = Thread(target=self.llm.generate, kwargs=generation_kwargs) + thread.start() + + self.session_lm_generated_ids.set(session_id, []) + + max_batch_size = math.ceil(self.max_stream_factor * cosy_model.model.flow.input_frame_rate) + batch_size = math.ceil(self.stream_factor * cosy_model.model.flow.input_frame_rate) + logging.info(f"init batch_size: {batch_size} max_batch_size: {max_batch_size}") + for token_id in streamer: + # print(token_id, end=",", flush=True) + if token_id == 3: # skip <|EOT|>, break + break + self.session_lm_generated_ids.get(session_id).append(token_id) + # if len(self.session_lm_generated_ids.get(session_id)) % batch_size == 0: + if ( + len(self.session_lm_generated_ids.get(session_id)) + >= batch_size + self.token_overlap_len + ): + batch = ( + torch.tensor( + self.session_lm_generated_ids.get(session_id)[ + : batch_size + self.token_overlap_len + ] + ) + .unsqueeze(0) + .to(cosy_model.model.device) + - 65536 + ) # [T] -> [1,T] + # Process each batch + sub_tts_speech = cosy_model.model.token2wav( + batch, + prompt_speaker_info["cosy_prompt_token"], + prompt_speaker_info["cosy_speech_feat"].to(torch.bfloat16), + prompt_speaker_info["cosy_speech_embedding"].to(torch.bfloat16), + session_id, + finalize=False, + ) + yield {"tts_speech": sub_tts_speech, "sample_rate": output_audio_sample_rate} + self.session_lm_generated_ids.set( + session_id, self.session_lm_generated_ids.get(session_id)[batch_size:] + ) + # increase token_hop_len for better speech quality + batch_size = min(max_batch_size, int(batch_size * self.stream_scale_factor)) + logging.info(f"increase batch_size: {batch_size}") + + if len(self.session_lm_generated_ids.get(session_id)) == 0: # end to finalize + self.session_lm_generated_ids.set(session_id, [65536]) + + batch = ( + torch.tensor(self.session_lm_generated_ids.get(session_id)) + .unsqueeze(0) + .to(cosy_model.model.device) + - 65536 + ) # [T] -> [1,T] + # Process each batch + sub_tts_speech = cosy_model.model.token2wav( + batch, + prompt_speaker_info["cosy_prompt_token"], + prompt_speaker_info["cosy_speech_feat"].to(torch.bfloat16), + prompt_speaker_info["cosy_speech_embedding"].to(torch.bfloat16), + session_id, + finalize=True, + ) + yield {"tts_speech": sub_tts_speech, "sample_rate": output_audio_sample_rate} + + self.session_lm_generated_ids.pop(session_id) + torch.cuda.empty_cache() diff --git a/tts_inference_stream.py b/tts_inference_stream.py new file mode 100644 index 0000000..08869a7 --- /dev/null +++ b/tts_inference_stream.py @@ -0,0 +1,78 @@ +import os +import argparse + +import torchaudio + +from tokenizer import StepAudioTokenizer +from utils import merge_tensors +from tts import StepAudioTTS + + +def main(): + parser = argparse.ArgumentParser(description="StepAudio Stream Inference") + parser.add_argument("--model-path", type=str, required=True, help="Base path for model files") + parser.add_argument( + "--synthesis-type", type=str, default="tts", help="Use tts or Clone for Synthesis" + ) + parser.add_argument( + "--output-path", type=str, required=True, help="Output path for synthesis audios" + ) + parser.add_argument( + "--stream", type=str, default="static_batch", help="Synthesis audios with streaming" + ) + parser.add_argument( + "--stream-factor", type=int, default=2, help="Synthesis audios stream factor" + ) + parser.add_argument( + "--stream-scale-factor", type=float, default=1.0, help="Synthesis audios stream scale factor" + ) + parser.add_argument( + "--max-stream-factor", type=int, default=2, help="Synthesis audios max stream factor" + ) + parser.add_argument( + "--token-overlap-len", type=int, default=20, help="Synthesis audios token overlap len" + ) + args = parser.parse_args() + os.makedirs(f"{args.output_path}", exist_ok=True) + + encoder = StepAudioTokenizer(f"{args.model_path}/Step-Audio-Tokenizer") + tts_engine = StepAudioTTS( + f"{args.model_path}/Step-Audio-TTS-3B", + encoder, + stream_factor=args.stream_factor, + stream_scale_factor=args.stream_scale_factor, + max_stream_factor=args.max_stream_factor, + token_overlap_len=args.token_overlap_len, + ) + + if args.synthesis_type == "tts": + text = "(RAP)君不见黄河之水天上来,奔流到海不复回。君不见高堂明镜悲白发,朝如青丝暮成雪。人生得意须尽欢,莫使金樽空对月。天生我材必有用,千金散尽还复来。" + text = os.getenv("TTS_TEXT", text) + batch_stream = tts_engine.batch_stream(text, "Tingting") + sub_tts_speechs = [] + sr = 22050 + for item in batch_stream: + sr = item["sample_rate"] + sub_tts_speechs.append(item["tts_speech"]) + output_audio = merge_tensors(sub_tts_speechs) # [1,T] + torchaudio.save(f"{args.output_path}/output_tts_stream.wav", output_audio, sr) + else: + clone_speaker = { + "speaker": "test", + "prompt_text": "叫做秋风起蟹脚痒,啊,什么意思呢?就是说这秋风一起啊,螃蟹就该上市了。", + "wav_path": "examples/prompt_wav_yuqian.wav", + } + text_clone = "万物之始,大道至简,衍化至繁。君不见黄河之水天上来,奔流到海不复回。君不见高堂明镜悲白发,朝如青丝暮成雪。人生得意须尽欢,莫使金樽空对月。天生我材必有用,千金散尽还复来。" + text_clone = os.getenv("TTS_TEXT", text_clone) + batch_stream = tts_engine.batch_stream(text_clone, "", clone_speaker) + sub_tts_speechs = [] + sr = 22050 + for item in batch_stream: + sr = item["sample_rate"] + sub_tts_speechs.append(item["tts_speech"]) + output_audio = merge_tensors(sub_tts_speechs) # [1,T] + torchaudio.save(f"{args.output_path}/output_clone_stream.wav", output_audio, sr) + + +if __name__ == "__main__": + main() diff --git a/utils.py b/utils.py index 2c98513..f4cb260 100644 --- a/utils.py +++ b/utils.py @@ -33,18 +33,14 @@ def trim_silence(audio, sr, keep_left_time=0.05, keep_right_time=0.22, hop_size= if start_idx > 0: trim_wav = trim_wav[start_idx:] else: - trim_wav = np.pad( - trim_wav, (abs(start_idx), 0), mode="constant", constant_values=0.0 - ) + trim_wav = np.pad(trim_wav, (abs(start_idx), 0), mode="constant", constant_values=0.0) wav_len = len(trim_wav) out_len = int(num_frames * hop_size + (keep_left_time + keep_right_time) * sr) if out_len < wav_len: trim_wav = trim_wav[:out_len] else: - trim_wav = np.pad( - trim_wav, (0, (out_len - wav_len)), mode="constant", constant_values=0.0 - ) + trim_wav = np.pad(trim_wav, (0, (out_len - wav_len)), mode="constant", constant_values=0.0) return trim_wav @@ -109,9 +105,7 @@ def audio_resample(audio16bit_torch, result_sr, target_sample_rate): def norm_audio(audio16bit_torch): # 直接 归一化处理。 audio16bit_torch = audio16bit_torch.numpy() - audio16bit_torch = ( - audio16bit_torch / np.abs(audio16bit_torch).max() * 32767 - ).astype(np.int16) + audio16bit_torch = (audio16bit_torch / np.abs(audio16bit_torch).max() * 32767).astype(np.int16) audio16bit_torch = torch.from_numpy(audio16bit_torch) return audio16bit_torch @@ -142,8 +136,7 @@ def energy_norm_fn(wav): def get_audio_tokens(audio_tokens: str) -> list[int]: audio_tokens = audio_tokens.split(">", "")) + 65536 - for token in audio_tokens + int(token.replace("", "")) + 65536 for token in audio_tokens ] return audio_tokens @@ -152,3 +145,65 @@ def load_audio(audio_path: str): audio_wav, sr = torchaudio.load(audio_path) audio_wav = audio_wav.mean(dim=0, keepdim=True) return audio_wav, sr + + +def splite_batches(tensor_audio_token_ids: torch.Tensor, batch_size: int): + """ + splite batches of audio token IDs. + # Assuming tensor_audio_token_ids is already defined as in your original code + # and has shape (1, sequence_length) + + Args: + tensor_audio_token_ids: A tensor of audio token IDs. + batch_size: The desired batch size. + + Returns: + A list of tensors, where each tensor is a batch of audio token IDs. + """ + sequence_length = tensor_audio_token_ids.shape[1] + num_batches = (sequence_length + batch_size - 1) // batch_size + batched_ids = [] + + for i in range(num_batches): + start_index = i * batch_size + end_index = min((i + 1) * batch_size, sequence_length) + batch = tensor_audio_token_ids[:, start_index:end_index] + batched_ids.append(batch) + + return batched_ids + + +def merge_tensors(sub_tts_speechs: torch.Tensor): + """ + Merges a list of tensors into a single tensor. + # Assuming sub_tts_speechs is a list of tensors + # and all tensors in the list have the same number of channels + # and has shape (1, sequence_length) + # but possibly different lengths. + + Args: + sub_tts_speechs: A list of tensors. + + Returns: + A single tensor with all the sub tensors concatenated along the time dimension. + Returns None if the input list is empty or contains tensors with inconsistent shapes. + """ + if not sub_tts_speechs: + return None + + num_channels = sub_tts_speechs[0].shape[0] + total_length = sum(tensor.shape[1] for tensor in sub_tts_speechs) + merged_tensor = torch.empty( + num_channels, total_length, dtype=sub_tts_speechs[0].dtype, device=sub_tts_speechs[0].device + ) + current_position = 0 + + for tensor in sub_tts_speechs: + if tensor.shape[0] != num_channels: + print("Error: Tensors have inconsistent number of channels.") + return None + + merged_tensor[:, current_position : current_position + tensor.shape[1]] = tensor + current_position += tensor.shape[1] + + return merged_tensor