diff --git a/.devcontainer/post-create.sh b/.devcontainer/post-create.sh index 84787613b..bc2fdfa6a 100755 --- a/.devcontainer/post-create.sh +++ b/.devcontainer/post-create.sh @@ -5,7 +5,7 @@ cd /workspace/comfystream # Install Comfystream in editable mode. echo -e "\e[32mInstalling Comfystream in editable mode...\e[0m" -/workspace/miniconda3/envs/comfystream/bin/python3 -m pip install -e . --root-user-action=ignore > /dev/null +/workspace/miniconda3/envs/comfystream/bin/python3 -m pip install -e . -c src/comfystream/scripts/constraints.txt --root-user-action=ignore > /dev/null # Install npm packages if needed if [ ! -d "/workspace/comfystream/ui/node_modules" ]; then diff --git a/.editorconfig b/.editorconfig index c39bef288..17f476700 100644 --- a/.editorconfig +++ b/.editorconfig @@ -13,7 +13,9 @@ insert_final_newline = true insert_final_newline = unset [*.py] +indent_style = space indent_size = 4 +trim_trailing_whitespace = false [workflows/comfy*/*.json] insert_final_newline = unset diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml index 608a0edc4..f231448d7 100644 --- a/.github/workflows/docker.yaml +++ b/.github/workflows/docker.yaml @@ -9,6 +9,12 @@ on: - main tags: - "v*" + workflow_dispatch: + inputs: + nodes_config: + description: "Custom nodes config filename or path for base image build" + required: false + default: "nodes.yaml" concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} @@ -27,7 +33,7 @@ jobs: runs-on: [self-hosted, linux, gpu] steps: - name: Check out code - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 ref: ${{ github.event.pull_request.head.sha }} @@ -78,33 +84,12 @@ jobs: file: docker/Dockerfile.base build-args: | CACHEBUST=${{ github.run_id }} + NODES_CONFIG=${{ github.event_name == 'workflow_dispatch' && github.event.inputs.nodes_config || 'nodes.yaml' }} labels: ${{ steps.meta.outputs.labels }} annotations: ${{ steps.meta.outputs.annotations }} cache-from: type=registry,ref=livepeer/comfyui-base:build-cache cache-to: type=registry,mode=max,ref=livepeer/comfyui-base:build-cache - trigger: - name: Trigger ai-runner workflow - needs: base - if: ${{ github.repository == 'livepeer/comfystream' }} - runs-on: ubuntu-latest - steps: - - name: Send workflow dispatch event to ai-runner - uses: actions/github-script@v7 - with: - github-token: ${{ secrets.CI_GITHUB_TOKEN }} - script: | - await github.rest.actions.createWorkflowDispatch({ - owner: context.repo.owner, - repo: "ai-runner", - workflow_id: "comfyui-trigger.yaml", - ref: "main", - inputs: { - "comfyui-base-digest": "${{ needs.base.outputs.image-digest }}", - "triggering-branch": "${{ github.head_ref || github.ref_name }}", - }, - }); - comfystream: name: comfystream image needs: base @@ -115,7 +100,7 @@ jobs: runs-on: [self-hosted, linux, amd64] steps: - name: Check out code - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 ref: ${{ github.event.pull_request.head.sha }} diff --git a/.github/workflows/opencv-cuda-artifact.yml b/.github/workflows/opencv-cuda-artifact.yml new file mode 100644 index 000000000..df7071903 --- /dev/null +++ b/.github/workflows/opencv-cuda-artifact.yml @@ -0,0 +1,184 @@ +name: Build OpenCV CUDA Artifact + +on: + workflow_dispatch: + inputs: + python_version: + description: 'Python version to build' + required: false + default: '3.12' + type: string + cuda_version: + description: 'CUDA version to build' + required: false + default: '12.8' + type: string + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +env: + PYTHON_VERSION: ${{ github.event.inputs.python_version || '3.12' }} + CUDA_VERSION: ${{ github.event.inputs.cuda_version || '12.8' }} + +jobs: + build-opencv-artifact: + name: Build OpenCV CUDA Artifact + runs-on: [self-hosted, linux, gpu] + + steps: + - name: Checkout code + uses: actions/checkout@v6 + with: + fetch-depth: 0 + ref: ${{ github.event.pull_request.head.sha || github.sha }} + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Build OpenCV CUDA Docker image + uses: docker/build-push-action@v6 + with: + context: . + file: docker/Dockerfile.opencv + build-args: | + BASE_IMAGE=nvidia/cuda:${{ env.CUDA_VERSION }}.1-cudnn-devel-ubuntu22.04 + PYTHON_VERSION=${{ env.PYTHON_VERSION }} + CUDA_VERSION=${{ env.CUDA_VERSION }} + tags: opencv-cuda-artifact:latest + load: true + cache-from: type=gha + cache-to: type=gha,mode=max + + - name: Extract OpenCV libraries from Docker container + run: | + echo "Creating temporary container..." + docker create --name opencv-extract opencv-cuda-artifact:latest + + echo "Creating workspace directory..." + mkdir -p ./opencv-artifacts + + # Try to copy from system installation + docker cp opencv-extract:/usr/local/lib/python${{ env.PYTHON_VERSION }}/site-packages/cv2 ./opencv-artifacts/cv2 || echo "cv2 not found in system site-packages" + + echo "Copying OpenCV source directories..." + # Copy opencv and opencv_contrib source directories + docker cp opencv-extract:/workspace/opencv ./opencv-artifacts/ || echo "opencv source not found" + docker cp opencv-extract:/workspace/opencv_contrib ./opencv-artifacts/ || echo "opencv_contrib source not found" + + echo "Cleaning up container..." + docker rm opencv-extract + + echo "Contents of opencv-artifacts:" + ls -la ./opencv-artifacts/ + + - name: Create tarball artifact + run: | + echo "Creating opencv-cuda-release.tar.gz..." + cd ./opencv-artifacts + tar -czf ../opencv-cuda-release.tar.gz . || echo "Failed to create tarball" + cd .. + + echo "Generating checksums..." + sha256sum opencv-cuda-release.tar.gz > opencv-cuda-release.tar.gz.sha256 + md5sum opencv-cuda-release.tar.gz > opencv-cuda-release.tar.gz.md5 + + echo "Verifying archive contents..." + echo "Archive size: $(ls -lh opencv-cuda-release.tar.gz | awk '{print $5}')" + echo "First 20 files in archive:" + tar -tzf opencv-cuda-release.tar.gz | head -20 + + - name: Extract and verify tarball + run: | + echo "Testing tarball extraction..." + mkdir -p test-extract + cd test-extract + tar -xzf ../opencv-cuda-release.tar.gz + echo "Extracted contents:" + find . -maxdepth 2 -type d | sort + cd .. + rm -rf test-extract + + - name: Upload OpenCV CUDA Release Artifact + uses: actions/upload-artifact@v5 + with: + name: opencv-cuda-release-python${{ env.PYTHON_VERSION }}-cuda${{ env.CUDA_VERSION }}-${{ github.sha }} + path: | + opencv-cuda-release.tar.gz + opencv-cuda-release.tar.gz.sha256 + opencv-cuda-release.tar.gz.md5 + retention-days: 30 + + - name: Create Release Notes + run: | + cat > release-info.txt << EOF + OpenCV CUDA Release Artifact + + Build Details: + - Python Version: ${{ env.PYTHON_VERSION }} + - CUDA Version: ${{ env.CUDA_VERSION }} + - OpenCV Version: 4.11.0 + - Built on: $(date -u) + - Commit SHA: ${{ github.sha }} + + Contents: + - cv2: Python OpenCV module with CUDA support + - opencv: OpenCV source code + - opencv_contrib: OpenCV contrib modules source + - lib: Compiled OpenCV libraries + - include: OpenCV header files + + Installation: + 1. Download opencv-cuda-release.tar.gz + 2. Extract: tar -xzf opencv-cuda-release.tar.gz + 3. Copy cv2 to your Python environment's site-packages + 4. Ensure CUDA libraries are in your system PATH + + Checksums: + SHA256: $(cat opencv-cuda-release.tar.gz.sha256) + MD5: $(cat opencv-cuda-release.tar.gz.md5) + EOF + + - name: Upload Release Info + uses: actions/upload-artifact@v5 + with: + name: release-info-python${{ env.PYTHON_VERSION }}-cuda${{ env.CUDA_VERSION }}-${{ github.sha }} + path: release-info.txt + retention-days: 30 + + create-release-draft: + name: Create Release Draft + needs: build-opencv-artifact + runs-on: ubuntu-latest + if: github.event_name == 'push' && github.ref == 'refs/heads/main' + + steps: + - name: Checkout code + uses: actions/checkout@v6 + + - name: Download artifacts + uses: actions/download-artifact@v6 + with: + name: opencv-cuda-release-python${{ env.PYTHON_VERSION }}-cuda${{ env.CUDA_VERSION }}-${{ github.sha }} + path: ./artifacts + + - name: Download release info + uses: actions/download-artifact@v6 + with: + name: release-info-python${{ env.PYTHON_VERSION }}-cuda${{ env.CUDA_VERSION }}-${{ github.sha }} + path: ./artifacts + + - name: Create Release Draft + uses: softprops/action-gh-release@v2 + with: + tag_name: opencv-cuda-v${{ env.PYTHON_VERSION }}-${{ env.CUDA_VERSION }}-${{ github.run_number }} + name: OpenCV CUDA Release - Python ${{ env.PYTHON_VERSION }} CUDA ${{ env.CUDA_VERSION }} + body_path: ./artifacts/release-info.txt + draft: true + files: | + ./artifacts/opencv-cuda-release.tar.gz + ./artifacts/opencv-cuda-release.tar.gz.sha256 + ./artifacts/opencv-cuda-release.tar.gz.md5 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} \ No newline at end of file diff --git a/.github/workflows/publish-comfyui-node.yaml b/.github/workflows/publish-comfyui-node.yaml new file mode 100644 index 000000000..0b609aee3 --- /dev/null +++ b/.github/workflows/publish-comfyui-node.yaml @@ -0,0 +1,26 @@ +name: Publish ComfyUI Custom Node + +on: + workflow_dispatch: + +permissions: + contents: read + issues: write + +jobs: + publish-comfyui-node: + name: Publish Custom Node to ComfyUI registry + runs-on: ubuntu-latest + # Ensure this only runs on main branch + if: ${{ github.ref == 'refs/heads/main' }} + steps: + - name: Check out code + uses: actions/checkout@v6 + with: + submodules: true + + - name: Publish Custom Node + uses: Comfy-Org/publish-node-action@v1 + with: + ## Add your own personal access token to your Github Repository secrets and reference it here. + personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 01900b71a..58d270929 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -12,9 +12,9 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v6 - - uses: actions/setup-node@v4 + - uses: actions/setup-node@v6 with: node-version: 18 cache: npm @@ -38,7 +38,7 @@ jobs: cd - - name: Upload artifacts - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v5 with: name: release-artifacts path: releases/ diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 98d27d262..98baa6cf3 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -19,20 +19,20 @@ jobs: runs-on: ubuntu-latest steps: - name: Check out code - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 ref: ${{ github.event.pull_request.head.sha }} # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL - uses: github/codeql-action/init@v3 + uses: github/codeql-action/init@v4 with: languages: typescript,javascript,python config-file: ./.github/codeql-config.yaml - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v3 + uses: github/codeql-action/analyze@v4 editorconfig: @@ -40,7 +40,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Check out code - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: # Check https://github.com/livepeer/go-livepeer/pull/1891 # for ref value discussion @@ -59,14 +59,14 @@ jobs: runs-on: ubuntu-latest steps: - name: Check out code - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: # Check https://github.com/livepeer/go-livepeer/pull/1891 # for ref value discussion ref: ${{ github.event.pull_request.head.sha }} - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.12' cache: pip diff --git a/.husky/pre-commit b/.husky/pre-commit new file mode 100755 index 000000000..d9a28aeb7 --- /dev/null +++ b/.husky/pre-commit @@ -0,0 +1,3 @@ +#!/bin/sh +cd ui && npx lint-staged + diff --git a/.vscode/launch.json b/.vscode/launch.json index 4d442c585..f05e02f5e 100755 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -56,7 +56,7 @@ "env": { "ORCH_URL": "https://172.17.0.1:9995", "ORCH_SECRET": "orch-secret", - "CAPABILITY_NAME": "comfystream-byoc-processor", + "CAPABILITY_NAME": "comfystream", "CAPABILITY_DESCRIPTION": "ComfyUI streaming processor for BYOC mode", "CAPABILITY_URL": "http://172.17.0.1:8000", "CAPABILITY_PRICE_PER_UNIT": "0", diff --git a/README.md b/README.md index 52c864f13..a9edf4be5 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ This repo also includes a WebRTC server and UI that uses comfystream to support Refer to [.devcontainer/README.md](.devcontainer/README.md) to setup ComfyStream in a devcontainer using a pre-configured ComfyUI docker environment. -For other installation options, refer to [Install ComfyUI and ComfyStream](https://pipelines.livepeer.org/docs/technical/install/local-testing) in the Livepeer pipelines documentation. +For other installation options, refer to [Install ComfyUI and ComfyStream](https://docs.comfystream.org/technical/get-started/install) in the ComfyStream documentation. For additional information, refer to the remaining sections below. @@ -35,7 +35,7 @@ For additional information, refer to the remaining sections below. You can quickly deploy ComfyStream using the docker image `livepeer/comfystream` -Refer to the documentation at [https://pipelines.livepeer.org/docs/technical/getting-started/install-comfystream](https://pipelines.livepeer.org/docs/technical/getting-started/install-comfystream) for instructions to run locally or on a remote server. +Refer to the documentation at [https://docs.comfystream.org/technical/get-started/install](https://docs.comfystream.org/technical/get-started/install) for instructions to run locally or on a remote server. #### RunPod diff --git a/__init__.py b/__init__.py index 1215db9b3..7782d363f 100644 --- a/__init__.py +++ b/__init__.py @@ -4,4 +4,4 @@ # Import and expose node classes from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS -__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] +__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] diff --git a/benchmark.py b/benchmark.py index 359a8c237..8e096a77f 100644 --- a/benchmark.py +++ b/benchmark.py @@ -1,15 +1,16 @@ -import av +import argparse +import asyncio import json -import time -import torch import logging -import asyncio -import argparse +import time + +import av import numpy as np +import torch from comfystream.client import ComfyStreamClient -logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s') +logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger() @@ -22,14 +23,30 @@ def create_dummy_video_frame(width, height): async def main(): parser = argparse.ArgumentParser(description="Benchmark ComfyStreamClient workflow execution.") - parser.add_argument("--workflow-path", default="./workflows/comfystream/tensor-utils-example-api.json", help="Path to the workflow JSON file.") + parser.add_argument( + "--workflow-path", + default="./workflows/comfystream/tensor-utils-example-api.json", + help="Path to the workflow JSON file.", + ) parser.add_argument("--num-requests", type=int, default=100, help="Number of requests to send.") - parser.add_argument("--fps", type=float, default=None, help="Frames per second for FPS-based benchmarking.") - parser.add_argument("--cwd", default="/workspace/ComfyUI", help="Current working directory for ComfyStreamClient.") + parser.add_argument( + "--fps", type=float, default=None, help="Frames per second for FPS-based benchmarking." + ) + parser.add_argument( + "--cwd", + default="/workspace/ComfyUI", + help="Current working directory for ComfyStreamClient.", + ) parser.add_argument("--width", type=int, default=512, help="Width of dummy video frames.") parser.add_argument("--height", type=int, default=512, help="Height of dummy video frames.") - parser.add_argument("--verbose", action="store_true", help="Enable verbose logging (shows progress for each request).") - parser.add_argument("--warmup-runs", type=int, default=5, help="Number of warm-up runs before benchmarking.") + parser.add_argument( + "--verbose", + action="store_true", + help="Enable verbose logging (shows progress for each request).", + ) + parser.add_argument( + "--warmup-runs", type=int, default=5, help="Number of warm-up runs before benchmarking." + ) args = parser.parse_args() @@ -44,7 +61,9 @@ async def main(): client = ComfyStreamClient(cwd=args.cwd) await client.set_prompts([prompt]) - logger.info(f"Starting benchmark with workflow: {args.workflow_path}, requests: {args.num_requests}, resolution: {args.width}x{args.height}, warmup runs: {args.warmup_runs}") + logger.info( + f"Starting benchmark with workflow: {args.workflow_path}, requests: {args.num_requests}, resolution: {args.width}x{args.height}, warmup runs: {args.warmup_runs}" + ) if args.warmup_runs > 0: logger.info(f"Running {args.warmup_runs} warm-up runs...") @@ -66,11 +85,13 @@ async def main(): await client.get_video_output() request_end_time = time.time() round_trip_times.append(request_end_time - request_start_time) - logger.debug(f"Request {i+1}/{args.num_requests} completed in {round_trip_times[-1]:.4f} seconds") + logger.debug( + f"Request {i + 1}/{args.num_requests} completed in {round_trip_times[-1]:.4f} seconds" + ) end_time = time.time() total_time = end_time - start_time - output_fps = args.num_requests / total_time if total_time > 0 else float('inf') + output_fps = args.num_requests / total_time if total_time > 0 else float("inf") # Calculate percentiles for sequential mode p50_rtt = np.percentile(round_trip_times, 50) @@ -79,17 +100,17 @@ async def main(): p95_rtt = np.percentile(round_trip_times, 95) p99_rtt = np.percentile(round_trip_times, 99) - print("\n" + "="*40) + print("\n" + "=" * 40) print("FPS Results:") - print("="*40) + print("=" * 40) print(f"Total requests: {args.num_requests}") print(f"Total time: {total_time:.4f} seconds") print(f"Actual Output FPS:{output_fps:.2f}") print(f"Total requests: {args.num_requests}") print(f"Total time: {total_time:.4f} seconds") - print("\n" + "="*40) + print("\n" + "=" * 40) print("Latency Results:") - print("="*40) + print("=" * 40) print(f"Average: {np.mean(round_trip_times):.4f}") print(f"Min: {np.min(round_trip_times):.4f}") print(f"Max: {np.max(round_trip_times):.4f}") @@ -99,7 +120,6 @@ async def main(): print(f"P95: {p95_rtt:.4f}") print(f"P99: {p99_rtt:.4f}") - else: # This is mainly used to stress test the ComfyUI client, gives us a good idea on how frame skipping etc is working on the client end. logger.info(f"Running FPS-based benchmark at {args.fps} FPS...") @@ -118,9 +138,11 @@ async def collect_outputs_task(): last_output_receive_time = time.time() received_frames_count += 1 - logger.debug(f"Received output frame {received_frames_count} at {last_output_receive_time - start_time:.4f} seconds") + logger.debug( + f"Received output frame {received_frames_count} at {last_output_receive_time - start_time:.4f} seconds" + ) except asyncio.TimeoutError: - logger.debug(f"Output collection task timed out after waiting for 5 seconds.") + logger.debug("Output collection task timed out after waiting for 5 seconds.") break except Exception as e: logger.debug(f"Output collection task finished due to exception: {e}") @@ -140,7 +162,9 @@ async def collect_outputs_task(): request_send_time = time.time() client.put_video_input(frame) - logger.debug(f"Sent request {i+1}/{args.num_requests} at {request_send_time - start_time:.4f} seconds") + logger.debug( + f"Sent request {i + 1}/{args.num_requests} at {request_send_time - start_time:.4f} seconds" + ) await output_collector_task @@ -150,11 +174,11 @@ async def collect_outputs_task(): elif received_frames_count == 0: output_fps = 0.0 else: - output_fps = float('inf') + output_fps = float("inf") - print("\n" + "="*40) + print("\n" + "=" * 40) print("FPS Results:") - print("="*40) + print("=" * 40) print(f"Target Input FPS: {args.fps:.2f}") print(f"Actual Output FPS:{output_fps:.2f} ({received_frames_count} frames received)") print(f"Total requests: {args.num_requests}") @@ -162,4 +186,4 @@ async def collect_outputs_task(): if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/configs/QUICK_REFERENCE.md b/configs/QUICK_REFERENCE.md new file mode 100644 index 000000000..a4a12481b --- /dev/null +++ b/configs/QUICK_REFERENCE.md @@ -0,0 +1,74 @@ +# Quick Reference: Model Configuration + +## Single File vs Directory Download + +### Single File (Default) +```yaml +my-model: + name: "My Model" + url: "https://huggingface.co/user/repo/resolve/main/file.safetensors" + path: "loras/model.safetensors" +``` + +### Directory (Add `is_directory: true`) +```yaml +my-directory: + name: "My Directory" + url: "https://huggingface.co/user/repo/tree/main/folder" + path: "models/folder" + is_directory: true # ← Add this! +``` + +## URL Patterns + +| Download Type | URL Pattern | Example | +|---------------|-------------|---------| +| **Single File** | `/resolve/` | `https://huggingface.co/h94/IP-Adapter/resolve/main/models/ip-adapter_sd15.safetensors` | +| **Directory** | `/tree/` | `https://huggingface.co/h94/IP-Adapter/tree/main/models/image_encoder` | + +## Common Model Paths + +| Model Type | Path Pattern | +|------------|--------------| +| Checkpoints | `checkpoints/SD1.5/` | +| LoRAs | `loras/SD1.5/` | +| ControlNet | `controlnet/` | +| VAE | `vae/` or `vae_approx/` | +| IP-Adapter | `ipadapter/` | +| Text Encoders | `text_encoders/CLIPText/` | +| TensorRT/ONNX | `tensorrt/` | + +## IP-Adapter Example + +```yaml +models: + # Single file - IP-Adapter model + ip-adapter-sd15: + name: "IP Adapter SD15" + url: "https://huggingface.co/h94/IP-Adapter/resolve/main/models/ip-adapter_sd15.safetensors" + path: "ipadapter/ip-adapter_sd15.safetensors" + + # Directory - CLIP image encoder + clip-image-encoder: + name: "CLIP Image Encoder" + url: "https://huggingface.co/h94/IP-Adapter/tree/main/models/image_encoder" + path: "ipadapter/models/image_encoder" + is_directory: true +``` + +## Usage + +```bash +# Use a config +python src/comfystream/scripts/setup_models.py --config my-config.yaml + +# Use default config (models.yaml) +python src/comfystream/scripts/setup_models.py +``` + +## See Also + +- [DIRECTORY_DOWNLOADS.md](../DIRECTORY_DOWNLOADS.md) - Detailed directory download guide +- [models-ipadapter-example.yaml](models-ipadapter-example.yaml) - Complete working example +- [README.md](README.md) - Full configuration reference + diff --git a/configs/models-ipadapter.yaml b/configs/models-ipadapter.yaml new file mode 100644 index 000000000..b6f225d2e --- /dev/null +++ b/configs/models-ipadapter.yaml @@ -0,0 +1,45 @@ +models: + # Example: IP-Adapter setup with directory download + + # Single file download (regular) + ip-adapter-plus-sd15: + name: "IP Adapter SD15" + url: "https://huggingface.co/h94/IP-Adapter/resolve/main/models/ip-adapter-plus_sd15.safetensors" + path: "ipadapter/ip-adapter-plus_sd15.safetensors" + type: "ipadapter" + extra_files: + - url: "https://huggingface.co/h94/IP-Adapter/resolve/main/models/ip-adapter-plus_sd15.bin" + path: "ipadapter/ip-adapter-plus_sd15.bin" + + clip-image-encoder: + name: "CLIP Image Encoder" + url: "https://huggingface.co/h94/IP-Adapter/resolve/main/models/image_encoder/model.safetensors" + path: "ipadapter/image_encoder/model.safetensors" + type: "image_encoder" + extra_files: + - url: "https://huggingface.co/h94/IP-Adapter/resolve/main/models/image_encoder/config.json" + path: "ipadapter/image_encoder/config.json" + + # Base model + sd-turbo: + name: "SD-Turbo" + url: "https://huggingface.co/stabilityai/sd-turbo/resolve/main/sd_turbo.safetensors" + path: "checkpoints/SD1.5/sd_turbo.safetensors" + type: "checkpoint" + + PixelArtRedmond15V-PixelArt-PIXARFK.safetensors: + name: "PixelArtRedmond15V-PixelArt-PIXARFK" + url: "https://huggingface.co/artificialguybr/pixelartredmond-1-5v-pixel-art-loras-for-sd-1-5/resolve/ab43d9e2cf8c9240189f01e9cdc4ca341362500c/PixelArtRedmond15V-PixelArt-PIXARFK.safetensors" + path: "loras/SD1.5/PixelArtRedmond15V-PixelArt-PIXARFK.safetensors" + type: "lora" + + # TAESD for fast VAE + taesd: + name: "TAESD" + url: "https://huggingface.co/madebyollin/taesd/resolve/main/taesd_decoder.safetensors" + path: "vae_approx/taesd_decoder.safetensors" + type: "vae_approx" + extra_files: + - url: "https://huggingface.co/madebyollin/taesd/resolve/main/taesd_encoder.safetensors" + path: "vae_approx/taesd_encoder.safetensors" + diff --git a/configs/models.yaml b/configs/models.yaml index 09149f954..bbcf0cdc7 100644 --- a/configs/models.yaml +++ b/configs/models.yaml @@ -29,19 +29,19 @@ models: # TAESD models taesd: name: "TAESD" - url: "https://raw.githubusercontent.com/madebyollin/taesd/main/taesd_decoder.pth" - path: "vae_approx/taesd_decoder.pth" + url: "https://huggingface.co/madebyollin/taesd/resolve/main/taesd_decoder.safetensors" + path: "vae_approx/taesd_decoder.safetensors" type: "vae_approx" extra_files: - - url: "https://raw.githubusercontent.com/madebyollin/taesd/main/taesd_encoder.pth" - path: "vae_approx/taesd_encoder.pth" + - url: "https://huggingface.co/madebyollin/taesd/resolve/main/taesd_encoder.safetensors" + path: "vae_approx/taesd_encoder.safetensors" # ControlNet models controlnet-depth: name: "ControlNet Depth" url: "https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11f1p_sd15_depth_fp16.safetensors" path: "controlnet/control_v11f1p_sd15_depth_fp16.safetensors" - type: "controlnet" + type: "controlnet" controlnet-mediapipe-face: name: "ControlNet MediaPipe Face" @@ -74,8 +74,82 @@ models: path: "text_encoders/CLIPText/model.fp16.safetensors" type: "text_encoder" + # JoyVASA models for ComfyUI-FasterLivePortrait + joyvasa_motion_generator: + name: "JoyVASA Motion Generator" + url: "https://huggingface.co/jdh-algo/JoyVASA/resolve/main/motion_generator/motion_generator_hubert_chinese.pt?download=true" + path: "liveportrait_onnx/joyvasa_models/motion_generator_hubert_chinese.pt" + type: "torch" + + joyvasa_audio_model: + name: "JoyVASA Hubert Chinese" + url: "https://huggingface.co/TencentGameMate/chinese-hubert-base/resolve/main/chinese-hubert-base-fairseq-ckpt.pt?download=true" + path: "liveportrait_onnx/joyvasa_models/chinese-hubert-base-fairseq-ckpt.pt" + type: "torch" + + joyvasa_motion_template: + name: "JoyVASA Motion Template" + url: "https://huggingface.co/jdh-algo/JoyVASA/resolve/main/motion_template/motion_template.pkl?download=true" + path: "liveportrait_onnx/joyvasa_models/motion_template.pkl" + type: "pickle" + + # LivePortrait ONNX models - only necessary to build TRT engines + warping_spade: + name: "WarpingSpadeModel" + url: "https://huggingface.co/warmshao/FasterLivePortrait/resolve/main/liveportrait_onnx/warping_spade-fix.onnx?download=true" + path: "liveportrait_onnx/warping_spade-fix.onnx" + type: "onnx" + + motion_extractor: + name: "MotionExtractorModel" + url: "https://huggingface.co/warmshao/FasterLivePortrait/resolve/main/liveportrait_onnx/motion_extractor.onnx?download=true" + path: "liveportrait_onnx/motion_extractor.onnx" + type: "onnx" + + landmark: + name: "LandmarkModel" + url: "https://huggingface.co/warmshao/FasterLivePortrait/resolve/main/liveportrait_onnx/landmark.onnx?download=true" + path: "liveportrait_onnx/landmark.onnx" + type: "onnx" + + face_analysis_retinaface: + name: "FaceAnalysisModel - RetinaFace" + url: "https://huggingface.co/warmshao/FasterLivePortrait/resolve/main/liveportrait_onnx/retinaface_det_static.onnx?download=true" + path: "liveportrait_onnx/retinaface_det_static.onnx" + type: "onnx" + + face_analysis_2dpose: + name: "FaceAnalysisModel - 2DPose" + url: "https://huggingface.co/warmshao/FasterLivePortrait/resolve/main/liveportrait_onnx/face_2dpose_106_static.onnx?download=true" + path: "liveportrait_onnx/face_2dpose_106_static.onnx" + type: "onnx" + + appearance_feature_extractor: + name: "AppearanceFeatureExtractorModel" + url: "https://huggingface.co/warmshao/FasterLivePortrait/resolve/main/liveportrait_onnx/appearance_feature_extractor.onnx?download=true" + path: "liveportrait_onnx/appearance_feature_extractor.onnx" + type: "onnx" + + stitching: + name: "StitchingModel" + url: "https://huggingface.co/warmshao/FasterLivePortrait/resolve/main/liveportrait_onnx/stitching.onnx?download=true" + path: "liveportrait_onnx/stitching.onnx" + type: "onnx" + + stitching_eye_retarget: + name: "StitchingModel (Eye Retargeting)" + url: "https://huggingface.co/warmshao/FasterLivePortrait/resolve/main/liveportrait_onnx/stitching_eye.onnx?download=true" + path: "liveportrait_onnx/stitching_eye.onnx" + type: "onnx" + + stitching_lip_retarget: + name: "StitchingModel (Lip Retargeting)" + url: "https://huggingface.co/warmshao/FasterLivePortrait/resolve/main/liveportrait_onnx/stitching_lip.onnx?download=true" + path: "liveportrait_onnx/stitching_lip.onnx" + type: "onnx" + sd-turbo: name: "SD-Turbo" url: "https://huggingface.co/stabilityai/sd-turbo/resolve/main/sd_turbo.safetensors" path: "checkpoints/SD1.5/sd_turbo.safetensors" - type: "checkpoint" \ No newline at end of file + type: "checkpoint" diff --git a/configs/nodes-streamdiffusion.yaml b/configs/nodes-streamdiffusion.yaml new file mode 100644 index 000000000..23cebb603 --- /dev/null +++ b/configs/nodes-streamdiffusion.yaml @@ -0,0 +1,37 @@ +nodes: + # Minimal node configuration for faster builds + comfyui-tensorrt: + name: "ComfyUI TensorRT" + url: "https://github.com/yondonfu/ComfyUI_TensorRT.git" + branch: "quantization_with_controlnet_fixes" + type: "tensorrt" + dependencies: + - "tensorrt==10.12.0.36" + + comfyui-streamdiffusion: + name: "ComfyUI StreamDiffusion" + url: "https://github.com/RUFFY-369/ComfyUI-StreamDiffusion" + branch: "main" + type: "tensorrt" + + comfyui-torch-compile: + name: "ComfyUI Torch Compile" + url: "https://github.com/yondonfu/ComfyUI-Torch-Compile" + type: "tensorrt" + + comfyui_controlnet_aux: + name: "ComfyUI ControlNet Auxiliary" + url: "https://github.com/Fannovel16/comfyui_controlnet_aux" + type: "controlnet" + + comfyui-stream-pack: + name: "ComfyUI Stream Pack" + url: "https://github.com/livepeer/ComfyUI-Stream-Pack" + branch: "main" + type: "utility" + + rgthree-comfy: + name: "rgthree Comfy" + url: "https://github.com/rgthree/rgthree-comfy.git" + type: "utility" + diff --git a/configs/nodes.yaml b/configs/nodes.yaml index 49d422a57..fb7e77a85 100644 --- a/configs/nodes.yaml +++ b/configs/nodes.yaml @@ -5,8 +5,6 @@ nodes: url: "https://github.com/yondonfu/ComfyUI_TensorRT.git" branch: "quantization_with_controlnet_fixes" type: "tensorrt" - dependencies: - - "tensorrt==10.12.0.36" comfyui-depthanything-tensorrt: name: "ComfyUI DepthAnything TensorRT" @@ -19,6 +17,12 @@ nodes: branch: "main" type: "tensorrt" + comfyui-fasterliveportrait: + name: "ComfyUI FasterLivePortrait" + url: "https://github.com/pschroedl/ComfyUI-FasterLivePortrait.git" + branch: "main" + type: "tensorrt" + # Ryan's nodes comfyui-ryanontheinside: name: "ComfyUI RyanOnTheInside" diff --git a/docker/Dockerfile.base b/docker/Dockerfile.base index c8f6b7ff1..9bedd7185 100644 --- a/docker/Dockerfile.base +++ b/docker/Dockerfile.base @@ -1,32 +1,29 @@ ARG BASE_IMAGE=nvidia/cuda:12.8.1-cudnn-devel-ubuntu22.04 \ CONDA_VERSION=latest \ - PYTHON_VERSION=3.12 + PYTHON_VERSION=3.12 \ + NODES_CONFIG=nodes.yaml FROM "${BASE_IMAGE}" ARG CONDA_VERSION \ - PYTHON_VERSION + PYTHON_VERSION \ + NODES_CONFIG ENV DEBIAN_FRONTEND=noninteractive \ + TensorRT_ROOT=/opt/TensorRT-10.12.0.36 \ CONDA_VERSION="${CONDA_VERSION}" \ PATH="/workspace/miniconda3/bin:${PATH}" \ PYTHON_VERSION="${PYTHON_VERSION}" # System dependencies RUN apt update && apt install -yqq --no-install-recommends \ - git \ - wget \ - nano \ - socat \ - libsndfile1 \ - build-essential \ - llvm \ - tk-dev \ - libglvnd-dev \ - cmake \ - swig \ - libprotobuf-dev \ - protobuf-compiler \ + git wget nano socat \ + libsndfile1 build-essential llvm tk-dev \ + libglvnd-dev cmake swig libprotobuf-dev \ + protobuf-compiler libcairo2-dev libpango1.0-dev libgdk-pixbuf2.0-dev \ + libffi-dev libgirepository1.0-dev pkg-config libgflags-dev \ + libgoogle-glog-dev libjpeg-dev libavcodec-dev libavformat-dev \ + libavutil-dev libswscale-dev \ && rm -rf /var/lib/apt/lists/* #enable opengl support with nvidia gpu @@ -51,40 +48,68 @@ RUN mkdir -p /workspace/comfystream && \ RUN conda run -n comfystream --no-capture-output pip install --upgrade pip && \ conda run -n comfystream --no-capture-output pip install wheel +RUN apt-get remove --purge -y libcudnn9-cuda-12 libcudnn9-dev-cuda-12 || true && \ + apt-get autoremove -y && \ + rm -rf /var/lib/apt/lists/* + +# Install numpy<2.0.0 first +# to ensure numpy 2.0 is not installed automatically by another package +RUN conda run -n comfystream --no-capture-output pip install "numpy<2.0.0" + +# Install cuDNN 9.8 via conda to match base system version +# Caution: Mixed versions installed in environment (system/python) can cause CUDNN_STATUS_SUBLIBRARY_VERSION_MISMATCH errors +RUN conda install -n comfystream -y -c nvidia -c conda-forge cudnn=9.8 cuda-version=12.8 + # Copy only files needed for setup COPY ./src/comfystream/scripts /workspace/comfystream/src/comfystream/scripts COPY ./configs /workspace/comfystream/configs +# TensorRT SDK +WORKDIR /opt +RUN wget --progress=dot:giga \ + https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.12.0/tars/TensorRT-10.12.0.36.Linux.x86_64-gnu.cuda-12.9.tar.gz \ + && tar -xzf TensorRT-10.12.0.36.Linux.x86_64-gnu.cuda-12.9.tar.gz \ + && rm TensorRT-10.12.0.36.Linux.x86_64-gnu.cuda-12.9.tar.gz + +# Link libraries and update linker cache +RUN echo "${TensorRT_ROOT}/lib" > /etc/ld.so.conf.d/tensorrt.conf \ + && ldconfig + +# Install matching TensorRT Python bindings for CPython 3.12 +RUN conda run -n comfystream pip install --no-cache-dir \ + ${TensorRT_ROOT}/python/tensorrt-10.12.0.36-cp312-none-linux_x86_64.whl + # Clone ComfyUI -RUN git clone --branch v0.3.56 --depth 1 https://github.com/comfyanonymous/ComfyUI.git /workspace/ComfyUI +RUN git clone --branch v0.3.60 --depth 1 https://github.com/comfyanonymous/ComfyUI.git /workspace/ComfyUI +RUN git clone https://github.com/Comfy-Org/ComfyUI-Manager.git /workspace/ComfyUI/custom_nodes/ComfyUI-Manager # Copy ComfyStream files into ComfyUI COPY . /workspace/comfystream -RUN conda run -n comfystream --cwd /workspace/comfystream --no-capture-output pip install -r ./src/comfystream/scripts/constraints.txt +RUN conda run -n comfystream --cwd /workspace/comfystream --no-capture-output pip install -r src/comfystream/scripts/constraints.txt # Copy comfystream and example workflows to ComfyUI COPY ./workflows/comfyui/* /workspace/ComfyUI/user/default/workflows/ COPY ./test/example-512x512.png /workspace/ComfyUI/input +COPY ./docker/entrypoint.sh /workspace/comfystream/docker/entrypoint.sh # Install ComfyUI requirements -RUN conda run -n comfystream --no-capture-output --cwd /workspace/ComfyUI pip install -r requirements.txt --root-user-action=ignore +RUN conda run -n comfystream --no-capture-output --cwd /workspace/ComfyUI pip install -r requirements.txt --constraint /workspace/comfystream/src/comfystream/scripts/constraints.txt --root-user-action=ignore # Install ComfyStream requirements RUN ln -s /workspace/comfystream /workspace/ComfyUI/custom_nodes/comfystream -RUN conda run -n comfystream --no-capture-output --cwd /workspace/comfystream pip install -e . --root-user-action=ignore +RUN conda run -n comfystream --no-capture-output --cwd /workspace/comfystream pip install -e . --constraint src/comfystream/scripts/constraints.txt --root-user-action=ignore RUN conda run -n comfystream --no-capture-output --cwd /workspace/comfystream python install.py --workspace /workspace/ComfyUI # Accept a build-arg that lets CI force-invalidate setup_nodes.py ARG CACHEBUST=static ENV CACHEBUST=${CACHEBUST} -# Run setup_nodes -RUN conda run -n comfystream --no-capture-output --cwd /workspace/comfystream python src/comfystream/scripts/setup_nodes.py --workspace /workspace/ComfyUI +# Run setup_nodes with custom config if specified +RUN conda run -n comfystream --no-capture-output --cwd /workspace/comfystream python src/comfystream/scripts/setup_nodes.py --workspace /workspace/ComfyUI --config ${NODES_CONFIG} -RUN conda run -n comfystream --no-capture-output pip install "numpy<2.0.0" - -RUN conda run -n comfystream --no-capture-output pip install --no-cache-dir xformers==0.0.30 --no-deps +# Setup opencv with CUDA support +RUN conda run -n comfystream --no-capture-output --cwd /workspace/comfystream --no-capture-output docker/entrypoint.sh --opencv-cuda # Configure no environment activation by default RUN conda config --set auto_activate_base false && \ @@ -94,3 +119,8 @@ RUN conda config --set auto_activate_base false && \ RUN echo "conda activate comfystream" >> ~/.bashrc WORKDIR /workspace/comfystream + +# Run ComfyStream BYOC server from /workspace/ComfyUI within the comfystream conda env +ENTRYPOINT ["conda", "run", "--no-capture-output", "-n", "comfystream", "--cwd", "/workspace/ComfyUI", "python", "/workspace/comfystream/server/byoc.py"] +# Default args; can be overridden/appended at runtime +CMD ["--workspace", "/workspace/ComfyUI", "--host", "0.0.0.0", "--port", "8000"] diff --git a/docker/Dockerfile.opencv b/docker/Dockerfile.opencv new file mode 100644 index 000000000..848db1fe8 --- /dev/null +++ b/docker/Dockerfile.opencv @@ -0,0 +1,124 @@ +ARG BASE_IMAGE=nvidia/cuda:12.8.1-cudnn-devel-ubuntu22.04 \ + CONDA_VERSION=latest \ + PYTHON_VERSION=3.12 \ + CUDA_VERSION=12.8 + +FROM "${BASE_IMAGE}" + +ARG CONDA_VERSION \ + PYTHON_VERSION \ + CUDA_VERSION + +ENV DEBIAN_FRONTEND=noninteractive \ + CONDA_VERSION="${CONDA_VERSION}" \ + PATH="/workspace/miniconda3/bin:${PATH}" \ + PYTHON_VERSION="${PYTHON_VERSION}" \ + CUDA_VERSION="${CUDA_VERSION}" + +# System dependencies +RUN apt update && apt install -yqq \ + git \ + wget \ + nano \ + socat \ + libsndfile1 \ + build-essential \ + llvm \ + tk-dev \ + cmake \ + libgflags-dev \ + libgoogle-glog-dev \ + libjpeg-dev \ + libavcodec-dev \ + libavformat-dev \ + libavutil-dev \ + libswscale-dev && \ + rm -rf /var/lib/apt/lists/* + +RUN mkdir -p /workspace/comfystream && \ + wget "https://repo.anaconda.com/miniconda/Miniconda3-${CONDA_VERSION}-Linux-x86_64.sh" -O /tmp/miniconda.sh && \ + bash /tmp/miniconda.sh -b -p /workspace/miniconda3 && \ + eval "$(/workspace/miniconda3/bin/conda shell.bash hook)" && \ + conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main && \ + conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r && \ + conda create -n comfystream python="${PYTHON_VERSION}" -c conda-forge -y && \ + rm /tmp/miniconda.sh && \ + conda run -n comfystream --no-capture-output pip install numpy==1.26.4 aiortc aiohttp requests tqdm pyyaml --root-user-action=ignore + +# Clone ComfyUI +ADD --link https://github.com/comfyanonymous/ComfyUI.git /workspace/ComfyUI + +# OpenCV with CUDA support +WORKDIR /workspace + +# Clone OpenCV repositories +RUN git clone --depth 1 --branch 4.11.0 https://github.com/opencv/opencv.git && \ + git clone --depth 1 --branch 4.11.0 https://github.com/opencv/opencv_contrib.git + +# Create build directory +RUN mkdir -p /workspace/opencv/build + +# Create a toolchain file with absolute path +RUN echo '# Custom toolchain file to exclude Conda paths\n\ +\n\ +# Set system compilers\n\ +set(CMAKE_C_COMPILER "/usr/bin/gcc")\n\ +set(CMAKE_CXX_COMPILER "/usr/bin/g++")\n\ +\n\ +# Set system root directories\n\ +set(CMAKE_FIND_ROOT_PATH "/usr")\n\ +set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER)\n\ +set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY)\n\ +set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY)\n\ +set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY)\n\ +\n\ +# Explicitly exclude Conda paths\n\ +list(APPEND CMAKE_IGNORE_PATH \n\ + "/workspace/miniconda3"\n\ + "/workspace/miniconda3/envs"\n\ + "/workspace/miniconda3/envs/comfystream"\n\ + "/workspace/miniconda3/envs/comfystream/lib"\n\ +)\n\ +\n\ +# Set RPATH settings\n\ +set(CMAKE_SKIP_BUILD_RPATH FALSE)\n\ +set(CMAKE_BUILD_WITH_INSTALL_RPATH FALSE)\n\ +set(CMAKE_INSTALL_RPATH "/usr/local/lib:/usr/lib/x86_64-linux-gnu")\n\ +set(PYTHON_LIBRARY "/workspace/miniconda3/envs/comfystream/lib/")\n\ +set(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE)' > /workspace/custom_toolchain.cmake + +# Set environment variables for OpenCV +RUN echo 'export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH' >> /root/.bashrc + +# Build and install OpenCV with CUDA support +RUN cd /workspace/opencv/build && \ + # Build OpenCV + cmake \ + -D CMAKE_TOOLCHAIN_FILE=/workspace/custom_toolchain.cmake \ + -D CMAKE_BUILD_TYPE=RELEASE \ + -D CMAKE_INSTALL_PREFIX=/usr/local \ + -D WITH_CUDA=ON \ + -D WITH_CUDNN=ON \ + -D WITH_CUBLAS=ON \ + -D WITH_TBB=ON \ + -D CUDA_ARCH_LIST="8.0+PTX" \ + -D OPENCV_DNN_CUDA=ON \ + -D OPENCV_ENABLE_NONFREE=ON \ + -D CUDA_TOOLKIT_ROOT_DIR=/usr/local/cuda \ + -D OPENCV_EXTRA_MODULES_PATH=/workspace/opencv_contrib/modules \ + -D PYTHON3_EXECUTABLE=/workspace/miniconda3/envs/comfystream/bin/python3.12 \ + -D PYTHON_INCLUDE_DIR=/workspace/miniconda3/envs/comfystream/include/python3.12 \ + -D PYTHON_LIBRARY=/workspace/miniconda3/envs/comfystream/lib/libpython3.12.so \ + -D HAVE_opencv_python3=ON \ + -D WITH_NVCUVID=OFF \ + -D WITH_NVCUVENC=OFF \ + .. && \ + make -j$(nproc) && \ + make install && \ + ldconfig + +# Configure no environment activation by default +RUN conda config --set auto_activate_base false && \ + conda init bash + +WORKDIR /workspace/comfystream diff --git a/docker/README.md b/docker/README.md index ad691aceb..cabb2fd48 100644 --- a/docker/README.md +++ b/docker/README.md @@ -1,4 +1,4 @@ -# ComfyStream Docker +# ComfyStream Docker Build Configuration This folder contains the Docker files that can be used to run ComfyStream in a containerized fashion or to work on the codebase within a dev container. This README contains the general usage instructions while the [Devcontainer Readme](../.devcontainer/README.md) contains instructions on how to use Comfystream inside a dev container and get quickly started with your development journey. @@ -7,21 +7,48 @@ This folder contains the Docker files that can be used to run ComfyStream in a c - [Dockerfile](Dockerfile) - The main Dockerfile that can be used to run ComfyStream in a containerized fashion. - [Dockerfile.base](Dockerfile.base) - The base Dockerfile that can be used to build the base image for ComfyStream. -## Pre-requisites +## Building with Custom Nodes Configuration -- [Docker](https://docs.docker.com/get-docker/) -- [Nvidia Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) +The base Docker image supports specifying a custom nodes configuration file during build time using the `NODES_CONFIG` build argument. -## Usage +### Usage -### Build the Base Image +#### Default build (uses `nodes.yaml`) +```bash +docker build -t livepeer/comfyui-base -f docker/Dockerfile . +``` -To build the base image, run the following command: +#### Build with custom config from configs directory +```bash +docker build -f docker/Dockerfile.base \ + --build-arg NODES_CONFIG=nodes-streamdiffusion.yaml \ + -t comfyui-base:streamdiffusion . +``` +#### Build with config from absolute path ```bash -docker build -t livepeer/comfyui-base -f docker/Dockerfile.base . +docker build -f docker/Dockerfile.base \ + --build-arg NODES_CONFIG=/path/to/custom-nodes.yaml \ + -t comfyui-base:custom . ``` +### Available Build Arguments + +| Argument | Default | Description | +|----------|---------|-------------| +| `BASE_IMAGE` | `nvidia/cuda:12.8.1-cudnn-devel-ubuntu22.04` | Base CUDA image | +| `CONDA_VERSION` | `latest` | Miniconda version | +| `PYTHON_VERSION` | `3.12` | Python version | +| `NODES_CONFIG` | `nodes.yaml` | Nodes configuration file (filename or path) | +| `CACHEBUST` | `static` | Cache invalidation for node setup | + +### Configuration Files in configs/ + +- **`nodes.yaml`** - Full node configuration (default) +- **`nodes-streamdiffusion.yaml`** - Minimal set of nodes for faster builds + +### Examples + ### Build the Main Image To build the main image, run the following command: diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh index ef55c77e2..e6d44463a 100755 --- a/docker/entrypoint.sh +++ b/docker/entrypoint.sh @@ -130,6 +130,14 @@ if [ "$1" = "--build-engines" ]; then echo "Engine for DepthAnything2 (large) already exists at ${DEPTH_ANYTHING_DIR}/${DEPTH_ANYTHING_ENGINE_LARGE}, skipping..." fi + # Build Engines for FasterLivePortrait + if [ ! -f "$FASTERLIVEPORTRAIT_DIR/warping_spade-fix.trt" ]; then + cd "$FASTERLIVEPORTRAIT_DIR" + bash /workspace/ComfyUI/custom_nodes/ComfyUI-FasterLivePortrait/scripts/build_fasterliveportrait_trt.sh "${FASTERLIVEPORTRAIT_DIR}" "${FASTERLIVEPORTRAIT_DIR}" "${FASTERLIVEPORTRAIT_DIR}" + else + echo "Engines for FasterLivePortrait already exists, skipping..." + fi + # Build Engine for StreamDiffusion if [ ! -f "$TENSORRT_DIR/StreamDiffusion-engines/stabilityai/sd-turbo--lcm_lora-True--tiny_vae-True--max_batch-3--min_batch-3--mode-img2img/unet.engine.opt.onnx" ]; then cd /workspace/ComfyUI/custom_nodes/ComfyUI-StreamDiffusion @@ -158,7 +166,7 @@ if [ "$1" = "--opencv-cuda" ]; then if [ ! -f "/workspace/comfystream/opencv-cuda-release.tar.gz" ]; then # Download and extract OpenCV CUDA build DOWNLOAD_NAME="opencv-cuda-release.tar.gz" - wget -q -O "$DOWNLOAD_NAME" https://github.com/JJassonn69/ComfyUI-Stream-Pack/releases/download/v2/opencv-cuda-release.tar.gz + wget -q -O "$DOWNLOAD_NAME" https://github.com/JJassonn69/ComfyUI-Stream-Pack/releases/download/v2.1/opencv-cuda-release.tar.gz tar -xzf "$DOWNLOAD_NAME" -C /workspace/comfystream/ rm "$DOWNLOAD_NAME" else @@ -166,15 +174,6 @@ if [ "$1" = "--opencv-cuda" ]; then fi # Install required libraries - apt-get update && apt-get install -y \ - libgflags-dev \ - libgoogle-glog-dev \ - libjpeg-dev \ - libavcodec-dev \ - libavformat-dev \ - libavutil-dev \ - libswscale-dev - # Remove existing cv2 package SITE_PACKAGES_DIR="/workspace/miniconda3/envs/comfystream/lib/python3.12/site-packages" rm -rf "${SITE_PACKAGES_DIR}/cv2"* diff --git a/example.py b/example.py index 6d68aac13..9f7567306 100644 --- a/example.py +++ b/example.py @@ -1,7 +1,8 @@ -import torch import asyncio import json +import torch + from comfystream.client import ComfyStreamClient diff --git a/install.py b/install.py index 385d160fd..7552ea8a9 100644 --- a/install.py +++ b/install.py @@ -1,19 +1,20 @@ -import os -import subprocess import argparse import logging +import os import pathlib +import subprocess import sys -import tarfile import tempfile import urllib.request -import toml import zipfile + +import toml from comfy_compatibility.workspace import auto_patch_workspace_and_restart -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) + def get_project_version(workspace: str) -> str: """Read project version from pyproject.toml""" pyproject_path = os.path.join(workspace, "pyproject.toml") @@ -25,19 +26,25 @@ def get_project_version(workspace: str) -> str: logger.error(f"Failed to read version from pyproject.toml: {e}") return "unknown" + def download_and_extract_ui_files(version: str): """Download and extract UI files to the workspace""" output_dir = os.path.join(os.getcwd(), "nodes", "web", "static") pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True) - base_url = urllib.parse.urljoin("https://github.com/livepeer/comfystream/releases/download/", f"v{version}/comfystream-uikit.zip") - fallback_url = "https://github.com/livepeer/comfystream/releases/latest/download/comfystream-uikit.zip" - + base_url = urllib.parse.urljoin( + "https://github.com/livepeer/comfystream/releases/download/", + f"v{version}/comfystream-uikit.zip", + ) + fallback_url = ( + "https://github.com/livepeer/comfystream/releases/latest/download/comfystream-uikit.zip" + ) + # Create a temporary directory instead of a temporary file with tempfile.TemporaryDirectory() as temp_dir: # Define the path for the downloaded file download_path = os.path.join(temp_dir, "comfystream-uikit.zip") - + # Download zip file logger.info(f"Downloading {base_url}") try: @@ -53,23 +60,27 @@ def download_and_extract_ui_files(version: str): else: logger.error(f"Error downloading package: {e}") raise - + # Extract contents try: logger.info(f"Extracting files to {output_dir}") - with zipfile.ZipFile(download_path, 'r') as zip_ref: + with zipfile.ZipFile(download_path, "r") as zip_ref: zip_ref.extractall(path=output_dir) except Exception as e: logger.error(f"Error extracting files: {e}") raise + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Install custom node requirements") parser.add_argument( - "--workspace", default=os.environ.get('COMFY_UI_WORKSPACE', None), required=False, help="Set Comfy workspace" + "--workspace", + default=os.environ.get("COMFY_UI_WORKSPACE", None), + required=False, + help="Set Comfy workspace", ) args = parser.parse_args() - + workspace = args.workspace if workspace is None: # Look up to 3 directories up for ComfyUI @@ -93,12 +104,14 @@ def download_and_extract_ui_files(version: str): subprocess.check_call([sys.executable, "-m", "pip", "install", "-e", "."]) if workspace is None: - logger.warning("No ComfyUI workspace found. Please specify a valid workspace path to fully install") - + logger.warning( + "No ComfyUI workspace found. Please specify a valid workspace path to fully install" + ) + if workspace is not None: logger.info("Patching ComfyUI workspace...") auto_patch_workspace_and_restart(workspace) - + logger.info("Downloading and extracting UI files...") version = get_project_version(os.getcwd()) download_and_extract_ui_files(version) diff --git a/nodes/__init__.py b/nodes/__init__.py index e3ca8f2ae..a590e9ccc 100644 --- a/nodes/__init__.py +++ b/nodes/__init__.py @@ -1,5 +1,7 @@ """ComfyStream nodes package""" -from comfy_compatibility.imports import ImportContext, SITE_PACKAGES, MAIN_PY + +from comfy_compatibility.imports import MAIN_PY, SITE_PACKAGES, ImportContext + with ImportContext("comfy", "comfy_extras", order=[SITE_PACKAGES, MAIN_PY]): from .audio_utils import * from .tensor_utils import * @@ -13,15 +15,16 @@ # Import and update mappings from submodules for module in [audio_utils, tensor_utils, video_stream_utils, api, web]: - if hasattr(module, 'NODE_CLASS_MAPPINGS'): + if hasattr(module, "NODE_CLASS_MAPPINGS"): NODE_CLASS_MAPPINGS.update(module.NODE_CLASS_MAPPINGS) - if hasattr(module, 'NODE_DISPLAY_NAME_MAPPINGS'): + if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS"): NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS) # Web directory for UI components import os + WEB_DIRECTORY = os.path.join(os.path.dirname(os.path.realpath(__file__)), "web") NODE_DISPLAY_NAME_MAPPINGS["ComfyStreamLauncher"] = "Launch ComfyStream 🚀" -__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] +__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] diff --git a/nodes/api/__init__.py b/nodes/api/__init__.py index ebc480859..eb6dbdf81 100644 --- a/nodes/api/__init__.py +++ b/nodes/api/__init__.py @@ -1,17 +1,22 @@ """ComfyStream API implementation""" + +import logging import os -import webbrowser -from server import PromptServer -from aiohttp import web import pathlib -import logging +import webbrowser + import aiohttp -from ..server_manager import LocalComfyStreamServer +from aiohttp import web + +from server import PromptServer + from .. import settings_storage +from ..server_manager import LocalComfyStreamServer routes = None server_manager = None + # Middleware to add Cache-Control: no-cache for index.html @web.middleware async def cache_control_middleware(request, handler): @@ -20,19 +25,22 @@ async def cache_control_middleware(request, handler): target_path = f"{STATIC_ROUTE}/index.html" # Log comparison details if request.path == target_path: - response.headers['Cache-Control'] = 'no-cache' - logging.debug(f"[CacheMiddleware] Added Cache-Control: no-cache for {request.path}") # Kept debug log + response.headers["Cache-Control"] = "no-cache" + logging.debug( + f"[CacheMiddleware] Added Cache-Control: no-cache for {request.path}" + ) # Kept debug log return response + # Only set up routes if we're in the main ComfyUI instance -if hasattr(PromptServer.instance, 'routes') and hasattr(PromptServer.instance.routes, 'static'): +if hasattr(PromptServer.instance, "routes") and hasattr(PromptServer.instance.routes, "static"): routes = PromptServer.instance.routes - + # Get the path to the static build directory - STATIC_DIR = pathlib.Path(__file__).parent.parent.parent / "nodes" / "web" / "static" - + STATIC_DIR = pathlib.Path(__file__).parent.parent.parent / "nodes" / "web" / "static" + # Dynamically determine the extension name from the directory structure - extension_name = "comfystream" # Define a local default for the try/except block + extension_name = "comfystream" # Define a local default for the try/except block try: # Get the parent directory of the current file # Then navigate up to get the extension root directory @@ -41,7 +49,9 @@ async def cache_control_middleware(request, handler): extension_name = EXTENSION_ROOT.name logging.info(f"Detected extension name: {extension_name}") except Exception as e: - logging.warning(f"Failed to get extension name dynamically: {e}, using fallback '{extension_name}'") + logging.warning( + f"Failed to get extension name dynamically: {e}, using fallback '{extension_name}'" + ) # Fallback name is already set by the initial local default # Define module-level constants AFTER determination @@ -53,30 +63,34 @@ async def cache_control_middleware(request, handler): routes.static(STATIC_ROUTE, str(STATIC_DIR), append_version=False, follow_symlinks=True) # Add the cache control middleware to the app - if hasattr(PromptServer.instance, 'app'): + if hasattr(PromptServer.instance, "app"): PromptServer.instance.app.middlewares.append(cache_control_middleware) logging.info(f"Added ComfyStream cache control middleware for {STATIC_ROUTE}/index.html") else: - logging.warning("Could not add ComfyStream cache control middleware: PromptServer.instance.app not found.") + logging.warning( + "Could not add ComfyStream cache control middleware: PromptServer.instance.app not found." + ) # Create server manager instance server_manager = LocalComfyStreamServer() - - @routes.get('/comfystream/extension_info') + + @routes.get("/comfystream/extension_info") async def get_extension_info(request): """Return extension information including name and paths""" try: - return web.json_response({ - "success": True, - "extension_name": EXTENSION_NAME, - "static_route": STATIC_ROUTE, - "ui_url": f"{STATIC_ROUTE}/index.html" - }) + return web.json_response( + { + "success": True, + "extension_name": EXTENSION_NAME, + "static_route": STATIC_ROUTE, + "ui_url": f"{STATIC_ROUTE}/index.html", + } + ) except Exception as e: logging.error(f"Error getting extension info: {str(e)}") return web.json_response({"success": False, "error": str(e)}, status=500) - - @routes.post('/api/offer') + + @routes.post("/api/offer") async def proxy_offer(request): """Proxy offer requests to the ComfyStream server""" try: @@ -89,64 +103,65 @@ async def proxy_offer(request): async with session.post( f"{target_url}/offer", json={"prompts": data.get("prompts"), "offer": data.get("offer")}, - headers={"Content-Type": "application/json"} + headers={"Content-Type": "application/json"}, ) as response: if not response.ok: return web.json_response( - {"error": f"Server error: {response.status}"}, - status=response.status + {"error": f"Server error: {response.status}"}, status=response.status ) return web.json_response(await response.json()) except Exception as e: logging.error(f"Error proxying offer: {str(e)}") return web.json_response({"error": str(e)}, status=500) - @routes.post('/comfystream/control') + @routes.post("/comfystream/control") async def control_server(request): """Handle server control requests""" try: data = await request.json() action = data.get("action") settings = data.get("settings", {}) - + # Extract host and port from settings if provided host = settings.get("host") if settings else None port = settings.get("port") if settings else None - + if action == "status": # Simply return the current server status - return web.json_response({ - "success": True, - "status": server_manager.get_status() - }) + return web.json_response({"success": True, "status": server_manager.get_status()}) elif action == "start": success = await server_manager.start(port=port, host=host) - return web.json_response({ - "success": success, - "status": server_manager.get_status() - }) + return web.json_response( + {"success": success, "status": server_manager.get_status()} + ) elif action == "stop": try: success = await server_manager.stop() - return web.json_response({ - "success": success, - "status": server_manager.get_status() - }) + return web.json_response( + {"success": success, "status": server_manager.get_status()} + ) except Exception as e: logging.error(f"Error stopping server: {str(e)}") # Force cleanup even if the normal stop fails server_manager.cleanup() - return web.json_response({ - "success": True, - "status": {"running": False, "port": None, "host": None, "pid": None, "type": "local"}, - "message": "Forced server shutdown due to error" - }) + return web.json_response( + { + "success": True, + "status": { + "running": False, + "port": None, + "host": None, + "pid": None, + "type": "local", + }, + "message": "Forced server shutdown due to error", + } + ) elif action == "restart": success = await server_manager.restart(port=port, host=host) - return web.json_response({ - "success": success, - "status": server_manager.get_status() - }) + return web.json_response( + {"success": success, "status": server_manager.get_status()} + ) else: return web.json_response({"error": "Invalid action"}, status=400) except Exception as e: @@ -155,17 +170,25 @@ async def control_server(request): if data and data.get("action") == "stop": try: server_manager.cleanup() - return web.json_response({ - "success": True, - "status": {"running": False, "port": None, "host": None, "pid": None, "type": "local"}, - "message": "Forced server shutdown due to error" - }) + return web.json_response( + { + "success": True, + "status": { + "running": False, + "port": None, + "host": None, + "pid": None, + "type": "local", + }, + "message": "Forced server shutdown due to error", + } + ) except Exception as cleanup_error: logging.error(f"Error during forced cleanup: {str(cleanup_error)}") - + return web.json_response({"error": str(e)}, status=500) - @routes.get('/comfystream/settings') + @routes.get("/comfystream/settings") async def get_settings(request): """Get ComfyStream settings""" try: @@ -174,63 +197,59 @@ async def get_settings(request): except Exception as e: logging.error(f"Error getting settings: {str(e)}") return web.json_response({"error": str(e)}, status=500) - - @routes.post('/comfystream/settings') + + @routes.post("/comfystream/settings") async def update_settings(request): """Update ComfyStream settings""" try: data = await request.json() success = settings_storage.update_settings(data) - return web.json_response({ - "success": success, - "settings": settings_storage.load_settings() - }) + return web.json_response( + {"success": success, "settings": settings_storage.load_settings()} + ) except Exception as e: logging.error(f"Error updating settings: {str(e)}") return web.json_response({"error": str(e)}, status=500) - - @routes.post('/comfystream/settings/configuration') + + @routes.post("/comfystream/settings/configuration") async def manage_configuration(request): """Add, remove, or select a configuration""" try: data = await request.json() action = data.get("action") - + if action == "add": name = data.get("name") host = data.get("host") port = data.get("port") if not name or not host or not port: return web.json_response({"error": "Missing required parameters"}, status=400) - + success = settings_storage.add_configuration(name, host, port) - return web.json_response({ - "success": success, - "settings": settings_storage.load_settings() - }) - + return web.json_response( + {"success": success, "settings": settings_storage.load_settings()} + ) + elif action == "remove": index = data.get("index") if index is None: return web.json_response({"error": "Missing index parameter"}, status=400) - + success = settings_storage.remove_configuration(index) - return web.json_response({ - "success": success, - "settings": settings_storage.load_settings() - }) - + return web.json_response( + {"success": success, "settings": settings_storage.load_settings()} + ) + elif action == "select": index = data.get("index") if index is None: return web.json_response({"error": "Missing index parameter"}, status=400) - + success = settings_storage.select_configuration(index) - return web.json_response({ - "success": success, - "settings": settings_storage.load_settings() - }) - + return web.json_response( + {"success": success, "settings": settings_storage.load_settings()} + ) + else: return web.json_response({"error": "Invalid action"}, status=400) except Exception as e: diff --git a/nodes/audio_utils/__init__.py b/nodes/audio_utils/__init__.py index bc090689a..f9c84d1f8 100644 --- a/nodes/audio_utils/__init__.py +++ b/nodes/audio_utils/__init__.py @@ -1,6 +1,6 @@ from .load_audio_tensor import LoadAudioTensor -from .save_audio_tensor import SaveAudioTensor from .pitch_shift import PitchShifter +from .save_audio_tensor import SaveAudioTensor NODE_CLASS_MAPPINGS = { "LoadAudioTensor": LoadAudioTensor, diff --git a/nodes/audio_utils/load_audio_tensor.py b/nodes/audio_utils/load_audio_tensor.py index ece7fca31..919a75991 100644 --- a/nodes/audio_utils/load_audio_tensor.py +++ b/nodes/audio_utils/load_audio_tensor.py @@ -1,71 +1,101 @@ +import queue + import numpy as np import torch from comfystream import tensor_cache +from comfystream.exceptions import ComfyStreamAudioBufferError, ComfyStreamInputTimeoutError + class LoadAudioTensor: - CATEGORY = "audio_utils" + CATEGORY = "ComfyStream/Loaders" RETURN_TYPES = ("AUDIO",) RETURN_NAMES = ("audio",) FUNCTION = "execute" - + DESCRIPTION = "Load audio tensor from ComfyStream input with timeout." + def __init__(self): self.audio_buffer = np.empty(0, dtype=np.int16) self.buffer_samples = None self.sample_rate = None - + self.leftover = np.empty(0, dtype=np.int16) + @classmethod - def INPUT_TYPES(s): + def INPUT_TYPES(cls): return { "required": { - "buffer_size": ("FLOAT", {"default": 500.0}), - } + "buffer_size": ( + "FLOAT", + {"default": 500.0, "tooltip": "Audio buffer size in milliseconds"}, + ), + }, + "optional": { + "timeout_seconds": ( + "FLOAT", + { + "default": 1.0, + "min": 0.1, + "max": 30.0, + "step": 0.1, + "tooltip": "Timeout in seconds", + }, + ), + }, } - + @classmethod - def IS_CHANGED(**kwargs): + def IS_CHANGED(cls, **kwargs): return float("nan") - - def execute(self, buffer_size): + + def execute(self, buffer_size: float, timeout_seconds: float = 1.0): + # Initialize if needed if self.sample_rate is None or self.buffer_samples is None: - frame = tensor_cache.audio_inputs.get(block=True) - self.sample_rate = frame.sample_rate - self.buffer_samples = int(self.sample_rate * buffer_size / 1000) - self.leftover = frame.side_data.input - - if self.leftover.shape[0] < self.buffer_samples: + try: + frame = tensor_cache.audio_inputs.get(block=True, timeout=timeout_seconds) + self.sample_rate = frame.sample_rate + self.buffer_samples = int(self.sample_rate * buffer_size / 1000) + self.leftover = frame.side_data.input + except queue.Empty: + raise ComfyStreamInputTimeoutError("audio", timeout_seconds) + + # Use leftover data if available + if self.leftover.shape[0] >= self.buffer_samples: + buffered_audio = self.leftover[: self.buffer_samples] + self.leftover = self.leftover[self.buffer_samples :] + else: + # Collect more audio chunks chunks = [self.leftover] if self.leftover.size > 0 else [] total_samples = self.leftover.shape[0] - + while total_samples < self.buffer_samples: - frame = tensor_cache.audio_inputs.get(block=True) - if frame.sample_rate != self.sample_rate: - raise ValueError("Sample rate mismatch") - chunks.append(frame.side_data.input) - total_samples += frame.side_data.input.shape[0] - + try: + frame = tensor_cache.audio_inputs.get(block=True, timeout=timeout_seconds) + if frame.sample_rate != self.sample_rate: + raise ValueError( + f"Sample rate mismatch: expected {self.sample_rate}Hz, got {frame.sample_rate}Hz" + ) + chunks.append(frame.side_data.input) + total_samples += frame.side_data.input.shape[0] + except queue.Empty: + raise ComfyStreamAudioBufferError( + timeout_seconds, self.buffer_samples, total_samples + ) + merged_audio = np.concatenate(chunks, dtype=np.int16) - buffered_audio = merged_audio[:self.buffer_samples] - self.leftover = merged_audio[self.buffer_samples:] - else: - buffered_audio = self.leftover[:self.buffer_samples] - self.leftover = self.leftover[self.buffer_samples:] - - # Convert numpy array to torch tensor and normalize int16 to float32 + buffered_audio = merged_audio[: self.buffer_samples] + self.leftover = ( + merged_audio[self.buffer_samples :] + if merged_audio.shape[0] > self.buffer_samples + else np.empty(0, dtype=np.int16) + ) + + # Convert to ComfyUI AUDIO format waveform_tensor = torch.from_numpy(buffered_audio.astype(np.float32) / 32768.0) - + # Ensure proper tensor shape: (batch, channels, samples) if waveform_tensor.dim() == 1: - # Mono: (samples,) -> (1, 1, samples) waveform_tensor = waveform_tensor.unsqueeze(0).unsqueeze(0) elif waveform_tensor.dim() == 2: - # Assume (channels, samples) and add batch dimension waveform_tensor = waveform_tensor.unsqueeze(0) - - # Return AUDIO dictionary format - audio_dict = { - "waveform": waveform_tensor, - "sample_rate": self.sample_rate - } - - return (audio_dict,) + + return ({"waveform": waveform_tensor, "sample_rate": self.sample_rate},) diff --git a/nodes/audio_utils/pitch_shift.py b/nodes/audio_utils/pitch_shift.py index 2fba9ee59..6e2456a2a 100644 --- a/nodes/audio_utils/pitch_shift.py +++ b/nodes/audio_utils/pitch_shift.py @@ -1,23 +1,19 @@ -import numpy as np import librosa +import numpy as np import torch + class PitchShifter: CATEGORY = "audio_utils" RETURN_TYPES = ("AUDIO",) FUNCTION = "execute" - + @classmethod def INPUT_TYPES(cls): return { "required": { "audio": ("AUDIO",), - "pitch_shift": ("FLOAT", { - "default": 4.0, - "min": 0.0, - "max": 12.0, - "step": 0.5 - }), + "pitch_shift": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 12.0, "step": 0.5}), } } @@ -29,37 +25,36 @@ def execute(self, audio, pitch_shift): # Extract waveform and sample rate from AUDIO format waveform = audio["waveform"] sample_rate = audio["sample_rate"] - + # Convert tensor to numpy and ensure proper format for librosa if isinstance(waveform, torch.Tensor): audio_numpy = waveform.squeeze().cpu().numpy() else: audio_numpy = waveform.squeeze() - + # Ensure float32 format and proper normalization for librosa processing if audio_numpy.dtype != np.float32: audio_numpy = audio_numpy.astype(np.float32) - + # Check if data needs normalization (librosa expects [-1, 1] range) max_abs_val = np.abs(audio_numpy).max() if max_abs_val > 1.0: # Data appears to be in int16 range, normalize it audio_numpy = audio_numpy / 32768.0 - + # Apply pitch shift - shifted_audio = librosa.effects.pitch_shift(y=audio_numpy, sr=sample_rate, n_steps=pitch_shift) - + shifted_audio = librosa.effects.pitch_shift( + y=audio_numpy, sr=sample_rate, n_steps=pitch_shift + ) + # Convert back to tensor and restore original shape shifted_tensor = torch.from_numpy(shifted_audio).float() if waveform.dim() == 3: # (batch, channels, samples) shifted_tensor = shifted_tensor.unsqueeze(0).unsqueeze(0) - elif waveform.dim() == 2: # (channels, samples) + elif waveform.dim() == 2: # (channels, samples) shifted_tensor = shifted_tensor.unsqueeze(0) - + # Return AUDIO format - result_audio = { - "waveform": shifted_tensor, - "sample_rate": sample_rate - } - + result_audio = {"waveform": shifted_tensor, "sample_rate": sample_rate} + return (result_audio,) diff --git a/nodes/audio_utils/save_audio_tensor.py b/nodes/audio_utils/save_audio_tensor.py index 6b7b0281c..eed9bf6ad 100644 --- a/nodes/audio_utils/save_audio_tensor.py +++ b/nodes/audio_utils/save_audio_tensor.py @@ -1,20 +1,17 @@ import numpy as np + from comfystream import tensor_cache + class SaveAudioTensor: CATEGORY = "audio_utils" RETURN_TYPES = () FUNCTION = "execute" OUTPUT_NODE = True - @classmethod def INPUT_TYPES(s): - return { - "required": { - "audio": ("AUDIO",) - } - } + return {"required": {"audio": ("AUDIO",)}} @classmethod def IS_CHANGED(s): @@ -23,24 +20,24 @@ def IS_CHANGED(s): def execute(self, audio): # Extract waveform tensor from AUDIO format waveform = audio["waveform"] - + # Convert to numpy and flatten for pipeline compatibility - if hasattr(waveform, 'cpu'): + if hasattr(waveform, "cpu"): # PyTorch tensor waveform_numpy = waveform.squeeze().cpu().numpy() else: # Already numpy waveform_numpy = waveform.squeeze() - + # Ensure 1D array for pipeline buffer concatenation if waveform_numpy.ndim > 1: waveform_numpy = waveform_numpy.flatten() - + # Convert to int16 if needed (pipeline expects int16) if waveform_numpy.dtype == np.float32: waveform_numpy = (waveform_numpy * 32767).astype(np.int16) elif waveform_numpy.dtype != np.int16: waveform_numpy = waveform_numpy.astype(np.int16) - + tensor_cache.audio_outputs.put_nowait(waveform_numpy) return (audio,) diff --git a/nodes/server_manager.py b/nodes/server_manager.py index 9b55f4f57..cd29bc07f 100644 --- a/nodes/server_manager.py +++ b/nodes/server_manager.py @@ -1,88 +1,83 @@ """ComfyStream server management module""" -import os -import sys -import subprocess -import socket -import signal + +import asyncio import atexit import logging -import urllib.request +import os +import socket +import subprocess +import sys +import threading import urllib.error -from pathlib import Path -import time -import asyncio +import urllib.request from abc import ABC, abstractmethod -import threading +from pathlib import Path # Configure logging to output to console -logging.basicConfig( - level=logging.INFO, - format='[ComfyStream] %(message)s', - stream=sys.stdout -) +logging.basicConfig(level=logging.INFO, format="[ComfyStream] %(message)s", stream=sys.stdout) # Set up Windows specific event loop policy -if sys.platform == 'win32': - if hasattr(asyncio, 'WindowsSelectorEventLoopPolicy'): +if sys.platform == "win32": + if hasattr(asyncio, "WindowsSelectorEventLoopPolicy"): asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - + class ComfyStreamServerBase(ABC): """Abstract base class for ComfyStream server management""" - + def __init__(self, host="0.0.0.0", port=None): self.host = host self.port = port self.is_running = False logging.info(f"Initializing {self.__class__.__name__}") - + @abstractmethod async def start(self, port=None, host=None) -> bool: """Start the ComfyStream server - + Args: port: Optional port to use. If None, implementation should choose a port. host: Optional host to use. If None, implementation should use the default host. - + Returns: bool: True if server started successfully, False otherwise """ pass - + @abstractmethod async def stop(self) -> bool: """Stop the ComfyStream server - + Returns: bool: True if server stopped successfully, False otherwise """ pass - + @abstractmethod def get_status(self) -> dict: """Get current server status - + Returns: dict: Server status information """ pass - + @abstractmethod def check_server_health(self) -> bool: """Check if server is responding to health checks - + Returns: bool: True if server is healthy, False otherwise """ pass - + async def restart(self, port=None, host=None) -> bool: """Restart the ComfyStream server - + Args: port: Optional port to use. If None, use the current port. host: Optional host to use. If None, use the current host. - + Returns: bool: True if server restarted successfully, False otherwise """ @@ -93,11 +88,18 @@ async def restart(self, port=None, host=None) -> bool: await self.stop() return await self.start(port=port_to_use, host=host_to_use) + class LocalComfyStreamServer(ComfyStreamServerBase): """Local ComfyStream server implementation""" - - def __init__(self, host="0.0.0.0", start_port=8889, max_port=65535, - health_check_timeout=30, health_check_interval=1): + + def __init__( + self, + host="0.0.0.0", + start_port=8889, + max_port=65535, + health_check_timeout=30, + health_check_interval=1, + ): super().__init__(host=host) self.process = None self.start_port = start_port @@ -105,7 +107,7 @@ def __init__(self, host="0.0.0.0", start_port=8889, max_port=65535, self.health_check_timeout = health_check_timeout self.health_check_interval = health_check_interval atexit.register(self.cleanup) - + def find_available_port(self): """Find an available port starting from start_port""" port = self.start_port @@ -123,7 +125,7 @@ def check_server_health(self): """Check if server is responding to health checks""" if not self.port: return False - + url = f"http://{self.host}:{self.port}" try: response = urllib.request.urlopen(url) @@ -133,7 +135,7 @@ def check_server_health(self): def log_subprocess_output(self, pipe, level): """Log the output from the subprocess to the logger.""" - for line in iter(pipe.readline, b''): + for line in iter(pipe.readline, b""): logging.log(level, line.decode().strip()) async def start(self, port=None, host=None): @@ -146,37 +148,48 @@ async def start(self, port=None, host=None): self.port = port or self.find_available_port() if host is not None: self.host = host - + # Get the path to the ComfyStream server directory and script server_dir = Path(__file__).parent.parent / "server" server_script = server_dir / "app.py" logging.info(f"Server script: {server_script}") - + # Get ComfyUI workspace path (which is where we'll run from) comfyui_workspace = Path(__file__).parent.parent.parent.parent logging.info(f"ComfyUI workspace: {comfyui_workspace}") - + # Use the system Python (which should have ComfyStream installed) - cmd = [sys.executable, "-u", str(server_script), - "--port", str(self.port), - "--host", str(self.host), - "--workspace", str(comfyui_workspace)] - + cmd = [ + sys.executable, + "-u", + str(server_script), + "--port", + str(self.port), + "--host", + str(self.host), + "--workspace", + str(comfyui_workspace), + ] + logging.info(f"Starting server with command: {' '.join(cmd)}") - + # Start process with output going to pipes self.process = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=str(comfyui_workspace), # Run from ComfyUI root - env={**os.environ, 'PYTHONUNBUFFERED': '1'} + env={**os.environ, "PYTHONUNBUFFERED": "1"}, ) - + # Start threads to log stdout and stderr - threading.Thread(target=self.log_subprocess_output, args=(self.process.stdout, logging.INFO)).start() - threading.Thread(target=self.log_subprocess_output, args=(self.process.stderr, logging.ERROR)).start() - + threading.Thread( + target=self.log_subprocess_output, args=(self.process.stdout, logging.INFO) + ).start() + threading.Thread( + target=self.log_subprocess_output, args=(self.process.stderr, logging.ERROR) + ).start() + # Wait for server to start responding logging.info("Waiting for server to start...") for _ in range(self.health_check_timeout): @@ -185,15 +198,19 @@ async def start(self, port=None, host=None): break await asyncio.sleep(self.health_check_interval) else: - raise RuntimeError(f"Server failed to start after {self.health_check_timeout} seconds") - + raise RuntimeError( + f"Server failed to start after {self.health_check_timeout} seconds" + ) + if self.process.poll() is not None: raise RuntimeError(f"Server failed to start (exit code: {self.process.poll()})") - + self.is_running = True - logging.info(f"ComfyStream server started on port {self.port} (PID: {self.process.pid})") + logging.info( + f"ComfyStream server started on port {self.port} (PID: {self.process.pid})" + ) return True - + except Exception as e: logging.error(f"Error starting ComfyStream server: {str(e)}") self.cleanup() @@ -204,7 +221,7 @@ async def stop(self): if not self.is_running: logging.info("Server is not running") return False - + try: self.cleanup() logging.info("ComfyStream server stopped") @@ -220,7 +237,7 @@ def get_status(self): "port": self.port, "host": self.host, "pid": self.process.pid if self.process else None, - "type": "local" + "type": "local", } logging.info(f"Server status: {status}") return status @@ -247,4 +264,4 @@ def cleanup(self): except Exception as e: logging.error(f"Error cleaning up server process: {str(e)}") self.process = None - self.is_running = False \ No newline at end of file + self.is_running = False diff --git a/nodes/settings_storage.py b/nodes/settings_storage.py index 5d888b611..8b9ecf54b 100644 --- a/nodes/settings_storage.py +++ b/nodes/settings_storage.py @@ -1,53 +1,53 @@ """ComfyStream server-side settings storage module""" -import os + import json import logging +import os import threading from pathlib import Path # Configure logging -logging.basicConfig( - level=logging.INFO, - format='[ComfyStream Settings] %(message)s' -) +logging.basicConfig(level=logging.INFO, format="[ComfyStream Settings] %(message)s") # Default settings DEFAULT_SETTINGS = { "host": "0.0.0.0", "port": 8889, "configurations": [], - "selectedConfigIndex": -1 + "selectedConfigIndex": -1, } # Lock for thread-safe file operations settings_lock = threading.Lock() + def get_settings_file_path(): """Get the path to the settings file""" # Store settings in the extension directory extension_dir = Path(__file__).parent.parent settings_dir = extension_dir / "settings" - + # Create settings directory if it doesn't exist os.makedirs(settings_dir, exist_ok=True) - + return settings_dir / "comfystream_settings.json" + def load_settings(): """Load settings from file""" settings_file = get_settings_file_path() - + with settings_lock: try: if settings_file.exists(): - with open(settings_file, 'r') as f: + with open(settings_file, "r") as f: settings = json.load(f) - + # Ensure all default keys exist for key, value in DEFAULT_SETTINGS.items(): if key not in settings: settings[key] = value - + return settings else: return DEFAULT_SETTINGS.copy() @@ -55,53 +55,57 @@ def load_settings(): logging.error(f"Error loading settings: {str(e)}") return DEFAULT_SETTINGS.copy() + def save_settings(settings): """Save settings to file""" settings_file = get_settings_file_path() - + with settings_lock: try: - with open(settings_file, 'w') as f: + with open(settings_file, "w") as f: json.dump(settings, f, indent=2) return True except Exception as e: logging.error(f"Error saving settings: {str(e)}") return False + def update_settings(new_settings): """Update settings with new values""" current_settings = load_settings() - + # Update only the keys that are provided for key, value in new_settings.items(): current_settings[key] = value - + return save_settings(current_settings) + def add_configuration(name, host, port): """Add a new configuration""" settings = load_settings() - + # Create the new configuration config = {"name": name, "host": host, "port": port} - + # Add to configurations list settings["configurations"].append(config) - + # Save updated settings return save_settings(settings) + def remove_configuration(index): """Remove a configuration by index""" settings = load_settings() - + if index < 0 or index >= len(settings["configurations"]): logging.error(f"Invalid configuration index: {index}") return False - + # Remove the configuration settings["configurations"].pop(index) - + # Update selectedConfigIndex if needed if settings["selectedConfigIndex"] == index: # The selected config was deleted @@ -109,25 +113,26 @@ def remove_configuration(index): elif settings["selectedConfigIndex"] > index: # The selected config is after the deleted one, adjust index settings["selectedConfigIndex"] -= 1 - + # Save updated settings return save_settings(settings) + def select_configuration(index): """Select a configuration by index""" settings = load_settings() - + if index == -1 or (index >= 0 and index < len(settings["configurations"])): settings["selectedConfigIndex"] = index - + # If a valid configuration is selected, update host and port if index >= 0: config = settings["configurations"][index] settings["host"] = config["host"] settings["port"] = config["port"] - + # Save updated settings return save_settings(settings) else: logging.error(f"Invalid configuration index: {index}") - return False \ No newline at end of file + return False diff --git a/nodes/tensor_utils/load_tensor.py b/nodes/tensor_utils/load_tensor.py index c39fe8a1d..a2fb59408 100644 --- a/nodes/tensor_utils/load_tensor.py +++ b/nodes/tensor_utils/load_tensor.py @@ -1,20 +1,40 @@ +import queue + from comfystream import tensor_cache +from comfystream.exceptions import ComfyStreamInputTimeoutError class LoadTensor: - CATEGORY = "tensor_utils" + CATEGORY = "ComfyStream/Loaders" RETURN_TYPES = ("IMAGE",) FUNCTION = "execute" + DESCRIPTION = "Load image tensor from ComfyStream input with timeout." @classmethod - def INPUT_TYPES(s): - return {} + def INPUT_TYPES(cls): + return { + "optional": { + "timeout_seconds": ( + "FLOAT", + { + "default": 1.0, + "min": 0.1, + "max": 30.0, + "step": 0.1, + "tooltip": "Timeout in seconds", + }, + ), + } + } @classmethod - def IS_CHANGED(): + def IS_CHANGED(cls, **kwargs): return float("nan") - def execute(self): - frame = tensor_cache.image_inputs.get(block=True) - frame.side_data.skipped = False - return (frame.side_data.input,) + def execute(self, timeout_seconds: float = 1.0): + try: + frame = tensor_cache.image_inputs.get(block=True, timeout=timeout_seconds) + frame.side_data.skipped = False + return (frame.side_data.input,) + except queue.Empty: + raise ComfyStreamInputTimeoutError("video", timeout_seconds) diff --git a/nodes/tensor_utils/save_text_tensor.py b/nodes/tensor_utils/save_text_tensor.py index 098887e07..defb3aaa8 100644 --- a/nodes/tensor_utils/save_text_tensor.py +++ b/nodes/tensor_utils/save_text_tensor.py @@ -1,5 +1,6 @@ from comfystream import tensor_cache + class SaveTextTensor: CATEGORY = "text_utils" RETURN_TYPES = () @@ -13,8 +14,11 @@ def INPUT_TYPES(s): "data": ("STRING",), # Accept text string as input. }, "optional": { - "remove_linebreaks": ("BOOLEAN", {"default": True}), # Remove whitespace and line breaks - } + "remove_linebreaks": ( + "BOOLEAN", + {"default": True}, + ), # Remove whitespace and line breaks + }, } @classmethod @@ -23,7 +27,7 @@ def IS_CHANGED(s, **kwargs): def execute(self, data, remove_linebreaks=True): if remove_linebreaks: - result_text = data.replace('\n', '').replace('\r', '') + result_text = data.replace("\n", "").replace("\r", "") else: result_text = data tensor_cache.text_outputs.put_nowait(result_text) diff --git a/nodes/video_stream_utils/__init__.py b/nodes/video_stream_utils/__init__.py index 787f84ce9..f48631984 100644 --- a/nodes/video_stream_utils/__init__.py +++ b/nodes/video_stream_utils/__init__.py @@ -1,6 +1,6 @@ """Video stream utility nodes for ComfyStream""" -from .primary_input_load_image import PrimaryInputLoadImage +from .primary_input_load_image import PrimaryInputLoadImage NODE_CLASS_MAPPINGS = {"PrimaryInputLoadImage": PrimaryInputLoadImage} NODE_DISPLAY_NAME_MAPPINGS = {} diff --git a/nodes/web/__init__.py b/nodes/web/__init__.py index ef41b57aa..caa6b49a5 100644 --- a/nodes/web/__init__.py +++ b/nodes/web/__init__.py @@ -1,7 +1,10 @@ """ComfyStream Web UI nodes""" + import os + import folder_paths + # Define a simple Python class for the UI Preview node class ComfyStreamUIPreview: """ @@ -9,29 +12,24 @@ class ComfyStreamUIPreview: It's needed for ComfyUI to properly register and execute the node. The actual implementation is in the JavaScript file. """ + @classmethod def INPUT_TYPES(cls): - return { - "required": {}, - "optional": {} - } - + return {"required": {}, "optional": {}} + RETURN_TYPES = () - + FUNCTION = "execute" CATEGORY = "ComfyStream" - + def execute(self): # This function doesn't do anything as the real work is done in JavaScript # But we need to return something to satisfy the ComfyUI node execution system return ("UI Preview Node Executed",) + # Register the node class -NODE_CLASS_MAPPINGS = { - "ComfyStreamUIPreview": ComfyStreamUIPreview -} +NODE_CLASS_MAPPINGS = {"ComfyStreamUIPreview": ComfyStreamUIPreview} # Display names for the nodes -NODE_DISPLAY_NAME_MAPPINGS = { - "ComfyStreamUIPreview": "ComfyStream UI Preview" -} \ No newline at end of file +NODE_DISPLAY_NAME_MAPPINGS = {"ComfyStreamUIPreview": "ComfyStream UI Preview"} diff --git a/pyproject.toml b/pyproject.toml index 50e59935d..e22a59559 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,16 +1,16 @@ [build-system] -requires = ["setuptools>=64.0.0", "wheel"] +requires = ["setuptools>=64.0.0,<81", "wheel"] build-backend = "setuptools.build_meta" [project] name = "comfystream" description = "Build Live AI Video with ComfyUI" -version = "0.1.5" +version = "0.1.7" license = { file = "LICENSE" } dependencies = [ "asyncio", - "pytrickle @ git+https://github.com/livepeer/pytrickle.git@de37bea74679fa5db46b656a83c9b7240fc597b6", - "comfyui @ git+https://github.com/hiddenswitch/ComfyUI.git@58622c7e91cb5cc2bca985d713db55e5681ff316", + "pytrickle @ git+https://github.com/livepeer/pytrickle.git@v0.1.5", + "comfyui @ git+https://github.com/hiddenswitch/ComfyUI.git@e62df3a8811d8c652a195d4669f4fb27f6c9a9ba", "aiortc", "aiohttp", "aiohttp_cors", @@ -21,7 +21,7 @@ dependencies = [ ] [project.optional-dependencies] -dev = ["pytest", "pytest-cov"] +dev = ["pytest", "pytest-cov", "ruff"] [project.urls] Repository = "https://github.com/yondonfu/comfystream" @@ -37,3 +37,23 @@ packages = {find = {where = ["src", "nodes"]}} [tool.setuptools.dynamic] dependencies = {file = ["requirements.txt"]} + +[tool.ruff] +line-length = 100 +target-version = "py312" + +[tool.ruff.lint] +select = ["E", "F", "I"] +ignore = [ + "E501", # let the formatter handle long lines + "E402", # module level import not at top (required for ComfyUI) + "E722", # bare except (existing code patterns) + "F401", # imported but unused (many re-exports) + "F403", # star imports (ComfyUI pattern) + "F405", # may be undefined from star imports + "F841", # assigned but never used (temporary) +] + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" diff --git a/requirements.txt b/requirements.txt index 790900bb1..da59730dc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ asyncio -pytrickle @ git+https://github.com/livepeer/pytrickle.git@de37bea74679fa5db46b656a83c9b7240fc597b6 -comfyui @ git+https://github.com/hiddenswitch/ComfyUI.git@58622c7e91cb5cc2bca985d713db55e5681ff316 +pytrickle @ git+https://github.com/livepeer/pytrickle.git@v0.1.5 +comfyui @ git+https://github.com/hiddenswitch/ComfyUI.git@e62df3a8811d8c652a195d4669f4fb27f6c9a9ba aiortc aiohttp aiohttp_cors diff --git a/scripts/monitor_pid_resources.py b/scripts/monitor_pid_resources.py index a7de85636..25ffc97a0 100644 --- a/scripts/monitor_pid_resources.py +++ b/scripts/monitor_pid_resources.py @@ -1,16 +1,17 @@ """A Python script to monitor system resources for a given PID and optionally create a py-spy profiler report.""" -import psutil -import pynvml -import time +import csv import subprocess -import click import threading -import csv +import time from pathlib import Path from typing import List +import click +import psutil +import pynvml + def is_running_inside_container(): """Detects if the script is running inside a container.""" @@ -131,30 +132,20 @@ def find_pid_by_name(name: str) -> int: for proc in psutil.process_iter(["pid", "name", "cmdline"]): if proc.info["cmdline"] and name in proc.info["cmdline"]: found_pid = proc.info["pid"] - click.echo( - click.style(f"Found process '{name}' with PID {found_pid}.", fg="green") - ) + click.echo(click.style(f"Found process '{name}' with PID {found_pid}.", fg="green")) return found_pid click.echo(click.style(f"Error: Process with name '{name}' not found.", fg="red")) return None @click.command() -@click.option( - "--pid", type=str, default="auto", help='Process ID or "auto" to find by name' -) -@click.option( - "--name", type=str, default="app.py", help="Process name (default: app.py)" -) +@click.option("--pid", type=str, default="auto", help='Process ID or "auto" to find by name') +@click.option("--name", type=str, default="app.py", help="Process name (default: app.py)") @click.option("--interval", type=int, default=2, help="Monitoring interval (seconds)") -@click.option( - "--duration", type=int, default=30, help="Total monitoring duration (seconds)" -) +@click.option("--duration", type=int, default=30, help="Total monitoring duration (seconds)") @click.option("--output", type=str, default=None, help="File to save logs (optional)") @click.option("--spy", is_flag=True, help="Enable py-spy profiling") -@click.option( - "--spy-output", type=str, default="pyspy_profile.svg", help="Py-Spy output file" -) +@click.option("--spy-output", type=str, default="pyspy_profile.svg", help="Py-Spy output file") @click.option( "--host-pid", type=int, @@ -213,21 +204,15 @@ def monitor_resources( click.echo(click.style(f"Error: Process with PID {pid} not found.", fg="red")) return - click.echo( - click.style(f"Monitoring PID {pid} for {duration} seconds...", fg="green") - ) + click.echo(click.style(f"Monitoring PID {pid} for {duration} seconds...", fg="green")) def run_py_spy(): """Run py-spy profiler for deep profiling.""" click.echo(click.style("Running py-spy for deep profiling...", fg="green")) spy_cmd = f"py-spy record -o {spy_output} --pid {pid} --duration {duration}" try: - subprocess.run( - spy_cmd, shell=True, check=True, capture_output=True, text=True - ) - click.echo( - click.style(f"Py-Spy flame graph saved to {spy_output}", fg="green") - ) + subprocess.run(spy_cmd, shell=True, check=True, capture_output=True, text=True) + click.echo(click.style(f"Py-Spy flame graph saved to {spy_output}", fg="green")) except subprocess.CalledProcessError as e: click.echo(click.style(f"Error running py-spy: {e.stderr}", fg="red")) diff --git a/scripts/requirements.txt b/scripts/requirements.txt index 6ea80ee2e..eb7cee7f7 100644 --- a/scripts/requirements.txt +++ b/scripts/requirements.txt @@ -8,4 +8,4 @@ bcrypt rich # Profiler psutil -pynvml +nvidia-ml-py diff --git a/server/app.py b/server/app.py index a3a42fc44..b93e35ee4 100644 --- a/server/app.py +++ b/server/app.py @@ -4,17 +4,13 @@ import logging import os import sys -import time -import secrets + import torch # Initialize CUDA before any other imports to prevent core dump. if torch.cuda.is_available(): torch.cuda.init() - -from aiohttp import web, MultipartWriter -from aiohttp_cors import setup as setup_cors, ResourceOptions from aiohttp import web from aiortc import ( MediaStreamTrack, @@ -23,15 +19,16 @@ RTCPeerConnection, RTCSessionDescription, ) + # Import HTTP streaming modules -from http_streaming.routes import setup_routes from aiortc.codecs import h264 from aiortc.rtcrtpsender import RTCRtpSender -from comfystream.pipeline import Pipeline from twilio.rest import Client -from comfystream.server.utils import patch_loop_datagram, add_prefix_to_app_routes, FPSMeter + +from comfystream.exceptions import ComfyStreamTimeoutFilter +from comfystream.pipeline import Pipeline from comfystream.server.metrics import MetricsManager, StreamStatsManager -import time +from comfystream.server.utils import FPSMeter, add_prefix_to_app_routes, patch_loop_datagram logger = logging.getLogger(__name__) logging.getLogger("aiortc.rtcrtpsender").setLevel(logging.WARNING) @@ -64,12 +61,10 @@ def __init__(self, track: MediaStreamTrack, pipeline: Pipeline): super().__init__() self.track = track self.pipeline = pipeline - self.fps_meter = FPSMeter( - metrics_manager=app["metrics_manager"], track_id=track.id - ) + self.fps_meter = FPSMeter(metrics_manager=app["metrics_manager"], track_id=track.id) self.running = True self.collect_task = asyncio.create_task(self.collect_frames()) - + # Add cleanup when track ends @track.on("ended") async def on_ended(): @@ -95,15 +90,13 @@ async def collect_frames(self): logger.error(f"Error collecting video frames: {str(e)}") self.running = False break - + # Perform cleanup outside the exception handler logger.info("Video frame collection stopped") except asyncio.CancelledError: logger.info("Frame collection task cancelled") except Exception as e: logger.error(f"Unexpected error in frame collection: {str(e)}") - finally: - await self.pipeline.cleanup() async def recv(self): """Receive a processed video frame from the pipeline, increment the frame @@ -111,15 +104,6 @@ async def recv(self): """ processed_frame = await self.pipeline.get_processed_video_frame() - # Update the frame buffer with the processed frame - try: - from frame_buffer import FrameBuffer - frame_buffer = FrameBuffer.get_instance() - frame_buffer.update_frame(processed_frame) - except Exception as e: - # Don't let frame buffer errors affect the main pipeline - print(f"Error updating frame buffer: {e}") - # Increment the frame count to calculate FPS. await self.fps_meter.increment_frame_count() @@ -128,6 +112,7 @@ async def recv(self): class NoopVideoStreamTrack(MediaStreamTrack): """Simple passthrough video track that bypasses pipeline processing.""" + kind = "video" def __init__(self, track: MediaStreamTrack): @@ -147,6 +132,7 @@ async def recv(self): class NoopAudioStreamTrack(MediaStreamTrack): """Simple passthrough audio track that bypasses pipeline processing.""" + kind = "audio" def __init__(self, track: MediaStreamTrack): @@ -174,7 +160,7 @@ def __init__(self, track: MediaStreamTrack, pipeline): self.running = True logger.info(f"AudioStreamTrack created for track {track.id}") self.collect_task = asyncio.create_task(self.collect_frames()) - + # Add cleanup when track ends @track.on("ended") async def on_ended(): @@ -200,19 +186,18 @@ async def collect_frames(self): logger.error(f"Error collecting audio frames: {str(e)}") self.running = False break - + # Perform cleanup outside the exception handler logger.info("Audio frame collection stopped") except asyncio.CancelledError: logger.info("Frame collection task cancelled") except Exception as e: logger.error(f"Unexpected error in audio frame collection: {str(e)}") - finally: - await self.pipeline.cleanup() async def recv(self): return await self.pipeline.get_processed_audio_frame() + def force_codec(pc, sender, forced_codec): kind = forced_codec.split("/")[0] codecs = RTCRtpSender.getCapabilities(kind).codecs @@ -258,39 +243,42 @@ async def offer(request): pcs = request.app["pcs"] params = await request.json() - + # Check if this is noop mode (no prompts provided) prompts = params.get("prompts") is_noop_mode = not prompts - - if is_noop_mode: - logger.info("[Offer] No prompts provided - entering noop passthrough mode") - else: - await pipeline.set_prompts(prompts) - logger.info("[Offer] Set workflow prompts") - - # Set resolution if provided in the offer + resolution = params.get("resolution") if resolution: pipeline.width = resolution["width"] pipeline.height = resolution["height"] - logger.info(f"[Offer] Set pipeline resolution to {resolution['width']}x{resolution['height']}") + logger.info( + f"[Offer] Set pipeline resolution to {resolution['width']}x{resolution['height']}" + ) + + if is_noop_mode: + logger.info("[Offer] No prompts provided - entering noop passthrough mode") + else: + await pipeline.apply_prompts( + prompts, + skip_warmup=False, + ) + await pipeline.start_streaming() + logger.info("[Offer] Set workflow prompts, warmed pipeline, and started execution") offer_params = params["offer"] offer = RTCSessionDescription(sdp=offer_params["sdp"], type=offer_params["type"]) - + ice_servers = get_ice_servers() if len(ice_servers) > 0: - pc = RTCPeerConnection( - configuration=RTCConfiguration(iceServers=get_ice_servers()) - ) + pc = RTCPeerConnection(configuration=RTCConfiguration(iceServers=get_ice_servers())) else: pc = RTCPeerConnection() pcs.add(pc) tracks = {"video": None, "audio": None} - + # Flag to track if we've received resolution update resolution_received = {"value": False} @@ -323,56 +311,140 @@ def on_datachannel(channel): @channel.on("message") async def on_message(message): - try: - params = json.loads(message) + def send_json(payload): + channel.send(json.dumps(payload)) + + def send_success_response(response_type, **extra): + payload = {"type": response_type, "success": True} + payload.update(extra) + send_json(payload) + + def send_error_response(response_type, error_message, **extra): + payload = { + "type": response_type, + "success": False, + "error": error_message, + } + payload.update(extra) + send_json(payload) + + async def handle_get_nodes(_params): + nodes_info = await pipeline.get_nodes_info() + send_json({"type": "nodes_info", "nodes": nodes_info}) + + async def handle_update_prompts(_params): + if "prompts" not in _params: + logger.warning("[Control] Missing prompt in update_prompt message") + send_error_response( + "prompts_updated", "Missing 'prompts' in control message" + ) + return + try: + await pipeline.update_prompts(_params["prompts"]) + except Exception as e: + logger.error(f"Error updating prompt: {str(e)}") + send_error_response("prompts_updated", str(e)) + return + send_success_response("prompts_updated") + + async def handle_update_resolution(_params): + width = _params.get("width") + height = _params.get("height") + if width is None or height is None: + logger.warning( + "[Control] Missing width or height in update_resolution message" + ) + send_error_response( + "resolution_updated", + "Missing 'width' or 'height' in control message", + ) + return + + if is_noop_mode: + logger.info( + f"[Control] Noop mode - resolution update to {width}x{height} (no pipeline involved)" + ) + else: + # Update pipeline resolution for future frames + pipeline.width = width + pipeline.height = height + logger.info(f"[Control] Updated resolution to {width}x{height}") + + # Mark that we've received resolution + resolution_received["value"] = True - if params.get("type") == "get_nodes": - nodes_info = await pipeline.get_nodes_info() - response = {"type": "nodes_info", "nodes": nodes_info} - channel.send(json.dumps(response)) - elif params.get("type") == "update_prompts": - if "prompts" not in params: - logger.warning( - "[Control] Missing prompt in update_prompt message" - ) + if is_noop_mode: + logger.info("[Control] Noop mode - no warmup needed") + else: + # Note: Video warmup now happens during offer, not here + logger.info( + "[Control] Resolution updated - warmup was already performed during offer" + ) + + send_success_response("resolution_updated") + + async def handle_pause_prompts(_params): + if is_noop_mode: + logger.info("[Control] Noop mode - no prompts to pause") + else: + try: + await pipeline.pause_prompts() + logger.info("[Control] Paused prompt execution") + except Exception as e: + logger.error(f"[Control] Error pausing prompts: {str(e)}") + send_error_response("prompts_paused", str(e)) return + send_success_response("prompts_paused") + + async def handle_resume_prompts(_params): + if is_noop_mode: + logger.info("[Control] Noop mode - no prompts to resume") + else: try: - await pipeline.update_prompts(params["prompts"]) + await pipeline.start_streaming() + logger.info("[Control] Resumed prompt execution") except Exception as e: - logger.error(f"Error updating prompt: {str(e)}") - response = {"type": "prompts_updated", "success": True} - channel.send(json.dumps(response)) - elif params.get("type") == "update_resolution": - if "width" not in params or "height" not in params: - logger.warning("[Control] Missing width or height in update_resolution message") + logger.error(f"[Control] Error resuming prompts: {str(e)}") + send_error_response("prompts_resumed", str(e)) return - - if is_noop_mode: - logger.info(f"[Control] Noop mode - resolution update to {params['width']}x{params['height']} (no pipeline involved)") - else: - # Update pipeline resolution for future frames - pipeline.width = params["width"] - pipeline.height = params["height"] - logger.info(f"[Control] Updated resolution to {params['width']}x{params['height']}") - - # Mark that we've received resolution - resolution_received["value"] = True - - if is_noop_mode: - logger.info("[Control] Noop mode - no warmup needed") - else: - # Note: Video warmup now happens during offer, not here - logger.info("[Control] Resolution updated - warmup was already performed during offer") - - response = { - "type": "resolution_updated", - "success": True - } - channel.send(json.dumps(response)) + send_success_response("prompts_resumed") + + async def handle_stop_prompts(_params): + if is_noop_mode: + logger.info("[Control] Noop mode - no prompts to stop") else: - logger.warning( - "[Server] Invalid message format - missing required fields" - ) + try: + await pipeline.stop_prompts(cleanup=False) + logger.info("[Control] Stopped prompt execution") + except Exception as e: + logger.error(f"[Control] Error stopping prompts: {str(e)}") + send_error_response("prompts_stopped", str(e)) + return + send_success_response("prompts_stopped") + + handlers = { + "get_nodes": handle_get_nodes, + "update_prompts": handle_update_prompts, + "update_resolution": handle_update_resolution, + "pause_prompts": handle_pause_prompts, + "resume_prompts": handle_resume_prompts, + "stop_prompts": handle_stop_prompts, + } + + try: + params = json.loads(message) + message_type = params.get("type") + + if not message_type: + logger.warning("[Server] Control message missing 'type'") + return + + handler = handlers.get(message_type) + if handler is None: + logger.warning(f"[Server] Unsupported control message: {message_type}") + return + + await handler(params) except json.JSONDecodeError: logger.error("[Server] Invalid JSON received") except Exception as e: @@ -381,12 +453,16 @@ async def on_message(message): elif channel.label == "data": if is_noop_mode: logger.debug("[TextChannel] Noop mode - skipping text output forwarding") + # In noop mode, just acknowledge the data channel but don't forward anything @channel.on("open") def on_data_channel_open(): - logger.debug("[TextChannel] Data channel opened in noop mode (no text forwarding)") + logger.debug( + "[TextChannel] Data channel opened in noop mode (no text forwarding)" + ) else: if pipeline.produces_text_output(): + async def forward_text(): try: while channel.readyState == "open": @@ -401,7 +477,9 @@ async def forward_text(): try: channel.send(json.dumps({"type": "text", "data": text})) except Exception as e: - logger.debug(f"[TextChannel] Send failed, stopping forwarder: {e}") + logger.debug( + f"[TextChannel] Send failed, stopping forwarder: {e}" + ) break except asyncio.CancelledError: logger.debug("[TextChannel] Forward text task cancelled") @@ -420,6 +498,7 @@ def _remove_forward_task(t): tasks = request.app.get("data_channel_tasks") if tasks is not None: tasks.discard(t) + forward_task.add_done_callback(_remove_forward_task) # Ensure cancellation on channel close event @@ -431,20 +510,22 @@ def on_data_channel_close(): if not t.done(): t.cancel() else: - logger.debug("[TextChannel] Workflow has no text outputs; not starting forward_text") + logger.debug( + "[TextChannel] Workflow has no text outputs; not starting forward_text" + ) @pc.on("track") def on_track(track): logger.info(f"Track received: {track.kind} (readyState: {track.readyState})") - + # Check if we already have a track of this type to avoid duplicate track errors if track.kind == "video" and tracks["video"] is not None: - logger.debug(f"Video track already exists, ignoring duplicate track event") + logger.debug("Video track already exists, ignoring duplicate track event") return elif track.kind == "audio" and tracks["audio"] is not None: - logger.debug(f"Audio track already exists, ignoring duplicate track event") + logger.debug("Audio track already exists, ignoring duplicate track event") return - + if track.kind == "video": if is_noop_mode: # Use simple passthrough track that bypasses pipeline @@ -454,7 +535,7 @@ def on_track(track): # Always use pipeline processing - it handles passthrough internally based on workflow videoTrack = VideoStreamTrack(track, pipeline) logger.info("[Pipeline] Using video processing pipeline") - + tracks["video"] = videoTrack sender = pc.addTrack(videoTrack) @@ -465,11 +546,10 @@ def on_track(track): codec = "video/H264" force_codec(pc, sender, codec) - - + elif track.kind == "audio": logger.info(f"Creating audio track for track {track.id}") - + if is_noop_mode: # Use simple passthrough track that bypasses pipeline audioTrack = NoopAudioStreamTrack(track) @@ -478,10 +558,10 @@ def on_track(track): # Always use pipeline processing - it handles passthrough internally based on workflow audioTrack = AudioStreamTrack(track, pipeline) logger.info("[Pipeline] Using audio processing pipeline") - + tracks["audio"] = audioTrack sender = pc.addTrack(audioTrack) - logger.debug(f"Audio track added to peer connection") + logger.debug("Audio track added to peer connection") @track.on("ended") async def on_ended(): @@ -500,6 +580,13 @@ async def on_connectionstatechange(): if not task.done(): task.cancel() request.app["data_channel_tasks"].clear() + # Cleanup pipeline once per connection (not per track) + if not is_noop_mode: + try: + await pipeline.stop_prompts(cleanup=True) + logger.info("Pipeline cleanup completed for failed connection") + except Exception as e: + logger.error(f"Error during pipeline cleanup on connection failure: {e}") elif pc.connectionState == "closed": await pc.close() pcs.discard(pc) @@ -509,6 +596,13 @@ async def on_connectionstatechange(): if not task.done(): task.cancel() request.app["data_channel_tasks"].clear() + # Cleanup pipeline once per connection (not per track) + if not is_noop_mode: + try: + await pipeline.stop_prompts(cleanup=True) + logger.info("Pipeline cleanup completed for closed connection") + except Exception as e: + logger.error(f"Error during pipeline cleanup on connection close: {e}") await pc.setRemoteDescription(offer) @@ -516,18 +610,12 @@ async def on_connectionstatechange(): transceivers = pc.getTransceivers() logger.debug(f"[Offer] After negotiation - Total transceivers: {len(transceivers)}") for i, t in enumerate(transceivers): - logger.debug(f"[Offer] Transceiver {i}: {t.kind} - direction: {t.direction} - currentDirection: {t.currentDirection}") + logger.debug( + f"[Offer] Transceiver {i}: {t.kind} - direction: {t.direction} - currentDirection: {t.currentDirection}" + ) # Warm up the pipeline based on detected modalities and SDP content (skip in noop mode) - if not is_noop_mode: - if "m=video" in pc.remoteDescription.sdp and pipeline.accepts_video_input(): - logger.info("[Offer] Warming up video pipeline") - await pipeline.warm_video() - - if "m=audio" in pc.remoteDescription.sdp and pipeline.accepts_audio_input(): - logger.info("[Offer] Warming up audio pipeline") - await pipeline.warm_audio() - else: + if is_noop_mode: logger.debug("[Offer] Skipping pipeline warmup in noop mode") answer = await pc.createAnswer() @@ -535,28 +623,29 @@ async def on_connectionstatechange(): return web.Response( content_type="application/json", - text=json.dumps( - {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type} - ), + text=json.dumps({"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}), ) + async def cancel_collect_frames(track): track.running = False - if hasattr(track, 'collect_task') is not None and not track.collect_task.done(): + if track.collect_task and not track.collect_task.done(): try: track.collect_task.cancel() await track.collect_task - except (asyncio.CancelledError): + except asyncio.CancelledError: pass + async def set_prompt(request): pipeline = request.app["pipeline"] prompt = await request.json() - await pipeline.set_prompts(prompt) + await pipeline.apply_prompts(prompt) return web.Response(content_type="application/json", text="OK") + def health(_): return web.Response(content_type="application/json", text="OK") @@ -568,12 +657,14 @@ async def on_startup(app: web.Application): app["pipeline"] = Pipeline( width=512, height=512, - cwd=app["workspace"], - disable_cuda_malloc=True, - gpu_only=True, - preview_method='none', - comfyui_inference_log_level=app.get("comfui_inference_log_level", None), + cwd=app["workspace"], + disable_cuda_malloc=True, + gpu_only=True, + preview_method="none", + comfyui_inference_log_level=app.get("comfyui_inference_log_level", None), + blacklist_custom_nodes=["ComfyUI-Manager"], ) + await app["pipeline"].initialize() app["pcs"] = set() app["video_tracks"] = {} @@ -588,13 +679,9 @@ async def on_shutdown(app: web.Application): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run comfystream server") parser.add_argument("--port", default=8889, help="Set the signaling port") - parser.add_argument( - "--media-ports", default=None, help="Set the UDP ports for WebRTC media" - ) + parser.add_argument("--media-ports", default=None, help="Set the UDP ports for WebRTC media") parser.add_argument("--host", default="127.0.0.1", help="Set the host") - parser.add_argument( - "--workspace", default=None, required=True, help="Set Comfy workspace" - ) + parser.add_argument("--workspace", default=None, required=True, help="Set Comfy workspace") parser.add_argument( "--log-level", default="INFO", @@ -636,16 +723,6 @@ async def on_shutdown(app: web.Application): app = web.Application() app["media_ports"] = args.media_ports.split(",") if args.media_ports else None app["workspace"] = args.workspace - - # Setup CORS - cors = setup_cors(app, defaults={ - "*": ResourceOptions( - allow_credentials=True, - expose_headers="*", - allow_headers="*", - allow_methods=["GET", "POST", "OPTIONS"] - ) - }) app.on_startup.append(on_startup) app.on_shutdown.append(on_shutdown) @@ -656,18 +733,10 @@ async def on_shutdown(app: web.Application): # WebRTC signalling and control routes. app.router.add_post("/offer", offer) app.router.add_post("/prompt", set_prompt) - - # Setup HTTP streaming routes - setup_routes(app, cors) - - # Serve static files from the public directory - app.router.add_static("/", path=os.path.join(os.path.dirname(__file__), "public"), name="static") # Add routes for getting stream statistics. stream_stats_manager = StreamStatsManager(app) - app.router.add_get( - "/streams/stats", stream_stats_manager.collect_all_stream_metrics - ) + app.router.add_get("/streams/stats", stream_stats_manager.collect_all_stream_metrics) app.router.add_get( "/stream/{stream_id}/stats", stream_stats_manager.collect_stream_metrics_by_id ) @@ -694,7 +763,12 @@ def force_print(*args, **kwargs): if args.comfyui_log_level: log_level = logging._nameToLevel.get(args.comfyui_log_level.upper()) logging.getLogger("comfy").setLevel(log_level) + + # Add ComfyStream timeout filter to suppress verbose execution logging + timeout_filter = ComfyStreamTimeoutFilter() + logging.getLogger("comfy.cmd.execution").addFilter(timeout_filter) + logging.getLogger("comfystream").addFilter(timeout_filter) if args.comfyui_inference_log_level: - app["comfui_inference_log_level"] = args.comfyui_inference_log_level + app["comfyui_inference_log_level"] = args.comfyui_inference_log_level web.run_app(app, host=args.host, port=int(args.port), print=force_print) diff --git a/server/byoc.py b/server/byoc.py index 0735674b4..3f8f3470c 100644 --- a/server/byoc.py +++ b/server/byoc.py @@ -5,54 +5,35 @@ import sys import torch + # Initialize CUDA before any other imports to prevent core dump. if torch.cuda.is_available(): torch.cuda.init() from aiohttp import web +from frame_processor import ComfyStreamFrameProcessor +from pytrickle.frame_overlay import OverlayConfig, OverlayMode +from pytrickle.frame_skipper import FrameSkipConfig from pytrickle.stream_processor import StreamProcessor from pytrickle.utils.register import RegisterCapability -from pytrickle.frame_skipper import FrameSkipConfig -from frame_processor import ComfyStreamFrameProcessor -logger = logging.getLogger(__name__) +from comfystream.exceptions import ComfyStreamTimeoutFilter +logger = logging.getLogger(__name__) -async def register_orchestrator(orch_url=None, orch_secret=None, capability_name=None, host="127.0.0.1", port=8889): - """Register capability with orchestrator if configured.""" - try: - orch_url = orch_url or os.getenv("ORCH_URL") - orch_secret = orch_secret or os.getenv("ORCH_SECRET") - - if orch_url and orch_secret: - os.environ.update({ - "CAPABILITY_NAME": capability_name or os.getenv("CAPABILITY_NAME") or "comfystream-processor", - "CAPABILITY_DESCRIPTION": "ComfyUI streaming processor", - "CAPABILITY_URL": f"http://{host}:{port}", - "CAPABILITY_CAPACITY": "1", - "ORCH_URL": orch_url, - "ORCH_SECRET": orch_secret - }) - - # Pass through explicit capability_name to ensure CLI/env override takes effect - result = await RegisterCapability.register( - logger=logger, - capability_name=capability_name - ) - if result: - logger.info(f"Registered capability: {result.geturl()}") - except Exception as e: - logger.error(f"Orchestrator registration failed: {e}") +DEFAULT_WITHHELD_TIMEOUT_SECONDS = 0.5 def main(): parser = argparse.ArgumentParser( description="Run comfystream server in BYOC (Bring Your Own Compute) mode using pytrickle." ) - parser.add_argument("--port", default=8889, help="Set the server port") - parser.add_argument("--host", default="127.0.0.1", help="Set the host") + parser.add_argument("--port", default=8000, help="Set the server port") + parser.add_argument("--host", default="0.0.0.0", help="Set the host") parser.add_argument( - "--workspace", default=None, required=True, help="Set Comfy workspace" + "--workspace", + default=os.getcwd() + "/../ComfyUI", + help="Set Comfy workspace (Default: ../ComfyUI)", ) parser.add_argument( "--log-level", @@ -72,21 +53,6 @@ def main(): choices=logging._nameToLevel.keys(), help="Set the logging level for ComfyUI inference", ) - parser.add_argument( - "--orch-url", - default=None, - help="Orchestrator URL for capability registration", - ) - parser.add_argument( - "--orch-secret", - default=None, - help="Orchestrator secret for capability registration", - ) - parser.add_argument( - "--capability-name", - default=None, - help="Name for this capability (default: comfystream-processor)", - ) parser.add_argument( "--disable-frame-skip", default=False, @@ -112,18 +78,28 @@ def main(): format="%(asctime)s [%(levelname)s] %(message)s", datefmt="%H:%M:%S", ) + logging.getLogger("comfy.model_detection").setLevel(logging.WARNING) # Allow overriding of ComfyUI log levels. if args.comfyui_log_level: log_level = logging._nameToLevel.get(args.comfyui_log_level.upper()) logging.getLogger("comfy").setLevel(log_level) + # Add ComfyStream timeout filter to suppress verbose execution logging + timeout_filter = ComfyStreamTimeoutFilter() + logging.getLogger("comfy.cmd.execution").addFilter(timeout_filter) + logging.getLogger("comfystream").addFilter(timeout_filter) + def force_print(*args, **kwargs): print(*args, **kwargs, flush=True) sys.stdout.flush() logger.info("Starting ComfyStream BYOC server with pytrickle StreamProcessor...") - + logger.info( + "Send initial workflow parameters (width/height/prompts/warmup) via /stream/start " + "params; runtime updates now apply incremental changes only." + ) + # Create frame processor with configuration frame_processor = ComfyStreamFrameProcessor( width=args.width, @@ -131,10 +107,12 @@ def force_print(*args, **kwargs): workspace=args.workspace, disable_cuda_malloc=True, gpu_only=True, - preview_method='none', - comfyui_inference_log_level=args.comfyui_inference_log_level + preview_method="none", + blacklist_custom_nodes=["ComfyUI-Manager"], + logging_level=args.comfyui_log_level, + comfyui_inference_log_level=args.comfyui_inference_log_level, ) - + # Create frame skip configuration only if enabled frame_skip_config = None if args.disable_frame_skip: @@ -142,67 +120,73 @@ def force_print(*args, **kwargs): else: frame_skip_config = FrameSkipConfig() logger.info("Frame skipping enabled: adaptive skipping based on queue sizes") - + # Create StreamProcessor with frame processor processor = StreamProcessor( video_processor=frame_processor.process_video_async, audio_processor=frame_processor.process_audio_async, model_loader=frame_processor.load_model, param_updater=frame_processor.update_params, + on_stream_start=frame_processor.on_stream_start, on_stream_stop=frame_processor.on_stream_stop, # Align processor name with capability for consistent logs - name=(args.capability_name or os.getenv("CAPABILITY_NAME") or "comfystream-processor"), + name=(os.getenv("CAPABILITY_NAME") or "comfystream"), port=int(args.port), host=args.host, frame_skip_config=frame_skip_config, + overlay_config=OverlayConfig( + mode=OverlayMode.PROGRESSBAR, + message="Loading...", + enabled=True, + auto_timeout_seconds=DEFAULT_WITHHELD_TIMEOUT_SECONDS, + frame_count_to_disable=20, + ), # Ensure server metadata reflects the desired capability name - capability_name=(args.capability_name or os.getenv("CAPABILITY_NAME") or "comfystream-processor") + capability_name=(os.getenv("CAPABILITY_NAME") or "comfystream"), + # server_kwargs... + route_prefix="/", ) # Set the stream processor reference for text data publishing frame_processor.set_stream_processor(processor) - - # Create async startup function to load model - async def load_model_on_startup(app): - await processor._frame_processor.load_model() - + + logger.info("Startup warmup runs automatically as part of on_stream_start.") + # Create async startup function for orchestrator registration async def register_orchestrator_startup(app): - await register_orchestrator( - orch_url=args.orch_url, - orch_secret=args.orch_secret, - capability_name=args.capability_name, - host=args.host, - port=args.port - ) - - # Add model loading and registration to startup hooks - processor.server.app.on_startup.append(load_model_on_startup) - processor.server.app.on_startup.append(register_orchestrator_startup) - - # Add warmup endpoint: accepts same body as prompts update - async def warmup_handler(request): try: - body = await request.json() + orch_url = os.getenv("ORCH_URL") + + if orch_url and os.getenv("ORCH_SECRET", None): + # CAPABILITY_URL always overrides host:port from args + capability_url = os.getenv("CAPABILITY_URL") or f"http://{args.host}:{args.port}" + + os.environ.update( + { + "CAPABILITY_NAME": os.getenv("CAPABILITY_NAME") or "comfystream", + "CAPABILITY_DESCRIPTION": "ComfyUI streaming processor", + "CAPABILITY_URL": capability_url, + "CAPABILITY_CAPACITY": "1", + "ORCH_URL": orch_url, + "ORCH_SECRET": os.getenv("ORCH_SECRET", None), + } + ) + + result = await RegisterCapability.register( + logger=logger, capability_name=os.getenv("CAPABILITY_NAME") or "comfystream" + ) + if result: + logger.info(f"Registered capability: {result.geturl()}") + # Clear ORCH_SECRET from environment after use for security + os.environ.pop("ORCH_SECRET", None) except Exception as e: - logger.error(f"Invalid JSON in warmup request: {e}") - return web.json_response({"error": "Invalid JSON"}, status=400) - try: - # Inject sentinel to trigger warmup inside update_params on the model thread - if isinstance(body, dict): - body["warmup"] = True - else: - body = {"warmup": True} - # Fire-and-forget: do not await warmup; update_params will schedule it - asyncio.get_running_loop().create_task(frame_processor.update_params(body)) - return web.json_response({"status": "accepted"}) - except Exception as e: - logger.error(f"Warmup failed: {e}") - return web.json_response({"error": str(e)}, status=500) + logger.error(f"Orchestrator registration failed: {e}") + # Clear ORCH_SECRET from environment even on error + os.environ.pop("ORCH_SECRET", None) + + # Add registration to startup hooks + processor.server.app.on_startup.append(register_orchestrator_startup) - # Mount at same API namespace as StreamProcessor defaults - processor.server.add_route("POST", "/api/stream/warmup", warmup_handler) - # Run the processor processor.run() diff --git a/server/frame_buffer.py b/server/frame_buffer.py deleted file mode 100644 index 2a16407ae..000000000 --- a/server/frame_buffer.py +++ /dev/null @@ -1,42 +0,0 @@ -import threading -import time -import numpy as np -import cv2 -import av -from typing import Optional - -class FrameBuffer: - _instance = None - - @classmethod - def get_instance(cls): - if cls._instance is None: - cls._instance = FrameBuffer() - return cls._instance - - def __init__(self): - self.current_frame = None - self.frame_lock = threading.Lock() - self.last_update_time = 0 - self.quality = 70 # JPEG quality (0-100) - - def update_frame(self, frame): - """Update the current frame in the buffer""" - with self.frame_lock: - # Convert frame to numpy array if it's an av.VideoFrame - if isinstance(frame, av.VideoFrame): - frame_np = frame.to_ndarray(format="rgb24") - else: - frame_np = frame - - # Store the frame as a JPEG-encoded bytes object for efficient serving - _, jpeg_frame = cv2.imencode('.jpg', cv2.cvtColor(frame_np, cv2.COLOR_RGB2BGR), - [cv2.IMWRITE_JPEG_QUALITY, self.quality]) - - self.current_frame = jpeg_frame.tobytes() - self.last_update_time = time.time() - - def get_current_frame(self) -> Optional[bytes]: - """Get the current frame from the buffer""" - with self.frame_lock: - return self.current_frame diff --git a/server/frame_processor.py b/server/frame_processor.py index bac139d4a..befb5c608 100644 --- a/server/frame_processor.py +++ b/server/frame_processor.py @@ -2,13 +2,19 @@ import json import logging import os -from typing import List +from typing import Any, Dict, List, Optional, Union -import numpy as np from pytrickle.frame_processor import FrameProcessor -from pytrickle.frames import VideoFrame, AudioFrame +from pytrickle.frames import AudioFrame, VideoFrame +from pytrickle.stream_processor import VideoProcessingResult + from comfystream.pipeline import Pipeline -from comfystream.utils import convert_prompt, ComfyStreamParamsUpdateRequest +from comfystream.pipeline_state import PipelineState +from comfystream.utils import ( + ComfyStreamParamsUpdateRequest, + convert_prompt, + normalize_stream_params, +) logger = logging.getLogger(__name__) @@ -16,17 +22,19 @@ class ComfyStreamFrameProcessor(FrameProcessor): """ Integrated ComfyStream FrameProcessor for pytrickle. - + This class wraps the ComfyStream Pipeline to work with pytrickle's streaming architecture. """ def __init__(self, text_poll_interval: float = 0.25, **load_params): """Initialize with load parameters for pipeline creation. - + Args: text_poll_interval: Interval in seconds to poll for text outputs (default: 0.25) **load_params: Parameters for pipeline creation """ + super().__init__() + self.pipeline = None self._load_params = load_params self._text_poll_interval = text_poll_interval @@ -35,13 +43,64 @@ def __init__(self, text_poll_interval: float = 0.25, **load_params): self._text_forward_task = None self._background_tasks = [] self._stop_event = asyncio.Event() - super().__init__() + + async def _apply_stream_start_prompt(self, prompt_value: Any) -> bool: + if not self.pipeline: + logger.debug("Cannot apply stream start prompt without pipeline") + return False + + # Parse prompt payload from various formats + prompt_dict = None + if prompt_value is None: + pass + elif isinstance(prompt_value, dict): + prompt_dict = prompt_value + elif isinstance(prompt_value, list): + for candidate in prompt_value: + if isinstance(candidate, dict): + prompt_dict = candidate + break + elif isinstance(prompt_value, str): + prompt_str = prompt_value.strip() + if prompt_str: + try: + parsed = json.loads(prompt_str) + if isinstance(parsed, dict): + prompt_dict = parsed + else: + logger.warning("Parsed prompt payload is %s, expected dict", type(parsed)) + except json.JSONDecodeError: + logger.error("Stream start prompt is not valid JSON") + else: + logger.warning("Unsupported prompt payload type: %s", type(prompt_value)) + + if not isinstance(prompt_dict, dict): + logger.warning("Skipping prompt application due to invalid payload") + return False + + try: + await self._process_prompts(prompt_dict, skip_warmup=True) + return True + except Exception: + logger.exception("Failed to apply stream start prompt") + raise + + def _workflow_has_video(self) -> bool: + """Return True if current workflow is expected to produce video output.""" + if not self.pipeline: + return False + try: + capabilities = self.pipeline.get_workflow_io_capabilities() + return bool(capabilities.get("video", {}).get("output", False)) + except Exception: + logger.debug("Unable to determine workflow video capability", exc_info=True) + return False def set_stream_processor(self, stream_processor): """Set reference to StreamProcessor for data publishing.""" self._stream_processor = stream_processor logger.info("StreamProcessor reference set for text data publishing") - + def _setup_text_monitoring(self): """Set up background text forwarding from the pipeline.""" try: @@ -75,7 +134,9 @@ async def _forward_text_loop(): if self._stream_processor: success = await self._stream_processor.send_data(text) if not success: - logger.debug("Text send failed; stopping text forwarder") + logger.debug( + "Text send failed; stopping text forwarder" + ) break except asyncio.CancelledError: logger.debug("Text forwarder task cancelled") @@ -105,7 +166,7 @@ async def _stop_text_forwarder(self) -> None: except Exception: logger.debug("Error while awaiting text forwarder cancellation", exc_info=True) self._text_forward_task = None - + async def on_stream_stop(self): """Called when stream stops - cleanup background tasks.""" logger.info("Stream stopped, cleaning up background tasks") @@ -113,6 +174,14 @@ async def on_stream_stop(self): # Set stop event to signal all background tasks to stop self._stop_event.set() + # Stop the ComfyStream client's prompt execution + if self.pipeline: + logger.info("Stopping ComfyStream client prompt execution") + try: + await self.pipeline.stop_prompts(cleanup=True) + except Exception as e: + logger.error(f"Error stopping ComfyStream client: {e}") + # Stop text forwarder await self._stop_text_forwarder() @@ -136,50 +205,89 @@ async def on_stream_stop(self): self._background_tasks.clear() logger.info("All background tasks cleaned up") - + def _reset_stop_event(self): """Reset the stop event for a new stream.""" self._stop_event.clear() + async def on_stream_start(self, params: Optional[Dict[str, Any]] = None): + """Handle stream start lifecycle events.""" + logger.info("Stream starting") + self._reset_stop_event() + logger.info(f"Stream start params: {params}") + + if not self.pipeline: + logger.debug("Stream start requested before pipeline initialization") + return + + stream_params = normalize_stream_params(params) + prompt_payload = stream_params.pop("prompts", None) + if prompt_payload is None: + prompt_payload = stream_params.pop("prompt", None) + + if prompt_payload: + try: + await self._apply_stream_start_prompt(prompt_payload) + except Exception: + logger.exception("Failed to apply stream start prompt") + return + + if not self.pipeline.state_manager.is_initialized(): + logger.info("Pipeline not initialized; waiting for prompts before streaming") + return + + if stream_params: + try: + await self.update_params(stream_params) + except Exception: + logger.exception("Failed to process stream start parameters") + return + + try: + if ( + self.pipeline.state != PipelineState.STREAMING + and self.pipeline.state_manager.can_stream() + ): + await self.pipeline.start_streaming() + + if self.pipeline.produces_text_output(): + self._setup_text_monitoring() + else: + await self._stop_text_forwarder() + except Exception: + logger.exception("Failed during stream start", exc_info=True) + async def load_model(self, **kwargs): """Load model and initialize the pipeline.""" params = {**self._load_params, **kwargs} - + if self.pipeline is None: self.pipeline = Pipeline( - width=int(params.get('width', 512)), - height=int(params.get('height', 512)), - cwd=params.get('workspace', os.getcwd()), - disable_cuda_malloc=params.get('disable_cuda_malloc', True), - gpu_only=params.get('gpu_only', True), - preview_method=params.get('preview_method', 'none'), - comfyui_inference_log_level=params.get('comfyui_inference_log_level'), + width=int(params.get("width", 512)), + height=int(params.get("height", 512)), + cwd=params.get("workspace", os.getcwd()), + disable_cuda_malloc=params.get("disable_cuda_malloc", True), + gpu_only=params.get("gpu_only", True), + preview_method=params.get("preview_method", "none"), + comfyui_inference_log_level=params.get("comfyui_inference_log_level", "INFO"), + logging_level=params.get("comfyui_inference_log_level", "INFO"), + blacklist_custom_nodes=["ComfyUI-Manager"], ) + await self.pipeline.initialize() async def warmup(self): - """Public warmup method that triggers pipeline warmup.""" + """Warm up the pipeline.""" if not self.pipeline: logger.warning("Warmup requested before pipeline initialization") return - + logger.info("Running pipeline warmup...") - """Run pipeline warmup.""" try: capabilities = self.pipeline.get_workflow_io_capabilities() - logger.info(f"Detected I/O capabilities for warmup: {capabilities}") - - # Warm video if there are video inputs or outputs - if capabilities.get("video", {}).get("input") or capabilities.get("video", {}).get("output"): - logger.info("Running video warmup...") - await self.pipeline.warm_video() - logger.info("Video warmup completed") - - # Warm audio if there are audio inputs or outputs - if capabilities.get("audio", {}).get("input") or capabilities.get("audio", {}).get("output"): - logger.info("Running audio warmup...") - await self.pipeline.warm_audio() - logger.info("Audio warmup completed") - + logger.info(f"Detected I/O capabilities: {capabilities}") + + await self.pipeline.warmup() + except Exception as e: logger.error(f"Warmup failed: {e}") @@ -195,23 +303,43 @@ def _schedule_warmup(self) -> None: except Exception: logger.warning("Failed to schedule warmup", exc_info=True) - async def process_video_async(self, frame: VideoFrame) -> VideoFrame: - """Process video frame through ComfyStream Pipeline.""" + async def process_video_async( + self, frame: VideoFrame + ) -> Union[VideoFrame, VideoProcessingResult]: + """Process video frame through ComfyStream Pipeline. + + Returns VideoProcessingResult.WITHHELD to trigger pytrickle's automatic overlay when + processed frames are not yet available. + """ try: - + if not self.pipeline: + return frame + + # If pipeline ingestion is paused, withhold frame so pytrickle renders the overlay + if not self.pipeline.is_ingest_enabled(): + return VideoProcessingResult.WITHHELD + # Convert pytrickle VideoFrame to av.VideoFrame av_frame = frame.to_av_frame(frame.tensor) av_frame.pts = frame.timestamp av_frame.time_base = frame.time_base - + # Process through pipeline await self.pipeline.put_video_frame(av_frame) - processed_av_frame = await self.pipeline.get_processed_video_frame() - - # Convert back to pytrickle VideoFrame - processed_frame = VideoFrame.from_av_frame_with_timing(processed_av_frame, frame) - return processed_frame - + + # Try to get processed frame with short timeout + try: + processed_av_frame = await asyncio.wait_for( + self.pipeline.get_processed_video_frame(), + timeout=self._stream_processor.overlay_config.auto_timeout_seconds, + ) + processed_frame = VideoFrame.from_av_frame_with_timing(processed_av_frame, frame) + return processed_frame + + except asyncio.TimeoutError: + # No frame ready yet - return withheld sentinel to trigger overlay + return VideoProcessingResult.WITHHELD + except Exception as e: logger.error(f"Video processing failed: {e}") return frame @@ -221,14 +349,20 @@ async def process_audio_async(self, frame: AudioFrame) -> List[AudioFrame]: try: if not self.pipeline: return [frame] - - # Audio processing needed - use pipeline + + # If pipeline ingestion is paused, passthrough audio + if not self.pipeline.is_ingest_enabled(): + frame.side_data.skipped = True + frame.side_data.passthrough = True + return [frame] + + # Audio processing - use pipeline av_frame = frame.to_av_frame() await self.pipeline.put_audio_frame(av_frame) processed_av_frame = await self.pipeline.get_processed_audio_frame() processed_frame = AudioFrame.from_av_audio(processed_av_frame) return [processed_frame] - + except Exception as e: logger.error(f"Audio processing failed: {e}") return [frame] @@ -237,44 +371,59 @@ async def update_params(self, params: dict): """Update processing parameters.""" if not self.pipeline: return - - # Handle list input - take first element - if isinstance(params, list) and params: - params = params[0] - + + params_payload: Dict[str, Any] = {} + if isinstance(params, list): + params = params[0] if params else {} + + if isinstance(params, dict): + params_payload = dict(params) + elif params is None: + params_payload = {} + else: + logger.warning("Unsupported params type for update_params: %s", type(params)) + return + + if not params_payload: + return + # Validate parameters using the centralized validation - validated = ComfyStreamParamsUpdateRequest(**params).model_dump() + validated = ComfyStreamParamsUpdateRequest(**params_payload).model_dump() logger.info(f"Parameter validation successful, keys: {list(validated.keys())}") - + # Process prompts if provided if "prompts" in validated and validated["prompts"]: - await self._process_prompts(validated["prompts"]) - + await self._process_prompts(validated["prompts"], skip_warmup=True) + # Update pipeline dimensions if "width" in validated: self.pipeline.width = int(validated["width"]) if "height" in validated: self.pipeline.height = int(validated["height"]) - - # Schedule warmup if requested - if validated.get("warmup", False): - self._schedule_warmup() - - async def _process_prompts(self, prompts): + async def _process_prompts(self, prompts, *, skip_warmup: bool = False): """Process and set prompts in the pipeline.""" + if not self.pipeline: + logger.warning("Prompt update requested before pipeline initialization") + return try: converted = convert_prompt(prompts, return_dict=True) - - # Set prompts in pipeline - await self.pipeline.set_prompts([converted]) - logger.info(f"Prompts set successfully: {list(prompts.keys())}") - - # Update text monitoring based on workflow capabilities + + await self.pipeline.apply_prompts( + [converted], + skip_warmup=skip_warmup, + ) + + if self.pipeline.state_manager.can_stream(): + await self.pipeline.start_streaming() + + logger.info(f"Prompts applied successfully: {list(prompts.keys())}") + if self.pipeline.produces_text_output(): self._setup_text_monitoring() else: await self._stop_text_forwarder() - + except Exception as e: logger.error(f"Failed to process prompts: {e}") + raise diff --git a/server/http_streaming/__init__.py b/server/http_streaming/__init__.py deleted file mode 100644 index 4fad17f79..000000000 --- a/server/http_streaming/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -""" -HTTP Streaming module for ComfyStream. - -This module contains components for token management and HTTP streaming routes. -""" diff --git a/server/http_streaming/routes.py b/server/http_streaming/routes.py deleted file mode 100644 index ac309bae5..000000000 --- a/server/http_streaming/routes.py +++ /dev/null @@ -1,69 +0,0 @@ -""" -HTTP streaming routes for ComfyStream. - -This module contains the routes for HTTP streaming. -""" -import asyncio -import logging -from aiohttp import web -from frame_buffer import FrameBuffer -from .tokens import cleanup_expired_sessions, validate_token, create_stream_token - -logger = logging.getLogger(__name__) - -async def stream_mjpeg(request): - """Serve an MJPEG stream with token validation""" - # Clean up expired sessions - cleanup_expired_sessions() - - stream_id = request.query.get("token") - - # Validate the stream token - is_valid, error_message = validate_token(stream_id) - if not is_valid: - return web.Response(status=403, text=error_message) - - frame_buffer = FrameBuffer.get_instance() - - # Use a fixed frame delay for 30 FPS - frame_delay = 1.0 / 30 - - response = web.StreamResponse( - status=200, - reason='OK', - headers={ - 'Content-Type': 'multipart/x-mixed-replace; boundary=frame', - 'Cache-Control': 'no-cache', - 'Connection': 'close', - } - ) - await response.prepare(request) - - try: - while True: - jpeg_frame = frame_buffer.get_current_frame() - if jpeg_frame is not None: - await response.write( - b'--frame\r\n' - b'Content-Type: image/jpeg\r\n\r\n' + jpeg_frame + b'\r\n' - ) - await asyncio.sleep(frame_delay) - except (ConnectionResetError, asyncio.CancelledError): - logger.info("MJPEG stream connection closed") - except Exception as e: - logger.error(f"Error in MJPEG stream: {e}") - finally: - return response - -def setup_routes(app, cors): - """Setup HTTP streaming routes - - Args: - app: The aiohttp web application - cors: The CORS setup object - """ - # Stream token endpoints - cors.add(app.router.add_post("/api/stream-token", create_stream_token)) - - # Stream endpoint with token validation - cors.add(app.router.add_get("/api/stream", stream_mjpeg)) diff --git a/server/http_streaming/tokens.py b/server/http_streaming/tokens.py deleted file mode 100644 index d424cf36d..000000000 --- a/server/http_streaming/tokens.py +++ /dev/null @@ -1,86 +0,0 @@ -""" -Token management system for ComfyStream HTTP streaming. - -This module handles the creation, validation, and management of stream tokens. -""" -import time -import secrets -import logging -from aiohttp import web - -logger = logging.getLogger(__name__) - -# Constants -SESSION_CLEANUP_INTERVAL = 60 # Clean up expired sessions every 60 seconds - -# Global token storage -active_stream_sessions = {} -last_cleanup_time = 0 - -def cleanup_expired_sessions(): - """Clean up expired stream sessions""" - global active_stream_sessions, last_cleanup_time - - current_time = time.time() - - # Only clean up if it's been at least SESSION_CLEANUP_INTERVAL since last cleanup - if current_time - last_cleanup_time < SESSION_CLEANUP_INTERVAL: - return - - # Update the last cleanup time - last_cleanup_time = current_time - - # Find expired sessions - expired_sessions = [sid for sid, expires in active_stream_sessions.items() if current_time > expires] - - # Remove expired sessions - for sid in expired_sessions: - logger.info(f"Removing expired session: {sid[:8]}...") - del active_stream_sessions[sid] - - if expired_sessions: - logger.info(f"Cleaned up {len(expired_sessions)} expired sessions. {len(active_stream_sessions)} active sessions remaining.") - -async def create_stream_token(request): - """Create a unique stream token for secure access to the stream""" - global active_stream_sessions - - # Clean up expired sessions - cleanup_expired_sessions() - - current_time = time.time() - - # Generate a new unique token - stream_id = secrets.token_urlsafe(32) - expires_at = current_time + 3600 # 1 hour from now - - # Store the new session - active_stream_sessions[stream_id] = expires_at - - logger.info(f"Generated new stream token: {stream_id[:8]}... ({len(active_stream_sessions)} active sessions)") - - return web.json_response({ - "stream_id": stream_id, - "expires_at": int(expires_at) - }) - -def validate_token(token): - """Validate a stream token and return whether it's valid - - Args: - token: The token to validate - - Returns: - tuple: (is_valid, error_message) - """ - if not token or token not in active_stream_sessions: - return False, "Invalid stream token" - - # Check if token is expired - current_time = time.time() - if current_time > active_stream_sessions[token]: - # Remove expired token - del active_stream_sessions[token] - return False, "Stream token expired" - - return True, None diff --git a/server/public/stream.html b/server/public/stream.html deleted file mode 100644 index 536781f97..000000000 --- a/server/public/stream.html +++ /dev/null @@ -1,60 +0,0 @@ - - - - - - ComfyStream - OBS Capture - - - - - Video Stream - - diff --git a/setup.py b/setup.py index fc1f76c84..606849326 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,3 @@ from setuptools import setup -setup() \ No newline at end of file +setup() diff --git a/src/comfystream/__init__.py b/src/comfystream/__init__.py index b58bf2e44..398a9273a 100644 --- a/src/comfystream/__init__.py +++ b/src/comfystream/__init__.py @@ -1,14 +1,19 @@ from .client import ComfyStreamClient +from .exceptions import ComfyStreamAudioBufferError, ComfyStreamInputTimeoutError from .pipeline import Pipeline -from .server.utils import temporary_log_level -from .server.utils import FPSMeter +from .pipeline_state import PipelineState, PipelineStateManager from .server.metrics import MetricsManager, StreamStatsManager +from .server.utils import FPSMeter, temporary_log_level __all__ = [ - 'ComfyStreamClient', - 'Pipeline', - 'temporary_log_level', - 'FPSMeter', - 'MetricsManager', - 'StreamStatsManager' + "ComfyStreamClient", + "Pipeline", + "PipelineState", + "PipelineStateManager", + "temporary_log_level", + "FPSMeter", + "MetricsManager", + "StreamStatsManager", + "ComfyStreamInputTimeoutError", + "ComfyStreamAudioBufferError", ] diff --git a/src/comfystream/client.py b/src/comfystream/client.py index b5c7dca7f..de5432448 100644 --- a/src/comfystream/client.py +++ b/src/comfystream/client.py @@ -1,14 +1,16 @@ import asyncio -from typing import List +import contextlib import logging - -from comfystream import tensor_cache -from comfystream.utils import convert_prompt +from typing import List from comfy.api.components.schema.prompt import PromptDictInput from comfy.cli_args_types import Configuration from comfy.client.embedded_comfy_client import EmbeddedComfyClient +from comfystream import tensor_cache +from comfystream.exceptions import ComfyStreamInputTimeoutError +from comfystream.utils import convert_prompt, get_default_workflow + logger = logging.getLogger(__name__) @@ -16,34 +18,37 @@ class ComfyStreamClient: def __init__(self, max_workers: int = 1, **kwargs): config = Configuration(**kwargs) self.comfy_client = EmbeddedComfyClient(config, max_workers=max_workers) - self.running_prompts = {} # To be used for cancelling tasks self.current_prompts = [] self._cleanup_lock = asyncio.Lock() self._prompt_update_lock = asyncio.Lock() - self._stop_event = asyncio.Event() + + # PromptRunner state + self._shutdown_event = asyncio.Event() + self._run_enabled_event = asyncio.Event() + self._runner_task = None async def set_prompts(self, prompts: List[PromptDictInput]): """Set new prompts, replacing any existing ones. - + Args: prompts: List of prompt dictionaries to set - + Raises: ValueError: If prompts list is empty Exception: If prompt conversion or validation fails """ if not prompts: raise ValueError("Cannot set empty prompts list") - - # Cancel existing prompts first to avoid conflicts - await self.cancel_running_prompts() - # Reset stop event for new prompts - self._stop_event.clear() + + # Pause runner while swapping prompts to avoid interleaving + was_running = self._run_enabled_event.is_set() + self._run_enabled_event.clear() self.current_prompts = [convert_prompt(prompt) for prompt in prompts] - logger.info(f"Queuing {len(self.current_prompts)} prompt(s) for execution") - for idx in range(len(self.current_prompts)): - task = asyncio.create_task(self.run_prompt(idx)) - self.running_prompts[idx] = task + logger.info(f"Configured {len(self.current_prompts)} prompt(s)") + # Ensure runner exists (IDLE until resumed) + await self.ensure_prompt_tasks_running() + if was_running: + self._run_enabled_event.set() async def update_prompts(self, prompts: List[PromptDictInput]): async with self._prompt_update_lock: @@ -56,30 +61,63 @@ async def update_prompts(self, prompts: List[PromptDictInput]): for idx, prompt in enumerate(prompts): converted_prompt = convert_prompt(prompt) try: + # Lightweight validation by queueing is retained for compatibility await self.comfy_client.queue_prompt(converted_prompt) self.current_prompts[idx] = converted_prompt except Exception as e: raise Exception(f"Prompt update failed: {str(e)}") from e - async def run_prompt(self, prompt_index: int): - while not self._stop_event.is_set(): - async with self._prompt_update_lock: - try: - await self.comfy_client.queue_prompt(self.current_prompts[prompt_index]) - except asyncio.CancelledError: - raise - except Exception as e: - await self.cleanup() - logger.error(f"Error running prompt: {str(e)}") - raise + async def ensure_prompt_tasks_running(self): + # Ensure the single runner task exists (does not force running) + if self._runner_task and not self._runner_task.done(): + return + if not self.current_prompts: + return + self._shutdown_event.clear() + self._runner_task = asyncio.create_task(self._runner_loop()) + + async def _runner_loop(self): + try: + while not self._shutdown_event.is_set(): + # IDLE until running is enabled + await self._run_enabled_event.wait() + # Snapshot prompts without holding the lock during network I/O + async with self._prompt_update_lock: + prompts_snapshot = list(self.current_prompts) + for prompt_index, prompt in enumerate(prompts_snapshot): + if self._shutdown_event.is_set() or not self._run_enabled_event.is_set(): + break + try: + await self.comfy_client.queue_prompt(prompt) + except asyncio.CancelledError: + raise + except ComfyStreamInputTimeoutError: + logger.info(f"Input for prompt {prompt_index} timed out, continuing") + continue + except Exception as e: + logger.error(f"Error running prompt: {str(e)}") + logger.info("Stopping prompt execution and returning to passthrough mode") + + # Stop running and switch to default passthrough workflow + await self._fallback_to_passthrough() + break + except asyncio.CancelledError: + pass async def cleanup(self): - # Set stop event to signal prompt loops to exit - self._stop_event.set() - - await self.cancel_running_prompts() + # Signal runner to shutdown + self._shutdown_event.set() + if self._runner_task: + self._runner_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._runner_task + self._runner_task = None + + # Pause running + self._run_enabled_event.clear() + async with self._cleanup_lock: - if self.comfy_client.is_running: + if getattr(self.comfy_client, "is_running", False): try: await self.comfy_client.__aexit__() except Exception as e: @@ -88,18 +126,32 @@ async def cleanup(self): await self.cleanup_queues() logger.info("Client cleanup complete") - async def cancel_running_prompts(self): - async with self._cleanup_lock: - tasks_to_cancel = list(self.running_prompts.values()) - for task in tasks_to_cancel: - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - self.running_prompts.clear() + def pause_prompts(self): + """Pause prompt execution loops without canceling underlying tasks.""" + self._run_enabled_event.clear() + logger.debug("Prompt execution paused") + + async def resume_prompts(self): + """Resume prompt execution loops.""" + await self.ensure_prompt_tasks_running() + self._run_enabled_event.set() + logger.debug("Prompt execution resumed") + + async def stop_prompts(self, cleanup: bool = False): + """Stop running prompts by canceling their tasks. + + Args: + cleanup: If True, perform full cleanup including queue clearing and + client shutdown. If False, only cancel prompt tasks. + """ + await self.stop_prompts_immediately() + + if cleanup: + await self.cleanup() + logger.info("Prompts stopped with full cleanup") + else: + logger.debug("Prompts stopped (tasks cancelled)") - async def cleanup_queues(self): while not tensor_cache.image_inputs.empty(): tensor_cache.image_inputs.get() @@ -116,20 +168,50 @@ async def cleanup_queues(self): while not tensor_cache.text_outputs.empty(): await tensor_cache.text_outputs.get() + async def stop_prompts_immediately(self): + """Cancel the runner task to immediately stop any in-flight prompt execution.""" + self._run_enabled_event.clear() + if self._runner_task: + self._runner_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._runner_task + self._runner_task = None + + async def _fallback_to_passthrough(self): + """Switch to default passthrough workflow when an error occurs.""" + try: + # Pause the runner + self._run_enabled_event.clear() + + # Set to default passthrough workflow + default_workflow = get_default_workflow() + async with self._prompt_update_lock: + self.current_prompts = [convert_prompt(default_workflow)] + + logger.info("Switched to default passthrough workflow") + + # Resume the runner with passthrough workflow + self._run_enabled_event.set() + + except Exception as e: + logger.error(f"Failed to fallback to passthrough: {str(e)}") + # If fallback fails, just pause execution + self._run_enabled_event.clear() + def put_video_input(self, frame): if tensor_cache.image_inputs.full(): tensor_cache.image_inputs.get(block=True) tensor_cache.image_inputs.put(frame) - + def put_audio_input(self, frame): tensor_cache.audio_inputs.put(frame) async def get_video_output(self): return await tensor_cache.image_outputs.get() - + async def get_audio_output(self): return await tensor_cache.audio_outputs.get() - + async def get_text_output(self): try: return tensor_cache.text_outputs.get_nowait() @@ -144,25 +226,20 @@ async def get_text_output(self): async def get_available_nodes(self): """Get metadata and available nodes info in a single pass""" # TODO: make it for for multiple prompts - if not self.running_prompts: + if not self.current_prompts: return {} try: from comfy.nodes.package import import_all_nodes_in_workspace + nodes = import_all_nodes_in_workspace() all_prompts_nodes_info = {} - + for prompt_index, prompt in enumerate(self.current_prompts): # Get set of class types we need metadata for, excluding LoadTensor and SaveTensor - needed_class_types = { - node.get('class_type') - for node in prompt.values() - } - remaining_nodes = { - node_id - for node_id, node in prompt.items() - } + needed_class_types = {node.get("class_type") for node in prompt.values()} + remaining_nodes = {node_id for node_id, node in prompt.items()} nodes_info = {} # Only process nodes until we've found all the ones we need @@ -174,87 +251,88 @@ async def get_available_nodes(self): continue # Get metadata for this node type (same as original get_node_metadata) - input_data = node_class.INPUT_TYPES() if hasattr(node_class, 'INPUT_TYPES') else {} + input_data = ( + node_class.INPUT_TYPES() if hasattr(node_class, "INPUT_TYPES") else {} + ) input_info = {} # Process required inputs - if 'required' in input_data: - for name, value in input_data['required'].items(): + if "required" in input_data: + for name, value in input_data["required"].items(): if isinstance(value, tuple): if len(value) == 1 and isinstance(value[0], list): # Handle combo box case where value is ([option1, option2, ...],) input_info[name] = { - 'type': 'combo', - 'value': value[0], # The list of options becomes the value + "type": "combo", + "value": value[0], # The list of options becomes the value } elif len(value) == 2: input_type, config = value input_info[name] = { - 'type': input_type, - 'required': True, - 'min': config.get('min', None), - 'max': config.get('max', None), - 'widget': config.get('widget', None) + "type": input_type, + "required": True, + "min": config.get("min", None), + "max": config.get("max", None), + "widget": config.get("widget", None), } elif len(value) == 1: # Handle simple type case like ('IMAGE',) - input_info[name] = { - 'type': value[0] - } + input_info[name] = {"type": value[0]} else: - logger.error(f"Unexpected structure for required input {name}: {value}") + logger.error( + f"Unexpected structure for required input {name}: {value}" + ) # Process optional inputs with same logic - if 'optional' in input_data: - for name, value in input_data['optional'].items(): + if "optional" in input_data: + for name, value in input_data["optional"].items(): if isinstance(value, tuple): if len(value) == 1 and isinstance(value[0], list): # Handle combo box case where value is ([option1, option2, ...],) input_info[name] = { - 'type': 'combo', - 'value': value[0], # The list of options becomes the value + "type": "combo", + "value": value[0], # The list of options becomes the value } elif len(value) == 2: input_type, config = value input_info[name] = { - 'type': input_type, - 'required': False, - 'min': config.get('min', None), - 'max': config.get('max', None), - 'widget': config.get('widget', None) + "type": input_type, + "required": False, + "min": config.get("min", None), + "max": config.get("max", None), + "widget": config.get("widget", None), } elif len(value) == 1: # Handle simple type case like ('IMAGE',) - input_info[name] = { - 'type': value[0] - } + input_info[name] = {"type": value[0]} else: - logger.error(f"Unexpected structure for optional input {name}: {value}") + logger.error( + f"Unexpected structure for optional input {name}: {value}" + ) # Now process any nodes in our prompt that use this class_type for node_id in list(remaining_nodes): node = prompt[node_id] - if node.get('class_type') != class_type: + if node.get("class_type") != class_type: continue - node_info = { - 'class_type': class_type, - 'inputs': {} - } + node_info = {"class_type": class_type, "inputs": {}} - if 'inputs' in node: - for input_name, input_value in node['inputs'].items(): + if "inputs" in node: + for input_name, input_value in node["inputs"].items(): input_metadata = input_info.get(input_name, {}) - node_info['inputs'][input_name] = { - 'value': input_value, - 'type': input_metadata.get('type', 'unknown'), - 'min': input_metadata.get('min', None), - 'max': input_metadata.get('max', None), - 'widget': input_metadata.get('widget', None) + node_info["inputs"][input_name] = { + "value": input_value, + "type": input_metadata.get("type", "unknown"), + "min": input_metadata.get("min", None), + "max": input_metadata.get("max", None), + "widget": input_metadata.get("widget", None), } # For combo type inputs, include the list of options - if input_metadata.get('type') == 'combo': - node_info['inputs'][input_name]['value'] = input_metadata.get('value', []) + if input_metadata.get("type") == "combo": + node_info["inputs"][input_name]["value"] = input_metadata.get( + "value", [] + ) nodes_info[node_id] = node_info remaining_nodes.remove(node_id) diff --git a/src/comfystream/exceptions.py b/src/comfystream/exceptions.py new file mode 100644 index 000000000..53a97e870 --- /dev/null +++ b/src/comfystream/exceptions.py @@ -0,0 +1,123 @@ +"""ComfyStream specific exceptions.""" + +import logging +from typing import Any, Dict, Optional + + +def log_comfystream_error( + exception: Exception, logger: Optional[logging.Logger] = None, level: int = logging.ERROR +) -> None: + """ + Centralized logging function for ComfyStream exceptions. + + Args: + exception: The exception to log + logger: Optional logger to use (defaults to module logger) + level: Log level (defaults to ERROR) + """ + if logger is None: + logger = logging.getLogger(__name__) + + # If it's a ComfyStream timeout error with structured details, use its logging method + if isinstance(exception, ComfyStreamInputTimeoutError): + exception.log_error(logger) + else: + # For other exceptions, provide basic logging + logger.log(level, f"ComfyStream error: {type(exception).__name__}: {str(exception)}") + + +class ComfyStreamInputTimeoutError(Exception): + """Raised when input tensors are not available within timeout.""" + + def __init__( + self, input_type: str, timeout_seconds: float, details: Optional[Dict[str, Any]] = None + ): + self.input_type = input_type + self.timeout_seconds = timeout_seconds + self.details = details or {} + message = f"No {input_type} frames available after {timeout_seconds}s timeout" + super().__init__(message) + + def get_log_details(self) -> Dict[str, Any]: + """Get structured details for logging.""" + base_details = {"input_type": self.input_type, "timeout_seconds": self.timeout_seconds} + base_details.update(self.details) + return base_details + + def log_error(self, logger: Optional[logging.Logger] = None) -> None: + """Log the error with detailed information.""" + if logger is None: + logger = logging.getLogger(__name__) + + details = self.get_log_details() + detail_str = ", ".join(f"{k}={v}" for k, v in details.items()) + logger.error(f"ComfyStream timeout error: {str(self)} | Details: {detail_str}") + + +class ComfyStreamAudioBufferError(ComfyStreamInputTimeoutError): + """Audio buffer insufficient data error.""" + + def __init__(self, timeout_seconds: float, needed_samples: int, available_samples: int): + self.needed_samples = needed_samples + self.available_samples = available_samples + + # Pass audio-specific details to the base class + audio_details = { + "needed_samples": needed_samples, + "available_samples": available_samples, + } + super().__init__("audio", timeout_seconds, details=audio_details) + + def get_log_details(self) -> Dict[str, Any]: + """Get structured details for logging, with audio-specific formatting.""" + details = super().get_log_details() + return details + + +class ComfyStreamTimeoutFilter(logging.Filter): + """Filter to suppress verbose ComfyUI execution logs for ComfyStream timeout exceptions.""" + + def filter(self, record): + """Filter out ComfyUI execution error logs for ComfyStream timeout exceptions.""" + try: + # Only filter ERROR level messages from ComfyUI execution system + if record.levelno != logging.ERROR: + return True + + # Determine if this record is from a logger we want to inspect for timeout suppression + is_comfy_execution_logger = record.name.startswith("comfy") and ( + "execution" in record.name or record.name == "comfy" + ) + is_comfystream_logger = record.name.startswith("comfystream") + + if not (is_comfy_execution_logger or is_comfystream_logger): + return True + + # Get the full message including any exception info + message = record.getMessage() + + # Simple check: if this log contains ComfyStreamAudioBufferError or ComfyStreamInputTimeoutError, suppress it + if ( + "ComfyStreamAudioBufferError" in message + or "ComfyStreamInputTimeoutError" in message + ): + return False + + # Also check the exception info if present + if record.exc_info and record.exc_info[1]: + exc_str = str(record.exc_info[1]) + exc_type = str(type(record.exc_info[1])) + + if ( + "ComfyStreamAudioBufferError" in exc_str + or "ComfyStreamInputTimeoutError" in exc_str + or "ComfyStreamAudioBufferError" in exc_type + or "ComfyStreamInputTimeoutError" in exc_type + ): + return False + + return True + except Exception as e: + # If filter fails, allow the log through and print the error + print(f"[FILTER ERROR] Filter failed: {e}") + return True diff --git a/src/comfystream/modalities.py b/src/comfystream/modalities.py index fccfabad5..ded168296 100644 --- a/src/comfystream/modalities.py +++ b/src/comfystream/modalities.py @@ -1,27 +1,29 @@ -from typing import Dict, Any, Set, Union, List, TypedDict +from typing import Any, Dict, List, Set, TypedDict, Union class ModalityIO(TypedDict): """Input/output capabilities for a single modality.""" + input: bool output: bool + class WorkflowModality(TypedDict): """Workflow modality detection result mapping modalities to their I/O capabilities.""" + video: ModalityIO audio: ModalityIO text: ModalityIO + # Centralized node type definitions NODE_TYPES = { # Video nodes "video_input": {"LoadTensor", "PrimaryInputLoadImage", "LoadImage"}, "video_output": {"SaveTensor", "PreviewImage", "SaveImage"}, - # Audio nodes "audio_input": {"LoadAudioTensor"}, "audio_output": {"SaveAudioTensor"}, - # Text nodes "text_input": set(), # No text input nodes currently "text_output": {"SaveTextTensor"}, @@ -29,7 +31,9 @@ class WorkflowModality(TypedDict): # Flatten all input and output node types for easier checking all_input_nodes = NODE_TYPES["video_input"] | NODE_TYPES["audio_input"] | NODE_TYPES["text_input"] -all_output_nodes = NODE_TYPES["video_output"] | NODE_TYPES["audio_output"] | NODE_TYPES["text_output"] +all_output_nodes = ( + NODE_TYPES["video_output"] | NODE_TYPES["audio_output"] | NODE_TYPES["text_output"] +) # Modality mappings derived from NODE_TYPES MODALITY_MAPPINGS = { @@ -55,41 +59,45 @@ class WorkflowModality(TypedDict): "SaveImage": "output_replacement", } + def get_node_counts_by_type(prompt: Dict[Any, Any]) -> Dict[str, int]: """Count nodes by their functional types (primary inputs, inputs, outputs).""" counts = {"primary_inputs": 0, "inputs": 0, "outputs": 0} - + for node in prompt.values(): class_type = node.get("class_type") - + if class_type == "PrimaryInputLoadImage": counts["primary_inputs"] += 1 elif class_type in all_input_nodes: counts["inputs"] += 1 elif class_type in all_output_nodes: counts["outputs"] += 1 - + return counts + def get_convertible_node_keys(prompt: Dict[Any, Any]) -> Dict[str, List[str]]: """Collect keys of nodes that need conversion, organized by node type.""" keys = {node_type: [] for node_type in CONVERTIBLE_NODES.keys()} - + for key, node in prompt.items(): class_type = node.get("class_type") if class_type in keys: keys[class_type].append(key) - + return keys + def create_empty_workflow_modality() -> WorkflowModality: """Create an empty WorkflowModality with all capabilities set to False.""" return { "video": {"input": False, "output": False}, "audio": {"input": False, "output": False}, - "text": {"input": False, "output": False}, + "text": {"input": False, "output": False}, } + def _merge_workflow_modalities(base: WorkflowModality, other: WorkflowModality) -> WorkflowModality: """Merge two WorkflowModality objects using logical OR for all capabilities.""" for modality in base: @@ -97,6 +105,7 @@ def _merge_workflow_modalities(base: WorkflowModality, other: WorkflowModality) base[modality][direction] = base[modality][direction] or other[modality][direction] return base + def detect_io_points(prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]) -> WorkflowModality: """Detect input/output presence per modality for a workflow. @@ -117,7 +126,7 @@ def detect_io_points(prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]) -> Wo # Scan nodes and detect modality I/O points using centralized mappings for node in prompts.values(): class_type = node.get("class_type", "") - + for modality, directions in MODALITY_MAPPINGS.items(): if class_type in directions["input"]: result[modality]["input"] = True @@ -126,17 +135,18 @@ def detect_io_points(prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]) -> Wo return result + def detect_prompt_modalities(prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]) -> Set[str]: """Detect which modalities are used by a workflow. - + Returns a set of modality names that have either input or output nodes. This is used by the pipeline to determine which modalities need processing. """ io_points = detect_io_points(prompts) modalities = set() - + for modality, capabilities in io_points.items(): if capabilities["input"] or capabilities["output"]: modalities.add(modality) - + return modalities diff --git a/src/comfystream/pipeline.py b/src/comfystream/pipeline.py index 3e4febaf1..cf2302de7 100644 --- a/src/comfystream/pipeline.py +++ b/src/comfystream/pipeline.py @@ -1,42 +1,63 @@ -import av -import torch -import numpy as np import asyncio import logging -from typing import Any, Dict, Union, List, Optional, Set +from typing import Any, Dict, List, Optional, Set, Union + +import av +import numpy as np +import torch from comfystream.client import ComfyStreamClient +from comfystream.pipeline_state import PipelineState, PipelineStateManager from comfystream.server.utils import temporary_log_level -from .modalities import detect_prompt_modalities, detect_io_points, WorkflowModality -from .modalities import create_empty_workflow_modality - +from comfystream.utils import get_default_workflow + +from .modalities import ( + WorkflowModality, + create_empty_workflow_modality, + detect_io_points, + detect_prompt_modalities, +) + WARMUP_RUNS = 5 +BOOTSTRAP_TIMEOUT_SECONDS = 30.0 logger = logging.getLogger(__name__) class Pipeline: """A pipeline for processing video and audio frames using ComfyUI. - + This class provides a high-level interface for processing video and audio frames through a ComfyUI-based processing pipeline. It handles frame preprocessing, postprocessing, and queue management. """ - - def __init__(self, width: int = 512, height: int = 512, - comfyui_inference_log_level: Optional[int] = None, **kwargs): + + def __init__( + self, + width: int = 512, + height: int = 512, + comfyui_inference_log_level: Optional[int] = None, + auto_warmup: bool = False, + bootstrap_default_prompt: bool = True, + **kwargs, + ): """Initialize the pipeline with the given configuration. - + Args: width: Width of the video frames (default: 512) height: Height of the video frames (default: 512) comfyui_inference_log_level: The logging level for ComfyUI inference. Defaults to None, using the global ComfyUI log level. + auto_warmup: Whether to run warmup automatically after prompts are set. + bootstrap_default_prompt: Whether to run the default workflow once during + initialization to start ComfyUI before prompts are applied. **kwargs: Additional arguments to pass to the ComfyStreamClient """ self.client = ComfyStreamClient(**kwargs) self.width = width self.height = height + self.auto_warmup = auto_warmup + self.bootstrap_default_prompt = bootstrap_default_prompt self.video_incoming_frames = asyncio.Queue() self.audio_incoming_frames = asyncio.Queue() @@ -46,6 +67,143 @@ def __init__(self, width: int = 512, height: int = 512, self._comfyui_inference_log_level = comfyui_inference_log_level self._cached_modalities: Optional[Set[str]] = None self._cached_io_capabilities: Optional[WorkflowModality] = None + self.state_manager = PipelineStateManager(self.client) + self._bootstrap_completed = False + self._initialize_lock = asyncio.Lock() + self._ingest_enabled = True + self._prompt_update_lock = asyncio.Lock() + + @property + def state(self) -> PipelineState: + """Expose current pipeline state.""" + return self.state_manager.state + + async def initialize(self): + """Run optional bootstrap workflow to start ComfyUI before prompts are set.""" + if self._bootstrap_completed or not self.bootstrap_default_prompt: + return + + async with self._initialize_lock: + if self._bootstrap_completed or not self.bootstrap_default_prompt: + return + + logger.info("Bootstrapping ComfyUI with default workflow") + await self._run_bootstrap_prompt() + self._bootstrap_completed = True + + async def _run_bootstrap_prompt(self): + """Run the default workflow once with a dummy frame to start ComfyUI.""" + default_workflow = get_default_workflow() + logger.debug("Running default workflow bootstrap prompt") + + try: + await self.client.set_prompts([default_workflow]) + await self.client.resume_prompts() + + dummy_frame = av.VideoFrame() + dummy_frame.side_data.input = torch.randn(1, self.height, self.width, 3) + self.client.put_video_input(dummy_frame) + + await asyncio.wait_for( + self.client.get_video_output(), + timeout=BOOTSTRAP_TIMEOUT_SECONDS, + ) + logger.info("Bootstrap prompt completed successfully") + except asyncio.TimeoutError as exc: + logger.error("Timeout while waiting for bootstrap prompt output") + raise RuntimeError("Bootstrap prompt timed out while waiting for output") from exc + finally: + try: + await self.client.stop_prompts(cleanup=False) + except Exception: + logger.debug("Failed to stop bootstrap prompts cleanly", exc_info=True) + + self.client.current_prompts = [] + self._cached_modalities = None + self._cached_io_capabilities = None + + try: + await self.client.cleanup_queues() + except Exception: + logger.debug("Failed to clear tensor caches after bootstrap prompt", exc_info=True) + + async def warmup( + self, + *, + warm_video: Optional[bool] = None, + warm_audio: Optional[bool] = None, + ): + """Run warmup for selected modalities while managing pipeline state.""" + if not self.state_manager.is_initialized(): + raise RuntimeError("Cannot warm up pipeline before prompts are initialized") + + state_before = self.state + transitioned = False + warmup_successful = False + + try: + if state_before != PipelineState.STREAMING: + await self.state_manager.transition_to(PipelineState.INITIALIZING) + transitioned = True + + await self._run_warmup( + warm_video=warm_video, + warm_audio=warm_audio, + ) + warmup_successful = True + except Exception: + await self.state_manager.transition_to(PipelineState.ERROR) + raise + finally: + if transitioned and warmup_successful: + try: + await self.state_manager.transition_to(PipelineState.READY) + except Exception: + logger.exception("Failed to transition pipeline to READY after warmup") + warmup_successful = False + + if warmup_successful and state_before == PipelineState.STREAMING: + try: + await self.state_manager.transition_to(PipelineState.STREAMING) + except Exception: + logger.exception("Failed to restore STREAMING state after warmup") + + async def _run_warmup( + self, + *, + warm_video: Optional[bool] = None, + warm_audio: Optional[bool] = None, + ): + """Run warmup routines for video and audio as requested.""" + capabilities = self.get_workflow_io_capabilities() + + video_config = capabilities.get("video", {}) + audio_config = capabilities.get("audio", {}) + + should_warm_video = ( + warm_video + if warm_video is not None + else bool(video_config.get("input") or video_config.get("output")) + ) + should_warm_audio = ( + warm_audio + if warm_audio is not None + else bool(audio_config.get("input") or audio_config.get("output")) + ) + + if should_warm_video: + logger.debug("Running video warmup routine") + await self.warm_video() + + if should_warm_audio: + logger.debug("Running audio warmup routine") + await self.warm_audio() + + logger.info( + "Pipeline warmup completed (video=%s, audio=%s)", + should_warm_video, + should_warm_audio, + ) async def warm_video(self): """Warm up the video processing pipeline with dummy frames.""" @@ -53,16 +211,16 @@ async def warm_video(self): if not self.accepts_video_input(): logger.debug("Skipping video warmup - workflow doesn't accept video input") return - + # Create dummy frame with the CURRENT resolution settings dummy_frame = av.VideoFrame() dummy_frame.side_data.input = torch.randn(1, self.height, self.width, 3) - + logger.debug(f"Warming video pipeline with resolution {self.width}x{self.height}") for _ in range(WARMUP_RUNS): self.client.put_video_input(dummy_frame) - + # Wait on the outputs that the workflow actually produces if self.produces_video_output(): await self.client.get_video_output() @@ -77,14 +235,16 @@ async def warm_audio(self): if not self.accepts_audio_input(): logger.debug("Skipping audio warmup - workflow doesn't accept audio input") return - + dummy_frame = av.AudioFrame() - dummy_frame.side_data.input = np.random.randint(-32768, 32768, int(48000 * 0.5), dtype=np.int16) + dummy_frame.side_data.input = np.random.randint( + -32768, 32768, int(48000 * 0.5), dtype=np.int16 + ) dummy_frame.sample_rate = 48000 for _ in range(WARMUP_RUNS): self.client.put_audio_input(dummy_frame) - + # Wait on the outputs that the workflow actually produces if self.produces_video_output(): await self.client.get_video_output() @@ -93,39 +253,245 @@ async def warm_audio(self): if self.produces_text_output(): await self.client.get_text_output() - async def set_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]): + async def set_prompts( + self, + prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]], + *, + skip_warmup: bool = False, + ): """Set the processing prompts for the pipeline. - + Args: prompts: Either a single prompt dictionary or a list of prompt dictionaries + skip_warmup: Skip automatic warmup even if auto_warmup is enabled """ - if isinstance(prompts, list): - await self.client.set_prompts(prompts) - else: - await self.client.set_prompts([prompts]) - - # Clear cached modalities and I/O capabilities when prompts change - self._cached_modalities = None - self._cached_io_capabilities = None + try: + prompt_list = prompts if isinstance(prompts, list) else [prompts] + await self.client.set_prompts(prompt_list) - async def update_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]): + # Refresh cached modalities and I/O capabilities from the new prompts + self._cached_modalities = detect_prompt_modalities(self.client.current_prompts) + self._cached_io_capabilities = detect_io_points(self.client.current_prompts) + + should_warmup = self.auto_warmup and not skip_warmup + if should_warmup: + await self.state_manager.transition_to(PipelineState.INITIALIZING) + try: + await self._run_warmup() + except Exception: + await self.state_manager.transition_to(PipelineState.ERROR) + raise + + await self.state_manager.transition_to(PipelineState.READY) + except Exception: + logger.exception("Failed to set pipeline prompts") + try: + await self.state_manager.transition_to(PipelineState.ERROR) + except ValueError: + logger.debug("Skipping ERROR transition due to invalid state") + except Exception: + logger.exception("Failed to transition pipeline to ERROR state") + raise + + async def update_prompts( + self, + prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]], + *, + skip_warmup: bool = False, + ): """Update the existing processing prompts. - + Args: prompts: Either a single prompt dictionary or a list of prompt dictionaries + skip_warmup: Skip automatic warmup even if auto_warmup is enabled + """ + was_streaming = self.state == PipelineState.STREAMING + should_warmup = self.auto_warmup and not skip_warmup + + try: + if was_streaming and should_warmup: + await self.state_manager.transition_to(PipelineState.READY) + + prompt_list = prompts if isinstance(prompts, list) else [prompts] + await self.client.update_prompts(prompt_list) + + # Refresh cached modalities and I/O capabilities from the updated prompts + self._cached_modalities = detect_prompt_modalities(self.client.current_prompts) + self._cached_io_capabilities = detect_io_points(self.client.current_prompts) + + if should_warmup: + await self.state_manager.transition_to(PipelineState.INITIALIZING) + try: + await self._run_warmup() + except Exception: + await self.state_manager.transition_to(PipelineState.ERROR) + raise + await self.state_manager.transition_to(PipelineState.READY) + + if was_streaming and self.state != PipelineState.STREAMING: + await self.state_manager.transition_to(PipelineState.STREAMING) + except Exception: + logger.exception("Failed to update pipeline prompts") + try: + await self.state_manager.transition_to(PipelineState.ERROR) + except ValueError: + logger.debug("Skipping ERROR transition due to invalid state") + except Exception: + logger.exception("Failed to transition pipeline to ERROR state") + raise + + def disable_ingest(self) -> None: + """Temporarily disable ingestion of new frames into the pipeline.""" + self._ingest_enabled = False + + def enable_ingest(self) -> None: + """Re-enable ingestion of new frames into the pipeline.""" + self._ingest_enabled = True + + def is_ingest_enabled(self) -> bool: + """Check if the pipeline is currently ingesting new frames.""" + return self._ingest_enabled + + async def apply_prompts( + self, + prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]], + *, + skip_warmup: bool = False, + warm_video: Optional[bool] = None, + warm_audio: Optional[bool] = None, + ) -> WorkflowModality: + """Atomically replace prompts while coordinating runner, queues, and state. + + This helper orchestrates prompt swaps by pausing streaming, cancelling any + in-flight prompt execution, clearing input queues, applying the new prompts, + warming the pipeline (unless explicitly skipped), and finally resuming + streaming if it was active beforehand. + + Args: + prompts: Prompt dictionary or list of prompt dictionaries to apply. + skip_warmup: If True, skip automatic warmup after applying prompts. + warm_video: Optional override for video warmup (None = auto-detect). + warm_audio: Optional override for audio warmup (None = auto-detect). + + Returns: + WorkflowModality describing I/O capabilities detected from the new prompts. + """ + prompt_list = prompts if isinstance(prompts, list) else [prompts] + + async with self._prompt_update_lock: + was_streaming = self.state == PipelineState.STREAMING + was_initialized = self.state_manager.is_initialized() + restart_streaming = False + capabilities: WorkflowModality | None = None + self.disable_ingest() + + try: + if was_streaming: + await self.pause_prompts() + + if was_initialized: + await self.stop_prompts_immediately() + + await self._clear_pipeline_queues() + await self.client.cleanup_queues() + + await self.set_prompts(prompt_list, skip_warmup=True) + + capabilities = self.get_workflow_io_capabilities() + video_capability = capabilities.get("video", {}) + audio_capability = capabilities.get("audio", {}) + + has_video_io = bool(video_capability.get("input") or video_capability.get("output")) + has_audio_io = bool(audio_capability.get("input") or audio_capability.get("output")) + + if not skip_warmup: + await self.warmup( + warm_video=warm_video if warm_video is not None else has_video_io, + warm_audio=warm_audio if warm_audio is not None else has_audio_io, + ) + + restart_streaming = was_streaming and self.state_manager.can_stream() + + except Exception: + raise + finally: + self.enable_ingest() + if restart_streaming and self.state_manager.can_stream(): + await self.start_streaming() + + return capabilities if capabilities is not None else self.get_workflow_io_capabilities() + + async def start_streaming(self): + """Enable prompt execution for active streaming.""" + if not self.state_manager.can_stream(): + raise RuntimeError(f"Cannot start streaming in state: {self.state.name}") + + await self.state_manager.transition_to(PipelineState.STREAMING) + + async def stop_streaming(self): + """Pause prompt execution while keeping prompts loaded.""" + if self.state == PipelineState.STREAMING: + await self.state_manager.transition_to(PipelineState.READY) + + async def pause_prompts(self): + """Pause prompt execution loops without canceling tasks.""" + await self.stop_streaming() + + async def resume_prompts(self): + """Resume paused prompt execution loops.""" + await self.start_streaming() + + def are_prompts_running(self) -> bool: + """Check if prompts are currently running. + + Returns: + True if prompts are enabled and running, False otherwise """ - if isinstance(prompts, list): - await self.client.update_prompts(prompts) + return self.state == PipelineState.STREAMING + + async def stop_prompts(self, cleanup: bool = False): + """Stop running prompts by canceling their tasks. + + Args: + cleanup: If True, perform full cleanup including queue clearing and + client shutdown. If False, only cancel prompt tasks. + """ + if self.state in {PipelineState.STREAMING, PipelineState.INITIALIZING}: + await self.state_manager.transition_to(PipelineState.READY) + + await self.client.stop_prompts(cleanup=cleanup) + + # Clear cached modalities and I/O capabilities when prompts are stopped + if cleanup: + self._cached_modalities = None + self._cached_io_capabilities = None + # Clear pipeline queues for full cleanup + await self._clear_pipeline_queues() + try: + await self.state_manager.transition_to(PipelineState.UNINITIALIZED) + except Exception: + logger.exception("Failed to transition pipeline to UNINITIALIZED during cleanup") else: - await self.client.update_prompts([prompts]) - - # Clear cached modalities and I/O capabilities when prompts change - self._cached_modalities = None - self._cached_io_capabilities = None + try: + await self.state_manager.transition_to(PipelineState.READY) + except ValueError: + logger.debug("Skipping READY transition due to invalid state") + except Exception: + logger.exception("Failed to ensure READY state after stopping prompts") + + async def stop_prompts_immediately(self): + """Cancel prompt execution tasks without full cleanup.""" + await self.client.stop_prompts_immediately() + try: + await self.state_manager.transition_to(PipelineState.READY) + except ValueError: + logger.debug("Skipping READY transition during immediate stop") + except Exception: + logger.exception("Failed to ensure READY state during immediate stop") async def put_video_frame(self, frame: av.VideoFrame): """Queue a video frame for processing. - + Args: frame: The video frame to process """ @@ -146,7 +512,7 @@ async def put_video_frame(self, frame: av.VideoFrame): async def put_audio_frame(self, frame: av.AudioFrame, preprocess: bool = True): """Queue an audio frame for processing. - + Args: frame: The audio frame to process """ @@ -168,33 +534,37 @@ async def put_audio_frame(self, frame: av.AudioFrame, preprocess: bool = True): def video_preprocess(self, frame: av.VideoFrame) -> torch.Tensor: """Preprocess a video frame before processing. - + Args: frame: The video frame to preprocess - + Returns: The preprocessed frame as a tensor or numpy array """ frame_np = frame.to_ndarray(format="rgb24").astype(np.float32) / 255.0 return torch.from_numpy(frame_np).unsqueeze(0) - + def audio_preprocess(self, frame: av.AudioFrame) -> np.ndarray: """Preprocess an audio frame before processing. - + Args: frame: The audio frame to preprocess - + Returns: The preprocessed frame as a numpy array with int16 dtype """ audio_data = frame.to_ndarray() - + # Handle multi-dimensional audio data - if audio_data.ndim == 2 and audio_data.shape[0] == 1 and audio_data.shape[0] <= audio_data.shape[1]: + if ( + audio_data.ndim == 2 + and audio_data.shape[0] == 1 + and audio_data.shape[0] <= audio_data.shape[1] + ): audio_data = audio_data.ravel().reshape(-1, 2).mean(axis=1) elif audio_data.ndim > 1: audio_data = audio_data.mean(axis=0) - + # Ensure we always return int16 data if audio_data.dtype in [np.float32, np.float64]: # Check if data is normalized (-1.0 to 1.0 range) @@ -209,15 +579,15 @@ def audio_preprocess(self, frame: av.AudioFrame) -> np.ndarray: else: # Already integer data - ensure it's int16 audio_data = audio_data.astype(np.int16) - + return audio_data - + def video_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.VideoFrame: """Postprocess a video frame after processing. - + Args: output: The processed output tensor or numpy array - + Returns: The postprocessed video frame """ @@ -227,32 +597,32 @@ def video_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.Video def audio_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.AudioFrame: """Postprocess an audio frame after processing. - + Args: output: The processed output tensor or numpy array - + Returns: The postprocessed audio frame """ return av.AudioFrame.from_ndarray(np.repeat(output, 2).reshape(1, -1)) - + # TODO: make it generic to support purely generative video cases async def get_processed_video_frame(self) -> av.VideoFrame: """Get the next processed video frame. - + Returns: The processed video frame, or original frame if no processing needed """ frame = await self.video_incoming_frames.get() - + # Skip frames that were marked as skipped - while frame.side_data.skipped and not hasattr(frame.side_data, 'passthrough'): + while frame.side_data.skipped and not hasattr(frame.side_data, "passthrough"): frame = await self.video_incoming_frames.get() - + # If this is a passthrough frame (no video output from workflow), return original - if hasattr(frame.side_data, 'passthrough') and frame.side_data.passthrough: + if hasattr(frame.side_data, "passthrough") and frame.side_data.passthrough: return frame - + # Get processed output from client async with temporary_log_level("comfy", self._comfyui_inference_log_level): out_tensor = await self.client.get_video_output() @@ -260,12 +630,12 @@ async def get_processed_video_frame(self) -> av.VideoFrame: processed_frame = self.video_postprocess(out_tensor) processed_frame.pts = frame.pts processed_frame.time_base = frame.time_base - + return processed_frame async def get_processed_audio_frame(self) -> av.AudioFrame: """Get the next processed audio frame. - + Returns: The processed audio frame, or original frame if no processing needed """ @@ -276,125 +646,124 @@ async def get_processed_audio_frame(self) -> av.AudioFrame: logger.debug("No audio frames available - generating silence frame") # Generate a silent audio frame to prevent blocking silent_frame = av.AudioFrame.from_ndarray( - np.zeros((1, 1024), dtype=np.int16), - format='s16', - layout='mono' + np.zeros((1, 1024), dtype=np.int16), format="s16", layout="mono" ) silent_frame.sample_rate = 48000 return silent_frame - + # If this is a passthrough frame (no audio output from workflow), return original - if hasattr(frame.side_data, 'passthrough') and frame.side_data.passthrough: + if hasattr(frame.side_data, "passthrough") and frame.side_data.passthrough: return frame - + # Process audio if needed if frame.samples > len(self.processed_audio_buffer): async with temporary_log_level("comfy", self._comfyui_inference_log_level): out_tensor = await self.client.get_audio_output() self.processed_audio_buffer = np.concatenate([self.processed_audio_buffer, out_tensor]) - - out_data = self.processed_audio_buffer[:frame.samples] - self.processed_audio_buffer = self.processed_audio_buffer[frame.samples:] + + out_data = self.processed_audio_buffer[: frame.samples] + self.processed_audio_buffer = self.processed_audio_buffer[frame.samples :] processed_frame = self.audio_postprocess(out_data) processed_frame.pts = frame.pts processed_frame.time_base = frame.time_base processed_frame.sample_rate = frame.sample_rate - + return processed_frame - + async def get_text_output(self) -> str | None: """Get the next text output from the pipeline. - + Returns: The processed text output, or empty string if no text output produced """ # If workflow doesn't produce text output, return empty string immediately if not self.produces_text_output(): return None - + async with temporary_log_level("comfy", self._comfyui_inference_log_level): out_text = await self.client.get_text_output() - + return out_text - + async def get_nodes_info(self) -> Dict[str, Any]: """Get information about all nodes in the current prompt including metadata. - + Returns: Dictionary containing node information """ nodes_info = await self.client.get_available_nodes() return nodes_info - + def get_workflow_io_capabilities(self) -> WorkflowModality: """Get the I/O capabilities for each modality in the current workflow. - + Returns: WorkflowModality mapping each modality to its input/output capabilities """ if self._cached_io_capabilities is None: - if not hasattr(self.client, 'current_prompts') or not self.client.current_prompts: - # Return empty capabilities if no prompts - return create_empty_workflow_modality() - - self._cached_io_capabilities = detect_io_points(self.client.current_prompts) - + if not hasattr(self.client, "current_prompts") or not self.client.current_prompts: + # Cache empty capabilities if no prompts to avoid repeated checks + self._cached_io_capabilities = create_empty_workflow_modality() + else: + self._cached_io_capabilities = detect_io_points(self.client.current_prompts) + return self._cached_io_capabilities def get_workflow_modalities(self) -> Set[str]: """Get the modalities required by the current workflow. - + Returns: Set of modality strings: {'video', 'audio', 'text'} """ if self._cached_modalities is None: - if not hasattr(self.client, 'current_prompts') or not self.client.current_prompts: - return set() - - self._cached_modalities = detect_prompt_modalities(self.client.current_prompts) - + if not hasattr(self.client, "current_prompts") or not self.client.current_prompts: + # Cache empty set if no prompts to avoid repeated checks + self._cached_modalities = set() + else: + self._cached_modalities = detect_prompt_modalities(self.client.current_prompts) + return self._cached_modalities - + def get_modalities(self) -> Set[str]: """Alias for get_workflow_modalities for compatibility.""" return self.get_workflow_modalities() - + def requires_video(self) -> bool: """Check if the workflow requires video processing.""" return "video" in self.get_workflow_modalities() - + def requires_audio(self) -> bool: """Check if the workflow requires audio processing.""" return "audio" in self.get_workflow_modalities() - + def requires_text(self) -> bool: """Check if the workflow requires text processing.""" return "text" in self.get_workflow_modalities() - + def accepts_video_input(self) -> bool: """Check if the workflow accepts video input.""" return self.get_workflow_io_capabilities()["video"]["input"] - + def accepts_audio_input(self) -> bool: """Check if the workflow accepts audio input.""" return self.get_workflow_io_capabilities()["audio"]["input"] - + def produces_video_output(self) -> bool: """Check if the workflow produces video output.""" return self.get_workflow_io_capabilities()["video"]["output"] - + def produces_audio_output(self) -> bool: """Check if the workflow produces audio output.""" return self.get_workflow_io_capabilities()["audio"]["output"] - + def produces_text_output(self) -> bool: """Check if the workflow produces text output.""" return self.get_workflow_io_capabilities()["text"]["output"] - + async def cleanup(self): """Clean up resources used by the pipeline. - + This includes: - Canceling running prompts - Clearing all queues (video, audio, tensor caches) @@ -402,19 +771,26 @@ async def cleanup(self): - Clearing cached modalities """ logger.debug("Starting pipeline cleanup") - + # Clear cached modalities and I/O capabilities since we're resetting self._cached_modalities = None self._cached_io_capabilities = None - + # Clear pipeline queues await self._clear_pipeline_queues() - + # Cleanup client (this handles prompt cancellation and tensor cache cleanup) await self.client.cleanup() - + + try: + await self.state_manager.transition_to(PipelineState.UNINITIALIZED) + except ValueError: + logger.debug("Skipping UNINITIALIZED transition during cleanup") + except Exception: + logger.exception("Failed to transition pipeline to UNINITIALIZED during cleanup") + logger.debug("Pipeline cleanup completed") - + async def _clear_pipeline_queues(self): """Clear the pipeline's internal frame queues.""" # Clear video frame queue @@ -423,15 +799,15 @@ async def _clear_pipeline_queues(self): self.video_incoming_frames.get_nowait() except asyncio.QueueEmpty: break - - # Clear audio frame queue + + # Clear audio frame queue while not self.audio_incoming_frames.empty(): try: self.audio_incoming_frames.get_nowait() except asyncio.QueueEmpty: break - + # Reset audio buffer self.processed_audio_buffer = np.array([], dtype=np.int16) - - logger.debug("Pipeline queues cleared") + + logger.debug("Pipeline queues cleared") diff --git a/src/comfystream/pipeline_state.py b/src/comfystream/pipeline_state.py new file mode 100644 index 000000000..0b2f93b94 --- /dev/null +++ b/src/comfystream/pipeline_state.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +import asyncio +import logging +from enum import Enum, auto +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from comfystream.client import ComfyStreamClient + +logger = logging.getLogger(__name__) + + +class PipelineState(Enum): + """Pipeline lifecycle states.""" + + UNINITIALIZED = auto() + INITIALIZING = auto() + READY = auto() + STREAMING = auto() + ERROR = auto() + + +class PipelineStateManager: + """Manages pipeline state transitions and runner lifecycle.""" + + def __init__(self, client: ComfyStreamClient): + self.client = client + self._state = PipelineState.UNINITIALIZED + self._state_lock = asyncio.Lock() + + @property + def state(self) -> PipelineState: + return self._state + + async def transition_to(self, new_state: PipelineState): + """Transition to a new state with automatic runner management.""" + # Fast path for no-op transitions + if new_state == self._state: + logger.debug("Pipeline state unchanged: %s", new_state.name) + return + + async with self._state_lock: + if new_state == self._state: + logger.debug("Pipeline state unchanged (locked): %s", new_state.name) + return + + old_state = self._state + + if not self._is_valid_transition(old_state, new_state): + raise ValueError(f"Invalid transition: {old_state.name} -> {new_state.name}") + + await self._on_exit_state(old_state) + self._state = new_state + await self._on_enter_state(new_state) + + logger.info("Pipeline state: %s -> %s", old_state.name, new_state.name) + + def _is_valid_transition(self, from_state: PipelineState, to_state: PipelineState) -> bool: + """Define valid state transitions.""" + valid_transitions = { + PipelineState.UNINITIALIZED: { + PipelineState.INITIALIZING, + PipelineState.READY, + PipelineState.ERROR, + }, + PipelineState.INITIALIZING: { + PipelineState.READY, + PipelineState.ERROR, + }, + PipelineState.READY: { + PipelineState.INITIALIZING, + PipelineState.STREAMING, + PipelineState.UNINITIALIZED, + PipelineState.ERROR, + }, + PipelineState.STREAMING: { + PipelineState.READY, + PipelineState.ERROR, + }, + PipelineState.ERROR: { + PipelineState.INITIALIZING, + PipelineState.READY, + PipelineState.UNINITIALIZED, + }, + } + + return to_state in valid_transitions.get(from_state, set()) + + async def _on_enter_state(self, state: PipelineState): + """Actions executed when entering a state.""" + if state == PipelineState.INITIALIZING: + await self.client.resume_prompts() + elif state == PipelineState.READY: + self.client.pause_prompts() + elif state == PipelineState.STREAMING: + await self.client.resume_prompts() + elif state == PipelineState.ERROR: + self.client.pause_prompts() + elif state == PipelineState.UNINITIALIZED: + self.client.pause_prompts() + + async def _on_exit_state(self, _state: PipelineState): + """Actions executed when exiting a state.""" + return + + def can_stream(self) -> bool: + """Check if pipeline is ready to stream.""" + return self._state in {PipelineState.READY, PipelineState.STREAMING} + + def is_initialized(self) -> bool: + """Check if pipeline has been initialized with prompts.""" + return self._state != PipelineState.UNINITIALIZED diff --git a/src/comfystream/scripts/README.md b/src/comfystream/scripts/README.md index 1d95d49fa..e593a376d 100644 --- a/src/comfystream/scripts/README.md +++ b/src/comfystream/scripts/README.md @@ -22,11 +22,23 @@ python src/comfystream/scripts/setup_nodes.py --workspace /path/to/comfyui ``` > The optional flag `--pull-branches` can be used to ensure the latest git changes are pulled for any custom nodes defined with a `branch` in nodes.yaml +#### Using a custom nodes configuration +```bash +python src/comfystream/scripts/setup_nodes.py --workspace /path/to/comfyui --config nodes-streamdiffusion.yaml +``` +> The `--config` flag accepts a filename (searches in `configs/`), relative path, or absolute path to a custom nodes configuration file + ### Download models and compile tensorrt engines ```bash python src/comfystream/scripts/setup_models.py --workspace /path/to/comfyui ``` +#### Using a custom models configuration +```bash +python src/comfystream/scripts/setup_models.py --workspace /path/to/comfyui --config models-minimal.yaml +``` +> The `--config` flag accepts a filename (searches in `configs/`), relative path, or absolute path to a custom models configuration file + ## Configuration Examples ### Custom Nodes (nodes.yaml) @@ -55,6 +67,10 @@ models: type: "checkpoint" ``` +> You can create custom model configurations for different use cases. See `configs/models-minimal.yaml` and `configs/models-pixelart.yaml` for examples. + +**Directory Downloads:** The script now supports downloading entire directories from HuggingFace! Add `is_directory: true` to your config. See `configs/models-ipadapter-example.yaml` for examples or read [DIRECTORY_DOWNLOADS.md](../../../DIRECTORY_DOWNLOADS.md) for the full guide. + ## Directory Structure ```sh diff --git a/src/comfystream/scripts/__init__.py b/src/comfystream/scripts/__init__.py index 9273e2423..abcb156fa 100644 --- a/src/comfystream/scripts/__init__.py +++ b/src/comfystream/scripts/__init__.py @@ -1,7 +1,7 @@ -from . import utils from utils import get_config_path, load_model_config -from .setup_nodes import run_setup_nodes + +from . import utils from .setup_models import run_setup_models +from .setup_nodes import run_setup_nodes """Setup scripts for ComfyUI streaming server""" - diff --git a/src/comfystream/scripts/build_trt.py b/src/comfystream/scripts/build_trt.py index 3ef1ac7c3..b329ff20d 100644 --- a/src/comfystream/scripts/build_trt.py +++ b/src/comfystream/scripts/build_trt.py @@ -1,9 +1,9 @@ #!/usr/bin/env python3 +import argparse import os import sys import time -import argparse # Reccomended running from comfystream conda environment # in devcontainer from the workspace/ directory, or comfystream/ if you've checked out the repo @@ -11,7 +11,7 @@ # $> python src/comfystream/scripts/build_trt.py --model /ComfyUI/models/checkpoints/SD1.5/dreamshaper-8.safetensors --out-engine /ComfyUI/output/tensorrt/static-dreamshaper8_SD15_$stat-b-1-h-512-w-512_00001_.engine # Paths path explicitly to use the downloaded comfyUI installation on root -ROOT_DIR="/workspace" +ROOT_DIR = "/workspace" COMFYUI_DIR = "/workspace/ComfyUI" timing_cache_path = "/workspace/ComfyUI/output/tensorrt/timing_cache" @@ -22,15 +22,16 @@ import comfy import comfy.model_management - -from ComfyUI.custom_nodes.ComfyUI_TensorRT.models.supported_models import detect_version_from_model, get_helper_from_model +from ComfyUI.custom_nodes.ComfyUI_TensorRT.models.supported_models import ( + detect_version_from_model, + get_helper_from_model, +) from ComfyUI.custom_nodes.ComfyUI_TensorRT.onnx_utils.export import export_onnx from ComfyUI.custom_nodes.ComfyUI_TensorRT.tensorrt_diffusion_model import TRTDiffusionBackbone + def parse_args(): - parser = argparse.ArgumentParser( - description="Build a TensorRT engine from a ComfyUI model." - ) + parser = argparse.ArgumentParser(description="Build a TensorRT engine from a ComfyUI model.") parser.add_argument( "--model", type=str, @@ -63,10 +64,18 @@ def parse_args(): ) # Dynamic Engine Optional Args - parser.add_argument("--min-width", type=int, default=None, help="Minimum width for dynamic shape (optional)") - parser.add_argument("--min-height", type=int, default=None, help="Minimum height for dynamic shape (optional)") - parser.add_argument("--max-width", type=int, default=None, help="Maximum width for dynamic shape (optional)") - parser.add_argument("--max-height", type=int, default=None, help="Maximum height for dynamic shape (optional)") + parser.add_argument( + "--min-width", type=int, default=None, help="Minimum width for dynamic shape (optional)" + ) + parser.add_argument( + "--min-height", type=int, default=None, help="Minimum height for dynamic shape (optional)" + ) + parser.add_argument( + "--max-width", type=int, default=None, help="Maximum width for dynamic shape (optional)" + ) + parser.add_argument( + "--max-height", type=int, default=None, help="Maximum height for dynamic shape (optional)" + ) parser.add_argument( "--context", @@ -81,12 +90,11 @@ def parse_args(): help="If set, attempts to export the ONNX with FP8 transformations (Flux or standard).", ) parser.add_argument( - "--verbose", - action="store_true", - help="Enable more logging / debug prints." + "--verbose", action="store_true", help="Enable more logging / debug prints." ) return parser.parse_args() + def build_trt_engine( model_path: str, engine_out_path: str, @@ -100,7 +108,7 @@ def build_trt_engine( context_opt: int = 1, num_video_frames: int = 14, fp8: bool = False, - verbose: bool = False + verbose: bool = False, ): """ 1) Load the model from ComfyUI by path or name @@ -121,8 +129,10 @@ def build_trt_engine( if verbose: print(f"[INFO] Starting build for model: {model_path}") print(f" Output Engine Path: {engine_out_path}") - print(f" (batch={batch_size_opt}, H={height_opt}, W={width_opt}, context={context_opt}, " - f"num_video_frames={num_video_frames}, fp8={fp8})") + print( + f" (batch={batch_size_opt}, H={height_opt}, W={width_opt}, context={context_opt}, " + f"num_video_frames={num_video_frames}, fp8={fp8})" + ) # 1) Load model in GPU: comfy.model_management.unload_all_models() @@ -130,11 +140,9 @@ def build_trt_engine( loaded_model = comfy.sd.load_diffusion_model(model_path, model_options={}) if loaded_model is None: raise ValueError("Failed to load model.") - + comfy.model_management.load_models_gpu( - [loaded_model], - force_patch_weights=True, - force_full_load=True + [loaded_model], force_patch_weights=True, force_full_load=True ) # 2) Export to ONNX at the desired shape @@ -151,19 +159,19 @@ def build_trt_engine( print(f"[INFO] Exporting ONNX to: {onnx_path}") export_onnx( - model = loaded_model, - path = onnx_path, - batch_size = batch_size_opt, - height = height_opt, - width = width_opt, - num_video_frames = num_video_frames, - context_multiplier = context_opt, - fp8 = fp8, + model=loaded_model, + path=onnx_path, + batch_size=batch_size_opt, + height=height_opt, + width=width_opt, + num_video_frames=num_video_frames, + context_multiplier=context_opt, + fp8=fp8, ) # 3) Build the TRT engine model_version = detect_version_from_model(loaded_model) - model_helper = get_helper_from_model(loaded_model) + model_helper = get_helper_from_model(loaded_model) trt_model = TRTDiffusionBackbone(model_helper) @@ -171,20 +179,20 @@ def build_trt_engine( is_dynamic = all(v is not None for v in [min_width, max_width, min_height, max_height]) min_config = { "batch_size": batch_size_opt, - "height": min_height if is_dynamic else height_opt, - "width": min_width if is_dynamic else width_opt, + "height": min_height if is_dynamic else height_opt, + "width": min_width if is_dynamic else width_opt, "context_len": context_opt * model_helper.context_len, } opt_config = { "batch_size": batch_size_opt, - "height": height_opt, - "width": width_opt, + "height": height_opt, + "width": width_opt, "context_len": context_opt * model_helper.context_len, } max_config = { "batch_size": batch_size_opt, - "height": max_height if is_dynamic else height_opt, - "width": max_width if is_dynamic else width_opt, + "height": max_height if is_dynamic else height_opt, + "width": max_width if is_dynamic else width_opt, "context_len": context_opt * model_helper.context_len, } @@ -197,12 +205,12 @@ def build_trt_engine( print(f"[INFO] Building engine -> {engine_out_path}") success = trt_model.build( - onnx_path = onnx_path, - engine_path = engine_out_path, - timing_cache_path = timing_cache_path, - opt_config = opt_config, - min_config = min_config, - max_config = max_config, + onnx_path=onnx_path, + engine_path=engine_out_path, + timing_cache_path=timing_cache_path, + opt_config=opt_config, + min_config=min_config, + max_config=max_config, ) if not success: raise RuntimeError("[ERROR] TensorRT engine build failed") @@ -224,18 +232,18 @@ def build_trt_engine( def main(): args = parse_args() build_trt_engine( - model_path = args.model, - engine_out_path = args.out_engine, - batch_size_opt = args.batch_size, - height_opt = args.height, - width_opt = args.width, - min_width = args.min_width, - min_height = args.min_height, - max_width = args.max_width, - max_height = args.max_height, - context_opt = args.context, - fp8 = args.fp8, - verbose = args.verbose + model_path=args.model, + engine_out_path=args.out_engine, + batch_size_opt=args.batch_size, + height_opt=args.height, + width_opt=args.width, + min_width=args.min_width, + min_height=args.min_height, + max_width=args.max_width, + max_height=args.max_height, + context_opt=args.context, + fp8=args.fp8, + verbose=args.verbose, ) diff --git a/src/comfystream/scripts/constraints.txt b/src/comfystream/scripts/constraints.txt index 529030f2c..50b5c6b30 100644 --- a/src/comfystream/scripts/constraints.txt +++ b/src/comfystream/scripts/constraints.txt @@ -1,12 +1,15 @@ --extra-index-url https://download.pytorch.org/whl/cu128 --extra-index-url https://pypi.nvidia.com numpy<2.0.0 -torch==2.7.1+cu128 -torchvision==0.22.1+cu128 -torchaudio==2.7.1+cu128 +torch==2.8.0+cu128 +torchvision==0.23.0+cu128 +torchaudio==2.8.0+cu128 tensorrt==10.12.0.36 tensorrt-cu12==10.12.0.36 +xformers==0.0.32.post2 onnx==1.18.0 onnxruntime==1.22.0 onnxruntime-gpu==1.22.0 -onnxmltools==1.14.0 \ No newline at end of file +onnxmltools==1.14.0 +huggingface-hub>=0.20.0 +cuda-python<13.0 diff --git a/src/comfystream/scripts/setup_models.py b/src/comfystream/scripts/setup_models.py index 9360a542b..0fb8e98bc 100644 --- a/src/comfystream/scripts/setup_models.py +++ b/src/comfystream/scripts/setup_models.py @@ -1,35 +1,45 @@ +import argparse import os +import sys from pathlib import Path + import requests -from tqdm import tqdm import yaml -import argparse +from huggingface_hub import hf_hub_download, snapshot_download +from tqdm import tqdm from utils import get_config_path, load_model_config + def parse_args(): - parser = argparse.ArgumentParser(description='Setup ComfyUI models') - parser.add_argument('--workspace', - default=os.environ.get('COMFY_UI_WORKSPACE', os.path.expanduser('~/comfyui')), - help='ComfyUI workspace directory (default: ~/comfyui or $COMFY_UI_WORKSPACE)') + parser = argparse.ArgumentParser(description="Setup ComfyUI models") + parser.add_argument( + "--workspace", + default=os.environ.get("COMFY_UI_WORKSPACE", os.path.expanduser("~/comfyui")), + help="ComfyUI workspace directory (default: ~/comfyui or $COMFY_UI_WORKSPACE)", + ) + parser.add_argument( + "--config", + default=None, + help="Path to custom models config file (default: configs/models.yaml). Can be a filename (searches in configs/), or an absolute/relative path.", + ) return parser.parse_args() + def download_file(url, destination, description=None): """Download a file with progress bar, follow redirects, and detect LFS pointer files""" - headers = { - "User-Agent": "Mozilla/5.0" - } + headers = {"User-Agent": "Mozilla/5.0"} with requests.get(url, stream=True, headers=headers, allow_redirects=True) as response: response.raise_for_status() - total_size = int(response.headers.get('content-length', 0)) + total_size = int(response.headers.get("content-length", 0)) desc = description or os.path.basename(destination) - progress_bar = tqdm(total=total_size, unit='iB', unit_scale=True, desc=desc) + progress_bar = tqdm(total=total_size, unit="iB", unit_scale=True, desc=desc) destination = Path(destination) destination.parent.mkdir(parents=True, exist_ok=True) - with open(destination, 'wb') as file: + with open(destination, "wb") as file: for chunk in response.iter_content(chunk_size=1024): if chunk: file.write(chunk) @@ -38,17 +48,40 @@ def download_file(url, destination, description=None): # Verify that we didn't just write a Git LFS pointer if destination.stat().st_size < 100: - with open(destination, 'r', errors='ignore') as f: + with open(destination, "r", errors="ignore") as f: content = f.read() - if 'git-lfs' in content.lower(): + if "git-lfs" in content.lower(): print(f"❌ LFS pointer detected in {destination}. Deleting.") destination.unlink() raise ValueError(f"LFS pointer detected. Failed to download: {url}") + +def download_hf_directory(repo_id, subfolder, destination, description=None): + """Download an entire directory from HuggingFace Hub""" + destination = Path(destination) + destination.mkdir(parents=True, exist_ok=True) + + desc = description or f"Downloading {repo_id}/{subfolder}" + print(f"{desc}...") + + try: + # Download the specific subfolder to the destination + snapshot_download( + repo_id=repo_id, + allow_patterns=f"{subfolder}/*", + local_dir=destination.parent, + local_dir_use_symlinks=False, + ) + print(f"✓ Downloaded {repo_id}/{subfolder} to {destination}") + except Exception as e: + print(f"❌ Error downloading {repo_id}/{subfolder}: {e}") + raise + + def setup_model_files(workspace_dir, config_path=None): """Download and setup required model files based on configuration""" if config_path is None: - config_path = get_config_path('models.yaml') + config_path = get_config_path("models.yaml") try: config = load_model_config(config_path) except FileNotFoundError: @@ -61,34 +94,57 @@ def setup_model_files(workspace_dir, config_path=None): models_path = workspace_dir / "models" base_path = workspace_dir - for _, model_info in config['models'].items(): + for _, model_info in config["models"].items(): # Determine the full path based on whether it's in custom_nodes or models - if model_info['path'].startswith('custom_nodes/'): - full_path = base_path / model_info['path'] + if model_info["path"].startswith("custom_nodes/"): + full_path = base_path / model_info["path"] else: - full_path = models_path / model_info['path'] + full_path = models_path / model_info["path"] if not full_path.exists(): print(f"Downloading {model_info['name']}...") - download_file( - model_info['url'], - full_path, - f"Downloading {model_info['name']}" - ) - print(f"Downloaded {model_info['name']} to {full_path}") + + # Check if this is a HuggingFace directory download + if model_info.get("is_directory", False): + # Parse HuggingFace URL to extract repo_id and subfolder + # Format: https://huggingface.co/{repo_id}/tree/main/{subfolder} + # Or: https://huggingface.co/{repo_id}/blob/main/{subfolder} + url = model_info["url"] + if "huggingface.co" in url: + parts = url.split("huggingface.co/")[-1].split("/") + if len(parts) >= 4 and (parts[2] in ["tree", "blob"]): + repo_id = f"{parts[0]}/{parts[1]}" + subfolder = "/".join(parts[4:]) if len(parts) > 4 else parts[3] + download_hf_directory( + repo_id=repo_id, + subfolder=subfolder, + destination=full_path, + description=f"Downloading {model_info['name']}", + ) + else: + print(f"❌ Invalid HuggingFace URL format: {url}") + continue + else: + print(f"❌ Directory download only supports HuggingFace URLs: {url}") + continue + else: + # Regular file download + download_file(model_info["url"], full_path, f"Downloading {model_info['name']}") + print(f"Downloaded {model_info['name']} to {full_path}") # Handle any extra files (like configs) - if 'extra_files' in model_info: - for extra in model_info['extra_files']: - extra_path = models_path / extra['path'] + if "extra_files" in model_info: + for extra in model_info["extra_files"]: + extra_path = models_path / extra["path"] if not extra_path.exists(): download_file( - extra['url'], + extra["url"], extra_path, - f"Downloading {os.path.basename(extra['path'])}" + f"Downloading {os.path.basename(extra['path'])}", ) print("Models download completed!") + def setup_directories(workspace_dir): """Create required directories in the workspace""" # Create base directories @@ -111,19 +167,37 @@ def setup_directories(workspace_dir): "checkpoints/SD1.5", "controlnet", "vae", + "vae_approx", "tensorrt", "unet", + "loras/SD1.5", + "ipadapter", + "text_encoders/CLIPText", + "liveportrait_onnx/joyvasa_models", "LLM", ] for dir_name in model_dirs: subdir = models_dir / dir_name subdir.mkdir(parents=True, exist_ok=True) + def setup_models(): args = parse_args() workspace_dir = Path(args.workspace) + # Resolve config path if provided + config_path = None + if args.config: + config_path = Path(args.config) + # If it's just a filename, look in configs directory + if not config_path.is_absolute() and "/" not in str(config_path): + config_path = Path("configs") / config_path + if not config_path.exists(): + print(f"Error: Config file not found at {config_path}") + sys.exit(1) + setup_directories(workspace_dir) setup_model_files(workspace_dir) + setup_models() diff --git a/src/comfystream/scripts/setup_nodes.py b/src/comfystream/scripts/setup_nodes.py index 418e55f86..05d2fab5a 100755 --- a/src/comfystream/scripts/setup_nodes.py +++ b/src/comfystream/scripts/setup_nodes.py @@ -1,9 +1,10 @@ +import argparse import os import subprocess import sys from pathlib import Path + import yaml -import argparse from utils import get_config_path, load_model_config @@ -20,6 +21,11 @@ def parse_args(): default=False, help="Update existing nodes to their specified branches", ) + parser.add_argument( + "--config", + default=None, + help="Path to custom nodes config file (default: configs/nodes.yaml). Can be a filename (searches in configs/), or an absolute/relative path.", + ) return parser.parse_args() @@ -77,7 +83,9 @@ def install_custom_nodes(workspace_dir, config_path=None, pull_branches=False): print(f"Updating {node_info['name']} to latest {node_info['branch']}...") subprocess.run(["git", "-C", dir_name, "fetch", "origin"], check=True) subprocess.run(["git", "-C", dir_name, "checkout", node_info["branch"]], check=True) - subprocess.run(["git", "-C", dir_name, "pull", "origin", node_info["branch"]], check=True) + subprocess.run( + ["git", "-C", dir_name, "pull", "origin", node_info["branch"]], check=True + ) else: print(f"{node_info['name']} already exists, skipping clone.") @@ -120,9 +128,20 @@ def setup_nodes(): args = parse_args() workspace_dir = Path(args.workspace) + # Resolve config path if provided + config_path = None + if args.config: + config_path = Path(args.config) + # If it's just a filename, look in configs directory + if not config_path.is_absolute() and "/" not in str(config_path): + config_path = Path("configs") / config_path + if not config_path.exists(): + print(f"Error: Config file not found at {config_path}") + sys.exit(1) + setup_environment(workspace_dir) setup_directories(workspace_dir) - install_custom_nodes(workspace_dir, pull_branches=args.pull_branches) + install_custom_nodes(workspace_dir, config_path=config_path, pull_branches=args.pull_branches) if __name__ == "__main__": diff --git a/src/comfystream/scripts/utils.py b/src/comfystream/scripts/utils.py index a7b37f2df..5ce4ab351 100644 --- a/src/comfystream/scripts/utils.py +++ b/src/comfystream/scripts/utils.py @@ -1,12 +1,14 @@ -import yaml from pathlib import Path +import yaml + + def get_config_path(filename): """Get the absolute path to a config file""" config_path = Path("configs") / filename if not config_path.exists(): print(f"Warning: Config file {filename} not found at {config_path}") - print(f"Available files in configs/:") + print("Available files in configs/:") try: for f in Path("configs").glob("*"): print(f" - {f.name}") @@ -15,7 +17,8 @@ def get_config_path(filename): raise FileNotFoundError(f"Config file {filename} not found at {config_path}") return config_path + def load_model_config(config_path): """Load model configuration from YAML file""" - with open(config_path, 'r') as f: - return yaml.safe_load(f) \ No newline at end of file + with open(config_path, "r") as f: + return yaml.safe_load(f) diff --git a/src/comfystream/server/metrics/prometheus_metrics.py b/src/comfystream/server/metrics/prometheus_metrics.py index 080bc2940..e1aec64e5 100644 --- a/src/comfystream/server/metrics/prometheus_metrics.py +++ b/src/comfystream/server/metrics/prometheus_metrics.py @@ -1,9 +1,10 @@ """Prometheus metrics utilities.""" -from prometheus_client import Gauge, generate_latest -from aiohttp import web from typing import Optional +from aiohttp import web +from prometheus_client import Gauge, generate_latest + class MetricsManager: """Manages Prometheus metrics collection.""" @@ -18,9 +19,7 @@ def __init__(self, include_stream_id: bool = False): self._include_stream_id = include_stream_id base_labels = ["stream_id"] if include_stream_id else [] - self._fps_gauge = Gauge( - "stream_fps", "Frames per second of the stream", base_labels - ) + self._fps_gauge = Gauge("stream_fps", "Frames per second of the stream", base_labels) def enable(self): """Enable Prometheus metrics collection.""" diff --git a/src/comfystream/server/metrics/stream_stats.py b/src/comfystream/server/metrics/stream_stats.py index 8dc2ab19e..40d88c782 100644 --- a/src/comfystream/server/metrics/stream_stats.py +++ b/src/comfystream/server/metrics/stream_stats.py @@ -1,7 +1,8 @@ """Handles real-time video stream statistics (non-Prometheus, JSON API).""" -from typing import Any, Dict import json +from typing import Any, Dict + from aiohttp import web from aiortc import MediaStreamTrack @@ -17,9 +18,7 @@ def __init__(self, app: web.Application): """ self._app = app - async def collect_video_metrics( - self, video_track: MediaStreamTrack - ) -> Dict[str, Any]: + async def collect_video_metrics(self, video_track: MediaStreamTrack) -> Dict[str, Any]: """Collects real-time statistics for a video track. Args: diff --git a/src/comfystream/server/utils/__init__.py b/src/comfystream/server/utils/__init__.py index daa71bb1e..f64357c32 100644 --- a/src/comfystream/server/utils/__init__.py +++ b/src/comfystream/server/utils/__init__.py @@ -1,2 +1,2 @@ -from .utils import patch_loop_datagram, add_prefix_to_app_routes, temporary_log_level from .fps_meter import FPSMeter +from .utils import add_prefix_to_app_routes, patch_loop_datagram, temporary_log_level diff --git a/src/comfystream/server/utils/fps_meter.py b/src/comfystream/server/utils/fps_meter.py index 87e75d461..ee86772b2 100644 --- a/src/comfystream/server/utils/fps_meter.py +++ b/src/comfystream/server/utils/fps_meter.py @@ -4,6 +4,7 @@ import logging import time from collections import deque + from comfystream.server.metrics import MetricsManager logger = logging.getLogger(__name__) @@ -35,11 +36,7 @@ async def _calculate_fps_loop(self): current_time = time.monotonic() if self._last_fps_calculation_time is not None: time_diff = current_time - self._last_fps_calculation_time - self._fps = ( - self._fps_interval_frame_count / time_diff - if time_diff > 0 - else 0.0 - ) + self._fps = self._fps_interval_frame_count / time_diff if time_diff > 0 else 0.0 self._fps_measurements.append( { "timestamp": current_time - self._fps_loop_start_time, @@ -92,8 +89,7 @@ async def average_fps(self) -> float: """ async with self._lock: return ( - sum(m["fps"] for m in self._fps_measurements) - / len(self._fps_measurements) + sum(m["fps"] for m in self._fps_measurements) / len(self._fps_measurements) if self._fps_measurements else self._fps ) @@ -106,9 +102,6 @@ async def last_fps_calculation_time(self) -> float: The elapsed time in seconds since the last FPS calculation. """ async with self._lock: - if ( - self._last_fps_calculation_time is None - or self._fps_loop_start_time is None - ): + if self._last_fps_calculation_time is None or self._fps_loop_start_time is None: return 0.0 return self._last_fps_calculation_time - self._fps_loop_start_time diff --git a/src/comfystream/server/utils/utils.py b/src/comfystream/server/utils/utils.py index c7a7ac304..baa7ff6d4 100644 --- a/src/comfystream/server/utils/utils.py +++ b/src/comfystream/server/utils/utils.py @@ -1,12 +1,13 @@ """General utility functions.""" import asyncio +import logging import random import types -import logging -from aiohttp import web -from typing import List, Tuple from contextlib import asynccontextmanager +from typing import List, Tuple + +from aiohttp import web logger = logging.getLogger(__name__) @@ -30,9 +31,7 @@ async def create_datagram_endpoint( protocol_factory, local_addr=local_addr, **kwargs ) if local_addr is None: - return await old_create_datagram_endpoint( - protocol_factory, local_addr=None, **kwargs - ) + return await old_create_datagram_endpoint(protocol_factory, local_addr=None, **kwargs) # if port is not specified make it use our range ports = list(local_ports) random.shuffle(ports) diff --git a/src/comfystream/tensor_cache.py b/src/comfystream/tensor_cache.py index 5cd54332e..609f98b8f 100644 --- a/src/comfystream/tensor_cache.py +++ b/src/comfystream/tensor_cache.py @@ -1,11 +1,10 @@ -import torch -import numpy as np - -from queue import Queue from asyncio import Queue as AsyncQueue - +from queue import Queue from typing import Union +import numpy as np +import torch + # TODO: improve eviction policy fifo might not be the best, skip alternate frames instead image_inputs: Queue[Union[torch.Tensor, np.ndarray]] = Queue(maxsize=1) image_outputs: AsyncQueue[Union[torch.Tensor, np.ndarray]] = AsyncQueue() diff --git a/src/comfystream/utils.py b/src/comfystream/utils.py index 7d8800c4a..1fd035a8f 100644 --- a/src/comfystream/utils.py +++ b/src/comfystream/utils.py @@ -1,16 +1,17 @@ import copy -import json -import os -import logging import importlib -from typing import Dict, Any, List, Tuple, Optional, Union -from pytrickle.api import StreamParamsUpdateRequest +import json +from typing import Any, Dict + from comfy.api.components.schema.prompt import Prompt, PromptDictInput +from pytrickle.api import StreamParamsUpdateRequest + from .modalities import ( - get_node_counts_by_type, get_convertible_node_keys, + get_node_counts_by_type, ) + def create_load_tensor_node(): return { "inputs": {}, @@ -18,6 +19,7 @@ def create_load_tensor_node(): "_meta": {"title": "LoadTensor"}, } + def create_save_tensor_node(inputs: Dict[Any, Any]): return { "inputs": inputs, @@ -25,6 +27,7 @@ def create_save_tensor_node(inputs: Dict[Any, Any]): "_meta": {"title": "SaveTensor"}, } + def _validate_prompt_constraints(counts: Dict[str, int]) -> None: """Validate that the prompt meets the required constraints.""" if counts["primary_inputs"] > 1: @@ -42,6 +45,7 @@ def _validate_prompt_constraints(counts: Dict[str, int]) -> None: if counts["outputs"] == 0: raise Exception("missing output") + def convert_prompt(prompt: PromptDictInput, return_dict: bool = False) -> Prompt: """Convert a prompt by replacing specific node types with tensor equivalents.""" try: @@ -49,7 +53,7 @@ def convert_prompt(prompt: PromptDictInput, return_dict: bool = False) -> Prompt importlib.import_module("comfy.api.components.schema.prompt_node") except Exception: pass - + """Convert and validate a ComfyUI workflow prompt.""" Prompt.validate(prompt) prompt = copy.deepcopy(prompt) @@ -57,7 +61,7 @@ def convert_prompt(prompt: PromptDictInput, return_dict: bool = False) -> Prompt # Count nodes and validate constraints counts = get_node_counts_by_type(prompt) _validate_prompt_constraints(counts) - + # Collect nodes that need conversion convertible_keys = get_convertible_node_keys(prompt) @@ -77,73 +81,82 @@ def convert_prompt(prompt: PromptDictInput, return_dict: bool = False) -> Prompt # Return dict if requested (for downstream components that expect plain dicts) if return_dict: return prompt # Already a plain dict at this point - + # Validate the processed prompt and return Pydantic object return Prompt.validate(prompt) + class ComfyStreamParamsUpdateRequest(StreamParamsUpdateRequest): """ComfyStream parameter validation.""" - + def __init__(self, **data): # Handle prompts parameter if "prompts" in data: prompts = data["prompts"] - + # Parse JSON string if needed if isinstance(prompts, str) and prompts.strip(): try: prompts = json.loads(prompts) except json.JSONDecodeError: data.pop("prompts") - + # Handle list - use first valid dict elif isinstance(prompts, list): prompts = next((p for p in prompts if isinstance(p, dict)), None) if not prompts: data.pop("prompts") - + # Validate prompts if "prompts" in data and isinstance(prompts, dict): try: data["prompts"] = convert_prompt(prompts, return_dict=True) except Exception: data.pop("prompts") - + # Call parent constructor super().__init__(**data) - + @classmethod def model_validate(cls, obj): return cls(**obj) - + def model_dump(self): return super().model_dump() + +def normalize_stream_params(params: Any) -> Dict[str, Any]: + """Normalize stream parameters from various formats to a dict. + + Args: + params: Parameters in dict, list, or other format + + Returns: + Dict containing normalized parameters, empty dict if invalid + """ + if params is None: + return {} + if isinstance(params, dict): + return dict(params) + if isinstance(params, list): + for candidate in params: + if isinstance(candidate, dict): + return dict(candidate) + return {} + return {} + + def get_default_workflow() -> dict: """Return the default workflow as a dictionary for warmup. - + Returns: dict: Default workflow dictionary """ return { "1": { - "inputs": { - "images": [ - "2", - 0 - ] - }, + "inputs": {"images": ["2", 0]}, "class_type": "SaveTensor", - "_meta": { - "title": "SaveTensor" - } + "_meta": {"title": "SaveTensor"}, }, - "2": { - "inputs": {}, - "class_type": "LoadTensor", - "_meta": { - "title": "LoadTensor" - } - } + "2": {"inputs": {}, "class_type": "LoadTensor", "_meta": {"title": "LoadTensor"}}, } - diff --git a/test/test_utils.py b/test/test_utils.py index 30c990caa..70b4209cd 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,6 +1,6 @@ import pytest - from comfy.api.components.schema.prompt import Prompt + from comfystream.utils import convert_prompt diff --git a/ui/package-lock.json b/ui/package-lock.json index 1261f8422..5b65c1b9b 100644 --- a/ui/package-lock.json +++ b/ui/package-lock.json @@ -1,12 +1,12 @@ { "name": "ui", - "version": "0.1.5", + "version": "0.1.7", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "ui", - "version": "0.1.5", + "version": "0.1.7", "dependencies": { "@hookform/resolvers": "^3.9.1", "@radix-ui/react-dialog": "^1.1.6", @@ -2367,9 +2367,9 @@ } }, "node_modules/@typescript-eslint/typescript-estree/node_modules/brace-expansion": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.1.tgz", - "integrity": "sha512-XnAIvQ8eM+kC6aULx6wuQiwVsnzsi9d3WxzV3FpWTGA19F621kwdbsAcFKXgKUHZWsy+mY6iL1sHTxWEFCytDA==", + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz", + "integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==", "dev": true, "license": "MIT", "dependencies": { @@ -2866,9 +2866,9 @@ } }, "node_modules/brace-expansion": { - "version": "1.1.11", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.11.tgz", - "integrity": "sha512-iCuPHDFgrHX7H2vEI/5xpz07zSHB00TpugqhmYtVmMO6518mCuRMoOYFldEBl0g187ufozdaHgWKcYFb61qGiA==", + "version": "1.1.12", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.12.tgz", + "integrity": "sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==", "dev": true, "license": "MIT", "dependencies": { @@ -5265,9 +5265,9 @@ "license": "MIT" }, "node_modules/js-yaml": { - "version": "4.1.0", - "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.0.tgz", - "integrity": "sha512-wpxZs9NoxZaJESJGIZTyDEaYpl0FKSA+FB9aJiyemKhMwkxQg63h4T1KJgUGHpTqPDNRcmmYLugrRjJlBtWvRA==", + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.1.tgz", + "integrity": "sha512-qQKT4zQxXl8lLwBtHMWwaTcGfFOZviOJet3Oy/xmGk2gZH677CJM9EvtfdSkgWcATZhj/55JZ0rmy3myCT5lsA==", "dev": true, "license": "MIT", "dependencies": { @@ -7389,9 +7389,9 @@ } }, "node_modules/sucrase/node_modules/brace-expansion": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.1.tgz", - "integrity": "sha512-XnAIvQ8eM+kC6aULx6wuQiwVsnzsi9d3WxzV3FpWTGA19F621kwdbsAcFKXgKUHZWsy+mY6iL1sHTxWEFCytDA==", + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz", + "integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==", "license": "MIT", "dependencies": { "balanced-match": "^1.0.0" diff --git a/ui/package.json b/ui/package.json index 666c6a5f1..7cc8128b7 100644 --- a/ui/package.json +++ b/ui/package.json @@ -1,6 +1,6 @@ { "name": "ui", - "version": "0.1.5", + "version": "0.1.7", "private": true, "scripts": { "dev": "cross-env NEXT_PUBLIC_DEV=true next dev", @@ -9,7 +9,7 @@ "start": "next start", "lint": "next lint", "format": "prettier --write .", - "prepare": "cd .. && husky && husky install ui/.husky" + "prepare": "cd .. && husky ui/.husky" }, "dependencies": { "@hookform/resolvers": "^3.9.1", @@ -52,6 +52,14 @@ }, "lint-staged": { "*.{js,ts,tsx}": "eslint --cache --fix", - "*.{js,ts,tsx,css,md}": "prettier --write" + "*.{js,ts,tsx,css,md}": "prettier --write", + "../*.py": [ + "ruff check --fix", + "ruff format" + ], + "../{src,server,nodes,scripts}/**/*.py": [ + "ruff check --fix", + "ruff format" + ] } } diff --git a/ui/src/app/webrtc-preview/page.tsx b/ui/src/app/webrtc-preview/page.tsx new file mode 100644 index 000000000..b2d2045a2 --- /dev/null +++ b/ui/src/app/webrtc-preview/page.tsx @@ -0,0 +1,176 @@ +"use client"; +// WebRTC Preview Popup Page (client-only) + +import React, { useEffect, useRef, useState, useCallback } from "react"; + +const POLL_INTERVAL_MS = 300; +const MAX_ATTEMPTS = 200; // ~60s + +export default function WebRTCPopupPage() { + const videoRef = useRef(null); + const localStreamRef = useRef(null); + const parentStreamRef = useRef(null); + const clonedIdsRef = useRef>(new Set()); + const attemptsRef = useRef(0); + const intervalRef = useRef(null); + const [status, setStatus] = useState("Initializing…"); + + const clearIntervalInternal = useCallback(() => { + if (intervalRef.current !== null) { + window.clearInterval(intervalRef.current); + intervalRef.current = null; + } + }, []); + + const scheduleClose = useCallback((delay = 800) => { + window.setTimeout(() => { + try { window.close(); } catch { /* noop */ } + }, delay); + }, []); + + const validateOpener = useCallback((): boolean => { + try { + if (!window.opener) { + setStatus("Opener lost. Closing…"); + scheduleClose(); + return false; + } + void window.opener.location.href; // cross-origin check + return true; + } catch { + setStatus("Cross-origin opener. Closing…"); + scheduleClose(); + return false; + } + }, [scheduleClose]); + + const attachVideoIfNeeded = () => { + if (!localStreamRef.current) return; + const video = videoRef.current; + if (video && video.srcObject !== localStreamRef.current) { + video.srcObject = localStreamRef.current; + } + }; + + const cloneTracks = useCallback(() => { + if (!validateOpener()) return; + // @ts-ignore - global from opener context + const parentStream: MediaStream | undefined = window.opener?.__comfystreamRemoteStream; + if (!parentStream) { + setStatus("Waiting for stream…"); + return; + } + if (!localStreamRef.current) { + localStreamRef.current = new MediaStream(); + } + // Parent stream changed -> reset + if (parentStreamRef.current && parentStreamRef.current !== parentStream) { + localStreamRef.current.getTracks().forEach(t => { try { t.stop(); } catch { /* */ } }); + localStreamRef.current = new MediaStream(); + clonedIdsRef.current.clear(); + } + parentStreamRef.current = parentStream; + + let added = false; + parentStream.getTracks().forEach(src => { + if (src.readyState === "ended") return; + if (!clonedIdsRef.current.has(src.id)) { + try { + const clone = src.clone(); + clone.addEventListener("ended", () => { + clonedIdsRef.current.delete(src.id); + }); + localStreamRef.current!.addTrack(clone); + clonedIdsRef.current.add(src.id); + added = true; + } catch { + /* skip */ + } + } + }); + // Cleanup ended clones + localStreamRef.current.getTracks().forEach(t => { + if (t.readyState === "ended") { + localStreamRef.current!.removeTrack(t); + try { t.stop(); } catch { /* */ } + } + }); + if (added) { + attachVideoIfNeeded(); + setStatus("Live"); + videoRef.current?.play().catch(() => {}); + } + }, [validateOpener]); + + useEffect(() => { + if (typeof window === "undefined") return; // safety + attemptsRef.current = 0; + setStatus("Initializing…"); + + const tick = () => { + attemptsRef.current += 1; + if (!validateOpener()) { + clearIntervalInternal(); + return; + } + // @ts-ignore + if (!window.opener.__comfystreamRemoteStream) { + setStatus("Parent stream ended"); + clearIntervalInternal(); + scheduleClose(1200); + return; + } + cloneTracks(); + if (attemptsRef.current >= MAX_ATTEMPTS && (!localStreamRef.current || localStreamRef.current.getTracks().length === 0)) { + setStatus("Timeout waiting for stream"); + clearIntervalInternal(); + scheduleClose(1500); + } + }; + + intervalRef.current = window.setInterval(tick, POLL_INTERVAL_MS); + cloneTracks(); + + const beforeUnload = () => { + clearIntervalInternal(); + localStreamRef.current?.getTracks().forEach(t => { try { t.stop(); } catch { /* */ } }); + }; + window.addEventListener("beforeunload", beforeUnload); + return () => { + window.removeEventListener("beforeunload", beforeUnload); + clearIntervalInternal(); + localStreamRef.current?.getTracks().forEach(t => { try { t.stop(); } catch { /* */ } }); + }; + }, [cloneTracks, clearIntervalInternal, validateOpener, scheduleClose]); + + return ( +
+
+ ); +} diff --git a/ui/src/components/room.tsx b/ui/src/components/room.tsx index 85c6a8f85..19f1ce35d 100644 --- a/ui/src/components/room.tsx +++ b/ui/src/components/room.tsx @@ -56,8 +56,20 @@ function useToast() { } // Wrapper component to access peer context -function TranscriptionViewerWrapper() { +function TranscriptionViewerWrapper({ onFirstTextOutput }: { onFirstTextOutput: () => void }) { const peer = usePeerContext(); + const hasNotifiedRef = useRef(false); + + // Watch for first text output and notify parent + useEffect(() => { + if (peer?.textOutputData && peer.textOutputData.trim() && !hasNotifiedRef.current) { + // Check if it's not a warmup message + if (!peer.textOutputData.includes('__WARMUP_SENTINEL__')) { + hasNotifiedRef.current = true; + onFirstTextOutput(); + } + } + }, [peer?.textOutputData, onFirstTextOutput]); return ( void; onComfyUIReady: () => void; resolution: { width: number; height: number }; - backendUrl: string; onOutputStreamReady: (stream: MediaStream | null) => void; prompts: Prompt[] | null; } -function Stage({ connected, onStreamReady, onComfyUIReady, resolution, backendUrl, onOutputStreamReady, prompts }: StageProps) { +function Stage({ connected, onStreamReady, onComfyUIReady, resolution, onOutputStreamReady, prompts }: StageProps) { const { remoteStream, peerConnection } = usePeerContext(); const [frameRate, setFrameRate] = useState(0); // Add state and refs for tracking frames @@ -310,7 +321,7 @@ function Stage({ connected, onStreamReady, onComfyUIReady, resolution, backendUr )} {/* Add StreamControlIcon at the bottom right corner of the video box */} - + ); } @@ -343,7 +354,8 @@ export const Room = () => { const [isRecordingsPanelOpen, setIsRecordingsPanelOpen] = useState(false); // Transcription state - const [isTranscriptionPanelOpen, setIsTranscriptionPanelOpen] = useState(true); + const [isTranscriptionPanelOpen, setIsTranscriptionPanelOpen] = useState(false); + const [hasReceivedTextOutput, setHasReceivedTextOutput] = useState(false); // Helper to get timestamped filenames const getFilename = (type: 'input' | 'output', extension: string) => { @@ -417,6 +429,7 @@ export const Room = () => { const handleDisconnected = useCallback(() => { setIsConnected(false); setIsComfyUIReady(false); + setHasReceivedTextOutput(false); // Reset text output state showToast("Stream disconnected", "error"); }, [showToast]); @@ -570,6 +583,14 @@ export const Room = () => { const [isControlPanelOpen, setIsControlPanelOpen] = useState(false); + // Callback to handle first text output received + const handleFirstTextOutput = useCallback(() => { + if (!hasReceivedTextOutput && !isTranscriptionPanelOpen) { + setHasReceivedTextOutput(true); + setIsTranscriptionPanelOpen(true); + } + }, [hasReceivedTextOutput, isTranscriptionPanelOpen]); + return (
{ { {isConnected && (
- +
)} diff --git a/ui/src/components/stream-control.tsx b/ui/src/components/stream-control.tsx index 3cd26583b..e71e674f6 100644 --- a/ui/src/components/stream-control.tsx +++ b/ui/src/components/stream-control.tsx @@ -1,95 +1,68 @@ import * as React from "react"; -import { useState } from "react"; +import { useState, useCallback } from "react"; interface StreamControlProps { className?: string; - backendUrl: string; } -export function StreamControl({ className = "", backendUrl }: StreamControlProps) { +export function StreamControl({ className = "" }: StreamControlProps) { const [isLoading, setIsLoading] = useState(false); - - // Generate the stream URL with a unique streamID from the server - const getStreamUrl = async (): Promise => { - try { - // Validate backendUrl - if (!backendUrl) { - console.error("Backend URL is not configured."); - throw new Error("Backend URL is not configured in settings."); - } - // Parse base URL from the provided backendUrl - let baseUrl: string; + // Open popup which polls opener for stream and clones tracks locally (no postMessage MediaStream cloning) + const openWebRTCPopup = useCallback(() => { + const features = 'width=1024,height=1024'; + const getBasePath = (): string => { try { - // The origin property gives us "http://hostname:port" - baseUrl = new URL(backendUrl).origin; - } catch (e) { - console.error("Invalid backend URL configured:", backendUrl, e); - throw new Error(`Invalid backend URL configured: ${backendUrl}`); - } - - // Check if we're in a hosted environment by looking at the current URL - // This might need adjustment depending on how hosted environments are detected - const isHosted = window.location.pathname.includes('/live'); - const pathPrefix = isHosted ? '/live' : ''; - - // Request a unique streamID from the server using the derived baseUrl - const response = await fetch(`${baseUrl}${pathPrefix}/api/stream-token`, { - method: 'POST', - headers: { - 'Content-Type': 'application/json' + const scripts = document.querySelectorAll('script[src]'); + for (const s of Array.from(scripts)) { + const src = (s as HTMLScriptElement).src; + // Look for /_next/static/ which precedes hashed chunks + const idx = src.indexOf('/_next/static/'); + if (idx !== -1) { + const urlObj = new URL(src); + const before = urlObj.pathname.substring(0, urlObj.pathname.indexOf('/_next/static/')); + if (before !== undefined) { + return before.replace(/\/$/, ''); + } + } } - }); - - if (!response.ok) { - const errorData = await response.json().catch(() => ({})); - throw new Error(errorData.error || `Failed to get stream token: ${response.status}`); - } - - const data = await response.json(); - const streamId = data.stream_id; - - // Return the URL with the unique streamID, using the derived baseUrl - // Note: Token will be removed from URL in a later step - return `${baseUrl}${pathPrefix}/stream.html?token=${streamId}`; - } catch (error) { - console.error('Error getting stream URL:', error); - return null; - } - }; - - // Open the stream in a new window - const openStreamWindow = async () => { - try { - setIsLoading(true); - const streamUrl = await getStreamUrl(); - - if (!streamUrl) { - throw new Error('Failed to get stream URL'); - } - - const newWindow = window.open(streamUrl, 'ComfyStream OBS Capture', 'width=1024,height=1024'); - - if (!newWindow) { - throw new Error('Failed to open stream window. Please check your popup blocker settings.'); - } - } catch (error) { - console.error('Error opening stream window:', error); - alert(error instanceof Error ? error.message : 'Failed to open stream window. Please try again.'); - } finally { - setIsLoading(false); + } catch { /* ignore */ } + + try { + const { pathname } = window.location; + // If pathname points to a file (no trailing slash and contains a dot), strip file portion + if (/\.[a-zA-Z0-9]{2,8}$/.test(pathname.split('/').pop() || '')) { + const parts = pathname.split('/'); + parts.pop(); + return parts.join('/') || '/'; + } + return pathname.replace(/\/$/, ''); + } catch { /* ignore */ } + + return ''; + }; + + const basePath = getBasePath(); + const isDev = process.env.NEXT_PUBLIC_DEV === 'true'; + const previewPath = (basePath ? basePath : '') + (isDev ? '/webrtc-preview' : '/webrtc-preview.html'); + + const popup = window.open(previewPath, 'comfystream_preview', features) || window.open(previewPath); + if (!popup) { + alert('Popup blocked. Please allow popups for this site.'); } + }, []); + + const openStreamWindow = () => { + openWebRTCPopup(); }; return (