From 2c02ec5ce5526f1c3198d722008fc0673945d6d9 Mon Sep 17 00:00:00 2001 From: Jason Stone Date: Fri, 20 Jun 2025 14:33:47 +0000 Subject: [PATCH 1/8] feat: integrate Chatterbox TTS model with support for voice cloning --- runner/app/pipelines/text_to_speech.py | 139 +++++++++++++----- runner/app/routes/text_to_speech.py | 90 +++++++++++- runner/app/utils/errors.py | 15 ++ runner/dl_checkpoints.sh | 1 + .../Dockerfile.text_to_speech:chatterbox | 25 ++++ runner/gateway.openapi.yaml | 53 +++---- runner/openapi.yaml | 51 ++++--- 7 files changed, 282 insertions(+), 92 deletions(-) create mode 100644 runner/docker/Dockerfile.text_to_speech:chatterbox diff --git a/runner/app/pipelines/text_to_speech.py b/runner/app/pipelines/text_to_speech.py index 097a76529..6e29d08b7 100644 --- a/runner/app/pipelines/text_to_speech.py +++ b/runner/app/pipelines/text_to_speech.py @@ -1,10 +1,10 @@ import io import logging +import torch import soundfile as sf -import torch -from parler_tts import ParlerTTSForConditionalGeneration from transformers import AutoTokenizer +from pathlib import Path from app.pipelines.base import Pipeline from app.pipelines.utils import get_model_dir, get_torch_device @@ -15,56 +15,120 @@ class TextToSpeechPipeline(Pipeline): def __init__(self, model_id: str): + """Instantiate a TTS pipeline. + + Depending on *model_id*, either the Parler-TTS or Chatterbox model will be + loaded. We treat any model_id that contains the substring "chatterbox" (case + insensitive) as a request for Chatterbox and default to Parler otherwise. + """ self.device = get_torch_device() self.model_id = model_id - kwargs = {"cache_dir": get_model_dir()} - - self.model = ParlerTTSForConditionalGeneration.from_pretrained( - model_id, - **kwargs, - ).to(self.device) - - self.tokenizer = AutoTokenizer.from_pretrained( - model_id, - torch_dtype=torch.bfloat16, - attn_implementation="flash_attention_2", - **kwargs, - ) - - def _generate_speech(self, text: str, tts_steering: str) -> io.BytesIO: + self._is_chatterbox = "chatterbox" in model_id.lower() + + if self._is_chatterbox: + try: + from chatterbox.src.chatterbox.tts import ChatterboxTTS # type: ignore + except ModuleNotFoundError: # pragma: no cover – optional + ChatterboxTTS = None # type: ignore + + if ChatterboxTTS is None: + raise ImportError( + "ChatterboxTTS requested but chatterbox package not installed." + ) + + # Try using locally downloaded checkpoints first (downloaded via + # `dl_checkpoints.sh`). + base_cache = Path(get_model_dir()) / "models--ResembleAI--chatterbox" / "snapshots" + ckpt_dir = None + if base_cache.exists(): + # Grab first snapshot (should only be one) as the checkpoint dir. + for d in base_cache.iterdir(): + if d.is_dir(): + ckpt_dir = d + break + + if ckpt_dir is not None and ckpt_dir.exists(): + self.model = ChatterboxTTS.from_local(ckpt_dir, self.device) + else: + # Fallback to HF hub (will also cache under above path). + self.model = ChatterboxTTS.from_pretrained(self.device) + + # Sample rate attribute is exposed for audio writing. + self.sample_rate = getattr(self.model, "sr", 44100) + + else: + # Now import Parler-TTS which will use our patched register method + from parler_tts import ParlerTTSForConditionalGeneration # type: ignore + + # Default to Parler-TTS. + kwargs = {"cache_dir": get_model_dir()} + + self.model = ParlerTTSForConditionalGeneration.from_pretrained( + model_id, + **kwargs, + ).to(self.device) + + self.tokenizer = AutoTokenizer.from_pretrained( + model_id, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + **kwargs, + ) + self.sample_rate = 44100 + + def _generate_speech( + self, + text: str, + tts_steering: str, + audio_prompt_path: str | None = None, + ) -> io.BytesIO: """Generate speech from text input using the text-to-speech model. Args: text: Text input for speech generation. tts_steering: Description of speaker to steer text to speech generation. + audio_prompt_path: Optional path to an audio prompt. Returns: buffer: BytesIO buffer containing the generated audio. """ - with torch.no_grad(): - input_ids = self.tokenizer(tts_steering, return_tensors="pt").input_ids.to( - self.device - ) - prompt_input_ids = self.tokenizer(text, return_tensors="pt").input_ids.to( - self.device - ) - - generation = self.model.generate( - input_ids=input_ids, prompt_input_ids=prompt_input_ids - ) - generated_audio = generation.cpu().numpy().squeeze() - - buffer = io.BytesIO() - sf.write(buffer, generated_audio, samplerate=44100, format="WAV") - buffer.seek(0) - - del input_ids, prompt_input_ids, generation, generated_audio + # Branch based on backend. + if self._is_chatterbox: + with torch.no_grad(): + # Chatterbox generate returns a tensor with shape (1, n) or (n,) + wav = self.model.generate(text, audio_prompt_path=audio_prompt_path) + audio = wav.squeeze(0).cpu().numpy() if torch.is_tensor(wav) else wav + + buffer = io.BytesIO() + sf.write(buffer, audio, samplerate=self.sample_rate, format="WAV") + buffer.seek(0) + else: + with torch.no_grad(): + input_ids = self.tokenizer(tts_steering, return_tensors="pt").input_ids.to( + self.device + ) + prompt_input_ids = self.tokenizer(text, return_tensors="pt").input_ids.to( + self.device + ) + + generation = self.model.generate( + input_ids=input_ids, prompt_input_ids=prompt_input_ids + ) + generated_audio = generation.cpu().numpy().squeeze() + + buffer = io.BytesIO() + sf.write(buffer, generated_audio, samplerate=self.sample_rate, format="WAV") + buffer.seek(0) return buffer def __call__(self, params) -> io.BytesIO: try: - output = self._generate_speech(params.text, params.description) + output = self._generate_speech( + params.text, + params.description, + getattr(params, "audio_prompt_path", None), + ) except torch.cuda.OutOfMemoryError as e: raise e except Exception as e: @@ -73,4 +137,5 @@ def __call__(self, params) -> io.BytesIO: return output def __str__(self) -> str: - return f"TextToSpeechPipeline model_id={self.model_id}" + backend = "Chatterbox" if self._is_chatterbox else "ParlerTTS" + return f"TextToSpeechPipeline({backend}) model_id={self.model_id}" diff --git a/runner/app/routes/text_to_speech.py b/runner/app/routes/text_to_speech.py index 808a60281..a96c6f54b 100644 --- a/runner/app/routes/text_to_speech.py +++ b/runner/app/routes/text_to_speech.py @@ -2,12 +2,13 @@ import os import time from typing import Annotated, Dict, Tuple, Union +from pathlib import Path import torch -from fastapi import APIRouter, Depends, status +from fastapi import APIRouter, Depends, status, UploadFile, File, Form from fastapi.responses import JSONResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, validator from app.dependencies import get_pipeline from app.pipelines.base import Pipeline @@ -15,6 +16,7 @@ AudioResponse, HTTPError, audio_to_data_url, + file_exceeds_max_size, handle_pipeline_exception, http_error, ) @@ -23,6 +25,11 @@ logger = logging.getLogger(__name__) +# ---------------- Validation limits ---------------- +MAX_TEXT_LEN = 10000 # maximum characters for text +MAX_DESC_LEN = 1000 # maximum characters for description +MAX_AUDIO_BYTES = 10 * 1024 * 1024 # 10 MB + # Pipeline specific error handling configuration. PIPELINE_ERROR_CONFIG: Dict[str, Tuple[Union[str, None], int]] = { # Specific error types. @@ -32,22 +39,30 @@ ) } - class TextToSpeechParams(BaseModel): # TODO: Make model_id and other None properties optional once Go codegen tool # supports OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373 model_id: Annotated[ str, Field( - default="", + default="ResembleAI/chatterbox", description="Hugging Face model ID used for text to speech generation.", ), ] text: Annotated[ - str, Field(default="", description=("Text input for speech generation.")) + str, + Field( + default=( + "When it was all over, the remaing animals, except for the pigs and dogs, " + "crept away in a body. They were shaken and miserable. They did not know " + "which was more shocking - the treachery of the animals who had leagued " + "themselves with Snowball, or the cruel retribution they had just witnessed." + ), + description="Text input for speech generation.", + ), ] description: Annotated[ - str, + str | None, Field( default=( "A male speaker delivers a slightly expressive and animated speech " @@ -56,6 +71,29 @@ class TextToSpeechParams(BaseModel): description=("Description of speaker to steer text to speech generation."), ), ] + audio_prompt_path: Annotated[ + str | None, + Field( + default=None, + description=( + "Optional path or URL to a reference audio clip for voice cloning (only used when model_id refers to Chatterbox)." + ), + ), + ] + + @validator("text") + def validate_text_length(cls, v): + if not v.strip(): + raise ValueError("text must not be empty") + if len(v) > MAX_TEXT_LEN: + raise ValueError(f"text exceeds {MAX_TEXT_LEN} characters") + return v + + @validator("description") + def validate_description_length(cls, v): + if len(v) > MAX_DESC_LEN: + raise ValueError(f"description exceeds {MAX_DESC_LEN} characters") + return v RESPONSES = { @@ -74,6 +112,20 @@ class TextToSpeechParams(BaseModel): } +def _multipart_params( + model_id: str = Form(""), + text: str = Form(""), + description: str = Form( + "A male speaker delivers a slightly expressive and animated speech with a moderate speed and pitch." + ), +) -> "TextToSpeechParams": + return TextToSpeechParams( + model_id=model_id, + text=text, + description=description, + ) + + @router.post( "/text-to-speech", response_model=AudioResponse, @@ -94,7 +146,8 @@ class TextToSpeechParams(BaseModel): include_in_schema=False, ) async def text_to_speech( - params: TextToSpeechParams, + params: TextToSpeechParams = Depends(_multipart_params), + audio_prompt: UploadFile | None = File(None, description="Optional reference audio file for Chatterbox voice cloning"), pipeline: Pipeline = Depends(get_pipeline), token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)), ): @@ -116,6 +169,13 @@ async def text_to_speech( content=http_error("Invalid bearer token."), ) + # Check uploaded file size early + if audio_prompt is not None and file_exceeds_max_size(audio_prompt, MAX_AUDIO_BYTES): + return JSONResponse( + status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, + content=http_error("audio_prompt too large; 10 MB max"), + ) + if params.model_id != "" and params.model_id != pipeline.model_id: return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, @@ -125,6 +185,16 @@ async def text_to_speech( ), ) + # Handle uploaded audio prompt. + temp_path = None + if audio_prompt is not None: + suffix = Path(audio_prompt.filename or "audio").suffix or ".wav" + import tempfile + with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: + tmp.write(await audio_prompt.read()) + temp_path = tmp.name + params.audio_prompt_path = temp_path # override + try: start = time.time() out = pipeline(params) @@ -139,5 +209,11 @@ async def text_to_speech( default_error_message="Text-to-speech pipeline error.", custom_error_config=PIPELINE_ERROR_CONFIG, ) + finally: + if temp_path: + try: + os.remove(temp_path) + except OSError: + pass return {"audio": {"url": audio_to_data_url(out)}} diff --git a/runner/app/utils/errors.py b/runner/app/utils/errors.py index 1dfab4661..12650bd1a 100644 --- a/runner/app/utils/errors.py +++ b/runner/app/utils/errors.py @@ -15,3 +15,18 @@ def __init__(self, message="Error during model execution", original_exception=No message = f"{message}: {original_exception}" super().__init__(message) self.original_exception = original_exception + + +class InvalidInputError(InferenceError): + """Exception raised when input validation fails.""" + pass + + +class ModelOOMError(InferenceError): + """Exception raised when the model runs out of memory.""" + pass + + +class GenerationError(InferenceError): + """Exception raised for general errors in the generation process.""" + pass diff --git a/runner/dl_checkpoints.sh b/runner/dl_checkpoints.sh index 459bce768..e2a068b57 100755 --- a/runner/dl_checkpoints.sh +++ b/runner/dl_checkpoints.sh @@ -51,6 +51,7 @@ function download_beta_models() { # Download custom pipeline models. huggingface-cli download facebook/sam2-hiera-large --include "*.pt" "*.yaml" --cache-dir models huggingface-cli download parler-tts/parler-tts-large-v1 --include "*.safetensors" "*.json" "*.model" --cache-dir models + huggingface-cli download ResembleAI/chatterbox --include "*.safetensors" "*.json" "*.model" --cache-dir models printf "\nDownloading token-gated models...\n" diff --git a/runner/docker/Dockerfile.text_to_speech:chatterbox b/runner/docker/Dockerfile.text_to_speech:chatterbox new file mode 100644 index 000000000..025d4b0a4 --- /dev/null +++ b/runner/docker/Dockerfile.text_to_speech:chatterbox @@ -0,0 +1,25 @@ +ARG BASE_IMAGE=livepeer/ai-runner:base +FROM ${BASE_IMAGE} + +# Install CUDA Toolkit to enable flash attention. +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + cuda-toolkit-12-1 \ + g++ && \ + rm -rf /var/lib/apt/lists/* + +# Clone Chatterbox, checkout specific commit, create pyproject.toml, and install it +RUN git clone https://huggingface.co/spaces/ResembleAI/Chatterbox /opt/Chatterbox && \ + cd /opt/Chatterbox && \ + git checkout bf4bbc30226326884e0d57b1e45ec0550683300f && \ + echo '[build-system]\nrequires = ["setuptools"]\nbuild-backend = "setuptools.build_meta"\n\n[project]\nname = "chatterbox"\nversion = "0.1.0"' > pyproject.toml && \ + pip install --no-cache-dir -e . && \ + pip install --no-cache-dir -r requirements.txt + +# Override base working directory to ensure the correct working directory. +WORKDIR /app + +# Copy app directory to avoid rebuilding the base image during development. +COPY app/ /app/app + +CMD ["uvicorn", "app.main:app", "--log-config", "app/cfg/uvicorn_logging_config.json", "--host", "", "--port", "8000"] diff --git a/runner/gateway.openapi.yaml b/runner/gateway.openapi.yaml index 1033b945c..8dd2ecb80 100644 --- a/runner/gateway.openapi.yaml +++ b/runner/gateway.openapi.yaml @@ -470,10 +470,11 @@ paths: operationId: genTextToSpeech requestBody: content: - application/json: + multipart/form-data: schema: - $ref: '#/components/schemas/TextToSpeechParams' - required: true + allOf: + - $ref: '#/components/schemas/Body_genTextToSpeech' + title: Body responses: '200': description: Successful Response @@ -769,6 +770,30 @@ components: - image - model_id title: Body_genSegmentAnything2 + Body_genTextToSpeech: + properties: + audio_prompt: + anyOf: + - type: string + format: binary + - type: 'null' + title: Audio Prompt + description: Optional reference audio file for Chatterbox voice cloning + model_id: + type: string + title: Model Id + default: '' + text: + type: string + title: Text + default: '' + description: + type: string + title: Description + default: A male speaker delivers a slightly expressive and animated speech + with a moderate speed and pitch. + type: object + title: Body_genTextToSpeech Body_genUpscale: properties: prompt: @@ -1209,28 +1234,6 @@ components: - prompt - model_id title: TextToImageParams - TextToSpeechParams: - properties: - model_id: - type: string - title: Model Id - description: Hugging Face model ID used for text to speech generation. - default: '' - text: - type: string - title: Text - description: Text input for speech generation. - default: '' - description: - type: string - title: Description - description: Description of speaker to steer text to speech generation. - default: A male speaker delivers a slightly expressive and animated speech - with a moderate speed and pitch. - type: object - title: TextToSpeechParams - required: - - model_id ValidationError: properties: loc: diff --git a/runner/openapi.yaml b/runner/openapi.yaml index 6c1a3f589..a25cd416d 100644 --- a/runner/openapi.yaml +++ b/runner/openapi.yaml @@ -470,10 +470,11 @@ paths: operationId: genTextToSpeech requestBody: content: - application/json: + multipart/form-data: schema: - $ref: '#/components/schemas/TextToSpeechParams' - required: true + allOf: + - $ref: '#/components/schemas/Body_genTextToSpeech' + title: Body responses: '200': description: Successful Response @@ -813,6 +814,30 @@ components: required: - image title: Body_genSegmentAnything2 + Body_genTextToSpeech: + properties: + audio_prompt: + anyOf: + - type: string + format: binary + - type: 'null' + title: Audio Prompt + description: Optional reference audio file for Chatterbox voice cloning + model_id: + type: string + title: Model Id + default: '' + text: + type: string + title: Text + default: '' + description: + type: string + title: Description + default: A male speaker delivers a slightly expressive and animated speech + with a moderate speed and pitch. + type: object + title: Body_genTextToSpeech Body_genUpscale: properties: prompt: @@ -1365,26 +1390,6 @@ components: required: - prompt title: TextToImageParams - TextToSpeechParams: - properties: - model_id: - type: string - title: Model Id - description: Hugging Face model ID used for text to speech generation. - default: '' - text: - type: string - title: Text - description: Text input for speech generation. - default: '' - description: - type: string - title: Description - description: Description of speaker to steer text to speech generation. - default: A male speaker delivers a slightly expressive and animated speech - with a moderate speed and pitch. - type: object - title: TextToSpeechParams ValidationError: properties: loc: From 3758d2031975cdef28bd06eee831ecf046bdc7b6 Mon Sep 17 00:00:00 2001 From: Jason Stone Date: Fri, 20 Jun 2025 19:15:11 +0000 Subject: [PATCH 2/8] feat: integrate Chatterbox TTS model with Docker and API endpoints --- runner/app/pipelines/text_to_speech.py | 186 +++++++++--------- runner/app/routes/text_to_speech.py | 124 +++++------- runner/app/routes/utils.py | 24 --- runner/app/utils/errors.py | 4 - runner/dl_checkpoints.sh | 1 - runner/docker/Dockerfile.text_to_speech | 22 +-- .../Dockerfile.text_to_speech:chatterbox | 25 --- 7 files changed, 149 insertions(+), 237 deletions(-) delete mode 100644 runner/docker/Dockerfile.text_to_speech:chatterbox diff --git a/runner/app/pipelines/text_to_speech.py b/runner/app/pipelines/text_to_speech.py index 6e29d08b7..7d758d0c0 100644 --- a/runner/app/pipelines/text_to_speech.py +++ b/runner/app/pipelines/text_to_speech.py @@ -1,141 +1,137 @@ import io import logging +import os +import tempfile import torch import soundfile as sf -from transformers import AutoTokenizer from pathlib import Path from app.pipelines.base import Pipeline from app.pipelines.utils import get_model_dir, get_torch_device -from app.utils.errors import InferenceError +from app.utils.errors import ( + InferenceError, + ModelOOMError, + GenerationError, +) logger = logging.getLogger(__name__) class TextToSpeechPipeline(Pipeline): def __init__(self, model_id: str): - """Instantiate a TTS pipeline. + """Instantiate a Chatterbox-based text-to-speech pipeline. - Depending on *model_id*, either the Parler-TTS or Chatterbox model will be - loaded. We treat any model_id that contains the substring "chatterbox" (case - insensitive) as a request for Chatterbox and default to Parler otherwise. + The constructor attempts to load a local checkpoint from the Hugging Face + cache first (to avoid repeated downloads). If none is found, it falls + back to fetching the weights from the Hub. All heavy objects live on + the device returned by ``get_torch_device`` (GPU if available). """ self.device = get_torch_device() self.model_id = model_id - self._is_chatterbox = "chatterbox" in model_id.lower() - - if self._is_chatterbox: - try: - from chatterbox.src.chatterbox.tts import ChatterboxTTS # type: ignore - except ModuleNotFoundError: # pragma: no cover – optional - ChatterboxTTS = None # type: ignore - - if ChatterboxTTS is None: - raise ImportError( - "ChatterboxTTS requested but chatterbox package not installed." - ) - - # Try using locally downloaded checkpoints first (downloaded via - # `dl_checkpoints.sh`). - base_cache = Path(get_model_dir()) / "models--ResembleAI--chatterbox" / "snapshots" - ckpt_dir = None - if base_cache.exists(): - # Grab first snapshot (should only be one) as the checkpoint dir. - for d in base_cache.iterdir(): - if d.is_dir(): - ckpt_dir = d - break - - if ckpt_dir is not None and ckpt_dir.exists(): - self.model = ChatterboxTTS.from_local(ckpt_dir, self.device) - else: - # Fallback to HF hub (will also cache under above path). - self.model = ChatterboxTTS.from_pretrained(self.device) - # Sample rate attribute is exposed for audio writing. - self.sample_rate = getattr(self.model, "sr", 44100) + try: + from chatterbox.src.chatterbox.tts import ChatterboxTTS # type: ignore + except ModuleNotFoundError: # pragma: no cover – optional + ChatterboxTTS = None # type: ignore - else: - # Now import Parler-TTS which will use our patched register method - from parler_tts import ParlerTTSForConditionalGeneration # type: ignore - - # Default to Parler-TTS. - kwargs = {"cache_dir": get_model_dir()} - - self.model = ParlerTTSForConditionalGeneration.from_pretrained( - model_id, - **kwargs, - ).to(self.device) - - self.tokenizer = AutoTokenizer.from_pretrained( - model_id, - torch_dtype=torch.bfloat16, - attn_implementation="flash_attention_2", - **kwargs, + if ChatterboxTTS is None: + raise ImportError( + "ChatterboxTTS requested but chatterbox package not installed." ) - self.sample_rate = 44100 + + # Try to locate a local checkpoint first. The huggingface cache layout is: + # /models--/snapshots// + safe_model_id = self.model_id.replace("/", "--") + snapshots_root = Path(get_model_dir()) / f"models--{safe_model_id}" / "snapshots" + ckpt_dir: Path | None = None + if snapshots_root.exists(): + # Choose the most recently modified snapshot directory if multiple exist. + snapshot_dirs = [d for d in snapshots_root.iterdir() if d.is_dir()] + if snapshot_dirs: + # Sort by modification time, newest first. + snapshot_dirs.sort(key=lambda p: p.stat().st_mtime, reverse=True) + if len(snapshot_dirs) > 1: + logger.warning( + "Multiple snapshots found for %s; using the most recent: %s", + self.model_id, + snapshot_dirs[0], + ) + ckpt_dir = snapshot_dirs[0] + + if ckpt_dir and ckpt_dir.exists(): + logger.info("Loading ChatterboxTTS from local checkpoint: %s", ckpt_dir) + self.model = ChatterboxTTS.from_local(ckpt_dir, self.device) + else: + logger.info("No local checkpoint found for %s — downloading from HuggingFace Hub.", self.model_id) + self.model = ChatterboxTTS.from_pretrained(self.device) + + # Sample rate attribute is exposed for audio writing. + self.sample_rate = getattr(self.model, "sr", 44100) + if not hasattr(self.model, "sr"): + logger.warning("Chatterbox model does not expose 'sr'; defaulting to 44100 Hz.") def _generate_speech( self, text: str, - tts_steering: str, - audio_prompt_path: str | None = None, + audio_prompt_bytes: bytes | None = None, ) -> io.BytesIO: - """Generate speech from text input using the text-to-speech model. + """Generate speech audio from text. Args: - text: Text input for speech generation. - tts_steering: Description of speaker to steer text to speech generation. - audio_prompt_path: Optional path to an audio prompt. + text: Input text to synthesise. + audio_prompt_path: Optional path to a reference audio clip used by the + model for voice cloning / style transfer. Returns: - buffer: BytesIO buffer containing the generated audio. + BytesIO: In-memory WAV file containing the generated speech. """ - # Branch based on backend. - if self._is_chatterbox: - with torch.no_grad(): - # Chatterbox generate returns a tensor with shape (1, n) or (n,) - wav = self.model.generate(text, audio_prompt_path=audio_prompt_path) - audio = wav.squeeze(0).cpu().numpy() if torch.is_tensor(wav) else wav - - buffer = io.BytesIO() - sf.write(buffer, audio, samplerate=self.sample_rate, format="WAV") - buffer.seek(0) - else: - with torch.no_grad(): - input_ids = self.tokenizer(tts_steering, return_tensors="pt").input_ids.to( - self.device - ) - prompt_input_ids = self.tokenizer(text, return_tensors="pt").input_ids.to( - self.device - ) - - generation = self.model.generate( - input_ids=input_ids, prompt_input_ids=prompt_input_ids - ) - generated_audio = generation.cpu().numpy().squeeze() - - buffer = io.BytesIO() - sf.write(buffer, generated_audio, samplerate=self.sample_rate, format="WAV") - buffer.seek(0) + with torch.no_grad(): + # The Chatterbox generate method returns either a Tensor of shape + # (1, n) or a NumPy-like 1-D array representing audio samples. + if audio_prompt_bytes: + with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_audio: + tmp_audio.write(audio_prompt_bytes) + temp_path = tmp_audio.name + try: + wav = self.model.generate(text, audio_prompt_path=temp_path) + except Exception as e: + raise GenerationError(original_exception=e) from e + finally: + try: + os.remove(temp_path) + except OSError: + pass + else: + try: + wav = self.model.generate(text) + except Exception as e: + raise GenerationError(original_exception=e) from e + audio = wav.squeeze(0).cpu().numpy() if torch.is_tensor(wav) else wav + + buffer = io.BytesIO() + sf.write(buffer, audio, samplerate=self.sample_rate, format="WAV") + buffer.seek(0) return buffer def __call__(self, params) -> io.BytesIO: try: + audio_prompt_bytes: bytes | None = getattr(params, "audio_prompt_base64", None) output = self._generate_speech( params.text, - params.description, - getattr(params, "audio_prompt_path", None), + audio_prompt_bytes=audio_prompt_bytes, ) except torch.cuda.OutOfMemoryError as e: - raise e + # Translate low-level OOM into domain-specific error. + raise ModelOOMError(original_exception=e) from e except Exception as e: - raise InferenceError(original_exception=e) + # If it's already a GenerationError, bubble up unchanged; otherwise wrap. + if isinstance(e, GenerationError): + raise + raise GenerationError(original_exception=e) from e return output def __str__(self) -> str: - backend = "Chatterbox" if self._is_chatterbox else "ParlerTTS" - return f"TextToSpeechPipeline({backend}) model_id={self.model_id}" + return f"TextToSpeechPipeline(Chatterbox) model_id={self.model_id}" diff --git a/runner/app/routes/text_to_speech.py b/runner/app/routes/text_to_speech.py index a96c6f54b..f3efbae1c 100644 --- a/runner/app/routes/text_to_speech.py +++ b/runner/app/routes/text_to_speech.py @@ -1,11 +1,10 @@ +import base64 import logging import os import time -from typing import Annotated, Dict, Tuple, Union -from pathlib import Path - +from typing import Annotated, Dict, Tuple, Union, Optional import torch -from fastapi import APIRouter, Depends, status, UploadFile, File, Form +from fastapi import APIRouter, Depends, status from fastapi.responses import JSONResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from pydantic import BaseModel, Field, validator @@ -16,7 +15,6 @@ AudioResponse, HTTPError, audio_to_data_url, - file_exceeds_max_size, handle_pipeline_exception, http_error, ) @@ -27,7 +25,6 @@ # ---------------- Validation limits ---------------- MAX_TEXT_LEN = 10000 # maximum characters for text -MAX_DESC_LEN = 1000 # maximum characters for description MAX_AUDIO_BYTES = 10 * 1024 * 1024 # 10 MB # Pipeline specific error handling configuration. @@ -40,13 +37,14 @@ } class TextToSpeechParams(BaseModel): + # TODO: Make model_id and other None properties optional once Go codegen tool # supports OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373 model_id: Annotated[ - str, + Optional[str], Field( - default="ResembleAI/chatterbox", - description="Hugging Face model ID used for text to speech generation.", + default=None, + description="Optional Hugging Face model ID for text-to-speech generation. If omitted, the pipeline’s configured model is used.", ), ] text: Annotated[ @@ -61,26 +59,16 @@ class TextToSpeechParams(BaseModel): description="Text input for speech generation.", ), ] - description: Annotated[ - str | None, - Field( - default=( - "A male speaker delivers a slightly expressive and animated speech " - "with a moderate speed and pitch." - ), - description=("Description of speaker to steer text to speech generation."), - ), - ] - audio_prompt_path: Annotated[ - str | None, + audio_prompt_base64: Annotated[ + bytes | None, Field( default=None, description=( - "Optional path or URL to a reference audio clip for voice cloning (only used when model_id refers to Chatterbox)." + "Optional base64-encoded audio data for voice cloning reference. Provide as base64-encoded string; it will be decoded server-side. Must be a valid audio file format like WAV or MP3." ), ), ] - + @validator("text") def validate_text_length(cls, v): if not v.strip(): @@ -89,11 +77,24 @@ def validate_text_length(cls, v): raise ValueError(f"text exceeds {MAX_TEXT_LEN} characters") return v - @validator("description") - def validate_description_length(cls, v): - if len(v) > MAX_DESC_LEN: - raise ValueError(f"description exceeds {MAX_DESC_LEN} characters") - return v + @validator("audio_prompt_base64", pre=True) + def validate_and_decode_audio_prompt(cls, v): + """Decode base64 audio once during validation and enforce size limits.""" + if v is None: + return None + try: + # Accept already-bytes input for flexibility. + if isinstance(v, (bytes, bytearray)): + decoded = bytes(v) + else: + decoded = base64.b64decode(v) + if len(decoded) > MAX_AUDIO_BYTES: + raise ValueError( + f"decoded audio data exceeds {MAX_AUDIO_BYTES / (1024 * 1024):.1f} MB" + ) + return decoded + except Exception as e: + raise ValueError(f"invalid base64 audio data: {e}") from e RESPONSES = { @@ -112,18 +113,7 @@ def validate_description_length(cls, v): } -def _multipart_params( - model_id: str = Form(""), - text: str = Form(""), - description: str = Form( - "A male speaker delivers a slightly expressive and animated speech with a moderate speed and pitch." - ), -) -> "TextToSpeechParams": - return TextToSpeechParams( - model_id=model_id, - text=text, - description=description, - ) + @router.post( @@ -131,8 +121,8 @@ def _multipart_params( response_model=AudioResponse, responses=RESPONSES, description=( - "Generate a text-to-speech audio file based on the provided text input and " - "speaker description." + "Generate a text-to-speech audio file based on the provided text input." + "Optionally include base64-encoded audio for voice cloning." ), operation_id="genTextToSpeech", summary="Text To Speech", @@ -146,8 +136,7 @@ def _multipart_params( include_in_schema=False, ) async def text_to_speech( - params: TextToSpeechParams = Depends(_multipart_params), - audio_prompt: UploadFile | None = File(None, description="Optional reference audio file for Chatterbox voice cloning"), + params: TextToSpeechParams, pipeline: Pipeline = Depends(get_pipeline), token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)), ): @@ -169,36 +158,22 @@ async def text_to_speech( content=http_error("Invalid bearer token."), ) - # Check uploaded file size early - if audio_prompt is not None and file_exceeds_max_size(audio_prompt, MAX_AUDIO_BYTES): - return JSONResponse( - status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, - content=http_error("audio_prompt too large; 10 MB max"), - ) - if params.model_id != "" and params.model_id != pipeline.model_id: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content=http_error( - f"pipeline configured with {pipeline.model_id} but called with " - f"{params.model_id}" - ), + + # If a model_id is supplied and differs from the pipeline’s current model, log a warning. + if params.model_id and params.model_id != pipeline.model_id: + logger.warning( + "Requested model_id %s differs from pipeline model_id %s — proceeding with current pipeline.", + params.model_id, + pipeline.model_id, ) - # Handle uploaded audio prompt. - temp_path = None - if audio_prompt is not None: - suffix = Path(audio_prompt.filename or "audio").suffix or ".wav" - import tempfile - with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: - tmp.write(await audio_prompt.read()) - temp_path = tmp.name - params.audio_prompt_path = temp_path # override try: - start = time.time() - out = pipeline(params) - logger.info(f"TextToSpeechPipeline took {time.time() - start} seconds.") + start_time = time.time() + output = pipeline(params) + end_time = time.time() + logger.info(f"TextToSpeechPipeline took {end_time - start_time} seconds.") except Exception as e: if isinstance(e, torch.cuda.OutOfMemoryError): # TODO: Investigate why not all VRAM memory is cleared. @@ -209,11 +184,6 @@ async def text_to_speech( default_error_message="Text-to-speech pipeline error.", custom_error_config=PIPELINE_ERROR_CONFIG, ) - finally: - if temp_path: - try: - os.remove(temp_path) - except OSError: - pass - - return {"audio": {"url": audio_to_data_url(out)}} + + + return {"audio": {"url": audio_to_data_url(output)}} diff --git a/runner/app/routes/utils.py b/runner/app/routes/utils.py index 5ed185f8b..98822a266 100644 --- a/runner/app/routes/utils.py +++ b/runner/app/routes/utils.py @@ -207,30 +207,6 @@ def audio_to_data_url(buffer: io.BytesIO, format: str = "wav") -> str: return f"data:audio/{format};base64,{base64_audio}" -def file_exceeds_max_size( - input_file: UploadFile, max_size: int = 10 * 1024 * 1024 -) -> bool: - """Checks if the uploaded file exceeds the specified maximum size. - - Args: - input_file: The uploaded file to check. - max_size: The maximum allowed file size in bytes. Defaults to 10 MB. - - Returns: - True if the file exceeds the maximum size, False otherwise. - """ - try: - if input_file.file: - # Get size by moving the cursor to the end of the file and back. - input_file.file.seek(0, os.SEEK_END) - file_size = input_file.file.tell() - input_file.file.seek(0) - return file_size > max_size - except Exception as e: - print(f"Error checking file size: {e}") - return False - - def json_str_to_np_array( data: Optional[str], var_name: Optional[str] = None ) -> Optional[np.ndarray]: diff --git a/runner/app/utils/errors.py b/runner/app/utils/errors.py index 12650bd1a..70e852f72 100644 --- a/runner/app/utils/errors.py +++ b/runner/app/utils/errors.py @@ -17,10 +17,6 @@ def __init__(self, message="Error during model execution", original_exception=No self.original_exception = original_exception -class InvalidInputError(InferenceError): - """Exception raised when input validation fails.""" - pass - class ModelOOMError(InferenceError): """Exception raised when the model runs out of memory.""" diff --git a/runner/dl_checkpoints.sh b/runner/dl_checkpoints.sh index e2a068b57..167ec7178 100755 --- a/runner/dl_checkpoints.sh +++ b/runner/dl_checkpoints.sh @@ -50,7 +50,6 @@ function download_beta_models() { # Download custom pipeline models. huggingface-cli download facebook/sam2-hiera-large --include "*.pt" "*.yaml" --cache-dir models - huggingface-cli download parler-tts/parler-tts-large-v1 --include "*.safetensors" "*.json" "*.model" --cache-dir models huggingface-cli download ResembleAI/chatterbox --include "*.safetensors" "*.json" "*.model" --cache-dir models printf "\nDownloading token-gated models...\n" diff --git a/runner/docker/Dockerfile.text_to_speech b/runner/docker/Dockerfile.text_to_speech index ad7458bcf..00a70225e 100644 --- a/runner/docker/Dockerfile.text_to_speech +++ b/runner/docker/Dockerfile.text_to_speech @@ -1,22 +1,22 @@ ARG BASE_IMAGE=livepeer/ai-runner:base FROM ${BASE_IMAGE} -# Install CUDA Toolkit to enable flash attention. +# Install CUDA Toolkit to enable flash attention and audio processing dependencies. RUN apt-get update && \ apt-get install -y --no-install-recommends \ cuda-toolkit-12-1 \ - g++ && \ + g++ \ + libsndfile1 \ + libsndfile1-dev && \ rm -rf /var/lib/apt/lists/* -RUN pip install --no-cache-dir \ - ninja \ - transformers==4.43.3 \ - peft \ - deepcache \ - soundfile \ - g2p-en \ - flash_attn==2.5.6 \ - git+https://github.com/huggingface/parler-tts.git@5d0aca9753ab74ded179732f5bd797f7a8c6f8ee +# Clone Chatterbox, checkout specific commit, create pyproject.toml, and install it +RUN git clone https://huggingface.co/spaces/ResembleAI/Chatterbox /opt/Chatterbox && \ + cd /opt/Chatterbox && \ + git checkout bf4bbc30226326884e0d57b1e45ec0550683300f && \ + echo '[build-system]\nrequires = ["setuptools"]\nbuild-backend = "setuptools.build_meta"\n\n[project]\nname = "chatterbox"\nversion = "0.1.0"' > pyproject.toml && \ + pip install --no-cache-dir -e . && \ + pip install --no-cache-dir -r requirements.txt # Override base working directory to ensure the correct working directory. WORKDIR /app diff --git a/runner/docker/Dockerfile.text_to_speech:chatterbox b/runner/docker/Dockerfile.text_to_speech:chatterbox deleted file mode 100644 index 025d4b0a4..000000000 --- a/runner/docker/Dockerfile.text_to_speech:chatterbox +++ /dev/null @@ -1,25 +0,0 @@ -ARG BASE_IMAGE=livepeer/ai-runner:base -FROM ${BASE_IMAGE} - -# Install CUDA Toolkit to enable flash attention. -RUN apt-get update && \ - apt-get install -y --no-install-recommends \ - cuda-toolkit-12-1 \ - g++ && \ - rm -rf /var/lib/apt/lists/* - -# Clone Chatterbox, checkout specific commit, create pyproject.toml, and install it -RUN git clone https://huggingface.co/spaces/ResembleAI/Chatterbox /opt/Chatterbox && \ - cd /opt/Chatterbox && \ - git checkout bf4bbc30226326884e0d57b1e45ec0550683300f && \ - echo '[build-system]\nrequires = ["setuptools"]\nbuild-backend = "setuptools.build_meta"\n\n[project]\nname = "chatterbox"\nversion = "0.1.0"' > pyproject.toml && \ - pip install --no-cache-dir -e . && \ - pip install --no-cache-dir -r requirements.txt - -# Override base working directory to ensure the correct working directory. -WORKDIR /app - -# Copy app directory to avoid rebuilding the base image during development. -COPY app/ /app/app - -CMD ["uvicorn", "app.main:app", "--log-config", "app/cfg/uvicorn_logging_config.json", "--host", "", "--port", "8000"] From 46e96365923b52dddb71c04916efc4938752a664 Mon Sep 17 00:00:00 2001 From: Jason Stone Date: Fri, 20 Jun 2025 19:19:59 +0000 Subject: [PATCH 3/8] feat: update text-to-speech API to use base64 encoded audio prompts --- runner/app/routes/audio_to_text.py | 7 ------- runner/app/routes/utils.py | 1 - 2 files changed, 8 deletions(-) diff --git a/runner/app/routes/audio_to_text.py b/runner/app/routes/audio_to_text.py index d0e1b8470..0e2d8fb3b 100644 --- a/runner/app/routes/audio_to_text.py +++ b/runner/app/routes/audio_to_text.py @@ -12,7 +12,6 @@ from app.routes.utils import ( HTTPError, TextResponse, - file_exceeds_max_size, get_media_duration_ffmpeg, handle_pipeline_exception, http_error, @@ -134,12 +133,6 @@ async def audio_to_text( ), ) - if file_exceeds_max_size(audio, 50 * 1024 * 1024): - return JSONResponse( - status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, - content=http_error("File size exceeds limit."), - ) - try: duration = parse_key_from_metadata(metadata, "duration", float) if duration is None: diff --git a/runner/app/routes/utils.py b/runner/app/routes/utils.py index 98822a266..c44a8dab7 100644 --- a/runner/app/routes/utils.py +++ b/runner/app/routes/utils.py @@ -206,7 +206,6 @@ def audio_to_data_url(buffer: io.BytesIO, format: str = "wav") -> str: base64_audio = base64.b64encode(buffer.read()).decode("utf-8") return f"data:audio/{format};base64,{base64_audio}" - def json_str_to_np_array( data: Optional[str], var_name: Optional[str] = None ) -> Optional[np.ndarray]: From cbf6ece3216fc4ac90dda7d51231b1437616b43b Mon Sep 17 00:00:00 2001 From: Jason Stone Date: Fri, 20 Jun 2025 19:37:27 +0000 Subject: [PATCH 4/8] feat: add file size limit check to audio_to_text endpoint --- runner/app/routes/audio_to_text.py | 7 ++++ runner/app/routes/utils.py | 25 ++++++++++++ runner/gateway.openapi.yaml | 64 ++++++++++++++++-------------- runner/openapi.yaml | 62 +++++++++++++++-------------- 4 files changed, 100 insertions(+), 58 deletions(-) diff --git a/runner/app/routes/audio_to_text.py b/runner/app/routes/audio_to_text.py index 0e2d8fb3b..d0e1b8470 100644 --- a/runner/app/routes/audio_to_text.py +++ b/runner/app/routes/audio_to_text.py @@ -12,6 +12,7 @@ from app.routes.utils import ( HTTPError, TextResponse, + file_exceeds_max_size, get_media_duration_ffmpeg, handle_pipeline_exception, http_error, @@ -133,6 +134,12 @@ async def audio_to_text( ), ) + if file_exceeds_max_size(audio, 50 * 1024 * 1024): + return JSONResponse( + status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, + content=http_error("File size exceeds limit."), + ) + try: duration = parse_key_from_metadata(metadata, "duration", float) if duration is None: diff --git a/runner/app/routes/utils.py b/runner/app/routes/utils.py index c44a8dab7..5ed185f8b 100644 --- a/runner/app/routes/utils.py +++ b/runner/app/routes/utils.py @@ -206,6 +206,31 @@ def audio_to_data_url(buffer: io.BytesIO, format: str = "wav") -> str: base64_audio = base64.b64encode(buffer.read()).decode("utf-8") return f"data:audio/{format};base64,{base64_audio}" + +def file_exceeds_max_size( + input_file: UploadFile, max_size: int = 10 * 1024 * 1024 +) -> bool: + """Checks if the uploaded file exceeds the specified maximum size. + + Args: + input_file: The uploaded file to check. + max_size: The maximum allowed file size in bytes. Defaults to 10 MB. + + Returns: + True if the file exceeds the maximum size, False otherwise. + """ + try: + if input_file.file: + # Get size by moving the cursor to the end of the file and back. + input_file.file.seek(0, os.SEEK_END) + file_size = input_file.file.tell() + input_file.file.seek(0) + return file_size > max_size + except Exception as e: + print(f"Error checking file size: {e}") + return False + + def json_str_to_np_array( data: Optional[str], var_name: Optional[str] = None ) -> Optional[np.ndarray]: diff --git a/runner/gateway.openapi.yaml b/runner/gateway.openapi.yaml index 8dd2ecb80..9c68bcdae 100644 --- a/runner/gateway.openapi.yaml +++ b/runner/gateway.openapi.yaml @@ -466,15 +466,14 @@ paths: - generate summary: Text To Speech description: Generate a text-to-speech audio file based on the provided text - input and speaker description. + input.Optionally include base64-encoded audio for voice cloning. operationId: genTextToSpeech requestBody: content: - multipart/form-data: + application/json: schema: - allOf: - - $ref: '#/components/schemas/Body_genTextToSpeech' - title: Body + $ref: '#/components/schemas/TextToSpeechParams' + required: true responses: '200': description: Successful Response @@ -770,30 +769,6 @@ components: - image - model_id title: Body_genSegmentAnything2 - Body_genTextToSpeech: - properties: - audio_prompt: - anyOf: - - type: string - format: binary - - type: 'null' - title: Audio Prompt - description: Optional reference audio file for Chatterbox voice cloning - model_id: - type: string - title: Model Id - default: '' - text: - type: string - title: Text - default: '' - description: - type: string - title: Description - default: A male speaker delivers a slightly expressive and animated speech - with a moderate speed and pitch. - type: object - title: Body_genTextToSpeech Body_genUpscale: properties: prompt: @@ -1234,6 +1209,37 @@ components: - prompt - model_id title: TextToImageParams + TextToSpeechParams: + properties: + model_id: + anyOf: + - type: string + - type: 'null' + title: Model Id + description: "Optional Hugging Face model ID for text-to-speech generation.\ + \ If omitted, the pipeline\u2019s configured model is used." + text: + type: string + title: Text + description: Text input for speech generation. + default: When it was all over, the remaing animals, except for the pigs + and dogs, crept away in a body. They were shaken and miserable. They did + not know which was more shocking - the treachery of the animals who had + leagued themselves with Snowball, or the cruel retribution they had just + witnessed. + audio_prompt_base64: + anyOf: + - type: string + format: binary + - type: 'null' + title: Audio Prompt Base64 + description: Optional base64-encoded audio data for voice cloning reference. + Provide as base64-encoded string; it will be decoded server-side. Must + be a valid audio file format like WAV or MP3. + type: object + title: TextToSpeechParams + required: + - model_id ValidationError: properties: loc: diff --git a/runner/openapi.yaml b/runner/openapi.yaml index a25cd416d..d1a322238 100644 --- a/runner/openapi.yaml +++ b/runner/openapi.yaml @@ -466,15 +466,14 @@ paths: - generate summary: Text To Speech description: Generate a text-to-speech audio file based on the provided text - input and speaker description. + input.Optionally include base64-encoded audio for voice cloning. operationId: genTextToSpeech requestBody: content: - multipart/form-data: + application/json: schema: - allOf: - - $ref: '#/components/schemas/Body_genTextToSpeech' - title: Body + $ref: '#/components/schemas/TextToSpeechParams' + required: true responses: '200': description: Successful Response @@ -814,30 +813,6 @@ components: required: - image title: Body_genSegmentAnything2 - Body_genTextToSpeech: - properties: - audio_prompt: - anyOf: - - type: string - format: binary - - type: 'null' - title: Audio Prompt - description: Optional reference audio file for Chatterbox voice cloning - model_id: - type: string - title: Model Id - default: '' - text: - type: string - title: Text - default: '' - description: - type: string - title: Description - default: A male speaker delivers a slightly expressive and animated speech - with a moderate speed and pitch. - type: object - title: Body_genTextToSpeech Body_genUpscale: properties: prompt: @@ -1390,6 +1365,35 @@ components: required: - prompt title: TextToImageParams + TextToSpeechParams: + properties: + model_id: + anyOf: + - type: string + - type: 'null' + title: Model Id + description: "Optional Hugging Face model ID for text-to-speech generation.\ + \ If omitted, the pipeline\u2019s configured model is used." + text: + type: string + title: Text + description: Text input for speech generation. + default: When it was all over, the remaing animals, except for the pigs + and dogs, crept away in a body. They were shaken and miserable. They did + not know which was more shocking - the treachery of the animals who had + leagued themselves with Snowball, or the cruel retribution they had just + witnessed. + audio_prompt_base64: + anyOf: + - type: string + format: binary + - type: 'null' + title: Audio Prompt Base64 + description: Optional base64-encoded audio data for voice cloning reference. + Provide as base64-encoded string; it will be decoded server-side. Must + be a valid audio file format like WAV or MP3. + type: object + title: TextToSpeechParams ValidationError: properties: loc: From 39cea1831f2f5f8efafde943ec831a29fbf426d1 Mon Sep 17 00:00:00 2001 From: Jason Stone Date: Fri, 20 Jun 2025 19:39:36 +0000 Subject: [PATCH 5/8] chore: fix a small formatting issue --- runner/app/routes/text_to_speech.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runner/app/routes/text_to_speech.py b/runner/app/routes/text_to_speech.py index f3efbae1c..21d2145c1 100644 --- a/runner/app/routes/text_to_speech.py +++ b/runner/app/routes/text_to_speech.py @@ -44,7 +44,7 @@ class TextToSpeechParams(BaseModel): Optional[str], Field( default=None, - description="Optional Hugging Face model ID for text-to-speech generation. If omitted, the pipeline’s configured model is used.", + description="Optional Hugging Face model ID for text-to-speech generation. If omitted, the pipelines configured model is used.", ), ] text: Annotated[ From 5a276427a5889d5ea39cf8dbf075baae4afd79ae Mon Sep 17 00:00:00 2001 From: Jason Stone Date: Fri, 20 Jun 2025 19:40:46 +0000 Subject: [PATCH 6/8] chore: regenerate api bindinds --- runner/gateway.openapi.yaml | 4 ++-- runner/openapi.yaml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/runner/gateway.openapi.yaml b/runner/gateway.openapi.yaml index 9c68bcdae..3580c3df4 100644 --- a/runner/gateway.openapi.yaml +++ b/runner/gateway.openapi.yaml @@ -1216,8 +1216,8 @@ components: - type: string - type: 'null' title: Model Id - description: "Optional Hugging Face model ID for text-to-speech generation.\ - \ If omitted, the pipeline\u2019s configured model is used." + description: Optional Hugging Face model ID for text-to-speech generation. + If omitted, the pipelines configured model is used. text: type: string title: Text diff --git a/runner/openapi.yaml b/runner/openapi.yaml index d1a322238..85b700ae6 100644 --- a/runner/openapi.yaml +++ b/runner/openapi.yaml @@ -1372,8 +1372,8 @@ components: - type: string - type: 'null' title: Model Id - description: "Optional Hugging Face model ID for text-to-speech generation.\ - \ If omitted, the pipeline\u2019s configured model is used." + description: Optional Hugging Face model ID for text-to-speech generation. + If omitted, the pipelines configured model is used. text: type: string title: Text From ea261424362fec1d8bbca3efb8825d9a0f6650c5 Mon Sep 17 00:00:00 2001 From: Jason Stone Date: Sat, 21 Jun 2025 15:49:42 +0000 Subject: [PATCH 7/8] refactor: update text-to-speech API params and add async request handling --- runner/app/routes/text_to_speech.py | 18 +++++++----------- runner/gateway.openapi.yaml | 18 ++++++------------ runner/openapi.yaml | 18 ++++++------------ 3 files changed, 19 insertions(+), 35 deletions(-) diff --git a/runner/app/routes/text_to_speech.py b/runner/app/routes/text_to_speech.py index 21d2145c1..4a2959b82 100644 --- a/runner/app/routes/text_to_speech.py +++ b/runner/app/routes/text_to_speech.py @@ -6,6 +6,7 @@ import torch from fastapi import APIRouter, Depends, status from fastapi.responses import JSONResponse +from fastapi.concurrency import run_in_threadpool from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from pydantic import BaseModel, Field, validator @@ -41,28 +42,23 @@ class TextToSpeechParams(BaseModel): # TODO: Make model_id and other None properties optional once Go codegen tool # supports OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373 model_id: Annotated[ - Optional[str], + str, Field( - default=None, + default="", description="Optional Hugging Face model ID for text-to-speech generation. If omitted, the pipelines configured model is used.", ), ] text: Annotated[ str, Field( - default=( - "When it was all over, the remaing animals, except for the pigs and dogs, " - "crept away in a body. They were shaken and miserable. They did not know " - "which was more shocking - the treachery of the animals who had leagued " - "themselves with Snowball, or the cruel retribution they had just witnessed." - ), + default="Hi, there my name is AI.", description="Text input for speech generation.", ), ] audio_prompt_base64: Annotated[ - bytes | None, + bytes, Field( - default=None, + default="", description=( "Optional base64-encoded audio data for voice cloning reference. Provide as base64-encoded string; it will be decoded server-side. Must be a valid audio file format like WAV or MP3." ), @@ -171,7 +167,7 @@ async def text_to_speech( try: start_time = time.time() - output = pipeline(params) + output = await run_in_threadpool(pipeline, params) end_time = time.time() logger.info(f"TextToSpeechPipeline took {end_time - start_time} seconds.") except Exception as e: diff --git a/runner/gateway.openapi.yaml b/runner/gateway.openapi.yaml index 3580c3df4..d3fad2299 100644 --- a/runner/gateway.openapi.yaml +++ b/runner/gateway.openapi.yaml @@ -1212,30 +1212,24 @@ components: TextToSpeechParams: properties: model_id: - anyOf: - - type: string - - type: 'null' + type: string title: Model Id description: Optional Hugging Face model ID for text-to-speech generation. If omitted, the pipelines configured model is used. + default: '' text: type: string title: Text description: Text input for speech generation. - default: When it was all over, the remaing animals, except for the pigs - and dogs, crept away in a body. They were shaken and miserable. They did - not know which was more shocking - the treachery of the animals who had - leagued themselves with Snowball, or the cruel retribution they had just - witnessed. + default: Hi, there my name is AI. audio_prompt_base64: - anyOf: - - type: string - format: binary - - type: 'null' + type: string + format: binary title: Audio Prompt Base64 description: Optional base64-encoded audio data for voice cloning reference. Provide as base64-encoded string; it will be decoded server-side. Must be a valid audio file format like WAV or MP3. + default: '' type: object title: TextToSpeechParams required: diff --git a/runner/openapi.yaml b/runner/openapi.yaml index 85b700ae6..61b49523c 100644 --- a/runner/openapi.yaml +++ b/runner/openapi.yaml @@ -1368,30 +1368,24 @@ components: TextToSpeechParams: properties: model_id: - anyOf: - - type: string - - type: 'null' + type: string title: Model Id description: Optional Hugging Face model ID for text-to-speech generation. If omitted, the pipelines configured model is used. + default: '' text: type: string title: Text description: Text input for speech generation. - default: When it was all over, the remaing animals, except for the pigs - and dogs, crept away in a body. They were shaken and miserable. They did - not know which was more shocking - the treachery of the animals who had - leagued themselves with Snowball, or the cruel retribution they had just - witnessed. + default: Hi, there my name is AI. audio_prompt_base64: - anyOf: - - type: string - format: binary - - type: 'null' + type: string + format: binary title: Audio Prompt Base64 description: Optional base64-encoded audio data for voice cloning reference. Provide as base64-encoded string; it will be decoded server-side. Must be a valid audio file format like WAV or MP3. + default: '' type: object title: TextToSpeechParams ValidationError: From d84d607249a61c8a0aade78a6bcc553abf30bbb2 Mon Sep 17 00:00:00 2001 From: Jason Stone Date: Sun, 22 Jun 2025 12:47:41 +0000 Subject: [PATCH 8/8] minor change as the current version of hugginface in the chatterbox expects the weights of the models to be in .pt format and would throw error if they were in .safetensors It does download the weights if none are found automatically --- runner/dl_checkpoints.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runner/dl_checkpoints.sh b/runner/dl_checkpoints.sh index 167ec7178..0b4d8904f 100755 --- a/runner/dl_checkpoints.sh +++ b/runner/dl_checkpoints.sh @@ -50,7 +50,7 @@ function download_beta_models() { # Download custom pipeline models. huggingface-cli download facebook/sam2-hiera-large --include "*.pt" "*.yaml" --cache-dir models - huggingface-cli download ResembleAI/chatterbox --include "*.safetensors" "*.json" "*.model" --cache-dir models + huggingface-cli download ResembleAI/chatterbox --include "*.pt" "*.json" "*.model" --cache-dir models printf "\nDownloading token-gated models...\n"