diff --git a/nodes/__init__.py b/nodes/__init__.py index 46aa93c6..4682f0d6 100644 --- a/nodes/__init__.py +++ b/nodes/__init__.py @@ -3,6 +3,7 @@ from .audio_utils import * from .tensor_utils import * from .video_stream_utils import * +from .native_utils import * from .api import * from .web import * @@ -11,7 +12,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {} # Import and update mappings from submodules -for module in [audio_utils, tensor_utils, video_stream_utils, api, web]: +for module in [audio_utils, tensor_utils, video_stream_utils, api, web, native_utils]: if hasattr(module, 'NODE_CLASS_MAPPINGS'): NODE_CLASS_MAPPINGS.update(module.NODE_CLASS_MAPPINGS) if hasattr(module, 'NODE_DISPLAY_NAME_MAPPINGS'): diff --git a/nodes/native_utils/__init__.py b/nodes/native_utils/__init__.py new file mode 100644 index 00000000..472e5411 --- /dev/null +++ b/nodes/native_utils/__init__.py @@ -0,0 +1,20 @@ +from .load_image_base64 import LoadImageBase64 +from .send_image_websocket import SendImageWebsocket +from .send_tensor_websocket import SendTensorWebSocket + +# This dictionary is used by ComfyUI to register the nodes +NODE_CLASS_MAPPINGS = { + "LoadImageBase64": LoadImageBase64, + "SendImageWebsocket": SendImageWebsocket, + "SendTensorWebSocket": SendTensorWebSocket +} + +# This dictionary provides display names for the nodes in the UI +NODE_DISPLAY_NAME_MAPPINGS = { + "LoadImageBase64": "Load Image Base64 (ComfyStream)", + "SendImageWebsocket": "Send Image Websocket (ComfyStream)", + "SendTensorWebSocket": "Save Tensor WebSocket (ComfyStream)" +} + +# Export these variables for ComfyUI to use +__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] diff --git a/nodes/native_utils/load_image_base64.py b/nodes/native_utils/load_image_base64.py new file mode 100644 index 00000000..f46a90df --- /dev/null +++ b/nodes/native_utils/load_image_base64.py @@ -0,0 +1,37 @@ +# borrowed from Acly's comfyui-tooling-nodes +# https://github.com/Acly/comfyui-tooling-nodes/blob/main/nodes.py + +# TODO: I think we can recieve tensor data directly from the pipeline through the /prompt endpoint as JSON +# This may be more efficient than sending base64 encoded images through the websocket and +# allow for alternative data formats. + +from PIL import Image +import base64 +import numpy as np +import torch +from io import BytesIO + +class LoadImageBase64: + @classmethod + def INPUT_TYPES(s): + return {"required": {"image": ("STRING", {"multiline": False})}} + + RETURN_TYPES = ("IMAGE", "MASK") + CATEGORY = "external_tooling" + FUNCTION = "load_image" + + def load_image(self, image): + imgdata = base64.b64decode(image) + img = Image.open(BytesIO(imgdata)) + + if "A" in img.getbands(): + mask = np.array(img.getchannel("A")).astype(np.float32) / 255.0 + mask = torch.from_numpy(mask) + else: + mask = None + + img = img.convert("RGB") + img = np.array(img).astype(np.float32) / 255.0 + img = torch.from_numpy(img)[None,] + + return (img, mask) \ No newline at end of file diff --git a/nodes/native_utils/send_image_websocket.py b/nodes/native_utils/send_image_websocket.py new file mode 100644 index 00000000..590d3b7e --- /dev/null +++ b/nodes/native_utils/send_image_websocket.py @@ -0,0 +1,44 @@ +# borrowed from Acly's comfyui-tooling-nodes +# https://github.com/Acly/comfyui-tooling-nodes/blob/main/nodes.py + +# TODO: I think we can send tensor data directly to the pipeline in the websocket response. +# Worth talking to ComfyAnonymous about this. + +import numpy as np +from PIL import Image +from server import PromptServer, BinaryEventTypes + +class SendImageWebsocket: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "images": ("IMAGE",), + "format": (["PNG", "JPEG"], {"default": "PNG"}), + } + } + + RETURN_TYPES = () + FUNCTION = "send_images" + OUTPUT_NODE = True + CATEGORY = "external_tooling" + + def send_images(self, images, format): + results = [] + for tensor in images: + array = 255.0 * tensor.cpu().numpy() + image = Image.fromarray(np.clip(array, 0, 255).astype(np.uint8)) + + server = PromptServer.instance + server.send_sync( + BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, + [format, image, None], + server.client_id, + ) + results.append({ + "source": "websocket", + "content-type": f"image/{format.lower()}", + "type": "output", + }) + + return {"ui": {"images": results}} \ No newline at end of file diff --git a/nodes/native_utils/send_tensor_websocket.py b/nodes/native_utils/send_tensor_websocket.py new file mode 100644 index 00000000..33104896 --- /dev/null +++ b/nodes/native_utils/send_tensor_websocket.py @@ -0,0 +1,289 @@ +import torch +import numpy as np +import base64 +import logging +import json +import traceback +import sys + +logger = logging.getLogger(__name__) + +# Log when the module is loaded +logger.debug("------------------ SendTensorWebSocket Module Loaded ------------------") + +class SendTensorWebSocket: + def __init__(self): + # Output directory is not needed as we send via WebSocket + logger.debug("SendTensorWebSocket instance created") + pass + + @classmethod + def INPUT_TYPES(cls): + logger.debug("SendTensorWebSocket.INPUT_TYPES called") + return { + "required": { + # Accept IMAGE input (typical output from VAE Decode) + "tensor": ("IMAGE", ), + }, + "hidden": { + # These are needed for ComfyUI execution context + "prompt": "PROMPT", + "extra_pnginfo": "EXTRA_PNGINFO" + }, + } + + RETURN_TYPES = () # No direct output connection to other nodes + FUNCTION = "save_tensor" + OUTPUT_NODE = True + CATEGORY = "ComfyStream/native" + + def save_tensor(self, tensor, prompt=None, extra_pnginfo=None): + logger.debug("========== SendTensorWebSocket.save_tensor STARTED ==========") + logger.info(f"SendTensorWebSocket received input. Type: {type(tensor)}") + logger.debug(f"SendTensorWebSocket node is processing tensor with id: {id(tensor)}") + + # Log memory usage for debugging + if torch.cuda.is_available(): + try: + logger.debug(f"CUDA memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB") + logger.debug(f"CUDA memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB") + except Exception as e: + logger.error(f"Error checking CUDA memory: {e}") + + if tensor is None: + logger.error("SendTensorWebSocket received None tensor.") + # Return error directly without ui nesting + return {"comfystream_tensor_output": {"error": "Input tensor was None"}} + + try: + # Log details about the tensor before processing + logger.debug(f"Process tensor of type: {type(tensor)}") + + if isinstance(tensor, torch.Tensor): + logger.debug("Processing torch.Tensor...") + logger.info(f"Input tensor details: shape={tensor.shape}, dtype={tensor.dtype}, device={tensor.device}") + + # Additional handling for IMAGE-type tensors (0-1 float values, BCHW format) + if len(tensor.shape) == 4: # BCHW format (batch) + logger.debug(f"Tensor is batched (BCHW): {tensor.shape}") + logger.info(f"Tensor appears to be IMAGE batch. Min: {tensor.min().item()}, Max: {tensor.max().item()}") + logger.debug(f"First batch slice: min={tensor[0].min().item()}, max={tensor[0].max().item()}") + tensor = tensor[0] # Select first image from batch + logger.debug(f"Selected first batch element. New shape: {tensor.shape}") + + if len(tensor.shape) == 3: # CHW format (single image) + logger.debug(f"Tensor is CHW format: {tensor.shape}") + logger.info(f"Tensor appears to be single IMAGE. Min: {tensor.min().item()}, Max: {tensor.max().item()}") + + # Log first few values for debugging + logger.debug(f"First few values: {tensor.flatten()[:10].tolist()}") + + # Ensure the tensor is on CPU and detached + logger.debug(f"Moving tensor to CPU. Current device: {tensor.device}") + try: + tensor = tensor.cpu().detach() + logger.debug(f"Tensor moved to CPU successfully: {tensor.device}") + except Exception as e: + logger.error(f"Error moving tensor to CPU: {e}") + logger.error(traceback.format_exc()) + return {"comfystream_tensor_output": {"error": f"CPU transfer error: {str(e)}"}} + + # Convert to numpy + logger.debug("Converting tensor to numpy array...") + try: + np_array = tensor.numpy() + logger.debug(f"Conversion to numpy successful: shape={np_array.shape}, dtype={np_array.dtype}") + logger.debug(f"NumPy array memory usage: {np_array.nbytes / 1024**2:.2f} MB") + except Exception as e: + logger.error(f"Error converting tensor to numpy: {e}") + logger.error(traceback.format_exc()) + return {"comfystream_tensor_output": {"error": f"NumPy conversion error: {str(e)}"}} + + # Encode the tensor + logger.debug("Converting numpy array to bytes...") + try: + tensor_bytes = np_array.tobytes() + logger.debug(f"Tensor converted to bytes: {len(tensor_bytes)} bytes") + except Exception as e: + logger.error(f"Error converting numpy array to bytes: {e}") + logger.error(traceback.format_exc()) + return {"comfystream_tensor_output": {"error": f"Bytes conversion error: {str(e)}"}} + + logger.debug("Encoding bytes to base64...") + try: + b64_data = base64.b64encode(tensor_bytes).decode('utf-8') + b64_size = len(b64_data) + logger.debug(f"Base64 encoding successful: {b64_size} characters") + if b64_size > 100: + logger.debug(f"Base64 sample: {b64_data[:50]}...{b64_data[-50:]}") + except Exception as e: + logger.error(f"Error encoding to base64: {e}") + logger.error(traceback.format_exc()) + return {"comfystream_tensor_output": {"error": f"Base64 encoding error: {str(e)}"}} + + # Prepare metadata + shape = list(np_array.shape) + dtype = str(np_array.dtype) + + logger.info(f"SendTensorWebSocket prepared tensor: shape={shape}, dtype={dtype}") + + # Construct the return value with simplified structure (no ui nesting) + success_output = { + "comfystream_tensor_output": { + "b64_data": b64_data, + "shape": shape, + "dtype": dtype + } + } + + # Log the structure of the output (avoid logging the actual base64 data which is large) + output_structure = { + "comfystream_tensor_output": { + "b64_data": f"(base64 string of {b64_size} bytes)", + "shape": shape, + "dtype": dtype + } + } + logger.info(f"SendTensorWebSocket returning SUCCESS data structure: {json.dumps(output_structure)}") + logger.debug("========== SendTensorWebSocket.save_tensor COMPLETED SUCCESSFULLY ==========") + + return success_output + + elif isinstance(tensor, np.ndarray): + logger.debug("Processing numpy.ndarray...") + logger.info(f"Input is numpy array: shape={tensor.shape}, dtype={tensor.dtype}") + + # Log memory details + logger.debug(f"NumPy array memory usage: {tensor.nbytes / 1024**2:.2f} MB") + logger.debug(f"First few values: {tensor.flatten()[:10].tolist()}") + + # Handle numpy array directly + logger.debug("Converting numpy array to bytes...") + try: + tensor_bytes = tensor.tobytes() + logger.debug(f"NumPy array converted to bytes: {len(tensor_bytes)} bytes") + except Exception as e: + logger.error(f"Error converting numpy array to bytes: {e}") + logger.error(traceback.format_exc()) + return {"comfystream_tensor_output": {"error": f"NumPy to bytes error: {str(e)}"}} + + logger.debug("Encoding bytes to base64...") + try: + b64_data = base64.b64encode(tensor_bytes).decode('utf-8') + b64_size = len(b64_data) + logger.debug(f"Base64 encoding successful: {b64_size} characters") + if b64_size > 100: + logger.debug(f"Base64 sample: {b64_data[:50]}...{b64_data[-50:]}") + except Exception as e: + logger.error(f"Error encoding numpy to base64: {e}") + logger.error(traceback.format_exc()) + return {"comfystream_tensor_output": {"error": f"NumPy base64 encoding error: {str(e)}"}} + + shape = list(tensor.shape) + dtype = str(tensor.dtype) + + logger.debug("Constructing success output for numpy array...") + success_output = { + "comfystream_tensor_output": { + "b64_data": b64_data, + "shape": shape, + "dtype": dtype + } + } + logger.info(f"SendTensorWebSocket returning SUCCESS from numpy array: shape={shape}, dtype={dtype}") + logger.debug("========== SendTensorWebSocket.save_tensor COMPLETED SUCCESSFULLY ==========") + return success_output + + elif isinstance(tensor, list): + logger.debug("Processing list input...") + logger.info(f"Input is a list of length {len(tensor)}") + + if len(tensor) > 0: + first_item = tensor[0] + logger.debug(f"First item type: {type(first_item)}") + + if isinstance(first_item, torch.Tensor): + logger.debug("Processing first tensor from list...") + logger.debug(f"First tensor details: shape={first_item.shape}, dtype={first_item.dtype}, device={first_item.device}") + + # Log first few values + logger.debug(f"First few values: {first_item.flatten()[:10].tolist()}") + + # Process first tensor in the list + try: + logger.debug("Moving tensor to CPU and detaching...") + np_array = first_item.cpu().detach().numpy() + logger.debug(f"Conversion successful: shape={np_array.shape}, dtype={np_array.dtype}") + except Exception as e: + logger.error(f"Error processing first tensor in list: {e}") + logger.error(traceback.format_exc()) + return {"comfystream_tensor_output": {"error": f"List tensor processing error: {str(e)}"}} + + try: + logger.debug("Converting numpy array to bytes...") + tensor_bytes = np_array.tobytes() + logger.debug(f"Converted to bytes: {len(tensor_bytes)} bytes") + except Exception as e: + logger.error(f"Error converting list tensor to bytes: {e}") + logger.error(traceback.format_exc()) + return {"comfystream_tensor_output": {"error": f"List tensor bytes conversion error: {str(e)}"}} + + try: + logger.debug("Encoding bytes to base64...") + b64_data = base64.b64encode(tensor_bytes).decode('utf-8') + b64_size = len(b64_data) + logger.debug(f"Base64 encoding successful: {b64_size} characters") + except Exception as e: + logger.error(f"Error encoding list tensor to base64: {e}") + logger.error(traceback.format_exc()) + return {"comfystream_tensor_output": {"error": f"List tensor base64 encoding error: {str(e)}"}} + + shape = list(np_array.shape) + dtype = str(np_array.dtype) + + logger.debug("Constructing success output for list tensor...") + success_output = { + "comfystream_tensor_output": { + "b64_data": b64_data, + "shape": shape, + "dtype": dtype + } + } + logger.info(f"SendTensorWebSocket returning SUCCESS from list's first tensor: shape={shape}, dtype={dtype}") + logger.debug("========== SendTensorWebSocket.save_tensor COMPLETED SUCCESSFULLY ==========") + return success_output + else: + logger.error(f"First item in list is not a tensor but {type(first_item)}") + if hasattr(first_item, '__dict__'): + logger.debug(f"First item attributes: {dir(first_item)}") + + # If we got here, couldn't process the list + logger.error(f"Unable to process list input: invalid content types") + list_types = [type(x).__name__ for x in tensor[:3]] + error_msg = f"Unsupported list content: {list_types}..." + logger.debug("========== SendTensorWebSocket.save_tensor FAILED ==========") + return {"comfystream_tensor_output": {"error": error_msg}} + + else: + # Unsupported type + error_msg = f"Unsupported tensor type: {type(tensor)}" + logger.error(error_msg) + if hasattr(tensor, '__dict__'): + logger.debug(f"Tensor attributes: {dir(tensor)}") + + logger.debug("========== SendTensorWebSocket.save_tensor FAILED ==========") + return {"comfystream_tensor_output": {"error": error_msg}} + + except Exception as e: + logger.exception(f"Error serializing tensor in SendTensorWebSocket: {e}") + + # Get detailed exception info + exc_type, exc_value, exc_traceback = sys.exc_info() + tb_lines = traceback.format_exception(exc_type, exc_value, exc_traceback) + tb_text = ''.join(tb_lines) + logger.debug(f"Exception traceback:\n{tb_text}") + + error_output = {"comfystream_tensor_output": {"error": f"{str(e)} - See save_tensor_websocket_debug.log for details"}} + logger.info(f"SendTensorWebSocket returning ERROR data: {error_output}") + logger.debug("========== SendTensorWebSocket.save_tensor FAILED WITH EXCEPTION ==========") + return error_output \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index d25686ad..4f6373ac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,12 @@ asyncio -comfyui @ git+https://github.com/hiddenswitch/ComfyUI.git@ce3583ad42c024b8f060d0002cbe20c265da6dc8 +comfyui @ git+https://github.com/hiddenswitch/ComfyUI.git@e034d0bb24b0d23e3c40419c68689464dec67690 aiortc aiohttp aiohttp_cors toml +tomli +websockets twilio prometheus_client librosa +torchvision \ No newline at end of file diff --git a/server/app.py b/server/app.py index 34ff5af3..b22a836c 100644 --- a/server/app.py +++ b/server/app.py @@ -7,12 +7,13 @@ import time import secrets import torch +import signal # Initialize CUDA before any other imports to prevent core dump. if torch.cuda.is_available(): torch.cuda.init() - - + torch.cuda.empty_cache() + from aiohttp import web, MultipartWriter from aiohttp_cors import setup as setup_cors, ResourceOptions from aiohttp import web @@ -37,6 +38,8 @@ logging.getLogger("aiortc.rtcrtpsender").setLevel(logging.WARNING) logging.getLogger("aiortc.rtcrtpreceiver").setLevel(logging.WARNING) +# Global variable to track the app for signal handling +app_instance = None MAX_BITRATE = 2000000 MIN_BITRATE = 2000000 @@ -72,7 +75,7 @@ def __init__(self, track: MediaStreamTrack, pipeline: Pipeline): # Add cleanup when track ends @track.on("ended") async def on_ended(): - logger.info("Source video track ended, stopping collection") + logger.info("[App] Source video track ended, stopping collection") await cancel_collect_frames(self) async def collect_frames(self): @@ -85,22 +88,22 @@ async def collect_frames(self): frame = await self.track.recv() await self.pipeline.put_video_frame(frame) except asyncio.CancelledError: - logger.info("Frame collection cancelled") + logger.info("[App] Frame collection cancelled") break except Exception as e: if "MediaStreamError" in str(type(e)): - logger.info("Media stream ended") + logger.info("[App] Media stream ended") else: - logger.error(f"Error collecting video frames: {str(e)}") + logger.error(f"[App] Error collecting video frames: {str(e)}") self.running = False break # Perform cleanup outside the exception handler - logger.info("Video frame collection stopped") + logger.info("[App] Video frame collection stopped") except asyncio.CancelledError: - logger.info("Frame collection task cancelled") + logger.info("[App] Frame collection task cancelled") except Exception as e: - logger.error(f"Unexpected error in frame collection: {str(e)}") + logger.error(f"[App] Unexpected error in frame collection: {str(e)}") finally: await self.pipeline.cleanup() @@ -110,14 +113,15 @@ async def recv(self): """ processed_frame = await self.pipeline.get_processed_video_frame() - # Update the frame buffer with the processed frame + # Update the frame buffer with the processed frame try: from frame_buffer import FrameBuffer frame_buffer = FrameBuffer.get_instance() frame_buffer.update_frame(processed_frame) except Exception as e: # Don't let frame buffer errors affect the main pipeline - print(f"Error updating frame buffer: {e}") + print(f"[App] Error updating frame buffer: {e}") + # Increment the frame count to calculate FPS. await self.fps_meter.increment_frame_count() @@ -151,22 +155,22 @@ async def collect_frames(self): frame = await self.track.recv() await self.pipeline.put_audio_frame(frame) except asyncio.CancelledError: - logger.info("Audio frame collection cancelled") + logger.info("[App] Audio frame collection cancelled") break except Exception as e: if "MediaStreamError" in str(type(e)): - logger.info("Media stream ended") + logger.info("[App] Media stream ended") else: - logger.error(f"Error collecting audio frames: {str(e)}") + logger.error(f"[App] Error collecting audio frames: {str(e)}") self.running = False break # Perform cleanup outside the exception handler - logger.info("Audio frame collection stopped") + logger.info("[App] Audio frame collection stopped") except asyncio.CancelledError: - logger.info("Frame collection task cancelled") + logger.info("[App] Frame collection task cancelled") except Exception as e: - logger.error(f"Unexpected error in audio frame collection: {str(e)}") + logger.error(f"[App] Unexpected error in audio frame collection: {str(e)}") finally: await self.pipeline.cleanup() @@ -301,16 +305,16 @@ async def on_message(message): channel.send(json.dumps(response)) else: logger.warning( - "[Server] Invalid message format - missing required fields" + "[App] Invalid message format - missing required fields" ) except json.JSONDecodeError: - logger.error("[Server] Invalid JSON received") + logger.error("[App] Invalid JSON received") except Exception as e: - logger.error(f"[Server] Error processing message: {str(e)}") + logger.error(f"[App] Error processing message: {str(e)}") @pc.on("track") def on_track(track): - logger.info(f"Track received: {track.kind}") + logger.info(f"[App] Track received: {track.kind}") if track.kind == "video": videoTrack = VideoStreamTrack(track, pipeline) tracks["video"] = videoTrack @@ -329,12 +333,12 @@ def on_track(track): @track.on("ended") async def on_ended(): - logger.info(f"{track.kind} track ended") + logger.info(f"[App] {track.kind} track ended") request.app["video_tracks"].pop(track.id, None) @pc.on("connectionstatechange") async def on_connectionstatechange(): - logger.info(f"Connection state is: {pc.connectionState}") + logger.info(f"[App] Connection state is: {pc.connectionState}") if pc.connectionState == "failed": await pc.close() pcs.discard(pc) @@ -376,7 +380,7 @@ async def set_prompt(request): await pipeline.set_prompts(prompt) return web.Response(content_type="application/json", text="OK") - + def health(_): return web.Response(content_type="application/json", text="OK") @@ -389,6 +393,7 @@ async def on_startup(app: web.Application): app["pipeline"] = Pipeline( width=512, height=512, + max_workers=app["workers"], cwd=app["workspace"], disable_cuda_malloc=True, gpu_only=True, @@ -400,10 +405,85 @@ async def on_startup(app: web.Application): async def on_shutdown(app: web.Application): + logger.info("Starting server shutdown...") + + # Clean up pipeline first - this should terminate worker processes + if "pipeline" in app: + try: + await app["pipeline"].cleanup() + logger.info("Pipeline cleanup completed") + except Exception as e: + logger.error(f"Error during pipeline cleanup: {e}") + + # Clean up peer connections pcs = app["pcs"] - coros = [pc.close() for pc in pcs] - await asyncio.gather(*coros) - pcs.clear() + if pcs: + logger.info(f"Closing {len(pcs)} peer connections...") + coros = [pc.close() for pc in pcs] + await asyncio.gather(*coros, return_exceptions=True) + pcs.clear() + logger.info("Peer connections closed") + + logger.info("Server shutdown completed") + + +def signal_handler(signum, frame): + """Handle SIGINT (Ctrl+C) gracefully""" + logger.info(f"Received signal {signum}, initiating graceful shutdown...") + + if app_instance: + # Get the current event loop + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + # Schedule the shutdown coroutine + asyncio.create_task(shutdown_server()) + else: + # If no loop is running, run the shutdown directly + asyncio.run(shutdown_server()) + except RuntimeError: + # If we can't get the loop, force cleanup and exit + logger.warning("Could not get event loop, forcing cleanup and exit") + force_cleanup_and_exit() + else: + logger.warning("No app instance found, forcing exit") + os._exit(1) + +def force_cleanup_and_exit(): + """Force cleanup of processes and exit""" + try: + # Try to cleanup the pipeline synchronously + if app_instance and "pipeline" in app_instance: + pipeline = app_instance["pipeline"] + if hasattr(pipeline, 'client') and hasattr(pipeline.client, 'comfy_client'): + # Force shutdown of the executor + if hasattr(pipeline.client.comfy_client, 'executor'): + executor = pipeline.client.comfy_client.executor + if hasattr(executor, 'shutdown'): + logger.info("Force shutting down executor...") + executor.shutdown(wait=False) + if hasattr(executor, '_processes'): + # Terminate all worker processes + for process in executor._processes: + if process.is_alive(): + logger.info(f"Terminating worker process {process.pid}") + process.terminate() + except Exception as e: + logger.error(f"Error in force cleanup: {e}") + finally: + os._exit(1) + +async def shutdown_server(): + """Perform graceful server shutdown""" + try: + if app_instance: + await on_shutdown(app_instance) + logger.info("Graceful shutdown completed") + except Exception as e: + logger.error(f"Error during graceful shutdown: {e}") + finally: + # Force exit after cleanup attempt + os._exit(0) if __name__ == "__main__": @@ -446,8 +526,25 @@ async def on_shutdown(app: web.Application): choices=logging._nameToLevel.keys(), help="Set the logging level for ComfyUI inference", ) + parser.add_argument( + "--frame-log-file", + type=str, + default=None, + help="Filename for frame timing log (optional)" + ) + parser.add_argument( + "--workers", + type=int, + default=1, + help="Number of workers to run", + ) args = parser.parse_args() + # Set up signal handlers + signal.signal(signal.SIGINT, signal_handler) + if hasattr(signal, 'SIGTERM'): + signal.signal(signal.SIGTERM, signal_handler) + logging.basicConfig( level=args.log_level.upper(), format="%(asctime)s [%(levelname)s] %(message)s", @@ -455,9 +552,10 @@ async def on_shutdown(app: web.Application): ) app = web.Application() + app_instance = app # Store reference for signal handler app["media_ports"] = args.media_ports.split(",") if args.media_ports else None app["workspace"] = args.workspace - + # Setup CORS cors = setup_cors(app, defaults={ "*": ResourceOptions( @@ -467,6 +565,7 @@ async def on_shutdown(app: web.Application): allow_methods=["GET", "POST", "OPTIONS"] ) }) + app["workers"] = args.workers app.on_startup.append(on_startup) app.on_shutdown.append(on_shutdown) @@ -477,7 +576,7 @@ async def on_shutdown(app: web.Application): # WebRTC signalling and control routes. app.router.add_post("/offer", offer) app.router.add_post("/prompt", set_prompt) - + # Setup HTTP streaming routes setup_routes(app, cors) @@ -516,6 +615,17 @@ def force_print(*args, **kwargs): log_level = logging._nameToLevel.get(args.comfyui_log_level.upper()) logging.getLogger("comfy").setLevel(log_level) if args.comfyui_inference_log_level: - app["comfui_inference_log_level"] = args.comfyui_inference_log_level - - web.run_app(app, host=args.host, port=int(args.port), print=force_print) + log_level = logging._nameToLevel.get(args.comfyui_inference_log_level.upper()) + app["comfyui_inference_log_level"] = log_level + + try: + web.run_app(app, host=args.host, port=int(args.port), print=force_print) + except KeyboardInterrupt: + logger.info("Received KeyboardInterrupt, shutting down...") + finally: + # Ensure cleanup happens even if web.run_app doesn't call on_shutdown + if app_instance: + try: + asyncio.run(on_shutdown(app_instance)) + except Exception as e: + logger.error(f"Error in final cleanup: {e}") diff --git a/src/comfystream/client.py b/src/comfystream/client.py index ca0c8751..f142f6a9 100644 --- a/src/comfystream/client.py +++ b/src/comfystream/client.py @@ -1,126 +1,336 @@ import asyncio -from typing import List import logging +from typing import List, Union +import multiprocessing as mp +import os +import sys +import numpy as np +import torch +import av -from comfystream import tensor_cache from comfystream.utils import convert_prompt +from comfystream.tensor_cache import init_tensor_cache -from comfy.api.components.schema.prompt import PromptDictInput from comfy.cli_args_types import Configuration +from comfy.distributed.process_pool_executor import ProcessPoolExecutor +from comfy.api.components.schema.prompt import PromptDictInput from comfy.client.embedded_comfy_client import EmbeddedComfyClient +from comfystream.frame_proxy import FrameProxy logger = logging.getLogger(__name__) +def _test_worker_init(): + """Test function to verify worker process initialization.""" + return os.getpid() class ComfyStreamClient: - def __init__(self, max_workers: int = 1, **kwargs): - config = Configuration(**kwargs) - self.comfy_client = EmbeddedComfyClient(config, max_workers=max_workers) - self.running_prompts = {} # To be used for cancelling tasks + def __init__(self, + max_workers: int = 1, + **kwargs): + logger.info(f"[ComfyStreamClient] Main Process ID: {os.getpid()}") + logger.info(f"[ComfyStreamClient] __init__ start, max_workers: {max_workers}") + + # Store default dimensions + self.width = kwargs.get('width', 512) + self.height = kwargs.get('height', 512) + + # Ensure workspace path is absolute + if 'cwd' in kwargs: + if not os.path.isabs(kwargs['cwd']): + # Convert relative path to absolute path from current working directory + kwargs['cwd'] = os.path.abspath(kwargs['cwd']) + logger.info(f"[ComfyStreamClient] Using absolute workspace path: {kwargs['cwd']}") + + # Register TensorRT paths in main process BEFORE creating ComfyUI client + self.register_tensorrt_paths_main_process(kwargs.get('cwd')) + + # Cache nodes information in main process to avoid ProcessPoolExecutor conflicts + self._initialize_nodes_cache() + + logger.info("[ComfyStreamClient] Config kwargs: %s", kwargs) + + try: + self.config = Configuration(**kwargs) + logger.info("[ComfyStreamClient] Configuration created") + logger.info(f"[ComfyStreamClient] Current working directory: {os.getcwd()}") + + logger.info("[ComfyStreamClient] Initializing process executor") + ctx = mp.get_context("spawn") + logger.info(f"[ComfyStreamClient] Using multiprocessing context: {ctx.get_start_method()}") + + manager = ctx.Manager() + logger.info("[ComfyStreamClient] Created multiprocessing context and manager") + + self.image_inputs = manager.Queue(maxsize=50) + self.image_outputs = manager.Queue(maxsize=50) + self.audio_inputs = manager.Queue(maxsize=50) + self.audio_outputs = manager.Queue(maxsize=50) + logger.info("[ComfyStreamClient] Created manager queues") + + logger.info("[ComfyStreamClient] About to create ProcessPoolExecutor...") + + executor = ProcessPoolExecutor( + max_workers=max_workers, + initializer=init_tensor_cache, + initargs=(self.image_inputs, self.image_outputs, self.audio_inputs, self.audio_outputs, kwargs.get('cwd')) + ) + logger.info("[ComfyStreamClient] ProcessPoolExecutor created successfully") + + # Create EmbeddedComfyClient with the executor + logger.info("[ComfyStreamClient] Creating EmbeddedComfyClient with executor") + self.comfy_client = EmbeddedComfyClient(self.config, executor=executor) + logger.info("[ComfyStreamClient] EmbeddedComfyClient created successfully") + + # Submit a test task to ensure worker processes are initialized + logger.info("[ComfyStreamClient] Testing worker process initialization...") + test_future = executor.submit(_test_worker_init) + try: + worker_pid = test_future.result(timeout=30) # 30 second timeout + logger.info(f"[ComfyStreamClient] Worker process initialized successfully (PID: {worker_pid})") + except Exception as e: + logger.info(f"[ComfyStreamClient] Error initializing worker process: {str(e)}") + raise + + except Exception as e: + logger.info(f"[ComfyStreamClient] Error during initialization: {str(e)}") + logger.info(f"[ComfyStreamClient] Error type: {type(e)}") + import traceback + logger.info(f"[ComfyStreamClient] Error traceback: {traceback.format_exc()}") + raise + + self.running_prompts = {} self.current_prompts = [] - self._cleanup_lock = asyncio.Lock() - self._prompt_update_lock = asyncio.Lock() + self.cleanup_lock = asyncio.Lock() + self.max_workers = max_workers + self.worker_tasks = [] + self.next_worker = 0 + self.distribution_lock = asyncio.Lock() + self.shutting_down = False # Add shutdown flag + self.distribution_task = None # Track distribution task + logger.info("[ComfyStreamClient] Initialized successfully") async def set_prompts(self, prompts: List[PromptDictInput]): - await self.cancel_running_prompts() + logger.info("set_prompts start") self.current_prompts = [convert_prompt(prompt) for prompt in prompts] - for idx in range(len(self.current_prompts)): - task = asyncio.create_task(self.run_prompt(idx)) - self.running_prompts[idx] = task + + # Start the distribution manager only if not already running + if self.distribution_task is None or self.distribution_task.done(): + self.shutting_down = False # Reset shutdown flag + self.distribution_task = asyncio.create_task(self.distribute_frames()) + self.running_prompts[-1] = self.distribution_task # Use -1 as a special key for the manager + logger.info("set_prompts end") + + async def distribute_frames(self): + """Manager that distributes frames across workers in round-robin fashion""" + logger.info(f"[ComfyStreamClient] Starting frame distribution manager") + + try: + # Initialize worker tasks + self.worker_tasks = [] + for worker_id in range(self.max_workers): + worker_task = asyncio.create_task(self.worker_loop(worker_id)) + self.worker_tasks.append(worker_task) + self.running_prompts[worker_id] = worker_task + + # Keep the manager running to monitor workers + while not self.shutting_down: + await asyncio.sleep(1.0) # Check periodically + + # Only restart crashed workers if we're not shutting down + if not self.shutting_down: + for worker_id, task in enumerate(self.worker_tasks): + if task.done(): + # Check if the task was cancelled (graceful shutdown) or crashed + if task.cancelled(): + logger.info(f"Worker {worker_id} was cancelled (graceful shutdown)") + else: + # Check if there was an exception + try: + task.result() + logger.info(f"Worker {worker_id} completed normally") + except Exception as e: + logger.warning(f"Worker {worker_id} crashed with error: {e}, restarting") + new_task = asyncio.create_task(self.worker_loop(worker_id)) + self.worker_tasks[worker_id] = new_task + self.running_prompts[worker_id] = new_task + + except asyncio.CancelledError: + logger.info("[ComfyStreamClient] Distribution manager cancelled") + except Exception as e: + logger.error(f"[ComfyStreamClient] Error in distribution manager: {e}") + finally: + logger.info("[ComfyStreamClient] Distribution manager stopped") - async def update_prompts(self, prompts: List[PromptDictInput]): - async with self._prompt_update_lock: - # TODO: currently under the assumption that only already running prompts are updated - if len(prompts) != len(self.current_prompts): - raise ValueError( - "Number of updated prompts must match the number of currently running prompts." - ) - # Validation step before updating the prompt, only meant for a single prompt for now - for idx, prompt in enumerate(prompts): - converted_prompt = convert_prompt(prompt) + async def worker_loop(self, worker_id: int): + """Simple worker loop - just process frames continuously""" + logger.info(f"[Worker {worker_id}] Started") + + frame_count = 0 + try: + while not self.shutting_down: try: - await self.comfy_client.queue_prompt(converted_prompt) - self.current_prompts[idx] = converted_prompt + # Simple round-robin prompt selection + prompt_index = worker_id % len(self.current_prompts) + current_prompt = self.current_prompts[prompt_index] + + # Just process the prompt + await self.comfy_client.queue_prompt(current_prompt) + frame_count += 1 + + except asyncio.CancelledError: + break except Exception as e: - raise Exception(f"Prompt update failed: {str(e)}") from e + if self.shutting_down: + break + logger.error(f"[Worker {worker_id}] Error: {e}") + await asyncio.sleep(0.1) + finally: + logger.info(f"[Worker {worker_id}] Processed {frame_count} frames") - async def run_prompt(self, prompt_index: int): - while True: - async with self._prompt_update_lock: + async def cleanup(self): + async with self.cleanup_lock: + logger.info("[ComfyStreamClient] Starting cleanup...") + + # Set shutdown flag to stop workers gracefully + self.shutting_down = True + + # Cancel distribution task first + if self.distribution_task and not self.distribution_task.done(): + self.distribution_task.cancel() try: - await self.comfy_client.queue_prompt(self.current_prompts[prompt_index]) + await self.distribution_task except asyncio.CancelledError: - raise + pass + + # Cancel all worker tasks + for task in self.worker_tasks: + if not task.done(): + task.cancel() + + # Wait for tasks to complete cancellation + if self.worker_tasks: + try: + await asyncio.gather(*self.worker_tasks, return_exceptions=True) + logger.info("[ComfyStreamClient] All worker tasks stopped") except Exception as e: - await self.cleanup() - logger.error(f"Error running prompt: {str(e)}") - raise - - async def cleanup(self): - await self.cancel_running_prompts() - async with self._cleanup_lock: - if self.comfy_client.is_running: + logger.error(f"Error waiting for worker tasks: {e}") + + # Clear the tasks list + self.worker_tasks.clear() + self.running_prompts.clear() + + # Cleanup the ComfyUI client and its executor + if hasattr(self, 'comfy_client') and self.comfy_client.is_running: try: - await self.comfy_client.__aexit__() + # Get the executor before closing the client + executor = getattr(self.comfy_client, 'executor', None) + + # Close the client first + await self.comfy_client.__aexit__(None, None, None) + logger.info("[ComfyStreamClient] ComfyUI client closed") + + # Then shutdown the executor and terminate processes + if executor: + logger.info("[ComfyStreamClient] Shutting down executor...") + # Shutdown the executor + executor.shutdown(wait=False) + + # Force terminate any remaining processes + if hasattr(executor, '_processes'): + for process in executor._processes: + if process.is_alive(): + logger.info(f"[ComfyStreamClient] Terminating worker process {process.pid}") + process.terminate() + # Give it a moment to terminate gracefully + try: + process.join(timeout=2.0) + except: + pass + # Force kill if still alive + if process.is_alive(): + logger.warning(f"[ComfyStreamClient] Force killing worker process {process.pid}") + process.kill() + + logger.info("[ComfyStreamClient] Executor shutdown completed") + except Exception as e: logger.error(f"Error during ComfyClient cleanup: {e}") await self.cleanup_queues() - logger.info("Client cleanup complete") - - async def cancel_running_prompts(self): - async with self._cleanup_lock: - tasks_to_cancel = list(self.running_prompts.values()) - for task in tasks_to_cancel: - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - self.running_prompts.clear() + + # Reset state for potential reuse + self.shutting_down = False + self.distribution_task = None + + logger.info("[ComfyStreamClient] Client cleanup complete") - async def cleanup_queues(self): - while not tensor_cache.image_inputs.empty(): - tensor_cache.image_inputs.get() - - while not tensor_cache.audio_inputs.empty(): - tensor_cache.audio_inputs.get() + # TODO: add for audio as well + while not self.image_inputs.empty(): + self.image_inputs.get() - while not tensor_cache.image_outputs.empty(): - await tensor_cache.image_outputs.get() - - while not tensor_cache.audio_outputs.empty(): - await tensor_cache.audio_outputs.get() + while not self.image_outputs.empty(): + self.image_outputs.get() def put_video_input(self, frame): - if tensor_cache.image_inputs.full(): - tensor_cache.image_inputs.get(block=True) - tensor_cache.image_inputs.put(frame) - + try: + # Check if frame is FrameProxy + if isinstance(frame, FrameProxy): + proxy = frame + else: + proxy = FrameProxy.avframe_to_frameproxy(frame) + + # Handle queue being full + if self.image_inputs.full(): + # logger.warning(f"[ComfyStreamClient] Input queue full, dropping oldest frame") + try: + self.image_inputs.get_nowait() + except Exception: + pass + + self.image_inputs.put_nowait(proxy) + # logger.info(f"[ComfyStreamClient] Video input queued. Queue size: {self.image_inputs.qsize()}") + except Exception as e: + logger.error(f"[ComfyStreamClient] Error putting video frame: {str(e)}") + def put_audio_input(self, frame): - tensor_cache.audio_inputs.put(frame) + self.audio_inputs.put(frame) async def get_video_output(self): - return await tensor_cache.image_outputs.get() + try: + # logger.info(f"[ComfyStreamClient] get_video_output called - PID: {os.getpid()}") + tensor = await asyncio.wait_for( + asyncio.get_event_loop().run_in_executor(None, self.image_outputs.get), + timeout=5.0 + ) + # logger.info(f"[ComfyStreamClient] get_video_output returning tensor: {tensor.shape} - PID: {os.getpid()}") + return tensor + except asyncio.TimeoutError: + logger.warning(f"[ComfyStreamClient] get_video_output timeout - PID: {os.getpid()}") + return torch.zeros((1, 3, self.height, self.width), dtype=torch.float32) + except Exception as e: + logger.error(f"[ComfyStreamClient] Error getting video output: {str(e)} - PID: {os.getpid()}") + return torch.zeros((1, 3, self.height, self.width), dtype=torch.float32) async def get_audio_output(self): - return await tensor_cache.audio_outputs.get() - + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, self.audio_outputs.get) + async def get_available_nodes(self): - """Get metadata and available nodes info in a single pass""" - # TODO: make it for for multiple prompts - if not self.running_prompts: + """Get metadata and available nodes info using cached nodes to avoid ProcessPoolExecutor conflicts""" + if not self.current_prompts: return {} - try: - from comfy.nodes.package import import_all_nodes_in_workspace - nodes = import_all_nodes_in_workspace() + # Use cached nodes instead of calling import_all_nodes_in_workspace from worker process + if self._nodes is None: + logger.warning("[ComfyStreamClient] Nodes cache not available, returning empty result") + return {} + try: all_prompts_nodes_info = {} for prompt_index, prompt in enumerate(self.current_prompts): - # Get set of class types we need metadata for, excluding LoadTensor and SaveTensor + # Get set of class types we need metadata for needed_class_types = { node.get('class_type') for node in prompt.values() @@ -132,7 +342,7 @@ async def get_available_nodes(self): nodes_info = {} # Only process nodes until we've found all the ones we need - for class_type, node_class in nodes.NODE_CLASS_MAPPINGS.items(): + for class_type, node_class in self._nodes.NODE_CLASS_MAPPINGS.items(): if not remaining_nodes: # Exit early if we've found all needed nodes break @@ -232,3 +442,61 @@ async def get_available_nodes(self): except Exception as e: logger.error(f"Error getting node info: {str(e)}") return {} + + def register_tensorrt_paths_main_process(self, workspace_path): + """Register TensorRT paths in the main process for validation""" + try: + from comfy.cmd import folder_paths + + if workspace_path: + base_dir = workspace_path + tensorrt_models_dir = os.path.join(base_dir, "models", "tensorrt") + tensorrt_outputs_dir = os.path.join(base_dir, "outputs", "tensorrt") + else: + tensorrt_models_dir = os.path.join(folder_paths.models_dir, "tensorrt") + tensorrt_outputs_dir = os.path.join(folder_paths.models_dir, "outputs", "tensorrt") + + # logger.info(f"[ComfyStreamClient] Registering TensorRT paths in main process") + # logger.info(f"[ComfyStreamClient] TensorRT models dir: {tensorrt_models_dir}") + # logger.info(f"[ComfyStreamClient] TensorRT outputs dir: {tensorrt_outputs_dir}") + + # Register TensorRT paths + if "tensorrt" in folder_paths.folder_names_and_paths: + existing_paths = folder_paths.folder_names_and_paths["tensorrt"][0] + for path in [tensorrt_models_dir, tensorrt_outputs_dir]: + if path not in existing_paths: + existing_paths.append(path) + folder_paths.folder_names_and_paths["tensorrt"][1].add(".engine") + else: + folder_paths.folder_names_and_paths["tensorrt"] = ( + [tensorrt_models_dir, tensorrt_outputs_dir], + {".engine"} + ) + + # Verify registration + # available_files = folder_paths.get_filename_list("tensorrt") + # logger.info(f"[ComfyStreamClient] Main process TensorRT files: {available_files}") + + except Exception as e: + logger.error(f"[ComfyStreamClient] Error registering TensorRT paths in main process: {e}") + import traceback + logger.error(f"[ComfyStreamClient] Traceback: {traceback.format_exc()}") + + async def update_prompts(self, prompts: List[PromptDictInput]): + """Update the existing processing prompts without restarting workers.""" + + # Simply update the current prompts - worker loops will pick up changes on next iteration + self.current_prompts = [convert_prompt(prompt) for prompt in prompts] + + logger.info("[ComfyStreamClient] Prompts updated") + + def _initialize_nodes_cache(self): + """Initialize nodes cache in main process to avoid ProcessPoolExecutor conflicts""" + try: + logger.info("[ComfyStreamClient] Initializing nodes cache in main process...") + from comfy.nodes.package import import_all_nodes_in_workspace + self._nodes = import_all_nodes_in_workspace() + logger.info(f"[ComfyStreamClient] Cached {len(self._nodes.NODE_CLASS_MAPPINGS)} node types") + except Exception as e: + logger.error(f"[ComfyStreamClient] Error initializing nodes cache: {e}") + self._nodes = None \ No newline at end of file diff --git a/src/comfystream/frame_proxy.py b/src/comfystream/frame_proxy.py new file mode 100644 index 00000000..e985e50f --- /dev/null +++ b/src/comfystream/frame_proxy.py @@ -0,0 +1,27 @@ +import torch +import numpy as np + +class SideData: + pass + +class FrameProxy: + def __init__(self, tensor, width, height, pts=None, time_base=None): + self.width = width + self.height = height + self.pts = pts + self.time_base = time_base + self.side_data = SideData() + self.side_data.input = tensor.clone().cpu() + self.side_data.skipped = True + + @staticmethod + def avframe_to_frameproxy(frame): + frame_np = frame.to_ndarray(format="rgb24").astype(np.float32) / 255.0 + tensor = torch.from_numpy(frame_np).unsqueeze(0) + return FrameProxy( + tensor=tensor.clone().cpu(), + width=frame.width, + height=frame.height, + pts=getattr(frame, "pts", None), + time_base=getattr(frame, "time_base", None) + ) \ No newline at end of file diff --git a/src/comfystream/pipeline.py b/src/comfystream/pipeline.py index a5776dfc..3b39229f 100644 --- a/src/comfystream/pipeline.py +++ b/src/comfystream/pipeline.py @@ -3,10 +3,14 @@ import numpy as np import asyncio import logging +import time +import os +from collections import OrderedDict from typing import Any, Dict, Union, List, Optional from comfystream.client import ComfyStreamClient from comfystream.server.utils import temporary_log_level +from comfystream.frame_proxy import FrameProxy WARMUP_RUNS = 5 @@ -21,39 +25,100 @@ class Pipeline: postprocessing, and queue management. """ - def __init__(self, width: int = 512, height: int = 512, - comfyui_inference_log_level: Optional[int] = None, **kwargs): + def __init__(self, + width: int = 512, + height: int = 512, + max_workers: int = 1, + comfyui_inference_log_level: Optional[int] = None, + **kwargs): """Initialize the pipeline with the given configuration. Args: width: Width of the video frames (default: 512) height: Height of the video frames (default: 512) + max_workers: Number of worker processes (default: 1) comfyui_inference_log_level: The logging level for ComfyUI inference. - Defaults to None, using the global ComfyUI log level. - **kwargs: Additional arguments to pass to the ComfyStreamClient + **kwargs: Additional arguments to pass to the ComfyStreamClient (cwd, disable_cuda_malloc, etc.) """ - self.client = ComfyStreamClient(**kwargs) + self.client = ComfyStreamClient( + max_workers=max_workers, + **kwargs) self.width = width self.height = height self.video_incoming_frames = asyncio.Queue() self.audio_incoming_frames = asyncio.Queue() + self.output_buffer = asyncio.Queue(maxsize=6) # Small buffer to smooth output + self.input_frame_counter = 0 + + # Simple frame collection without ordering + self.collector_task = asyncio.create_task(self._collect_frames_simple()) + self.processed_audio_buffer = np.array([], dtype=np.int16) self._comfyui_inference_log_level = comfyui_inference_log_level - async def warm_video(self): - """Warm up the video processing pipeline with dummy frames.""" - # Create dummy frame with the CURRENT resolution settings - dummy_frame = av.VideoFrame() - dummy_frame.side_data.input = torch.randn(1, self.height, self.width, 3) - - logger.info(f"Warming video pipeline with resolution {self.width}x{self.height}") + self.next_expected_frame_id = 0 - for _ in range(WARMUP_RUNS): - self.client.put_video_input(dummy_frame) - await self.client.get_video_output() + # Add a queue for frame log entries + self.running = True + + async def _collect_frames_simple(self): + """Simple frame collector - no ordering, just buffer""" + try: + while self.running: + try: + tensor = await asyncio.wait_for(self.client.get_video_output(), timeout=0.1) + if tensor is not None: + # Just put it in the buffer, don't worry about order + try: + self.output_buffer.put_nowait(tensor) + except asyncio.QueueFull: + # Drop oldest frame if buffer is full + try: + self.output_buffer.get_nowait() + self.output_buffer.put_nowait(tensor) + except asyncio.QueueEmpty: + pass + except asyncio.TimeoutError: + pass + except Exception as e: + logger.error(f"Error collecting frame: {e}") + + await asyncio.sleep(0.001) # Minimal sleep + + except asyncio.CancelledError: + pass + + async def initialize(self, prompts): + await self.set_prompts(prompts) + await self.warm_video() + + async def warm_video(self): + logger.info("[PipelineMulti] Starting warmup...") + for i in range(WARMUP_RUNS): + dummy_tensor = torch.randn(1, self.height, self.width, 3) + dummy_proxy = FrameProxy( + tensor=dummy_tensor, + width=self.width, + height=self.height, + pts=None, + time_base=None + ) + # Set frame_id for warmup frames (negative to distinguish from real frames) + dummy_proxy.side_data.frame_id = -(i + 1) + logger.debug(f"[PipelineMulti] Warmup: putting dummy frame {i+1}/{WARMUP_RUNS}") + self.client.put_video_input(dummy_proxy) + + # For warmup, we don't need to wait for ordered output + try: + out = await asyncio.wait_for(self.client.get_video_output(), timeout=30.0) + logger.debug(f"[PipelineMulti] Warmup: got output for dummy frame {i+1}/{WARMUP_RUNS}") + except asyncio.TimeoutError: + logger.warning(f"[PipelineMulti] Warmup frame {i+1} timed out") + + logger.info("[PipelineMulti] Warmup complete.") async def warm_audio(self): """Warm up the audio processing pipeline with dummy frames.""" @@ -86,6 +151,8 @@ async def update_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any await self.client.update_prompts(prompts) else: await self.client.update_prompts([prompts]) + + logger.info("[PipelineMulti] Prompts updated") async def put_video_frame(self, frame: av.VideoFrame): """Queue a video frame for processing. @@ -93,8 +160,18 @@ async def put_video_frame(self, frame: av.VideoFrame): Args: frame: The video frame to process """ + current_time = time.time() frame.side_data.input = self.video_preprocess(frame) frame.side_data.skipped = True + frame.side_data.frame_received_time = current_time + + # Assign frame ID and increment counter + frame_id = self.input_frame_counter + frame.side_data.frame_id = frame_id + frame.side_data.client_index = -1 + self.next_expected_frame_id += 1 + self.input_frame_counter += 1 + self.client.put_video_input(frame) await self.video_incoming_frames.put(frame) @@ -132,6 +209,7 @@ def audio_preprocess(self, frame: av.AudioFrame) -> Union[torch.Tensor, np.ndarr """ return frame.to_ndarray().ravel().reshape(-1, 2).mean(axis=1).astype(np.int16) + def video_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.VideoFrame: """Postprocess a video frame after processing. @@ -141,8 +219,27 @@ def video_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.Video Returns: The postprocessed video frame """ + + # First ensure we have a tensor + if isinstance(output, np.ndarray): + output = torch.from_numpy(output) + + # Handle different tensor formats + if len(output.shape) == 4: # BCHW or BHWC format + if output.shape[1] != 3: # If BHWC format + output = output.permute(0, 3, 1, 2) # Convert BHWC to BCHW + output = output[0] # Take first image from batch -> CHW + elif len(output.shape) != 3: # Should be CHW at this point + raise ValueError(f"Unexpected tensor shape after batch removal: {output.shape}") + + # Convert CHW to HWC for video frame + output = output.permute(1, 2, 0) # CHW -> HWC + + # Convert to numpy and create video frame return av.VideoFrame.from_ndarray( - (output * 255.0).clamp(0, 255).to(dtype=torch.uint8).squeeze(0).cpu().numpy() + (output * 255.0).clamp(0, 255).to(dtype=torch.uint8).squeeze(0).cpu().numpy(), + # (output * 255.0).clamp(0, 255).to(dtype=torch.uint8).cpu().numpy(), + format='rgb24' ) def audio_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.AudioFrame: @@ -163,16 +260,22 @@ async def get_processed_video_frame(self) -> av.VideoFrame: Returns: The processed video frame """ - async with temporary_log_level("comfy", self._comfyui_inference_log_level): - out_tensor = await self.client.get_video_output() + # Get input frame frame = await self.video_incoming_frames.get() - while frame.side_data.skipped: - frame = await self.video_incoming_frames.get() - + + # Get any available output (don't match input to output) + try: + out_tensor = await asyncio.wait_for(self.output_buffer.get(), timeout=1.0) + except asyncio.TimeoutError: + # Fallback: return input frame to maintain stream + logger.warning("Output timeout, using input frame") + return frame + + # Process and return processed_frame = self.video_postprocess(out_tensor) processed_frame.pts = frame.pts processed_frame.time_base = frame.time_base - + return processed_frame async def get_processed_audio_frame(self) -> av.AudioFrame: @@ -207,4 +310,20 @@ async def get_nodes_info(self) -> Dict[str, Any]: async def cleanup(self): """Clean up resources used by the pipeline.""" - await self.client.cleanup() \ No newline at end of file + logger.info("[PipelineMulti] Starting pipeline cleanup...") + + # Set running flag to false to stop frame processing + self.running = False + + # Cancel collector task + if hasattr(self, 'collector_task') and self.collector_task: + self.collector_task.cancel() + try: + await self.collector_task + except asyncio.CancelledError: + pass + + # Clean up the client (this will gracefully shutdown workers) + await self.client.cleanup() + + logger.info("[PipelineMulti] Pipeline cleanup complete") \ No newline at end of file diff --git a/src/comfystream/scripts/build_trt.py b/src/comfystream/scripts/build_trt.py index fa0a0f24..866d7de6 100644 --- a/src/comfystream/scripts/build_trt.py +++ b/src/comfystream/scripts/build_trt.py @@ -11,9 +11,9 @@ # $> python src/comfystream/scripts/build_trt.py --model /ComfyUI/models/checkpoints/SD1.5/dreamshaper-8.safetensors --out-engine /ComfyUI/output/tensorrt/static-dreamshaper8_SD15_$stat-b-1-h-512-w-512_00001_.engine # Paths path explicitly to use the downloaded comfyUI installation on root -ROOT_DIR="/workspace" -COMFYUI_DIR = "/workspace/ComfyUI" -timing_cache_path = "/workspace/ComfyUI/output/tensorrt/timing_cache" +ROOT_DIR = os.environ.get("ROOT_DIR", "/workspace") +COMFYUI_DIR = os.path.join(ROOT_DIR, "ComfyUI") +timing_cache_path = os.path.join(ROOT_DIR, "ComfyUI/output/tensorrt/timing_cache") if ROOT_DIR not in sys.path: sys.path.insert(0, ROOT_DIR) @@ -21,9 +21,9 @@ sys.path.insert(0, COMFYUI_DIR) comfy_dirs = [ - "/workspace/ComfyUI/", - "/workspace/ComfyUI/comfy", - "/workspace/ComfyUI/comfy_extras" + COMFYUI_DIR, + os.path.join(COMFYUI_DIR, "comfy"), + os.path.join(COMFYUI_DIR, "comfy_extras") ] for comfy_dir in comfy_dirs: diff --git a/src/comfystream/server/utils/config.py b/src/comfystream/server/utils/config.py new file mode 100644 index 00000000..7f066643 --- /dev/null +++ b/src/comfystream/server/utils/config.py @@ -0,0 +1,45 @@ +import tomli +import logging +from typing import List, Dict, Any, Optional + +logger = logging.getLogger(__name__) + +class ComfyConfig: + def __init__(self, config_path: Optional[str] = None): + self.servers = [] + self.config_path = config_path + if config_path: + self.load_config(config_path) + else: + # Default to single local server if no config provided + self.servers = [{"host": "127.0.0.1", "port": 8188}] + + def load_config(self, config_path: str): + """Load server configuration from TOML file""" + try: + with open(config_path, "rb") as f: + config = tomli.load(f) + + # Extract server configurations + if "servers" in config: + self.servers = config["servers"] + logger.info(f"Loaded {len(self.servers)} server configurations") + else: + logger.warning("No servers defined in config, using default") + self.servers = [{"host": "127.0.0.1", "port": 8198}] + + # Validate each server has required fields + for i, server in enumerate(self.servers): + if "host" not in server or "port" not in server: + logger.warning(f"Server {i} missing host or port, using defaults") + server["host"] = server.get("host", "127.0.0.1") + server["port"] = server.get("port", 8198) + + except Exception as e: + logger.error(f"Error loading config from {config_path}: {e}") + # Fall back to default server + self.servers = [{"host": "127.0.0.1", "port": 8198}] + + def get_servers(self) -> List[Dict[str, Any]]: + """Return list of server configurations""" + return self.servers \ No newline at end of file diff --git a/src/comfystream/tensor_cache.py b/src/comfystream/tensor_cache.py index 0216f73b..1cc3c73b 100644 --- a/src/comfystream/tensor_cache.py +++ b/src/comfystream/tensor_cache.py @@ -1,14 +1,127 @@ +from comfystream import tensor_cache +import logging +import queue import torch -import numpy as np +import asyncio +import os +logger = logging.getLogger(__name__) -from queue import Queue -from asyncio import Queue as AsyncQueue +image_inputs = None +image_outputs = None -from typing import Union +audio_inputs = None +audio_outputs = None -# TODO: improve eviction policy fifo might not be the best, skip alternate frames instead -image_inputs: Queue[Union[torch.Tensor, np.ndarray]] = Queue(maxsize=1) -image_outputs: AsyncQueue[Union[torch.Tensor, np.ndarray]] = AsyncQueue() +# Global frame ID tracking for worker processes +current_frame_id = None +frame_id_mapping = {} # Maps tensor id to frame_id -audio_inputs: Queue[Union[torch.Tensor, np.ndarray]] = Queue() -audio_outputs: AsyncQueue[Union[torch.Tensor, np.ndarray]] = AsyncQueue() +class FrameData: + """Wrapper class to carry frame metadata through the processing pipeline""" + def __init__(self, tensor, frame_id=None): + self.tensor = tensor + self.frame_id = frame_id + +# Create wrapper classes that match the interface of the original queues +class MultiProcessInputQueue: + def __init__(self, mp_queue): + self.queue = mp_queue + + def get(self, block=True, timeout=None): + result = self.queue.get(block=block, timeout=timeout) + + # Extract frame metadata and store it globally for this worker + global current_frame_id + if hasattr(result, 'side_data') and hasattr(result.side_data, 'frame_id'): + current_frame_id = result.side_data.frame_id + # logger.info(f"[MultiProcessInputQueue] Frame {current_frame_id} retrieved by worker PID: {os.getpid()}") + + return result + + def get_nowait(self): + result = self.queue.get_nowait() + + # Extract frame metadata and store it globally for this worker + global current_frame_id + if hasattr(result, 'side_data') and hasattr(result.side_data, 'frame_id'): + current_frame_id = result.side_data.frame_id + # logger.info(f"[MultiProcessInputQueue] Frame {current_frame_id} retrieved (nowait) by worker PID: {os.getpid()}") + + return result + + def put(self, item, block=True, timeout=None): + return self.queue.put(item, block=block, timeout=timeout) + + def put_nowait(self, item): + return self.queue.put_nowait(item) + + def empty(self): + return self.queue.empty() + + def full(self): + return self.queue.full() + +class MultiProcessOutputQueue: + def __init__(self, mp_queue): + self.queue = mp_queue + + async def get(self): + # Convert synchronous get to async + loop = asyncio.get_event_loop() + result = await loop.run_in_executor(None, self.queue.get) + + # Check if this is a tuple with frame_id + if isinstance(result, tuple) and len(result) == 2: + frame_id, tensor = result + return (frame_id, tensor) + else: + # Backward compatibility - return just the tensor + return result + + async def put(self, item): + # Convert synchronous put to async + loop = asyncio.get_event_loop() + # Ensure tensor is on CPU before sending + if torch.is_tensor(item): + item = item.cpu() + # logger.info(f"[MultiProcessOutputQueue] Frame sent from worker PID: {os.getpid()}") + return await loop.run_in_executor(None, self.queue.put, item) + + def put_nowait(self, item): + try: + # Ensure tensor is on CPU + if torch.is_tensor(item): + item = item.cpu() + self.queue.put_nowait(item) + except queue.Full: + # Simple: drop one old frame and try again + try: + self.queue.get_nowait() + self.queue.put_nowait(item) + except: + pass # If still fails, just drop this frame + +def init_tensor_cache(image_inputs, image_outputs, audio_inputs, audio_outputs, workspace_path=None): + """Initialize the tensor cache for a worker process. + + Args: + image_inputs: Multiprocessing Queue for input images + image_outputs: Multiprocessing Queue for output images + audio_inputs: Multiprocessing Queue for input audio + audio_outputs: Multiprocessing Queue for output audio + workspace_path: The ComfyUI workspace path (should be C:\sd\ComfyUI-main) + """ + logger.info(f"[init_tensor_cache] Setting up tensor_cache queues in worker - PID: {os.getpid()}") + logger.info(f"[init_tensor_cache] Workspace path: {workspace_path}") + logger.info(f"[init_tensor_cache] Current working directory: {os.getcwd()}") + + # Replace the queues with our wrapped versions that match the original interface + tensor_cache.image_inputs = MultiProcessInputQueue(image_inputs) + tensor_cache.image_outputs = MultiProcessOutputQueue(image_outputs) + tensor_cache.audio_inputs = MultiProcessInputQueue(audio_inputs) + tensor_cache.audio_outputs = MultiProcessOutputQueue(audio_outputs) + + logger.info(f"[init_tensor_cache] tensor_cache.image_outputs id: {id(tensor_cache.image_outputs)} - PID: {os.getpid()}") + logger.info(f"[init_tensor_cache] Initialization complete - PID: {os.getpid()}") + + return os.getpid() # Return PID for verification \ No newline at end of file