Skip to content
137 changes: 99 additions & 38 deletions runner/app/pipelines/text_to_speech.py
Original file line number Diff line number Diff line change
@@ -1,76 +1,137 @@
import io
import logging
import os
import tempfile
import torch

import soundfile as sf
import torch
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer
from pathlib import Path

from app.pipelines.base import Pipeline
from app.pipelines.utils import get_model_dir, get_torch_device
from app.utils.errors import InferenceError
from app.utils.errors import (
InferenceError,
ModelOOMError,
GenerationError,
)

logger = logging.getLogger(__name__)


class TextToSpeechPipeline(Pipeline):
def __init__(self, model_id: str):
"""Instantiate a Chatterbox-based text-to-speech pipeline.

The constructor attempts to load a local checkpoint from the Hugging Face
cache first (to avoid repeated downloads). If none is found, it falls
back to fetching the weights from the Hub. All heavy objects live on
the device returned by ``get_torch_device`` (GPU if available).
"""
self.device = get_torch_device()
self.model_id = model_id
kwargs = {"cache_dir": get_model_dir()}

self.model = ParlerTTSForConditionalGeneration.from_pretrained(
model_id,
**kwargs,
).to(self.device)
try:
from chatterbox.src.chatterbox.tts import ChatterboxTTS # type: ignore
except ModuleNotFoundError: # pragma: no cover – optional
ChatterboxTTS = None # type: ignore

self.tokenizer = AutoTokenizer.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
**kwargs,
)
if ChatterboxTTS is None:
raise ImportError(
"ChatterboxTTS requested but chatterbox package not installed."
)

def _generate_speech(self, text: str, tts_steering: str) -> io.BytesIO:
"""Generate speech from text input using the text-to-speech model.
# Try to locate a local checkpoint first. The huggingface cache layout is:
# <hf_cache>/models--<model_id with slashes replaced by -->/snapshots/<hash>/
safe_model_id = self.model_id.replace("/", "--")
snapshots_root = Path(get_model_dir()) / f"models--{safe_model_id}" / "snapshots"
ckpt_dir: Path | None = None
if snapshots_root.exists():
# Choose the most recently modified snapshot directory if multiple exist.
snapshot_dirs = [d for d in snapshots_root.iterdir() if d.is_dir()]
if snapshot_dirs:
# Sort by modification time, newest first.
snapshot_dirs.sort(key=lambda p: p.stat().st_mtime, reverse=True)
if len(snapshot_dirs) > 1:
logger.warning(
"Multiple snapshots found for %s; using the most recent: %s",
self.model_id,
snapshot_dirs[0],
)
ckpt_dir = snapshot_dirs[0]

if ckpt_dir and ckpt_dir.exists():
logger.info("Loading ChatterboxTTS from local checkpoint: %s", ckpt_dir)
self.model = ChatterboxTTS.from_local(ckpt_dir, self.device)
else:
logger.info("No local checkpoint found for %s — downloading from HuggingFace Hub.", self.model_id)
self.model = ChatterboxTTS.from_pretrained(self.device)

# Sample rate attribute is exposed for audio writing.
self.sample_rate = getattr(self.model, "sr", 44100)
if not hasattr(self.model, "sr"):
logger.warning("Chatterbox model does not expose 'sr'; defaulting to 44100 Hz.")

def _generate_speech(
self,
text: str,
audio_prompt_bytes: bytes | None = None,
) -> io.BytesIO:
"""Generate speech audio from text.

Args:
text: Text input for speech generation.
tts_steering: Description of speaker to steer text to speech generation.
text: Input text to synthesise.
audio_prompt_path: Optional path to a reference audio clip used by the
model for voice cloning / style transfer.

