diff --git a/.devcontainer/post-create.sh b/.devcontainer/post-create.sh index 84787613b..bc2fdfa6a 100755 --- a/.devcontainer/post-create.sh +++ b/.devcontainer/post-create.sh @@ -5,7 +5,7 @@ cd /workspace/comfystream # Install Comfystream in editable mode. echo -e "\e[32mInstalling Comfystream in editable mode...\e[0m" -/workspace/miniconda3/envs/comfystream/bin/python3 -m pip install -e . --root-user-action=ignore > /dev/null +/workspace/miniconda3/envs/comfystream/bin/python3 -m pip install -e . -c src/comfystream/scripts/constraints.txt --root-user-action=ignore > /dev/null # Install npm packages if needed if [ ! -d "/workspace/comfystream/ui/node_modules" ]; then diff --git a/.editorconfig b/.editorconfig index c39bef288..17f476700 100644 --- a/.editorconfig +++ b/.editorconfig @@ -13,7 +13,9 @@ insert_final_newline = true insert_final_newline = unset [*.py] +indent_style = space indent_size = 4 +trim_trailing_whitespace = false [workflows/comfy*/*.json] insert_final_newline = unset diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml index 68f374747..1d79dfe0c 100644 --- a/.github/workflows/docker.yaml +++ b/.github/workflows/docker.yaml @@ -27,7 +27,7 @@ jobs: runs-on: [self-hosted, linux, gpu] steps: - name: Check out code - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 ref: ${{ github.event.pull_request.head.sha }} @@ -93,7 +93,7 @@ jobs: runs-on: [self-hosted, linux, amd64] steps: - name: Check out code - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 ref: ${{ github.event.pull_request.head.sha }} diff --git a/.github/workflows/opencv-cuda-artifact.yml b/.github/workflows/opencv-cuda-artifact.yml index c5deab076..df7071903 100644 --- a/.github/workflows/opencv-cuda-artifact.yml +++ b/.github/workflows/opencv-cuda-artifact.yml @@ -29,7 +29,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 ref: ${{ github.event.pull_request.head.sha || github.sha }} @@ -101,7 +101,7 @@ jobs: rm -rf test-extract - name: Upload OpenCV CUDA Release Artifact - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v5 with: name: opencv-cuda-release-python${{ env.PYTHON_VERSION }}-cuda${{ env.CUDA_VERSION }}-${{ github.sha }} path: | @@ -141,7 +141,7 @@ jobs: EOF - name: Upload Release Info - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v5 with: name: release-info-python${{ env.PYTHON_VERSION }}-cuda${{ env.CUDA_VERSION }}-${{ github.sha }} path: release-info.txt @@ -155,22 +155,22 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Download artifacts - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v6 with: name: opencv-cuda-release-python${{ env.PYTHON_VERSION }}-cuda${{ env.CUDA_VERSION }}-${{ github.sha }} path: ./artifacts - name: Download release info - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v6 with: name: release-info-python${{ env.PYTHON_VERSION }}-cuda${{ env.CUDA_VERSION }}-${{ github.sha }} path: ./artifacts - name: Create Release Draft - uses: softprops/action-gh-release@v1 + uses: softprops/action-gh-release@v2 with: tag_name: opencv-cuda-v${{ env.PYTHON_VERSION }}-${{ env.CUDA_VERSION }}-${{ github.run_number }} name: OpenCV CUDA Release - Python ${{ env.PYTHON_VERSION }} CUDA ${{ env.CUDA_VERSION }} diff --git a/.github/workflows/publish-comfyui-node.yaml b/.github/workflows/publish-comfyui-node.yaml new file mode 100644 index 000000000..0b609aee3 --- /dev/null +++ b/.github/workflows/publish-comfyui-node.yaml @@ -0,0 +1,26 @@ +name: Publish ComfyUI Custom Node + +on: + workflow_dispatch: + +permissions: + contents: read + issues: write + +jobs: + publish-comfyui-node: + name: Publish Custom Node to ComfyUI registry + runs-on: ubuntu-latest + # Ensure this only runs on main branch + if: ${{ github.ref == 'refs/heads/main' }} + steps: + - name: Check out code + uses: actions/checkout@v6 + with: + submodules: true + + - name: Publish Custom Node + uses: Comfy-Org/publish-node-action@v1 + with: + ## Add your own personal access token to your Github Repository secrets and reference it here. + personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index b46ff8ba2..58d270929 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -12,9 +12,9 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v6 - - uses: actions/setup-node@v5 + - uses: actions/setup-node@v6 with: node-version: 18 cache: npm @@ -38,7 +38,7 @@ jobs: cd - - name: Upload artifacts - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v5 with: name: release-artifacts path: releases/ diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 98d27d262..98baa6cf3 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -19,20 +19,20 @@ jobs: runs-on: ubuntu-latest steps: - name: Check out code - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 ref: ${{ github.event.pull_request.head.sha }} # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL - uses: github/codeql-action/init@v3 + uses: github/codeql-action/init@v4 with: languages: typescript,javascript,python config-file: ./.github/codeql-config.yaml - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v3 + uses: github/codeql-action/analyze@v4 editorconfig: @@ -40,7 +40,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Check out code - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: # Check https://github.com/livepeer/go-livepeer/pull/1891 # for ref value discussion @@ -59,14 +59,14 @@ jobs: runs-on: ubuntu-latest steps: - name: Check out code - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: # Check https://github.com/livepeer/go-livepeer/pull/1891 # for ref value discussion ref: ${{ github.event.pull_request.head.sha }} - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.12' cache: pip diff --git a/.husky/pre-commit b/.husky/pre-commit new file mode 100755 index 000000000..d9a28aeb7 --- /dev/null +++ b/.husky/pre-commit @@ -0,0 +1,3 @@ +#!/bin/sh +cd ui && npx lint-staged + diff --git a/__init__.py b/__init__.py index 1215db9b3..7782d363f 100644 --- a/__init__.py +++ b/__init__.py @@ -4,4 +4,4 @@ # Import and expose node classes from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS -__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] +__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] diff --git a/benchmark.py b/benchmark.py index 359a8c237..8e096a77f 100644 --- a/benchmark.py +++ b/benchmark.py @@ -1,15 +1,16 @@ -import av +import argparse +import asyncio import json -import time -import torch import logging -import asyncio -import argparse +import time + +import av import numpy as np +import torch from comfystream.client import ComfyStreamClient -logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s') +logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger() @@ -22,14 +23,30 @@ def create_dummy_video_frame(width, height): async def main(): parser = argparse.ArgumentParser(description="Benchmark ComfyStreamClient workflow execution.") - parser.add_argument("--workflow-path", default="./workflows/comfystream/tensor-utils-example-api.json", help="Path to the workflow JSON file.") + parser.add_argument( + "--workflow-path", + default="./workflows/comfystream/tensor-utils-example-api.json", + help="Path to the workflow JSON file.", + ) parser.add_argument("--num-requests", type=int, default=100, help="Number of requests to send.") - parser.add_argument("--fps", type=float, default=None, help="Frames per second for FPS-based benchmarking.") - parser.add_argument("--cwd", default="/workspace/ComfyUI", help="Current working directory for ComfyStreamClient.") + parser.add_argument( + "--fps", type=float, default=None, help="Frames per second for FPS-based benchmarking." + ) + parser.add_argument( + "--cwd", + default="/workspace/ComfyUI", + help="Current working directory for ComfyStreamClient.", + ) parser.add_argument("--width", type=int, default=512, help="Width of dummy video frames.") parser.add_argument("--height", type=int, default=512, help="Height of dummy video frames.") - parser.add_argument("--verbose", action="store_true", help="Enable verbose logging (shows progress for each request).") - parser.add_argument("--warmup-runs", type=int, default=5, help="Number of warm-up runs before benchmarking.") + parser.add_argument( + "--verbose", + action="store_true", + help="Enable verbose logging (shows progress for each request).", + ) + parser.add_argument( + "--warmup-runs", type=int, default=5, help="Number of warm-up runs before benchmarking." + ) args = parser.parse_args() @@ -44,7 +61,9 @@ async def main(): client = ComfyStreamClient(cwd=args.cwd) await client.set_prompts([prompt]) - logger.info(f"Starting benchmark with workflow: {args.workflow_path}, requests: {args.num_requests}, resolution: {args.width}x{args.height}, warmup runs: {args.warmup_runs}") + logger.info( + f"Starting benchmark with workflow: {args.workflow_path}, requests: {args.num_requests}, resolution: {args.width}x{args.height}, warmup runs: {args.warmup_runs}" + ) if args.warmup_runs > 0: logger.info(f"Running {args.warmup_runs} warm-up runs...") @@ -66,11 +85,13 @@ async def main(): await client.get_video_output() request_end_time = time.time() round_trip_times.append(request_end_time - request_start_time) - logger.debug(f"Request {i+1}/{args.num_requests} completed in {round_trip_times[-1]:.4f} seconds") + logger.debug( + f"Request {i + 1}/{args.num_requests} completed in {round_trip_times[-1]:.4f} seconds" + ) end_time = time.time() total_time = end_time - start_time - output_fps = args.num_requests / total_time if total_time > 0 else float('inf') + output_fps = args.num_requests / total_time if total_time > 0 else float("inf") # Calculate percentiles for sequential mode p50_rtt = np.percentile(round_trip_times, 50) @@ -79,17 +100,17 @@ async def main(): p95_rtt = np.percentile(round_trip_times, 95) p99_rtt = np.percentile(round_trip_times, 99) - print("\n" + "="*40) + print("\n" + "=" * 40) print("FPS Results:") - print("="*40) + print("=" * 40) print(f"Total requests: {args.num_requests}") print(f"Total time: {total_time:.4f} seconds") print(f"Actual Output FPS:{output_fps:.2f}") print(f"Total requests: {args.num_requests}") print(f"Total time: {total_time:.4f} seconds") - print("\n" + "="*40) + print("\n" + "=" * 40) print("Latency Results:") - print("="*40) + print("=" * 40) print(f"Average: {np.mean(round_trip_times):.4f}") print(f"Min: {np.min(round_trip_times):.4f}") print(f"Max: {np.max(round_trip_times):.4f}") @@ -99,7 +120,6 @@ async def main(): print(f"P95: {p95_rtt:.4f}") print(f"P99: {p99_rtt:.4f}") - else: # This is mainly used to stress test the ComfyUI client, gives us a good idea on how frame skipping etc is working on the client end. logger.info(f"Running FPS-based benchmark at {args.fps} FPS...") @@ -118,9 +138,11 @@ async def collect_outputs_task(): last_output_receive_time = time.time() received_frames_count += 1 - logger.debug(f"Received output frame {received_frames_count} at {last_output_receive_time - start_time:.4f} seconds") + logger.debug( + f"Received output frame {received_frames_count} at {last_output_receive_time - start_time:.4f} seconds" + ) except asyncio.TimeoutError: - logger.debug(f"Output collection task timed out after waiting for 5 seconds.") + logger.debug("Output collection task timed out after waiting for 5 seconds.") break except Exception as e: logger.debug(f"Output collection task finished due to exception: {e}") @@ -140,7 +162,9 @@ async def collect_outputs_task(): request_send_time = time.time() client.put_video_input(frame) - logger.debug(f"Sent request {i+1}/{args.num_requests} at {request_send_time - start_time:.4f} seconds") + logger.debug( + f"Sent request {i + 1}/{args.num_requests} at {request_send_time - start_time:.4f} seconds" + ) await output_collector_task @@ -150,11 +174,11 @@ async def collect_outputs_task(): elif received_frames_count == 0: output_fps = 0.0 else: - output_fps = float('inf') + output_fps = float("inf") - print("\n" + "="*40) + print("\n" + "=" * 40) print("FPS Results:") - print("="*40) + print("=" * 40) print(f"Target Input FPS: {args.fps:.2f}") print(f"Actual Output FPS:{output_fps:.2f} ({received_frames_count} frames received)") print(f"Total requests: {args.num_requests}") @@ -162,4 +186,4 @@ async def collect_outputs_task(): if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/docker/Dockerfile.base b/docker/Dockerfile.base index 878f4c9bd..3ee9d95e8 100644 --- a/docker/Dockerfile.base +++ b/docker/Dockerfile.base @@ -46,6 +46,16 @@ RUN mkdir -p /workspace/comfystream && \ RUN conda run -n comfystream --no-capture-output pip install --upgrade pip && \ conda run -n comfystream --no-capture-output pip install wheel +# Remove system cuDNN to avoid version conflicts with PyTorch-bundled cuDNN +# The base image includes cuDNN 9.8, but PyTorch 2.7+ bundles cuDNN 9.7.1 +# Mixed versions cause CUDNN_STATUS_SUBLIBRARY_VERSION_MISMATCH errors +RUN apt-get remove --purge -y libcudnn9-cuda-12 libcudnn9-dev-cuda-12 || true && \ + apt-get autoremove -y && \ + rm -rf /var/lib/apt/lists/* + +# Install cuDNN 9.7.1 via conda to match PyTorch's bundled version +RUN conda install -n comfystream -y -c nvidia -c conda-forge cudnn=9.7.1 cuda-version=12.8 + # Copy only files needed for setup COPY ./src/comfystream/scripts /workspace/comfystream/src/comfystream/scripts COPY ./configs /workspace/comfystream/configs @@ -110,3 +120,8 @@ RUN conda config --set auto_activate_base false && \ RUN echo "conda activate comfystream" >> ~/.bashrc WORKDIR /workspace/comfystream + +# Run ComfyStream BYOC server from /workspace/ComfyUI within the comfystream conda env +ENTRYPOINT ["conda", "run", "--no-capture-output", "-n", "comfystream", "--cwd", "/workspace/ComfyUI", "python", "/workspace/comfystream/server/byoc.py"] +# Default args; can be overridden/appended at runtime +CMD ["--workspace", "/workspace/ComfyUI", "--host", "0.0.0.0", "--port", "8000"] diff --git a/example.py b/example.py index 6d68aac13..9f7567306 100644 --- a/example.py +++ b/example.py @@ -1,7 +1,8 @@ -import torch import asyncio import json +import torch + from comfystream.client import ComfyStreamClient diff --git a/install.py b/install.py index 385d160fd..7552ea8a9 100644 --- a/install.py +++ b/install.py @@ -1,19 +1,20 @@ -import os -import subprocess import argparse import logging +import os import pathlib +import subprocess import sys -import tarfile import tempfile import urllib.request -import toml import zipfile + +import toml from comfy_compatibility.workspace import auto_patch_workspace_and_restart -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) + def get_project_version(workspace: str) -> str: """Read project version from pyproject.toml""" pyproject_path = os.path.join(workspace, "pyproject.toml") @@ -25,19 +26,25 @@ def get_project_version(workspace: str) -> str: logger.error(f"Failed to read version from pyproject.toml: {e}") return "unknown" + def download_and_extract_ui_files(version: str): """Download and extract UI files to the workspace""" output_dir = os.path.join(os.getcwd(), "nodes", "web", "static") pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True) - base_url = urllib.parse.urljoin("https://github.com/livepeer/comfystream/releases/download/", f"v{version}/comfystream-uikit.zip") - fallback_url = "https://github.com/livepeer/comfystream/releases/latest/download/comfystream-uikit.zip" - + base_url = urllib.parse.urljoin( + "https://github.com/livepeer/comfystream/releases/download/", + f"v{version}/comfystream-uikit.zip", + ) + fallback_url = ( + "https://github.com/livepeer/comfystream/releases/latest/download/comfystream-uikit.zip" + ) + # Create a temporary directory instead of a temporary file with tempfile.TemporaryDirectory() as temp_dir: # Define the path for the downloaded file download_path = os.path.join(temp_dir, "comfystream-uikit.zip") - + # Download zip file logger.info(f"Downloading {base_url}") try: @@ -53,23 +60,27 @@ def download_and_extract_ui_files(version: str): else: logger.error(f"Error downloading package: {e}") raise - + # Extract contents try: logger.info(f"Extracting files to {output_dir}") - with zipfile.ZipFile(download_path, 'r') as zip_ref: + with zipfile.ZipFile(download_path, "r") as zip_ref: zip_ref.extractall(path=output_dir) except Exception as e: logger.error(f"Error extracting files: {e}") raise + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Install custom node requirements") parser.add_argument( - "--workspace", default=os.environ.get('COMFY_UI_WORKSPACE', None), required=False, help="Set Comfy workspace" + "--workspace", + default=os.environ.get("COMFY_UI_WORKSPACE", None), + required=False, + help="Set Comfy workspace", ) args = parser.parse_args() - + workspace = args.workspace if workspace is None: # Look up to 3 directories up for ComfyUI @@ -93,12 +104,14 @@ def download_and_extract_ui_files(version: str): subprocess.check_call([sys.executable, "-m", "pip", "install", "-e", "."]) if workspace is None: - logger.warning("No ComfyUI workspace found. Please specify a valid workspace path to fully install") - + logger.warning( + "No ComfyUI workspace found. Please specify a valid workspace path to fully install" + ) + if workspace is not None: logger.info("Patching ComfyUI workspace...") auto_patch_workspace_and_restart(workspace) - + logger.info("Downloading and extracting UI files...") version = get_project_version(os.getcwd()) download_and_extract_ui_files(version) diff --git a/nodes/__init__.py b/nodes/__init__.py index e3ca8f2ae..a590e9ccc 100644 --- a/nodes/__init__.py +++ b/nodes/__init__.py @@ -1,5 +1,7 @@ """ComfyStream nodes package""" -from comfy_compatibility.imports import ImportContext, SITE_PACKAGES, MAIN_PY + +from comfy_compatibility.imports import MAIN_PY, SITE_PACKAGES, ImportContext + with ImportContext("comfy", "comfy_extras", order=[SITE_PACKAGES, MAIN_PY]): from .audio_utils import * from .tensor_utils import * @@ -13,15 +15,16 @@ # Import and update mappings from submodules for module in [audio_utils, tensor_utils, video_stream_utils, api, web]: - if hasattr(module, 'NODE_CLASS_MAPPINGS'): + if hasattr(module, "NODE_CLASS_MAPPINGS"): NODE_CLASS_MAPPINGS.update(module.NODE_CLASS_MAPPINGS) - if hasattr(module, 'NODE_DISPLAY_NAME_MAPPINGS'): + if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS"): NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS) # Web directory for UI components import os + WEB_DIRECTORY = os.path.join(os.path.dirname(os.path.realpath(__file__)), "web") NODE_DISPLAY_NAME_MAPPINGS["ComfyStreamLauncher"] = "Launch ComfyStream 🚀" -__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] +__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] diff --git a/nodes/api/__init__.py b/nodes/api/__init__.py index ebc480859..eb6dbdf81 100644 --- a/nodes/api/__init__.py +++ b/nodes/api/__init__.py @@ -1,17 +1,22 @@ """ComfyStream API implementation""" + +import logging import os -import webbrowser -from server import PromptServer -from aiohttp import web import pathlib -import logging +import webbrowser + import aiohttp -from ..server_manager import LocalComfyStreamServer +from aiohttp import web + +from server import PromptServer + from .. import settings_storage +from ..server_manager import LocalComfyStreamServer routes = None server_manager = None + # Middleware to add Cache-Control: no-cache for index.html @web.middleware async def cache_control_middleware(request, handler): @@ -20,19 +25,22 @@ async def cache_control_middleware(request, handler): target_path = f"{STATIC_ROUTE}/index.html" # Log comparison details if request.path == target_path: - response.headers['Cache-Control'] = 'no-cache' - logging.debug(f"[CacheMiddleware] Added Cache-Control: no-cache for {request.path}") # Kept debug log + response.headers["Cache-Control"] = "no-cache" + logging.debug( + f"[CacheMiddleware] Added Cache-Control: no-cache for {request.path}" + ) # Kept debug log return response + # Only set up routes if we're in the main ComfyUI instance -if hasattr(PromptServer.instance, 'routes') and hasattr(PromptServer.instance.routes, 'static'): +if hasattr(PromptServer.instance, "routes") and hasattr(PromptServer.instance.routes, "static"): routes = PromptServer.instance.routes - + # Get the path to the static build directory - STATIC_DIR = pathlib.Path(__file__).parent.parent.parent / "nodes" / "web" / "static" - + STATIC_DIR = pathlib.Path(__file__).parent.parent.parent / "nodes" / "web" / "static" + # Dynamically determine the extension name from the directory structure - extension_name = "comfystream" # Define a local default for the try/except block + extension_name = "comfystream" # Define a local default for the try/except block try: # Get the parent directory of the current file # Then navigate up to get the extension root directory @@ -41,7 +49,9 @@ async def cache_control_middleware(request, handler): extension_name = EXTENSION_ROOT.name logging.info(f"Detected extension name: {extension_name}") except Exception as e: - logging.warning(f"Failed to get extension name dynamically: {e}, using fallback '{extension_name}'") + logging.warning( + f"Failed to get extension name dynamically: {e}, using fallback '{extension_name}'" + ) # Fallback name is already set by the initial local default # Define module-level constants AFTER determination @@ -53,30 +63,34 @@ async def cache_control_middleware(request, handler): routes.static(STATIC_ROUTE, str(STATIC_DIR), append_version=False, follow_symlinks=True) # Add the cache control middleware to the app - if hasattr(PromptServer.instance, 'app'): + if hasattr(PromptServer.instance, "app"): PromptServer.instance.app.middlewares.append(cache_control_middleware) logging.info(f"Added ComfyStream cache control middleware for {STATIC_ROUTE}/index.html") else: - logging.warning("Could not add ComfyStream cache control middleware: PromptServer.instance.app not found.") + logging.warning( + "Could not add ComfyStream cache control middleware: PromptServer.instance.app not found." + ) # Create server manager instance server_manager = LocalComfyStreamServer() - - @routes.get('/comfystream/extension_info') + + @routes.get("/comfystream/extension_info") async def get_extension_info(request): """Return extension information including name and paths""" try: - return web.json_response({ - "success": True, - "extension_name": EXTENSION_NAME, - "static_route": STATIC_ROUTE, - "ui_url": f"{STATIC_ROUTE}/index.html" - }) + return web.json_response( + { + "success": True, + "extension_name": EXTENSION_NAME, + "static_route": STATIC_ROUTE, + "ui_url": f"{STATIC_ROUTE}/index.html", + } + ) except Exception as e: logging.error(f"Error getting extension info: {str(e)}") return web.json_response({"success": False, "error": str(e)}, status=500) - - @routes.post('/api/offer') + + @routes.post("/api/offer") async def proxy_offer(request): """Proxy offer requests to the ComfyStream server""" try: @@ -89,64 +103,65 @@ async def proxy_offer(request): async with session.post( f"{target_url}/offer", json={"prompts": data.get("prompts"), "offer": data.get("offer")}, - headers={"Content-Type": "application/json"} + headers={"Content-Type": "application/json"}, ) as response: if not response.ok: return web.json_response( - {"error": f"Server error: {response.status}"}, - status=response.status + {"error": f"Server error: {response.status}"}, status=response.status ) return web.json_response(await response.json()) except Exception as e: logging.error(f"Error proxying offer: {str(e)}") return web.json_response({"error": str(e)}, status=500) - @routes.post('/comfystream/control') + @routes.post("/comfystream/control") async def control_server(request): """Handle server control requests""" try: data = await request.json() action = data.get("action") settings = data.get("settings", {}) - + # Extract host and port from settings if provided host = settings.get("host") if settings else None port = settings.get("port") if settings else None - + if action == "status": # Simply return the current server status - return web.json_response({ - "success": True, - "status": server_manager.get_status() - }) + return web.json_response({"success": True, "status": server_manager.get_status()}) elif action == "start": success = await server_manager.start(port=port, host=host) - return web.json_response({ - "success": success, - "status": server_manager.get_status() - }) + return web.json_response( + {"success": success, "status": server_manager.get_status()} + ) elif action == "stop": try: success = await server_manager.stop() - return web.json_response({ - "success": success, - "status": server_manager.get_status() - }) + return web.json_response( + {"success": success, "status": server_manager.get_status()} + ) except Exception as e: logging.error(f"Error stopping server: {str(e)}") # Force cleanup even if the normal stop fails server_manager.cleanup() - return web.json_response({ - "success": True, - "status": {"running": False, "port": None, "host": None, "pid": None, "type": "local"}, - "message": "Forced server shutdown due to error" - }) + return web.json_response( + { + "success": True, + "status": { + "running": False, + "port": None, + "host": None, + "pid": None, + "type": "local", + }, + "message": "Forced server shutdown due to error", + } + ) elif action == "restart": success = await server_manager.restart(port=port, host=host) - return web.json_response({ - "success": success, - "status": server_manager.get_status() - }) + return web.json_response( + {"success": success, "status": server_manager.get_status()} + ) else: return web.json_response({"error": "Invalid action"}, status=400) except Exception as e: @@ -155,17 +170,25 @@ async def control_server(request): if data and data.get("action") == "stop": try: server_manager.cleanup() - return web.json_response({ - "success": True, - "status": {"running": False, "port": None, "host": None, "pid": None, "type": "local"}, - "message": "Forced server shutdown due to error" - }) + return web.json_response( + { + "success": True, + "status": { + "running": False, + "port": None, + "host": None, + "pid": None, + "type": "local", + }, + "message": "Forced server shutdown due to error", + } + ) except Exception as cleanup_error: logging.error(f"Error during forced cleanup: {str(cleanup_error)}") - + return web.json_response({"error": str(e)}, status=500) - @routes.get('/comfystream/settings') + @routes.get("/comfystream/settings") async def get_settings(request): """Get ComfyStream settings""" try: @@ -174,63 +197,59 @@ async def get_settings(request): except Exception as e: logging.error(f"Error getting settings: {str(e)}") return web.json_response({"error": str(e)}, status=500) - - @routes.post('/comfystream/settings') + + @routes.post("/comfystream/settings") async def update_settings(request): """Update ComfyStream settings""" try: data = await request.json() success = settings_storage.update_settings(data) - return web.json_response({ - "success": success, - "settings": settings_storage.load_settings() - }) + return web.json_response( + {"success": success, "settings": settings_storage.load_settings()} + ) except Exception as e: logging.error(f"Error updating settings: {str(e)}") return web.json_response({"error": str(e)}, status=500) - - @routes.post('/comfystream/settings/configuration') + + @routes.post("/comfystream/settings/configuration") async def manage_configuration(request): """Add, remove, or select a configuration""" try: data = await request.json() action = data.get("action") - + if action == "add": name = data.get("name") host = data.get("host") port = data.get("port") if not name or not host or not port: return web.json_response({"error": "Missing required parameters"}, status=400) - + success = settings_storage.add_configuration(name, host, port) - return web.json_response({ - "success": success, - "settings": settings_storage.load_settings() - }) - + return web.json_response( + {"success": success, "settings": settings_storage.load_settings()} + ) + elif action == "remove": index = data.get("index") if index is None: return web.json_response({"error": "Missing index parameter"}, status=400) - + success = settings_storage.remove_configuration(index) - return web.json_response({ - "success": success, - "settings": settings_storage.load_settings() - }) - + return web.json_response( + {"success": success, "settings": settings_storage.load_settings()} + ) + elif action == "select": index = data.get("index") if index is None: return web.json_response({"error": "Missing index parameter"}, status=400) - + success = settings_storage.select_configuration(index) - return web.json_response({ - "success": success, - "settings": settings_storage.load_settings() - }) - + return web.json_response( + {"success": success, "settings": settings_storage.load_settings()} + ) + else: return web.json_response({"error": "Invalid action"}, status=400) except Exception as e: diff --git a/nodes/audio_utils/__init__.py b/nodes/audio_utils/__init__.py index bc090689a..f9c84d1f8 100644 --- a/nodes/audio_utils/__init__.py +++ b/nodes/audio_utils/__init__.py @@ -1,6 +1,6 @@ from .load_audio_tensor import LoadAudioTensor -from .save_audio_tensor import SaveAudioTensor from .pitch_shift import PitchShifter +from .save_audio_tensor import SaveAudioTensor NODE_CLASS_MAPPINGS = { "LoadAudioTensor": LoadAudioTensor, diff --git a/nodes/audio_utils/load_audio_tensor.py b/nodes/audio_utils/load_audio_tensor.py index eed09eea9..919a75991 100644 --- a/nodes/audio_utils/load_audio_tensor.py +++ b/nodes/audio_utils/load_audio_tensor.py @@ -1,8 +1,10 @@ +import queue + import numpy as np import torch -import queue + from comfystream import tensor_cache -from comfystream.exceptions import ComfyStreamInputTimeoutError, ComfyStreamAudioBufferError +from comfystream.exceptions import ComfyStreamAudioBufferError, ComfyStreamInputTimeoutError class LoadAudioTensor: @@ -11,33 +13,36 @@ class LoadAudioTensor: RETURN_NAMES = ("audio",) FUNCTION = "execute" DESCRIPTION = "Load audio tensor from ComfyStream input with timeout." - + def __init__(self): self.audio_buffer = np.empty(0, dtype=np.int16) self.buffer_samples = None self.sample_rate = None self.leftover = np.empty(0, dtype=np.int16) - + @classmethod def INPUT_TYPES(cls): return { "required": { - "buffer_size": ("FLOAT", { - "default": 500.0, - "tooltip": "Audio buffer size in milliseconds" - }), + "buffer_size": ( + "FLOAT", + {"default": 500.0, "tooltip": "Audio buffer size in milliseconds"}, + ), }, "optional": { - "timeout_seconds": ("FLOAT", { - "default": 1.0, - "min": 0.1, - "max": 30.0, - "step": 0.1, - "tooltip": "Timeout in seconds" - }), - } + "timeout_seconds": ( + "FLOAT", + { + "default": 1.0, + "min": 0.1, + "max": 30.0, + "step": 0.1, + "tooltip": "Timeout in seconds", + }, + ), + }, } - + @classmethod def IS_CHANGED(cls, **kwargs): return float("nan") @@ -52,37 +57,45 @@ def execute(self, buffer_size: float, timeout_seconds: float = 1.0): self.leftover = frame.side_data.input except queue.Empty: raise ComfyStreamInputTimeoutError("audio", timeout_seconds) - + # Use leftover data if available if self.leftover.shape[0] >= self.buffer_samples: - buffered_audio = self.leftover[:self.buffer_samples] - self.leftover = self.leftover[self.buffer_samples:] + buffered_audio = self.leftover[: self.buffer_samples] + self.leftover = self.leftover[self.buffer_samples :] else: # Collect more audio chunks chunks = [self.leftover] if self.leftover.size > 0 else [] total_samples = self.leftover.shape[0] - + while total_samples < self.buffer_samples: try: frame = tensor_cache.audio_inputs.get(block=True, timeout=timeout_seconds) if frame.sample_rate != self.sample_rate: - raise ValueError(f"Sample rate mismatch: expected {self.sample_rate}Hz, got {frame.sample_rate}Hz") + raise ValueError( + f"Sample rate mismatch: expected {self.sample_rate}Hz, got {frame.sample_rate}Hz" + ) chunks.append(frame.side_data.input) total_samples += frame.side_data.input.shape[0] except queue.Empty: - raise ComfyStreamAudioBufferError(timeout_seconds, self.buffer_samples, total_samples) - + raise ComfyStreamAudioBufferError( + timeout_seconds, self.buffer_samples, total_samples + ) + merged_audio = np.concatenate(chunks, dtype=np.int16) - buffered_audio = merged_audio[:self.buffer_samples] - self.leftover = merged_audio[self.buffer_samples:] if merged_audio.shape[0] > self.buffer_samples else np.empty(0, dtype=np.int16) - + buffered_audio = merged_audio[: self.buffer_samples] + self.leftover = ( + merged_audio[self.buffer_samples :] + if merged_audio.shape[0] > self.buffer_samples + else np.empty(0, dtype=np.int16) + ) + # Convert to ComfyUI AUDIO format waveform_tensor = torch.from_numpy(buffered_audio.astype(np.float32) / 32768.0) - + # Ensure proper tensor shape: (batch, channels, samples) if waveform_tensor.dim() == 1: waveform_tensor = waveform_tensor.unsqueeze(0).unsqueeze(0) elif waveform_tensor.dim() == 2: waveform_tensor = waveform_tensor.unsqueeze(0) - - return ({"waveform": waveform_tensor, "sample_rate": self.sample_rate},) \ No newline at end of file + + return ({"waveform": waveform_tensor, "sample_rate": self.sample_rate},) diff --git a/nodes/audio_utils/pitch_shift.py b/nodes/audio_utils/pitch_shift.py index 2fba9ee59..6e2456a2a 100644 --- a/nodes/audio_utils/pitch_shift.py +++ b/nodes/audio_utils/pitch_shift.py @@ -1,23 +1,19 @@ -import numpy as np import librosa +import numpy as np import torch + class PitchShifter: CATEGORY = "audio_utils" RETURN_TYPES = ("AUDIO",) FUNCTION = "execute" - + @classmethod def INPUT_TYPES(cls): return { "required": { "audio": ("AUDIO",), - "pitch_shift": ("FLOAT", { - "default": 4.0, - "min": 0.0, - "max": 12.0, - "step": 0.5 - }), + "pitch_shift": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 12.0, "step": 0.5}), } } @@ -29,37 +25,36 @@ def execute(self, audio, pitch_shift): # Extract waveform and sample rate from AUDIO format waveform = audio["waveform"] sample_rate = audio["sample_rate"] - + # Convert tensor to numpy and ensure proper format for librosa if isinstance(waveform, torch.Tensor): audio_numpy = waveform.squeeze().cpu().numpy() else: audio_numpy = waveform.squeeze() - + # Ensure float32 format and proper normalization for librosa processing if audio_numpy.dtype != np.float32: audio_numpy = audio_numpy.astype(np.float32) - + # Check if data needs normalization (librosa expects [-1, 1] range) max_abs_val = np.abs(audio_numpy).max() if max_abs_val > 1.0: # Data appears to be in int16 range, normalize it audio_numpy = audio_numpy / 32768.0 - + # Apply pitch shift - shifted_audio = librosa.effects.pitch_shift(y=audio_numpy, sr=sample_rate, n_steps=pitch_shift) - + shifted_audio = librosa.effects.pitch_shift( + y=audio_numpy, sr=sample_rate, n_steps=pitch_shift + ) + # Convert back to tensor and restore original shape shifted_tensor = torch.from_numpy(shifted_audio).float() if waveform.dim() == 3: # (batch, channels, samples) shifted_tensor = shifted_tensor.unsqueeze(0).unsqueeze(0) - elif waveform.dim() == 2: # (channels, samples) + elif waveform.dim() == 2: # (channels, samples) shifted_tensor = shifted_tensor.unsqueeze(0) - + # Return AUDIO format - result_audio = { - "waveform": shifted_tensor, - "sample_rate": sample_rate - } - + result_audio = {"waveform": shifted_tensor, "sample_rate": sample_rate} + return (result_audio,) diff --git a/nodes/audio_utils/save_audio_tensor.py b/nodes/audio_utils/save_audio_tensor.py index 6b7b0281c..eed9bf6ad 100644 --- a/nodes/audio_utils/save_audio_tensor.py +++ b/nodes/audio_utils/save_audio_tensor.py @@ -1,20 +1,17 @@ import numpy as np + from comfystream import tensor_cache + class SaveAudioTensor: CATEGORY = "audio_utils" RETURN_TYPES = () FUNCTION = "execute" OUTPUT_NODE = True - @classmethod def INPUT_TYPES(s): - return { - "required": { - "audio": ("AUDIO",) - } - } + return {"required": {"audio": ("AUDIO",)}} @classmethod def IS_CHANGED(s): @@ -23,24 +20,24 @@ def IS_CHANGED(s): def execute(self, audio): # Extract waveform tensor from AUDIO format waveform = audio["waveform"] - + # Convert to numpy and flatten for pipeline compatibility - if hasattr(waveform, 'cpu'): + if hasattr(waveform, "cpu"): # PyTorch tensor waveform_numpy = waveform.squeeze().cpu().numpy() else: # Already numpy waveform_numpy = waveform.squeeze() - + # Ensure 1D array for pipeline buffer concatenation if waveform_numpy.ndim > 1: waveform_numpy = waveform_numpy.flatten() - + # Convert to int16 if needed (pipeline expects int16) if waveform_numpy.dtype == np.float32: waveform_numpy = (waveform_numpy * 32767).astype(np.int16) elif waveform_numpy.dtype != np.int16: waveform_numpy = waveform_numpy.astype(np.int16) - + tensor_cache.audio_outputs.put_nowait(waveform_numpy) return (audio,) diff --git a/nodes/server_manager.py b/nodes/server_manager.py index 9b55f4f57..cd29bc07f 100644 --- a/nodes/server_manager.py +++ b/nodes/server_manager.py @@ -1,88 +1,83 @@ """ComfyStream server management module""" -import os -import sys -import subprocess -import socket -import signal + +import asyncio import atexit import logging -import urllib.request +import os +import socket +import subprocess +import sys +import threading import urllib.error -from pathlib import Path -import time -import asyncio +import urllib.request from abc import ABC, abstractmethod -import threading +from pathlib import Path # Configure logging to output to console -logging.basicConfig( - level=logging.INFO, - format='[ComfyStream] %(message)s', - stream=sys.stdout -) +logging.basicConfig(level=logging.INFO, format="[ComfyStream] %(message)s", stream=sys.stdout) # Set up Windows specific event loop policy -if sys.platform == 'win32': - if hasattr(asyncio, 'WindowsSelectorEventLoopPolicy'): +if sys.platform == "win32": + if hasattr(asyncio, "WindowsSelectorEventLoopPolicy"): asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - + class ComfyStreamServerBase(ABC): """Abstract base class for ComfyStream server management""" - + def __init__(self, host="0.0.0.0", port=None): self.host = host self.port = port self.is_running = False logging.info(f"Initializing {self.__class__.__name__}") - + @abstractmethod async def start(self, port=None, host=None) -> bool: """Start the ComfyStream server - + Args: port: Optional port to use. If None, implementation should choose a port. host: Optional host to use. If None, implementation should use the default host. - + Returns: bool: True if server started successfully, False otherwise """ pass - + @abstractmethod async def stop(self) -> bool: """Stop the ComfyStream server - + Returns: bool: True if server stopped successfully, False otherwise """ pass - + @abstractmethod def get_status(self) -> dict: """Get current server status - + Returns: dict: Server status information """ pass - + @abstractmethod def check_server_health(self) -> bool: """Check if server is responding to health checks - + Returns: bool: True if server is healthy, False otherwise """ pass - + async def restart(self, port=None, host=None) -> bool: """Restart the ComfyStream server - + Args: port: Optional port to use. If None, use the current port. host: Optional host to use. If None, use the current host. - + Returns: bool: True if server restarted successfully, False otherwise """ @@ -93,11 +88,18 @@ async def restart(self, port=None, host=None) -> bool: await self.stop() return await self.start(port=port_to_use, host=host_to_use) + class LocalComfyStreamServer(ComfyStreamServerBase): """Local ComfyStream server implementation""" - - def __init__(self, host="0.0.0.0", start_port=8889, max_port=65535, - health_check_timeout=30, health_check_interval=1): + + def __init__( + self, + host="0.0.0.0", + start_port=8889, + max_port=65535, + health_check_timeout=30, + health_check_interval=1, + ): super().__init__(host=host) self.process = None self.start_port = start_port @@ -105,7 +107,7 @@ def __init__(self, host="0.0.0.0", start_port=8889, max_port=65535, self.health_check_timeout = health_check_timeout self.health_check_interval = health_check_interval atexit.register(self.cleanup) - + def find_available_port(self): """Find an available port starting from start_port""" port = self.start_port @@ -123,7 +125,7 @@ def check_server_health(self): """Check if server is responding to health checks""" if not self.port: return False - + url = f"http://{self.host}:{self.port}" try: response = urllib.request.urlopen(url) @@ -133,7 +135,7 @@ def check_server_health(self): def log_subprocess_output(self, pipe, level): """Log the output from the subprocess to the logger.""" - for line in iter(pipe.readline, b''): + for line in iter(pipe.readline, b""): logging.log(level, line.decode().strip()) async def start(self, port=None, host=None): @@ -146,37 +148,48 @@ async def start(self, port=None, host=None): self.port = port or self.find_available_port() if host is not None: self.host = host - + # Get the path to the ComfyStream server directory and script server_dir = Path(__file__).parent.parent / "server" server_script = server_dir / "app.py" logging.info(f"Server script: {server_script}") - + # Get ComfyUI workspace path (which is where we'll run from) comfyui_workspace = Path(__file__).parent.parent.parent.parent logging.info(f"ComfyUI workspace: {comfyui_workspace}") - + # Use the system Python (which should have ComfyStream installed) - cmd = [sys.executable, "-u", str(server_script), - "--port", str(self.port), - "--host", str(self.host), - "--workspace", str(comfyui_workspace)] - + cmd = [ + sys.executable, + "-u", + str(server_script), + "--port", + str(self.port), + "--host", + str(self.host), + "--workspace", + str(comfyui_workspace), + ] + logging.info(f"Starting server with command: {' '.join(cmd)}") - + # Start process with output going to pipes self.process = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=str(comfyui_workspace), # Run from ComfyUI root - env={**os.environ, 'PYTHONUNBUFFERED': '1'} + env={**os.environ, "PYTHONUNBUFFERED": "1"}, ) - + # Start threads to log stdout and stderr - threading.Thread(target=self.log_subprocess_output, args=(self.process.stdout, logging.INFO)).start() - threading.Thread(target=self.log_subprocess_output, args=(self.process.stderr, logging.ERROR)).start() - + threading.Thread( + target=self.log_subprocess_output, args=(self.process.stdout, logging.INFO) + ).start() + threading.Thread( + target=self.log_subprocess_output, args=(self.process.stderr, logging.ERROR) + ).start() + # Wait for server to start responding logging.info("Waiting for server to start...") for _ in range(self.health_check_timeout): @@ -185,15 +198,19 @@ async def start(self, port=None, host=None): break await asyncio.sleep(self.health_check_interval) else: - raise RuntimeError(f"Server failed to start after {self.health_check_timeout} seconds") - + raise RuntimeError( + f"Server failed to start after {self.health_check_timeout} seconds" + ) + if self.process.poll() is not None: raise RuntimeError(f"Server failed to start (exit code: {self.process.poll()})") - + self.is_running = True - logging.info(f"ComfyStream server started on port {self.port} (PID: {self.process.pid})") + logging.info( + f"ComfyStream server started on port {self.port} (PID: {self.process.pid})" + ) return True - + except Exception as e: logging.error(f"Error starting ComfyStream server: {str(e)}") self.cleanup() @@ -204,7 +221,7 @@ async def stop(self): if not self.is_running: logging.info("Server is not running") return False - + try: self.cleanup() logging.info("ComfyStream server stopped") @@ -220,7 +237,7 @@ def get_status(self): "port": self.port, "host": self.host, "pid": self.process.pid if self.process else None, - "type": "local" + "type": "local", } logging.info(f"Server status: {status}") return status @@ -247,4 +264,4 @@ def cleanup(self): except Exception as e: logging.error(f"Error cleaning up server process: {str(e)}") self.process = None - self.is_running = False \ No newline at end of file + self.is_running = False diff --git a/nodes/settings_storage.py b/nodes/settings_storage.py index 5d888b611..8b9ecf54b 100644 --- a/nodes/settings_storage.py +++ b/nodes/settings_storage.py @@ -1,53 +1,53 @@ """ComfyStream server-side settings storage module""" -import os + import json import logging +import os import threading from pathlib import Path # Configure logging -logging.basicConfig( - level=logging.INFO, - format='[ComfyStream Settings] %(message)s' -) +logging.basicConfig(level=logging.INFO, format="[ComfyStream Settings] %(message)s") # Default settings DEFAULT_SETTINGS = { "host": "0.0.0.0", "port": 8889, "configurations": [], - "selectedConfigIndex": -1 + "selectedConfigIndex": -1, } # Lock for thread-safe file operations settings_lock = threading.Lock() + def get_settings_file_path(): """Get the path to the settings file""" # Store settings in the extension directory extension_dir = Path(__file__).parent.parent settings_dir = extension_dir / "settings" - + # Create settings directory if it doesn't exist os.makedirs(settings_dir, exist_ok=True) - + return settings_dir / "comfystream_settings.json" + def load_settings(): """Load settings from file""" settings_file = get_settings_file_path() - + with settings_lock: try: if settings_file.exists(): - with open(settings_file, 'r') as f: + with open(settings_file, "r") as f: settings = json.load(f) - + # Ensure all default keys exist for key, value in DEFAULT_SETTINGS.items(): if key not in settings: settings[key] = value - + return settings else: return DEFAULT_SETTINGS.copy() @@ -55,53 +55,57 @@ def load_settings(): logging.error(f"Error loading settings: {str(e)}") return DEFAULT_SETTINGS.copy() + def save_settings(settings): """Save settings to file""" settings_file = get_settings_file_path() - + with settings_lock: try: - with open(settings_file, 'w') as f: + with open(settings_file, "w") as f: json.dump(settings, f, indent=2) return True except Exception as e: logging.error(f"Error saving settings: {str(e)}") return False + def update_settings(new_settings): """Update settings with new values""" current_settings = load_settings() - + # Update only the keys that are provided for key, value in new_settings.items(): current_settings[key] = value - + return save_settings(current_settings) + def add_configuration(name, host, port): """Add a new configuration""" settings = load_settings() - + # Create the new configuration config = {"name": name, "host": host, "port": port} - + # Add to configurations list settings["configurations"].append(config) - + # Save updated settings return save_settings(settings) + def remove_configuration(index): """Remove a configuration by index""" settings = load_settings() - + if index < 0 or index >= len(settings["configurations"]): logging.error(f"Invalid configuration index: {index}") return False - + # Remove the configuration settings["configurations"].pop(index) - + # Update selectedConfigIndex if needed if settings["selectedConfigIndex"] == index: # The selected config was deleted @@ -109,25 +113,26 @@ def remove_configuration(index): elif settings["selectedConfigIndex"] > index: # The selected config is after the deleted one, adjust index settings["selectedConfigIndex"] -= 1 - + # Save updated settings return save_settings(settings) + def select_configuration(index): """Select a configuration by index""" settings = load_settings() - + if index == -1 or (index >= 0 and index < len(settings["configurations"])): settings["selectedConfigIndex"] = index - + # If a valid configuration is selected, update host and port if index >= 0: config = settings["configurations"][index] settings["host"] = config["host"] settings["port"] = config["port"] - + # Save updated settings return save_settings(settings) else: logging.error(f"Invalid configuration index: {index}") - return False \ No newline at end of file + return False diff --git a/nodes/tensor_utils/load_tensor.py b/nodes/tensor_utils/load_tensor.py index 9923a996a..a2fb59408 100644 --- a/nodes/tensor_utils/load_tensor.py +++ b/nodes/tensor_utils/load_tensor.py @@ -1,5 +1,5 @@ -import torch import queue + from comfystream import tensor_cache from comfystream.exceptions import ComfyStreamInputTimeoutError @@ -14,13 +14,16 @@ class LoadTensor: def INPUT_TYPES(cls): return { "optional": { - "timeout_seconds": ("FLOAT", { - "default": 1.0, - "min": 0.1, - "max": 30.0, - "step": 0.1, - "tooltip": "Timeout in seconds" - }), + "timeout_seconds": ( + "FLOAT", + { + "default": 1.0, + "min": 0.1, + "max": 30.0, + "step": 0.1, + "tooltip": "Timeout in seconds", + }, + ), } } diff --git a/nodes/tensor_utils/save_text_tensor.py b/nodes/tensor_utils/save_text_tensor.py index 098887e07..defb3aaa8 100644 --- a/nodes/tensor_utils/save_text_tensor.py +++ b/nodes/tensor_utils/save_text_tensor.py @@ -1,5 +1,6 @@ from comfystream import tensor_cache + class SaveTextTensor: CATEGORY = "text_utils" RETURN_TYPES = () @@ -13,8 +14,11 @@ def INPUT_TYPES(s): "data": ("STRING",), # Accept text string as input. }, "optional": { - "remove_linebreaks": ("BOOLEAN", {"default": True}), # Remove whitespace and line breaks - } + "remove_linebreaks": ( + "BOOLEAN", + {"default": True}, + ), # Remove whitespace and line breaks + }, } @classmethod @@ -23,7 +27,7 @@ def IS_CHANGED(s, **kwargs): def execute(self, data, remove_linebreaks=True): if remove_linebreaks: - result_text = data.replace('\n', '').replace('\r', '') + result_text = data.replace("\n", "").replace("\r", "") else: result_text = data tensor_cache.text_outputs.put_nowait(result_text) diff --git a/nodes/video_stream_utils/__init__.py b/nodes/video_stream_utils/__init__.py index 787f84ce9..f48631984 100644 --- a/nodes/video_stream_utils/__init__.py +++ b/nodes/video_stream_utils/__init__.py @@ -1,6 +1,6 @@ """Video stream utility nodes for ComfyStream""" -from .primary_input_load_image import PrimaryInputLoadImage +from .primary_input_load_image import PrimaryInputLoadImage NODE_CLASS_MAPPINGS = {"PrimaryInputLoadImage": PrimaryInputLoadImage} NODE_DISPLAY_NAME_MAPPINGS = {} diff --git a/nodes/web/__init__.py b/nodes/web/__init__.py index ef41b57aa..caa6b49a5 100644 --- a/nodes/web/__init__.py +++ b/nodes/web/__init__.py @@ -1,7 +1,10 @@ """ComfyStream Web UI nodes""" + import os + import folder_paths + # Define a simple Python class for the UI Preview node class ComfyStreamUIPreview: """ @@ -9,29 +12,24 @@ class ComfyStreamUIPreview: It's needed for ComfyUI to properly register and execute the node. The actual implementation is in the JavaScript file. """ + @classmethod def INPUT_TYPES(cls): - return { - "required": {}, - "optional": {} - } - + return {"required": {}, "optional": {}} + RETURN_TYPES = () - + FUNCTION = "execute" CATEGORY = "ComfyStream" - + def execute(self): # This function doesn't do anything as the real work is done in JavaScript # But we need to return something to satisfy the ComfyUI node execution system return ("UI Preview Node Executed",) + # Register the node class -NODE_CLASS_MAPPINGS = { - "ComfyStreamUIPreview": ComfyStreamUIPreview -} +NODE_CLASS_MAPPINGS = {"ComfyStreamUIPreview": ComfyStreamUIPreview} # Display names for the nodes -NODE_DISPLAY_NAME_MAPPINGS = { - "ComfyStreamUIPreview": "ComfyStream UI Preview" -} \ No newline at end of file +NODE_DISPLAY_NAME_MAPPINGS = {"ComfyStreamUIPreview": "ComfyStream UI Preview"} diff --git a/pyproject.toml b/pyproject.toml index ecf846f78..e22a59559 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,15 +1,15 @@ [build-system] -requires = ["setuptools>=64.0.0", "wheel"] +requires = ["setuptools>=64.0.0,<81", "wheel"] build-backend = "setuptools.build_meta" [project] name = "comfystream" description = "Build Live AI Video with ComfyUI" -version = "0.1.6" +version = "0.1.7" license = { file = "LICENSE" } dependencies = [ "asyncio", - "pytrickle @ git+https://github.com/livepeer/pytrickle.git@v0.1.4", + "pytrickle @ git+https://github.com/livepeer/pytrickle.git@v0.1.5", "comfyui @ git+https://github.com/hiddenswitch/ComfyUI.git@e62df3a8811d8c652a195d4669f4fb27f6c9a9ba", "aiortc", "aiohttp", @@ -21,7 +21,7 @@ dependencies = [ ] [project.optional-dependencies] -dev = ["pytest", "pytest-cov"] +dev = ["pytest", "pytest-cov", "ruff"] [project.urls] Repository = "https://github.com/yondonfu/comfystream" @@ -37,3 +37,23 @@ packages = {find = {where = ["src", "nodes"]}} [tool.setuptools.dynamic] dependencies = {file = ["requirements.txt"]} + +[tool.ruff] +line-length = 100 +target-version = "py312" + +[tool.ruff.lint] +select = ["E", "F", "I"] +ignore = [ + "E501", # let the formatter handle long lines + "E402", # module level import not at top (required for ComfyUI) + "E722", # bare except (existing code patterns) + "F401", # imported but unused (many re-exports) + "F403", # star imports (ComfyUI pattern) + "F405", # may be undefined from star imports + "F841", # assigned but never used (temporary) +] + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" diff --git a/requirements.txt b/requirements.txt index b35bb52cc..da59730dc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ asyncio -pytrickle @ git+https://github.com/livepeer/pytrickle.git@v0.1.4 +pytrickle @ git+https://github.com/livepeer/pytrickle.git@v0.1.5 comfyui @ git+https://github.com/hiddenswitch/ComfyUI.git@e62df3a8811d8c652a195d4669f4fb27f6c9a9ba aiortc aiohttp diff --git a/scripts/monitor_pid_resources.py b/scripts/monitor_pid_resources.py index a7de85636..25ffc97a0 100644 --- a/scripts/monitor_pid_resources.py +++ b/scripts/monitor_pid_resources.py @@ -1,16 +1,17 @@ """A Python script to monitor system resources for a given PID and optionally create a py-spy profiler report.""" -import psutil -import pynvml -import time +import csv import subprocess -import click import threading -import csv +import time from pathlib import Path from typing import List +import click +import psutil +import pynvml + def is_running_inside_container(): """Detects if the script is running inside a container.""" @@ -131,30 +132,20 @@ def find_pid_by_name(name: str) -> int: for proc in psutil.process_iter(["pid", "name", "cmdline"]): if proc.info["cmdline"] and name in proc.info["cmdline"]: found_pid = proc.info["pid"] - click.echo( - click.style(f"Found process '{name}' with PID {found_pid}.", fg="green") - ) + click.echo(click.style(f"Found process '{name}' with PID {found_pid}.", fg="green")) return found_pid click.echo(click.style(f"Error: Process with name '{name}' not found.", fg="red")) return None @click.command() -@click.option( - "--pid", type=str, default="auto", help='Process ID or "auto" to find by name' -) -@click.option( - "--name", type=str, default="app.py", help="Process name (default: app.py)" -) +@click.option("--pid", type=str, default="auto", help='Process ID or "auto" to find by name') +@click.option("--name", type=str, default="app.py", help="Process name (default: app.py)") @click.option("--interval", type=int, default=2, help="Monitoring interval (seconds)") -@click.option( - "--duration", type=int, default=30, help="Total monitoring duration (seconds)" -) +@click.option("--duration", type=int, default=30, help="Total monitoring duration (seconds)") @click.option("--output", type=str, default=None, help="File to save logs (optional)") @click.option("--spy", is_flag=True, help="Enable py-spy profiling") -@click.option( - "--spy-output", type=str, default="pyspy_profile.svg", help="Py-Spy output file" -) +@click.option("--spy-output", type=str, default="pyspy_profile.svg", help="Py-Spy output file") @click.option( "--host-pid", type=int, @@ -213,21 +204,15 @@ def monitor_resources( click.echo(click.style(f"Error: Process with PID {pid} not found.", fg="red")) return - click.echo( - click.style(f"Monitoring PID {pid} for {duration} seconds...", fg="green") - ) + click.echo(click.style(f"Monitoring PID {pid} for {duration} seconds...", fg="green")) def run_py_spy(): """Run py-spy profiler for deep profiling.""" click.echo(click.style("Running py-spy for deep profiling...", fg="green")) spy_cmd = f"py-spy record -o {spy_output} --pid {pid} --duration {duration}" try: - subprocess.run( - spy_cmd, shell=True, check=True, capture_output=True, text=True - ) - click.echo( - click.style(f"Py-Spy flame graph saved to {spy_output}", fg="green") - ) + subprocess.run(spy_cmd, shell=True, check=True, capture_output=True, text=True) + click.echo(click.style(f"Py-Spy flame graph saved to {spy_output}", fg="green")) except subprocess.CalledProcessError as e: click.echo(click.style(f"Error running py-spy: {e.stderr}", fg="red")) diff --git a/server/app.py b/server/app.py index 46f39ebb9..b93e35ee4 100644 --- a/server/app.py +++ b/server/app.py @@ -4,8 +4,7 @@ import logging import os import sys -import time -import secrets + import torch # Initialize CUDA before any other imports to prevent core dump. @@ -20,15 +19,16 @@ RTCPeerConnection, RTCSessionDescription, ) + # Import HTTP streaming modules from aiortc.codecs import h264 from aiortc.rtcrtpsender import RTCRtpSender -from comfystream.pipeline import Pipeline from twilio.rest import Client -from comfystream.server.utils import patch_loop_datagram, add_prefix_to_app_routes, FPSMeter + from comfystream.exceptions import ComfyStreamTimeoutFilter +from comfystream.pipeline import Pipeline from comfystream.server.metrics import MetricsManager, StreamStatsManager -import time +from comfystream.server.utils import FPSMeter, add_prefix_to_app_routes, patch_loop_datagram logger = logging.getLogger(__name__) logging.getLogger("aiortc.rtcrtpsender").setLevel(logging.WARNING) @@ -61,12 +61,10 @@ def __init__(self, track: MediaStreamTrack, pipeline: Pipeline): super().__init__() self.track = track self.pipeline = pipeline - self.fps_meter = FPSMeter( - metrics_manager=app["metrics_manager"], track_id=track.id - ) + self.fps_meter = FPSMeter(metrics_manager=app["metrics_manager"], track_id=track.id) self.running = True self.collect_task = asyncio.create_task(self.collect_frames()) - + # Add cleanup when track ends @track.on("ended") async def on_ended(): @@ -92,15 +90,13 @@ async def collect_frames(self): logger.error(f"Error collecting video frames: {str(e)}") self.running = False break - + # Perform cleanup outside the exception handler logger.info("Video frame collection stopped") except asyncio.CancelledError: logger.info("Frame collection task cancelled") except Exception as e: logger.error(f"Unexpected error in frame collection: {str(e)}") - finally: - await self.pipeline.cleanup() async def recv(self): """Receive a processed video frame from the pipeline, increment the frame @@ -116,6 +112,7 @@ async def recv(self): class NoopVideoStreamTrack(MediaStreamTrack): """Simple passthrough video track that bypasses pipeline processing.""" + kind = "video" def __init__(self, track: MediaStreamTrack): @@ -135,6 +132,7 @@ async def recv(self): class NoopAudioStreamTrack(MediaStreamTrack): """Simple passthrough audio track that bypasses pipeline processing.""" + kind = "audio" def __init__(self, track: MediaStreamTrack): @@ -162,7 +160,7 @@ def __init__(self, track: MediaStreamTrack, pipeline): self.running = True logger.info(f"AudioStreamTrack created for track {track.id}") self.collect_task = asyncio.create_task(self.collect_frames()) - + # Add cleanup when track ends @track.on("ended") async def on_ended(): @@ -188,19 +186,18 @@ async def collect_frames(self): logger.error(f"Error collecting audio frames: {str(e)}") self.running = False break - + # Perform cleanup outside the exception handler logger.info("Audio frame collection stopped") except asyncio.CancelledError: logger.info("Frame collection task cancelled") except Exception as e: logger.error(f"Unexpected error in audio frame collection: {str(e)}") - finally: - await self.pipeline.cleanup() async def recv(self): return await self.pipeline.get_processed_audio_frame() + def force_codec(pc, sender, forced_codec): kind = forced_codec.split("/")[0] codecs = RTCRtpSender.getCapabilities(kind).codecs @@ -246,39 +243,42 @@ async def offer(request): pcs = request.app["pcs"] params = await request.json() - + # Check if this is noop mode (no prompts provided) prompts = params.get("prompts") is_noop_mode = not prompts - - if is_noop_mode: - logger.info("[Offer] No prompts provided - entering noop passthrough mode") - else: - await pipeline.set_prompts(prompts) - logger.info("[Offer] Set workflow prompts") - - # Set resolution if provided in the offer + resolution = params.get("resolution") if resolution: pipeline.width = resolution["width"] pipeline.height = resolution["height"] - logger.info(f"[Offer] Set pipeline resolution to {resolution['width']}x{resolution['height']}") + logger.info( + f"[Offer] Set pipeline resolution to {resolution['width']}x{resolution['height']}" + ) + + if is_noop_mode: + logger.info("[Offer] No prompts provided - entering noop passthrough mode") + else: + await pipeline.apply_prompts( + prompts, + skip_warmup=False, + ) + await pipeline.start_streaming() + logger.info("[Offer] Set workflow prompts, warmed pipeline, and started execution") offer_params = params["offer"] offer = RTCSessionDescription(sdp=offer_params["sdp"], type=offer_params["type"]) - + ice_servers = get_ice_servers() if len(ice_servers) > 0: - pc = RTCPeerConnection( - configuration=RTCConfiguration(iceServers=get_ice_servers()) - ) + pc = RTCPeerConnection(configuration=RTCConfiguration(iceServers=get_ice_servers())) else: pc = RTCPeerConnection() pcs.add(pc) tracks = {"video": None, "audio": None} - + # Flag to track if we've received resolution update resolution_received = {"value": False} @@ -311,56 +311,140 @@ def on_datachannel(channel): @channel.on("message") async def on_message(message): - try: - params = json.loads(message) + def send_json(payload): + channel.send(json.dumps(payload)) + + def send_success_response(response_type, **extra): + payload = {"type": response_type, "success": True} + payload.update(extra) + send_json(payload) + + def send_error_response(response_type, error_message, **extra): + payload = { + "type": response_type, + "success": False, + "error": error_message, + } + payload.update(extra) + send_json(payload) + + async def handle_get_nodes(_params): + nodes_info = await pipeline.get_nodes_info() + send_json({"type": "nodes_info", "nodes": nodes_info}) + + async def handle_update_prompts(_params): + if "prompts" not in _params: + logger.warning("[Control] Missing prompt in update_prompt message") + send_error_response( + "prompts_updated", "Missing 'prompts' in control message" + ) + return + try: + await pipeline.update_prompts(_params["prompts"]) + except Exception as e: + logger.error(f"Error updating prompt: {str(e)}") + send_error_response("prompts_updated", str(e)) + return + send_success_response("prompts_updated") + + async def handle_update_resolution(_params): + width = _params.get("width") + height = _params.get("height") + if width is None or height is None: + logger.warning( + "[Control] Missing width or height in update_resolution message" + ) + send_error_response( + "resolution_updated", + "Missing 'width' or 'height' in control message", + ) + return - if params.get("type") == "get_nodes": - nodes_info = await pipeline.get_nodes_info() - response = {"type": "nodes_info", "nodes": nodes_info} - channel.send(json.dumps(response)) - elif params.get("type") == "update_prompts": - if "prompts" not in params: - logger.warning( - "[Control] Missing prompt in update_prompt message" - ) + if is_noop_mode: + logger.info( + f"[Control] Noop mode - resolution update to {width}x{height} (no pipeline involved)" + ) + else: + # Update pipeline resolution for future frames + pipeline.width = width + pipeline.height = height + logger.info(f"[Control] Updated resolution to {width}x{height}") + + # Mark that we've received resolution + resolution_received["value"] = True + + if is_noop_mode: + logger.info("[Control] Noop mode - no warmup needed") + else: + # Note: Video warmup now happens during offer, not here + logger.info( + "[Control] Resolution updated - warmup was already performed during offer" + ) + + send_success_response("resolution_updated") + + async def handle_pause_prompts(_params): + if is_noop_mode: + logger.info("[Control] Noop mode - no prompts to pause") + else: + try: + await pipeline.pause_prompts() + logger.info("[Control] Paused prompt execution") + except Exception as e: + logger.error(f"[Control] Error pausing prompts: {str(e)}") + send_error_response("prompts_paused", str(e)) return + send_success_response("prompts_paused") + + async def handle_resume_prompts(_params): + if is_noop_mode: + logger.info("[Control] Noop mode - no prompts to resume") + else: try: - await pipeline.update_prompts(params["prompts"]) + await pipeline.start_streaming() + logger.info("[Control] Resumed prompt execution") except Exception as e: - logger.error(f"Error updating prompt: {str(e)}") - response = {"type": "prompts_updated", "success": True} - channel.send(json.dumps(response)) - elif params.get("type") == "update_resolution": - if "width" not in params or "height" not in params: - logger.warning("[Control] Missing width or height in update_resolution message") + logger.error(f"[Control] Error resuming prompts: {str(e)}") + send_error_response("prompts_resumed", str(e)) return - - if is_noop_mode: - logger.info(f"[Control] Noop mode - resolution update to {params['width']}x{params['height']} (no pipeline involved)") - else: - # Update pipeline resolution for future frames - pipeline.width = params["width"] - pipeline.height = params["height"] - logger.info(f"[Control] Updated resolution to {params['width']}x{params['height']}") - - # Mark that we've received resolution - resolution_received["value"] = True - - if is_noop_mode: - logger.info("[Control] Noop mode - no warmup needed") - else: - # Note: Video warmup now happens during offer, not here - logger.info("[Control] Resolution updated - warmup was already performed during offer") - - response = { - "type": "resolution_updated", - "success": True - } - channel.send(json.dumps(response)) + send_success_response("prompts_resumed") + + async def handle_stop_prompts(_params): + if is_noop_mode: + logger.info("[Control] Noop mode - no prompts to stop") else: - logger.warning( - "[Server] Invalid message format - missing required fields" - ) + try: + await pipeline.stop_prompts(cleanup=False) + logger.info("[Control] Stopped prompt execution") + except Exception as e: + logger.error(f"[Control] Error stopping prompts: {str(e)}") + send_error_response("prompts_stopped", str(e)) + return + send_success_response("prompts_stopped") + + handlers = { + "get_nodes": handle_get_nodes, + "update_prompts": handle_update_prompts, + "update_resolution": handle_update_resolution, + "pause_prompts": handle_pause_prompts, + "resume_prompts": handle_resume_prompts, + "stop_prompts": handle_stop_prompts, + } + + try: + params = json.loads(message) + message_type = params.get("type") + + if not message_type: + logger.warning("[Server] Control message missing 'type'") + return + + handler = handlers.get(message_type) + if handler is None: + logger.warning(f"[Server] Unsupported control message: {message_type}") + return + + await handler(params) except json.JSONDecodeError: logger.error("[Server] Invalid JSON received") except Exception as e: @@ -369,12 +453,16 @@ async def on_message(message): elif channel.label == "data": if is_noop_mode: logger.debug("[TextChannel] Noop mode - skipping text output forwarding") + # In noop mode, just acknowledge the data channel but don't forward anything @channel.on("open") def on_data_channel_open(): - logger.debug("[TextChannel] Data channel opened in noop mode (no text forwarding)") + logger.debug( + "[TextChannel] Data channel opened in noop mode (no text forwarding)" + ) else: if pipeline.produces_text_output(): + async def forward_text(): try: while channel.readyState == "open": @@ -389,7 +477,9 @@ async def forward_text(): try: channel.send(json.dumps({"type": "text", "data": text})) except Exception as e: - logger.debug(f"[TextChannel] Send failed, stopping forwarder: {e}") + logger.debug( + f"[TextChannel] Send failed, stopping forwarder: {e}" + ) break except asyncio.CancelledError: logger.debug("[TextChannel] Forward text task cancelled") @@ -408,6 +498,7 @@ def _remove_forward_task(t): tasks = request.app.get("data_channel_tasks") if tasks is not None: tasks.discard(t) + forward_task.add_done_callback(_remove_forward_task) # Ensure cancellation on channel close event @@ -419,20 +510,22 @@ def on_data_channel_close(): if not t.done(): t.cancel() else: - logger.debug("[TextChannel] Workflow has no text outputs; not starting forward_text") + logger.debug( + "[TextChannel] Workflow has no text outputs; not starting forward_text" + ) @pc.on("track") def on_track(track): logger.info(f"Track received: {track.kind} (readyState: {track.readyState})") - + # Check if we already have a track of this type to avoid duplicate track errors if track.kind == "video" and tracks["video"] is not None: - logger.debug(f"Video track already exists, ignoring duplicate track event") + logger.debug("Video track already exists, ignoring duplicate track event") return elif track.kind == "audio" and tracks["audio"] is not None: - logger.debug(f"Audio track already exists, ignoring duplicate track event") + logger.debug("Audio track already exists, ignoring duplicate track event") return - + if track.kind == "video": if is_noop_mode: # Use simple passthrough track that bypasses pipeline @@ -442,7 +535,7 @@ def on_track(track): # Always use pipeline processing - it handles passthrough internally based on workflow videoTrack = VideoStreamTrack(track, pipeline) logger.info("[Pipeline] Using video processing pipeline") - + tracks["video"] = videoTrack sender = pc.addTrack(videoTrack) @@ -453,11 +546,10 @@ def on_track(track): codec = "video/H264" force_codec(pc, sender, codec) - - + elif track.kind == "audio": logger.info(f"Creating audio track for track {track.id}") - + if is_noop_mode: # Use simple passthrough track that bypasses pipeline audioTrack = NoopAudioStreamTrack(track) @@ -466,10 +558,10 @@ def on_track(track): # Always use pipeline processing - it handles passthrough internally based on workflow audioTrack = AudioStreamTrack(track, pipeline) logger.info("[Pipeline] Using audio processing pipeline") - + tracks["audio"] = audioTrack sender = pc.addTrack(audioTrack) - logger.debug(f"Audio track added to peer connection") + logger.debug("Audio track added to peer connection") @track.on("ended") async def on_ended(): @@ -488,6 +580,13 @@ async def on_connectionstatechange(): if not task.done(): task.cancel() request.app["data_channel_tasks"].clear() + # Cleanup pipeline once per connection (not per track) + if not is_noop_mode: + try: + await pipeline.stop_prompts(cleanup=True) + logger.info("Pipeline cleanup completed for failed connection") + except Exception as e: + logger.error(f"Error during pipeline cleanup on connection failure: {e}") elif pc.connectionState == "closed": await pc.close() pcs.discard(pc) @@ -497,6 +596,13 @@ async def on_connectionstatechange(): if not task.done(): task.cancel() request.app["data_channel_tasks"].clear() + # Cleanup pipeline once per connection (not per track) + if not is_noop_mode: + try: + await pipeline.stop_prompts(cleanup=True) + logger.info("Pipeline cleanup completed for closed connection") + except Exception as e: + logger.error(f"Error during pipeline cleanup on connection close: {e}") await pc.setRemoteDescription(offer) @@ -504,18 +610,12 @@ async def on_connectionstatechange(): transceivers = pc.getTransceivers() logger.debug(f"[Offer] After negotiation - Total transceivers: {len(transceivers)}") for i, t in enumerate(transceivers): - logger.debug(f"[Offer] Transceiver {i}: {t.kind} - direction: {t.direction} - currentDirection: {t.currentDirection}") + logger.debug( + f"[Offer] Transceiver {i}: {t.kind} - direction: {t.direction} - currentDirection: {t.currentDirection}" + ) # Warm up the pipeline based on detected modalities and SDP content (skip in noop mode) - if not is_noop_mode: - if "m=video" in pc.remoteDescription.sdp and pipeline.accepts_video_input(): - logger.info("[Offer] Warming up video pipeline") - await pipeline.warm_video() - - if "m=audio" in pc.remoteDescription.sdp and pipeline.accepts_audio_input(): - logger.info("[Offer] Warming up audio pipeline") - await pipeline.warm_audio() - else: + if is_noop_mode: logger.debug("[Offer] Skipping pipeline warmup in noop mode") answer = await pc.createAnswer() @@ -523,28 +623,29 @@ async def on_connectionstatechange(): return web.Response( content_type="application/json", - text=json.dumps( - {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type} - ), + text=json.dumps({"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}), ) + async def cancel_collect_frames(track): track.running = False - if hasattr(track, 'collect_task') is not None and not track.collect_task.done(): + if track.collect_task and not track.collect_task.done(): try: track.collect_task.cancel() await track.collect_task - except (asyncio.CancelledError): + except asyncio.CancelledError: pass + async def set_prompt(request): pipeline = request.app["pipeline"] prompt = await request.json() - await pipeline.set_prompts(prompt) + await pipeline.apply_prompts(prompt) return web.Response(content_type="application/json", text="OK") + def health(_): return web.Response(content_type="application/json", text="OK") @@ -556,13 +657,14 @@ async def on_startup(app: web.Application): app["pipeline"] = Pipeline( width=512, height=512, - cwd=app["workspace"], - disable_cuda_malloc=True, - gpu_only=True, - preview_method='none', - comfyui_inference_log_level=app.get("comfui_inference_log_level", None), - blacklist_nodes=["ComfyUI-Manager"] + cwd=app["workspace"], + disable_cuda_malloc=True, + gpu_only=True, + preview_method="none", + comfyui_inference_log_level=app.get("comfyui_inference_log_level", None), + blacklist_custom_nodes=["ComfyUI-Manager"], ) + await app["pipeline"].initialize() app["pcs"] = set() app["video_tracks"] = {} @@ -577,13 +679,9 @@ async def on_shutdown(app: web.Application): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run comfystream server") parser.add_argument("--port", default=8889, help="Set the signaling port") - parser.add_argument( - "--media-ports", default=None, help="Set the UDP ports for WebRTC media" - ) + parser.add_argument("--media-ports", default=None, help="Set the UDP ports for WebRTC media") parser.add_argument("--host", default="127.0.0.1", help="Set the host") - parser.add_argument( - "--workspace", default=None, required=True, help="Set Comfy workspace" - ) + parser.add_argument("--workspace", default=None, required=True, help="Set Comfy workspace") parser.add_argument( "--log-level", default="INFO", @@ -635,12 +733,10 @@ async def on_shutdown(app: web.Application): # WebRTC signalling and control routes. app.router.add_post("/offer", offer) app.router.add_post("/prompt", set_prompt) - + # Add routes for getting stream statistics. stream_stats_manager = StreamStatsManager(app) - app.router.add_get( - "/streams/stats", stream_stats_manager.collect_all_stream_metrics - ) + app.router.add_get("/streams/stats", stream_stats_manager.collect_all_stream_metrics) app.router.add_get( "/stream/{stream_id}/stats", stream_stats_manager.collect_stream_metrics_by_id ) @@ -667,10 +763,12 @@ def force_print(*args, **kwargs): if args.comfyui_log_level: log_level = logging._nameToLevel.get(args.comfyui_log_level.upper()) logging.getLogger("comfy").setLevel(log_level) - + # Add ComfyStream timeout filter to suppress verbose execution logging - logging.getLogger("comfy.cmd.execution").addFilter(ComfyStreamTimeoutFilter()) + timeout_filter = ComfyStreamTimeoutFilter() + logging.getLogger("comfy.cmd.execution").addFilter(timeout_filter) + logging.getLogger("comfystream").addFilter(timeout_filter) if args.comfyui_inference_log_level: - app["comfui_inference_log_level"] = args.comfyui_inference_log_level + app["comfyui_inference_log_level"] = args.comfyui_inference_log_level web.run_app(app, host=args.host, port=int(args.port), print=force_print) diff --git a/server/byoc.py b/server/byoc.py index 895667482..3f8f3470c 100644 --- a/server/byoc.py +++ b/server/byoc.py @@ -5,55 +5,35 @@ import sys import torch + # Initialize CUDA before any other imports to prevent core dump. if torch.cuda.is_available(): torch.cuda.init() from aiohttp import web +from frame_processor import ComfyStreamFrameProcessor +from pytrickle.frame_overlay import OverlayConfig, OverlayMode +from pytrickle.frame_skipper import FrameSkipConfig from pytrickle.stream_processor import StreamProcessor from pytrickle.utils.register import RegisterCapability -from pytrickle.frame_skipper import FrameSkipConfig -from frame_processor import ComfyStreamFrameProcessor + from comfystream.exceptions import ComfyStreamTimeoutFilter logger = logging.getLogger(__name__) - -async def register_orchestrator(orch_url=None, orch_secret=None, capability_name=None, host="127.0.0.1", port=8889): - """Register capability with orchestrator if configured.""" - try: - orch_url = orch_url or os.getenv("ORCH_URL") - orch_secret = orch_secret or os.getenv("ORCH_SECRET") - - if orch_url and orch_secret: - os.environ.update({ - "CAPABILITY_NAME": capability_name or os.getenv("CAPABILITY_NAME") or "comfystream-processor", - "CAPABILITY_DESCRIPTION": "ComfyUI streaming processor", - "CAPABILITY_URL": f"http://{host}:{port}", - "CAPABILITY_CAPACITY": "1", - "ORCH_URL": orch_url, - "ORCH_SECRET": orch_secret - }) - - # Pass through explicit capability_name to ensure CLI/env override takes effect - result = await RegisterCapability.register( - logger=logger, - capability_name=capability_name - ) - if result: - logger.info(f"Registered capability: {result.geturl()}") - except Exception as e: - logger.error(f"Orchestrator registration failed: {e}") +DEFAULT_WITHHELD_TIMEOUT_SECONDS = 0.5 def main(): parser = argparse.ArgumentParser( description="Run comfystream server in BYOC (Bring Your Own Compute) mode using pytrickle." ) - parser.add_argument("--port", default=8889, help="Set the server port") - parser.add_argument("--host", default="127.0.0.1", help="Set the host") + parser.add_argument("--port", default=8000, help="Set the server port") + parser.add_argument("--host", default="0.0.0.0", help="Set the host") parser.add_argument( - "--workspace", default=None, required=True, help="Set Comfy workspace" + "--workspace", + default=os.getcwd() + "/../ComfyUI", + help="Set Comfy workspace (Default: ../ComfyUI)", ) parser.add_argument( "--log-level", @@ -73,21 +53,6 @@ def main(): choices=logging._nameToLevel.keys(), help="Set the logging level for ComfyUI inference", ) - parser.add_argument( - "--orch-url", - default=None, - help="Orchestrator URL for capability registration", - ) - parser.add_argument( - "--orch-secret", - default=None, - help="Orchestrator secret for capability registration", - ) - parser.add_argument( - "--capability-name", - default=None, - help="Name for this capability (default: comfystream-processor)", - ) parser.add_argument( "--disable-frame-skip", default=False, @@ -113,21 +78,28 @@ def main(): format="%(asctime)s [%(levelname)s] %(message)s", datefmt="%H:%M:%S", ) + logging.getLogger("comfy.model_detection").setLevel(logging.WARNING) # Allow overriding of ComfyUI log levels. if args.comfyui_log_level: log_level = logging._nameToLevel.get(args.comfyui_log_level.upper()) logging.getLogger("comfy").setLevel(log_level) - + # Add ComfyStream timeout filter to suppress verbose execution logging - logging.getLogger("comfy.cmd.execution").addFilter(ComfyStreamTimeoutFilter()) + timeout_filter = ComfyStreamTimeoutFilter() + logging.getLogger("comfy.cmd.execution").addFilter(timeout_filter) + logging.getLogger("comfystream").addFilter(timeout_filter) def force_print(*args, **kwargs): print(*args, **kwargs, flush=True) sys.stdout.flush() logger.info("Starting ComfyStream BYOC server with pytrickle StreamProcessor...") - + logger.info( + "Send initial workflow parameters (width/height/prompts/warmup) via /stream/start " + "params; runtime updates now apply incremental changes only." + ) + # Create frame processor with configuration frame_processor = ComfyStreamFrameProcessor( width=args.width, @@ -135,10 +107,12 @@ def force_print(*args, **kwargs): workspace=args.workspace, disable_cuda_malloc=True, gpu_only=True, - preview_method='none', - comfyui_inference_log_level=args.comfyui_inference_log_level + preview_method="none", + blacklist_custom_nodes=["ComfyUI-Manager"], + logging_level=args.comfyui_log_level, + comfyui_inference_log_level=args.comfyui_inference_log_level, ) - + # Create frame skip configuration only if enabled frame_skip_config = None if args.disable_frame_skip: @@ -146,67 +120,73 @@ def force_print(*args, **kwargs): else: frame_skip_config = FrameSkipConfig() logger.info("Frame skipping enabled: adaptive skipping based on queue sizes") - + # Create StreamProcessor with frame processor processor = StreamProcessor( video_processor=frame_processor.process_video_async, audio_processor=frame_processor.process_audio_async, model_loader=frame_processor.load_model, param_updater=frame_processor.update_params, + on_stream_start=frame_processor.on_stream_start, on_stream_stop=frame_processor.on_stream_stop, # Align processor name with capability for consistent logs - name=(args.capability_name or os.getenv("CAPABILITY_NAME") or "comfystream-processor"), + name=(os.getenv("CAPABILITY_NAME") or "comfystream"), port=int(args.port), host=args.host, frame_skip_config=frame_skip_config, + overlay_config=OverlayConfig( + mode=OverlayMode.PROGRESSBAR, + message="Loading...", + enabled=True, + auto_timeout_seconds=DEFAULT_WITHHELD_TIMEOUT_SECONDS, + frame_count_to_disable=20, + ), # Ensure server metadata reflects the desired capability name - capability_name=(args.capability_name or os.getenv("CAPABILITY_NAME") or "comfystream-processor") + capability_name=(os.getenv("CAPABILITY_NAME") or "comfystream"), + # server_kwargs... + route_prefix="/", ) # Set the stream processor reference for text data publishing frame_processor.set_stream_processor(processor) - - # Create async startup function to load model - async def load_model_on_startup(app): - await processor._frame_processor.load_model() - + + logger.info("Startup warmup runs automatically as part of on_stream_start.") + # Create async startup function for orchestrator registration async def register_orchestrator_startup(app): - await register_orchestrator( - orch_url=args.orch_url, - orch_secret=args.orch_secret, - capability_name=args.capability_name, - host=args.host, - port=args.port - ) - - # Add model loading and registration to startup hooks - processor.server.app.on_startup.append(load_model_on_startup) - processor.server.app.on_startup.append(register_orchestrator_startup) - - # Add warmup endpoint: accepts same body as prompts update - async def warmup_handler(request): try: - body = await request.json() + orch_url = os.getenv("ORCH_URL") + + if orch_url and os.getenv("ORCH_SECRET", None): + # CAPABILITY_URL always overrides host:port from args + capability_url = os.getenv("CAPABILITY_URL") or f"http://{args.host}:{args.port}" + + os.environ.update( + { + "CAPABILITY_NAME": os.getenv("CAPABILITY_NAME") or "comfystream", + "CAPABILITY_DESCRIPTION": "ComfyUI streaming processor", + "CAPABILITY_URL": capability_url, + "CAPABILITY_CAPACITY": "1", + "ORCH_URL": orch_url, + "ORCH_SECRET": os.getenv("ORCH_SECRET", None), + } + ) + + result = await RegisterCapability.register( + logger=logger, capability_name=os.getenv("CAPABILITY_NAME") or "comfystream" + ) + if result: + logger.info(f"Registered capability: {result.geturl()}") + # Clear ORCH_SECRET from environment after use for security + os.environ.pop("ORCH_SECRET", None) except Exception as e: - logger.error(f"Invalid JSON in warmup request: {e}") - return web.json_response({"error": "Invalid JSON"}, status=400) - try: - # Inject sentinel to trigger warmup inside update_params on the model thread - if isinstance(body, dict): - body["warmup"] = True - else: - body = {"warmup": True} - # Fire-and-forget: do not await warmup; update_params will schedule it - asyncio.get_running_loop().create_task(frame_processor.update_params(body)) - return web.json_response({"status": "accepted"}) - except Exception as e: - logger.error(f"Warmup failed: {e}") - return web.json_response({"error": str(e)}, status=500) + logger.error(f"Orchestrator registration failed: {e}") + # Clear ORCH_SECRET from environment even on error + os.environ.pop("ORCH_SECRET", None) + + # Add registration to startup hooks + processor.server.app.on_startup.append(register_orchestrator_startup) - # Mount at same API namespace as StreamProcessor defaults - processor.server.add_route("POST", "/api/stream/warmup", warmup_handler) - # Run the processor processor.run() diff --git a/server/frame_processor.py b/server/frame_processor.py index 39272313b..befb5c608 100644 --- a/server/frame_processor.py +++ b/server/frame_processor.py @@ -2,13 +2,19 @@ import json import logging import os -from typing import List +from typing import Any, Dict, List, Optional, Union -import numpy as np from pytrickle.frame_processor import FrameProcessor -from pytrickle.frames import VideoFrame, AudioFrame +from pytrickle.frames import AudioFrame, VideoFrame +from pytrickle.stream_processor import VideoProcessingResult + from comfystream.pipeline import Pipeline -from comfystream.utils import convert_prompt, ComfyStreamParamsUpdateRequest +from comfystream.pipeline_state import PipelineState +from comfystream.utils import ( + ComfyStreamParamsUpdateRequest, + convert_prompt, + normalize_stream_params, +) logger = logging.getLogger(__name__) @@ -16,17 +22,19 @@ class ComfyStreamFrameProcessor(FrameProcessor): """ Integrated ComfyStream FrameProcessor for pytrickle. - + This class wraps the ComfyStream Pipeline to work with pytrickle's streaming architecture. """ def __init__(self, text_poll_interval: float = 0.25, **load_params): """Initialize with load parameters for pipeline creation. - + Args: text_poll_interval: Interval in seconds to poll for text outputs (default: 0.25) **load_params: Parameters for pipeline creation """ + super().__init__() + self.pipeline = None self._load_params = load_params self._text_poll_interval = text_poll_interval @@ -35,13 +43,64 @@ def __init__(self, text_poll_interval: float = 0.25, **load_params): self._text_forward_task = None self._background_tasks = [] self._stop_event = asyncio.Event() - super().__init__() + + async def _apply_stream_start_prompt(self, prompt_value: Any) -> bool: + if not self.pipeline: + logger.debug("Cannot apply stream start prompt without pipeline") + return False + + # Parse prompt payload from various formats + prompt_dict = None + if prompt_value is None: + pass + elif isinstance(prompt_value, dict): + prompt_dict = prompt_value + elif isinstance(prompt_value, list): + for candidate in prompt_value: + if isinstance(candidate, dict): + prompt_dict = candidate + break + elif isinstance(prompt_value, str): + prompt_str = prompt_value.strip() + if prompt_str: + try: + parsed = json.loads(prompt_str) + if isinstance(parsed, dict): + prompt_dict = parsed + else: + logger.warning("Parsed prompt payload is %s, expected dict", type(parsed)) + except json.JSONDecodeError: + logger.error("Stream start prompt is not valid JSON") + else: + logger.warning("Unsupported prompt payload type: %s", type(prompt_value)) + + if not isinstance(prompt_dict, dict): + logger.warning("Skipping prompt application due to invalid payload") + return False + + try: + await self._process_prompts(prompt_dict, skip_warmup=True) + return True + except Exception: + logger.exception("Failed to apply stream start prompt") + raise + + def _workflow_has_video(self) -> bool: + """Return True if current workflow is expected to produce video output.""" + if not self.pipeline: + return False + try: + capabilities = self.pipeline.get_workflow_io_capabilities() + return bool(capabilities.get("video", {}).get("output", False)) + except Exception: + logger.debug("Unable to determine workflow video capability", exc_info=True) + return False def set_stream_processor(self, stream_processor): """Set reference to StreamProcessor for data publishing.""" self._stream_processor = stream_processor logger.info("StreamProcessor reference set for text data publishing") - + def _setup_text_monitoring(self): """Set up background text forwarding from the pipeline.""" try: @@ -75,7 +134,9 @@ async def _forward_text_loop(): if self._stream_processor: success = await self._stream_processor.send_data(text) if not success: - logger.debug("Text send failed; stopping text forwarder") + logger.debug( + "Text send failed; stopping text forwarder" + ) break except asyncio.CancelledError: logger.debug("Text forwarder task cancelled") @@ -105,7 +166,7 @@ async def _stop_text_forwarder(self) -> None: except Exception: logger.debug("Error while awaiting text forwarder cancellation", exc_info=True) self._text_forward_task = None - + async def on_stream_stop(self): """Called when stream stops - cleanup background tasks.""" logger.info("Stream stopped, cleaning up background tasks") @@ -114,10 +175,10 @@ async def on_stream_stop(self): self._stop_event.set() # Stop the ComfyStream client's prompt execution - if self.pipeline and self.pipeline.client: + if self.pipeline: logger.info("Stopping ComfyStream client prompt execution") try: - await self.pipeline.client.cleanup() + await self.pipeline.stop_prompts(cleanup=True) except Exception as e: logger.error(f"Error stopping ComfyStream client: {e}") @@ -144,44 +205,89 @@ async def on_stream_stop(self): self._background_tasks.clear() logger.info("All background tasks cleaned up") - + def _reset_stop_event(self): """Reset the stop event for a new stream.""" self._stop_event.clear() + async def on_stream_start(self, params: Optional[Dict[str, Any]] = None): + """Handle stream start lifecycle events.""" + logger.info("Stream starting") + self._reset_stop_event() + logger.info(f"Stream start params: {params}") + + if not self.pipeline: + logger.debug("Stream start requested before pipeline initialization") + return + + stream_params = normalize_stream_params(params) + prompt_payload = stream_params.pop("prompts", None) + if prompt_payload is None: + prompt_payload = stream_params.pop("prompt", None) + + if prompt_payload: + try: + await self._apply_stream_start_prompt(prompt_payload) + except Exception: + logger.exception("Failed to apply stream start prompt") + return + + if not self.pipeline.state_manager.is_initialized(): + logger.info("Pipeline not initialized; waiting for prompts before streaming") + return + + if stream_params: + try: + await self.update_params(stream_params) + except Exception: + logger.exception("Failed to process stream start parameters") + return + + try: + if ( + self.pipeline.state != PipelineState.STREAMING + and self.pipeline.state_manager.can_stream() + ): + await self.pipeline.start_streaming() + + if self.pipeline.produces_text_output(): + self._setup_text_monitoring() + else: + await self._stop_text_forwarder() + except Exception: + logger.exception("Failed during stream start", exc_info=True) + async def load_model(self, **kwargs): """Load model and initialize the pipeline.""" params = {**self._load_params, **kwargs} - + if self.pipeline is None: self.pipeline = Pipeline( - width=int(params.get('width', 512)), - height=int(params.get('height', 512)), - cwd=params.get('workspace', os.getcwd()), - disable_cuda_malloc=params.get('disable_cuda_malloc', True), - gpu_only=params.get('gpu_only', True), - preview_method=params.get('preview_method', 'none'), - comfyui_inference_log_level=params.get('comfyui_inference_log_level'), - blacklist_nodes=["ComfyUI-Manager"] + width=int(params.get("width", 512)), + height=int(params.get("height", 512)), + cwd=params.get("workspace", os.getcwd()), + disable_cuda_malloc=params.get("disable_cuda_malloc", True), + gpu_only=params.get("gpu_only", True), + preview_method=params.get("preview_method", "none"), + comfyui_inference_log_level=params.get("comfyui_inference_log_level", "INFO"), + logging_level=params.get("comfyui_inference_log_level", "INFO"), + blacklist_custom_nodes=["ComfyUI-Manager"], ) + await self.pipeline.initialize() async def warmup(self): """Warm up the pipeline.""" if not self.pipeline: logger.warning("Warmup requested before pipeline initialization") return - + logger.info("Running pipeline warmup...") try: capabilities = self.pipeline.get_workflow_io_capabilities() logger.info(f"Detected I/O capabilities: {capabilities}") - - if capabilities.get("video", {}).get("input") or capabilities.get("video", {}).get("output"): - await self.pipeline.warm_video() - - if capabilities.get("audio", {}).get("input") or capabilities.get("audio", {}).get("output"): - await self.pipeline.warm_audio() - + + await self.pipeline.warmup() + except Exception as e: logger.error(f"Warmup failed: {e}") @@ -197,23 +303,43 @@ def _schedule_warmup(self) -> None: except Exception: logger.warning("Failed to schedule warmup", exc_info=True) - async def process_video_async(self, frame: VideoFrame) -> VideoFrame: - """Process video frame through ComfyStream Pipeline.""" + async def process_video_async( + self, frame: VideoFrame + ) -> Union[VideoFrame, VideoProcessingResult]: + """Process video frame through ComfyStream Pipeline. + + Returns VideoProcessingResult.WITHHELD to trigger pytrickle's automatic overlay when + processed frames are not yet available. + """ try: - + if not self.pipeline: + return frame + + # If pipeline ingestion is paused, withhold frame so pytrickle renders the overlay + if not self.pipeline.is_ingest_enabled(): + return VideoProcessingResult.WITHHELD + # Convert pytrickle VideoFrame to av.VideoFrame av_frame = frame.to_av_frame(frame.tensor) av_frame.pts = frame.timestamp av_frame.time_base = frame.time_base - + # Process through pipeline await self.pipeline.put_video_frame(av_frame) - processed_av_frame = await self.pipeline.get_processed_video_frame() - - # Convert back to pytrickle VideoFrame - processed_frame = VideoFrame.from_av_frame_with_timing(processed_av_frame, frame) - return processed_frame - + + # Try to get processed frame with short timeout + try: + processed_av_frame = await asyncio.wait_for( + self.pipeline.get_processed_video_frame(), + timeout=self._stream_processor.overlay_config.auto_timeout_seconds, + ) + processed_frame = VideoFrame.from_av_frame_with_timing(processed_av_frame, frame) + return processed_frame + + except asyncio.TimeoutError: + # No frame ready yet - return withheld sentinel to trigger overlay + return VideoProcessingResult.WITHHELD + except Exception as e: logger.error(f"Video processing failed: {e}") return frame @@ -223,14 +349,20 @@ async def process_audio_async(self, frame: AudioFrame) -> List[AudioFrame]: try: if not self.pipeline: return [frame] - - # Audio processing needed - use pipeline + + # If pipeline ingestion is paused, passthrough audio + if not self.pipeline.is_ingest_enabled(): + frame.side_data.skipped = True + frame.side_data.passthrough = True + return [frame] + + # Audio processing - use pipeline av_frame = frame.to_av_frame() await self.pipeline.put_audio_frame(av_frame) processed_av_frame = await self.pipeline.get_processed_audio_frame() processed_frame = AudioFrame.from_av_audio(processed_av_frame) return [processed_frame] - + except Exception as e: logger.error(f"Audio processing failed: {e}") return [frame] @@ -239,44 +371,59 @@ async def update_params(self, params: dict): """Update processing parameters.""" if not self.pipeline: return - - # Handle list input - take first element - if isinstance(params, list) and params: - params = params[0] - + + params_payload: Dict[str, Any] = {} + if isinstance(params, list): + params = params[0] if params else {} + + if isinstance(params, dict): + params_payload = dict(params) + elif params is None: + params_payload = {} + else: + logger.warning("Unsupported params type for update_params: %s", type(params)) + return + + if not params_payload: + return + # Validate parameters using the centralized validation - validated = ComfyStreamParamsUpdateRequest(**params).model_dump() + validated = ComfyStreamParamsUpdateRequest(**params_payload).model_dump() logger.info(f"Parameter validation successful, keys: {list(validated.keys())}") - + # Process prompts if provided if "prompts" in validated and validated["prompts"]: - await self._process_prompts(validated["prompts"]) - + await self._process_prompts(validated["prompts"], skip_warmup=True) + # Update pipeline dimensions if "width" in validated: self.pipeline.width = int(validated["width"]) if "height" in validated: self.pipeline.height = int(validated["height"]) - - # Schedule warmup if requested - if validated.get("warmup", False): - self._schedule_warmup() - - async def _process_prompts(self, prompts): + async def _process_prompts(self, prompts, *, skip_warmup: bool = False): """Process and set prompts in the pipeline.""" + if not self.pipeline: + logger.warning("Prompt update requested before pipeline initialization") + return try: converted = convert_prompt(prompts, return_dict=True) - - # Set prompts in pipeline - await self.pipeline.set_prompts([converted]) - logger.info(f"Prompts set successfully: {list(prompts.keys())}") - - # Update text monitoring based on workflow capabilities + + await self.pipeline.apply_prompts( + [converted], + skip_warmup=skip_warmup, + ) + + if self.pipeline.state_manager.can_stream(): + await self.pipeline.start_streaming() + + logger.info(f"Prompts applied successfully: {list(prompts.keys())}") + if self.pipeline.produces_text_output(): self._setup_text_monitoring() else: await self._stop_text_forwarder() - + except Exception as e: logger.error(f"Failed to process prompts: {e}") + raise diff --git a/setup.py b/setup.py index fc1f76c84..606849326 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,3 @@ from setuptools import setup -setup() \ No newline at end of file +setup() diff --git a/src/comfystream/__init__.py b/src/comfystream/__init__.py index 8aee624cf..398a9273a 100644 --- a/src/comfystream/__init__.py +++ b/src/comfystream/__init__.py @@ -1,17 +1,19 @@ from .client import ComfyStreamClient +from .exceptions import ComfyStreamAudioBufferError, ComfyStreamInputTimeoutError from .pipeline import Pipeline -from .server.utils import temporary_log_level -from .server.utils import FPSMeter +from .pipeline_state import PipelineState, PipelineStateManager from .server.metrics import MetricsManager, StreamStatsManager -from .exceptions import ComfyStreamInputTimeoutError, ComfyStreamAudioBufferError +from .server.utils import FPSMeter, temporary_log_level __all__ = [ - 'ComfyStreamClient', - 'Pipeline', - 'temporary_log_level', - 'FPSMeter', - 'MetricsManager', - 'StreamStatsManager', - 'ComfyStreamInputTimeoutError', - 'ComfyStreamAudioBufferError' + "ComfyStreamClient", + "Pipeline", + "PipelineState", + "PipelineStateManager", + "temporary_log_level", + "FPSMeter", + "MetricsManager", + "StreamStatsManager", + "ComfyStreamInputTimeoutError", + "ComfyStreamAudioBufferError", ] diff --git a/src/comfystream/client.py b/src/comfystream/client.py index 7686fca18..de5432448 100644 --- a/src/comfystream/client.py +++ b/src/comfystream/client.py @@ -1,15 +1,16 @@ import asyncio -from typing import List +import contextlib import logging - -from comfystream import tensor_cache -from comfystream.utils import convert_prompt -from comfystream.exceptions import ComfyStreamInputTimeoutError +from typing import List from comfy.api.components.schema.prompt import PromptDictInput from comfy.cli_args_types import Configuration from comfy.client.embedded_comfy_client import EmbeddedComfyClient +from comfystream import tensor_cache +from comfystream.exceptions import ComfyStreamInputTimeoutError +from comfystream.utils import convert_prompt, get_default_workflow + logger = logging.getLogger(__name__) @@ -17,34 +18,37 @@ class ComfyStreamClient: def __init__(self, max_workers: int = 1, **kwargs): config = Configuration(**kwargs) self.comfy_client = EmbeddedComfyClient(config, max_workers=max_workers) - self.running_prompts = {} # To be used for cancelling tasks self.current_prompts = [] self._cleanup_lock = asyncio.Lock() self._prompt_update_lock = asyncio.Lock() - self._stop_event = asyncio.Event() + + # PromptRunner state + self._shutdown_event = asyncio.Event() + self._run_enabled_event = asyncio.Event() + self._runner_task = None async def set_prompts(self, prompts: List[PromptDictInput]): """Set new prompts, replacing any existing ones. - + Args: prompts: List of prompt dictionaries to set - + Raises: ValueError: If prompts list is empty Exception: If prompt conversion or validation fails """ if not prompts: raise ValueError("Cannot set empty prompts list") - - # Cancel existing prompts first to avoid conflicts - await self.cancel_running_prompts() - # Reset stop event for new prompts - self._stop_event.clear() + + # Pause runner while swapping prompts to avoid interleaving + was_running = self._run_enabled_event.is_set() + self._run_enabled_event.clear() self.current_prompts = [convert_prompt(prompt) for prompt in prompts] - logger.info(f"Queuing {len(self.current_prompts)} prompt(s) for execution") - for idx in range(len(self.current_prompts)): - task = asyncio.create_task(self.run_prompt(idx)) - self.running_prompts[idx] = task + logger.info(f"Configured {len(self.current_prompts)} prompt(s)") + # Ensure runner exists (IDLE until resumed) + await self.ensure_prompt_tasks_running() + if was_running: + self._run_enabled_event.set() async def update_prompts(self, prompts: List[PromptDictInput]): async with self._prompt_update_lock: @@ -57,34 +61,63 @@ async def update_prompts(self, prompts: List[PromptDictInput]): for idx, prompt in enumerate(prompts): converted_prompt = convert_prompt(prompt) try: + # Lightweight validation by queueing is retained for compatibility await self.comfy_client.queue_prompt(converted_prompt) self.current_prompts[idx] = converted_prompt except Exception as e: raise Exception(f"Prompt update failed: {str(e)}") from e - async def run_prompt(self, prompt_index: int): - while not self._stop_event.is_set(): - async with self._prompt_update_lock: - try: - await self.comfy_client.queue_prompt(self.current_prompts[prompt_index]) - except asyncio.CancelledError: - raise - except ComfyStreamInputTimeoutError: - # Timeout errors are expected during stream switching - just continue - logger.info(f"Input for prompt {prompt_index} timed out, continuing") - continue - except Exception as e: - await self.cleanup() - logger.error(f"Error running prompt: {str(e)}") - raise + async def ensure_prompt_tasks_running(self): + # Ensure the single runner task exists (does not force running) + if self._runner_task and not self._runner_task.done(): + return + if not self.current_prompts: + return + self._shutdown_event.clear() + self._runner_task = asyncio.create_task(self._runner_loop()) + + async def _runner_loop(self): + try: + while not self._shutdown_event.is_set(): + # IDLE until running is enabled + await self._run_enabled_event.wait() + # Snapshot prompts without holding the lock during network I/O + async with self._prompt_update_lock: + prompts_snapshot = list(self.current_prompts) + for prompt_index, prompt in enumerate(prompts_snapshot): + if self._shutdown_event.is_set() or not self._run_enabled_event.is_set(): + break + try: + await self.comfy_client.queue_prompt(prompt) + except asyncio.CancelledError: + raise + except ComfyStreamInputTimeoutError: + logger.info(f"Input for prompt {prompt_index} timed out, continuing") + continue + except Exception as e: + logger.error(f"Error running prompt: {str(e)}") + logger.info("Stopping prompt execution and returning to passthrough mode") + + # Stop running and switch to default passthrough workflow + await self._fallback_to_passthrough() + break + except asyncio.CancelledError: + pass async def cleanup(self): - # Set stop event to signal prompt loops to exit - self._stop_event.set() - - await self.cancel_running_prompts() + # Signal runner to shutdown + self._shutdown_event.set() + if self._runner_task: + self._runner_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._runner_task + self._runner_task = None + + # Pause running + self._run_enabled_event.clear() + async with self._cleanup_lock: - if self.comfy_client.is_running: + if getattr(self.comfy_client, "is_running", False): try: await self.comfy_client.__aexit__() except Exception as e: @@ -93,18 +126,32 @@ async def cleanup(self): await self.cleanup_queues() logger.info("Client cleanup complete") - async def cancel_running_prompts(self): - async with self._cleanup_lock: - tasks_to_cancel = list(self.running_prompts.values()) - for task in tasks_to_cancel: - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - self.running_prompts.clear() + def pause_prompts(self): + """Pause prompt execution loops without canceling underlying tasks.""" + self._run_enabled_event.clear() + logger.debug("Prompt execution paused") + + async def resume_prompts(self): + """Resume prompt execution loops.""" + await self.ensure_prompt_tasks_running() + self._run_enabled_event.set() + logger.debug("Prompt execution resumed") + + async def stop_prompts(self, cleanup: bool = False): + """Stop running prompts by canceling their tasks. + + Args: + cleanup: If True, perform full cleanup including queue clearing and + client shutdown. If False, only cancel prompt tasks. + """ + await self.stop_prompts_immediately() + + if cleanup: + await self.cleanup() + logger.info("Prompts stopped with full cleanup") + else: + logger.debug("Prompts stopped (tasks cancelled)") - async def cleanup_queues(self): while not tensor_cache.image_inputs.empty(): tensor_cache.image_inputs.get() @@ -121,20 +168,50 @@ async def cleanup_queues(self): while not tensor_cache.text_outputs.empty(): await tensor_cache.text_outputs.get() + async def stop_prompts_immediately(self): + """Cancel the runner task to immediately stop any in-flight prompt execution.""" + self._run_enabled_event.clear() + if self._runner_task: + self._runner_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._runner_task + self._runner_task = None + + async def _fallback_to_passthrough(self): + """Switch to default passthrough workflow when an error occurs.""" + try: + # Pause the runner + self._run_enabled_event.clear() + + # Set to default passthrough workflow + default_workflow = get_default_workflow() + async with self._prompt_update_lock: + self.current_prompts = [convert_prompt(default_workflow)] + + logger.info("Switched to default passthrough workflow") + + # Resume the runner with passthrough workflow + self._run_enabled_event.set() + + except Exception as e: + logger.error(f"Failed to fallback to passthrough: {str(e)}") + # If fallback fails, just pause execution + self._run_enabled_event.clear() + def put_video_input(self, frame): if tensor_cache.image_inputs.full(): tensor_cache.image_inputs.get(block=True) tensor_cache.image_inputs.put(frame) - + def put_audio_input(self, frame): tensor_cache.audio_inputs.put(frame) async def get_video_output(self): return await tensor_cache.image_outputs.get() - + async def get_audio_output(self): return await tensor_cache.audio_outputs.get() - + async def get_text_output(self): try: return tensor_cache.text_outputs.get_nowait() @@ -149,25 +226,20 @@ async def get_text_output(self): async def get_available_nodes(self): """Get metadata and available nodes info in a single pass""" # TODO: make it for for multiple prompts - if not self.running_prompts: + if not self.current_prompts: return {} try: from comfy.nodes.package import import_all_nodes_in_workspace + nodes = import_all_nodes_in_workspace() all_prompts_nodes_info = {} - + for prompt_index, prompt in enumerate(self.current_prompts): # Get set of class types we need metadata for, excluding LoadTensor and SaveTensor - needed_class_types = { - node.get('class_type') - for node in prompt.values() - } - remaining_nodes = { - node_id - for node_id, node in prompt.items() - } + needed_class_types = {node.get("class_type") for node in prompt.values()} + remaining_nodes = {node_id for node_id, node in prompt.items()} nodes_info = {} # Only process nodes until we've found all the ones we need @@ -179,87 +251,88 @@ async def get_available_nodes(self): continue # Get metadata for this node type (same as original get_node_metadata) - input_data = node_class.INPUT_TYPES() if hasattr(node_class, 'INPUT_TYPES') else {} + input_data = ( + node_class.INPUT_TYPES() if hasattr(node_class, "INPUT_TYPES") else {} + ) input_info = {} # Process required inputs - if 'required' in input_data: - for name, value in input_data['required'].items(): + if "required" in input_data: + for name, value in input_data["required"].items(): if isinstance(value, tuple): if len(value) == 1 and isinstance(value[0], list): # Handle combo box case where value is ([option1, option2, ...],) input_info[name] = { - 'type': 'combo', - 'value': value[0], # The list of options becomes the value + "type": "combo", + "value": value[0], # The list of options becomes the value } elif len(value) == 2: input_type, config = value input_info[name] = { - 'type': input_type, - 'required': True, - 'min': config.get('min', None), - 'max': config.get('max', None), - 'widget': config.get('widget', None) + "type": input_type, + "required": True, + "min": config.get("min", None), + "max": config.get("max", None), + "widget": config.get("widget", None), } elif len(value) == 1: # Handle simple type case like ('IMAGE',) - input_info[name] = { - 'type': value[0] - } + input_info[name] = {"type": value[0]} else: - logger.error(f"Unexpected structure for required input {name}: {value}") + logger.error( + f"Unexpected structure for required input {name}: {value}" + ) # Process optional inputs with same logic - if 'optional' in input_data: - for name, value in input_data['optional'].items(): + if "optional" in input_data: + for name, value in input_data["optional"].items(): if isinstance(value, tuple): if len(value) == 1 and isinstance(value[0], list): # Handle combo box case where value is ([option1, option2, ...],) input_info[name] = { - 'type': 'combo', - 'value': value[0], # The list of options becomes the value + "type": "combo", + "value": value[0], # The list of options becomes the value } elif len(value) == 2: input_type, config = value input_info[name] = { - 'type': input_type, - 'required': False, - 'min': config.get('min', None), - 'max': config.get('max', None), - 'widget': config.get('widget', None) + "type": input_type, + "required": False, + "min": config.get("min", None), + "max": config.get("max", None), + "widget": config.get("widget", None), } elif len(value) == 1: # Handle simple type case like ('IMAGE',) - input_info[name] = { - 'type': value[0] - } + input_info[name] = {"type": value[0]} else: - logger.error(f"Unexpected structure for optional input {name}: {value}") + logger.error( + f"Unexpected structure for optional input {name}: {value}" + ) # Now process any nodes in our prompt that use this class_type for node_id in list(remaining_nodes): node = prompt[node_id] - if node.get('class_type') != class_type: + if node.get("class_type") != class_type: continue - node_info = { - 'class_type': class_type, - 'inputs': {} - } + node_info = {"class_type": class_type, "inputs": {}} - if 'inputs' in node: - for input_name, input_value in node['inputs'].items(): + if "inputs" in node: + for input_name, input_value in node["inputs"].items(): input_metadata = input_info.get(input_name, {}) - node_info['inputs'][input_name] = { - 'value': input_value, - 'type': input_metadata.get('type', 'unknown'), - 'min': input_metadata.get('min', None), - 'max': input_metadata.get('max', None), - 'widget': input_metadata.get('widget', None) + node_info["inputs"][input_name] = { + "value": input_value, + "type": input_metadata.get("type", "unknown"), + "min": input_metadata.get("min", None), + "max": input_metadata.get("max", None), + "widget": input_metadata.get("widget", None), } # For combo type inputs, include the list of options - if input_metadata.get('type') == 'combo': - node_info['inputs'][input_name]['value'] = input_metadata.get('value', []) + if input_metadata.get("type") == "combo": + node_info["inputs"][input_name]["value"] = input_metadata.get( + "value", [] + ) nodes_info[node_id] = node_info remaining_nodes.remove(node_id) @@ -270,4 +343,4 @@ async def get_available_nodes(self): except Exception as e: logger.error(f"Error getting node info: {str(e)}") - return {} \ No newline at end of file + return {} diff --git a/src/comfystream/exceptions.py b/src/comfystream/exceptions.py index aaf20c844..53a97e870 100644 --- a/src/comfystream/exceptions.py +++ b/src/comfystream/exceptions.py @@ -1,17 +1,15 @@ """ComfyStream specific exceptions.""" import logging -from typing import Dict, Any, Optional +from typing import Any, Dict, Optional def log_comfystream_error( - exception: Exception, - logger: Optional[logging.Logger] = None, - level: int = logging.ERROR + exception: Exception, logger: Optional[logging.Logger] = None, level: int = logging.ERROR ) -> None: """ Centralized logging function for ComfyStream exceptions. - + Args: exception: The exception to log logger: Optional logger to use (defaults to module logger) @@ -19,7 +17,7 @@ def log_comfystream_error( """ if logger is None: logger = logging.getLogger(__name__) - + # If it's a ComfyStream timeout error with structured details, use its logging method if isinstance(exception, ComfyStreamInputTimeoutError): exception.log_error(logger) @@ -30,33 +28,27 @@ def log_comfystream_error( class ComfyStreamInputTimeoutError(Exception): """Raised when input tensors are not available within timeout.""" - + def __init__( - self, - input_type: str, - timeout_seconds: float, - details: Optional[Dict[str, Any]] = None + self, input_type: str, timeout_seconds: float, details: Optional[Dict[str, Any]] = None ): self.input_type = input_type self.timeout_seconds = timeout_seconds self.details = details or {} message = f"No {input_type} frames available after {timeout_seconds}s timeout" super().__init__(message) - + def get_log_details(self) -> Dict[str, Any]: """Get structured details for logging.""" - base_details = { - "input_type": self.input_type, - "timeout_seconds": self.timeout_seconds - } + base_details = {"input_type": self.input_type, "timeout_seconds": self.timeout_seconds} base_details.update(self.details) return base_details - + def log_error(self, logger: Optional[logging.Logger] = None) -> None: """Log the error with detailed information.""" if logger is None: logger = logging.getLogger(__name__) - + details = self.get_log_details() detail_str = ", ".join(f"{k}={v}" for k, v in details.items()) logger.error(f"ComfyStream timeout error: {str(self)} | Details: {detail_str}") @@ -64,23 +56,18 @@ def log_error(self, logger: Optional[logging.Logger] = None) -> None: class ComfyStreamAudioBufferError(ComfyStreamInputTimeoutError): """Audio buffer insufficient data error.""" - - def __init__( - self, - timeout_seconds: float, - needed_samples: int, - available_samples: int - ): + + def __init__(self, timeout_seconds: float, needed_samples: int, available_samples: int): self.needed_samples = needed_samples self.available_samples = available_samples - + # Pass audio-specific details to the base class audio_details = { "needed_samples": needed_samples, "available_samples": available_samples, } super().__init__("audio", timeout_seconds, details=audio_details) - + def get_log_details(self) -> Dict[str, Any]: """Get structured details for logging, with audio-specific formatting.""" details = super().get_log_details() @@ -89,37 +76,46 @@ def get_log_details(self) -> Dict[str, Any]: class ComfyStreamTimeoutFilter(logging.Filter): """Filter to suppress verbose ComfyUI execution logs for ComfyStream timeout exceptions.""" - + def filter(self, record): """Filter out ComfyUI execution error logs for ComfyStream timeout exceptions.""" try: # Only filter ERROR level messages from ComfyUI execution system if record.levelno != logging.ERROR: return True - - # Check if this is from ComfyUI execution system - if not (record.name.startswith("comfy") and ("execution" in record.name or record.name == "comfy")): + + # Determine if this record is from a logger we want to inspect for timeout suppression + is_comfy_execution_logger = record.name.startswith("comfy") and ( + "execution" in record.name or record.name == "comfy" + ) + is_comfystream_logger = record.name.startswith("comfystream") + + if not (is_comfy_execution_logger or is_comfystream_logger): return True - + # Get the full message including any exception info message = record.getMessage() - + # Simple check: if this log contains ComfyStreamAudioBufferError or ComfyStreamInputTimeoutError, suppress it - if ("ComfyStreamAudioBufferError" in message or - "ComfyStreamInputTimeoutError" in message): + if ( + "ComfyStreamAudioBufferError" in message + or "ComfyStreamInputTimeoutError" in message + ): return False - + # Also check the exception info if present if record.exc_info and record.exc_info[1]: exc_str = str(record.exc_info[1]) exc_type = str(type(record.exc_info[1])) - - if ("ComfyStreamAudioBufferError" in exc_str or - "ComfyStreamInputTimeoutError" in exc_str or - "ComfyStreamAudioBufferError" in exc_type or - "ComfyStreamInputTimeoutError" in exc_type): + + if ( + "ComfyStreamAudioBufferError" in exc_str + or "ComfyStreamInputTimeoutError" in exc_str + or "ComfyStreamAudioBufferError" in exc_type + or "ComfyStreamInputTimeoutError" in exc_type + ): return False - + return True except Exception as e: # If filter fails, allow the log through and print the error diff --git a/src/comfystream/modalities.py b/src/comfystream/modalities.py index fccfabad5..ded168296 100644 --- a/src/comfystream/modalities.py +++ b/src/comfystream/modalities.py @@ -1,27 +1,29 @@ -from typing import Dict, Any, Set, Union, List, TypedDict +from typing import Any, Dict, List, Set, TypedDict, Union class ModalityIO(TypedDict): """Input/output capabilities for a single modality.""" + input: bool output: bool + class WorkflowModality(TypedDict): """Workflow modality detection result mapping modalities to their I/O capabilities.""" + video: ModalityIO audio: ModalityIO text: ModalityIO + # Centralized node type definitions NODE_TYPES = { # Video nodes "video_input": {"LoadTensor", "PrimaryInputLoadImage", "LoadImage"}, "video_output": {"SaveTensor", "PreviewImage", "SaveImage"}, - # Audio nodes "audio_input": {"LoadAudioTensor"}, "audio_output": {"SaveAudioTensor"}, - # Text nodes "text_input": set(), # No text input nodes currently "text_output": {"SaveTextTensor"}, @@ -29,7 +31,9 @@ class WorkflowModality(TypedDict): # Flatten all input and output node types for easier checking all_input_nodes = NODE_TYPES["video_input"] | NODE_TYPES["audio_input"] | NODE_TYPES["text_input"] -all_output_nodes = NODE_TYPES["video_output"] | NODE_TYPES["audio_output"] | NODE_TYPES["text_output"] +all_output_nodes = ( + NODE_TYPES["video_output"] | NODE_TYPES["audio_output"] | NODE_TYPES["text_output"] +) # Modality mappings derived from NODE_TYPES MODALITY_MAPPINGS = { @@ -55,41 +59,45 @@ class WorkflowModality(TypedDict): "SaveImage": "output_replacement", } + def get_node_counts_by_type(prompt: Dict[Any, Any]) -> Dict[str, int]: """Count nodes by their functional types (primary inputs, inputs, outputs).""" counts = {"primary_inputs": 0, "inputs": 0, "outputs": 0} - + for node in prompt.values(): class_type = node.get("class_type") - + if class_type == "PrimaryInputLoadImage": counts["primary_inputs"] += 1 elif class_type in all_input_nodes: counts["inputs"] += 1 elif class_type in all_output_nodes: counts["outputs"] += 1 - + return counts + def get_convertible_node_keys(prompt: Dict[Any, Any]) -> Dict[str, List[str]]: """Collect keys of nodes that need conversion, organized by node type.""" keys = {node_type: [] for node_type in CONVERTIBLE_NODES.keys()} - + for key, node in prompt.items(): class_type = node.get("class_type") if class_type in keys: keys[class_type].append(key) - + return keys + def create_empty_workflow_modality() -> WorkflowModality: """Create an empty WorkflowModality with all capabilities set to False.""" return { "video": {"input": False, "output": False}, "audio": {"input": False, "output": False}, - "text": {"input": False, "output": False}, + "text": {"input": False, "output": False}, } + def _merge_workflow_modalities(base: WorkflowModality, other: WorkflowModality) -> WorkflowModality: """Merge two WorkflowModality objects using logical OR for all capabilities.""" for modality in base: @@ -97,6 +105,7 @@ def _merge_workflow_modalities(base: WorkflowModality, other: WorkflowModality) base[modality][direction] = base[modality][direction] or other[modality][direction] return base + def detect_io_points(prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]) -> WorkflowModality: """Detect input/output presence per modality for a workflow. @@ -117,7 +126,7 @@ def detect_io_points(prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]) -> Wo # Scan nodes and detect modality I/O points using centralized mappings for node in prompts.values(): class_type = node.get("class_type", "") - + for modality, directions in MODALITY_MAPPINGS.items(): if class_type in directions["input"]: result[modality]["input"] = True @@ -126,17 +135,18 @@ def detect_io_points(prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]) -> Wo return result + def detect_prompt_modalities(prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]) -> Set[str]: """Detect which modalities are used by a workflow. - + Returns a set of modality names that have either input or output nodes. This is used by the pipeline to determine which modalities need processing. """ io_points = detect_io_points(prompts) modalities = set() - + for modality, capabilities in io_points.items(): if capabilities["input"] or capabilities["output"]: modalities.add(modality) - + return modalities diff --git a/src/comfystream/pipeline.py b/src/comfystream/pipeline.py index 3e4febaf1..cf2302de7 100644 --- a/src/comfystream/pipeline.py +++ b/src/comfystream/pipeline.py @@ -1,42 +1,63 @@ -import av -import torch -import numpy as np import asyncio import logging -from typing import Any, Dict, Union, List, Optional, Set +from typing import Any, Dict, List, Optional, Set, Union + +import av +import numpy as np +import torch from comfystream.client import ComfyStreamClient +from comfystream.pipeline_state import PipelineState, PipelineStateManager from comfystream.server.utils import temporary_log_level -from .modalities import detect_prompt_modalities, detect_io_points, WorkflowModality -from .modalities import create_empty_workflow_modality - +from comfystream.utils import get_default_workflow + +from .modalities import ( + WorkflowModality, + create_empty_workflow_modality, + detect_io_points, + detect_prompt_modalities, +) + WARMUP_RUNS = 5 +BOOTSTRAP_TIMEOUT_SECONDS = 30.0 logger = logging.getLogger(__name__) class Pipeline: """A pipeline for processing video and audio frames using ComfyUI. - + This class provides a high-level interface for processing video and audio frames through a ComfyUI-based processing pipeline. It handles frame preprocessing, postprocessing, and queue management. """ - - def __init__(self, width: int = 512, height: int = 512, - comfyui_inference_log_level: Optional[int] = None, **kwargs): + + def __init__( + self, + width: int = 512, + height: int = 512, + comfyui_inference_log_level: Optional[int] = None, + auto_warmup: bool = False, + bootstrap_default_prompt: bool = True, + **kwargs, + ): """Initialize the pipeline with the given configuration. - + Args: width: Width of the video frames (default: 512) height: Height of the video frames (default: 512) comfyui_inference_log_level: The logging level for ComfyUI inference. Defaults to None, using the global ComfyUI log level. + auto_warmup: Whether to run warmup automatically after prompts are set. + bootstrap_default_prompt: Whether to run the default workflow once during + initialization to start ComfyUI before prompts are applied. **kwargs: Additional arguments to pass to the ComfyStreamClient """ self.client = ComfyStreamClient(**kwargs) self.width = width self.height = height + self.auto_warmup = auto_warmup + self.bootstrap_default_prompt = bootstrap_default_prompt self.video_incoming_frames = asyncio.Queue() self.audio_incoming_frames = asyncio.Queue() @@ -46,6 +67,143 @@ def __init__(self, width: int = 512, height: int = 512, self._comfyui_inference_log_level = comfyui_inference_log_level self._cached_modalities: Optional[Set[str]] = None self._cached_io_capabilities: Optional[WorkflowModality] = None + self.state_manager = PipelineStateManager(self.client) + self._bootstrap_completed = False + self._initialize_lock = asyncio.Lock() + self._ingest_enabled = True + self._prompt_update_lock = asyncio.Lock() + + @property + def state(self) -> PipelineState: + """Expose current pipeline state.""" + return self.state_manager.state + + async def initialize(self): + """Run optional bootstrap workflow to start ComfyUI before prompts are set.""" + if self._bootstrap_completed or not self.bootstrap_default_prompt: + return + + async with self._initialize_lock: + if self._bootstrap_completed or not self.bootstrap_default_prompt: + return + + logger.info("Bootstrapping ComfyUI with default workflow") + await self._run_bootstrap_prompt() + self._bootstrap_completed = True + + async def _run_bootstrap_prompt(self): + """Run the default workflow once with a dummy frame to start ComfyUI.""" + default_workflow = get_default_workflow() + logger.debug("Running default workflow bootstrap prompt") + + try: + await self.client.set_prompts([default_workflow]) + await self.client.resume_prompts() + + dummy_frame = av.VideoFrame() + dummy_frame.side_data.input = torch.randn(1, self.height, self.width, 3) + self.client.put_video_input(dummy_frame) + + await asyncio.wait_for( + self.client.get_video_output(), + timeout=BOOTSTRAP_TIMEOUT_SECONDS, + ) + logger.info("Bootstrap prompt completed successfully") + except asyncio.TimeoutError as exc: + logger.error("Timeout while waiting for bootstrap prompt output") + raise RuntimeError("Bootstrap prompt timed out while waiting for output") from exc + finally: + try: + await self.client.stop_prompts(cleanup=False) + except Exception: + logger.debug("Failed to stop bootstrap prompts cleanly", exc_info=True) + + self.client.current_prompts = [] + self._cached_modalities = None + self._cached_io_capabilities = None + + try: + await self.client.cleanup_queues() + except Exception: + logger.debug("Failed to clear tensor caches after bootstrap prompt", exc_info=True) + + async def warmup( + self, + *, + warm_video: Optional[bool] = None, + warm_audio: Optional[bool] = None, + ): + """Run warmup for selected modalities while managing pipeline state.""" + if not self.state_manager.is_initialized(): + raise RuntimeError("Cannot warm up pipeline before prompts are initialized") + + state_before = self.state + transitioned = False + warmup_successful = False + + try: + if state_before != PipelineState.STREAMING: + await self.state_manager.transition_to(PipelineState.INITIALIZING) + transitioned = True + + await self._run_warmup( + warm_video=warm_video, + warm_audio=warm_audio, + ) + warmup_successful = True + except Exception: + await self.state_manager.transition_to(PipelineState.ERROR) + raise + finally: + if transitioned and warmup_successful: + try: + await self.state_manager.transition_to(PipelineState.READY) + except Exception: + logger.exception("Failed to transition pipeline to READY after warmup") + warmup_successful = False + + if warmup_successful and state_before == PipelineState.STREAMING: + try: + await self.state_manager.transition_to(PipelineState.STREAMING) + except Exception: + logger.exception("Failed to restore STREAMING state after warmup") + + async def _run_warmup( + self, + *, + warm_video: Optional[bool] = None, + warm_audio: Optional[bool] = None, + ): + """Run warmup routines for video and audio as requested.""" + capabilities = self.get_workflow_io_capabilities() + + video_config = capabilities.get("video", {}) + audio_config = capabilities.get("audio", {}) + + should_warm_video = ( + warm_video + if warm_video is not None + else bool(video_config.get("input") or video_config.get("output")) + ) + should_warm_audio = ( + warm_audio + if warm_audio is not None + else bool(audio_config.get("input") or audio_config.get("output")) + ) + + if should_warm_video: + logger.debug("Running video warmup routine") + await self.warm_video() + + if should_warm_audio: + logger.debug("Running audio warmup routine") + await self.warm_audio() + + logger.info( + "Pipeline warmup completed (video=%s, audio=%s)", + should_warm_video, + should_warm_audio, + ) async def warm_video(self): """Warm up the video processing pipeline with dummy frames.""" @@ -53,16 +211,16 @@ async def warm_video(self): if not self.accepts_video_input(): logger.debug("Skipping video warmup - workflow doesn't accept video input") return - + # Create dummy frame with the CURRENT resolution settings dummy_frame = av.VideoFrame() dummy_frame.side_data.input = torch.randn(1, self.height, self.width, 3) - + logger.debug(f"Warming video pipeline with resolution {self.width}x{self.height}") for _ in range(WARMUP_RUNS): self.client.put_video_input(dummy_frame) - + # Wait on the outputs that the workflow actually produces if self.produces_video_output(): await self.client.get_video_output() @@ -77,14 +235,16 @@ async def warm_audio(self): if not self.accepts_audio_input(): logger.debug("Skipping audio warmup - workflow doesn't accept audio input") return - + dummy_frame = av.AudioFrame() - dummy_frame.side_data.input = np.random.randint(-32768, 32768, int(48000 * 0.5), dtype=np.int16) + dummy_frame.side_data.input = np.random.randint( + -32768, 32768, int(48000 * 0.5), dtype=np.int16 + ) dummy_frame.sample_rate = 48000 for _ in range(WARMUP_RUNS): self.client.put_audio_input(dummy_frame) - + # Wait on the outputs that the workflow actually produces if self.produces_video_output(): await self.client.get_video_output() @@ -93,39 +253,245 @@ async def warm_audio(self): if self.produces_text_output(): await self.client.get_text_output() - async def set_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]): + async def set_prompts( + self, + prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]], + *, + skip_warmup: bool = False, + ): """Set the processing prompts for the pipeline. - + Args: prompts: Either a single prompt dictionary or a list of prompt dictionaries + skip_warmup: Skip automatic warmup even if auto_warmup is enabled """ - if isinstance(prompts, list): - await self.client.set_prompts(prompts) - else: - await self.client.set_prompts([prompts]) - - # Clear cached modalities and I/O capabilities when prompts change - self._cached_modalities = None - self._cached_io_capabilities = None + try: + prompt_list = prompts if isinstance(prompts, list) else [prompts] + await self.client.set_prompts(prompt_list) - async def update_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]): + # Refresh cached modalities and I/O capabilities from the new prompts + self._cached_modalities = detect_prompt_modalities(self.client.current_prompts) + self._cached_io_capabilities = detect_io_points(self.client.current_prompts) + + should_warmup = self.auto_warmup and not skip_warmup + if should_warmup: + await self.state_manager.transition_to(PipelineState.INITIALIZING) + try: + await self._run_warmup() + except Exception: + await self.state_manager.transition_to(PipelineState.ERROR) + raise + + await self.state_manager.transition_to(PipelineState.READY) + except Exception: + logger.exception("Failed to set pipeline prompts") + try: + await self.state_manager.transition_to(PipelineState.ERROR) + except ValueError: + logger.debug("Skipping ERROR transition due to invalid state") + except Exception: + logger.exception("Failed to transition pipeline to ERROR state") + raise + + async def update_prompts( + self, + prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]], + *, + skip_warmup: bool = False, + ): """Update the existing processing prompts. - + Args: prompts: Either a single prompt dictionary or a list of prompt dictionaries + skip_warmup: Skip automatic warmup even if auto_warmup is enabled + """ + was_streaming = self.state == PipelineState.STREAMING + should_warmup = self.auto_warmup and not skip_warmup + + try: + if was_streaming and should_warmup: + await self.state_manager.transition_to(PipelineState.READY) + + prompt_list = prompts if isinstance(prompts, list) else [prompts] + await self.client.update_prompts(prompt_list) + + # Refresh cached modalities and I/O capabilities from the updated prompts + self._cached_modalities = detect_prompt_modalities(self.client.current_prompts) + self._cached_io_capabilities = detect_io_points(self.client.current_prompts) + + if should_warmup: + await self.state_manager.transition_to(PipelineState.INITIALIZING) + try: + await self._run_warmup() + except Exception: + await self.state_manager.transition_to(PipelineState.ERROR) + raise + await self.state_manager.transition_to(PipelineState.READY) + + if was_streaming and self.state != PipelineState.STREAMING: + await self.state_manager.transition_to(PipelineState.STREAMING) + except Exception: + logger.exception("Failed to update pipeline prompts") + try: + await self.state_manager.transition_to(PipelineState.ERROR) + except ValueError: + logger.debug("Skipping ERROR transition due to invalid state") + except Exception: + logger.exception("Failed to transition pipeline to ERROR state") + raise + + def disable_ingest(self) -> None: + """Temporarily disable ingestion of new frames into the pipeline.""" + self._ingest_enabled = False + + def enable_ingest(self) -> None: + """Re-enable ingestion of new frames into the pipeline.""" + self._ingest_enabled = True + + def is_ingest_enabled(self) -> bool: + """Check if the pipeline is currently ingesting new frames.""" + return self._ingest_enabled + + async def apply_prompts( + self, + prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]], + *, + skip_warmup: bool = False, + warm_video: Optional[bool] = None, + warm_audio: Optional[bool] = None, + ) -> WorkflowModality: + """Atomically replace prompts while coordinating runner, queues, and state. + + This helper orchestrates prompt swaps by pausing streaming, cancelling any + in-flight prompt execution, clearing input queues, applying the new prompts, + warming the pipeline (unless explicitly skipped), and finally resuming + streaming if it was active beforehand. + + Args: + prompts: Prompt dictionary or list of prompt dictionaries to apply. + skip_warmup: If True, skip automatic warmup after applying prompts. + warm_video: Optional override for video warmup (None = auto-detect). + warm_audio: Optional override for audio warmup (None = auto-detect). + + Returns: + WorkflowModality describing I/O capabilities detected from the new prompts. + """ + prompt_list = prompts if isinstance(prompts, list) else [prompts] + + async with self._prompt_update_lock: + was_streaming = self.state == PipelineState.STREAMING + was_initialized = self.state_manager.is_initialized() + restart_streaming = False + capabilities: WorkflowModality | None = None + self.disable_ingest() + + try: + if was_streaming: + await self.pause_prompts() + + if was_initialized: + await self.stop_prompts_immediately() + + await self._clear_pipeline_queues() + await self.client.cleanup_queues() + + await self.set_prompts(prompt_list, skip_warmup=True) + + capabilities = self.get_workflow_io_capabilities() + video_capability = capabilities.get("video", {}) + audio_capability = capabilities.get("audio", {}) + + has_video_io = bool(video_capability.get("input") or video_capability.get("output")) + has_audio_io = bool(audio_capability.get("input") or audio_capability.get("output")) + + if not skip_warmup: + await self.warmup( + warm_video=warm_video if warm_video is not None else has_video_io, + warm_audio=warm_audio if warm_audio is not None else has_audio_io, + ) + + restart_streaming = was_streaming and self.state_manager.can_stream() + + except Exception: + raise + finally: + self.enable_ingest() + if restart_streaming and self.state_manager.can_stream(): + await self.start_streaming() + + return capabilities if capabilities is not None else self.get_workflow_io_capabilities() + + async def start_streaming(self): + """Enable prompt execution for active streaming.""" + if not self.state_manager.can_stream(): + raise RuntimeError(f"Cannot start streaming in state: {self.state.name}") + + await self.state_manager.transition_to(PipelineState.STREAMING) + + async def stop_streaming(self): + """Pause prompt execution while keeping prompts loaded.""" + if self.state == PipelineState.STREAMING: + await self.state_manager.transition_to(PipelineState.READY) + + async def pause_prompts(self): + """Pause prompt execution loops without canceling tasks.""" + await self.stop_streaming() + + async def resume_prompts(self): + """Resume paused prompt execution loops.""" + await self.start_streaming() + + def are_prompts_running(self) -> bool: + """Check if prompts are currently running. + + Returns: + True if prompts are enabled and running, False otherwise """ - if isinstance(prompts, list): - await self.client.update_prompts(prompts) + return self.state == PipelineState.STREAMING + + async def stop_prompts(self, cleanup: bool = False): + """Stop running prompts by canceling their tasks. + + Args: + cleanup: If True, perform full cleanup including queue clearing and + client shutdown. If False, only cancel prompt tasks. + """ + if self.state in {PipelineState.STREAMING, PipelineState.INITIALIZING}: + await self.state_manager.transition_to(PipelineState.READY) + + await self.client.stop_prompts(cleanup=cleanup) + + # Clear cached modalities and I/O capabilities when prompts are stopped + if cleanup: + self._cached_modalities = None + self._cached_io_capabilities = None + # Clear pipeline queues for full cleanup + await self._clear_pipeline_queues() + try: + await self.state_manager.transition_to(PipelineState.UNINITIALIZED) + except Exception: + logger.exception("Failed to transition pipeline to UNINITIALIZED during cleanup") else: - await self.client.update_prompts([prompts]) - - # Clear cached modalities and I/O capabilities when prompts change - self._cached_modalities = None - self._cached_io_capabilities = None + try: + await self.state_manager.transition_to(PipelineState.READY) + except ValueError: + logger.debug("Skipping READY transition due to invalid state") + except Exception: + logger.exception("Failed to ensure READY state after stopping prompts") + + async def stop_prompts_immediately(self): + """Cancel prompt execution tasks without full cleanup.""" + await self.client.stop_prompts_immediately() + try: + await self.state_manager.transition_to(PipelineState.READY) + except ValueError: + logger.debug("Skipping READY transition during immediate stop") + except Exception: + logger.exception("Failed to ensure READY state during immediate stop") async def put_video_frame(self, frame: av.VideoFrame): """Queue a video frame for processing. - + Args: frame: The video frame to process """ @@ -146,7 +512,7 @@ async def put_video_frame(self, frame: av.VideoFrame): async def put_audio_frame(self, frame: av.AudioFrame, preprocess: bool = True): """Queue an audio frame for processing. - + Args: frame: The audio frame to process """ @@ -168,33 +534,37 @@ async def put_audio_frame(self, frame: av.AudioFrame, preprocess: bool = True): def video_preprocess(self, frame: av.VideoFrame) -> torch.Tensor: """Preprocess a video frame before processing. - + Args: frame: The video frame to preprocess - + Returns: The preprocessed frame as a tensor or numpy array """ frame_np = frame.to_ndarray(format="rgb24").astype(np.float32) / 255.0 return torch.from_numpy(frame_np).unsqueeze(0) - + def audio_preprocess(self, frame: av.AudioFrame) -> np.ndarray: """Preprocess an audio frame before processing. - + Args: frame: The audio frame to preprocess - + Returns: The preprocessed frame as a numpy array with int16 dtype """ audio_data = frame.to_ndarray() - + # Handle multi-dimensional audio data - if audio_data.ndim == 2 and audio_data.shape[0] == 1 and audio_data.shape[0] <= audio_data.shape[1]: + if ( + audio_data.ndim == 2 + and audio_data.shape[0] == 1 + and audio_data.shape[0] <= audio_data.shape[1] + ): audio_data = audio_data.ravel().reshape(-1, 2).mean(axis=1) elif audio_data.ndim > 1: audio_data = audio_data.mean(axis=0) - + # Ensure we always return int16 data if audio_data.dtype in [np.float32, np.float64]: # Check if data is normalized (-1.0 to 1.0 range) @@ -209,15 +579,15 @@ def audio_preprocess(self, frame: av.AudioFrame) -> np.ndarray: else: # Already integer data - ensure it's int16 audio_data = audio_data.astype(np.int16) - + return audio_data - + def video_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.VideoFrame: """Postprocess a video frame after processing. - + Args: output: The processed output tensor or numpy array - + Returns: The postprocessed video frame """ @@ -227,32 +597,32 @@ def video_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.Video def audio_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.AudioFrame: """Postprocess an audio frame after processing. - + Args: output: The processed output tensor or numpy array - + Returns: The postprocessed audio frame """ return av.AudioFrame.from_ndarray(np.repeat(output, 2).reshape(1, -1)) - + # TODO: make it generic to support purely generative video cases async def get_processed_video_frame(self) -> av.VideoFrame: """Get the next processed video frame. - + Returns: The processed video frame, or original frame if no processing needed """ frame = await self.video_incoming_frames.get() - + # Skip frames that were marked as skipped - while frame.side_data.skipped and not hasattr(frame.side_data, 'passthrough'): + while frame.side_data.skipped and not hasattr(frame.side_data, "passthrough"): frame = await self.video_incoming_frames.get() - + # If this is a passthrough frame (no video output from workflow), return original - if hasattr(frame.side_data, 'passthrough') and frame.side_data.passthrough: + if hasattr(frame.side_data, "passthrough") and frame.side_data.passthrough: return frame - + # Get processed output from client async with temporary_log_level("comfy", self._comfyui_inference_log_level): out_tensor = await self.client.get_video_output() @@ -260,12 +630,12 @@ async def get_processed_video_frame(self) -> av.VideoFrame: processed_frame = self.video_postprocess(out_tensor) processed_frame.pts = frame.pts processed_frame.time_base = frame.time_base - + return processed_frame async def get_processed_audio_frame(self) -> av.AudioFrame: """Get the next processed audio frame. - + Returns: The processed audio frame, or original frame if no processing needed """ @@ -276,125 +646,124 @@ async def get_processed_audio_frame(self) -> av.AudioFrame: logger.debug("No audio frames available - generating silence frame") # Generate a silent audio frame to prevent blocking silent_frame = av.AudioFrame.from_ndarray( - np.zeros((1, 1024), dtype=np.int16), - format='s16', - layout='mono' + np.zeros((1, 1024), dtype=np.int16), format="s16", layout="mono" ) silent_frame.sample_rate = 48000 return silent_frame - + # If this is a passthrough frame (no audio output from workflow), return original - if hasattr(frame.side_data, 'passthrough') and frame.side_data.passthrough: + if hasattr(frame.side_data, "passthrough") and frame.side_data.passthrough: return frame - + # Process audio if needed if frame.samples > len(self.processed_audio_buffer): async with temporary_log_level("comfy", self._comfyui_inference_log_level): out_tensor = await self.client.get_audio_output() self.processed_audio_buffer = np.concatenate([self.processed_audio_buffer, out_tensor]) - - out_data = self.processed_audio_buffer[:frame.samples] - self.processed_audio_buffer = self.processed_audio_buffer[frame.samples:] + + out_data = self.processed_audio_buffer[: frame.samples] + self.processed_audio_buffer = self.processed_audio_buffer[frame.samples :] processed_frame = self.audio_postprocess(out_data) processed_frame.pts = frame.pts processed_frame.time_base = frame.time_base processed_frame.sample_rate = frame.sample_rate - + return processed_frame - + async def get_text_output(self) -> str | None: """Get the next text output from the pipeline. - + Returns: The processed text output, or empty string if no text output produced """ # If workflow doesn't produce text output, return empty string immediately if not self.produces_text_output(): return None - + async with temporary_log_level("comfy", self._comfyui_inference_log_level): out_text = await self.client.get_text_output() - + return out_text - + async def get_nodes_info(self) -> Dict[str, Any]: """Get information about all nodes in the current prompt including metadata. - + Returns: Dictionary containing node information """ nodes_info = await self.client.get_available_nodes() return nodes_info - + def get_workflow_io_capabilities(self) -> WorkflowModality: """Get the I/O capabilities for each modality in the current workflow. - + Returns: WorkflowModality mapping each modality to its input/output capabilities """ if self._cached_io_capabilities is None: - if not hasattr(self.client, 'current_prompts') or not self.client.current_prompts: - # Return empty capabilities if no prompts - return create_empty_workflow_modality() - - self._cached_io_capabilities = detect_io_points(self.client.current_prompts) - + if not hasattr(self.client, "current_prompts") or not self.client.current_prompts: + # Cache empty capabilities if no prompts to avoid repeated checks + self._cached_io_capabilities = create_empty_workflow_modality() + else: + self._cached_io_capabilities = detect_io_points(self.client.current_prompts) + return self._cached_io_capabilities def get_workflow_modalities(self) -> Set[str]: """Get the modalities required by the current workflow. - + Returns: Set of modality strings: {'video', 'audio', 'text'} """ if self._cached_modalities is None: - if not hasattr(self.client, 'current_prompts') or not self.client.current_prompts: - return set() - - self._cached_modalities = detect_prompt_modalities(self.client.current_prompts) - + if not hasattr(self.client, "current_prompts") or not self.client.current_prompts: + # Cache empty set if no prompts to avoid repeated checks + self._cached_modalities = set() + else: + self._cached_modalities = detect_prompt_modalities(self.client.current_prompts) + return self._cached_modalities - + def get_modalities(self) -> Set[str]: """Alias for get_workflow_modalities for compatibility.""" return self.get_workflow_modalities() - + def requires_video(self) -> bool: """Check if the workflow requires video processing.""" return "video" in self.get_workflow_modalities() - + def requires_audio(self) -> bool: """Check if the workflow requires audio processing.""" return "audio" in self.get_workflow_modalities() - + def requires_text(self) -> bool: """Check if the workflow requires text processing.""" return "text" in self.get_workflow_modalities() - + def accepts_video_input(self) -> bool: """Check if the workflow accepts video input.""" return self.get_workflow_io_capabilities()["video"]["input"] - + def accepts_audio_input(self) -> bool: """Check if the workflow accepts audio input.""" return self.get_workflow_io_capabilities()["audio"]["input"] - + def produces_video_output(self) -> bool: """Check if the workflow produces video output.""" return self.get_workflow_io_capabilities()["video"]["output"] - + def produces_audio_output(self) -> bool: """Check if the workflow produces audio output.""" return self.get_workflow_io_capabilities()["audio"]["output"] - + def produces_text_output(self) -> bool: """Check if the workflow produces text output.""" return self.get_workflow_io_capabilities()["text"]["output"] - + async def cleanup(self): """Clean up resources used by the pipeline. - + This includes: - Canceling running prompts - Clearing all queues (video, audio, tensor caches) @@ -402,19 +771,26 @@ async def cleanup(self): - Clearing cached modalities """ logger.debug("Starting pipeline cleanup") - + # Clear cached modalities and I/O capabilities since we're resetting self._cached_modalities = None self._cached_io_capabilities = None - + # Clear pipeline queues await self._clear_pipeline_queues() - + # Cleanup client (this handles prompt cancellation and tensor cache cleanup) await self.client.cleanup() - + + try: + await self.state_manager.transition_to(PipelineState.UNINITIALIZED) + except ValueError: + logger.debug("Skipping UNINITIALIZED transition during cleanup") + except Exception: + logger.exception("Failed to transition pipeline to UNINITIALIZED during cleanup") + logger.debug("Pipeline cleanup completed") - + async def _clear_pipeline_queues(self): """Clear the pipeline's internal frame queues.""" # Clear video frame queue @@ -423,15 +799,15 @@ async def _clear_pipeline_queues(self): self.video_incoming_frames.get_nowait() except asyncio.QueueEmpty: break - - # Clear audio frame queue + + # Clear audio frame queue while not self.audio_incoming_frames.empty(): try: self.audio_incoming_frames.get_nowait() except asyncio.QueueEmpty: break - + # Reset audio buffer self.processed_audio_buffer = np.array([], dtype=np.int16) - - logger.debug("Pipeline queues cleared") + + logger.debug("Pipeline queues cleared") diff --git a/src/comfystream/pipeline_state.py b/src/comfystream/pipeline_state.py new file mode 100644 index 000000000..0b2f93b94 --- /dev/null +++ b/src/comfystream/pipeline_state.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +import asyncio +import logging +from enum import Enum, auto +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from comfystream.client import ComfyStreamClient + +logger = logging.getLogger(__name__) + + +class PipelineState(Enum): + """Pipeline lifecycle states.""" + + UNINITIALIZED = auto() + INITIALIZING = auto() + READY = auto() + STREAMING = auto() + ERROR = auto() + + +class PipelineStateManager: + """Manages pipeline state transitions and runner lifecycle.""" + + def __init__(self, client: ComfyStreamClient): + self.client = client + self._state = PipelineState.UNINITIALIZED + self._state_lock = asyncio.Lock() + + @property + def state(self) -> PipelineState: + return self._state + + async def transition_to(self, new_state: PipelineState): + """Transition to a new state with automatic runner management.""" + # Fast path for no-op transitions + if new_state == self._state: + logger.debug("Pipeline state unchanged: %s", new_state.name) + return + + async with self._state_lock: + if new_state == self._state: + logger.debug("Pipeline state unchanged (locked): %s", new_state.name) + return + + old_state = self._state + + if not self._is_valid_transition(old_state, new_state): + raise ValueError(f"Invalid transition: {old_state.name} -> {new_state.name}") + + await self._on_exit_state(old_state) + self._state = new_state + await self._on_enter_state(new_state) + + logger.info("Pipeline state: %s -> %s", old_state.name, new_state.name) + + def _is_valid_transition(self, from_state: PipelineState, to_state: PipelineState) -> bool: + """Define valid state transitions.""" + valid_transitions = { + PipelineState.UNINITIALIZED: { + PipelineState.INITIALIZING, + PipelineState.READY, + PipelineState.ERROR, + }, + PipelineState.INITIALIZING: { + PipelineState.READY, + PipelineState.ERROR, + }, + PipelineState.READY: { + PipelineState.INITIALIZING, + PipelineState.STREAMING, + PipelineState.UNINITIALIZED, + PipelineState.ERROR, + }, + PipelineState.STREAMING: { + PipelineState.READY, + PipelineState.ERROR, + }, + PipelineState.ERROR: { + PipelineState.INITIALIZING, + PipelineState.READY, + PipelineState.UNINITIALIZED, + }, + } + + return to_state in valid_transitions.get(from_state, set()) + + async def _on_enter_state(self, state: PipelineState): + """Actions executed when entering a state.""" + if state == PipelineState.INITIALIZING: + await self.client.resume_prompts() + elif state == PipelineState.READY: + self.client.pause_prompts() + elif state == PipelineState.STREAMING: + await self.client.resume_prompts() + elif state == PipelineState.ERROR: + self.client.pause_prompts() + elif state == PipelineState.UNINITIALIZED: + self.client.pause_prompts() + + async def _on_exit_state(self, _state: PipelineState): + """Actions executed when exiting a state.""" + return + + def can_stream(self) -> bool: + """Check if pipeline is ready to stream.""" + return self._state in {PipelineState.READY, PipelineState.STREAMING} + + def is_initialized(self) -> bool: + """Check if pipeline has been initialized with prompts.""" + return self._state != PipelineState.UNINITIALIZED diff --git a/src/comfystream/scripts/__init__.py b/src/comfystream/scripts/__init__.py index 9273e2423..abcb156fa 100644 --- a/src/comfystream/scripts/__init__.py +++ b/src/comfystream/scripts/__init__.py @@ -1,7 +1,7 @@ -from . import utils from utils import get_config_path, load_model_config -from .setup_nodes import run_setup_nodes + +from . import utils from .setup_models import run_setup_models +from .setup_nodes import run_setup_nodes """Setup scripts for ComfyUI streaming server""" - diff --git a/src/comfystream/scripts/build_trt.py b/src/comfystream/scripts/build_trt.py index 3ef1ac7c3..b329ff20d 100644 --- a/src/comfystream/scripts/build_trt.py +++ b/src/comfystream/scripts/build_trt.py @@ -1,9 +1,9 @@ #!/usr/bin/env python3 +import argparse import os import sys import time -import argparse # Reccomended running from comfystream conda environment # in devcontainer from the workspace/ directory, or comfystream/ if you've checked out the repo @@ -11,7 +11,7 @@ # $> python src/comfystream/scripts/build_trt.py --model /ComfyUI/models/checkpoints/SD1.5/dreamshaper-8.safetensors --out-engine /ComfyUI/output/tensorrt/static-dreamshaper8_SD15_$stat-b-1-h-512-w-512_00001_.engine # Paths path explicitly to use the downloaded comfyUI installation on root -ROOT_DIR="/workspace" +ROOT_DIR = "/workspace" COMFYUI_DIR = "/workspace/ComfyUI" timing_cache_path = "/workspace/ComfyUI/output/tensorrt/timing_cache" @@ -22,15 +22,16 @@ import comfy import comfy.model_management - -from ComfyUI.custom_nodes.ComfyUI_TensorRT.models.supported_models import detect_version_from_model, get_helper_from_model +from ComfyUI.custom_nodes.ComfyUI_TensorRT.models.supported_models import ( + detect_version_from_model, + get_helper_from_model, +) from ComfyUI.custom_nodes.ComfyUI_TensorRT.onnx_utils.export import export_onnx from ComfyUI.custom_nodes.ComfyUI_TensorRT.tensorrt_diffusion_model import TRTDiffusionBackbone + def parse_args(): - parser = argparse.ArgumentParser( - description="Build a TensorRT engine from a ComfyUI model." - ) + parser = argparse.ArgumentParser(description="Build a TensorRT engine from a ComfyUI model.") parser.add_argument( "--model", type=str, @@ -63,10 +64,18 @@ def parse_args(): ) # Dynamic Engine Optional Args - parser.add_argument("--min-width", type=int, default=None, help="Minimum width for dynamic shape (optional)") - parser.add_argument("--min-height", type=int, default=None, help="Minimum height for dynamic shape (optional)") - parser.add_argument("--max-width", type=int, default=None, help="Maximum width for dynamic shape (optional)") - parser.add_argument("--max-height", type=int, default=None, help="Maximum height for dynamic shape (optional)") + parser.add_argument( + "--min-width", type=int, default=None, help="Minimum width for dynamic shape (optional)" + ) + parser.add_argument( + "--min-height", type=int, default=None, help="Minimum height for dynamic shape (optional)" + ) + parser.add_argument( + "--max-width", type=int, default=None, help="Maximum width for dynamic shape (optional)" + ) + parser.add_argument( + "--max-height", type=int, default=None, help="Maximum height for dynamic shape (optional)" + ) parser.add_argument( "--context", @@ -81,12 +90,11 @@ def parse_args(): help="If set, attempts to export the ONNX with FP8 transformations (Flux or standard).", ) parser.add_argument( - "--verbose", - action="store_true", - help="Enable more logging / debug prints." + "--verbose", action="store_true", help="Enable more logging / debug prints." ) return parser.parse_args() + def build_trt_engine( model_path: str, engine_out_path: str, @@ -100,7 +108,7 @@ def build_trt_engine( context_opt: int = 1, num_video_frames: int = 14, fp8: bool = False, - verbose: bool = False + verbose: bool = False, ): """ 1) Load the model from ComfyUI by path or name @@ -121,8 +129,10 @@ def build_trt_engine( if verbose: print(f"[INFO] Starting build for model: {model_path}") print(f" Output Engine Path: {engine_out_path}") - print(f" (batch={batch_size_opt}, H={height_opt}, W={width_opt}, context={context_opt}, " - f"num_video_frames={num_video_frames}, fp8={fp8})") + print( + f" (batch={batch_size_opt}, H={height_opt}, W={width_opt}, context={context_opt}, " + f"num_video_frames={num_video_frames}, fp8={fp8})" + ) # 1) Load model in GPU: comfy.model_management.unload_all_models() @@ -130,11 +140,9 @@ def build_trt_engine( loaded_model = comfy.sd.load_diffusion_model(model_path, model_options={}) if loaded_model is None: raise ValueError("Failed to load model.") - + comfy.model_management.load_models_gpu( - [loaded_model], - force_patch_weights=True, - force_full_load=True + [loaded_model], force_patch_weights=True, force_full_load=True ) # 2) Export to ONNX at the desired shape @@ -151,19 +159,19 @@ def build_trt_engine( print(f"[INFO] Exporting ONNX to: {onnx_path}") export_onnx( - model = loaded_model, - path = onnx_path, - batch_size = batch_size_opt, - height = height_opt, - width = width_opt, - num_video_frames = num_video_frames, - context_multiplier = context_opt, - fp8 = fp8, + model=loaded_model, + path=onnx_path, + batch_size=batch_size_opt, + height=height_opt, + width=width_opt, + num_video_frames=num_video_frames, + context_multiplier=context_opt, + fp8=fp8, ) # 3) Build the TRT engine model_version = detect_version_from_model(loaded_model) - model_helper = get_helper_from_model(loaded_model) + model_helper = get_helper_from_model(loaded_model) trt_model = TRTDiffusionBackbone(model_helper) @@ -171,20 +179,20 @@ def build_trt_engine( is_dynamic = all(v is not None for v in [min_width, max_width, min_height, max_height]) min_config = { "batch_size": batch_size_opt, - "height": min_height if is_dynamic else height_opt, - "width": min_width if is_dynamic else width_opt, + "height": min_height if is_dynamic else height_opt, + "width": min_width if is_dynamic else width_opt, "context_len": context_opt * model_helper.context_len, } opt_config = { "batch_size": batch_size_opt, - "height": height_opt, - "width": width_opt, + "height": height_opt, + "width": width_opt, "context_len": context_opt * model_helper.context_len, } max_config = { "batch_size": batch_size_opt, - "height": max_height if is_dynamic else height_opt, - "width": max_width if is_dynamic else width_opt, + "height": max_height if is_dynamic else height_opt, + "width": max_width if is_dynamic else width_opt, "context_len": context_opt * model_helper.context_len, } @@ -197,12 +205,12 @@ def build_trt_engine( print(f"[INFO] Building engine -> {engine_out_path}") success = trt_model.build( - onnx_path = onnx_path, - engine_path = engine_out_path, - timing_cache_path = timing_cache_path, - opt_config = opt_config, - min_config = min_config, - max_config = max_config, + onnx_path=onnx_path, + engine_path=engine_out_path, + timing_cache_path=timing_cache_path, + opt_config=opt_config, + min_config=min_config, + max_config=max_config, ) if not success: raise RuntimeError("[ERROR] TensorRT engine build failed") @@ -224,18 +232,18 @@ def build_trt_engine( def main(): args = parse_args() build_trt_engine( - model_path = args.model, - engine_out_path = args.out_engine, - batch_size_opt = args.batch_size, - height_opt = args.height, - width_opt = args.width, - min_width = args.min_width, - min_height = args.min_height, - max_width = args.max_width, - max_height = args.max_height, - context_opt = args.context, - fp8 = args.fp8, - verbose = args.verbose + model_path=args.model, + engine_out_path=args.out_engine, + batch_size_opt=args.batch_size, + height_opt=args.height, + width_opt=args.width, + min_width=args.min_width, + min_height=args.min_height, + max_width=args.max_width, + max_height=args.max_height, + context_opt=args.context, + fp8=args.fp8, + verbose=args.verbose, ) diff --git a/src/comfystream/scripts/setup_models.py b/src/comfystream/scripts/setup_models.py index 9360a542b..50a186f46 100644 --- a/src/comfystream/scripts/setup_models.py +++ b/src/comfystream/scripts/setup_models.py @@ -1,35 +1,38 @@ +import argparse import os from pathlib import Path + import requests -from tqdm import tqdm import yaml -import argparse +from tqdm import tqdm from utils import get_config_path, load_model_config + def parse_args(): - parser = argparse.ArgumentParser(description='Setup ComfyUI models') - parser.add_argument('--workspace', - default=os.environ.get('COMFY_UI_WORKSPACE', os.path.expanduser('~/comfyui')), - help='ComfyUI workspace directory (default: ~/comfyui or $COMFY_UI_WORKSPACE)') + parser = argparse.ArgumentParser(description="Setup ComfyUI models") + parser.add_argument( + "--workspace", + default=os.environ.get("COMFY_UI_WORKSPACE", os.path.expanduser("~/comfyui")), + help="ComfyUI workspace directory (default: ~/comfyui or $COMFY_UI_WORKSPACE)", + ) return parser.parse_args() + def download_file(url, destination, description=None): """Download a file with progress bar, follow redirects, and detect LFS pointer files""" - headers = { - "User-Agent": "Mozilla/5.0" - } + headers = {"User-Agent": "Mozilla/5.0"} with requests.get(url, stream=True, headers=headers, allow_redirects=True) as response: response.raise_for_status() - total_size = int(response.headers.get('content-length', 0)) + total_size = int(response.headers.get("content-length", 0)) desc = description or os.path.basename(destination) - progress_bar = tqdm(total=total_size, unit='iB', unit_scale=True, desc=desc) + progress_bar = tqdm(total=total_size, unit="iB", unit_scale=True, desc=desc) destination = Path(destination) destination.parent.mkdir(parents=True, exist_ok=True) - with open(destination, 'wb') as file: + with open(destination, "wb") as file: for chunk in response.iter_content(chunk_size=1024): if chunk: file.write(chunk) @@ -38,17 +41,18 @@ def download_file(url, destination, description=None): # Verify that we didn't just write a Git LFS pointer if destination.stat().st_size < 100: - with open(destination, 'r', errors='ignore') as f: + with open(destination, "r", errors="ignore") as f: content = f.read() - if 'git-lfs' in content.lower(): + if "git-lfs" in content.lower(): print(f"❌ LFS pointer detected in {destination}. Deleting.") destination.unlink() raise ValueError(f"LFS pointer detected. Failed to download: {url}") + def setup_model_files(workspace_dir, config_path=None): """Download and setup required model files based on configuration""" if config_path is None: - config_path = get_config_path('models.yaml') + config_path = get_config_path("models.yaml") try: config = load_model_config(config_path) except FileNotFoundError: @@ -61,34 +65,31 @@ def setup_model_files(workspace_dir, config_path=None): models_path = workspace_dir / "models" base_path = workspace_dir - for _, model_info in config['models'].items(): + for _, model_info in config["models"].items(): # Determine the full path based on whether it's in custom_nodes or models - if model_info['path'].startswith('custom_nodes/'): - full_path = base_path / model_info['path'] + if model_info["path"].startswith("custom_nodes/"): + full_path = base_path / model_info["path"] else: - full_path = models_path / model_info['path'] + full_path = models_path / model_info["path"] if not full_path.exists(): print(f"Downloading {model_info['name']}...") - download_file( - model_info['url'], - full_path, - f"Downloading {model_info['name']}" - ) + download_file(model_info["url"], full_path, f"Downloading {model_info['name']}") print(f"Downloaded {model_info['name']} to {full_path}") # Handle any extra files (like configs) - if 'extra_files' in model_info: - for extra in model_info['extra_files']: - extra_path = models_path / extra['path'] + if "extra_files" in model_info: + for extra in model_info["extra_files"]: + extra_path = models_path / extra["path"] if not extra_path.exists(): download_file( - extra['url'], + extra["url"], extra_path, - f"Downloading {os.path.basename(extra['path'])}" + f"Downloading {os.path.basename(extra['path'])}", ) print("Models download completed!") + def setup_directories(workspace_dir): """Create required directories in the workspace""" # Create base directories @@ -119,6 +120,7 @@ def setup_directories(workspace_dir): subdir = models_dir / dir_name subdir.mkdir(parents=True, exist_ok=True) + def setup_models(): args = parse_args() workspace_dir = Path(args.workspace) @@ -126,4 +128,5 @@ def setup_models(): setup_directories(workspace_dir) setup_model_files(workspace_dir) + setup_models() diff --git a/src/comfystream/scripts/setup_nodes.py b/src/comfystream/scripts/setup_nodes.py index 418e55f86..2aca10772 100755 --- a/src/comfystream/scripts/setup_nodes.py +++ b/src/comfystream/scripts/setup_nodes.py @@ -1,9 +1,10 @@ +import argparse import os import subprocess import sys from pathlib import Path + import yaml -import argparse from utils import get_config_path, load_model_config @@ -77,7 +78,9 @@ def install_custom_nodes(workspace_dir, config_path=None, pull_branches=False): print(f"Updating {node_info['name']} to latest {node_info['branch']}...") subprocess.run(["git", "-C", dir_name, "fetch", "origin"], check=True) subprocess.run(["git", "-C", dir_name, "checkout", node_info["branch"]], check=True) - subprocess.run(["git", "-C", dir_name, "pull", "origin", node_info["branch"]], check=True) + subprocess.run( + ["git", "-C", dir_name, "pull", "origin", node_info["branch"]], check=True + ) else: print(f"{node_info['name']} already exists, skipping clone.") diff --git a/src/comfystream/scripts/utils.py b/src/comfystream/scripts/utils.py index a7b37f2df..5ce4ab351 100644 --- a/src/comfystream/scripts/utils.py +++ b/src/comfystream/scripts/utils.py @@ -1,12 +1,14 @@ -import yaml from pathlib import Path +import yaml + + def get_config_path(filename): """Get the absolute path to a config file""" config_path = Path("configs") / filename if not config_path.exists(): print(f"Warning: Config file {filename} not found at {config_path}") - print(f"Available files in configs/:") + print("Available files in configs/:") try: for f in Path("configs").glob("*"): print(f" - {f.name}") @@ -15,7 +17,8 @@ def get_config_path(filename): raise FileNotFoundError(f"Config file {filename} not found at {config_path}") return config_path + def load_model_config(config_path): """Load model configuration from YAML file""" - with open(config_path, 'r') as f: - return yaml.safe_load(f) \ No newline at end of file + with open(config_path, "r") as f: + return yaml.safe_load(f) diff --git a/src/comfystream/server/metrics/prometheus_metrics.py b/src/comfystream/server/metrics/prometheus_metrics.py index 080bc2940..e1aec64e5 100644 --- a/src/comfystream/server/metrics/prometheus_metrics.py +++ b/src/comfystream/server/metrics/prometheus_metrics.py @@ -1,9 +1,10 @@ """Prometheus metrics utilities.""" -from prometheus_client import Gauge, generate_latest -from aiohttp import web from typing import Optional +from aiohttp import web +from prometheus_client import Gauge, generate_latest + class MetricsManager: """Manages Prometheus metrics collection.""" @@ -18,9 +19,7 @@ def __init__(self, include_stream_id: bool = False): self._include_stream_id = include_stream_id base_labels = ["stream_id"] if include_stream_id else [] - self._fps_gauge = Gauge( - "stream_fps", "Frames per second of the stream", base_labels - ) + self._fps_gauge = Gauge("stream_fps", "Frames per second of the stream", base_labels) def enable(self): """Enable Prometheus metrics collection.""" diff --git a/src/comfystream/server/metrics/stream_stats.py b/src/comfystream/server/metrics/stream_stats.py index 8dc2ab19e..40d88c782 100644 --- a/src/comfystream/server/metrics/stream_stats.py +++ b/src/comfystream/server/metrics/stream_stats.py @@ -1,7 +1,8 @@ """Handles real-time video stream statistics (non-Prometheus, JSON API).""" -from typing import Any, Dict import json +from typing import Any, Dict + from aiohttp import web from aiortc import MediaStreamTrack @@ -17,9 +18,7 @@ def __init__(self, app: web.Application): """ self._app = app - async def collect_video_metrics( - self, video_track: MediaStreamTrack - ) -> Dict[str, Any]: + async def collect_video_metrics(self, video_track: MediaStreamTrack) -> Dict[str, Any]: """Collects real-time statistics for a video track. Args: diff --git a/src/comfystream/server/utils/__init__.py b/src/comfystream/server/utils/__init__.py index daa71bb1e..f64357c32 100644 --- a/src/comfystream/server/utils/__init__.py +++ b/src/comfystream/server/utils/__init__.py @@ -1,2 +1,2 @@ -from .utils import patch_loop_datagram, add_prefix_to_app_routes, temporary_log_level from .fps_meter import FPSMeter +from .utils import add_prefix_to_app_routes, patch_loop_datagram, temporary_log_level diff --git a/src/comfystream/server/utils/fps_meter.py b/src/comfystream/server/utils/fps_meter.py index 87e75d461..ee86772b2 100644 --- a/src/comfystream/server/utils/fps_meter.py +++ b/src/comfystream/server/utils/fps_meter.py @@ -4,6 +4,7 @@ import logging import time from collections import deque + from comfystream.server.metrics import MetricsManager logger = logging.getLogger(__name__) @@ -35,11 +36,7 @@ async def _calculate_fps_loop(self): current_time = time.monotonic() if self._last_fps_calculation_time is not None: time_diff = current_time - self._last_fps_calculation_time - self._fps = ( - self._fps_interval_frame_count / time_diff - if time_diff > 0 - else 0.0 - ) + self._fps = self._fps_interval_frame_count / time_diff if time_diff > 0 else 0.0 self._fps_measurements.append( { "timestamp": current_time - self._fps_loop_start_time, @@ -92,8 +89,7 @@ async def average_fps(self) -> float: """ async with self._lock: return ( - sum(m["fps"] for m in self._fps_measurements) - / len(self._fps_measurements) + sum(m["fps"] for m in self._fps_measurements) / len(self._fps_measurements) if self._fps_measurements else self._fps ) @@ -106,9 +102,6 @@ async def last_fps_calculation_time(self) -> float: The elapsed time in seconds since the last FPS calculation. """ async with self._lock: - if ( - self._last_fps_calculation_time is None - or self._fps_loop_start_time is None - ): + if self._last_fps_calculation_time is None or self._fps_loop_start_time is None: return 0.0 return self._last_fps_calculation_time - self._fps_loop_start_time diff --git a/src/comfystream/server/utils/utils.py b/src/comfystream/server/utils/utils.py index 96e6661e9..baa7ff6d4 100644 --- a/src/comfystream/server/utils/utils.py +++ b/src/comfystream/server/utils/utils.py @@ -1,12 +1,13 @@ """General utility functions.""" import asyncio +import logging import random import types -import logging -from aiohttp import web -from typing import List, Tuple from contextlib import asynccontextmanager +from typing import List, Tuple + +from aiohttp import web logger = logging.getLogger(__name__) @@ -30,9 +31,7 @@ async def create_datagram_endpoint( protocol_factory, local_addr=local_addr, **kwargs ) if local_addr is None: - return await old_create_datagram_endpoint( - protocol_factory, local_addr=None, **kwargs - ) + return await old_create_datagram_endpoint(protocol_factory, local_addr=None, **kwargs) # if port is not specified make it use our range ports = list(local_ports) random.shuffle(ports) @@ -83,4 +82,3 @@ async def temporary_log_level(logger_name: str, level: int): finally: if level is not None: logger.setLevel(original_level) - diff --git a/src/comfystream/tensor_cache.py b/src/comfystream/tensor_cache.py index 5cd54332e..609f98b8f 100644 --- a/src/comfystream/tensor_cache.py +++ b/src/comfystream/tensor_cache.py @@ -1,11 +1,10 @@ -import torch -import numpy as np - -from queue import Queue from asyncio import Queue as AsyncQueue - +from queue import Queue from typing import Union +import numpy as np +import torch + # TODO: improve eviction policy fifo might not be the best, skip alternate frames instead image_inputs: Queue[Union[torch.Tensor, np.ndarray]] = Queue(maxsize=1) image_outputs: AsyncQueue[Union[torch.Tensor, np.ndarray]] = AsyncQueue() diff --git a/src/comfystream/utils.py b/src/comfystream/utils.py index c2cecd05b..1fd035a8f 100644 --- a/src/comfystream/utils.py +++ b/src/comfystream/utils.py @@ -1,16 +1,17 @@ import copy -import json -import os -import logging import importlib -from typing import Dict, Any, List, Tuple, Optional, Union -from pytrickle.api import StreamParamsUpdateRequest +import json +from typing import Any, Dict + from comfy.api.components.schema.prompt import Prompt, PromptDictInput +from pytrickle.api import StreamParamsUpdateRequest + from .modalities import ( - get_node_counts_by_type, get_convertible_node_keys, + get_node_counts_by_type, ) + def create_load_tensor_node(): return { "inputs": {}, @@ -18,6 +19,7 @@ def create_load_tensor_node(): "_meta": {"title": "LoadTensor"}, } + def create_save_tensor_node(inputs: Dict[Any, Any]): return { "inputs": inputs, @@ -25,6 +27,7 @@ def create_save_tensor_node(inputs: Dict[Any, Any]): "_meta": {"title": "SaveTensor"}, } + def _validate_prompt_constraints(counts: Dict[str, int]) -> None: """Validate that the prompt meets the required constraints.""" if counts["primary_inputs"] > 1: @@ -42,6 +45,7 @@ def _validate_prompt_constraints(counts: Dict[str, int]) -> None: if counts["outputs"] == 0: raise Exception("missing output") + def convert_prompt(prompt: PromptDictInput, return_dict: bool = False) -> Prompt: """Convert a prompt by replacing specific node types with tensor equivalents.""" try: @@ -49,7 +53,7 @@ def convert_prompt(prompt: PromptDictInput, return_dict: bool = False) -> Prompt importlib.import_module("comfy.api.components.schema.prompt_node") except Exception: pass - + """Convert and validate a ComfyUI workflow prompt.""" Prompt.validate(prompt) prompt = copy.deepcopy(prompt) @@ -57,7 +61,7 @@ def convert_prompt(prompt: PromptDictInput, return_dict: bool = False) -> Prompt # Count nodes and validate constraints counts = get_node_counts_by_type(prompt) _validate_prompt_constraints(counts) - + # Collect nodes that need conversion convertible_keys = get_convertible_node_keys(prompt) @@ -73,78 +77,86 @@ def convert_prompt(prompt: PromptDictInput, return_dict: bool = False) -> Prompt for key in convertible_keys["PreviewImage"] + convertible_keys["SaveImage"]: node = prompt[key] prompt[key] = create_save_tensor_node(node["inputs"]) - # Return dict if requested (for downstream components that expect plain dicts) if return_dict: return prompt # Already a plain dict at this point - + # Validate the processed prompt and return Pydantic object return Prompt.validate(prompt) + class ComfyStreamParamsUpdateRequest(StreamParamsUpdateRequest): """ComfyStream parameter validation.""" - + def __init__(self, **data): # Handle prompts parameter if "prompts" in data: prompts = data["prompts"] - + # Parse JSON string if needed if isinstance(prompts, str) and prompts.strip(): try: prompts = json.loads(prompts) except json.JSONDecodeError: data.pop("prompts") - + # Handle list - use first valid dict elif isinstance(prompts, list): prompts = next((p for p in prompts if isinstance(p, dict)), None) if not prompts: data.pop("prompts") - + # Validate prompts if "prompts" in data and isinstance(prompts, dict): try: data["prompts"] = convert_prompt(prompts, return_dict=True) except Exception: data.pop("prompts") - + # Call parent constructor super().__init__(**data) - + @classmethod def model_validate(cls, obj): return cls(**obj) - + def model_dump(self): return super().model_dump() + +def normalize_stream_params(params: Any) -> Dict[str, Any]: + """Normalize stream parameters from various formats to a dict. + + Args: + params: Parameters in dict, list, or other format + + Returns: + Dict containing normalized parameters, empty dict if invalid + """ + if params is None: + return {} + if isinstance(params, dict): + return dict(params) + if isinstance(params, list): + for candidate in params: + if isinstance(candidate, dict): + return dict(candidate) + return {} + return {} + + def get_default_workflow() -> dict: """Return the default workflow as a dictionary for warmup. - + Returns: dict: Default workflow dictionary """ return { "1": { - "inputs": { - "images": [ - "2", - 0 - ] - }, + "inputs": {"images": ["2", 0]}, "class_type": "SaveTensor", - "_meta": { - "title": "SaveTensor" - } + "_meta": {"title": "SaveTensor"}, }, - "2": { - "inputs": {}, - "class_type": "LoadTensor", - "_meta": { - "title": "LoadTensor" - } - } + "2": {"inputs": {}, "class_type": "LoadTensor", "_meta": {"title": "LoadTensor"}}, } - diff --git a/test/test_utils.py b/test/test_utils.py index 30c990caa..70b4209cd 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,6 +1,6 @@ import pytest - from comfy.api.components.schema.prompt import Prompt + from comfystream.utils import convert_prompt diff --git a/ui/package-lock.json b/ui/package-lock.json index ca977d805..5b65c1b9b 100644 --- a/ui/package-lock.json +++ b/ui/package-lock.json @@ -1,12 +1,12 @@ { "name": "ui", - "version": "0.1.6", + "version": "0.1.7", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "ui", - "version": "0.1.6", + "version": "0.1.7", "dependencies": { "@hookform/resolvers": "^3.9.1", "@radix-ui/react-dialog": "^1.1.6", @@ -5265,9 +5265,9 @@ "license": "MIT" }, "node_modules/js-yaml": { - "version": "4.1.0", - "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.0.tgz", - "integrity": "sha512-wpxZs9NoxZaJESJGIZTyDEaYpl0FKSA+FB9aJiyemKhMwkxQg63h4T1KJgUGHpTqPDNRcmmYLugrRjJlBtWvRA==", + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.1.tgz", + "integrity": "sha512-qQKT4zQxXl8lLwBtHMWwaTcGfFOZviOJet3Oy/xmGk2gZH677CJM9EvtfdSkgWcATZhj/55JZ0rmy3myCT5lsA==", "dev": true, "license": "MIT", "dependencies": { diff --git a/ui/package.json b/ui/package.json index fc4435e6a..7cc8128b7 100644 --- a/ui/package.json +++ b/ui/package.json @@ -1,6 +1,6 @@ { "name": "ui", - "version": "0.1.6", + "version": "0.1.7", "private": true, "scripts": { "dev": "cross-env NEXT_PUBLIC_DEV=true next dev", @@ -9,7 +9,7 @@ "start": "next start", "lint": "next lint", "format": "prettier --write .", - "prepare": "cd .. && husky && husky install ui/.husky" + "prepare": "cd .. && husky ui/.husky" }, "dependencies": { "@hookform/resolvers": "^3.9.1", @@ -52,6 +52,14 @@ }, "lint-staged": { "*.{js,ts,tsx}": "eslint --cache --fix", - "*.{js,ts,tsx,css,md}": "prettier --write" + "*.{js,ts,tsx,css,md}": "prettier --write", + "../*.py": [ + "ruff check --fix", + "ruff format" + ], + "../{src,server,nodes,scripts}/**/*.py": [ + "ruff check --fix", + "ruff format" + ] } } diff --git a/ui/src/components/room.tsx b/ui/src/components/room.tsx index df03ed50f..19f1ce35d 100644 --- a/ui/src/components/room.tsx +++ b/ui/src/components/room.tsx @@ -56,8 +56,20 @@ function useToast() { } // Wrapper component to access peer context -function TranscriptionViewerWrapper() { +function TranscriptionViewerWrapper({ onFirstTextOutput }: { onFirstTextOutput: () => void }) { const peer = usePeerContext(); + const hasNotifiedRef = useRef(false); + + // Watch for first text output and notify parent + useEffect(() => { + if (peer?.textOutputData && peer.textOutputData.trim() && !hasNotifiedRef.current) { + // Check if it's not a warmup message + if (!peer.textOutputData.includes('__WARMUP_SENTINEL__')) { + hasNotifiedRef.current = true; + onFirstTextOutput(); + } + } + }, [peer?.textOutputData, onFirstTextOutput]); return ( { const [isRecordingsPanelOpen, setIsRecordingsPanelOpen] = useState(false); // Transcription state - const [isTranscriptionPanelOpen, setIsTranscriptionPanelOpen] = useState(true); + const [isTranscriptionPanelOpen, setIsTranscriptionPanelOpen] = useState(false); + const [hasReceivedTextOutput, setHasReceivedTextOutput] = useState(false); // Helper to get timestamped filenames const getFilename = (type: 'input' | 'output', extension: string) => { @@ -416,6 +429,7 @@ export const Room = () => { const handleDisconnected = useCallback(() => { setIsConnected(false); setIsComfyUIReady(false); + setHasReceivedTextOutput(false); // Reset text output state showToast("Stream disconnected", "error"); }, [showToast]); @@ -569,6 +583,14 @@ export const Room = () => { const [isControlPanelOpen, setIsControlPanelOpen] = useState(false); + // Callback to handle first text output received + const handleFirstTextOutput = useCallback(() => { + if (!hasReceivedTextOutput && !isTranscriptionPanelOpen) { + setHasReceivedTextOutput(true); + setIsTranscriptionPanelOpen(true); + } + }, [hasReceivedTextOutput, isTranscriptionPanelOpen]); + return (
{ {isConnected && (
- +
)}