Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 18 additions & 20 deletions cosyvoice2/flow/decoder_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,9 +403,6 @@ def __init__(
self.inference_buffers_chunk = {}
self.max_size_chunk = {}

self.register_buffer('att_cache_buffer', torch.zeros((16, 2, 8, 1000, 128)), persistent=False)
self.register_buffer('cnn_cache_buffer', torch.zeros((16, 2, 1024, 2)), persistent=False)

def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
Expand All @@ -432,7 +429,7 @@ def _basic_init(module):

def _init_cuda_graph_chunk(self):
# get dtype, device from registered buffer
dtype, device = self.cnn_cache_buffer.dtype, self.cnn_cache_buffer.device
dtype, device = self.in_proj.weight.dtype, self.in_proj.weight.device
# init cuda graph for streaming forward
with torch.no_grad():
for chunk_size in [30, 48, 96]:
Expand Down Expand Up @@ -491,6 +488,7 @@ def forward(self, x, mask, mu, t, spks=None, cond=None):

# time
t = self.t_embedder(t).unsqueeze(1) # (b, 1, c)

x = pack([x, mu], "b * t")[0]
if spks is not None:
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
Expand Down Expand Up @@ -541,16 +539,14 @@ def forward_chunk(self,

# create fake cache
if cnn_cache is None:
cnn_cache = [None] * len(self.blocks)
cnn_cache = torch.zeros(len(self.blocks), x.shape[0], 1024, 2, dtype=x.dtype, device=x.device)
if att_cache is None:
att_cache = [None] * len(self.blocks)
if att_cache[0] is not None:
last_att_len = att_cache.shape[3]
else:
last_att_len = 0
att_cache = torch.zeros(len(self.blocks), x.shape[0], self.blocks[0].attn.num_heads, 0, self.blocks[0].attn.head_dim * 2, dtype=x.dtype, device=x.device)

last_att_len = att_cache.shape[3]
chunk_size = x.shape[2]
mask = torch.ones(x.shape[0], chunk_size, last_att_len+chunk_size, dtype=torch.bool, device=x.device)
if self.use_cuda_graph and att_cache[0] is not None and chunk_size in self.graph_chunk and last_att_len <= self.max_size_chunk[chunk_size]:
if self.use_cuda_graph and att_cache is not None and chunk_size in self.graph_chunk and last_att_len <= self.max_size_chunk[chunk_size]:
padded_mask = torch.zeros((2, chunk_size, self.max_size_chunk[chunk_size]+chunk_size), dtype=mask.dtype, device=mask.device)
padded_mask[:, :, :mask.shape[-1]] = mask
padded_att_cache = torch.zeros((16, 2, 8, self.max_size_chunk[chunk_size], 128), dtype=att_cache.dtype, device=att_cache.device)
Expand All @@ -566,20 +562,22 @@ def forward_chunk(self,
new_att_cache = self.inference_buffers_chunk[chunk_size]['static_outputs'][2][:, :, :, :chunk_size+last_att_len, :]
else:
mask = None
x = self.blocks_forward_chunk(x, t, mask, cnn_cache, att_cache, self.cnn_cache_buffer, self.att_cache_buffer)
new_cnn_cache = self.cnn_cache_buffer
new_att_cache = self.att_cache_buffer[:, :, :, :last_att_len+chunk_size, :]
x, new_cnn_cache, new_att_cache = self.blocks_forward_chunk(x, t, mask, cnn_cache, att_cache)

return x, new_cnn_cache, new_att_cache

def blocks_forward_chunk(self, x, t, mask, cnn_cache=None, att_cache=None, cnn_cache_buffer=None, att_cache_buffer=None):
def blocks_forward_chunk(self, x, t, mask, cnn_cache, att_cache):
x = x.transpose(1, 2)
x = self.in_proj(x)

new_cnn_caches = []
new_att_caches = []

for b_idx, block in enumerate(self.blocks):
x, this_new_cnn_cache, this_new_att_cache \
= block.forward_chunk(x, t, cnn_cache[b_idx], att_cache[b_idx], mask)
cnn_cache_buffer[b_idx] = this_new_cnn_cache
att_cache_buffer[b_idx][:, :, :this_new_att_cache.shape[2], :] = this_new_att_cache
x, this_new_cnn_cache, this_new_att_cache = block.forward_chunk(x, t, cnn_cache[b_idx], att_cache[b_idx], mask)
new_cnn_caches.append(this_new_cnn_cache)
new_att_caches.append(this_new_att_cache)

x = self.final_layer(x, t)
x = x.transpose(1, 2)
return x
return x, torch.stack(new_cnn_caches), torch.stack(new_att_caches)
26 changes: 7 additions & 19 deletions cosyvoice2/flow/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,39 +65,29 @@ def scatter_cuda_graph(self, enable_cuda_graph: bool):
def inference(self,
token,
token_len,
prompt_token,
prompt_token_len,
prompt_feat,
prompt_feat_len,
embedding,
n_timesteps: int = 10,
):
assert token.shape[0] == 1

# xvec projection
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding)

# concat text and prompt_text
token_len = prompt_token_len + token_len
token = torch.concat([prompt_token, token], dim=1)

mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
token = self.input_embedding(torch.clamp(token, min=0)) * mask

# token encode
h, _ = self.encoder.forward(token, token_len)
h, h_lengths = self.encoder.forward(token, token_len)
h = self.encoder_proj(h)

# condition
mel_len1 = prompt_feat.shape[1]
mel_len2 = h.shape[1] - prompt_feat.shape[1]

conds = torch.zeros_like(h)
conds[:, :mel_len1] = prompt_feat
for i, j in enumerate(prompt_feat_len):
conds[i, :j] = prompt_feat[i, :j]
conds = conds.transpose(1, 2).contiguous()

mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
h_lengths = h_lengths.sum(dim=-1).squeeze(dim=1)
mask = (~make_pad_mask(h_lengths, max_len=h.shape[1])).to(h)

feat = self.decoder.forward(
mu=h.transpose(1, 2).contiguous(),
Expand All @@ -107,9 +97,7 @@ def inference(self,
n_timesteps=n_timesteps,
)

feat = feat[:, :, mel_len1:]
assert feat.shape[2] == mel_len2
return feat
return feat.float(), h_lengths

@torch.inference_mode()
def setup_cache(self,
Expand Down Expand Up @@ -149,7 +137,7 @@ def setup_cache(self,
feat, estimator_cnn_cache, estimator_att_cache = self.decoder.forward_chunk(
mu = h.transpose(1, 2).contiguous(),
spks = spk,
cond = mel.transpose(1, 2).contiguous(),
cond = mel.transpose(1, 2).contiguous().to(h.dtype),
n_timesteps = n_timesteps,
temperature = 1.0,
cnn_cache = None,
Expand Down
133 changes: 114 additions & 19 deletions cosyvoice2/flow/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@
from cosyvoice2.utils.mask import make_pad_mask


def get_data_ptr(tensor: torch.Tensor, dummy_buffer: torch.Tensor):
if tensor.numel() == 0:
return dummy_buffer.data_ptr()
else:
return tensor.contiguous().data_ptr()


"""
Inference wrapper
"""
Expand All @@ -32,13 +39,42 @@ def __init__(self, estimator: DiT, inference_cfg_rate:float=0.7):
# a maximum of 600s
self.register_buffer('rand_noise', torch.randn([1, self.out_channels, 50 * 600]), persistent=False)

self.register_buffer('cnn_cache_buffer', torch.zeros(16, 16, 2, 1024, 2), persistent=False)
self.register_buffer('att_cache_buffer', torch.zeros(16, 16, 2, 8, 1000, 128), persistent=False)
self.register_buffer('dummy_buffer', torch.zeros(1), persistent=False)

def scatter_cuda_graph(self, enable_cuda_graph: bool):
if enable_cuda_graph:
self.estimator._init_cuda_graph_all()

def forward_estimator(self, x, mask, mu, t, spks, cond):
if isinstance(self.estimator, torch.nn.Module):
return self.estimator(x, mask, mu, t, spks, cond)
else:
[estimator, stream], trt_engine = self.estimator.acquire_estimator()
# NOTE need to synchronize when switching stream
torch.cuda.current_stream().synchronize()
batch_size = x.size(0)
with stream:
estimator.set_input_shape('x', (batch_size, 80, x.size(2)))
estimator.set_input_shape('mask', (batch_size, 1, x.size(2)))
estimator.set_input_shape('mu', (batch_size, 80, x.size(2)))
estimator.set_input_shape('t', (batch_size,))
estimator.set_input_shape('spks', (batch_size, 80))
estimator.set_input_shape('cond', (batch_size, 80, x.size(2)))
data_ptrs = [x.contiguous().data_ptr(),
mask.contiguous().data_ptr(),
mu.contiguous().data_ptr(),
t.contiguous().data_ptr(),
spks.contiguous().data_ptr(),
cond.contiguous().data_ptr(),
x.data_ptr()]
for i, j in enumerate(data_ptrs):
estimator.set_tensor_address(trt_engine.get_tensor_name(i), j)
# run trt engine
assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
torch.cuda.current_stream().synchronize()
self.estimator.release_estimator(estimator, stream)
return x

def solve_euler(self, x, t_span, mu, mask, spks, cond):
"""
Fixed euler solver for ODEs.
Expand All @@ -55,7 +91,8 @@ def solve_euler(self, x, t_span, mu, mask, spks, cond):
cond: Not used but kept for future purposes
"""
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
t = t.unsqueeze(dim=0)
t_in = torch.zeros([x.shape[0] * 2], device=x.device, dtype=x.dtype)

assert self.inference_cfg_rate > 0, 'inference_cfg_rate better > 0'

# constant during denoising
Expand All @@ -65,18 +102,16 @@ def solve_euler(self, x, t_span, mu, mask, spks, cond):
cond_in = torch.cat([cond, torch.zeros_like(cond)], dim=0)

for step in range(1, len(t_span)):

x_in = torch.cat([x, x], dim=0)
t_in = torch.cat([t, t], dim=0)
t_in.fill_(t)

dphi_dt = self.estimator.forward(
x_in,
mask_in,
mu_in,
t_in,
dphi_dt = self.forward_estimator(
x_in, mask_in,
mu_in, t_in,
spks_in,
cond_in,
)

dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
x = x + dt * dphi_dt
Expand All @@ -88,12 +123,63 @@ def solve_euler(self, x, t_span, mu, mask, spks, cond):

@torch.inference_mode()
def forward(self, mu, mask, spks, cond, n_timesteps=10, temperature=1.0):
z = self.rand_noise[:, :, :mu.size(2)] * temperature
z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
# cosine scheduling
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
return self.solve_euler(z, t_span, mu, mask, spks, cond)


def forward_estimator_chunk(self, x, mu, t, spks, cond, cnn_cache, att_cache):
if isinstance(self.estimator, torch.nn.Module):
dphi_dt, this_new_cnn_cache, this_new_att_cache = self.estimator.forward_chunk(
x = x,
mu = mu,
t = t,
spks = spks,
cond = cond,
cnn_cache = cnn_cache,
att_cache = att_cache,
)
return dphi_dt, this_new_cnn_cache, this_new_att_cache
else:
[estimator, stream], trt_engine = self.estimator.acquire_estimator()
# NOTE need to synchronize when switching stream
torch.cuda.current_stream().synchronize()
batch_size = x.size(0)
with stream:
estimator.set_input_shape('x', (batch_size, 80, x.size(2)))
estimator.set_input_shape('mu', (batch_size, 80, x.size(2)))
estimator.set_input_shape('t', (batch_size,))
estimator.set_input_shape('spks', (batch_size, 80))
estimator.set_input_shape('cond', (batch_size, 80, x.size(2)))
estimator.set_input_shape('cnn_cache', cnn_cache.shape)
estimator.set_input_shape('att_cache', att_cache.shape)
new_cnn_cache = torch.empty_like(cnn_cache)
new_att_cache_shape = list(att_cache.shape)
new_att_cache_shape[3] += x.size(2)
new_att_cache = torch.empty(new_att_cache_shape, device=att_cache.device, dtype=x.dtype)
data_ptrs = [x.contiguous().data_ptr(),
mu.contiguous().data_ptr(),
t.contiguous().data_ptr(),
spks.contiguous().data_ptr(),
cond.contiguous().data_ptr(),
cnn_cache.contiguous().data_ptr(),
get_data_ptr(att_cache, self.dummy_buffer),
x.data_ptr(),
new_cnn_cache.data_ptr(),
get_data_ptr(new_att_cache, self.dummy_buffer)]

for i, j in enumerate(data_ptrs):
estimator.set_tensor_address(trt_engine.get_tensor_name(i), j)
# run trt engine
assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
torch.cuda.current_stream().synchronize()
self.estimator.release_estimator(estimator, stream)

return x, new_cnn_cache, new_att_cache


def solve_euler_chunk(self,
x:torch.Tensor,
t_span:torch.Tensor,
Expand Down Expand Up @@ -122,14 +208,18 @@ def solve_euler_chunk(self,
assert self.inference_cfg_rate > 0, 'cfg rate should be > 0'

t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
t = t.unsqueeze(dim=0) # (b,)
t_in = torch.zeros([x.shape[0] * 2], device=x.device, dtype=x.dtype)

# setup initial cache
if cnn_cache is None:
cnn_cache = [None for _ in range(len(t_span)-1)]
cnn_cache = torch.zeros((len(t_span)-1, 16, x.shape[0] * 2, 1024, 2), device=x.device, dtype=x.dtype)
if att_cache is None:
att_cache = [None for _ in range(len(t_span)-1)]
att_cache = torch.empty((len(t_span)-1, 16, x.shape[0] * 2, 8, 0, 128), device=x.device, dtype=x.dtype)
# next chunk's cache at each timestep
new_cnn_caches = []
new_att_caches = []

if att_cache[0] is not None:
last_att_len = att_cache.shape[4]
Expand All @@ -146,27 +236,31 @@ def solve_euler_chunk(self,
this_att_cache = att_cache[step-1]
this_cnn_cache = cnn_cache[step-1]

dphi_dt, this_new_cnn_cache, this_new_att_cache = self.estimator.forward_chunk(
x = x.repeat(2, 1, 1),
x_in = x.repeat(2, 1, 1)
t_in.fill_(t)

dphi_dt, this_new_cnn_cache, this_new_att_cache = self.forward_estimator_chunk(
x = x_in,
mu = mu_in,
t = t.repeat(2),
t = t_in,
spks = spks_in,
cond = cond_in,
cnn_cache = this_cnn_cache,
att_cache = this_att_cache,
)

dphi_dt, cfg_dphi_dt = dphi_dt.chunk(2, dim=0)
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
x = x + dt * dphi_dt
t = t + dt
if step < len(t_span) - 1:
dt = t_span[step + 1] - t

self.cnn_cache_buffer[step-1] = this_new_cnn_cache
self.att_cache_buffer[step-1][:, :, :, :x.shape[2]+last_att_len, :] = this_new_att_cache
new_cnn_caches.append(this_new_cnn_cache)
new_att_caches.append(this_new_att_cache)

cnn_cache = self.cnn_cache_buffer
att_cache = self.att_cache_buffer[:, :, :, :, :x.shape[2]+last_att_len, :]
cnn_cache = torch.stack(new_cnn_caches)
att_cache = torch.stack(new_att_caches)
return x, cnn_cache, att_cache

@torch.inference_mode()
Expand All @@ -190,6 +284,7 @@ def forward_chunk(self,
# get offset from att_cache
offset = att_cache.shape[4] if att_cache is not None else 0
z = self.rand_noise[:, :, offset:offset+mu.size(2)] * temperature
z = z.to(mu.dtype)
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
# cosine scheduling
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
Expand Down
1 change: 0 additions & 1 deletion examples-vllm-stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def stream_client(model, history, tools, token2wav=None, output_stream=None, pro
model = StepAudio2(api_url, model_name)
token2wav = Token2wav('Step-Audio-2-mini/token2wav')
tokens = [1493, 4299, 4218, 2049, 528, 2752, 4850, 4569, 4575, 6372, 2127, 4068, 2312, 4993, 4769, 2300, 226, 2175, 2160, 2152, 6311, 6065, 4859, 5102, 4615, 6534, 6426, 1763, 2249, 2209, 5938, 1725, 6048, 3816, 6058, 958, 63, 4460, 5914, 2379, 735, 5319, 4593, 2328, 890, 35, 751, 1483, 1484, 1483, 2112, 303, 4753, 2301, 5507, 5588, 5261, 5744, 5501, 2341, 2001, 2252, 2344, 1860, 2031, 414, 4366, 4366, 6059, 5300, 4814, 5092, 5100, 1923, 3054, 4320, 4296, 2148, 4371, 5831, 5084, 5027, 4946, 4946, 2678, 575, 575, 521, 518, 638, 1367, 2804, 3402, 4299]
token2wav.set_stream_cache(prompt_wav)
token2wav.stream(tokens[:CHUNK_SIZE + token2wav.flow.pre_lookahead_len], prompt_wav=prompt_wav) # Warm up

output_stream = Path('output-stream.pcm')
Expand Down
Loading