Returns:
buffer: BytesIO buffer containing the generated audio.
BytesIO: In-memory WAV file containing the generated speech.
"""
with torch.no_grad():
input_ids = self.tokenizer(tts_steering, return_tensors="pt").input_ids.to(
self.device
)
prompt_input_ids = self.tokenizer(text, return_tensors="pt").input_ids.to(
self.device
)

generation = self.model.generate(
input_ids=input_ids, prompt_input_ids=prompt_input_ids
)
generated_audio = generation.cpu().numpy().squeeze()
# The Chatterbox generate method returns either a Tensor of shape
# (1, n) or a NumPy-like 1-D array representing audio samples.
if audio_prompt_bytes:
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_audio:
tmp_audio.write(audio_prompt_bytes)
temp_path = tmp_audio.name
try:
wav = self.model.generate(text, audio_prompt_path=temp_path)
except Exception as e:
raise GenerationError(original_exception=e) from e
finally:
try:
os.remove(temp_path)
except OSError:
pass
else:
try:
wav = self.model.generate(text)
except Exception as e:
raise GenerationError(original_exception=e) from e
audio = wav.squeeze(0).cpu().numpy() if torch.is_tensor(wav) else wav

buffer = io.BytesIO()
sf.write(buffer, generated_audio, samplerate=44100, format="WAV")
sf.write(buffer, audio, samplerate=self.sample_rate, format="WAV")
buffer.seek(0)

del input_ids, prompt_input_ids, generation, generated_audio

return buffer

def __call__(self, params) -> io.BytesIO:
try:
output = self._generate_speech(params.text, params.description)
audio_prompt_bytes: bytes | None = getattr(params, "audio_prompt_base64", None)
output = self._generate_speech(
params.text,
audio_prompt_bytes=audio_prompt_bytes,
)
except torch.cuda.OutOfMemoryError as e:
raise e
# Translate low-level OOM into domain-specific error.
raise ModelOOMError(original_exception=e) from e
except Exception as e:
raise InferenceError(original_exception=e)
# If it's already a GenerationError, bubble up unchanged; otherwise wrap.
if isinstance(e, GenerationError):
raise
raise GenerationError(original_exception=e) from e

return output

def __str__(self) -> str:
return f"TextToSpeechPipeline model_id={self.model_id}"
return f"TextToSpeechPipeline(Chatterbox) model_id={self.model_id}"
92 changes: 67 additions & 25 deletions runner/app/routes/text_to_speech.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import base64
import logging
import os
import time
from typing import Annotated, Dict, Tuple, Union

from typing import Annotated, Dict, Tuple, Union, Optional
import torch
from fastapi import APIRouter, Depends, status
from fastapi.responses import JSONResponse
from fastapi.concurrency import run_in_threadpool
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, validator

from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
Expand All @@ -23,6 +24,10 @@

logger = logging.getLogger(__name__)

# ---------------- Validation limits ----------------
MAX_TEXT_LEN = 10000 # maximum characters for text
MAX_AUDIO_BYTES = 10 * 1024 * 1024 # 10 MB

# Pipeline specific error handling configuration.
PIPELINE_ERROR_CONFIG: Dict[str, Tuple[Union[str, None], int]] = {
# Specific error types.
Expand All @@ -32,31 +37,61 @@
)
}


class TextToSpeechParams(BaseModel):

# TODO: Make model_id and other None properties optional once Go codegen tool
# supports OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373
model_id: Annotated[
str,
Field(
default="",
description="Hugging Face model ID used for text to speech generation.",
description="Optional Hugging Face model ID for text-to-speech generation. If omitted, the pipelines configured model is used.",
),
]
text: Annotated[
str, Field(default="", description=("Text input for speech generation."))
]
description: Annotated[
str,
Field(
default=(
"A male speaker delivers a slightly expressive and animated speech "
"with a moderate speed and pitch."
default="Hi, there my name is AI.",
description="Text input for speech generation.",
),
]
audio_prompt_base64: Annotated[
bytes,
Field(
default="",
description=(
"Optional base64-encoded audio data for voice cloning reference. Provide as base64-encoded string; it will be decoded server-side. Must be a valid audio file format like WAV or MP3."
),
description=("Description of speaker to steer text to speech generation."),
),
]

@validator("text")
def validate_text_length(cls, v):
if not v.strip():
raise ValueError("text must not be empty")
if len(v) > MAX_TEXT_LEN:
raise ValueError(f"text exceeds {MAX_TEXT_LEN} characters")
return v

@validator("audio_prompt_base64", pre=True)
def validate_and_decode_audio_prompt(cls, v):
"""Decode base64 audio once during validation and enforce size limits."""
if v is None:
return None
try:
# Accept already-bytes input for flexibility.
if isinstance(v, (bytes, bytearray)):
decoded = bytes(v)
else:
decoded = base64.b64decode(v)
if len(decoded) > MAX_AUDIO_BYTES:
raise ValueError(
f"decoded audio data exceeds {MAX_AUDIO_BYTES / (1024 * 1024):.1f} MB"
)
return decoded
except Exception as e:
raise ValueError(f"invalid base64 audio data: {e}") from e


RESPONSES = {
status.HTTP_200_OK: {
Expand All @@ -74,13 +109,16 @@ class TextToSpeechParams(BaseModel):
}





@router.post(
"/text-to-speech",
response_model=AudioResponse,
responses=RESPONSES,
description=(
"Generate a text-to-speech audio file based on the provided text input and "
"speaker description."
"Generate a text-to-speech audio file based on the provided text input."
"Optionally include base64-encoded audio for voice cloning."
),
operation_id="genTextToSpeech",
summary="Text To Speech",
Expand Down Expand Up @@ -116,19 +154,22 @@ async def text_to_speech(
content=http_error("Invalid bearer token."),
)

if params.model_id != "" and params.model_id != pipeline.model_id:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=http_error(
f"pipeline configured with {pipeline.model_id} but called with "
f"{params.model_id}"
),


# If a model_id is supplied and differs from the pipeline’s current model, log a warning.
if params.model_id and params.model_id != pipeline.model_id:
logger.warning(
"Requested model_id %s differs from pipeline model_id %s — proceeding with current pipeline.",
params.model_id,
pipeline.model_id,
)


try:
start = time.time()
out = pipeline(params)
logger.info(f"TextToSpeechPipeline took {time.time() - start} seconds.")
start_time = time.time()
output = await run_in_threadpool(pipeline, params)
end_time = time.time()
logger.info(f"TextToSpeechPipeline took {end_time - start_time} seconds.")
except Exception as e:
if isinstance(e, torch.cuda.OutOfMemoryError):
# TODO: Investigate why not all VRAM memory is cleared.
Expand All @@ -140,4 +181,5 @@ async def text_to_speech(
custom_error_config=PIPELINE_ERROR_CONFIG,
)

return {"audio": {"url": audio_to_data_url(out)}}

return {"audio": {"url": audio_to_data_url(output)}}
11 changes: 11 additions & 0 deletions runner/app/utils/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,14 @@ def __init__(self, message="Error during model execution", original_exception=No
message = f"{message}: {original_exception}"
super().__init__(message)
self.original_exception = original_exception



class ModelOOMError(InferenceError):
"""Exception raised when the model runs out of memory."""
pass


class GenerationError(InferenceError):
"""Exception raised for general errors in the generation process."""
pass
2 changes: 1 addition & 1 deletion runner/dl_checkpoints.sh
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ function download_beta_models() {

# Download custom pipeline models.
huggingface-cli download facebook/sam2-hiera-large --include "*.pt" "*.yaml" --cache-dir models
huggingface-cli download parler-tts/parler-tts-large-v1 --include "*.safetensors" "*.json" "*.model" --cache-dir models
huggingface-cli download ResembleAI/chatterbox --include "*.pt" "*.json" "*.model" --cache-dir models

printf "\nDownloading token-gated models...\n"

Expand Down
22 changes: 11 additions & 11 deletions runner/docker/Dockerfile.text_to_speech
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
ARG BASE_IMAGE=livepeer/ai-runner:base
FROM ${BASE_IMAGE}

# Install CUDA Toolkit to enable flash attention.
# Install CUDA Toolkit to enable flash attention and audio processing dependencies.
RUN apt-get update && \
apt-get install -y --no-install-recommends \
cuda-toolkit-12-1 \
g++ && \
g++ \
libsndfile1 \
libsndfile1-dev && \
rm -rf /var/lib/apt/lists/*

RUN pip install --no-cache-dir \
ninja \
transformers==4.43.3 \
peft \
deepcache \
soundfile \
g2p-en \
flash_attn==2.5.6 \
git+https://github.com/huggingface/parler-tts.git@5d0aca9753ab74ded179732f5bd797f7a8c6f8ee
# Clone Chatterbox, checkout specific commit, create pyproject.toml, and install it
RUN git clone https://huggingface.co/spaces/ResembleAI/Chatterbox /opt/Chatterbox && \
cd /opt/Chatterbox && \
git checkout bf4bbc30226326884e0d57b1e45ec0550683300f && \
echo '[build-system]\nrequires = ["setuptools"]\nbuild-backend = "setuptools.build_meta"\n\n[project]\nname = "chatterbox"\nversion = "0.1.0"' > pyproject.toml && \
pip install --no-cache-dir -e . && \
pip install --no-cache-dir -r requirements.txt

# Override base working directory to ensure the correct working directory.
WORKDIR /app
Expand Down
Loading
Loading