From dcbdcfec1634ff8af64464cbe39a814ca3fd3911 Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Tue, 18 Mar 2025 16:08:02 -0400 Subject: [PATCH 1/7] Preliminary work for ComfyUI native API integration. --- server/app_api.py | 438 ++++++++++++++++++ server/pipeline_api.py | 195 ++++++++ src/comfystream/client_api.py | 836 ++++++++++++++++++++++++++++++++++ src/comfystream/utils_api.py | 245 ++++++++++ 4 files changed, 1714 insertions(+) create mode 100644 server/app_api.py create mode 100644 server/pipeline_api.py create mode 100644 src/comfystream/client_api.py create mode 100644 src/comfystream/utils_api.py diff --git a/server/app_api.py b/server/app_api.py new file mode 100644 index 00000000..6158a540 --- /dev/null +++ b/server/app_api.py @@ -0,0 +1,438 @@ +import argparse +import asyncio +import json +import logging +import os +import sys + +import torch + +# Initialize CUDA before any other imports to prevent core dump. +if torch.cuda.is_available(): + torch.cuda.init() + + +from aiohttp import web +from aiortc import ( + MediaStreamTrack, + RTCConfiguration, + RTCIceServer, + RTCPeerConnection, + RTCSessionDescription, +) +from aiortc.codecs import h264 +from aiortc.rtcrtpsender import RTCRtpSender +from pipeline_api import Pipeline # TODO: Better integration (Are we replacing pipeline with pipeline_api?) +from twilio.rest import Client +from utils import patch_loop_datagram, add_prefix_to_app_routes, FPSMeter +from metrics import MetricsManager, StreamStatsManager +import time + +logger = logging.getLogger(__name__) +logging.getLogger("aiortc.rtcrtpsender").setLevel(logging.WARNING) +logging.getLogger("aiortc.rtcrtpreceiver").setLevel(logging.WARNING) + + +MAX_BITRATE = 2000000 +MIN_BITRATE = 2000000 + + +class VideoStreamTrack(MediaStreamTrack): + """video stream track that processes video frames using a pipeline. + + Attributes: + kind (str): The kind of media, which is "video" for this class. + track (MediaStreamTrack): The underlying media stream track. + pipeline (Pipeline): The processing pipeline to apply to each video frame. + """ + + kind = "video" + + def __init__(self, track: MediaStreamTrack, pipeline: Pipeline): + """Initialize the VideoStreamTrack. + + Args: + track: The underlying media stream track. + pipeline: The processing pipeline to apply to each video frame. + """ + super().__init__() + self.track = track + self.pipeline = pipeline + self.fps_meter = FPSMeter( + metrics_manager=app["metrics_manager"], track_id=track.id + ) + self.running = True + self.collect_task = asyncio.create_task(self.collect_frames()) + + # Add cleanup when track ends + @track.on("ended") + async def on_ended(): + logger.info("Source video track ended, stopping collection") + await cancel_collect_frames(self) + + async def collect_frames(self): + """Collect video frames from the underlying track and pass them to + the processing pipeline. Stops when track ends or connection closes. + """ + try: + while self.running: + try: + frame = await self.track.recv() + await self.pipeline.put_video_frame(frame) + except asyncio.CancelledError: + logger.info("Frame collection cancelled") + break + except Exception as e: + if "MediaStreamError" in str(type(e)): + logger.info("Media stream ended") + else: + logger.error(f"Error collecting video frames: {str(e)}") + self.running = False + break + + # Perform cleanup outside the exception handler + logger.info("Video frame collection stopped") + except asyncio.CancelledError: + logger.info("Frame collection task cancelled") + except Exception as e: + logger.error(f"Unexpected error in frame collection: {str(e)}") + finally: + await self.pipeline.cleanup() + + async def recv(self): + """Receive a processed video frame from the pipeline, increment the frame + count for FPS calculation and return the processed frame to the client. + """ + processed_frame = await self.pipeline.get_processed_video_frame() + + # Increment the frame count to calculate FPS. + await self.fps_meter.increment_frame_count() + + return processed_frame + + +class AudioStreamTrack(MediaStreamTrack): + kind = "audio" + + def __init__(self, track: MediaStreamTrack, pipeline): + super().__init__() + self.track = track + self.pipeline = pipeline + self.running = True + self.collect_task = asyncio.create_task(self.collect_frames()) + + # Add cleanup when track ends + @track.on("ended") + async def on_ended(): + logger.info("Source audio track ended, stopping collection") + await cancel_collect_frames(self) + + async def collect_frames(self): + """Collect audio frames from the underlying track and pass them to + the processing pipeline. Stops when track ends or connection closes. + """ + try: + while self.running: + try: + frame = await self.track.recv() + await self.pipeline.put_audio_frame(frame) + except asyncio.CancelledError: + logger.info("Audio frame collection cancelled") + break + except Exception as e: + if "MediaStreamError" in str(type(e)): + logger.info("Media stream ended") + else: + logger.error(f"Error collecting audio frames: {str(e)}") + self.running = False + break + + # Perform cleanup outside the exception handler + logger.info("Audio frame collection stopped") + except asyncio.CancelledError: + logger.info("Frame collection task cancelled") + except Exception as e: + logger.error(f"Unexpected error in audio frame collection: {str(e)}") + finally: + await self.pipeline.cleanup() + + async def recv(self): + return await self.pipeline.get_processed_audio_frame() + + +def force_codec(pc, sender, forced_codec): + kind = forced_codec.split("/")[0] + codecs = RTCRtpSender.getCapabilities(kind).codecs + transceiver = next(t for t in pc.getTransceivers() if t.sender == sender) + codecPrefs = [codec for codec in codecs if codec.mimeType == forced_codec] + transceiver.setCodecPreferences(codecPrefs) + + +def get_twilio_token(): + account_sid = os.getenv("TWILIO_ACCOUNT_SID") + auth_token = os.getenv("TWILIO_AUTH_TOKEN") + + if account_sid is None or auth_token is None: + return None + + client = Client(account_sid, auth_token) + + token = client.tokens.create() + + return token + + +def get_ice_servers(): + ice_servers = [] + + token = get_twilio_token() + if token is not None: + # Use Twilio TURN servers + for server in token.ice_servers: + if server["url"].startswith("turn:"): + turn = RTCIceServer( + urls=[server["urls"]], + credential=server["credential"], + username=server["username"], + ) + ice_servers.append(turn) + + return ice_servers + + +async def offer(request): + pipeline = request.app["pipeline"] + pcs = request.app["pcs"] + + params = await request.json() + + await pipeline.set_prompts(params["prompts"]) + + offer_params = params["offer"] + offer = RTCSessionDescription(sdp=offer_params["sdp"], type=offer_params["type"]) + + ice_servers = get_ice_servers() + if len(ice_servers) > 0: + pc = RTCPeerConnection( + configuration=RTCConfiguration(iceServers=get_ice_servers()) + ) + else: + pc = RTCPeerConnection() + + pcs.add(pc) + + tracks = {"video": None, "audio": None} + + # Only add video transceiver if video is present in the offer + if "m=video" in offer.sdp: + # Prefer h264 + transceiver = pc.addTransceiver("video") + caps = RTCRtpSender.getCapabilities("video") + prefs = list(filter(lambda x: x.name == "H264", caps.codecs)) + transceiver.setCodecPreferences(prefs) + + # Monkey patch max and min bitrate to ensure constant bitrate + h264.MAX_BITRATE = MAX_BITRATE + h264.MIN_BITRATE = MIN_BITRATE + + # Handle control channel from client + @pc.on("datachannel") + def on_datachannel(channel): + if channel.label == "control": + + @channel.on("message") + async def on_message(message): + try: + params = json.loads(message) + + if params.get("type") == "get_nodes": + nodes_info = await pipeline.get_nodes_info() + response = {"type": "nodes_info", "nodes": nodes_info} + channel.send(json.dumps(response)) + elif params.get("type") == "update_prompts": + if "prompts" not in params: + logger.warning( + "[Control] Missing prompt in update_prompt message" + ) + return + await pipeline.update_prompts(params["prompts"]) + response = {"type": "prompts_updated", "success": True} + channel.send(json.dumps(response)) + else: + logger.warning( + "[Server] Invalid message format - missing required fields" + ) + except json.JSONDecodeError: + logger.error("[Server] Invalid JSON received") + except Exception as e: + logger.error(f"[Server] Error processing message: {str(e)}") + + @pc.on("track") + def on_track(track): + logger.info(f"Track received: {track.kind}") + if track.kind == "video": + videoTrack = VideoStreamTrack(track, pipeline) + tracks["video"] = videoTrack + sender = pc.addTrack(videoTrack) + + # Store video track in app for stats. + stream_id = track.id + request.app["video_tracks"][stream_id] = videoTrack + + codec = "video/H264" + force_codec(pc, sender, codec) + elif track.kind == "audio": + audioTrack = AudioStreamTrack(track, pipeline) + tracks["audio"] = audioTrack + pc.addTrack(audioTrack) + + @track.on("ended") + async def on_ended(): + logger.info(f"{track.kind} track ended") + request.app["video_tracks"].pop(track.id, None) + + @pc.on("connectionstatechange") + async def on_connectionstatechange(): + logger.info(f"Connection state is: {pc.connectionState}") + if pc.connectionState == "failed": + await pc.close() + pcs.discard(pc) + elif pc.connectionState == "closed": + await pc.close() + pcs.discard(pc) + + await pc.setRemoteDescription(offer) + + if "m=audio" in pc.remoteDescription.sdp: + await pipeline.warm_audio() + if "m=video" in pc.remoteDescription.sdp: + await pipeline.warm_video() + + answer = await pc.createAnswer() + await pc.setLocalDescription(answer) + + return web.Response( + content_type="application/json", + text=json.dumps( + {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type} + ), + ) + +async def cancel_collect_frames(track): + track.running = False + if hasattr(track, 'collect_task') is not None and not track.collect_task.done(): + try: + track.collect_task.cancel() + await track.collect_task + except (asyncio.CancelledError): + pass + +async def set_prompt(request): + pipeline = request.app["pipeline"] + + prompt = await request.json() + await pipeline.set_prompts(prompt) + + return web.Response(content_type="application/json", text="OK") + + +def health(_): + return web.Response(content_type="application/json", text="OK") + + +async def on_startup(app: web.Application): + if app["media_ports"]: + patch_loop_datagram(app["media_ports"]) + + app["pipeline"] = Pipeline( + cwd=app["workspace"], disable_cuda_malloc=True, gpu_only=True, preview_method='none' + ) + app["pcs"] = set() + app["video_tracks"] = {} + + +async def on_shutdown(app: web.Application): + pcs = app["pcs"] + coros = [pc.close() for pc in pcs] + await asyncio.gather(*coros) + pcs.clear() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run comfystream server") + parser.add_argument("--port", default=8889, help="Set the signaling port") + parser.add_argument( + "--media-ports", default=None, help="Set the UDP ports for WebRTC media" + ) + parser.add_argument("--host", default="127.0.0.1", help="Set the host") + parser.add_argument( + "--workspace", default=None, required=True, help="Set Comfy workspace" + ) + parser.add_argument( + "--log-level", + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + help="Set the logging level", + ) + parser.add_argument( + "--monitor", + default=False, + action="store_true", + help="Start a Prometheus metrics endpoint for monitoring.", + ) + parser.add_argument( + "--stream-id-label", + default=False, + action="store_true", + help="Include stream ID as a label in Prometheus metrics.", + ) + args = parser.parse_args() + + logging.basicConfig( + level=args.log_level.upper(), + format="%(asctime)s [%(levelname)s] %(message)s", + datefmt="%H:%M:%S", + ) + + app = web.Application() + app["media_ports"] = args.media_ports.split(",") if args.media_ports else None + app["workspace"] = args.workspace + + app.on_startup.append(on_startup) + app.on_shutdown.append(on_shutdown) + + app.router.add_get("/", health) + app.router.add_get("/health", health) + + # WebRTC signalling and control routes. + app.router.add_post("/offer", offer) + app.router.add_post("/prompt", set_prompt) + + # Add routes for getting stream statistics. + stream_stats_manager = StreamStatsManager(app) + app.router.add_get( + "/streams/stats", stream_stats_manager.collect_all_stream_metrics + ) + app.router.add_get( + "/stream/{stream_id}/stats", stream_stats_manager.collect_stream_metrics_by_id + ) + + # Add Prometheus metrics endpoint. + app["metrics_manager"] = MetricsManager(include_stream_id=args.stream_id_label) + if args.monitor: + app["metrics_manager"].enable() + logger.info( + f"Monitoring enabled - Prometheus metrics available at: " + f"http://{args.host}:{args.port}/metrics" + ) + app.router.add_get("/metrics", app["metrics_manager"].metrics_handler) + + # Add hosted platform route prefix. + # NOTE: This ensures that the local and hosted experiences have consistent routes. + add_prefix_to_app_routes(app, "/live") + + def force_print(*args, **kwargs): + print(*args, **kwargs, flush=True) + sys.stdout.flush() + + web.run_app(app, host=args.host, port=int(args.port), print=force_print) diff --git a/server/pipeline_api.py b/server/pipeline_api.py new file mode 100644 index 00000000..0f04e773 --- /dev/null +++ b/server/pipeline_api.py @@ -0,0 +1,195 @@ +import av +import torch +import numpy as np +import asyncio +import logging +import time +from PIL import Image +from io import BytesIO + +from typing import Any, Dict, Union, List +from comfystream.client_api import ComfyStreamClient +from comfystream import tensor_cache + +WARMUP_RUNS = 5 +logger = logging.getLogger(__name__) + + +class Pipeline: + def __init__(self, **kwargs): + self.client = ComfyStreamClient(**kwargs) + self.video_incoming_frames = asyncio.Queue() + self.audio_incoming_frames = asyncio.Queue() + + self.processed_audio_buffer = np.array([], dtype=np.int16) + + async def warm_video(self): + """Warm up the video pipeline with dummy frames""" + logger.info("Warming up video pipeline...") + + # Create a properly formatted dummy frame (random color pattern) + # Using standard tensor shape: BCHW [1, 3, 512, 512] + tensor = torch.rand(1, 3, 512, 512) # Random values in [0,1] + + # Create a dummy frame and attach the tensor as side_data + dummy_frame = av.VideoFrame(width=512, height=512, format="rgb24") + dummy_frame.side_data.input = tensor + + # Process a few frames for warmup + for i in range(WARMUP_RUNS): + logger.info(f"Video warmup iteration {i+1}/{WARMUP_RUNS}") + self.client.put_video_input(dummy_frame) + await self.client.get_video_output() + + logger.info("Video pipeline warmup complete") + + async def warm_audio(self): + dummy_frame = av.AudioFrame() + dummy_frame.side_data.input = np.random.randint(-32768, 32767, int(48000 * 0.5), dtype=np.int16) # TODO: adds a lot of delay if it doesn't match the buffer size, is warmup needed? + dummy_frame.sample_rate = 48000 + + for _ in range(WARMUP_RUNS): + self.client.put_audio_input(dummy_frame) + await self.client.get_audio_output() + + async def set_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]): + if isinstance(prompts, list): + await self.client.set_prompts(prompts) + else: + await self.client.set_prompts([prompts]) + + async def update_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]): + if isinstance(prompts, list): + await self.client.update_prompts(prompts) + else: + await self.client.update_prompts([prompts]) + + async def put_video_frame(self, frame: av.VideoFrame): + frame.side_data.input = self.video_preprocess(frame) + frame.side_data.skipped = False # Different from LoadTensor, we don't skip frames here + self.client.put_video_input(frame) + await self.video_incoming_frames.put(frame) + + async def put_audio_frame(self, frame: av.AudioFrame): + frame.side_data.input = self.audio_preprocess(frame) + frame.side_data.skipped = False + self.client.put_audio_input(frame) + await self.audio_incoming_frames.put(frame) + + def video_preprocess(self, frame: av.VideoFrame) -> Union[torch.Tensor, np.ndarray]: + """Convert input video frame to tensor in consistent BCHW format""" + try: + frame_np = frame.to_ndarray(format="rgb24") + frame_np = frame_np.astype(np.float32) / 255.0 + tensor = torch.from_numpy(frame_np) + + # TODO: Necessary? + if len(tensor.shape) == 3 and tensor.shape[2] == 3: # HWC format + tensor = tensor.permute(2, 0, 1).unsqueeze(0) # -> BCHW + + # Ensure values are in range [0,1] + if tensor.min() < 0 or tensor.max() > 1: + logger.warning(f"Clamping preprocessing tensor: min={tensor.min().item()}, max={tensor.max().item()}") + tensor = torch.clamp(tensor, 0, 1) + + return tensor + + except Exception as e: + logger.error(f"Error in video_preprocess: {e}") + # Return a default tensor in case of error + return torch.zeros(1, 3, frame.height, frame.width) + + def audio_preprocess(self, frame: av.AudioFrame) -> Union[torch.Tensor, np.ndarray]: + return frame.to_ndarray().ravel().reshape(-1, 2).mean(axis=1).astype(np.int16) + + def video_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.VideoFrame: + """Convert tensor to VideoFrame format""" + try: + # Ensure output is a tensor + if isinstance(output, np.ndarray): + output = torch.from_numpy(output) + + # Convert from BCHW to HWC format for video frame + if len(output.shape) == 4: # BCHW format + output = output.squeeze(0) # Remove batch dimension + if output.shape[0] == 3: # CHW format + output = output.permute(1, 2, 0) # Convert to HWC + + # Convert to numpy array in correct format for VideoFrame + frame_np = (output * 255.0).clamp(0, 255).to(dtype=torch.uint8).cpu().numpy() + + # Create VideoFrame with RGB format + video_frame = av.VideoFrame.from_ndarray(frame_np, format='rgb24') + + logger.info(f"Created video frame with shape: {frame_np.shape}") + return video_frame + + except Exception as e: + logger.error(f"Error in video_postprocess: {str(e)}") + # Return a black frame as fallback + return av.VideoFrame(width=512, height=512, format='rgb24') + + def audio_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.AudioFrame: + return av.AudioFrame.from_ndarray(np.repeat(output, 2).reshape(1, -1)) + + async def get_processed_video_frame(self): + """Get processed video frame from output queue and match it with input frame""" + try: + # Get the frame from the incoming queue first + frame = await self.video_incoming_frames.get() + + while frame.side_data.skipped: + frame = await self.video_incoming_frames.get() + + # Get the processed frame from the output queue + logger.info("Getting video output") + out_tensor = await self.client.get_video_output() + + # If there are more frames in the output queue, drain them to get the most recent + # This helps with synchronization when processing is faster than display + while not tensor_cache.image_outputs.empty(): + try: + newer_tensor = await asyncio.wait_for(self.client.get_video_output(), 0.01) + out_tensor = newer_tensor # Use the most recent frame + logger.info("Using more recent frame from output queue") + except asyncio.TimeoutError: + break + + logger.info(f"Received output tensor with shape: {out_tensor.shape if hasattr(out_tensor, 'shape') else 'unknown'}") + + # Process the output tensor + processed_frame = self.video_postprocess(out_tensor) + processed_frame.pts = frame.pts + processed_frame.time_base = frame.time_base + + return processed_frame + + except Exception as e: + logger.error(f"Error in get_processed_video_frame: {str(e)}") + # Create a black frame as fallback + black_frame = av.VideoFrame(width=512, height=512, format='rgb24') + return black_frame + + async def get_processed_audio_frame(self): + # TODO: make it generic to support purely generative audio cases and also add frame skipping + frame = await self.audio_incoming_frames.get() + if frame.samples > len(self.processed_audio_buffer): + out_tensor = await self.client.get_audio_output() + self.processed_audio_buffer = np.concatenate([self.processed_audio_buffer, out_tensor]) + out_data = self.processed_audio_buffer[:frame.samples] + self.processed_audio_buffer = self.processed_audio_buffer[frame.samples:] + + processed_frame = self.audio_postprocess(out_data) + processed_frame.pts = frame.pts + processed_frame.time_base = frame.time_base + processed_frame.sample_rate = frame.sample_rate + + return processed_frame + + async def get_nodes_info(self) -> Dict[str, Any]: + """Get information about all nodes in the current prompt including metadata.""" + nodes_info = await self.client.get_available_nodes() + return nodes_info + + async def cleanup(self): + await self.client.cleanup() \ No newline at end of file diff --git a/src/comfystream/client_api.py b/src/comfystream/client_api.py new file mode 100644 index 00000000..5e23e391 --- /dev/null +++ b/src/comfystream/client_api.py @@ -0,0 +1,836 @@ +import asyncio +import json +import uuid +import websockets +import base64 +import aiohttp +import logging +import torch +import numpy as np +from io import BytesIO +from PIL import Image +from typing import List, Dict, Any, Optional, Union +import random +import time + +from comfystream import tensor_cache +from comfystream.utils_api import convert_prompt + +logger = logging.getLogger(__name__) + +class ComfyStreamClient: + def __init__(self, host: str = "127.0.0.1", port: int = 8198, **kwargs): + """ + Initialize the ComfyStream client to use the ComfyUI API. + + Args: + host: The hostname or IP address of the ComfyUI server + port: The port number of the ComfyUI server + **kwargs: Additional configuration parameters + """ + self.host = host + self.port = port + self.server_address = f"ws://{host}:{port}/ws" + self.api_base_url = f"http://{host}:{port}/api" + self.client_id = kwargs.get('client_id', str(uuid.uuid4())) + self.api_version = kwargs.get('api_version', "1.0.0") + self.ws = None + self.current_prompts = [] + self.running_prompts = {} + self.cleanup_lock = asyncio.Lock() + + # WebSocket connection + self._ws_listener_task = None + self.execution_complete_event = asyncio.Event() + self.execution_started = False + self._prompt_id = None + + # Configure logging + if 'log_level' in kwargs: + logger.setLevel(kwargs['log_level']) + + # Enable debug mode + self.debug = kwargs.get('debug', True) + + logger.info(f"ComfyStreamClient initialized with host: {host}, port: {port}, client_id: {self.client_id}") + + async def set_prompts(self, prompts: List[Dict]): + """Set prompts and run them (compatible with original interface)""" + # Convert prompts (this already randomizes seeds, but we'll enhance it) + self.current_prompts = [convert_prompt(prompt) for prompt in prompts] + + # Create tasks for each prompt + for idx in range(len(self.current_prompts)): + task = asyncio.create_task(self.run_prompt(idx)) + self.running_prompts[idx] = task + + logger.info(f"Set {len(self.current_prompts)} prompts for execution") + + async def update_prompts(self, prompts: List[Dict]): + """Update existing prompts (compatible with original interface)""" + if len(prompts) != len(self.current_prompts): + raise ValueError( + "Number of updated prompts must match the number of currently running prompts." + ) + self.current_prompts = [convert_prompt(prompt) for prompt in prompts] + logger.info(f"Updated {len(self.current_prompts)} prompts") + + async def run_prompt(self, prompt_index: int): + """Run a prompt continuously, processing new frames as they arrive""" + logger.info(f"Running prompt {prompt_index}") + + # Make sure WebSocket is connected + await self._connect_websocket() + + # Always set execution complete at start to allow first frame to be processed + self.execution_complete_event.set() + logger.info("Setting execution_complete_event to TRUE at start") + + try: + while True: + # Wait until we have tensor data available before sending prompt + if tensor_cache.image_inputs.empty(): + await asyncio.sleep(0.01) # Reduced sleep time for faster checking + continue + + # Clear event before sending a new prompt + if self.execution_complete_event.is_set(): + # Reset execution state for next frame + self.execution_complete_event.clear() + logger.info("Setting execution_complete_event to FALSE before executing prompt") + + # Queue the prompt with the current frame + await self._execute_prompt(prompt_index) + + # Wait for execution completion with timeout + try: + logger.info("Waiting for execution to complete (max 10 seconds)...") + await asyncio.wait_for(self.execution_complete_event.wait(), timeout=10.0) + logger.info("Execution complete, ready for next frame") + except asyncio.TimeoutError: + logger.error("Timeout waiting for execution, forcing continuation") + self.execution_complete_event.set() + else: + # If execution is not complete, check again shortly + await asyncio.sleep(0.01) # Short sleep to prevent CPU spinning + + except asyncio.CancelledError: + logger.info(f"Prompt {prompt_index} execution cancelled") + raise + except Exception as e: + logger.error(f"Error in run_prompt: {str(e)}") + raise + + async def _connect_websocket(self): + """Connect to the ComfyUI WebSocket endpoint""" + try: + if self.ws is not None and self.ws.open: + return self.ws + + # Close existing connection if any + if self.ws is not None: + try: + await self.ws.close() + except: + pass + self.ws = None + + logger.info(f"Connecting to WebSocket at {self.server_address}?clientId={self.client_id}") + + # Set a reasonable timeout for connection + websocket_timeout = 10.0 # seconds + + try: + # Connect with proper error handling + self.ws = await websockets.connect( + f"{self.server_address}?clientId={self.client_id}", + ping_interval=5, + ping_timeout=10, + close_timeout=5, + max_size=None, # No limit on message size + ssl=None + ) + + logger.info("WebSocket connected successfully") + + # Start the listener task if not already running + if self._ws_listener_task is None or self._ws_listener_task.done(): + self._ws_listener_task = asyncio.create_task(self._ws_listener()) + logger.info("Started WebSocket listener task") + + return self.ws + + except (websockets.exceptions.WebSocketException, ConnectionError, OSError) as e: + logger.error(f"WebSocket connection error: {e}") + self.ws = None + # Signal execution complete to prevent hanging if connection fails + self.execution_complete_event.set() + # Retry after a delay + await asyncio.sleep(1) + return await self._connect_websocket() + + except Exception as e: + logger.error(f"Unexpected error in _connect_websocket: {e}") + self.ws = None + # Signal execution complete to prevent hanging + self.execution_complete_event.set() + return None + + async def _ws_listener(self): + """Listen for WebSocket messages and process them""" + try: + logger.info(f"WebSocket listener started") + while True: + if self.ws is None: + try: + await self._connect_websocket() + except Exception as e: + logger.error(f"Error connecting to WebSocket: {e}") + await asyncio.sleep(1) + continue + + try: + # Receive and process messages + message = await self.ws.recv() + + if isinstance(message, str): + # Process JSON messages + await self._handle_text_message(message) + else: + # Handle binary data - likely image preview or tensor data + await self._handle_binary_message(message) + + except websockets.exceptions.ConnectionClosed: + logger.info("WebSocket connection closed") + self.ws = None + await asyncio.sleep(1) + except Exception as e: + logger.error(f"Error in WebSocket listener: {e}") + await asyncio.sleep(1) + + except asyncio.CancelledError: + logger.info("WebSocket listener cancelled") + raise + except Exception as e: + logger.error(f"Unexpected error in WebSocket listener: {e}") + + async def _handle_text_message(self, message: str): + """Process text (JSON) messages from the WebSocket""" + try: + data = json.loads(message) + message_type = data.get("type", "unknown") + + # logger.info(f"Received message type: {message_type}") + + # Handle different message types + if message_type == "status": + pass + ''' + # Status message with comfy_ui's queue information + queue_remaining = data.get("data", {}).get("queue_remaining", 0) + exec_info = data.get("data", {}).get("exec_info", {}) + if queue_remaining == 0 and not exec_info: + logger.info("Queue empty, no active execution") + else: + logger.info(f"Queue status: {queue_remaining} items remaining") + ''' + + elif message_type == "progress": + if "data" in data and "value" in data["data"]: + progress = data["data"]["value"] + max_value = data["data"].get("max", 100) + # Log the progress for debugging + # logger.info(f"Progress: {progress}/{max_value}") + + elif message_type == "execution_start": + self.execution_started = True + if "data" in data and "prompt_id" in data["data"]: + self._prompt_id = data["data"]["prompt_id"] + # logger.info(f"Execution started for prompt {self._prompt_id}") + + elif message_type == "executing": + self.execution_started = True + if "data" in data: + if "prompt_id" in data["data"]: + self._prompt_id = data["data"]["prompt_id"] + if "node" in data["data"]: + node_id = data["data"]["node"] + # ogger.info(f"Executing node: {node_id}") + + elif message_type in ["execution_cached", "execution_error", "execution_complete", "execution_interrupted"]: + # logger.info(f"{message_type} message received for prompt {self._prompt_id}") + #self.execution_started = False + + # Always signal completion for these terminal states + # self.execution_complete_event.set() + # logger.info(f"Set execution_complete_event from {message_type}") + pass + + elif message_type == "executed": + # This is sent when a node is completely done + if "data" in data and "node_id" in data["data"]: + node_id = data["data"]["node_id"] + logger.info(f"Node execution complete: {node_id}") + + # Check if this is our SaveTensorAPI node + if "SaveTensorAPI" in str(node_id): + logger.info("SaveTensorAPI node executed, checking for tensor data") + # The binary data should come separately via websocket + + # If we've been running for too long without tensor data, force completion + elif self.execution_started and not self.execution_complete_event.is_set(): + # Check if this was the last node + if data.get("data", {}).get("remaining", 0) == 0: + # logger.info("All nodes executed but no tensor data received, forcing completion") + # self.execution_complete_event.set() + pass + + elif message_type == "executed_node" and "output" in data.get("data", {}): + node_id = data.get("data", {}).get("node_id") + output_data = data.get("data", {}).get("output", {}) + prompt_id = data.get("data", {}).get("prompt_id", "unknown") + + logger.info(f"Node {node_id} executed in prompt {prompt_id}") + + ''' + # Check if this is from ETN_SendImageWebSocket node + if "ui" in output_data and "images" in output_data["ui"]: + images_info = output_data["ui"]["images"] + logger.info(f"Found image output from ETN_SendImageWebSocket in node {node_id}") + + # Images will be received via binary websocket messages after this event + # The binary handler will take care of them + pass + + # Keep existing handling for tensor data + elif "ui" in output_data and "tensor" in output_data["ui"]: + tensor_info = output_data["ui"]["tensor"] + tensor_id = tensor_info.get("tensor_id", "unknown") + logger.info(f"Found tensor data with ID: {tensor_id} in node {node_id}") + + # Decode the tensor data + tensor_data = await self._decode_tensor_data(tensor_info) + if tensor_data is not None: + # Add to output queue without waiting to unblock event loop + tensor_cache.image_outputs.put_nowait(tensor_data) + logger.info(f"Added tensor to output queue, shape: {tensor_data.shape}") + + # IMPORTANT: Immediately signal that we can proceed with the next frame + # when we receive tensor data, don't wait + logger.info("Received tensor data, immediately signaling execution complete") + self.execution_complete_event.set() + logger.info("Set execution_complete_event after processing tensor data") + else: + logger.error("Failed to decode tensor data") + # Signal completion even if decoding failed to prevent hanging + self.execution_complete_event.set() + ''' + except json.JSONDecodeError: + logger.error(f"Invalid JSON message: {message[:100]}...") + except Exception as e: + logger.error(f"Error handling WebSocket message: {e}") + # Signal completion on error to prevent hanging + self.execution_complete_event.set() + + async def _handle_binary_message(self, binary_data): + """Process binary messages from the WebSocket""" + try: + # Log binary message information + # logger.info(f"Received binary message of size: {len(binary_data)} bytes") + + # Signal execution is complete, queue next frame + self.execution_complete_event.set() + + # Binary messages in ComfyUI start with a header + # First 8 bytes are used for header information + if len(binary_data) <= 8: + logger.warning(f"Binary message too short: {len(binary_data)} bytes") + return + + # Extract header data based on the actual format observed in logs + # Header bytes (hex): 0000000100000001 - this appears to be the format in use + event_type = int.from_bytes(binary_data[:4], byteorder='little') + format_type = int.from_bytes(binary_data[4:8], byteorder='little') + data = binary_data[8:] + + # Log header details + logger.info(f"Binary message header: event_type={event_type}, format_type={format_type}, data_size={len(data)} bytes") + #logger.info(f"Header bytes (hex): {binary_data[:8].hex()}") + + # Check if this is an image (JPEG starts with 0xFF, 0xD8, PNG starts with 0x89, 0x50) + is_jpeg = data[:2] == b'\xff\xd8' + is_png = data[:4] == b'\x89\x50\x4e\x47' + + if is_jpeg or is_png: + image_format = "JPEG" if is_jpeg else "PNG" + logger.info(f"Detected {image_format} image based on magic bytes") + + # Create a NEW binary message with the expected header format for the JavaScript client + # The JavaScript expects: [0:4]=1 (PREVIEW_IMAGE), [4:8]=1 (JPEG format) or [4:8]=2 (PNG format) + # This matches exactly what the JS code is looking for: + # const event = dataView.getUint32(0); // event type (1 = PREVIEW_IMAGE) + # const format = dataView.getUint32(4); // format (1 = JPEG, 2 = PNG) + js_event_type = (1).to_bytes(4, byteorder='little') # PREVIEW_IMAGE = 1 + js_format_type = (1 if is_jpeg else 2).to_bytes(4, byteorder='little') + transformed_data = js_event_type + js_format_type + data + + # Forward to WebSocket client if connected + # if self.ws: + # await self.ws.send(transformed_data) + # logger.info(f"Sent transformed {image_format} image data to WebSocket with correct JS header format") + #else: + # logger.error("WebSocket not connected, cannot forward image to JS client") + + # Process the image for our pipeline + try: + # Decode the image + img = Image.open(BytesIO(data)) + logger.info(f"Successfully decoded image: size={img.size}, mode={img.mode}, format={img.format}") + + # Convert to RGB if not already + if img.mode != "RGB": + img = img.convert("RGB") + logger.info(f"Converted image to RGB mode") + + # Save image to temp folder as a file + # TESTING + ''' + import os + import tempfile + temp_folder = os.path.join(tempfile.gettempdir(), "comfyui_images") + os.makedirs(temp_folder, exist_ok=True) + img_path = os.path.join(temp_folder, f"comfyui_image_{time.time()}.png") + img.save(img_path) + logger.info(f"Saved image to {img_path}") + ''' + + # Convert to tensor (normalize to [0,1] range for consistency) + img_np = np.array(img).astype(np.float32) / 255.0 + tensor = torch.from_numpy(img_np) + + # CRITICAL: Ensure dimensions are correctly understood + # The tensor should be in HWC format initially from PIL/numpy + logger.info(f"Initial tensor shape from image: {tensor.shape}") + + # Convert from HWC to BCHW format for consistency with model expectations + if len(tensor.shape) == 3 and tensor.shape[2] == 3: # HWC format (H,W,3) + tensor = tensor.permute(2, 0, 1).unsqueeze(0) # -> BCHW (1,3,H,W) + logger.info(f"Converted to BCHW tensor with shape: {tensor.shape}") + + # Check for NaN or Inf values + if torch.isnan(tensor).any() or torch.isinf(tensor).any(): + logger.warning("Tensor contains NaN or Inf values! Replacing with zeros") + tensor = torch.nan_to_num(tensor, nan=0.0, posinf=1.0, neginf=0.0) + + # Log detailed tensor info for debugging + logger.info(f"Final tensor with shape: {tensor.shape}, " + f"min={tensor.min().item()}, max={tensor.max().item()}, " + f"mean={tensor.mean().item()}") + + # Add to output queue without waiting + tensor_cache.image_outputs.put_nowait(tensor) + logger.info(f"Added tensor to output queue, queue size: {tensor_cache.image_outputs.qsize()}") + return + + except Exception as img_error: + logger.error(f"Error processing image: {img_error}", exc_info=True) + + # If we get here, we couldn't process the image + logger.warning("Failed to process image, creating default tensor") + default_tensor = torch.zeros(1, 3, 512, 512) + tensor_cache.image_outputs.put_nowait(default_tensor) + self.execution_complete_event.set() + + except Exception as e: + logger.error(f"Error handling binary message: {e}", exc_info=True) + # Set execution complete event to avoid hanging + self.execution_complete_event.set() + + async def _execute_prompt(self, prompt_index: int): + """Execute a prompt via the ComfyUI API""" + try: + # Get the prompt to execute + prompt = self.current_prompts[prompt_index] + + # Ensure all seed values are randomized for every execution + # This forces ComfyUI to not use cached results + for node_id, node in prompt.items(): + if isinstance(node, dict) and "inputs" in node: + if "seed" in node["inputs"]: + # Generate a truly random seed each time + random_seed = random.randint(0, 18446744073709551615) + node["inputs"]["seed"] = random_seed + logger.info(f"Randomized seed to {random_seed} for node {node_id}") + + # Also randomize noise_seed if present + if "noise_seed" in node["inputs"]: + noise_seed = random.randint(0, 18446744073709551615) + node["inputs"]["noise_seed"] = noise_seed + logger.info(f"Randomized noise_seed to {noise_seed} for node {node_id}") + + # Add a timestamp parameter to each node to prevent caching + # This is a "hidden" trick to force ComfyUI to consider each execution unique + timestamp = int(time.time() * 1000) # millisecond timestamp + for node_id, node in prompt.items(): + if isinstance(node, dict) and "inputs" in node: + # Add a timestamp parameter to ETN_LoadImageBase64 nodes + if node.get("class_type") in ["ETN_LoadImageBase64", "LoadImageBase64"]: + # Add a unique cache-busting parameter + node["inputs"]["_timestamp"] = timestamp + logger.info(f"Added timestamp {timestamp} to node {node_id}") + + # Check if we have a frame waiting to be processed + if not tensor_cache.image_inputs.empty(): + logger.info("Found tensor in input queue, preparing for API") + # Get the frame from the cache - make sure to get the most recent frame + while not tensor_cache.image_inputs.empty(): + frame_or_tensor = tensor_cache.image_inputs.get_nowait() + + # Find ETN_LoadImageBase64 nodes + load_image_nodes = [] + for node_id, node in prompt.items(): + if isinstance(node, dict) and node.get("class_type") == "ETN_LoadImageBase64": + load_image_nodes.append(node_id) + + if not load_image_nodes: + # Also check for regular LoadImageBase64 nodes as fallback + for node_id, node in prompt.items(): + if isinstance(node, dict) and node.get("class_type") == "LoadImageBase64": + load_image_nodes.append(node_id) + + if not load_image_nodes: + logger.warning("No ETN_LoadImageBase64 or LoadImageBase64 nodes found in the prompt") + self.execution_complete_event.set() # Signal completion + return + else: + # Convert the frame/tensor to base64 and include directly in the prompt + try: + # Get the actual tensor data - handle different input types + tensor = None + + # Check if it's a PyAV VideoFrame with preprocessed tensor in side_data + if hasattr(frame_or_tensor, 'side_data') and hasattr(frame_or_tensor.side_data, 'input'): + tensor = frame_or_tensor.side_data.input + logger.info(f"Using preprocessed tensor from frame.side_data.input with shape {tensor.shape}") + # Check if it's a PyTorch tensor + elif isinstance(frame_or_tensor, torch.Tensor): + tensor = frame_or_tensor + logger.info(f"Using tensor directly with shape {tensor.shape}") + # Check if it's a numpy array + elif isinstance(frame_or_tensor, np.ndarray): + tensor = torch.from_numpy(frame_or_tensor).float() + logger.info(f"Converted numpy array to tensor with shape {tensor.shape}") + else: + # If it's a PyAV frame without preprocessed data, convert it + try: + if hasattr(frame_or_tensor, 'to_ndarray'): + frame_np = frame_or_tensor.to_ndarray(format="rgb24").astype(np.float32) / 255.0 + tensor = torch.from_numpy(frame_np).unsqueeze(0) + logger.info(f"Converted PyAV frame to tensor with shape {tensor.shape}") + else: + logger.error(f"Unsupported frame type: {type(frame_or_tensor)}") + self.execution_complete_event.set() + return + except Exception as e: + logger.error(f"Error converting frame to tensor: {e}") + self.execution_complete_event.set() + return + + if tensor is None: + logger.error("Failed to get valid tensor data from input") + self.execution_complete_event.set() + return + + # Now process the tensor (which should be a proper PyTorch tensor) + # Ensure it's a tensor on CPU and detached + tensor = tensor.detach().cpu().float() + + # Handle different formats + if len(tensor.shape) == 4: # BCHW format (batch) + tensor = tensor[0] # Take first image from batch + + # Ensure it's in CHW format + if len(tensor.shape) == 3 and tensor.shape[2] == 3: # HWC format + tensor = tensor.permute(2, 0, 1) # Convert to CHW + + # Convert to PIL image for saving + tensor_np = (tensor.permute(1, 2, 0) * 255).clamp(0, 255).numpy().astype(np.uint8) + img = Image.fromarray(tensor_np) + + # Save as PNG to BytesIO and convert to base64 string + buffer = BytesIO() + img.save(buffer, format="PNG") + buffer.seek(0) + + # Encode as base64 - for ETN_LoadImageBase64, we need the raw base64 string + img_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8') + logger.info(f"Created base64 string of length: {len(img_base64)}") + + # Update all ETN_LoadImageBase64 nodes with the base64 data + for node_id in load_image_nodes: + prompt[node_id]["inputs"]["image"] = img_base64 + # Add a small random suffix to image data to prevent caching + rand_suffix = str(random.randint(1, 1000000)) + prompt[node_id]["inputs"]["_cache_buster"] = rand_suffix + logger.info(f"Updated node {node_id} with base64 string and cache buster {rand_suffix}") + + except Exception as e: + logger.error(f"Error converting tensor to base64: {e}") + # Signal execution complete in case of error + self.execution_complete_event.set() + return + else: + logger.info("No tensor in input queue, skipping prompt execution") + self.execution_complete_event.set() # Signal completion + return + + # Execute the prompt via API + async with aiohttp.ClientSession() as session: + api_url = f"{self.api_base_url}/prompt" + payload = { + "prompt": prompt, + "client_id": self.client_id + } + + # Send the request + logger.info(f"Sending prompt to {api_url}") + async with session.post(api_url, json=payload) as response: + if response.status == 200: + result = await response.json() + self._prompt_id = result.get("prompt_id") + logger.info(f"Prompt queued with ID: {self._prompt_id}") + self.execution_started = True + else: + error_text = await response.text() + logger.error(f"Error queueing prompt: {response.status} - {error_text}") + # Signal execution complete in case of error + self.execution_complete_event.set() + + except aiohttp.ClientError as e: + logger.error(f"Client error queueing prompt: {e}") + self.execution_complete_event.set() + except Exception as e: + logger.error(f"Error executing prompt: {e}") + # Signal execution complete in case of error + self.execution_complete_event.set() + + async def _send_tensor_via_websocket(self, tensor): + """Send tensor data via the websocket connection""" + try: + if self.ws is None: + logger.error("WebSocket not connected, cannot send tensor") + self.execution_complete_event.set() # Prevent hanging + return + + # Convert the tensor to image format for sending + if isinstance(tensor, np.ndarray): + tensor = torch.from_numpy(tensor).float() + + # Ensure on CPU and correct format + tensor = tensor.detach().cpu().float() + + # Prepare binary data + if len(tensor.shape) == 4: # BCHW format (batch of images) + if tensor.shape[0] > 1: + logger.info(f"Taking first image from batch of {tensor.shape[0]}") + tensor = tensor[0] # Take first image if batch + + # Ensure CHW format (3 channels) + if len(tensor.shape) == 3: + if tensor.shape[0] != 3 and tensor.shape[2] == 3: # HWC format + tensor = tensor.permute(2, 0, 1) # Convert to CHW + elif tensor.shape[0] != 3: + logger.warning(f"Tensor doesn't have 3 channels: {tensor.shape}. Creating standard tensor.") + # Create a standard RGB tensor + tensor = torch.zeros(3, 512, 512) + else: + logger.warning(f"Tensor has unexpected shape: {tensor.shape}. Creating standard tensor.") + # Create a standard RGB tensor + tensor = torch.zeros(3, 512, 512) + + # Check tensor dimensions and log detailed info + logger.info(f"Original tensor for WS: shape={tensor.shape}, min={tensor.min().item():.4f}, max={tensor.max().item():.4f}") + + # CRITICAL FIX: The issue is with the shape - no need to resize if dimensions are fine + if tensor.shape[1] < 64 or tensor.shape[2] < 64: + logger.warning(f"Tensor dimensions too small: {tensor.shape}. Resizing to 512x512") + import torch.nn.functional as F + tensor = tensor.unsqueeze(0) # Add batch dimension for interpolate + tensor = F.interpolate(tensor, size=(512, 512), mode='bilinear', align_corners=False) + tensor = tensor.squeeze(0) # Remove batch dimension after resize + logger.info(f"Resized tensor to: {tensor.shape}") + + # Check for NaN or Inf values + if torch.isnan(tensor).any() or torch.isinf(tensor).any(): + logger.warning("Tensor contains NaN or Inf values! Replacing with zeros.") + tensor = torch.nan_to_num(tensor, nan=0.0, posinf=1.0, neginf=0.0) + + # Convert to image (HWC for PIL) + tensor_np = (tensor.permute(1, 2, 0) * 255).clamp(0, 255).numpy().astype(np.uint8) + img = Image.fromarray(tensor_np) + + logger.info(f"Converted to PIL image with dimensions: {img.size}") + + # Convert to PNG + buffer = BytesIO() + img.save(buffer, format="PNG") + buffer.seek(0) + img_bytes = buffer.getvalue() + + # CRITICAL FIX: We need to send the binary data with a proper node ID prefix + # LoadTensorAPI node expects this header format to identify the target node + # The first 4 bytes are the message type (3 for binary tensor) and the next 4 are the node ID + # Since we don't know the exact node ID, we'll use a generic one that will be interpreted as + # "send this to the currently waiting LoadTensorAPI node" + + # Build header (8 bytes total) + header = bytearray() + # Message type 3 (custom binary tensor data) + header.extend((3).to_bytes(4, byteorder='little')) + # Generic node ID (0 means "send to whatever node is waiting") + header.extend((0).to_bytes(4, byteorder='little')) + + # Combine header and image data + full_data = header + img_bytes + + # Send binary data via websocket + await self.ws.send(full_data) + logger.info(f"Sent tensor as PNG image via websocket with proper header, size: {len(full_data)} bytes, image dimensions: {img.size}") + + except Exception as e: + logger.error(f"Error sending tensor via websocket: {e}") + + # Signal execution complete in case of error + self.execution_complete_event.set() + + async def cleanup(self): + """Clean up resources""" + async with self.cleanup_lock: + # Cancel all running tasks + for task in self.running_prompts.values(): + if not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + self.running_prompts.clear() + + # Close WebSocket connection + if self.ws: + try: + await self.ws.close() + except Exception as e: + logger.error(f"Error closing WebSocket: {e}") + self.ws = None + + # Cancel WebSocket listener task + if self._ws_listener_task and not self._ws_listener_task.done(): + self._ws_listener_task.cancel() + try: + await self._ws_listener_task + except asyncio.CancelledError: + pass + self._ws_listener_task = None + + await self.cleanup_queues() + logger.info("Client cleanup complete") + + async def cleanup_queues(self): + """Clean up tensor queues""" + while not tensor_cache.image_inputs.empty(): + tensor_cache.image_inputs.get() + + while not tensor_cache.audio_inputs.empty(): + tensor_cache.audio_inputs.get() + + while tensor_cache.image_outputs.qsize() > 0: + try: + await tensor_cache.image_outputs.get() + except: + pass + + while tensor_cache.audio_outputs.qsize() > 0: + try: + await tensor_cache.audio_outputs.get() + except: + pass + + logger.info("Tensor queues cleared") + + def put_video_input(self, tensor: Union[torch.Tensor, np.ndarray]): + """ + Put a video TENSOR into the tensor cache for processing. + + Args: + tensor: Video frame as a tensor (or numpy array) + """ + try: + # Only remove one frame if the queue is full (like in client.py) + if tensor_cache.image_inputs.full(): + tensor_cache.image_inputs.get_nowait() + + # Ensure tensor is detached if it's a torch tensor + if isinstance(tensor, torch.Tensor): + tensor = tensor.detach() + + tensor_cache.image_inputs.put(tensor) + + except Exception as e: + logger.error(f"Error in put_video_input: {e}") + + def put_audio_input(self, frame): + """Put audio frame into tensor cache""" + tensor_cache.audio_inputs.put(frame) + + async def get_video_output(self): + """Get processed video frame from tensor cache""" + logger.info("Waiting for processed tensor from output queue") + result = await tensor_cache.image_outputs.get() + logger.info(f"Got processed tensor from output queue: shape={result.shape if hasattr(result, 'shape') else 'unknown'}") + return result + + async def get_audio_output(self): + """Get processed audio frame from tensor cache""" + return await tensor_cache.audio_outputs.get() + + async def get_available_nodes(self): + """Get metadata and available nodes info for current prompts""" + try: + async with aiohttp.ClientSession() as session: + url = f"{self.api_base_url}/object_info" + async with session.get(url) as response: + if response.status == 200: + data = await response.json() + + # Format node info similar to the embedded client response + all_prompts_nodes_info = {} + + for prompt_index, prompt in enumerate(self.current_prompts): + nodes_info = {} + + for node_id, node in prompt.items(): + class_type = node.get('class_type') + if class_type: + nodes_info[node_id] = { + 'class_type': class_type, + 'inputs': {} + } + + if 'inputs' in node: + for input_name, input_value in node['inputs'].items(): + nodes_info[node_id]['inputs'][input_name] = { + 'value': input_value, + 'type': 'unknown' # We don't have type information + } + + all_prompts_nodes_info[prompt_index] = nodes_info + + return all_prompts_nodes_info + + else: + logger.error(f"Error getting node info: {response.status}") + return {} + except Exception as e: + logger.error(f"Error getting node info: {str(e)}") + return {} \ No newline at end of file diff --git a/src/comfystream/utils_api.py b/src/comfystream/utils_api.py new file mode 100644 index 00000000..8b93ec7f --- /dev/null +++ b/src/comfystream/utils_api.py @@ -0,0 +1,245 @@ +import copy +import random + +from typing import Dict, Any +# from comfy.api.components.schema.prompt import Prompt, PromptDictInput + +import logging + +def create_load_tensor_node(): + return { + "inputs": { + "tensor_data": "" # Empty tensor data that will be filled at runtime + }, + "class_type": "LoadTensorAPI", + "_meta": {"title": "Load Tensor (API)"}, + } + + +def create_save_tensor_node(inputs: Dict[Any, Any]): + """Create a SaveTensorAPI node with proper input formatting""" + # Make sure images input is properly formatted [node_id, output_index] + images_input = inputs.get("images") + + # If images input is not properly formatted as [node_id, output_index] + if not isinstance(images_input, list) or len(images_input) != 2: + print(f"Warning: Invalid images input format: {images_input}, using default") + images_input = ["", 0] # Default empty value + + return { + "inputs": { + "images": images_input, # Should be [node_id, output_index] + "format": "png", # Better default than JPG for quality + "quality": 95 + }, + "class_type": "SaveTensorAPI", + "_meta": {"title": "Save Tensor (API)"}, + } + +def convert_prompt(prompt): + + logging.info("Converting prompt: %s", prompt) + + # Set random seeds for any seed nodes + for key, node in prompt.items(): + if not isinstance(node, dict) or "inputs" not in node: + continue + + # Check if this node has a seed input directly + if "seed" in node.get("inputs", {}): + # Generate a random seed (same range as JavaScript's Math.random() * 18446744073709552000) + random_seed = random.randint(0, 18446744073709551615) + node["inputs"]["seed"] = random_seed + print(f"Set random seed {random_seed} for node {key}") + + return prompt + +''' +def convert_prompt(prompt): + + # Check if this is a ComfyUI web UI format prompt with 'nodes' and 'links' + if isinstance(prompt, dict) and 'nodes' in prompt and 'links' in prompt: + # Convert the web UI prompt format to the API format + api_prompt = {} + + # Process each node + for node in prompt['nodes']: + node_id = str(node['id']) + + # Create a node entry in the API format + api_prompt[node_id] = { + 'class_type': node.get('type', 'Unknown'), + 'inputs': {}, + '_meta': { + 'title': node.get('type', 'Unknown') + } + } + + # Process inputs + if 'inputs' in node: + for input_data in node['inputs']: + input_name = input_data.get('name') + link_id = input_data.get('link') + + if input_name and link_id is not None: + # Find the source of this link + for link in prompt['links']: + if link[0] == link_id: # link ID matches + # Get source node and output slot + source_node_id = str(link[1]) + source_slot = link[3] + + # Add to inputs + api_prompt[node_id]['inputs'][input_name] = [ + source_node_id, + source_slot + ] + break + # If no link found, set to empty value + if input_name not in api_prompt[node_id]['inputs']: + api_prompt[node_id]['inputs'][input_name] = None + + # Process widget values + if 'widgets_values' in node: + for i, widget_value in enumerate(node.get('widgets_values', [])): + # Try to determine widget name from properties or use index + widget_name = f"widget_{i}" + # Add to inputs + api_prompt[node_id]['inputs'][widget_name] = widget_value + + # Use this converted prompt instead + prompt = api_prompt + + # Now process as normal API format prompt + prompt = copy.deepcopy(prompt) + + # Set random seeds for any seed nodes + for key, node in prompt.items(): + if not isinstance(node, dict) or "inputs" not in node: + continue + + # Check if this node has a seed input directly + if "seed" in node.get("inputs", {}): + # Generate a random seed (same range as JavaScript's Math.random() * 18446744073709552000) + random_seed = random.randint(0, 18446744073709551615) + node["inputs"]["seed"] = random_seed + print(f"Set random seed {random_seed} for node {key}") + + num_primary_inputs = 0 + num_inputs = 0 + num_outputs = 0 + + keys = { + "PrimaryInputLoadImage": [], + "LoadImage": [], + "PreviewImage": [], + "SaveImage": [], + "LoadTensor": [], + "SaveTensor": [], + "LoadImageBase64": [], + "LoadTensorAPI": [], + "SaveTensorAPI": [], + } + + for key, node in prompt.items(): + if not isinstance(node, dict): + continue + + class_type = node.get("class_type", "") + + # Track primary input and output nodes + if class_type in ["PrimaryInput", "PrimaryInputImage"]: + num_primary_inputs += 1 + keys["PrimaryInputLoadImage"].append(key) + elif class_type in ["LoadImage", "LoadTensor", "LoadAudioTensor", "LoadImageBase64", "LoadTensorAPI"]: + num_inputs += 1 + if class_type == "LoadImage": + keys["LoadImage"].append(key) + elif class_type == "LoadTensor": + keys["LoadTensor"].append(key) + elif class_type == "LoadImageBase64": + keys["LoadImageBase64"].append(key) + elif class_type == "LoadTensorAPI": + keys["LoadTensorAPI"].append(key) + elif class_type in ["PreviewImage", "SaveImage", "SaveTensor", "SaveAudioTensor", "SendImageWebSocket", "SaveTensorAPI"]: + num_outputs += 1 + if class_type == "PreviewImage": + keys["PreviewImage"].append(key) + elif class_type == "SaveImage": + keys["SaveImage"].append(key) + elif class_type == "SaveTensor": + keys["SaveTensor"].append(key) + elif class_type == "SaveTensorAPI": + keys["SaveTensorAPI"].append(key) + + print(f"Found {num_primary_inputs} primary inputs, {num_inputs} inputs, {num_outputs} outputs") + + # Set up connection for video feeds by replacing LoadImage with LoadImageBase64 + if num_inputs == 0 and num_primary_inputs == 0: + # Add a LoadTensorAPI node + new_key = "999990" + prompt[new_key] = create_load_tensor_node() + keys["LoadTensorAPI"].append(new_key) + print("Added LoadTensorAPI node for tensor input") + elif len(keys["LoadTensor"]) > 0 and len(keys["LoadTensorAPI"]) == 0: + # Replace LoadTensor with LoadTensorAPI if found + for key in keys["LoadTensor"]: + prompt[key] = create_load_tensor_node() + keys["LoadTensorAPI"].append(key) + print("Replaced LoadTensor with LoadTensorAPI") + + # Set up connection for output if needed + if num_outputs == 0: + # Find nodes that produce images + image_output_nodes = [] + for key, node in prompt.items(): + if isinstance(node, dict): + class_type = node.get("class_type", "") + # Look for nodes that typically output images + if any(output_type in class_type.lower() for output_type in ["vae", "decode", "img", "image", "upscale", "sample"]): + image_output_nodes.append(key) + # Also check if the node's RETURN_TYPES includes IMAGE + elif "outputs" in node and isinstance(node["outputs"], dict): + for output_name, output_type in node["outputs"].items(): + if "IMAGE" in output_type: + image_output_nodes.append(key) + break + + # If we found image output nodes, connect SaveTensorAPI to them + if image_output_nodes: + for i, node_key in enumerate(image_output_nodes): + new_key = f"999991_{i}" + prompt[new_key] = create_save_tensor_node({"images": [node_key, 0]}) + print(f"Added SaveTensorAPI node connected to {node_key}") + else: + # Try to find the last node in the chain as fallback + last_node = None + max_id = -1 + for key, node in prompt.items(): + if isinstance(node, dict): + try: + node_id = int(key) + if node_id > max_id: + max_id = node_id + last_node = key + except ValueError: + pass + + if last_node: + # Add a SaveTensorAPI node + new_key = "999991" + prompt[new_key] = create_save_tensor_node({"images": [last_node, 0]}) + print(f"Added SaveTensorAPI node connected to {last_node}") + else: + print("Warning: Could not find a suitable node to connect SaveTensorAPI to") + + # Make sure all SaveTensorAPI nodes have proper configuration + for key, node in prompt.items(): + if isinstance(node, dict) and node.get("class_type") == "SaveTensorAPI": + # Ensure format is set to PNG for optimal compatibility + if "inputs" in node: + node["inputs"]["format"] = "png" + + # Return the modified prompt + return prompt +''' \ No newline at end of file From b3a95ad30941320ee4ca16d62d472b80eeaf8b5b Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Tue, 25 Mar 2025 14:43:28 -0400 Subject: [PATCH 2/7] Cleanup of pre/post processing of frames. --- server/pipeline_api.py | 74 +++++++++++++----------------------------- 1 file changed, 22 insertions(+), 52 deletions(-) diff --git a/server/pipeline_api.py b/server/pipeline_api.py index 0f04e773..bfaf4909 100644 --- a/server/pipeline_api.py +++ b/server/pipeline_api.py @@ -3,9 +3,6 @@ import numpy as np import asyncio import logging -import time -from PIL import Image -from io import BytesIO from typing import Any, Dict, Union, List from comfystream.client_api import ComfyStreamClient @@ -76,62 +73,35 @@ async def put_audio_frame(self, frame: av.AudioFrame): self.client.put_audio_input(frame) await self.audio_incoming_frames.put(frame) - def video_preprocess(self, frame: av.VideoFrame) -> Union[torch.Tensor, np.ndarray]: - """Convert input video frame to tensor in consistent BCHW format""" - try: - frame_np = frame.to_ndarray(format="rgb24") - frame_np = frame_np.astype(np.float32) / 255.0 - tensor = torch.from_numpy(frame_np) - - # TODO: Necessary? - if len(tensor.shape) == 3 and tensor.shape[2] == 3: # HWC format - tensor = tensor.permute(2, 0, 1).unsqueeze(0) # -> BCHW - - # Ensure values are in range [0,1] - if tensor.min() < 0 or tensor.max() > 1: - logger.warning(f"Clamping preprocessing tensor: min={tensor.min().item()}, max={tensor.max().item()}") - tensor = torch.clamp(tensor, 0, 1) - - return tensor - - except Exception as e: - logger.error(f"Error in video_preprocess: {e}") - # Return a default tensor in case of error - return torch.zeros(1, 3, frame.height, frame.width) - def audio_preprocess(self, frame: av.AudioFrame) -> Union[torch.Tensor, np.ndarray]: return frame.to_ndarray().ravel().reshape(-1, 2).mean(axis=1).astype(np.int16) + # Works with ComfyUI Native + def video_preprocess(self, frame: av.VideoFrame) -> Union[torch.Tensor, np.ndarray]: + frame_np = frame.to_ndarray(format="rgb24").astype(np.float32) / 255.0 + return torch.from_numpy(frame_np).unsqueeze(0) + + ''' Converts HWC format (height, width, channels) to VideoFrame. def video_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.VideoFrame: - """Convert tensor to VideoFrame format""" - try: - # Ensure output is a tensor - if isinstance(output, np.ndarray): - output = torch.from_numpy(output) - - # Convert from BCHW to HWC format for video frame - if len(output.shape) == 4: # BCHW format - output = output.squeeze(0) # Remove batch dimension - if output.shape[0] == 3: # CHW format - output = output.permute(1, 2, 0) # Convert to HWC - - # Convert to numpy array in correct format for VideoFrame - frame_np = (output * 255.0).clamp(0, 255).to(dtype=torch.uint8).cpu().numpy() - - # Create VideoFrame with RGB format - video_frame = av.VideoFrame.from_ndarray(frame_np, format='rgb24') - - logger.info(f"Created video frame with shape: {frame_np.shape}") - return video_frame - - except Exception as e: - logger.error(f"Error in video_postprocess: {str(e)}") - # Return a black frame as fallback - return av.VideoFrame(width=512, height=512, format='rgb24') + return av.VideoFrame.from_ndarray( + (output * 255.0).clamp(0, 255).to(dtype=torch.uint8).squeeze(0).cpu().numpy() + ) + ''' + + '''Converts BCHW tensor [1,3,H,W] to VideoFrame. Assumes values are in range [0,1].''' + def video_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.VideoFrame: + return av.VideoFrame.from_ndarray( + (output.squeeze(0).permute(1, 2, 0) * 255.0) + .clamp(0, 255) + .to(dtype=torch.uint8) + .cpu() + .numpy(), + format='rgb24' + ) def audio_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.AudioFrame: return av.AudioFrame.from_ndarray(np.repeat(output, 2).reshape(1, -1)) - + async def get_processed_video_frame(self): """Get processed video frame from output queue and match it with input frame""" try: From 50ca4a1fc85ddf1db9e456114a5292fb3978e02a Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Tue, 25 Mar 2025 15:24:36 -0400 Subject: [PATCH 3/7] Added built-in nodes for base64 string and websocket image send, reduced uncessary base64 input frame operations, prep for multi-instance, cleanup. --- nodes/native_utils/__init__.py | 17 ++ nodes/native_utils/load_image_base64.py | 37 +++ nodes/native_utils/send_image_websocket.py | 44 +++ server/pipeline_api.py | 45 +-- src/comfystream/client_api.py | 313 ++++++--------------- src/comfystream/utils_api.py | 225 ++++----------- 6 files changed, 270 insertions(+), 411 deletions(-) create mode 100644 nodes/native_utils/__init__.py create mode 100644 nodes/native_utils/load_image_base64.py create mode 100644 nodes/native_utils/send_image_websocket.py diff --git a/nodes/native_utils/__init__.py b/nodes/native_utils/__init__.py new file mode 100644 index 00000000..e7e7789c --- /dev/null +++ b/nodes/native_utils/__init__.py @@ -0,0 +1,17 @@ +from .load_image_base64 import LoadImageBase64 +from .send_image_websocket import SendImageWebsocket + +# This dictionary is used by ComfyUI to register the nodes +NODE_CLASS_MAPPINGS = { + "LoadImageBase64": LoadImageBase64, + "SendImageWebsocket": SendImageWebsocket +} + +# This dictionary provides display names for the nodes in the UI +NODE_DISPLAY_NAME_MAPPINGS = { + "LoadImageBase64": "Load Image Base64 (ComfyStream)", + "SendImageWebsocket": "Send Image Websocket (ComfyStream)" +} + +# Export these variables for ComfyUI to use +__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] diff --git a/nodes/native_utils/load_image_base64.py b/nodes/native_utils/load_image_base64.py new file mode 100644 index 00000000..f46a90df --- /dev/null +++ b/nodes/native_utils/load_image_base64.py @@ -0,0 +1,37 @@ +# borrowed from Acly's comfyui-tooling-nodes +# https://github.com/Acly/comfyui-tooling-nodes/blob/main/nodes.py + +# TODO: I think we can recieve tensor data directly from the pipeline through the /prompt endpoint as JSON +# This may be more efficient than sending base64 encoded images through the websocket and +# allow for alternative data formats. + +from PIL import Image +import base64 +import numpy as np +import torch +from io import BytesIO + +class LoadImageBase64: + @classmethod + def INPUT_TYPES(s): + return {"required": {"image": ("STRING", {"multiline": False})}} + + RETURN_TYPES = ("IMAGE", "MASK") + CATEGORY = "external_tooling" + FUNCTION = "load_image" + + def load_image(self, image): + imgdata = base64.b64decode(image) + img = Image.open(BytesIO(imgdata)) + + if "A" in img.getbands(): + mask = np.array(img.getchannel("A")).astype(np.float32) / 255.0 + mask = torch.from_numpy(mask) + else: + mask = None + + img = img.convert("RGB") + img = np.array(img).astype(np.float32) / 255.0 + img = torch.from_numpy(img)[None,] + + return (img, mask) \ No newline at end of file diff --git a/nodes/native_utils/send_image_websocket.py b/nodes/native_utils/send_image_websocket.py new file mode 100644 index 00000000..590d3b7e --- /dev/null +++ b/nodes/native_utils/send_image_websocket.py @@ -0,0 +1,44 @@ +# borrowed from Acly's comfyui-tooling-nodes +# https://github.com/Acly/comfyui-tooling-nodes/blob/main/nodes.py + +# TODO: I think we can send tensor data directly to the pipeline in the websocket response. +# Worth talking to ComfyAnonymous about this. + +import numpy as np +from PIL import Image +from server import PromptServer, BinaryEventTypes + +class SendImageWebsocket: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "images": ("IMAGE",), + "format": (["PNG", "JPEG"], {"default": "PNG"}), + } + } + + RETURN_TYPES = () + FUNCTION = "send_images" + OUTPUT_NODE = True + CATEGORY = "external_tooling" + + def send_images(self, images, format): + results = [] + for tensor in images: + array = 255.0 * tensor.cpu().numpy() + image = Image.fromarray(np.clip(array, 0, 255).astype(np.uint8)) + + server = PromptServer.instance + server.send_sync( + BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, + [format, image, None], + server.client_id, + ) + results.append({ + "source": "websocket", + "content-type": f"image/{format.lower()}", + "type": "output", + }) + + return {"ui": {"images": results}} \ No newline at end of file diff --git a/server/pipeline_api.py b/server/pipeline_api.py index bfaf4909..a6eb1205 100644 --- a/server/pipeline_api.py +++ b/server/pipeline_api.py @@ -3,10 +3,10 @@ import numpy as np import asyncio import logging +import time from typing import Any, Dict, Union, List from comfystream.client_api import ComfyStreamClient -from comfystream import tensor_cache WARMUP_RUNS = 5 logger = logging.getLogger(__name__) @@ -19,6 +19,10 @@ def __init__(self, **kwargs): self.audio_incoming_frames = asyncio.Queue() self.processed_audio_buffer = np.array([], dtype=np.int16) + self.last_frame_time = 0 + + # TODO: Not sure if this is needed - should match to UI selected FPS + self.min_frame_interval = 1/30 # Limit to 30 FPS async def warm_video(self): """Warm up the video pipeline with dummy frames""" @@ -62,6 +66,11 @@ async def update_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any await self.client.update_prompts([prompts]) async def put_video_frame(self, frame: av.VideoFrame): + current_time = time.time() + if current_time - self.last_frame_time < self.min_frame_interval: + return # Skip frame if too soon + + self.last_frame_time = current_time frame.side_data.input = self.video_preprocess(frame) frame.side_data.skipped = False # Different from LoadTensor, we don't skip frames here self.client.put_video_input(frame) @@ -78,8 +87,16 @@ def audio_preprocess(self, frame: av.AudioFrame) -> Union[torch.Tensor, np.ndarr # Works with ComfyUI Native def video_preprocess(self, frame: av.VideoFrame) -> Union[torch.Tensor, np.ndarray]: - frame_np = frame.to_ndarray(format="rgb24").astype(np.float32) / 255.0 - return torch.from_numpy(frame_np).unsqueeze(0) + # Convert directly to tensor, avoiding intermediate numpy array when possible + if hasattr(frame, 'to_tensor'): + tensor = frame.to_tensor() + else: + # If direct tensor conversion not available, use numpy + frame_np = frame.to_ndarray(format="rgb24") + tensor = torch.from_numpy(frame_np) + + # Normalize to [0,1] range and add batch dimension + return tensor.float().div(255.0).unsqueeze(0) ''' Converts HWC format (height, width, channels) to VideoFrame. def video_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.VideoFrame: @@ -103,31 +120,21 @@ def audio_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.Audio return av.AudioFrame.from_ndarray(np.repeat(output, 2).reshape(1, -1)) async def get_processed_video_frame(self): - """Get processed video frame from output queue and match it with input frame""" try: # Get the frame from the incoming queue first frame = await self.video_incoming_frames.get() - while frame.side_data.skipped: + # Skip frames if we're falling behind + while not self.video_incoming_frames.empty(): + # Get newer frame and mark old one as skipped + frame.side_data.skipped = True frame = await self.video_incoming_frames.get() + logger.info("Skipped older frame to catch up") # Get the processed frame from the output queue - logger.info("Getting video output") out_tensor = await self.client.get_video_output() - # If there are more frames in the output queue, drain them to get the most recent - # This helps with synchronization when processing is faster than display - while not tensor_cache.image_outputs.empty(): - try: - newer_tensor = await asyncio.wait_for(self.client.get_video_output(), 0.01) - out_tensor = newer_tensor # Use the most recent frame - logger.info("Using more recent frame from output queue") - except asyncio.TimeoutError: - break - - logger.info(f"Received output tensor with shape: {out_tensor.shape if hasattr(out_tensor, 'shape') else 'unknown'}") - - # Process the output tensor + # Process only the most recent frame processed_frame = self.video_postprocess(out_tensor) processed_frame.pts = frame.pts processed_frame.time_base = frame.time_base diff --git a/src/comfystream/client_api.py b/src/comfystream/client_api.py index 5e23e391..624de56e 100644 --- a/src/comfystream/client_api.py +++ b/src/comfystream/client_api.py @@ -255,11 +255,11 @@ async def _handle_text_message(self, message: str): self._prompt_id = data["data"]["prompt_id"] if "node" in data["data"]: node_id = data["data"]["node"] - # ogger.info(f"Executing node: {node_id}") + # logger.info(f"Executing node: {node_id}") elif message_type in ["execution_cached", "execution_error", "execution_complete", "execution_interrupted"]: # logger.info(f"{message_type} message received for prompt {self._prompt_id}") - #self.execution_started = False + # self.execution_started = False # Always signal completion for these terminal states # self.execution_complete_event.set() @@ -335,214 +335,90 @@ async def _handle_text_message(self, message: str): async def _handle_binary_message(self, binary_data): """Process binary messages from the WebSocket""" try: - # Log binary message information - # logger.info(f"Received binary message of size: {len(binary_data)} bytes") - - # Signal execution is complete, queue next frame - self.execution_complete_event.set() - - # Binary messages in ComfyUI start with a header - # First 8 bytes are used for header information + # Early return if message is too short if len(binary_data) <= 8: - logger.warning(f"Binary message too short: {len(binary_data)} bytes") + self.execution_complete_event.set() return - # Extract header data based on the actual format observed in logs - # Header bytes (hex): 0000000100000001 - this appears to be the format in use + # Extract header data only when needed event_type = int.from_bytes(binary_data[:4], byteorder='little') format_type = int.from_bytes(binary_data[4:8], byteorder='little') data = binary_data[8:] - # Log header details - logger.info(f"Binary message header: event_type={event_type}, format_type={format_type}, data_size={len(data)} bytes") - #logger.info(f"Header bytes (hex): {binary_data[:8].hex()}") - - # Check if this is an image (JPEG starts with 0xFF, 0xD8, PNG starts with 0x89, 0x50) - is_jpeg = data[:2] == b'\xff\xd8' - is_png = data[:4] == b'\x89\x50\x4e\x47' + # Quick check for image format + is_image = data[:2] in [b'\xff\xd8', b'\x89\x50'] + if not is_image: + self.execution_complete_event.set() + return - if is_jpeg or is_png: - image_format = "JPEG" if is_jpeg else "PNG" - logger.info(f"Detected {image_format} image based on magic bytes") - - # Create a NEW binary message with the expected header format for the JavaScript client - # The JavaScript expects: [0:4]=1 (PREVIEW_IMAGE), [4:8]=1 (JPEG format) or [4:8]=2 (PNG format) - # This matches exactly what the JS code is looking for: - # const event = dataView.getUint32(0); // event type (1 = PREVIEW_IMAGE) - # const format = dataView.getUint32(4); // format (1 = JPEG, 2 = PNG) - js_event_type = (1).to_bytes(4, byteorder='little') # PREVIEW_IMAGE = 1 - js_format_type = (1 if is_jpeg else 2).to_bytes(4, byteorder='little') - transformed_data = js_event_type + js_format_type + data + # Process image data directly + try: + img = Image.open(BytesIO(data)) + if img.mode != "RGB": + img = img.convert("RGB") + + with torch.no_grad(): + tensor = torch.from_numpy(np.array(img)).float().permute(2, 0, 1).unsqueeze(0) / 255.0 - # Forward to WebSocket client if connected - # if self.ws: - # await self.ws.send(transformed_data) - # logger.info(f"Sent transformed {image_format} image data to WebSocket with correct JS header format") - #else: - # logger.error("WebSocket not connected, cannot forward image to JS client") + # Add to output queue without waiting + tensor_cache.image_outputs.put_nowait(tensor) + self.execution_complete_event.set() - # Process the image for our pipeline - try: - # Decode the image - img = Image.open(BytesIO(data)) - logger.info(f"Successfully decoded image: size={img.size}, mode={img.mode}, format={img.format}") - - # Convert to RGB if not already - if img.mode != "RGB": - img = img.convert("RGB") - logger.info(f"Converted image to RGB mode") - - # Save image to temp folder as a file - # TESTING - ''' - import os - import tempfile - temp_folder = os.path.join(tempfile.gettempdir(), "comfyui_images") - os.makedirs(temp_folder, exist_ok=True) - img_path = os.path.join(temp_folder, f"comfyui_image_{time.time()}.png") - img.save(img_path) - logger.info(f"Saved image to {img_path}") - ''' - - # Convert to tensor (normalize to [0,1] range for consistency) - img_np = np.array(img).astype(np.float32) / 255.0 - tensor = torch.from_numpy(img_np) - - # CRITICAL: Ensure dimensions are correctly understood - # The tensor should be in HWC format initially from PIL/numpy - logger.info(f"Initial tensor shape from image: {tensor.shape}") - - # Convert from HWC to BCHW format for consistency with model expectations - if len(tensor.shape) == 3 and tensor.shape[2] == 3: # HWC format (H,W,3) - tensor = tensor.permute(2, 0, 1).unsqueeze(0) # -> BCHW (1,3,H,W) - logger.info(f"Converted to BCHW tensor with shape: {tensor.shape}") - - # Check for NaN or Inf values - if torch.isnan(tensor).any() or torch.isinf(tensor).any(): - logger.warning("Tensor contains NaN or Inf values! Replacing with zeros") - tensor = torch.nan_to_num(tensor, nan=0.0, posinf=1.0, neginf=0.0) - - # Log detailed tensor info for debugging - logger.info(f"Final tensor with shape: {tensor.shape}, " - f"min={tensor.min().item()}, max={tensor.max().item()}, " - f"mean={tensor.mean().item()}") - - # Add to output queue without waiting - tensor_cache.image_outputs.put_nowait(tensor) - logger.info(f"Added tensor to output queue, queue size: {tensor_cache.image_outputs.qsize()}") - return + except Exception as img_error: + logger.error(f"Error processing image: {img_error}") + self.execution_complete_event.set() - except Exception as img_error: - logger.error(f"Error processing image: {img_error}", exc_info=True) - - # If we get here, we couldn't process the image - logger.warning("Failed to process image, creating default tensor") - default_tensor = torch.zeros(1, 3, 512, 512) - tensor_cache.image_outputs.put_nowait(default_tensor) - self.execution_complete_event.set() - except Exception as e: - logger.error(f"Error handling binary message: {e}", exc_info=True) - # Set execution complete event to avoid hanging + logger.error(f"Error handling binary message: {e}") self.execution_complete_event.set() async def _execute_prompt(self, prompt_index: int): - """Execute a prompt via the ComfyUI API""" try: # Get the prompt to execute prompt = self.current_prompts[prompt_index] - # Ensure all seed values are randomized for every execution - # This forces ComfyUI to not use cached results - for node_id, node in prompt.items(): - if isinstance(node, dict) and "inputs" in node: - if "seed" in node["inputs"]: - # Generate a truly random seed each time - random_seed = random.randint(0, 18446744073709551615) - node["inputs"]["seed"] = random_seed - logger.info(f"Randomized seed to {random_seed} for node {node_id}") - - # Also randomize noise_seed if present - if "noise_seed" in node["inputs"]: - noise_seed = random.randint(0, 18446744073709551615) - node["inputs"]["noise_seed"] = noise_seed - logger.info(f"Randomized noise_seed to {noise_seed} for node {node_id}") - - # Add a timestamp parameter to each node to prevent caching - # This is a "hidden" trick to force ComfyUI to consider each execution unique - timestamp = int(time.time() * 1000) # millisecond timestamp - for node_id, node in prompt.items(): - if isinstance(node, dict) and "inputs" in node: - # Add a timestamp parameter to ETN_LoadImageBase64 nodes - if node.get("class_type") in ["ETN_LoadImageBase64", "LoadImageBase64"]: - # Add a unique cache-busting parameter - node["inputs"]["_timestamp"] = timestamp - logger.info(f"Added timestamp {timestamp} to node {node_id}") - # Check if we have a frame waiting to be processed if not tensor_cache.image_inputs.empty(): logger.info("Found tensor in input queue, preparing for API") - # Get the frame from the cache - make sure to get the most recent frame + # Get the most recent frame only + frame_or_tensor = None while not tensor_cache.image_inputs.empty(): frame_or_tensor = tensor_cache.image_inputs.get_nowait() - # Find ETN_LoadImageBase64 nodes + # Find ETN_LoadImageBase64 nodes first load_image_nodes = [] for node_id, node in prompt.items(): - if isinstance(node, dict) and node.get("class_type") == "ETN_LoadImageBase64": + if isinstance(node, dict) and node.get("class_type") in ["ETN_LoadImageBase64", "LoadImageBase64"]: load_image_nodes.append(node_id) if not load_image_nodes: - # Also check for regular LoadImageBase64 nodes as fallback - for node_id, node in prompt.items(): - if isinstance(node, dict) and node.get("class_type") == "LoadImageBase64": - load_image_nodes.append(node_id) - - if not load_image_nodes: - logger.warning("No ETN_LoadImageBase64 or LoadImageBase64 nodes found in the prompt") - self.execution_complete_event.set() # Signal completion + logger.warning("No LoadImageBase64 nodes found in the prompt") + self.execution_complete_event.set() return - else: - # Convert the frame/tensor to base64 and include directly in the prompt - try: - # Get the actual tensor data - handle different input types - tensor = None - - # Check if it's a PyAV VideoFrame with preprocessed tensor in side_data - if hasattr(frame_or_tensor, 'side_data') and hasattr(frame_or_tensor.side_data, 'input'): - tensor = frame_or_tensor.side_data.input - logger.info(f"Using preprocessed tensor from frame.side_data.input with shape {tensor.shape}") - # Check if it's a PyTorch tensor - elif isinstance(frame_or_tensor, torch.Tensor): - tensor = frame_or_tensor - logger.info(f"Using tensor directly with shape {tensor.shape}") - # Check if it's a numpy array - elif isinstance(frame_or_tensor, np.ndarray): - tensor = torch.from_numpy(frame_or_tensor).float() - logger.info(f"Converted numpy array to tensor with shape {tensor.shape}") - else: - # If it's a PyAV frame without preprocessed data, convert it - try: - if hasattr(frame_or_tensor, 'to_ndarray'): - frame_np = frame_or_tensor.to_ndarray(format="rgb24").astype(np.float32) / 255.0 - tensor = torch.from_numpy(frame_np).unsqueeze(0) - logger.info(f"Converted PyAV frame to tensor with shape {tensor.shape}") - else: - logger.error(f"Unsupported frame type: {type(frame_or_tensor)}") - self.execution_complete_event.set() - return - except Exception as e: - logger.error(f"Error converting frame to tensor: {e}") - self.execution_complete_event.set() - return - - if tensor is None: - logger.error("Failed to get valid tensor data from input") - self.execution_complete_event.set() - return - - # Now process the tensor (which should be a proper PyTorch tensor) - # Ensure it's a tensor on CPU and detached + + # Process the tensor ONLY if we have nodes to send it to + try: + # Get the actual tensor data - handle different input types + tensor = None + + # Handle different input types efficiently + if hasattr(frame_or_tensor, 'side_data') and hasattr(frame_or_tensor.side_data, 'input'): + tensor = frame_or_tensor.side_data.input + elif isinstance(frame_or_tensor, torch.Tensor): + tensor = frame_or_tensor + elif isinstance(frame_or_tensor, np.ndarray): + tensor = torch.from_numpy(frame_or_tensor).float() + elif hasattr(frame_or_tensor, 'to_ndarray'): + frame_np = frame_or_tensor.to_ndarray(format="rgb24").astype(np.float32) / 255.0 + tensor = torch.from_numpy(frame_np).unsqueeze(0) + + if tensor is None: + logger.error("Failed to get valid tensor data from input") + self.execution_complete_event.set() + return + + # Process tensor format only once + with torch.no_grad(): tensor = tensor.detach().cpu().float() # Handle different formats @@ -553,65 +429,52 @@ async def _execute_prompt(self, prompt_index: int): if len(tensor.shape) == 3 and tensor.shape[2] == 3: # HWC format tensor = tensor.permute(2, 0, 1) # Convert to CHW - # Convert to PIL image for saving + # Convert to PIL image for base64 ONLY ONCE tensor_np = (tensor.permute(1, 2, 0) * 255).clamp(0, 255).numpy().astype(np.uint8) img = Image.fromarray(tensor_np) - # Save as PNG to BytesIO and convert to base64 string + # Convert to base64 ONCE for all nodes buffer = BytesIO() img.save(buffer, format="PNG") buffer.seek(0) - - # Encode as base64 - for ETN_LoadImageBase64, we need the raw base64 string img_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8') - logger.info(f"Created base64 string of length: {len(img_base64)}") - - # Update all ETN_LoadImageBase64 nodes with the base64 data - for node_id in load_image_nodes: - prompt[node_id]["inputs"]["image"] = img_base64 - # Add a small random suffix to image data to prevent caching - rand_suffix = str(random.randint(1, 1000000)) - prompt[node_id]["inputs"]["_cache_buster"] = rand_suffix - logger.info(f"Updated node {node_id} with base64 string and cache buster {rand_suffix}") - except Exception as e: - logger.error(f"Error converting tensor to base64: {e}") - # Signal execution complete in case of error - self.execution_complete_event.set() - return + # Update all nodes with the SAME base64 string + timestamp = int(time.time() * 1000) + for node_id in load_image_nodes: + prompt[node_id]["inputs"]["image"] = img_base64 + prompt[node_id]["inputs"]["_timestamp"] = timestamp + # Use timestamp as cache buster instead of random number + prompt[node_id]["inputs"]["_cache_buster"] = str(timestamp) + + except Exception as e: + logger.error(f"Error converting tensor to base64: {e}") + self.execution_complete_event.set() + return + + # Execute the prompt via API + async with aiohttp.ClientSession() as session: + api_url = f"{self.api_base_url}/prompt" + payload = { + "prompt": prompt, + "client_id": self.client_id + } + + async with session.post(api_url, json=payload) as response: + if response.status == 200: + result = await response.json() + self._prompt_id = result.get("prompt_id") + self.execution_started = True + else: + error_text = await response.text() + logger.error(f"Error queueing prompt: {response.status} - {error_text}") + self.execution_complete_event.set() else: logger.info("No tensor in input queue, skipping prompt execution") - self.execution_complete_event.set() # Signal completion - return - - # Execute the prompt via API - async with aiohttp.ClientSession() as session: - api_url = f"{self.api_base_url}/prompt" - payload = { - "prompt": prompt, - "client_id": self.client_id - } - - # Send the request - logger.info(f"Sending prompt to {api_url}") - async with session.post(api_url, json=payload) as response: - if response.status == 200: - result = await response.json() - self._prompt_id = result.get("prompt_id") - logger.info(f"Prompt queued with ID: {self._prompt_id}") - self.execution_started = True - else: - error_text = await response.text() - logger.error(f"Error queueing prompt: {response.status} - {error_text}") - # Signal execution complete in case of error - self.execution_complete_event.set() + self.execution_complete_event.set() - except aiohttp.ClientError as e: - logger.error(f"Client error queueing prompt: {e}") - self.execution_complete_event.set() except Exception as e: logger.error(f"Error executing prompt: {e}") - # Signal execution complete in case of error self.execution_complete_event.set() async def _send_tensor_via_websocket(self, tensor): diff --git a/src/comfystream/utils_api.py b/src/comfystream/utils_api.py index 8b93ec7f..5a932ef1 100644 --- a/src/comfystream/utils_api.py +++ b/src/comfystream/utils_api.py @@ -15,7 +15,6 @@ def create_load_tensor_node(): "_meta": {"title": "Load Tensor (API)"}, } - def create_save_tensor_node(inputs: Dict[Any, Any]): """Create a SaveTensorAPI node with proper input formatting""" # Make sure images input is properly formatted [node_id, output_index] @@ -51,79 +50,29 @@ def convert_prompt(prompt): random_seed = random.randint(0, 18446744073709551615) node["inputs"]["seed"] = random_seed print(f"Set random seed {random_seed} for node {key}") + + + # Replace LoadImage with LoadImageBase64 + for key, node in prompt.items(): + if node.get("class_type") == "LoadImage": + node["class_type"] = "LoadImageBase64" + + # Replace SaveImage/PreviewImage with SendImageWebsocket + for key, node in prompt.items(): + if node.get("class_type") in ["SaveImage", "PreviewImage"]: + node["class_type"] = "SendImageWebsocket" + # Ensure format is set + if "format" not in node["inputs"]: + node["inputs"]["format"] = "PNG" # Set default format return prompt ''' -def convert_prompt(prompt): +def convert_prompt(prompt: PromptDictInput) -> Prompt: + # Validate the schema + Prompt.validate(prompt) - # Check if this is a ComfyUI web UI format prompt with 'nodes' and 'links' - if isinstance(prompt, dict) and 'nodes' in prompt and 'links' in prompt: - # Convert the web UI prompt format to the API format - api_prompt = {} - - # Process each node - for node in prompt['nodes']: - node_id = str(node['id']) - - # Create a node entry in the API format - api_prompt[node_id] = { - 'class_type': node.get('type', 'Unknown'), - 'inputs': {}, - '_meta': { - 'title': node.get('type', 'Unknown') - } - } - - # Process inputs - if 'inputs' in node: - for input_data in node['inputs']: - input_name = input_data.get('name') - link_id = input_data.get('link') - - if input_name and link_id is not None: - # Find the source of this link - for link in prompt['links']: - if link[0] == link_id: # link ID matches - # Get source node and output slot - source_node_id = str(link[1]) - source_slot = link[3] - - # Add to inputs - api_prompt[node_id]['inputs'][input_name] = [ - source_node_id, - source_slot - ] - break - # If no link found, set to empty value - if input_name not in api_prompt[node_id]['inputs']: - api_prompt[node_id]['inputs'][input_name] = None - - # Process widget values - if 'widgets_values' in node: - for i, widget_value in enumerate(node.get('widgets_values', [])): - # Try to determine widget name from properties or use index - widget_name = f"widget_{i}" - # Add to inputs - api_prompt[node_id]['inputs'][widget_name] = widget_value - - # Use this converted prompt instead - prompt = api_prompt - - # Now process as normal API format prompt prompt = copy.deepcopy(prompt) - - # Set random seeds for any seed nodes - for key, node in prompt.items(): - if not isinstance(node, dict) or "inputs" not in node: - continue - - # Check if this node has a seed input directly - if "seed" in node.get("inputs", {}): - # Generate a random seed (same range as JavaScript's Math.random() * 18446744073709552000) - random_seed = random.randint(0, 18446744073709551615) - node["inputs"]["seed"] = random_seed - print(f"Set random seed {random_seed} for node {key}") num_primary_inputs = 0 num_inputs = 0 @@ -134,112 +83,54 @@ def convert_prompt(prompt): "LoadImage": [], "PreviewImage": [], "SaveImage": [], - "LoadTensor": [], - "SaveTensor": [], - "LoadImageBase64": [], - "LoadTensorAPI": [], - "SaveTensorAPI": [], } for key, node in prompt.items(): - if not isinstance(node, dict): - continue + class_type = node.get("class_type") + + # Collect keys for nodes that might need to be replaced + if class_type in keys: + keys[class_type].append(key) - class_type = node.get("class_type", "") - - # Track primary input and output nodes - if class_type in ["PrimaryInput", "PrimaryInputImage"]: + # Count inputs and outputs + if class_type == "PrimaryInputLoadImage": num_primary_inputs += 1 - keys["PrimaryInputLoadImage"].append(key) - elif class_type in ["LoadImage", "LoadTensor", "LoadAudioTensor", "LoadImageBase64", "LoadTensorAPI"]: + elif class_type in ["LoadImage", "LoadTensor", "LoadAudioTensor"]: num_inputs += 1 - if class_type == "LoadImage": - keys["LoadImage"].append(key) - elif class_type == "LoadTensor": - keys["LoadTensor"].append(key) - elif class_type == "LoadImageBase64": - keys["LoadImageBase64"].append(key) - elif class_type == "LoadTensorAPI": - keys["LoadTensorAPI"].append(key) - elif class_type in ["PreviewImage", "SaveImage", "SaveTensor", "SaveAudioTensor", "SendImageWebSocket", "SaveTensorAPI"]: + elif class_type in ["PreviewImage", "SaveImage", "SaveTensor", "SaveAudioTensor"]: num_outputs += 1 - if class_type == "PreviewImage": - keys["PreviewImage"].append(key) - elif class_type == "SaveImage": - keys["SaveImage"].append(key) - elif class_type == "SaveTensor": - keys["SaveTensor"].append(key) - elif class_type == "SaveTensorAPI": - keys["SaveTensorAPI"].append(key) - - print(f"Found {num_primary_inputs} primary inputs, {num_inputs} inputs, {num_outputs} outputs") - - # Set up connection for video feeds by replacing LoadImage with LoadImageBase64 - if num_inputs == 0 and num_primary_inputs == 0: - # Add a LoadTensorAPI node - new_key = "999990" - prompt[new_key] = create_load_tensor_node() - keys["LoadTensorAPI"].append(new_key) - print("Added LoadTensorAPI node for tensor input") - elif len(keys["LoadTensor"]) > 0 and len(keys["LoadTensorAPI"]) == 0: - # Replace LoadTensor with LoadTensorAPI if found - for key in keys["LoadTensor"]: - prompt[key] = create_load_tensor_node() - keys["LoadTensorAPI"].append(key) - print("Replaced LoadTensor with LoadTensorAPI") - - # Set up connection for output if needed + + # Only handle single primary input + if num_primary_inputs > 1: + raise Exception("too many primary inputs in prompt") + + # If there are no primary inputs, only handle single input + if num_primary_inputs == 0 and num_inputs > 1: + raise Exception("too many inputs in prompt") + + # Only handle single output for now + if num_outputs > 1: + raise Exception("too many outputs in prompt") + + if num_primary_inputs + num_inputs == 0: + raise Exception("missing input") + if num_outputs == 0: - # Find nodes that produce images - image_output_nodes = [] - for key, node in prompt.items(): - if isinstance(node, dict): - class_type = node.get("class_type", "") - # Look for nodes that typically output images - if any(output_type in class_type.lower() for output_type in ["vae", "decode", "img", "image", "upscale", "sample"]): - image_output_nodes.append(key) - # Also check if the node's RETURN_TYPES includes IMAGE - elif "outputs" in node and isinstance(node["outputs"], dict): - for output_name, output_type in node["outputs"].items(): - if "IMAGE" in output_type: - image_output_nodes.append(key) - break - - # If we found image output nodes, connect SaveTensorAPI to them - if image_output_nodes: - for i, node_key in enumerate(image_output_nodes): - new_key = f"999991_{i}" - prompt[new_key] = create_save_tensor_node({"images": [node_key, 0]}) - print(f"Added SaveTensorAPI node connected to {node_key}") - else: - # Try to find the last node in the chain as fallback - last_node = None - max_id = -1 - for key, node in prompt.items(): - if isinstance(node, dict): - try: - node_id = int(key) - if node_id > max_id: - max_id = node_id - last_node = key - except ValueError: - pass - - if last_node: - # Add a SaveTensorAPI node - new_key = "999991" - prompt[new_key] = create_save_tensor_node({"images": [last_node, 0]}) - print(f"Added SaveTensorAPI node connected to {last_node}") - else: - print("Warning: Could not find a suitable node to connect SaveTensorAPI to") - - # Make sure all SaveTensorAPI nodes have proper configuration - for key, node in prompt.items(): - if isinstance(node, dict) and node.get("class_type") == "SaveTensorAPI": - # Ensure format is set to PNG for optimal compatibility - if "inputs" in node: - node["inputs"]["format"] = "png" - - # Return the modified prompt + raise Exception("missing output") + + # Replace nodes + for key in keys["PrimaryInputLoadImage"]: + prompt[key] = create_load_tensor_node() + + if num_primary_inputs == 0 and len(keys["LoadImage"]) == 1: + prompt[keys["LoadImage"][0]] = create_load_tensor_node() + + for key in keys["PreviewImage"] + keys["SaveImage"]: + node = prompt[key] + prompt[key] = create_save_tensor_node(node["inputs"]) + + # Validate the processed prompt input + prompt = Prompt.validate(prompt) + return prompt ''' \ No newline at end of file From d66dac4dbbdc94d7ff7f394ee9b6b0aae78aeca0 Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Tue, 25 Mar 2025 15:48:50 -0400 Subject: [PATCH 4/7] Preliminary work on multi-Comfy server inference. --- server/app_api.py | 14 ++- server/pipeline_api.py | 246 ++++++++++++++++++++++++++++++++--------- 2 files changed, 207 insertions(+), 53 deletions(-) diff --git a/server/app_api.py b/server/app_api.py index 6158a540..94fc36af 100644 --- a/server/app_api.py +++ b/server/app_api.py @@ -26,7 +26,6 @@ from twilio.rest import Client from utils import patch_loop_datagram, add_prefix_to_app_routes, FPSMeter from metrics import MetricsManager, StreamStatsManager -import time logger = logging.getLogger(__name__) logging.getLogger("aiortc.rtcrtpsender").setLevel(logging.WARNING) @@ -345,7 +344,11 @@ async def on_startup(app: web.Application): patch_loop_datagram(app["media_ports"]) app["pipeline"] = Pipeline( - cwd=app["workspace"], disable_cuda_malloc=True, gpu_only=True, preview_method='none' + config_path=app["config_file"], + cwd=app["workspace"], + disable_cuda_malloc=True, + gpu_only=True, + preview_method='none' ) app["pcs"] = set() app["video_tracks"] = {} @@ -374,6 +377,12 @@ async def on_shutdown(app: web.Application): choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], help="Set the logging level", ) + parser.add_argument( + "--config-file", + type=str, + default=None, + help="Path to TOML configuration file for Comfy servers" + ) parser.add_argument( "--monitor", default=False, @@ -397,6 +406,7 @@ async def on_shutdown(app: web.Application): app = web.Application() app["media_ports"] = args.media_ports.split(",") if args.media_ports else None app["workspace"] = args.workspace + app["config_file"] = args.config_file app.on_startup.append(on_startup) app.on_shutdown.append(on_shutdown) diff --git a/server/pipeline_api.py b/server/pipeline_api.py index a6eb1205..6849c265 100644 --- a/server/pipeline_api.py +++ b/server/pipeline_api.py @@ -4,88 +4,214 @@ import asyncio import logging import time +import random +from collections import deque -from typing import Any, Dict, Union, List +from typing import Any, Dict, Union, List, Optional, Deque from comfystream.client_api import ComfyStreamClient +from config import ComfyConfig WARMUP_RUNS = 5 logger = logging.getLogger(__name__) -class Pipeline: - def __init__(self, **kwargs): - self.client = ComfyStreamClient(**kwargs) +class MultiServerPipeline: + def __init__(self, config_path: Optional[str] = None, **kwargs): + # Load server configurations + self.config = ComfyConfig(config_path) + self.servers = self.config.get_servers() + + # Create client for each server + self.clients = [] + for server_config in self.servers: + client_kwargs = kwargs.copy() + client_kwargs.update(server_config) + self.clients.append(ComfyStreamClient(**client_kwargs)) + + logger.info(f"Initialized {len(self.clients)} ComfyUI clients") + self.video_incoming_frames = asyncio.Queue() self.audio_incoming_frames = asyncio.Queue() - + + # Queue for processed frames from all clients + self.processed_video_frames = asyncio.Queue() + + # Track which client gets each frame (round-robin) + self.current_client_index = 0 + self.client_frame_mapping = {} # Maps frame_id -> client_index + + # Buffer to store frames in order of original pts + self.frame_output_buffer: Deque = deque() + + # Audio processing self.processed_audio_buffer = np.array([], dtype=np.int16) self.last_frame_time = 0 - # TODO: Not sure if this is needed - should match to UI selected FPS + # Frame rate limiting self.min_frame_interval = 1/30 # Limit to 30 FPS + + # Create background task for collecting processed frames + self.running = True + self.collector_task = asyncio.create_task(self._collect_processed_frames()) + + async def _collect_processed_frames(self): + """Background task to collect processed frames from all clients""" + try: + while self.running: + for i, client in enumerate(self.clients): + try: + # Non-blocking check if client has output ready + if hasattr(client, '_prompt_id') and client._prompt_id is not None: + # Get frame without waiting + try: + # Use wait_for with small timeout to avoid blocking + out_tensor = await asyncio.wait_for( + client.get_video_output(), + timeout=0.01 + ) + + # Find which original frame this corresponds to + # (using a simple approach here - could be improved) + # In real implementation, need to track which frames went to which client + frame_ids = [frame_id for frame_id, client_idx in + self.client_frame_mapping.items() if client_idx == i] + + if frame_ids: + # Use the oldest frame ID for this client + frame_id = min(frame_ids) + # Store the processed tensor along with original frame ID for ordering + await self.processed_video_frames.put((frame_id, out_tensor)) + # Remove the mapping + self.client_frame_mapping.pop(frame_id, None) + logger.info(f"Collected processed frame from client {i}, frame_id: {frame_id}") + except asyncio.TimeoutError: + # No frame ready yet, continue + pass + except Exception as e: + logger.error(f"Error collecting frame from client {i}: {e}") + + # Small sleep to avoid CPU spinning + await asyncio.sleep(0.01) + except asyncio.CancelledError: + logger.info("Frame collector task cancelled") + except Exception as e: + logger.error(f"Unexpected error in frame collector: {e}") async def warm_video(self): - """Warm up the video pipeline with dummy frames""" + """Warm up the video pipeline with dummy frames for each client""" logger.info("Warming up video pipeline...") - # Create a properly formatted dummy frame (random color pattern) - # Using standard tensor shape: BCHW [1, 3, 512, 512] + # Create a properly formatted dummy frame tensor = torch.rand(1, 3, 512, 512) # Random values in [0,1] - - # Create a dummy frame and attach the tensor as side_data dummy_frame = av.VideoFrame(width=512, height=512, format="rgb24") dummy_frame.side_data.input = tensor - # Process a few frames for warmup - for i in range(WARMUP_RUNS): - logger.info(f"Video warmup iteration {i+1}/{WARMUP_RUNS}") - self.client.put_video_input(dummy_frame) - await self.client.get_video_output() - + # Warm up each client + warmup_tasks = [] + for i, client in enumerate(self.clients): + warmup_tasks.append(self._warm_client_video(client, i, dummy_frame)) + + # Wait for all warmup tasks to complete + await asyncio.gather(*warmup_tasks) logger.info("Video pipeline warmup complete") + + async def _warm_client_video(self, client, client_index, dummy_frame): + """Warm up a single client""" + logger.info(f"Warming up client {client_index}") + for i in range(WARMUP_RUNS): + logger.info(f"Client {client_index} warmup iteration {i+1}/{WARMUP_RUNS}") + client.put_video_input(dummy_frame) + try: + await asyncio.wait_for(client.get_video_output(), timeout=5.0) + except asyncio.TimeoutError: + logger.warning(f"Timeout waiting for warmup frame from client {client_index}") + except Exception as e: + logger.error(f"Error warming client {client_index}: {e}") async def warm_audio(self): + # For now, only use the first client for audio + if not self.clients: + logger.warning("No clients available for audio warmup") + return + dummy_frame = av.AudioFrame() - dummy_frame.side_data.input = np.random.randint(-32768, 32767, int(48000 * 0.5), dtype=np.int16) # TODO: adds a lot of delay if it doesn't match the buffer size, is warmup needed? + dummy_frame.side_data.input = np.random.randint(-32768, 32767, int(48000 * 0.5), dtype=np.int16) dummy_frame.sample_rate = 48000 for _ in range(WARMUP_RUNS): - self.client.put_audio_input(dummy_frame) - await self.client.get_audio_output() + self.clients[0].put_audio_input(dummy_frame) + await self.clients[0].get_audio_output() async def set_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]): - if isinstance(prompts, list): - await self.client.set_prompts(prompts) - else: - await self.client.set_prompts([prompts]) + """Set the same prompts for all clients""" + if isinstance(prompts, dict): + prompts = [prompts] + + # Set prompts for each client + tasks = [] + for client in self.clients: + tasks.append(client.set_prompts(prompts)) + + await asyncio.gather(*tasks) + logger.info(f"Set prompts for {len(self.clients)} clients") async def update_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]): - if isinstance(prompts, list): - await self.client.update_prompts(prompts) - else: - await self.client.update_prompts([prompts]) + """Update prompts for all clients""" + if isinstance(prompts, dict): + prompts = [prompts] + + # Update prompts for each client + tasks = [] + for client in self.clients: + tasks.append(client.update_prompts(prompts)) + + await asyncio.gather(*tasks) + logger.info(f"Updated prompts for {len(self.clients)} clients") async def put_video_frame(self, frame: av.VideoFrame): + """Distribute video frames among clients using round-robin""" current_time = time.time() if current_time - self.last_frame_time < self.min_frame_interval: return # Skip frame if too soon self.last_frame_time = current_time + + # Generate a unique frame ID + frame_id = int(time.time() * 1000000) # Microseconds as ID + frame.side_data.frame_id = frame_id + + # Preprocess the frame frame.side_data.input = self.video_preprocess(frame) - frame.side_data.skipped = False # Different from LoadTensor, we don't skip frames here - self.client.put_video_input(frame) + frame.side_data.skipped = False + + # Select the next client in round-robin fashion + client_index = self.current_client_index + self.current_client_index = (self.current_client_index + 1) % len(self.clients) + + # Store mapping of which client is processing this frame + self.client_frame_mapping[frame_id] = client_index + + # Send frame to the selected client + self.clients[client_index].put_video_input(frame) + + # Also add to the incoming queue for reference await self.video_incoming_frames.put(frame) + + logger.info(f"Sent frame {frame_id} to client {client_index}") async def put_audio_frame(self, frame: av.AudioFrame): + # For now, only use the first client for audio + if not self.clients: + return + frame.side_data.input = self.audio_preprocess(frame) frame.side_data.skipped = False - self.client.put_audio_input(frame) + self.clients[0].put_audio_input(frame) await self.audio_incoming_frames.put(frame) def audio_preprocess(self, frame: av.AudioFrame) -> Union[torch.Tensor, np.ndarray]: return frame.to_ndarray().ravel().reshape(-1, 2).mean(axis=1).astype(np.int16) - # Works with ComfyUI Native def video_preprocess(self, frame: av.VideoFrame) -> Union[torch.Tensor, np.ndarray]: # Convert directly to tensor, avoiding intermediate numpy array when possible if hasattr(frame, 'to_tensor'): @@ -97,15 +223,7 @@ def video_preprocess(self, frame: av.VideoFrame) -> Union[torch.Tensor, np.ndarr # Normalize to [0,1] range and add batch dimension return tensor.float().div(255.0).unsqueeze(0) - - ''' Converts HWC format (height, width, channels) to VideoFrame. - def video_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.VideoFrame: - return av.VideoFrame.from_ndarray( - (output * 255.0).clamp(0, 255).to(dtype=torch.uint8).squeeze(0).cpu().numpy() - ) - ''' - '''Converts BCHW tensor [1,3,H,W] to VideoFrame. Assumes values are in range [0,1].''' def video_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.VideoFrame: return av.VideoFrame.from_ndarray( (output.squeeze(0).permute(1, 2, 0) * 255.0) @@ -121,7 +239,7 @@ def audio_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.Audio async def get_processed_video_frame(self): try: - # Get the frame from the incoming queue first + # Get the frame from the incoming queue first to maintain timing frame = await self.video_incoming_frames.get() # Skip frames if we're falling behind @@ -130,11 +248,11 @@ async def get_processed_video_frame(self): frame.side_data.skipped = True frame = await self.video_incoming_frames.get() logger.info("Skipped older frame to catch up") - - # Get the processed frame from the output queue - out_tensor = await self.client.get_video_output() - # Process only the most recent frame + # Get the processed frame from our output queue + frame_id, out_tensor = await self.processed_video_frames.get() + + # Process the frame processed_frame = self.video_postprocess(out_tensor) processed_frame.pts = frame.pts processed_frame.time_base = frame.time_base @@ -148,10 +266,14 @@ async def get_processed_video_frame(self): return black_frame async def get_processed_audio_frame(self): - # TODO: make it generic to support purely generative audio cases and also add frame skipping + # Only use the first client for audio + if not self.clients: + logger.warning("No clients available for audio processing") + return av.AudioFrame(format='s16', layout='mono', samples=1024) + frame = await self.audio_incoming_frames.get() if frame.samples > len(self.processed_audio_buffer): - out_tensor = await self.client.get_audio_output() + out_tensor = await self.clients[0].get_audio_output() self.processed_audio_buffer = np.concatenate([self.processed_audio_buffer, out_tensor]) out_data = self.processed_audio_buffer[:frame.samples] self.processed_audio_buffer = self.processed_audio_buffer[frame.samples:] @@ -164,9 +286,31 @@ async def get_processed_audio_frame(self): return processed_frame async def get_nodes_info(self) -> Dict[str, Any]: - """Get information about all nodes in the current prompt including metadata.""" - nodes_info = await self.client.get_available_nodes() - return nodes_info + """Get information about nodes from the first client""" + if not self.clients: + return {} + return await self.clients[0].get_available_nodes() async def cleanup(self): - await self.client.cleanup() \ No newline at end of file + """Clean up all clients and background tasks""" + self.running = False + + # Cancel collector task + if hasattr(self, 'collector_task') and not self.collector_task.done(): + self.collector_task.cancel() + try: + await self.collector_task + except asyncio.CancelledError: + pass + + # Clean up all clients + cleanup_tasks = [] + for client in self.clients: + cleanup_tasks.append(client.cleanup()) + + await asyncio.gather(*cleanup_tasks) + logger.info("All clients cleaned up") + + +# For backwards compatibility, maintain the original Pipeline name +Pipeline = MultiServerPipeline \ No newline at end of file From d3b0b3c24704d7e146b48de10ca1145cae204171 Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Tue, 25 Mar 2025 17:19:27 -0400 Subject: [PATCH 5/7] Work on frame timing and management, added max_frame_wait argument, added config for server management. --- configs/comfy.toml | 13 ++++ requirements.txt | 1 + server/app_api.py | 10 ++++ server/config.py | 45 ++++++++++++++ server/pipeline_api.py | 110 +++++++++++++++++++++++++++++----- src/comfystream/client_api.py | 23 ++++--- 6 files changed, 174 insertions(+), 28 deletions(-) create mode 100644 configs/comfy.toml create mode 100644 server/config.py diff --git a/configs/comfy.toml b/configs/comfy.toml new file mode 100644 index 00000000..5d278541 --- /dev/null +++ b/configs/comfy.toml @@ -0,0 +1,13 @@ +# Configuration for multiple ComfyUI servers + +[[servers]] +host = "127.0.0.1" +port = 8188 +client_id = "client1" + +# Adding more servers: + +# [[servers]] +# host = "127.0.0.1" +# port = 8189 +# client_id = "client2" diff --git a/requirements.txt b/requirements.txt index 4a7e68ad..56fe8b22 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,5 +3,6 @@ comfyui @ git+https://github.com/hiddenswitch/ComfyUI.git@ce3583ad42c024b8f060d0 aiortc aiohttp toml +tomli twilio prometheus_client diff --git a/server/app_api.py b/server/app_api.py index 94fc36af..b7b56e56 100644 --- a/server/app_api.py +++ b/server/app_api.py @@ -345,6 +345,7 @@ async def on_startup(app: web.Application): app["pipeline"] = Pipeline( config_path=app["config_file"], + max_frame_wait_ms=app["max_frame_wait"], cwd=app["workspace"], disable_cuda_malloc=True, gpu_only=True, @@ -353,6 +354,8 @@ async def on_startup(app: web.Application): app["pcs"] = set() app["video_tracks"] = {} + app["max_frame_wait"] = args.max_frame_wait + async def on_shutdown(app: web.Application): pcs = app["pcs"] @@ -395,6 +398,12 @@ async def on_shutdown(app: web.Application): action="store_true", help="Include stream ID as a label in Prometheus metrics.", ) + parser.add_argument( + "--max-frame-wait", + type=int, + default=500, + help="Maximum time to wait for a frame in milliseconds before dropping it" + ) args = parser.parse_args() logging.basicConfig( @@ -407,6 +416,7 @@ async def on_shutdown(app: web.Application): app["media_ports"] = args.media_ports.split(",") if args.media_ports else None app["workspace"] = args.workspace app["config_file"] = args.config_file + app["max_frame_wait"] = args.max_frame_wait app.on_startup.append(on_startup) app.on_shutdown.append(on_shutdown) diff --git a/server/config.py b/server/config.py new file mode 100644 index 00000000..7f066643 --- /dev/null +++ b/server/config.py @@ -0,0 +1,45 @@ +import tomli +import logging +from typing import List, Dict, Any, Optional + +logger = logging.getLogger(__name__) + +class ComfyConfig: + def __init__(self, config_path: Optional[str] = None): + self.servers = [] + self.config_path = config_path + if config_path: + self.load_config(config_path) + else: + # Default to single local server if no config provided + self.servers = [{"host": "127.0.0.1", "port": 8188}] + + def load_config(self, config_path: str): + """Load server configuration from TOML file""" + try: + with open(config_path, "rb") as f: + config = tomli.load(f) + + # Extract server configurations + if "servers" in config: + self.servers = config["servers"] + logger.info(f"Loaded {len(self.servers)} server configurations") + else: + logger.warning("No servers defined in config, using default") + self.servers = [{"host": "127.0.0.1", "port": 8198}] + + # Validate each server has required fields + for i, server in enumerate(self.servers): + if "host" not in server or "port" not in server: + logger.warning(f"Server {i} missing host or port, using defaults") + server["host"] = server.get("host", "127.0.0.1") + server["port"] = server.get("port", 8198) + + except Exception as e: + logger.error(f"Error loading config from {config_path}: {e}") + # Fall back to default server + self.servers = [{"host": "127.0.0.1", "port": 8198}] + + def get_servers(self) -> List[Dict[str, Any]]: + """Return list of server configurations""" + return self.servers \ No newline at end of file diff --git a/server/pipeline_api.py b/server/pipeline_api.py index 6849c265..d2aedf51 100644 --- a/server/pipeline_api.py +++ b/server/pipeline_api.py @@ -5,7 +5,7 @@ import logging import time import random -from collections import deque +from collections import deque, OrderedDict from typing import Any, Dict, Union, List, Optional, Deque from comfystream.client_api import ComfyStreamClient @@ -16,7 +16,7 @@ class MultiServerPipeline: - def __init__(self, config_path: Optional[str] = None, **kwargs): + def __init__(self, config_path: Optional[str] = None, max_frame_wait_ms: int = 500, **kwargs): # Load server configurations self.config = ComfyConfig(config_path) self.servers = self.config.get_servers() @@ -40,8 +40,10 @@ def __init__(self, config_path: Optional[str] = None, **kwargs): self.current_client_index = 0 self.client_frame_mapping = {} # Maps frame_id -> client_index - # Buffer to store frames in order of original pts - self.frame_output_buffer: Deque = deque() + # Frame ordering and timing + self.max_frame_wait_ms = max_frame_wait_ms # Max time to wait for a frame before dropping + self.next_expected_frame_id = None # Track expected frame ID + self.ordered_frames = OrderedDict() # Buffer for ordering frames (frame_id -> (timestamp, tensor)) # Audio processing self.processed_audio_buffer = np.array([], dtype=np.int16) @@ -71,16 +73,17 @@ async def _collect_processed_frames(self): ) # Find which original frame this corresponds to - # (using a simple approach here - could be improved) - # In real implementation, need to track which frames went to which client frame_ids = [frame_id for frame_id, client_idx in self.client_frame_mapping.items() if client_idx == i] if frame_ids: # Use the oldest frame ID for this client frame_id = min(frame_ids) - # Store the processed tensor along with original frame ID for ordering - await self.processed_video_frames.put((frame_id, out_tensor)) + + # Store frame with timestamp for ordering + timestamp = time.time() + await self._add_frame_to_ordered_buffer(frame_id, timestamp, out_tensor) + # Remove the mapping self.client_frame_mapping.pop(frame_id, None) logger.info(f"Collected processed frame from client {i}, frame_id: {frame_id}") @@ -90,6 +93,9 @@ async def _collect_processed_frames(self): except Exception as e: logger.error(f"Error collecting frame from client {i}: {e}") + # Check for frames that have waited too long + await self._check_frame_timeouts() + # Small sleep to avoid CPU spinning await asyncio.sleep(0.01) except asyncio.CancelledError: @@ -97,6 +103,70 @@ async def _collect_processed_frames(self): except Exception as e: logger.error(f"Unexpected error in frame collector: {e}") + async def _add_frame_to_ordered_buffer(self, frame_id, timestamp, tensor): + """Add a processed frame to the ordered buffer""" + self.ordered_frames[frame_id] = (timestamp, tensor) + + # If this is the first frame, set the next expected frame ID + if self.next_expected_frame_id is None: + self.next_expected_frame_id = frame_id + + # Check if we can release any frames now + await self._release_ordered_frames() + + async def _release_ordered_frames(self): + """Process ordered frames and put them in the output queue""" + # If we don't have a next expected frame yet, can't do anything + if self.next_expected_frame_id is None: + return + + # Check if the next expected frame is in our buffer + while self.ordered_frames and self.next_expected_frame_id in self.ordered_frames: + # Get the frame + timestamp, tensor = self.ordered_frames.pop(self.next_expected_frame_id) + + # Put it in the output queue + await self.processed_video_frames.put((self.next_expected_frame_id, tensor)) + logger.info(f"Released frame {self.next_expected_frame_id} to output queue") + + # Update the next expected frame ID to the next sequential ID if possible + # (or the lowest frame ID in our buffer) + if self.ordered_frames: + self.next_expected_frame_id = min(self.ordered_frames.keys()) + else: + # If no more frames, keep the last ID + 1 as next expected + self.next_expected_frame_id += 1 + + async def _check_frame_timeouts(self): + """Check for frames that have waited too long and handle them""" + if not self.ordered_frames or self.next_expected_frame_id is None: + return + + current_time = time.time() + + # If the next expected frame has timed out, skip it and move on + if self.next_expected_frame_id in self.ordered_frames: + timestamp, _ = self.ordered_frames[self.next_expected_frame_id] + wait_time_ms = (current_time - timestamp) * 1000 + + if wait_time_ms > self.max_frame_wait_ms: + logger.warning(f"Frame {self.next_expected_frame_id} exceeded max wait time, releasing anyway") + await self._release_ordered_frames() + + # Check if we're missing the next expected frame and it's been too long + elif self.ordered_frames: + # The next frame we're expecting isn't in the buffer + # Check how long we've been waiting since the oldest frame in the buffer + oldest_frame_id = min(self.ordered_frames.keys()) + oldest_timestamp, _ = self.ordered_frames[oldest_frame_id] + wait_time_ms = (current_time - oldest_timestamp) * 1000 + + # If we've waited too long, skip the missing frame(s) + if wait_time_ms > self.max_frame_wait_ms: + logger.warning(f"Missing frame {self.next_expected_frame_id}, skipping to {oldest_frame_id}") + self.next_expected_frame_id = oldest_frame_id + await self._release_ordered_frames() + async def warm_video(self): """Warm up the video pipeline with dummy frames for each client""" logger.info("Warming up video pipeline...") @@ -176,8 +246,13 @@ async def put_video_frame(self, frame: av.VideoFrame): self.last_frame_time = current_time - # Generate a unique frame ID - frame_id = int(time.time() * 1000000) # Microseconds as ID + # Generate a unique frame ID - use sequential IDs for better ordering + if not hasattr(self, 'next_frame_id'): + self.next_frame_id = 1 + + frame_id = self.next_frame_id + self.next_frame_id += 1 + frame.side_data.frame_id = frame_id # Preprocess the frame @@ -195,7 +270,7 @@ async def put_video_frame(self, frame: av.VideoFrame): self.clients[client_index].put_video_input(frame) # Also add to the incoming queue for reference - await self.video_incoming_frames.put(frame) + await self.video_incoming_frames.put((frame_id, frame)) logger.info(f"Sent frame {frame_id} to client {client_index}") @@ -239,18 +314,21 @@ def audio_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.Audio async def get_processed_video_frame(self): try: - # Get the frame from the incoming queue first to maintain timing - frame = await self.video_incoming_frames.get() + # Get the original frame from the incoming queue first to maintain timing + frame_id, frame = await self.video_incoming_frames.get() # Skip frames if we're falling behind while not self.video_incoming_frames.empty(): # Get newer frame and mark old one as skipped frame.side_data.skipped = True - frame = await self.video_incoming_frames.get() - logger.info("Skipped older frame to catch up") + frame_id, frame = await self.video_incoming_frames.get() + logger.info(f"Skipped older frame {frame_id} to catch up") # Get the processed frame from our output queue - frame_id, out_tensor = await self.processed_video_frames.get() + processed_frame_id, out_tensor = await self.processed_video_frames.get() + + if processed_frame_id != frame_id: + logger.warning(f"Frame ID mismatch: expected {frame_id}, got {processed_frame_id}") # Process the frame processed_frame = self.video_postprocess(out_tensor) diff --git a/src/comfystream/client_api.py b/src/comfystream/client_api.py index 624de56e..40801205 100644 --- a/src/comfystream/client_api.py +++ b/src/comfystream/client_api.py @@ -84,7 +84,6 @@ async def run_prompt(self, prompt_index: int): # Always set execution complete at start to allow first frame to be processed self.execution_complete_event.set() - logger.info("Setting execution_complete_event to TRUE at start") try: while True: @@ -97,16 +96,15 @@ async def run_prompt(self, prompt_index: int): if self.execution_complete_event.is_set(): # Reset execution state for next frame self.execution_complete_event.clear() - logger.info("Setting execution_complete_event to FALSE before executing prompt") # Queue the prompt with the current frame await self._execute_prompt(prompt_index) # Wait for execution completion with timeout try: - logger.info("Waiting for execution to complete (max 10 seconds)...") + # logger.info("Waiting for execution to complete (max 10 seconds)...") await asyncio.wait_for(self.execution_complete_event.wait(), timeout=10.0) - logger.info("Execution complete, ready for next frame") + # logger.info("Execution complete, ready for next frame") except asyncio.TimeoutError: logger.error("Timeout waiting for execution, forcing continuation") self.execution_complete_event.set() @@ -379,7 +377,7 @@ async def _execute_prompt(self, prompt_index: int): # Check if we have a frame waiting to be processed if not tensor_cache.image_inputs.empty(): - logger.info("Found tensor in input queue, preparing for API") + # logger.info("Found tensor in input queue, preparing for API") # Get the most recent frame only frame_or_tensor = None while not tensor_cache.image_inputs.empty(): @@ -514,15 +512,16 @@ async def _send_tensor_via_websocket(self, tensor): # Check tensor dimensions and log detailed info logger.info(f"Original tensor for WS: shape={tensor.shape}, min={tensor.min().item():.4f}, max={tensor.max().item():.4f}") - # CRITICAL FIX: The issue is with the shape - no need to resize if dimensions are fine - if tensor.shape[1] < 64 or tensor.shape[2] < 64: - logger.warning(f"Tensor dimensions too small: {tensor.shape}. Resizing to 512x512") + # Always ensure consistent 512x512 dimensions + ''' + if tensor.shape[1] != 512 or tensor.shape[2] != 512: + logger.info(f"Resizing tensor from {tensor.shape} to standard 512x512") import torch.nn.functional as F tensor = tensor.unsqueeze(0) # Add batch dimension for interpolate tensor = F.interpolate(tensor, size=(512, 512), mode='bilinear', align_corners=False) tensor = tensor.squeeze(0) # Remove batch dimension after resize - logger.info(f"Resized tensor to: {tensor.shape}") - + ''' + # Check for NaN or Inf values if torch.isnan(tensor).any() or torch.isinf(tensor).any(): logger.warning("Tensor contains NaN or Inf values! Replacing with zeros.") @@ -648,9 +647,9 @@ def put_audio_input(self, frame): async def get_video_output(self): """Get processed video frame from tensor cache""" - logger.info("Waiting for processed tensor from output queue") + # logger.info("Waiting for processed tensor from output queue") result = await tensor_cache.image_outputs.get() - logger.info(f"Got processed tensor from output queue: shape={result.shape if hasattr(result, 'shape') else 'unknown'}") + # logger.info(f"Got processed tensor from output queue: shape={result.shape if hasattr(result, 'shape') else 'unknown'}") return result async def get_audio_output(self): From 63014e3f4e6a0d6c8684589053297d9b74facc0c Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Tue, 25 Mar 2025 17:37:31 -0400 Subject: [PATCH 6/7] Cleaned up prompt manipulation with custom nodes. --- src/comfystream/utils_api.py | 93 +++++++++++++++++++++++++++++++----- 1 file changed, 81 insertions(+), 12 deletions(-) diff --git a/src/comfystream/utils_api.py b/src/comfystream/utils_api.py index 5a932ef1..d207f4e3 100644 --- a/src/comfystream/utils_api.py +++ b/src/comfystream/utils_api.py @@ -15,6 +15,15 @@ def create_load_tensor_node(): "_meta": {"title": "Load Tensor (API)"}, } +def create_load_image_node(): + return { + "inputs": { + "image": "" # Should be "image" not "image_data" to match LoadImageBase64 + }, + "class_type": "LoadImageBase64", + "_meta": {"title": "Load Image Base64 (ComfyStream)"}, + } + def create_save_tensor_node(inputs: Dict[Any, Any]): """Create a SaveTensorAPI node with proper input formatting""" # Make sure images input is properly formatted [node_id, output_index] @@ -35,10 +44,38 @@ def create_save_tensor_node(inputs: Dict[Any, Any]): "_meta": {"title": "Save Tensor (API)"}, } -def convert_prompt(prompt): +def create_save_image_node(inputs: Dict[Any, Any]): + # Get the correct image input reference + images_input = inputs.get("images", inputs.get("image")) + + # If not properly formatted, use default + if not images_input: + images_input = ["", 0] # Default empty value + + return { + "inputs": { + "images": images_input, + "format": "PNG" # Default format + }, + "class_type": "SendImageWebsocket", + "_meta": {"title": "Send Image Websocket (ComfyStream)"}, + } +def convert_prompt(prompt): logging.info("Converting prompt: %s", prompt) + # Initialize counters + num_primary_inputs = 0 + num_inputs = 0 + num_outputs = 0 + + keys = { + "PrimaryInputLoadImage": [], + "LoadImage": [], + "PreviewImage": [], + "SaveImage": [], + } + # Set random seeds for any seed nodes for key, node in prompt.items(): if not isinstance(node, dict) or "inputs" not in node: @@ -50,20 +87,52 @@ def convert_prompt(prompt): random_seed = random.randint(0, 18446744073709551615) node["inputs"]["seed"] = random_seed print(f"Set random seed {random_seed} for node {key}") + + for key, node in prompt.items(): + class_type = node.get("class_type") + # Collect keys for nodes that might need to be replaced + if class_type in keys: + keys[class_type].append(key) - # Replace LoadImage with LoadImageBase64 - for key, node in prompt.items(): - if node.get("class_type") == "LoadImage": - node["class_type"] = "LoadImageBase64" + # Count inputs and outputs + if class_type == "PrimaryInputLoadImage": + num_primary_inputs += 1 + elif class_type in ["LoadImage", "LoadImageBase64"]: + num_inputs += 1 + elif class_type in ["PreviewImage", "SaveImage", "SendImageWebsocket"]: + num_outputs += 1 - # Replace SaveImage/PreviewImage with SendImageWebsocket - for key, node in prompt.items(): - if node.get("class_type") in ["SaveImage", "PreviewImage"]: - node["class_type"] = "SendImageWebsocket" - # Ensure format is set - if "format" not in node["inputs"]: - node["inputs"]["format"] = "PNG" # Set default format + # Only handle single primary input + if num_primary_inputs > 1: + raise Exception("too many primary inputs in prompt") + + # If there are no primary inputs, only handle single input + if num_primary_inputs == 0 and num_inputs > 1: + raise Exception("too many inputs in prompt") + + # Only handle single output for now + if num_outputs > 1: + raise Exception("too many outputs in prompt") + + if num_primary_inputs + num_inputs == 0: + raise Exception("missing input") + + if num_outputs == 0: + raise Exception("missing output") + + # Replace nodes with proper implementations + for key in keys["PrimaryInputLoadImage"]: + prompt[key] = create_load_image_node() + + if num_primary_inputs == 0 and len(keys["LoadImage"]) == 1: + prompt[keys["LoadImage"][0]] = create_load_image_node() + + for key in keys["PreviewImage"] + keys["SaveImage"]: + node = prompt[key] + prompt[key] = create_save_image_node(node["inputs"]) + + # TODO: Validate the processed prompt input return prompt From 15b51a8283369080f1fdbd6af33b5ac3e460e1ea Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Tue, 25 Mar 2025 18:40:17 -0400 Subject: [PATCH 7/7] Added frame tracking, add frame timing stability, added mismatched frame size handling, commented out some logging. --- server/pipeline_api.py | 51 ++++++++----- src/comfystream/client_api.py | 136 ++++++++++++++++++++++++++-------- 2 files changed, 136 insertions(+), 51 deletions(-) diff --git a/server/pipeline_api.py b/server/pipeline_api.py index d2aedf51..1ff1826a 100644 --- a/server/pipeline_api.py +++ b/server/pipeline_api.py @@ -67,26 +67,36 @@ async def _collect_processed_frames(self): # Get frame without waiting try: # Use wait_for with small timeout to avoid blocking - out_tensor = await asyncio.wait_for( + result = await asyncio.wait_for( client.get_video_output(), timeout=0.01 ) - # Find which original frame this corresponds to - frame_ids = [frame_id for frame_id, client_idx in - self.client_frame_mapping.items() if client_idx == i] - - if frame_ids: - # Use the oldest frame ID for this client - frame_id = min(frame_ids) - - # Store frame with timestamp for ordering - timestamp = time.time() - await self._add_frame_to_ordered_buffer(frame_id, timestamp, out_tensor) + # Check if result is already a tuple with frame_id + if isinstance(result, tuple) and len(result) == 2: + frame_id, out_tensor = result + # logger.info(f"Got result with embedded frame_id: {frame_id}") + else: + out_tensor = result + # Find which original frame this corresponds to using our mapping + frame_ids = [frame_id for frame_id, client_idx in + self.client_frame_mapping.items() if client_idx == i] - # Remove the mapping - self.client_frame_mapping.pop(frame_id, None) - logger.info(f"Collected processed frame from client {i}, frame_id: {frame_id}") + if frame_ids: + # Use the oldest frame ID for this client + frame_id = min(frame_ids) + else: + # If no mapping found, log warning and continue + logger.warning(f"No frame_id mapping found for tensor from client {i}") + continue + + # Store frame with timestamp for ordering + timestamp = time.time() + await self._add_frame_to_ordered_buffer(frame_id, timestamp, out_tensor) + + # Remove the mapping + self.client_frame_mapping.pop(frame_id, None) + # logger.info(f"Collected processed frame from client {i}, frame_id: {frame_id}") except asyncio.TimeoutError: # No frame ready yet, continue pass @@ -127,7 +137,7 @@ async def _release_ordered_frames(self): # Put it in the output queue await self.processed_video_frames.put((self.next_expected_frame_id, tensor)) - logger.info(f"Released frame {self.next_expected_frame_id} to output queue") + # logger.info(f"Released frame {self.next_expected_frame_id} to output queue") # Update the next expected frame ID to the next sequential ID if possible # (or the lowest frame ID in our buffer) @@ -163,7 +173,7 @@ async def _check_frame_timeouts(self): # If we've waited too long, skip the missing frame(s) if wait_time_ms > self.max_frame_wait_ms: - logger.warning(f"Missing frame {self.next_expected_frame_id}, skipping to {oldest_frame_id}") + # logger.warning(f"Missing frame {self.next_expected_frame_id}, skipping to {oldest_frame_id}") self.next_expected_frame_id = oldest_frame_id await self._release_ordered_frames() @@ -272,7 +282,7 @@ async def put_video_frame(self, frame: av.VideoFrame): # Also add to the incoming queue for reference await self.video_incoming_frames.put((frame_id, frame)) - logger.info(f"Sent frame {frame_id} to client {client_index}") + # logger.info(f"Sent frame {frame_id} to client {client_index}") async def put_audio_frame(self, frame: av.AudioFrame): # For now, only use the first client for audio @@ -322,13 +332,14 @@ async def get_processed_video_frame(self): # Get newer frame and mark old one as skipped frame.side_data.skipped = True frame_id, frame = await self.video_incoming_frames.get() - logger.info(f"Skipped older frame {frame_id} to catch up") + # logger.info(f"Skipped older frame {frame_id} to catch up") # Get the processed frame from our output queue processed_frame_id, out_tensor = await self.processed_video_frames.get() if processed_frame_id != frame_id: - logger.warning(f"Frame ID mismatch: expected {frame_id}, got {processed_frame_id}") + # logger.warning(f"Frame ID mismatch: expected {frame_id}, got {processed_frame_id}") + pass # Process the frame processed_frame = self.video_postprocess(out_tensor) diff --git a/src/comfystream/client_api.py b/src/comfystream/client_api.py index 40801205..f8725aed 100644 --- a/src/comfystream/client_api.py +++ b/src/comfystream/client_api.py @@ -45,6 +45,10 @@ def __init__(self, host: str = "127.0.0.1", port: int = 8198, **kwargs): self.execution_started = False self._prompt_id = None + # Add frame tracking + self._current_frame_id = None # Track the current frame being processed + self._frame_id_mapping = {} # Map prompt_ids to frame_ids + # Configure logging if 'log_level' in kwargs: logger.setLevel(kwargs['log_level']) @@ -358,8 +362,23 @@ async def _handle_binary_message(self, binary_data): with torch.no_grad(): tensor = torch.from_numpy(np.array(img)).float().permute(2, 0, 1).unsqueeze(0) / 255.0 - # Add to output queue without waiting - tensor_cache.image_outputs.put_nowait(tensor) + # Try to get frame_id from mapping using current prompt_id + frame_id = None + if hasattr(self, '_prompt_id') and self._prompt_id in self._frame_id_mapping: + frame_id = self._frame_id_mapping.get(self._prompt_id) + # logger.info(f"Using frame_id {frame_id} from prompt_id {self._prompt_id}") + elif hasattr(self, '_current_frame_id') and self._current_frame_id is not None: + frame_id = self._current_frame_id + # logger.info(f"Using current frame_id {frame_id}") + + # Add to output queue - include frame_id if available + if frame_id is not None: + tensor_cache.image_outputs.put_nowait((frame_id, tensor)) + # logger.info(f"Added tensor with frame_id {frame_id} to output queue") + else: + tensor_cache.image_outputs.put_nowait(tensor) + #logger.info("Added tensor without frame_id to output queue") + self.execution_complete_event.set() except Exception as img_error: @@ -377,16 +396,26 @@ async def _execute_prompt(self, prompt_index: int): # Check if we have a frame waiting to be processed if not tensor_cache.image_inputs.empty(): - # logger.info("Found tensor in input queue, preparing for API") # Get the most recent frame only frame_or_tensor = None while not tensor_cache.image_inputs.empty(): frame_or_tensor = tensor_cache.image_inputs.get_nowait() + # Extract frame ID if available in side_data + frame_id = None + if hasattr(frame_or_tensor, 'side_data'): + # Try to get frame_id from side_data + if hasattr(frame_or_tensor.side_data, 'frame_id'): + frame_id = frame_or_tensor.side_data.frame_id + logger.info(f"Found frame_id in side_data: {frame_id}") + + # Store current frame ID for binary message handler to use + self._current_frame_id = frame_id + # Find ETN_LoadImageBase64 nodes first load_image_nodes = [] for node_id, node in prompt.items(): - if isinstance(node, dict) and node.get("class_type") in ["ETN_LoadImageBase64", "LoadImageBase64"]: + if isinstance(node, dict) and node.get("class_type") in ["LoadImageBase64"]: load_image_nodes.append(node_id) if not load_image_nodes: @@ -415,34 +444,64 @@ async def _execute_prompt(self, prompt_index: int): self.execution_complete_event.set() return - # Process tensor format only once + # Process tensor format only once - streamlined for speed and reliability with torch.no_grad(): - tensor = tensor.detach().cpu().float() - - # Handle different formats - if len(tensor.shape) == 4: # BCHW format (batch) - tensor = tensor[0] # Take first image from batch - - # Ensure it's in CHW format - if len(tensor.shape) == 3 and tensor.shape[2] == 3: # HWC format - tensor = tensor.permute(2, 0, 1) # Convert to CHW - - # Convert to PIL image for base64 ONLY ONCE - tensor_np = (tensor.permute(1, 2, 0) * 255).clamp(0, 255).numpy().astype(np.uint8) - img = Image.fromarray(tensor_np) + # Fast tensor normalization to ensure consistent output + try: + # TODO: Why is the UI sending different sizes? Should be fixed no? This breaks tensorrt + # I'm sometimes seeing (BCHW): torch.Size([1, 384, 384, 3]), H=384, W=3 + # Ensure minimum size of 512x512 + + # Handle batch dimension if present + if len(tensor.shape) == 4: # BCHW format + tensor = tensor[0] # Take first image from batch + + # Normalize to CHW format consistently + if len(tensor.shape) == 3 and tensor.shape[2] == 3: # HWC format + tensor = tensor.permute(2, 0, 1) # Convert to CHW + + # Handle single-channel case + if len(tensor.shape) == 3 and tensor.shape[0] == 1: + tensor = tensor.repeat(3, 1, 1) # Convert grayscale to RGB + + # Ensure tensor is on CPU + if tensor.is_cuda: + tensor = tensor.cpu() + + # Always resize to 512x512 for consistency (faster than checking dimensions first) + tensor = tensor.unsqueeze(0) # Add batch dim for interpolate + tensor = torch.nn.functional.interpolate( + tensor, size=(512, 512), mode='bilinear', align_corners=False + ) + tensor = tensor[0] # Remove batch dimension + + # Direct conversion to PIL without intermediate numpy step for speed + tensor_np = (tensor.permute(1, 2, 0).clamp(0, 1) * 255).to(torch.uint8).numpy() + img = Image.fromarray(tensor_np) + + # Fast JPEG encoding with balanced quality + buffer = BytesIO() + img.save(buffer, format="JPEG", quality=90, optimize=True) + buffer.seek(0) + img_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8') + + except Exception as e: + logger.warning(f"Error in tensor processing: {e}, creating fallback image") + # Create a standard 512x512 placeholder if anything fails + img = Image.new('RGB', (512, 512), color=(100, 149, 237)) + buffer = BytesIO() + img.save(buffer, format="JPEG", quality=90) + buffer.seek(0) + img_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8') - # Convert to base64 ONCE for all nodes - buffer = BytesIO() - img.save(buffer, format="PNG") - buffer.seek(0) - img_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8') + # Add timestamp for cache busting (once, outside the try/except) + timestamp = int(time.time() * 1000) # Update all nodes with the SAME base64 string - timestamp = int(time.time() * 1000) for node_id in load_image_nodes: prompt[node_id]["inputs"]["image"] = img_base64 prompt[node_id]["inputs"]["_timestamp"] = timestamp - # Use timestamp as cache buster instead of random number + # Use timestamp as cache buster prompt[node_id]["inputs"]["_cache_buster"] = str(timestamp) except Exception as e: @@ -462,6 +521,12 @@ async def _execute_prompt(self, prompt_index: int): if response.status == 200: result = await response.json() self._prompt_id = result.get("prompt_id") + + # Map prompt_id to frame_id for later retrieval + if frame_id is not None: + self._frame_id_mapping[self._prompt_id] = frame_id + # logger.info(f"Mapped prompt_id {self._prompt_id} to frame_id {frame_id}") + self.execution_started = True else: error_text = await response.text() @@ -493,7 +558,8 @@ async def _send_tensor_via_websocket(self, tensor): # Prepare binary data if len(tensor.shape) == 4: # BCHW format (batch of images) if tensor.shape[0] > 1: - logger.info(f"Taking first image from batch of {tensor.shape[0]}") + # logger.info(f"Taking first image from batch of {tensor.shape[0]}") + pass tensor = tensor[0] # Take first image if batch # Ensure CHW format (3 channels) @@ -510,7 +576,7 @@ async def _send_tensor_via_websocket(self, tensor): tensor = torch.zeros(3, 512, 512) # Check tensor dimensions and log detailed info - logger.info(f"Original tensor for WS: shape={tensor.shape}, min={tensor.min().item():.4f}, max={tensor.max().item():.4f}") + # logger.info(f"Original tensor for WS: shape={tensor.shape}, min={tensor.min().item():.4f}, max={tensor.max().item():.4f}") # Always ensure consistent 512x512 dimensions ''' @@ -557,7 +623,7 @@ async def _send_tensor_via_websocket(self, tensor): # Send binary data via websocket await self.ws.send(full_data) - logger.info(f"Sent tensor as PNG image via websocket with proper header, size: {len(full_data)} bytes, image dimensions: {img.size}") + # logger.info(f"Sent tensor as PNG image via websocket with proper header, size: {len(full_data)} bytes, image dimensions: {img.size}") except Exception as e: logger.error(f"Error sending tensor via websocket: {e}") @@ -647,10 +713,18 @@ def put_audio_input(self, frame): async def get_video_output(self): """Get processed video frame from tensor cache""" - # logger.info("Waiting for processed tensor from output queue") result = await tensor_cache.image_outputs.get() - # logger.info(f"Got processed tensor from output queue: shape={result.shape if hasattr(result, 'shape') else 'unknown'}") - return result + + # Check if the result is a tuple with frame_id + if isinstance(result, tuple) and len(result) == 2: + frame_id, tensor = result + # logger.info(f"Got processed tensor from output queue with frame_id {frame_id}") + # Return both the frame_id and tensor to help with ordering in the pipeline + return frame_id, tensor + else: + # If it's not a tuple with frame_id, just return the tensor + # logger.info("Got processed tensor from output queue without frame_id") + return result async def get_audio_output(self): """Get processed audio frame from tensor cache"""