Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion service/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,14 @@

@app.get("/health")
async def health() -> Response:
return Response(status_code=200)
global model
try:
await model.engine.check_health()
except Exception as e:
logger.critical(f"Error in model health check: {e}")
raise HTTPException(status_code=500, detail=str(e))

return Response(content="OK", status_code=200)

@app.post("/tts", response_model=TTSResponse)
async def text_to_speech(requests: TTSRequest):
Expand Down
4 changes: 2 additions & 2 deletions service/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ class TTSRequest(BaseModel):
speaker: Speakers

class TTSMetrics(BaseModel):
time_to_first_token: List[float]
time_to_last_token: List[float]
time_to_first_token: float
time_to_last_token: float
time_to_decode_audio: float
input_tokens: List[int]
decoding_tokens: List[int]
Expand Down
144 changes: 87 additions & 57 deletions service/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
sys.path.append('omni/')

import time
import uuid
import torch
import numpy as np
from transformers import MimiModel, AutoTokenizer
from typing import List, Dict, Any, Tuple, Optional
from functools import partial

from vllm import SamplingParams, AsyncEngineArgs
from vllm import SamplingParams, AsyncEngineArgs, RequestOutput
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.inputs import TokensPrompt

Expand Down Expand Up @@ -66,9 +67,22 @@ def __init__(self, model_path, device):
logits_processors=logits_processors
)

def prepare_tokens(self, incoming_text, speaker) -> List[int]:
incoming_tokens = self.text_tokenizer.encode(incoming_text)
def prepare_tokens(self, incoming_text, speaker, prompt_tokens: dict = None) -> Tuple[List[int], List[int]]:

if prompt_tokens:
incoming_tokens = self.text_tokenizer.encode(' ' + incoming_text)
input_tokens = np.hstack([
self.text_modality_token,
prompt_tokens[TEXT],
incoming_tokens,
self.convert_token,
self.acoustic_modality_token,
self.text_tokenizer.encode(speaker),
prompt_tokens[MIMI]
])
return incoming_tokens, input_tokens.tolist()

incoming_tokens = self.text_tokenizer.encode(incoming_text)
input_tokens = np.hstack([
self.text_modality_token,
incoming_tokens,
Expand All @@ -77,78 +91,84 @@ def prepare_tokens(self, incoming_text, speaker) -> List[int]:
self.text_tokenizer.encode(speaker)
])

return input_tokens.tolist()
return incoming_tokens, input_tokens.tolist()

def get_generation_output(self, output: RequestOutput) -> Tuple[np.ndarray, np.ndarray, Dict[str, Any]]:
output_tokens = np.array(output.outputs[0].token_ids)
end = np.where(output_tokens == self.stop_token[0])[0]

if len(end) > 0:
end = end[0]
else:
end = len(output_tokens)

output_tokens = output_tokens[:end]

output_token_ids = output_tokens.copy()
output_tokens = output_tokens - cfg.OFFSET[MIMI]
output_tokens = deserialize_tokens(output_tokens)

assert np.all(output_tokens >= 0), f'Negative token index generated'

metrics = {
'time_to_first_token': output.metrics.first_token_time - output.metrics.first_scheduled_time,
'time_to_last_token': output.metrics.finished_time - output.metrics.first_scheduled_time,
'input_tokens': len(output.prompt_token_ids),
'decoding_tokens': len(output.outputs[0].token_ids)
}

return output_token_ids, output_tokens, metrics

async def generate_async(self,
text: str,
speaker: Optional[str] = '[spkr_hifi_tts_9017]',
request_id: Optional[str] = None
request_id: Optional[str] = None,
max_context_words: Optional[int] = 10
) -> Dict[str, Any]:

start_time = time.time()
batch_text = sanitize_text(text)
input_tokens = [self.prepare_tokens(text, speaker) for text in batch_text]
batch_text = sanitize_text(text, max_context_words)
prompt_tokens = {}
overall_metrics = []
mimi_tokens = []

logger.info(f'Texts after preprocessing: {batch_text}, {speaker}', extra={'request_id': request_id})
logger.info(f'Input tokens shape: {len(input_tokens)} and batch size: {len(batch_text)}', extra={'request_id': request_id})

prompt = TokensPrompt(prompt_token_ids=input_tokens[0])

results_generator = self.engine.generate(
prompt=prompt,
sampling_params=self.sampling_params,
request_id=request_id
)
for text in batch_text:
text_token_ids, input_tokens = self.prepare_tokens(text, speaker, prompt_tokens)
logger.info(f'Input tokens shape: {len(input_tokens)}', extra={'request_id': request_id})
prompt = TokensPrompt(prompt_token_ids=input_tokens)

preds = []
results_generator = self.engine.generate(
prompt=prompt,
sampling_params=self.sampling_params,
request_id=str(uuid.uuid4())
)

async for request_output in results_generator:
if request_output.finished:
preds.append(request_output)
async for request_output in results_generator:
if request_output.finished:
output = request_output

mimi_tokens = []
output_token_ids, output_tokens, generation_metrics = self.get_generation_output(output=output)
logger.info(f'Output tokens shape: {output_tokens.shape}', extra={'request_id': request_id})

