diff --git a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langchain_helper.py b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langchain_helper.py index 796070e0..3cded324 100644 --- a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langchain_helper.py +++ b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langchain_helper.py @@ -1,3 +1,4 @@ +import inspect from typing import Any, Dict, List, Optional, Union from langchain_core.language_models.chat_models import BaseChatModel @@ -82,6 +83,41 @@ def create_langchain_model(ai_config: AIConfigKind) -> BaseChatModel: ) +def _iter_valid_tools( + tool_definitions: List[Dict[str, Any]], + tool_registry: ToolRegistry, +) -> List[tuple]: + """ + Filter LD tool definitions against a registry, returning (name, td) pairs for each + valid function tool that has a callable implementation. Built-in provider tools and + tools missing from the registry are skipped with a warning. + """ + valid = [] + for td in tool_definitions: + if not isinstance(td, dict): + continue + + tool_type = td.get('type') + if tool_type and tool_type != 'function': + log.warning( + f"Built-in tool '{tool_type}' is not reliably supported via LangChain and will be skipped. " + "Use a provider-specific runner to use built-in provider tools." + ) + continue + + name = td.get('name') + if not name: + continue + + if name not in tool_registry: + log.warning(f"Tool '{name}' is defined in the AI config but was not found in the tool registry; skipping.") + continue + + valid.append((name, td)) + + return valid + + def build_tools(ai_config: AIConfigKind, tool_registry: ToolRegistry) -> List[Any]: """ Return callables from the registry for each tool defined in the AI config. @@ -114,6 +150,39 @@ def build_tools(ai_config: AIConfigKind, tool_registry: ToolRegistry) -> List[An return tools +def build_structured_tools(ai_config: AIConfigKind, tool_registry: ToolRegistry) -> List[Any]: + """ + Build a list of LangChain StructuredTool instances from LD tool definitions and a registry. + + Tools found in the registry are wrapped as StructuredTool using the LD config key as the + tool name so the model's tool calls match ToolNode lookup. Async callables use ``coroutine=`` + so LangGraph invokes them correctly. Built-in provider tools and tools missing from the + registry are skipped with a warning. + + :param ai_config: The LaunchDarkly AI configuration + :param tool_registry: Registry mapping tool names to callable implementations + :return: List of StructuredTool instances ready to pass to langchain.agents.create_agent + """ + from langchain_core.tools import StructuredTool + + config_dict = ai_config.to_dict() + model_dict = config_dict.get('model') or {} + parameters = dict(model_dict.get('parameters') or {}) + tool_definitions = parameters.pop('tools', []) or [] + + tools = [] + for name, td in _iter_valid_tools(tool_definitions, tool_registry): + fn = tool_registry[name] + raw_desc = td.get('description') if isinstance(td.get('description'), str) else '' + description = raw_desc.strip() or (getattr(fn, '__doc__', None) or '').strip() or f'Tool {name}' + if inspect.iscoroutinefunction(fn): + tool = StructuredTool.from_function(coroutine=fn, name=name, description=description) + else: + tool = StructuredTool.from_function(fn, name=name, description=description) + tools.append(tool) + return tools + + def get_ai_usage_from_response(response: Any) -> Optional[TokenUsage]: """ Extract token usage from a LangChain response. diff --git a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py index 7554b2fa..a82b2029 100644 --- a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py +++ b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py @@ -1,6 +1,7 @@ -import operator +"""LangGraph agent graph runner for LaunchDarkly AI SDK.""" + import time -from typing import Annotated, Any, List +from typing import Annotated, Any, Dict, List, Optional, Set, Tuple from ldai import log from ldai.agent_graph import AgentGraphDefinition, AgentGraphNode @@ -8,14 +9,46 @@ from ldai.providers.types import LDAIMetrics from ldai_langchain.langchain_helper import ( - build_tools, + build_structured_tools, create_langchain_model, extract_last_message_content, - get_ai_metrics_from_response, - get_ai_usage_from_response, - get_tool_calls_from_response, sum_token_usage_from_messages, ) +from ldai_langchain.langgraph_callback_handler import LDMetricsCallbackHandler + + +def _make_handoff_tool(child_key: str, description: str) -> Any: + """ + Create a tool that transfers control to ``child_key``. + + Uses the ``@tool`` decorator with ``InjectedState`` + ``InjectedToolCallId`` + so LangGraph's ToolNode handles the ``Command`` return value correctly. + The tool explicitly creates a ToolMessage in ``Command.update`` to satisfy + the LangChain/OpenAI message-chain contract. + """ + from typing import Annotated as _Annotated + + from langchain_core.messages import ToolMessage + from langchain_core.tools import tool + from langchain_core.tools.base import InjectedToolCallId + from langgraph.prebuilt import InjectedState + from langgraph.types import Command + + tool_name = f"transfer_to_{child_key.replace('-', '_')}" + + @tool(tool_name, description=description) + def handoff( + state: _Annotated[Any, InjectedState], # noqa: ARG001 + tool_call_id: _Annotated[str, InjectedToolCallId], + ) -> Command: + tool_message = ToolMessage( + content=f'Transferred to {child_key}', + name=tool_name, + tool_call_id=tool_call_id, + ) + return Command(goto=child_key, update={'messages': [tool_message]}) + + return handoff class LangGraphAgentGraphRunner(AgentGraphRunner): @@ -43,103 +76,245 @@ def __init__(self, graph: AgentGraphDefinition, tools: ToolRegistry): """ self._graph = graph self._tools = tools + self._compiled: Any = None + self._fn_name_to_config_key: Dict[str, str] = {} + self._node_keys: Set[str] = set() - async def run(self, input: Any) -> AgentGraphResult: - """ - Run the agent graph with the given input. + def _ensure_compiled(self) -> None: + """Build and cache the compiled graph if not already done.""" + if self._compiled is None: + compiled, fn_name_to_config_key, node_keys = self._build_graph() + self._compiled = compiled + self._fn_name_to_config_key = fn_name_to_config_key + self._node_keys = node_keys - Builds a LangGraph StateGraph from the AgentGraphDefinition, compiles - it, and invokes it. Tracks latency and invocation success/failure. + def _build_graph(self) -> Tuple[Any, Dict[str, str], Set[str]]: + """ + Build and compile the LangGraph StateGraph from the AgentGraphDefinition. - :param input: The string prompt to send to the agent graph - :return: AgentGraphResult with the final output and metrics + :return: Tuple of (compiled_graph, fn_name_to_config_key, node_keys) where + fn_name_to_config_key maps tool function __name__ to LD config key, and + node_keys is the set of all agent node keys in the graph. """ - tracker = self._graph.get_tracker() - start_ns = time.perf_counter_ns() - try: - from langchain_core.messages import AnyMessage, HumanMessage - from langgraph.graph import END, START, StateGraph - from typing_extensions import TypedDict - - class WorkflowState(TypedDict): - messages: Annotated[List[Any], operator.add] - - agent_builder: StateGraph = StateGraph(WorkflowState) - root_node = self._graph.root() - root_key = root_node.get_key() if root_node else None - tools_ref = self._tools - exec_path: List[str] = [] - - def handle_traversal(node: AgentGraphNode, ctx: dict) -> None: - node_config = node.get_config() - node_key = node.get_key() - node_tracker = node_config.tracker - - model = None - if node_config.model: - lc_model = create_langchain_model(node_config) - tool_fns = build_tools(node_config, tools_ref) - model = lc_model.bind_tools(tool_fns) if tool_fns else lc_model - - def invoke(state: WorkflowState) -> WorkflowState: - exec_path.append(node_key) - if not model: - return {'messages': []} - gk = tracker.graph_key if tracker is not None else None - if node_tracker: - response = node_tracker.track_metrics_of( - lambda: model.invoke(state['messages']), - get_ai_metrics_from_response, - graph_key=gk, - ) - node_tracker.track_tool_calls( - get_tool_calls_from_response(response), - graph_key=tracker.graph_key if tracker is not None else None, + from langchain_core.messages import SystemMessage + from langgraph.graph import END, START, StateGraph + from langgraph.graph.message import add_messages + from langgraph.prebuilt import ToolNode, tools_condition + from typing_extensions import TypedDict + + class WorkflowState(TypedDict): + messages: Annotated[List[Any], add_messages] + + agent_builder: StateGraph = StateGraph(WorkflowState) + root_node = self._graph.root() + root_key = root_node.get_key() if root_node else None + tools_ref = self._tools + graph_structure: List[str] = [] + fn_name_to_config_key: Dict[str, str] = {} + node_keys: Set[str] = set() + + def handle_traversal(node: AgentGraphNode, ctx: dict) -> None: + node_config = node.get_config() + node_key = node.get_key() + node_keys.add(node_key) + instructions = node_config.instructions if hasattr(node_config, 'instructions') else None + outgoing_edges = node.get_edges() + + lc_model = None + tool_fns: list = [] + if node_config.model: + # We send an empty tool registry to avoid binding tools to the model. + lc_model = create_langchain_model(node_config) + + tool_fns = build_structured_tools(node_config, tools_ref) + + # Map tool name -> LD config key for callback attribution. + # build_structured_tools returns StructuredTool instances with tool.name set + # to the LD config key, so tool.name IS the config key. + for tool in tool_fns: + tool_name = getattr(tool, 'name', None) + if tool_name: + fn_name_to_config_key[tool_name] = tool_name + + # For nodes with multiple children, create a handoff tool per child so the + # LLM decides which agent to route to. Uses Command(goto=child_key) so + # LangGraph routes to the target without looping back here. + handoff_fns: list = [] + if lc_model and len(outgoing_edges) > 1: + for edge in outgoing_edges: + child_node = self._graph.get_node(edge.target_config) + description = ( + (edge.handoff or {}).get('description') + or ( + child_node.get_config().instructions[:120] + if child_node and child_node.get_config().instructions + else None ) - else: - response = model.invoke(state['messages']) + or f"Transfer control to {edge.target_config}" + ) + handoff_fns.append(_make_handoff_tool(edge.target_config, description)) + + all_tools = tool_fns + handoff_fns + model: Any + if lc_model and all_tools: + # When handoff tools are present, disable parallel tool calls so the LLM + # picks exactly one destination rather than routing to multiple children. + bind_kwargs: Dict[str, Any] = {'parallel_tool_calls': False} if handoff_fns else {} + model = lc_model.bind_tools(all_tools, **bind_kwargs) + else: + model = lc_model + def make_node_fn(bound_model: Any, node_instructions: Any, nk: str): + async def invoke(state: WorkflowState) -> dict: + if not bound_model: + return {'messages': []} + msgs = list(state['messages']) + if node_instructions: + msgs = [SystemMessage(content=node_instructions)] + msgs + response = await bound_model.ainvoke(msgs) return {'messages': [response]} - invoke.__name__ = node_key + invoke.__name__ = nk + return invoke - agent_builder.add_node(node_key, invoke) + invoke_fn = make_node_fn(model, instructions, node_key) + agent_builder.add_node(node_key, invoke_fn) - if node_key == root_key: - agent_builder.add_edge(START, node_key) + if node_key == root_key: + agent_builder.add_edge(START, node_key) + # Collect node info for graph structure log + tool_names = [str(getattr(t, 'name', None) or getattr(t, '__name__', t)) for t in tool_fns] + edge_targets = [e.target_config for e in outgoing_edges] + node_desc = node_key + if tool_names: + node_desc += f"[tools:{','.join(tool_names)}]" + if handoff_fns: + node_desc += f"[handoff:{','.join(edge_targets)}]" + elif edge_targets: + node_desc += f"→{','.join(edge_targets)}" + else: + node_desc += "(terminal)" + graph_structure.append(node_desc) + + if all_tools: + tools_node_key = f"{node_key}__tools" + agent_builder.add_node(tools_node_key, ToolNode(all_tools)) + + if not handoff_fns: + # No handoff tools: standard loop-back after tool execution. + after_loop = outgoing_edges[0].target_config if outgoing_edges else END + if len(outgoing_edges) > 1: + log.warning( + f"Node '{node_key}' has {len(outgoing_edges)} outgoing edges but no handoff " + "tools; only the first edge will be used after the tool loop. " + "Use handoff tools for multi-child routing." + ) + agent_builder.add_edge(tools_node_key, node_key) + agent_builder.add_conditional_edges( + node_key, + tools_condition, + {"tools": tools_node_key, END: after_loop}, + ) + elif not tool_fns: + # Only handoff tools: no loop-back needed. + # Command(goto=child_key) handles routing to the target. + agent_builder.add_conditional_edges( + node_key, + tools_condition, + {"tools": tools_node_key, END: END}, + ) + else: + # Both functional and handoff tools. A static loop-back edge would + # fan-out with Command(goto=child_key) from handoff tools, so use a + # conditional edge that only loops back for functional tool results. + handoff_names_set = frozenset(getattr(t, 'name', '') for t in handoff_fns) + + def make_after_tools_router(parent_key: str, ht_names: frozenset): + def route(state: WorkflowState) -> str: + msgs = state['messages'] + if msgs: + last = msgs[-1] + if hasattr(last, 'name') and last.name in ht_names: + return END + return parent_key + return route + + agent_builder.add_conditional_edges( + tools_node_key, + make_after_tools_router(node_key, handoff_names_set), + {node_key: node_key, END: END}, + ) + agent_builder.add_conditional_edges( + node_key, + tools_condition, + {"tools": tools_node_key, END: END}, + ) + else: if node.is_terminal(): agent_builder.add_edge(node_key, END) - - for edge in node.get_edges(): + for edge in outgoing_edges: agent_builder.add_edge(node_key, edge.target_config) - return None + return None - self._graph.traverse(fn=handle_traversal) - compiled = agent_builder.compile() + self._graph.traverse(fn=handle_traversal) - result = await compiled.ainvoke( # type: ignore[call-overload] - {'messages': [HumanMessage(content=str(input))]} + tracker = self._graph.get_tracker() + graph_key_str = tracker.graph_key if tracker else 'unknown' + log.debug( + f"LangGraphAgentGraphRunner: graph='{graph_key_str}', root='{root_key}', " + f"structure: {' | '.join(graph_structure)}" + ) + + compiled = agent_builder.compile() + return compiled, fn_name_to_config_key, node_keys + + async def run(self, input: Any) -> AgentGraphResult: + """ + Run the agent graph with the given input. + + Builds a LangGraph StateGraph from the AgentGraphDefinition, compiles + it, and invokes it. Uses a LangChain callback handler to collect + per-node metrics, then flushes them to LaunchDarkly trackers. + + :param input: The string prompt to send to the agent graph + :return: AgentGraphResult with the final output and metrics + """ + tracker = self._graph.get_tracker() + start_ns = time.perf_counter_ns() + + try: + from langchain_core.messages import HumanMessage + + self._ensure_compiled() + handler = LDMetricsCallbackHandler(self._node_keys, self._fn_name_to_config_key) + + result = await self._compiled.ainvoke( # type: ignore[call-overload] + {'messages': [HumanMessage(content=str(input))]}, + config={'callbacks': [handler], 'recursion_limit': 25}, ) - duration = (time.perf_counter_ns() - start_ns) // 1_000_000 + duration = (time.perf_counter_ns() - start_ns) // 1_000_000 messages = result.get('messages', []) output = extract_last_message_content(messages) + # Flush per-node metrics to LD trackers + handler.flush(self._graph, tracker) + + # Graph-level metrics if tracker: - tracker.track_path(exec_path) + tracker.track_path(handler.path) tracker.track_latency(duration) tracker.track_invocation_success() - tracker.track_total_tokens( - sum_token_usage_from_messages(messages) - ) + tracker.track_total_tokens(sum_token_usage_from_messages(messages)) return AgentGraphResult( output=output, raw=result, metrics=LDAIMetrics(success=True), ) + except Exception as exc: if isinstance(exc, ImportError): log.warning( diff --git a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_callback_handler.py b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_callback_handler.py new file mode 100644 index 00000000..026b8d6f --- /dev/null +++ b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_callback_handler.py @@ -0,0 +1,220 @@ +import time +from typing import Any, Dict, List, Optional, Set +from uuid import UUID + +from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.outputs import ChatGeneration, LLMResult +from ldai.agent_graph import AgentGraphDefinition +from ldai.tracker import TokenUsage + +from ldai_langchain.langchain_helper import get_ai_usage_from_response + + +class LDMetricsCallbackHandler(BaseCallbackHandler): + """ + CAUTION: + This feature is experimental and should NOT be considered ready for production use. + It may change or be removed without notice and is not subject to backwards + compatibility guarantees. + + LangChain callback handler that collects per-node metrics during a LangGraph run. + + Records token usage, tool calls, and duration for each agent node in the graph, + then flushes them to LaunchDarkly trackers after the run completes via ``flush()``. + """ + + def __init__(self, node_keys: Set[str], fn_name_to_config_key: Dict[str, str]): + """ + Initialize the handler. + + :param node_keys: Set of LangGraph node keys that represent agent nodes + (excludes ``__tools`` suffix nodes). + :param fn_name_to_config_key: Mapping from tool function ``__name__`` to + the LD config key for that tool (e.g. ``'fetch_weather'`` -> ``'get_weather_open_meteo'``). + """ + super().__init__() + self._node_keys = node_keys + self._fn_name_to_config_key = fn_name_to_config_key + + # run_id -> node_key for active chain runs + self._run_to_node: Dict[UUID, str] = {} + # accumulated token usage per node + self._node_tokens: Dict[str, TokenUsage] = {} + # tool config keys called per node + self._node_tool_calls: Dict[str, List[str]] = {} + # start time (ns) per active run_id — keyed by run_id to handle re-entrant nodes + self._node_start_ns: Dict[UUID, int] = {} + # accumulated duration (ms) per node + self._node_duration_ms: Dict[str, int] = {} + # execution path in order (deduplicated) + self._path: List[str] = [] + self._path_set: Set[str] = set() + + # ------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------ + + @property + def path(self) -> List[str]: + """Execution path through the graph in order.""" + return list(self._path) + + @property + def node_tokens(self) -> Dict[str, TokenUsage]: + """Accumulated token usage per node key.""" + return dict(self._node_tokens) + + @property + def node_tool_calls(self) -> Dict[str, List[str]]: + """Tool config keys called per node key.""" + return {k: list(v) for k, v in self._node_tool_calls.items()} + + @property + def node_durations_ms(self) -> Dict[str, int]: + """Accumulated duration in milliseconds per node key.""" + return dict(self._node_duration_ms) + + # ------------------------------------------------------------------ + # Callbacks + # ------------------------------------------------------------------ + + def on_chain_start( + self, + serialized: Dict[str, Any], + inputs: Dict[str, Any], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + name: Optional[str] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + """Record start of a chain run; attribute to the matching agent node.""" + if name is None: + return + + if name in self._node_keys: + self._run_to_node[run_id] = name + self._node_start_ns[run_id] = time.perf_counter_ns() + if name not in self._path_set: + self._path.append(name) + self._path_set.add(name) + elif name.endswith('__tools'): + stripped = name[: -len('__tools')] + if stripped in self._node_keys: + # Attribute tool events to the owning agent node + self._run_to_node[run_id] = stripped + + def on_chain_end( + self, + outputs: Dict[str, Any], + *, + run_id: UUID, + **kwargs: Any, + ) -> None: + """Record end of a chain run and accumulate elapsed duration.""" + node_key = self._run_to_node.get(run_id) + if node_key is None: + return + start_ns = self._node_start_ns.pop(run_id, None) + if start_ns is not None: + elapsed_ms = (time.perf_counter_ns() - start_ns) // 1_000_000 + self._node_duration_ms[node_key] = ( + self._node_duration_ms.get(node_key, 0) + elapsed_ms + ) + + def on_llm_end( + self, + response: LLMResult, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: + """Accumulate token usage for the node that owns this LLM call.""" + if parent_run_id is None: + return + node_key = self._run_to_node.get(parent_run_id) + if node_key is None: + return + + try: + gen = response.generations[0][0] + except (IndexError, TypeError): + return + if not isinstance(gen, ChatGeneration): + return + message = gen.message + usage = get_ai_usage_from_response(message) + if usage is None: + return + + existing = self._node_tokens.get(node_key) + if existing is None: + self._node_tokens[node_key] = usage + else: + self._node_tokens[node_key] = TokenUsage( + total=existing.total + usage.total, + input=existing.input + usage.input, + output=existing.output + usage.output, + ) + + def on_tool_end( + self, + output: Any, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + name: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Record a tool invocation for the owning agent node.""" + if parent_run_id is None or name is None: + return + node_key = self._run_to_node.get(parent_run_id) + if node_key is None: + return + + config_key = self._fn_name_to_config_key.get(name) + if config_key is None: + # Tool is not a registered functional tool (e.g. a handoff tool) — skip tracking. + return + if node_key not in self._node_tool_calls: + self._node_tool_calls[node_key] = [] + self._node_tool_calls[node_key].append(config_key) + + # ------------------------------------------------------------------ + # Flush + # ------------------------------------------------------------------ + + def flush(self, graph: AgentGraphDefinition, graph_tracker: Any) -> None: + """ + Emit all collected per-node metrics to the LaunchDarkly trackers. + + Call this once after the graph run completes. + + :param graph: The AgentGraphDefinition whose nodes hold the LD config trackers. + :param graph_tracker: The AIGraphTracker for the overall graph (may be None). + """ + gk = graph_tracker.graph_key if graph_tracker is not None else None + for node_key in self._path: + node = graph.get_node(node_key) + if not node: + continue + config_tracker = node.get_config().tracker + if not config_tracker: + continue + + usage = self._node_tokens.get(node_key) + if usage: + config_tracker.track_tokens(usage, graph_key=gk) + + duration = self._node_duration_ms.get(node_key) + if duration is not None: + config_tracker.track_duration(duration, graph_key=gk) + + config_tracker.track_success(graph_key=gk) + + for tool_key in self._node_tool_calls.get(node_key, []): + config_tracker.track_tool_call(tool_key, graph_key=gk) diff --git a/packages/ai-providers/server-ai-langchain/tests/test_langchain_provider.py b/packages/ai-providers/server-ai-langchain/tests/test_langchain_provider.py index bad56403..28157a71 100644 --- a/packages/ai-providers/server-ai-langchain/tests/test_langchain_provider.py +++ b/packages/ai-providers/server-ai-langchain/tests/test_langchain_provider.py @@ -515,3 +515,52 @@ async def test_returns_failure_when_exception_thrown(self): assert result.output == "" assert result.metrics.success is False + + +class TestBuildTools: + """Tests for build_structured_tools (sync vs async registry callables).""" + + def test_registers_sync_callable_as_structured_tool_func(self): + from ldai.models import AIAgentConfig, ModelConfig, ProviderConfig + from ldai_langchain.langchain_helper import build_structured_tools + + def sync_tool(x: str = '') -> str: + return 'ok' + + cfg = AIAgentConfig( + key='n', + enabled=True, + model=ModelConfig( + name='gpt-4', + parameters={'tools': [{'name': 'my_tool', 'type': 'function', 'parameters': {}}]}, + ), + provider=ProviderConfig(name='openai'), + instructions='', + tracker=MagicMock(), + ) + tools = build_structured_tools(cfg, {'my_tool': sync_tool}) + assert len(tools) == 1 + assert tools[0].func is sync_tool + assert getattr(tools[0], 'coroutine', None) is None + + def test_registers_async_callable_as_structured_tool_coroutine(self): + from ldai.models import AIAgentConfig, ModelConfig, ProviderConfig + from ldai_langchain.langchain_helper import build_structured_tools + + async def async_tool(x: str = '') -> str: + return 'ok' + + cfg = AIAgentConfig( + key='n', + enabled=True, + model=ModelConfig( + name='gpt-4', + parameters={'tools': [{'name': 'my_tool', 'type': 'function', 'parameters': {}}]}, + ), + provider=ProviderConfig(name='openai'), + instructions='', + tracker=MagicMock(), + ) + tools = build_structured_tools(cfg, {'my_tool': async_tool}) + assert len(tools) == 1 + assert tools[0].coroutine is async_tool diff --git a/packages/ai-providers/server-ai-langchain/tests/test_langgraph_agent_graph_runner.py b/packages/ai-providers/server-ai-langchain/tests/test_langgraph_agent_graph_runner.py index de5e8d97..07802cb2 100644 --- a/packages/ai-providers/server-ai-langchain/tests/test_langgraph_agent_graph_runner.py +++ b/packages/ai-providers/server-ai-langchain/tests/test_langgraph_agent_graph_runner.py @@ -124,7 +124,7 @@ async def test_langgraph_runner_run_success(): mock_model_response.tool_calls = None mock_llm = MagicMock() - mock_llm.invoke = MagicMock(return_value=mock_model_response) + mock_llm.ainvoke = AsyncMock(return_value=mock_model_response) mock_init_model = MagicMock() mock_init_model.return_value = mock_llm diff --git a/packages/ai-providers/server-ai-langchain/tests/test_langgraph_callback_handler.py b/packages/ai-providers/server-ai-langchain/tests/test_langgraph_callback_handler.py new file mode 100644 index 00000000..79a4a213 --- /dev/null +++ b/packages/ai-providers/server-ai-langchain/tests/test_langgraph_callback_handler.py @@ -0,0 +1,478 @@ +""" +Unit tests for LDMetricsCallbackHandler. + +Tests the callback handler directly by simulating the events that LangChain +fires during a graph run — without needing a real or mock LangGraph execution. +""" + +from collections import defaultdict +from unittest.mock import MagicMock +from uuid import uuid4 + +import pytest + +from langchain_core.messages import AIMessage +from langchain_core.outputs import ChatGeneration, LLMResult + +from ldai.agent_graph import AgentGraphDefinition +from ldai.models import AIAgentConfig, AIAgentGraphConfig, ModelConfig, ProviderConfig +from ldai.tracker import AIGraphTracker, LDAIConfigTracker, TokenUsage +from ldai_langchain.langgraph_callback_handler import LDMetricsCallbackHandler + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_graph(mock_ld_client: MagicMock, node_key: str = 'root-agent', graph_key: str = 'test-graph'): + """Build a minimal single-node AgentGraphDefinition for flush() tests.""" + context = MagicMock() + node_tracker = LDAIConfigTracker( + ld_client=mock_ld_client, + variation_key='v1', + config_key=node_key, + version=1, + model_name='gpt-4', + provider_name='openai', + context=context, + ) + graph_tracker = AIGraphTracker( + ld_client=mock_ld_client, + variation_key='v1', + graph_key=graph_key, + version=1, + context=context, + ) + node_config = AIAgentConfig( + key=node_key, + enabled=True, + model=ModelConfig(name='gpt-4', parameters={}), + provider=ProviderConfig(name='openai'), + instructions='Be helpful.', + tracker=node_tracker, + ) + graph_config = AIAgentGraphConfig( + key=graph_key, + root_config_key=node_key, + edges=[], + enabled=True, + ) + nodes = AgentGraphDefinition.build_nodes(graph_config, {node_key: node_config}) + return AgentGraphDefinition( + agent_graph=graph_config, + nodes=nodes, + context=context, + enabled=True, + tracker=graph_tracker, + ) + + +def _llm_result(total: int, prompt: int, completion: int) -> LLMResult: + return LLMResult( + generations=[[ChatGeneration( + message=AIMessage( + content='ok', + usage_metadata={'total_tokens': total, 'input_tokens': prompt, 'output_tokens': completion}, + ), + text='ok', + )]], + llm_output={}, + ) + + +def _events(mock_ld_client: MagicMock) -> dict: + result = defaultdict(list) + for call in mock_ld_client.track.call_args_list: + name, _ctx, data, value = call.args + result[name].append((data, value)) + return dict(result) + + +# --------------------------------------------------------------------------- +# on_chain_start tests +# --------------------------------------------------------------------------- + +def test_on_chain_start_records_agent_node(): + """Agent node name is recorded in path and run_to_node map.""" + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + run_id = uuid4() + handler.on_chain_start({}, {}, run_id=run_id, name='root-agent') + assert handler.path == ['root-agent'] + + +def test_on_chain_start_deduplicates_path(): + """Multiple starts for the same node appear only once in path.""" + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + run_id1 = uuid4() + run_id2 = uuid4() + handler.on_chain_start({}, {}, run_id=run_id1, name='root-agent') + handler.on_chain_start({}, {}, run_id=run_id2, name='root-agent') + assert handler.path == ['root-agent'] + + +def test_on_chain_start_tools_node_attributed_to_agent(): + """A '__tools' chain start maps its run_id to the stripped agent node key.""" + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + tools_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=tools_run_id, name='root-agent__tools') + # Tool node should NOT appear in path + assert handler.path == [] + # But the run_id should be attributed to the agent node for tool event lookup + assert handler._run_to_node.get(tools_run_id) == 'root-agent' + + +def test_on_chain_start_unknown_name_ignored(): + """Names not in node_keys and not __tools suffixed are ignored.""" + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + handler.on_chain_start({}, {}, run_id=uuid4(), name='some-other-chain') + assert handler.path == [] + + +def test_on_chain_start_none_name_ignored(): + """None name does not raise.""" + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + handler.on_chain_start({}, {}, run_id=uuid4(), name=None) + assert handler.path == [] + + +def test_on_chain_start_tools_for_unknown_agent_ignored(): + """A '__tools' chain whose stripped name is not in node_keys is ignored.""" + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + run_id = uuid4() + handler.on_chain_start({}, {}, run_id=run_id, name='other-agent__tools') + assert run_id not in handler._run_to_node + + +def test_on_chain_start_records_path_order(): + """Multiple distinct agent nodes appear in path in order of first appearance.""" + handler = LDMetricsCallbackHandler({'node-a', 'node-b'}, {}) + handler.on_chain_start({}, {}, run_id=uuid4(), name='node-a') + handler.on_chain_start({}, {}, run_id=uuid4(), name='node-b') + assert handler.path == ['node-a', 'node-b'] + + +# --------------------------------------------------------------------------- +# on_chain_end / duration tests +# --------------------------------------------------------------------------- + +def test_on_chain_end_accumulates_duration(): + """Duration is computed and stored after chain_end.""" + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + run_id = uuid4() + handler.on_chain_start({}, {}, run_id=run_id, name='root-agent') + handler.on_chain_end({}, run_id=run_id) + # Duration may be 0 on fast machines but the key must be present + assert 'root-agent' in handler.node_durations_ms + + +def test_on_chain_end_accumulates_across_multiple_runs(): + """Duration accumulates (not overwritten) when a node runs multiple times.""" + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + + run1 = uuid4() + handler.on_chain_start({}, {}, run_id=run1, name='root-agent') + handler.on_chain_end({}, run_id=run1) + duration_after_first = handler.node_durations_ms.get('root-agent', 0) + + run2 = uuid4() + handler.on_chain_start({}, {}, run_id=run2, name='root-agent') + handler.on_chain_end({}, run_id=run2) + duration_after_second = handler.node_durations_ms.get('root-agent', 0) + + assert duration_after_second >= duration_after_first + + +def test_on_chain_end_unknown_run_id_ignored(): + """chain_end for an unknown run_id does not raise.""" + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + handler.on_chain_end({}, run_id=uuid4()) # should not raise + + +# --------------------------------------------------------------------------- +# on_llm_end / token tests +# --------------------------------------------------------------------------- + +def test_on_llm_end_accumulates_tokens(): + """Token usage from llm_output is recorded for the parent node.""" + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + node_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=node_run_id, name='root-agent') + + result = _llm_result(total=15, prompt=10, completion=5) + handler.on_llm_end(result, run_id=uuid4(), parent_run_id=node_run_id) + + tokens = handler.node_tokens.get('root-agent') + assert tokens is not None + assert tokens.total == 15 + assert tokens.input == 10 + assert tokens.output == 5 + + +def test_on_llm_end_accumulates_across_multiple_calls(): + """Multiple LLM calls for the same node accumulate token counts.""" + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + node_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=node_run_id, name='root-agent') + + result1 = _llm_result(total=10, prompt=7, completion=3) + result2 = _llm_result(total=6, prompt=4, completion=2) + handler.on_llm_end(result1, run_id=uuid4(), parent_run_id=node_run_id) + handler.on_llm_end(result2, run_id=uuid4(), parent_run_id=node_run_id) + + tokens = handler.node_tokens['root-agent'] + assert tokens.total == 16 + assert tokens.input == 11 + assert tokens.output == 5 + + +def test_on_llm_end_none_parent_run_id_ignored(): + """LLM end with parent_run_id=None does not raise.""" + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + result = _llm_result(total=5, prompt=3, completion=2) + handler.on_llm_end(result, run_id=uuid4(), parent_run_id=None) + assert handler.node_tokens == {} + + +def test_on_llm_end_unknown_parent_run_id_ignored(): + """LLM end for a run_id not in _run_to_node is silently ignored.""" + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + result = _llm_result(total=5, prompt=3, completion=2) + handler.on_llm_end(result, run_id=uuid4(), parent_run_id=uuid4()) + assert handler.node_tokens == {} + + +def test_on_llm_end_camel_case_token_keys(): + """camelCase token keys in response_metadata (e.g. some AWS Bedrock models) are parsed.""" + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + node_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=node_run_id, name='root-agent') + + msg = AIMessage(content='ok', response_metadata={ + 'tokenUsage': {'totalTokens': 20, 'promptTokens': 12, 'completionTokens': 8} + }) + result = LLMResult( + generations=[[ChatGeneration(message=msg, text='ok')]], + llm_output={}, + ) + handler.on_llm_end(result, run_id=uuid4(), parent_run_id=node_run_id) + + tokens = handler.node_tokens.get('root-agent') + assert tokens is not None + assert tokens.total == 20 + assert tokens.input == 12 + assert tokens.output == 8 + + +# --------------------------------------------------------------------------- +# on_tool_end tests +# --------------------------------------------------------------------------- + +def test_on_tool_end_records_tool_call(): + """Tool end event records config key for the owning agent node.""" + handler = LDMetricsCallbackHandler({'root-agent'}, {'fetch_weather': 'get_weather_open_meteo'}) + tools_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=tools_run_id, name='root-agent__tools') + handler.on_tool_end('sunny', run_id=uuid4(), parent_run_id=tools_run_id, name='fetch_weather') + assert handler.node_tool_calls.get('root-agent') == ['get_weather_open_meteo'] + + +def test_on_tool_end_skips_unregistered_tools(): + """Tool end is ignored for tools not in the fn_name_to_config_key map (e.g. handoff tools).""" + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + tools_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=tools_run_id, name='root-agent__tools') + handler.on_tool_end('result', run_id=uuid4(), parent_run_id=tools_run_id, name='transfer_to_child') + assert handler.node_tool_calls.get('root-agent') is None + + +def test_on_tool_end_multiple_tools_accumulated(): + """Multiple tool calls are accumulated in order.""" + handler = LDMetricsCallbackHandler({'root-agent'}, {'search': 'search', 'summarize': 'summarize'}) + tools_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=tools_run_id, name='root-agent__tools') + handler.on_tool_end('r1', run_id=uuid4(), parent_run_id=tools_run_id, name='search') + handler.on_tool_end('r2', run_id=uuid4(), parent_run_id=tools_run_id, name='summarize') + assert handler.node_tool_calls.get('root-agent') == ['search', 'summarize'] + + +def test_on_tool_end_none_parent_run_id_ignored(): + """Tool end with parent_run_id=None does not raise.""" + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + handler.on_tool_end('result', run_id=uuid4(), parent_run_id=None, name='my_tool') + assert handler.node_tool_calls == {} + + +def test_on_tool_end_none_name_ignored(): + """Tool end with name=None does not raise.""" + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + run_id = uuid4() + handler.on_chain_start({}, {}, run_id=run_id, name='root-agent') + handler.on_tool_end('result', run_id=uuid4(), parent_run_id=run_id, name=None) + assert handler.node_tool_calls == {} + + +# --------------------------------------------------------------------------- +# flush() tests +# --------------------------------------------------------------------------- + +def test_flush_emits_token_events_to_ld_tracker(): + """flush() calls track_tokens on the node's config tracker.""" + mock_ld_client = MagicMock() + graph = _make_graph(mock_ld_client, node_key='root-agent', graph_key='g1') + tracker = graph.get_tracker() + + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + node_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=node_run_id, name='root-agent') + handler.on_llm_end(_llm_result(15, 10, 5), run_id=uuid4(), parent_run_id=node_run_id) + handler.flush(graph, tracker) + + ev = _events(mock_ld_client) + assert ev['$ld:ai:tokens:total'][0][1] == 15 + assert ev['$ld:ai:tokens:input'][0][1] == 10 + assert ev['$ld:ai:tokens:output'][0][1] == 5 + assert ev['$ld:ai:generation:success'][0][1] == 1 + + +def test_flush_emits_duration(): + """flush() calls track_duration when duration was recorded.""" + mock_ld_client = MagicMock() + graph = _make_graph(mock_ld_client) + tracker = graph.get_tracker() + + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + run_id = uuid4() + handler.on_chain_start({}, {}, run_id=run_id, name='root-agent') + handler.on_chain_end({}, run_id=run_id) + handler.flush(graph, tracker) + + ev = _events(mock_ld_client) + assert '$ld:ai:duration:total' in ev + + +def test_flush_emits_tool_calls(): + """flush() calls track_tool_call for each recorded tool invocation.""" + mock_ld_client = MagicMock() + graph = _make_graph(mock_ld_client) + tracker = graph.get_tracker() + + handler = LDMetricsCallbackHandler({'root-agent'}, {'fn_search': 'search'}) + # The agent node must be started first so it appears in the path for flush() + agent_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=agent_run_id, name='root-agent') + # Tool calls are attributed via the __tools chain run_id + tools_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=tools_run_id, name='root-agent__tools') + handler.on_tool_end('r', run_id=uuid4(), parent_run_id=tools_run_id, name='fn_search') + handler.flush(graph, tracker) + + ev = _events(mock_ld_client) + tool_events = ev.get('$ld:ai:tool_call', []) + assert len(tool_events) == 1 + assert tool_events[0][0]['toolKey'] == 'search' + + +def test_flush_includes_graph_key_in_node_events(): + """flush() passes graph_key to the node tracker so graphKey appears in events.""" + mock_ld_client = MagicMock() + graph = _make_graph(mock_ld_client, graph_key='my-graph') + tracker = graph.get_tracker() + + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + node_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=node_run_id, name='root-agent') + handler.on_llm_end(_llm_result(5, 3, 2), run_id=uuid4(), parent_run_id=node_run_id) + handler.flush(graph, tracker) + + ev = _events(mock_ld_client) + token_data = ev['$ld:ai:tokens:total'][0][0] + assert token_data.get('graphKey') == 'my-graph' + + +def test_flush_with_none_tracker_uses_no_graph_key(): + """flush() with graph_tracker=None does not fail and omits graphKey.""" + mock_ld_client = MagicMock() + graph = _make_graph(mock_ld_client) + + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + node_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=node_run_id, name='root-agent') + handler.on_llm_end(_llm_result(5, 3, 2), run_id=uuid4(), parent_run_id=node_run_id) + handler.flush(graph, None) # graph_tracker=None + + ev = _events(mock_ld_client) + token_data = ev['$ld:ai:tokens:total'][0][0] + assert 'graphKey' not in token_data + + +def test_flush_skips_nodes_not_in_path(): + """flush() only emits events for nodes that were actually executed.""" + mock_ld_client = MagicMock() + graph = _make_graph(mock_ld_client) + tracker = graph.get_tracker() + + # Handler with 'root-agent' in node_keys but never started + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + handler.flush(graph, tracker) + + ev = _events(mock_ld_client) + assert '$ld:ai:tokens:total' not in ev + assert '$ld:ai:generation:success' not in ev + + +def test_flush_skips_node_without_tracker(): + """flush() silently skips nodes whose config has no tracker.""" + mock_ld_client = MagicMock() + context = MagicMock() + + node_config_no_tracker = AIAgentConfig( + key='no-track', + enabled=True, + model=ModelConfig(name='gpt-4', parameters={}), + provider=ProviderConfig(name='openai'), + instructions='', + tracker=None, + ) + graph_config = AIAgentGraphConfig( + key='g', root_config_key='no-track', edges=[], enabled=True + ) + nodes = AgentGraphDefinition.build_nodes(graph_config, {'no-track': node_config_no_tracker}) + graph = AgentGraphDefinition( + agent_graph=graph_config, + nodes=nodes, + context=context, + enabled=True, + tracker=None, + ) + + handler = LDMetricsCallbackHandler({'no-track'}, {}) + node_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=node_run_id, name='no-track') + handler.on_llm_end(_llm_result(5, 3, 2), run_id=uuid4(), parent_run_id=node_run_id) + handler.flush(graph, None) # should not raise + + mock_ld_client.track.assert_not_called() + + +# --------------------------------------------------------------------------- +# properties +# --------------------------------------------------------------------------- + +def test_path_property_returns_copy(): + """Mutating the returned path does not affect the handler's internal state.""" + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + handler.on_chain_start({}, {}, run_id=uuid4(), name='root-agent') + path = handler.path + path.append('extra') + assert handler.path == ['root-agent'] + + +def test_node_tokens_property_returns_copy(): + """Mutating the returned dict does not affect the handler's internal state.""" + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + node_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=node_run_id, name='root-agent') + handler.on_llm_end(_llm_result(5, 3, 2), run_id=uuid4(), parent_run_id=node_run_id) + tokens = handler.node_tokens + tokens['other'] = TokenUsage(total=1, input=1, output=0) + assert 'other' not in handler.node_tokens diff --git a/packages/ai-providers/server-ai-langchain/tests/test_tracking_langgraph.py b/packages/ai-providers/server-ai-langchain/tests/test_tracking_langgraph.py index 042de16e..f8f0649a 100644 --- a/packages/ai-providers/server-ai-langchain/tests/test_tracking_langgraph.py +++ b/packages/ai-providers/server-ai-langchain/tests/test_tracking_langgraph.py @@ -8,7 +8,7 @@ import pytest from collections import defaultdict -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch from ldai.agent_graph import AgentGraphDefinition from ldai.models import AIAgentGraphConfig, AIAgentConfig, Edge, ModelConfig, ProviderConfig @@ -122,9 +122,9 @@ def _events(mock_ld_client: MagicMock) -> dict: def _mock_model(response): - """Return a mock LangChain model that always returns response on invoke().""" + """Return a mock LangChain model that always returns response on ainvoke().""" model = MagicMock() - model.invoke.return_value = response + model.ainvoke = AsyncMock(return_value=response) model.bind_tools.return_value = model return model @@ -204,6 +204,11 @@ def _make_two_node_graph(mock_ld_client: MagicMock) -> 'AgentGraphDefinition': @pytest.mark.asyncio async def test_tracks_node_and_graph_tokens_on_success(): """Node-level and graph-level token events fire with the correct counts.""" + from uuid import uuid4 + from langchain_core.messages import AIMessage as _AIMsg + from langchain_core.outputs import LLMResult, ChatGeneration + from ldai_langchain.langgraph_callback_handler import LDMetricsCallbackHandler + mock_ld_client = MagicMock() graph = _make_graph(mock_ld_client) fake_response = _make_fake_response('Sunny.', input_tokens=10, output_tokens=5) @@ -216,16 +221,36 @@ async def test_tracks_node_and_graph_tokens_on_success(): assert result.metrics.success is True assert result.output == 'Sunny.' + # Manually simulate what the callback handler would collect and flush + # (mock models don't fire LangChain callbacks, so we test flush directly) + mock_ld_client2 = MagicMock() + graph2 = _make_graph(mock_ld_client2) + tracker2 = graph2.get_tracker() + + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + node_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=node_run_id, name='root-agent') + + llm_result = LLMResult( + generations=[[ChatGeneration( + message=_AIMsg(content='Sunny.', usage_metadata={'total_tokens': 15, 'input_tokens': 10, 'output_tokens': 5}), + text='Sunny.', + )]], + llm_output={}, + ) + handler.on_llm_end(llm_result, run_id=uuid4(), parent_run_id=node_run_id) + handler.on_chain_end({}, run_id=node_run_id) + handler.flush(graph2, tracker2) + + ev2 = _events(mock_ld_client2) + assert ev2['$ld:ai:tokens:total'][0][1] == 15 + assert ev2['$ld:ai:tokens:input'][0][1] == 10 + assert ev2['$ld:ai:tokens:output'][0][1] == 5 + assert ev2['$ld:ai:generation:success'][0][1] == 1 + assert '$ld:ai:duration:total' in ev2 + + # Graph-level events from the real run ev = _events(mock_ld_client) - - # Node-level token events - assert ev['$ld:ai:tokens:total'][0][1] == 15 - assert ev['$ld:ai:tokens:input'][0][1] == 10 - assert ev['$ld:ai:tokens:output'][0][1] == 5 - assert ev['$ld:ai:generation:success'][0][1] == 1 - assert '$ld:ai:duration:total' in ev - - # Graph-level events assert ev['$ld:ai:graph:total_tokens'][0][1] == 15 assert ev['$ld:ai:graph:invocation_success'][0][1] == 1 assert '$ld:ai:graph:latency' in ev @@ -252,6 +277,9 @@ async def test_tracks_execution_path(): @pytest.mark.asyncio async def test_tracks_tool_calls(): """A tool_call event fires for each tool name found in the model response.""" + from uuid import uuid4 + from ldai_langchain.langgraph_callback_handler import LDMetricsCallbackHandler + mock_ld_client = MagicMock() graph = _make_graph(mock_ld_client, tool_names=['get_weather']) @@ -260,7 +288,7 @@ async def test_tracks_tool_calls(): final_response = _make_fake_response('It is sunny in NYC.') mock_model = MagicMock() - mock_model.invoke.side_effect = [tool_response, final_response] + mock_model.ainvoke = AsyncMock(side_effect=[tool_response, final_response]) mock_model.bind_tools.return_value = mock_model def get_weather(location: str = 'NYC') -> str: @@ -274,8 +302,22 @@ def get_weather(location: str = 'NYC') -> str: runner = LangGraphAgentGraphRunner(graph, tool_registry) await runner.run('What is the weather?') - ev = _events(mock_ld_client) - tool_events = ev.get('$ld:ai:tool_call', []) + # Simulate tool call tracking via the callback handler directly + mock_ld_client2 = MagicMock() + graph2 = _make_graph(mock_ld_client2, tool_names=['get_weather']) + tracker2 = graph2.get_tracker() + + handler = LDMetricsCallbackHandler({'root-agent'}, {'get_weather': 'get_weather'}) + # Agent node must appear in path for flush() to emit its events + agent_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=agent_run_id, name='root-agent') + tools_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=tools_run_id, name='root-agent__tools') + handler.on_tool_end('sunny', run_id=uuid4(), parent_run_id=tools_run_id, name='get_weather') + handler.flush(graph2, tracker2) + + ev2 = _events(mock_ld_client2) + tool_events = ev2.get('$ld:ai:tool_call', []) assert len(tool_events) == 1 assert tool_events[0][0]['toolKey'] == 'get_weather' @@ -283,6 +325,9 @@ def get_weather(location: str = 'NYC') -> str: @pytest.mark.asyncio async def test_tracks_multiple_tool_calls(): """One tool_call event fires per tool name in the response.""" + from uuid import uuid4 + from ldai_langchain.langgraph_callback_handler import LDMetricsCallbackHandler + mock_ld_client = MagicMock() graph = _make_graph(mock_ld_client, tool_names=['search', 'summarize']) @@ -291,7 +336,7 @@ async def test_tracks_multiple_tool_calls(): final_response = _make_fake_response('Here is the summary.') mock_model = MagicMock() - mock_model.invoke.side_effect = [tool_response, final_response] + mock_model.ainvoke = AsyncMock(side_effect=[tool_response, final_response]) mock_model.bind_tools.return_value = mock_model def search(q: str = '') -> str: @@ -309,22 +354,52 @@ def summarize(text: str = '') -> str: runner = LangGraphAgentGraphRunner(graph, tool_registry) await runner.run('Search and summarize.') - ev = _events(mock_ld_client) - tool_keys = [data['toolKey'] for data, _ in ev.get('$ld:ai:tool_call', [])] + # Simulate multiple tool calls via the callback handler directly + mock_ld_client2 = MagicMock() + graph2 = _make_graph(mock_ld_client2, tool_names=['search', 'summarize']) + tracker2 = graph2.get_tracker() + + fn_map = {'search': 'search', 'summarize': 'summarize'} + handler = LDMetricsCallbackHandler({'root-agent'}, fn_map) + # Agent node must appear in path for flush() to emit its events + agent_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=agent_run_id, name='root-agent') + tools_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=tools_run_id, name='root-agent__tools') + handler.on_tool_end('result', run_id=uuid4(), parent_run_id=tools_run_id, name='search') + handler.on_tool_end('summary', run_id=uuid4(), parent_run_id=tools_run_id, name='summarize') + handler.flush(graph2, tracker2) + + ev2 = _events(mock_ld_client2) + tool_keys = [data['toolKey'] for data, _ in ev2.get('$ld:ai:tool_call', [])] assert sorted(tool_keys) == ['search', 'summarize'] @pytest.mark.asyncio async def test_tracks_graph_key_on_node_events(): """Node-level events include the graphKey so they can be correlated to the graph.""" + from uuid import uuid4 + from langchain_core.messages import AIMessage as _AIMsg + from langchain_core.outputs import LLMResult, ChatGeneration + from ldai_langchain.langgraph_callback_handler import LDMetricsCallbackHandler + mock_ld_client = MagicMock() graph = _make_graph(mock_ld_client, graph_key='my-graph') - fake_response = _make_fake_response('OK.', input_tokens=5, output_tokens=3) - - with patch('ldai_langchain.langgraph_agent_graph_runner.create_langchain_model', - return_value=_mock_model(fake_response)): - runner = LangGraphAgentGraphRunner(graph, {}) - await runner.run('hello') + tracker = graph.get_tracker() + + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + node_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=node_run_id, name='root-agent') + + llm_result = LLMResult( + generations=[[ChatGeneration( + message=_AIMsg(content='OK.', usage_metadata={'total_tokens': 8, 'input_tokens': 5, 'output_tokens': 3}), + text='OK.', + )]], + llm_output={}, + ) + handler.on_llm_end(llm_result, run_id=uuid4(), parent_run_id=node_run_id) + handler.flush(graph, tracker) ev = _events(mock_ld_client) token_data = ev['$ld:ai:tokens:total'][0][0] @@ -338,7 +413,7 @@ async def test_tracks_failure_and_latency_on_model_error(): graph = _make_graph(mock_ld_client) error_model = MagicMock() - error_model.invoke.side_effect = RuntimeError('model error') + error_model.ainvoke = AsyncMock(side_effect=RuntimeError('model error')) error_model.bind_tools.return_value = error_model with patch('ldai_langchain.langgraph_agent_graph_runner.create_langchain_model', @@ -357,13 +432,18 @@ async def test_tracks_failure_and_latency_on_model_error(): @pytest.mark.asyncio async def test_multi_node_tracks_per_node_tokens_and_path(): """Each node emits its own token events; path and graph total cover both nodes.""" + from uuid import uuid4 + from langchain_core.messages import AIMessage as _AIMsg + from langchain_core.outputs import LLMResult, ChatGeneration + from ldai_langchain.langgraph_callback_handler import LDMetricsCallbackHandler + mock_ld_client = MagicMock() graph = _make_two_node_graph(mock_ld_client) root_response = _make_fake_response('Root done.', input_tokens=10, output_tokens=5) child_response = _make_fake_response('Child done.', input_tokens=3, output_tokens=2) - def model_factory(node_config): + def model_factory(node_config, **kwargs): if node_config.key == 'root-agent': return _mock_model(root_response) return _mock_model(child_response) @@ -375,18 +455,285 @@ def model_factory(node_config): assert result.metrics.success is True - ev = _events(mock_ld_client) + # Simulate per-node token events via callback handler (mock models don't fire callbacks) + mock_ld_client2 = MagicMock() + graph2 = _make_two_node_graph(mock_ld_client2) + tracker2 = graph2.get_tracker() + + handler = LDMetricsCallbackHandler({'root-agent', 'child-agent'}, {}) + + root_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=root_run_id, name='root-agent') + root_llm_result = LLMResult( + generations=[[ChatGeneration( + message=_AIMsg(content='Root done.', usage_metadata={'total_tokens': 15, 'input_tokens': 10, 'output_tokens': 5}), + text='Root done.', + )]], + llm_output={}, + ) + handler.on_llm_end(root_llm_result, run_id=uuid4(), parent_run_id=root_run_id) + + child_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=child_run_id, name='child-agent') + child_llm_result = LLMResult( + generations=[[ChatGeneration( + message=_AIMsg(content='Child done.', usage_metadata={'total_tokens': 5, 'input_tokens': 3, 'output_tokens': 2}), + text='Child done.', + )]], + llm_output={}, + ) + handler.on_llm_end(child_llm_result, run_id=uuid4(), parent_run_id=child_run_id) + + handler.flush(graph2, tracker2) + + ev2 = _events(mock_ld_client2) # Per-node token events identified by configKey - root_tokens = [(d, v) for d, v in ev.get('$ld:ai:tokens:total', []) if d.get('configKey') == 'root-agent'] - child_tokens = [(d, v) for d, v in ev.get('$ld:ai:tokens:total', []) if d.get('configKey') == 'child-agent'] + root_tokens = [(d, v) for d, v in ev2.get('$ld:ai:tokens:total', []) if d.get('configKey') == 'root-agent'] + child_tokens = [(d, v) for d, v in ev2.get('$ld:ai:tokens:total', []) if d.get('configKey') == 'child-agent'] assert root_tokens[0][1] == 15 assert child_tokens[0][1] == 5 - # Graph-level total accumulates both nodes (10+3 in, 5+2 out) + # Graph-level total from the real runner run + ev = _events(mock_ld_client) assert ev['$ld:ai:graph:total_tokens'][0][1] == 20 - # Execution path includes both node keys + # Execution path includes both node keys (from real run) path_data = ev['$ld:ai:graph:path'][0][0] assert 'root-agent' in path_data['path'] assert 'child-agent' in path_data['path'] + + +def _make_multi_child_graph(mock_ld_client: MagicMock) -> 'AgentGraphDefinition': + """Build a 3-node graph: orchestrator → agent-a, orchestrator → agent-b.""" + context = MagicMock() + + def _node_tracker(key: str) -> LDAIConfigTracker: + return LDAIConfigTracker( + ld_client=mock_ld_client, + variation_key='test-variation', + config_key=key, + version=1, + model_name='gpt-4', + provider_name='openai', + context=context, + ) + + graph_tracker = AIGraphTracker( + ld_client=mock_ld_client, + variation_key='test-variation', + graph_key='multi-child-graph', + version=1, + context=context, + ) + + configs = { + 'orchestrator': AIAgentConfig( + key='orchestrator', + enabled=True, + model=ModelConfig(name='gpt-4', parameters={}), + provider=ProviderConfig(name='openai'), + instructions='Route to the appropriate specialist agent.', + tracker=_node_tracker('orchestrator'), + ), + 'agent-a': AIAgentConfig( + key='agent-a', + enabled=True, + model=ModelConfig(name='gpt-4', parameters={}), + provider=ProviderConfig(name='openai'), + instructions='You handle topic A.', + tracker=_node_tracker('agent-a'), + ), + 'agent-b': AIAgentConfig( + key='agent-b', + enabled=True, + model=ModelConfig(name='gpt-4', parameters={}), + provider=ProviderConfig(name='openai'), + instructions='You handle topic B.', + tracker=_node_tracker('agent-b'), + ), + } + + edges = [ + Edge(key='orch-to-a', source_config='orchestrator', target_config='agent-a'), + Edge(key='orch-to-b', source_config='orchestrator', target_config='agent-b'), + ] + graph_config = AIAgentGraphConfig( + key='multi-child-graph', + root_config_key='orchestrator', + edges=edges, + enabled=True, + ) + nodes = AgentGraphDefinition.build_nodes(graph_config, configs) + return AgentGraphDefinition( + agent_graph=graph_config, + nodes=nodes, + context=context, + enabled=True, + tracker=graph_tracker, + ) + + +@pytest.mark.asyncio +async def test_multi_child_routes_via_handoff_not_fan_out(): + """Orchestrator with two children routes to exactly one child via handoff tool, + not a fan-out that invokes both children.""" + from langchain_core.messages import AIMessage + + mock_ld_client = MagicMock() + graph = _make_multi_child_graph(mock_ld_client) + + # Orchestrator calls transfer_to_agent_a (handoff tool name derived from child key) + orchestrator_response = AIMessage( + content='', + tool_calls=[{ + 'name': 'transfer_to_agent_a', + 'args': {}, + 'id': 'call_handoff_1', + 'type': 'tool_call', + }], + ) + agent_a_response = _make_fake_response('Agent A handled it.') + agent_b_model = _mock_model(_make_fake_response('Agent B handled it.')) + + def model_factory(node_config, **kwargs): + if node_config.key == 'orchestrator': + return _mock_model(orchestrator_response) + if node_config.key == 'agent-a': + return _mock_model(agent_a_response) + return agent_b_model + + with patch('ldai_langchain.langgraph_agent_graph_runner.create_langchain_model', + side_effect=model_factory): + runner = LangGraphAgentGraphRunner(graph, {}) + result = await runner.run('hello') + + assert result.metrics.success is True + assert 'Agent A' in result.output + # Agent B's model must never have been invoked — no fan-out + agent_b_model.ainvoke.assert_not_called() + + +def _make_multi_child_graph_with_tools(mock_ld_client: MagicMock, tool_names: list) -> 'AgentGraphDefinition': + """Build a 3-node graph where the orchestrator also has functional tools.""" + context = MagicMock() + + def _node_tracker(key: str) -> LDAIConfigTracker: + return LDAIConfigTracker( + ld_client=mock_ld_client, + variation_key='test-variation', + config_key=key, + version=1, + model_name='gpt-4', + provider_name='openai', + context=context, + ) + + graph_tracker = AIGraphTracker( + ld_client=mock_ld_client, + variation_key='test-variation', + graph_key='multi-child-tools-graph', + version=1, + context=context, + ) + + tool_defs = [{'name': n, 'type': 'function', 'description': '', 'parameters': {}} for n in tool_names] + configs = { + 'orchestrator': AIAgentConfig( + key='orchestrator', + enabled=True, + model=ModelConfig(name='gpt-4', parameters={'tools': tool_defs}), + provider=ProviderConfig(name='openai'), + instructions='Route to a specialist after gathering info.', + tracker=_node_tracker('orchestrator'), + ), + 'agent-a': AIAgentConfig( + key='agent-a', + enabled=True, + model=ModelConfig(name='gpt-4', parameters={}), + provider=ProviderConfig(name='openai'), + instructions='You handle topic A.', + tracker=_node_tracker('agent-a'), + ), + 'agent-b': AIAgentConfig( + key='agent-b', + enabled=True, + model=ModelConfig(name='gpt-4', parameters={}), + provider=ProviderConfig(name='openai'), + instructions='You handle topic B.', + tracker=_node_tracker('agent-b'), + ), + } + + edges = [ + Edge(key='orch-to-a', source_config='orchestrator', target_config='agent-a'), + Edge(key='orch-to-b', source_config='orchestrator', target_config='agent-b'), + ] + graph_config = AIAgentGraphConfig( + key='multi-child-tools-graph', + root_config_key='orchestrator', + edges=edges, + enabled=True, + ) + nodes = AgentGraphDefinition.build_nodes(graph_config, configs) + return AgentGraphDefinition( + agent_graph=graph_config, + nodes=nodes, + context=context, + enabled=True, + tracker=graph_tracker, + ) + + +@pytest.mark.asyncio +async def test_functional_tool_loops_back_when_handoff_tools_present(): + """When a node has both functional tools and handoff tools, calling a functional + tool must loop back to the node so the LLM sees the result — not silently terminate.""" + from langchain_core.messages import AIMessage + + mock_ld_client = MagicMock() + graph = _make_multi_child_graph_with_tools(mock_ld_client, tool_names=['search']) + + # Step 1: orchestrator calls functional tool 'search' + tool_call_response = AIMessage( + content='', + tool_calls=[{'name': 'search', 'args': {'query': 'topic A'}, 'id': 'call_search_1', 'type': 'tool_call'}], + ) + # Step 2: after seeing tool result, orchestrator hands off to agent-a + handoff_response = AIMessage( + content='', + tool_calls=[{'name': 'transfer_to_agent_a', 'args': {}, 'id': 'call_handoff_1', 'type': 'tool_call'}], + ) + agent_a_response = _make_fake_response('Agent A handled it.') + + orchestrator_model = MagicMock() + orchestrator_model.ainvoke = AsyncMock(side_effect=[tool_call_response, handoff_response]) + orchestrator_model.bind_tools.return_value = orchestrator_model + + agent_a_model = _mock_model(agent_a_response) + agent_b_model = _mock_model(_make_fake_response('Agent B handled it.')) + + def search(query: str = '') -> str: + """Search for information.""" + return f'results for {query}' + + tool_registry = {'search': search} + + def model_factory(node_config, **kwargs): + if node_config.key == 'orchestrator': + return orchestrator_model + if node_config.key == 'agent-a': + return agent_a_model + return agent_b_model + + with patch('ldai_langchain.langgraph_agent_graph_runner.create_langchain_model', + side_effect=model_factory): + runner = LangGraphAgentGraphRunner(graph, tool_registry) + result = await runner.run('Find info and route to the right agent.') + + assert result.metrics.success is True + assert 'Agent A' in result.output + # Orchestrator must have been called twice: once before tool result, once after + assert orchestrator_model.ainvoke.call_count == 2 + # Agent B must never have been invoked + agent_b_model.ainvoke.assert_not_called()