From dcbdcfec1634ff8af64464cbe39a814ca3fd3911 Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Tue, 18 Mar 2025 16:08:02 -0400 Subject: [PATCH 01/42] Preliminary work for ComfyUI native API integration. --- server/app_api.py | 438 ++++++++++++++++++ server/pipeline_api.py | 195 ++++++++ src/comfystream/client_api.py | 836 ++++++++++++++++++++++++++++++++++ src/comfystream/utils_api.py | 245 ++++++++++ 4 files changed, 1714 insertions(+) create mode 100644 server/app_api.py create mode 100644 server/pipeline_api.py create mode 100644 src/comfystream/client_api.py create mode 100644 src/comfystream/utils_api.py diff --git a/server/app_api.py b/server/app_api.py new file mode 100644 index 00000000..6158a540 --- /dev/null +++ b/server/app_api.py @@ -0,0 +1,438 @@ +import argparse +import asyncio +import json +import logging +import os +import sys + +import torch + +# Initialize CUDA before any other imports to prevent core dump. +if torch.cuda.is_available(): + torch.cuda.init() + + +from aiohttp import web +from aiortc import ( + MediaStreamTrack, + RTCConfiguration, + RTCIceServer, + RTCPeerConnection, + RTCSessionDescription, +) +from aiortc.codecs import h264 +from aiortc.rtcrtpsender import RTCRtpSender +from pipeline_api import Pipeline # TODO: Better integration (Are we replacing pipeline with pipeline_api?) +from twilio.rest import Client +from utils import patch_loop_datagram, add_prefix_to_app_routes, FPSMeter +from metrics import MetricsManager, StreamStatsManager +import time + +logger = logging.getLogger(__name__) +logging.getLogger("aiortc.rtcrtpsender").setLevel(logging.WARNING) +logging.getLogger("aiortc.rtcrtpreceiver").setLevel(logging.WARNING) + + +MAX_BITRATE = 2000000 +MIN_BITRATE = 2000000 + + +class VideoStreamTrack(MediaStreamTrack): + """video stream track that processes video frames using a pipeline. + + Attributes: + kind (str): The kind of media, which is "video" for this class. + track (MediaStreamTrack): The underlying media stream track. + pipeline (Pipeline): The processing pipeline to apply to each video frame. + """ + + kind = "video" + + def __init__(self, track: MediaStreamTrack, pipeline: Pipeline): + """Initialize the VideoStreamTrack. + + Args: + track: The underlying media stream track. + pipeline: The processing pipeline to apply to each video frame. + """ + super().__init__() + self.track = track + self.pipeline = pipeline + self.fps_meter = FPSMeter( + metrics_manager=app["metrics_manager"], track_id=track.id + ) + self.running = True + self.collect_task = asyncio.create_task(self.collect_frames()) + + # Add cleanup when track ends + @track.on("ended") + async def on_ended(): + logger.info("Source video track ended, stopping collection") + await cancel_collect_frames(self) + + async def collect_frames(self): + """Collect video frames from the underlying track and pass them to + the processing pipeline. Stops when track ends or connection closes. + """ + try: + while self.running: + try: + frame = await self.track.recv() + await self.pipeline.put_video_frame(frame) + except asyncio.CancelledError: + logger.info("Frame collection cancelled") + break + except Exception as e: + if "MediaStreamError" in str(type(e)): + logger.info("Media stream ended") + else: + logger.error(f"Error collecting video frames: {str(e)}") + self.running = False + break + + # Perform cleanup outside the exception handler + logger.info("Video frame collection stopped") + except asyncio.CancelledError: + logger.info("Frame collection task cancelled") + except Exception as e: + logger.error(f"Unexpected error in frame collection: {str(e)}") + finally: + await self.pipeline.cleanup() + + async def recv(self): + """Receive a processed video frame from the pipeline, increment the frame + count for FPS calculation and return the processed frame to the client. + """ + processed_frame = await self.pipeline.get_processed_video_frame() + + # Increment the frame count to calculate FPS. + await self.fps_meter.increment_frame_count() + + return processed_frame + + +class AudioStreamTrack(MediaStreamTrack): + kind = "audio" + + def __init__(self, track: MediaStreamTrack, pipeline): + super().__init__() + self.track = track + self.pipeline = pipeline + self.running = True + self.collect_task = asyncio.create_task(self.collect_frames()) + + # Add cleanup when track ends + @track.on("ended") + async def on_ended(): + logger.info("Source audio track ended, stopping collection") + await cancel_collect_frames(self) + + async def collect_frames(self): + """Collect audio frames from the underlying track and pass them to + the processing pipeline. Stops when track ends or connection closes. + """ + try: + while self.running: + try: + frame = await self.track.recv() + await self.pipeline.put_audio_frame(frame) + except asyncio.CancelledError: + logger.info("Audio frame collection cancelled") + break + except Exception as e: + if "MediaStreamError" in str(type(e)): + logger.info("Media stream ended") + else: + logger.error(f"Error collecting audio frames: {str(e)}") + self.running = False + break + + # Perform cleanup outside the exception handler + logger.info("Audio frame collection stopped") + except asyncio.CancelledError: + logger.info("Frame collection task cancelled") + except Exception as e: + logger.error(f"Unexpected error in audio frame collection: {str(e)}") + finally: + await self.pipeline.cleanup() + + async def recv(self): + return await self.pipeline.get_processed_audio_frame() + + +def force_codec(pc, sender, forced_codec): + kind = forced_codec.split("/")[0] + codecs = RTCRtpSender.getCapabilities(kind).codecs + transceiver = next(t for t in pc.getTransceivers() if t.sender == sender) + codecPrefs = [codec for codec in codecs if codec.mimeType == forced_codec] + transceiver.setCodecPreferences(codecPrefs) + + +def get_twilio_token(): + account_sid = os.getenv("TWILIO_ACCOUNT_SID") + auth_token = os.getenv("TWILIO_AUTH_TOKEN") + + if account_sid is None or auth_token is None: + return None + + client = Client(account_sid, auth_token) + + token = client.tokens.create() + + return token + + +def get_ice_servers(): + ice_servers = [] + + token = get_twilio_token() + if token is not None: + # Use Twilio TURN servers + for server in token.ice_servers: + if server["url"].startswith("turn:"): + turn = RTCIceServer( + urls=[server["urls"]], + credential=server["credential"], + username=server["username"], + ) + ice_servers.append(turn) + + return ice_servers + + +async def offer(request): + pipeline = request.app["pipeline"] + pcs = request.app["pcs"] + + params = await request.json() + + await pipeline.set_prompts(params["prompts"]) + + offer_params = params["offer"] + offer = RTCSessionDescription(sdp=offer_params["sdp"], type=offer_params["type"]) + + ice_servers = get_ice_servers() + if len(ice_servers) > 0: + pc = RTCPeerConnection( + configuration=RTCConfiguration(iceServers=get_ice_servers()) + ) + else: + pc = RTCPeerConnection() + + pcs.add(pc) + + tracks = {"video": None, "audio": None} + + # Only add video transceiver if video is present in the offer + if "m=video" in offer.sdp: + # Prefer h264 + transceiver = pc.addTransceiver("video") + caps = RTCRtpSender.getCapabilities("video") + prefs = list(filter(lambda x: x.name == "H264", caps.codecs)) + transceiver.setCodecPreferences(prefs) + + # Monkey patch max and min bitrate to ensure constant bitrate + h264.MAX_BITRATE = MAX_BITRATE + h264.MIN_BITRATE = MIN_BITRATE + + # Handle control channel from client + @pc.on("datachannel") + def on_datachannel(channel): + if channel.label == "control": + + @channel.on("message") + async def on_message(message): + try: + params = json.loads(message) + + if params.get("type") == "get_nodes": + nodes_info = await pipeline.get_nodes_info() + response = {"type": "nodes_info", "nodes": nodes_info} + channel.send(json.dumps(response)) + elif params.get("type") == "update_prompts": + if "prompts" not in params: + logger.warning( + "[Control] Missing prompt in update_prompt message" + ) + return + await pipeline.update_prompts(params["prompts"]) + response = {"type": "prompts_updated", "success": True} + channel.send(json.dumps(response)) + else: + logger.warning( + "[Server] Invalid message format - missing required fields" + ) + except json.JSONDecodeError: + logger.error("[Server] Invalid JSON received") + except Exception as e: + logger.error(f"[Server] Error processing message: {str(e)}") + + @pc.on("track") + def on_track(track): + logger.info(f"Track received: {track.kind}") + if track.kind == "video": + videoTrack = VideoStreamTrack(track, pipeline) + tracks["video"] = videoTrack + sender = pc.addTrack(videoTrack) + + # Store video track in app for stats. + stream_id = track.id + request.app["video_tracks"][stream_id] = videoTrack + + codec = "video/H264" + force_codec(pc, sender, codec) + elif track.kind == "audio": + audioTrack = AudioStreamTrack(track, pipeline) + tracks["audio"] = audioTrack + pc.addTrack(audioTrack) + + @track.on("ended") + async def on_ended(): + logger.info(f"{track.kind} track ended") + request.app["video_tracks"].pop(track.id, None) + + @pc.on("connectionstatechange") + async def on_connectionstatechange(): + logger.info(f"Connection state is: {pc.connectionState}") + if pc.connectionState == "failed": + await pc.close() + pcs.discard(pc) + elif pc.connectionState == "closed": + await pc.close() + pcs.discard(pc) + + await pc.setRemoteDescription(offer) + + if "m=audio" in pc.remoteDescription.sdp: + await pipeline.warm_audio() + if "m=video" in pc.remoteDescription.sdp: + await pipeline.warm_video() + + answer = await pc.createAnswer() + await pc.setLocalDescription(answer) + + return web.Response( + content_type="application/json", + text=json.dumps( + {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type} + ), + ) + +async def cancel_collect_frames(track): + track.running = False + if hasattr(track, 'collect_task') is not None and not track.collect_task.done(): + try: + track.collect_task.cancel() + await track.collect_task + except (asyncio.CancelledError): + pass + +async def set_prompt(request): + pipeline = request.app["pipeline"] + + prompt = await request.json() + await pipeline.set_prompts(prompt) + + return web.Response(content_type="application/json", text="OK") + + +def health(_): + return web.Response(content_type="application/json", text="OK") + + +async def on_startup(app: web.Application): + if app["media_ports"]: + patch_loop_datagram(app["media_ports"]) + + app["pipeline"] = Pipeline( + cwd=app["workspace"], disable_cuda_malloc=True, gpu_only=True, preview_method='none' + ) + app["pcs"] = set() + app["video_tracks"] = {} + + +async def on_shutdown(app: web.Application): + pcs = app["pcs"] + coros = [pc.close() for pc in pcs] + await asyncio.gather(*coros) + pcs.clear() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run comfystream server") + parser.add_argument("--port", default=8889, help="Set the signaling port") + parser.add_argument( + "--media-ports", default=None, help="Set the UDP ports for WebRTC media" + ) + parser.add_argument("--host", default="127.0.0.1", help="Set the host") + parser.add_argument( + "--workspace", default=None, required=True, help="Set Comfy workspace" + ) + parser.add_argument( + "--log-level", + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + help="Set the logging level", + ) + parser.add_argument( + "--monitor", + default=False, + action="store_true", + help="Start a Prometheus metrics endpoint for monitoring.", + ) + parser.add_argument( + "--stream-id-label", + default=False, + action="store_true", + help="Include stream ID as a label in Prometheus metrics.", + ) + args = parser.parse_args() + + logging.basicConfig( + level=args.log_level.upper(), + format="%(asctime)s [%(levelname)s] %(message)s", + datefmt="%H:%M:%S", + ) + + app = web.Application() + app["media_ports"] = args.media_ports.split(",") if args.media_ports else None + app["workspace"] = args.workspace + + app.on_startup.append(on_startup) + app.on_shutdown.append(on_shutdown) + + app.router.add_get("/", health) + app.router.add_get("/health", health) + + # WebRTC signalling and control routes. + app.router.add_post("/offer", offer) + app.router.add_post("/prompt", set_prompt) + + # Add routes for getting stream statistics. + stream_stats_manager = StreamStatsManager(app) + app.router.add_get( + "/streams/stats", stream_stats_manager.collect_all_stream_metrics + ) + app.router.add_get( + "/stream/{stream_id}/stats", stream_stats_manager.collect_stream_metrics_by_id + ) + + # Add Prometheus metrics endpoint. + app["metrics_manager"] = MetricsManager(include_stream_id=args.stream_id_label) + if args.monitor: + app["metrics_manager"].enable() + logger.info( + f"Monitoring enabled - Prometheus metrics available at: " + f"http://{args.host}:{args.port}/metrics" + ) + app.router.add_get("/metrics", app["metrics_manager"].metrics_handler) + + # Add hosted platform route prefix. + # NOTE: This ensures that the local and hosted experiences have consistent routes. + add_prefix_to_app_routes(app, "/live") + + def force_print(*args, **kwargs): + print(*args, **kwargs, flush=True) + sys.stdout.flush() + + web.run_app(app, host=args.host, port=int(args.port), print=force_print) diff --git a/server/pipeline_api.py b/server/pipeline_api.py new file mode 100644 index 00000000..0f04e773 --- /dev/null +++ b/server/pipeline_api.py @@ -0,0 +1,195 @@ +import av +import torch +import numpy as np +import asyncio +import logging +import time +from PIL import Image +from io import BytesIO + +from typing import Any, Dict, Union, List +from comfystream.client_api import ComfyStreamClient +from comfystream import tensor_cache + +WARMUP_RUNS = 5 +logger = logging.getLogger(__name__) + + +class Pipeline: + def __init__(self, **kwargs): + self.client = ComfyStreamClient(**kwargs) + self.video_incoming_frames = asyncio.Queue() + self.audio_incoming_frames = asyncio.Queue() + + self.processed_audio_buffer = np.array([], dtype=np.int16) + + async def warm_video(self): + """Warm up the video pipeline with dummy frames""" + logger.info("Warming up video pipeline...") + + # Create a properly formatted dummy frame (random color pattern) + # Using standard tensor shape: BCHW [1, 3, 512, 512] + tensor = torch.rand(1, 3, 512, 512) # Random values in [0,1] + + # Create a dummy frame and attach the tensor as side_data + dummy_frame = av.VideoFrame(width=512, height=512, format="rgb24") + dummy_frame.side_data.input = tensor + + # Process a few frames for warmup + for i in range(WARMUP_RUNS): + logger.info(f"Video warmup iteration {i+1}/{WARMUP_RUNS}") + self.client.put_video_input(dummy_frame) + await self.client.get_video_output() + + logger.info("Video pipeline warmup complete") + + async def warm_audio(self): + dummy_frame = av.AudioFrame() + dummy_frame.side_data.input = np.random.randint(-32768, 32767, int(48000 * 0.5), dtype=np.int16) # TODO: adds a lot of delay if it doesn't match the buffer size, is warmup needed? + dummy_frame.sample_rate = 48000 + + for _ in range(WARMUP_RUNS): + self.client.put_audio_input(dummy_frame) + await self.client.get_audio_output() + + async def set_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]): + if isinstance(prompts, list): + await self.client.set_prompts(prompts) + else: + await self.client.set_prompts([prompts]) + + async def update_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]): + if isinstance(prompts, list): + await self.client.update_prompts(prompts) + else: + await self.client.update_prompts([prompts]) + + async def put_video_frame(self, frame: av.VideoFrame): + frame.side_data.input = self.video_preprocess(frame) + frame.side_data.skipped = False # Different from LoadTensor, we don't skip frames here + self.client.put_video_input(frame) + await self.video_incoming_frames.put(frame) + + async def put_audio_frame(self, frame: av.AudioFrame): + frame.side_data.input = self.audio_preprocess(frame) + frame.side_data.skipped = False + self.client.put_audio_input(frame) + await self.audio_incoming_frames.put(frame) + + def video_preprocess(self, frame: av.VideoFrame) -> Union[torch.Tensor, np.ndarray]: + """Convert input video frame to tensor in consistent BCHW format""" + try: + frame_np = frame.to_ndarray(format="rgb24") + frame_np = frame_np.astype(np.float32) / 255.0 + tensor = torch.from_numpy(frame_np) + + # TODO: Necessary? + if len(tensor.shape) == 3 and tensor.shape[2] == 3: # HWC format + tensor = tensor.permute(2, 0, 1).unsqueeze(0) # -> BCHW + + # Ensure values are in range [0,1] + if tensor.min() < 0 or tensor.max() > 1: + logger.warning(f"Clamping preprocessing tensor: min={tensor.min().item()}, max={tensor.max().item()}") + tensor = torch.clamp(tensor, 0, 1) + + return tensor + + except Exception as e: + logger.error(f"Error in video_preprocess: {e}") + # Return a default tensor in case of error + return torch.zeros(1, 3, frame.height, frame.width) + + def audio_preprocess(self, frame: av.AudioFrame) -> Union[torch.Tensor, np.ndarray]: + return frame.to_ndarray().ravel().reshape(-1, 2).mean(axis=1).astype(np.int16) + + def video_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.VideoFrame: + """Convert tensor to VideoFrame format""" + try: + # Ensure output is a tensor + if isinstance(output, np.ndarray): + output = torch.from_numpy(output) + + # Convert from BCHW to HWC format for video frame + if len(output.shape) == 4: # BCHW format + output = output.squeeze(0) # Remove batch dimension + if output.shape[0] == 3: # CHW format + output = output.permute(1, 2, 0) # Convert to HWC + + # Convert to numpy array in correct format for VideoFrame + frame_np = (output * 255.0).clamp(0, 255).to(dtype=torch.uint8).cpu().numpy() + + # Create VideoFrame with RGB format + video_frame = av.VideoFrame.from_ndarray(frame_np, format='rgb24') + + logger.info(f"Created video frame with shape: {frame_np.shape}") + return video_frame + + except Exception as e: + logger.error(f"Error in video_postprocess: {str(e)}") + # Return a black frame as fallback + return av.VideoFrame(width=512, height=512, format='rgb24') + + def audio_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.AudioFrame: + return av.AudioFrame.from_ndarray(np.repeat(output, 2).reshape(1, -1)) + + async def get_processed_video_frame(self): + """Get processed video frame from output queue and match it with input frame""" + try: + # Get the frame from the incoming queue first + frame = await self.video_incoming_frames.get() + + while frame.side_data.skipped: + frame = await self.video_incoming_frames.get() + + # Get the processed frame from the output queue + logger.info("Getting video output") + out_tensor = await self.client.get_video_output() + + # If there are more frames in the output queue, drain them to get the most recent + # This helps with synchronization when processing is faster than display + while not tensor_cache.image_outputs.empty(): + try: + newer_tensor = await asyncio.wait_for(self.client.get_video_output(), 0.01) + out_tensor = newer_tensor # Use the most recent frame + logger.info("Using more recent frame from output queue") + except asyncio.TimeoutError: + break + + logger.info(f"Received output tensor with shape: {out_tensor.shape if hasattr(out_tensor, 'shape') else 'unknown'}") + + # Process the output tensor + processed_frame = self.video_postprocess(out_tensor) + processed_frame.pts = frame.pts + processed_frame.time_base = frame.time_base + + return processed_frame + + except Exception as e: + logger.error(f"Error in get_processed_video_frame: {str(e)}") + # Create a black frame as fallback + black_frame = av.VideoFrame(width=512, height=512, format='rgb24') + return black_frame + + async def get_processed_audio_frame(self): + # TODO: make it generic to support purely generative audio cases and also add frame skipping + frame = await self.audio_incoming_frames.get() + if frame.samples > len(self.processed_audio_buffer): + out_tensor = await self.client.get_audio_output() + self.processed_audio_buffer = np.concatenate([self.processed_audio_buffer, out_tensor]) + out_data = self.processed_audio_buffer[:frame.samples] + self.processed_audio_buffer = self.processed_audio_buffer[frame.samples:] + + processed_frame = self.audio_postprocess(out_data) + processed_frame.pts = frame.pts + processed_frame.time_base = frame.time_base + processed_frame.sample_rate = frame.sample_rate + + return processed_frame + + async def get_nodes_info(self) -> Dict[str, Any]: + """Get information about all nodes in the current prompt including metadata.""" + nodes_info = await self.client.get_available_nodes() + return nodes_info + + async def cleanup(self): + await self.client.cleanup() \ No newline at end of file diff --git a/src/comfystream/client_api.py b/src/comfystream/client_api.py new file mode 100644 index 00000000..5e23e391 --- /dev/null +++ b/src/comfystream/client_api.py @@ -0,0 +1,836 @@ +import asyncio +import json +import uuid +import websockets +import base64 +import aiohttp +import logging +import torch +import numpy as np +from io import BytesIO +from PIL import Image +from typing import List, Dict, Any, Optional, Union +import random +import time + +from comfystream import tensor_cache +from comfystream.utils_api import convert_prompt + +logger = logging.getLogger(__name__) + +class ComfyStreamClient: + def __init__(self, host: str = "127.0.0.1", port: int = 8198, **kwargs): + """ + Initialize the ComfyStream client to use the ComfyUI API. + + Args: + host: The hostname or IP address of the ComfyUI server + port: The port number of the ComfyUI server + **kwargs: Additional configuration parameters + """ + self.host = host + self.port = port + self.server_address = f"ws://{host}:{port}/ws" + self.api_base_url = f"http://{host}:{port}/api" + self.client_id = kwargs.get('client_id', str(uuid.uuid4())) + self.api_version = kwargs.get('api_version', "1.0.0") + self.ws = None + self.current_prompts = [] + self.running_prompts = {} + self.cleanup_lock = asyncio.Lock() + + # WebSocket connection + self._ws_listener_task = None + self.execution_complete_event = asyncio.Event() + self.execution_started = False + self._prompt_id = None + + # Configure logging + if 'log_level' in kwargs: + logger.setLevel(kwargs['log_level']) + + # Enable debug mode + self.debug = kwargs.get('debug', True) + + logger.info(f"ComfyStreamClient initialized with host: {host}, port: {port}, client_id: {self.client_id}") + + async def set_prompts(self, prompts: List[Dict]): + """Set prompts and run them (compatible with original interface)""" + # Convert prompts (this already randomizes seeds, but we'll enhance it) + self.current_prompts = [convert_prompt(prompt) for prompt in prompts] + + # Create tasks for each prompt + for idx in range(len(self.current_prompts)): + task = asyncio.create_task(self.run_prompt(idx)) + self.running_prompts[idx] = task + + logger.info(f"Set {len(self.current_prompts)} prompts for execution") + + async def update_prompts(self, prompts: List[Dict]): + """Update existing prompts (compatible with original interface)""" + if len(prompts) != len(self.current_prompts): + raise ValueError( + "Number of updated prompts must match the number of currently running prompts." + ) + self.current_prompts = [convert_prompt(prompt) for prompt in prompts] + logger.info(f"Updated {len(self.current_prompts)} prompts") + + async def run_prompt(self, prompt_index: int): + """Run a prompt continuously, processing new frames as they arrive""" + logger.info(f"Running prompt {prompt_index}") + + # Make sure WebSocket is connected + await self._connect_websocket() + + # Always set execution complete at start to allow first frame to be processed + self.execution_complete_event.set() + logger.info("Setting execution_complete_event to TRUE at start") + + try: + while True: + # Wait until we have tensor data available before sending prompt + if tensor_cache.image_inputs.empty(): + await asyncio.sleep(0.01) # Reduced sleep time for faster checking + continue + + # Clear event before sending a new prompt + if self.execution_complete_event.is_set(): + # Reset execution state for next frame + self.execution_complete_event.clear() + logger.info("Setting execution_complete_event to FALSE before executing prompt") + + # Queue the prompt with the current frame + await self._execute_prompt(prompt_index) + + # Wait for execution completion with timeout + try: + logger.info("Waiting for execution to complete (max 10 seconds)...") + await asyncio.wait_for(self.execution_complete_event.wait(), timeout=10.0) + logger.info("Execution complete, ready for next frame") + except asyncio.TimeoutError: + logger.error("Timeout waiting for execution, forcing continuation") + self.execution_complete_event.set() + else: + # If execution is not complete, check again shortly + await asyncio.sleep(0.01) # Short sleep to prevent CPU spinning + + except asyncio.CancelledError: + logger.info(f"Prompt {prompt_index} execution cancelled") + raise + except Exception as e: + logger.error(f"Error in run_prompt: {str(e)}") + raise + + async def _connect_websocket(self): + """Connect to the ComfyUI WebSocket endpoint""" + try: + if self.ws is not None and self.ws.open: + return self.ws + + # Close existing connection if any + if self.ws is not None: + try: + await self.ws.close() + except: + pass + self.ws = None + + logger.info(f"Connecting to WebSocket at {self.server_address}?clientId={self.client_id}") + + # Set a reasonable timeout for connection + websocket_timeout = 10.0 # seconds + + try: + # Connect with proper error handling + self.ws = await websockets.connect( + f"{self.server_address}?clientId={self.client_id}", + ping_interval=5, + ping_timeout=10, + close_timeout=5, + max_size=None, # No limit on message size + ssl=None + ) + + logger.info("WebSocket connected successfully") + + # Start the listener task if not already running + if self._ws_listener_task is None or self._ws_listener_task.done(): + self._ws_listener_task = asyncio.create_task(self._ws_listener()) + logger.info("Started WebSocket listener task") + + return self.ws + + except (websockets.exceptions.WebSocketException, ConnectionError, OSError) as e: + logger.error(f"WebSocket connection error: {e}") + self.ws = None + # Signal execution complete to prevent hanging if connection fails + self.execution_complete_event.set() + # Retry after a delay + await asyncio.sleep(1) + return await self._connect_websocket() + + except Exception as e: + logger.error(f"Unexpected error in _connect_websocket: {e}") + self.ws = None + # Signal execution complete to prevent hanging + self.execution_complete_event.set() + return None + + async def _ws_listener(self): + """Listen for WebSocket messages and process them""" + try: + logger.info(f"WebSocket listener started") + while True: + if self.ws is None: + try: + await self._connect_websocket() + except Exception as e: + logger.error(f"Error connecting to WebSocket: {e}") + await asyncio.sleep(1) + continue + + try: + # Receive and process messages + message = await self.ws.recv() + + if isinstance(message, str): + # Process JSON messages + await self._handle_text_message(message) + else: + # Handle binary data - likely image preview or tensor data + await self._handle_binary_message(message) + + except websockets.exceptions.ConnectionClosed: + logger.info("WebSocket connection closed") + self.ws = None + await asyncio.sleep(1) + except Exception as e: + logger.error(f"Error in WebSocket listener: {e}") + await asyncio.sleep(1) + + except asyncio.CancelledError: + logger.info("WebSocket listener cancelled") + raise + except Exception as e: + logger.error(f"Unexpected error in WebSocket listener: {e}") + + async def _handle_text_message(self, message: str): + """Process text (JSON) messages from the WebSocket""" + try: + data = json.loads(message) + message_type = data.get("type", "unknown") + + # logger.info(f"Received message type: {message_type}") + + # Handle different message types + if message_type == "status": + pass + ''' + # Status message with comfy_ui's queue information + queue_remaining = data.get("data", {}).get("queue_remaining", 0) + exec_info = data.get("data", {}).get("exec_info", {}) + if queue_remaining == 0 and not exec_info: + logger.info("Queue empty, no active execution") + else: + logger.info(f"Queue status: {queue_remaining} items remaining") + ''' + + elif message_type == "progress": + if "data" in data and "value" in data["data"]: + progress = data["data"]["value"] + max_value = data["data"].get("max", 100) + # Log the progress for debugging + # logger.info(f"Progress: {progress}/{max_value}") + + elif message_type == "execution_start": + self.execution_started = True + if "data" in data and "prompt_id" in data["data"]: + self._prompt_id = data["data"]["prompt_id"] + # logger.info(f"Execution started for prompt {self._prompt_id}") + + elif message_type == "executing": + self.execution_started = True + if "data" in data: + if "prompt_id" in data["data"]: + self._prompt_id = data["data"]["prompt_id"] + if "node" in data["data"]: + node_id = data["data"]["node"] + # ogger.info(f"Executing node: {node_id}") + + elif message_type in ["execution_cached", "execution_error", "execution_complete", "execution_interrupted"]: + # logger.info(f"{message_type} message received for prompt {self._prompt_id}") + #self.execution_started = False + + # Always signal completion for these terminal states + # self.execution_complete_event.set() + # logger.info(f"Set execution_complete_event from {message_type}") + pass + + elif message_type == "executed": + # This is sent when a node is completely done + if "data" in data and "node_id" in data["data"]: + node_id = data["data"]["node_id"] + logger.info(f"Node execution complete: {node_id}") + + # Check if this is our SaveTensorAPI node + if "SaveTensorAPI" in str(node_id): + logger.info("SaveTensorAPI node executed, checking for tensor data") + # The binary data should come separately via websocket + + # If we've been running for too long without tensor data, force completion + elif self.execution_started and not self.execution_complete_event.is_set(): + # Check if this was the last node + if data.get("data", {}).get("remaining", 0) == 0: + # logger.info("All nodes executed but no tensor data received, forcing completion") + # self.execution_complete_event.set() + pass + + elif message_type == "executed_node" and "output" in data.get("data", {}): + node_id = data.get("data", {}).get("node_id") + output_data = data.get("data", {}).get("output", {}) + prompt_id = data.get("data", {}).get("prompt_id", "unknown") + + logger.info(f"Node {node_id} executed in prompt {prompt_id}") + + ''' + # Check if this is from ETN_SendImageWebSocket node + if "ui" in output_data and "images" in output_data["ui"]: + images_info = output_data["ui"]["images"] + logger.info(f"Found image output from ETN_SendImageWebSocket in node {node_id}") + + # Images will be received via binary websocket messages after this event + # The binary handler will take care of them + pass + + # Keep existing handling for tensor data + elif "ui" in output_data and "tensor" in output_data["ui"]: + tensor_info = output_data["ui"]["tensor"] + tensor_id = tensor_info.get("tensor_id", "unknown") + logger.info(f"Found tensor data with ID: {tensor_id} in node {node_id}") + + # Decode the tensor data + tensor_data = await self._decode_tensor_data(tensor_info) + if tensor_data is not None: + # Add to output queue without waiting to unblock event loop + tensor_cache.image_outputs.put_nowait(tensor_data) + logger.info(f"Added tensor to output queue, shape: {tensor_data.shape}") + + # IMPORTANT: Immediately signal that we can proceed with the next frame + # when we receive tensor data, don't wait + logger.info("Received tensor data, immediately signaling execution complete") + self.execution_complete_event.set() + logger.info("Set execution_complete_event after processing tensor data") + else: + logger.error("Failed to decode tensor data") + # Signal completion even if decoding failed to prevent hanging + self.execution_complete_event.set() + ''' + except json.JSONDecodeError: + logger.error(f"Invalid JSON message: {message[:100]}...") + except Exception as e: + logger.error(f"Error handling WebSocket message: {e}") + # Signal completion on error to prevent hanging + self.execution_complete_event.set() + + async def _handle_binary_message(self, binary_data): + """Process binary messages from the WebSocket""" + try: + # Log binary message information + # logger.info(f"Received binary message of size: {len(binary_data)} bytes") + + # Signal execution is complete, queue next frame + self.execution_complete_event.set() + + # Binary messages in ComfyUI start with a header + # First 8 bytes are used for header information + if len(binary_data) <= 8: + logger.warning(f"Binary message too short: {len(binary_data)} bytes") + return + + # Extract header data based on the actual format observed in logs + # Header bytes (hex): 0000000100000001 - this appears to be the format in use + event_type = int.from_bytes(binary_data[:4], byteorder='little') + format_type = int.from_bytes(binary_data[4:8], byteorder='little') + data = binary_data[8:] + + # Log header details + logger.info(f"Binary message header: event_type={event_type}, format_type={format_type}, data_size={len(data)} bytes") + #logger.info(f"Header bytes (hex): {binary_data[:8].hex()}") + + # Check if this is an image (JPEG starts with 0xFF, 0xD8, PNG starts with 0x89, 0x50) + is_jpeg = data[:2] == b'\xff\xd8' + is_png = data[:4] == b'\x89\x50\x4e\x47' + + if is_jpeg or is_png: + image_format = "JPEG" if is_jpeg else "PNG" + logger.info(f"Detected {image_format} image based on magic bytes") + + # Create a NEW binary message with the expected header format for the JavaScript client + # The JavaScript expects: [0:4]=1 (PREVIEW_IMAGE), [4:8]=1 (JPEG format) or [4:8]=2 (PNG format) + # This matches exactly what the JS code is looking for: + # const event = dataView.getUint32(0); // event type (1 = PREVIEW_IMAGE) + # const format = dataView.getUint32(4); // format (1 = JPEG, 2 = PNG) + js_event_type = (1).to_bytes(4, byteorder='little') # PREVIEW_IMAGE = 1 + js_format_type = (1 if is_jpeg else 2).to_bytes(4, byteorder='little') + transformed_data = js_event_type + js_format_type + data + + # Forward to WebSocket client if connected + # if self.ws: + # await self.ws.send(transformed_data) + # logger.info(f"Sent transformed {image_format} image data to WebSocket with correct JS header format") + #else: + # logger.error("WebSocket not connected, cannot forward image to JS client") + + # Process the image for our pipeline + try: + # Decode the image + img = Image.open(BytesIO(data)) + logger.info(f"Successfully decoded image: size={img.size}, mode={img.mode}, format={img.format}") + + # Convert to RGB if not already + if img.mode != "RGB": + img = img.convert("RGB") + logger.info(f"Converted image to RGB mode") + + # Save image to temp folder as a file + # TESTING + ''' + import os + import tempfile + temp_folder = os.path.join(tempfile.gettempdir(), "comfyui_images") + os.makedirs(temp_folder, exist_ok=True) + img_path = os.path.join(temp_folder, f"comfyui_image_{time.time()}.png") + img.save(img_path) + logger.info(f"Saved image to {img_path}") + ''' + + # Convert to tensor (normalize to [0,1] range for consistency) + img_np = np.array(img).astype(np.float32) / 255.0 + tensor = torch.from_numpy(img_np) + + # CRITICAL: Ensure dimensions are correctly understood + # The tensor should be in HWC format initially from PIL/numpy + logger.info(f"Initial tensor shape from image: {tensor.shape}") + + # Convert from HWC to BCHW format for consistency with model expectations + if len(tensor.shape) == 3 and tensor.shape[2] == 3: # HWC format (H,W,3) + tensor = tensor.permute(2, 0, 1).unsqueeze(0) # -> BCHW (1,3,H,W) + logger.info(f"Converted to BCHW tensor with shape: {tensor.shape}") + + # Check for NaN or Inf values + if torch.isnan(tensor).any() or torch.isinf(tensor).any(): + logger.warning("Tensor contains NaN or Inf values! Replacing with zeros") + tensor = torch.nan_to_num(tensor, nan=0.0, posinf=1.0, neginf=0.0) + + # Log detailed tensor info for debugging + logger.info(f"Final tensor with shape: {tensor.shape}, " + f"min={tensor.min().item()}, max={tensor.max().item()}, " + f"mean={tensor.mean().item()}") + + # Add to output queue without waiting + tensor_cache.image_outputs.put_nowait(tensor) + logger.info(f"Added tensor to output queue, queue size: {tensor_cache.image_outputs.qsize()}") + return + + except Exception as img_error: + logger.error(f"Error processing image: {img_error}", exc_info=True) + + # If we get here, we couldn't process the image + logger.warning("Failed to process image, creating default tensor") + default_tensor = torch.zeros(1, 3, 512, 512) + tensor_cache.image_outputs.put_nowait(default_tensor) + self.execution_complete_event.set() + + except Exception as e: + logger.error(f"Error handling binary message: {e}", exc_info=True) + # Set execution complete event to avoid hanging + self.execution_complete_event.set() + + async def _execute_prompt(self, prompt_index: int): + """Execute a prompt via the ComfyUI API""" + try: + # Get the prompt to execute + prompt = self.current_prompts[prompt_index] + + # Ensure all seed values are randomized for every execution + # This forces ComfyUI to not use cached results + for node_id, node in prompt.items(): + if isinstance(node, dict) and "inputs" in node: + if "seed" in node["inputs"]: + # Generate a truly random seed each time + random_seed = random.randint(0, 18446744073709551615) + node["inputs"]["seed"] = random_seed + logger.info(f"Randomized seed to {random_seed} for node {node_id}") + + # Also randomize noise_seed if present + if "noise_seed" in node["inputs"]: + noise_seed = random.randint(0, 18446744073709551615) + node["inputs"]["noise_seed"] = noise_seed + logger.info(f"Randomized noise_seed to {noise_seed} for node {node_id}") + + # Add a timestamp parameter to each node to prevent caching + # This is a "hidden" trick to force ComfyUI to consider each execution unique + timestamp = int(time.time() * 1000) # millisecond timestamp + for node_id, node in prompt.items(): + if isinstance(node, dict) and "inputs" in node: + # Add a timestamp parameter to ETN_LoadImageBase64 nodes + if node.get("class_type") in ["ETN_LoadImageBase64", "LoadImageBase64"]: + # Add a unique cache-busting parameter + node["inputs"]["_timestamp"] = timestamp + logger.info(f"Added timestamp {timestamp} to node {node_id}") + + # Check if we have a frame waiting to be processed + if not tensor_cache.image_inputs.empty(): + logger.info("Found tensor in input queue, preparing for API") + # Get the frame from the cache - make sure to get the most recent frame + while not tensor_cache.image_inputs.empty(): + frame_or_tensor = tensor_cache.image_inputs.get_nowait() + + # Find ETN_LoadImageBase64 nodes + load_image_nodes = [] + for node_id, node in prompt.items(): + if isinstance(node, dict) and node.get("class_type") == "ETN_LoadImageBase64": + load_image_nodes.append(node_id) + + if not load_image_nodes: + # Also check for regular LoadImageBase64 nodes as fallback + for node_id, node in prompt.items(): + if isinstance(node, dict) and node.get("class_type") == "LoadImageBase64": + load_image_nodes.append(node_id) + + if not load_image_nodes: + logger.warning("No ETN_LoadImageBase64 or LoadImageBase64 nodes found in the prompt") + self.execution_complete_event.set() # Signal completion + return + else: + # Convert the frame/tensor to base64 and include directly in the prompt + try: + # Get the actual tensor data - handle different input types + tensor = None + + # Check if it's a PyAV VideoFrame with preprocessed tensor in side_data + if hasattr(frame_or_tensor, 'side_data') and hasattr(frame_or_tensor.side_data, 'input'): + tensor = frame_or_tensor.side_data.input + logger.info(f"Using preprocessed tensor from frame.side_data.input with shape {tensor.shape}") + # Check if it's a PyTorch tensor + elif isinstance(frame_or_tensor, torch.Tensor): + tensor = frame_or_tensor + logger.info(f"Using tensor directly with shape {tensor.shape}") + # Check if it's a numpy array + elif isinstance(frame_or_tensor, np.ndarray): + tensor = torch.from_numpy(frame_or_tensor).float() + logger.info(f"Converted numpy array to tensor with shape {tensor.shape}") + else: + # If it's a PyAV frame without preprocessed data, convert it + try: + if hasattr(frame_or_tensor, 'to_ndarray'): + frame_np = frame_or_tensor.to_ndarray(format="rgb24").astype(np.float32) / 255.0 + tensor = torch.from_numpy(frame_np).unsqueeze(0) + logger.info(f"Converted PyAV frame to tensor with shape {tensor.shape}") + else: + logger.error(f"Unsupported frame type: {type(frame_or_tensor)}") + self.execution_complete_event.set() + return + except Exception as e: + logger.error(f"Error converting frame to tensor: {e}") + self.execution_complete_event.set() + return + + if tensor is None: + logger.error("Failed to get valid tensor data from input") + self.execution_complete_event.set() + return + + # Now process the tensor (which should be a proper PyTorch tensor) + # Ensure it's a tensor on CPU and detached + tensor = tensor.detach().cpu().float() + + # Handle different formats + if len(tensor.shape) == 4: # BCHW format (batch) + tensor = tensor[0] # Take first image from batch + + # Ensure it's in CHW format + if len(tensor.shape) == 3 and tensor.shape[2] == 3: # HWC format + tensor = tensor.permute(2, 0, 1) # Convert to CHW + + # Convert to PIL image for saving + tensor_np = (tensor.permute(1, 2, 0) * 255).clamp(0, 255).numpy().astype(np.uint8) + img = Image.fromarray(tensor_np) + + # Save as PNG to BytesIO and convert to base64 string + buffer = BytesIO() + img.save(buffer, format="PNG") + buffer.seek(0) + + # Encode as base64 - for ETN_LoadImageBase64, we need the raw base64 string + img_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8') + logger.info(f"Created base64 string of length: {len(img_base64)}") + + # Update all ETN_LoadImageBase64 nodes with the base64 data + for node_id in load_image_nodes: + prompt[node_id]["inputs"]["image"] = img_base64 + # Add a small random suffix to image data to prevent caching + rand_suffix = str(random.randint(1, 1000000)) + prompt[node_id]["inputs"]["_cache_buster"] = rand_suffix + logger.info(f"Updated node {node_id} with base64 string and cache buster {rand_suffix}") + + except Exception as e: + logger.error(f"Error converting tensor to base64: {e}") + # Signal execution complete in case of error + self.execution_complete_event.set() + return + else: + logger.info("No tensor in input queue, skipping prompt execution") + self.execution_complete_event.set() # Signal completion + return + + # Execute the prompt via API + async with aiohttp.ClientSession() as session: + api_url = f"{self.api_base_url}/prompt" + payload = { + "prompt": prompt, + "client_id": self.client_id + } + + # Send the request + logger.info(f"Sending prompt to {api_url}") + async with session.post(api_url, json=payload) as response: + if response.status == 200: + result = await response.json() + self._prompt_id = result.get("prompt_id") + logger.info(f"Prompt queued with ID: {self._prompt_id}") + self.execution_started = True + else: + error_text = await response.text() + logger.error(f"Error queueing prompt: {response.status} - {error_text}") + # Signal execution complete in case of error + self.execution_complete_event.set() + + except aiohttp.ClientError as e: + logger.error(f"Client error queueing prompt: {e}") + self.execution_complete_event.set() + except Exception as e: + logger.error(f"Error executing prompt: {e}") + # Signal execution complete in case of error + self.execution_complete_event.set() + + async def _send_tensor_via_websocket(self, tensor): + """Send tensor data via the websocket connection""" + try: + if self.ws is None: + logger.error("WebSocket not connected, cannot send tensor") + self.execution_complete_event.set() # Prevent hanging + return + + # Convert the tensor to image format for sending + if isinstance(tensor, np.ndarray): + tensor = torch.from_numpy(tensor).float() + + # Ensure on CPU and correct format + tensor = tensor.detach().cpu().float() + + # Prepare binary data + if len(tensor.shape) == 4: # BCHW format (batch of images) + if tensor.shape[0] > 1: + logger.info(f"Taking first image from batch of {tensor.shape[0]}") + tensor = tensor[0] # Take first image if batch + + # Ensure CHW format (3 channels) + if len(tensor.shape) == 3: + if tensor.shape[0] != 3 and tensor.shape[2] == 3: # HWC format + tensor = tensor.permute(2, 0, 1) # Convert to CHW + elif tensor.shape[0] != 3: + logger.warning(f"Tensor doesn't have 3 channels: {tensor.shape}. Creating standard tensor.") + # Create a standard RGB tensor + tensor = torch.zeros(3, 512, 512) + else: + logger.warning(f"Tensor has unexpected shape: {tensor.shape}. Creating standard tensor.") + # Create a standard RGB tensor + tensor = torch.zeros(3, 512, 512) + + # Check tensor dimensions and log detailed info + logger.info(f"Original tensor for WS: shape={tensor.shape}, min={tensor.min().item():.4f}, max={tensor.max().item():.4f}") + + # CRITICAL FIX: The issue is with the shape - no need to resize if dimensions are fine + if tensor.shape[1] < 64 or tensor.shape[2] < 64: + logger.warning(f"Tensor dimensions too small: {tensor.shape}. Resizing to 512x512") + import torch.nn.functional as F + tensor = tensor.unsqueeze(0) # Add batch dimension for interpolate + tensor = F.interpolate(tensor, size=(512, 512), mode='bilinear', align_corners=False) + tensor = tensor.squeeze(0) # Remove batch dimension after resize + logger.info(f"Resized tensor to: {tensor.shape}") + + # Check for NaN or Inf values + if torch.isnan(tensor).any() or torch.isinf(tensor).any(): + logger.warning("Tensor contains NaN or Inf values! Replacing with zeros.") + tensor = torch.nan_to_num(tensor, nan=0.0, posinf=1.0, neginf=0.0) + + # Convert to image (HWC for PIL) + tensor_np = (tensor.permute(1, 2, 0) * 255).clamp(0, 255).numpy().astype(np.uint8) + img = Image.fromarray(tensor_np) + + logger.info(f"Converted to PIL image with dimensions: {img.size}") + + # Convert to PNG + buffer = BytesIO() + img.save(buffer, format="PNG") + buffer.seek(0) + img_bytes = buffer.getvalue() + + # CRITICAL FIX: We need to send the binary data with a proper node ID prefix + # LoadTensorAPI node expects this header format to identify the target node + # The first 4 bytes are the message type (3 for binary tensor) and the next 4 are the node ID + # Since we don't know the exact node ID, we'll use a generic one that will be interpreted as + # "send this to the currently waiting LoadTensorAPI node" + + # Build header (8 bytes total) + header = bytearray() + # Message type 3 (custom binary tensor data) + header.extend((3).to_bytes(4, byteorder='little')) + # Generic node ID (0 means "send to whatever node is waiting") + header.extend((0).to_bytes(4, byteorder='little')) + + # Combine header and image data + full_data = header + img_bytes + + # Send binary data via websocket + await self.ws.send(full_data) + logger.info(f"Sent tensor as PNG image via websocket with proper header, size: {len(full_data)} bytes, image dimensions: {img.size}") + + except Exception as e: + logger.error(f"Error sending tensor via websocket: {e}") + + # Signal execution complete in case of error + self.execution_complete_event.set() + + async def cleanup(self): + """Clean up resources""" + async with self.cleanup_lock: + # Cancel all running tasks + for task in self.running_prompts.values(): + if not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + self.running_prompts.clear() + + # Close WebSocket connection + if self.ws: + try: + await self.ws.close() + except Exception as e: + logger.error(f"Error closing WebSocket: {e}") + self.ws = None + + # Cancel WebSocket listener task + if self._ws_listener_task and not self._ws_listener_task.done(): + self._ws_listener_task.cancel() + try: + await self._ws_listener_task + except asyncio.CancelledError: + pass + self._ws_listener_task = None + + await self.cleanup_queues() + logger.info("Client cleanup complete") + + async def cleanup_queues(self): + """Clean up tensor queues""" + while not tensor_cache.image_inputs.empty(): + tensor_cache.image_inputs.get() + + while not tensor_cache.audio_inputs.empty(): + tensor_cache.audio_inputs.get() + + while tensor_cache.image_outputs.qsize() > 0: + try: + await tensor_cache.image_outputs.get() + except: + pass + + while tensor_cache.audio_outputs.qsize() > 0: + try: + await tensor_cache.audio_outputs.get() + except: + pass + + logger.info("Tensor queues cleared") + + def put_video_input(self, tensor: Union[torch.Tensor, np.ndarray]): + """ + Put a video TENSOR into the tensor cache for processing. + + Args: + tensor: Video frame as a tensor (or numpy array) + """ + try: + # Only remove one frame if the queue is full (like in client.py) + if tensor_cache.image_inputs.full(): + tensor_cache.image_inputs.get_nowait() + + # Ensure tensor is detached if it's a torch tensor + if isinstance(tensor, torch.Tensor): + tensor = tensor.detach() + + tensor_cache.image_inputs.put(tensor) + + except Exception as e: + logger.error(f"Error in put_video_input: {e}") + + def put_audio_input(self, frame): + """Put audio frame into tensor cache""" + tensor_cache.audio_inputs.put(frame) + + async def get_video_output(self): + """Get processed video frame from tensor cache""" + logger.info("Waiting for processed tensor from output queue") + result = await tensor_cache.image_outputs.get() + logger.info(f"Got processed tensor from output queue: shape={result.shape if hasattr(result, 'shape') else 'unknown'}") + return result + + async def get_audio_output(self): + """Get processed audio frame from tensor cache""" + return await tensor_cache.audio_outputs.get() + + async def get_available_nodes(self): + """Get metadata and available nodes info for current prompts""" + try: + async with aiohttp.ClientSession() as session: + url = f"{self.api_base_url}/object_info" + async with session.get(url) as response: + if response.status == 200: + data = await response.json() + + # Format node info similar to the embedded client response + all_prompts_nodes_info = {} + + for prompt_index, prompt in enumerate(self.current_prompts): + nodes_info = {} + + for node_id, node in prompt.items(): + class_type = node.get('class_type') + if class_type: + nodes_info[node_id] = { + 'class_type': class_type, + 'inputs': {} + } + + if 'inputs' in node: + for input_name, input_value in node['inputs'].items(): + nodes_info[node_id]['inputs'][input_name] = { + 'value': input_value, + 'type': 'unknown' # We don't have type information + } + + all_prompts_nodes_info[prompt_index] = nodes_info + + return all_prompts_nodes_info + + else: + logger.error(f"Error getting node info: {response.status}") + return {} + except Exception as e: + logger.error(f"Error getting node info: {str(e)}") + return {} \ No newline at end of file diff --git a/src/comfystream/utils_api.py b/src/comfystream/utils_api.py new file mode 100644 index 00000000..8b93ec7f --- /dev/null +++ b/src/comfystream/utils_api.py @@ -0,0 +1,245 @@ +import copy +import random + +from typing import Dict, Any +# from comfy.api.components.schema.prompt import Prompt, PromptDictInput + +import logging + +def create_load_tensor_node(): + return { + "inputs": { + "tensor_data": "" # Empty tensor data that will be filled at runtime + }, + "class_type": "LoadTensorAPI", + "_meta": {"title": "Load Tensor (API)"}, + } + + +def create_save_tensor_node(inputs: Dict[Any, Any]): + """Create a SaveTensorAPI node with proper input formatting""" + # Make sure images input is properly formatted [node_id, output_index] + images_input = inputs.get("images") + + # If images input is not properly formatted as [node_id, output_index] + if not isinstance(images_input, list) or len(images_input) != 2: + print(f"Warning: Invalid images input format: {images_input}, using default") + images_input = ["", 0] # Default empty value + + return { + "inputs": { + "images": images_input, # Should be [node_id, output_index] + "format": "png", # Better default than JPG for quality + "quality": 95 + }, + "class_type": "SaveTensorAPI", + "_meta": {"title": "Save Tensor (API)"}, + } + +def convert_prompt(prompt): + + logging.info("Converting prompt: %s", prompt) + + # Set random seeds for any seed nodes + for key, node in prompt.items(): + if not isinstance(node, dict) or "inputs" not in node: + continue + + # Check if this node has a seed input directly + if "seed" in node.get("inputs", {}): + # Generate a random seed (same range as JavaScript's Math.random() * 18446744073709552000) + random_seed = random.randint(0, 18446744073709551615) + node["inputs"]["seed"] = random_seed + print(f"Set random seed {random_seed} for node {key}") + + return prompt + +''' +def convert_prompt(prompt): + + # Check if this is a ComfyUI web UI format prompt with 'nodes' and 'links' + if isinstance(prompt, dict) and 'nodes' in prompt and 'links' in prompt: + # Convert the web UI prompt format to the API format + api_prompt = {} + + # Process each node + for node in prompt['nodes']: + node_id = str(node['id']) + + # Create a node entry in the API format + api_prompt[node_id] = { + 'class_type': node.get('type', 'Unknown'), + 'inputs': {}, + '_meta': { + 'title': node.get('type', 'Unknown') + } + } + + # Process inputs + if 'inputs' in node: + for input_data in node['inputs']: + input_name = input_data.get('name') + link_id = input_data.get('link') + + if input_name and link_id is not None: + # Find the source of this link + for link in prompt['links']: + if link[0] == link_id: # link ID matches + # Get source node and output slot + source_node_id = str(link[1]) + source_slot = link[3] + + # Add to inputs + api_prompt[node_id]['inputs'][input_name] = [ + source_node_id, + source_slot + ] + break + # If no link found, set to empty value + if input_name not in api_prompt[node_id]['inputs']: + api_prompt[node_id]['inputs'][input_name] = None + + # Process widget values + if 'widgets_values' in node: + for i, widget_value in enumerate(node.get('widgets_values', [])): + # Try to determine widget name from properties or use index + widget_name = f"widget_{i}" + # Add to inputs + api_prompt[node_id]['inputs'][widget_name] = widget_value + + # Use this converted prompt instead + prompt = api_prompt + + # Now process as normal API format prompt + prompt = copy.deepcopy(prompt) + + # Set random seeds for any seed nodes + for key, node in prompt.items(): + if not isinstance(node, dict) or "inputs" not in node: + continue + + # Check if this node has a seed input directly + if "seed" in node.get("inputs", {}): + # Generate a random seed (same range as JavaScript's Math.random() * 18446744073709552000) + random_seed = random.randint(0, 18446744073709551615) + node["inputs"]["seed"] = random_seed + print(f"Set random seed {random_seed} for node {key}") + + num_primary_inputs = 0 + num_inputs = 0 + num_outputs = 0 + + keys = { + "PrimaryInputLoadImage": [], + "LoadImage": [], + "PreviewImage": [], + "SaveImage": [], + "LoadTensor": [], + "SaveTensor": [], + "LoadImageBase64": [], + "LoadTensorAPI": [], + "SaveTensorAPI": [], + } + + for key, node in prompt.items(): + if not isinstance(node, dict): + continue + + class_type = node.get("class_type", "") + + # Track primary input and output nodes + if class_type in ["PrimaryInput", "PrimaryInputImage"]: + num_primary_inputs += 1 + keys["PrimaryInputLoadImage"].append(key) + elif class_type in ["LoadImage", "LoadTensor", "LoadAudioTensor", "LoadImageBase64", "LoadTensorAPI"]: + num_inputs += 1 + if class_type == "LoadImage": + keys["LoadImage"].append(key) + elif class_type == "LoadTensor": + keys["LoadTensor"].append(key) + elif class_type == "LoadImageBase64": + keys["LoadImageBase64"].append(key) + elif class_type == "LoadTensorAPI": + keys["LoadTensorAPI"].append(key) + elif class_type in ["PreviewImage", "SaveImage", "SaveTensor", "SaveAudioTensor", "SendImageWebSocket", "SaveTensorAPI"]: + num_outputs += 1 + if class_type == "PreviewImage": + keys["PreviewImage"].append(key) + elif class_type == "SaveImage": + keys["SaveImage"].append(key) + elif class_type == "SaveTensor": + keys["SaveTensor"].append(key) + elif class_type == "SaveTensorAPI": + keys["SaveTensorAPI"].append(key) + + print(f"Found {num_primary_inputs} primary inputs, {num_inputs} inputs, {num_outputs} outputs") + + # Set up connection for video feeds by replacing LoadImage with LoadImageBase64 + if num_inputs == 0 and num_primary_inputs == 0: + # Add a LoadTensorAPI node + new_key = "999990" + prompt[new_key] = create_load_tensor_node() + keys["LoadTensorAPI"].append(new_key) + print("Added LoadTensorAPI node for tensor input") + elif len(keys["LoadTensor"]) > 0 and len(keys["LoadTensorAPI"]) == 0: + # Replace LoadTensor with LoadTensorAPI if found + for key in keys["LoadTensor"]: + prompt[key] = create_load_tensor_node() + keys["LoadTensorAPI"].append(key) + print("Replaced LoadTensor with LoadTensorAPI") + + # Set up connection for output if needed + if num_outputs == 0: + # Find nodes that produce images + image_output_nodes = [] + for key, node in prompt.items(): + if isinstance(node, dict): + class_type = node.get("class_type", "") + # Look for nodes that typically output images + if any(output_type in class_type.lower() for output_type in ["vae", "decode", "img", "image", "upscale", "sample"]): + image_output_nodes.append(key) + # Also check if the node's RETURN_TYPES includes IMAGE + elif "outputs" in node and isinstance(node["outputs"], dict): + for output_name, output_type in node["outputs"].items(): + if "IMAGE" in output_type: + image_output_nodes.append(key) + break + + # If we found image output nodes, connect SaveTensorAPI to them + if image_output_nodes: + for i, node_key in enumerate(image_output_nodes): + new_key = f"999991_{i}" + prompt[new_key] = create_save_tensor_node({"images": [node_key, 0]}) + print(f"Added SaveTensorAPI node connected to {node_key}") + else: + # Try to find the last node in the chain as fallback + last_node = None + max_id = -1 + for key, node in prompt.items(): + if isinstance(node, dict): + try: + node_id = int(key) + if node_id > max_id: + max_id = node_id + last_node = key + except ValueError: + pass + + if last_node: + # Add a SaveTensorAPI node + new_key = "999991" + prompt[new_key] = create_save_tensor_node({"images": [last_node, 0]}) + print(f"Added SaveTensorAPI node connected to {last_node}") + else: + print("Warning: Could not find a suitable node to connect SaveTensorAPI to") + + # Make sure all SaveTensorAPI nodes have proper configuration + for key, node in prompt.items(): + if isinstance(node, dict) and node.get("class_type") == "SaveTensorAPI": + # Ensure format is set to PNG for optimal compatibility + if "inputs" in node: + node["inputs"]["format"] = "png" + + # Return the modified prompt + return prompt +''' \ No newline at end of file From b3a95ad30941320ee4ca16d62d472b80eeaf8b5b Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Tue, 25 Mar 2025 14:43:28 -0400 Subject: [PATCH 02/42] Cleanup of pre/post processing of frames. --- server/pipeline_api.py | 74 +++++++++++++----------------------------- 1 file changed, 22 insertions(+), 52 deletions(-) diff --git a/server/pipeline_api.py b/server/pipeline_api.py index 0f04e773..bfaf4909 100644 --- a/server/pipeline_api.py +++ b/server/pipeline_api.py @@ -3,9 +3,6 @@ import numpy as np import asyncio import logging -import time -from PIL import Image -from io import BytesIO from typing import Any, Dict, Union, List from comfystream.client_api import ComfyStreamClient @@ -76,62 +73,35 @@ async def put_audio_frame(self, frame: av.AudioFrame): self.client.put_audio_input(frame) await self.audio_incoming_frames.put(frame) - def video_preprocess(self, frame: av.VideoFrame) -> Union[torch.Tensor, np.ndarray]: - """Convert input video frame to tensor in consistent BCHW format""" - try: - frame_np = frame.to_ndarray(format="rgb24") - frame_np = frame_np.astype(np.float32) / 255.0 - tensor = torch.from_numpy(frame_np) - - # TODO: Necessary? - if len(tensor.shape) == 3 and tensor.shape[2] == 3: # HWC format - tensor = tensor.permute(2, 0, 1).unsqueeze(0) # -> BCHW - - # Ensure values are in range [0,1] - if tensor.min() < 0 or tensor.max() > 1: - logger.warning(f"Clamping preprocessing tensor: min={tensor.min().item()}, max={tensor.max().item()}") - tensor = torch.clamp(tensor, 0, 1) - - return tensor - - except Exception as e: - logger.error(f"Error in video_preprocess: {e}") - # Return a default tensor in case of error - return torch.zeros(1, 3, frame.height, frame.width) - def audio_preprocess(self, frame: av.AudioFrame) -> Union[torch.Tensor, np.ndarray]: return frame.to_ndarray().ravel().reshape(-1, 2).mean(axis=1).astype(np.int16) + # Works with ComfyUI Native + def video_preprocess(self, frame: av.VideoFrame) -> Union[torch.Tensor, np.ndarray]: + frame_np = frame.to_ndarray(format="rgb24").astype(np.float32) / 255.0 + return torch.from_numpy(frame_np).unsqueeze(0) + + ''' Converts HWC format (height, width, channels) to VideoFrame. def video_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.VideoFrame: - """Convert tensor to VideoFrame format""" - try: - # Ensure output is a tensor - if isinstance(output, np.ndarray): - output = torch.from_numpy(output) - - # Convert from BCHW to HWC format for video frame - if len(output.shape) == 4: # BCHW format - output = output.squeeze(0) # Remove batch dimension - if output.shape[0] == 3: # CHW format - output = output.permute(1, 2, 0) # Convert to HWC - - # Convert to numpy array in correct format for VideoFrame - frame_np = (output * 255.0).clamp(0, 255).to(dtype=torch.uint8).cpu().numpy() - - # Create VideoFrame with RGB format - video_frame = av.VideoFrame.from_ndarray(frame_np, format='rgb24') - - logger.info(f"Created video frame with shape: {frame_np.shape}") - return video_frame - - except Exception as e: - logger.error(f"Error in video_postprocess: {str(e)}") - # Return a black frame as fallback - return av.VideoFrame(width=512, height=512, format='rgb24') + return av.VideoFrame.from_ndarray( + (output * 255.0).clamp(0, 255).to(dtype=torch.uint8).squeeze(0).cpu().numpy() + ) + ''' + + '''Converts BCHW tensor [1,3,H,W] to VideoFrame. Assumes values are in range [0,1].''' + def video_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.VideoFrame: + return av.VideoFrame.from_ndarray( + (output.squeeze(0).permute(1, 2, 0) * 255.0) + .clamp(0, 255) + .to(dtype=torch.uint8) + .cpu() + .numpy(), + format='rgb24' + ) def audio_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.AudioFrame: return av.AudioFrame.from_ndarray(np.repeat(output, 2).reshape(1, -1)) - + async def get_processed_video_frame(self): """Get processed video frame from output queue and match it with input frame""" try: From 50ca4a1fc85ddf1db9e456114a5292fb3978e02a Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Tue, 25 Mar 2025 15:24:36 -0400 Subject: [PATCH 03/42] Added built-in nodes for base64 string and websocket image send, reduced uncessary base64 input frame operations, prep for multi-instance, cleanup. --- nodes/native_utils/__init__.py | 17 ++ nodes/native_utils/load_image_base64.py | 37 +++ nodes/native_utils/send_image_websocket.py | 44 +++ server/pipeline_api.py | 45 +-- src/comfystream/client_api.py | 313 ++++++--------------- src/comfystream/utils_api.py | 225 ++++----------- 6 files changed, 270 insertions(+), 411 deletions(-) create mode 100644 nodes/native_utils/__init__.py create mode 100644 nodes/native_utils/load_image_base64.py create mode 100644 nodes/native_utils/send_image_websocket.py diff --git a/nodes/native_utils/__init__.py b/nodes/native_utils/__init__.py new file mode 100644 index 00000000..e7e7789c --- /dev/null +++ b/nodes/native_utils/__init__.py @@ -0,0 +1,17 @@ +from .load_image_base64 import LoadImageBase64 +from .send_image_websocket import SendImageWebsocket + +# This dictionary is used by ComfyUI to register the nodes +NODE_CLASS_MAPPINGS = { + "LoadImageBase64": LoadImageBase64, + "SendImageWebsocket": SendImageWebsocket +} + +# This dictionary provides display names for the nodes in the UI +NODE_DISPLAY_NAME_MAPPINGS = { + "LoadImageBase64": "Load Image Base64 (ComfyStream)", + "SendImageWebsocket": "Send Image Websocket (ComfyStream)" +} + +# Export these variables for ComfyUI to use +__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] diff --git a/nodes/native_utils/load_image_base64.py b/nodes/native_utils/load_image_base64.py new file mode 100644 index 00000000..f46a90df --- /dev/null +++ b/nodes/native_utils/load_image_base64.py @@ -0,0 +1,37 @@ +# borrowed from Acly's comfyui-tooling-nodes +# https://github.com/Acly/comfyui-tooling-nodes/blob/main/nodes.py + +# TODO: I think we can recieve tensor data directly from the pipeline through the /prompt endpoint as JSON +# This may be more efficient than sending base64 encoded images through the websocket and +# allow for alternative data formats. + +from PIL import Image +import base64 +import numpy as np +import torch +from io import BytesIO + +class LoadImageBase64: + @classmethod + def INPUT_TYPES(s): + return {"required": {"image": ("STRING", {"multiline": False})}} + + RETURN_TYPES = ("IMAGE", "MASK") + CATEGORY = "external_tooling" + FUNCTION = "load_image" + + def load_image(self, image): + imgdata = base64.b64decode(image) + img = Image.open(BytesIO(imgdata)) + + if "A" in img.getbands(): + mask = np.array(img.getchannel("A")).astype(np.float32) / 255.0 + mask = torch.from_numpy(mask) + else: + mask = None + + img = img.convert("RGB") + img = np.array(img).astype(np.float32) / 255.0 + img = torch.from_numpy(img)[None,] + + return (img, mask) \ No newline at end of file diff --git a/nodes/native_utils/send_image_websocket.py b/nodes/native_utils/send_image_websocket.py new file mode 100644 index 00000000..590d3b7e --- /dev/null +++ b/nodes/native_utils/send_image_websocket.py @@ -0,0 +1,44 @@ +# borrowed from Acly's comfyui-tooling-nodes +# https://github.com/Acly/comfyui-tooling-nodes/blob/main/nodes.py + +# TODO: I think we can send tensor data directly to the pipeline in the websocket response. +# Worth talking to ComfyAnonymous about this. + +import numpy as np +from PIL import Image +from server import PromptServer, BinaryEventTypes + +class SendImageWebsocket: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "images": ("IMAGE",), + "format": (["PNG", "JPEG"], {"default": "PNG"}), + } + } + + RETURN_TYPES = () + FUNCTION = "send_images" + OUTPUT_NODE = True + CATEGORY = "external_tooling" + + def send_images(self, images, format): + results = [] + for tensor in images: + array = 255.0 * tensor.cpu().numpy() + image = Image.fromarray(np.clip(array, 0, 255).astype(np.uint8)) + + server = PromptServer.instance + server.send_sync( + BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, + [format, image, None], + server.client_id, + ) + results.append({ + "source": "websocket", + "content-type": f"image/{format.lower()}", + "type": "output", + }) + + return {"ui": {"images": results}} \ No newline at end of file diff --git a/server/pipeline_api.py b/server/pipeline_api.py index bfaf4909..a6eb1205 100644 --- a/server/pipeline_api.py +++ b/server/pipeline_api.py @@ -3,10 +3,10 @@ import numpy as np import asyncio import logging +import time from typing import Any, Dict, Union, List from comfystream.client_api import ComfyStreamClient -from comfystream import tensor_cache WARMUP_RUNS = 5 logger = logging.getLogger(__name__) @@ -19,6 +19,10 @@ def __init__(self, **kwargs): self.audio_incoming_frames = asyncio.Queue() self.processed_audio_buffer = np.array([], dtype=np.int16) + self.last_frame_time = 0 + + # TODO: Not sure if this is needed - should match to UI selected FPS + self.min_frame_interval = 1/30 # Limit to 30 FPS async def warm_video(self): """Warm up the video pipeline with dummy frames""" @@ -62,6 +66,11 @@ async def update_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any await self.client.update_prompts([prompts]) async def put_video_frame(self, frame: av.VideoFrame): + current_time = time.time() + if current_time - self.last_frame_time < self.min_frame_interval: + return # Skip frame if too soon + + self.last_frame_time = current_time frame.side_data.input = self.video_preprocess(frame) frame.side_data.skipped = False # Different from LoadTensor, we don't skip frames here self.client.put_video_input(frame) @@ -78,8 +87,16 @@ def audio_preprocess(self, frame: av.AudioFrame) -> Union[torch.Tensor, np.ndarr # Works with ComfyUI Native def video_preprocess(self, frame: av.VideoFrame) -> Union[torch.Tensor, np.ndarray]: - frame_np = frame.to_ndarray(format="rgb24").astype(np.float32) / 255.0 - return torch.from_numpy(frame_np).unsqueeze(0) + # Convert directly to tensor, avoiding intermediate numpy array when possible + if hasattr(frame, 'to_tensor'): + tensor = frame.to_tensor() + else: + # If direct tensor conversion not available, use numpy + frame_np = frame.to_ndarray(format="rgb24") + tensor = torch.from_numpy(frame_np) + + # Normalize to [0,1] range and add batch dimension + return tensor.float().div(255.0).unsqueeze(0) ''' Converts HWC format (height, width, channels) to VideoFrame. def video_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.VideoFrame: @@ -103,31 +120,21 @@ def audio_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.Audio return av.AudioFrame.from_ndarray(np.repeat(output, 2).reshape(1, -1)) async def get_processed_video_frame(self): - """Get processed video frame from output queue and match it with input frame""" try: # Get the frame from the incoming queue first frame = await self.video_incoming_frames.get() - while frame.side_data.skipped: + # Skip frames if we're falling behind + while not self.video_incoming_frames.empty(): + # Get newer frame and mark old one as skipped + frame.side_data.skipped = True frame = await self.video_incoming_frames.get() + logger.info("Skipped older frame to catch up") # Get the processed frame from the output queue - logger.info("Getting video output") out_tensor = await self.client.get_video_output() - # If there are more frames in the output queue, drain them to get the most recent - # This helps with synchronization when processing is faster than display - while not tensor_cache.image_outputs.empty(): - try: - newer_tensor = await asyncio.wait_for(self.client.get_video_output(), 0.01) - out_tensor = newer_tensor # Use the most recent frame - logger.info("Using more recent frame from output queue") - except asyncio.TimeoutError: - break - - logger.info(f"Received output tensor with shape: {out_tensor.shape if hasattr(out_tensor, 'shape') else 'unknown'}") - - # Process the output tensor + # Process only the most recent frame processed_frame = self.video_postprocess(out_tensor) processed_frame.pts = frame.pts processed_frame.time_base = frame.time_base diff --git a/src/comfystream/client_api.py b/src/comfystream/client_api.py index 5e23e391..624de56e 100644 --- a/src/comfystream/client_api.py +++ b/src/comfystream/client_api.py @@ -255,11 +255,11 @@ async def _handle_text_message(self, message: str): self._prompt_id = data["data"]["prompt_id"] if "node" in data["data"]: node_id = data["data"]["node"] - # ogger.info(f"Executing node: {node_id}") + # logger.info(f"Executing node: {node_id}") elif message_type in ["execution_cached", "execution_error", "execution_complete", "execution_interrupted"]: # logger.info(f"{message_type} message received for prompt {self._prompt_id}") - #self.execution_started = False + # self.execution_started = False # Always signal completion for these terminal states # self.execution_complete_event.set() @@ -335,214 +335,90 @@ async def _handle_text_message(self, message: str): async def _handle_binary_message(self, binary_data): """Process binary messages from the WebSocket""" try: - # Log binary message information - # logger.info(f"Received binary message of size: {len(binary_data)} bytes") - - # Signal execution is complete, queue next frame - self.execution_complete_event.set() - - # Binary messages in ComfyUI start with a header - # First 8 bytes are used for header information + # Early return if message is too short if len(binary_data) <= 8: - logger.warning(f"Binary message too short: {len(binary_data)} bytes") + self.execution_complete_event.set() return - # Extract header data based on the actual format observed in logs - # Header bytes (hex): 0000000100000001 - this appears to be the format in use + # Extract header data only when needed event_type = int.from_bytes(binary_data[:4], byteorder='little') format_type = int.from_bytes(binary_data[4:8], byteorder='little') data = binary_data[8:] - # Log header details - logger.info(f"Binary message header: event_type={event_type}, format_type={format_type}, data_size={len(data)} bytes") - #logger.info(f"Header bytes (hex): {binary_data[:8].hex()}") - - # Check if this is an image (JPEG starts with 0xFF, 0xD8, PNG starts with 0x89, 0x50) - is_jpeg = data[:2] == b'\xff\xd8' - is_png = data[:4] == b'\x89\x50\x4e\x47' + # Quick check for image format + is_image = data[:2] in [b'\xff\xd8', b'\x89\x50'] + if not is_image: + self.execution_complete_event.set() + return - if is_jpeg or is_png: - image_format = "JPEG" if is_jpeg else "PNG" - logger.info(f"Detected {image_format} image based on magic bytes") - - # Create a NEW binary message with the expected header format for the JavaScript client - # The JavaScript expects: [0:4]=1 (PREVIEW_IMAGE), [4:8]=1 (JPEG format) or [4:8]=2 (PNG format) - # This matches exactly what the JS code is looking for: - # const event = dataView.getUint32(0); // event type (1 = PREVIEW_IMAGE) - # const format = dataView.getUint32(4); // format (1 = JPEG, 2 = PNG) - js_event_type = (1).to_bytes(4, byteorder='little') # PREVIEW_IMAGE = 1 - js_format_type = (1 if is_jpeg else 2).to_bytes(4, byteorder='little') - transformed_data = js_event_type + js_format_type + data + # Process image data directly + try: + img = Image.open(BytesIO(data)) + if img.mode != "RGB": + img = img.convert("RGB") + + with torch.no_grad(): + tensor = torch.from_numpy(np.array(img)).float().permute(2, 0, 1).unsqueeze(0) / 255.0 - # Forward to WebSocket client if connected - # if self.ws: - # await self.ws.send(transformed_data) - # logger.info(f"Sent transformed {image_format} image data to WebSocket with correct JS header format") - #else: - # logger.error("WebSocket not connected, cannot forward image to JS client") + # Add to output queue without waiting + tensor_cache.image_outputs.put_nowait(tensor) + self.execution_complete_event.set() - # Process the image for our pipeline - try: - # Decode the image - img = Image.open(BytesIO(data)) - logger.info(f"Successfully decoded image: size={img.size}, mode={img.mode}, format={img.format}") - - # Convert to RGB if not already - if img.mode != "RGB": - img = img.convert("RGB") - logger.info(f"Converted image to RGB mode") - - # Save image to temp folder as a file - # TESTING - ''' - import os - import tempfile - temp_folder = os.path.join(tempfile.gettempdir(), "comfyui_images") - os.makedirs(temp_folder, exist_ok=True) - img_path = os.path.join(temp_folder, f"comfyui_image_{time.time()}.png") - img.save(img_path) - logger.info(f"Saved image to {img_path}") - ''' - - # Convert to tensor (normalize to [0,1] range for consistency) - img_np = np.array(img).astype(np.float32) / 255.0 - tensor = torch.from_numpy(img_np) - - # CRITICAL: Ensure dimensions are correctly understood - # The tensor should be in HWC format initially from PIL/numpy - logger.info(f"Initial tensor shape from image: {tensor.shape}") - - # Convert from HWC to BCHW format for consistency with model expectations - if len(tensor.shape) == 3 and tensor.shape[2] == 3: # HWC format (H,W,3) - tensor = tensor.permute(2, 0, 1).unsqueeze(0) # -> BCHW (1,3,H,W) - logger.info(f"Converted to BCHW tensor with shape: {tensor.shape}") - - # Check for NaN or Inf values - if torch.isnan(tensor).any() or torch.isinf(tensor).any(): - logger.warning("Tensor contains NaN or Inf values! Replacing with zeros") - tensor = torch.nan_to_num(tensor, nan=0.0, posinf=1.0, neginf=0.0) - - # Log detailed tensor info for debugging - logger.info(f"Final tensor with shape: {tensor.shape}, " - f"min={tensor.min().item()}, max={tensor.max().item()}, " - f"mean={tensor.mean().item()}") - - # Add to output queue without waiting - tensor_cache.image_outputs.put_nowait(tensor) - logger.info(f"Added tensor to output queue, queue size: {tensor_cache.image_outputs.qsize()}") - return + except Exception as img_error: + logger.error(f"Error processing image: {img_error}") + self.execution_complete_event.set() - except Exception as img_error: - logger.error(f"Error processing image: {img_error}", exc_info=True) - - # If we get here, we couldn't process the image - logger.warning("Failed to process image, creating default tensor") - default_tensor = torch.zeros(1, 3, 512, 512) - tensor_cache.image_outputs.put_nowait(default_tensor) - self.execution_complete_event.set() - except Exception as e: - logger.error(f"Error handling binary message: {e}", exc_info=True) - # Set execution complete event to avoid hanging + logger.error(f"Error handling binary message: {e}") self.execution_complete_event.set() async def _execute_prompt(self, prompt_index: int): - """Execute a prompt via the ComfyUI API""" try: # Get the prompt to execute prompt = self.current_prompts[prompt_index] - # Ensure all seed values are randomized for every execution - # This forces ComfyUI to not use cached results - for node_id, node in prompt.items(): - if isinstance(node, dict) and "inputs" in node: - if "seed" in node["inputs"]: - # Generate a truly random seed each time - random_seed = random.randint(0, 18446744073709551615) - node["inputs"]["seed"] = random_seed - logger.info(f"Randomized seed to {random_seed} for node {node_id}") - - # Also randomize noise_seed if present - if "noise_seed" in node["inputs"]: - noise_seed = random.randint(0, 18446744073709551615) - node["inputs"]["noise_seed"] = noise_seed - logger.info(f"Randomized noise_seed to {noise_seed} for node {node_id}") - - # Add a timestamp parameter to each node to prevent caching - # This is a "hidden" trick to force ComfyUI to consider each execution unique - timestamp = int(time.time() * 1000) # millisecond timestamp - for node_id, node in prompt.items(): - if isinstance(node, dict) and "inputs" in node: - # Add a timestamp parameter to ETN_LoadImageBase64 nodes - if node.get("class_type") in ["ETN_LoadImageBase64", "LoadImageBase64"]: - # Add a unique cache-busting parameter - node["inputs"]["_timestamp"] = timestamp - logger.info(f"Added timestamp {timestamp} to node {node_id}") - # Check if we have a frame waiting to be processed if not tensor_cache.image_inputs.empty(): logger.info("Found tensor in input queue, preparing for API") - # Get the frame from the cache - make sure to get the most recent frame + # Get the most recent frame only + frame_or_tensor = None while not tensor_cache.image_inputs.empty(): frame_or_tensor = tensor_cache.image_inputs.get_nowait() - # Find ETN_LoadImageBase64 nodes + # Find ETN_LoadImageBase64 nodes first load_image_nodes = [] for node_id, node in prompt.items(): - if isinstance(node, dict) and node.get("class_type") == "ETN_LoadImageBase64": + if isinstance(node, dict) and node.get("class_type") in ["ETN_LoadImageBase64", "LoadImageBase64"]: load_image_nodes.append(node_id) if not load_image_nodes: - # Also check for regular LoadImageBase64 nodes as fallback - for node_id, node in prompt.items(): - if isinstance(node, dict) and node.get("class_type") == "LoadImageBase64": - load_image_nodes.append(node_id) - - if not load_image_nodes: - logger.warning("No ETN_LoadImageBase64 or LoadImageBase64 nodes found in the prompt") - self.execution_complete_event.set() # Signal completion + logger.warning("No LoadImageBase64 nodes found in the prompt") + self.execution_complete_event.set() return - else: - # Convert the frame/tensor to base64 and include directly in the prompt - try: - # Get the actual tensor data - handle different input types - tensor = None - - # Check if it's a PyAV VideoFrame with preprocessed tensor in side_data - if hasattr(frame_or_tensor, 'side_data') and hasattr(frame_or_tensor.side_data, 'input'): - tensor = frame_or_tensor.side_data.input - logger.info(f"Using preprocessed tensor from frame.side_data.input with shape {tensor.shape}") - # Check if it's a PyTorch tensor - elif isinstance(frame_or_tensor, torch.Tensor): - tensor = frame_or_tensor - logger.info(f"Using tensor directly with shape {tensor.shape}") - # Check if it's a numpy array - elif isinstance(frame_or_tensor, np.ndarray): - tensor = torch.from_numpy(frame_or_tensor).float() - logger.info(f"Converted numpy array to tensor with shape {tensor.shape}") - else: - # If it's a PyAV frame without preprocessed data, convert it - try: - if hasattr(frame_or_tensor, 'to_ndarray'): - frame_np = frame_or_tensor.to_ndarray(format="rgb24").astype(np.float32) / 255.0 - tensor = torch.from_numpy(frame_np).unsqueeze(0) - logger.info(f"Converted PyAV frame to tensor with shape {tensor.shape}") - else: - logger.error(f"Unsupported frame type: {type(frame_or_tensor)}") - self.execution_complete_event.set() - return - except Exception as e: - logger.error(f"Error converting frame to tensor: {e}") - self.execution_complete_event.set() - return - - if tensor is None: - logger.error("Failed to get valid tensor data from input") - self.execution_complete_event.set() - return - - # Now process the tensor (which should be a proper PyTorch tensor) - # Ensure it's a tensor on CPU and detached + + # Process the tensor ONLY if we have nodes to send it to + try: + # Get the actual tensor data - handle different input types + tensor = None + + # Handle different input types efficiently + if hasattr(frame_or_tensor, 'side_data') and hasattr(frame_or_tensor.side_data, 'input'): + tensor = frame_or_tensor.side_data.input + elif isinstance(frame_or_tensor, torch.Tensor): + tensor = frame_or_tensor + elif isinstance(frame_or_tensor, np.ndarray): + tensor = torch.from_numpy(frame_or_tensor).float() + elif hasattr(frame_or_tensor, 'to_ndarray'): + frame_np = frame_or_tensor.to_ndarray(format="rgb24").astype(np.float32) / 255.0 + tensor = torch.from_numpy(frame_np).unsqueeze(0) + + if tensor is None: + logger.error("Failed to get valid tensor data from input") + self.execution_complete_event.set() + return + + # Process tensor format only once + with torch.no_grad(): tensor = tensor.detach().cpu().float() # Handle different formats @@ -553,65 +429,52 @@ async def _execute_prompt(self, prompt_index: int): if len(tensor.shape) == 3 and tensor.shape[2] == 3: # HWC format tensor = tensor.permute(2, 0, 1) # Convert to CHW - # Convert to PIL image for saving + # Convert to PIL image for base64 ONLY ONCE tensor_np = (tensor.permute(1, 2, 0) * 255).clamp(0, 255).numpy().astype(np.uint8) img = Image.fromarray(tensor_np) - # Save as PNG to BytesIO and convert to base64 string + # Convert to base64 ONCE for all nodes buffer = BytesIO() img.save(buffer, format="PNG") buffer.seek(0) - - # Encode as base64 - for ETN_LoadImageBase64, we need the raw base64 string img_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8') - logger.info(f"Created base64 string of length: {len(img_base64)}") - - # Update all ETN_LoadImageBase64 nodes with the base64 data - for node_id in load_image_nodes: - prompt[node_id]["inputs"]["image"] = img_base64 - # Add a small random suffix to image data to prevent caching - rand_suffix = str(random.randint(1, 1000000)) - prompt[node_id]["inputs"]["_cache_buster"] = rand_suffix - logger.info(f"Updated node {node_id} with base64 string and cache buster {rand_suffix}") - except Exception as e: - logger.error(f"Error converting tensor to base64: {e}") - # Signal execution complete in case of error - self.execution_complete_event.set() - return + # Update all nodes with the SAME base64 string + timestamp = int(time.time() * 1000) + for node_id in load_image_nodes: + prompt[node_id]["inputs"]["image"] = img_base64 + prompt[node_id]["inputs"]["_timestamp"] = timestamp + # Use timestamp as cache buster instead of random number + prompt[node_id]["inputs"]["_cache_buster"] = str(timestamp) + + except Exception as e: + logger.error(f"Error converting tensor to base64: {e}") + self.execution_complete_event.set() + return + + # Execute the prompt via API + async with aiohttp.ClientSession() as session: + api_url = f"{self.api_base_url}/prompt" + payload = { + "prompt": prompt, + "client_id": self.client_id + } + + async with session.post(api_url, json=payload) as response: + if response.status == 200: + result = await response.json() + self._prompt_id = result.get("prompt_id") + self.execution_started = True + else: + error_text = await response.text() + logger.error(f"Error queueing prompt: {response.status} - {error_text}") + self.execution_complete_event.set() else: logger.info("No tensor in input queue, skipping prompt execution") - self.execution_complete_event.set() # Signal completion - return - - # Execute the prompt via API - async with aiohttp.ClientSession() as session: - api_url = f"{self.api_base_url}/prompt" - payload = { - "prompt": prompt, - "client_id": self.client_id - } - - # Send the request - logger.info(f"Sending prompt to {api_url}") - async with session.post(api_url, json=payload) as response: - if response.status == 200: - result = await response.json() - self._prompt_id = result.get("prompt_id") - logger.info(f"Prompt queued with ID: {self._prompt_id}") - self.execution_started = True - else: - error_text = await response.text() - logger.error(f"Error queueing prompt: {response.status} - {error_text}") - # Signal execution complete in case of error - self.execution_complete_event.set() + self.execution_complete_event.set() - except aiohttp.ClientError as e: - logger.error(f"Client error queueing prompt: {e}") - self.execution_complete_event.set() except Exception as e: logger.error(f"Error executing prompt: {e}") - # Signal execution complete in case of error self.execution_complete_event.set() async def _send_tensor_via_websocket(self, tensor): diff --git a/src/comfystream/utils_api.py b/src/comfystream/utils_api.py index 8b93ec7f..5a932ef1 100644 --- a/src/comfystream/utils_api.py +++ b/src/comfystream/utils_api.py @@ -15,7 +15,6 @@ def create_load_tensor_node(): "_meta": {"title": "Load Tensor (API)"}, } - def create_save_tensor_node(inputs: Dict[Any, Any]): """Create a SaveTensorAPI node with proper input formatting""" # Make sure images input is properly formatted [node_id, output_index] @@ -51,79 +50,29 @@ def convert_prompt(prompt): random_seed = random.randint(0, 18446744073709551615) node["inputs"]["seed"] = random_seed print(f"Set random seed {random_seed} for node {key}") + + + # Replace LoadImage with LoadImageBase64 + for key, node in prompt.items(): + if node.get("class_type") == "LoadImage": + node["class_type"] = "LoadImageBase64" + + # Replace SaveImage/PreviewImage with SendImageWebsocket + for key, node in prompt.items(): + if node.get("class_type") in ["SaveImage", "PreviewImage"]: + node["class_type"] = "SendImageWebsocket" + # Ensure format is set + if "format" not in node["inputs"]: + node["inputs"]["format"] = "PNG" # Set default format return prompt ''' -def convert_prompt(prompt): +def convert_prompt(prompt: PromptDictInput) -> Prompt: + # Validate the schema + Prompt.validate(prompt) - # Check if this is a ComfyUI web UI format prompt with 'nodes' and 'links' - if isinstance(prompt, dict) and 'nodes' in prompt and 'links' in prompt: - # Convert the web UI prompt format to the API format - api_prompt = {} - - # Process each node - for node in prompt['nodes']: - node_id = str(node['id']) - - # Create a node entry in the API format - api_prompt[node_id] = { - 'class_type': node.get('type', 'Unknown'), - 'inputs': {}, - '_meta': { - 'title': node.get('type', 'Unknown') - } - } - - # Process inputs - if 'inputs' in node: - for input_data in node['inputs']: - input_name = input_data.get('name') - link_id = input_data.get('link') - - if input_name and link_id is not None: - # Find the source of this link - for link in prompt['links']: - if link[0] == link_id: # link ID matches - # Get source node and output slot - source_node_id = str(link[1]) - source_slot = link[3] - - # Add to inputs - api_prompt[node_id]['inputs'][input_name] = [ - source_node_id, - source_slot - ] - break - # If no link found, set to empty value - if input_name not in api_prompt[node_id]['inputs']: - api_prompt[node_id]['inputs'][input_name] = None - - # Process widget values - if 'widgets_values' in node: - for i, widget_value in enumerate(node.get('widgets_values', [])): - # Try to determine widget name from properties or use index - widget_name = f"widget_{i}" - # Add to inputs - api_prompt[node_id]['inputs'][widget_name] = widget_value - - # Use this converted prompt instead - prompt = api_prompt - - # Now process as normal API format prompt prompt = copy.deepcopy(prompt) - - # Set random seeds for any seed nodes - for key, node in prompt.items(): - if not isinstance(node, dict) or "inputs" not in node: - continue - - # Check if this node has a seed input directly - if "seed" in node.get("inputs", {}): - # Generate a random seed (same range as JavaScript's Math.random() * 18446744073709552000) - random_seed = random.randint(0, 18446744073709551615) - node["inputs"]["seed"] = random_seed - print(f"Set random seed {random_seed} for node {key}") num_primary_inputs = 0 num_inputs = 0 @@ -134,112 +83,54 @@ def convert_prompt(prompt): "LoadImage": [], "PreviewImage": [], "SaveImage": [], - "LoadTensor": [], - "SaveTensor": [], - "LoadImageBase64": [], - "LoadTensorAPI": [], - "SaveTensorAPI": [], } for key, node in prompt.items(): - if not isinstance(node, dict): - continue + class_type = node.get("class_type") + + # Collect keys for nodes that might need to be replaced + if class_type in keys: + keys[class_type].append(key) - class_type = node.get("class_type", "") - - # Track primary input and output nodes - if class_type in ["PrimaryInput", "PrimaryInputImage"]: + # Count inputs and outputs + if class_type == "PrimaryInputLoadImage": num_primary_inputs += 1 - keys["PrimaryInputLoadImage"].append(key) - elif class_type in ["LoadImage", "LoadTensor", "LoadAudioTensor", "LoadImageBase64", "LoadTensorAPI"]: + elif class_type in ["LoadImage", "LoadTensor", "LoadAudioTensor"]: num_inputs += 1 - if class_type == "LoadImage": - keys["LoadImage"].append(key) - elif class_type == "LoadTensor": - keys["LoadTensor"].append(key) - elif class_type == "LoadImageBase64": - keys["LoadImageBase64"].append(key) - elif class_type == "LoadTensorAPI": - keys["LoadTensorAPI"].append(key) - elif class_type in ["PreviewImage", "SaveImage", "SaveTensor", "SaveAudioTensor", "SendImageWebSocket", "SaveTensorAPI"]: + elif class_type in ["PreviewImage", "SaveImage", "SaveTensor", "SaveAudioTensor"]: num_outputs += 1 - if class_type == "PreviewImage": - keys["PreviewImage"].append(key) - elif class_type == "SaveImage": - keys["SaveImage"].append(key) - elif class_type == "SaveTensor": - keys["SaveTensor"].append(key) - elif class_type == "SaveTensorAPI": - keys["SaveTensorAPI"].append(key) - - print(f"Found {num_primary_inputs} primary inputs, {num_inputs} inputs, {num_outputs} outputs") - - # Set up connection for video feeds by replacing LoadImage with LoadImageBase64 - if num_inputs == 0 and num_primary_inputs == 0: - # Add a LoadTensorAPI node - new_key = "999990" - prompt[new_key] = create_load_tensor_node() - keys["LoadTensorAPI"].append(new_key) - print("Added LoadTensorAPI node for tensor input") - elif len(keys["LoadTensor"]) > 0 and len(keys["LoadTensorAPI"]) == 0: - # Replace LoadTensor with LoadTensorAPI if found - for key in keys["LoadTensor"]: - prompt[key] = create_load_tensor_node() - keys["LoadTensorAPI"].append(key) - print("Replaced LoadTensor with LoadTensorAPI") - - # Set up connection for output if needed + + # Only handle single primary input + if num_primary_inputs > 1: + raise Exception("too many primary inputs in prompt") + + # If there are no primary inputs, only handle single input + if num_primary_inputs == 0 and num_inputs > 1: + raise Exception("too many inputs in prompt") + + # Only handle single output for now + if num_outputs > 1: + raise Exception("too many outputs in prompt") + + if num_primary_inputs + num_inputs == 0: + raise Exception("missing input") + if num_outputs == 0: - # Find nodes that produce images - image_output_nodes = [] - for key, node in prompt.items(): - if isinstance(node, dict): - class_type = node.get("class_type", "") - # Look for nodes that typically output images - if any(output_type in class_type.lower() for output_type in ["vae", "decode", "img", "image", "upscale", "sample"]): - image_output_nodes.append(key) - # Also check if the node's RETURN_TYPES includes IMAGE - elif "outputs" in node and isinstance(node["outputs"], dict): - for output_name, output_type in node["outputs"].items(): - if "IMAGE" in output_type: - image_output_nodes.append(key) - break - - # If we found image output nodes, connect SaveTensorAPI to them - if image_output_nodes: - for i, node_key in enumerate(image_output_nodes): - new_key = f"999991_{i}" - prompt[new_key] = create_save_tensor_node({"images": [node_key, 0]}) - print(f"Added SaveTensorAPI node connected to {node_key}") - else: - # Try to find the last node in the chain as fallback - last_node = None - max_id = -1 - for key, node in prompt.items(): - if isinstance(node, dict): - try: - node_id = int(key) - if node_id > max_id: - max_id = node_id - last_node = key - except ValueError: - pass - - if last_node: - # Add a SaveTensorAPI node - new_key = "999991" - prompt[new_key] = create_save_tensor_node({"images": [last_node, 0]}) - print(f"Added SaveTensorAPI node connected to {last_node}") - else: - print("Warning: Could not find a suitable node to connect SaveTensorAPI to") - - # Make sure all SaveTensorAPI nodes have proper configuration - for key, node in prompt.items(): - if isinstance(node, dict) and node.get("class_type") == "SaveTensorAPI": - # Ensure format is set to PNG for optimal compatibility - if "inputs" in node: - node["inputs"]["format"] = "png" - - # Return the modified prompt + raise Exception("missing output") + + # Replace nodes + for key in keys["PrimaryInputLoadImage"]: + prompt[key] = create_load_tensor_node() + + if num_primary_inputs == 0 and len(keys["LoadImage"]) == 1: + prompt[keys["LoadImage"][0]] = create_load_tensor_node() + + for key in keys["PreviewImage"] + keys["SaveImage"]: + node = prompt[key] + prompt[key] = create_save_tensor_node(node["inputs"]) + + # Validate the processed prompt input + prompt = Prompt.validate(prompt) + return prompt ''' \ No newline at end of file From d66dac4dbbdc94d7ff7f394ee9b6b0aae78aeca0 Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Tue, 25 Mar 2025 15:48:50 -0400 Subject: [PATCH 04/42] Preliminary work on multi-Comfy server inference. --- server/app_api.py | 14 ++- server/pipeline_api.py | 246 ++++++++++++++++++++++++++++++++--------- 2 files changed, 207 insertions(+), 53 deletions(-) diff --git a/server/app_api.py b/server/app_api.py index 6158a540..94fc36af 100644 --- a/server/app_api.py +++ b/server/app_api.py @@ -26,7 +26,6 @@ from twilio.rest import Client from utils import patch_loop_datagram, add_prefix_to_app_routes, FPSMeter from metrics import MetricsManager, StreamStatsManager -import time logger = logging.getLogger(__name__) logging.getLogger("aiortc.rtcrtpsender").setLevel(logging.WARNING) @@ -345,7 +344,11 @@ async def on_startup(app: web.Application): patch_loop_datagram(app["media_ports"]) app["pipeline"] = Pipeline( - cwd=app["workspace"], disable_cuda_malloc=True, gpu_only=True, preview_method='none' + config_path=app["config_file"], + cwd=app["workspace"], + disable_cuda_malloc=True, + gpu_only=True, + preview_method='none' ) app["pcs"] = set() app["video_tracks"] = {} @@ -374,6 +377,12 @@ async def on_shutdown(app: web.Application): choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], help="Set the logging level", ) + parser.add_argument( + "--config-file", + type=str, + default=None, + help="Path to TOML configuration file for Comfy servers" + ) parser.add_argument( "--monitor", default=False, @@ -397,6 +406,7 @@ async def on_shutdown(app: web.Application): app = web.Application() app["media_ports"] = args.media_ports.split(",") if args.media_ports else None app["workspace"] = args.workspace + app["config_file"] = args.config_file app.on_startup.append(on_startup) app.on_shutdown.append(on_shutdown) diff --git a/server/pipeline_api.py b/server/pipeline_api.py index a6eb1205..6849c265 100644 --- a/server/pipeline_api.py +++ b/server/pipeline_api.py @@ -4,88 +4,214 @@ import asyncio import logging import time +import random +from collections import deque -from typing import Any, Dict, Union, List +from typing import Any, Dict, Union, List, Optional, Deque from comfystream.client_api import ComfyStreamClient +from config import ComfyConfig WARMUP_RUNS = 5 logger = logging.getLogger(__name__) -class Pipeline: - def __init__(self, **kwargs): - self.client = ComfyStreamClient(**kwargs) +class MultiServerPipeline: + def __init__(self, config_path: Optional[str] = None, **kwargs): + # Load server configurations + self.config = ComfyConfig(config_path) + self.servers = self.config.get_servers() + + # Create client for each server + self.clients = [] + for server_config in self.servers: + client_kwargs = kwargs.copy() + client_kwargs.update(server_config) + self.clients.append(ComfyStreamClient(**client_kwargs)) + + logger.info(f"Initialized {len(self.clients)} ComfyUI clients") + self.video_incoming_frames = asyncio.Queue() self.audio_incoming_frames = asyncio.Queue() - + + # Queue for processed frames from all clients + self.processed_video_frames = asyncio.Queue() + + # Track which client gets each frame (round-robin) + self.current_client_index = 0 + self.client_frame_mapping = {} # Maps frame_id -> client_index + + # Buffer to store frames in order of original pts + self.frame_output_buffer: Deque = deque() + + # Audio processing self.processed_audio_buffer = np.array([], dtype=np.int16) self.last_frame_time = 0 - # TODO: Not sure if this is needed - should match to UI selected FPS + # Frame rate limiting self.min_frame_interval = 1/30 # Limit to 30 FPS + + # Create background task for collecting processed frames + self.running = True + self.collector_task = asyncio.create_task(self._collect_processed_frames()) + + async def _collect_processed_frames(self): + """Background task to collect processed frames from all clients""" + try: + while self.running: + for i, client in enumerate(self.clients): + try: + # Non-blocking check if client has output ready + if hasattr(client, '_prompt_id') and client._prompt_id is not None: + # Get frame without waiting + try: + # Use wait_for with small timeout to avoid blocking + out_tensor = await asyncio.wait_for( + client.get_video_output(), + timeout=0.01 + ) + + # Find which original frame this corresponds to + # (using a simple approach here - could be improved) + # In real implementation, need to track which frames went to which client + frame_ids = [frame_id for frame_id, client_idx in + self.client_frame_mapping.items() if client_idx == i] + + if frame_ids: + # Use the oldest frame ID for this client + frame_id = min(frame_ids) + # Store the processed tensor along with original frame ID for ordering + await self.processed_video_frames.put((frame_id, out_tensor)) + # Remove the mapping + self.client_frame_mapping.pop(frame_id, None) + logger.info(f"Collected processed frame from client {i}, frame_id: {frame_id}") + except asyncio.TimeoutError: + # No frame ready yet, continue + pass + except Exception as e: + logger.error(f"Error collecting frame from client {i}: {e}") + + # Small sleep to avoid CPU spinning + await asyncio.sleep(0.01) + except asyncio.CancelledError: + logger.info("Frame collector task cancelled") + except Exception as e: + logger.error(f"Unexpected error in frame collector: {e}") async def warm_video(self): - """Warm up the video pipeline with dummy frames""" + """Warm up the video pipeline with dummy frames for each client""" logger.info("Warming up video pipeline...") - # Create a properly formatted dummy frame (random color pattern) - # Using standard tensor shape: BCHW [1, 3, 512, 512] + # Create a properly formatted dummy frame tensor = torch.rand(1, 3, 512, 512) # Random values in [0,1] - - # Create a dummy frame and attach the tensor as side_data dummy_frame = av.VideoFrame(width=512, height=512, format="rgb24") dummy_frame.side_data.input = tensor - # Process a few frames for warmup - for i in range(WARMUP_RUNS): - logger.info(f"Video warmup iteration {i+1}/{WARMUP_RUNS}") - self.client.put_video_input(dummy_frame) - await self.client.get_video_output() - + # Warm up each client + warmup_tasks = [] + for i, client in enumerate(self.clients): + warmup_tasks.append(self._warm_client_video(client, i, dummy_frame)) + + # Wait for all warmup tasks to complete + await asyncio.gather(*warmup_tasks) logger.info("Video pipeline warmup complete") + + async def _warm_client_video(self, client, client_index, dummy_frame): + """Warm up a single client""" + logger.info(f"Warming up client {client_index}") + for i in range(WARMUP_RUNS): + logger.info(f"Client {client_index} warmup iteration {i+1}/{WARMUP_RUNS}") + client.put_video_input(dummy_frame) + try: + await asyncio.wait_for(client.get_video_output(), timeout=5.0) + except asyncio.TimeoutError: + logger.warning(f"Timeout waiting for warmup frame from client {client_index}") + except Exception as e: + logger.error(f"Error warming client {client_index}: {e}") async def warm_audio(self): + # For now, only use the first client for audio + if not self.clients: + logger.warning("No clients available for audio warmup") + return + dummy_frame = av.AudioFrame() - dummy_frame.side_data.input = np.random.randint(-32768, 32767, int(48000 * 0.5), dtype=np.int16) # TODO: adds a lot of delay if it doesn't match the buffer size, is warmup needed? + dummy_frame.side_data.input = np.random.randint(-32768, 32767, int(48000 * 0.5), dtype=np.int16) dummy_frame.sample_rate = 48000 for _ in range(WARMUP_RUNS): - self.client.put_audio_input(dummy_frame) - await self.client.get_audio_output() + self.clients[0].put_audio_input(dummy_frame) + await self.clients[0].get_audio_output() async def set_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]): - if isinstance(prompts, list): - await self.client.set_prompts(prompts) - else: - await self.client.set_prompts([prompts]) + """Set the same prompts for all clients""" + if isinstance(prompts, dict): + prompts = [prompts] + + # Set prompts for each client + tasks = [] + for client in self.clients: + tasks.append(client.set_prompts(prompts)) + + await asyncio.gather(*tasks) + logger.info(f"Set prompts for {len(self.clients)} clients") async def update_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]): - if isinstance(prompts, list): - await self.client.update_prompts(prompts) - else: - await self.client.update_prompts([prompts]) + """Update prompts for all clients""" + if isinstance(prompts, dict): + prompts = [prompts] + + # Update prompts for each client + tasks = [] + for client in self.clients: + tasks.append(client.update_prompts(prompts)) + + await asyncio.gather(*tasks) + logger.info(f"Updated prompts for {len(self.clients)} clients") async def put_video_frame(self, frame: av.VideoFrame): + """Distribute video frames among clients using round-robin""" current_time = time.time() if current_time - self.last_frame_time < self.min_frame_interval: return # Skip frame if too soon self.last_frame_time = current_time + + # Generate a unique frame ID + frame_id = int(time.time() * 1000000) # Microseconds as ID + frame.side_data.frame_id = frame_id + + # Preprocess the frame frame.side_data.input = self.video_preprocess(frame) - frame.side_data.skipped = False # Different from LoadTensor, we don't skip frames here - self.client.put_video_input(frame) + frame.side_data.skipped = False + + # Select the next client in round-robin fashion + client_index = self.current_client_index + self.current_client_index = (self.current_client_index + 1) % len(self.clients) + + # Store mapping of which client is processing this frame + self.client_frame_mapping[frame_id] = client_index + + # Send frame to the selected client + self.clients[client_index].put_video_input(frame) + + # Also add to the incoming queue for reference await self.video_incoming_frames.put(frame) + + logger.info(f"Sent frame {frame_id} to client {client_index}") async def put_audio_frame(self, frame: av.AudioFrame): + # For now, only use the first client for audio + if not self.clients: + return + frame.side_data.input = self.audio_preprocess(frame) frame.side_data.skipped = False - self.client.put_audio_input(frame) + self.clients[0].put_audio_input(frame) await self.audio_incoming_frames.put(frame) def audio_preprocess(self, frame: av.AudioFrame) -> Union[torch.Tensor, np.ndarray]: return frame.to_ndarray().ravel().reshape(-1, 2).mean(axis=1).astype(np.int16) - # Works with ComfyUI Native def video_preprocess(self, frame: av.VideoFrame) -> Union[torch.Tensor, np.ndarray]: # Convert directly to tensor, avoiding intermediate numpy array when possible if hasattr(frame, 'to_tensor'): @@ -97,15 +223,7 @@ def video_preprocess(self, frame: av.VideoFrame) -> Union[torch.Tensor, np.ndarr # Normalize to [0,1] range and add batch dimension return tensor.float().div(255.0).unsqueeze(0) - - ''' Converts HWC format (height, width, channels) to VideoFrame. - def video_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.VideoFrame: - return av.VideoFrame.from_ndarray( - (output * 255.0).clamp(0, 255).to(dtype=torch.uint8).squeeze(0).cpu().numpy() - ) - ''' - '''Converts BCHW tensor [1,3,H,W] to VideoFrame. Assumes values are in range [0,1].''' def video_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.VideoFrame: return av.VideoFrame.from_ndarray( (output.squeeze(0).permute(1, 2, 0) * 255.0) @@ -121,7 +239,7 @@ def audio_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.Audio async def get_processed_video_frame(self): try: - # Get the frame from the incoming queue first + # Get the frame from the incoming queue first to maintain timing frame = await self.video_incoming_frames.get() # Skip frames if we're falling behind @@ -130,11 +248,11 @@ async def get_processed_video_frame(self): frame.side_data.skipped = True frame = await self.video_incoming_frames.get() logger.info("Skipped older frame to catch up") - - # Get the processed frame from the output queue - out_tensor = await self.client.get_video_output() - # Process only the most recent frame + # Get the processed frame from our output queue + frame_id, out_tensor = await self.processed_video_frames.get() + + # Process the frame processed_frame = self.video_postprocess(out_tensor) processed_frame.pts = frame.pts processed_frame.time_base = frame.time_base @@ -148,10 +266,14 @@ async def get_processed_video_frame(self): return black_frame async def get_processed_audio_frame(self): - # TODO: make it generic to support purely generative audio cases and also add frame skipping + # Only use the first client for audio + if not self.clients: + logger.warning("No clients available for audio processing") + return av.AudioFrame(format='s16', layout='mono', samples=1024) + frame = await self.audio_incoming_frames.get() if frame.samples > len(self.processed_audio_buffer): - out_tensor = await self.client.get_audio_output() + out_tensor = await self.clients[0].get_audio_output() self.processed_audio_buffer = np.concatenate([self.processed_audio_buffer, out_tensor]) out_data = self.processed_audio_buffer[:frame.samples] self.processed_audio_buffer = self.processed_audio_buffer[frame.samples:] @@ -164,9 +286,31 @@ async def get_processed_audio_frame(self): return processed_frame async def get_nodes_info(self) -> Dict[str, Any]: - """Get information about all nodes in the current prompt including metadata.""" - nodes_info = await self.client.get_available_nodes() - return nodes_info + """Get information about nodes from the first client""" + if not self.clients: + return {} + return await self.clients[0].get_available_nodes() async def cleanup(self): - await self.client.cleanup() \ No newline at end of file + """Clean up all clients and background tasks""" + self.running = False + + # Cancel collector task + if hasattr(self, 'collector_task') and not self.collector_task.done(): + self.collector_task.cancel() + try: + await self.collector_task + except asyncio.CancelledError: + pass + + # Clean up all clients + cleanup_tasks = [] + for client in self.clients: + cleanup_tasks.append(client.cleanup()) + + await asyncio.gather(*cleanup_tasks) + logger.info("All clients cleaned up") + + +# For backwards compatibility, maintain the original Pipeline name +Pipeline = MultiServerPipeline \ No newline at end of file From d3b0b3c24704d7e146b48de10ca1145cae204171 Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Tue, 25 Mar 2025 17:19:27 -0400 Subject: [PATCH 05/42] Work on frame timing and management, added max_frame_wait argument, added config for server management. --- configs/comfy.toml | 13 ++++ requirements.txt | 1 + server/app_api.py | 10 ++++ server/config.py | 45 ++++++++++++++ server/pipeline_api.py | 110 +++++++++++++++++++++++++++++----- src/comfystream/client_api.py | 23 ++++--- 6 files changed, 174 insertions(+), 28 deletions(-) create mode 100644 configs/comfy.toml create mode 100644 server/config.py diff --git a/configs/comfy.toml b/configs/comfy.toml new file mode 100644 index 00000000..5d278541 --- /dev/null +++ b/configs/comfy.toml @@ -0,0 +1,13 @@ +# Configuration for multiple ComfyUI servers + +[[servers]] +host = "127.0.0.1" +port = 8188 +client_id = "client1" + +# Adding more servers: + +# [[servers]] +# host = "127.0.0.1" +# port = 8189 +# client_id = "client2" diff --git a/requirements.txt b/requirements.txt index 4a7e68ad..56fe8b22 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,5 +3,6 @@ comfyui @ git+https://github.com/hiddenswitch/ComfyUI.git@ce3583ad42c024b8f060d0 aiortc aiohttp toml +tomli twilio prometheus_client diff --git a/server/app_api.py b/server/app_api.py index 94fc36af..b7b56e56 100644 --- a/server/app_api.py +++ b/server/app_api.py @@ -345,6 +345,7 @@ async def on_startup(app: web.Application): app["pipeline"] = Pipeline( config_path=app["config_file"], + max_frame_wait_ms=app["max_frame_wait"], cwd=app["workspace"], disable_cuda_malloc=True, gpu_only=True, @@ -353,6 +354,8 @@ async def on_startup(app: web.Application): app["pcs"] = set() app["video_tracks"] = {} + app["max_frame_wait"] = args.max_frame_wait + async def on_shutdown(app: web.Application): pcs = app["pcs"] @@ -395,6 +398,12 @@ async def on_shutdown(app: web.Application): action="store_true", help="Include stream ID as a label in Prometheus metrics.", ) + parser.add_argument( + "--max-frame-wait", + type=int, + default=500, + help="Maximum time to wait for a frame in milliseconds before dropping it" + ) args = parser.parse_args() logging.basicConfig( @@ -407,6 +416,7 @@ async def on_shutdown(app: web.Application): app["media_ports"] = args.media_ports.split(",") if args.media_ports else None app["workspace"] = args.workspace app["config_file"] = args.config_file + app["max_frame_wait"] = args.max_frame_wait app.on_startup.append(on_startup) app.on_shutdown.append(on_shutdown) diff --git a/server/config.py b/server/config.py new file mode 100644 index 00000000..7f066643 --- /dev/null +++ b/server/config.py @@ -0,0 +1,45 @@ +import tomli +import logging +from typing import List, Dict, Any, Optional + +logger = logging.getLogger(__name__) + +class ComfyConfig: + def __init__(self, config_path: Optional[str] = None): + self.servers = [] + self.config_path = config_path + if config_path: + self.load_config(config_path) + else: + # Default to single local server if no config provided + self.servers = [{"host": "127.0.0.1", "port": 8188}] + + def load_config(self, config_path: str): + """Load server configuration from TOML file""" + try: + with open(config_path, "rb") as f: + config = tomli.load(f) + + # Extract server configurations + if "servers" in config: + self.servers = config["servers"] + logger.info(f"Loaded {len(self.servers)} server configurations") + else: + logger.warning("No servers defined in config, using default") + self.servers = [{"host": "127.0.0.1", "port": 8198}] + + # Validate each server has required fields + for i, server in enumerate(self.servers): + if "host" not in server or "port" not in server: + logger.warning(f"Server {i} missing host or port, using defaults") + server["host"] = server.get("host", "127.0.0.1") + server["port"] = server.get("port", 8198) + + except Exception as e: + logger.error(f"Error loading config from {config_path}: {e}") + # Fall back to default server + self.servers = [{"host": "127.0.0.1", "port": 8198}] + + def get_servers(self) -> List[Dict[str, Any]]: + """Return list of server configurations""" + return self.servers \ No newline at end of file diff --git a/server/pipeline_api.py b/server/pipeline_api.py index 6849c265..d2aedf51 100644 --- a/server/pipeline_api.py +++ b/server/pipeline_api.py @@ -5,7 +5,7 @@ import logging import time import random -from collections import deque +from collections import deque, OrderedDict from typing import Any, Dict, Union, List, Optional, Deque from comfystream.client_api import ComfyStreamClient @@ -16,7 +16,7 @@ class MultiServerPipeline: - def __init__(self, config_path: Optional[str] = None, **kwargs): + def __init__(self, config_path: Optional[str] = None, max_frame_wait_ms: int = 500, **kwargs): # Load server configurations self.config = ComfyConfig(config_path) self.servers = self.config.get_servers() @@ -40,8 +40,10 @@ def __init__(self, config_path: Optional[str] = None, **kwargs): self.current_client_index = 0 self.client_frame_mapping = {} # Maps frame_id -> client_index - # Buffer to store frames in order of original pts - self.frame_output_buffer: Deque = deque() + # Frame ordering and timing + self.max_frame_wait_ms = max_frame_wait_ms # Max time to wait for a frame before dropping + self.next_expected_frame_id = None # Track expected frame ID + self.ordered_frames = OrderedDict() # Buffer for ordering frames (frame_id -> (timestamp, tensor)) # Audio processing self.processed_audio_buffer = np.array([], dtype=np.int16) @@ -71,16 +73,17 @@ async def _collect_processed_frames(self): ) # Find which original frame this corresponds to - # (using a simple approach here - could be improved) - # In real implementation, need to track which frames went to which client frame_ids = [frame_id for frame_id, client_idx in self.client_frame_mapping.items() if client_idx == i] if frame_ids: # Use the oldest frame ID for this client frame_id = min(frame_ids) - # Store the processed tensor along with original frame ID for ordering - await self.processed_video_frames.put((frame_id, out_tensor)) + + # Store frame with timestamp for ordering + timestamp = time.time() + await self._add_frame_to_ordered_buffer(frame_id, timestamp, out_tensor) + # Remove the mapping self.client_frame_mapping.pop(frame_id, None) logger.info(f"Collected processed frame from client {i}, frame_id: {frame_id}") @@ -90,6 +93,9 @@ async def _collect_processed_frames(self): except Exception as e: logger.error(f"Error collecting frame from client {i}: {e}") + # Check for frames that have waited too long + await self._check_frame_timeouts() + # Small sleep to avoid CPU spinning await asyncio.sleep(0.01) except asyncio.CancelledError: @@ -97,6 +103,70 @@ async def _collect_processed_frames(self): except Exception as e: logger.error(f"Unexpected error in frame collector: {e}") + async def _add_frame_to_ordered_buffer(self, frame_id, timestamp, tensor): + """Add a processed frame to the ordered buffer""" + self.ordered_frames[frame_id] = (timestamp, tensor) + + # If this is the first frame, set the next expected frame ID + if self.next_expected_frame_id is None: + self.next_expected_frame_id = frame_id + + # Check if we can release any frames now + await self._release_ordered_frames() + + async def _release_ordered_frames(self): + """Process ordered frames and put them in the output queue""" + # If we don't have a next expected frame yet, can't do anything + if self.next_expected_frame_id is None: + return + + # Check if the next expected frame is in our buffer + while self.ordered_frames and self.next_expected_frame_id in self.ordered_frames: + # Get the frame + timestamp, tensor = self.ordered_frames.pop(self.next_expected_frame_id) + + # Put it in the output queue + await self.processed_video_frames.put((self.next_expected_frame_id, tensor)) + logger.info(f"Released frame {self.next_expected_frame_id} to output queue") + + # Update the next expected frame ID to the next sequential ID if possible + # (or the lowest frame ID in our buffer) + if self.ordered_frames: + self.next_expected_frame_id = min(self.ordered_frames.keys()) + else: + # If no more frames, keep the last ID + 1 as next expected + self.next_expected_frame_id += 1 + + async def _check_frame_timeouts(self): + """Check for frames that have waited too long and handle them""" + if not self.ordered_frames or self.next_expected_frame_id is None: + return + + current_time = time.time() + + # If the next expected frame has timed out, skip it and move on + if self.next_expected_frame_id in self.ordered_frames: + timestamp, _ = self.ordered_frames[self.next_expected_frame_id] + wait_time_ms = (current_time - timestamp) * 1000 + + if wait_time_ms > self.max_frame_wait_ms: + logger.warning(f"Frame {self.next_expected_frame_id} exceeded max wait time, releasing anyway") + await self._release_ordered_frames() + + # Check if we're missing the next expected frame and it's been too long + elif self.ordered_frames: + # The next frame we're expecting isn't in the buffer + # Check how long we've been waiting since the oldest frame in the buffer + oldest_frame_id = min(self.ordered_frames.keys()) + oldest_timestamp, _ = self.ordered_frames[oldest_frame_id] + wait_time_ms = (current_time - oldest_timestamp) * 1000 + + # If we've waited too long, skip the missing frame(s) + if wait_time_ms > self.max_frame_wait_ms: + logger.warning(f"Missing frame {self.next_expected_frame_id}, skipping to {oldest_frame_id}") + self.next_expected_frame_id = oldest_frame_id + await self._release_ordered_frames() + async def warm_video(self): """Warm up the video pipeline with dummy frames for each client""" logger.info("Warming up video pipeline...") @@ -176,8 +246,13 @@ async def put_video_frame(self, frame: av.VideoFrame): self.last_frame_time = current_time - # Generate a unique frame ID - frame_id = int(time.time() * 1000000) # Microseconds as ID + # Generate a unique frame ID - use sequential IDs for better ordering + if not hasattr(self, 'next_frame_id'): + self.next_frame_id = 1 + + frame_id = self.next_frame_id + self.next_frame_id += 1 + frame.side_data.frame_id = frame_id # Preprocess the frame @@ -195,7 +270,7 @@ async def put_video_frame(self, frame: av.VideoFrame): self.clients[client_index].put_video_input(frame) # Also add to the incoming queue for reference - await self.video_incoming_frames.put(frame) + await self.video_incoming_frames.put((frame_id, frame)) logger.info(f"Sent frame {frame_id} to client {client_index}") @@ -239,18 +314,21 @@ def audio_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.Audio async def get_processed_video_frame(self): try: - # Get the frame from the incoming queue first to maintain timing - frame = await self.video_incoming_frames.get() + # Get the original frame from the incoming queue first to maintain timing + frame_id, frame = await self.video_incoming_frames.get() # Skip frames if we're falling behind while not self.video_incoming_frames.empty(): # Get newer frame and mark old one as skipped frame.side_data.skipped = True - frame = await self.video_incoming_frames.get() - logger.info("Skipped older frame to catch up") + frame_id, frame = await self.video_incoming_frames.get() + logger.info(f"Skipped older frame {frame_id} to catch up") # Get the processed frame from our output queue - frame_id, out_tensor = await self.processed_video_frames.get() + processed_frame_id, out_tensor = await self.processed_video_frames.get() + + if processed_frame_id != frame_id: + logger.warning(f"Frame ID mismatch: expected {frame_id}, got {processed_frame_id}") # Process the frame processed_frame = self.video_postprocess(out_tensor) diff --git a/src/comfystream/client_api.py b/src/comfystream/client_api.py index 624de56e..40801205 100644 --- a/src/comfystream/client_api.py +++ b/src/comfystream/client_api.py @@ -84,7 +84,6 @@ async def run_prompt(self, prompt_index: int): # Always set execution complete at start to allow first frame to be processed self.execution_complete_event.set() - logger.info("Setting execution_complete_event to TRUE at start") try: while True: @@ -97,16 +96,15 @@ async def run_prompt(self, prompt_index: int): if self.execution_complete_event.is_set(): # Reset execution state for next frame self.execution_complete_event.clear() - logger.info("Setting execution_complete_event to FALSE before executing prompt") # Queue the prompt with the current frame await self._execute_prompt(prompt_index) # Wait for execution completion with timeout try: - logger.info("Waiting for execution to complete (max 10 seconds)...") + # logger.info("Waiting for execution to complete (max 10 seconds)...") await asyncio.wait_for(self.execution_complete_event.wait(), timeout=10.0) - logger.info("Execution complete, ready for next frame") + # logger.info("Execution complete, ready for next frame") except asyncio.TimeoutError: logger.error("Timeout waiting for execution, forcing continuation") self.execution_complete_event.set() @@ -379,7 +377,7 @@ async def _execute_prompt(self, prompt_index: int): # Check if we have a frame waiting to be processed if not tensor_cache.image_inputs.empty(): - logger.info("Found tensor in input queue, preparing for API") + # logger.info("Found tensor in input queue, preparing for API") # Get the most recent frame only frame_or_tensor = None while not tensor_cache.image_inputs.empty(): @@ -514,15 +512,16 @@ async def _send_tensor_via_websocket(self, tensor): # Check tensor dimensions and log detailed info logger.info(f"Original tensor for WS: shape={tensor.shape}, min={tensor.min().item():.4f}, max={tensor.max().item():.4f}") - # CRITICAL FIX: The issue is with the shape - no need to resize if dimensions are fine - if tensor.shape[1] < 64 or tensor.shape[2] < 64: - logger.warning(f"Tensor dimensions too small: {tensor.shape}. Resizing to 512x512") + # Always ensure consistent 512x512 dimensions + ''' + if tensor.shape[1] != 512 or tensor.shape[2] != 512: + logger.info(f"Resizing tensor from {tensor.shape} to standard 512x512") import torch.nn.functional as F tensor = tensor.unsqueeze(0) # Add batch dimension for interpolate tensor = F.interpolate(tensor, size=(512, 512), mode='bilinear', align_corners=False) tensor = tensor.squeeze(0) # Remove batch dimension after resize - logger.info(f"Resized tensor to: {tensor.shape}") - + ''' + # Check for NaN or Inf values if torch.isnan(tensor).any() or torch.isinf(tensor).any(): logger.warning("Tensor contains NaN or Inf values! Replacing with zeros.") @@ -648,9 +647,9 @@ def put_audio_input(self, frame): async def get_video_output(self): """Get processed video frame from tensor cache""" - logger.info("Waiting for processed tensor from output queue") + # logger.info("Waiting for processed tensor from output queue") result = await tensor_cache.image_outputs.get() - logger.info(f"Got processed tensor from output queue: shape={result.shape if hasattr(result, 'shape') else 'unknown'}") + # logger.info(f"Got processed tensor from output queue: shape={result.shape if hasattr(result, 'shape') else 'unknown'}") return result async def get_audio_output(self): From 63014e3f4e6a0d6c8684589053297d9b74facc0c Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Tue, 25 Mar 2025 17:37:31 -0400 Subject: [PATCH 06/42] Cleaned up prompt manipulation with custom nodes. --- src/comfystream/utils_api.py | 93 +++++++++++++++++++++++++++++++----- 1 file changed, 81 insertions(+), 12 deletions(-) diff --git a/src/comfystream/utils_api.py b/src/comfystream/utils_api.py index 5a932ef1..d207f4e3 100644 --- a/src/comfystream/utils_api.py +++ b/src/comfystream/utils_api.py @@ -15,6 +15,15 @@ def create_load_tensor_node(): "_meta": {"title": "Load Tensor (API)"}, } +def create_load_image_node(): + return { + "inputs": { + "image": "" # Should be "image" not "image_data" to match LoadImageBase64 + }, + "class_type": "LoadImageBase64", + "_meta": {"title": "Load Image Base64 (ComfyStream)"}, + } + def create_save_tensor_node(inputs: Dict[Any, Any]): """Create a SaveTensorAPI node with proper input formatting""" # Make sure images input is properly formatted [node_id, output_index] @@ -35,10 +44,38 @@ def create_save_tensor_node(inputs: Dict[Any, Any]): "_meta": {"title": "Save Tensor (API)"}, } -def convert_prompt(prompt): +def create_save_image_node(inputs: Dict[Any, Any]): + # Get the correct image input reference + images_input = inputs.get("images", inputs.get("image")) + + # If not properly formatted, use default + if not images_input: + images_input = ["", 0] # Default empty value + + return { + "inputs": { + "images": images_input, + "format": "PNG" # Default format + }, + "class_type": "SendImageWebsocket", + "_meta": {"title": "Send Image Websocket (ComfyStream)"}, + } +def convert_prompt(prompt): logging.info("Converting prompt: %s", prompt) + # Initialize counters + num_primary_inputs = 0 + num_inputs = 0 + num_outputs = 0 + + keys = { + "PrimaryInputLoadImage": [], + "LoadImage": [], + "PreviewImage": [], + "SaveImage": [], + } + # Set random seeds for any seed nodes for key, node in prompt.items(): if not isinstance(node, dict) or "inputs" not in node: @@ -50,20 +87,52 @@ def convert_prompt(prompt): random_seed = random.randint(0, 18446744073709551615) node["inputs"]["seed"] = random_seed print(f"Set random seed {random_seed} for node {key}") + + for key, node in prompt.items(): + class_type = node.get("class_type") + # Collect keys for nodes that might need to be replaced + if class_type in keys: + keys[class_type].append(key) - # Replace LoadImage with LoadImageBase64 - for key, node in prompt.items(): - if node.get("class_type") == "LoadImage": - node["class_type"] = "LoadImageBase64" + # Count inputs and outputs + if class_type == "PrimaryInputLoadImage": + num_primary_inputs += 1 + elif class_type in ["LoadImage", "LoadImageBase64"]: + num_inputs += 1 + elif class_type in ["PreviewImage", "SaveImage", "SendImageWebsocket"]: + num_outputs += 1 - # Replace SaveImage/PreviewImage with SendImageWebsocket - for key, node in prompt.items(): - if node.get("class_type") in ["SaveImage", "PreviewImage"]: - node["class_type"] = "SendImageWebsocket" - # Ensure format is set - if "format" not in node["inputs"]: - node["inputs"]["format"] = "PNG" # Set default format + # Only handle single primary input + if num_primary_inputs > 1: + raise Exception("too many primary inputs in prompt") + + # If there are no primary inputs, only handle single input + if num_primary_inputs == 0 and num_inputs > 1: + raise Exception("too many inputs in prompt") + + # Only handle single output for now + if num_outputs > 1: + raise Exception("too many outputs in prompt") + + if num_primary_inputs + num_inputs == 0: + raise Exception("missing input") + + if num_outputs == 0: + raise Exception("missing output") + + # Replace nodes with proper implementations + for key in keys["PrimaryInputLoadImage"]: + prompt[key] = create_load_image_node() + + if num_primary_inputs == 0 and len(keys["LoadImage"]) == 1: + prompt[keys["LoadImage"][0]] = create_load_image_node() + + for key in keys["PreviewImage"] + keys["SaveImage"]: + node = prompt[key] + prompt[key] = create_save_image_node(node["inputs"]) + + # TODO: Validate the processed prompt input return prompt From 15b51a8283369080f1fdbd6af33b5ac3e460e1ea Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Tue, 25 Mar 2025 18:40:17 -0400 Subject: [PATCH 07/42] Added frame tracking, add frame timing stability, added mismatched frame size handling, commented out some logging. --- server/pipeline_api.py | 51 ++++++++----- src/comfystream/client_api.py | 136 ++++++++++++++++++++++++++-------- 2 files changed, 136 insertions(+), 51 deletions(-) diff --git a/server/pipeline_api.py b/server/pipeline_api.py index d2aedf51..1ff1826a 100644 --- a/server/pipeline_api.py +++ b/server/pipeline_api.py @@ -67,26 +67,36 @@ async def _collect_processed_frames(self): # Get frame without waiting try: # Use wait_for with small timeout to avoid blocking - out_tensor = await asyncio.wait_for( + result = await asyncio.wait_for( client.get_video_output(), timeout=0.01 ) - # Find which original frame this corresponds to - frame_ids = [frame_id for frame_id, client_idx in - self.client_frame_mapping.items() if client_idx == i] - - if frame_ids: - # Use the oldest frame ID for this client - frame_id = min(frame_ids) - - # Store frame with timestamp for ordering - timestamp = time.time() - await self._add_frame_to_ordered_buffer(frame_id, timestamp, out_tensor) + # Check if result is already a tuple with frame_id + if isinstance(result, tuple) and len(result) == 2: + frame_id, out_tensor = result + # logger.info(f"Got result with embedded frame_id: {frame_id}") + else: + out_tensor = result + # Find which original frame this corresponds to using our mapping + frame_ids = [frame_id for frame_id, client_idx in + self.client_frame_mapping.items() if client_idx == i] - # Remove the mapping - self.client_frame_mapping.pop(frame_id, None) - logger.info(f"Collected processed frame from client {i}, frame_id: {frame_id}") + if frame_ids: + # Use the oldest frame ID for this client + frame_id = min(frame_ids) + else: + # If no mapping found, log warning and continue + logger.warning(f"No frame_id mapping found for tensor from client {i}") + continue + + # Store frame with timestamp for ordering + timestamp = time.time() + await self._add_frame_to_ordered_buffer(frame_id, timestamp, out_tensor) + + # Remove the mapping + self.client_frame_mapping.pop(frame_id, None) + # logger.info(f"Collected processed frame from client {i}, frame_id: {frame_id}") except asyncio.TimeoutError: # No frame ready yet, continue pass @@ -127,7 +137,7 @@ async def _release_ordered_frames(self): # Put it in the output queue await self.processed_video_frames.put((self.next_expected_frame_id, tensor)) - logger.info(f"Released frame {self.next_expected_frame_id} to output queue") + # logger.info(f"Released frame {self.next_expected_frame_id} to output queue") # Update the next expected frame ID to the next sequential ID if possible # (or the lowest frame ID in our buffer) @@ -163,7 +173,7 @@ async def _check_frame_timeouts(self): # If we've waited too long, skip the missing frame(s) if wait_time_ms > self.max_frame_wait_ms: - logger.warning(f"Missing frame {self.next_expected_frame_id}, skipping to {oldest_frame_id}") + # logger.warning(f"Missing frame {self.next_expected_frame_id}, skipping to {oldest_frame_id}") self.next_expected_frame_id = oldest_frame_id await self._release_ordered_frames() @@ -272,7 +282,7 @@ async def put_video_frame(self, frame: av.VideoFrame): # Also add to the incoming queue for reference await self.video_incoming_frames.put((frame_id, frame)) - logger.info(f"Sent frame {frame_id} to client {client_index}") + # logger.info(f"Sent frame {frame_id} to client {client_index}") async def put_audio_frame(self, frame: av.AudioFrame): # For now, only use the first client for audio @@ -322,13 +332,14 @@ async def get_processed_video_frame(self): # Get newer frame and mark old one as skipped frame.side_data.skipped = True frame_id, frame = await self.video_incoming_frames.get() - logger.info(f"Skipped older frame {frame_id} to catch up") + # logger.info(f"Skipped older frame {frame_id} to catch up") # Get the processed frame from our output queue processed_frame_id, out_tensor = await self.processed_video_frames.get() if processed_frame_id != frame_id: - logger.warning(f"Frame ID mismatch: expected {frame_id}, got {processed_frame_id}") + # logger.warning(f"Frame ID mismatch: expected {frame_id}, got {processed_frame_id}") + pass # Process the frame processed_frame = self.video_postprocess(out_tensor) diff --git a/src/comfystream/client_api.py b/src/comfystream/client_api.py index 40801205..f8725aed 100644 --- a/src/comfystream/client_api.py +++ b/src/comfystream/client_api.py @@ -45,6 +45,10 @@ def __init__(self, host: str = "127.0.0.1", port: int = 8198, **kwargs): self.execution_started = False self._prompt_id = None + # Add frame tracking + self._current_frame_id = None # Track the current frame being processed + self._frame_id_mapping = {} # Map prompt_ids to frame_ids + # Configure logging if 'log_level' in kwargs: logger.setLevel(kwargs['log_level']) @@ -358,8 +362,23 @@ async def _handle_binary_message(self, binary_data): with torch.no_grad(): tensor = torch.from_numpy(np.array(img)).float().permute(2, 0, 1).unsqueeze(0) / 255.0 - # Add to output queue without waiting - tensor_cache.image_outputs.put_nowait(tensor) + # Try to get frame_id from mapping using current prompt_id + frame_id = None + if hasattr(self, '_prompt_id') and self._prompt_id in self._frame_id_mapping: + frame_id = self._frame_id_mapping.get(self._prompt_id) + # logger.info(f"Using frame_id {frame_id} from prompt_id {self._prompt_id}") + elif hasattr(self, '_current_frame_id') and self._current_frame_id is not None: + frame_id = self._current_frame_id + # logger.info(f"Using current frame_id {frame_id}") + + # Add to output queue - include frame_id if available + if frame_id is not None: + tensor_cache.image_outputs.put_nowait((frame_id, tensor)) + # logger.info(f"Added tensor with frame_id {frame_id} to output queue") + else: + tensor_cache.image_outputs.put_nowait(tensor) + #logger.info("Added tensor without frame_id to output queue") + self.execution_complete_event.set() except Exception as img_error: @@ -377,16 +396,26 @@ async def _execute_prompt(self, prompt_index: int): # Check if we have a frame waiting to be processed if not tensor_cache.image_inputs.empty(): - # logger.info("Found tensor in input queue, preparing for API") # Get the most recent frame only frame_or_tensor = None while not tensor_cache.image_inputs.empty(): frame_or_tensor = tensor_cache.image_inputs.get_nowait() + # Extract frame ID if available in side_data + frame_id = None + if hasattr(frame_or_tensor, 'side_data'): + # Try to get frame_id from side_data + if hasattr(frame_or_tensor.side_data, 'frame_id'): + frame_id = frame_or_tensor.side_data.frame_id + logger.info(f"Found frame_id in side_data: {frame_id}") + + # Store current frame ID for binary message handler to use + self._current_frame_id = frame_id + # Find ETN_LoadImageBase64 nodes first load_image_nodes = [] for node_id, node in prompt.items(): - if isinstance(node, dict) and node.get("class_type") in ["ETN_LoadImageBase64", "LoadImageBase64"]: + if isinstance(node, dict) and node.get("class_type") in ["LoadImageBase64"]: load_image_nodes.append(node_id) if not load_image_nodes: @@ -415,34 +444,64 @@ async def _execute_prompt(self, prompt_index: int): self.execution_complete_event.set() return - # Process tensor format only once + # Process tensor format only once - streamlined for speed and reliability with torch.no_grad(): - tensor = tensor.detach().cpu().float() - - # Handle different formats - if len(tensor.shape) == 4: # BCHW format (batch) - tensor = tensor[0] # Take first image from batch - - # Ensure it's in CHW format - if len(tensor.shape) == 3 and tensor.shape[2] == 3: # HWC format - tensor = tensor.permute(2, 0, 1) # Convert to CHW - - # Convert to PIL image for base64 ONLY ONCE - tensor_np = (tensor.permute(1, 2, 0) * 255).clamp(0, 255).numpy().astype(np.uint8) - img = Image.fromarray(tensor_np) + # Fast tensor normalization to ensure consistent output + try: + # TODO: Why is the UI sending different sizes? Should be fixed no? This breaks tensorrt + # I'm sometimes seeing (BCHW): torch.Size([1, 384, 384, 3]), H=384, W=3 + # Ensure minimum size of 512x512 + + # Handle batch dimension if present + if len(tensor.shape) == 4: # BCHW format + tensor = tensor[0] # Take first image from batch + + # Normalize to CHW format consistently + if len(tensor.shape) == 3 and tensor.shape[2] == 3: # HWC format + tensor = tensor.permute(2, 0, 1) # Convert to CHW + + # Handle single-channel case + if len(tensor.shape) == 3 and tensor.shape[0] == 1: + tensor = tensor.repeat(3, 1, 1) # Convert grayscale to RGB + + # Ensure tensor is on CPU + if tensor.is_cuda: + tensor = tensor.cpu() + + # Always resize to 512x512 for consistency (faster than checking dimensions first) + tensor = tensor.unsqueeze(0) # Add batch dim for interpolate + tensor = torch.nn.functional.interpolate( + tensor, size=(512, 512), mode='bilinear', align_corners=False + ) + tensor = tensor[0] # Remove batch dimension + + # Direct conversion to PIL without intermediate numpy step for speed + tensor_np = (tensor.permute(1, 2, 0).clamp(0, 1) * 255).to(torch.uint8).numpy() + img = Image.fromarray(tensor_np) + + # Fast JPEG encoding with balanced quality + buffer = BytesIO() + img.save(buffer, format="JPEG", quality=90, optimize=True) + buffer.seek(0) + img_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8') + + except Exception as e: + logger.warning(f"Error in tensor processing: {e}, creating fallback image") + # Create a standard 512x512 placeholder if anything fails + img = Image.new('RGB', (512, 512), color=(100, 149, 237)) + buffer = BytesIO() + img.save(buffer, format="JPEG", quality=90) + buffer.seek(0) + img_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8') - # Convert to base64 ONCE for all nodes - buffer = BytesIO() - img.save(buffer, format="PNG") - buffer.seek(0) - img_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8') + # Add timestamp for cache busting (once, outside the try/except) + timestamp = int(time.time() * 1000) # Update all nodes with the SAME base64 string - timestamp = int(time.time() * 1000) for node_id in load_image_nodes: prompt[node_id]["inputs"]["image"] = img_base64 prompt[node_id]["inputs"]["_timestamp"] = timestamp - # Use timestamp as cache buster instead of random number + # Use timestamp as cache buster prompt[node_id]["inputs"]["_cache_buster"] = str(timestamp) except Exception as e: @@ -462,6 +521,12 @@ async def _execute_prompt(self, prompt_index: int): if response.status == 200: result = await response.json() self._prompt_id = result.get("prompt_id") + + # Map prompt_id to frame_id for later retrieval + if frame_id is not None: + self._frame_id_mapping[self._prompt_id] = frame_id + # logger.info(f"Mapped prompt_id {self._prompt_id} to frame_id {frame_id}") + self.execution_started = True else: error_text = await response.text() @@ -493,7 +558,8 @@ async def _send_tensor_via_websocket(self, tensor): # Prepare binary data if len(tensor.shape) == 4: # BCHW format (batch of images) if tensor.shape[0] > 1: - logger.info(f"Taking first image from batch of {tensor.shape[0]}") + # logger.info(f"Taking first image from batch of {tensor.shape[0]}") + pass tensor = tensor[0] # Take first image if batch # Ensure CHW format (3 channels) @@ -510,7 +576,7 @@ async def _send_tensor_via_websocket(self, tensor): tensor = torch.zeros(3, 512, 512) # Check tensor dimensions and log detailed info - logger.info(f"Original tensor for WS: shape={tensor.shape}, min={tensor.min().item():.4f}, max={tensor.max().item():.4f}") + # logger.info(f"Original tensor for WS: shape={tensor.shape}, min={tensor.min().item():.4f}, max={tensor.max().item():.4f}") # Always ensure consistent 512x512 dimensions ''' @@ -557,7 +623,7 @@ async def _send_tensor_via_websocket(self, tensor): # Send binary data via websocket await self.ws.send(full_data) - logger.info(f"Sent tensor as PNG image via websocket with proper header, size: {len(full_data)} bytes, image dimensions: {img.size}") + # logger.info(f"Sent tensor as PNG image via websocket with proper header, size: {len(full_data)} bytes, image dimensions: {img.size}") except Exception as e: logger.error(f"Error sending tensor via websocket: {e}") @@ -647,10 +713,18 @@ def put_audio_input(self, frame): async def get_video_output(self): """Get processed video frame from tensor cache""" - # logger.info("Waiting for processed tensor from output queue") result = await tensor_cache.image_outputs.get() - # logger.info(f"Got processed tensor from output queue: shape={result.shape if hasattr(result, 'shape') else 'unknown'}") - return result + + # Check if the result is a tuple with frame_id + if isinstance(result, tuple) and len(result) == 2: + frame_id, tensor = result + # logger.info(f"Got processed tensor from output queue with frame_id {frame_id}") + # Return both the frame_id and tensor to help with ordering in the pipeline + return frame_id, tensor + else: + # If it's not a tuple with frame_id, just return the tensor + # logger.info("Got processed tensor from output queue without frame_id") + return result async def get_audio_output(self): """Get processed audio frame from tensor cache""" From 80636b648443ae3e4d210bd85c7e49cdf6759bfb Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Tue, 1 Apr 2025 12:15:56 -0400 Subject: [PATCH 08/42] Removed requirement for workspace in app startup. --- server/app_api.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/server/app_api.py b/server/app_api.py index b7b56e56..b5676aae 100644 --- a/server/app_api.py +++ b/server/app_api.py @@ -11,7 +11,6 @@ if torch.cuda.is_available(): torch.cuda.init() - from aiohttp import web from aiortc import ( MediaStreamTrack, @@ -317,6 +316,7 @@ async def on_connectionstatechange(): ), ) + async def cancel_collect_frames(track): track.running = False if hasattr(track, 'collect_task') is not None and not track.collect_task.done(): @@ -326,6 +326,7 @@ async def cancel_collect_frames(track): except (asyncio.CancelledError): pass + async def set_prompt(request): pipeline = request.app["pipeline"] @@ -346,7 +347,6 @@ async def on_startup(app: web.Application): app["pipeline"] = Pipeline( config_path=app["config_file"], max_frame_wait_ms=app["max_frame_wait"], - cwd=app["workspace"], disable_cuda_malloc=True, gpu_only=True, preview_method='none' @@ -371,9 +371,6 @@ async def on_shutdown(app: web.Application): "--media-ports", default=None, help="Set the UDP ports for WebRTC media" ) parser.add_argument("--host", default="127.0.0.1", help="Set the host") - parser.add_argument( - "--workspace", default=None, required=True, help="Set Comfy workspace" - ) parser.add_argument( "--log-level", default="INFO", @@ -414,7 +411,6 @@ async def on_shutdown(app: web.Application): app = web.Application() app["media_ports"] = args.media_ports.split(",") if args.media_ports else None - app["workspace"] = args.workspace app["config_file"] = args.config_file app["max_frame_wait"] = args.max_frame_wait From 2cdb6faaa8dfe7e693f30edb6d23f92511df7b56 Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Tue, 1 Apr 2025 15:07:20 -0400 Subject: [PATCH 09/42] Cleanup of logging, added log_level argument, testing of send tensor websocket node. --- nodes/native_utils/__init__.py | 7 +- nodes/native_utils/send_tensor_websocket.py | 289 ++++++++++++++++++++ server/app_api.py | 8 +- server/pipeline_api.py | 16 +- src/comfystream/client_api.py | 90 ++---- src/comfystream/utils_api.py | 97 ++----- 6 files changed, 354 insertions(+), 153 deletions(-) create mode 100644 nodes/native_utils/send_tensor_websocket.py diff --git a/nodes/native_utils/__init__.py b/nodes/native_utils/__init__.py index e7e7789c..472e5411 100644 --- a/nodes/native_utils/__init__.py +++ b/nodes/native_utils/__init__.py @@ -1,16 +1,19 @@ from .load_image_base64 import LoadImageBase64 from .send_image_websocket import SendImageWebsocket +from .send_tensor_websocket import SendTensorWebSocket # This dictionary is used by ComfyUI to register the nodes NODE_CLASS_MAPPINGS = { "LoadImageBase64": LoadImageBase64, - "SendImageWebsocket": SendImageWebsocket + "SendImageWebsocket": SendImageWebsocket, + "SendTensorWebSocket": SendTensorWebSocket } # This dictionary provides display names for the nodes in the UI NODE_DISPLAY_NAME_MAPPINGS = { "LoadImageBase64": "Load Image Base64 (ComfyStream)", - "SendImageWebsocket": "Send Image Websocket (ComfyStream)" + "SendImageWebsocket": "Send Image Websocket (ComfyStream)", + "SendTensorWebSocket": "Save Tensor WebSocket (ComfyStream)" } # Export these variables for ComfyUI to use diff --git a/nodes/native_utils/send_tensor_websocket.py b/nodes/native_utils/send_tensor_websocket.py new file mode 100644 index 00000000..33104896 --- /dev/null +++ b/nodes/native_utils/send_tensor_websocket.py @@ -0,0 +1,289 @@ +import torch +import numpy as np +import base64 +import logging +import json +import traceback +import sys + +logger = logging.getLogger(__name__) + +# Log when the module is loaded +logger.debug("------------------ SendTensorWebSocket Module Loaded ------------------") + +class SendTensorWebSocket: + def __init__(self): + # Output directory is not needed as we send via WebSocket + logger.debug("SendTensorWebSocket instance created") + pass + + @classmethod + def INPUT_TYPES(cls): + logger.debug("SendTensorWebSocket.INPUT_TYPES called") + return { + "required": { + # Accept IMAGE input (typical output from VAE Decode) + "tensor": ("IMAGE", ), + }, + "hidden": { + # These are needed for ComfyUI execution context + "prompt": "PROMPT", + "extra_pnginfo": "EXTRA_PNGINFO" + }, + } + + RETURN_TYPES = () # No direct output connection to other nodes + FUNCTION = "save_tensor" + OUTPUT_NODE = True + CATEGORY = "ComfyStream/native" + + def save_tensor(self, tensor, prompt=None, extra_pnginfo=None): + logger.debug("========== SendTensorWebSocket.save_tensor STARTED ==========") + logger.info(f"SendTensorWebSocket received input. Type: {type(tensor)}") + logger.debug(f"SendTensorWebSocket node is processing tensor with id: {id(tensor)}") + + # Log memory usage for debugging + if torch.cuda.is_available(): + try: + logger.debug(f"CUDA memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB") + logger.debug(f"CUDA memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB") + except Exception as e: + logger.error(f"Error checking CUDA memory: {e}") + + if tensor is None: + logger.error("SendTensorWebSocket received None tensor.") + # Return error directly without ui nesting + return {"comfystream_tensor_output": {"error": "Input tensor was None"}} + + try: + # Log details about the tensor before processing + logger.debug(f"Process tensor of type: {type(tensor)}") + + if isinstance(tensor, torch.Tensor): + logger.debug("Processing torch.Tensor...") + logger.info(f"Input tensor details: shape={tensor.shape}, dtype={tensor.dtype}, device={tensor.device}") + + # Additional handling for IMAGE-type tensors (0-1 float values, BCHW format) + if len(tensor.shape) == 4: # BCHW format (batch) + logger.debug(f"Tensor is batched (BCHW): {tensor.shape}") + logger.info(f"Tensor appears to be IMAGE batch. Min: {tensor.min().item()}, Max: {tensor.max().item()}") + logger.debug(f"First batch slice: min={tensor[0].min().item()}, max={tensor[0].max().item()}") + tensor = tensor[0] # Select first image from batch + logger.debug(f"Selected first batch element. New shape: {tensor.shape}") + + if len(tensor.shape) == 3: # CHW format (single image) + logger.debug(f"Tensor is CHW format: {tensor.shape}") + logger.info(f"Tensor appears to be single IMAGE. Min: {tensor.min().item()}, Max: {tensor.max().item()}") + + # Log first few values for debugging + logger.debug(f"First few values: {tensor.flatten()[:10].tolist()}") + + # Ensure the tensor is on CPU and detached + logger.debug(f"Moving tensor to CPU. Current device: {tensor.device}") + try: + tensor = tensor.cpu().detach() + logger.debug(f"Tensor moved to CPU successfully: {tensor.device}") + except Exception as e: + logger.error(f"Error moving tensor to CPU: {e}") + logger.error(traceback.format_exc()) + return {"comfystream_tensor_output": {"error": f"CPU transfer error: {str(e)}"}} + + # Convert to numpy + logger.debug("Converting tensor to numpy array...") + try: + np_array = tensor.numpy() + logger.debug(f"Conversion to numpy successful: shape={np_array.shape}, dtype={np_array.dtype}") + logger.debug(f"NumPy array memory usage: {np_array.nbytes / 1024**2:.2f} MB") + except Exception as e: + logger.error(f"Error converting tensor to numpy: {e}") + logger.error(traceback.format_exc()) + return {"comfystream_tensor_output": {"error": f"NumPy conversion error: {str(e)}"}} + + # Encode the tensor + logger.debug("Converting numpy array to bytes...") + try: + tensor_bytes = np_array.tobytes() + logger.debug(f"Tensor converted to bytes: {len(tensor_bytes)} bytes") + except Exception as e: + logger.error(f"Error converting numpy array to bytes: {e}") + logger.error(traceback.format_exc()) + return {"comfystream_tensor_output": {"error": f"Bytes conversion error: {str(e)}"}} + + logger.debug("Encoding bytes to base64...") + try: + b64_data = base64.b64encode(tensor_bytes).decode('utf-8') + b64_size = len(b64_data) + logger.debug(f"Base64 encoding successful: {b64_size} characters") + if b64_size > 100: + logger.debug(f"Base64 sample: {b64_data[:50]}...{b64_data[-50:]}") + except Exception as e: + logger.error(f"Error encoding to base64: {e}") + logger.error(traceback.format_exc()) + return {"comfystream_tensor_output": {"error": f"Base64 encoding error: {str(e)}"}} + + # Prepare metadata + shape = list(np_array.shape) + dtype = str(np_array.dtype) + + logger.info(f"SendTensorWebSocket prepared tensor: shape={shape}, dtype={dtype}") + + # Construct the return value with simplified structure (no ui nesting) + success_output = { + "comfystream_tensor_output": { + "b64_data": b64_data, + "shape": shape, + "dtype": dtype + } + } + + # Log the structure of the output (avoid logging the actual base64 data which is large) + output_structure = { + "comfystream_tensor_output": { + "b64_data": f"(base64 string of {b64_size} bytes)", + "shape": shape, + "dtype": dtype + } + } + logger.info(f"SendTensorWebSocket returning SUCCESS data structure: {json.dumps(output_structure)}") + logger.debug("========== SendTensorWebSocket.save_tensor COMPLETED SUCCESSFULLY ==========") + + return success_output + + elif isinstance(tensor, np.ndarray): + logger.debug("Processing numpy.ndarray...") + logger.info(f"Input is numpy array: shape={tensor.shape}, dtype={tensor.dtype}") + + # Log memory details + logger.debug(f"NumPy array memory usage: {tensor.nbytes / 1024**2:.2f} MB") + logger.debug(f"First few values: {tensor.flatten()[:10].tolist()}") + + # Handle numpy array directly + logger.debug("Converting numpy array to bytes...") + try: + tensor_bytes = tensor.tobytes() + logger.debug(f"NumPy array converted to bytes: {len(tensor_bytes)} bytes") + except Exception as e: + logger.error(f"Error converting numpy array to bytes: {e}") + logger.error(traceback.format_exc()) + return {"comfystream_tensor_output": {"error": f"NumPy to bytes error: {str(e)}"}} + + logger.debug("Encoding bytes to base64...") + try: + b64_data = base64.b64encode(tensor_bytes).decode('utf-8') + b64_size = len(b64_data) + logger.debug(f"Base64 encoding successful: {b64_size} characters") + if b64_size > 100: + logger.debug(f"Base64 sample: {b64_data[:50]}...{b64_data[-50:]}") + except Exception as e: + logger.error(f"Error encoding numpy to base64: {e}") + logger.error(traceback.format_exc()) + return {"comfystream_tensor_output": {"error": f"NumPy base64 encoding error: {str(e)}"}} + + shape = list(tensor.shape) + dtype = str(tensor.dtype) + + logger.debug("Constructing success output for numpy array...") + success_output = { + "comfystream_tensor_output": { + "b64_data": b64_data, + "shape": shape, + "dtype": dtype + } + } + logger.info(f"SendTensorWebSocket returning SUCCESS from numpy array: shape={shape}, dtype={dtype}") + logger.debug("========== SendTensorWebSocket.save_tensor COMPLETED SUCCESSFULLY ==========") + return success_output + + elif isinstance(tensor, list): + logger.debug("Processing list input...") + logger.info(f"Input is a list of length {len(tensor)}") + + if len(tensor) > 0: + first_item = tensor[0] + logger.debug(f"First item type: {type(first_item)}") + + if isinstance(first_item, torch.Tensor): + logger.debug("Processing first tensor from list...") + logger.debug(f"First tensor details: shape={first_item.shape}, dtype={first_item.dtype}, device={first_item.device}") + + # Log first few values + logger.debug(f"First few values: {first_item.flatten()[:10].tolist()}") + + # Process first tensor in the list + try: + logger.debug("Moving tensor to CPU and detaching...") + np_array = first_item.cpu().detach().numpy() + logger.debug(f"Conversion successful: shape={np_array.shape}, dtype={np_array.dtype}") + except Exception as e: + logger.error(f"Error processing first tensor in list: {e}") + logger.error(traceback.format_exc()) + return {"comfystream_tensor_output": {"error": f"List tensor processing error: {str(e)}"}} + + try: + logger.debug("Converting numpy array to bytes...") + tensor_bytes = np_array.tobytes() + logger.debug(f"Converted to bytes: {len(tensor_bytes)} bytes") + except Exception as e: + logger.error(f"Error converting list tensor to bytes: {e}") + logger.error(traceback.format_exc()) + return {"comfystream_tensor_output": {"error": f"List tensor bytes conversion error: {str(e)}"}} + + try: + logger.debug("Encoding bytes to base64...") + b64_data = base64.b64encode(tensor_bytes).decode('utf-8') + b64_size = len(b64_data) + logger.debug(f"Base64 encoding successful: {b64_size} characters") + except Exception as e: + logger.error(f"Error encoding list tensor to base64: {e}") + logger.error(traceback.format_exc()) + return {"comfystream_tensor_output": {"error": f"List tensor base64 encoding error: {str(e)}"}} + + shape = list(np_array.shape) + dtype = str(np_array.dtype) + + logger.debug("Constructing success output for list tensor...") + success_output = { + "comfystream_tensor_output": { + "b64_data": b64_data, + "shape": shape, + "dtype": dtype + } + } + logger.info(f"SendTensorWebSocket returning SUCCESS from list's first tensor: shape={shape}, dtype={dtype}") + logger.debug("========== SendTensorWebSocket.save_tensor COMPLETED SUCCESSFULLY ==========") + return success_output + else: + logger.error(f"First item in list is not a tensor but {type(first_item)}") + if hasattr(first_item, '__dict__'): + logger.debug(f"First item attributes: {dir(first_item)}") + + # If we got here, couldn't process the list + logger.error(f"Unable to process list input: invalid content types") + list_types = [type(x).__name__ for x in tensor[:3]] + error_msg = f"Unsupported list content: {list_types}..." + logger.debug("========== SendTensorWebSocket.save_tensor FAILED ==========") + return {"comfystream_tensor_output": {"error": error_msg}} + + else: + # Unsupported type + error_msg = f"Unsupported tensor type: {type(tensor)}" + logger.error(error_msg) + if hasattr(tensor, '__dict__'): + logger.debug(f"Tensor attributes: {dir(tensor)}") + + logger.debug("========== SendTensorWebSocket.save_tensor FAILED ==========") + return {"comfystream_tensor_output": {"error": error_msg}} + + except Exception as e: + logger.exception(f"Error serializing tensor in SendTensorWebSocket: {e}") + + # Get detailed exception info + exc_type, exc_value, exc_traceback = sys.exc_info() + tb_lines = traceback.format_exception(exc_type, exc_value, exc_traceback) + tb_text = ''.join(tb_lines) + logger.debug(f"Exception traceback:\n{tb_text}") + + error_output = {"comfystream_tensor_output": {"error": f"{str(e)} - See save_tensor_websocket_debug.log for details"}} + logger.info(f"SendTensorWebSocket returning ERROR data: {error_output}") + logger.debug("========== SendTensorWebSocket.save_tensor FAILED WITH EXCEPTION ==========") + return error_output \ No newline at end of file diff --git a/server/app_api.py b/server/app_api.py index b5676aae..03fa4eda 100644 --- a/server/app_api.py +++ b/server/app_api.py @@ -372,8 +372,9 @@ async def on_shutdown(app: web.Application): ) parser.add_argument("--host", default="127.0.0.1", help="Set the host") parser.add_argument( - "--log-level", - default="INFO", + "--log-level", "--log_level", + dest="log_level", + default="WARNING", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], help="Set the logging level", ) @@ -409,6 +410,9 @@ async def on_shutdown(app: web.Application): datefmt="%H:%M:%S", ) + # Set logger level based on command line arguments + logger.setLevel(getattr(logging, args.log_level.upper())) + app = web.Application() app["media_ports"] = args.media_ports.split(",") if args.media_ports else None app["config_file"] = args.config_file diff --git a/server/pipeline_api.py b/server/pipeline_api.py index 1ff1826a..23381e81 100644 --- a/server/pipeline_api.py +++ b/server/pipeline_api.py @@ -5,7 +5,7 @@ import logging import time import random -from collections import deque, OrderedDict +from collections import OrderedDict from typing import Any, Dict, Union, List, Optional, Deque from comfystream.client_api import ComfyStreamClient @@ -75,7 +75,7 @@ async def _collect_processed_frames(self): # Check if result is already a tuple with frame_id if isinstance(result, tuple) and len(result) == 2: frame_id, out_tensor = result - # logger.info(f"Got result with embedded frame_id: {frame_id}") + logger.info(f"Got result with embedded frame_id: {frame_id}") else: out_tensor = result # Find which original frame this corresponds to using our mapping @@ -96,7 +96,7 @@ async def _collect_processed_frames(self): # Remove the mapping self.client_frame_mapping.pop(frame_id, None) - # logger.info(f"Collected processed frame from client {i}, frame_id: {frame_id}") + logger.info(f"Collected processed frame from client {i}, frame_id: {frame_id}") except asyncio.TimeoutError: # No frame ready yet, continue pass @@ -137,7 +137,7 @@ async def _release_ordered_frames(self): # Put it in the output queue await self.processed_video_frames.put((self.next_expected_frame_id, tensor)) - # logger.info(f"Released frame {self.next_expected_frame_id} to output queue") + logger.info(f"Released frame {self.next_expected_frame_id} to output queue") # Update the next expected frame ID to the next sequential ID if possible # (or the lowest frame ID in our buffer) @@ -173,7 +173,7 @@ async def _check_frame_timeouts(self): # If we've waited too long, skip the missing frame(s) if wait_time_ms > self.max_frame_wait_ms: - # logger.warning(f"Missing frame {self.next_expected_frame_id}, skipping to {oldest_frame_id}") + logger.warning(f"Missing frame {self.next_expected_frame_id}, skipping to {oldest_frame_id}") self.next_expected_frame_id = oldest_frame_id await self._release_ordered_frames() @@ -282,7 +282,7 @@ async def put_video_frame(self, frame: av.VideoFrame): # Also add to the incoming queue for reference await self.video_incoming_frames.put((frame_id, frame)) - # logger.info(f"Sent frame {frame_id} to client {client_index}") + logger.debug(f"Sent frame {frame_id} to client {client_index}") async def put_audio_frame(self, frame: av.AudioFrame): # For now, only use the first client for audio @@ -332,13 +332,13 @@ async def get_processed_video_frame(self): # Get newer frame and mark old one as skipped frame.side_data.skipped = True frame_id, frame = await self.video_incoming_frames.get() - # logger.info(f"Skipped older frame {frame_id} to catch up") + logger.info(f"Skipped older frame {frame_id} to catch up") # Get the processed frame from our output queue processed_frame_id, out_tensor = await self.processed_video_frames.get() if processed_frame_id != frame_id: - # logger.warning(f"Frame ID mismatch: expected {frame_id}, got {processed_frame_id}") + logger.warning(f"Frame ID mismatch: expected {frame_id}, got {processed_frame_id}") pass # Process the frame diff --git a/src/comfystream/client_api.py b/src/comfystream/client_api.py index f8725aed..2fd14bb1 100644 --- a/src/comfystream/client_api.py +++ b/src/comfystream/client_api.py @@ -106,9 +106,9 @@ async def run_prompt(self, prompt_index: int): # Wait for execution completion with timeout try: - # logger.info("Waiting for execution to complete (max 10 seconds)...") + logger.info("Waiting for execution to complete (max 10 seconds)...") await asyncio.wait_for(self.execution_complete_event.wait(), timeout=10.0) - # logger.info("Execution complete, ready for next frame") + logger.info("Execution complete, ready for next frame") except asyncio.TimeoutError: logger.error("Timeout waiting for execution, forcing continuation") self.execution_complete_event.set() @@ -139,9 +139,6 @@ async def _connect_websocket(self): logger.info(f"Connecting to WebSocket at {self.server_address}?clientId={self.client_id}") - # Set a reasonable timeout for connection - websocket_timeout = 10.0 # seconds - try: # Connect with proper error handling self.ws = await websockets.connect( @@ -222,12 +219,12 @@ async def _handle_text_message(self, message: str): data = json.loads(message) message_type = data.get("type", "unknown") - # logger.info(f"Received message type: {message_type}") + logger.debug(f"Received message type: {message_type}") + logger.debug(f"{data}") + ''' # Handle different message types if message_type == "status": - pass - ''' # Status message with comfy_ui's queue information queue_remaining = data.get("data", {}).get("queue_remaining", 0) exec_info = data.get("data", {}).get("exec_info", {}) @@ -235,20 +232,19 @@ async def _handle_text_message(self, message: str): logger.info("Queue empty, no active execution") else: logger.info(f"Queue status: {queue_remaining} items remaining") - ''' elif message_type == "progress": if "data" in data and "value" in data["data"]: progress = data["data"]["value"] max_value = data["data"].get("max", 100) # Log the progress for debugging - # logger.info(f"Progress: {progress}/{max_value}") + logger.info(f"Progress: {progress}/{max_value}") elif message_type == "execution_start": self.execution_started = True if "data" in data and "prompt_id" in data["data"]: self._prompt_id = data["data"]["prompt_id"] - # logger.info(f"Execution started for prompt {self._prompt_id}") + logger.info(f"Execution started for prompt {self._prompt_id}") elif message_type == "executing": self.execution_started = True @@ -257,18 +253,19 @@ async def _handle_text_message(self, message: str): self._prompt_id = data["data"]["prompt_id"] if "node" in data["data"]: node_id = data["data"]["node"] - # logger.info(f"Executing node: {node_id}") + logger.info(f"Executing node: {node_id}") elif message_type in ["execution_cached", "execution_error", "execution_complete", "execution_interrupted"]: - # logger.info(f"{message_type} message received for prompt {self._prompt_id}") + logger.info(f"{message_type} message received for prompt {self._prompt_id}") # self.execution_started = False # Always signal completion for these terminal states # self.execution_complete_event.set() - # logger.info(f"Set execution_complete_event from {message_type}") + logger.info(f"Set execution_complete_event from {message_type}") pass + ''' - elif message_type == "executed": + if message_type == "executed": # This is sent when a node is completely done if "data" in data and "node_id" in data["data"]: node_id = data["data"]["node_id"] @@ -283,50 +280,9 @@ async def _handle_text_message(self, message: str): elif self.execution_started and not self.execution_complete_event.is_set(): # Check if this was the last node if data.get("data", {}).get("remaining", 0) == 0: - # logger.info("All nodes executed but no tensor data received, forcing completion") # self.execution_complete_event.set() pass - - elif message_type == "executed_node" and "output" in data.get("data", {}): - node_id = data.get("data", {}).get("node_id") - output_data = data.get("data", {}).get("output", {}) - prompt_id = data.get("data", {}).get("prompt_id", "unknown") - - logger.info(f"Node {node_id} executed in prompt {prompt_id}") - - ''' - # Check if this is from ETN_SendImageWebSocket node - if "ui" in output_data and "images" in output_data["ui"]: - images_info = output_data["ui"]["images"] - logger.info(f"Found image output from ETN_SendImageWebSocket in node {node_id}") - - # Images will be received via binary websocket messages after this event - # The binary handler will take care of them - pass - - # Keep existing handling for tensor data - elif "ui" in output_data and "tensor" in output_data["ui"]: - tensor_info = output_data["ui"]["tensor"] - tensor_id = tensor_info.get("tensor_id", "unknown") - logger.info(f"Found tensor data with ID: {tensor_id} in node {node_id}") - - # Decode the tensor data - tensor_data = await self._decode_tensor_data(tensor_info) - if tensor_data is not None: - # Add to output queue without waiting to unblock event loop - tensor_cache.image_outputs.put_nowait(tensor_data) - logger.info(f"Added tensor to output queue, shape: {tensor_data.shape}") - - # IMPORTANT: Immediately signal that we can proceed with the next frame - # when we receive tensor data, don't wait - logger.info("Received tensor data, immediately signaling execution complete") - self.execution_complete_event.set() - logger.info("Set execution_complete_event after processing tensor data") - else: - logger.error("Failed to decode tensor data") - # Signal completion even if decoding failed to prevent hanging - self.execution_complete_event.set() - ''' + except json.JSONDecodeError: logger.error(f"Invalid JSON message: {message[:100]}...") except Exception as e: @@ -366,18 +322,18 @@ async def _handle_binary_message(self, binary_data): frame_id = None if hasattr(self, '_prompt_id') and self._prompt_id in self._frame_id_mapping: frame_id = self._frame_id_mapping.get(self._prompt_id) - # logger.info(f"Using frame_id {frame_id} from prompt_id {self._prompt_id}") + logger.info(f"Using frame_id {frame_id} from prompt_id {self._prompt_id}") elif hasattr(self, '_current_frame_id') and self._current_frame_id is not None: frame_id = self._current_frame_id - # logger.info(f"Using current frame_id {frame_id}") + logger.info(f"Using current frame_id {frame_id}") # Add to output queue - include frame_id if available if frame_id is not None: tensor_cache.image_outputs.put_nowait((frame_id, tensor)) - # logger.info(f"Added tensor with frame_id {frame_id} to output queue") + logger.debug(f"Added tensor with frame_id {frame_id} to output queue") else: tensor_cache.image_outputs.put_nowait(tensor) - #logger.info("Added tensor without frame_id to output queue") + logger.debug("Added tensor without frame_id to output queue") self.execution_complete_event.set() @@ -525,7 +481,7 @@ async def _execute_prompt(self, prompt_index: int): # Map prompt_id to frame_id for later retrieval if frame_id is not None: self._frame_id_mapping[self._prompt_id] = frame_id - # logger.info(f"Mapped prompt_id {self._prompt_id} to frame_id {frame_id}") + logger.info(f"Mapped prompt_id {self._prompt_id} to frame_id {frame_id}") self.execution_started = True else: @@ -558,7 +514,7 @@ async def _send_tensor_via_websocket(self, tensor): # Prepare binary data if len(tensor.shape) == 4: # BCHW format (batch of images) if tensor.shape[0] > 1: - # logger.info(f"Taking first image from batch of {tensor.shape[0]}") + logger.info(f"Taking first image from batch of {tensor.shape[0]}") pass tensor = tensor[0] # Take first image if batch @@ -576,7 +532,7 @@ async def _send_tensor_via_websocket(self, tensor): tensor = torch.zeros(3, 512, 512) # Check tensor dimensions and log detailed info - # logger.info(f"Original tensor for WS: shape={tensor.shape}, min={tensor.min().item():.4f}, max={tensor.max().item():.4f}") + logger.info(f"Original tensor for WS: shape={tensor.shape}, min={tensor.min().item():.4f}, max={tensor.max().item():.4f}") # Always ensure consistent 512x512 dimensions ''' @@ -623,7 +579,7 @@ async def _send_tensor_via_websocket(self, tensor): # Send binary data via websocket await self.ws.send(full_data) - # logger.info(f"Sent tensor as PNG image via websocket with proper header, size: {len(full_data)} bytes, image dimensions: {img.size}") + logger.info(f"Sent tensor as PNG image via websocket with proper header, size: {len(full_data)} bytes, image dimensions: {img.size}") except Exception as e: logger.error(f"Error sending tensor via websocket: {e}") @@ -718,12 +674,12 @@ async def get_video_output(self): # Check if the result is a tuple with frame_id if isinstance(result, tuple) and len(result) == 2: frame_id, tensor = result - # logger.info(f"Got processed tensor from output queue with frame_id {frame_id}") + logger.info(f"Got processed tensor from output queue with frame_id {frame_id}") # Return both the frame_id and tensor to help with ordering in the pipeline return frame_id, tensor else: # If it's not a tuple with frame_id, just return the tensor - # logger.info("Got processed tensor from output queue without frame_id") + logger.info("Got processed tensor from output queue without frame_id") return result async def get_audio_output(self): diff --git a/src/comfystream/utils_api.py b/src/comfystream/utils_api.py index d207f4e3..ba31ccd4 100644 --- a/src/comfystream/utils_api.py +++ b/src/comfystream/utils_api.py @@ -15,7 +15,7 @@ def create_load_tensor_node(): "_meta": {"title": "Load Tensor (API)"}, } -def create_load_image_node(): +def create_load_image_base64_node(): return { "inputs": { "image": "" # Should be "image" not "image_data" to match LoadImageBase64 @@ -44,7 +44,7 @@ def create_save_tensor_node(inputs: Dict[Any, Any]): "_meta": {"title": "Save Tensor (API)"}, } -def create_save_image_node(inputs: Dict[Any, Any]): +def create_send_image_websocket_node(inputs: Dict[Any, Any]): # Get the correct image input reference images_input = inputs.get("images", inputs.get("image")) @@ -61,6 +61,22 @@ def create_save_image_node(inputs: Dict[Any, Any]): "_meta": {"title": "Send Image Websocket (ComfyStream)"}, } +def create_send_tensor_websocket_node(inputs: Dict[Any, Any]): + # Get the correct image input reference + tensor_input = inputs.get("images", inputs.get("tensor")) + + if not tensor_input: + logging.warning("No valid tensor input found for SendTensorWebSocket node") + tensor_input = ["", 0] # Default empty value + + return { + "inputs": { + "tensor": tensor_input + }, + "class_type": "SendTensorWebSocket", + "_meta": {"title": "Save Tensor WebSocket (ComfyStream)"}, + } + def convert_prompt(prompt): logging.info("Converting prompt: %s", prompt) @@ -100,7 +116,7 @@ def convert_prompt(prompt): num_primary_inputs += 1 elif class_type in ["LoadImage", "LoadImageBase64"]: num_inputs += 1 - elif class_type in ["PreviewImage", "SaveImage", "SendImageWebsocket"]: + elif class_type in ["PreviewImage", "SaveImage", "SendImageWebsocket", "SendTensorWebSocket"]: num_outputs += 1 # Only handle single primary input @@ -123,83 +139,16 @@ def convert_prompt(prompt): # Replace nodes with proper implementations for key in keys["PrimaryInputLoadImage"]: - prompt[key] = create_load_image_node() + prompt[key] = create_load_image_base64_node() if num_primary_inputs == 0 and len(keys["LoadImage"]) == 1: - prompt[keys["LoadImage"][0]] = create_load_image_node() + prompt[keys["LoadImage"][0]] = create_load_image_base64_node() for key in keys["PreviewImage"] + keys["SaveImage"]: node = prompt[key] - prompt[key] = create_save_image_node(node["inputs"]) + # prompt[key] = create_save_image_node(node["inputs"]) + prompt[key] = create_send_image_websocket_node(node["inputs"]) # TESTING # TODO: Validate the processed prompt input return prompt - -''' -def convert_prompt(prompt: PromptDictInput) -> Prompt: - # Validate the schema - Prompt.validate(prompt) - - prompt = copy.deepcopy(prompt) - - num_primary_inputs = 0 - num_inputs = 0 - num_outputs = 0 - - keys = { - "PrimaryInputLoadImage": [], - "LoadImage": [], - "PreviewImage": [], - "SaveImage": [], - } - - for key, node in prompt.items(): - class_type = node.get("class_type") - - # Collect keys for nodes that might need to be replaced - if class_type in keys: - keys[class_type].append(key) - - # Count inputs and outputs - if class_type == "PrimaryInputLoadImage": - num_primary_inputs += 1 - elif class_type in ["LoadImage", "LoadTensor", "LoadAudioTensor"]: - num_inputs += 1 - elif class_type in ["PreviewImage", "SaveImage", "SaveTensor", "SaveAudioTensor"]: - num_outputs += 1 - - # Only handle single primary input - if num_primary_inputs > 1: - raise Exception("too many primary inputs in prompt") - - # If there are no primary inputs, only handle single input - if num_primary_inputs == 0 and num_inputs > 1: - raise Exception("too many inputs in prompt") - - # Only handle single output for now - if num_outputs > 1: - raise Exception("too many outputs in prompt") - - if num_primary_inputs + num_inputs == 0: - raise Exception("missing input") - - if num_outputs == 0: - raise Exception("missing output") - - # Replace nodes - for key in keys["PrimaryInputLoadImage"]: - prompt[key] = create_load_tensor_node() - - if num_primary_inputs == 0 and len(keys["LoadImage"]) == 1: - prompt[keys["LoadImage"][0]] = create_load_tensor_node() - - for key in keys["PreviewImage"] + keys["SaveImage"]: - node = prompt[key] - prompt[key] = create_save_tensor_node(node["inputs"]) - - # Validate the processed prompt input - prompt = Prompt.validate(prompt) - - return prompt -''' \ No newline at end of file From c4d2ea177d0cb6a136e390140f8da380109d190d Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Tue, 1 Apr 2025 15:10:48 -0400 Subject: [PATCH 10/42] Setting a few logs to debug. --- server/pipeline_api.py | 2 +- src/comfystream/client_api.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/server/pipeline_api.py b/server/pipeline_api.py index 23381e81..805f8734 100644 --- a/server/pipeline_api.py +++ b/server/pipeline_api.py @@ -75,7 +75,7 @@ async def _collect_processed_frames(self): # Check if result is already a tuple with frame_id if isinstance(result, tuple) and len(result) == 2: frame_id, out_tensor = result - logger.info(f"Got result with embedded frame_id: {frame_id}") + logger.debug(f"Got result with embedded frame_id: {frame_id}") else: out_tensor = result # Find which original frame this corresponds to using our mapping diff --git a/src/comfystream/client_api.py b/src/comfystream/client_api.py index 2fd14bb1..d4a2ed7c 100644 --- a/src/comfystream/client_api.py +++ b/src/comfystream/client_api.py @@ -106,9 +106,9 @@ async def run_prompt(self, prompt_index: int): # Wait for execution completion with timeout try: - logger.info("Waiting for execution to complete (max 10 seconds)...") + logger.debug("Waiting for execution to complete (max 10 seconds)...") await asyncio.wait_for(self.execution_complete_event.wait(), timeout=10.0) - logger.info("Execution complete, ready for next frame") + logger.debug("Execution complete, ready for next frame") except asyncio.TimeoutError: logger.error("Timeout waiting for execution, forcing continuation") self.execution_complete_event.set() @@ -322,10 +322,10 @@ async def _handle_binary_message(self, binary_data): frame_id = None if hasattr(self, '_prompt_id') and self._prompt_id in self._frame_id_mapping: frame_id = self._frame_id_mapping.get(self._prompt_id) - logger.info(f"Using frame_id {frame_id} from prompt_id {self._prompt_id}") + logger.debug(f"Using frame_id {frame_id} from prompt_id {self._prompt_id}") elif hasattr(self, '_current_frame_id') and self._current_frame_id is not None: frame_id = self._current_frame_id - logger.info(f"Using current frame_id {frame_id}") + logger.debug(f"Using current frame_id {frame_id}") # Add to output queue - include frame_id if available if frame_id is not None: From 0ee6222c89183f80ed6f1ee4d9ce55bbf0a26bf8 Mon Sep 17 00:00:00 2001 From: Buff Date: Tue, 1 Apr 2025 15:14:12 -0400 Subject: [PATCH 11/42] Update requirements.txt Co-authored-by: John | Elite Encoder --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 56fe8b22..24c6c32a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,5 +4,6 @@ aiortc aiohttp toml tomli +websockets twilio prometheus_client From 61a03fad16fdfee881c24ef677edc69b2292a06a Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Tue, 1 Apr 2025 16:45:31 -0400 Subject: [PATCH 12/42] Added native nodes into root nodes. --- nodes/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nodes/__init__.py b/nodes/__init__.py index 46aa93c6..4682f0d6 100644 --- a/nodes/__init__.py +++ b/nodes/__init__.py @@ -3,6 +3,7 @@ from .audio_utils import * from .tensor_utils import * from .video_stream_utils import * +from .native_utils import * from .api import * from .web import * @@ -11,7 +12,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {} # Import and update mappings from submodules -for module in [audio_utils, tensor_utils, video_stream_utils, api, web]: +for module in [audio_utils, tensor_utils, video_stream_utils, api, web, native_utils]: if hasattr(module, 'NODE_CLASS_MAPPINGS'): NODE_CLASS_MAPPINGS.update(module.NODE_CLASS_MAPPINGS) if hasattr(module, 'NODE_DISPLAY_NAME_MAPPINGS'): From 77e28ffd5205965e546389df4e238372a9784d87 Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Tue, 8 Apr 2025 14:54:58 -0400 Subject: [PATCH 13/42] Rebuilt get_available_nodes using native ComfyUI api for retrofit to the ui, cleanup of tensor code. --- server/pipeline_api.py | 17 +- src/comfystream/client_api.py | 326 +++++++++++++++++++--------------- src/comfystream/utils_api.py | 4 +- 3 files changed, 190 insertions(+), 157 deletions(-) diff --git a/server/pipeline_api.py b/server/pipeline_api.py index 805f8734..fdf8a14e 100644 --- a/server/pipeline_api.py +++ b/server/pipeline_api.py @@ -173,7 +173,7 @@ async def _check_frame_timeouts(self): # If we've waited too long, skip the missing frame(s) if wait_time_ms > self.max_frame_wait_ms: - logger.warning(f"Missing frame {self.next_expected_frame_id}, skipping to {oldest_frame_id}") + logger.debug(f"Missing frame {self.next_expected_frame_id}, skipping to {oldest_frame_id}") self.next_expected_frame_id = oldest_frame_id await self._release_ordered_frames() @@ -338,7 +338,7 @@ async def get_processed_video_frame(self): processed_frame_id, out_tensor = await self.processed_video_frames.get() if processed_frame_id != frame_id: - logger.warning(f"Frame ID mismatch: expected {frame_id}, got {processed_frame_id}") + logger.debug(f"Frame ID mismatch: expected {frame_id}, got {processed_frame_id}") pass # Process the frame @@ -373,13 +373,14 @@ async def get_processed_audio_frame(self): processed_frame.sample_rate = frame.sample_rate return processed_frame - + async def get_nodes_info(self) -> Dict[str, Any]: - """Get information about nodes from the first client""" - if not self.clients: - return {} - return await self.clients[0].get_available_nodes() - + """Get information about all nodes in the current prompt including metadata.""" + # Note that we pull the node info from the first client (as they should all be the same) + # TODO: This is just retrofitting the functionality of the comfy embedded client, there could be major improvements here + nodes_info = await self.clients[0].get_available_nodes() + return nodes_info + async def cleanup(self): """Clean up all clients and background tasks""" self.running = False diff --git a/src/comfystream/client_api.py b/src/comfystream/client_api.py index d4a2ed7c..f47ebb3f 100644 --- a/src/comfystream/client_api.py +++ b/src/comfystream/client_api.py @@ -496,97 +496,6 @@ async def _execute_prompt(self, prompt_index: int): logger.error(f"Error executing prompt: {e}") self.execution_complete_event.set() - async def _send_tensor_via_websocket(self, tensor): - """Send tensor data via the websocket connection""" - try: - if self.ws is None: - logger.error("WebSocket not connected, cannot send tensor") - self.execution_complete_event.set() # Prevent hanging - return - - # Convert the tensor to image format for sending - if isinstance(tensor, np.ndarray): - tensor = torch.from_numpy(tensor).float() - - # Ensure on CPU and correct format - tensor = tensor.detach().cpu().float() - - # Prepare binary data - if len(tensor.shape) == 4: # BCHW format (batch of images) - if tensor.shape[0] > 1: - logger.info(f"Taking first image from batch of {tensor.shape[0]}") - pass - tensor = tensor[0] # Take first image if batch - - # Ensure CHW format (3 channels) - if len(tensor.shape) == 3: - if tensor.shape[0] != 3 and tensor.shape[2] == 3: # HWC format - tensor = tensor.permute(2, 0, 1) # Convert to CHW - elif tensor.shape[0] != 3: - logger.warning(f"Tensor doesn't have 3 channels: {tensor.shape}. Creating standard tensor.") - # Create a standard RGB tensor - tensor = torch.zeros(3, 512, 512) - else: - logger.warning(f"Tensor has unexpected shape: {tensor.shape}. Creating standard tensor.") - # Create a standard RGB tensor - tensor = torch.zeros(3, 512, 512) - - # Check tensor dimensions and log detailed info - logger.info(f"Original tensor for WS: shape={tensor.shape}, min={tensor.min().item():.4f}, max={tensor.max().item():.4f}") - - # Always ensure consistent 512x512 dimensions - ''' - if tensor.shape[1] != 512 or tensor.shape[2] != 512: - logger.info(f"Resizing tensor from {tensor.shape} to standard 512x512") - import torch.nn.functional as F - tensor = tensor.unsqueeze(0) # Add batch dimension for interpolate - tensor = F.interpolate(tensor, size=(512, 512), mode='bilinear', align_corners=False) - tensor = tensor.squeeze(0) # Remove batch dimension after resize - ''' - - # Check for NaN or Inf values - if torch.isnan(tensor).any() or torch.isinf(tensor).any(): - logger.warning("Tensor contains NaN or Inf values! Replacing with zeros.") - tensor = torch.nan_to_num(tensor, nan=0.0, posinf=1.0, neginf=0.0) - - # Convert to image (HWC for PIL) - tensor_np = (tensor.permute(1, 2, 0) * 255).clamp(0, 255).numpy().astype(np.uint8) - img = Image.fromarray(tensor_np) - - logger.info(f"Converted to PIL image with dimensions: {img.size}") - - # Convert to PNG - buffer = BytesIO() - img.save(buffer, format="PNG") - buffer.seek(0) - img_bytes = buffer.getvalue() - - # CRITICAL FIX: We need to send the binary data with a proper node ID prefix - # LoadTensorAPI node expects this header format to identify the target node - # The first 4 bytes are the message type (3 for binary tensor) and the next 4 are the node ID - # Since we don't know the exact node ID, we'll use a generic one that will be interpreted as - # "send this to the currently waiting LoadTensorAPI node" - - # Build header (8 bytes total) - header = bytearray() - # Message type 3 (custom binary tensor data) - header.extend((3).to_bytes(4, byteorder='little')) - # Generic node ID (0 means "send to whatever node is waiting") - header.extend((0).to_bytes(4, byteorder='little')) - - # Combine header and image data - full_data = header + img_bytes - - # Send binary data via websocket - await self.ws.send(full_data) - logger.info(f"Sent tensor as PNG image via websocket with proper header, size: {len(full_data)} bytes, image dimensions: {img.size}") - - except Exception as e: - logger.error(f"Error sending tensor via websocket: {e}") - - # Signal execution complete in case of error - self.execution_complete_event.set() - async def cleanup(self): """Clean up resources""" async with self.cleanup_lock: @@ -642,26 +551,10 @@ async def cleanup_queues(self): logger.info("Tensor queues cleared") - def put_video_input(self, tensor: Union[torch.Tensor, np.ndarray]): - """ - Put a video TENSOR into the tensor cache for processing. - - Args: - tensor: Video frame as a tensor (or numpy array) - """ - try: - # Only remove one frame if the queue is full (like in client.py) - if tensor_cache.image_inputs.full(): - tensor_cache.image_inputs.get_nowait() - - # Ensure tensor is detached if it's a torch tensor - if isinstance(tensor, torch.Tensor): - tensor = tensor.detach() - - tensor_cache.image_inputs.put(tensor) - - except Exception as e: - logger.error(f"Error in put_video_input: {e}") + def put_video_input(self, frame): + if tensor_cache.image_inputs.full(): + tensor_cache.image_inputs.get(block=True) + tensor_cache.image_inputs.put(frame) def put_audio_input(self, frame): """Put audio frame into tensor cache""" @@ -685,44 +578,183 @@ async def get_video_output(self): async def get_audio_output(self): """Get processed audio frame from tensor cache""" return await tensor_cache.audio_outputs.get() + + async def get_available_nodes(self) -> Dict[int, Dict[str, Any]]: + """ + Retrieves detailed information about the nodes used in the current prompts + by querying the ComfyUI /object_info API endpoint. + + Returns: + A dictionary where keys are prompt indices and values are dictionaries + mapping node IDs to their information, matching the required UI format. - async def get_available_nodes(self): - """Get metadata and available nodes info for current prompts""" - try: - async with aiohttp.ClientSession() as session: - url = f"{self.api_base_url}/object_info" - async with session.get(url) as response: - if response.status == 200: - data = await response.json() - - # Format node info similar to the embedded client response - all_prompts_nodes_info = {} - - for prompt_index, prompt in enumerate(self.current_prompts): - nodes_info = {} + The idea of this function is to replicate the functionality of comfy embedded client import_all_nodes_in_workspace + TODO: Why not support ckpt_name and lora_name as dropdown selectors on UI? + """ + + if not self.current_prompts: + logger.warning("No current prompts set. Cannot get node info.") + return {} + + all_prompts_nodes_info: Dict[int, Dict[str, Any]] = {} + all_needed_class_types = set() + + # Collect all unique class types across all prompts first + for prompt in self.current_prompts: + for node in prompt.values(): + if isinstance(node, dict) and 'class_type' in node: + all_needed_class_types.add(node['class_type']) + + class_info_cache: Dict[str, Any] = {} + + async with aiohttp.ClientSession() as session: + fetch_tasks = [] + for class_type in all_needed_class_types: + api_url = f"{self.api_base_url}/object_info/{class_type}" + fetch_tasks.append(self._fetch_object_info(session, api_url, class_type)) + + results = await asyncio.gather(*fetch_tasks, return_exceptions=True) + + # Populate cache from results + for result in results: + if isinstance(result, tuple) and len(result) == 2: + class_type, info = result + if info: + class_info_cache[class_type] = info + elif isinstance(result, Exception): + logger.error(f"An exception occurred during object_info fetch task: {result}") + + # Now, build the output structure for each prompt + for prompt_index, prompt in enumerate(self.current_prompts): + nodes_info: Dict[str, Any] = {} + for node_id, node_data in prompt.items(): + if not isinstance(node_data, dict) or 'class_type' not in node_data: + logger.debug(f"Skipping invalid node data for node_id {node_id} in prompt {prompt_index}") + continue + + class_type = node_data['class_type'] + # Let's skip the native api i/o nodes for now, subject to change + if class_type in ['LoadImageBase64', 'SendImageWebsocket']: + continue + + node_info = { + 'class_type': class_type, + 'inputs': {} + } + + specific_class_info = class_info_cache.get(class_type) + + if specific_class_info and 'input' in specific_class_info: + input_definitions = {} + required_inputs = specific_class_info['input'].get('required', {}) + optional_inputs = specific_class_info['input'].get('optional', {}) + + if isinstance(required_inputs, dict): + input_definitions.update(required_inputs) + if isinstance(optional_inputs, dict): + input_definitions.update(optional_inputs) + + if 'inputs' in node_data and isinstance(node_data['inputs'], dict): + for input_name, input_value in node_data['inputs'].items(): + input_def = input_definitions.get(input_name) - for node_id, node in prompt.items(): - class_type = node.get('class_type') - if class_type: - nodes_info[node_id] = { - 'class_type': class_type, - 'inputs': {} - } - - if 'inputs' in node: - for input_name, input_value in node['inputs'].items(): - nodes_info[node_id]['inputs'][input_name] = { - 'value': input_value, - 'type': 'unknown' # We don't have type information - } + # Format the input value as a tuple if it's a list with node references + if isinstance(input_value, list) and len(input_value) == 2 and isinstance(input_value[0], str) and isinstance(input_value[1], int): + input_value = tuple(input_value) # Convert [node_id, output_index] to (node_id, output_index) + + # Create Enum-like objects for certain types + def create_enum_format(type_name): + # Format the type as + return f"" - all_prompts_nodes_info[prompt_index] = nodes_info - - return all_prompts_nodes_info - + input_details = { + 'value': input_value, + 'type': 'unknown', # Default type + 'min': None, + 'max': None, + 'widget': None # Default, all widgets should be None to match format + } + + # Parse the definition tuple/list if valid + if isinstance(input_def, (list, tuple)) and len(input_def) > 0: + config = None + # Check for config dict as the second element + if len(input_def) > 1 and isinstance(input_def[1], dict): + config = input_def[1] + + # Check for COMBO type (first element is list/tuple of options) + if input_name in ['ckpt_name', 'lora_name']: + # For checkpoint and lora names, use STRING type instead of combo list + input_details['type'] = create_enum_format('STRING') + elif isinstance(input_def[0], (list, tuple)): + input_details['type'] = input_def[0] # Type is the list of options + # Don't set widget for combo + else: + # Regular type (string or enum) + input_type_raw = input_def[0] + # Keep raw type name for certain types to match format + if hasattr(input_type_raw, 'name'): + # Special handling for CLIP and STRING to match expected format + type_name = str(input_type_raw.name) + if type_name in ('CLIP', 'STRING'): + # Create Enum-like format that matches format in desired output + input_details['type'] = create_enum_format(type_name) + else: + input_details['type'] = type_name + else: + # For non-enum types + input_details['type'] = str(input_type_raw) + + # Extract constraints/widget from config if it exists + if config: + for key in ['min', 'max']: # Only include these, skip widget/step/round + if key in config: + input_details[key] = config[key] + + node_info['inputs'][input_name] = input_details else: - logger.error(f"Error getting node info: {response.status}") - return {} + logger.debug(f"Node {node_id} ({class_type}) has no 'inputs' dictionary.") + elif class_type not in class_info_cache: + logger.warning(f"No cached info found for class_type: {class_type} (node_id: {node_id}).") + else: + logger.debug(f"Class info for {class_type} does not contain an 'input' key.") + # If class info exists but no 'input' key, still add node with empty inputs dict + + nodes_info[node_id] = node_info + + # Only add if there are any nodes after filtering + if nodes_info: + all_prompts_nodes_info[prompt_index] = nodes_info + + return all_prompts_nodes_info + + async def _fetch_object_info(self, session: aiohttp.ClientSession, url: str, class_type: str) -> Optional[tuple[str, Any]]: + """Helper function to fetch object info for a single class type.""" + try: + logger.debug(f"Fetching object info for: {class_type} from {url}") + async with session.get(url) as response: + if response.status == 200: + try: + data = await response.json() + # Extract the actual node info from the nested structure + if class_type in data and isinstance(data[class_type], dict): + node_specific_info = data[class_type] + logger.debug(f"Successfully fetched and extracted info for {class_type}") + return class_type, node_specific_info + else: + logger.error(f"Unexpected response structure for {class_type}. Key missing or not a dict. Response: {data}") + + except aiohttp.ContentTypeError: + logger.error(f"Failed to decode JSON for {class_type}. Status: {response.status}, Content-Type: {response.headers.get('Content-Type')}, Response: {await response.text()[:200]}...") # Log beginning of text + except json.JSONDecodeError as e: + logger.error(f"Invalid JSON received for {class_type}. Status: {response.status}, Error: {e}, Response: {await response.text()[:200]}...") + else: + error_text = await response.text() + logger.error(f"Error fetching info for {class_type}: {response.status} - {error_text[:200]}...") + except aiohttp.ClientError as e: + logger.error(f"HTTP client error fetching info for {class_type} ({url}): {e}") except Exception as e: - logger.error(f"Error getting node info: {str(e)}") - return {} \ No newline at end of file + logger.error(f"Unexpected error fetching info for {class_type} ({url}): {e}") + + # Return class_type and None if any error occurred + return class_type, None \ No newline at end of file diff --git a/src/comfystream/utils_api.py b/src/comfystream/utils_api.py index ba31ccd4..dbbb6790 100644 --- a/src/comfystream/utils_api.py +++ b/src/comfystream/utils_api.py @@ -2,9 +2,9 @@ import random from typing import Dict, Any -# from comfy.api.components.schema.prompt import Prompt, PromptDictInput import logging +logger = logging.getLogger(__name__) def create_load_tensor_node(): return { @@ -102,7 +102,7 @@ def convert_prompt(prompt): # Generate a random seed (same range as JavaScript's Math.random() * 18446744073709552000) random_seed = random.randint(0, 18446744073709551615) node["inputs"]["seed"] = random_seed - print(f"Set random seed {random_seed} for node {key}") + logger.debug(f"Set random seed {random_seed} for node {key}") for key, node in prompt.items(): class_type = node.get("class_type") From 13b511a3de0ce48dbadcfb9328452a78e7b4e846 Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Tue, 8 Apr 2025 16:39:35 -0400 Subject: [PATCH 14/42] Modified base64 processing to use torchvision instead of numpy intermediate step, moved prompt execution strategy to `execution_start` event, moved buffer to self variable to avoid reinitalization. --- requirements.txt | 1 + src/comfystream/client_api.py | 134 ++++++++++++++++++++-------------- 2 files changed, 80 insertions(+), 55 deletions(-) diff --git a/requirements.txt b/requirements.txt index 24c6c32a..f8d1b46d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ tomli websockets twilio prometheus_client +torchvision diff --git a/src/comfystream/client_api.py b/src/comfystream/client_api.py index f47ebb3f..5dc9ace1 100644 --- a/src/comfystream/client_api.py +++ b/src/comfystream/client_api.py @@ -15,11 +15,12 @@ from comfystream import tensor_cache from comfystream.utils_api import convert_prompt +from torchvision.transforms.functional import to_pil_image logger = logging.getLogger(__name__) class ComfyStreamClient: - def __init__(self, host: str = "127.0.0.1", port: int = 8198, **kwargs): + def __init__(self, host: str = "127.0.0.1", port: int = 8188, **kwargs): """ Initialize the ComfyStream client to use the ComfyUI API. @@ -38,11 +39,10 @@ def __init__(self, host: str = "127.0.0.1", port: int = 8198, **kwargs): self.current_prompts = [] self.running_prompts = {} self.cleanup_lock = asyncio.Lock() - + self.buffer = BytesIO() # WebSocket connection self._ws_listener_task = None self.execution_complete_event = asyncio.Event() - self.execution_started = False self._prompt_id = None # Add frame tracking @@ -219,11 +219,22 @@ async def _handle_text_message(self, message: str): data = json.loads(message) message_type = data.get("type", "unknown") - logger.debug(f"Received message type: {message_type}") - logger.debug(f"{data}") - + # logger.info(f"Received message type: {message_type}") + logger.info(f"{data}") + + # Example output + ''' + 15:15:58 [INFO] Received message type: executing + 15:15:58 [INFO] {'type': 'executing', 'data': {'node': '18', 'display_node': '18', 'prompt_id': '6f983049-dca4-4935-9f36-d2bff7b744fa'}} + 15:15:58 [INFO] Received message type: executed + 15:15:58 [INFO] {'type': 'executed', 'data': {'node': '18', 'display_node': '18', 'output': {'images': [{'source': 'websocket', 'content-type': 'image/png', 'type': 'output'}]}, 'prompt_id': '6f983049-dca4-4935-9f36-d2bff7b744fa'}} + 15:15:58 [INFO] Received message type: execution_success + 15:15:58 [INFO] {'type': 'execution_success', 'data': {'prompt_id': '6f983049-dca4-4935-9f36-d2bff7b744fa', 'timestamp': 1744139758250}} + ''' + + # Handle different message types to have fun with! + ''' - # Handle different message types if message_type == "status": # Status message with comfy_ui's queue information queue_remaining = data.get("data", {}).get("queue_remaining", 0) @@ -232,56 +243,60 @@ async def _handle_text_message(self, message: str): logger.info("Queue empty, no active execution") else: logger.info(f"Queue status: {queue_remaining} items remaining") - - elif message_type == "progress": + ''' + + ''' + if message_type == "progress": if "data" in data and "value" in data["data"]: progress = data["data"]["value"] max_value = data["data"].get("max", 100) # Log the progress for debugging logger.info(f"Progress: {progress}/{max_value}") - - elif message_type == "execution_start": - self.execution_started = True + ''' + + if message_type == "execution_start": if "data" in data and "prompt_id" in data["data"]: self._prompt_id = data["data"]["prompt_id"] logger.info(f"Execution started for prompt {self._prompt_id}") - - elif message_type == "executing": - self.execution_started = True + + # Let's queue the next prompt here! + self.execution_complete_event.set() + + ''' + if message_type == "executing": if "data" in data: if "prompt_id" in data["data"]: self._prompt_id = data["data"]["prompt_id"] if "node" in data["data"]: node_id = data["data"]["node"] logger.info(f"Executing node: {node_id}") - - elif message_type in ["execution_cached", "execution_error", "execution_complete", "execution_interrupted"]: - logger.info(f"{message_type} message received for prompt {self._prompt_id}") - # self.execution_started = False - + + # Let's check which node_id is a LoadImageBase64 node + # and set the execution complete event for that node + for prompt_index, prompt in enumerate(self.current_prompts): + for node_id, node in prompt.items(): + if (node_id == executing_node_id and isinstance(node, dict) and node.get("class_type") in ["LoadImageBase64"]): + logger.info(f"Setting execution complete event for LoadImageBase64 node {node_id}") + self.execution_complete_event.set() + break + ''' + + ''' + if message_type == "executed": + # This is sent when a node is completely done + if "data" in data and "node" in data["data"]: + node_id = data["data"]["node"] + logger.info(f"Node execution complete: {node_id}") + ''' + + ''' + if message_type in ["execution_cached", "execution_error", "execution_complete", "execution_interrupted"]: + logger.info(f"{message_type} message received for prompt {self._prompt_id}") # Always signal completion for these terminal states # self.execution_complete_event.set() logger.info(f"Set execution_complete_event from {message_type}") pass ''' - - if message_type == "executed": - # This is sent when a node is completely done - if "data" in data and "node_id" in data["data"]: - node_id = data["data"]["node_id"] - logger.info(f"Node execution complete: {node_id}") - - # Check if this is our SaveTensorAPI node - if "SaveTensorAPI" in str(node_id): - logger.info("SaveTensorAPI node executed, checking for tensor data") - # The binary data should come separately via websocket - - # If we've been running for too long without tensor data, force completion - elif self.execution_started and not self.execution_complete_event.is_set(): - # Check if this was the last node - if data.get("data", {}).get("remaining", 0) == 0: - # self.execution_complete_event.set() - pass except json.JSONDecodeError: logger.error(f"Invalid JSON message: {message[:100]}...") @@ -295,7 +310,7 @@ async def _handle_binary_message(self, binary_data): try: # Early return if message is too short if len(binary_data) <= 8: - self.execution_complete_event.set() + # self.execution_complete_event.set() return # Extract header data only when needed @@ -306,7 +321,7 @@ async def _handle_binary_message(self, binary_data): # Quick check for image format is_image = data[:2] in [b'\xff\xd8', b'\x89\x50'] if not is_image: - self.execution_complete_event.set() + # self.execution_complete_event.set() return # Process image data directly @@ -335,15 +350,16 @@ async def _handle_binary_message(self, binary_data): tensor_cache.image_outputs.put_nowait(tensor) logger.debug("Added tensor without frame_id to output queue") - self.execution_complete_event.set() + # We will execute the next prompt from message_type == "execution_start" instead + # self.execution_complete_event.set() except Exception as img_error: logger.error(f"Error processing image: {img_error}") - self.execution_complete_event.set() + # self.execution_complete_event.set() except Exception as e: logger.error(f"Error handling binary message: {e}") - self.execution_complete_event.set() + # self.execution_complete_event.set() async def _execute_prompt(self, prompt_index: int): try: @@ -430,25 +446,35 @@ async def _execute_prompt(self, prompt_index: int): tensor, size=(512, 512), mode='bilinear', align_corners=False ) tensor = tensor[0] # Remove batch dimension - + + # ==== + # PIL method + ''' # Direct conversion to PIL without intermediate numpy step for speed tensor_np = (tensor.permute(1, 2, 0).clamp(0, 1) * 255).to(torch.uint8).numpy() img = Image.fromarray(tensor_np) + img.save(self.buffer, format="JPEG", quality=90, optimize=True) + ''' - # Fast JPEG encoding with balanced quality - buffer = BytesIO() - img.save(buffer, format="JPEG", quality=90, optimize=True) - buffer.seek(0) - img_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8') + # ==== + # torchvision method (more performant - TODO: need to test further) + # Direct conversion to PIL without intermediate numpy step + # Fast JPEG encoding with reduced quality for better performance + tensor_pil = to_pil_image(tensor.clamp(0, 1)) + tensor_pil.save(self.buffer, format="JPEG", quality=75, optimize=True) + # ==== + + self.buffer.seek(0) + img_base64 = base64.b64encode(self.buffer.getvalue()).decode('utf-8') except Exception as e: logger.warning(f"Error in tensor processing: {e}, creating fallback image") # Create a standard 512x512 placeholder if anything fails img = Image.new('RGB', (512, 512), color=(100, 149, 237)) - buffer = BytesIO() - img.save(buffer, format="JPEG", quality=90) - buffer.seek(0) - img_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8') + self.buffer = BytesIO() + img.save(self.buffer, format="JPEG", quality=90) + self.buffer.seek(0) + img_base64 = base64.b64encode(self.buffer.getvalue()).decode('utf-8') # Add timestamp for cache busting (once, outside the try/except) timestamp = int(time.time() * 1000) @@ -482,8 +508,6 @@ async def _execute_prompt(self, prompt_index: int): if frame_id is not None: self._frame_id_mapping[self._prompt_id] = frame_id logger.info(f"Mapped prompt_id {self._prompt_id} to frame_id {frame_id}") - - self.execution_started = True else: error_text = await response.text() logger.error(f"Error queueing prompt: {response.status} - {error_text}") From c9aff80fbdac2a52106a1f7b7ae193b72891e709 Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Tue, 15 Apr 2025 19:06:01 -0400 Subject: [PATCH 15/42] Built Comfy subprocess spawn client mode, built dynamic output pacer to improve frame buffer, modified comfy arg handling. --- server/app_api.py | 59 +++++++++- server/pipeline_api.py | 205 +++++++++++++++++++++++++++++----- src/comfystream/client_api.py | 177 ++++++++++++++++++++++++++--- 3 files changed, 388 insertions(+), 53 deletions(-) diff --git a/server/app_api.py b/server/app_api.py index f5010790..7fd03cc9 100644 --- a/server/app_api.py +++ b/server/app_api.py @@ -202,8 +202,34 @@ async def offer(request): pipeline = request.app["pipeline"] pcs = request.app["pcs"] + # Check if clients are initialized, and initialize them if not + if not pipeline.clients: + logger.info("Clients not initialized yet, starting clients...") + await pipeline.start_clients() + # Check if any clients with spawn=True need to have servers started + elif pipeline.client_mode == "spawn": + start_tasks = [] + for client in pipeline.clients: + if client.spawn and (not hasattr(client, '_comfyui_proc') or client._comfyui_proc is None): + start_tasks.append(client.start_server()) + + # Start any servers that need to be started + if start_tasks: + logger.info(f"Starting ComfyUI servers for new workflow...") + await asyncio.gather(*start_tasks) + logger.info(f"Started {len(start_tasks)} ComfyUI servers") + + # Get parameters params = await request.json() - + + # When a client reconnects after refresh, we need to clear certain pipeline state + # but NOT restart the ComfyUI servers/clients + # Reset the frame tracking, but keep the servers running + pipeline.next_expected_frame_id = None + pipeline.ordered_frames.clear() + pipeline.next_frame_id = 1 # Reset frame ID counter for new connection + pipeline.client_frame_mapping.clear() + await pipeline.set_prompts(params["prompts"]) offer_params = params["offer"] @@ -369,17 +395,25 @@ async def on_startup(app: web.Application): if app["media_ports"]: patch_loop_datagram(app["media_ports"]) + # ComfyUI args have been moved to the client constructor app["pipeline"] = Pipeline( width=512, height=512, - cwd=app["workspace"], - disable_cuda_malloc=True, - gpu_only=True, - preview_method='none', comfyui_inference_log_level=app.get("comfui_inference_log_level", None), config_path=app["config_file"], max_frame_wait_ms=app["max_frame_wait"], + client_mode=app["client_mode"], + workspace=app["workspace"], + workers=app["workers"], ) + + # Start the clients during initialization + # await app["pipeline"].start_clients() + + # Wait for pipeline startup to complete (which starts the ComfyUI servers) + if hasattr(app["pipeline"], "startup_task"): + await app["pipeline"].startup_task + app["pcs"] = set() app["video_tracks"] = {} @@ -392,7 +426,6 @@ async def on_shutdown(app: web.Application): await asyncio.gather(*coros) pcs.clear() - if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run comfystream server") parser.add_argument("--port", default=8889, help="Set the signaling port") @@ -446,6 +479,18 @@ async def on_shutdown(app: web.Application): choices=logging._nameToLevel.keys(), help="Set the logging level for ComfyUI inference", ) + parser.add_argument( + "--client-mode", + choices=["toml", "spawn"], + default="toml", + help="How to create ComfyUI clients: 'toml' (from config file) or 'spawn' (spawn processes directly)", + ) + parser.add_argument( + "--workers", + type=int, + default=2, + help="Number of worker processes to spawn when using --client-mode=spawn" + ) args = parser.parse_args() logging.basicConfig( @@ -462,6 +507,8 @@ async def on_shutdown(app: web.Application): app["workspace"] = args.workspace app["config_file"] = args.config_file app["max_frame_wait"] = args.max_frame_wait + app["client_mode"] = args.client_mode + app["workers"] = args.workers app.on_startup.append(on_startup) app.on_shutdown.append(on_shutdown) diff --git a/server/pipeline_api.py b/server/pipeline_api.py index de802ba8..aa29b9d0 100644 --- a/server/pipeline_api.py +++ b/server/pipeline_api.py @@ -6,10 +6,13 @@ import time import random from collections import OrderedDict +import collections +import os +import socket from typing import Any, Dict, Union, List, Optional, Deque from comfystream.client_api import ComfyStreamClient -from utils import temporary_log_level +from utils import temporary_log_level # Not sure exactly what this does from config import ComfyConfig WARMUP_RUNS = 5 @@ -17,31 +20,52 @@ class MultiServerPipeline: - def __init__(self, width=512, height=512, comfyui_inference_log_level: int = None, config_path: Optional[str] = None, max_frame_wait_ms: int = 500, **kwargs): + def __init__( + self, + width: int = 512, + height: int = 512, + workers: int = 2, + comfyui_inference_log_level: int = None, + config_path: Optional[str] = None, + max_frame_wait_ms: int = 500, + client_mode: str = "toml", + workspace: str = None + ): """Initialize the pipeline with the given configuration. Args: width: The width of the video frames. height: The height of the video frames. + workers: The number of ComfyUI clients to spin up (if client_mode is "spawn"). comfyui_inference_log_level: The logging level for ComfyUI inference. Defaults to None, using the global ComfyUI log level. - **kwargs: Additional arguments to pass to the ComfyStreamClient + config_path: The path to the ComfyUI config toml file (if client_mode is "toml"). + max_frame_wait_ms: The maximum number of milliseconds to wait for a frame before dropping it. + client_mode: The mode to use for the ComfyUI clients. + "toml": Use a config file to describe clients. + "spawn": Spawn ComfyUI clients as external processes. """ - - # Load server configurations - self.config = ComfyConfig(config_path) - self.servers = self.config.get_servers() - - # Create client for each server + + # There are two methods for starting the clients: + # 1. client_mode == "toml" -> Use a config file to describe clients. + # 2. client_mode == "spawn" -> Spawn ComfyUI clients as external processes. + self.clients = [] - for server_config in self.servers: - client_kwargs = kwargs.copy() - client_kwargs.update(server_config) - self.clients.append(ComfyStreamClient(**client_kwargs)) - - self.width = kwargs.get("width", 512) - self.height = kwargs.get("height", 512) + self.workspace = workspace + self.client_mode = client_mode - logger.info(f"Initialized {len(self.clients)} ComfyUI clients") + if (client_mode == "toml"): + # Load server configurations + self.config = ComfyConfig(config_path) + self.servers = self.config.get_servers() + elif (client_mode == "spawn"): + # Set the number of workers to spawn + self.workers = workers + + # Started in /offer + # self.start_clients() + + self.width = width + self.height = height self.video_incoming_frames = asyncio.Queue() self.audio_incoming_frames = asyncio.Queue() @@ -72,6 +96,11 @@ def __init__(self, width=512, height=512, comfyui_inference_log_level: int = Non self.running = True self.collector_task = asyncio.create_task(self._collect_processed_frames()) + self.output_interval = 1/30 # Start with 30 FPS + self.last_output_time = None + self.frame_interval_history = collections.deque(maxlen=30) + self.output_pacer_task = asyncio.create_task(self._dynamic_output_pacer()) + async def _collect_processed_frames(self): """Background task to collect processed frames from all clients""" try: @@ -141,26 +170,15 @@ async def _add_frame_to_ordered_buffer(self, frame_id, timestamp, tensor): await self._release_ordered_frames() async def _release_ordered_frames(self): - """Process ordered frames and put them in the output queue""" - # If we don't have a next expected frame yet, can't do anything if self.next_expected_frame_id is None: return - - # Check if the next expected frame is in our buffer - while self.ordered_frames and self.next_expected_frame_id in self.ordered_frames: - # Get the frame + if self.ordered_frames and self.next_expected_frame_id in self.ordered_frames: timestamp, tensor = self.ordered_frames.pop(self.next_expected_frame_id) - - # Put it in the output queue await self.processed_video_frames.put((self.next_expected_frame_id, tensor)) logger.info(f"Released frame {self.next_expected_frame_id} to output queue") - - # Update the next expected frame ID to the next sequential ID if possible - # (or the lowest frame ID in our buffer) if self.ordered_frames: self.next_expected_frame_id = min(self.ordered_frames.keys()) else: - # If no more frames, keep the last ID + 1 as next expected self.next_expected_frame_id += 1 async def _check_frame_timeouts(self): @@ -349,12 +367,13 @@ async def get_processed_video_frame(self): frame_id, frame = await self.video_incoming_frames.get() # Skip frames if we're falling behind + ''' while not self.video_incoming_frames.empty(): # Get newer frame and mark old one as skipped frame.side_data.skipped = True frame_id, frame = await self.video_incoming_frames.get() logger.info(f"Skipped older frame {frame_id} to catch up") - + ''' # Get the processed frame from our output queue processed_frame_id, out_tensor = await self.processed_video_frames.get() @@ -422,6 +441,132 @@ async def cleanup(self): await asyncio.gather(*cleanup_tasks) logger.info("All clients cleaned up") + async def _dynamic_output_pacer(self): + while self.running: + # Only release if the next expected frame is available + if self.next_expected_frame_id is not None and self.next_expected_frame_id in self.ordered_frames: + timestamp, tensor = self.ordered_frames.pop(self.next_expected_frame_id) + now = time.time() + + # Calculate dynamic interval based on output history + if self.last_output_time is not None: + actual_interval = now - self.last_output_time + self.frame_interval_history.append(actual_interval) + avg_interval = sum(self.frame_interval_history) / len(self.frame_interval_history) + self.output_interval = avg_interval + self.last_output_time = now + + await self.processed_video_frames.put((self.next_expected_frame_id, tensor)) + logger.info(f"Released frame {self.next_expected_frame_id} to output queue") + + # Update next expected frame ID + if self.ordered_frames: + self.next_expected_frame_id = min(self.ordered_frames.keys()) + else: + self.next_expected_frame_id += 1 + + # Sleep for the dynamic interval, but don't sleep negative time + await asyncio.sleep(max(self.output_interval, 0.001)) + else: + # No frame ready, wait a bit and check again + await asyncio.sleep(0.005) + async def start_clients(self): + """Start the clients based on the client_mode (TOML or spawn)""" + logger.info(f"Starting clients with mode: {self.client_mode}") + + self.clients = [] + + if hasattr(self, 'client_mode') and self.client_mode == "toml": + # Use config file to create clients + for server_config in self.servers: + client_kwargs = server_config.copy() + self.clients.append(ComfyStreamClient(**client_kwargs)) + + elif hasattr(self, 'client_mode') and self.client_mode == "spawn": + # Spin up clients as external processes + ports = [8195 + i for i in range(self.workers)] + + for i in range(self.workers): + client = ComfyStreamClient( + host="127.0.0.1", + port=ports[i], + spawn=True, + comfyui_path=os.path.join(self.workspace, "main.py"), + workspace=self.workspace, + comfyui_args=[ + "--disable-cuda-malloc", + "--gpu-only", + "--preview-method", "none", + "--listen", + "--cuda-device", "0", + "--fast", + "--enable-cors-header", "*", + "--port", str(ports[i]), + "--disable-xformers", + ], + ) + self.clients.append(client) + + else: + raise ValueError(f"Unknown client_mode: {getattr(self, 'client_mode', 'None')}") + + # Start all ComfyUI servers in parallel if in spawn mode + if hasattr(self, 'client_mode') and self.client_mode == "spawn": + # First, launch all server processes in parallel + for client in self.clients: + if client.spawn: + client._launch_comfyui_server() + + # Now create async functions to check server readiness + async def check_server_ready(client, timeout=60, check_interval=0.5): + """Async version of waiting for server to be ready""" + logger.info(f"Waiting for ComfyUI server on port {client.port} to be ready...") + + start_time = time.time() + while time.time() - start_time < timeout: + # Check if process is still running + if client._comfyui_proc and client._comfyui_proc.poll() is not None: + return_code = client._comfyui_proc.poll() + logger.error(f"ComfyUI process exited with code {return_code} before it was ready") + raise RuntimeError(f"ComfyUI process exited with code {return_code}") + + # Try to connect to the server + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(2) + result = sock.connect_ex((client.host, client.port)) + sock.close() + + if result == 0: + logger.info(f"ComfyUI server on port {client.port} is now accepting connections") + return + except Exception: + pass + + # Sleep and try again + await asyncio.sleep(check_interval) + + # If we get here, the server didn't start in time + logger.error(f"Timed out waiting for ComfyUI server on port {client.port}") + if client._comfyui_proc: + client._comfyui_proc.terminate() + client._comfyui_proc = None + raise RuntimeError(f"Timed out waiting for ComfyUI server on port {client.port}") + + # Wait for all servers to be ready in parallel + wait_tasks = [] + for client in self.clients: + if client.spawn: + wait_tasks.append(check_server_ready(client)) + + if wait_tasks: + logger.info(f"Waiting for {len(wait_tasks)} ComfyUI servers to become ready...") + await asyncio.gather(*wait_tasks) + logger.info(f"All {len(wait_tasks)} ComfyUI servers are ready") + + logger.info(f"Initialized {len(self.clients)} clients") + return self.clients + # For backwards compatibility, maintain the original Pipeline name Pipeline = MultiServerPipeline \ No newline at end of file diff --git a/src/comfystream/client_api.py b/src/comfystream/client_api.py index 1be876ac..6956d5a6 100644 --- a/src/comfystream/client_api.py +++ b/src/comfystream/client_api.py @@ -12,6 +12,9 @@ from typing import List, Dict, Any, Optional, Union import random import time +import subprocess +import os +import socket from comfystream import tensor_cache from comfystream.utils_api import convert_prompt @@ -20,41 +23,51 @@ logger = logging.getLogger(__name__) class ComfyStreamClient: - def __init__(self, host: str = "127.0.0.1", port: int = 8188, **kwargs): + def __init__( + self, + host: str = "127.0.0.1", + port: int = 8188, + spawn: bool = False, + comfyui_path: str = None, + comfyui_args: list = None, + workspace: str = None + ): """ Initialize the ComfyStream client to use the ComfyUI API. Args: host: The hostname or IP address of the ComfyUI server port: The port number of the ComfyUI server - **kwargs: Additional configuration parameters + spawn: If True, launch a ComfyUI server when start_server is called + comfyui_path: Path to the ComfyUI main.py file (required if spawn=True) + comfyui_args: Additional arguments for ComfyUI + workspace: The workspace directory for ComfyUI """ self.host = host self.port = port + self.spawn = spawn + self.comfyui_path = comfyui_path + self.comfyui_args = comfyui_args or [] + self.workspace = workspace + self._comfyui_proc = None + + # Server launch is deferred to start_server method + self.server_address = f"ws://{host}:{port}/ws" self.api_base_url = f"http://{host}:{port}/api" - self.client_id = kwargs.get('client_id', str(uuid.uuid4())) + self.client_id = str(uuid.uuid4()) self.ws = None self.current_prompts = [] self.running_prompts = {} self.cleanup_lock = asyncio.Lock() self.buffer = BytesIO() - # WebSocket connection - self._ws_listener_task = None self.execution_complete_event = asyncio.Event() + + self._ws_listener_task = None self._prompt_id = None - - # Add frame tracking self._current_frame_id = None # Track the current frame being processed self._frame_id_mapping = {} # Map prompt_ids to frame_ids - - # Configure logging - if 'log_level' in kwargs: - logger.setLevel(kwargs['log_level']) - - # Enable debug mode - self.debug = kwargs.get('debug', True) - + logger.info(f"ComfyStreamClient initialized with host: {host}, port: {port}, client_id: {self.client_id}") async def set_prompts(self, prompts: List[Dict]): @@ -520,7 +533,7 @@ async def _execute_prompt(self, prompt_index: int): self.execution_complete_event.set() async def cleanup(self): - """Clean up resources""" + """Clean up resources, including terminating spawned ComfyUI process""" async with self.cleanup_lock: # Cancel all running tasks for task in self.running_prompts.values(): @@ -550,6 +563,26 @@ async def cleanup(self): self._ws_listener_task = None await self.cleanup_queues() + + # Terminate the ComfyUI process if we spawned it + if self.spawn and self._comfyui_proc: + logger.info(f"Terminating ComfyUI process (PID: {self._comfyui_proc.pid})") + try: + self._comfyui_proc.terminate() + try: + # Wait for the process to terminate gracefully + exit_code = self._comfyui_proc.wait(timeout=10) + logger.info(f"ComfyUI process exited with code {exit_code}") + except subprocess.TimeoutExpired: + # If it doesn't terminate gracefully, kill it + logger.warning("ComfyUI process did not terminate gracefully, killing...") + self._comfyui_proc.kill() + self._comfyui_proc.wait() + except Exception as e: + logger.error(f"Error terminating ComfyUI process: {e}") + finally: + self._comfyui_proc = None + logger.info("Client cleanup complete") async def cleanup_queues(self): @@ -751,6 +784,17 @@ def create_enum_format(type_name): return all_prompts_nodes_info + async def start_server(self): + """Launch the ComfyUI server if spawn is True""" + if self.spawn: + if not self.comfyui_path: + raise ValueError("comfyui_path must be provided when spawn=True") + self._launch_comfyui_server() + self._wait_for_server_ready() + logger.info("ComfyUI server started successfully") + else: + logger.info("Using existing ComfyUI server (spawn=False)") + async def _fetch_object_info(self, session: aiohttp.ClientSession, url: str, class_type: str) -> Optional[tuple[str, Any]]: """Helper function to fetch object info for a single class type.""" try: @@ -780,4 +824,103 @@ async def _fetch_object_info(self, session: aiohttp.ClientSession, url: str, cla logger.error(f"Unexpected error fetching info for {class_type} ({url}): {e}") # Return class_type and None if any error occurred - return class_type, None \ No newline at end of file + return class_type, None + + def _launch_comfyui_server(self): + """Launch ComfyUI as a subprocess""" + logger.info(f"Spawning ComfyUI server on port {self.port}...") + + # Build the command with just the basics + cmd = [ + "python", self.comfyui_path, + ] + + # Add the arguments from comfyui_args if provided + if self.comfyui_args: + cmd.extend(self.comfyui_args) + else: + # Only add default arguments if comfyui_args was not provided + cmd.extend([ + "--listen", + "--port", str(self.port), + "--fast", + "--enable-cors-header", "*", + "--disable-xformers", + "--preview-method", "none" + ]) + + # Add workspace if provided and not in comfyui_args + if self.workspace: + cmd.extend(["--dir", self.workspace]) + + # Check if CUDA is available and add device argument + if hasattr(torch, 'cuda') and torch.cuda.is_available(): + cuda_device = os.environ.get("CUDA_VISIBLE_DEVICES", "0") + cmd.extend(["--cuda-device", cuda_device]) + + # Always ensure port is set correctly (override if provided in comfyui_args) + # Remove any existing --port argument + if "--port" in cmd: + port_index = cmd.index("--port") + # Remove both the flag and its value + if port_index + 1 < len(cmd): + cmd.pop(port_index + 1) + cmd.pop(port_index) + + # Add our port + cmd.extend(["--port", str(self.port)]) + + # Start the process + try: + logger.info(f"Starting ComfyUI with command: {' '.join(cmd)}") + self._comfyui_proc = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=os.environ.copy(), + text=True, + bufsize=1, # Line buffered + ) + + # Start a thread to log output + def log_output(stream, level): + for line in iter(stream.readline, ''): + # TODO: Handle error logs from comfy + logger.debug(f"ComfyUI[{self.port}]: {line.strip()}") + + import threading + threading.Thread(target=log_output, args=(self._comfyui_proc.stdout, logging.INFO), daemon=True).start() + threading.Thread(target=log_output, args=(self._comfyui_proc.stderr, logging.INFO), daemon=True).start() + + logger.info(f"Started ComfyUI process with PID {self._comfyui_proc.pid}") + except Exception as e: + logger.error(f"Failed to spawn ComfyUI: {e}") + raise + + def _wait_for_server_ready(self, timeout=60, check_interval=0.5): + """Wait until the ComfyUI server is accepting connections""" + logger.info(f"Waiting for ComfyUI server on port {self.port} to be ready...") + + start_time = time.time() + while time.time() - start_time < timeout: + # Check if process is still running + if self._comfyui_proc and self._comfyui_proc.poll() is not None: + return_code = self._comfyui_proc.poll() + logger.error(f"ComfyUI process exited with code {return_code} before it was ready") + raise RuntimeError(f"ComfyUI process exited with code {return_code}") + + # Try to connect to the server + try: + with socket.create_connection((self.host, self.port), timeout=2): + logger.info(f"ComfyUI server on port {self.port} is now accepting connections") + return + except (ConnectionRefusedError, socket.timeout, OSError): + # Sleep and try again + time.sleep(check_interval) + + # If we get here, the server didn't start in time + logger.error(f"Timed out waiting for ComfyUI server on port {self.port}") + if self._comfyui_proc: + self._comfyui_proc.terminate() + self._comfyui_proc = None + raise RuntimeError(f"Timed out waiting for ComfyUI server on port {self.port}") \ No newline at end of file From 937384fcd14bddb8c6006f5e440b94cb597bfa71 Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Tue, 15 Apr 2025 19:47:20 -0400 Subject: [PATCH 16/42] Small fix, cleanup. --- server/app_api.py | 557 ------------------------ server/pipeline_api.py | 572 ------------------------- src/comfystream/server/pipeline_api.py | 7 +- 3 files changed, 5 insertions(+), 1131 deletions(-) delete mode 100644 server/app_api.py delete mode 100644 server/pipeline_api.py diff --git a/server/app_api.py b/server/app_api.py deleted file mode 100644 index 7fd03cc9..00000000 --- a/server/app_api.py +++ /dev/null @@ -1,557 +0,0 @@ -import argparse -import asyncio -import json -import logging -import os -import sys - -import torch - -# Initialize CUDA before any other imports to prevent core dump. -if torch.cuda.is_available(): - torch.cuda.init() - -from aiohttp import web -from aiortc import ( - MediaStreamTrack, - RTCConfiguration, - RTCIceServer, - RTCPeerConnection, - RTCSessionDescription, -) -from aiortc.codecs import h264 -from aiortc.rtcrtpsender import RTCRtpSender -from pipeline_api import Pipeline # TODO: Better integration (Are we replacing pipeline with pipeline_api?) -from twilio.rest import Client -from utils import patch_loop_datagram, add_prefix_to_app_routes, FPSMeter -from metrics import MetricsManager, StreamStatsManager - -logger = logging.getLogger(__name__) -logging.getLogger("aiortc.rtcrtpsender").setLevel(logging.WARNING) -logging.getLogger("aiortc.rtcrtpreceiver").setLevel(logging.WARNING) - - -MAX_BITRATE = 2000000 -MIN_BITRATE = 2000000 - - -class VideoStreamTrack(MediaStreamTrack): - """video stream track that processes video frames using a pipeline. - - Attributes: - kind (str): The kind of media, which is "video" for this class. - track (MediaStreamTrack): The underlying media stream track. - pipeline (Pipeline): The processing pipeline to apply to each video frame. - """ - - kind = "video" - - def __init__(self, track: MediaStreamTrack, pipeline: Pipeline): - """Initialize the VideoStreamTrack. - - Args: - track: The underlying media stream track. - pipeline: The processing pipeline to apply to each video frame. - """ - super().__init__() - self.track = track - self.pipeline = pipeline - self.fps_meter = FPSMeter( - metrics_manager=app["metrics_manager"], track_id=track.id - ) - self.running = True - self.collect_task = asyncio.create_task(self.collect_frames()) - - # Add cleanup when track ends - @track.on("ended") - async def on_ended(): - logger.info("Source video track ended, stopping collection") - await cancel_collect_frames(self) - - async def collect_frames(self): - """Collect video frames from the underlying track and pass them to - the processing pipeline. Stops when track ends or connection closes. - """ - try: - while self.running: - try: - frame = await self.track.recv() - await self.pipeline.put_video_frame(frame) - except asyncio.CancelledError: - logger.info("Frame collection cancelled") - break - except Exception as e: - if "MediaStreamError" in str(type(e)): - logger.info("Media stream ended") - else: - logger.error(f"Error collecting video frames: {str(e)}") - self.running = False - break - - # Perform cleanup outside the exception handler - logger.info("Video frame collection stopped") - except asyncio.CancelledError: - logger.info("Frame collection task cancelled") - except Exception as e: - logger.error(f"Unexpected error in frame collection: {str(e)}") - finally: - await self.pipeline.cleanup() - - async def recv(self): - """Receive a processed video frame from the pipeline, increment the frame - count for FPS calculation and return the processed frame to the client. - """ - processed_frame = await self.pipeline.get_processed_video_frame() - - # Increment the frame count to calculate FPS. - await self.fps_meter.increment_frame_count() - - return processed_frame - - -class AudioStreamTrack(MediaStreamTrack): - kind = "audio" - - def __init__(self, track: MediaStreamTrack, pipeline): - super().__init__() - self.track = track - self.pipeline = pipeline - self.running = True - self.collect_task = asyncio.create_task(self.collect_frames()) - - # Add cleanup when track ends - @track.on("ended") - async def on_ended(): - logger.info("Source audio track ended, stopping collection") - await cancel_collect_frames(self) - - async def collect_frames(self): - """Collect audio frames from the underlying track and pass them to - the processing pipeline. Stops when track ends or connection closes. - """ - try: - while self.running: - try: - frame = await self.track.recv() - await self.pipeline.put_audio_frame(frame) - except asyncio.CancelledError: - logger.info("Audio frame collection cancelled") - break - except Exception as e: - if "MediaStreamError" in str(type(e)): - logger.info("Media stream ended") - else: - logger.error(f"Error collecting audio frames: {str(e)}") - self.running = False - break - - # Perform cleanup outside the exception handler - logger.info("Audio frame collection stopped") - except asyncio.CancelledError: - logger.info("Frame collection task cancelled") - except Exception as e: - logger.error(f"Unexpected error in audio frame collection: {str(e)}") - finally: - await self.pipeline.cleanup() - - async def recv(self): - return await self.pipeline.get_processed_audio_frame() - - -def force_codec(pc, sender, forced_codec): - kind = forced_codec.split("/")[0] - codecs = RTCRtpSender.getCapabilities(kind).codecs - transceiver = next(t for t in pc.getTransceivers() if t.sender == sender) - codecPrefs = [codec for codec in codecs if codec.mimeType == forced_codec] - transceiver.setCodecPreferences(codecPrefs) - - -def get_twilio_token(): - account_sid = os.getenv("TWILIO_ACCOUNT_SID") - auth_token = os.getenv("TWILIO_AUTH_TOKEN") - - if account_sid is None or auth_token is None: - return None - - client = Client(account_sid, auth_token) - - token = client.tokens.create() - - return token - - -def get_ice_servers(): - ice_servers = [] - - token = get_twilio_token() - if token is not None: - # Use Twilio TURN servers - for server in token.ice_servers: - if server["url"].startswith("turn:"): - turn = RTCIceServer( - urls=[server["urls"]], - credential=server["credential"], - username=server["username"], - ) - ice_servers.append(turn) - - return ice_servers - - -async def offer(request): - pipeline = request.app["pipeline"] - pcs = request.app["pcs"] - - # Check if clients are initialized, and initialize them if not - if not pipeline.clients: - logger.info("Clients not initialized yet, starting clients...") - await pipeline.start_clients() - # Check if any clients with spawn=True need to have servers started - elif pipeline.client_mode == "spawn": - start_tasks = [] - for client in pipeline.clients: - if client.spawn and (not hasattr(client, '_comfyui_proc') or client._comfyui_proc is None): - start_tasks.append(client.start_server()) - - # Start any servers that need to be started - if start_tasks: - logger.info(f"Starting ComfyUI servers for new workflow...") - await asyncio.gather(*start_tasks) - logger.info(f"Started {len(start_tasks)} ComfyUI servers") - - # Get parameters - params = await request.json() - - # When a client reconnects after refresh, we need to clear certain pipeline state - # but NOT restart the ComfyUI servers/clients - # Reset the frame tracking, but keep the servers running - pipeline.next_expected_frame_id = None - pipeline.ordered_frames.clear() - pipeline.next_frame_id = 1 # Reset frame ID counter for new connection - pipeline.client_frame_mapping.clear() - - await pipeline.set_prompts(params["prompts"]) - - offer_params = params["offer"] - offer = RTCSessionDescription(sdp=offer_params["sdp"], type=offer_params["type"]) - - ice_servers = get_ice_servers() - if len(ice_servers) > 0: - pc = RTCPeerConnection( - configuration=RTCConfiguration(iceServers=get_ice_servers()) - ) - else: - pc = RTCPeerConnection() - - pcs.add(pc) - - tracks = {"video": None, "audio": None} - - # Flag to track if we've received resolution update - resolution_received = {"value": False} - - # Only add video transceiver if video is present in the offer - if "m=video" in offer.sdp: - # Prefer h264 - transceiver = pc.addTransceiver("video") - caps = RTCRtpSender.getCapabilities("video") - prefs = list(filter(lambda x: x.name == "H264", caps.codecs)) - transceiver.setCodecPreferences(prefs) - - # Monkey patch max and min bitrate to ensure constant bitrate - h264.MAX_BITRATE = MAX_BITRATE - h264.MIN_BITRATE = MIN_BITRATE - - # Handle control channel from client - @pc.on("datachannel") - def on_datachannel(channel): - if channel.label == "control": - - @channel.on("message") - async def on_message(message): - try: - params = json.loads(message) - - if params.get("type") == "get_nodes": - nodes_info = await pipeline.get_nodes_info() - response = {"type": "nodes_info", "nodes": nodes_info} - channel.send(json.dumps(response)) - elif params.get("type") == "update_prompts": - if "prompts" not in params: - logger.warning( - "[Control] Missing prompt in update_prompt message" - ) - return - await pipeline.update_prompts(params["prompts"]) - response = {"type": "prompts_updated", "success": True} - channel.send(json.dumps(response)) - elif params.get("type") == "update_resolution": - if "width" not in params or "height" not in params: - logger.warning("[Control] Missing width or height in update_resolution message") - return - # Update pipeline resolution for future frames - pipeline.width = params["width"] - pipeline.height = params["height"] - logger.info(f"[Control] Updated resolution to {params['width']}x{params['height']}") - - # Mark that we've received resolution - resolution_received["value"] = True - - # Warm the video pipeline with the new resolution - if "m=video" in pc.remoteDescription.sdp: - await pipeline.warm_video() - - response = { - "type": "resolution_updated", - "success": True - } - channel.send(json.dumps(response)) - else: - logger.warning( - "[Server] Invalid message format - missing required fields" - ) - except json.JSONDecodeError: - logger.error("[Server] Invalid JSON received") - except Exception as e: - logger.error(f"[Server] Error processing message: {str(e)}") - - @pc.on("track") - def on_track(track): - logger.info(f"Track received: {track.kind}") - if track.kind == "video": - videoTrack = VideoStreamTrack(track, pipeline) - tracks["video"] = videoTrack - sender = pc.addTrack(videoTrack) - - # Store video track in app for stats. - stream_id = track.id - request.app["video_tracks"][stream_id] = videoTrack - - codec = "video/H264" - force_codec(pc, sender, codec) - elif track.kind == "audio": - audioTrack = AudioStreamTrack(track, pipeline) - tracks["audio"] = audioTrack - pc.addTrack(audioTrack) - - @track.on("ended") - async def on_ended(): - logger.info(f"{track.kind} track ended") - request.app["video_tracks"].pop(track.id, None) - - @pc.on("connectionstatechange") - async def on_connectionstatechange(): - logger.info(f"Connection state is: {pc.connectionState}") - if pc.connectionState == "failed": - await pc.close() - pcs.discard(pc) - elif pc.connectionState == "closed": - await pc.close() - pcs.discard(pc) - - await pc.setRemoteDescription(offer) - - # Only warm audio here, video warming happens after resolution update - if "m=audio" in pc.remoteDescription.sdp: - await pipeline.warm_audio() - - # We no longer warm video here - it will be warmed after receiving resolution - - answer = await pc.createAnswer() - await pc.setLocalDescription(answer) - - return web.Response( - content_type="application/json", - text=json.dumps( - {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type} - ), - ) - - -async def cancel_collect_frames(track): - track.running = False - if hasattr(track, 'collect_task') is not None and not track.collect_task.done(): - try: - track.collect_task.cancel() - await track.collect_task - except (asyncio.CancelledError): - pass - - -async def set_prompt(request): - pipeline = request.app["pipeline"] - - prompt = await request.json() - await pipeline.set_prompts(prompt) - - return web.Response(content_type="application/json", text="OK") - - -def health(_): - return web.Response(content_type="application/json", text="OK") - - -async def on_startup(app: web.Application): - if app["media_ports"]: - patch_loop_datagram(app["media_ports"]) - - # ComfyUI args have been moved to the client constructor - app["pipeline"] = Pipeline( - width=512, - height=512, - comfyui_inference_log_level=app.get("comfui_inference_log_level", None), - config_path=app["config_file"], - max_frame_wait_ms=app["max_frame_wait"], - client_mode=app["client_mode"], - workspace=app["workspace"], - workers=app["workers"], - ) - - # Start the clients during initialization - # await app["pipeline"].start_clients() - - # Wait for pipeline startup to complete (which starts the ComfyUI servers) - if hasattr(app["pipeline"], "startup_task"): - await app["pipeline"].startup_task - - app["pcs"] = set() - app["video_tracks"] = {} - - app["max_frame_wait"] = args.max_frame_wait - - -async def on_shutdown(app: web.Application): - pcs = app["pcs"] - coros = [pc.close() for pc in pcs] - await asyncio.gather(*coros) - pcs.clear() - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Run comfystream server") - parser.add_argument("--port", default=8889, help="Set the signaling port") - parser.add_argument( - "--media-ports", default=None, help="Set the UDP ports for WebRTC media" - ) - parser.add_argument("--host", default="127.0.0.1", help="Set the host") - parser.add_argument( - "--workspace", default=None, required=True, help="Set Comfy workspace" - ) - parser.add_argument( - "--log-level", "--log_level", - dest="log_level", - default="WARNING", - choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], - help="Set the logging level", - ) - parser.add_argument( - "--config-file", - type=str, - default=None, - help="Path to TOML configuration file for Comfy servers" - ) - parser.add_argument( - "--monitor", - default=False, - action="store_true", - help="Start a Prometheus metrics endpoint for monitoring.", - ) - parser.add_argument( - "--stream-id-label", - default=False, - action="store_true", - help="Include stream ID as a label in Prometheus metrics.", - ) - parser.add_argument( - "--max-frame-wait", - type=int, - default=500, - help="Maximum time to wait for a frame in milliseconds before dropping it" - ) - parser.add_argument( - "--comfyui-log-level", - default=None, - choices=logging._nameToLevel.keys(), - help="Set the global logging level for ComfyUI", - ) - parser.add_argument( - "--comfyui-inference-log-level", - default=None, - choices=logging._nameToLevel.keys(), - help="Set the logging level for ComfyUI inference", - ) - parser.add_argument( - "--client-mode", - choices=["toml", "spawn"], - default="toml", - help="How to create ComfyUI clients: 'toml' (from config file) or 'spawn' (spawn processes directly)", - ) - parser.add_argument( - "--workers", - type=int, - default=2, - help="Number of worker processes to spawn when using --client-mode=spawn" - ) - args = parser.parse_args() - - logging.basicConfig( - level=args.log_level.upper(), - format="%(asctime)s [%(levelname)s] %(message)s", - datefmt="%H:%M:%S", - ) - - # Set logger level based on command line arguments - logger.setLevel(getattr(logging, args.log_level.upper())) - - app = web.Application() - app["media_ports"] = args.media_ports.split(",") if args.media_ports else None - app["workspace"] = args.workspace - app["config_file"] = args.config_file - app["max_frame_wait"] = args.max_frame_wait - app["client_mode"] = args.client_mode - app["workers"] = args.workers - - app.on_startup.append(on_startup) - app.on_shutdown.append(on_shutdown) - - app.router.add_get("/", health) - app.router.add_get("/health", health) - - # WebRTC signalling and control routes. - app.router.add_post("/offer", offer) - app.router.add_post("/prompt", set_prompt) - - # Add routes for getting stream statistics. - stream_stats_manager = StreamStatsManager(app) - app.router.add_get( - "/streams/stats", stream_stats_manager.collect_all_stream_metrics - ) - app.router.add_get( - "/stream/{stream_id}/stats", stream_stats_manager.collect_stream_metrics_by_id - ) - - # Add Prometheus metrics endpoint. - app["metrics_manager"] = MetricsManager(include_stream_id=args.stream_id_label) - if args.monitor: - app["metrics_manager"].enable() - logger.info( - f"Monitoring enabled - Prometheus metrics available at: " - f"http://{args.host}:{args.port}/metrics" - ) - app.router.add_get("/metrics", app["metrics_manager"].metrics_handler) - - # Add hosted platform route prefix. - # NOTE: This ensures that the local and hosted experiences have consistent routes. - add_prefix_to_app_routes(app, "/live") - - def force_print(*args, **kwargs): - print(*args, **kwargs, flush=True) - sys.stdout.flush() - - # Allow overriding of ComyfUI log levels. - if args.comfyui_log_level: - log_level = logging._nameToLevel.get(args.comfyui_log_level.upper()) - logging.getLogger("comfy").setLevel(log_level) - if args.comfyui_inference_log_level: - app["comfui_inference_log_level"] = args.comfyui_inference_log_level - - web.run_app(app, host=args.host, port=int(args.port), print=force_print) diff --git a/server/pipeline_api.py b/server/pipeline_api.py deleted file mode 100644 index aa29b9d0..00000000 --- a/server/pipeline_api.py +++ /dev/null @@ -1,572 +0,0 @@ -import av -import torch -import numpy as np -import asyncio -import logging -import time -import random -from collections import OrderedDict -import collections -import os -import socket - -from typing import Any, Dict, Union, List, Optional, Deque -from comfystream.client_api import ComfyStreamClient -from utils import temporary_log_level # Not sure exactly what this does -from config import ComfyConfig - -WARMUP_RUNS = 5 -logger = logging.getLogger(__name__) - - -class MultiServerPipeline: - def __init__( - self, - width: int = 512, - height: int = 512, - workers: int = 2, - comfyui_inference_log_level: int = None, - config_path: Optional[str] = None, - max_frame_wait_ms: int = 500, - client_mode: str = "toml", - workspace: str = None - ): - """Initialize the pipeline with the given configuration. - Args: - width: The width of the video frames. - height: The height of the video frames. - workers: The number of ComfyUI clients to spin up (if client_mode is "spawn"). - comfyui_inference_log_level: The logging level for ComfyUI inference. - Defaults to None, using the global ComfyUI log level. - config_path: The path to the ComfyUI config toml file (if client_mode is "toml"). - max_frame_wait_ms: The maximum number of milliseconds to wait for a frame before dropping it. - client_mode: The mode to use for the ComfyUI clients. - "toml": Use a config file to describe clients. - "spawn": Spawn ComfyUI clients as external processes. - """ - - # There are two methods for starting the clients: - # 1. client_mode == "toml" -> Use a config file to describe clients. - # 2. client_mode == "spawn" -> Spawn ComfyUI clients as external processes. - - self.clients = [] - self.workspace = workspace - self.client_mode = client_mode - - if (client_mode == "toml"): - # Load server configurations - self.config = ComfyConfig(config_path) - self.servers = self.config.get_servers() - elif (client_mode == "spawn"): - # Set the number of workers to spawn - self.workers = workers - - # Started in /offer - # self.start_clients() - - self.width = width - self.height = height - - self.video_incoming_frames = asyncio.Queue() - self.audio_incoming_frames = asyncio.Queue() - - # Queue for processed frames from all clients - self.processed_video_frames = asyncio.Queue() - - # Track which client gets each frame (round-robin) - self.current_client_index = 0 - self.client_frame_mapping = {} # Maps frame_id -> client_index - - # Frame ordering and timing - self.max_frame_wait_ms = max_frame_wait_ms # Max time to wait for a frame before dropping - self.next_expected_frame_id = None # Track expected frame ID - self.ordered_frames = OrderedDict() # Buffer for ordering frames (frame_id -> (timestamp, tensor)) - - # Audio processing - self.processed_audio_buffer = np.array([], dtype=np.int16) - self.last_frame_time = 0 - - # ComfyUI inference log level - self._comfyui_inference_log_level = comfyui_inference_log_level - - # Frame rate limiting - self.min_frame_interval = 1/30 # Limit to 30 FPS - - # Create background task for collecting processed frames - self.running = True - self.collector_task = asyncio.create_task(self._collect_processed_frames()) - - self.output_interval = 1/30 # Start with 30 FPS - self.last_output_time = None - self.frame_interval_history = collections.deque(maxlen=30) - self.output_pacer_task = asyncio.create_task(self._dynamic_output_pacer()) - - async def _collect_processed_frames(self): - """Background task to collect processed frames from all clients""" - try: - while self.running: - for i, client in enumerate(self.clients): - try: - # Non-blocking check if client has output ready - if hasattr(client, '_prompt_id') and client._prompt_id is not None: - # Get frame without waiting - try: - # Use wait_for with small timeout to avoid blocking - result = await asyncio.wait_for( - client.get_video_output(), - timeout=0.01 - ) - - # Check if result is already a tuple with frame_id - if isinstance(result, tuple) and len(result) == 2: - frame_id, out_tensor = result - logger.debug(f"Got result with embedded frame_id: {frame_id}") - else: - out_tensor = result - # Find which original frame this corresponds to using our mapping - frame_ids = [frame_id for frame_id, client_idx in - self.client_frame_mapping.items() if client_idx == i] - - if frame_ids: - # Use the oldest frame ID for this client - frame_id = min(frame_ids) - else: - # If no mapping found, log warning and continue - logger.warning(f"No frame_id mapping found for tensor from client {i}") - continue - - # Store frame with timestamp for ordering - timestamp = time.time() - await self._add_frame_to_ordered_buffer(frame_id, timestamp, out_tensor) - - # Remove the mapping - self.client_frame_mapping.pop(frame_id, None) - logger.info(f"Collected processed frame from client {i}, frame_id: {frame_id}") - except asyncio.TimeoutError: - # No frame ready yet, continue - pass - except Exception as e: - logger.error(f"Error collecting frame from client {i}: {e}") - - # Check for frames that have waited too long - await self._check_frame_timeouts() - - # Small sleep to avoid CPU spinning - await asyncio.sleep(0.01) - except asyncio.CancelledError: - logger.info("Frame collector task cancelled") - except Exception as e: - logger.error(f"Unexpected error in frame collector: {e}") - - async def _add_frame_to_ordered_buffer(self, frame_id, timestamp, tensor): - """Add a processed frame to the ordered buffer""" - self.ordered_frames[frame_id] = (timestamp, tensor) - - # If this is the first frame, set the next expected frame ID - if self.next_expected_frame_id is None: - self.next_expected_frame_id = frame_id - - # Check if we can release any frames now - await self._release_ordered_frames() - - async def _release_ordered_frames(self): - if self.next_expected_frame_id is None: - return - if self.ordered_frames and self.next_expected_frame_id in self.ordered_frames: - timestamp, tensor = self.ordered_frames.pop(self.next_expected_frame_id) - await self.processed_video_frames.put((self.next_expected_frame_id, tensor)) - logger.info(f"Released frame {self.next_expected_frame_id} to output queue") - if self.ordered_frames: - self.next_expected_frame_id = min(self.ordered_frames.keys()) - else: - self.next_expected_frame_id += 1 - - async def _check_frame_timeouts(self): - """Check for frames that have waited too long and handle them""" - if not self.ordered_frames or self.next_expected_frame_id is None: - return - - current_time = time.time() - - # If the next expected frame has timed out, skip it and move on - if self.next_expected_frame_id in self.ordered_frames: - timestamp, _ = self.ordered_frames[self.next_expected_frame_id] - wait_time_ms = (current_time - timestamp) * 1000 - - if wait_time_ms > self.max_frame_wait_ms: - logger.warning(f"Frame {self.next_expected_frame_id} exceeded max wait time, releasing anyway") - await self._release_ordered_frames() - - # Check if we're missing the next expected frame and it's been too long - elif self.ordered_frames: - # The next frame we're expecting isn't in the buffer - # Check how long we've been waiting since the oldest frame in the buffer - oldest_frame_id = min(self.ordered_frames.keys()) - oldest_timestamp, _ = self.ordered_frames[oldest_frame_id] - wait_time_ms = (current_time - oldest_timestamp) * 1000 - - # If we've waited too long, skip the missing frame(s) - if wait_time_ms > self.max_frame_wait_ms: - logger.debug(f"Missing frame {self.next_expected_frame_id}, skipping to {oldest_frame_id}") - self.next_expected_frame_id = oldest_frame_id - await self._release_ordered_frames() - - async def warm_video(self): - # Create dummy frame with the CURRENT resolution settings (which might have been updated via control channel) - - # Create a properly formatted dummy frame - ''' - tensor = torch.rand(1, 3, 512, 512) # Random values in [0,1] - dummy_frame = av.VideoFrame(width=512, height=512, format="rgb24") - dummy_frame.side_data.input = tensor - ''' - dummy_frame = av.VideoFrame() - dummy_frame.side_data.input = torch.randn(1, self.height, self.width, 3) - - logger.info(f"Warming video pipeline with resolution {self.width}x{self.height}") - - # Warm up each client - warmup_tasks = [] - for i, client in enumerate(self.clients): - warmup_tasks.append(self._warm_client_video(client, i, dummy_frame)) - - # Wait for all warmup tasks to complete - await asyncio.gather(*warmup_tasks) - logger.info("Video pipeline warmup complete") - - async def _warm_client_video(self, client, client_index, dummy_frame): - """Warm up a single client""" - logger.info(f"Warming up client {client_index}") - for i in range(WARMUP_RUNS): - logger.info(f"Client {client_index} warmup iteration {i+1}/{WARMUP_RUNS}") - client.put_video_input(dummy_frame) - try: - await asyncio.wait_for(client.get_video_output(), timeout=5.0) - except asyncio.TimeoutError: - logger.warning(f"Timeout waiting for warmup frame from client {client_index}") - except Exception as e: - logger.error(f"Error warming client {client_index}: {e}") - - async def warm_audio(self): - # For now, only use the first client for audio - if not self.clients: - logger.warning("No clients available for audio warmup") - return - - dummy_frame = av.AudioFrame() - dummy_frame.side_data.input = np.random.randint(-32768, 32767, int(48000 * 0.5), dtype=np.int16) - dummy_frame.sample_rate = 48000 - - for _ in range(WARMUP_RUNS): - self.clients[0].put_audio_input(dummy_frame) - await self.clients[0].get_audio_output() - - async def set_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]): - """Set the same prompts for all clients""" - if isinstance(prompts, dict): - prompts = [prompts] - - # Set prompts for each client - tasks = [] - for client in self.clients: - tasks.append(client.set_prompts(prompts)) - - await asyncio.gather(*tasks) - logger.info(f"Set prompts for {len(self.clients)} clients") - - async def update_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]): - """Update prompts for all clients""" - if isinstance(prompts, dict): - prompts = [prompts] - - # Update prompts for each client - tasks = [] - for client in self.clients: - tasks.append(client.update_prompts(prompts)) - - await asyncio.gather(*tasks) - logger.info(f"Updated prompts for {len(self.clients)} clients") - - async def put_video_frame(self, frame: av.VideoFrame): - """Distribute video frames among clients using round-robin""" - current_time = time.time() - if current_time - self.last_frame_time < self.min_frame_interval: - return # Skip frame if too soon - - self.last_frame_time = current_time - - # Generate a unique frame ID - use sequential IDs for better ordering - if not hasattr(self, 'next_frame_id'): - self.next_frame_id = 1 - - frame_id = self.next_frame_id - self.next_frame_id += 1 - - frame.side_data.frame_id = frame_id - - # Preprocess the frame - frame.side_data.input = self.video_preprocess(frame) - frame.side_data.skipped = False - - # Select the next client in round-robin fashion - client_index = self.current_client_index - self.current_client_index = (self.current_client_index + 1) % len(self.clients) - - # Store mapping of which client is processing this frame - self.client_frame_mapping[frame_id] = client_index - - # Send frame to the selected client - self.clients[client_index].put_video_input(frame) - - # Also add to the incoming queue for reference - await self.video_incoming_frames.put((frame_id, frame)) - - logger.debug(f"Sent frame {frame_id} to client {client_index}") - - async def put_audio_frame(self, frame: av.AudioFrame): - # For now, only use the first client for audio - if not self.clients: - return - - frame.side_data.input = self.audio_preprocess(frame) - frame.side_data.skipped = False - self.clients[0].put_audio_input(frame) - await self.audio_incoming_frames.put(frame) - - def audio_preprocess(self, frame: av.AudioFrame) -> Union[torch.Tensor, np.ndarray]: - return frame.to_ndarray().ravel().reshape(-1, 2).mean(axis=1).astype(np.int16) - - def video_preprocess(self, frame: av.VideoFrame) -> Union[torch.Tensor, np.ndarray]: - # Convert directly to tensor, avoiding intermediate numpy array when possible - if hasattr(frame, 'to_tensor'): - tensor = frame.to_tensor() - else: - # If direct tensor conversion not available, use numpy - frame_np = frame.to_ndarray(format="rgb24") - tensor = torch.from_numpy(frame_np) - - # Normalize to [0,1] range and add batch dimension - return tensor.float().div(255.0).unsqueeze(0) - - def video_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.VideoFrame: - return av.VideoFrame.from_ndarray( - (output.squeeze(0).permute(1, 2, 0) * 255.0) - .clamp(0, 255) - .to(dtype=torch.uint8) - .cpu() - .numpy(), - format='rgb24' - ) - - def audio_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.AudioFrame: - return av.AudioFrame.from_ndarray(np.repeat(output, 2).reshape(1, -1)) - - async def get_processed_video_frame(self): - try: - # Get the original frame from the incoming queue first to maintain timing - frame_id, frame = await self.video_incoming_frames.get() - - # Skip frames if we're falling behind - ''' - while not self.video_incoming_frames.empty(): - # Get newer frame and mark old one as skipped - frame.side_data.skipped = True - frame_id, frame = await self.video_incoming_frames.get() - logger.info(f"Skipped older frame {frame_id} to catch up") - ''' - # Get the processed frame from our output queue - processed_frame_id, out_tensor = await self.processed_video_frames.get() - - if processed_frame_id != frame_id: - logger.debug(f"Frame ID mismatch: expected {frame_id}, got {processed_frame_id}") - pass - - # Process the frame - processed_frame = self.video_postprocess(out_tensor) - processed_frame.pts = frame.pts - processed_frame.time_base = frame.time_base - - return processed_frame - - except Exception as e: - logger.error(f"Error in get_processed_video_frame: {str(e)}") - # Create a black frame as fallback - black_frame = av.VideoFrame(width=self.width, height=self.height, format='rgb24') - return black_frame - - async def get_processed_audio_frame(self): - # Only use the first client for audio - if not self.clients: - logger.warning("No clients available for audio processing") - return av.AudioFrame(format='s16', layout='mono', samples=1024) - - frame = await self.audio_incoming_frames.get() - if frame.samples > len(self.processed_audio_buffer): - out_tensor = await self.clients[0].get_audio_output() - self.processed_audio_buffer = np.concatenate([self.processed_audio_buffer, out_tensor]) - out_data = self.processed_audio_buffer[:frame.samples] - self.processed_audio_buffer = self.processed_audio_buffer[frame.samples:] - - processed_frame = self.audio_postprocess(out_data) - processed_frame.pts = frame.pts - processed_frame.time_base = frame.time_base - processed_frame.sample_rate = frame.sample_rate - - return processed_frame - - async def get_nodes_info(self) -> Dict[str, Any]: - """Get information about all nodes in the current prompt including metadata.""" - # Note that we pull the node info from the first client (as they should all be the same) - # TODO: This is just retrofitting the functionality of the comfy embedded client, there could be major improvements here - nodes_info = await self.clients[0].get_available_nodes() - return nodes_info - - async def cleanup(self): - """Clean up all clients and background tasks""" - self.running = False - - # Cancel collector task - if hasattr(self, 'collector_task') and not self.collector_task.done(): - self.collector_task.cancel() - try: - await self.collector_task - except asyncio.CancelledError: - pass - - # Clean up all clients - cleanup_tasks = [] - for client in self.clients: - cleanup_tasks.append(client.cleanup()) - - await asyncio.gather(*cleanup_tasks) - logger.info("All clients cleaned up") - - async def _dynamic_output_pacer(self): - while self.running: - # Only release if the next expected frame is available - if self.next_expected_frame_id is not None and self.next_expected_frame_id in self.ordered_frames: - timestamp, tensor = self.ordered_frames.pop(self.next_expected_frame_id) - now = time.time() - - # Calculate dynamic interval based on output history - if self.last_output_time is not None: - actual_interval = now - self.last_output_time - self.frame_interval_history.append(actual_interval) - avg_interval = sum(self.frame_interval_history) / len(self.frame_interval_history) - self.output_interval = avg_interval - self.last_output_time = now - - await self.processed_video_frames.put((self.next_expected_frame_id, tensor)) - logger.info(f"Released frame {self.next_expected_frame_id} to output queue") - - # Update next expected frame ID - if self.ordered_frames: - self.next_expected_frame_id = min(self.ordered_frames.keys()) - else: - self.next_expected_frame_id += 1 - - # Sleep for the dynamic interval, but don't sleep negative time - await asyncio.sleep(max(self.output_interval, 0.001)) - else: - # No frame ready, wait a bit and check again - await asyncio.sleep(0.005) - - async def start_clients(self): - """Start the clients based on the client_mode (TOML or spawn)""" - logger.info(f"Starting clients with mode: {self.client_mode}") - - self.clients = [] - - if hasattr(self, 'client_mode') and self.client_mode == "toml": - # Use config file to create clients - for server_config in self.servers: - client_kwargs = server_config.copy() - self.clients.append(ComfyStreamClient(**client_kwargs)) - - elif hasattr(self, 'client_mode') and self.client_mode == "spawn": - # Spin up clients as external processes - ports = [8195 + i for i in range(self.workers)] - - for i in range(self.workers): - client = ComfyStreamClient( - host="127.0.0.1", - port=ports[i], - spawn=True, - comfyui_path=os.path.join(self.workspace, "main.py"), - workspace=self.workspace, - comfyui_args=[ - "--disable-cuda-malloc", - "--gpu-only", - "--preview-method", "none", - "--listen", - "--cuda-device", "0", - "--fast", - "--enable-cors-header", "*", - "--port", str(ports[i]), - "--disable-xformers", - ], - ) - self.clients.append(client) - - else: - raise ValueError(f"Unknown client_mode: {getattr(self, 'client_mode', 'None')}") - - # Start all ComfyUI servers in parallel if in spawn mode - if hasattr(self, 'client_mode') and self.client_mode == "spawn": - # First, launch all server processes in parallel - for client in self.clients: - if client.spawn: - client._launch_comfyui_server() - - # Now create async functions to check server readiness - async def check_server_ready(client, timeout=60, check_interval=0.5): - """Async version of waiting for server to be ready""" - logger.info(f"Waiting for ComfyUI server on port {client.port} to be ready...") - - start_time = time.time() - while time.time() - start_time < timeout: - # Check if process is still running - if client._comfyui_proc and client._comfyui_proc.poll() is not None: - return_code = client._comfyui_proc.poll() - logger.error(f"ComfyUI process exited with code {return_code} before it was ready") - raise RuntimeError(f"ComfyUI process exited with code {return_code}") - - # Try to connect to the server - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(2) - result = sock.connect_ex((client.host, client.port)) - sock.close() - - if result == 0: - logger.info(f"ComfyUI server on port {client.port} is now accepting connections") - return - except Exception: - pass - - # Sleep and try again - await asyncio.sleep(check_interval) - - # If we get here, the server didn't start in time - logger.error(f"Timed out waiting for ComfyUI server on port {client.port}") - if client._comfyui_proc: - client._comfyui_proc.terminate() - client._comfyui_proc = None - raise RuntimeError(f"Timed out waiting for ComfyUI server on port {client.port}") - - # Wait for all servers to be ready in parallel - wait_tasks = [] - for client in self.clients: - if client.spawn: - wait_tasks.append(check_server_ready(client)) - - if wait_tasks: - logger.info(f"Waiting for {len(wait_tasks)} ComfyUI servers to become ready...") - await asyncio.gather(*wait_tasks) - logger.info(f"All {len(wait_tasks)} ComfyUI servers are ready") - - logger.info(f"Initialized {len(self.clients)} clients") - return self.clients - -# For backwards compatibility, maintain the original Pipeline name -Pipeline = MultiServerPipeline \ No newline at end of file diff --git a/src/comfystream/server/pipeline_api.py b/src/comfystream/server/pipeline_api.py index aa29b9d0..9155faf8 100644 --- a/src/comfystream/server/pipeline_api.py +++ b/src/comfystream/server/pipeline_api.py @@ -480,8 +480,11 @@ async def start_clients(self): if hasattr(self, 'client_mode') and self.client_mode == "toml": # Use config file to create clients for server_config in self.servers: - client_kwargs = server_config.copy() - self.clients.append(ComfyStreamClient(**client_kwargs)) + self.clients.append(ComfyStreamClient( + host=server_config["host"], + port=server_config["port"], + spawn=False, + )) elif hasattr(self, 'client_mode') and self.client_mode == "spawn": # Spin up clients as external processes From f7326c41b1ca0001e91160901d5c9cb362b5ba32 Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Wed, 16 Apr 2025 12:19:55 -0400 Subject: [PATCH 17/42] Added cuda-devices and workers-start-port params for multi-gpu spawning on same machine. --- src/comfystream/__init__.py | 3 -- src/comfystream/server/app_api.py | 16 +++++++ src/comfystream/server/pipeline_api.py | 66 ++++++++++++++++---------- 3 files changed, 56 insertions(+), 29 deletions(-) diff --git a/src/comfystream/__init__.py b/src/comfystream/__init__.py index 5e1b9fac..b58bf2e4 100644 --- a/src/comfystream/__init__.py +++ b/src/comfystream/__init__.py @@ -1,7 +1,6 @@ from .client import ComfyStreamClient from .pipeline import Pipeline from .server.utils import temporary_log_level -from .server.app import VideoStreamTrack, AudioStreamTrack from .server.utils import FPSMeter from .server.metrics import MetricsManager, StreamStatsManager @@ -9,8 +8,6 @@ 'ComfyStreamClient', 'Pipeline', 'temporary_log_level', - 'VideoStreamTrack', - 'AudioStreamTrack', 'FPSMeter', 'MetricsManager', 'StreamStatsManager' diff --git a/src/comfystream/server/app_api.py b/src/comfystream/server/app_api.py index 7fd03cc9..70dcb4cf 100644 --- a/src/comfystream/server/app_api.py +++ b/src/comfystream/server/app_api.py @@ -405,6 +405,8 @@ async def on_startup(app: web.Application): client_mode=app["client_mode"], workspace=app["workspace"], workers=app["workers"], + cuda_devices=app["cuda_devices"], + workers_start_port=app.get("workers_start_port", 8195), ) # Start the clients during initialization @@ -491,6 +493,18 @@ async def on_shutdown(app: web.Application): default=2, help="Number of worker processes to spawn when using --client-mode=spawn" ) + parser.add_argument( + "--cuda-devices", + type=str, + default='0', + help="Comma-separated list of CUDA devices to use" + ) + parser.add_argument( + "--workers-start-port", + type=int, + default=8195, + help="Starting port number for worker processes" + ) args = parser.parse_args() logging.basicConfig( @@ -509,6 +523,8 @@ async def on_shutdown(app: web.Application): app["max_frame_wait"] = args.max_frame_wait app["client_mode"] = args.client_mode app["workers"] = args.workers + app["cuda_devices"] = args.cuda_devices + app["workers_start_port"] = args.workers_start_port app.on_startup.append(on_startup) app.on_shutdown.append(on_shutdown) diff --git a/src/comfystream/server/pipeline_api.py b/src/comfystream/server/pipeline_api.py index 9155faf8..c77eb8c5 100644 --- a/src/comfystream/server/pipeline_api.py +++ b/src/comfystream/server/pipeline_api.py @@ -24,12 +24,14 @@ def __init__( self, width: int = 512, height: int = 512, - workers: int = 2, comfyui_inference_log_level: int = None, config_path: Optional[str] = None, max_frame_wait_ms: int = 500, client_mode: str = "toml", - workspace: str = None + workspace: str = None, + workers: int = 2, + cuda_devices: str = '0', + workers_start_port: int = 8195, ): """Initialize the pipeline with the given configuration. Args: @@ -43,6 +45,8 @@ def __init__( client_mode: The mode to use for the ComfyUI clients. "toml": Use a config file to describe clients. "spawn": Spawn ComfyUI clients as external processes. + workers_start_port: The starting port number for worker processes (default: 8195). + cuda_devices: The list of CUDA devices to use for the ComfyUI clients. """ # There are two methods for starting the clients: @@ -54,14 +58,20 @@ def __init__( self.client_mode = client_mode if (client_mode == "toml"): + # TOML Mode: Use a config file to describe existing ComfyUI Instances + # Load server configurations self.config = ComfyConfig(config_path) self.servers = self.config.get_servers() elif (client_mode == "spawn"): - # Set the number of workers to spawn + # SPAWN Mode: Spawn new ComfyUI Instances automatically + self.workers = workers + self.workers_start_port = workers_start_port + self.cuda_devices = cuda_devices - # Started in /offer + # Clients started in /offer (this is due to when the page refreshes, the clients automatically close) + # TODO: Perhaps a better way would be to keep the the clients alive while the server is alive? # self.start_clients() self.width = width @@ -488,28 +498,32 @@ async def start_clients(self): elif hasattr(self, 'client_mode') and self.client_mode == "spawn": # Spin up clients as external processes - ports = [8195 + i for i in range(self.workers)] - - for i in range(self.workers): - client = ComfyStreamClient( - host="127.0.0.1", - port=ports[i], - spawn=True, - comfyui_path=os.path.join(self.workspace, "main.py"), - workspace=self.workspace, - comfyui_args=[ - "--disable-cuda-malloc", - "--gpu-only", - "--preview-method", "none", - "--listen", - "--cuda-device", "0", - "--fast", - "--enable-cors-header", "*", - "--port", str(ports[i]), - "--disable-xformers", - ], - ) - self.clients.append(client) + ports = [] + cuda_device_list = [d.strip() for d in str(self.cuda_devices).split(',') if d.strip()] + for device_idx, cuda_device in enumerate(cuda_device_list): + for worker_idx in range(self.workers): + port = self.workers_start_port + len(ports) + ports.append(port) + client = ComfyStreamClient( + host="127.0.0.1", + port=port, + spawn=True, + comfyui_path=os.path.join(self.workspace, "main.py"), + workspace=self.workspace, + comfyui_args=[ + "--disable-cuda-malloc", + "--gpu-only", + "--preview-method", "none", + "--listen", + "--cuda-device", str(cuda_device), + "--fast", + "--enable-cors-header", "*", + "--port", str(port), + "--disable-xformers", + ], + ) + self.clients.append(client) + logger.info(f"Created worker {worker_idx+1}/{self.workers} for CUDA device {cuda_device} on port {port}") else: raise ValueError(f"Unknown client_mode: {getattr(self, 'client_mode', 'None')}") From 0c1aa0eeec4d1e9a79cd6bbd6e2a1d39e0641205 Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Tue, 22 Apr 2025 12:52:24 -0400 Subject: [PATCH 18/42] Fixed issue with cleanup not properly resetting the clients for subsequent runs. --- src/comfystream/client_api.py | 92 ++++++++++++-------------- src/comfystream/server/app_api.py | 14 +--- src/comfystream/server/pipeline_api.py | 82 +++++++++++++++++++---- 3 files changed, 112 insertions(+), 76 deletions(-) diff --git a/src/comfystream/client_api.py b/src/comfystream/client_api.py index 6956d5a6..eae4f69d 100644 --- a/src/comfystream/client_api.py +++ b/src/comfystream/client_api.py @@ -533,57 +533,49 @@ async def _execute_prompt(self, prompt_index: int): self.execution_complete_event.set() async def cleanup(self): - """Clean up resources, including terminating spawned ComfyUI process""" - async with self.cleanup_lock: - # Cancel all running tasks - for task in self.running_prompts.values(): - if not task.done(): - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - self.running_prompts.clear() - - # Close WebSocket connection - if self.ws: - try: - await self.ws.close() - except Exception as e: - logger.error(f"Error closing WebSocket: {e}") + """Clean up resources and reset connection state completely.""" + logger.info("Performing client cleanup and connection reset") + + # Cancel the WebSocket listener task + if self._ws_listener_task is not None and not self._ws_listener_task.done(): + self._ws_listener_task.cancel() + try: + await self._ws_listener_task + except asyncio.CancelledError: + pass + self._ws_listener_task = None + + # Close WebSocket connection + if self.ws is not None: + try: + await self.ws.close() + except Exception as e: + logger.error(f"Error closing WebSocket: {e}") + finally: self.ws = None - - # Cancel WebSocket listener task - if self._ws_listener_task and not self._ws_listener_task.done(): - self._ws_listener_task.cancel() - try: - await self._ws_listener_task - except asyncio.CancelledError: - pass - self._ws_listener_task = None - - await self.cleanup_queues() - - # Terminate the ComfyUI process if we spawned it - if self.spawn and self._comfyui_proc: - logger.info(f"Terminating ComfyUI process (PID: {self._comfyui_proc.pid})") - try: - self._comfyui_proc.terminate() - try: - # Wait for the process to terminate gracefully - exit_code = self._comfyui_proc.wait(timeout=10) - logger.info(f"ComfyUI process exited with code {exit_code}") - except subprocess.TimeoutExpired: - # If it doesn't terminate gracefully, kill it - logger.warning("ComfyUI process did not terminate gracefully, killing...") - self._comfyui_proc.kill() - self._comfyui_proc.wait() - except Exception as e: - logger.error(f"Error terminating ComfyUI process: {e}") - finally: - self._comfyui_proc = None - - logger.info("Client cleanup complete") + + # Reset all state variables + self._prompt_id = None + self._current_frame_id = None + self._frame_id_mapping = {} + self.current_prompts = [] + + # Cancel any running prompt tasks + for task in self.running_prompts.values(): + if not task.done(): + task.cancel() + self.running_prompts = {} + + # Reset the execution event + self.execution_complete_event.set() + + # Clean up queues + await self.cleanup_queues() + + # Reset buffer + self.buffer = BytesIO() + + logger.info("Client cleanup completed, connection will be reestablished on next use") async def cleanup_queues(self): """Clean up tensor queues""" diff --git a/src/comfystream/server/app_api.py b/src/comfystream/server/app_api.py index 70dcb4cf..2ef16eff 100644 --- a/src/comfystream/server/app_api.py +++ b/src/comfystream/server/app_api.py @@ -206,19 +206,7 @@ async def offer(request): if not pipeline.clients: logger.info("Clients not initialized yet, starting clients...") await pipeline.start_clients() - # Check if any clients with spawn=True need to have servers started - elif pipeline.client_mode == "spawn": - start_tasks = [] - for client in pipeline.clients: - if client.spawn and (not hasattr(client, '_comfyui_proc') or client._comfyui_proc is None): - start_tasks.append(client.start_server()) - - # Start any servers that need to be started - if start_tasks: - logger.info(f"Starting ComfyUI servers for new workflow...") - await asyncio.gather(*start_tasks) - logger.info(f"Started {len(start_tasks)} ComfyUI servers") - + # Get parameters params = await request.json() diff --git a/src/comfystream/server/pipeline_api.py b/src/comfystream/server/pipeline_api.py index c77eb8c5..5a3340e0 100644 --- a/src/comfystream/server/pipeline_api.py +++ b/src/comfystream/server/pipeline_api.py @@ -432,24 +432,80 @@ async def get_nodes_info(self) -> Dict[str, Any]: return nodes_info async def cleanup(self): - """Clean up all clients and background tasks""" - self.running = False + """Clean up resources used by the pipeline.""" + logger.info("Performing complete pipeline cleanup") - # Cancel collector task - if hasattr(self, 'collector_task') and not self.collector_task.done(): - self.collector_task.cancel() + # Cancel the dynamic output pacer task if it exists + if hasattr(self, "_pacer_task") and self._pacer_task is not None: + self._pacer_task.cancel() try: - await self.collector_task + await self._pacer_task except asyncio.CancelledError: pass + self._pacer_task = None - # Clean up all clients - cleanup_tasks = [] - for client in self.clients: - cleanup_tasks.append(client.cleanup()) - - await asyncio.gather(*cleanup_tasks) - logger.info("All clients cleaned up") + # Cancel any frame timeout tasks + if hasattr(self, "_timeout_task") and self._timeout_task is not None: + self._timeout_task.cancel() + try: + await self._timeout_task + except asyncio.CancelledError: + pass + self._timeout_task = None + + # Reset frame tracking state + self.next_expected_frame_id = None + self.ordered_frames.clear() + self.next_frame_id = 1 # Reset frame ID counter for new connection + self.client_frame_mapping.clear() + + # Clear any queued frames + while not self.video_incoming_frames.empty(): + try: + self.video_incoming_frames.get_nowait() + except asyncio.QueueEmpty: + break + + # Reset client state and connections + for i, client in enumerate(self.clients): + if client: + # Clean up client resources + try: + await client.cleanup() + except Exception as e: + logger.error(f"Error during client {i} cleanup: {e}") + + # Reset client connection status + if hasattr(client, 'ws_connected'): + client.ws_connected = False + + # Clear any client-specific execution state + if hasattr(client, 'prompt_executing'): + client.prompt_executing = False + + # Mark clients as needing reinitialization + self.clients_initialized = False + + # Clear any cached prompt mappings + if hasattr(self, "_prompt_ids"): + self._prompt_ids = {} + + # Reset warmup state + if hasattr(self, "_warmup_complete"): + self._warmup_complete = False + + # Reset any frame buffers + if hasattr(self, "_frame_buffer"): + self._frame_buffer.clear() + + # Ensure dynamic state like frame rate trackers are reset + if hasattr(self, "_last_frame_time"): + self._last_frame_time = None + + # Reset output counters + self.output_counter = 0 + + logger.info("Pipeline cleanup completed, clients will be reinitialized on next connection") async def _dynamic_output_pacer(self): while self.running: From 6c07c6582fdfef2a5ac93a34aea9aecc8b0eb1e1 Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Tue, 22 Apr 2025 21:45:20 -0400 Subject: [PATCH 19/42] Better error handling for Comfy instances via spawn, reorganization of app, pipeline and config files. --- {src/comfystream/server => server}/app_api.py | 7 +- src/comfystream/client_api.py | 15 ++- src/comfystream/{server => }/pipeline_api.py | 105 ++++++++++++------ src/comfystream/server/{ => utils}/config.py | 0 4 files changed, 86 insertions(+), 41 deletions(-) rename {src/comfystream/server => server}/app_api.py (98%) rename src/comfystream/{server => }/pipeline_api.py (86%) rename src/comfystream/server/{ => utils}/config.py (100%) diff --git a/src/comfystream/server/app_api.py b/server/app_api.py similarity index 98% rename from src/comfystream/server/app_api.py rename to server/app_api.py index 2ef16eff..b5ed4a7d 100644 --- a/src/comfystream/server/app_api.py +++ b/server/app_api.py @@ -21,10 +21,11 @@ ) from aiortc.codecs import h264 from aiortc.rtcrtpsender import RTCRtpSender -from pipeline_api import Pipeline # TODO: Better integration (Are we replacing pipeline with pipeline_api?) from twilio.rest import Client -from utils import patch_loop_datagram, add_prefix_to_app_routes, FPSMeter -from metrics import MetricsManager, StreamStatsManager + +from comfystream.pipeline_api import Pipeline +from comfystream.server.utils import patch_loop_datagram, add_prefix_to_app_routes, FPSMeter +from comfystream.server.metrics import MetricsManager, StreamStatsManager logger = logging.getLogger(__name__) logging.getLogger("aiortc.rtcrtpsender").setLevel(logging.WARNING) diff --git a/src/comfystream/client_api.py b/src/comfystream/client_api.py index eae4f69d..17decbf5 100644 --- a/src/comfystream/client_api.py +++ b/src/comfystream/client_api.py @@ -396,7 +396,7 @@ async def _execute_prompt(self, prompt_index: int): # Store current frame ID for binary message handler to use self._current_frame_id = frame_id - # Find ETN_LoadImageBase64 nodes first + # Find LoadImageBase64 nodes first load_image_nodes = [] for node_id, node in prompt.items(): if isinstance(node, dict) and node.get("class_type") in ["LoadImageBase64"]: @@ -833,12 +833,14 @@ def _launch_comfyui_server(self): else: # Only add default arguments if comfyui_args was not provided cmd.extend([ + "--disable-cuda-malloc", # Helps prevent CUDA memory issues + "--gpu-only", # Use GPU for all operations when possible + "--preview-method", "none", # Disable previews to save memory "--listen", "--port", str(self.port), "--fast", "--enable-cors-header", "*", - "--disable-xformers", - "--preview-method", "none" + "--disable-xformers", # More compatible with some systems ]) # Add workspace if provided and not in comfyui_args @@ -877,8 +879,11 @@ def _launch_comfyui_server(self): # Start a thread to log output def log_output(stream, level): for line in iter(stream.readline, ''): - # TODO: Handle error logs from comfy - logger.debug(f"ComfyUI[{self.port}]: {line.strip()}") + # Check for known error patterns and log them at appropriate levels + if "error" in line.lower() or "exception" in line.lower(): + logger.error(f"ComfyUI[{self.port}]: {line.strip()}") + else: + logger.debug(f"ComfyUI[{self.port}]: {line.strip()}") import threading threading.Thread(target=log_output, args=(self._comfyui_proc.stdout, logging.INFO), daemon=True).start() diff --git a/src/comfystream/server/pipeline_api.py b/src/comfystream/pipeline_api.py similarity index 86% rename from src/comfystream/server/pipeline_api.py rename to src/comfystream/pipeline_api.py index 5a3340e0..5c4ff303 100644 --- a/src/comfystream/server/pipeline_api.py +++ b/src/comfystream/pipeline_api.py @@ -12,8 +12,8 @@ from typing import Any, Dict, Union, List, Optional, Deque from comfystream.client_api import ComfyStreamClient -from utils import temporary_log_level # Not sure exactly what this does -from config import ComfyConfig +from comfystream.server.utils import temporary_log_level # Not sure exactly what this does +from comfystream.server.utils.config import ComfyConfig WARMUP_RUNS = 5 logger = logging.getLogger(__name__) @@ -592,40 +592,79 @@ async def start_clients(self): client._launch_comfyui_server() # Now create async functions to check server readiness - async def check_server_ready(client, timeout=60, check_interval=0.5): - """Async version of waiting for server to be ready""" + async def check_server_ready(client, timeout=60, check_interval=0.5, max_retries=3): + """Async version of waiting for server to be ready with retry logic""" logger.info(f"Waiting for ComfyUI server on port {client.port} to be ready...") - start_time = time.time() - while time.time() - start_time < timeout: - # Check if process is still running - if client._comfyui_proc and client._comfyui_proc.poll() is not None: - return_code = client._comfyui_proc.poll() - logger.error(f"ComfyUI process exited with code {return_code} before it was ready") - raise RuntimeError(f"ComfyUI process exited with code {return_code}") - - # Try to connect to the server - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(2) - result = sock.connect_ex((client.host, client.port)) - sock.close() - - if result == 0: - logger.info(f"ComfyUI server on port {client.port} is now accepting connections") - return - except Exception: - pass - - # Sleep and try again - await asyncio.sleep(check_interval) + retries = 0 + while retries <= max_retries: + start_time = time.time() + while time.time() - start_time < timeout: + # Check if process is still running + if client._comfyui_proc and client._comfyui_proc.poll() is not None: + return_code = client._comfyui_proc.poll() + logger.error(f"ComfyUI process exited with code {return_code} before it was ready") + + # If we still have retries left, restart the process + if retries < max_retries: + retries += 1 + logger.info(f"Retrying ComfyUI server on port {client.port} (attempt {retries}/{max_retries})") + + # Kill any zombie process + if client._comfyui_proc: + try: + client._comfyui_proc.terminate() + except Exception: + pass + + # Start a new process + client._launch_comfyui_server() + await asyncio.sleep(2) # Give it a moment to start + break # Break inner loop to restart timeout + else: + # We're out of retries + raise RuntimeError(f"ComfyUI process exited with code {return_code} after {max_retries} retries") + + # Try to connect to the server + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(2) + result = sock.connect_ex((client.host, client.port)) + sock.close() + + if result == 0: + logger.info(f"ComfyUI server on port {client.port} is now accepting connections") + return + except Exception: + pass + + # Sleep and try again + await asyncio.sleep(check_interval) - # If we get here, the server didn't start in time - logger.error(f"Timed out waiting for ComfyUI server on port {client.port}") - if client._comfyui_proc: - client._comfyui_proc.terminate() - client._comfyui_proc = None - raise RuntimeError(f"Timed out waiting for ComfyUI server on port {client.port}") + # If we break out of the inner loop due to a restart, continue + # If we break out due to timeout, increment retries and try again + if time.time() - start_time >= timeout: + retries += 1 + if retries <= max_retries: + logger.info(f"Timed out waiting for ComfyUI server on port {client.port}, retrying (attempt {retries}/{max_retries})") + + # Kill any zombie process + if client._comfyui_proc: + try: + client._comfyui_proc.terminate() + except Exception: + pass + + # Start a new process + client._launch_comfyui_server() + await asyncio.sleep(2) # Give it a moment to start + else: + # We're out of retries + logger.error(f"Timed out waiting for ComfyUI server on port {client.port} after {max_retries} retries") + if client._comfyui_proc: + client._comfyui_proc.terminate() + client._comfyui_proc = None + raise RuntimeError(f"Timed out waiting for ComfyUI server on port {client.port} after {max_retries} retries") # Wait for all servers to be ready in parallel wait_tasks = [] diff --git a/src/comfystream/server/config.py b/src/comfystream/server/utils/config.py similarity index 100% rename from src/comfystream/server/config.py rename to src/comfystream/server/utils/config.py From 93308b6e6a2f904b1f1fce8fd170a229c4ae5242 Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Thu, 24 Apr 2025 13:00:06 -0400 Subject: [PATCH 20/42] Fix to linux subprocess command. --- src/comfystream/client_api.py | 2 +- src/comfystream/pipeline_api.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/comfystream/client_api.py b/src/comfystream/client_api.py index 17decbf5..4c20fe86 100644 --- a/src/comfystream/client_api.py +++ b/src/comfystream/client_api.py @@ -839,7 +839,7 @@ def _launch_comfyui_server(self): "--listen", "--port", str(self.port), "--fast", - "--enable-cors-header", "*", + "--enable-cors-header", "\"*\"", "--disable-xformers", # More compatible with some systems ]) diff --git a/src/comfystream/pipeline_api.py b/src/comfystream/pipeline_api.py index 5c4ff303..79321785 100644 --- a/src/comfystream/pipeline_api.py +++ b/src/comfystream/pipeline_api.py @@ -573,7 +573,7 @@ async def start_clients(self): "--listen", "--cuda-device", str(cuda_device), "--fast", - "--enable-cors-header", "*", + "--enable-cors-header", "\"*\"", "--port", str(port), "--disable-xformers", ], From 98b78e89ceb85fe6a44c876e87fe93cd039e79c3 Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Fri, 25 Apr 2025 16:30:49 -0400 Subject: [PATCH 21/42] Added spawned comfy specific logging, modified client logging, modifications to spawning instances, better handling of misconfigured workspace. --- server/app_api.py | 18 ++- src/comfystream/client_api.py | 220 +++++++++++++----------------- src/comfystream/pipeline_api.py | 233 +++++++++++++------------------- 3 files changed, 210 insertions(+), 261 deletions(-) diff --git a/server/app_api.py b/server/app_api.py index b5ed4a7d..a5330bca 100644 --- a/server/app_api.py +++ b/server/app_api.py @@ -206,7 +206,17 @@ async def offer(request): # Check if clients are initialized, and initialize them if not if not pipeline.clients: logger.info("Clients not initialized yet, starting clients...") - await pipeline.start_clients() + results = await pipeline.start_clients() + + # Check if there was an error during startup + if results is None and hasattr(pipeline, 'startup_error') and pipeline.startup_error: + error_message = pipeline.startup_error + logger.error(f"Failed to initialize clients: {error_message}") + return web.Response( + status=500, + content_type="application/json", + text=json.dumps({"error": f"Failed to start ComfyUI: {error_message}"}) + ) # Get parameters params = await request.json() @@ -396,7 +406,11 @@ async def on_startup(app: web.Application): workers=app["workers"], cuda_devices=app["cuda_devices"], workers_start_port=app.get("workers_start_port", 8195), + comfyui_log_level=app.get("comfyui_log_level", None), ) + + if (app.get("client_mode") == "spawn" and app.get("comfyui_log_level") is None): + print("To see spawned ComfyUI logs, add --comfyui_log_level=DEBUG") # Start the clients during initialization # await app["pipeline"].start_clients() @@ -553,9 +567,11 @@ def force_print(*args, **kwargs): sys.stdout.flush() # Allow overriding of ComyfUI log levels. + # TODO: This will have to pipe to spawn clients if args.comfyui_log_level: log_level = logging._nameToLevel.get(args.comfyui_log_level.upper()) logging.getLogger("comfy").setLevel(log_level) + app["comfyui_log_level"] = args.comfyui_log_level if args.comfyui_inference_log_level: app["comfui_inference_log_level"] = args.comfyui_inference_log_level diff --git a/src/comfystream/client_api.py b/src/comfystream/client_api.py index 4c20fe86..7975384c 100644 --- a/src/comfystream/client_api.py +++ b/src/comfystream/client_api.py @@ -30,7 +30,8 @@ def __init__( spawn: bool = False, comfyui_path: str = None, comfyui_args: list = None, - workspace: str = None + workspace: str = None, + comfyui_log_level: str = None, ): """ Initialize the ComfyStream client to use the ComfyUI API. @@ -42,6 +43,7 @@ def __init__( comfyui_path: Path to the ComfyUI main.py file (required if spawn=True) comfyui_args: Additional arguments for ComfyUI workspace: The workspace directory for ComfyUI + comfyui_log_level: The logging level for ComfyUI """ self.host = host self.port = port @@ -67,8 +69,10 @@ def __init__( self._prompt_id = None self._current_frame_id = None # Track the current frame being processed self._frame_id_mapping = {} # Map prompt_ids to frame_ids + + self.comfyui_log_level = comfyui_log_level - logger.info(f"ComfyStreamClient initialized with host: {host}, port: {port}, client_id: {self.client_id}") + logger.info(f"[Client[{self.port}]: ComfyStreamClient initialized with host: {host}, port: {port}, client_id: {self.client_id}") async def set_prompts(self, prompts: List[Dict]): """Set prompts and run them (compatible with original interface)""" @@ -80,7 +84,7 @@ async def set_prompts(self, prompts: List[Dict]): task = asyncio.create_task(self.run_prompt(idx)) self.running_prompts[idx] = task - logger.info(f"Set {len(self.current_prompts)} prompts for execution") + logger.info(f"[Client[{self.port}]: Set {len(self.current_prompts)} prompts for execution") async def update_prompts(self, prompts: List[Dict]): """Update existing prompts (compatible with original interface)""" @@ -89,11 +93,11 @@ async def update_prompts(self, prompts: List[Dict]): "Number of updated prompts must match the number of currently running prompts." ) self.current_prompts = [convert_prompt(prompt) for prompt in prompts] - logger.info(f"Updated {len(self.current_prompts)} prompts") + logger.info(f"[Client[{self.port}]: Updated {len(self.current_prompts)} prompts") async def run_prompt(self, prompt_index: int): """Run a prompt continuously, processing new frames as they arrive""" - logger.info(f"Running prompt {prompt_index}") + logger.info(f"[Client[{self.port}]: Running prompt {prompt_index}") # Make sure WebSocket is connected await self._connect_websocket() @@ -129,7 +133,7 @@ async def run_prompt(self, prompt_index: int): await asyncio.sleep(0.01) # Short sleep to prevent CPU spinning except asyncio.CancelledError: - logger.info(f"Prompt {prompt_index} execution cancelled") + logger.info(f"[Client[{self.port}]: Prompt {prompt_index} execution cancelled") raise except Exception as e: logger.error(f"Error in run_prompt: {str(e)}") @@ -149,7 +153,7 @@ async def _connect_websocket(self): pass self.ws = None - logger.info(f"Connecting to WebSocket at {self.server_address}?clientId={self.client_id}") + logger.info(f"[Client[{self.port}]: Connecting to WebSocket at {self.server_address}?clientId={self.client_id}") try: # Connect with proper error handling @@ -162,12 +166,12 @@ async def _connect_websocket(self): ssl=None ) - logger.info("WebSocket connected successfully") + logger.info(f"[Client[{self.port}]: WebSocket connected successfully") # Start the listener task if not already running if self._ws_listener_task is None or self._ws_listener_task.done(): self._ws_listener_task = asyncio.create_task(self._ws_listener()) - logger.info("Started WebSocket listener task") + logger.info(f"[Client[{self.port}]: Started WebSocket listener task") return self.ws @@ -190,7 +194,7 @@ async def _connect_websocket(self): async def _ws_listener(self): """Listen for WebSocket messages and process them""" try: - logger.info(f"WebSocket listener started") + logger.info(f"[Client[{self.port}]: WebSocket listener started") while True: if self.ws is None: try: @@ -212,7 +216,7 @@ async def _ws_listener(self): await self._handle_binary_message(message) except websockets.exceptions.ConnectionClosed: - logger.info("WebSocket connection closed") + logger.info(f"[Client[{self.port}]: WebSocket connection closed") self.ws = None await asyncio.sleep(1) except Exception as e: @@ -220,7 +224,7 @@ async def _ws_listener(self): await asyncio.sleep(1) except asyncio.CancelledError: - logger.info("WebSocket listener cancelled") + logger.info(f"[Client[{self.port}]: WebSocket listener cancelled") raise except Exception as e: logger.error(f"Unexpected error in WebSocket listener: {e}") @@ -231,8 +235,8 @@ async def _handle_text_message(self, message: str): data = json.loads(message) message_type = data.get("type", "unknown") - # logger.info(f"Received message type: {message_type}") - logger.debug(f"{data}") + # logger.info(f"[Client[{self.port}]: Received message type: {message_type}") + # logger.debug(f"{data}") # Example output ''' @@ -269,7 +273,7 @@ async def _handle_text_message(self, message: str): if message_type == "execution_start": if "data" in data and "prompt_id" in data["data"]: self._prompt_id = data["data"]["prompt_id"] - logger.info(f"Execution started for prompt {self._prompt_id}") + logger.debug(f"[Client[{self.port}]: Execution started for prompt {self._prompt_id}") # Let's queue the next prompt here! self.execution_complete_event.set() @@ -391,7 +395,7 @@ async def _execute_prompt(self, prompt_index: int): # Try to get frame_id from side_data if hasattr(frame_or_tensor.side_data, 'frame_id'): frame_id = frame_or_tensor.side_data.frame_id - logger.info(f"Found frame_id in side_data: {frame_id}") + logger.debug(f"Found frame_id in side_data: {frame_id}") # Store current frame ID for binary message handler to use self._current_frame_id = frame_id @@ -519,13 +523,13 @@ async def _execute_prompt(self, prompt_index: int): # Map prompt_id to frame_id for later retrieval if frame_id is not None: self._frame_id_mapping[self._prompt_id] = frame_id - logger.info(f"Mapped prompt_id {self._prompt_id} to frame_id {frame_id}") + logger.debug(f"Mapped prompt_id {self._prompt_id} to frame_id {frame_id}") else: error_text = await response.text() logger.error(f"Error queueing prompt: {response.status} - {error_text}") self.execution_complete_event.set() else: - logger.info("No tensor in input queue, skipping prompt execution") + logger.debug("No tensor in input queue, skipping prompt execution") self.execution_complete_event.set() except Exception as e: @@ -534,7 +538,7 @@ async def _execute_prompt(self, prompt_index: int): async def cleanup(self): """Clean up resources and reset connection state completely.""" - logger.info("Performing client cleanup and connection reset") + logger.info(f"[Client[{self.port}]: Performing client cleanup and connection reset") # Cancel the WebSocket listener task if self._ws_listener_task is not None and not self._ws_listener_task.done(): @@ -575,7 +579,7 @@ async def cleanup(self): # Reset buffer self.buffer = BytesIO() - logger.info("Client cleanup completed, connection will be reestablished on next use") + logger.info(f"[Client[{self.port}]: Client cleanup completed, connection will be reestablished on next use") async def cleanup_queues(self): """Clean up tensor queues""" @@ -597,7 +601,7 @@ async def cleanup_queues(self): except: pass - logger.info("Tensor queues cleared") + logger.info(f"[Client[{self.port}]: Tensor queues cleared") def put_video_input(self, frame): if tensor_cache.image_inputs.full(): @@ -615,12 +619,12 @@ async def get_video_output(self): # Check if the result is a tuple with frame_id if isinstance(result, tuple) and len(result) == 2: frame_id, tensor = result - logger.info(f"Got processed tensor from output queue with frame_id {frame_id}") + logger.debug(f"[Client[{self.port}]: Got processed tensor from output queue with frame_id {frame_id}") # Return both the frame_id and tensor to help with ordering in the pipeline return frame_id, tensor else: # If it's not a tuple with frame_id, just return the tensor - logger.info("Got processed tensor from output queue without frame_id") + logger.debug("Got processed tensor from output queue without frame_id") return result async def get_audio_output(self): @@ -776,97 +780,34 @@ def create_enum_format(type_name): return all_prompts_nodes_info - async def start_server(self): - """Launch the ComfyUI server if spawn is True""" - if self.spawn: - if not self.comfyui_path: - raise ValueError("comfyui_path must be provided when spawn=True") - self._launch_comfyui_server() - self._wait_for_server_ready() - logger.info("ComfyUI server started successfully") - else: - logger.info("Using existing ComfyUI server (spawn=False)") - - async def _fetch_object_info(self, session: aiohttp.ClientSession, url: str, class_type: str) -> Optional[tuple[str, Any]]: - """Helper function to fetch object info for a single class type.""" - try: - logger.debug(f"Fetching object info for: {class_type} from {url}") - async with session.get(url) as response: - if response.status == 200: - try: - data = await response.json() - # Extract the actual node info from the nested structure - if class_type in data and isinstance(data[class_type], dict): - node_specific_info = data[class_type] - logger.debug(f"Successfully fetched and extracted info for {class_type}") - return class_type, node_specific_info - else: - logger.error(f"Unexpected response structure for {class_type}. Key missing or not a dict. Response: {data}") - - except aiohttp.ContentTypeError: - logger.error(f"Failed to decode JSON for {class_type}. Status: {response.status}, Content-Type: {response.headers.get('Content-Type')}, Response: {await response.text()[:200]}...") # Log beginning of text - except json.JSONDecodeError as e: - logger.error(f"Invalid JSON received for {class_type}. Status: {response.status}, Error: {e}, Response: {await response.text()[:200]}...") - else: - error_text = await response.text() - logger.error(f"Error fetching info for {class_type}: {response.status} - {error_text[:200]}...") - except aiohttp.ClientError as e: - logger.error(f"HTTP client error fetching info for {class_type} ({url}): {e}") - except Exception as e: - logger.error(f"Unexpected error fetching info for {class_type} ({url}): {e}") - - # Return class_type and None if any error occurred - return class_type, None - - def _launch_comfyui_server(self): + def launch_comfyui_server(self): """Launch ComfyUI as a subprocess""" - logger.info(f"Spawning ComfyUI server on port {self.port}...") - - # Build the command with just the basics - cmd = [ - "python", self.comfyui_path, - ] - - # Add the arguments from comfyui_args if provided - if self.comfyui_args: - cmd.extend(self.comfyui_args) - else: - # Only add default arguments if comfyui_args was not provided - cmd.extend([ - "--disable-cuda-malloc", # Helps prevent CUDA memory issues - "--gpu-only", # Use GPU for all operations when possible - "--preview-method", "none", # Disable previews to save memory - "--listen", - "--port", str(self.port), - "--fast", - "--enable-cors-header", "\"*\"", - "--disable-xformers", # More compatible with some systems - ]) - - # Add workspace if provided and not in comfyui_args - if self.workspace: - cmd.extend(["--dir", self.workspace]) - - # Check if CUDA is available and add device argument - if hasattr(torch, 'cuda') and torch.cuda.is_available(): - cuda_device = os.environ.get("CUDA_VISIBLE_DEVICES", "0") - cmd.extend(["--cuda-device", cuda_device]) - - # Always ensure port is set correctly (override if provided in comfyui_args) - # Remove any existing --port argument - if "--port" in cmd: - port_index = cmd.index("--port") - # Remove both the flag and its value - if port_index + 1 < len(cmd): - cmd.pop(port_index + 1) - cmd.pop(port_index) - - # Add our port - cmd.extend(["--port", str(self.port)]) + logger.info(f"[Client[{self.port}]: Spawning ComfyUI server...") # Start the process try: - logger.info(f"Starting ComfyUI with command: {' '.join(cmd)}") + # Build the command with just the basics + cmd = [ + "python", self.comfyui_path, + ] + + # Add the arguments from comfyui_args (retreived from pipeline_api) + ''' + [ + "--disable-cuda-malloc", + "--gpu-only", + "--preview-method", "none", + "--listen", + "--cuda-device", str(cuda_device), + "--fast", + "--enable-cors-header", "\"*\"", + "--port", str(port), + "--disable-xformers", + ] + ''' + cmd.extend(self.comfyui_args) + + logger.info(f"[Client[{self.port}]: Starting ComfyUI with command: {' '.join(cmd)}") self._comfyui_proc = subprocess.Popen( cmd, stdout=subprocess.PIPE, @@ -877,26 +818,24 @@ def _launch_comfyui_server(self): ) # Start a thread to log output - def log_output(stream, level): + def log_output(stream): for line in iter(stream.readline, ''): - # Check for known error patterns and log them at appropriate levels - if "error" in line.lower() or "exception" in line.lower(): - logger.error(f"ComfyUI[{self.port}]: {line.strip()}") - else: - logger.debug(f"ComfyUI[{self.port}]: {line.strip()}") + # TODO: Handle different log levels? + if self.comfyui_log_level == "DEBUG": + logger.info(f"ComfyUI[{self.port}]: {line.strip()}") import threading - threading.Thread(target=log_output, args=(self._comfyui_proc.stdout, logging.INFO), daemon=True).start() - threading.Thread(target=log_output, args=(self._comfyui_proc.stderr, logging.INFO), daemon=True).start() + threading.Thread(target=log_output, args=(self._comfyui_proc.stdout,), daemon=True).start() + threading.Thread(target=log_output, args=(self._comfyui_proc.stderr,), daemon=True).start() - logger.info(f"Started ComfyUI process with PID {self._comfyui_proc.pid}") + logger.info(f"[Client[{self.port}]: Started ComfyUI process with PID {self._comfyui_proc.pid}") except Exception as e: logger.error(f"Failed to spawn ComfyUI: {e}") raise - def _wait_for_server_ready(self, timeout=60, check_interval=0.5): + def wait_for_server_ready(self, timeout=60, check_interval=0.5): """Wait until the ComfyUI server is accepting connections""" - logger.info(f"Waiting for ComfyUI server on port {self.port} to be ready...") + logger.info(f"[Client[{self.port}]: Waiting for ComfyUI server to be ready...") start_time = time.time() while time.time() - start_time < timeout: @@ -909,15 +848,48 @@ def _wait_for_server_ready(self, timeout=60, check_interval=0.5): # Try to connect to the server try: with socket.create_connection((self.host, self.port), timeout=2): - logger.info(f"ComfyUI server on port {self.port} is now accepting connections") + logger.info(f"[Client[{self.port}]: ComfyUI server is now accepting connections") return except (ConnectionRefusedError, socket.timeout, OSError): # Sleep and try again time.sleep(check_interval) # If we get here, the server didn't start in time - logger.error(f"Timed out waiting for ComfyUI server on port {self.port}") + logger.error(f"[Client[{self.port}]: Timed out waiting for ComfyUI server") + if self._comfyui_proc: self._comfyui_proc.terminate() self._comfyui_proc = None - raise RuntimeError(f"Timed out waiting for ComfyUI server on port {self.port}") \ No newline at end of file + + raise RuntimeError(f"Timed out waiting for ComfyUI server on port {self.port}") + + async def _fetch_object_info(self, session: aiohttp.ClientSession, url: str, class_type: str) -> Optional[tuple[str, Any]]: + """Helper function to fetch object info for a single class type.""" + try: + logger.debug(f"Fetching object info for: {class_type} from {url}") + async with session.get(url) as response: + if response.status == 200: + try: + data = await response.json() + # Extract the actual node info from the nested structure + if class_type in data and isinstance(data[class_type], dict): + node_specific_info = data[class_type] + logger.debug(f"Successfully fetched and extracted info for {class_type}") + return class_type, node_specific_info + else: + logger.error(f"Unexpected response structure for {class_type}. Key missing or not a dict. Response: {data}") + + except aiohttp.ContentTypeError: + logger.error(f"Failed to decode JSON for {class_type}. Status: {response.status}, Content-Type: {response.headers.get('Content-Type')}, Response: {await response.text()[:200]}...") # Log beginning of text + except json.JSONDecodeError as e: + logger.error(f"Invalid JSON received for {class_type}. Status: {response.status}, Error: {e}, Response: {await response.text()[:200]}...") + else: + error_text = await response.text() + logger.error(f"Error fetching info for {class_type}: {response.status} - {error_text[:200]}...") + except aiohttp.ClientError as e: + logger.error(f"HTTP client error fetching info for {class_type} ({url}): {e}") + except Exception as e: + logger.error(f"Unexpected error fetching info for {class_type} ({url}): {e}") + + # Return class_type and None if any error occurred + return class_type, None \ No newline at end of file diff --git a/src/comfystream/pipeline_api.py b/src/comfystream/pipeline_api.py index 79321785..a4bb49ea 100644 --- a/src/comfystream/pipeline_api.py +++ b/src/comfystream/pipeline_api.py @@ -32,6 +32,7 @@ def __init__( workers: int = 2, cuda_devices: str = '0', workers_start_port: int = 8195, + comfyui_log_level: str = None, ): """Initialize the pipeline with the given configuration. Args: @@ -47,6 +48,7 @@ def __init__( "spawn": Spawn ComfyUI clients as external processes. workers_start_port: The starting port number for worker processes (default: 8195). cuda_devices: The list of CUDA devices to use for the ComfyUI clients. + comfyui_log_level: The logging level for ComfyUI """ # There are two methods for starting the clients: @@ -110,6 +112,8 @@ def __init__( self.last_output_time = None self.frame_interval_history = collections.deque(maxlen=30) self.output_pacer_task = asyncio.create_task(self._dynamic_output_pacer()) + + self.comfyui_log_level = comfyui_log_level async def _collect_processed_frames(self): """Background task to collect processed frames from all clients""" @@ -151,7 +155,7 @@ async def _collect_processed_frames(self): # Remove the mapping self.client_frame_mapping.pop(frame_id, None) - logger.info(f"Collected processed frame from client {i}, frame_id: {frame_id}") + logger.debug(f"Collected processed frame from client {i}, frame_id: {frame_id}") except asyncio.TimeoutError: # No frame ready yet, continue pass @@ -185,7 +189,7 @@ async def _release_ordered_frames(self): if self.ordered_frames and self.next_expected_frame_id in self.ordered_frames: timestamp, tensor = self.ordered_frames.pop(self.next_expected_frame_id) await self.processed_video_frames.put((self.next_expected_frame_id, tensor)) - logger.info(f"Released frame {self.next_expected_frame_id} to output queue") + logger.debug(f"Released frame {self.next_expected_frame_id} to output queue") if self.ordered_frames: self.next_expected_frame_id = min(self.ordered_frames.keys()) else: @@ -523,7 +527,7 @@ async def _dynamic_output_pacer(self): self.last_output_time = now await self.processed_video_frames.put((self.next_expected_frame_id, tensor)) - logger.info(f"Released frame {self.next_expected_frame_id} to output queue") + logger.debug(f"Released frame {self.next_expected_frame_id} to output queue") # Update next expected frame ID if self.ordered_frames: @@ -542,143 +546,100 @@ async def start_clients(self): logger.info(f"Starting clients with mode: {self.client_mode}") self.clients = [] + self.startup_error = None - if hasattr(self, 'client_mode') and self.client_mode == "toml": - # Use config file to create clients - for server_config in self.servers: - self.clients.append(ComfyStreamClient( - host=server_config["host"], - port=server_config["port"], - spawn=False, - )) - - elif hasattr(self, 'client_mode') and self.client_mode == "spawn": - # Spin up clients as external processes - ports = [] - cuda_device_list = [d.strip() for d in str(self.cuda_devices).split(',') if d.strip()] - for device_idx, cuda_device in enumerate(cuda_device_list): - for worker_idx in range(self.workers): - port = self.workers_start_port + len(ports) - ports.append(port) - client = ComfyStreamClient( - host="127.0.0.1", - port=port, - spawn=True, - comfyui_path=os.path.join(self.workspace, "main.py"), - workspace=self.workspace, - comfyui_args=[ - "--disable-cuda-malloc", - "--gpu-only", - "--preview-method", "none", - "--listen", - "--cuda-device", str(cuda_device), - "--fast", - "--enable-cors-header", "\"*\"", - "--port", str(port), - "--disable-xformers", - ], - ) - self.clients.append(client) - logger.info(f"Created worker {worker_idx+1}/{self.workers} for CUDA device {cuda_device} on port {port}") + try: + if hasattr(self, 'client_mode') and self.client_mode == "toml": + # Use config file to create clients + for server_config in self.servers: + self.clients.append(ComfyStreamClient( + host=server_config["host"], + port=server_config["port"], + spawn=False, + comfyui_log_level=self.comfyui_log_level, + )) + + elif hasattr(self, 'client_mode') and self.client_mode == "spawn": + # Spin up clients as external processes + ports = [] + cuda_device_list = [d.strip() for d in str(self.cuda_devices).split(',') if d.strip()] + for device_idx, cuda_device in enumerate(cuda_device_list): + for worker_idx in range(self.workers): + port = self.workers_start_port + len(ports) + ports.append(port) + client = ComfyStreamClient( + host="127.0.0.1", + port=port, + spawn=True, + comfyui_path=os.path.join(self.workspace, "main.py"), + workspace=self.workspace, + comfyui_args=[ + "--disable-cuda-malloc", + "--gpu-only", + "--preview-method", "none", + "--listen", + "--cuda-device", str(cuda_device), + "--fast", + "--enable-cors-header", "\"*\"", + "--port", str(port), + "--disable-xformers", + ], + comfyui_log_level=self.comfyui_log_level, + ) + self.clients.append(client) + logger.info(f"Created worker {worker_idx+1}/{self.workers} for CUDA device {cuda_device} on port {port}") - else: - raise ValueError(f"Unknown client_mode: {getattr(self, 'client_mode', 'None')}") - - # Start all ComfyUI servers in parallel if in spawn mode - if hasattr(self, 'client_mode') and self.client_mode == "spawn": - # First, launch all server processes in parallel - for client in self.clients: - if client.spawn: - client._launch_comfyui_server() + else: + raise ValueError(f"Unknown client_mode: {getattr(self, 'client_mode', 'None')}") - # Now create async functions to check server readiness - async def check_server_ready(client, timeout=60, check_interval=0.5, max_retries=3): - """Async version of waiting for server to be ready with retry logic""" - logger.info(f"Waiting for ComfyUI server on port {client.port} to be ready...") - - retries = 0 - while retries <= max_retries: - start_time = time.time() - while time.time() - start_time < timeout: - # Check if process is still running - if client._comfyui_proc and client._comfyui_proc.poll() is not None: - return_code = client._comfyui_proc.poll() - logger.error(f"ComfyUI process exited with code {return_code} before it was ready") - - # If we still have retries left, restart the process - if retries < max_retries: - retries += 1 - logger.info(f"Retrying ComfyUI server on port {client.port} (attempt {retries}/{max_retries})") - - # Kill any zombie process - if client._comfyui_proc: - try: - client._comfyui_proc.terminate() - except Exception: - pass - - # Start a new process - client._launch_comfyui_server() - await asyncio.sleep(2) # Give it a moment to start - break # Break inner loop to restart timeout - else: - # We're out of retries - raise RuntimeError(f"ComfyUI process exited with code {return_code} after {max_retries} retries") - - # Try to connect to the server - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(2) - result = sock.connect_ex((client.host, client.port)) - sock.close() - - if result == 0: - logger.info(f"ComfyUI server on port {client.port} is now accepting connections") - return - except Exception: - pass - - # Sleep and try again - await asyncio.sleep(check_interval) - - # If we break out of the inner loop due to a restart, continue - # If we break out due to timeout, increment retries and try again - if time.time() - start_time >= timeout: - retries += 1 - if retries <= max_retries: - logger.info(f"Timed out waiting for ComfyUI server on port {client.port}, retrying (attempt {retries}/{max_retries})") - - # Kill any zombie process - if client._comfyui_proc: - try: - client._comfyui_proc.terminate() - except Exception: - pass - - # Start a new process - client._launch_comfyui_server() - await asyncio.sleep(2) # Give it a moment to start - else: - # We're out of retries - logger.error(f"Timed out waiting for ComfyUI server on port {client.port} after {max_retries} retries") - if client._comfyui_proc: + # Start all ComfyUI servers in parallel if in spawn mode + if hasattr(self, 'client_mode') and self.client_mode == "spawn": + try: + # Get all spawn clients + spawn_clients = [client for client in self.clients if client.spawn] + if spawn_clients: + logger.info(f"Starting {len(spawn_clients)} ComfyUI servers in parallel") + + # First validate all clients (keeping original validation logic) + for client in spawn_clients: + # These checks are from the original start_server method + if not client.comfyui_path: + raise ValueError("comfyui_path must be provided when spawn=True") + if not os.path.exists(client.comfyui_path): + raise FileNotFoundError(f"ComfyUI path does not exist: {client.comfyui_path}") + + # Start all server processes WITHOUT waiting for them to be ready + for client in spawn_clients: + client.launch_comfyui_server() + + # Now wait for all servers to be ready in parallel using thread pool + await asyncio.gather(*[ + asyncio.to_thread(client.wait_for_server_ready) + for client in spawn_clients + ]) + + except Exception as e: + # Clean up any clients that might have started + for client in self.clients: + if hasattr(client, '_comfyui_proc') and client._comfyui_proc: + try: client._comfyui_proc.terminate() - client._comfyui_proc = None - raise RuntimeError(f"Timed out waiting for ComfyUI server on port {client.port} after {max_retries} retries") - - # Wait for all servers to be ready in parallel - wait_tasks = [] - for client in self.clients: - if client.spawn: - wait_tasks.append(check_server_ready(client)) + except: + pass + + self.clients = [] + self.startup_error = str(e) + logger.error(f"Failed to start ComfyUI servers: {e}") + return None - if wait_tasks: - logger.info(f"Waiting for {len(wait_tasks)} ComfyUI servers to become ready...") - await asyncio.gather(*wait_tasks) - logger.info(f"All {len(wait_tasks)} ComfyUI servers are ready") - - logger.info(f"Initialized {len(self.clients)} clients") - return self.clients - + logger.info(f"Initialized {len(self.clients)} clients") + return self.clients + + except Exception as e: + self.startup_error = str(e) + logger.error(f"Error starting clients: {e}") + self.clients = [] + return None + # For backwards compatibility, maintain the original Pipeline name Pipeline = MultiServerPipeline \ No newline at end of file From 5b3ac66af681662e4da7ac946b654c359e54fb11 Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Tue, 29 Apr 2025 17:45:36 -0400 Subject: [PATCH 22/42] Modification to logging handler for subprocesses. --- src/comfystream/client_api.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/src/comfystream/client_api.py b/src/comfystream/client_api.py index 7975384c..5d328287 100644 --- a/src/comfystream/client_api.py +++ b/src/comfystream/client_api.py @@ -807,22 +807,38 @@ def launch_comfyui_server(self): ''' cmd.extend(self.comfyui_args) + # Set up environment with proper encoding and ANSI support + env = os.environ.copy() + env.update({ + 'PYTHONIOENCODING': 'utf-8', + 'PYTHONLEGACYWINDOWSSTDIO': 'utf-8', + 'FORCE_COLOR': '1' + }) + logger.info(f"[Client[{self.port}]: Starting ComfyUI with command: {' '.join(cmd)}") self._comfyui_proc = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, - env=os.environ.copy(), + env=env, text=True, + encoding='utf-8', + errors='replace', bufsize=1, # Line buffered ) - # Start a thread to log output def log_output(stream): for line in iter(stream.readline, ''): - # TODO: Handle different log levels? - if self.comfyui_log_level == "DEBUG": - logger.info(f"ComfyUI[{self.port}]: {line.strip()}") + try: + if self.comfyui_log_level == "DEBUG": + # Strip ANSI codes if they cause problems + message = line.strip() + # Optional: Remove ANSI codes if they still cause issues + # import re + # message = re.sub(r'\033\[[0-9;]*[mGKH]', '', message) + logger.info(f"ComfyUI[{self.port}]: {message}") + except Exception as e: + logger.error(f"Error logging output: {e}") import threading threading.Thread(target=log_output, args=(self._comfyui_proc.stdout,), daemon=True).start() From 07a9c51217114e49309b20da8921b44652cf0d64 Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Fri, 2 May 2025 15:30:26 -0400 Subject: [PATCH 23/42] Added optional ROOT_DIR environment variable to help with build_trt script. --- src/comfystream/scripts/build_trt.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/comfystream/scripts/build_trt.py b/src/comfystream/scripts/build_trt.py index fa0a0f24..866d7de6 100644 --- a/src/comfystream/scripts/build_trt.py +++ b/src/comfystream/scripts/build_trt.py @@ -11,9 +11,9 @@ # $> python src/comfystream/scripts/build_trt.py --model /ComfyUI/models/checkpoints/SD1.5/dreamshaper-8.safetensors --out-engine /ComfyUI/output/tensorrt/static-dreamshaper8_SD15_$stat-b-1-h-512-w-512_00001_.engine # Paths path explicitly to use the downloaded comfyUI installation on root -ROOT_DIR="/workspace" -COMFYUI_DIR = "/workspace/ComfyUI" -timing_cache_path = "/workspace/ComfyUI/output/tensorrt/timing_cache" +ROOT_DIR = os.environ.get("ROOT_DIR", "/workspace") +COMFYUI_DIR = os.path.join(ROOT_DIR, "ComfyUI") +timing_cache_path = os.path.join(ROOT_DIR, "ComfyUI/output/tensorrt/timing_cache") if ROOT_DIR not in sys.path: sys.path.insert(0, ROOT_DIR) @@ -21,9 +21,9 @@ sys.path.insert(0, COMFYUI_DIR) comfy_dirs = [ - "/workspace/ComfyUI/", - "/workspace/ComfyUI/comfy", - "/workspace/ComfyUI/comfy_extras" + COMFYUI_DIR, + os.path.join(COMFYUI_DIR, "comfy"), + os.path.join(COMFYUI_DIR, "comfy_extras") ] for comfy_dir in comfy_dirs: From fb44f6d2d9ae767278f8062bdf12bbe13939543c Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Tue, 6 May 2025 18:48:34 -0400 Subject: [PATCH 24/42] Code cleanup, modification to frame timing mechanism, added frame logging utility. --- server/app_api.py | 8 +- src/comfystream/client_api.py | 205 ++++++++++++++----------------- src/comfystream/frame_logging.py | 170 +++++++++++++++++++++++++ src/comfystream/pipeline_api.py | 157 +++++++++++------------ 4 files changed, 333 insertions(+), 207 deletions(-) create mode 100644 src/comfystream/frame_logging.py diff --git a/server/app_api.py b/server/app_api.py index a5330bca..64666b3d 100644 --- a/server/app_api.py +++ b/server/app_api.py @@ -158,7 +158,6 @@ async def collect_frames(self): async def recv(self): return await self.pipeline.get_processed_audio_frame() - def force_codec(pc, sender, forced_codec): kind = forced_codec.split("/")[0] codecs = RTCRtpSender.getCapabilities(kind).codecs @@ -398,10 +397,9 @@ async def on_startup(app: web.Application): app["pipeline"] = Pipeline( width=512, height=512, - comfyui_inference_log_level=app.get("comfui_inference_log_level", None), config_path=app["config_file"], max_frame_wait_ms=app["max_frame_wait"], - client_mode=app["client_mode"], + client_mode=app["client_mode"], workspace=app["workspace"], workers=app["workers"], cuda_devices=app["cuda_devices"], @@ -517,6 +515,7 @@ async def on_shutdown(app: web.Application): ) # Set logger level based on command line arguments + print(f"Setting log level to {args.log_level.upper()}") logger.setLevel(getattr(logging, args.log_level.upper())) app = web.Application() @@ -567,12 +566,11 @@ def force_print(*args, **kwargs): sys.stdout.flush() # Allow overriding of ComyfUI log levels. - # TODO: This will have to pipe to spawn clients if args.comfyui_log_level: log_level = logging._nameToLevel.get(args.comfyui_log_level.upper()) logging.getLogger("comfy").setLevel(log_level) app["comfyui_log_level"] = args.comfyui_log_level if args.comfyui_inference_log_level: - app["comfui_inference_log_level"] = args.comfyui_inference_log_level + app["comfyui_inference_log_level"] = args.comfyui_inference_log_level web.run_app(app, host=args.host, port=int(args.port), print=force_print) diff --git a/src/comfystream/client_api.py b/src/comfystream/client_api.py index 5d328287..61e57102 100644 --- a/src/comfystream/client_api.py +++ b/src/comfystream/client_api.py @@ -358,14 +358,9 @@ async def _handle_binary_message(self, binary_data): frame_id = self._current_frame_id logger.debug(f"Using current frame_id {frame_id}") - # Add to output queue - include frame_id if available - if frame_id is not None: - tensor_cache.image_outputs.put_nowait((frame_id, tensor)) - logger.debug(f"Added tensor with frame_id {frame_id} to output queue") - else: - tensor_cache.image_outputs.put_nowait(tensor) - logger.debug("Added tensor without frame_id to output queue") - + tensor_cache.image_outputs.put_nowait((frame_id, tensor)) + logger.debug(f"Added tensor with frame_id {frame_id} to output queue") + # We will execute the next prompt from message_type == "execution_start" instead # self.execution_complete_event.set() @@ -385,20 +380,24 @@ async def _execute_prompt(self, prompt_index: int): # Check if we have a frame waiting to be processed if not tensor_cache.image_inputs.empty(): # Get the most recent frame only - frame_or_tensor = None + frame = None while not tensor_cache.image_inputs.empty(): - frame_or_tensor = tensor_cache.image_inputs.get_nowait() - - # Extract frame ID if available in side_data - frame_id = None - if hasattr(frame_or_tensor, 'side_data'): - # Try to get frame_id from side_data - if hasattr(frame_or_tensor.side_data, 'frame_id'): - frame_id = frame_or_tensor.side_data.frame_id - logger.debug(f"Found frame_id in side_data: {frame_id}") - - # Store current frame ID for binary message handler to use - self._current_frame_id = frame_id + frame = tensor_cache.image_inputs.get_nowait() + + self._current_frame_id = getattr(frame.side_data, 'frame_id', None) + + if self._current_frame_id is None: + logger.error("No frame_id found in side_data") + self.execution_complete_event.set() + return + + if not (hasattr(frame, 'side_data') and hasattr(frame.side_data, 'input')): + logger.error( + "Frame object from queue ('tensor_cache.image_inputs') is not structured as " + "expected (missing side_data.input). Skipping processing for this frame." + ) + self.execution_complete_event.set() # Allow next cycle + return # Find LoadImageBase64 nodes first load_image_nodes = [] @@ -413,87 +412,72 @@ async def _execute_prompt(self, prompt_index: int): # Process the tensor ONLY if we have nodes to send it to try: - # Get the actual tensor data - handle different input types - tensor = None - - # Handle different input types efficiently - if hasattr(frame_or_tensor, 'side_data') and hasattr(frame_or_tensor.side_data, 'input'): - tensor = frame_or_tensor.side_data.input - elif isinstance(frame_or_tensor, torch.Tensor): - tensor = frame_or_tensor - elif isinstance(frame_or_tensor, np.ndarray): - tensor = torch.from_numpy(frame_or_tensor).float() - elif hasattr(frame_or_tensor, 'to_ndarray'): - frame_np = frame_or_tensor.to_ndarray(format="rgb24").astype(np.float32) / 255.0 - tensor = torch.from_numpy(frame_np).unsqueeze(0) - + tensor = getattr(frame.side_data, 'input', None) + if tensor is None: - logger.error("Failed to get valid tensor data from input") + logger.error("No tensor found in side_data") self.execution_complete_event.set() return - - # Process tensor format only once - streamlined for speed and reliability - with torch.no_grad(): - # Fast tensor normalization to ensure consistent output - try: - # TODO: Why is the UI sending different sizes? Should be fixed no? This breaks tensorrt - # I'm sometimes seeing (BCHW): torch.Size([1, 384, 384, 3]), H=384, W=3 - # Ensure minimum size of 512x512 - - # Handle batch dimension if present - if len(tensor.shape) == 4: # BCHW format - tensor = tensor[0] # Take first image from batch - - # Normalize to CHW format consistently - if len(tensor.shape) == 3 and tensor.shape[2] == 3: # HWC format - tensor = tensor.permute(2, 0, 1) # Convert to CHW - - # Handle single-channel case - if len(tensor.shape) == 3 and tensor.shape[0] == 1: - tensor = tensor.repeat(3, 1, 1) # Convert grayscale to RGB - - # Ensure tensor is on CPU - if tensor.is_cuda: - tensor = tensor.cpu() - - # Always resize to 512x512 for consistency (faster than checking dimensions first) - tensor = tensor.unsqueeze(0) # Add batch dim for interpolate - tensor = torch.nn.functional.interpolate( - tensor, size=(512, 512), mode='bilinear', align_corners=False - ) - tensor = tensor[0] # Remove batch dimension - - # ==== - # PIL method - ''' - # Direct conversion to PIL without intermediate numpy step for speed - tensor_np = (tensor.permute(1, 2, 0).clamp(0, 1) * 255).to(torch.uint8).numpy() - img = Image.fromarray(tensor_np) - img.save(self.buffer, format="JPEG", quality=90, optimize=True) - ''' - - # ==== - # torchvision method (more performant - TODO: need to test further) - # Direct conversion to PIL without intermediate numpy step - # Fast JPEG encoding with reduced quality for better performance - tensor_pil = to_pil_image(tensor.clamp(0, 1)) - tensor_pil.save(self.buffer, format="JPEG", quality=75, optimize=True) - # ==== - - self.buffer.seek(0) - img_base64 = base64.b64encode(self.buffer.getvalue()).decode('utf-8') - - except Exception as e: - logger.warning(f"Error in tensor processing: {e}, creating fallback image") - # Create a standard 512x512 placeholder if anything fails - img = Image.new('RGB', (512, 512), color=(100, 149, 237)) - self.buffer = BytesIO() - img.save(self.buffer, format="JPEG", quality=90) - self.buffer.seek(0) - img_base64 = base64.b64encode(self.buffer.getvalue()).decode('utf-8') + + try: + # TODO: Why is the UI sending different sizes? Should be fixed no? This breaks tensorrt + # I'm sometimes seeing (BCHW): torch.Size([1, 384, 384, 3]), H=384, W=3 + # Ensure minimum size of 512x512 + + # Handle batch dimension if present + if len(tensor.shape) == 4: # BCHW format + tensor = tensor[0] # Take first image from batch + + # Normalize to CHW format consistently + if len(tensor.shape) == 3 and tensor.shape[2] == 3: # HWC format + tensor = tensor.permute(2, 0, 1) # Convert to CHW + + # Handle single-channel case + if len(tensor.shape) == 3 and tensor.shape[0] == 1: + tensor = tensor.repeat(3, 1, 1) # Convert grayscale to RGB - # Add timestamp for cache busting (once, outside the try/except) - timestamp = int(time.time() * 1000) + # Ensure tensor is on CPU + if tensor.is_cuda: + tensor = tensor.cpu() + + # Always resize to 512x512 for consistency (faster than checking dimensions first) + tensor = tensor.unsqueeze(0) # Add batch dim for interpolate + tensor = torch.nn.functional.interpolate( + tensor, size=(512, 512), mode='bilinear', align_corners=False + ) + tensor = tensor[0] # Remove batch dimension + + # ==== + # PIL method + ''' + # Direct conversion to PIL without intermediate numpy step for speed + tensor_np = (tensor.permute(1, 2, 0).clamp(0, 1) * 255).to(torch.uint8).numpy() + img = Image.fromarray(tensor_np) + img.save(self.buffer, format="JPEG", quality=90, optimize=True) + ''' + + # ==== + # torchvision method (more performant - TODO: need to test further) + # Direct conversion to PIL without intermediate numpy step + # Fast JPEG encoding with reduced quality for better performance + tensor_pil = to_pil_image(tensor.clamp(0, 1)) + tensor_pil.save(self.buffer, format="JPEG", quality=75, optimize=True) + # ==== + + self.buffer.seek(0) + img_base64 = base64.b64encode(self.buffer.getvalue()).decode('utf-8') + + except Exception as e: + logger.warning(f"Error in tensor processing: {e}, creating fallback image") + # Create a standard 512x512 placeholder if anything fails + img = Image.new('RGB', (512, 512), color=(100, 149, 237)) + self.buffer = BytesIO() + img.save(self.buffer, format="JPEG", quality=90) + self.buffer.seek(0) + img_base64 = base64.b64encode(self.buffer.getvalue()).decode('utf-8') + + # Add timestamp for cache busting (once, outside the try/except) + timestamp = int(time.time() * 1000) # Update all nodes with the SAME base64 string for node_id in load_image_nodes: @@ -520,10 +504,9 @@ async def _execute_prompt(self, prompt_index: int): result = await response.json() self._prompt_id = result.get("prompt_id") - # Map prompt_id to frame_id for later retrieval - if frame_id is not None: - self._frame_id_mapping[self._prompt_id] = frame_id - logger.debug(f"Mapped prompt_id {self._prompt_id} to frame_id {frame_id}") + self._frame_id_mapping[self._prompt_id] = self._current_frame_id + # logger.debug(f"Mapped prompt_id {self._prompt_id} to frame_id {self._current_frame_id}") + else: error_text = await response.text() logger.error(f"Error queueing prompt: {response.status} - {error_text}") @@ -614,18 +597,10 @@ def put_audio_input(self, frame): async def get_video_output(self): """Get processed video frame from tensor cache""" - result = await tensor_cache.image_outputs.get() - - # Check if the result is a tuple with frame_id - if isinstance(result, tuple) and len(result) == 2: - frame_id, tensor = result - logger.debug(f"[Client[{self.port}]: Got processed tensor from output queue with frame_id {frame_id}") - # Return both the frame_id and tensor to help with ordering in the pipeline - return frame_id, tensor - else: - # If it's not a tuple with frame_id, just return the tensor - logger.debug("Got processed tensor from output queue without frame_id") - return result + frame_id, tensor = await tensor_cache.image_outputs.get() + logger.debug(f"[Client[{self.port}]: Got processed tensor from output queue with frame_id {frame_id}") + # Return both the frame_id and tensor to help with ordering in the pipeline + return frame_id, tensor async def get_audio_output(self): """Get processed audio frame from tensor cache""" @@ -814,7 +789,7 @@ def launch_comfyui_server(self): 'PYTHONLEGACYWINDOWSSTDIO': 'utf-8', 'FORCE_COLOR': '1' }) - + logger.info(f"[Client[{self.port}]: Starting ComfyUI with command: {' '.join(cmd)}") self._comfyui_proc = subprocess.Popen( cmd, diff --git a/src/comfystream/frame_logging.py b/src/comfystream/frame_logging.py new file mode 100644 index 00000000..748eef78 --- /dev/null +++ b/src/comfystream/frame_logging.py @@ -0,0 +1,170 @@ +import csv +import os +import time +from typing import Optional, Dict, Any +import argparse +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt + +def log_frame_timing( + frame_id: Optional[int], + frame_received_time: Optional[float], + frame_processed_time: Optional[float], + client_index: Optional[int] = None, + additional_metadata: Optional[Dict[str, Any]] = None, + csv_path: str = "frame_logs.csv" +): + """ + Log frame timing information to a CSV file. + Args: + frame_id: The unique identifier for the frame. + frame_received_time: Timestamp when the frame was received by pipeline. + frame_processed_time: Timestamp when the frame was processed. + client_index: Index of the client that processed this frame. + additional_metadata: Any additional data to log (will be converted to string). + csv_path: Path to the CSV file (default: 'frame_logs.csv'). + """ + latency = None + if frame_received_time is not None and frame_processed_time is not None: + latency = frame_processed_time - frame_received_time + + # Calculate absolute time + current_time = time.time() + + # Convert additional metadata to string if present + metadata_str = str(additional_metadata) if additional_metadata else None + + file_exists = os.path.isfile(csv_path) + with open(csv_path, "a", newline="") as csvfile: + writer = csv.writer(csvfile) + if not file_exists: + header = [ + "log_timestamp", "frame_id", "frame_received_time", + "frame_processed_time", "latency_ms", "client_index", "metadata" + ] + writer.writerow(header) + + # Calculate latency in milliseconds for logging + latency_ms = None + if frame_received_time is not None and frame_processed_time is not None: + latency_ms = (frame_processed_time - frame_received_time) * 1000 + + writer.writerow([ + current_time, frame_id, frame_received_time, frame_processed_time, + latency_ms, client_index, metadata_str + ]) + + +def plot_frame_metrics(csv_path: str = "frame_logs.csv"): + if not os.path.isfile(csv_path): + print(f"CSV file '{csv_path}' not found.") + return + + df = pd.read_csv(csv_path) + if df.empty: + print(f"CSV file '{csv_path}' is empty.") + return + + # Drop rows with missing times or frame_id + df = df.dropna(subset=["frame_id", "frame_received_time", "frame_processed_time"]) + if df.empty: + print("No valid timing data in CSV after dropping NA in essential time columns.") + return + + # Sort by frame_id to ensure correct interval calculations + df = df.sort_values("frame_id").reset_index(drop=True) + + # Calculate time since start of stream (first frame) + stream_start_time = df["frame_received_time"].min() + df["time_since_start"] = df["frame_received_time"] - stream_start_time + + # Calculate intervals between consecutive frames + # Input interval: Time between a frame and the previous frame being received + df["input_interval_s"] = df["frame_received_time"].diff() + # Output interval: Time between a frame and the previous frame being processed + df["output_interval_s"] = df["frame_processed_time"].diff() + + # Handle potential zero or negative intervals (e.g., from duplicate timestamps or sorting issues if not by frame_id) + # These would lead to infinite or meaningless FPS, so set them to NaN. + df.loc[df["input_interval_s"] <= 0, "input_interval_s"] = np.nan + df.loc[df["output_interval_s"] <= 0, "output_interval_s"] = np.nan + + # Calculate FPS (Frames Per Second) + # FPS is the reciprocal of the interval in seconds. + df["input_fps"] = 1.0 / df["input_interval_s"] + df["output_fps"] = 1.0 / df["output_interval_s"] + + # Rolling statistics for smoothing + window = 30 # Increased window for smoother, more stable FPS and jitter + df["input_fps_smooth"] = df["input_fps"].rolling(window, min_periods=1).mean() + df["output_fps_smooth"] = df["output_fps"].rolling(window, min_periods=1).mean() + + # Jitter: Standard deviation of intervals (in milliseconds) + # This measures the variation in frame arrival/processing times. + df["input_jitter_ms"] = df["input_interval_s"].rolling(window, min_periods=1).std() * 1000 + df["output_jitter_ms"] = df["output_interval_s"].rolling(window, min_periods=1).std() * 1000 + + # Latency: Time taken from frame reception to frame processing (in milliseconds) + # Use pre-calculated 'latency_ms' if available and valid, otherwise recalculate. + if "latency_ms" not in df.columns or df["latency_ms"].isnull().all(): + print("Recalculating latency_ms as it's not found or all null in CSV.") + df["latency_ms"] = (df["frame_processed_time"] - df["frame_received_time"]) * 1000 + else: + # Ensure it's numeric if it came from CSV + df["latency_ms"] = pd.to_numeric(df["latency_ms"], errors='coerce') + + df["latency_ms_smooth"] = df["latency_ms"].rolling(window, min_periods=1).mean() + + # Create visualization - reduced to 3 plots (removed frame distribution) + fig, axs = plt.subplots(3, 1, figsize=(14, 12), sharex=True) + + # FPS plot + axs[0].plot(df["time_since_start"], df["input_fps_smooth"], label="Input FPS (smooth)", color="blue", linewidth=2) + axs[0].plot(df["time_since_start"], df["output_fps_smooth"], label="Output FPS (smooth)", color="green", linewidth=2) + axs[0].set_ylabel("FPS") + axs[0].set_title("Frame Rate") + axs[0].legend() + axs[0].grid(True) + axs[0].set_xlim(left=0) # Start x-axis at 0 + + # Jitter plot + axs[1].plot(df["time_since_start"], df["input_jitter_ms"], label="Input Jitter (ms)", color="blue", alpha=0.7) + axs[1].plot(df["time_since_start"], df["output_jitter_ms"], label="Output Jitter (ms)", color="green", alpha=0.7) + axs[1].set_ylabel("Jitter (ms)") + axs[1].set_title("Frame Timing Jitter (Rolling StdDev of Intervals)") + axs[1].legend() + axs[1].grid(True) + axs[1].set_xlim(left=0) # Start x-axis at 0 + + # Latency plot - show per-frame latency rather than cumulative + axs[2].plot(df["time_since_start"], df["latency_ms"], label="Per-Frame Latency (ms)", color="red", alpha=0.5) + axs[2].plot(df["time_since_start"], df["latency_ms_smooth"], label="Smoothed Latency (ms)", color="darkred", linewidth=2) + axs[2].set_ylabel("Latency (ms)") + axs[2].set_title("Frame Processing Latency") + axs[2].legend() + axs[2].grid(True) + axs[2].set_xlim(left=0) # Start x-axis at 0 + + # Set y-axis limit to better visualize actual per-frame latency without accumulation + max_latency = df["latency_ms"].quantile(0.99) # Use 99th percentile to avoid outliers + axs[2].set_ylim(0, max_latency * 1.1) # Add 10% margin + + # Add x-axis label and make it visible + fig.text(0.5, 0.04, 'Time (seconds)', ha='center', va='center', fontsize=12) + + plt.tight_layout() + plt.subplots_adjust(bottom=0.07) # Add space for x-axis label + plt.savefig(csv_path.replace('.csv', '.png')) + plt.show() + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Plot frame timing metrics from CSV logs.") + parser.add_argument( + "--frame-logs", + type=str, + default="frame_logs.csv", + help="Path to the frame timing CSV log file (default: frame_logs.csv)" + ) + args = parser.parse_args() + plot_frame_metrics(args.frame_logs) diff --git a/src/comfystream/pipeline_api.py b/src/comfystream/pipeline_api.py index a4bb49ea..e9387e2d 100644 --- a/src/comfystream/pipeline_api.py +++ b/src/comfystream/pipeline_api.py @@ -4,16 +4,16 @@ import asyncio import logging import time -import random from collections import OrderedDict import collections import os -import socket +import fractions from typing import Any, Dict, Union, List, Optional, Deque from comfystream.client_api import ComfyStreamClient from comfystream.server.utils import temporary_log_level # Not sure exactly what this does from comfystream.server.utils.config import ComfyConfig +from comfystream.frame_logging import log_frame_timing WARMUP_RUNS = 5 logger = logging.getLogger(__name__) @@ -24,7 +24,6 @@ def __init__( self, width: int = 512, height: int = 512, - comfyui_inference_log_level: int = None, config_path: Optional[str] = None, max_frame_wait_ms: int = 500, client_mode: str = "toml", @@ -39,8 +38,6 @@ def __init__( width: The width of the video frames. height: The height of the video frames. workers: The number of ComfyUI clients to spin up (if client_mode is "spawn"). - comfyui_inference_log_level: The logging level for ComfyUI inference. - Defaults to None, using the global ComfyUI log level. config_path: The path to the ComfyUI config toml file (if client_mode is "toml"). max_frame_wait_ms: The maximum number of milliseconds to wait for a frame before dropping it. client_mode: The mode to use for the ComfyUI clients. @@ -86,6 +83,7 @@ def __init__( self.processed_video_frames = asyncio.Queue() # Track which client gets each frame (round-robin) + self.last_frame_time = 0 self.current_client_index = 0 self.client_frame_mapping = {} # Maps frame_id -> client_index @@ -96,10 +94,6 @@ def __init__( # Audio processing self.processed_audio_buffer = np.array([], dtype=np.int16) - self.last_frame_time = 0 - - # ComfyUI inference log level - self._comfyui_inference_log_level = comfyui_inference_log_level # Frame rate limiting self.min_frame_interval = 1/30 # Limit to 30 FPS @@ -123,39 +117,21 @@ async def _collect_processed_frames(self): try: # Non-blocking check if client has output ready if hasattr(client, '_prompt_id') and client._prompt_id is not None: - # Get frame without waiting try: # Use wait_for with small timeout to avoid blocking - result = await asyncio.wait_for( + frame_id, out_tensor = await asyncio.wait_for( client.get_video_output(), - timeout=0.01 + timeout=0.001 ) - # Check if result is already a tuple with frame_id - if isinstance(result, tuple) and len(result) == 2: - frame_id, out_tensor = result - logger.debug(f"Got result with embedded frame_id: {frame_id}") - else: - out_tensor = result - # Find which original frame this corresponds to using our mapping - frame_ids = [frame_id for frame_id, client_idx in - self.client_frame_mapping.items() if client_idx == i] - - if frame_ids: - # Use the oldest frame ID for this client - frame_id = min(frame_ids) - else: - # If no mapping found, log warning and continue - logger.warning(f"No frame_id mapping found for tensor from client {i}") - continue - # Store frame with timestamp for ordering - timestamp = time.time() - await self._add_frame_to_ordered_buffer(frame_id, timestamp, out_tensor) + current_time = time.time() + await self._add_frame_to_ordered_buffer(frame_id, current_time, out_tensor) # Remove the mapping self.client_frame_mapping.pop(frame_id, None) - logger.debug(f"Collected processed frame from client {i}, frame_id: {frame_id}") + + # logger.debug(f"Collected processed frame from client {i}, frame_id: {frame_id}") except asyncio.TimeoutError: # No frame ready yet, continue pass @@ -186,14 +162,14 @@ async def _add_frame_to_ordered_buffer(self, frame_id, timestamp, tensor): async def _release_ordered_frames(self): if self.next_expected_frame_id is None: return - if self.ordered_frames and self.next_expected_frame_id in self.ordered_frames: + + # Only release frames in strict sequential order + while self.ordered_frames and self.next_expected_frame_id in self.ordered_frames: timestamp, tensor = self.ordered_frames.pop(self.next_expected_frame_id) await self.processed_video_frames.put((self.next_expected_frame_id, tensor)) logger.debug(f"Released frame {self.next_expected_frame_id} to output queue") - if self.ordered_frames: - self.next_expected_frame_id = min(self.ordered_frames.keys()) - else: - self.next_expected_frame_id += 1 + # Always increment to next sequential frame ID + self.next_expected_frame_id += 1 async def _check_frame_timeouts(self): """Check for frames that have waited too long and handle them""" @@ -208,8 +184,11 @@ async def _check_frame_timeouts(self): wait_time_ms = (current_time - timestamp) * 1000 if wait_time_ms > self.max_frame_wait_ms: - logger.warning(f"Frame {self.next_expected_frame_id} exceeded max wait time, releasing anyway") - await self._release_ordered_frames() + # logger.warning(f"Frame {self.next_expected_frame_id} exceeded max wait time, releasing anyway") + # await self._release_ordered_frames() + + # Remove frame + self.ordered_frames.pop(self.next_expected_frame_id) # Check if we're missing the next expected frame and it's been too long elif self.ordered_frames: @@ -227,15 +206,11 @@ async def _check_frame_timeouts(self): async def warm_video(self): # Create dummy frame with the CURRENT resolution settings (which might have been updated via control channel) - - # Create a properly formatted dummy frame - ''' + tensor = torch.rand(1, 3, 512, 512) # Random values in [0,1] dummy_frame = av.VideoFrame(width=512, height=512, format="rgb24") dummy_frame.side_data.input = tensor - ''' - dummy_frame = av.VideoFrame() - dummy_frame.side_data.input = torch.randn(1, self.height, self.width, 3) + dummy_frame.side_data.frame_received_time = time.time() logger.info(f"Warming video pipeline with resolution {self.width}x{self.height}") @@ -251,6 +226,11 @@ async def warm_video(self): async def _warm_client_video(self, client, client_index, dummy_frame): """Warm up a single client""" logger.info(f"Warming up client {client_index}") + + # Set frame input as dummyframe with side_data.input set to a random tensor + dummy_frame.side_data.input = torch.randn(1, self.height, self.width, 3) + dummy_frame.side_data.frame_id = -1 + for i in range(WARMUP_RUNS): logger.info(f"Client {client_index} warmup iteration {i+1}/{WARMUP_RUNS}") client.put_video_input(dummy_frame) @@ -283,6 +263,7 @@ async def set_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]] # Set prompts for each client tasks = [] for client in self.clients: + logger.info(f"Setting prompts for client {client.port}") tasks.append(client.set_prompts(prompts)) await asyncio.gather(*tasks) @@ -302,11 +283,15 @@ async def update_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any logger.info(f"Updated prompts for {len(self.clients)} clients") async def put_video_frame(self, frame: av.VideoFrame): - """Distribute video frames among clients using round-robin""" + ''' Put a video frame into the pipeline round-robin to all clients ''' current_time = time.time() + + ''' if current_time - self.last_frame_time < self.min_frame_interval: + print(f"Skipping frame due to rate limiting: {current_time - self.last_frame_time} seconds since last frame") return # Skip frame if too soon - + ''' + self.last_frame_time = current_time # Generate a unique frame ID - use sequential IDs for better ordering @@ -316,28 +301,28 @@ async def put_video_frame(self, frame: av.VideoFrame): frame_id = self.next_frame_id self.next_frame_id += 1 - frame.side_data.frame_id = frame_id - - # Preprocess the frame - frame.side_data.input = self.video_preprocess(frame) - frame.side_data.skipped = False - # Select the next client in round-robin fashion client_index = self.current_client_index self.current_client_index = (self.current_client_index + 1) % len(self.clients) # Store mapping of which client is processing this frame self.client_frame_mapping[frame_id] = client_index + + # Set side data for the frame + frame.side_data.input = self.video_preprocess(frame) + frame.side_data.frame_id = frame_id + frame.side_data.skipped = False + frame.side_data.frame_received_time = time.time() + frame.side_data.client_index = client_index # Send frame to the selected client self.clients[client_index].put_video_input(frame) + await self.video_incoming_frames.put(frame) - # Also add to the incoming queue for reference - await self.video_incoming_frames.put((frame_id, frame)) - - logger.debug(f"Sent frame {frame_id} to client {client_index}") - async def put_audio_frame(self, frame: av.AudioFrame): + ''' Not implemented yet ''' + return + # For now, only use the first client for audio if not self.clients: return @@ -351,16 +336,16 @@ def audio_preprocess(self, frame: av.AudioFrame) -> Union[torch.Tensor, np.ndarr return frame.to_ndarray().ravel().reshape(-1, 2).mean(axis=1).astype(np.int16) def video_preprocess(self, frame: av.VideoFrame) -> Union[torch.Tensor, np.ndarray]: - # Convert directly to tensor, avoiding intermediate numpy array when possible - if hasattr(frame, 'to_tensor'): - tensor = frame.to_tensor() - else: - # If direct tensor conversion not available, use numpy - frame_np = frame.to_ndarray(format="rgb24") - tensor = torch.from_numpy(frame_np) + """Preprocess a video frame before processing. - # Normalize to [0,1] range and add batch dimension - return tensor.float().div(255.0).unsqueeze(0) + Args: + frame: The video frame to preprocess + + Returns: + The preprocessed frame as a tensor or numpy array + """ + frame_np = frame.to_ndarray(format="rgb24").astype(np.float32) / 255.0 + return torch.from_numpy(frame_np).unsqueeze(0) def video_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.VideoFrame: return av.VideoFrame.from_ndarray( @@ -377,28 +362,23 @@ def audio_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.Audio async def get_processed_video_frame(self): try: - # Get the original frame from the incoming queue first to maintain timing - frame_id, frame = await self.video_incoming_frames.get() + frame = await self.video_incoming_frames.get() - # Skip frames if we're falling behind - ''' - while not self.video_incoming_frames.empty(): - # Get newer frame and mark old one as skipped - frame.side_data.skipped = True - frame_id, frame = await self.video_incoming_frames.get() - logger.info(f"Skipped older frame {frame_id} to catch up") - ''' # Get the processed frame from our output queue processed_frame_id, out_tensor = await self.processed_video_frames.get() - if processed_frame_id != frame_id: - logger.debug(f"Frame ID mismatch: expected {frame_id}, got {processed_frame_id}") - pass - # Process the frame processed_frame = self.video_postprocess(out_tensor) processed_frame.pts = frame.pts processed_frame.time_base = frame.time_base + + # Log frame timing asynchronously + log_frame_timing( + frame_id=processed_frame_id, + frame_received_time=frame.side_data.frame_received_time, + frame_processed_time=time.time(), + client_index=frame.side_data.client_index, + ) return processed_frame @@ -406,6 +386,12 @@ async def get_processed_video_frame(self): logger.error(f"Error in get_processed_video_frame: {str(e)}") # Create a black frame as fallback black_frame = av.VideoFrame(width=self.width, height=self.height, format='rgb24') + + # Set timestamps to avoid TypeError during encoding + # Use default values that work with the aiortc encoding pipeline + black_frame.pts = 0 + black_frame.time_base = fractions.Fraction(1, 90000) # Standard video timebase + return black_frame async def get_processed_audio_frame(self): @@ -529,11 +515,8 @@ async def _dynamic_output_pacer(self): await self.processed_video_frames.put((self.next_expected_frame_id, tensor)) logger.debug(f"Released frame {self.next_expected_frame_id} to output queue") - # Update next expected frame ID - if self.ordered_frames: - self.next_expected_frame_id = min(self.ordered_frames.keys()) - else: - self.next_expected_frame_id += 1 + # Always increment to next sequential frame ID + self.next_expected_frame_id += 1 # Sleep for the dynamic interval, but don't sleep negative time await asyncio.sleep(max(self.output_interval, 0.001)) From 9dc280b71ac6869ac879ab337a0206329bfb6d43 Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Wed, 7 May 2025 18:14:41 -0400 Subject: [PATCH 25/42] Frame logging development, removal of obsolete pipeline code, logging tests. --- src/comfystream/frame_logging.py | 400 +++++++++++++++++++++++-------- src/comfystream/pipeline_api.py | 52 +++- 2 files changed, 341 insertions(+), 111 deletions(-) diff --git a/src/comfystream/frame_logging.py b/src/comfystream/frame_logging.py index 748eef78..a069b33f 100644 --- a/src/comfystream/frame_logging.py +++ b/src/comfystream/frame_logging.py @@ -1,170 +1,366 @@ +# frame_logging.py +# Developed by @buffmcbighuge (Marco Tundo) + +# You can generate graphs from the log file: +# python frame_logging.py --frame-logs frame_logs1.csv,frame_logs_2.csv,frame_logs_3.csv + import csv import os import time -from typing import Optional, Dict, Any +from typing import Optional, Dict, Any, List, Tuple import argparse import numpy as np import pandas as pd import matplotlib.pyplot as plt +import matplotlib.colors as mcolors def log_frame_timing( frame_id: Optional[int], frame_received_time: Optional[float], + frame_process_start_time: Optional[float], frame_processed_time: Optional[float], client_index: Optional[int] = None, additional_metadata: Optional[Dict[str, Any]] = None, csv_path: str = "frame_logs.csv" ): """ - Log frame timing information to a CSV file. + Log frame timing information to a CSV file with simplified metrics. Args: - frame_id: The unique identifier for the frame. - frame_received_time: Timestamp when the frame was received by pipeline. - frame_processed_time: Timestamp when the frame was processed. - client_index: Index of the client that processed this frame. - additional_metadata: Any additional data to log (will be converted to string). - csv_path: Path to the CSV file (default: 'frame_logs.csv'). + frame_id: The unique identifier for the frame + frame_received_time: Timestamp when the frame was received by pipeline + frame_process_start_time: Timestamp when processing began + frame_processed_time: Timestamp when processing completed + client_index: Index of the client that processed this frame + additional_metadata: Any additional data to log + csv_path: Path to the CSV file """ - latency = None - if frame_received_time is not None and frame_processed_time is not None: - latency = frame_processed_time - frame_received_time - - # Calculate absolute time - current_time = time.time() + # Calculate processing latency + processing_latency = None + if frame_process_start_time is not None and frame_processed_time is not None: + processing_latency = (frame_processed_time - frame_process_start_time) * 1000 # Convert additional metadata to string if present metadata_str = str(additional_metadata) if additional_metadata else None + # Calculate absolute time for logging + current_time = time.time() + + # Determine if this is an input-only frame or a processed frame + is_processed = frame_process_start_time is not None and frame_processed_time is not None + frame_type = "processed" if is_processed else "input" + + # Prepare data based on frame type + if is_processed: + # For processed frames, include all columns + header = [ + "log_timestamp", "frame_id", "frame_type", + "frame_received_time", "frame_process_start_time", "frame_processed_time", + "processing_latency_ms", "client_index", "metadata" + ] + data = [ + current_time, frame_id, frame_type, + frame_received_time, frame_process_start_time, frame_processed_time, + processing_latency, client_index, metadata_str + ] + else: + # For input frames, only include relevant columns (skip processing-related columns) + header = [ + "log_timestamp", "frame_id", "frame_type", + "frame_received_time", "client_index", "metadata" + ] + data = [ + current_time, frame_id, frame_type, + frame_received_time, client_index, metadata_str + ] + file_exists = os.path.isfile(csv_path) - with open(csv_path, "a", newline="") as csvfile: - writer = csv.writer(csvfile) - if not file_exists: - header = [ - "log_timestamp", "frame_id", "frame_received_time", - "frame_processed_time", "latency_ms", "client_index", "metadata" - ] - writer.writerow(header) - - # Calculate latency in milliseconds for logging - latency_ms = None - if frame_received_time is not None and frame_processed_time is not None: - latency_ms = (frame_processed_time - frame_received_time) * 1000 + file_empty = file_exists and os.path.getsize(csv_path) == 0 - writer.writerow([ - current_time, frame_id, frame_received_time, frame_processed_time, - latency_ms, client_index, metadata_str - ]) + # Use pandas to handle the CSV file, which handles mixed column formats better + if not file_exists or file_empty: + # If file doesn't exist or is empty, create a new one with the full header + # This ensures the file always has all possible columns defined + full_header = [ + "log_timestamp", "frame_id", "frame_type", + "frame_received_time", "frame_process_start_time", "frame_processed_time", + "processing_latency_ms", "client_index", "metadata" + ] + pd.DataFrame(columns=full_header).to_csv(csv_path, index=False) + # Now append the data + df = pd.DataFrame([dict(zip(header, data))]) + df.to_csv(csv_path, mode='a', header=False, index=False, columns=header) -def plot_frame_metrics(csv_path: str = "frame_logs.csv"): +def process_log_file(csv_path: str) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, float]: + """ + Process a single log file and return the processed dataframes + """ if not os.path.isfile(csv_path): print(f"CSV file '{csv_path}' not found.") - return + return None, None, None, 0 df = pd.read_csv(csv_path) if df.empty: print(f"CSV file '{csv_path}' is empty.") - return + return None, None, None, 0 - # Drop rows with missing times or frame_id - df = df.dropna(subset=["frame_id", "frame_received_time", "frame_processed_time"]) + # Drop rows with missing essential times + df = df.dropna(subset=["frame_id", "frame_received_time"]) if df.empty: print("No valid timing data in CSV after dropping NA in essential time columns.") - return + return None, None, None, 0 - # Sort by frame_id to ensure correct interval calculations + # Sort by frame_id and calculate time since start df = df.sort_values("frame_id").reset_index(drop=True) - - # Calculate time since start of stream (first frame) stream_start_time = df["frame_received_time"].min() df["time_since_start"] = df["frame_received_time"] - stream_start_time + + # Separate input and processed frames based on frame_type column + if "frame_type" in df.columns: + input_df = df[df["frame_type"] == "input"].copy() + processed_df = df[df["frame_type"] == "processed"].copy() + else: + # Backward compatibility - separate based on process timestamps + input_df = df[df["frame_process_start_time"].isna()].copy() + processed_df = df.dropna(subset=["frame_processed_time"]).copy() + + # Calculate time of processed frames relative to stream start + if not processed_df.empty: + processed_df.loc[:, "output_time_relative"] = processed_df["frame_processed_time"] - stream_start_time + + # Create a consistent timeline with fixed intervals based on overall activity + max_input_time = df["time_since_start"].max() if not df["time_since_start"].empty else 0 + max_output_time = processed_df["output_time_relative"].max() if not processed_df.empty else 0 + + max_time = max(max_input_time, max_output_time) + time_range = np.arange(0, int(max_time) + 1) + + # Initialize FPS arrays for consistent timeline + input_fps_counts = np.zeros(len(time_range)) + output_fps_counts = np.zeros(len(time_range)) + + # Count frames in each 1-second interval + for t_idx, t_sec in enumerate(time_range): + # For input FPS, count input frames by received time + if not input_df.empty: + input_mask = (input_df["time_since_start"] >= t_sec) & (input_df["time_since_start"] < t_sec + 1) + input_fps_counts[t_idx] = input_mask.sum() + + # For output FPS, count processed frames by processed time + if not processed_df.empty: + output_mask = (processed_df["output_time_relative"] >= t_sec) & (processed_df["output_time_relative"] < t_sec + 1) + output_fps_counts[t_idx] = output_mask.sum() + + fps_df = pd.DataFrame({ + "time_bin": time_range, + "input_fps": input_fps_counts, + "output_fps": output_fps_counts + }) + + # Apply smoothing + smoothing_window = 3 + fps_df["input_fps_smooth"] = fps_df["input_fps"].rolling(smoothing_window, min_periods=1).mean() + fps_df["output_fps_smooth"] = fps_df["output_fps"].rolling(smoothing_window, min_periods=1).mean() + + # Calculate frame intervals for input and output frames separately + # Only calculate intervals for the same frame type + if not input_df.empty: + input_df = input_df.sort_values("frame_received_time").reset_index(drop=True) + input_df.loc[:, "input_interval_s"] = input_df["frame_received_time"].diff() + input_df.loc[input_df["input_interval_s"] <= 0, "input_interval_s"] = np.nan + input_df.loc[:, "input_time_bin"] = input_df["time_since_start"].astype(int) + + if not processed_df.empty: + processed_df = processed_df.sort_values("frame_processed_time").reset_index(drop=True) + processed_df.loc[:, "output_interval_s"] = processed_df["frame_processed_time"].diff() + processed_df.loc[processed_df["output_interval_s"] <= 0, "output_interval_s"] = np.nan + processed_df.loc[:, "output_time_bin"] = processed_df["output_time_relative"].astype(int) + + # Calculate jitter as the standard deviation of frame intervals in each time bin + input_jitter = np.full(len(time_range), np.nan) + output_jitter = np.full(len(time_range), np.nan) + + for t_idx, t_sec in enumerate(time_range): + # Input jitter - variation in input frame arrival times + if not input_df.empty: + intervals = input_df.loc[input_df["input_time_bin"] == t_sec, "input_interval_s"] + if len(intervals.dropna()) > 1: + std_dev = intervals.std() * 1000 # Convert to ms + input_jitter[t_idx] = std_dev + + # Output jitter - variation in processed frame completion times + if not processed_df.empty: + intervals = processed_df.loc[processed_df["output_time_bin"] == t_sec, "output_interval_s"] + if len(intervals.dropna()) > 1: + std_dev = intervals.std() * 1000 # Convert to ms + output_jitter[t_idx] = std_dev + + jitter_df = pd.DataFrame({ + "time_bin": time_range, + "input_jitter_ms": input_jitter, + "output_jitter_ms": output_jitter + }) - # Calculate intervals between consecutive frames - # Input interval: Time between a frame and the previous frame being received - df["input_interval_s"] = df["frame_received_time"].diff() - # Output interval: Time between a frame and the previous frame being processed - df["output_interval_s"] = df["frame_processed_time"].diff() - - # Handle potential zero or negative intervals (e.g., from duplicate timestamps or sorting issues if not by frame_id) - # These would lead to infinite or meaningless FPS, so set them to NaN. - df.loc[df["input_interval_s"] <= 0, "input_interval_s"] = np.nan - df.loc[df["output_interval_s"] <= 0, "output_interval_s"] = np.nan - - # Calculate FPS (Frames Per Second) - # FPS is the reciprocal of the interval in seconds. - df["input_fps"] = 1.0 / df["input_interval_s"] - df["output_fps"] = 1.0 / df["output_interval_s"] - - # Rolling statistics for smoothing - window = 30 # Increased window for smoother, more stable FPS and jitter - df["input_fps_smooth"] = df["input_fps"].rolling(window, min_periods=1).mean() - df["output_fps_smooth"] = df["output_fps"].rolling(window, min_periods=1).mean() - - # Jitter: Standard deviation of intervals (in milliseconds) - # This measures the variation in frame arrival/processing times. - df["input_jitter_ms"] = df["input_interval_s"].rolling(window, min_periods=1).std() * 1000 - df["output_jitter_ms"] = df["output_interval_s"].rolling(window, min_periods=1).std() * 1000 - - # Latency: Time taken from frame reception to frame processing (in milliseconds) - # Use pre-calculated 'latency_ms' if available and valid, otherwise recalculate. - if "latency_ms" not in df.columns or df["latency_ms"].isnull().all(): - print("Recalculating latency_ms as it's not found or all null in CSV.") - df["latency_ms"] = (df["frame_processed_time"] - df["frame_received_time"]) * 1000 + # Aggregate processing latency by time bin + if not processed_df.empty: + avg_latencies = processed_df.groupby("output_time_bin").agg({ + "processing_latency_ms": "mean" + }).reset_index() + + latency_df = pd.DataFrame({"time_bin": time_range}) + latency_df = pd.merge( + latency_df, + avg_latencies.rename(columns={"output_time_bin": "time_bin"}), + on="time_bin", + how="left" + ) else: - # Ensure it's numeric if it came from CSV - df["latency_ms"] = pd.to_numeric(df["latency_ms"], errors='coerce') - - df["latency_ms_smooth"] = df["latency_ms"].rolling(window, min_periods=1).mean() + latency_df = pd.DataFrame({ + "time_bin": time_range, + "processing_latency_ms": np.nan + }) + + return fps_df, jitter_df, latency_df, max_time - # Create visualization - reduced to 3 plots (removed frame distribution) +def plot_multiple_frame_metrics(csv_paths: List[str]): + """ + Plot metrics from multiple log files on the same charts + """ + if not csv_paths: + print("No CSV files provided.") + return + + # Process each log file + all_data = [] + max_time_overall = 0 + + for csv_path in csv_paths: + fps_df, jitter_df, latency_df, max_time = process_log_file(csv_path) + if fps_df is not None: + all_data.append({ + 'path': csv_path, + 'fps_df': fps_df, + 'jitter_df': jitter_df, + 'latency_df': latency_df, + 'max_time': max_time + }) + max_time_overall = max(max_time_overall, max_time) + + if not all_data: + print("No valid data found in any of the provided CSV files.") + return + + # Create visualization with 3 subplots fig, axs = plt.subplots(3, 1, figsize=(14, 12), sharex=True) - - # FPS plot - axs[0].plot(df["time_since_start"], df["input_fps_smooth"], label="Input FPS (smooth)", color="blue", linewidth=2) - axs[0].plot(df["time_since_start"], df["output_fps_smooth"], label="Output FPS (smooth)", color="green", linewidth=2) + + # Generate a list of distinct colors for multiple datasets + # Use a subset of tab colors for better distinction + tab_colors = list(mcolors.TABLEAU_COLORS.values()) + + for i, data in enumerate(all_data): + # Get colors for this dataset + input_color = tab_colors[i % len(tab_colors)] + output_color = tab_colors[(i + len(tab_colors)//2) % len(tab_colors)] + + # Extract the filename without path and extension for legend + file_label = os.path.splitext(os.path.basename(data['path']))[0] + + # 1. FPS Plot + axs[0].plot( + data['fps_df']["time_bin"], + data['fps_df']["input_fps_smooth"], + label=f"Input FPS - {file_label}", + color=input_color, + linewidth=2 + ) + axs[0].plot( + data['fps_df']["time_bin"], + data['fps_df']["output_fps_smooth"], + label=f"Output FPS - {file_label}", + color=output_color, + linewidth=2, + linestyle='--' + ) + + # 2. Jitter Plot + axs[1].plot( + data['jitter_df']["time_bin"], + data['jitter_df']["input_jitter_ms"], + label=f"Input Jitter - {file_label}", + color=input_color, + alpha=0.7 + ) + axs[1].plot( + data['jitter_df']["time_bin"], + data['jitter_df']["output_jitter_ms"], + label=f"Output Jitter - {file_label}", + color=output_color, + alpha=0.7, + linestyle='--' + ) + + # 3. Processing Latency Plot + axs[2].plot( + data['latency_df']["time_bin"], + data['latency_df']["processing_latency_ms"], + label=f"Processing Latency - {file_label}", + color=output_color, + alpha=0.7 + ) + + # Configure axes and labels axs[0].set_ylabel("FPS") axs[0].set_title("Frame Rate") axs[0].legend() axs[0].grid(True) - axs[0].set_xlim(left=0) # Start x-axis at 0 - - # Jitter plot - axs[1].plot(df["time_since_start"], df["input_jitter_ms"], label="Input Jitter (ms)", color="blue", alpha=0.7) - axs[1].plot(df["time_since_start"], df["output_jitter_ms"], label="Output Jitter (ms)", color="green", alpha=0.7) + axs[0].set_xlim(left=0, right=max_time_overall if max_time_overall > 0 else 1) + axs[0].set_ylim(bottom=0) + axs[1].set_ylabel("Jitter (ms)") - axs[1].set_title("Frame Timing Jitter (Rolling StdDev of Intervals)") + axs[1].set_title("Frame Timing Jitter") axs[1].legend() axs[1].grid(True) - axs[1].set_xlim(left=0) # Start x-axis at 0 - - # Latency plot - show per-frame latency rather than cumulative - axs[2].plot(df["time_since_start"], df["latency_ms"], label="Per-Frame Latency (ms)", color="red", alpha=0.5) - axs[2].plot(df["time_since_start"], df["latency_ms_smooth"], label="Smoothed Latency (ms)", color="darkred", linewidth=2) + axs[1].set_xlim(left=0, right=max_time_overall if max_time_overall > 0 else 1) + axs[2].set_ylabel("Latency (ms)") - axs[2].set_title("Frame Processing Latency") + axs[2].set_title("Processing Latency") axs[2].legend() axs[2].grid(True) - axs[2].set_xlim(left=0) # Start x-axis at 0 - - # Set y-axis limit to better visualize actual per-frame latency without accumulation - max_latency = df["latency_ms"].quantile(0.99) # Use 99th percentile to avoid outliers - axs[2].set_ylim(0, max_latency * 1.1) # Add 10% margin + axs[2].set_xlim(left=0, right=max_time_overall if max_time_overall > 0 else 1) - # Add x-axis label and make it visible + # Add x-axis label fig.text(0.5, 0.04, 'Time (seconds)', ha='center', va='center', fontsize=12) plt.tight_layout() - plt.subplots_adjust(bottom=0.07) # Add space for x-axis label - plt.savefig(csv_path.replace('.csv', '.png')) + plt.subplots_adjust(bottom=0.07) + + # Save combined plot + output_filename = "combined_frame_logs.png" + plt.savefig(output_filename) + print(f"Combined plot saved as {output_filename}") plt.show() +def plot_frame_metrics(csv_path: str = "frame_logs.csv"): + """ + Plot metrics from a single log file for backward compatibility + """ + # For single files, just call the multiple processing function with a list of one item + plot_multiple_frame_metrics([csv_path]) + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Plot frame timing metrics from CSV logs.") parser.add_argument( "--frame-logs", type=str, default="frame_logs.csv", - help="Path to the frame timing CSV log file (default: frame_logs.csv)" + help="Path to the frame timing CSV log file(s) (comma-separated for multiple files)" ) args = parser.parse_args() - plot_frame_metrics(args.frame_logs) + + # Check if multiple files are specified + csv_paths = [path.strip() for path in args.frame_logs.split(',')] + + if len(csv_paths) > 1: + plot_multiple_frame_metrics(csv_paths) + else: + plot_frame_metrics(csv_paths[0]) diff --git a/src/comfystream/pipeline_api.py b/src/comfystream/pipeline_api.py index e9387e2d..a479744a 100644 --- a/src/comfystream/pipeline_api.py +++ b/src/comfystream/pipeline_api.py @@ -105,7 +105,7 @@ def __init__( self.output_interval = 1/30 # Start with 30 FPS self.last_output_time = None self.frame_interval_history = collections.deque(maxlen=30) - self.output_pacer_task = asyncio.create_task(self._dynamic_output_pacer()) + # self.output_pacer_task = asyncio.create_task(self._dynamic_output_pacer()) self.comfyui_log_level = comfyui_log_level @@ -312,9 +312,21 @@ async def put_video_frame(self, frame: av.VideoFrame): frame.side_data.input = self.video_preprocess(frame) frame.side_data.frame_id = frame_id frame.side_data.skipped = False - frame.side_data.frame_received_time = time.time() + + # Set receive time + frame.side_data.frame_received_time = current_time frame.side_data.client_index = client_index - + + # Log frame at input time to properly track input FPS + log_frame_timing( + frame_id=frame_id, + frame_received_time=current_time, + frame_process_start_time=None, + frame_processed_time=None, + client_index=client_index, + csv_path="frame_logs.csv" + ) + # Send frame to the selected client self.clients[client_index].put_video_input(frame) await self.video_incoming_frames.put(frame) @@ -363,23 +375,42 @@ def audio_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.Audio async def get_processed_video_frame(self): try: frame = await self.video_incoming_frames.get() - + + # Set process start time just before processing + frame.side_data.frame_process_start_time = time.time() + # Get the processed frame from our output queue processed_frame_id, out_tensor = await self.processed_video_frames.get() - - # Process the frame + + # if (processed_frame_id != frame.side_data.frame_id): + # logger.warning(f"Processed frame ID {processed_frame_id} does not match expected frame ID {frame.side_data.frame_id}") + + # The processed frame and the video_incoming_frame is never the same + ''' + Processed frame ID 45 does not match expected frame ID 6 + Processed frame ID 47 does not match expected frame ID 7 + Processed frame ID 49 does not match expected frame ID 8 + ''' + # What does this mean? + + # Record the time when processing is complete + frame_processed_time = time.time() + + # Process the frame (post-processing) processed_frame = self.video_postprocess(out_tensor) processed_frame.pts = frame.pts processed_frame.time_base = frame.time_base - # Log frame timing asynchronously + # Log frame timing with simplified metrics log_frame_timing( frame_id=processed_frame_id, frame_received_time=frame.side_data.frame_received_time, - frame_processed_time=time.time(), + frame_process_start_time=frame.side_data.frame_process_start_time, + frame_processed_time=frame_processed_time, client_index=frame.side_data.client_index, + csv_path="frame_logs.csv" ) - + return processed_frame except Exception as e: @@ -497,6 +528,8 @@ async def cleanup(self): logger.info("Pipeline cleanup completed, clients will be reinitialized on next connection") + # This may not be needed anymore - more work is req to balance frame timing + ''' async def _dynamic_output_pacer(self): while self.running: # Only release if the next expected frame is available @@ -523,6 +556,7 @@ async def _dynamic_output_pacer(self): else: # No frame ready, wait a bit and check again await asyncio.sleep(0.005) + ''' async def start_clients(self): """Start the clients based on the client_mode (TOML or spawn)""" From 262ebb68028856a4ddcb8fa46769ae45a124db44 Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Wed, 7 May 2025 23:02:22 -0400 Subject: [PATCH 26/42] Added frame_log_file as argument to select logging file, moved logging to task queue, fixed issues. --- server/app_api.py | 35 +++++++++++---- src/comfystream/pipeline_api.py | 76 ++++++++++++++++++++++++--------- 2 files changed, 82 insertions(+), 29 deletions(-) diff --git a/server/app_api.py b/server/app_api.py index 64666b3d..708241eb 100644 --- a/server/app_api.py +++ b/server/app_api.py @@ -10,6 +10,7 @@ # Initialize CUDA before any other imports to prevent core dump. if torch.cuda.is_available(): torch.cuda.init() + torch.cuda.empty_cache() from aiohttp import web from aiortc import ( @@ -405,11 +406,9 @@ async def on_startup(app: web.Application): cuda_devices=app["cuda_devices"], workers_start_port=app.get("workers_start_port", 8195), comfyui_log_level=app.get("comfyui_log_level", None), + frame_log_file=app.get("frame_log_file", None), ) - if (app.get("client_mode") == "spawn" and app.get("comfyui_log_level") is None): - print("To see spawned ComfyUI logs, add --comfyui_log_level=DEBUG") - # Start the clients during initialization # await app["pipeline"].start_clients() @@ -440,8 +439,7 @@ async def on_shutdown(app: web.Application): "--workspace", default=None, required=True, help="Set Comfy workspace" ) parser.add_argument( - "--log-level", "--log_level", - dest="log_level", + "--log-level", default="WARNING", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], help="Set the logging level", @@ -506,6 +504,12 @@ async def on_shutdown(app: web.Application): default=8195, help="Starting port number for worker processes" ) + parser.add_argument( + "--frame-log-file", + type=str, + default=None, + help="Filename for frame timing log (optional)" + ) args = parser.parse_args() logging.basicConfig( @@ -514,10 +518,6 @@ async def on_shutdown(app: web.Application): datefmt="%H:%M:%S", ) - # Set logger level based on command line arguments - print(f"Setting log level to {args.log_level.upper()}") - logger.setLevel(getattr(logging, args.log_level.upper())) - app = web.Application() app["media_ports"] = args.media_ports.split(",") if args.media_ports else None app["workspace"] = args.workspace @@ -527,6 +527,7 @@ async def on_shutdown(app: web.Application): app["workers"] = args.workers app["cuda_devices"] = args.cuda_devices app["workers_start_port"] = args.workers_start_port + app["frame_log_file"] = args.frame_log_file app.on_startup.append(on_startup) app.on_shutdown.append(on_shutdown) @@ -573,4 +574,20 @@ def force_print(*args, **kwargs): if args.comfyui_inference_log_level: app["comfyui_inference_log_level"] = args.comfyui_inference_log_level + print("\n\nComfystream Options:") + + print(f"Client Mode: {app.get('client_mode')}") + print(f"Log Level: {args.log_level.upper()}") + if (app.get("client_mode") == "spawn" and app.get("comfyui_log_level") is None): + print("To see spawned ComfyUI logs, add --comfyui_log_level=DEBUG") + else: + print(f"ComfyUI Log Level: {app.get('comfyui_log_level')}") + if (app.get("frame_log_file") is None): + print("To set a frame log file, add --frame_log_file=filename.csv") + else: + print(f"Frame Log File: {app.get('frame_log_file')}") + print("\n\n") + + logger.setLevel(getattr(logging, args.log_level.upper())) + web.run_app(app, host=args.host, port=int(args.port), print=force_print) diff --git a/src/comfystream/pipeline_api.py b/src/comfystream/pipeline_api.py index a479744a..b6a0deee 100644 --- a/src/comfystream/pipeline_api.py +++ b/src/comfystream/pipeline_api.py @@ -32,6 +32,7 @@ def __init__( cuda_devices: str = '0', workers_start_port: int = 8195, comfyui_log_level: str = None, + frame_log_file: Optional[str] = None, ): """Initialize the pipeline with the given configuration. Args: @@ -46,8 +47,9 @@ def __init__( workers_start_port: The starting port number for worker processes (default: 8195). cuda_devices: The list of CUDA devices to use for the ComfyUI clients. comfyui_log_level: The logging level for ComfyUI + frame_log_file: The filename for the frame timing log (optional). """ - + # There are two methods for starting the clients: # 1. client_mode == "toml" -> Use a config file to describe clients. # 2. client_mode == "spawn" -> Spawn ComfyUI clients as external processes. @@ -108,7 +110,15 @@ def __init__( # self.output_pacer_task = asyncio.create_task(self._dynamic_output_pacer()) self.comfyui_log_level = comfyui_log_level - + + # Add a queue for frame log entries + self.frame_log_file = frame_log_file + self.frame_log_queue = None # Initialize to None by default + + if self.frame_log_file: + self.frame_log_queue = asyncio.Queue() + self.frame_logger_task = asyncio.create_task(self._process_frame_logs()) + async def _collect_processed_frames(self): """Background task to collect processed frames from all clients""" try: @@ -235,7 +245,7 @@ async def _warm_client_video(self, client, client_index, dummy_frame): logger.info(f"Client {client_index} warmup iteration {i+1}/{WARMUP_RUNS}") client.put_video_input(dummy_frame) try: - await asyncio.wait_for(client.get_video_output(), timeout=5.0) + await asyncio.wait_for(client.get_video_output(), timeout=30) except asyncio.TimeoutError: logger.warning(f"Timeout waiting for warmup frame from client {client_index}") except Exception as e: @@ -318,14 +328,15 @@ async def put_video_frame(self, frame: av.VideoFrame): frame.side_data.client_index = client_index # Log frame at input time to properly track input FPS - log_frame_timing( - frame_id=frame_id, - frame_received_time=current_time, - frame_process_start_time=None, - frame_processed_time=None, - client_index=client_index, - csv_path="frame_logs.csv" - ) + if self.frame_log_file: + await self.frame_log_queue.put({ + 'frame_id': frame_id, + 'frame_received_time': frame.side_data.frame_received_time, + 'frame_process_start_time': None, + 'frame_processed_time': None, + 'client_index': frame.side_data.client_index, + 'csv_path': self.frame_log_file + }) # Send frame to the selected client self.clients[client_index].put_video_input(frame) @@ -377,7 +388,7 @@ async def get_processed_video_frame(self): frame = await self.video_incoming_frames.get() # Set process start time just before processing - frame.side_data.frame_process_start_time = time.time() + frame_process_start_time = time.time() # Get the processed frame from our output queue processed_frame_id, out_tensor = await self.processed_video_frames.get() @@ -402,14 +413,15 @@ async def get_processed_video_frame(self): processed_frame.time_base = frame.time_base # Log frame timing with simplified metrics - log_frame_timing( - frame_id=processed_frame_id, - frame_received_time=frame.side_data.frame_received_time, - frame_process_start_time=frame.side_data.frame_process_start_time, - frame_processed_time=frame_processed_time, - client_index=frame.side_data.client_index, - csv_path="frame_logs.csv" - ) + if self.frame_log_file: + await self.frame_log_queue.put({ + 'frame_id': processed_frame_id, + 'frame_received_time': frame.side_data.frame_received_time, + 'frame_process_start_time': frame_process_start_time, + 'frame_processed_time': frame_processed_time, + 'client_index': frame.side_data.client_index, + 'csv_path': self.frame_log_file + }) return processed_frame @@ -526,6 +538,14 @@ async def cleanup(self): # Reset output counters self.output_counter = 0 + # Cancel frame logger task if it exists + if hasattr(self, 'frame_logger_task') and self.frame_logger_task: + self.frame_logger_task.cancel() + try: + await self.frame_logger_task + except asyncio.CancelledError: + pass + logger.info("Pipeline cleanup completed, clients will be reinitialized on next connection") # This may not be needed anymore - more work is req to balance frame timing @@ -658,5 +678,21 @@ async def start_clients(self): self.clients = [] return None + async def _process_frame_logs(self): + """Background task to process frame logs from queue""" + while self.running: + try: + # Get log entry from queue + log_entry = await self.frame_log_queue.get() + + log_frame_timing(**log_entry) + + # Mark task as done + self.frame_log_queue.task_done() + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Error in frame logging: {e}") + # For backwards compatibility, maintain the original Pipeline name Pipeline = MultiServerPipeline \ No newline at end of file From 1304fd1a748a7918b949a9280b0674ea14aff985 Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Wed, 7 May 2025 23:03:27 -0400 Subject: [PATCH 27/42] Added frame file logging to embedded client. --- server/app.py | 10 ++++++ src/comfystream/pipeline.py | 72 +++++++++++++++++++++++++++++++++++-- 2 files changed, 80 insertions(+), 2 deletions(-) diff --git a/server/app.py b/server/app.py index a52f8061..ee5512ca 100644 --- a/server/app.py +++ b/server/app.py @@ -9,6 +9,7 @@ # Initialize CUDA before any other imports to prevent core dump. if torch.cuda.is_available(): torch.cuda.init() + torch.cuda.empty_cache() from aiohttp import web from aiortc import ( @@ -375,6 +376,7 @@ async def on_startup(app: web.Application): gpu_only=True, preview_method='none', comfyui_inference_log_level=app.get("comfui_inference_log_level", None), + frame_log_file=app.get("frame_log_file", None), ) app["pcs"] = set() app["video_tracks"] = {} @@ -427,6 +429,12 @@ async def on_shutdown(app: web.Application): choices=logging._nameToLevel.keys(), help="Set the logging level for ComfyUI inference", ) + parser.add_argument( + "--frame-log-file", + type=str, + default=None, + help="Filename for frame timing log (optional)" + ) args = parser.parse_args() logging.basicConfig( @@ -438,6 +446,8 @@ async def on_shutdown(app: web.Application): app = web.Application() app["media_ports"] = args.media_ports.split(",") if args.media_ports else None app["workspace"] = args.workspace + app["frame_log_file"] = args.frame_log_file + app.on_startup.append(on_startup) app.on_shutdown.append(on_shutdown) diff --git a/src/comfystream/pipeline.py b/src/comfystream/pipeline.py index a5776dfc..e4c79bf8 100644 --- a/src/comfystream/pipeline.py +++ b/src/comfystream/pipeline.py @@ -3,10 +3,12 @@ import numpy as np import asyncio import logging +import time from typing import Any, Dict, Union, List, Optional from comfystream.client import ComfyStreamClient from comfystream.server.utils import temporary_log_level +from comfystream.frame_logging import log_frame_timing WARMUP_RUNS = 5 @@ -22,7 +24,7 @@ class Pipeline: """ def __init__(self, width: int = 512, height: int = 512, - comfyui_inference_log_level: Optional[int] = None, **kwargs): + comfyui_inference_log_level: Optional[int] = None, frame_log_file: Optional[str] = None, **kwargs): """Initialize the pipeline with the given configuration. Args: @@ -43,6 +45,16 @@ def __init__(self, width: int = 512, height: int = 512, self._comfyui_inference_log_level = comfyui_inference_log_level + # Add a queue for frame log entries + self.running = True + self.next_expected_frame_id = 0 + self.frame_log_file = frame_log_file + self.frame_log_queue = None # Initialize to None by default + + if self.frame_log_file: + self.frame_log_queue = asyncio.Queue() + self.frame_logger_task = asyncio.create_task(self._process_frame_logs()) + async def warm_video(self): """Warm up the video processing pipeline with dummy frames.""" # Create dummy frame with the CURRENT resolution settings @@ -93,8 +105,25 @@ async def put_video_frame(self, frame: av.VideoFrame): Args: frame: The video frame to process """ + current_time = time.time() frame.side_data.input = self.video_preprocess(frame) frame.side_data.skipped = True + frame.side_data.frame_received_time = current_time + frame.side_data.frame_id = self.next_expected_frame_id + frame.side_data.client_index = -1 + self.next_expected_frame_id += 1 + + # Log frame at input time to properly track input FPS + if self.frame_log_file: + await self.frame_log_queue.put({ + 'frame_id': frame.side_data.frame_id, + 'frame_received_time': frame.side_data.frame_received_time, + 'frame_process_start_time': None, + 'frame_processed_time': None, + 'client_index': frame.side_data.client_index, + 'csv_path': self.frame_log_file + }) + self.client.put_video_input(frame) await self.video_incoming_frames.put(frame) @@ -163,6 +192,8 @@ async def get_processed_video_frame(self) -> av.VideoFrame: Returns: The processed video frame """ + frame_process_start_time = time.time() + async with temporary_log_level("comfy", self._comfyui_inference_log_level): out_tensor = await self.client.get_video_output() frame = await self.video_incoming_frames.get() @@ -172,6 +203,19 @@ async def get_processed_video_frame(self) -> av.VideoFrame: processed_frame = self.video_postprocess(out_tensor) processed_frame.pts = frame.pts processed_frame.time_base = frame.time_base + + frame_processed_time = time.time() + + # Log frame timing with simplified metrics + if self.frame_log_file: + await self.frame_log_queue.put({ + 'frame_id': frame.side_data.frame_id, + 'frame_received_time': frame.side_data.frame_received_time, + 'frame_process_start_time': frame_process_start_time, + 'frame_processed_time': frame_processed_time, + 'client_index': frame.side_data.client_index, + 'csv_path': self.frame_log_file + }) return processed_frame @@ -207,4 +251,28 @@ async def get_nodes_info(self) -> Dict[str, Any]: async def cleanup(self): """Clean up resources used by the pipeline.""" - await self.client.cleanup() \ No newline at end of file + + # Cancel frame logger task if it exists + if hasattr(self, 'frame_logger_task') and self.frame_logger_task: + self.frame_logger_task.cancel() + try: + await self.frame_logger_task + except asyncio.CancelledError: + pass + + await self.client.cleanup() + + async def _process_frame_logs(self): + """Background task to process frame logs from queue""" + while self.running: + try: + # Get log entry from queue + log_entry = await self.frame_log_queue.get() + log_frame_timing(**log_entry) + + # Mark task as done + self.frame_log_queue.task_done() + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Error in frame logging: {e}") \ No newline at end of file From 25397411db144b10ba969abe3fc2966772a83fb6 Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Tue, 20 May 2025 19:21:49 -0400 Subject: [PATCH 28/42] First version of multi working. --- server/app_multi.py | 503 ++++++++++++++++++++++++++ src/comfystream/client_multi.py | 435 ++++++++++++++++++++++ src/comfystream/frame_proxy.py | 27 ++ src/comfystream/pipeline_multi.py | 294 +++++++++++++++ src/comfystream/tensor_cache_multi.py | 87 +++++ 5 files changed, 1346 insertions(+) create mode 100644 server/app_multi.py create mode 100644 src/comfystream/client_multi.py create mode 100644 src/comfystream/frame_proxy.py create mode 100644 src/comfystream/pipeline_multi.py create mode 100644 src/comfystream/tensor_cache_multi.py diff --git a/server/app_multi.py b/server/app_multi.py new file mode 100644 index 00000000..9ea4d8f1 --- /dev/null +++ b/server/app_multi.py @@ -0,0 +1,503 @@ +import argparse +import asyncio +import json +import logging +import os +import sys +import torch + +# Initialize CUDA before any other imports to prevent core dump. +if torch.cuda.is_available(): + torch.cuda.init() + torch.cuda.empty_cache() + +from aiohttp import web +from aiortc import ( + MediaStreamTrack, + RTCConfiguration, + RTCIceServer, + RTCPeerConnection, + RTCSessionDescription, +) +from aiortc.codecs import h264 +from aiortc.rtcrtpsender import RTCRtpSender +from comfystream.pipeline_multi import Pipeline +from twilio.rest import Client +from comfystream.server.utils import patch_loop_datagram, add_prefix_to_app_routes, FPSMeter +from comfystream.server.metrics import MetricsManager, StreamStatsManager +import time + +logger = logging.getLogger(__name__) +logging.getLogger("aiortc.rtcrtpsender").setLevel(logging.WARNING) +logging.getLogger("aiortc.rtcrtpreceiver").setLevel(logging.WARNING) + + +MAX_BITRATE = 2000000 +MIN_BITRATE = 2000000 + + +class VideoStreamTrack(MediaStreamTrack): + """video stream track that processes video frames using a pipeline. + + Attributes: + kind (str): The kind of media, which is "video" for this class. + track (MediaStreamTrack): The underlying media stream track. + pipeline (Pipeline): The processing pipeline to apply to each video frame. + """ + + kind = "video" + + def __init__(self, track: MediaStreamTrack, pipeline: Pipeline): + """Initialize the VideoStreamTrack. + + Args: + track: The underlying media stream track. + pipeline: The processing pipeline to apply to each video frame. + """ + super().__init__() + self.track = track + self.pipeline = pipeline + self.fps_meter = FPSMeter( + metrics_manager=app["metrics_manager"], track_id=track.id + ) + self.running = True + self.collect_task = asyncio.create_task(self.collect_frames()) + + # Add cleanup when track ends + @track.on("ended") + async def on_ended(): + logger.info("[App] Source video track ended, stopping collection") + await cancel_collect_frames(self) + + async def collect_frames(self): + """Collect video frames from the underlying track and pass them to + the processing pipeline. Stops when track ends or connection closes. + """ + try: + while self.running: + try: + frame = await self.track.recv() + await self.pipeline.put_video_frame(frame) + except asyncio.CancelledError: + logger.info("[App] Frame collection cancelled") + break + except Exception as e: + if "MediaStreamError" in str(type(e)): + logger.info("[App] Media stream ended") + else: + logger.error(f"[App] Error collecting video frames: {str(e)}") + self.running = False + break + + # Perform cleanup outside the exception handler + logger.info("[App] Video frame collection stopped") + except asyncio.CancelledError: + logger.info("[App] Frame collection task cancelled") + except Exception as e: + logger.error(f"[App] Unexpected error in frame collection: {str(e)}") + finally: + await self.pipeline.cleanup() + + async def recv(self): + """Receive a processed video frame from the pipeline, increment the frame + count for FPS calculation and return the processed frame to the client. + """ + processed_frame = await self.pipeline.get_processed_video_frame() + + # Increment the frame count to calculate FPS. + await self.fps_meter.increment_frame_count() + + return processed_frame + + +class AudioStreamTrack(MediaStreamTrack): + kind = "audio" + + def __init__(self, track: MediaStreamTrack, pipeline): + super().__init__() + self.track = track + self.pipeline = pipeline + self.running = True + self.collect_task = asyncio.create_task(self.collect_frames()) + + # Add cleanup when track ends + @track.on("ended") + async def on_ended(): + logger.info("Source audio track ended, stopping collection") + await cancel_collect_frames(self) + + async def collect_frames(self): + """Collect audio frames from the underlying track and pass them to + the processing pipeline. Stops when track ends or connection closes. + """ + try: + while self.running: + try: + frame = await self.track.recv() + await self.pipeline.put_audio_frame(frame) + except asyncio.CancelledError: + logger.info("[App] Audio frame collection cancelled") + break + except Exception as e: + if "MediaStreamError" in str(type(e)): + logger.info("[App] Media stream ended") + else: + logger.error(f"[App] Error collecting audio frames: {str(e)}") + self.running = False + break + + # Perform cleanup outside the exception handler + logger.info("[App] Audio frame collection stopped") + except asyncio.CancelledError: + logger.info("[App] Frame collection task cancelled") + except Exception as e: + logger.error(f"[App] Unexpected error in audio frame collection: {str(e)}") + finally: + await self.pipeline.cleanup() + + async def recv(self): + return await self.pipeline.get_processed_audio_frame() + + +def force_codec(pc, sender, forced_codec): + kind = forced_codec.split("/")[0] + codecs = RTCRtpSender.getCapabilities(kind).codecs + transceiver = next(t for t in pc.getTransceivers() if t.sender == sender) + codecPrefs = [codec for codec in codecs if codec.mimeType == forced_codec] + transceiver.setCodecPreferences(codecPrefs) + + +def get_twilio_token(): + account_sid = os.getenv("TWILIO_ACCOUNT_SID") + auth_token = os.getenv("TWILIO_AUTH_TOKEN") + + if account_sid is None or auth_token is None: + return None + + client = Client(account_sid, auth_token) + + token = client.tokens.create() + + return token + + +def get_ice_servers(): + ice_servers = [] + + token = get_twilio_token() + if token is not None: + # Use Twilio TURN servers + for server in token.ice_servers: + if server["url"].startswith("turn:"): + turn = RTCIceServer( + urls=[server["urls"]], + credential=server["credential"], + username=server["username"], + ) + ice_servers.append(turn) + + return ice_servers + + +async def offer(request): + pipeline = request.app["pipeline"] + pcs = request.app["pcs"] + + params = await request.json() + + await pipeline.set_prompts(params["prompts"]) + + offer_params = params["offer"] + offer = RTCSessionDescription(sdp=offer_params["sdp"], type=offer_params["type"]) + + ice_servers = get_ice_servers() + if len(ice_servers) > 0: + pc = RTCPeerConnection( + configuration=RTCConfiguration(iceServers=get_ice_servers()) + ) + else: + pc = RTCPeerConnection() + + pcs.add(pc) + + tracks = {"video": None, "audio": None} + + # Flag to track if we've received resolution update + resolution_received = {"value": False} + + # Only add video transceiver if video is present in the offer + if "m=video" in offer.sdp: + # Prefer h264 + transceiver = pc.addTransceiver("video") + caps = RTCRtpSender.getCapabilities("video") + prefs = list(filter(lambda x: x.name == "H264", caps.codecs)) + transceiver.setCodecPreferences(prefs) + + # Monkey patch max and min bitrate to ensure constant bitrate + h264.MAX_BITRATE = MAX_BITRATE + h264.MIN_BITRATE = MIN_BITRATE + + # Handle control channel from client + @pc.on("datachannel") + def on_datachannel(channel): + if channel.label == "control": + + @channel.on("message") + async def on_message(message): + try: + params = json.loads(message) + + if params.get("type") == "get_nodes": + nodes_info = await pipeline.get_nodes_info() + response = {"type": "nodes_info", "nodes": nodes_info} + channel.send(json.dumps(response)) + elif params.get("type") == "update_prompts": + if "prompts" not in params: + logger.warning( + "[Control] Missing prompt in update_prompt message" + ) + return + await pipeline.update_prompts(params["prompts"]) + response = {"type": "prompts_updated", "success": True} + channel.send(json.dumps(response)) + elif params.get("type") == "update_resolution": + if "width" not in params or "height" not in params: + logger.warning("[Control] Missing width or height in update_resolution message") + return + # Update pipeline resolution for future frames + pipeline.width = params["width"] + pipeline.height = params["height"] + logger.info(f"[Control] Updated resolution to {params['width']}x{params['height']}") + + # Mark that we've received resolution + resolution_received["value"] = True + + # Warm the video pipeline with the new resolution + if "m=video" in pc.remoteDescription.sdp: + await pipeline.warm_video() + + response = { + "type": "resolution_updated", + "success": True + } + channel.send(json.dumps(response)) + else: + logger.warning( + "[App] Invalid message format - missing required fields" + ) + except json.JSONDecodeError: + logger.error("[App] Invalid JSON received") + except Exception as e: + logger.error(f"[App] Error processing message: {str(e)}") + + @pc.on("track") + def on_track(track): + logger.info(f"[App] Track received: {track.kind}") + if track.kind == "video": + videoTrack = VideoStreamTrack(track, pipeline) + tracks["video"] = videoTrack + sender = pc.addTrack(videoTrack) + + # Store video track in app for stats. + stream_id = track.id + request.app["video_tracks"][stream_id] = videoTrack + + codec = "video/H264" + force_codec(pc, sender, codec) + elif track.kind == "audio": + audioTrack = AudioStreamTrack(track, pipeline) + tracks["audio"] = audioTrack + pc.addTrack(audioTrack) + + @track.on("ended") + async def on_ended(): + logger.info(f"[App] {track.kind} track ended") + request.app["video_tracks"].pop(track.id, None) + + @pc.on("connectionstatechange") + async def on_connectionstatechange(): + logger.info(f"[App] Connection state is: {pc.connectionState}") + if pc.connectionState == "failed": + await pc.close() + pcs.discard(pc) + elif pc.connectionState == "closed": + await pc.close() + pcs.discard(pc) + + await pc.setRemoteDescription(offer) + + # Only warm audio here, video warming happens after resolution update + if "m=audio" in pc.remoteDescription.sdp: + await pipeline.warm_audio() + + # We no longer warm video here - it will be warmed after receiving resolution + + answer = await pc.createAnswer() + await pc.setLocalDescription(answer) + + return web.Response( + content_type="application/json", + text=json.dumps( + {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type} + ), + ) + +async def cancel_collect_frames(track): + track.running = False + if hasattr(track, 'collect_task') is not None and not track.collect_task.done(): + try: + track.collect_task.cancel() + await track.collect_task + except (asyncio.CancelledError): + pass + +async def set_prompt(request): + pipeline = request.app["pipeline"] + + prompt = await request.json() + await pipeline.set_prompts(prompt) + + return web.Response(content_type="application/json", text="OK") + + +def health(_): + return web.Response(content_type="application/json", text="OK") + + +async def on_startup(app: web.Application): + if app["media_ports"]: + patch_loop_datagram(app["media_ports"]) + + app["pipeline"] = Pipeline( + width=512, + height=512, + cwd=app["workspace"], + disable_cuda_malloc=True, + gpu_only=True, + preview_method='none', + max_workers=app["workers"], + comfyui_inference_log_level=app.get("comfui_inference_log_level", None), + frame_log_file=app.get("frame_log_file", None), + ) + app["pcs"] = set() + app["video_tracks"] = {} + + +async def on_shutdown(app: web.Application): + pcs = app["pcs"] + coros = [pc.close() for pc in pcs] + await asyncio.gather(*coros) + pcs.clear() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run comfystream server") + parser.add_argument("--port", default=8889, help="Set the signaling port") + parser.add_argument( + "--media-ports", default=None, help="Set the UDP ports for WebRTC media" + ) + parser.add_argument("--host", default="127.0.0.1", help="Set the host") + parser.add_argument( + "--workspace", default=None, required=True, help="Set Comfy workspace" + ) + parser.add_argument( + "--log-level", + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + help="Set the logging level", + ) + parser.add_argument( + "--monitor", + default=False, + action="store_true", + help="Start a Prometheus metrics endpoint for monitoring.", + ) + parser.add_argument( + "--stream-id-label", + default=False, + action="store_true", + help="Include stream ID as a label in Prometheus metrics.", + ) + parser.add_argument( + "--comfyui-log-level", + default=None, + choices=logging._nameToLevel.keys(), + help="Set the global logging level for ComfyUI", + ) + parser.add_argument( + "--comfyui-inference-log-level", + default=None, + choices=logging._nameToLevel.keys(), + help="Set the logging level for ComfyUI inference", + ) + parser.add_argument( + "--frame-log-file", + type=str, + default=None, + help="Filename for frame timing log (optional)" + ) + parser.add_argument( + "--workers", + type=int, + default=1, + help="Number of workers to run", + ) + args = parser.parse_args() + + logging.basicConfig( + level=args.log_level.upper(), + format="%(asctime)s [%(levelname)s] %(message)s", + datefmt="%H:%M:%S", + ) + + app = web.Application() + app["media_ports"] = args.media_ports.split(",") if args.media_ports else None + app["workspace"] = args.workspace + app["frame_log_file"] = args.frame_log_file + app["workers"] = args.workers + + app.on_startup.append(on_startup) + app.on_shutdown.append(on_shutdown) + + app.router.add_get("/", health) + app.router.add_get("/health", health) + + # WebRTC signalling and control routes. + app.router.add_post("/offer", offer) + app.router.add_post("/prompt", set_prompt) + + # Add routes for getting stream statistics. + stream_stats_manager = StreamStatsManager(app) + app.router.add_get( + "/streams/stats", stream_stats_manager.collect_all_stream_metrics + ) + app.router.add_get( + "/stream/{stream_id}/stats", stream_stats_manager.collect_stream_metrics_by_id + ) + + # Add Prometheus metrics endpoint. + app["metrics_manager"] = MetricsManager(include_stream_id=args.stream_id_label) + if args.monitor: + app["metrics_manager"].enable() + logger.info( + f"Monitoring enabled - Prometheus metrics available at: " + f"http://{args.host}:{args.port}/metrics" + ) + app.router.add_get("/metrics", app["metrics_manager"].metrics_handler) + + # Add hosted platform route prefix. + # NOTE: This ensures that the local and hosted experiences have consistent routes. + add_prefix_to_app_routes(app, "/live") + + def force_print(*args, **kwargs): + print(*args, **kwargs, flush=True) + sys.stdout.flush() + + # Allow overriding of ComyfUI log levels. + if args.comfyui_log_level: + log_level = logging._nameToLevel.get(args.comfyui_log_level.upper()) + logging.getLogger("comfy").setLevel(log_level) + if args.comfyui_inference_log_level: + app["comfui_inference_log_level"] = args.comfyui_inference_log_level + + web.run_app(app, host=args.host, port=int(args.port), print=force_print) diff --git a/src/comfystream/client_multi.py b/src/comfystream/client_multi.py new file mode 100644 index 00000000..707751fa --- /dev/null +++ b/src/comfystream/client_multi.py @@ -0,0 +1,435 @@ +import asyncio +import logging +from typing import List, Union +import multiprocessing as mp +import os +import sys +import numpy as np +import torch +import av + +from comfystream.utils import convert_prompt +from comfystream.tensor_cache_multi import init_tensor_cache + +from comfy.cli_args_types import Configuration +from comfy.distributed.executors import ProcessPoolExecutor # Use ComfyUI's executor +from comfy.api.components.schema.prompt import PromptDictInput +from comfy.client.embedded_comfy_client import EmbeddedComfyClient +from comfystream.frame_proxy import FrameProxy + +logger = logging.getLogger(__name__) + +def _test_worker_init(): + """Test function to verify worker process initialization.""" + return os.getpid() + +class ComfyStreamClient: + def __init__(self, max_workers: int = 1, executor_type: str = "process", **kwargs): + logger.info(f"[ComfyStreamClient] Main Process ID: {os.getpid()}") + logger.info("[ComfyStreamClient] __init__ start, max_workers:", max_workers, "executor_type:", executor_type) + + # Store default dimensions + self.width = kwargs.get('width', 512) + self.height = kwargs.get('height', 512) + + # Ensure workspace path is absolute + if 'cwd' in kwargs and not os.path.isabs(kwargs['cwd']): + kwargs['cwd'] = os.path.abspath(kwargs['cwd']) + logger.info(f"[ComfyStreamClient] Converted workspace path to absolute: {kwargs['cwd']}") + + logger.info("[ComfyStreamClient] Config kwargs:", kwargs) + + try: + self.config = Configuration(**kwargs) + print("[ComfyStreamClient] Configuration created") + + if executor_type == "process": + logger.info("[ComfyStreamClient] Initializing process executor") + ctx = mp.get_context("spawn") + logger.info(f"[ComfyStreamClient] Using multiprocessing context: {ctx.get_start_method()}") + + manager = ctx.Manager() + logger.info("[ComfyStreamClient] Created multiprocessing context and manager") + + # Create queues with a reasonable size limit to prevent memory issues + self.image_inputs = manager.Queue(maxsize=30) + self.image_outputs = manager.Queue(maxsize=30) + self.audio_inputs = manager.Queue(maxsize=10) + self.audio_outputs = manager.Queue(maxsize=10) + logger.info("[ComfyStreamClient] Created manager queues") + + logger.info("[ComfyStreamClient] About to create ProcessPoolExecutor...") + try: + # Create executor first + executor = ProcessPoolExecutor( + max_workers=max_workers, + initializer=init_tensor_cache, + initargs=(self.image_inputs, self.image_outputs, self.audio_inputs, self.audio_outputs) + ) + logger.info("[ComfyStreamClient] ProcessPoolExecutor created successfully") + + # Create EmbeddedComfyClient with the executor + logger.info("[ComfyStreamClient] Creating EmbeddedComfyClient with executor") + self.comfy_client = EmbeddedComfyClient(self.config, executor=executor) + logger.info("[ComfyStreamClient] EmbeddedComfyClient created successfully") + + # Submit a test task to ensure worker processes are initialized + logger.info("[ComfyStreamClient] Testing worker process initialization...") + test_future = executor.submit(_test_worker_init) # Use the named function instead of lambda + try: + worker_pid = test_future.result(timeout=30) # 30 second timeout + logger.info(f"[ComfyStreamClient] Worker process initialized successfully (PID: {worker_pid})") + except Exception as e: + logger.info(f"[ComfyStreamClient] Error initializing worker process: {str(e)}") + raise + + except Exception as e: + logger.info(f"[ComfyStreamClient] Error during initialization: {str(e)}") + logger.info(f"[ComfyStreamClient] Error type: {type(e)}") + import traceback + logger.info(f"[ComfyStreamClient] Error traceback: {traceback.format_exc()}") + raise + + else: + logger.info("[ComfyStreamClient] Using default executor") + logger.info("[ComfyStreamClient] Creating EmbeddedComfyClient in main process") + self.comfy_client = EmbeddedComfyClient(self.config) + logger.info("[ComfyStreamClient] EmbeddedComfyClient created in main process") + + self.running_prompts = {} + self.current_prompts = [] + self.cleanup_lock = asyncio.Lock() + logger.info("[ComfyStreamClient] __init__ complete") + + except Exception as e: + logger.info(f"[ComfyStreamClient] Error during initialization: {str(e)}") + logger.info(f"[ComfyStreamClient] Error type: {type(e)}") + import traceback + logger.info(f"[ComfyStreamClient] Error traceback: {traceback.format_exc()}") + raise + + async def set_prompts(self, prompts: List[PromptDictInput]): + logger.info("set_prompts start") + self.current_prompts = [convert_prompt(prompt) for prompt in prompts] + for idx in range(len(self.current_prompts)): + logger.info(f"Scheduling run_prompt for idx {idx}") + task = asyncio.create_task(self.run_prompt(idx)) + self.running_prompts[idx] = task + logger.info("set_prompts end") + + async def update_prompts(self, prompts: List[PromptDictInput]): + # TODO: currently under the assumption that only already running prompts are updated + if len(prompts) != len(self.current_prompts): + raise ValueError( + "Number of updated prompts must match the number of currently running prompts." + ) + self.current_prompts = [convert_prompt(prompt) for prompt in prompts] + + async def run_prompt(self, prompt_index: int): + logger.info(f"[ComfyStreamClient] Starting run_prompt for index {prompt_index}") + while True: + try: + logger.debug(f"[ComfyStreamClient] Queueing prompt {prompt_index}") + await self.comfy_client.queue_prompt(self.current_prompts[prompt_index]) + logger.debug(f"[ComfyStreamClient] Prompt {prompt_index} queued successfully") + except Exception as e: + logger.error(f"[ComfyStreamClient] Error in run_prompt {prompt_index}: {str(e)}") + await self.cleanup() + raise + + async def cleanup(self): + async with self.cleanup_lock: + tasks_to_cancel = list(self.running_prompts.values()) + for task in tasks_to_cancel: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + self.running_prompts.clear() + + if self.comfy_client.is_running: + try: + await self.comfy_client.__aexit__() + except Exception as e: + logger.error(f"Error during ComfyClient cleanup: {e}") + + + await self.cleanup_queues() + logger.info("Client cleanup complete") + + + async def cleanup_queues(self): + # TODO: add for audio as well + while not self.image_inputs.empty(): + self.image_inputs.get() + + while not self.image_outputs.empty(): + self.image_outputs.get() + + def put_video_input(self, frame, width=None, height=None, pts=None, time_base=None): + # logger.debug(f"[ComfyStreamClient] Putting video input: {type(frame)}") + try: + if isinstance(frame, FrameProxy): + # Already a FrameProxy, ensure tensor is in BCHW format and on CPU + tensor = frame.side_data.input + if len(tensor.shape) == 3: # CHW format + tensor = tensor.unsqueeze(0) # Add batch dimension -> BCHW + elif len(tensor.shape) == 4: # Already BCHW + tensor = tensor # Keep as is + else: + raise ValueError(f"Unexpected tensor shape: {tensor.shape}") + tensor = tensor.cpu() + proxy = FrameProxy( + tensor=tensor, + width=frame.width, + height=frame.height, + pts=frame.pts, + time_base=frame.time_base + ) + elif hasattr(frame, "to_ndarray") and hasattr(frame, "width") and hasattr(frame, "height"): + # It's an av.VideoFrame, convert to BCHW format + frame_np = frame.to_ndarray(format="rgb24").astype(np.float32) / 255.0 + tensor = torch.from_numpy(frame_np).permute(2, 0, 1).unsqueeze(0) # Convert to BCHW + proxy = FrameProxy( + tensor=tensor, + width=frame.width, + height=frame.height, + pts=getattr(frame, 'pts', None), + time_base=getattr(frame, 'time_base', None) + ) + else: + # Assume it's a tensor, require width/height + if width is None or height is None: + raise ValueError("Width and height must be provided for raw tensors") + tensor = frame + if len(tensor.shape) == 3: # CHW format + tensor = tensor.unsqueeze(0) # Add batch dimension -> BCHW + elif len(tensor.shape) == 4: # Already BCHW + tensor = tensor # Keep as is + else: + raise ValueError(f"Unexpected tensor shape: {tensor.shape}") + tensor = tensor.cpu() + proxy = FrameProxy( + tensor=tensor, + width=width, + height=height, + pts=pts, + time_base=time_base + ) + + if self.image_inputs.full(): + try: + self.image_inputs.get_nowait() + except Exception: + pass + self.image_inputs.put_nowait(proxy) + logger.debug(f"[ComfyStreamClient] Video input queued.") + except Exception as e: + logger.info(f"[ComfyStreamClient] Error putting video frame: {str(e)}") + + def put_audio_input(self, frame): + self.audio_inputs.put(frame) + + async def get_video_output(self): + try: + tensor = await asyncio.wait_for( + asyncio.get_event_loop().run_in_executor(None, self.image_outputs.get), + timeout=5.0 + ) + # Add format conversion here + if len(tensor.shape) == 4 and tensor.shape[1] != 3: # If BHWC format + tensor = tensor.permute(0, 3, 1, 2) # Convert BHWC to BCHW + return tensor + except asyncio.TimeoutError: + return torch.zeros((1, 3, self.height, self.width), dtype=torch.float32) # Return BCHW format + except Exception as e: + logger.info(f"[ComfyStreamClient] Error getting video output: {str(e)}") + return torch.zeros((1, 3, self.height, self.width), dtype=torch.float32) # Return BCHW format + + async def get_audio_output(self): + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, self.audio_outputs.get) + + async def get_available_nodes(self): + """Get metadata and available nodes info in a single pass""" + # TODO: make it for for multiple prompts + if not self.running_prompts: + return {} + + try: + from comfy.nodes.package import import_all_nodes_in_workspace + nodes = import_all_nodes_in_workspace() + + all_prompts_nodes_info = {} + + for prompt_index, prompt in enumerate(self.current_prompts): + # Get set of class types we need metadata for, excluding LoadTensor and SaveTensor + needed_class_types = { + node.get('class_type') + for node in prompt.values() + } + remaining_nodes = { + node_id + for node_id, node in prompt.items() + } + nodes_info = {} + + # Only process nodes until we've found all the ones we need + for class_type, node_class in nodes.NODE_CLASS_MAPPINGS.items(): + if not remaining_nodes: # Exit early if we've found all needed nodes + break + + if class_type not in needed_class_types: + continue + + # Get metadata for this node type (same as original get_node_metadata) + input_data = node_class.INPUT_TYPES() if hasattr(node_class, 'INPUT_TYPES') else {} + input_info = {} + + # Process required inputs + if 'required' in input_data: + for name, value in input_data['required'].items(): + if isinstance(value, tuple): + if len(value) == 1 and isinstance(value[0], list): + # Handle combo box case where value is ([option1, option2, ...],) + input_info[name] = { + 'type': 'combo', + 'value': value[0], # The list of options becomes the value + } + elif len(value) == 2: + input_type, config = value + input_info[name] = { + 'type': input_type, + 'required': True, + 'min': config.get('min', None), + 'max': config.get('max', None), + 'widget': config.get('widget', None) + } + elif len(value) == 1: + # Handle simple type case like ('IMAGE',) + input_info[name] = { + 'type': value[0] + } + else: + logger.error(f"Unexpected structure for required input {name}: {value}") + + # Process optional inputs with same logic + if 'optional' in input_data: + for name, value in input_data['optional'].items(): + if isinstance(value, tuple): + if len(value) == 1 and isinstance(value[0], list): + # Handle combo box case where value is ([option1, option2, ...],) + input_info[name] = { + 'type': 'combo', + 'value': value[0], # The list of options becomes the value + } + elif len(value) == 2: + input_type, config = value + input_info[name] = { + 'type': input_type, + 'required': False, + 'min': config.get('min', None), + 'max': config.get('max', None), + 'widget': config.get('widget', None) + } + elif len(value) == 1: + # Handle simple type case like ('IMAGE',) + input_info[name] = { + 'type': value[0] + } + else: + logger.error(f"Unexpected structure for optional input {name}: {value}") + + # Now process any nodes in our prompt that use this class_type + for node_id in list(remaining_nodes): + node = prompt[node_id] + if node.get('class_type') != class_type: + continue + + node_info = { + 'class_type': class_type, + 'inputs': {} + } + + if 'inputs' in node: + for input_name, input_value in node['inputs'].items(): + input_metadata = input_info.get(input_name, {}) + node_info['inputs'][input_name] = { + 'value': input_value, + 'type': input_metadata.get('type', 'unknown'), + 'min': input_metadata.get('min', None), + 'max': input_metadata.get('max', None), + 'widget': input_metadata.get('widget', None) + } + # For combo type inputs, include the list of options + if input_metadata.get('type') == 'combo': + node_info['inputs'][input_name]['value'] = input_metadata.get('value', []) + + nodes_info[node_id] = node_info + remaining_nodes.remove(node_id) + + all_prompts_nodes_info[prompt_index] = nodes_info + + return all_prompts_nodes_info + + except Exception as e: + logger.error(f"Error getting node info: {str(e)}") + return {} + +def execute_prompt_in_worker(config_dict, prompt): + """Execute a prompt in the worker process""" + logger.info(f"[execute_prompt_in_worker] Starting in process {os.getpid()}") + try: + import os + import sys + import torch + from comfy.cli_args_types import Configuration + from comfy.client.embedded_comfy_client import EmbeddedComfyClient + + # On Windows, we need to ensure the working directory is correct + if sys.platform == 'win32': + # Get the workspace directory from config + workspace = config_dict.get('cwd', '..\\..') + logger.info(f"[execute_prompt_in_worker] Setting working directory to: {workspace}") + os.chdir(workspace) + + # Ensure Python path includes the workspace + if workspace not in sys.path: + sys.path.insert(0, workspace) + logger.info(f"[execute_prompt_in_worker] Added {workspace} to Python path") + + logger.info(f"[execute_prompt_in_worker] Current working directory: {os.getcwd()}") + logger.info(f"[execute_prompt_in_worker] Python path: {sys.path}") + + # Create a new client in the worker process + logger.info("[execute_prompt_in_worker] Creating configuration") + config = Configuration(**config_dict) + + logger.info("[execute_prompt_in_worker] Creating EmbeddedComfyClient") + # Try to initialize CUDA before creating the client + if torch.cuda.is_available(): + logger.info(f"[execute_prompt_in_worker] CUDA device count: {torch.cuda.device_count()}") + # Set the device explicitly + torch.cuda.set_device(0) + logger.info(f"[execute_prompt_in_worker] Set CUDA device to: {torch.cuda.current_device()}") + + client = EmbeddedComfyClient(config) + + # Execute the prompt + logger.info("[execute_prompt_in_worker] Setting up event loop") + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + logger.info("[execute_prompt_in_worker] Queueing prompt") + loop.run_until_complete(client.queue_prompt(prompt)) + logger.info("[execute_prompt_in_worker] Prompt queued successfully") + finally: + logger.info("[execute_prompt_in_worker] Closing event loop") + loop.close() + except Exception as e: + logger.info(f"[execute_prompt_in_worker] Error: {str(e)}") + logger.info(f"[execute_prompt_in_worker] Error type: {type(e)}") + import traceback + logger.info(f"[execute_prompt_in_worker] Error traceback: {traceback.format_exc()}") + raise \ No newline at end of file diff --git a/src/comfystream/frame_proxy.py b/src/comfystream/frame_proxy.py new file mode 100644 index 00000000..e985e50f --- /dev/null +++ b/src/comfystream/frame_proxy.py @@ -0,0 +1,27 @@ +import torch +import numpy as np + +class SideData: + pass + +class FrameProxy: + def __init__(self, tensor, width, height, pts=None, time_base=None): + self.width = width + self.height = height + self.pts = pts + self.time_base = time_base + self.side_data = SideData() + self.side_data.input = tensor.clone().cpu() + self.side_data.skipped = True + + @staticmethod + def avframe_to_frameproxy(frame): + frame_np = frame.to_ndarray(format="rgb24").astype(np.float32) / 255.0 + tensor = torch.from_numpy(frame_np).unsqueeze(0) + return FrameProxy( + tensor=tensor.clone().cpu(), + width=frame.width, + height=frame.height, + pts=getattr(frame, "pts", None), + time_base=getattr(frame, "time_base", None) + ) \ No newline at end of file diff --git a/src/comfystream/pipeline_multi.py b/src/comfystream/pipeline_multi.py new file mode 100644 index 00000000..08dd9431 --- /dev/null +++ b/src/comfystream/pipeline_multi.py @@ -0,0 +1,294 @@ +import av +import torch +import numpy as np +import asyncio +import logging +import time +from typing import Any, Dict, Union, List, Optional + +from comfystream.client_multi import ComfyStreamClient +from comfystream.server.utils import temporary_log_level +from comfystream.frame_logging import log_frame_timing +from comfystream.frame_proxy import FrameProxy + +WARMUP_RUNS = 5 + +logger = logging.getLogger(__name__) + + +class Pipeline: + """A pipeline for processing video and audio frames using ComfyUI. + + This class provides a high-level interface for processing video and audio frames + through a ComfyUI-based processing pipeline. It handles frame preprocessing, + postprocessing, and queue management. + """ + + def __init__(self, width: int = 512, height: int = 512, + comfyui_inference_log_level: Optional[int] = None, frame_log_file: Optional[str] = None, **kwargs): + """Initialize the pipeline with the given configuration. + + Args: + width: Width of the video frames (default: 512) + height: Height of the video frames (default: 512) + comfyui_inference_log_level: The logging level for ComfyUI inference. + Defaults to None, using the global ComfyUI log level. + **kwargs: Additional arguments to pass to the ComfyStreamClient + """ + self.client = ComfyStreamClient(**kwargs) + self.width = width + self.height = height + + self.video_incoming_frames = asyncio.Queue() + self.audio_incoming_frames = asyncio.Queue() + + self.processed_audio_buffer = np.array([], dtype=np.int16) + + self._comfyui_inference_log_level = comfyui_inference_log_level + + # Add a queue for frame log entries + self.running = True + self.next_expected_frame_id = 0 + self.frame_log_file = frame_log_file + self.frame_log_queue = None # Initialize to None by default + + if self.frame_log_file: + self.frame_log_queue = asyncio.Queue() + self.frame_logger_task = asyncio.create_task(self._process_frame_logs()) + + async def initialize(self, prompts): + await self.set_prompts(prompts) + await self.warm_video() + + async def warm_video(self): + logger.info("[PipelineMulti] Starting warmup...") + for i in range(WARMUP_RUNS): + dummy_tensor = torch.randn(1, self.height, self.width, 3) + dummy_proxy = FrameProxy( + tensor=dummy_tensor, + width=self.width, + height=self.height, + pts=None, + time_base=None + ) + logger.debug(f"[PipelineMulti] Warmup: putting dummy frame {i+1}/{WARMUP_RUNS}") + self.client.put_video_input(dummy_proxy) + out = await self.client.get_video_output() + logger.debug(f"[PipelineMulti] Warmup: got output for dummy frame {i+1}/{WARMUP_RUNS}: shape={getattr(out, 'shape', None)}") + logger.info("[PipelineMulti] Warmup complete.") + + async def warm_audio(self): + """Warm up the audio processing pipeline with dummy frames.""" + dummy_frame = av.AudioFrame() + dummy_frame.side_data.input = np.random.randint(-32768, 32767, int(48000 * 0.5), dtype=np.int16) # TODO: adds a lot of delay if it doesn't match the buffer size, is warmup needed? + dummy_frame.sample_rate = 48000 + + for _ in range(WARMUP_RUNS): + self.client.put_audio_input(dummy_frame) + await self.client.get_audio_output() + + async def set_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]): + """Set the processing prompts for the pipeline. + + Args: + prompts: Either a single prompt dictionary or a list of prompt dictionaries + """ + if isinstance(prompts, list): + await self.client.set_prompts(prompts) + else: + await self.client.set_prompts([prompts]) + + async def update_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]): + """Update the existing processing prompts. + + Args: + prompts: Either a single prompt dictionary or a list of prompt dictionaries + """ + if isinstance(prompts, list): + await self.client.update_prompts(prompts) + else: + await self.client.update_prompts([prompts]) + + async def put_video_frame(self, frame: av.VideoFrame): + current_time = time.time() + frame.side_data.input = self.video_preprocess(frame) + frame.side_data.skipped = True + frame.side_data.frame_received_time = current_time + frame.side_data.frame_id = self.next_expected_frame_id + frame.side_data.client_index = -1 + self.next_expected_frame_id += 1 + + # Log frame at input time to properly track input FPS + if self.frame_log_file: + await self.frame_log_queue.put({ + 'frame_id': frame.side_data.frame_id, + 'frame_received_time': frame.side_data.frame_received_time, + 'frame_process_start_time': None, + 'frame_processed_time': None, + 'client_index': frame.side_data.client_index, + 'csv_path': self.frame_log_file + }) + + self.client.put_video_input(frame) + await self.video_incoming_frames.put(frame) + + async def put_audio_frame(self, frame: av.AudioFrame): + """Queue an audio frame for processing. + + Args: + frame: The audio frame to process + """ + frame.side_data.input = self.audio_preprocess(frame) + frame.side_data.skipped = True + self.client.put_audio_input(frame) + await self.audio_incoming_frames.put(frame) + + def video_preprocess(self, frame: av.VideoFrame) -> Union[torch.Tensor, np.ndarray]: + """Preprocess a video frame before processing. + + Args: + frame: The video frame to preprocess + + Returns: + The preprocessed frame as a tensor or numpy array + """ + frame_np = frame.to_ndarray(format="rgb24").astype(np.float32) / 255.0 + return torch.from_numpy(frame_np).unsqueeze(0) + + def audio_preprocess(self, frame: av.AudioFrame) -> Union[torch.Tensor, np.ndarray]: + """Preprocess an audio frame before processing. + + Args: + frame: The audio frame to preprocess + + Returns: + The preprocessed frame as a tensor or numpy array + """ + return frame.to_ndarray().ravel().reshape(-1, 2).mean(axis=1).astype(np.int16) + + + def video_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.VideoFrame: + """Postprocess a tensor in BCHW format back to video frame.""" + # First ensure we have a tensor + if isinstance(output, np.ndarray): + output = torch.from_numpy(output) + + # Handle different tensor formats + if len(output.shape) == 4: # BCHW or BHWC format + if output.shape[1] != 3: # If BHWC format + output = output.permute(0, 3, 1, 2) # Convert BHWC to BCHW + output = output[0] # Take first image from batch -> CHW + elif len(output.shape) != 3: # Should be CHW at this point + raise ValueError(f"Unexpected tensor shape after batch removal: {output.shape}") + + # Convert CHW to HWC for video frame + output = output.permute(1, 2, 0) # CHW -> HWC + + # Convert to numpy and create video frame + return av.VideoFrame.from_ndarray( + (output * 255.0).clamp(0, 255).to(dtype=torch.uint8).cpu().numpy(), + format='rgb24' + ) + + def audio_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.AudioFrame: + """Postprocess an audio frame after processing. + + Args: + output: The processed output tensor or numpy array + + Returns: + The postprocessed audio frame + """ + return av.AudioFrame.from_ndarray(np.repeat(output, 2).reshape(1, -1)) + + # TODO: make it generic to support purely generative video cases + async def get_processed_video_frame(self) -> av.VideoFrame: + logger.info("[PipelineMulti] get_processed_video_frame called") + logger.debug("[PipelineMulti] Waiting for processed video frame...") + frame_process_start_time = time.time() + + # Get the input frame first + frame = await self.video_incoming_frames.get() + + # Then get the output tensor + async with temporary_log_level("comfy", self._comfyui_inference_log_level): + out_tensor = await self.client.get_video_output() + + # Process the frame + processed_frame = self.video_postprocess(out_tensor) + processed_frame.pts = frame.pts + processed_frame.time_base = frame.time_base + + frame_processed_time = time.time() + + # Log frame timing + if self.frame_log_file: + await self.frame_log_queue.put({ + 'frame_id': frame.side_data.frame_id, + 'frame_received_time': frame.side_data.frame_received_time, + 'frame_process_start_time': frame_process_start_time, + 'frame_processed_time': frame_processed_time, + 'client_index': frame.side_data.client_index, + 'csv_path': self.frame_log_file + }) + + logger.info("[PipelineMulti] get_processed_video_frame returning frame") + return processed_frame + + async def get_processed_audio_frame(self) -> av.AudioFrame: + """Get the next processed audio frame. + + Returns: + The processed audio frame + """ + frame = await self.audio_incoming_frames.get() + if frame.samples > len(self.processed_audio_buffer): + async with temporary_log_level("comfy", self._comfyui_inference_log_level): + out_tensor = await self.client.get_audio_output() + self.processed_audio_buffer = np.concatenate([self.processed_audio_buffer, out_tensor]) + out_data = self.processed_audio_buffer[:frame.samples] + self.processed_audio_buffer = self.processed_audio_buffer[frame.samples:] + + processed_frame = self.audio_postprocess(out_data) + processed_frame.pts = frame.pts + processed_frame.time_base = frame.time_base + processed_frame.sample_rate = frame.sample_rate + + return processed_frame + + async def get_nodes_info(self) -> Dict[str, Any]: + """Get information about all nodes in the current prompt including metadata. + + Returns: + Dictionary containing node information + """ + nodes_info = await self.client.get_available_nodes() + return nodes_info + + async def cleanup(self): + """Clean up resources used by the pipeline.""" + + # Cancel frame logger task if it exists + if hasattr(self, 'frame_logger_task') and self.frame_logger_task: + self.frame_logger_task.cancel() + try: + await self.frame_logger_task + except asyncio.CancelledError: + pass + + await self.client.cleanup() + + async def _process_frame_logs(self): + """Background task to process frame logs from queue""" + while self.running: + try: + # Get log entry from queue + log_entry = await self.frame_log_queue.get() + log_frame_timing(**log_entry) + + # Mark task as done + self.frame_log_queue.task_done() + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Error in frame logging: {e}") \ No newline at end of file diff --git a/src/comfystream/tensor_cache_multi.py b/src/comfystream/tensor_cache_multi.py new file mode 100644 index 00000000..90a0f02f --- /dev/null +++ b/src/comfystream/tensor_cache_multi.py @@ -0,0 +1,87 @@ +# TODO: add better frame management, improve eviction policy fifo might not be the best, skip alternate frames instead +# TODO: also make the tensor_cache solution backward compatible for when not using process pool -- after the multi process solution is stable +from comfystream import tensor_cache +import queue +import torch +import asyncio +from queue import Queue +from asyncio import Queue as AsyncQueue + +image_inputs = None +image_outputs = None + +audio_inputs = None +audio_outputs = None + +# Create wrapper classes that match the interface of the original queues +class MultiProcessInputQueue: + def __init__(self, mp_queue): + self.queue = mp_queue + + def get(self, block=True, timeout=None): + return self.queue.get(block=block, timeout=timeout) + + def get_nowait(self): + return self.queue.get_nowait() + + def put(self, item, block=True, timeout=None): + return self.queue.put(item, block=block, timeout=timeout) + + def put_nowait(self, item): + return self.queue.put_nowait(item) + + def empty(self): + return self.queue.empty() + + def full(self): + return self.queue.full() + +class MultiProcessOutputQueue: + def __init__(self, mp_queue): + self.queue = mp_queue + + async def get(self): + # Convert synchronous get to async + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, self.queue.get) + + async def put(self, item): + # Convert synchronous put to async + loop = asyncio.get_event_loop() + # Ensure tensor is on CPU before sending + if torch.is_tensor(item): + item = item.cpu() + return await loop.run_in_executor(None, self.queue.put, item) + + def put_nowait(self, item): + try: + # Ensure tensor is on CPU before sending + if torch.is_tensor(item): + item = item.cpu() + self.queue.put_nowait(item) + except queue.Full: + try: + self.queue.get_nowait() # Drop oldest + self.queue.put_nowait(item) + except Exception: + pass # If still full, drop this frame + +def init_tensor_cache(image_inputs, image_outputs, audio_inputs, audio_outputs): + """Initialize the tensor cache for a worker process. + + Args: + image_inputs: Multiprocessing Queue for input images + image_outputs: Multiprocessing Queue for output images + audio_inputs: Multiprocessing Queue for input audio + audio_outputs: Multiprocessing Queue for output audio + """ + print("[init_tensor_cache] Setting up tensor_cache queues in worker") + + # Replace the queues with our wrapped versions that match the original interface + tensor_cache.image_inputs = MultiProcessInputQueue(image_inputs) + tensor_cache.image_outputs = MultiProcessOutputQueue(image_outputs) + tensor_cache.audio_inputs = MultiProcessInputQueue(audio_inputs) + tensor_cache.audio_outputs = MultiProcessOutputQueue(audio_outputs) + + print("[init_tensor_cache] tensor_cache.image_outputs id:", id(tensor_cache.image_outputs)) + print("[init_tensor_cache] Initialization complete") \ No newline at end of file From ad4476687ccccf82f19e75c614d8f6fa11cdbaa7 Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Tue, 20 May 2025 21:52:51 -0400 Subject: [PATCH 29/42] Simplify video input, small fix. --- src/comfystream/client_multi.py | 66 ++++++--------------------------- 1 file changed, 12 insertions(+), 54 deletions(-) diff --git a/src/comfystream/client_multi.py b/src/comfystream/client_multi.py index 707751fa..32dfaf8c 100644 --- a/src/comfystream/client_multi.py +++ b/src/comfystream/client_multi.py @@ -26,7 +26,7 @@ def _test_worker_init(): class ComfyStreamClient: def __init__(self, max_workers: int = 1, executor_type: str = "process", **kwargs): logger.info(f"[ComfyStreamClient] Main Process ID: {os.getpid()}") - logger.info("[ComfyStreamClient] __init__ start, max_workers:", max_workers, "executor_type:", executor_type) + logger.info(f"[ComfyStreamClient] __init__ start, max_workers: {max_workers}, executor_type: {executor_type}") # Store default dimensions self.width = kwargs.get('width', 512) @@ -167,62 +167,22 @@ async def cleanup_queues(self): while not self.image_outputs.empty(): self.image_outputs.get() - def put_video_input(self, frame, width=None, height=None, pts=None, time_base=None): - # logger.debug(f"[ComfyStreamClient] Putting video input: {type(frame)}") + def put_video_input(self, frame): try: + # Check if frame is FrameProxy if isinstance(frame, FrameProxy): - # Already a FrameProxy, ensure tensor is in BCHW format and on CPU - tensor = frame.side_data.input - if len(tensor.shape) == 3: # CHW format - tensor = tensor.unsqueeze(0) # Add batch dimension -> BCHW - elif len(tensor.shape) == 4: # Already BCHW - tensor = tensor # Keep as is - else: - raise ValueError(f"Unexpected tensor shape: {tensor.shape}") - tensor = tensor.cpu() - proxy = FrameProxy( - tensor=tensor, - width=frame.width, - height=frame.height, - pts=frame.pts, - time_base=frame.time_base - ) - elif hasattr(frame, "to_ndarray") and hasattr(frame, "width") and hasattr(frame, "height"): - # It's an av.VideoFrame, convert to BCHW format - frame_np = frame.to_ndarray(format="rgb24").astype(np.float32) / 255.0 - tensor = torch.from_numpy(frame_np).permute(2, 0, 1).unsqueeze(0) # Convert to BCHW - proxy = FrameProxy( - tensor=tensor, - width=frame.width, - height=frame.height, - pts=getattr(frame, 'pts', None), - time_base=getattr(frame, 'time_base', None) - ) + proxy = frame + # Otherwise create a proxy (assuming frame is av.VideoFrame as in pipeline.py) else: - # Assume it's a tensor, require width/height - if width is None or height is None: - raise ValueError("Width and height must be provided for raw tensors") - tensor = frame - if len(tensor.shape) == 3: # CHW format - tensor = tensor.unsqueeze(0) # Add batch dimension -> BCHW - elif len(tensor.shape) == 4: # Already BCHW - tensor = tensor # Keep as is - else: - raise ValueError(f"Unexpected tensor shape: {tensor.shape}") - tensor = tensor.cpu() - proxy = FrameProxy( - tensor=tensor, - width=width, - height=height, - pts=pts, - time_base=time_base - ) - + proxy = FrameProxy.avframe_to_frameproxy(frame) + + # Handle queue being full if self.image_inputs.full(): try: self.image_inputs.get_nowait() except Exception: pass + self.image_inputs.put_nowait(proxy) logger.debug(f"[ComfyStreamClient] Video input queued.") except Exception as e: @@ -237,15 +197,13 @@ async def get_video_output(self): asyncio.get_event_loop().run_in_executor(None, self.image_outputs.get), timeout=5.0 ) - # Add format conversion here - if len(tensor.shape) == 4 and tensor.shape[1] != 3: # If BHWC format - tensor = tensor.permute(0, 3, 1, 2) # Convert BHWC to BCHW + # No need for permutation here - tensor should already be in the right format return tensor except asyncio.TimeoutError: - return torch.zeros((1, 3, self.height, self.width), dtype=torch.float32) # Return BCHW format + return torch.zeros((1, 3, self.height, self.width), dtype=torch.float32) except Exception as e: logger.info(f"[ComfyStreamClient] Error getting video output: {str(e)}") - return torch.zeros((1, 3, self.height, self.width), dtype=torch.float32) # Return BCHW format + return torch.zeros((1, 3, self.height, self.width), dtype=torch.float32) async def get_audio_output(self): loop = asyncio.get_event_loop() From cbeb89ee44b19d06cc3df260a05e6d0899b78f8e Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Mon, 26 May 2025 12:21:31 -0400 Subject: [PATCH 30/42] Restructure of distribution logic, added pipeline/tensor_cache logging, attempt to fix tensor_rt directory retrieval in comfyui. --- server/app_multi.py | 1 + src/comfystream/client_multi.py | 117 +++++++++++++++++++------- src/comfystream/pipeline_multi.py | 5 +- src/comfystream/tensor_cache_multi.py | 17 ++-- 4 files changed, 104 insertions(+), 36 deletions(-) diff --git a/server/app_multi.py b/server/app_multi.py index 9ea4d8f1..20d60a60 100644 --- a/server/app_multi.py +++ b/server/app_multi.py @@ -378,6 +378,7 @@ async def on_startup(app: web.Application): max_workers=app["workers"], comfyui_inference_log_level=app.get("comfui_inference_log_level", None), frame_log_file=app.get("frame_log_file", None), + base_directory=app["workspace"] ) app["pcs"] = set() app["video_tracks"] = {} diff --git a/src/comfystream/client_multi.py b/src/comfystream/client_multi.py index 32dfaf8c..3b5fa6c8 100644 --- a/src/comfystream/client_multi.py +++ b/src/comfystream/client_multi.py @@ -99,6 +99,10 @@ def __init__(self, max_workers: int = 1, executor_type: str = "process", **kwarg self.running_prompts = {} self.current_prompts = [] self.cleanup_lock = asyncio.Lock() + self.max_workers = max_workers + self.worker_tasks = [] + self.next_worker = 0 + self.distribution_lock = asyncio.Lock() logger.info("[ComfyStreamClient] __init__ complete") except Exception as e: @@ -111,54 +115,106 @@ def __init__(self, max_workers: int = 1, executor_type: str = "process", **kwarg async def set_prompts(self, prompts: List[PromptDictInput]): logger.info("set_prompts start") self.current_prompts = [convert_prompt(prompt) for prompt in prompts] - for idx in range(len(self.current_prompts)): - logger.info(f"Scheduling run_prompt for idx {idx}") - task = asyncio.create_task(self.run_prompt(idx)) - self.running_prompts[idx] = task + + # Start the distribution manager + distribution_task = asyncio.create_task(self.distribute_frames()) + self.running_prompts[-1] = distribution_task # Use -1 as a special key for the manager logger.info("set_prompts end") - async def update_prompts(self, prompts: List[PromptDictInput]): - # TODO: currently under the assumption that only already running prompts are updated - if len(prompts) != len(self.current_prompts): - raise ValueError( - "Number of updated prompts must match the number of currently running prompts." - ) - self.current_prompts = [convert_prompt(prompt) for prompt in prompts] - - async def run_prompt(self, prompt_index: int): - logger.info(f"[ComfyStreamClient] Starting run_prompt for index {prompt_index}") + async def distribute_frames(self): + """Manager that distributes frames across workers in round-robin fashion""" + logger.info(f"[ComfyStreamClient] Starting frame distribution manager") + + # Initialize worker tasks + self.worker_tasks = [] + for worker_id in range(self.max_workers): + worker_task = asyncio.create_task(self.worker_loop(worker_id)) + self.worker_tasks.append(worker_task) + self.running_prompts[worker_id] = worker_task + + # Keep the manager running to monitor workers + while True: + await asyncio.sleep(1.0) # Check periodically + # Restart any crashed workers + for worker_id, task in enumerate(self.worker_tasks): + if task.done(): + logger.warning(f"Worker {worker_id} crashed, restarting") + new_task = asyncio.create_task(self.worker_loop(worker_id)) + self.worker_tasks[worker_id] = new_task + self.running_prompts[worker_id] = new_task + + async def worker_loop(self, worker_id: int): + """Worker process that handles frames""" + logger.info(f"[ComfyStreamClient] Worker {worker_id} started - PID: {os.getpid()}") + while True: - try: - logger.debug(f"[ComfyStreamClient] Queueing prompt {prompt_index}") - await self.comfy_client.queue_prompt(self.current_prompts[prompt_index]) - logger.debug(f"[ComfyStreamClient] Prompt {prompt_index} queued successfully") - except Exception as e: - logger.error(f"[ComfyStreamClient] Error in run_prompt {prompt_index}: {str(e)}") - await self.cleanup() - raise + # Check if this worker should process the next frame + async with self.distribution_lock: + should_process = (self.next_worker == worker_id) + if should_process: + # Move to next worker for round-robin distribution + self.next_worker = (self.next_worker + 1) % self.max_workers + + if should_process: + # Get prompt - cycle through available prompts + prompt_index = worker_id % len(self.current_prompts) + prompt = self.current_prompts[prompt_index] + + try: + logger.debug(f"[ComfyStreamClient] Worker {worker_id} processing frame with prompt {prompt_index} - PID: {os.getpid()}") + # Process a single frame + await self.process_frame_with_prompt(prompt) + except Exception as e: + logger.error(f"[ComfyStreamClient] Error in worker {worker_id}: {str(e)}") + await asyncio.sleep(0.1) # Avoid tight error loop + else: + # Small wait to avoid busy waiting + await asyncio.sleep(0.01) + + async def process_frame_with_prompt(self, prompt): + """Process a single frame with the given prompt""" + try: + # Get frame from input queue if available + if not self.image_inputs.empty(): + # Non-blocking queue check + frame = None + try: + frame = self.image_inputs.get_nowait() + except Exception: + pass + + if frame is not None: + # Queue the prompt to process this frame + await self.comfy_client.queue_prompt(prompt) + + # Assuming result will eventually appear in output queue + # You may need logic to match inputs with outputs + + else: + # No frames to process, small wait + await asyncio.sleep(0.01) + + except Exception as e: + logger.error(f"[ComfyStreamClient] Error processing frame: {str(e)}") async def cleanup(self): async with self.cleanup_lock: - tasks_to_cancel = list(self.running_prompts.values()) - for task in tasks_to_cancel: + for task in self.worker_tasks: task.cancel() try: await task except asyncio.CancelledError: pass - self.running_prompts.clear() - + if self.comfy_client.is_running: try: await self.comfy_client.__aexit__() except Exception as e: logger.error(f"Error during ComfyClient cleanup: {e}") - await self.cleanup_queues() logger.info("Client cleanup complete") - async def cleanup_queues(self): # TODO: add for audio as well while not self.image_inputs.empty(): @@ -193,16 +249,19 @@ def put_audio_input(self, frame): async def get_video_output(self): try: + logger.debug(f"[ComfyStreamClient] get_video_output called - PID: {os.getpid()}") tensor = await asyncio.wait_for( asyncio.get_event_loop().run_in_executor(None, self.image_outputs.get), timeout=5.0 ) + logger.debug(f"[ComfyStreamClient] get_video_output returning tensor - PID: {os.getpid()}") # No need for permutation here - tensor should already be in the right format return tensor except asyncio.TimeoutError: + logger.warning(f"[ComfyStreamClient] get_video_output timeout - PID: {os.getpid()}") return torch.zeros((1, 3, self.height, self.width), dtype=torch.float32) except Exception as e: - logger.info(f"[ComfyStreamClient] Error getting video output: {str(e)}") + logger.info(f"[ComfyStreamClient] Error getting video output: {str(e)} - PID: {os.getpid()}") return torch.zeros((1, 3, self.height, self.width), dtype=torch.float32) async def get_audio_output(self): diff --git a/src/comfystream/pipeline_multi.py b/src/comfystream/pipeline_multi.py index 08dd9431..cc48ae27 100644 --- a/src/comfystream/pipeline_multi.py +++ b/src/comfystream/pipeline_multi.py @@ -4,6 +4,7 @@ import asyncio import logging import time +import os from typing import Any, Dict, Union, List, Optional from comfystream.client_multi import ComfyStreamClient @@ -203,7 +204,7 @@ def audio_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.Audio # TODO: make it generic to support purely generative video cases async def get_processed_video_frame(self) -> av.VideoFrame: - logger.info("[PipelineMulti] get_processed_video_frame called") + logger.info(f"[PipelineMulti] get_processed_video_frame called - PID: {os.getpid()}") logger.debug("[PipelineMulti] Waiting for processed video frame...") frame_process_start_time = time.time() @@ -232,7 +233,7 @@ async def get_processed_video_frame(self) -> av.VideoFrame: 'csv_path': self.frame_log_file }) - logger.info("[PipelineMulti] get_processed_video_frame returning frame") + logger.info(f"[PipelineMulti] get_processed_video_frame returning frame - PID: {os.getpid()}") return processed_frame async def get_processed_audio_frame(self) -> av.AudioFrame: diff --git a/src/comfystream/tensor_cache_multi.py b/src/comfystream/tensor_cache_multi.py index 90a0f02f..2d291822 100644 --- a/src/comfystream/tensor_cache_multi.py +++ b/src/comfystream/tensor_cache_multi.py @@ -4,6 +4,7 @@ import queue import torch import asyncio +import os from queue import Queue from asyncio import Queue as AsyncQueue @@ -19,10 +20,14 @@ def __init__(self, mp_queue): self.queue = mp_queue def get(self, block=True, timeout=None): - return self.queue.get(block=block, timeout=timeout) + result = self.queue.get(block=block, timeout=timeout) + print(f"[MultiProcessInputQueue] Frame retrieved by worker PID: {os.getpid()}") + return result def get_nowait(self): - return self.queue.get_nowait() + result = self.queue.get_nowait() + print(f"[MultiProcessInputQueue] Frame retrieved (nowait) by worker PID: {os.getpid()}") + return result def put(self, item, block=True, timeout=None): return self.queue.put(item, block=block, timeout=timeout) @@ -51,6 +56,7 @@ async def put(self, item): # Ensure tensor is on CPU before sending if torch.is_tensor(item): item = item.cpu() + print(f"[MultiProcessOutputQueue] Frame sent from worker PID: {os.getpid()}") return await loop.run_in_executor(None, self.queue.put, item) def put_nowait(self, item): @@ -58,6 +64,7 @@ def put_nowait(self, item): # Ensure tensor is on CPU before sending if torch.is_tensor(item): item = item.cpu() + print(f"[MultiProcessOutputQueue] Frame sent (nowait) from worker PID: {os.getpid()}") self.queue.put_nowait(item) except queue.Full: try: @@ -75,7 +82,7 @@ def init_tensor_cache(image_inputs, image_outputs, audio_inputs, audio_outputs): audio_inputs: Multiprocessing Queue for input audio audio_outputs: Multiprocessing Queue for output audio """ - print("[init_tensor_cache] Setting up tensor_cache queues in worker") + print(f"[init_tensor_cache] Setting up tensor_cache queues in worker - PID: {os.getpid()}") # Replace the queues with our wrapped versions that match the original interface tensor_cache.image_inputs = MultiProcessInputQueue(image_inputs) @@ -83,5 +90,5 @@ def init_tensor_cache(image_inputs, image_outputs, audio_inputs, audio_outputs): tensor_cache.audio_inputs = MultiProcessInputQueue(audio_inputs) tensor_cache.audio_outputs = MultiProcessOutputQueue(audio_outputs) - print("[init_tensor_cache] tensor_cache.image_outputs id:", id(tensor_cache.image_outputs)) - print("[init_tensor_cache] Initialization complete") \ No newline at end of file + print(f"[init_tensor_cache] tensor_cache.image_outputs id: {id(tensor_cache.image_outputs)} - PID: {os.getpid()}") + print(f"[init_tensor_cache] Initialization complete - PID: {os.getpid()}") \ No newline at end of file From 9d0ee6e4677b8ea5969872b5a2659acf795ebc55 Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Mon, 26 May 2025 21:17:28 -0400 Subject: [PATCH 31/42] Modified queue size, re-worked prompting, testing queue sizes, logging. --- src/comfystream/client_multi.py | 89 ++++++++++----------------------- 1 file changed, 27 insertions(+), 62 deletions(-) diff --git a/src/comfystream/client_multi.py b/src/comfystream/client_multi.py index 3b5fa6c8..a0e78339 100644 --- a/src/comfystream/client_multi.py +++ b/src/comfystream/client_multi.py @@ -50,12 +50,11 @@ def __init__(self, max_workers: int = 1, executor_type: str = "process", **kwarg manager = ctx.Manager() logger.info("[ComfyStreamClient] Created multiprocessing context and manager") - - # Create queues with a reasonable size limit to prevent memory issues - self.image_inputs = manager.Queue(maxsize=30) - self.image_outputs = manager.Queue(maxsize=30) - self.audio_inputs = manager.Queue(maxsize=10) - self.audio_outputs = manager.Queue(maxsize=10) + + self.image_inputs = manager.Queue(maxsize=50) + self.image_outputs = manager.Queue(maxsize=50) + self.audio_inputs = manager.Queue(maxsize=50) + self.audio_outputs = manager.Queue(maxsize=50) logger.info("[ComfyStreamClient] Created manager queues") logger.info("[ComfyStreamClient] About to create ProcessPoolExecutor...") @@ -144,58 +143,25 @@ async def distribute_frames(self): self.running_prompts[worker_id] = new_task async def worker_loop(self, worker_id: int): - """Worker process that handles frames""" - logger.info(f"[ComfyStreamClient] Worker {worker_id} started - PID: {os.getpid()}") + """Worker process that continuously processes prompts""" + logger.info(f"[Worker {worker_id}] Started - PID: {os.getpid()}") + + # Get prompt for this worker + prompt_index = worker_id % len(self.current_prompts) + prompt = self.current_prompts[prompt_index] + frame_count = 0 while True: - # Check if this worker should process the next frame - async with self.distribution_lock: - should_process = (self.next_worker == worker_id) - if should_process: - # Move to next worker for round-robin distribution - self.next_worker = (self.next_worker + 1) % self.max_workers - - if should_process: - # Get prompt - cycle through available prompts - prompt_index = worker_id % len(self.current_prompts) - prompt = self.current_prompts[prompt_index] - - try: - logger.debug(f"[ComfyStreamClient] Worker {worker_id} processing frame with prompt {prompt_index} - PID: {os.getpid()}") - # Process a single frame - await self.process_frame_with_prompt(prompt) - except Exception as e: - logger.error(f"[ComfyStreamClient] Error in worker {worker_id}: {str(e)}") - await asyncio.sleep(0.1) # Avoid tight error loop - else: - # Small wait to avoid busy waiting - await asyncio.sleep(0.01) - - async def process_frame_with_prompt(self, prompt): - """Process a single frame with the given prompt""" - try: - # Get frame from input queue if available - if not self.image_inputs.empty(): - # Non-blocking queue check - frame = None - try: - frame = self.image_inputs.get_nowait() - except Exception: - pass - - if frame is not None: - # Queue the prompt to process this frame - await self.comfy_client.queue_prompt(prompt) - - # Assuming result will eventually appear in output queue - # You may need logic to match inputs with outputs - - else: - # No frames to process, small wait - await asyncio.sleep(0.01) - - except Exception as e: - logger.error(f"[ComfyStreamClient] Error processing frame: {str(e)}") + try: + logger.debug(f"[Worker {worker_id}] Starting prompt execution {frame_count}") + # Continuously execute the prompt + # The LoadTensor node will block until a frame is available + await self.comfy_client.queue_prompt(prompt) + frame_count += 1 + logger.info(f"[Worker {worker_id}] Completed prompt execution {frame_count}") + except Exception as e: + logger.error(f"[Worker {worker_id}] Error on frame {frame_count}: {str(e)}") + await asyncio.sleep(0.1) async def cleanup(self): async with self.cleanup_lock: @@ -228,21 +194,21 @@ def put_video_input(self, frame): # Check if frame is FrameProxy if isinstance(frame, FrameProxy): proxy = frame - # Otherwise create a proxy (assuming frame is av.VideoFrame as in pipeline.py) else: proxy = FrameProxy.avframe_to_frameproxy(frame) # Handle queue being full if self.image_inputs.full(): + # logger.warning(f"[ComfyStreamClient] Input queue full, dropping oldest frame") try: self.image_inputs.get_nowait() except Exception: pass self.image_inputs.put_nowait(proxy) - logger.debug(f"[ComfyStreamClient] Video input queued.") + # logger.info(f"[ComfyStreamClient] Video input queued. Queue size: {self.image_inputs.qsize()}") except Exception as e: - logger.info(f"[ComfyStreamClient] Error putting video frame: {str(e)}") + logger.error(f"[ComfyStreamClient] Error putting video frame: {str(e)}") def put_audio_input(self, frame): self.audio_inputs.put(frame) @@ -254,14 +220,13 @@ async def get_video_output(self): asyncio.get_event_loop().run_in_executor(None, self.image_outputs.get), timeout=5.0 ) - logger.debug(f"[ComfyStreamClient] get_video_output returning tensor - PID: {os.getpid()}") - # No need for permutation here - tensor should already be in the right format + logger.info(f"[ComfyStreamClient] get_video_output returning tensor: {tensor.shape} - PID: {os.getpid()}") return tensor except asyncio.TimeoutError: logger.warning(f"[ComfyStreamClient] get_video_output timeout - PID: {os.getpid()}") return torch.zeros((1, 3, self.height, self.width), dtype=torch.float32) except Exception as e: - logger.info(f"[ComfyStreamClient] Error getting video output: {str(e)} - PID: {os.getpid()}") + logger.error(f"[ComfyStreamClient] Error getting video output: {str(e)} - PID: {os.getpid()}") return torch.zeros((1, 3, self.height, self.width), dtype=torch.float32) async def get_audio_output(self): From 5bdc1dfeb0b11e2ace4ba265248424080652f02d Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Tue, 27 May 2025 12:57:59 -0400 Subject: [PATCH 32/42] Better terminal close handling, fixes to kwargs sent to client. --- server/app_multi.py | 85 ++++++++++++++++++++++++++++--- src/comfystream/client_multi.py | 7 ++- src/comfystream/pipeline_multi.py | 19 +++++-- 3 files changed, 96 insertions(+), 15 deletions(-) diff --git a/server/app_multi.py b/server/app_multi.py index 20d60a60..4f8557ec 100644 --- a/server/app_multi.py +++ b/server/app_multi.py @@ -5,6 +5,7 @@ import os import sys import torch +import signal # Initialize CUDA before any other imports to prevent core dump. if torch.cuda.is_available(): @@ -31,6 +32,8 @@ logging.getLogger("aiortc.rtcrtpsender").setLevel(logging.WARNING) logging.getLogger("aiortc.rtcrtpreceiver").setLevel(logging.WARNING) +# Global variable to track the app for signal handling +app_instance = None MAX_BITRATE = 2000000 MIN_BITRATE = 2000000 @@ -371,24 +374,74 @@ async def on_startup(app: web.Application): app["pipeline"] = Pipeline( width=512, height=512, + max_workers=app["workers"], + comfyui_inference_log_level=app.get("comfui_inference_log_level", None), + frame_log_file=app.get("frame_log_file", None), cwd=app["workspace"], disable_cuda_malloc=True, gpu_only=True, preview_method='none', - max_workers=app["workers"], - comfyui_inference_log_level=app.get("comfui_inference_log_level", None), - frame_log_file=app.get("frame_log_file", None), - base_directory=app["workspace"] ) app["pcs"] = set() app["video_tracks"] = {} async def on_shutdown(app: web.Application): + logger.info("Starting server shutdown...") + + # Clean up pipeline first + if "pipeline" in app: + try: + await app["pipeline"].cleanup() + logger.info("Pipeline cleanup completed") + except Exception as e: + logger.error(f"Error during pipeline cleanup: {e}") + + # Clean up peer connections pcs = app["pcs"] - coros = [pc.close() for pc in pcs] - await asyncio.gather(*coros) - pcs.clear() + if pcs: + logger.info(f"Closing {len(pcs)} peer connections...") + coros = [pc.close() for pc in pcs] + await asyncio.gather(*coros, return_exceptions=True) + pcs.clear() + logger.info("Peer connections closed") + + logger.info("Server shutdown completed") + + +def signal_handler(signum, frame): + """Handle SIGINT (Ctrl+C) gracefully""" + logger.info(f"Received signal {signum}, initiating graceful shutdown...") + + if app_instance: + # Get the current event loop + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + # Schedule the shutdown coroutine + asyncio.create_task(shutdown_server()) + else: + # If no loop is running, run the shutdown directly + asyncio.run(shutdown_server()) + except RuntimeError: + # If we can't get the loop, force exit + logger.warning("Could not get event loop, forcing exit") + os._exit(1) + else: + logger.warning("No app instance found, forcing exit") + os._exit(1) + +async def shutdown_server(): + """Perform graceful server shutdown""" + try: + if app_instance: + await on_shutdown(app_instance) + logger.info("Graceful shutdown completed") + except Exception as e: + logger.error(f"Error during graceful shutdown: {e}") + finally: + # Force exit after cleanup attempt + os._exit(0) if __name__ == "__main__": @@ -445,6 +498,11 @@ async def on_shutdown(app: web.Application): ) args = parser.parse_args() + # Set up signal handlers + signal.signal(signal.SIGINT, signal_handler) + if hasattr(signal, 'SIGTERM'): + signal.signal(signal.SIGTERM, signal_handler) + logging.basicConfig( level=args.log_level.upper(), format="%(asctime)s [%(levelname)s] %(message)s", @@ -452,6 +510,7 @@ async def on_shutdown(app: web.Application): ) app = web.Application() + app_instance = app # Store reference for signal handler app["media_ports"] = args.media_ports.split(",") if args.media_ports else None app["workspace"] = args.workspace app["frame_log_file"] = args.frame_log_file @@ -501,4 +560,14 @@ def force_print(*args, **kwargs): if args.comfyui_inference_log_level: app["comfui_inference_log_level"] = args.comfyui_inference_log_level - web.run_app(app, host=args.host, port=int(args.port), print=force_print) + try: + web.run_app(app, host=args.host, port=int(args.port), print=force_print) + except KeyboardInterrupt: + logger.info("Received KeyboardInterrupt, shutting down...") + finally: + # Ensure cleanup happens even if web.run_app doesn't call on_shutdown + if app_instance: + try: + asyncio.run(on_shutdown(app_instance)) + except Exception as e: + logger.error(f"Error in final cleanup: {e}") diff --git a/src/comfystream/client_multi.py b/src/comfystream/client_multi.py index a0e78339..4dbb895e 100644 --- a/src/comfystream/client_multi.py +++ b/src/comfystream/client_multi.py @@ -24,7 +24,10 @@ def _test_worker_init(): return os.getpid() class ComfyStreamClient: - def __init__(self, max_workers: int = 1, executor_type: str = "process", **kwargs): + def __init__(self, + max_workers: int = 1, + executor_type: str = "process", + **kwargs): logger.info(f"[ComfyStreamClient] Main Process ID: {os.getpid()}") logger.info(f"[ComfyStreamClient] __init__ start, max_workers: {max_workers}, executor_type: {executor_type}") @@ -37,7 +40,7 @@ def __init__(self, max_workers: int = 1, executor_type: str = "process", **kwarg kwargs['cwd'] = os.path.abspath(kwargs['cwd']) logger.info(f"[ComfyStreamClient] Converted workspace path to absolute: {kwargs['cwd']}") - logger.info("[ComfyStreamClient] Config kwargs:", kwargs) + logger.info("[ComfyStreamClient] Config kwargs: %s", kwargs) try: self.config = Configuration(**kwargs) diff --git a/src/comfystream/pipeline_multi.py b/src/comfystream/pipeline_multi.py index cc48ae27..4b7c6516 100644 --- a/src/comfystream/pipeline_multi.py +++ b/src/comfystream/pipeline_multi.py @@ -25,18 +25,27 @@ class Pipeline: postprocessing, and queue management. """ - def __init__(self, width: int = 512, height: int = 512, - comfyui_inference_log_level: Optional[int] = None, frame_log_file: Optional[str] = None, **kwargs): + def __init__(self, + width: int = 512, + height: int = 512, + max_workers: int = 1, + comfyui_inference_log_level: Optional[int] = None, + frame_log_file: Optional[str] = None, + **kwargs): """Initialize the pipeline with the given configuration. Args: width: Width of the video frames (default: 512) height: Height of the video frames (default: 512) + max_workers: Number of worker processes (default: 1) comfyui_inference_log_level: The logging level for ComfyUI inference. - Defaults to None, using the global ComfyUI log level. - **kwargs: Additional arguments to pass to the ComfyStreamClient + frame_log_file: Path to frame timing log file + **kwargs: Additional arguments to pass to the ComfyStreamClient (cwd, disable_cuda_malloc, etc.) """ - self.client = ComfyStreamClient(**kwargs) + self.client = ComfyStreamClient( + max_workers=max_workers, + executor_type="process", + **kwargs) self.width = width self.height = height From 63cebb085944538823da06e721e370baa82d5d7d Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Wed, 28 May 2025 12:34:58 -0400 Subject: [PATCH 33/42] Fixed issue with pipeline reset on workflow change or UI refresh, removed extreanous executor type param, commented out some logging. --- server/app_multi.py | 32 ++- src/comfystream/client_multi.py | 283 ++++++++++++++++---------- src/comfystream/pipeline_multi.py | 13 +- src/comfystream/tensor_cache_multi.py | 8 +- 4 files changed, 222 insertions(+), 114 deletions(-) diff --git a/server/app_multi.py b/server/app_multi.py index 4f8557ec..f4d6f82e 100644 --- a/server/app_multi.py +++ b/server/app_multi.py @@ -389,7 +389,7 @@ async def on_startup(app: web.Application): async def on_shutdown(app: web.Application): logger.info("Starting server shutdown...") - # Clean up pipeline first + # Clean up pipeline first - this should terminate worker processes if "pipeline" in app: try: await app["pipeline"].cleanup() @@ -424,13 +424,37 @@ def signal_handler(signum, frame): # If no loop is running, run the shutdown directly asyncio.run(shutdown_server()) except RuntimeError: - # If we can't get the loop, force exit - logger.warning("Could not get event loop, forcing exit") - os._exit(1) + # If we can't get the loop, force cleanup and exit + logger.warning("Could not get event loop, forcing cleanup and exit") + force_cleanup_and_exit() else: logger.warning("No app instance found, forcing exit") os._exit(1) +def force_cleanup_and_exit(): + """Force cleanup of processes and exit""" + try: + # Try to cleanup the pipeline synchronously + if app_instance and "pipeline" in app_instance: + pipeline = app_instance["pipeline"] + if hasattr(pipeline, 'client') and hasattr(pipeline.client, 'comfy_client'): + # Force shutdown of the executor + if hasattr(pipeline.client.comfy_client, 'executor'): + executor = pipeline.client.comfy_client.executor + if hasattr(executor, 'shutdown'): + logger.info("Force shutting down executor...") + executor.shutdown(wait=False) + if hasattr(executor, '_processes'): + # Terminate all worker processes + for process in executor._processes: + if process.is_alive(): + logger.info(f"Terminating worker process {process.pid}") + process.terminate() + except Exception as e: + logger.error(f"Error in force cleanup: {e}") + finally: + os._exit(1) + async def shutdown_server(): """Perform graceful server shutdown""" try: diff --git a/src/comfystream/client_multi.py b/src/comfystream/client_multi.py index 4dbb895e..5b6a591c 100644 --- a/src/comfystream/client_multi.py +++ b/src/comfystream/client_multi.py @@ -12,7 +12,7 @@ from comfystream.tensor_cache_multi import init_tensor_cache from comfy.cli_args_types import Configuration -from comfy.distributed.executors import ProcessPoolExecutor # Use ComfyUI's executor +from comfy.distributed.executors import ProcessPoolExecutor # Use ComfyUI's executor wrapper from comfy.api.components.schema.prompt import PromptDictInput from comfy.client.embedded_comfy_client import EmbeddedComfyClient from comfystream.frame_proxy import FrameProxy @@ -26,10 +26,9 @@ def _test_worker_init(): class ComfyStreamClient: def __init__(self, max_workers: int = 1, - executor_type: str = "process", **kwargs): logger.info(f"[ComfyStreamClient] Main Process ID: {os.getpid()}") - logger.info(f"[ComfyStreamClient] __init__ start, max_workers: {max_workers}, executor_type: {executor_type}") + logger.info(f"[ComfyStreamClient] __init__ start, max_workers: {max_workers}") # Store default dimensions self.width = kwargs.get('width', 512) @@ -46,104 +45,112 @@ def __init__(self, self.config = Configuration(**kwargs) print("[ComfyStreamClient] Configuration created") - if executor_type == "process": - logger.info("[ComfyStreamClient] Initializing process executor") - ctx = mp.get_context("spawn") - logger.info(f"[ComfyStreamClient] Using multiprocessing context: {ctx.get_start_method()}") - - manager = ctx.Manager() - logger.info("[ComfyStreamClient] Created multiprocessing context and manager") - - self.image_inputs = manager.Queue(maxsize=50) - self.image_outputs = manager.Queue(maxsize=50) - self.audio_inputs = manager.Queue(maxsize=50) - self.audio_outputs = manager.Queue(maxsize=50) - logger.info("[ComfyStreamClient] Created manager queues") - - logger.info("[ComfyStreamClient] About to create ProcessPoolExecutor...") - try: - # Create executor first - executor = ProcessPoolExecutor( - max_workers=max_workers, - initializer=init_tensor_cache, - initargs=(self.image_inputs, self.image_outputs, self.audio_inputs, self.audio_outputs) - ) - logger.info("[ComfyStreamClient] ProcessPoolExecutor created successfully") - - # Create EmbeddedComfyClient with the executor - logger.info("[ComfyStreamClient] Creating EmbeddedComfyClient with executor") - self.comfy_client = EmbeddedComfyClient(self.config, executor=executor) - logger.info("[ComfyStreamClient] EmbeddedComfyClient created successfully") - - # Submit a test task to ensure worker processes are initialized - logger.info("[ComfyStreamClient] Testing worker process initialization...") - test_future = executor.submit(_test_worker_init) # Use the named function instead of lambda - try: - worker_pid = test_future.result(timeout=30) # 30 second timeout - logger.info(f"[ComfyStreamClient] Worker process initialized successfully (PID: {worker_pid})") - except Exception as e: - logger.info(f"[ComfyStreamClient] Error initializing worker process: {str(e)}") - raise - - except Exception as e: - logger.info(f"[ComfyStreamClient] Error during initialization: {str(e)}") - logger.info(f"[ComfyStreamClient] Error type: {type(e)}") - import traceback - logger.info(f"[ComfyStreamClient] Error traceback: {traceback.format_exc()}") - raise - - else: - logger.info("[ComfyStreamClient] Using default executor") - logger.info("[ComfyStreamClient] Creating EmbeddedComfyClient in main process") - self.comfy_client = EmbeddedComfyClient(self.config) - logger.info("[ComfyStreamClient] EmbeddedComfyClient created in main process") - - self.running_prompts = {} - self.current_prompts = [] - self.cleanup_lock = asyncio.Lock() - self.max_workers = max_workers - self.worker_tasks = [] - self.next_worker = 0 - self.distribution_lock = asyncio.Lock() - logger.info("[ComfyStreamClient] __init__ complete") - + logger.info("[ComfyStreamClient] Initializing process executor") + ctx = mp.get_context("spawn") + logger.info(f"[ComfyStreamClient] Using multiprocessing context: {ctx.get_start_method()}") + + manager = ctx.Manager() + logger.info("[ComfyStreamClient] Created multiprocessing context and manager") + + self.image_inputs = manager.Queue(maxsize=50) + self.image_outputs = manager.Queue(maxsize=50) + self.audio_inputs = manager.Queue(maxsize=50) + self.audio_outputs = manager.Queue(maxsize=50) + logger.info("[ComfyStreamClient] Created manager queues") + + logger.info("[ComfyStreamClient] About to create ProcessPoolExecutor...") + + executor = ProcessPoolExecutor( + max_workers=max_workers, + initializer=init_tensor_cache, + initargs=(self.image_inputs, self.image_outputs, self.audio_inputs, self.audio_outputs) + ) + logger.info("[ComfyStreamClient] ProcessPoolExecutor created successfully") + + # Create EmbeddedComfyClient with the executor + logger.info("[ComfyStreamClient] Creating EmbeddedComfyClient with executor") + self.comfy_client = EmbeddedComfyClient(self.config, executor=executor) + logger.info("[ComfyStreamClient] EmbeddedComfyClient created successfully") + + # Submit a test task to ensure worker processes are initialized + logger.info("[ComfyStreamClient] Testing worker process initialization...") + test_future = executor.submit(_test_worker_init) + try: + worker_pid = test_future.result(timeout=30) # 30 second timeout + logger.info(f"[ComfyStreamClient] Worker process initialized successfully (PID: {worker_pid})") + except Exception as e: + logger.info(f"[ComfyStreamClient] Error initializing worker process: {str(e)}") + raise + except Exception as e: logger.info(f"[ComfyStreamClient] Error during initialization: {str(e)}") logger.info(f"[ComfyStreamClient] Error type: {type(e)}") import traceback logger.info(f"[ComfyStreamClient] Error traceback: {traceback.format_exc()}") raise + + self.running_prompts = {} + self.current_prompts = [] + self.cleanup_lock = asyncio.Lock() + self.max_workers = max_workers + self.worker_tasks = [] + self.next_worker = 0 + self.distribution_lock = asyncio.Lock() + self.shutting_down = False # Add shutdown flag + self.distribution_task = None # Track distribution task + logger.info("[ComfyStreamClient] Initialized successfully") async def set_prompts(self, prompts: List[PromptDictInput]): logger.info("set_prompts start") self.current_prompts = [convert_prompt(prompt) for prompt in prompts] - # Start the distribution manager - distribution_task = asyncio.create_task(self.distribute_frames()) - self.running_prompts[-1] = distribution_task # Use -1 as a special key for the manager + # Start the distribution manager only if not already running + if self.distribution_task is None or self.distribution_task.done(): + self.shutting_down = False # Reset shutdown flag + self.distribution_task = asyncio.create_task(self.distribute_frames()) + self.running_prompts[-1] = self.distribution_task # Use -1 as a special key for the manager logger.info("set_prompts end") async def distribute_frames(self): """Manager that distributes frames across workers in round-robin fashion""" logger.info(f"[ComfyStreamClient] Starting frame distribution manager") - # Initialize worker tasks - self.worker_tasks = [] - for worker_id in range(self.max_workers): - worker_task = asyncio.create_task(self.worker_loop(worker_id)) - self.worker_tasks.append(worker_task) - self.running_prompts[worker_id] = worker_task - - # Keep the manager running to monitor workers - while True: - await asyncio.sleep(1.0) # Check periodically - # Restart any crashed workers - for worker_id, task in enumerate(self.worker_tasks): - if task.done(): - logger.warning(f"Worker {worker_id} crashed, restarting") - new_task = asyncio.create_task(self.worker_loop(worker_id)) - self.worker_tasks[worker_id] = new_task - self.running_prompts[worker_id] = new_task + try: + # Initialize worker tasks + self.worker_tasks = [] + for worker_id in range(self.max_workers): + worker_task = asyncio.create_task(self.worker_loop(worker_id)) + self.worker_tasks.append(worker_task) + self.running_prompts[worker_id] = worker_task + + # Keep the manager running to monitor workers + while not self.shutting_down: + await asyncio.sleep(1.0) # Check periodically + + # Only restart crashed workers if we're not shutting down + if not self.shutting_down: + for worker_id, task in enumerate(self.worker_tasks): + if task.done(): + # Check if the task was cancelled (graceful shutdown) or crashed + if task.cancelled(): + logger.info(f"Worker {worker_id} was cancelled (graceful shutdown)") + else: + # Check if there was an exception + try: + task.result() + logger.info(f"Worker {worker_id} completed normally") + except Exception as e: + logger.warning(f"Worker {worker_id} crashed with error: {e}, restarting") + new_task = asyncio.create_task(self.worker_loop(worker_id)) + self.worker_tasks[worker_id] = new_task + self.running_prompts[worker_id] = new_task + + except asyncio.CancelledError: + logger.info("[ComfyStreamClient] Distribution manager cancelled") + except Exception as e: + logger.error(f"[ComfyStreamClient] Error in distribution manager: {e}") + finally: + logger.info("[ComfyStreamClient] Distribution manager stopped") async def worker_loop(self, worker_id: int): """Worker process that continuously processes prompts""" @@ -154,35 +161,107 @@ async def worker_loop(self, worker_id: int): prompt = self.current_prompts[prompt_index] frame_count = 0 - while True: - try: - logger.debug(f"[Worker {worker_id}] Starting prompt execution {frame_count}") - # Continuously execute the prompt - # The LoadTensor node will block until a frame is available - await self.comfy_client.queue_prompt(prompt) - frame_count += 1 - logger.info(f"[Worker {worker_id}] Completed prompt execution {frame_count}") - except Exception as e: - logger.error(f"[Worker {worker_id}] Error on frame {frame_count}: {str(e)}") - await asyncio.sleep(0.1) + try: + while not self.shutting_down: + try: + # Check if we should stop before processing + if self.shutting_down: + break + + # Continuously execute the prompt + # The LoadTensor node will block until a frame is available + await self.comfy_client.queue_prompt(prompt) + frame_count += 1 + except asyncio.CancelledError: + logger.info(f"[Worker {worker_id}] Cancelled after {frame_count} frames") + break + except Exception as e: + if self.shutting_down: + logger.info(f"[Worker {worker_id}] Stopping due to shutdown") + break + logger.error(f"[Worker {worker_id}] Error on frame {frame_count}: {str(e)}") + await asyncio.sleep(0.1) + except asyncio.CancelledError: + logger.info(f"[Worker {worker_id}] Task cancelled") + finally: + logger.info(f"[Worker {worker_id}] Stopped after processing {frame_count} frames") async def cleanup(self): async with self.cleanup_lock: - for task in self.worker_tasks: - task.cancel() + logger.info("[ComfyStreamClient] Starting cleanup...") + + # Set shutdown flag to stop workers gracefully + self.shutting_down = True + + # Cancel distribution task first + if self.distribution_task and not self.distribution_task.done(): + self.distribution_task.cancel() try: - await task + await self.distribution_task except asyncio.CancelledError: pass - if self.comfy_client.is_running: + # Cancel all worker tasks + for task in self.worker_tasks: + if not task.done(): + task.cancel() + + # Wait for tasks to complete cancellation + if self.worker_tasks: try: - await self.comfy_client.__aexit__() + await asyncio.gather(*self.worker_tasks, return_exceptions=True) + logger.info("[ComfyStreamClient] All worker tasks stopped") + except Exception as e: + logger.error(f"Error waiting for worker tasks: {e}") + + # Clear the tasks list + self.worker_tasks.clear() + self.running_prompts.clear() + + # Cleanup the ComfyUI client and its executor + if hasattr(self, 'comfy_client') and self.comfy_client.is_running: + try: + # Get the executor before closing the client + executor = getattr(self.comfy_client, 'executor', None) + + # Close the client first + await self.comfy_client.__aexit__(None, None, None) + logger.info("[ComfyStreamClient] ComfyUI client closed") + + # Then shutdown the executor and terminate processes + if executor: + logger.info("[ComfyStreamClient] Shutting down executor...") + # Shutdown the executor + executor.shutdown(wait=False) + + # Force terminate any remaining processes + if hasattr(executor, '_processes'): + for process in executor._processes: + if process.is_alive(): + logger.info(f"[ComfyStreamClient] Terminating worker process {process.pid}") + process.terminate() + # Give it a moment to terminate gracefully + try: + process.join(timeout=2.0) + except: + pass + # Force kill if still alive + if process.is_alive(): + logger.warning(f"[ComfyStreamClient] Force killing worker process {process.pid}") + process.kill() + + logger.info("[ComfyStreamClient] Executor shutdown completed") + except Exception as e: logger.error(f"Error during ComfyClient cleanup: {e}") await self.cleanup_queues() - logger.info("Client cleanup complete") + + # Reset state for potential reuse + self.shutting_down = False + self.distribution_task = None + + logger.info("[ComfyStreamClient] Client cleanup complete") async def cleanup_queues(self): # TODO: add for audio as well @@ -223,7 +302,7 @@ async def get_video_output(self): asyncio.get_event_loop().run_in_executor(None, self.image_outputs.get), timeout=5.0 ) - logger.info(f"[ComfyStreamClient] get_video_output returning tensor: {tensor.shape} - PID: {os.getpid()}") + # logger.info(f"[ComfyStreamClient] get_video_output returning tensor: {tensor.shape} - PID: {os.getpid()}") return tensor except asyncio.TimeoutError: logger.warning(f"[ComfyStreamClient] get_video_output timeout - PID: {os.getpid()}") diff --git a/src/comfystream/pipeline_multi.py b/src/comfystream/pipeline_multi.py index 4b7c6516..e526ef4b 100644 --- a/src/comfystream/pipeline_multi.py +++ b/src/comfystream/pipeline_multi.py @@ -44,7 +44,6 @@ def __init__(self, """ self.client = ComfyStreamClient( max_workers=max_workers, - executor_type="process", **kwargs) self.width = width self.height = height @@ -213,8 +212,7 @@ def audio_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.Audio # TODO: make it generic to support purely generative video cases async def get_processed_video_frame(self) -> av.VideoFrame: - logger.info(f"[PipelineMulti] get_processed_video_frame called - PID: {os.getpid()}") - logger.debug("[PipelineMulti] Waiting for processed video frame...") + # logger.info(f"[PipelineMulti] get_processed_video_frame called - PID: {os.getpid()}") frame_process_start_time = time.time() # Get the input frame first @@ -242,7 +240,7 @@ async def get_processed_video_frame(self) -> av.VideoFrame: 'csv_path': self.frame_log_file }) - logger.info(f"[PipelineMulti] get_processed_video_frame returning frame - PID: {os.getpid()}") + # logger.info(f"[PipelineMulti] get_processed_video_frame returning frame - PID: {os.getpid()}") return processed_frame async def get_processed_audio_frame(self) -> av.AudioFrame: @@ -277,6 +275,10 @@ async def get_nodes_info(self) -> Dict[str, Any]: async def cleanup(self): """Clean up resources used by the pipeline.""" + logger.info("[PipelineMulti] Starting pipeline cleanup...") + + # Set running flag to false to stop frame processing + self.running = False # Cancel frame logger task if it exists if hasattr(self, 'frame_logger_task') and self.frame_logger_task: @@ -286,7 +288,10 @@ async def cleanup(self): except asyncio.CancelledError: pass + # Clean up the client (this will gracefully shutdown workers) await self.client.cleanup() + + logger.info("[PipelineMulti] Pipeline cleanup complete") async def _process_frame_logs(self): """Background task to process frame logs from queue""" diff --git a/src/comfystream/tensor_cache_multi.py b/src/comfystream/tensor_cache_multi.py index 2d291822..22bd99a5 100644 --- a/src/comfystream/tensor_cache_multi.py +++ b/src/comfystream/tensor_cache_multi.py @@ -21,12 +21,12 @@ def __init__(self, mp_queue): def get(self, block=True, timeout=None): result = self.queue.get(block=block, timeout=timeout) - print(f"[MultiProcessInputQueue] Frame retrieved by worker PID: {os.getpid()}") + # print(f"[MultiProcessInputQueue] Frame retrieved by worker PID: {os.getpid()}") return result def get_nowait(self): result = self.queue.get_nowait() - print(f"[MultiProcessInputQueue] Frame retrieved (nowait) by worker PID: {os.getpid()}") + # print(f"[MultiProcessInputQueue] Frame retrieved (nowait) by worker PID: {os.getpid()}") return result def put(self, item, block=True, timeout=None): @@ -56,7 +56,7 @@ async def put(self, item): # Ensure tensor is on CPU before sending if torch.is_tensor(item): item = item.cpu() - print(f"[MultiProcessOutputQueue] Frame sent from worker PID: {os.getpid()}") + # print(f"[MultiProcessOutputQueue] Frame sent from worker PID: {os.getpid()}") return await loop.run_in_executor(None, self.queue.put, item) def put_nowait(self, item): @@ -64,7 +64,7 @@ def put_nowait(self, item): # Ensure tensor is on CPU before sending if torch.is_tensor(item): item = item.cpu() - print(f"[MultiProcessOutputQueue] Frame sent (nowait) from worker PID: {os.getpid()}") + # print(f"[MultiProcessOutputQueue] Frame sent (nowait) from worker PID: {os.getpid()}") self.queue.put_nowait(item) except queue.Full: try: From a642b05a32c175fcba2134ff1732b5bf4b97845c Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Wed, 28 May 2025 13:21:01 -0400 Subject: [PATCH 34/42] Fixes to logging, attempt at fixing root path issue for models. --- server/app_multi.py | 5 +++-- src/comfystream/client_multi.py | 17 +++++++++-------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/server/app_multi.py b/server/app_multi.py index f4d6f82e..0e3d28f5 100644 --- a/server/app_multi.py +++ b/server/app_multi.py @@ -375,7 +375,7 @@ async def on_startup(app: web.Application): width=512, height=512, max_workers=app["workers"], - comfyui_inference_log_level=app.get("comfui_inference_log_level", None), + comfyui_inference_log_level=app.get("comfyui_inference_log_level", None), frame_log_file=app.get("frame_log_file", None), cwd=app["workspace"], disable_cuda_malloc=True, @@ -582,7 +582,8 @@ def force_print(*args, **kwargs): log_level = logging._nameToLevel.get(args.comfyui_log_level.upper()) logging.getLogger("comfy").setLevel(log_level) if args.comfyui_inference_log_level: - app["comfui_inference_log_level"] = args.comfyui_inference_log_level + log_level = logging._nameToLevel.get(args.comfyui_inference_log_level.upper()) + app["comfyui_inference_log_level"] = log_level try: web.run_app(app, host=args.host, port=int(args.port), print=force_print) diff --git a/src/comfystream/client_multi.py b/src/comfystream/client_multi.py index 5b6a591c..625d088a 100644 --- a/src/comfystream/client_multi.py +++ b/src/comfystream/client_multi.py @@ -35,9 +35,11 @@ def __init__(self, self.height = kwargs.get('height', 512) # Ensure workspace path is absolute - if 'cwd' in kwargs and not os.path.isabs(kwargs['cwd']): - kwargs['cwd'] = os.path.abspath(kwargs['cwd']) - logger.info(f"[ComfyStreamClient] Converted workspace path to absolute: {kwargs['cwd']}") + if 'cwd' in kwargs: + if not os.path.isabs(kwargs['cwd']): + # Convert relative path to absolute path from current working directory + kwargs['cwd'] = os.path.abspath(kwargs['cwd']) + logger.info(f"[ComfyStreamClient] Using absolute workspace path: {kwargs['cwd']}") logger.info("[ComfyStreamClient] Config kwargs: %s", kwargs) @@ -451,10 +453,10 @@ def execute_prompt_in_worker(config_dict, prompt): from comfy.cli_args_types import Configuration from comfy.client.embedded_comfy_client import EmbeddedComfyClient - # On Windows, we need to ensure the working directory is correct - if sys.platform == 'win32': - # Get the workspace directory from config - workspace = config_dict.get('cwd', '..\\..') + # Ensure the working directory is correct for all platforms + workspace = config_dict.get('cwd') + if workspace: + # The workspace should already be an absolute path from the main process logger.info(f"[execute_prompt_in_worker] Setting working directory to: {workspace}") os.chdir(workspace) @@ -464,7 +466,6 @@ def execute_prompt_in_worker(config_dict, prompt): logger.info(f"[execute_prompt_in_worker] Added {workspace} to Python path") logger.info(f"[execute_prompt_in_worker] Current working directory: {os.getcwd()}") - logger.info(f"[execute_prompt_in_worker] Python path: {sys.path}") # Create a new client in the worker process logger.info("[execute_prompt_in_worker] Creating configuration") From cee5445c560bb4e53af3fa1cffe607285f95e4d4 Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Wed, 28 May 2025 22:20:44 -0400 Subject: [PATCH 35/42] Attempts to fix tensorrt directory issue, logging and testing development. --- src/comfystream/client_multi.py | 107 +++++++++---------- src/comfystream/tensor_cache_multi.py | 143 ++++++++++++++++++++++++-- 2 files changed, 179 insertions(+), 71 deletions(-) diff --git a/src/comfystream/client_multi.py b/src/comfystream/client_multi.py index 625d088a..e30fc70e 100644 --- a/src/comfystream/client_multi.py +++ b/src/comfystream/client_multi.py @@ -12,7 +12,7 @@ from comfystream.tensor_cache_multi import init_tensor_cache from comfy.cli_args_types import Configuration -from comfy.distributed.executors import ProcessPoolExecutor # Use ComfyUI's executor wrapper +from comfy.distributed.process_pool_executor import ProcessPoolExecutor from comfy.api.components.schema.prompt import PromptDictInput from comfy.client.embedded_comfy_client import EmbeddedComfyClient from comfystream.frame_proxy import FrameProxy @@ -33,7 +33,7 @@ def __init__(self, # Store default dimensions self.width = kwargs.get('width', 512) self.height = kwargs.get('height', 512) - + # Ensure workspace path is absolute if 'cwd' in kwargs: if not os.path.isabs(kwargs['cwd']): @@ -41,12 +41,16 @@ def __init__(self, kwargs['cwd'] = os.path.abspath(kwargs['cwd']) logger.info(f"[ComfyStreamClient] Using absolute workspace path: {kwargs['cwd']}") + # Register TensorRT paths in main process BEFORE creating ComfyUI client + self._register_tensorrt_paths_main_process(kwargs.get('cwd')) + logger.info("[ComfyStreamClient] Config kwargs: %s", kwargs) try: - self.config = Configuration(**kwargs) - print("[ComfyStreamClient] Configuration created") - + self.config = Configuration(**kwargs) + logger.info("[ComfyStreamClient] Configuration created") + logger.info(f"[ComfyStreamClient] Current working directory: {os.getcwd()}") + logger.info("[ComfyStreamClient] Initializing process executor") ctx = mp.get_context("spawn") logger.info(f"[ComfyStreamClient] Using multiprocessing context: {ctx.get_start_method()}") @@ -65,7 +69,7 @@ def __init__(self, executor = ProcessPoolExecutor( max_workers=max_workers, initializer=init_tensor_cache, - initargs=(self.image_inputs, self.image_outputs, self.audio_inputs, self.audio_outputs) + initargs=(self.image_inputs, self.image_outputs, self.audio_inputs, self.audio_outputs, kwargs.get('cwd')) ) logger.info("[ComfyStreamClient] ProcessPoolExecutor created successfully") @@ -443,58 +447,41 @@ async def get_available_nodes(self): logger.error(f"Error getting node info: {str(e)}") return {} -def execute_prompt_in_worker(config_dict, prompt): - """Execute a prompt in the worker process""" - logger.info(f"[execute_prompt_in_worker] Starting in process {os.getpid()}") - try: - import os - import sys - import torch - from comfy.cli_args_types import Configuration - from comfy.client.embedded_comfy_client import EmbeddedComfyClient - - # Ensure the working directory is correct for all platforms - workspace = config_dict.get('cwd') - if workspace: - # The workspace should already be an absolute path from the main process - logger.info(f"[execute_prompt_in_worker] Setting working directory to: {workspace}") - os.chdir(workspace) - - # Ensure Python path includes the workspace - if workspace not in sys.path: - sys.path.insert(0, workspace) - logger.info(f"[execute_prompt_in_worker] Added {workspace} to Python path") - - logger.info(f"[execute_prompt_in_worker] Current working directory: {os.getcwd()}") - - # Create a new client in the worker process - logger.info("[execute_prompt_in_worker] Creating configuration") - config = Configuration(**config_dict) - - logger.info("[execute_prompt_in_worker] Creating EmbeddedComfyClient") - # Try to initialize CUDA before creating the client - if torch.cuda.is_available(): - logger.info(f"[execute_prompt_in_worker] CUDA device count: {torch.cuda.device_count()}") - # Set the device explicitly - torch.cuda.set_device(0) - logger.info(f"[execute_prompt_in_worker] Set CUDA device to: {torch.cuda.current_device()}") - - client = EmbeddedComfyClient(config) - - # Execute the prompt - logger.info("[execute_prompt_in_worker] Setting up event loop") - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + def _register_tensorrt_paths_main_process(self, workspace_path): + """Register TensorRT paths in the main process for validation""" try: - logger.info("[execute_prompt_in_worker] Queueing prompt") - loop.run_until_complete(client.queue_prompt(prompt)) - logger.info("[execute_prompt_in_worker] Prompt queued successfully") - finally: - logger.info("[execute_prompt_in_worker] Closing event loop") - loop.close() - except Exception as e: - logger.info(f"[execute_prompt_in_worker] Error: {str(e)}") - logger.info(f"[execute_prompt_in_worker] Error type: {type(e)}") - import traceback - logger.info(f"[execute_prompt_in_worker] Error traceback: {traceback.format_exc()}") - raise \ No newline at end of file + from comfy.cmd import folder_paths + + if workspace_path: + base_dir = workspace_path + tensorrt_models_dir = os.path.join(base_dir, "models", "tensorrt") + tensorrt_outputs_dir = os.path.join(base_dir, "outputs", "tensorrt") + else: + tensorrt_models_dir = os.path.join(folder_paths.models_dir, "tensorrt") + tensorrt_outputs_dir = os.path.join(folder_paths.models_dir, "outputs", "tensorrt") + + logger.info(f"[ComfyStreamClient] Registering TensorRT paths in main process") + logger.info(f"[ComfyStreamClient] TensorRT models dir: {tensorrt_models_dir}") + logger.info(f"[ComfyStreamClient] TensorRT outputs dir: {tensorrt_outputs_dir}") + + # Register TensorRT paths + if "tensorrt" in folder_paths.folder_names_and_paths: + existing_paths = folder_paths.folder_names_and_paths["tensorrt"][0] + for path in [tensorrt_models_dir, tensorrt_outputs_dir]: + if path not in existing_paths: + existing_paths.append(path) + folder_paths.folder_names_and_paths["tensorrt"][1].add(".engine") + else: + folder_paths.folder_names_and_paths["tensorrt"] = ( + [tensorrt_models_dir, tensorrt_outputs_dir], + {".engine"} + ) + + # Verify registration + available_files = folder_paths.get_filename_list("tensorrt") + logger.info(f"[ComfyStreamClient] Main process TensorRT files: {available_files}") + + except Exception as e: + logger.error(f"[ComfyStreamClient] Error registering TensorRT paths in main process: {e}") + import traceback + logger.error(f"[ComfyStreamClient] Traceback: {traceback.format_exc()}") \ No newline at end of file diff --git a/src/comfystream/tensor_cache_multi.py b/src/comfystream/tensor_cache_multi.py index 22bd99a5..32debf7b 100644 --- a/src/comfystream/tensor_cache_multi.py +++ b/src/comfystream/tensor_cache_multi.py @@ -1,12 +1,13 @@ # TODO: add better frame management, improve eviction policy fifo might not be the best, skip alternate frames instead # TODO: also make the tensor_cache solution backward compatible for when not using process pool -- after the multi process solution is stable from comfystream import tensor_cache +import logging import queue import torch import asyncio import os -from queue import Queue -from asyncio import Queue as AsyncQueue +from comfy.cmd import folder_paths +logger = logging.getLogger(__name__) image_inputs = None image_outputs = None @@ -21,12 +22,12 @@ def __init__(self, mp_queue): def get(self, block=True, timeout=None): result = self.queue.get(block=block, timeout=timeout) - # print(f"[MultiProcessInputQueue] Frame retrieved by worker PID: {os.getpid()}") + # logger.info(f"[MultiProcessInputQueue] Frame retrieved by worker PID: {os.getpid()}") return result def get_nowait(self): result = self.queue.get_nowait() - # print(f"[MultiProcessInputQueue] Frame retrieved (nowait) by worker PID: {os.getpid()}") + # logger.info(f"[MultiProcessInputQueue] Frame retrieved (nowait) by worker PID: {os.getpid()}") return result def put(self, item, block=True, timeout=None): @@ -56,7 +57,7 @@ async def put(self, item): # Ensure tensor is on CPU before sending if torch.is_tensor(item): item = item.cpu() - # print(f"[MultiProcessOutputQueue] Frame sent from worker PID: {os.getpid()}") + # logger.info(f"[MultiProcessOutputQueue] Frame sent from worker PID: {os.getpid()}") return await loop.run_in_executor(None, self.queue.put, item) def put_nowait(self, item): @@ -64,7 +65,7 @@ def put_nowait(self, item): # Ensure tensor is on CPU before sending if torch.is_tensor(item): item = item.cpu() - # print(f"[MultiProcessOutputQueue] Frame sent (nowait) from worker PID: {os.getpid()}") + # logger.info(f"[MultiProcessOutputQueue] Frame sent (nowait) from worker PID: {os.getpid()}") self.queue.put_nowait(item) except queue.Full: try: @@ -73,7 +74,7 @@ def put_nowait(self, item): except Exception: pass # If still full, drop this frame -def init_tensor_cache(image_inputs, image_outputs, audio_inputs, audio_outputs): +def init_tensor_cache(image_inputs, image_outputs, audio_inputs, audio_outputs, workspace_path=None): """Initialize the tensor cache for a worker process. Args: @@ -81,14 +82,134 @@ def init_tensor_cache(image_inputs, image_outputs, audio_inputs, audio_outputs): image_outputs: Multiprocessing Queue for output images audio_inputs: Multiprocessing Queue for input audio audio_outputs: Multiprocessing Queue for output audio + workspace_path: The ComfyUI workspace path (should be C:\sd\ComfyUI-main) """ - print(f"[init_tensor_cache] Setting up tensor_cache queues in worker - PID: {os.getpid()}") - + logger.info(f"[init_tensor_cache] Setting up tensor_cache queues in worker - PID: {os.getpid()}") + logger.info(f"[init_tensor_cache] Workspace path: {workspace_path}") + logger.info(f"[init_tensor_cache] Current working directory: {os.getcwd()}") + + # Initialize folder_paths in worker process + # Another attempt to fix the tensorrt paths issue via ProcessPoolExecutor + ''' + try: + # Import both possible folder_paths modules + from comfy.cmd import folder_paths as comfy_folder_paths + + # Also try to import the direct folder_paths (which TensorRT loader uses) + import sys + try: + import folder_paths as direct_folder_paths + logger.info("[init_tensor_cache] Successfully imported direct folder_paths") + except ImportError: + # If direct import fails, create an alias + sys.modules['folder_paths'] = comfy_folder_paths + direct_folder_paths = comfy_folder_paths + logger.info("[init_tensor_cache] Created folder_paths alias to comfy.cmd.folder_paths") + + logger.info(f"[init_tensor_cache] comfy_folder_paths.models_dir: {comfy_folder_paths.models_dir}") + logger.info(f"[init_tensor_cache] direct_folder_paths.models_dir: {direct_folder_paths.models_dir}") + + # Use the workspace_path as the base directory for TensorRT paths + if workspace_path: + base_dir = workspace_path + else: + # Fallback to the parent directory of models_dir + base_dir = os.path.dirname(comfy_folder_paths.models_dir) + + # Set up both models/tensorrt and outputs/tensorrt directories + tensorrt_models_dir = os.path.join(base_dir, "models", "tensorrt") + tensorrt_outputs_dir = os.path.join(base_dir, "outputs", "tensorrt") + + logger.info(f"[init_tensor_cache] TensorRT models folder: {tensorrt_models_dir}") + logger.info(f"[init_tensor_cache] TensorRT outputs folder: {tensorrt_outputs_dir}") + logger.info(f"[init_tensor_cache] Models dir exists: {os.path.exists(tensorrt_models_dir)}") + logger.info(f"[init_tensor_cache] Outputs dir exists: {os.path.exists(tensorrt_outputs_dir)}") + + # Register TensorRT paths in BOTH folder_paths modules + tensorrt_config = ([tensorrt_models_dir, tensorrt_outputs_dir], {".engine"}) + + # Update comfy.cmd.folder_paths + comfy_folder_paths.folder_names_and_paths["tensorrt"] = tensorrt_config + logger.info("[init_tensor_cache] Registered TensorRT paths in comfy.cmd.folder_paths") + + # Update direct folder_paths (which TensorRT loader uses) + direct_folder_paths.folder_names_and_paths["tensorrt"] = tensorrt_config + logger.info("[init_tensor_cache] Registered TensorRT paths in direct folder_paths") + + # Also update any existing modules in sys.modules + for module_name, module in sys.modules.items(): + if (module_name.endswith('folder_paths') or module_name == 'folder_paths') and hasattr(module, 'folder_names_and_paths'): + module.folder_names_and_paths["tensorrt"] = tensorrt_config + logger.info(f"[init_tensor_cache] Updated TensorRT paths in {module_name}") + + # Verify the registration worked + logger.info(f"[init_tensor_cache] comfy_folder_paths TensorRT files: {comfy_folder_paths.get_filename_list('tensorrt')}") + logger.info(f"[init_tensor_cache] direct_folder_paths TensorRT files: {direct_folder_paths.get_filename_list('tensorrt')}") + + except Exception as e: + logger.error(f"[init_tensor_cache] Error initializing folder_paths: {e}") + import traceback + logger.error(f"[init_tensor_cache] Traceback: {traceback.format_exc()}") + ''' + # Replace the queues with our wrapped versions that match the original interface tensor_cache.image_inputs = MultiProcessInputQueue(image_inputs) tensor_cache.image_outputs = MultiProcessOutputQueue(image_outputs) tensor_cache.audio_inputs = MultiProcessInputQueue(audio_inputs) tensor_cache.audio_outputs = MultiProcessOutputQueue(audio_outputs) - print(f"[init_tensor_cache] tensor_cache.image_outputs id: {id(tensor_cache.image_outputs)} - PID: {os.getpid()}") - print(f"[init_tensor_cache] Initialization complete - PID: {os.getpid()}") \ No newline at end of file + logger.info(f"[init_tensor_cache] tensor_cache.image_outputs id: {id(tensor_cache.image_outputs)} - PID: {os.getpid()}") + logger.info(f"[init_tensor_cache] Initialization complete - PID: {os.getpid()}") + + return os.getpid() # Return PID for verification + +# THis was an attempt to fix the tensorrt paths issue via ProcessPoolExecutor +''' +def register_tensorrt_paths(workspace_path=None): + """Register TensorRT paths in folder_paths at import time""" + try: + # Use workspace_path if provided, otherwise fall back to folder_paths.models_dir + if workspace_path: + base_dir = workspace_path + tensorrt_models_dir = os.path.join(base_dir, "models", "tensorrt") + else: + # Create tensorrt subdirectory in the models directory + tensorrt_models_dir = os.path.join(folder_paths.models_dir, "tensorrt") + + print(f"[TensorRT] workspace_path: {workspace_path}") + print(f"[TensorRT] folder_paths.models_dir: {folder_paths.models_dir}") + print(f"[TensorRT] Registering paths:") + print(f"[TensorRT] - Models: {tensorrt_models_dir}") + + if "tensorrt" in folder_paths.folder_names_and_paths: + # Update existing registration + existing_paths = folder_paths.folder_names_and_paths["tensorrt"][0] + if tensorrt_models_dir not in existing_paths: + existing_paths.append(tensorrt_models_dir) + folder_paths.folder_names_and_paths["tensorrt"][1].add(".engine") + else: + # Create new registration (same as Depth-Anything approach) + folder_paths.folder_names_and_paths["tensorrt"] = ( + [tensorrt_models_dir], + {".engine"} + ) + + # Verify registration + available_files = folder_paths.get_filename_list("tensorrt") + print(f"[TensorRT] Available engine files: {available_files}") + + except Exception as e: + print(f"[TensorRT] Error registering paths: {e}") + import traceback + traceback.print_exc() + # Fallback to original behavior + if "tensorrt" in folder_paths.folder_names_and_paths: + folder_paths.folder_names_and_paths["tensorrt"][0].append( + os.path.join(folder_paths.models_dir, "tensorrt")) + folder_paths.folder_names_and_paths["tensorrt"][1].add(".engine") + else: + folder_paths.folder_names_and_paths["tensorrt"] = ( + [os.path.join(folder_paths.models_dir, "tensorrt")], + {".engine"} + ) +''' \ No newline at end of file From 7754fbb07f8af296263487d299d97fa44fa9435e Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Thu, 5 Jun 2025 15:09:25 -0400 Subject: [PATCH 36/42] Refactored prompt updating from UI to worker processes, added and removed some logging. --- src/comfystream/client_multi.py | 238 ++++++++++++++++-------------- src/comfystream/pipeline_multi.py | 15 +- 2 files changed, 138 insertions(+), 115 deletions(-) diff --git a/src/comfystream/client_multi.py b/src/comfystream/client_multi.py index e30fc70e..6ec1f876 100644 --- a/src/comfystream/client_multi.py +++ b/src/comfystream/client_multi.py @@ -159,13 +159,9 @@ async def distribute_frames(self): logger.info("[ComfyStreamClient] Distribution manager stopped") async def worker_loop(self, worker_id: int): - """Worker process that continuously processes prompts""" + """Worker loop that continuously processes prompts and picks up updates""" logger.info(f"[Worker {worker_id}] Started - PID: {os.getpid()}") - # Get prompt for this worker - prompt_index = worker_id % len(self.current_prompts) - prompt = self.current_prompts[prompt_index] - frame_count = 0 try: while not self.shutting_down: @@ -173,11 +169,14 @@ async def worker_loop(self, worker_id: int): # Check if we should stop before processing if self.shutting_down: break - - # Continuously execute the prompt - # The LoadTensor node will block until a frame is available - await self.comfy_client.queue_prompt(prompt) + + prompt_index = worker_id % len(self.current_prompts) + current_prompt = self.current_prompts[prompt_index] + + # Execute the current prompt + await self.comfy_client.queue_prompt(current_prompt) frame_count += 1 + except asyncio.CancelledError: logger.info(f"[Worker {worker_id}] Cancelled after {frame_count} frames") break @@ -322,19 +321,30 @@ async def get_audio_output(self): return await loop.run_in_executor(None, self.audio_outputs.get) async def get_available_nodes(self): - """Get metadata and available nodes info in a single pass""" - # TODO: make it for for multiple prompts - if not self.running_prompts: + """Get metadata and available nodes info in the main process""" + logger.info("[ComfyStreamClient] get_available_nodes called") + + if not self.running_prompts or not self.current_prompts: + logger.info("[ComfyStreamClient] No running prompts or current prompts, returning empty") return {} try: + logger.info("[ComfyStreamClient] Starting node info gathering...") + + # Run node info gathering in the main process instead of worker process + # This avoids interfering with the ongoing video processing from comfy.nodes.package import import_all_nodes_in_workspace + logger.info("[ComfyStreamClient] Imported import_all_nodes_in_workspace") + nodes = import_all_nodes_in_workspace() - + logger.info(f"[ComfyStreamClient] Loaded {len(nodes.NODE_CLASS_MAPPINGS)} node classes") + all_prompts_nodes_info = {} for prompt_index, prompt in enumerate(self.current_prompts): - # Get set of class types we need metadata for, excluding LoadTensor and SaveTensor + logger.info(f"[ComfyStreamClient] Processing prompt {prompt_index}") + + # Get set of class types we need metadata for needed_class_types = { node.get('class_type') for node in prompt.values() @@ -344,107 +354,109 @@ async def get_available_nodes(self): for node_id, node in prompt.items() } nodes_info = {} + + logger.info(f"[ComfyStreamClient] Need metadata for {len(needed_class_types)} class types") - # Only process nodes until we've found all the ones we need + # Process nodes to get metadata for class_type, node_class in nodes.NODE_CLASS_MAPPINGS.items(): - if not remaining_nodes: # Exit early if we've found all needed nodes + if not remaining_nodes: break if class_type not in needed_class_types: continue - # Get metadata for this node type (same as original get_node_metadata) - input_data = node_class.INPUT_TYPES() if hasattr(node_class, 'INPUT_TYPES') else {} - input_info = {} - - # Process required inputs - if 'required' in input_data: - for name, value in input_data['required'].items(): - if isinstance(value, tuple): - if len(value) == 1 and isinstance(value[0], list): - # Handle combo box case where value is ([option1, option2, ...],) - input_info[name] = { - 'type': 'combo', - 'value': value[0], # The list of options becomes the value - } - elif len(value) == 2: - input_type, config = value - input_info[name] = { - 'type': input_type, - 'required': True, - 'min': config.get('min', None), - 'max': config.get('max', None), - 'widget': config.get('widget', None) - } - elif len(value) == 1: - # Handle simple type case like ('IMAGE',) - input_info[name] = { - 'type': value[0] + try: + # Get metadata for this node type + input_data = node_class.INPUT_TYPES() if hasattr(node_class, 'INPUT_TYPES') else {} + input_info = {} + + # Process required inputs + if 'required' in input_data: + for name, value in input_data['required'].items(): + if isinstance(value, tuple): + if len(value) == 1 and isinstance(value[0], list): + input_info[name] = { + 'type': 'combo', + 'value': value[0], + } + elif len(value) == 2: + input_type, config = value + input_info[name] = { + 'type': input_type, + 'required': True, + 'min': config.get('min', None), + 'max': config.get('max', None), + 'widget': config.get('widget', None) + } + elif len(value) == 1: + input_info[name] = { + 'type': value[0] + } + + # Process optional inputs + if 'optional' in input_data: + for name, value in input_data['optional'].items(): + if isinstance(value, tuple): + if len(value) == 1 and isinstance(value[0], list): + input_info[name] = { + 'type': 'combo', + 'value': value[0], + } + elif len(value) == 2: + input_type, config = value + input_info[name] = { + 'type': input_type, + 'required': False, + 'min': config.get('min', None), + 'max': config.get('max', None), + 'widget': config.get('widget', None) + } + elif len(value) == 1: + input_info[name] = { + 'type': value[0] + } + + # Process nodes in prompt that use this class_type + for node_id in list(remaining_nodes): + node = prompt[node_id] + if node.get('class_type') != class_type: + continue + + node_info = { + 'class_type': class_type, + 'inputs': {} + } + + if 'inputs' in node: + for input_name, input_value in node['inputs'].items(): + input_metadata = input_info.get(input_name, {}) + node_info['inputs'][input_name] = { + 'value': input_value, + 'type': input_metadata.get('type', 'unknown'), + 'min': input_metadata.get('min', None), + 'max': input_metadata.get('max', None), + 'widget': input_metadata.get('widget', None) } - else: - logger.error(f"Unexpected structure for required input {name}: {value}") - - # Process optional inputs with same logic - if 'optional' in input_data: - for name, value in input_data['optional'].items(): - if isinstance(value, tuple): - if len(value) == 1 and isinstance(value[0], list): - # Handle combo box case where value is ([option1, option2, ...],) - input_info[name] = { - 'type': 'combo', - 'value': value[0], # The list of options becomes the value - } - elif len(value) == 2: - input_type, config = value - input_info[name] = { - 'type': input_type, - 'required': False, - 'min': config.get('min', None), - 'max': config.get('max', None), - 'widget': config.get('widget', None) - } - elif len(value) == 1: - # Handle simple type case like ('IMAGE',) - input_info[name] = { - 'type': value[0] - } - else: - logger.error(f"Unexpected structure for optional input {name}: {value}") - - # Now process any nodes in our prompt that use this class_type - for node_id in list(remaining_nodes): - node = prompt[node_id] - if node.get('class_type') != class_type: - continue - - node_info = { - 'class_type': class_type, - 'inputs': {} - } - - if 'inputs' in node: - for input_name, input_value in node['inputs'].items(): - input_metadata = input_info.get(input_name, {}) - node_info['inputs'][input_name] = { - 'value': input_value, - 'type': input_metadata.get('type', 'unknown'), - 'min': input_metadata.get('min', None), - 'max': input_metadata.get('max', None), - 'widget': input_metadata.get('widget', None) - } - # For combo type inputs, include the list of options - if input_metadata.get('type') == 'combo': - node_info['inputs'][input_name]['value'] = input_metadata.get('value', []) - - nodes_info[node_id] = node_info - remaining_nodes.remove(node_id) - - all_prompts_nodes_info[prompt_index] = nodes_info + if input_metadata.get('type') == 'combo': + node_info['inputs'][input_name]['value'] = input_metadata.get('value', []) + + nodes_info[node_id] = node_info + remaining_nodes.remove(node_id) + + except Exception as e: + logger.error(f"[ComfyStreamClient] Error processing class_type {class_type}: {e}") + continue + + all_prompts_nodes_info[prompt_index] = nodes_info + logger.info(f"[ComfyStreamClient] Completed prompt {prompt_index}, found {len(nodes_info)} nodes") + logger.info(f"[ComfyStreamClient] get_available_nodes completed successfully, returning {len(all_prompts_nodes_info)} prompts") return all_prompts_nodes_info except Exception as e: - logger.error(f"Error getting node info: {str(e)}") + logger.error(f"[ComfyStreamClient] Error getting node info: {str(e)}") + import traceback + logger.error(f"[ComfyStreamClient] Traceback: {traceback.format_exc()}") return {} def _register_tensorrt_paths_main_process(self, workspace_path): @@ -460,9 +472,9 @@ def _register_tensorrt_paths_main_process(self, workspace_path): tensorrt_models_dir = os.path.join(folder_paths.models_dir, "tensorrt") tensorrt_outputs_dir = os.path.join(folder_paths.models_dir, "outputs", "tensorrt") - logger.info(f"[ComfyStreamClient] Registering TensorRT paths in main process") - logger.info(f"[ComfyStreamClient] TensorRT models dir: {tensorrt_models_dir}") - logger.info(f"[ComfyStreamClient] TensorRT outputs dir: {tensorrt_outputs_dir}") + # logger.info(f"[ComfyStreamClient] Registering TensorRT paths in main process") + # logger.info(f"[ComfyStreamClient] TensorRT models dir: {tensorrt_models_dir}") + # logger.info(f"[ComfyStreamClient] TensorRT outputs dir: {tensorrt_outputs_dir}") # Register TensorRT paths if "tensorrt" in folder_paths.folder_names_and_paths: @@ -478,10 +490,18 @@ def _register_tensorrt_paths_main_process(self, workspace_path): ) # Verify registration - available_files = folder_paths.get_filename_list("tensorrt") - logger.info(f"[ComfyStreamClient] Main process TensorRT files: {available_files}") + # available_files = folder_paths.get_filename_list("tensorrt") + # logger.info(f"[ComfyStreamClient] Main process TensorRT files: {available_files}") except Exception as e: logger.error(f"[ComfyStreamClient] Error registering TensorRT paths in main process: {e}") import traceback - logger.error(f"[ComfyStreamClient] Traceback: {traceback.format_exc()}") \ No newline at end of file + logger.error(f"[ComfyStreamClient] Traceback: {traceback.format_exc()}") + + async def update_prompts(self, prompts: List[PromptDictInput]): + """Update the existing processing prompts without restarting workers.""" + + # Simply update the current prompts - worker loops will pick up changes on next iteration + self.current_prompts = [convert_prompt(prompt) for prompt in prompts] + + logger.info("[ComfyStreamClient] Prompts updated") \ No newline at end of file diff --git a/src/comfystream/pipeline_multi.py b/src/comfystream/pipeline_multi.py index e526ef4b..32446d3a 100644 --- a/src/comfystream/pipeline_multi.py +++ b/src/comfystream/pipeline_multi.py @@ -57,7 +57,7 @@ def __init__(self, # Add a queue for frame log entries self.running = True - self.next_expected_frame_id = 0 + self.next_expected_frame_id = 0 # Initialize to 0 instead of None self.frame_log_file = frame_log_file self.frame_log_queue = None # Initialize to None by default @@ -108,21 +108,24 @@ async def set_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]] await self.client.set_prompts([prompts]) async def update_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]): - """Update the existing processing prompts. - - Args: - prompts: Either a single prompt dictionary or a list of prompt dictionaries - """ + """Update the existing processing prompts.""" if isinstance(prompts, list): await self.client.update_prompts(prompts) else: await self.client.update_prompts([prompts]) + + logger.info("Prompts updated") async def put_video_frame(self, frame: av.VideoFrame): current_time = time.time() frame.side_data.input = self.video_preprocess(frame) frame.side_data.skipped = True frame.side_data.frame_received_time = current_time + + # Initialize frame ID if it's None (safety check) + if self.next_expected_frame_id is None: + self.next_expected_frame_id = 0 + frame.side_data.frame_id = self.next_expected_frame_id frame.side_data.client_index = -1 self.next_expected_frame_id += 1 From e5d26ea7e56f7ea0880532e5fabafdc25665c03f Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Thu, 5 Jun 2025 15:26:14 -0400 Subject: [PATCH 37/42] Merge fixes. --- server/app_api.py | 5 ++++- server/app_multi.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/server/app_api.py b/server/app_api.py index 708241eb..7de8f2de 100644 --- a/server/app_api.py +++ b/server/app_api.py @@ -281,7 +281,10 @@ async def on_message(message): "[Control] Missing prompt in update_prompt message" ) return - await pipeline.update_prompts(params["prompts"]) + try: + await pipeline.update_prompts(params["prompts"]) + except Exception as e: + logger.error(f"Error updating prompt: {str(e)}") response = {"type": "prompts_updated", "success": True} channel.send(json.dumps(response)) elif params.get("type") == "update_resolution": diff --git a/server/app_multi.py b/server/app_multi.py index 0e3d28f5..ff4392b1 100644 --- a/server/app_multi.py +++ b/server/app_multi.py @@ -260,7 +260,10 @@ async def on_message(message): "[Control] Missing prompt in update_prompt message" ) return - await pipeline.update_prompts(params["prompts"]) + try: + await pipeline.update_prompts(params["prompts"]) + except Exception as e: + logger.error(f"Error updating prompt: {str(e)}") response = {"type": "prompts_updated", "success": True} channel.send(json.dumps(response)) elif params.get("type") == "update_resolution": From b0516d2cb83b77c8afed0a72c7a4517d475c1a6e Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Thu, 5 Jun 2025 15:49:06 -0400 Subject: [PATCH 38/42] Revert of node retrieval testing. --- src/comfystream/client_multi.py | 207 +++++++++++++++----------------- 1 file changed, 97 insertions(+), 110 deletions(-) diff --git a/src/comfystream/client_multi.py b/src/comfystream/client_multi.py index 6ec1f876..23cc9f02 100644 --- a/src/comfystream/client_multi.py +++ b/src/comfystream/client_multi.py @@ -42,7 +42,7 @@ def __init__(self, logger.info(f"[ComfyStreamClient] Using absolute workspace path: {kwargs['cwd']}") # Register TensorRT paths in main process BEFORE creating ComfyUI client - self._register_tensorrt_paths_main_process(kwargs.get('cwd')) + self.register_tensorrt_paths_main_process(kwargs.get('cwd')) logger.info("[ComfyStreamClient] Config kwargs: %s", kwargs) @@ -319,32 +319,21 @@ async def get_video_output(self): async def get_audio_output(self): loop = asyncio.get_event_loop() return await loop.run_in_executor(None, self.audio_outputs.get) - + async def get_available_nodes(self): - """Get metadata and available nodes info in the main process""" - logger.info("[ComfyStreamClient] get_available_nodes called") - - if not self.running_prompts or not self.current_prompts: - logger.info("[ComfyStreamClient] No running prompts or current prompts, returning empty") + """Get metadata and available nodes info in a single pass""" + # TODO: make it for for multiple prompts + if not self.running_prompts: return {} try: - logger.info("[ComfyStreamClient] Starting node info gathering...") - - # Run node info gathering in the main process instead of worker process - # This avoids interfering with the ongoing video processing from comfy.nodes.package import import_all_nodes_in_workspace - logger.info("[ComfyStreamClient] Imported import_all_nodes_in_workspace") - nodes = import_all_nodes_in_workspace() - logger.info(f"[ComfyStreamClient] Loaded {len(nodes.NODE_CLASS_MAPPINGS)} node classes") - + all_prompts_nodes_info = {} for prompt_index, prompt in enumerate(self.current_prompts): - logger.info(f"[ComfyStreamClient] Processing prompt {prompt_index}") - - # Get set of class types we need metadata for + # Get set of class types we need metadata for, excluding LoadTensor and SaveTensor needed_class_types = { node.get('class_type') for node in prompt.values() @@ -354,112 +343,110 @@ async def get_available_nodes(self): for node_id, node in prompt.items() } nodes_info = {} - - logger.info(f"[ComfyStreamClient] Need metadata for {len(needed_class_types)} class types") - # Process nodes to get metadata + # Only process nodes until we've found all the ones we need for class_type, node_class in nodes.NODE_CLASS_MAPPINGS.items(): - if not remaining_nodes: + if not remaining_nodes: # Exit early if we've found all needed nodes break if class_type not in needed_class_types: continue - try: - # Get metadata for this node type - input_data = node_class.INPUT_TYPES() if hasattr(node_class, 'INPUT_TYPES') else {} - input_info = {} - - # Process required inputs - if 'required' in input_data: - for name, value in input_data['required'].items(): - if isinstance(value, tuple): - if len(value) == 1 and isinstance(value[0], list): - input_info[name] = { - 'type': 'combo', - 'value': value[0], - } - elif len(value) == 2: - input_type, config = value - input_info[name] = { - 'type': input_type, - 'required': True, - 'min': config.get('min', None), - 'max': config.get('max', None), - 'widget': config.get('widget', None) - } - elif len(value) == 1: - input_info[name] = { - 'type': value[0] - } - - # Process optional inputs - if 'optional' in input_data: - for name, value in input_data['optional'].items(): - if isinstance(value, tuple): - if len(value) == 1 and isinstance(value[0], list): - input_info[name] = { - 'type': 'combo', - 'value': value[0], - } - elif len(value) == 2: - input_type, config = value - input_info[name] = { - 'type': input_type, - 'required': False, - 'min': config.get('min', None), - 'max': config.get('max', None), - 'widget': config.get('widget', None) - } - elif len(value) == 1: - input_info[name] = { - 'type': value[0] - } - - # Process nodes in prompt that use this class_type - for node_id in list(remaining_nodes): - node = prompt[node_id] - if node.get('class_type') != class_type: - continue - - node_info = { - 'class_type': class_type, - 'inputs': {} - } - - if 'inputs' in node: - for input_name, input_value in node['inputs'].items(): - input_metadata = input_info.get(input_name, {}) - node_info['inputs'][input_name] = { - 'value': input_value, - 'type': input_metadata.get('type', 'unknown'), - 'min': input_metadata.get('min', None), - 'max': input_metadata.get('max', None), - 'widget': input_metadata.get('widget', None) + # Get metadata for this node type (same as original get_node_metadata) + input_data = node_class.INPUT_TYPES() if hasattr(node_class, 'INPUT_TYPES') else {} + input_info = {} + + # Process required inputs + if 'required' in input_data: + for name, value in input_data['required'].items(): + if isinstance(value, tuple): + if len(value) == 1 and isinstance(value[0], list): + # Handle combo box case where value is ([option1, option2, ...],) + input_info[name] = { + 'type': 'combo', + 'value': value[0], # The list of options becomes the value } - if input_metadata.get('type') == 'combo': - node_info['inputs'][input_name]['value'] = input_metadata.get('value', []) - - nodes_info[node_id] = node_info - remaining_nodes.remove(node_id) - - except Exception as e: - logger.error(f"[ComfyStreamClient] Error processing class_type {class_type}: {e}") - continue - - all_prompts_nodes_info[prompt_index] = nodes_info - logger.info(f"[ComfyStreamClient] Completed prompt {prompt_index}, found {len(nodes_info)} nodes") + elif len(value) == 2: + input_type, config = value + input_info[name] = { + 'type': input_type, + 'required': True, + 'min': config.get('min', None), + 'max': config.get('max', None), + 'widget': config.get('widget', None) + } + elif len(value) == 1: + # Handle simple type case like ('IMAGE',) + input_info[name] = { + 'type': value[0] + } + else: + logger.error(f"Unexpected structure for required input {name}: {value}") + + # Process optional inputs with same logic + if 'optional' in input_data: + for name, value in input_data['optional'].items(): + if isinstance(value, tuple): + if len(value) == 1 and isinstance(value[0], list): + # Handle combo box case where value is ([option1, option2, ...],) + input_info[name] = { + 'type': 'combo', + 'value': value[0], # The list of options becomes the value + } + elif len(value) == 2: + input_type, config = value + input_info[name] = { + 'type': input_type, + 'required': False, + 'min': config.get('min', None), + 'max': config.get('max', None), + 'widget': config.get('widget', None) + } + elif len(value) == 1: + # Handle simple type case like ('IMAGE',) + input_info[name] = { + 'type': value[0] + } + else: + logger.error(f"Unexpected structure for optional input {name}: {value}") + + # Now process any nodes in our prompt that use this class_type + for node_id in list(remaining_nodes): + node = prompt[node_id] + if node.get('class_type') != class_type: + continue + + node_info = { + 'class_type': class_type, + 'inputs': {} + } + + if 'inputs' in node: + for input_name, input_value in node['inputs'].items(): + input_metadata = input_info.get(input_name, {}) + node_info['inputs'][input_name] = { + 'value': input_value, + 'type': input_metadata.get('type', 'unknown'), + 'min': input_metadata.get('min', None), + 'max': input_metadata.get('max', None), + 'widget': input_metadata.get('widget', None) + } + # For combo type inputs, include the list of options + if input_metadata.get('type') == 'combo': + node_info['inputs'][input_name]['value'] = input_metadata.get('value', []) + + nodes_info[node_id] = node_info + remaining_nodes.remove(node_id) + + all_prompts_nodes_info[prompt_index] = nodes_info - logger.info(f"[ComfyStreamClient] get_available_nodes completed successfully, returning {len(all_prompts_nodes_info)} prompts") return all_prompts_nodes_info except Exception as e: - logger.error(f"[ComfyStreamClient] Error getting node info: {str(e)}") - import traceback - logger.error(f"[ComfyStreamClient] Traceback: {traceback.format_exc()}") + logger.error(f"Error getting node info: {str(e)}") return {} - def _register_tensorrt_paths_main_process(self, workspace_path): + def register_tensorrt_paths_main_process(self, workspace_path): """Register TensorRT paths in the main process for validation""" try: from comfy.cmd import folder_paths From d70982e78470a77df5fd5e453a91236884f88a1b Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Thu, 5 Jun 2025 16:22:35 -0400 Subject: [PATCH 39/42] Added node cache system to stop process pool interuption, improve response frame management. --- server/app_multi.py | 8 ++ src/comfystream/client_multi.py | 33 ++++-- src/comfystream/pipeline_multi.py | 144 +++++++++++++++++++++++--- src/comfystream/tensor_cache_multi.py | 58 +++++++++-- 4 files changed, 212 insertions(+), 31 deletions(-) diff --git a/server/app_multi.py b/server/app_multi.py index ff4392b1..1abe1974 100644 --- a/server/app_multi.py +++ b/server/app_multi.py @@ -380,6 +380,7 @@ async def on_startup(app: web.Application): max_workers=app["workers"], comfyui_inference_log_level=app.get("comfyui_inference_log_level", None), frame_log_file=app.get("frame_log_file", None), + max_frame_wait_ms=app.get("max_frame_wait", 500), cwd=app["workspace"], disable_cuda_malloc=True, gpu_only=True, @@ -523,6 +524,12 @@ async def shutdown_server(): default=1, help="Number of workers to run", ) + parser.add_argument( + "--max-frame-wait", + type=int, + default=500, + help="Maximum time to wait for a frame before dropping it (milliseconds)", + ) args = parser.parse_args() # Set up signal handlers @@ -542,6 +549,7 @@ async def shutdown_server(): app["workspace"] = args.workspace app["frame_log_file"] = args.frame_log_file app["workers"] = args.workers + app["max_frame_wait"] = args.max_frame_wait app.on_startup.append(on_startup) app.on_shutdown.append(on_shutdown) diff --git a/src/comfystream/client_multi.py b/src/comfystream/client_multi.py index 23cc9f02..b55b7b9a 100644 --- a/src/comfystream/client_multi.py +++ b/src/comfystream/client_multi.py @@ -44,6 +44,9 @@ def __init__(self, # Register TensorRT paths in main process BEFORE creating ComfyUI client self.register_tensorrt_paths_main_process(kwargs.get('cwd')) + # Cache nodes information in main process to avoid ProcessPoolExecutor conflicts + self._initialize_nodes_cache() + logger.info("[ComfyStreamClient] Config kwargs: %s", kwargs) try: @@ -321,19 +324,20 @@ async def get_audio_output(self): return await loop.run_in_executor(None, self.audio_outputs.get) async def get_available_nodes(self): - """Get metadata and available nodes info in a single pass""" - # TODO: make it for for multiple prompts - if not self.running_prompts: + """Get metadata and available nodes info using cached nodes to avoid ProcessPoolExecutor conflicts""" + if not self.current_prompts: return {} - try: - from comfy.nodes.package import import_all_nodes_in_workspace - nodes = import_all_nodes_in_workspace() + # Use cached nodes instead of calling import_all_nodes_in_workspace from worker process + if self._nodes is None: + logger.warning("[ComfyStreamClient] Nodes cache not available, returning empty result") + return {} + try: all_prompts_nodes_info = {} for prompt_index, prompt in enumerate(self.current_prompts): - # Get set of class types we need metadata for, excluding LoadTensor and SaveTensor + # Get set of class types we need metadata for needed_class_types = { node.get('class_type') for node in prompt.values() @@ -345,7 +349,7 @@ async def get_available_nodes(self): nodes_info = {} # Only process nodes until we've found all the ones we need - for class_type, node_class in nodes.NODE_CLASS_MAPPINGS.items(): + for class_type, node_class in self._nodes.NODE_CLASS_MAPPINGS.items(): if not remaining_nodes: # Exit early if we've found all needed nodes break @@ -491,4 +495,15 @@ async def update_prompts(self, prompts: List[PromptDictInput]): # Simply update the current prompts - worker loops will pick up changes on next iteration self.current_prompts = [convert_prompt(prompt) for prompt in prompts] - logger.info("[ComfyStreamClient] Prompts updated") \ No newline at end of file + logger.info("[ComfyStreamClient] Prompts updated") + + def _initialize_nodes_cache(self): + """Initialize nodes cache in main process to avoid ProcessPoolExecutor conflicts""" + try: + logger.info("[ComfyStreamClient] Initializing nodes cache in main process...") + from comfy.nodes.package import import_all_nodes_in_workspace + self._nodes = import_all_nodes_in_workspace() + logger.info(f"[ComfyStreamClient] Cached {len(self._nodes.NODE_CLASS_MAPPINGS)} node types") + except Exception as e: + logger.error(f"[ComfyStreamClient] Error initializing nodes cache: {e}") + self._nodes = None \ No newline at end of file diff --git a/src/comfystream/pipeline_multi.py b/src/comfystream/pipeline_multi.py index 32446d3a..b12b4b22 100644 --- a/src/comfystream/pipeline_multi.py +++ b/src/comfystream/pipeline_multi.py @@ -5,6 +5,7 @@ import logging import time import os +from collections import OrderedDict from typing import Any, Dict, Union, List, Optional from comfystream.client_multi import ComfyStreamClient @@ -30,7 +31,8 @@ def __init__(self, height: int = 512, max_workers: int = 1, comfyui_inference_log_level: Optional[int] = None, - frame_log_file: Optional[str] = None, + frame_log_file: Optional[str] = None, + max_frame_wait_ms: int = 500, **kwargs): """Initialize the pipeline with the given configuration. @@ -40,6 +42,7 @@ def __init__(self, max_workers: Number of worker processes (default: 1) comfyui_inference_log_level: The logging level for ComfyUI inference. frame_log_file: Path to frame timing log file + max_frame_wait_ms: Maximum time to wait for a frame before dropping it (default: 500) **kwargs: Additional arguments to pass to the ComfyStreamClient (cwd, disable_cuda_malloc, etc.) """ self.client = ComfyStreamClient( @@ -51,13 +54,19 @@ def __init__(self, self.video_incoming_frames = asyncio.Queue() self.audio_incoming_frames = asyncio.Queue() + # Frame ordering system (similar to pipeline_api.py) + self.ordered_frames = OrderedDict() # frame_id -> (timestamp, tensor, original_frame) + self.next_expected_frame_id = 0 + self.input_frame_counter = 0 # Separate counter for input frames + self.max_frame_wait_ms = max_frame_wait_ms + self.processed_video_frames = asyncio.Queue() + self.processed_audio_buffer = np.array([], dtype=np.int16) self._comfyui_inference_log_level = comfyui_inference_log_level # Add a queue for frame log entries self.running = True - self.next_expected_frame_id = 0 # Initialize to 0 instead of None self.frame_log_file = frame_log_file self.frame_log_queue = None # Initialize to None by default @@ -65,6 +74,93 @@ def __init__(self, self.frame_log_queue = asyncio.Queue() self.frame_logger_task = asyncio.create_task(self._process_frame_logs()) + # Start background task for collecting and ordering frames + self.collector_task = asyncio.create_task(self._collect_processed_frames()) + + async def _collect_processed_frames(self): + """Background task to collect processed frames and maintain order""" + try: + while self.running: + try: + # Get output from client (this should now return frame_id and tensor) + output = await asyncio.wait_for(self.client.get_video_output(), timeout=0.1) + + if output is not None: + # If client returns just tensor (backward compatibility) + if isinstance(output, torch.Tensor): + # For backward compatibility, assume sequential processing + frame_id = self.next_expected_frame_id + tensor = output + else: + # New format: (frame_id, tensor) + frame_id, tensor = output + + current_time = time.time() + await self._add_frame_to_ordered_buffer(frame_id, current_time, tensor) + + except asyncio.TimeoutError: + # No frame ready, continue + pass + except Exception as e: + logger.error(f"Error collecting processed frame: {e}") + + # Check for frames that have waited too long + await self._check_frame_timeouts() + + # Small sleep to avoid CPU spinning + await asyncio.sleep(0.01) + + except asyncio.CancelledError: + logger.info("[PipelineMulti] Frame collector task cancelled") + except Exception as e: + logger.error(f"[PipelineMulti] Unexpected error in frame collector: {e}") + + async def _add_frame_to_ordered_buffer(self, frame_id, timestamp, tensor): + """Add a processed frame to the ordered buffer""" + self.ordered_frames[frame_id] = (timestamp, tensor) + + # Check if we can release any frames now + await self._release_ordered_frames() + + async def _release_ordered_frames(self): + """Release frames in sequential order""" + # Only release frames in strict sequential order + while self.ordered_frames and self.next_expected_frame_id in self.ordered_frames: + timestamp, tensor = self.ordered_frames.pop(self.next_expected_frame_id) + await self.processed_video_frames.put((self.next_expected_frame_id, tensor)) + logger.debug(f"[PipelineMulti] Released frame {self.next_expected_frame_id} to output queue") + self.next_expected_frame_id += 1 + + async def _check_frame_timeouts(self): + """Check for frames that have waited too long and handle them""" + if not self.ordered_frames: + return + + current_time = time.time() + + # If the next expected frame has timed out, skip it and move on + if self.next_expected_frame_id in self.ordered_frames: + timestamp, _ = self.ordered_frames[self.next_expected_frame_id] + wait_time_ms = (current_time - timestamp) * 1000 + + if wait_time_ms > self.max_frame_wait_ms: + logger.debug(f"[PipelineMulti] Frame {self.next_expected_frame_id} exceeded max wait time, releasing anyway") + await self._release_ordered_frames() + + # Check if we're missing the next expected frame and it's been too long + elif self.ordered_frames: + # The next frame we're expecting isn't in the buffer + # Check how long we've been waiting since the oldest frame in the buffer + oldest_frame_id = min(self.ordered_frames.keys()) + oldest_timestamp, _ = self.ordered_frames[oldest_frame_id] + wait_time_ms = (current_time - oldest_timestamp) * 1000 + + # If we've waited too long, skip the missing frame(s) + if wait_time_ms > self.max_frame_wait_ms: + logger.debug(f"[PipelineMulti] Missing frame {self.next_expected_frame_id}, skipping to {oldest_frame_id}") + self.next_expected_frame_id = oldest_frame_id + await self._release_ordered_frames() + async def initialize(self, prompts): await self.set_prompts(prompts) await self.warm_video() @@ -80,10 +176,18 @@ async def warm_video(self): pts=None, time_base=None ) + # Set frame_id for warmup frames (negative to distinguish from real frames) + dummy_proxy.side_data.frame_id = -(i + 1) logger.debug(f"[PipelineMulti] Warmup: putting dummy frame {i+1}/{WARMUP_RUNS}") self.client.put_video_input(dummy_proxy) - out = await self.client.get_video_output() - logger.debug(f"[PipelineMulti] Warmup: got output for dummy frame {i+1}/{WARMUP_RUNS}: shape={getattr(out, 'shape', None)}") + + # For warmup, we don't need to wait for ordered output + try: + out = await asyncio.wait_for(self.client.get_video_output(), timeout=30.0) + logger.debug(f"[PipelineMulti] Warmup: got output for dummy frame {i+1}/{WARMUP_RUNS}") + except asyncio.TimeoutError: + logger.warning(f"[PipelineMulti] Warmup frame {i+1} timed out") + logger.info("[PipelineMulti] Warmup complete.") async def warm_audio(self): @@ -122,13 +226,11 @@ async def put_video_frame(self, frame: av.VideoFrame): frame.side_data.skipped = True frame.side_data.frame_received_time = current_time - # Initialize frame ID if it's None (safety check) - if self.next_expected_frame_id is None: - self.next_expected_frame_id = 0 - - frame.side_data.frame_id = self.next_expected_frame_id + # Assign frame ID and increment counter + frame_id = self.input_frame_counter + frame.side_data.frame_id = frame_id frame.side_data.client_index = -1 - self.next_expected_frame_id += 1 + self.input_frame_counter += 1 # Log frame at input time to properly track input FPS if self.frame_log_file: @@ -215,15 +317,13 @@ def audio_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.Audio # TODO: make it generic to support purely generative video cases async def get_processed_video_frame(self) -> av.VideoFrame: - # logger.info(f"[PipelineMulti] get_processed_video_frame called - PID: {os.getpid()}") frame_process_start_time = time.time() # Get the input frame first frame = await self.video_incoming_frames.get() - # Then get the output tensor - async with temporary_log_level("comfy", self._comfyui_inference_log_level): - out_tensor = await self.client.get_video_output() + # Get the processed frame from our ordered output queue + processed_frame_id, out_tensor = await self.processed_video_frames.get() # Process the frame processed_frame = self.video_postprocess(out_tensor) @@ -235,7 +335,7 @@ async def get_processed_video_frame(self) -> av.VideoFrame: # Log frame timing if self.frame_log_file: await self.frame_log_queue.put({ - 'frame_id': frame.side_data.frame_id, + 'frame_id': processed_frame_id, 'frame_received_time': frame.side_data.frame_received_time, 'frame_process_start_time': frame_process_start_time, 'frame_processed_time': frame_processed_time, @@ -243,7 +343,6 @@ async def get_processed_video_frame(self) -> av.VideoFrame: 'csv_path': self.frame_log_file }) - # logger.info(f"[PipelineMulti] get_processed_video_frame returning frame - PID: {os.getpid()}") return processed_frame async def get_processed_audio_frame(self) -> av.AudioFrame: @@ -283,6 +382,14 @@ async def cleanup(self): # Set running flag to false to stop frame processing self.running = False + # Cancel collector task + if hasattr(self, 'collector_task') and self.collector_task: + self.collector_task.cancel() + try: + await self.collector_task + except asyncio.CancelledError: + pass + # Cancel frame logger task if it exists if hasattr(self, 'frame_logger_task') and self.frame_logger_task: self.frame_logger_task.cancel() @@ -291,6 +398,11 @@ async def cleanup(self): except asyncio.CancelledError: pass + # Clear ordered frames buffer + self.ordered_frames.clear() + self.next_expected_frame_id = 0 + self.input_frame_counter = 0 + # Clean up the client (this will gracefully shutdown workers) await self.client.cleanup() diff --git a/src/comfystream/tensor_cache_multi.py b/src/comfystream/tensor_cache_multi.py index 32debf7b..d50f558e 100644 --- a/src/comfystream/tensor_cache_multi.py +++ b/src/comfystream/tensor_cache_multi.py @@ -15,6 +15,16 @@ audio_inputs = None audio_outputs = None +# Global frame ID tracking for worker processes +current_frame_id = None +frame_id_mapping = {} # Maps tensor id to frame_id + +class FrameData: + """Wrapper class to carry frame metadata through the processing pipeline""" + def __init__(self, tensor, frame_id=None): + self.tensor = tensor + self.frame_id = frame_id + # Create wrapper classes that match the interface of the original queues class MultiProcessInputQueue: def __init__(self, mp_queue): @@ -22,12 +32,24 @@ def __init__(self, mp_queue): def get(self, block=True, timeout=None): result = self.queue.get(block=block, timeout=timeout) - # logger.info(f"[MultiProcessInputQueue] Frame retrieved by worker PID: {os.getpid()}") + + # Extract frame metadata and store it globally for this worker + global current_frame_id + if hasattr(result, 'side_data') and hasattr(result.side_data, 'frame_id'): + current_frame_id = result.side_data.frame_id + # logger.info(f"[MultiProcessInputQueue] Frame {current_frame_id} retrieved by worker PID: {os.getpid()}") + return result def get_nowait(self): result = self.queue.get_nowait() - # logger.info(f"[MultiProcessInputQueue] Frame retrieved (nowait) by worker PID: {os.getpid()}") + + # Extract frame metadata and store it globally for this worker + global current_frame_id + if hasattr(result, 'side_data') and hasattr(result.side_data, 'frame_id'): + current_frame_id = result.side_data.frame_id + # logger.info(f"[MultiProcessInputQueue] Frame {current_frame_id} retrieved (nowait) by worker PID: {os.getpid()}") + return result def put(self, item, block=True, timeout=None): @@ -49,7 +71,15 @@ def __init__(self, mp_queue): async def get(self): # Convert synchronous get to async loop = asyncio.get_event_loop() - return await loop.run_in_executor(None, self.queue.get) + result = await loop.run_in_executor(None, self.queue.get) + + # Check if this is a tuple with frame_id + if isinstance(result, tuple) and len(result) == 2: + frame_id, tensor = result + return (frame_id, tensor) + else: + # Backward compatibility - return just the tensor + return result async def put(self, item): # Convert synchronous put to async @@ -62,15 +92,31 @@ async def put(self, item): def put_nowait(self, item): try: + # Check if we have a current frame ID to associate with this output + global current_frame_id + # Ensure tensor is on CPU before sending if torch.is_tensor(item): item = item.cpu() - # logger.info(f"[MultiProcessOutputQueue] Frame sent (nowait) from worker PID: {os.getpid()}") - self.queue.put_nowait(item) + + # If we have a frame ID, send it as a tuple + if current_frame_id is not None: + output_data = (current_frame_id, item) + # logger.info(f"[MultiProcessOutputQueue] Frame {current_frame_id} sent (nowait) from worker PID: {os.getpid()}") + else: + output_data = item + # logger.info(f"[MultiProcessOutputQueue] Frame sent (nowait) without ID from worker PID: {os.getpid()}") + + self.queue.put_nowait(output_data) except queue.Full: try: self.queue.get_nowait() # Drop oldest - self.queue.put_nowait(item) + # Try again with the same logic + if current_frame_id is not None: + output_data = (current_frame_id, item) + else: + output_data = item + self.queue.put_nowait(output_data) except Exception: pass # If still full, drop this frame From e22ceff9a1b9b4a339c7bcea003f99bf6627440f Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Thu, 5 Jun 2025 16:39:25 -0400 Subject: [PATCH 40/42] Rebuit frame processing management for smoother playback. --- server/app_multi.py | 8 -- src/comfystream/client_multi.py | 19 ++-- src/comfystream/pipeline_multi.py | 147 +++++++------------------- src/comfystream/tensor_cache_multi.py | 30 ++---- 4 files changed, 49 insertions(+), 155 deletions(-) diff --git a/server/app_multi.py b/server/app_multi.py index 1abe1974..ff4392b1 100644 --- a/server/app_multi.py +++ b/server/app_multi.py @@ -380,7 +380,6 @@ async def on_startup(app: web.Application): max_workers=app["workers"], comfyui_inference_log_level=app.get("comfyui_inference_log_level", None), frame_log_file=app.get("frame_log_file", None), - max_frame_wait_ms=app.get("max_frame_wait", 500), cwd=app["workspace"], disable_cuda_malloc=True, gpu_only=True, @@ -524,12 +523,6 @@ async def shutdown_server(): default=1, help="Number of workers to run", ) - parser.add_argument( - "--max-frame-wait", - type=int, - default=500, - help="Maximum time to wait for a frame before dropping it (milliseconds)", - ) args = parser.parse_args() # Set up signal handlers @@ -549,7 +542,6 @@ async def shutdown_server(): app["workspace"] = args.workspace app["frame_log_file"] = args.frame_log_file app["workers"] = args.workers - app["max_frame_wait"] = args.max_frame_wait app.on_startup.append(on_startup) app.on_shutdown.append(on_shutdown) diff --git a/src/comfystream/client_multi.py b/src/comfystream/client_multi.py index b55b7b9a..2e942c30 100644 --- a/src/comfystream/client_multi.py +++ b/src/comfystream/client_multi.py @@ -162,37 +162,30 @@ async def distribute_frames(self): logger.info("[ComfyStreamClient] Distribution manager stopped") async def worker_loop(self, worker_id: int): - """Worker loop that continuously processes prompts and picks up updates""" - logger.info(f"[Worker {worker_id}] Started - PID: {os.getpid()}") + """Simple worker loop - just process frames continuously""" + logger.info(f"[Worker {worker_id}] Started") frame_count = 0 try: while not self.shutting_down: try: - # Check if we should stop before processing - if self.shutting_down: - break - + # Simple round-robin prompt selection prompt_index = worker_id % len(self.current_prompts) current_prompt = self.current_prompts[prompt_index] - # Execute the current prompt + # Just process the prompt await self.comfy_client.queue_prompt(current_prompt) frame_count += 1 except asyncio.CancelledError: - logger.info(f"[Worker {worker_id}] Cancelled after {frame_count} frames") break except Exception as e: if self.shutting_down: - logger.info(f"[Worker {worker_id}] Stopping due to shutdown") break - logger.error(f"[Worker {worker_id}] Error on frame {frame_count}: {str(e)}") + logger.error(f"[Worker {worker_id}] Error: {e}") await asyncio.sleep(0.1) - except asyncio.CancelledError: - logger.info(f"[Worker {worker_id}] Task cancelled") finally: - logger.info(f"[Worker {worker_id}] Stopped after processing {frame_count} frames") + logger.info(f"[Worker {worker_id}] Processed {frame_count} frames") async def cleanup(self): async with self.cleanup_lock: diff --git a/src/comfystream/pipeline_multi.py b/src/comfystream/pipeline_multi.py index b12b4b22..eae6a7f3 100644 --- a/src/comfystream/pipeline_multi.py +++ b/src/comfystream/pipeline_multi.py @@ -32,7 +32,6 @@ def __init__(self, max_workers: int = 1, comfyui_inference_log_level: Optional[int] = None, frame_log_file: Optional[str] = None, - max_frame_wait_ms: int = 500, **kwargs): """Initialize the pipeline with the given configuration. @@ -42,7 +41,6 @@ def __init__(self, max_workers: Number of worker processes (default: 1) comfyui_inference_log_level: The logging level for ComfyUI inference. frame_log_file: Path to frame timing log file - max_frame_wait_ms: Maximum time to wait for a frame before dropping it (default: 500) **kwargs: Additional arguments to pass to the ComfyStreamClient (cwd, disable_cuda_malloc, etc.) """ self.client = ComfyStreamClient( @@ -54,12 +52,12 @@ def __init__(self, self.video_incoming_frames = asyncio.Queue() self.audio_incoming_frames = asyncio.Queue() - # Frame ordering system (similar to pipeline_api.py) - self.ordered_frames = OrderedDict() # frame_id -> (timestamp, tensor, original_frame) - self.next_expected_frame_id = 0 - self.input_frame_counter = 0 # Separate counter for input frames - self.max_frame_wait_ms = max_frame_wait_ms - self.processed_video_frames = asyncio.Queue() + # Remove complex frame ordering - just use a simple buffer + self.output_buffer = asyncio.Queue(maxsize=6) # Small buffer to smooth output + self.input_frame_counter = 0 + + # Simple frame collection without ordering + self.collector_task = asyncio.create_task(self._collect_frames_simple()) self.processed_audio_buffer = np.array([], dtype=np.int16) @@ -74,92 +72,32 @@ def __init__(self, self.frame_log_queue = asyncio.Queue() self.frame_logger_task = asyncio.create_task(self._process_frame_logs()) - # Start background task for collecting and ordering frames - self.collector_task = asyncio.create_task(self._collect_processed_frames()) - - async def _collect_processed_frames(self): - """Background task to collect processed frames and maintain order""" + async def _collect_frames_simple(self): + """Simple frame collector - no ordering, just buffer""" try: while self.running: try: - # Get output from client (this should now return frame_id and tensor) - output = await asyncio.wait_for(self.client.get_video_output(), timeout=0.1) - - if output is not None: - # If client returns just tensor (backward compatibility) - if isinstance(output, torch.Tensor): - # For backward compatibility, assume sequential processing - frame_id = self.next_expected_frame_id - tensor = output - else: - # New format: (frame_id, tensor) - frame_id, tensor = output - - current_time = time.time() - await self._add_frame_to_ordered_buffer(frame_id, current_time, tensor) - + tensor = await asyncio.wait_for(self.client.get_video_output(), timeout=0.1) + if tensor is not None: + # Just put it in the buffer, don't worry about order + try: + self.output_buffer.put_nowait(tensor) + except asyncio.QueueFull: + # Drop oldest frame if buffer is full + try: + self.output_buffer.get_nowait() + self.output_buffer.put_nowait(tensor) + except asyncio.QueueEmpty: + pass except asyncio.TimeoutError: - # No frame ready, continue pass except Exception as e: - logger.error(f"Error collecting processed frame: {e}") - - # Check for frames that have waited too long - await self._check_frame_timeouts() + logger.error(f"Error collecting frame: {e}") - # Small sleep to avoid CPU spinning - await asyncio.sleep(0.01) + await asyncio.sleep(0.001) # Minimal sleep except asyncio.CancelledError: - logger.info("[PipelineMulti] Frame collector task cancelled") - except Exception as e: - logger.error(f"[PipelineMulti] Unexpected error in frame collector: {e}") - - async def _add_frame_to_ordered_buffer(self, frame_id, timestamp, tensor): - """Add a processed frame to the ordered buffer""" - self.ordered_frames[frame_id] = (timestamp, tensor) - - # Check if we can release any frames now - await self._release_ordered_frames() - - async def _release_ordered_frames(self): - """Release frames in sequential order""" - # Only release frames in strict sequential order - while self.ordered_frames and self.next_expected_frame_id in self.ordered_frames: - timestamp, tensor = self.ordered_frames.pop(self.next_expected_frame_id) - await self.processed_video_frames.put((self.next_expected_frame_id, tensor)) - logger.debug(f"[PipelineMulti] Released frame {self.next_expected_frame_id} to output queue") - self.next_expected_frame_id += 1 - - async def _check_frame_timeouts(self): - """Check for frames that have waited too long and handle them""" - if not self.ordered_frames: - return - - current_time = time.time() - - # If the next expected frame has timed out, skip it and move on - if self.next_expected_frame_id in self.ordered_frames: - timestamp, _ = self.ordered_frames[self.next_expected_frame_id] - wait_time_ms = (current_time - timestamp) * 1000 - - if wait_time_ms > self.max_frame_wait_ms: - logger.debug(f"[PipelineMulti] Frame {self.next_expected_frame_id} exceeded max wait time, releasing anyway") - await self._release_ordered_frames() - - # Check if we're missing the next expected frame and it's been too long - elif self.ordered_frames: - # The next frame we're expecting isn't in the buffer - # Check how long we've been waiting since the oldest frame in the buffer - oldest_frame_id = min(self.ordered_frames.keys()) - oldest_timestamp, _ = self.ordered_frames[oldest_frame_id] - wait_time_ms = (current_time - oldest_timestamp) * 1000 - - # If we've waited too long, skip the missing frame(s) - if wait_time_ms > self.max_frame_wait_ms: - logger.debug(f"[PipelineMulti] Missing frame {self.next_expected_frame_id}, skipping to {oldest_frame_id}") - self.next_expected_frame_id = oldest_frame_id - await self._release_ordered_frames() + pass async def initialize(self, prompts): await self.set_prompts(prompts) @@ -232,17 +170,6 @@ async def put_video_frame(self, frame: av.VideoFrame): frame.side_data.client_index = -1 self.input_frame_counter += 1 - # Log frame at input time to properly track input FPS - if self.frame_log_file: - await self.frame_log_queue.put({ - 'frame_id': frame.side_data.frame_id, - 'frame_received_time': frame.side_data.frame_received_time, - 'frame_process_start_time': None, - 'frame_processed_time': None, - 'client_index': frame.side_data.client_index, - 'csv_path': self.frame_log_file - }) - self.client.put_video_input(frame) await self.video_incoming_frames.put(frame) @@ -317,32 +244,35 @@ def audio_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.Audio # TODO: make it generic to support purely generative video cases async def get_processed_video_frame(self) -> av.VideoFrame: - frame_process_start_time = time.time() - - # Get the input frame first + # Get input frame frame = await self.video_incoming_frames.get() - # Get the processed frame from our ordered output queue - processed_frame_id, out_tensor = await self.processed_video_frames.get() + # Get any available output (don't match input to output) + try: + out_tensor = await asyncio.wait_for(self.output_buffer.get(), timeout=1.0) + except asyncio.TimeoutError: + # Fallback: return input frame to maintain stream + logger.warning("Output timeout, using input frame") + return frame - # Process the frame + # Process and return processed_frame = self.video_postprocess(out_tensor) processed_frame.pts = frame.pts processed_frame.time_base = frame.time_base frame_processed_time = time.time() - # Log frame timing + # Log frame at input time to properly track input FPS if self.frame_log_file: await self.frame_log_queue.put({ - 'frame_id': processed_frame_id, + 'frame_id': frame.side_data.frame_id, 'frame_received_time': frame.side_data.frame_received_time, - 'frame_process_start_time': frame_process_start_time, + 'frame_process_start_time': 0, # TODO: We dont know the start time of the frame processing 'frame_processed_time': frame_processed_time, 'client_index': frame.side_data.client_index, 'csv_path': self.frame_log_file }) - + return processed_frame async def get_processed_audio_frame(self) -> av.AudioFrame: @@ -398,11 +328,6 @@ async def cleanup(self): except asyncio.CancelledError: pass - # Clear ordered frames buffer - self.ordered_frames.clear() - self.next_expected_frame_id = 0 - self.input_frame_counter = 0 - # Clean up the client (this will gracefully shutdown workers) await self.client.cleanup() diff --git a/src/comfystream/tensor_cache_multi.py b/src/comfystream/tensor_cache_multi.py index d50f558e..ebcf113b 100644 --- a/src/comfystream/tensor_cache_multi.py +++ b/src/comfystream/tensor_cache_multi.py @@ -92,33 +92,17 @@ async def put(self, item): def put_nowait(self, item): try: - # Check if we have a current frame ID to associate with this output - global current_frame_id - - # Ensure tensor is on CPU before sending + # Ensure tensor is on CPU if torch.is_tensor(item): item = item.cpu() - - # If we have a frame ID, send it as a tuple - if current_frame_id is not None: - output_data = (current_frame_id, item) - # logger.info(f"[MultiProcessOutputQueue] Frame {current_frame_id} sent (nowait) from worker PID: {os.getpid()}") - else: - output_data = item - # logger.info(f"[MultiProcessOutputQueue] Frame sent (nowait) without ID from worker PID: {os.getpid()}") - - self.queue.put_nowait(output_data) + self.queue.put_nowait(item) except queue.Full: + # Simple: drop one old frame and try again try: - self.queue.get_nowait() # Drop oldest - # Try again with the same logic - if current_frame_id is not None: - output_data = (current_frame_id, item) - else: - output_data = item - self.queue.put_nowait(output_data) - except Exception: - pass # If still full, drop this frame + self.queue.get_nowait() + self.queue.put_nowait(item) + except: + pass # If still fails, just drop this frame def init_tensor_cache(image_inputs, image_outputs, audio_inputs, audio_outputs, workspace_path=None): """Initialize the tensor cache for a worker process. From 2773c9a4070ddb06dbc16b090c4991433f80ff1b Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Tue, 24 Jun 2025 21:02:33 -0400 Subject: [PATCH 41/42] Removal of extreanous files, cleanup, modified _multi as primary files. --- configs/comfy.toml | 13 - requirements.txt | 2 +- server/app.py | 167 ++++- server/app_api.py | 596 ----------------- server/app_multi.py | 601 ----------------- src/comfystream/client.py | 426 ++++++++++--- src/comfystream/client_api.py | 886 -------------------------- src/comfystream/client_multi.py | 502 --------------- src/comfystream/frame_logging.py | 366 ----------- src/comfystream/pipeline.py | 193 ++++-- src/comfystream/pipeline_api.py | 698 -------------------- src/comfystream/pipeline_multi.py | 349 ---------- src/comfystream/tensor_cache.py | 131 +++- src/comfystream/tensor_cache_multi.py | 245 ------- src/comfystream/utils_api.py | 154 ----- 15 files changed, 728 insertions(+), 4601 deletions(-) delete mode 100644 configs/comfy.toml delete mode 100644 server/app_api.py delete mode 100644 server/app_multi.py delete mode 100644 src/comfystream/client_api.py delete mode 100644 src/comfystream/client_multi.py delete mode 100644 src/comfystream/frame_logging.py delete mode 100644 src/comfystream/pipeline_api.py delete mode 100644 src/comfystream/pipeline_multi.py delete mode 100644 src/comfystream/tensor_cache_multi.py delete mode 100644 src/comfystream/utils_api.py diff --git a/configs/comfy.toml b/configs/comfy.toml deleted file mode 100644 index 5d278541..00000000 --- a/configs/comfy.toml +++ /dev/null @@ -1,13 +0,0 @@ -# Configuration for multiple ComfyUI servers - -[[servers]] -host = "127.0.0.1" -port = 8188 -client_id = "client1" - -# Adding more servers: - -# [[servers]] -# host = "127.0.0.1" -# port = 8189 -# client_id = "client2" diff --git a/requirements.txt b/requirements.txt index 595ad47c..4f6373ac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ asyncio -comfyui @ git+https://github.com/hiddenswitch/ComfyUI.git@ce3583ad42c024b8f060d0002cbe20c265da6dc8 +comfyui @ git+https://github.com/hiddenswitch/ComfyUI.git@e034d0bb24b0d23e3c40419c68689464dec67690 aiortc aiohttp aiohttp_cors diff --git a/server/app.py b/server/app.py index d24f2e78..b22a836c 100644 --- a/server/app.py +++ b/server/app.py @@ -7,13 +7,13 @@ import time import secrets import torch +import signal # Initialize CUDA before any other imports to prevent core dump. if torch.cuda.is_available(): torch.cuda.init() torch.cuda.empty_cache() - - + from aiohttp import web, MultipartWriter from aiohttp_cors import setup as setup_cors, ResourceOptions from aiohttp import web @@ -38,6 +38,8 @@ logging.getLogger("aiortc.rtcrtpsender").setLevel(logging.WARNING) logging.getLogger("aiortc.rtcrtpreceiver").setLevel(logging.WARNING) +# Global variable to track the app for signal handling +app_instance = None MAX_BITRATE = 2000000 MIN_BITRATE = 2000000 @@ -73,7 +75,7 @@ def __init__(self, track: MediaStreamTrack, pipeline: Pipeline): # Add cleanup when track ends @track.on("ended") async def on_ended(): - logger.info("Source video track ended, stopping collection") + logger.info("[App] Source video track ended, stopping collection") await cancel_collect_frames(self) async def collect_frames(self): @@ -86,22 +88,22 @@ async def collect_frames(self): frame = await self.track.recv() await self.pipeline.put_video_frame(frame) except asyncio.CancelledError: - logger.info("Frame collection cancelled") + logger.info("[App] Frame collection cancelled") break except Exception as e: if "MediaStreamError" in str(type(e)): - logger.info("Media stream ended") + logger.info("[App] Media stream ended") else: - logger.error(f"Error collecting video frames: {str(e)}") + logger.error(f"[App] Error collecting video frames: {str(e)}") self.running = False break # Perform cleanup outside the exception handler - logger.info("Video frame collection stopped") + logger.info("[App] Video frame collection stopped") except asyncio.CancelledError: - logger.info("Frame collection task cancelled") + logger.info("[App] Frame collection task cancelled") except Exception as e: - logger.error(f"Unexpected error in frame collection: {str(e)}") + logger.error(f"[App] Unexpected error in frame collection: {str(e)}") finally: await self.pipeline.cleanup() @@ -111,14 +113,15 @@ async def recv(self): """ processed_frame = await self.pipeline.get_processed_video_frame() - # Update the frame buffer with the processed frame + # Update the frame buffer with the processed frame try: from frame_buffer import FrameBuffer frame_buffer = FrameBuffer.get_instance() frame_buffer.update_frame(processed_frame) except Exception as e: # Don't let frame buffer errors affect the main pipeline - print(f"Error updating frame buffer: {e}") + print(f"[App] Error updating frame buffer: {e}") + # Increment the frame count to calculate FPS. await self.fps_meter.increment_frame_count() @@ -152,22 +155,22 @@ async def collect_frames(self): frame = await self.track.recv() await self.pipeline.put_audio_frame(frame) except asyncio.CancelledError: - logger.info("Audio frame collection cancelled") + logger.info("[App] Audio frame collection cancelled") break except Exception as e: if "MediaStreamError" in str(type(e)): - logger.info("Media stream ended") + logger.info("[App] Media stream ended") else: - logger.error(f"Error collecting audio frames: {str(e)}") + logger.error(f"[App] Error collecting audio frames: {str(e)}") self.running = False break # Perform cleanup outside the exception handler - logger.info("Audio frame collection stopped") + logger.info("[App] Audio frame collection stopped") except asyncio.CancelledError: - logger.info("Frame collection task cancelled") + logger.info("[App] Frame collection task cancelled") except Exception as e: - logger.error(f"Unexpected error in audio frame collection: {str(e)}") + logger.error(f"[App] Unexpected error in audio frame collection: {str(e)}") finally: await self.pipeline.cleanup() @@ -302,16 +305,16 @@ async def on_message(message): channel.send(json.dumps(response)) else: logger.warning( - "[Server] Invalid message format - missing required fields" + "[App] Invalid message format - missing required fields" ) except json.JSONDecodeError: - logger.error("[Server] Invalid JSON received") + logger.error("[App] Invalid JSON received") except Exception as e: - logger.error(f"[Server] Error processing message: {str(e)}") + logger.error(f"[App] Error processing message: {str(e)}") @pc.on("track") def on_track(track): - logger.info(f"Track received: {track.kind}") + logger.info(f"[App] Track received: {track.kind}") if track.kind == "video": videoTrack = VideoStreamTrack(track, pipeline) tracks["video"] = videoTrack @@ -330,12 +333,12 @@ def on_track(track): @track.on("ended") async def on_ended(): - logger.info(f"{track.kind} track ended") + logger.info(f"[App] {track.kind} track ended") request.app["video_tracks"].pop(track.id, None) @pc.on("connectionstatechange") async def on_connectionstatechange(): - logger.info(f"Connection state is: {pc.connectionState}") + logger.info(f"[App] Connection state is: {pc.connectionState}") if pc.connectionState == "failed": await pc.close() pcs.discard(pc) @@ -377,7 +380,7 @@ async def set_prompt(request): await pipeline.set_prompts(prompt) return web.Response(content_type="application/json", text="OK") - + def health(_): return web.Response(content_type="application/json", text="OK") @@ -390,22 +393,97 @@ async def on_startup(app: web.Application): app["pipeline"] = Pipeline( width=512, height=512, + max_workers=app["workers"], cwd=app["workspace"], disable_cuda_malloc=True, gpu_only=True, preview_method='none', comfyui_inference_log_level=app.get("comfui_inference_log_level", None), - frame_log_file=app.get("frame_log_file", None), ) app["pcs"] = set() app["video_tracks"] = {} async def on_shutdown(app: web.Application): + logger.info("Starting server shutdown...") + + # Clean up pipeline first - this should terminate worker processes + if "pipeline" in app: + try: + await app["pipeline"].cleanup() + logger.info("Pipeline cleanup completed") + except Exception as e: + logger.error(f"Error during pipeline cleanup: {e}") + + # Clean up peer connections pcs = app["pcs"] - coros = [pc.close() for pc in pcs] - await asyncio.gather(*coros) - pcs.clear() + if pcs: + logger.info(f"Closing {len(pcs)} peer connections...") + coros = [pc.close() for pc in pcs] + await asyncio.gather(*coros, return_exceptions=True) + pcs.clear() + logger.info("Peer connections closed") + + logger.info("Server shutdown completed") + + +def signal_handler(signum, frame): + """Handle SIGINT (Ctrl+C) gracefully""" + logger.info(f"Received signal {signum}, initiating graceful shutdown...") + + if app_instance: + # Get the current event loop + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + # Schedule the shutdown coroutine + asyncio.create_task(shutdown_server()) + else: + # If no loop is running, run the shutdown directly + asyncio.run(shutdown_server()) + except RuntimeError: + # If we can't get the loop, force cleanup and exit + logger.warning("Could not get event loop, forcing cleanup and exit") + force_cleanup_and_exit() + else: + logger.warning("No app instance found, forcing exit") + os._exit(1) + +def force_cleanup_and_exit(): + """Force cleanup of processes and exit""" + try: + # Try to cleanup the pipeline synchronously + if app_instance and "pipeline" in app_instance: + pipeline = app_instance["pipeline"] + if hasattr(pipeline, 'client') and hasattr(pipeline.client, 'comfy_client'): + # Force shutdown of the executor + if hasattr(pipeline.client.comfy_client, 'executor'): + executor = pipeline.client.comfy_client.executor + if hasattr(executor, 'shutdown'): + logger.info("Force shutting down executor...") + executor.shutdown(wait=False) + if hasattr(executor, '_processes'): + # Terminate all worker processes + for process in executor._processes: + if process.is_alive(): + logger.info(f"Terminating worker process {process.pid}") + process.terminate() + except Exception as e: + logger.error(f"Error in force cleanup: {e}") + finally: + os._exit(1) + +async def shutdown_server(): + """Perform graceful server shutdown""" + try: + if app_instance: + await on_shutdown(app_instance) + logger.info("Graceful shutdown completed") + except Exception as e: + logger.error(f"Error during graceful shutdown: {e}") + finally: + # Force exit after cleanup attempt + os._exit(0) if __name__ == "__main__": @@ -454,8 +532,19 @@ async def on_shutdown(app: web.Application): default=None, help="Filename for frame timing log (optional)" ) + parser.add_argument( + "--workers", + type=int, + default=1, + help="Number of workers to run", + ) args = parser.parse_args() + # Set up signal handlers + signal.signal(signal.SIGINT, signal_handler) + if hasattr(signal, 'SIGTERM'): + signal.signal(signal.SIGTERM, signal_handler) + logging.basicConfig( level=args.log_level.upper(), format="%(asctime)s [%(levelname)s] %(message)s", @@ -463,9 +552,9 @@ async def on_shutdown(app: web.Application): ) app = web.Application() + app_instance = app # Store reference for signal handler app["media_ports"] = args.media_ports.split(",") if args.media_ports else None app["workspace"] = args.workspace - app["frame_log_file"] = args.frame_log_file # Setup CORS cors = setup_cors(app, defaults={ @@ -476,6 +565,7 @@ async def on_shutdown(app: web.Application): allow_methods=["GET", "POST", "OPTIONS"] ) }) + app["workers"] = args.workers app.on_startup.append(on_startup) app.on_shutdown.append(on_shutdown) @@ -486,7 +576,7 @@ async def on_shutdown(app: web.Application): # WebRTC signalling and control routes. app.router.add_post("/offer", offer) app.router.add_post("/prompt", set_prompt) - + # Setup HTTP streaming routes setup_routes(app, cors) @@ -525,6 +615,17 @@ def force_print(*args, **kwargs): log_level = logging._nameToLevel.get(args.comfyui_log_level.upper()) logging.getLogger("comfy").setLevel(log_level) if args.comfyui_inference_log_level: - app["comfui_inference_log_level"] = args.comfyui_inference_log_level - - web.run_app(app, host=args.host, port=int(args.port), print=force_print) + log_level = logging._nameToLevel.get(args.comfyui_inference_log_level.upper()) + app["comfyui_inference_log_level"] = log_level + + try: + web.run_app(app, host=args.host, port=int(args.port), print=force_print) + except KeyboardInterrupt: + logger.info("Received KeyboardInterrupt, shutting down...") + finally: + # Ensure cleanup happens even if web.run_app doesn't call on_shutdown + if app_instance: + try: + asyncio.run(on_shutdown(app_instance)) + except Exception as e: + logger.error(f"Error in final cleanup: {e}") diff --git a/server/app_api.py b/server/app_api.py deleted file mode 100644 index 7de8f2de..00000000 --- a/server/app_api.py +++ /dev/null @@ -1,596 +0,0 @@ -import argparse -import asyncio -import json -import logging -import os -import sys - -import torch - -# Initialize CUDA before any other imports to prevent core dump. -if torch.cuda.is_available(): - torch.cuda.init() - torch.cuda.empty_cache() - -from aiohttp import web -from aiortc import ( - MediaStreamTrack, - RTCConfiguration, - RTCIceServer, - RTCPeerConnection, - RTCSessionDescription, -) -from aiortc.codecs import h264 -from aiortc.rtcrtpsender import RTCRtpSender -from twilio.rest import Client - -from comfystream.pipeline_api import Pipeline -from comfystream.server.utils import patch_loop_datagram, add_prefix_to_app_routes, FPSMeter -from comfystream.server.metrics import MetricsManager, StreamStatsManager - -logger = logging.getLogger(__name__) -logging.getLogger("aiortc.rtcrtpsender").setLevel(logging.WARNING) -logging.getLogger("aiortc.rtcrtpreceiver").setLevel(logging.WARNING) - - -MAX_BITRATE = 2000000 -MIN_BITRATE = 2000000 - - -class VideoStreamTrack(MediaStreamTrack): - """video stream track that processes video frames using a pipeline. - - Attributes: - kind (str): The kind of media, which is "video" for this class. - track (MediaStreamTrack): The underlying media stream track. - pipeline (Pipeline): The processing pipeline to apply to each video frame. - """ - - kind = "video" - - def __init__(self, track: MediaStreamTrack, pipeline: Pipeline): - """Initialize the VideoStreamTrack. - - Args: - track: The underlying media stream track. - pipeline: The processing pipeline to apply to each video frame. - """ - super().__init__() - self.track = track - self.pipeline = pipeline - self.fps_meter = FPSMeter( - metrics_manager=app["metrics_manager"], track_id=track.id - ) - self.running = True - self.collect_task = asyncio.create_task(self.collect_frames()) - - # Add cleanup when track ends - @track.on("ended") - async def on_ended(): - logger.info("Source video track ended, stopping collection") - await cancel_collect_frames(self) - - async def collect_frames(self): - """Collect video frames from the underlying track and pass them to - the processing pipeline. Stops when track ends or connection closes. - """ - try: - while self.running: - try: - frame = await self.track.recv() - await self.pipeline.put_video_frame(frame) - except asyncio.CancelledError: - logger.info("Frame collection cancelled") - break - except Exception as e: - if "MediaStreamError" in str(type(e)): - logger.info("Media stream ended") - else: - logger.error(f"Error collecting video frames: {str(e)}") - self.running = False - break - - # Perform cleanup outside the exception handler - logger.info("Video frame collection stopped") - except asyncio.CancelledError: - logger.info("Frame collection task cancelled") - except Exception as e: - logger.error(f"Unexpected error in frame collection: {str(e)}") - finally: - await self.pipeline.cleanup() - - async def recv(self): - """Receive a processed video frame from the pipeline, increment the frame - count for FPS calculation and return the processed frame to the client. - """ - processed_frame = await self.pipeline.get_processed_video_frame() - - # Increment the frame count to calculate FPS. - await self.fps_meter.increment_frame_count() - - return processed_frame - - -class AudioStreamTrack(MediaStreamTrack): - kind = "audio" - - def __init__(self, track: MediaStreamTrack, pipeline): - super().__init__() - self.track = track - self.pipeline = pipeline - self.running = True - self.collect_task = asyncio.create_task(self.collect_frames()) - - # Add cleanup when track ends - @track.on("ended") - async def on_ended(): - logger.info("Source audio track ended, stopping collection") - await cancel_collect_frames(self) - - async def collect_frames(self): - """Collect audio frames from the underlying track and pass them to - the processing pipeline. Stops when track ends or connection closes. - """ - try: - while self.running: - try: - frame = await self.track.recv() - await self.pipeline.put_audio_frame(frame) - except asyncio.CancelledError: - logger.info("Audio frame collection cancelled") - break - except Exception as e: - if "MediaStreamError" in str(type(e)): - logger.info("Media stream ended") - else: - logger.error(f"Error collecting audio frames: {str(e)}") - self.running = False - break - - # Perform cleanup outside the exception handler - logger.info("Audio frame collection stopped") - except asyncio.CancelledError: - logger.info("Frame collection task cancelled") - except Exception as e: - logger.error(f"Unexpected error in audio frame collection: {str(e)}") - finally: - await self.pipeline.cleanup() - - async def recv(self): - return await self.pipeline.get_processed_audio_frame() - -def force_codec(pc, sender, forced_codec): - kind = forced_codec.split("/")[0] - codecs = RTCRtpSender.getCapabilities(kind).codecs - transceiver = next(t for t in pc.getTransceivers() if t.sender == sender) - codecPrefs = [codec for codec in codecs if codec.mimeType == forced_codec] - transceiver.setCodecPreferences(codecPrefs) - - -def get_twilio_token(): - account_sid = os.getenv("TWILIO_ACCOUNT_SID") - auth_token = os.getenv("TWILIO_AUTH_TOKEN") - - if account_sid is None or auth_token is None: - return None - - client = Client(account_sid, auth_token) - - token = client.tokens.create() - - return token - - -def get_ice_servers(): - ice_servers = [] - - token = get_twilio_token() - if token is not None: - # Use Twilio TURN servers - for server in token.ice_servers: - if server["url"].startswith("turn:"): - turn = RTCIceServer( - urls=[server["urls"]], - credential=server["credential"], - username=server["username"], - ) - ice_servers.append(turn) - - return ice_servers - - -async def offer(request): - pipeline = request.app["pipeline"] - pcs = request.app["pcs"] - - # Check if clients are initialized, and initialize them if not - if not pipeline.clients: - logger.info("Clients not initialized yet, starting clients...") - results = await pipeline.start_clients() - - # Check if there was an error during startup - if results is None and hasattr(pipeline, 'startup_error') and pipeline.startup_error: - error_message = pipeline.startup_error - logger.error(f"Failed to initialize clients: {error_message}") - return web.Response( - status=500, - content_type="application/json", - text=json.dumps({"error": f"Failed to start ComfyUI: {error_message}"}) - ) - - # Get parameters - params = await request.json() - - # When a client reconnects after refresh, we need to clear certain pipeline state - # but NOT restart the ComfyUI servers/clients - # Reset the frame tracking, but keep the servers running - pipeline.next_expected_frame_id = None - pipeline.ordered_frames.clear() - pipeline.next_frame_id = 1 # Reset frame ID counter for new connection - pipeline.client_frame_mapping.clear() - - await pipeline.set_prompts(params["prompts"]) - - offer_params = params["offer"] - offer = RTCSessionDescription(sdp=offer_params["sdp"], type=offer_params["type"]) - - ice_servers = get_ice_servers() - if len(ice_servers) > 0: - pc = RTCPeerConnection( - configuration=RTCConfiguration(iceServers=get_ice_servers()) - ) - else: - pc = RTCPeerConnection() - - pcs.add(pc) - - tracks = {"video": None, "audio": None} - - # Flag to track if we've received resolution update - resolution_received = {"value": False} - - # Only add video transceiver if video is present in the offer - if "m=video" in offer.sdp: - # Prefer h264 - transceiver = pc.addTransceiver("video") - caps = RTCRtpSender.getCapabilities("video") - prefs = list(filter(lambda x: x.name == "H264", caps.codecs)) - transceiver.setCodecPreferences(prefs) - - # Monkey patch max and min bitrate to ensure constant bitrate - h264.MAX_BITRATE = MAX_BITRATE - h264.MIN_BITRATE = MIN_BITRATE - - # Handle control channel from client - @pc.on("datachannel") - def on_datachannel(channel): - if channel.label == "control": - - @channel.on("message") - async def on_message(message): - try: - params = json.loads(message) - - if params.get("type") == "get_nodes": - nodes_info = await pipeline.get_nodes_info() - response = {"type": "nodes_info", "nodes": nodes_info} - channel.send(json.dumps(response)) - elif params.get("type") == "update_prompts": - if "prompts" not in params: - logger.warning( - "[Control] Missing prompt in update_prompt message" - ) - return - try: - await pipeline.update_prompts(params["prompts"]) - except Exception as e: - logger.error(f"Error updating prompt: {str(e)}") - response = {"type": "prompts_updated", "success": True} - channel.send(json.dumps(response)) - elif params.get("type") == "update_resolution": - if "width" not in params or "height" not in params: - logger.warning("[Control] Missing width or height in update_resolution message") - return - # Update pipeline resolution for future frames - pipeline.width = params["width"] - pipeline.height = params["height"] - logger.info(f"[Control] Updated resolution to {params['width']}x{params['height']}") - - # Mark that we've received resolution - resolution_received["value"] = True - - # Warm the video pipeline with the new resolution - if "m=video" in pc.remoteDescription.sdp: - await pipeline.warm_video() - - response = { - "type": "resolution_updated", - "success": True - } - channel.send(json.dumps(response)) - else: - logger.warning( - "[Server] Invalid message format - missing required fields" - ) - except json.JSONDecodeError: - logger.error("[Server] Invalid JSON received") - except Exception as e: - logger.error(f"[Server] Error processing message: {str(e)}") - - @pc.on("track") - def on_track(track): - logger.info(f"Track received: {track.kind}") - if track.kind == "video": - videoTrack = VideoStreamTrack(track, pipeline) - tracks["video"] = videoTrack - sender = pc.addTrack(videoTrack) - - # Store video track in app for stats. - stream_id = track.id - request.app["video_tracks"][stream_id] = videoTrack - - codec = "video/H264" - force_codec(pc, sender, codec) - elif track.kind == "audio": - audioTrack = AudioStreamTrack(track, pipeline) - tracks["audio"] = audioTrack - pc.addTrack(audioTrack) - - @track.on("ended") - async def on_ended(): - logger.info(f"{track.kind} track ended") - request.app["video_tracks"].pop(track.id, None) - - @pc.on("connectionstatechange") - async def on_connectionstatechange(): - logger.info(f"Connection state is: {pc.connectionState}") - if pc.connectionState == "failed": - await pc.close() - pcs.discard(pc) - elif pc.connectionState == "closed": - await pc.close() - pcs.discard(pc) - - await pc.setRemoteDescription(offer) - - # Only warm audio here, video warming happens after resolution update - if "m=audio" in pc.remoteDescription.sdp: - await pipeline.warm_audio() - - # We no longer warm video here - it will be warmed after receiving resolution - - answer = await pc.createAnswer() - await pc.setLocalDescription(answer) - - return web.Response( - content_type="application/json", - text=json.dumps( - {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type} - ), - ) - - -async def cancel_collect_frames(track): - track.running = False - if hasattr(track, 'collect_task') is not None and not track.collect_task.done(): - try: - track.collect_task.cancel() - await track.collect_task - except (asyncio.CancelledError): - pass - - -async def set_prompt(request): - pipeline = request.app["pipeline"] - - prompt = await request.json() - await pipeline.set_prompts(prompt) - - return web.Response(content_type="application/json", text="OK") - - -def health(_): - return web.Response(content_type="application/json", text="OK") - - -async def on_startup(app: web.Application): - if app["media_ports"]: - patch_loop_datagram(app["media_ports"]) - - # ComfyUI args have been moved to the client constructor - app["pipeline"] = Pipeline( - width=512, - height=512, - config_path=app["config_file"], - max_frame_wait_ms=app["max_frame_wait"], - client_mode=app["client_mode"], - workspace=app["workspace"], - workers=app["workers"], - cuda_devices=app["cuda_devices"], - workers_start_port=app.get("workers_start_port", 8195), - comfyui_log_level=app.get("comfyui_log_level", None), - frame_log_file=app.get("frame_log_file", None), - ) - - # Start the clients during initialization - # await app["pipeline"].start_clients() - - # Wait for pipeline startup to complete (which starts the ComfyUI servers) - if hasattr(app["pipeline"], "startup_task"): - await app["pipeline"].startup_task - - app["pcs"] = set() - app["video_tracks"] = {} - - app["max_frame_wait"] = args.max_frame_wait - - -async def on_shutdown(app: web.Application): - pcs = app["pcs"] - coros = [pc.close() for pc in pcs] - await asyncio.gather(*coros) - pcs.clear() - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Run comfystream server") - parser.add_argument("--port", default=8889, help="Set the signaling port") - parser.add_argument( - "--media-ports", default=None, help="Set the UDP ports for WebRTC media" - ) - parser.add_argument("--host", default="127.0.0.1", help="Set the host") - parser.add_argument( - "--workspace", default=None, required=True, help="Set Comfy workspace" - ) - parser.add_argument( - "--log-level", - default="WARNING", - choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], - help="Set the logging level", - ) - parser.add_argument( - "--config-file", - type=str, - default=None, - help="Path to TOML configuration file for Comfy servers" - ) - parser.add_argument( - "--monitor", - default=False, - action="store_true", - help="Start a Prometheus metrics endpoint for monitoring.", - ) - parser.add_argument( - "--stream-id-label", - default=False, - action="store_true", - help="Include stream ID as a label in Prometheus metrics.", - ) - parser.add_argument( - "--max-frame-wait", - type=int, - default=500, - help="Maximum time to wait for a frame in milliseconds before dropping it" - ) - parser.add_argument( - "--comfyui-log-level", - default=None, - choices=logging._nameToLevel.keys(), - help="Set the global logging level for ComfyUI", - ) - parser.add_argument( - "--comfyui-inference-log-level", - default=None, - choices=logging._nameToLevel.keys(), - help="Set the logging level for ComfyUI inference", - ) - parser.add_argument( - "--client-mode", - choices=["toml", "spawn"], - default="toml", - help="How to create ComfyUI clients: 'toml' (from config file) or 'spawn' (spawn processes directly)", - ) - parser.add_argument( - "--workers", - type=int, - default=2, - help="Number of worker processes to spawn when using --client-mode=spawn" - ) - parser.add_argument( - "--cuda-devices", - type=str, - default='0', - help="Comma-separated list of CUDA devices to use" - ) - parser.add_argument( - "--workers-start-port", - type=int, - default=8195, - help="Starting port number for worker processes" - ) - parser.add_argument( - "--frame-log-file", - type=str, - default=None, - help="Filename for frame timing log (optional)" - ) - args = parser.parse_args() - - logging.basicConfig( - level=args.log_level.upper(), - format="%(asctime)s [%(levelname)s] %(message)s", - datefmt="%H:%M:%S", - ) - - app = web.Application() - app["media_ports"] = args.media_ports.split(",") if args.media_ports else None - app["workspace"] = args.workspace - app["config_file"] = args.config_file - app["max_frame_wait"] = args.max_frame_wait - app["client_mode"] = args.client_mode - app["workers"] = args.workers - app["cuda_devices"] = args.cuda_devices - app["workers_start_port"] = args.workers_start_port - app["frame_log_file"] = args.frame_log_file - - app.on_startup.append(on_startup) - app.on_shutdown.append(on_shutdown) - - app.router.add_get("/", health) - app.router.add_get("/health", health) - - # WebRTC signalling and control routes. - app.router.add_post("/offer", offer) - app.router.add_post("/prompt", set_prompt) - - # Add routes for getting stream statistics. - stream_stats_manager = StreamStatsManager(app) - app.router.add_get( - "/streams/stats", stream_stats_manager.collect_all_stream_metrics - ) - app.router.add_get( - "/stream/{stream_id}/stats", stream_stats_manager.collect_stream_metrics_by_id - ) - - # Add Prometheus metrics endpoint. - app["metrics_manager"] = MetricsManager(include_stream_id=args.stream_id_label) - if args.monitor: - app["metrics_manager"].enable() - logger.info( - f"Monitoring enabled - Prometheus metrics available at: " - f"http://{args.host}:{args.port}/metrics" - ) - app.router.add_get("/metrics", app["metrics_manager"].metrics_handler) - - # Add hosted platform route prefix. - # NOTE: This ensures that the local and hosted experiences have consistent routes. - add_prefix_to_app_routes(app, "/live") - - def force_print(*args, **kwargs): - print(*args, **kwargs, flush=True) - sys.stdout.flush() - - # Allow overriding of ComyfUI log levels. - if args.comfyui_log_level: - log_level = logging._nameToLevel.get(args.comfyui_log_level.upper()) - logging.getLogger("comfy").setLevel(log_level) - app["comfyui_log_level"] = args.comfyui_log_level - if args.comfyui_inference_log_level: - app["comfyui_inference_log_level"] = args.comfyui_inference_log_level - - print("\n\nComfystream Options:") - - print(f"Client Mode: {app.get('client_mode')}") - print(f"Log Level: {args.log_level.upper()}") - if (app.get("client_mode") == "spawn" and app.get("comfyui_log_level") is None): - print("To see spawned ComfyUI logs, add --comfyui_log_level=DEBUG") - else: - print(f"ComfyUI Log Level: {app.get('comfyui_log_level')}") - if (app.get("frame_log_file") is None): - print("To set a frame log file, add --frame_log_file=filename.csv") - else: - print(f"Frame Log File: {app.get('frame_log_file')}") - print("\n\n") - - logger.setLevel(getattr(logging, args.log_level.upper())) - - web.run_app(app, host=args.host, port=int(args.port), print=force_print) diff --git a/server/app_multi.py b/server/app_multi.py deleted file mode 100644 index ff4392b1..00000000 --- a/server/app_multi.py +++ /dev/null @@ -1,601 +0,0 @@ -import argparse -import asyncio -import json -import logging -import os -import sys -import torch -import signal - -# Initialize CUDA before any other imports to prevent core dump. -if torch.cuda.is_available(): - torch.cuda.init() - torch.cuda.empty_cache() - -from aiohttp import web -from aiortc import ( - MediaStreamTrack, - RTCConfiguration, - RTCIceServer, - RTCPeerConnection, - RTCSessionDescription, -) -from aiortc.codecs import h264 -from aiortc.rtcrtpsender import RTCRtpSender -from comfystream.pipeline_multi import Pipeline -from twilio.rest import Client -from comfystream.server.utils import patch_loop_datagram, add_prefix_to_app_routes, FPSMeter -from comfystream.server.metrics import MetricsManager, StreamStatsManager -import time - -logger = logging.getLogger(__name__) -logging.getLogger("aiortc.rtcrtpsender").setLevel(logging.WARNING) -logging.getLogger("aiortc.rtcrtpreceiver").setLevel(logging.WARNING) - -# Global variable to track the app for signal handling -app_instance = None - -MAX_BITRATE = 2000000 -MIN_BITRATE = 2000000 - - -class VideoStreamTrack(MediaStreamTrack): - """video stream track that processes video frames using a pipeline. - - Attributes: - kind (str): The kind of media, which is "video" for this class. - track (MediaStreamTrack): The underlying media stream track. - pipeline (Pipeline): The processing pipeline to apply to each video frame. - """ - - kind = "video" - - def __init__(self, track: MediaStreamTrack, pipeline: Pipeline): - """Initialize the VideoStreamTrack. - - Args: - track: The underlying media stream track. - pipeline: The processing pipeline to apply to each video frame. - """ - super().__init__() - self.track = track - self.pipeline = pipeline - self.fps_meter = FPSMeter( - metrics_manager=app["metrics_manager"], track_id=track.id - ) - self.running = True - self.collect_task = asyncio.create_task(self.collect_frames()) - - # Add cleanup when track ends - @track.on("ended") - async def on_ended(): - logger.info("[App] Source video track ended, stopping collection") - await cancel_collect_frames(self) - - async def collect_frames(self): - """Collect video frames from the underlying track and pass them to - the processing pipeline. Stops when track ends or connection closes. - """ - try: - while self.running: - try: - frame = await self.track.recv() - await self.pipeline.put_video_frame(frame) - except asyncio.CancelledError: - logger.info("[App] Frame collection cancelled") - break - except Exception as e: - if "MediaStreamError" in str(type(e)): - logger.info("[App] Media stream ended") - else: - logger.error(f"[App] Error collecting video frames: {str(e)}") - self.running = False - break - - # Perform cleanup outside the exception handler - logger.info("[App] Video frame collection stopped") - except asyncio.CancelledError: - logger.info("[App] Frame collection task cancelled") - except Exception as e: - logger.error(f"[App] Unexpected error in frame collection: {str(e)}") - finally: - await self.pipeline.cleanup() - - async def recv(self): - """Receive a processed video frame from the pipeline, increment the frame - count for FPS calculation and return the processed frame to the client. - """ - processed_frame = await self.pipeline.get_processed_video_frame() - - # Increment the frame count to calculate FPS. - await self.fps_meter.increment_frame_count() - - return processed_frame - - -class AudioStreamTrack(MediaStreamTrack): - kind = "audio" - - def __init__(self, track: MediaStreamTrack, pipeline): - super().__init__() - self.track = track - self.pipeline = pipeline - self.running = True - self.collect_task = asyncio.create_task(self.collect_frames()) - - # Add cleanup when track ends - @track.on("ended") - async def on_ended(): - logger.info("Source audio track ended, stopping collection") - await cancel_collect_frames(self) - - async def collect_frames(self): - """Collect audio frames from the underlying track and pass them to - the processing pipeline. Stops when track ends or connection closes. - """ - try: - while self.running: - try: - frame = await self.track.recv() - await self.pipeline.put_audio_frame(frame) - except asyncio.CancelledError: - logger.info("[App] Audio frame collection cancelled") - break - except Exception as e: - if "MediaStreamError" in str(type(e)): - logger.info("[App] Media stream ended") - else: - logger.error(f"[App] Error collecting audio frames: {str(e)}") - self.running = False - break - - # Perform cleanup outside the exception handler - logger.info("[App] Audio frame collection stopped") - except asyncio.CancelledError: - logger.info("[App] Frame collection task cancelled") - except Exception as e: - logger.error(f"[App] Unexpected error in audio frame collection: {str(e)}") - finally: - await self.pipeline.cleanup() - - async def recv(self): - return await self.pipeline.get_processed_audio_frame() - - -def force_codec(pc, sender, forced_codec): - kind = forced_codec.split("/")[0] - codecs = RTCRtpSender.getCapabilities(kind).codecs - transceiver = next(t for t in pc.getTransceivers() if t.sender == sender) - codecPrefs = [codec for codec in codecs if codec.mimeType == forced_codec] - transceiver.setCodecPreferences(codecPrefs) - - -def get_twilio_token(): - account_sid = os.getenv("TWILIO_ACCOUNT_SID") - auth_token = os.getenv("TWILIO_AUTH_TOKEN") - - if account_sid is None or auth_token is None: - return None - - client = Client(account_sid, auth_token) - - token = client.tokens.create() - - return token - - -def get_ice_servers(): - ice_servers = [] - - token = get_twilio_token() - if token is not None: - # Use Twilio TURN servers - for server in token.ice_servers: - if server["url"].startswith("turn:"): - turn = RTCIceServer( - urls=[server["urls"]], - credential=server["credential"], - username=server["username"], - ) - ice_servers.append(turn) - - return ice_servers - - -async def offer(request): - pipeline = request.app["pipeline"] - pcs = request.app["pcs"] - - params = await request.json() - - await pipeline.set_prompts(params["prompts"]) - - offer_params = params["offer"] - offer = RTCSessionDescription(sdp=offer_params["sdp"], type=offer_params["type"]) - - ice_servers = get_ice_servers() - if len(ice_servers) > 0: - pc = RTCPeerConnection( - configuration=RTCConfiguration(iceServers=get_ice_servers()) - ) - else: - pc = RTCPeerConnection() - - pcs.add(pc) - - tracks = {"video": None, "audio": None} - - # Flag to track if we've received resolution update - resolution_received = {"value": False} - - # Only add video transceiver if video is present in the offer - if "m=video" in offer.sdp: - # Prefer h264 - transceiver = pc.addTransceiver("video") - caps = RTCRtpSender.getCapabilities("video") - prefs = list(filter(lambda x: x.name == "H264", caps.codecs)) - transceiver.setCodecPreferences(prefs) - - # Monkey patch max and min bitrate to ensure constant bitrate - h264.MAX_BITRATE = MAX_BITRATE - h264.MIN_BITRATE = MIN_BITRATE - - # Handle control channel from client - @pc.on("datachannel") - def on_datachannel(channel): - if channel.label == "control": - - @channel.on("message") - async def on_message(message): - try: - params = json.loads(message) - - if params.get("type") == "get_nodes": - nodes_info = await pipeline.get_nodes_info() - response = {"type": "nodes_info", "nodes": nodes_info} - channel.send(json.dumps(response)) - elif params.get("type") == "update_prompts": - if "prompts" not in params: - logger.warning( - "[Control] Missing prompt in update_prompt message" - ) - return - try: - await pipeline.update_prompts(params["prompts"]) - except Exception as e: - logger.error(f"Error updating prompt: {str(e)}") - response = {"type": "prompts_updated", "success": True} - channel.send(json.dumps(response)) - elif params.get("type") == "update_resolution": - if "width" not in params or "height" not in params: - logger.warning("[Control] Missing width or height in update_resolution message") - return - # Update pipeline resolution for future frames - pipeline.width = params["width"] - pipeline.height = params["height"] - logger.info(f"[Control] Updated resolution to {params['width']}x{params['height']}") - - # Mark that we've received resolution - resolution_received["value"] = True - - # Warm the video pipeline with the new resolution - if "m=video" in pc.remoteDescription.sdp: - await pipeline.warm_video() - - response = { - "type": "resolution_updated", - "success": True - } - channel.send(json.dumps(response)) - else: - logger.warning( - "[App] Invalid message format - missing required fields" - ) - except json.JSONDecodeError: - logger.error("[App] Invalid JSON received") - except Exception as e: - logger.error(f"[App] Error processing message: {str(e)}") - - @pc.on("track") - def on_track(track): - logger.info(f"[App] Track received: {track.kind}") - if track.kind == "video": - videoTrack = VideoStreamTrack(track, pipeline) - tracks["video"] = videoTrack - sender = pc.addTrack(videoTrack) - - # Store video track in app for stats. - stream_id = track.id - request.app["video_tracks"][stream_id] = videoTrack - - codec = "video/H264" - force_codec(pc, sender, codec) - elif track.kind == "audio": - audioTrack = AudioStreamTrack(track, pipeline) - tracks["audio"] = audioTrack - pc.addTrack(audioTrack) - - @track.on("ended") - async def on_ended(): - logger.info(f"[App] {track.kind} track ended") - request.app["video_tracks"].pop(track.id, None) - - @pc.on("connectionstatechange") - async def on_connectionstatechange(): - logger.info(f"[App] Connection state is: {pc.connectionState}") - if pc.connectionState == "failed": - await pc.close() - pcs.discard(pc) - elif pc.connectionState == "closed": - await pc.close() - pcs.discard(pc) - - await pc.setRemoteDescription(offer) - - # Only warm audio here, video warming happens after resolution update - if "m=audio" in pc.remoteDescription.sdp: - await pipeline.warm_audio() - - # We no longer warm video here - it will be warmed after receiving resolution - - answer = await pc.createAnswer() - await pc.setLocalDescription(answer) - - return web.Response( - content_type="application/json", - text=json.dumps( - {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type} - ), - ) - -async def cancel_collect_frames(track): - track.running = False - if hasattr(track, 'collect_task') is not None and not track.collect_task.done(): - try: - track.collect_task.cancel() - await track.collect_task - except (asyncio.CancelledError): - pass - -async def set_prompt(request): - pipeline = request.app["pipeline"] - - prompt = await request.json() - await pipeline.set_prompts(prompt) - - return web.Response(content_type="application/json", text="OK") - - -def health(_): - return web.Response(content_type="application/json", text="OK") - - -async def on_startup(app: web.Application): - if app["media_ports"]: - patch_loop_datagram(app["media_ports"]) - - app["pipeline"] = Pipeline( - width=512, - height=512, - max_workers=app["workers"], - comfyui_inference_log_level=app.get("comfyui_inference_log_level", None), - frame_log_file=app.get("frame_log_file", None), - cwd=app["workspace"], - disable_cuda_malloc=True, - gpu_only=True, - preview_method='none', - ) - app["pcs"] = set() - app["video_tracks"] = {} - - -async def on_shutdown(app: web.Application): - logger.info("Starting server shutdown...") - - # Clean up pipeline first - this should terminate worker processes - if "pipeline" in app: - try: - await app["pipeline"].cleanup() - logger.info("Pipeline cleanup completed") - except Exception as e: - logger.error(f"Error during pipeline cleanup: {e}") - - # Clean up peer connections - pcs = app["pcs"] - if pcs: - logger.info(f"Closing {len(pcs)} peer connections...") - coros = [pc.close() for pc in pcs] - await asyncio.gather(*coros, return_exceptions=True) - pcs.clear() - logger.info("Peer connections closed") - - logger.info("Server shutdown completed") - - -def signal_handler(signum, frame): - """Handle SIGINT (Ctrl+C) gracefully""" - logger.info(f"Received signal {signum}, initiating graceful shutdown...") - - if app_instance: - # Get the current event loop - try: - loop = asyncio.get_event_loop() - if loop.is_running(): - # Schedule the shutdown coroutine - asyncio.create_task(shutdown_server()) - else: - # If no loop is running, run the shutdown directly - asyncio.run(shutdown_server()) - except RuntimeError: - # If we can't get the loop, force cleanup and exit - logger.warning("Could not get event loop, forcing cleanup and exit") - force_cleanup_and_exit() - else: - logger.warning("No app instance found, forcing exit") - os._exit(1) - -def force_cleanup_and_exit(): - """Force cleanup of processes and exit""" - try: - # Try to cleanup the pipeline synchronously - if app_instance and "pipeline" in app_instance: - pipeline = app_instance["pipeline"] - if hasattr(pipeline, 'client') and hasattr(pipeline.client, 'comfy_client'): - # Force shutdown of the executor - if hasattr(pipeline.client.comfy_client, 'executor'): - executor = pipeline.client.comfy_client.executor - if hasattr(executor, 'shutdown'): - logger.info("Force shutting down executor...") - executor.shutdown(wait=False) - if hasattr(executor, '_processes'): - # Terminate all worker processes - for process in executor._processes: - if process.is_alive(): - logger.info(f"Terminating worker process {process.pid}") - process.terminate() - except Exception as e: - logger.error(f"Error in force cleanup: {e}") - finally: - os._exit(1) - -async def shutdown_server(): - """Perform graceful server shutdown""" - try: - if app_instance: - await on_shutdown(app_instance) - logger.info("Graceful shutdown completed") - except Exception as e: - logger.error(f"Error during graceful shutdown: {e}") - finally: - # Force exit after cleanup attempt - os._exit(0) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Run comfystream server") - parser.add_argument("--port", default=8889, help="Set the signaling port") - parser.add_argument( - "--media-ports", default=None, help="Set the UDP ports for WebRTC media" - ) - parser.add_argument("--host", default="127.0.0.1", help="Set the host") - parser.add_argument( - "--workspace", default=None, required=True, help="Set Comfy workspace" - ) - parser.add_argument( - "--log-level", - default="INFO", - choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], - help="Set the logging level", - ) - parser.add_argument( - "--monitor", - default=False, - action="store_true", - help="Start a Prometheus metrics endpoint for monitoring.", - ) - parser.add_argument( - "--stream-id-label", - default=False, - action="store_true", - help="Include stream ID as a label in Prometheus metrics.", - ) - parser.add_argument( - "--comfyui-log-level", - default=None, - choices=logging._nameToLevel.keys(), - help="Set the global logging level for ComfyUI", - ) - parser.add_argument( - "--comfyui-inference-log-level", - default=None, - choices=logging._nameToLevel.keys(), - help="Set the logging level for ComfyUI inference", - ) - parser.add_argument( - "--frame-log-file", - type=str, - default=None, - help="Filename for frame timing log (optional)" - ) - parser.add_argument( - "--workers", - type=int, - default=1, - help="Number of workers to run", - ) - args = parser.parse_args() - - # Set up signal handlers - signal.signal(signal.SIGINT, signal_handler) - if hasattr(signal, 'SIGTERM'): - signal.signal(signal.SIGTERM, signal_handler) - - logging.basicConfig( - level=args.log_level.upper(), - format="%(asctime)s [%(levelname)s] %(message)s", - datefmt="%H:%M:%S", - ) - - app = web.Application() - app_instance = app # Store reference for signal handler - app["media_ports"] = args.media_ports.split(",") if args.media_ports else None - app["workspace"] = args.workspace - app["frame_log_file"] = args.frame_log_file - app["workers"] = args.workers - - app.on_startup.append(on_startup) - app.on_shutdown.append(on_shutdown) - - app.router.add_get("/", health) - app.router.add_get("/health", health) - - # WebRTC signalling and control routes. - app.router.add_post("/offer", offer) - app.router.add_post("/prompt", set_prompt) - - # Add routes for getting stream statistics. - stream_stats_manager = StreamStatsManager(app) - app.router.add_get( - "/streams/stats", stream_stats_manager.collect_all_stream_metrics - ) - app.router.add_get( - "/stream/{stream_id}/stats", stream_stats_manager.collect_stream_metrics_by_id - ) - - # Add Prometheus metrics endpoint. - app["metrics_manager"] = MetricsManager(include_stream_id=args.stream_id_label) - if args.monitor: - app["metrics_manager"].enable() - logger.info( - f"Monitoring enabled - Prometheus metrics available at: " - f"http://{args.host}:{args.port}/metrics" - ) - app.router.add_get("/metrics", app["metrics_manager"].metrics_handler) - - # Add hosted platform route prefix. - # NOTE: This ensures that the local and hosted experiences have consistent routes. - add_prefix_to_app_routes(app, "/live") - - def force_print(*args, **kwargs): - print(*args, **kwargs, flush=True) - sys.stdout.flush() - - # Allow overriding of ComyfUI log levels. - if args.comfyui_log_level: - log_level = logging._nameToLevel.get(args.comfyui_log_level.upper()) - logging.getLogger("comfy").setLevel(log_level) - if args.comfyui_inference_log_level: - log_level = logging._nameToLevel.get(args.comfyui_inference_log_level.upper()) - app["comfyui_inference_log_level"] = log_level - - try: - web.run_app(app, host=args.host, port=int(args.port), print=force_print) - except KeyboardInterrupt: - logger.info("Received KeyboardInterrupt, shutting down...") - finally: - # Ensure cleanup happens even if web.run_app doesn't call on_shutdown - if app_instance: - try: - asyncio.run(on_shutdown(app_instance)) - except Exception as e: - logger.error(f"Error in final cleanup: {e}") diff --git a/src/comfystream/client.py b/src/comfystream/client.py index ca0c8751..3e402549 100644 --- a/src/comfystream/client.py +++ b/src/comfystream/client.py @@ -1,126 +1,336 @@ import asyncio -from typing import List import logging +from typing import List, Union +import multiprocessing as mp +import os +import sys +import numpy as np +import torch +import av -from comfystream import tensor_cache from comfystream.utils import convert_prompt +from comfystream.tensor_cache import init_tensor_cache -from comfy.api.components.schema.prompt import PromptDictInput from comfy.cli_args_types import Configuration +from comfy.distributed.process_pool_executor import ProcessPoolExecutor +from comfy.api.components.schema.prompt import PromptDictInput from comfy.client.embedded_comfy_client import EmbeddedComfyClient +from comfystream.frame_proxy import FrameProxy logger = logging.getLogger(__name__) +def _test_worker_init(): + """Test function to verify worker process initialization.""" + return os.getpid() class ComfyStreamClient: - def __init__(self, max_workers: int = 1, **kwargs): - config = Configuration(**kwargs) - self.comfy_client = EmbeddedComfyClient(config, max_workers=max_workers) - self.running_prompts = {} # To be used for cancelling tasks + def __init__(self, + max_workers: int = 1, + **kwargs): + logger.info(f"[ComfyStreamClient] Main Process ID: {os.getpid()}") + logger.info(f"[ComfyStreamClient] __init__ start, max_workers: {max_workers}") + + # Store default dimensions + self.width = kwargs.get('width', 512) + self.height = kwargs.get('height', 512) + + # Ensure workspace path is absolute + if 'cwd' in kwargs: + if not os.path.isabs(kwargs['cwd']): + # Convert relative path to absolute path from current working directory + kwargs['cwd'] = os.path.abspath(kwargs['cwd']) + logger.info(f"[ComfyStreamClient] Using absolute workspace path: {kwargs['cwd']}") + + # Register TensorRT paths in main process BEFORE creating ComfyUI client + self.register_tensorrt_paths_main_process(kwargs.get('cwd')) + + # Cache nodes information in main process to avoid ProcessPoolExecutor conflicts + self._initialize_nodes_cache() + + logger.info("[ComfyStreamClient] Config kwargs: %s", kwargs) + + try: + self.config = Configuration(**kwargs) + logger.info("[ComfyStreamClient] Configuration created") + logger.info(f"[ComfyStreamClient] Current working directory: {os.getcwd()}") + + logger.info("[ComfyStreamClient] Initializing process executor") + ctx = mp.get_context("spawn") + logger.info(f"[ComfyStreamClient] Using multiprocessing context: {ctx.get_start_method()}") + + manager = ctx.Manager() + logger.info("[ComfyStreamClient] Created multiprocessing context and manager") + + self.image_inputs = manager.Queue(maxsize=50) + self.image_outputs = manager.Queue(maxsize=50) + self.audio_inputs = manager.Queue(maxsize=50) + self.audio_outputs = manager.Queue(maxsize=50) + logger.info("[ComfyStreamClient] Created manager queues") + + logger.info("[ComfyStreamClient] About to create ProcessPoolExecutor...") + + executor = ProcessPoolExecutor( + max_workers=max_workers, + initializer=init_tensor_cache, + initargs=(self.image_inputs, self.image_outputs, self.audio_inputs, self.audio_outputs, kwargs.get('cwd')) + ) + logger.info("[ComfyStreamClient] ProcessPoolExecutor created successfully") + + # Create EmbeddedComfyClient with the executor + logger.info("[ComfyStreamClient] Creating EmbeddedComfyClient with executor") + self.comfy_client = EmbeddedComfyClient(self.config, executor=executor) + logger.info("[ComfyStreamClient] EmbeddedComfyClient created successfully") + + # Submit a test task to ensure worker processes are initialized + logger.info("[ComfyStreamClient] Testing worker process initialization...") + test_future = executor.submit(_test_worker_init) + try: + worker_pid = test_future.result(timeout=30) # 30 second timeout + logger.info(f"[ComfyStreamClient] Worker process initialized successfully (PID: {worker_pid})") + except Exception as e: + logger.info(f"[ComfyStreamClient] Error initializing worker process: {str(e)}") + raise + + except Exception as e: + logger.info(f"[ComfyStreamClient] Error during initialization: {str(e)}") + logger.info(f"[ComfyStreamClient] Error type: {type(e)}") + import traceback + logger.info(f"[ComfyStreamClient] Error traceback: {traceback.format_exc()}") + raise + + self.running_prompts = {} self.current_prompts = [] - self._cleanup_lock = asyncio.Lock() - self._prompt_update_lock = asyncio.Lock() + self.cleanup_lock = asyncio.Lock() + self.max_workers = max_workers + self.worker_tasks = [] + self.next_worker = 0 + self.distribution_lock = asyncio.Lock() + self.shutting_down = False # Add shutdown flag + self.distribution_task = None # Track distribution task + logger.info("[ComfyStreamClient] Initialized successfully") async def set_prompts(self, prompts: List[PromptDictInput]): - await self.cancel_running_prompts() + logger.info("set_prompts start") self.current_prompts = [convert_prompt(prompt) for prompt in prompts] - for idx in range(len(self.current_prompts)): - task = asyncio.create_task(self.run_prompt(idx)) - self.running_prompts[idx] = task + + # Start the distribution manager only if not already running + if self.distribution_task is None or self.distribution_task.done(): + self.shutting_down = False # Reset shutdown flag + self.distribution_task = asyncio.create_task(self.distribute_frames()) + self.running_prompts[-1] = self.distribution_task # Use -1 as a special key for the manager + logger.info("set_prompts end") + + async def distribute_frames(self): + """Manager that distributes frames across workers in round-robin fashion""" + logger.info(f"[ComfyStreamClient] Starting frame distribution manager") + + try: + # Initialize worker tasks + self.worker_tasks = [] + for worker_id in range(self.max_workers): + worker_task = asyncio.create_task(self.worker_loop(worker_id)) + self.worker_tasks.append(worker_task) + self.running_prompts[worker_id] = worker_task + + # Keep the manager running to monitor workers + while not self.shutting_down: + await asyncio.sleep(1.0) # Check periodically + + # Only restart crashed workers if we're not shutting down + if not self.shutting_down: + for worker_id, task in enumerate(self.worker_tasks): + if task.done(): + # Check if the task was cancelled (graceful shutdown) or crashed + if task.cancelled(): + logger.info(f"Worker {worker_id} was cancelled (graceful shutdown)") + else: + # Check if there was an exception + try: + task.result() + logger.info(f"Worker {worker_id} completed normally") + except Exception as e: + logger.warning(f"Worker {worker_id} crashed with error: {e}, restarting") + new_task = asyncio.create_task(self.worker_loop(worker_id)) + self.worker_tasks[worker_id] = new_task + self.running_prompts[worker_id] = new_task + + except asyncio.CancelledError: + logger.info("[ComfyStreamClient] Distribution manager cancelled") + except Exception as e: + logger.error(f"[ComfyStreamClient] Error in distribution manager: {e}") + finally: + logger.info("[ComfyStreamClient] Distribution manager stopped") - async def update_prompts(self, prompts: List[PromptDictInput]): - async with self._prompt_update_lock: - # TODO: currently under the assumption that only already running prompts are updated - if len(prompts) != len(self.current_prompts): - raise ValueError( - "Number of updated prompts must match the number of currently running prompts." - ) - # Validation step before updating the prompt, only meant for a single prompt for now - for idx, prompt in enumerate(prompts): - converted_prompt = convert_prompt(prompt) + async def worker_loop(self, worker_id: int): + """Simple worker loop - just process frames continuously""" + logger.info(f"[Worker {worker_id}] Started") + + frame_count = 0 + try: + while not self.shutting_down: try: - await self.comfy_client.queue_prompt(converted_prompt) - self.current_prompts[idx] = converted_prompt + # Simple round-robin prompt selection + prompt_index = worker_id % len(self.current_prompts) + current_prompt = self.current_prompts[prompt_index] + + # Just process the prompt + await self.comfy_client.queue_prompt(current_prompt) + frame_count += 1 + + except asyncio.CancelledError: + break except Exception as e: - raise Exception(f"Prompt update failed: {str(e)}") from e + if self.shutting_down: + break + logger.error(f"[Worker {worker_id}] Error: {e}") + await asyncio.sleep(0.1) + finally: + logger.info(f"[Worker {worker_id}] Processed {frame_count} frames") - async def run_prompt(self, prompt_index: int): - while True: - async with self._prompt_update_lock: + async def cleanup(self): + async with self.cleanup_lock: + logger.info("[ComfyStreamClient] Starting cleanup...") + + # Set shutdown flag to stop workers gracefully + self.shutting_down = True + + # Cancel distribution task first + if self.distribution_task and not self.distribution_task.done(): + self.distribution_task.cancel() try: - await self.comfy_client.queue_prompt(self.current_prompts[prompt_index]) + await self.distribution_task except asyncio.CancelledError: - raise + pass + + # Cancel all worker tasks + for task in self.worker_tasks: + if not task.done(): + task.cancel() + + # Wait for tasks to complete cancellation + if self.worker_tasks: + try: + await asyncio.gather(*self.worker_tasks, return_exceptions=True) + logger.info("[ComfyStreamClient] All worker tasks stopped") except Exception as e: - await self.cleanup() - logger.error(f"Error running prompt: {str(e)}") - raise - - async def cleanup(self): - await self.cancel_running_prompts() - async with self._cleanup_lock: - if self.comfy_client.is_running: + logger.error(f"Error waiting for worker tasks: {e}") + + # Clear the tasks list + self.worker_tasks.clear() + self.running_prompts.clear() + + # Cleanup the ComfyUI client and its executor + if hasattr(self, 'comfy_client') and self.comfy_client.is_running: try: - await self.comfy_client.__aexit__() + # Get the executor before closing the client + executor = getattr(self.comfy_client, 'executor', None) + + # Close the client first + await self.comfy_client.__aexit__(None, None, None) + logger.info("[ComfyStreamClient] ComfyUI client closed") + + # Then shutdown the executor and terminate processes + if executor: + logger.info("[ComfyStreamClient] Shutting down executor...") + # Shutdown the executor + executor.shutdown(wait=False) + + # Force terminate any remaining processes + if hasattr(executor, '_processes'): + for process in executor._processes: + if process.is_alive(): + logger.info(f"[ComfyStreamClient] Terminating worker process {process.pid}") + process.terminate() + # Give it a moment to terminate gracefully + try: + process.join(timeout=2.0) + except: + pass + # Force kill if still alive + if process.is_alive(): + logger.warning(f"[ComfyStreamClient] Force killing worker process {process.pid}") + process.kill() + + logger.info("[ComfyStreamClient] Executor shutdown completed") + except Exception as e: logger.error(f"Error during ComfyClient cleanup: {e}") await self.cleanup_queues() - logger.info("Client cleanup complete") - - async def cancel_running_prompts(self): - async with self._cleanup_lock: - tasks_to_cancel = list(self.running_prompts.values()) - for task in tasks_to_cancel: - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - self.running_prompts.clear() + + # Reset state for potential reuse + self.shutting_down = False + self.distribution_task = None + + logger.info("[ComfyStreamClient] Client cleanup complete") - async def cleanup_queues(self): - while not tensor_cache.image_inputs.empty(): - tensor_cache.image_inputs.get() - - while not tensor_cache.audio_inputs.empty(): - tensor_cache.audio_inputs.get() + # TODO: add for audio as well + while not self.image_inputs.empty(): + self.image_inputs.get() - while not tensor_cache.image_outputs.empty(): - await tensor_cache.image_outputs.get() - - while not tensor_cache.audio_outputs.empty(): - await tensor_cache.audio_outputs.get() + while not self.image_outputs.empty(): + self.image_outputs.get() def put_video_input(self, frame): - if tensor_cache.image_inputs.full(): - tensor_cache.image_inputs.get(block=True) - tensor_cache.image_inputs.put(frame) - + try: + # Check if frame is FrameProxy + if isinstance(frame, FrameProxy): + proxy = frame + else: + proxy = FrameProxy.avframe_to_frameproxy(frame) + + # Handle queue being full + if self.image_inputs.full(): + # logger.warning(f"[ComfyStreamClient] Input queue full, dropping oldest frame") + try: + self.image_inputs.get_nowait() + except Exception: + pass + + self.image_inputs.put_nowait(proxy) + # logger.info(f"[ComfyStreamClient] Video input queued. Queue size: {self.image_inputs.qsize()}") + except Exception as e: + logger.error(f"[ComfyStreamClient] Error putting video frame: {str(e)}") + def put_audio_input(self, frame): - tensor_cache.audio_inputs.put(frame) + self.audio_inputs.put(frame) async def get_video_output(self): - return await tensor_cache.image_outputs.get() + try: + logger.debug(f"[ComfyStreamClient] get_video_output called - PID: {os.getpid()}") + tensor = await asyncio.wait_for( + asyncio.get_event_loop().run_in_executor(None, self.image_outputs.get), + timeout=5.0 + ) + # logger.info(f"[ComfyStreamClient] get_video_output returning tensor: {tensor.shape} - PID: {os.getpid()}") + return tensor + except asyncio.TimeoutError: + logger.warning(f"[ComfyStreamClient] get_video_output timeout - PID: {os.getpid()}") + return torch.zeros((1, 3, self.height, self.width), dtype=torch.float32) + except Exception as e: + logger.error(f"[ComfyStreamClient] Error getting video output: {str(e)} - PID: {os.getpid()}") + return torch.zeros((1, 3, self.height, self.width), dtype=torch.float32) async def get_audio_output(self): - return await tensor_cache.audio_outputs.get() - + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, self.audio_outputs.get) + async def get_available_nodes(self): - """Get metadata and available nodes info in a single pass""" - # TODO: make it for for multiple prompts - if not self.running_prompts: + """Get metadata and available nodes info using cached nodes to avoid ProcessPoolExecutor conflicts""" + if not self.current_prompts: return {} - try: - from comfy.nodes.package import import_all_nodes_in_workspace - nodes = import_all_nodes_in_workspace() + # Use cached nodes instead of calling import_all_nodes_in_workspace from worker process + if self._nodes is None: + logger.warning("[ComfyStreamClient] Nodes cache not available, returning empty result") + return {} + try: all_prompts_nodes_info = {} for prompt_index, prompt in enumerate(self.current_prompts): - # Get set of class types we need metadata for, excluding LoadTensor and SaveTensor + # Get set of class types we need metadata for needed_class_types = { node.get('class_type') for node in prompt.values() @@ -132,7 +342,7 @@ async def get_available_nodes(self): nodes_info = {} # Only process nodes until we've found all the ones we need - for class_type, node_class in nodes.NODE_CLASS_MAPPINGS.items(): + for class_type, node_class in self._nodes.NODE_CLASS_MAPPINGS.items(): if not remaining_nodes: # Exit early if we've found all needed nodes break @@ -232,3 +442,61 @@ async def get_available_nodes(self): except Exception as e: logger.error(f"Error getting node info: {str(e)}") return {} + + def register_tensorrt_paths_main_process(self, workspace_path): + """Register TensorRT paths in the main process for validation""" + try: + from comfy.cmd import folder_paths + + if workspace_path: + base_dir = workspace_path + tensorrt_models_dir = os.path.join(base_dir, "models", "tensorrt") + tensorrt_outputs_dir = os.path.join(base_dir, "outputs", "tensorrt") + else: + tensorrt_models_dir = os.path.join(folder_paths.models_dir, "tensorrt") + tensorrt_outputs_dir = os.path.join(folder_paths.models_dir, "outputs", "tensorrt") + + # logger.info(f"[ComfyStreamClient] Registering TensorRT paths in main process") + # logger.info(f"[ComfyStreamClient] TensorRT models dir: {tensorrt_models_dir}") + # logger.info(f"[ComfyStreamClient] TensorRT outputs dir: {tensorrt_outputs_dir}") + + # Register TensorRT paths + if "tensorrt" in folder_paths.folder_names_and_paths: + existing_paths = folder_paths.folder_names_and_paths["tensorrt"][0] + for path in [tensorrt_models_dir, tensorrt_outputs_dir]: + if path not in existing_paths: + existing_paths.append(path) + folder_paths.folder_names_and_paths["tensorrt"][1].add(".engine") + else: + folder_paths.folder_names_and_paths["tensorrt"] = ( + [tensorrt_models_dir, tensorrt_outputs_dir], + {".engine"} + ) + + # Verify registration + # available_files = folder_paths.get_filename_list("tensorrt") + # logger.info(f"[ComfyStreamClient] Main process TensorRT files: {available_files}") + + except Exception as e: + logger.error(f"[ComfyStreamClient] Error registering TensorRT paths in main process: {e}") + import traceback + logger.error(f"[ComfyStreamClient] Traceback: {traceback.format_exc()}") + + async def update_prompts(self, prompts: List[PromptDictInput]): + """Update the existing processing prompts without restarting workers.""" + + # Simply update the current prompts - worker loops will pick up changes on next iteration + self.current_prompts = [convert_prompt(prompt) for prompt in prompts] + + logger.info("[ComfyStreamClient] Prompts updated") + + def _initialize_nodes_cache(self): + """Initialize nodes cache in main process to avoid ProcessPoolExecutor conflicts""" + try: + logger.info("[ComfyStreamClient] Initializing nodes cache in main process...") + from comfy.nodes.package import import_all_nodes_in_workspace + self._nodes = import_all_nodes_in_workspace() + logger.info(f"[ComfyStreamClient] Cached {len(self._nodes.NODE_CLASS_MAPPINGS)} node types") + except Exception as e: + logger.error(f"[ComfyStreamClient] Error initializing nodes cache: {e}") + self._nodes = None \ No newline at end of file diff --git a/src/comfystream/client_api.py b/src/comfystream/client_api.py deleted file mode 100644 index 61e57102..00000000 --- a/src/comfystream/client_api.py +++ /dev/null @@ -1,886 +0,0 @@ -import asyncio -import json -import uuid -import websockets -import base64 -import aiohttp -import logging -import torch -import numpy as np -from io import BytesIO -from PIL import Image -from typing import List, Dict, Any, Optional, Union -import random -import time -import subprocess -import os -import socket - -from comfystream import tensor_cache -from comfystream.utils_api import convert_prompt -from torchvision.transforms.functional import to_pil_image - -logger = logging.getLogger(__name__) - -class ComfyStreamClient: - def __init__( - self, - host: str = "127.0.0.1", - port: int = 8188, - spawn: bool = False, - comfyui_path: str = None, - comfyui_args: list = None, - workspace: str = None, - comfyui_log_level: str = None, - ): - """ - Initialize the ComfyStream client to use the ComfyUI API. - - Args: - host: The hostname or IP address of the ComfyUI server - port: The port number of the ComfyUI server - spawn: If True, launch a ComfyUI server when start_server is called - comfyui_path: Path to the ComfyUI main.py file (required if spawn=True) - comfyui_args: Additional arguments for ComfyUI - workspace: The workspace directory for ComfyUI - comfyui_log_level: The logging level for ComfyUI - """ - self.host = host - self.port = port - self.spawn = spawn - self.comfyui_path = comfyui_path - self.comfyui_args = comfyui_args or [] - self.workspace = workspace - self._comfyui_proc = None - - # Server launch is deferred to start_server method - - self.server_address = f"ws://{host}:{port}/ws" - self.api_base_url = f"http://{host}:{port}/api" - self.client_id = str(uuid.uuid4()) - self.ws = None - self.current_prompts = [] - self.running_prompts = {} - self.cleanup_lock = asyncio.Lock() - self.buffer = BytesIO() - self.execution_complete_event = asyncio.Event() - - self._ws_listener_task = None - self._prompt_id = None - self._current_frame_id = None # Track the current frame being processed - self._frame_id_mapping = {} # Map prompt_ids to frame_ids - - self.comfyui_log_level = comfyui_log_level - - logger.info(f"[Client[{self.port}]: ComfyStreamClient initialized with host: {host}, port: {port}, client_id: {self.client_id}") - - async def set_prompts(self, prompts: List[Dict]): - """Set prompts and run them (compatible with original interface)""" - # Convert prompts (this already randomizes seeds, but we'll enhance it) - self.current_prompts = [convert_prompt(prompt) for prompt in prompts] - - # Create tasks for each prompt - for idx in range(len(self.current_prompts)): - task = asyncio.create_task(self.run_prompt(idx)) - self.running_prompts[idx] = task - - logger.info(f"[Client[{self.port}]: Set {len(self.current_prompts)} prompts for execution") - - async def update_prompts(self, prompts: List[Dict]): - """Update existing prompts (compatible with original interface)""" - if len(prompts) != len(self.current_prompts): - raise ValueError( - "Number of updated prompts must match the number of currently running prompts." - ) - self.current_prompts = [convert_prompt(prompt) for prompt in prompts] - logger.info(f"[Client[{self.port}]: Updated {len(self.current_prompts)} prompts") - - async def run_prompt(self, prompt_index: int): - """Run a prompt continuously, processing new frames as they arrive""" - logger.info(f"[Client[{self.port}]: Running prompt {prompt_index}") - - # Make sure WebSocket is connected - await self._connect_websocket() - - # Always set execution complete at start to allow first frame to be processed - self.execution_complete_event.set() - - try: - while True: - # Wait until we have tensor data available before sending prompt - if tensor_cache.image_inputs.empty(): - await asyncio.sleep(0.01) # Reduced sleep time for faster checking - continue - - # Clear event before sending a new prompt - if self.execution_complete_event.is_set(): - # Reset execution state for next frame - self.execution_complete_event.clear() - - # Queue the prompt with the current frame - await self._execute_prompt(prompt_index) - - # Wait for execution completion with timeout - try: - logger.debug("Waiting for execution to complete (max 10 seconds)...") - await asyncio.wait_for(self.execution_complete_event.wait(), timeout=10.0) - logger.debug("Execution complete, ready for next frame") - except asyncio.TimeoutError: - logger.error("Timeout waiting for execution, forcing continuation") - self.execution_complete_event.set() - else: - # If execution is not complete, check again shortly - await asyncio.sleep(0.01) # Short sleep to prevent CPU spinning - - except asyncio.CancelledError: - logger.info(f"[Client[{self.port}]: Prompt {prompt_index} execution cancelled") - raise - except Exception as e: - logger.error(f"Error in run_prompt: {str(e)}") - raise - - async def _connect_websocket(self): - """Connect to the ComfyUI WebSocket endpoint""" - try: - if self.ws is not None and self.ws.open: - return self.ws - - # Close existing connection if any - if self.ws is not None: - try: - await self.ws.close() - except: - pass - self.ws = None - - logger.info(f"[Client[{self.port}]: Connecting to WebSocket at {self.server_address}?clientId={self.client_id}") - - try: - # Connect with proper error handling - self.ws = await websockets.connect( - f"{self.server_address}?clientId={self.client_id}", - ping_interval=5, - ping_timeout=10, - close_timeout=5, - max_size=None, # No limit on message size - ssl=None - ) - - logger.info(f"[Client[{self.port}]: WebSocket connected successfully") - - # Start the listener task if not already running - if self._ws_listener_task is None or self._ws_listener_task.done(): - self._ws_listener_task = asyncio.create_task(self._ws_listener()) - logger.info(f"[Client[{self.port}]: Started WebSocket listener task") - - return self.ws - - except (websockets.exceptions.WebSocketException, ConnectionError, OSError) as e: - logger.error(f"WebSocket connection error: {e}") - self.ws = None - # Signal execution complete to prevent hanging if connection fails - self.execution_complete_event.set() - # Retry after a delay - await asyncio.sleep(1) - return await self._connect_websocket() - - except Exception as e: - logger.error(f"Unexpected error in _connect_websocket: {e}") - self.ws = None - # Signal execution complete to prevent hanging - self.execution_complete_event.set() - return None - - async def _ws_listener(self): - """Listen for WebSocket messages and process them""" - try: - logger.info(f"[Client[{self.port}]: WebSocket listener started") - while True: - if self.ws is None: - try: - await self._connect_websocket() - except Exception as e: - logger.error(f"Error connecting to WebSocket: {e}") - await asyncio.sleep(1) - continue - - try: - # Receive and process messages - message = await self.ws.recv() - - if isinstance(message, str): - # Process JSON messages - await self._handle_text_message(message) - else: - # Handle binary data - likely image preview or tensor data - await self._handle_binary_message(message) - - except websockets.exceptions.ConnectionClosed: - logger.info(f"[Client[{self.port}]: WebSocket connection closed") - self.ws = None - await asyncio.sleep(1) - except Exception as e: - logger.error(f"Error in WebSocket listener: {e}") - await asyncio.sleep(1) - - except asyncio.CancelledError: - logger.info(f"[Client[{self.port}]: WebSocket listener cancelled") - raise - except Exception as e: - logger.error(f"Unexpected error in WebSocket listener: {e}") - - async def _handle_text_message(self, message: str): - """Process text (JSON) messages from the WebSocket""" - try: - data = json.loads(message) - message_type = data.get("type", "unknown") - - # logger.info(f"[Client[{self.port}]: Received message type: {message_type}") - # logger.debug(f"{data}") - - # Example output - ''' - 15:15:58 [INFO] Received message type: executing - 15:15:58 [INFO] {'type': 'executing', 'data': {'node': '18', 'display_node': '18', 'prompt_id': '6f983049-dca4-4935-9f36-d2bff7b744fa'}} - 15:15:58 [INFO] Received message type: executed - 15:15:58 [INFO] {'type': 'executed', 'data': {'node': '18', 'display_node': '18', 'output': {'images': [{'source': 'websocket', 'content-type': 'image/png', 'type': 'output'}]}, 'prompt_id': '6f983049-dca4-4935-9f36-d2bff7b744fa'}} - 15:15:58 [INFO] Received message type: execution_success - 15:15:58 [INFO] {'type': 'execution_success', 'data': {'prompt_id': '6f983049-dca4-4935-9f36-d2bff7b744fa', 'timestamp': 1744139758250}} - ''' - - # Handle different message types to have fun with! - - ''' - if message_type == "status": - # Status message with comfy_ui's queue information - queue_remaining = data.get("data", {}).get("queue_remaining", 0) - exec_info = data.get("data", {}).get("exec_info", {}) - if queue_remaining == 0 and not exec_info: - logger.info("Queue empty, no active execution") - else: - logger.info(f"Queue status: {queue_remaining} items remaining") - ''' - - ''' - if message_type == "progress": - if "data" in data and "value" in data["data"]: - progress = data["data"]["value"] - max_value = data["data"].get("max", 100) - # Log the progress for debugging - logger.info(f"Progress: {progress}/{max_value}") - ''' - - if message_type == "execution_start": - if "data" in data and "prompt_id" in data["data"]: - self._prompt_id = data["data"]["prompt_id"] - logger.debug(f"[Client[{self.port}]: Execution started for prompt {self._prompt_id}") - - # Let's queue the next prompt here! - self.execution_complete_event.set() - - ''' - if message_type == "executing": - if "data" in data: - if "prompt_id" in data["data"]: - self._prompt_id = data["data"]["prompt_id"] - if "node" in data["data"]: - node_id = data["data"]["node"] - logger.info(f"Executing node: {node_id}") - - # Let's check which node_id is a LoadImageBase64 node - # and set the execution complete event for that node - for prompt_index, prompt in enumerate(self.current_prompts): - for node_id, node in prompt.items(): - if (node_id == executing_node_id and isinstance(node, dict) and node.get("class_type") in ["LoadImageBase64"]): - logger.info(f"Setting execution complete event for LoadImageBase64 node {node_id}") - self.execution_complete_event.set() - break - ''' - - ''' - if message_type == "executed": - # This is sent when a node is completely done - if "data" in data and "node" in data["data"]: - node_id = data["data"]["node"] - logger.info(f"Node execution complete: {node_id}") - ''' - - ''' - if message_type in ["execution_cached", "execution_error", "execution_complete", "execution_interrupted"]: - logger.info(f"{message_type} message received for prompt {self._prompt_id}") - # Always signal completion for these terminal states - # self.execution_complete_event.set() - logger.info(f"Set execution_complete_event from {message_type}") - pass - ''' - - except json.JSONDecodeError: - logger.error(f"Invalid JSON message: {message[:100]}...") - except Exception as e: - logger.error(f"Error handling WebSocket message: {e}") - # Signal completion on error to prevent hanging - self.execution_complete_event.set() - - async def _handle_binary_message(self, binary_data): - """Process binary messages from the WebSocket""" - try: - # Early return if message is too short - if len(binary_data) <= 8: - # self.execution_complete_event.set() - return - - # Extract header data only when needed - event_type = int.from_bytes(binary_data[:4], byteorder='little') - format_type = int.from_bytes(binary_data[4:8], byteorder='little') - data = binary_data[8:] - - # Quick check for image format - is_image = data[:2] in [b'\xff\xd8', b'\x89\x50'] - if not is_image: - # self.execution_complete_event.set() - return - - # Process image data directly - try: - img = Image.open(BytesIO(data)) - if img.mode != "RGB": - img = img.convert("RGB") - - with torch.no_grad(): - tensor = torch.from_numpy(np.array(img)).float().permute(2, 0, 1).unsqueeze(0) / 255.0 - - # Try to get frame_id from mapping using current prompt_id - frame_id = None - if hasattr(self, '_prompt_id') and self._prompt_id in self._frame_id_mapping: - frame_id = self._frame_id_mapping.get(self._prompt_id) - logger.debug(f"Using frame_id {frame_id} from prompt_id {self._prompt_id}") - elif hasattr(self, '_current_frame_id') and self._current_frame_id is not None: - frame_id = self._current_frame_id - logger.debug(f"Using current frame_id {frame_id}") - - tensor_cache.image_outputs.put_nowait((frame_id, tensor)) - logger.debug(f"Added tensor with frame_id {frame_id} to output queue") - - # We will execute the next prompt from message_type == "execution_start" instead - # self.execution_complete_event.set() - - except Exception as img_error: - logger.error(f"Error processing image: {img_error}") - # self.execution_complete_event.set() - - except Exception as e: - logger.error(f"Error handling binary message: {e}") - # self.execution_complete_event.set() - - async def _execute_prompt(self, prompt_index: int): - try: - # Get the prompt to execute - prompt = self.current_prompts[prompt_index] - - # Check if we have a frame waiting to be processed - if not tensor_cache.image_inputs.empty(): - # Get the most recent frame only - frame = None - while not tensor_cache.image_inputs.empty(): - frame = tensor_cache.image_inputs.get_nowait() - - self._current_frame_id = getattr(frame.side_data, 'frame_id', None) - - if self._current_frame_id is None: - logger.error("No frame_id found in side_data") - self.execution_complete_event.set() - return - - if not (hasattr(frame, 'side_data') and hasattr(frame.side_data, 'input')): - logger.error( - "Frame object from queue ('tensor_cache.image_inputs') is not structured as " - "expected (missing side_data.input). Skipping processing for this frame." - ) - self.execution_complete_event.set() # Allow next cycle - return - - # Find LoadImageBase64 nodes first - load_image_nodes = [] - for node_id, node in prompt.items(): - if isinstance(node, dict) and node.get("class_type") in ["LoadImageBase64"]: - load_image_nodes.append(node_id) - - if not load_image_nodes: - logger.warning("No LoadImageBase64 nodes found in the prompt") - self.execution_complete_event.set() - return - - # Process the tensor ONLY if we have nodes to send it to - try: - tensor = getattr(frame.side_data, 'input', None) - - if tensor is None: - logger.error("No tensor found in side_data") - self.execution_complete_event.set() - return - - try: - # TODO: Why is the UI sending different sizes? Should be fixed no? This breaks tensorrt - # I'm sometimes seeing (BCHW): torch.Size([1, 384, 384, 3]), H=384, W=3 - # Ensure minimum size of 512x512 - - # Handle batch dimension if present - if len(tensor.shape) == 4: # BCHW format - tensor = tensor[0] # Take first image from batch - - # Normalize to CHW format consistently - if len(tensor.shape) == 3 and tensor.shape[2] == 3: # HWC format - tensor = tensor.permute(2, 0, 1) # Convert to CHW - - # Handle single-channel case - if len(tensor.shape) == 3 and tensor.shape[0] == 1: - tensor = tensor.repeat(3, 1, 1) # Convert grayscale to RGB - - # Ensure tensor is on CPU - if tensor.is_cuda: - tensor = tensor.cpu() - - # Always resize to 512x512 for consistency (faster than checking dimensions first) - tensor = tensor.unsqueeze(0) # Add batch dim for interpolate - tensor = torch.nn.functional.interpolate( - tensor, size=(512, 512), mode='bilinear', align_corners=False - ) - tensor = tensor[0] # Remove batch dimension - - # ==== - # PIL method - ''' - # Direct conversion to PIL without intermediate numpy step for speed - tensor_np = (tensor.permute(1, 2, 0).clamp(0, 1) * 255).to(torch.uint8).numpy() - img = Image.fromarray(tensor_np) - img.save(self.buffer, format="JPEG", quality=90, optimize=True) - ''' - - # ==== - # torchvision method (more performant - TODO: need to test further) - # Direct conversion to PIL without intermediate numpy step - # Fast JPEG encoding with reduced quality for better performance - tensor_pil = to_pil_image(tensor.clamp(0, 1)) - tensor_pil.save(self.buffer, format="JPEG", quality=75, optimize=True) - # ==== - - self.buffer.seek(0) - img_base64 = base64.b64encode(self.buffer.getvalue()).decode('utf-8') - - except Exception as e: - logger.warning(f"Error in tensor processing: {e}, creating fallback image") - # Create a standard 512x512 placeholder if anything fails - img = Image.new('RGB', (512, 512), color=(100, 149, 237)) - self.buffer = BytesIO() - img.save(self.buffer, format="JPEG", quality=90) - self.buffer.seek(0) - img_base64 = base64.b64encode(self.buffer.getvalue()).decode('utf-8') - - # Add timestamp for cache busting (once, outside the try/except) - timestamp = int(time.time() * 1000) - - # Update all nodes with the SAME base64 string - for node_id in load_image_nodes: - prompt[node_id]["inputs"]["image"] = img_base64 - prompt[node_id]["inputs"]["_timestamp"] = timestamp - # Use timestamp as cache buster - prompt[node_id]["inputs"]["_cache_buster"] = str(timestamp) - - except Exception as e: - logger.error(f"Error converting tensor to base64: {e}") - self.execution_complete_event.set() - return - - # Execute the prompt via API - async with aiohttp.ClientSession() as session: - api_url = f"{self.api_base_url}/prompt" - payload = { - "prompt": prompt, - "client_id": self.client_id - } - - async with session.post(api_url, json=payload) as response: - if response.status == 200: - result = await response.json() - self._prompt_id = result.get("prompt_id") - - self._frame_id_mapping[self._prompt_id] = self._current_frame_id - # logger.debug(f"Mapped prompt_id {self._prompt_id} to frame_id {self._current_frame_id}") - - else: - error_text = await response.text() - logger.error(f"Error queueing prompt: {response.status} - {error_text}") - self.execution_complete_event.set() - else: - logger.debug("No tensor in input queue, skipping prompt execution") - self.execution_complete_event.set() - - except Exception as e: - logger.error(f"Error executing prompt: {e}") - self.execution_complete_event.set() - - async def cleanup(self): - """Clean up resources and reset connection state completely.""" - logger.info(f"[Client[{self.port}]: Performing client cleanup and connection reset") - - # Cancel the WebSocket listener task - if self._ws_listener_task is not None and not self._ws_listener_task.done(): - self._ws_listener_task.cancel() - try: - await self._ws_listener_task - except asyncio.CancelledError: - pass - self._ws_listener_task = None - - # Close WebSocket connection - if self.ws is not None: - try: - await self.ws.close() - except Exception as e: - logger.error(f"Error closing WebSocket: {e}") - finally: - self.ws = None - - # Reset all state variables - self._prompt_id = None - self._current_frame_id = None - self._frame_id_mapping = {} - self.current_prompts = [] - - # Cancel any running prompt tasks - for task in self.running_prompts.values(): - if not task.done(): - task.cancel() - self.running_prompts = {} - - # Reset the execution event - self.execution_complete_event.set() - - # Clean up queues - await self.cleanup_queues() - - # Reset buffer - self.buffer = BytesIO() - - logger.info(f"[Client[{self.port}]: Client cleanup completed, connection will be reestablished on next use") - - async def cleanup_queues(self): - """Clean up tensor queues""" - while not tensor_cache.image_inputs.empty(): - tensor_cache.image_inputs.get() - - while not tensor_cache.audio_inputs.empty(): - tensor_cache.audio_inputs.get() - - while tensor_cache.image_outputs.qsize() > 0: - try: - await tensor_cache.image_outputs.get() - except: - pass - - while tensor_cache.audio_outputs.qsize() > 0: - try: - await tensor_cache.audio_outputs.get() - except: - pass - - logger.info(f"[Client[{self.port}]: Tensor queues cleared") - - def put_video_input(self, frame): - if tensor_cache.image_inputs.full(): - tensor_cache.image_inputs.get(block=True) - tensor_cache.image_inputs.put(frame) - - def put_audio_input(self, frame): - """Put audio frame into tensor cache""" - tensor_cache.audio_inputs.put(frame) - - async def get_video_output(self): - """Get processed video frame from tensor cache""" - frame_id, tensor = await tensor_cache.image_outputs.get() - logger.debug(f"[Client[{self.port}]: Got processed tensor from output queue with frame_id {frame_id}") - # Return both the frame_id and tensor to help with ordering in the pipeline - return frame_id, tensor - - async def get_audio_output(self): - """Get processed audio frame from tensor cache""" - return await tensor_cache.audio_outputs.get() - - async def get_available_nodes(self) -> Dict[int, Dict[str, Any]]: - """ - Retrieves detailed information about the nodes used in the current prompts - by querying the ComfyUI /object_info API endpoint. - - Returns: - A dictionary where keys are prompt indices and values are dictionaries - mapping node IDs to their information, matching the required UI format. - - The idea of this function is to replicate the functionality of comfy embedded client import_all_nodes_in_workspace - TODO: Why not support ckpt_name and lora_name as dropdown selectors on UI? - """ - - if not self.current_prompts: - logger.warning("No current prompts set. Cannot get node info.") - return {} - - all_prompts_nodes_info: Dict[int, Dict[str, Any]] = {} - all_needed_class_types = set() - - # Collect all unique class types across all prompts first - for prompt in self.current_prompts: - for node in prompt.values(): - if isinstance(node, dict) and 'class_type' in node: - all_needed_class_types.add(node['class_type']) - - class_info_cache: Dict[str, Any] = {} - - async with aiohttp.ClientSession() as session: - fetch_tasks = [] - for class_type in all_needed_class_types: - api_url = f"{self.api_base_url}/object_info/{class_type}" - fetch_tasks.append(self._fetch_object_info(session, api_url, class_type)) - - results = await asyncio.gather(*fetch_tasks, return_exceptions=True) - - # Populate cache from results - for result in results: - if isinstance(result, tuple) and len(result) == 2: - class_type, info = result - if info: - class_info_cache[class_type] = info - elif isinstance(result, Exception): - logger.error(f"An exception occurred during object_info fetch task: {result}") - - # Now, build the output structure for each prompt - for prompt_index, prompt in enumerate(self.current_prompts): - nodes_info: Dict[str, Any] = {} - for node_id, node_data in prompt.items(): - if not isinstance(node_data, dict) or 'class_type' not in node_data: - logger.debug(f"Skipping invalid node data for node_id {node_id} in prompt {prompt_index}") - continue - - class_type = node_data['class_type'] - # Let's skip the native api i/o nodes for now, subject to change - if class_type in ['LoadImageBase64', 'SendImageWebsocket']: - continue - - node_info = { - 'class_type': class_type, - 'inputs': {} - } - - specific_class_info = class_info_cache.get(class_type) - - if specific_class_info and 'input' in specific_class_info: - input_definitions = {} - required_inputs = specific_class_info['input'].get('required', {}) - optional_inputs = specific_class_info['input'].get('optional', {}) - - if isinstance(required_inputs, dict): - input_definitions.update(required_inputs) - if isinstance(optional_inputs, dict): - input_definitions.update(optional_inputs) - - if 'inputs' in node_data and isinstance(node_data['inputs'], dict): - for input_name, input_value in node_data['inputs'].items(): - input_def = input_definitions.get(input_name) - - # Format the input value as a tuple if it's a list with node references - if isinstance(input_value, list) and len(input_value) == 2 and isinstance(input_value[0], str) and isinstance(input_value[1], int): - input_value = tuple(input_value) # Convert [node_id, output_index] to (node_id, output_index) - - # Create Enum-like objects for certain types - def create_enum_format(type_name): - # Format the type as - return f"" - - input_details = { - 'value': input_value, - 'type': 'unknown', # Default type - 'min': None, - 'max': None, - 'widget': None # Default, all widgets should be None to match format - } - - # Parse the definition tuple/list if valid - if isinstance(input_def, (list, tuple)) and len(input_def) > 0: - config = None - # Check for config dict as the second element - if len(input_def) > 1 and isinstance(input_def[1], dict): - config = input_def[1] - - # Check for COMBO type (first element is list/tuple of options) - if input_name in ['ckpt_name', 'lora_name']: - # For checkpoint and lora names, use STRING type instead of combo list - input_details['type'] = create_enum_format('STRING') - elif isinstance(input_def[0], (list, tuple)): - input_details['type'] = input_def[0] # Type is the list of options - # Don't set widget for combo - else: - # Regular type (string or enum) - input_type_raw = input_def[0] - # Keep raw type name for certain types to match format - if hasattr(input_type_raw, 'name'): - # Special handling for CLIP and STRING to match expected format - type_name = str(input_type_raw.name) - if type_name in ('CLIP', 'STRING'): - # Create Enum-like format that matches format in desired output - input_details['type'] = create_enum_format(type_name) - else: - input_details['type'] = type_name - else: - # For non-enum types - input_details['type'] = str(input_type_raw) - - # Extract constraints/widget from config if it exists - if config: - for key in ['min', 'max']: # Only include these, skip widget/step/round - if key in config: - input_details[key] = config[key] - - node_info['inputs'][input_name] = input_details - else: - logger.debug(f"Node {node_id} ({class_type}) has no 'inputs' dictionary.") - elif class_type not in class_info_cache: - logger.warning(f"No cached info found for class_type: {class_type} (node_id: {node_id}).") - else: - logger.debug(f"Class info for {class_type} does not contain an 'input' key.") - # If class info exists but no 'input' key, still add node with empty inputs dict - - nodes_info[node_id] = node_info - - # Only add if there are any nodes after filtering - if nodes_info: - all_prompts_nodes_info[prompt_index] = nodes_info - - return all_prompts_nodes_info - - def launch_comfyui_server(self): - """Launch ComfyUI as a subprocess""" - logger.info(f"[Client[{self.port}]: Spawning ComfyUI server...") - - # Start the process - try: - # Build the command with just the basics - cmd = [ - "python", self.comfyui_path, - ] - - # Add the arguments from comfyui_args (retreived from pipeline_api) - ''' - [ - "--disable-cuda-malloc", - "--gpu-only", - "--preview-method", "none", - "--listen", - "--cuda-device", str(cuda_device), - "--fast", - "--enable-cors-header", "\"*\"", - "--port", str(port), - "--disable-xformers", - ] - ''' - cmd.extend(self.comfyui_args) - - # Set up environment with proper encoding and ANSI support - env = os.environ.copy() - env.update({ - 'PYTHONIOENCODING': 'utf-8', - 'PYTHONLEGACYWINDOWSSTDIO': 'utf-8', - 'FORCE_COLOR': '1' - }) - - logger.info(f"[Client[{self.port}]: Starting ComfyUI with command: {' '.join(cmd)}") - self._comfyui_proc = subprocess.Popen( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - env=env, - text=True, - encoding='utf-8', - errors='replace', - bufsize=1, # Line buffered - ) - - def log_output(stream): - for line in iter(stream.readline, ''): - try: - if self.comfyui_log_level == "DEBUG": - # Strip ANSI codes if they cause problems - message = line.strip() - # Optional: Remove ANSI codes if they still cause issues - # import re - # message = re.sub(r'\033\[[0-9;]*[mGKH]', '', message) - logger.info(f"ComfyUI[{self.port}]: {message}") - except Exception as e: - logger.error(f"Error logging output: {e}") - - import threading - threading.Thread(target=log_output, args=(self._comfyui_proc.stdout,), daemon=True).start() - threading.Thread(target=log_output, args=(self._comfyui_proc.stderr,), daemon=True).start() - - logger.info(f"[Client[{self.port}]: Started ComfyUI process with PID {self._comfyui_proc.pid}") - except Exception as e: - logger.error(f"Failed to spawn ComfyUI: {e}") - raise - - def wait_for_server_ready(self, timeout=60, check_interval=0.5): - """Wait until the ComfyUI server is accepting connections""" - logger.info(f"[Client[{self.port}]: Waiting for ComfyUI server to be ready...") - - start_time = time.time() - while time.time() - start_time < timeout: - # Check if process is still running - if self._comfyui_proc and self._comfyui_proc.poll() is not None: - return_code = self._comfyui_proc.poll() - logger.error(f"ComfyUI process exited with code {return_code} before it was ready") - raise RuntimeError(f"ComfyUI process exited with code {return_code}") - - # Try to connect to the server - try: - with socket.create_connection((self.host, self.port), timeout=2): - logger.info(f"[Client[{self.port}]: ComfyUI server is now accepting connections") - return - except (ConnectionRefusedError, socket.timeout, OSError): - # Sleep and try again - time.sleep(check_interval) - - # If we get here, the server didn't start in time - logger.error(f"[Client[{self.port}]: Timed out waiting for ComfyUI server") - - if self._comfyui_proc: - self._comfyui_proc.terminate() - self._comfyui_proc = None - - raise RuntimeError(f"Timed out waiting for ComfyUI server on port {self.port}") - - async def _fetch_object_info(self, session: aiohttp.ClientSession, url: str, class_type: str) -> Optional[tuple[str, Any]]: - """Helper function to fetch object info for a single class type.""" - try: - logger.debug(f"Fetching object info for: {class_type} from {url}") - async with session.get(url) as response: - if response.status == 200: - try: - data = await response.json() - # Extract the actual node info from the nested structure - if class_type in data and isinstance(data[class_type], dict): - node_specific_info = data[class_type] - logger.debug(f"Successfully fetched and extracted info for {class_type}") - return class_type, node_specific_info - else: - logger.error(f"Unexpected response structure for {class_type}. Key missing or not a dict. Response: {data}") - - except aiohttp.ContentTypeError: - logger.error(f"Failed to decode JSON for {class_type}. Status: {response.status}, Content-Type: {response.headers.get('Content-Type')}, Response: {await response.text()[:200]}...") # Log beginning of text - except json.JSONDecodeError as e: - logger.error(f"Invalid JSON received for {class_type}. Status: {response.status}, Error: {e}, Response: {await response.text()[:200]}...") - else: - error_text = await response.text() - logger.error(f"Error fetching info for {class_type}: {response.status} - {error_text[:200]}...") - except aiohttp.ClientError as e: - logger.error(f"HTTP client error fetching info for {class_type} ({url}): {e}") - except Exception as e: - logger.error(f"Unexpected error fetching info for {class_type} ({url}): {e}") - - # Return class_type and None if any error occurred - return class_type, None \ No newline at end of file diff --git a/src/comfystream/client_multi.py b/src/comfystream/client_multi.py deleted file mode 100644 index 2e942c30..00000000 --- a/src/comfystream/client_multi.py +++ /dev/null @@ -1,502 +0,0 @@ -import asyncio -import logging -from typing import List, Union -import multiprocessing as mp -import os -import sys -import numpy as np -import torch -import av - -from comfystream.utils import convert_prompt -from comfystream.tensor_cache_multi import init_tensor_cache - -from comfy.cli_args_types import Configuration -from comfy.distributed.process_pool_executor import ProcessPoolExecutor -from comfy.api.components.schema.prompt import PromptDictInput -from comfy.client.embedded_comfy_client import EmbeddedComfyClient -from comfystream.frame_proxy import FrameProxy - -logger = logging.getLogger(__name__) - -def _test_worker_init(): - """Test function to verify worker process initialization.""" - return os.getpid() - -class ComfyStreamClient: - def __init__(self, - max_workers: int = 1, - **kwargs): - logger.info(f"[ComfyStreamClient] Main Process ID: {os.getpid()}") - logger.info(f"[ComfyStreamClient] __init__ start, max_workers: {max_workers}") - - # Store default dimensions - self.width = kwargs.get('width', 512) - self.height = kwargs.get('height', 512) - - # Ensure workspace path is absolute - if 'cwd' in kwargs: - if not os.path.isabs(kwargs['cwd']): - # Convert relative path to absolute path from current working directory - kwargs['cwd'] = os.path.abspath(kwargs['cwd']) - logger.info(f"[ComfyStreamClient] Using absolute workspace path: {kwargs['cwd']}") - - # Register TensorRT paths in main process BEFORE creating ComfyUI client - self.register_tensorrt_paths_main_process(kwargs.get('cwd')) - - # Cache nodes information in main process to avoid ProcessPoolExecutor conflicts - self._initialize_nodes_cache() - - logger.info("[ComfyStreamClient] Config kwargs: %s", kwargs) - - try: - self.config = Configuration(**kwargs) - logger.info("[ComfyStreamClient] Configuration created") - logger.info(f"[ComfyStreamClient] Current working directory: {os.getcwd()}") - - logger.info("[ComfyStreamClient] Initializing process executor") - ctx = mp.get_context("spawn") - logger.info(f"[ComfyStreamClient] Using multiprocessing context: {ctx.get_start_method()}") - - manager = ctx.Manager() - logger.info("[ComfyStreamClient] Created multiprocessing context and manager") - - self.image_inputs = manager.Queue(maxsize=50) - self.image_outputs = manager.Queue(maxsize=50) - self.audio_inputs = manager.Queue(maxsize=50) - self.audio_outputs = manager.Queue(maxsize=50) - logger.info("[ComfyStreamClient] Created manager queues") - - logger.info("[ComfyStreamClient] About to create ProcessPoolExecutor...") - - executor = ProcessPoolExecutor( - max_workers=max_workers, - initializer=init_tensor_cache, - initargs=(self.image_inputs, self.image_outputs, self.audio_inputs, self.audio_outputs, kwargs.get('cwd')) - ) - logger.info("[ComfyStreamClient] ProcessPoolExecutor created successfully") - - # Create EmbeddedComfyClient with the executor - logger.info("[ComfyStreamClient] Creating EmbeddedComfyClient with executor") - self.comfy_client = EmbeddedComfyClient(self.config, executor=executor) - logger.info("[ComfyStreamClient] EmbeddedComfyClient created successfully") - - # Submit a test task to ensure worker processes are initialized - logger.info("[ComfyStreamClient] Testing worker process initialization...") - test_future = executor.submit(_test_worker_init) - try: - worker_pid = test_future.result(timeout=30) # 30 second timeout - logger.info(f"[ComfyStreamClient] Worker process initialized successfully (PID: {worker_pid})") - except Exception as e: - logger.info(f"[ComfyStreamClient] Error initializing worker process: {str(e)}") - raise - - except Exception as e: - logger.info(f"[ComfyStreamClient] Error during initialization: {str(e)}") - logger.info(f"[ComfyStreamClient] Error type: {type(e)}") - import traceback - logger.info(f"[ComfyStreamClient] Error traceback: {traceback.format_exc()}") - raise - - self.running_prompts = {} - self.current_prompts = [] - self.cleanup_lock = asyncio.Lock() - self.max_workers = max_workers - self.worker_tasks = [] - self.next_worker = 0 - self.distribution_lock = asyncio.Lock() - self.shutting_down = False # Add shutdown flag - self.distribution_task = None # Track distribution task - logger.info("[ComfyStreamClient] Initialized successfully") - - async def set_prompts(self, prompts: List[PromptDictInput]): - logger.info("set_prompts start") - self.current_prompts = [convert_prompt(prompt) for prompt in prompts] - - # Start the distribution manager only if not already running - if self.distribution_task is None or self.distribution_task.done(): - self.shutting_down = False # Reset shutdown flag - self.distribution_task = asyncio.create_task(self.distribute_frames()) - self.running_prompts[-1] = self.distribution_task # Use -1 as a special key for the manager - logger.info("set_prompts end") - - async def distribute_frames(self): - """Manager that distributes frames across workers in round-robin fashion""" - logger.info(f"[ComfyStreamClient] Starting frame distribution manager") - - try: - # Initialize worker tasks - self.worker_tasks = [] - for worker_id in range(self.max_workers): - worker_task = asyncio.create_task(self.worker_loop(worker_id)) - self.worker_tasks.append(worker_task) - self.running_prompts[worker_id] = worker_task - - # Keep the manager running to monitor workers - while not self.shutting_down: - await asyncio.sleep(1.0) # Check periodically - - # Only restart crashed workers if we're not shutting down - if not self.shutting_down: - for worker_id, task in enumerate(self.worker_tasks): - if task.done(): - # Check if the task was cancelled (graceful shutdown) or crashed - if task.cancelled(): - logger.info(f"Worker {worker_id} was cancelled (graceful shutdown)") - else: - # Check if there was an exception - try: - task.result() - logger.info(f"Worker {worker_id} completed normally") - except Exception as e: - logger.warning(f"Worker {worker_id} crashed with error: {e}, restarting") - new_task = asyncio.create_task(self.worker_loop(worker_id)) - self.worker_tasks[worker_id] = new_task - self.running_prompts[worker_id] = new_task - - except asyncio.CancelledError: - logger.info("[ComfyStreamClient] Distribution manager cancelled") - except Exception as e: - logger.error(f"[ComfyStreamClient] Error in distribution manager: {e}") - finally: - logger.info("[ComfyStreamClient] Distribution manager stopped") - - async def worker_loop(self, worker_id: int): - """Simple worker loop - just process frames continuously""" - logger.info(f"[Worker {worker_id}] Started") - - frame_count = 0 - try: - while not self.shutting_down: - try: - # Simple round-robin prompt selection - prompt_index = worker_id % len(self.current_prompts) - current_prompt = self.current_prompts[prompt_index] - - # Just process the prompt - await self.comfy_client.queue_prompt(current_prompt) - frame_count += 1 - - except asyncio.CancelledError: - break - except Exception as e: - if self.shutting_down: - break - logger.error(f"[Worker {worker_id}] Error: {e}") - await asyncio.sleep(0.1) - finally: - logger.info(f"[Worker {worker_id}] Processed {frame_count} frames") - - async def cleanup(self): - async with self.cleanup_lock: - logger.info("[ComfyStreamClient] Starting cleanup...") - - # Set shutdown flag to stop workers gracefully - self.shutting_down = True - - # Cancel distribution task first - if self.distribution_task and not self.distribution_task.done(): - self.distribution_task.cancel() - try: - await self.distribution_task - except asyncio.CancelledError: - pass - - # Cancel all worker tasks - for task in self.worker_tasks: - if not task.done(): - task.cancel() - - # Wait for tasks to complete cancellation - if self.worker_tasks: - try: - await asyncio.gather(*self.worker_tasks, return_exceptions=True) - logger.info("[ComfyStreamClient] All worker tasks stopped") - except Exception as e: - logger.error(f"Error waiting for worker tasks: {e}") - - # Clear the tasks list - self.worker_tasks.clear() - self.running_prompts.clear() - - # Cleanup the ComfyUI client and its executor - if hasattr(self, 'comfy_client') and self.comfy_client.is_running: - try: - # Get the executor before closing the client - executor = getattr(self.comfy_client, 'executor', None) - - # Close the client first - await self.comfy_client.__aexit__(None, None, None) - logger.info("[ComfyStreamClient] ComfyUI client closed") - - # Then shutdown the executor and terminate processes - if executor: - logger.info("[ComfyStreamClient] Shutting down executor...") - # Shutdown the executor - executor.shutdown(wait=False) - - # Force terminate any remaining processes - if hasattr(executor, '_processes'): - for process in executor._processes: - if process.is_alive(): - logger.info(f"[ComfyStreamClient] Terminating worker process {process.pid}") - process.terminate() - # Give it a moment to terminate gracefully - try: - process.join(timeout=2.0) - except: - pass - # Force kill if still alive - if process.is_alive(): - logger.warning(f"[ComfyStreamClient] Force killing worker process {process.pid}") - process.kill() - - logger.info("[ComfyStreamClient] Executor shutdown completed") - - except Exception as e: - logger.error(f"Error during ComfyClient cleanup: {e}") - - await self.cleanup_queues() - - # Reset state for potential reuse - self.shutting_down = False - self.distribution_task = None - - logger.info("[ComfyStreamClient] Client cleanup complete") - - async def cleanup_queues(self): - # TODO: add for audio as well - while not self.image_inputs.empty(): - self.image_inputs.get() - - while not self.image_outputs.empty(): - self.image_outputs.get() - - def put_video_input(self, frame): - try: - # Check if frame is FrameProxy - if isinstance(frame, FrameProxy): - proxy = frame - else: - proxy = FrameProxy.avframe_to_frameproxy(frame) - - # Handle queue being full - if self.image_inputs.full(): - # logger.warning(f"[ComfyStreamClient] Input queue full, dropping oldest frame") - try: - self.image_inputs.get_nowait() - except Exception: - pass - - self.image_inputs.put_nowait(proxy) - # logger.info(f"[ComfyStreamClient] Video input queued. Queue size: {self.image_inputs.qsize()}") - except Exception as e: - logger.error(f"[ComfyStreamClient] Error putting video frame: {str(e)}") - - def put_audio_input(self, frame): - self.audio_inputs.put(frame) - - async def get_video_output(self): - try: - logger.debug(f"[ComfyStreamClient] get_video_output called - PID: {os.getpid()}") - tensor = await asyncio.wait_for( - asyncio.get_event_loop().run_in_executor(None, self.image_outputs.get), - timeout=5.0 - ) - # logger.info(f"[ComfyStreamClient] get_video_output returning tensor: {tensor.shape} - PID: {os.getpid()}") - return tensor - except asyncio.TimeoutError: - logger.warning(f"[ComfyStreamClient] get_video_output timeout - PID: {os.getpid()}") - return torch.zeros((1, 3, self.height, self.width), dtype=torch.float32) - except Exception as e: - logger.error(f"[ComfyStreamClient] Error getting video output: {str(e)} - PID: {os.getpid()}") - return torch.zeros((1, 3, self.height, self.width), dtype=torch.float32) - - async def get_audio_output(self): - loop = asyncio.get_event_loop() - return await loop.run_in_executor(None, self.audio_outputs.get) - - async def get_available_nodes(self): - """Get metadata and available nodes info using cached nodes to avoid ProcessPoolExecutor conflicts""" - if not self.current_prompts: - return {} - - # Use cached nodes instead of calling import_all_nodes_in_workspace from worker process - if self._nodes is None: - logger.warning("[ComfyStreamClient] Nodes cache not available, returning empty result") - return {} - - try: - all_prompts_nodes_info = {} - - for prompt_index, prompt in enumerate(self.current_prompts): - # Get set of class types we need metadata for - needed_class_types = { - node.get('class_type') - for node in prompt.values() - } - remaining_nodes = { - node_id - for node_id, node in prompt.items() - } - nodes_info = {} - - # Only process nodes until we've found all the ones we need - for class_type, node_class in self._nodes.NODE_CLASS_MAPPINGS.items(): - if not remaining_nodes: # Exit early if we've found all needed nodes - break - - if class_type not in needed_class_types: - continue - - # Get metadata for this node type (same as original get_node_metadata) - input_data = node_class.INPUT_TYPES() if hasattr(node_class, 'INPUT_TYPES') else {} - input_info = {} - - # Process required inputs - if 'required' in input_data: - for name, value in input_data['required'].items(): - if isinstance(value, tuple): - if len(value) == 1 and isinstance(value[0], list): - # Handle combo box case where value is ([option1, option2, ...],) - input_info[name] = { - 'type': 'combo', - 'value': value[0], # The list of options becomes the value - } - elif len(value) == 2: - input_type, config = value - input_info[name] = { - 'type': input_type, - 'required': True, - 'min': config.get('min', None), - 'max': config.get('max', None), - 'widget': config.get('widget', None) - } - elif len(value) == 1: - # Handle simple type case like ('IMAGE',) - input_info[name] = { - 'type': value[0] - } - else: - logger.error(f"Unexpected structure for required input {name}: {value}") - - # Process optional inputs with same logic - if 'optional' in input_data: - for name, value in input_data['optional'].items(): - if isinstance(value, tuple): - if len(value) == 1 and isinstance(value[0], list): - # Handle combo box case where value is ([option1, option2, ...],) - input_info[name] = { - 'type': 'combo', - 'value': value[0], # The list of options becomes the value - } - elif len(value) == 2: - input_type, config = value - input_info[name] = { - 'type': input_type, - 'required': False, - 'min': config.get('min', None), - 'max': config.get('max', None), - 'widget': config.get('widget', None) - } - elif len(value) == 1: - # Handle simple type case like ('IMAGE',) - input_info[name] = { - 'type': value[0] - } - else: - logger.error(f"Unexpected structure for optional input {name}: {value}") - - # Now process any nodes in our prompt that use this class_type - for node_id in list(remaining_nodes): - node = prompt[node_id] - if node.get('class_type') != class_type: - continue - - node_info = { - 'class_type': class_type, - 'inputs': {} - } - - if 'inputs' in node: - for input_name, input_value in node['inputs'].items(): - input_metadata = input_info.get(input_name, {}) - node_info['inputs'][input_name] = { - 'value': input_value, - 'type': input_metadata.get('type', 'unknown'), - 'min': input_metadata.get('min', None), - 'max': input_metadata.get('max', None), - 'widget': input_metadata.get('widget', None) - } - # For combo type inputs, include the list of options - if input_metadata.get('type') == 'combo': - node_info['inputs'][input_name]['value'] = input_metadata.get('value', []) - - nodes_info[node_id] = node_info - remaining_nodes.remove(node_id) - - all_prompts_nodes_info[prompt_index] = nodes_info - - return all_prompts_nodes_info - - except Exception as e: - logger.error(f"Error getting node info: {str(e)}") - return {} - - def register_tensorrt_paths_main_process(self, workspace_path): - """Register TensorRT paths in the main process for validation""" - try: - from comfy.cmd import folder_paths - - if workspace_path: - base_dir = workspace_path - tensorrt_models_dir = os.path.join(base_dir, "models", "tensorrt") - tensorrt_outputs_dir = os.path.join(base_dir, "outputs", "tensorrt") - else: - tensorrt_models_dir = os.path.join(folder_paths.models_dir, "tensorrt") - tensorrt_outputs_dir = os.path.join(folder_paths.models_dir, "outputs", "tensorrt") - - # logger.info(f"[ComfyStreamClient] Registering TensorRT paths in main process") - # logger.info(f"[ComfyStreamClient] TensorRT models dir: {tensorrt_models_dir}") - # logger.info(f"[ComfyStreamClient] TensorRT outputs dir: {tensorrt_outputs_dir}") - - # Register TensorRT paths - if "tensorrt" in folder_paths.folder_names_and_paths: - existing_paths = folder_paths.folder_names_and_paths["tensorrt"][0] - for path in [tensorrt_models_dir, tensorrt_outputs_dir]: - if path not in existing_paths: - existing_paths.append(path) - folder_paths.folder_names_and_paths["tensorrt"][1].add(".engine") - else: - folder_paths.folder_names_and_paths["tensorrt"] = ( - [tensorrt_models_dir, tensorrt_outputs_dir], - {".engine"} - ) - - # Verify registration - # available_files = folder_paths.get_filename_list("tensorrt") - # logger.info(f"[ComfyStreamClient] Main process TensorRT files: {available_files}") - - except Exception as e: - logger.error(f"[ComfyStreamClient] Error registering TensorRT paths in main process: {e}") - import traceback - logger.error(f"[ComfyStreamClient] Traceback: {traceback.format_exc()}") - - async def update_prompts(self, prompts: List[PromptDictInput]): - """Update the existing processing prompts without restarting workers.""" - - # Simply update the current prompts - worker loops will pick up changes on next iteration - self.current_prompts = [convert_prompt(prompt) for prompt in prompts] - - logger.info("[ComfyStreamClient] Prompts updated") - - def _initialize_nodes_cache(self): - """Initialize nodes cache in main process to avoid ProcessPoolExecutor conflicts""" - try: - logger.info("[ComfyStreamClient] Initializing nodes cache in main process...") - from comfy.nodes.package import import_all_nodes_in_workspace - self._nodes = import_all_nodes_in_workspace() - logger.info(f"[ComfyStreamClient] Cached {len(self._nodes.NODE_CLASS_MAPPINGS)} node types") - except Exception as e: - logger.error(f"[ComfyStreamClient] Error initializing nodes cache: {e}") - self._nodes = None \ No newline at end of file diff --git a/src/comfystream/frame_logging.py b/src/comfystream/frame_logging.py deleted file mode 100644 index a069b33f..00000000 --- a/src/comfystream/frame_logging.py +++ /dev/null @@ -1,366 +0,0 @@ -# frame_logging.py -# Developed by @buffmcbighuge (Marco Tundo) - -# You can generate graphs from the log file: -# python frame_logging.py --frame-logs frame_logs1.csv,frame_logs_2.csv,frame_logs_3.csv - -import csv -import os -import time -from typing import Optional, Dict, Any, List, Tuple -import argparse -import numpy as np -import pandas as pd -import matplotlib.pyplot as plt -import matplotlib.colors as mcolors - -def log_frame_timing( - frame_id: Optional[int], - frame_received_time: Optional[float], - frame_process_start_time: Optional[float], - frame_processed_time: Optional[float], - client_index: Optional[int] = None, - additional_metadata: Optional[Dict[str, Any]] = None, - csv_path: str = "frame_logs.csv" -): - """ - Log frame timing information to a CSV file with simplified metrics. - Args: - frame_id: The unique identifier for the frame - frame_received_time: Timestamp when the frame was received by pipeline - frame_process_start_time: Timestamp when processing began - frame_processed_time: Timestamp when processing completed - client_index: Index of the client that processed this frame - additional_metadata: Any additional data to log - csv_path: Path to the CSV file - """ - # Calculate processing latency - processing_latency = None - if frame_process_start_time is not None and frame_processed_time is not None: - processing_latency = (frame_processed_time - frame_process_start_time) * 1000 - - # Convert additional metadata to string if present - metadata_str = str(additional_metadata) if additional_metadata else None - - # Calculate absolute time for logging - current_time = time.time() - - # Determine if this is an input-only frame or a processed frame - is_processed = frame_process_start_time is not None and frame_processed_time is not None - frame_type = "processed" if is_processed else "input" - - # Prepare data based on frame type - if is_processed: - # For processed frames, include all columns - header = [ - "log_timestamp", "frame_id", "frame_type", - "frame_received_time", "frame_process_start_time", "frame_processed_time", - "processing_latency_ms", "client_index", "metadata" - ] - data = [ - current_time, frame_id, frame_type, - frame_received_time, frame_process_start_time, frame_processed_time, - processing_latency, client_index, metadata_str - ] - else: - # For input frames, only include relevant columns (skip processing-related columns) - header = [ - "log_timestamp", "frame_id", "frame_type", - "frame_received_time", "client_index", "metadata" - ] - data = [ - current_time, frame_id, frame_type, - frame_received_time, client_index, metadata_str - ] - - file_exists = os.path.isfile(csv_path) - file_empty = file_exists and os.path.getsize(csv_path) == 0 - - # Use pandas to handle the CSV file, which handles mixed column formats better - if not file_exists or file_empty: - # If file doesn't exist or is empty, create a new one with the full header - # This ensures the file always has all possible columns defined - full_header = [ - "log_timestamp", "frame_id", "frame_type", - "frame_received_time", "frame_process_start_time", "frame_processed_time", - "processing_latency_ms", "client_index", "metadata" - ] - pd.DataFrame(columns=full_header).to_csv(csv_path, index=False) - - # Now append the data - df = pd.DataFrame([dict(zip(header, data))]) - df.to_csv(csv_path, mode='a', header=False, index=False, columns=header) - -def process_log_file(csv_path: str) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, float]: - """ - Process a single log file and return the processed dataframes - """ - if not os.path.isfile(csv_path): - print(f"CSV file '{csv_path}' not found.") - return None, None, None, 0 - - df = pd.read_csv(csv_path) - if df.empty: - print(f"CSV file '{csv_path}' is empty.") - return None, None, None, 0 - - # Drop rows with missing essential times - df = df.dropna(subset=["frame_id", "frame_received_time"]) - if df.empty: - print("No valid timing data in CSV after dropping NA in essential time columns.") - return None, None, None, 0 - - # Sort by frame_id and calculate time since start - df = df.sort_values("frame_id").reset_index(drop=True) - stream_start_time = df["frame_received_time"].min() - df["time_since_start"] = df["frame_received_time"] - stream_start_time - - # Separate input and processed frames based on frame_type column - if "frame_type" in df.columns: - input_df = df[df["frame_type"] == "input"].copy() - processed_df = df[df["frame_type"] == "processed"].copy() - else: - # Backward compatibility - separate based on process timestamps - input_df = df[df["frame_process_start_time"].isna()].copy() - processed_df = df.dropna(subset=["frame_processed_time"]).copy() - - # Calculate time of processed frames relative to stream start - if not processed_df.empty: - processed_df.loc[:, "output_time_relative"] = processed_df["frame_processed_time"] - stream_start_time - - # Create a consistent timeline with fixed intervals based on overall activity - max_input_time = df["time_since_start"].max() if not df["time_since_start"].empty else 0 - max_output_time = processed_df["output_time_relative"].max() if not processed_df.empty else 0 - - max_time = max(max_input_time, max_output_time) - time_range = np.arange(0, int(max_time) + 1) - - # Initialize FPS arrays for consistent timeline - input_fps_counts = np.zeros(len(time_range)) - output_fps_counts = np.zeros(len(time_range)) - - # Count frames in each 1-second interval - for t_idx, t_sec in enumerate(time_range): - # For input FPS, count input frames by received time - if not input_df.empty: - input_mask = (input_df["time_since_start"] >= t_sec) & (input_df["time_since_start"] < t_sec + 1) - input_fps_counts[t_idx] = input_mask.sum() - - # For output FPS, count processed frames by processed time - if not processed_df.empty: - output_mask = (processed_df["output_time_relative"] >= t_sec) & (processed_df["output_time_relative"] < t_sec + 1) - output_fps_counts[t_idx] = output_mask.sum() - - fps_df = pd.DataFrame({ - "time_bin": time_range, - "input_fps": input_fps_counts, - "output_fps": output_fps_counts - }) - - # Apply smoothing - smoothing_window = 3 - fps_df["input_fps_smooth"] = fps_df["input_fps"].rolling(smoothing_window, min_periods=1).mean() - fps_df["output_fps_smooth"] = fps_df["output_fps"].rolling(smoothing_window, min_periods=1).mean() - - # Calculate frame intervals for input and output frames separately - # Only calculate intervals for the same frame type - if not input_df.empty: - input_df = input_df.sort_values("frame_received_time").reset_index(drop=True) - input_df.loc[:, "input_interval_s"] = input_df["frame_received_time"].diff() - input_df.loc[input_df["input_interval_s"] <= 0, "input_interval_s"] = np.nan - input_df.loc[:, "input_time_bin"] = input_df["time_since_start"].astype(int) - - if not processed_df.empty: - processed_df = processed_df.sort_values("frame_processed_time").reset_index(drop=True) - processed_df.loc[:, "output_interval_s"] = processed_df["frame_processed_time"].diff() - processed_df.loc[processed_df["output_interval_s"] <= 0, "output_interval_s"] = np.nan - processed_df.loc[:, "output_time_bin"] = processed_df["output_time_relative"].astype(int) - - # Calculate jitter as the standard deviation of frame intervals in each time bin - input_jitter = np.full(len(time_range), np.nan) - output_jitter = np.full(len(time_range), np.nan) - - for t_idx, t_sec in enumerate(time_range): - # Input jitter - variation in input frame arrival times - if not input_df.empty: - intervals = input_df.loc[input_df["input_time_bin"] == t_sec, "input_interval_s"] - if len(intervals.dropna()) > 1: - std_dev = intervals.std() * 1000 # Convert to ms - input_jitter[t_idx] = std_dev - - # Output jitter - variation in processed frame completion times - if not processed_df.empty: - intervals = processed_df.loc[processed_df["output_time_bin"] == t_sec, "output_interval_s"] - if len(intervals.dropna()) > 1: - std_dev = intervals.std() * 1000 # Convert to ms - output_jitter[t_idx] = std_dev - - jitter_df = pd.DataFrame({ - "time_bin": time_range, - "input_jitter_ms": input_jitter, - "output_jitter_ms": output_jitter - }) - - # Aggregate processing latency by time bin - if not processed_df.empty: - avg_latencies = processed_df.groupby("output_time_bin").agg({ - "processing_latency_ms": "mean" - }).reset_index() - - latency_df = pd.DataFrame({"time_bin": time_range}) - latency_df = pd.merge( - latency_df, - avg_latencies.rename(columns={"output_time_bin": "time_bin"}), - on="time_bin", - how="left" - ) - else: - latency_df = pd.DataFrame({ - "time_bin": time_range, - "processing_latency_ms": np.nan - }) - - return fps_df, jitter_df, latency_df, max_time - -def plot_multiple_frame_metrics(csv_paths: List[str]): - """ - Plot metrics from multiple log files on the same charts - """ - if not csv_paths: - print("No CSV files provided.") - return - - # Process each log file - all_data = [] - max_time_overall = 0 - - for csv_path in csv_paths: - fps_df, jitter_df, latency_df, max_time = process_log_file(csv_path) - if fps_df is not None: - all_data.append({ - 'path': csv_path, - 'fps_df': fps_df, - 'jitter_df': jitter_df, - 'latency_df': latency_df, - 'max_time': max_time - }) - max_time_overall = max(max_time_overall, max_time) - - if not all_data: - print("No valid data found in any of the provided CSV files.") - return - - # Create visualization with 3 subplots - fig, axs = plt.subplots(3, 1, figsize=(14, 12), sharex=True) - - # Generate a list of distinct colors for multiple datasets - # Use a subset of tab colors for better distinction - tab_colors = list(mcolors.TABLEAU_COLORS.values()) - - for i, data in enumerate(all_data): - # Get colors for this dataset - input_color = tab_colors[i % len(tab_colors)] - output_color = tab_colors[(i + len(tab_colors)//2) % len(tab_colors)] - - # Extract the filename without path and extension for legend - file_label = os.path.splitext(os.path.basename(data['path']))[0] - - # 1. FPS Plot - axs[0].plot( - data['fps_df']["time_bin"], - data['fps_df']["input_fps_smooth"], - label=f"Input FPS - {file_label}", - color=input_color, - linewidth=2 - ) - axs[0].plot( - data['fps_df']["time_bin"], - data['fps_df']["output_fps_smooth"], - label=f"Output FPS - {file_label}", - color=output_color, - linewidth=2, - linestyle='--' - ) - - # 2. Jitter Plot - axs[1].plot( - data['jitter_df']["time_bin"], - data['jitter_df']["input_jitter_ms"], - label=f"Input Jitter - {file_label}", - color=input_color, - alpha=0.7 - ) - axs[1].plot( - data['jitter_df']["time_bin"], - data['jitter_df']["output_jitter_ms"], - label=f"Output Jitter - {file_label}", - color=output_color, - alpha=0.7, - linestyle='--' - ) - - # 3. Processing Latency Plot - axs[2].plot( - data['latency_df']["time_bin"], - data['latency_df']["processing_latency_ms"], - label=f"Processing Latency - {file_label}", - color=output_color, - alpha=0.7 - ) - - # Configure axes and labels - axs[0].set_ylabel("FPS") - axs[0].set_title("Frame Rate") - axs[0].legend() - axs[0].grid(True) - axs[0].set_xlim(left=0, right=max_time_overall if max_time_overall > 0 else 1) - axs[0].set_ylim(bottom=0) - - axs[1].set_ylabel("Jitter (ms)") - axs[1].set_title("Frame Timing Jitter") - axs[1].legend() - axs[1].grid(True) - axs[1].set_xlim(left=0, right=max_time_overall if max_time_overall > 0 else 1) - - axs[2].set_ylabel("Latency (ms)") - axs[2].set_title("Processing Latency") - axs[2].legend() - axs[2].grid(True) - axs[2].set_xlim(left=0, right=max_time_overall if max_time_overall > 0 else 1) - - # Add x-axis label - fig.text(0.5, 0.04, 'Time (seconds)', ha='center', va='center', fontsize=12) - - plt.tight_layout() - plt.subplots_adjust(bottom=0.07) - - # Save combined plot - output_filename = "combined_frame_logs.png" - plt.savefig(output_filename) - print(f"Combined plot saved as {output_filename}") - plt.show() - -def plot_frame_metrics(csv_path: str = "frame_logs.csv"): - """ - Plot metrics from a single log file for backward compatibility - """ - # For single files, just call the multiple processing function with a list of one item - plot_multiple_frame_metrics([csv_path]) - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Plot frame timing metrics from CSV logs.") - parser.add_argument( - "--frame-logs", - type=str, - default="frame_logs.csv", - help="Path to the frame timing CSV log file(s) (comma-separated for multiple files)" - ) - args = parser.parse_args() - - # Check if multiple files are specified - csv_paths = [path.strip() for path in args.frame_logs.split(',')] - - if len(csv_paths) > 1: - plot_multiple_frame_metrics(csv_paths) - else: - plot_frame_metrics(csv_paths[0]) diff --git a/src/comfystream/pipeline.py b/src/comfystream/pipeline.py index e4c79bf8..3f2907bc 100644 --- a/src/comfystream/pipeline.py +++ b/src/comfystream/pipeline.py @@ -4,11 +4,13 @@ import asyncio import logging import time +import os +from collections import OrderedDict from typing import Any, Dict, Union, List, Optional from comfystream.client import ComfyStreamClient from comfystream.server.utils import temporary_log_level -from comfystream.frame_logging import log_frame_timing +from comfystream.frame_proxy import FrameProxy WARMUP_RUNS = 5 @@ -23,49 +25,100 @@ class Pipeline: postprocessing, and queue management. """ - def __init__(self, width: int = 512, height: int = 512, - comfyui_inference_log_level: Optional[int] = None, frame_log_file: Optional[str] = None, **kwargs): + def __init__(self, + width: int = 512, + height: int = 512, + max_workers: int = 1, + comfyui_inference_log_level: Optional[int] = None, + **kwargs): """Initialize the pipeline with the given configuration. Args: width: Width of the video frames (default: 512) height: Height of the video frames (default: 512) + max_workers: Number of worker processes (default: 1) comfyui_inference_log_level: The logging level for ComfyUI inference. - Defaults to None, using the global ComfyUI log level. - **kwargs: Additional arguments to pass to the ComfyStreamClient + **kwargs: Additional arguments to pass to the ComfyStreamClient (cwd, disable_cuda_malloc, etc.) """ - self.client = ComfyStreamClient(**kwargs) + self.client = ComfyStreamClient( + max_workers=max_workers, + **kwargs) self.width = width self.height = height self.video_incoming_frames = asyncio.Queue() self.audio_incoming_frames = asyncio.Queue() + self.output_buffer = asyncio.Queue(maxsize=6) # Small buffer to smooth output + self.input_frame_counter = 0 + + # Simple frame collection without ordering + self.collector_task = asyncio.create_task(self._collect_frames_simple()) + self.processed_audio_buffer = np.array([], dtype=np.int16) self._comfyui_inference_log_level = comfyui_inference_log_level + self.next_expected_frame_id = 0 + # Add a queue for frame log entries self.running = True - self.next_expected_frame_id = 0 - self.frame_log_file = frame_log_file - self.frame_log_queue = None # Initialize to None by default - if self.frame_log_file: - self.frame_log_queue = asyncio.Queue() - self.frame_logger_task = asyncio.create_task(self._process_frame_logs()) + async def _collect_frames_simple(self): + """Simple frame collector - no ordering, just buffer""" + try: + while self.running: + try: + tensor = await asyncio.wait_for(self.client.get_video_output(), timeout=0.1) + if tensor is not None: + # Just put it in the buffer, don't worry about order + try: + self.output_buffer.put_nowait(tensor) + except asyncio.QueueFull: + # Drop oldest frame if buffer is full + try: + self.output_buffer.get_nowait() + self.output_buffer.put_nowait(tensor) + except asyncio.QueueEmpty: + pass + except asyncio.TimeoutError: + pass + except Exception as e: + logger.error(f"Error collecting frame: {e}") + + await asyncio.sleep(0.001) # Minimal sleep + + except asyncio.CancelledError: + pass - async def warm_video(self): - """Warm up the video processing pipeline with dummy frames.""" - # Create dummy frame with the CURRENT resolution settings - dummy_frame = av.VideoFrame() - dummy_frame.side_data.input = torch.randn(1, self.height, self.width, 3) - - logger.info(f"Warming video pipeline with resolution {self.width}x{self.height}") + async def initialize(self, prompts): + await self.set_prompts(prompts) + await self.warm_video() - for _ in range(WARMUP_RUNS): - self.client.put_video_input(dummy_frame) - await self.client.get_video_output() + async def warm_video(self): + logger.info("[PipelineMulti] Starting warmup...") + for i in range(WARMUP_RUNS): + dummy_tensor = torch.randn(1, self.height, self.width, 3) + dummy_proxy = FrameProxy( + tensor=dummy_tensor, + width=self.width, + height=self.height, + pts=None, + time_base=None + ) + # Set frame_id for warmup frames (negative to distinguish from real frames) + dummy_proxy.side_data.frame_id = -(i + 1) + logger.debug(f"[PipelineMulti] Warmup: putting dummy frame {i+1}/{WARMUP_RUNS}") + self.client.put_video_input(dummy_proxy) + + # For warmup, we don't need to wait for ordered output + try: + out = await asyncio.wait_for(self.client.get_video_output(), timeout=30.0) + logger.debug(f"[PipelineMulti] Warmup: got output for dummy frame {i+1}/{WARMUP_RUNS}") + except asyncio.TimeoutError: + logger.warning(f"[PipelineMulti] Warmup frame {i+1} timed out") + + logger.info("[PipelineMulti] Warmup complete.") async def warm_audio(self): """Warm up the audio processing pipeline with dummy frames.""" @@ -98,6 +151,8 @@ async def update_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any await self.client.update_prompts(prompts) else: await self.client.update_prompts([prompts]) + + logger.info("[PipelineMulti] Prompts updated") async def put_video_frame(self, frame: av.VideoFrame): """Queue a video frame for processing. @@ -109,20 +164,13 @@ async def put_video_frame(self, frame: av.VideoFrame): frame.side_data.input = self.video_preprocess(frame) frame.side_data.skipped = True frame.side_data.frame_received_time = current_time - frame.side_data.frame_id = self.next_expected_frame_id + + # Assign frame ID and increment counter + frame_id = self.input_frame_counter + frame.side_data.frame_id = frame_id frame.side_data.client_index = -1 self.next_expected_frame_id += 1 - - # Log frame at input time to properly track input FPS - if self.frame_log_file: - await self.frame_log_queue.put({ - 'frame_id': frame.side_data.frame_id, - 'frame_received_time': frame.side_data.frame_received_time, - 'frame_process_start_time': None, - 'frame_processed_time': None, - 'client_index': frame.side_data.client_index, - 'csv_path': self.frame_log_file - }) + self.input_frame_counter += 1 self.client.put_video_input(frame) await self.video_incoming_frames.put(frame) @@ -161,6 +209,7 @@ def audio_preprocess(self, frame: av.AudioFrame) -> Union[torch.Tensor, np.ndarr """ return frame.to_ndarray().ravel().reshape(-1, 2).mean(axis=1).astype(np.int16) + def video_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.VideoFrame: """Postprocess a video frame after processing. @@ -170,8 +219,27 @@ def video_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.Video Returns: The postprocessed video frame """ + + # First ensure we have a tensor + if isinstance(output, np.ndarray): + output = torch.from_numpy(output) + + # Handle different tensor formats + if len(output.shape) == 4: # BCHW or BHWC format + if output.shape[1] != 3: # If BHWC format + output = output.permute(0, 3, 1, 2) # Convert BHWC to BCHW + output = output[0] # Take first image from batch -> CHW + elif len(output.shape) != 3: # Should be CHW at this point + raise ValueError(f"Unexpected tensor shape after batch removal: {output.shape}") + + # Convert CHW to HWC for video frame + output = output.permute(1, 2, 0) # CHW -> HWC + + # Convert to numpy and create video frame return av.VideoFrame.from_ndarray( - (output * 255.0).clamp(0, 255).to(dtype=torch.uint8).squeeze(0).cpu().numpy() + (output * 255.0).clamp(0, 255).to(dtype=torch.uint8).squeeze(0).cpu().numpy(), + # (output * 255.0).clamp(0, 255).to(dtype=torch.uint8).cpu().numpy(), + format='rgb24' ) def audio_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.AudioFrame: @@ -192,31 +260,26 @@ async def get_processed_video_frame(self) -> av.VideoFrame: Returns: The processed video frame """ - frame_process_start_time = time.time() async with temporary_log_level("comfy", self._comfyui_inference_log_level): out_tensor = await self.client.get_video_output() frame = await self.video_incoming_frames.get() while frame.side_data.skipped: frame = await self.video_incoming_frames.get() - + + # Get any available output (don't match input to output) + try: + out_tensor = await asyncio.wait_for(self.output_buffer.get(), timeout=1.0) + except asyncio.TimeoutError: + # Fallback: return input frame to maintain stream + logger.warning("Output timeout, using input frame") + return frame + + # Process and return processed_frame = self.video_postprocess(out_tensor) processed_frame.pts = frame.pts processed_frame.time_base = frame.time_base - frame_processed_time = time.time() - - # Log frame timing with simplified metrics - if self.frame_log_file: - await self.frame_log_queue.put({ - 'frame_id': frame.side_data.frame_id, - 'frame_received_time': frame.side_data.frame_received_time, - 'frame_process_start_time': frame_process_start_time, - 'frame_processed_time': frame_processed_time, - 'client_index': frame.side_data.client_index, - 'csv_path': self.frame_log_file - }) - return processed_frame async def get_processed_audio_frame(self) -> av.AudioFrame: @@ -251,28 +314,20 @@ async def get_nodes_info(self) -> Dict[str, Any]: async def cleanup(self): """Clean up resources used by the pipeline.""" + logger.info("[PipelineMulti] Starting pipeline cleanup...") + + # Set running flag to false to stop frame processing + self.running = False - # Cancel frame logger task if it exists - if hasattr(self, 'frame_logger_task') and self.frame_logger_task: - self.frame_logger_task.cancel() + # Cancel collector task + if hasattr(self, 'collector_task') and self.collector_task: + self.collector_task.cancel() try: - await self.frame_logger_task + await self.collector_task except asyncio.CancelledError: pass + # Clean up the client (this will gracefully shutdown workers) await self.client.cleanup() - - async def _process_frame_logs(self): - """Background task to process frame logs from queue""" - while self.running: - try: - # Get log entry from queue - log_entry = await self.frame_log_queue.get() - log_frame_timing(**log_entry) - - # Mark task as done - self.frame_log_queue.task_done() - except asyncio.CancelledError: - break - except Exception as e: - logger.error(f"Error in frame logging: {e}") \ No newline at end of file + + logger.info("[PipelineMulti] Pipeline cleanup complete") \ No newline at end of file diff --git a/src/comfystream/pipeline_api.py b/src/comfystream/pipeline_api.py deleted file mode 100644 index b6a0deee..00000000 --- a/src/comfystream/pipeline_api.py +++ /dev/null @@ -1,698 +0,0 @@ -import av -import torch -import numpy as np -import asyncio -import logging -import time -from collections import OrderedDict -import collections -import os -import fractions - -from typing import Any, Dict, Union, List, Optional, Deque -from comfystream.client_api import ComfyStreamClient -from comfystream.server.utils import temporary_log_level # Not sure exactly what this does -from comfystream.server.utils.config import ComfyConfig -from comfystream.frame_logging import log_frame_timing - -WARMUP_RUNS = 5 -logger = logging.getLogger(__name__) - - -class MultiServerPipeline: - def __init__( - self, - width: int = 512, - height: int = 512, - config_path: Optional[str] = None, - max_frame_wait_ms: int = 500, - client_mode: str = "toml", - workspace: str = None, - workers: int = 2, - cuda_devices: str = '0', - workers_start_port: int = 8195, - comfyui_log_level: str = None, - frame_log_file: Optional[str] = None, - ): - """Initialize the pipeline with the given configuration. - Args: - width: The width of the video frames. - height: The height of the video frames. - workers: The number of ComfyUI clients to spin up (if client_mode is "spawn"). - config_path: The path to the ComfyUI config toml file (if client_mode is "toml"). - max_frame_wait_ms: The maximum number of milliseconds to wait for a frame before dropping it. - client_mode: The mode to use for the ComfyUI clients. - "toml": Use a config file to describe clients. - "spawn": Spawn ComfyUI clients as external processes. - workers_start_port: The starting port number for worker processes (default: 8195). - cuda_devices: The list of CUDA devices to use for the ComfyUI clients. - comfyui_log_level: The logging level for ComfyUI - frame_log_file: The filename for the frame timing log (optional). - """ - - # There are two methods for starting the clients: - # 1. client_mode == "toml" -> Use a config file to describe clients. - # 2. client_mode == "spawn" -> Spawn ComfyUI clients as external processes. - - self.clients = [] - self.workspace = workspace - self.client_mode = client_mode - - if (client_mode == "toml"): - # TOML Mode: Use a config file to describe existing ComfyUI Instances - - # Load server configurations - self.config = ComfyConfig(config_path) - self.servers = self.config.get_servers() - elif (client_mode == "spawn"): - # SPAWN Mode: Spawn new ComfyUI Instances automatically - - self.workers = workers - self.workers_start_port = workers_start_port - self.cuda_devices = cuda_devices - - # Clients started in /offer (this is due to when the page refreshes, the clients automatically close) - # TODO: Perhaps a better way would be to keep the the clients alive while the server is alive? - # self.start_clients() - - self.width = width - self.height = height - - self.video_incoming_frames = asyncio.Queue() - self.audio_incoming_frames = asyncio.Queue() - - # Queue for processed frames from all clients - self.processed_video_frames = asyncio.Queue() - - # Track which client gets each frame (round-robin) - self.last_frame_time = 0 - self.current_client_index = 0 - self.client_frame_mapping = {} # Maps frame_id -> client_index - - # Frame ordering and timing - self.max_frame_wait_ms = max_frame_wait_ms # Max time to wait for a frame before dropping - self.next_expected_frame_id = None # Track expected frame ID - self.ordered_frames = OrderedDict() # Buffer for ordering frames (frame_id -> (timestamp, tensor)) - - # Audio processing - self.processed_audio_buffer = np.array([], dtype=np.int16) - - # Frame rate limiting - self.min_frame_interval = 1/30 # Limit to 30 FPS - - # Create background task for collecting processed frames - self.running = True - self.collector_task = asyncio.create_task(self._collect_processed_frames()) - - self.output_interval = 1/30 # Start with 30 FPS - self.last_output_time = None - self.frame_interval_history = collections.deque(maxlen=30) - # self.output_pacer_task = asyncio.create_task(self._dynamic_output_pacer()) - - self.comfyui_log_level = comfyui_log_level - - # Add a queue for frame log entries - self.frame_log_file = frame_log_file - self.frame_log_queue = None # Initialize to None by default - - if self.frame_log_file: - self.frame_log_queue = asyncio.Queue() - self.frame_logger_task = asyncio.create_task(self._process_frame_logs()) - - async def _collect_processed_frames(self): - """Background task to collect processed frames from all clients""" - try: - while self.running: - for i, client in enumerate(self.clients): - try: - # Non-blocking check if client has output ready - if hasattr(client, '_prompt_id') and client._prompt_id is not None: - try: - # Use wait_for with small timeout to avoid blocking - frame_id, out_tensor = await asyncio.wait_for( - client.get_video_output(), - timeout=0.001 - ) - - # Store frame with timestamp for ordering - current_time = time.time() - await self._add_frame_to_ordered_buffer(frame_id, current_time, out_tensor) - - # Remove the mapping - self.client_frame_mapping.pop(frame_id, None) - - # logger.debug(f"Collected processed frame from client {i}, frame_id: {frame_id}") - except asyncio.TimeoutError: - # No frame ready yet, continue - pass - except Exception as e: - logger.error(f"Error collecting frame from client {i}: {e}") - - # Check for frames that have waited too long - await self._check_frame_timeouts() - - # Small sleep to avoid CPU spinning - await asyncio.sleep(0.01) - except asyncio.CancelledError: - logger.info("Frame collector task cancelled") - except Exception as e: - logger.error(f"Unexpected error in frame collector: {e}") - - async def _add_frame_to_ordered_buffer(self, frame_id, timestamp, tensor): - """Add a processed frame to the ordered buffer""" - self.ordered_frames[frame_id] = (timestamp, tensor) - - # If this is the first frame, set the next expected frame ID - if self.next_expected_frame_id is None: - self.next_expected_frame_id = frame_id - - # Check if we can release any frames now - await self._release_ordered_frames() - - async def _release_ordered_frames(self): - if self.next_expected_frame_id is None: - return - - # Only release frames in strict sequential order - while self.ordered_frames and self.next_expected_frame_id in self.ordered_frames: - timestamp, tensor = self.ordered_frames.pop(self.next_expected_frame_id) - await self.processed_video_frames.put((self.next_expected_frame_id, tensor)) - logger.debug(f"Released frame {self.next_expected_frame_id} to output queue") - # Always increment to next sequential frame ID - self.next_expected_frame_id += 1 - - async def _check_frame_timeouts(self): - """Check for frames that have waited too long and handle them""" - if not self.ordered_frames or self.next_expected_frame_id is None: - return - - current_time = time.time() - - # If the next expected frame has timed out, skip it and move on - if self.next_expected_frame_id in self.ordered_frames: - timestamp, _ = self.ordered_frames[self.next_expected_frame_id] - wait_time_ms = (current_time - timestamp) * 1000 - - if wait_time_ms > self.max_frame_wait_ms: - # logger.warning(f"Frame {self.next_expected_frame_id} exceeded max wait time, releasing anyway") - # await self._release_ordered_frames() - - # Remove frame - self.ordered_frames.pop(self.next_expected_frame_id) - - # Check if we're missing the next expected frame and it's been too long - elif self.ordered_frames: - # The next frame we're expecting isn't in the buffer - # Check how long we've been waiting since the oldest frame in the buffer - oldest_frame_id = min(self.ordered_frames.keys()) - oldest_timestamp, _ = self.ordered_frames[oldest_frame_id] - wait_time_ms = (current_time - oldest_timestamp) * 1000 - - # If we've waited too long, skip the missing frame(s) - if wait_time_ms > self.max_frame_wait_ms: - logger.debug(f"Missing frame {self.next_expected_frame_id}, skipping to {oldest_frame_id}") - self.next_expected_frame_id = oldest_frame_id - await self._release_ordered_frames() - - async def warm_video(self): - # Create dummy frame with the CURRENT resolution settings (which might have been updated via control channel) - - tensor = torch.rand(1, 3, 512, 512) # Random values in [0,1] - dummy_frame = av.VideoFrame(width=512, height=512, format="rgb24") - dummy_frame.side_data.input = tensor - dummy_frame.side_data.frame_received_time = time.time() - - logger.info(f"Warming video pipeline with resolution {self.width}x{self.height}") - - # Warm up each client - warmup_tasks = [] - for i, client in enumerate(self.clients): - warmup_tasks.append(self._warm_client_video(client, i, dummy_frame)) - - # Wait for all warmup tasks to complete - await asyncio.gather(*warmup_tasks) - logger.info("Video pipeline warmup complete") - - async def _warm_client_video(self, client, client_index, dummy_frame): - """Warm up a single client""" - logger.info(f"Warming up client {client_index}") - - # Set frame input as dummyframe with side_data.input set to a random tensor - dummy_frame.side_data.input = torch.randn(1, self.height, self.width, 3) - dummy_frame.side_data.frame_id = -1 - - for i in range(WARMUP_RUNS): - logger.info(f"Client {client_index} warmup iteration {i+1}/{WARMUP_RUNS}") - client.put_video_input(dummy_frame) - try: - await asyncio.wait_for(client.get_video_output(), timeout=30) - except asyncio.TimeoutError: - logger.warning(f"Timeout waiting for warmup frame from client {client_index}") - except Exception as e: - logger.error(f"Error warming client {client_index}: {e}") - - async def warm_audio(self): - # For now, only use the first client for audio - if not self.clients: - logger.warning("No clients available for audio warmup") - return - - dummy_frame = av.AudioFrame() - dummy_frame.side_data.input = np.random.randint(-32768, 32767, int(48000 * 0.5), dtype=np.int16) - dummy_frame.sample_rate = 48000 - - for _ in range(WARMUP_RUNS): - self.clients[0].put_audio_input(dummy_frame) - await self.clients[0].get_audio_output() - - async def set_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]): - """Set the same prompts for all clients""" - if isinstance(prompts, dict): - prompts = [prompts] - - # Set prompts for each client - tasks = [] - for client in self.clients: - logger.info(f"Setting prompts for client {client.port}") - tasks.append(client.set_prompts(prompts)) - - await asyncio.gather(*tasks) - logger.info(f"Set prompts for {len(self.clients)} clients") - - async def update_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]): - """Update prompts for all clients""" - if isinstance(prompts, dict): - prompts = [prompts] - - # Update prompts for each client - tasks = [] - for client in self.clients: - tasks.append(client.update_prompts(prompts)) - - await asyncio.gather(*tasks) - logger.info(f"Updated prompts for {len(self.clients)} clients") - - async def put_video_frame(self, frame: av.VideoFrame): - ''' Put a video frame into the pipeline round-robin to all clients ''' - current_time = time.time() - - ''' - if current_time - self.last_frame_time < self.min_frame_interval: - print(f"Skipping frame due to rate limiting: {current_time - self.last_frame_time} seconds since last frame") - return # Skip frame if too soon - ''' - - self.last_frame_time = current_time - - # Generate a unique frame ID - use sequential IDs for better ordering - if not hasattr(self, 'next_frame_id'): - self.next_frame_id = 1 - - frame_id = self.next_frame_id - self.next_frame_id += 1 - - # Select the next client in round-robin fashion - client_index = self.current_client_index - self.current_client_index = (self.current_client_index + 1) % len(self.clients) - - # Store mapping of which client is processing this frame - self.client_frame_mapping[frame_id] = client_index - - # Set side data for the frame - frame.side_data.input = self.video_preprocess(frame) - frame.side_data.frame_id = frame_id - frame.side_data.skipped = False - - # Set receive time - frame.side_data.frame_received_time = current_time - frame.side_data.client_index = client_index - - # Log frame at input time to properly track input FPS - if self.frame_log_file: - await self.frame_log_queue.put({ - 'frame_id': frame_id, - 'frame_received_time': frame.side_data.frame_received_time, - 'frame_process_start_time': None, - 'frame_processed_time': None, - 'client_index': frame.side_data.client_index, - 'csv_path': self.frame_log_file - }) - - # Send frame to the selected client - self.clients[client_index].put_video_input(frame) - await self.video_incoming_frames.put(frame) - - async def put_audio_frame(self, frame: av.AudioFrame): - ''' Not implemented yet ''' - return - - # For now, only use the first client for audio - if not self.clients: - return - - frame.side_data.input = self.audio_preprocess(frame) - frame.side_data.skipped = False - self.clients[0].put_audio_input(frame) - await self.audio_incoming_frames.put(frame) - - def audio_preprocess(self, frame: av.AudioFrame) -> Union[torch.Tensor, np.ndarray]: - return frame.to_ndarray().ravel().reshape(-1, 2).mean(axis=1).astype(np.int16) - - def video_preprocess(self, frame: av.VideoFrame) -> Union[torch.Tensor, np.ndarray]: - """Preprocess a video frame before processing. - - Args: - frame: The video frame to preprocess - - Returns: - The preprocessed frame as a tensor or numpy array - """ - frame_np = frame.to_ndarray(format="rgb24").astype(np.float32) / 255.0 - return torch.from_numpy(frame_np).unsqueeze(0) - - def video_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.VideoFrame: - return av.VideoFrame.from_ndarray( - (output.squeeze(0).permute(1, 2, 0) * 255.0) - .clamp(0, 255) - .to(dtype=torch.uint8) - .cpu() - .numpy(), - format='rgb24' - ) - - def audio_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.AudioFrame: - return av.AudioFrame.from_ndarray(np.repeat(output, 2).reshape(1, -1)) - - async def get_processed_video_frame(self): - try: - frame = await self.video_incoming_frames.get() - - # Set process start time just before processing - frame_process_start_time = time.time() - - # Get the processed frame from our output queue - processed_frame_id, out_tensor = await self.processed_video_frames.get() - - # if (processed_frame_id != frame.side_data.frame_id): - # logger.warning(f"Processed frame ID {processed_frame_id} does not match expected frame ID {frame.side_data.frame_id}") - - # The processed frame and the video_incoming_frame is never the same - ''' - Processed frame ID 45 does not match expected frame ID 6 - Processed frame ID 47 does not match expected frame ID 7 - Processed frame ID 49 does not match expected frame ID 8 - ''' - # What does this mean? - - # Record the time when processing is complete - frame_processed_time = time.time() - - # Process the frame (post-processing) - processed_frame = self.video_postprocess(out_tensor) - processed_frame.pts = frame.pts - processed_frame.time_base = frame.time_base - - # Log frame timing with simplified metrics - if self.frame_log_file: - await self.frame_log_queue.put({ - 'frame_id': processed_frame_id, - 'frame_received_time': frame.side_data.frame_received_time, - 'frame_process_start_time': frame_process_start_time, - 'frame_processed_time': frame_processed_time, - 'client_index': frame.side_data.client_index, - 'csv_path': self.frame_log_file - }) - - return processed_frame - - except Exception as e: - logger.error(f"Error in get_processed_video_frame: {str(e)}") - # Create a black frame as fallback - black_frame = av.VideoFrame(width=self.width, height=self.height, format='rgb24') - - # Set timestamps to avoid TypeError during encoding - # Use default values that work with the aiortc encoding pipeline - black_frame.pts = 0 - black_frame.time_base = fractions.Fraction(1, 90000) # Standard video timebase - - return black_frame - - async def get_processed_audio_frame(self): - # Only use the first client for audio - if not self.clients: - logger.warning("No clients available for audio processing") - return av.AudioFrame(format='s16', layout='mono', samples=1024) - - frame = await self.audio_incoming_frames.get() - if frame.samples > len(self.processed_audio_buffer): - out_tensor = await self.clients[0].get_audio_output() - self.processed_audio_buffer = np.concatenate([self.processed_audio_buffer, out_tensor]) - out_data = self.processed_audio_buffer[:frame.samples] - self.processed_audio_buffer = self.processed_audio_buffer[frame.samples:] - - processed_frame = self.audio_postprocess(out_data) - processed_frame.pts = frame.pts - processed_frame.time_base = frame.time_base - processed_frame.sample_rate = frame.sample_rate - - return processed_frame - - async def get_nodes_info(self) -> Dict[str, Any]: - """Get information about all nodes in the current prompt including metadata.""" - # Note that we pull the node info from the first client (as they should all be the same) - # TODO: This is just retrofitting the functionality of the comfy embedded client, there could be major improvements here - nodes_info = await self.clients[0].get_available_nodes() - return nodes_info - - async def cleanup(self): - """Clean up resources used by the pipeline.""" - logger.info("Performing complete pipeline cleanup") - - # Cancel the dynamic output pacer task if it exists - if hasattr(self, "_pacer_task") and self._pacer_task is not None: - self._pacer_task.cancel() - try: - await self._pacer_task - except asyncio.CancelledError: - pass - self._pacer_task = None - - # Cancel any frame timeout tasks - if hasattr(self, "_timeout_task") and self._timeout_task is not None: - self._timeout_task.cancel() - try: - await self._timeout_task - except asyncio.CancelledError: - pass - self._timeout_task = None - - # Reset frame tracking state - self.next_expected_frame_id = None - self.ordered_frames.clear() - self.next_frame_id = 1 # Reset frame ID counter for new connection - self.client_frame_mapping.clear() - - # Clear any queued frames - while not self.video_incoming_frames.empty(): - try: - self.video_incoming_frames.get_nowait() - except asyncio.QueueEmpty: - break - - # Reset client state and connections - for i, client in enumerate(self.clients): - if client: - # Clean up client resources - try: - await client.cleanup() - except Exception as e: - logger.error(f"Error during client {i} cleanup: {e}") - - # Reset client connection status - if hasattr(client, 'ws_connected'): - client.ws_connected = False - - # Clear any client-specific execution state - if hasattr(client, 'prompt_executing'): - client.prompt_executing = False - - # Mark clients as needing reinitialization - self.clients_initialized = False - - # Clear any cached prompt mappings - if hasattr(self, "_prompt_ids"): - self._prompt_ids = {} - - # Reset warmup state - if hasattr(self, "_warmup_complete"): - self._warmup_complete = False - - # Reset any frame buffers - if hasattr(self, "_frame_buffer"): - self._frame_buffer.clear() - - # Ensure dynamic state like frame rate trackers are reset - if hasattr(self, "_last_frame_time"): - self._last_frame_time = None - - # Reset output counters - self.output_counter = 0 - - # Cancel frame logger task if it exists - if hasattr(self, 'frame_logger_task') and self.frame_logger_task: - self.frame_logger_task.cancel() - try: - await self.frame_logger_task - except asyncio.CancelledError: - pass - - logger.info("Pipeline cleanup completed, clients will be reinitialized on next connection") - - # This may not be needed anymore - more work is req to balance frame timing - ''' - async def _dynamic_output_pacer(self): - while self.running: - # Only release if the next expected frame is available - if self.next_expected_frame_id is not None and self.next_expected_frame_id in self.ordered_frames: - timestamp, tensor = self.ordered_frames.pop(self.next_expected_frame_id) - now = time.time() - - # Calculate dynamic interval based on output history - if self.last_output_time is not None: - actual_interval = now - self.last_output_time - self.frame_interval_history.append(actual_interval) - avg_interval = sum(self.frame_interval_history) / len(self.frame_interval_history) - self.output_interval = avg_interval - self.last_output_time = now - - await self.processed_video_frames.put((self.next_expected_frame_id, tensor)) - logger.debug(f"Released frame {self.next_expected_frame_id} to output queue") - - # Always increment to next sequential frame ID - self.next_expected_frame_id += 1 - - # Sleep for the dynamic interval, but don't sleep negative time - await asyncio.sleep(max(self.output_interval, 0.001)) - else: - # No frame ready, wait a bit and check again - await asyncio.sleep(0.005) - ''' - - async def start_clients(self): - """Start the clients based on the client_mode (TOML or spawn)""" - logger.info(f"Starting clients with mode: {self.client_mode}") - - self.clients = [] - self.startup_error = None - - try: - if hasattr(self, 'client_mode') and self.client_mode == "toml": - # Use config file to create clients - for server_config in self.servers: - self.clients.append(ComfyStreamClient( - host=server_config["host"], - port=server_config["port"], - spawn=False, - comfyui_log_level=self.comfyui_log_level, - )) - - elif hasattr(self, 'client_mode') and self.client_mode == "spawn": - # Spin up clients as external processes - ports = [] - cuda_device_list = [d.strip() for d in str(self.cuda_devices).split(',') if d.strip()] - for device_idx, cuda_device in enumerate(cuda_device_list): - for worker_idx in range(self.workers): - port = self.workers_start_port + len(ports) - ports.append(port) - client = ComfyStreamClient( - host="127.0.0.1", - port=port, - spawn=True, - comfyui_path=os.path.join(self.workspace, "main.py"), - workspace=self.workspace, - comfyui_args=[ - "--disable-cuda-malloc", - "--gpu-only", - "--preview-method", "none", - "--listen", - "--cuda-device", str(cuda_device), - "--fast", - "--enable-cors-header", "\"*\"", - "--port", str(port), - "--disable-xformers", - ], - comfyui_log_level=self.comfyui_log_level, - ) - self.clients.append(client) - logger.info(f"Created worker {worker_idx+1}/{self.workers} for CUDA device {cuda_device} on port {port}") - - else: - raise ValueError(f"Unknown client_mode: {getattr(self, 'client_mode', 'None')}") - - # Start all ComfyUI servers in parallel if in spawn mode - if hasattr(self, 'client_mode') and self.client_mode == "spawn": - try: - # Get all spawn clients - spawn_clients = [client for client in self.clients if client.spawn] - if spawn_clients: - logger.info(f"Starting {len(spawn_clients)} ComfyUI servers in parallel") - - # First validate all clients (keeping original validation logic) - for client in spawn_clients: - # These checks are from the original start_server method - if not client.comfyui_path: - raise ValueError("comfyui_path must be provided when spawn=True") - if not os.path.exists(client.comfyui_path): - raise FileNotFoundError(f"ComfyUI path does not exist: {client.comfyui_path}") - - # Start all server processes WITHOUT waiting for them to be ready - for client in spawn_clients: - client.launch_comfyui_server() - - # Now wait for all servers to be ready in parallel using thread pool - await asyncio.gather(*[ - asyncio.to_thread(client.wait_for_server_ready) - for client in spawn_clients - ]) - - except Exception as e: - # Clean up any clients that might have started - for client in self.clients: - if hasattr(client, '_comfyui_proc') and client._comfyui_proc: - try: - client._comfyui_proc.terminate() - except: - pass - - self.clients = [] - self.startup_error = str(e) - logger.error(f"Failed to start ComfyUI servers: {e}") - return None - - logger.info(f"Initialized {len(self.clients)} clients") - return self.clients - - except Exception as e: - self.startup_error = str(e) - logger.error(f"Error starting clients: {e}") - self.clients = [] - return None - - async def _process_frame_logs(self): - """Background task to process frame logs from queue""" - while self.running: - try: - # Get log entry from queue - log_entry = await self.frame_log_queue.get() - - log_frame_timing(**log_entry) - - # Mark task as done - self.frame_log_queue.task_done() - except asyncio.CancelledError: - break - except Exception as e: - logger.error(f"Error in frame logging: {e}") - -# For backwards compatibility, maintain the original Pipeline name -Pipeline = MultiServerPipeline \ No newline at end of file diff --git a/src/comfystream/pipeline_multi.py b/src/comfystream/pipeline_multi.py deleted file mode 100644 index eae6a7f3..00000000 --- a/src/comfystream/pipeline_multi.py +++ /dev/null @@ -1,349 +0,0 @@ -import av -import torch -import numpy as np -import asyncio -import logging -import time -import os -from collections import OrderedDict -from typing import Any, Dict, Union, List, Optional - -from comfystream.client_multi import ComfyStreamClient -from comfystream.server.utils import temporary_log_level -from comfystream.frame_logging import log_frame_timing -from comfystream.frame_proxy import FrameProxy - -WARMUP_RUNS = 5 - -logger = logging.getLogger(__name__) - - -class Pipeline: - """A pipeline for processing video and audio frames using ComfyUI. - - This class provides a high-level interface for processing video and audio frames - through a ComfyUI-based processing pipeline. It handles frame preprocessing, - postprocessing, and queue management. - """ - - def __init__(self, - width: int = 512, - height: int = 512, - max_workers: int = 1, - comfyui_inference_log_level: Optional[int] = None, - frame_log_file: Optional[str] = None, - **kwargs): - """Initialize the pipeline with the given configuration. - - Args: - width: Width of the video frames (default: 512) - height: Height of the video frames (default: 512) - max_workers: Number of worker processes (default: 1) - comfyui_inference_log_level: The logging level for ComfyUI inference. - frame_log_file: Path to frame timing log file - **kwargs: Additional arguments to pass to the ComfyStreamClient (cwd, disable_cuda_malloc, etc.) - """ - self.client = ComfyStreamClient( - max_workers=max_workers, - **kwargs) - self.width = width - self.height = height - - self.video_incoming_frames = asyncio.Queue() - self.audio_incoming_frames = asyncio.Queue() - - # Remove complex frame ordering - just use a simple buffer - self.output_buffer = asyncio.Queue(maxsize=6) # Small buffer to smooth output - self.input_frame_counter = 0 - - # Simple frame collection without ordering - self.collector_task = asyncio.create_task(self._collect_frames_simple()) - - self.processed_audio_buffer = np.array([], dtype=np.int16) - - self._comfyui_inference_log_level = comfyui_inference_log_level - - # Add a queue for frame log entries - self.running = True - self.frame_log_file = frame_log_file - self.frame_log_queue = None # Initialize to None by default - - if self.frame_log_file: - self.frame_log_queue = asyncio.Queue() - self.frame_logger_task = asyncio.create_task(self._process_frame_logs()) - - async def _collect_frames_simple(self): - """Simple frame collector - no ordering, just buffer""" - try: - while self.running: - try: - tensor = await asyncio.wait_for(self.client.get_video_output(), timeout=0.1) - if tensor is not None: - # Just put it in the buffer, don't worry about order - try: - self.output_buffer.put_nowait(tensor) - except asyncio.QueueFull: - # Drop oldest frame if buffer is full - try: - self.output_buffer.get_nowait() - self.output_buffer.put_nowait(tensor) - except asyncio.QueueEmpty: - pass - except asyncio.TimeoutError: - pass - except Exception as e: - logger.error(f"Error collecting frame: {e}") - - await asyncio.sleep(0.001) # Minimal sleep - - except asyncio.CancelledError: - pass - - async def initialize(self, prompts): - await self.set_prompts(prompts) - await self.warm_video() - - async def warm_video(self): - logger.info("[PipelineMulti] Starting warmup...") - for i in range(WARMUP_RUNS): - dummy_tensor = torch.randn(1, self.height, self.width, 3) - dummy_proxy = FrameProxy( - tensor=dummy_tensor, - width=self.width, - height=self.height, - pts=None, - time_base=None - ) - # Set frame_id for warmup frames (negative to distinguish from real frames) - dummy_proxy.side_data.frame_id = -(i + 1) - logger.debug(f"[PipelineMulti] Warmup: putting dummy frame {i+1}/{WARMUP_RUNS}") - self.client.put_video_input(dummy_proxy) - - # For warmup, we don't need to wait for ordered output - try: - out = await asyncio.wait_for(self.client.get_video_output(), timeout=30.0) - logger.debug(f"[PipelineMulti] Warmup: got output for dummy frame {i+1}/{WARMUP_RUNS}") - except asyncio.TimeoutError: - logger.warning(f"[PipelineMulti] Warmup frame {i+1} timed out") - - logger.info("[PipelineMulti] Warmup complete.") - - async def warm_audio(self): - """Warm up the audio processing pipeline with dummy frames.""" - dummy_frame = av.AudioFrame() - dummy_frame.side_data.input = np.random.randint(-32768, 32767, int(48000 * 0.5), dtype=np.int16) # TODO: adds a lot of delay if it doesn't match the buffer size, is warmup needed? - dummy_frame.sample_rate = 48000 - - for _ in range(WARMUP_RUNS): - self.client.put_audio_input(dummy_frame) - await self.client.get_audio_output() - - async def set_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]): - """Set the processing prompts for the pipeline. - - Args: - prompts: Either a single prompt dictionary or a list of prompt dictionaries - """ - if isinstance(prompts, list): - await self.client.set_prompts(prompts) - else: - await self.client.set_prompts([prompts]) - - async def update_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]): - """Update the existing processing prompts.""" - if isinstance(prompts, list): - await self.client.update_prompts(prompts) - else: - await self.client.update_prompts([prompts]) - - logger.info("Prompts updated") - - async def put_video_frame(self, frame: av.VideoFrame): - current_time = time.time() - frame.side_data.input = self.video_preprocess(frame) - frame.side_data.skipped = True - frame.side_data.frame_received_time = current_time - - # Assign frame ID and increment counter - frame_id = self.input_frame_counter - frame.side_data.frame_id = frame_id - frame.side_data.client_index = -1 - self.input_frame_counter += 1 - - self.client.put_video_input(frame) - await self.video_incoming_frames.put(frame) - - async def put_audio_frame(self, frame: av.AudioFrame): - """Queue an audio frame for processing. - - Args: - frame: The audio frame to process - """ - frame.side_data.input = self.audio_preprocess(frame) - frame.side_data.skipped = True - self.client.put_audio_input(frame) - await self.audio_incoming_frames.put(frame) - - def video_preprocess(self, frame: av.VideoFrame) -> Union[torch.Tensor, np.ndarray]: - """Preprocess a video frame before processing. - - Args: - frame: The video frame to preprocess - - Returns: - The preprocessed frame as a tensor or numpy array - """ - frame_np = frame.to_ndarray(format="rgb24").astype(np.float32) / 255.0 - return torch.from_numpy(frame_np).unsqueeze(0) - - def audio_preprocess(self, frame: av.AudioFrame) -> Union[torch.Tensor, np.ndarray]: - """Preprocess an audio frame before processing. - - Args: - frame: The audio frame to preprocess - - Returns: - The preprocessed frame as a tensor or numpy array - """ - return frame.to_ndarray().ravel().reshape(-1, 2).mean(axis=1).astype(np.int16) - - - def video_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.VideoFrame: - """Postprocess a tensor in BCHW format back to video frame.""" - # First ensure we have a tensor - if isinstance(output, np.ndarray): - output = torch.from_numpy(output) - - # Handle different tensor formats - if len(output.shape) == 4: # BCHW or BHWC format - if output.shape[1] != 3: # If BHWC format - output = output.permute(0, 3, 1, 2) # Convert BHWC to BCHW - output = output[0] # Take first image from batch -> CHW - elif len(output.shape) != 3: # Should be CHW at this point - raise ValueError(f"Unexpected tensor shape after batch removal: {output.shape}") - - # Convert CHW to HWC for video frame - output = output.permute(1, 2, 0) # CHW -> HWC - - # Convert to numpy and create video frame - return av.VideoFrame.from_ndarray( - (output * 255.0).clamp(0, 255).to(dtype=torch.uint8).cpu().numpy(), - format='rgb24' - ) - - def audio_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.AudioFrame: - """Postprocess an audio frame after processing. - - Args: - output: The processed output tensor or numpy array - - Returns: - The postprocessed audio frame - """ - return av.AudioFrame.from_ndarray(np.repeat(output, 2).reshape(1, -1)) - - # TODO: make it generic to support purely generative video cases - async def get_processed_video_frame(self) -> av.VideoFrame: - # Get input frame - frame = await self.video_incoming_frames.get() - - # Get any available output (don't match input to output) - try: - out_tensor = await asyncio.wait_for(self.output_buffer.get(), timeout=1.0) - except asyncio.TimeoutError: - # Fallback: return input frame to maintain stream - logger.warning("Output timeout, using input frame") - return frame - - # Process and return - processed_frame = self.video_postprocess(out_tensor) - processed_frame.pts = frame.pts - processed_frame.time_base = frame.time_base - - frame_processed_time = time.time() - - # Log frame at input time to properly track input FPS - if self.frame_log_file: - await self.frame_log_queue.put({ - 'frame_id': frame.side_data.frame_id, - 'frame_received_time': frame.side_data.frame_received_time, - 'frame_process_start_time': 0, # TODO: We dont know the start time of the frame processing - 'frame_processed_time': frame_processed_time, - 'client_index': frame.side_data.client_index, - 'csv_path': self.frame_log_file - }) - - return processed_frame - - async def get_processed_audio_frame(self) -> av.AudioFrame: - """Get the next processed audio frame. - - Returns: - The processed audio frame - """ - frame = await self.audio_incoming_frames.get() - if frame.samples > len(self.processed_audio_buffer): - async with temporary_log_level("comfy", self._comfyui_inference_log_level): - out_tensor = await self.client.get_audio_output() - self.processed_audio_buffer = np.concatenate([self.processed_audio_buffer, out_tensor]) - out_data = self.processed_audio_buffer[:frame.samples] - self.processed_audio_buffer = self.processed_audio_buffer[frame.samples:] - - processed_frame = self.audio_postprocess(out_data) - processed_frame.pts = frame.pts - processed_frame.time_base = frame.time_base - processed_frame.sample_rate = frame.sample_rate - - return processed_frame - - async def get_nodes_info(self) -> Dict[str, Any]: - """Get information about all nodes in the current prompt including metadata. - - Returns: - Dictionary containing node information - """ - nodes_info = await self.client.get_available_nodes() - return nodes_info - - async def cleanup(self): - """Clean up resources used by the pipeline.""" - logger.info("[PipelineMulti] Starting pipeline cleanup...") - - # Set running flag to false to stop frame processing - self.running = False - - # Cancel collector task - if hasattr(self, 'collector_task') and self.collector_task: - self.collector_task.cancel() - try: - await self.collector_task - except asyncio.CancelledError: - pass - - # Cancel frame logger task if it exists - if hasattr(self, 'frame_logger_task') and self.frame_logger_task: - self.frame_logger_task.cancel() - try: - await self.frame_logger_task - except asyncio.CancelledError: - pass - - # Clean up the client (this will gracefully shutdown workers) - await self.client.cleanup() - - logger.info("[PipelineMulti] Pipeline cleanup complete") - - async def _process_frame_logs(self): - """Background task to process frame logs from queue""" - while self.running: - try: - # Get log entry from queue - log_entry = await self.frame_log_queue.get() - log_frame_timing(**log_entry) - - # Mark task as done - self.frame_log_queue.task_done() - except asyncio.CancelledError: - break - except Exception as e: - logger.error(f"Error in frame logging: {e}") \ No newline at end of file diff --git a/src/comfystream/tensor_cache.py b/src/comfystream/tensor_cache.py index 0216f73b..1cc3c73b 100644 --- a/src/comfystream/tensor_cache.py +++ b/src/comfystream/tensor_cache.py @@ -1,14 +1,127 @@ +from comfystream import tensor_cache +import logging +import queue import torch -import numpy as np +import asyncio +import os +logger = logging.getLogger(__name__) -from queue import Queue -from asyncio import Queue as AsyncQueue +image_inputs = None +image_outputs = None -from typing import Union +audio_inputs = None +audio_outputs = None -# TODO: improve eviction policy fifo might not be the best, skip alternate frames instead -image_inputs: Queue[Union[torch.Tensor, np.ndarray]] = Queue(maxsize=1) -image_outputs: AsyncQueue[Union[torch.Tensor, np.ndarray]] = AsyncQueue() +# Global frame ID tracking for worker processes +current_frame_id = None +frame_id_mapping = {} # Maps tensor id to frame_id -audio_inputs: Queue[Union[torch.Tensor, np.ndarray]] = Queue() -audio_outputs: AsyncQueue[Union[torch.Tensor, np.ndarray]] = AsyncQueue() +class FrameData: + """Wrapper class to carry frame metadata through the processing pipeline""" + def __init__(self, tensor, frame_id=None): + self.tensor = tensor + self.frame_id = frame_id + +# Create wrapper classes that match the interface of the original queues +class MultiProcessInputQueue: + def __init__(self, mp_queue): + self.queue = mp_queue + + def get(self, block=True, timeout=None): + result = self.queue.get(block=block, timeout=timeout) + + # Extract frame metadata and store it globally for this worker + global current_frame_id + if hasattr(result, 'side_data') and hasattr(result.side_data, 'frame_id'): + current_frame_id = result.side_data.frame_id + # logger.info(f"[MultiProcessInputQueue] Frame {current_frame_id} retrieved by worker PID: {os.getpid()}") + + return result + + def get_nowait(self): + result = self.queue.get_nowait() + + # Extract frame metadata and store it globally for this worker + global current_frame_id + if hasattr(result, 'side_data') and hasattr(result.side_data, 'frame_id'): + current_frame_id = result.side_data.frame_id + # logger.info(f"[MultiProcessInputQueue] Frame {current_frame_id} retrieved (nowait) by worker PID: {os.getpid()}") + + return result + + def put(self, item, block=True, timeout=None): + return self.queue.put(item, block=block, timeout=timeout) + + def put_nowait(self, item): + return self.queue.put_nowait(item) + + def empty(self): + return self.queue.empty() + + def full(self): + return self.queue.full() + +class MultiProcessOutputQueue: + def __init__(self, mp_queue): + self.queue = mp_queue + + async def get(self): + # Convert synchronous get to async + loop = asyncio.get_event_loop() + result = await loop.run_in_executor(None, self.queue.get) + + # Check if this is a tuple with frame_id + if isinstance(result, tuple) and len(result) == 2: + frame_id, tensor = result + return (frame_id, tensor) + else: + # Backward compatibility - return just the tensor + return result + + async def put(self, item): + # Convert synchronous put to async + loop = asyncio.get_event_loop() + # Ensure tensor is on CPU before sending + if torch.is_tensor(item): + item = item.cpu() + # logger.info(f"[MultiProcessOutputQueue] Frame sent from worker PID: {os.getpid()}") + return await loop.run_in_executor(None, self.queue.put, item) + + def put_nowait(self, item): + try: + # Ensure tensor is on CPU + if torch.is_tensor(item): + item = item.cpu() + self.queue.put_nowait(item) + except queue.Full: + # Simple: drop one old frame and try again + try: + self.queue.get_nowait() + self.queue.put_nowait(item) + except: + pass # If still fails, just drop this frame + +def init_tensor_cache(image_inputs, image_outputs, audio_inputs, audio_outputs, workspace_path=None): + """Initialize the tensor cache for a worker process. + + Args: + image_inputs: Multiprocessing Queue for input images + image_outputs: Multiprocessing Queue for output images + audio_inputs: Multiprocessing Queue for input audio + audio_outputs: Multiprocessing Queue for output audio + workspace_path: The ComfyUI workspace path (should be C:\sd\ComfyUI-main) + """ + logger.info(f"[init_tensor_cache] Setting up tensor_cache queues in worker - PID: {os.getpid()}") + logger.info(f"[init_tensor_cache] Workspace path: {workspace_path}") + logger.info(f"[init_tensor_cache] Current working directory: {os.getcwd()}") + + # Replace the queues with our wrapped versions that match the original interface + tensor_cache.image_inputs = MultiProcessInputQueue(image_inputs) + tensor_cache.image_outputs = MultiProcessOutputQueue(image_outputs) + tensor_cache.audio_inputs = MultiProcessInputQueue(audio_inputs) + tensor_cache.audio_outputs = MultiProcessOutputQueue(audio_outputs) + + logger.info(f"[init_tensor_cache] tensor_cache.image_outputs id: {id(tensor_cache.image_outputs)} - PID: {os.getpid()}") + logger.info(f"[init_tensor_cache] Initialization complete - PID: {os.getpid()}") + + return os.getpid() # Return PID for verification \ No newline at end of file diff --git a/src/comfystream/tensor_cache_multi.py b/src/comfystream/tensor_cache_multi.py deleted file mode 100644 index ebcf113b..00000000 --- a/src/comfystream/tensor_cache_multi.py +++ /dev/null @@ -1,245 +0,0 @@ -# TODO: add better frame management, improve eviction policy fifo might not be the best, skip alternate frames instead -# TODO: also make the tensor_cache solution backward compatible for when not using process pool -- after the multi process solution is stable -from comfystream import tensor_cache -import logging -import queue -import torch -import asyncio -import os -from comfy.cmd import folder_paths -logger = logging.getLogger(__name__) - -image_inputs = None -image_outputs = None - -audio_inputs = None -audio_outputs = None - -# Global frame ID tracking for worker processes -current_frame_id = None -frame_id_mapping = {} # Maps tensor id to frame_id - -class FrameData: - """Wrapper class to carry frame metadata through the processing pipeline""" - def __init__(self, tensor, frame_id=None): - self.tensor = tensor - self.frame_id = frame_id - -# Create wrapper classes that match the interface of the original queues -class MultiProcessInputQueue: - def __init__(self, mp_queue): - self.queue = mp_queue - - def get(self, block=True, timeout=None): - result = self.queue.get(block=block, timeout=timeout) - - # Extract frame metadata and store it globally for this worker - global current_frame_id - if hasattr(result, 'side_data') and hasattr(result.side_data, 'frame_id'): - current_frame_id = result.side_data.frame_id - # logger.info(f"[MultiProcessInputQueue] Frame {current_frame_id} retrieved by worker PID: {os.getpid()}") - - return result - - def get_nowait(self): - result = self.queue.get_nowait() - - # Extract frame metadata and store it globally for this worker - global current_frame_id - if hasattr(result, 'side_data') and hasattr(result.side_data, 'frame_id'): - current_frame_id = result.side_data.frame_id - # logger.info(f"[MultiProcessInputQueue] Frame {current_frame_id} retrieved (nowait) by worker PID: {os.getpid()}") - - return result - - def put(self, item, block=True, timeout=None): - return self.queue.put(item, block=block, timeout=timeout) - - def put_nowait(self, item): - return self.queue.put_nowait(item) - - def empty(self): - return self.queue.empty() - - def full(self): - return self.queue.full() - -class MultiProcessOutputQueue: - def __init__(self, mp_queue): - self.queue = mp_queue - - async def get(self): - # Convert synchronous get to async - loop = asyncio.get_event_loop() - result = await loop.run_in_executor(None, self.queue.get) - - # Check if this is a tuple with frame_id - if isinstance(result, tuple) and len(result) == 2: - frame_id, tensor = result - return (frame_id, tensor) - else: - # Backward compatibility - return just the tensor - return result - - async def put(self, item): - # Convert synchronous put to async - loop = asyncio.get_event_loop() - # Ensure tensor is on CPU before sending - if torch.is_tensor(item): - item = item.cpu() - # logger.info(f"[MultiProcessOutputQueue] Frame sent from worker PID: {os.getpid()}") - return await loop.run_in_executor(None, self.queue.put, item) - - def put_nowait(self, item): - try: - # Ensure tensor is on CPU - if torch.is_tensor(item): - item = item.cpu() - self.queue.put_nowait(item) - except queue.Full: - # Simple: drop one old frame and try again - try: - self.queue.get_nowait() - self.queue.put_nowait(item) - except: - pass # If still fails, just drop this frame - -def init_tensor_cache(image_inputs, image_outputs, audio_inputs, audio_outputs, workspace_path=None): - """Initialize the tensor cache for a worker process. - - Args: - image_inputs: Multiprocessing Queue for input images - image_outputs: Multiprocessing Queue for output images - audio_inputs: Multiprocessing Queue for input audio - audio_outputs: Multiprocessing Queue for output audio - workspace_path: The ComfyUI workspace path (should be C:\sd\ComfyUI-main) - """ - logger.info(f"[init_tensor_cache] Setting up tensor_cache queues in worker - PID: {os.getpid()}") - logger.info(f"[init_tensor_cache] Workspace path: {workspace_path}") - logger.info(f"[init_tensor_cache] Current working directory: {os.getcwd()}") - - # Initialize folder_paths in worker process - # Another attempt to fix the tensorrt paths issue via ProcessPoolExecutor - ''' - try: - # Import both possible folder_paths modules - from comfy.cmd import folder_paths as comfy_folder_paths - - # Also try to import the direct folder_paths (which TensorRT loader uses) - import sys - try: - import folder_paths as direct_folder_paths - logger.info("[init_tensor_cache] Successfully imported direct folder_paths") - except ImportError: - # If direct import fails, create an alias - sys.modules['folder_paths'] = comfy_folder_paths - direct_folder_paths = comfy_folder_paths - logger.info("[init_tensor_cache] Created folder_paths alias to comfy.cmd.folder_paths") - - logger.info(f"[init_tensor_cache] comfy_folder_paths.models_dir: {comfy_folder_paths.models_dir}") - logger.info(f"[init_tensor_cache] direct_folder_paths.models_dir: {direct_folder_paths.models_dir}") - - # Use the workspace_path as the base directory for TensorRT paths - if workspace_path: - base_dir = workspace_path - else: - # Fallback to the parent directory of models_dir - base_dir = os.path.dirname(comfy_folder_paths.models_dir) - - # Set up both models/tensorrt and outputs/tensorrt directories - tensorrt_models_dir = os.path.join(base_dir, "models", "tensorrt") - tensorrt_outputs_dir = os.path.join(base_dir, "outputs", "tensorrt") - - logger.info(f"[init_tensor_cache] TensorRT models folder: {tensorrt_models_dir}") - logger.info(f"[init_tensor_cache] TensorRT outputs folder: {tensorrt_outputs_dir}") - logger.info(f"[init_tensor_cache] Models dir exists: {os.path.exists(tensorrt_models_dir)}") - logger.info(f"[init_tensor_cache] Outputs dir exists: {os.path.exists(tensorrt_outputs_dir)}") - - # Register TensorRT paths in BOTH folder_paths modules - tensorrt_config = ([tensorrt_models_dir, tensorrt_outputs_dir], {".engine"}) - - # Update comfy.cmd.folder_paths - comfy_folder_paths.folder_names_and_paths["tensorrt"] = tensorrt_config - logger.info("[init_tensor_cache] Registered TensorRT paths in comfy.cmd.folder_paths") - - # Update direct folder_paths (which TensorRT loader uses) - direct_folder_paths.folder_names_and_paths["tensorrt"] = tensorrt_config - logger.info("[init_tensor_cache] Registered TensorRT paths in direct folder_paths") - - # Also update any existing modules in sys.modules - for module_name, module in sys.modules.items(): - if (module_name.endswith('folder_paths') or module_name == 'folder_paths') and hasattr(module, 'folder_names_and_paths'): - module.folder_names_and_paths["tensorrt"] = tensorrt_config - logger.info(f"[init_tensor_cache] Updated TensorRT paths in {module_name}") - - # Verify the registration worked - logger.info(f"[init_tensor_cache] comfy_folder_paths TensorRT files: {comfy_folder_paths.get_filename_list('tensorrt')}") - logger.info(f"[init_tensor_cache] direct_folder_paths TensorRT files: {direct_folder_paths.get_filename_list('tensorrt')}") - - except Exception as e: - logger.error(f"[init_tensor_cache] Error initializing folder_paths: {e}") - import traceback - logger.error(f"[init_tensor_cache] Traceback: {traceback.format_exc()}") - ''' - - # Replace the queues with our wrapped versions that match the original interface - tensor_cache.image_inputs = MultiProcessInputQueue(image_inputs) - tensor_cache.image_outputs = MultiProcessOutputQueue(image_outputs) - tensor_cache.audio_inputs = MultiProcessInputQueue(audio_inputs) - tensor_cache.audio_outputs = MultiProcessOutputQueue(audio_outputs) - - logger.info(f"[init_tensor_cache] tensor_cache.image_outputs id: {id(tensor_cache.image_outputs)} - PID: {os.getpid()}") - logger.info(f"[init_tensor_cache] Initialization complete - PID: {os.getpid()}") - - return os.getpid() # Return PID for verification - -# THis was an attempt to fix the tensorrt paths issue via ProcessPoolExecutor -''' -def register_tensorrt_paths(workspace_path=None): - """Register TensorRT paths in folder_paths at import time""" - try: - # Use workspace_path if provided, otherwise fall back to folder_paths.models_dir - if workspace_path: - base_dir = workspace_path - tensorrt_models_dir = os.path.join(base_dir, "models", "tensorrt") - else: - # Create tensorrt subdirectory in the models directory - tensorrt_models_dir = os.path.join(folder_paths.models_dir, "tensorrt") - - print(f"[TensorRT] workspace_path: {workspace_path}") - print(f"[TensorRT] folder_paths.models_dir: {folder_paths.models_dir}") - print(f"[TensorRT] Registering paths:") - print(f"[TensorRT] - Models: {tensorrt_models_dir}") - - if "tensorrt" in folder_paths.folder_names_and_paths: - # Update existing registration - existing_paths = folder_paths.folder_names_and_paths["tensorrt"][0] - if tensorrt_models_dir not in existing_paths: - existing_paths.append(tensorrt_models_dir) - folder_paths.folder_names_and_paths["tensorrt"][1].add(".engine") - else: - # Create new registration (same as Depth-Anything approach) - folder_paths.folder_names_and_paths["tensorrt"] = ( - [tensorrt_models_dir], - {".engine"} - ) - - # Verify registration - available_files = folder_paths.get_filename_list("tensorrt") - print(f"[TensorRT] Available engine files: {available_files}") - - except Exception as e: - print(f"[TensorRT] Error registering paths: {e}") - import traceback - traceback.print_exc() - # Fallback to original behavior - if "tensorrt" in folder_paths.folder_names_and_paths: - folder_paths.folder_names_and_paths["tensorrt"][0].append( - os.path.join(folder_paths.models_dir, "tensorrt")) - folder_paths.folder_names_and_paths["tensorrt"][1].add(".engine") - else: - folder_paths.folder_names_and_paths["tensorrt"] = ( - [os.path.join(folder_paths.models_dir, "tensorrt")], - {".engine"} - ) -''' \ No newline at end of file diff --git a/src/comfystream/utils_api.py b/src/comfystream/utils_api.py deleted file mode 100644 index dbbb6790..00000000 --- a/src/comfystream/utils_api.py +++ /dev/null @@ -1,154 +0,0 @@ -import copy -import random - -from typing import Dict, Any - -import logging -logger = logging.getLogger(__name__) - -def create_load_tensor_node(): - return { - "inputs": { - "tensor_data": "" # Empty tensor data that will be filled at runtime - }, - "class_type": "LoadTensorAPI", - "_meta": {"title": "Load Tensor (API)"}, - } - -def create_load_image_base64_node(): - return { - "inputs": { - "image": "" # Should be "image" not "image_data" to match LoadImageBase64 - }, - "class_type": "LoadImageBase64", - "_meta": {"title": "Load Image Base64 (ComfyStream)"}, - } - -def create_save_tensor_node(inputs: Dict[Any, Any]): - """Create a SaveTensorAPI node with proper input formatting""" - # Make sure images input is properly formatted [node_id, output_index] - images_input = inputs.get("images") - - # If images input is not properly formatted as [node_id, output_index] - if not isinstance(images_input, list) or len(images_input) != 2: - print(f"Warning: Invalid images input format: {images_input}, using default") - images_input = ["", 0] # Default empty value - - return { - "inputs": { - "images": images_input, # Should be [node_id, output_index] - "format": "png", # Better default than JPG for quality - "quality": 95 - }, - "class_type": "SaveTensorAPI", - "_meta": {"title": "Save Tensor (API)"}, - } - -def create_send_image_websocket_node(inputs: Dict[Any, Any]): - # Get the correct image input reference - images_input = inputs.get("images", inputs.get("image")) - - # If not properly formatted, use default - if not images_input: - images_input = ["", 0] # Default empty value - - return { - "inputs": { - "images": images_input, - "format": "PNG" # Default format - }, - "class_type": "SendImageWebsocket", - "_meta": {"title": "Send Image Websocket (ComfyStream)"}, - } - -def create_send_tensor_websocket_node(inputs: Dict[Any, Any]): - # Get the correct image input reference - tensor_input = inputs.get("images", inputs.get("tensor")) - - if not tensor_input: - logging.warning("No valid tensor input found for SendTensorWebSocket node") - tensor_input = ["", 0] # Default empty value - - return { - "inputs": { - "tensor": tensor_input - }, - "class_type": "SendTensorWebSocket", - "_meta": {"title": "Save Tensor WebSocket (ComfyStream)"}, - } - -def convert_prompt(prompt): - logging.info("Converting prompt: %s", prompt) - - # Initialize counters - num_primary_inputs = 0 - num_inputs = 0 - num_outputs = 0 - - keys = { - "PrimaryInputLoadImage": [], - "LoadImage": [], - "PreviewImage": [], - "SaveImage": [], - } - - # Set random seeds for any seed nodes - for key, node in prompt.items(): - if not isinstance(node, dict) or "inputs" not in node: - continue - - # Check if this node has a seed input directly - if "seed" in node.get("inputs", {}): - # Generate a random seed (same range as JavaScript's Math.random() * 18446744073709552000) - random_seed = random.randint(0, 18446744073709551615) - node["inputs"]["seed"] = random_seed - logger.debug(f"Set random seed {random_seed} for node {key}") - - for key, node in prompt.items(): - class_type = node.get("class_type") - - # Collect keys for nodes that might need to be replaced - if class_type in keys: - keys[class_type].append(key) - - # Count inputs and outputs - if class_type == "PrimaryInputLoadImage": - num_primary_inputs += 1 - elif class_type in ["LoadImage", "LoadImageBase64"]: - num_inputs += 1 - elif class_type in ["PreviewImage", "SaveImage", "SendImageWebsocket", "SendTensorWebSocket"]: - num_outputs += 1 - - # Only handle single primary input - if num_primary_inputs > 1: - raise Exception("too many primary inputs in prompt") - - # If there are no primary inputs, only handle single input - if num_primary_inputs == 0 and num_inputs > 1: - raise Exception("too many inputs in prompt") - - # Only handle single output for now - if num_outputs > 1: - raise Exception("too many outputs in prompt") - - if num_primary_inputs + num_inputs == 0: - raise Exception("missing input") - - if num_outputs == 0: - raise Exception("missing output") - - # Replace nodes with proper implementations - for key in keys["PrimaryInputLoadImage"]: - prompt[key] = create_load_image_base64_node() - - if num_primary_inputs == 0 and len(keys["LoadImage"]) == 1: - prompt[keys["LoadImage"][0]] = create_load_image_base64_node() - - for key in keys["PreviewImage"] + keys["SaveImage"]: - node = prompt[key] - # prompt[key] = create_save_image_node(node["inputs"]) - prompt[key] = create_send_image_websocket_node(node["inputs"]) # TESTING - - # TODO: Validate the processed prompt input - - return prompt From 02e8d0824c513ab4309fbfe2f322bcf04c591b1e Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Tue, 24 Jun 2025 21:37:50 -0400 Subject: [PATCH 42/42] Testing - small fix with merge. --- src/comfystream/client.py | 2 +- src/comfystream/pipeline.py | 6 +----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/comfystream/client.py b/src/comfystream/client.py index 3e402549..f142f6a9 100644 --- a/src/comfystream/client.py +++ b/src/comfystream/client.py @@ -298,7 +298,7 @@ def put_audio_input(self, frame): async def get_video_output(self): try: - logger.debug(f"[ComfyStreamClient] get_video_output called - PID: {os.getpid()}") + # logger.info(f"[ComfyStreamClient] get_video_output called - PID: {os.getpid()}") tensor = await asyncio.wait_for( asyncio.get_event_loop().run_in_executor(None, self.image_outputs.get), timeout=5.0 diff --git a/src/comfystream/pipeline.py b/src/comfystream/pipeline.py index 3f2907bc..3b39229f 100644 --- a/src/comfystream/pipeline.py +++ b/src/comfystream/pipeline.py @@ -260,12 +260,8 @@ async def get_processed_video_frame(self) -> av.VideoFrame: Returns: The processed video frame """ - - async with temporary_log_level("comfy", self._comfyui_inference_log_level): - out_tensor = await self.client.get_video_output() + # Get input frame frame = await self.video_incoming_frames.get() - while frame.side_data.skipped: - frame = await self.video_incoming_frames.get() # Get any available output (don't match input to output) try: