Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 50 additions & 17 deletions runner/app/pipelines/text_to_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,35 +4,44 @@
import soundfile as sf
import torch
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer
from transformers import AutoTokenizer, pipeline

from app.pipelines.base import Pipeline
from app.pipelines.utils import get_model_dir, get_torch_device
from app.utils.errors import InferenceError

logger = logging.getLogger(__name__)

PARLER_MODEL_ID = "parler-tts-large-v1"

class TextToSpeechPipeline(Pipeline):
def __init__(self, model_id: str):
self.device = get_torch_device()
self.model_id = model_id
kwargs = {"cache_dir": get_model_dir()}

self.model = ParlerTTSForConditionalGeneration.from_pretrained(
model_id,
**kwargs,
).to(self.device)

self.tokenizer = AutoTokenizer.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
**kwargs,
)
cache_dir = get_model_dir()

# Setup correct pipeline for a given model_id.
if PARLER_MODEL_ID in model_id.lower():
self.model = ParlerTTSForConditionalGeneration.from_pretrained(
model_id,
cache_dir=cache_dir,
).to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
cache_dir=cache_dir,
)
else:
self.pipe = pipeline(
"text-to-speech",
model=model_id,
device=self.device,
model_kwargs={"cache_dir": cache_dir},
)

def _generate_speech(self, text: str, tts_steering: str) -> io.BytesIO:
"""Generate speech from text input using the text-to-speech model.
def _generate_speech_parler(self, text: str, tts_steering: str) -> io.BytesIO:
"""Generate speech from text input using the Parler TTS model.

Args:
text: Text input for speech generation.
Expand All @@ -59,12 +68,36 @@ def _generate_speech(self, text: str, tts_steering: str) -> io.BytesIO:
buffer.seek(0)

del input_ids, prompt_input_ids, generation, generated_audio
return buffer

def _generate_speech(self, text: str) -> io.BytesIO:
"""Generate speech from text input using the regular TTS pipeline.

Args:
text (str): The text to convert to speech.

Returns:
io.BytesIO: The generated speech audio in WAV format.
"""
generation = self.pipe(text)
generated_audio = generation["audio"].squeeze()

buffer = io.BytesIO()
sf.write(
buffer,
generated_audio,
samplerate=generation["sampling_rate"],
format="WAV",
)
buffer.seek(0)
return buffer

def __call__(self, params) -> io.BytesIO:
try:
output = self._generate_speech(params.text, params.description)
if PARLER_MODEL_ID in self.model_id.lower():
output = self._generate_speech_parler(params.text, params.description)
else:
output = self._generate_speech(params.text)
except torch.cuda.OutOfMemoryError as e:
raise e
except Exception as e:
Expand Down