diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..e945b77 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,32 @@ +FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04 + +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt-get update && apt-get install -y \ + python3.8 \ + python3-pip \ + ffmpeg \ + libsm6 \ + libxext6 \ + git \ + wget \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +# Copy requirements +COPY requirements.txt /app/ + +# Install PyTorch +RUN pip3 install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118 + +# Install other requirements +RUN pip3 install -r requirements.txt + +# Copy application +COPY . /app/ + +# Expose Gradio port +EXPOSE 7860 + +CMD ["python3", "gradio_app.py", "--server_name", "0.0.0.0", "--port", "7860"] \ No newline at end of file diff --git a/download_checkpoints.sh b/download_checkpoints.sh index cd6eddc..2746f1a 100644 --- a/download_checkpoints.sh +++ b/download_checkpoints.sh @@ -1,5 +1,9 @@ -pip install gdown - gdown --id 1rvWuM12cyvNvBQNCLmG4Fr2L1rpjQBF0 mv float.pth checkpoints/ +hf download r-f/wav2vec-english-speech-emotion-recognition \ + --local-dir ./checkpoints/wav2vec-english-speech-emotion-recognition \ + --include "*" +hf download facebook/wav2vec2-base-960h \ + --local-dir ./checkpoints/wav2vec2-base-960h \ + --include "*" diff --git a/generate.py b/generate.py index 9ca93f8..b0ee0f6 100644 --- a/generate.py +++ b/generate.py @@ -1,7 +1,7 @@ """ Inference Stage 2 """ - +import os import os, torch, random, cv2, torchvision, subprocess, librosa, datetime, tempfile, face_alignment import numpy as np import albumentations as A @@ -25,7 +25,10 @@ def __init__(self, opt): self.fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, flip_input=False) # wav2vec2 audio preprocessor - self.wav2vec_preprocessor = Wav2Vec2FeatureExtractor.from_pretrained(opt.wav2vec_model_path, local_files_only=True) + if os.path.exists(opt.wav2vec_model_path): + self.wav2vec_preprocessor = Wav2Vec2FeatureExtractor.from_pretrained(opt.wav2vec_model_path, local_files_only=True) + else: + self.wav2vec_preprocessor = Wav2Vec2FeatureExtractor.from_pretrained(opt.wav2vec_model_path) # image transform self.transform = A.Compose([ @@ -173,9 +176,10 @@ def initialize(self, parser): parser.add_argument('--res_video_path', default=None, type=str, help='res video path') parser.add_argument('--ckpt_path', - default="/home/nvadmin/workspace/taek/float-pytorch/checkpoints/float.pth", type=str, help='checkpoint path') + default="./checkpoints/float.pth", type=str, help='checkpoint path') parser.add_argument('--res_dir', default="./results", type=str, help='result dir') + return parser diff --git a/gradio_interface.py b/gradio_interface.py new file mode 100644 index 0000000..b6cd270 --- /dev/null +++ b/gradio_interface.py @@ -0,0 +1,493 @@ +""" +FLOAT - Gradio Interface for Audio-Driven Talking Face Generation +""" + +import os +import gradio as gr +import datetime +from pathlib import Path + +# Import the inference components +from generate import InferenceAgent, InferenceOptions + + +class GradioInterface: + def __init__(self): + # Initialize options with defaults + self.opt = InferenceOptions().parse() + self.opt.rank, self.opt.ngpus = 0, 1 + + # Create results directory + os.makedirs(self.opt.res_dir, exist_ok=True) + + # Initialize the inference agent + print("Loading FLOAT model...") + self.agent = InferenceAgent(self.opt) + print("Model loaded successfully!") + + def generate_video( + self, + ref_image, + audio_file, + emotion, + a_cfg_scale, + r_cfg_scale, + e_cfg_scale, + nfe, + seed, + no_crop, + progress=gr.Progress() + ): + """ + Generate talking face video from reference image and audio + """ + try: + progress(0, desc="Preparing inputs...") + + # Validate inputs + if ref_image is None: + return None, "āŒ Please upload a reference image" + if audio_file is None: + return None, "āŒ Please upload an audio file" + + # Generate output filename + video_name = Path(ref_image).stem + audio_name = Path(audio_file).stem + call_time = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + + res_video_path = os.path.join( + self.opt.res_dir, + f"{call_time}-{video_name}-{audio_name}-nfe{nfe}-seed{seed}-acfg{a_cfg_scale}-ecfg{e_cfg_scale}-{emotion}.mp4" + ) + + progress(0.3, desc="Running inference...") + + # Run inference + output_path = self.agent.run_inference( + res_video_path=res_video_path, + ref_path=ref_image, + audio_path=audio_file, + a_cfg_scale=a_cfg_scale, + r_cfg_scale=r_cfg_scale, + e_cfg_scale=e_cfg_scale, + emo=emotion, + nfe=nfe, + no_crop=no_crop, + seed=seed, + verbose=True + ) + + progress(1.0, desc="Complete!") + + status_msg = f"āœ… Video generated successfully!\nšŸ“ Saved to: {output_path}" + return output_path, status_msg + + except Exception as e: + error_msg = f"āŒ Error during generation: {str(e)}" + print(error_msg) + return None, error_msg + + +def create_interface(): + """Create and configure the Gradio interface""" + + interface = GradioInterface() + + # Custom CSS for better styling + custom_css = """ + .gradio-container { + font-family: 'Arial', sans-serif; + } + .main-header { + text-align: center; + margin-bottom: 2rem; + } + .status-box { + padding: 1rem; + border-radius: 8px; + margin-top: 1rem; + } + """ + + with gr.Blocks(css=custom_css, title="FLOAT - Talking Face Generation") as demo: + + gr.Markdown( + """ + # šŸŽ­ FLOAT: Audio-Driven Talking Face Generation + + Generate realistic talking face videos from a reference image and audio file. + Upload your inputs and adjust the parameters below to create your video. + """ + ) + + with gr.Row(): + # Left column - Inputs + with gr.Column(scale=1): + gr.Markdown("### šŸ“„ Input Files") + + ref_image = gr.Image( + label="Reference Image", + type="filepath", + sources=["upload"], + height=300 + ) + gr.Markdown("*Upload a clear frontal face image*") + + audio_file = gr.Audio( + label="Audio File", + type="filepath", + sources=["upload"] + ) + gr.Markdown("*Upload the audio/speech file*") + + with gr.Accordion("āš™ļø Advanced Options", open=False): + emotion = gr.Dropdown( + choices=['S2E', 'angry', 'disgust', 'fear', 'happy', 'neutral', 'sad', 'surprise'], + value='S2E', + label="Emotion Control", + info="Choose target emotion or 'S2E' for speech-to-emotion" + ) + + no_crop = gr.Checkbox( + label="Skip Face Cropping", + value=False, + info="Enable if image is already cropped and aligned" + ) + + seed = gr.Slider( + minimum=0, + maximum=10000, + value=25, + step=1, + label="Random Seed", + info="Set seed for reproducible results" + ) + + # Right column - Parameters & Output + with gr.Column(scale=1): + gr.Markdown("### šŸŽ›ļø Generation Parameters") + + with gr.Group(): + nfe = gr.Slider( + minimum=1, + maximum=50, + value=10, + step=1, + label="Number of Function Evaluations (NFE)", + info="Higher = better quality but slower (10-20 recommended)" + ) + + a_cfg_scale = gr.Slider( + minimum=0.0, + maximum=5.0, + value=2.0, + step=0.1, + label="Audio CFG Scale", + info="Audio guidance strength (1.5-3.0 recommended)" + ) + + r_cfg_scale = gr.Slider( + minimum=0.0, + maximum=3.0, + value=1.0, + step=0.1, + label="Reference CFG Scale", + info="Reference image guidance strength" + ) + + e_cfg_scale = gr.Slider( + minimum=0.0, + maximum=3.0, + value=1.0, + step=0.1, + label="Emotion CFG Scale", + info="Emotion control guidance strength" + ) + + generate_btn = gr.Button( + "šŸš€ Generate Video", + variant="primary", + size="lg" + ) + + status_output = gr.Textbox( + label="Status", + placeholder="Status messages will appear here...", + lines=3 + ) + + video_output = gr.Video( + label="Generated Video", + height=400 + ) + + # Parameter presets + with gr.Accordion("šŸ“‹ Parameter Presets", open=False): + gr.Markdown( + """ + ### Quick Presets + + **Fast Preview** (NFE=5, A_CFG=2.0) + - Quick generation for testing + - Lower quality but fast + + **Balanced** (NFE=10, A_CFG=2.0) ⭐ *Default* + - Good balance of quality and speed + - Recommended for most uses + + **High Quality** (NFE=20, A_CFG=2.5) + - Best quality output + - Slower generation time + + **Expressive** (NFE=15, E_CFG=1.5) + - Enhanced emotional expressions + - Good for dramatic content + """ + ) + + with gr.Row(): + preset_fast = gr.Button("⚔ Fast Preview") + preset_balanced = gr.Button("āš–ļø Balanced") + preset_quality = gr.Button("šŸ’Ž High Quality") + preset_expressive = gr.Button("šŸŽ­ Expressive") + + # Information section + with gr.Accordion("ā„¹ļø Help & Information", open=False): + gr.Markdown( + """ + ## How to Use + + 1. **Upload Reference Image**: Choose a clear, frontal face image (512x512 recommended) + 2. **Upload Audio**: Select the audio file for lip-sync generation + 3. **Adjust Parameters**: Modify generation settings or use presets + 4. **Generate**: Click the generate button and wait for processing + + ## Parameter Guide + + - **NFE**: Controls generation steps. Higher = better quality but slower + - **Audio CFG**: Controls how closely video follows audio. Higher = stricter sync + - **Reference CFG**: Controls identity preservation. Higher = more similar to reference + - **Emotion CFG**: Controls emotion expression strength + - **Emotion**: Choose specific emotion or 'S2E' for automatic emotion from speech + + ## Tips + + - Use high-quality, well-lit reference images for best results + - Audio should be clear with minimal background noise + - Start with default parameters and adjust based on results + - Enable "Skip Face Cropping" only if your image is pre-processed + + ## Supported Formats + + - **Images**: JPG, PNG + - **Audio**: WAV, MP3, M4A, FLAC + """ + ) + + # Event handlers + def set_fast_preset(): + return 5, 2.0, 1.0, 1.0 + + def set_balanced_preset(): + return 10, 2.0, 1.0, 1.0 + + def set_quality_preset(): + return 20, 2.5, 1.0, 1.0 + + def set_expressive_preset(): + return 15, 2.0, 1.0, 1.5 + + # Connect preset buttons + preset_fast.click( + fn=set_fast_preset, + outputs=[nfe, a_cfg_scale, r_cfg_scale, e_cfg_scale] + ) + + preset_balanced.click( + fn=set_balanced_preset, + outputs=[nfe, a_cfg_scale, r_cfg_scale, e_cfg_scale] + ) + + preset_quality.click( + fn=set_quality_preset, + outputs=[nfe, a_cfg_scale, r_cfg_scale, e_cfg_scale] + ) + + preset_expressive.click( + fn=set_expressive_preset, + outputs=[nfe, a_cfg_scale, r_cfg_scale, e_cfg_scale] + ) + + # Connect generate button + generate_btn.click( + fn=interface.generate_video, + inputs=[ + ref_image, + audio_file, + emotion, + a_cfg_scale, + r_cfg_scale, + e_cfg_scale, + nfe, + seed, + no_crop + ], + outputs=[video_output, status_output] + ) + + # Examples + gr.Markdown("### šŸ“š Example Configurations") + gr.Examples( + examples=[ + ["S2E", 2.0, 1.0, 1.0, 10, 25, False], + ["happy", 2.5, 1.0, 1.5, 15, 42, False], + ["neutral", 2.0, 1.2, 1.0, 20, 100, False], + ], + inputs=[emotion, a_cfg_scale, r_cfg_scale, e_cfg_scale, nfe, seed, no_crop], + label="Try these parameter combinations" + ) + + return demo + + +def parse_launch_args(): + """Parse command-line arguments for Gradio launch""" + import argparse + + parser = argparse.ArgumentParser(description='Launch FLOAT Gradio Interface') + + # Server options + parser.add_argument('--port', type=int, default=7860, + help='Port to run the server on (default: 7860)') + parser.add_argument('--server_name', type=str, default="0.0.0.0", + help='Server name/IP to bind to (default: 0.0.0.0)') + parser.add_argument('--share', action='store_true', + help='Create a public share link') + parser.add_argument('--auth', type=str, default=None, + help='Username and password for authentication, format: "username:password"') + parser.add_argument('--auth_message', type=str, default="Please login to access FLOAT", + help='Message to display on login page') + + # Model options + parser.add_argument('--ckpt_path', type=str, + default="./checkpoints/float.pth", + help='Path to model checkpoint') + parser.add_argument('--res_dir', type=str, default="./results", + help='Directory to save generated videos') + parser.add_argument('--wav2vec_model_path', type=str, default="facebook/wav2vec2-base-960h", + help='Path to wav2vec2 model') + + # Interface options + parser.add_argument('--debug', action='store_true', + help='Enable debug mode with detailed error messages') + parser.add_argument('--queue', action='store_true', default=True, + help='Enable request queuing (default: True)') + parser.add_argument('--max_threads', type=int, default=4, + help='Maximum number of concurrent threads (default: 4)') + parser.add_argument('--inbrowser', action='store_true', + help='Automatically open in browser') + parser.add_argument('--prevent_thread_lock', action='store_true', + help='Prevent thread lock (useful for debugging)') + + # Advanced options + parser.add_argument('--ssl_keyfile', type=str, default=None, + help='Path to SSL key file for HTTPS') + parser.add_argument('--ssl_certfile', type=str, default=None, + help='Path to SSL certificate file for HTTPS') + parser.add_argument('--ssl_keyfile_password', type=str, default=None, + help='Password for SSL key file') + parser.add_argument('--favicon_path', type=str, default=None, + help='Path to custom favicon') + + return parser.parse_args() + + +if __name__ == "__main__": + # Parse launch arguments + args = parse_launch_args() + + # Override inference options with command-line args if provided + import sys + inference_args = [] + if args.ckpt_path: + inference_args.extend(['--ckpt_path', args.ckpt_path]) + if args.res_dir: + inference_args.extend(['--res_dir', args.res_dir]) + if args.wav2vec_model_path: + inference_args.extend(['--wav2vec_model_path', args.wav2vec_model_path]) + + # Temporarily modify sys.argv for InferenceOptions + original_argv = sys.argv.copy() + sys.argv = [sys.argv[0]] + inference_args + + # Create and launch the interface + demo = create_interface() + + # Restore original argv + sys.argv = original_argv + + # Parse authentication if provided + auth_tuple = None + if args.auth: + try: + username, password = args.auth.split(':', 1) + auth_tuple = (username, password) + print(f"šŸ”’ Authentication enabled for user: {username}") + except ValueError: + print("āš ļø Invalid auth format. Use 'username:password'") + + # Print launch information + print("\n" + "="*60) + print("šŸš€ Launching FLOAT Gradio Interface") + print("="*60) + print(f"šŸ“ Server: {args.server_name}:{args.port}") + print(f"šŸ”— Local URL: http://localhost:{args.port}") + if args.share: + print("🌐 Public sharing: Enabled") + if auth_tuple: + print(f"šŸ”’ Authentication: Enabled") + print(f"šŸ’¾ Results directory: {args.res_dir}") + print(f"šŸ¤– Model checkpoint: {args.ckpt_path}") + print("="*60 + "\n") + + # Launch configuration + launch_kwargs = { + 'server_name': args.server_name, + 'server_port': args.port, + 'share': args.share, + 'show_error': True, + 'inbrowser': args.inbrowser, + 'prevent_thread_lock': args.prevent_thread_lock, + } + + # Add optional parameters + if auth_tuple: + launch_kwargs['auth'] = auth_tuple + launch_kwargs['auth_message'] = args.auth_message + + if args.ssl_keyfile and args.ssl_certfile: + launch_kwargs['ssl_keyfile'] = args.ssl_keyfile + launch_kwargs['ssl_certfile'] = args.ssl_certfile + if args.ssl_keyfile_password: + launch_kwargs['ssl_keyfile_password'] = args.ssl_keyfile_password + print("šŸ” HTTPS enabled") + + if args.favicon_path: + launch_kwargs['favicon_path'] = args.favicon_path + + if args.debug: + launch_kwargs['debug'] = True + print("šŸ› Debug mode enabled") + + # Launch the demo + try: + if args.queue: + demo.queue(max_size=args.max_threads) + + demo.launch(**launch_kwargs) + except KeyboardInterrupt: + print("\n\nšŸ‘‹ Shutting down gracefully...") + except Exception as e: + print(f"\nāŒ Error launching interface: {e}") + if args.debug: + raise diff --git a/main.py b/main.py new file mode 100644 index 0000000..35aa72a --- /dev/null +++ b/main.py @@ -0,0 +1,673 @@ +""" +FastAPI WebSocket API for FLOAT - Audio-Driven Talking Face Generation +Save this file as: api.py +Run with:cc +""" +import os +import json +import base64 +import tempfile +import datetime +from pathlib import Path +from typing import Optional, Dict +import asyncio +from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException +from fastapi.responses import JSONResponse +from pydantic import BaseModel, Field +import uvloop + +# Import the inference components +from generate import InferenceAgent, InferenceOptions + +# Use uvloop for better performance +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + +# Configuration +DEFAULT_IMAGE_PATH = "img.jpg" +RESULTS_DIR = "./results" +MAX_CONCURRENT_GENERATIONS = 3 # Maximum number of concurrent video generations + +# Initialize FastAPI app +app = FastAPI( + title="FLOAT WebSocket API", + description="Audio-Driven Talking Face Generation via WebSocket and REST API", + version="1.0.0" +) + +# Pydantic models for POST request +class AudioData(BaseModel): + content: str = Field(..., description="Base64 encoded audio data") + ext: str = Field(..., description="Audio file extension (e.g., 'wav', 'mp3')") + +class ImageData(BaseModel): + content: str = Field(..., description="Base64 encoded image data") + ext: str = Field(..., description="Image file extension (e.g., 'jpg', 'png')") + +class GenerationParams(BaseModel): + emotion: str = Field("S2E", description="Emotion control: S2E, angry, disgust, fear, happy, neutral, sad, surprise") + a_cfg_scale: float = Field(2.0, ge=0.0, le=5.0, description="Audio CFG scale") + r_cfg_scale: float = Field(1.0, ge=0.0, le=3.0, description="Reference CFG scale") + e_cfg_scale: float = Field(1.0, ge=0.0, le=3.0, description="Emotion CFG scale") + nfe: int = Field(10, ge=1, le=50, description="Number of function evaluations") + seed: int = Field(25, ge=0, description="Random seed") + no_crop: bool = Field(False, description="Skip face cropping") + +class GenerationRequest(BaseModel): + audio: AudioData + image: Optional[ImageData] = None + params: Optional[GenerationParams] = None + +class BatchGenerationRequest(BaseModel): + audios: list[AudioData] = Field(..., description="List of audio files to process") + image: Optional[ImageData] = Field(None, description="Single image to use for all audios") + images: Optional[list[ImageData]] = Field(None, description="List of images (one per audio)") + params: Optional[GenerationParams] = Field(None, description="Generation parameters to use for all") + + class Config: + schema_extra = { + "example": { + "audios": [ + {"content": "base64_audio_1", "ext": "wav"}, + {"content": "base64_audio_2", "ext": "wav"} + ], + "image": {"content": "base64_image", "ext": "jpg"}, + "params": {"emotion": "happy", "nfe": 10} + } + } + +# Global model instance and semaphore +agent = None +opt = None +generation_semaphore = None # Will be initialized on startup +active_generations = 0 # Track active generation count +generation_lock = asyncio.Lock() # Lock for updating counter + +@app.on_event("startup") +async def startup_event(): + """Initialize model on startup""" + global agent, opt, generation_semaphore + + print("\n" + "="*60) + print("šŸš€ Starting FLOAT WebSocket API") + print("="*60) + + # Initialize the semaphore for concurrent generation control + generation_semaphore = asyncio.Semaphore(MAX_CONCURRENT_GENERATIONS) + print(f"šŸ”’ Concurrent generation limit: {MAX_CONCURRENT_GENERATIONS}") + + # Initialize options with empty args to use defaults + import sys + original_argv = sys.argv.copy() + sys.argv = [sys.argv[0]] # Keep only script name + + opt = InferenceOptions().parse() + opt.rank, opt.ngpus = 0, 1 + opt.res_dir = RESULTS_DIR + + sys.argv = original_argv + + # Create results directory + os.makedirs(opt.res_dir, exist_ok=True) + + # Load model + print("šŸ¤– Loading FLOAT model...") + agent = InferenceAgent(opt) + print("āœ… Model loaded successfully!") + + # Check default image + if os.path.exists(DEFAULT_IMAGE_PATH): + print(f"šŸ“ø Default image: {DEFAULT_IMAGE_PATH}") + else: + print(f"āš ļø Warning: Default image not found at {DEFAULT_IMAGE_PATH}") + + print(f"šŸ’¾ Results directory: {opt.res_dir}") + print("="*60) + print("āœ… Server ready! Waiting for connections...") + print("="*60 + "\n") + +@app.get("/") +async def root(): + """Health check endpoint""" + return { + "status": "online", + "service": "FLOAT WebSocket API", + "version": "1.0.0", + "max_concurrent_generations": MAX_CONCURRENT_GENERATIONS, + "active_generations": active_generations, + "endpoints": { + "websocket": "/ws", + "rest_api": "/generate", + "batch_api": "/generate/batch", + "health": "/health", + "status": "/status" + } + } + +@app.get("/health") +async def health_check(): + """Detailed health check""" + return { + "status": "healthy", + "model_loaded": agent is not None, + "default_image_exists": os.path.exists(DEFAULT_IMAGE_PATH), + "results_dir": opt.res_dir if opt else None, + "max_concurrent_generations": MAX_CONCURRENT_GENERATIONS, + "active_generations": active_generations + } + +@app.get("/status") +async def get_status(): + """Get current generation status""" + available_slots = MAX_CONCURRENT_GENERATIONS - active_generations + return { + "max_concurrent_generations": MAX_CONCURRENT_GENERATIONS, + "active_generations": active_generations, + "available_slots": available_slots, + "queue_full": available_slots == 0 + } + +@app.post("/generate") +async def generate_video(request: GenerationRequest): + """ + REST API endpoint for video generation + + Request body: + { + "audio": {"content": "base64_data", "ext": "wav"}, + "image": {"content": "base64_data", "ext": "jpg"}, // Optional + "params": { + "emotion": "S2E", + "a_cfg_scale": 2.0, + "r_cfg_scale": 1.0, + "e_cfg_scale": 1.0, + "nfe": 10, + "seed": 25, + "no_crop": false + } + } + + Returns: + { + "status": "success", + "video": "base64_encoded_video", + "video_path": "/path/to/video.mp4", + "message": "Video generated successfully", + "params_used": {...} + } + """ + if agent is None: + raise HTTPException(status_code=503, detail="Model not loaded yet") + + try: + # Convert Pydantic models to dict format expected by process_generation_request + request_data = { + "audio": { + "content": request.audio.content, + "ext": request.audio.ext + } + } + + # Add image if provided + if request.image: + request_data["image"] = { + "content": request.image.content, + "ext": request.image.ext + } + + # Add params if provided + if request.params: + request_data["params"] = { + "emotion": request.params.emotion, + "a_cfg_scale": request.params.a_cfg_scale, + "r_cfg_scale": request.params.r_cfg_scale, + "e_cfg_scale": request.params.e_cfg_scale, + "nfe": request.params.nfe, + "seed": request.params.seed, + "no_crop": request.params.no_crop + } + + # Process the request with semaphore control + result = await process_generation_request(request_data) + + # Return error if processing failed + if result.get("status") == "error": + raise HTTPException(status_code=500, detail=result.get("error", "Generation failed")) + + return result + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") + +@app.post("/generate/batch") +async def generate_video_batch(request: BatchGenerationRequest): + """ + Batch video generation endpoint - process multiple audios at once + + Request body: + { + "audios": [ + {"content": "base64_audio_1", "ext": "wav"}, + {"content": "base64_audio_2", "ext": "wav"} + ], + "image": {"content": "base64_image", "ext": "jpg"}, // Optional: single image for all + "images": [ // Optional: one image per audio (overrides "image") + {"content": "base64_image_1", "ext": "jpg"}, + {"content": "base64_image_2", "ext": "jpg"} + ], + "params": { + "emotion": "S2E", + "nfe": 10, + ... + } + } + + Returns: + { + "status": "success", + "total": 2, + "successful": 2, + "failed": 0, + "results": [ + { + "index": 0, + "status": "success", + "video": "base64_video_1", + "video_path": "/path/to/video1.mp4", + ... + }, + { + "index": 1, + "status": "success", + "video": "base64_video_2", + "video_path": "/path/to/video2.mp4", + ... + } + ] + } + """ + if agent is None: + raise HTTPException(status_code=503, detail="Model not loaded yet") + + try: + # Validate input + num_audios = len(request.audios) + if num_audios == 0: + raise HTTPException(status_code=400, detail="At least one audio file is required") + + # Check if using per-audio images or single image + if request.images: + if len(request.images) != num_audios: + raise HTTPException( + status_code=400, + detail=f"Number of images ({len(request.images)}) must match number of audios ({num_audios})" + ) + use_multiple_images = True + else: + use_multiple_images = False + + print(f"\nšŸ“¦ Processing batch of {num_audios} audio files...") + print(f"šŸ”’ Concurrent limit: {MAX_CONCURRENT_GENERATIONS}, Active: {active_generations}") + + # Process each audio (semaphore will automatically control concurrency) + tasks = [] + + for idx, audio in enumerate(request.audios): + # Build request for this audio + single_request = { + "audio": { + "content": audio.content, + "ext": audio.ext + } + } + + # Determine which image to use + if use_multiple_images: + single_request["image"] = { + "content": request.images[idx].content, + "ext": request.images[idx].ext + } + elif request.image: + single_request["image"] = { + "content": request.image.content, + "ext": request.image.ext + } + + # Add params if provided + if request.params: + single_request["params"] = { + "emotion": request.params.emotion, + "a_cfg_scale": request.params.a_cfg_scale, + "r_cfg_scale": request.params.r_cfg_scale, + "e_cfg_scale": request.params.e_cfg_scale, + "nfe": request.params.nfe, + "seed": request.params.seed, + "no_crop": request.params.no_crop + } + + # Create task for this generation + task = process_single_batch_item(idx, single_request, num_audios) + tasks.append(task) + + # Wait for all tasks to complete (concurrency controlled by semaphore) + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Process results + successful = 0 + failed = 0 + processed_results = [] + + for idx, result in enumerate(results): + if isinstance(result, Exception): + failed += 1 + processed_results.append({ + "index": idx, + "status": "error", + "error": str(result), + "message": f"Failed to process audio {idx + 1}" + }) + else: + if result.get("status") == "success": + successful += 1 + else: + failed += 1 + processed_results.append(result) + + print(f"\nšŸ“Š Batch processing complete: {successful} successful, {failed} failed") + + # Return batch results + return { + "status": "success" if failed == 0 else "partial", + "total": num_audios, + "successful": successful, + "failed": failed, + "results": processed_results, + "message": f"Processed {num_audios} audios: {successful} successful, {failed} failed" + } + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Batch processing error: {str(e)}") + +async def process_single_batch_item(idx: int, request_data: Dict, total: int) -> Dict: + """Process a single batch item with logging""" + print(f"\nšŸŽµ Queuing audio {idx + 1}/{total}...") + + try: + result = await process_generation_request(request_data) + result["index"] = idx + + if result.get("status") == "success": + print(f"āœ… Audio {idx + 1} completed successfully") + else: + print(f"āŒ Audio {idx + 1} failed: {result.get('error', 'Unknown error')}") + + return result + + except Exception as e: + print(f"āŒ Audio {idx + 1} failed with exception: {str(e)}") + return { + "index": idx, + "status": "error", + "error": str(e), + "message": f"Failed to process audio {idx + 1}" + } + +def save_base64_file(base64_content: str, extension: str, prefix: str = "temp") -> str: + """Save base64 content to a temporary file""" + temp_file = tempfile.NamedTemporaryFile( + delete=False, + suffix=f".{extension}", + prefix=f"{prefix}_" + ) + + file_data = base64.b64decode(base64_content) + temp_file.write(file_data) + temp_file.close() + + return temp_file.name + +def read_file_as_base64(file_path: str) -> str: + """Read file and return as base64 string""" + with open(file_path, 'rb') as f: + return base64.b64encode(f.read()).decode('utf-8') + +async def process_generation_request(request_data: Dict) -> Dict: + """ + Process video generation request with semaphore-based concurrency control + + Expected format: + { + "audio": {"content": "base64_data", "ext": "wav"}, + "image": {"content": "base64_data", "ext": "jpg"}, # Optional + "params": { + "emotion": "S2E", + "a_cfg_scale": 2.0, + "r_cfg_scale": 1.0, + "e_cfg_scale": 1.0, + "nfe": 10, + "seed": 25, + "no_crop": false + } + } + """ + global active_generations + + # Acquire semaphore to limit concurrent generations + async with generation_semaphore: + # Update active generation counter + async with generation_lock: + active_generations += 1 + current_count = active_generations + + print(f"šŸ”’ Generation slot acquired ({current_count}/{MAX_CONCURRENT_GENERATIONS} active)") + + temp_files = [] + + try: + # Validate audio + if "audio" not in request_data: + return { + "status": "error", + "error": "Missing 'audio' field in request" + } + + audio_data = request_data["audio"] + if "content" not in audio_data or "ext" not in audio_data: + return { + "status": "error", + "error": "Audio must have 'content' and 'ext' fields" + } + + # Save audio to temp file + audio_ext = audio_data["ext"].lstrip('.') + audio_path = save_base64_file(audio_data["content"], audio_ext, "audio") + temp_files.append(audio_path) + print(f"šŸ’¾ Saved audio: {audio_path}") + + # Handle image + if "image" in request_data and request_data["image"]: + image_data = request_data["image"] + if "content" not in image_data or "ext" not in image_data: + return { + "status": "error", + "error": "Image must have 'content' and 'ext' fields" + } + + image_ext = image_data["ext"].lstrip('.') + image_path = save_base64_file(image_data["content"], image_ext, "image") + temp_files.append(image_path) + print(f"šŸ“ø Using uploaded image: {image_path}") + else: + if not os.path.exists(DEFAULT_IMAGE_PATH): + return { + "status": "error", + "error": f"Default image not found: {DEFAULT_IMAGE_PATH}" + } + image_path = DEFAULT_IMAGE_PATH + print(f"šŸ“ø Using default image: {image_path}") + + # Extract parameters + params = request_data.get("params", {}) + emotion = params.get("emotion", "S2E") + a_cfg_scale = params.get("a_cfg_scale", 2.0) + r_cfg_scale = params.get("r_cfg_scale", 1.0) + e_cfg_scale = params.get("e_cfg_scale", 1.0) + nfe = params.get("nfe", 10) + seed = params.get("seed", 25) + no_crop = params.get("no_crop", False) + + # Generate output path + call_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + res_video_path = os.path.join( + opt.res_dir, + f"{call_time}_generated_nfe{nfe}_seed{seed}.mp4" + ) + + print(f"šŸŽ¬ Generating video...") + print(f" Emotion: {emotion}, NFE: {nfe}, Seed: {seed}") + + # Run inference in thread pool to avoid blocking + loop = asyncio.get_event_loop() + output_path = await loop.run_in_executor( + None, + lambda: agent.run_inference( + res_video_path=res_video_path, + ref_path=image_path, + audio_path=audio_path, + a_cfg_scale=a_cfg_scale, + r_cfg_scale=r_cfg_scale, + e_cfg_scale=e_cfg_scale, + emo=emotion, + nfe=nfe, + no_crop=no_crop, + seed=seed, + verbose=True + ) + ) + + print(f"āœ… Video generated: {output_path}") + + # Read video as base64 + video_base64 = read_file_as_base64(output_path) + + return { + "status": "success", + "video": video_base64, + "video_path": output_path, + "message": "Video generated successfully", + "params_used": { + "emotion": emotion, + "nfe": nfe, + "seed": seed, + "a_cfg_scale": a_cfg_scale, + "r_cfg_scale": r_cfg_scale, + "e_cfg_scale": e_cfg_scale + } + } + + except Exception as e: + print(f"āŒ Error during processing: {str(e)}") + import traceback + traceback.print_exc() + + return { + "status": "error", + "error": str(e), + "message": "Failed to generate video" + } + + finally: + # Cleanup temp files + for temp_file in temp_files: + try: + if os.path.exists(temp_file): + os.remove(temp_file) + print(f"šŸ—‘ļø Cleaned up: {temp_file}") + except Exception as e: + print(f"āš ļø Failed to delete {temp_file}: {e}") + + # Release generation slot + async with generation_lock: + active_generations -= 1 + remaining = active_generations + + print(f"šŸ”“ Generation slot released ({remaining}/{MAX_CONCURRENT_GENERATIONS} active)") + +@app.websocket("/ws") +async def websocket_endpoint(websocket: WebSocket): + """WebSocket endpoint for video generation""" + await websocket.accept() + + client_id = f"{websocket.client.host}:{websocket.client.port}" + print(f"\nšŸ”Œ New connection from {client_id}") + + try: + # Send welcome message + await websocket.send_json({ + "status": "connected", + "message": "Connected to FLOAT WebSocket API", + "version": "1.0.0", + "max_concurrent_generations": MAX_CONCURRENT_GENERATIONS + }) + + # Process messages + while True: + try: + # Receive message + message = await websocket.receive_text() + print(f"šŸ“Ø Received message from {client_id}") + + # Parse JSON + request_data = json.loads(message) + + # Send processing acknowledgment + await websocket.send_json({ + "status": "processing", + "message": "Processing your request...", + "active_generations": active_generations, + "max_concurrent": MAX_CONCURRENT_GENERATIONS + }) + + # Process request (semaphore will control concurrency) + response = await process_generation_request(request_data) + + # Send response + await websocket.send_json(response) + print(f"āœ‰ļø Sent response to {client_id}") + + except json.JSONDecodeError as e: + await websocket.send_json({ + "status": "error", + "error": f"Invalid JSON: {str(e)}" + }) + print(f"āš ļø JSON decode error from {client_id}") + + except Exception as e: + await websocket.send_json({ + "status": "error", + "error": f"Processing error: {str(e)}" + }) + print(f"āŒ Error handling message from {client_id}: {e}") + + except WebSocketDisconnect: + print(f"šŸ”Œ Client disconnected: {client_id}") + + except Exception as e: + print(f"āŒ WebSocket error with {client_id}: {e}") + + finally: + print(f"šŸ‘‹ Connection closed: {client_id}\n") + +# For running with uvicorn (alternative to hypercorn) +"""if __name__ == "__main__": + import uvicorn + uvicorn.run( + "main:app", + host="0.0.0.0", + port=8000, + reload=False, + log_level="info" + )""" \ No newline at end of file diff --git a/models/float/FLOAT.py b/models/float/FLOAT.py index 776fb58..bf33b2b 100644 --- a/models/float/FLOAT.py +++ b/models/float/FLOAT.py @@ -1,3 +1,4 @@ +import os import torch, math import torch.nn as nn import torch.nn.functional as F @@ -184,8 +185,12 @@ def __init__(self, opt): self.num_frames_for_clip = int(opt.wav2vec_sec * self.opt.fps) self.num_prev_frames = int(opt.num_prev_frames) + if os.path.exists(opt.wav2vec_model_path): + + self.wav2vec2 = Wav2VecModel.from_pretrained(opt.wav2vec_model_path, local_files_only = True) + else: + self.wav2vec2 = Wav2VecModel.from_pretrained(opt.wav2vec_model_path) - self.wav2vec2 = Wav2VecModel.from_pretrained(opt.wav2vec_model_path, local_files_only = True) self.wav2vec2.feature_extractor._freeze_parameters() for name, param in self.wav2vec2.named_parameters(): @@ -233,20 +238,38 @@ def inference(self, a: torch.Tensor, seq_len:int) -> torch.Tensor: class Audio2Emotion(nn.Module): def __init__(self, opt): super().__init__() - self.wav2vec2_for_emotion = Wav2Vec2ForSpeechClassification.from_pretrained(opt.audio2emotion_path, local_files_only=True) + + # Load pretrained model (local preferred) + if os.path.exists(opt.audio2emotion_path): + self.wav2vec2_for_emotion = Wav2Vec2ForSpeechClassification.from_pretrained( + opt.audio2emotion_path, + local_files_only=True + ) + else: + self.wav2vec2_for_emotion = Wav2Vec2ForSpeechClassification.from_pretrained( + opt.audio2emotion_path + ) + self.wav2vec2_for_emotion.eval() - - # seven labels - self.id2label = {0: "angry", 1: "disgust", 2: "fear", 3: "happy", - 4: "neutral", 5: "sad", 6: "surprise"} + + # seven labels + self.id2label = { + 0: "angry", + 1: "disgust", + 2: "fear", + 3: "happy", + 4: "neutral", + 5: "sad", + 6: "surprise" + } self.label2id = {v: k for k, v in self.id2label.items()} @torch.no_grad() def predict_emotion(self, a: torch.Tensor, prev_a: torch.Tensor = None) -> torch.Tensor: + # Concatenate previous audio if provided if prev_a is not None: a = torch.cat([prev_a, a], dim=1) - logits = self.wav2vec2_for_emotion.forward(a).logits - return F.softmax(logits, dim=1) # scores -####################################################### \ No newline at end of file + logits = self.wav2vec2_for_emotion(a).logits + return F.softmax(logits, dim=1) diff --git a/requirements.txt b/requirements.txt index 0a38d5d..9d2d3e6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,14 +1,56 @@ -pyyaml -opencv-python -pandas -tqdm -matplotlib -flow-vis -librosa +# Core Deep Learning - PyTorch Stack +torch==2.0.1 +torchvision==0.15.2 +torchaudio==2.0.2 + +# Core Dependencies +numpy==1.24.4 + +# Transformers & NLP transformers==4.30.2 + +# Computer Vision +opencv-python==4.8.0.74 +timm==1.0.9 +face-alignment==1.4.1 + +# Image Augmentation albumentations==1.4.15 albucore==0.0.16 -torchdiffeq==0.2.5 -timm==1.0.9 -face_alignment==1.4.1 + +# Audio Processing +librosa==0.10.1 + +# Video Processing av==12.0.0 + +# Math & Physics +torchdiffeq==0.2.5 + +# Data Processing & Visualization +pandas==2.0.3 +matplotlib==3.7.2 +flow-vis==0.1 +pyyaml==6.0.1 +tqdm==4.65.0 +gdown==5.2.0 + +# Web Interface - Gradio +gradio==3.50.2 +# Web API +fastapi==0.122.0 +hypercorn==0.15.0 +uvicorn==0.24.0 + +# Async performance +uvloop==0.19.0 + +# WebSocket support +websockets>=10.0,<12.0 + +# HTTP client (if needed) +httpx==0.25.1 +# Pydantic for data validation +pydantic==2.7.0 +pydantic-settings==2.1.0 +matplotlib_inline==0.1.7