From 0c27a8d1af2e0e94ac402da95535e4790e123fc8 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 9 Sep 2025 18:21:45 +0800 Subject: [PATCH 01/12] support batching and trt for token2wav --- cosyvoice2/flow/decoder_dit.py | 1 + cosyvoice2/flow/flow.py | 29 +-- cosyvoice2/flow/flow_matching.py | 49 ++++- token2wav_batch.py | 336 +++++++++++++++++++++++++++++++ 4 files changed, 393 insertions(+), 22 deletions(-) create mode 100644 token2wav_batch.py diff --git a/cosyvoice2/flow/decoder_dit.py b/cosyvoice2/flow/decoder_dit.py index cb80edc..b3615e3 100644 --- a/cosyvoice2/flow/decoder_dit.py +++ b/cosyvoice2/flow/decoder_dit.py @@ -491,6 +491,7 @@ def forward(self, x, mask, mu, t, spks=None, cond=None): # time t = self.t_embedder(t).unsqueeze(1) # (b, 1, c) + x = pack([x, mu], "b * t")[0] if spks is not None: spks = repeat(spks, "b c -> b c t", t=x.shape[-1]) diff --git a/cosyvoice2/flow/flow.py b/cosyvoice2/flow/flow.py index f252d9b..c9e4170 100644 --- a/cosyvoice2/flow/flow.py +++ b/cosyvoice2/flow/flow.py @@ -65,39 +65,42 @@ def scatter_cuda_graph(self, enable_cuda_graph: bool): def inference(self, token, token_len, - prompt_token, - prompt_token_len, + # prompt_token, + # prompt_token_len, prompt_feat, prompt_feat_len, embedding, n_timesteps: int = 10, ): - assert token.shape[0] == 1 + # assert token.shape[0] == 1 # xvec projection embedding = F.normalize(embedding, dim=1) embedding = self.spk_embed_affine_layer(embedding) # concat text and prompt_text - token_len = prompt_token_len + token_len - token = torch.concat([prompt_token, token], dim=1) + # token_len = prompt_token_len + token_len + # token = torch.concat([prompt_token, token], dim=1) mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding) token = self.input_embedding(torch.clamp(token, min=0)) * mask # token encode - h, _ = self.encoder.forward(token, token_len) + h, h_lengths = self.encoder.forward(token, token_len) h = self.encoder_proj(h) # condition - mel_len1 = prompt_feat.shape[1] - mel_len2 = h.shape[1] - prompt_feat.shape[1] + # mel_len1 = prompt_feat.shape[1] + # mel_len2 = h.shape[1] - prompt_feat.shape[1] conds = torch.zeros_like(h) - conds[:, :mel_len1] = prompt_feat + # conds[:, :mel_len1] = prompt_feat + for i, j in enumerate(prompt_feat_len): + conds[i, :j] = prompt_feat[i, :j] conds = conds.transpose(1, 2).contiguous() - mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h) + h_lengths = h_lengths.sum(dim=-1).squeeze(dim=1) + mask = (~make_pad_mask(h_lengths, max_len=h.shape[1])).to(h) feat = self.decoder.forward( mu=h.transpose(1, 2).contiguous(), @@ -107,9 +110,9 @@ def inference(self, n_timesteps=n_timesteps, ) - feat = feat[:, :, mel_len1:] - assert feat.shape[2] == mel_len2 - return feat + # feat = feat[:, :, mel_len1:] + # assert feat.shape[2] == mel_len2 + return feat.float(), h_lengths @torch.inference_mode() def setup_cache(self, diff --git a/cosyvoice2/flow/flow_matching.py b/cosyvoice2/flow/flow_matching.py index 900c71e..cbcfcfb 100644 --- a/cosyvoice2/flow/flow_matching.py +++ b/cosyvoice2/flow/flow_matching.py @@ -39,6 +39,36 @@ def scatter_cuda_graph(self, enable_cuda_graph: bool): if enable_cuda_graph: self.estimator._init_cuda_graph_all() + def forward_estimator(self, x, mask, mu, t, spks, cond): + if isinstance(self.estimator, torch.nn.Module): + return self.estimator(x, mask, mu, t, spks, cond) + else: + [estimator, stream], trt_engine = self.estimator.acquire_estimator() + # NOTE need to synchronize when switching stream + torch.cuda.current_stream().synchronize() + batch_size = x.size(0) + with stream: + estimator.set_input_shape('x', (batch_size, 80, x.size(2))) + estimator.set_input_shape('mask', (batch_size, 1, x.size(2))) + estimator.set_input_shape('mu', (batch_size, 80, x.size(2))) + estimator.set_input_shape('t', (batch_size,)) + estimator.set_input_shape('spks', (batch_size, 80)) + estimator.set_input_shape('cond', (batch_size, 80, x.size(2))) + data_ptrs = [x.contiguous().data_ptr(), + mask.contiguous().data_ptr(), + mu.contiguous().data_ptr(), + t.contiguous().data_ptr(), + spks.contiguous().data_ptr(), + cond.contiguous().data_ptr(), + x.data_ptr()] + for i, j in enumerate(data_ptrs): + estimator.set_tensor_address(trt_engine.get_tensor_name(i), j) + # run trt engine + assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True + torch.cuda.current_stream().synchronize() + self.estimator.release_estimator(estimator, stream) + return x + def solve_euler(self, x, t_span, mu, mask, spks, cond): """ Fixed euler solver for ODEs. @@ -55,7 +85,9 @@ def solve_euler(self, x, t_span, mu, mask, spks, cond): cond: Not used but kept for future purposes """ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] - t = t.unsqueeze(dim=0) + # t = t.unsqueeze(dim=0) + t_in = torch.zeros([x.shape[0] * 2], device=x.device, dtype=x.dtype) + assert self.inference_cfg_rate > 0, 'inference_cfg_rate better > 0' # constant during denoising @@ -65,18 +97,16 @@ def solve_euler(self, x, t_span, mu, mask, spks, cond): cond_in = torch.cat([cond, torch.zeros_like(cond)], dim=0) for step in range(1, len(t_span)): - x_in = torch.cat([x, x], dim=0) - t_in = torch.cat([t, t], dim=0) + t_in.fill_(t) - dphi_dt = self.estimator.forward( - x_in, - mask_in, - mu_in, - t_in, + dphi_dt = self.forward_estimator( + x_in, mask_in, + mu_in, t_in, spks_in, cond_in, ) + dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0) dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt) x = x + dt * dphi_dt @@ -88,7 +118,8 @@ def solve_euler(self, x, t_span, mu, mask, spks, cond): @torch.inference_mode() def forward(self, mu, mask, spks, cond, n_timesteps=10, temperature=1.0): - z = self.rand_noise[:, :, :mu.size(2)] * temperature + # z = self.rand_noise[:, :, :mu.size(2)] * temperature + z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype) # cosine scheduling t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) diff --git a/token2wav_batch.py b/token2wav_batch.py new file mode 100644 index 0000000..168cd64 --- /dev/null +++ b/token2wav_batch.py @@ -0,0 +1,336 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Example Usage + CUDA_VISIBLE_DEVICES=0 \ + python3 token2wav.py --enable-trt || exit 1 +""" +import torch +# from flashcosyvoice.modules.flow import CausalMaskedDiffWithXvec +from flashcosyvoice.modules.hifigan import HiFTGenerator +from flashcosyvoice.utils.audio import mel_spectrogram +import torchaudio.compliance.kaldi as kaldi +import onnxruntime +import s3tokenizer +from torch.utils.data import DataLoader +from datasets import load_dataset +import torchaudio +import os +import logging +import argparse +import queue +import time + +from hyperpyyaml import load_hyperpyyaml +def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16): + import tensorrt as trt + logging.info("Converting onnx to trt...") + network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + logger = trt.Logger(trt.Logger.INFO) + builder = trt.Builder(logger) + network = builder.create_network(network_flags) + parser = trt.OnnxParser(network, logger) + config = builder.create_builder_config() + # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB + if fp16: + config.set_flag(trt.BuilderFlag.FP16) + profile = builder.create_optimization_profile() + # load onnx model + with open(onnx_model, "rb") as f: + if not parser.parse(f.read()): + for error in range(parser.num_errors): + print(parser.get_error(error)) + raise ValueError('failed to parse {}'.format(onnx_model)) + # set input shapes + for i in range(len(trt_kwargs['input_names'])): + profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i]) + tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT + # set input and output data type + for i in range(network.num_inputs): + input_tensor = network.get_input(i) + input_tensor.dtype = tensor_dtype + for i in range(network.num_outputs): + output_tensor = network.get_output(i) + output_tensor.dtype = tensor_dtype + config.add_optimization_profile(profile) + engine_bytes = builder.build_serialized_network(network, config) + # save trt engine + with open(trt_model, "wb") as f: + f.write(engine_bytes) + logging.info("Succesfully convert onnx to trt...") + +class TrtContextWrapper: + def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'): + self.trt_context_pool = queue.Queue(maxsize=trt_concurrent) + self.trt_engine = trt_engine + self.device = device + for _ in range(trt_concurrent): + trt_context = trt_engine.create_execution_context() + trt_stream = torch.cuda.stream(torch.cuda.Stream(torch.device(device))) + assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent) + self.trt_context_pool.put([trt_context, trt_stream]) + assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context' + + def acquire_estimator(self): + return self.trt_context_pool.get(), self.trt_engine + + def release_estimator(self, context, stream): + self.trt_context_pool.put([context, stream]) + +class CosyVoice2_Token2Wav(torch.nn.Module): + def __init__(self, model_dir: str = "./CosyVoice2-0.5B", enable_trt: bool = False, device_id: int = 0): + super().__init__() + self.device_id = device_id + self.device = f"cuda:{device_id}" + with open(f"{model_dir}/flow.yaml", "r") as f: + configs = load_hyperpyyaml(f) + self.flow = configs['flow'] + # self.flow = CausalMaskedDiffWithXvec() + self.flow.half() + self.flow.load_state_dict(torch.load(f"{model_dir}/flow.pt", map_location="cpu", weights_only=True), strict=True) + self.flow.to(self.device).eval() + + self.hift = HiFTGenerator() + hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(f"{model_dir}/hift.pt", map_location="cpu", weights_only=True).items()} + self.hift.load_state_dict(hift_state_dict, strict=True) + self.hift.to(self.device).eval() + + option = onnxruntime.SessionOptions() + option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + option.intra_op_num_threads = 1 + self.spk_model = onnxruntime.InferenceSession(f"{model_dir}/campplus.onnx", sess_options=option, + providers=["CPUExecutionProvider"]) + + self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2_25hz.onnx").to(self.device).eval() + + gpu="l20" + if enable_trt: + self.load_trt(f'{model_dir}/flow.decoder.estimator.fp16.dynamic_batch.{gpu}.plan', + f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx', + 1, + True) + self.load_spk_trt(f'{model_dir}/campplus.{gpu}.fp32.trt', + f'{model_dir}/campplus.onnx', + 1, + False) + + + def forward_spk_embedding(self, spk_feat): + if isinstance(self.spk_model, onnxruntime.InferenceSession): + return self.spk_model.run( + None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()} + )[0].flatten().tolist() + else: + [spk_model, stream], trt_engine = self.spk_model.acquire_estimator() + # NOTE need to synchronize when switching stream + with torch.cuda.device(self.device_id): + torch.cuda.current_stream().synchronize() + spk_feat = spk_feat.unsqueeze(dim=0).to(self.device) + batch_size = spk_feat.size(0) + + with stream: + spk_model.set_input_shape('input', (batch_size, spk_feat.size(1), 80)) + output_tensor = torch.empty((batch_size, 192), device=spk_feat.device) + + data_ptrs = [spk_feat.contiguous().data_ptr(), + output_tensor.contiguous().data_ptr()] + for i, j in enumerate(data_ptrs): + + spk_model.set_tensor_address(trt_engine.get_tensor_name(i), j) + # run trt engine + assert spk_model.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True + torch.cuda.current_stream().synchronize() + self.spk_model.release_estimator(spk_model, stream) + + return output_tensor.cpu().numpy().flatten().tolist() + + def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp16=True): + if not os.path.exists(spk_model) or os.path.getsize(spk_model) == 0: + trt_kwargs = self.get_spk_trt_kwargs() + convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, fp16) + import tensorrt as trt + with open(spk_model, 'rb') as f: + spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) + assert spk_engine is not None, 'failed to load trt {}'.format(spk_model) + self.spk_model = TrtContextWrapper(spk_engine, trt_concurrent=trt_concurrent, device=self.device) + + def get_spk_trt_kwargs(self): + min_shape = [(1, 4, 80)] + opt_shape = [(1, 500, 80)] + max_shape = [(1, 3000, 80)] + input_names = ["input"] + return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} + + def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent=1, fp16=True): + assert torch.cuda.is_available(), 'tensorrt only supports gpu!' + if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0: + trt_kwargs = self.get_trt_kwargs_dynamic_batch(opt_batch_size=2, max_batch_size=16) + convert_onnx_to_trt(flow_decoder_estimator_model, trt_kwargs, flow_decoder_onnx_model, fp16) + del self.flow.decoder.estimator + import tensorrt as trt + with open(flow_decoder_estimator_model, 'rb') as f: + estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) + assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model) + self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device) + + def get_trt_kwargs_dynamic_batch(self, opt_batch_size=2, max_batch_size=64): + min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80)] + opt_shape = [(opt_batch_size*2, 80, 500), (opt_batch_size*2, 1, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2,), (opt_batch_size*2, 80)] + max_shape = [(max_batch_size*2, 80, 3000), (max_batch_size*2, 1, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2,), (max_batch_size*2, 80)] + input_names = ["x", "mask", "mu", "cond", "t", "spks"] + return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} + + def prompt_audio_tokenization(self, prompt_audios_list: list[torch.Tensor]) -> list[list[int]]: + prompt_speech_tokens_list, prompt_speech_mels_list = [], [] + for audio in prompt_audios_list: + assert len(audio.shape) == 1 + log_mel = s3tokenizer.log_mel_spectrogram(audio) # [num_mels, T] + prompt_speech_mels_list.append(log_mel) + prompt_mels_for_llm, prompt_mels_lens_for_llm = s3tokenizer.padding(prompt_speech_mels_list) + prompt_speech_tokens, prompt_speech_tokens_lens = self.audio_tokenizer.quantize( + prompt_mels_for_llm.to(self.device), prompt_mels_lens_for_llm.to(self.device) + ) + for i in range(len(prompt_speech_tokens)): + speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist() + prompt_speech_tokens_list.append(speech_tokens_i) + return prompt_speech_tokens_list + + def get_spk_emb(self, prompt_audios_list: list[torch.Tensor]) -> torch.Tensor: + spk_emb_for_flow = [] + for audio in prompt_audios_list: + assert len(audio.shape) == 1 + spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000) + spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True) + spk_emb = self.forward_spk_embedding(spk_feat) + + spk_emb_for_flow.append(spk_emb) + spk_emb_for_flow = torch.tensor(spk_emb_for_flow) + return spk_emb_for_flow + + def get_prompt_mels(self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]): + prompt_mels_for_flow = [] + prompt_mels_lens_for_flow = [] + for audio, sample_rate in zip(prompt_audios_list, prompt_audios_sample_rate): + assert len(audio.shape) == 1 + audio = audio.unsqueeze(0) + if sample_rate != 24000: + audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio) + mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0) # [T, num_mels] + mel_len = mel.shape[0] + prompt_mels_for_flow.append(mel) + prompt_mels_lens_for_flow.append(mel_len) + prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(prompt_mels_for_flow, batch_first=True, padding_value=0) # [B, T', num_mels=80] + prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow) + return prompt_mels_for_flow, prompt_mels_lens_for_flow + + def forward_flow(self, prompt_speech_tokens_list: list[list[int]], generated_speech_tokens_list: list[list[int]], prompt_mels_for_flow: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor, spk_emb_for_flow: torch.Tensor): + batch_size = prompt_mels_for_flow.shape[0] + flow_inputs = [] + flow_inputs_lens = [] + for prompt_speech_tokens, generated_speech_tokens in zip(prompt_speech_tokens_list, generated_speech_tokens_list): + flow_inputs.append(torch.tensor(prompt_speech_tokens + generated_speech_tokens)) + flow_inputs_lens.append(len(prompt_speech_tokens) + len(generated_speech_tokens)) + + flow_inputs = torch.nn.utils.rnn.pad_sequence(flow_inputs, batch_first=True, padding_value=0) + flow_inputs_lens = torch.tensor(flow_inputs_lens) + + with torch.amp.autocast(self.device, dtype=torch.float16): + generated_mels, generated_mels_lens = self.flow.inference( + flow_inputs.to(self.device), flow_inputs_lens.to(self.device), + prompt_mels_for_flow.to(self.device), prompt_mels_lens_for_flow.to(self.device), spk_emb_for_flow.to(self.device), 10 + ) + + return generated_mels, generated_mels_lens + + def forward_hift(self, generated_mels: torch.Tensor, generated_mels_lens: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor): + batch_size = generated_mels.shape[0] + generated_wavs = [] + for i in range(batch_size): + mel = generated_mels[i, :, prompt_mels_lens_for_flow[i].item():generated_mels_lens[i].item()].unsqueeze(0) + wav, _ = self.hift(speech_feat=mel) + generated_wavs.append(wav) + return generated_wavs + + + @torch.inference_mode() + def forward( + self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int] + ): + # assert all item in prompt_audios_sample_rate is 16000 + assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate) + + + prompt_speech_tokens_list = self.prompt_audio_tokenization(prompt_audios_list) + + prompt_mels_for_flow, prompt_mels_lens_for_flow = self.get_prompt_mels(prompt_audios_list, prompt_audios_sample_rate) + + spk_emb_for_flow = self.get_spk_emb(prompt_audios_list) + + generated_mels, generated_mels_lens = self.forward_flow(prompt_speech_tokens_list, generated_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow) + + generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow) + + return generated_wavs + + +def collate_fn(batch): + ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], [] + for i, item in enumerate(batch): + generated_speech_tokens_list.append(item['target_audio_cosy2_tokens']) + audio = torch.from_numpy(item['prompt_audio']['array']).float() + prompt_audios_list.append(audio) + prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate']) + ids.append(item['id']) + + return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--enable-trt", action="store_true") + parser.add_argument("--model-dir", type=str, default="./Step-Audio-2-mini/token2wav") + parser.add_argument("--batch-size", type=int, default=1) + parser.add_argument("--output-dir", type=str, default="generated_wavs") + parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts") + parser.add_argument("--warmup", type=int, default=3, help="Number of warmup epochs, performance statistics will only be collected from the last epoch") + return parser.parse_args() + +if __name__ == "__main__": + args = get_args() + model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt) + # mkdir output_dir if not exists + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + dataset_name = "yuekai/seed_tts_cosy2" + + dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True) + + + data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0) + + + for epoch in range(args.warmup): + start_time = time.time() + + for batch in data_loader: + ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch + + generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate) + + + for id, wav in zip(ids, generated_wavs): + torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000) + + end_time = time.time() + epoch_time = end_time - start_time + print(f"Measurement epoch time taken: {epoch_time:.4f} seconds") \ No newline at end of file From 9e9433a93eaa7521cffd29dede004939a5a2806f Mon Sep 17 00:00:00 2001 From: yuekaiz Date: Thu, 18 Sep 2025 15:13:24 +0800 Subject: [PATCH 02/12] add streaming trt support --- cosyvoice2/flow/flow.py | 2 +- cosyvoice2/flow/flow_matching.py | 78 +++++++++++++++- token2wav_batch.py | 154 ++++++++++++++++++++++++++----- token2wav_streaming.py | 122 ++++++++++++++++++++++++ 4 files changed, 328 insertions(+), 28 deletions(-) create mode 100644 token2wav_streaming.py diff --git a/cosyvoice2/flow/flow.py b/cosyvoice2/flow/flow.py index c9e4170..51d7492 100644 --- a/cosyvoice2/flow/flow.py +++ b/cosyvoice2/flow/flow.py @@ -152,7 +152,7 @@ def setup_cache(self, feat, estimator_cnn_cache, estimator_att_cache = self.decoder.forward_chunk( mu = h.transpose(1, 2).contiguous(), spks = spk, - cond = mel.transpose(1, 2).contiguous(), + cond = mel.transpose(1, 2).contiguous().to(h.dtype), n_timesteps = n_timesteps, temperature = 1.0, cnn_cache = None, diff --git a/cosyvoice2/flow/flow_matching.py b/cosyvoice2/flow/flow_matching.py index cbcfcfb..ffdb7d7 100644 --- a/cosyvoice2/flow/flow_matching.py +++ b/cosyvoice2/flow/flow_matching.py @@ -20,6 +20,13 @@ from cosyvoice2.utils.mask import make_pad_mask +def get_data_ptr(tensor: torch.Tensor, dummy_buffer: torch.Tensor): + if tensor.numel() == 0: + return dummy_buffer.data_ptr() + else: + return tensor.contiguous().data_ptr() + + """ Inference wrapper """ @@ -34,6 +41,7 @@ def __init__(self, estimator: DiT, inference_cfg_rate:float=0.7): self.register_buffer('cnn_cache_buffer', torch.zeros(16, 16, 2, 1024, 2), persistent=False) self.register_buffer('att_cache_buffer', torch.zeros(16, 16, 2, 8, 1000, 128), persistent=False) + self.register_buffer('dummy_buffer', torch.zeros(1), persistent=False) def scatter_cuda_graph(self, enable_cuda_graph: bool): if enable_cuda_graph: @@ -125,6 +133,58 @@ def forward(self, mu, mask, spks, cond, n_timesteps=10, temperature=1.0): t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) return self.solve_euler(z, t_span, mu, mask, spks, cond) + + def forward_estimator_chunk(self, x, mu, t, spks, cond, cnn_cache, att_cache): + if isinstance(self.estimator, torch.nn.Module): + dphi_dt, this_new_cnn_cache, this_new_att_cache = self.estimator.forward_chunk( + x = x, + mu = mu, + t = t, + spks = spks, + cond = cond, + cnn_cache = cnn_cache, + att_cache = att_cache, + ) + return dphi_dt, this_new_cnn_cache, this_new_att_cache + else: + [estimator, stream], trt_engine = self.estimator.acquire_estimator() + # NOTE need to synchronize when switching stream + torch.cuda.current_stream().synchronize() + batch_size = x.size(0) + with stream: + estimator.set_input_shape('x', (batch_size, 80, x.size(2))) + # estimator.set_input_shape('mask', (batch_size, 1, x.size(2))) + estimator.set_input_shape('mu', (batch_size, 80, x.size(2))) + estimator.set_input_shape('t', (batch_size,)) + estimator.set_input_shape('spks', (batch_size, 80)) + estimator.set_input_shape('cond', (batch_size, 80, x.size(2))) + estimator.set_input_shape('cnn_cache', cnn_cache.shape) + estimator.set_input_shape('att_cache', att_cache.shape) + new_cnn_cache = torch.empty_like(cnn_cache) + new_att_cache_shape = list(att_cache.shape) + new_att_cache_shape[3] += x.size(2) + new_att_cache = torch.empty(new_att_cache_shape, device=att_cache.device, dtype=x.dtype) + data_ptrs = [x.contiguous().data_ptr(), + mu.contiguous().data_ptr(), + t.contiguous().data_ptr(), + spks.contiguous().data_ptr(), + cond.contiguous().data_ptr(), + cnn_cache.contiguous().data_ptr(), + get_data_ptr(att_cache, self.dummy_buffer), + x.data_ptr(), + new_cnn_cache.data_ptr(), + get_data_ptr(new_att_cache, self.dummy_buffer)] + + for i, j in enumerate(data_ptrs): + estimator.set_tensor_address(trt_engine.get_tensor_name(i), j) + # run trt engine + assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True + torch.cuda.current_stream().synchronize() + self.estimator.release_estimator(estimator, stream) + + return x, new_cnn_cache, new_att_cache + + def solve_euler_chunk(self, x:torch.Tensor, t_span:torch.Tensor, @@ -153,13 +213,16 @@ def solve_euler_chunk(self, assert self.inference_cfg_rate > 0, 'cfg rate should be > 0' t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] - t = t.unsqueeze(dim=0) # (b,) + # t = t.unsqueeze(dim=0) # (b,) + t_in = torch.zeros([x.shape[0] * 2], device=x.device, dtype=x.dtype) # setup initial cache if cnn_cache is None: cnn_cache = [None for _ in range(len(t_span)-1)] + cnn_cache = torch.zeros((len(t_span)-1, 16, x.shape[0] * 2, 1024, 2), device=x.device, dtype=x.dtype) if att_cache is None: att_cache = [None for _ in range(len(t_span)-1)] + att_cache = torch.empty((len(t_span)-1, 16, x.shape[0] * 2, 8, 0, 128), device=x.device, dtype=x.dtype) # next chunk's cache at each timestep if att_cache[0] is not None: @@ -177,15 +240,19 @@ def solve_euler_chunk(self, this_att_cache = att_cache[step-1] this_cnn_cache = cnn_cache[step-1] - dphi_dt, this_new_cnn_cache, this_new_att_cache = self.estimator.forward_chunk( - x = x.repeat(2, 1, 1), + x_in = x.repeat(2, 1, 1) + t_in.fill_(t) + + dphi_dt, this_new_cnn_cache, this_new_att_cache = self.forward_estimator_chunk( + x = x_in, mu = mu_in, - t = t.repeat(2), + t = t_in, spks = spks_in, cond = cond_in, cnn_cache = this_cnn_cache, att_cache = this_att_cache, ) + dphi_dt, cfg_dphi_dt = dphi_dt.chunk(2, dim=0) dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt) x = x + dt * dphi_dt @@ -221,6 +288,9 @@ def forward_chunk(self, # get offset from att_cache offset = att_cache.shape[4] if att_cache is not None else 0 z = self.rand_noise[:, :, offset:offset+mu.size(2)] * temperature + + z = z.to(mu.dtype) + t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype) # cosine scheduling t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) diff --git a/token2wav_batch.py b/token2wav_batch.py index 168cd64..d71615c 100644 --- a/token2wav_batch.py +++ b/token2wav_batch.py @@ -33,7 +33,8 @@ import time from hyperpyyaml import load_hyperpyyaml -def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16): + +def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, dtype): import tensorrt as trt logging.info("Converting onnx to trt...") network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) @@ -43,8 +44,12 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16): parser = trt.OnnxParser(network, logger) config = builder.create_builder_config() # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB - if fp16: + if dtype == torch.float16: config.set_flag(trt.BuilderFlag.FP16) + elif dtype == torch.bfloat16: + config.set_flag(trt.BuilderFlag.BF16) + elif dtype == torch.float32: + config.set_flag(trt.BuilderFlag.FP32) profile = builder.create_optimization_profile() # load onnx model with open(onnx_model, "rb") as f: @@ -55,7 +60,14 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16): # set input shapes for i in range(len(trt_kwargs['input_names'])): profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i]) - tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT + if dtype == torch.float16: + tensor_dtype = trt.DataType.HALF + elif dtype == torch.bfloat16: + tensor_dtype = trt.DataType.BF16 + elif dtype == torch.float32: + tensor_dtype = trt.DataType.FLOAT + else: + raise ValueError('invalid dtype {}'.format(dtype)) # set input and output data type for i in range(network.num_inputs): input_tensor = network.get_input(i) @@ -89,15 +101,17 @@ def release_estimator(self, context, stream): self.trt_context_pool.put([context, stream]) class CosyVoice2_Token2Wav(torch.nn.Module): - def __init__(self, model_dir: str = "./CosyVoice2-0.5B", enable_trt: bool = False, device_id: int = 0): + def __init__(self, model_dir: str, enable_trt: bool = False, device_id: int = 0, streaming: bool = False, dtype: torch.dtype = torch.float16): super().__init__() self.device_id = device_id self.device = f"cuda:{device_id}" with open(f"{model_dir}/flow.yaml", "r") as f: configs = load_hyperpyyaml(f) self.flow = configs['flow'] - # self.flow = CausalMaskedDiffWithXvec() - self.flow.half() + + self.dtype = dtype + self.flow.to(self.dtype) + self.flow.load_state_dict(torch.load(f"{model_dir}/flow.pt", map_location="cpu", weights_only=True), strict=True) self.flow.to(self.device).eval() @@ -116,16 +130,25 @@ def __init__(self, model_dir: str = "./CosyVoice2-0.5B", enable_trt: bool = Fals gpu="l20" if enable_trt: - self.load_trt(f'{model_dir}/flow.decoder.estimator.fp16.dynamic_batch.{gpu}.plan', - f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx', - 1, - True) + if streaming: + self.load_trt(f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.chunk.{gpu}.plan', + f'{model_dir}/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx', + 1, + self.dtype, streaming) + else: + self.load_trt(f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.{gpu}.plan', + f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx', + 1, + self.dtype) self.load_spk_trt(f'{model_dir}/campplus.{gpu}.fp32.trt', f'{model_dir}/campplus.onnx', 1, False) + self.streaming_cache = {} + + def forward_spk_embedding(self, spk_feat): if isinstance(self.spk_model, onnxruntime.InferenceSession): return self.spk_model.run( @@ -172,11 +195,15 @@ def get_spk_trt_kwargs(self): input_names = ["input"] return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} - def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent=1, fp16=True): + def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent=1, dtype=torch.float16, streaming=False): assert torch.cuda.is_available(), 'tensorrt only supports gpu!' if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0: - trt_kwargs = self.get_trt_kwargs_dynamic_batch(opt_batch_size=2, max_batch_size=16) - convert_onnx_to_trt(flow_decoder_estimator_model, trt_kwargs, flow_decoder_onnx_model, fp16) + opt_batch_size = 2 + max_batch_size = 16 + if streaming: + opt_batch_size, max_batch_size = 1, 1 # only support batch size 1 for streaming tts + trt_kwargs = self.get_trt_kwargs_dynamic_batch(opt_batch_size=opt_batch_size, max_batch_size=max_batch_size, streaming=streaming) + convert_onnx_to_trt(flow_decoder_estimator_model, trt_kwargs, flow_decoder_onnx_model, dtype) del self.flow.decoder.estimator import tensorrt as trt with open(flow_decoder_estimator_model, 'rb') as f: @@ -184,11 +211,17 @@ def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_co assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model) self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device) - def get_trt_kwargs_dynamic_batch(self, opt_batch_size=2, max_batch_size=64): - min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80)] - opt_shape = [(opt_batch_size*2, 80, 500), (opt_batch_size*2, 1, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2,), (opt_batch_size*2, 80)] - max_shape = [(max_batch_size*2, 80, 3000), (max_batch_size*2, 1, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2,), (max_batch_size*2, 80)] - input_names = ["x", "mask", "mu", "cond", "t", "spks"] + def get_trt_kwargs_dynamic_batch(self, opt_batch_size=2, max_batch_size=64, streaming=False): + if streaming: + min_shape = [(2, 80, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80), (16, 2, 1024, 2), (16, 2, 8, 0, 128)] + opt_shape = [(opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2,), (opt_batch_size*2, 80), (16, opt_batch_size*2, 1024, 2), (16, opt_batch_size*2, 8, 100, 128)] + max_shape = [(max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2,), (max_batch_size*2, 80), (16, max_batch_size*2, 1024, 2), (16, max_batch_size*2, 8, 1000, 128)] + input_names = ["x", "mu", "cond", "t", "spks", "cnn_cache", "att_cache"] + else: + min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80)] + opt_shape = [(opt_batch_size*2, 80, 500), (opt_batch_size*2, 1, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2,), (opt_batch_size*2, 80)] + max_shape = [(max_batch_size*2, 80, 3000), (max_batch_size*2, 1, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2,), (max_batch_size*2, 80)] + input_names = ["x", "mask", "mu", "cond", "t", "spks"] return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} def prompt_audio_tokenization(self, prompt_audios_list: list[torch.Tensor]) -> list[list[int]]: @@ -215,7 +248,9 @@ def get_spk_emb(self, prompt_audios_list: list[torch.Tensor]) -> torch.Tensor: spk_emb = self.forward_spk_embedding(spk_feat) spk_emb_for_flow.append(spk_emb) - spk_emb_for_flow = torch.tensor(spk_emb_for_flow) + spk_emb_for_flow = torch.tensor(spk_emb_for_flow) + if self.dtype != torch.float32: + spk_emb_for_flow = spk_emb_for_flow.to(self.dtype) return spk_emb_for_flow def get_prompt_mels(self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]): @@ -271,18 +306,91 @@ def forward( assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate) + prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio(prompt_audios_list, prompt_audios_sample_rate) + + generated_mels, generated_mels_lens = self.forward_flow(prompt_speech_tokens_list, generated_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow) + + generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow) + + return generated_wavs + + def prepare_prompt_audio( + self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int] + ): + # assert all item in prompt_audios_sample_rate is 16000 + assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate) + + prompt_speech_tokens_list = self.prompt_audio_tokenization(prompt_audios_list) prompt_mels_for_flow, prompt_mels_lens_for_flow = self.get_prompt_mels(prompt_audios_list, prompt_audios_sample_rate) spk_emb_for_flow = self.get_spk_emb(prompt_audios_list) + + return prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow - generated_mels, generated_mels_lens = self.forward_flow(prompt_speech_tokens_list, generated_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow) - generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow) - - return generated_wavs + def get_prompt_audio_cache_for_streaming_tts( + self, prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow + ): + assert len(prompt_speech_tokens_list) == 1, "only support batch size 1 for streaming tts" + for i, prompt_speech_tokens in enumerate(prompt_speech_tokens_list): + prompt_speech_tokens_list[i] = torch.tensor(prompt_speech_tokens + prompt_speech_tokens_list[i][:3]) + prompt_speech_tokens_tensor = torch.nn.utils.rnn.pad_sequence(prompt_speech_tokens_list, batch_first=True, padding_value=0) + + cache = self.flow.setup_cache( + prompt_speech_tokens_tensor.to(self.device), + prompt_mels_for_flow.to(self.device), + spk_emb_for_flow.to(self.device), + n_timesteps=10 + ) + + # cache dict's tensor batch dim is 1 for now + return cache + + + @torch.inference_mode() + def forward_streaming( + self, generated_speech_tokens: list[int], prompt_audio: torch.Tensor, prompt_audio_sample_rate: int, last_chunk: bool + ): + + assert prompt_audio_sample_rate == 16000 + + if not self.streaming_cache: + prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio([prompt_audio], [prompt_audio_sample_rate]) + + + token_len = min(int(prompt_mels_for_flow.shape[1] / 2), len(prompt_speech_tokens_list[0])) + prompt_mels_for_flow = prompt_mels_for_flow[:, :2 * token_len].contiguous() + prompt_speech_tokens_list[0] = prompt_speech_tokens_list[0][:token_len] + + + cache_dict = self.get_prompt_audio_cache_for_streaming_tts(prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow) + prompt_audio_dict = {'spk_emb_for_flow': spk_emb_for_flow, 'prompt_mels_for_flow': prompt_mels_for_flow} + self.streaming_cache = cache_dict | prompt_audio_dict + + generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda') + + chunk_mel, streaming_cache = self.flow.inference_chunk( + token=generated_speech_tokens, + spk=self.streaming_cache['spk_emb_for_flow'].to(self.device), + cache=self.streaming_cache, + last_chunk=last_chunk, + n_timesteps=10, + ) + prompt_audio_dict = {'spk_emb_for_flow': self.streaming_cache['spk_emb_for_flow'], 'prompt_mels_for_flow': self.streaming_cache['prompt_mels_for_flow']} + self.streaming_cache = streaming_cache | prompt_audio_dict + if self.streaming_cache['estimator_att_cache'].shape[4] > (self.streaming_cache['prompt_mels_for_flow'].shape[1] + 100): + self.streaming_cache['estimator_att_cache'] = torch.cat([ + self.streaming_cache['estimator_att_cache'][:, :, :, :, :self.streaming_cache['prompt_mels_for_flow'].shape[1]], + self.streaming_cache['estimator_att_cache'][:, :, :, :, -100:], + ], dim=4) + + + wav, _ = self.hift(speech_feat=chunk_mel.to(torch.float32)) + + return wav def collate_fn(batch): ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], [] diff --git a/token2wav_streaming.py b/token2wav_streaming.py new file mode 100644 index 0000000..abc647f --- /dev/null +++ b/token2wav_streaming.py @@ -0,0 +1,122 @@ +import torch +import os +import argparse +from datasets import load_dataset +from torch.utils.data import DataLoader +import numpy as np +import torchaudio + +from token2wav_batch import CosyVoice2_Token2Wav +import soundfile as sf + +def collate_fn(batch): + ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], [] + prompt_speech_tokens_list, prompt_text_list = [], [] + for i, item in enumerate(batch): + generated_speech_tokens_list.append(item['target_audio_cosy2_tokens']) + audio = torch.from_numpy(item['prompt_audio']['array']).float() + prompt_audios_list.append(audio) + prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate']) + ids.append(item['id']) + prompt_speech_tokens_list.append(item['prompt_audio_cosy2_tokens']) + prompt_text_list.append(item['prompt_text']) + + return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--enable-trt", action="store_true") + parser.add_argument("--model-dir", type=str, default="./Step-Audio-2-mini/token2wav") + parser.add_argument("--batch-size", type=int, default=1) + parser.add_argument("--output-dir", type=str, default="generated_wavs") + parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts") + parser.add_argument("--dataset-name", type=str, default="yuekai/seed_tts_cosy2") + return parser.parse_args() + + +def fake_generated_id_iter(generated_speech_tokens_list): + for i in range(len(generated_speech_tokens_list)): + yield generated_speech_tokens_list[i] + + + +if __name__ == "__main__": + args = get_args() + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + dataset_name = args.dataset_name + dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True) + data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0) + + token2wav_model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt, streaming=True) + + flow_pre_lookahead_len = 3 + CHUNK_SIZE = 25 + OVERLAP_SIZE = 0 + + for batch in data_loader: + tts_speech_list = [] + ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list = batch + + id, generated_speech_tokens, prompt_audio, prompt_audio_sample_rate = ids[0], generated_speech_tokens_list[0], prompt_audios_list[0], prompt_audios_sample_rate[0] + assert prompt_audio_sample_rate == 16000 + + prompt_text = prompt_text_list[0] + prompt_speech_tokens = prompt_speech_tokens_list[0] + + + # generated_ids_iter = fake_generated_id_iter(generated_speech_tokens) + + semantic_token_ids_arr, token_offset = [], 0 + flow_prompt_speech_token_len = len(prompt_speech_tokens) + + buffer = generated_speech_tokens + output_wavs = [] + while True: + + if len(buffer) >= CHUNK_SIZE + token2wav_model.flow.pre_lookahead_len: + wavs = token2wav_model.forward_streaming(buffer[:CHUNK_SIZE + token2wav_model.flow.pre_lookahead_len], prompt_audio, prompt_audio_sample_rate, False) + buffer = buffer[CHUNK_SIZE - OVERLAP_SIZE:] + + output_wavs.append(wavs) + + else: + wavs = token2wav_model.forward_streaming(buffer, prompt_audio, prompt_audio_sample_rate, True) + output_wavs.append(wavs) + token2wav_model.streaming_cache = None + break + + # tts_speech = torch.cat(output_wavs, dim=-1) + # torchaudio.save(os.path.join(args.output_dir, f"{id}.wav"), tts_speech.cpu(), 24000) + + for i, wav in enumerate(output_wavs): + output_wavs[i] = wav.cpu().numpy().squeeze() + + + audios = output_wavs + + # cross_fade_samples = int(0.16 * 24000) + # fade_out = np.linspace(1, 0, cross_fade_samples) + # fade_in = np.linspace(0, 1, cross_fade_samples) + # reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap + # for i in range(1, len(audios)): + # # Cross-fade section + # cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in + + # audios[i - 1][-cross_fade_samples:] * fade_out) + # # Middle section of the current chunk + # middle_part = audios[i][cross_fade_samples:-cross_fade_samples] + # # Concatenate + # reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part]) + # # Add the last part of the final chunk + # reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]]) + + + reconstructed_audio = np.concatenate(audios) + # Save reconstructed audio + sf.write(os.path.join(args.output_dir, f"{id}.wav"), reconstructed_audio, 24000, "PCM_16") + + + print(f"Saved {id}") + From 699749443e060d52ae1fa8467eb201cfb95cef9a Mon Sep 17 00:00:00 2001 From: root Date: Thu, 18 Sep 2025 18:10:19 +0800 Subject: [PATCH 03/12] add speaker cache and runtime streaming request cache --- token2wav_batch.py | 42 ++++++++++++-------- token2wav_streaming.py | 89 ++++++++++++++++++------------------------ 2 files changed, 65 insertions(+), 66 deletions(-) diff --git a/token2wav_batch.py b/token2wav_batch.py index d71615c..47f5cf0 100644 --- a/token2wav_batch.py +++ b/token2wav_batch.py @@ -147,6 +147,7 @@ def __init__(self, model_dir: str, enable_trt: bool = False, device_id: int = 0, self.streaming_cache = {} + self.speaker_cache = {} def forward_spk_embedding(self, spk_feat): @@ -351,44 +352,53 @@ def get_prompt_audio_cache_for_streaming_tts( @torch.inference_mode() def forward_streaming( - self, generated_speech_tokens: list[int], prompt_audio: torch.Tensor, prompt_audio_sample_rate: int, last_chunk: bool + self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000 ): - assert prompt_audio_sample_rate == 16000 + if speaker_id not in self.speaker_cache: + assert prompt_audio is not None, "prompt_audio is required for new speaker" + assert prompt_audio_sample_rate == 16000 - - if not self.streaming_cache: prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio([prompt_audio], [prompt_audio_sample_rate]) - token_len = min(int(prompt_mels_for_flow.shape[1] / 2), len(prompt_speech_tokens_list[0])) prompt_mels_for_flow = prompt_mels_for_flow[:, :2 * token_len].contiguous() prompt_speech_tokens_list[0] = prompt_speech_tokens_list[0][:token_len] - cache_dict = self.get_prompt_audio_cache_for_streaming_tts(prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow) prompt_audio_dict = {'spk_emb_for_flow': spk_emb_for_flow, 'prompt_mels_for_flow': prompt_mels_for_flow} - self.streaming_cache = cache_dict | prompt_audio_dict + + self.speaker_cache[speaker_id] = {'prompt_audio_dict': prompt_audio_dict, 'cache_dict': cache_dict} + if request_id not in self.streaming_cache: + self.streaming_cache[request_id] = self.speaker_cache[speaker_id]['cache_dict'].copy() + + current_request_cache = self.streaming_cache[request_id] + prompt_audio_dict = self.speaker_cache[speaker_id]['prompt_audio_dict'] generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda') - chunk_mel, streaming_cache = self.flow.inference_chunk( + chunk_mel, new_streaming_cache = self.flow.inference_chunk( token=generated_speech_tokens, - spk=self.streaming_cache['spk_emb_for_flow'].to(self.device), - cache=self.streaming_cache, + spk=prompt_audio_dict['spk_emb_for_flow'].to(self.device), + cache=current_request_cache, last_chunk=last_chunk, n_timesteps=10, ) - prompt_audio_dict = {'spk_emb_for_flow': self.streaming_cache['spk_emb_for_flow'], 'prompt_mels_for_flow': self.streaming_cache['prompt_mels_for_flow']} - self.streaming_cache = streaming_cache | prompt_audio_dict - if self.streaming_cache['estimator_att_cache'].shape[4] > (self.streaming_cache['prompt_mels_for_flow'].shape[1] + 100): - self.streaming_cache['estimator_att_cache'] = torch.cat([ - self.streaming_cache['estimator_att_cache'][:, :, :, :, :self.streaming_cache['prompt_mels_for_flow'].shape[1]], - self.streaming_cache['estimator_att_cache'][:, :, :, :, -100:], + + self.streaming_cache[request_id] = new_streaming_cache + + if self.streaming_cache[request_id]['estimator_att_cache'].shape[4] > (prompt_audio_dict['prompt_mels_for_flow'].shape[1] + 100): + self.streaming_cache[request_id]['estimator_att_cache'] = torch.cat([ + self.streaming_cache[request_id]['estimator_att_cache'][:, :, :, :, :prompt_audio_dict['prompt_mels_for_flow'].shape[1]], + self.streaming_cache[request_id]['estimator_att_cache'][:, :, :, :, -100:], ], dim=4) wav, _ = self.hift(speech_feat=chunk_mel.to(torch.float32)) + + if last_chunk: + if request_id in self.streaming_cache: + del self.streaming_cache[request_id] return wav diff --git a/token2wav_streaming.py b/token2wav_streaming.py index abc647f..342b423 100644 --- a/token2wav_streaming.py +++ b/token2wav_streaming.py @@ -5,7 +5,7 @@ from torch.utils.data import DataLoader import numpy as np import torchaudio - +import time from token2wav_batch import CosyVoice2_Token2Wav import soundfile as sf @@ -56,67 +56,56 @@ def fake_generated_id_iter(generated_speech_tokens_list): CHUNK_SIZE = 25 OVERLAP_SIZE = 0 - for batch in data_loader: - tts_speech_list = [] - ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list = batch - - id, generated_speech_tokens, prompt_audio, prompt_audio_sample_rate = ids[0], generated_speech_tokens_list[0], prompt_audios_list[0], prompt_audios_sample_rate[0] - assert prompt_audio_sample_rate == 16000 + warmup_times = 3 + for _ in range(warmup_times): + start_time = time.time() + for batch in data_loader: + tts_speech_list = [] + ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list = batch - prompt_text = prompt_text_list[0] - prompt_speech_tokens = prompt_speech_tokens_list[0] + id, generated_speech_tokens, prompt_audio, prompt_audio_sample_rate = ids[0], generated_speech_tokens_list[0], prompt_audios_list[0], prompt_audios_sample_rate[0] + assert prompt_audio_sample_rate == 16000 + prompt_text = prompt_text_list[0] + prompt_speech_tokens = prompt_speech_tokens_list[0] - # generated_ids_iter = fake_generated_id_iter(generated_speech_tokens) - semantic_token_ids_arr, token_offset = [], 0 - flow_prompt_speech_token_len = len(prompt_speech_tokens) - - buffer = generated_speech_tokens - output_wavs = [] - while True: + # generated_ids_iter = fake_generated_id_iter(generated_speech_tokens) - if len(buffer) >= CHUNK_SIZE + token2wav_model.flow.pre_lookahead_len: - wavs = token2wav_model.forward_streaming(buffer[:CHUNK_SIZE + token2wav_model.flow.pre_lookahead_len], prompt_audio, prompt_audio_sample_rate, False) - buffer = buffer[CHUNK_SIZE - OVERLAP_SIZE:] + semantic_token_ids_arr, token_offset = [], 0 + flow_prompt_speech_token_len = len(prompt_speech_tokens) + + buffer = generated_speech_tokens + output_wavs = [] + while True: - output_wavs.append(wavs) + if len(buffer) >= CHUNK_SIZE + token2wav_model.flow.pre_lookahead_len: + wavs = token2wav_model.forward_streaming(buffer[:CHUNK_SIZE + token2wav_model.flow.pre_lookahead_len], False, request_id=id, speaker_id=id, prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate) + buffer = buffer[CHUNK_SIZE - OVERLAP_SIZE:] - else: - wavs = token2wav_model.forward_streaming(buffer, prompt_audio, prompt_audio_sample_rate, True) - output_wavs.append(wavs) - token2wav_model.streaming_cache = None - break + output_wavs.append(wavs) - # tts_speech = torch.cat(output_wavs, dim=-1) - # torchaudio.save(os.path.join(args.output_dir, f"{id}.wav"), tts_speech.cpu(), 24000) + else: + wavs = token2wav_model.forward_streaming(buffer, True, request_id=id, speaker_id=id, prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate) + output_wavs.append(wavs) + break - for i, wav in enumerate(output_wavs): - output_wavs[i] = wav.cpu().numpy().squeeze() + # tts_speech = torch.cat(output_wavs, dim=-1) + # torchaudio.save(os.path.join(args.output_dir, f"{id}.wav"), tts_speech.cpu(), 24000) + for i, wav in enumerate(output_wavs): + output_wavs[i] = wav.cpu().numpy().squeeze() - audios = output_wavs - # cross_fade_samples = int(0.16 * 24000) - # fade_out = np.linspace(1, 0, cross_fade_samples) - # fade_in = np.linspace(0, 1, cross_fade_samples) - # reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap - # for i in range(1, len(audios)): - # # Cross-fade section - # cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in + - # audios[i - 1][-cross_fade_samples:] * fade_out) - # # Middle section of the current chunk - # middle_part = audios[i][cross_fade_samples:-cross_fade_samples] - # # Concatenate - # reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part]) - # # Add the last part of the final chunk - # reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]]) - - - reconstructed_audio = np.concatenate(audios) - # Save reconstructed audio - sf.write(os.path.join(args.output_dir, f"{id}.wav"), reconstructed_audio, 24000, "PCM_16") + audios = output_wavs + reconstructed_audio = np.concatenate(audios) + # Save reconstructed audio + sf.write(os.path.join(args.output_dir, f"{id}.wav"), reconstructed_audio, 24000, "PCM_16") - print(f"Saved {id}") + print(f"Saved {id}") + end_time = time.time() + if _ == 0: + token2wav_model.speaker_cache = {} + print(f"Warmup time: {end_time - start_time} seconds") From 4e517602ba0207b4aa16c73742f14833fa552250 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 18 Sep 2025 18:42:29 +0800 Subject: [PATCH 04/12] support vocoder cache --- token2wav_batch.py | 72 ++++++++++++++++++++++++++++++++++++---------- 1 file changed, 57 insertions(+), 15 deletions(-) diff --git a/token2wav_batch.py b/token2wav_batch.py index 47f5cf0..69db946 100644 --- a/token2wav_batch.py +++ b/token2wav_batch.py @@ -31,9 +31,20 @@ import argparse import queue import time - +import numpy as np from hyperpyyaml import load_hyperpyyaml + +def fade_in_out(fade_in_mel:torch.Tensor, fade_out_mel:torch.Tensor, window:torch.Tensor): + """perform fade_in_out in tensor style + """ + mel_overlap_len = int(window.shape[0] / 2) + fade_in_mel = fade_in_mel.clone() + fade_in_mel[..., :mel_overlap_len] = \ + fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \ + fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:] + return fade_in_mel + def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, dtype): import tensorrt as trt logging.info("Converting onnx to trt...") @@ -146,9 +157,15 @@ def __init__(self, model_dir: str, enable_trt: bool = False, device_id: int = 0, False) - self.streaming_cache = {} + self.streaming_flow_cache = {} self.speaker_cache = {} + self.mel_cache_len = 8 # hard-coded, 160ms + self.source_cache_len = int(self.mel_cache_len * 480) # 50hz mel -> 24kHz wave + self.speech_window = torch.from_numpy(np.hamming(2 * self.source_cache_len)).cuda() + + # hifigan cache for streaming tts + self.hift_cache_dict = {} def forward_spk_embedding(self, spk_feat): if isinstance(self.spk_model, onnxruntime.InferenceSession): @@ -370,14 +387,19 @@ def forward_streaming( self.speaker_cache[speaker_id] = {'prompt_audio_dict': prompt_audio_dict, 'cache_dict': cache_dict} - if request_id not in self.streaming_cache: - self.streaming_cache[request_id] = self.speaker_cache[speaker_id]['cache_dict'].copy() + if request_id not in self.streaming_flow_cache: + self.streaming_flow_cache[request_id] = self.speaker_cache[speaker_id]['cache_dict'].copy() + self.hift_cache_dict[request_id] = dict( + mel = torch.zeros(1, 80, 0, device='cuda'), + source = torch.zeros(1, 1, 0, device='cuda'), + speech = torch.zeros(1, 0, device='cuda'), + ) - current_request_cache = self.streaming_cache[request_id] + current_request_cache = self.streaming_flow_cache[request_id] prompt_audio_dict = self.speaker_cache[speaker_id]['prompt_audio_dict'] generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda') - chunk_mel, new_streaming_cache = self.flow.inference_chunk( + chunk_mel, new_streaming_flow_cache = self.flow.inference_chunk( token=generated_speech_tokens, spk=prompt_audio_dict['spk_emb_for_flow'].to(self.device), cache=current_request_cache, @@ -385,22 +407,42 @@ def forward_streaming( n_timesteps=10, ) - self.streaming_cache[request_id] = new_streaming_cache + self.streaming_flow_cache[request_id] = new_streaming_flow_cache - if self.streaming_cache[request_id]['estimator_att_cache'].shape[4] > (prompt_audio_dict['prompt_mels_for_flow'].shape[1] + 100): - self.streaming_cache[request_id]['estimator_att_cache'] = torch.cat([ - self.streaming_cache[request_id]['estimator_att_cache'][:, :, :, :, :prompt_audio_dict['prompt_mels_for_flow'].shape[1]], - self.streaming_cache[request_id]['estimator_att_cache'][:, :, :, :, -100:], + if self.streaming_flow_cache[request_id]['estimator_att_cache'].shape[4] > (prompt_audio_dict['prompt_mels_for_flow'].shape[1] + 100): + self.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.cat([ + self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, :prompt_audio_dict['prompt_mels_for_flow'].shape[1]], + self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, -100:], ], dim=4) - wav, _ = self.hift(speech_feat=chunk_mel.to(torch.float32)) + + hift_cache_mel = self.hift_cache_dict[request_id]['mel'] + hift_cache_source = self.hift_cache_dict[request_id]['source'] + hift_cache_speech = self.hift_cache_dict[request_id]['speech'] + mel = torch.concat([hift_cache_mel, chunk_mel], dim=2) + + speech, source = self.hift(mel, hift_cache_source) + + # overlap speech smooth + if hift_cache_speech.shape[-1] > 0: + speech = fade_in_out(speech, hift_cache_speech, self.speech_window) + + # update vocoder cache + self.hift_cache_dict[request_id] = dict( + mel = mel[..., -self.mel_cache_len:].clone().detach(), + source = source[:, :, -self.source_cache_len:].clone().detach(), + speech = speech[:, -self.source_cache_len:].clone().detach(), + ) + if not last_chunk: + speech = speech[:, :-self.source_cache_len] if last_chunk: - if request_id in self.streaming_cache: - del self.streaming_cache[request_id] + assert request_id in self.streaming_flow_cache + self.streaming_flow_cache.pop(request_id) + self.hift_cache_dict.pop(request_id) - return wav + return speech def collate_fn(batch): ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], [] From 6a012972c778bb18f96a95f0023a7cbc05086139 Mon Sep 17 00:00:00 2001 From: yuekaiz Date: Thu, 18 Sep 2025 18:45:46 +0800 Subject: [PATCH 05/12] rename files --- token2wav_streaming.py => benchmark_streaming_token2wav.py | 2 +- token2wav_batch.py => token2wav_dit.py | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename token2wav_streaming.py => benchmark_streaming_token2wav.py (98%) rename token2wav_batch.py => token2wav_dit.py (100%) diff --git a/token2wav_streaming.py b/benchmark_streaming_token2wav.py similarity index 98% rename from token2wav_streaming.py rename to benchmark_streaming_token2wav.py index 342b423..70b1199 100644 --- a/token2wav_streaming.py +++ b/benchmark_streaming_token2wav.py @@ -6,7 +6,7 @@ import numpy as np import torchaudio import time -from token2wav_batch import CosyVoice2_Token2Wav +from token2wav_dit import CosyVoice2_Token2Wav import soundfile as sf def collate_fn(batch): diff --git a/token2wav_batch.py b/token2wav_dit.py similarity index 100% rename from token2wav_batch.py rename to token2wav_dit.py From 28984ee18f66a9e223ea78e78ef69ca14f3b0c5c Mon Sep 17 00:00:00 2001 From: yuekaiz Date: Wed, 24 Sep 2025 15:13:27 +0800 Subject: [PATCH 06/12] fix att buffer shallow copy issue --- cosyvoice2/flow/decoder_dit.py | 37 ++++---- cosyvoice2/flow/flow_matching.py | 12 +-- export_onnx.py | 125 +++++++++++++++++++++++++ export_onnx_chunk.py | 155 +++++++++++++++++++++++++++++++ token2wav_dit.py | 37 ++++---- 5 files changed, 323 insertions(+), 43 deletions(-) create mode 100644 export_onnx.py create mode 100644 export_onnx_chunk.py diff --git a/cosyvoice2/flow/decoder_dit.py b/cosyvoice2/flow/decoder_dit.py index b3615e3..15c8c1e 100644 --- a/cosyvoice2/flow/decoder_dit.py +++ b/cosyvoice2/flow/decoder_dit.py @@ -403,9 +403,6 @@ def __init__( self.inference_buffers_chunk = {} self.max_size_chunk = {} - self.register_buffer('att_cache_buffer', torch.zeros((16, 2, 8, 1000, 128)), persistent=False) - self.register_buffer('cnn_cache_buffer', torch.zeros((16, 2, 1024, 2)), persistent=False) - def initialize_weights(self): # Initialize transformer layers: def _basic_init(module): @@ -432,7 +429,7 @@ def _basic_init(module): def _init_cuda_graph_chunk(self): # get dtype, device from registered buffer - dtype, device = self.cnn_cache_buffer.dtype, self.cnn_cache_buffer.device + dtype, device = self.in_proj.weight.dtype, self.in_proj.weight.device # init cuda graph for streaming forward with torch.no_grad(): for chunk_size in [30, 48, 96]: @@ -542,16 +539,14 @@ def forward_chunk(self, # create fake cache if cnn_cache is None: - cnn_cache = [None] * len(self.blocks) + cnn_cache = torch.zeros(len(self.blocks), x.shape[0], 1024, 2, dtype=x.dtype, device=x.device) if att_cache is None: - att_cache = [None] * len(self.blocks) - if att_cache[0] is not None: - last_att_len = att_cache.shape[3] - else: - last_att_len = 0 + att_cache = torch.zeros(len(self.blocks), x.shape[0], self.blocks[0].attn.num_heads, 0, self.blocks[0].attn.head_dim * 2, dtype=x.dtype, device=x.device) + + last_att_len = att_cache.shape[3] chunk_size = x.shape[2] mask = torch.ones(x.shape[0], chunk_size, last_att_len+chunk_size, dtype=torch.bool, device=x.device) - if self.use_cuda_graph and att_cache[0] is not None and chunk_size in self.graph_chunk and last_att_len <= self.max_size_chunk[chunk_size]: + if self.use_cuda_graph and att_cache is not None and chunk_size in self.graph_chunk and last_att_len <= self.max_size_chunk[chunk_size]: padded_mask = torch.zeros((2, chunk_size, self.max_size_chunk[chunk_size]+chunk_size), dtype=mask.dtype, device=mask.device) padded_mask[:, :, :mask.shape[-1]] = mask padded_att_cache = torch.zeros((16, 2, 8, self.max_size_chunk[chunk_size], 128), dtype=att_cache.dtype, device=att_cache.device) @@ -567,20 +562,22 @@ def forward_chunk(self, new_att_cache = self.inference_buffers_chunk[chunk_size]['static_outputs'][2][:, :, :, :chunk_size+last_att_len, :] else: mask = None - x = self.blocks_forward_chunk(x, t, mask, cnn_cache, att_cache, self.cnn_cache_buffer, self.att_cache_buffer) - new_cnn_cache = self.cnn_cache_buffer - new_att_cache = self.att_cache_buffer[:, :, :, :last_att_len+chunk_size, :] + x, new_cnn_cache, new_att_cache = self.blocks_forward_chunk(x, t, mask, cnn_cache, att_cache) return x, new_cnn_cache, new_att_cache - def blocks_forward_chunk(self, x, t, mask, cnn_cache=None, att_cache=None, cnn_cache_buffer=None, att_cache_buffer=None): + def blocks_forward_chunk(self, x, t, mask, cnn_cache, att_cache): x = x.transpose(1, 2) x = self.in_proj(x) + + new_cnn_caches = [] + new_att_caches = [] + for b_idx, block in enumerate(self.blocks): - x, this_new_cnn_cache, this_new_att_cache \ - = block.forward_chunk(x, t, cnn_cache[b_idx], att_cache[b_idx], mask) - cnn_cache_buffer[b_idx] = this_new_cnn_cache - att_cache_buffer[b_idx][:, :, :this_new_att_cache.shape[2], :] = this_new_att_cache + x, this_new_cnn_cache, this_new_att_cache = block.forward_chunk(x, t, cnn_cache[b_idx], att_cache[b_idx], mask) + new_cnn_caches.append(this_new_cnn_cache) + new_att_caches.append(this_new_att_cache) + x = self.final_layer(x, t) x = x.transpose(1, 2) - return x + return x, torch.stack(new_cnn_caches), torch.stack(new_att_caches) diff --git a/cosyvoice2/flow/flow_matching.py b/cosyvoice2/flow/flow_matching.py index ffdb7d7..701370b 100644 --- a/cosyvoice2/flow/flow_matching.py +++ b/cosyvoice2/flow/flow_matching.py @@ -39,8 +39,6 @@ def __init__(self, estimator: DiT, inference_cfg_rate:float=0.7): # a maximum of 600s self.register_buffer('rand_noise', torch.randn([1, self.out_channels, 50 * 600]), persistent=False) - self.register_buffer('cnn_cache_buffer', torch.zeros(16, 16, 2, 1024, 2), persistent=False) - self.register_buffer('att_cache_buffer', torch.zeros(16, 16, 2, 8, 1000, 128), persistent=False) self.register_buffer('dummy_buffer', torch.zeros(1), persistent=False) def scatter_cuda_graph(self, enable_cuda_graph: bool): @@ -224,6 +222,8 @@ def solve_euler_chunk(self, att_cache = [None for _ in range(len(t_span)-1)] att_cache = torch.empty((len(t_span)-1, 16, x.shape[0] * 2, 8, 0, 128), device=x.device, dtype=x.dtype) # next chunk's cache at each timestep + new_cnn_caches = [] + new_att_caches = [] if att_cache[0] is not None: last_att_len = att_cache.shape[4] @@ -260,11 +260,11 @@ def solve_euler_chunk(self, if step < len(t_span) - 1: dt = t_span[step + 1] - t - self.cnn_cache_buffer[step-1] = this_new_cnn_cache - self.att_cache_buffer[step-1][:, :, :, :x.shape[2]+last_att_len, :] = this_new_att_cache + new_cnn_caches.append(this_new_cnn_cache) + new_att_caches.append(this_new_att_cache) - cnn_cache = self.cnn_cache_buffer - att_cache = self.att_cache_buffer[:, :, :, :, :x.shape[2]+last_att_len, :] + cnn_cache = torch.stack(new_cnn_caches) + att_cache = torch.stack(new_att_caches) return x, cnn_cache, att_cache @torch.inference_mode() diff --git a/export_onnx.py b/export_onnx.py new file mode 100644 index 0000000..9d21444 --- /dev/null +++ b/export_onnx.py @@ -0,0 +1,125 @@ +# Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com) +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import argparse +import logging +logging.getLogger('matplotlib').setLevel(logging.WARNING) +import os +import sys +import onnxruntime +import random +import torch +from tqdm import tqdm +from hyperpyyaml import load_hyperpyyaml + + +def get_dummy_input(batch_size, seq_len, out_channels, device): + x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) + mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device) + mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) + t = torch.rand((batch_size), dtype=torch.float32, device=device) + spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device) + cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) + return x, mask, mu, t, spks, cond + + +def get_args(): + parser = argparse.ArgumentParser(description='export your model for deployment') + parser.add_argument('--model_dir', + type=str, + default='Step-Audio-2-mini/token2wav', + help='local path') + parser.add_argument('--onnx_model', + type=str, + default='flow.decoder.estimator.fp32.dynamic_batch.onnx', + help='onnx model name') + args = parser.parse_args() + print(args) + return args + + +@torch.no_grad() +def main(): + args = get_args() + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + + with open(f"{args.model_dir}/flow.yaml", "r") as f: + configs = load_hyperpyyaml(f) + flow_model = configs['flow'] + + device = torch.device('cuda') + + + # 1. export flow decoder estimator + flow_model.load_state_dict(torch.load(f"{args.model_dir}/flow.pt", map_location="cpu", weights_only=True), strict=True) + estimator = flow_model.decoder.estimator + estimator.eval() + estimator.to(device) + + + batch_size, seq_len = 2, 256 + out_channels = flow_model.decoder.estimator.out_channels + x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device) + torch.onnx.export( + estimator, + (x, mask, mu, t, spks, cond), + f'{args.model_dir}/{args.onnx_model}', + export_params=True, + opset_version=18, + do_constant_folding=True, + input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'], + output_names=['estimator_out'], + dynamic_axes={ + 'x': {0: 'batch_size', 2: 'seq_len'}, + 'mask': {0: 'batch_size', 2: 'seq_len'}, + 'mu': {0: 'batch_size', 2: 'seq_len'}, + 'cond': {0: 'batch_size', 2: 'seq_len'}, + 't': {0: 'batch_size'}, + 'spks': {0: 'batch_size'}, + 'estimator_out': {0: 'batch_size', 2: 'seq_len'}, + + } + ) + + # 2. test computation consistency + option = onnxruntime.SessionOptions() + option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + option.intra_op_num_threads = 1 + providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider'] + estimator_onnx = onnxruntime.InferenceSession(f'{args.model_dir}/{args.onnx_model}', + sess_options=option, providers=providers) + + for _ in tqdm(range(10)): + x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device) + output_pytorch = estimator(x, mask, mu, t, spks, cond) + ort_inputs = { + 'x': x.cpu().numpy(), + 'mask': mask.cpu().numpy(), + 'mu': mu.cpu().numpy(), + 't': t.cpu().numpy(), + 'spks': spks.cpu().numpy(), + 'cond': cond.cpu().numpy() + } + output_onnx = estimator_onnx.run(None, ort_inputs)[0] + torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4) + logging.info('successfully export estimator') + + +if __name__ == "__main__": + main() + diff --git a/export_onnx_chunk.py b/export_onnx_chunk.py new file mode 100644 index 0000000..e6a99e3 --- /dev/null +++ b/export_onnx_chunk.py @@ -0,0 +1,155 @@ +# Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com) +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import argparse +import logging +logging.getLogger('matplotlib').setLevel(logging.WARNING) +import os +import sys +import onnxruntime +import random +import torch +from tqdm import tqdm +from hyperpyyaml import load_hyperpyyaml + + +def get_dummy_input_chunk(batch_size, seq_len, prev_seq_len, out_channels, estimator, device): + x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) + mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) + t = torch.rand((batch_size), dtype=torch.float32, device=device) + spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device) + cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) + + depth = len(estimator.blocks) + num_heads = estimator.blocks[0].attn.num_heads + head_dim = estimator.blocks[0].attn.head_dim + cnn_channels = estimator.blocks[0].conv.in_channels + estimator.blocks[0].conv.out_channels + + cnn_cache = torch.rand((depth, batch_size, cnn_channels, 2), dtype=torch.float32, device=device) + att_cache = torch.rand((depth, batch_size, num_heads, prev_seq_len, head_dim * 2), dtype=torch.float32, device=device) + return x, mu, t, spks, cond, cnn_cache, att_cache + + +def get_args(): + parser = argparse.ArgumentParser(description='export your model for deployment') + parser.add_argument('--model_dir', + type=str, + default='Step-Audio-2-mini/token2wav', + help='local path') + parser.add_argument('--onnx_model', + type=str, + default='flow.decoder.estimator.chunk.fp32.dynamic_batch.onnx', + help='onnx model name') + args = parser.parse_args() + print(args) + return args + + +class DiTChunkWrapper(torch.nn.Module): + def __init__(self, dit_model): + super().__init__() + self.dit_model = dit_model + + def forward(self, x, mu, t, spks, cond, cnn_cache, att_cache): + return self.dit_model.forward_chunk(x, mu, t, spks, cond, cnn_cache, att_cache) + + +@torch.no_grad() +def main(): + args = get_args() + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + + with open(f"{args.model_dir}/flow.yaml", "r") as f: + configs = load_hyperpyyaml(f) + flow_model = configs['flow'] + + device = torch.device('cuda') + + + # 1. export flow decoder estimator for chunk processing + flow_model.load_state_dict(torch.load(f"{args.model_dir}/flow.pt", map_location="cpu", weights_only=True), strict=True) + estimator = flow_model.decoder.estimator + estimator.eval() + estimator.to(device) + + estimator_chunk_wrapper = DiTChunkWrapper(estimator) + + batch_size, seq_len, prev_seq_len = 2, 500, 100 + out_channels = flow_model.decoder.estimator.out_channels + dummy_inputs = get_dummy_input_chunk(batch_size, seq_len, prev_seq_len, out_channels, estimator, device) + (x, mu, t, spks, cond, cnn_cache, att_cache) = dummy_inputs + + torch.onnx.export( + estimator_chunk_wrapper, + dummy_inputs, + f'{args.model_dir}/{args.onnx_model}', + export_params=True, + opset_version=18, + do_constant_folding=True, + input_names=['x', 'mu', 't', 'spks', 'cond', 'cnn_cache', 'att_cache'], + output_names=['output', 'new_cnn_cache', 'new_att_cache'], + dynamic_axes={ + 'x': {0: 'batch_size', 2: 'seq_len'}, + 'mu': {0: 'batch_size', 2: 'seq_len'}, + 'cond': {0: 'batch_size', 2: 'seq_len'}, + 't': {0: 'batch_size'}, + 'spks': {0: 'batch_size'}, + 'cnn_cache': {1: 'batch_size'}, + 'att_cache': {1: 'batch_size', 3: 'prev_seq_len'}, + 'output': {0: 'batch_size', 2: 'seq_len'}, + 'new_cnn_cache': {1: 'batch_size'}, + 'new_att_cache': {1: 'batch_size', 3: 'total_seq_len'}, + } + ) + + # 2. test computation consistency + option = onnxruntime.SessionOptions() + option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + option.intra_op_num_threads = 1 + providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider'] + estimator_onnx = onnxruntime.InferenceSession(f'{args.model_dir}/{args.onnx_model}', + sess_options=option, providers=providers) + + for _ in tqdm(range(10)): + seq_len = random.randint(16, 512) + prev_seq_len = random.randint(16, 1024) + dummy_inputs = get_dummy_input_chunk(batch_size, seq_len, prev_seq_len, out_channels, estimator, device) + (x, mu, t, spks, cond, cnn_cache, att_cache) = dummy_inputs + + output_pytorch, new_cnn_cache_pytorch, new_att_cache_pytorch = estimator_chunk_wrapper(*dummy_inputs) + + ort_inputs = { + 'x': x.cpu().numpy(), + 'mu': mu.cpu().numpy(), + 't': t.cpu().numpy(), + 'spks': spks.cpu().numpy(), + 'cond': cond.cpu().numpy(), + 'cnn_cache': cnn_cache.cpu().numpy(), + 'att_cache': att_cache.cpu().numpy(), + } + output_onnx, new_cnn_cache_onnx, new_att_cache_onnx = estimator_onnx.run(None, ort_inputs) + + torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4) + torch.testing.assert_allclose(new_cnn_cache_pytorch, torch.from_numpy(new_cnn_cache_onnx).to(device), rtol=1e-2, atol=1e-4) + torch.testing.assert_allclose(new_att_cache_pytorch, torch.from_numpy(new_att_cache_onnx).to(device), rtol=1e-2, atol=1e-4) + + logging.info('successfully export chunk-wise estimator') + + +if __name__ == "__main__": + main() diff --git a/token2wav_dit.py b/token2wav_dit.py index 69db946..9389c08 100644 --- a/token2wav_dit.py +++ b/token2wav_dit.py @@ -362,16 +362,15 @@ def get_prompt_audio_cache_for_streaming_tts( spk_emb_for_flow.to(self.device), n_timesteps=10 ) - - # cache dict's tensor batch dim is 1 for now - return cache + new_cache = {k: v.clone() for k, v in cache.items()} + # Hack: this is a hack to avoid in-place changes to the cache['estimator_att_cache'] and cache['estimator_cnn_cache'] + return new_cache @torch.inference_mode() def forward_streaming( self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000 - ): - + ): if speaker_id not in self.speaker_cache: assert prompt_audio is not None, "prompt_audio is required for new speaker" assert prompt_audio_sample_rate == 16000 @@ -382,13 +381,14 @@ def forward_streaming( prompt_mels_for_flow = prompt_mels_for_flow[:, :2 * token_len].contiguous() prompt_speech_tokens_list[0] = prompt_speech_tokens_list[0][:token_len] - cache_dict = self.get_prompt_audio_cache_for_streaming_tts(prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow) prompt_audio_dict = {'spk_emb_for_flow': spk_emb_for_flow, 'prompt_mels_for_flow': prompt_mels_for_flow} - + + cache_dict = self.get_prompt_audio_cache_for_streaming_tts(prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow) self.speaker_cache[speaker_id] = {'prompt_audio_dict': prompt_audio_dict, 'cache_dict': cache_dict} + print(f"speaker_id {speaker_id} added to cache") if request_id not in self.streaming_flow_cache: - self.streaming_flow_cache[request_id] = self.speaker_cache[speaker_id]['cache_dict'].copy() + self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()} self.hift_cache_dict[request_id] = dict( mel = torch.zeros(1, 80, 0, device='cuda'), source = torch.zeros(1, 1, 0, device='cuda'), @@ -396,12 +396,14 @@ def forward_streaming( ) current_request_cache = self.streaming_flow_cache[request_id] - prompt_audio_dict = self.speaker_cache[speaker_id]['prompt_audio_dict'] + + current_prompt_audio_dict = self.speaker_cache[speaker_id]['prompt_audio_dict'] generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda') + chunk_mel, new_streaming_flow_cache = self.flow.inference_chunk( token=generated_speech_tokens, - spk=prompt_audio_dict['spk_emb_for_flow'].to(self.device), + spk=current_prompt_audio_dict['spk_emb_for_flow'].to(self.device), cache=current_request_cache, last_chunk=last_chunk, n_timesteps=10, @@ -409,18 +411,19 @@ def forward_streaming( self.streaming_flow_cache[request_id] = new_streaming_flow_cache - if self.streaming_flow_cache[request_id]['estimator_att_cache'].shape[4] > (prompt_audio_dict['prompt_mels_for_flow'].shape[1] + 100): + + if self.streaming_flow_cache[request_id]['estimator_att_cache'].shape[4] > (current_prompt_audio_dict['prompt_mels_for_flow'].shape[1] + 100): self.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.cat([ - self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, :prompt_audio_dict['prompt_mels_for_flow'].shape[1]], + self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, :current_prompt_audio_dict['prompt_mels_for_flow'].shape[1]], self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, -100:], ], dim=4) - hift_cache_mel = self.hift_cache_dict[request_id]['mel'] - hift_cache_source = self.hift_cache_dict[request_id]['source'] - hift_cache_speech = self.hift_cache_dict[request_id]['speech'] - mel = torch.concat([hift_cache_mel, chunk_mel], dim=2) + hift_cache_mel = self.hift_cache_dict[request_id]['mel'].clone() + hift_cache_source = self.hift_cache_dict[request_id]['source'].clone() + hift_cache_speech = self.hift_cache_dict[request_id]['speech'].clone() + mel = torch.concat([hift_cache_mel, chunk_mel], dim=2).clone() speech, source = self.hift(mel, hift_cache_source) @@ -441,7 +444,7 @@ def forward_streaming( assert request_id in self.streaming_flow_cache self.streaming_flow_cache.pop(request_id) self.hift_cache_dict.pop(request_id) - + return speech def collate_fn(batch): From c9b95f46251e9654093c53067ed987b87de929be Mon Sep 17 00:00:00 2001 From: yuekaiz Date: Tue, 21 Oct 2025 15:49:23 +0800 Subject: [PATCH 07/12] lint --- token2wav_dit.py | 145 +++++++++++++++++++++++++---------------------- 1 file changed, 78 insertions(+), 67 deletions(-) diff --git a/token2wav_dit.py b/token2wav_dit.py index 9389c08..1c6c423 100644 --- a/token2wav_dit.py +++ b/token2wav_dit.py @@ -35,7 +35,7 @@ from hyperpyyaml import load_hyperpyyaml -def fade_in_out(fade_in_mel:torch.Tensor, fade_out_mel:torch.Tensor, window:torch.Tensor): +def fade_in_out(fade_in_mel: torch.Tensor, fade_out_mel: torch.Tensor, window: torch.Tensor): """perform fade_in_out in tensor style """ mel_overlap_len = int(window.shape[0] / 2) @@ -45,6 +45,7 @@ def fade_in_out(fade_in_mel:torch.Tensor, fade_out_mel:torch.Tensor, window:torc fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:] return fade_in_mel + def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, dtype): import tensorrt as trt logging.info("Converting onnx to trt...") @@ -57,10 +58,7 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, dtype): # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB if dtype == torch.float16: config.set_flag(trt.BuilderFlag.FP16) - elif dtype == torch.bfloat16: - config.set_flag(trt.BuilderFlag.BF16) - elif dtype == torch.float32: - config.set_flag(trt.BuilderFlag.FP32) + profile = builder.create_optimization_profile() # load onnx model with open(onnx_model, "rb") as f: @@ -93,6 +91,7 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, dtype): f.write(engine_bytes) logging.info("Succesfully convert onnx to trt...") + class TrtContextWrapper: def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'): self.trt_context_pool = queue.Queue(maxsize=trt_concurrent) @@ -111,6 +110,7 @@ def acquire_estimator(self): def release_estimator(self, context, stream): self.trt_context_pool.put([context, stream]) + class CosyVoice2_Token2Wav(torch.nn.Module): def __init__(self, model_dir: str, enable_trt: bool = False, device_id: int = 0, streaming: bool = False, dtype: torch.dtype = torch.float16): super().__init__() @@ -134,28 +134,33 @@ def __init__(self, model_dir: str, enable_trt: bool = False, device_id: int = 0, option = onnxruntime.SessionOptions() option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL option.intra_op_num_threads = 1 - self.spk_model = onnxruntime.InferenceSession(f"{model_dir}/campplus.onnx", sess_options=option, - providers=["CPUExecutionProvider"]) - + self.spk_model = onnxruntime.InferenceSession( + f"{model_dir}/campplus.onnx", sess_options=option, + providers=["CPUExecutionProvider"]) self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2_25hz.onnx").to(self.device).eval() - gpu="l20" + gpu = "l20" if enable_trt: if streaming: - self.load_trt(f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.chunk.{gpu}.plan', - f'{model_dir}/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx', - 1, - self.dtype, streaming) + self.load_trt( + f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.chunk.{gpu}.plan', + f'{model_dir}/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx', + 1, + self.dtype, streaming + ) else: - self.load_trt(f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.{gpu}.plan', - f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx', - 1, - self.dtype) - self.load_spk_trt(f'{model_dir}/campplus.{gpu}.fp32.trt', - f'{model_dir}/campplus.onnx', - 1, - False) - + self.load_trt( + f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.{gpu}.plan', + f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx', + 1, + self.dtype + ) + self.load_spk_trt( + f'{model_dir}/campplus.{gpu}.fp32.trt', + f'{model_dir}/campplus.onnx', + 1, + False + ) self.streaming_flow_cache = {} self.speaker_cache = {} @@ -199,7 +204,7 @@ def forward_spk_embedding(self, spk_feat): def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp16=True): if not os.path.exists(spk_model) or os.path.getsize(spk_model) == 0: trt_kwargs = self.get_spk_trt_kwargs() - convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, fp16) + convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, torch.float32) import tensorrt as trt with open(spk_model, 'rb') as f: spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) @@ -219,7 +224,7 @@ def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_co opt_batch_size = 2 max_batch_size = 16 if streaming: - opt_batch_size, max_batch_size = 1, 1 # only support batch size 1 for streaming tts + opt_batch_size, max_batch_size = 1, 1 # only support batch size 1 for streaming tts trt_kwargs = self.get_trt_kwargs_dynamic_batch(opt_batch_size=opt_batch_size, max_batch_size=max_batch_size, streaming=streaming) convert_onnx_to_trt(flow_decoder_estimator_model, trt_kwargs, flow_decoder_onnx_model, dtype) del self.flow.decoder.estimator @@ -232,13 +237,27 @@ def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_co def get_trt_kwargs_dynamic_batch(self, opt_batch_size=2, max_batch_size=64, streaming=False): if streaming: min_shape = [(2, 80, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80), (16, 2, 1024, 2), (16, 2, 8, 0, 128)] - opt_shape = [(opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2,), (opt_batch_size*2, 80), (16, opt_batch_size*2, 1024, 2), (16, opt_batch_size*2, 8, 100, 128)] - max_shape = [(max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2,), (max_batch_size*2, 80), (16, max_batch_size*2, 1024, 2), (16, max_batch_size*2, 8, 1000, 128)] + opt_shape = [ + (opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 80, 500), + (opt_batch_size * 2,), (opt_batch_size * 2, 80), (16, opt_batch_size * 2, 1024, 2), + (16, opt_batch_size * 2, 8, 100, 128) + ] + max_shape = [ + (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000), + (max_batch_size * 2,), (max_batch_size * 2, 80), (16, max_batch_size * 2, 1024, 2), + (16, max_batch_size * 2, 8, 1000, 128) + ] input_names = ["x", "mu", "cond", "t", "spks", "cnn_cache", "att_cache"] else: min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80)] - opt_shape = [(opt_batch_size*2, 80, 500), (opt_batch_size*2, 1, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2,), (opt_batch_size*2, 80)] - max_shape = [(max_batch_size*2, 80, 3000), (max_batch_size*2, 1, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2,), (max_batch_size*2, 80)] + opt_shape = [ + (opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 1, 500), (opt_batch_size * 2, 80, 500), + (opt_batch_size * 2, 80, 500), (opt_batch_size * 2,), (opt_batch_size * 2, 80) + ] + max_shape = [ + (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 1, 3000), (max_batch_size * 2, 80, 3000), + (max_batch_size * 2, 80, 3000), (max_batch_size * 2,), (max_batch_size * 2, 80) + ] input_names = ["x", "mask", "mu", "cond", "t", "spks"] return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} @@ -256,7 +275,7 @@ def prompt_audio_tokenization(self, prompt_audios_list: list[torch.Tensor]) -> l speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist() prompt_speech_tokens_list.append(speech_tokens_i) return prompt_speech_tokens_list - + def get_spk_emb(self, prompt_audios_list: list[torch.Tensor]) -> torch.Tensor: spk_emb_for_flow = [] for audio in prompt_audios_list: @@ -266,11 +285,11 @@ def get_spk_emb(self, prompt_audios_list: list[torch.Tensor]) -> torch.Tensor: spk_emb = self.forward_spk_embedding(spk_feat) spk_emb_for_flow.append(spk_emb) - spk_emb_for_flow = torch.tensor(spk_emb_for_flow) + spk_emb_for_flow = torch.tensor(spk_emb_for_flow) if self.dtype != torch.float32: spk_emb_for_flow = spk_emb_for_flow.to(self.dtype) return spk_emb_for_flow - + def get_prompt_mels(self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]): prompt_mels_for_flow = [] prompt_mels_lens_for_flow = [] @@ -283,11 +302,17 @@ def get_prompt_mels(self, prompt_audios_list: list[torch.Tensor], prompt_audios_ mel_len = mel.shape[0] prompt_mels_for_flow.append(mel) prompt_mels_lens_for_flow.append(mel_len) - prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(prompt_mels_for_flow, batch_first=True, padding_value=0) # [B, T', num_mels=80] + prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence( + prompt_mels_for_flow, batch_first=True, padding_value=0 + ) # [B, T', num_mels=80] prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow) return prompt_mels_for_flow, prompt_mels_lens_for_flow - - def forward_flow(self, prompt_speech_tokens_list: list[list[int]], generated_speech_tokens_list: list[list[int]], prompt_mels_for_flow: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor, spk_emb_for_flow: torch.Tensor): + + def forward_flow(self, prompt_speech_tokens_list: list[list[int]], + generated_speech_tokens_list: list[list[int]], + prompt_mels_for_flow: torch.Tensor, + prompt_mels_lens_for_flow: torch.Tensor, + spk_emb_for_flow: torch.Tensor): batch_size = prompt_mels_for_flow.shape[0] flow_inputs = [] flow_inputs_lens = [] @@ -315,39 +340,34 @@ def forward_hift(self, generated_mels: torch.Tensor, generated_mels_lens: torch. generated_wavs.append(wav) return generated_wavs - @torch.inference_mode() def forward( self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int] ): - # assert all item in prompt_audios_sample_rate is 16000 assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate) - prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio(prompt_audios_list, prompt_audios_sample_rate) - generated_mels, generated_mels_lens = self.forward_flow(prompt_speech_tokens_list, generated_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow) + generated_mels, generated_mels_lens = self.forward_flow( + prompt_speech_tokens_list, generated_speech_tokens_list, + prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow + ) generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow) - return generated_wavs def prepare_prompt_audio( self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int] ): - # assert all item in prompt_audios_sample_rate is 16000 assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate) - prompt_speech_tokens_list = self.prompt_audio_tokenization(prompt_audios_list) prompt_mels_for_flow, prompt_mels_lens_for_flow = self.get_prompt_mels(prompt_audios_list, prompt_audios_sample_rate) spk_emb_for_flow = self.get_spk_emb(prompt_audios_list) - return prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow - def get_prompt_audio_cache_for_streaming_tts( self, prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow ): @@ -366,11 +386,10 @@ def get_prompt_audio_cache_for_streaming_tts( # Hack: this is a hack to avoid in-place changes to the cache['estimator_att_cache'] and cache['estimator_cnn_cache'] return new_cache - @torch.inference_mode() def forward_streaming( self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000 - ): + ): if speaker_id not in self.speaker_cache: assert prompt_audio is not None, "prompt_audio is required for new speaker" assert prompt_audio_sample_rate == 16000 @@ -385,14 +404,13 @@ def forward_streaming( cache_dict = self.get_prompt_audio_cache_for_streaming_tts(prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow) self.speaker_cache[speaker_id] = {'prompt_audio_dict': prompt_audio_dict, 'cache_dict': cache_dict} - print(f"speaker_id {speaker_id} added to cache") if request_id not in self.streaming_flow_cache: self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()} self.hift_cache_dict[request_id] = dict( - mel = torch.zeros(1, 80, 0, device='cuda'), - source = torch.zeros(1, 1, 0, device='cuda'), - speech = torch.zeros(1, 0, device='cuda'), + mel=torch.zeros(1, 80, 0, device='cuda'), + source=torch.zeros(1, 1, 0, device='cuda'), + speech=torch.zeros(1, 0, device='cuda'), ) current_request_cache = self.streaming_flow_cache[request_id] @@ -400,7 +418,6 @@ def forward_streaming( current_prompt_audio_dict = self.speaker_cache[speaker_id]['prompt_audio_dict'] generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda') - chunk_mel, new_streaming_flow_cache = self.flow.inference_chunk( token=generated_speech_tokens, spk=current_prompt_audio_dict['spk_emb_for_flow'].to(self.device), @@ -411,15 +428,12 @@ def forward_streaming( self.streaming_flow_cache[request_id] = new_streaming_flow_cache - if self.streaming_flow_cache[request_id]['estimator_att_cache'].shape[4] > (current_prompt_audio_dict['prompt_mels_for_flow'].shape[1] + 100): self.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.cat([ self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, :current_prompt_audio_dict['prompt_mels_for_flow'].shape[1]], self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, -100:], ], dim=4) - - hift_cache_mel = self.hift_cache_dict[request_id]['mel'].clone() hift_cache_source = self.hift_cache_dict[request_id]['source'].clone() hift_cache_speech = self.hift_cache_dict[request_id]['speech'].clone() @@ -433,9 +447,9 @@ def forward_streaming( # update vocoder cache self.hift_cache_dict[request_id] = dict( - mel = mel[..., -self.mel_cache_len:].clone().detach(), - source = source[:, :, -self.source_cache_len:].clone().detach(), - speech = speech[:, -self.source_cache_len:].clone().detach(), + mel=mel[..., -self.mel_cache_len:].clone().detach(), + source=source[:, :, -self.source_cache_len:].clone().detach(), + speech=speech[:, -self.source_cache_len:].clone().detach(), ) if not last_chunk: speech = speech[:, :-self.source_cache_len] @@ -447,17 +461,19 @@ def forward_streaming( return speech + def collate_fn(batch): ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], [] - for i, item in enumerate(batch): + for item in batch: generated_speech_tokens_list.append(item['target_audio_cosy2_tokens']) - audio = torch.from_numpy(item['prompt_audio']['array']).float() + audio = torch.from_numpy(item['prompt_audio']['array']).float() prompt_audios_list.append(audio) prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate']) ids.append(item['id']) return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate + def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--enable-trt", action="store_true") @@ -468,32 +484,27 @@ def get_args(): parser.add_argument("--warmup", type=int, default=3, help="Number of warmup epochs, performance statistics will only be collected from the last epoch") return parser.parse_args() + if __name__ == "__main__": args = get_args() model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt) - # mkdir output_dir if not exists if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) dataset_name = "yuekai/seed_tts_cosy2" dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True) - data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0) - - - for epoch in range(args.warmup): + + for _ in range(args.warmup): start_time = time.time() - for batch in data_loader: ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate) - for id, wav in zip(ids, generated_wavs): torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000) - end_time = time.time() epoch_time = end_time - start_time - print(f"Measurement epoch time taken: {epoch_time:.4f} seconds") \ No newline at end of file + print(f"Measurement epoch time taken: {epoch_time:.4f} seconds") From 96ae8882c0f487a072ae00d994f5a8fd312707d6 Mon Sep 17 00:00:00 2001 From: yuekaiz Date: Tue, 21 Oct 2025 21:07:47 +0800 Subject: [PATCH 08/12] code clean --- benchmark_streaming_token2wav.py | 111 ---- cosyvoice2/flow/flow.py | 15 - cosyvoice2/flow/flow_matching.py | 5 - token2wav.py | 585 ++++++++++++++---- token2wav_dit.py | 510 --------------- .../export_onnx_offline_token2wav.py | 0 .../export_onnx_streaming_token2wav.py | 0 7 files changed, 467 insertions(+), 759 deletions(-) delete mode 100644 benchmark_streaming_token2wav.py delete mode 100644 token2wav_dit.py rename export_onnx.py => tools/export_onnx_offline_token2wav.py (100%) rename export_onnx_chunk.py => tools/export_onnx_streaming_token2wav.py (100%) diff --git a/benchmark_streaming_token2wav.py b/benchmark_streaming_token2wav.py deleted file mode 100644 index 70b1199..0000000 --- a/benchmark_streaming_token2wav.py +++ /dev/null @@ -1,111 +0,0 @@ -import torch -import os -import argparse -from datasets import load_dataset -from torch.utils.data import DataLoader -import numpy as np -import torchaudio -import time -from token2wav_dit import CosyVoice2_Token2Wav -import soundfile as sf - -def collate_fn(batch): - ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], [] - prompt_speech_tokens_list, prompt_text_list = [], [] - for i, item in enumerate(batch): - generated_speech_tokens_list.append(item['target_audio_cosy2_tokens']) - audio = torch.from_numpy(item['prompt_audio']['array']).float() - prompt_audios_list.append(audio) - prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate']) - ids.append(item['id']) - prompt_speech_tokens_list.append(item['prompt_audio_cosy2_tokens']) - prompt_text_list.append(item['prompt_text']) - - return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--enable-trt", action="store_true") - parser.add_argument("--model-dir", type=str, default="./Step-Audio-2-mini/token2wav") - parser.add_argument("--batch-size", type=int, default=1) - parser.add_argument("--output-dir", type=str, default="generated_wavs") - parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts") - parser.add_argument("--dataset-name", type=str, default="yuekai/seed_tts_cosy2") - return parser.parse_args() - - -def fake_generated_id_iter(generated_speech_tokens_list): - for i in range(len(generated_speech_tokens_list)): - yield generated_speech_tokens_list[i] - - - -if __name__ == "__main__": - args = get_args() - - if not os.path.exists(args.output_dir): - os.makedirs(args.output_dir) - - dataset_name = args.dataset_name - dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True) - data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0) - - token2wav_model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt, streaming=True) - - flow_pre_lookahead_len = 3 - CHUNK_SIZE = 25 - OVERLAP_SIZE = 0 - - warmup_times = 3 - for _ in range(warmup_times): - start_time = time.time() - for batch in data_loader: - tts_speech_list = [] - ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list = batch - - id, generated_speech_tokens, prompt_audio, prompt_audio_sample_rate = ids[0], generated_speech_tokens_list[0], prompt_audios_list[0], prompt_audios_sample_rate[0] - assert prompt_audio_sample_rate == 16000 - - prompt_text = prompt_text_list[0] - prompt_speech_tokens = prompt_speech_tokens_list[0] - - - # generated_ids_iter = fake_generated_id_iter(generated_speech_tokens) - - semantic_token_ids_arr, token_offset = [], 0 - flow_prompt_speech_token_len = len(prompt_speech_tokens) - - buffer = generated_speech_tokens - output_wavs = [] - while True: - - if len(buffer) >= CHUNK_SIZE + token2wav_model.flow.pre_lookahead_len: - wavs = token2wav_model.forward_streaming(buffer[:CHUNK_SIZE + token2wav_model.flow.pre_lookahead_len], False, request_id=id, speaker_id=id, prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate) - buffer = buffer[CHUNK_SIZE - OVERLAP_SIZE:] - - output_wavs.append(wavs) - - else: - wavs = token2wav_model.forward_streaming(buffer, True, request_id=id, speaker_id=id, prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate) - output_wavs.append(wavs) - break - - # tts_speech = torch.cat(output_wavs, dim=-1) - # torchaudio.save(os.path.join(args.output_dir, f"{id}.wav"), tts_speech.cpu(), 24000) - - for i, wav in enumerate(output_wavs): - output_wavs[i] = wav.cpu().numpy().squeeze() - - - audios = output_wavs - reconstructed_audio = np.concatenate(audios) - # Save reconstructed audio - sf.write(os.path.join(args.output_dir, f"{id}.wav"), reconstructed_audio, 24000, "PCM_16") - - - print(f"Saved {id}") - end_time = time.time() - if _ == 0: - token2wav_model.speaker_cache = {} - print(f"Warmup time: {end_time - start_time} seconds") - diff --git a/cosyvoice2/flow/flow.py b/cosyvoice2/flow/flow.py index 51d7492..d1aa4ac 100644 --- a/cosyvoice2/flow/flow.py +++ b/cosyvoice2/flow/flow.py @@ -65,22 +65,14 @@ def scatter_cuda_graph(self, enable_cuda_graph: bool): def inference(self, token, token_len, - # prompt_token, - # prompt_token_len, prompt_feat, prompt_feat_len, embedding, n_timesteps: int = 10, ): - # assert token.shape[0] == 1 - # xvec projection embedding = F.normalize(embedding, dim=1) embedding = self.spk_embed_affine_layer(embedding) - - # concat text and prompt_text - # token_len = prompt_token_len + token_len - # token = torch.concat([prompt_token, token], dim=1) mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding) token = self.input_embedding(torch.clamp(token, min=0)) * mask @@ -89,12 +81,7 @@ def inference(self, h, h_lengths = self.encoder.forward(token, token_len) h = self.encoder_proj(h) - # condition - # mel_len1 = prompt_feat.shape[1] - # mel_len2 = h.shape[1] - prompt_feat.shape[1] - conds = torch.zeros_like(h) - # conds[:, :mel_len1] = prompt_feat for i, j in enumerate(prompt_feat_len): conds[i, :j] = prompt_feat[i, :j] conds = conds.transpose(1, 2).contiguous() @@ -110,8 +97,6 @@ def inference(self, n_timesteps=n_timesteps, ) - # feat = feat[:, :, mel_len1:] - # assert feat.shape[2] == mel_len2 return feat.float(), h_lengths @torch.inference_mode() diff --git a/cosyvoice2/flow/flow_matching.py b/cosyvoice2/flow/flow_matching.py index 701370b..f136b87 100644 --- a/cosyvoice2/flow/flow_matching.py +++ b/cosyvoice2/flow/flow_matching.py @@ -91,7 +91,6 @@ def solve_euler(self, x, t_span, mu, mask, spks, cond): cond: Not used but kept for future purposes """ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] - # t = t.unsqueeze(dim=0) t_in = torch.zeros([x.shape[0] * 2], device=x.device, dtype=x.dtype) assert self.inference_cfg_rate > 0, 'inference_cfg_rate better > 0' @@ -124,7 +123,6 @@ def solve_euler(self, x, t_span, mu, mask, spks, cond): @torch.inference_mode() def forward(self, mu, mask, spks, cond, n_timesteps=10, temperature=1.0): - # z = self.rand_noise[:, :, :mu.size(2)] * temperature z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype) # cosine scheduling @@ -211,7 +209,6 @@ def solve_euler_chunk(self, assert self.inference_cfg_rate > 0, 'cfg rate should be > 0' t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] - # t = t.unsqueeze(dim=0) # (b,) t_in = torch.zeros([x.shape[0] * 2], device=x.device, dtype=x.dtype) # setup initial cache @@ -288,9 +285,7 @@ def forward_chunk(self, # get offset from att_cache offset = att_cache.shape[4] if att_cache is not None else 0 z = self.rand_noise[:, :, offset:offset+mu.size(2)] * temperature - z = z.to(mu.dtype) - t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype) # cosine scheduling t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) diff --git a/token2wav.py b/token2wav.py index ad4163e..8bdd1d6 100644 --- a/token2wav.py +++ b/token2wav.py @@ -1,14 +1,37 @@ -import io - +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Example Usage + CUDA_VISIBLE_DEVICES=0 \ + python3 token2wav.py --enable-trt || exit 1 +""" import torch -import torchaudio -import s3tokenizer -import onnxruntime -import numpy as np - -import torchaudio.compliance.kaldi as kaldi +# from flashcosyvoice.modules.flow import CausalMaskedDiffWithXvec from flashcosyvoice.modules.hifigan import HiFTGenerator from flashcosyvoice.utils.audio import mel_spectrogram +import torchaudio.compliance.kaldi as kaldi +import onnxruntime +import s3tokenizer +from torch.utils.data import DataLoader +from datasets import load_dataset +import torchaudio +import os +import logging +import argparse +import queue +import time +import numpy as np from hyperpyyaml import load_hyperpyyaml def fade_in_out(fade_in_mel:torch.Tensor, fade_out_mel:torch.Tensor, window:torch.Tensor): @@ -22,122 +45,403 @@ def fade_in_out(fade_in_mel:torch.Tensor, fade_out_mel:torch.Tensor, window:torc return fade_in_mel -class Token2wav(): - - def __init__(self, model_path, float16=False): - self.float16 = float16 - - self.audio_tokenizer = s3tokenizer.load_model(f"{model_path}/speech_tokenizer_v2_25hz.onnx").cuda().eval() +def fade_in_out(fade_in_mel: torch.Tensor, fade_out_mel: torch.Tensor, window: torch.Tensor): + """perform fade_in_out in tensor style + """ + mel_overlap_len = int(window.shape[0] / 2) + fade_in_mel = fade_in_mel.clone() + fade_in_mel[..., :mel_overlap_len] = \ + fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \ + fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:] + return fade_in_mel - option = onnxruntime.SessionOptions() - option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL - option.intra_op_num_threads = 1 - self.spk_model = onnxruntime.InferenceSession(f"{model_path}/campplus.onnx", sess_options=option, providers=["CPUExecutionProvider"]) - with open(f"{model_path}/flow.yaml", "r") as f: +def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, dtype): + import tensorrt as trt + logging.info("Converting onnx to trt...") + network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + logger = trt.Logger(trt.Logger.INFO) + builder = trt.Builder(logger) + network = builder.create_network(network_flags) + parser = trt.OnnxParser(network, logger) + config = builder.create_builder_config() + # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB + if dtype == torch.float16: + config.set_flag(trt.BuilderFlag.FP16) + + profile = builder.create_optimization_profile() + # load onnx model + with open(onnx_model, "rb") as f: + if not parser.parse(f.read()): + for error in range(parser.num_errors): + print(parser.get_error(error)) + raise ValueError('failed to parse {}'.format(onnx_model)) + # set input shapes + for i in range(len(trt_kwargs['input_names'])): + profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i]) + if dtype == torch.float16: + tensor_dtype = trt.DataType.HALF + elif dtype == torch.bfloat16: + tensor_dtype = trt.DataType.BF16 + elif dtype == torch.float32: + tensor_dtype = trt.DataType.FLOAT + else: + raise ValueError('invalid dtype {}'.format(dtype)) + # set input and output data type + for i in range(network.num_inputs): + input_tensor = network.get_input(i) + input_tensor.dtype = tensor_dtype + for i in range(network.num_outputs): + output_tensor = network.get_output(i) + output_tensor.dtype = tensor_dtype + config.add_optimization_profile(profile) + engine_bytes = builder.build_serialized_network(network, config) + # save trt engine + with open(trt_model, "wb") as f: + f.write(engine_bytes) + logging.info("Succesfully convert onnx to trt...") + + +class TrtContextWrapper: + def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'): + self.trt_context_pool = queue.Queue(maxsize=trt_concurrent) + self.trt_engine = trt_engine + self.device = device + for _ in range(trt_concurrent): + trt_context = trt_engine.create_execution_context() + trt_stream = torch.cuda.stream(torch.cuda.Stream(torch.device(device))) + assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent) + self.trt_context_pool.put([trt_context, trt_stream]) + assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context' + + def acquire_estimator(self): + return self.trt_context_pool.get(), self.trt_engine + + def release_estimator(self, context, stream): + self.trt_context_pool.put([context, stream]) + + +class CosyVoice2_Token2Wav(torch.nn.Module): + def __init__(self, model_dir: str, enable_trt: bool = False, device_id: int = 0, streaming: bool = False, dtype: torch.dtype = torch.float16): + super().__init__() + self.device_id = device_id + self.device = f"cuda:{device_id}" + with open(f"{model_dir}/flow.yaml", "r") as f: configs = load_hyperpyyaml(f) self.flow = configs['flow'] - if float16: - self.flow.half() - self.flow.load_state_dict(torch.load(f"{model_path}/flow.pt", map_location="cpu", weights_only=True), strict=True) - self.flow.cuda().eval() + + self.dtype = dtype + self.flow.to(self.dtype) + + self.flow.load_state_dict(torch.load(f"{model_dir}/flow.pt", map_location="cpu", weights_only=True), strict=True) + self.flow.to(self.device).eval() self.hift = HiFTGenerator() - hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(f"{model_path}/hift.pt", map_location="cpu", weights_only=True).items()} + hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(f"{model_dir}/hift.pt", map_location="cpu", weights_only=True).items()} self.hift.load_state_dict(hift_state_dict, strict=True) - self.hift.cuda().eval() + self.hift.to(self.device).eval() + + option = onnxruntime.SessionOptions() + option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + option.intra_op_num_threads = 1 + self.spk_model = onnxruntime.InferenceSession( + f"{model_dir}/campplus.onnx", sess_options=option, + providers=["CPUExecutionProvider"]) + self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2_25hz.onnx").to(self.device).eval() + + gpu = "l20" + if enable_trt: + if streaming: + self.load_trt( + f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.chunk.{gpu}.plan', + f'{model_dir}/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx', + 1, + self.dtype, streaming + ) + else: + self.load_trt( + f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.{gpu}.plan', + f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx', + 1, + self.dtype + ) + self.load_spk_trt( + f'{model_dir}/campplus.{gpu}.fp32.trt', + f'{model_dir}/campplus.onnx', + 1, + False + ) - self.cache = {} + self.streaming_flow_cache = {} + self.speaker_cache = {} - # stream conf self.mel_cache_len = 8 # hard-coded, 160ms self.source_cache_len = int(self.mel_cache_len * 480) # 50hz mel -> 24kHz wave self.speech_window = torch.from_numpy(np.hamming(2 * self.source_cache_len)).cuda() - # hifigan cache + # hifigan cache for streaming tts self.hift_cache_dict = {} + def forward_spk_embedding(self, spk_feat): + if isinstance(self.spk_model, onnxruntime.InferenceSession): + return self.spk_model.run( + None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()} + )[0].flatten().tolist() + else: + [spk_model, stream], trt_engine = self.spk_model.acquire_estimator() + # NOTE need to synchronize when switching stream + with torch.cuda.device(self.device_id): + torch.cuda.current_stream().synchronize() + spk_feat = spk_feat.unsqueeze(dim=0).to(self.device) + batch_size = spk_feat.size(0) + + with stream: + spk_model.set_input_shape('input', (batch_size, spk_feat.size(1), 80)) + output_tensor = torch.empty((batch_size, 192), device=spk_feat.device) + + data_ptrs = [spk_feat.contiguous().data_ptr(), + output_tensor.contiguous().data_ptr()] + for i, j in enumerate(data_ptrs): + + spk_model.set_tensor_address(trt_engine.get_tensor_name(i), j) + # run trt engine + assert spk_model.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True + torch.cuda.current_stream().synchronize() + self.spk_model.release_estimator(spk_model, stream) + + return output_tensor.cpu().numpy().flatten().tolist() + + def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp16=True): + if not os.path.exists(spk_model) or os.path.getsize(spk_model) == 0: + trt_kwargs = self.get_spk_trt_kwargs() + convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, torch.float32) + import tensorrt as trt + with open(spk_model, 'rb') as f: + spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) + assert spk_engine is not None, 'failed to load trt {}'.format(spk_model) + self.spk_model = TrtContextWrapper(spk_engine, trt_concurrent=trt_concurrent, device=self.device) + + def get_spk_trt_kwargs(self): + min_shape = [(1, 4, 80)] + opt_shape = [(1, 500, 80)] + max_shape = [(1, 3000, 80)] + input_names = ["input"] + return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} + + def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent=1, dtype=torch.float16, streaming=False): + assert torch.cuda.is_available(), 'tensorrt only supports gpu!' + if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0: + opt_batch_size = 2 + max_batch_size = 16 + if streaming: + opt_batch_size, max_batch_size = 1, 1 # only support batch size 1 for streaming tts + trt_kwargs = self.get_trt_kwargs_dynamic_batch(opt_batch_size=opt_batch_size, max_batch_size=max_batch_size, streaming=streaming) + convert_onnx_to_trt(flow_decoder_estimator_model, trt_kwargs, flow_decoder_onnx_model, dtype) + del self.flow.decoder.estimator + import tensorrt as trt + with open(flow_decoder_estimator_model, 'rb') as f: + estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) + assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model) + self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device) + + def get_trt_kwargs_dynamic_batch(self, opt_batch_size=2, max_batch_size=64, streaming=False): + if streaming: + min_shape = [(2, 80, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80), (16, 2, 1024, 2), (16, 2, 8, 0, 128)] + opt_shape = [ + (opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 80, 500), + (opt_batch_size * 2,), (opt_batch_size * 2, 80), (16, opt_batch_size * 2, 1024, 2), + (16, opt_batch_size * 2, 8, 100, 128) + ] + max_shape = [ + (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000), + (max_batch_size * 2,), (max_batch_size * 2, 80), (16, max_batch_size * 2, 1024, 2), + (16, max_batch_size * 2, 8, 1000, 128) + ] + input_names = ["x", "mu", "cond", "t", "spks", "cnn_cache", "att_cache"] + else: + min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80)] + opt_shape = [ + (opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 1, 500), (opt_batch_size * 2, 80, 500), + (opt_batch_size * 2, 80, 500), (opt_batch_size * 2,), (opt_batch_size * 2, 80) + ] + max_shape = [ + (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 1, 3000), (max_batch_size * 2, 80, 3000), + (max_batch_size * 2, 80, 3000), (max_batch_size * 2,), (max_batch_size * 2, 80) + ] + input_names = ["x", "mask", "mu", "cond", "t", "spks"] + return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} + + def prompt_audio_tokenization(self, prompt_audios_list: list[torch.Tensor]) -> list[list[int]]: + prompt_speech_tokens_list, prompt_speech_mels_list = [], [] + for audio in prompt_audios_list: + assert len(audio.shape) == 1 + log_mel = s3tokenizer.log_mel_spectrogram(audio) # [num_mels, T] + prompt_speech_mels_list.append(log_mel) + prompt_mels_for_llm, prompt_mels_lens_for_llm = s3tokenizer.padding(prompt_speech_mels_list) + prompt_speech_tokens, prompt_speech_tokens_lens = self.audio_tokenizer.quantize( + prompt_mels_for_llm.to(self.device), prompt_mels_lens_for_llm.to(self.device) + ) + for i in range(len(prompt_speech_tokens)): + speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist() + prompt_speech_tokens_list.append(speech_tokens_i) + return prompt_speech_tokens_list + + def get_spk_emb(self, prompt_audios_list: list[torch.Tensor]) -> torch.Tensor: + spk_emb_for_flow = [] + for audio in prompt_audios_list: + assert len(audio.shape) == 1 + spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000) + spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True) + spk_emb = self.forward_spk_embedding(spk_feat) + + spk_emb_for_flow.append(spk_emb) + spk_emb_for_flow = torch.tensor(spk_emb_for_flow) + if self.dtype != torch.float32: + spk_emb_for_flow = spk_emb_for_flow.to(self.dtype) + return spk_emb_for_flow + + def get_prompt_mels(self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]): + prompt_mels_for_flow = [] + prompt_mels_lens_for_flow = [] + for audio, sample_rate in zip(prompt_audios_list, prompt_audios_sample_rate): + assert len(audio.shape) == 1 + audio = audio.unsqueeze(0) + if sample_rate != 24000: + audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio) + mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0) # [T, num_mels] + mel_len = mel.shape[0] + prompt_mels_for_flow.append(mel) + prompt_mels_lens_for_flow.append(mel_len) + prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence( + prompt_mels_for_flow, batch_first=True, padding_value=0 + ) # [B, T', num_mels=80] + prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow) + return prompt_mels_for_flow, prompt_mels_lens_for_flow + + def forward_flow(self, prompt_speech_tokens_list: list[list[int]], + generated_speech_tokens_list: list[list[int]], + prompt_mels_for_flow: torch.Tensor, + prompt_mels_lens_for_flow: torch.Tensor, + spk_emb_for_flow: torch.Tensor): + batch_size = prompt_mels_for_flow.shape[0] + flow_inputs = [] + flow_inputs_lens = [] + for prompt_speech_tokens, generated_speech_tokens in zip(prompt_speech_tokens_list, generated_speech_tokens_list): + flow_inputs.append(torch.tensor(prompt_speech_tokens + generated_speech_tokens)) + flow_inputs_lens.append(len(prompt_speech_tokens) + len(generated_speech_tokens)) + + flow_inputs = torch.nn.utils.rnn.pad_sequence(flow_inputs, batch_first=True, padding_value=0) + flow_inputs_lens = torch.tensor(flow_inputs_lens) + + with torch.amp.autocast(self.device, dtype=torch.float16): + generated_mels, generated_mels_lens = self.flow.inference( + flow_inputs.to(self.device), flow_inputs_lens.to(self.device), + prompt_mels_for_flow.to(self.device), prompt_mels_lens_for_flow.to(self.device), spk_emb_for_flow.to(self.device), 10 + ) - def _prepare_prompt(self, prompt_wav): - audio = s3tokenizer.load_audio(prompt_wav, sr=16000) # [T] - mels = s3tokenizer.log_mel_spectrogram(audio) - mels, mels_lens = s3tokenizer.padding([mels]) - prompt_speech_tokens, prompt_speech_tokens_lens = self.audio_tokenizer.quantize(mels.cuda(), mels_lens.cuda()) - - spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000) - spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True) - spk_emb = torch.tensor(self.spk_model.run( - None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()} - )[0], device='cuda') - - audio, sample_rate = torchaudio.load(prompt_wav, backend='soundfile') - audio = audio.mean(dim=0, keepdim=True) # [1, T] - if sample_rate != 24000: - audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio) - prompt_mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0) # [T, num_mels] - prompt_mels = prompt_mel.unsqueeze(0).cuda() - prompt_mels_lens = torch.tensor([prompt_mels.shape[1]], dtype=torch.int32, device='cuda') - prompt_mels = torch.nn.functional.pad(prompt_mels, (0, 0, 0, prompt_speech_tokens.shape[1] * self.flow.up_rate - prompt_mels.shape[1]), mode='replicate') - return prompt_speech_tokens, prompt_speech_tokens_lens, spk_emb, prompt_mels, prompt_mels_lens - - def __call__(self, generated_speech_tokens, prompt_wav): - if prompt_wav not in self.cache: - self.cache[prompt_wav] = self._prepare_prompt(prompt_wav) - prompt_speech_tokens, prompt_speech_tokens_lens, spk_emb, prompt_mels, prompt_mels_lens = self.cache[prompt_wav] + return generated_mels, generated_mels_lens - generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda') - generated_speech_tokens_lens = torch.tensor([generated_speech_tokens.shape[1]], dtype=torch.int32, device='cuda') - - with torch.amp.autocast("cuda", dtype=torch.float16 if self.float16 else torch.float32): - mel = self.flow.inference(generated_speech_tokens, generated_speech_tokens_lens, - prompt_speech_tokens, prompt_speech_tokens_lens, - prompt_mels, prompt_mels_lens, spk_emb, 10) - - wav, _ = self.hift(speech_feat=mel) - output = io.BytesIO() - torchaudio.save(output, wav.cpu(), sample_rate=24000, format='wav') - - return output.getvalue() - - def set_stream_cache(self, prompt_wav): - if prompt_wav not in self.cache: - self.cache[prompt_wav] = self._prepare_prompt(prompt_wav) - prompt_speech_tokens, prompt_speech_tokens_lens, spk_emb, prompt_mels, prompt_mels_lens = self.cache[prompt_wav] - self.stream_cache = self.flow.setup_cache( - torch.cat([prompt_speech_tokens, prompt_speech_tokens[:, :3]], dim=1), - prompt_mels, spk_emb, n_timesteps=10) - - # hift cache - self.hift_cache_dict = dict( - mel = torch.zeros(1, prompt_mels.shape[2], 0, device='cuda'), - source = torch.zeros(1, 1, 0, device='cuda'), - speech = torch.zeros(1, 0, device='cuda'), + def forward_hift(self, generated_mels: torch.Tensor, generated_mels_lens: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor): + batch_size = generated_mels.shape[0] + generated_wavs = [] + for i in range(batch_size): + mel = generated_mels[i, :, prompt_mels_lens_for_flow[i].item():generated_mels_lens[i].item()].unsqueeze(0) + wav, _ = self.hift(speech_feat=mel) + generated_wavs.append(wav) + return generated_wavs + + @torch.inference_mode() + def forward( + self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int] + ): + assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate) + + prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio(prompt_audios_list, prompt_audios_sample_rate) + + generated_mels, generated_mels_lens = self.forward_flow( + prompt_speech_tokens_list, generated_speech_tokens_list, + prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow ) + generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow) + return generated_wavs - def stream(self, generated_speech_tokens, prompt_wav, last_chunk=False): - if prompt_wav not in self.cache: - self.cache[prompt_wav] = self._prepare_prompt(prompt_wav) - prompt_speech_tokens, prompt_speech_tokens_lens, spk_emb, prompt_mels, prompt_mels_lens = self.cache[prompt_wav] + def prepare_prompt_audio( + self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int] + ): + assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate) - generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda') - generated_speech_tokens_lens = torch.tensor([generated_speech_tokens.shape[1]], dtype=torch.int32, device='cuda') - - if self.stream_cache is None: - raise ValueError("stream_cache is not set") - - with torch.amp.autocast("cuda", dtype=torch.float16 if self.float16 else torch.float32): - chunk_mel, self.stream_cache = self.flow.inference_chunk( - token=generated_speech_tokens, - spk=spk_emb, - cache=self.stream_cache, - last_chunk=last_chunk, - n_timesteps=10, + prompt_speech_tokens_list = self.prompt_audio_tokenization(prompt_audios_list) + + prompt_mels_for_flow, prompt_mels_lens_for_flow = self.get_prompt_mels(prompt_audios_list, prompt_audios_sample_rate) + + spk_emb_for_flow = self.get_spk_emb(prompt_audios_list) + return prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow + + def get_prompt_audio_cache_for_streaming_tts( + self, prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow + ): + assert len(prompt_speech_tokens_list) == 1, "only support batch size 1 for streaming tts" + for i, prompt_speech_tokens in enumerate(prompt_speech_tokens_list): + prompt_speech_tokens_list[i] = torch.tensor(prompt_speech_tokens + prompt_speech_tokens_list[i][:3]) + prompt_speech_tokens_tensor = torch.nn.utils.rnn.pad_sequence(prompt_speech_tokens_list, batch_first=True, padding_value=0) + + cache = self.flow.setup_cache( + prompt_speech_tokens_tensor.to(self.device), + prompt_mels_for_flow.to(self.device), + spk_emb_for_flow.to(self.device), + n_timesteps=10 + ) + new_cache = {k: v.clone() for k, v in cache.items()} + # Hack: this is a hack to avoid in-place changes to the cache['estimator_att_cache'] and cache['estimator_cnn_cache'] + return new_cache + + @torch.inference_mode() + def forward_streaming( + self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000 + ): + if speaker_id not in self.speaker_cache: + assert prompt_audio is not None, "prompt_audio is required for new speaker" + assert prompt_audio_sample_rate == 16000 + + prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio([prompt_audio], [prompt_audio_sample_rate]) + + token_len = min(int(prompt_mels_for_flow.shape[1] / 2), len(prompt_speech_tokens_list[0])) + prompt_mels_for_flow = prompt_mels_for_flow[:, :2 * token_len].contiguous() + prompt_speech_tokens_list[0] = prompt_speech_tokens_list[0][:token_len] + + prompt_audio_dict = {'spk_emb_for_flow': spk_emb_for_flow, 'prompt_mels_for_flow': prompt_mels_for_flow} + + cache_dict = self.get_prompt_audio_cache_for_streaming_tts(prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow) + self.speaker_cache[speaker_id] = {'prompt_audio_dict': prompt_audio_dict, 'cache_dict': cache_dict} + + if request_id not in self.streaming_flow_cache: + self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()} + self.hift_cache_dict[request_id] = dict( + mel=torch.zeros(1, 80, 0, device='cuda'), + source=torch.zeros(1, 1, 0, device='cuda'), + speech=torch.zeros(1, 0, device='cuda'), ) - if self.stream_cache['estimator_att_cache'].shape[4] > (prompt_mels.shape[1] + 100): - self.stream_cache['estimator_att_cache'] = torch.cat([ - self.stream_cache['estimator_att_cache'][:, :, :, :, :prompt_mels.shape[1]], - self.stream_cache['estimator_att_cache'][:, :, :, :, -100:], + + current_request_cache = self.streaming_flow_cache[request_id] + + current_prompt_audio_dict = self.speaker_cache[speaker_id]['prompt_audio_dict'] + generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda') + + chunk_mel, new_streaming_flow_cache = self.flow.inference_chunk( + token=generated_speech_tokens, + spk=current_prompt_audio_dict['spk_emb_for_flow'].to(self.device), + cache=current_request_cache, + last_chunk=last_chunk, + n_timesteps=10, + ) + + self.streaming_flow_cache[request_id] = new_streaming_flow_cache + + if self.streaming_flow_cache[request_id]['estimator_att_cache'].shape[4] > (current_prompt_audio_dict['prompt_mels_for_flow'].shape[1] + 100): + self.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.cat([ + self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, :current_prompt_audio_dict['prompt_mels_for_flow'].shape[1]], + self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, -100:], ], dim=4) # vocoder cache @@ -146,6 +450,11 @@ def stream(self, generated_speech_tokens, prompt_wav, last_chunk=False): hift_cache_speech = self.hift_cache_dict['speech'] mel = torch.concat([hift_cache_mel, chunk_mel], dim=2) + hift_cache_mel = self.hift_cache_dict[request_id]['mel'].clone() + hift_cache_source = self.hift_cache_dict[request_id]['source'].clone() + hift_cache_speech = self.hift_cache_dict[request_id]['speech'].clone() + mel = torch.concat([hift_cache_mel, chunk_mel], dim=2).clone() + speech, source = self.hift(mel, hift_cache_source) # overlap speech smooth @@ -153,25 +462,65 @@ def stream(self, generated_speech_tokens, prompt_wav, last_chunk=False): speech = fade_in_out(speech, hift_cache_speech, self.speech_window) # update vocoder cache - self.hift_cache_dict = dict( - mel = mel[..., -self.mel_cache_len:].clone().detach(), - source = source[:, :, -self.source_cache_len:].clone().detach(), - speech = speech[:, -self.source_cache_len:].clone().detach(), + self.hift_cache_dict[request_id] = dict( + mel=mel[..., -self.mel_cache_len:].clone().detach(), + source=source[:, :, -self.source_cache_len:].clone().detach(), + speech=speech[:, -self.source_cache_len:].clone().detach(), ) if not last_chunk: speech = speech[:, :-self.source_cache_len] - wav_np = speech.cpu().numpy() - # Clip to [-1, 1] to avoid overflow, then scale to int16 - wav_np = np.clip(wav_np, -1.0, 1.0) - wav_int16 = (wav_np * 32767.0).astype(' 24kHz wave - self.speech_window = torch.from_numpy(np.hamming(2 * self.source_cache_len)).cuda() - - # hifigan cache for streaming tts - self.hift_cache_dict = {} - - def forward_spk_embedding(self, spk_feat): - if isinstance(self.spk_model, onnxruntime.InferenceSession): - return self.spk_model.run( - None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()} - )[0].flatten().tolist() - else: - [spk_model, stream], trt_engine = self.spk_model.acquire_estimator() - # NOTE need to synchronize when switching stream - with torch.cuda.device(self.device_id): - torch.cuda.current_stream().synchronize() - spk_feat = spk_feat.unsqueeze(dim=0).to(self.device) - batch_size = spk_feat.size(0) - - with stream: - spk_model.set_input_shape('input', (batch_size, spk_feat.size(1), 80)) - output_tensor = torch.empty((batch_size, 192), device=spk_feat.device) - - data_ptrs = [spk_feat.contiguous().data_ptr(), - output_tensor.contiguous().data_ptr()] - for i, j in enumerate(data_ptrs): - - spk_model.set_tensor_address(trt_engine.get_tensor_name(i), j) - # run trt engine - assert spk_model.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True - torch.cuda.current_stream().synchronize() - self.spk_model.release_estimator(spk_model, stream) - - return output_tensor.cpu().numpy().flatten().tolist() - - def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp16=True): - if not os.path.exists(spk_model) or os.path.getsize(spk_model) == 0: - trt_kwargs = self.get_spk_trt_kwargs() - convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, torch.float32) - import tensorrt as trt - with open(spk_model, 'rb') as f: - spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) - assert spk_engine is not None, 'failed to load trt {}'.format(spk_model) - self.spk_model = TrtContextWrapper(spk_engine, trt_concurrent=trt_concurrent, device=self.device) - - def get_spk_trt_kwargs(self): - min_shape = [(1, 4, 80)] - opt_shape = [(1, 500, 80)] - max_shape = [(1, 3000, 80)] - input_names = ["input"] - return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} - - def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent=1, dtype=torch.float16, streaming=False): - assert torch.cuda.is_available(), 'tensorrt only supports gpu!' - if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0: - opt_batch_size = 2 - max_batch_size = 16 - if streaming: - opt_batch_size, max_batch_size = 1, 1 # only support batch size 1 for streaming tts - trt_kwargs = self.get_trt_kwargs_dynamic_batch(opt_batch_size=opt_batch_size, max_batch_size=max_batch_size, streaming=streaming) - convert_onnx_to_trt(flow_decoder_estimator_model, trt_kwargs, flow_decoder_onnx_model, dtype) - del self.flow.decoder.estimator - import tensorrt as trt - with open(flow_decoder_estimator_model, 'rb') as f: - estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) - assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model) - self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device) - - def get_trt_kwargs_dynamic_batch(self, opt_batch_size=2, max_batch_size=64, streaming=False): - if streaming: - min_shape = [(2, 80, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80), (16, 2, 1024, 2), (16, 2, 8, 0, 128)] - opt_shape = [ - (opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 80, 500), - (opt_batch_size * 2,), (opt_batch_size * 2, 80), (16, opt_batch_size * 2, 1024, 2), - (16, opt_batch_size * 2, 8, 100, 128) - ] - max_shape = [ - (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000), - (max_batch_size * 2,), (max_batch_size * 2, 80), (16, max_batch_size * 2, 1024, 2), - (16, max_batch_size * 2, 8, 1000, 128) - ] - input_names = ["x", "mu", "cond", "t", "spks", "cnn_cache", "att_cache"] - else: - min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80)] - opt_shape = [ - (opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 1, 500), (opt_batch_size * 2, 80, 500), - (opt_batch_size * 2, 80, 500), (opt_batch_size * 2,), (opt_batch_size * 2, 80) - ] - max_shape = [ - (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 1, 3000), (max_batch_size * 2, 80, 3000), - (max_batch_size * 2, 80, 3000), (max_batch_size * 2,), (max_batch_size * 2, 80) - ] - input_names = ["x", "mask", "mu", "cond", "t", "spks"] - return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} - - def prompt_audio_tokenization(self, prompt_audios_list: list[torch.Tensor]) -> list[list[int]]: - prompt_speech_tokens_list, prompt_speech_mels_list = [], [] - for audio in prompt_audios_list: - assert len(audio.shape) == 1 - log_mel = s3tokenizer.log_mel_spectrogram(audio) # [num_mels, T] - prompt_speech_mels_list.append(log_mel) - prompt_mels_for_llm, prompt_mels_lens_for_llm = s3tokenizer.padding(prompt_speech_mels_list) - prompt_speech_tokens, prompt_speech_tokens_lens = self.audio_tokenizer.quantize( - prompt_mels_for_llm.to(self.device), prompt_mels_lens_for_llm.to(self.device) - ) - for i in range(len(prompt_speech_tokens)): - speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist() - prompt_speech_tokens_list.append(speech_tokens_i) - return prompt_speech_tokens_list - - def get_spk_emb(self, prompt_audios_list: list[torch.Tensor]) -> torch.Tensor: - spk_emb_for_flow = [] - for audio in prompt_audios_list: - assert len(audio.shape) == 1 - spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000) - spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True) - spk_emb = self.forward_spk_embedding(spk_feat) - - spk_emb_for_flow.append(spk_emb) - spk_emb_for_flow = torch.tensor(spk_emb_for_flow) - if self.dtype != torch.float32: - spk_emb_for_flow = spk_emb_for_flow.to(self.dtype) - return spk_emb_for_flow - - def get_prompt_mels(self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]): - prompt_mels_for_flow = [] - prompt_mels_lens_for_flow = [] - for audio, sample_rate in zip(prompt_audios_list, prompt_audios_sample_rate): - assert len(audio.shape) == 1 - audio = audio.unsqueeze(0) - if sample_rate != 24000: - audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio) - mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0) # [T, num_mels] - mel_len = mel.shape[0] - prompt_mels_for_flow.append(mel) - prompt_mels_lens_for_flow.append(mel_len) - prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence( - prompt_mels_for_flow, batch_first=True, padding_value=0 - ) # [B, T', num_mels=80] - prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow) - return prompt_mels_for_flow, prompt_mels_lens_for_flow - - def forward_flow(self, prompt_speech_tokens_list: list[list[int]], - generated_speech_tokens_list: list[list[int]], - prompt_mels_for_flow: torch.Tensor, - prompt_mels_lens_for_flow: torch.Tensor, - spk_emb_for_flow: torch.Tensor): - batch_size = prompt_mels_for_flow.shape[0] - flow_inputs = [] - flow_inputs_lens = [] - for prompt_speech_tokens, generated_speech_tokens in zip(prompt_speech_tokens_list, generated_speech_tokens_list): - flow_inputs.append(torch.tensor(prompt_speech_tokens + generated_speech_tokens)) - flow_inputs_lens.append(len(prompt_speech_tokens) + len(generated_speech_tokens)) - - flow_inputs = torch.nn.utils.rnn.pad_sequence(flow_inputs, batch_first=True, padding_value=0) - flow_inputs_lens = torch.tensor(flow_inputs_lens) - - with torch.amp.autocast(self.device, dtype=torch.float16): - generated_mels, generated_mels_lens = self.flow.inference( - flow_inputs.to(self.device), flow_inputs_lens.to(self.device), - prompt_mels_for_flow.to(self.device), prompt_mels_lens_for_flow.to(self.device), spk_emb_for_flow.to(self.device), 10 - ) - - return generated_mels, generated_mels_lens - - def forward_hift(self, generated_mels: torch.Tensor, generated_mels_lens: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor): - batch_size = generated_mels.shape[0] - generated_wavs = [] - for i in range(batch_size): - mel = generated_mels[i, :, prompt_mels_lens_for_flow[i].item():generated_mels_lens[i].item()].unsqueeze(0) - wav, _ = self.hift(speech_feat=mel) - generated_wavs.append(wav) - return generated_wavs - - @torch.inference_mode() - def forward( - self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int] - ): - assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate) - - prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio(prompt_audios_list, prompt_audios_sample_rate) - - generated_mels, generated_mels_lens = self.forward_flow( - prompt_speech_tokens_list, generated_speech_tokens_list, - prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow - ) - - generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow) - return generated_wavs - - def prepare_prompt_audio( - self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int] - ): - assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate) - - prompt_speech_tokens_list = self.prompt_audio_tokenization(prompt_audios_list) - - prompt_mels_for_flow, prompt_mels_lens_for_flow = self.get_prompt_mels(prompt_audios_list, prompt_audios_sample_rate) - - spk_emb_for_flow = self.get_spk_emb(prompt_audios_list) - return prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow - - def get_prompt_audio_cache_for_streaming_tts( - self, prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow - ): - assert len(prompt_speech_tokens_list) == 1, "only support batch size 1 for streaming tts" - for i, prompt_speech_tokens in enumerate(prompt_speech_tokens_list): - prompt_speech_tokens_list[i] = torch.tensor(prompt_speech_tokens + prompt_speech_tokens_list[i][:3]) - prompt_speech_tokens_tensor = torch.nn.utils.rnn.pad_sequence(prompt_speech_tokens_list, batch_first=True, padding_value=0) - - cache = self.flow.setup_cache( - prompt_speech_tokens_tensor.to(self.device), - prompt_mels_for_flow.to(self.device), - spk_emb_for_flow.to(self.device), - n_timesteps=10 - ) - new_cache = {k: v.clone() for k, v in cache.items()} - # Hack: this is a hack to avoid in-place changes to the cache['estimator_att_cache'] and cache['estimator_cnn_cache'] - return new_cache - - @torch.inference_mode() - def forward_streaming( - self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000 - ): - if speaker_id not in self.speaker_cache: - assert prompt_audio is not None, "prompt_audio is required for new speaker" - assert prompt_audio_sample_rate == 16000 - - prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio([prompt_audio], [prompt_audio_sample_rate]) - - token_len = min(int(prompt_mels_for_flow.shape[1] / 2), len(prompt_speech_tokens_list[0])) - prompt_mels_for_flow = prompt_mels_for_flow[:, :2 * token_len].contiguous() - prompt_speech_tokens_list[0] = prompt_speech_tokens_list[0][:token_len] - - prompt_audio_dict = {'spk_emb_for_flow': spk_emb_for_flow, 'prompt_mels_for_flow': prompt_mels_for_flow} - - cache_dict = self.get_prompt_audio_cache_for_streaming_tts(prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow) - self.speaker_cache[speaker_id] = {'prompt_audio_dict': prompt_audio_dict, 'cache_dict': cache_dict} - - if request_id not in self.streaming_flow_cache: - self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()} - self.hift_cache_dict[request_id] = dict( - mel=torch.zeros(1, 80, 0, device='cuda'), - source=torch.zeros(1, 1, 0, device='cuda'), - speech=torch.zeros(1, 0, device='cuda'), - ) - - current_request_cache = self.streaming_flow_cache[request_id] - - current_prompt_audio_dict = self.speaker_cache[speaker_id]['prompt_audio_dict'] - generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda') - - chunk_mel, new_streaming_flow_cache = self.flow.inference_chunk( - token=generated_speech_tokens, - spk=current_prompt_audio_dict['spk_emb_for_flow'].to(self.device), - cache=current_request_cache, - last_chunk=last_chunk, - n_timesteps=10, - ) - - self.streaming_flow_cache[request_id] = new_streaming_flow_cache - - if self.streaming_flow_cache[request_id]['estimator_att_cache'].shape[4] > (current_prompt_audio_dict['prompt_mels_for_flow'].shape[1] + 100): - self.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.cat([ - self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, :current_prompt_audio_dict['prompt_mels_for_flow'].shape[1]], - self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, -100:], - ], dim=4) - - hift_cache_mel = self.hift_cache_dict[request_id]['mel'].clone() - hift_cache_source = self.hift_cache_dict[request_id]['source'].clone() - hift_cache_speech = self.hift_cache_dict[request_id]['speech'].clone() - mel = torch.concat([hift_cache_mel, chunk_mel], dim=2).clone() - - speech, source = self.hift(mel, hift_cache_source) - - # overlap speech smooth - if hift_cache_speech.shape[-1] > 0: - speech = fade_in_out(speech, hift_cache_speech, self.speech_window) - - # update vocoder cache - self.hift_cache_dict[request_id] = dict( - mel=mel[..., -self.mel_cache_len:].clone().detach(), - source=source[:, :, -self.source_cache_len:].clone().detach(), - speech=speech[:, -self.source_cache_len:].clone().detach(), - ) - if not last_chunk: - speech = speech[:, :-self.source_cache_len] - - if last_chunk: - assert request_id in self.streaming_flow_cache - self.streaming_flow_cache.pop(request_id) - self.hift_cache_dict.pop(request_id) - - return speech - - -def collate_fn(batch): - ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], [] - for item in batch: - generated_speech_tokens_list.append(item['target_audio_cosy2_tokens']) - audio = torch.from_numpy(item['prompt_audio']['array']).float() - prompt_audios_list.append(audio) - prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate']) - ids.append(item['id']) - - return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--enable-trt", action="store_true") - parser.add_argument("--model-dir", type=str, default="./Step-Audio-2-mini/token2wav") - parser.add_argument("--batch-size", type=int, default=1) - parser.add_argument("--output-dir", type=str, default="generated_wavs") - parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts") - parser.add_argument("--warmup", type=int, default=3, help="Number of warmup epochs, performance statistics will only be collected from the last epoch") - return parser.parse_args() - - -if __name__ == "__main__": - args = get_args() - model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt) - if not os.path.exists(args.output_dir): - os.makedirs(args.output_dir) - dataset_name = "yuekai/seed_tts_cosy2" - - dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True) - - data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0) - - for _ in range(args.warmup): - start_time = time.time() - for batch in data_loader: - ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch - - generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate) - - for id, wav in zip(ids, generated_wavs): - torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000) - end_time = time.time() - epoch_time = end_time - start_time - print(f"Measurement epoch time taken: {epoch_time:.4f} seconds") diff --git a/export_onnx.py b/tools/export_onnx_offline_token2wav.py similarity index 100% rename from export_onnx.py rename to tools/export_onnx_offline_token2wav.py diff --git a/export_onnx_chunk.py b/tools/export_onnx_streaming_token2wav.py similarity index 100% rename from export_onnx_chunk.py rename to tools/export_onnx_streaming_token2wav.py From 459f67ee6a1894b7ffc34042aabb1f5ed6ec1a29 Mon Sep 17 00:00:00 2001 From: yuekaiz Date: Tue, 21 Oct 2025 21:55:06 +0800 Subject: [PATCH 09/12] align with the original token2wav inferface --- cosyvoice2/flow/flow_matching.py | 1 - examples-vllm-stream.py | 1 - token2wav.py | 93 +++++++++++--------------------- 3 files changed, 32 insertions(+), 63 deletions(-) diff --git a/cosyvoice2/flow/flow_matching.py b/cosyvoice2/flow/flow_matching.py index f136b87..79d10b4 100644 --- a/cosyvoice2/flow/flow_matching.py +++ b/cosyvoice2/flow/flow_matching.py @@ -149,7 +149,6 @@ def forward_estimator_chunk(self, x, mu, t, spks, cond, cnn_cache, att_cache): batch_size = x.size(0) with stream: estimator.set_input_shape('x', (batch_size, 80, x.size(2))) - # estimator.set_input_shape('mask', (batch_size, 1, x.size(2))) estimator.set_input_shape('mu', (batch_size, 80, x.size(2))) estimator.set_input_shape('t', (batch_size,)) estimator.set_input_shape('spks', (batch_size, 80)) diff --git a/examples-vllm-stream.py b/examples-vllm-stream.py index 866fb45..a681ae6 100644 --- a/examples-vllm-stream.py +++ b/examples-vllm-stream.py @@ -43,7 +43,6 @@ def stream_client(model, history, tools, token2wav=None, output_stream=None, pro model = StepAudio2(api_url, model_name) token2wav = Token2wav('Step-Audio-2-mini/token2wav') tokens = [1493, 4299, 4218, 2049, 528, 2752, 4850, 4569, 4575, 6372, 2127, 4068, 2312, 4993, 4769, 2300, 226, 2175, 2160, 2152, 6311, 6065, 4859, 5102, 4615, 6534, 6426, 1763, 2249, 2209, 5938, 1725, 6048, 3816, 6058, 958, 63, 4460, 5914, 2379, 735, 5319, 4593, 2328, 890, 35, 751, 1483, 1484, 1483, 2112, 303, 4753, 2301, 5507, 5588, 5261, 5744, 5501, 2341, 2001, 2252, 2344, 1860, 2031, 414, 4366, 4366, 6059, 5300, 4814, 5092, 5100, 1923, 3054, 4320, 4296, 2148, 4371, 5831, 5084, 5027, 4946, 4946, 2678, 575, 575, 521, 518, 638, 1367, 2804, 3402, 4299] - token2wav.set_stream_cache(prompt_wav) token2wav.stream(tokens[:CHUNK_SIZE + token2wav.flow.pre_lookahead_len], prompt_wav=prompt_wav) # Warm up output_stream = Path('output-stream.pcm') diff --git a/token2wav.py b/token2wav.py index 8bdd1d6..4a7444e 100644 --- a/token2wav.py +++ b/token2wav.py @@ -121,7 +121,7 @@ def release_estimator(self, context, stream): self.trt_context_pool.put([context, stream]) -class CosyVoice2_Token2Wav(torch.nn.Module): +class Token2Wav(torch.nn.Module): def __init__(self, model_dir: str, enable_trt: bool = False, device_id: int = 0, streaming: bool = False, dtype: torch.dtype = torch.float16): super().__init__() self.device_id = device_id @@ -149,24 +149,23 @@ def __init__(self, model_dir: str, enable_trt: bool = False, device_id: int = 0, providers=["CPUExecutionProvider"]) self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2_25hz.onnx").to(self.device).eval() - gpu = "l20" if enable_trt: if streaming: self.load_trt( - f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.chunk.{gpu}.plan', + f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.chunk.plan', f'{model_dir}/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx', 1, self.dtype, streaming ) else: self.load_trt( - f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.{gpu}.plan', + f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.plan', f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx', 1, self.dtype ) self.load_spk_trt( - f'{model_dir}/campplus.{gpu}.fp32.trt', + f'{model_dir}/campplus.fp32.trt', f'{model_dir}/campplus.onnx', 1, False @@ -230,6 +229,7 @@ def get_spk_trt_kwargs(self): def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent=1, dtype=torch.float16, streaming=False): assert torch.cuda.is_available(), 'tensorrt only supports gpu!' + assert os.path.exists(flow_decoder_onnx_model), f'Please use tools/export_onnx.py or tools/export_onnx_streaming.py to export onnx model for token2wav first.' if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0: opt_batch_size = 2 max_batch_size = 16 @@ -352,19 +352,22 @@ def forward_hift(self, generated_mels: torch.Tensor, generated_mels_lens: torch. @torch.inference_mode() def forward( - self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int] - ): - assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate) - - prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio(prompt_audios_list, prompt_audios_sample_rate) - + self, generated_speech_tokens: list[int], prompt_wav: str): + generated_speech_tokens_list = [generated_speech_tokens] + audio = s3tokenizer.load_audio(prompt_wav, sr=16000) + prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio([audio], [16000]) generated_mels, generated_mels_lens = self.forward_flow( prompt_speech_tokens_list, generated_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow ) generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow) - return generated_wavs + + wav = generated_wavs[0] + output = io.BytesIO() + torchaudio.save(output, wav.cpu(), sample_rate=24000, format='wav') + + return output.getvalue() def prepare_prompt_audio( self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int] @@ -397,9 +400,13 @@ def get_prompt_audio_cache_for_streaming_tts( return new_cache @torch.inference_mode() - def forward_streaming( - self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000 + def stream( + self, generated_speech_tokens: list[int], prompt_wav: str, last_chunk: bool = False, ): + speaker_id = prompt_wav + request_id = prompt_wav + prompt_audio_sample_rate = 16000 + prompt_audio = s3tokenizer.load_audio(prompt_wav, sr=prompt_audio_sample_rate) if speaker_id not in self.speaker_cache: assert prompt_audio is not None, "prompt_audio is required for new speaker" assert prompt_audio_sample_rate == 16000 @@ -474,53 +481,17 @@ def forward_streaming( assert request_id in self.streaming_flow_cache self.streaming_flow_cache.pop(request_id) self.hift_cache_dict.pop(request_id) - - return speech - - -def collate_fn(batch): - ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], [] - for item in batch: - generated_speech_tokens_list.append(item['target_audio_cosy2_tokens']) - audio = torch.from_numpy(item['prompt_audio']['array']).float() - prompt_audios_list.append(audio) - prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate']) - ids.append(item['id']) - - return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--enable-trt", action="store_true") - parser.add_argument("--model-dir", type=str, default="./Step-Audio-2-mini/token2wav") - parser.add_argument("--batch-size", type=int, default=1) - parser.add_argument("--output-dir", type=str, default="generated_wavs") - parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts") - parser.add_argument("--warmup", type=int, default=3, help="Number of warmup epochs, performance statistics will only be collected from the last epoch") - return parser.parse_args() - + wav_np = speech.cpu().numpy() + # Clip to [-1, 1] to avoid overflow, then scale to int16 + wav_np = np.clip(wav_np, -1.0, 1.0) + wav_int16 = (wav_np * 32767.0).astype(' Date: Wed, 22 Oct 2025 10:49:13 +0800 Subject: [PATCH 10/12] add benchmark results --- token2wav.py | 55 +++++++----- tools/export_onnx_offline_token2wav.py | 27 +++--- tools/export_onnx_streaming_token2wav.py | 38 +++++---- tools/tensorrt_token2wav.md | 103 +++++++++++++++++++++++ 4 files changed, 176 insertions(+), 47 deletions(-) create mode 100644 tools/tensorrt_token2wav.md diff --git a/token2wav.py b/token2wav.py index 4a7444e..c0448dc 100644 --- a/token2wav.py +++ b/token2wav.py @@ -17,7 +17,6 @@ python3 token2wav.py --enable-trt || exit 1 """ import torch -# from flashcosyvoice.modules.flow import CausalMaskedDiffWithXvec from flashcosyvoice.modules.hifigan import HiFTGenerator from flashcosyvoice.utils.audio import mel_spectrogram import torchaudio.compliance.kaldi as kaldi @@ -33,17 +32,9 @@ import time import numpy as np from hyperpyyaml import load_hyperpyyaml - -def fade_in_out(fade_in_mel:torch.Tensor, fade_out_mel:torch.Tensor, window:torch.Tensor): - """perform fade_in_out in tensor style - """ - mel_overlap_len = int(window.shape[0] / 2) - fade_in_mel = fade_in_mel.clone() - fade_in_mel[..., :mel_overlap_len] = \ - fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \ - fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:] - return fade_in_mel - +import io +from pathlib import Path +import wave def fade_in_out(fade_in_mel: torch.Tensor, fade_out_mel: torch.Tensor, window: torch.Tensor): """perform fade_in_out in tensor style @@ -121,7 +112,7 @@ def release_estimator(self, context, stream): self.trt_context_pool.put([context, stream]) -class Token2Wav(torch.nn.Module): +class Token2wav(torch.nn.Module): def __init__(self, model_dir: str, enable_trt: bool = False, device_id: int = 0, streaming: bool = False, dtype: torch.dtype = torch.float16): super().__init__() self.device_id = device_id @@ -152,8 +143,8 @@ def __init__(self, model_dir: str, enable_trt: bool = False, device_id: int = 0, if enable_trt: if streaming: self.load_trt( - f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.chunk.plan', - f'{model_dir}/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx', + f'{model_dir}/flow.decoder.estimator.{self.dtype}.static_batch.chunk.plan', + f'{model_dir}/flow.decoder.estimator.chunk.fp32.static_batch.onnx', 1, self.dtype, streaming ) @@ -489,9 +480,33 @@ def stream( return pcm_bytes if __name__ == "__main__": - token2wav = Token2wav('Step-Audio-2-mini/token2wav') - tokens = [1493, 4299, 4218, 2049, 528, 2752, 4850, 4569, 4575, 6372, 2127, 4068, 2312, 4993, 4769, 2300, 226, 2175, 2160, 2152, 6311, 6065, 4859, 5102, 4615, 6534, 6426, 1763, 2249, 2209, 5938, 1725, 6048, 3816, 6058, 958, 63, 4460, 5914, 2379, 735, 5319, 4593, 2328, 890, 35, 751, 1483, 1484, 1483, 2112, 303, 4753, 2301, 5507, 5588, 5261, 5744, 5501, 2341, 2001, 2252, 2344, 1860, 2031, 414, 4366, 4366, 6059, 5300, 4814, 5092, 5100, 1923, 3054, 4320, 4296, 2148, 4371, 5831, 5084, 5027, 4946, 4946, 2678, 575, 575, 521, 518, 638, 1367, 2804, 3402, 4299] - audio = token2wav(tokens, 'assets/default_male.wav') - with open('assets/give_me_a_brief_introduction_to_the_great_wall.wav', 'wb') as f: - f.write(audio) \ No newline at end of file + # offline token2wav + # token2wav = Token2wav('Step-Audio-2-mini/token2wav', enable_trt=True) + # audio = token2wav(tokens, 'assets/default_male.wav') + # with open('assets/give_me_a_brief_introduction_to_the_great_wall.wav', 'wb') as f: + # f.write(audio) + + # streaming token2wav using pytorch + # token2wav = Token2wav('Step-Audio-2-mini/token2wav') + + # streaming token2wav using tensorrt + token2wav = Token2wav('Step-Audio-2-mini/token2wav', enable_trt=True, streaming=True) + audio_first_chunk = token2wav.stream(tokens[:25 + token2wav.flow.pre_lookahead_len], prompt_wav='assets/default_male.wav') + audio_last_chunk = token2wav.stream(tokens[25 + token2wav.flow.pre_lookahead_len:], prompt_wav='assets/default_male.wav', last_chunk=True) + + + output_stream = Path('output-stream.pcm') + output_stream.unlink(missing_ok=True) + with open(output_stream, 'wb') as f: + f.write(audio_first_chunk) + f.write(audio_last_chunk) + + with open(output_stream, 'rb') as f: + pcm = f.read() + wav_path = output_stream.with_suffix('.wav') + with wave.open(str(wav_path), 'wb') as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(24000) + wf.writeframes(pcm) \ No newline at end of file diff --git a/tools/export_onnx_offline_token2wav.py b/tools/export_onnx_offline_token2wav.py index 9d21444..1f64e36 100644 --- a/tools/export_onnx_offline_token2wav.py +++ b/tools/export_onnx_offline_token2wav.py @@ -12,6 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +This script is used to export the offline token2wav model to onnx. +python3 tools/export_onnx_offline_token2wav.py +""" from __future__ import print_function @@ -25,17 +29,10 @@ import torch from tqdm import tqdm from hyperpyyaml import load_hyperpyyaml - - -def get_dummy_input(batch_size, seq_len, out_channels, device): - x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) - mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device) - mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) - t = torch.rand((batch_size), dtype=torch.float32, device=device) - spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device) - cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) - return x, mask, mu, t, spks, cond - +import sys +import os +# add ../ to python path +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) def get_args(): parser = argparse.ArgumentParser(description='export your model for deployment') @@ -51,6 +48,14 @@ def get_args(): print(args) return args +def get_dummy_input(batch_size, seq_len, out_channels, device): + x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) + mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device) + mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) + t = torch.rand((batch_size), dtype=torch.float32, device=device) + spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device) + cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) + return x, mask, mu, t, spks, cond @torch.no_grad() def main(): diff --git a/tools/export_onnx_streaming_token2wav.py b/tools/export_onnx_streaming_token2wav.py index e6a99e3..5f4362d 100644 --- a/tools/export_onnx_streaming_token2wav.py +++ b/tools/export_onnx_streaming_token2wav.py @@ -12,6 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +This script is used to export the streaming token2wav model to onnx. +python3 tools/export_onnx_streaming_token2wav.py +""" from __future__ import print_function @@ -26,6 +30,24 @@ from tqdm import tqdm from hyperpyyaml import load_hyperpyyaml +import sys +import os +# add ../ to python path +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +def get_args(): + parser = argparse.ArgumentParser(description='export your model for deployment') + parser.add_argument('--model_dir', + type=str, + default='Step-Audio-2-mini/token2wav', + help='local path') + parser.add_argument('--onnx_model', + type=str, + default='flow.decoder.estimator.chunk.fp32.static_batch.onnx', + help='onnx model name') + args = parser.parse_args() + print(args) + return args def get_dummy_input_chunk(batch_size, seq_len, prev_seq_len, out_channels, estimator, device): x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) @@ -43,22 +65,6 @@ def get_dummy_input_chunk(batch_size, seq_len, prev_seq_len, out_channels, estim att_cache = torch.rand((depth, batch_size, num_heads, prev_seq_len, head_dim * 2), dtype=torch.float32, device=device) return x, mu, t, spks, cond, cnn_cache, att_cache - -def get_args(): - parser = argparse.ArgumentParser(description='export your model for deployment') - parser.add_argument('--model_dir', - type=str, - default='Step-Audio-2-mini/token2wav', - help='local path') - parser.add_argument('--onnx_model', - type=str, - default='flow.decoder.estimator.chunk.fp32.dynamic_batch.onnx', - help='onnx model name') - args = parser.parse_args() - print(args) - return args - - class DiTChunkWrapper(torch.nn.Module): def __init__(self, dit_model): super().__init__() diff --git a/tools/tensorrt_token2wav.md b/tools/tensorrt_token2wav.md new file mode 100644 index 0000000..8888d93 --- /dev/null +++ b/tools/tensorrt_token2wav.md @@ -0,0 +1,103 @@ +# Accelerating StepAudio2 Token2wav with NVIDIA TensorRT + +This document provides instructions on how to use NVIDIA TensorRT to accelerate the Token2wav module in StepAudio2 for both offline and streaming inference. + +## Preparation + +### 1. Install Dependencies + +Install the necessary packages using pip. For GPU acceleration with TensorRT, use `onnxruntime-gpu`. + +```bash +pip install tensorrt onnxruntime-gpu +``` + +### 2. Export ONNX Models + +You need to export the PyTorch models to ONNX format. There are separate scripts for offline (dynamic batch) and streaming (static batch) modes. + +**For Offline Inference:** +```bash +python3 tools/export_onnx_offline_token2wav.py +``` + +**For Streaming Inference:** +```bash +python3 tools/export_onnx_streaming_token2wav.py +``` + +## Usage + +### Offline Inference + +Here is an example of how to use the TensorRT-accelerated Token2wav model for offline inference. + +```python +from token2wav import Token2wav +import wave + +# The tokens to be converted to speech +tokens = [1493, 4299, 4218, 2049, 528, 2752, 4850, 4569, 4575, 6372, 2127, 4068, 2312, 4993, 4769, 2300, 226, 2175, 2160, 2152, 6311, 6065, 4859, 5102, 4615, 6534, 6426, 1763, 2249, 2209, 5938, 1725, 6048, 3816, 6058, 958, 63, 4460, 5914, 2379, 735, 5319, 4593, 2328, 890, 35, 751, 1483, 1484, 1483, 2112, 303, 4753, 2301, 5507, 5588, 5261, 5744, 5501, 2341, 2001, 2252, 2344, 1860, 2031, 414, 4366, 4366, 6059, 5300, 4814, 5092, 5100, 1923, 3054, 4320, 4296, 2148, 4371, 5831, 5084, 5027, 4946, 4946, 2678, 575, 575, 521, 518, 638, 1367, 2804, 3402, 4299] + +# Initialize Token2wav with TensorRT enabled +token2wav = Token2wav('Step-Audio-2-mini/token2wav', enable_trt=True) + +# Generate audio +audio_bytes = token2wav(tokens, 'assets/default_male.wav') + +# Save the generated audio to a file +with open('output_offline.wav', 'wb') as f: + f.write(audio_bytes) +``` + +### Streaming Inference + +For streaming inference, you can process tokens in chunks. + +```python +from token2wav import Token2wav +from pathlib import Path +import wave + +tokens = [1493, 4299, 4218, 2049, 528, 2752, 4850, 4569, 4575, 6372, 2127, 4068, 2312, 4993, 4769, 2300, 226, 2175, 2160, 2152, 6311, 6065, 4859, 5102, 4615, 6534, 6426, 1763, 2249, 2209, 5938, 1725, 6048, 3816, 6058, 958, 63, 4460, 5914, 2379, 735, 5319, 4593, 2328, 890, 35, 751, 1483, 1484, 1483, 2112, 303, 4753, 2301, 5507, 5588, 5261, 5744, 5501, 2341, 2001, 2252, 2344, 1860, 2031, 414, 4366, 4366, 6059, 5300, 4814, 5092, 5100, 1923, 3054, 4320, 4296, 2148, 4371, 5831, 5084, 5027, 4946, 4946, 2678, 575, 575, 521, 518, 638, 1367, 2804, 3402, 4299] + +# Initialize Token2wav for streaming with TensorRT +token2wav = Token2wav('Step-Audio-2-mini/token2wav', enable_trt=True, streaming=True) + +# Process the first chunk of tokens +audio_first_chunk = token2wav.stream(tokens[:25 + token2wav.flow.pre_lookahead_len], prompt_wav='assets/default_male.wav') + +# Process the remaining tokens as the last chunk +audio_last_chunk = token2wav.stream(tokens[25 + token2wav.flow.pre_lookahead_len:], prompt_wav='assets/default_male.wav', last_chunk=True) + +# Save the streaming output to a PCM file +output_stream = Path('output-stream.pcm') +output_stream.unlink(missing_ok=True) +with open(output_stream, 'wb') as f: + f.write(audio_first_chunk) + f.write(audio_last_chunk) + +# Convert PCM to WAV +with open(output_stream, 'rb') as f: + pcm = f.read() +wav_path = output_stream.with_suffix('.wav') +with wave.open(str(wav_path), 'wb') as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(24000) + wf.writeframes(pcm) + +``` + +## Benchmark + +The following benchmark was conducted on an NVIDIA L20 GPU, generating 26 audio clips with a total length of 170 seconds. RTF (Real-Time Factor) is calculated as `Cost Time / Total Audio Length`. + +| Method | Note | Cost Time | RTF | +|-----------|-------------------------------------|----------------|---------| +| Offline | batch=1, PyTorch | 4.32 seconds | 0.025 | +| Offline | batch=1, TensorRT enabled | 2.09 seconds | 0.012 | +| Offline | batch=2, PyTorch | 3.77 seconds | 0.022 | +| Offline | batch=2, TensorRT enabled | 1.97 seconds | 0.012 | +| Streaming | batch=1, chunk_size = 1 second, PyTorch | 20.3 seconds | 0.119 | +| Streaming | batch=1, chunk_size = 1 second, TensorRT | 12.96 seconds | 0.076 | From e2a16c07397855fa92da8e010dce0d0fc860bdd4 Mon Sep 17 00:00:00 2001 From: yuekaiz Date: Wed, 22 Oct 2025 10:59:03 +0800 Subject: [PATCH 11/12] clean code --- token2wav.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/token2wav.py b/token2wav.py index c0448dc..18a7f82 100644 --- a/token2wav.py +++ b/token2wav.py @@ -441,12 +441,6 @@ def stream( self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, :current_prompt_audio_dict['prompt_mels_for_flow'].shape[1]], self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, -100:], ], dim=4) - - # vocoder cache - hift_cache_mel = self.hift_cache_dict['mel'] - hift_cache_source = self.hift_cache_dict['source'] - hift_cache_speech = self.hift_cache_dict['speech'] - mel = torch.concat([hift_cache_mel, chunk_mel], dim=2) hift_cache_mel = self.hift_cache_dict[request_id]['mel'].clone() hift_cache_source = self.hift_cache_dict[request_id]['source'].clone() From 204e5879bdc69f21ab187b3470602b5e79d9b4e5 Mon Sep 17 00:00:00 2001 From: yuekaiz Date: Wed, 22 Oct 2025 11:06:54 +0800 Subject: [PATCH 12/12] remove license --- token2wav.py | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/token2wav.py b/token2wav.py index 18a7f82..aa64b96 100644 --- a/token2wav.py +++ b/token2wav.py @@ -1,20 +1,6 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. """ Example Usage CUDA_VISIBLE_DEVICES=0 \ - python3 token2wav.py --enable-trt || exit 1 + python3 token2wav.py """ import torch from flashcosyvoice.modules.hifigan import HiFTGenerator