diff --git a/runner/app/pipelines/text_to_speech.py b/runner/app/pipelines/text_to_speech.py index 097a76529..7d758d0c0 100644 --- a/runner/app/pipelines/text_to_speech.py +++ b/runner/app/pipelines/text_to_speech.py @@ -1,76 +1,137 @@ import io import logging +import os +import tempfile +import torch import soundfile as sf -import torch -from parler_tts import ParlerTTSForConditionalGeneration -from transformers import AutoTokenizer +from pathlib import Path from app.pipelines.base import Pipeline from app.pipelines.utils import get_model_dir, get_torch_device -from app.utils.errors import InferenceError +from app.utils.errors import ( + InferenceError, + ModelOOMError, + GenerationError, +) logger = logging.getLogger(__name__) class TextToSpeechPipeline(Pipeline): def __init__(self, model_id: str): + """Instantiate a Chatterbox-based text-to-speech pipeline. + + The constructor attempts to load a local checkpoint from the Hugging Face + cache first (to avoid repeated downloads). If none is found, it falls + back to fetching the weights from the Hub. All heavy objects live on + the device returned by ``get_torch_device`` (GPU if available). + """ self.device = get_torch_device() self.model_id = model_id - kwargs = {"cache_dir": get_model_dir()} - self.model = ParlerTTSForConditionalGeneration.from_pretrained( - model_id, - **kwargs, - ).to(self.device) + try: + from chatterbox.src.chatterbox.tts import ChatterboxTTS # type: ignore + except ModuleNotFoundError: # pragma: no cover – optional + ChatterboxTTS = None # type: ignore - self.tokenizer = AutoTokenizer.from_pretrained( - model_id, - torch_dtype=torch.bfloat16, - attn_implementation="flash_attention_2", - **kwargs, - ) + if ChatterboxTTS is None: + raise ImportError( + "ChatterboxTTS requested but chatterbox package not installed." + ) - def _generate_speech(self, text: str, tts_steering: str) -> io.BytesIO: - """Generate speech from text input using the text-to-speech model. + # Try to locate a local checkpoint first. The huggingface cache layout is: + # /models--/snapshots// + safe_model_id = self.model_id.replace("/", "--") + snapshots_root = Path(get_model_dir()) / f"models--{safe_model_id}" / "snapshots" + ckpt_dir: Path | None = None + if snapshots_root.exists(): + # Choose the most recently modified snapshot directory if multiple exist. + snapshot_dirs = [d for d in snapshots_root.iterdir() if d.is_dir()] + if snapshot_dirs: + # Sort by modification time, newest first. + snapshot_dirs.sort(key=lambda p: p.stat().st_mtime, reverse=True) + if len(snapshot_dirs) > 1: + logger.warning( + "Multiple snapshots found for %s; using the most recent: %s", + self.model_id, + snapshot_dirs[0], + ) + ckpt_dir = snapshot_dirs[0] + + if ckpt_dir and ckpt_dir.exists(): + logger.info("Loading ChatterboxTTS from local checkpoint: %s", ckpt_dir) + self.model = ChatterboxTTS.from_local(ckpt_dir, self.device) + else: + logger.info("No local checkpoint found for %s — downloading from HuggingFace Hub.", self.model_id) + self.model = ChatterboxTTS.from_pretrained(self.device) + + # Sample rate attribute is exposed for audio writing. + self.sample_rate = getattr(self.model, "sr", 44100) + if not hasattr(self.model, "sr"): + logger.warning("Chatterbox model does not expose 'sr'; defaulting to 44100 Hz.") + + def _generate_speech( + self, + text: str, + audio_prompt_bytes: bytes | None = None, + ) -> io.BytesIO: + """Generate speech audio from text. Args: - text: Text input for speech generation. - tts_steering: Description of speaker to steer text to speech generation. + text: Input text to synthesise. + audio_prompt_path: Optional path to a reference audio clip used by the + model for voice cloning / style transfer. Returns: - buffer: BytesIO buffer containing the generated audio. + BytesIO: In-memory WAV file containing the generated speech. """ with torch.no_grad(): - input_ids = self.tokenizer(tts_steering, return_tensors="pt").input_ids.to( - self.device - ) - prompt_input_ids = self.tokenizer(text, return_tensors="pt").input_ids.to( - self.device - ) - - generation = self.model.generate( - input_ids=input_ids, prompt_input_ids=prompt_input_ids - ) - generated_audio = generation.cpu().numpy().squeeze() + # The Chatterbox generate method returns either a Tensor of shape + # (1, n) or a NumPy-like 1-D array representing audio samples. + if audio_prompt_bytes: + with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_audio: + tmp_audio.write(audio_prompt_bytes) + temp_path = tmp_audio.name + try: + wav = self.model.generate(text, audio_prompt_path=temp_path) + except Exception as e: + raise GenerationError(original_exception=e) from e + finally: + try: + os.remove(temp_path) + except OSError: + pass + else: + try: + wav = self.model.generate(text) + except Exception as e: + raise GenerationError(original_exception=e) from e + audio = wav.squeeze(0).cpu().numpy() if torch.is_tensor(wav) else wav buffer = io.BytesIO() - sf.write(buffer, generated_audio, samplerate=44100, format="WAV") + sf.write(buffer, audio, samplerate=self.sample_rate, format="WAV") buffer.seek(0) - del input_ids, prompt_input_ids, generation, generated_audio - return buffer def __call__(self, params) -> io.BytesIO: try: - output = self._generate_speech(params.text, params.description) + audio_prompt_bytes: bytes | None = getattr(params, "audio_prompt_base64", None) + output = self._generate_speech( + params.text, + audio_prompt_bytes=audio_prompt_bytes, + ) except torch.cuda.OutOfMemoryError as e: - raise e + # Translate low-level OOM into domain-specific error. + raise ModelOOMError(original_exception=e) from e except Exception as e: - raise InferenceError(original_exception=e) + # If it's already a GenerationError, bubble up unchanged; otherwise wrap. + if isinstance(e, GenerationError): + raise + raise GenerationError(original_exception=e) from e return output def __str__(self) -> str: - return f"TextToSpeechPipeline model_id={self.model_id}" + return f"TextToSpeechPipeline(Chatterbox) model_id={self.model_id}" diff --git a/runner/app/routes/text_to_speech.py b/runner/app/routes/text_to_speech.py index 808a60281..4a2959b82 100644 --- a/runner/app/routes/text_to_speech.py +++ b/runner/app/routes/text_to_speech.py @@ -1,13 +1,14 @@ +import base64 import logging import os import time -from typing import Annotated, Dict, Tuple, Union - +from typing import Annotated, Dict, Tuple, Union, Optional import torch from fastapi import APIRouter, Depends, status from fastapi.responses import JSONResponse +from fastapi.concurrency import run_in_threadpool from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, validator from app.dependencies import get_pipeline from app.pipelines.base import Pipeline @@ -23,6 +24,10 @@ logger = logging.getLogger(__name__) +# ---------------- Validation limits ---------------- +MAX_TEXT_LEN = 10000 # maximum characters for text +MAX_AUDIO_BYTES = 10 * 1024 * 1024 # 10 MB + # Pipeline specific error handling configuration. PIPELINE_ERROR_CONFIG: Dict[str, Tuple[Union[str, None], int]] = { # Specific error types. @@ -32,31 +37,61 @@ ) } - class TextToSpeechParams(BaseModel): + # TODO: Make model_id and other None properties optional once Go codegen tool # supports OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373 model_id: Annotated[ str, Field( default="", - description="Hugging Face model ID used for text to speech generation.", + description="Optional Hugging Face model ID for text-to-speech generation. If omitted, the pipelines configured model is used.", ), ] text: Annotated[ - str, Field(default="", description=("Text input for speech generation.")) - ] - description: Annotated[ str, Field( - default=( - "A male speaker delivers a slightly expressive and animated speech " - "with a moderate speed and pitch." + default="Hi, there my name is AI.", + description="Text input for speech generation.", + ), + ] + audio_prompt_base64: Annotated[ + bytes, + Field( + default="", + description=( + "Optional base64-encoded audio data for voice cloning reference. Provide as base64-encoded string; it will be decoded server-side. Must be a valid audio file format like WAV or MP3." ), - description=("Description of speaker to steer text to speech generation."), ), ] + @validator("text") + def validate_text_length(cls, v): + if not v.strip(): + raise ValueError("text must not be empty") + if len(v) > MAX_TEXT_LEN: + raise ValueError(f"text exceeds {MAX_TEXT_LEN} characters") + return v + + @validator("audio_prompt_base64", pre=True) + def validate_and_decode_audio_prompt(cls, v): + """Decode base64 audio once during validation and enforce size limits.""" + if v is None: + return None + try: + # Accept already-bytes input for flexibility. + if isinstance(v, (bytes, bytearray)): + decoded = bytes(v) + else: + decoded = base64.b64decode(v) + if len(decoded) > MAX_AUDIO_BYTES: + raise ValueError( + f"decoded audio data exceeds {MAX_AUDIO_BYTES / (1024 * 1024):.1f} MB" + ) + return decoded + except Exception as e: + raise ValueError(f"invalid base64 audio data: {e}") from e + RESPONSES = { status.HTTP_200_OK: { @@ -74,13 +109,16 @@ class TextToSpeechParams(BaseModel): } + + + @router.post( "/text-to-speech", response_model=AudioResponse, responses=RESPONSES, description=( - "Generate a text-to-speech audio file based on the provided text input and " - "speaker description." + "Generate a text-to-speech audio file based on the provided text input." + "Optionally include base64-encoded audio for voice cloning." ), operation_id="genTextToSpeech", summary="Text To Speech", @@ -116,19 +154,22 @@ async def text_to_speech( content=http_error("Invalid bearer token."), ) - if params.model_id != "" and params.model_id != pipeline.model_id: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content=http_error( - f"pipeline configured with {pipeline.model_id} but called with " - f"{params.model_id}" - ), + + + # If a model_id is supplied and differs from the pipeline’s current model, log a warning. + if params.model_id and params.model_id != pipeline.model_id: + logger.warning( + "Requested model_id %s differs from pipeline model_id %s — proceeding with current pipeline.", + params.model_id, + pipeline.model_id, ) + try: - start = time.time() - out = pipeline(params) - logger.info(f"TextToSpeechPipeline took {time.time() - start} seconds.") + start_time = time.time() + output = await run_in_threadpool(pipeline, params) + end_time = time.time() + logger.info(f"TextToSpeechPipeline took {end_time - start_time} seconds.") except Exception as e: if isinstance(e, torch.cuda.OutOfMemoryError): # TODO: Investigate why not all VRAM memory is cleared. @@ -140,4 +181,5 @@ async def text_to_speech( custom_error_config=PIPELINE_ERROR_CONFIG, ) - return {"audio": {"url": audio_to_data_url(out)}} + + return {"audio": {"url": audio_to_data_url(output)}} diff --git a/runner/app/utils/errors.py b/runner/app/utils/errors.py index 1dfab4661..70e852f72 100644 --- a/runner/app/utils/errors.py +++ b/runner/app/utils/errors.py @@ -15,3 +15,14 @@ def __init__(self, message="Error during model execution", original_exception=No message = f"{message}: {original_exception}" super().__init__(message) self.original_exception = original_exception + + + +class ModelOOMError(InferenceError): + """Exception raised when the model runs out of memory.""" + pass + + +class GenerationError(InferenceError): + """Exception raised for general errors in the generation process.""" + pass diff --git a/runner/dl_checkpoints.sh b/runner/dl_checkpoints.sh index 459bce768..0b4d8904f 100755 --- a/runner/dl_checkpoints.sh +++ b/runner/dl_checkpoints.sh @@ -50,7 +50,7 @@ function download_beta_models() { # Download custom pipeline models. huggingface-cli download facebook/sam2-hiera-large --include "*.pt" "*.yaml" --cache-dir models - huggingface-cli download parler-tts/parler-tts-large-v1 --include "*.safetensors" "*.json" "*.model" --cache-dir models + huggingface-cli download ResembleAI/chatterbox --include "*.pt" "*.json" "*.model" --cache-dir models printf "\nDownloading token-gated models...\n" diff --git a/runner/docker/Dockerfile.text_to_speech b/runner/docker/Dockerfile.text_to_speech index ad7458bcf..00a70225e 100644 --- a/runner/docker/Dockerfile.text_to_speech +++ b/runner/docker/Dockerfile.text_to_speech @@ -1,22 +1,22 @@ ARG BASE_IMAGE=livepeer/ai-runner:base FROM ${BASE_IMAGE} -# Install CUDA Toolkit to enable flash attention. +# Install CUDA Toolkit to enable flash attention and audio processing dependencies. RUN apt-get update && \ apt-get install -y --no-install-recommends \ cuda-toolkit-12-1 \ - g++ && \ + g++ \ + libsndfile1 \ + libsndfile1-dev && \ rm -rf /var/lib/apt/lists/* -RUN pip install --no-cache-dir \ - ninja \ - transformers==4.43.3 \ - peft \ - deepcache \ - soundfile \ - g2p-en \ - flash_attn==2.5.6 \ - git+https://github.com/huggingface/parler-tts.git@5d0aca9753ab74ded179732f5bd797f7a8c6f8ee +# Clone Chatterbox, checkout specific commit, create pyproject.toml, and install it +RUN git clone https://huggingface.co/spaces/ResembleAI/Chatterbox /opt/Chatterbox && \ + cd /opt/Chatterbox && \ + git checkout bf4bbc30226326884e0d57b1e45ec0550683300f && \ + echo '[build-system]\nrequires = ["setuptools"]\nbuild-backend = "setuptools.build_meta"\n\n[project]\nname = "chatterbox"\nversion = "0.1.0"' > pyproject.toml && \ + pip install --no-cache-dir -e . && \ + pip install --no-cache-dir -r requirements.txt # Override base working directory to ensure the correct working directory. WORKDIR /app diff --git a/runner/gateway.openapi.yaml b/runner/gateway.openapi.yaml index 1033b945c..d3fad2299 100644 --- a/runner/gateway.openapi.yaml +++ b/runner/gateway.openapi.yaml @@ -466,7 +466,7 @@ paths: - generate summary: Text To Speech description: Generate a text-to-speech audio file based on the provided text - input and speaker description. + input.Optionally include base64-encoded audio for voice cloning. operationId: genTextToSpeech requestBody: content: @@ -1214,19 +1214,22 @@ components: model_id: type: string title: Model Id - description: Hugging Face model ID used for text to speech generation. + description: Optional Hugging Face model ID for text-to-speech generation. + If omitted, the pipelines configured model is used. default: '' text: type: string title: Text description: Text input for speech generation. - default: '' - description: + default: Hi, there my name is AI. + audio_prompt_base64: type: string - title: Description - description: Description of speaker to steer text to speech generation. - default: A male speaker delivers a slightly expressive and animated speech - with a moderate speed and pitch. + format: binary + title: Audio Prompt Base64 + description: Optional base64-encoded audio data for voice cloning reference. + Provide as base64-encoded string; it will be decoded server-side. Must + be a valid audio file format like WAV or MP3. + default: '' type: object title: TextToSpeechParams required: diff --git a/runner/openapi.yaml b/runner/openapi.yaml index 6c1a3f589..61b49523c 100644 --- a/runner/openapi.yaml +++ b/runner/openapi.yaml @@ -466,7 +466,7 @@ paths: - generate summary: Text To Speech description: Generate a text-to-speech audio file based on the provided text - input and speaker description. + input.Optionally include base64-encoded audio for voice cloning. operationId: genTextToSpeech requestBody: content: @@ -1370,19 +1370,22 @@ components: model_id: type: string title: Model Id - description: Hugging Face model ID used for text to speech generation. + description: Optional Hugging Face model ID for text-to-speech generation. + If omitted, the pipelines configured model is used. default: '' text: type: string title: Text description: Text input for speech generation. - default: '' - description: + default: Hi, there my name is AI. + audio_prompt_base64: type: string - title: Description - description: Description of speaker to steer text to speech generation. - default: A male speaker delivers a slightly expressive and animated speech - with a moderate speed and pitch. + format: binary + title: Audio Prompt Base64 + description: Optional base64-encoded audio data for voice cloning reference. + Provide as base64-encoded string; it will be decoded server-side. Must + be a valid audio file format like WAV or MP3. + default: '' type: object title: TextToSpeechParams ValidationError: