diff --git a/runner/app/live/infer.py b/runner/app/live/infer.py index dcb30a60e..7bbb77e59 100644 --- a/runner/app/live/infer.py +++ b/runner/app/live/infer.py @@ -57,20 +57,22 @@ async def main( ): loop = asyncio.get_event_loop() loop.set_exception_handler(asyncio_exception_handler) - process = ProcessGuardian(pipeline, params or {}) # Only initialize the streamer if we have a protocol and URLs to connect to streamer = None if stream_protocol and subscribe_url and publish_url: + width = params.get('width') + height = params.get('height') if stream_protocol == "trickle": protocol = TrickleProtocol( - subscribe_url, publish_url, control_url, events_url + subscribe_url, publish_url, control_url, events_url, + width=width, height=height ) elif stream_protocol == "zeromq": protocol = ZeroMQProtocol(subscribe_url, publish_url) else: raise ValueError(f"Unsupported protocol: {stream_protocol}") - streamer = PipelineStreamer(protocol, process, request_id, stream_id) + streamer = PipelineStreamer(protocol, process, request_id, stream_id, width=width, height=height) api = None try: diff --git a/runner/app/live/pipelines/comfyui.py b/runner/app/live/pipelines/comfyui.py index 2fe9655c3..9e5dd5d50 100644 --- a/runner/app/live/pipelines/comfyui.py +++ b/runner/app/live/pipelines/comfyui.py @@ -2,29 +2,35 @@ import json import torch import asyncio -from typing import Union +from typing import Union, Optional, Tuple from pydantic import BaseModel, field_validator import pathlib from .interface import Pipeline from comfystream.client import ComfyStreamClient from trickle import VideoFrame, VideoOutput +from utils import ComfyUtils import logging COMFY_UI_WORKSPACE_ENV = "COMFY_UI_WORKSPACE" WARMUP_RUNS = 1 -_default_workflow_path = pathlib.Path(__file__).parent.absolute() / "comfyui_default_workflow.json" -with open(_default_workflow_path, 'r') as f: - DEFAULT_WORKFLOW_JSON = json.load(f) +def get_default_workflow_json(): + _default_workflow_path = pathlib.Path(__file__).parent.absolute() / "comfyui_default_workflow.json" + with open(_default_workflow_path, 'r') as f: + return json.load(f) +# Get the default workflow json during startup +DEFAULT_WORKFLOW_JSON = get_default_workflow_json() class ComfyUIParams(BaseModel): class Config: extra = "forbid" prompt: Union[str, dict] = DEFAULT_WORKFLOW_JSON + width: Optional[int] = None + height: Optional[int] = None @field_validator('prompt') @classmethod @@ -53,6 +59,9 @@ def __init__(self): self.client = ComfyStreamClient(cwd=comfy_ui_workspace) self.params: ComfyUIParams self.video_incoming_frames: asyncio.Queue[VideoOutput] = asyncio.Queue() + self.width = ComfyUtils.DEFAULT_WIDTH + self.height = ComfyUtils.DEFAULT_HEIGHT + self.pause_input = False async def initialize(self, **params): new_params = ComfyUIParams(**params) @@ -60,10 +69,19 @@ async def initialize(self, **params): # TODO: currently its a single prompt, but need to support multiple prompts await self.client.set_prompts([new_params.prompt]) self.params = new_params - - # Warm up the pipeline + + # Get dimensions from params or environment variable + width = new_params.width + height = new_params.height + + # Fallback to default dimensions if not found + width = width or ComfyUtils.DEFAULT_WIDTH + height = height or ComfyUtils.DEFAULT_HEIGHT + + # Warm up the pipeline with the workflow dimensions + logging.info(f"Warming up pipeline with dimensions: {width}x{height}") dummy_frame = VideoFrame(None, 0, 0) - dummy_frame.side_data.input = torch.randn(1, 512, 512, 3) + dummy_frame.side_data.input = torch.randn(1, height, width, 3) for _ in range(WARMUP_RUNS): self.client.put_video_input(dummy_frame) @@ -71,6 +89,8 @@ async def initialize(self, **params): logging.info("Pipeline initialization and warmup complete") async def put_video_frame(self, frame: VideoFrame, request_id: str): + if self.pause_input: + return tensor = frame.tensor if tensor.is_cuda: # Clone the tensor to be able to send it on comfystream internal queue @@ -99,6 +119,28 @@ async def update_params(self, **params): self.params = new_params async def stop(self): - logging.info("Stopping ComfyUI pipeline") - await self.client.cleanup() - logging.info("ComfyUI pipeline stopped") + try: + self.pause_input = True + logging.info("Stopping ComfyUI pipeline") + await self.client.cleanup(unload_models=False) + # Wait for the pipeline to stop + await asyncio.sleep(1) + # Clear the video_incoming_frames queue + while not self.video_incoming_frames.empty(): + try: + frame = self.video_incoming_frames.get_nowait() + # Ensure any CUDA tensors are properly handled + if frame.tensor is not None and frame.tensor.is_cuda: + frame.tensor.cpu() + except asyncio.QueueEmpty: + break + + # Force CUDA cache clear + if torch.cuda.is_available(): + torch.cuda.empty_cache() + except Exception as e: + logging.error(f"Error stopping ComfyUI pipeline: {e}") + finally: + self.pause_input = False + + logging.info("ComfyUI pipeline stopped") \ No newline at end of file diff --git a/runner/app/live/pipelines/comfyui_default_workflow.json b/runner/app/live/pipelines/comfyui_default_workflow.json index 36ecd9f77..fe997eeb6 100644 --- a/runner/app/live/pipelines/comfyui_default_workflow.json +++ b/runner/app/live/pipelines/comfyui_default_workflow.json @@ -20,7 +20,7 @@ }, "3": { "inputs": { - "unet_name": "static-dreamshaper8_SD15_$stat-b-1-h-512-w-512_00001_.engine", + "unet_name": "static-dreamshaper8_SD15_$stat-b-1-h-704-w-384_00001_.engine", "model_type": "SD15" }, "class_type": "TensorRTLoader", @@ -146,8 +146,8 @@ }, "16": { "inputs": { - "width": 512, - "height": 512, + "width": 384, + "height": 704, "batch_size": 1 }, "class_type": "EmptyLatentImage", diff --git a/runner/app/live/pipelines/noop.py b/runner/app/live/pipelines/noop.py index 350543cc0..ca588471c 100644 --- a/runner/app/live/pipelines/noop.py +++ b/runner/app/live/pipelines/noop.py @@ -1,7 +1,7 @@ import logging import asyncio from PIL import Image - +import torch from .interface import Pipeline from trickle import VideoFrame, VideoOutput @@ -26,3 +26,18 @@ async def update_params(self, **params): async def stop(self): logging.info("Stopping pipeline") + + # Clear the frame queue and move any CUDA tensors to CPU + while not self.frame_queue.empty(): + try: + frame = self.frame_queue.get_nowait() + if frame.tensor.is_cuda: + frame.tensor.cpu() # Move tensor to CPU before deletion + except asyncio.QueueEmpty: + break + except Exception as e: + logging.error(f"Error clearing frame queue: {e}") + + # Force CUDA cache clear + if torch.cuda.is_available(): + torch.cuda.empty_cache() diff --git a/runner/app/live/streamer/process.py b/runner/app/live/streamer/process.py index a11c0958b..25109e7b7 100644 --- a/runner/app/live/streamer/process.py +++ b/runner/app/live/streamer/process.py @@ -11,6 +11,7 @@ from pipelines import load_pipeline, Pipeline from log import config_logging, config_logging_fields, log_timing from trickle import InputFrame, AudioFrame, VideoFrame, OutputFrame, VideoOutput, AudioOutput +from utils import ComfyUtils class PipelineProcess: @staticmethod @@ -24,6 +25,7 @@ def start(pipeline_name: str, params: dict): def __init__(self, pipeline_name: str): self.pipeline_name = pipeline_name + self.pipeline = None # Initialize pipeline as None self.ctx = mp.get_context("spawn") self.input_queue = self.ctx.Queue(maxsize=2) @@ -165,9 +167,12 @@ async def _initialize_pipeline(self): logging.info("PipelineProcess: No params found in param_update_queue, loading with default params") with log_timing(f"PipelineProcess: Pipeline loading with {params}"): - pipeline = load_pipeline(self.pipeline_name) - await pipeline.initialize(**params) - return pipeline + self.pipeline = load_pipeline(self.pipeline_name) + + # TODO: We may need to call reset_stream when resolution is changed and start the pipeline again + # Changing the engine causes issues, maybe cleanup related + await self.pipeline.initialize(**params) + return self.pipeline except Exception as e: self._report_error(f"Error loading pipeline: {e}") if not params: @@ -177,19 +182,19 @@ async def _initialize_pipeline(self): with log_timing( f"PipelineProcess: Pipeline loading with default params due to error with params: {params}" ): - pipeline = load_pipeline(self.pipeline_name) - await pipeline.initialize() - return pipeline + self.pipeline = load_pipeline(self.pipeline_name) + await self.pipeline.initialize() + return self.pipeline except Exception as e: self._report_error(f"Error loading pipeline with default params: {e}") raise async def _run_pipeline_loops(self): - pipeline = await self._initialize_pipeline() + await self._initialize_pipeline() self.pipeline_initialized.set() - input_task = asyncio.create_task(self._input_loop(pipeline)) - output_task = asyncio.create_task(self._output_loop(pipeline)) - param_task = asyncio.create_task(self._param_update_loop(pipeline)) + input_task = asyncio.create_task(self._input_loop()) + output_task = asyncio.create_task(self._output_loop()) + param_task = asyncio.create_task(self._param_update_loop()) async def wait_for_stop(): while not self.is_done(): @@ -205,17 +210,17 @@ async def wait_for_stop(): for task in tasks: task.cancel() await asyncio.gather(*tasks, return_exceptions=True) - await self._cleanup_pipeline(pipeline) + await self._cleanup_pipeline() logging.info("PipelineProcess: _run_pipeline_loops finished.") - async def _input_loop(self, pipeline: Pipeline): + async def _input_loop(self): while not self.is_done(): try: input_frame = await asyncio.to_thread(self.input_queue.get, timeout=0.1) if isinstance(input_frame, VideoFrame): input_frame.log_timestamps["pre_process_frame"] = time.time() - await pipeline.put_video_frame(input_frame, self.request_id) + await self.pipeline.put_video_frame(input_frame, self.request_id) elif isinstance(input_frame, AudioFrame): self._try_queue_put(self.output_queue, AudioOutput([input_frame], self.request_id)) except queue.Empty: @@ -224,10 +229,10 @@ async def _input_loop(self, pipeline: Pipeline): except Exception as e: self._report_error(f"Error processing input frame: {e}") - async def _output_loop(self, pipeline: Pipeline): + async def _output_loop(self): while not self.is_done(): try: - output = await pipeline.get_processed_video_frame() + output = await self.pipeline.get_processed_video_frame() if isinstance(output, VideoOutput) and not output.tensor.is_cuda and torch.cuda.is_available(): output = output.replace_tensor(output.tensor.cuda()) output.log_timestamps["post_process_frame"] = time.time() @@ -235,19 +240,18 @@ async def _output_loop(self, pipeline: Pipeline): except Exception as e: self._report_error(f"Error processing output frame: {e}") - async def _param_update_loop(self, pipeline: Pipeline): + async def _param_update_loop(self): while not self.is_done(): try: params = await asyncio.to_thread(self.param_update_queue.get, timeout=0.1) if self._handle_logging_params(params): logging.info(f"PipelineProcess: Updating pipeline parameters: {params}") - await pipeline.update_params(**params) + await self.pipeline.update_params(**params) except queue.Empty: - # Timeout ensures the non-daemon threads from to_thread can exit if task is cancelled continue except Exception as e: - self._report_error(f"Error updating params: {e}") + self._report_error(f"Error updating parameters: {e}") def _report_error(self, error_msg: str): error_event = { @@ -257,12 +261,12 @@ def _report_error(self, error_msg: str): logging.error(error_msg) self._try_queue_put(self.error_queue, error_event) - async def _cleanup_pipeline(self, pipeline): - if pipeline is not None: + async def _cleanup_pipeline(self): + if self.pipeline: try: - await pipeline.stop() + await self.pipeline.stop() except Exception as e: - logging.error(f"Error stopping pipeline: {e}") + self._report_error(f"Error cleaning up pipeline: {e}") def _setup_logging(self): level = ( diff --git a/runner/app/live/streamer/process_guardian.py b/runner/app/live/streamer/process_guardian.py index eb3fe933c..432f3b051 100644 --- a/runner/app/live/streamer/process_guardian.py +++ b/runner/app/live/streamer/process_guardian.py @@ -7,6 +7,7 @@ from trickle import InputFrame, OutputFrame from .process import PipelineProcess from .status import PipelineState, PipelineStatus, InferenceStatus, InputStatus +from utils import ComfyUtils FPS_LOG_INTERVAL = 10.0 @@ -19,6 +20,9 @@ class StreamerCallbacks(abc.ABC): @abc.abstractmethod def is_stream_running(self) -> bool: ... + + +class ProcessCallbacks(abc.ABC): @abc.abstractmethod async def emit_monitoring_event(self, event_data: dict) -> None: ... @@ -44,6 +48,7 @@ def __init__( ): self.pipeline = pipeline self.initial_params = params + self.width, self.height = ComfyUtils.get_latent_image_dimensions(params.get('prompt')) self.streamer: StreamerCallbacks = _NoopStreamerCallbacks() self.process: Optional[PipelineProcess] = None @@ -82,6 +87,24 @@ async def reset_stream( ): if not self.process: raise RuntimeError("Process not running") + + # Check if resolution has changed + new_width = params.get("width", None) + new_height = params.get("height", None) + if (new_width is None or new_height is None): + new_width, new_height = ComfyUtils.DEFAULT_WIDTH, ComfyUtils.DEFAULT_HEIGHT + + # If resolution changed, we need to restart the process (does not work for comfyui) + if (new_width != self.width or new_height != self.height): + logging.info(f"Resolution changed from {self.width}x{self.height} to {new_width}x{new_height}, restarting process") + self.width = new_width + self.height = new_height + await self.process._cleanup_pipeline() + await self.stop() + # Create new process with current pipeline name and params + params.update({"width": new_width, "height": new_height}) + self.process = PipelineProcess.start(self.pipeline, params) + self.status.start_time = time.time() self.status.input_status = InputStatus() self.input_fps_counter.reset() @@ -89,6 +112,8 @@ async def reset_stream( self.streamer = streamer or _NoopStreamerCallbacks() self.process.reset_stream(request_id, manifest_id, stream_id) + self.process.update_params(params) + await self.update_params(params) self.status.update_state(PipelineState.ONLINE) @@ -310,7 +335,7 @@ async def _monitor_loop(self): # Hot fix: the comfyui pipeline process is having trouble shutting down and causes restarts not to recover. # So we skip the restart here and move the state to ERROR so the worker will restart the whole container. # TODO: Remove this exception once pipeline shutdown is fixed and restarting process is useful again. - raise Exception("Skipping process restart due to pipeline shutdown issues") + #raise Exception("Skipping process restart due to pipeline shutdown issues") await self._restart_process() except Exception: logging.exception("Failed to stop streamer and restart process. Moving to ERROR state", stack_info=True) diff --git a/runner/app/live/streamer/protocol/trickle.py b/runner/app/live/streamer/protocol/trickle.py index e2178e182..eee25c46f 100644 --- a/runner/app/live/streamer/protocol/trickle.py +++ b/runner/app/live/streamer/protocol/trickle.py @@ -7,12 +7,12 @@ from PIL import Image from trickle import media, TricklePublisher, TrickleSubscriber, InputFrame, OutputFrame, AudioFrame, AudioOutput - +from utils import ComfyUtils from .protocol import StreamProtocol from .last_value_cache import LastValueCache class TrickleProtocol(StreamProtocol): - def __init__(self, subscribe_url: str, publish_url: str, control_url: Optional[str] = None, events_url: Optional[str] = None): + def __init__(self, subscribe_url: str, publish_url: str, control_url: Optional[str] = None, events_url: Optional[str] = None, width: Optional[int] = ComfyUtils.DEFAULT_WIDTH, height: Optional[int] = ComfyUtils.DEFAULT_HEIGHT): self.subscribe_url = subscribe_url self.publish_url = publish_url self.control_url = control_url @@ -23,13 +23,15 @@ def __init__(self, subscribe_url: str, publish_url: str, control_url: Optional[s self.events_publisher = None self.subscribe_task = None self.publish_task = None + self.width = width + self.height = height async def start(self): self.subscribe_queue = queue.Queue[InputFrame]() self.publish_queue = queue.Queue[OutputFrame]() metadata_cache = LastValueCache[dict]() # to pass video metadata from decoder to encoder self.subscribe_task = asyncio.create_task( - media.run_subscribe(self.subscribe_url, self.subscribe_queue.put, metadata_cache.put, self.emit_monitoring_event) + media.run_subscribe(self.subscribe_url, self.subscribe_queue.put, metadata_cache.put, self.emit_monitoring_event, self.width, self.height) ) self.publish_task = asyncio.create_task( media.run_publish(self.publish_url, self.publish_queue.get, metadata_cache.get, self.emit_monitoring_event) diff --git a/runner/app/live/streamer/streamer.py b/runner/app/live/streamer/streamer.py index 2fa7b1b38..da239c201 100644 --- a/runner/app/live/streamer/streamer.py +++ b/runner/app/live/streamer/streamer.py @@ -13,6 +13,7 @@ from .protocol.protocol import StreamProtocol from .status import timestamp_to_ms from trickle import AudioFrame, VideoFrame, OutputFrame, AudioOutput, VideoOutput +from utils import ComfyUtils fps_log_interval = 10 status_report_interval = 10 @@ -25,6 +26,8 @@ def __init__( request_id: str, manifest_id: str, stream_id: str, + width: int = ComfyUtils.DEFAULT_WIDTH, + height: int = ComfyUtils.DEFAULT_HEIGHT, ): self.protocol = protocol self.process = process @@ -37,6 +40,8 @@ def __init__( self.request_id = request_id self.manifest_id = manifest_id self.stream_id = stream_id + self.width = width + self.height = height async def start(self, params: dict): if self.tasks_supervisor_task: @@ -46,7 +51,14 @@ async def start(self, params: dict): self.request_id, self.manifest_id, self.stream_id, params, self ) + # Update dimensions from process after reset_stream + self.width = self.process.width + self.height = self.process.height + logging.info(f"Streamer: Updated dimensions to {self.width}x{self.height} from process") + self.stop_event.clear() + self.protocol.width = self.width + self.protocol.height = self.height await self.protocol.start() # We need a bunch of concurrent tasks to run the streamer. So we start them all in background and then also start @@ -55,11 +67,11 @@ async def start(self, params: dict): run_in_background("ingress_loop", self.run_ingress_loop()), run_in_background("egress_loop", self.run_egress_loop()), run_in_background("report_status_loop", self.report_status_loop()), - run_in_background("control_loop", self.run_control_loop()), - ] + ] # auxiliary tasks that are not critical to the supervisor, but which we want to run # TODO: maybe remove this since we had to move the control loop to main tasks - self.auxiliary_tasks: list[asyncio.Task] = [] + self.auxiliary_tasks: list[asyncio.Task] = [run_in_background("control_loop", self.run_control_loop()), + ] self.tasks_supervisor_task = run_in_background( "tasks_supervisor", self.tasks_supervisor() ) diff --git a/runner/app/live/trickle/decoder.py b/runner/app/live/trickle/decoder.py index 1f94f30a0..0eede088c 100644 --- a/runner/app/live/trickle/decoder.py +++ b/runner/app/live/trickle/decoder.py @@ -11,7 +11,7 @@ MAX_FRAMERATE=24 -def decode_av(pipe_input, frame_callback, put_metadata): +def decode_av(pipe_input, frame_callback, put_metadata, output_width, output_height): """ Reads from a pipe (or file-like object). @@ -56,6 +56,8 @@ def decode_av(pipe_input, frame_callback, put_metadata): "sar": video_stream.codec_context.sample_aspect_ratio, "dar": video_stream.codec_context.display_aspect_ratio, "format": str(video_stream.codec_context.format), + "output_width": output_width, + "output_height": output_height, } if video_metadata is None and audio_metadata is None: @@ -93,7 +95,6 @@ def decode_av(pipe_input, frame_callback, put_metadata): frame = cast(av.VideoFrame, frame) if frame.pts is None: continue - # drop frames that come in too fast # TODO also check timing relative to wall clock pts_time = frame.time @@ -106,25 +107,34 @@ def decode_av(pipe_input, frame_callback, put_metadata): else: # not delayed, so use prev pts to allow more jitter next_pts_time = next_pts_time + frame_interval - - h = 512 - w = int((512 * frame.width / frame.height) / 2) * 2 # force divisible by 2 - if frame.height > frame.width: - w = 512 - h = int((512 * frame.height / frame.width) / 2) * 2 - frame = reformatter.reformat(frame, format='rgba', width=w, height=h) - + # Convert frame to image image = frame.to_image() if image.mode != "RGB": image = image.convert("RGB") width, height = image.size - if (width, height) != (512, 512): - # Crop to the center square if image not already square - square_size = 512 - start_x = width // 2 - square_size // 2 - start_y = height // 2 - square_size // 2 - image = image.crop((start_x, start_y, start_x + square_size, start_y + square_size)) + # Calculate aspect ratios + input_ratio = width / height + output_ratio = output_width / output_height + + if input_ratio != output_ratio: + # Need to crop to match output aspect ratio + if input_ratio > output_ratio: + # Input is wider than output - crop width + new_width = int(height * output_ratio) + start_x = (width - new_width) // 2 + image = image.crop((start_x, 0, start_x + new_width, height)) + else: + # Input is taller than output - crop height + new_height = int(width / output_ratio) + start_y = (height - new_height) // 2 + image = image.crop((0, start_y, width, start_y + new_height)) + + # Resize to final dimensions + if (output_width, output_height) != image.size: + image = image.resize((output_width, output_height)) + + # Convert to tensor image_np = np.array(image).astype(np.float32) / 255.0 tensor = torch.tensor(image_np).unsqueeze(0) diff --git a/runner/app/live/trickle/encoder.py b/runner/app/live/trickle/encoder.py index 648c78952..bf22c80f3 100644 --- a/runner/app/live/trickle/encoder.py +++ b/runner/app/live/trickle/encoder.py @@ -56,7 +56,9 @@ def custom_io_open(url: str, flags: int, options: dict): if video_meta and video_codec: # Add a new stream to the output using the desired video codec - video_opts = { 'video_size':'512x512', 'bf':'0' } + output_width = video_meta['output_width'] + output_height = video_meta['output_height'] + video_opts = { 'video_size':f'{output_width}x{output_height}', 'bf':'0' } if video_codec == 'libx264': video_opts = video_opts | { 'preset':'superfast', 'tune':'zerolatency', 'forced-idr':'1' } output_video_stream = output_container.add_stream(video_codec, options=video_opts) diff --git a/runner/app/live/trickle/media.py b/runner/app/live/trickle/media.py index 72ae59004..2c401cc65 100644 --- a/runner/app/live/trickle/media.py +++ b/runner/app/live/trickle/media.py @@ -16,12 +16,12 @@ MAX_ENCODER_RETRIES = 3 ENCODER_RETRY_RESET_SECONDS = 120 # reset retry counter after 2 minutes -async def run_subscribe(subscribe_url: str, image_callback, put_metadata, monitoring_callback): +async def run_subscribe(subscribe_url: str, image_callback, put_metadata, monitoring_callback, output_width, output_height): # TODO add some pre-processing parameters, eg image size try: in_pipe, out_pipe = os.pipe() write_fd = await AsyncifyFdWriter(out_pipe) - parse_task = asyncio.create_task(decode_in(in_pipe, image_callback, put_metadata, write_fd)) + parse_task = asyncio.create_task(decode_in(in_pipe, image_callback, put_metadata, write_fd, output_width, output_height)) subscribe_task = asyncio.create_task(subscribe(subscribe_url, write_fd, monitoring_callback)) await asyncio.gather(subscribe_task, parse_task) logging.info("run_subscribe complete") @@ -74,13 +74,13 @@ async def AsyncifyFdWriter(write_fd): writer = asyncio.StreamWriter(write_transport, write_protocol, None, loop) return writer -async def decode_in(in_pipe, frame_callback, put_metadata, write_fd): +async def decode_in(in_pipe, frame_callback, put_metadata, write_fd, output_width, output_height): def decode_runner(): retry_count = 0 last_retry_time = time.time() while retry_count < MAX_DECODER_RETRIES: try: - decode_av(f"pipe:{in_pipe}", frame_callback, put_metadata) + decode_av(f"pipe:{in_pipe}", frame_callback, put_metadata, output_width, output_height) break # clean exit except Exception as e: msg = str(e) diff --git a/runner/app/live/utils/__init__.py b/runner/app/live/utils/__init__.py new file mode 100644 index 000000000..f04013e59 --- /dev/null +++ b/runner/app/live/utils/__init__.py @@ -0,0 +1,3 @@ +from .comfy_utils import ComfyUtils + +__all__ = ['ComfyUtils'] \ No newline at end of file diff --git a/runner/app/live/utils/comfy_utils.py b/runner/app/live/utils/comfy_utils.py new file mode 100644 index 000000000..2dc8b5147 --- /dev/null +++ b/runner/app/live/utils/comfy_utils.py @@ -0,0 +1,40 @@ +import logging +import json + +class ComfyUtils: + DEFAULT_WIDTH = 384 + DEFAULT_HEIGHT = 704 + + @staticmethod + def get_latent_image_dimensions(workflow: dict | str | None) -> tuple[int, int]: + """Get dimensions from the EmptyLatentImage node in the workflow. + + Args: + workflow: The workflow JSON dictionary + + Returns: + Tuple of (width, height) from the latent image. Returns default dimensions if not found or on error. + """ + + if workflow is None: + return ComfyUtils.DEFAULT_WIDTH, ComfyUtils.DEFAULT_HEIGHT + + if isinstance(workflow, str): + workflow = json.loads(workflow) + + try: + for node_id, node in workflow.items(): + if node.get("class_type") == "EmptyLatentImage": + inputs = node.get("inputs", {}) + width = inputs.get("width") + height = inputs.get("height") + if width is not None and height is not None: + return width, height + logging.warning("Incomplete dimensions in latent image node") + break + except Exception as e: + logging.warning(f"Failed to extract dimensions from workflow: {e}") + + # Return defaults if dimensions not found or on any error + logging.info(f"Using default dimensions {ComfyUtils.DEFAULT_WIDTH}x{ComfyUtils.DEFAULT_HEIGHT}") + return ComfyUtils.DEFAULT_WIDTH, ComfyUtils.DEFAULT_HEIGHT \ No newline at end of file diff --git a/runner/app/main.py b/runner/app/main.py index 372e407de..582fb3f5c 100644 --- a/runner/app/main.py +++ b/runner/app/main.py @@ -26,8 +26,17 @@ async def lifespan(app: FastAPI): pipeline = os.environ["PIPELINE"] model_id = os.environ["MODEL_ID"] - - app.pipeline = load_pipeline(pipeline, model_id) + dimensions = os.environ.get("DIMENSIONS", "512x512") + if dimensions is not None: + try: + width, height = map(int, dimensions.split("x")) + if width % 64 != 0 or height % 64 != 0: + raise ValueError(f"Width and height must be divisible by 64, got {width}x{height}") + except ValueError as e: + logger.error(f"Invalid DIMENSIONS format. Expected 'WIDTHxHEIGHT' but got '{dimensions}'") + raise + + app.pipeline = load_pipeline(pipeline, model_id, dimensions) app.include_router(load_route(pipeline)) app.hardware_info_service.log_gpu_compute_info() @@ -42,7 +51,7 @@ async def lifespan(app: FastAPI): logger.info("Shutting down") -def load_pipeline(pipeline: str, model_id: str) -> any: +def load_pipeline(pipeline: str, model_id: str, dimensions: str | None = None) -> any: match pipeline: case "text-to-image": from app.pipelines.text_to_image import TextToImagePipeline @@ -81,7 +90,7 @@ def load_pipeline(pipeline: str, model_id: str) -> any: case "live-video-to-video": from app.pipelines.live_video_to_video import LiveVideoToVideoPipeline - return LiveVideoToVideoPipeline(model_id) + return LiveVideoToVideoPipeline(model_id, dimensions) case "text-to-speech": from app.pipelines.text_to_speech import TextToSpeechPipeline diff --git a/runner/app/pipelines/live_video_to_video.py b/runner/app/pipelines/live_video_to_video.py index 09896b4ff..b828e9693 100644 --- a/runner/app/pipelines/live_video_to_video.py +++ b/runner/app/pipelines/live_video_to_video.py @@ -18,11 +18,16 @@ proc_status_important_fields = ["State", "VmRSS", "VmSize", "Threads", "voluntary_ctxt_switches", "nonvoluntary_ctxt_switches", "CoreDumping"] class LiveVideoToVideoPipeline(Pipeline): - def __init__(self, model_id: str): + def __init__(self, model_id: str, dimensions: str | None = None): self.version = os.getenv("VERSION", "undefined") self.model_id = model_id self.model_dir = get_model_dir() self.torch_device = get_torch_device() + self.dimensions = dimensions + if dimensions: + self.width, self.height = map(int, dimensions.split("x")) + else: + self.width, self.height = 512, 512 # Default values self.infer_script_path = ( Path(__file__).parent.parent / "live" / "infer.py" ) @@ -34,6 +39,9 @@ def __call__( # type: ignore ): if not self.process: raise RuntimeError("Pipeline process not running") + + # TODO: remove this once we have a better way to parse dimensions/dynamically reload process + params.update({"width": self.width, "height": self.height}) max_retries = 10 thrown_ex = None @@ -109,11 +117,14 @@ def start_process(self): logging.info("Starting pipeline process") cmd = [sys.executable, str(self.infer_script_path)] cmd.extend(["--pipeline", self.model_id]) # we use the model_id as the pipeline name for now + cmd.extend(["--initial-params", json.dumps({"width": self.width, "height": self.height})]) cmd.extend(["--http-port", "8888"]) # TODO: set torch device from self.torch_device env = os.environ.copy() env["HUGGINGFACE_HUB_CACHE"] = str(self.model_dir) + if self.dimensions: + env["DIMENSIONS"] = self.dimensions try: self.process = subprocess.Popen( diff --git a/runner/docker/Dockerfile.live-base-comfyui b/runner/docker/Dockerfile.live-base-comfyui index fb42d186d..ab5395a3f 100644 --- a/runner/docker/Dockerfile.live-base-comfyui +++ b/runner/docker/Dockerfile.live-base-comfyui @@ -1,4 +1,4 @@ -ARG BASE_IMAGE=livepeer/comfyui-base@sha256:2d5ecad6bf24bb73831c6c87a33a7503f48a88ef2d490609068505a945d91146 +ARG BASE_IMAGE=livepeer/comfyui-base@sha256:a9ecd7be5cb93cd8d90e41e2e8759bb307b45b1dc20a19dfd40c3b4844e6097b FROM ${BASE_IMAGE} # -----------------------------------------------------------------------------