diff --git a/runner/app/pipelines/image_to_video.py b/runner/app/pipelines/image_to_video.py index 96841ba16..ec1a854ef 100644 --- a/runner/app/pipelines/image_to_video.py +++ b/runner/app/pipelines/image_to_video.py @@ -1,11 +1,12 @@ import logging +import inspect import os import time -from typing import List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Type import PIL import torch -from diffusers import StableVideoDiffusionPipeline +from diffusers import DiffusionPipeline, LTXImageToVideoPipeline, StableVideoDiffusionPipeline from huggingface_hub import file_download from PIL import ImageFile @@ -22,6 +23,8 @@ class ImageToVideoPipeline(Pipeline): def __init__(self, model_id: str): + self.pipeline_name = "" + self.model_id = model_id kwargs = {"cache_dir": get_model_dir()} @@ -41,8 +44,21 @@ def __init__(self, model_id: str): kwargs["torch_dtype"] = torch.float16 kwargs["variant"] = "fp16" - self.ldm = StableVideoDiffusionPipeline.from_pretrained(model_id, **kwargs) - self.ldm.to(get_torch_device()) + logger.info("Loading DiffusionPipeline for model_id: %s", model_id) + self.ldm = DiffusionPipeline.from_pretrained(model_id, **kwargs) + + if any(substring in model_id.lower() for substring in ("ltx-video", "ltx")): + logger.info("Adjusting to LTXImageToVideoPipeline for model_id: %s", model_id) + self.ldm = LTXImageToVideoPipeline.from_pipe(self.ldm) + self.ldm.enable_model_cpu_offload() + self.ldm.vae.enable_slicing() + LOW_VRAM = os.getenv("USE_LOW_VRAM", "false") + if LOW_VRAM == "true": + self.ldm.enable_sequential_cpu_offload() + else: + self.ldm.to(get_torch_device()) + + self.pipeline_name = type(self.ldm).__name__ sfast_enabled = os.getenv("SFAST", "").strip().lower() == "true" deepcache_enabled = os.getenv("DEEPCACHE", "").strip().lower() == "true" @@ -52,7 +68,9 @@ def __init__(self, model_id: str): "as it may lead to suboptimal performance. Please disable one of them." ) - if sfast_enabled: + if sfast_enabled and self.pipeline_name == "LTXImageToVideoPipeline": + logger.warning("StableFast optimization is not compatible with LTXImageToVideoPipeline so,skipping.") + elif sfast_enabled: logger.info( "ImageToVideoPipeline will be dynamically compiled with stable-fast " "for %s", @@ -95,9 +113,11 @@ def __init__(self, model_id: str): ) logger.info("Total warmup time: %s seconds", total_time) - if deepcache_enabled: + if deepcache_enabled and self.pipeline_name == "LTXImageToVideoPipeline": + logger.warning("DeepCache optimization is not compatible with LTXImageToVideoPipeline so,skipping.") + elif deepcache_enabled: logger.info( - "TextToImagePipeline will be optimized with DeepCache for %s", + "ImageToVideoPipeline will be optimized with DeepCache for %s", model_id, ) from app.pipelines.optim.deepcache import enable_deepcache @@ -132,6 +152,13 @@ def __call__( ): del kwargs["num_inference_steps"] + if self.pipeline_name == "LTXImageToVideoPipeline": + pipeline_class = LTXImageToVideoPipeline + elif self.pipeline_name == "StableVideoDiffusionPipeline": + pipeline_class = StableVideoDiffusionPipeline + + kwargs = self._filter_valid_kwargs(pipeline_class, kwargs) + if safety_check: _, has_nsfw_concept = self._safety_checker.check_nsfw_images([image]) else: @@ -146,5 +173,14 @@ def __call__( return outputs.frames, has_nsfw_concept + @staticmethod + def _filter_valid_kwargs(pipeline_class: Type, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """ + Filters the kwargs to just include keys that are necesssary for the pipeline_class. + """ + + valid_kwargs = inspect.signature(pipeline_class.__call__).parameters.keys() + return {k: v for k, v in kwargs.items() if k in valid_kwargs} + def __str__(self) -> str: return f"ImageToVideoPipeline model_id={self.model_id}" diff --git a/runner/app/routes/image_to_video.py b/runner/app/routes/image_to_video.py index eb64a3ef1..7a410c6a3 100644 --- a/runner/app/routes/image_to_video.py +++ b/runner/app/routes/image_to_video.py @@ -74,6 +74,19 @@ async def image_to_video( UploadFile, File(description="Uploaded image to generate a video from."), ], + prompt: Annotated[ + str, + Form(description="Text prompt(s) to guide video generation for prompt accepting models.") + ] = "", + negative_prompt: Annotated[ + str, + Form( + description=( + "Text prompt(s) to guide what to exclude from video generation for prompt accepting models. " + "Ignored if guidance_scale < 1." + ) + ), + ] = "", model_id: Annotated[ str, Form(description="Hugging Face model ID used for video generation.") ] = "", @@ -123,6 +136,9 @@ async def image_to_video( ) ), ] = 25, # NOTE: Hardcoded due to varying pipeline values. + num_frames: Annotated[ + int, Form(description="The number of video frames to generate.") + ] = 25, # NOTE: Added `25` as default value to consider for `stable-video-diffusion-img2vid-xt` model having smaller default value than LTX-V in its pipeline. pipeline: Pipeline = Depends(get_pipeline), token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)), ): @@ -159,6 +175,9 @@ async def image_to_video( try: batch_frames, has_nsfw_concept = pipeline( image=Image.open(image.file).convert("RGB"), + prompt=prompt, + negative_prompt=negative_prompt, + num_frames=num_frames, height=height, width=width, fps=fps, diff --git a/runner/dl_checkpoints.sh b/runner/dl_checkpoints.sh index dfc4ff46c..b1da6b9be 100755 --- a/runner/dl_checkpoints.sh +++ b/runner/dl_checkpoints.sh @@ -77,8 +77,9 @@ function download_all_models() { huggingface-cli download SG161222/Realistic_Vision_V6.0_B1_noVAE --include "*.fp16.safetensors" "*.json" "*.txt" "*.bin" --exclude ".onnx" ".onnx_data" --cache-dir models huggingface-cli download black-forest-labs/FLUX.1-schnell --include "*.safetensors" "*.json" "*.txt" "*.model" --exclude ".onnx" ".onnx_data" --cache-dir models - # Download image-to-video models. - huggingface-cli download stabilityai/stable-video-diffusion-img2vid-xt --include "*.fp16.safetensors" "*.json" --cache-dir models + # Download image-to-video models. + huggingface-cli download stabilityai/stable-video-diffusion-img2vid-xt --include "*.fp16.safetensors" "*.json" --cache-dir models + huggingface-cli download Lightricks/LTX-Video --include "*.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models # Download image-to-text models. huggingface-cli download Salesforce/blip-image-captioning-large --include "*.safetensors" "*.json" --cache-dir models diff --git a/runner/gateway.openapi.yaml b/runner/gateway.openapi.yaml index c2243160a..2ceb7e391 100644 --- a/runner/gateway.openapi.yaml +++ b/runner/gateway.openapi.yaml @@ -661,6 +661,18 @@ components: format: binary title: Image description: Uploaded image to generate a video from. + prompt: + type: string + title: Prompt + description: Text prompt(s) to guide video generation for prompt accepting + models. + default: '' + negative_prompt: + type: string + title: Negative Prompt + description: Text prompt(s) to guide what to exclude from video generation + for prompt accepting models. Ignored if guidance_scale < 1. + default: '' model_id: type: string title: Model Id @@ -709,6 +721,11 @@ components: description: Number of denoising steps. More steps usually lead to higher quality images but slower inference. Modulated by strength. default: 25 + num_frames: + type: integer + title: Num Frames + description: The number of video frames to generate. + default: 25 type: object required: - image diff --git a/runner/openapi.yaml b/runner/openapi.yaml index e3aba929d..51a01fe52 100644 --- a/runner/openapi.yaml +++ b/runner/openapi.yaml @@ -707,6 +707,18 @@ components: format: binary title: Image description: Uploaded image to generate a video from. + prompt: + type: string + title: Prompt + description: Text prompt(s) to guide video generation for prompt accepting + models. + default: '' + negative_prompt: + type: string + title: Negative Prompt + description: Text prompt(s) to guide what to exclude from video generation + for prompt accepting models. Ignored if guidance_scale < 1. + default: '' model_id: type: string title: Model Id @@ -755,6 +767,11 @@ components: description: Number of denoising steps. More steps usually lead to higher quality images but slower inference. Modulated by strength. default: 25 + num_frames: + type: integer + title: Num Frames + description: The number of video frames to generate. + default: 25 type: object required: - image diff --git a/runner/requirements.txt b/runner/requirements.txt index 648e4ea0a..bf71a1a02 100644 --- a/runner/requirements.txt +++ b/runner/requirements.txt @@ -1,12 +1,12 @@ -diffusers==0.31.0 +diffusers==0.33.1 accelerate==0.30.1 -transformers==4.43.3 +transformers==4.51.3 fastapi==0.111.0 pydantic==2.7.2 Pillow==10.3.0 python-multipart==0.0.9 uvicorn==0.30.0 -huggingface_hub==0.23.2 +huggingface_hub>=0.27.0 xformers==0.0.23 triton>=2.1.0 peft==0.11.1