Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion backend/service/link_agent_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from agents.tool import function_tool
from ..utils.request_context import get_session_id
from ..dao.workflow_table import get_workflow_data, save_workflow_data
from ..utils.comfy_gateway import get_object_info
from ..utils.comfy_gateway import get_object_info, get_object_info_by_missed_class_list
from ..utils.logger import log

@function_tool
Expand All @@ -33,8 +33,29 @@ async def analyze_missing_connections() -> str:
if not workflow_data:
return json.dumps({"error": "No workflow data found for this session"})

log.info(f"analyze_missing_connections: workflow_data: {workflow_data}")

object_info = await get_object_info()

log.info(f"analyze_missing_connections: object_info: {object_info}")

# 缺失的node class
missed_node_class = []
# 缺失的object info
missed_object_info = {}
for node_id, node_data in workflow_data.items():
node_class = node_data.get("class_type")
log.info(f"node: node_id:{node_id}, node_class: {node_class}")
if node_class not in object_info:
log.info(f"missed node: node_id:{node_id}, node_class: {node_class}")
missed_node_class.append(node_class)

log.info(f"analyze_missing_connections: missed_node_class: {missed_node_class}")

if missed_node_class:
missed_object_info = await get_object_info_by_missed_class_list(missed_node_class)
object_info.update(missed_object_info)

