From 2f4ef13b60304dae49a870d021f83864a12a5d3c Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Sat, 19 Apr 2025 00:06:13 -0700 Subject: [PATCH 1/2] WIP support for async execution functions This commit adds support for node execution functions defined as async. When a node's execution function is defined as async, we can continue executing other nodes while it is processing. Standard uses of `await` should "just work", but people will still have to be careful if they spawn actual threads. Because torch doesn't really have async/await versions of functions, this won't particularly help with most locally-executing nodes, but it does work for e.g. web requests to other machines. Remaining work: 1. The UI doesn't properly display multiple concurrent node executions. 2. I probably need some work to handle the case where one node hits an out-of-memory error (or other exception) while another node is concurrently executing. 3. If people are doing node expansion within an async function and using the `GraphBuilder`, they'll need to provide a manual prefix. The default one won't necessarily work properly in this case. This looks easy to fix in Python 3.12+ with contextvars, but I have to figure out how to do it in 3.11 and earlier. --- comfy_execution/graph.py | 20 +++++++- execution.py | 103 ++++++++++++++++++++++++++++++++++++--- 2 files changed, 115 insertions(+), 8 deletions(-) diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index a2799b52e102..c79243e1eefe 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -2,6 +2,7 @@ from typing import Type, Literal import nodes +import asyncio from comfy_execution.graph_utils import is_link from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions @@ -100,6 +101,8 @@ def __init__(self, dynprompt): self.pendingNodes = {} self.blockCount = {} # Number of nodes this node is directly blocked by self.blocking = {} # Which nodes are blocked by this node + self.externalBlocks = 0 + self.unblockedEvent = asyncio.Event() def get_input_info(self, unique_id, input_name): class_type = self.dynprompt.get_node(unique_id)["class_type"] @@ -153,6 +156,16 @@ def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None): for link in links: self.add_strong_link(*link) + def add_external_block(self, node_id): + assert node_id in self.blockCount, "Can't add external block to a node that isn't pending" + self.externalBlocks += 1 + self.blockCount[node_id] += 1 + def unblock(): + self.externalBlocks -= 1 + self.blockCount[node_id] -= 1 + self.unblockedEvent.set() + return unblock + def is_cached(self, node_id): return False @@ -181,11 +194,16 @@ def __init__(self, dynprompt, output_cache): def is_cached(self, node_id): return self.output_cache.get(node_id) is not None - def stage_node_execution(self): + async def stage_node_execution(self): assert self.staged_node_id is None if self.is_empty(): return None, None, None available = self.get_ready_nodes() + while len(available) == 0 and self.externalBlocks > 0: + # Wait for an external block to be released + await self.unblockedEvent.wait() + self.unblockedEvent.clear() + available = self.get_ready_nodes() if len(available) == 0: cycled_nodes = self.get_nodes_in_cycle() # Because cycles composed entirely of static nodes are caught during initial validation, diff --git a/execution.py b/execution.py index d09102f55247..e721d6e016fc 100644 --- a/execution.py +++ b/execution.py @@ -8,6 +8,7 @@ from enum import Enum import inspect from typing import List, Literal, NamedTuple, Optional +import asyncio import torch import nodes @@ -192,6 +193,63 @@ def process_inputs(inputs, index=None, input_is_list=False): process_inputs(input_dict, i) return results +async def _async_map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None): + # check if node wants the lists + input_is_list = getattr(obj, "INPUT_IS_LIST", False) + + if len(input_data_all) == 0: + max_len_input = 0 + else: + max_len_input = max(len(x) for x in input_data_all.values()) + + # get a slice of inputs, repeat last input when list isn't long enough + def slice_dict(d, i): + return {k: v[i if len(v) > i else -1] for k, v in d.items()} + + results = [] + async def process_inputs(inputs, index=None, input_is_list=False): + if allow_interrupt: + nodes.before_node_execution() + execution_block = None + for k, v in inputs.items(): + if input_is_list: + for e in v: + if isinstance(e, ExecutionBlocker): + v = e + break + if isinstance(v, ExecutionBlocker): + execution_block = execution_block_cb(v) if execution_block_cb else v + break + if execution_block is None: + if pre_execute_cb is not None and index is not None: + pre_execute_cb(index) + f = getattr(obj, func) + if inspect.iscoroutinefunction(f): + task = asyncio.create_task(f(**inputs)) + # Give the task a chance to execute without yielding + await asyncio.sleep(0) + if task.done(): + result = task.result() + results.append(result) + else: + results.append(task) + else: + result = f(**inputs) + results.append(result) + else: + results.append(execution_block) + + if input_is_list: + await process_inputs(input_data_all, 0, input_is_list=input_is_list) + elif max_len_input == 0: + await process_inputs({}) + else: + for i in range(max_len_input): + input_dict = slice_dict(input_data_all, i) + await process_inputs(input_dict, i) + return results + + def merge_result_data(results, obj): # check which outputs need concatenating output = [] @@ -213,11 +271,18 @@ def merge_result_data(results, obj): output.append([o[i] for o in results]) return output -def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None): +async def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None): + return_values = await _async_map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb) + has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values) + if has_pending_task: + return return_values, {}, False, has_pending_task + output, ui, has_subgraph = get_output_from_returns(return_values, obj) + return output, ui, has_subgraph, False + +def get_output_from_returns(return_values, obj): results = [] uis = [] subgraph_results = [] - return_values = _map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb) has_subgraph = False for i in range(len(return_values)): r = return_values[i] @@ -251,6 +316,10 @@ def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb else: output = [] ui = dict() + # Think there's an existing bug here + # If we're performing a subgraph expansion, we probably shouldn't be returning UI values yet. + # They'll get cached without the completed subgraphs. It's an edge case and I'm not aware of + # any nodes that use both subgraph expansion and custom UI outputs, but might be a problem in the future. if len(uis) > 0: ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} return output, ui, has_subgraph @@ -263,7 +332,7 @@ def format_value(x): else: return str(x) -def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results): +async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes): unique_id = current_item real_node_id = dynprompt.get_real_node_id(unique_id) display_node_id = dynprompt.get_display_node_id(unique_id) @@ -279,7 +348,11 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp input_data_all = None try: - if unique_id in pending_subgraph_results: + if unique_id in pending_async_nodes: + results = [r.result() if isinstance(r, asyncio.Task) else r for r in pending_async_nodes[unique_id]] + del pending_async_nodes[unique_id] + output_data, output_ui, has_subgraph = get_output_from_returns(results, class_def) + elif unique_id in pending_subgraph_results: cached_results = pending_subgraph_results[unique_id] resolved_outputs = [] for is_subgraph, result in cached_results: @@ -341,8 +414,18 @@ def execution_block_cb(block): else: return block def pre_execute_cb(call_index): + # TODO - How to handle this with async functions without contextvars (which requires Python 3.12)? GraphBuilder.set_default_prefix(unique_id, call_index, 0) - output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb) + output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb) + if has_pending_tasks: + pending_async_nodes[unique_id] = output_data + unblock = execution_list.add_external_block(unique_id) + async def await_completion(): + tasks = [x for x in output_data if isinstance(x, asyncio.Task)] + await asyncio.gather(*tasks) + unblock() + asyncio.create_task(await_completion()) + return (ExecutionResult.PENDING, None, None) if len(output_ui) > 0: caches.ui.set(unique_id, { "meta": { @@ -481,6 +564,11 @@ def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, e self.add_message("execution_error", mes, broadcast=False) def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): + asyncio_loop = asyncio.new_event_loop() + asyncio.set_event_loop(asyncio_loop) + asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs)) + + async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): nodes.interrupt_processing(False) if "client_id" in extra_data: @@ -508,6 +596,7 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): { "nodes": cached_nodes, "prompt_id": prompt_id}, broadcast=False) pending_subgraph_results = {} + pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results executed = set() execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) current_outputs = self.caches.outputs.all_node_ids() @@ -515,12 +604,12 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): execution_list.add_node(node_id) while not execution_list.is_empty(): - node_id, error, ex = execution_list.stage_node_execution() + node_id, error, ex = await execution_list.stage_node_execution() if error is not None: self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) break - result, error, ex = execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results) + result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) self.success = result != ExecutionResult.FAILURE if result == ExecutionResult.FAILURE: self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) From 9b4963e723721fa5ffe537e8e5c128eaba4a1685 Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Sat, 19 Apr 2025 01:01:53 -0700 Subject: [PATCH 2/2] Add an example "Sleep" node to test async --- comfy/utils.py | 5 +++-- main.py | 8 ++++++-- nodes.py | 31 +++++++++++++++++++++++++++++++ 3 files changed, 40 insertions(+), 4 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index a826e41bf939..f0e93b8bea79 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -992,11 +992,12 @@ def set_progress_bar_global_hook(function): PROGRESS_BAR_HOOK = function class ProgressBar: - def __init__(self, total): + def __init__(self, total, node_id=None): global PROGRESS_BAR_HOOK self.total = total self.current = 0 self.hook = PROGRESS_BAR_HOOK + self.node_id = node_id def update_absolute(self, value, total=None, preview=None): if total is not None: @@ -1005,7 +1006,7 @@ def update_absolute(self, value, total=None, preview=None): value = self.total self.current = value if self.hook is not None: - self.hook(self.current, self.total, preview) + self.hook(self.current, self.total, preview, node_id=self.node_id) def update(self, value): self.update_absolute(self.current + value) diff --git a/main.py b/main.py index ac9d24b7b82c..61d5fca7e68a 100644 --- a/main.py +++ b/main.py @@ -227,9 +227,13 @@ async def run(server_instance, address='', port=8188, verbose=True, call_on_star def hijack_progress(server_instance): - def hook(value, total, preview_image): + def hook(value, total, preview_image, prompt_id=None, node_id=None): comfy.model_management.throw_exception_if_processing_interrupted() - progress = {"value": value, "max": total, "prompt_id": server_instance.last_prompt_id, "node": server_instance.last_node_id} + if prompt_id is None: + prompt_id = server_instance.last_prompt_id + if node_id is None: + node_id = server_instance.last_node_id + progress = {"value": value, "max": total, "prompt_id": prompt_id, "node": node_id} server_instance.send_sync("progress", progress, server_instance.client_id) if preview_image is not None: diff --git a/nodes.py b/nodes.py index d4082d19d430..de93e9741f09 100644 --- a/nodes.py +++ b/nodes.py @@ -46,6 +46,35 @@ def interrupt_processing(value=True): MAX_RESOLUTION=16384 +class TestSleep(ComfyNodeABC): + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": (IO.ANY, {}), + "seconds": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 9999.0, "step": 0.01, "tooltip": "The amount of seconds to sleep."}), + }, + "hidden": { + "unique_id": "UNIQUE_ID", + }, + } + RETURN_TYPES = (IO.ANY,) + FUNCTION = "sleep" + + CATEGORY = "_for_testing" + + async def sleep(self, value, seconds, unique_id): + pbar = comfy.utils.ProgressBar(seconds, node_id=unique_id) + import asyncio + start = time.time() + expiration = start + seconds + now = start + while now < expiration: + now = time.time() + pbar.update_absolute(now - start) + await asyncio.sleep(0.01) + return (value,) + class CLIPTextEncode(ComfyNodeABC): @classmethod def INPUT_TYPES(s) -> InputTypeDict: @@ -1941,6 +1970,7 @@ def expand_image(self, image, left, top, right, bottom, feathering): NODE_CLASS_MAPPINGS = { + "TestSleep": TestSleep, "KSampler": KSampler, "CheckpointLoaderSimple": CheckpointLoaderSimple, "CLIPTextEncode": CLIPTextEncode, @@ -2011,6 +2041,7 @@ def expand_image(self, image, left, top, right, bottom, feathering): } NODE_DISPLAY_NAME_MAPPINGS = { + "TestSleep": "Test Sleep", # Sampling "KSampler": "KSampler", "KSamplerAdvanced": "KSampler (Advanced)",