diff --git a/cosyvoice2/flow/decoder_dit.py b/cosyvoice2/flow/decoder_dit.py index cb80edc..15c8c1e 100644 --- a/cosyvoice2/flow/decoder_dit.py +++ b/cosyvoice2/flow/decoder_dit.py @@ -403,9 +403,6 @@ def __init__( self.inference_buffers_chunk = {} self.max_size_chunk = {} - self.register_buffer('att_cache_buffer', torch.zeros((16, 2, 8, 1000, 128)), persistent=False) - self.register_buffer('cnn_cache_buffer', torch.zeros((16, 2, 1024, 2)), persistent=False) - def initialize_weights(self): # Initialize transformer layers: def _basic_init(module): @@ -432,7 +429,7 @@ def _basic_init(module): def _init_cuda_graph_chunk(self): # get dtype, device from registered buffer - dtype, device = self.cnn_cache_buffer.dtype, self.cnn_cache_buffer.device + dtype, device = self.in_proj.weight.dtype, self.in_proj.weight.device # init cuda graph for streaming forward with torch.no_grad(): for chunk_size in [30, 48, 96]: @@ -491,6 +488,7 @@ def forward(self, x, mask, mu, t, spks=None, cond=None): # time t = self.t_embedder(t).unsqueeze(1) # (b, 1, c) + x = pack([x, mu], "b * t")[0] if spks is not None: spks = repeat(spks, "b c -> b c t", t=x.shape[-1]) @@ -541,16 +539,14 @@ def forward_chunk(self, # create fake cache if cnn_cache is None: - cnn_cache = [None] * len(self.blocks) + cnn_cache = torch.zeros(len(self.blocks), x.shape[0], 1024, 2, dtype=x.dtype, device=x.device) if att_cache is None: - att_cache = [None] * len(self.blocks) - if att_cache[0] is not None: - last_att_len = att_cache.shape[3] - else: - last_att_len = 0 + att_cache = torch.zeros(len(self.blocks), x.shape[0], self.blocks[0].attn.num_heads, 0, self.blocks[0].attn.head_dim * 2, dtype=x.dtype, device=x.device) + + last_att_len = att_cache.shape[3] chunk_size = x.shape[2] mask = torch.ones(x.shape[0], chunk_size, last_att_len+chunk_size, dtype=torch.bool, device=x.device) - if self.use_cuda_graph and att_cache[0] is not None and chunk_size in self.graph_chunk and last_att_len <= self.max_size_chunk[chunk_size]: + if self.use_cuda_graph and att_cache is not None and chunk_size in self.graph_chunk and last_att_len <= self.max_size_chunk[chunk_size]: padded_mask = torch.zeros((2, chunk_size, self.max_size_chunk[chunk_size]+chunk_size), dtype=mask.dtype, device=mask.device) padded_mask[:, :, :mask.shape[-1]] = mask padded_att_cache = torch.zeros((16, 2, 8, self.max_size_chunk[chunk_size], 128), dtype=att_cache.dtype, device=att_cache.device) @@ -566,20 +562,22 @@ def forward_chunk(self, new_att_cache = self.inference_buffers_chunk[chunk_size]['static_outputs'][2][:, :, :, :chunk_size+last_att_len, :] else: mask = None - x = self.blocks_forward_chunk(x, t, mask, cnn_cache, att_cache, self.cnn_cache_buffer, self.att_cache_buffer) - new_cnn_cache = self.cnn_cache_buffer - new_att_cache = self.att_cache_buffer[:, :, :, :last_att_len+chunk_size, :] + x, new_cnn_cache, new_att_cache = self.blocks_forward_chunk(x, t, mask, cnn_cache, att_cache) return x, new_cnn_cache, new_att_cache - def blocks_forward_chunk(self, x, t, mask, cnn_cache=None, att_cache=None, cnn_cache_buffer=None, att_cache_buffer=None): + def blocks_forward_chunk(self, x, t, mask, cnn_cache, att_cache): x = x.transpose(1, 2) x = self.in_proj(x) + + new_cnn_caches = [] + new_att_caches = [] + for b_idx, block in enumerate(self.blocks): - x, this_new_cnn_cache, this_new_att_cache \ - = block.forward_chunk(x, t, cnn_cache[b_idx], att_cache[b_idx], mask) - cnn_cache_buffer[b_idx] = this_new_cnn_cache - att_cache_buffer[b_idx][:, :, :this_new_att_cache.shape[2], :] = this_new_att_cache + x, this_new_cnn_cache, this_new_att_cache = block.forward_chunk(x, t, cnn_cache[b_idx], att_cache[b_idx], mask) + new_cnn_caches.append(this_new_cnn_cache) + new_att_caches.append(this_new_att_cache) + x = self.final_layer(x, t) x = x.transpose(1, 2) - return x + return x, torch.stack(new_cnn_caches), torch.stack(new_att_caches) diff --git a/cosyvoice2/flow/flow.py b/cosyvoice2/flow/flow.py index f252d9b..d1aa4ac 100644 --- a/cosyvoice2/flow/flow.py +++ b/cosyvoice2/flow/flow.py @@ -65,39 +65,29 @@ def scatter_cuda_graph(self, enable_cuda_graph: bool): def inference(self, token, token_len, - prompt_token, - prompt_token_len, prompt_feat, prompt_feat_len, embedding, n_timesteps: int = 10, ): - assert token.shape[0] == 1 - # xvec projection embedding = F.normalize(embedding, dim=1) embedding = self.spk_embed_affine_layer(embedding) - - # concat text and prompt_text - token_len = prompt_token_len + token_len - token = torch.concat([prompt_token, token], dim=1) mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding) token = self.input_embedding(torch.clamp(token, min=0)) * mask # token encode - h, _ = self.encoder.forward(token, token_len) + h, h_lengths = self.encoder.forward(token, token_len) h = self.encoder_proj(h) - # condition - mel_len1 = prompt_feat.shape[1] - mel_len2 = h.shape[1] - prompt_feat.shape[1] - conds = torch.zeros_like(h) - conds[:, :mel_len1] = prompt_feat + for i, j in enumerate(prompt_feat_len): + conds[i, :j] = prompt_feat[i, :j] conds = conds.transpose(1, 2).contiguous() - mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h) + h_lengths = h_lengths.sum(dim=-1).squeeze(dim=1) + mask = (~make_pad_mask(h_lengths, max_len=h.shape[1])).to(h) feat = self.decoder.forward( mu=h.transpose(1, 2).contiguous(), @@ -107,9 +97,7 @@ def inference(self, n_timesteps=n_timesteps, ) - feat = feat[:, :, mel_len1:] - assert feat.shape[2] == mel_len2 - return feat + return feat.float(), h_lengths @torch.inference_mode() def setup_cache(self, @@ -149,7 +137,7 @@ def setup_cache(self, feat, estimator_cnn_cache, estimator_att_cache = self.decoder.forward_chunk( mu = h.transpose(1, 2).contiguous(), spks = spk, - cond = mel.transpose(1, 2).contiguous(), + cond = mel.transpose(1, 2).contiguous().to(h.dtype), n_timesteps = n_timesteps, temperature = 1.0, cnn_cache = None, diff --git a/cosyvoice2/flow/flow_matching.py b/cosyvoice2/flow/flow_matching.py index 900c71e..79d10b4 100644 --- a/cosyvoice2/flow/flow_matching.py +++ b/cosyvoice2/flow/flow_matching.py @@ -20,6 +20,13 @@ from cosyvoice2.utils.mask import make_pad_mask +def get_data_ptr(tensor: torch.Tensor, dummy_buffer: torch.Tensor): + if tensor.numel() == 0: + return dummy_buffer.data_ptr() + else: + return tensor.contiguous().data_ptr() + + """ Inference wrapper """ @@ -32,13 +39,42 @@ def __init__(self, estimator: DiT, inference_cfg_rate:float=0.7): # a maximum of 600s self.register_buffer('rand_noise', torch.randn([1, self.out_channels, 50 * 600]), persistent=False) - self.register_buffer('cnn_cache_buffer', torch.zeros(16, 16, 2, 1024, 2), persistent=False) - self.register_buffer('att_cache_buffer', torch.zeros(16, 16, 2, 8, 1000, 128), persistent=False) + self.register_buffer('dummy_buffer', torch.zeros(1), persistent=False) def scatter_cuda_graph(self, enable_cuda_graph: bool): if enable_cuda_graph: self.estimator._init_cuda_graph_all() + def forward_estimator(self, x, mask, mu, t, spks, cond): + if isinstance(self.estimator, torch.nn.Module): + return self.estimator(x, mask, mu, t, spks, cond) + else: + [estimator, stream], trt_engine = self.estimator.acquire_estimator() + # NOTE need to synchronize when switching stream + torch.cuda.current_stream().synchronize() + batch_size = x.size(0) + with stream: + estimator.set_input_shape('x', (batch_size, 80, x.size(2))) + estimator.set_input_shape('mask', (batch_size, 1, x.size(2))) + estimator.set_input_shape('mu', (batch_size, 80, x.size(2))) + estimator.set_input_shape('t', (batch_size,)) + estimator.set_input_shape('spks', (batch_size, 80)) + estimator.set_input_shape('cond', (batch_size, 80, x.size(2))) + data_ptrs = [x.contiguous().data_ptr(), + mask.contiguous().data_ptr(), + mu.contiguous().data_ptr(), + t.contiguous().data_ptr(), + spks.contiguous().data_ptr(), + cond.contiguous().data_ptr(), + x.data_ptr()] + for i, j in enumerate(data_ptrs): + estimator.set_tensor_address(trt_engine.get_tensor_name(i), j) + # run trt engine + assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True + torch.cuda.current_stream().synchronize() + self.estimator.release_estimator(estimator, stream) + return x + def solve_euler(self, x, t_span, mu, mask, spks, cond): """ Fixed euler solver for ODEs. @@ -55,7 +91,8 @@ def solve_euler(self, x, t_span, mu, mask, spks, cond): cond: Not used but kept for future purposes """ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] - t = t.unsqueeze(dim=0) + t_in = torch.zeros([x.shape[0] * 2], device=x.device, dtype=x.dtype) + assert self.inference_cfg_rate > 0, 'inference_cfg_rate better > 0' # constant during denoising @@ -65,18 +102,16 @@ def solve_euler(self, x, t_span, mu, mask, spks, cond): cond_in = torch.cat([cond, torch.zeros_like(cond)], dim=0) for step in range(1, len(t_span)): - x_in = torch.cat([x, x], dim=0) - t_in = torch.cat([t, t], dim=0) + t_in.fill_(t) - dphi_dt = self.estimator.forward( - x_in, - mask_in, - mu_in, - t_in, + dphi_dt = self.forward_estimator( + x_in, mask_in, + mu_in, t_in, spks_in, cond_in, ) + dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0) dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt) x = x + dt * dphi_dt @@ -88,12 +123,63 @@ def solve_euler(self, x, t_span, mu, mask, spks, cond): @torch.inference_mode() def forward(self, mu, mask, spks, cond, n_timesteps=10, temperature=1.0): - z = self.rand_noise[:, :, :mu.size(2)] * temperature + z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype) # cosine scheduling t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) return self.solve_euler(z, t_span, mu, mask, spks, cond) + + def forward_estimator_chunk(self, x, mu, t, spks, cond, cnn_cache, att_cache): + if isinstance(self.estimator, torch.nn.Module): + dphi_dt, this_new_cnn_cache, this_new_att_cache = self.estimator.forward_chunk( + x = x, + mu = mu, + t = t, + spks = spks, + cond = cond, + cnn_cache = cnn_cache, + att_cache = att_cache, + ) + return dphi_dt, this_new_cnn_cache, this_new_att_cache + else: + [estimator, stream], trt_engine = self.estimator.acquire_estimator() + # NOTE need to synchronize when switching stream + torch.cuda.current_stream().synchronize() + batch_size = x.size(0) + with stream: + estimator.set_input_shape('x', (batch_size, 80, x.size(2))) + estimator.set_input_shape('mu', (batch_size, 80, x.size(2))) + estimator.set_input_shape('t', (batch_size,)) + estimator.set_input_shape('spks', (batch_size, 80)) + estimator.set_input_shape('cond', (batch_size, 80, x.size(2))) + estimator.set_input_shape('cnn_cache', cnn_cache.shape) + estimator.set_input_shape('att_cache', att_cache.shape) + new_cnn_cache = torch.empty_like(cnn_cache) + new_att_cache_shape = list(att_cache.shape) + new_att_cache_shape[3] += x.size(2) + new_att_cache = torch.empty(new_att_cache_shape, device=att_cache.device, dtype=x.dtype) + data_ptrs = [x.contiguous().data_ptr(), + mu.contiguous().data_ptr(), + t.contiguous().data_ptr(), + spks.contiguous().data_ptr(), + cond.contiguous().data_ptr(), + cnn_cache.contiguous().data_ptr(), + get_data_ptr(att_cache, self.dummy_buffer), + x.data_ptr(), + new_cnn_cache.data_ptr(), + get_data_ptr(new_att_cache, self.dummy_buffer)] + + for i, j in enumerate(data_ptrs): + estimator.set_tensor_address(trt_engine.get_tensor_name(i), j) + # run trt engine + assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True + torch.cuda.current_stream().synchronize() + self.estimator.release_estimator(estimator, stream) + + return x, new_cnn_cache, new_att_cache + + def solve_euler_chunk(self, x:torch.Tensor, t_span:torch.Tensor, @@ -122,14 +208,18 @@ def solve_euler_chunk(self, assert self.inference_cfg_rate > 0, 'cfg rate should be > 0' t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] - t = t.unsqueeze(dim=0) # (b,) + t_in = torch.zeros([x.shape[0] * 2], device=x.device, dtype=x.dtype) # setup initial cache if cnn_cache is None: cnn_cache = [None for _ in range(len(t_span)-1)] + cnn_cache = torch.zeros((len(t_span)-1, 16, x.shape[0] * 2, 1024, 2), device=x.device, dtype=x.dtype) if att_cache is None: att_cache = [None for _ in range(len(t_span)-1)] + att_cache = torch.empty((len(t_span)-1, 16, x.shape[0] * 2, 8, 0, 128), device=x.device, dtype=x.dtype) # next chunk's cache at each timestep + new_cnn_caches = [] + new_att_caches = [] if att_cache[0] is not None: last_att_len = att_cache.shape[4] @@ -146,15 +236,19 @@ def solve_euler_chunk(self, this_att_cache = att_cache[step-1] this_cnn_cache = cnn_cache[step-1] - dphi_dt, this_new_cnn_cache, this_new_att_cache = self.estimator.forward_chunk( - x = x.repeat(2, 1, 1), + x_in = x.repeat(2, 1, 1) + t_in.fill_(t) + + dphi_dt, this_new_cnn_cache, this_new_att_cache = self.forward_estimator_chunk( + x = x_in, mu = mu_in, - t = t.repeat(2), + t = t_in, spks = spks_in, cond = cond_in, cnn_cache = this_cnn_cache, att_cache = this_att_cache, ) + dphi_dt, cfg_dphi_dt = dphi_dt.chunk(2, dim=0) dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt) x = x + dt * dphi_dt @@ -162,11 +256,11 @@ def solve_euler_chunk(self, if step < len(t_span) - 1: dt = t_span[step + 1] - t - self.cnn_cache_buffer[step-1] = this_new_cnn_cache - self.att_cache_buffer[step-1][:, :, :, :x.shape[2]+last_att_len, :] = this_new_att_cache + new_cnn_caches.append(this_new_cnn_cache) + new_att_caches.append(this_new_att_cache) - cnn_cache = self.cnn_cache_buffer - att_cache = self.att_cache_buffer[:, :, :, :, :x.shape[2]+last_att_len, :] + cnn_cache = torch.stack(new_cnn_caches) + att_cache = torch.stack(new_att_caches) return x, cnn_cache, att_cache @torch.inference_mode() @@ -190,6 +284,7 @@ def forward_chunk(self, # get offset from att_cache offset = att_cache.shape[4] if att_cache is not None else 0 z = self.rand_noise[:, :, offset:offset+mu.size(2)] * temperature + z = z.to(mu.dtype) t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype) # cosine scheduling t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) diff --git a/examples-vllm-stream.py b/examples-vllm-stream.py index 866fb45..a681ae6 100644 --- a/examples-vllm-stream.py +++ b/examples-vllm-stream.py @@ -43,7 +43,6 @@ def stream_client(model, history, tools, token2wav=None, output_stream=None, pro model = StepAudio2(api_url, model_name) token2wav = Token2wav('Step-Audio-2-mini/token2wav') tokens = [1493, 4299, 4218, 2049, 528, 2752, 4850, 4569, 4575, 6372, 2127, 4068, 2312, 4993, 4769, 2300, 226, 2175, 2160, 2152, 6311, 6065, 4859, 5102, 4615, 6534, 6426, 1763, 2249, 2209, 5938, 1725, 6048, 3816, 6058, 958, 63, 4460, 5914, 2379, 735, 5319, 4593, 2328, 890, 35, 751, 1483, 1484, 1483, 2112, 303, 4753, 2301, 5507, 5588, 5261, 5744, 5501, 2341, 2001, 2252, 2344, 1860, 2031, 414, 4366, 4366, 6059, 5300, 4814, 5092, 5100, 1923, 3054, 4320, 4296, 2148, 4371, 5831, 5084, 5027, 4946, 4946, 2678, 575, 575, 521, 518, 638, 1367, 2804, 3402, 4299] - token2wav.set_stream_cache(prompt_wav) token2wav.stream(tokens[:CHUNK_SIZE + token2wav.flow.pre_lookahead_len], prompt_wav=prompt_wav) # Warm up output_stream = Path('output-stream.pcm') diff --git a/token2wav.py b/token2wav.py index ad4163e..aa64b96 100644 --- a/token2wav.py +++ b/token2wav.py @@ -1,17 +1,28 @@ -import io - +""" Example Usage + CUDA_VISIBLE_DEVICES=0 \ + python3 token2wav.py +""" import torch -import torchaudio -import s3tokenizer -import onnxruntime -import numpy as np - -import torchaudio.compliance.kaldi as kaldi from flashcosyvoice.modules.hifigan import HiFTGenerator from flashcosyvoice.utils.audio import mel_spectrogram +import torchaudio.compliance.kaldi as kaldi +import onnxruntime +import s3tokenizer +from torch.utils.data import DataLoader +from datasets import load_dataset +import torchaudio +import os +import logging +import argparse +import queue +import time +import numpy as np from hyperpyyaml import load_hyperpyyaml +import io +from pathlib import Path +import wave -def fade_in_out(fade_in_mel:torch.Tensor, fade_out_mel:torch.Tensor, window:torch.Tensor): +def fade_in_out(fade_in_mel: torch.Tensor, fade_out_mel: torch.Tensor, window: torch.Tensor): """perform fade_in_out in tensor style """ mel_overlap_len = int(window.shape[0] / 2) @@ -22,129 +33,405 @@ def fade_in_out(fade_in_mel:torch.Tensor, fade_out_mel:torch.Tensor, window:torc return fade_in_mel -class Token2wav(): +def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, dtype): + import tensorrt as trt + logging.info("Converting onnx to trt...") + network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + logger = trt.Logger(trt.Logger.INFO) + builder = trt.Builder(logger) + network = builder.create_network(network_flags) + parser = trt.OnnxParser(network, logger) + config = builder.create_builder_config() + # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB + if dtype == torch.float16: + config.set_flag(trt.BuilderFlag.FP16) + + profile = builder.create_optimization_profile() + # load onnx model + with open(onnx_model, "rb") as f: + if not parser.parse(f.read()): + for error in range(parser.num_errors): + print(parser.get_error(error)) + raise ValueError('failed to parse {}'.format(onnx_model)) + # set input shapes + for i in range(len(trt_kwargs['input_names'])): + profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i]) + if dtype == torch.float16: + tensor_dtype = trt.DataType.HALF + elif dtype == torch.bfloat16: + tensor_dtype = trt.DataType.BF16 + elif dtype == torch.float32: + tensor_dtype = trt.DataType.FLOAT + else: + raise ValueError('invalid dtype {}'.format(dtype)) + # set input and output data type + for i in range(network.num_inputs): + input_tensor = network.get_input(i) + input_tensor.dtype = tensor_dtype + for i in range(network.num_outputs): + output_tensor = network.get_output(i) + output_tensor.dtype = tensor_dtype + config.add_optimization_profile(profile) + engine_bytes = builder.build_serialized_network(network, config) + # save trt engine + with open(trt_model, "wb") as f: + f.write(engine_bytes) + logging.info("Succesfully convert onnx to trt...") + + +class TrtContextWrapper: + def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'): + self.trt_context_pool = queue.Queue(maxsize=trt_concurrent) + self.trt_engine = trt_engine + self.device = device + for _ in range(trt_concurrent): + trt_context = trt_engine.create_execution_context() + trt_stream = torch.cuda.stream(torch.cuda.Stream(torch.device(device))) + assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent) + self.trt_context_pool.put([trt_context, trt_stream]) + assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context' + + def acquire_estimator(self): + return self.trt_context_pool.get(), self.trt_engine + + def release_estimator(self, context, stream): + self.trt_context_pool.put([context, stream]) + + +class Token2wav(torch.nn.Module): + def __init__(self, model_dir: str, enable_trt: bool = False, device_id: int = 0, streaming: bool = False, dtype: torch.dtype = torch.float16): + super().__init__() + self.device_id = device_id + self.device = f"cuda:{device_id}" + with open(f"{model_dir}/flow.yaml", "r") as f: + configs = load_hyperpyyaml(f) + self.flow = configs['flow'] - def __init__(self, model_path, float16=False): - self.float16 = float16 + self.dtype = dtype + self.flow.to(self.dtype) - self.audio_tokenizer = s3tokenizer.load_model(f"{model_path}/speech_tokenizer_v2_25hz.onnx").cuda().eval() + self.flow.load_state_dict(torch.load(f"{model_dir}/flow.pt", map_location="cpu", weights_only=True), strict=True) + self.flow.to(self.device).eval() + + self.hift = HiFTGenerator() + hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(f"{model_dir}/hift.pt", map_location="cpu", weights_only=True).items()} + self.hift.load_state_dict(hift_state_dict, strict=True) + self.hift.to(self.device).eval() option = onnxruntime.SessionOptions() option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL option.intra_op_num_threads = 1 - self.spk_model = onnxruntime.InferenceSession(f"{model_path}/campplus.onnx", sess_options=option, providers=["CPUExecutionProvider"]) - - with open(f"{model_path}/flow.yaml", "r") as f: - configs = load_hyperpyyaml(f) - self.flow = configs['flow'] - if float16: - self.flow.half() - self.flow.load_state_dict(torch.load(f"{model_path}/flow.pt", map_location="cpu", weights_only=True), strict=True) - self.flow.cuda().eval() - - self.hift = HiFTGenerator() - hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(f"{model_path}/hift.pt", map_location="cpu", weights_only=True).items()} - self.hift.load_state_dict(hift_state_dict, strict=True) - self.hift.cuda().eval() + self.spk_model = onnxruntime.InferenceSession( + f"{model_dir}/campplus.onnx", sess_options=option, + providers=["CPUExecutionProvider"]) + self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2_25hz.onnx").to(self.device).eval() + + if enable_trt: + if streaming: + self.load_trt( + f'{model_dir}/flow.decoder.estimator.{self.dtype}.static_batch.chunk.plan', + f'{model_dir}/flow.decoder.estimator.chunk.fp32.static_batch.onnx', + 1, + self.dtype, streaming + ) + else: + self.load_trt( + f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.plan', + f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx', + 1, + self.dtype + ) + self.load_spk_trt( + f'{model_dir}/campplus.fp32.trt', + f'{model_dir}/campplus.onnx', + 1, + False + ) - self.cache = {} + self.streaming_flow_cache = {} + self.speaker_cache = {} - # stream conf self.mel_cache_len = 8 # hard-coded, 160ms self.source_cache_len = int(self.mel_cache_len * 480) # 50hz mel -> 24kHz wave self.speech_window = torch.from_numpy(np.hamming(2 * self.source_cache_len)).cuda() - # hifigan cache + # hifigan cache for streaming tts self.hift_cache_dict = {} + def forward_spk_embedding(self, spk_feat): + if isinstance(self.spk_model, onnxruntime.InferenceSession): + return self.spk_model.run( + None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()} + )[0].flatten().tolist() + else: + [spk_model, stream], trt_engine = self.spk_model.acquire_estimator() + # NOTE need to synchronize when switching stream + with torch.cuda.device(self.device_id): + torch.cuda.current_stream().synchronize() + spk_feat = spk_feat.unsqueeze(dim=0).to(self.device) + batch_size = spk_feat.size(0) + + with stream: + spk_model.set_input_shape('input', (batch_size, spk_feat.size(1), 80)) + output_tensor = torch.empty((batch_size, 192), device=spk_feat.device) + + data_ptrs = [spk_feat.contiguous().data_ptr(), + output_tensor.contiguous().data_ptr()] + for i, j in enumerate(data_ptrs): + + spk_model.set_tensor_address(trt_engine.get_tensor_name(i), j) + # run trt engine + assert spk_model.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True + torch.cuda.current_stream().synchronize() + self.spk_model.release_estimator(spk_model, stream) + + return output_tensor.cpu().numpy().flatten().tolist() + + def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp16=True): + if not os.path.exists(spk_model) or os.path.getsize(spk_model) == 0: + trt_kwargs = self.get_spk_trt_kwargs() + convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, torch.float32) + import tensorrt as trt + with open(spk_model, 'rb') as f: + spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) + assert spk_engine is not None, 'failed to load trt {}'.format(spk_model) + self.spk_model = TrtContextWrapper(spk_engine, trt_concurrent=trt_concurrent, device=self.device) + + def get_spk_trt_kwargs(self): + min_shape = [(1, 4, 80)] + opt_shape = [(1, 500, 80)] + max_shape = [(1, 3000, 80)] + input_names = ["input"] + return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} + + def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent=1, dtype=torch.float16, streaming=False): + assert torch.cuda.is_available(), 'tensorrt only supports gpu!' + assert os.path.exists(flow_decoder_onnx_model), f'Please use tools/export_onnx.py or tools/export_onnx_streaming.py to export onnx model for token2wav first.' + if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0: + opt_batch_size = 2 + max_batch_size = 16 + if streaming: + opt_batch_size, max_batch_size = 1, 1 # only support batch size 1 for streaming tts + trt_kwargs = self.get_trt_kwargs_dynamic_batch(opt_batch_size=opt_batch_size, max_batch_size=max_batch_size, streaming=streaming) + convert_onnx_to_trt(flow_decoder_estimator_model, trt_kwargs, flow_decoder_onnx_model, dtype) + del self.flow.decoder.estimator + import tensorrt as trt + with open(flow_decoder_estimator_model, 'rb') as f: + estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) + assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model) + self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device) + + def get_trt_kwargs_dynamic_batch(self, opt_batch_size=2, max_batch_size=64, streaming=False): + if streaming: + min_shape = [(2, 80, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80), (16, 2, 1024, 2), (16, 2, 8, 0, 128)] + opt_shape = [ + (opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 80, 500), + (opt_batch_size * 2,), (opt_batch_size * 2, 80), (16, opt_batch_size * 2, 1024, 2), + (16, opt_batch_size * 2, 8, 100, 128) + ] + max_shape = [ + (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000), + (max_batch_size * 2,), (max_batch_size * 2, 80), (16, max_batch_size * 2, 1024, 2), + (16, max_batch_size * 2, 8, 1000, 128) + ] + input_names = ["x", "mu", "cond", "t", "spks", "cnn_cache", "att_cache"] + else: + min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80)] + opt_shape = [ + (opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 1, 500), (opt_batch_size * 2, 80, 500), + (opt_batch_size * 2, 80, 500), (opt_batch_size * 2,), (opt_batch_size * 2, 80) + ] + max_shape = [ + (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 1, 3000), (max_batch_size * 2, 80, 3000), + (max_batch_size * 2, 80, 3000), (max_batch_size * 2,), (max_batch_size * 2, 80) + ] + input_names = ["x", "mask", "mu", "cond", "t", "spks"] + return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} + + def prompt_audio_tokenization(self, prompt_audios_list: list[torch.Tensor]) -> list[list[int]]: + prompt_speech_tokens_list, prompt_speech_mels_list = [], [] + for audio in prompt_audios_list: + assert len(audio.shape) == 1 + log_mel = s3tokenizer.log_mel_spectrogram(audio) # [num_mels, T] + prompt_speech_mels_list.append(log_mel) + prompt_mels_for_llm, prompt_mels_lens_for_llm = s3tokenizer.padding(prompt_speech_mels_list) + prompt_speech_tokens, prompt_speech_tokens_lens = self.audio_tokenizer.quantize( + prompt_mels_for_llm.to(self.device), prompt_mels_lens_for_llm.to(self.device) + ) + for i in range(len(prompt_speech_tokens)): + speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist() + prompt_speech_tokens_list.append(speech_tokens_i) + return prompt_speech_tokens_list + + def get_spk_emb(self, prompt_audios_list: list[torch.Tensor]) -> torch.Tensor: + spk_emb_for_flow = [] + for audio in prompt_audios_list: + assert len(audio.shape) == 1 + spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000) + spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True) + spk_emb = self.forward_spk_embedding(spk_feat) + + spk_emb_for_flow.append(spk_emb) + spk_emb_for_flow = torch.tensor(spk_emb_for_flow) + if self.dtype != torch.float32: + spk_emb_for_flow = spk_emb_for_flow.to(self.dtype) + return spk_emb_for_flow + + def get_prompt_mels(self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]): + prompt_mels_for_flow = [] + prompt_mels_lens_for_flow = [] + for audio, sample_rate in zip(prompt_audios_list, prompt_audios_sample_rate): + assert len(audio.shape) == 1 + audio = audio.unsqueeze(0) + if sample_rate != 24000: + audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio) + mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0) # [T, num_mels] + mel_len = mel.shape[0] + prompt_mels_for_flow.append(mel) + prompt_mels_lens_for_flow.append(mel_len) + prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence( + prompt_mels_for_flow, batch_first=True, padding_value=0 + ) # [B, T', num_mels=80] + prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow) + return prompt_mels_for_flow, prompt_mels_lens_for_flow + + def forward_flow(self, prompt_speech_tokens_list: list[list[int]], + generated_speech_tokens_list: list[list[int]], + prompt_mels_for_flow: torch.Tensor, + prompt_mels_lens_for_flow: torch.Tensor, + spk_emb_for_flow: torch.Tensor): + batch_size = prompt_mels_for_flow.shape[0] + flow_inputs = [] + flow_inputs_lens = [] + for prompt_speech_tokens, generated_speech_tokens in zip(prompt_speech_tokens_list, generated_speech_tokens_list): + flow_inputs.append(torch.tensor(prompt_speech_tokens + generated_speech_tokens)) + flow_inputs_lens.append(len(prompt_speech_tokens) + len(generated_speech_tokens)) + + flow_inputs = torch.nn.utils.rnn.pad_sequence(flow_inputs, batch_first=True, padding_value=0) + flow_inputs_lens = torch.tensor(flow_inputs_lens) + + with torch.amp.autocast(self.device, dtype=torch.float16): + generated_mels, generated_mels_lens = self.flow.inference( + flow_inputs.to(self.device), flow_inputs_lens.to(self.device), + prompt_mels_for_flow.to(self.device), prompt_mels_lens_for_flow.to(self.device), spk_emb_for_flow.to(self.device), 10 + ) - def _prepare_prompt(self, prompt_wav): - audio = s3tokenizer.load_audio(prompt_wav, sr=16000) # [T] - mels = s3tokenizer.log_mel_spectrogram(audio) - mels, mels_lens = s3tokenizer.padding([mels]) - prompt_speech_tokens, prompt_speech_tokens_lens = self.audio_tokenizer.quantize(mels.cuda(), mels_lens.cuda()) - - spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000) - spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True) - spk_emb = torch.tensor(self.spk_model.run( - None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()} - )[0], device='cuda') - - audio, sample_rate = torchaudio.load(prompt_wav, backend='soundfile') - audio = audio.mean(dim=0, keepdim=True) # [1, T] - if sample_rate != 24000: - audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio) - prompt_mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0) # [T, num_mels] - prompt_mels = prompt_mel.unsqueeze(0).cuda() - prompt_mels_lens = torch.tensor([prompt_mels.shape[1]], dtype=torch.int32, device='cuda') - prompt_mels = torch.nn.functional.pad(prompt_mels, (0, 0, 0, prompt_speech_tokens.shape[1] * self.flow.up_rate - prompt_mels.shape[1]), mode='replicate') - return prompt_speech_tokens, prompt_speech_tokens_lens, spk_emb, prompt_mels, prompt_mels_lens - - def __call__(self, generated_speech_tokens, prompt_wav): - if prompt_wav not in self.cache: - self.cache[prompt_wav] = self._prepare_prompt(prompt_wav) - prompt_speech_tokens, prompt_speech_tokens_lens, spk_emb, prompt_mels, prompt_mels_lens = self.cache[prompt_wav] - - generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda') - generated_speech_tokens_lens = torch.tensor([generated_speech_tokens.shape[1]], dtype=torch.int32, device='cuda') + return generated_mels, generated_mels_lens + + def forward_hift(self, generated_mels: torch.Tensor, generated_mels_lens: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor): + batch_size = generated_mels.shape[0] + generated_wavs = [] + for i in range(batch_size): + mel = generated_mels[i, :, prompt_mels_lens_for_flow[i].item():generated_mels_lens[i].item()].unsqueeze(0) + wav, _ = self.hift(speech_feat=mel) + generated_wavs.append(wav) + return generated_wavs + + @torch.inference_mode() + def forward( + self, generated_speech_tokens: list[int], prompt_wav: str): + generated_speech_tokens_list = [generated_speech_tokens] + audio = s3tokenizer.load_audio(prompt_wav, sr=16000) + prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio([audio], [16000]) + generated_mels, generated_mels_lens = self.forward_flow( + prompt_speech_tokens_list, generated_speech_tokens_list, + prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow + ) - with torch.amp.autocast("cuda", dtype=torch.float16 if self.float16 else torch.float32): - mel = self.flow.inference(generated_speech_tokens, generated_speech_tokens_lens, - prompt_speech_tokens, prompt_speech_tokens_lens, - prompt_mels, prompt_mels_lens, spk_emb, 10) + generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow) - wav, _ = self.hift(speech_feat=mel) + wav = generated_wavs[0] output = io.BytesIO() torchaudio.save(output, wav.cpu(), sample_rate=24000, format='wav') return output.getvalue() - def set_stream_cache(self, prompt_wav): - if prompt_wav not in self.cache: - self.cache[prompt_wav] = self._prepare_prompt(prompt_wav) - prompt_speech_tokens, prompt_speech_tokens_lens, spk_emb, prompt_mels, prompt_mels_lens = self.cache[prompt_wav] - self.stream_cache = self.flow.setup_cache( - torch.cat([prompt_speech_tokens, prompt_speech_tokens[:, :3]], dim=1), - prompt_mels, spk_emb, n_timesteps=10) - - # hift cache - self.hift_cache_dict = dict( - mel = torch.zeros(1, prompt_mels.shape[2], 0, device='cuda'), - source = torch.zeros(1, 1, 0, device='cuda'), - speech = torch.zeros(1, 0, device='cuda'), - ) + def prepare_prompt_audio( + self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int] + ): + assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate) + prompt_speech_tokens_list = self.prompt_audio_tokenization(prompt_audios_list) - def stream(self, generated_speech_tokens, prompt_wav, last_chunk=False): - if prompt_wav not in self.cache: - self.cache[prompt_wav] = self._prepare_prompt(prompt_wav) - prompt_speech_tokens, prompt_speech_tokens_lens, spk_emb, prompt_mels, prompt_mels_lens = self.cache[prompt_wav] + prompt_mels_for_flow, prompt_mels_lens_for_flow = self.get_prompt_mels(prompt_audios_list, prompt_audios_sample_rate) - generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda') - generated_speech_tokens_lens = torch.tensor([generated_speech_tokens.shape[1]], dtype=torch.int32, device='cuda') - - if self.stream_cache is None: - raise ValueError("stream_cache is not set") - - with torch.amp.autocast("cuda", dtype=torch.float16 if self.float16 else torch.float32): - chunk_mel, self.stream_cache = self.flow.inference_chunk( - token=generated_speech_tokens, - spk=spk_emb, - cache=self.stream_cache, - last_chunk=last_chunk, - n_timesteps=10, + spk_emb_for_flow = self.get_spk_emb(prompt_audios_list) + return prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow + + def get_prompt_audio_cache_for_streaming_tts( + self, prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow + ): + assert len(prompt_speech_tokens_list) == 1, "only support batch size 1 for streaming tts" + for i, prompt_speech_tokens in enumerate(prompt_speech_tokens_list): + prompt_speech_tokens_list[i] = torch.tensor(prompt_speech_tokens + prompt_speech_tokens_list[i][:3]) + prompt_speech_tokens_tensor = torch.nn.utils.rnn.pad_sequence(prompt_speech_tokens_list, batch_first=True, padding_value=0) + + cache = self.flow.setup_cache( + prompt_speech_tokens_tensor.to(self.device), + prompt_mels_for_flow.to(self.device), + spk_emb_for_flow.to(self.device), + n_timesteps=10 + ) + new_cache = {k: v.clone() for k, v in cache.items()} + # Hack: this is a hack to avoid in-place changes to the cache['estimator_att_cache'] and cache['estimator_cnn_cache'] + return new_cache + + @torch.inference_mode() + def stream( + self, generated_speech_tokens: list[int], prompt_wav: str, last_chunk: bool = False, + ): + speaker_id = prompt_wav + request_id = prompt_wav + prompt_audio_sample_rate = 16000 + prompt_audio = s3tokenizer.load_audio(prompt_wav, sr=prompt_audio_sample_rate) + if speaker_id not in self.speaker_cache: + assert prompt_audio is not None, "prompt_audio is required for new speaker" + assert prompt_audio_sample_rate == 16000 + + prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio([prompt_audio], [prompt_audio_sample_rate]) + + token_len = min(int(prompt_mels_for_flow.shape[1] / 2), len(prompt_speech_tokens_list[0])) + prompt_mels_for_flow = prompt_mels_for_flow[:, :2 * token_len].contiguous() + prompt_speech_tokens_list[0] = prompt_speech_tokens_list[0][:token_len] + + prompt_audio_dict = {'spk_emb_for_flow': spk_emb_for_flow, 'prompt_mels_for_flow': prompt_mels_for_flow} + + cache_dict = self.get_prompt_audio_cache_for_streaming_tts(prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow) + self.speaker_cache[speaker_id] = {'prompt_audio_dict': prompt_audio_dict, 'cache_dict': cache_dict} + + if request_id not in self.streaming_flow_cache: + self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()} + self.hift_cache_dict[request_id] = dict( + mel=torch.zeros(1, 80, 0, device='cuda'), + source=torch.zeros(1, 1, 0, device='cuda'), + speech=torch.zeros(1, 0, device='cuda'), ) - if self.stream_cache['estimator_att_cache'].shape[4] > (prompt_mels.shape[1] + 100): - self.stream_cache['estimator_att_cache'] = torch.cat([ - self.stream_cache['estimator_att_cache'][:, :, :, :, :prompt_mels.shape[1]], - self.stream_cache['estimator_att_cache'][:, :, :, :, -100:], + + current_request_cache = self.streaming_flow_cache[request_id] + + current_prompt_audio_dict = self.speaker_cache[speaker_id]['prompt_audio_dict'] + generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda') + + chunk_mel, new_streaming_flow_cache = self.flow.inference_chunk( + token=generated_speech_tokens, + spk=current_prompt_audio_dict['spk_emb_for_flow'].to(self.device), + cache=current_request_cache, + last_chunk=last_chunk, + n_timesteps=10, + ) + + self.streaming_flow_cache[request_id] = new_streaming_flow_cache + + if self.streaming_flow_cache[request_id]['estimator_att_cache'].shape[4] > (current_prompt_audio_dict['prompt_mels_for_flow'].shape[1] + 100): + self.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.cat([ + self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, :current_prompt_audio_dict['prompt_mels_for_flow'].shape[1]], + self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, -100:], ], dim=4) - - # vocoder cache - hift_cache_mel = self.hift_cache_dict['mel'] - hift_cache_source = self.hift_cache_dict['source'] - hift_cache_speech = self.hift_cache_dict['speech'] - mel = torch.concat([hift_cache_mel, chunk_mel], dim=2) + + hift_cache_mel = self.hift_cache_dict[request_id]['mel'].clone() + hift_cache_source = self.hift_cache_dict[request_id]['source'].clone() + hift_cache_speech = self.hift_cache_dict[request_id]['speech'].clone() + mel = torch.concat([hift_cache_mel, chunk_mel], dim=2).clone() speech, source = self.hift(mel, hift_cache_source) @@ -153,14 +440,18 @@ def stream(self, generated_speech_tokens, prompt_wav, last_chunk=False): speech = fade_in_out(speech, hift_cache_speech, self.speech_window) # update vocoder cache - self.hift_cache_dict = dict( - mel = mel[..., -self.mel_cache_len:].clone().detach(), - source = source[:, :, -self.source_cache_len:].clone().detach(), - speech = speech[:, -self.source_cache_len:].clone().detach(), + self.hift_cache_dict[request_id] = dict( + mel=mel[..., -self.mel_cache_len:].clone().detach(), + source=source[:, :, -self.source_cache_len:].clone().detach(), + speech=speech[:, -self.source_cache_len:].clone().detach(), ) if not last_chunk: speech = speech[:, :-self.source_cache_len] + if last_chunk: + assert request_id in self.streaming_flow_cache + self.streaming_flow_cache.pop(request_id) + self.hift_cache_dict.pop(request_id) wav_np = speech.cpu().numpy() # Clip to [-1, 1] to avoid overflow, then scale to int16 wav_np = np.clip(wav_np, -1.0, 1.0) @@ -168,10 +459,34 @@ def stream(self, generated_speech_tokens, prompt_wav, last_chunk=False): pcm_bytes = wav_int16.tobytes() return pcm_bytes -if __name__ == '__main__': - token2wav = Token2wav('Step-Audio-2-mini/token2wav') - +if __name__ == "__main__": tokens = [1493, 4299, 4218, 2049, 528, 2752, 4850, 4569, 4575, 6372, 2127, 4068, 2312, 4993, 4769, 2300, 226, 2175, 2160, 2152, 6311, 6065, 4859, 5102, 4615, 6534, 6426, 1763, 2249, 2209, 5938, 1725, 6048, 3816, 6058, 958, 63, 4460, 5914, 2379, 735, 5319, 4593, 2328, 890, 35, 751, 1483, 1484, 1483, 2112, 303, 4753, 2301, 5507, 5588, 5261, 5744, 5501, 2341, 2001, 2252, 2344, 1860, 2031, 414, 4366, 4366, 6059, 5300, 4814, 5092, 5100, 1923, 3054, 4320, 4296, 2148, 4371, 5831, 5084, 5027, 4946, 4946, 2678, 575, 575, 521, 518, 638, 1367, 2804, 3402, 4299] - audio = token2wav(tokens, 'assets/default_male.wav') - with open('assets/give_me_a_brief_introduction_to_the_great_wall.wav', 'wb') as f: - f.write(audio) + # offline token2wav + # token2wav = Token2wav('Step-Audio-2-mini/token2wav', enable_trt=True) + # audio = token2wav(tokens, 'assets/default_male.wav') + # with open('assets/give_me_a_brief_introduction_to_the_great_wall.wav', 'wb') as f: + # f.write(audio) + + # streaming token2wav using pytorch + # token2wav = Token2wav('Step-Audio-2-mini/token2wav') + + # streaming token2wav using tensorrt + token2wav = Token2wav('Step-Audio-2-mini/token2wav', enable_trt=True, streaming=True) + audio_first_chunk = token2wav.stream(tokens[:25 + token2wav.flow.pre_lookahead_len], prompt_wav='assets/default_male.wav') + audio_last_chunk = token2wav.stream(tokens[25 + token2wav.flow.pre_lookahead_len:], prompt_wav='assets/default_male.wav', last_chunk=True) + + + output_stream = Path('output-stream.pcm') + output_stream.unlink(missing_ok=True) + with open(output_stream, 'wb') as f: + f.write(audio_first_chunk) + f.write(audio_last_chunk) + + with open(output_stream, 'rb') as f: + pcm = f.read() + wav_path = output_stream.with_suffix('.wav') + with wave.open(str(wav_path), 'wb') as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(24000) + wf.writeframes(pcm) \ No newline at end of file diff --git a/tools/export_onnx_offline_token2wav.py b/tools/export_onnx_offline_token2wav.py new file mode 100644 index 0000000..1f64e36 --- /dev/null +++ b/tools/export_onnx_offline_token2wav.py @@ -0,0 +1,130 @@ +# Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com) +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script is used to export the offline token2wav model to onnx. +python3 tools/export_onnx_offline_token2wav.py +""" + +from __future__ import print_function + +import argparse +import logging +logging.getLogger('matplotlib').setLevel(logging.WARNING) +import os +import sys +import onnxruntime +import random +import torch +from tqdm import tqdm +from hyperpyyaml import load_hyperpyyaml +import sys +import os +# add ../ to python path +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +def get_args(): + parser = argparse.ArgumentParser(description='export your model for deployment') + parser.add_argument('--model_dir', + type=str, + default='Step-Audio-2-mini/token2wav', + help='local path') + parser.add_argument('--onnx_model', + type=str, + default='flow.decoder.estimator.fp32.dynamic_batch.onnx', + help='onnx model name') + args = parser.parse_args() + print(args) + return args + +def get_dummy_input(batch_size, seq_len, out_channels, device): + x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) + mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device) + mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) + t = torch.rand((batch_size), dtype=torch.float32, device=device) + spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device) + cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) + return x, mask, mu, t, spks, cond + +@torch.no_grad() +def main(): + args = get_args() + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + + with open(f"{args.model_dir}/flow.yaml", "r") as f: + configs = load_hyperpyyaml(f) + flow_model = configs['flow'] + + device = torch.device('cuda') + + + # 1. export flow decoder estimator + flow_model.load_state_dict(torch.load(f"{args.model_dir}/flow.pt", map_location="cpu", weights_only=True), strict=True) + estimator = flow_model.decoder.estimator + estimator.eval() + estimator.to(device) + + + batch_size, seq_len = 2, 256 + out_channels = flow_model.decoder.estimator.out_channels + x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device) + torch.onnx.export( + estimator, + (x, mask, mu, t, spks, cond), + f'{args.model_dir}/{args.onnx_model}', + export_params=True, + opset_version=18, + do_constant_folding=True, + input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'], + output_names=['estimator_out'], + dynamic_axes={ + 'x': {0: 'batch_size', 2: 'seq_len'}, + 'mask': {0: 'batch_size', 2: 'seq_len'}, + 'mu': {0: 'batch_size', 2: 'seq_len'}, + 'cond': {0: 'batch_size', 2: 'seq_len'}, + 't': {0: 'batch_size'}, + 'spks': {0: 'batch_size'}, + 'estimator_out': {0: 'batch_size', 2: 'seq_len'}, + + } + ) + + # 2. test computation consistency + option = onnxruntime.SessionOptions() + option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + option.intra_op_num_threads = 1 + providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider'] + estimator_onnx = onnxruntime.InferenceSession(f'{args.model_dir}/{args.onnx_model}', + sess_options=option, providers=providers) + + for _ in tqdm(range(10)): + x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device) + output_pytorch = estimator(x, mask, mu, t, spks, cond) + ort_inputs = { + 'x': x.cpu().numpy(), + 'mask': mask.cpu().numpy(), + 'mu': mu.cpu().numpy(), + 't': t.cpu().numpy(), + 'spks': spks.cpu().numpy(), + 'cond': cond.cpu().numpy() + } + output_onnx = estimator_onnx.run(None, ort_inputs)[0] + torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4) + logging.info('successfully export estimator') + + +if __name__ == "__main__": + main() + diff --git a/tools/export_onnx_streaming_token2wav.py b/tools/export_onnx_streaming_token2wav.py new file mode 100644 index 0000000..5f4362d --- /dev/null +++ b/tools/export_onnx_streaming_token2wav.py @@ -0,0 +1,161 @@ +# Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com) +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script is used to export the streaming token2wav model to onnx. +python3 tools/export_onnx_streaming_token2wav.py +""" + +from __future__ import print_function + +import argparse +import logging +logging.getLogger('matplotlib').setLevel(logging.WARNING) +import os +import sys +import onnxruntime +import random +import torch +from tqdm import tqdm +from hyperpyyaml import load_hyperpyyaml + +import sys +import os +# add ../ to python path +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +def get_args(): + parser = argparse.ArgumentParser(description='export your model for deployment') + parser.add_argument('--model_dir', + type=str, + default='Step-Audio-2-mini/token2wav', + help='local path') + parser.add_argument('--onnx_model', + type=str, + default='flow.decoder.estimator.chunk.fp32.static_batch.onnx', + help='onnx model name') + args = parser.parse_args() + print(args) + return args + +def get_dummy_input_chunk(batch_size, seq_len, prev_seq_len, out_channels, estimator, device): + x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) + mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) + t = torch.rand((batch_size), dtype=torch.float32, device=device) + spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device) + cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) + + depth = len(estimator.blocks) + num_heads = estimator.blocks[0].attn.num_heads + head_dim = estimator.blocks[0].attn.head_dim + cnn_channels = estimator.blocks[0].conv.in_channels + estimator.blocks[0].conv.out_channels + + cnn_cache = torch.rand((depth, batch_size, cnn_channels, 2), dtype=torch.float32, device=device) + att_cache = torch.rand((depth, batch_size, num_heads, prev_seq_len, head_dim * 2), dtype=torch.float32, device=device) + return x, mu, t, spks, cond, cnn_cache, att_cache + +class DiTChunkWrapper(torch.nn.Module): + def __init__(self, dit_model): + super().__init__() + self.dit_model = dit_model + + def forward(self, x, mu, t, spks, cond, cnn_cache, att_cache): + return self.dit_model.forward_chunk(x, mu, t, spks, cond, cnn_cache, att_cache) + + +@torch.no_grad() +def main(): + args = get_args() + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + + with open(f"{args.model_dir}/flow.yaml", "r") as f: + configs = load_hyperpyyaml(f) + flow_model = configs['flow'] + + device = torch.device('cuda') + + + # 1. export flow decoder estimator for chunk processing + flow_model.load_state_dict(torch.load(f"{args.model_dir}/flow.pt", map_location="cpu", weights_only=True), strict=True) + estimator = flow_model.decoder.estimator + estimator.eval() + estimator.to(device) + + estimator_chunk_wrapper = DiTChunkWrapper(estimator) + + batch_size, seq_len, prev_seq_len = 2, 500, 100 + out_channels = flow_model.decoder.estimator.out_channels + dummy_inputs = get_dummy_input_chunk(batch_size, seq_len, prev_seq_len, out_channels, estimator, device) + (x, mu, t, spks, cond, cnn_cache, att_cache) = dummy_inputs + + torch.onnx.export( + estimator_chunk_wrapper, + dummy_inputs, + f'{args.model_dir}/{args.onnx_model}', + export_params=True, + opset_version=18, + do_constant_folding=True, + input_names=['x', 'mu', 't', 'spks', 'cond', 'cnn_cache', 'att_cache'], + output_names=['output', 'new_cnn_cache', 'new_att_cache'], + dynamic_axes={ + 'x': {0: 'batch_size', 2: 'seq_len'}, + 'mu': {0: 'batch_size', 2: 'seq_len'}, + 'cond': {0: 'batch_size', 2: 'seq_len'}, + 't': {0: 'batch_size'}, + 'spks': {0: 'batch_size'}, + 'cnn_cache': {1: 'batch_size'}, + 'att_cache': {1: 'batch_size', 3: 'prev_seq_len'}, + 'output': {0: 'batch_size', 2: 'seq_len'}, + 'new_cnn_cache': {1: 'batch_size'}, + 'new_att_cache': {1: 'batch_size', 3: 'total_seq_len'}, + } + ) + + # 2. test computation consistency + option = onnxruntime.SessionOptions() + option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + option.intra_op_num_threads = 1 + providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider'] + estimator_onnx = onnxruntime.InferenceSession(f'{args.model_dir}/{args.onnx_model}', + sess_options=option, providers=providers) + + for _ in tqdm(range(10)): + seq_len = random.randint(16, 512) + prev_seq_len = random.randint(16, 1024) + dummy_inputs = get_dummy_input_chunk(batch_size, seq_len, prev_seq_len, out_channels, estimator, device) + (x, mu, t, spks, cond, cnn_cache, att_cache) = dummy_inputs + + output_pytorch, new_cnn_cache_pytorch, new_att_cache_pytorch = estimator_chunk_wrapper(*dummy_inputs) + + ort_inputs = { + 'x': x.cpu().numpy(), + 'mu': mu.cpu().numpy(), + 't': t.cpu().numpy(), + 'spks': spks.cpu().numpy(), + 'cond': cond.cpu().numpy(), + 'cnn_cache': cnn_cache.cpu().numpy(), + 'att_cache': att_cache.cpu().numpy(), + } + output_onnx, new_cnn_cache_onnx, new_att_cache_onnx = estimator_onnx.run(None, ort_inputs) + + torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4) + torch.testing.assert_allclose(new_cnn_cache_pytorch, torch.from_numpy(new_cnn_cache_onnx).to(device), rtol=1e-2, atol=1e-4) + torch.testing.assert_allclose(new_att_cache_pytorch, torch.from_numpy(new_att_cache_onnx).to(device), rtol=1e-2, atol=1e-4) + + logging.info('successfully export chunk-wise estimator') + + +if __name__ == "__main__": + main() diff --git a/tools/tensorrt_token2wav.md b/tools/tensorrt_token2wav.md new file mode 100644 index 0000000..8888d93 --- /dev/null +++ b/tools/tensorrt_token2wav.md @@ -0,0 +1,103 @@ +# Accelerating StepAudio2 Token2wav with NVIDIA TensorRT + +This document provides instructions on how to use NVIDIA TensorRT to accelerate the Token2wav module in StepAudio2 for both offline and streaming inference. + +## Preparation + +### 1. Install Dependencies + +Install the necessary packages using pip. For GPU acceleration with TensorRT, use `onnxruntime-gpu`. + +```bash +pip install tensorrt onnxruntime-gpu +``` + +### 2. Export ONNX Models + +You need to export the PyTorch models to ONNX format. There are separate scripts for offline (dynamic batch) and streaming (static batch) modes. + +**For Offline Inference:** +```bash +python3 tools/export_onnx_offline_token2wav.py +``` + +**For Streaming Inference:** +```bash +python3 tools/export_onnx_streaming_token2wav.py +``` + +## Usage + +### Offline Inference + +Here is an example of how to use the TensorRT-accelerated Token2wav model for offline inference. + +```python +from token2wav import Token2wav +import wave + +# The tokens to be converted to speech +tokens = [1493, 4299, 4218, 2049, 528, 2752, 4850, 4569, 4575, 6372, 2127, 4068, 2312, 4993, 4769, 2300, 226, 2175, 2160, 2152, 6311, 6065, 4859, 5102, 4615, 6534, 6426, 1763, 2249, 2209, 5938, 1725, 6048, 3816, 6058, 958, 63, 4460, 5914, 2379, 735, 5319, 4593, 2328, 890, 35, 751, 1483, 1484, 1483, 2112, 303, 4753, 2301, 5507, 5588, 5261, 5744, 5501, 2341, 2001, 2252, 2344, 1860, 2031, 414, 4366, 4366, 6059, 5300, 4814, 5092, 5100, 1923, 3054, 4320, 4296, 2148, 4371, 5831, 5084, 5027, 4946, 4946, 2678, 575, 575, 521, 518, 638, 1367, 2804, 3402, 4299] + +# Initialize Token2wav with TensorRT enabled +token2wav = Token2wav('Step-Audio-2-mini/token2wav', enable_trt=True) + +# Generate audio +audio_bytes = token2wav(tokens, 'assets/default_male.wav') + +# Save the generated audio to a file +with open('output_offline.wav', 'wb') as f: + f.write(audio_bytes) +``` + +### Streaming Inference + +For streaming inference, you can process tokens in chunks. + +```python +from token2wav import Token2wav +from pathlib import Path +import wave + +tokens = [1493, 4299, 4218, 2049, 528, 2752, 4850, 4569, 4575, 6372, 2127, 4068, 2312, 4993, 4769, 2300, 226, 2175, 2160, 2152, 6311, 6065, 4859, 5102, 4615, 6534, 6426, 1763, 2249, 2209, 5938, 1725, 6048, 3816, 6058, 958, 63, 4460, 5914, 2379, 735, 5319, 4593, 2328, 890, 35, 751, 1483, 1484, 1483, 2112, 303, 4753, 2301, 5507, 5588, 5261, 5744, 5501, 2341, 2001, 2252, 2344, 1860, 2031, 414, 4366, 4366, 6059, 5300, 4814, 5092, 5100, 1923, 3054, 4320, 4296, 2148, 4371, 5831, 5084, 5027, 4946, 4946, 2678, 575, 575, 521, 518, 638, 1367, 2804, 3402, 4299] + +# Initialize Token2wav for streaming with TensorRT +token2wav = Token2wav('Step-Audio-2-mini/token2wav', enable_trt=True, streaming=True) + +# Process the first chunk of tokens +audio_first_chunk = token2wav.stream(tokens[:25 + token2wav.flow.pre_lookahead_len], prompt_wav='assets/default_male.wav') + +# Process the remaining tokens as the last chunk +audio_last_chunk = token2wav.stream(tokens[25 + token2wav.flow.pre_lookahead_len:], prompt_wav='assets/default_male.wav', last_chunk=True) + +# Save the streaming output to a PCM file +output_stream = Path('output-stream.pcm') +output_stream.unlink(missing_ok=True) +with open(output_stream, 'wb') as f: + f.write(audio_first_chunk) + f.write(audio_last_chunk) + +# Convert PCM to WAV +with open(output_stream, 'rb') as f: + pcm = f.read() +wav_path = output_stream.with_suffix('.wav') +with wave.open(str(wav_path), 'wb') as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(24000) + wf.writeframes(pcm) + +``` + +## Benchmark + +The following benchmark was conducted on an NVIDIA L20 GPU, generating 26 audio clips with a total length of 170 seconds. RTF (Real-Time Factor) is calculated as `Cost Time / Total Audio Length`. + +| Method | Note | Cost Time | RTF | +|-----------|-------------------------------------|----------------|---------| +| Offline | batch=1, PyTorch | 4.32 seconds | 0.025 | +| Offline | batch=1, TensorRT enabled | 2.09 seconds | 0.012 | +| Offline | batch=2, PyTorch | 3.77 seconds | 0.022 | +| Offline | batch=2, TensorRT enabled | 1.97 seconds | 0.012 | +| Streaming | batch=1, chunk_size = 1 second, PyTorch | 20.3 seconds | 0.119 | +| Streaming | batch=1, chunk_size = 1 second, TensorRT | 12.96 seconds | 0.076 |