analysis_result = {
"missing_connections": [],
"possible_connections": [],
Expand Down
6 changes: 3 additions & 3 deletions backend/service/workflow_rewrite_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .workflow_rewrite_agent_simple import rewrite_workflow_simple

from ..dao.workflow_table import get_workflow_data, save_workflow_data, get_workflow_data_ui, get_workflow_data_by_id
from ..utils.comfy_gateway import get_object_info
from ..utils.comfy_gateway import get_object_info, get_object_info_by_class_list
from ..utils.request_context import get_rewrite_context, get_session_id
from ..utils.logger import log

Expand Down Expand Up @@ -68,7 +68,7 @@ def get_current_workflow() -> str:
async def get_node_info(node_class: str) -> str:
"""获取节点的详细信息,包括输入输出参数"""
try:
object_info = await get_object_info()
object_info = await get_object_info_by_class_list([node_class])
if node_class in object_info:
node_info_str = json.dumps(object_info[node_class], ensure_ascii=False)
get_rewrite_context().node_infos[node_class] = node_info_str
Expand All @@ -89,7 +89,7 @@ async def get_node_info(node_class: str) -> str:
async def get_node_infos(node_class_list: list[str]) -> str:
"""获取多个节点的详细信息,包括输入输出参数。只做最小化有必要的查询,不要查询所有节点。尽量不要超过5个"""
try:
object_info = await get_object_info()
object_info = await get_object_info_by_class_list(node_class_list)
node_infos = {}
for node_class in node_class_list:
if node_class in object_info:
Expand Down
173 changes: 172 additions & 1 deletion backend/utils/comfy_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import folder_paths
import server
import aiohttp

from concurrent.futures import ThreadPoolExecutor, wait
from ..utils.logger import log

class ComfyGateway:
"""ComfyUI API Gateway for Python backend - uses internal functions instead of HTTP requests"""
Expand Down Expand Up @@ -346,6 +347,176 @@ async def run_prompt(json_data: Dict[str, Any], base_url: Optional[str] = None)
gateway = ComfyGateway(base_url)
return await gateway.run_prompt(json_data)

def node_info(node) -> Dict[str, Any]:
info = {}
info['name'] = node_class
info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
info['description'] = node['description'] if hasattr(node,'description') else ''
info['python_module'] = getattr(node, "RELATIVE_PYTHON_MODULE", "nodes")
info['category'] = node['category'] if hasattr(node, 'category') else 'sd'
if hasattr(node, "input_types"):
input_types = json.dumps(node['input_types'], ensure_ascii=False)
if input_types:
info['input'] = input_types
info['input_order'] = {key: list(value.keys()) for (key, value) in input_types.items()}
else:
info['input'] = {}
info['input_order'] = []
else:
info['input'] = {}
info['input_order'] = []

if hasattr(node, "return_names"):
info['output_name'] = json.dumps(node['return_names'], ensure_ascii=False)
else:
info['output_name'] = []

if hasattr(node, "output_is_list"):
info['output_is_list'] = node['output_is_list']
else:
info['output_is_list'] = [False] * len(info['output'])

if hasattr(node, "output_name"):
info['output_name'] = node['output_name']

if hasattr(node, 'OUTPUT_NODE') and node['output_node'] == True:
info['output_node'] = True
else:
info['output_node'] = False

if hasattr(node, 'OUTPUT_TOOLTIPS'):
info['output_tooltips'] = node['output_tooltips']

if getattr(node, "DEPRECATED", False):
info['deprecated'] = True
if getattr(node, "EXPERIMENTAL", False):
info['experimental'] = True

if hasattr(node, 'API_NODE'):
info['api_node'] = node['api_node']
return info

async def get_version(node_class: str = None) -> str:
try:
# Create a timeout configuration
timeout = aiohttp.ClientTimeout(total=30) # 30 second timeout

# Build URL - either specific node or all nodes
url = f"https://api.comfy.org/nodes/{node_class}/versions"

# Make HTTP request to /api/object_info endpoint
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(url) as response:
if response.status == 200:
result = await response.json()
versionData = result[0] if result[0] else None
version = versionData.get("version") if versionData.get("version") else None
return version
else:
logging.error(f"Failed to get object info: HTTP {response.status}")
return None

except aiohttp.ClientConnectionError as e:
logging.error(f"Connection error in get_object_info_from_comfyui_by_class: {e}")
return None
except aiohttp.ClientTimeout as e:
logging.error(f"Timeout error in get_object_info_from_comfyui_by_class: {e}")
return None
except Exception as e:
logging.error(f"Error getting get_object_info_from_comfyui_by_class: {e}")
return None

async def get_object_info_from_comfyapi_by_class(node_class: str = None) -> Dict[str, Any]:
"""Standalone function to get object info from comfyapi by class - HTTP call to ComfyUI /api/object_info/{node_class} endpoint"""
try:
# Create a timeout configuration
timeout = aiohttp.ClientTimeout(total=30) # 30 second timeout

version = await get_version(node_class)

# Build URL - either specific node or all nodes
url = f"https://api.comfy.org/nodes/{node_class}/versions/{version}/comfy-nodes"

# Make HTTP request to /api/object_info endpoint
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(url) as response:
if response.status == 200:
return await response.json()
else:
logging.error(f"Failed to get object info: HTTP {response.status}")
return {}

except aiohttp.ClientConnectionError as e:
logging.error(f"Connection error in get_object_info_from_comfyui_by_class: {e}")
return {}
except aiohttp.ClientTimeout as e:
logging.error(f"Timeout error in get_object_info_from_comfyui_by_class: {e}")
return {}
except Exception as e:
logging.error(f"Error getting get_object_info_from_comfyui_by_class: {e}")
return {}

async def get_object_info_by_missed_class_list(node_class_list: list[str] = None) -> Dict[str, Any]:
result: Dict[str, Any] = {}
# Limit concurrency to avoid excessive outbound requests
semaphore = asyncio.Semaphore(10)

async def fetch_and_merge(node_class: str) -> None:
async with semaphore:
try:
data = await get_object_info_from_comfyapi_by_class(node_class)
if not isinstance(data, dict):
return

# for d in data["comfy-nodes"]:

# Prefer exact key if present
if node_class in data:
result[node_class] = data[node_class]
return

# Otherwise, merge any matching requested keys from the payload
for k, v in data.items():
if k in node_class_list and k not in result:
result[k] = v

# Fallback: keep the whole payload under the class key if still missing
if node_class not in result:
result[node_class] = data
except Exception as e:
logging.error(f"Error fetching object info for {node_class}: {e}")

await asyncio.gather(*(fetch_and_merge(cls) for cls in node_class_list))
return result

async def get_object_info_by_class_list(node_class_list: list[str] = None, base_url: Optional[str] = None) -> Dict[str, Any]:
"""Get object info for a list of node classes.

Strategy:
- Fetch local object info once and take entries for classes that already exist.
- For missing classes, concurrently fetch via get_object_info_from_comfyapi_by_class.
- Merge results and return a dict keyed by class names.
"""
# Start with whatever the local Comfy server knows
local_object_info = await get_object_info(base_url) or {}

if not node_class_list:
return local_object_info

result: Dict[str, Any] = {}
for cls in node_class_list:
if cls in local_object_info:
result[cls] = local_object_info[cls]

# Determine which classes we still need
missing_classes = [cls for cls in node_class_list if cls not in result]
log.info(f"get_object_info_by_class_list: node_class_list: {node_class_list}")
log.info(f"get_object_info_by_class_list: missing_classes: {missing_classes}")
if not missing_classes:
return result

result.update(await get_object_info_by_missed_class_list(missing_classes))
return result

async def get_object_info(base_url: Optional[str] = None) -> Dict[str, Any]:
"""Standalone function to get object info - HTTP call to ComfyUI /api/object_info endpoint"""
Expand Down