metrics = {
'time_to_first_token': [],
'time_to_last_token': [],
'input_tokens': [],
'decoding_tokens': []
}
overall_metrics.append(generation_metrics)
mimi_tokens.append(output_tokens)

for idx, request_output in enumerate(preds):
o = np.array(request_output.outputs[0].token_ids)
end = np.where(o == self.stop_token[0])[0]
if len(end) > 0:
end = end[0]
else:
end = len(o)
o = o[:end]
o = o - cfg.OFFSET[MIMI]
o = deserialize_tokens(o)
assert np.all(o >= 0), f'Negative token index generated for batch {idx}'

metrics['time_to_first_token'].append(
request_output.metrics.first_token_time - request_output.metrics.first_scheduled_time
)
metrics['time_to_last_token'].append(
request_output.metrics.finished_time - request_output.metrics.first_scheduled_time
)
metrics['input_tokens'].append(len(request_output.prompt_token_ids))
metrics['decoding_tokens'].append(len(request_output.outputs[0].token_ids))

mimi_tokens.append(o)
prompt_tokens = {
TEXT: text_token_ids,
MIMI: output_token_ids
}

mimi_tokens = np.concatenate(mimi_tokens, axis=1)
logger.info(f'Mimi tokens shape: {mimi_tokens.shape}')

audio, decode_time = self.decode_audio(mimi_tokens)

metrics = TTSMetrics(
time_to_first_token=metrics['time_to_first_token'],
time_to_last_token=metrics['time_to_last_token'],
time_to_first_token=overall_metrics[0]['time_to_first_token'],
time_to_last_token=sum([x['time_to_last_token'] for x in overall_metrics]),
time_to_decode_audio=decode_time,
input_tokens=metrics['input_tokens'],
decoding_tokens=metrics['decoding_tokens'],
input_tokens=[x['input_tokens'] for x in overall_metrics],
decoding_tokens=[x['decoding_tokens'] for x in overall_metrics],
generate_end_to_end_time=time.time()-start_time
)

Expand All @@ -167,8 +187,18 @@ def decode_audio(self, audio_tokens) -> Tuple[np.ndarray, float]:

return audio, end_time - start_time

if __name__ == '__main__':
tts = TTS('cmeraki/mimi_tts_hf', 'cuda:0')
result = tts.generate('Long ago, in a distant kingdom between emerald hills and sapphire lakes, magic flowed freely. A wise king ruled, ensuring peace and prosperity.')

async def main():
model = TTS('cmeraki/mimi_tts_hf_stage', 'cuda:0')
result = await model.generate_async(
'Long ago, in a distant kingdom between emerald hills and sapphire lakes, magic flowed freely. This is a second sentence.',
speaker='[spkr_hifi_tts_9017]',
request_id=str(uuid.uuid4()),
max_context_words=20
)

print(result['metrics'])

if __name__ == '__main__':
import asyncio
asyncio.run(main())
30 changes: 21 additions & 9 deletions service/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,31 +19,47 @@ def deserialize_tokens(tokens):
return acoustic_tokens


def sanitize_text(text: str) -> list[str]:
def sanitize_text(text: str, max_context_words: int) -> list[str]:
"""
Sanitize text to be used for TTS

Args:
text (str): Text to sanitize
max_context_words (int): Maximum number of words in a sentence

Returns:
list[str]: List of sentences, split by punctuation (., !, ?)
"""
text = text.lower()

# Remove more than one newlines and tabs
text = re.sub(r'\n+', ' ', text)
text = re.sub(r'[ \t]+', ' ', text)

# Remove non-alphanumeric characters except for , . ? !
# allowed_pattern = r'[^a-z0-9\s,\.?\n\!]'
# text = re.sub(allowed_pattern, '', text)

# Remove more than one punctuation mark
text = re.sub(r'([,\.?])+', r'\1', text)

# pattern = r'([.!?])'
# segments = re.split(pattern, text)
# Split sentences by max context length
total_words = text.split(' ')
sentences = []
current_sentence = ''

for i in range(0, len(total_words), max_context_words):
current_sentence = ' '.join(total_words[i:i+max_context_words])
sentences.append(current_sentence.strip())

sentences = [text.strip()]
return sentences

# current_sentence = ''
# Split sentences by punctuation (., !, ?)
pattern = r'([.!?])'
segments = re.split(pattern, text)

sentences = []
current_sentence = ''

for segment in segments:
current_sentence += segment
Expand Down Expand Up @@ -77,15 +93,11 @@ def alternative_logits_processor(
Returns:
torch.Tensor - Processed logits. Shape (vocab_size)
"""
logger.debug(f'Stop token: {stop_token}')

codebook_indices = len(past_token_ids) % num_codebooks

start_idx = offset + codebook_indices * codebook_size
end_idx = offset + (codebook_indices + 1) * codebook_size

logger.debug(f'Past_token_ids: {len(past_token_ids)}, codebook indices: {codebook_indices}, start idx: {start_idx}, end idx: {end_idx}')

mask = torch.zeros_like(logits)
mask[start_idx:end_idx] = 1
mask[stop_token] = 1
Expand Down