From 22d40e7a0635c064335c33a28ed43c035693bd83 Mon Sep 17 00:00:00 2001 From: li <> Date: Mon, 15 Sep 2025 10:22:45 +0800 Subject: [PATCH] feat: get node info --- backend/service/link_agent_tools.py | 23 ++- backend/service/workflow_rewrite_tools.py | 6 +- backend/utils/comfy_gateway.py | 173 +++++++++++++++++++++- 3 files changed, 197 insertions(+), 5 deletions(-) diff --git a/backend/service/link_agent_tools.py b/backend/service/link_agent_tools.py index aa734b08..c0e6773b 100644 --- a/backend/service/link_agent_tools.py +++ b/backend/service/link_agent_tools.py @@ -7,7 +7,7 @@ from agents.tool import function_tool from ..utils.request_context import get_session_id from ..dao.workflow_table import get_workflow_data, save_workflow_data -from ..utils.comfy_gateway import get_object_info +from ..utils.comfy_gateway import get_object_info, get_object_info_by_missed_class_list from ..utils.logger import log @function_tool @@ -33,8 +33,29 @@ async def analyze_missing_connections() -> str: if not workflow_data: return json.dumps({"error": "No workflow data found for this session"}) + log.info(f"analyze_missing_connections: workflow_data: {workflow_data}") + object_info = await get_object_info() + log.info(f"analyze_missing_connections: object_info: {object_info}") + + # 缺失的node class + missed_node_class = [] + # 缺失的object info + missed_object_info = {} + for node_id, node_data in workflow_data.items(): + node_class = node_data.get("class_type") + log.info(f"node: node_id:{node_id}, node_class: {node_class}") + if node_class not in object_info: + log.info(f"missed node: node_id:{node_id}, node_class: {node_class}") + missed_node_class.append(node_class) + + log.info(f"analyze_missing_connections: missed_node_class: {missed_node_class}") + + if missed_node_class: + missed_object_info = await get_object_info_by_missed_class_list(missed_node_class) + object_info.update(missed_object_info) + analysis_result = { "missing_connections": [], "possible_connections": [], diff --git a/backend/service/workflow_rewrite_tools.py b/backend/service/workflow_rewrite_tools.py index e68e4293..38b60755 100644 --- a/backend/service/workflow_rewrite_tools.py +++ b/backend/service/workflow_rewrite_tools.py @@ -9,7 +9,7 @@ from .workflow_rewrite_agent_simple import rewrite_workflow_simple from ..dao.workflow_table import get_workflow_data, save_workflow_data, get_workflow_data_ui, get_workflow_data_by_id -from ..utils.comfy_gateway import get_object_info +from ..utils.comfy_gateway import get_object_info, get_object_info_by_class_list from ..utils.request_context import get_rewrite_context, get_session_id from ..utils.logger import log @@ -68,7 +68,7 @@ def get_current_workflow() -> str: async def get_node_info(node_class: str) -> str: """获取节点的详细信息,包括输入输出参数""" try: - object_info = await get_object_info() + object_info = await get_object_info_by_class_list([node_class]) if node_class in object_info: node_info_str = json.dumps(object_info[node_class], ensure_ascii=False) get_rewrite_context().node_infos[node_class] = node_info_str @@ -89,7 +89,7 @@ async def get_node_info(node_class: str) -> str: async def get_node_infos(node_class_list: list[str]) -> str: """获取多个节点的详细信息,包括输入输出参数。只做最小化有必要的查询,不要查询所有节点。尽量不要超过5个""" try: - object_info = await get_object_info() + object_info = await get_object_info_by_class_list(node_class_list) node_infos = {} for node_class in node_class_list: if node_class in object_info: diff --git a/backend/utils/comfy_gateway.py b/backend/utils/comfy_gateway.py index a366dbe1..897179d8 100644 --- a/backend/utils/comfy_gateway.py +++ b/backend/utils/comfy_gateway.py @@ -18,7 +18,8 @@ import folder_paths import server import aiohttp - +from concurrent.futures import ThreadPoolExecutor, wait +from ..utils.logger import log class ComfyGateway: """ComfyUI API Gateway for Python backend - uses internal functions instead of HTTP requests""" @@ -346,6 +347,176 @@ async def run_prompt(json_data: Dict[str, Any], base_url: Optional[str] = None) gateway = ComfyGateway(base_url) return await gateway.run_prompt(json_data) +def node_info(node) -> Dict[str, Any]: + info = {} + info['name'] = node_class + info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class + info['description'] = node['description'] if hasattr(node,'description') else '' + info['python_module'] = getattr(node, "RELATIVE_PYTHON_MODULE", "nodes") + info['category'] = node['category'] if hasattr(node, 'category') else 'sd' + if hasattr(node, "input_types"): + input_types = json.dumps(node['input_types'], ensure_ascii=False) + if input_types: + info['input'] = input_types + info['input_order'] = {key: list(value.keys()) for (key, value) in input_types.items()} + else: + info['input'] = {} + info['input_order'] = [] + else: + info['input'] = {} + info['input_order'] = [] + + if hasattr(node, "return_names"): + info['output_name'] = json.dumps(node['return_names'], ensure_ascii=False) + else: + info['output_name'] = [] + + if hasattr(node, "output_is_list"): + info['output_is_list'] = node['output_is_list'] + else: + info['output_is_list'] = [False] * len(info['output']) + + if hasattr(node, "output_name"): + info['output_name'] = node['output_name'] + + if hasattr(node, 'OUTPUT_NODE') and node['output_node'] == True: + info['output_node'] = True + else: + info['output_node'] = False + + if hasattr(node, 'OUTPUT_TOOLTIPS'): + info['output_tooltips'] = node['output_tooltips'] + + if getattr(node, "DEPRECATED", False): + info['deprecated'] = True + if getattr(node, "EXPERIMENTAL", False): + info['experimental'] = True + + if hasattr(node, 'API_NODE'): + info['api_node'] = node['api_node'] + return info + +async def get_version(node_class: str = None) -> str: + try: + # Create a timeout configuration + timeout = aiohttp.ClientTimeout(total=30) # 30 second timeout + + # Build URL - either specific node or all nodes + url = f"https://api.comfy.org/nodes/{node_class}/versions" + + # Make HTTP request to /api/object_info endpoint + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.get(url) as response: + if response.status == 200: + result = await response.json() + versionData = result[0] if result[0] else None + version = versionData.get("version") if versionData.get("version") else None + return version + else: + logging.error(f"Failed to get object info: HTTP {response.status}") + return None + + except aiohttp.ClientConnectionError as e: + logging.error(f"Connection error in get_object_info_from_comfyui_by_class: {e}") + return None + except aiohttp.ClientTimeout as e: + logging.error(f"Timeout error in get_object_info_from_comfyui_by_class: {e}") + return None + except Exception as e: + logging.error(f"Error getting get_object_info_from_comfyui_by_class: {e}") + return None + +async def get_object_info_from_comfyapi_by_class(node_class: str = None) -> Dict[str, Any]: + """Standalone function to get object info from comfyapi by class - HTTP call to ComfyUI /api/object_info/{node_class} endpoint""" + try: + # Create a timeout configuration + timeout = aiohttp.ClientTimeout(total=30) # 30 second timeout + + version = await get_version(node_class) + + # Build URL - either specific node or all nodes + url = f"https://api.comfy.org/nodes/{node_class}/versions/{version}/comfy-nodes" + + # Make HTTP request to /api/object_info endpoint + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.get(url) as response: + if response.status == 200: + return await response.json() + else: + logging.error(f"Failed to get object info: HTTP {response.status}") + return {} + + except aiohttp.ClientConnectionError as e: + logging.error(f"Connection error in get_object_info_from_comfyui_by_class: {e}") + return {} + except aiohttp.ClientTimeout as e: + logging.error(f"Timeout error in get_object_info_from_comfyui_by_class: {e}") + return {} + except Exception as e: + logging.error(f"Error getting get_object_info_from_comfyui_by_class: {e}") + return {} + +async def get_object_info_by_missed_class_list(node_class_list: list[str] = None) -> Dict[str, Any]: + result: Dict[str, Any] = {} + # Limit concurrency to avoid excessive outbound requests + semaphore = asyncio.Semaphore(10) + + async def fetch_and_merge(node_class: str) -> None: + async with semaphore: + try: + data = await get_object_info_from_comfyapi_by_class(node_class) + if not isinstance(data, dict): + return + + # for d in data["comfy-nodes"]: + + # Prefer exact key if present + if node_class in data: + result[node_class] = data[node_class] + return + + # Otherwise, merge any matching requested keys from the payload + for k, v in data.items(): + if k in node_class_list and k not in result: + result[k] = v + + # Fallback: keep the whole payload under the class key if still missing + if node_class not in result: + result[node_class] = data + except Exception as e: + logging.error(f"Error fetching object info for {node_class}: {e}") + + await asyncio.gather(*(fetch_and_merge(cls) for cls in node_class_list)) + return result + +async def get_object_info_by_class_list(node_class_list: list[str] = None, base_url: Optional[str] = None) -> Dict[str, Any]: + """Get object info for a list of node classes. + + Strategy: + - Fetch local object info once and take entries for classes that already exist. + - For missing classes, concurrently fetch via get_object_info_from_comfyapi_by_class. + - Merge results and return a dict keyed by class names. + """ + # Start with whatever the local Comfy server knows + local_object_info = await get_object_info(base_url) or {} + + if not node_class_list: + return local_object_info + + result: Dict[str, Any] = {} + for cls in node_class_list: + if cls in local_object_info: + result[cls] = local_object_info[cls] + + # Determine which classes we still need + missing_classes = [cls for cls in node_class_list if cls not in result] + log.info(f"get_object_info_by_class_list: node_class_list: {node_class_list}") + log.info(f"get_object_info_by_class_list: missing_classes: {missing_classes}") + if not missing_classes: + return result + + result.update(await get_object_info_by_missed_class_list(missing_classes)) + return result async def get_object_info(base_url: Optional[str] = None) -> Dict[str, Any]: """Standalone function to get object info - HTTP call to ComfyUI /api/object_info endpoint"""