From d91a2a289bd631ec6e2f733ef103f6a37a73de23 Mon Sep 17 00:00:00 2001 From: Mohamed Emam <126331291+mohamed-em2m@users.noreply.github.com> Date: Tue, 25 Nov 2025 06:19:43 +0000 Subject: [PATCH 01/20] create gradio interface --- Dockerfile | 32 +++ app.py | 493 +++++++++++++++++++++++++++++++++++++++++++++++ requirements.txt | 129 ++++++++++++- 3 files changed, 644 insertions(+), 10 deletions(-) create mode 100644 Dockerfile create mode 100644 app.py diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..e945b77 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,32 @@ +FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04 + +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt-get update && apt-get install -y \ + python3.8 \ + python3-pip \ + ffmpeg \ + libsm6 \ + libxext6 \ + git \ + wget \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +# Copy requirements +COPY requirements.txt /app/ + +# Install PyTorch +RUN pip3 install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118 + +# Install other requirements +RUN pip3 install -r requirements.txt + +# Copy application +COPY . /app/ + +# Expose Gradio port +EXPOSE 7860 + +CMD ["python3", "gradio_app.py", "--server_name", "0.0.0.0", "--port", "7860"] \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000..45d4d4a --- /dev/null +++ b/app.py @@ -0,0 +1,493 @@ +""" +FLOAT - Gradio Interface for Audio-Driven Talking Face Generation +""" + +import os +import gradio as gr +import datetime +from pathlib import Path + +# Import the inference components +from XXXz import InferenceAgent, InferenceOptions + + +class GradioInterface: + def __init__(self): + # Initialize options with defaults + self.opt = InferenceOptions().parse() + self.opt.rank, self.opt.ngpus = 0, 1 + + # Create results directory + os.makedirs(self.opt.res_dir, exist_ok=True) + + # Initialize the inference agent + print("Loading FLOAT model...") + self.agent = InferenceAgent(self.opt) + print("Model loaded successfully!") + + def generate_video( + self, + ref_image, + audio_file, + emotion, + a_cfg_scale, + r_cfg_scale, + e_cfg_scale, + nfe, + seed, + no_crop, + progress=gr.Progress() + ): + """ + Generate talking face video from reference image and audio + """ + try: + progress(0, desc="Preparing inputs...") + + # Validate inputs + if ref_image is None: + return None, "āŒ Please upload a reference image" + if audio_file is None: + return None, "āŒ Please upload an audio file" + + # Generate output filename + video_name = Path(ref_image).stem + audio_name = Path(audio_file).stem + call_time = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + + res_video_path = os.path.join( + self.opt.res_dir, + f"{call_time}-{video_name}-{audio_name}-nfe{nfe}-seed{seed}-acfg{a_cfg_scale}-ecfg{e_cfg_scale}-{emotion}.mp4" + ) + + progress(0.3, desc="Running inference...") + + # Run inference + output_path = self.agent.run_inference( + res_video_path=res_video_path, + ref_path=ref_image, + audio_path=audio_file, + a_cfg_scale=a_cfg_scale, + r_cfg_scale=r_cfg_scale, + e_cfg_scale=e_cfg_scale, + emo=emotion, + nfe=nfe, + no_crop=no_crop, + seed=seed, + verbose=True + ) + + progress(1.0, desc="Complete!") + + status_msg = f"āœ… Video generated successfully!\nšŸ“ Saved to: {output_path}" + return output_path, status_msg + + except Exception as e: + error_msg = f"āŒ Error during generation: {str(e)}" + print(error_msg) + return None, error_msg + + +def create_interface(): + """Create and configure the Gradio interface""" + + interface = GradioInterface() + + # Custom CSS for better styling + custom_css = """ + .gradio-container { + font-family: 'Arial', sans-serif; + } + .main-header { + text-align: center; + margin-bottom: 2rem; + } + .status-box { + padding: 1rem; + border-radius: 8px; + margin-top: 1rem; + } + """ + + with gr.Blocks(css=custom_css, title="FLOAT - Talking Face Generation") as demo: + + gr.Markdown( + """ + # šŸŽ­ FLOAT: Audio-Driven Talking Face Generation + + Generate realistic talking face videos from a reference image and audio file. + Upload your inputs and adjust the parameters below to create your video. + """ + ) + + with gr.Row(): + # Left column - Inputs + with gr.Column(scale=1): + gr.Markdown("### šŸ“„ Input Files") + + ref_image = gr.Image( + label="Reference Image", + type="filepath", + sources=["upload"], + height=300 + ) + gr.Markdown("*Upload a clear frontal face image*") + + audio_file = gr.Audio( + label="Audio File", + type="filepath", + sources=["upload"] + ) + gr.Markdown("*Upload the audio/speech file*") + + with gr.Accordion("āš™ļø Advanced Options", open=False): + emotion = gr.Dropdown( + choices=['S2E', 'angry', 'disgust', 'fear', 'happy', 'neutral', 'sad', 'surprise'], + value='S2E', + label="Emotion Control", + info="Choose target emotion or 'S2E' for speech-to-emotion" + ) + + no_crop = gr.Checkbox( + label="Skip Face Cropping", + value=False, + info="Enable if image is already cropped and aligned" + ) + + seed = gr.Slider( + minimum=0, + maximum=10000, + value=25, + step=1, + label="Random Seed", + info="Set seed for reproducible results" + ) + + # Right column - Parameters & Output + with gr.Column(scale=1): + gr.Markdown("### šŸŽ›ļø Generation Parameters") + + with gr.Group(): + nfe = gr.Slider( + minimum=1, + maximum=50, + value=10, + step=1, + label="Number of Function Evaluations (NFE)", + info="Higher = better quality but slower (10-20 recommended)" + ) + + a_cfg_scale = gr.Slider( + minimum=0.0, + maximum=5.0, + value=2.0, + step=0.1, + label="Audio CFG Scale", + info="Audio guidance strength (1.5-3.0 recommended)" + ) + + r_cfg_scale = gr.Slider( + minimum=0.0, + maximum=3.0, + value=1.0, + step=0.1, + label="Reference CFG Scale", + info="Reference image guidance strength" + ) + + e_cfg_scale = gr.Slider( + minimum=0.0, + maximum=3.0, + value=1.0, + step=0.1, + label="Emotion CFG Scale", + info="Emotion control guidance strength" + ) + + generate_btn = gr.Button( + "šŸš€ Generate Video", + variant="primary", + size="lg" + ) + + status_output = gr.Textbox( + label="Status", + placeholder="Status messages will appear here...", + lines=3 + ) + + video_output = gr.Video( + label="Generated Video", + height=400 + ) + + # Parameter presets + with gr.Accordion("šŸ“‹ Parameter Presets", open=False): + gr.Markdown( + """ + ### Quick Presets + + **Fast Preview** (NFE=5, A_CFG=2.0) + - Quick generation for testing + - Lower quality but fast + + **Balanced** (NFE=10, A_CFG=2.0) ⭐ *Default* + - Good balance of quality and speed + - Recommended for most uses + + **High Quality** (NFE=20, A_CFG=2.5) + - Best quality output + - Slower generation time + + **Expressive** (NFE=15, E_CFG=1.5) + - Enhanced emotional expressions + - Good for dramatic content + """ + ) + + with gr.Row(): + preset_fast = gr.Button("⚔ Fast Preview") + preset_balanced = gr.Button("āš–ļø Balanced") + preset_quality = gr.Button("šŸ’Ž High Quality") + preset_expressive = gr.Button("šŸŽ­ Expressive") + + # Information section + with gr.Accordion("ā„¹ļø Help & Information", open=False): + gr.Markdown( + """ + ## How to Use + + 1. **Upload Reference Image**: Choose a clear, frontal face image (512x512 recommended) + 2. **Upload Audio**: Select the audio file for lip-sync generation + 3. **Adjust Parameters**: Modify generation settings or use presets + 4. **Generate**: Click the generate button and wait for processing + + ## Parameter Guide + + - **NFE**: Controls generation steps. Higher = better quality but slower + - **Audio CFG**: Controls how closely video follows audio. Higher = stricter sync + - **Reference CFG**: Controls identity preservation. Higher = more similar to reference + - **Emotion CFG**: Controls emotion expression strength + - **Emotion**: Choose specific emotion or 'S2E' for automatic emotion from speech + + ## Tips + + - Use high-quality, well-lit reference images for best results + - Audio should be clear with minimal background noise + - Start with default parameters and adjust based on results + - Enable "Skip Face Cropping" only if your image is pre-processed + + ## Supported Formats + + - **Images**: JPG, PNG + - **Audio**: WAV, MP3, M4A, FLAC + """ + ) + + # Event handlers + def set_fast_preset(): + return 5, 2.0, 1.0, 1.0 + + def set_balanced_preset(): + return 10, 2.0, 1.0, 1.0 + + def set_quality_preset(): + return 20, 2.5, 1.0, 1.0 + + def set_expressive_preset(): + return 15, 2.0, 1.0, 1.5 + + # Connect preset buttons + preset_fast.click( + fn=set_fast_preset, + outputs=[nfe, a_cfg_scale, r_cfg_scale, e_cfg_scale] + ) + + preset_balanced.click( + fn=set_balanced_preset, + outputs=[nfe, a_cfg_scale, r_cfg_scale, e_cfg_scale] + ) + + preset_quality.click( + fn=set_quality_preset, + outputs=[nfe, a_cfg_scale, r_cfg_scale, e_cfg_scale] + ) + + preset_expressive.click( + fn=set_expressive_preset, + outputs=[nfe, a_cfg_scale, r_cfg_scale, e_cfg_scale] + ) + + # Connect generate button + generate_btn.click( + fn=interface.generate_video, + inputs=[ + ref_image, + audio_file, + emotion, + a_cfg_scale, + r_cfg_scale, + e_cfg_scale, + nfe, + seed, + no_crop + ], + outputs=[video_output, status_output] + ) + + # Examples + gr.Markdown("### šŸ“š Example Configurations") + gr.Examples( + examples=[ + ["S2E", 2.0, 1.0, 1.0, 10, 25, False], + ["happy", 2.5, 1.0, 1.5, 15, 42, False], + ["neutral", 2.0, 1.2, 1.0, 20, 100, False], + ], + inputs=[emotion, a_cfg_scale, r_cfg_scale, e_cfg_scale, nfe, seed, no_crop], + label="Try these parameter combinations" + ) + + return demo + + +def parse_launch_args(): + """Parse command-line arguments for Gradio launch""" + import argparse + + parser = argparse.ArgumentParser(description='Launch FLOAT Gradio Interface') + + # Server options + parser.add_argument('--port', type=int, default=7860, + help='Port to run the server on (default: 7860)') + parser.add_argument('--server_name', type=str, default="0.0.0.0", + help='Server name/IP to bind to (default: 0.0.0.0)') + parser.add_argument('--share', action='store_true', + help='Create a public share link') + parser.add_argument('--auth', type=str, default=None, + help='Username and password for authentication, format: "username:password"') + parser.add_argument('--auth_message', type=str, default="Please login to access FLOAT", + help='Message to display on login page') + + # Model options + parser.add_argument('--ckpt_path', type=str, + default="/home/nvadmin/workspace/taek/float-pytorch/checkpoints/float.pth", + help='Path to model checkpoint') + parser.add_argument('--res_dir', type=str, default="./results", + help='Directory to save generated videos') + parser.add_argument('--wav2vec_model_path', type=str, default="facebook/wav2vec2-base-960h", + help='Path to wav2vec2 model') + + # Interface options + parser.add_argument('--debug', action='store_true', + help='Enable debug mode with detailed error messages') + parser.add_argument('--queue', action='store_true', default=True, + help='Enable request queuing (default: True)') + parser.add_argument('--max_threads', type=int, default=4, + help='Maximum number of concurrent threads (default: 4)') + parser.add_argument('--inbrowser', action='store_true', + help='Automatically open in browser') + parser.add_argument('--prevent_thread_lock', action='store_true', + help='Prevent thread lock (useful for debugging)') + + # Advanced options + parser.add_argument('--ssl_keyfile', type=str, default=None, + help='Path to SSL key file for HTTPS') + parser.add_argument('--ssl_certfile', type=str, default=None, + help='Path to SSL certificate file for HTTPS') + parser.add_argument('--ssl_keyfile_password', type=str, default=None, + help='Password for SSL key file') + parser.add_argument('--favicon_path', type=str, default=None, + help='Path to custom favicon') + + return parser.parse_args() + + +if __name__ == "__main__": + # Parse launch arguments + args = parse_launch_args() + + # Override inference options with command-line args if provided + import sys + inference_args = [] + if args.ckpt_path: + inference_args.extend(['--ckpt_path', args.ckpt_path]) + if args.res_dir: + inference_args.extend(['--res_dir', args.res_dir]) + if args.wav2vec_model_path: + inference_args.extend(['--wav2vec_model_path', args.wav2vec_model_path]) + + # Temporarily modify sys.argv for InferenceOptions + original_argv = sys.argv.copy() + sys.argv = [sys.argv[0]] + inference_args + + # Create and launch the interface + demo = create_interface() + + # Restore original argv + sys.argv = original_argv + + # Parse authentication if provided + auth_tuple = None + if args.auth: + try: + username, password = args.auth.split(':', 1) + auth_tuple = (username, password) + print(f"šŸ”’ Authentication enabled for user: {username}") + except ValueError: + print("āš ļø Invalid auth format. Use 'username:password'") + + # Print launch information + print("\n" + "="*60) + print("šŸš€ Launching FLOAT Gradio Interface") + print("="*60) + print(f"šŸ“ Server: {args.server_name}:{args.port}") + print(f"šŸ”— Local URL: http://localhost:{args.port}") + if args.share: + print("🌐 Public sharing: Enabled") + if auth_tuple: + print(f"šŸ”’ Authentication: Enabled") + print(f"šŸ’¾ Results directory: {args.res_dir}") + print(f"šŸ¤– Model checkpoint: {args.ckpt_path}") + print("="*60 + "\n") + + # Launch configuration + launch_kwargs = { + 'server_name': args.server_name, + 'server_port': args.port, + 'share': args.share, + 'show_error': True, + 'inbrowser': args.inbrowser, + 'prevent_thread_lock': args.prevent_thread_lock, + } + + # Add optional parameters + if auth_tuple: + launch_kwargs['auth'] = auth_tuple + launch_kwargs['auth_message'] = args.auth_message + + if args.ssl_keyfile and args.ssl_certfile: + launch_kwargs['ssl_keyfile'] = args.ssl_keyfile + launch_kwargs['ssl_certfile'] = args.ssl_certfile + if args.ssl_keyfile_password: + launch_kwargs['ssl_keyfile_password'] = args.ssl_keyfile_password + print("šŸ” HTTPS enabled") + + if args.favicon_path: + launch_kwargs['favicon_path'] = args.favicon_path + + if args.debug: + launch_kwargs['debug'] = True + print("šŸ› Debug mode enabled") + + # Launch the demo + try: + if args.queue: + demo.queue(max_size=args.max_threads) + + demo.launch(**launch_kwargs) + except KeyboardInterrupt: + print("\n\nšŸ‘‹ Shutting down gracefully...") + except Exception as e: + print(f"\nāŒ Error launching interface: {e}") + if args.debug: + raise \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 0a38d5d..a64ab14 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,14 +1,123 @@ -pyyaml -opencv-python -pandas -tqdm -matplotlib -flow-vis -librosa +# ============================================ +# FLOAT - Complete Pinned Requirements +# Official FLOAT Repository Compatible Versions +# Python 3.8.5 (as specified in official repo) +# Tested on Linux, A100 GPU, and V100 GPU +# ============================================ + +# Core Deep Learning - PyTorch Stack (Official FLOAT versions) +torch==2.0.1 +torchvision==0.15.2 +torchaudio==2.0.2 + +# NumPy (compatible with PyTorch 2.0.1 and Python 3.8.5) +numpy==1.24.4 + +# Transformers & NLP (Official FLOAT requirement) transformers==4.30.2 +tokenizers==0.13.3 +huggingface-hub==0.16.4 +safetensors==0.3.1 + +# Computer Vision +opencv-python==4.8.0.74 +timm==1.0.9 +face-alignment==1.4.1 +scikit-image==0.21.0 + +# Image Augmentation (Official FLOAT requirement) albumentations==1.4.15 albucore==0.0.16 -torchdiffeq==0.2.5 -timm==1.0.9 -face_alignment==1.4.1 +qudida==0.0.4 + +# Audio Processing (Official FLOAT requirement) +librosa==0.10.1 +soundfile==0.12.1 +audioread==3.0.0 +resampy==0.4.2 +pooch==1.7.0 + +# Video Processing (Official FLOAT requirement) av==12.0.0 +imageio==2.31.1 +imageio-ffmpeg==0.4.8 + +# Math & Physics (Official FLOAT requirement) +torchdiffeq==0.2.5 +scipy==1.11.2 + +# Data Processing & Visualization (Official FLOAT requirements) +pandas==2.0.3 +matplotlib==3.7.2 +seaborn==0.12.2 +flow-vis==0.1 +pyyaml==6.0.1 +tqdm==4.65.0 + +# Web Interface - Gradio (Compatible with Python 3.8.5) +gradio==3.50.2 + +# Gradio Dependencies (pinned for stability with Python 3.8.5) +aiofiles==23.2.1 +aiohttp==3.9.5 +altair==5.0.1 +fastapi==0.103.2 +ffmpy==0.3.1 +httpx==0.25.0 +jinja2==3.1.2 +orjson==3.9.9 +pillow==10.0.1 +pydantic==1.10.13 +pydub==0.25.1 +python-multipart==0.0.6 +requests==2.31.0 +semantic-version==2.10.0 +uvicorn==0.23.2 +websockets==11.0.3 + +# Additional ML utilities +scikit-learn==1.3.2 +numba==0.57.1 +llvmlite==0.40.1 +joblib==1.3.2 + +# System & File Handling +psutil==5.9.6 +filelock==3.13.1 +fsspec==2023.10.0 +packaging==23.2 + +# Type Checking & Validation +typing-extensions==4.8.0 +pydantic==1.10.13 + +# Network & HTTP +certifi==2023.11.17 +charset-normalizer==3.3.2 +idna==3.6 +urllib3==2.1.0 + +# Misc Core Dependencies +regex==2023.10.3 +sympy==1.12 +markupsafe==2.1.3 +six==1.16.0 +python-dateutil==2.8.2 +pytz==2023.3.post1 +tzdata==2023.3 + +# Image Processing +decorator==5.1.1 +networkx==3.1 +lazy-loader==0.3 +tifffile==2023.7.10 + +# Audio/Video Codecs +cffi==1.16.0 +pycparser==2.21 +msgpack==1.0.7 +soxr==0.3.7 + +# Optional but recommended for better performance +einops==0.7.0 +kornia==0.7.0 \ No newline at end of file From 13707f2a56a978178aa291e1da79823716d5321d Mon Sep 17 00:00:00 2001 From: Mohamed Emam <126331291+mohamed-em2m@users.noreply.github.com> Date: Tue, 25 Nov 2025 18:33:07 +0200 Subject: [PATCH 02/20] Update requirements.txt to simplify dependencies --- requirements.txt | 103 +++++------------------------------------------ 1 file changed, 9 insertions(+), 94 deletions(-) diff --git a/requirements.txt b/requirements.txt index a64ab14..6f67934 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,123 +1,38 @@ -# ============================================ -# FLOAT - Complete Pinned Requirements -# Official FLOAT Repository Compatible Versions -# Python 3.8.5 (as specified in official repo) -# Tested on Linux, A100 GPU, and V100 GPU -# ============================================ - -# Core Deep Learning - PyTorch Stack (Official FLOAT versions) +# Core Deep Learning - PyTorch Stack torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 -# NumPy (compatible with PyTorch 2.0.1 and Python 3.8.5) +# Core Dependencies numpy==1.24.4 -# Transformers & NLP (Official FLOAT requirement) +# Transformers & NLP transformers==4.30.2 -tokenizers==0.13.3 -huggingface-hub==0.16.4 -safetensors==0.3.1 # Computer Vision opencv-python==4.8.0.74 timm==1.0.9 face-alignment==1.4.1 -scikit-image==0.21.0 -# Image Augmentation (Official FLOAT requirement) +# Image Augmentation albumentations==1.4.15 albucore==0.0.16 -qudida==0.0.4 -# Audio Processing (Official FLOAT requirement) +# Audio Processing librosa==0.10.1 -soundfile==0.12.1 -audioread==3.0.0 -resampy==0.4.2 -pooch==1.7.0 -# Video Processing (Official FLOAT requirement) +# Video Processing av==12.0.0 -imageio==2.31.1 -imageio-ffmpeg==0.4.8 -# Math & Physics (Official FLOAT requirement) +# Math & Physics torchdiffeq==0.2.5 -scipy==1.11.2 -# Data Processing & Visualization (Official FLOAT requirements) +# Data Processing & Visualization pandas==2.0.3 matplotlib==3.7.2 -seaborn==0.12.2 flow-vis==0.1 pyyaml==6.0.1 tqdm==4.65.0 -# Web Interface - Gradio (Compatible with Python 3.8.5) +# Web Interface - Gradio gradio==3.50.2 - -# Gradio Dependencies (pinned for stability with Python 3.8.5) -aiofiles==23.2.1 -aiohttp==3.9.5 -altair==5.0.1 -fastapi==0.103.2 -ffmpy==0.3.1 -httpx==0.25.0 -jinja2==3.1.2 -orjson==3.9.9 -pillow==10.0.1 -pydantic==1.10.13 -pydub==0.25.1 -python-multipart==0.0.6 -requests==2.31.0 -semantic-version==2.10.0 -uvicorn==0.23.2 -websockets==11.0.3 - -# Additional ML utilities -scikit-learn==1.3.2 -numba==0.57.1 -llvmlite==0.40.1 -joblib==1.3.2 - -# System & File Handling -psutil==5.9.6 -filelock==3.13.1 -fsspec==2023.10.0 -packaging==23.2 - -# Type Checking & Validation -typing-extensions==4.8.0 -pydantic==1.10.13 - -# Network & HTTP -certifi==2023.11.17 -charset-normalizer==3.3.2 -idna==3.6 -urllib3==2.1.0 - -# Misc Core Dependencies -regex==2023.10.3 -sympy==1.12 -markupsafe==2.1.3 -six==1.16.0 -python-dateutil==2.8.2 -pytz==2023.3.post1 -tzdata==2023.3 - -# Image Processing -decorator==5.1.1 -networkx==3.1 -lazy-loader==0.3 -tifffile==2023.7.10 - -# Audio/Video Codecs -cffi==1.16.0 -pycparser==2.21 -msgpack==1.0.7 -soxr==0.3.7 - -# Optional but recommended for better performance -einops==0.7.0 -kornia==0.7.0 \ No newline at end of file From 941b124fdb71402bbcbaa979ebc6ca0b8a157992 Mon Sep 17 00:00:00 2001 From: Mohamed Emam <126331291+mohamed-em2m@users.noreply.github.com> Date: Tue, 25 Nov 2025 18:54:59 +0200 Subject: [PATCH 03/20] Fix import statement for inference components --- app.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/app.py b/app.py index 45d4d4a..812dcc8 100644 --- a/app.py +++ b/app.py @@ -8,7 +8,7 @@ from pathlib import Path # Import the inference components -from XXXz import InferenceAgent, InferenceOptions +from inference import InferenceAgent, InferenceOptions class GradioInterface: @@ -490,4 +490,4 @@ def parse_launch_args(): except Exception as e: print(f"\nāŒ Error launching interface: {e}") if args.debug: - raise \ No newline at end of file + raise From 9b714f41eb35315a4195af3215dd2b9be21cfd37 Mon Sep 17 00:00:00 2001 From: Mohamed Emam <126331291+mohamed-em2m@users.noreply.github.com> Date: Tue, 25 Nov 2025 19:25:00 +0200 Subject: [PATCH 04/20] Remove pip install gdown from script --- download_checkpoints.sh | 2 -- 1 file changed, 2 deletions(-) diff --git a/download_checkpoints.sh b/download_checkpoints.sh index cd6eddc..5c3072c 100644 --- a/download_checkpoints.sh +++ b/download_checkpoints.sh @@ -1,5 +1,3 @@ -pip install gdown - gdown --id 1rvWuM12cyvNvBQNCLmG4Fr2L1rpjQBF0 mv float.pth checkpoints/ From 9ef06f06def6c8bd75fe0fa197e06bab1f93a20b Mon Sep 17 00:00:00 2001 From: Mohamed Emam <126331291+mohamed-em2m@users.noreply.github.com> Date: Tue, 25 Nov 2025 19:26:00 +0200 Subject: [PATCH 05/20] Add gdown package version 5.2.0 to requirements --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 6f67934..34f1fd3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -33,6 +33,7 @@ matplotlib==3.7.2 flow-vis==0.1 pyyaml==6.0.1 tqdm==4.65.0 +gdown==5.2.0 # Web Interface - Gradio gradio==3.50.2 From ae2268385b8ed4042864bcd239eff563d9990c96 Mon Sep 17 00:00:00 2001 From: Mohamed Emam <126331291+mohamed-em2m@users.noreply.github.com> Date: Tue, 25 Nov 2025 19:26:46 +0200 Subject: [PATCH 06/20] Change import from inference to generate module --- app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app.py b/app.py index 812dcc8..3a5a10d 100644 --- a/app.py +++ b/app.py @@ -8,7 +8,7 @@ from pathlib import Path # Import the inference components -from inference import InferenceAgent, InferenceOptions +from generate import InferenceAgent, InferenceOptions class GradioInterface: From 07c6d34797c0e121e58cedd6dd066b691a8eb2bf Mon Sep 17 00:00:00 2001 From: Mohamed Emam <126331291+mohamed-em2m@users.noreply.github.com> Date: Wed, 26 Nov 2025 01:37:37 +0200 Subject: [PATCH 07/20] Change default checkpoint path in app.py --- app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app.py b/app.py index 3a5a10d..b6cd270 100644 --- a/app.py +++ b/app.py @@ -370,7 +370,7 @@ def parse_launch_args(): # Model options parser.add_argument('--ckpt_path', type=str, - default="/home/nvadmin/workspace/taek/float-pytorch/checkpoints/float.pth", + default="./checkpoints/float.pth", help='Path to model checkpoint') parser.add_argument('--res_dir', type=str, default="./results", help='Directory to save generated videos') From c2555376029920ef40dc5e1ed746fc107093fd00 Mon Sep 17 00:00:00 2001 From: Mohamed Emam <126331291+mohamed-em2m@users.noreply.github.com> Date: Wed, 26 Nov 2025 01:44:42 +0200 Subject: [PATCH 08/20] Check audio2emotion_path existence before loading model Add check for existence of audio2emotion_path before loading model. --- models/float/FLOAT.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/models/float/FLOAT.py b/models/float/FLOAT.py index 776fb58..0cf65c0 100644 --- a/models/float/FLOAT.py +++ b/models/float/FLOAT.py @@ -1,3 +1,4 @@ +import os import torch, math import torch.nn as nn import torch.nn.functional as F @@ -233,7 +234,11 @@ def inference(self, a: torch.Tensor, seq_len:int) -> torch.Tensor: class Audio2Emotion(nn.Module): def __init__(self, opt): super().__init__() - self.wav2vec2_for_emotion = Wav2Vec2ForSpeechClassification.from_pretrained(opt.audio2emotion_path, local_files_only=True) + if os.path.exists(opt.audio2emotion_path): + self.wav2vec2_for_emotion = Wav2Vec2ForSpeechClassification.from_pretrained(opt.audio2emotion_path, local_files_only=True) + else: + self.wav2vec2_for_emotion = Wav2Vec2ForSpeechClassification.from_pretrained(opt.audio2emotion_path) + self.wav2vec2_for_emotion.eval() # seven labels @@ -249,4 +254,4 @@ def predict_emotion(self, a: torch.Tensor, prev_a: torch.Tensor = None) -> torch logits = self.wav2vec2_for_emotion.forward(a).logits return F.softmax(logits, dim=1) # scores -####################################################### \ No newline at end of file +####################################################### From c36e85191fed2d214c0c76aba4cced1bcd340d14 Mon Sep 17 00:00:00 2001 From: Mohamed Emam <126331291+mohamed-em2m@users.noreply.github.com> Date: Wed, 26 Nov 2025 01:58:36 +0200 Subject: [PATCH 09/20] Refactor Wav2Vec model loading with path checks --- models/float/FLOAT.py | 42 ++++++++++++++++++++++++++++++------------ 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/models/float/FLOAT.py b/models/float/FLOAT.py index 0cf65c0..bf33b2b 100644 --- a/models/float/FLOAT.py +++ b/models/float/FLOAT.py @@ -185,8 +185,12 @@ def __init__(self, opt): self.num_frames_for_clip = int(opt.wav2vec_sec * self.opt.fps) self.num_prev_frames = int(opt.num_prev_frames) + if os.path.exists(opt.wav2vec_model_path): + + self.wav2vec2 = Wav2VecModel.from_pretrained(opt.wav2vec_model_path, local_files_only = True) + else: + self.wav2vec2 = Wav2VecModel.from_pretrained(opt.wav2vec_model_path) - self.wav2vec2 = Wav2VecModel.from_pretrained(opt.wav2vec_model_path, local_files_only = True) self.wav2vec2.feature_extractor._freeze_parameters() for name, param in self.wav2vec2.named_parameters(): @@ -234,24 +238,38 @@ def inference(self, a: torch.Tensor, seq_len:int) -> torch.Tensor: class Audio2Emotion(nn.Module): def __init__(self, opt): super().__init__() - if os.path.exists(opt.audio2emotion_path): - self.wav2vec2_for_emotion = Wav2Vec2ForSpeechClassification.from_pretrained(opt.audio2emotion_path, local_files_only=True) - else: - self.wav2vec2_for_emotion = Wav2Vec2ForSpeechClassification.from_pretrained(opt.audio2emotion_path) + + # Load pretrained model (local preferred) + if os.path.exists(opt.audio2emotion_path): + self.wav2vec2_for_emotion = Wav2Vec2ForSpeechClassification.from_pretrained( + opt.audio2emotion_path, + local_files_only=True + ) + else: + self.wav2vec2_for_emotion = Wav2Vec2ForSpeechClassification.from_pretrained( + opt.audio2emotion_path + ) self.wav2vec2_for_emotion.eval() - - # seven labels - self.id2label = {0: "angry", 1: "disgust", 2: "fear", 3: "happy", - 4: "neutral", 5: "sad", 6: "surprise"} + + # seven labels + self.id2label = { + 0: "angry", + 1: "disgust", + 2: "fear", + 3: "happy", + 4: "neutral", + 5: "sad", + 6: "surprise" + } self.label2id = {v: k for k, v in self.id2label.items()} @torch.no_grad() def predict_emotion(self, a: torch.Tensor, prev_a: torch.Tensor = None) -> torch.Tensor: + # Concatenate previous audio if provided if prev_a is not None: a = torch.cat([prev_a, a], dim=1) - logits = self.wav2vec2_for_emotion.forward(a).logits - return F.softmax(logits, dim=1) # scores -####################################################### + logits = self.wav2vec2_for_emotion(a).logits + return F.softmax(logits, dim=1) From 017dde07b32415c7f4e7decbbeb7da72cf9114f6 Mon Sep 17 00:00:00 2001 From: Mohamed Emam <126331291+mohamed-em2m@users.noreply.github.com> Date: Wed, 26 Nov 2025 02:35:55 +0200 Subject: [PATCH 10/20] Add commands to download additional models --- download_checkpoints.sh | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/download_checkpoints.sh b/download_checkpoints.sh index 5c3072c..8cc51eb 100644 --- a/download_checkpoints.sh +++ b/download_checkpoints.sh @@ -1,3 +1,9 @@ gdown --id 1rvWuM12cyvNvBQNCLmG4Fr2L1rpjQBF0 mv float.pth checkpoints/ +hf download r-f/wav2vec-english-speech-emotion-recognition \ + --local-dir ./checkpoints/wav2vec-english-speech-emotion-recognition \ + --include "*" +hf download facebook/wav2vec2-base-960h \ + --local-dir ./checkpoints/facebook--wav2vec2-base-960h \ + --include "*" From 8dcafa1dd42b540a92d6940e4ef2bb8f75fb09fe Mon Sep 17 00:00:00 2001 From: Mohamed Emam <126331291+mohamed-em2m@users.noreply.github.com> Date: Wed, 26 Nov 2025 03:44:31 +0200 Subject: [PATCH 11/20] Refactor wav2vec model path handling and defaults --- generate.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/generate.py b/generate.py index 9ca93f8..2b36c3c 100644 --- a/generate.py +++ b/generate.py @@ -1,7 +1,7 @@ """ Inference Stage 2 """ - +import os import os, torch, random, cv2, torchvision, subprocess, librosa, datetime, tempfile, face_alignment import numpy as np import albumentations as A @@ -25,7 +25,10 @@ def __init__(self, opt): self.fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, flip_input=False) # wav2vec2 audio preprocessor - self.wav2vec_preprocessor = Wav2Vec2FeatureExtractor.from_pretrained(opt.wav2vec_model_path, local_files_only=True) + if os.path.exists(opt.wav2vec_model_path): + self.wav2vec_preprocessor = Wav2Vec2FeatureExtractor.from_pretrained(opt.wav2vec_model_path, local_files_only=True) + else: + self.wav2vec_preprocessor = Wav2Vec2FeatureExtractor.from_pretrained(opt.wav2vec_model_path) # image transform self.transform = A.Compose([ @@ -173,9 +176,11 @@ def initialize(self, parser): parser.add_argument('--res_video_path', default=None, type=str, help='res video path') parser.add_argument('--ckpt_path', - default="/home/nvadmin/workspace/taek/float-pytorch/checkpoints/float.pth", type=str, help='checkpoint path') + default="./checkpoints/float.pth", type=str, help='checkpoint path') parser.add_argument('--res_dir', default="./results", type=str, help='result dir') + parser.add_argument('--wav2vec_model_path', + default="./checkpoints/facebook--wav2vec2-base-960h", type=str, help='wav2vec_model_path) return parser From 21dfc585bfd89aabf111f65383c42b2bc5018f8d Mon Sep 17 00:00:00 2001 From: Mohamed Emam <126331291+mohamed-em2m@users.noreply.github.com> Date: Wed, 26 Nov 2025 03:49:03 +0200 Subject: [PATCH 12/20] Update generate.py --- generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/generate.py b/generate.py index 2b36c3c..0e74154 100644 --- a/generate.py +++ b/generate.py @@ -180,7 +180,7 @@ def initialize(self, parser): parser.add_argument('--res_dir', default="./results", type=str, help='result dir') parser.add_argument('--wav2vec_model_path', - default="./checkpoints/facebook--wav2vec2-base-960h", type=str, help='wav2vec_model_path) + default="./checkpoints/facebook--wav2vec2-base-960h", type=str, help='wav2vec_model_path') return parser From f03c6531f69bab4bae25c8dd803b65d82cbd842a Mon Sep 17 00:00:00 2001 From: Mohamed Emam <126331291+mohamed-em2m@users.noreply.github.com> Date: Wed, 26 Nov 2025 04:00:37 +0200 Subject: [PATCH 13/20] Remove wav2vec_model_path from argument parser Removed wav2vec_model_path argument from parser. --- generate.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/generate.py b/generate.py index 0e74154..b0ee0f6 100644 --- a/generate.py +++ b/generate.py @@ -179,8 +179,7 @@ def initialize(self, parser): default="./checkpoints/float.pth", type=str, help='checkpoint path') parser.add_argument('--res_dir', default="./results", type=str, help='result dir') - parser.add_argument('--wav2vec_model_path', - default="./checkpoints/facebook--wav2vec2-base-960h", type=str, help='wav2vec_model_path') + return parser From d908959ebbd9908a4f2ae483e08f302273cddf38 Mon Sep 17 00:00:00 2001 From: Mohamed Emam <126331291+mohamed-em2m@users.noreply.github.com> Date: Wed, 26 Nov 2025 04:09:00 +0200 Subject: [PATCH 14/20] Fix local directory path for wav2vec2 model download --- download_checkpoints.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/download_checkpoints.sh b/download_checkpoints.sh index 8cc51eb..2746f1a 100644 --- a/download_checkpoints.sh +++ b/download_checkpoints.sh @@ -5,5 +5,5 @@ hf download r-f/wav2vec-english-speech-emotion-recognition \ --local-dir ./checkpoints/wav2vec-english-speech-emotion-recognition \ --include "*" hf download facebook/wav2vec2-base-960h \ - --local-dir ./checkpoints/facebook--wav2vec2-base-960h \ + --local-dir ./checkpoints/wav2vec2-base-960h \ --include "*" From 61724be55c69449df7d17971b7b45b9a7e5b8890 Mon Sep 17 00:00:00 2001 From: Mohamed Emam <126331291+mohamed-em2m@users.noreply.github.com> Date: Wed, 26 Nov 2025 04:10:06 +0200 Subject: [PATCH 15/20] Update app.py From 3563fc26aeb63c4c6ec92d4d4df3eb1ba62798b7 Mon Sep 17 00:00:00 2001 From: Mohamed Emam <126331291+mohamed-em2m@users.noreply.github.com> Date: Wed, 26 Nov 2025 05:39:27 +0200 Subject: [PATCH 16/20] Rename app.py to gradio_interface.py --- app.py => gradio_interface.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename app.py => gradio_interface.py (100%) diff --git a/app.py b/gradio_interface.py similarity index 100% rename from app.py rename to gradio_interface.py From 3893f5152c9fdd54bd21abbb4b1ab354b0bc12aa Mon Sep 17 00:00:00 2001 From: Mohamed Emam <126331291+mohamed-em2m@users.noreply.github.com> Date: Wed, 26 Nov 2025 05:40:19 +0200 Subject: [PATCH 17/20] Add web API dependencies to requirements.txt Added FastAPI, Hypercorn, Uvicorn, Uvloop, Websockets, and Httpx to requirements. --- requirements.txt | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/requirements.txt b/requirements.txt index 34f1fd3..01edb6a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -37,3 +37,10 @@ gdown==5.2.0 # Web Interface - Gradio gradio==3.50.2 +# Web API +fastapi==0.104.1 +hypercorn==0.15.0 +uvicorn==0.24.0 +uvloop==0.19.0 +websockets==12.0 +httpx==0.25.1 From bb2e43f63b1688a376b7e412b0bc0ee7a4d33231 Mon Sep 17 00:00:00 2001 From: Mohamed Emam <126331291+mohamed-em2m@users.noreply.github.com> Date: Wed, 26 Nov 2025 05:43:19 +0200 Subject: [PATCH 18/20] Update dependencies in requirements.txt Updated FastAPI to version 0.122.0 and adjusted WebSocket version constraints. Added Pydantic and Pydantic Settings with specified versions. --- requirements.txt | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 01edb6a..84c3e13 100644 --- a/requirements.txt +++ b/requirements.txt @@ -38,9 +38,18 @@ gdown==5.2.0 # Web Interface - Gradio gradio==3.50.2 # Web API -fastapi==0.104.1 +fastapi==0.122.0 hypercorn==0.15.0 uvicorn==0.24.0 + +# Async performance uvloop==0.19.0 -websockets==12.0 + +# WebSocket support +websockets>=10.0,<12.0 + +# HTTP client (if needed) httpx==0.25.1 +# Pydantic for data validation +pydantic==2.7.0 +pydantic-settings==2.1.0 From 3192dcf5f47d2adae9866a0ee3c880149903aa44 Mon Sep 17 00:00:00 2001 From: Mohamed Emam <126331291+mohamed-em2m@users.noreply.github.com> Date: Wed, 26 Nov 2025 21:48:09 +0000 Subject: [PATCH 19/20] add api system --- main.py | 673 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 673 insertions(+) create mode 100644 main.py diff --git a/main.py b/main.py new file mode 100644 index 0000000..35aa72a --- /dev/null +++ b/main.py @@ -0,0 +1,673 @@ +""" +FastAPI WebSocket API for FLOAT - Audio-Driven Talking Face Generation +Save this file as: api.py +Run with:cc +""" +import os +import json +import base64 +import tempfile +import datetime +from pathlib import Path +from typing import Optional, Dict +import asyncio +from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException +from fastapi.responses import JSONResponse +from pydantic import BaseModel, Field +import uvloop + +# Import the inference components +from generate import InferenceAgent, InferenceOptions + +# Use uvloop for better performance +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + +# Configuration +DEFAULT_IMAGE_PATH = "img.jpg" +RESULTS_DIR = "./results" +MAX_CONCURRENT_GENERATIONS = 3 # Maximum number of concurrent video generations + +# Initialize FastAPI app +app = FastAPI( + title="FLOAT WebSocket API", + description="Audio-Driven Talking Face Generation via WebSocket and REST API", + version="1.0.0" +) + +# Pydantic models for POST request +class AudioData(BaseModel): + content: str = Field(..., description="Base64 encoded audio data") + ext: str = Field(..., description="Audio file extension (e.g., 'wav', 'mp3')") + +class ImageData(BaseModel): + content: str = Field(..., description="Base64 encoded image data") + ext: str = Field(..., description="Image file extension (e.g., 'jpg', 'png')") + +class GenerationParams(BaseModel): + emotion: str = Field("S2E", description="Emotion control: S2E, angry, disgust, fear, happy, neutral, sad, surprise") + a_cfg_scale: float = Field(2.0, ge=0.0, le=5.0, description="Audio CFG scale") + r_cfg_scale: float = Field(1.0, ge=0.0, le=3.0, description="Reference CFG scale") + e_cfg_scale: float = Field(1.0, ge=0.0, le=3.0, description="Emotion CFG scale") + nfe: int = Field(10, ge=1, le=50, description="Number of function evaluations") + seed: int = Field(25, ge=0, description="Random seed") + no_crop: bool = Field(False, description="Skip face cropping") + +class GenerationRequest(BaseModel): + audio: AudioData + image: Optional[ImageData] = None + params: Optional[GenerationParams] = None + +class BatchGenerationRequest(BaseModel): + audios: list[AudioData] = Field(..., description="List of audio files to process") + image: Optional[ImageData] = Field(None, description="Single image to use for all audios") + images: Optional[list[ImageData]] = Field(None, description="List of images (one per audio)") + params: Optional[GenerationParams] = Field(None, description="Generation parameters to use for all") + + class Config: + schema_extra = { + "example": { + "audios": [ + {"content": "base64_audio_1", "ext": "wav"}, + {"content": "base64_audio_2", "ext": "wav"} + ], + "image": {"content": "base64_image", "ext": "jpg"}, + "params": {"emotion": "happy", "nfe": 10} + } + } + +# Global model instance and semaphore +agent = None +opt = None +generation_semaphore = None # Will be initialized on startup +active_generations = 0 # Track active generation count +generation_lock = asyncio.Lock() # Lock for updating counter + +@app.on_event("startup") +async def startup_event(): + """Initialize model on startup""" + global agent, opt, generation_semaphore + + print("\n" + "="*60) + print("šŸš€ Starting FLOAT WebSocket API") + print("="*60) + + # Initialize the semaphore for concurrent generation control + generation_semaphore = asyncio.Semaphore(MAX_CONCURRENT_GENERATIONS) + print(f"šŸ”’ Concurrent generation limit: {MAX_CONCURRENT_GENERATIONS}") + + # Initialize options with empty args to use defaults + import sys + original_argv = sys.argv.copy() + sys.argv = [sys.argv[0]] # Keep only script name + + opt = InferenceOptions().parse() + opt.rank, opt.ngpus = 0, 1 + opt.res_dir = RESULTS_DIR + + sys.argv = original_argv + + # Create results directory + os.makedirs(opt.res_dir, exist_ok=True) + + # Load model + print("šŸ¤– Loading FLOAT model...") + agent = InferenceAgent(opt) + print("āœ… Model loaded successfully!") + + # Check default image + if os.path.exists(DEFAULT_IMAGE_PATH): + print(f"šŸ“ø Default image: {DEFAULT_IMAGE_PATH}") + else: + print(f"āš ļø Warning: Default image not found at {DEFAULT_IMAGE_PATH}") + + print(f"šŸ’¾ Results directory: {opt.res_dir}") + print("="*60) + print("āœ… Server ready! Waiting for connections...") + print("="*60 + "\n") + +@app.get("/") +async def root(): + """Health check endpoint""" + return { + "status": "online", + "service": "FLOAT WebSocket API", + "version": "1.0.0", + "max_concurrent_generations": MAX_CONCURRENT_GENERATIONS, + "active_generations": active_generations, + "endpoints": { + "websocket": "/ws", + "rest_api": "/generate", + "batch_api": "/generate/batch", + "health": "/health", + "status": "/status" + } + } + +@app.get("/health") +async def health_check(): + """Detailed health check""" + return { + "status": "healthy", + "model_loaded": agent is not None, + "default_image_exists": os.path.exists(DEFAULT_IMAGE_PATH), + "results_dir": opt.res_dir if opt else None, + "max_concurrent_generations": MAX_CONCURRENT_GENERATIONS, + "active_generations": active_generations + } + +@app.get("/status") +async def get_status(): + """Get current generation status""" + available_slots = MAX_CONCURRENT_GENERATIONS - active_generations + return { + "max_concurrent_generations": MAX_CONCURRENT_GENERATIONS, + "active_generations": active_generations, + "available_slots": available_slots, + "queue_full": available_slots == 0 + } + +@app.post("/generate") +async def generate_video(request: GenerationRequest): + """ + REST API endpoint for video generation + + Request body: + { + "audio": {"content": "base64_data", "ext": "wav"}, + "image": {"content": "base64_data", "ext": "jpg"}, // Optional + "params": { + "emotion": "S2E", + "a_cfg_scale": 2.0, + "r_cfg_scale": 1.0, + "e_cfg_scale": 1.0, + "nfe": 10, + "seed": 25, + "no_crop": false + } + } + + Returns: + { + "status": "success", + "video": "base64_encoded_video", + "video_path": "/path/to/video.mp4", + "message": "Video generated successfully", + "params_used": {...} + } + """ + if agent is None: + raise HTTPException(status_code=503, detail="Model not loaded yet") + + try: + # Convert Pydantic models to dict format expected by process_generation_request + request_data = { + "audio": { + "content": request.audio.content, + "ext": request.audio.ext + } + } + + # Add image if provided + if request.image: + request_data["image"] = { + "content": request.image.content, + "ext": request.image.ext + } + + # Add params if provided + if request.params: + request_data["params"] = { + "emotion": request.params.emotion, + "a_cfg_scale": request.params.a_cfg_scale, + "r_cfg_scale": request.params.r_cfg_scale, + "e_cfg_scale": request.params.e_cfg_scale, + "nfe": request.params.nfe, + "seed": request.params.seed, + "no_crop": request.params.no_crop + } + + # Process the request with semaphore control + result = await process_generation_request(request_data) + + # Return error if processing failed + if result.get("status") == "error": + raise HTTPException(status_code=500, detail=result.get("error", "Generation failed")) + + return result + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") + +@app.post("/generate/batch") +async def generate_video_batch(request: BatchGenerationRequest): + """ + Batch video generation endpoint - process multiple audios at once + + Request body: + { + "audios": [ + {"content": "base64_audio_1", "ext": "wav"}, + {"content": "base64_audio_2", "ext": "wav"} + ], + "image": {"content": "base64_image", "ext": "jpg"}, // Optional: single image for all + "images": [ // Optional: one image per audio (overrides "image") + {"content": "base64_image_1", "ext": "jpg"}, + {"content": "base64_image_2", "ext": "jpg"} + ], + "params": { + "emotion": "S2E", + "nfe": 10, + ... + } + } + + Returns: + { + "status": "success", + "total": 2, + "successful": 2, + "failed": 0, + "results": [ + { + "index": 0, + "status": "success", + "video": "base64_video_1", + "video_path": "/path/to/video1.mp4", + ... + }, + { + "index": 1, + "status": "success", + "video": "base64_video_2", + "video_path": "/path/to/video2.mp4", + ... + } + ] + } + """ + if agent is None: + raise HTTPException(status_code=503, detail="Model not loaded yet") + + try: + # Validate input + num_audios = len(request.audios) + if num_audios == 0: + raise HTTPException(status_code=400, detail="At least one audio file is required") + + # Check if using per-audio images or single image + if request.images: + if len(request.images) != num_audios: + raise HTTPException( + status_code=400, + detail=f"Number of images ({len(request.images)}) must match number of audios ({num_audios})" + ) + use_multiple_images = True + else: + use_multiple_images = False + + print(f"\nšŸ“¦ Processing batch of {num_audios} audio files...") + print(f"šŸ”’ Concurrent limit: {MAX_CONCURRENT_GENERATIONS}, Active: {active_generations}") + + # Process each audio (semaphore will automatically control concurrency) + tasks = [] + + for idx, audio in enumerate(request.audios): + # Build request for this audio + single_request = { + "audio": { + "content": audio.content, + "ext": audio.ext + } + } + + # Determine which image to use + if use_multiple_images: + single_request["image"] = { + "content": request.images[idx].content, + "ext": request.images[idx].ext + } + elif request.image: + single_request["image"] = { + "content": request.image.content, + "ext": request.image.ext + } + + # Add params if provided + if request.params: + single_request["params"] = { + "emotion": request.params.emotion, + "a_cfg_scale": request.params.a_cfg_scale, + "r_cfg_scale": request.params.r_cfg_scale, + "e_cfg_scale": request.params.e_cfg_scale, + "nfe": request.params.nfe, + "seed": request.params.seed, + "no_crop": request.params.no_crop + } + + # Create task for this generation + task = process_single_batch_item(idx, single_request, num_audios) + tasks.append(task) + + # Wait for all tasks to complete (concurrency controlled by semaphore) + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Process results + successful = 0 + failed = 0 + processed_results = [] + + for idx, result in enumerate(results): + if isinstance(result, Exception): + failed += 1 + processed_results.append({ + "index": idx, + "status": "error", + "error": str(result), + "message": f"Failed to process audio {idx + 1}" + }) + else: + if result.get("status") == "success": + successful += 1 + else: + failed += 1 + processed_results.append(result) + + print(f"\nšŸ“Š Batch processing complete: {successful} successful, {failed} failed") + + # Return batch results + return { + "status": "success" if failed == 0 else "partial", + "total": num_audios, + "successful": successful, + "failed": failed, + "results": processed_results, + "message": f"Processed {num_audios} audios: {successful} successful, {failed} failed" + } + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Batch processing error: {str(e)}") + +async def process_single_batch_item(idx: int, request_data: Dict, total: int) -> Dict: + """Process a single batch item with logging""" + print(f"\nšŸŽµ Queuing audio {idx + 1}/{total}...") + + try: + result = await process_generation_request(request_data) + result["index"] = idx + + if result.get("status") == "success": + print(f"āœ… Audio {idx + 1} completed successfully") + else: + print(f"āŒ Audio {idx + 1} failed: {result.get('error', 'Unknown error')}") + + return result + + except Exception as e: + print(f"āŒ Audio {idx + 1} failed with exception: {str(e)}") + return { + "index": idx, + "status": "error", + "error": str(e), + "message": f"Failed to process audio {idx + 1}" + } + +def save_base64_file(base64_content: str, extension: str, prefix: str = "temp") -> str: + """Save base64 content to a temporary file""" + temp_file = tempfile.NamedTemporaryFile( + delete=False, + suffix=f".{extension}", + prefix=f"{prefix}_" + ) + + file_data = base64.b64decode(base64_content) + temp_file.write(file_data) + temp_file.close() + + return temp_file.name + +def read_file_as_base64(file_path: str) -> str: + """Read file and return as base64 string""" + with open(file_path, 'rb') as f: + return base64.b64encode(f.read()).decode('utf-8') + +async def process_generation_request(request_data: Dict) -> Dict: + """ + Process video generation request with semaphore-based concurrency control + + Expected format: + { + "audio": {"content": "base64_data", "ext": "wav"}, + "image": {"content": "base64_data", "ext": "jpg"}, # Optional + "params": { + "emotion": "S2E", + "a_cfg_scale": 2.0, + "r_cfg_scale": 1.0, + "e_cfg_scale": 1.0, + "nfe": 10, + "seed": 25, + "no_crop": false + } + } + """ + global active_generations + + # Acquire semaphore to limit concurrent generations + async with generation_semaphore: + # Update active generation counter + async with generation_lock: + active_generations += 1 + current_count = active_generations + + print(f"šŸ”’ Generation slot acquired ({current_count}/{MAX_CONCURRENT_GENERATIONS} active)") + + temp_files = [] + + try: + # Validate audio + if "audio" not in request_data: + return { + "status": "error", + "error": "Missing 'audio' field in request" + } + + audio_data = request_data["audio"] + if "content" not in audio_data or "ext" not in audio_data: + return { + "status": "error", + "error": "Audio must have 'content' and 'ext' fields" + } + + # Save audio to temp file + audio_ext = audio_data["ext"].lstrip('.') + audio_path = save_base64_file(audio_data["content"], audio_ext, "audio") + temp_files.append(audio_path) + print(f"šŸ’¾ Saved audio: {audio_path}") + + # Handle image + if "image" in request_data and request_data["image"]: + image_data = request_data["image"] + if "content" not in image_data or "ext" not in image_data: + return { + "status": "error", + "error": "Image must have 'content' and 'ext' fields" + } + + image_ext = image_data["ext"].lstrip('.') + image_path = save_base64_file(image_data["content"], image_ext, "image") + temp_files.append(image_path) + print(f"šŸ“ø Using uploaded image: {image_path}") + else: + if not os.path.exists(DEFAULT_IMAGE_PATH): + return { + "status": "error", + "error": f"Default image not found: {DEFAULT_IMAGE_PATH}" + } + image_path = DEFAULT_IMAGE_PATH + print(f"šŸ“ø Using default image: {image_path}") + + # Extract parameters + params = request_data.get("params", {}) + emotion = params.get("emotion", "S2E") + a_cfg_scale = params.get("a_cfg_scale", 2.0) + r_cfg_scale = params.get("r_cfg_scale", 1.0) + e_cfg_scale = params.get("e_cfg_scale", 1.0) + nfe = params.get("nfe", 10) + seed = params.get("seed", 25) + no_crop = params.get("no_crop", False) + + # Generate output path + call_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + res_video_path = os.path.join( + opt.res_dir, + f"{call_time}_generated_nfe{nfe}_seed{seed}.mp4" + ) + + print(f"šŸŽ¬ Generating video...") + print(f" Emotion: {emotion}, NFE: {nfe}, Seed: {seed}") + + # Run inference in thread pool to avoid blocking + loop = asyncio.get_event_loop() + output_path = await loop.run_in_executor( + None, + lambda: agent.run_inference( + res_video_path=res_video_path, + ref_path=image_path, + audio_path=audio_path, + a_cfg_scale=a_cfg_scale, + r_cfg_scale=r_cfg_scale, + e_cfg_scale=e_cfg_scale, + emo=emotion, + nfe=nfe, + no_crop=no_crop, + seed=seed, + verbose=True + ) + ) + + print(f"āœ… Video generated: {output_path}") + + # Read video as base64 + video_base64 = read_file_as_base64(output_path) + + return { + "status": "success", + "video": video_base64, + "video_path": output_path, + "message": "Video generated successfully", + "params_used": { + "emotion": emotion, + "nfe": nfe, + "seed": seed, + "a_cfg_scale": a_cfg_scale, + "r_cfg_scale": r_cfg_scale, + "e_cfg_scale": e_cfg_scale + } + } + + except Exception as e: + print(f"āŒ Error during processing: {str(e)}") + import traceback + traceback.print_exc() + + return { + "status": "error", + "error": str(e), + "message": "Failed to generate video" + } + + finally: + # Cleanup temp files + for temp_file in temp_files: + try: + if os.path.exists(temp_file): + os.remove(temp_file) + print(f"šŸ—‘ļø Cleaned up: {temp_file}") + except Exception as e: + print(f"āš ļø Failed to delete {temp_file}: {e}") + + # Release generation slot + async with generation_lock: + active_generations -= 1 + remaining = active_generations + + print(f"šŸ”“ Generation slot released ({remaining}/{MAX_CONCURRENT_GENERATIONS} active)") + +@app.websocket("/ws") +async def websocket_endpoint(websocket: WebSocket): + """WebSocket endpoint for video generation""" + await websocket.accept() + + client_id = f"{websocket.client.host}:{websocket.client.port}" + print(f"\nšŸ”Œ New connection from {client_id}") + + try: + # Send welcome message + await websocket.send_json({ + "status": "connected", + "message": "Connected to FLOAT WebSocket API", + "version": "1.0.0", + "max_concurrent_generations": MAX_CONCURRENT_GENERATIONS + }) + + # Process messages + while True: + try: + # Receive message + message = await websocket.receive_text() + print(f"šŸ“Ø Received message from {client_id}") + + # Parse JSON + request_data = json.loads(message) + + # Send processing acknowledgment + await websocket.send_json({ + "status": "processing", + "message": "Processing your request...", + "active_generations": active_generations, + "max_concurrent": MAX_CONCURRENT_GENERATIONS + }) + + # Process request (semaphore will control concurrency) + response = await process_generation_request(request_data) + + # Send response + await websocket.send_json(response) + print(f"āœ‰ļø Sent response to {client_id}") + + except json.JSONDecodeError as e: + await websocket.send_json({ + "status": "error", + "error": f"Invalid JSON: {str(e)}" + }) + print(f"āš ļø JSON decode error from {client_id}") + + except Exception as e: + await websocket.send_json({ + "status": "error", + "error": f"Processing error: {str(e)}" + }) + print(f"āŒ Error handling message from {client_id}: {e}") + + except WebSocketDisconnect: + print(f"šŸ”Œ Client disconnected: {client_id}") + + except Exception as e: + print(f"āŒ WebSocket error with {client_id}: {e}") + + finally: + print(f"šŸ‘‹ Connection closed: {client_id}\n") + +# For running with uvicorn (alternative to hypercorn) +"""if __name__ == "__main__": + import uvicorn + uvicorn.run( + "main:app", + host="0.0.0.0", + port=8000, + reload=False, + log_level="info" + )""" \ No newline at end of file From b78f57c21e8d1331bff7adbf04449a0f2bd19ad6 Mon Sep 17 00:00:00 2001 From: Mohamed Emam <126331291+mohamed-em2m@users.noreply.github.com> Date: Thu, 27 Nov 2025 01:40:16 +0200 Subject: [PATCH 20/20] Add matplotlib_inline version 0.1.7 to requirements --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 84c3e13..9d2d3e6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -53,3 +53,4 @@ httpx==0.25.1 # Pydantic for data validation pydantic==2.7.0 pydantic-settings==2.1.0 +matplotlib_inline==0.1.7