From fa42e82dc76d038bfa1351c39cae765bc13343e5 Mon Sep 17 00:00:00 2001 From: romit <11757603+romitjain@users.noreply.github.com> Date: Mon, 4 Nov 2024 15:20:11 +0530 Subject: [PATCH 1/6] Added long infer handling --- service/tts.py | 105 ++++++++++++++++++++++++++++++----------------- service/utils.py | 14 ++----- 2 files changed, 71 insertions(+), 48 deletions(-) diff --git a/service/tts.py b/service/tts.py index 16e7438..a3c09bb 100644 --- a/service/tts.py +++ b/service/tts.py @@ -8,7 +8,7 @@ from typing import List, Dict, Any, Tuple, Optional from functools import partial -from vllm import SamplingParams, AsyncEngineArgs +from vllm import SamplingParams, AsyncEngineArgs, RequestOutput from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.inputs import TokensPrompt @@ -66,9 +66,21 @@ def __init__(self, model_path, device): logits_processors=logits_processors ) - def prepare_tokens(self, incoming_text, speaker) -> List[int]: + def prepare_tokens(self, incoming_text, speaker, prompt_tokens: dict = None) -> List[int]: incoming_tokens = self.text_tokenizer.encode(incoming_text) + if prompt_tokens: + input_tokens = np.hstack([ + self.text_modality_token, + prompt_tokens[TEXT], + incoming_tokens, + self.convert_token, + self.acoustic_modality_token, + self.text_tokenizer.encode(speaker), + prompt_tokens[MIMI], + ]) + return input_tokens.tolist() + input_tokens = np.hstack([ self.text_modality_token, incoming_tokens, @@ -79,34 +91,8 @@ def prepare_tokens(self, incoming_text, speaker) -> List[int]: return input_tokens.tolist() - async def generate_async(self, - text: str, - speaker: Optional[str] = '[spkr_hifi_tts_9017]', - request_id: Optional[str] = None - ) -> Dict[str, Any]: - - start_time = time.time() - batch_text = sanitize_text(text) - input_tokens = [self.prepare_tokens(text, speaker) for text in batch_text] - - logger.info(f'Texts after preprocessing: {batch_text}, {speaker}', extra={'request_id': request_id}) - logger.info(f'Input tokens shape: {len(input_tokens)} and batch size: {len(batch_text)}', extra={'request_id': request_id}) - - prompt = TokensPrompt(prompt_token_ids=input_tokens[0]) - - results_generator = self.engine.generate( - prompt=prompt, - sampling_params=self.sampling_params, - request_id=request_id - ) - - preds = [] - - async for request_output in results_generator: - if request_output.finished: - preds.append(request_output) - - mimi_tokens = [] + def get_generation_output(self, output: List[RequestOutput]) -> Tuple[np.ndarray, Dict[str, Any]]: + output_tokens = [] metrics = { 'time_to_first_token': [], @@ -115,16 +101,19 @@ async def generate_async(self, 'decoding_tokens': [] } - for idx, request_output in enumerate(preds): + for idx, request_output in enumerate(output): o = np.array(request_output.outputs[0].token_ids) end = np.where(o == self.stop_token[0])[0] + if len(end) > 0: end = end[0] else: end = len(o) + o = o[:end] o = o - cfg.OFFSET[MIMI] o = deserialize_tokens(o) + assert np.all(o >= 0), f'Negative token index generated for batch {idx}' metrics['time_to_first_token'].append( @@ -136,19 +125,59 @@ async def generate_async(self, metrics['input_tokens'].append(len(request_output.prompt_token_ids)) metrics['decoding_tokens'].append(len(request_output.outputs[0].token_ids)) - mimi_tokens.append(o) + output_tokens.append(o) + + output_tokens = np.concatenate(output_tokens, axis=1) + return output_tokens, metrics + + async def generate_async(self, + text: str, + speaker: Optional[str] = '[spkr_hifi_tts_9017]', + request_id: Optional[str] = None + ) -> Dict[str, Any]: + + start_time = time.time() + batch_text = sanitize_text(text) + prompt_tokens = {} + overall_metrics = [] + mimi_tokens = [] + + logger.info(f'Texts after preprocessing: {batch_text}, {speaker}', extra={'request_id': request_id}) + + for text in batch_text: + input_tokens = self.prepare_tokens(text, speaker, prompt_tokens) + logger.info(f'Input tokens shape: {len(input_tokens)}', extra={'request_id': request_id}) + prompt = TokensPrompt(prompt_token_ids=input_tokens) + + results_generator = self.engine.generate( + prompt=prompt, + sampling_params=self.sampling_params, + request_id=request_id + ) + + async for request_output in results_generator: + if request_output.finished: + output = request_output + + output_tokens, generation_metrics = self.get_generation_output(output=output) + logger.info(f'Output tokens shape: {output_tokens.shape}', extra={'request_id': request_id}) + + overall_metrics.append(generation_metrics) + mimi_tokens.append(output_tokens) - mimi_tokens = np.concatenate(mimi_tokens, axis=1) - logger.info(f'Mimi tokens shape: {mimi_tokens.shape}') + prompt_tokens = { + TEXT: input_tokens, + MIMI: output_tokens + } audio, decode_time = self.decode_audio(mimi_tokens) metrics = TTSMetrics( - time_to_first_token=metrics['time_to_first_token'], - time_to_last_token=metrics['time_to_last_token'], + time_to_first_token=overall_metrics[0]['time_to_first_token'], + time_to_last_token=sum([x['time_to_last_token'] for x in overall_metrics]), time_to_decode_audio=decode_time, - input_tokens=metrics['input_tokens'], - decoding_tokens=metrics['decoding_tokens'], + input_tokens=[x['input_tokens'] for x in overall_metrics], + decoding_tokens=[x['decoding_tokens'] for x in overall_metrics], generate_end_to_end_time=time.time()-start_time ) diff --git a/service/utils.py b/service/utils.py index 80be61f..fb3a5b0 100644 --- a/service/utils.py +++ b/service/utils.py @@ -37,13 +37,11 @@ def sanitize_text(text: str) -> list[str]: # text = re.sub(allowed_pattern, '', text) text = re.sub(r'([,\.?])+', r'\1', text) - # pattern = r'([.!?])' - # segments = re.split(pattern, text) + pattern = r'([.!?])' + segments = re.split(pattern, text) - sentences = [text.strip()] - return sentences - - # current_sentence = '' + sentences = [] + current_sentence = '' for segment in segments: current_sentence += segment @@ -77,15 +75,11 @@ def alternative_logits_processor( Returns: torch.Tensor - Processed logits. Shape (vocab_size) """ - logger.debug(f'Stop token: {stop_token}') - codebook_indices = len(past_token_ids) % num_codebooks start_idx = offset + codebook_indices * codebook_size end_idx = offset + (codebook_indices + 1) * codebook_size - logger.debug(f'Past_token_ids: {len(past_token_ids)}, codebook indices: {codebook_indices}, start idx: {start_idx}, end idx: {end_idx}') - mask = torch.zeros_like(logits) mask[start_idx:end_idx] = 1 mask[stop_token] = 1 From 103e30f3abac82d14ef722549f4012fcb398ddb3 Mon Sep 17 00:00:00 2001 From: romitjain <11757603+romitjain@users.noreply.github.com> Date: Mon, 4 Nov 2024 12:43:15 +0000 Subject: [PATCH 2/6] Updated long infer --- service/inference.py | 9 ++++- service/models.py | 4 +- service/tts.py | 89 ++++++++++++++++++++++---------------------- 3 files changed, 54 insertions(+), 48 deletions(-) diff --git a/service/inference.py b/service/inference.py index 90bf171..6aafe0c 100644 --- a/service/inference.py +++ b/service/inference.py @@ -33,7 +33,14 @@ @app.get("/health") async def health() -> Response: - return Response(status_code=200) + global model + try: + await model.engine.check_health() + except Exception as e: + logger.critical(f"Error in model health check: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + return Response(content="OK", status_code=200) @app.post("/tts", response_model=TTSResponse) async def text_to_speech(requests: TTSRequest): diff --git a/service/models.py b/service/models.py index 0efd013..2e5a3b5 100644 --- a/service/models.py +++ b/service/models.py @@ -34,8 +34,8 @@ class TTSRequest(BaseModel): speaker: Speakers class TTSMetrics(BaseModel): - time_to_first_token: List[float] - time_to_last_token: List[float] + time_to_first_token: float + time_to_last_token: float time_to_decode_audio: float input_tokens: List[int] decoding_tokens: List[int] diff --git a/service/tts.py b/service/tts.py index a3c09bb..1fea221 100644 --- a/service/tts.py +++ b/service/tts.py @@ -2,6 +2,7 @@ sys.path.append('omni/') import time +import uuid import torch import numpy as np from transformers import MimiModel, AutoTokenizer @@ -66,10 +67,10 @@ def __init__(self, model_path, device): logits_processors=logits_processors ) - def prepare_tokens(self, incoming_text, speaker, prompt_tokens: dict = None) -> List[int]: - incoming_tokens = self.text_tokenizer.encode(incoming_text) + def prepare_tokens(self, incoming_text, speaker, prompt_tokens: dict = None) -> Tuple[List[int], List[int]]: if prompt_tokens: + incoming_tokens = self.text_tokenizer.encode(' ' + incoming_text) input_tokens = np.hstack([ self.text_modality_token, prompt_tokens[TEXT], @@ -77,10 +78,11 @@ def prepare_tokens(self, incoming_text, speaker, prompt_tokens: dict = None) -> self.convert_token, self.acoustic_modality_token, self.text_tokenizer.encode(speaker), - prompt_tokens[MIMI], + prompt_tokens[MIMI] ]) - return input_tokens.tolist() + return incoming_tokens, input_tokens.tolist() + incoming_tokens = self.text_tokenizer.encode(incoming_text) input_tokens = np.hstack([ self.text_modality_token, incoming_tokens, @@ -89,46 +91,33 @@ def prepare_tokens(self, incoming_text, speaker, prompt_tokens: dict = None) -> self.text_tokenizer.encode(speaker) ]) - return input_tokens.tolist() + return incoming_tokens, input_tokens.tolist() - def get_generation_output(self, output: List[RequestOutput]) -> Tuple[np.ndarray, Dict[str, Any]]: - output_tokens = [] + def get_generation_output(self, output: RequestOutput) -> Tuple[np.ndarray, np.ndarray, Dict[str, Any]]: + output_tokens = np.array(output.outputs[0].token_ids) + end = np.where(output_tokens == self.stop_token[0])[0] - metrics = { - 'time_to_first_token': [], - 'time_to_last_token': [], - 'input_tokens': [], - 'decoding_tokens': [] - } + if len(end) > 0: + end = end[0] + else: + end = len(output_tokens) - for idx, request_output in enumerate(output): - o = np.array(request_output.outputs[0].token_ids) - end = np.where(o == self.stop_token[0])[0] + output_tokens = output_tokens[:end] - if len(end) > 0: - end = end[0] - else: - end = len(o) + output_token_ids = output_tokens.copy() + output_tokens = output_tokens - cfg.OFFSET[MIMI] + output_tokens = deserialize_tokens(output_tokens) - o = o[:end] - o = o - cfg.OFFSET[MIMI] - o = deserialize_tokens(o) - - assert np.all(o >= 0), f'Negative token index generated for batch {idx}' - - metrics['time_to_first_token'].append( - request_output.metrics.first_token_time - request_output.metrics.first_scheduled_time - ) - metrics['time_to_last_token'].append( - request_output.metrics.finished_time - request_output.metrics.first_scheduled_time - ) - metrics['input_tokens'].append(len(request_output.prompt_token_ids)) - metrics['decoding_tokens'].append(len(request_output.outputs[0].token_ids)) + assert np.all(output_tokens >= 0), f'Negative token index generated' - output_tokens.append(o) + metrics = { + 'time_to_first_token': output.metrics.first_token_time - output.metrics.first_scheduled_time, + 'time_to_last_token': output.metrics.finished_time - output.metrics.first_scheduled_time, + 'input_tokens': len(output.prompt_token_ids), + 'decoding_tokens': len(output.outputs[0].token_ids) + } - output_tokens = np.concatenate(output_tokens, axis=1) - return output_tokens, metrics + return output_token_ids, output_tokens, metrics async def generate_async(self, text: str, @@ -145,31 +134,32 @@ async def generate_async(self, logger.info(f'Texts after preprocessing: {batch_text}, {speaker}', extra={'request_id': request_id}) for text in batch_text: - input_tokens = self.prepare_tokens(text, speaker, prompt_tokens) + text_token_ids, input_tokens = self.prepare_tokens(text, speaker, prompt_tokens) logger.info(f'Input tokens shape: {len(input_tokens)}', extra={'request_id': request_id}) prompt = TokensPrompt(prompt_token_ids=input_tokens) results_generator = self.engine.generate( prompt=prompt, sampling_params=self.sampling_params, - request_id=request_id + request_id=str(uuid.uuid4()) ) async for request_output in results_generator: if request_output.finished: output = request_output - output_tokens, generation_metrics = self.get_generation_output(output=output) + output_token_ids, output_tokens, generation_metrics = self.get_generation_output(output=output) logger.info(f'Output tokens shape: {output_tokens.shape}', extra={'request_id': request_id}) overall_metrics.append(generation_metrics) mimi_tokens.append(output_tokens) prompt_tokens = { - TEXT: input_tokens, - MIMI: output_tokens + TEXT: text_token_ids, + MIMI: output_token_ids } + mimi_tokens = np.concatenate(mimi_tokens, axis=1) audio, decode_time = self.decode_audio(mimi_tokens) metrics = TTSMetrics( @@ -196,8 +186,17 @@ def decode_audio(self, audio_tokens) -> Tuple[np.ndarray, float]: return audio, end_time - start_time -if __name__ == '__main__': - tts = TTS('cmeraki/mimi_tts_hf', 'cuda:0') - result = tts.generate('Long ago, in a distant kingdom between emerald hills and sapphire lakes, magic flowed freely. A wise king ruled, ensuring peace and prosperity.') + +async def main(): + model = TTS('cmeraki/mimi_tts_hf', 'cuda:0') + result = await model.generate_async( + 'Long ago, in a distant kingdom between emerald hills and sapphire lakes, magic flowed freely. This is a second sentence.', + speaker='[spkr_hifi_tts_9017]', + request_id=str(uuid.uuid4()) + ) print(result['metrics']) + +if __name__ == '__main__': + import asyncio + asyncio.run(main()) From 03d335596d9918ca43e7ba17b999ce611137132d Mon Sep 17 00:00:00 2001 From: romitjain <11757603+romitjain@users.noreply.github.com> Date: Tue, 5 Nov 2024 09:02:20 +0000 Subject: [PATCH 3/6] Added long infer at word boundary --- service/tts.py | 10 ++++++---- service/utils.py | 20 +++++++++++++++++++- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/service/tts.py b/service/tts.py index 1fea221..bd5883f 100644 --- a/service/tts.py +++ b/service/tts.py @@ -122,11 +122,12 @@ def get_generation_output(self, output: RequestOutput) -> Tuple[np.ndarray, np.n async def generate_async(self, text: str, speaker: Optional[str] = '[spkr_hifi_tts_9017]', - request_id: Optional[str] = None + request_id: Optional[str] = None, + max_context_words: Optional[int] = 10 ) -> Dict[str, Any]: start_time = time.time() - batch_text = sanitize_text(text) + batch_text = sanitize_text(text, max_context_words) prompt_tokens = {} overall_metrics = [] mimi_tokens = [] @@ -188,11 +189,12 @@ def decode_audio(self, audio_tokens) -> Tuple[np.ndarray, float]: async def main(): - model = TTS('cmeraki/mimi_tts_hf', 'cuda:0') + model = TTS('cmeraki/mimi_tts_hf_stage', 'cuda:0') result = await model.generate_async( 'Long ago, in a distant kingdom between emerald hills and sapphire lakes, magic flowed freely. This is a second sentence.', speaker='[spkr_hifi_tts_9017]', - request_id=str(uuid.uuid4()) + request_id=str(uuid.uuid4()), + max_context_words=20 ) print(result['metrics']) diff --git a/service/utils.py b/service/utils.py index fb3a5b0..c55e29f 100644 --- a/service/utils.py +++ b/service/utils.py @@ -19,24 +19,42 @@ def deserialize_tokens(tokens): return acoustic_tokens -def sanitize_text(text: str) -> list[str]: +def sanitize_text(text: str, max_context_words: int) -> list[str]: """ Sanitize text to be used for TTS Args: text (str): Text to sanitize + max_context_words (int): Maximum number of words in a sentence Returns: list[str]: List of sentences, split by punctuation (., !, ?) """ text = text.lower() + + # Remove more than one newlines and tabs text = re.sub(r'\n+', ' ', text) text = re.sub(r'[ \t]+', ' ', text) + # Remove non-alphanumeric characters except for , . ? ! # allowed_pattern = r'[^a-z0-9\s,\.?\n\!]' # text = re.sub(allowed_pattern, '', text) + + # Remove more than one punctuation mark text = re.sub(r'([,\.?])+', r'\1', text) + # Split sentences by max context length + total_words = text.split(' ') + sentences = [] + current_sentence = '' + + for i in range(0, len(total_words), max_context_words): + current_sentence = ' '.join(total_words[i:i+max_context_words]) + sentences.append(current_sentence.strip()) + + return sentences + + # Split sentences by punctuation (., !, ?) pattern = r'([.!?])' segments = re.split(pattern, text) From b24e40a34966960db0b60762d33ef0ca1e956301 Mon Sep 17 00:00:00 2001 From: romit <11757603+romitjain@users.noreply.github.com> Date: Mon, 4 Nov 2024 15:20:11 +0530 Subject: [PATCH 4/6] Added long infer handling --- service/tts.py | 105 ++++++++++++++++++++++++++++++----------------- service/utils.py | 14 ++----- 2 files changed, 71 insertions(+), 48 deletions(-) diff --git a/service/tts.py b/service/tts.py index 16e7438..a3c09bb 100644 --- a/service/tts.py +++ b/service/tts.py @@ -8,7 +8,7 @@ from typing import List, Dict, Any, Tuple, Optional from functools import partial -from vllm import SamplingParams, AsyncEngineArgs +from vllm import SamplingParams, AsyncEngineArgs, RequestOutput from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.inputs import TokensPrompt @@ -66,9 +66,21 @@ def __init__(self, model_path, device): logits_processors=logits_processors ) - def prepare_tokens(self, incoming_text, speaker) -> List[int]: + def prepare_tokens(self, incoming_text, speaker, prompt_tokens: dict = None) -> List[int]: incoming_tokens = self.text_tokenizer.encode(incoming_text) + if prompt_tokens: + input_tokens = np.hstack([ + self.text_modality_token, + prompt_tokens[TEXT], + incoming_tokens, + self.convert_token, + self.acoustic_modality_token, + self.text_tokenizer.encode(speaker), + prompt_tokens[MIMI], + ]) + return input_tokens.tolist() + input_tokens = np.hstack([ self.text_modality_token, incoming_tokens, @@ -79,34 +91,8 @@ def prepare_tokens(self, incoming_text, speaker) -> List[int]: return input_tokens.tolist() - async def generate_async(self, - text: str, - speaker: Optional[str] = '[spkr_hifi_tts_9017]', - request_id: Optional[str] = None - ) -> Dict[str, Any]: - - start_time = time.time() - batch_text = sanitize_text(text) - input_tokens = [self.prepare_tokens(text, speaker) for text in batch_text] - - logger.info(f'Texts after preprocessing: {batch_text}, {speaker}', extra={'request_id': request_id}) - logger.info(f'Input tokens shape: {len(input_tokens)} and batch size: {len(batch_text)}', extra={'request_id': request_id}) - - prompt = TokensPrompt(prompt_token_ids=input_tokens[0]) - - results_generator = self.engine.generate( - prompt=prompt, - sampling_params=self.sampling_params, - request_id=request_id - ) - - preds = [] - - async for request_output in results_generator: - if request_output.finished: - preds.append(request_output) - - mimi_tokens = [] + def get_generation_output(self, output: List[RequestOutput]) -> Tuple[np.ndarray, Dict[str, Any]]: + output_tokens = [] metrics = { 'time_to_first_token': [], @@ -115,16 +101,19 @@ async def generate_async(self, 'decoding_tokens': [] } - for idx, request_output in enumerate(preds): + for idx, request_output in enumerate(output): o = np.array(request_output.outputs[0].token_ids) end = np.where(o == self.stop_token[0])[0] + if len(end) > 0: end = end[0] else: end = len(o) + o = o[:end] o = o - cfg.OFFSET[MIMI] o = deserialize_tokens(o) + assert np.all(o >= 0), f'Negative token index generated for batch {idx}' metrics['time_to_first_token'].append( @@ -136,19 +125,59 @@ async def generate_async(self, metrics['input_tokens'].append(len(request_output.prompt_token_ids)) metrics['decoding_tokens'].append(len(request_output.outputs[0].token_ids)) - mimi_tokens.append(o) + output_tokens.append(o) + + output_tokens = np.concatenate(output_tokens, axis=1) + return output_tokens, metrics + + async def generate_async(self, + text: str, + speaker: Optional[str] = '[spkr_hifi_tts_9017]', + request_id: Optional[str] = None + ) -> Dict[str, Any]: + + start_time = time.time() + batch_text = sanitize_text(text) + prompt_tokens = {} + overall_metrics = [] + mimi_tokens = [] + + logger.info(f'Texts after preprocessing: {batch_text}, {speaker}', extra={'request_id': request_id}) + + for text in batch_text: + input_tokens = self.prepare_tokens(text, speaker, prompt_tokens) + logger.info(f'Input tokens shape: {len(input_tokens)}', extra={'request_id': request_id}) + prompt = TokensPrompt(prompt_token_ids=input_tokens) + + results_generator = self.engine.generate( + prompt=prompt, + sampling_params=self.sampling_params, + request_id=request_id + ) + + async for request_output in results_generator: + if request_output.finished: + output = request_output + + output_tokens, generation_metrics = self.get_generation_output(output=output) + logger.info(f'Output tokens shape: {output_tokens.shape}', extra={'request_id': request_id}) + + overall_metrics.append(generation_metrics) + mimi_tokens.append(output_tokens) - mimi_tokens = np.concatenate(mimi_tokens, axis=1) - logger.info(f'Mimi tokens shape: {mimi_tokens.shape}') + prompt_tokens = { + TEXT: input_tokens, + MIMI: output_tokens + } audio, decode_time = self.decode_audio(mimi_tokens) metrics = TTSMetrics( - time_to_first_token=metrics['time_to_first_token'], - time_to_last_token=metrics['time_to_last_token'], + time_to_first_token=overall_metrics[0]['time_to_first_token'], + time_to_last_token=sum([x['time_to_last_token'] for x in overall_metrics]), time_to_decode_audio=decode_time, - input_tokens=metrics['input_tokens'], - decoding_tokens=metrics['decoding_tokens'], + input_tokens=[x['input_tokens'] for x in overall_metrics], + decoding_tokens=[x['decoding_tokens'] for x in overall_metrics], generate_end_to_end_time=time.time()-start_time ) diff --git a/service/utils.py b/service/utils.py index 80be61f..fb3a5b0 100644 --- a/service/utils.py +++ b/service/utils.py @@ -37,13 +37,11 @@ def sanitize_text(text: str) -> list[str]: # text = re.sub(allowed_pattern, '', text) text = re.sub(r'([,\.?])+', r'\1', text) - # pattern = r'([.!?])' - # segments = re.split(pattern, text) + pattern = r'([.!?])' + segments = re.split(pattern, text) - sentences = [text.strip()] - return sentences - - # current_sentence = '' + sentences = [] + current_sentence = '' for segment in segments: current_sentence += segment @@ -77,15 +75,11 @@ def alternative_logits_processor( Returns: torch.Tensor - Processed logits. Shape (vocab_size) """ - logger.debug(f'Stop token: {stop_token}') - codebook_indices = len(past_token_ids) % num_codebooks start_idx = offset + codebook_indices * codebook_size end_idx = offset + (codebook_indices + 1) * codebook_size - logger.debug(f'Past_token_ids: {len(past_token_ids)}, codebook indices: {codebook_indices}, start idx: {start_idx}, end idx: {end_idx}') - mask = torch.zeros_like(logits) mask[start_idx:end_idx] = 1 mask[stop_token] = 1 From 30b9839b9b45f33a53f4e9949abdc1c2896c3840 Mon Sep 17 00:00:00 2001 From: romitjain <11757603+romitjain@users.noreply.github.com> Date: Mon, 4 Nov 2024 12:43:15 +0000 Subject: [PATCH 5/6] Updated long infer --- service/inference.py | 9 ++++- service/models.py | 4 +- service/tts.py | 89 ++++++++++++++++++++++---------------------- 3 files changed, 54 insertions(+), 48 deletions(-) diff --git a/service/inference.py b/service/inference.py index 90bf171..6aafe0c 100644 --- a/service/inference.py +++ b/service/inference.py @@ -33,7 +33,14 @@ @app.get("/health") async def health() -> Response: - return Response(status_code=200) + global model + try: + await model.engine.check_health() + except Exception as e: + logger.critical(f"Error in model health check: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + return Response(content="OK", status_code=200) @app.post("/tts", response_model=TTSResponse) async def text_to_speech(requests: TTSRequest): diff --git a/service/models.py b/service/models.py index b930809..8f4180a 100644 --- a/service/models.py +++ b/service/models.py @@ -23,8 +23,8 @@ class TTSRequest(BaseModel): speaker: Speakers class TTSMetrics(BaseModel): - time_to_first_token: List[float] - time_to_last_token: List[float] + time_to_first_token: float + time_to_last_token: float time_to_decode_audio: float input_tokens: List[int] decoding_tokens: List[int] diff --git a/service/tts.py b/service/tts.py index a3c09bb..1fea221 100644 --- a/service/tts.py +++ b/service/tts.py @@ -2,6 +2,7 @@ sys.path.append('omni/') import time +import uuid import torch import numpy as np from transformers import MimiModel, AutoTokenizer @@ -66,10 +67,10 @@ def __init__(self, model_path, device): logits_processors=logits_processors ) - def prepare_tokens(self, incoming_text, speaker, prompt_tokens: dict = None) -> List[int]: - incoming_tokens = self.text_tokenizer.encode(incoming_text) + def prepare_tokens(self, incoming_text, speaker, prompt_tokens: dict = None) -> Tuple[List[int], List[int]]: if prompt_tokens: + incoming_tokens = self.text_tokenizer.encode(' ' + incoming_text) input_tokens = np.hstack([ self.text_modality_token, prompt_tokens[TEXT], @@ -77,10 +78,11 @@ def prepare_tokens(self, incoming_text, speaker, prompt_tokens: dict = None) -> self.convert_token, self.acoustic_modality_token, self.text_tokenizer.encode(speaker), - prompt_tokens[MIMI], + prompt_tokens[MIMI] ]) - return input_tokens.tolist() + return incoming_tokens, input_tokens.tolist() + incoming_tokens = self.text_tokenizer.encode(incoming_text) input_tokens = np.hstack([ self.text_modality_token, incoming_tokens, @@ -89,46 +91,33 @@ def prepare_tokens(self, incoming_text, speaker, prompt_tokens: dict = None) -> self.text_tokenizer.encode(speaker) ]) - return input_tokens.tolist() + return incoming_tokens, input_tokens.tolist() - def get_generation_output(self, output: List[RequestOutput]) -> Tuple[np.ndarray, Dict[str, Any]]: - output_tokens = [] + def get_generation_output(self, output: RequestOutput) -> Tuple[np.ndarray, np.ndarray, Dict[str, Any]]: + output_tokens = np.array(output.outputs[0].token_ids) + end = np.where(output_tokens == self.stop_token[0])[0] - metrics = { - 'time_to_first_token': [], - 'time_to_last_token': [], - 'input_tokens': [], - 'decoding_tokens': [] - } + if len(end) > 0: + end = end[0] + else: + end = len(output_tokens) - for idx, request_output in enumerate(output): - o = np.array(request_output.outputs[0].token_ids) - end = np.where(o == self.stop_token[0])[0] + output_tokens = output_tokens[:end] - if len(end) > 0: - end = end[0] - else: - end = len(o) + output_token_ids = output_tokens.copy() + output_tokens = output_tokens - cfg.OFFSET[MIMI] + output_tokens = deserialize_tokens(output_tokens) - o = o[:end] - o = o - cfg.OFFSET[MIMI] - o = deserialize_tokens(o) - - assert np.all(o >= 0), f'Negative token index generated for batch {idx}' - - metrics['time_to_first_token'].append( - request_output.metrics.first_token_time - request_output.metrics.first_scheduled_time - ) - metrics['time_to_last_token'].append( - request_output.metrics.finished_time - request_output.metrics.first_scheduled_time - ) - metrics['input_tokens'].append(len(request_output.prompt_token_ids)) - metrics['decoding_tokens'].append(len(request_output.outputs[0].token_ids)) + assert np.all(output_tokens >= 0), f'Negative token index generated' - output_tokens.append(o) + metrics = { + 'time_to_first_token': output.metrics.first_token_time - output.metrics.first_scheduled_time, + 'time_to_last_token': output.metrics.finished_time - output.metrics.first_scheduled_time, + 'input_tokens': len(output.prompt_token_ids), + 'decoding_tokens': len(output.outputs[0].token_ids) + } - output_tokens = np.concatenate(output_tokens, axis=1) - return output_tokens, metrics + return output_token_ids, output_tokens, metrics async def generate_async(self, text: str, @@ -145,31 +134,32 @@ async def generate_async(self, logger.info(f'Texts after preprocessing: {batch_text}, {speaker}', extra={'request_id': request_id}) for text in batch_text: - input_tokens = self.prepare_tokens(text, speaker, prompt_tokens) + text_token_ids, input_tokens = self.prepare_tokens(text, speaker, prompt_tokens) logger.info(f'Input tokens shape: {len(input_tokens)}', extra={'request_id': request_id}) prompt = TokensPrompt(prompt_token_ids=input_tokens) results_generator = self.engine.generate( prompt=prompt, sampling_params=self.sampling_params, - request_id=request_id + request_id=str(uuid.uuid4()) ) async for request_output in results_generator: if request_output.finished: output = request_output - output_tokens, generation_metrics = self.get_generation_output(output=output) + output_token_ids, output_tokens, generation_metrics = self.get_generation_output(output=output) logger.info(f'Output tokens shape: {output_tokens.shape}', extra={'request_id': request_id}) overall_metrics.append(generation_metrics) mimi_tokens.append(output_tokens) prompt_tokens = { - TEXT: input_tokens, - MIMI: output_tokens + TEXT: text_token_ids, + MIMI: output_token_ids } + mimi_tokens = np.concatenate(mimi_tokens, axis=1) audio, decode_time = self.decode_audio(mimi_tokens) metrics = TTSMetrics( @@ -196,8 +186,17 @@ def decode_audio(self, audio_tokens) -> Tuple[np.ndarray, float]: return audio, end_time - start_time -if __name__ == '__main__': - tts = TTS('cmeraki/mimi_tts_hf', 'cuda:0') - result = tts.generate('Long ago, in a distant kingdom between emerald hills and sapphire lakes, magic flowed freely. A wise king ruled, ensuring peace and prosperity.') + +async def main(): + model = TTS('cmeraki/mimi_tts_hf', 'cuda:0') + result = await model.generate_async( + 'Long ago, in a distant kingdom between emerald hills and sapphire lakes, magic flowed freely. This is a second sentence.', + speaker='[spkr_hifi_tts_9017]', + request_id=str(uuid.uuid4()) + ) print(result['metrics']) + +if __name__ == '__main__': + import asyncio + asyncio.run(main()) From c03403169a4a0d55fec7bdfba432f5b3d1237b1d Mon Sep 17 00:00:00 2001 From: romitjain <11757603+romitjain@users.noreply.github.com> Date: Tue, 5 Nov 2024 09:02:20 +0000 Subject: [PATCH 6/6] Added long infer at word boundary --- service/tts.py | 10 ++++++---- service/utils.py | 20 +++++++++++++++++++- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/service/tts.py b/service/tts.py index 1fea221..bd5883f 100644 --- a/service/tts.py +++ b/service/tts.py @@ -122,11 +122,12 @@ def get_generation_output(self, output: RequestOutput) -> Tuple[np.ndarray, np.n async def generate_async(self, text: str, speaker: Optional[str] = '[spkr_hifi_tts_9017]', - request_id: Optional[str] = None + request_id: Optional[str] = None, + max_context_words: Optional[int] = 10 ) -> Dict[str, Any]: start_time = time.time() - batch_text = sanitize_text(text) + batch_text = sanitize_text(text, max_context_words) prompt_tokens = {} overall_metrics = [] mimi_tokens = [] @@ -188,11 +189,12 @@ def decode_audio(self, audio_tokens) -> Tuple[np.ndarray, float]: async def main(): - model = TTS('cmeraki/mimi_tts_hf', 'cuda:0') + model = TTS('cmeraki/mimi_tts_hf_stage', 'cuda:0') result = await model.generate_async( 'Long ago, in a distant kingdom between emerald hills and sapphire lakes, magic flowed freely. This is a second sentence.', speaker='[spkr_hifi_tts_9017]', - request_id=str(uuid.uuid4()) + request_id=str(uuid.uuid4()), + max_context_words=20 ) print(result['metrics']) diff --git a/service/utils.py b/service/utils.py index fb3a5b0..c55e29f 100644 --- a/service/utils.py +++ b/service/utils.py @@ -19,24 +19,42 @@ def deserialize_tokens(tokens): return acoustic_tokens -def sanitize_text(text: str) -> list[str]: +def sanitize_text(text: str, max_context_words: int) -> list[str]: """ Sanitize text to be used for TTS Args: text (str): Text to sanitize + max_context_words (int): Maximum number of words in a sentence Returns: list[str]: List of sentences, split by punctuation (., !, ?) """ text = text.lower() + + # Remove more than one newlines and tabs text = re.sub(r'\n+', ' ', text) text = re.sub(r'[ \t]+', ' ', text) + # Remove non-alphanumeric characters except for , . ? ! # allowed_pattern = r'[^a-z0-9\s,\.?\n\!]' # text = re.sub(allowed_pattern, '', text) + + # Remove more than one punctuation mark text = re.sub(r'([,\.?])+', r'\1', text) + # Split sentences by max context length + total_words = text.split(' ') + sentences = [] + current_sentence = '' + + for i in range(0, len(total_words), max_context_words): + current_sentence = ' '.join(total_words[i:i+max_context_words]) + sentences.append(current_sentence.strip()) + + return sentences + + # Split sentences by punctuation (., !, ?) pattern = r'([.!?])' segments = re.split(pattern, text)